Implement automaton simplification.
[tatoo.git] / src / ata.ml
index 190c4c7..7b0dff4 100644 (file)
@@ -176,6 +176,15 @@ struct
     let table = get_states_by_move phi in
     Move.fold (fun _ s acc -> StateSet.union s acc) table StateSet.empty
 
+  let rec rename_state phi qfrom qto =
+    let open Boolean in
+    match expr phi with
+      False | True -> phi
+    | Or (phi1, phi2) -> or_ (rename_state phi1 qfrom qto)  (rename_state phi2 qfrom qto)
+    | And (phi1, phi2) ->  and_ (rename_state phi1 qfrom qto)  (rename_state phi2 qfrom qto)
+    | Atom ({ Atom.node = Move(m, q); }, b) when q == qfrom ->
+      let atm = mk_move m qto in if b then atm else not_ atm
+    | Atom _ -> phi
 end
 
 module Transition =
@@ -184,32 +193,28 @@ module Transition =
   type t = State.t * QNameSet.t * Formula.t
   let equal (a, b, c) (d, e, f) =
     a == d && b == e && c == f
-  let hash (a, b, c) =
-    HASHINT4 (PRIME1, a, ((QNameSet.uid b) :> int), ((Formula.uid c) :> int))
+  let hash ((a, b, c) : t) =
+    HASHINT4 (PRIME1, ((a) :> int), ((QNameSet.uid b) :> int), ((Formula.uid c) :> int))
 end)
     let print ppf t =
       let q, l, f = t.node in
       fprintf ppf "%a, %a %s %a"
         State.print q
         QNameSet.print l
-        Pretty.double_right_arrow
+        Pretty.left_arrow
         Formula.print f
   end
 
 
 module TransList : sig
   include Hlist.S with type elt = Transition.t
-  val print : Format.formatter -> ?sep:string -> t -> unit
+  val print : ?sep:string -> Format.formatter -> t -> unit
 end =
   struct
     include Hlist.Make(Transition)
-    let print ppf ?(sep="\n") l =
+    let print ?(sep="\n") ppf l =
       iter (fun t ->
-        let q, lab, f = Transition.node t in
-        fprintf ppf "%a, %a → %a%s"
-          State.print q
-          QNameSet.print lab
-          Formula.print f sep) l
+          fprintf ppf "%a%s" Transition.print t sep) l
   end
 
 
@@ -276,15 +281,15 @@ let print fmt a =
     ) ([], 0, 0) sorted_trs
   in
   let line = Pretty.line (max_all + max_pre + 6) in
-  let prev_q = ref State.dummy in
+  let prev_q = ref State.dummy_state in
   fprintf fmt "%s@\n" line;
   List.iter (fun (q, s1, s2, s3) ->
-    if !prev_q != q && !prev_q != State.dummy then fprintf fmt "%s@\n"  line;
+    if !prev_q != q && !prev_q != State.dummy_state then fprintf fmt "%s@\n"  line;
     prev_q := q;
     fprintf fmt "%s, %s" s1 s2;
     fprintf fmt "%s"
       (Pretty.padding (max_pre - Pretty.length s1 - Pretty.length s2));
-    fprintf fmt " %s  %s@\n" Pretty.right_arrow s3;
+    fprintf fmt " %s  %s@\n" Pretty.left_arrow s3;
   ) strs_strings;
   fprintf fmt "%s@\n" line
 
@@ -387,7 +392,7 @@ let normalize_negations auto =
               with
                 Not_found ->
               (* create a new state and add it to the todo queue *)
-                  let nq = State.make () in
+                  let nq = State.next () in
                   auto.states <- StateSet.add nq auto.states;
                   Hashtbl.add memo_state (q, false) nq;
                   Queue.add (q, false) todo; nq
@@ -409,7 +414,7 @@ let normalize_negations auto =
         with
           Not_found ->
             let nq = if b then q else
-                let nq = State.make () in
+                let nq = State.next () in
                 auto.states <- StateSet.add nq auto.states;
                 nq
             in
@@ -421,6 +426,47 @@ let normalize_negations auto =
   done;
   cleanup_states auto
 
