diff --git a/Cslib.lean b/Cslib.lean index 3e5977af2..c68826c98 100644 --- a/Cslib.lean +++ b/Cslib.lean @@ -120,3 +120,10 @@ public import Cslib.Logics.LinearLogic.CLL.CutElimination public import Cslib.Logics.LinearLogic.CLL.EtaExpansion public import Cslib.Logics.LinearLogic.CLL.PhaseSemantics.Basic public import Cslib.Logics.Propositional.Defs +public import Cslib.MachineLearning.PACLearning.Defs +public import Cslib.MachineLearning.PACLearning.SampleComplexityLower +public import Cslib.MachineLearning.PACLearning.SampleComplexityLower.AdversarialMeasure +public import Cslib.MachineLearning.PACLearning.SampleComplexityLower.EHKVProof +public import Cslib.MachineLearning.PACLearning.SampleComplexityLower.Helpers +public import Cslib.MachineLearning.PACLearning.SampleComplexityLower.InvolutionPairing +public import Cslib.MachineLearning.PACLearning.VCDimension diff --git a/Cslib/MachineLearning/PACLearning/Defs.lean b/Cslib/MachineLearning/PACLearning/Defs.lean new file mode 100644 index 000000000..625c60344 --- /dev/null +++ b/Cslib/MachineLearning/PACLearning/Defs.lean @@ -0,0 +1,178 @@ +/- +Copyright (c) 2026 Samuel Schlesinger. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Samuel Schlesinger +-/ + +module + +public import Cslib.Init +public import Mathlib.MeasureTheory.Measure.MeasureSpace +public import Mathlib.MeasureTheory.Constructions.Pi +public import Mathlib.Order.SymmDiff + +@[expose] public section + +/-! # PAC Learning + +This file defines the Probably Approximately Correct (PAC) learning model +introduced by Valiant [Valiant1984]. A concept class `C` over a domain `α` is a +collection of subsets of `α`. A learning algorithm receives a labeled sample +drawn i.i.d. from an unknown distribution and must produce a hypothesis that, +with high probability, has low error with respect to the true concept. + +## Main definitions + +- `ConceptClass`: a concept class over domain `α`, i.e., a set of subsets. +- `LabeledSample`: a finite sequence of (point, label) pairs. +- `sampleOf`: constructs a labeled sample from a sequence of points and a concept. +- `hypothesisError`: the total error of a hypothesis with respect to a concept under a + distribution, defined as the measure of the symmetric difference. +- `falsePositiveError`: the false positive error `P(h \ c)`. +- `falseNegativeError`: the false negative error `P(c \ h)`. +- `Learner`: a function from labeled samples to hypotheses. +- `IsPACLearner`: the property that a deterministic learner produces a hypothesis + with error at most `ε` with probability at least `1 - δ`, for every distribution + and concept from the class. +- `IsRPACLearner`: the randomized variant, where the learner draws internal + randomness from a probability space `(Ω, Q)`. +- `sampleComplexity`: the smallest sample size admitting a deterministic PAC learner. +- `rsampleComplexity`: the smallest sample size admitting a randomized PAC learner. + +## Main statements + +- `IsPACLearner.toIsRPACLearner`: every deterministic PAC learner is in particular + a randomized PAC learner (with the trivial randomness space `PUnit`). +- `hypothesisError_eq_add`: total error decomposes as the sum of false positive and + false negative errors. + +## References + +* [L. G. Valiant, *A Theory of the Learnable*][Valiant1984] +* [A. Ehrenfeucht, D. Haussler, M. Kearns, L. Valiant, + *A General Lower Bound on the Number of Examples Needed for Learning*][EHKV1989] +-/ + +open MeasureTheory Set +open scoped ENNReal + +namespace Cslib.MachineLearning + +/-- A *concept class* over domain `α` is a collection of subsets of `α`. Each subset represents +a concept (i.e., a binary classifier). -/ +abbrev ConceptClass (α : Type*) := Set (Set α) + +/-- A *labeled sample* of size `m` over domain `α` is a sequence of `(point, label)` pairs. -/ +abbrev LabeledSample (α : Type*) (m : ℕ) := Fin m → (α × Bool) + +open Classical in +/-- Construct a labeled sample from a sequence of points `xs` and a concept `c`. +Each point is labeled `true` if it belongs to the concept and `false` otherwise. -/ +noncomputable def sampleOf {α : Type*} {m : ℕ} (c : Set α) (xs : Fin m → α) : + LabeledSample α m := + fun i => (xs i, decide (xs i ∈ c)) + +/-- The *error* of a hypothesis `h` with respect to a target concept `c` under distribution `P`, +defined as the measure of their symmetric difference `h ∆ c`. -/ +noncomputable def hypothesisError {α : Type*} [MeasurableSpace α] (P : Measure α) + (h c : Set α) : ℝ≥0∞ := + P (symmDiff h c) + +/-- The *false positive error* of a hypothesis `h` with respect to a target concept `c` +under distribution `P`, defined as the measure of `h \ c` — points classified positive +but not in the concept. -/ +noncomputable def falsePositiveError {α : Type*} [MeasurableSpace α] (P : Measure α) + (h c : Set α) : ℝ≥0∞ := + P (h \ c) + +/-- The *false negative error* of a hypothesis `h` with respect to a target concept `c` +under distribution `P`, defined as the measure of `c \ h` — points in the concept but +classified negative. -/ +noncomputable def falseNegativeError {α : Type*} [MeasurableSpace α] (P : Measure α) + (h c : Set α) : ℝ≥0∞ := + P (c \ h) + +/-- The total hypothesis error decomposes as the sum of false positive and false negative +errors, since `h ∆ c = (h \ c) ∪ (c \ h)` is a disjoint union. -/ +theorem hypothesisError_eq_add {α : Type*} [MeasurableSpace α] {P : Measure α} + {h c : Set α} (hh : MeasurableSet h) (hc : MeasurableSet c) : + hypothesisError P h c = falsePositiveError P h c + falseNegativeError P h c := by + simp only [hypothesisError, falsePositiveError, falseNegativeError, symmDiff_def, sup_eq_union] + exact measure_union disjoint_sdiff_sdiff (hc.diff hh) + +/-- A learner using `m` samples is a function that takes a labeled sample of size `m` and produces +a hypothesis (a subset of the domain). -/ +abbrev Learner (α : Type*) (m : ℕ) := LabeledSample α m → Set α + +variable {α : Type*} [MeasurableSpace α] + +/-- `IsPACLearner m ε δ C` asserts that there exists a learner using `m` samples that is +`(ε, δ)`-correct for the concept class `C`: for every probability measure `P` on `α` and every +target concept `c ∈ C`, the probability (over i.i.d. samples from `P`) that the learner's +hypothesis has error greater than `ε` is at most `δ`. + +More precisely, we require that the set of sample-vectors whose induced hypothesis has error +exceeding `ε` has `P^m`-measure at most `δ`. -/ +def IsPACLearner (m : ℕ) (ε δ : ℝ≥0∞) (C : ConceptClass α) : Prop := + ∃ A : Learner α m, + ∀ (P : Measure α) [IsProbabilityMeasure P], + ∀ c ∈ C, + (Measure.pi (fun _ : Fin m => P)) + {xs : Fin m → α | hypothesisError P (A (sampleOf c xs)) c > ε} ≤ δ + +/-- `IsRPACLearner m ε δ C` asserts that there exists a *randomized* learner using `m` samples +that is `(ε, δ)`-correct for the concept class `C`. A randomized learner draws internal +randomness `ω` from a probability space `(Ω, Q)` and then acts as the deterministic learner +`A(ω)`. + +For every probability measure `P` on `α` and every target concept `c ∈ C`, the failure +probability function `ω ↦ P^m{xs | error(A(ω)(xs), c) > ε}` must be `Q`-a.e. measurable, +and its expectation over `ω` must be at most `δ`. + +A deterministic PAC learner (`IsPACLearner`) is the special case `Ω = PUnit`; +see `IsPACLearner.toIsRPACLearner`. -/ +def IsRPACLearner (m : ℕ) (ε δ : ℝ≥0∞) (C : ConceptClass α) : Prop := + ∃ (Ω : Type*) (_ : MeasurableSpace Ω) (Q : Measure Ω) (_ : IsProbabilityMeasure Q) + (A : Ω → Learner α m), + ∀ (P : Measure α) [IsProbabilityMeasure P], + ∀ c ∈ C, + AEMeasurable (fun ω => (Measure.pi (fun _ : Fin m => P)) + {xs : Fin m → α | hypothesisError P ((A ω) (sampleOf c xs)) c > ε}) Q ∧ + ∫⁻ ω, (Measure.pi (fun _ : Fin m => P)) + {xs : Fin m → α | hypothesisError P ((A ω) (sampleOf c xs)) c > ε} ∂Q ≤ δ + +/-- Every deterministic PAC learner is in particular a randomized PAC learner +(with the trivial one-point randomness space `PUnit`). -/ +theorem IsPACLearner.toIsRPACLearner {m : ℕ} {ε δ : ℝ≥0∞} {C : ConceptClass α} + (h : IsPACLearner m ε δ C) : IsRPACLearner m ε δ C := by + obtain ⟨A, hA⟩ := h + refine ⟨PUnit, inferInstance, Measure.dirac PUnit.unit, inferInstance, fun _ => A, ?_⟩ + intro P _ c hc + exact ⟨measurable_const.aemeasurable, by + simp only [gt_iff_lt, lintegral_const, measure_univ, mul_one]; exact hA P c hc⟩ + +/-- The *deterministic sample complexity* of a concept class `C` at accuracy `ε` and confidence `δ` +is the smallest sample size `m` such that a deterministic `(ε, δ)`-PAC learner for `C` exists +using `m` samples. See also `rsampleComplexity` for the randomized variant. + +**Caveat**: because `sInf` on `ℕ` returns `0` for the empty set, this definition returns `0` when +no deterministic learner exists (e.g., when `C` has infinite VC dimension). It is only meaningful +when the defining set `{m | IsPACLearner m ε δ C}` is nonempty. -/ +noncomputable def sampleComplexity (C : ConceptClass α) (ε δ : ℝ≥0∞) : ℕ := + sInf {m : ℕ | IsPACLearner m ε δ C} + +/-- The *randomized sample complexity* of a concept class `C` at accuracy `ε` and confidence `δ` +is the smallest sample size `m` such that a randomized `(ε, δ)`-PAC learner for `C` exists +using `m` samples. This is at most `sampleComplexity C ε δ` since every deterministic learner +is also a randomized learner (see `IsPACLearner.toIsRPACLearner`). + +The universe of the randomness space `Ω` is pinned to `Type 0` (via `.{_, 0}`) so that the +`sInf` is taken over a definite set; without the pin the existential quantifier over `Ω : Type*` +would range over all universe levels, making the set ill-defined. + +**Caveat**: because `sInf` on `ℕ` returns `0` for the empty set, this definition returns `0` when +no randomized learner exists. It is only meaningful when the defining set is nonempty. -/ +noncomputable def rsampleComplexity (C : ConceptClass α) (ε δ : ℝ≥0∞) : ℕ := + sInf {m : ℕ | IsRPACLearner.{_, 0} m ε δ C} + +end Cslib.MachineLearning diff --git a/Cslib/MachineLearning/PACLearning/SampleComplexityLower.lean b/Cslib/MachineLearning/PACLearning/SampleComplexityLower.lean new file mode 100644 index 000000000..4be9edb8a --- /dev/null +++ b/Cslib/MachineLearning/PACLearning/SampleComplexityLower.lean @@ -0,0 +1,191 @@ +/- +Copyright (c) 2026 Samuel Schlesinger. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Samuel Schlesinger +-/ + +module + +public import Cslib.MachineLearning.PACLearning.SampleComplexityLower.EHKVProof + +@[expose] public section + +/-! # Sample Complexity Lower Bound + +We use the prefix `ehkv` for Ehrenfeucht–Haussler–Kearns–Valiant throughout. + +This module formalizes the main result of [EHKV1989]: a lower bound on the +number of examples required for distribution-free PAC learning of a concept +class, in terms of its Vapnik-Chervonenkis dimension. + +**Theorem 1** [EHKV1989, Theorem 1]: Assume `0 < ε ≤ 1/8`, +`0 < δ < 1/14`, and `VCdim(C) ≥ 2`. Then any `(ε, δ)`-learning +algorithm for `C` must use sample size at least `(VCdim(C) - 1) / (32ε)`. + +The proof constructs an adversarial distribution `P` on `d + 1` points +(where `d = VCdim(C) - 1`). Via a Markov/Bernoulli bound (Lemma 3), +"bad" samples — those which do not reveal enough of the shattered +set — occur with probability `> 1/2` when the sample size is too +small. An involution pairing argument (Lemma 2) shows that for each +bad sample, at least half of the `2^d` concepts obtained from +shattering force large error, and a counting/contradiction argument +then produces a single concept whose failure probability exceeds `δ`. + +## Proof structure (submodules) + +- `SampleComplexityLower.Helpers`: generic lemmas (Bernoulli inequality, + product measure support, `seenElements`, integration bound) +- `SampleComplexityLower.AdversarialMeasure`: construction of the + adversarial probability distribution on `d + 1` points +- `SampleComplexityLower.InvolutionPairing`: the involution/pairing + argument and complementary-error contradiction +- `SampleComplexityLower.EHKVProof`: Markov bound on bad samples, + half-fraction sum lower bound, and the assembled contradiction + +## Main statements + +- `sample_complexity_lower_bound_randomized`: **Theorem 1** of [EHKV1989] for + randomized learners — the full strength of the result. +- `sample_complexity_lower_bound`: deterministic corollary via + `IsPACLearner.toIsRPACLearner`. +- `sample_complexity_lower_bound_vcDim`: the bound stated in terms of `vcDim`. +- `sampleComplexity_lower_bound_vcDim`: lower bound on `sampleComplexity` via `vcDim`. +- `rsampleComplexity_lower_bound_vcDim`: lower bound on `rsampleComplexity` via `vcDim`. + +## References + +* [A. Ehrenfeucht, D. Haussler, M. Kearns, L. Valiant, + *A General Lower Bound on the Number of Examples Needed + for Learning*][EHKV1989] +-/ + +open MeasureTheory Set Finset +open scoped ENNReal + +noncomputable section + +namespace Cslib.MachineLearning + +/-- **Theorem 1 (randomized)** [EHKV1989]: The sample-complexity lower bound +`(VCdim(C) - 1) / (32 * ε) ≤ m` holds for *randomized* `(ε, δ)`-PAC +learners. This is the full strength of the EHKV result. -/ +theorem sample_complexity_lower_bound_randomized + {α : Type*} [MeasurableSpace α] [MeasurableSingletonClass α] + {C : ConceptClass α} + {W : Finset α} (hW : SetShatters C (↑W)) + {m : ℕ} {ε δ : ℝ≥0∞} + (hε_pos : 0 < ε) (hε_le : ε ≤ ENNReal.ofReal (1 / 8)) + (hδ_lt : δ < ENNReal.ofReal (1 / 14)) + (hW_card : 2 ≤ W.card) + (hlearn : IsRPACLearner m ε δ C) : + (W.card - 1 : ℝ) / (32 * ε.toReal) ≤ m := by + by_contra h + push Not at h + have hε_ne_top : ε ≠ ⊤ := ne_top_of_le_ne_top ENNReal.ofReal_ne_top hε_le + have hε'_pos : 0 < ε.toReal := ENNReal.toReal_pos (ne_of_gt hε_pos) hε_ne_top + have h32ε_pos : (0 : ℝ) < 32 * ε.toReal := by positivity + have hW_sub : (0 : ℝ) < (W.card : ℝ) - 1 := by + have : (2 : ℝ) ≤ (W.card : ℝ) := by exact_mod_cast hW_card + linarith + have hm_ennreal : (↑m : ℝ≥0∞) < ENNReal.ofReal + ((W.card - 1 : ℝ) / (32 * ε.toReal)) := by + rw [← ENNReal.ofReal_natCast (n := m)] + exact ENNReal.ofReal_lt_ofReal_iff (div_pos hW_sub h32ε_pos) |>.mpr h + obtain ⟨Ω, mΩ, Q, hQ, A, hA⟩ := hlearn + have hA_aem : ∀ (P : Measure α) [IsProbabilityMeasure P], ∀ c ∈ C, + AEMeasurable (fun ω => (Measure.pi (fun _ : Fin m => P)) + {xs : Fin m → α | hypothesisError P ((A ω) (sampleOf c xs)) c > ε}) Q := + fun P _ c hc => (hA P c hc).1 + obtain ⟨P, hP, c, hc, hbad⟩ := + exists_bad_distribution_and_concept_randomized hW hW_card + hε_pos hε_le hδ_lt hm_ennreal Q A hA_aem + haveI := hP + exact absurd ((hA P c hc).2) (not_le_of_gt hbad) + +/-- **Theorem 1** [EHKV1989]: Assume `0 < ε ≤ 1/8`, `0 < δ < 1/14`, +and `VCdim(C) ≥ 2`. Then any deterministic `(ε, δ)`-learning algorithm +for `C` must use sample size `m` satisfying `(VCdim(C) - 1) / (32 * ε) ≤ m`. + +This is a corollary of the stronger `sample_complexity_lower_bound_randomized` +via `IsPACLearner.toIsRPACLearner`. -/ +theorem sample_complexity_lower_bound + {α : Type*} [MeasurableSpace α] [MeasurableSingletonClass α] + {C : ConceptClass α} + {W : Finset α} (hW : SetShatters C (↑W)) + {m : ℕ} {ε δ : ℝ≥0∞} + (hε_pos : 0 < ε) (hε_le : ε ≤ ENNReal.ofReal (1 / 8)) + (hδ_lt : δ < ENNReal.ofReal (1 / 14)) + (hW_card : 2 ≤ W.card) + (hlearn : IsPACLearner m ε δ C) : + (W.card - 1 : ℝ) / (32 * ε.toReal) ≤ m := by + exact sample_complexity_lower_bound_randomized hW hε_pos hε_le hδ_lt hW_card + (IsPACLearner.toIsRPACLearner.{_, 0} hlearn) + +/-- **Corollary**: The EHKV sample-complexity lower bound stated in terms of `vcDim`. + +If the VC dimension of `C` is at least `2` (and is finite, i.e., the defining set is bounded +above), then any randomized `(ε, δ)`-PAC learner for `C` must use at least +`(vcDim C - 1) / (32ε)` samples. + +This wraps `sample_complexity_lower_bound_randomized` by extracting a shattered witness from +`vcDim`. -/ +theorem sample_complexity_lower_bound_vcDim + {α : Type*} [MeasurableSpace α] [MeasurableSingletonClass α] + {C : ConceptClass α} + {m : ℕ} {ε δ : ℝ≥0∞} + (hε_pos : 0 < ε) (hε_le : ε ≤ ENNReal.ofReal (1 / 8)) + (hδ_lt : δ < ENNReal.ofReal (1 / 14)) + (hvc : 2 ≤ vcDim C) + (hbdd : BddAbove {n : ℕ | ∃ W : Finset α, W.card = n ∧ SetShatters C (↑W)}) + (hlearn : IsRPACLearner m ε δ C) : + (vcDim C - 1 : ℝ) / (32 * ε.toReal) ≤ m := by + set S := {n : ℕ | ∃ W : Finset α, W.card = n ∧ SetShatters C (↑W)} + have hne : S.Nonempty := by + by_contra hempty + rw [Set.not_nonempty_iff_eq_empty] at hempty + have : (2 : ℕ) ≤ sSup (∅ : Set ℕ) := hempty ▸ hvc + simp at this + obtain ⟨W, hWcard, hW⟩ := Nat.sSup_mem hne hbdd + have hW_card : 2 ≤ W.card := hWcard ▸ hvc + have hvc_eq : vcDim C = W.card := hWcard.symm + simp only [hvc_eq] + exact sample_complexity_lower_bound_randomized hW hε_pos hε_le hδ_lt hW_card hlearn + +/-- Lower bound on deterministic sample complexity in terms of `vcDim`. + +If the VC dimension is at least `2` and finite, then +`(vcDim C - 1) / (32ε) ≤ sampleComplexity C ε δ`, +provided the concept class is learnable (some deterministic PAC learner exists). -/ +theorem sampleComplexity_lower_bound_vcDim + {α : Type*} [MeasurableSpace α] [MeasurableSingletonClass α] + {C : ConceptClass α} + {ε δ : ℝ≥0∞} + (hε_pos : 0 < ε) (hε_le : ε ≤ ENNReal.ofReal (1 / 8)) + (hδ_lt : δ < ENNReal.ofReal (1 / 14)) + (hvc : 2 ≤ vcDim C) + (hbdd : BddAbove {n : ℕ | ∃ W : Finset α, W.card = n ∧ SetShatters C (↑W)}) + (hlearnable : {m : ℕ | IsPACLearner m ε δ C}.Nonempty) : + (vcDim C - 1 : ℝ) / (32 * ε.toReal) ≤ sampleComplexity C ε δ := by + have hmem := Nat.sInf_mem hlearnable + exact sample_complexity_lower_bound_vcDim hε_pos hε_le hδ_lt hvc hbdd + (IsPACLearner.toIsRPACLearner.{_, 0} hmem) + +/-- Lower bound on randomized sample complexity in terms of `vcDim`. + +If the VC dimension is at least `2` and finite, then +`(vcDim C - 1) / (32ε) ≤ rsampleComplexity C ε δ`, +provided the concept class is learnable by a randomized learner. -/ +theorem rsampleComplexity_lower_bound_vcDim + {α : Type*} [MeasurableSpace α] [MeasurableSingletonClass α] + {C : ConceptClass α} + {ε δ : ℝ≥0∞} + (hε_pos : 0 < ε) (hε_le : ε ≤ ENNReal.ofReal (1 / 8)) + (hδ_lt : δ < ENNReal.ofReal (1 / 14)) + (hvc : 2 ≤ vcDim C) + (hbdd : BddAbove {n : ℕ | ∃ W : Finset α, W.card = n ∧ SetShatters C (↑W)}) + (hlearnable : {m : ℕ | IsRPACLearner.{_, 0} m ε δ C}.Nonempty) : + (vcDim C - 1 : ℝ) / (32 * ε.toReal) ≤ rsampleComplexity C ε δ := by + have hmem := Nat.sInf_mem hlearnable + exact sample_complexity_lower_bound_vcDim hε_pos hε_le hδ_lt hvc hbdd hmem + +end Cslib.MachineLearning diff --git a/Cslib/MachineLearning/PACLearning/SampleComplexityLower/AdversarialMeasure.lean b/Cslib/MachineLearning/PACLearning/SampleComplexityLower/AdversarialMeasure.lean new file mode 100644 index 000000000..b2c041d20 --- /dev/null +++ b/Cslib/MachineLearning/PACLearning/SampleComplexityLower/AdversarialMeasure.lean @@ -0,0 +1,106 @@ +/- +Copyright (c) 2026 Samuel Schlesinger. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Samuel Schlesinger +-/ + +module + +public import Cslib.MachineLearning.PACLearning.SampleComplexityLower.Helpers + +@[expose] public section + +/-! # Adversarial Measure Construction + +Given a shattered set `W` with `|W| ≥ 2`, we construct a discrete +probability measure `P` supported on `W`. We pick an arbitrary element +`w₀ ∈ W` as the "heavy" point and put: +- mass `1 - 8ε` on `w₀` +- mass `8ε / (|W| - 1)` on each remaining point in `W \ {w₀}` + +## Main definitions + +- `adversarialMeasure W w₀ ε'`: the adversarial probability measure + +## Main statements + +- `adversarialMeasure_isProbabilityMeasure`: it is a probability measure +- `adversarialMeasure_singleton`: point mass on each `W'` element +- `adversarialMeasure_support`: support contained in `W` +-/ + +open MeasureTheory Set Finset +open scoped ENNReal + +noncomputable section + +namespace Cslib.MachineLearning + +section AdversarialMeasure + +variable {α : Type*} [MeasurableSpace α] + +open Classical in +/-- The adversarial probability measure for the EHKV lower bound. +Concentrated on a finite set `W`, with a heavy point `w₀` carrying +mass `1 - 8ε` and each of the remaining `d = |W| - 1` points +carrying mass `8ε/d`. -/ +def adversarialMeasure (W : Finset α) (w₀ : α) (ε' : ℝ) : + Measure α := + let W' := W.erase w₀ + ENNReal.ofReal (1 - 8 * ε') • Measure.dirac w₀ + + ∑ w ∈ W', ENNReal.ofReal (8 * ε' / W'.card) • Measure.dirac w + +open Classical in +/-- The adversarial measure is a probability measure when `0 < ε ≤ 1/8` +and `|W| ≥ 2`. -/ +theorem adversarialMeasure_isProbabilityMeasure + {W : Finset α} {w₀ : α} + {ε' : ℝ} (hε'_pos : 0 < ε') (hε'_le : ε' ≤ 1 / 8) + (hd : 1 ≤ (W.erase w₀).card) : + IsProbabilityMeasure (adversarialMeasure W w₀ ε') := by + set d := (W.erase w₀).card with hd_def + constructor + simp only [adversarialMeasure, Measure.coe_add, Pi.add_apply, + Measure.smul_apply, smul_eq_mul, Measure.dirac_apply' _ MeasurableSet.univ, + Set.indicator_univ, Pi.one_apply, mul_one, + Measure.finset_sum_apply, Finset.sum_const, nsmul_eq_mul] + have hd_pos : (0 : ℝ) < d := Nat.cast_pos.mpr (by omega) + rw [← ENNReal.ofReal_natCast (n := d)] + rw [← ENNReal.ofReal_mul (by exact_mod_cast hd_pos.le)] + rw [mul_div_cancel₀ _ (ne_of_gt hd_pos)] + rw [← ENNReal.ofReal_add (by linarith) (by linarith)] + simp [sub_add_cancel] + +open Classical in +/-- The adversarial measure assigns mass `8ε'/d` to each point in `W'`. -/ +theorem adversarialMeasure_singleton [MeasurableSingletonClass α] + {W : Finset α} {w₀ : α} {ε' : ℝ} + {w : α} (hw : w ∈ W.erase w₀) : + (adversarialMeasure W w₀ ε') {w} = ENNReal.ofReal (8 * ε' / (W.erase w₀).card) := by + have hw_ne : w₀ ≠ w := Ne.symm (ne_of_mem_erase hw) + simp only [adversarialMeasure, Measure.coe_add, Pi.add_apply, Measure.smul_apply, smul_eq_mul, + Measure.finset_sum_apply, Measure.dirac_apply, Set.indicator_apply, Set.mem_singleton_iff] + rw [if_neg hw_ne, mul_zero, zero_add] + simp_rw [Pi.one_apply] + simp_rw [mul_ite, mul_one, mul_zero] + rw [Finset.sum_ite_eq' _ _ _, if_pos hw] + +open Classical in +/-- The adversarial measure is supported on `W`: all mass outside `W` is zero. -/ +theorem adversarialMeasure_support [MeasurableSingletonClass α] + {W : Finset α} {w₀ : α} (hw₀ : w₀ ∈ W) {ε' : ℝ} : + (adversarialMeasure W w₀ ε') (↑W : Set α)ᶜ = 0 := by + have hw₀_not_compl : w₀ ∉ (↑W : Set α)ᶜ := by simp [hw₀] + have hw_not_compl : ∀ w ∈ W.erase w₀, w ∉ (↑W : Set α)ᶜ := by + intro w hw; simp [Finset.erase_subset _ _ hw] + simp only [adversarialMeasure, Measure.coe_add, Pi.add_apply, Measure.smul_apply, + smul_eq_mul, Measure.finset_sum_apply, Measure.dirac_apply] + rw [indicator_of_notMem hw₀_not_compl, mul_zero, zero_add] + apply Finset.sum_eq_zero + intro w hw + rw [indicator_of_notMem (hw_not_compl w hw), mul_zero] + +end AdversarialMeasure + +end Cslib.MachineLearning diff --git a/Cslib/MachineLearning/PACLearning/SampleComplexityLower/EHKVProof.lean b/Cslib/MachineLearning/PACLearning/SampleComplexityLower/EHKVProof.lean new file mode 100644 index 000000000..b2832d096 --- /dev/null +++ b/Cslib/MachineLearning/PACLearning/SampleComplexityLower/EHKVProof.lean @@ -0,0 +1,457 @@ +/- +Copyright (c) 2026 Samuel Schlesinger. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Samuel Schlesinger +-/ + +module + +public import Cslib.MachineLearning.PACLearning.SampleComplexityLower.AdversarialMeasure +public import Cslib.MachineLearning.PACLearning.SampleComplexityLower.InvolutionPairing + +@[expose] public section + +/-! # EHKV Proof Assembly + +This module assembles the Markov bound, involution pairing, and adversarial +measure construction into the full EHKV proof by contradiction. + +## Main statements + +- `markov_bad_samples`: **Lemma 3** [EHKV1989] — bad samples occur with + probability `> 1/2` when sample size is too small. +- `ehkv_sum_lower_bound`: the half-fraction sum lower bound via involution. +- `ehkv_final_contradiction`: the final arithmetic contradiction. +- `exists_bad_distribution_and_concept_randomized`: for any randomized learner + with too few samples, there exists an adversarial distribution and concept. + +## References + +* [A. Ehrenfeucht, D. Haussler, M. Kearns, L. Valiant, + *A General Lower Bound on the Number of Examples Needed + for Learning*][EHKV1989] +-/ + +open MeasureTheory Set Finset +open scoped ENNReal + +noncomputable section + +namespace Cslib.MachineLearning + +section EHKVProof + +variable {α : Type*} [MeasurableSpace α] [MeasurableSingletonClass α] + +open Classical in +/-- **Lemma 3** [EHKV1989]: Markov bound on bad samples. + +When the sample size `m` satisfies `m < d / (32ε)`, the probability +(under the product measure `P^m`) that the sample reveals at most +half of the shattered set `W'` exceeds `1/2`. -/ +theorem markov_bad_samples + {W : Finset α} {w₀ : α} (hw₀ : w₀ ∈ W) + (hW_card : 2 ≤ W.card) + {ε' : ℝ} (hε'_pos : 0 < ε') (hε'_le : ε' ≤ 1 / 8) + {m : ℕ} (hm : (m : ℝ) < ((W.erase w₀).card : ℝ) / (32 * ε')) + (P : Measure α) [IsProbabilityMeasure P] + (hP_w : ∀ w ∈ W.erase w₀, + P {w} = ENNReal.ofReal (8 * ε' / (W.erase w₀).card)) : + ENNReal.ofReal (1 / 2) < + (Measure.pi (fun _ : Fin m => P)) + {xs : Fin m → α | + (seenElements (W.erase w₀) xs).card ≤ (W.erase w₀).card / 2} := by + -- Setup + set W' := W.erase w₀ with hW'_def + set d := W'.card with hd_def + set μ := Measure.pi (fun _ : Fin m => P) with hμ_def + haveI hμ_prob : IsProbabilityMeasure μ := Measure.pi.instIsProbabilityMeasure _ + have hd_pos : 0 < d := by rw [hd_def, hW'_def, card_erase_of_mem hw₀]; omega + have hd_cast : (0 : ℝ) < (d : ℝ) := Nat.cast_pos.mpr hd_pos + have hp_nonneg : (0 : ℝ) ≤ 8 * ε' / d := by positivity + have hp_le_one : 8 * ε' / d ≤ 1 := by + rw [div_le_one hd_cast] + linarith [show (1 : ℝ) ≤ d from by exact_mod_cast hd_pos] + have hf_meas : Measurable (fun xs : Fin m → α => ((seenElements W' xs).card : ℝ≥0∞)) := + measurable_seenElements_card W' + -- Define "good" and "bad" sets + set good : Set (Fin m → α) := {xs | (seenElements W' xs).card ≤ d / 2} + set bad : Set (Fin m → α) := {xs | d / 2 < (seenElements W' xs).card} + have hgood_compl_bad : good = badᶜ := by + ext xs; simp only [good, bad, Set.mem_compl_iff, Set.mem_setOf_eq, not_lt] + have hbad_meas : MeasurableSet bad := by + have : bad = {xs | ((d / 2 : ℕ) : ℝ≥0∞) < ((seenElements W' xs).card : ℝ≥0∞)} := by + ext xs; simp only [bad, Set.mem_setOf_eq, Nat.cast_lt] + rw [this]; exact measurableSet_lt measurable_const hf_meas + -- Step 1: Bound ∫ card ∂μ ≤ ENNReal.ofReal(8*m*ε') via Bernoulli integration + have hE_bound : ∫⁻ xs, ((seenElements W' xs).card : ℝ≥0∞) ∂μ + ≤ ENNReal.ofReal (8 * ↑m * ε') := + calc ∫⁻ xs, ((seenElements W' xs).card : ℝ≥0∞) ∂μ + ≤ ENNReal.ofReal (↑d * (↑m * (8 * ε' / ↑d))) := + expected_seenElements_le hp_nonneg hp_le_one P hP_w + _ = ENNReal.ofReal (8 * ↑m * ε') := by congr 1; field_simp + -- Step 2: Apply Markov's inequality to bound μ(bad) + set k := d / 2 + 1 with hk_def + have hk_pos : (0 : ℝ) < (k : ℝ) := Nat.cast_pos.mpr (by omega) + have hbad_eq : bad = {xs | (k : ℝ≥0∞) ≤ ((seenElements W' xs).card : ℝ≥0∞)} := by + ext xs; simp only [bad, Set.mem_setOf_eq, Nat.cast_le]; omega + have hbad_bound : μ bad ≤ ENNReal.ofReal (8 * ↑m * ε' / ↑k) := by + rw [hbad_eq] + calc μ {xs | (k : ℝ≥0∞) ≤ ↑(seenElements W' xs).card} + ≤ (∫⁻ xs, ↑(seenElements W' xs).card ∂μ) / ↑k := + meas_ge_le_lintegral_div hf_meas.aemeasurable + (by exact_mod_cast (show k ≠ 0 by omega)) (ENNReal.natCast_ne_top k) + _ ≤ ENNReal.ofReal (8 * ↑m * ε') / ↑k := + ENNReal.div_le_div_right hE_bound _ + _ = ENNReal.ofReal (8 * ↑m * ε') / ENNReal.ofReal (↑k) := by + rw [ENNReal.ofReal_natCast] + _ = ENNReal.ofReal (8 * ↑m * ε' / ↑k) := + (ENNReal.ofReal_div_of_pos hk_pos).symm + -- Step 3: Show 8*m*ε'/k < 1/2 via arithmetic + have harith : 8 * ↑m * ε' / ↑k < 1 / 2 := by + have h8mε : 8 * (m : ℝ) * ε' < (↑d : ℝ) / 4 := by + calc 8 * (m : ℝ) * ε' < 8 * ((↑d : ℝ) / (32 * ε')) * ε' := + mul_lt_mul_of_pos_right (by linarith) hε'_pos + _ = ↑d / 4 := by field_simp; ring + have h2k : (d : ℝ) < 2 * ↑k := by + exact_mod_cast (show d < 2 * k from by omega) + calc 8 * ↑m * ε' / ↑k + < (↑d / 4) / ↑k := div_lt_div_of_pos_right h8mε hk_pos + _ = ↑d / (4 * ↑k) := by ring + _ < 1 / 2 := by + rw [div_lt_iff₀ (by positivity : (0 : ℝ) < 4 * ↑k)]; linarith + -- Step 4: μ(bad) < ENNReal.ofReal(1/2) + have hbad_lt : μ bad < ENNReal.ofReal (1 / 2) := calc + μ bad ≤ ENNReal.ofReal (8 * ↑m * ε' / ↑k) := hbad_bound + _ < ENNReal.ofReal (1 / 2) := + (ENNReal.ofReal_lt_ofReal_iff (by norm_num : (0 : ℝ) < 1 / 2)).mpr harith + -- Step 5: Complement argument: μ(good) > 1/2 + rw [hgood_compl_bad] + have hfin : μ bad ≠ ⊤ := ne_top_of_lt (lt_of_lt_of_le hbad_lt ENNReal.ofReal_lt_top.le) + have h_sum := prob_add_prob_compl hbad_meas (μ := μ) + rw [← ENNReal.add_lt_add_iff_left hfin, h_sum] + calc μ bad + ENNReal.ofReal (1 / 2) + < ENNReal.ofReal (1 / 2) + ENNReal.ofReal (1 / 2) := + ENNReal.add_lt_add_right ENNReal.ofReal_ne_top hbad_lt + _ = ENNReal.ofReal 1 := by + rw [← ENNReal.ofReal_add (by norm_num) (by norm_num)]; norm_num + _ = 1 := ENNReal.ofReal_one + +open Classical in +/-- Per-learner half-fraction sum lower bound: for any learner `A'` and + any assignment of concepts to subsets of `W' = W \ {w₀}` satisfying + the shattering intersection property, the weighted measure of "bad" + samples is bounded by the sum of failure measures over the powerset. + + This is the key quantitative step in the EHKV argument, independent + of any per-concept failure bound. It enables both the deterministic + and randomized lower bound proofs. -/ +theorem ehkv_sum_lower_bound + {W : Finset α} + (hW_card : 2 ≤ W.card) + {w₀ : α} (hw₀ : w₀ ∈ W) + {ε' : ℝ} (hε'_pos : 0 < ε') + {m : ℕ} + (A' : Learner α m) + (P : Measure α) [IsProbabilityMeasure P] + (hP_w : ∀ w ∈ W.erase w₀, + P {w} = ENNReal.ofReal (8 * ε' / (W.erase w₀).card)) + (hP_supp : P (↑W : Set α)ᶜ = 0) + (concepts : Finset α → Set α) + (hconcepts_eq : ∀ S ∈ (W.erase w₀).powerset, + concepts S ∩ ↑W = {w₀} ∪ ↑S) : + (2 ^ (W.erase w₀).card / 2 : ℕ) • + (Measure.pi (fun _ : Fin m => P)) + {xs : Fin m → α | + (seenElements (W.erase w₀) xs).card ≤ (W.erase w₀).card / 2} ≤ + ∑ S ∈ (W.erase w₀).powerset, + (Measure.pi (fun _ : Fin m => P)) + {xs : Fin m → α | + hypothesisError P (A' (sampleOf (concepts S) xs)) (concepts S) > + ENNReal.ofReal ε'} := by + set W' := W.erase w₀ with hW'_def + set d := W'.card with hd_def + set μ := Measure.pi (fun _ : Fin m => P) with hμ_def + have hd_pos : 0 < d := by rw [hd_def, hW'_def, card_erase_of_mem hw₀]; omega + -- Dependent cMap wrapper for cMap_sample_agree + let cMap : (S : Finset α) → S ∈ W'.powerset → Set α := fun S _ => concepts S + -- fail function + set fail : Set α → Set (Fin m → α) := + fun c => {xs | hypothesisError P (A' (sampleOf c xs)) c > ENNReal.ofReal ε'} with hfail_def + set B := {xs : Fin m → α | (seenElements W' xs).card ≤ d / 2} + set Wm := {xs : Fin m → α | ∀ i, xs i ∈ (↑W : Set α)} with hWm_def + have hμ_supp : μ Wmᶜ = 0 := pi_measure_compl_zero hP_supp + have hWm_ae : Wm ∈ (ae μ) := mem_ae_iff.mpr hμ_supp + -- Every set is null-measurable (P supported on finite W) + have hnull_meas : ∀ (S : Set (Fin m → α)), NullMeasurableSet S μ := + nullMeasurableSet_pi_of_finite_support hP_supp + -- Half-fraction: for xs ∈ B ∩ Wm, ≥ 2^{d-1} concepts fail + have hhalf_fraction : ∀ xs, xs ∈ B → xs ∈ Wm → + (2 ^ d / 2 : ℕ) ≤ (W'.powerset.filter (fun S => xs ∈ fail (concepts S))).card := by + intro xs hxs_bad hxs_Wm + let T := seenElements W' xs + let U := W' \ T + let σ : Finset α → Finset α := fun S => (S ∩ T) ∪ (U \ S) + have hσ_self : ∀ S, S ∈ W'.powerset → σ (σ S) = S := by + intro S hS; rw [Finset.mem_powerset] at hS + ext x; simp only [σ, U, Finset.mem_union, Finset.mem_inter, Finset.mem_sdiff]; tauto + have hσ_mem : ∀ S ∈ W'.powerset, σ S ∈ W'.powerset := by + intro S hS; rw [Finset.mem_powerset] at hS ⊢ + exact union_subset (inter_subset_left.trans hS) (sdiff_subset.trans sdiff_subset) + have hσ_agree_T : ∀ S, σ S ∩ T = S ∩ T := by + intro S; ext x + simp only [σ, U, Finset.mem_inter, Finset.mem_union, Finset.mem_sdiff]; tauto + -- Pairing: for each S, xs ∈ fail(concepts S) ∨ xs ∈ fail(concepts(σ S)) + have hpairing : ∀ S ∈ W'.powerset, + xs ∈ fail (concepts S) ∨ xs ∈ fail (concepts (σ S)) := by + intro S hS + have hσS_mem := hσ_mem S hS + have h_sample_agree : ∀ i, xs i ∈ concepts S ↔ xs i ∈ concepts (σ S) := + cMap_sample_agree cMap hconcepts_eq hS hσS_mem + (by rw [Finset.inter_comm, hσ_agree_T, Finset.inter_comm]) hxs_Wm + have h_same_sample : sampleOf (concepts S) xs = sampleOf (concepts (σ S)) xs := + sampleOf_eq_of_agree h_sample_agree + -- On U: concepts S and concepts(σ S) are complementary + set h₀_local := A' (sampleOf (concepts S) xs) + have hU_in_S : ∀ w ∈ U, w ∈ S → w ∉ σ S := by + intro w hwU hwS + simp only [σ, U, Finset.mem_union, Finset.mem_inter, Finset.mem_sdiff] at *; tauto + have hU_not_S : ∀ w ∈ U, w ∉ S → w ∈ σ S := by + intro w hwU hwnS + simp only [σ, Finset.mem_union, Finset.mem_sdiff]; exact Or.inr ⟨hwU, hwnS⟩ + have hU_sub_symmDiff : (↑U : Set α) ⊆ + symmDiff h₀_local (concepts S) ∪ symmDiff h₀_local (concepts (σ S)) := by + intro w hwU + have hwU' := mem_coe.mp hwU + have hwW : w ∈ (↑W : Set α) := + mem_coe.mpr (erase_subset _ _ (Finset.sdiff_subset hwU')) + by_cases hwS : w ∈ S + · have hw_cS : w ∈ concepts S := by + have : w ∈ ({w₀} ∪ ↑S : Set α) := Or.inr (mem_coe.mpr hwS) + rw [← hconcepts_eq S hS] at this; exact this.1 + have hw_ncσS : w ∉ concepts (σ S) := by + intro hw + have : w ∈ concepts (σ S) ∩ ↑W := ⟨hw, hwW⟩ + rw [hconcepts_eq (σ S) hσS_mem] at this + rcases this with hw0 | hwσS + · exact absurd (Set.mem_singleton_iff.mp hw0) + (Finset.ne_of_mem_erase (Finset.sdiff_subset hwU')) + · exact hU_in_S w hwU' hwS (mem_coe.mp hwσS) + by_cases hw_h : w ∈ h₀_local + · right; exact Set.mem_symmDiff.mpr (Or.inl ⟨hw_h, hw_ncσS⟩) + · left; exact Set.mem_symmDiff.mpr (Or.inr ⟨hw_cS, hw_h⟩) + · have hwσS := hU_not_S w hwU' hwS + have hw_cσS : w ∈ concepts (σ S) := by + have : w ∈ ({w₀} ∪ ↑(σ S) : Set α) := Or.inr (mem_coe.mpr hwσS) + rw [← hconcepts_eq (σ S) hσS_mem] at this; exact this.1 + have hw_ncS : w ∉ concepts S := by + intro hw + have : w ∈ concepts S ∩ ↑W := ⟨hw, hwW⟩ + rw [hconcepts_eq S hS] at this + rcases this with hw0 | hwS' + · exact absurd (Set.mem_singleton_iff.mp hw0) + (Finset.ne_of_mem_erase (Finset.sdiff_subset hwU')) + · exact hwS (mem_coe.mp hwS') + by_cases hw_h : w ∈ h₀_local + · left; exact Set.mem_symmDiff.mpr (Or.inl ⟨hw_h, hw_ncS⟩) + · right; exact Set.mem_symmDiff.mpr (Or.inr ⟨hw_cσS, hw_h⟩) + -- P(U) ≥ 4ε' + have hT_sub_W' : T ≤ W' := filter_subset _ _ + have h2U : d ≤ 2 * U.card := by + have hTeq : U.card = d - T.card := card_sdiff_of_subset hT_sub_W' + have hTle : T.card ≤ d / 2 := hxs_bad + omega + have hP_U : ENNReal.ofReal (4 * ε') ≤ P (↑U) := + unseen_measure_ge hε'_pos (by omega) h2U + (fun w hw => hP_w w (Finset.sdiff_subset hw)) + by_contra h_neither + push Not at h_neither + obtain ⟨hS_ok, hσS_ok⟩ := h_neither + simp only [hfail_def, Set.mem_setOf_eq, not_lt] at hS_ok hσS_ok + rw [show A' (sampleOf (concepts (σ S)) xs) = h₀_local from + congr_arg A' h_same_sample.symm] at hσS_ok + exact complementary_error_contradiction hε'_pos hU_sub_symmDiff hP_U hS_ok hσS_ok + rw [show 2 ^ d / 2 = W'.powerset.card / 2 from by rw [card_powerset]] + have := involution_half_count (P := fun S => xs ∈ fail (concepts S)) + hσ_self hσ_mem hpairing + convert this + -- Integration interchange + have haem : ∀ S ∈ W'.powerset, + AEMeasurable (fun xs => + (fail (concepts S)).indicator (1 : (Fin m → α) → ℝ≥0∞) xs) μ := + fun S _ => (aemeasurable_indicator_const_iff (1 : ℝ≥0∞)).mpr (hnull_meas _) + have hsum_eq_integral : + ∑ S ∈ W'.powerset, μ (fail (concepts S)) = + ∫⁻ xs, ∑ S ∈ W'.powerset, (fail (concepts S)).indicator 1 xs ∂μ := by + rw [lintegral_finset_sum' _ haem] + congr 1; ext S + exact (lintegral_indicator_one₀ (hnull_meas _)).symm + -- Lower bound assembly: (2^d/2) • μ(B) ≤ ∑ μ(fail(concepts S)) + rw [hsum_eq_integral, nsmul_eq_mul] + rw [show (↑(2 ^ d / 2 : ℕ) : ℝ≥0∞) * μ B = + ∫⁻ xs, B.indicator (fun _ => (↑(2 ^ d / 2 : ℕ) : ℝ≥0∞)) xs ∂μ from + (lintegral_indicator_const₀ (hnull_meas _) _).symm] + apply lintegral_mono_ae + filter_upwards [hWm_ae] with xs hxs_Wm + by_cases hxs_B : xs ∈ B + · simp only [Set.indicator_apply, hxs_B, ite_true, Pi.one_apply] + rw [Finset.sum_boole] + exact_mod_cast hhalf_fraction xs hxs_B hxs_Wm + · simp only [Set.indicator_apply, hxs_B, ite_false]; exact zero_le _ + +/-- The final arithmetic contradiction in the EHKV argument: if `(2^d/2) • μ(B) < 2^d • (1/14)` +but `1/2 < μ(B)`, then `μ(B) ≤ 1/7 < 1/2`, a contradiction. -/ +theorem ehkv_final_contradiction + {d : ℕ} (hd_pos : 0 < d) {μB : ℝ≥0∞} + (hB_prob : ENNReal.ofReal (1 / 2) < μB) + (h_combined : (2 ^ d / 2 : ℕ) • μB < (2 ^ d : ℕ) • ENNReal.ofReal (1 / 14)) : False := by + have hB_upper : μB ≤ ENNReal.ofReal (1 / 7) := by + have h2d_nat : (2 ^ d : ℕ) = 2 * (2 ^ d / 2) := + Nat.eq_mul_of_div_eq_right (dvd_pow_self 2 (by omega : d ≠ 0)) rfl + have hpow_half_pos : 0 < (2 ^ d / 2 : ℕ) := Nat.div_pos + (le_of_eq (pow_one 2).symm |>.trans (Nat.pow_le_pow_right (by omega) hd_pos)) + (by norm_num) + rw [nsmul_eq_mul, nsmul_eq_mul] at h_combined + have h_rhs : (↑(2 ^ d : ℕ) : ℝ≥0∞) * ENNReal.ofReal (1 / 14) = + ↑(2 ^ d / 2 : ℕ) * ENNReal.ofReal (1 / 7) := by + calc (↑(2 ^ d : ℕ) : ℝ≥0∞) * ENNReal.ofReal (1 / 14) + = ↑(2 * (2 ^ d / 2) : ℕ) * ENNReal.ofReal (1 / 14) := by rw [← h2d_nat] + _ = (2 * ↑(2 ^ d / 2 : ℕ)) * ENNReal.ofReal (1 / 14) := by push_cast; ring_nf + _ = ↑(2 ^ d / 2 : ℕ) * (2 * ENNReal.ofReal (1 / 14)) := by ring + _ = ↑(2 ^ d / 2 : ℕ) * ENNReal.ofReal (1 / 7) := by + congr 1 + rw [show (2 : ℝ≥0∞) = ENNReal.ofReal 2 from by norm_num, + ← ENNReal.ofReal_mul (by norm_num : (0 : ℝ) ≤ 2)] + congr 1; norm_num + rw [h_rhs] at h_combined + exact (lt_of_mul_lt_mul_left' h_combined).le + exact absurd hB_prob + (not_lt.mpr (hB_upper.trans (ENNReal.ofReal_le_ofReal (by norm_num : (1:ℝ)/7 ≤ 1/2)))) + +open Classical in +/-- **Randomized variant of Lemmas 2 + 3** [EHKV1989]: If the sample +size `m` is strictly less than `(|W| - 1) / (32ε)`, then for any +*randomized* learner `(Ω, Q, A)` there exists a probability measure `P` +and a target concept `c ∈ C` such that the learner's integrated error +exceeds `δ`. -/ +theorem exists_bad_distribution_and_concept_randomized + {α : Type*} [MeasurableSpace α] [MeasurableSingletonClass α] + {C : ConceptClass α} + {W : Finset α} (hW : SetShatters C (↑W)) + (hW_card : 2 ≤ W.card) + {m : ℕ} {ε δ : ℝ≥0∞} + (hε_pos : 0 < ε) (hε_le : ε ≤ ENNReal.ofReal (1 / 8)) + (hδ_lt : δ < ENNReal.ofReal (1 / 14)) + (hm : (↑m : ℝ≥0∞) < ENNReal.ofReal + ((W.card - 1 : ℝ) / (32 * ENNReal.toReal ε))) + {Ω : Type*} [MeasurableSpace Ω] (Q : Measure Ω) [IsProbabilityMeasure Q] + (A : Ω → Learner α m) + (hA_aem : ∀ (P : Measure α) [IsProbabilityMeasure P], ∀ c ∈ C, + AEMeasurable (fun ω => (Measure.pi (fun _ : Fin m => P)) + {xs : Fin m → α | hypothesisError P ((A ω) (sampleOf c xs)) c > ε}) Q) : + ∃ (P : Measure α) (_ : IsProbabilityMeasure P), + ∃ c ∈ C, + δ < ∫⁻ ω, (Measure.pi (fun _ : Fin m => P)) + {xs : Fin m → α | + hypothesisError P ((A ω) (sampleOf c xs)) c > ε} ∂Q := by + -- Extract real-valued parameters + have hε_ne_top : ε ≠ ⊤ := ne_top_of_le_ne_top ENNReal.ofReal_ne_top hε_le + set ε' := ε.toReal with hε'_def + have hε'_pos : 0 < ε' := ENNReal.toReal_pos (ne_of_gt hε_pos) hε_ne_top + have hε'_le : ε' ≤ 1 / 8 := by + rw [hε'_def] + have := (ENNReal.toReal_le_toReal hε_ne_top ENNReal.ofReal_ne_top).mpr hε_le + rwa [ENNReal.toReal_ofReal (by norm_num : (0 : ℝ) ≤ 1 / 8)] at this + have hε_eq : ENNReal.ofReal ε' = ε := ENNReal.ofReal_toReal hε_ne_top + -- Pick w₀ ∈ W and set up W' + have hW_nonempty : W.Nonempty := card_pos.mp (by omega) + obtain ⟨w₀, hw₀⟩ := hW_nonempty + set W' := W.erase w₀ with hW'_def + set d := W'.card with hd_def + have hd : 1 ≤ W'.card := by rw [hW'_def, card_erase_of_mem hw₀]; omega + have hd_pos : 0 < d := by omega + -- Construct the adversarial measure + set P := adversarialMeasure W w₀ ε' with hP_def + have hP_prob : IsProbabilityMeasure P := + adversarialMeasure_isProbabilityMeasure hε'_pos hε'_le hd + have hP_w : ∀ w ∈ W', P {w} = ENNReal.ofReal (8 * ε' / W'.card) := + fun w hw => adversarialMeasure_singleton hw + have hP_supp : P (↑W : Set α)ᶜ = 0 := adversarialMeasure_support hw₀ + -- Sample size bound in ℝ + have hm_real : (m : ℝ) < (W'.card : ℝ) / (32 * ε') := by + have hW'_eq : (W'.card : ℝ) = (W.card : ℝ) - 1 := by + rw [hW'_def, card_erase_of_mem hw₀] + simp [Nat.cast_sub (by omega : 1 ≤ W.card)] + rw [hW'_eq] + have h32ε_pos : (0 : ℝ) < 32 * ε' := by positivity + rw [← ENNReal.ofReal_natCast (n := m)] at hm + have hW_pos : (0 : ℝ) < (W.card : ℝ) - 1 := by + linarith [show (2 : ℝ) ≤ W.card from by exact_mod_cast hW_card] + rwa [ENNReal.ofReal_lt_ofReal_iff (div_pos hW_pos h32ε_pos)] at hm + haveI := hP_prob + set μ := Measure.pi (fun _ : Fin m => P) with hμ_def + refine ⟨P, hP_prob, ?_⟩ + -- It suffices to find c with ∫ μ(fail(c)) ≥ 1/14 + suffices ∃ c ∈ C, ENNReal.ofReal (1 / 14) ≤ + ∫⁻ ω, μ {xs : Fin m → α | + hypothesisError P ((A ω) (sampleOf c xs)) c > ε} ∂Q by + obtain ⟨c, hc, hprob⟩ := this + exact ⟨c, hc, lt_of_lt_of_le hδ_lt hprob⟩ + -- By contradiction + by_contra h_neg + push Not at h_neg + -- Choose concepts from shattering + have hcMap : ∀ S ∈ W'.powerset, ∃ c ∈ C, c ∩ ↑W = {w₀} ∪ ↑S := by + intro S hS + exact hW _ (Set.union_subset (Set.singleton_subset_iff.mpr (mem_coe.mpr hw₀)) + ((coe_subset.mpr (Finset.mem_powerset.mp hS)).trans (coe_subset.mpr (erase_subset _ _)))) + choose cMap hcMap_mem hcMap_eq using hcMap + set concepts : Finset α → Set α := + fun S => if h : S ∈ W'.powerset then cMap S h else ∅ with hconcepts_def + have hconcepts_eq : ∀ S ∈ W'.powerset, concepts S ∩ ↑W = {w₀} ∪ ↑S := by + intro S hS; simp only [concepts, dif_pos hS]; exact hcMap_eq S hS + have hconcepts_mem : ∀ S ∈ W'.powerset, concepts S ∈ C := by + intro S hS; simp only [concepts, dif_pos hS]; exact hcMap_mem S hS + -- Markov bound: 1/2 < μ(B) + set B := {xs : Fin m → α | (seenElements W' xs).card ≤ d / 2} + have hB_prob : ENNReal.ofReal (1 / 2) < μ B := + markov_bad_samples hw₀ hW_card hε'_pos hε'_le hm_real P hP_w + -- Per-ω lower bound via ehkv_sum_lower_bound, converted from ENNReal.ofReal ε' to ε + have hlower_ω : ∀ ω : Ω, (2 ^ d / 2 : ℕ) • μ B ≤ + ∑ S ∈ W'.powerset, μ {xs : Fin m → α | + hypothesisError P ((A ω) (sampleOf (concepts S) xs)) (concepts S) > ε} := by + intro ω; simpa only [hε_eq] using + ehkv_sum_lower_bound hW_card hw₀ hε'_pos (A ω) P hP_w hP_supp concepts hconcepts_eq + -- Integrate over ω and swap sum/integral + have hintegrate : (2 ^ d / 2 : ℕ) • μ B ≤ + ∑ S ∈ W'.powerset, ∫⁻ ω, μ {xs : Fin m → α | + hypothesisError P ((A ω) (sampleOf (concepts S) xs)) (concepts S) > ε} ∂Q := by + calc (2 ^ d / 2 : ℕ) • μ B + = ∫⁻ _ : Ω, ((2 ^ d / 2 : ℕ) • μ B : ℝ≥0∞) ∂Q := by + rw [lintegral_const, measure_univ, mul_one] + _ ≤ ∫⁻ ω, ∑ S ∈ W'.powerset, μ {xs : Fin m → α | + hypothesisError P ((A ω) (sampleOf (concepts S) xs)) (concepts S) > ε} ∂Q := + lintegral_mono (fun ω => hlower_ω ω) + _ = ∑ S ∈ W'.powerset, ∫⁻ ω, μ {xs : Fin m → α | + hypothesisError P ((A ω) (sampleOf (concepts S) xs)) (concepts S) > ε} ∂Q := + lintegral_finset_sum' _ (fun S hS => hA_aem P (concepts S) (hconcepts_mem S hS)) + -- Upper bound from h_neg: each concept's integrated failure < 1/14 + have hfail_bound : ∀ S ∈ W'.powerset, + ∫⁻ ω, μ {xs : Fin m → α | + hypothesisError P ((A ω) (sampleOf (concepts S) xs)) (concepts S) > ε} ∂Q < + ENNReal.ofReal (1 / 14) := + fun S hS => h_neg (concepts S) (hconcepts_mem S hS) + have hpow_nonempty : W'.powerset.Nonempty := ⟨∅, Finset.empty_mem_powerset _⟩ + -- Combined: (2^d/2) • μ(B) < 2^d • (1/14) + have h_combined : (2 ^ d / 2 : ℕ) • μ B < (2 ^ d : ℕ) • ENNReal.ofReal (1 / 14) := + lt_of_le_of_lt hintegrate (by + calc ∑ S ∈ W'.powerset, ∫⁻ ω, μ {xs : Fin m → α | + hypothesisError P ((A ω) (sampleOf (concepts S) xs)) (concepts S) > ε} ∂Q + < ∑ _S ∈ W'.powerset, ENNReal.ofReal (1 / 14) := + ENNReal.sum_lt_sum_of_nonempty hpow_nonempty hfail_bound + _ = (2 ^ d : ℕ) • ENNReal.ofReal (1 / 14) := by rw [Finset.sum_const, card_powerset]) + exact ehkv_final_contradiction hd_pos hB_prob h_combined + +end EHKVProof + +end Cslib.MachineLearning diff --git a/Cslib/MachineLearning/PACLearning/SampleComplexityLower/Helpers.lean b/Cslib/MachineLearning/PACLearning/SampleComplexityLower/Helpers.lean new file mode 100644 index 000000000..5b9c9ba4d --- /dev/null +++ b/Cslib/MachineLearning/PACLearning/SampleComplexityLower/Helpers.lean @@ -0,0 +1,209 @@ +/- +Copyright (c) 2026 Samuel Schlesinger. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Samuel Schlesinger +-/ + +module + +public import Cslib.MachineLearning.PACLearning.VCDimension + +@[expose] public section + +/-! # Sample Complexity Lower Bound — Helper Lemmas + +Generic reusable lemmas for product measures, sample functions, and +combinatorics used in the EHKV lower bound proof. + +## Main definitions + +- `seenElements W' xs`: the elements of a `Finset` that appear in a sample + +## Main statements + +- `one_sub_pow_le_mul`: Bernoulli's inequality `1 - (1-x)^n ≤ n·x` +- `sampleOf_eq_of_agree`: agreeing concepts yield the same labeled sample +- `hypothesisError_eq_of_inter_eq`: error invariance under same intersection +- `pi_measure_compl_zero`: product measure vanishes off `W^m` +- `nullMeasurableSet_pi_of_finite_support`: null-measurability from finite support +- `measurableSet_setOf_exists_pi_eq`: the set of samples containing a given point + is measurable +- `measurable_seenElements_card`: `xs ↦ |seenElements W' xs|` is measurable +- `expected_seenElements_le`: Bernoulli integration bound on seen elements +-/ + +open MeasureTheory Set Finset +open scoped ENNReal + +noncomputable section + +namespace Cslib.MachineLearning + +open Classical in +/-- The set of elements of a `Finset` that appear in a sample. -/ +noncomputable def seenElements {α : Type*} (W' : Finset α) {m : ℕ} (xs : Fin m → α) : Finset α := + W'.filter (fun w => ∃ i, xs i = w) + +/-- Bernoulli's inequality: `1 - (1 - x)^n ≤ n * x` for `0 ≤ x ≤ 1`. -/ +theorem one_sub_pow_le_mul {x : ℝ} (_hx : 0 ≤ x) (hx1 : x ≤ 1) (n : ℕ) : + 1 - (1 - x) ^ n ≤ ↑n * x := by + have h : -1 ≤ (1 - x) := by linarith + linarith [one_add_mul_sub_le_pow h n] + +/-- Two concepts that agree on all sample points produce the same labeled +sample. -/ +theorem sampleOf_eq_of_agree {α : Type*} {m : ℕ} {c₁ c₂ : Set α} + {xs : Fin m → α} (h : ∀ i, xs i ∈ c₁ ↔ xs i ∈ c₂) : + sampleOf c₁ xs = sampleOf c₂ xs := by + funext i; simp [sampleOf, h i] + +/-- Two concepts with the same intersection with `W` have the same hypothesis +error when the measure `P` is supported on `W`. -/ +theorem hypothesisError_eq_of_inter_eq {α : Type*} [MeasurableSpace α] + {P : Measure α} {W : Set α} (hP_supp : P Wᶜ = 0) + {h₀ c₁ c₂ : Set α} (hinter : c₁ ∩ W = c₂ ∩ W) : + hypothesisError P h₀ c₁ = hypothesisError P h₀ c₂ := by + simp only [hypothesisError] + have hP_restrict : ∀ A : Set α, P A = P (A ∩ W) := by + intro A + have h1 : P A ≤ P (A ∩ W ∪ Wᶜ) := + measure_mono (fun x hx => by_cases (fun hxW : x ∈ W => Or.inl ⟨hx, hxW⟩) + (fun hxW => Or.inr hxW)) + exact le_antisymm ((h1.trans (measure_union_le _ _)).trans (by rw [hP_supp, add_zero])) + (measure_mono Set.inter_subset_left) + rw [hP_restrict (symmDiff h₀ c₁), hP_restrict (symmDiff h₀ c₂)] + have : symmDiff h₀ c₁ ∩ W = symmDiff h₀ c₂ ∩ W := by + ext x + simp only [Set.mem_inter_iff, Set.mem_symmDiff, and_congr_left_iff] + intro hxW + have hc_iff : x ∈ c₁ ↔ x ∈ c₂ := + ⟨fun h => ((Set.ext_iff.mp hinter x).mp ⟨h, hxW⟩).1, + fun h => ((Set.ext_iff.mp hinter x).mpr ⟨h, hxW⟩).1⟩ + tauto + rw [this] + +/-- If a measure `P` on `α` gives zero mass to the complement of a finite set `W`, then +the product measure `P^m` gives zero mass to the complement of `W^m`. -/ +theorem pi_measure_compl_zero + {α : Type*} [MeasurableSpace α] [MeasurableSingletonClass α] + {W : Finset α} {P : Measure α} [SigmaFinite P] + (hP_supp : P (↑W : Set α)ᶜ = 0) + {m : ℕ} : + (Measure.pi (fun _ : Fin m => P)) + {xs : Fin m → α | ∀ i, xs i ∈ (↑W : Set α)}ᶜ = 0 := by + set μ := Measure.pi (fun _ : Fin m => P) + set Wm := {xs : Fin m → α | ∀ i, xs i ∈ (↑W : Set α)} + have hsub : Wmᶜ ⊆ ⋃ i : Fin m, Function.eval i ⁻¹' (↑W : Set α)ᶜ := by + intro xs hxs; simp only [Wm, Set.mem_compl_iff, Set.mem_setOf_eq, not_forall] at hxs + exact Set.mem_iUnion.mpr hxs + have hle : μ Wmᶜ ≤ 0 := + calc μ Wmᶜ ≤ μ (⋃ i, Function.eval i ⁻¹' (↑W : Set α)ᶜ) := measure_mono hsub + _ ≤ ∑ i : Fin m, μ (Function.eval i ⁻¹' (↑W : Set α)ᶜ) := + measure_iUnion_fintype_le μ _ + _ = ∑ _i : Fin m, (0 : ℝ≥0∞) := by + congr 1; ext i; exact Measure.pi_eval_preimage_null _ hP_supp + _ = 0 := Finset.sum_const_zero + exact le_antisymm hle (zero_le _) + +/-- If a measure `P` on `α` gives zero mass to the complement of a finite set `W`, then +every set in the product space is `NullMeasurableSet`. -/ +theorem nullMeasurableSet_pi_of_finite_support + {α : Type*} [MeasurableSpace α] [MeasurableSingletonClass α] + {W : Finset α} {P : Measure α} [SigmaFinite P] (hP_supp : P (↑W : Set α)ᶜ = 0) + {m : ℕ} (S : Set (Fin m → α)) : + NullMeasurableSet S (Measure.pi (fun _ : Fin m => P)) := by + set μ := Measure.pi (fun _ : Fin m => P) + set Wm := {xs : Fin m → α | ∀ i, xs i ∈ (↑W : Set α)} + have hμ_supp : μ Wmᶜ = 0 := pi_measure_compl_zero hP_supp + have hWm_finite : Wm.Finite := Set.Finite.pi' (fun _ => W.finite_toSet) + have hAWm_meas : MeasurableSet (S ∩ Wm) := + (hWm_finite.subset (show S ∩ Wm ⊆ Wm from fun _ h => h.2)).measurableSet + have hAWm_diff_null : μ (S \ Wm) = 0 := + measure_mono_null (fun _ ⟨_, hx⟩ => hx) hμ_supp + have hA_eq : S = (S ∩ Wm) ∪ (S \ Wm) := by ext x; simp + rw [hA_eq] + exact hAWm_meas.nullMeasurableSet.union (NullMeasurableSet.of_null hAWm_diff_null) + +/-- The set of sample vectors in which point `w` appears equals the union of +coordinate preimages `{xs | xs i = w}`. -/ +theorem setOf_exists_pi_eq_iUnion_preimage {α : Type*} {m : ℕ} (w : α) : + {xs : Fin m → α | ∃ i, xs i = w} = + ⋃ i : Fin m, (fun xs : Fin m → α => xs i) ⁻¹' {w} := by ext xs; simp + +/-- The set of sample vectors containing a given point is measurable. -/ +theorem measurableSet_setOf_exists_pi_eq + {α : Type*} [MeasurableSpace α] [MeasurableSingletonClass α] {m : ℕ} (w : α) : + MeasurableSet {xs : Fin m → α | ∃ i, xs i = w} := by + rw [setOf_exists_pi_eq_iUnion_preimage w] + exact MeasurableSet.iUnion (fun i => measurable_pi_apply i (MeasurableSet.singleton w)) + +open Classical in +/-- The cardinality of `seenElements W' xs` as an extended non-negative real equals a finset sum +of indicator functions over `W'`. -/ +theorem seenElements_card_eq_sum {α : Type*} {m : ℕ} (W' : Finset α) : + (fun xs : Fin m → α => ((seenElements W' xs).card : ℝ≥0∞)) = + (fun xs => ∑ w ∈ W', if (∃ i, xs i = w) then (1 : ℝ≥0∞) else 0) := by + ext xs; simp only [seenElements, Finset.card_filter]; push_cast; rfl + +/-- The function `xs ↦ |seenElements W' xs|` is measurable with respect to the +product σ-algebra. -/ +theorem measurable_seenElements_card + {α : Type*} [MeasurableSpace α] [MeasurableSingletonClass α] {m : ℕ} (W' : Finset α) : + Measurable (fun xs : Fin m → α => ((seenElements W' xs).card : ℝ≥0∞)) := by + rw [seenElements_card_eq_sum W'] + exact Finset.measurable_sum W' (fun w _ => + Measurable.ite (measurableSet_setOf_exists_pi_eq w) measurable_const measurable_const) + +open Classical in +/-- **Bernoulli integration bound**: the expected number of elements of `W'` +seen in a random sample of size `m` is at most `|W'| · m · p`, when each +element of `W'` has probability `p ≤ 1` under the base measure `P`. +Follows from summing Bernoulli's inequality `1 - (1-p)^m ≤ m·p` over `W'`. -/ +theorem expected_seenElements_le + {α : Type*} [MeasurableSpace α] [MeasurableSingletonClass α] + {W' : Finset α} {p : ℝ} (hp_nonneg : 0 ≤ p) (hp_le_one : p ≤ 1) + {m : ℕ} (P : Measure α) [IsProbabilityMeasure P] + (hP_w : ∀ w ∈ W', P {w} = ENNReal.ofReal p) : + ∫⁻ xs, ((seenElements W' xs).card : ℝ≥0∞) + ∂(Measure.pi (fun _ : Fin m => P)) + ≤ ENNReal.ofReal (↑W'.card * (↑m * p)) := by + set μ := Measure.pi (fun _ : Fin m => P) + have h1p_nonneg : (0 : ℝ) ≤ 1 - p := by linarith + -- Rewrite lintegral as sum of measures + have hstep1 : ∫⁻ xs, ((seenElements W' xs).card : ℝ≥0∞) ∂μ + = ∑ w ∈ W', μ {xs : Fin m → α | ∃ i, xs i = w} := by + rw [seenElements_card_eq_sum W', + lintegral_finset_sum' W' (fun w _ => + (Measurable.ite (measurableSet_setOf_exists_pi_eq w) measurable_const + measurable_const).aemeasurable)] + congr 1; ext w + rw [show (fun xs : Fin m → α => if (∃ i, xs i = w) then (1 : ℝ≥0∞) else 0) = + ({xs : Fin m → α | ∃ i, xs i = w}).indicator 1 from by ext; simp [indicator]] + exact lintegral_indicator_one (measurableSet_setOf_exists_pi_eq w) + rw [hstep1] + -- Bound each term using Bernoulli inequality + calc ∑ w ∈ W', μ {xs | ∃ i, xs i = w} + ≤ ∑ _w ∈ W', ENNReal.ofReal (↑m * p) := by + apply Finset.sum_le_sum; intro w hw + have hcompl_eq : μ {xs : Fin m → α | ∃ i, xs i = w}ᶜ = (P {w}ᶜ) ^ m := by + have : {xs : Fin m → α | ∃ i, xs i = w}ᶜ = + Set.pi Set.univ (fun _ : Fin m => ({w} : Set α)ᶜ) := by + ext xs; simp [Set.mem_pi] + rw [this, Measure.pi_pi] + simp [Finset.prod_const, Finset.card_univ, Fintype.card_fin] + have hseen : μ {xs | ∃ i, xs i = w} = 1 - (P {w}ᶜ) ^ m := by + have h2 := prob_compl_eq_one_sub (μ := μ) (measurableSet_setOf_exists_pi_eq w).compl + rw [compl_compl] at h2; rw [h2, hcompl_eq] + have hPwc : P {w}ᶜ = ENNReal.ofReal (1 - p) := by + rw [prob_compl_eq_one_sub (MeasurableSet.singleton w), hP_w w hw, + ← ENNReal.ofReal_one] + exact (ENNReal.ofReal_sub 1 hp_nonneg).symm + rw [hseen, hPwc, ← ENNReal.ofReal_pow h1p_nonneg, + ← ENNReal.ofReal_one, ← ENNReal.ofReal_sub 1 (pow_nonneg h1p_nonneg _)] + exact ENNReal.ofReal_le_ofReal (one_sub_pow_le_mul hp_nonneg hp_le_one m) + _ = ENNReal.ofReal (↑W'.card * (↑m * p)) := by + rw [Finset.sum_const, nsmul_eq_mul, + ← ENNReal.ofReal_natCast (n := W'.card), + ← ENNReal.ofReal_mul (by exact_mod_cast Nat.zero_le W'.card)] + +end Cslib.MachineLearning diff --git a/Cslib/MachineLearning/PACLearning/SampleComplexityLower/InvolutionPairing.lean b/Cslib/MachineLearning/PACLearning/SampleComplexityLower/InvolutionPairing.lean new file mode 100644 index 000000000..51153d22a --- /dev/null +++ b/Cslib/MachineLearning/PACLearning/SampleComplexityLower/InvolutionPairing.lean @@ -0,0 +1,165 @@ +/- +Copyright (c) 2026 Samuel Schlesinger. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Samuel Schlesinger +-/ + +module + +public import Cslib.MachineLearning.PACLearning.SampleComplexityLower.Helpers + +@[expose] public section + +/-! # Involution Pairing Argument + +The combinatorial core of the EHKV lower bound proof. For each "bad" sample +(one that reveals at most half the shattered set), an involution on +`2^d` concepts pairs each concept with its complement on the unseen +points. At least one concept per pair forces large error. + +## Main statements + +- `involution_half_count`: an involution where every pair has a "failing" + element implies at least half the elements fail. +- `cMap_sample_agree`: two concepts agreeing on seen elements yield the + same labeled sample. +- `unseen_measure_ge`: the measure of the unseen set is at least `4ε'`. +- `complementary_error_contradiction`: two complementary errors can't + both be `≤ ε`. + +## References + +* [A. Ehrenfeucht, D. Haussler, M. Kearns, L. Valiant, + *A General Lower Bound on the Number of Examples Needed + for Learning*][EHKV1989] +-/ + +open MeasureTheory Set Finset +open scoped ENNReal + +noncomputable section + +namespace Cslib.MachineLearning + +section InvolutionPairing + +variable {α : Type*} [MeasurableSpace α] [MeasurableSingletonClass α] + +open Classical in +omit [MeasurableSpace α] [MeasurableSingletonClass α] in +/-- If `σ` is an involution on a finset `F` that maps `F` into itself, and +every element is paired with a "failing" partner (i.e., `P x ∨ P (σ x)`), +then at least half the elements satisfy `P`. -/ +theorem involution_half_count {ι : Type*} + {F : Finset ι} {σ : ι → ι} + (hσ_self : ∀ x ∈ F, σ (σ x) = x) + (hσ_mem : ∀ x ∈ F, σ x ∈ F) + {P : ι → Prop} + (hpair : ∀ x ∈ F, P x ∨ P (σ x)) : + F.card / 2 ≤ (F.filter P).card := by + set G := F \ F.filter P + have hG_sub : G ⊆ F := sdiff_subset + -- σ maps G into F.filter P + have hσ_G_to_F : ∀ x ∈ G, σ x ∈ F.filter P := by + intro x hxG + have hx_pw := hG_sub hxG + have hx_nF : x ∉ F.filter P := (Finset.mem_sdiff.mp hxG).2 + have hx_not_P : ¬P x := fun h => hx_nF (Finset.mem_filter.mpr ⟨hx_pw, h⟩) + rcases hpair x hx_pw with h | h + · exact absurd h hx_not_P + · exact Finset.mem_filter.mpr ⟨hσ_mem x hx_pw, h⟩ + -- σ is injective on G (from involution) + have hσ_inj_G : Set.InjOn σ (G : Set ι) := by + intro x₁ hx₁ x₂ hx₂ hσeq + calc x₁ = σ (σ x₁) := (hσ_self x₁ (hG_sub hx₁)).symm + _ = σ (σ x₂) := by rw [hσeq] + _ = x₂ := hσ_self x₂ (hG_sub hx₂) + -- |G| ≤ |F.filter P| via injection + have hcard : G.card ≤ (F.filter P).card := + (card_image_of_injOn hσ_inj_G) ▸ + card_le_card (fun S hS => let ⟨_, hG, heq⟩ := Finset.mem_image.mp hS; heq ▸ hσ_G_to_F _ hG) + -- |F| = |G| + |F.filter P| ≤ 2 * |F.filter P| + have hpow_eq : G.card + (F.filter P).card = F.card := + card_sdiff_add_card_eq_card (filter_subset _ _) + omega + +open Classical in +omit [MeasurableSpace α] [MeasurableSingletonClass α] in +/-- Two `cMap` concepts agree on all sample points when their underlying subsets +have the same intersection with the seen elements `T = seenElements W' xs`. -/ +theorem cMap_sample_agree + {W W' : Finset α} {w₀ : α} + {m : ℕ} {xs : Fin m → α} + (cMap : (S : Finset α) → S ∈ W'.powerset → Set α) + (hcMap_eq : ∀ S (hS : S ∈ W'.powerset), cMap S hS ∩ ↑W = {w₀} ∪ ↑S) + {S₁ S₂ : Finset α} + (hS₁ : S₁ ∈ W'.powerset) (hS₂ : S₂ ∈ W'.powerset) + (hinter : S₁ ∩ seenElements W' xs = S₂ ∩ seenElements W' xs) + (hxs : ∀ i, xs i ∈ (↑W : Set α)) : + ∀ i, xs i ∈ cMap S₁ hS₁ ↔ xs i ∈ cMap S₂ hS₂ := by + set T := seenElements W' xs + intro i + have hxiW := hxs i + have step : ∀ {Sa Sb : Finset α} (hSa : Sa ∈ W'.powerset) (hSb : Sb ∈ W'.powerset), + Sa ∩ T = Sb ∩ T → xs i ∈ cMap Sa hSa → xs i ∈ cMap Sb hSb := by + intro Sa Sb hSa hSb hint hxi + have hxi_inter : xs i ∈ cMap Sa hSa ∩ ↑W := ⟨hxi, hxiW⟩ + rw [hcMap_eq] at hxi_inter + rcases hxi_inter with hw0 | hxiSa + · -- xs i = w₀ + have : xs i ∈ ({w₀} ∪ ↑Sb : Set α) := Or.inl hw0 + rw [← hcMap_eq Sb hSb] at this; exact this.1 + · -- xs i ∈ Sa, so xs i ∈ T (seen), so xs i ∈ Sa ∩ T = Sb ∩ T, so xs i ∈ Sb + have hxiW' : xs i ∈ W' := (Finset.mem_powerset.mp hSa) (mem_coe.mp hxiSa) + have hxiT : xs i ∈ T := Finset.mem_filter.mpr ⟨hxiW', ⟨i, rfl⟩⟩ + have hxiSb : xs i ∈ Sb := + (Finset.mem_inter.mp (hint ▸ Finset.mem_inter.mpr ⟨mem_coe.mp hxiSa, hxiT⟩)).1 + have : xs i ∈ ({w₀} ∪ ↑Sb : Set α) := Or.inr (mem_coe.mpr hxiSb) + rw [← hcMap_eq Sb hSb] at this; exact this.1 + exact ⟨step hS₁ hS₂ hinter, step hS₂ hS₁ hinter.symm⟩ + +/-- The measure of an unseen set `U` is at least `4ε'` when each point has measure +`8ε'/d` and `|U| ≥ d/2`. This is the common measure lower bound used in the +counting argument and involution pairing. -/ +theorem unseen_measure_ge {U : Finset α} {d : ℕ} {ε' : ℝ} {P : Measure α} + (hε'_pos : 0 < ε') (hd_pos : 0 < d) (h2U : d ≤ 2 * U.card) + (hP_each : ∀ w ∈ U, P {w} = ENNReal.ofReal (8 * ε' / ↑d)) : + ENNReal.ofReal (4 * ε') ≤ P (↑U) := by + have hU_eq : (↑U : Set α) = ⋃ w ∈ U, ({w} : Set α) := by ext x; simp + rw [hU_eq, measure_biUnion_finset + (fun w _ w' _ hww' => Set.disjoint_singleton.mpr hww') + (fun w _ => MeasurableSet.singleton w)] + rw [Finset.sum_congr rfl hP_each, Finset.sum_const, nsmul_eq_mul, + ← ENNReal.ofReal_natCast (n := U.card), + ← ENNReal.ofReal_mul (by positivity)] + apply ENNReal.ofReal_le_ofReal + have hd_cast : (0 : ℝ) < d := Nat.cast_pos.mpr hd_pos + calc 4 * ε' = (d : ℝ) / 2 * (8 * ε' / d) := by field_simp; ring + _ ≤ (U.card : ℝ) * (8 * ε' / d) := by + apply mul_le_mul_of_nonneg_right _ (by positivity) + linarith [show (d : ℝ) ≤ 2 * (U.card : ℝ) from by exact_mod_cast h2U] + +omit [MeasurableSingletonClass α] in +/-- If two sets' symmetric differences with a hypothesis cover a set of measure `≥ 4ε'`, +but each symmetric difference has measure `≤ ε'`, we derive a contradiction. +This is the core contradiction in the EHKV counting and pairing arguments. -/ +theorem complementary_error_contradiction {P : Measure α} {h c₁ c₂ : Set α} + {U : Set α} {ε' : ℝ} (hε'_pos : 0 < ε') + (hU_sub : U ⊆ symmDiff h c₁ ∪ symmDiff h c₂) + (hP_U : ENNReal.ofReal (4 * ε') ≤ P U) + (herr₁ : P (symmDiff h c₁) ≤ ENNReal.ofReal ε') + (herr₂ : P (symmDiff h c₂) ≤ ENNReal.ofReal ε') : False := by + have h_contra : ENNReal.ofReal (4 * ε') ≤ ENNReal.ofReal (2 * ε') := + calc ENNReal.ofReal (4 * ε') + ≤ P U := hP_U + _ ≤ P (symmDiff h c₁ ∪ symmDiff h c₂) := measure_mono hU_sub + _ ≤ P (symmDiff h c₁) + P (symmDiff h c₂) := measure_union_le _ _ + _ ≤ ENNReal.ofReal ε' + ENNReal.ofReal ε' := add_le_add herr₁ herr₂ + _ = ENNReal.ofReal (2 * ε') := by + rw [← ENNReal.ofReal_add hε'_pos.le hε'_pos.le]; ring_nf + rw [ENNReal.ofReal_le_ofReal_iff (by linarith)] at h_contra + linarith + +end InvolutionPairing + +end Cslib.MachineLearning diff --git a/Cslib/MachineLearning/PACLearning/VCDimension.lean b/Cslib/MachineLearning/PACLearning/VCDimension.lean new file mode 100644 index 000000000..f562420de --- /dev/null +++ b/Cslib/MachineLearning/PACLearning/VCDimension.lean @@ -0,0 +1,108 @@ +/- +Copyright (c) 2026 Samuel Schlesinger. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Samuel Schlesinger +-/ + +module + +public import Cslib.MachineLearning.PACLearning.Defs +public import Mathlib.Combinatorics.SetFamily.Shatter + +@[expose] public section + +/-! # VC Dimension for Concept Classes + +This file defines *shattering* and the *Vapnik-Chervonenkis dimension* for +concept classes modeled as `Set (Set α)`. See also the `Finset`-based +definitions in `Mathlib.Combinatorics.SetFamily.Shatter`. + +The VC dimension of a concept class `C` is the cardinality of the largest +finite set that `C` shatters: a set `W` is shattered by `C` if every +subset of `W` can be obtained as the intersection of `W` with some +concept in `C`. + +## Main definitions + +- `SetShatters C W`: the concept class `C` shatters the set `W`. +- `vcDim C`: the VC dimension of `C`, i.e., the supremum of the + cardinalities of finite sets shattered by `C`. + +## Main statements + +- `SetShatters.subset`: shattering is anti-monotone in the shattered set. +- `SetShatters.superset`: shattering is monotone in the concept class. +- `Finset.Shatters.toSetShatters`: bridge from Mathlib's `Finset.Shatters` + to `SetShatters`. + +## References + +* [A. Ehrenfeucht, D. Haussler, M. Kearns, L. Valiant, + *A General Lower Bound on the Number of Examples Needed + for Learning*][EHKV1989] +-/ + +open Set + +namespace Cslib.MachineLearning + +/-- A concept class `C` *shatters* a set `W` if for every subset `W'` +of `W`, there exists a concept `c ∈ C` such that `c ∩ W = W'`. -/ +def SetShatters (C : ConceptClass α) (W : Set α) : Prop := + ∀ W' ⊆ W, ∃ c ∈ C, c ∩ W = W' + +/-- Shattering is anti-monotone in the shattered set: if `C` shatters +`W` and `V ⊆ W`, then `C` shatters `V`. -/ +theorem SetShatters.subset {C : ConceptClass α} {W V : Set α} + (hW : SetShatters C W) (hVW : V ⊆ W) : SetShatters C V := by + intro V' hV'V + -- We need c ∈ C with c ∩ V = V'. + -- Use that C shatters W: pick c with c ∩ W = V' ∪ (W \ V). + have hV'W : V' ⊆ W := hV'V.trans hVW + have hsub : V' ∪ (W \ V) ⊆ W := + union_subset hV'W diff_subset + obtain ⟨c, hc, hc_eq⟩ := hW (V' ∪ (W \ V)) hsub + refine ⟨c, hc, ?_⟩ + ext x; simp only [mem_inter_iff] + constructor + · rintro ⟨hxc, hxV⟩ + have := (hc_eq ▸ (⟨hxc, hVW hxV⟩ : x ∈ c ∩ W) : x ∈ V' ∪ (W \ V)) + exact this.elim id (fun ⟨_, h⟩ => absurd hxV h) + · intro hxV' + exact ⟨(hc_eq ▸ (Or.inl hxV' : x ∈ V' ∪ (W \ V)) : x ∈ c ∩ W).1, hV'V hxV'⟩ + +/-- Shattering is monotone in the concept class: if `C` shatters +`W` and `C ⊆ C'`, then `C'` shatters `W`. -/ +theorem SetShatters.superset {C C' : ConceptClass α} {W : Set α} + (hW : SetShatters C W) (hCC' : C ⊆ C') : SetShatters C' W := by + intro W' hW' + obtain ⟨c, hc, hcW⟩ := hW W' hW' + exact ⟨c, hCC' hc, hcW⟩ + +open Classical in +/-- If a finite set family `𝒜` shatters a finite set `s` in the sense of +Mathlib's `Finset.Shatters`, then the coerced concept class shatters `↑s` +in the sense of `SetShatters`. This bridges Mathlib's finset-based shattering +to the set-based notion used by the PAC learning lower bounds. -/ +theorem Finset.Shatters.toSetShatters {𝒜 : Finset (Finset α)} {s : Finset α} + (h : 𝒜.Shatters s) : + SetShatters {c : Set α | ∃ t ∈ 𝒜, (↑t : Set α) = c} ↑s := by + intro W' hW' + have hfin : Set.Finite W' := s.finite_toSet.subset hW' + set t := hfin.toFinset + have ht_eq : (↑t : Set α) = W' := hfin.coe_toFinset + have ht_sub : t ⊆ s := Finset.coe_subset.mp (ht_eq ▸ hW') + obtain ⟨u, hu, hsu⟩ := h ht_sub + have hut : u ∩ s = t := by rwa [Finset.inter_comm] at hsu + exact ⟨↑u, ⟨u, hu, rfl⟩, by rw [← ht_eq]; exact_mod_cast hut⟩ + +/-- The *Vapnik-Chervonenkis dimension* of a concept class `C` is the +supremum of the cardinalities of finite sets shattered by `C`. +Returns `0` when no finite set is shattered (i.e., the defining set is empty). + +**Caveat**: because `sSup` on `ℕ` returns `0` for unbounded sets, this definition +is only meaningful when the VC dimension is finite. -/ +noncomputable def vcDim (C : ConceptClass α) : ℕ := + sSup {n : ℕ | ∃ W : Finset α, W.card = n ∧ SetShatters C (↑W)} + +end Cslib.MachineLearning diff --git a/references.bib b/references.bib index 2cccb928f..799df1110 100644 --- a/references.bib +++ b/references.bib @@ -243,6 +243,29 @@ @incollection{ Thomas1990 year = {1990} } +@article{ Valiant1984, + author = {Valiant, Leslie G.}, + title = {A Theory of the Learnable}, + journal = {Communications of the ACM}, + volume = {27}, + number = {11}, + pages = {1134--1142}, + year = {1984}, + doi = {10.1145/1968.1972} +} + +@article{ EHKV1989, + author = {Ehrenfeucht, Andrzej and Haussler, David and Kearns, Michael and Valiant, Leslie}, + title = {A General Lower Bound on the Number of Examples Needed for Learning}, + journal = {Information and Computation}, + volume = {82}, + number = {3}, + pages = {247--261}, + year = {1989}, + doi = {10.1016/0890-5401(89)90002-3}, + issn = {0890-5401} +} + @book{ Cutland1980, author = {Cutland, Nigel J.}, title = {Computability: An Introduction to Recursive Function Theory},