Refactor the code to have a unique place for signature definition.
[tatoo.git] / src / ptset.ml
1 (* Original file : *)
2
3 (***********************************************************************)
4 (*                                                                     *)
5 (*  Copyright (C) Jean-Christophe Filliatre                            *)
6 (*                                                                     *)
7 (*  This software is free software; you can redistribute it and/or     *)
8 (*  modify it under the terms of the GNU Library General Public        *)
9 (*  License version 2.1, with the special exception on linking         *)
10 (*  described in file http://www.lri.fr/~filliatr/ftp/ocaml/ds/LICENSE *)
11 (*                                                                     *)
12 (*  This software is distributed in the hope that it will be useful,   *)
13 (*  but WITHOUT ANY WARRANTY; without even the implied warranty of     *)
14 (*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.               *)
15 (*                                                                     *)
16 (***********************************************************************)
17
18 (* Modified by Kim Nguyen *)
19 (* The Patricia trees are themselves deeply hashconsed. The module
20    provides a Make (and Weak) functor to build hashconsed patricia
21    trees whose elements are Abstract hashconsed values. This allows
22    to build sets of integers without boxing them in an hacons structure
23 *)
24 INCLUDE "utils.ml"
25
26 include Sigs.PTSET
27
28 module type HConsBuilder =
29   functor (H : Sigs.AUX.HashedType) -> Hcons.S with type data = H.t
30
31 module Builder (HCB : HConsBuilder) (H : Hcons.Abstract) :
32   S with type elt = H.t =
33 struct
34   type elt = H.t
35
36   type 'a set =
37     | Empty
38     | Leaf of elt
39     | Branch of int * int * 'a * 'a
40
41   module rec Node : Hcons.S with type data = Data.t = HCB(Data)
42   and Data : Sigs.AUX.HashedType with type t = Node.t set =
43   struct
44     type t =  Node.t set
45     let equal x y =
46       match x,y with
47         Empty,Empty -> true
48       | Leaf k1, Leaf k2 ->  k1 == k2
49       | Branch(b1,i1,l1,r1),Branch(b2,i2,l2,r2) ->
50         b1 == b2 && i1 == i2 && (Node.equal l1 l2) && (Node.equal r1 r2)
51
52       | _ -> false
53
54     let hash = function
55       | Empty -> 0
56       | Leaf i -> HASHINT2 (PRIME1, Uid.to_int (H.uid i))
57       | Branch (b,i,l,r) ->
58         HASHINT4(b, i, Uid.to_int l.Node.id, Uid.to_int r.Node.id)
59   end
60
61   include Node
62
63   let empty = Node.make Empty
64
65   let is_empty s = (Node.node s) == Empty
66
67   let branch p m l r = Node.make (Branch(p,m,l,r))
68
69   let leaf k = Node.make (Leaf k)
70
71   (* To enforce the invariant that a branch contains two non empty
72      sub-trees *)
73   let branch_ne p m t0 t1 =
74     if (is_empty t0) then t1
75     else if is_empty t1 then t0 else branch p m t0 t1
76
77   (******** from here on, only use the smart constructors ************)
78
79   let zero_bit k m = (k land m) == 0
80
81   let singleton k = leaf k
82
83   let is_singleton n =
84     match Node.node n with Leaf _ -> true
85       | _ -> false
86
87   let mem (k:elt) n =
88     let kid = Uid.to_int (H.uid k) in
89     let rec loop n = match Node.node n with
90       | Empty -> false
91       | Leaf j ->  k == j
92       | Branch (p, _, l, r) -> if kid <= p then loop l else loop r
93     in loop n
94
95   let rec min_elt n = match Node.node n with
96     | Empty -> raise Not_found
97     | Leaf k -> k
98     | Branch (_,_,s,_) -> min_elt s
99
100   let rec max_elt n = match Node.node n with
101     | Empty -> raise Not_found
102     | Leaf k -> k
103     | Branch (_,_,_,t) -> max_elt t
104
105   let elements s =
106     let rec elements_aux acc n = match Node.node n with
107       | Empty -> acc
108       | Leaf k -> k :: acc
109       | Branch (_,_,l,r) -> elements_aux (elements_aux acc r) l
110     in
111       elements_aux [] s
112
113   let mask k m  = (k lor (m-1)) land (lnot m)
114
115   let naive_highest_bit x =
116     assert (x < 256);
117     let rec loop i =
118       if i = 0 then 1 else if x lsr i = 1 then 1 lsl i else loop (i-1)
119     in
120       loop 7
121
122   let hbit = Array.init 256 naive_highest_bit
123 (*
124   external clz : int -> int = "caml_clz" "noalloc"
125   external leading_bit : int -> int = "caml_leading_bit" "noalloc"
126 *)
127   let highest_bit x =
128     try
129       let n = (x) lsr 24 in
130       if n != 0 then  hbit.(n) lsl 24
131       else let n = (x) lsr 16 in if n != 0 then hbit.(n) lsl 16
132       else let n = (x) lsr 8 in if n != 0 then hbit.(n) lsl 8
133       else hbit.(x)
134     with
135       _ -> raise (Invalid_argument ("highest_bit " ^ (string_of_int x)))
136
137   let highest_bit64 x =
138     let n = x lsr 32 in if n != 0 then highest_bit n lsl 32
139       else highest_bit x
140
141   let branching_bit p0 p1 = highest_bit64 (p0 lxor p1)
142
143   let join p0 t0 p1 t1 =
144     let m = branching_bit p0 p1  in
145     let msk = mask p0 m in
146       if zero_bit p0 m then
147         branch_ne msk m t0 t1
148       else
149         branch_ne msk m t1 t0
150
151   let match_prefix k p m = (mask k m) == p
152
153   let add k t =
154     let kid = Uid.to_int (H.uid k) in
155       assert (kid >=0);
156     let rec ins n = match Node.node n with
157       | Empty -> leaf k
158       | Leaf j ->  if j == k then n else join kid (leaf k) (Uid.to_int (H.uid j)) n
159       | Branch (p,m,t0,t1)  ->
160         if match_prefix kid p m then
161           if zero_bit kid m then
162             branch_ne p m (ins t0) t1
163           else
164             branch_ne p m t0 (ins t1)
165         else
166           join kid (leaf k)  p n
167     in
168     ins t
169
170   let remove k t =
171     let kid = Uid.to_int(H.uid k) in
172     let rec rmv n = match Node.node n with
173       | Empty -> empty
174       | Leaf j  -> if  k == j then empty else n
175       | Branch (p,m,t0,t1) ->
176         if match_prefix kid p m then
177           if zero_bit kid m then
178             branch_ne p m (rmv t0) t1
179           else
180             branch_ne p m t0 (rmv t1)
181         else
182           n
183     in
184     rmv t
185
186   (* should run in O(1) thanks to Hash consing *)
187
188   let equal a b = Node.equal a b
189
190   let compare a b = (Uid.to_int (Node.uid a)) - (Uid.to_int (Node.uid b))
191
192   let rec merge s t =
193     if (equal s t) (* This is cheap thanks to hash-consing *)
194     then s
195     else
196     match Node.node s, Node.node t with
197       | Empty, _  -> t
198       | _, Empty  -> s
199       | Leaf k, _ -> add k t
200       | _, Leaf k -> add k s
201       | Branch (p,m,s0,s1), Branch (q,n,t0,t1) ->
202         if m == n && match_prefix q p m then
203           branch p  m  (merge s0 t0) (merge s1 t1)
204         else if m > n && match_prefix q p m then
205           if zero_bit q m then
206             branch_ne p m (merge s0 t) s1
207           else
208             branch_ne p m s0 (merge s1 t)
209         else if m < n && match_prefix p q n then
210           if zero_bit p n then
211             branch_ne q n (merge s t0) t1
212           else
213             branch_ne q n t0 (merge s t1)
214         else
215           (* The prefixes disagree. *)
216           join p s q t
217
218
219
220
221   let rec subset s1 s2 = (equal s1 s2) ||
222     match (Node.node s1,Node.node s2) with
223       | Empty, _ -> true
224       | _, Empty -> false
225       | Leaf k1, _ -> mem k1 s2
226       | Branch _, Leaf _ -> false
227       | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
228         if m1 == m2 && p1 == p2 then
229           subset l1 l2 && subset r1 r2
230         else if m1 < m2 && match_prefix p1 p2 m2 then
231           if zero_bit p1 m2 then
232             subset l1 l2 && subset r1 l2
233           else
234             subset l1 r2 && subset r1 r2
235         else
236           false
237
238
239   let union s1 s2 = merge s1 s2
240     (* Todo replace with e Memo Module *)
241
242   let rec inter s1 s2 =
243     if equal s1 s2
244     then s1
245     else
246       match (Node.node s1,Node.node s2) with
247       | Empty, _ -> empty
248       | _, Empty -> empty
249       | Leaf k1, _ -> if mem k1 s2 then s1 else empty
250       | _, Leaf k2 -> if mem k2 s1 then s2 else empty
251       | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
252         if m1 == m2 && p1 == p2 then
253           merge (inter l1 l2)  (inter r1 r2)
254         else if m1 > m2 && match_prefix p2 p1 m1 then
255           inter (if zero_bit p2 m1 then l1 else r1) s2
256         else if m1 < m2 && match_prefix p1 p2 m2 then
257           inter s1 (if zero_bit p1 m2 then l2 else r2)
258         else
259           empty
260
261   let rec diff s1 s2 =
262     if equal s1 s2
263     then empty
264     else
265       match (Node.node s1,Node.node s2) with
266       | Empty, _ -> empty
267       | _, Empty -> s1
268       | Leaf k1, _ -> if mem k1 s2 then empty else s1
269       | _, Leaf k2 -> remove k2 s1
270       | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
271         if m1 == m2 && p1 == p2 then
272           merge (diff l1 l2) (diff r1 r2)
273         else if m1 > m2 && match_prefix p2 p1 m1 then
274           if zero_bit p2 m1 then
275             merge (diff l1 s2) r1
276           else
277             merge l1 (diff r1 s2)
278         else if m1 < m2 && match_prefix p1 p2 m2 then
279           if zero_bit p1 m2 then diff s1 l2 else diff s1 r2
280         else
281           s1
282
283
284 (*s All the following operations ([cardinal], [iter], [fold], [for_all],
285     [exists], [filter], [partition], [choose], [elements]) are
286     implemented as for any other kind of binary trees. *)
287
288 let rec cardinal n = match Node.node n with
289   | Empty -> 0
290   | Leaf _ -> 1
291   | Branch (_,_,t0,t1) -> cardinal t0 + cardinal t1
292
293 let rec iter f n = match Node.node n with
294   | Empty -> ()
295   | Leaf k -> f k
296   | Branch (_,_,t0,t1) -> iter f t0; iter f t1
297
298 let rec fold f s accu = match Node.node s with
299   | Empty -> accu
300   | Leaf k -> f k accu
301   | Branch (_,_,t0,t1) -> fold f t0 (fold f t1 accu)
302
303
304 let rec for_all p n = match Node.node n with
305   | Empty -> true
306   | Leaf k -> p k
307   | Branch (_,_,t0,t1) -> for_all p t0 && for_all p t1
308
309 let rec exists p n = match Node.node n with
310   | Empty -> false
311   | Leaf k -> p k
312   | Branch (_,_,t0,t1) -> exists p t0 || exists p t1
313
314 let rec filter pr n = match Node.node n with
315   | Empty -> empty
316   | Leaf k -> if pr k then n else empty
317   | Branch (p,m,t0,t1) -> branch_ne p m (filter pr t0) (filter pr t1)
318
319 let partition p s =
320   let rec part (t,f as acc) n = match Node.node n with
321     | Empty -> acc
322     | Leaf k -> if p k then (add k t, f) else (t, add k f)
323     | Branch (_,_,t0,t1) -> part (part acc t0) t1
324   in
325   part (empty, empty) s
326
327 let rec choose n = match Node.node n with
328   | Empty -> raise Not_found
329   | Leaf k -> k
330   | Branch (_, _,t0,_) -> choose t0   (* we know that [t0] is non-empty *)
331
332
333 let split x s =
334   let coll k (l, b, r) =
335     if k < x then add k l, b, r
336     else if k > x then l, b, add k r
337     else l, true, r
338   in
339   fold coll s (empty, false, empty)
340
341 (*s Additional functions w.r.t to [Set.S]. *)
342
343 let rec intersect s1 s2 = (equal s1 s2) ||
344   match (Node.node s1,Node.node s2) with
345   | Empty, _ -> false
346   | _, Empty -> false
347   | Leaf k1, _ -> mem k1 s2
348   | _, Leaf k2 -> mem k2 s1
349   | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
350       if m1 == m2 && p1 == p2 then
351         intersect l1 l2 || intersect r1 r2
352       else if m1 < m2 && match_prefix p2 p1 m1 then
353         intersect (if zero_bit p2 m1 then l1 else r1) s2
354       else if m1 > m2 && match_prefix p1 p2 m2 then
355         intersect s1 (if zero_bit p1 m2 then l2 else r2)
356       else
357         false
358
359
360 let from_list l = List.fold_left (fun acc e -> add e acc) empty l
361
362
363 end
364
365 module Make = Builder(Hcons.Make)
366 module Weak = Builder(Hcons.Weak)
367
368 module PosInt
369   =
370   struct
371     include Make(Hcons.PosInt)
372     let print ppf s =
373       Format.pp_print_string ppf "{ ";
374       iter (fun i -> Format.fprintf ppf "%i " i) s;
375       Format.pp_print_string ppf "}";
376       Format.pp_print_flush ppf ()
377   end