Fix typo in debugging message.
[SXSI/xpathcomp.git] / src / ata.ml
index 86a12eb..591fab1 100644 (file)
@@ -1,31 +1,32 @@
 INCLUDE "debug.ml"
 INCLUDE "utils.ml"
+INCLUDE "log.ml"
+
 open Format
 
 type t = {
-    id : int;
-    states : StateSet.t;
-    init : StateSet.t;
-    last : State.t;
-    (* Transitions of the Alternating automaton *)
-    trans : (State.t, (TagSet.t * Transition.t) list) Hashtbl.t;
-    marking_states : StateSet.t;
-    topdown_marking_states : StateSet.t;
-    bottom_states : StateSet.t;
-    true_states : StateSet.t;
+  id : int;
+  states : StateSet.t;
+  init : StateSet.t;
+  last : State.t;
+  (* Transitions of the Alternating automaton *)
+  trans : (State.t, (TagSet.t * Transition.t) list) Hashtbl.t;
+  marking_states : StateSet.t;
+  topdown_marking_states : StateSet.t;
+  bottom_states : StateSet.t;
+  true_states : StateSet.t
 }
 
-
 let print ppf a =
   fprintf ppf
-"Automaton (%i) :
-States %a
-Initial states: %a
-Marking states: %a
-Topdown marking states: %a
-Bottom states: %a
-True states: %a
-Alternating transitions\n"
+    "Automaton (%i) :@\n\
+     States %a@\n\
+     Initial states: %a@\n\
+     Marking states: %a@\n\
+     Topdown marking states: %a@\n\
+     Bottom states: %a@\n\
+     True states: %a@\n\
+     Alternating transitions:@\n"
     a.id
     StateSet.print a.states
     StateSet.print a.init
@@ -35,193 +36,251 @@ Alternating transitions\n"
     StateSet.print a.true_states;
   let trs =
     Hashtbl.fold
-      (fun _ t acc ->
-        List.fold_left (fun acc (_, tr) -> tr::acc) acc t) a.trans []
+      (fun _ t acc -> List.fold_left (fun acc (_, tr) -> tr::acc) acc t)
+      a.trans
+      []
   in
   let sorted_trs = List.stable_sort Transition.compare trs in
   let strings = Transition.format_list sorted_trs in
-    match strings with
-       [] -> ()
-      | line::_ ->
-         let sline = Pretty.line (Pretty.length line) in
-           fprintf ppf "%s\n%!" sline;
-           List.iter (fun s -> fprintf ppf "%s\n%!" s) strings;
-           fprintf ppf "%s\n%!" sline
-
-type jump_kind = NIL
-            | NODE
-            | STAR
-           | JUMP_ONE of Ptset.Int.t
-           | JUMP_MANY of Ptset.Int.t
-           | CAPTURE_MANY of Ptset.Int.t
+  match strings with
+  | [] -> ()
+  | line::_ ->
+    let sline = Pretty.line (Pretty.length line) in
+    fprintf ppf "%s@\n" sline;
+    List.iter (fun s -> fprintf ppf "%s@\n" s) strings;
+    fprintf ppf "%s@\n" sline
+
+
+type jump_kind =
+    NIL
+  | NODE
+  | STAR
+  | JUMP_ONE of Ptset.Int.t
+  | JUMP_MANY of Ptset.Int.t
+  | CAPTURE_MANY of Ptset.Int.t
+
 
 let print_kind fmt k =
-  let () =
+  begin
     match k with
-      | NIL -> fprintf fmt "NIL"
-      | STAR -> fprintf fmt "STAR"
-      | NODE -> fprintf fmt "NODE"
-      | JUMP_ONE(t) -> let t = TagSet.inj_positive t in
-         fprintf fmt "JUMP_ONE(%a)" TagSet.print t
-      | JUMP_MANY(t) -> let t = TagSet.inj_positive t in
-        fprintf fmt "JUMP_MANY(%a)" TagSet.print t
-      | CAPTURE_MANY(t) ->
-         let t = TagSet.inj_positive t in
-           fprintf fmt "JUMP_MANY(%a)" TagSet.print t
-
-  in fprintf fmt "%!"
+    | NIL -> fprintf fmt "NIL"
+    | STAR -> fprintf fmt "STAR"
+    | NODE -> fprintf fmt "NODE"
+
+    | JUMP_ONE(t) ->
+      let t = TagSet.inj_positive t in
+      fprintf fmt "JUMP_ONE(%a)" TagSet.print t
+
+    | JUMP_MANY(t) ->
+      let t = TagSet.inj_positive t in
+      fprintf fmt "JUMP_MANY(%a)" TagSet.print t
+
+    | CAPTURE_MANY(t) ->
+      let t = TagSet.inj_positive t in
+      fprintf fmt "JUMP_MANY(%a)" TagSet.print t
+  end;
+  fprintf fmt "%!"
+
+let pr_trans fmt (ts, (l, r, m)) =
+  Format.fprintf fmt "%a %s %a %a"
+    TagSet.print ts
+    (if m then Pretty.double_right_arrow else Pretty.right_arrow)
+    StateSet.print l
+    StateSet.print r
+
 let compute_jump auto tree states l marking =
   let rel_trans, skip_trans =
     List.fold_left
