Finish porting to the Grammar2 API
[SXSI/xpathcomp.git] / src / grammar2.ml
1 type t = {
2   start : Bp.t;
3   tags : int array;
4   rules : int array;
5   rules_offset : int;
6   tag_to_id : (string, int) Hashtbl.t;
7   tag_of_id : string array
8 }
9
10
11
12 module Parse =
13 struct
14
15   let buffer = Buffer.create 512
16
17   let parse_tree cin open_tag close_tag =
18     let rec loop () =
19       let c = input_char cin in
20       match c with
21         '\n'| '>' -> ()
22       | ' ' | ',' | '-' -> loop ()
23       | 'a'..'z' | 'B'..'Z' | '0'..'9' | '_' ->
24         Buffer.clear buffer;
25         Buffer.add_char buffer c;
26         loop_tag false
27
28       | 'A' ->  Buffer.clear buffer;
29         Buffer.add_char buffer c;
30         loop_tag true
31       | ')' -> close_tag (); loop ()
32       | _ -> failwith ("Invalid character: " ^ (String.make 1 c))
33
34     and loop_tag t =
35       let c = input_char cin in
36       match c with
37       | 'a'..'z' | 'A'..'Z' | '0'..'9' | '_' ->
38         Buffer.add_char buffer c;
39         loop_tag t
40       | '(' -> let s = Buffer.contents buffer in
41                open_tag s t;
42                Buffer.clear buffer;
43                loop ()
44       | ' ' -> loop_tag t
45       | ',' | '-'  -> let s = Buffer.contents buffer in
46                       open_tag s t;
47                       close_tag ();
48                       Buffer.clear buffer;
49                       loop ()
50       | ')' -> let s = Buffer.contents buffer in
51                open_tag s t;
52                Buffer.clear buffer;
53                close_tag ();
54                close_tag ();
55                loop ()
56       | _ -> failwith ("Invalid character: " ^ (String.make 1 c))
57     in
58     loop ()
59
60
61   let tag_info = Hashtbl.create 1023
62   let tag_of_id  = Hashtbl.create 1023
63   let current_id = ref 4
64   let init() =
65     Hashtbl.clear tag_info;
66     Hashtbl.clear tag_of_id;
67     current_id := 4;
68     Hashtbl.add tag_info "_ROOT" (0, ~-1, false);
69     Hashtbl.add tag_info "_A" (1, ~-1, false);
70     Hashtbl.add tag_info "_T" (2, ~-1, false);
71     Hashtbl.add tag_info "_AT" (3, ~-1, false);
72     Hashtbl.add tag_info "_"  (4, ~-1, false);
73     Hashtbl.add tag_of_id 0 "_ROOT";
74     Hashtbl.add tag_of_id 1 "_A";
75     Hashtbl.add tag_of_id 2 "_T";
76     Hashtbl.add tag_of_id 3 "_AT";
77     Hashtbl.add tag_of_id 4 "_"
78
79
80   let add_tag s nterm =
81     let id, count, nterm =
82       try Hashtbl.find tag_info s with
83         Not_found ->
84           incr current_id;
85           let id = !current_id in
86           Hashtbl.add tag_of_id id s;
87           (!current_id, ~-1, nterm || s = "START")
88     in
89     let r = id, count+1, nterm in
90     Hashtbl.replace tag_info s r;
91     r
92
93
94   type tree = Node of string * tree list
95
96   let parse_small_tree cin =
97     let stack = ref [ Node("", []) ] in
98     let open_tag s isnterm =
99       if s <> "y0" && s <> "y1" then ignore(add_tag s isnterm);
100       stack := Node(s, []) :: !stack
101     in
102     let close_tag () =
103       match !stack with
104         Node(t1, l1) :: Node(t2, l2) :: r ->
105           stack := Node(t2, Node(t1, List.rev l1)::l2) :: r
106       | _ -> assert false
107     in
108     parse_tree cin open_tag close_tag;
109     match !stack with
110       [ Node(_, [ l ]) ] -> l
111     | _ -> raise End_of_file
112
113   let parse_big_tree cin =
114     let bv = Bp.bitmap_create () in
115     let tags = IntArray.create () in
116     let open_tag s isnterm =
117       let id, _, _ = add_tag s isnterm in
118       Bp.bitmap_push_back bv 1;
119       IntArray.push_back tags id
120     in
121     let close_tag () =
122       Bp.bitmap_push_back bv 0
123     in
124     parse_tree cin open_tag close_tag;
125     Bp.create bv, IntArray.pack tags
126
127   let eat_char cin = ignore (input_char cin)
128
129   let h_find ?(msg="") h i =
130     try
131       Hashtbl.find h i
132     with
133       Not_found ->
134         let r = Obj.repr i in
135         if Obj.is_int r then Printf.eprintf "Not_found (%s): %i\n%!" msg (Obj.magic i);
136         if Obj.tag r = Obj.string_tag then Printf.eprintf "Not_found (%s): %s\n%!" msg (Obj.magic i);
137         raise Not_found
138   ;;
139
140   let parse cin =
141     let rules = Hashtbl.create 1023 in
142     init ();
143     (* START *)
144     ignore (parse_small_tree cin);
145     (* > *)
146     (* ignore (input_char cin); *)
147     let bv, tags = parse_big_tree cin in
148     let () =
149       try
150         while true do
151           let lhs = parse_small_tree cin in
152           let rhs = parse_small_tree cin in
153           Hashtbl.add rules lhs rhs
154         done;
155       with End_of_file -> ()
156     in
157     (* First, re-order the tags *)
158     let old_new_mapping =
159       Array.init (Hashtbl.length tag_of_id)
160         (fun i -> h_find ~msg:"1" tag_of_id i)
161     in
162     Array.fast_sort (fun tag1 tag2 ->
163       let t1, count1, isnterm1 =
164         h_find  ~msg:"2" tag_info tag1
165       and t2, count2, isnterm2 =
166         h_find  ~msg:"3" tag_info tag2
167       in
168       if t1 <= 4 && t2 <= 4 then compare t1 t2
169       else if t1 <= 4 then -1
170       else if t2 <= 4 then 1
171       else
172         if (not isnterm1) && (not isnterm2) then compare t1 t2
173         else if isnterm1 && isnterm2 then
174           match tag1, tag2 with
175             "START", "START" -> 0
176           | "START", _ -> ~-1
177           | _, "START" -> 1
178           | _ -> compare count2 count1
179         else if isnterm2 then -1
180         else 1) old_new_mapping;
181     let tag_to_id = Hashtbl.create 503 in
182     Array.iteri (fun i s ->
183       Hashtbl.add tag_to_id s i) old_new_mapping;
184     let renum_tags = Array.copy tags in
185     for i = 0 to Array.length tags - 1 do
186       renum_tags.(i) <-
187         h_find  ~msg:"4" tag_to_id (h_find  ~msg:"5" tag_of_id (tags.(i)))
188     done;
189     let r_array = Array.create (Hashtbl.length rules) 0 in
190     let rules_offset = h_find  ~msg:"6" tag_to_id "START" + 1 in
191     let pos_id2 l =
192       let rec loop i l =
193         match l with
194           [] -> assert false
195         | Node(tag, children) :: ll ->
196           if tag <> "y0" && tag <> "y1" then
197             tag, i
198           else loop (i+1) ll
199       in
200       loop 1 l
201     in
202     Hashtbl.iter (fun lhs rhs ->
203       let Node( head, args ) = lhs in
204       let Node( tag1, params) = rhs in
205       let tag2, pos2 = pos_id2 params in
206       let id1 = h_find ~msg:"7" tag_to_id tag1
207       and id2 = h_find ~msg:"8" tag_to_id tag2 in
208       let rule_ = id2 lsl 27 in
209       let rule_ = (rule_ lor id1) lsl 2 in
210       let rule_ = (rule_ lor pos2) lsl 2 in
211       let rule_ = (rule_ lor (List.length params)) lsl 2 in
212       let rule_ = rule_ lor (List.length args) in
213       r_array.((h_find  ~msg:"9" tag_to_id head) - rules_offset ) <- rule_
214     ) rules;
215     (*let l = Array.length renum_tags in *)
216     (*let tag32 = Array32.create l 0 in
217     for i = 0 to l - 1 do
218       Array32.set tag32 i (renum_tags.(i) land 0x7ffffff);
219     done; *)
220     (* Remove the non-terminal names from the hash tables *)
221     let tag_to_id2 = Hashtbl.create 31 in
222     Hashtbl.iter (fun s i -> if i < rules_offset then Hashtbl.add tag_to_id2 s i)
223       tag_to_id;
224     { start = bv;
225       tags = renum_tags;
226       rules = r_array;
227       rules_offset = rules_offset;
228       tag_to_id = tag_to_id2;
229       tag_of_id = Array.sub old_new_mapping 0 rules_offset
230     }
231
232 end
233
234 let parse file =
235   let cin = open_in file in
236   let g = Parse.parse cin in
237   close_in cin;
238   g
239
240 let _GRAMMAR_MAGIC = 0xaabbcc
241 let _GRAMMAR_VERSION = 2
242
243 let save g f =
244   let cout = open_out f in
245   let write a = Marshal.to_channel cout a [  ]
246   in
247   write _GRAMMAR_MAGIC;
248   write _GRAMMAR_VERSION;
249   write g.tags;
250   write g.rules;
251   write g.rules_offset;
252   write g.tag_to_id;
253   write g.tag_of_id;
254   flush cout;
255   let fd = Unix.descr_of_out_channel cout in
256   Bp.save g.start fd;
257   close_out cout
258
259 let load f =
260   let cin = open_in f in
261   let read () = Marshal.from_channel cin in
262   if read () != _GRAMMAR_MAGIC then failwith "Invalid grammar file";
263   if read () != _GRAMMAR_VERSION then failwith "Deprecated grammar format";
264   let tags : int array = read () in
265   let rules : int array = read () in
266   let rules_offset : int = read () in
267   let tag_to_id : (string, int) Hashtbl.t = read () in
268   let tag_of_id : string array = read () in
269   let fd = Unix.descr_of_in_channel cin in
270   let pos = pos_in cin in
271   ignore(Unix.lseek fd pos Unix.SEEK_SET);
272   let bp = Bp.load fd in
273   close_in cin;
274   {
275     start = bp;
276     tags = tags;
277     rules = rules;
278     rules_offset = rules_offset;
279     tag_to_id = tag_to_id;
280     tag_of_id = tag_of_id;
281   }
282
283
284 type node = [ `Start ] Node.t
285
286 type n_type = [ `NonTerminal ]
287 type t_type = [ `Terminal ]
288 type r_type = [ `Rule ]
289 type any_type = [ n_type | t_type ]
290 type rhs = [ r_type ] Node.t
291
292 type n_symbol = n_type Node.t
293 type t_symbol = t_type Node.t
294 type tn_symbol = [ any_type ] Node.t
295
296
297 type partial =
298     Leaf of node
299   | Node of tn_symbol * partial array
300
301
302 let is_nil  (t : t_symbol) =
303   (Node.to_int t) == 4
304
305 let nil_symbol : t_symbol =
306   (Node.of_int 4)
307
308 let translate_tag _ t  = if t == 4 then ~-1 else t
309 let to_string t tag =
310   if tag < Array.length t.tag_of_id then t.tag_of_id.(Tag.to_int tag)
311   else "<!INVALID TAG!>"
312
313 let register_tag t s =
314   try Hashtbl.find t.tag_to_id s with
315     Not_found -> 4
316
317 let tag_operations t = {
318   Tag.tag = (fun s -> register_tag t s);
319   Tag.to_string = (fun s -> to_string t s);
320   Tag.translate = (fun s -> translate_tag t s);
321 }
322
323 let start_root : node = Node.of_int 0
324 let start_tag g (idx : node) : [<any_type] Node.t =
325   Node.of_int (g.tags.(Bp.preorder_rank g.start (Node.to_int idx)))
326
327 ;;
328
329 let start_first_child t (idx : node) =
330   Node.of_int (Bp.first_child t.start (Node.to_int idx))
331
332 let start_next_sibling t (idx : node) =
333   Node.of_int (Bp.next_sibling t.start (Node.to_int idx))
334
335 let is_non_terminal t (n : [< any_type ] Node.t) =
336   let n = Node.to_int n in
337   n >= t.rules_offset
338
339 let is_terminal t (n : [< any_type ] Node.t) = not(is_non_terminal t n)
340
341 external terminal : [< any_type ] Node.t -> t_symbol = "%identity"
342 external non_terminal : [< any_type ] Node.t -> n_symbol = "%identity"
343
344
345 let tag (n : t_symbol) : Tag.t = Obj.magic n
346
347 let get_rule g (r : n_symbol) : rhs =
348   Node.of_int (g.rules.((Node.to_int r) - g.rules_offset))
349
350 let get_id1 (r : rhs) : tn_symbol =
351   Node.of_int(
352     ((Node.to_int r) lsr 6) land 0x7ffffff)
353
354 let get_id2 (r : rhs) : tn_symbol =
355   Node.of_int((Node.to_int r) lsr 33)
356
357 let get_rank (r : rhs) : int =
358   (Node.to_int r) land 0b11
359
360 let get_id1_rank (r : rhs) : int =
361   ((Node.to_int r) lsr 2) land 0b11
362
363 let get_id2_pos (r : rhs) : int =
364   ((Node.to_int r) lsr 4) land 0b11
365
366 let get_id2_rank (r : rhs) : int =
367   get_rank r  + 1 - get_id1_rank r