removed cruft, fixed ptset.ml
[SXSI/xpathcomp.git] / ptset_include.ml
1 (***************************************************************************)
2 (* Implementation for sets of positive integers implemented as deeply hash-*)
3 (* consed Patricia trees. Provide fast set operations, fast membership as  *)
4 (* well as fast min and max elements. Hash consing provides O(1) equality  *)
5 (* checking                                                                *)
6 (*                                                                         *)
7 (***************************************************************************)
8 IFDEF USE_PTSET_INCLUDE
9 THEN
10 INCLUDE "utils.ml"
11 (*
12   Cannot be used like this:
13   Need to be included after the following declrations:
14   type elt = ...
15   let equal_elt : elt -> elt -> bool = ...
16   let hash_elt : elt -> int = ...
17   let uid_elt : elt -> int = ...
18 *)
19 type 'a node =
20   | Empty
21   | Leaf of elt
22   | Branch of int * int * 'a * 'a
23
24 module rec HNode : Hcons.S with type data = Node.t = Hcons.Make (Node)
25 and Node : Hashtbl.HashedType  with type t = HNode.t node =
26 struct 
27   type t =  HNode.t node
28   let equal x y = 
29     match x,y with
30       | Empty,Empty -> true
31       | Leaf k1, Leaf k2 -> equal_elt k1 k2
32       | Branch(b1,i1,l1,r1),Branch(b2,i2,l2,r2) ->
33           b1 == b2 && i1 == i2 &&
34             (HNode.equal l1 l2) &&
35             (HNode.equal r1 r2) 
36       | _ -> false
37   let hash = function 
38     | Empty -> 0
39     | Leaf i -> HASHINT2(HALF_MAX_INT,hash_elt i)
40     | Branch (b,i,l,r) -> HASHINT4(b,i,HNode.hash l, HNode.hash r)
41 end
42
43 type t = HNode.t
44 let hash = HNode.hash 
45 let uid = HNode.uid
46
47 let empty = HNode.make Empty
48
49 let is_empty s = (HNode.node s) == Empty
50     
51 (*  WH.merge pool *)
52
53 let branch p m l r = HNode.make (Branch(p,m,l,r))
54 let leaf k = HNode.make (Leaf k)
55
56 (* To enforce the invariant that a branch contains two non empty sub-trees *)
57 let branch_ne p m t0 t1 = 
58   if (is_empty t0) then t1
59   else if is_empty t1 then t0 else branch p m t0 t1
60
61 (********** from here on, only use the smart constructors *************)
62
63 let zero_bit k m = (k land m) == 0
64
65 let singleton k = leaf k
66
67 let is_singleton n = 
68   match HNode.node n with Leaf _ -> true
69     | _ -> false
70
71 let mem (k:elt) n = 
72   let kid = uid_elt k in
73   let rec loop n = match HNode.node n with
74     | Empty -> false
75     | Leaf j -> equal_elt k j
76     | Branch (p, _, l, r) -> if kid <= p then loop l else loop r
77   in loop n
78
79 let rec min_elt n = match HNode.node n with
80   | Empty -> raise Not_found
81   | Leaf k -> k
82   | Branch (_,_,s,_) -> min_elt s
83       
84 let rec max_elt n = match HNode.node n with
85   | Empty -> raise Not_found
86   | Leaf k -> k
87   | Branch (_,_,_,t) -> max_elt t
88
89   let elements s =
90     let rec elements_aux acc n = match HNode.node n with
91       | Empty -> acc
92       | Leaf k -> k :: acc
93       | Branch (_,_,l,r) -> elements_aux (elements_aux acc r) l
94     in
95     elements_aux [] s
96
97   let mask k m  = (k lor (m-1)) land (lnot m)
98
99   let naive_highest_bit x = 
100     assert (x < 256);
101     let rec loop i = 
102       if i = 0 then 1 else if x lsr i = 1 then 1 lsl i else loop (i-1)
103     in
104     loop 7
105
106   let hbit = Array.init 256 naive_highest_bit
107   
108   let highest_bit_32 x =
109     let n = x lsr 24 in if n != 0 then Array.unsafe_get hbit n lsl 24
110     else let n = x lsr 16 in if n != 0 then Array.unsafe_get hbit n lsl 16
111     else let n = x lsr 8 in if n != 0 then Array.unsafe_get hbit n lsl 8
112     else Array.unsafe_get hbit x
113
114   let highest_bit_64 x =
115     let n = x lsr 32 in if n != 0 then (highest_bit_32 n) lsl 32
116     else highest_bit_32 x
117
118   let highest_bit = match Sys.word_size with
119     | 32 -> highest_bit_32
120     | 64 -> highest_bit_64
121     | _ -> assert false
122
123   let branching_bit p0 p1 = highest_bit (p0 lxor p1)
124
125   let join p0 t0 p1 t1 =  
126     let m = branching_bit p0 p1  in
127     if zero_bit p0 m then 
128       branch (mask p0 m) m t0 t1
129     else 
130       branch (mask p0 m) m t1 t0
131     
132   let match_prefix k p m = (mask k m) == p
133
134   let add k t =
135     let kid = uid_elt k in
136     let rec ins n = match HNode.node n with
137       | Empty -> leaf k
138       | Leaf j ->  if equal_elt j k then n else join kid (leaf k) (uid_elt j) n
139       | Branch (p,m,t0,t1)  ->
140           if match_prefix kid p m then
141             if zero_bit kid m then 
142               branch p m (ins t0) t1
143             else
144               branch p m t0 (ins t1)
145           else
146             join kid (leaf k)  p n
147     in
148     ins t
149       
150   let remove k t =
151     let kid = uid_elt k in
152     let rec rmv n = match HNode.node n with
153       | Empty -> empty
154       | Leaf j  -> if equal_elt k j then empty else n
155       | Branch (p,m,t0,t1) -> 
156           if match_prefix kid p m then
157             if zero_bit kid m then
158               branch_ne p m (rmv t0) t1
159             else
160               branch_ne p m t0 (rmv t1)
161           else
162             n
163     in
164     rmv t
165       
166   (* should run in O(1) thanks to Hash consing *)
167
168   let equal a b = HNode.equal a b 
169
170   let compare a b =  (HNode.uid a) - (HNode.uid b)
171
172   let rec merge s t = 
173     if (equal s t) (* This is cheap thanks to hash-consing *)
174     then s
175     else
176     match HNode.node s, HNode.node t with
177       | Empty, _  -> t
178       | _, Empty  -> s
179       | Leaf k, _ -> add k t
180       | _, Leaf k -> add k s
181       | Branch (p,m,s0,s1), Branch (q,n,t0,t1) ->
182           if m == n && match_prefix q p m then
183             branch p  m  (merge s0 t0) (merge s1 t1)
184           else if m > n && match_prefix q p m then
185             if zero_bit q m then 
186               branch p m (merge s0 t) s1
187             else 
188               branch p m s0 (merge s1 t)
189           else if m < n && match_prefix p q n then     
190             if zero_bit p n then
191               branch q n (merge s t0) t1
192             else
193               branch q n t0 (merge s t1)
194           else
195             (* The prefixes disagree. *)
196             join p s q t
197                
198         
199                
200                
201   let rec subset s1 s2 = (equal s1 s2) ||
202     match (HNode.node s1,HNode.node s2) with
203       | Empty, _ -> true
204       | _, Empty -> false
205       | Leaf k1, _ -> mem k1 s2
206       | Branch _, Leaf _ -> false
207       | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
208           if m1 == m2 && p1 == p2 then
209             subset l1 l2 && subset r1 r2
210           else if m1 < m2 && match_prefix p1 p2 m2 then
211             if zero_bit p1 m2 then 
212               subset l1 l2 && subset r1 l2
213             else 
214               subset l1 r2 && subset r1 r2
215           else
216             false
217
218               
219   let union s1 s2 = merge s1 s2
220     (* Todo replace with e Memo Module *)
221   module MemUnion = Hashtbl.Make(
222     struct 
223       type set = t 
224       type t = set*set 
225       let equal (x,y) (z,t) = (equal x z)&&(equal y t)
226       let equal a b = equal a b || equal b a
227       let hash (x,y) =   (* commutative hash *)
228         let x = HNode.hash x 
229         and y = HNode.hash y 
230         in
231           if x < y then HASHINT2(x,y) else HASHINT2(y,x)
232     end)
233   let h_mem = MemUnion.create MED_H_SIZE
234
235   let mem_union s1 s2 = 
236     try  MemUnion.find h_mem (s1,s2) 
237     with Not_found ->
238           let r = merge s1 s2 in MemUnion.add h_mem (s1,s2) r;r 
239       
240
241   let rec inter s1 s2 = 
242     if equal s1 s2 
243     then s1
244     else
245       match (HNode.node s1,HNode.node s2) with
246         | Empty, _ -> empty
247         | _, Empty -> empty
248         | Leaf k1, _ -> if mem k1 s2 then s1 else empty
249         | _, Leaf k2 -> if mem k2 s1 then s2 else empty
250         | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
251             if m1 == m2 && p1 == p2 then 
252               merge (inter l1 l2)  (inter r1 r2)
253             else if m1 > m2 && match_prefix p2 p1 m1 then
254               inter (if zero_bit p2 m1 then l1 else r1) s2
255             else if m1 < m2 && match_prefix p1 p2 m2 then
256               inter s1 (if zero_bit p1 m2 then l2 else r2)
257             else
258               empty
259
260   let rec diff s1 s2 = 
261     if equal s1 s2 
262     then empty
263     else
264       match (HNode.node s1,HNode.node s2) with
265         | Empty, _ -> empty
266         | _, Empty -> s1
267         | Leaf k1, _ -> if mem k1 s2 then empty else s1
268         | _, Leaf k2 -> remove k2 s1
269         | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
270             if m1 == m2 && p1 == p2 then
271               merge (diff l1 l2) (diff r1 r2)
272             else if m1 > m2 && match_prefix p2 p1 m1 then
273               if zero_bit p2 m1 then 
274                 merge (diff l1 s2) r1
275               else 
276                 merge l1 (diff r1 s2)
277             else if m1 < m2 && match_prefix p1 p2 m2 then
278               if zero_bit p1 m2 then diff s1 l2 else diff s1 r2
279             else
280           s1
281                
282
283 (*s All the following operations ([cardinal], [iter], [fold], [for_all],
284     [exists], [filter], [partition], [choose], [elements]) are
285     implemented as for any other kind of binary trees. *)
286
287 let rec cardinal n = match HNode.node n with
288   | Empty -> 0
289   | Leaf _ -> 1
290   | Branch (_,_,t0,t1) -> cardinal t0 + cardinal t1
291
292 let rec iter f n = match HNode.node n with
293   | Empty -> ()
294   | Leaf k -> f k
295   | Branch (_,_,t0,t1) -> iter f t0; iter f t1
296       
297 let rec fold f s accu = match HNode.node s with
298   | Empty -> accu
299   | Leaf k -> f k accu
300   | Branch (_,_,t0,t1) -> fold f t0 (fold f t1 accu)
301
302
303 let rec for_all p n = match HNode.node n with
304   | Empty -> true
305   | Leaf k -> p k
306   | Branch (_,_,t0,t1) -> for_all p t0 && for_all p t1
307
308 let rec exists p n = match HNode.node n with
309   | Empty -> false
310   | Leaf k -> p k
311   | Branch (_,_,t0,t1) -> exists p t0 || exists p t1
312
313 let rec filter pr n = match HNode.node n with
314   | Empty -> empty
315   | Leaf k -> if pr k then n else empty
316   | Branch (p,m,t0,t1) -> branch_ne p m (filter pr t0) (filter pr t1)
317
318 let partition p s =
319   let rec part (t,f as acc) n = match HNode.node n with
320     | Empty -> acc
321     | Leaf k -> if p k then (add k t, f) else (t, add k f)
322     | Branch (_,_,t0,t1) -> part (part acc t0) t1
323   in
324   part (empty, empty) s
325
326 let rec choose n = match HNode.node n with
327   | Empty -> raise Not_found
328   | Leaf k -> k
329   | Branch (_, _,t0,_) -> choose t0   (* we know that [t0] is non-empty *)
330
331
332 let split x s =
333   let coll k (l, b, r) =
334     if k < x then add k l, b, r
335     else if k > x then l, b, add k r
336     else l, true, r 
337   in
338   fold coll s (empty, false, empty)
339
340
341 let make l = List.fold_left (fun acc e -> add e acc ) empty l
342 (*i*)
343
344 (*s Additional functions w.r.t to [Set.S]. *)
345
346 let rec intersect s1 s2 = (equal s1 s2) ||
347   match (HNode.node s1,HNode.node s2) with
348   | Empty, _ -> false
349   | _, Empty -> false
350   | Leaf k1, _ -> mem k1 s2
351   | _, Leaf k2 -> mem k2 s1
352   | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
353       if m1 == m2 && p1 == p2 then
354         intersect l1 l2 || intersect r1 r2
355       else if m1 < m2 && match_prefix p2 p1 m1 then
356         intersect (if zero_bit p2 m1 then l1 else r1) s2
357       else if m1 > m2 && match_prefix p1 p2 m2 then
358         intersect s1 (if zero_bit p1 m2 then l2 else r2)
359       else
360         false
361
362
363
364 let rec uncons n = match HNode.node n with
365   | Empty -> raise Not_found
366   | Leaf k -> (k,empty)
367   | Branch (p,m,s,t) -> let h,ns = uncons s in h,branch_ne p m ns t
368    
369 let from_list l = List.fold_left (fun acc e -> add e acc) empty l
370
371 END