Improve sdarray select_next function
authorKim Nguyễn <kn@lri.fr>
Fri, 11 May 2012 11:41:38 +0000 (13:41 +0200)
committerKim Nguyễn <kn@lri.fr>
Fri, 11 May 2012 11:41:38 +0000 (13:41 +0200)
src/static_bitsequence/sdarray.cpp

index 0d43d81..fba36bf 100644 (file)
@@ -1,4 +1,3 @@
-
 #include <sdarray.h>
 using std::min;
 using std::max;
@@ -89,7 +88,7 @@ int __getbit(uint *B, int i) {
 }
 
 
-int __getbit2(uchar *B, int i) {
+static int __getbit2(uchar *B, int i) {
   int j,l;
 
   //j = i / D;
@@ -101,18 +100,22 @@ int __getbit2(uchar *B, int i) {
 
 
 #if 1
-uint __getbits(uint *B, int i, int d) {
+uint __getbits_aux(uint *B, int i, int d) {
   qword x,z;
-
+  int j = i;
   B += (i >> logD);
   i &= (D-1);
-  if (i+d <= 2*D) {
+  if (d==8 && (j & 7) == 0) {
+    i = (24 - i) >> 3;
+    x = (uint) (((unsigned char*) B)[i]);
+  } else if (i+d <= 2*D) {
     x = (((qword)B[0]) << D) + B[1];
     x <<= i;
     x >>= (D*2-1-d);
     x >>= 1;
   }
   else {
+    fprintf(stderr, "Warning: %d, %d\n", D, d);
     x = (((qword)B[0])<<D)+B[1];
     z = (x<<D)+B[2];
     x <<= i;
@@ -125,6 +128,23 @@ uint __getbits(uint *B, int i, int d) {
 
   return x;
 }
+
+uint __getbits(uint *B, int i, int d)
+{
+  ulong x;
+//  uint y;
+  // y = __getbits_aux(B, i,d);
+  B += (i >> logD);
+  i &= (D-1);
+  x = ((ulong *) B)[0];
+  x = (x << 32)|(x >> 32);
+  x <<= i;
+  x >>= 2*D - d;
+//  fprintf(stderr, "slow: %i, fast: %i\n",
+//       y, (uint) x);
+  return x;
+}
+
 #endif
 
 #if 0
@@ -391,21 +411,50 @@ int selectd2_construct(selectd2 *select, int n, uchar *buf) {
 }
 
 
-int selectd2_select(selectd2 *select, int i,int f) {
+int selectd2_select1(selectd2 *select, int i) {
   int p,r;
   int il;
   int rr;
   uchar *q;
+  if (i <= 0) return -1;
+  i--;
 
-  if (i == 0) return -1;
+  il = select->p[i>>logL];
+  if (il < 0) {
+    il = -il-1;
+    //p = select->sl[(il<<logL)+(i & (L-1))];
+    p = select->sl[il+(i & (L-1))];
+  }
+  else {
+    p = select->lp[i>>logL];
+    p += select->ss[il+((i & (L-1))>>logLLL)];
+    r = i - (i & (LLL-1));
 
-  #if 0
-  if (i > select->m) {
-    printf("ERROR: m=%d i=%d\n",select->m,i);
-    exit(1);
+    q = &(select->buf[p>>3]);
+
+    rr = p & (8-1);
+    r -= _fast_popcount(*q >> (8-1-rr));
+    
+    while (1) {
+        //rr = _popCount[*q];
+      rr = _fast_popcount(*q);
+      if (r + rr >= i) break;
+      r += rr;
+      //p += 8;
+      q++;
+    }
+      p = (q - select->buf) << 3;
+      p += __selecttbl[((i-r-1)<<8)+(*q)];
   }
-  #endif
+  return p;
+}
 
+int selectd2_select0(selectd2 *select, int i) {
+  int p,r;
+  int il;
+  int rr;
+  uchar *q;
+  if (i <= 0) return -1;
   i--;
 
   il = select->p[i>>logL];
@@ -422,44 +471,29 @@ int selectd2_select(selectd2 *select, int i,int f) {
 
     q = &(select->buf[p>>3]);
 
-    if (f == 1) {
-      rr = p & (8-1);
-      //r -= _popCount[*q >> (8-1-rr)];
-      r -= _fast_popcount(*q >> (8-1-rr));
-      //p = p - rr;
+    rr = p & (8-1);
 
-      while (1) {
-        //rr = _popCount[*q];
-       rr = _fast_popcount(*q);
-        if (r + rr >= i) break;
-        r += rr;
-        //p += 8;
-        q++;
-      }
-      p = (q - select->buf) << 3;
-      p += __selecttbl[((i-r-1)<<8)+(*q)];
-    }
-    else {
-      rr = p & (8-1);
-      //r -= _popCount[(*q ^ 0xff) >> (8-1-rr)];
-      r -= _fast_popcount((*q ^ 0xff) >> (8-1-rr));
-      //p = p - rr;
+    r -= _fast_popcount((*q ^ 0xff) >> (8-1-rr));
 
-      while (1) {
-        //rr = _popCount[*q ^ 0xff];
-       rr = _fast_popcount(*q ^ 0xff);
-        if (r + rr >= i) break;
-        r += rr;
-        //p += 8;
-        q++;
-      }
-      p = (q - select->buf) << 3;
-      p += __selecttbl[((i-r-1)<<8)+(*q ^ 0xff)];
+    while (1) {
+      //rr = _popCount[*q ^ 0xff];
+      rr = _fast_popcount(*q ^ 0xff);
+      if (r + rr >= i) break;
+      r += rr;
+      //p += 8;
+      q++;
     }
+    p = (q - select->buf) << 3;
+    p += __selecttbl[((i-r-1)<<8)+(*q ^ 0xff)];
   }
   return p;
 }
 
+int selectd2_select(selectd2 *select, int i,int f) {
+  return f ? selectd2_select1(select, i) :
+    selectd2_select0(select, i);
+}
+
 
 int selectd2_select2(selectd2 *select, int i,int f, int *st, int *en) {
   int p,r,p2;
@@ -710,80 +744,93 @@ int selects3_select(selects3 *select, int i) {
   d = select->d;
        /*if(select->lasti==(uint)i-1) {
                while(!__getbit2(select->sd1->buf,++select->lasts));
-       } 
+       }
        else {
          select->lasts = selectd2_select(select->sd1,i,1);
        }
        select->lasti = i;
        //lasts3 = select; */
-       x = selectd2_select(select->sd1,i,1) - (i-1);
+       x = selectd2_select1(select->sd1,i) - (i-1);
   //x = (select->lasts-(i-1)) << d;
   x <<= d;
   x += __getbits(select->low,(i-1)*d,d);
   return x;
 }
 
+void pr_byte(FILE* fp, uchar b)
+{
+  uchar * buff = &b;
+  for(int i = 0; i < 8; i++){
+    fprintf(stderr, "%i", __getbit2(buff, i));
+  };
+}
 
 int selects3_selectnext(selects3 *select, int i) {
        //return selects3_select(select,selects3_rank(select,i)+1);
   int d,x,w,y;
   int r,j;
   int z,ii;
+  int xoffset;
   uint *q;
   d = select->d;
   q = select->low;
   ii = i>>d;
-  y = selectd2_select(select->sd0,ii,0)+1;
-       int k2=y-ii;
+  y = selectd2_select0(select->sd0,ii)+1;
+  int k2=y-ii;
   x = y - ii;
-       int x_orig = x;
+  int x_orig = x;
   j = i - (ii<<d);
   r = y & 7;
   y >>= 3;
   z = select->hi[y];
+  xoffset = x * d;
   while (1) {
     if (((z << r) & 0x80) == 0) {
-                       if(x!=x_orig) k2++;
-                       break;
-               }
-    w = __getbits(q,x*d,d);
+      k2 += (x!=x_orig);
+      break;
+    };
+
+    w = __getbits(q,xoffset,d);
     if (w >= j) {
-      if (w == j) {
-                               if(__getbit2(select->hi,(8*y+r))) k2++;
-                               x++;
-                               r++;
-                       }
+      bool t1 = (w == j);
+      bool t2 = (__getbit2(select->hi,((y << 3)+r)));
+      if (t2)  k2 += (t1);
+      x  += t1;
+      r  += t1;
       break;
-    }
+    };
+
     x++;
     r++;
-               if(__getbit2(select->hi,(8*y+r))) k2++;
+    xoffset += d;
+    if(__getbit2(select->hi,( (y << 3)+r))) k2++;
     if (r == 8) {
       r = 0;
       y++;
       z = select->hi[y];
     }
-  }
-       if(x==select->m)
-               return (uint)-1;
-       int c=8*y+r;
-       int fin=0;
-       for(int kk=0;kk<8-r;kk++) {
-               if(__getbit2(select->hi,c)) {
-                       fin=1;
-                       break;
-               }
-               c++;
-       }
-       if(!fin) {
-               int pp = c/8;
-               while(select->hi[pp]==0) {
-                       pp++;
-                       c+=8;
-               }
-               while(!__getbit2(select->hi,c)) c++;
-       }
-       c -= (k2);
+  };
+
+  if(x==select->m) return (uint)-1;
+
+
+  int c=(y << 3)+r;
+  unsigned int mask = 0x00ffffff;
+  unsigned int tmp = (select->hi[y] << (24+r)) | mask;
+  unsigned int c_incr = __builtin_clz(tmp);
+  c += (c_incr & 7);
+  if (c_incr == 8) {
+    c += (8-r);
+    int pp = c >> 3;
+    while(select->hi[pp]==0) {
+      pp++;
+      c += 8;
+    };
+    while(!__getbit2(select->hi,c)) c++;
+
+  };
+  c -= (k2);
+
   return __getbits(q,x*d,d)+((c)<<d);
 }
 
@@ -791,6 +838,7 @@ int selects3_rank(selects3 *select, int i) {
   int d,x,w,y;
   int r,j;
   int z,ii;
+  int xoffset;
   uint *q;
 
   d = select->d;
@@ -798,7 +846,7 @@ int selects3_rank(selects3 *select, int i) {
 
   ii = i>>d;
 
-  y = selectd2_select(select->sd0,ii,0)+1;
+  y = selectd2_select0(select->sd0, ii)+1;
   //  selectd2_select2(select->sd0,ii,0,&y1,&y2);
   //y1++;  y2++;
   //printf("y %d y1 %d  %d\n",y,y1,y2-y1);
@@ -810,15 +858,17 @@ int selects3_rank(selects3 *select, int i) {
   r = y & 7;
   y >>= 3;
   z = select->hi[y];
+  xoffset = x * d;
   while (1) {
     if (((z << r) & 0x80) == 0) break;
-    w = __getbits(q,x*d,d);
+    w = __getbits(q, xoffset, d);
     if (w >= j) {
-      if (w == j) x++;
+      x += (w == j);
       break;
     }
     x++;
     r++;
+    xoffset += d;
     if (r == 8) {
       r = 0;
       y++;