ca21c73d7e825c673f05ac27bd1d20e04856c06a
[SXSI/XMLTree.git] / libcds / src / static_sequence / wt_node_internal.cpp
1 /* wt_node_internal.cpp
2  * Copyright (C) 2008, Francisco Claude, all rights reserved.
3  *
4  * wt_node_internal
5  *
6  * This library is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with this library; if not, write to the Free Software
18  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
19  *
20  */
21  
22 #include <wt_node_internal.h>
23
24 wt_node_internal::wt_node_internal(uint * symbols, uint n, uint l, wt_coder * c, static_bitsequence_builder * bmb) {
25         uint * ibitmap = new uint[n/W+1];
26         for(uint i=0;i<n/W+1;i++)
27                 ibitmap[i]=0;
28         for(uint i=0;i<n;i++) 
29                 if(c->is_set(symbols[i],l))
30                         bitset(ibitmap,i);
31         bitmap = bmb->build(ibitmap, n);
32   delete [] ibitmap;
33         uint count_right = bitmap->rank1(n-1);
34         uint count_left = n-count_right+1;
35         uint * left = new uint[count_left+1];
36         uint * right = new uint[count_right+1];
37         count_right = count_left = 0;
38         bool match_left = true, match_right = true;
39         for(uint i=0;i<n;i++) {
40                 if(bitmap->access(i)) {
41                         right[count_right++]=symbols[i];
42                         if(count_right>1)
43                                 if(right[count_right-1]!=right[count_right-2])
44                                         match_right = false;
45                 }
46                 else {
47                         left[count_left++]=symbols[i];
48                         if(count_left>1)
49                                 if(left[count_left-1]!=left[count_left-2])
50                                         match_left = false;
51                 }
52         }
53         if(count_left>0) {
54                 if(match_left/* && c->done(left[0],l+1)*/)
55                         left_child = new wt_node_leaf(left[0], count_left);
56                 else
57                         left_child = new wt_node_internal(left, count_left, l+1, c, bmb);
58         } else {
59                 left_child = NULL;
60         }
61         if(count_right>0) {
62                 if(match_right/* && c->done(right[0],l+1)*/)
63                         right_child = new wt_node_leaf(right[0], count_right);
64                 else
65                         right_child = new wt_node_internal(right, count_right, l+1, c, bmb);
66         } else {
67                 right_child = NULL;
68         }
69         delete [] left;
70         delete [] right;
71 }
72
73 wt_node_internal::wt_node_internal(uchar * symbols, uint n, uint l, wt_coder * c, static_bitsequence_builder * bmb, uint left, uint *done) {
74         uint * ibitmap = new uint[n/W+1];
75         for(uint i=0;i<n/W+1;i++)
76                 ibitmap[i]=0;
77         for(uint i=0;i<n;i++) 
78                 if(c->is_set((uint)symbols[i + left],l))
79                         bitset(ibitmap,i);
80         bitmap = bmb->build(ibitmap, n);
81         delete [] ibitmap;
82
83         uint count_right = bitmap->rank1(n-1);
84         uint count_left = n-count_right;
85
86         for (uint i=0;i<n;i++)
87             set_field(done, 1, i+left, 0);
88
89         for (uint i = 0; i < n; ) 
90         {
91             uint j = i;
92             uchar swap = symbols[j+left];
93             while (!get_field(done, 1, j+left)) { // swapping
94                 ulong k = j; 
95                 if (!c->is_set(swap,l)) 
96                     j = bitmap->rank0(k)-1;
97                 else 
98                     j = count_left + bitmap->rank1(k)-1;
99                 uchar temp = symbols[j+left];
100                 symbols[j+left] = swap;
101                 swap = temp;
102                 set_field(done,1,k+left,1);
103             }
104
105             while (get_field(done,1,i+left))
106                    ++i;
107         }
108
109         bool match_left = true, match_right = true;
110         for (uint i=1; i < count_left; i++)
111             if (symbols[i+left] != symbols[i+left-1])
112                 match_left = false;
113         for (uint i=count_left + 1; i < n; i++)
114             if (symbols[i+left] != symbols[i+left-1])
115                 match_right = false;
116
117
118         if(count_left>0) {
119                 if(match_left/* && c->done(left[0],l+1)*/)
120                     left_child = new wt_node_leaf((uint)symbols[left], count_left);
121                 else
122                     left_child = new wt_node_internal(symbols, count_left, l+1, c, bmb, left, done);
123         } else {
124                 left_child = NULL;
125         }
126         if(count_right>0) {
127                 if(match_right/* && c->done(right[0],l+1)*/)
128                     right_child = new wt_node_leaf((uint)symbols[left+count_left], count_right);
129                 else 
130                     right_child = new wt_node_internal(symbols, count_right, l+1, c, bmb, left+count_left, done);
131         } else {
132                 right_child = NULL;
133         }
134 }
135
136
137 wt_node_internal::wt_node_internal() { }
138
139 wt_node_internal::~wt_node_internal() {
140         delete bitmap;
141         if(right_child!=NULL) delete right_child;
142         if(left_child!=NULL) delete left_child;
143 }
144
145 uint wt_node_internal::rank(uint symbol, uint pos, uint l, wt_coder * c) {
146         bool is_set = c->is_set(symbol,l);
147         if(!is_set) {
148                 if(left_child==NULL) return 0;
149                 return left_child->rank(symbol, bitmap->rank0(pos)-1,l+1,c);
150         }
151         else {
152                 if(right_child==NULL) return 0;
153                 return right_child->rank(symbol, bitmap->rank1(pos)-1,l+1,c);
154         }
155 }
156
157 // return value is rank of symbol (less or equal to the given symbol) that has rank > 0, 
158 // the parameter symbol is updated accordinly
159 uint wt_node_internal::rankLessThan(uint &symbol, uint pos) 
160 {
161     uint result = -1;
162     using std::cout;
163     using std::endl;
164 //    cout << "pos = " << pos << ", symbol = " << symbol << endl;
165     
166     if (pos == (uint)-1)
167         return (uint)-1;
168     if(right_child!=NULL)
169         result = right_child->rankLessThan(symbol, bitmap->rank1(pos)-1);
170     if(result == (uint)-1 && left_child!=NULL)
171         return left_child->rankLessThan(symbol, bitmap->rank0(pos)-1);
172     return result;
173 }
174
175
176 uint wt_node_internal::select(uint symbol, uint pos, uint l, wt_coder * c) {
177         bool is_set = c->is_set(symbol, l);
178         uint ret = 0;
179         if(!is_set) {
180                 if(left_child==NULL)
181                         return (uint)(-1);
182                 uint new_pos = left_child->select(symbol, pos, l+1,c);
183                 if(new_pos+1==0) return (uint)(-1);
184                 ret = bitmap->select0(new_pos)+1;
185         } else {
186                 if(right_child==NULL)
187                         return (uint)(-1);
188                 uint new_pos = right_child->select(symbol, pos, l+1,c);
189                 if(new_pos+1==0) return (uint)(-1);
190                 ret = bitmap->select1(new_pos)+1;
191         }
192         if(ret==0) return (uint)-1;
193         return ret;
194 }
195
196 uint wt_node_internal::access(uint pos) {
197         bool is_set = bitmap->access(pos);
198         if(!is_set) {
199                 assert(left_child!=NULL);
200                 return left_child->access(bitmap->rank0(pos)-1);
201         } else {
202                 assert(right_child!=NULL);
203                 return right_child->access(bitmap->rank1(pos)-1);
204         }
205 }
206
207 // Returns the value at given position and its rank
208 uint wt_node_internal::access(uint pos, uint &rank) 
209 {
210     // p is the internal node we are pointing our finger at each step
211     wt_node_internal *p = this;
212
213     while(1)
214     {
215         bool is_set = p->bitmap->access(pos);
216 //        cout << "is_set = " << is_set << ", pos = " << pos << ", rank0 = " << bitmap->rank0(pos) << ", rank1 = " << bitmap->rank1(pos) << endl;
217         if(!is_set)
218         {
219             // recurse left
220             pos = p->bitmap->rank0(pos)-1;
221             wt_node_internal *tmp = dynamic_cast<wt_node_internal *>(p->left_child);
222             if (tmp == NULL)
223             {
224                 // it's a leaf
225                 rank = pos+1;
226                 return p->left_child->access(0);
227             }
228             p = tmp; // new internal node
229         } 
230         else 
231         {
232             // recurse right
233             pos = p->bitmap->rank1(pos)-1;
234             wt_node_internal *tmp = dynamic_cast<wt_node_internal *>(p->right_child);
235             if (tmp == NULL)
236             {
237                 // it's a leaf
238                 rank = pos+1;
239                 return p->right_child->access(0);
240             }
241             p = tmp; // new internal node
242         }
243     }
244 }
245
246 void wt_node_internal::access(vector<int> &result, uint i, uint j, uint min, uint max, uint l, uint pivot)
247 {
248     uint symbol = pivot | (1 << l);
249 //    std::cout << "At l = " << l << ", [" << i << ", " << j  << "], [" << min << ", " << max << "], symbol = " << symbol << std::endl;
250
251     if (j < i || max < min)
252         return;
253
254     if (min < symbol)
255     {
256         // Recurse left
257         uint newi = 0;
258         if (i > 0)
259             newi = bitmap->rank0(i - 1);
260         uint newj = bitmap->rank0(j);
261
262         uint newmax = max < symbol - 1 ? max : symbol - 1;
263         if (left_child != NULL && newj > 0)
264             left_child->access(result, newi, newj-1, min, newmax, l-1, pivot);
265     }
266     
267     if (max >= symbol)
268     {
269         // Recurse right
270         uint newi = 0;
271         if (i > 0)
272             newi = bitmap->rank1(i - 1);
273         uint newj = bitmap->rank1(j);
274
275         uint newmin = min > symbol ? min : symbol;
276         if (right_child != NULL && newj > 0)
277             right_child->access(result, newi, newj-1, newmin, max, l-1, symbol);
278     }
279 }
280
281 void wt_node_internal::access(vector<int> &result, uint i, uint j)
282 {
283 //    std::cout << "At l = " << l << ", [" << i << ", " << j  << "], [" << min << ", " << max << "], symbol = " << symbol << std::endl;
284
285     if (j < i)
286         return;
287
288     {
289         // Recurse left
290         uint newi = 0;
291         if (i > 0)
292             newi = bitmap->rank0(i - 1);
293         uint newj = bitmap->rank0(j);
294
295         if (left_child != NULL && newj > 0)
296             left_child->access(result, newi, newj-1);
297     }
298     
299     {
300         // Recurse right
301         uint newi = 0;
302         if (i > 0)
303             newi = bitmap->rank1(i - 1);
304         uint newj = bitmap->rank1(j);
305
306         if (right_child != NULL && newj > 0)
307             right_child->access(result, newi, newj-1);
308     }
309 }
310
311 // Count
312 uint wt_node_internal::access(uint i, uint j, uint min, uint max, uint l, uint pivot)
313 {
314     uint count = 0;
315     uint symbol = pivot | (1 << l);
316 //    std::cout << "At l = " << l << ", [" << i << ", " << j  << "], [" << min << ", " << max << "], symbol = " << symbol << std::endl;
317
318     if (j < i || max < min)
319         return 0;
320
321     if (min < symbol)
322     {
323         // Recurse left
324         uint newi = 0;
325         if (i > 0)
326             newi = bitmap->rank0(i - 1);
327         uint newj = bitmap->rank0(j);
328
329         uint newmax = max < symbol - 1 ? max : symbol - 1;
330         if (left_child != NULL && newj > 0)
331             count += left_child->access(newi, newj-1, min, newmax, l-1, pivot);
332     }
333     
334     if (max >= symbol)
335     {
336         // Recurse right
337         uint newi = 0;
338         if (i > 0)
339             newi = bitmap->rank1(i - 1);
340         uint newj = bitmap->rank1(j);
341
342         uint newmin = min > symbol ? min : symbol;
343         if (right_child != NULL && newj > 0)
344             count += right_child->access(newi, newj-1, newmin, max, l-1, symbol);
345     }
346     return count;
347 }
348
349
350 uint wt_node_internal::size() {
351         uint s = bitmap->size()+sizeof(wt_node_internal);
352         if(left_child!=NULL)
353                 s += left_child->size();
354         if(right_child!=NULL)
355                 s += right_child->size();
356         return s;
357 }
358
359 uint wt_node_internal::save(FILE *fp) {
360   uint wr = WT_NODE_INTERNAL_HDR;
361   wr = fwrite(&wr,sizeof(uint),1,fp);
362   if(wr!=1) return 1;
363   if(bitmap->save(fp)) return 1;
364   if(left_child!=NULL) {
365     if(left_child->save(fp)) return 1;
366   } else {
367     wr = WT_NODE_NULL_HDR;
368     wr = fwrite(&wr,sizeof(uint),1,fp);
369     if(wr!=1) return 1;
370   }
371   if(right_child!=NULL) {
372     if(right_child->save(fp)) return 1;
373   } else {
374     wr = WT_NODE_NULL_HDR;
375     wr = fwrite(&wr,sizeof(uint),1,fp);
376     if(wr!=1) return 1;
377   }
378   return 0;
379 }
380
381 wt_node_internal * wt_node_internal::load(FILE *fp) {
382   uint rd;
383   if(fread(&rd,sizeof(uint),1,fp)!=1) return NULL;
384   if(rd!=WT_NODE_INTERNAL_HDR) return NULL;
385   wt_node_internal * ret = new wt_node_internal();
386   ret->bitmap = static_bitsequence::load(fp);
387   ret->left_child = wt_node::load(fp);
388   ret->right_child = wt_node::load(fp);
389   return ret;
390 }