Added RLCSA index option
[SXSI/TextCollection.git] / incbwt / rlcsa_builder.cpp
index 43d5a7d..b11a0e8 100644 (file)
@@ -1,17 +1,27 @@
+#include <algorithm>
+#include <cstdlib>
+#include <cstring>
 #include <iostream>
 
 #include "rlcsa_builder.h"
+#include "misc/utils.h"
+
+#ifdef MULTITHREAD_SUPPORT
+#include <omp.h>
+#endif
 
 
 namespace CSA
 {
 
 
-RLCSABuilder::RLCSABuilder(usint _block_size, usint _sample_rate, usint _buffer_size) :
+RLCSABuilder::RLCSABuilder(usint _block_size, usint _sample_rate, usint _buffer_size, usint _threads) :
   block_size(_block_size), sample_rate(_sample_rate), buffer_size(_buffer_size),
+  threads(_threads),
   buffer(0)
 {
   this->reset();
+  this->build_time = this->search_time = this->sort_time = this->merge_time = 0.0;
 }
 
 RLCSABuilder::~RLCSABuilder()
@@ -33,9 +43,9 @@ RLCSABuilder::insertSequence(char* sequence, usint length, bool delete_sequence)
 
   if(this->buffer == 0)
   {
-    clock_t start = clock();
+    double start = readTimer();
     RLCSA* temp = new RLCSA((uchar*)sequence, length, this->block_size, this->sample_rate, false, false);
-    this->build_time += clock() - start;
+    this->build_time += readTimer() - start;
     this->addRLCSA(temp, (uchar*)sequence, length + 1, delete_sequence);
     return;
   }
@@ -50,13 +60,16 @@ RLCSABuilder::insertSequence(char* sequence, usint length, bool delete_sequence)
   }
   else
   {
-    this->flush();
-    this->buffer = new uchar[this->buffer_size];
+    if(this->chars > 0)
+    {
+      this->flush();
+      this->buffer = new uchar[this->buffer_size];
+    }
     if(length >= this->buffer_size - 1)
     {
-      clock_t start = clock();
+      double start = readTimer();
       RLCSA* temp = new RLCSA((uchar*)sequence, length, this->block_size, this->sample_rate, false, false);
-      this->build_time += clock() - start;
+      this->build_time += readTimer() - start;
       this->addRLCSA(temp, (uchar*)sequence, length + 1, delete_sequence);
     }
     else
@@ -70,6 +83,28 @@ RLCSABuilder::insertSequence(char* sequence, usint length, bool delete_sequence)
   }
 }
 
+void
+RLCSABuilder::insertFromFile(const std::string& base_name)
+{
+  if(!this->ok) { return; }
+
+  if(this->buffer != 0 && this->chars > 0)
+  {
+    this->flush();
+    this->buffer = new uchar[this->buffer_size];
+  }
+
+  std::ifstream input(base_name.c_str(), std::ios_base::binary);
+  if(!input) { return; }
+  RLCSA* increment = new RLCSA(base_name);
+  usint data_size = increment->getSize() + increment->getNumberOfSequences();
+  uchar* data = new uchar[data_size];
+  input.read((char*)data, data_size);
+  input.close();
+
+  this->addRLCSA(increment, data, data_size, true);
+}
+
 RLCSA*
 RLCSABuilder::getRLCSA()
 {
@@ -109,19 +144,25 @@ RLCSABuilder::isOk()
 double
 RLCSABuilder::getBuildTime()
 {
-  return this->build_time / (double)CLOCKS_PER_SEC;
+  return this->build_time;
 }
 
 double
 RLCSABuilder::getSearchTime()
 {
-  return this->search_time / (double)CLOCKS_PER_SEC;
+  return this->search_time;
+}
+
+double
+RLCSABuilder::getSortTime()
+{
+  return this->sort_time;
 }
 
 double
 RLCSABuilder::getMergeTime()
 {
-  return this->merge_time / (double)CLOCKS_PER_SEC;
+  return this->merge_time;
 }
 
 //--------------------------------------------------------------------------
@@ -129,10 +170,10 @@ RLCSABuilder::getMergeTime()
 void
 RLCSABuilder::flush()
 {
-  clock_t start = clock();
+  double start = readTimer();
   RLCSA* temp = new RLCSA(this->buffer, this->chars, this->block_size, this->sample_rate, true, (this->index == 0));
-  this->build_time += clock() - start;
-  this->addRLCSA(temp, this->buffer, this->chars, true);
+  this->build_time += readTimer() - start;
+  this->addRLCSA(temp, this->buffer, this->chars, (this->index != 0));
   this->buffer = 0; this->chars = 0;
 }
 
@@ -141,40 +182,66 @@ RLCSABuilder::addRLCSA(RLCSA* increment, uchar* sequence, usint length, bool del
 {
   if(this->index != 0)
   {
-    clock_t start = clock();
+    double start = readTimer();
 
-    usint* positions = new usint[length];
-    usint begin = 0;
+    usint sequences = increment->getNumberOfSequences();
+    usint* end_markers = new usint[sequences];
+    usint curr = 0;
     for(usint i = 0; i < length - 1; i++)
     {
-      if(sequence[i] == 0)
+      if(sequence[i] == 0) { end_markers[curr++] = i; }
+    }
+    end_markers[sequences - 1] = length - 1;
+
+    usint* positions = new usint[length]; usint begin;
+    #ifdef MULTITHREAD_SUPPORT
+    usint chunk = std::max((usint)1, sequences / (8 * this->threads));
+    omp_set_num_threads(this->threads);
+    #pragma omp parallel private(begin)
+    {
+      #pragma omp for schedule(dynamic, chunk)
+      for(sint i = 0; i < (sint)sequences; i++)
       {
-        this->index->reportPositions(&(sequence[begin]), i - begin, &(positions[begin]));
-        begin = i + 1;
+        if(i > 0) { begin = end_markers[i - 1] + 1; } else { begin = 0; }
+        this->index->reportPositions(sequence + begin, end_markers[i] - begin, positions + begin);
       }
     }
-    this->index->reportPositions(&(sequence[begin]), length - 1 - begin, &(positions[begin]));
+    #else
+    for(sint i = 0; i < (sint)sequences; i++)
+    {
+      if(i > 0) { begin = end_markers[i - 1] + 1; } else { begin = 0; }
+      this->index->reportPositions(sequence + begin, end_markers[i] - begin, positions + begin);
+    }
+    #endif
+    delete[] end_markers;
+    if(delete_sequence) { delete[] sequence; }
+
+    double mark = readTimer();
+    this->search_time += mark - start;
 
+    #ifdef MULTITHREAD_SUPPORT
+    omp_set_num_threads(this->threads);
+    #endif
     std::sort(positions, positions + length);
     for(usint i = 0; i < length; i++)
     {
       positions[i] += i + 1;  // +1 because the insertion will be after positions[i]
     }
-    if(delete_sequence) { delete[] sequence; }
 
-    clock_t mark = clock();
-    this->search_time += mark - start;
+    double sort = readTimer();
+    this->sort_time += sort - mark;
 
-    RLCSA* merged = new RLCSA(*(this->index), *increment, positions, this->block_size);
+    RLCSA* merged = new RLCSA(*(this->index), *increment, positions, this->block_size, this->threads);
     delete[] positions;
     delete this->index;
     delete increment;
     this->index = merged;
 
-    this->merge_time += clock() - mark;  
+    this->merge_time += readTimer() - sort;  
   }
   else
   {
+    if(delete_sequence) { delete[] sequence; }
     this->index = increment;
   }