9d415f81ca0396b37ca05d2a4f2380f24fd5c13a
[tatoo.git] / src / ptset.ml
1 (* Original file: *)
2 (***********************************************************************)
3 (*                                                                     *)
4 (*  Copyright (C) Jean-Christophe Filliatre                            *)
5 (*                                                                     *)
6 (*  This software is free software; you can redistribute it and/or     *)
7 (*  modify it under the terms of the GNU Library General Public        *)
8 (*  License version 2.1, with the special exception on linking         *)
9 (*  described in file http://www.lri.fr/~filliatr/ftp/ocaml/ds/LICENSE *)
10 (*                                                                     *)
11 (*  This software is distributed in the hope that it will be useful,   *)
12 (*  but WITHOUT ANY WARRANTY; without even the implied warranty of     *)
13 (*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.               *)
14 (*                                                                     *)
15 (***********************************************************************)
16
17 (* Modified by Kim Nguyen *)
18 (* The Patricia trees are themselves deeply hash-consed. The module
19    provides a Make (and Weak) functor to build hash-consed patricia
20    trees whose elements are Abstract hash-consed values.
21 *)
22
23 INCLUDE "utils.ml"
24
25 include Ptset_sig
26
27 module type HConsBuilder =
28   functor (H : Common_sig.HashedType) -> Hcons.S with type data = H.t
29
30 module Builder (HCB : HConsBuilder) (H : Hcons.Abstract) :
31   S with type elt = H.t =
32 struct
33   type elt = H.t
34
35   type 'a set =
36     | Empty
37     | Leaf of elt
38     | Branch of int * int * 'a * 'a
39
40   module rec Node : Hcons.S with type data = Data.t = HCB(Data)
41                             and Data : Common_sig.HashedType with type t = Node.t set =
42   struct
43     type t =  Node.t set
44     let equal x y =
45       match x,y with
46       | Empty,Empty -> true
47       | Leaf k1, Leaf k2 ->  k1 == k2
48       | Branch(b1,i1,l1,r1), Branch(b2,i2,l2,r2) ->
49           b1 == b2 && i1 == i2 && (Node.equal l1 l2) && (Node.equal r1 r2)
50
51       | (Empty|Leaf _|Branch _), _  -> false
52
53     let hash = function
54     | Empty -> 0
55     | Leaf i -> HASHINT2 (PRIME1, Uid.to_int (H.uid i))
56     | Branch (b,i,l,r) ->
57         HASHINT4(b, i, Uid.to_int l.Node.id, Uid.to_int r.Node.id)
58   end
59
60   include Node
61
62   let empty = Node.make Empty
63
64   let is_empty s = (Node.node s) == Empty
65
66   let branch p m l r = Node.make (Branch(p,m,l,r))
67
68   let leaf k = Node.make (Leaf k)
69
70                             (* To enforce the invariant that a branch contains two non empty
71                                sub-trees *)
72   let branch_ne p m t0 t1 =
73     if (is_empty t0) then t1
74     else if is_empty t1 then t0 else branch p m t0 t1
75
76                             (******** from here on, only use the smart constructors ************)
77
78   let zero_bit k m = (k land m) == 0
79
80   let singleton k = leaf k
81
82   let is_singleton n =
83     match Node.node n with
84       | Leaf _ -> true
85       | Branch _ | Empty -> false
86
87   let mem (k:elt) n =
88     let kid = Uid.to_int (H.uid k) in
89     let rec loop n = match Node.node n with
90     | Empty -> false
91     | Leaf j ->  k == j
92     | Branch (p, _, l, r) -> if kid <= p then loop l else loop r
93     in loop n
94
95   let rec min_elt n = match Node.node n with
96   | Empty -> raise Not_found
97   | Leaf k -> k
98   | Branch (_,_,s,_) -> min_elt s
99
100   let rec max_elt n = match Node.node n with
101   | Empty -> raise Not_found
102   | Leaf k -> k
103   | Branch (_,_,_,t) -> max_elt t
104
105   let elements s =
106     let rec elements_aux acc n = match Node.node n with
107     | Empty -> acc
108     | Leaf k -> k :: acc
109     | Branch (_,_,l,r) -> elements_aux (elements_aux acc r) l
110     in
111     elements_aux [] s
112
113   let mask k m  = (k lor (m-1)) land (lnot m)
114
115   let naive_highest_bit x =
116     assert (x < 256);
117     let rec loop i =
118       if i = 0 then 1 else if x lsr i = 1 then 1 lsl i else loop (i-1)
119     in
120     loop 7
121
122   let hbit = Array.init 256 naive_highest_bit
123   (*
124     external clz : int -> int = "caml_clz" "noalloc"
125     external leading_bit : int -> int = "caml_leading_bit" "noalloc"
126   *)
127   let highest_bit x =
128     try
129       let n = (x) lsr 24 in
130       if n != 0 then  hbit.(n) lsl 24
131       else let n = (x) lsr 16 in if n != 0 then hbit.(n) lsl 16
132         else let n = (x) lsr 8 in if n != 0 then hbit.(n) lsl 8
133           else hbit.(x)
134     with
135       _ -> raise (Invalid_argument ("highest_bit " ^ (string_of_int x)))
136
137   let highest_bit64 x =
138     let n = x lsr 32 in if n != 0 then highest_bit n lsl 32
139       else highest_bit x
140
141   let branching_bit p0 p1 = highest_bit64 (p0 lxor p1)
142
143   let join p0 t0 p1 t1 =
144     let m = branching_bit p0 p1  in
145     let msk = mask p0 m in
146     if zero_bit p0 m then
147     branch_ne msk m t0 t1
148     else
149     branch_ne msk m t1 t0
150
151   let match_prefix k p m = (mask k m) == p
152
153   let add k t =
154     let kid = Uid.to_int (H.uid k) in
155     assert (kid >=0);
156     let rec ins n = match Node.node n with
157     | Empty -> leaf k
158     | Leaf j ->  if j == k then n else join kid (leaf k) (Uid.to_int (H.uid j)) n
159     | Branch (p,m,t0,t1)  ->
160         if match_prefix kid p m then
161         if zero_bit kid m then
162         branch_ne p m (ins t0) t1
163         else
164         branch_ne p m t0 (ins t1)
165         else
166         join kid (leaf k)  p n
167     in
168     ins t
169
170   let remove k t =
171     let kid = Uid.to_int(H.uid k) in
172     let rec rmv n = match Node.node n with
173     | Empty -> empty
174     | Leaf j  -> if  k == j then empty else n
175     | Branch (p,m,t0,t1) ->
176         if match_prefix kid p m then
177         if zero_bit kid m then
178         branch_ne p m (rmv t0) t1
179         else
180         branch_ne p m t0 (rmv t1)
181         else
182         n
183     in
184     rmv t
185
186   (* should run in O(1) thanks to hash consing *)
187
188   let equal a b = Node.equal a b
189
190   let compare a b = (Uid.to_int (Node.uid a)) - (Uid.to_int (Node.uid b))
191
192   let rec merge s t =
193     if equal s t (* This is cheap thanks to hash-consing *)
194     then s
195     else
196     match Node.node s, Node.node t with
197     | Empty, _  -> t
198     | _, Empty  -> s
199     | Leaf k, _ -> add k t
200     | _, Leaf k -> add k s
201     | Branch (p,m,s0,s1), Branch (q,n,t0,t1) ->
202         if m == n && match_prefix q p m then
203         branch p  m  (merge s0 t0) (merge s1 t1)
204         else if m > n && match_prefix q p m then
205         if zero_bit q m then
206         branch_ne p m (merge s0 t) s1
207         else
208         branch_ne p m s0 (merge s1 t)
209         else if m < n && match_prefix p q n then
210         if zero_bit p n then
211         branch_ne q n (merge s t0) t1
212         else
213         branch_ne q n t0 (merge s t1)
214         else
215                                   (* The prefixes disagree. *)
216         join p s q t
217
218
219
220
221   let rec subset s1 s2 = (equal s1 s2) ||
222     match (Node.node s1,Node.node s2) with
223     | Empty, _ -> true
224     | _, Empty -> false
225     | Leaf k1, _ -> mem k1 s2
226     | Branch _, Leaf _ -> false
227     | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
228         if m1 == m2 && p1 == p2 then
229         subset l1 l2 && subset r1 r2
230         else if m1 < m2 && match_prefix p1 p2 m2 then
231         if zero_bit p1 m2 then
232         subset l1 l2 && subset r1 l2
233         else
234         subset l1 r2 && subset r1 r2
235         else
236         false
237
238
239   let union s1 s2 = merge s1 s2
240                             (* Todo replace with e Memo Module *)
241
242   let rec inter s1 s2 =
243     if equal s1 s2
244     then s1
245     else
246     match (Node.node s1,Node.node s2) with
247     | Empty, _ -> empty
248     | _, Empty -> empty
249     | Leaf k1, _ -> if mem k1 s2 then s1 else empty
250     | _, Leaf k2 -> if mem k2 s1 then s2 else empty
251     | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
252         if m1 == m2 && p1 == p2 then
253         merge (inter l1 l2)  (inter r1 r2)
254         else if m1 > m2 && match_prefix p2 p1 m1 then
255         inter (if zero_bit p2 m1 then l1 else r1) s2
256         else if m1 < m2 && match_prefix p1 p2 m2 then
257         inter s1 (if zero_bit p1 m2 then l2 else r2)
258         else
259         empty
260
261   let rec diff s1 s2 =
262     if equal s1 s2
263     then empty
264     else
265     match (Node.node s1,Node.node s2) with
266     | Empty, _ -> empty
267     | _, Empty -> s1
268     | Leaf k1, _ -> if mem k1 s2 then empty else s1
269     | _, Leaf k2 -> remove k2 s1
270     | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
271         if m1 == m2 && p1 == p2 then
272         merge (diff l1 l2) (diff r1 r2)
273         else if m1 > m2 && match_prefix p2 p1 m1 then
274         if zero_bit p2 m1 then
275         merge (diff l1 s2) r1
276         else
277         merge l1 (diff r1 s2)
278         else if m1 < m2 && match_prefix p1 p2 m2 then
279         if zero_bit p1 m2 then diff s1 l2 else diff s1 r2
280         else
281         s1
282
283
284   (*s All the following operations ([cardinal], [iter], [fold], [for_all],
285     [exists], [filter], [partition], [choose], [elements]) are
286     implemented as for any other kind of binary trees. *)
287
288   let rec cardinal n = match Node.node n with
289   | Empty -> 0
290   | Leaf _ -> 1
291   | Branch (_,_,t0,t1) -> cardinal t0 + cardinal t1
292
293   let rec iter f n = match Node.node n with
294   | Empty -> ()
295   | Leaf k -> f k
296   | Branch (_,_,t0,t1) -> iter f t0; iter f t1
297
298   let rec fold f s accu = match Node.node s with
299   | Empty -> accu
300   | Leaf k -> f k accu
301   | Branch (_,_,t0,t1) -> fold f t0 (fold f t1 accu)
302
303
304   let rec for_all p n = match Node.node n with
305   | Empty -> true
306   | Leaf k -> p k
307   | Branch (_,_,t0,t1) -> for_all p t0 && for_all p t1
308
309   let rec exists p n = match Node.node n with
310   | Empty -> false
311   | Leaf k -> p k
312   | Branch (_,_,t0,t1) -> exists p t0 || exists p t1
313
314   let rec filter pr n = match Node.node n with
315   | Empty -> empty
316   | Leaf k -> if pr k then n else empty
317   | Branch (p,m,t0,t1) -> branch_ne p m (filter pr t0) (filter pr t1)
318
319   let partition p s =
320     let rec part (t,f as acc) n = match Node.node n with
321     | Empty -> acc
322     | Leaf k -> if p k then (add k t, f) else (t, add k f)
323     | Branch (_,_,t0,t1) -> part (part acc t0) t1
324     in
325     part (empty, empty) s
326
327   let rec choose n = match Node.node n with
328   | Empty -> raise Not_found
329   | Leaf k -> k
330   | Branch (_, _,t0,_) -> choose t0   (* we know that [t0] is non-empty *)
331
332
333   let split x s =
334     let coll k (l, b, r) =
335       if k < x then add k l, b, r
336       else if k > x then l, b, add k r
337       else l, true, r
338     in
339     fold coll s (empty, false, empty)
340
341   (*s Additional functions w.r.t to [Set.S]. *)
342
343   let rec intersect s1 s2 = (equal s1 s2) ||
344     match (Node.node s1,Node.node s2) with
345     | Empty, _ -> false
346     | _, Empty -> false
347     | Leaf k1, _ -> mem k1 s2
348     | _, Leaf k2 -> mem k2 s1
349     | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
350         if m1 == m2 && p1 == p2 then
351         intersect l1 l2 || intersect r1 r2
352         else if m1 > m2 && match_prefix p2 p1 m1 then
353         intersect (if zero_bit p2 m1 then l1 else r1) s2
354         else if m1 < m2 && match_prefix p1 p2 m2 then
355         intersect s1 (if zero_bit p1 m2 then l2 else r2)
356         else
357         false
358
359
360   let from_list l = List.fold_left (fun acc e -> add e acc) empty l
361
362
363 end
364
365 module Make = Builder(Hcons.Make)
366 module Weak = Builder(Hcons.Weak)
367
368 module PosInt
369   =
370 struct
371   include Make(Hcons.PosInt)
372   let print ppf s =
373     Format.pp_print_string ppf "{ ";
374     iter (fun i -> Format.fprintf ppf "%i " i) s;
375     Format.pp_print_string ppf "}";
376     Format.pp_print_flush ppf ()
377 end