5c029f772ed28f4871e5299b45c20544823cf54c
[SXSI/xpathcomp.git] / ptset.ml
1 (***************************************************************************)
2 (* Implementation for sets of positive integers implemented as deeply hash-*)
3 (* consed Patricia trees. Provide fast set operations, fast membership as  *)
4 (* well as fast min and max elements. Hash consing provides O(1) equality  *)
5 (* checking                                                                *)
6 (*                                                                         *)
7 (***************************************************************************)
8
9
10 type elt = int
11
12 type t = { id : int;
13            key : int; (* hash *)
14            node : node }
15 and node = 
16   | Empty
17   | Leaf of int
18   | Branch of int * int * t * t
19
20 module Node = 
21   struct
22     type _t = t
23     type t = _t
24     let hash x = x.key       
25     let hash_node = function 
26          | Empty -> 0
27          | Leaf i -> i+1
28              (* power of 2 +/- 1 are fast ! *)
29          | Branch (b,i,l,r) -> 
30              (b lsl 1)+ b + i+(i lsl 4) + (l.key lsl 5)-l.key
31              + (r.key lsl 7) - r.key
32     let hash_node x = (hash_node x) land max_int
33     let equal x y = match (x.node,y.node) with
34       | Empty,Empty -> true
35       | Leaf k1, Leaf k2 when k1 == k2 -> true
36       | Branch(p1,m1,l1,r1), Branch(p2,m2,l2,r2) when m1==m2 && p1==p2 && 
37           (l1.id == l2.id) && (r1.id == r2.id) -> true
38       | _ -> false
39   end
40
41 module WH = Weak.Make(Node)
42
43 let pool = WH.create 4093
44
45 (* Neat trick thanks to Alain Frisch ! *)
46
47 let gen_uid () = Oo.id (object end) 
48
49 let empty = { id = gen_uid ();
50               key = 0;
51               node = Empty }
52
53 let _ = WH.add pool empty
54
55 let is_empty = function { id = 0 } -> true  | _ -> false
56     
57 let rec norm n =
58   let v = { id = gen_uid ();
59             key = Node.hash_node n;
60             node = n } 
61   in
62       WH.merge pool v 
63
64 (*  WH.merge pool *)
65
66 let branch (p,m,l,r) = norm (Branch(p,m,l,r))
67 let leaf k = norm (Leaf k)
68
69 (* To enforce the invariant that a branch contains two non empty sub-trees *)
70 let branch_ne = function
71   | (_,_,e,t) when is_empty e -> t
72   | (_,_,t,e) when is_empty e -> t
73   | (p,m,t0,t1)   -> branch (p,m,t0,t1)
74
75 (********** from here on, only use the smart constructors *************)
76
77 let zero_bit k m = (k land m) == 0
78
79 let singleton k = if k < 0 then failwith "singleton" else leaf k
80
81 let rec mem k n = match n.node with
82   | Empty -> false
83   | Leaf j -> k == j
84   | Branch (p, _, l, r) -> if k <= p then mem k l else mem k r
85
86 let rec min_elt n = match n.node with
87   | Empty -> raise Not_found
88   | Leaf k -> k
89   | Branch (_,_,s,_) -> min_elt s
90       
91   let rec max_elt n = match n.node with
92     | Empty -> raise Not_found
93     | Leaf k -> k
94     | Branch (_,_,_,t) -> max_elt t
95
96   let elements s =
97     let rec elements_aux acc n = match n.node with
98       | Empty -> acc
99       | Leaf k -> k :: acc
100       | Branch (_,_,l,r) -> elements_aux (elements_aux acc r) l
101     in
102     elements_aux [] s
103
104   let mask k m  = (k lor (m-1)) land (lnot m)
105
106   let naive_highest_bit x = 
107     assert (x < 256);
108     let rec loop i = 
109       if i = 0 then 1 else if x lsr i = 1 then 1 lsl i else loop (i-1)
110     in
111     loop 7
112
113   let hbit = Array.init 256 naive_highest_bit
114   
115   let highest_bit_32 x =
116     let n = x lsr 24 in if n != 0 then hbit.(n) lsl 24
117     else let n = x lsr 16 in if n != 0 then hbit.(n) lsl 16
118     else let n = x lsr 8 in if n != 0 then hbit.(n) lsl 8
119     else hbit.(x)
120
121   let highest_bit_64 x =
122     let n = x lsr 32 in if n != 0 then (highest_bit_32 n) lsl 32
123     else highest_bit_32 x
124
125   let highest_bit = match Sys.word_size with
126     | 32 -> highest_bit_32
127     | 64 -> highest_bit_64
128     | _ -> assert false
129
130   let branching_bit p0 p1 = highest_bit (p0 lxor p1)
131
132   let join (p0,t0,p1,t1) =  
133     let m = branching_bit p0 p1  in
134     if zero_bit p0 m then 
135       branch (mask p0 m, m, t0, t1)
136     else 
137       branch (mask p0 m, m, t1, t0)
138     
139   let match_prefix k p m = (mask k m) == p
140
141   let add k t =
142     let rec ins n = match n.node with
143       | Empty -> leaf k
144       | Leaf j ->  if j == k then n else join (k, leaf k, j, n)
145       | Branch (p,m,t0,t1)  ->
146           if match_prefix k p m then
147             if zero_bit k m then 
148               branch (p, m, ins t0, t1)
149             else
150               branch (p, m, t0, ins t1)
151           else
152             join (k, leaf k, p, n)
153     in
154     ins t
155       
156   let remove k t =
157     let rec rmv n = match n.node with
158       | Empty -> empty
159       | Leaf j  -> if k == j then empty else n
160       | Branch (p,m,t0,t1) -> 
161           if match_prefix k p m then
162             if zero_bit k m then
163               branch_ne (p, m, rmv t0, t1)
164             else
165               branch_ne (p, m, t0, rmv t1)
166           else
167             n
168     in
169     rmv t
170       
171   (* should run in O(1) thanks to Hash consing *)
172
173   let equal = (=)
174
175   let compare = compare
176
177
178   let rec merge (s,t)  = 
179     if (equal s t) (* This is cheap thanks to hash-consing *)
180     then s
181     else
182       match s.node,t.node with
183         | Empty, _  -> t
184         | _, Empty  -> s
185         | Leaf k, _ -> add k t
186         | _, Leaf k -> add k s
187         | Branch (p,m,s0,s1), Branch (q,n,t0,t1) ->
188             if m == n && match_prefix q p m then
189               branch (p, m, merge (s0,t0), merge (s1,t1))
190             else if m > n && match_prefix q p m then
191               if zero_bit q m then 
192                 branch (p, m, merge (s0,t), s1)
193               else 
194                 branch (p, m, s0, merge (s1,t))
195             else if m < n && match_prefix p q n then
196               
197               if zero_bit p n then
198                 branch (q, n, merge (s,t0), t1)
199               else
200                 branch (q, n, t0, merge (s,t1))
201             else
202               (* The prefixes disagree. *)
203               join (p, s, q, t)
204             
205   let union s t = merge (s,t)
206
207   let rec subset s1 s2 = (equal s1 s2) ||
208     match (s1.node,s2.node) with
209       | Empty, _ -> true
210       | _, Empty -> false
211       | Leaf k1, _ -> mem k1 s2
212       | Branch _, Leaf _ -> false
213       | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
214           if m1 == m2 && p1 == p2 then
215             subset l1 l2 && subset r1 r2
216           else if m1 < m2 && match_prefix p1 p2 m2 then
217             if zero_bit p1 m2 then 
218               subset l1 l2 && subset r1 l2
219             else 
220               subset l1 r2 && subset r1 r2
221           else
222             false
223               
224   let rec inter s1 s2 = 
225     if (equal s1 s2) 
226     then s1
227     else
228       match (s1.node,s2.node) with
229         | Empty, _ -> empty
230         | _, Empty -> empty
231         | Leaf k1, _ -> if mem k1 s2 then s1 else empty
232         | _, Leaf k2 -> if mem k2 s1 then s2 else empty
233         | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
234             if m1 == m2 && p1 == p2 then 
235               merge (inter l1 l2, inter r1 r2)
236             else if m1 > m2 && match_prefix p2 p1 m1 then
237               inter (if zero_bit p2 m1 then l1 else r1) s2
238             else if m1 < m2 && match_prefix p1 p2 m2 then
239               inter s1 (if zero_bit p1 m2 then l2 else r2)
240             else
241               empty
242
243   let rec diff s1 s2 = 
244     if (equal s1 s2) 
245     then empty
246     else
247       match (s1.node,s2.node) with
248         | Empty, _ -> empty
249         | _, Empty -> s1
250         | Leaf k1, _ -> if mem k1 s2 then empty else s1
251         | _, Leaf k2 -> remove k2 s1
252         | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
253             if m1 == m2 && p1 == p2 then
254               merge (diff l1 l2, diff r1 r2)
255             else if m1 > m2 && match_prefix p2 p1 m1 then
256               if zero_bit p2 m1 then 
257                 merge (diff l1 s2, r1) 
258               else 
259                 merge (l1, diff r1 s2)
260             else if m1 < m2 && match_prefix p1 p2 m2 then
261               if zero_bit p1 m2 then diff s1 l2 else diff s1 r2
262             else
263           s1
264             
265             
266
267
268 (*s All the following operations ([cardinal], [iter], [fold], [for_all],
269     [exists], [filter], [partition], [choose], [elements]) are
270     implemented as for any other kind of binary trees. *)
271
272 let rec cardinal n = match n.node with
273   | Empty -> 0
274   | Leaf _ -> 1
275   | Branch (_,_,t0,t1) -> cardinal t0 + cardinal t1
276
277 let rec iter f n = match n.node with
278   | Empty -> ()
279   | Leaf k -> f k
280   | Branch (_,_,t0,t1) -> iter f t0; iter f t1
281       
282 let rec fold f s accu = match s.node with
283   | Empty -> accu
284   | Leaf k -> f k accu
285   | Branch (_,_,t0,t1) -> fold f t0 (fold f t1 accu)
286
287 let rec for_all p n = match n.node with
288   | Empty -> true
289   | Leaf k -> p k
290   | Branch (_,_,t0,t1) -> for_all p t0 && for_all p t1
291
292 let rec exists p n = match n.node with
293   | Empty -> false
294   | Leaf k -> p k
295   | Branch (_,_,t0,t1) -> exists p t0 || exists p t1
296
297 let rec filter pr n = match n.node with
298   | Empty -> empty
299   | Leaf k -> if pr k then n else empty
300   | Branch (p,m,t0,t1) -> branch_ne (p, m, filter pr t0, filter pr t1)
301
302 let partition p s =
303   let rec part (t,f as acc) n = match n.node with
304     | Empty -> acc
305     | Leaf k -> if p k then (add k t, f) else (t, add k f)
306     | Branch (_,_,t0,t1) -> part (part acc t0) t1
307   in
308   part (empty, empty) s
309
310 let rec choose n = match n.node with
311   | Empty -> raise Not_found
312   | Leaf k -> k
313   | Branch (_, _,t0,_) -> choose t0   (* we know that [t0] is non-empty *)
314
315
316 let split x s =
317   let coll k (l, b, r) =
318     if k < x then add k l, b, r
319     else if k > x then l, b, add k r
320     else l, true, r 
321   in
322   fold coll s (empty, false, empty)
323
324
325
326 let rec dump n =
327   Printf.eprintf "{ id = %i; key = %i ; node=" n.id n.key;
328   match n.node with
329     | Empty -> Printf.eprintf "Empty; }\n"
330     | Leaf k -> Printf.eprintf "Leaf %i; }\n" k
331     | Branch (p,m,l,r) -> 
332         Printf.eprintf "Branch(%i,%i,id=%i,id=%i); }\n"
333           p m l.id r.id;
334         dump l;
335         dump r
336
337 (*i*)
338 let make l = List.fold_left (fun acc e -> add e acc ) empty l
339 (*i*)
340
341 (*s Additional functions w.r.t to [Set.S]. *)
342
343 let rec intersect s1 s2 = (equal s1 s2) ||
344   match (s1.node,s2.node) with
345   | Empty, _ -> false
346   | _, Empty -> false
347   | Leaf k1, _ -> mem k1 s2
348   | _, Leaf k2 -> mem k2 s1
349   | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
350       if m1 == m2 && p1 == p2 then
351         intersect l1 l2 || intersect r1 r2
352       else if m1 < m2 && match_prefix p2 p1 m1 then
353         intersect (if zero_bit p2 m1 then l1 else r1) s2
354       else if m1 > m2 && match_prefix p1 p2 m2 then
355         intersect s1 (if zero_bit p1 m2 then l2 else r2)
356       else
357         false
358
359
360 let hash s = s.key
361
362 let from_list l = List.fold_left (fun acc i -> add i acc) empty l