more fixes
[SXSI/XMLTree.git] / libcds / src / static_bitsequence / sdarray.cpp
1
2 #include <sdarray.h>
3
4 #if 0
5 typedef unsigned int qword;
6 #define logD 4
7 #else
8 typedef unsigned long long qword;
9 #define logD 5
10 #endif
11 #define PBS (sizeof(uint)*8)
12 #define D (1<<logD)
13 #define logM 5
14 #define M (1<<logM)
15 #define logP 8
16 #define P (1<<logP)
17 #define logLL 16                 // size of word
18 #define LL (1<<logLL)
19 //#define logLLL 7
20 #define logLLL 5
21 //#define LLL 128
22 //#define LLL 32
23 #define LLL (1<<logLLL)
24 //#define logL 10
25 //#define logL (logLL-3)
26 #define logL (logLL-1-5)
27 #define L (1<<logL)
28
29 int __blog(int x) {
30   int l;
31   l = 0;
32   while (x>0) {
33     x>>=1;
34     l++;
35   }
36   return l;
37 }
38
39
40 int __setbit(uint *B, int i,int x) {
41   int j,l;
42 //printf("%u\n",D);
43   j = i / D;
44   l = i % D;
45   if (x==0) B[j] &= (~(1<<(D-1-l)));
46   else if (x==1) B[j] |= (1<<(D-1-l));
47   else {
48     printf("error __setbit x=%d\n",x);
49     exit(1);
50   }
51   return x;
52 }
53
54
55 int __setbit2(uchar *B, int i,int x) {
56   int j,l;
57
58   j = i / 8;
59   l = i % 8;
60   if (x==0) B[j] &= (~(1<<(8-1-l)));
61   else if (x==1) B[j] |= (1<<(8-1-l));
62   else {
63     printf("error __setbit2 x=%d\n",x);
64     exit(1);
65   }
66   return x;
67 }
68
69
70 int __setbits(uint *B, int i, int d, int x) {
71   int j;
72
73   for (j=0; j<d; j++) {
74     __setbit(B,i+j,(x>>(d-j-1))&1);
75   }
76   return x;
77 }
78
79
80 int __getbit(uint *B, int i) {
81   int j,l;
82
83   //j = i / D;
84   //l = i % D;
85   j = i >> logD;
86   l = i & (D-1);
87   return (B[j] >> (D-1-l)) & 1;
88 }
89
90
91 int __getbit2(uchar *B, int i) {
92   int j,l;
93
94   //j = i / D;
95   //l = i % D;
96   j = i >> 3;
97   l = i & (8-1);
98   return (B[j] >> (8-1-l)) & 1;
99 }
100
101
102 #if 1
103 uint __getbits(uint *B, int i, int d) {
104   qword x,z;
105
106   B += (i >> logD);
107   i &= (D-1);
108   if (i+d <= 2*D) {
109     x = (((qword)B[0]) << D) + B[1];
110     x <<= i;
111     x >>= (D*2-1-d);
112     x >>= 1;
113   }
114   else {
115     x = (((qword)B[0])<<D)+B[1];
116     z = (x<<D)+B[2];
117     x <<= i;
118     x &= (((qword)1L<<D)-1)<<D;
119     z <<= i;
120     z >>= D;
121     x += z;
122     x >>= (2*D-d);
123   }
124
125   return x;
126 }
127 #endif
128
129 #if 0
130 uint __getbits(uint *B, int i, int d) {
131   uint j,x;
132
133   x = 0;
134   for (j=0; j<d; j++) {
135     x <<= 1;
136     x += __getbit(B,i+j);
137   }
138   return x;
139 }
140 #endif
141
142 static const unsigned int _popCount[] = {
143   0,1,1,2,1,2,2,3,1,2,2,3,2,3,3,4,
144   1,2,2,3,2,3,3,4,2,3,3,4,3,4,4,5,
145   1,2,2,3,2,3,3,4,2,3,3,4,3,4,4,5,
146   2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,
147   1,2,2,3,2,3,3,4,2,3,3,4,3,4,4,5,
148   2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,
149   2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,
150   3,4,4,5,4,5,5,6,4,5,5,6,5,6,6,7,
151   1,2,2,3,2,3,3,4,2,3,3,4,3,4,4,5,
152   2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,
153   2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,
154   3,4,4,5,4,5,5,6,4,5,5,6,5,6,6,7,
155   2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,
156   3,4,4,5,4,5,5,6,4,5,5,6,5,6,6,7,
157   3,4,4,5,4,5,5,6,4,5,5,6,5,6,6,7,
158   4,5,5,6,5,6,6,7,5,6,6,7,6,7,7,8
159 };
160
161 static unsigned int __selecttbl[8*256];
162
163 void make___selecttbl(void) {
164   int i,x,r;
165   uint buf[1];
166
167   for (x = 0; x < 256; x++) {
168     __setbits(buf,0,8,x);
169     for (r=0; r<8; r++) __selecttbl[(r<<8)+x] = -1;
170     r = 0;
171     for (i=0; i<8; i++) {
172       if (__getbit(buf,i)) {
173         __selecttbl[(r<<8)+x] = i;
174         r++;
175       }
176     }
177   }
178 }
179
180 unsigned int __popCount(uint x) {
181   uint r;
182   #if 0
183   r = x;
184   r = r - ((r>>1) & 0x77777777) - ((r>>2) & 0x33333333) - ((r>>3) & 0x11111111);
185   r = ((r + (r>>4)) & 0x0f0f0f0f) % 0xff;
186   #elif 1
187   r = x;
188   r = ((r & 0xaaaaaaaa)>>1) + (r & 0x55555555);
189   r = ((r & 0xcccccccc)>>2) + (r & 0x33333333);
190   //r = ((r & 0xf0f0f0f0)>>4) + (r & 0x0f0f0f0f);
191   r = ((r>>4) + r) & 0x0f0f0f0f;
192   //r = ((r & 0xff00ff00)>>8) + (r & 0x00ff00ff);
193   r = (r>>8) + r;
194   //r = ((r & 0xffff0000)>>16) + (r & 0x0000ffff);
195   r = ((r>>16) + r) & 63;
196   #else
197   r = _popCount[x & 0xff];
198   x >>= 8;
199   r += _popCount[x & 0xff];
200   x >>= 8;
201   r += _popCount[x & 0xff];
202   x >>= 8;
203   r += _popCount[x & 0xff];
204   #endif
205   return r;
206 }
207
208
209 unsigned int __popCount8(uint x) {
210   uint r;
211   #if 1
212   r = x;
213   r = ((r & 0xaa)>>1) + (r & 0x55);
214   r = ((r & 0xcc)>>2) + (r & 0x33);
215   r = ((r>>4) + r) & 0x0f;
216   #else
217   r = _popCount[x & 0xff];
218   #endif
219   return r;
220 }
221
222 int selectd2_save(selectd2 * s, FILE * fp) {
223         uint wr = 0;
224         wr += fwrite(&s->n,sizeof(uint),1,fp);
225         wr += fwrite(&s->m,sizeof(uint),1,fp);
226         wr += fwrite(&s->size,sizeof(uint),1,fp);
227         wr += fwrite(&s->ss_len,sizeof(uint),1,fp);
228         wr += fwrite(&s->sl_len,sizeof(uint),1,fp);
229         wr += fwrite(s->buf,sizeof(uchar),(s->n+7)/8+1,fp);
230         uint nl = (s->m-1) / L + 1;
231         wr += fwrite(s->lp,sizeof(uint),nl+1,fp);
232         wr += fwrite(s->p,sizeof(uint),nl+1,fp);
233         wr += fwrite(s->ss,sizeof(ushort),s->ss_len,fp);
234         wr += fwrite(s->sl,sizeof(uint),s->sl_len,fp);
235         if(wr!=s->sl_len+s->ss_len+2*(nl+1)+(s->n+7)/8+1+5) 
236                 return 1;
237         return 0;
238 }
239
240 int selectd2_load(selectd2 * s, FILE * fp) {
241         uint rd = 0;
242         rd += fread(&s->n,sizeof(uint),1,fp);
243         rd += fread(&s->m,sizeof(uint),1,fp);
244         rd += fread(&s->size,sizeof(uint),1,fp);
245         rd += fread(&s->ss_len,sizeof(uint),1,fp);
246         rd += fread(&s->sl_len,sizeof(uint),1,fp);
247         s->buf = new uchar[(s->n+7)/8+1];
248         rd += fread(s->buf,sizeof(uchar),(s->n+7)/8+1,fp);
249         uint nl = (s->m-1) / L + 1;
250         s->lp = new uint[nl+1];
251         rd += fread(s->lp,sizeof(uint),nl+1,fp);
252         s->p = new uint[nl+1];
253         rd += fread(s->p,sizeof(uint),nl+1,fp);
254         s->ss = new ushort[s->ss_len];
255         rd += fread(s->ss,sizeof(ushort),s->ss_len,fp);
256         s->sl = new uint[s->sl_len];
257         rd += fread(s->sl,sizeof(uint),s->sl_len,fp);
258         if(rd!=s->sl_len+s->ss_len+2*(nl+1)+(s->n+7)/8+1+5) 
259                 return 1;
260         return 0;
261 }
262
263 void selectd2_free(selectd2 * s) {
264         //delete [] s->buf;
265         delete [] s->lp;
266         delete [] s->p;
267         delete [] s->ss;
268         delete [] s->sl;
269 }
270
271 int selectd2_construct(selectd2 *select, int n, uchar *buf) {
272   int i,m;
273   int nl;
274   int p,pp;
275   int il,is,ml,ms;
276   int r;
277   uint *s;
278
279   make___selecttbl();
280
281   if (L/LLL == 0) {
282     printf("ERROR: L=%d LLL=%d\n",L,LLL);
283     exit(1);
284   }
285
286   m = 0;
287   for (i=0; i<n; i++) m += __getbit2(buf,i);
288   select->n = n;
289   select->m = m;
290   //printf("n=%d m=%d\n",n,m);
291
292   select->buf = buf;
293
294   s = new uint[m];
295   m = 0;
296   for (i=0; i<n; i++) {
297     if (__getbit2(buf,i)) {
298       m++;
299       s[m-1] = i;
300     }
301   }
302
303   nl = (m-1) / L + 1;
304   select->size = 0; //ignoring buf, shared with selects3
305   select->lp = new uint[nl+1];
306         for(int k=0;k<nl+1;k++) select->lp[k]=0;
307   select->size += (nl+1)*sizeof(uint);
308   select->p = new uint[nl+1];
309         for(int k=0;k<nl+1;k++) select->p[k]=0;
310   select->size += (nl+1)*sizeof(uint);
311
312   for (r = 0; r < 2; r++) {
313     ml = ms = 0;
314     for (il = 0; il < nl; il++) {
315       pp = s[il*L];
316       select->lp[il] = pp;
317       i = min((il+1)*L-1,m-1);
318       p = s[i];
319       //printf("%d ",p-pp);
320       if (p - pp >= LL) {
321         if (r == 1) {
322           for (is = 0; is < L; is++) {
323             if (il*L+is >= m) break;
324             select->sl[ml*L+is] = s[il*L+is];
325           }
326         }
327         select->p[il] = -((ml<<logL)+1);
328         ml++;
329       }
330       else {
331         if (r == 1) {
332           for (is = 0; is < L/LLL; is++) {
333             if (il*L+is*LLL >= m) break;
334             select->ss[ms*(L/LLL)+is] = s[il*L+is*LLL] - pp;
335           }
336         }
337         select->p[il] = ms << (logL-logLLL);
338         ms++;
339       }
340     }
341     if (r == 0) {
342       select->sl = new uint[ml*L+1];
343                         for(int k=0;k<ml*L+1;k++) select->sl[k]=0;
344       select->size += sizeof(uint)*(ml*L+1);
345                         select->sl_len = ml*L+1;
346       select->ss = new ushort[ms*(L/LLL)+1];
347                         for(int k=0;k<ms*(L/LLL)+1;k++) select->ss[k]=0;
348                         select->ss_len = ms*(L/LLL)+1;
349       select->size += sizeof(ushort)*(ms*(L/LLL)+1);
350     }
351   }
352   delete [] s;
353         return 0;
354 }
355
356
357 int selectd2_select(selectd2 *select, int i,int f) {
358   int p,r;
359   int il;
360   int rr;
361   uchar *q;
362
363   if (i == 0) return -1;
364
365   #if 0
366   if (i > select->m) {
367     printf("ERROR: m=%d i=%d\n",select->m,i);
368     exit(1);
369   }
370   #endif
371
372   i--;
373
374   il = select->p[i>>logL];
375   if (il < 0) {
376     il = -il-1;
377     //p = select->sl[(il<<logL)+(i & (L-1))];
378     p = select->sl[il+(i & (L-1))];
379   }
380   else {
381     p = select->lp[i>>logL];
382     //p += select->ss[(il<<(logL-logLLL))+(i & (L-1))/LLL];
383     p += select->ss[il+((i & (L-1))>>logLLL)];
384     r = i - (i & (LLL-1));
385
386     q = &(select->buf[p>>3]);
387
388     if (f == 1) {
389       rr = p & (8-1);
390       r -= _popCount[*q >> (8-1-rr)];
391       //p = p - rr;
392
393       while (1) {
394         rr = _popCount[*q];
395         if (r + rr >= i) break;
396         r += rr;
397         //p += 8;
398         q++;
399       }
400       p = (q - select->buf) << 3;
401       p += __selecttbl[((i-r-1)<<8)+(*q)];
402     }
403     else {
404       rr = p & (8-1);
405       r -= _popCount[(*q ^ 0xff) >> (8-1-rr)];
406       //p = p - rr;
407
408       while (1) {
409         rr = _popCount[*q ^ 0xff];
410         if (r + rr >= i) break;
411         r += rr;
412         //p += 8;
413         q++;
414       }
415       p = (q - select->buf) << 3;
416       p += __selecttbl[((i-r-1)<<8)+(*q ^ 0xff)];
417     }
418   }
419   return p;
420 }
421
422
423 int selectd2_select2(selectd2 *select, int i,int f, int *st, int *en) {
424   int p,r,p2;
425   int il;
426   int rr;
427   uchar *q;
428
429   if (i == 0) {
430     *st = -1;
431     return -1;
432   }
433
434   #if 0
435   if (i > select->m) {
436     printf("ERROR: m=%d i=%d\n",select->m,i);
437     exit(1);
438   }
439   #endif
440
441   i--;
442
443   il = select->p[i>>logL];
444   if (il < 0) {
445     il = -il-1;
446     //p = select->sl[(il<<logL)+(i & (L-1))];
447     p = select->sl[il+(i & (L-1))];
448
449     if ((i>>logL) == ((i+1)>>logL)) {
450       p2 = select->sl[il+((i+1) & (L-1))];
451     }
452     else {
453       p2 = selectd2_select(select,i+2,f);
454     }
455   }
456   else {
457     p = select->lp[i>>logL];
458     //p += select->ss[(il<<(logL-logLLL))+(i & (L-1))/LLL];
459     p += select->ss[il+((i & (L-1))>>logLLL)];
460     r = i - (i & (LLL-1));
461
462     q = &(select->buf[p>>3]);
463
464     if (f == 1) {
465       rr = p & (8-1);
466       r -= _popCount[*q >> (8-1-rr)];
467       //p = p - rr;
468
469       while (1) {
470         rr = _popCount[*q];
471         if (r + rr >= i) break;
472         r += rr;
473         //p += 8;
474         q++;
475       }
476       p = (q - select->buf) << 3;
477       p += __selecttbl[((i-r-1)<<8)+(*q)];
478
479       if ((i>>logL) == ((i+1)>>logL)) {
480         i++;
481         while (1) {
482           rr = _popCount[*q];
483           if (r + rr >= i) break;
484           r += rr;
485           q++;
486         }
487         p2 = (q - select->buf) << 3;
488         p2 += __selecttbl[((i-r-1)<<8)+(*q)];
489       }
490       else {
491         p2 = selectd2_select(select,i+2,f);
492       }
493
494     }
495     else {
496       rr = p & (8-1);
497       r -= _popCount[(*q ^ 0xff) >> (8-1-rr)];
498       //p = p - rr;
499
500       while (1) {
501         rr = _popCount[*q ^ 0xff];
502         if (r + rr >= i) break;
503         r += rr;
504         //p += 8;
505         q++;
506       }
507       p = (q - select->buf) << 3;
508       p += __selecttbl[((i-r-1)<<8)+(*q ^ 0xff)];
509
510       if ((i>>logL) == ((i+1)>>logL)) {
511         i++;
512         while (1) {
513           rr = _popCount[*q ^ 0xff];
514           if (r + rr >= i) break;
515           r += rr;
516           q++;
517         }
518         p2 = (q - select->buf) << 3;
519         p2 += __selecttbl[((i-r-1)<<8)+(*q ^ 0xff)];
520       }
521       else {
522         p2 = selectd2_select(select,i+2,f);
523       }
524     }
525   }
526   *st = p;
527   *en = p2;
528   return p;
529 }
530
531
532 int selects3_save(selects3 * s, FILE * fp) {
533         uint wr = 0;
534         wr += fwrite(&s->n,sizeof(uint),1,fp);
535         wr += fwrite(&s->m,sizeof(uint),1,fp);
536         wr += fwrite(&s->size,sizeof(uint),1,fp);
537         wr += fwrite(&s->d,sizeof(uint),1,fp);
538         wr += fwrite(&s->hi_len,sizeof(uint),1,fp);
539         wr += fwrite(&s->low_len,sizeof(uint),1,fp);
540         wr += fwrite(s->hi,sizeof(uchar),s->hi_len,fp);
541         wr += fwrite(s->low,sizeof(uint),s->low_len,fp);
542         if(wr!=(6+s->hi_len+s->low_len))
543                 return 1;
544         if(selectd2_save(s->sd0,fp)) return 2;
545         if(selectd2_save(s->sd1,fp)) return 3;
546         return 0;
547 }
548
549 int selects3_load(selects3 * s, FILE * fp) {
550         uint rd = 0;
551         rd += fread(&s->n,sizeof(uint),1,fp);
552         rd += fread(&s->m,sizeof(uint),1,fp);
553         rd += fread(&s->size,sizeof(uint),1,fp);
554         rd += fread(&s->d,sizeof(uint),1,fp);
555         rd += fread(&s->hi_len,sizeof(uint),1,fp);
556         rd += fread(&s->low_len,sizeof(uint),1,fp);
557         s->hi = new uchar[s->hi_len];
558         rd += fread(s->hi,sizeof(uchar),s->hi_len,fp);
559         s->low = new uint[s->low_len];
560         rd += fread(s->low,sizeof(uint),s->low_len,fp);
561         if(rd!=(6+s->hi_len+s->low_len))
562                 return 1;
563         s->sd0 = new selectd2;
564         if(selectd2_load(s->sd0,fp)) return 2;
565         s->sd1 = new selectd2;
566         if(selectd2_load(s->sd1,fp)) return 3;
567         delete [] s->sd0->buf;
568         delete [] s->sd1->buf;
569         s->sd0->buf = s->hi;
570         s->sd1->buf = s->hi;
571         return 0;
572 }
573
574 void selects3_free(selects3 * s) {
575         delete [] s->hi;
576         delete [] s->low;
577         //delete [] s->sd0->buf;
578         selectd2_free(s->sd0);
579         delete s->sd0;
580         selectd2_free(s->sd1);
581         delete s->sd1;
582 }
583
584 int selects3_construct(selects3 *select, int n, uint *buf) {
585   int i,m;
586   int d,mm;
587   uint *low;
588   uchar *buf2;
589   selectd2 *sd0,*sd1;
590
591   m = 0;
592   for (i=0; i<n; i++) m += __getbit(buf,i);
593   select->n = n;
594   select->m = m;
595
596   if (m == 0) return 0;
597
598   mm = m;
599   d = 0;
600   while (mm < n) {
601     mm <<= 1;
602     d++;
603   }
604
605   select->d = d;
606
607   buf2 = new uchar[(2*m+8-1)/8+1];
608         for(int k=0;k<(2*m+8-1)/8+1;k++) buf2[k]=0;
609         select->hi_len = (2*m+8-1)/8+1;
610   low = new uint[(d*m+PBS-1)/PBS+1];
611         for(uint k=0;k<(d*m+PBS-1)/PBS+1;k++) low[k]=0;
612         select->low_len = (d*m+PBS-1)/PBS+1;
613
614   select->hi = buf2;
615   select->low = low;
616   select->size = sizeof(uchar)*((2*m+8-1)/8+1) + sizeof(uint)*((d*m+PBS-1)/PBS+1);
617
618   for (i=0; i<m*2; i++) __setbit2(buf2,i,0);
619
620   m = 0;
621   for (i=0; i<n; i++) {
622     if (__getbit(buf,i)) {
623       __setbit2(buf2,(i>>d)+m,1);
624       __setbits(low,m*d,d,i & ((1<<d)-1));
625       m++;
626     }
627   }
628
629   sd1 = new selectd2;
630   sd0 = new selectd2;
631   select->size += 2*sizeof(selectd2);
632
633   selectd2_construct(sd1,m*2,buf2);
634   select->sd1 = sd1;
635
636   for (i=0; i<m*2; i++) __setbit2(buf2,i,1-__getbit2(buf2,i));
637   selectd2_construct(sd0,m*2,buf2);
638   select->sd0 = sd0;
639
640   for (i=0; i<m*2; i++) __setbit2(buf2,i,1-__getbit2(buf2,i));
641         return 0;
642 }
643
644
645 int selects3_select(selects3 *select, int i) {
646   int d,x;
647
648   #if 1
649   if (i > select->m) {
650     printf("ERROR: m=%d i=%d\n",select->m,i);
651     exit(1);
652   }
653   #endif
654
655   if (i == 0) return -1;
656
657   d = select->d;
658
659   x = selectd2_select(select->sd1,i,1) - (i-1);
660   x <<= d;
661   x += __getbits(select->low,(i-1)*d,d);
662   return x;
663
664 }
665
666 int selects3_rank(selects3 *select, int i) {
667   int d,x,w,y;
668   int r,j;
669   int z,ii;
670   uint *q;
671
672   d = select->d;
673   q = select->low;
674
675   ii = i>>d;
676   y = selectd2_select(select->sd0,ii,0)+1;
677   //  selectd2_select2(select->sd0,ii,0,&y1,&y2);
678   //y1++;  y2++;
679   //printf("y %d y1 %d  %d\n",y,y1,y2-y1);
680
681   x = y - ii;
682
683   j = i - (ii<<d);
684
685   r = y & 7;
686   y >>= 3;
687   z = select->hi[y];
688   while (1) {
689     if (((z << r) & 0x80) == 0) break;
690     w = __getbits(q,x*d,d);
691     if (w >= j) {
692       if (w == j) x++;
693       break;
694     }
695     x++;
696     r++;
697     if (r == 8) {
698       r = 0;
699       y++;
700       z = select->hi[y];
701     }
702   }
703
704   return x;
705 }