+exception Found of State.t * State.t
+
+let simplify_epsilon auto =
+  let rec loop old_states =
+    if old_states != auto.states then begin
+      let old_states = auto.states in
+      try
+        Hashtbl.iter
+          (fun qfrom v -> match v with
+               [ (labels, phi) ] ->
+               if labels == QNameSet.any then begin
+                 match (Formula.expr phi) with
+                   Boolean.Atom ( {Atom.node = Move(`Stay, qto); _ }, true) -> raise (Found (qfrom, qto))
+                 | _ -> ()
+               end
+             | _ -> ()
+          ) auto.transitions
+      with Found (qfrom, qto) ->
+        Hashtbl.remove auto.transitions qfrom;
+        let new_trans = Hashtbl.fold (fun q tr_lst acc ->
+            let new_tr_lst =
+              List.map (fun (lab, phi) ->
+                  (lab, Formula.rename_state phi qfrom qto))
+                tr_lst
+            in
+            (q, new_tr_lst) :: acc) auto.transitions []
+        in
+        Hashtbl.reset auto.transitions;
+        List.iter (fun (q, l) -> Hashtbl.add auto.transitions q l) new_trans;
+        auto.states <- StateSet.remove qfrom auto.states;
+        if (StateSet.mem qfrom auto.starting_states) then
+          auto.starting_states <- StateSet.add qto (StateSet.remove qfrom auto.starting_states);
+        if (StateSet.mem qfrom auto.selecting_states) then
+          auto.selecting_states <- StateSet.add qto (StateSet.remove qfrom auto.selecting_states);
+        loop old_states
+    end
+  in
+  loop StateSet.empty
+
+
+
 (* [compute_dependencies auto] returns a hash table storing for each
    states [q] a Move.table containing the set of states on which [q]
    depends (loosely). [q] depends on [q'] if there is a transition
@@ -500,8 +546,8 @@ let compute_rank auto =
   done;
   let by_rank = Hashtbl.create 17 in
   List.iter (fun (r,s) ->
-    let set = try Hashtbl.find by_rank r with Not_found -> StateSet.empty in
-    Hashtbl.replace by_rank r (StateSet.union s set)) !rank_list;
+      let set = try Hashtbl.find by_rank r with Not_found -> StateSet.empty in
+      Hashtbl.replace by_rank r (StateSet.union s set)) !rank_list;
   auto.ranked_states <-
     Array.init (Hashtbl.length by_rank) (fun i -> Hashtbl.find by_rank i)
 
@@ -554,9 +600,12 @@ module Builder =
       in
       Hashtbl.replace a.transitions q ntrs
 
+
+
     let finalize a =
       complete_transitions a;
       normalize_negations a;
+      simplify_epsilon a;
       compute_rank a;
       a
   end
@@ -597,7 +646,7 @@ let rename_states mapper a =
 let copy a =
   let mapper = Hashtbl.create MED_H_SIZE in
   let () =
-    StateSet.iter (fun q -> Hashtbl.add mapper q (State.make())) a.states
+    StateSet.iter (fun q -> Hashtbl.add mapper q (State.next())) a.states
   in
   rename_states mapper a
 
@@ -657,7 +706,7 @@ let link a1 a2 q link_phi =
 let union a1 a2 =
   let a1 = copy a1 in
   let a2 = copy a2 in
-  let q = State.make () in
+  let q = State.next () in
   let link_phi =
     StateSet.fold
       (fun q phi -> Formula.(or_ (stay q) phi))
@@ -669,7 +718,7 @@ let union a1 a2 =
 let inter a1 a2 =
   let a1 = copy a1 in
   let a2 = copy a2 in
-  let q = State.make () in
+  let q = State.next () in
   let link_phi =
     StateSet.fold
       (fun q phi -> Formula.(and_ (stay q) phi))
@@ -680,7 +729,7 @@ let inter a1 a2 =
 
 let neg a =
   let a = copy a in
-  let q = State.make () in
+  let q = State.next () in
   let link_phi =
     StateSet.fold
       (fun q phi -> Formula.(and_ (not_(stay q)) phi))