ca21d0a99df0b1df65531b5827117be89dd8c8a5
[SXSI/xpathcomp.git] / src / grammar2.ml
1 type t = {
2   start : Bp.t;
3   tags : int array;
4   rules : int array;
5   rules_offset : int;
6   tag_to_id : (string, int) Hashtbl.t;
7   tag_of_id : string array
8 }
9
10
11
12 module Parse =
13 struct
14
15   let buffer = Buffer.create 512
16
17   let parse_tree cin open_tag close_tag =
18     let rec loop () =
19       let c = input_char cin in
20       match c with
21         '\n'| '>' -> ()
22       | ' ' | ',' | '-' -> loop ()
23       | 'a'..'z' | 'B'..'Z' | '0'..'9' | '_' ->
24         Buffer.clear buffer;
25         Buffer.add_char buffer c;
26         loop_tag false
27
28       | 'A' ->  Buffer.clear buffer;
29         Buffer.add_char buffer c;
30         loop_tag true
31       | ')' -> close_tag (); loop ()
32       | _ -> failwith ("Invalid character: " ^ (String.make 1 c))
33
34     and loop_tag t =
35       let c = input_char cin in
36       match c with
37       | 'a'..'z' | 'A'..'Z' | '0'..'9' | '_' ->
38         Buffer.add_char buffer c;
39         loop_tag t
40       | '(' -> let s = Buffer.contents buffer in
41                open_tag s t;
42                Buffer.clear buffer;
43                loop ()
44       | ' ' -> loop_tag t
45       | ',' | '-'  -> let s = Buffer.contents buffer in
46                       open_tag s t;
47                       close_tag ();
48                       Buffer.clear buffer;
49                       loop ()
50       | ')' -> let s = Buffer.contents buffer in
51                open_tag s t;
52                Buffer.clear buffer;
53                close_tag ();
54                close_tag ();
55                loop ()
56       | _ -> failwith ("Invalid character: " ^ (String.make 1 c))
57     in
58     loop ()
59
60
61   let tag_info = Hashtbl.create 1023
62   let tag_of_id  = Hashtbl.create 1023
63   let current_id = ref 4
64   let init() =
65     Hashtbl.clear tag_info;
66     Hashtbl.clear tag_of_id;
67     current_id := 4;
68     Hashtbl.add tag_info "_ROOT" (0, ~-1, false);
69     Hashtbl.add tag_info "_A" (1, ~-1, false);
70     Hashtbl.add tag_info "_T" (2, ~-1, false);
71     Hashtbl.add tag_info "_AT" (3, ~-1, false);
72     Hashtbl.add tag_info "_"  (4, ~-1, false);
73     Hashtbl.add tag_of_id 0 "_ROOT";
74     Hashtbl.add tag_of_id 1 "_A";
75     Hashtbl.add tag_of_id 2 "_T";
76     Hashtbl.add tag_of_id 3 "_AT";
77     Hashtbl.add tag_of_id 4 "_"
78
79
80   let add_tag s nterm =
81     let id, count, nterm =
82       try Hashtbl.find tag_info s with
83         Not_found ->
84           incr current_id;
85           let id = !current_id in
86           Hashtbl.add tag_of_id id s;
87           (!current_id, ~-1, nterm || s = "START")
88     in
89     let r = id, count+1, nterm in
90     Hashtbl.replace tag_info s r;
91     r
92
93
94   type tree = Node of string * tree list
95
96   let parse_small_tree cin =
97     let stack = ref [ Node("", []) ] in
98     let open_tag s isnterm =
99       if s <> "y0" && s <> "y1" then ignore(add_tag s isnterm);
100       stack := Node(s, []) :: !stack
101     in
102     let close_tag () =
103       match !stack with
104         Node(t1, l1) :: Node(t2, l2) :: r ->
105           stack := Node(t2, Node(t1, List.rev l1)::l2) :: r
106       | _ -> assert false
107     in
108     parse_tree cin open_tag close_tag;
109     match !stack with
110       [ Node(_, [ l ]) ] -> l
111     | _ -> raise End_of_file
112
113   let parse_big_tree cin =
114     let bv = Bp.bitmap_create () in
115     let tags = IntArray.create () in
116     let open_tag s isnterm =
117       let id, _, _ = add_tag s isnterm in
118       Bp.bitmap_push_back bv 1;
119       IntArray.push_back tags id
120     in
121     let close_tag () =
122       Bp.bitmap_push_back bv 0
123     in
124     parse_tree cin open_tag close_tag;
125     Bp.create bv, IntArray.pack tags
126
127   let eat_char cin = ignore (input_char cin)
128
129   let h_find ?(msg="") h i =
130     try
131       Hashtbl.find h i
132     with
133       Not_found ->
134         let r = Obj.repr i in
135         if Obj.is_int r then Printf.eprintf "Not_found (%s): %i\n%!" msg (Obj.magic i);
136         if Obj.tag r = Obj.string_tag then Printf.eprintf "Not_found (%s): %s\n%!" msg (Obj.magic i);
137         raise Not_found
138   ;;
139
140   let parse cin =
141     let rules = Hashtbl.create 1023 in
142     init ();
143     (* START *)
144     ignore (parse_small_tree cin);
145     (* > *)
146     (* ignore (input_char cin); *)
147     let bv, tags = parse_big_tree cin in
148     let () =
149       try
150         while true do
151           let lhs = parse_small_tree cin in
152           let rhs = parse_small_tree cin in
153           Hashtbl.add rules lhs rhs
154         done;
155       with End_of_file -> ()
156     in
157     (* First, re-order the tags *)
158     let old_new_mapping =
159       Array.init (Hashtbl.length tag_of_id)
160         (fun i -> h_find ~msg:"1" tag_of_id i)
161     in
162     Array.fast_sort (fun tag1 tag2 ->
163       let t1, count1, isnterm1 =
164         h_find  ~msg:"2" tag_info tag1
165       and t2, count2, isnterm2 =
166         h_find  ~msg:"3" tag_info tag2
167       in
168       if t1 <= 4 && t2 <= 4 then compare t1 t2
169       else if t1 <= 4 then -1
170       else if t2 <= 4 then 1
171       else
172         if (not isnterm1) && (not isnterm2) then compare t1 t2
173         else if isnterm1 && isnterm2 then
174           match tag1, tag2 with
175             "START", "START" -> 0
176           | "START", _ -> ~-1
177           | _, "START" -> 1
178           | _ -> compare count2 count1
179         else if isnterm2 then -1
180         else 1) old_new_mapping;
181     let tag_to_id = Hashtbl.create 503 in
182     Array.iteri (fun i s ->
183       Hashtbl.add tag_to_id s i) old_new_mapping;
184     let renum_tags = Array.copy tags in
185     for i = 0 to Array.length tags - 1 do
186       renum_tags.(i) <-
187         h_find  ~msg:"4" tag_to_id (h_find  ~msg:"5" tag_of_id (tags.(i)))
188     done;
189     let r_array = Array.create (Hashtbl.length rules) 0 in
190     let rules_offset = h_find  ~msg:"6" tag_to_id "START" + 1 in
191     let pos_id2 l =
192       let rec loop i l =
193         match l with
194           [] -> assert false
195         | Node(tag, children) :: ll ->
196           if tag <> "y0" && tag <> "y1" then
197             tag, i
198           else loop (i+1) ll
199       in
200       loop 1 l
201     in
202     Hashtbl.iter (fun lhs rhs ->
203       let Node( head, args ) = lhs in
204       let Node( tag1, params) = rhs in
205       let tag2, pos2 = pos_id2 params in
206       let id1 = h_find ~msg:"7" tag_to_id tag1
207       and id2 = h_find ~msg:"8" tag_to_id tag2 in
208       let conf =
209         if List.length args = 0 then 0
210         else
211           if List.length args = 1 then
212           if List.length params = 1 then 1
213           else if pos2 = 1 then 2
214           else 3
215           else (* 2 parameters *)
216             if List.length params = 1 then 4
217             else if pos2 = 1 then 5
218             else 6
219       in
220       let rule_ = id2 lsl 27 in
221       let rule_ = (rule_ lor id1) lsl 3 in
222       let rule_ = rule_ lor conf in
223       r_array.((h_find  ~msg:"9" tag_to_id head) - rules_offset ) <- rule_
224     ) rules;
225     (*let l = Array.length renum_tags in *)
226     (*let tag32 = Array32.create l 0 in
227     for i = 0 to l - 1 do
228       Array32.set tag32 i (renum_tags.(i) land 0x7ffffff);
229     done; *)
230     (* Remove the non-terminal names from the hash tables *)
231     let tag_to_id2 = Hashtbl.create 31 in
232     Hashtbl.iter (fun s i -> if i < rules_offset then Hashtbl.add tag_to_id2 s i)
233       tag_to_id;
234     { start = bv;
235       tags = renum_tags;
236       rules = r_array;
237       rules_offset = rules_offset;
238       tag_to_id = tag_to_id2;
239       tag_of_id = Array.sub old_new_mapping 0 rules_offset
240     }
241
242 end
243
244 let parse file =
245   let cin = open_in file in
246   let g = Parse.parse cin in
247   close_in cin;
248   g
249
250 let _GRAMMAR_MAGIC = 0xaabbcc
251 let _GRAMMAR_VERSION = 3
252
253 let save g f =
254   let cout = open_out f in
255   let write a = Marshal.to_channel cout a [  ]
256   in
257   write _GRAMMAR_MAGIC;
258   write _GRAMMAR_VERSION;
259   write g.tags;
260   write g.rules;
261   write g.rules_offset;
262   write g.tag_to_id;
263   write g.tag_of_id;
264   flush cout;
265   let fd = Unix.descr_of_out_channel cout in
266   Bp.save g.start fd;
267   close_out cout
268
269 let load f =
270   let cin = open_in f in
271   let read () = Marshal.from_channel cin in
272   if read () != _GRAMMAR_MAGIC then failwith "Invalid grammar file";
273   if read () != _GRAMMAR_VERSION then failwith "Deprecated grammar format";
274   let tags : int array = read () in
275   let rules : int array = read () in
276   let rules_offset : int = read () in
277   let tag_to_id : (string, int) Hashtbl.t = read () in
278   let tag_of_id : string array = read () in
279   let fd = Unix.descr_of_in_channel cin in
280   let pos = pos_in cin in
281   ignore(Unix.lseek fd pos Unix.SEEK_SET);
282   let bp = Bp.load fd in
283   close_in cin;
284   let g = {
285     start = bp;
286     tags = tags;
287     rules = rules;
288     rules_offset = rules_offset;
289     tag_to_id = tag_to_id;
290     tag_of_id = tag_of_id;
291   } in
292   Printf.eprintf "Grammar size:%i kb\n%!"
293     ((Ocaml.size_b g  + Bp.alloc_stats ())/1024);
294   g
295
296
297 type node = [ `Start ] Node.t
298
299 type n_type = [ `NonTerminal ]
300 type t_type = [ `Terminal ]
301 type r_type = [ `Rule ]
302 type any_type = [ n_type | t_type ]
303 type rhs = [ r_type ] Node.t
304
305 type n_symbol = n_type Node.t
306 type t_symbol = t_type Node.t
307 type tn_symbol = [ any_type ] Node.t
308
309
310 type 'a partial =
311   | Cache of 'a
312   | Leaf of int*int * StateSet.t array * node
313   | Node0 of tn_symbol (* No parameters *)
314   | Node1 of tn_symbol * 'a partial
315   | Node2 of tn_symbol * 'a partial * 'a partial
316
317
318 let is_nil  (t : t_symbol) =
319   (Node.to_int t) == 4
320
321 let nil_symbol : t_symbol =
322   (Node.of_int 4)
323
324 let translate_tag _ t  = if t == 4 then ~-1 else t
325 let to_string t tag =
326   if tag < Array.length t.tag_of_id then t.tag_of_id.(Tag.to_int tag)
327   else "<!INVALID TAG!>"
328
329 let register_tag t s =
330   try Hashtbl.find t.tag_to_id s with
331     Not_found -> 4
332
333 let tag_operations t = {
334   Tag.tag = (fun s -> register_tag t s);
335   Tag.to_string = (fun s -> to_string t s);
336   Tag.translate = (fun s -> translate_tag t s);
337 }
338
339 let start_root : node = Node.of_int 0
340 let start_tag g (idx : node) : [<any_type] Node.t =
341   Node.of_int (g.tags.(Bp.preorder_rank g.start (Node.to_int idx)))
342
343 ;;
344
345 let start_first_child t (idx : node) =
346   Node.of_int (Bp.first_child t.start (Node.to_int idx))
347
348 let start_next_sibling t (idx : node) =
349   Node.of_int (Bp.next_sibling t.start (Node.to_int idx))
350
351 let is_non_terminal t (n : [< any_type ] Node.t) =
352   let n = Node.to_int n in
353   n >= t.rules_offset
354
355 let is_terminal t (n : [< any_type ] Node.t) = not(is_non_terminal t n)
356
357 external terminal : [< any_type ] Node.t -> t_symbol = "%identity"
358 external non_terminal : [< any_type ] Node.t -> n_symbol = "%identity"
359
360
361 let tag (n : t_symbol) : Tag.t = Obj.magic n
362
363 let get_rule g (r : n_symbol) : rhs =
364   Node.of_int (g.rules.((Node.to_int r) - g.rules_offset))
365
366 let get_id1 (r : rhs) : tn_symbol =
367   Node.of_int(((Node.to_int r) lsr 3) land 0x7ffffff)
368
369 let get_id2 (r : rhs) : tn_symbol =
370   Node.of_int((Node.to_int r) lsr 30)
371
372 type conf = | C0 (* B(C) *)
373             | C1 (* B(C(y0)) *)
374             | C2 (* B(C, y0) *)
375             | C3 (* B(y0, C) *)
376             | C4 (* B(C(y0, y1)) *)
377             | C5 (* B(C(y0), y1) *)
378             | C6 (* B(y0, C(y1)) *)
379
380 let get_conf (r : rhs) : conf =
381   (Obj.magic ((Node.to_int r) land 0b111))
382
383
384 let get_rank (r : rhs) : int =
385   match get_conf r with
386   | C0 -> 0
387   | C1 | C2 | C3 -> 1
388   | C4 | C5 | C6 -> 2
389
390 let get_id1_rank (r : rhs) : int =
391   match get_conf r with
392   | C0 | C1 | C4 -> 1
393   | _ -> 2
394
395 let get_id2_pos (r : rhs) : int =
396   match get_conf r with
397   | C0 | C1 |C2 | C4 | C5 -> 1
398   | _ -> 2
399
400 let get_id2_rank (r : rhs) : int =
401   match get_conf r with
402   | C0 | C2 | C3 -> 0
403   | C1 | C5 | C6 -> 1
404   | C4 -> 2
405
406 let is_attribute g tag =
407   tag > 4 && (to_string g tag).[0] == '2'
408
409 let dummy_param : 'a partial = Leaf (~-1,~-1, [||], Node.nil)
410
411 (*
412 let rec start_skip g idx count =
413   if idx < Node.null then count else
414     let symbol = start_tag g idx in
415     if is_terminal g symbol then
416       let symbol = terminal symbol in
417       if symbol == nil_symbol then count else
418         let count = count+1 in
419         let fs = start_first_child g idx in
420         let countl = start_skip g fs count in
421         start_skip g fs countl
422     else
423       let nt = non_terminal symbol in
424       let rhs = get_rule g nt in
425       let nparam = get_rank rhs in
426       match nparam with
427       | 0 -> rule_skip g nt dummy_param dummy_param count
428       | 1 -> rule_skip g nt (Leaf(0,StateSet.empty, Node.nil,start_first_child g idx)) dummy_param count
429       | 2 ->
430         let fc = start_first_child g idx in
431         let ns = start_next_sibling g fc in
432         rule_skip g nt (Leaf (0,[||],fc)) (Leaf (1,[||],ns)) count
433       | _ -> assert false
434
435 and rule_skip g t y0 y1 count =
436   let rhs = get_rule g t in
437   let id1 = get_id1 rhs in
438   let id2 = get_id2 rhs in
439   let conf = get_conf rhs in
440   if is_non_terminal g id1 then
441     let id1 = non_terminal id1 in
442     match conf with
443     | C0 ->rule_skip g id1 (Node0 id2) dummy_param count
444     | C1 -> rule_skip g id1 (Node1(id2,y0)) dummy_param count
445     | C2 -> rule_skip g id1 (Node0 id2) y0 count
446     | C3 -> rule_skip g id1 y0 (Node0 id2) count
447     | C4 -> rule_skip g id1 (Node2(id2, y0, y1)) dummy_param count
448     | C5 -> rule_skip g id1 (Node1(id2, y0)) y1 count
449     | C6 -> rule_skip g id1 y0 (Node1(id2, y1)) count
450   else
451     let id1 = terminal id1 in
452     match conf with
453     | C0 | C1 -> assert false
454     | C2  -> terminal_skip g id1 (Node0 id2) y0 count
455     | C3  -> terminal_skip g id1 y0 (Node0 id2) count
456     | C4  -> assert false
457     | C5  -> terminal_skip g id1 (Node1(id2, y0)) y1 count
458     | C6  -> terminal_skip g id1 y0 (Node1(id2, y1)) count
459
460 and terminal_skip g (symbol : t_symbol) y0 y1 count =
461   if symbol == nil_symbol then count else
462     let count = count + 1 in
463     let countl = partial_skip g y0 count in
464     partial_skip g y1 countl
465
466 and partial_skip g l count =
467   match l with
468   | Cache _ -> assert false
469   | Leaf (_,_,_, id) -> start_skip g id count
470   | Node0 id ->
471     if (terminal id) == nil_symbol then count
472     else
473       rule_skip g (non_terminal id) dummy_param dummy_param count
474
475   | Node1 (id, y0) ->
476     rule_skip g (non_terminal id) y0 dummy_param count
477
478   | Node2 (id, y0, y1) ->
479
480     if is_terminal g id then
481       terminal_skip g (terminal id) y0 y1 count
482     else
483       rule_skip g (non_terminal id) y0 y1 count
484
485
486 let dispatch_param0 conf id2 y0 y1 =
487   match conf with
488   | C0 -> Node0 id2
489   | C1 -> Node1(id2,y0)
490   | C2 -> Node0 id2
491   | C3 -> Node0 id2
492   | C4 -> Node2(id2, y0, y1)
493   | C5 -> Node1(id2, y0)
494   | C6 -> y0
495
496 let dispatch_param1 conf id2 y0 y1 =
497   match conf with
498   | C0 -> dummy_param
499   | C1 -> dummy_param
500   | C2 -> y0
501   | C3 -> Node0 id2
502   | C4 -> dummy_param
503   | C5 -> y1
504   | C6 -> Node1(id2, y1)
505
506 *)