Remove trailing white spaces
[SXSI/xpathcomp.git] / ptset.ml
index e16cc2c..befb42e 100644 (file)
--- a/ptset.ml
+++ b/ptset.ml
 (* checking                                                                *)
 (*                                                                         *)
 (***************************************************************************)
-
-
-type elt = int
-
-type t = { id : int;
-          key : int; (* hash *)
-          node : node;
-          }
-and node = 
-  | Empty
-  | Leaf of int
-  | Branch of int * int * t * t
-
-
-(* faster if outside of a module *)
-let hash_node x = match x with 
-  | Empty -> 0
-  | Leaf i -> (i+1) land max_int
-      (* power of 2 +/- 1 are fast ! *)
-  | Branch (b,i,l,r) -> 
-      ((b lsl 1)+ b + i+(i lsl 4) + (l.key lsl 5)-l.key
-       + (r.key lsl 7) - r.key) land max_int
-
-module Node = 
-  struct
-    type _t = t
-    type t = _t
-    external hash : t -> int = "%field1"
+INCLUDE "utils.ml"
+module type S = 
+sig
+    type elt
+
+  type 'a node
+  module rec Node : sig
+    include  Hcons.S with type data = Data.t
+  end
+  and Data : sig 
+    include 
+      Hashtbl.HashedType with type t = Node.t node
+  end
+  type data = Data.t
+  type t = Node.t
+
+
+  val empty : t
+  val is_empty : t -> bool
+  val mem : elt -> t -> bool
+  val add : elt -> t -> t
+  val singleton : elt -> t
+  val remove : elt -> t -> t
+  val union : t -> t -> t
+  val inter : t -> t -> t
+  val diff : t -> t -> t
+  val compare : t -> t -> int
+  val equal : t -> t -> bool
+  val subset : t -> t -> bool
+  val iter : (elt -> unit) -> t -> unit
+  val fold : (elt -> 'a -> 'a) -> t -> 'a -> 'a
+  val for_all : (elt -> bool) -> t -> bool
+  val exists : (elt -> bool) -> t -> bool
+  val filter : (elt -> bool) -> t -> t
+  val partition : (elt -> bool) -> t -> t * t
+  val cardinal : t -> int
+  val elements : t -> elt list
+  val min_elt : t -> elt
+  val max_elt : t -> elt
+  val choose : t -> elt
+  val split : elt -> t -> t * bool * t   
+
+  val intersect : t -> t -> bool
+  val is_singleton : t -> bool
+  val mem_union : t -> t -> t
+  val hash : t -> int
+  val uid : t -> Uid.t
+  val uncons : t -> elt*t
+  val from_list : elt list -> t 
+  val make : data -> t
+  val node : t -> data
+    
+  val with_id : Uid.t -> t
+end
+
+module Make ( H : Hcons.SA ) : S with type elt = H.t =
+struct
+  type elt = H.t
+  type 'a node =
+    | Empty
+    | Leaf of elt
+    | Branch of int * int * 'a * 'a
+
+  module rec Node : Hcons.S with type data = Data.t = Hcons.Make (Data)
+  and Data : Hashtbl.HashedType  with type t = Node.t node =
+  struct 
+    type t =  Node.t node
     let equal x y = 
-      if x.id == y.id || x.key == y.key || x.node == y.node then true
-      else
-      match (x.node,y.node) with
-      | Empty,Empty -> true
-      | Leaf k1, Leaf k2 when k1 == k2 -> true
-      | Branch(p1,m1,l1,r1), Branch(p2,m2,l2,r2) when m1==m2 && p1==p2 && 
-         (l1.id == l2.id) && (r1.id == r2.id) -> true
-      | _ -> false
+      match x,y with
+       | Empty,Empty -> true
+       | Leaf k1, Leaf k2 ->  k1 == k2
+       | Branch(b1,i1,l1,r1),Branch(b2,i2,l2,r2) ->
+           b1 == b2 && i1 == i2 &&
+             (Node.equal l1 l2) &&
+             (Node.equal r1 r2) 
+       | _ -> false
+    let hash = function 
+      | Empty -> 0
+      | Leaf i -> HASHINT2(HALF_MAX_INT,Uid.to_int (H.uid i))
+      | Branch (b,i,l,r) -> HASHINT4(b,i,Uid.to_int l.Node.id, Uid.to_int r.Node.id)
   end
