diff --git a/config/discover.ml b/config/discover.ml index 892359a7..77db78ec 100644 --- a/config/discover.ml +++ b/config/discover.ml @@ -133,6 +133,9 @@ let discover_stats () = (match Unix.getenv "OCANREN_STATS" with | exception Not_found -> [] | _ -> [ "-D"; "STATS" ]); + (match Unix.getenv "OCANREN_TRACE" with + | exception Not_found -> [] + | _ -> [ "-D"; "TRACE" ]); (match Unix.getenv "OCANREN_NON_ABSTRACT_GOAL" with | exception Not_found -> [] | _ -> [ "-D"; "NON_ABSTRACT_GOAL" ]); diff --git a/samples/add.ml b/samples/add.ml index a31c71e6..8a1f6b5a 100644 --- a/samples/add.ml +++ b/samples/add.ml @@ -1,7 +1,6 @@ module L = List open GT -open Printf open OCanren open OCanren.Std @@ -13,12 +12,19 @@ let addo x y z = x == S x' & z == S z' & addo x' y z' } -let _ = - L.iter (fun (q, r) -> printf "q=%s, r=%s\n" q r) @@ - Stream.take ~n:(-1) @@ - ocanrun (q, r : ^Nat.nat) {addo q r 2} -> (show(Nat.logic) q, show(Nat.logic) r) +let () = + let counter = Stdlib.ref 1 in + (ocanrun (q, r : ^Nat.nat) {addo q r 2} -> (Trace.extract_last (), show(Nat.logic) q, show(Nat.logic) r)) + |> Stream.take ~n:(-1) + |> L.iter begin fun (trace, q, r) -> + Format.printf "q=%s, r=%s\n" q r ; + Format.printf "TRACE: %a\n" Trace.pp trace ; + Trace.marshal_to_file (Format.sprintf "add_%d.trace" !counter) trace ; + incr counter + end ; + Format.printf "Saved %d traces\n" (!counter - 1) let _ = - L.iter (fun q -> printf "q=%s\n" q) @@ + L.iter (fun q -> Format.printf "q=%s\n" q) @@ Stream.take ~n:(-1) @@ ocanrun (q : ^Nat.nat) {addo q 1 0} -> (show(Nat.logic) q) diff --git a/samples/dune b/samples/dune index 27fced1f..e098939a 100644 --- a/samples/dune +++ b/samples/dune @@ -7,8 +7,8 @@ (:standard -rectypes)))) (executables - (names tree sorting WGC add) - (modules tree sorting WGC add) + (names tree sorting WGC add show_trace) + (modules tree sorting WGC add show_trace) (flags (:standard ;-dsource @@ -43,7 +43,7 @@ (file %{project_root}/camlp5/pp5+ocanren+dump.exe) (file %{project_root}/camlp5/pp5+gt+plugins+ocanren+logger+dump.exe) (file %{project_root}/camlp5/pp5+ocanren+o.exe)) - (libraries GT OCanren)) + (libraries GT OCanren benchmark)) (executable (name bench_reverso) @@ -123,6 +123,6 @@ OCanren-ppx.ppx_distrib OCanren-ppx.ppx_deriving_reify GT.ppx_all - ppx_expect)) + ppx_expect_nobase)) (preprocessor_deps (file %{project_root}/ppx/pp_ocanren_all.exe))) diff --git a/samples/show_trace.ml b/samples/show_trace.ml new file mode 100644 index 00000000..ffe9c989 --- /dev/null +++ b/samples/show_trace.ml @@ -0,0 +1,19 @@ +open OCanren + + +let () = + if Array.length Sys.argv > 2 then begin + Printf.eprintf "Usage: %s [filename]\n" @@ Sys.argv.(0) ; + exit (-1) + end + +let filename = + if Array.length Sys.argv > 1 then + Sys.argv.(1) + else begin + print_string "Filename: " ; + flush stdout ; + read_line () + end + +let () = Format.printf "%a\n" Trace.pp @@ Trace.unmarshal_from_file filename diff --git a/src/core/Core.ml b/src/core/Core.ml index aa8903cf..cb585dcf 100644 --- a/src/core/Core.ml +++ b/src/core/Core.ml @@ -260,6 +260,7 @@ module State = ; ctrs : Disequality.t ; prunes: Prunes.t ; scope : Term.Var.scope + ; trace : (Term.t * Term.t * t) option } type reified = Env.t * Term.t @@ -270,6 +271,7 @@ module State = ; ctrs = Disequality.empty ; prunes = Prunes.empty ; scope = Term.Var.new_scope () + ; trace = None } let env {env} = env @@ -290,7 +292,8 @@ module State = match Disequality.recheck env subst ctrs prefix with | None -> None | Some ctrs -> - let next_state = { st with subst ; ctrs } in + let trace = IFDEF TRACE THEN Some (Term.repr x, Term.repr y, st) ELSE st.trace END in + let next_state = { st with subst ; ctrs ; trace } in if PrunesControl.is_exceeded () then begin let () = PrunesControl.reset_cur_counter () in @@ -311,33 +314,106 @@ module State = | Prunes.Violated -> None | NonViolated -> Some { st with ctrs } + IFDEF TRACE THEN + + (* assigned in Trace module *) + let save_trace = ref @@ Obj.magic 0 + + END + (* always returns non-empty list *) - let reify x { env ; subst ; ctrs } = + let reify x ({ env ; subst ; ctrs } as st) = let answ = Subst.reify env subst x in match Disequality.reify env subst ctrs x with | [] -> (* [Answer.make env answ] *) assert false - | diseqs -> ListLabels.map diseqs ~f:begin fun diseq -> - let rec helper forbidden t = Term.map t ~fval:Term.repr - ~fvar:begin fun v -> Term.repr @@ - if Term.VarSet.mem v forbidden then v - else { v with Term.Var.constraints = Disequality.Answer.extract diseq v - |> List.filter begin fun dt -> - match Env.var env dt with - | Some u -> not @@ Term.VarSet.mem u forbidden - | None -> true - end - |> List.map (fun x -> helper (Term.VarSet.add v forbidden) x) - (* TODO: represent [Var.constraints] as [Set]; - * TODO: hide all manipulations on [Var.t] inside [Var] module; - *) - |> List.sort Term.compare - } - end - in - Answer.make env @@ helper Term.VarSet.empty answ - end + | diseqs -> + IFDEF TRACE THEN !save_trace st ELSE let _ = st in () END ; + ListLabels.map diseqs ~f:begin fun diseq -> + let rec helper forbidden t = Term.map t ~fval:Term.repr + ~fvar:begin fun v -> Term.repr @@ + if Term.VarSet.mem v forbidden then v + else { v with Term.Var.constraints = Disequality.Answer.extract diseq v + |> List.filter begin fun dt -> + match Env.var env dt with + | Some u -> not @@ Term.VarSet.mem u forbidden + | None -> true + end + |> List.map (fun x -> helper (Term.VarSet.add v forbidden) x) + (* TODO: represent [Var.constraints] as [Set]; + * TODO: hide all manipulations on [Var.t] inside [Var] module; + *) + |> List.sort Term.compare + } + end + in + Answer.make env @@ helper Term.VarSet.empty answ + end + end + +IFDEF TRACE THEN + +module Trace : + sig + + type t + + val pp : Format.formatter -> t -> unit + + val extract_last : unit -> t + + val marshal : out_channel -> t -> unit + val unmarshal : ?env:Term.Var.env -> ?scope:Term.Var.scope -> in_channel -> t + + val marshal_to_file : string -> t -> unit + val unmarshal_from_file : ?env:Term.Var.env -> ?scope:Term.Var.scope -> string -> t + end = struct + + type t = (Term.t * Term.t) list + + let saved_state = ref None + + let () = State.save_trace := fun st -> saved_state := Some st + + let pp = + let hlp ppf (l, r) = Format.fprintf ppf "%a = %a" Term.pp l Term.pp r in + let pp_sep ppf () = Format.fprintf ppf "; " in + Format.pp_print_list ~pp_sep hlp + + let extract = + let rec extract acc = function + | Some (l, r, st) -> extract ((l, r)::acc) st.State.trace + | None -> acc + in + + extract [] + + let extract_last () = match !saved_state with + | Some st -> extract st.State.trace + | None -> raise Not_found + + let rec marshal chan = function + | [] -> () + | (t1, t2)::xs -> + Term.marshal chan t1 ; + Term.marshal chan t2 ; + marshal chan xs + + let[@tail_mod_cons] rec unmarshal ?(env=0) ?(scope=Term.Var.non_local_scope) chan = + match Term.unmarshal ~env ~scope chan with + | exception End_of_file -> [] + | t1 -> + let t2 = Term.unmarshal ~env ~scope chan in + (t1, t2) :: unmarshal ~env ~scope chan + + let marshal_to_file filename value = + Out_channel.with_open_bin filename @@ fun oc -> marshal oc value + + let unmarshal_from_file ?env ?scope filename = + In_channel.with_open_bin filename @@ unmarshal ?env ?scope end +END + let (!!!) = Obj.magic type 'a goal' = State.t -> 'a @@ -816,4 +892,4 @@ let is_ground_bool let ans = Subst.reify (State.env st) (State.subst st) (Obj.magic v) in if (Term.is_var ans) then onvar() else on_ground (Obj.magic ans : bool) -END \ No newline at end of file +END diff --git a/src/core/Core.mli b/src/core/Core.mli index 258822df..3c9cfd12 100644 --- a/src/core/Core.mli +++ b/src/core/Core.mli @@ -304,6 +304,23 @@ val disj_counter : unit -> int val delay_counter : unit -> int END +IFDEF TRACE THEN + +module Trace : sig + type t + + val pp : Format.formatter -> t -> unit + val extract_last : unit -> t + + val marshal : out_channel -> t -> unit + val unmarshal : ?env:Term.Var.env -> ?scope:Term.Var.scope -> in_channel -> t + + val marshal_to_file : string -> t -> unit + val unmarshal_from_file : ?env:Term.Var.env -> ?scope:Term.Var.scope -> string -> t +end + +END + (** The call [debug_var var reifier callback] performs reification of variable [var] in a current state using [reifier] and passes list of answer to [callback] (multiple answers can arise in presence of disequality constraints). The [callback] can investigate reified value and construct required goal to continue search. See also: {!structural}. @@ -338,4 +355,4 @@ val is_ground : 'a ilogic -> State.t -> (bool -> unit) -> unit val is_ground_bool : bool ilogic -> State.t -> onvar:(unit->unit) -> on_ground:(bool -> unit) -> unit -END \ No newline at end of file +END diff --git a/src/core/Term.ml b/src/core/Term.ml index 75e0a164..2713ab76 100644 --- a/src/core/Term.ml +++ b/src/core/Term.ml @@ -314,3 +314,80 @@ let rec compare x y = let rec hash x = fold x ~init:1 ~fvar:(fun acc v -> Hashtbl.hash (Var.hash v, List.fold_left (fun acc x -> Hashtbl.hash (acc, hash x)) acc v.Var.constraints)) ~fval:(fun acc x -> Hashtbl.hash (acc, Hashtbl.hash x)) + +module Int64_IO : sig + + val input : in_channel -> int64 + val output : out_channel -> int64 -> unit +end = struct + + let buf = Bytes.create 8 + + let output chan x = + Bytes.set_int64_le buf 0 x ; + output_bytes chan buf + + let input chan = + really_input chan buf 0 8 ; + Bytes.get_int64_le buf 0 +end + +let output_int64 = Int64_IO.output +let input_int64 = Int64_IO.input + +(* + * Variable: [0 : byte] [index : int64 ] + * Int value: [1 : byte] [value : int64 ] + * Float value: [2 : byte] [value : float64] + * String value: [3 : byte] [length : int64 ] [bytes...] + * Functor: [4 : byte] [tag : int64 ] [arity : int64] [subterms...] + *) +let rec marshal chan x = + let tx = Obj.tag x in + if is_box tx then + let sx = Obj.size x in + if Var.has_var_structure tx sx x then begin + output_byte chan 0 ; + output_int64 chan @@ Int64.of_int (obj x).Var.index + end else begin + output_byte chan 4 ; + output_int64 chan @@ Int64.of_int tx ; + output_int64 chan @@ Int64.of_int sx ; + for i = 0 to sx - 1 do + marshal chan @@ Obj.field x i + done + end + else begin + check_val tx ; + if is_int tx then begin + output_byte chan 1 ; + output_int64 chan @@ Int64.of_int @@ obj x + end else if is_dbl tx then begin + output_byte chan 2 ; + output_int64 chan @@ Int64.bits_of_float @@ obj x + end else if is_str tx then begin + output_byte chan 3 ; + let value : string = obj x in + output_int64 chan @@ Int64.of_int @@ String.length value ; + output_string chan value + end else + assert false + end + +let rec unmarshal ~env ~scope chan = + match input_byte chan with + | 0 -> repr @@ Var.make ~env ~scope @@ Int64.to_int @@ input_int64 chan + | 1 -> repr (Int64.to_int @@ input_int64 chan : int) + | 2 -> repr (Int64.float_of_bits @@ input_int64 chan : float) + | 3 -> + let length = Int64.to_int @@ input_int64 chan in + Obj.repr (Stdlib.Option.get @@ In_channel.really_input_string chan length : string) + | 4 -> + let tx = Int64.to_int @@ input_int64 chan in + let sx = Int64.to_int @@ input_int64 chan in + let res = Obj.new_block tx sx in + for i = 0 to sx - 1 do + Obj.set_field res i @@ unmarshal ~env ~scope chan + done ; + res + | x -> invalid_arg @@ Printf.sprintf "Term.unmarshal: incorrect data (%d)" x diff --git a/src/core/Term.mli b/src/core/Term.mli index ca3c75c3..8935dae1 100644 --- a/src/core/Term.mli +++ b/src/core/Term.mli @@ -117,3 +117,6 @@ val fold2 : val equal : t -> t -> bool val compare : t -> t -> int val hash : t -> int + +val marshal : out_channel -> t -> unit +val unmarshal : env:Var.env -> scope:Var.scope -> in_channel -> t