Half way through refactoring
[SXSI/xpathcomp.git] / ptset.ml
index 091d4a8..10c311c 100644 (file)
--- a/ptset.ml
+++ b/ptset.ml
 (* checking                                                                *)
 (*                                                                         *)
 (***************************************************************************)
-
-
-type elt = int
-
-type t = { id : int;
-          key : int; (* hash *)
-          node : node }
-and node = 
-  | Empty
-  | Leaf of int
-  | Branch of int * int * t * t
-
-module Node = 
-  struct
-    type _t = t
-    type t = _t
-    let hash x = x.key       
-    let hash_node = function 
-        | Empty -> 0
-        | Leaf i -> i+1
-            (* power of 2 +/- 1 are fast ! *)
-        | Branch (b,i,l,r) -> 
-            (b lsl 1)+ b + i+(i lsl 4) + (l.key lsl 5)-l.key
-            + (r.key lsl 7) - r.key
-    let hash_node x = (hash_node x) land max_int
-    let equal x y = match (x.node,y.node) with
-      | Empty,Empty -> true
-      | Leaf k1, Leaf k2 when k1 == k2 -> true
-      | Branch(p1,m1,l1,r1), Branch(p2,m2,l2,r2) when m1==m2 && p1==p2 && 
-         (l1.id == l2.id) && (r1.id == r2.id) -> true
-      | _ -> false
-  end
-
-module WH =Weak.Make(Node) 
-(* struct 
-  include Hashtbl.Make(Node)
-    let merge h v =
-      if mem h v then v
-      else (add h v v;v)
+INCLUDE "utils.ml"
+module type S = 
+sig
+  include Set.S
+  val intersect : t -> t -> bool
+  val is_singleton : t -> bool
+  val mem_union : t -> t -> t
+  val hash : t -> int
+  val uid : t -> int
+  val uncons : t -> elt*t
+  val from_list : elt list -> t 
 end
-*)
-let pool = WH.create 4093
-
-(* Neat trick thanks to Alain Frisch ! *)
-
-let gen_uid () = Oo.id (object end) 
-
-let empty = { id = gen_uid ();
-             key = 0;
-             node = Empty }
-
-let _ = WH.add pool empty
-
-let is_empty = function { id = 0 } -> true  | _ -> false
-    
-let rec norm n =
-  let v = { id = gen_uid ();
-           key = Node.hash_node n;
-           node = n } 
-  in
-      WH.merge pool v 
-
-(*  WH.merge pool *)
-
-let branch  p m l r  = norm (Branch(p,m,l,r))
-let leaf k = norm (Leaf k)
-
-(* To enforce the invariant that a branch contains two non empty sub-trees *)
-let branch_ne = function
-  | (_,_,e,t) when is_empty e -> t
-  | (_,_,t,e) when is_empty e -> t
-  | (p,m,t0,t1)   -> branch p m t0 t1
 
-(********** from here on, only use the smart constructors *************)
-
-let zero_bit k m = (k land m) == 0
-
-let singleton k = if k < 0 then failwith "singleton" else leaf k
-
-let rec mem k n = match n.node with
-  | Empty -> false
-  | Leaf j -> k == j
-  | Branch (p, _, l, r) -> if k <= p then mem k l else mem k r
-
-let rec min_elt n = match n.node with
-  | Empty -> raise Not_found
-  | Leaf k -> k
-  | Branch (_,_,s,_) -> min_elt s
-      
-  let rec max_elt n = match n.node with
-    | Empty -> raise Not_found
-    | Leaf k -> k
-    | Branch (_,_,_,t) -> max_elt t
-
-  let elements s =
-    let rec elements_aux acc n = match n.node with
-      | Empty -> acc
-      | Leaf k -> k :: acc
-      | Branch (_,_,l,r) -> elements_aux (elements_aux acc r) l
-    in
-    elements_aux [] s
-
-  let mask k m  = (k lor (m-1)) land (lnot m)
-
-  let naive_highest_bit x = 
-    assert (x < 256);
-    let rec loop i = 
-      if i = 0 then 1 else if x lsr i = 1 then 1 lsl i else loop (i-1)
-    in
-    loop 7
-
-  let hbit = Array.init 256 naive_highest_bit
-  
-  let highest_bit_32 x =
-    let n = x lsr 24 in if n != 0 then Array.unsafe_get hbit n lsl 24
-    else let n = x lsr 16 in if n != 0 then Array.unsafe_get hbit n lsl 16
-    else let n = x lsr 8 in if n != 0 then Array.unsafe_get hbit n lsl 8
-    else Array.unsafe_get hbit x
-
-  let highest_bit_64 x =
-    let n = x lsr 32 in if n != 0 then (highest_bit_32 n) lsl 32
-    else highest_bit_32 x
-
-  let highest_bit = match Sys.word_size with
-    | 32 -> highest_bit_32
-    | 64 -> highest_bit_64
-    | _ -> assert false
-
-  let branching_bit p0 p1 = highest_bit (p0 lxor p1)
-
-  let join p0 t0 p1 t1 =  
-    let m = branching_bit p0 p1  in
-    if zero_bit p0 m then 
-      branch (mask p0 m)  m t0 t1
-    else 
-      branch (mask p0 m) m t1 t0
+module Int : S with type elt = int = 
+struct
+  type elt = int
+  external hash_elt : elt -> int = "%identity"
+  external uid_elt : elt -> int = "%identity"
+  let equal_elt : elt -> elt -> bool = (==);;
     
