Merged -correctxpath branch
[SXSI/xpathcomp.git] / ptset.ml
1 (***************************************************************************)
2 (* Implementation for sets of positive integers implemented as deeply hash-*)
3 (* consed Patricia trees. Provide fast set operations, fast membership as  *)
4 (* well as fast min and max elements. Hash consing provides O(1) equality  *)
5 (* checking                                                                *)
6 (*                                                                         *)
7 (***************************************************************************)
8
9
10 type elt = int
11
12 type t = { id : int;
13            key : int; (* hash *)
14            node : node;
15            }
16 and node = 
17   | Empty
18   | Leaf of int
19   | Branch of int * int * t * t
20
21
22 (* faster if outside of a module *)
23 let hash_node x = match x with 
24   | Empty -> 0
25   | Leaf i -> (i+1) land max_int
26       (* power of 2 +/- 1 are fast ! *)
27   | Branch (b,i,l,r) -> 
28       ((b lsl 1)+ b + i+(i lsl 4) + (l.key lsl 5)-l.key
29        + (r.key lsl 7) - r.key) land max_int
30
31 module Node = 
32   struct
33     type _t = t
34     type t = _t
35     external hash : t -> int = "%field1"
36     let equal x y = 
37       if x.id == y.id || x.key == y.key || x.node == y.node then true
38       else
39       match (x.node,y.node) with
40       | Empty,Empty -> true
41       | Leaf k1, Leaf k2 when k1 == k2 -> true
42       | Branch(p1,m1,l1,r1), Branch(p2,m2,l2,r2) when m1==m2 && p1==p2 && 
43           (l1.id == l2.id) && (r1.id == r2.id) -> true
44       | _ -> false
45   end
46
47 module WH =Weak.Make(Node) 
48 (* struct 
49   include Hashtbl.Make(Node)
50     let merge h v =
51       if mem h v then v
52       else (add h v v;v)
53 end
54 *)
55 let pool = WH.create 4093
56
57 (* Neat trick thanks to Alain Frisch ! *)
58
59 let gen_uid () = Oo.id (object end) 
60
61 let empty = { id = gen_uid ();
62               key = 0;
63               node = Empty }
64
65 let _ = WH.add pool empty
66
67 let is_empty s = s.id==0
68     
69 let rec norm n =
70   let v = { id = gen_uid ();
71             key = hash_node n;
72             node = n } 
73   in
74       WH.merge pool v 
75
76 (*  WH.merge pool *)
77
78 let branch  p m l r  = norm (Branch(p,m,l,r))
79 let leaf k = norm (Leaf k)
80
81 (* To enforce the invariant that a branch contains two non empty sub-trees *)
82 let branch_ne = function
83   | (_,_,e,t) when is_empty e -> t
84   | (_,_,t,e) when is_empty e -> t
85   | (p,m,t0,t1)   -> branch p m t0 t1
86
87 (********** from here on, only use the smart constructors *************)
88
89 let zero_bit k m = (k land m) == 0
90
91 let singleton k = leaf k
92
93 let rec mem k n = match n.node with
94   | Empty -> false
95   | Leaf j -> k == j
96   | Branch (p, _, l, r) -> if k <= p then mem k l else mem k r
97
98 let rec min_elt n = match n.node with
99   | Empty -> raise Not_found
100   | Leaf k -> k
101   | Branch (_,_,s,_) -> min_elt s
102       
103   let rec max_elt n = match n.node with
104     | Empty -> raise Not_found
105     | Leaf k -> k
106     | Branch (_,_,_,t) -> max_elt t
107
108   let elements s =
109     let rec elements_aux acc n = match n.node with
110       | Empty -> acc
111       | Leaf k -> k :: acc
112       | Branch (_,_,l,r) -> elements_aux (elements_aux acc r) l
113     in
114     elements_aux [] s
115
116   let mask k m  = (k lor (m-1)) land (lnot m)
117
118   let naive_highest_bit x = 
119     assert (x < 256);
120     let rec loop i = 
121       if i = 0 then 1 else if x lsr i = 1 then 1 lsl i else loop (i-1)
122     in
123     loop 7
124
125   let hbit = Array.init 256 naive_highest_bit
126   
127   let highest_bit_32 x =
128     let n = x lsr 24 in if n != 0 then Array.unsafe_get hbit n lsl 24
129     else let n = x lsr 16 in if n != 0 then Array.unsafe_get hbit n lsl 16
130     else let n = x lsr 8 in if n != 0 then Array.unsafe_get hbit n lsl 8
131     else Array.unsafe_get hbit x
132
133   let highest_bit_64 x =
134     let n = x lsr 32 in if n != 0 then (highest_bit_32 n) lsl 32
135     else highest_bit_32 x
136
137   let highest_bit = match Sys.word_size with
138     | 32 -> highest_bit_32
139     | 64 -> highest_bit_64
140     | _ -> assert false
141
142   let branching_bit p0 p1 = highest_bit (p0 lxor p1)
143
144   let join p0 t0 p1 t1 =  
145     let m = branching_bit p0 p1  in
146     if zero_bit p0 m then 
147       branch (mask p0 m)  m t0 t1
148     else 
149       branch (mask p0 m) m t1 t0
150     
151   let match_prefix k p m = (mask k m) == p
152
153   let add k t =
154     let rec ins n = match n.node with
155       | Empty -> leaf k
156       | Leaf j ->  if j == k then n else join k (leaf k) j n
157       | Branch (p,m,t0,t1)  ->
158           if match_prefix k p m then
159             if zero_bit k m then 
160               branch p m (ins t0) t1
161             else
162               branch p m t0 (ins t1)
163           else
164             join k  (leaf k)  p n
165     in
166     ins t
167       
168   let remove k t =
169     let rec rmv n = match n.node with
170       | Empty -> empty
171       | Leaf j  -> if k == j then empty else n
172       | Branch (p,m,t0,t1) -> 
173           if match_prefix k p m then
174             if zero_bit k m then
175               branch_ne (p, m, rmv t0, t1)
176             else
177               branch_ne (p, m, t0, rmv t1)
178           else
179             n
180     in
181     rmv t
182       
183   (* should run in O(1) thanks to Hash consing *)
184
185   let equal a b = a==b || a.id == b.id
186
187   let compare a b = if a == b then 0 else a.id - b.id
188
189
190   let rec merge s t = 
191     if (equal s t) (* This is cheap thanks to hash-consing *)
192     then s
193     else
194       match s.node,t.node with
195         | Empty, _  -> t
196         | _, Empty  -> s
197         | Leaf k, _ -> add k t
198         | _, Leaf k -> add k s
199         | Branch (p,m,s0,s1), Branch (q,n,t0,t1) ->
200             if m == n && match_prefix q p m then
201               branch p  m  (merge s0 t0) (merge s1 t1)
202             else if m > n && match_prefix q p m then
203               if zero_bit q m then 
204                 branch p m (merge s0 t) s1
205               else 
206                 branch p m s0 (merge s1 t)
207             else if m < n && match_prefix p q n then     
208               if zero_bit p n then
209                 branch q n (merge s t0) t1
210               else
211                 branch q n t0 (merge s t1)
212             else
213               (* The prefixes disagree. *)
214               join p s q t
215             
216
217
218   let rec subset s1 s2 = (equal s1 s2) ||
219     match (s1.node,s2.node) with
220       | Empty, _ -> true
221       | _, Empty -> false
222       | Leaf k1, _ -> mem k1 s2
223       | Branch _, Leaf _ -> false
224       | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
225           if m1 == m2 && p1 == p2 then
226             subset l1 l2 && subset r1 r2
227           else if m1 < m2 && match_prefix p1 p2 m2 then
228             if zero_bit p1 m2 then 
229               subset l1 l2 && subset r1 l2
230             else 
231               subset l1 r2 && subset r1 r2
232           else
233             false
234
235   let union s t = 
236       merge s t
237               
238   let rec inter s1 s2 = 
239     if equal s1 s2 
240     then s1
241     else
242       match (s1.node,s2.node) with
243         | Empty, _ -> empty
244         | _, Empty -> empty
245         | Leaf k1, _ -> if mem k1 s2 then s1 else empty
246         | _, Leaf k2 -> if mem k2 s1 then s2 else empty
247         | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
248             if m1 == m2 && p1 == p2 then 
249               merge (inter l1 l2)  (inter r1 r2)
250             else if m1 > m2 && match_prefix p2 p1 m1 then
251               inter (if zero_bit p2 m1 then l1 else r1) s2
252             else if m1 < m2 && match_prefix p1 p2 m2 then
253               inter s1 (if zero_bit p1 m2 then l2 else r2)
254             else
255               empty
256
257   let rec diff s1 s2 = 
258     if equal s1 s2 
259     then empty
260     else
261       match (s1.node,s2.node) with
262         | Empty, _ -> empty
263         | _, Empty -> s1
264         | Leaf k1, _ -> if mem k1 s2 then empty else s1
265         | _, Leaf k2 -> remove k2 s1
266         | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
267             if m1 == m2 && p1 == p2 then
268               merge (diff l1 l2) (diff r1 r2)
269             else if m1 > m2 && match_prefix p2 p1 m1 then
270               if zero_bit p2 m1 then 
271                 merge (diff l1 s2) r1
272               else 
273                 merge l1 (diff r1 s2)
274             else if m1 < m2 && match_prefix p1 p2 m2 then
275               if zero_bit p1 m2 then diff s1 l2 else diff s1 r2
276             else
277           s1
278             
279             
280
281
282 (*s All the following operations ([cardinal], [iter], [fold], [for_all],
283     [exists], [filter], [partition], [choose], [elements]) are
284     implemented as for any other kind of binary trees. *)
285
286 let rec cardinal n = match n.node with
287   | Empty -> 0
288   | Leaf _ -> 1
289   | Branch (_,_,t0,t1) -> cardinal t0 + cardinal t1
290
291 let rec iter f n = match n.node with
292   | Empty -> ()
293   | Leaf k -> f k
294   | Branch (_,_,t0,t1) -> iter f t0; iter f t1
295       
296 let rec fold f s accu = match s.node with
297   | Empty -> accu
298   | Leaf k -> f k accu
299   | Branch (_,_,t0,t1) -> fold f t0 (fold f t1 accu)
300
301 let rec for_all p n = match n.node with
302   | Empty -> true
303   | Leaf k -> p k
304   | Branch (_,_,t0,t1) -> for_all p t0 && for_all p t1
305
306 let rec exists p n = match n.node with
307   | Empty -> false
308   | Leaf k -> p k
309   | Branch (_,_,t0,t1) -> exists p t0 || exists p t1
310
311 let rec filter pr n = match n.node with
312   | Empty -> empty
313   | Leaf k -> if pr k then n else empty
314   | Branch (p,m,t0,t1) -> branch_ne (p, m, filter pr t0, filter pr t1)
315
316 let partition p s =
317   let rec part (t,f as acc) n = match n.node with
318     | Empty -> acc
319     | Leaf k -> if p k then (add k t, f) else (t, add k f)
320     | Branch (_,_,t0,t1) -> part (part acc t0) t1
321   in
322   part (empty, empty) s
323
324 let rec choose n = match n.node with
325   | Empty -> raise Not_found
326   | Leaf k -> k
327   | Branch (_, _,t0,_) -> choose t0   (* we know that [t0] is non-empty *)
328
329
330 let split x s =
331   let coll k (l, b, r) =
332     if k < x then add k l, b, r
333     else if k > x then l, b, add k r
334     else l, true, r 
335   in
336   fold coll s (empty, false, empty)
337
338
339
340 let rec dump n =
341   Printf.eprintf "{ id = %i; key = %i ; node=" n.id n.key;
342   match n.node with
343     | Empty -> Printf.eprintf "Empty; }\n"
344     | Leaf k -> Printf.eprintf "Leaf %i; }\n" k
345     | Branch (p,m,l,r) -> 
346         Printf.eprintf "Branch(%i,%i,id=%i,id=%i); }\n"
347           p m l.id r.id;
348         dump l;
349         dump r
350
351 (*i*)
352 let make l = List.fold_left (fun acc e -> add e acc ) empty l
353 (*i*)
354
355 (*s Additional functions w.r.t to [Set.S]. *)
356
357 let rec intersect s1 s2 = (equal s1 s2) ||
358   match (s1.node,s2.node) with
359   | Empty, _ -> false
360   | _, Empty -> false
361   | Leaf k1, _ -> mem k1 s2
362   | _, Leaf k2 -> mem k2 s1
363   | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
364       if m1 == m2 && p1 == p2 then
365         intersect l1 l2 || intersect r1 r2
366       else if m1 < m2 && match_prefix p2 p1 m1 then
367         intersect (if zero_bit p2 m1 then l1 else r1) s2
368       else if m1 > m2 && match_prefix p1 p2 m2 then
369         intersect s1 (if zero_bit p1 m2 then l2 else r2)
370       else
371         false
372
373
374 let hash s = s.key
375
376 let from_list l = List.fold_left (fun acc i -> add i acc) empty l
377
378 type int_vector
379
380 external int_vector_alloc : int -> int_vector = "caml_int_vector_alloc"
381 external int_vector_set : int_vector -> int -> int -> unit = "caml_int_vector_set"
382 external int_vector_length : int_vector -> int  = "caml_int_vector_length"
383 external int_vector_empty : unit -> int_vector = "caml_int_vector_empty"
384
385 let empty_vector = int_vector_empty ()
386
387 let to_int_vector_ext s =
388   let l = cardinal s in
389   let v = int_vector_alloc l in
390   let i = ref 0 in
391     iter (fun e -> int_vector_set v !i e; incr i) s;
392     v
393
394 let hash_vectors = Hashtbl.create 4097
395
396 let to_int_vector s =
397   try 
398     Hashtbl.find hash_vectors s.key
399   with
400       Not_found -> 
401         let v = to_int_vector_ext s in
402           Hashtbl.add hash_vectors s.key v;
403           v
404
405