X-Git-Url: http://git.nguyen.vg/gitweb/?p=tatoo.git;a=blobdiff_plain;f=src%2Fptset.ml;fp=src%2Fptset.ml;h=f9bbd03eb6074f96d7d6148ec099b3b75ff47ebb;hp=0000000000000000000000000000000000000000;hb=b00bff88c7902e828804c06b7f9dc55222fdc84e;hpb=03b6a364e7240ca827585e7baff225a0aaa33bc6 diff --git a/src/ptset.ml b/src/ptset.ml new file mode 100644 index 0000000..f9bbd03 --- /dev/null +++ b/src/ptset.ml @@ -0,0 +1,381 @@ +(* Original file: *) +(***********************************************************************) +(* *) +(* Copyright (C) Jean-Christophe Filliatre *) +(* *) +(* This software is free software; you can redistribute it and/or *) +(* modify it under the terms of the GNU Library General Public *) +(* License version 2.1, with the special exception on linking *) +(* described in file http://www.lri.fr/~filliatr/ftp/ocaml/ds/LICENSE *) +(* *) +(* This software is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. *) +(* *) +(***********************************************************************) + +(* + Time-stamp: +*) + +(* Modified by Kim Nguyen *) +(* The Patricia trees are themselves deeply hash-consed. The module + provides a Make (and Weak) functor to build hash-consed patricia + trees whose elements are Abstract hash-consed values. +*) + +INCLUDE "utils.ml" + +include Ptset_sig + +module type HConsBuilder = + functor (H : Common_sig.HashedType) -> Hcons.S with type data = H.t + +module Builder (HCB : HConsBuilder) (H : Hcons.Abstract) : + S with type elt = H.t = +struct + type elt = H.t + + type 'a set = + | Empty + | Leaf of elt + | Branch of int * int * 'a * 'a + + module rec Node : Hcons.S with type data = Data.t = HCB(Data) + and Data : Common_sig.HashedType with type t = Node.t set = + struct + type t = Node.t set + let equal x y = + match x,y with + | Empty,Empty -> true + | Leaf k1, Leaf k2 -> k1 == k2 + | Branch(b1,i1,l1,r1), Branch(b2,i2,l2,r2) -> + b1 == b2 && i1 == i2 && (Node.equal l1 l2) && (Node.equal r1 r2) + + | (Empty|Leaf _|Branch _), _ -> false + + let hash = function + | Empty -> 0 + | Leaf i -> HASHINT2 (PRIME1, Uid.to_int (H.uid i)) + | Branch (b,i,l,r) -> + HASHINT4(b, i, Uid.to_int l.Node.id, Uid.to_int r.Node.id) + end + + include Node + + let empty = Node.make Empty + + let is_empty s = (Node.node s) == Empty + + let branch p m l r = Node.make (Branch(p,m,l,r)) + + let leaf k = Node.make (Leaf k) + + (* To enforce the invariant that a branch contains two non empty + sub-trees *) + let branch_ne p m t0 t1 = + if (is_empty t0) then t1 + else if is_empty t1 then t0 else branch p m t0 t1 + + (******** from here on, only use the smart constructors ************) + + let zero_bit k m = (k land m) == 0 + + let singleton k = leaf k + + let is_singleton n = + match Node.node n with + | Leaf _ -> true + | Branch _ | Empty -> false + + let mem (k:elt) n = + let kid = Uid.to_int (H.uid k) in + let rec loop n = match Node.node n with + | Empty -> false + | Leaf j -> k == j + | Branch (p, _, l, r) -> if kid <= p then loop l else loop r + in loop n + + let rec min_elt n = match Node.node n with + | Empty -> raise Not_found + | Leaf k -> k + | Branch (_,_,s,_) -> min_elt s + + let rec max_elt n = match Node.node n with + | Empty -> raise Not_found + | Leaf k -> k + | Branch (_,_,_,t) -> max_elt t + + let elements s = + let rec elements_aux acc n = match Node.node n with + | Empty -> acc + | Leaf k -> k :: acc + | Branch (_,_,l,r) -> elements_aux (elements_aux acc r) l + in + elements_aux [] s + + let mask k m = (k lor (m-1)) land (lnot m) + + let naive_highest_bit x = + assert (x < 256); + let rec loop i = + if i = 0 then 1 else if x lsr i = 1 then 1 lsl i else loop (i-1) + in + loop 7 + + let hbit = Array.init 256 naive_highest_bit + (* + external clz : int -> int = "caml_clz" "noalloc" + external leading_bit : int -> int = "caml_leading_bit" "noalloc" + *) + let highest_bit x = + try + let n = (x) lsr 24 in + if n != 0 then hbit.(n) lsl 24 + else let n = (x) lsr 16 in if n != 0 then hbit.(n) lsl 16 + else let n = (x) lsr 8 in if n != 0 then hbit.(n) lsl 8 + else hbit.(x) + with + _ -> raise (Invalid_argument ("highest_bit " ^ (string_of_int x))) + + let highest_bit64 x = + let n = x lsr 32 in if n != 0 then highest_bit n lsl 32 + else highest_bit x + + let branching_bit p0 p1 = highest_bit64 (p0 lxor p1) + + let join p0 t0 p1 t1 = + let m = branching_bit p0 p1 in + let msk = mask p0 m in + if zero_bit p0 m then + branch_ne msk m t0 t1 + else + branch_ne msk m t1 t0 + + let match_prefix k p m = (mask k m) == p + + let add k t = + let kid = Uid.to_int (H.uid k) in + assert (kid >=0); + let rec ins n = match Node.node n with + | Empty -> leaf k + | Leaf j -> if j == k then n else join kid (leaf k) (Uid.to_int (H.uid j)) n + | Branch (p,m,t0,t1) -> + if match_prefix kid p m then + if zero_bit kid m then + branch_ne p m (ins t0) t1 + else + branch_ne p m t0 (ins t1) + else + join kid (leaf k) p n + in + ins t + + let remove k t = + let kid = Uid.to_int(H.uid k) in + let rec rmv n = match Node.node n with + | Empty -> empty + | Leaf j -> if k == j then empty else n + | Branch (p,m,t0,t1) -> + if match_prefix kid p m then + if zero_bit kid m then + branch_ne p m (rmv t0) t1 + else + branch_ne p m t0 (rmv t1) + else + n + in + rmv t + + (* should run in O(1) thanks to hash consing *) + + let equal a b = Node.equal a b + + let compare a b = (Uid.to_int (Node.uid a)) - (Uid.to_int (Node.uid b)) + + let rec merge s t = + if equal s t (* This is cheap thanks to hash-consing *) + then s + else + match Node.node s, Node.node t with + | Empty, _ -> t + | _, Empty -> s + | Leaf k, _ -> add k t + | _, Leaf k -> add k s + | Branch (p,m,s0,s1), Branch (q,n,t0,t1) -> + if m == n && match_prefix q p m then + branch p m (merge s0 t0) (merge s1 t1) + else if m > n && match_prefix q p m then + if zero_bit q m then + branch_ne p m (merge s0 t) s1 + else + branch_ne p m s0 (merge s1 t) + else if m < n && match_prefix p q n then + if zero_bit p n then + branch_ne q n (merge s t0) t1 + else + branch_ne q n t0 (merge s t1) + else + (* The prefixes disagree. *) + join p s q t + + + + + let rec subset s1 s2 = (equal s1 s2) || + match (Node.node s1,Node.node s2) with + | Empty, _ -> true + | _, Empty -> false + | Leaf k1, _ -> mem k1 s2 + | Branch _, Leaf _ -> false + | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) -> + if m1 == m2 && p1 == p2 then + subset l1 l2 && subset r1 r2 + else if m1 < m2 && match_prefix p1 p2 m2 then + if zero_bit p1 m2 then + subset l1 l2 && subset r1 l2 + else + subset l1 r2 && subset r1 r2 + else + false + + + let union s1 s2 = merge s1 s2 + (* Todo replace with e Memo Module *) + + let rec inter s1 s2 = + if equal s1 s2 + then s1 + else + match (Node.node s1,Node.node s2) with + | Empty, _ -> empty + | _, Empty -> empty + | Leaf k1, _ -> if mem k1 s2 then s1 else empty + | _, Leaf k2 -> if mem k2 s1 then s2 else empty + | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) -> + if m1 == m2 && p1 == p2 then + merge (inter l1 l2) (inter r1 r2) + else if m1 > m2 && match_prefix p2 p1 m1 then + inter (if zero_bit p2 m1 then l1 else r1) s2 + else if m1 < m2 && match_prefix p1 p2 m2 then + inter s1 (if zero_bit p1 m2 then l2 else r2) + else + empty + + let rec diff s1 s2 = + if equal s1 s2 + then empty + else + match (Node.node s1,Node.node s2) with + | Empty, _ -> empty + | _, Empty -> s1 + | Leaf k1, _ -> if mem k1 s2 then empty else s1 + | _, Leaf k2 -> remove k2 s1 + | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) -> + if m1 == m2 && p1 == p2 then + merge (diff l1 l2) (diff r1 r2) + else if m1 > m2 && match_prefix p2 p1 m1 then + if zero_bit p2 m1 then + merge (diff l1 s2) r1 + else + merge l1 (diff r1 s2) + else if m1 < m2 && match_prefix p1 p2 m2 then + if zero_bit p1 m2 then diff s1 l2 else diff s1 r2 + else + s1 + + + (*s All the following operations ([cardinal], [iter], [fold], [for_all], + [exists], [filter], [partition], [choose], [elements]) are + implemented as for any other kind of binary trees. *) + + let rec cardinal n = match Node.node n with + | Empty -> 0 + | Leaf _ -> 1 + | Branch (_,_,t0,t1) -> cardinal t0 + cardinal t1 + + let rec iter f n = match Node.node n with + | Empty -> () + | Leaf k -> f k + | Branch (_,_,t0,t1) -> iter f t0; iter f t1 + + let rec fold f s accu = match Node.node s with + | Empty -> accu + | Leaf k -> f k accu + | Branch (_,_,t0,t1) -> fold f t0 (fold f t1 accu) + + + let rec for_all p n = match Node.node n with + | Empty -> true + | Leaf k -> p k + | Branch (_,_,t0,t1) -> for_all p t0 && for_all p t1 + + let rec exists p n = match Node.node n with + | Empty -> false + | Leaf k -> p k + | Branch (_,_,t0,t1) -> exists p t0 || exists p t1 + + let rec filter pr n = match Node.node n with + | Empty -> empty + | Leaf k -> if pr k then n else empty + | Branch (p,m,t0,t1) -> branch_ne p m (filter pr t0) (filter pr t1) + + let partition p s = + let rec part (t,f as acc) n = match Node.node n with + | Empty -> acc + | Leaf k -> if p k then (add k t, f) else (t, add k f) + | Branch (_,_,t0,t1) -> part (part acc t0) t1 + in + part (empty, empty) s + + let rec choose n = match Node.node n with + | Empty -> raise Not_found + | Leaf k -> k + | Branch (_, _,t0,_) -> choose t0 (* we know that [t0] is non-empty *) + + + let split x s = + let coll k (l, b, r) = + if k < x then add k l, b, r + else if k > x then l, b, add k r + else l, true, r + in + fold coll s (empty, false, empty) + + (*s Additional functions w.r.t to [Set.S]. *) + + let rec intersect s1 s2 = (equal s1 s2) || + match (Node.node s1,Node.node s2) with + | Empty, _ -> false + | _, Empty -> false + | Leaf k1, _ -> mem k1 s2 + | _, Leaf k2 -> mem k2 s1 + | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) -> + if m1 == m2 && p1 == p2 then + intersect l1 l2 || intersect r1 r2 + else if m1 > m2 && match_prefix p2 p1 m1 then + intersect (if zero_bit p2 m1 then l1 else r1) s2 + else if m1 < m2 && match_prefix p1 p2 m2 then + intersect s1 (if zero_bit p1 m2 then l2 else r2) + else + false + + + let from_list l = List.fold_left (fun acc e -> add e acc) empty l + + +end + +module Make = Builder(Hcons.Make) +module Weak = Builder(Hcons.Weak) + +module PosInt + = +struct + include Make(Hcons.PosInt) + let print ppf s = + Format.pp_print_string ppf "{ "; + iter (fun i -> Format.fprintf ppf "%i " i) s; + Format.pp_print_string ppf "}"; + Format.pp_print_flush ppf () +end