diff --git a/CHANGES.md b/CHANGES.md index facafe4..af84110 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,6 +1,8 @@ ## unreleased +- new API for `Path_condition` which makes enforcing invariants easier - expose `Path_condition.pp` +- substitute equalities in the path condition ## 0.1 - 2026-02-09 diff --git a/shell.nix b/shell.nix index 25965c8..e61f4ce 100644 --- a/shell.nix +++ b/shell.nix @@ -3,6 +3,18 @@ }) {} }: +let + smtml = pkgs.ocamlPackages.smtml.overrideAttrs (old: { + src = pkgs.fetchFromGitHub { + owner = "formalsec"; + repo = "smtml"; + rev = "a9dff52e7ef2215c786ee8ce2c24d716db0b5ace"; + hash = "sha256-TIOOE/bsis6oYV3Dt6TcI/r/aN3S1MQNtxDAnvBbVO0="; + }; + doCheck = false; + }); +in + pkgs.mkShell { name = "symex-dev-shell"; dontDetectOcamlConflicts = true; @@ -13,7 +25,7 @@ pkgs.mkShell { bisect_ppx #landmarks #landmarks-ppx - mdx + #mdx merlin ocaml ocamlformat diff --git a/src/path_condition.ml b/src/path_condition.ml index d14c4b4..2a02653 100644 --- a/src/path_condition.ml +++ b/src/path_condition.ml @@ -2,56 +2,155 @@ (* Copyright © 2021-2024 OCamlPro *) (* Written by the Owi programmers *) -module Union_find = Union_find.Make (Smtml.Symbol) - -type t = Smtml.Expr.Set.t Union_find.t - -let pp : t Fmt.t = fun ppf pc -> (Union_find.pp Smtml.Expr.Set.pp) ppf pc - -let empty : t = Union_find.empty - -let add_one (condition : Smtml.Expr.t) pc : t = - match Smtml.Expr.get_symbols [ condition ] with - | hd :: tl -> - (* We add the first symbol to the UF *) - let pc = - let c = Smtml.Expr.Set.singleton condition in - Union_find.add ~merge:Smtml.Expr.Set.union hd c pc - in - (* We union-ize all symbols together, starting with the first one that has already been added *) - let pc, _last_sym = - List.fold_left - (fun (pc, last_sym) sym -> - (Union_find.union ~merge:Smtml.Expr.Set.union last_sym sym pc, sym) ) - (pc, hd) tl - in - pc - | [] -> - (* It means smtml did not properly simplified an expression! *) - assert false - -let add (condition : Smtml.Typed.Bool.t) (pc : t) : t = +module Union_find = struct + include Union_find.Make (Smtml.Symbol) + + type nonrec t = Smtml.Expr.Set.t t + + let pp : t Fmt.t = fun ppf union_find -> pp Smtml.Expr.Set.pp ppf union_find +end + +module Equalities : sig + type t = Smtml.Value.t Smtml.Symbol.Map.t + + val pp : t Fmt.t +end = struct + type t = Smtml.Value.t Smtml.Symbol.Map.t + + let pp ppf v = + Fmt.pf ppf "@[{%a}@]" + (Fmt.iter_bindings + ~sep:(fun ppf () -> Fmt.pf ppf ",@ ") + Smtml.Symbol.Map.iter + (fun ppf (k, v) -> + Fmt.pf ppf "%a = %a" Smtml.Symbol.pp k Smtml.Value.pp v ) ) + v +end + +type t = + { union_find : Union_find.t + ; equalities : Equalities.t + } + +let pp : t Fmt.t = + fun ppf { union_find; equalities } -> + Fmt.pf ppf "union find:@\n @[%a@]@\nequalities:@\n @[%a@]@\n" + Union_find.pp union_find Equalities.pp equalities + +let empty : t = + let union_find = Union_find.empty in + let equalities = Smtml.Symbol.Map.empty in + { union_find; equalities } + +let rec add_one_equality symbol value (pc : t) : t = + let equalities = Smtml.Symbol.Map.add symbol value pc.equalities in + match Union_find.find_and_remove_opt symbol pc.union_find with + | Some (set, union_find) -> + (* propagate back the equality in the union find *) + let pc = { union_find; equalities } in + Smtml.Expr.Set.fold add_one_constraint set pc + | None -> { pc with equalities } + +and add_one_constraint (condition : Smtml.Expr.t) (pc : t) : t = + (* we start by simplifying the constraint by substituting already known equalities *) + let condition = Smtml.Expr.inline_symbol_values pc.equalities condition in + let condition = Smtml.Expr.simplify condition in + + let pc, shortcut = + match Smtml.Expr.view condition with + (* if the condition is of the form e1 = e2 *) + | Relop (_, Smtml.Ty.Relop.Eq, e1, e2) -> begin + match (Smtml.Expr.view e1, Smtml.Expr.view e2) with + (* it has the form: symbol = value *) + | Smtml.Expr.Symbol symbol, Val value | Val value, Symbol symbol -> begin + match Smtml.Symbol.Map.find_opt symbol pc.equalities with + | None -> + (* we don't have an equality for s=v so we add it *) + (add_one_equality symbol value pc, true) + | Some value' -> + (* that would mean the PC is unsat which is illegal! *) + assert (Smtml.Eval.relop symbol.ty Eq value value'); + (* we discovered an already known equality, nothing to do *) + (pc, true) + end + | _ -> (pc, false) + end + | _ -> (pc, false) + in + + if shortcut then pc + else + match Smtml.Expr.view condition with + | Val True -> + (* no need to change anything, the condition is a tautology *) + pc + | Val False -> + (* the PC is unsat *) + assert false + | _ -> begin + match Smtml.Expr.get_symbols [ condition ] with + | hd :: tl -> + (* We add the first symbol to the UF *) + let union_find = + let c = Smtml.Expr.Set.singleton condition in + Union_find.add ~merge:Smtml.Expr.Set.union hd c pc.union_find + in + (* We union-ize all symbols together, starting with the first one that has already been added *) + let union_find, _last_sym = + List.fold_left + (fun (union_find, last_sym) sym -> + ( Union_find.union ~merge:Smtml.Expr.Set.union last_sym sym + union_find + , sym ) ) + (union_find, hd) tl + in + { pc with union_find } + | [] -> + (* either it is a boolean value (because it is a condition) and was matched before, either it's not a value and should have at least one symbol! *) + assert false + end + +let add_checked_sat_condition (condition : Smtml.Typed.Bool.t) (pc : t) : t = (* we start by splitting the condition ((P & Q) & R) into a set {P; Q; R} before adding each of P, Q and R into the UF data structure, this way we maximize the independence of the PC *) let splitted_condition = Smtml.Typed.Bool.split_conjunctions condition in - Smtml.Expr.Set.fold add_one splitted_condition pc + Smtml.Expr.Set.fold add_one_constraint splitted_condition pc (* Get all sub conditions of the path condition as a list of independent sets of constraints. *) -let slice pc = Union_find.explode pc +let to_list (pc : t) = + let uf = Union_find.explode pc.union_find in + (* add equalities too *) + Smtml.Symbol.Map.fold + (fun symbol value acc -> + let eq = + Smtml.Expr.Bool.equal (Smtml.Expr.symbol symbol) + (Smtml.Expr.value value) + |> Smtml.Expr.Set.singleton + in + eq :: acc ) + pc.equalities uf (* Return the set of constraints from [pc] that are relevant for [sym]. *) -let slice_on_symbol (sym : Smtml.Symbol.t) pc : Smtml.Expr.Set.t = - match Union_find.find_opt sym pc with +let slice_on_symbol (sym : Smtml.Symbol.t) (pc : t) : Smtml.Expr.Set.t = + match Union_find.find_opt sym pc.union_find with | Some s -> s - | None -> - (* if there is a symbol, it should have been added to the union-find structure before, otherwise it means `add` has not been called properly before *) - assert false + | None -> ( + match Smtml.Symbol.Map.find_opt sym pc.equalities with + | Some value -> + let eq = + Smtml.Expr.Bool.equal (Smtml.Expr.symbol sym) (Smtml.Expr.value value) + in + Smtml.Expr.Set.singleton eq + | None -> Smtml.Expr.Set.empty ) (* Return the set of constraints from [pc] that are relevant for [c]. *) -let slice_on_condition (c : Smtml.Typed.Bool.t) pc : Smtml.Expr.Set.t = - match Smtml.Typed.get_symbols [ c ] with - | sym0 :: _tl -> - (* we need only the first symbol as all the others should have been merged with it *) - slice_on_symbol sym0 pc - | [] -> - (* It means smtml did not properly simplified a expression! *) - assert false +let slice_on_new_condition (c : Smtml.Typed.Bool.t) (pc : t) : Smtml.Expr.Set.t + = + let symbols = Smtml.Typed.get_symbols [ c ] in + List.fold_left + (fun set symbol -> + let slice = slice_on_symbol symbol pc in + Smtml.Expr.Set.union set slice ) + Smtml.Expr.Set.empty symbols + +let get_known_equalities (pc : t) : Smtml.Value.t Smtml.Symbol.Map.t = + pc.equalities diff --git a/src/path_condition.mli b/src/path_condition.mli index e1ea611..a247467 100644 --- a/src/path_condition.mli +++ b/src/path_condition.mli @@ -8,12 +8,17 @@ val pp : t Fmt.t val empty : t -val add : Smtml.Typed.Bool.t -> t -> t +(* You should only use this after checking the condition is SAT in the context of the path condition. If you add it without checking, it may lead to wrong results. *) +val add_checked_sat_condition : Smtml.Typed.Bool.t -> t -> t -(* CAUTION: this must only be called after the symbol has been added to the path condition *) +(* Get all the slices that are related to the condition given in argument. The condition is not added to the result. *) +val slice_on_new_condition : Smtml.Typed.Bool.t -> t -> Smtml.Expr.Set.t + +(* Get the slice for a symbol. Return an empty set if the symbol is not part of the path condition. *) val slice_on_symbol : Smtml.Symbol.t -> t -> Smtml.Expr.Set.t -(* CAUTION: this must only be called after the condition added to the path condition with `add` *) -val slice_on_condition : Smtml.Typed.Bool.t -> t -> Smtml.Expr.Set.t +(* Get all slices of the path condition as a list. *) +val to_list : t -> Smtml.Expr.Set.t list -val slice : t -> Smtml.Expr.Set.t list +(* Return a map of known equalities from symbols to values. *) +val get_known_equalities : t -> Smtml.Value.t Smtml.Symbol.Map.t diff --git a/src/union_find.ml b/src/union_find.ml index 3c2c5d4..413467d 100644 --- a/src/union_find.ml +++ b/src/union_find.ml @@ -38,6 +38,8 @@ module type S = sig [key] does not need to be canonical. *) val find_opt : key -> 'a t -> 'a option + val find_and_remove_opt : key -> 'a t -> ('a * 'a t) option + (** [union ~merge key1 key2 uf] merges the equivalence classes associated with [key1] and [key2], calling [merge] on the corresponding values. *) val union : merge:('a -> 'a -> 'a) -> key -> key -> 'a t -> 'a t @@ -63,28 +65,17 @@ module Make (X : VariableType) : S with type key = X.t = struct } let print_set ppf set = - if SX.is_empty set then Fmt.pf ppf "{}" - else ( - Fmt.pf ppf "@[{"; - let first = ref true in - SX.iter - (fun x -> - if !first then first := false else Fmt.pf ppf ",@ "; - X.pp ppf x ) - set; - Fmt.pf ppf "}@]" ) + Fmt.pf ppf "@[{%a}@]" + (Fmt.iter ~sep:(fun ppf () -> Fmt.pf ppf ",@") SX.iter X.pp) + set let print_map pp ppf map = - if MX.is_empty map then Fmt.pf ppf "{}" - else ( - Fmt.pf ppf "@[{"; - let first = ref true in - MX.iter - (fun key value -> - if !first then first := false else Fmt.pf ppf ",@ "; - Fmt.pf ppf "@[(%a@ %a)@]" X.pp key pp value ) - map; - Fmt.pf ppf "}@]" ) + Fmt.pf ppf "@[{%a}@]" + (Fmt.iter_bindings + ~sep:(fun ppf () -> Fmt.pf ppf ",@ ") + MX.iter + (fun ppf (k, v) -> Fmt.pf ppf "@[(%a@ %a)@]" X.pp k pp v) ) + map let print_aliases ppf { aliases; _ } = print_set ppf aliases @@ -137,6 +128,21 @@ module Make (X : VariableType) : S with type key = X.t = struct (find_node_opt (find_canonical variable t) t) (fun node -> node.datum) + let find_and_remove_opt variable t = + let canonical = find_canonical variable t in + match find_node_opt canonical t with + | None -> None + | Some node -> ( + match node.datum with + | None -> None + | Some datum -> + let node_of_canonicals = MX.remove canonical t.node_of_canonicals in + let canonical_elements = + MX.filter (fun _k v -> v <> canonical) t.canonical_elements + in + let t = { canonical_elements; node_of_canonicals } in + Some (datum, t) ) + let set_canonical_element aliases canonical canonical_elements = SX.fold (fun alias canonical_elements -> MX.add alias canonical canonical_elements) diff --git a/test/test_path_condition.ml b/test/test_path_condition.ml index f65e21a..d80e532 100644 --- a/test/test_path_condition.ml +++ b/test/test_path_condition.ml @@ -15,43 +15,59 @@ let print pc = Fmt.str "%a" Symex.Path_condition.pp pc let test_print_pc_empty () : unit = let pc = Symex.Path_condition.empty |> print in Alcotest.(check string) - "same string" "((aliases_of_canonicals {}) (payload_of_canonicals {}))" pc + "same string" + "union find:\n\ + \ ((aliases_of_canonicals {}) (payload_of_canonicals {}))\n\ + equalities:\n\ + \ {}\n" + pc let test_print_pc_one () : unit = let pc = Symex.Path_condition.empty - |> Symex.Path_condition.add (eq (i32_sym "x0") (i32_const 42l)) + |> Symex.Path_condition.add_checked_sat_condition + (eq (i32_sym "x0") (i32_const 42l)) |> print in Alcotest.(check string) "same string" - "((aliases_of_canonicals {(x0 {})})\n\ - \ (payload_of_canonicals {(x0 (bool.eq x0 42))}))" + "union find:\n\ + \ ((aliases_of_canonicals {}) (payload_of_canonicals {}))\n\ + equalities:\n\ + \ {x0 = 42}\n" pc let test_print_pc_two () : unit = let pc = Symex.Path_condition.empty - |> Symex.Path_condition.add (eq (i32_sym "x0") (i32_const 42l)) - |> Symex.Path_condition.add (eq (i32_sym "y0") (i32_const 56l)) + |> Symex.Path_condition.add_checked_sat_condition + (eq (i32_sym "x0") (i32_const 42l)) + |> Symex.Path_condition.add_checked_sat_condition + (eq (i32_sym "y0") (i32_const 56l)) |> print in Alcotest.(check string) "same string" - "((aliases_of_canonicals {(x0 {}), (y0 {})})\n\ - \ (payload_of_canonicals {(x0 (bool.eq x0 42)), (y0 (bool.eq y0 56))}))" + "union find:\n\ + \ ((aliases_of_canonicals {}) (payload_of_canonicals {}))\n\ + equalities:\n\ + \ {x0 = 42, y0 = 56}\n" pc let test_print_pc_one_two_sym () : unit = let pc = Symex.Path_condition.empty - |> Symex.Path_condition.add (eq (i32_sym "x0") (i32_sym "y0")) + |> Symex.Path_condition.add_checked_sat_condition + (eq (i32_sym "x0") (i32_sym "y0")) |> print in Alcotest.(check string) "same string" - "((aliases_of_canonicals {(x0 {y0})})\n\ - \ (payload_of_canonicals {(x0 (bool.eq x0 y0))}))" + "union find:\n\ + \ ((aliases_of_canonicals {(x0 {y0})})\n\ + \ (payload_of_canonicals {(x0 (bool.eq x0 y0))}))\n\ + equalities:\n\ + \ {}\n" pc let basics_suite = @@ -69,7 +85,7 @@ let test_slice_on_symbol_single () = let slice = let x0 = sym_i32 "x0" in Symex.Path_condition.empty - |> Symex.Path_condition.add + |> Symex.Path_condition.add_checked_sat_condition (eq (Smtml.Typed.Bitv32.symbol x0) (i32_const 42l)) |> Symex.Path_condition.slice_on_symbol x0 in @@ -81,8 +97,9 @@ let test_slice_on_symbol_irrelevant () = let slice = let y0 = sym_i32 "y0" in Symex.Path_condition.empty - |> Symex.Path_condition.add (eq (i32_sym "x0") (i32_const 42l)) - |> Symex.Path_condition.add + |> Symex.Path_condition.add_checked_sat_condition + (eq (i32_sym "x0") (i32_const 42l)) + |> Symex.Path_condition.add_checked_sat_condition (eq (Smtml.Typed.Bitv32.symbol y0) (i32_const 56l)) |> Symex.Path_condition.slice_on_symbol y0 in @@ -92,9 +109,11 @@ let test_slice_on_symbol_transitive () = let slice = let x0 = sym_i32 "x0" in let y0 = i32_sym "y0" in + let z0 = i32_sym "z0" in Symex.Path_condition.empty - |> Symex.Path_condition.add (eq (Smtml.Typed.Bitv32.symbol x0) y0) - |> Symex.Path_condition.add (eq y0 (i32_const 5l)) + |> Symex.Path_condition.add_checked_sat_condition + (eq (Smtml.Typed.Bitv32.symbol x0) y0) + |> Symex.Path_condition.add_checked_sat_condition (eq y0 z0) |> Symex.Path_condition.slice_on_symbol x0 in Alcotest.(check int) @@ -109,24 +128,28 @@ let slice_on_symbol_suite = ; test_case "transitive" `Quick test_slice_on_symbol_transitive ] ) -(* `slice_on_condition` *) +(* `slice_on_new_condition` *) let test_slice_on_condition_basic () = let slice = let c1 = eq (i32_sym "x0") (i32_const 10l) in Symex.Path_condition.empty - |> Symex.Path_condition.add c1 - |> Symex.Path_condition.slice_on_condition c1 + |> Symex.Path_condition.add_checked_sat_condition c1 + |> Symex.Path_condition.slice_on_new_condition c1 in Alcotest.(check int) "condition slice size" 1 (Smtml.Expr.Set.cardinal slice) let test_slice_on_condition_transitive () = let slice = - let c1 = eq (i32_sym "x0") (i32_sym "y0") in + let x0 = i32_sym "x0" in + let y0 = i32_sym "y0" in + let z0 = i32_sym "z0" in + let c1 = eq x0 y0 in + let c2 = eq y0 z0 in Symex.Path_condition.empty - |> Symex.Path_condition.add c1 - |> Symex.Path_condition.add (eq (i32_sym "y0") (i32_const 7l)) - |> Symex.Path_condition.slice_on_condition c1 + |> Symex.Path_condition.add_checked_sat_condition c1 + |> Symex.Path_condition.add_checked_sat_condition c2 + |> Symex.Path_condition.slice_on_new_condition c1 in Alcotest.(check int) "transitive slice size" 2 (Smtml.Expr.Set.cardinal slice) @@ -137,27 +160,31 @@ let slice_on_condition_suite = ; test_case "transitive" `Quick test_slice_on_condition_transitive ] ) -(* `slice` *) +(* `to_list` *) let test_slice_empty () = - let slices = Symex.Path_condition.empty |> Symex.Path_condition.slice in + let slices = Symex.Path_condition.empty |> Symex.Path_condition.to_list in Alcotest.(check int) "no slices" 0 (List.length slices) let test_slice_independent () = let slices = Symex.Path_condition.empty - |> Symex.Path_condition.add (eq (i32_sym "x0") (i32_const 1l)) - |> Symex.Path_condition.add (eq (i32_sym "y0") (i32_const 2l)) - |> Symex.Path_condition.slice + |> Symex.Path_condition.add_checked_sat_condition + (eq (i32_sym "x0") (i32_const 1l)) + |> Symex.Path_condition.add_checked_sat_condition + (eq (i32_sym "y0") (i32_const 2l)) + |> Symex.Path_condition.to_list in Alcotest.(check int) "two independent slices" 2 (List.length slices) let test_slice_connected () = let slices = Symex.Path_condition.empty - |> Symex.Path_condition.add (eq (i32_sym "x0") (i32_sym "y0")) - |> Symex.Path_condition.add (eq (i32_sym "y0") (i32_const 3l)) - |> Symex.Path_condition.slice + |> Symex.Path_condition.add_checked_sat_condition + (eq (i32_sym "x0") (i32_sym "y0")) + |> Symex.Path_condition.add_checked_sat_condition + (eq (i32_sym "y0") (i32_sym "z0")) + |> Symex.Path_condition.to_list in Alcotest.(check int) "one connected slice" 1 (List.length slices)