-      (fun (acc_rel, acc_skip) ((ts, (l,r,m)) as tr) ->
-        if not m &&
-          ((l == states && r == states)
-           || (l == StateSet.empty && states == r)
-           || (l == states && r = StateSet.empty)
-           || (l == StateSet.empty && r = StateSet.empty))
-        then (acc_rel, tr::acc_skip)
-        else (tr::acc_rel, acc_skip))
+      (fun (acc_rel, acc_skip) ((ts, (l,r,marking)) as tr) ->
+        if not marking &&
+          ((l == states && r == states)
+           || (l == StateSet.empty && states == r)
+           || (l == states && r = StateSet.empty)
+           || (l == StateSet.empty && r = StateSet.empty))
+        then (acc_rel, tr::acc_skip)
+        else (tr::acc_rel, acc_skip))
       ([],[]) l
   in
   let rel_labels = List.fold_left
     (fun acc (ts, _ ) ->
-       Ptset.Int.union (TagSet.positive ts) acc)
+      Ptset.Int.union (TagSet.positive ts) acc)
     Ptset.Int.empty
     rel_trans
   in
-    if Ptset.Int.is_empty rel_labels then NIL
-    else
-      match skip_trans with
-         [ (_, (l, r, _) ) ] when l == r && l == states ->
-           begin
-             match rel_trans with
-               | [ (_, (l, r, m) ) ]
-                   when (rel_labels == (Tree.element_tags tree) ||
-                           Ptset.Int.is_singleton rel_labels)
-                     && (StateSet.diff l auto.true_states) == states && m
-                           -> CAPTURE_MANY(rel_labels)
-               | _ ->
-                   JUMP_MANY(rel_labels)
-           end
-       | [ (_, (l, r, _) ) ] when l == StateSet.empty -> JUMP_ONE(rel_labels)
-       | _ -> if Ptset.Int.mem Tag.pcdata rel_labels then
-           let () = D_TRACE_(Format.eprintf ">>> Computed rel_labels: %a\n%!" TagSet.print (TagSet.inj_positive rel_labels)) in NODE else STAR
+  if Ptset.Int.is_empty rel_labels then NIL
+  else
+    match skip_trans with
+      [ (_, (l, r, _) ) ] when l == r && l == states ->
+        begin
+          match rel_trans with
+          | [ (_, (l, r, m) ) ]
+              when (rel_labels == (Tree.element_tags tree) ||
+                      Ptset.Int.is_singleton rel_labels)
+                && (StateSet.diff l auto.true_states) == states && m
+                -> CAPTURE_MANY(rel_labels)
+          | _ ->
+            JUMP_MANY(rel_labels)
+        end
+
+    | [ (_, (l, r, _) ) ] when l == StateSet.empty -> JUMP_ONE(rel_labels)
+
+    | _ ->
+      if Ptset.Int.mem Tag.pcdata rel_labels then begin
+        LOG(__ "top-down-approx"  3  "Computed rel_labels: %a"
+              TagSet.print (TagSet.inj_positive rel_labels));
+        LOG(__ "top-down-approx"  3  "skip_trans:@\n%a"
+              (Pretty.pp_print_list ~sep:Format.pp_force_newline pr_trans)
+              skip_trans);
+        LOG(__ "top-down-approx"  3  "rel_trans:@\n%a"
+              (Pretty.pp_print_list ~sep:Format.pp_force_newline pr_trans)
+              rel_trans);
+        NODE
+      end else STAR
 
 module Cache = Hashtbl.Make(StateSet)
 let cache = Cache.create 1023
 let init () = Cache.clear cache
 
-let by_labels (tags1,(_,_,m1)) (tags2,(_,_,m2)) =
-  let r = TagSet.compare tags1 tags2 in r
-(*
-    if r == 0 then compare m1 m2 else r
-*)
-let by_states (_,(l1,r1, m1)) (_, (l2,r2,m2)) =
+let by_labels (tags1, _) (tags2, _) = TagSet.compare tags1 tags2
+
+let by_states x1 x2 =
+  let l1, r1, (m1 : bool) = snd x1
+  and l2, r2, (m2 : bool) = snd x2 in
+  (* force m1/m2 to be of type bool for efficient compare *)
   let r = StateSet.compare l1 l2 in