-  let match_prefix k p m = (mask k m) == p
-
-  let add k t =
-    let rec ins n = match n.node with
-      | Empty -> leaf k
-      | Leaf j ->  if j == k then n else join k (leaf k) j n
-      | Branch (p,m,t0,t1)  ->
-         if match_prefix k p m then
-           if zero_bit k m then 
-             branch p m (ins t0) t1
-           else
-             branch p m t0 (ins t1)
-         else
-           join k  (leaf k)  p n
-    in
-    ins t
-      
-  let remove k t =
-    let rec rmv n = match n.node with
-      | Empty -> empty
-      | Leaf j  -> if k == j then empty else n
-      | Branch (p,m,t0,t1) -> 
-         if match_prefix k p m then
-           if zero_bit k m then
-             branch_ne (p, m, rmv t0, t1)
-           else
-             branch_ne (p, m, t0, rmv t1)
-         else
-           n
-    in
-    rmv t
-      
-  (* should run in O(1) thanks to Hash consing *)
-
-  let equal a b = a==b || a.id == b.id
+DEFINE USE_PTSET_INCLUDE
+INCLUDE "ptset_include.ml"
 
-  let compare a b = if a == b then 0 else a.id - b.id
-
-
-  let rec merge s t = 
-    if (equal s t) (* This is cheap thanks to hash-consing *)
-    then s
-    else
-      match s.node,t.node with
-       | Empty, _  -> t
-       | _, Empty  -> s
-       | Leaf k, _ -> add k t
-       | _, Leaf k -> add k s
-       | Branch (p,m,s0,s1), Branch (q,n,t0,t1) ->
-           if m == n && match_prefix q p m then
-             branch p  m  (merge s0 t0) (merge s1 t1)
-           else if m > n && match_prefix q p m then
-             if zero_bit q m then 
-               branch p m (merge s0 t) s1
-              else 
-               branch p m s0 (merge s1 t)
-           else if m < n && match_prefix p q n then     
-             if zero_bit p n then
-               branch q n (merge s t0) t1
-             else
-               branch q n t0 (merge s t1)
-           else
-             (* The prefixes disagree. *)
-             join p s q t
-           
-
-
-  let rec subset s1 s2 = (equal s1 s2) ||
-    match (s1.node,s2.node) with
-      | Empty, _ -> true
-      | _, Empty -> false
-      | Leaf k1, _ -> mem k1 s2
-      | Branch _, Leaf _ -> false
-      | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
-         if m1 == m2 && p1 == p2 then
-           subset l1 l2 && subset r1 r2
-         else if m1 < m2 && match_prefix p1 p2 m2 then
-           if zero_bit p1 m2 then 
-             subset l1 l2 && subset r1 l2
-           else 
-             subset l1 r2 && subset r1 r2
-         else
-           false
-
-  let union s t = 
-      merge s t
-             
-  let rec inter s1 s2 = 
-    if equal s1 s2 
-    then s1
-    else
-      match (s1.node,s2.node) with
-       | Empty, _ -> empty
-       | _, Empty -> empty
-       | Leaf k1, _ -> if mem k1 s2 then s1 else empty
-       | _, Leaf k2 -> if mem k2 s1 then s2 else empty
-       | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
-           if m1 == m2 && p1 == p2 then 
-             merge (inter l1 l2)  (inter r1 r2)
-           else if m1 > m2 && match_prefix p2 p1 m1 then
-             inter (if zero_bit p2 m1 then l1 else r1) s2
-           else if m1 < m2 && match_prefix p1 p2 m2 then
-             inter s1 (if zero_bit p1 m2 then l2 else r2)
-           else
-             empty
-
-  let rec diff s1 s2 = 
-    if equal s1 s2 
-    then empty
-    else
-      match (s1.node,s2.node) with
-       | Empty, _ -> empty
-       | _, Empty -> s1
-       | Leaf k1, _ -> if mem k1 s2 then empty else s1
-       | _, Leaf k2 -> remove k2 s1
-       | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
-           if m1 == m2 && p1 == p2 then
-             merge (diff l1 l2) (diff r1 r2)
-           else if m1 > m2 && match_prefix p2 p1 m1 then
-             if zero_bit p2 m1 then 
-               merge (diff l1 s2) r1
-             else 
-               merge l1 (diff r1 s2)
-           else if m1 < m2 && match_prefix p1 p2 m2 then
-             if zero_bit p1 m2 then diff s1 l2 else diff s1 r2
-           else
-         s1
-           
-           
-
-
-(*s All the following operations ([cardinal], [iter], [fold], [for_all],
-    [exists], [filter], [partition], [choose], [elements]) are
-    implemented as for any other kind of binary trees. *)
-
-let rec cardinal n = match n.node with
-  | Empty -> 0
-  | Leaf _ -> 1
-  | Branch (_,_,t0,t1) -> cardinal t0 + cardinal t1
-
-let rec iter f n = match n.node with
-  | Empty -> ()
-  | Leaf k -> f k
-  | Branch (_,_,t0,t1) -> iter f t0; iter f t1
-      
-let rec fold f s accu = match s.node with
-  | Empty -> accu
-  | Leaf k -> f k accu
-  | Branch (_,_,t0,t1) -> fold f t0 (fold f t1 accu)
-
-let rec for_all p n = match n.node with
-  | Empty -> true
-  | Leaf k -> p k
-  | Branch (_,_,t0,t1) -> for_all p t0 && for_all p t1
-
-let rec exists p n = match n.node with
-  | Empty -> false
-  | Leaf k -> p k
-  | Branch (_,_,t0,t1) -> exists p t0 || exists p t1
-
-let rec filter pr n = match n.node with
-  | Empty -> empty
-  | Leaf k -> if pr k then n else empty
-  | Branch (p,m,t0,t1) -> branch_ne (p, m, filter pr t0, filter pr t1)
-
-let partition p s =
-  let rec part (t,f as acc) n = match n.node with
-    | Empty -> acc
-    | Leaf k -> if p k then (add k t, f) else (t, add k f)
-    | Branch (_,_,t0,t1) -> part (part acc t0) t1
-  in
-  part (empty, empty) s
-
-let rec choose n = match n.node with
-  | Empty -> raise Not_found
-  | Leaf k -> k
-  | Branch (_, _,t0,_) -> choose t0   (* we know that [t0] is non-empty *)
-
-
-let split x s =
-  let coll k (l, b, r) =
-    if k < x then add k l, b, r
-    else if k > x then l, b, add k r
-    else l, true, r 
-  in
-  fold coll s (empty, false, empty)
-
-
-
-let rec dump n =
-  Printf.eprintf "{ id = %i; key = %i ; node=" n.id n.key;
-  match n.node with
-    | Empty -> Printf.eprintf "Empty; }\n"
-    | Leaf k -> Printf.eprintf "Leaf %i; }\n" k
-    | Branch (p,m,l,r) -> 
-       Printf.eprintf "Branch(%i,%i,id=%i,id=%i); }\n"
-         p m l.id r.id;
-       dump l;
-       dump r
-
-(*i*)
-let make l = List.fold_left (fun acc e -> add e acc ) empty l
-(*i*)
-
-(*s Additional functions w.r.t to [Set.S]. *)
-
-let rec intersect s1 s2 = (equal s1 s2) ||
-  match (s1.node,s2.node) with
-  | Empty, _ -> false
-  | _, Empty -> false
-  | Leaf k1, _ -> mem k1 s2
-  | _, Leaf k2 -> mem k2 s1
-  | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
-      if m1 == m2 && p1 == p2 then
-        intersect l1 l2 || intersect r1 r2
-      else if m1 < m2 && match_prefix p2 p1 m1 then
-        intersect (if zero_bit p2 m1 then l1 else r1) s2
-      else if m1 > m2 && match_prefix p1 p2 m2 then
-        intersect s1 (if zero_bit p1 m2 then l2 else r2)
-      else
-        false
-
-
-let hash s = s.key
+end
+module Make ( H : Hcons.S ) : S with type elt = H.t =
+struct
+  type elt = H.t
+  let hash_elt = H.hash
+  let uid_elt = H.uid 
+  let equal_elt = H.equal
+INCLUDE "ptset_include.ml"
+end
 
-let from_list l = List.fold_left (fun acc i -> add i acc) empty l
+(* Have to benchmark wheter this whole include stuff is worth it *)
+module I : S with type elt = int = Make ( struct type t = int 
+                                                type data = t
+                                                external hash : t -> int = "%identity"
+                                                external uid : t -> int = "%identity"
+                                                let equal : t -> t -> bool = (==)
+                                                external make : t -> int = "%identity"
+                                                external node : t -> int = "%identity"
+                                                  
+                                         end
+                                         )