4e6871a802ddf7bad3f2ec02900376490cb574b1
[SXSI/XMLTree.git] / libcds / src / static_sequence / wt_node_internal.cpp
1 /* wt_node_internal.cpp
2  * Copyright (C) 2008, Francisco Claude, all rights reserved.
3  *
4  * wt_node_internal
5  *
6  * This library is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with this library; if not, write to the Free Software
18  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
19  *
20  */
21  
22 #include <wt_node_internal.h>
23
24 wt_node_internal::wt_node_internal(uint * symbols, uint n, uint l, wt_coder * c, static_bitsequence_builder * bmb) {
25         uint * ibitmap = new uint[n/W+1];
26         for(uint i=0;i<n/W+1;i++)
27                 ibitmap[i]=0;
28         for(uint i=0;i<n;i++) 
29                 if(c->is_set(symbols[i],l))
30                         bitset(ibitmap,i);
31         bitmap = bmb->build(ibitmap, n);
32   delete [] ibitmap;
33         uint count_right = bitmap->rank1(n-1);
34         uint count_left = n-count_right+1;
35         uint * left = new uint[count_left+1];
36         uint * right = new uint[count_right+1];
37         count_right = count_left = 0;
38         bool match_left = true, match_right = true;
39         for(uint i=0;i<n;i++) {
40                 if(bitmap->access(i)) {
41                         right[count_right++]=symbols[i];
42                         if(count_right>1)
43                                 if(right[count_right-1]!=right[count_right-2])
44                                         match_right = false;
45                 }
46                 else {
47                         left[count_left++]=symbols[i];
48                         if(count_left>1)
49                                 if(left[count_left-1]!=left[count_left-2])
50                                         match_left = false;
51                 }
52         }
53         if(count_left>0) {
54                 if(match_left/* && c->done(left[0],l+1)*/)
55                         left_child = new wt_node_leaf(left[0], count_left);
56                 else
57                         left_child = new wt_node_internal(left, count_left, l+1, c, bmb);
58         } else {
59                 left_child = NULL;
60         }
61         if(count_right>0) {
62                 if(match_right/* && c->done(right[0],l+1)*/)
63                         right_child = new wt_node_leaf(right[0], count_right);
64                 else
65                         right_child = new wt_node_internal(right, count_right, l+1, c, bmb);
66         } else {
67                 right_child = NULL;
68         }
69         delete [] left;
70         delete [] right;
71 }
72
73 // Deletes symbols array!
74 wt_node_internal::wt_node_internal(uchar * symbols, uint n, uint l, wt_coder * c, static_bitsequence_builder * bmb) {
75         uint * ibitmap = new uint[n/W+1];
76         for(uint i=0;i<n/W+1;i++)
77                 ibitmap[i]=0;
78         for(uint i=0;i<n;i++) 
79                 if(c->is_set((uint)symbols[i],l))
80                         bitset(ibitmap,i);
81         bitmap = bmb->build(ibitmap, n);
82   delete [] ibitmap;
83         uint count_right = bitmap->rank1(n-1);
84         uint count_left = n-count_right+1;
85         uchar * left = new uchar[count_left+1];
86         uchar * right = new uchar[count_right+1];
87         count_right = count_left = 0;
88         bool match_left = true, match_right = true;
89         for(uint i=0;i<n;i++) {
90                 if(bitmap->access(i)) {
91                         right[count_right++]=symbols[i];
92                         if(count_right>1)
93                                 if(right[count_right-1]!=right[count_right-2])
94                                         match_right = false;
95                 }
96                 else {
97                         left[count_left++]=symbols[i];
98                         if(count_left>1)
99                                 if(left[count_left-1]!=left[count_left-2])
100                                         match_left = false;
101                 }
102         }
103
104         delete [] symbols; 
105         symbols = 0;
106
107         if(count_left>0) {
108                 if(match_left/* && c->done(left[0],l+1)*/)
109                 {
110                     left_child = new wt_node_leaf((uint)left[0], count_left);
111                     delete [] left;
112                     left = 0;
113                 }
114                 else
115                 {
116                     left_child = new wt_node_internal(left, count_left, l+1, c, bmb);
117                     left = 0; // Already deleted
118                 }
119         } else {
120                 left_child = NULL;
121         }
122         if(count_right>0) {
123                 if(match_right/* && c->done(right[0],l+1)*/)
124                 {
125                     right_child = new wt_node_leaf((uint)right[0], count_right);
126                     delete [] right;
127                     right = 0;
128                 }
129                 else 
130                 {
131                     right_child = new wt_node_internal(right, count_right, l+1, c, bmb);
132                     right = 0; // Already deleted
133                 }
134         } else {
135                 right_child = NULL;
136         }
137         delete [] left; // already deleted if count_left > 0
138         delete [] right;
139 }
140
141 wt_node_internal::wt_node_internal(uchar * symbols, uint n, uint l, wt_coder * c, static_bitsequence_builder * bmb, uint left, uint *done) {
142         uint * ibitmap = new uint[n/W+1];
143         for(uint i=0;i<n/W+1;i++)
144                 ibitmap[i]=0;
145         for(uint i=0;i<n;i++) 
146                 if(c->is_set((uint)symbols[i + left],l))
147                         bitset(ibitmap,i);
148         bitmap = bmb->build(ibitmap, n);
149         delete [] ibitmap;
150
151         uint count_right = bitmap->rank1(n-1);
152         uint count_left = n-count_right;
153 /*      uchar * leftarr = new uchar[count_left+1];
154         uchar * rightarr = new uchar[count_right+1];
155         count_right = count_left = 0;
156         for(uint i=0;i<n;i++) {
157                 if(bitmap->access(i)) {
158                         rightarr[count_right++]=symbols[i+left];
159                 }
160                 else {
161                         leftarr[count_left++]=symbols[i+left];
162                 }
163                 }
164 */
165
166         for (uint i=0;i<n;i++)
167             set_field(done, 1, i+left, 0);
168
169         for (uint i = 0; i < n; ) 
170         {
171             uint j = i;
172             uchar swap = symbols[j+left];
173             while (!get_field(done, 1, j+left)) { // swapping
174                 ulong k = j; 
175                 if (!c->is_set(swap,l)) 
176                     j = bitmap->rank0(k)-1;
177                 else 
178                     j = count_left + bitmap->rank1(k)-1;
179                 uchar temp = symbols[j+left];
180                 symbols[j+left] = swap;
181                 swap = temp;
182                 set_field(done,1,k+left,1);
183             }
184
185             while (get_field(done,1,i+left))
186                    ++i;
187         }
188
189         // checking
190         /*       for (uint i=0;i<n;i++)
191             if (!bitget(done,i+left)) 
192                 std::cout << "not swapped: " << i << "\n";
193                for (uint i=0;i<count_left;i++)
194             if (leftarr[i] != symbols[i+left]) //c->is_set(symbols[i+left], l)) 
195             {    
196                 std::cout << symbols[i+left] << " != " << leftarr[i] << " lev = " << l << "\n";
197                 exit(0);
198             }
199         for (uint i=count_left;i<n;i++)
200             if (rightarr[i-count_left] != symbols[i+left]) //!c->is_set(symbols[i+left],l)) 
201                 std::cout << symbols[i+left] << " != " << rightarr[i-count_left] <<  " lev = " << l <<  "\n";    
202         */
203         bool match_left = true, match_right = true;
204         for (uint i=1; i < count_left; i++)
205             if (symbols[i+left] != symbols[i+left-1])
206                 match_left = false;
207         for (uint i=count_left + 1; i < n; i++)
208             if (symbols[i+left] != symbols[i+left-1])
209                 match_right = false;
210
211
212         if(count_left>0) {
213                 if(match_left/* && c->done(left[0],l+1)*/)
214                     left_child = new wt_node_leaf((uint)symbols[left], count_left);
215                 else
216                     left_child = new wt_node_internal(symbols, count_left, l+1, c, bmb, left, done);
217         } else {
218                 left_child = NULL;
219         }
220         if(count_right>0) {
221                 if(match_right/* && c->done(right[0],l+1)*/)
222                     right_child = new wt_node_leaf((uint)symbols[left+count_left], count_right);
223                 else 
224                     right_child = new wt_node_internal(symbols, count_right, l+1, c, bmb, left+count_left, done);
225         } else {
226                 right_child = NULL;
227         }
228 }
229
230
231 wt_node_internal::wt_node_internal() { }
232
233 wt_node_internal::~wt_node_internal() {
234         delete bitmap;
235         if(right_child!=NULL) delete right_child;
236         if(left_child!=NULL) delete left_child;
237 }
238
239 uint wt_node_internal::rank(uint symbol, uint pos, uint l, wt_coder * c) {
240         bool is_set = c->is_set(symbol,l);
241         if(!is_set) {
242                 if(left_child==NULL) return 0;
243                 return left_child->rank(symbol, bitmap->rank0(pos)-1,l+1,c);
244         }
245         else {
246                 if(right_child==NULL) return 0;
247                 return right_child->rank(symbol, bitmap->rank1(pos)-1,l+1,c);
248         }
249 }
250
251 // return value is rank of symbol (less or equal to the given symbol) that has rank > 0, 
252 // the parameter symbol is updated accordinly
253 uint wt_node_internal::rankLessThan(uint &symbol, uint pos) 
254 {
255     uint result = -1;
256     using std::cout;
257     using std::endl;
258 //    cout << "pos = " << pos << ", symbol = " << symbol << endl;
259     
260     if (pos == (uint)-1)
261         return (uint)-1;
262     if(right_child!=NULL)
263         result = right_child->rankLessThan(symbol, bitmap->rank1(pos)-1);
264     if(result == (uint)-1 && left_child!=NULL)
265         return left_child->rankLessThan(symbol, bitmap->rank0(pos)-1);
266     return result;
267 }
268
269
270 uint wt_node_internal::select(uint symbol, uint pos, uint l, wt_coder * c) {
271         bool is_set = c->is_set(symbol, l);
272         uint ret = 0;
273         if(!is_set) {
274                 if(left_child==NULL)
275                         return (uint)(-1);
276                 uint new_pos = left_child->select(symbol, pos, l+1,c);
277                 if(new_pos+1==0) return (uint)(-1);
278                 ret = bitmap->select0(new_pos)+1;
279         } else {
280                 if(right_child==NULL)
281                         return (uint)(-1);
282                 uint new_pos = right_child->select(symbol, pos, l+1,c);
283                 if(new_pos+1==0) return (uint)(-1);
284                 ret = bitmap->select1(new_pos)+1;
285         }
286         if(ret==0) return (uint)-1;
287         return ret;
288 }
289
290 uint wt_node_internal::access(uint pos) {
291         bool is_set = bitmap->access(pos);
292         if(!is_set) {
293                 assert(left_child!=NULL);
294                 return left_child->access(bitmap->rank0(pos)-1);
295         } else {
296                 assert(right_child!=NULL);
297                 return right_child->access(bitmap->rank1(pos)-1);
298         }
299 }
300
301 // Returns the value at given position and its rank
302 uint wt_node_internal::access(uint pos, uint &rank) 
303 {
304     // p is the internal node we are pointing our finger at each step
305     wt_node_internal *p = this;
306
307     while(1)
308     {
309         bool is_set = p->bitmap->access(pos);
310 //        cout << "is_set = " << is_set << ", pos = " << pos << ", rank0 = " << bitmap->rank0(pos) << ", rank1 = " << bitmap->rank1(pos) << endl;
311         if(!is_set)
312         {
313             // recurse left
314             pos = p->bitmap->rank0(pos)-1;
315             wt_node_internal *tmp = dynamic_cast<wt_node_internal *>(p->left_child);
316             if (tmp == NULL)
317             {
318                 // it's a leaf
319                 rank = pos+1;
320                 return p->left_child->access(0);
321             }
322             p = tmp; // new internal node
323         } 
324         else 
325         {
326             // recurse right
327             pos = p->bitmap->rank1(pos)-1;
328             wt_node_internal *tmp = dynamic_cast<wt_node_internal *>(p->right_child);
329             if (tmp == NULL)
330             {
331                 // it's a leaf
332                 rank = pos+1;
333                 return p->right_child->access(0);
334             }
335             p = tmp; // new internal node
336         }
337     }
338 }
339
340 void wt_node_internal::access(vector<int> &result, uint i, uint j, uint min, uint max, uint l, uint pivot)
341 {
342     uint symbol = pivot | (1 << l);
343 //    std::cout << "At l = " << l << ", [" << i << ", " << j  << "], [" << min << ", " << max << "], symbol = " << symbol << std::endl;
344
345     if (j < i || max < min)
346         return;
347
348     if (min < symbol)
349     {
350         // Recurse left
351         uint newi = 0;
352         if (i > 0)
353             newi = bitmap->rank0(i - 1);
354         uint newj = bitmap->rank0(j);
355
356         uint newmax = max < symbol - 1 ? max : symbol - 1;
357         if (left_child != NULL && newj > 0)
358             left_child->access(result, newi, newj-1, min, newmax, l-1, pivot);
359     }
360     
361     if (max >= symbol)
362     {
363         // Recurse right
364         uint newi = 0;
365         if (i > 0)
366             newi = bitmap->rank1(i - 1);
367         uint newj = bitmap->rank1(j);
368
369         uint newmin = min > symbol ? min : symbol;
370         if (right_child != NULL && newj > 0)
371             right_child->access(result, newi, newj-1, newmin, max, l-1, symbol);
372     }
373 }
374
375 void wt_node_internal::access(vector<int> &result, uint i, uint j)
376 {
377 //    std::cout << "At l = " << l << ", [" << i << ", " << j  << "], [" << min << ", " << max << "], symbol = " << symbol << std::endl;
378
379     if (j < i)
380         return;
381
382     {
383         // Recurse left
384         uint newi = 0;
385         if (i > 0)
386             newi = bitmap->rank0(i - 1);
387         uint newj = bitmap->rank0(j);
388
389         if (left_child != NULL && newj > 0)
390             left_child->access(result, newi, newj-1);
391     }
392     
393     {
394         // Recurse right
395         uint newi = 0;
396         if (i > 0)
397             newi = bitmap->rank1(i - 1);
398         uint newj = bitmap->rank1(j);
399
400         if (right_child != NULL && newj > 0)
401             right_child->access(result, newi, newj-1);
402     }
403 }
404
405 // Count
406 uint wt_node_internal::access(uint i, uint j, uint min, uint max, uint l, uint pivot)
407 {
408     uint count = 0;
409     uint symbol = pivot | (1 << l);
410 //    std::cout << "At l = " << l << ", [" << i << ", " << j  << "], [" << min << ", " << max << "], symbol = " << symbol << std::endl;
411
412     if (j < i || max < min)
413         return 0;
414
415     if (min < symbol)
416     {
417         // Recurse left
418         uint newi = 0;
419         if (i > 0)
420             newi = bitmap->rank0(i - 1);
421         uint newj = bitmap->rank0(j);
422
423         uint newmax = max < symbol - 1 ? max : symbol - 1;
424         if (left_child != NULL && newj > 0)
425             count += left_child->access(newi, newj-1, min, newmax, l-1, pivot);
426     }
427     
428     if (max >= symbol)
429     {
430         // Recurse right
431         uint newi = 0;
432         if (i > 0)
433             newi = bitmap->rank1(i - 1);
434         uint newj = bitmap->rank1(j);
435
436         uint newmin = min > symbol ? min : symbol;
437         if (right_child != NULL && newj > 0)
438             count += right_child->access(newi, newj-1, newmin, max, l-1, symbol);
439     }
440     return count;
441 }
442
443
444 uint wt_node_internal::size() {
445         uint s = bitmap->size()+sizeof(wt_node_internal);
446         if(left_child!=NULL)
447                 s += left_child->size();
448         if(right_child!=NULL)
449                 s += right_child->size();
450         return s;
451 }
452
453 uint wt_node_internal::save(FILE *fp) {
454   uint wr = WT_NODE_INTERNAL_HDR;
455   wr = fwrite(&wr,sizeof(uint),1,fp);
456   if(wr!=1) return 1;
457   if(bitmap->save(fp)) return 1;
458   if(left_child!=NULL) {
459     if(left_child->save(fp)) return 1;
460   } else {
461     wr = WT_NODE_NULL_HDR;
462     wr = fwrite(&wr,sizeof(uint),1,fp);
463     if(wr!=1) return 1;
464   }
465   if(right_child!=NULL) {
466     if(right_child->save(fp)) return 1;
467   } else {
468     wr = WT_NODE_NULL_HDR;
469     wr = fwrite(&wr,sizeof(uint),1,fp);
470     if(wr!=1) return 1;
471   }
472   return 0;
473 }
474
475 wt_node_internal * wt_node_internal::load(FILE *fp) {
476   uint rd;
477   if(fread(&rd,sizeof(uint),1,fp)!=1) return NULL;
478   if(rd!=WT_NODE_INTERNAL_HDR) return NULL;
479   wt_node_internal * ret = new wt_node_internal();
480   ret->bitmap = static_bitsequence::load(fp);
481   ret->left_child = wt_node::load(fp);
482   ret->right_child = wt_node::load(fp);
483   return ret;
484 }