-    if r == 0 then
-      let r' = StateSet.compare r1 r2 in
-       if r' == 0 then compare m1 m2
-       else r'
-    else r
+  if r != 0 then r
+  else
+    let r' = StateSet.compare r1 r2 in
+    if r' != 0 then r'
+    else compare m1 m2
 
 let merge_states (tags1, (l1, r1, m1)) (tags2, (l2, r2, m2)) =
-  if tags1 == tags2 then (tags1,(StateSet.union l1 l2, StateSet.union r1 r2, m1 || m2))
+  if tags1 == tags2 then
+    tags1, (StateSet.union l1 l2, StateSet.union r1 r2, m1 || m2)
   else assert false
 
 let merge_labels (tags1, (l1, r1, m1)) (tags2, (l2, r2, m2)) =
-  if (l1 == l2) && (r1 == r2) && (m1 == m2) then (TagSet.cup tags1 tags2),(l1,r1,m1)
+  if l1 == l2 && r1 == r2 && m1 == m2 then
+    (TagSet.cup tags1 tags2), (l1, r1, m1)
   else assert false
 
 let rec merge_trans comp f l =
   match l with
-    | [] |[ _ ] -> l
-    | tr1::tr2::ll ->
-       if comp tr1 tr2 == 0 then merge_trans comp f ((f tr1 tr2)::ll)
-       else tr1 :: (merge_trans comp f (tr2::ll))
+  | [] |[ _ ] -> l
+  | tr1::tr2::ll ->
+    if comp tr1 tr2 == 0 then merge_trans comp f ((f tr1 tr2)::ll)
+    else tr1 :: (merge_trans comp f (tr2::ll))
 
+let fold_trans_of_states auto f states acc =
+  StateSet.fold
+    (fun q tr_acc ->
+      List.fold_left f tr_acc (Hashtbl.find auto.trans q))
+    states
+    acc
 
 let top_down_approx auto states tree =
