Usable version:
[tatoo.git] / src / ptset.ml
1 (* Original file : *)
2
3 (***********************************************************************)
4 (*                                                                     *)
5 (*  Copyright (C) Jean-Christophe Filliatre                            *)
6 (*                                                                     *)
7 (*  This software is free software; you can redistribute it and/or     *)
8 (*  modify it under the terms of the GNU Library General Public        *)
9 (*  License version 2.1, with the special exception on linking         *)
10 (*  described in file http://www.lri.fr/~filliatr/ftp/ocaml/ds/LICENSE *)
11 (*                                                                     *)
12 (*  This software is distributed in the hope that it will be useful,   *)
13 (*  but WITHOUT ANY WARRANTY; without even the implied warranty of     *)
14 (*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.               *)
15 (*                                                                     *)
16 (***********************************************************************)
17
18 (* Modified by Kim Nguyen *)
19 (* The Patricia trees are themselves deeply hashconsed. The module
20    provides a Make (and Weak) functor to build hashconsed patricia
21    trees whose elements are Abstract hashconsed values. This allows
22    to build sets of integers without boxing them in an hacons structure
23 *)
24 INCLUDE "utils.ml"
25
26 module type S =
27   sig
28     include Sigs.Set
29     include Hcons.S with type t := t
30   end
31
32 module type HConsBuilder =
33   functor (H : Sigs.HashedType) -> Hcons.S with type data = H.t
34
35 module Builder (HCB : HConsBuilder) (H : Hcons.Abstract) :
36   S with type elt = H.t =
37 struct
38   type elt = H.t
39
40   type 'a set =
41     | Empty
42     | Leaf of elt
43     | Branch of int * int * 'a * 'a
44
45   module rec Node : Hcons.S with type data = Data.t = HCB(Data)
46   and Data : Sigs.HashedType with type t = Node.t set =
47   struct
48     type t =  Node.t set
49     let equal x y =
50       match x,y with
51         Empty,Empty -> true
52       | Leaf k1, Leaf k2 ->  k1 == k2
53       | Branch(b1,i1,l1,r1),Branch(b2,i2,l2,r2) ->
54         b1 == b2 && i1 == i2 && (Node.equal l1 l2) && (Node.equal r1 r2)
55
56       | _ -> false
57
58     let hash = function
59       | Empty -> 0
60       | Leaf i -> HASHINT2 (PRIME1, Uid.to_int (H.uid i))
61       | Branch (b,i,l,r) ->
62         HASHINT4(b, i, Uid.to_int l.Node.id, Uid.to_int r.Node.id)
63   end
64
65   include Node
66
67   let empty = Node.make Empty
68
69   let is_empty s = (Node.node s) == Empty
70
71   let branch p m l r = Node.make (Branch(p,m,l,r))
72
73   let leaf k = Node.make (Leaf k)
74
75   (* To enforce the invariant that a branch contains two non empty
76      sub-trees *)
77   let branch_ne p m t0 t1 =
78     if (is_empty t0) then t1
79     else if is_empty t1 then t0 else branch p m t0 t1
80
81   (******** from here on, only use the smart constructors ************)
82
83   let zero_bit k m = (k land m) == 0
84
85   let singleton k = leaf k
86
87   let is_singleton n =
88     match Node.node n with Leaf _ -> true
89       | _ -> false
90
91   let mem (k:elt) n =
92     let kid = Uid.to_int (H.uid k) in
93     let rec loop n = match Node.node n with
94       | Empty -> false
95       | Leaf j ->  k == j
96       | Branch (p, _, l, r) -> if kid <= p then loop l else loop r
97     in loop n
98
99   let rec min_elt n = match Node.node n with
100     | Empty -> raise Not_found
101     | Leaf k -> k
102     | Branch (_,_,s,_) -> min_elt s
103
104   let rec max_elt n = match Node.node n with
105     | Empty -> raise Not_found
106     | Leaf k -> k
107     | Branch (_,_,_,t) -> max_elt t
108
109   let elements s =
110     let rec elements_aux acc n = match Node.node n with
111       | Empty -> acc
112       | Leaf k -> k :: acc
113       | Branch (_,_,l,r) -> elements_aux (elements_aux acc r) l
114     in
115       elements_aux [] s
116
117   let mask k m  = (k lor (m-1)) land (lnot m)
118
119   let naive_highest_bit x =
120     assert (x < 256);
121     let rec loop i =
122       if i = 0 then 1 else if x lsr i = 1 then 1 lsl i else loop (i-1)
123     in
124       loop 7
125
126   let hbit = Array.init 256 naive_highest_bit
127 (*
128   external clz : int -> int = "caml_clz" "noalloc"
129   external leading_bit : int -> int = "caml_leading_bit" "noalloc"
130 *)
131   let highest_bit x =
132     try
133       let n = (x) lsr 24 in
134       if n != 0 then  hbit.(n) lsl 24
135       else let n = (x) lsr 16 in if n != 0 then hbit.(n) lsl 16
136       else let n = (x) lsr 8 in if n != 0 then hbit.(n) lsl 8
137       else hbit.(x)
138     with
139       _ -> raise (Invalid_argument ("highest_bit " ^ (string_of_int x)))
140
141   let highest_bit64 x =
142     let n = x lsr 32 in if n != 0 then highest_bit n lsl 32
143       else highest_bit x
144
145   let branching_bit p0 p1 = highest_bit64 (p0 lxor p1)
146
147   let join p0 t0 p1 t1 =
148     let m = branching_bit p0 p1  in
149     let msk = mask p0 m in
150       if zero_bit p0 m then
151         branch_ne msk m t0 t1
152       else
153         branch_ne msk m t1 t0
154
155   let match_prefix k p m = (mask k m) == p
156
157   let add k t =
158     let kid = Uid.to_int (H.uid k) in
159       assert (kid >=0);
160     let rec ins n = match Node.node n with
161       | Empty -> leaf k
162       | Leaf j ->  if j == k then n else join kid (leaf k) (Uid.to_int (H.uid j)) n
163       | Branch (p,m,t0,t1)  ->
164         if match_prefix kid p m then
165           if zero_bit kid m then
166             branch_ne p m (ins t0) t1
167           else
168             branch_ne p m t0 (ins t1)
169         else
170           join kid (leaf k)  p n
171     in
172     ins t
173
174   let remove k t =
175     let kid = Uid.to_int(H.uid k) in
176     let rec rmv n = match Node.node n with
177       | Empty -> empty
178       | Leaf j  -> if  k == j then empty else n
179       | Branch (p,m,t0,t1) ->
180         if match_prefix kid p m then
181           if zero_bit kid m then
182             branch_ne p m (rmv t0) t1
183           else
184             branch_ne p m t0 (rmv t1)
185         else
186           n
187     in
188     rmv t
189
190   (* should run in O(1) thanks to Hash consing *)
191
192   let equal a b = Node.equal a b
193
194   let compare a b = (Uid.to_int (Node.uid a)) - (Uid.to_int (Node.uid b))
195
196   let rec merge s t =
197     if (equal s t) (* This is cheap thanks to hash-consing *)
198     then s
199     else
200     match Node.node s, Node.node t with
201       | Empty, _  -> t
202       | _, Empty  -> s
203       | Leaf k, _ -> add k t
204       | _, Leaf k -> add k s
205       | Branch (p,m,s0,s1), Branch (q,n,t0,t1) ->
206         if m == n && match_prefix q p m then
207           branch p  m  (merge s0 t0) (merge s1 t1)
208         else if m > n && match_prefix q p m then
209           if zero_bit q m then
210             branch_ne p m (merge s0 t) s1
211           else
212             branch_ne p m s0 (merge s1 t)
213         else if m < n && match_prefix p q n then
214           if zero_bit p n then
215             branch_ne q n (merge s t0) t1
216           else
217             branch_ne q n t0 (merge s t1)
218         else
219           (* The prefixes disagree. *)
220           join p s q t
221
222
223
224
225   let rec subset s1 s2 = (equal s1 s2) ||
226     match (Node.node s1,Node.node s2) with
227       | Empty, _ -> true
228       | _, Empty -> false
229       | Leaf k1, _ -> mem k1 s2
230       | Branch _, Leaf _ -> false
231       | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
232         if m1 == m2 && p1 == p2 then
233           subset l1 l2 && subset r1 r2
234         else if m1 < m2 && match_prefix p1 p2 m2 then
235           if zero_bit p1 m2 then
236             subset l1 l2 && subset r1 l2
237           else
238             subset l1 r2 && subset r1 r2
239         else
240           false
241
242
243   let union s1 s2 = merge s1 s2
244     (* Todo replace with e Memo Module *)
245
246   let rec inter s1 s2 =
247     if equal s1 s2
248     then s1
249     else
250       match (Node.node s1,Node.node s2) with
251       | Empty, _ -> empty
252       | _, Empty -> empty
253       | Leaf k1, _ -> if mem k1 s2 then s1 else empty
254       | _, Leaf k2 -> if mem k2 s1 then s2 else empty
255       | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
256         if m1 == m2 && p1 == p2 then
257           merge (inter l1 l2)  (inter r1 r2)
258         else if m1 > m2 && match_prefix p2 p1 m1 then
259           inter (if zero_bit p2 m1 then l1 else r1) s2
260         else if m1 < m2 && match_prefix p1 p2 m2 then
261           inter s1 (if zero_bit p1 m2 then l2 else r2)
262         else
263           empty
264
265   let rec diff s1 s2 =
266     if equal s1 s2
267     then empty
268     else
269       match (Node.node s1,Node.node s2) with
270       | Empty, _ -> empty
271       | _, Empty -> s1
272       | Leaf k1, _ -> if mem k1 s2 then empty else s1
273       | _, Leaf k2 -> remove k2 s1
274       | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
275         if m1 == m2 && p1 == p2 then
276           merge (diff l1 l2) (diff r1 r2)
277         else if m1 > m2 && match_prefix p2 p1 m1 then
278           if zero_bit p2 m1 then
279             merge (diff l1 s2) r1
280           else
281             merge l1 (diff r1 s2)
282         else if m1 < m2 && match_prefix p1 p2 m2 then
283           if zero_bit p1 m2 then diff s1 l2 else diff s1 r2
284         else
285           s1
286
287
288 (*s All the following operations ([cardinal], [iter], [fold], [for_all],
289     [exists], [filter], [partition], [choose], [elements]) are
290     implemented as for any other kind of binary trees. *)
291
292 let rec cardinal n = match Node.node n with
293   | Empty -> 0
294   | Leaf _ -> 1
295   | Branch (_,_,t0,t1) -> cardinal t0 + cardinal t1
296
297 let rec iter f n = match Node.node n with
298   | Empty -> ()
299   | Leaf k -> f k
300   | Branch (_,_,t0,t1) -> iter f t0; iter f t1
301
302 let rec fold f s accu = match Node.node s with
303   | Empty -> accu
304   | Leaf k -> f k accu
305   | Branch (_,_,t0,t1) -> fold f t0 (fold f t1 accu)
306
307
308 let rec for_all p n = match Node.node n with
309   | Empty -> true
310   | Leaf k -> p k
311   | Branch (_,_,t0,t1) -> for_all p t0 && for_all p t1
312
313 let rec exists p n = match Node.node n with
314   | Empty -> false
315   | Leaf k -> p k
316   | Branch (_,_,t0,t1) -> exists p t0 || exists p t1
317
318 let rec filter pr n = match Node.node n with
319   | Empty -> empty
320   | Leaf k -> if pr k then n else empty
321   | Branch (p,m,t0,t1) -> branch_ne p m (filter pr t0) (filter pr t1)
322
323 let partition p s =
324   let rec part (t,f as acc) n = match Node.node n with
325     | Empty -> acc
326     | Leaf k -> if p k then (add k t, f) else (t, add k f)
327     | Branch (_,_,t0,t1) -> part (part acc t0) t1
328   in
329   part (empty, empty) s
330
331 let rec choose n = match Node.node n with
332   | Empty -> raise Not_found
333   | Leaf k -> k
334   | Branch (_, _,t0,_) -> choose t0   (* we know that [t0] is non-empty *)
335
336
337 let split x s =
338   let coll k (l, b, r) =
339     if k < x then add k l, b, r
340     else if k > x then l, b, add k r
341     else l, true, r
342   in
343   fold coll s (empty, false, empty)
344
345 (*s Additional functions w.r.t to [Set.S]. *)
346
347 let rec intersect s1 s2 = (equal s1 s2) ||
348   match (Node.node s1,Node.node s2) with
349   | Empty, _ -> false
350   | _, Empty -> false
351   | Leaf k1, _ -> mem k1 s2
352   | _, Leaf k2 -> mem k2 s1
353   | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
354       if m1 == m2 && p1 == p2 then
355         intersect l1 l2 || intersect r1 r2
356       else if m1 < m2 && match_prefix p2 p1 m1 then
357         intersect (if zero_bit p2 m1 then l1 else r1) s2
358       else if m1 > m2 && match_prefix p1 p2 m2 then
359         intersect s1 (if zero_bit p1 m2 then l2 else r2)
360       else
361         false
362
363
364 let from_list l = List.fold_left (fun acc e -> add e acc) empty l
365
366
367 end
368
369 module Make = Builder(Hcons.Make)
370 module Weak = Builder(Hcons.Weak)
371
372 module PosInt
373   =
374   struct
375     include Make(Hcons.PosInt)
376     let print ppf s =
377       Format.pp_print_string ppf "{ ";
378       iter (fun i -> Format.fprintf ppf "%i " i) s;
379       Format.pp_print_string ppf "}";
380       Format.pp_print_flush ppf ()
381   end