-
-module WH =Weak.Make(Node) 
-
-let pool = WH.create 4093
-
-(* Neat trick thanks to Alain Frisch ! *)
-
-let gen_uid () = Oo.id (object end) 
-
-let empty = { id = gen_uid ();
-             key = 0;
-             node = Empty }
-
-let _ = WH.add pool empty
-
-let is_empty s = s.id==0
+  type data = Data.t
+  type t = Node.t
+
+  let hash = Node.hash 
+  let uid = Node.uid
+  let make = Node.make
+  let node _ = failwith "node"
+  let empty = Node.make Empty
     
-let rec norm n =
-  let v = { id = gen_uid ();
-           key = hash_node n;
-           node = n } 
-  in
-      WH.merge pool v 
+  let is_empty s = (Node.node s) == Empty
+       
+  let branch p m l r = Node.make (Branch(p,m,l,r))
 
-(*  WH.merge pool *)
+  let leaf k = Node.make (Leaf k)
 
-let branch  p m l r  = norm (Branch(p,m,l,r))
-let leaf k = norm (Leaf k)
-
-(* To enforce the invariant that a branch contains two non empty sub-trees *)
-let branch_ne = function
-  | (_,_,e,t) when is_empty e -> t
-  | (_,_,t,e) when is_empty e -> t
-  | (p,m,t0,t1)   -> branch p m t0 t1
-
-(********** from here on, only use the smart constructors *************)
-
-let zero_bit k m = (k land m) == 0
-
-let singleton k = leaf k
-let is_singleton n = 
-  match n.node with Leaf _ -> true
-    | _ -> false
-
-let rec mem k n = match n.node with
-  | Empty -> false
-  | Leaf j -> k == j
-  | Branch (p, _, l, r) -> if k <= p then mem k l else mem k r
-
-let rec min_elt n = match n.node with
-  | Empty -> raise Not_found
-  | Leaf k -> k
-  | Branch (_,_,s,_) -> min_elt s
+  (* To enforce the invariant that a branch contains two non empty sub-trees *)
+  let branch_ne p m t0 t1 = 
+    if (is_empty t0) then t1
+    else if is_empty t1 then t0 else branch p m t0 t1
       
-  let rec max_elt n = match n.node with
+  (********** from here on, only use the smart constructors *************)
+      
+  let zero_bit k m = (k land m) == 0
+    
+  let singleton k = leaf k
+    
+  let is_singleton n = 
+    match Node.node n with Leaf _ -> true
+      | _ -> false
+         
+  let mem (k:elt) n = 
+    let kid = Uid.to_int (H.uid k) in
+    let rec loop n = match Node.node n with
+      | Empty -> false
+      | Leaf j ->  k == j
+      | Branch (p, _, l, r) -> if kid <= p then loop l else loop r
+    in loop n
+        
+  let rec min_elt n = match Node.node n with
+    | Empty -> raise Not_found
+    | Leaf k -> k
+    | Branch (_,_,s,_) -> min_elt s
+       
+  let rec max_elt n = match Node.node n with
     | Empty -> raise Not_found
     | Leaf k -> k
     | Branch (_,_,_,t) -> max_elt t
-
+       
   let elements s =
-    let rec elements_aux acc n = match n.node with
+    let rec elements_aux acc n = match Node.node n with
       | Empty -> acc
       | Leaf k -> k :: acc
       | Branch (_,_,l,r) -> elements_aux (elements_aux acc r) l
     in
-    elements_aux [] s
-
+      elements_aux [] s
+       
   let mask k m  = (k lor (m-1)) land (lnot m)
-
+    
   let naive_highest_bit x = 
     assert (x < 256);
     let rec loop i = 
       if i = 0 then 1 else if x lsr i = 1 then 1 lsl i else loop (i-1)
     in