-  try
-    Cache.find cache states
-  with
-      Not_found ->
-       let jump =
-         begin
-           let trs =
-             StateSet.fold
-               (fun q acc -> List.fold_left
-                  (fun acc_tr (ts, tr) ->
-                      let pos =
-                        if ts == TagSet.star
-                        then Tree.element_tags tree
-                        else if ts == TagSet.any then Tree.node_tags tree
-                        else TagSet.positive ts
-                      in
-                      let _, _, m, f = Transition.node tr in
-                      let (_, _, ls), (_, _, rs) = Formula.st f in
-                        if Ptset.Int.is_empty pos then acc_tr else
-                        (TagSet.inj_positive pos,(ls, rs, m))::acc_tr
-                  )
-                  acc
-                  (Hashtbl.find auto.trans q)
-               )
-               states
-               []
-           in
-             (* all labels in the tree compute what transition would be taken *)
-           let all_trs =
-             Ptset.Int.fold (fun tag acc ->
-                               List.fold_left (fun acc' (ts, lhs) ->
-                                                 if TagSet.mem tag ts then
-                                                   (TagSet.singleton tag, lhs)::acc'
-                                                 else acc') acc trs)
-               (Tree.node_tags tree) []
-           in
-             (* now merge together states that have common labels *)
-           let uniq_states_trs =
-             merge_trans by_labels merge_states (List.sort by_labels all_trs)
-           in
-             (* now merge together labels that have common states *)
-           let td_approx =
-             merge_trans by_states merge_labels (List.sort by_states uniq_states_trs)
-           in
-             D_TRACE_(
-               let is_pairwise_disjoint l =
-                 List.for_all (fun ((ts, _) as tr) ->
-                                 List.for_all (fun ((ts', _) as tr') ->
-                                                 (ts == ts' && (by_states tr tr' == 0)) ||
-                                                   TagSet.is_empty (TagSet.cap ts ts')) l) l
-               in
-               let is_complete l = TagSet.positive
-                 (List.fold_left (fun acc (ts, _) -> TagSet.cup acc ts) TagSet.empty l)
-                 ==
-                 (Tree.node_tags tree)
-               in
-                 eprintf "Top-down approximation (%b, %b):\n%!"
-                   (is_pairwise_disjoint td_approx)
-                   (is_complete td_approx);
-                 List.iter (fun (ts,(l,r, m)) ->
-                              let ts = if TagSet.cardinal ts >10
-                              then TagSet.diff TagSet.any 
-                                (TagSet.diff
-                                   (TagSet.inj_positive (Tree.node_tags tree))
-                                   ts)
-                              else ts 
-                              in
-                              eprintf "%a, %a, %b -> %a, %a\n%!"
-                                StateSet.print states
-                                TagSet.print ts
-                                m
-                                StateSet.print l
-                                StateSet.print r
-                           ) td_approx;
-                 eprintf "\n%!"
-
-             );
+  try Cache.find cache states with
+  | Not_found ->
+    let trs =
+      (* Collect all (ts, (l, r, m)) where
+         ts is a tagset, l and r are left and right set of states
+         m is marking flag
+      *)
+      fold_trans_of_states auto
+        (fun acc_tr (ts, tr) ->
+          let pos =
+            if ts == TagSet.star then Tree.element_tags tree
+            else if ts == TagSet.any then Tree.node_tags tree
+            else TagSet.positive ts
+          in
+          let _, _, m, f = Transition.node tr in
+          let ls, rs = Formula.st f in
+          if Ptset.Int.is_empty pos then acc_tr
+          else
+            (TagSet.inj_positive pos, (ls, rs, m))::acc_tr
+        )
+        states
+        []
+    in
+    (* for all labels in the tree compute which transition is taken *)
+    let all_trs =
+      Ptset.Int.fold (fun tag acc ->
+        List.fold_left (fun acc' (ts, rhs) ->
+          if TagSet.mem tag ts then
+            (TagSet.singleton tag, rhs)::acc'
+          else acc') acc trs)
+        (Tree.node_tags tree) []
+    in
+    (* merge together states that have common labels *)
+    let uniq_states_trs =
+      merge_trans by_labels merge_states (List.sort by_labels all_trs)
+    in
+    (* merge together labels that have common states *)
+    let td_approx =
+      merge_trans by_states merge_labels
+        (List.sort by_states uniq_states_trs)
+    in
+    LOG(
+      let is_pairwise_disjoint l =
+        List.for_all (fun ((ts, _) as tr) ->
+          List.for_all (fun ((ts', _) as tr') ->
+            (ts == ts' && (by_states tr tr' == 0)) ||
+              TagSet.is_empty (TagSet.cap ts ts')) l) l
+      in
+      let is_complete l = TagSet.positive
+        (List.fold_left (fun acc (ts, _) -> TagSet.cup acc ts)
+           TagSet.empty l)
+        ==
+        (Tree.node_tags tree)
+      in
+      let pr_td_approx fmt td_approx =
+        List.iter (fun (ts,(l,r, m)) ->
+          let ts = if TagSet.cardinal ts >10
+            then TagSet.diff TagSet.any
+              (TagSet.diff
+                 (TagSet.inj_positive (Tree.node_tags tree))
+                 ts)
+            else ts
+          in
+          fprintf fmt "%a, %a, %b -> %a, %a@\n"
+            StateSet.print states
+            TagSet.print ts
+            m
+            StateSet.print l
+            StateSet.print r
+        ) td_approx;
+        fprintf fmt "\n%!"
+      in
+      __ "top-down-approx" 2 " pairwise-disjoint:%b, complete:%b:@\n%a"
+        (is_pairwise_disjoint td_approx)
+        (is_complete td_approx)
+        pr_td_approx td_approx
+    );
     let jump =
-      compute_jump auto tree states td_approx (List.exists (fun (_,(_,_,b)) -> b) td_approx)
+      compute_jump
+        auto tree states td_approx
+        (List.exists (fun (_,(_,_,b)) -> b) td_approx)
     in
-      jump
-         end
-       in
-         Cache.add cache states jump; jump
+    Cache.add cache states jump; jump
+
+
+
+let get_trans ?(attributes=TagSet.empty) auto states tag =
+  StateSet.fold (fun q acc ->
+    List.fold_left (fun ((tr_acc, l_acc, r_acc) as acc) (ts, tr) ->
+(*      let ts = if ts == TagSet.star then TagSet.diff ts attributes else ts
+      in *)
+      let b = TagSet.mem tag ts in
+      LOG(__ "transition" 3 "tag=<%s>, %s: %a"
+        (Tag.to_string tag)
+        (if b then "    taking" else "not taking")
+        Transition.print tr);
+      if b then
+        let _, _, _, f = Transition.node tr in
+        let l, r = Formula.st f in
+        (Translist.cons tr tr_acc,
+         StateSet.union l l_acc,
+         StateSet.union r r_acc)
+      else acc) acc (Hashtbl.find auto.trans q))
+    states
+    (Translist.nil, StateSet.empty, StateSet.empty)