2219e9b616fba503cfd67263098d6f34b7c167d3
[SXSI/XMLTree.git] / libcds / src / static_sequence / wt_node_internal.cpp
1 /* wt_node_internal.cpp
2  * Copyright (C) 2008, Francisco Claude, all rights reserved.
3  *
4  * wt_node_internal
5  *
6  * This library is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with this library; if not, write to the Free Software
18  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
19  *
20  */
21  
22 #include <wt_node_internal.h>
23
24 wt_node_internal::wt_node_internal(uint * symbols, uint n, uint l, wt_coder * c, static_bitsequence_builder * bmb) {
25         uint * ibitmap = new uint[n/W+1];
26         for(uint i=0;i<n/W+1;i++)
27                 ibitmap[i]=0;
28         for(uint i=0;i<n;i++) 
29                 if(c->is_set(symbols[i],l))
30                         bitset(ibitmap,i);
31         bitmap = bmb->build(ibitmap, n);
32   delete [] ibitmap;
33         uint count_right = bitmap->rank1(n-1);
34         uint count_left = n-count_right+1;
35         uint * left = new uint[count_left+1];
36         uint * right = new uint[count_right+1];
37         count_right = count_left = 0;
38         bool match_left = true, match_right = true;
39         for(uint i=0;i<n;i++) {
40                 if(bitmap->access(i)) {
41                         right[count_right++]=symbols[i];
42                         if(count_right>1)
43                                 if(right[count_right-1]!=right[count_right-2])
44                                         match_right = false;
45                 }
46                 else {
47                         left[count_left++]=symbols[i];
48                         if(count_left>1)
49                                 if(left[count_left-1]!=left[count_left-2])
50                                         match_left = false;
51                 }
52         }
53         if(count_left>0) {
54                 if(match_left/* && c->done(left[0],l+1)*/)
55                         left_child = new wt_node_leaf(left[0], count_left);
56                 else
57                         left_child = new wt_node_internal(left, count_left, l+1, c, bmb);
58         } else {
59                 left_child = NULL;
60         }
61         if(count_right>0) {
62                 if(match_right/* && c->done(right[0],l+1)*/)
63                         right_child = new wt_node_leaf(right[0], count_right);
64                 else
65                         right_child = new wt_node_internal(right, count_right, l+1, c, bmb);
66         } else {
67                 right_child = NULL;
68         }
69         delete [] left;
70         delete [] right;
71 }
72
73 // Deletes symbols array!
74 wt_node_internal::wt_node_internal(uchar * symbols, uint n, uint l, wt_coder * c, static_bitsequence_builder * bmb) {
75         uint * ibitmap = new uint[n/W+1];
76         for(uint i=0;i<n/W+1;i++)
77                 ibitmap[i]=0;
78         for(uint i=0;i<n;i++) 
79                 if(c->is_set((uint)symbols[i],l))
80                         bitset(ibitmap,i);
81         bitmap = bmb->build(ibitmap, n);
82   delete [] ibitmap;
83         uint count_right = bitmap->rank1(n-1);
84         uint count_left = n-count_right+1;
85         uchar * left = new uchar[count_left+1];
86         uchar * right = new uchar[count_right+1];
87         count_right = count_left = 0;
88         bool match_left = true, match_right = true;
89         for(uint i=0;i<n;i++) {
90                 if(bitmap->access(i)) {
91                         right[count_right++]=symbols[i];
92                         if(count_right>1)
93                                 if(right[count_right-1]!=right[count_right-2])
94                                         match_right = false;
95                 }
96                 else {
97                         left[count_left++]=symbols[i];
98                         if(count_left>1)
99                                 if(left[count_left-1]!=left[count_left-2])
100                                         match_left = false;
101                 }
102         }
103
104         delete [] symbols; 
105         symbols = 0;
106
107         if(count_left>0) {
108                 if(match_left/* && c->done(left[0],l+1)*/)
109                 {
110                     left_child = new wt_node_leaf((uint)left[0], count_left);
111                     delete [] left;
112                     left = 0;
113                 }
114                 else
115                 {
116                     left_child = new wt_node_internal(left, count_left, l+1, c, bmb);
117                     left = 0; // Already deleted
118                 }
119         } else {
120                 left_child = NULL;
121         }
122         if(count_right>0) {
123                 if(match_right/* && c->done(right[0],l+1)*/)
124                 {
125                     right_child = new wt_node_leaf((uint)right[0], count_right);
126                     delete [] right;
127                     right = 0;
128                 }
129                 else 
130                 {
131                     right_child = new wt_node_internal(right, count_right, l+1, c, bmb);
132                     right = 0; // Already deleted
133                 }
134         } else {
135                 right_child = NULL;
136         }
137         delete [] left; // already deleted if count_left > 0
138         delete [] right;
139 }
140
141 wt_node_internal::wt_node_internal(uchar * symbols, uint n, uint l, wt_coder * c, static_bitsequence_builder * bmb, uint left, uint *done) {
142         uint * ibitmap = new uint[n/W+1];
143         for(uint i=0;i<n/W+1;i++)
144                 ibitmap[i]=0;
145         for(uint i=0;i<n;i++) 
146                 if(c->is_set((uint)symbols[i + left],l))
147                         bitset(ibitmap,i);
148         bitmap = bmb->build(ibitmap, n);
149         delete [] ibitmap;
150
151         uint count_right = bitmap->rank1(n-1);
152         uint count_left = n-count_right;
153 /*      uchar * leftarr = new uchar[count_left+1];
154         uchar * rightarr = new uchar[count_right+1];
155         count_right = count_left = 0;
156         for(uint i=0;i<n;i++) {
157                 if(bitmap->access(i)) {
158                         rightarr[count_right++]=symbols[i+left];
159                 }
160                 else {
161                         leftarr[count_left++]=symbols[i+left];
162                 }
163                 }
164 */
165
166         for (uint i=0;i<n;i++)
167             set_field(done, 1, i+left, 0);
168
169         for (uint i = 0; i < n; ) 
170         {
171             uint j = i;
172             uchar swap = symbols[j+left];
173             while (!get_field(done, 1, j+left)) { // swapping
174                 ulong k = j; 
175                 if (!c->is_set(swap,l)) 
176                     j = bitmap->rank0(k)-1;
177                 else 
178                     j = count_left + bitmap->rank1(k)-1;
179                 uchar temp = symbols[j+left];
180                 symbols[j+left] = swap;
181                 swap = temp;
182                 set_field(done,1,k+left,1);
183             }
184
185             while (get_field(done,1,i+left))
186                    ++i;
187         }
188
189         // checking
190         /*       for (uint i=0;i<n;i++)
191             if (!bitget(done,i+left)) 
192                 std::cout << "not swapped: " << i << "\n";
193                for (uint i=0;i<count_left;i++)
194             if (leftarr[i] != symbols[i+left]) //c->is_set(symbols[i+left], l)) 
195             {    
196                 std::cout << symbols[i+left] << " != " << leftarr[i] << " lev = " << l << "\n";
197                 exit(0);
198             }
199         for (uint i=count_left;i<n;i++)
200             if (rightarr[i-count_left] != symbols[i+left]) //!c->is_set(symbols[i+left],l)) 
201                 std::cout << symbols[i+left] << " != " << rightarr[i-count_left] <<  " lev = " << l <<  "\n";    
202         */
203         bool match_left = true, match_right = true;
204         for (uint i=1; i < count_left; i++)
205             if (symbols[i+left] != symbols[i+left-1])
206                 match_left = false;
207         for (uint i=count_left + 1; i < n; i++)
208             if (symbols[i+left] != symbols[i+left-1])
209                 match_right = false;
210
211
212         if(count_left>0) {
213                 if(match_left/* && c->done(left[0],l+1)*/)
214                     left_child = new wt_node_leaf((uint)symbols[left], count_left);
215                 else
216                     left_child = new wt_node_internal(symbols, count_left, l+1, c, bmb, left, done);
217         } else {
218                 left_child = NULL;
219         }
220         if(count_right>0) {
221                 if(match_right/* && c->done(right[0],l+1)*/)
222                     right_child = new wt_node_leaf((uint)symbols[left+count_left], count_right);
223                 else 
224                     right_child = new wt_node_internal(symbols, count_right, l+1, c, bmb, left+count_left, done);
225         } else {
226                 right_child = NULL;
227         }
228 }
229
230
231 wt_node_internal::wt_node_internal() { }
232
233 wt_node_internal::~wt_node_internal() {
234         delete bitmap;
235         if(right_child!=NULL) delete right_child;
236         if(left_child!=NULL) delete left_child;
237 }
238
239 uint wt_node_internal::rank(uint symbol, uint pos, uint l, wt_coder * c) {
240         bool is_set = c->is_set(symbol,l);
241         if(!is_set) {
242                 if(left_child==NULL) return 0;
243                 return left_child->rank(symbol, bitmap->rank0(pos)-1,l+1,c);
244         }
245         else {
246                 if(right_child==NULL) return 0;
247                 return right_child->rank(symbol, bitmap->rank1(pos)-1,l+1,c);
248         }
249 }
250
251 // return value is rank of symbol (less or equal to the given symbol) that has rank > 0, 
252 // the parameter symbol is updated accordinly
253 uint wt_node_internal::rankLessThan(uint &symbol, uint pos, uint l, wt_coder * c) 
254 {
255     bool is_set = c->is_set(symbol,l);
256     using std::cout;
257     using std::endl;
258 //    cout << "l = " << l << ", symbol = " << (uchar)symbol << ", rank0 = " << bitmap->rank0(pos) << ", rank1 = " << bitmap->rank1(pos) << endl;
259
260     uint result = -1;
261     if(!is_set) {
262         if(left_child==NULL) return -1;
263         uint rank = bitmap->rank0(pos);
264         if(rank != 0)
265             result = left_child->rankLessThan(symbol,rank-1,l+1,c);
266         return result;
267     }
268
269     uint rank = bitmap->rank1(pos);
270     if (rank != 0 && right_child != NULL)
271         result = right_child->rankLessThan(symbol, rank-1,l+1,c);
272
273 //    cout << "recursion to leftchild at l = " << l << ", symbol = " << (uchar)symbol << ", rank0 = " << bitmap->rank0(pos) << ", rank1 = " << bitmap->rank1(pos) << endl;
274     // check left child for symbols <= givenSymbol
275     if (result != (uint)-1 || left_child == NULL)
276         return result;
277     return left_child->rankLessThan(symbol, bitmap->rank0(pos)-1);
278 }
279
280 uint wt_node_internal::rankLessThan(uint &symbol, uint pos) 
281 {
282     uint result = -1;
283     using std::cout;
284     using std::endl;
285 //    cout << "pos = " << pos << ", symbol = " << (uchar)symbol << endl;
286     
287     if (pos == (uint)-1)
288         return (uint)-1;
289     if(right_child!=NULL)
290         result = right_child->rankLessThan(symbol, bitmap->rank1(pos)-1);
291     if(result == (uint)-1 && left_child!=NULL)
292         return left_child->rankLessThan(symbol, bitmap->rank0(pos)-1);
293     return result;
294 }
295
296
297 uint wt_node_internal::select(uint symbol, uint pos, uint l, wt_coder * c) {
298         bool is_set = c->is_set(symbol, l);
299         uint ret = 0;
300         if(!is_set) {
301                 if(left_child==NULL)
302                         return (uint)(-1);
303                 uint new_pos = left_child->select(symbol, pos, l+1,c);
304                 if(new_pos+1==0) return (uint)(-1);
305                 ret = bitmap->select0(new_pos)+1;
306         } else {
307                 if(right_child==NULL)
308                         return (uint)(-1);
309                 uint new_pos = right_child->select(symbol, pos, l+1,c);
310                 if(new_pos+1==0) return (uint)(-1);
311                 ret = bitmap->select1(new_pos)+1;
312         }
313         if(ret==0) return (uint)-1;
314         return ret;
315 }
316
317 uint wt_node_internal::access(uint pos) {
318         bool is_set = bitmap->access(pos);
319         if(!is_set) {
320                 assert(left_child!=NULL);
321                 return left_child->access(bitmap->rank0(pos)-1);
322         } else {
323                 assert(right_child!=NULL);
324                 return right_child->access(bitmap->rank1(pos)-1);
325         }
326 }
327
328 // Returns the value at given position and its rank
329 uint wt_node_internal::access(uint pos, uint &rank) 
330 {
331     // p is the internal node we are pointing our finger at each step
332     wt_node_internal *p = this;
333
334     while(1)
335     {
336         bool is_set = p->bitmap->access(pos);
337 //        cout << "is_set = " << is_set << ", pos = " << pos << ", rank0 = " << bitmap->rank0(pos) << ", rank1 = " << bitmap->rank1(pos) << endl;
338         if(!is_set)
339         {
340             // recurse left
341             pos = p->bitmap->rank0(pos)-1;
342             wt_node_internal *tmp = dynamic_cast<wt_node_internal *>(p->left_child);
343             if (tmp == NULL)
344             {
345                 // it's a leaf
346                 rank = pos+1;
347                 return p->left_child->access(0);
348             }
349             p = tmp; // new internal node
350         } 
351         else 
352         {
353             // recurse right
354             pos = p->bitmap->rank1(pos)-1;
355             wt_node_internal *tmp = dynamic_cast<wt_node_internal *>(p->right_child);
356             if (tmp == NULL)
357             {
358                 // it's a leaf
359                 rank = pos+1;
360                 return p->right_child->access(0);
361             }
362             p = tmp; // new internal node
363         }
364     }
365 }
366
367 void wt_node_internal::access(vector<int> &result, uint i, uint j, uint min, uint max, uint l, uint pivot)
368 {
369     uint symbol = pivot | (1 << l);
370 //    std::cout << "At l = " << l << ", [" << i << ", " << j  << "], [" << min << ", " << max << "], symbol = " << symbol << std::endl;
371
372     if (j < i || max < min)
373         return;
374
375     if (min < symbol)
376     {
377         // Recurse left
378         uint newi = 0;
379         if (i > 0)
380             newi = bitmap->rank0(i - 1);
381         uint newj = bitmap->rank0(j);
382
383         uint newmax = max < symbol - 1 ? max : symbol - 1;
384         if (left_child != NULL && newj > 0)
385             left_child->access(result, newi, newj-1, min, newmax, l-1, pivot);
386     }
387     
388     if (max >= symbol)
389     {
390         // Recurse right
391         uint newi = 0;
392         if (i > 0)
393             newi = bitmap->rank1(i - 1);
394         uint newj = bitmap->rank1(j);
395
396         uint newmin = min > symbol ? min : symbol;
397         if (right_child != NULL && newj > 0)
398             right_child->access(result, newi, newj-1, newmin, max, l-1, symbol);
399     }
400 }
401
402 void wt_node_internal::access(vector<int> &result, uint i, uint j)
403 {
404 //    std::cout << "At l = " << l << ", [" << i << ", " << j  << "], [" << min << ", " << max << "], symbol = " << symbol << std::endl;
405
406     if (j < i)
407         return;
408
409     {
410         // Recurse left
411         uint newi = 0;
412         if (i > 0)
413             newi = bitmap->rank0(i - 1);
414         uint newj = bitmap->rank0(j);
415
416         if (left_child != NULL && newj > 0)
417             left_child->access(result, newi, newj-1);
418     }
419     
420     {
421         // Recurse right
422         uint newi = 0;
423         if (i > 0)
424             newi = bitmap->rank1(i - 1);
425         uint newj = bitmap->rank1(j);
426
427         if (right_child != NULL && newj > 0)
428             right_child->access(result, newi, newj-1);
429     }
430 }
431
432
433 uint wt_node_internal::access(uint i, uint j, uint min, uint max, uint l, uint pivot)
434 {
435     uint count = 0;
436     uint symbol = pivot | (1 << l);
437 //    std::cout << "At l = " << l << ", [" << i << ", " << j  << "], [" << min << ", " << max << "], symbol = " << symbol << std::endl;
438
439     if (j < i || max < min)
440         return 0;
441
442     if (min < symbol)
443     {
444         // Recurse left
445         uint newi = 0;
446         if (i > 0)
447             newi = bitmap->rank0(i - 1);
448         uint newj = bitmap->rank0(j);
449
450         uint newmax = max < symbol - 1 ? max : symbol - 1;
451         if (left_child != NULL && newj > 0)
452             count += left_child->access(newi, newj-1, min, newmax, l-1, pivot);
453     }
454     
455     if (max >= symbol)
456     {
457         // Recurse right
458         uint newi = 0;
459         if (i > 0)
460             newi = bitmap->rank1(i - 1);
461         uint newj = bitmap->rank1(j);
462
463         uint newmin = min > symbol ? min : symbol;
464         if (right_child != NULL && newj > 0)
465             count += right_child->access(newi, newj-1, newmin, max, l-1, symbol);
466     }
467     return count;
468 }
469
470
471 uint wt_node_internal::size() {
472         uint s = bitmap->size()+sizeof(wt_node_internal);
473         if(left_child!=NULL)
474                 s += left_child->size();
475         if(right_child!=NULL)
476                 s += right_child->size();
477         return s;
478 }
479
480 uint wt_node_internal::save(FILE *fp) {
481   uint wr = WT_NODE_INTERNAL_HDR;
482   wr = fwrite(&wr,sizeof(uint),1,fp);
483   if(wr!=1) return 1;
484   if(bitmap->save(fp)) return 1;
485   if(left_child!=NULL) {
486     if(left_child->save(fp)) return 1;
487   } else {
488     wr = WT_NODE_NULL_HDR;
489     wr = fwrite(&wr,sizeof(uint),1,fp);
490     if(wr!=1) return 1;
491   }
492   if(right_child!=NULL) {
493     if(right_child->save(fp)) return 1;
494   } else {
495     wr = WT_NODE_NULL_HDR;
496     wr = fwrite(&wr,sizeof(uint),1,fp);
497     if(wr!=1) return 1;
498   }
499   return 0;
500 }
501
502 wt_node_internal * wt_node_internal::load(FILE *fp) {
503   uint rd;
504   if(fread(&rd,sizeof(uint),1,fp)!=1) return NULL;
505   if(rd!=WT_NODE_INTERNAL_HDR) return NULL;
506   wt_node_internal * ret = new wt_node_internal();
507   ret->bitmap = static_bitsequence::load(fp);
508   ret->left_child = wt_node::load(fp);
509   ret->right_child = wt_node::load(fp);
510   return ret;
511 }