bb58049d2c8bd8040c761e6c99fbc1af8245607f
[SXSI/xpathcomp.git] / src / ptset.ml
1 (***************************************************************************)
2 (* Implementation for sets of positive integers implemented as deeply hash-*)
3 (* consed Patricia trees. Provide fast set operations, fast membership as  *)
4 (* well as fast min and max elements. Hash consing provides O(1) equality  *)
5 (* checking                                                                *)
6 (*                                                                         *)
7 (***************************************************************************)
8 INCLUDE "utils.ml"
9 module type S =
10 sig
11     type elt
12
13   type 'a node
14   module rec Node : sig
15     include  Hcons.S with type data = Data.t
16   end
17   and Data : sig
18     include
19       Hashtbl.HashedType with type t = Node.t node
20   end
21   type data = Data.t
22   type t = Node.t
23
24
25   val empty : t
26   val is_empty : t -> bool
27   val mem : elt -> t -> bool
28   val add : elt -> t -> t
29   val singleton : elt -> t
30   val remove : elt -> t -> t
31   val union : t -> t -> t
32   val inter : t -> t -> t
33   val diff : t -> t -> t
34   val compare : t -> t -> int
35   val equal : t -> t -> bool
36   val subset : t -> t -> bool
37   val iter : (elt -> unit) -> t -> unit
38   val fold : (elt -> 'a -> 'a) -> t -> 'a -> 'a
39   val for_all : (elt -> bool) -> t -> bool
40   val exists : (elt -> bool) -> t -> bool
41   val filter : (elt -> bool) -> t -> t
42   val partition : (elt -> bool) -> t -> t * t
43   val cardinal : t -> int
44   val elements : t -> elt list
45   val min_elt : t -> elt
46   val max_elt : t -> elt
47   val choose : t -> elt
48   val split : elt -> t -> t * bool * t
49
50   val intersect : t -> t -> bool
51   val is_singleton : t -> bool
52   val mem_union : t -> t -> t
53   val hash : t -> int
54   val uid : t -> Uid.t
55   val uncons : t -> elt*t
56   val from_list : elt list -> t
57   val make : data -> t
58   val node : t -> data
59   val stats : unit -> unit
60 end
61
62 module Make ( H : Hcons.SA ) : S with type elt = H.t =
63 struct
64   type elt = H.t
65   type 'a node =
66     | Empty
67     | Leaf of elt
68     | Branch of int * int * 'a * 'a
69
70   module rec Node : Hcons.S with type data = Data.t = Hcons.Make (Data)
71   and Data : Hashtbl.HashedType  with type t = Node.t node =
72   struct
73     type t =  Node.t node
74     let equal x y =
75       match x,y with
76         | Empty,Empty -> true
77         | Leaf k1, Leaf k2 ->  k1 == k2
78         | Branch(b1,i1,l1,r1),Branch(b2,i2,l2,r2) ->
79             b1 == b2 && i1 == i2 &&
80               (Node.equal l1 l2) &&
81               (Node.equal r1 r2)
82         | _ -> false
83     let hash = function
84       | Empty -> 0
85       | Leaf i -> HASHINT2(HALF_MAX_INT,Uid.to_int (H.uid i))
86       | Branch (b,i,l,r) -> HASHINT4(b,i,Uid.to_int l.Node.id, Uid.to_int r.Node.id)
87   end
88
89   type data = Data.t
90   type t = Node.t
91   let stats = Node.stats
92   let hash = Node.hash
93   let uid = Node.uid
94   let make = Node.make
95   let node _ = failwith "node"
96   let empty = Node.make Empty
97
98   let is_empty s = (Node.node s) == Empty
99
100   let branch p m l r = Node.make (Branch(p,m,l,r))
101
102   let leaf k = Node.make (Leaf k)
103
104   (* To enforce the invariant that a branch contains two non empty sub-trees *)
105   let branch_ne p m t0 t1 =
106     if (is_empty t0) then t1
107     else if is_empty t1 then t0 else branch p m t0 t1
108
109   (********** from here on, only use the smart constructors *************)
110
111   let zero_bit k m = (k land m) == 0
112
113   let singleton k = leaf k
114
115   let is_singleton n =
116     match Node.node n with Leaf _ -> true
117       | _ -> false
118
119   let mem (k:elt) n =
120     let kid = Uid.to_int (H.uid k) in
121     let rec loop n = match Node.node n with
122       | Empty -> false
123       | Leaf j ->  k == j
124       | Branch (p, _, l, r) -> if kid <= p then loop l else loop r
125     in loop n
126
127   let rec min_elt n = match Node.node n with
128     | Empty -> raise Not_found
129     | Leaf k -> k
130     | Branch (_,_,s,_) -> min_elt s
131
132   let rec max_elt n = match Node.node n with
133     | Empty -> raise Not_found
134     | Leaf k -> k
135     | Branch (_,_,_,t) -> max_elt t
136
137   let elements s =
138     let rec elements_aux acc n = match Node.node n with
139       | Empty -> acc
140       | Leaf k -> k :: acc
141       | Branch (_,_,l,r) -> elements_aux (elements_aux acc r) l
142     in
143       elements_aux [] s
144
145   let mask k m  = (k lor (m-1)) land (lnot m)
146
147   let naive_highest_bit x =
148     assert (x < 256);
149     let rec loop i =
150       if i = 0 then 1 else if x lsr i = 1 then 1 lsl i else loop (i-1)
151     in
152       loop 7
153
154   let hbit = Array.init 256 naive_highest_bit
155
156   external clz : int -> int = "caml_clz" "noalloc"
157   external leading_bit : int -> int = "caml_leading_bit" "noalloc"
158   let highest_bit x =
159     try
160       let n = (x) lsr 24 in
161         if n != 0 then  hbit.(n) lsl 24
162         else let n = (x) lsr 16 in if n != 0 then hbit.(n) lsl 16
163         else let n = (x) lsr 8 in if n != 0 then hbit.(n) lsl 8
164         else hbit.(x)
165     with
166         _ -> raise (Invalid_argument ("highest_bit " ^ (string_of_int x)))
167
168   let highest_bit64 x =
169     let n = x lsr 32 in if n != 0 then highest_bit n lsl 32
170       else highest_bit x
171
172   let branching_bit p0 p1 = leading_bit (p0 lxor p1)
173
174   let join p0 t0 p1 t1 =
175     let m = branching_bit p0 p1  in
176     let msk = mask p0 m in
177       if zero_bit p0 m then
178         branch_ne msk m t0 t1
179       else
180         branch_ne msk m t1 t0
181
182   let match_prefix k p m = (mask k m) == p
183
184   let add k t =
185     let kid = Uid.to_int (H.uid k) in
186       assert (kid >=0);
187     let rec ins n = match Node.node n with
188       | Empty -> leaf k
189       | Leaf j ->  if j == k then n else join kid (leaf k) (Uid.to_int (H.uid j)) n
190       | Branch (p,m,t0,t1)  ->
191           if match_prefix kid p m then
192             if zero_bit kid m then
193               branch_ne p m (ins t0) t1
194             else
195               branch_ne p m t0 (ins t1)
196           else
197             join kid (leaf k)  p n
198     in
199     ins t
200
201   let remove k t =
202     let kid = Uid.to_int(H.uid k) in
203     let rec rmv n = match Node.node n with
204       | Empty -> empty
205       | Leaf j  -> if  k == j then empty else n
206       | Branch (p,m,t0,t1) ->
207           if match_prefix kid p m then
208             if zero_bit kid m then
209               branch_ne p m (rmv t0) t1
210             else
211               branch_ne p m t0 (rmv t1)
212           else
213             n
214     in
215     rmv t
216
217   (* should run in O(1) thanks to Hash consing *)
218
219   let equal a b = Node.equal a b
220
221   let compare a b =  (Uid.to_int (Node.uid a)) - (Uid.to_int (Node.uid b))
222
223   let rec merge s t =
224     if (equal s t) (* This is cheap thanks to hash-consing *)
225     then s
226     else
227     match Node.node s, Node.node t with
228       | Empty, _  -> t
229       | _, Empty  -> s
230       | Leaf k, _ -> add k t
231       | _, Leaf k -> add k s
232       | Branch (p,m,s0,s1), Branch (q,n,t0,t1) ->
233           if m == n && match_prefix q p m then
234             branch p  m  (merge s0 t0) (merge s1 t1)
235           else if m > n && match_prefix q p m then
236             if zero_bit q m then
237               branch_ne p m (merge s0 t) s1
238             else
239               branch_ne p m s0 (merge s1 t)
240           else if m < n && match_prefix p q n then
241             if zero_bit p n then
242               branch_ne q n (merge s t0) t1
243             else
244               branch_ne q n t0 (merge s t1)
245           else
246             (* The prefixes disagree. *)
247             join p s q t
248
249
250
251
252   let rec subset s1 s2 = (equal s1 s2) ||
253     match (Node.node s1,Node.node s2) with
254       | Empty, _ -> true
255       | _, Empty -> false
256       | Leaf k1, _ -> mem k1 s2
257       | Branch _, Leaf _ -> false
258       | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
259           if m1 == m2 && p1 == p2 then
260             subset l1 l2 && subset r1 r2
261           else if m1 < m2 && match_prefix p1 p2 m2 then
262             if zero_bit p1 m2 then
263               subset l1 l2 && subset r1 l2
264             else
265               subset l1 r2 && subset r1 r2
266           else
267             false
268
269
270   let union s1 s2 = merge s1 s2
271     (* Todo replace with e Memo Module *)
272   module MemUnion = Hashtbl.Make(
273     struct
274       type set = t
275       type t = set*set
276       let equal (x,y) (z,t) = (equal x z)&&(equal y t)
277       let equal a b = equal a b || equal b a
278       let hash (x,y) =   (* commutative hash *)
279         let x = Node.hash x
280         and y = Node.hash y
281         in
282           if x < y then HASHINT2(x,y) else HASHINT2(y,x)
283     end)
284   let h_mem = MemUnion.create MED_H_SIZE
285
286   let mem_union s1 s2 =
287     try  MemUnion.find h_mem (s1,s2)
288     with Not_found ->
289           let r = merge s1 s2 in MemUnion.add h_mem (s1,s2) r;r
290
291
292   let rec inter s1 s2 =
293     if equal s1 s2
294     then s1
295     else
296       match (Node.node s1,Node.node s2) with
297         | Empty, _ -> empty
298         | _, Empty -> empty
299         | Leaf k1, _ -> if mem k1 s2 then s1 else empty
300         | _, Leaf k2 -> if mem k2 s1 then s2 else empty
301         | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
302             if m1 == m2 && p1 == p2 then
303               merge (inter l1 l2)  (inter r1 r2)
304             else if m1 > m2 && match_prefix p2 p1 m1 then
305               inter (if zero_bit p2 m1 then l1 else r1) s2
306             else if m1 < m2 && match_prefix p1 p2 m2 then
307               inter s1 (if zero_bit p1 m2 then l2 else r2)
308             else
309               empty
310
311   let rec diff s1 s2 =
312     if equal s1 s2
313     then empty
314     else
315       match (Node.node s1,Node.node s2) with
316         | Empty, _ -> empty
317         | _, Empty -> s1
318         | Leaf k1, _ -> if mem k1 s2 then empty else s1
319         | _, Leaf k2 -> remove k2 s1
320         | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
321             if m1 == m2 && p1 == p2 then
322               merge (diff l1 l2) (diff r1 r2)
323             else if m1 > m2 && match_prefix p2 p1 m1 then
324               if zero_bit p2 m1 then
325                 merge (diff l1 s2) r1
326               else
327                 merge l1 (diff r1 s2)
328             else if m1 < m2 && match_prefix p1 p2 m2 then
329               if zero_bit p1 m2 then diff s1 l2 else diff s1 r2
330             else
331           s1
332
333
334 (*s All the following operations ([cardinal], [iter], [fold], [for_all],
335     [exists], [filter], [partition], [choose], [elements]) are
336     implemented as for any other kind of binary trees. *)
337
338 let rec cardinal n = match Node.node n with
339   | Empty -> 0
340   | Leaf _ -> 1
341   | Branch (_,_,t0,t1) -> cardinal t0 + cardinal t1
342
343 let rec iter f n = match Node.node n with
344   | Empty -> ()
345   | Leaf k -> f k
346   | Branch (_,_,t0,t1) -> iter f t0; iter f t1
347
348 let rec fold f s accu = match Node.node s with
349   | Empty -> accu
350   | Leaf k -> f k accu
351   | Branch (_,_,t0,t1) -> fold f t0 (fold f t1 accu)
352
353
354 let rec for_all p n = match Node.node n with
355   | Empty -> true
356   | Leaf k -> p k
357   | Branch (_,_,t0,t1) -> for_all p t0 && for_all p t1
358
359 let rec exists p n = match Node.node n with
360   | Empty -> false
361   | Leaf k -> p k
362   | Branch (_,_,t0,t1) -> exists p t0 || exists p t1
363
364 let rec filter pr n = match Node.node n with
365   | Empty -> empty
366   | Leaf k -> if pr k then n else empty
367   | Branch (p,m,t0,t1) -> branch_ne p m (filter pr t0) (filter pr t1)
368
369 let partition p s =
370   let rec part (t,f as acc) n = match Node.node n with
371     | Empty -> acc
372     | Leaf k -> if p k then (add k t, f) else (t, add k f)
373     | Branch (_,_,t0,t1) -> part (part acc t0) t1
374   in
375   part (empty, empty) s
376
377 let rec choose n = match Node.node n with
378   | Empty -> raise Not_found
379   | Leaf k -> k
380   | Branch (_, _,t0,_) -> choose t0   (* we know that [t0] is non-empty *)
381
382
383 let split x s =
384   let coll k (l, b, r) =
385     if k < x then add k l, b, r
386     else if k > x then l, b, add k r
387     else l, true, r
388   in
389   fold coll s (empty, false, empty)
390
391 (*s Additional functions w.r.t to [Set.S]. *)
392
393 let rec intersect s1 s2 = (equal s1 s2) ||
394   match (Node.node s1,Node.node s2) with
395   | Empty, _ -> false
396   | _, Empty -> false
397   | Leaf k1, _ -> mem k1 s2
398   | _, Leaf k2 -> mem k2 s1
399   | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
400       if m1 == m2 && p1 == p2 then
401         intersect l1 l2 || intersect r1 r2
402       else if m1 < m2 && match_prefix p2 p1 m1 then
403         intersect (if zero_bit p2 m1 then l1 else r1) s2
404       else if m1 > m2 && match_prefix p1 p2 m2 then
405         intersect s1 (if zero_bit p1 m2 then l2 else r2)
406       else
407         false
408
409
410
411 let rec uncons n = match Node.node n with
412   | Empty -> raise Not_found
413   | Leaf k -> (k,empty)
414   | Branch (p,m,s,t) -> let h,ns = uncons s in h,branch_ne p m ns t
415
416 let from_list l = List.fold_left (fun acc e -> add e acc) empty l
417
418
419 end
420
421 module Int : sig
422   include S with type elt = int
423   val print : Format.formatter -> t -> unit
424 end
425   =
426 struct
427   include Make ( struct type t = int
428                         type data = t
429                         external hash : t -> int = "%identity"
430                         external uid : t -> Uid.t = "%identity"
431                         external equal : t -> t -> bool = "%eq"
432                         external make : t -> int = "%identity"
433                         external node : t -> int = "%identity"
434                         external stats : unit -> unit = "%identity"
435                  end
436                )
437   let print ppf s =
438     Format.pp_print_string ppf "{ ";
439     iter (fun i -> Format.fprintf ppf "%i " i) s;
440     Format.pp_print_string ppf "}";
441     Format.pp_print_flush ppf ()
442  end