Removed testing cruft
[SXSI/xpathcomp.git] / ptset.ml
1 (***************************************************************************)
2 (* Implementation for sets of positive integers implemented as deeply hash-*)
3 (* consed Patricia trees. Provide fast set operations, fast membership as  *)
4 (* well as fast min and max elements. Hash consing provides O(1) equality  *)
5 (* checking                                                                *)
6 (*                                                                         *)
7 (***************************************************************************)
8
9
10 type elt = int
11
12 type t = { id : int;
13            key : int; (* hash *)
14            node : node }
15 and node = 
16   | Empty
17   | Leaf of int
18   | Branch of int * int * t * t
19
20 module Node = 
21   struct
22     type _t = t
23     type t = _t
24     let hash x = x.key       
25     let hash_node = function 
26          | Empty -> 0
27          | Leaf i -> i+1
28              (* power of 2 +/- 1 are fast ! *)
29          | Branch (b,i,l,r) -> 
30              (b lsl 1)+ b + i+(i lsl 4) + (l.key lsl 5)-l.key
31              + (r.key lsl 7) - r.key
32     let hash_node x = (hash_node x) land max_int
33     let equal x y = match (x.node,y.node) with
34       | Empty,Empty -> true
35       | Leaf k1, Leaf k2 when k1 == k2 -> true
36       | Branch(p1,m1,l1,r1), Branch(p2,m2,l2,r2) when m1==m2 && p1==p2 && 
37           (l1.id == l2.id) && (r1.id == r2.id) -> true
38       | _ -> false
39   end
40
41 module WH =Weak.Make(Node) 
42 (* struct 
43   include Hashtbl.Make(Node)
44     let merge h v =
45       if mem h v then v
46       else (add h v v;v)
47 end
48 *)
49 let pool = WH.create 4093
50
51 (* Neat trick thanks to Alain Frisch ! *)
52
53 let gen_uid () = Oo.id (object end) 
54
55 let empty = { id = gen_uid ();
56               key = 0;
57               node = Empty }
58
59 let _ = WH.add pool empty
60
61 let is_empty s = s.id==0
62     
63 let rec norm n =
64   let v = { id = gen_uid ();
65             key = Node.hash_node n;
66             node = n } 
67   in
68       WH.merge pool v 
69
70 (*  WH.merge pool *)
71
72 let branch  p m l r  = norm (Branch(p,m,l,r))
73 let leaf k = norm (Leaf k)
74
75 (* To enforce the invariant that a branch contains two non empty sub-trees *)
76 let branch_ne = function
77   | (_,_,e,t) when is_empty e -> t
78   | (_,_,t,e) when is_empty e -> t
79   | (p,m,t0,t1)   -> branch p m t0 t1
80
81 (********** from here on, only use the smart constructors *************)
82
83 let zero_bit k m = (k land m) == 0
84
85 let singleton k = if k < 0 then failwith "singleton" else leaf k
86
87 let rec mem k n = match n.node with
88   | Empty -> false
89   | Leaf j -> k == j
90   | Branch (p, _, l, r) -> if k <= p then mem k l else mem k r
91
92 let rec min_elt n = match n.node with
93   | Empty -> raise Not_found
94   | Leaf k -> k
95   | Branch (_,_,s,_) -> min_elt s
96       
97   let rec max_elt n = match n.node with
98     | Empty -> raise Not_found
99     | Leaf k -> k
100     | Branch (_,_,_,t) -> max_elt t
101
102   let elements s =
103     let rec elements_aux acc n = match n.node with
104       | Empty -> acc
105       | Leaf k -> k :: acc
106       | Branch (_,_,l,r) -> elements_aux (elements_aux acc r) l
107     in
108     elements_aux [] s
109
110   let mask k m  = (k lor (m-1)) land (lnot m)
111
112   let naive_highest_bit x = 
113     assert (x < 256);
114     let rec loop i = 
115       if i = 0 then 1 else if x lsr i = 1 then 1 lsl i else loop (i-1)
116     in
117     loop 7
118
119   let hbit = Array.init 256 naive_highest_bit
120   
121   let highest_bit_32 x =
122     let n = x lsr 24 in if n != 0 then Array.unsafe_get hbit n lsl 24
123     else let n = x lsr 16 in if n != 0 then Array.unsafe_get hbit n lsl 16
124     else let n = x lsr 8 in if n != 0 then Array.unsafe_get hbit n lsl 8
125     else Array.unsafe_get hbit x
126
127   let highest_bit_64 x =
128     let n = x lsr 32 in if n != 0 then (highest_bit_32 n) lsl 32
129     else highest_bit_32 x
130
131   let highest_bit = match Sys.word_size with
132     | 32 -> highest_bit_32
133     | 64 -> highest_bit_64
134     | _ -> assert false
135
136   let branching_bit p0 p1 = highest_bit (p0 lxor p1)
137
138   let join p0 t0 p1 t1 =  
139     let m = branching_bit p0 p1  in
140     if zero_bit p0 m then 
141       branch (mask p0 m)  m t0 t1
142     else 
143       branch (mask p0 m) m t1 t0
144     
145   let match_prefix k p m = (mask k m) == p
146
147   let add k t =
148     let rec ins n = match n.node with
149       | Empty -> leaf k
150       | Leaf j ->  if j == k then n else join k (leaf k) j n
151       | Branch (p,m,t0,t1)  ->
152           if match_prefix k p m then
153             if zero_bit k m then 
154               branch p m (ins t0) t1
155             else
156               branch p m t0 (ins t1)
157           else
158             join k  (leaf k)  p n
159     in
160     ins t
161       
162   let remove k t =
163     let rec rmv n = match n.node with
164       | Empty -> empty
165       | Leaf j  -> if k == j then empty else n
166       | Branch (p,m,t0,t1) -> 
167           if match_prefix k p m then
168             if zero_bit k m then
169               branch_ne (p, m, rmv t0, t1)
170             else
171               branch_ne (p, m, t0, rmv t1)
172           else
173             n
174     in
175     rmv t
176       
177   (* should run in O(1) thanks to Hash consing *)
178
179   let equal a b = a==b || a.id == b.id
180
181   let compare a b = if a == b then 0 else a.id - b.id
182
183
184   let rec merge s t = 
185     if (equal s t) (* This is cheap thanks to hash-consing *)
186     then s
187     else
188       match s.node,t.node with
189         | Empty, _  -> t
190         | _, Empty  -> s
191         | Leaf k, _ -> add k t
192         | _, Leaf k -> add k s
193         | Branch (p,m,s0,s1), Branch (q,n,t0,t1) ->
194             if m == n && match_prefix q p m then
195               branch p  m  (merge s0 t0) (merge s1 t1)
196             else if m > n && match_prefix q p m then
197               if zero_bit q m then 
198                 branch p m (merge s0 t) s1
199               else 
200                 branch p m s0 (merge s1 t)
201             else if m < n && match_prefix p q n then     
202               if zero_bit p n then
203                 branch q n (merge s t0) t1
204               else
205                 branch q n t0 (merge s t1)
206             else
207               (* The prefixes disagree. *)
208               join p s q t
209             
210
211
212   let rec subset s1 s2 = (equal s1 s2) ||
213     match (s1.node,s2.node) with
214       | Empty, _ -> true
215       | _, Empty -> false
216       | Leaf k1, _ -> mem k1 s2
217       | Branch _, Leaf _ -> false
218       | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
219           if m1 == m2 && p1 == p2 then
220             subset l1 l2 && subset r1 r2
221           else if m1 < m2 && match_prefix p1 p2 m2 then
222             if zero_bit p1 m2 then 
223               subset l1 l2 && subset r1 l2
224             else 
225               subset l1 r2 && subset r1 r2
226           else
227             false
228
229   let union s t = 
230       merge s t
231               
232   let rec inter s1 s2 = 
233     if equal s1 s2 
234     then s1
235     else
236       match (s1.node,s2.node) with
237         | Empty, _ -> empty
238         | _, Empty -> empty
239         | Leaf k1, _ -> if mem k1 s2 then s1 else empty
240         | _, Leaf k2 -> if mem k2 s1 then s2 else empty
241         | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
242             if m1 == m2 && p1 == p2 then 
243               merge (inter l1 l2)  (inter r1 r2)
244             else if m1 > m2 && match_prefix p2 p1 m1 then
245               inter (if zero_bit p2 m1 then l1 else r1) s2
246             else if m1 < m2 && match_prefix p1 p2 m2 then
247               inter s1 (if zero_bit p1 m2 then l2 else r2)
248             else
249               empty
250
251   let rec diff s1 s2 = 
252     if equal s1 s2 
253     then empty
254     else
255       match (s1.node,s2.node) with
256         | Empty, _ -> empty
257         | _, Empty -> s1
258         | Leaf k1, _ -> if mem k1 s2 then empty else s1
259         | _, Leaf k2 -> remove k2 s1
260         | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
261             if m1 == m2 && p1 == p2 then
262               merge (diff l1 l2) (diff r1 r2)
263             else if m1 > m2 && match_prefix p2 p1 m1 then
264               if zero_bit p2 m1 then 
265                 merge (diff l1 s2) r1
266               else 
267                 merge l1 (diff r1 s2)
268             else if m1 < m2 && match_prefix p1 p2 m2 then
269               if zero_bit p1 m2 then diff s1 l2 else diff s1 r2
270             else
271           s1
272             
273             
274
275
276 (*s All the following operations ([cardinal], [iter], [fold], [for_all],
277     [exists], [filter], [partition], [choose], [elements]) are
278     implemented as for any other kind of binary trees. *)
279
280 let rec cardinal n = match n.node with
281   | Empty -> 0
282   | Leaf _ -> 1
283   | Branch (_,_,t0,t1) -> cardinal t0 + cardinal t1
284
285 let rec iter f n = match n.node with
286   | Empty -> ()
287   | Leaf k -> f k
288   | Branch (_,_,t0,t1) -> iter f t0; iter f t1
289       
290 let rec fold f s accu = match s.node with
291   | Empty -> accu
292   | Leaf k -> f k accu
293   | Branch (_,_,t0,t1) -> fold f t0 (fold f t1 accu)
294
295 let rec for_all p n = match n.node with
296   | Empty -> true
297   | Leaf k -> p k
298   | Branch (_,_,t0,t1) -> for_all p t0 && for_all p t1
299
300 let rec exists p n = match n.node with
301   | Empty -> false
302   | Leaf k -> p k
303   | Branch (_,_,t0,t1) -> exists p t0 || exists p t1
304
305 let rec filter pr n = match n.node with
306   | Empty -> empty
307   | Leaf k -> if pr k then n else empty
308   | Branch (p,m,t0,t1) -> branch_ne (p, m, filter pr t0, filter pr t1)
309
310 let partition p s =
311   let rec part (t,f as acc) n = match n.node with
312     | Empty -> acc
313     | Leaf k -> if p k then (add k t, f) else (t, add k f)
314     | Branch (_,_,t0,t1) -> part (part acc t0) t1
315   in
316   part (empty, empty) s
317
318 let rec choose n = match n.node with
319   | Empty -> raise Not_found
320   | Leaf k -> k
321   | Branch (_, _,t0,_) -> choose t0   (* we know that [t0] is non-empty *)
322
323
324 let split x s =
325   let coll k (l, b, r) =
326     if k < x then add k l, b, r
327     else if k > x then l, b, add k r
328     else l, true, r 
329   in
330   fold coll s (empty, false, empty)
331
332
333
334 let rec dump n =
335   Printf.eprintf "{ id = %i; key = %i ; node=" n.id n.key;
336   match n.node with
337     | Empty -> Printf.eprintf "Empty; }\n"
338     | Leaf k -> Printf.eprintf "Leaf %i; }\n" k
339     | Branch (p,m,l,r) -> 
340         Printf.eprintf "Branch(%i,%i,id=%i,id=%i); }\n"
341           p m l.id r.id;
342         dump l;
343         dump r
344
345 (*i*)
346 let make l = List.fold_left (fun acc e -> add e acc ) empty l
347 (*i*)
348
349 (*s Additional functions w.r.t to [Set.S]. *)
350
351 let rec intersect s1 s2 = (equal s1 s2) ||
352   match (s1.node,s2.node) with
353   | Empty, _ -> false
354   | _, Empty -> false
355   | Leaf k1, _ -> mem k1 s2
356   | _, Leaf k2 -> mem k2 s1
357   | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
358       if m1 == m2 && p1 == p2 then
359         intersect l1 l2 || intersect r1 r2
360       else if m1 < m2 && match_prefix p2 p1 m1 then
361         intersect (if zero_bit p2 m1 then l1 else r1) s2
362       else if m1 > m2 && match_prefix p1 p2 m2 then
363         intersect s1 (if zero_bit p1 m2 then l2 else r2)
364       else
365         false
366
367
368 let hash s = s.key
369
370 let from_list l = List.fold_left (fun acc i -> add i acc) empty l