Merge branch 'handle-stdout'
[SXSI/xpathcomp.git] / src / ptset.ml
1 (***************************************************************************)
2 (* Implementation for sets of positive integers implemented as deeply hash-*)
3 (* consed Patricia trees. Provide fast set operations, fast membership as  *)
4 (* well as fast min and max elements. Hash consing provides O(1) equality  *)
5 (* checking                                                                *)
6 (*                                                                         *)
7 (***************************************************************************)
8 INCLUDE "utils.ml"
9 module type S =
10 sig
11     type elt
12
13   type 'a node
14   module rec Node : sig
15     include  Hcons.S with type data = Data.t
16   end
17   and Data : sig
18     include
19       Hashtbl.HashedType with type t = Node.t node
20   end
21   type data = Data.t
22   type t = Node.t
23
24
25   val empty : t
26   val is_empty : t -> bool
27   val mem : elt -> t -> bool
28   val add : elt -> t -> t
29   val singleton : elt -> t
30   val remove : elt -> t -> t
31   val union : t -> t -> t
32   val inter : t -> t -> t
33   val diff : t -> t -> t
34   val compare : t -> t -> int
35   val equal : t -> t -> bool
36   val subset : t -> t -> bool
37   val iter : (elt -> unit) -> t -> unit
38   val fold : (elt -> 'a -> 'a) -> t -> 'a -> 'a
39   val for_all : (elt -> bool) -> t -> bool
40   val exists : (elt -> bool) -> t -> bool
41   val filter : (elt -> bool) -> t -> t
42   val partition : (elt -> bool) -> t -> t * t
43   val cardinal : t -> int
44   val elements : t -> elt list
45   val min_elt : t -> elt
46   val max_elt : t -> elt
47   val choose : t -> elt
48   val split : elt -> t -> t * bool * t
49
50   val intersect : t -> t -> bool
51   val is_singleton : t -> bool
52   val mem_union : t -> t -> t
53   val hash : t -> int
54   val uid : t -> Uid.t
55   val uncons : t -> elt*t
56   val from_list : elt list -> t
57   val make : data -> t
58   val node : t -> data
59   val stats : unit -> unit
60   val init : unit -> unit
61 end
62
63 module Make ( H : Hcons.SA ) : S with type elt = H.t =
64 struct
65   type elt = H.t
66   type 'a node =
67     | Empty
68     | Leaf of elt
69     | Branch of int * int * 'a * 'a
70
71   module rec Node : Hcons.S with type data = Data.t = Hcons.Make (Data)
72   and Data : Hashtbl.HashedType  with type t = Node.t node =
73   struct
74     type t =  Node.t node
75     let equal x y =
76       match x,y with
77         | Empty,Empty -> true
78         | Leaf k1, Leaf k2 ->  k1 == k2
79         | Branch(b1,i1,l1,r1),Branch(b2,i2,l2,r2) ->
80             b1 == b2 && i1 == i2 &&
81               (Node.equal l1 l2) &&
82               (Node.equal r1 r2)
83         | _ -> false
84     let hash = function
85       | Empty -> 0
86       | Leaf i -> HASHINT2(HALF_MAX_INT,Uid.to_int (H.uid i))
87       | Branch (b,i,l,r) -> HASHINT4(b,i,Uid.to_int l.Node.id, Uid.to_int r.Node.id)
88   end
89
90   type data = Data.t
91   type t = Node.t
92   let stats = Node.stats
93   let init = Node.init
94   let hash = Node.hash
95   let uid = Node.uid
96   let make = Node.make
97   let node _ = failwith "node"
98   let empty = Node.make Empty
99
100   let is_empty s = (Node.node s) == Empty
101
102   let branch p m l r = Node.make (Branch(p,m,l,r))
103
104   let leaf k = Node.make (Leaf k)
105
106   (* To enforce the invariant that a branch contains two non empty sub-trees *)
107   let branch_ne p m t0 t1 =
108     if (is_empty t0) then t1
109     else if is_empty t1 then t0 else branch p m t0 t1
110
111   (********** from here on, only use the smart constructors *************)
112
113   let zero_bit k m = (k land m) == 0
114
115   let singleton k = leaf k
116
117   let is_singleton n =
118     match Node.node n with Leaf _ -> true
119       | _ -> false
120
121   let mem (k:elt) n =
122     let kid = Uid.to_int (H.uid k) in
123     let rec loop n = match Node.node n with
124       | Empty -> false
125       | Leaf j ->  k == j
126       | Branch (p, _, l, r) -> if kid <= p then loop l else loop r
127     in loop n
128
129   let rec min_elt n = match Node.node n with
130     | Empty -> raise Not_found
131     | Leaf k -> k
132     | Branch (_,_,s,_) -> min_elt s
133
134   let rec max_elt n = match Node.node n with
135     | Empty -> raise Not_found
136     | Leaf k -> k
137     | Branch (_,_,_,t) -> max_elt t
138
139   let elements s =
140     let rec elements_aux acc n = match Node.node n with
141       | Empty -> acc
142       | Leaf k -> k :: acc
143       | Branch (_,_,l,r) -> elements_aux (elements_aux acc r) l
144     in
145       elements_aux [] s
146
147   let mask k m  = (k lor (m-1)) land (lnot m)
148
149   let naive_highest_bit x =
150     assert (x < 256);
151     let rec loop i =
152       if i = 0 then 1 else if x lsr i = 1 then 1 lsl i else loop (i-1)
153     in
154       loop 7
155
156   let hbit = Array.init 256 naive_highest_bit
157
158   external clz : int -> int = "caml_clz" "noalloc"
159   external leading_bit : int -> int = "caml_leading_bit" "noalloc"
160   let highest_bit x =
161     try
162       let n = (x) lsr 24 in
163         if n != 0 then  hbit.(n) lsl 24
164         else let n = (x) lsr 16 in if n != 0 then hbit.(n) lsl 16
165         else let n = (x) lsr 8 in if n != 0 then hbit.(n) lsl 8
166         else hbit.(x)
167     with
168         _ -> raise (Invalid_argument ("highest_bit " ^ (string_of_int x)))
169
170   let highest_bit64 x =
171     let n = x lsr 32 in if n != 0 then highest_bit n lsl 32
172       else highest_bit x
173
174   let branching_bit p0 p1 = leading_bit (p0 lxor p1)
175
176   let join p0 t0 p1 t1 =
177     let m = branching_bit p0 p1  in
178     let msk = mask p0 m in
179       if zero_bit p0 m then
180         branch_ne msk m t0 t1
181       else
182         branch_ne msk m t1 t0
183
184   let match_prefix k p m = (mask k m) == p
185
186   let add k t =
187     let kid = Uid.to_int (H.uid k) in
188       assert (kid >=0);
189     let rec ins n = match Node.node n with
190       | Empty -> leaf k
191       | Leaf j ->  if j == k then n else join kid (leaf k) (Uid.to_int (H.uid j)) n
192       | Branch (p,m,t0,t1)  ->
193           if match_prefix kid p m then
194             if zero_bit kid m then
195               branch_ne p m (ins t0) t1
196             else
197               branch_ne p m t0 (ins t1)
198           else
199             join kid (leaf k)  p n
200     in
201     ins t
202
203   let remove k t =
204     let kid = Uid.to_int(H.uid k) in
205     let rec rmv n = match Node.node n with
206       | Empty -> empty
207       | Leaf j  -> if  k == j then empty else n
208       | Branch (p,m,t0,t1) ->
209           if match_prefix kid p m then
210             if zero_bit kid m then
211               branch_ne p m (rmv t0) t1
212             else
213               branch_ne p m t0 (rmv t1)
214           else
215             n
216     in
217     rmv t
218
219   (* should run in O(1) thanks to Hash consing *)
220
221   let equal a b = Node.equal a b
222
223   let compare a b =  (Uid.to_int (Node.uid a)) - (Uid.to_int (Node.uid b))
224
225   let rec merge s t =
226     if (equal s t) (* This is cheap thanks to hash-consing *)
227     then s
228     else
229     match Node.node s, Node.node t with
230       | Empty, _  -> t
231       | _, Empty  -> s
232       | Leaf k, _ -> add k t
233       | _, Leaf k -> add k s
234       | Branch (p,m,s0,s1), Branch (q,n,t0,t1) ->
235           if m == n && match_prefix q p m then
236             branch p  m  (merge s0 t0) (merge s1 t1)
237           else if m > n && match_prefix q p m then
238             if zero_bit q m then
239               branch_ne p m (merge s0 t) s1
240             else
241               branch_ne p m s0 (merge s1 t)
242           else if m < n && match_prefix p q n then
243             if zero_bit p n then
244               branch_ne q n (merge s t0) t1
245             else
246               branch_ne q n t0 (merge s t1)
247           else
248             (* The prefixes disagree. *)
249             join p s q t
250
251
252
253
254   let rec subset s1 s2 = (equal s1 s2) ||
255     match (Node.node s1,Node.node s2) with
256       | Empty, _ -> true
257       | _, Empty -> false
258       | Leaf k1, _ -> mem k1 s2
259       | Branch _, Leaf _ -> false
260       | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
261           if m1 == m2 && p1 == p2 then
262             subset l1 l2 && subset r1 r2
263           else if m1 < m2 && match_prefix p1 p2 m2 then
264             if zero_bit p1 m2 then
265               subset l1 l2 && subset r1 l2
266             else
267               subset l1 r2 && subset r1 r2
268           else
269             false
270
271
272   let union s1 s2 = merge s1 s2
273     (* Todo replace with e Memo Module *)
274   module MemUnion = Hashtbl.Make(
275     struct
276       type set = t
277       type t = set*set
278       let equal (x,y) (z,t) = (equal x z)&&(equal y t)
279       let equal a b = equal a b || equal b a
280       let hash (x,y) =   (* commutative hash *)
281         let x = Node.hash x
282         and y = Node.hash y
283         in
284           if x < y then HASHINT2(x,y) else HASHINT2(y,x)
285     end)
286   let h_mem = MemUnion.create MED_H_SIZE
287
288   let mem_union s1 s2 =
289     try  MemUnion.find h_mem (s1,s2)
290     with Not_found ->
291           let r = merge s1 s2 in MemUnion.add h_mem (s1,s2) r;r
292
293
294   let rec inter s1 s2 =
295     if equal s1 s2
296     then s1
297     else
298       match (Node.node s1,Node.node s2) with
299         | Empty, _ -> empty
300         | _, Empty -> empty
301         | Leaf k1, _ -> if mem k1 s2 then s1 else empty
302         | _, Leaf k2 -> if mem k2 s1 then s2 else empty
303         | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
304             if m1 == m2 && p1 == p2 then
305               merge (inter l1 l2)  (inter r1 r2)
306             else if m1 > m2 && match_prefix p2 p1 m1 then
307               inter (if zero_bit p2 m1 then l1 else r1) s2
308             else if m1 < m2 && match_prefix p1 p2 m2 then
309               inter s1 (if zero_bit p1 m2 then l2 else r2)
310             else
311               empty
312
313   let rec diff s1 s2 =
314     if equal s1 s2
315     then empty
316     else
317       match (Node.node s1,Node.node s2) with
318       | Empty, _ -> empty
319       | _, Empty -> s1
320       | Leaf k1, _ -> if mem k1 s2 then empty else s1
321       | _, Leaf k2 -> remove k2 s1
322       | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
323         if m1 == m2 && p1 == p2 then
324           merge (diff l1 l2) (diff r1 r2)
325         else if m1 > m2 && match_prefix p2 p1 m1 then
326           if zero_bit p2 m1 then
327             merge (diff l1 s2) r1
328           else
329             merge l1 (diff r1 s2)
330         else if m1 < m2 && match_prefix p1 p2 m2 then
331           if zero_bit p1 m2 then diff s1 l2 else diff s1 r2
332         else
333           s1
334
335
336 (*s All the following operations ([cardinal], [iter], [fold], [for_all],
337     [exists], [filter], [partition], [choose], [elements]) are
338     implemented as for any other kind of binary trees. *)
339
340 let rec cardinal n = match Node.node n with
341   | Empty -> 0
342   | Leaf _ -> 1
343   | Branch (_,_,t0,t1) -> cardinal t0 + cardinal t1
344
345 let rec iter f n = match Node.node n with
346   | Empty -> ()
347   | Leaf k -> f k
348   | Branch (_,_,t0,t1) -> iter f t0; iter f t1
349
350 let rec fold f s accu = match Node.node s with
351   | Empty -> accu
352   | Leaf k -> f k accu
353   | Branch (_,_,t0,t1) -> fold f t0 (fold f t1 accu)
354
355
356 let rec for_all p n = match Node.node n with
357   | Empty -> true
358   | Leaf k -> p k
359   | Branch (_,_,t0,t1) -> for_all p t0 && for_all p t1
360
361 let rec exists p n = match Node.node n with
362   | Empty -> false
363   | Leaf k -> p k
364   | Branch (_,_,t0,t1) -> exists p t0 || exists p t1
365
366 let rec filter pr n = match Node.node n with
367   | Empty -> empty
368   | Leaf k -> if pr k then n else empty
369   | Branch (p,m,t0,t1) -> branch_ne p m (filter pr t0) (filter pr t1)
370
371 let partition p s =
372   let rec part (t,f as acc) n = match Node.node n with
373     | Empty -> acc
374     | Leaf k -> if p k then (add k t, f) else (t, add k f)
375     | Branch (_,_,t0,t1) -> part (part acc t0) t1
376   in
377   part (empty, empty) s
378
379 let rec choose n = match Node.node n with
380   | Empty -> raise Not_found
381   | Leaf k -> k
382   | Branch (_, _,t0,_) -> choose t0   (* we know that [t0] is non-empty *)
383
384
385 let split x s =
386   let coll k (l, b, r) =
387     if k < x then add k l, b, r
388     else if k > x then l, b, add k r
389     else l, true, r
390   in
391   fold coll s (empty, false, empty)
392
393 (*s Additional functions w.r.t to [Set.S]. *)
394
395 let rec intersect s1 s2 = (equal s1 s2) ||
396   match (Node.node s1,Node.node s2) with
397   | Empty, _ -> false
398   | _, Empty -> false
399   | Leaf k1, _ -> mem k1 s2
400   | _, Leaf k2 -> mem k2 s1
401   | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
402       if m1 == m2 && p1 == p2 then
403         intersect l1 l2 || intersect r1 r2
404       else if m1 < m2 && match_prefix p2 p1 m1 then
405         intersect (if zero_bit p2 m1 then l1 else r1) s2
406       else if m1 > m2 && match_prefix p1 p2 m2 then
407         intersect s1 (if zero_bit p1 m2 then l2 else r2)
408       else
409         false
410
411
412
413 let rec uncons n = match Node.node n with
414   | Empty -> raise Not_found
415   | Leaf k -> (k,empty)
416   | Branch (p,m,s,t) -> let h,ns = uncons s in h,branch_ne p m ns t
417
418 let from_list l = List.fold_left (fun acc e -> add e acc) empty l
419
420
421 end
422
423 module Int : sig
424   include S with type elt = int
425   val print : Format.formatter -> t -> unit
426 end
427   =
428 struct
429   include Make ( struct type t = int
430                         type data = t
431                         external hash : t -> int = "%identity"
432                         external uid : t -> Uid.t = "%identity"
433                         external equal : t -> t -> bool = "%eq"
434                         external make : t -> int = "%identity"
435                         external node : t -> int = "%identity"
436                         external stats : unit -> unit = "%identity"
437                         external init : unit -> unit = "%identity"
438                  end
439                )
440   let print ppf s =
441     Format.pp_print_string ppf "{ ";
442     iter (fun i -> Format.fprintf ppf "%i " i) s;
443     Format.pp_print_string ppf "}";
444     Format.pp_print_flush ppf ()
445  end