-    loop 7
-
+      loop 7
+       
   let hbit = Array.init 256 naive_highest_bit
-  
-  let highest_bit_32 x =
-    let n = x lsr 24 in if n != 0 then Array.unsafe_get hbit n lsl 24
-    else let n = x lsr 16 in if n != 0 then Array.unsafe_get hbit n lsl 16
-    else let n = x lsr 8 in if n != 0 then Array.unsafe_get hbit n lsl 8
-    else Array.unsafe_get hbit x
-
-  let highest_bit_64 x =
-    let n = x lsr 32 in if n != 0 then (highest_bit_32 n) lsl 32
-    else highest_bit_32 x
-
-  let highest_bit = match Sys.word_size with
-    | 32 -> highest_bit_32
-    | 64 -> highest_bit_64
-    | _ -> assert false
 
-  let branching_bit p0 p1 = highest_bit (p0 lxor p1)
 
+  let highest_bit x = let n = (x) lsr 24 in 
+  if n != 0 then Array.unsafe_get hbit n lsl 24
+  else let n = (x) lsr 16 in if n != 0 then Array.unsafe_get hbit n lsl 16
+  else let n = (x) lsr 8 in if n != 0 then Array.unsafe_get hbit n lsl 8
+  else Array.unsafe_get hbit (x)
+
+IFDEF WORDIZE64
+THEN
+  let highest_bit64 x =
+    let n = x lsr 32 in if n != 0 then highest_bit n lsl 32
+      else highest_bit x
+END
+
+       
+  let branching_bit p0 p1 = highest_bit (p0 lxor p1)
+    
   let join p0 t0 p1 t1 =  
     let m = branching_bit p0 p1  in
-    if zero_bit p0 m then 
-      branch (mask p0 m)  m t0 t1
-    else 
-      branch (mask p0 m) m t1 t0
-    
+      if zero_bit p0 m then 
+       branch (mask p0 m) m t0 t1
+      else 
+       branch (mask p0 m) m t1 t0
+         
   let match_prefix k p m = (mask k m) == p
-
+    
   let add k t =
-    let rec ins n = match n.node with
+    let kid = Uid.to_int (H.uid k) in
+    let rec ins n = match Node.node n with
       | Empty -> leaf k
-      | Leaf j ->  if j == k then n else join k (leaf k) j n
+      | Leaf j ->  if j == k then n else join kid (leaf k) (Uid.to_int (H.uid j)) n
       | Branch (p,m,t0,t1)  ->
-         if match_prefix k p m then
-           if zero_bit k m then 
+         if match_prefix kid p m then
+           if zero_bit kid m then 
              branch p m (ins t0) t1
            else
              branch p m t0 (ins t1)
          else
-           join k  (leaf k)  p n
+           join kid (leaf k)  p n
     in
     ins t
       
   let remove k t =
-    let rec rmv n = match n.node with
+    let kid = Uid.to_int(H.uid k) in
+    let rec rmv n = match Node.node n with
       | Empty -> empty
-      | Leaf j  -> if k == j then empty else n
+      | Leaf j  -> if  k == j then empty else n
       | Branch (p,m,t0,t1) -> 
-         if match_prefix k p m then
-           if zero_bit k m then
-             branch_ne (p, m, rmv t0, t1)
+         if match_prefix kid p m then
+           if zero_bit kid m then
+             branch_ne p m (rmv t0) t1
            else
-             branch_ne (p, m, t0, rmv t1)
+             branch_ne p m t0 (rmv t1)
          else
            n
     in
@@ -179,18 +214,15 @@ let rec min_elt n = match n.node with
       
   (* should run in O(1) thanks to Hash consing *)
 
-  let equal a b = a==b || a.id == b.id
+  let equal a b = Node.equal a b 
 
-  let compare a b = if a == b then 0 else a.id - b.id
-
-  let h_merge = Hashtbl.create 4097
-  let com_hash x y = (x*y - (x+y)) land max_int
+  let compare a b =  (Uid.to_int (Node.uid a)) - (Uid.to_int (Node.uid b))
 
   let rec merge s t = 
     if (equal s t) (* This is cheap thanks to hash-consing *)
     then s
     else
