Added TextCollection::Save() and TextCollection::Load() functionality
[SXSI/TextCollection.git] / CSA.cpp
diff --git a/CSA.cpp b/CSA.cpp
index a15c733..754dfa8 100644 (file)
--- a/CSA.cpp
+++ b/CSA.cpp
@@ -25,6 +25,7 @@
 #include <set>
 #include <vector>
 #include <utility>
+#include <stdexcept>
 #include <cassert>
 #include <cstring> // For strlen()
 using std::vector;
@@ -190,58 +191,22 @@ void CSA::MakeStatic()
     delete dynFMI;
     dynFMI = 0;
 
-    ulong i, min = 0,
-             max;
-    for (i=0;i<256;i++)
-        C[i]=0;
-    for (i=0;i<n;++i)
-        C[(int)bwt[i]]++;
-    for (i=0;i<256;i++)
-        if (C[i]>0) {min = i; break;}          
-    for (i=255;i>=min;--i)
-        if (C[i]>0) {max = i; break;}
-    
-    // Print frequencies
-/*    for(i = 0; i < 256; i++)
-        if (C[i]>0) printf("C[%lu] = %lu\n", i, C[i]);
-        fflush(stdout);*/
-
-    ulong prev=C[0], temp;
-    C[0]=0;
-    for (i=1;i<256;i++) {          
-        temp = C[i];
-        C[i]=C[i-1]+prev;
-        prev = temp;
-    }
-    this->codetable = node::makecodetable(bwt,n);
-    alphabetrank = new THuffAlphabetRank(bwt,n, this->codetable,0);   
-    //if (alphabetrank->Test(bwt,n)) printf("alphabetrank ok\n");    
-
-/*    for (i = 0; i < n; ++i)
-    {
-        uchar c = alphabetrank->charAtPos(i);
-        TextPosition j = C[c]+alphabetrank->rank(c, i)-1;
-        printf("LF[%lu] = %lu\n", i, j);
-        }*/
+    makewavelet(bwt);     
 
     // Calculate BWT end-marker position (of last inserted text)
-    i = 0;
+    ulong i = 0;
     while (bwt[i] != '\0')
     {
-        uchar c = alphabetrank->charAtPos(i);
+        uchar c = bwt[i];
         i = C[c]+alphabetrank->rank(c, i)-1;
     }
     bwtEndPos = i;
     //printf("bwtEndPos = %lu\n", bwtEndPos);
-     
+
     delete [] bwt;
 
     // Make sampling tables
     maketables();
-    // to avoid running out of unsigned, the sizes are computed in specific order (large/small)*small
-    // |class CSA| +256*|TCodeEntry|+|C[]|+|suffixes[]+positions[]|+...       
-    //printf("FMindex takes %d B\n",
-    //    6*W/8+256*3*W/8+256*W/8+ (2*n/(samplerate*8))*W+sampled->SpaceRequirementInBits()/8+alphabetrank->SpaceRequirementInBits()/8+W/8);
 }
 
 
@@ -648,18 +613,128 @@ TextCollection::DocId CSA::DocIdAtTextPos(TextPosition i) const
     return a;
 }
 
-void CSA::Load(FILE *filename, unsigned samplerate)
+/**
+ * Save index to a file handle
+ *
+ * Throws a std::runtime_error exception on i/o error.
+ * First byte that is saved represents the version number of the save file.
+ * In version 1 files, the data fields are:
+ *        <uchar version> version info;
+ *        <unsigned s>    samplerate;
+ *        <ulong n>       length of the BWT;
+ *        <ulong bwt>     end marker position in BWT;
+ *        <uchar *>       BWT string of length n;
+ *        <unsigned r>    number of texts;
+ *        <vector textLength>
+ *                        array of <ulong, ulong> pairs.
+ * 
+ * TODO: Save the data structures instead of BWT sequence?
+ */
+void CSA::Save(FILE *file) const
 {
-    // TODO
+    // Saving version 1 data:
+    uchar versionFlag = 1;
+    if (std::fwrite(&versionFlag, 1, 1, file) != 1)
+        throw std::runtime_error("CSA::Save(): file write error (version flag).");
 
+    if (std::fwrite(&(this->samplerate), sizeof(unsigned), 1, file) != 1)
+        throw std::runtime_error("CSA::Save(): file write error (samplerate).");
+
+    if (std::fwrite(&(this->n), sizeof(TextPosition), 1, file) != 1)
+        throw std::runtime_error("CSA::Save(): file write error (n).");
+
+    if (std::fwrite(&(this->bwtEndPos), sizeof(TextPosition), 1, file) != 1)
+        throw std::runtime_error("CSA::Save(): file write error (bwt end position).");
+    
+    for (ulong offset = 0; offset < n; offset ++)
+    {
+        uchar c = alphabetrank->charAtPos(offset);
+        if (std::fwrite(&c, sizeof(uchar), 1, file) != 1)
+            throw std::runtime_error("CSA::Save(): file write error (bwt sequence).");
+    }
+
+    unsigned r = textLength.size();
+    if (std::fwrite(&r, sizeof(unsigned), 1, file) != 1)
+        throw std::runtime_error("CSA::Save(): file write error (r).");
+    
+    for (r = 0; r < textLength.size(); ++ r)
+    {
+        if (std::fwrite(&(textLength[r].first), sizeof(TextPosition), 1, file) != 1)
+            throw std::runtime_error("CSA::Save(): file write error (text length).");
+        if (std::fwrite(&(textLength[r].second), sizeof(TextPosition), 1, file) != 1)
+            throw std::runtime_error("CSA::Save(): file write error (text start).");
+    }    
 }
 
