779d24f1ee2d975040cf5592116558ff026dabc2
[tatoo.git] / src / ptset.ml
1 (* Original file: *)
2 (***********************************************************************)
3 (*                                                                     *)
4 (*  Copyright (C) Jean-Christophe Filliatre                            *)
5 (*                                                                     *)
6 (*  This software is free software; you can redistribute it and/or     *)
7 (*  modify it under the terms of the GNU Library General Public        *)
8 (*  License version 2.1, with the special exception on linking         *)
9 (*  described in file http://www.lri.fr/~filliatr/ftp/ocaml/ds/LICENSE *)
10 (*                                                                     *)
11 (*  This software is distributed in the hope that it will be useful,   *)
12 (*  but WITHOUT ANY WARRANTY; without even the implied warranty of     *)
13 (*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.               *)
14 (*                                                                     *)
15 (***********************************************************************)
16
17 (* Modified by Kim Nguyen *)
18 (* The Patricia trees are themselves deeply hash-consed. The module
19    provides a Make (and Weak) functor to build hash-consed patricia
20    trees whose elements are Abstract hash-consed values.
21 *)
22
23 INCLUDE "utils.ml"
24
25 include Ptset_sig
26
27 module type HConsBuilder =
28   functor (H : Common_sig.HashedType) -> Hcons.S with type data = H.t
29
30 module Builder (HCB : HConsBuilder) (H : Hcons.Abstract) :
31   S with type elt = H.t =
32 struct
33   type elt = H.t
34
35   type 'a set =
36     | Empty
37     | Leaf of elt
38     | Branch of int * int * 'a * 'a
39
40   module rec Node : Hcons.S with type data = Data.t = HCB(Data)
41                             and Data : Common_sig.HashedType
42                                 with type t = Node.t set
43     =
44   struct
45     type t =  Node.t set
46     let equal x y =
47       match x,y with
48       | Empty,Empty -> true
49       | Leaf k1, Leaf k2 ->  k1 == k2
50       | Branch(b1,i1,l1,r1), Branch(b2,i2,l2,r2) ->
51           b1 == b2 && i1 == i2 && (Node.equal l1 l2) && (Node.equal r1 r2)
52
53       | (Empty|Leaf _|Branch _), _  -> false
54
55     let hash = function
56     | Empty -> 0
57     | Leaf i -> HASHINT2 (PRIME1, Uid.to_int (H.uid i))
58     | Branch (b,i,l,r) ->
59         HASHINT4(b, i, Uid.to_int l.Node.id, Uid.to_int r.Node.id)
60   end
61
62   include Node
63
64   let empty = Node.make Empty
65
66   let is_empty s = s.Node.node == Empty
67
68   let branch p m l r = Node.make (Branch(p,m,l,r))
69
70   let leaf k = Node.make (Leaf k)
71
72   (* To enforce the invariant that a branch contains two non empty sub-trees *)
73   let branch_ne p m t0 t1 =
74     if (is_empty t0) then t1
75     else if is_empty t1 then t0 else branch p m t0 t1
76
77   (******** from here on, only use the smart constructors ************)
78
79   let singleton k = leaf k
80
81   let is_singleton n =
82     match n.Node.node with
83       | Leaf _ -> true
84       | Branch _ | Empty -> false
85
86   let mem (k:elt) n =
87     let kid = (H.uid k :> int) in
88     let rec loop n = match n.Node.node with
89     | Empty -> false
90     | Leaf j ->  k == j
91     | Branch (p, _, l, r) -> loop (if kid <= p then l else r)
92     in loop n
93
94   let rec min_elt n = match n.Node.node with
95   | Empty -> raise Not_found
96   | Leaf k -> k
97   | Branch (_,_,s,_) -> min_elt s
98
99   let rec max_elt n = match n.Node.node with
100   | Empty -> raise Not_found
101   | Leaf k -> k
102   | Branch (_,_,_,t) -> max_elt t
103
104   let elements s =
105     let rec elements_aux acc n = match n.Node.node with
106     | Empty -> acc
107     | Leaf k -> k :: acc
108     | Branch (_,_,l,r) -> elements_aux (elements_aux acc r) l
109     in
110     elements_aux [] s
111
112
113   let zero_bit k m = (k land m) == 0
114
115   let mask k m  = (k lor (m-1)) land (lnot m)
116
117   external int_of_bool : bool -> int = "%identity"
118
119   let hb32 v0 =
120     let v = v0 lor (v0 lsr 1) in
121     let v = v lor (v lsr 2) in
122     let v = v lor (v lsr 4) in
123     let v = v lor (v lsr 8) in
124     let v = v lor (v lsr 16) in
125     ((v + 1) lsr 1) + (int_of_bool (v0 == 0))
126
127   let hb64 v0 =
128     let v = v0 lor (v0 lsr 1) in
129     let v = v lor (v lsr 2) in
130     let v = v lor (v lsr 4) in
131     let v = v lor (v lsr 8) in
132     let v = v lor (v lsr 16) in
133     let v = v lor (v lsr 32) in
134     ((v + 1) lsr 1) + (int_of_bool (v0 == 0))
135
136
137   let branching_bit p0 p1 = hb64 (p0 lxor p1)
138
139   let join p0 t0 p1 t1 =
140     let m = branching_bit p0 p1  in
141     let msk = mask p0 m in
142     if zero_bit p0 m then
143       branch_ne msk m t0 t1
144     else
145       branch_ne msk m t1 t0
146
147   let match_prefix k p m = (mask k m) == p
148
149   let add k t =
150     let kid = Uid.to_int (H.uid k) in
151     let rec ins n = match n.Node.node with
152     | Empty -> leaf k
153     | Leaf j -> if j == k then n else join kid (leaf k) (Uid.to_int (H.uid j)) n
154     | Branch (p,m,t0,t1)  ->
155         if match_prefix kid p m then
156           if zero_bit kid m then
157             branch_ne p m (ins t0) t1
158           else
159             branch_ne p m t0 (ins t1)
160         else
161           join kid (leaf k)  p n
162     in
163     ins t
164
165   let remove k t =
166     let kid = (H.uid k :> int) in
167     let rec rmv n = match n.Node.node with
168     | Empty -> empty
169     | Leaf j  -> if  k == j then empty else n
170     | Branch (p,m,t0,t1) ->
171         if match_prefix kid p m then
172           if zero_bit kid m then
173             branch_ne p m (rmv t0) t1
174           else
175             branch_ne p m t0 (rmv t1)
176         else
177           n
178     in
179     rmv t
180
181   (* runs in O(1) thanks to hash consing *)
182
183   let equal a b = a == b
184
185   let compare a b = (Uid.to_int (Node.uid a)) - (Uid.to_int (Node.uid b))
186
187   let rec merge s t =
188     if equal s t (* This is cheap thanks to hash-consing *)
189     then s
190     else
191       match s.Node.node, t.Node.node with
192       | Empty, _  -> t
193       | _, Empty  -> s
194       | Leaf k, _ -> add k t
195       | _, Leaf k -> add k s
196       | Branch (p,m,s0,s1), Branch (q,n,t0,t1) ->
197           if m == n && match_prefix q p m then
198             branch p  m  (merge s0 t0) (merge s1 t1)
199           else if m > n && match_prefix q p m then
200             if zero_bit q m then
201               branch_ne p m (merge s0 t) s1
202             else
203               branch_ne p m s0 (merge s1 t)
204           else if m < n && match_prefix p q n then
205             if zero_bit p n then
206               branch_ne q n (merge s t0) t1
207             else
208               branch_ne q n t0 (merge s t1)
209           else
210         (* The prefixes disagree. *)
211             join p s q t
212
213
214
215
216   let rec subset s1 s2 = (equal s1 s2) ||
217     match s1.Node.node, s2.Node.node with
218     | Empty, _ -> true
219     | _, Empty -> false
220     | Leaf k1, _ -> mem k1 s2
221     | Branch _, Leaf _ -> false
222     | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
223         if m1 == m2 && p1 == p2 then
224           subset l1 l2 && subset r1 r2
225         else if m1 < m2 && match_prefix p1 p2 m2 then
226           if zero_bit p1 m2 then
227             subset l1 l2 && subset r1 l2
228           else
229             subset l1 r2 && subset r1 r2
230         else
231           false
232
233
234   let union s1 s2 = merge s1 s2
235   (* Todo replace with e Memo Module *)
236
237   let rec inter s1 s2 =
238     if equal s1 s2
239     then s1
240     else
241       match s1.Node.node, s2.Node.node with
242       | Empty, _ -> empty
243       | _, Empty -> empty
244       | Leaf k1, _ -> if mem k1 s2 then s1 else empty
245       | _, Leaf k2 -> if mem k2 s1 then s2 else empty
246       | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
247           if m1 == m2 && p1 == p2 then
248             merge (inter l1 l2)  (inter r1 r2)
249           else if m1 > m2 && match_prefix p2 p1 m1 then
250             inter (if zero_bit p2 m1 then l1 else r1) s2
251           else if m1 < m2 && match_prefix p1 p2 m2 then
252             inter s1 (if zero_bit p1 m2 then l2 else r2)
253           else
254             empty
255
256   let rec diff s1 s2 =
257     if equal s1 s2
258     then empty
259     else
260       match s1.Node.node, s2.Node.node with
261       | Empty, _ -> empty
262       | _, Empty -> s1
263       | Leaf k1, _ -> if mem k1 s2 then empty else s1
264       | _, Leaf k2 -> remove k2 s1
265       | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
266           if m1 == m2 && p1 == p2 then
267             merge (diff l1 l2) (diff r1 r2)
268           else if m1 > m2 && match_prefix p2 p1 m1 then
269             if zero_bit p2 m1 then
270               merge (diff l1 s2) r1
271             else
272               merge l1 (diff r1 s2)
273           else if m1 < m2 && match_prefix p1 p2 m2 then
274             if zero_bit p1 m2 then diff s1 l2 else diff s1 r2
275           else
276             s1
277
278
279   (*s All the following operations ([cardinal], [iter], [fold], [for_all],
280     [exists], [filter], [partition], [choose], [elements]) are
281     implemented as for any other kind of binary trees. *)
282
283   let rec cardinal n = match n.Node.node with
284   | Empty -> 0
285   | Leaf _ -> 1
286   | Branch (_,_,t0,t1) -> cardinal t0 + cardinal t1
287
288   let rec iter f n = match n.Node.node with
289   | Empty -> ()
290   | Leaf k -> f k
291   | Branch (_,_,t0,t1) -> iter f t0; iter f t1
292
293   let rec fold_left f s accu = match s.Node.node with
294   | Empty -> accu
295   | Leaf k -> f k accu
296   | Branch (_,_,t0,t1) -> fold_left f t1 (fold_left f t0 accu)
297
298   let rec fold_right f s accu = match s.Node.node with
299   | Empty -> accu
300   | Leaf k -> f k accu
301   | Branch (_,_,t0,t1) -> fold_right f t0 (fold_right f t1 accu)
302
303   let fold f s accu = fold_left f s accu
304
305   let rec for_all p n = match n.Node.node with
306   | Empty -> true
307   | Leaf k -> p k
308   | Branch (_,_,t0,t1) -> for_all p t0 && for_all p t1
309
310   let rec exists p n = match n.Node.node with
311   | Empty -> false
312   | Leaf k -> p k
313   | Branch (_,_,t0,t1) -> exists p t0 || exists p t1
314
315   let rec filter pr n = match n.Node.node with
316   | Empty -> empty
317   | Leaf k -> if pr k then n else empty
318   | Branch (p,m,t0,t1) -> let n0 = filter pr t0 in
319                           let n1 = filter pr t1 in
320                           branch_ne p m n0 n1
321
322   let partition p s =
323     let rec part (t,f as acc) n = match n.Node.node with
324     | Empty -> acc
325     | Leaf k -> if p k then (add k t, f) else (t, add k f)
326     | Branch (_,_,t0,t1) -> part (part acc t0) t1
327     in
328     part (empty, empty) s
329
330   let rec choose n = match n.Node.node with
331   | Empty -> raise Not_found
332   | Leaf k -> k
333   | Branch (_, _,t0,_) -> choose t0   (* we know that [t0] is non-empty *)
334
335
336   let split x s =
337     let coll k (l, b, r) =
338       if k < x then add k l, b, r
339       else if k > x then l, b, add k r
340       else l, true, r
341     in
342     fold coll s (empty, false, empty)
343
344   (*s Additional functions w.r.t to [Set.S]. *)
345
346   let rec intersect s1 s2 = (equal s1 s2) ||
347     match  s1.Node.node, s2.Node.node with
348     | Empty, _ -> false
349     | _, Empty -> false
350     | Leaf k1, _ -> mem k1 s2
351     | _, Leaf k2 -> mem k2 s1
352     | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
353         if m1 == m2 && p1 == p2 then
354           intersect l1 l2 || intersect r1 r2
355         else if m1 > m2 && match_prefix p2 p1 m1 then
356           intersect (if zero_bit p2 m1 then l1 else r1) s2
357         else if m1 < m2 && match_prefix p1 p2 m2 then
358           intersect s1 (if zero_bit p1 m2 then l2 else r2)
359         else
360           false
361
362
363   let from_list l = List.fold_left (fun acc e -> add e acc) empty l
364
365
366 end
367
368 module Make = Builder(Hcons.Make)
369 module Weak = Builder(Hcons.Weak)
370
371 module PosInt
372   =
373 struct
374   include Make(Hcons.PosInt)
375   let print ppf s =
376     Format.pp_print_string ppf "{ ";
377     iter (fun i -> Format.fprintf ppf "%i " i) s;
378     Format.pp_print_string ppf "}";
379     Format.pp_print_flush ppf ()
380 end