Exception less mainloop in ata.ml
[SXSI/xpathcomp.git] / ptset.ml
1 (***************************************************************************)
2 (* Implementation for sets of positive integers implemented as deeply hash-*)
3 (* consed Patricia trees. Provide fast set operations, fast membership as  *)
4 (* well as fast min and max elements. Hash consing provides O(1) equality  *)
5 (* checking                                                                *)
6 (*                                                                         *)
7 (***************************************************************************)
8 INCLUDE "utils.ml"
9 module type S = 
10 sig
11   include Set.S
12   type data
13   val intersect : t -> t -> bool
14   val is_singleton : t -> bool
15   val mem_union : t -> t -> t
16   val hash : t -> int
17   val uid : t -> Uid.t
18   val uncons : t -> elt*t
19   val from_list : elt list -> t 
20   val make : data -> t
21   val node : t -> data
22     
23   val with_id : Uid.t -> t
24 end
25
26 module Make ( H : Hcons.SA ) : S with type elt = H.t =
27 struct
28   type elt = H.t
29   type 'a node =
30     | Empty
31     | Leaf of elt
32     | Branch of int * int * 'a * 'a
33
34   module rec HNode : Hcons.S with type data = Node.t = Hcons.Make (Node)
35   and Node : Hashtbl.HashedType  with type t = HNode.t node =
36   struct 
37     type t =  HNode.t node
38     let equal x y = 
39       match x,y with
40         | Empty,Empty -> true
41         | Leaf k1, Leaf k2 ->  k1 == k2
42         | Branch(b1,i1,l1,r1),Branch(b2,i2,l2,r2) ->
43             b1 == b2 && i1 == i2 &&
44               (HNode.equal l1 l2) &&
45               (HNode.equal r1 r2) 
46         | _ -> false
47     let hash = function 
48       | Empty -> 0
49       | Leaf i -> HASHINT2(HALF_MAX_INT,Uid.to_int (H.uid i))
50       | Branch (b,i,l,r) -> HASHINT4(b,i,Uid.to_int l.HNode.id, Uid.to_int r.HNode.id)
51   end
52  ;;
53                              
54   type t = HNode.t
55   type data = t node
56   let hash = HNode.hash 
57   let uid = HNode.uid
58   let make = HNode.make
59   let node _ = failwith "node"
60   let empty = HNode.make Empty
61     
62   let is_empty s = (HNode.node s) == Empty
63        
64   let branch p m l r = HNode.make (Branch(p,m,l,r))
65
66   let leaf k = HNode.make (Leaf k)
67
68   (* To enforce the invariant that a branch contains two non empty sub-trees *)
69   let branch_ne p m t0 t1 = 
70     if (is_empty t0) then t1
71     else if is_empty t1 then t0 else branch p m t0 t1
72       
73   (********** from here on, only use the smart constructors *************)
74       
75   let zero_bit k m = (k land m) == 0
76     
77   let singleton k = leaf k
78     
79   let is_singleton n = 
80     match HNode.node n with Leaf _ -> true
81       | _ -> false
82           
83   let mem (k:elt) n = 
84     let kid = Uid.to_int (H.uid k) in
85     let rec loop n = match HNode.node n with
86       | Empty -> false
87       | Leaf j ->  k == j
88       | Branch (p, _, l, r) -> if kid <= p then loop l else loop r
89     in loop n
90          
91   let rec min_elt n = match HNode.node n with
92     | Empty -> raise Not_found
93     | Leaf k -> k
94     | Branch (_,_,s,_) -> min_elt s
95         
96   let rec max_elt n = match HNode.node n with
97     | Empty -> raise Not_found
98     | Leaf k -> k
99     | Branch (_,_,_,t) -> max_elt t
100         
101   let elements s =
102     let rec elements_aux acc n = match HNode.node n with
103       | Empty -> acc
104       | Leaf k -> k :: acc
105       | Branch (_,_,l,r) -> elements_aux (elements_aux acc r) l
106     in
107       elements_aux [] s
108         
109   let mask k m  = (k lor (m-1)) land (lnot m)
110     
111   let naive_highest_bit x = 
112     assert (x < 256);
113     let rec loop i = 
114       if i = 0 then 1 else if x lsr i = 1 then 1 lsl i else loop (i-1)
115     in
116       loop 7
117         
118   let hbit = Array.init 256 naive_highest_bit
119
120
121   let highest_bit x = let n = (x) lsr 24 in 
122   if n != 0 then Array.unsafe_get hbit n lsl 24
123   else let n = (x) lsr 16 in if n != 0 then Array.unsafe_get hbit n lsl 16
124   else let n = (x) lsr 8 in if n != 0 then Array.unsafe_get hbit n lsl 8
125   else Array.unsafe_get hbit (x)
126
127 IFDEF WORDIZE64
128 THEN
129   let highest_bit64 x =
130     let n = x lsr 32 in if n != 0 then highest_bit n lsl 32
131       else highest_bit x
132 END
133
134         
135   let branching_bit p0 p1 = highest_bit (p0 lxor p1)
136     
137   let join p0 t0 p1 t1 =  
138     let m = branching_bit p0 p1  in
139       if zero_bit p0 m then 
140         branch (mask p0 m) m t0 t1
141       else 
142         branch (mask p0 m) m t1 t0
143           
144   let match_prefix k p m = (mask k m) == p
145     
146   let add k t =
147     let kid = Uid.to_int (H.uid k) in
148     let rec ins n = match HNode.node n with
149       | Empty -> leaf k
150       | Leaf j ->  if j == k then n else join kid (leaf k) (Uid.to_int (H.uid j)) n
151       | Branch (p,m,t0,t1)  ->
152           if match_prefix kid p m then
153             if zero_bit kid m then 
154               branch p m (ins t0) t1
155             else
156               branch p m t0 (ins t1)
157           else
158             join kid (leaf k)  p n
159     in
160     ins t
161       
162   let remove k t =
163     let kid = Uid.to_int(H.uid k) in
164     let rec rmv n = match HNode.node n with
165       | Empty -> empty
166       | Leaf j  -> if  k == j then empty else n
167       | Branch (p,m,t0,t1) -> 
168           if match_prefix kid p m then
169             if zero_bit kid m then
170               branch_ne p m (rmv t0) t1
171             else
172               branch_ne p m t0 (rmv t1)
173           else
174             n
175     in
176     rmv t
177       
178   (* should run in O(1) thanks to Hash consing *)
179
180   let equal a b = HNode.equal a b 
181
182   let compare a b =  (Uid.to_int (HNode.uid a)) - (Uid.to_int (HNode.uid b))
183
184   let rec merge s t = 
185     if (equal s t) (* This is cheap thanks to hash-consing *)
186     then s
187     else
188     match HNode.node s, HNode.node t with
189       | Empty, _  -> t
190       | _, Empty  -> s
191       | Leaf k, _ -> add k t
192       | _, Leaf k -> add k s
193       | Branch (p,m,s0,s1), Branch (q,n,t0,t1) ->
194           if m == n && match_prefix q p m then
195             branch p  m  (merge s0 t0) (merge s1 t1)
196           else if m > n && match_prefix q p m then
197             if zero_bit q m then 
198               branch p m (merge s0 t) s1
199             else 
200               branch p m s0 (merge s1 t)
201           else if m < n && match_prefix p q n then     
202             if zero_bit p n then
203               branch q n (merge s t0) t1
204             else
205               branch q n t0 (merge s t1)
206           else
207             (* The prefixes disagree. *)
208             join p s q t
209                
210         
211                
212                
213   let rec subset s1 s2 = (equal s1 s2) ||
214     match (HNode.node s1,HNode.node s2) with
215       | Empty, _ -> true
216       | _, Empty -> false
217       | Leaf k1, _ -> mem k1 s2
218       | Branch _, Leaf _ -> false
219       | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
220           if m1 == m2 && p1 == p2 then
221             subset l1 l2 && subset r1 r2
222           else if m1 < m2 && match_prefix p1 p2 m2 then
223             if zero_bit p1 m2 then 
224               subset l1 l2 && subset r1 l2
225             else 
226               subset l1 r2 && subset r1 r2
227           else
228             false
229
230               
231   let union s1 s2 = merge s1 s2
232     (* Todo replace with e Memo Module *)
233   module MemUnion = Hashtbl.Make(
234     struct 
235       type set = t 
236       type t = set*set 
237       let equal (x,y) (z,t) = (equal x z)&&(equal y t)
238       let equal a b = equal a b || equal b a
239       let hash (x,y) =   (* commutative hash *)
240         let x = HNode.hash x 
241         and y = HNode.hash y 
242         in
243           if x < y then HASHINT2(x,y) else HASHINT2(y,x)
244     end)
245   let h_mem = MemUnion.create MED_H_SIZE
246
247   let mem_union s1 s2 = 
248     try  MemUnion.find h_mem (s1,s2) 
249     with Not_found ->
250           let r = merge s1 s2 in MemUnion.add h_mem (s1,s2) r;r 
251       
252
253   let rec inter s1 s2 = 
254     if equal s1 s2 
255     then s1
256     else
257       match (HNode.node s1,HNode.node s2) with
258         | Empty, _ -> empty
259         | _, Empty -> empty
260         | Leaf k1, _ -> if mem k1 s2 then s1 else empty
261         | _, Leaf k2 -> if mem k2 s1 then s2 else empty
262         | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
263             if m1 == m2 && p1 == p2 then 
264               merge (inter l1 l2)  (inter r1 r2)
265             else if m1 > m2 && match_prefix p2 p1 m1 then
266               inter (if zero_bit p2 m1 then l1 else r1) s2
267             else if m1 < m2 && match_prefix p1 p2 m2 then
268               inter s1 (if zero_bit p1 m2 then l2 else r2)
269             else
270               empty
271
272   let rec diff s1 s2 = 
273     if equal s1 s2 
274     then empty
275     else
276       match (HNode.node s1,HNode.node s2) with
277         | Empty, _ -> empty
278         | _, Empty -> s1
279         | Leaf k1, _ -> if mem k1 s2 then empty else s1
280         | _, Leaf k2 -> remove k2 s1
281         | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
282             if m1 == m2 && p1 == p2 then
283               merge (diff l1 l2) (diff r1 r2)
284             else if m1 > m2 && match_prefix p2 p1 m1 then
285               if zero_bit p2 m1 then 
286                 merge (diff l1 s2) r1
287               else 
288                 merge l1 (diff r1 s2)
289             else if m1 < m2 && match_prefix p1 p2 m2 then
290               if zero_bit p1 m2 then diff s1 l2 else diff s1 r2
291             else
292           s1
293                
294
295 (*s All the following operations ([cardinal], [iter], [fold], [for_all],
296     [exists], [filter], [partition], [choose], [elements]) are
297     implemented as for any other kind of binary trees. *)
298
299 let rec cardinal n = match HNode.node n with
300   | Empty -> 0
301   | Leaf _ -> 1
302   | Branch (_,_,t0,t1) -> cardinal t0 + cardinal t1
303
304 let rec iter f n = match HNode.node n with
305   | Empty -> ()
306   | Leaf k -> f k
307   | Branch (_,_,t0,t1) -> iter f t0; iter f t1
308       
309 let rec fold f s accu = match HNode.node s with
310   | Empty -> accu
311   | Leaf k -> f k accu
312   | Branch (_,_,t0,t1) -> fold f t0 (fold f t1 accu)
313
314
315 let rec for_all p n = match HNode.node n with
316   | Empty -> true
317   | Leaf k -> p k
318   | Branch (_,_,t0,t1) -> for_all p t0 && for_all p t1
319
320 let rec exists p n = match HNode.node n with
321   | Empty -> false
322   | Leaf k -> p k
323   | Branch (_,_,t0,t1) -> exists p t0 || exists p t1
324
325 let rec filter pr n = match HNode.node n with
326   | Empty -> empty
327   | Leaf k -> if pr k then n else empty
328   | Branch (p,m,t0,t1) -> branch_ne p m (filter pr t0) (filter pr t1)
329
330 let partition p s =
331   let rec part (t,f as acc) n = match HNode.node n with
332     | Empty -> acc
333     | Leaf k -> if p k then (add k t, f) else (t, add k f)
334     | Branch (_,_,t0,t1) -> part (part acc t0) t1
335   in
336   part (empty, empty) s
337
338 let rec choose n = match HNode.node n with
339   | Empty -> raise Not_found
340   | Leaf k -> k
341   | Branch (_, _,t0,_) -> choose t0   (* we know that [t0] is non-empty *)
342
343
344 let split x s =
345   let coll k (l, b, r) =
346     if k < x then add k l, b, r
347     else if k > x then l, b, add k r
348     else l, true, r 
349   in
350   fold coll s (empty, false, empty)
351
352 (*s Additional functions w.r.t to [Set.S]. *)
353
354 let rec intersect s1 s2 = (equal s1 s2) ||
355   match (HNode.node s1,HNode.node s2) with
356   | Empty, _ -> false
357   | _, Empty -> false
358   | Leaf k1, _ -> mem k1 s2
359   | _, Leaf k2 -> mem k2 s1
360   | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
361       if m1 == m2 && p1 == p2 then
362         intersect l1 l2 || intersect r1 r2
363       else if m1 < m2 && match_prefix p2 p1 m1 then
364         intersect (if zero_bit p2 m1 then l1 else r1) s2
365       else if m1 > m2 && match_prefix p1 p2 m2 then
366         intersect s1 (if zero_bit p1 m2 then l2 else r2)
367       else
368         false
369
370
371
372 let rec uncons n = match HNode.node n with
373   | Empty -> raise Not_found
374   | Leaf k -> (k,empty)
375   | Branch (p,m,s,t) -> let h,ns = uncons s in h,branch_ne p m ns t
376    
377 let from_list l = List.fold_left (fun acc e -> add e acc) empty l
378
379 let with_id = HNode.with_id
380 end
381
382 module Int : sig
383   include S with type elt = int
384   val print : Format.formatter -> t -> unit
385 end
386   = 
387 struct
388   include Make ( struct type t = int 
389                         type data = t
390                         external hash : t -> int = "%identity"
391                         external uid : t -> Uid.t = "%identity"
392                         external equal : t -> t -> bool = "%eq"
393                         external make : t -> int = "%identity"
394                         external node : t -> int = "%identity"
395                         external with_id : Uid.t -> t = "%identity"
396                  end
397                ) 
398   let print ppf s = 
399     Format.pp_print_string ppf "{ ";
400     iter (fun i -> Format.fprintf ppf "%i " i) s;
401     Format.pp_print_string ppf "}";
402     Format.pp_print_flush ppf ()
403  end