New WVTree constructor
[SXSI/XMLTree.git] / libcds / src / static_sequence / static_sequence_wvtree_noptrs.cpp
1 /* static_sequence_wvtree_noptrs.cpp
2  * Copyright (C) 2008, Francisco Claude, all rights reserved.
3  *
4  * static_sequence_wvtree_noptrs definition
5  *
6  * This library is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with this library; if not, write to the Free Software
18  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
19  *
20  */
21  
22 #include <static_sequence_wvtree_noptrs.h>
23
24 static_sequence_wvtree_noptrs::static_sequence_wvtree_noptrs(uint * symbols, uint n, static_bitsequence_builder * bmb, alphabet_mapper * am, bool deleteSymbols) {
25   this->n=n;
26   this->am=am;
27         am->use();
28   for(uint i=0;i<n;i++)
29     symbols[i] = am->map(symbols[i]);
30   max_v=max_value(symbols,n);
31   height=bits(max_v);
32   uint *occurrences=new uint[max_v+1];
33   for(uint i=0;i<=max_v;i++) occurrences[i]=0;
34   for(uint i=0;i<n;i++)
35     occurrences[symbols[i]]++;
36   uint to_add=0;
37   for(uint i=0;i<max_v;i++)
38     if(occurrences[i]==0) to_add++;
39   uint * new_symb = new uint[n+to_add];
40   for(uint i=0;i<n;i++)
41     new_symb[i] = symbols[i];
42
43   if (deleteSymbols)
44   {
45       delete [] symbols;
46       symbols = 0;
47   }
48
49   to_add = 0;
50   for(uint i=0;i<max_v;i++)
51     if(occurrences[i]==0) {
52       occurrences[i]++;
53       new_symb[n+to_add]=i;
54       to_add++;
55     }
56   uint new_n = n+to_add;
57   for(uint i=1;i<=max_v;i++)
58     occurrences[i] += occurrences[i-1];
59   uint *oc = new uint[(new_n+1)/W+1];
60   for(uint i=0;i<(new_n+1)/W+1;i++)
61     oc[i] = 0;
62   for(uint i=0;i<=max_v;i++)
63     bitset(oc,occurrences[i]-1);
64   bitset(oc,new_n);
65   occ = bmb->build(oc,new_n+1);
66   delete [] occurrences;
67   this->n = new_n;
68   uint ** _bm=new uint*[height];
69   for(uint i=0;i<height;i++) {
70     _bm[i] = new uint[new_n/W+1];
71     for(uint j=0;j<new_n/W+1;j++)
72       _bm[i][j]=0;
73   }
74   build_level(_bm,new_symb,0,new_n,0);
75   bitstring = new static_bitsequence*[height];
76   for(uint i=0;i<height;i++) {
77         bitstring[i] = bmb->build(_bm[i],new_n);
78     delete [] _bm[i];
79   }
80   delete [] _bm;
81   
82   if (!deleteSymbols)
83       for(uint i=0;i<n;i++)
84           symbols[i] = am->unmap(symbols[i]);
85
86 // delete [] new_symb; // already deleted in build_level()!
87   delete [] oc;
88 }
89
90 // symbols is an array of elements of "width" bits
91 static_sequence_wvtree_noptrs::static_sequence_wvtree_noptrs(uint * symbols, uint n, unsigned width, static_bitsequence_builder * bmb, alphabet_mapper * am, bool deleteSymbols) {
92   this->n=n;
93   this->am=am;
94         am->use();
95   for(uint i=0;i<n;i++)
96       set_field(symbols, width, i, am->map(get_field(symbols, width, i)));
97   max_v=max_value(symbols, width, n);
98   height=bits(max_v);
99   uint *occurrences=new uint[max_v+1];
100   for(uint i=0;i<=max_v;i++) occurrences[i]=0;
101   for(uint i=0;i<n;i++)
102       occurrences[get_field(symbols, width, i)]++;
103   uint to_add=0;
104   for(uint i=0;i<max_v;i++)
105     if(occurrences[i]==0) to_add++;
106   uint * new_symb = new uint[((n+to_add)*width)/W + 1];
107   for(uint i=0;i<n;i++)
108       set_field(new_symb, width, i, get_field(symbols, width, i));
109
110   if (deleteSymbols)
111   {
112       delete [] symbols;
113       symbols = 0;
114   }
115
116   to_add = 0;
117   for(uint i=0;i<max_v;i++)
118     if(occurrences[i]==0) {
119       occurrences[i]++;
120       set_field(new_symb, width, n+to_add, i);
121       to_add++;
122     }
123   uint new_n = n+to_add;
124   for(uint i=1;i<=max_v;i++)
125     occurrences[i] += occurrences[i-1];
126   uint *oc = new uint[(new_n+1)/W+1];
127   for(uint i=0;i<(new_n+1)/W+1;i++)
128     oc[i] = 0;
129   for(uint i=0;i<=max_v;i++)
130     bitset(oc,occurrences[i]-1);
131   bitset(oc,new_n);
132   occ = bmb->build(oc,new_n+1);
133   delete [] occurrences;
134   this->n = new_n;
135   uint ** _bm=new uint*[height];
136   for(uint i=0;i<height;i++) {
137     _bm[i] = new uint[new_n/W+1];
138     for(uint j=0;j<new_n/W+1;j++)
139       _bm[i][j]=0;
140   }
141   build_level(_bm,new_symb,width,0,new_n,0);
142   bitstring = new static_bitsequence*[height];
143   for(uint i=0;i<height;i++) {
144         bitstring[i] = bmb->build(_bm[i],new_n);
145     delete [] _bm[i];
146   }
147   delete [] _bm;
148   
149   if (!deleteSymbols)
150       for(uint i=0;i<n;i++)
151           set_field(symbols, width, i, am->unmap(get_field(symbols, width, i)));
152
153 // delete [] new_symb; // already deleted in build_level()!
154   delete [] oc;
155 }
156
157 static_sequence_wvtree_noptrs::static_sequence_wvtree_noptrs() {
158 }
159
160 static_sequence_wvtree_noptrs::~static_sequence_wvtree_noptrs() {
161   for(uint i=0;i<height;i++)
162     delete bitstring[i];
163   delete [] bitstring;
164   delete occ;
165         am->unuse();
166 }
167
168 uint static_sequence_wvtree_noptrs::save(FILE *fp) {
169   uint wr = WVTREE_NOPTRS_HDR;
170   wr = fwrite(&wr,sizeof(uint),1,fp);
171   wr += fwrite(&n,sizeof(uint),1,fp);
172   wr += fwrite(&max_v,sizeof(uint),1,fp);
173   wr += fwrite(&height,sizeof(uint),1,fp);
174   if(wr!=4) return 1;
175   if(am->save(fp)) return 1;
176   for(uint i=0;i<height;i++)
177     if(bitstring[i]->save(fp)) return 1;
178         if(occ->save(fp)) return 1;
179   return 0;
180 }
181
182 static_sequence_wvtree_noptrs * static_sequence_wvtree_noptrs::load(FILE *fp) {
183   uint rd;
184   if(fread(&rd,sizeof(uint),1,fp)!=1) return NULL;
185   if(rd!=WVTREE_NOPTRS_HDR) return NULL;
186   static_sequence_wvtree_noptrs * ret = new static_sequence_wvtree_noptrs();
187   rd = fread(&ret->n,sizeof(uint),1,fp);
188   rd += fread(&ret->max_v,sizeof(uint),1,fp);
189   rd += fread(&ret->height,sizeof(uint),1,fp);
190   if(rd!=3) {
191     delete ret;
192     return NULL;
193   }
194   ret->am = alphabet_mapper::load(fp);
195   if(ret->am==NULL) {
196     delete ret;
197     return NULL;
198   }
199         ret->am->use();
200   ret->bitstring = new static_bitsequence*[ret->height];
201   for(uint i=0;i<ret->height;i++) {
202     ret->bitstring[i] = static_bitsequence::load(fp);
203     if(ret->bitstring[i]==NULL){
204       delete ret;
205       return NULL;
206     }
207   }
208         ret->occ = static_bitsequence::load(fp);
209         if(ret->occ==NULL) {
210                 delete ret;
211                 return NULL;
212         }
213   return ret;
214 }
215
216 uint static_sequence_wvtree_noptrs::access(uint pos) {
217   uint level=0;
218   uint ret=0;
219   uint start=0;
220   uint end=n-1;
221   while(level<height) {
222     assert(pos>=start && pos<=end);
223     if(bitstring[level]->access(pos)) {
224       ret=set(ret,level);
225       pos=bitstring[level]->rank1(pos-1)-bitstring[level]->rank1(start-1);
226       start=(bitstring[level]->rank1(end)-bitstring[level]->rank1(start-1));
227       start=end-start+1;
228       pos+=start;
229     }
230     else {
231       pos=pos-start-(bitstring[level]->rank1(pos)-bitstring[level]->rank1(start-1));
232       end=end-start-(bitstring[level]->rank1(end)-bitstring[level]->rank1(start-1));
233       end+=start;
234       pos+=start;
235     }
236     level++;
237   }
238   return am->unmap(ret);
239 }
240
241 uint static_sequence_wvtree_noptrs::rank(uint symbol, uint pos) {
242   symbol = am->map(symbol);
243   uint level=0;
244   uint start=0;
245   uint end=n-1;
246   uint count=0;
247   while(level<height) {
248     if(is_set(symbol,level)) {
249       pos=bitstring[level]->rank1(pos)-bitstring[level]->rank1(start-1)-1;
250       count=pos+1;
251       start=(bitstring[level]->rank1(end)-bitstring[level]->rank1(start-1));
252       start=end-start+1;
253       pos+=start;
254     }
255     else {
256       pos=pos-start+bitstring[level]->rank1(start-1)-bitstring[level]->rank1(pos);
257       count=pos+1;
258       end=end-start-(bitstring[level]->rank1(end)-bitstring[level]->rank1(start-1));
259       end+=start;
260       pos+=start;
261     }
262     level++;
263     if(count==0) return 0;
264   }
265   return count;
266 }
267
268 vector<int> static_sequence_wvtree_noptrs::access(uint i, uint j, uint min, uint max)
269 {
270     vector<int> resultSet;
271 //    cout << "height = " << height << endl;
272     access(resultSet, i, j, am->map(min), am->map(max), 0, 0, 0, n-1);
273     return resultSet;
274 }
275
276 void static_sequence_wvtree_noptrs::access(vector<int> &result, uint i, uint j, uint min, uint max, uint l, uint pivot, uint start, uint end)
277 {
278     uint symbol = pivot | (1 << (height-l-1));
279     //std::cout << "At l = " << l << ", [" << i << ", " << j  << "], [" << min << ", " << max << "], [" << start << ", " << end << "], symbol = " << symbol << std::endl;
280
281     if (l == height)
282     {
283         if (i <= j && pivot >= min && pivot <= max && start <= end)
284             result.push_back(am->unmap((int)pivot));
285         return;
286     }
287
288     if (j < i || max < min || end < start)
289         return;
290
291     if (min < symbol)
292     {
293         // Recurse left
294         uint newi = i + bitstring[l]->rank1(start-1) - bitstring[l]->rank1(i-1);
295         uint newend = end - (bitstring[l]->rank1(end) - bitstring[l]->rank1(start-1));
296         uint newj = j + bitstring[l]->rank1(start-1) - bitstring[l]->rank1(j) + 1;
297
298         uint newmax = max < symbol - 1 ? max : symbol - 1;
299         if (newj > start)
300             access(result, newi, newj-1, min, newmax, l+1, pivot, start, newend);
301     }
302
303     if (max >= symbol)
304     {
305         // Recurse right
306         uint newstart = (bitstring[l]->rank1(end)-bitstring[l]->rank1(start-1));
307         newstart = end - newstart + 1;
308         uint newi = bitstring[l]->rank1(i-1)-bitstring[l]->rank1(start-1) + newstart;
309         uint newj = bitstring[l]->rank1(j)-bitstring[l]->rank1(start-1) + newstart;
310
311         uint newmin = min > symbol ? min : symbol;
312         if (newj > newstart)
313             access(result, newi, newj-1, newmin, max, l+1, symbol, newstart, end);
314     }
315 }
316
317
318 vector<int> static_sequence_wvtree_noptrs::accessAll(uint i, uint j)
319 {
320     vector<int> resultSet;
321     if (j < i)
322         return resultSet;
323
324     resultSet.reserve(j-i+1);
325     accessAll(resultSet, i, j, 0, 0, 0, n-1);
326     return resultSet;
327 }
328
329 void static_sequence_wvtree_noptrs::accessAll(vector<int> &result, uint i, uint j, uint l, uint pivot, uint start, uint end)
330 {
331     uint symbol = pivot | (1 << (height-l-1));
332 //    std::cout << "At l = " << l << ", [" << i << ", " << j  << "], [" << start << ", " << end << "], symbol = " << symbol << std::endl;
333
334     if (l == height)
335     {
336         if (i <= j && start <= end)
337             result.push_back(am->unmap((int)pivot));
338         return;
339     }
340
341     if (j < i || end < start)
342         return;
343
344     {
345         // Recurse left
346         uint newi = i + bitstring[l]->rank1(start-1) - bitstring[l]->rank1(i-1);
347         uint newend = end - (bitstring[l]->rank1(end) - bitstring[l]->rank1(start-1));
348         uint newj = j + bitstring[l]->rank1(start-1) - bitstring[l]->rank1(j) + 1;
349
350         if (newj > start)
351             accessAll(result, newi, newj-1, l+1, pivot, start, newend);
352     }
353
354     {
355         // Recurse right
356         uint newstart = (bitstring[l]->rank1(end)-bitstring[l]->rank1(start-1));
357         newstart = end - newstart + 1;
358         uint newi = bitstring[l]->rank1(i-1)-bitstring[l]->rank1(start-1) + newstart;
359         uint newj = bitstring[l]->rank1(j)-bitstring[l]->rank1(start-1) + newstart;
360
361         if (newj > newstart)
362             accessAll(result, newi, newj-1, l+1, symbol, newstart, end);
363     }
364 }
365
366
367 uint static_sequence_wvtree_noptrs::count(uint i, uint j, uint min, uint max)
368 {
369     return count(i, j, am->map(min), am->map(max), 0, 0, 0, n-1);
370 }
371
372 uint static_sequence_wvtree_noptrs::count(uint i, uint j, uint min, uint max, uint l, uint pivot, uint start, uint end)
373 {
374     uint symbol = pivot | (1 << (height-l-1));
375     //std::cout << "At l = " << l << ", [" << i << ", " << j  << "], [" << min << ", " << max << "], [" << start << ", " << end << "], symbol = " << symbol << std::endl;
376
377     if (l == height)
378     {
379         if (i <= j && pivot >= min && pivot <= max && start <= end)
380             return 1;
381         return 0;
382     }
383
384     if (j < i || max < min || end < start)
385         return 0;
386
387     uint result = 0;
388     if (min < symbol)
389     {
390         // Recurse left
391         uint newi = i + bitstring[l]->rank1(start-1) - bitstring[l]->rank1(i-1);
392         uint newend = end - (bitstring[l]->rank1(end) - bitstring[l]->rank1(start-1));
393         uint newj = j + bitstring[l]->rank1(start-1) - bitstring[l]->rank1(j) + 1;
394
395         uint newmax = max < symbol - 1 ? max : symbol - 1;
396         if (newj > start)
397             result += count(newi, newj-1, min, newmax, l+1, pivot, start, newend);
398     }
399
400     if (max >= symbol)
401     {
402         // Recurse right
403         uint newstart = (bitstring[l]->rank1(end)-bitstring[l]->rank1(start-1));
404         newstart = end - newstart + 1;
405         uint newi = bitstring[l]->rank1(i-1)-bitstring[l]->rank1(start-1) + newstart;
406         uint newj = bitstring[l]->rank1(j)-bitstring[l]->rank1(start-1) + newstart;
407
408         uint newmin = min > symbol ? min : symbol;
409         if (newj > newstart)
410             result += count(newi, newj-1, newmin, max, l+1, symbol, newstart, end);
411     }
412     return result;
413 }
414
415
416
417 inline uint get_start(uint symbol, uint mask) {
418   return symbol&mask;
419 }
420
421 inline uint get_end(uint symbol, uint mask) {
422   return get_start(symbol,mask)+!mask+1;
423 }
424
425 uint static_sequence_wvtree_noptrs::select(uint symbol, uint j) {
426   symbol = am->map(symbol);
427   uint mask = (1<<height)-2;
428   uint sum=2;
429   uint level = height-1;
430   uint pos=j;
431   while(true) {
432     uint start = get_start(symbol,mask);
433     uint end = min(max_v+1,start+sum);
434     start = (start==0)?0:(occ->select1(start)+1);
435     end = occ->select1(end+1)-1;
436     if(is_set(symbol,level)) {
437       uint ones_start = bitstring[level]->rank1(start-1);
438       pos = bitstring[level]->select1(ones_start+pos)-start+1;
439     }
440     else {
441       uint ones_start = bitstring[level]->rank1(start-1);
442       pos = bitstring[level]->select0(start-ones_start+pos)-start+1;
443     }
444     mask <<=1;
445     sum <<=1;
446     if(level==0) break;
447     level--;
448   }
449   return pos-1;
450 }
451
452 uint static_sequence_wvtree_noptrs::size() {
453   uint ptrs = sizeof(static_sequence_wvtree_noptrs)+height*sizeof(static_sequence*);
454   uint bytesBitstrings = 0;
455   for(uint i=0;i<height;i++)
456     bytesBitstrings += bitstring[i]->size();
457   return bytesBitstrings+occ->size()+ptrs;
458 }
459
460 void static_sequence_wvtree_noptrs::build_level(uint **bm, uint *symbols, uint level, uint length, uint offset) {
461   if(level==height)
462   {
463       delete [] symbols;
464       return;
465   }
466   uint cleft=0;
467   for(uint i=0;i<length;i++)
468     if(!is_set(symbols[i],level))
469       cleft++;
470   uint cright=length-cleft;
471   uint *left=new uint[cleft], *right=new uint[cright];
472   cleft=cright=0;
473   for(uint i=0;i<length;i++)
474   if(!is_set(symbols[i],level)) {
475     left[cleft++]=symbols[i];
476     bitclean(bm[level],offset+i);
477   }
478   else {
479     right[cright++]=symbols[i];
480     bitset(bm[level],offset+i);
481   }
482   
483   delete [] symbols;
484   symbols = 0;
485   
486   build_level(bm,left,level+1,cleft,offset);
487   left = 0; // Gets deleted in recursion.
488   build_level(bm,right,level+1,cright,offset+cleft);
489   right = 0; // Gets deleted in recursion.
490   //delete [] left;
491   //delete [] right;
492 }
493
494 // symbols is an array of elements of "width" bits.
495 void static_sequence_wvtree_noptrs::build_level(uint **bm, uint *symbols, unsigned width, uint level, uint length, uint offset) {
496     if(level==height)
497     {
498         delete [] symbols;
499         return;
500     }
501     uint cleft=0;
502     for(uint i=0;i<length;i++)
503         if(!is_set(get_field(symbols, width, i),level))
504             cleft++;
505     uint cright=length-cleft;
506     uint *left=new uint[(cleft*width)/W + 1], 
507         *right=new uint[(cright*width)/W + 1];
508     cleft=cright=0;
509     for(uint i=0;i<length;i++)
510         if(!is_set(get_field(symbols,width,i),level)) {
511             set_field(left,width,cleft++,get_field(symbols, width,i));
512             bitclean(bm[level],offset+i);
513         }
514         else {
515             set_field(right,width,cright++,get_field(symbols,width,i));
516             bitset(bm[level],offset+i);
517         }
518   
519     delete [] symbols;
520     symbols = 0;
521   
522     build_level(bm,left,width,level+1,cleft,offset);
523     left = 0; // Gets deleted in recursion.
524     build_level(bm,right,width,level+1,cright,offset+cleft);
525     right = 0; // Gets deleted in recursion.
526     //delete [] left;
527     //delete [] right;
528 }
529
530 uint static_sequence_wvtree_noptrs::max_value(uint *symbols, uint n) {
531   uint max_v = 0;
532   for(uint i=0;i<n;i++)
533     max_v = max(symbols[i],max_v);
534   return max_v;
535 }
536
537 uint static_sequence_wvtree_noptrs::max_value(uint *symbols, unsigned width, uint n) {
538   uint max_v = 0;
539   for(uint i=0;i<n;i++)
540       max_v = max(get_field(symbols, width, i),max_v);
541   return max_v;
542 }
543
544 uint static_sequence_wvtree_noptrs::bits(uint val) {
545   uint ret = 0;
546   while(val!=0) {
547     ret++;
548     val >>= 1;
549   }
550   return ret;
551 }
552
553 bool static_sequence_wvtree_noptrs::is_set(uint val, uint ind) {
554   assert(ind<height);
555   return (val & (1<<(height-ind-1)))!=0;
556 }
557
558
559 uint static_sequence_wvtree_noptrs::set(uint val, uint ind) {
560   assert(ind<=height);
561   return val | (1<<(height-ind-1));
562 }