Debug swcsa
[SXSI/TextCollection.git] / RLCSAWrapper.h
1 /******************************************************************************
2  *   Copyright (C) 2010 by Niko Välimäki                                      *
3  *                                                                            *
4  *   RLCSA implementation for the TextCollection interface                  *
5  *                                                                            *
6  *   This program is free software; you can redistribute it and/or modify     *
7  *   it under the terms of the GNU Lesser General Public License as published *
8  *   by the Free Software Foundation; either version 2 of the License, or     *
9  *   (at your option) any later version.                                      *
10  *                                                                            *
11  *   This program is distributed in the hope that it will be useful,          *
12  *   but WITHOUT ANY WARRANTY; without even the implied warranty of           *
13  *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the            *
14  *   GNU Lesser General Public License for more details.                      *
15  *                                                                            *
16  *   You should have received a copy of the GNU Lesser General Public License *
17  *   along with this program; if not, write to the                            *
18  *   Free Software Foundation, Inc.,                                          *
19  *   51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.            *
20  *****************************************************************************/
21
22 #ifndef _RLCSAWrapper_H_
23 #define _RLCSAWrapper_H_
24
25 #include "TextCollection.h"
26 #include "IndelQuery.h"
27 #include "PssmQuery.h"
28
29 #include "incbwt/rlcsa.h"
30
31 // Re-define word size to ulong:
32 #undef W
33 #if __WORDSIZE == 64
34 #   define W 64
35 #else
36 #   define W 32
37 #endif
38
39 #include <stdexcept>
40 #include <set>
41 #include <string>
42 #include <cmath>
43
44 namespace SXSI 
45 {
46
47 /**
48  * Partial implementation of the TextCollection interface
49  *
50  * Supports index construction, save, load and simple search.
51  * Use FMIndex implementation for full support.
52  */
53 class RLCSAWrapper : public SXSI::TextCollection 
54 {
55 public:
56     RLCSAWrapper(const CSA::RLCSA* index)
57         : rlcsa(index)
58     {
59         // Init the edit distance look-up tables
60         MyersEditDistanceIncremental::initMyersFourRussians();
61     }
62
63     ~RLCSAWrapper()
64     {
65         delete rlcsa; rlcsa = 0;
66     }
67
68     bool EmptyText(DocId k) const
69     {
70         return false; // Empty texts are not indexed
71     }
72
73     /**
74      * Extracting one text.
75      *
76      * Call DeleteText() for each pointer returned by GetText()
77      * to avoid possible memory leaks.
78      */
79     uchar * GetText(DocId i) const
80     {
81         return rlcsa->display(i);
82     }
83     void DeleteText(uchar *text) const
84     { 
85         delete [] text;
86     }
87
88     /**
89      * Returns a pointer to the beginning of texts i, i+1, ..., j.
90      * Texts are separated by a '\0' byte.
91      *
92      * Call DeleteText() for each pointer returned by GetText()
93      * to avoid possible memory leaks.
94      */
95     uchar * GetText(DocId i, DocId j) const
96     { 
97         std::cerr << "RLCSAWrapper::GetText(i,j): unsupported method!" << std::endl
98                   << "Use the default (FMIndex) text collection instead!" << std::endl;
99         std::exit(1);
100     }
101
102     bool IsPrefix(uchar const *) const { unsupported(); return false; };
103     bool IsSuffix(uchar const *) const { unsupported(); return false; };
104     bool IsEqual(uchar const *) const { unsupported(); return false; };
105     bool IsContains(uchar const *) const { unsupported(); return false; };
106     bool IsLessThan(uchar const *) const { unsupported(); return false; };
107
108     bool IsPrefix(uchar const *, DocId, DocId) const { unsupported(); return false; };
109     bool IsSuffix(uchar const *, DocId, DocId) const { unsupported(); return false; };
110     bool IsEqual(uchar const *, DocId, DocId) const { unsupported(); return false; };
111     bool IsContains(uchar const *, DocId, DocId) const { unsupported(); return false; };
112     bool IsLessThan(uchar const *, DocId, DocId) const { unsupported(); return false; };
113
114     ulong Count(uchar const *pattern) const
115     {
116 //        std::pair<TextPosition, TextPosition> p = backwards(pattern);
117         CSA::pair_type p = rlcsa->count(std::string((char const *)pattern));
118
119         if (p.first <= p.second)
120             return p.second - p.first + 1;
121         else
122             return 0;
123     }
124
125     unsigned CountPrefix(uchar const *) const { unsupported(); return 0; };
126     unsigned CountSuffix(uchar const *) const { unsupported(); return 0; };
127     unsigned CountEqual(uchar const *) const { unsupported(); return 0; };
128     unsigned CountContains(uchar const *) const { unsupported(); return 0; };
129     unsigned CountLessThan(const unsigned char*) const { unsupported(); return 0; };
130
131     unsigned CountPrefix(uchar const *, DocId, DocId) const { unsupported(); return 0; };
132     unsigned CountSuffix(uchar const *, DocId, DocId) const { unsupported(); return 0; };
133     unsigned CountEqual(uchar const *, DocId, DocId) const { unsupported(); return 0; };
134     unsigned CountContains(uchar const *, DocId, DocId) const { unsupported(); return 0; };
135     unsigned CountLessThan(uchar const *, DocId, DocId) const { unsupported(); return 0; };
136     
137     // Definition of document_result is inherited from SXSI::TextCollection.
138     document_result Prefix(uchar const *) const { unsupported(); return document_result(); };
139     document_result Suffix(uchar const *) const { unsupported(); return document_result(); };
140     document_result Equal(uchar const *) const { unsupported(); return document_result(); };
141
142     document_result Contains(uchar const *pattern) const
143     {        
144         /*************************************************************************
145          * Approximate pattern matching
146          * 
147          * Using suffix filters. Has to enumerate *all* approx. occurrences (sloooow...)
148          * instead of just returning the best occurrence (which is usually much faster).
149          *
150          * Query format: contains("APM 3 GATTACA")
151          * where
152          *          "APM" is the keyword for approximate queries.
153          *            "3" is the maximum edit distance allowed.
154          *      "GATTACA" is the query word to be aligned.
155          */
156         if (strncmp((char const *)pattern, "APM ", 4) == 0)
157         {
158             // Edit distance allowed.
159             int k = std::atoi((char const *)pattern + 4);
160             if (k < 0 || k == INT_MAX || k == INT_MIN)
161                 goto exact_pattern_matching; // Invalid format
162
163             // Find the start of the pattern (i.e. the second ' ')
164             uchar const * tmp = pattern + 4;
165             while (*tmp != ' ' && *tmp != 0) ++tmp;
166             if (*tmp != ' ' || tmp == pattern + 4)
167                 goto exact_pattern_matching; // Invalid format
168
169             IndelQuery iq(this);
170 //            std::cerr << "RLCSAWrapper::Contains(): Pattern: " << tmp+1 << ", k = " << k << std::endl;
171             return iq.align(tmp+1, k);
172         }
173
174         /*************************************************************************
175          * Position Specific Scoring Matrix (PSSM) matching
176          * See PssmQuery.h for usage information.
177          */
178         if (strncmp((char const *)pattern, "PSSM ", 4) == 0)
179         {
180             // Parse threshold
181             double thr = std::atof((char const *)pattern + 5);
182             if (thr <= 0)
183                 goto exact_pattern_matching; // Invalid format
184             
185             // Find the start of the pattern (i.e. the second ' ')
186             uchar const * tmp = pattern + 5;
187             while (*tmp != ' ' && *tmp != 0) ++tmp;
188             if (*tmp != ' ' || tmp == pattern + 5)
189                 goto exact_pattern_matching; // Invalid format
190             
191             PssmQuery pq(this, std::log(thr));
192             //std::cerr << "Pattern: " << tmp+1 << ", log(threshold) = " << std::log(thr) << std::endl;
193             return pq.align(tmp+1, 0);
194         }
195
196         /*************************************************************************
197          * Exact pattern matching
198          */
199     exact_pattern_matching:
200         CSA::pair_type p = rlcsa->count(std::string((char const *)pattern));
201         if (p.first > p.second)
202             return document_result();
203
204         document_result dr;
205         dr.reserve(p.second - p.first + 2);
206         for (ulong i = p.first; i <= p.second; ++i)
207             dr.push_back(rlcsa->getSequenceForPosition(rlcsa->locate(i)));
208         return dr;
209     }
210     document_result LessThan(uchar const *) const { unsupported(); return document_result(); };
211     document_result KMismaches(uchar const *, unsigned) const { unsupported(); return document_result(); };
212     document_result KErrors(uchar const *, unsigned) const { unsupported(); return document_result(); };
213
214     document_result Prefix(uchar const *, DocId, DocId) const { unsupported(); return document_result(); };
215     document_result Suffix(uchar const *, DocId, DocId) const { unsupported(); return document_result(); };
216     document_result Equal(uchar const *, DocId, DocId) const { unsupported(); return document_result(); };
217     document_result Contains(uchar const *, DocId, DocId) const { unsupported(); return document_result(); };
218     document_result LessThan(uchar const *, DocId, DocId) const { unsupported(); return document_result(); };
219
220     // Definition of full_result is inherited from SXSI::TextCollection.
221     full_result FullContains(uchar const *) const { unsupported(); return full_result(); };
222     full_result FullContains(uchar const *, DocId, DocId) const { unsupported(); return full_result(); };
223     full_result FullKMismatches(uchar const *, unsigned) const { unsupported(); return full_result(); };
224     full_result FullKErrors(uchar const *, unsigned) const { unsupported(); return full_result(); };
225
226     // Index from/to disk
227     RLCSAWrapper(FILE *file, char const *filename)
228         : rlcsa(new CSA::RLCSA(std::string(filename)))
229     { 
230         // Init the edit distance look-up tables
231         MyersEditDistanceIncremental::initMyersFourRussians();
232     }
233
234     void Save(FILE *file, char const *filename) const
235     {
236         const char type = 'R';
237         // Saving type info:
238         if (std::fwrite(&type, 1, 1, file) != 1)
239             throw std::runtime_error("RLCSAWrapper::Save(): file write error (type flag).");
240         fflush(file);
241         
242         this->rlcsa->writeTo(std::string(filename));
243     }
244
245     TextCollection::TextPosition getLength() const
246     {
247         return this->rlcsa->getSize() + this->rlcsa->getNumberOfSequences();
248     }
249     
250     inline TextCollection::TextPosition LF(uchar c, TextPosition i) const
251     {
252         ++i;
253         if(i < this->rlcsa->getNumberOfSequences()) 
254             return rlcsa->C(c);
255         return rlcsa->LF(i - this->rlcsa->getNumberOfSequences(), c);
256     }
257
258     uchar* getSuffix(TextPosition pos, unsigned l) const
259     {
260         ++pos;
261         uchar* text = new uchar[l + 1];
262         
263         if(l == 0 || pos < this->rlcsa->getNumberOfSequences())
264         {
265             text[0] = 0;
266             return text;
267         }
268         pos -= this->rlcsa->getNumberOfSequences();
269         
270         unsigned n = rlcsa->displayFromPosition(pos, l, text);
271         text[n] = 0;
272         return text;
273     }
274     
275     DocId getDoc(TextPosition i) const
276     {
277         if(i < this->rlcsa->getNumberOfSequences())
278             return i;
279         
280         CSA::pair_type pt = rlcsa->getRelativePosition(this->rlcsa->locate(i - this->rlcsa->getNumberOfSequences()));
281         return pt.first;
282     }
283
284 private:
285     const CSA::RLCSA* rlcsa;
286
287 /*    std::pair<TextPosition, TextPosition> backwards(uchar const *pattern) const
288     {
289         TextPosition i = std::strlen((char const *)pattern) - 1;
290         int c = (int)pattern[i]; 
291
292         std::cerr << "\nFirst symb = " << (char)c << std::endl;
293
294         TextPosition sp = LF(c, -1);
295         TextPosition ep = LF(c, getLength()-1) - 1;
296         printf("i = %lu, c = %c, sp = %lu, ep = %lu\n", i, pattern[i], sp, ep);
297         while (sp<=ep && i>=1) 
298         {
299             c = (int)pattern[--i];
300             sp = LF(c, sp-1);
301             ep = LF(c, ep)-1;
302             printf("new: c = %c, sp = %lu, ep = %lu\n", pattern[i], sp, ep);
303         }
304
305         return std::make_pair(sp, ep);
306         }*/
307    
308     void unsupported() const
309     { 
310         std::cerr << std::endl << "-------------------------------------------------------------\n"
311             << "RLCSAWrapper: unsupported method!\nSee RLCSAWrapper.h for more details.\n"
312             << "The default index (FMIndex) implements this method!" << std::endl;
313         std::exit(5);
314     }
315 }; // class RLCSAWrapper
316
317 } // namespace SXSI
318
319 #endif