-    match s.node,t.node with
+    match Node.node s, Node.node t with
       | Empty, _  -> t
       | _, Empty  -> s
       | Leaf k, _ -> add k t
@@ -216,7 +248,7 @@ let rec min_elt n = match n.node with
               
               
   let rec subset s1 s2 = (equal s1 s2) ||
-    match (s1.node,s2.node) with
+    match (Node.node s1,Node.node s2) with
       | Empty, _ -> true
       | _, Empty -> false
       | Leaf k1, _ -> mem k1 s2
@@ -232,16 +264,34 @@ let rec min_elt n = match n.node with
          else
            false
 
-
              
-
   let union s1 s2 = merge s1 s2
-             
+    (* Todo replace with e Memo Module *)
+  module MemUnion = Hashtbl.Make(
+    struct 
+      type set = t 
+      type t = set*set 
+      let equal (x,y) (z,t) = (equal x z)&&(equal y t)
+      let equal a b = equal a b || equal b a
+      let hash (x,y) =   (* commutative hash *)
+       let x = Node.hash x 
+       and y = Node.hash y 
+       in
+         if x < y then HASHINT2(x,y) else HASHINT2(y,x)
+    end)
+  let h_mem = MemUnion.create MED_H_SIZE
+
+  let mem_union s1 s2 = 
+    try  MemUnion.find h_mem (s1,s2) 
+    with Not_found ->
+         let r = merge s1 s2 in MemUnion.add h_mem (s1,s2) r;r 
+      
+
   let rec inter s1 s2 = 
     if equal s1 s2 
     then s1
     else
-      match (s1.node,s2.node) with
+      match (Node.node s1,Node.node s2) with
        | Empty, _ -> empty
        | _, Empty -> empty
        | Leaf k1, _ -> if mem k1 s2 then s1 else empty
@@ -260,7 +310,7 @@ let rec min_elt n = match n.node with
     if equal s1 s2 
     then empty
     else
-      match (s1.node,s2.node) with
+      match (Node.node s1,Node.node s2) with
        | Empty, _ -> empty
        | _, Empty -> s1
        | Leaf k1, _ -> if mem k1 s2 then empty else s1
@@ -277,53 +327,52 @@ let rec min_elt n = match n.node with
              if zero_bit p1 m2 then diff s1 l2 else diff s1 r2
            else
          s1
-           
-           
-
+              
 
 (*s All the following operations ([cardinal], [iter], [fold], [for_all],
     [exists], [filter], [partition], [choose], [elements]) are
     implemented as for any other kind of binary trees. *)
 
-let rec cardinal n = match n.node with
+let rec cardinal n = match Node.node n with
   | Empty -> 0
   | Leaf _ -> 1
   | Branch (_,_,t0,t1) -> cardinal t0 + cardinal t1
 
-let rec iter f n = match n.node with
+let rec iter f n = match Node.node n with
   | Empty -> ()
   | Leaf k -> f k
   | Branch (_,_,t0,t1) -> iter f t0; iter f t1
       
-let rec fold f s accu = match s.node with
+let rec fold f s accu = match Node.node s with
   | Empty -> accu
   | Leaf k -> f k accu
   | Branch (_,_,t0,t1) -> fold f t0 (fold f t1 accu)
 
-let rec for_all p n = match n.node with
+
+let rec for_all p n = match Node.node n with
   | Empty -> true
   | Leaf k -> p k
   | Branch (_,_,t0,t1) -> for_all p t0 && for_all p t1
 
-let rec exists p n = match n.node with
+let rec exists p n = match Node.node n with
   | Empty -> false
   | Leaf k -> p k
   | Branch (_,_,t0,t1) -> exists p t0 || exists p t1
 
-let rec filter pr n = match n.node with
+let rec filter pr n = match Node.node n with
   | Empty -> empty
   | Leaf k -> if pr k then n else empty
