Add text-attribute tags to the star tagset.
[SXSI/xpathcomp.git] / src / ptset.ml
1 (***************************************************************************)
2 (* Implementation for sets of positive integers implemented as deeply hash-*)
3 (* consed Patricia trees. Provide fast set operations, fast membership as  *)
4 (* well as fast min and max elements. Hash consing provides O(1) equality  *)
5 (* checking                                                                *)
6 (*                                                                         *)
7 (***************************************************************************)
8 INCLUDE "utils.ml"
9 module type S =
10 sig
11     type elt
12
13   type 'a node
14   module rec Node : sig
15     include  Hcons.S with type data = Data.t
16   end
17   and Data : sig
18     include
19       Hashtbl.HashedType with type t = Node.t node
20   end
21   type data = Data.t
22   type t = Node.t
23
24
25   val empty : t
26   val is_empty : t -> bool
27   val mem : elt -> t -> bool
28   val add : elt -> t -> t
29   val singleton : elt -> t
30   val remove : elt -> t -> t
31   val union : t -> t -> t
32   val inter : t -> t -> t
33   val diff : t -> t -> t
34   val compare : t -> t -> int
35   val equal : t -> t -> bool
36   val subset : t -> t -> bool
37   val iter : (elt -> unit) -> t -> unit
38   val fold : (elt -> 'a -> 'a) -> t -> 'a -> 'a
39   val for_all : (elt -> bool) -> t -> bool
40   val exists : (elt -> bool) -> t -> bool
41   val filter : (elt -> bool) -> t -> t
42   val partition : (elt -> bool) -> t -> t * t
43   val cardinal : t -> int
44   val elements : t -> elt list
45   val min_elt : t -> elt
46   val max_elt : t -> elt
47   val choose : t -> elt
48   val split : elt -> t -> t * bool * t
49
50   val intersect : t -> t -> bool
51   val is_singleton : t -> bool
52   val mem_union : t -> t -> t
53   val hash : t -> int
54   val uid : t -> Uid.t
55   val uncons : t -> elt*t
56   val from_list : elt list -> t
57   val make : data -> t
58   val node : t -> data
59   val stats : unit -> unit
60 end
61
62 module Make ( H : Hcons.SA ) : S with type elt = H.t =
63 struct
64   type elt = H.t
65   type 'a node =
66     | Empty
67     | Leaf of elt
68     | Branch of int * int * 'a * 'a
69
70   module rec Node : Hcons.S with type data = Data.t = Hcons.Make (Data)
71   and Data : Hashtbl.HashedType  with type t = Node.t node =
72   struct
73     type t =  Node.t node
74     let equal x y =
75       match x,y with
76         | Empty,Empty -> true
77         | Leaf k1, Leaf k2 ->  k1 == k2
78         | Branch(b1,i1,l1,r1),Branch(b2,i2,l2,r2) ->
79             b1 == b2 && i1 == i2 &&
80               (Node.equal l1 l2) &&
81               (Node.equal r1 r2)
82         | _ -> false
83     let hash = function
84       | Empty -> 0
85       | Leaf i -> HASHINT2(HALF_MAX_INT,Uid.to_int (H.uid i))
86       | Branch (b,i,l,r) -> HASHINT4(b,i,Uid.to_int l.Node.id, Uid.to_int r.Node.id)
87   end
88
89   type data = Data.t
90   type t = Node.t
91   let stats = Node.stats
92   let hash = Node.hash
93   let uid = Node.uid
94   let make = Node.make
95   let node _ = failwith "node"
96   let empty = Node.make Empty
97
98   let is_empty s = (Node.node s) == Empty
99
100   let branch p m l r = Node.make (Branch(p,m,l,r))
101
102   let leaf k = Node.make (Leaf k)
103
104   (* To enforce the invariant that a branch contains two non empty sub-trees *)
105   let branch_ne p m t0 t1 =
106     if (is_empty t0) then t1
107     else if is_empty t1 then t0 else branch p m t0 t1
108
109   (********** from here on, only use the smart constructors *************)
110
111   let zero_bit k m = (k land m) == 0
112
113   let singleton k = leaf k
114
115   let is_singleton n =
116     match Node.node n with Leaf _ -> true
117       | _ -> false
118
119   let mem (k:elt) n =
120     let kid = Uid.to_int (H.uid k) in
121     let rec loop n = match Node.node n with
122       | Empty -> false
123       | Leaf j ->  k == j
124       | Branch (p, _, l, r) -> if kid <= p then loop l else loop r
125     in loop n
126
127   let rec min_elt n = match Node.node n with
128     | Empty -> raise Not_found
129     | Leaf k -> k
130     | Branch (_,_,s,_) -> min_elt s
131
132   let rec max_elt n = match Node.node n with
133     | Empty -> raise Not_found
134     | Leaf k -> k
135     | Branch (_,_,_,t) -> max_elt t
136
137   let elements s =
138     let rec elements_aux acc n = match Node.node n with
139       | Empty -> acc
140       | Leaf k -> k :: acc
141       | Branch (_,_,l,r) -> elements_aux (elements_aux acc r) l
142     in
143       elements_aux [] s
144
145   let mask k m  = (k lor (m-1)) land (lnot m)
146
147   let naive_highest_bit x =
148     assert (x < 256);
149     let rec loop i =
150       if i = 0 then 1 else if x lsr i = 1 then 1 lsl i else loop (i-1)
151     in
152       loop 7
153
154   let hbit = Array.init 256 naive_highest_bit
155
156
157   let highest_bit x =
158     try
159       let n = (x) lsr 24 in
160         if n != 0 then  hbit.(n) lsl 24
161         else let n = (x) lsr 16 in if n != 0 then hbit.(n) lsl 16
162         else let n = (x) lsr 8 in if n != 0 then hbit.(n) lsl 8
163         else hbit.(x)
164     with
165         _ -> raise (Invalid_argument ("highest_bit " ^ (string_of_int x)))
166
167   let highest_bit64 x =
168     let n = x lsr 32 in if n != 0 then highest_bit n lsl 32
169       else highest_bit x
170
171   let branching_bit p0 p1 = highest_bit64 (p0 lxor p1)
172
173   let join p0 t0 p1 t1 =
174     let m = branching_bit p0 p1  in
175       if zero_bit p0 m then
176         branch (mask p0 m) m t0 t1
177       else
178         branch (mask p0 m) m t1 t0
179
180   let match_prefix k p m = (mask k m) == p
181
182   let add k t =
183     let kid = Uid.to_int (H.uid k) in
184       assert (kid >=0);
185     let rec ins n = match Node.node n with
186       | Empty -> leaf k
187       | Leaf j ->  if j == k then n else join kid (leaf k) (Uid.to_int (H.uid j)) n
188       | Branch (p,m,t0,t1)  ->
189           if match_prefix kid p m then
190             if zero_bit kid m then
191               branch p m (ins t0) t1
192             else
193               branch p m t0 (ins t1)
194           else
195             join kid (leaf k)  p n
196     in
197     ins t
198
199   let remove k t =
200     let kid = Uid.to_int(H.uid k) in
201     let rec rmv n = match Node.node n with
202       | Empty -> empty
203       | Leaf j  -> if  k == j then empty else n
204       | Branch (p,m,t0,t1) ->
205           if match_prefix kid p m then
206             if zero_bit kid m then
207               branch_ne p m (rmv t0) t1
208             else
209               branch_ne p m t0 (rmv t1)
210           else
211             n
212     in
213     rmv t
214
215   (* should run in O(1) thanks to Hash consing *)
216
217   let equal a b = Node.equal a b
218
219   let compare a b =  (Uid.to_int (Node.uid a)) - (Uid.to_int (Node.uid b))
220
221   let rec merge s t =
222     if (equal s t) (* This is cheap thanks to hash-consing *)
223     then s
224     else
225     match Node.node s, Node.node t with
226       | Empty, _  -> t
227       | _, Empty  -> s
228       | Leaf k, _ -> add k t
229       | _, Leaf k -> add k s
230       | Branch (p,m,s0,s1), Branch (q,n,t0,t1) ->
231           if m == n && match_prefix q p m then
232             branch p  m  (merge s0 t0) (merge s1 t1)
233           else if m > n && match_prefix q p m then
234             if zero_bit q m then
235               branch p m (merge s0 t) s1
236             else
237               branch p m s0 (merge s1 t)
238           else if m < n && match_prefix p q n then
239             if zero_bit p n then
240               branch q n (merge s t0) t1
241             else
242               branch q n t0 (merge s t1)
243           else
244             (* The prefixes disagree. *)
245             join p s q t
246
247
248
249
250   let rec subset s1 s2 = (equal s1 s2) ||
251     match (Node.node s1,Node.node s2) with
252       | Empty, _ -> true
253       | _, Empty -> false
254       | Leaf k1, _ -> mem k1 s2
255       | Branch _, Leaf _ -> false
256       | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
257           if m1 == m2 && p1 == p2 then
258             subset l1 l2 && subset r1 r2
259           else if m1 < m2 && match_prefix p1 p2 m2 then
260             if zero_bit p1 m2 then
261               subset l1 l2 && subset r1 l2
262             else
263               subset l1 r2 && subset r1 r2
264           else
265             false
266
267
268   let union s1 s2 = merge s1 s2
269     (* Todo replace with e Memo Module *)
270   module MemUnion = Hashtbl.Make(
271     struct
272       type set = t
273       type t = set*set
274       let equal (x,y) (z,t) = (equal x z)&&(equal y t)
275       let equal a b = equal a b || equal b a
276       let hash (x,y) =   (* commutative hash *)
277         let x = Node.hash x
278         and y = Node.hash y
279         in
280           if x < y then HASHINT2(x,y) else HASHINT2(y,x)
281     end)
282   let h_mem = MemUnion.create MED_H_SIZE
283
284   let mem_union s1 s2 =
285     try  MemUnion.find h_mem (s1,s2)
286     with Not_found ->
287           let r = merge s1 s2 in MemUnion.add h_mem (s1,s2) r;r
288
289
290   let rec inter s1 s2 =
291     if equal s1 s2
292     then s1
293     else
294       match (Node.node s1,Node.node s2) with
295         | Empty, _ -> empty
296         | _, Empty -> empty
297         | Leaf k1, _ -> if mem k1 s2 then s1 else empty
298         | _, Leaf k2 -> if mem k2 s1 then s2 else empty
299         | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
300             if m1 == m2 && p1 == p2 then
301               merge (inter l1 l2)  (inter r1 r2)
302             else if m1 > m2 && match_prefix p2 p1 m1 then
303               inter (if zero_bit p2 m1 then l1 else r1) s2
304             else if m1 < m2 && match_prefix p1 p2 m2 then
305               inter s1 (if zero_bit p1 m2 then l2 else r2)
306             else
307               empty
308
309   let rec diff s1 s2 =
310     if equal s1 s2
311     then empty
312     else
313       match (Node.node s1,Node.node s2) with
314         | Empty, _ -> empty
315         | _, Empty -> s1
316         | Leaf k1, _ -> if mem k1 s2 then empty else s1
317         | _, Leaf k2 -> remove k2 s1
318         | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
319             if m1 == m2 && p1 == p2 then
320               merge (diff l1 l2) (diff r1 r2)
321             else if m1 > m2 && match_prefix p2 p1 m1 then
322               if zero_bit p2 m1 then
323                 merge (diff l1 s2) r1
324               else
325                 merge l1 (diff r1 s2)
326             else if m1 < m2 && match_prefix p1 p2 m2 then
327               if zero_bit p1 m2 then diff s1 l2 else diff s1 r2
328             else
329           s1
330
331
332 (*s All the following operations ([cardinal], [iter], [fold], [for_all],
333     [exists], [filter], [partition], [choose], [elements]) are
334     implemented as for any other kind of binary trees. *)
335
336 let rec cardinal n = match Node.node n with
337   | Empty -> 0
338   | Leaf _ -> 1
339   | Branch (_,_,t0,t1) -> cardinal t0 + cardinal t1
340
341 let rec iter f n = match Node.node n with
342   | Empty -> ()
343   | Leaf k -> f k
344   | Branch (_,_,t0,t1) -> iter f t0; iter f t1
345
346 let rec fold f s accu = match Node.node s with
347   | Empty -> accu
348   | Leaf k -> f k accu
349   | Branch (_,_,t0,t1) -> fold f t0 (fold f t1 accu)
350
351
352 let rec for_all p n = match Node.node n with
353   | Empty -> true
354   | Leaf k -> p k
355   | Branch (_,_,t0,t1) -> for_all p t0 && for_all p t1
356
357 let rec exists p n = match Node.node n with
358   | Empty -> false
359   | Leaf k -> p k
360   | Branch (_,_,t0,t1) -> exists p t0 || exists p t1
361
362 let rec filter pr n = match Node.node n with
363   | Empty -> empty
364   | Leaf k -> if pr k then n else empty
365   | Branch (p,m,t0,t1) -> branch_ne p m (filter pr t0) (filter pr t1)
366
367 let partition p s =
368   let rec part (t,f as acc) n = match Node.node n with
369     | Empty -> acc
370     | Leaf k -> if p k then (add k t, f) else (t, add k f)
371     | Branch (_,_,t0,t1) -> part (part acc t0) t1
372   in
373   part (empty, empty) s
374
375 let rec choose n = match Node.node n with
376   | Empty -> raise Not_found
377   | Leaf k -> k
378   | Branch (_, _,t0,_) -> choose t0   (* we know that [t0] is non-empty *)
379
380
381 let split x s =
382   let coll k (l, b, r) =
383     if k < x then add k l, b, r
384     else if k > x then l, b, add k r
385     else l, true, r
386   in
387   fold coll s (empty, false, empty)
388
389 (*s Additional functions w.r.t to [Set.S]. *)
390
391 let rec intersect s1 s2 = (equal s1 s2) ||
392   match (Node.node s1,Node.node s2) with
393   | Empty, _ -> false
394   | _, Empty -> false
395   | Leaf k1, _ -> mem k1 s2
396   | _, Leaf k2 -> mem k2 s1
397   | Branch (p1,m1,l1,r1), Branch (p2,m2,l2,r2) ->
398       if m1 == m2 && p1 == p2 then
399         intersect l1 l2 || intersect r1 r2
400       else if m1 < m2 && match_prefix p2 p1 m1 then
401         intersect (if zero_bit p2 m1 then l1 else r1) s2
402       else if m1 > m2 && match_prefix p1 p2 m2 then
403         intersect s1 (if zero_bit p1 m2 then l2 else r2)
404       else
405         false
406
407
408
409 let rec uncons n = match Node.node n with
410   | Empty -> raise Not_found
411   | Leaf k -> (k,empty)
412   | Branch (p,m,s,t) -> let h,ns = uncons s in h,branch_ne p m ns t
413
414 let from_list l = List.fold_left (fun acc e -> add e acc) empty l
415
416
417 end
418
419 module Int : sig
420   include S with type elt = int
421   val print : Format.formatter -> t -> unit
422 end
423   =
424 struct
425   include Make ( struct type t = int
426                         type data = t
427                         external hash : t -> int = "%identity"
428                         external uid : t -> Uid.t = "%identity"
429                         external equal : t -> t -> bool = "%eq"
430                         external make : t -> int = "%identity"
431                         external node : t -> int = "%identity"
432                         external stats : unit -> unit = "%identity"
433                  end
434                )
435   let print ppf s =
436     Format.pp_print_string ppf "{ ";
437     iter (fun i -> Format.fprintf ppf "%i " i) s;
438     Format.pp_print_string ppf "}";
439     Format.pp_print_flush ppf ()
440  end