...
[SXSI/XMLTree.git] / libcds / src / static_bitsequence / sdarray.cpp
index f904aa3..96810fe 100644 (file)
@@ -39,7 +39,7 @@ int __blog(int x) {
 
 int __setbit(uint *B, int i,int x) {
   int j,l;
-//printf("%u\n",D);
+  //printf("%u\n",D);
   j = i / D;
   l = i % D;
   if (x==0) B[j] &= (~(1<<(D-1-l)));
@@ -177,6 +177,7 @@ void make___selecttbl(void) {
   }
 }
 
+
 unsigned int __popCount(uint x) {
   uint r;
   #if 0
@@ -219,55 +220,59 @@ unsigned int __popCount8(uint x) {
   return r;
 }
 
+
 int selectd2_save(selectd2 * s, FILE * fp) {
-       uint wr = 0;
-       wr += fwrite(&s->n,sizeof(uint),1,fp);
-       wr += fwrite(&s->m,sizeof(uint),1,fp);
-       wr += fwrite(&s->size,sizeof(uint),1,fp);
-       wr += fwrite(&s->ss_len,sizeof(uint),1,fp);
-       wr += fwrite(&s->sl_len,sizeof(uint),1,fp);
-       wr += fwrite(s->buf,sizeof(uchar),(s->n+7)/8+1,fp);
-       uint nl = (s->m-1) / L + 1;
-       wr += fwrite(s->lp,sizeof(uint),nl+1,fp);
-       wr += fwrite(s->p,sizeof(uint),nl+1,fp);
-       wr += fwrite(s->ss,sizeof(ushort),s->ss_len,fp);
-       wr += fwrite(s->sl,sizeof(uint),s->sl_len,fp);
-       if(wr!=s->sl_len+s->ss_len+2*(nl+1)+(s->n+7)/8+1+5) 
-               return 1;
-       return 0;
+  uint wr = 0;
+  wr += fwrite(&s->n,sizeof(uint),1,fp);
+  wr += fwrite(&s->m,sizeof(uint),1,fp);
+  wr += fwrite(&s->size,sizeof(uint),1,fp);
+  wr += fwrite(&s->ss_len,sizeof(uint),1,fp);
+  wr += fwrite(&s->sl_len,sizeof(uint),1,fp);
+  wr += fwrite(s->buf,sizeof(uchar),(s->n+7)/8+1,fp);
+  uint nl = (s->m-1) / L + 1;
+  wr += fwrite(s->lp,sizeof(uint),nl+1,fp);
+  wr += fwrite(s->p,sizeof(uint),nl+1,fp);
+  wr += fwrite(s->ss,sizeof(ushort),s->ss_len,fp);
+  wr += fwrite(s->sl,sizeof(uint),s->sl_len,fp);
+  if(wr!=s->sl_len+s->ss_len+2*(nl+1)+(s->n+7)/8+1+5)
+    return 1;
+  return 0;
 }
 
+
 int selectd2_load(selectd2 * s, FILE * fp) {
-       uint rd = 0;
-       rd += fread(&s->n,sizeof(uint),1,fp);
-       rd += fread(&s->m,sizeof(uint),1,fp);
-       rd += fread(&s->size,sizeof(uint),1,fp);
-       rd += fread(&s->ss_len,sizeof(uint),1,fp);
-       rd += fread(&s->sl_len,sizeof(uint),1,fp);
-       s->buf = new uchar[(s->n+7)/8+1];
-       rd += fread(s->buf,sizeof(uchar),(s->n+7)/8+1,fp);
-       uint nl = (s->m-1) / L + 1;
-       s->lp = new uint[nl+1];
-       rd += fread(s->lp,sizeof(uint),nl+1,fp);
-       s->p = new uint[nl+1];
-       rd += fread(s->p,sizeof(uint),nl+1,fp);
-       s->ss = new ushort[s->ss_len];
-       rd += fread(s->ss,sizeof(ushort),s->ss_len,fp);
-       s->sl = new uint[s->sl_len];
-       rd += fread(s->sl,sizeof(uint),s->sl_len,fp);
-       if(rd!=s->sl_len+s->ss_len+2*(nl+1)+(s->n+7)/8+1+5) 
-               return 1;
-       return 0;
+  uint rd = 0;
+  rd += fread(&s->n,sizeof(uint),1,fp);
+  rd += fread(&s->m,sizeof(uint),1,fp);
+  rd += fread(&s->size,sizeof(uint),1,fp);
+  rd += fread(&s->ss_len,sizeof(uint),1,fp);
+  rd += fread(&s->sl_len,sizeof(uint),1,fp);
+  s->buf = new uchar[(s->n+7)/8+1];
+  rd += fread(s->buf,sizeof(uchar),(s->n+7)/8+1,fp);
+  uint nl = (s->m-1) / L + 1;
+  s->lp = new uint[nl+1];
+  rd += fread(s->lp,sizeof(uint),nl+1,fp);
+  s->p = new uint[nl+1];
+  rd += fread(s->p,sizeof(uint),nl+1,fp);
+  s->ss = new ushort[s->ss_len];
+  rd += fread(s->ss,sizeof(ushort),s->ss_len,fp);
+  s->sl = new uint[s->sl_len];
+  rd += fread(s->sl,sizeof(uint),s->sl_len,fp);
+  if(rd!=s->sl_len+s->ss_len+2*(nl+1)+(s->n+7)/8+1+5)
+    return 1;
+  return 0;
 }
 
+
 void selectd2_free(selectd2 * s) {
-       //delete [] s->buf;
-       delete [] s->lp;
-       delete [] s->p;
-       delete [] s->ss;
-       delete [] s->sl;
+  //delete [] s->buf;
+  delete [] s->lp;
+  delete [] s->p;
+  delete [] s->ss;
+  delete [] s->sl;
 }
 
+
 int selectd2_construct(selectd2 *select, int n, uchar *buf) {
   int i,m;
   int nl;
@@ -301,12 +306,12 @@ int selectd2_construct(selectd2 *select, int n, uchar *buf) {
   }
 
   nl = (m-1) / L + 1;
-  select->size = 0; //ignoring buf, shared with selects3
+  select->size = 0;              //ignoring buf, shared with selects3
   select->lp = new uint[nl+1];
-       for(int k=0;k<nl+1;k++) select->lp[k]=0;
+  for(int k=0;k<nl+1;k++) select->lp[k]=0;
   select->size += (nl+1)*sizeof(uint);
   select->p = new uint[nl+1];
-       for(int k=0;k<nl+1;k++) select->p[k]=0;
+  for(int k=0;k<nl+1;k++) select->p[k]=0;
   select->size += (nl+1)*sizeof(uint);
 
   for (r = 0; r < 2; r++) {
@@ -340,17 +345,17 @@ int selectd2_construct(selectd2 *select, int n, uchar *buf) {
     }
     if (r == 0) {
       select->sl = new uint[ml*L+1];
-                       for(int k=0;k<ml*L+1;k++) select->sl[k]=0;
+      for(int k=0;k<ml*L+1;k++) select->sl[k]=0;
       select->size += sizeof(uint)*(ml*L+1);
-                       select->sl_len = ml*L+1;
+      select->sl_len = ml*L+1;
       select->ss = new ushort[ms*(L/LLL)+1];
-                       for(int k=0;k<ms*(L/LLL)+1;k++) select->ss[k]=0;
-                       select->ss_len = ms*(L/LLL)+1;
+      for(int k=0;k<ms*(L/LLL)+1;k++) select->ss[k]=0;
+      select->ss_len = ms*(L/LLL)+1;
       select->size += sizeof(ushort)*(ms*(L/LLL)+1);
     }
   }
   delete [] s;
-       return 0;
+  return 0;
 }
 
 
@@ -530,57 +535,60 @@ int selectd2_select2(selectd2 *select, int i,int f, int *st, int *en) {
 
 
 int selects3_save(selects3 * s, FILE * fp) {
-       uint wr = 0;
-       wr += fwrite(&s->n,sizeof(uint),1,fp);
-       wr += fwrite(&s->m,sizeof(uint),1,fp);
-       wr += fwrite(&s->size,sizeof(uint),1,fp);
-        wr += fwrite(&s->d,sizeof(uint),1,fp);
-       wr += fwrite(&s->hi_len,sizeof(uint),1,fp);
-        wr += fwrite(&s->low_len,sizeof(uint),1,fp);
-       wr += fwrite(s->hi,sizeof(uchar),s->hi_len,fp);
-       wr += fwrite(s->low,sizeof(uint),s->low_len,fp);
-       if(wr!=(6+s->hi_len+s->low_len))
-               return 1;
-       if(selectd2_save(s->sd0,fp)) return 2;
-       if(selectd2_save(s->sd1,fp)) return 3;
-       return 0;
+  uint wr = 0;
+  wr += fwrite(&s->n,sizeof(uint),1,fp);
+  wr += fwrite(&s->m,sizeof(uint),1,fp);
+  wr += fwrite(&s->size,sizeof(uint),1,fp);
+  wr += fwrite(&s->d,sizeof(uint),1,fp);
+  wr += fwrite(&s->hi_len,sizeof(uint),1,fp);
+  wr += fwrite(&s->low_len,sizeof(uint),1,fp);
+  wr += fwrite(s->hi,sizeof(uchar),s->hi_len,fp);
+  wr += fwrite(s->low,sizeof(uint),s->low_len,fp);
+  if(wr!=(6+s->hi_len+s->low_len))
+    return 1;
+  if(selectd2_save(s->sd0,fp)) return 2;
+  if(selectd2_save(s->sd1,fp)) return 3;
+  return 0;
 }
 
+
 int selects3_load(selects3 * s, FILE * fp) {
-       uint rd = 0;
-       rd += fread(&s->n,sizeof(uint),1,fp);
-       rd += fread(&s->m,sizeof(uint),1,fp);
-       rd += fread(&s->size,sizeof(uint),1,fp);
-       rd += fread(&s->d,sizeof(uint),1,fp);
-       rd += fread(&s->hi_len,sizeof(uint),1,fp);
-       rd += fread(&s->low_len,sizeof(uint),1,fp);
-       s->hi = new uchar[s->hi_len];
-       rd += fread(s->hi,sizeof(uchar),s->hi_len,fp);
-       s->low = new uint[s->low_len];
-       rd += fread(s->low,sizeof(uint),s->low_len,fp);
-       if(rd!=(6+s->hi_len+s->low_len))
-               return 1;
-       s->sd0 = new selectd2;
-       if(selectd2_load(s->sd0,fp)) return 2;
-       s->sd1 = new selectd2;
-       if(selectd2_load(s->sd1,fp)) return 3;
-       delete [] s->sd0->buf;
-       delete [] s->sd1->buf;
-       s->sd0->buf = s->hi;
-       s->sd1->buf = s->hi;
-       return 0;
+  uint rd = 0;
+  rd += fread(&s->n,sizeof(uint),1,fp);
+  rd += fread(&s->m,sizeof(uint),1,fp);
+  rd += fread(&s->size,sizeof(uint),1,fp);
+  rd += fread(&s->d,sizeof(uint),1,fp);
+  rd += fread(&s->hi_len,sizeof(uint),1,fp);
+  rd += fread(&s->low_len,sizeof(uint),1,fp);
+  s->hi = new uchar[s->hi_len];
+  rd += fread(s->hi,sizeof(uchar),s->hi_len,fp);
+  s->low = new uint[s->low_len];
+  rd += fread(s->low,sizeof(uint),s->low_len,fp);
+  if(rd!=(6+s->hi_len+s->low_len))
+    return 1;
+  s->sd0 = new selectd2;
+  if(selectd2_load(s->sd0,fp)) return 2;
+  s->sd1 = new selectd2;
+  if(selectd2_load(s->sd1,fp)) return 3;
+  delete [] s->sd0->buf;
+  delete [] s->sd1->buf;
+  s->sd0->buf = s->hi;
+  s->sd1->buf = s->hi;
+  return 0;
 }
 
+
 void selects3_free(selects3 * s) {
-       delete [] s->hi;
-       delete [] s->low;
-       //delete [] s->sd0->buf;
-       selectd2_free(s->sd0);
-       delete s->sd0;
-       selectd2_free(s->sd1);
-       delete s->sd1;
+  delete [] s->hi;
+  delete [] s->low;
+  //delete [] s->sd0->buf;
+  selectd2_free(s->sd0);
+  delete s->sd0;
+  selectd2_free(s->sd1);
+  delete s->sd1;
 }
 
+
 int selects3_construct(selects3 *select, int n, uint *buf) {
   int i,m;
   int d,mm;
@@ -605,11 +613,11 @@ int selects3_construct(selects3 *select, int n, uint *buf) {
   select->d = d;
 
   buf2 = new uchar[(2*m+8-1)/8+1];
-       for(int k=0;k<(2*m+8-1)/8+1;k++) buf2[k]=0;
-       select->hi_len = (2*m+8-1)/8+1;
+  for(int k=0;k<(2*m+8-1)/8+1;k++) buf2[k]=0;
+  select->hi_len = (2*m+8-1)/8+1;
   low = new uint[(d*m+PBS-1)/PBS+1];
-       for(uint k=0;k<(d*m+PBS-1)/PBS+1;k++) low[k]=0;
-       select->low_len = (d*m+PBS-1)/PBS+1;
+  for(uint k=0;k<(d*m+PBS-1)/PBS+1;k++) low[k]=0;
+  select->low_len = (d*m+PBS-1)/PBS+1;
 
   select->hi = buf2;
   select->low = low;
@@ -638,14 +646,17 @@ int selects3_construct(selects3 *select, int n, uint *buf) {
   select->sd0 = sd0;
 
   for (i=0; i<m*2; i++) __setbit2(buf2,i,1-__getbit2(buf2,i));
-       return 0;
+  return 0;
 }
 
+//selects3 * lasts3=NULL;
+//int lasti=0;
+//int lasts=0;
 
 int selects3_select(selects3 *select, int i) {
   int d,x;
 
-  #if 1
+  #if 0
   if (i > select->m) {
     printf("ERROR: m=%d i=%d\n",select->m,i);
     exit(1);
@@ -655,12 +666,81 @@ int selects3_select(selects3 *select, int i) {
   if (i == 0) return -1;
 
   d = select->d;
-
-  x = selectd2_select(select->sd1,i,1) - (i-1);
-  x <<= d;
+       if(select->lasti==(uint)i-1) {
+               while(!__getbit2(select->sd1->buf,++select->lasts));
+       } 
+       else {
+         select->lasts = selectd2_select(select->sd1,i,1);
+       }
+       select->lasti = i;
+       //lasts3 = select;
+  x = (select->lasts-(i-1)) << d;
   x += __getbits(select->low,(i-1)*d,d);
   return x;
+}
+
 
+int selects3_selectnext(selects3 *select, int i) {
+       return selects3_select(select,selects3_rank(select,i)+1);
+  int d,x,w,y;
+  int r,j;
+  int z,ii;
+  uint *q;
+  d = select->d;
+  q = select->low;
+  ii = i>>d;
+  y = selectd2_select(select->sd0,ii,0)+1;
+       int k2=y-ii;
+  x = y - ii;
+       int x_orig = x;
+  j = i - (ii<<d);
+  r = y & 7;
+  y >>= 3;
+  z = select->hi[y];
+  while (1) {
+    if (((z << r) & 0x80) == 0) {
+                       if(x!=x_orig) k2++;
+                       break;
+               }
+    w = __getbits(q,x*d,d);
+    if (w >= j) {
+      if (w == j) {
+                               if(__getbit2(select->hi,(8*y+r))) k2++;
+                               x++;
+                               r++;
+                       }
+      break;
+    }
+    x++;
+    r++;
+               if(__getbit2(select->hi,(8*y+r))) k2++;
+    if (r == 8) {
+      r = 0;
+      y++;
+      z = select->hi[y];
+    }
+  }
+       if(x==select->m)
+               return (uint)-1;
+       int c=8*y+r;
+       int fin=0;
+       for(int kk=0;kk<8-r;kk++) {
+               if(__getbit2(select->hi,c)) {
+                       fin=1;
+                       break;
+               }
+               c++;
+       }
+       if(!fin) {
+               int pp = c/8;
+               while(select->hi[pp]==0) {
+                       pp++;
+                       c+=8;
+               }
+               while(!__getbit2(select->hi,c)) c++;
+       }
+       c -= (k2);
+  return __getbits(q,x*d,d)+((c)<<d);
 }
 
 int selects3_rank(selects3 *select, int i) {
@@ -701,5 +781,6 @@ int selects3_rank(selects3 *select, int i) {
     }
   }
 
-  return x;
+       return x;
 }
+