Added approximate pattern matching with Suffix Filters
[SXSI/TextCollection.git] / RLCSAWrapper.h
index 4e22fe8..8fd1d68 100644 (file)
@@ -23,6 +23,7 @@
 #define _RLCSAWrapper_H_
 
 #include "TextCollection.h"
+#include "IndelQuery.h"
 
 #include "incbwt/rlcsa.h"
 
@@ -34,6 +35,7 @@
 #   define W 32
 #endif
 
+#include <stdexcept>
 #include <set>
 #include <string>
 
@@ -51,7 +53,10 @@ class RLCSAWrapper : public SXSI::TextCollection
 public:
     RLCSAWrapper(const CSA::RLCSA* index)
         : rlcsa(index)
-    { /* NOP */ }
+    {
+        // Init the edit distance look-up tables
+        MyersEditDistanceIncremental::initMyersFourRussians();
+    }
 
     ~RLCSAWrapper()
     {
@@ -131,9 +136,43 @@ public:
     document_result Prefix(uchar const *) const { unsupported(); return document_result(); };
     document_result Suffix(uchar const *) const { unsupported(); return document_result(); };
     document_result Equal(uchar const *) const { unsupported(); return document_result(); };
+
     document_result Contains(uchar const *pattern) const
-    {
-        //std::pair<TextPosition,TextPosition> p = backwards(pattern);
+    {        
+        /*************************************************************************
+         * Approximate pattern matching
+         * 
+         * Using suffix filters. Has to enumerate *all* approx. occurrences (sloooow...)
+         * instead of just returning the best occurrence (which is usually much faster).
+         *
+         * Query format: contains("APM 3 GATTACA")
+         * where
+         *          "APM" is the keyword for approximate queries.
+         *            "3" is the maximum edit distance allowed.
+         *      "GATTACA" is the query word to be aligned.
+         */
+        if (strncmp((char const *)pattern, "APM ", 4) == 0)
+        {
+            // Edit distance allowed.
+            int k = std::atoi((char const *)pattern + 4);
+            if (k < 0 || k == INT_MAX || k == INT_MIN)
+                goto exact_pattern_matching; // Invalid format
+
+            // Find the start of the pattern (i.e. the second ' ')
+            uchar const * tmp = pattern + 4;
+            while (*tmp != ' ' && *tmp != 0) ++tmp;
+            if (*tmp != ' ' || tmp == pattern + 4)
+                goto exact_pattern_matching; // Invalid format
+
+            IndelQuery iq(this);
+            //std::cerr << "Pattern: " << tmp+1 << ", k = " << k << std::endl;
+            return iq.align(tmp+1, k);
+        }
+
+        /*************************************************************************
+         * Exact pattern matching
+         */
+    exact_pattern_matching:
         CSA::pair_type p = rlcsa->count(std::string((char const *)pattern));
         if (p.first > p.second)
             return document_result();
@@ -175,6 +214,45 @@ public:
         this->rlcsa->writeTo(std::string(filename));
     }
 
+    TextCollection::TextPosition getLength() const
+    {
+        return this->rlcsa->getSize() + this->rlcsa->getNumberOfSequences();
+    }
+    
+    inline TextCollection::TextPosition LF(uchar c, TextPosition i) const
+    {
+        ++i;
+        if(i < this->rlcsa->getNumberOfSequences()) 
+            return rlcsa->C(c);
+        return rlcsa->LF(i - this->rlcsa->getNumberOfSequences(), c);
+    }
+
+    uchar* getSuffix(TextPosition pos, unsigned l) const
+    {
+        ++pos;
+        uchar* text = new uchar[l + 1];
+        
+        if(l == 0 || pos < this->rlcsa->getNumberOfSequences())
+        {
+            text[0] = 0;
+            return text;
+        }
+        pos -= this->rlcsa->getNumberOfSequences();
+        
+        unsigned n = rlcsa->displayFromPosition(pos, l, text);
+        text[n] = 0;
+        return text;
+    }
+    
+    DocId getDoc(TextPosition i) const
+    {
+        if(i < this->rlcsa->getNumberOfSequences())
+            return i;
+        
+        CSA::pair_type pt = rlcsa->getRelativePosition(this->rlcsa->locate(i - this->rlcsa->getNumberOfSequences()));
+        return pt.first;
+    }
+
 private:
     const CSA::RLCSA* rlcsa;
 
@@ -185,32 +263,20 @@ private:
 
         std::cerr << "\nFirst symb = " << (char)c << std::endl;
 
-        TextPosition sp = rlcsa->C(c);
-        TextPosition ep = rlcsa->C(c+1)-1;
+        TextPosition sp = LF(c, -1);
+        TextPosition ep = LF(c, getLength()-1) - 1;
         printf("i = %lu, c = %c, sp = %lu, ep = %lu\n", i, pattern[i], sp, ep);
         while (sp<=ep && i>=1) 
         {
             c = (int)pattern[--i];
-            sp = LF(c, sp);
-            ep = LF(c, ep+1)-1;
+            sp = LF(c, sp-1);
+            ep = LF(c, ep)-1;
             printf("new: c = %c, sp = %lu, ep = %lu\n", pattern[i], sp, ep);
         }
 
-        uchar* found = rlcsa->display(sp, std::strlen((char const *)pattern), 5, i);
-        std::cerr << "found: " << found << " (" << i << std::endl;
-
         return std::make_pair(sp, ep);
         }*/
-
-    inline TextCollection::TextPosition
-        LF(uchar c, TextPosition i) const
-    {
-        if(i == (TextPosition)-1 || i < this->rlcsa->getNumberOfSequences()) 
-        { return this->rlcsa->C(c); }
-        return this->rlcsa->LF(i - this->rlcsa->getNumberOfSequences(), c);
-    }
-
-    
+   
     void unsupported() const
     { 
         std::cerr << std::endl << "-------------------------------------------------------------\n"