removed cruft, fixed ptset.ml
[SXSI/xpathcomp.git] / ptset.ml
1 (***************************************************************************)
2 (* Implementation for sets of positive integers implemented as deeply hash-*)
3 (* consed Patricia trees. Provide fast set operations, fast membership as  *)
4 (* well as fast min and max elements. Hash consing provides O(1) equality  *)
5 (* checking                                                                *)
6 (*                                                                         *)
7 (***************************************************************************)
8 INCLUDE "utils.ml"
9 module type S = 
10 sig
11   include Set.S
12   val intersect : t -> t -> bool
13   val is_singleton : t -> bool
14   val mem_union : t -> t -> t
15   val hash : t -> int
16   val uid : t -> int
17   val uncons : t -> elt*t
18   val from_list : elt list -> t 
19 end
20
21 module Make ( H : Hcons.S ) : S with type elt = H.t =
22 struct
23   type elt = H.t
24
25   type 'a node =
26     | Empty
27     | Leaf of elt
28     | Branch of int * int * 'a * 'a
29         
30   module rec HNode : Hcons.S with type data = Node.t = Hcons.Make (Node)
31   and Node : Hashtbl.HashedType  with type t = HNode.t node =
32   struct 
33     type t =  HNode.t node
34     let equal x y = 
35       match x,y with
36         | Empty,Empty -> true
37         | Leaf k1, Leaf k2 -> H.equal k1 k2
38         | Branch(b1,i1,l1,r1),Branch(b2,i2,l2,r2) ->
39             b1 == b2 && i1 == i2 &&
40               (HNode.equal l1 l2) &&
41               (HNode.equal r1 r2) 
42         | _ -> false
43     let hash = function 
44       | Empty -> 0
45       | Leaf i -> HASHINT2(HALF_MAX_INT,H.hash i)
46       | Branch (b,i,l,r) -> HASHINT4(b,i,HNode.hash l, HNode.hash r)
47   end
48  ;;
49                              
50   type t = HNode.t
51   let hash = HNode.hash 
52   let uid = HNode.uid
53     
54   let empty = HNode.make Empty
55     
56   let is_empty s = (HNode.node s) == Empty
57        
58   let branch p m l r = HNode.make (Branch(p,m,l,r))
59
60   let leaf k = HNode.make (Leaf k)
61
62   (* To enforce the invariant that a branch contains two non empty sub-trees *)
63   let branch_ne p m t0 t1 = 
64     if (is_empty t0) then t1
65     else if is_empty t1 then t0 else branch p m t0 t1
66       
67   (********** from here on, only use the smart constructors *************)
68       
69   let zero_bit k m = (k land m) == 0
70     
71   let singleton k = leaf k
72     
73   let is_singleton n = 
74     match HNode.node n with Leaf _ -> true
75       | _ -> false
76           
77   let mem (k:elt) n = 
78     let kid = H.uid k in
79     let rec loop n = match HNode.node n with
80       | Empty -> false
81       | Leaf j -> H.equal k j
82       | Branch (p, _, l, r) -> if kid <= p then loop l else loop r
83     in loop n
84          
85   let rec min_elt n = match HNode.node n with
86     | Empty -> raise Not_found
87     | Leaf k -> k
88     | Branch (_,_,s,_) -> min_elt s
89         
90   let rec max_elt n = match HNode.node n with
91     | Empty -> raise Not_found
92     | Leaf k -> k
93     | Branch (_,_,_,t) -> max_elt t
94         
95   let elements s =
96     let rec elements_aux acc n = match HNode.node n with
97       | Empty -> acc
98       | Leaf k -> k :: acc
99       | Branch (_,_,l,r) -> elements_aux (elements_aux acc r) l
100     in
101       elements_aux [] s
102         
103   let mask k m  = (k lor (m-1)) land (lnot m)
104     
105   let naive_highest_bit x = 
106     assert (x < 256);
107     let rec loop i = 
108       if i = 0 then 1 else if x lsr i = 1 then 1 lsl i else loop (i-1)
109     in
110       loop 7
111         
112   let hbit = Array.init 256 naive_highest_bit
113
114
115   let highest_bit x = let n = (x) lsr 24 in 
116   if n != 0 then Array.unsafe_get hbit n lsl 24
117   else let n = (x) lsr 16 in if n != 0 then Array.unsafe_get hbit n lsl 16
118   else let n = (x) lsr 8 in if n != 0 then Array.unsafe_get hbit n lsl 8
119   else Array.unsafe_get hbit (x)
120
121 IFDEF WORDIZE64
122 THEN
123   let highest_bit64 x =
124     let n = x lsr 32 in if n != 0 then highest_bit n lsl 32
125       else highest_bit x
126 END
127
128         
129   let branching_bit p0 p1 = highest_bit (p0 lxor p1)
130     
131   let join p0 t0 p1 t1 =  
132     let m = branching_bit p0 p1  in
133       if zero_bit p0 m then 
134         branch (mask p0 m) m t0 t1
135       else 
136         branch (mask p0 m) m t1 t0
137           
138   let match_prefix k p m = (mask k m) == p
139     
140   let add k t =
141     let kid = H.uid k in
142     let rec ins n = match HNode.node n with
143       | Empty -> leaf k
144       | Leaf j ->  if H.equal j k then n else join kid (leaf k) (H.uid j) n
145       | Branch (p,m,t0,t1)  ->
146           if match_prefix kid p m then
147             if zero_bit kid m then 
148               branch p m (ins t0) t1
149             else
150               branch p m t0 (ins t1)
151           else
152             join kid (leaf k)  p n
153     in
154     ins t
155       
156   let remove k t =
157     let kid = H.uid k in
158     let rec rmv n = match HNode.node n with
159       | Empty -> empty
160       | Leaf j  -> if H.equal k j then empty else n
161       | Branch (p,m,t0,t1) -> 
162           if match_prefix kid p m then
163             if zero_bit kid m then
164               branch_ne p m (rmv t0) t1
165             else
166               branch_ne p m t0 (rmv t1)
167           else
168             n
169     in
170     rmv t
171       
172   (* should run in O(1) thanks to Hash consing *)
173
174   let equal a b = HNode.equal a b 
175
176   let compare a b =  (HNode.uid a) - (HNode.uid b)
177
178   let rec merge s t = 
179     if (equal s t) (* This is cheap thanks to hash-consing *)
180     then s
181     else
182     match HNode.node s, HNode.node t with
183       | Empty, _  -> t
184       | _, Empty  -> s
185       | Leaf k, _ -> add k t
186       | _, Leaf k -> add k s
187       | Branch (p,m,s0,s1), Branch (q,n,t0,t1) ->
188           if m == n && match_prefix q p m then
189             branch p  m  (merge s0 t0) (merge s1 t1)
190           else if m > n && match_prefix q p m then
191             if zero_bit q m then 
192               branch p m (merge s0 t) s1
193             else 
194               branch p m s0 (merge s1 t)
195           else if m < n && match_prefix p q n then     
196             if zero_bit p n then
197               branch q n (merge s t0) t1
198             else
199               branch q n t0 (merge s t1)
200           else
201             (* The prefixes disagree. *)
202             join p s q t
203                
204         
205                
206                
207   let rec subset s1 s2 = (equal s1 s2) ||
208     match (HNode.node s1,HNode.node s2) with
209       | Empty, _ -> true
210       | _, Empty -> false
211       | Leaf k1, _ -> mem k1 s2
212       | Branch _, Leaf _ -> false
213       | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
214           if m1 == m2 && p1 == p2 then
215             subset l1 l2 && subset r1 r2
216           else if m1 < m2 && match_prefix p1 p2 m2 then
217             if zero_bit p1 m2 then 
218               subset l1 l2 && subset r1 l2
219             else 
220               subset l1 r2 && subset r1 r2
221           else
222             false
223
224               
225   let union s1 s2 = merge s1 s2
226     (* Todo replace with e Memo Module *)
227   module MemUnion = Hashtbl.Make(
228     struct 
229       type set = t 
230       type t = set*set 
231       let equal (x,y) (z,t) = (equal x z)&&(equal y t)
232       let equal a b = equal a b || equal b a
233       let hash (x,y) =   (* commutative hash *)
234         let x = HNode.hash x 
235         and y = HNode.hash y 
236         in
237           if x < y then HASHINT2(x,y) else HASHINT2(y,x)
238     end)
239   let h_mem = MemUnion.create MED_H_SIZE
240
241   let mem_union s1 s2 = 
242     try  MemUnion.find h_mem (s1,s2) 
243     with Not_found ->
244           let r = merge s1 s2 in MemUnion.add h_mem (s1,s2) r;r 
245       
246
247   let rec inter s1 s2 = 
248     if equal s1 s2 
249     then s1
250     else
251       match (HNode.node s1,HNode.node s2) with
252         | Empty, _ -> empty
253         | _, Empty -> empty
254         | Leaf k1, _ -> if mem k1 s2 then s1 else empty
255         | _, Leaf k2 -> if mem k2 s1 then s2 else empty
256         | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
257             if m1 == m2 && p1 == p2 then 
258               merge (inter l1 l2)  (inter r1 r2)
259             else if m1 > m2 && match_prefix p2 p1 m1 then
260               inter (if zero_bit p2 m1 then l1 else r1) s2
261             else if m1 < m2 && match_prefix p1 p2 m2 then
262               inter s1 (if zero_bit p1 m2 then l2 else r2)
263             else
264               empty
265
266   let rec diff s1 s2 = 
267     if equal s1 s2 
268     then empty
269     else
270       match (HNode.node s1,HNode.node s2) with
271         | Empty, _ -> empty
272         | _, Empty -> s1
273         | Leaf k1, _ -> if mem k1 s2 then empty else s1
274         | _, Leaf k2 -> remove k2 s1
275         | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
276             if m1 == m2 && p1 == p2 then
277               merge (diff l1 l2) (diff r1 r2)
278             else if m1 > m2 && match_prefix p2 p1 m1 then
279               if zero_bit p2 m1 then 
280                 merge (diff l1 s2) r1
281               else 
282                 merge l1 (diff r1 s2)
283             else if m1 < m2 && match_prefix p1 p2 m2 then
284               if zero_bit p1 m2 then diff s1 l2 else diff s1 r2
285             else
286           s1
287                
288
289 (*s All the following operations ([cardinal], [iter], [fold], [for_all],
290     [exists], [filter], [partition], [choose], [elements]) are
291     implemented as for any other kind of binary trees. *)
292
293 let rec cardinal n = match HNode.node n with
294   | Empty -> 0
295   | Leaf _ -> 1
296   | Branch (_,_,t0,t1) -> cardinal t0 + cardinal t1
297
298 let rec iter f n = match HNode.node n with
299   | Empty -> ()
300   | Leaf k -> f k
301   | Branch (_,_,t0,t1) -> iter f t0; iter f t1
302       
303 let rec fold f s accu = match HNode.node s with
304   | Empty -> accu
305   | Leaf k -> f k accu
306   | Branch (_,_,t0,t1) -> fold f t0 (fold f t1 accu)
307
308
309 let rec for_all p n = match HNode.node n with
310   | Empty -> true
311   | Leaf k -> p k
312   | Branch (_,_,t0,t1) -> for_all p t0 && for_all p t1
313
314 let rec exists p n = match HNode.node n with
315   | Empty -> false
316   | Leaf k -> p k
317   | Branch (_,_,t0,t1) -> exists p t0 || exists p t1
318
319 let rec filter pr n = match HNode.node n with
320   | Empty -> empty
321   | Leaf k -> if pr k then n else empty
322   | Branch (p,m,t0,t1) -> branch_ne p m (filter pr t0) (filter pr t1)
323
324 let partition p s =
325   let rec part (t,f as acc) n = match HNode.node n with
326     | Empty -> acc
327     | Leaf k -> if p k then (add k t, f) else (t, add k f)
328     | Branch (_,_,t0,t1) -> part (part acc t0) t1
329   in
330   part (empty, empty) s
331
332 let rec choose n = match HNode.node n with
333   | Empty -> raise Not_found
334   | Leaf k -> k
335   | Branch (_, _,t0,_) -> choose t0   (* we know that [t0] is non-empty *)
336
337
338 let split x s =
339   let coll k (l, b, r) =
340     if k < x then add k l, b, r
341     else if k > x then l, b, add k r
342     else l, true, r 
343   in
344   fold coll s (empty, false, empty)
345
346
347 let make l = List.fold_left (fun acc e -> add e acc ) empty l
348 (*i*)
349
350 (*s Additional functions w.r.t to [Set.S]. *)
351
352 let rec intersect s1 s2 = (equal s1 s2) ||
353   match (HNode.node s1,HNode.node s2) with
354   | Empty, _ -> false
355   | _, Empty -> false
356   | Leaf k1, _ -> mem k1 s2
357   | _, Leaf k2 -> mem k2 s1
358   | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
359       if m1 == m2 && p1 == p2 then
360         intersect l1 l2 || intersect r1 r2
361       else if m1 < m2 && match_prefix p2 p1 m1 then
362         intersect (if zero_bit p2 m1 then l1 else r1) s2
363       else if m1 > m2 && match_prefix p1 p2 m2 then
364         intersect s1 (if zero_bit p1 m2 then l2 else r2)
365       else
366         false
367
368
369
370 let rec uncons n = match HNode.node n with
371   | Empty -> raise Not_found
372   | Leaf k -> (k,empty)
373   | Branch (p,m,s,t) -> let h,ns = uncons s in h,branch_ne p m ns t
374    
375 let from_list l = List.fold_left (fun acc e -> add e acc) empty l
376
377
378 end
379
380 (* Have to benchmark wheter this whole include stuff is worth it *)
381 module Int : S with type elt = int = Make ( struct type t = int 
382                                                  type data = t
383                                                  external hash : t -> int = "%identity"
384                                                  external uid : t -> int = "%identity"
385                                                  let equal : t -> t -> bool = (==)
386                                                  external make : t -> int = "%identity"
387                                                  external node : t -> int = "%identity"
388                                                    
389                                           end
390                                           )