removed cruft, fixed ptset.ml
[SXSI/xpathcomp.git] / ptset.ml
index 10c311c..ea84ddf 100644 (file)
--- a/ptset.ml
+++ b/ptset.ml
@@ -18,28 +18,367 @@ sig
   val from_list : elt list -> t 
 end
 
-module Int : S with type elt = int = 
-struct
-  type elt = int
-  external hash_elt : elt -> int = "%identity"
-  external uid_elt : elt -> int = "%identity"
-  let equal_elt : elt -> elt -> bool = (==);;
-    
-DEFINE USE_PTSET_INCLUDE
-INCLUDE "ptset_include.ml"
-
-end
 module Make ( H : Hcons.S ) : S with type elt = H.t =
 struct
   type elt = H.t
-  let hash_elt = H.hash
-  let uid_elt = H.uid 
-  let equal_elt = H.equal
-INCLUDE "ptset_include.ml"
+
+  type 'a node =
+    | Empty
+    | Leaf of elt
+    | Branch of int * int * 'a * 'a
+       
+  module rec HNode : Hcons.S with type data = Node.t = Hcons.Make (Node)
+  and Node : Hashtbl.HashedType  with type t = HNode.t node =
+  struct 
+    type t =  HNode.t node
+    let equal x y = 
+      match x,y with
+       | Empty,Empty -> true
+       | Leaf k1, Leaf k2 -> H.equal k1 k2
+       | Branch(b1,i1,l1,r1),Branch(b2,i2,l2,r2) ->
+           b1 == b2 && i1 == i2 &&
+             (HNode.equal l1 l2) &&
+             (HNode.equal r1 r2) 
+       | _ -> false
+    let hash = function 
+      | Empty -> 0
+      | Leaf i -> HASHINT2(HALF_MAX_INT,H.hash i)
+      | Branch (b,i,l,r) -> HASHINT4(b,i,HNode.hash l, HNode.hash r)
+  end
+ ;;
+                            
+  type t = HNode.t
+  let hash = HNode.hash 
+  let uid = HNode.uid
+    
+  let empty = HNode.make Empty
+    
+  let is_empty s = (HNode.node s) == Empty
+       
+  let branch p m l r = HNode.make (Branch(p,m,l,r))
+
+  let leaf k = HNode.make (Leaf k)
+
+  (* To enforce the invariant that a branch contains two non empty sub-trees *)
+  let branch_ne p m t0 t1 = 
+    if (is_empty t0) then t1
+    else if is_empty t1 then t0 else branch p m t0 t1
+      
+  (********** from here on, only use the smart constructors *************)
+      
+  let zero_bit k m = (k land m) == 0
+    
+  let singleton k = leaf k
+    
+  let is_singleton n = 
+    match HNode.node n with Leaf _ -> true
+      | _ -> false
+         
+  let mem (k:elt) n = 
+    let kid = H.uid k in
+    let rec loop n = match HNode.node n with
+      | Empty -> false
+      | Leaf j -> H.equal k j
+      | Branch (p, _, l, r) -> if kid <= p then loop l else loop r
+    in loop n
+        
+  let rec min_elt n = match HNode.node n with
+    | Empty -> raise Not_found
+    | Leaf k -> k
+    | Branch (_,_,s,_) -> min_elt s
+       
+  let rec max_elt n = match HNode.node n with
+    | Empty -> raise Not_found
+    | Leaf k -> k
+    | Branch (_,_,_,t) -> max_elt t
+       
+  let elements s =
+    let rec elements_aux acc n = match HNode.node n with
+      | Empty -> acc
+      | Leaf k -> k :: acc
+      | Branch (_,_,l,r) -> elements_aux (elements_aux acc r) l
+    in
+      elements_aux [] s
+       
+  let mask k m  = (k lor (m-1)) land (lnot m)
+    
+  let naive_highest_bit x = 
+    assert (x < 256);
+    let rec loop i = 
+      if i = 0 then 1 else if x lsr i = 1 then 1 lsl i else loop (i-1)
+    in
+      loop 7
+       
+  let hbit = Array.init 256 naive_highest_bit
+
+
+  let highest_bit x = let n = (x) lsr 24 in 
+  if n != 0 then Array.unsafe_get hbit n lsl 24
+  else let n = (x) lsr 16 in if n != 0 then Array.unsafe_get hbit n lsl 16
+  else let n = (x) lsr 8 in if n != 0 then Array.unsafe_get hbit n lsl 8
+  else Array.unsafe_get hbit (x)
+
+IFDEF WORDIZE64
+THEN
+  let highest_bit64 x =
+    let n = x lsr 32 in if n != 0 then highest_bit n lsl 32
+      else highest_bit x
+END
+
+       
+  let branching_bit p0 p1 = highest_bit (p0 lxor p1)
+    
+  let join p0 t0 p1 t1 =  
+    let m = branching_bit p0 p1  in
+      if zero_bit p0 m then 
+       branch (mask p0 m) m t0 t1
+      else 
+       branch (mask p0 m) m t1 t0
+         
+  let match_prefix k p m = (mask k m) == p
+    
+  let add k t =
+    let kid = H.uid k in
+    let rec ins n = match HNode.node n with
+      | Empty -> leaf k
+      | Leaf j ->  if H.equal j k then n else join kid (leaf k) (H.uid j) n
+      | Branch (p,m,t0,t1)  ->
+         if match_prefix kid p m then
+           if zero_bit kid m then 
+             branch p m (ins t0) t1
+           else
+             branch p m t0 (ins t1)
+         else
+           join kid (leaf k)  p n
+    in
+    ins t
+      
+  let remove k t =
+    let kid = H.uid k in
+    let rec rmv n = match HNode.node n with
+      | Empty -> empty
+      | Leaf j  -> if H.equal k j then empty else n
+      | Branch (p,m,t0,t1) -> 
+         if match_prefix kid p m then
+           if zero_bit kid m then
+             branch_ne p m (rmv t0) t1
+           else
+             branch_ne p m t0 (rmv t1)
+         else
+           n
+    in
+    rmv t
+      
+  (* should run in O(1) thanks to Hash consing *)
+
+  let equal a b = HNode.equal a b 
+
+  let compare a b =  (HNode.uid a) - (HNode.uid b)
+
+  let rec merge s t = 
+    if (equal s t) (* This is cheap thanks to hash-consing *)
+    then s
+    else
+    match HNode.node s, HNode.node t with
+      | Empty, _  -> t
+      | _, Empty  -> s
+      | Leaf k, _ -> add k t
+      | _, Leaf k -> add k s
+      | Branch (p,m,s0,s1), Branch (q,n,t0,t1) ->
+         if m == n && match_prefix q p m then
+           branch p  m  (merge s0 t0) (merge s1 t1)
+         else if m > n && match_prefix q p m then
+           if zero_bit q m then 
+             branch p m (merge s0 t) s1
+            else 
+             branch p m s0 (merge s1 t)
+         else if m < n && match_prefix p q n then     
+           if zero_bit p n then
+             branch q n (merge s t0) t1
+           else
+             branch q n t0 (merge s t1)
+         else
+           (* The prefixes disagree. *)
+           join p s q t
+              
+       
+              
+              
+  let rec subset s1 s2 = (equal s1 s2) ||
+    match (HNode.node s1,HNode.node s2) with
+      | Empty, _ -> true
+      | _, Empty -> false
+      | Leaf k1, _ -> mem k1 s2
+      | Branch _, Leaf _ -> false
+      | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
+         if m1 == m2 && p1 == p2 then
+           subset l1 l2 && subset r1 r2
+         else if m1 < m2 && match_prefix p1 p2 m2 then
+           if zero_bit p1 m2 then 
+             subset l1 l2 && subset r1 l2
+           else 
+             subset l1 r2 && subset r1 r2
+         else
+           false
+
+             
+  let union s1 s2 = merge s1 s2
+    (* Todo replace with e Memo Module *)
+  module MemUnion = Hashtbl.Make(
+    struct 
+      type set = t 
+      type t = set*set 
+      let equal (x,y) (z,t) = (equal x z)&&(equal y t)
+      let equal a b = equal a b || equal b a
+      let hash (x,y) =   (* commutative hash *)
+       let x = HNode.hash x 
+       and y = HNode.hash y 
+       in
+         if x < y then HASHINT2(x,y) else HASHINT2(y,x)
+    end)
+  let h_mem = MemUnion.create MED_H_SIZE
+
+  let mem_union s1 s2 = 
+    try  MemUnion.find h_mem (s1,s2) 
+    with Not_found ->
+         let r = merge s1 s2 in MemUnion.add h_mem (s1,s2) r;r 
+      
+
+  let rec inter s1 s2 = 
+    if equal s1 s2 
+    then s1
+    else
+      match (HNode.node s1,HNode.node s2) with
+       | Empty, _ -> empty
+       | _, Empty -> empty
+       | Leaf k1, _ -> if mem k1 s2 then s1 else empty
+       | _, Leaf k2 -> if mem k2 s1 then s2 else empty
+       | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
+           if m1 == m2 && p1 == p2 then 
+             merge (inter l1 l2)  (inter r1 r2)
+           else if m1 > m2 && match_prefix p2 p1 m1 then
+             inter (if zero_bit p2 m1 then l1 else r1) s2
+           else if m1 < m2 && match_prefix p1 p2 m2 then
+             inter s1 (if zero_bit p1 m2 then l2 else r2)
+           else
+             empty
+
+  let rec diff s1 s2 = 
+    if equal s1 s2 
+    then empty
+    else
+      match (HNode.node s1,HNode.node s2) with
+       | Empty, _ -> empty
+       | _, Empty -> s1
+       | Leaf k1, _ -> if mem k1 s2 then empty else s1
+       | _, Leaf k2 -> remove k2 s1
+       | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
+           if m1 == m2 && p1 == p2 then
+             merge (diff l1 l2) (diff r1 r2)
+           else if m1 > m2 && match_prefix p2 p1 m1 then
+             if zero_bit p2 m1 then 
+               merge (diff l1 s2) r1
+             else 
+               merge l1 (diff r1 s2)
+           else if m1 < m2 && match_prefix p1 p2 m2 then
+             if zero_bit p1 m2 then diff s1 l2 else diff s1 r2
+           else
+         s1
+              
+
+(*s All the following operations ([cardinal], [iter], [fold], [for_all],
+    [exists], [filter], [partition], [choose], [elements]) are
+    implemented as for any other kind of binary trees. *)
+
+let rec cardinal n = match HNode.node n with
+  | Empty -> 0
+  | Leaf _ -> 1
+  | Branch (_,_,t0,t1) -> cardinal t0 + cardinal t1
+
+let rec iter f n = match HNode.node n with
+  | Empty -> ()
+  | Leaf k -> f k
+  | Branch (_,_,t0,t1) -> iter f t0; iter f t1
+      
+let rec fold f s accu = match HNode.node s with
+  | Empty -> accu
+  | Leaf k -> f k accu
+  | Branch (_,_,t0,t1) -> fold f t0 (fold f t1 accu)
+
+
+let rec for_all p n = match HNode.node n with
+  | Empty -> true
+  | Leaf k -> p k
+  | Branch (_,_,t0,t1) -> for_all p t0 && for_all p t1
+
+let rec exists p n = match HNode.node n with
+  | Empty -> false
+  | Leaf k -> p k
+  | Branch (_,_,t0,t1) -> exists p t0 || exists p t1
+
+let rec filter pr n = match HNode.node n with
+  | Empty -> empty
+  | Leaf k -> if pr k then n else empty
+  | Branch (p,m,t0,t1) -> branch_ne p m (filter pr t0) (filter pr t1)
+
+let partition p s =
+  let rec part (t,f as acc) n = match HNode.node n with
+    | Empty -> acc
+    | Leaf k -> if p k then (add k t, f) else (t, add k f)
+    | Branch (_,_,t0,t1) -> part (part acc t0) t1
+  in
+  part (empty, empty) s
+
+let rec choose n = match HNode.node n with
+  | Empty -> raise Not_found
+  | Leaf k -> k
+  | Branch (_, _,t0,_) -> choose t0   (* we know that [t0] is non-empty *)
+
+
+let split x s =
+  let coll k (l, b, r) =
+    if k < x then add k l, b, r
+    else if k > x then l, b, add k r
+    else l, true, r 
+  in
+  fold coll s (empty, false, empty)
+
+
+let make l = List.fold_left (fun acc e -> add e acc ) empty l
+(*i*)
+
+(*s Additional functions w.r.t to [Set.S]. *)
+
+let rec intersect s1 s2 = (equal s1 s2) ||
+  match (HNode.node s1,HNode.node s2) with
+  | Empty, _ -> false
+  | _, Empty -> false
+  | Leaf k1, _ -> mem k1 s2
+  | _, Leaf k2 -> mem k2 s1
+  | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
+      if m1 == m2 && p1 == p2 then
+        intersect l1 l2 || intersect r1 r2
+      else if m1 < m2 && match_prefix p2 p1 m1 then
+        intersect (if zero_bit p2 m1 then l1 else r1) s2
+      else if m1 > m2 && match_prefix p1 p2 m2 then
+        intersect s1 (if zero_bit p1 m2 then l2 else r2)
+      else
+        false
+
+
+
+let rec uncons n = match HNode.node n with
+  | Empty -> raise Not_found
+  | Leaf k -> (k,empty)
+  | Branch (p,m,s,t) -> let h,ns = uncons s in h,branch_ne p m ns t
+   
+let from_list l = List.fold_left (fun acc e -> add e acc) empty l
+
+
 end
 
 (* Have to benchmark wheter this whole include stuff is worth it *)
-module I : S with type elt = int = Make ( struct type t = int 
+module Int : S with type elt = int = Make ( struct type t = int 
                                                 type data = t
                                                 external hash : t -> int = "%identity"
                                                 external uid : t -> int = "%identity"