-void CSA::Save(FILE *filename)
+
+/**
+ * Load index from a file handle
+ *
+ * Throws a std::runtime_error exception on i/o error.
+ * For more info, see CSA::Save().
+ */
+void CSA::Load(FILE *file, unsigned samplerate)
 {
-    // TODO
+    // Delete everything
+    delete dynFMI;       dynFMI = 0;
+    delete alphabetrank; alphabetrank = 0;
+    delete sampled;      sampled = 0;
+    delete [] suffixes;  suffixes = 0;
+    delete [] positions; positions = 0;
+    delete [] codetable; codetable = 0;
+
+    endmarkers.clear();
+    textLength.clear();
+
+    this->samplerate = samplerate;
+    this->n = 0;
+
+    uchar versionFlag = 0;
+    if (std::fread(&versionFlag, 1, 1, file) != 1)
+        throw std::runtime_error("CSA::Load(): file read error (version flag).");
+    if (versionFlag != 1)
+        throw std::runtime_error("CSA::Load(): invalid start byte.");
+
+    if (std::fread(&samplerate, sizeof(unsigned), 1, file) != 1)
+        throw std::runtime_error("CSA::Load(): file read error (samplerate).");
+    if (this->samplerate == 0)
+        this->samplerate = samplerate;
+
+    if (std::fread(&(this->n), sizeof(TextPosition), 1, file) != 1)
+        throw std::runtime_error("CSA::Load(): file read error (n).");
+
+    if (std::fread(&(this->bwtEndPos), sizeof(TextPosition), 1, file) != 1)
+        throw std::runtime_error("CSA::Load(): file read error (bwt end position).");
+
+    uchar *bwt = new uchar[n];
+    for (ulong offset = 0; offset < n; offset ++)
+        if (std::fread((bwt + offset), sizeof(uchar), 1, file) != 1)
+            throw std::runtime_error("CSA::Load(): file read error (bwt sequence).");
+
+    unsigned r = 0;
+    if (std::fread(&r, sizeof(unsigned), 1, file) != 1)
+        throw std::runtime_error("CSA::Load(): file read error (r).");
+    
+    while (r > 0)
+    {
+        TextPosition length = 0, start = 0;
+        if (std::fread(&length, sizeof(TextPosition), 1, file) != 1)
+            throw std::runtime_error("CSA::Load(): file read error (text length).");
+        if (std::fread(&start, sizeof(TextPosition), 1, file) != 1)
+            throw std::runtime_error("CSA::Load(): file read error (text start).");
+
+        textLength.push_back(make_pair(length, start));
+        --r;
+    }
+
+    // Construct data structures
+    makewavelet(bwt);
+    delete [] bwt;
+    maketables();
 }
 
 
+
 /**
  * Rest of the functions follow...
  */
@@ -789,6 +864,7 @@ CSA::TextPosition CSA::Lookup(TextPosition i) const // Time complexity: O(sample
 }
 
 CSA::~CSA() {
+    delete dynFMI;
     delete alphabetrank;       
     delete sampled;
     delete [] suffixes;
@@ -796,6 +872,42 @@ CSA::~CSA() {
     delete [] codetable;
 }
 
+void CSA::makewavelet(uchar *bwt)
+{
+    ulong i, min = 0,
+             max;
+    for (i=0;i<256;i++)
+        C[i]=0;
+    for (i=0;i<n;++i)
+        C[(int)bwt[i]]++;
+    for (i=0;i<256;i++)
+        if (C[i]>0) {min = i; break;}          
+    for (i=255;i>=min;--i)
+        if (C[i]>0) {max = i; break;}
+    
+    // Print frequencies
+/*    for(i = 0; i < 256; i++)
+        if (C[i]>0) printf("C[%lu] = %lu\n", i, C[i]);
+        fflush(stdout);*/
+
+    ulong prev=C[0], temp;
+    C[0]=0;
+    for (i=1;i<256;i++) {          
+        temp = C[i];
+        C[i]=C[i-1]+prev;
+        prev = temp;
+    }
+    this->codetable = node::makecodetable(bwt,n);
+    alphabetrank = new THuffAlphabetRank(bwt,n, this->codetable,0);   
+    //if (alphabetrank->Test(bwt,n)) printf("alphabetrank ok\n");    
+
+/*    for (i = 0; i < n; ++i)
+    {
+        uchar c = alphabetrank->charAtPos(i);
+        TextPosition j = C[c]+alphabetrank->rank(c, i)-1;
+        printf("LF[%lu] = %lu\n", i, j);
+        }*/
+}
 
 void CSA::maketables()
 {