-  | Branch (p,m,t0,t1) -> branch_ne (p, m, filter pr t0, filter pr t1)
+  | Branch (p,m,t0,t1) -> branch_ne p m (filter pr t0) (filter pr t1)
 
 let partition p s =
-  let rec part (t,f as acc) n = match n.node with
+  let rec part (t,f as acc) n = match Node.node n with
     | Empty -> acc
     | Leaf k -> if p k then (add k t, f) else (t, add k f)
     | Branch (_,_,t0,t1) -> part (part acc t0) t1
   in
   part (empty, empty) s
 
-let rec choose n = match n.node with
+let rec choose n = match Node.node n with
   | Empty -> raise Not_found
   | Leaf k -> k
   | Branch (_, _,t0,_) -> choose t0   (* we know that [t0] is non-empty *)
@@ -337,27 +386,10 @@ let split x s =
   in
   fold coll s (empty, false, empty)
 
-
-
-let rec dump n =
-  Printf.eprintf "{ id = %i; key = %i ; node=" n.id n.key;
-  match n.node with
-    | Empty -> Printf.eprintf "Empty; }\n"
-    | Leaf k -> Printf.eprintf "Leaf %i; }\n" k
-    | Branch (p,m,l,r) -> 
-       Printf.eprintf "Branch(%i,%i,id=%i,id=%i); }\n"
-         p m l.id r.id;
-       dump l;
-       dump r
-
-(*i*)
-let make l = List.fold_left (fun acc e -> add e acc ) empty l
-(*i*)
-
 (*s Additional functions w.r.t to [Set.S]. *)
 
 let rec intersect s1 s2 = (equal s1 s2) ||
-  match (s1.node,s2.node) with
+  match (Node.node s1,Node.node s2) with
   | Empty, _ -> false
   | _, Empty -> false
   | Leaf k1, _ -> mem k1 s2
@@ -373,35 +405,36 @@ let rec intersect s1 s2 = (equal s1 s2) ||
         false
 
 
-let hash s = s.key
 
-let from_list l = List.fold_left (fun acc i -> add i acc) empty l
-
-type int_vector
-
-external int_vector_alloc : int -> int_vector = "caml_int_vector_alloc"
-external int_vector_set : int_vector -> int -> int -> unit = "caml_int_vector_set"
-external int_vector_length : int_vector -> int  = "caml_int_vector_length"
-external int_vector_empty : unit -> int_vector = "caml_int_vector_empty"
-
-let empty_vector = int_vector_empty ()
-
-let to_int_vector_ext s =
-  let l = cardinal s in
-  let v = int_vector_alloc l in
-  let i = ref 0 in
-    iter (fun e -> int_vector_set v !i e; incr i) s;
-    v
-
-let hash_vectors = Hashtbl.create 4097
-
-let to_int_vector s =
-  try 
-    Hashtbl.find hash_vectors s.key
-  with
-      Not_found -> 
-       let v = to_int_vector_ext s in
-         Hashtbl.add hash_vectors s.key v;
-         v
-
-    
+let rec uncons n = match Node.node n with
+  | Empty -> raise Not_found
+  | Leaf k -> (k,empty)
+  | Branch (p,m,s,t) -> let h,ns = uncons s in h,branch_ne p m ns t
+   
+let from_list l = List.fold_left (fun acc e -> add e acc) empty l
+
+let with_id = Node.with_id
+end
+
+module Int : sig
+  include S with type elt = int
+  val print : Format.formatter -> t -> unit
+end
+  = 
+struct
+  include Make ( struct type t = int 
+                       type data = t
+                       external hash : t -> int = "%identity"
+                       external uid : t -> Uid.t = "%identity"
+                       external equal : t -> t -> bool = "%eq"
+                       external make : t -> int = "%identity"
+                       external node : t -> int = "%identity"
+                       external with_id : Uid.t -> t = "%identity"
+                end
+              ) 
+  let print ppf s = 
+    Format.pp_print_string ppf "{ ";
+    iter (fun i -> Format.fprintf ppf "%i " i) s;
+    Format.pp_print_string ppf "}";
+    Format.pp_print_flush ppf ()
+ end