Further optimisations, changed the prototype of Tree.mli
[SXSI/xpathcomp.git] / ptset.ml
1 (***************************************************************************)
2 (* Implementation for sets of positive integers implemented as deeply hash-*)
3 (* consed Patricia trees. Provide fast set operations, fast membership as  *)
4 (* well as fast min and max elements. Hash consing provides O(1) equality  *)
5 (* checking                                                                *)
6 (*                                                                         *)
7 (***************************************************************************)
8 INCLUDE "utils.ml"
9 module type S = 
10 sig
11   include Set.S
12   type data
13   val intersect : t -> t -> bool
14   val is_singleton : t -> bool
15   val mem_union : t -> t -> t
16   val hash : t -> int
17   val uid : t -> int
18   val uncons : t -> elt*t
19   val from_list : elt list -> t 
20   val make : data -> t
21   val node : t -> data
22 end
23
24 module Make ( H : Hcons.S ) : S with type elt = H.t =
25 struct
26   type elt = H.t
27   type 'a node =
28     | Empty
29     | Leaf of elt
30     | Branch of int * int * 'a * 'a
31
32   module rec HNode : Hcons.S with type data = Node.t = Hcons.Make (Node)
33   and Node : Hashtbl.HashedType  with type t = HNode.t node =
34   struct 
35     type t =  HNode.t node
36     let equal x y = 
37       match x,y with
38         | Empty,Empty -> true
39         | Leaf k1, Leaf k2 ->  k1 == k2
40         | Branch(b1,i1,l1,r1),Branch(b2,i2,l2,r2) ->
41             b1 == b2 && i1 == i2 &&
42               (HNode.equal l1 l2) &&
43               (HNode.equal r1 r2) 
44         | _ -> false
45     let hash = function 
46       | Empty -> 0
47       | Leaf i -> HASHINT2(HALF_MAX_INT,H.uid i)
48       | Branch (b,i,l,r) -> HASHINT4(b,i,HNode.uid l, HNode.uid r)
49   end
50  ;;
51                              
52   type t = HNode.t
53   type data = t node
54   let hash = HNode.hash 
55   let uid = HNode.uid
56   let make = HNode.make
57   let node _ = failwith "node"
58   let empty = HNode.make Empty
59     
60   let is_empty s = (HNode.node s) == Empty
61        
62   let branch p m l r = HNode.make (Branch(p,m,l,r))
63
64   let leaf k = HNode.make (Leaf k)
65
66   (* To enforce the invariant that a branch contains two non empty sub-trees *)
67   let branch_ne p m t0 t1 = 
68     if (is_empty t0) then t1
69     else if is_empty t1 then t0 else branch p m t0 t1
70       
71   (********** from here on, only use the smart constructors *************)
72       
73   let zero_bit k m = (k land m) == 0
74     
75   let singleton k = leaf k
76     
77   let is_singleton n = 
78     match HNode.node n with Leaf _ -> true
79       | _ -> false
80           
81   let mem (k:elt) n = 
82     let kid = H.uid k in
83     let rec loop n = match HNode.node n with
84       | Empty -> false
85       | Leaf j ->  k == j
86       | Branch (p, _, l, r) -> if kid <= p then loop l else loop r
87     in loop n
88          
89   let rec min_elt n = match HNode.node n with
90     | Empty -> raise Not_found
91     | Leaf k -> k
92     | Branch (_,_,s,_) -> min_elt s
93         
94   let rec max_elt n = match HNode.node n with
95     | Empty -> raise Not_found
96     | Leaf k -> k
97     | Branch (_,_,_,t) -> max_elt t
98         
99   let elements s =
100     let rec elements_aux acc n = match HNode.node n with
101       | Empty -> acc
102       | Leaf k -> k :: acc
103       | Branch (_,_,l,r) -> elements_aux (elements_aux acc r) l
104     in
105       elements_aux [] s
106         
107   let mask k m  = (k lor (m-1)) land (lnot m)
108     
109   let naive_highest_bit x = 
110     assert (x < 256);
111     let rec loop i = 
112       if i = 0 then 1 else if x lsr i = 1 then 1 lsl i else loop (i-1)
113     in
114       loop 7
115         
116   let hbit = Array.init 256 naive_highest_bit
117
118
119   let highest_bit x = let n = (x) lsr 24 in 
120   if n != 0 then Array.unsafe_get hbit n lsl 24
121   else let n = (x) lsr 16 in if n != 0 then Array.unsafe_get hbit n lsl 16
122   else let n = (x) lsr 8 in if n != 0 then Array.unsafe_get hbit n lsl 8
123   else Array.unsafe_get hbit (x)
124
125 IFDEF WORDIZE64
126 THEN
127   let highest_bit64 x =
128     let n = x lsr 32 in if n != 0 then highest_bit n lsl 32
129       else highest_bit x
130 END
131
132         
133   let branching_bit p0 p1 = highest_bit (p0 lxor p1)
134     
135   let join p0 t0 p1 t1 =  
136     let m = branching_bit p0 p1  in
137       if zero_bit p0 m then 
138         branch (mask p0 m) m t0 t1
139       else 
140         branch (mask p0 m) m t1 t0
141           
142   let match_prefix k p m = (mask k m) == p
143     
144   let add k t =
145     let kid = H.uid k in
146     let rec ins n = match HNode.node n with
147       | Empty -> leaf k
148       | Leaf j ->  if j == k then n else join kid (leaf k) (H.uid j) n
149       | Branch (p,m,t0,t1)  ->
150           if match_prefix kid p m then
151             if zero_bit kid m then 
152               branch p m (ins t0) t1
153             else
154               branch p m t0 (ins t1)
155           else
156             join kid (leaf k)  p n
157     in
158     ins t
159       
160   let remove k t =
161     let kid = H.uid k in
162     let rec rmv n = match HNode.node n with
163       | Empty -> empty
164       | Leaf j  -> if  k == j then empty else n
165       | Branch (p,m,t0,t1) -> 
166           if match_prefix kid p m then
167             if zero_bit kid m then
168               branch_ne p m (rmv t0) t1
169             else
170               branch_ne p m t0 (rmv t1)
171           else
172             n
173     in
174     rmv t
175       
176   (* should run in O(1) thanks to Hash consing *)
177
178   let equal a b = HNode.equal a b 
179
180   let compare a b =  (HNode.uid a) - (HNode.uid b)
181
182   let rec merge s t = 
183     if (equal s t) (* This is cheap thanks to hash-consing *)
184     then s
185     else
186     match HNode.node s, HNode.node t with
187       | Empty, _  -> t
188       | _, Empty  -> s
189       | Leaf k, _ -> add k t
190       | _, Leaf k -> add k s
191       | Branch (p,m,s0,s1), Branch (q,n,t0,t1) ->
192           if m == n && match_prefix q p m then
193             branch p  m  (merge s0 t0) (merge s1 t1)
194           else if m > n && match_prefix q p m then
195             if zero_bit q m then 
196               branch p m (merge s0 t) s1
197             else 
198               branch p m s0 (merge s1 t)
199           else if m < n && match_prefix p q n then     
200             if zero_bit p n then
201               branch q n (merge s t0) t1
202             else
203               branch q n t0 (merge s t1)
204           else
205             (* The prefixes disagree. *)
206             join p s q t
207                
208         
209                
210                
211   let rec subset s1 s2 = (equal s1 s2) ||
212     match (HNode.node s1,HNode.node s2) with
213       | Empty, _ -> true
214       | _, Empty -> false
215       | Leaf k1, _ -> mem k1 s2
216       | Branch _, Leaf _ -> false
217       | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
218           if m1 == m2 && p1 == p2 then
219             subset l1 l2 && subset r1 r2
220           else if m1 < m2 && match_prefix p1 p2 m2 then
221             if zero_bit p1 m2 then 
222               subset l1 l2 && subset r1 l2
223             else 
224               subset l1 r2 && subset r1 r2
225           else
226             false
227
228               
229   let union s1 s2 = merge s1 s2
230     (* Todo replace with e Memo Module *)
231   module MemUnion = Hashtbl.Make(
232     struct 
233       type set = t 
234       type t = set*set 
235       let equal (x,y) (z,t) = (equal x z)&&(equal y t)
236       let equal a b = equal a b || equal b a
237       let hash (x,y) =   (* commutative hash *)
238         let x = HNode.hash x 
239         and y = HNode.hash y 
240         in
241           if x < y then HASHINT2(x,y) else HASHINT2(y,x)
242     end)
243   let h_mem = MemUnion.create MED_H_SIZE
244
245   let mem_union s1 s2 = 
246     try  MemUnion.find h_mem (s1,s2) 
247     with Not_found ->
248           let r = merge s1 s2 in MemUnion.add h_mem (s1,s2) r;r 
249       
250
251   let rec inter s1 s2 = 
252     if equal s1 s2 
253     then s1
254     else
255       match (HNode.node s1,HNode.node s2) with
256         | Empty, _ -> empty
257         | _, Empty -> empty
258         | Leaf k1, _ -> if mem k1 s2 then s1 else empty
259         | _, Leaf k2 -> if mem k2 s1 then s2 else empty
260         | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
261             if m1 == m2 && p1 == p2 then 
262               merge (inter l1 l2)  (inter r1 r2)
263             else if m1 > m2 && match_prefix p2 p1 m1 then
264               inter (if zero_bit p2 m1 then l1 else r1) s2
265             else if m1 < m2 && match_prefix p1 p2 m2 then
266               inter s1 (if zero_bit p1 m2 then l2 else r2)
267             else
268               empty
269
270   let rec diff s1 s2 = 
271     if equal s1 s2 
272     then empty
273     else
274       match (HNode.node s1,HNode.node s2) with
275         | Empty, _ -> empty
276         | _, Empty -> s1
277         | Leaf k1, _ -> if mem k1 s2 then empty else s1
278         | _, Leaf k2 -> remove k2 s1
279         | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
280             if m1 == m2 && p1 == p2 then
281               merge (diff l1 l2) (diff r1 r2)
282             else if m1 > m2 && match_prefix p2 p1 m1 then
283               if zero_bit p2 m1 then 
284                 merge (diff l1 s2) r1
285               else 
286                 merge l1 (diff r1 s2)
287             else if m1 < m2 && match_prefix p1 p2 m2 then
288               if zero_bit p1 m2 then diff s1 l2 else diff s1 r2
289             else
290           s1
291                
292
293 (*s All the following operations ([cardinal], [iter], [fold], [for_all],
294     [exists], [filter], [partition], [choose], [elements]) are
295     implemented as for any other kind of binary trees. *)
296
297 let rec cardinal n = match HNode.node n with
298   | Empty -> 0
299   | Leaf _ -> 1
300   | Branch (_,_,t0,t1) -> cardinal t0 + cardinal t1
301
302 let rec iter f n = match HNode.node n with
303   | Empty -> ()
304   | Leaf k -> f k
305   | Branch (_,_,t0,t1) -> iter f t0; iter f t1
306       
307 let rec fold f s accu = match HNode.node s with
308   | Empty -> accu
309   | Leaf k -> f k accu
310   | Branch (_,_,t0,t1) -> fold f t0 (fold f t1 accu)
311
312
313 let rec for_all p n = match HNode.node n with
314   | Empty -> true
315   | Leaf k -> p k
316   | Branch (_,_,t0,t1) -> for_all p t0 && for_all p t1
317
318 let rec exists p n = match HNode.node n with
319   | Empty -> false
320   | Leaf k -> p k
321   | Branch (_,_,t0,t1) -> exists p t0 || exists p t1
322
323 let rec filter pr n = match HNode.node n with
324   | Empty -> empty
325   | Leaf k -> if pr k then n else empty
326   | Branch (p,m,t0,t1) -> branch_ne p m (filter pr t0) (filter pr t1)
327
328 let partition p s =
329   let rec part (t,f as acc) n = match HNode.node n with
330     | Empty -> acc
331     | Leaf k -> if p k then (add k t, f) else (t, add k f)
332     | Branch (_,_,t0,t1) -> part (part acc t0) t1
333   in
334   part (empty, empty) s
335
336 let rec choose n = match HNode.node n with
337   | Empty -> raise Not_found
338   | Leaf k -> k
339   | Branch (_, _,t0,_) -> choose t0   (* we know that [t0] is non-empty *)
340
341
342 let split x s =
343   let coll k (l, b, r) =
344     if k < x then add k l, b, r
345     else if k > x then l, b, add k r
346     else l, true, r 
347   in
348   fold coll s (empty, false, empty)
349
350 (*s Additional functions w.r.t to [Set.S]. *)
351
352 let rec intersect s1 s2 = (equal s1 s2) ||
353   match (HNode.node s1,HNode.node s2) with
354   | Empty, _ -> false
355   | _, Empty -> false
356   | Leaf k1, _ -> mem k1 s2
357   | _, Leaf k2 -> mem k2 s1
358   | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
359       if m1 == m2 && p1 == p2 then
360         intersect l1 l2 || intersect r1 r2
361       else if m1 < m2 && match_prefix p2 p1 m1 then
362         intersect (if zero_bit p2 m1 then l1 else r1) s2
363       else if m1 > m2 && match_prefix p1 p2 m2 then
364         intersect s1 (if zero_bit p1 m2 then l2 else r2)
365       else
366         false
367
368
369
370 let rec uncons n = match HNode.node n with
371   | Empty -> raise Not_found
372   | Leaf k -> (k,empty)
373   | Branch (p,m,s,t) -> let h,ns = uncons s in h,branch_ne p m ns t
374    
375 let from_list l = List.fold_left (fun acc e -> add e acc) empty l
376
377
378 end
379
380 module Int : S with type elt = int 
381   =
382   Make ( struct type t = int 
383                 type data = t
384                 external hash : t -> int = "%identity"
385                 external uid : t -> int = "%identity"
386                 let equal : t -> t -> bool = (==)
387                 external make : t -> int = "%identity"
388                 external node : t -> int = "%identity"
389                   
390          end
391        )