.
[SXSI/xpathcomp.git] / ptset.ml
1 (***************************************************************************)
2 (* Implementation for sets of positive integers implemented as deeply hash-*)
3 (* consed Patricia trees. Provide fast set operations, fast membership as  *)
4 (* well as fast min and max elements. Hash consing provides O(1) equality  *)
5 (* checking                                                                *)
6 (*                                                                         *)
7 (***************************************************************************)
8 INCLUDE "utils.ml"
9 module type S = 
10 sig
11     type elt
12
13   type 'a node
14   module rec Node : sig
15     include  Hcons.S with type data = Data.t
16   end
17   and Data : sig 
18     include 
19       Hashtbl.HashedType with type t = Node.t node
20   end
21   type data = Data.t
22   type t = Node.t
23
24
25   val empty : t
26   val is_empty : t -> bool
27   val mem : elt -> t -> bool
28   val add : elt -> t -> t
29   val singleton : elt -> t
30   val remove : elt -> t -> t
31   val union : t -> t -> t
32   val inter : t -> t -> t
33   val diff : t -> t -> t
34   val compare : t -> t -> int
35   val equal : t -> t -> bool
36   val subset : t -> t -> bool
37   val iter : (elt -> unit) -> t -> unit
38   val fold : (elt -> 'a -> 'a) -> t -> 'a -> 'a
39   val for_all : (elt -> bool) -> t -> bool
40   val exists : (elt -> bool) -> t -> bool
41   val filter : (elt -> bool) -> t -> t
42   val partition : (elt -> bool) -> t -> t * t
43   val cardinal : t -> int
44   val elements : t -> elt list
45   val min_elt : t -> elt
46   val max_elt : t -> elt
47   val choose : t -> elt
48   val split : elt -> t -> t * bool * t   
49
50   val intersect : t -> t -> bool
51   val is_singleton : t -> bool
52   val mem_union : t -> t -> t
53   val hash : t -> int
54   val uid : t -> Uid.t
55   val uncons : t -> elt*t
56   val from_list : elt list -> t 
57   val make : data -> t
58   val node : t -> data
59     
60   val with_id : Uid.t -> t
61 end
62
63 module Make ( H : Hcons.SA ) : S with type elt = H.t =
64 struct
65   type elt = H.t
66   type 'a node =
67     | Empty
68     | Leaf of elt
69     | Branch of int * int * 'a * 'a
70
71   module rec Node : Hcons.S with type data = Data.t = Hcons.Make (Data)
72   and Data : Hashtbl.HashedType  with type t = Node.t node =
73   struct 
74     type t =  Node.t node
75     let equal x y = 
76       match x,y with
77         | Empty,Empty -> true
78         | Leaf k1, Leaf k2 ->  k1 == k2
79         | Branch(b1,i1,l1,r1),Branch(b2,i2,l2,r2) ->
80             b1 == b2 && i1 == i2 &&
81               (Node.equal l1 l2) &&
82               (Node.equal r1 r2) 
83         | _ -> false
84     let hash = function 
85       | Empty -> 0
86       | Leaf i -> HASHINT2(HALF_MAX_INT,Uid.to_int (H.uid i))
87       | Branch (b,i,l,r) -> HASHINT4(b,i,Uid.to_int l.Node.id, Uid.to_int r.Node.id)
88   end
89  
90   type data = Data.t
91   type t = Node.t
92
93   let hash = Node.hash 
94   let uid = Node.uid
95   let make = Node.make
96   let node _ = failwith "node"
97   let empty = Node.make Empty
98     
99   let is_empty s = (Node.node s) == Empty
100        
101   let branch p m l r = Node.make (Branch(p,m,l,r))
102
103   let leaf k = Node.make (Leaf k)
104
105   (* To enforce the invariant that a branch contains two non empty sub-trees *)
106   let branch_ne p m t0 t1 = 
107     if (is_empty t0) then t1
108     else if is_empty t1 then t0 else branch p m t0 t1
109       
110   (********** from here on, only use the smart constructors *************)
111       
112   let zero_bit k m = (k land m) == 0
113     
114   let singleton k = leaf k
115     
116   let is_singleton n = 
117     match Node.node n with Leaf _ -> true
118       | _ -> false
119           
120   let mem (k:elt) n = 
121     let kid = Uid.to_int (H.uid k) in
122     let rec loop n = match Node.node n with
123       | Empty -> false
124       | Leaf j ->  k == j
125       | Branch (p, _, l, r) -> if kid <= p then loop l else loop r
126     in loop n
127          
128   let rec min_elt n = match Node.node n with
129     | Empty -> raise Not_found
130     | Leaf k -> k
131     | Branch (_,_,s,_) -> min_elt s
132         
133   let rec max_elt n = match Node.node n with
134     | Empty -> raise Not_found
135     | Leaf k -> k
136     | Branch (_,_,_,t) -> max_elt t
137         
138   let elements s =
139     let rec elements_aux acc n = match Node.node n with
140       | Empty -> acc
141       | Leaf k -> k :: acc
142       | Branch (_,_,l,r) -> elements_aux (elements_aux acc r) l
143     in
144       elements_aux [] s
145         
146   let mask k m  = (k lor (m-1)) land (lnot m)
147     
148   let naive_highest_bit x = 
149     assert (x < 256);
150     let rec loop i = 
151       if i = 0 then 1 else if x lsr i = 1 then 1 lsl i else loop (i-1)
152     in
153       loop 7
154         
155   let hbit = Array.init 256 naive_highest_bit
156
157
158   let highest_bit x = let n = (x) lsr 24 in 
159   if n != 0 then Array.unsafe_get hbit n lsl 24
160   else let n = (x) lsr 16 in if n != 0 then Array.unsafe_get hbit n lsl 16
161   else let n = (x) lsr 8 in if n != 0 then Array.unsafe_get hbit n lsl 8
162   else Array.unsafe_get hbit (x)
163
164 IFDEF WORDIZE64
165 THEN
166   let highest_bit64 x =
167     let n = x lsr 32 in if n != 0 then highest_bit n lsl 32
168       else highest_bit x
169 END
170
171         
172   let branching_bit p0 p1 = highest_bit (p0 lxor p1)
173     
174   let join p0 t0 p1 t1 =  
175     let m = branching_bit p0 p1  in
176       if zero_bit p0 m then 
177         branch (mask p0 m) m t0 t1
178       else 
179         branch (mask p0 m) m t1 t0
180           
181   let match_prefix k p m = (mask k m) == p
182     
183   let add k t =
184     let kid = Uid.to_int (H.uid k) in
185     let rec ins n = match Node.node n with
186       | Empty -> leaf k
187       | Leaf j ->  if j == k then n else join kid (leaf k) (Uid.to_int (H.uid j)) n
188       | Branch (p,m,t0,t1)  ->
189           if match_prefix kid p m then
190             if zero_bit kid m then 
191               branch p m (ins t0) t1
192             else
193               branch p m t0 (ins t1)
194           else
195             join kid (leaf k)  p n
196     in
197     ins t
198       
199   let remove k t =
200     let kid = Uid.to_int(H.uid k) in
201     let rec rmv n = match Node.node n with
202       | Empty -> empty
203       | Leaf j  -> if  k == j then empty else n
204       | Branch (p,m,t0,t1) -> 
205           if match_prefix kid p m then
206             if zero_bit kid m then
207               branch_ne p m (rmv t0) t1
208             else
209               branch_ne p m t0 (rmv t1)
210           else
211             n
212     in
213     rmv t
214       
215   (* should run in O(1) thanks to Hash consing *)
216
217   let equal a b = Node.equal a b 
218
219   let compare a b =  (Uid.to_int (Node.uid a)) - (Uid.to_int (Node.uid b))
220
221   let rec merge s t = 
222     if (equal s t) (* This is cheap thanks to hash-consing *)
223     then s
224     else
225     match Node.node s, Node.node t with
226       | Empty, _  -> t
227       | _, Empty  -> s
228       | Leaf k, _ -> add k t
229       | _, Leaf k -> add k s
230       | Branch (p,m,s0,s1), Branch (q,n,t0,t1) ->
231           if m == n && match_prefix q p m then
232             branch p  m  (merge s0 t0) (merge s1 t1)
233           else if m > n && match_prefix q p m then
234             if zero_bit q m then 
235               branch p m (merge s0 t) s1
236             else 
237               branch p m s0 (merge s1 t)
238           else if m < n && match_prefix p q n then     
239             if zero_bit p n then
240               branch q n (merge s t0) t1
241             else
242               branch q n t0 (merge s t1)
243           else
244             (* The prefixes disagree. *)
245             join p s q t
246                
247         
248                
249                
250   let rec subset s1 s2 = (equal s1 s2) ||
251     match (Node.node s1,Node.node s2) with
252       | Empty, _ -> true
253       | _, Empty -> false
254       | Leaf k1, _ -> mem k1 s2
255       | Branch _, Leaf _ -> false
256       | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
257           if m1 == m2 && p1 == p2 then
258             subset l1 l2 && subset r1 r2
259           else if m1 < m2 && match_prefix p1 p2 m2 then
260             if zero_bit p1 m2 then 
261               subset l1 l2 && subset r1 l2
262             else 
263               subset l1 r2 && subset r1 r2
264           else
265             false
266
267               
268   let union s1 s2 = merge s1 s2
269     (* Todo replace with e Memo Module *)
270   module MemUnion = Hashtbl.Make(
271     struct 
272       type set = t 
273       type t = set*set 
274       let equal (x,y) (z,t) = (equal x z)&&(equal y t)
275       let equal a b = equal a b || equal b a
276       let hash (x,y) =   (* commutative hash *)
277         let x = Node.hash x 
278         and y = Node.hash y 
279         in
280           if x < y then HASHINT2(x,y) else HASHINT2(y,x)
281     end)
282   let h_mem = MemUnion.create MED_H_SIZE
283
284   let mem_union s1 s2 = 
285     try  MemUnion.find h_mem (s1,s2) 
286     with Not_found ->
287           let r = merge s1 s2 in MemUnion.add h_mem (s1,s2) r;r 
288       
289
290   let rec inter s1 s2 = 
291     if equal s1 s2 
292     then s1
293     else
294       match (Node.node s1,Node.node s2) with
295         | Empty, _ -> empty
296         | _, Empty -> empty
297         | Leaf k1, _ -> if mem k1 s2 then s1 else empty
298         | _, Leaf k2 -> if mem k2 s1 then s2 else empty
299         | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
300             if m1 == m2 && p1 == p2 then 
301               merge (inter l1 l2)  (inter r1 r2)
302             else if m1 > m2 && match_prefix p2 p1 m1 then
303               inter (if zero_bit p2 m1 then l1 else r1) s2
304             else if m1 < m2 && match_prefix p1 p2 m2 then
305               inter s1 (if zero_bit p1 m2 then l2 else r2)
306             else
307               empty
308
309   let rec diff s1 s2 = 
310     if equal s1 s2 
311     then empty
312     else
313       match (Node.node s1,Node.node s2) with
314         | Empty, _ -> empty
315         | _, Empty -> s1
316         | Leaf k1, _ -> if mem k1 s2 then empty else s1
317         | _, Leaf k2 -> remove k2 s1
318         | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
319             if m1 == m2 && p1 == p2 then
320               merge (diff l1 l2) (diff r1 r2)
321             else if m1 > m2 && match_prefix p2 p1 m1 then
322               if zero_bit p2 m1 then 
323                 merge (diff l1 s2) r1
324               else 
325                 merge l1 (diff r1 s2)
326             else if m1 < m2 && match_prefix p1 p2 m2 then
327               if zero_bit p1 m2 then diff s1 l2 else diff s1 r2
328             else
329           s1
330                
331
332 (*s All the following operations ([cardinal], [iter], [fold], [for_all],
333     [exists], [filter], [partition], [choose], [elements]) are
334     implemented as for any other kind of binary trees. *)
335
336 let rec cardinal n = match Node.node n with
337   | Empty -> 0
338   | Leaf _ -> 1
339   | Branch (_,_,t0,t1) -> cardinal t0 + cardinal t1
340
341 let rec iter f n = match Node.node n with
342   | Empty -> ()
343   | Leaf k -> f k
344   | Branch (_,_,t0,t1) -> iter f t0; iter f t1
345       
346 let rec fold f s accu = match Node.node s with
347   | Empty -> accu
348   | Leaf k -> f k accu
349   | Branch (_,_,t0,t1) -> fold f t0 (fold f t1 accu)
350
351
352 let rec for_all p n = match Node.node n with
353   | Empty -> true
354   | Leaf k -> p k
355   | Branch (_,_,t0,t1) -> for_all p t0 && for_all p t1
356
357 let rec exists p n = match Node.node n with
358   | Empty -> false
359   | Leaf k -> p k
360   | Branch (_,_,t0,t1) -> exists p t0 || exists p t1
361
362 let rec filter pr n = match Node.node n with
363   | Empty -> empty
364   | Leaf k -> if pr k then n else empty
365   | Branch (p,m,t0,t1) -> branch_ne p m (filter pr t0) (filter pr t1)
366
367 let partition p s =
368   let rec part (t,f as acc) n = match Node.node n with
369     | Empty -> acc
370     | Leaf k -> if p k then (add k t, f) else (t, add k f)
371     | Branch (_,_,t0,t1) -> part (part acc t0) t1
372   in
373   part (empty, empty) s
374
375 let rec choose n = match Node.node n with
376   | Empty -> raise Not_found
377   | Leaf k -> k
378   | Branch (_, _,t0,_) -> choose t0   (* we know that [t0] is non-empty *)
379
380
381 let split x s =
382   let coll k (l, b, r) =
383     if k < x then add k l, b, r
384     else if k > x then l, b, add k r
385     else l, true, r 
386   in
387   fold coll s (empty, false, empty)
388
389 (*s Additional functions w.r.t to [Set.S]. *)
390
391 let rec intersect s1 s2 = (equal s1 s2) ||
392   match (Node.node s1,Node.node s2) with
393   | Empty, _ -> false
394   | _, Empty -> false
395   | Leaf k1, _ -> mem k1 s2
396   | _, Leaf k2 -> mem k2 s1
397   | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
398       if m1 == m2 && p1 == p2 then
399         intersect l1 l2 || intersect r1 r2
400       else if m1 < m2 && match_prefix p2 p1 m1 then
401         intersect (if zero_bit p2 m1 then l1 else r1) s2
402       else if m1 > m2 && match_prefix p1 p2 m2 then
403         intersect s1 (if zero_bit p1 m2 then l2 else r2)
404       else
405         false
406
407
408
409 let rec uncons n = match Node.node n with
410   | Empty -> raise Not_found
411   | Leaf k -> (k,empty)
412   | Branch (p,m,s,t) -> let h,ns = uncons s in h,branch_ne p m ns t
413    
414 let from_list l = List.fold_left (fun acc e -> add e acc) empty l
415
416 let with_id = Node.with_id
417 end
418
419 module Int : sig
420   include S with type elt = int
421   val print : Format.formatter -> t -> unit
422 end
423   = 
424 struct
425   include Make ( struct type t = int 
426                         type data = t
427                         external hash : t -> int = "%identity"
428                         external uid : t -> Uid.t = "%identity"
429                         external equal : t -> t -> bool = "%eq"
430                         external make : t -> int = "%identity"
431                         external node : t -> int = "%identity"
432                         external with_id : Uid.t -> t = "%identity"
433                  end
434                ) 
435   let print ppf s = 
436     Format.pp_print_string ppf "{ ";
437     iter (fun i -> Format.fprintf ppf "%i " i) s;
438     Format.pp_print_string ppf "}";
439     Format.pp_print_flush ppf ()
440  end