f9bbd03eb6074f96d7d6148ec099b3b75ff47ebb
[tatoo.git] / src / ptset.ml
1 (* Original file: *)
2 (***********************************************************************)
3 (*                                                                     *)
4 (*  Copyright (C) Jean-Christophe Filliatre                            *)
5 (*                                                                     *)
6 (*  This software is free software; you can redistribute it and/or     *)
7 (*  modify it under the terms of the GNU Library General Public        *)
8 (*  License version 2.1, with the special exception on linking         *)
9 (*  described in file http://www.lri.fr/~filliatr/ftp/ocaml/ds/LICENSE *)
10 (*                                                                     *)
11 (*  This software is distributed in the hope that it will be useful,   *)
12 (*  but WITHOUT ANY WARRANTY; without even the implied warranty of     *)
13 (*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.               *)
14 (*                                                                     *)
15 (***********************************************************************)
16
17 (*
18   Time-stamp: <Last modified on 2013-03-10 18:18:54 CET by Kim Nguyen>
19 *)
20
21 (* Modified by Kim Nguyen *)
22 (* The Patricia trees are themselves deeply hash-consed. The module
23    provides a Make (and Weak) functor to build hash-consed patricia
24    trees whose elements are Abstract hash-consed values.
25 *)
26
27 INCLUDE "utils.ml"
28
29 include Ptset_sig
30
31 module type HConsBuilder =
32   functor (H : Common_sig.HashedType) -> Hcons.S with type data = H.t
33
34 module Builder (HCB : HConsBuilder) (H : Hcons.Abstract) :
35   S with type elt = H.t =
36 struct
37   type elt = H.t
38
39   type 'a set =
40     | Empty
41     | Leaf of elt
42     | Branch of int * int * 'a * 'a
43
44   module rec Node : Hcons.S with type data = Data.t = HCB(Data)
45                             and Data : Common_sig.HashedType with type t = Node.t set =
46   struct
47     type t =  Node.t set
48     let equal x y =
49       match x,y with
50       | Empty,Empty -> true
51       | Leaf k1, Leaf k2 ->  k1 == k2
52       | Branch(b1,i1,l1,r1), Branch(b2,i2,l2,r2) ->
53           b1 == b2 && i1 == i2 && (Node.equal l1 l2) && (Node.equal r1 r2)
54
55       | (Empty|Leaf _|Branch _), _  -> false
56
57     let hash = function
58     | Empty -> 0
59     | Leaf i -> HASHINT2 (PRIME1, Uid.to_int (H.uid i))
60     | Branch (b,i,l,r) ->
61         HASHINT4(b, i, Uid.to_int l.Node.id, Uid.to_int r.Node.id)
62   end
63
64   include Node
65
66   let empty = Node.make Empty
67
68   let is_empty s = (Node.node s) == Empty
69
70   let branch p m l r = Node.make (Branch(p,m,l,r))
71
72   let leaf k = Node.make (Leaf k)
73
74                             (* To enforce the invariant that a branch contains two non empty
75                                sub-trees *)
76   let branch_ne p m t0 t1 =
77     if (is_empty t0) then t1
78     else if is_empty t1 then t0 else branch p m t0 t1
79
80                             (******** from here on, only use the smart constructors ************)
81
82   let zero_bit k m = (k land m) == 0
83
84   let singleton k = leaf k
85
86   let is_singleton n =
87     match Node.node n with
88       | Leaf _ -> true
89       | Branch _ | Empty -> false
90
91   let mem (k:elt) n =
92     let kid = Uid.to_int (H.uid k) in
93     let rec loop n = match Node.node n with
94     | Empty -> false
95     | Leaf j ->  k == j
96     | Branch (p, _, l, r) -> if kid <= p then loop l else loop r
97     in loop n
98
99   let rec min_elt n = match Node.node n with
100   | Empty -> raise Not_found
101   | Leaf k -> k
102   | Branch (_,_,s,_) -> min_elt s
103
104   let rec max_elt n = match Node.node n with
105   | Empty -> raise Not_found
106   | Leaf k -> k
107   | Branch (_,_,_,t) -> max_elt t
108
109   let elements s =
110     let rec elements_aux acc n = match Node.node n with
111     | Empty -> acc
112     | Leaf k -> k :: acc
113     | Branch (_,_,l,r) -> elements_aux (elements_aux acc r) l
114     in
115     elements_aux [] s
116
117   let mask k m  = (k lor (m-1)) land (lnot m)
118
119   let naive_highest_bit x =
120     assert (x < 256);
121     let rec loop i =
122       if i = 0 then 1 else if x lsr i = 1 then 1 lsl i else loop (i-1)
123     in
124     loop 7
125
126   let hbit = Array.init 256 naive_highest_bit
127   (*
128     external clz : int -> int = "caml_clz" "noalloc"
129     external leading_bit : int -> int = "caml_leading_bit" "noalloc"
130   *)
131   let highest_bit x =
132     try
133       let n = (x) lsr 24 in
134       if n != 0 then  hbit.(n) lsl 24
135       else let n = (x) lsr 16 in if n != 0 then hbit.(n) lsl 16
136         else let n = (x) lsr 8 in if n != 0 then hbit.(n) lsl 8
137           else hbit.(x)
138     with
139       _ -> raise (Invalid_argument ("highest_bit " ^ (string_of_int x)))
140
141   let highest_bit64 x =
142     let n = x lsr 32 in if n != 0 then highest_bit n lsl 32
143       else highest_bit x
144
145   let branching_bit p0 p1 = highest_bit64 (p0 lxor p1)
146
147   let join p0 t0 p1 t1 =
148     let m = branching_bit p0 p1  in
149     let msk = mask p0 m in
150     if zero_bit p0 m then
151     branch_ne msk m t0 t1
152     else
153     branch_ne msk m t1 t0
154
155   let match_prefix k p m = (mask k m) == p
156
157   let add k t =
158     let kid = Uid.to_int (H.uid k) in
159     assert (kid >=0);
160     let rec ins n = match Node.node n with
161     | Empty -> leaf k
162     | Leaf j ->  if j == k then n else join kid (leaf k) (Uid.to_int (H.uid j)) n
163     | Branch (p,m,t0,t1)  ->
164         if match_prefix kid p m then
165         if zero_bit kid m then
166         branch_ne p m (ins t0) t1
167         else
168         branch_ne p m t0 (ins t1)
169         else
170         join kid (leaf k)  p n
171     in
172     ins t
173
174   let remove k t =
175     let kid = Uid.to_int(H.uid k) in
176     let rec rmv n = match Node.node n with
177     | Empty -> empty
178     | Leaf j  -> if  k == j then empty else n
179     | Branch (p,m,t0,t1) ->
180         if match_prefix kid p m then
181         if zero_bit kid m then
182         branch_ne p m (rmv t0) t1
183         else
184         branch_ne p m t0 (rmv t1)
185         else
186         n
187     in
188     rmv t
189
190   (* should run in O(1) thanks to hash consing *)
191
192   let equal a b = Node.equal a b
193
194   let compare a b = (Uid.to_int (Node.uid a)) - (Uid.to_int (Node.uid b))
195
196   let rec merge s t =
197     if equal s t (* This is cheap thanks to hash-consing *)
198     then s
199     else
200     match Node.node s, Node.node t with
201     | Empty, _  -> t
202     | _, Empty  -> s
203     | Leaf k, _ -> add k t
204     | _, Leaf k -> add k s
205     | Branch (p,m,s0,s1), Branch (q,n,t0,t1) ->
206         if m == n && match_prefix q p m then
207         branch p  m  (merge s0 t0) (merge s1 t1)
208         else if m > n && match_prefix q p m then
209         if zero_bit q m then
210         branch_ne p m (merge s0 t) s1
211         else
212         branch_ne p m s0 (merge s1 t)
213         else if m < n && match_prefix p q n then
214         if zero_bit p n then
215         branch_ne q n (merge s t0) t1
216         else
217         branch_ne q n t0 (merge s t1)
218         else
219                                   (* The prefixes disagree. *)
220         join p s q t
221
222
223
224
225   let rec subset s1 s2 = (equal s1 s2) ||
226     match (Node.node s1,Node.node s2) with
227     | Empty, _ -> true
228     | _, Empty -> false
229     | Leaf k1, _ -> mem k1 s2
230     | Branch _, Leaf _ -> false
231     | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
232         if m1 == m2 && p1 == p2 then
233         subset l1 l2 && subset r1 r2
234         else if m1 < m2 && match_prefix p1 p2 m2 then
235         if zero_bit p1 m2 then
236         subset l1 l2 && subset r1 l2
237         else
238         subset l1 r2 && subset r1 r2
239         else
240         false
241
242
243   let union s1 s2 = merge s1 s2
244                             (* Todo replace with e Memo Module *)
245
246   let rec inter s1 s2 =
247     if equal s1 s2
248     then s1
249     else
250     match (Node.node s1,Node.node s2) with
251     | Empty, _ -> empty
252     | _, Empty -> empty
253     | Leaf k1, _ -> if mem k1 s2 then s1 else empty
254     | _, Leaf k2 -> if mem k2 s1 then s2 else empty
255     | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
256         if m1 == m2 && p1 == p2 then
257         merge (inter l1 l2)  (inter r1 r2)
258         else if m1 > m2 && match_prefix p2 p1 m1 then
259         inter (if zero_bit p2 m1 then l1 else r1) s2
260         else if m1 < m2 && match_prefix p1 p2 m2 then
261         inter s1 (if zero_bit p1 m2 then l2 else r2)
262         else
263         empty
264
265   let rec diff s1 s2 =
266     if equal s1 s2
267     then empty
268     else
269     match (Node.node s1,Node.node s2) with
270     | Empty, _ -> empty
271     | _, Empty -> s1
272     | Leaf k1, _ -> if mem k1 s2 then empty else s1
273     | _, Leaf k2 -> remove k2 s1
274     | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
275         if m1 == m2 && p1 == p2 then
276         merge (diff l1 l2) (diff r1 r2)
277         else if m1 > m2 && match_prefix p2 p1 m1 then
278         if zero_bit p2 m1 then
279         merge (diff l1 s2) r1
280         else
281         merge l1 (diff r1 s2)
282         else if m1 < m2 && match_prefix p1 p2 m2 then
283         if zero_bit p1 m2 then diff s1 l2 else diff s1 r2
284         else
285         s1
286
287
288   (*s All the following operations ([cardinal], [iter], [fold], [for_all],
289     [exists], [filter], [partition], [choose], [elements]) are
290     implemented as for any other kind of binary trees. *)
291
292   let rec cardinal n = match Node.node n with
293   | Empty -> 0
294   | Leaf _ -> 1
295   | Branch (_,_,t0,t1) -> cardinal t0 + cardinal t1
296
297   let rec iter f n = match Node.node n with
298   | Empty -> ()
299   | Leaf k -> f k
300   | Branch (_,_,t0,t1) -> iter f t0; iter f t1
301
302   let rec fold f s accu = match Node.node s with
303   | Empty -> accu
304   | Leaf k -> f k accu
305   | Branch (_,_,t0,t1) -> fold f t0 (fold f t1 accu)
306
307
308   let rec for_all p n = match Node.node n with
309   | Empty -> true
310   | Leaf k -> p k
311   | Branch (_,_,t0,t1) -> for_all p t0 && for_all p t1
312
313   let rec exists p n = match Node.node n with
314   | Empty -> false
315   | Leaf k -> p k
316   | Branch (_,_,t0,t1) -> exists p t0 || exists p t1
317
318   let rec filter pr n = match Node.node n with
319   | Empty -> empty
320   | Leaf k -> if pr k then n else empty
321   | Branch (p,m,t0,t1) -> branch_ne p m (filter pr t0) (filter pr t1)
322
323   let partition p s =
324     let rec part (t,f as acc) n = match Node.node n with
325     | Empty -> acc
326     | Leaf k -> if p k then (add k t, f) else (t, add k f)
327     | Branch (_,_,t0,t1) -> part (part acc t0) t1
328     in
329     part (empty, empty) s
330
331   let rec choose n = match Node.node n with
332   | Empty -> raise Not_found
333   | Leaf k -> k
334   | Branch (_, _,t0,_) -> choose t0   (* we know that [t0] is non-empty *)
335
336
337   let split x s =
338     let coll k (l, b, r) =
339       if k < x then add k l, b, r
340       else if k > x then l, b, add k r
341       else l, true, r
342     in
343     fold coll s (empty, false, empty)
344
345   (*s Additional functions w.r.t to [Set.S]. *)
346
347   let rec intersect s1 s2 = (equal s1 s2) ||
348     match (Node.node s1,Node.node s2) with
349     | Empty, _ -> false
350     | _, Empty -> false
351     | Leaf k1, _ -> mem k1 s2
352     | _, Leaf k2 -> mem k2 s1
353     | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
354         if m1 == m2 && p1 == p2 then
355         intersect l1 l2 || intersect r1 r2
356         else if m1 > m2 && match_prefix p2 p1 m1 then
357         intersect (if zero_bit p2 m1 then l1 else r1) s2
358         else if m1 < m2 && match_prefix p1 p2 m2 then
359         intersect s1 (if zero_bit p1 m2 then l2 else r2)
360         else
361         false
362
363
364   let from_list l = List.fold_left (fun acc e -> add e acc) empty l
365
366
367 end
368
369 module Make = Builder(Hcons.Make)
370 module Weak = Builder(Hcons.Weak)
371
372 module PosInt
373   =
374 struct
375   include Make(Hcons.PosInt)
376   let print ppf s =
377     Format.pp_print_string ppf "{ ";
378     iter (fun i -> Format.fprintf ppf "%i " i) s;
379     Format.pp_print_string ppf "}";
380     Format.pp_print_flush ppf ()
381 end