Add function to cast tags to integers.
[SXSI/xpathcomp.git] / src / grammar2.ml
1 type t = {
2   start : Bp.t;
3   tags : int array;
4   rules : int array;
5   rules_offset : int;
6   tag_to_id : (string, int) Hashtbl.t;
7   tag_of_id : string array
8 }
9
10
11
12 module Parse =
13 struct
14
15   let buffer = Buffer.create 512
16
17   let parse_tree cin open_tag close_tag =
18     let rec loop () =
19       let c = input_char cin in
20       match c with
21         '\n'| '>' -> ()
22       | ' ' | ',' | '-' -> loop ()
23       | 'a'..'z' | 'B'..'Z' | '0'..'9' | '_' ->
24         Buffer.clear buffer;
25         Buffer.add_char buffer c;
26         loop_tag false
27
28       | 'A' ->  Buffer.clear buffer;
29         Buffer.add_char buffer c;
30         loop_tag true
31       | ')' -> close_tag (); loop ()
32       | _ -> failwith ("Invalid character: " ^ (String.make 1 c))
33
34     and loop_tag t =
35       let c = input_char cin in
36       match c with
37       | 'a'..'z' | 'A'..'Z' | '0'..'9' | '_' ->
38         Buffer.add_char buffer c;
39         loop_tag t
40       | '(' -> let s = Buffer.contents buffer in
41                open_tag s t;
42                Buffer.clear buffer;
43                loop ()
44       | ' ' -> loop_tag t
45       | ',' | '-'  -> let s = Buffer.contents buffer in
46                       open_tag s t;
47                       close_tag ();
48                       Buffer.clear buffer;
49                       loop ()
50       | ')' -> let s = Buffer.contents buffer in
51                open_tag s t;
52                Buffer.clear buffer;
53                close_tag ();
54                close_tag ();
55                loop ()
56       | _ -> failwith ("Invalid character: " ^ (String.make 1 c))
57     in
58     loop ()
59
60
61   let tag_info = Hashtbl.create 1023
62   let tag_of_id  = Hashtbl.create 1023
63   let current_id = ref 4
64   let init() =
65     Hashtbl.clear tag_info;
66     Hashtbl.clear tag_of_id;
67     current_id := 4;
68     Hashtbl.add tag_info "_ROOT" (0, ~-1, false);
69     Hashtbl.add tag_info "_A" (1, ~-1, false);
70     Hashtbl.add tag_info "_T" (2, ~-1, false);
71     Hashtbl.add tag_info "_AT" (3, ~-1, false);
72     Hashtbl.add tag_info "_"  (4, ~-1, false);
73     Hashtbl.add tag_of_id 0 "_ROOT";
74     Hashtbl.add tag_of_id 1 "_A";
75     Hashtbl.add tag_of_id 2 "_T";
76     Hashtbl.add tag_of_id 3 "_AT";
77     Hashtbl.add tag_of_id 4 "_"
78
79
80   let add_tag s nterm =
81     let id, count, nterm =
82       try Hashtbl.find tag_info s with
83         Not_found ->
84           incr current_id;
85           let id = !current_id in
86           Hashtbl.add tag_of_id id s;
87           (!current_id, ~-1, nterm || s = "START")
88     in
89     let r = id, count+1, nterm in
90     Hashtbl.replace tag_info s r;
91     r
92
93
94   type tree = Node of string * tree list
95
96   let parse_small_tree cin =
97     let stack = ref [ Node("", []) ] in
98     let open_tag s isnterm =
99       if s <> "y0" && s <> "y1" then ignore(add_tag s isnterm);
100       stack := Node(s, []) :: !stack
101     in
102     let close_tag () =
103       match !stack with
104         Node(t1, l1) :: Node(t2, l2) :: r ->
105           stack := Node(t2, Node(t1, List.rev l1)::l2) :: r
106       | _ -> assert false
107     in
108     parse_tree cin open_tag close_tag;
109     match !stack with
110       [ Node(_, [ l ]) ] -> l
111     | _ -> raise End_of_file
112
113   let parse_big_tree cin =
114     let bv = Bp.bitmap_create () in
115     let tags = IntArray.create () in
116     let open_tag s isnterm =
117       let id, _, _ = add_tag s isnterm in
118       Bp.bitmap_push_back bv 1;
119       IntArray.push_back tags id
120     in
121     let close_tag () =
122       Bp.bitmap_push_back bv 0
123     in
124     parse_tree cin open_tag close_tag;
125     Bp.create bv, IntArray.pack tags
126
127   let eat_char cin = ignore (input_char cin)
128
129   let h_find ?(msg="") h i =
130     try
131       Hashtbl.find h i
132     with
133       Not_found ->
134         let r = Obj.repr i in
135         if Obj.is_int r then Printf.eprintf "Not_found (%s): %i\n%!" msg (Obj.magic i);
136         if Obj.tag r = Obj.string_tag then Printf.eprintf "Not_found (%s): %s\n%!" msg (Obj.magic i);
137         raise Not_found
138   ;;
139
140   let parse cin =
141     let rules = Hashtbl.create 1023 in
142     init ();
143     (* START *)
144     ignore (parse_small_tree cin);
145     (* > *)
146     (* ignore (input_char cin); *)
147     let bv, tags = parse_big_tree cin in
148     let () =
149       try
150         while true do
151           let lhs = parse_small_tree cin in
152           let rhs = parse_small_tree cin in
153           Hashtbl.add rules lhs rhs
154         done;
155       with End_of_file -> ()
156     in
157     (* First, re-order the tags *)
158     let old_new_mapping =
159       Array.init (Hashtbl.length tag_of_id)
160         (fun i -> h_find ~msg:"1" tag_of_id i)
161     in
162     Array.fast_sort (fun tag1 tag2 ->
163       let t1, count1, isnterm1 =
164         h_find  ~msg:"2" tag_info tag1
165       and t2, count2, isnterm2 =
166         h_find  ~msg:"3" tag_info tag2
167       in
168       if t1 <= 4 && t2 <= 4 then compare t1 t2
169       else if t1 <= 4 then -1
170       else if t2 <= 4 then 1
171       else
172         if (not isnterm1) && (not isnterm2) then compare t1 t2
173         else if isnterm1 && isnterm2 then
174           match tag1, tag2 with
175             "START", "START" -> 0
176           | "START", _ -> ~-1
177           | _, "START" -> 1
178           | _ -> compare count2 count1
179         else if isnterm2 then -1
180         else 1) old_new_mapping;
181     let tag_to_id = Hashtbl.create 503 in
182     Array.iteri (fun i s ->
183       Hashtbl.add tag_to_id s i) old_new_mapping;
184     let renum_tags = Array.copy tags in
185     for i = 0 to Array.length tags - 1 do
186       renum_tags.(i) <-
187         h_find  ~msg:"4" tag_to_id (h_find  ~msg:"5" tag_of_id (tags.(i)))
188     done;
189     let r_array = Array.create (Hashtbl.length rules) 0 in
190     let rules_offset = h_find  ~msg:"6" tag_to_id "START" + 1 in
191     let pos_id2 l =
192       let rec loop i l =
193         match l with
194           [] -> assert false
195         | Node(tag, children) :: ll ->
196           if tag <> "y0" && tag <> "y1" then
197             tag, i
198           else loop (i+1) ll
199       in
200       loop 1 l
201     in
202     Hashtbl.iter (fun lhs rhs ->
203       let Node( head, _ ) = lhs in
204       let Node( tag1, params) = rhs in
205       let tag2, pos2 = pos_id2 params in
206       let id1 = h_find ~msg:"7" tag_to_id tag1
207       and id2 = h_find ~msg:"8" tag_to_id tag2
208       in
209       let rule_ = id2 lsl 27 in
210       let rule_ = (rule_ lor id1) lsl 2 in
211       let rule_ = (rule_ lor pos2) lsl 2 in
212       let rule_ = rule_ lor (List.length params) in
213       r_array.((h_find  ~msg:"9" tag_to_id head) - rules_offset ) <- rule_
214     ) rules;
215     (*let l = Array.length renum_tags in *)
216     (*let tag32 = Array32.create l 0 in
217     for i = 0 to l - 1 do
218       Array32.set tag32 i (renum_tags.(i) land 0x7ffffff);
219     done; *)
220     (* Remove the non-terminal names from the hash tables *)
221     let tag_to_id2 = Hashtbl.create 31 in
222     Hashtbl.iter (fun s i -> if i < rules_offset then Hashtbl.add tag_to_id2 s i)
223       tag_to_id;
224     { start = bv;
225       tags = renum_tags;
226       rules = r_array;
227       rules_offset = rules_offset;
228       tag_to_id = tag_to_id2;
229       tag_of_id = Array.sub old_new_mapping 0 rules_offset
230     }
231
232 end
233
234 let parse file =
235   let cin = open_in file in
236   let g = Parse.parse cin in
237   close_in cin;
238   g
239
240 let _GRAMMAR_MAGIC = 0xaabbcc
241 let _GRAMMAR_VERSION = 2
242
243 let save g f =
244   let cout = open_out f in
245   let write a = Marshal.to_channel cout a [  ]
246   in
247   write _GRAMMAR_MAGIC;
248   write _GRAMMAR_VERSION;
249   write g.tags;
250   write g.rules;
251   write g.rules_offset;
252   write g.tag_to_id;
253   write g.tag_of_id;
254   flush cout;
255   let fd = Unix.descr_of_out_channel cout in
256   Bp.save g.start fd;
257   close_out cout
258
259 let load f =
260   let cin = open_in f in
261   let pr_pos () =
262     Printf.eprintf "Position: %i kiB\n" (pos_in cin / 1024)
263   in
264   let read () = Marshal.from_channel cin in
265   if read () != _GRAMMAR_MAGIC then failwith "Invalid grammar file";
266   if read () != _GRAMMAR_VERSION then failwith "Deprecated grammar format";
267   pr_pos();
268   let tags : int array = read () in
269   pr_pos();
270   let rules : int array = read () in
271   pr_pos();
272   let rules_offset : int = read () in
273   pr_pos();
274   let tag_to_id : (string, int) Hashtbl.t = read () in
275   pr_pos();
276   let tag_of_id : string array = read () in
277   pr_pos();
278   let fd = Unix.descr_of_in_channel cin in
279   let pos = pos_in cin in
280   ignore(Unix.lseek fd pos Unix.SEEK_SET);
281   let bp = Bp.load fd in
282   close_in cin;
283   {
284     start = bp;
285     tags = tags;
286     rules = rules;
287     rules_offset = rules_offset;
288     tag_to_id = tag_to_id;
289     tag_of_id = tag_of_id;
290   }
291
292
293 type node = [ `Start ] Node.t
294
295 type n_type = [ `NonTerminal ]
296 type t_type = [ `Terminal ]
297 type r_type = [ `Rule ]
298 type any_type = [ n_type | t_type ]
299 type rhs = [ r_type ] Node.t
300
301 type n_symbol = n_type Node.t
302 type t_symbol = t_type Node.t
303 type tn_symbol = [ any_type ] Node.t
304
305
306 let is_nil  (t : t_symbol) =
307   (Node.to_int t) == 4
308
309 let nil_symbol : t_symbol =
310   (Node.of_int 4)
311
312 let translate_tag _ t  = if t == 4 then ~-1 else t
313 let to_string t tag =
314   if tag < Array.length t.tag_of_id then t.tag_of_id.(Tag.to_int tag)
315   else "<!INVALIDTAG!>"
316
317 let register_tag t s =
318   try Hashtbl.find t.tag_to_id s with
319     Not_found -> 4
320
321 let tag_operations t = {
322   Tag.tag = (fun s -> register_tag t s);
323   Tag.to_string = (fun s -> to_string t s);
324   Tag.translate = (fun s -> translate_tag t s);
325 }
326
327 let start_root : node = Node.of_int 0
328 let start_tag t (idx : node) =
329   t.tags.(Bp.preorder_rank t.start (Node.to_int idx))
330
331 let start_first_child t (idx : node) =
332   Bp.first_child t.start (Node.to_int idx)
333
334 let start_next_sibling t (idx : node) =
335   Bp.next_sibling t.start (Node.to_int idx)
336
337 let is_non_terminal t (n : [< any_type ] Node.t) =
338   let n = Node.to_int n in
339   n >= t.rules_offset
340
341 let is_terminal t (n : [< any_type ] Node.t) = not(is_non_terminal t n)
342
343 external terminal : [< any_type ] Node.t -> t_symbol = "%identity"
344 external non_terminal : [< any_type ] Node.t -> t_symbol = "%identity"
345
346
347 let tag (n : t_symbol) : Tag.t = Obj.magic n
348
349 let get_rule g (r : n_symbol) : rhs =
350   Node.of_int (g.rules.((Node.to_int r) - g.rules_offset))
351
352 let get_id1 (r : rhs) : tn_symbol =
353   Node.of_int(
354     ((Node.to_int r) lsr 4) land 0x7ffffff)
355
356 let get_id2 (r : rhs) : tn_symbol =
357   Node.of_int((Node.to_int r) lsr 31)
358
359 let get_param_pos (r : rhs) : int =
360   ((Node.to_int r) lsr 2) land 0b11
361
362 let num_params (r : rhs) : int =
363   (Node.to_int r) land 0b11
364