diff --git a/Cslib.lean b/Cslib.lean index a9d5ffc3e..2fc4a8d06 100644 --- a/Cslib.lean +++ b/Cslib.lean @@ -1,7 +1,15 @@ module -- shake: keep-all -public import Cslib.Algorithms.Lean.MergeSort.MergeSort -public import Cslib.Algorithms.Lean.TimeM +public import Cslib.AlgorithmsTheory.Algorithms.ListInsertionSort +public import Cslib.AlgorithmsTheory.Algorithms.ListLinearSearch +public import Cslib.AlgorithmsTheory.Algorithms.ListOrderedInsert +public import Cslib.AlgorithmsTheory.Algorithms.MergeSort +public import Cslib.AlgorithmsTheory.Lean.MergeSort.MergeSort +public import Cslib.AlgorithmsTheory.Lean.TimeM +public import Cslib.AlgorithmsTheory.LowerBounds.ComparisonSort +public import Cslib.AlgorithmsTheory.Models.ListComparisonSearch +public import Cslib.AlgorithmsTheory.Models.ListComparisonSort +public import Cslib.AlgorithmsTheory.QueryModel public import Cslib.Computability.Automata.Acceptors.Acceptor public import Cslib.Computability.Automata.Acceptors.OmegaAcceptor public import Cslib.Computability.Automata.DA.Basic diff --git a/Cslib/AlgorithmsTheory/Algorithms/ListInsertionSort.lean b/Cslib/AlgorithmsTheory/Algorithms/ListInsertionSort.lean new file mode 100644 index 000000000..5bfb62d73 --- /dev/null +++ b/Cslib/AlgorithmsTheory/Algorithms/ListInsertionSort.lean @@ -0,0 +1,93 @@ +/- +Copyright (c) 2026 Shreyas Srinivas. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Shreyas Srinivas, Eric Wieser +-/ +module + +public import Cslib.AlgorithmsTheory.Algorithms.ListOrderedInsert +public import Mathlib.Tactic.NormNum + +@[expose] public section + +/-! +# Insertion sort in a list + +In this file we state and prove the correctness and complexity of insertion sort in lists under +the `SortOpsInsertHead` model. This insertionSort evaluates identically to the upstream version of +`List.insertionSort` +-- + +## Main Definitions + +- `insertionSort` : Insertion sort algorithm in the `SortOpsInsertHead` query model + +## Main results + +- `insertionSort_eval`: `insertionSort` evaluates identically to `List.insertionSort`. +- `insertionSort_permutation` : `insertionSort` outputs a permutation of the input list. +- `insertionSort_sorted` : `insertionSort` outputs a sorted list. +- `insertionSort_complexity` : `insertionSort` takes at most n * (n + 1) comparisons and + (n + 1) * (n + 2) list head-insertions. +-/ + +namespace Cslib + +namespace Algorithms + +open Prog + +/-- The insertionSort algorithms on lists with the `SortOps` query. -/ +def insertionSort (l : List α) : Prog (SortOpsInsertHead α) (List α) := + match l with + | [] => return [] + | x :: xs => do + let rest ← insertionSort xs + insertOrd x rest + +@[simp] +theorem insertionSort_eval (l : List α) (le : α → α → Bool) : + (insertionSort l).eval (sortModel le) = l.insertionSort (fun x y => le x y = true) := by + induction l with simp_all [insertionSort] + +theorem insertionSort_permutation (l : List α) (le : α → α → Bool) : + ((insertionSort l).eval (sortModel le)).Perm l := by + simp [insertionSort_eval, List.perm_insertionSort] + +theorem insertionSort_sorted + (l : List α) (le : α → α → Bool) + [Std.Total (fun x y => le x y = true)] [IsTrans α (fun x y => le x y = true)] : + ((insertionSort l).eval (sortModel le)).Pairwise (fun x y => le x y = true) := by + simpa using List.pairwise_insertionSort _ _ + +lemma insertionSort_length (l : List α) (le : α → α → Bool) : + ((insertionSort l).eval (sortModel le)).length = l.length := by + simp + +lemma insertionSort_time_compares (head : α) (tail : List α) (le : α → α → Bool) : + ((insertionSort (head :: tail)).time (sortModel le)).compares = + ((insertionSort tail).time (sortModel le)).compares + + ((insertOrd head (tail.insertionSort (fun x y => le x y = true))).time + (sortModel le)).compares := by + simp [insertionSort] + +lemma insertionSort_time_inserts (head : α) (tail : List α) (le : α → α → Bool) : + ((insertionSort (head :: tail)).time (sortModel le)).inserts = + ((insertionSort tail).time (sortModel le)).inserts + + ((insertOrd head (tail.insertionSort (fun x y => le x y = true))).time + (sortModel le)).inserts := by + simp [insertionSort] + +theorem insertionSort_complexity (l : List α) (le : α → α → Bool) : + ((insertionSort l).time (sortModel le)) + ≤ ⟨l.length * (l.length + 1), (l.length + 1) * (l.length + 2)⟩ := by + induction l with + | nil => + simp [insertionSort] + | cons head tail ih => + grind [insertOrd_complexity_upper_bound, List.length_insertionSort, SortOpsCost.le_def, + insertionSort_time_compares, insertionSort_time_inserts] + +end Algorithms + +end Cslib diff --git a/Cslib/AlgorithmsTheory/Algorithms/ListLinearSearch.lean b/Cslib/AlgorithmsTheory/Algorithms/ListLinearSearch.lean new file mode 100644 index 000000000..9c685886a --- /dev/null +++ b/Cslib/AlgorithmsTheory/Algorithms/ListLinearSearch.lean @@ -0,0 +1,88 @@ +/- +Copyright (c) 2026 Shreyas Srinivas. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Shreyas Srinivas, Eric Wieser +-/ + +module + +public import Cslib.AlgorithmsTheory.QueryModel +public import Cslib.AlgorithmsTheory.Models.ListComparisonSearch +public import Batteries.Data.List +public import Mathlib.Algebra.Order.Group.Nat +public import Mathlib.Tactic.Set + +@[expose] public section + +/-! +# Linear search in a list + +In this file we state and prove the correctness and complexity of linear search in lists under +the `ListSearch` model. +-- + +## Main Definitions + +- `listLinearSearch` : Linear search algorithm in the `ListSearch` query model + +## Main results + +- `listLinearSearch_eval`: `insertOrd` evaluates identically to `List.contains`. +- `listLinearSearchM_time_complexity_upper_bound` : `linearSearch` takes at most `n` + comparison operations +- `listLinearSearchM_time_complexity_lower_bound` : There exist lists on which `linearSearch` needs + `n` comparisons +-/ +namespace Cslib + +namespace Algorithms + +open Prog + +open ListSearch in +/-- Linear Search in Lists on top of the `ListSearch` query model. -/ +def listLinearSearch (l : List α) (x : α) : Prog (ListSearch α) Bool := do + match l with + | [] => return false + | l :: ls => + let cmp : Bool ← compare (l :: ls) x + if cmp then + return true + else + listLinearSearch ls x + +@[simp, grind =] +lemma listLinearSearch_eval [BEq α] (l : List α) (x : α) : + (listLinearSearch l x).eval ListSearch.natCost = l.contains x := by + fun_induction l.elem x with simp_all [listLinearSearch] + +lemma listLinearSearchM_correct_true [BEq α] [LawfulBEq α] (l : List α) + {x : α} (x_mem_l : x ∈ l) : (listLinearSearch l x).eval ListSearch.natCost = true := by + simp [x_mem_l] + +lemma listLinearSearchM_correct_false [BEq α] [LawfulBEq α] (l : List α) + {x : α} (x_mem_l : x ∉ l) : (listLinearSearch l x).eval ListSearch.natCost = false := by + simp [x_mem_l] + +lemma listLinearSearchM_time_complexity_upper_bound [BEq α] (l : List α) (x : α) : + (listLinearSearch l x).time ListSearch.natCost ≤ l.length := by + fun_induction l.elem x with + | case1 => simp [listLinearSearch] + | case2 => simp_all [listLinearSearch] + | case3 => + simp [listLinearSearch] + lia + +lemma listLinearSearchM_time_complexity_lower_bound [DecidableEq α] [Nontrivial α] (n : ℕ) : + ∃ (l : List α) (x : α), l.length = n + ∧ (listLinearSearch l x).time ListSearch.natCost = l.length := by + obtain ⟨x, y, hneq⟩ := exists_pair_ne α + use List.replicate n y, x + split_ands + · simp + · induction n <;> simp [listLinearSearch, List.replicate] + grind [ListSearch.natCost_cost, ListSearch.natCost_evalQuery] + +end Algorithms + +end Cslib diff --git a/Cslib/AlgorithmsTheory/Algorithms/ListOrderedInsert.lean b/Cslib/AlgorithmsTheory/Algorithms/ListOrderedInsert.lean new file mode 100644 index 000000000..d4b4ea9fd --- /dev/null +++ b/Cslib/AlgorithmsTheory/Algorithms/ListOrderedInsert.lean @@ -0,0 +1,102 @@ +/- +Copyright (c) 2026 Shreyas Srinivas. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Shreyas Srinivas, Eric Wieser +-/ + +module + +public import Cslib.AlgorithmsTheory.QueryModel +public import Cslib.AlgorithmsTheory.Models.ListComparisonSort +public import Mathlib.Algebra.Order.Group.Nat +public import Mathlib.Data.Int.ConditionallyCompleteOrder +public import Mathlib.Data.List.Sort +public import Mathlib.Order.ConditionallyCompleteLattice.Basic + +@[expose] public section + +/-! +# Ordered insertion in a list + +In this file we state and prove the correctness and complexity of ordered insertions in lists under +the `SortOps` model. This ordered insert is later used in `insertionSort` mirroring the structure +in upstream libraries for the pure lean code versions of these declarations. + +-- + +## Main Definitions + +- `insertOrd` : ordered insert algorithm in the `SortOps` query model + +## Main results + +- `insertOrd_eval`: `insertOrd` evaluates identically to `List.orderedInsert`. +- `insertOrd_complexity_upper_bound` : Shows that `insertOrd` takes at most `n` comparisons, + and `n + 1` list head-insertion operations. +- `insertOrd_sorted` : Applying `insertOrd` to a sorted list yields a sorted list. +-/ + +namespace Cslib +namespace Algorithms + +open Prog + +open SortOpsInsertHead + +/-- +Performs ordered insertion of `x` into a list `l` in the `SortOps` query model. +If `l` is sorted, then `x` is inserted into `l` such that the resultant list is also sorted. +-/ +def insertOrd (x : α) (l : List α) : Prog (SortOpsInsertHead α) (List α) := do + match l with + | [] => insertHead x l + | a :: as => + if (← cmpLE x a : Bool) then + insertHead x (a :: as) + else + let res ← insertOrd x as + insertHead a res + +@[simp] +lemma insertOrd_eval (x : α) (l : List α) (le : α → α → Bool) : + (insertOrd x l).eval (sortModel le) = l.orderedInsert (fun x y => le x y = true) x := by + induction l with + | nil => + simp [insertOrd, sortModel] + | cons head tail ih => + by_cases h_head : le x head + · simp [insertOrd, h_head] + · simp [insertOrd, h_head, ih] + +-- TODO : to upstream +@[simp] +lemma _root_.List.length_orderedInsert (x : α) (l : List α) [DecidableRel r] : + (l.orderedInsert r x).length = l.length + 1 := by + induction l <;> grind + +theorem insertOrd_complexity_upper_bound + (l : List α) (x : α) (le : α → α → Bool) : + (insertOrd x l).time (sortModel le) ≤ ⟨l.length, l.length + 1⟩ := by + induction l with + | nil => + simp [insertOrd, sortModel] + | cons head tail ih => + obtain ⟨ih_compares, ih_inserts⟩ := ih + rw [insertOrd] + by_cases h_head : le x head + · simp [h_head] + · simp [h_head] + grind + +lemma insertOrd_sorted + (l : List α) (x : α) (le : α → α → Bool) + [Std.Total (fun x y => le x y)] + [IsTrans _ (fun x y => le x y)] : + l.Pairwise (fun x y => le x y) + → ((insertOrd x l).eval (sortModel le)).Pairwise (fun x y => le x y = true) := by + rw [insertOrd_eval] + exact List.Pairwise.orderedInsert _ _ + +end Algorithms + +end Cslib diff --git a/Cslib/AlgorithmsTheory/Algorithms/MergeSort.lean b/Cslib/AlgorithmsTheory/Algorithms/MergeSort.lean new file mode 100644 index 000000000..7d76807a0 --- /dev/null +++ b/Cslib/AlgorithmsTheory/Algorithms/MergeSort.lean @@ -0,0 +1,183 @@ +/- +Copyright (c) 2026 Shreyas Srinivas. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Shreyas Srinivas, Eric Wieser +-/ + +module + +public import Cslib.AlgorithmsTheory.Models.ListComparisonSort +public import Cslib.AlgorithmsTheory.Lean.MergeSort.MergeSort +import all Cslib.AlgorithmsTheory.Lean.MergeSort.MergeSort +import all Init.Data.List.Sort.Basic +@[expose] public section + +/-! +# Merge sort in a list + +In this file we state and prove the correctness and complexity of merge sort in lists under +the `SortOps` model. +-- + +## Main Definitions +- `merge` : Merge algorithm for merging two sorted lists in the `SortOps` query model +- `mergeSort` : Merge sort algorithm in the `SortOps` query model + +## Main results + +- `mergeSort_eval`: `mergeSort` evaluates identically to the priva. +- `mergeSort_sorted` : `mergeSort` outputs a sorted list. +- `mergeSort_perm` : The output of `mergeSort` is a permutation of the input list +- `mergeSort_complexity` : `mergeSort` takes at most n * ⌈log n⌉ comparisons. +-/ +namespace Cslib.Algorithms + +open SortOps + +/-- Merge two sorted lists using comparisons in the query monad. -/ +@[simp] +def merge (x y : List α) : Prog (SortOps α) (List α) := do + match x,y with + | [], ys => return ys + | xs, [] => return xs + | x :: xs', y :: ys' => do + let cmp : Bool ← cmpLE x y + if cmp then + let rest ← merge xs' (y :: ys') + return (x :: rest) + else + let rest ← merge (x :: xs') ys' + return (y :: rest) + +lemma merge_timeComplexity (x y : List α) (le : α → α → Bool) : + (merge x y).time (sortModelNat le) ≤ x.length + y.length := by + fun_induction List.merge x y (le · ·) with + | case1 => simp + | case2 => simp + | case3 x xs y ys hxy ihx => + suffices 1 + (merge xs (y :: ys)).time (sortModelNat le) ≤ xs.length + 1 + (ys.length + 1) by + simpa [hxy] + grind + | case4 x xs y ys hxy ihy => + suffices 1 + (merge (x :: xs) ys).time (sortModelNat le) ≤ xs.length + 1 + (ys.length + 1) by + simpa [hxy] + grind + +@[simp] +lemma merge_eval (x y : List α) (le : α → α → Bool) : + (merge x y).eval (sortModelNat le) = List.merge x y (le · ·) := by + fun_induction List.merge with simp_all [merge] + +lemma merge_length (x y : List α) (le : α → α → Bool) : + ((merge x y).eval (sortModelNat le)).length = x.length + y.length := by + rw [merge_eval] + apply List.length_merge + +/-- +The `mergeSort` algorithm in the `SortOps` query model. It sorts the input list +according to the mergeSort algorithm. +-/ +def mergeSort (xs : List α) : Prog (SortOps α) (List α) := do + if xs.length < 2 then return xs + else + let half := xs.length / 2 + let left := xs.take half + let right := xs.drop half + let sortedLeft ← mergeSort left + let sortedRight ← mergeSort right + merge sortedLeft sortedRight + +/-- +The vanilla-lean version of `mergeSortNaive` that is extensionally equal to `mergeSort` +-/ +private def mergeSortNaive (xs : List α) (le : α → α → Bool) : List α := + if xs.length < 2 then xs + else + let sortedLeft := mergeSortNaive (xs.take (xs.length/2)) le + let sortedRight := mergeSortNaive (xs.drop (xs.length/2)) le + List.merge sortedLeft sortedRight (le · ·) + +private proof_wanted mergeSortNaive_eq_mergeSort + [LinearOrder α] (xs : List α) (le : α → α → Bool) : + mergeSortNaive xs le = xs.mergeSort + +private lemma mergeSortNaive_Perm (xs : List α) (le : α → α → Bool) : + (mergeSortNaive xs le).Perm xs := by + fun_induction mergeSortNaive with + | case1 => simp + | case2 x _ _ _ ih2 ih1 => grw [←List.take_append_drop _ x, List.merge_perm_append, ← ih1, ← ih2] + +@[simp] +private lemma mergeSort_eval (xs : List α) (le : α → α → Bool) : + (mergeSort xs).eval (sortModelNat le) = mergeSortNaive xs le := by + fun_induction mergeSort with + | case1 xs h => + simp [h, mergeSortNaive, Prog.eval] + | case2 xs h n left right ihl ihr => + rw [mergeSortNaive, if_neg h] + simp [ihl, ihr, merge_eval] + rfl + +private lemma mergeSortNaive_length (xs : List α) (le : α → α → Bool) : + (mergeSortNaive xs le).length = xs.length := by + fun_induction mergeSortNaive with + | case1 xs h => + simp + | case2 xs h left right ihl ihr => + rw [List.length_merge] + convert congr($ihl + $ihr) + rw [← List.length_append] + simp + +lemma mergeSort_length (xs : List α) (le : α → α → Bool) : + ((mergeSort xs).eval (sortModelNat le)).length = xs.length := by + rw [mergeSort_eval] + apply mergeSortNaive_length + +lemma merge_sorted_sorted + (xs ys : List α) (le : α → α → Bool) [Std.Total (fun x y => le x y)] + [IsTrans _ (fun x y => le x y)] + (hxs_mono : xs.Pairwise (fun x y => le x y)) + (hys_mono : ys.Pairwise (fun x y => le x y)) : + ((merge xs ys).eval (sortModelNat le)).Pairwise (fun x y => le x y) := by + rw [merge_eval] + simpa using hxs_mono.merge hys_mono + +private lemma mergeSortNaive_sorted + (xs : List α) (le : α → α → Bool) [Std.Total ((fun x y => le x y = true))] + [IsTrans _ ((fun x y => le x y = true))] : + (mergeSortNaive xs le).Pairwise ((fun x y => le x y = true)) := by + fun_induction mergeSortNaive with + | case1 xs h => + match xs with | [] | [x] => simp + | case2 xs h left right ihl ihr => + simpa using ihl.merge ihr + +theorem mergeSort_sorted + (xs : List α) (le : α → α → Bool) [Std.Total (fun x y => le x y = true)] + [IsTrans _ (fun x y => le x y = true)] : + ((mergeSort xs).eval (sortModelNat le)).Pairwise ((fun x y => le x y = true)) := by + rw [mergeSort_eval] + apply mergeSortNaive_sorted + +theorem mergeSort_perm (xs : List α) (le : α → α → Bool) : + ((mergeSort xs).eval (sortModelNat le)).Perm xs := by + rw [mergeSort_eval] + apply mergeSortNaive_Perm + +section TimeComplexity + +open Cslib.Algorithms.Lean.TimeM + +-- TODO: reuse the work in `mergeSort_time_le`? +theorem mergeSort_complexity (xs : List α) (le : α → α → Bool) : + (mergeSort xs).time (sortModelNat le) ≤ T (xs.length) := by + fun_induction mergeSort with + | case1 => simp [T] + | case2 x => + simp only [FreeM.bind_eq_bind, Prog.time_bind] + grind [some_algebra (x.length - 2), mergeSort_eval, merge_timeComplexity, mergeSortNaive_length] + +end TimeComplexity + +end Cslib.Algorithms diff --git a/Cslib/Algorithms/Lean/MergeSort/MergeSort.lean b/Cslib/AlgorithmsTheory/Lean/MergeSort/MergeSort.lean similarity index 97% rename from Cslib/Algorithms/Lean/MergeSort/MergeSort.lean rename to Cslib/AlgorithmsTheory/Lean/MergeSort/MergeSort.lean index 8ba55d461..081dbf1b7 100644 --- a/Cslib/Algorithms/Lean/MergeSort/MergeSort.lean +++ b/Cslib/AlgorithmsTheory/Lean/MergeSort/MergeSort.lean @@ -6,7 +6,7 @@ Authors: Sorrachai Yingchareonthawornhcai module -public import Cslib.Algorithms.Lean.TimeM +public import Cslib.AlgorithmsTheory.Lean.TimeM public import Mathlib.Data.Nat.Cast.Order.Ring public import Mathlib.Data.Nat.Lattice public import Mathlib.Data.Nat.Log @@ -158,6 +158,10 @@ private lemma some_algebra (n : ℕ) : /-- Upper bound function for merge sort time complexity: `T(n) = n * ⌈log₂ n⌉` -/ abbrev T (n : ℕ) : ℕ := n * clog 2 n +lemma T_monotone : Monotone T := by + intro i j h_ij + exact Nat.mul_le_mul h_ij (Nat.clog_monotone 2 h_ij) + /-- Solve the recurrence -/ theorem timeMergeSortRec_le (n : ℕ) : timeMergeSortRec n ≤ T n := by fun_induction timeMergeSortRec with diff --git a/Cslib/Algorithms/Lean/TimeM.lean b/Cslib/AlgorithmsTheory/Lean/TimeM.lean similarity index 100% rename from Cslib/Algorithms/Lean/TimeM.lean rename to Cslib/AlgorithmsTheory/Lean/TimeM.lean diff --git a/Cslib/AlgorithmsTheory/LowerBounds/ComparisonSort.lean b/Cslib/AlgorithmsTheory/LowerBounds/ComparisonSort.lean new file mode 100644 index 000000000..03d10004b --- /dev/null +++ b/Cslib/AlgorithmsTheory/LowerBounds/ComparisonSort.lean @@ -0,0 +1,433 @@ +/- +Copyright (c) 2025 Shreyas Srinivas. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Shreyas Srinivas +-/ + +module + +public import Cslib.AlgorithmsTheory.Models.ListComparisonSort +public import Mathlib.Algebra.Order.Group.Nat +public import Mathlib.Algebra.Ring.Nat +public import Mathlib.Data.Fintype.BigOperators +public import Mathlib.Data.Fintype.Perm +public import Mathlib.Data.Nat.Lattice +public import Mathlib.Data.Nat.Log +import all Init.Data.List.Sort.Basic + +public import Mathlib + +@[expose] public section + +namespace Cslib + +namespace Algorithms + +open Prog + +/-- +Finite pigeonhole/cardinality step over an arbitrary finite domain. +-/ +lemma hDecisionTreeFintype + (β : Type*) [Fintype β] (t : ℕ) + (traceCode : β → (Fin t → Bool)) + (hTraceInj : Function.Injective traceCode) : + Fintype.card β ≤ 2 ^ t := by + simpa [Fintype.card_fun, Fintype.card_bool] using + (Fintype.card_le_of_injective traceCode hTraceInj) + +/-- +Arithmetic lower bound used to derive an `Ω(n log n)` comparison lower bound +from `Nat.log 2 (n!)`. +-/ +lemma hFactorialLog (n : ℕ) : + (n / 2) * Nat.log 2 (n / 2) ≤ Nat.log 2 (Nat.factorial n) := by + let k := n / 2 + change k * Nat.log 2 k ≤ Nat.log 2 (Nat.factorial n) + by_cases hk : k = 0 + · simp [hk] + · have hk_pos : 0 < k := Nat.pos_of_ne_zero hk + have hk_le_n : k ≤ n := by + simpa [k] using Nat.div_le_self n 2 + have h2k_le_n : k + k ≤ n := by + simpa [k, two_mul, Nat.mul_assoc, Nat.mul_left_comm, Nat.mul_comm] using Nat.mul_div_le n 2 + have hk_le_sub : k ≤ n - k := (Nat.le_sub_iff_add_le hk_le_n).2 h2k_le_n + have hPowLe : k ^ k ≤ k ^ (n - k) := + Nat.pow_le_pow_right hk_pos hk_le_sub + have hFactorialPow : Nat.factorial k * k ^ (n - k) ≤ Nat.factorial n := + Nat.factorial_mul_pow_sub_le_factorial hk_le_n + have hkPow_le_factorial : k ^ k ≤ Nat.factorial n := by + calc + k ^ k ≤ k ^ (n - k) := hPowLe + _ ≤ Nat.factorial k * k ^ (n - k) := Nat.le_mul_of_pos_left _ (Nat.factorial_pos k) + _ ≤ Nat.factorial n := hFactorialPow + have hLogPow : k * Nat.log 2 k ≤ Nat.log 2 (k ^ k) := by + have hPow : 2 ^ (k * Nat.log 2 k) ≤ k ^ k := by + calc + 2 ^ (k * Nat.log 2 k) = (2 ^ Nat.log 2 k) ^ k := by + rw [Nat.mul_comm, Nat.pow_mul] + _ ≤ k ^ k := Nat.pow_le_pow_left (Nat.pow_log_le_self 2 hk) k + exact Nat.le_log_of_pow_le (by decide : 1 < 2) hPow + have hLogMono : Nat.log 2 (k ^ k) ≤ Nat.log 2 (Nat.factorial n) := + Nat.log_mono_right hkPow_le_factorial + exact le_trans hLogPow hLogMono + +/-- The order on `Fin n` induced by a hidden permutation `σ`. -/ +def permLE {n : ℕ} (σ : Equiv.Perm (Fin n)) : Fin n → Fin n → Bool := + fun x y => decide (σ x ≤ σ y) + +/-- Canonical sorted output for the hidden order induced by `σ`. -/ +def permOutput {n : ℕ} (σ : Equiv.Perm (Fin n)) : List (Fin n) := + List.ofFn σ.symm + +lemma permOutput_pairwise {n : ℕ} (σ : Equiv.Perm (Fin n)) : + (permOutput σ).Pairwise (fun x y => permLE σ x y = true) := by + rw [permOutput, List.pairwise_ofFn] + intro i j hij + simpa [permLE, decide_eq_true_eq] using (le_of_lt hij) + +lemma permOutput_injective {n : ℕ} : + Function.Injective (permOutput (n := n)) := by + intro σ τ h + have hsymm : (fun i => σ.symm i) = fun i => τ.symm i := List.ofFn_injective h + ext x + have hAt : σ.symm (τ x) = τ.symm (τ x) := by + simpa using congrArg (fun f => f (τ x)) hsymm + have hσ := congrArg σ hAt + simpa using (congrArg Fin.val hσ).symm + +/-- +Boolean transcript produced by running a comparison program under comparator `le`. +-/ +def traceSort : Prog (SortOps α) β → (α → α → Bool) → List Bool + | .pure _, _ => [] + | .liftBind q cont, le => + match q with + | .cmpLE x y => + let b := le x y + b :: traceSort (cont b) le + +@[simp] lemma traceSort_pure (x : β) (le : α → α → Bool) : + traceSort (.pure x : Prog (SortOps α) β) le = [] := rfl + +@[simp] lemma traceSort_liftBind (x y : α) (cont : Bool → Prog (SortOps α) β) (le : α → α → Bool) : + traceSort (.liftBind (SortOps.cmpLE x y) cont) le = + (le x y) :: traceSort (cont (le x y)) le := by + simp [traceSort] + +lemma traceSort_length_eq_time (P : Prog (SortOps α) β) (le : α → α → Bool) : + (traceSort P le).length = P.time (sortModelNat le) := by + induction P with + | pure a => + simp [traceSort] + | liftBind op cont ih => + cases op with + | cmpLE x y => + simp [traceSort, ih, Nat.add_comm] + +/-- +If two runs of a program have the same comparison transcript, then they have the same output. +-/ +lemma eval_eq_of_traceSort_eq + (P : Prog (SortOps α) β) {le₁ le₂ : α → α → Bool} + (h : traceSort P le₁ = traceSort P le₂) : + P.eval (sortModelNat le₁) = P.eval (sortModelNat le₂) := by + induction P generalizing le₁ le₂ with + | pure a => + simp + | liftBind op cont ih => + cases op with + | cmpLE x y => + have hcons : + (le₁ x y) :: traceSort (cont (le₁ x y)) le₁ = + (le₂ x y) :: traceSort (cont (le₂ x y)) le₂ := by + simpa [traceSort] using h + injection hcons with hhead htail + have htail' : + traceSort (cont (le₁ x y)) le₁ = + traceSort (cont (le₁ x y)) le₂ := by + simpa [hhead] using htail + simpa [Prog.eval_liftBind, hhead] using ih (le₁ x y) htail' + +/-- +For a fixed program, one transcript cannot be a strict prefix of another. +-/ +lemma traceSort_prefix_eq + (P : Prog (SortOps α) β) {le₁ le₂ : α → α → Bool} + (h : traceSort P le₁ <+: traceSort P le₂) : + traceSort P le₁ = traceSort P le₂ := by + induction P generalizing le₁ le₂ with + | pure a => + simp [traceSort] + | liftBind op cont ih => + cases op with + | cmpLE x y => + have hcons : + (le₁ x y) :: traceSort (cont (le₁ x y)) le₁ <+: + (le₂ x y) :: traceSort (cont (le₂ x y)) le₂ := by + simpa [traceSort] using h + rcases List.cons_prefix_cons.mp hcons with ⟨hhead, htail⟩ + have htail' : + traceSort (cont (le₁ x y)) le₁ <+: + traceSort (cont (le₁ x y)) le₂ := by + simpa [hhead] using htail + have hEqTail := ih (le₁ x y) htail' + have hEqTail' : + traceSort (cont (le₂ x y)) le₁ = + traceSort (cont (le₂ x y)) le₂ := by + simpa [hhead] using hEqTail + simp [traceSort, hhead, hEqTail'] + +/-- Pad a transcript with `false` bits up to a fixed length `t`. -/ +def padTrace (t : ℕ) (tr : List Bool) : Fin t → Bool := + fun i => (tr[i.1]?).getD false + +lemma isPrefix_of_padTrace_eq + {t : ℕ} {s₁ s₂ : List Bool} + (hs₁ : s₁.length ≤ t) (hLen : s₁.length ≤ s₂.length) + (hPad : padTrace t s₁ = padTrace t s₂) : + s₁ <+: s₂ := by + rw [List.prefix_iff_eq_take] + apply List.ext_getElem?' + intro i hi + have hTakeLen : (s₂.take s₁.length).length = s₁.length := by + simp [List.length_take, Nat.min_eq_left hLen] + have hi₁ : i < s₁.length := by + simpa [hTakeLen] using hi + have hi₂ : i < s₂.length := lt_of_lt_of_le hi₁ hLen + have hit : i < t := lt_of_lt_of_le hi₁ hs₁ + have hAt := congrArg (fun f => f ⟨i, hit⟩) hPad + calc + s₁[i]? = (s₁[i]?).getD false := by simp [hi₁] + _ = (s₂[i]?).getD false := by simpa [padTrace] using hAt + _ = s₂[i]? := by simp [hi₂] + _ = (s₂.take s₁.length)[i]? := by + simpa using (List.getElem?_take_of_lt (l := s₂) (i := i) (j := s₁.length) hi₁).symm + +lemma traceSort_eq_of_padTrace_eq + (P : Prog (SortOps α) β) {le₁ le₂ : α → α → Bool} {t : ℕ} + (hLen₁ : (traceSort P le₁).length ≤ t) + (hLen₂ : (traceSort P le₂).length ≤ t) + (hPad : padTrace t (traceSort P le₁) = padTrace t (traceSort P le₂)) : + traceSort P le₁ = traceSort P le₂ := by + by_cases hcmp : (traceSort P le₁).length ≤ (traceSort P le₂).length + · exact traceSort_prefix_eq P (isPrefix_of_padTrace_eq hLen₁ hcmp hPad) + · have hcmp' : (traceSort P le₂).length ≤ (traceSort P le₁).length := Nat.le_of_not_ge hcmp + have hEq21 : traceSort P le₂ = traceSort P le₁ := by + exact traceSort_prefix_eq P (isPrefix_of_padTrace_eq hLen₂ hcmp' hPad.symm) + exact hEq21.symm + +/-- Worst-case number of comparisons over all hidden permutations of `Fin n`. -/ +def worstTime {n : ℕ} (P : Prog (SortOps (Fin n)) (List (Fin n))) : ℕ := + (Finset.univ : Finset (Equiv.Perm (Fin n))).sup + (fun σ => P.time (sortModelNat (permLE σ))) + +/-- Fixed-length transcript code at depth `worstTime`. -/ +def traceCode {n : ℕ} (P : Prog (SortOps (Fin n)) (List (Fin n))) : + Equiv.Perm (Fin n) → (Fin (worstTime P) → Bool) := + fun σ => padTrace (worstTime P) (traceSort P (permLE σ)) + +lemma traceCode_injective + {n : ℕ} (P : Prog (SortOps (Fin n)) (List (Fin n))) + (hCorrect : ∀ σ : Equiv.Perm (Fin n), + P.eval (sortModelNat (permLE σ)) = permOutput σ) : + Function.Injective (traceCode P) := by + intro σ τ hCode + have hTimeσ : + P.time (sortModelNat (permLE σ)) ≤ + (Finset.univ : Finset (Equiv.Perm (Fin n))).sup + (fun ρ => P.time (sortModelNat (permLE ρ))) := by + exact Finset.le_sup + (s := (Finset.univ : Finset (Equiv.Perm (Fin n)))) + (f := fun ρ => P.time (sortModelNat (permLE ρ))) + (Finset.mem_univ σ) + have hTimeτ : + P.time (sortModelNat (permLE τ)) ≤ + (Finset.univ : Finset (Equiv.Perm (Fin n))).sup + (fun ρ => P.time (sortModelNat (permLE ρ))) := by + exact Finset.le_sup + (s := (Finset.univ : Finset (Equiv.Perm (Fin n)))) + (f := fun ρ => P.time (sortModelNat (permLE ρ))) + (Finset.mem_univ τ) + have hLenσ : (traceSort P (permLE σ)).length ≤ worstTime P := by + simpa [worstTime, traceSort_length_eq_time] using hTimeσ + have hLenτ : (traceSort P (permLE τ)).length ≤ worstTime P := by + simpa [worstTime, traceSort_length_eq_time] using hTimeτ + have hTrace : + traceSort P (permLE σ) = traceSort P (permLE τ) := by + exact traceSort_eq_of_padTrace_eq P hLenσ hLenτ hCode + have hEval : + P.eval (sortModelNat (permLE σ)) = P.eval (sortModelNat (permLE τ)) := + eval_eq_of_traceSort_eq P hTrace + have hOut : permOutput σ = permOutput τ := by + simpa [hCorrect σ, hCorrect τ] using hEval + exact permOutput_injective hOut + +/-- +Decision-tree lower bound in the strong hidden-permutation model: +`n!` distinct hidden orders require at least `log₂(n!)` worst-case comparisons. +-/ +lemma hDecisionTreeLower + {n : ℕ} (P : Prog (SortOps (Fin n)) (List (Fin n))) + (hCorrect : ∀ σ : Equiv.Perm (Fin n), + P.eval (sortModelNat (permLE σ)) = permOutput σ) : + Nat.factorial n ≤ 2 ^ worstTime P := by + have hCard : + Fintype.card (Equiv.Perm (Fin n)) ≤ 2 ^ worstTime P := + hDecisionTreeFintype (β := Equiv.Perm (Fin n)) (worstTime P) (traceCode P) + (traceCode_injective P hCorrect) + simpa [Fintype.card_perm] using hCard + +lemma eval_pairwise_of_correct + {n : ℕ} (P : Prog (SortOps (Fin n)) (List (Fin n))) + (hCorrect : ∀ σ : Equiv.Perm (Fin n), + P.eval (sortModelNat (permLE σ)) = permOutput σ) + (σ : Equiv.Perm (Fin n)) : + (P.eval (sortModelNat (permLE σ))).Pairwise (fun x y => permLE σ x y = true) := by + simpa [hCorrect σ] using permOutput_pairwise σ + +/-- +GPT suggested to pick an abitrary hidden permutation of `Fin n` and generate a list from it +and then prove that for this, sorting takes `n /2 * (Nat.log 2 (n / 2))` +-/ +theorem cmpSort_lower_bound + (n : ℕ) (P : Prog (SortOps (Fin n)) (List (Fin n))) + (hCorrect : ∀ σ : Equiv.Perm (Fin n), + P.eval (sortModelNat (permLE σ)) = permOutput σ) : + worstTime P ≥ (n / 2) * Nat.log 2 (n / 2) := by + have hDecision : Nat.factorial n ≤ 2 ^ worstTime P := + hDecisionTreeLower P hCorrect + have hLog : + Nat.log 2 (Nat.factorial n) ≤ Nat.log 2 (2 ^ worstTime P) := + Nat.log_mono_right hDecision + have hTime : Nat.log 2 (Nat.factorial n) ≤ worstTime P := by + simpa [Nat.log_pow (b := 2) (x := worstTime P) (by decide : 1 < 2)] using hLog + exact le_trans (hFactorialLog n) hTime + +section HiddenOrderEquiv + +/-- Hidden order induced by a permutation after encoding elements with `e : β ≃ Fin n`. -/ +def permLEEquiv {β : Type} {n : ℕ} + (e : β ≃ Fin n) (σ : Equiv.Perm (Fin n)) : β → β → Bool := + fun x y => decide (σ (e x) ≤ σ (e y)) + +/-- Canonical sorted output induced by `σ`, transported through `e`. -/ +def permOutputEquiv {β : Type} {n : ℕ} + (e : β ≃ Fin n) (σ : Equiv.Perm (Fin n)) : List β := + List.ofFn (fun i => e.symm (σ.symm i)) + +lemma permOutputEquiv_pairwise {β : Type} {n : ℕ} + (e : β ≃ Fin n) (σ : Equiv.Perm (Fin n)) : + (permOutputEquiv e σ).Pairwise (fun x y => permLEEquiv e σ x y = true) := by + rw [permOutputEquiv, List.pairwise_ofFn] + intro i j hij + simpa [permLEEquiv, decide_eq_true_eq] using (le_of_lt hij) + +lemma permOutputEquiv_injective {β : Type} {n : ℕ} + (e : β ≃ Fin n) : + Function.Injective (permOutputEquiv e) := by + intro σ τ h + have hsymm : + (fun i => e.symm (σ.symm i)) = fun i => e.symm (τ.symm i) := + List.ofFn_injective h + ext x + have hAt : e.symm (σ.symm (τ x)) = e.symm (τ.symm (τ x)) := by + simpa using congrArg (fun f => f (τ x)) hsymm + have hAt' : σ.symm (τ x) = τ.symm (τ x) := by + simpa using congrArg e hAt + have hσ : τ x = σ x := by + simpa using congrArg σ hAt' + simpa [eq_comm] using congrArg Fin.val hσ + +/-- Worst-case comparisons over hidden permutations, transported through `e`. -/ +def worstTimeEquiv {β : Type} {n : ℕ} + (e : β ≃ Fin n) (P : Prog (SortOps β) (List β)) : ℕ := + (Finset.univ : Finset (Equiv.Perm (Fin n))).sup + (fun σ => Prog.time P (sortModelNat (α := β) (permLEEquiv e σ))) + +/-- Fixed-length transcript code at depth `worstTimeEquiv`. -/ +def traceCodeEquiv {β : Type} {n : ℕ} + (e : β ≃ Fin n) (P : Prog (SortOps β) (List β)) : + Equiv.Perm (Fin n) → (Fin (worstTimeEquiv e P) → Bool) := + fun σ => padTrace (worstTimeEquiv e P) (traceSort P (permLEEquiv e σ)) + +lemma traceCodeEquiv_injective + {β : Type} {n : ℕ} + (e : β ≃ Fin n) (P : Prog (SortOps β) (List β)) + (hCorrect : ∀ σ : Equiv.Perm (Fin n), + Prog.eval P (sortModelNat (α := β) (permLEEquiv e σ)) = permOutputEquiv e σ) : + Function.Injective (traceCodeEquiv e P) := by + intro σ τ hCode + have hLen (ρ : Equiv.Perm (Fin n)) : + (traceSort P (permLEEquiv e ρ)).length ≤ worstTimeEquiv e P := by + simpa [worstTimeEquiv, traceSort_length_eq_time] using + (Finset.le_sup + (s := (Finset.univ : Finset (Equiv.Perm (Fin n)))) + (f := fun ρ => Prog.time P (sortModelNat (α := β) (permLEEquiv e ρ))) + (Finset.mem_univ ρ)) + have hTrace : + traceSort P (permLEEquiv e σ) = traceSort P (permLEEquiv e τ) := by + exact traceSort_eq_of_padTrace_eq P (hLen σ) (hLen τ) hCode + exact permOutputEquiv_injective e <| by + simpa [hCorrect σ, hCorrect τ] using eval_eq_of_traceSort_eq P hTrace + +lemma hDecisionTreeLowerEquiv + {β : Type} {n : ℕ} + (e : β ≃ Fin n) (P : Prog (SortOps β) (List β)) + (hCorrect : ∀ σ : Equiv.Perm (Fin n), + Prog.eval P (sortModelNat (α := β) (permLEEquiv e σ)) = permOutputEquiv e σ) : + Nat.factorial n ≤ 2 ^ worstTimeEquiv e P := by + have hCard : + Fintype.card (Equiv.Perm (Fin n)) ≤ 2 ^ worstTimeEquiv e P := + hDecisionTreeFintype (β := Equiv.Perm (Fin n)) (worstTimeEquiv e P) (traceCodeEquiv e P) + (traceCodeEquiv_injective e P hCorrect) + simpa [Fintype.card_perm] using hCard + +/-- `Ω(n log n)` lower bound on any type equivalent to `Fin n`. -/ +theorem cmpSort_lower_bound_equiv + {β : Type} {n : ℕ} + (e : β ≃ Fin n) (P : Prog (SortOps β) (List β)) + (hCorrect : ∀ σ : Equiv.Perm (Fin n), + Prog.eval P (sortModelNat (α := β) (permLEEquiv e σ)) = permOutputEquiv e σ) : + worstTimeEquiv e P ≥ (n / 2) * Nat.log 2 (n / 2) := by + have hDecision : Nat.factorial n ≤ 2 ^ worstTimeEquiv e P := + hDecisionTreeLowerEquiv e P hCorrect + have hLog : + Nat.log 2 (Nat.factorial n) ≤ Nat.log 2 (2 ^ worstTimeEquiv e P) := + Nat.log_mono_right hDecision + have hTime : Nat.log 2 (Nat.factorial n) ≤ worstTimeEquiv e P := by + simpa [Nat.log_pow (b := 2) (x := worstTimeEquiv e P) (by decide : 1 < 2)] using hLog + exact le_trans (hFactorialLog n) hTime + +/-- `Ω(n log n)` lower bound stated directly for a finite carrier type `α`. -/ +theorem cmpSort_lower_bound_fintype + (α : Type) [Fintype α] + (P : Prog (SortOps α) (List α)) + (hCorrect : ∀ σ : Equiv.Perm (Fin (Fintype.card α)), + Prog.eval P (sortModelNat (α := α) (permLEEquiv (Fintype.equivFin α) σ)) = + permOutputEquiv (Fintype.equivFin α) σ) : + worstTimeEquiv (Fintype.equivFin α) P ≥ + (Fintype.card α / 2) * Nat.log 2 (Fintype.card α / 2) := by + simpa using cmpSort_lower_bound_equiv (e := Fintype.equivFin α) (P := P) hCorrect + +/-- +Lower bound specialized to a fixed nodup list `l`. +This is a corollary of the fintype statement with carrier `{x // x ∈ l}`. +-/ +theorem cmpSort_lower_bound_infinite_types + {α : Type} [DecidableEq α] + (l : List α) (hNodup : l.Nodup) + (P : Prog (SortOps {x // x ∈ l}) (List {x // x ∈ l})) + (hCorrect : ∀ σ : Equiv.Perm (Fin l.length), + Prog.eval P (sortModelNat (α := {x // x ∈ l}) + (permLEEquiv (List.Nodup.getEquiv l hNodup).symm σ)) = + permOutputEquiv (List.Nodup.getEquiv l hNodup).symm σ) : + worstTimeEquiv (List.Nodup.getEquiv l hNodup).symm P ≥ + (l.length / 2) * Nat.log 2 (l.length / 2) := by + simpa using cmpSort_lower_bound_equiv (List.Nodup.getEquiv l hNodup).symm P hCorrect + +end HiddenOrderEquiv + +end Algorithms + +end Cslib diff --git a/Cslib/AlgorithmsTheory/Models/ListComparisonSearch.lean b/Cslib/AlgorithmsTheory/Models/ListComparisonSearch.lean new file mode 100644 index 000000000..4badcfa2b --- /dev/null +++ b/Cslib/AlgorithmsTheory/Models/ListComparisonSearch.lean @@ -0,0 +1,52 @@ +/- +Copyright (c) 2025 Shreyas Srinivas. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Shreyas Srinivas +-/ + +module + +public import Cslib.AlgorithmsTheory.QueryModel + +@[expose] public section + +/-! +# Query Type for Comparison Search in Lists + +In this file we define a query type `ListSearch` for comparison based searching in Lists, +whose sole query `compare` compares the head of the list with a given argument. It +further defines a model `ListSearch.natCost` for this query. + +-- +## Definitions + +- `ListSearch`: A query type for comparison based search in lists. +- `ListSearch.natCost`: A model for this query with costs in `ℕ`. + +-/ + +namespace Cslib + +namespace Algorithms + +open Prog + +/-- +A query type for searching elements in list. It supports exactly one query +`compare l val` which returns `true` if the head of the list `l` is equal to `val` +and returns `false` otherwise. +-/ +inductive ListSearch (α : Type*) : Type → Type _ where + | compare (a : List α) (val : α) : ListSearch α Bool + + +/-- A model of the `ListSearch` query type that assigns the cost as the number of queries. -/ +@[simps] +def ListSearch.natCost [BEq α] : Model (ListSearch α) ℕ where + evalQuery + | .compare l x => some x == l.head? + cost _ := 1 + +end Algorithms + +end Cslib diff --git a/Cslib/AlgorithmsTheory/Models/ListComparisonSort.lean b/Cslib/AlgorithmsTheory/Models/ListComparisonSort.lean new file mode 100644 index 000000000..4781fbf06 --- /dev/null +++ b/Cslib/AlgorithmsTheory/Models/ListComparisonSort.lean @@ -0,0 +1,150 @@ +/- +Copyright (c) 2026 Shreyas Srinivas. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Shreyas Srinivas, Eric WIeser +-/ + +module + +public import Cslib.AlgorithmsTheory.QueryModel +public import Mathlib.Algebra.Group.Nat.Defs +public import Mathlib.Algebra.Group.Prod +public import Mathlib.Data.Nat.Basic +public import Mathlib.Order.Basic +public import Mathlib.Tactic.FastInstance +@[expose] public section + +/-! +# Query Type for Comparison Search in Lists + +In this file we define two query types `SortOps` which is suitable for insertion sort, and +`SortOps`for comparison based searching in Lists. We define a model `sortModel` for `SortOps` +which uses a custom cost structure `SortOpsCost`. We define a model `sortModelCmp` for `SortOpsCmp` +which defines a `ℕ` based cost structure. +-- +## Definitions + +- `SortOps`: A query type for comparison based sorting in lists which includes queries for + comparison and head-insertion into Lists. This is a suitable query for ordered insertion + and insertion sort. +- `SortOpsCmp`: A query type for comparison based sorting that only includes a comparison query. + This is more suitable for comparison based sorts for which it is only desirable to count + comparisons + +-/ +namespace Cslib + +namespace Algorithms + +open Prog + +/-- +A model for comparison sorting on lists. +-/ +inductive SortOpsInsertHead (α : Type) : Type → Type where + /-- `cmpLE x y` is intended to return `true` if `x ≤ y` and `false` otherwise. + The specific order relation depends on the model provided for this typ. e-/ + | cmpLE (x : α) (y : α) : SortOpsInsertHead α Bool + /-- `insertHead l x` is intended to return `x :: l`. -/ + | insertHead (x : α) (l : List α) : SortOpsInsertHead α (List α) + +open SortOpsInsertHead + +section SortOpsCostModel + +/-- +A cost type for counting the operations of `SortOps` with separate fields for +counting calls to `cmpLT` and `insertHead` +-/ +@[ext, grind] +structure SortOpsCost where + /-- `compares` counts the number of calls to `cmpLT` -/ + compares : ℕ + /-- `inserts` counts the number of calls to `insertHead` -/ + inserts : ℕ + +/-- Equivalence between SortOpsCost and a product type. -/ +def SortOpsCost.equivProd : SortOpsCost ≃ (ℕ × ℕ) where + toFun sortOps := (sortOps.compares, sortOps.inserts) + invFun pair := ⟨pair.1, pair.2⟩ + left_inv _ := rfl + right_inv _ := rfl + +namespace SortOpsCost + +@[simps, grind] +instance : Zero SortOpsCost := ⟨0, 0⟩ + +@[simps] +instance : LE SortOpsCost where + le soc₁ soc₂ := soc₁.compares ≤ soc₂.compares ∧ soc₁.inserts ≤ soc₂.inserts + +instance : LT SortOpsCost where + lt soc₁ soc₂ := soc₁ ≤ soc₂ ∧ ¬soc₂ ≤ soc₁ + +@[grind] +instance : PartialOrder SortOpsCost := + fast_instance% SortOpsCost.equivProd.injective.partialOrder _ .rfl .rfl + +@[simps] +instance : Add SortOpsCost where + add soc₁ soc₂ := ⟨soc₁.compares + soc₂.compares, soc₁.inserts + soc₂.inserts⟩ + +@[simps] +instance : SMul ℕ SortOpsCost where + smul n soc := ⟨n • soc.compares, n • soc.inserts⟩ + +instance : AddCommMonoid SortOpsCost := + fast_instance% + SortOpsCost.equivProd.injective.addCommMonoid _ rfl (fun _ _ => rfl) (fun _ _ => rfl) + +end SortOpsCost + +/-- +A model of `SortOps` that uses `SortOpsCost` as the cost type for operations. + +While this accepts any decidable relation `le`, most sorting algorithms are only well-behaved in the +presence of `[Std.Total le] [IsTrans _ le]`. +-/ +@[simps, grind] +def sortModel {α : Type} (le : α → α → Bool) : + Model (SortOpsInsertHead α) SortOpsCost where + evalQuery + | .cmpLE x y => le x y + | .insertHead x l => x :: l + cost + | .cmpLE _ _ => ⟨1,0⟩ + | .insertHead _ _ => ⟨0,1⟩ + +end SortOpsCostModel + +section NatModel + +/-- +A model for comparison sorting on lists with only the comparison operation. This +is used in mergeSort. +-/ +inductive SortOps.{u} (α : Type u) : Type → Type _ where + /-- `cmpLE x y` is intended to return `true` if `x ≤ y` and `false` otherwise. + The specific order relation depends on the model provided for this type. -/ + | cmpLE (x : α) (y : α) : SortOps α Bool + +/-- +A model of `SortOps` that uses `ℕ` as the type for the cost of operations. In this model, +both comparisons and insertions are counted in a single `ℕ` parameter. + +While this accepts any decidable relation `le`, most sorting algorithms are only well-behaved in the +presence of `[Std.Total le] [IsTrans _ le]`. +-/ +@[simps] +def sortModelNat {α : Type*} + (le : α → α → Bool) : Model (SortOps α) ℕ where + evalQuery + | .cmpLE x y => le x y + cost _ := 1 + +end NatModel + +end Algorithms + +end Cslib diff --git a/Cslib/AlgorithmsTheory/QueryModel.lean b/Cslib/AlgorithmsTheory/QueryModel.lean new file mode 100644 index 000000000..c91beb60d --- /dev/null +++ b/Cslib/AlgorithmsTheory/QueryModel.lean @@ -0,0 +1,149 @@ +/- +Copyright (c) 2025 Tanner Duve. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Tanner Duve, Shreyas Srinivas, Eric Wieser +-/ + +module + +public import Cslib.Foundations.Control.Monad.Free +public import Cslib.AlgorithmsTheory.Lean.TimeM + +@[expose] public section + +/- +# Query model + +This file defines a simple query language modeled as a free monad over a +parametric type of query operations. + +## Main definitions + +- `Model Q c`: A model type for a query type `Q : Type u → Type u` and cost type `c` +- `Prog Q α`: The type of programs of query type `Q` and return type `α`. + This is a free monad under the hood +- `Prog.eval`, `Prog.time`: concrete execution semantics of a `Prog Q α` for a given model of `Q` + +## How to set up an algorithm + +This model is a lightweight framework for specifying and verifying both the correctness +and complexity of algorithms in lean. To specify an algorithm, one must: +1. Define an inductive type of queries. This type must at least one index parameter + which determines the output type of the query. Additionally, it helps to have a parameter `α` + on which the index type depends. This way, any instance parameters of `α` can be used easily + for the output types. The signatures of `Model.evalQuery` and `Model.cost` are fixed. + So you can't supply instances for the index type there. +2. Define a record of the `Model Q C` structure that specifies the evaluation and time (cost) of + each query +3. Write your algorithm as a monadic program in `Prog Q α`. With sufficient type anotations + each query `q : Q` is automatically lifted into `Prog Q α`. + +## Tags +query model, free monad, time complexity, Prog +-/ + +namespace Cslib + +namespace Algorithms + +/-- +A model type for a query type `QType` and cost type `Cost`. It consists of +two fields, which respectively define the evaluation and cost of a query. +-/ +structure Model (QType : Type u → Type v) (Cost : Type w) where + /-- Evaluates a query `q : Q ι` to return a result of type `ι`. -/ + evalQuery : QType ι → ι + /-- Counts the operational cost of a query `q : Q ι` to return a result of type `Cost`. + The cost could represent any desired complexity measure, + including but not limited to time complexity. -/ + cost : QType ι → Cost + + +open Cslib.Algorithms.Lean in +/-- lift `Model.cost` to `TimeM Cost ι` -/ +abbrev Model.timeQuery + (M : Model Q Cost) (x : Q ι) : TimeM Cost ι := + TimeM.mk (M.evalQuery x) (M.cost x) + +/-- +A program is defined as a Free Monad over a Query type `Q` which operates on a base type `α` +which can determine the input and output types of a query. +-/ +abbrev Prog Q α := FreeM Q α + +/-- +The evaluation function of a program `P : Prog Q α` given a model `M : Model Q α` of `Q` +-/ +def Prog.eval + (P : Prog Q α) (M : Model Q Cost) : α := + Id.run <| P.liftM fun x => pure (M.evalQuery x) + +@[simp, grind =] +theorem Prog.eval_pure (a : α) (M : Model Q Cost) : + Prog.eval (FreeM.pure a) M = a := + rfl + +@[simp, grind =] +theorem Prog.eval_bind + (x : Prog Q α) (f : α → Prog Q β) (M : Model Q Cost) : + Prog.eval (FreeM.bind x f) M = Prog.eval (f (x.eval M)) M := by + simp [Prog.eval] + +@[simp, grind =] +theorem Prog.eval_liftBind + (x : Q α) (f : α → Prog Q β) (M : Model Q Cost) : + Prog.eval (FreeM.liftBind x f) M = Prog.eval (f <| M.evalQuery x) M := by + simp [Prog.eval] + +/-- +The cost function of a program `P : Prog Q α` given a model `M : Model Q α` of `Q`. +The most common use case of this function is to compute time-complexity, hence the name. + +In practice this is only well-behaved in the presence of `AddCommMonoid Cost`. +-/ +def Prog.time [AddZero Cost] + (P : Prog Q α) (M : Model Q Cost) : Cost := + (P.liftM M.timeQuery).time + +@[simp, grind =] +lemma Prog.time_pure [AddZero Cost] (a : α) (M : Model Q Cost) : + Prog.time (FreeM.pure a) M = 0 := by + simp [time] + +@[simp, grind =] +theorem Prog.time_liftBind [AddZero Cost] + (x : Q α) (f : α → Prog Q β) (M : Model Q Cost) : + Prog.time (FreeM.liftBind x f) M = M.cost x + Prog.time (f <| M.evalQuery x) M := by + simp [Prog.time] + +@[simp, grind =] +lemma Prog.time_bind [AddCommMonoid Cost] (M : Model Q Cost) + (op : Prog Q ι) (cont : ι → Prog Q α) : + Prog.time (op.bind cont) M = + Prog.time op M + Prog.time (cont (Prog.eval op M)) M := by + simp only [eval, time] + induction op with + | pure a => + simp + | liftBind op cont' ih => + specialize ih (M.evalQuery op) + simp_all [add_assoc] + +section Reduction + +/-- A reduction structure from query type `Q₁` to query type `Q₂`. -/ +structure Reduction (Q₁ Q₂ : Type u → Type u) where + /-- `reduce (q : Q₁ α)` is a program `P : Prog Q₂ α` that is intended to + implement `q` in the query type `Q₂` -/ + reduce : Q₁ α → Prog Q₂ α + +/-- +`Prog.reduceProg` takes a reduction structure from a query `Q₁` to `Q₂` and extends its +`reduce` function to programs on the query type `Q₁`. +-/ +abbrev Prog.reduceProg (P : Prog Q₁ α) (red : Reduction Q₁ Q₂) : Prog Q₂ α := + P.liftM red.reduce + +end Reduction + +end Cslib.Algorithms diff --git a/Cslib/Foundations/Control/Monad/Free.lean b/Cslib/Foundations/Control/Monad/Free.lean index b0a828c1b..90d78b6da 100644 --- a/Cslib/Foundations/Control/Monad/Free.lean +++ b/Cslib/Foundations/Control/Monad/Free.lean @@ -96,7 +96,7 @@ variable {F : Type u → Type v} {ι : Type u} {α : Type w} {β : Type w'} {γ instance : Pure (FreeM F) where pure := .pure -@[simp] +@[simp, grind =] theorem pure_eq_pure : (pure : α → FreeM F α) = FreeM.pure := rfl /-- Bind operation for the `FreeM` monad. -/ @@ -115,7 +115,7 @@ protected theorem bind_assoc (x : FreeM F α) (f : α → FreeM F β) (g : β instance : Bind (FreeM F) where bind := .bind -@[simp] +@[simp, grind =] theorem bind_eq_bind {α β : Type w} : Bind.bind = (FreeM.bind : FreeM F α → _ → FreeM F β) := rfl /-- Map a function over a `FreeM` monad. -/ @@ -154,14 +154,21 @@ lemma map_lift (f : ι → α) (op : F ι) : map f (lift op : FreeM F ι) = liftBind op (fun z => (.pure (f z) : FreeM F α)) := rfl /-- `.pure a` followed by `bind` collapses immediately. -/ -@[simp] +@[simp, grind =] lemma pure_bind (a : α) (f : α → FreeM F β) : (.pure a : FreeM F α).bind f = f a := rfl -@[simp] +@[simp, grind =] +lemma pure_bind' {α β} (a : α) (f : α → FreeM F β) : (.pure a : FreeM F α) >>= f = f a := + pure_bind a f + +@[simp, grind =] lemma bind_pure : ∀ x : FreeM F α, x.bind (.pure) = x | .pure a => rfl | liftBind op k => by simp [FreeM.bind, bind_pure] +@[simp, grind =] +lemma bind_pure' : ∀ x : FreeM F α, x >>= .pure = x := bind_pure + @[simp] lemma bind_pure_comp (f : α → β) : ∀ x : FreeM F α, x.bind (.pure ∘ f) = map f x | .pure a => rfl @@ -223,6 +230,9 @@ lemma liftM_bind [LawfulMonad m] rw [FreeM.bind, liftM_liftBind, liftM_liftBind, bind_assoc] simp_rw [ih] +instance {Q α} : CoeOut (Q α) (FreeM Q α) where + coe := FreeM.lift + /-- A predicate stating that `interp : FreeM F α → m α` is an interpreter for the effect handler `handler : ∀ {α}, F α → m α`. diff --git a/CslibTests.lean b/CslibTests.lean index 73292aef3..c1c44021a 100644 --- a/CslibTests.lean +++ b/CslibTests.lean @@ -11,4 +11,6 @@ public import CslibTests.HasFresh public import CslibTests.ImportWithMathlib public import CslibTests.LTS public import CslibTests.LambdaCalculus +public import CslibTests.QueryModel.ProgExamples +public import CslibTests.QueryModel.QueryExamples public import CslibTests.Reduction diff --git a/CslibTests/QueryModel/ProgExamples.lean b/CslibTests/QueryModel/ProgExamples.lean new file mode 100644 index 000000000..e8b764809 --- /dev/null +++ b/CslibTests/QueryModel/ProgExamples.lean @@ -0,0 +1,133 @@ +/- +Copyright (c) 2025 Shreyas Srinivas. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Shreyas Srinivas +-/ + +module + +public import Cslib.AlgorithmsTheory.QueryModel +public import Mathlib.Algebra.Lie.OfAssociative + +@[expose] public section + +/-! +# Additional examples of Progs with Query Types + +This file contains two query types and associated `Prog`s +- `Arith` with `ex1` +- `VectorSortOps` with `simpleExample` +- `VecSearch` with `linearSearch` +They are meant to be additional examples to guide authors to write +query types and programs on top of them +-/ +namespace Cslib + +namespace Algorithms + +namespace Prog + +section ProgExamples + +inductive Arith (α : Type u) : Type u → Type _ where + | add (x y : α) : Arith α α + | mul (x y : α) : Arith α α + | neg (x : α) : Arith α α + | zero : Arith α α + | one : Arith α α + +def Arith.natCost [Ring α] : Model (Arith α) ℕ where + evalQuery + | .add x y => x + y + | .mul x y => x * y + | .neg x => -x + | .zero => 0 + | .one => 1 + cost _ := 1 + +open Arith in +def ex1 : Prog (Arith α) α := do + let mut x : α ← @zero α + let mut y ← @one α + let z ← (add x y) + let w ← @neg α (← add z y) + add w z + +/-- The array version of the sort operations. -/ +inductive VecSortOps.{u} (α : Type u) : Type u → Type _ where + | swap (a : Vector α n) (i j : Fin n) : VecSortOps α (Vector α n) + -- Note that we have to ULift the result to fit this in the same universe as the other types. + -- We can avoid this only by forcing everything to be in `Type 0`. + | cmp (a : Vector α n) (i j : Fin n) : VecSortOps α (ULift Bool) + | write (a : Vector α n) (i : Fin n) (x : α) : VecSortOps α (Vector α n) + | read (a : Vector α n) (i : Fin n) : VecSortOps α α + | push (a : Vector α n) (elem : α) : VecSortOps α (Vector α (n + 1)) + +/-- The typical means of evaluating a `VecSortOps`. -/ +@[simp] +def VecSortOps.eval [BEq α] : VecSortOps α β → β + | .write v i x => v.set i x + | .cmp l i j => .up <| l[i] == l[j] + | .read l i => l[i] + | .swap l i j => l.swap i j + | .push a elem => a.push elem + +@[simps] +def VecSortOps.worstCase [DecidableEq α] : Model (VecSortOps α) ℕ where + evalQuery := VecSortOps.eval + cost + | .write _ _ _ => 1 + | .read _ _ => 1 + | .cmp _ _ _ => 1 + | .swap _ _ _ => 1 + | .push _ _ => 2 -- amortized over array insertion and resizing by doubling + +@[simps] +def VecSortOps.cmpSwap [DecidableEq α] : Model (VecSortOps α) ℕ where + evalQuery := VecSortOps.eval + cost + | .cmp _ _ _ => 1 + | .swap _ _ _ => 1 + | _ => 0 + +open VecSortOps in +def simpleExample (v : Vector ℤ n) (i k : Fin n) : + Prog (VecSortOps ℤ) (Vector ℤ (n + 1)) := do + let b : Vector ℤ n ← write v i 10 + let mut c : Vector ℤ n ← swap b i k + let elem ← read c i + push c elem + +inductive VecSearch (α : Type u) : Type → Type _ where + | compare (a : Vector α n) (i : ℕ) (val : α) : VecSearch α Bool + +@[simps] +def VecSearch.nat [DecidableEq α] : Model (VecSearch α) ℕ where + evalQuery + | .compare l i x => l[i]? == some x + cost + | .compare _ _ _ => 1 + +open VecSearch in +def linearSearchAux (v : Vector α n) + (x : α) (acc : Bool) (index : ℕ) : Prog (VecSearch α) Bool := do + if h : index ≥ n then + return acc + else + let cmp_res : Bool ← compare v index x + if cmp_res then + return true + else + linearSearchAux v x false (index + 1) + +open VecSearch in +def linearSearch (v : Vector α n) (x : α) : Prog (VecSearch α) Bool:= + linearSearchAux v x false 0 + +end ProgExamples + +end Prog + +end Algorithms + +end Cslib diff --git a/CslibTests/QueryModel/QueryExamples.lean b/CslibTests/QueryModel/QueryExamples.lean new file mode 100644 index 000000000..b4134bed1 --- /dev/null +++ b/CslibTests/QueryModel/QueryExamples.lean @@ -0,0 +1,107 @@ +/- +Copyright (c) 2025 Shreyas Srinivas. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Shreyas Srinivas +-/ + +module + +public import Cslib.AlgorithmsTheory.QueryModel +public import Cslib.AlgorithmsTheory.QueryModel +public import Mathlib.Algebra.Ring.ULift +public import Mathlib.Data.Nat.Log + +@[expose] public section + +/-! +# Additional examples of Query Types + +This file contains two query types +- `ListOpsWithFind` +- `ArrayOpsWithFind` +which respectively provide query types for List and Array operations +equipped with a searching algorithm, and different models for them. +They are meant to be additional examples to guide authors of query types +-/ +namespace Cslib + +namespace Algorithms + +section Examples + +/-- +ListOpsWithFind provides an example of list query type equipped with a `find` query. +The complexity of this query depends on the search algorithm used. This means +we can define two separate models for modelling situations where linear search +or binary search is used. +-/ +inductive ListOpsWithFind (α : Type u) : Type u → Type _ where + | get (l : List α) (i : Fin l.length) : ListOpsWithFind α α + | find (l : List α) (elem : α) : ListOpsWithFind α (ULift ℕ) + | write (l : List α) (i : Fin l.length) (x : α) : ListOpsWithFind α (List α) + +/-- The typical means of evaluating a `ListOps`. -/ +@[simp] +def ListOpsWithFind.eval [BEq α] : ListOpsWithFind α ι → ι + | .write l i x => l.set i x + | .find l elem => l.findIdx (· == elem) + | .get l i => l[i] + +/-- +A model of `ListOpsWithFind` that assumes that `find` is implemented by a +linear search like `Θ(n)` algorithm. +-/ +@[simps] +def ListOpsWithFind.linSearchWorstCase [DecidableEq α] : Model (ListOpsWithFind α) ℕ where + evalQuery := ListOpsWithFind.eval + cost + | .write l _ _ => l.length + | .find l _ => l.length + | .get l _ => l.length + +/-- +A model of `ListOpsWithFind` that assumes that `find` is implemented by a +binary-search like `Θ(log n)` algorithm. +-/ +@[simps] +def ListOps.binSearchWorstCase [BEq α] : Model (ListOpsWithFind α) ℕ where + evalQuery := ListOpsWithFind.eval + cost + | .find l _ => 1 + Nat.log 2 (l.length) + | .write l _ _ => l.length + | .get l _ => l.length + +/-- +ArrayOpsWithFind is the `Array` version of `ListOpsWithFind`. It comes with +`get` and `write` queries, and additionally a `find` query which corresponds +to a search algorithm. +-/ +inductive ArrayOpsWithFind (α : Type u) : Type u → Type _ where + | get (l : Array α) (i : Fin l.size) : ArrayOpsWithFind α α + | find (l : Array α) (x : α) : ArrayOpsWithFind α (ULift ℕ) + | write (l : Array α) (i : Fin l.size) (x : α) : ArrayOpsWithFind α (Array α) + +/-- The typical means of evaluating a `ListOps`. -/ +@[simp] +def ArrayOpsWithFind.eval [BEq α] : ArrayOpsWithFind α ι → ι + | .write l i x => l.set i x + | .find l elem => l.findIdx (· == elem) + | .get l i => l[i] + +/-- +A model of `ArrayOpsWithFind` that assumes that `find` is implemented by a +binary-search like `Θ(log n)` algorithm. +-/ +@[simps] +def ArrayOpsWithFind.binSearchWorstCase [BEq α] : Model (ArrayOpsWithFind α) ℕ where + evalQuery := ArrayOpsWithFind.eval + cost + | .find l _ => 1 + Nat.log 2 (l.size) + | .write _ _ _ => 1 + | .get _ _ => 1 + +end Examples + +end Algorithms + +end Cslib