Merge branch 'lucca-tests-bench' into lucca-optim
[tatoo.git] / src / formula.ml
1 (***********************************************************************)
2 (*                                                                     *)
3 (*                               TAToo                                 *)
4 (*                                                                     *)
5 (*                     Kim Nguyen, LRI UMR8623                         *)
6 (*                   Université Paris-Sud & CNRS                       *)
7 (*                                                                     *)
8 (*  Copyright 2010-2012 Université Paris-Sud and Centre National de la *)
9 (*  Recherche Scientifique. All rights reserved.  This file is         *)
10 (*  distributed under the terms of the GNU Lesser General Public       *)
11 (*  License, with the special exception on linking described in file   *)
12 (*  ../LICENSE.                                                        *)
13 (*                                                                     *)
14 (***********************************************************************)
15 INCLUDE "utils.ml"
16
17 open Format
18 type move = [ `Left | `Right | `Self ]
19 type 'hcons expr =
20   | False | True
21   | Or of 'hcons * 'hcons
22   | And of 'hcons * 'hcons
23   | Atom of (move * bool * State.t)
24
25 type 'hcons node = {
26   pos : 'hcons expr;
27   mutable neg : 'hcons;
28   st : StateSet.t * StateSet.t * StateSet.t;
29   size: int; (* Todo check if this is needed *)
30 }
31
32 external hash_const_variant : [> ] -> int = "%identity"
33 external vb : bool -> int = "%identity"
34
35 module rec Node : Hcons.S
36   with type data = Data.t = Hcons.Make (Data)
37   and Data : Hashtbl.HashedType  with type t = Node.t node =
38   struct
39     type t =  Node.t node
40     let equal x y = x.size == y.size &&
41       match x.pos, y.pos with
42       | a,b when a == b -> true
43       | Or(xf1, xf2), Or(yf1, yf2)
44       | And(xf1, xf2), And(yf1,yf2)  -> (xf1 == yf1) && (xf2 == yf2)
45       | Atom(d1, p1, s1), Atom(d2 ,p2 ,s2) -> d1 == d2 && p1 == p2 && s1 == s2
46       | _ -> false
47
48     let hash f =
49       match f.pos with
50       | False -> 0
51       | True -> 1
52       | Or (f1, f2) ->
53         HASHINT3 (PRIME1, Uid.to_int f1.Node.id, Uid.to_int f2.Node.id)
54       | And (f1, f2) ->
55         HASHINT3(PRIME3, Uid.to_int f1.Node.id, Uid.to_int f2.Node.id)
56
57       | Atom(d, p, s) -> HASHINT4(PRIME5, hash_const_variant d,vb p,s)
58   end
59
60 type t = Node.t
61 let hash x = x.Node.key
62 let uid x = x.Node.id
63 let equal = Node.equal
64 let expr f = f.Node.node.pos
65 let st f = f.Node.node.st
66 let size f = f.Node.node.size
67 let compare f1 f2 = compare f1.Node.id  f2.Node.id
68 let prio f =
69   match expr f with
70     | True | False -> 10
71     | Atom _ -> 8
72     | And _ -> 6
73     | Or _ -> 1
74       
75 (* Begin Lucca Hirschi *)
76 module type HcEval =
77 sig
78   type t = StateSet.t*StateSet.t*StateSet.t*Node.t
79   val equal : t -> t -> bool
80   val hash : t -> int
81 end
82   
83 type dStateS = StateSet.t*StateSet.t
84 module type HcInfer =
85 sig
86   type t = dStateS*dStateS*dStateS*Node.t
87   val equal : t -> t -> bool
88   val hash : t -> int
89 end
90     
91 module HcEval : HcEval = struct
92   type t =
93       StateSet.t*StateSet.t*StateSet.t*Node.t
94   let equal (s,l,r,f) (s',l',r',f') = StateSet.equal s s' &&
95     StateSet.equal l l' && StateSet.equal r r' && Node.equal f f'
96   let hash (s,l,r,f) =
97     HASHINT4(StateSet.hash s, StateSet.hash l, StateSet.hash r, Node.hash f)
98 end
99   
100 let dequal (x,y) (x',y') = StateSet.equal x x' && StateSet.equal y y'
101 let dhash (x,y) = HASHINT2(StateSet.hash x, StateSet.hash y)
102 module HcInfer : HcInfer = struct
103   type t = dStateS*dStateS*dStateS*Node.t
104   let equal (s,l,r,f) (s',l',r',f') = dequal s s' &&
105     dequal l l' && dequal r r' && Node.equal f f'
106   let hash (s,l,r,f) =
107     HASHINT4(dhash s, dhash l, dhash r, Node.hash f)
108 end
109
110 module HashEval = Hashtbl.Make(HcEval)
111 module HashInfer = Hashtbl.Make(HcInfer)
112 type hcEval = bool Hashtbl.Make(HcEval).t
113 type hcInfer = bool Hashtbl.Make(HcInfer).t
114
115 let rec eval_form (q,qf,qn) f hashEval =
116 try HashEval.find hashEval (q,qf,qn,f)
117 with _ ->
118   let res = match expr f with
119     | False -> false
120     | True -> true 
121     | And(f1,f2) -> eval_form (q,qf,qn) f1 hashEval &&
122       eval_form (q,qf,qn) f2 hashEval
123     | Or(f1,f2) -> eval_form (q,qf,qn) f1 hashEval ||
124       eval_form (q,qf,qn) f2 hashEval
125     | Atom(dir, b, s) -> 
126       let set = match dir with
127         |`Left -> qf | `Right -> qn | `Self -> q in
128       if b then StateSet.mem s set
129       else not (StateSet.mem s set) in
130   HashEval.add hashEval (q,qf,qn,f) res;
131   res
132
133 let rec infer_form sq sqf sqn f hashInfer =
134 try HashInfer.find hashInfer (sq,sqf,sqn,f)
135 with _ ->
136   let res = match expr f with
137     | False -> false
138     | True -> true
139     | And(f1,f2) -> infer_form sq sqf sqn f1 hashInfer &&
140       infer_form sq sqf sqn f2 hashInfer
141     | Or(f1,f2) -> infer_form sq sqf sqn f1 hashInfer ||
142       infer_form sq sqf sqn f2 hashInfer
143     | Atom(dir, b, s) -> 
144       let setq, setr = match dir with
145         | `Left -> sqf | `Right -> sqn | `Self -> sq in
146     (* WG: WE SUPPOSE THAT Q^r and Q^q are disjoint ! *)
147       let mem =  StateSet.mem s setq || StateSet.mem s setr in
148       if b then mem else not mem in
149   HashInfer.add hashInfer (sq,sqf,sqn,f) res;
150   res   
151 (* End *)
152       
153 let rec print ?(parent=false) ppf f =
154   if parent then fprintf ppf "(";
155   let _ = match expr f with
156     | True -> fprintf ppf "%s" Pretty.top
157     | False -> fprintf ppf "%s" Pretty.bottom
158     | And(f1,f2) ->
159       print ~parent:(prio f > prio f1) ppf f1;
160       fprintf ppf " %s "  Pretty.wedge;
161       print ~parent:(prio f > prio f2) ppf f2;
162     | Or(f1,f2) ->
163       (print ppf f1);
164       fprintf ppf " %s " Pretty.vee;
165       (print ppf f2);
166     | Atom(dir, b, s) ->
167       let _ = flush_str_formatter() in
168       let fmt = str_formatter in
169         let a_str, d_str =
170           match  dir with
171           | `Left ->  Pretty.down_arrow, Pretty.subscript 1
172           | `Right -> Pretty.down_arrow, Pretty.subscript 2
173           | `Self -> Pretty.down_arrow, Pretty.subscript 0
174         in
175         fprintf fmt "%s%s" a_str d_str;
176         State.print fmt s;
177         let str = flush_str_formatter() in
178         if b then fprintf ppf "%s" str
179         else Pretty.pp_overline ppf str
180   in
181     if parent then fprintf ppf ")"
182
183 let print ppf f =  print ~parent:false ppf f
184
185 let is_true f = (expr f) == True
186 let is_false f = (expr f) == False
187
188
189 let cons pos neg s1 s2 size1 size2 =
190   let nnode = Node.make { pos = neg; neg = (Obj.magic 0); st = s2; size = size2 } in
191   let pnode = Node.make { pos = pos; neg = nnode ; st = s1; size = size1 } in
192     (Node.node nnode).neg <- pnode; (* works because the neg field isn't taken into
193                                        account for hashing ! *)
194     pnode,nnode
195
196
197 let empty_triple = StateSet.empty, StateSet.empty, StateSet.empty
198 let true_,false_ = cons True False empty_triple empty_triple 0 0
199 let atom_ d p s =
200   let si = StateSet.singleton s in
201   let ss = match d with
202     | `Left -> StateSet.empty, si, StateSet.empty
203     | `Right -> StateSet.empty, StateSet.empty, si
204     | `Self -> si, StateSet.empty, StateSet.empty
205   in fst (cons (Atom(d,p,s)) (Atom(d,not p,s)) ss ss 1 1)
206
207 let not_ f = f.Node.node.neg
208
209 let union_triple (s1,l1,r1) (s2,l2, r2) =
210   StateSet.union s1 s2,
211   StateSet.union l1 l2,
212   StateSet.union r1 r2
213
214 let merge_states f1 f2 =
215   let sp =
216     union_triple (st f1) (st f2)
217   and sn =
218     union_triple (st (not_ f1)) (st (not_ f2))
219   in
220     sp,sn
221
222 let order f1 f2 = if uid f1  < uid f2 then f2,f1 else f1,f2
223
224 let or_ f1 f2 =
225   (* Tautologies: x|x, x|not(x) *)
226
227   if equal f1 f2 then f1
228   else if equal f1 (not_ f2) then true_
229
230   (* simplification *)
231   else if is_true f1 || is_true f2 then true_
232   else if is_false f1 && is_false f2 then false_
233   else if is_false f1 then f2
234   else if is_false f2 then f1
235
236   (* commutativity of | *)
237   else
238     let f1, f2 = order f1 f2 in
239     let psize = (size f1) + (size f2) in
240     let nsize = (size (not_ f1)) + (size (not_ f2)) in
241     let sp, sn = merge_states f1 f2 in
242       fst (cons (Or(f1,f2)) (And(not_ f1, not_ f2)) sp sn psize nsize)
243
244
245 let and_ f1 f2 =
246   not_ (or_ (not_ f1) (not_ f2))
247
248
249 let of_bool = function true -> true_ | false -> false_
250
251
252 module Infix = struct
253   let ( +| ) f1 f2 = or_ f1 f2
254
255   let ( *& ) f1 f2 = and_ f1 f2
256
257   let ( *+ ) d s = atom_ d true s
258   let ( *- ) d s = atom_ d false s
259 end