d39ea4b126be3f11cd79adc1b2363ad6e1a82cb4
[SXSI/XMLTree.git] / libcds / src / static_sequence / wt_node_internal.cpp
1 /* wt_node_internal.cpp
2  * Copyright (C) 2008, Francisco Claude, all rights reserved.
3  *
4  * wt_node_internal
5  *
6  * This library is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with this library; if not, write to the Free Software
18  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
19  *
20  */
21  
22 #include <wt_node_internal.h>
23
24 wt_node_internal::wt_node_internal(uint * symbols, uint n, uint l, wt_coder * c, static_bitsequence_builder * bmb) {
25         uint * ibitmap = new uint[n/W+1];
26         for(uint i=0;i<n/W+1;i++)
27                 ibitmap[i]=0;
28         for(uint i=0;i<n;i++) 
29                 if(c->is_set(symbols[i],l))
30                         bitset(ibitmap,i);
31         bitmap = bmb->build(ibitmap, n);
32   delete [] ibitmap;
33         uint count_right = bitmap->rank1(n-1);
34         uint count_left = n-count_right+1;
35         uint * left = new uint[count_left+1];
36         uint * right = new uint[count_right+1];
37         count_right = count_left = 0;
38         bool match_left = true, match_right = true;
39         for(uint i=0;i<n;i++) {
40                 if(bitmap->access(i)) {
41                         right[count_right++]=symbols[i];
42                         if(count_right>1)
43                                 if(right[count_right-1]!=right[count_right-2])
44                                         match_right = false;
45                 }
46                 else {
47                         left[count_left++]=symbols[i];
48                         if(count_left>1)
49                                 if(left[count_left-1]!=left[count_left-2])
50                                         match_left = false;
51                 }
52         }
53         if(count_left>0) {
54                 if(match_left/* && c->done(left[0],l+1)*/)
55                         left_child = new wt_node_leaf(left[0], count_left);
56                 else
57                         left_child = new wt_node_internal(left, count_left, l+1, c, bmb);
58         } else {
59                 left_child = NULL;
60         }
61         if(count_right>0) {
62                 if(match_right/* && c->done(right[0],l+1)*/)
63                         right_child = new wt_node_leaf(right[0], count_right);
64                 else
65                         right_child = new wt_node_internal(right, count_right, l+1, c, bmb);
66         } else {
67                 right_child = NULL;
68         }
69         delete [] left;
70         delete [] right;
71 }
72
73 // Deletes symbols array!
74 wt_node_internal::wt_node_internal(uchar * symbols, uint n, uint l, wt_coder * c, static_bitsequence_builder * bmb) {
75         uint * ibitmap = new uint[n/W+1];
76         for(uint i=0;i<n/W+1;i++)
77                 ibitmap[i]=0;
78         for(uint i=0;i<n;i++) 
79                 if(c->is_set((uint)symbols[i],l))
80                         bitset(ibitmap,i);
81         bitmap = bmb->build(ibitmap, n);
82   delete [] ibitmap;
83         uint count_right = bitmap->rank1(n-1);
84         uint count_left = n-count_right+1;
85         uchar * left = new uchar[count_left+1];
86         uchar * right = new uchar[count_right+1];
87         count_right = count_left = 0;
88         bool match_left = true, match_right = true;
89         for(uint i=0;i<n;i++) {
90                 if(bitmap->access(i)) {
91                         right[count_right++]=symbols[i];
92                         if(count_right>1)
93                                 if(right[count_right-1]!=right[count_right-2])
94                                         match_right = false;
95                 }
96                 else {
97                         left[count_left++]=symbols[i];
98                         if(count_left>1)
99                                 if(left[count_left-1]!=left[count_left-2])
100                                         match_left = false;
101                 }
102         }
103
104         delete [] symbols; 
105         symbols = 0;
106
107         if(count_left>0) {
108                 if(match_left/* && c->done(left[0],l+1)*/)
109                 {
110                     left_child = new wt_node_leaf((uint)left[0], count_left);
111                     delete [] left;
112                     left = 0;
113                 }
114                 else
115                 {
116                     left_child = new wt_node_internal(left, count_left, l+1, c, bmb);
117                     left = 0; // Already deleted
118                 }
119         } else {
120                 left_child = NULL;
121         }
122         if(count_right>0) {
123                 if(match_right/* && c->done(right[0],l+1)*/)
124                 {
125                     right_child = new wt_node_leaf((uint)right[0], count_right);
126                     delete [] right;
127                     right = 0;
128                 }
129                 else 
130                 {
131                     right_child = new wt_node_internal(right, count_right, l+1, c, bmb);
132                     right = 0; // Already deleted
133                 }
134         } else {
135                 right_child = NULL;
136         }
137         delete [] left; // already deleted if count_left > 0
138         delete [] right;
139 }
140
141
142 wt_node_internal::wt_node_internal() { }
143
144 wt_node_internal::~wt_node_internal() {
145         delete bitmap;
146         if(right_child!=NULL) delete right_child;
147         if(left_child!=NULL) delete left_child;
148 }
149
150 uint wt_node_internal::rank(uint symbol, uint pos, uint l, wt_coder * c) {
151         bool is_set = c->is_set(symbol,l);
152         if(!is_set) {
153                 if(left_child==NULL) return 0;
154                 return left_child->rank(symbol, bitmap->rank0(pos)-1,l+1,c);
155         }
156         else {
157                 if(right_child==NULL) return 0;
158                 return right_child->rank(symbol, bitmap->rank1(pos)-1,l+1,c);
159         }
160 }
161
162 // return value is rank of symbol (less or equal to the given symbol) that has rank > 0, 
163 // the parameter symbol is updated accordinly
164 uint wt_node_internal::rankLessThan(uint &symbol, uint pos, uint l, wt_coder * c) 
165 {
166     bool is_set = c->is_set(symbol,l);
167     using std::cout;
168     using std::endl;
169 //    cout << "l = " << l << ", symbol = " << (uchar)symbol << ", rank0 = " << bitmap->rank0(pos) << ", rank1 = " << bitmap->rank1(pos) << endl;
170
171     uint result = -1;
172     if(!is_set) {
173         if(left_child==NULL) return -1;
174         uint rank = bitmap->rank0(pos);
175         if(rank != 0)
176             result = left_child->rankLessThan(symbol,rank-1,l+1,c);
177         return result;
178     }
179
180     uint rank = bitmap->rank1(pos);
181     if (rank != 0 && right_child != NULL)
182         result = right_child->rankLessThan(symbol, rank-1,l+1,c);
183
184 //    cout << "recursion to leftchild at l = " << l << ", symbol = " << (uchar)symbol << ", rank0 = " << bitmap->rank0(pos) << ", rank1 = " << bitmap->rank1(pos) << endl;
185     // check left child for symbols <= givenSymbol
186     if (result != (uint)-1 || left_child == NULL)
187         return result;
188     return left_child->rankLessThan(symbol, bitmap->rank0(pos)-1);
189 }
190
191 uint wt_node_internal::rankLessThan(uint &symbol, uint pos) 
192 {
193     uint result = -1;
194     using std::cout;
195     using std::endl;
196 //    cout << "pos = " << pos << ", symbol = " << (uchar)symbol << endl;
197     
198     if (pos == (uint)-1)
199         return (uint)-1;
200     if(right_child!=NULL)
201         result = right_child->rankLessThan(symbol, bitmap->rank1(pos)-1);
202     if(result == (uint)-1 && left_child!=NULL)
203         return left_child->rankLessThan(symbol, bitmap->rank0(pos)-1);
204     return result;
205 }
206
207
208 uint wt_node_internal::select(uint symbol, uint pos, uint l, wt_coder * c) {
209         bool is_set = c->is_set(symbol, l);
210         uint ret = 0;
211         if(!is_set) {
212                 if(left_child==NULL)
213                         return (uint)(-1);
214                 uint new_pos = left_child->select(symbol, pos, l+1,c);
215                 if(new_pos+1==0) return (uint)(-1);
216                 ret = bitmap->select0(new_pos)+1;
217         } else {
218                 if(right_child==NULL)
219                         return (uint)(-1);
220                 uint new_pos = right_child->select(symbol, pos, l+1,c);
221                 if(new_pos+1==0) return (uint)(-1);
222                 ret = bitmap->select1(new_pos)+1;
223         }
224         if(ret==0) return (uint)-1;
225         return ret;
226 }
227
228 uint wt_node_internal::access(uint pos) {
229         bool is_set = bitmap->access(pos);
230         if(!is_set) {
231                 assert(left_child!=NULL);
232                 return left_child->access(bitmap->rank0(pos)-1);
233         } else {
234                 assert(right_child!=NULL);
235                 return right_child->access(bitmap->rank1(pos)-1);
236         }
237 }
238
239 // Returns the value at given position and its rank
240 uint wt_node_internal::access(uint pos, uint &rank) 
241 {
242     // p is the internal node we are pointing our finger at each step
243     wt_node_internal *p = this;
244
245     while(1)
246     {
247         bool is_set = p->bitmap->access(pos);
248 //        cout << "is_set = " << is_set << ", pos = " << pos << ", rank0 = " << bitmap->rank0(pos) << ", rank1 = " << bitmap->rank1(pos) << endl;
249         if(!is_set)
250         {
251             // recurse left
252             pos = p->bitmap->rank0(pos)-1;
253             wt_node_internal *tmp = dynamic_cast<wt_node_internal *>(p->left_child);
254             if (tmp == NULL)
255             {
256                 // it's a leaf
257                 rank = pos+1;
258                 return p->left_child->access(0);
259             }
260             p = tmp; // new internal node
261         } 
262         else 
263         {
264             // recurse right
265             pos = p->bitmap->rank1(pos)-1;
266             wt_node_internal *tmp = dynamic_cast<wt_node_internal *>(p->right_child);
267             if (tmp == NULL)
268             {
269                 // it's a leaf
270                 rank = pos+1;
271                 return p->right_child->access(0);
272             }
273             p = tmp; // new internal node
274         }
275     }
276 }
277
278 void wt_node_internal::access(vector<int> &result, uint i, uint j, uint min, uint max, uint l, uint pivot)
279 {
280     uint symbol = pivot | (1 << l);
281 //    std::cout << "At l = " << l << ", [" << i << ", " << j  << "], [" << min << ", " << max << "], symbol = " << symbol << std::endl;
282
283     if (j < i || max < min)
284         return;
285
286     if (min < symbol)
287     {
288         // Recurse left
289         uint newi = 0;
290         if (i > 0)
291             newi = bitmap->rank0(i - 1);
292         uint newj = bitmap->rank0(j);
293
294         uint newmax = max < symbol - 1 ? max : symbol - 1;
295         if (left_child != NULL && newj > 0)
296             left_child->access(result, newi, newj-1, min, newmax, l-1, pivot);
297     }
298     
299     if (max >= symbol)
300     {
301         // Recurse right
302         uint newi = 0;
303         if (i > 0)
304             newi = bitmap->rank1(i - 1);
305         uint newj = bitmap->rank1(j);
306
307         uint newmin = min > symbol ? min : symbol;
308         if (right_child != NULL && newj > 0)
309             right_child->access(result, newi, newj-1, newmin, max, l-1, symbol);
310     }
311 }
312
313 void wt_node_internal::access(vector<int> &result, uint i, uint j)
314 {
315 //    std::cout << "At l = " << l << ", [" << i << ", " << j  << "], [" << min << ", " << max << "], symbol = " << symbol << std::endl;
316
317     if (j < i)
318         return;
319
320     {
321         // Recurse left
322         uint newi = 0;
323         if (i > 0)
324             newi = bitmap->rank0(i - 1);
325         uint newj = bitmap->rank0(j);
326
327         if (left_child != NULL && newj > 0)
328             left_child->access(result, newi, newj-1);
329     }
330     
331     {
332         // Recurse right
333         uint newi = 0;
334         if (i > 0)
335             newi = bitmap->rank1(i - 1);
336         uint newj = bitmap->rank1(j);
337
338         if (right_child != NULL && newj > 0)
339             right_child->access(result, newi, newj-1);
340     }
341 }
342
343
344 uint wt_node_internal::access(uint i, uint j, uint min, uint max, uint l, uint pivot)
345 {
346     uint count = 0;
347     uint symbol = pivot | (1 << l);
348 //    std::cout << "At l = " << l << ", [" << i << ", " << j  << "], [" << min << ", " << max << "], symbol = " << symbol << std::endl;
349
350     if (j < i || max < min)
351         return 0;
352
353     if (min < symbol)
354     {
355         // Recurse left
356         uint newi = 0;
357         if (i > 0)
358             newi = bitmap->rank0(i - 1);
359         uint newj = bitmap->rank0(j);
360
361         uint newmax = max < symbol - 1 ? max : symbol - 1;
362         if (left_child != NULL && newj > 0)
363             count += left_child->access(newi, newj-1, min, newmax, l-1, pivot);
364     }
365     
366     if (max >= symbol)
367     {
368         // Recurse right
369         uint newi = 0;
370         if (i > 0)
371             newi = bitmap->rank1(i - 1);
372         uint newj = bitmap->rank1(j);
373
374         uint newmin = min > symbol ? min : symbol;
375         if (right_child != NULL && newj > 0)
376             count += right_child->access(newi, newj-1, newmin, max, l-1, symbol);
377     }
378     return count;
379 }
380
381
382 uint wt_node_internal::size() {
383         uint s = bitmap->size()+sizeof(wt_node_internal);
384         if(left_child!=NULL)
385                 s += left_child->size();
386         if(right_child!=NULL)
387                 s += right_child->size();
388         return s;
389 }
390
391 uint wt_node_internal::save(FILE *fp) {
392   uint wr = WT_NODE_INTERNAL_HDR;
393   wr = fwrite(&wr,sizeof(uint),1,fp);
394   if(wr!=1) return 1;
395   if(bitmap->save(fp)) return 1;
396   if(left_child!=NULL) {
397     if(left_child->save(fp)) return 1;
398   } else {
399     wr = WT_NODE_NULL_HDR;
400     wr = fwrite(&wr,sizeof(uint),1,fp);
401     if(wr!=1) return 1;
402   }
403   if(right_child!=NULL) {
404     if(right_child->save(fp)) return 1;
405   } else {
406     wr = WT_NODE_NULL_HDR;
407     wr = fwrite(&wr,sizeof(uint),1,fp);
408     if(wr!=1) return 1;
409   }
410   return 0;
411 }
412
413 wt_node_internal * wt_node_internal::load(FILE *fp) {
414   uint rd;
415   if(fread(&rd,sizeof(uint),1,fp)!=1) return NULL;
416   if(rd!=WT_NODE_INTERNAL_HDR) return NULL;
417   wt_node_internal * ret = new wt_node_internal();
418   ret->bitmap = static_bitsequence::load(fp);
419   ret->left_child = wt_node::load(fp);
420   ret->right_child = wt_node::load(fp);
421   return ret;
422 }