From 7f9641f70353306b4eb3ad770a771fa7a71b97f7 Mon Sep 17 00:00:00 2001 From: Alix Trieu Date: Mon, 27 Apr 2026 20:26:24 +0200 Subject: [PATCH 1/2] Feat: Signed Barrett Reduction algorithm --- Cslib.lean | 1 + .../Algorithms/BarrettReduction/Aux.lean | 199 ++++++++++++ .../Algorithms/BarrettReduction/Signed.lean | 301 ++++++++++++++++++ 3 files changed, 501 insertions(+) create mode 100644 Cslib/Crypto/Algorithms/BarrettReduction/Aux.lean create mode 100644 Cslib/Crypto/Algorithms/BarrettReduction/Signed.lean diff --git a/Cslib.lean b/Cslib.lean index 7db43680b..0ea0c77b4 100644 --- a/Cslib.lean +++ b/Cslib.lean @@ -36,6 +36,7 @@ public import Cslib.Computability.URM.Defs public import Cslib.Computability.URM.Execution public import Cslib.Computability.URM.StandardForm public import Cslib.Computability.URM.StraightLine +public import Cslib.Crypto.Algorithms.BarrettReduction.Signed public import Cslib.Crypto.Protocols.PerfectSecrecy.Basic public import Cslib.Crypto.Protocols.PerfectSecrecy.Defs public import Cslib.Crypto.Protocols.PerfectSecrecy.Encryption diff --git a/Cslib/Crypto/Algorithms/BarrettReduction/Aux.lean b/Cslib/Crypto/Algorithms/BarrettReduction/Aux.lean new file mode 100644 index 000000000..ec517935c --- /dev/null +++ b/Cslib/Crypto/Algorithms/BarrettReduction/Aux.lean @@ -0,0 +1,199 @@ +/- +Copyright (c) 2026 Alix Trieu. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Alix Trieu +-/ + +module + +public import Cslib.Init +public import Mathlib.Data.Nat.Log +public import Mathlib.Algebra.Order.Round +public import Mathlib.Data.Rat.Floor +public import Mathlib.Algebra.Order.Floor.Defs +public import Mathlib.Data.Int.DivMod +import Mathlib.Tactic + +/- +# Auxiliary definitions and lemmas + +- Defines `clog2`, a base 2 upper logarithm and some associated lemmas +- Additional facts about `bmod`, `floor` and `round` +-/ + +@[expose] +public section + +namespace Nat + +def clog2 : ℕ → ℕ := Nat.clog 2 + +lemma le_clog2_self (n : ℕ) : + n ≤ 2 ^ (n.clog2) := by + apply le_pow_clog (by simp) n + +lemma log2_le_clog2 (n : ℕ) : + n.log2 ≤ n.clog2 := by + rw [log2_eq_log_two] + apply Nat.log_le_clog 2 n + +lemma le_pow_iff_clog2_le {x y : ℕ} : + x ≤ 2 ^ y ↔ clog2 x ≤ y := + by symm; apply Nat.clog_le_iff_le_pow; simp + +lemma clog2_le_log2 (n : ℕ) : + n.clog2 ≤ n.log2 + 1 := by + rw [log2_eq_log_two] + rw [← le_pow_iff_clog2_le] + apply le_of_lt + cases n with + | zero => simp + | succ n => + rw [← log2_eq_log_two, ← Nat.log2_lt (by simp)] + simp + +lemma clog2_eq (n : ℕ) : + n.clog2 = if 2 ^ n.log2 < n then n.log2 + 1 else n.log2 := by + have H₀ := clog2_le_log2 n + have H₁ := log2_le_clog2 n + split_ifs with Hcond <;> rw [← Nat.lt_clog_iff_pow_lt (by simp), ← clog2] at Hcond <;> linarith + +end Nat + +namespace Int + +lemma abs_bmod_le (x : ℤ) (m : ℕ) (Hm : 0 < m) : + |x.bmod m| ≤ m / 2 := by + rw [abs_le]; apply And.intro + · apply Int.le_bmod Hm + · transitivity + · apply Int.bmod_le Hm + · omega + +lemma bmod_eq' (x : ℤ) (m : ℕ) : + x.bmod m = x - m * (round (x / (m: ℚ))) := by + rw [round_eq, Int.bmod] + have X: x % m < (m + 1) / 2 ↔ 2 * (x % m) < m := by omega + cases Nat.eq_zero_or_pos m with + | inl Hz => rw [Hz]; simp + | inr Hpos => + rw [div_add_div] <;> + simp only [mul_one, Nat.cast_eq_zero, ne_eq, OfNat.ofNat_ne_zero, not_false_eq_true] <;> + try linarith + split_ifs with Hcond <;> rw [X] at Hcond + · rw [Int.emod_def]; simp only [sub_right_inj, _root_.mul_eq_mul_left_iff, natCast_eq_zero] + left; rw [show m * (2:ℚ) = ↑(2 * m) by simp; linarith] + rw [show x * 2 + (m:ℚ) = ↑(2 * x + m) by simp; linarith] + rw [Rat.floor_intCast_div_natCast]; symm + apply ((@Int.ediv_emod_unique _ _ (2 * (x % m) + m) _ (by omega)).mpr ?_).left + apply And.intro + · nth_rw 3 [← Int.mul_ediv_add_emod x m]; simp + linarith + · have X := @Int.emod_nonneg x m (by omega) + simp only [Nat.cast_mul, Nat.cast_ofNat]; apply And.intro <;> linarith + · rw [show m * (2:ℚ) = ↑(2 * m) by simp; linarith] + rw [show x * 2 + (m:ℚ) = ↑(2 * x + m) by simp; linarith] + rw [Rat.floor_intCast_div_natCast] + rw [Int.emod_def]; simp only [Nat.cast_mul, Nat.cast_ofNat] + nth_rw 3 [← mul_one m] + rw [Int.sub_sub, Nat.cast_mul, ← mul_add]; simp only [Nat.cast_one, sub_right_inj, + _root_.mul_eq_mul_left_iff, natCast_eq_zero] + left; symm + apply ((@Int.ediv_emod_unique _ _ (2 * (x % m) - m) _ (by omega)).mpr ?_).left + apply And.intro + · nth_rw 3 [← Int.mul_ediv_add_emod x m] + linarith + · have X := @Int.emod_lt_of_pos x m (by omega) + simp only [Int.sub_nonneg]; apply And.intro <;> try linarith + +lemma emod_def' (x : ℤ) (m : ℕ) : + x % ↑m = if x.bmod m < 0 then m + x.bmod m else x.bmod m := by + simp [Int.bmod_def] + split_ifs <;> try omega + · cases Nat.eq_zero_or_pos m with + | inl Hz => rw [Hz]; simp + | inr Hpos => + have X := @Int.emod_nonneg x m (by omega); linarith + · cases Nat.eq_zero_or_pos m with + | inl Hz => rw [Hz]; simp + | inr Hpos => + have X := @Int.emod_lt_of_pos x m (by omega); linarith + +lemma bmod_eq_of_abs_lt {n : ℤ} {m : ℕ} (hlt : |n| < m / 2) : + n.bmod m = n := by + rw [abs_lt] at hlt + apply Int.bmod_eq_of_le <;> omega + +lemma bmod_bmod_eq_of_le {x : ℤ} {m1 m2 : ℕ} (h : 0 < m1) (h' : m1 ≤ m2) : + (x.bmod m1).bmod m2 = x.bmod m1 := by + have X0 := @Int.le_bmod x m1 h + have X1 := @Int.bmod_le x m1 h + rw [@Int.bmod_eq_of_le _ m2] <;> omega + +lemma bmod_bmod_eq_of_lt {x : ℤ} {m1 m2 : ℕ} (h : 0 < m1) (h' : m1 < m2) : + (x.bmod m1).bmod m2 = x.bmod m1 := by + rw [bmod_bmod_eq_of_le] <;> omega +end Int + +end + +@[expose] +public section + +variable {α : Type*} +variable [Field α] [LinearOrder α] [IsStrictOrderedRing α] [FloorRing α] + +lemma floor_sub_abs (a b : α) : + |⌊a⌋ - ⌊b⌋| ≤ ⌈|a - b|⌉ := by + wlog Hab: a ≥ b + · rw [abs_sub_comm ⌊a⌋, abs_sub_comm a] + apply this; apply le_of_not_ge at Hab; assumption + · rw [abs_of_nonneg, abs_of_nonneg] <;> + [skip; linarith; (simp only [Int.sub_nonneg]; apply Int.floor_mono; assumption)] + nth_rw 2 [← Int.fract_add_floor a] + nth_rw 2 [← Int.fract_add_floor b] + rw [show (Int.fract a + ↑⌊a⌋ - (Int.fract b + ↑⌊b⌋)=(Int.fract a - Int.fract b) + ↑(⌊a⌋ - ⌊b⌋)) + by rw [Int.cast_sub]; linarith] + rw [Int.ceil_add_intCast]; simp only [le_add_iff_nonneg_left] + rw [show (0 = -1 + 1) by omega] + apply Int.add_one_le_of_lt + rw [Int.lt_ceil]; simp + have Ha₀: 0 ≤ Int.fract a := by apply Int.fract_nonneg + have Hb₁: Int.fract b < 1 := by apply Int.fract_lt_one + linarith + +lemma floor_lt_iff (a b : α) : + ⌊a⌋ < ⌊b⌋ ↔ ∃ (n: ℤ), a < ↑n ∧ ↑n ≤ b := by + apply Iff.intro + · intro H; cases lt_or_ge a ↑⌊b⌋ with + | inl Hlt => use ↑⌊b⌋; apply And.intro + · assumption + · exact Int.floor_le b + | inr Hge => + apply Int.le_floor.mpr at Hge; linarith + · intro ⟨n, Ha, Hb⟩ + have H := Int.floor_le_floor Hb + rw [Int.floor_intCast] at H + apply @lt_of_lt_of_le _ _ _ n + · exact Int.floor_lt.mpr Ha + · assumption + +lemma round_sub_abs (a b : α) : + |round a - round b| ≤ ⌈|a - b|⌉ := by + rw [round_eq, round_eq] + rw [show (a - b = (a + 1/2) - (b + 1/2)) by linarith] + apply floor_sub_abs + +lemma round_lt_iff (a b : α) : + round a < round b ↔ ∃ (n: ℤ), a < n + 1/2 ∧ n + 1/2 ≤ b := by + apply Iff.intro + · rw [round_eq, round_eq]; intro H + rw [floor_lt_iff] at H + let ⟨n, Ha, Hb⟩ := H + use (n - 1); apply And.intro <;> (simp; linarith) + · intro ⟨n, Ha, Hb⟩ + rw [round_eq, round_eq] + rw [floor_lt_iff] + use (n + 1); apply And.intro <;> (simp; linarith) + +end diff --git a/Cslib/Crypto/Algorithms/BarrettReduction/Signed.lean b/Cslib/Crypto/Algorithms/BarrettReduction/Signed.lean new file mode 100644 index 000000000..f28e0060d --- /dev/null +++ b/Cslib/Crypto/Algorithms/BarrettReduction/Signed.lean @@ -0,0 +1,301 @@ +/- +Copyright (c) 2026 Alix Trieu. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Alix Trieu +-/ + +module + +public import Cslib.Init +public import Mathlib.Data.Nat.Log +public import Mathlib.Data.Rat.Init +public import Mathlib.Data.Rat.Floor +public import Mathlib.Algebra.Order.Field.Rat +public import Mathlib.Algebra.Order.Round +import Mathlib.Tactic +public import Cslib.Crypto.Algorithms.BarrettReduction.Aux + +/- +# Signed Barrett Reduction + +This file formalizes signed variant of the Barrett reduction algorithm used in many +schemes such as ML-DSA or ML-KEM. + +This formalization is inspired by Section 2.4 of the following paper +Efficient Multiplication of Somewhat Small Integers Using Number-Theoretic Transforms +Hanno Becker, Vincent Hwang, Matthias J. Kannwischer, Lorenz Panny, and Bo-Yin Yang +IWSEC 2022 + +The main theorem is `barrett_reduce_spec`. + +See example at the end of file for how to use it. +-/ + +namespace Cslib.Crypto.Algorithms.BarrettReduction.Signed + +notation "⌊" x "⌉" => round (x : ℚ) + +def is_approx (δ : ℚ) (α : ℚ → ℤ) : Prop := + ∀ (x: ℚ), |(x - α x)| ≤ δ + +lemma round_is_approx : is_approx (1/2) round := by + intro x; apply abs_sub_round + +def round_to_even (x : ℚ) : ℤ := + 2 * ⌊(x / 2)⌉ + +def mod_approx (α : ℚ → ℤ) (x : ℤ) (N : ℕ) : ℤ := x - ↑N * (α (x/N)) + +public def smod (x : ℤ) (N : ℕ) : ℤ := mod_approx round x N + +notation x "mod±" N => smod x N + +lemma smod_is_bmod (x : ℤ) (N : ℕ) : + (x mod± N) = (x.bmod N) := by + rw [Int.bmod_eq_self_sub_mul_bdiv, smod, mod_approx] + rw [Int.bdiv]; split_ifs with HN + · rw [HN]; simp + · simp only [mul_ite, sub_right_inj] + rw [round_eq, show (↑x / ↑N + 1 / (2:ℚ)) = (↑(2 * x + N) / ↑(2 * N)) by + rw [← Rat.mkRat_eq_div, ← Rat.mkRat_eq_div] + rw [show (1/2:ℚ) = mkRat 1 2 by cbv] + rw [Rat.mkRat_add_mkRat] <;> try omega + rw [Rat.mkRat_eq_iff] <;> try omega + simp; linarith] + rw [Rat.floor_intCast_div_natCast]; simp only [Nat.cast_mul, Nat.cast_ofNat] + have X: x % N < (N + 1) / 2 ↔ 2 * (x % N) < N := by omega + rw [show ((2 * x + ↑N) / (2 * ↑N)) = if x % ↑N < (↑N + 1) / 2 then (x / ↑N) else (x / ↑N + 1) by + refine ((@Int.ediv_emod_unique (2 * x + ↑N) (2 * ↑N) + (if x % ↑N < (↑N + 1) / 2 then 2 * (x % ↑N) + ↑N else 2 * (x % ↑N) - ↑N) + (if x % ↑N < (↑N + 1) / 2 then x / ↑N else x / ↑N + 1) (by omega)).mpr ?_).left + apply And.intro + · split_ifs with A <;> nth_rw 3 [← Int.mul_ediv_add_emod x N] <;> linarith + · apply And.intro + · have Y := @Int.emod_nonneg x N (by omega) + split_ifs with A + · linarith + · rw [X] at A; linarith + · split_ifs with A <;> rw [X] at A + · linarith + · have Y := @Int.emod_lt_of_pos x N (by omega) + linarith] + split <;> simp + +def barrett_mul (R : ℕ) (a b : ℤ) (q : ℕ) : ℤ := + a * b - q * ⌊((a * ⌊((b * R) / q)⌉) / R)⌉ + +-- This is Fact 2 of cited paper above. +-- M is the bitwidth of the considered integer type, e.g., 16, 32, 64, etc. +lemma barrett_mul_spec (a b : ℤ) (M R k q : ℕ) + (H1_le_k : 1 ≤ k) + (Hk : |((b * R) / (q : ℚ)) - ⌊((b * R) / q)⌉| ≤ (1 / (2 ^ k))) + (HOddq : Odd q) (HR : R = 2 ^ (M - 1 + q.log2 - |b|.toNat.clog2)) + (HM : 2 ≤ M) + (Hb : |b| ≤ 2 ^ (M - 2)) + (Ha' : |a| ≤ 2 ^ ((M - 2) - |b|.toNat.clog2 + (k - 1))) : + barrett_mul R a b q = (a * b).bmod q := by + have Hqpos: q > 0 := by exact Odd.pos HOddq + have HRpos: R > 0 := by subst R; exact Nat.two_pow_pos _ + rw [← smod_is_bmod, barrett_mul, smod, mod_approx] + simp only [Int.cast_mul, sub_right_inj, mul_eq_mul_left_iff, Int.natCast_eq_zero]; left + let δ := a * (round ((b * R) / (q: ℚ))) / (R: ℚ) - ((a * b) / q) + rw [show ↑a * ↑(round (↑b * ↑R / (q:ℚ))) / (R: ℚ) = ((a * b) / q) + δ by simp [δ]] + cases eq_or_ne ⌊(a * b) / q⌉ ⌊(a * b) / q + δ⌉ with + | inl _ => omega + | inr Hne => + exfalso + have Hδ₀: |δ| ≤ ↑|a| / (2^k * ↑R) := by + rw [show δ = (a / R) * (round ((b * R) / (q: ℚ)) - (b * R) / (q: ℚ)) by + unfold δ; qify at Hqpos; qify at HRpos + rw [mul_sub, ← mul_div_right_comm, ← mul_div_mul_comm] + rw [← mul_assoc, mul_comm (R: ℚ) q, mul_div_mul_right]; linarith] + rw [abs_mul, abs_sub_comm, abs_div, @abs_of_nonneg _ _ _ (↑R:ℚ) _ (Nat.cast_nonneg' R)] + rw [show ↑|a| / (2^k * ↑R) = |↑a| / (↑R:ℚ) * (1 / 2^k) by + rw [div_mul_div_comm, mul_comm]; simp] + apply mul_le_mul_of_nonneg_left + · apply Hk + · apply div_nonneg <;> simp + have Hδ₁: 1 / (2*q: ℚ) ≤ |δ| := by + cases lt_or_gt_of_ne Hne with + | inl Hlt => + rw [round_lt_iff] at Hlt + let ⟨n, Ha, Hb⟩ := Hlt + have Hδ₀: 0 < (↑n:ℚ) + 1 / 2 - ↑(a* b) / ↑q ∧ ↑n + 1 / 2 - ↑(a * b) / ↑q ≤ δ := by + apply And.intro <;> qify <;> linarith + rw [← div_one (↑n:ℚ), div_add_div, div_sub_div] at Hδ₀ <;> + [skip ; simp; (qify at Hqpos; linarith); simp ; simp] + simp only [mul_one, one_mul, Int.cast_mul] at Hδ₀ + let ⟨Hδ₁, Hδ₂⟩ := Hδ₀ + rw [div_pos_iff_of_pos_right] at Hδ₁ <;> [skip; (qify at Hqpos; linarith only [Hqpos])] + have H: 1 ≤ (n * 2 + 1) * q - 2 * (a * b) := by + suffices 0 < (n * 2 + 1) * q - 2 * (a * b) by linarith + qify; linarith + transitivity δ <;> [skip; apply le_abs_self] + transitivity ((↑n * 2 + 1) * ↑q - 2 * (a * b)) / (2 * ↑q) <;> [skip; assumption] + refine (div_le_div_iff_of_pos_right ?_).mpr ?_ <;> (qify at Hqpos; qify at H; linarith) + | inr Hgt => + rw [round_lt_iff] at Hgt + let ⟨n, Ha, Hb⟩ := Hgt + have Hδ₀: δ < (↑n:ℚ) + 1 / 2 - ↑(a * b) / ↑q ∧ ↑n + 1 / 2 - ↑(a * b) / ↑q ≤ (0:ℚ) := by + apply And.intro <;> qify <;> linarith + rw [← div_one (↑n:ℚ), div_add_div, div_sub_div] at Hδ₀ <;> + [skip ; simp; (qify at Hqpos; linarith); simp ; simp] + simp only [mul_one, one_mul, Int.cast_mul] at Hδ₀ + let ⟨Hδ₁, Hδ₂⟩ := Hδ₀ + have X: δ < 0 := by apply lt_of_lt_of_le <;> assumption + apply neg_lt_neg at Hδ₁ + apply neg_le_neg at Hδ₂ + rw [← abs_of_neg X, ← neg_div] at Hδ₁ + rw [← neg_div] at Hδ₂ + simp only [neg_sub] at Hδ₁; simp only [neg_zero, neg_sub] at Hδ₂ + rw [div_nonneg_iff] at Hδ₂ + cases Hδ₂ with + | inl H => + let ⟨H₀, _⟩ := H + have H₁: 0 ≤ 2 * (a * b) - (n * 2 + 1) * q := by qify; linarith + have H₂: 0 = 2 * (a * b) - (n * 2 + 1) * q ∨ 1 ≤ 2 * (a * b) - (n * 2 + 1) * q := by omega + cases H₂ with + | inl Hz => + symm at Hz; rw [Int.sub_eq_zero] at Hz + have Heven: Even (2 * (a * b)) := by exact even_two_mul (a * b) + have Hodd: Odd (2 * (a * b)) := by + rw [Hz]; apply Odd.mul + · rw [Odd]; use n; linarith + · simp only [Int.odd_coe_nat]; assumption + exfalso; apply (@Int.not_odd_iff_even (2 * (a * b))).mpr <;> assumption + | inr Hle => + apply le_of_lt; apply lt_of_le_of_lt _ Hδ₁ + refine (div_le_div_iff_of_pos_right ?_).mpr ?_ <;> (qify at Hqpos; qify at Hle;linarith) + | inr H => + let ⟨_, H'⟩ := H + qify at Hqpos; linarith + suffices X: ↑|a| / (2 ^ k * ↑R) < 1 / (2 * (q: ℚ)) by linarith + rw [div_lt_div_iff₀] <;> [simp only [Int.cast_abs, one_mul]; + (qify at HRpos; rw [mul_pos_iff]; left; apply And.intro + · exact pow_pos rfl k + · linarith); (qify at Hqpos; linarith)] + rw [HR] + cases eq_or_ne |↑a| (0: ℚ) with + | inl Heq => + rw [Heq]; simp + | inr Hne => + simp only [Nat.cast_pow, Nat.cast_ofNat]; rw [← pow_add] + apply lt_of_le_of_lt + · qify at Ha' + apply (Rat.mul_le_mul_of_nonneg_right Ha' ( by qify at Hqpos; linarith)) + · nth_rw 4 [← pow_one 2] + rw [← mul_assoc, ← pow_add] + apply lt_of_lt_of_le + · have X: q < 2 ^ (q.log2 + 1) := by rw [← Nat.log2_lt] <;> linarith + qify at X; apply (Rat.mul_lt_mul_of_pos_left X) + apply pow_pos; rfl + · rw [← pow_add] + apply pow_le_pow_right₀ + · simp + · have X: |b|.toNat.clog2 ≤ M - 2 := by + rw [← Nat.le_pow_iff_clog2_le]; zify + rw [Int.toNat_of_nonneg] + · omega + · apply abs_nonneg + ring_nf; omega + +public def barrett_reduce (R : ℕ) (a : ℤ) (q : ℕ) : ℤ := + a - q * ⌊((a * ⌊(R / q)⌉) / R)⌉ + +public theorem barrett_reduce_spec (a : ℤ) (M R k q : ℕ) + (H1_le_k : 1 ≤ k) + (Hk : |(R / (q : ℚ)) - ⌊(R / q)⌉| ≤ (1 / (2 ^ k))) + (HOddq : Odd q) (HR : R = 2 ^ (M - 1 + q.log2)) + (HM : 2 ≤ M) + (Ha' : |a| ≤ 2 ^ ((M - 2) + (k - 1))) : + barrett_reduce R a q = a.bmod q := by + nth_rw 2 [← mul_one a] + rw [← barrett_mul_spec a 1 M R k q] <;> try assumption + · rw [barrett_reduce, barrett_mul] + rw [mul_one]; simp + · simp only [Int.cast_one, one_mul, one_div]; simp at Hk; assumption + · simp only [abs_one]; refine one_le_pow₀ (by simp) + +end Cslib.Crypto.Algorithms.BarrettReduction.Signed + +section MLKEMExample +open Cslib.Crypto.Algorithms.BarrettReduction.Signed + +/- +Proof of correctness for the signed Barrett reduction +used in the reference implementation of Kyber/MLKEM +https://github.com/pq-crystals/kyber/blob/main/ref/reduce.c#L25-L42 +-/ + +def M : ℕ := 16 -- 16 bits +def q : ℕ := 3329 -- prime modulus used for MLKEM +def R : ℕ := 2 ^ 26 +def k : ℕ := 2 + +-- This follows closely the original C code, though in ℤ +def mlkem_barrett_reduce (a : ℤ) : ℤ := + let v := 20159 + let t := (v * a + (1 <<< 25)) >>> (26: ℕ) + let t := t * 3329 + a - t + +lemma mlkem_barrett_reduce_correct (a : ℤ) (Ha : |a| ≤ 2 ^ 15) : + mlkem_barrett_reduce a = a.bmod q := by + rw [← barrett_reduce_spec a M R k q] + · rw [mlkem_barrett_reduce, barrett_reduce, R, q]; simp only [Nat.cast_ofNat, + Nat.reduceShiftLeft, Nat.reducePow, sub_right_inj] + rw [show (round (67108864 / (3329:ℚ))) = 20159 by decide +kernel] + rw [Int.shiftRight_eq_div_pow, round_eq] + rw [div_add_div] <;> try decide + simp only [Nat.reducePow, Nat.cast_ofNat, Int.cast_ofNat, mul_one] + rw [show ↑a * 20159 * 2 + 67108864 = (20159 * a + 33554432) * (2:ℚ) by linarith] + rw [mul_div_mul_right _ _ (by simp)] + rw [show 20159 * ↑a + (33554432:ℚ) = ↑(20159 * a + (33554432:ℤ)) by simp] + rw [show (67108864:ℚ) = ↑(67108864:ℕ) by simp] + rw [Rat.floor_intCast_div_natCast]; simp; omega + · simp [k] + · decide +kernel + · use (q/2); decide + · decide + · decide + · transitivity + · apply Ha + · decide + +-- This is basically the C code translated manually into Lean +def mlkem_barrett_reduce_impl (a : Int16) : Int16 := + let v: Int16 := 20159 + let t: Int32 := (v.toInt32 * a.toInt32 + ((1: Int32) <<< 25)) >>> 26 + let t: Int16 := t.toInt16 * 3329 + a - t + +lemma mlkem_barrett_reduce_impl_correct (a : Int16) : + Int16.toInt (mlkem_barrett_reduce_impl a) = (Int16.toInt a).bmod q := by + rw [← mlkem_barrett_reduce_correct] + · rw [mlkem_barrett_reduce, mlkem_barrett_reduce_impl] + rw [Int16.toInt_sub, Int16.toInt_mul] + simp only [Nat.reduceLeDiff, Int16.toInt32_ofNat, Int32.toInt_toInt16, Nat.reducePow, + Int16.reduceToInt, Int.bmod_mul_bmod, Int.sub_bmod_bmod, Nat.cast_ofNat, Nat.reduceShiftLeft] + rw [show ((1:Int32) <<< 25 = 33554432) by decide] + rw [← Int32.toInt_toBitVec, Int32.toBitVec_shiftRight] + simp only [Int32.toBitVec_add, Int32.toBitVec_mul, Int32.toBitVec_ofNat, BitVec.ofNat_eq_ofNat, + Int16.toBitVec_toInt32, BitVec.reduceSMod, BitVec.sshiftRight_eq', BitVec.toNat_ofNat, + Nat.reducePow, Nat.reduceMod, BitVec.toInt_sshiftRight, BitVec.toInt_add, BitVec.toInt_mul, + BitVec.reduceToInt, Int.bmod_add_bmod] + rw [BitVec.toInt_signExtend_of_le] <;> [skip;simp] + rw [Int16.toInt_toBitVec] + have Hle := Int16.le_toInt a + have Hlt := Int16.toInt_lt a + rw [@Int.bmod_eq_of_le _ 4294967296] <;> [skip; (simp; omega); (simp; omega)] + rw [show (a.toInt - (20159 * a.toInt + 33554432) >>> 26 * 3329 = mlkem_barrett_reduce a.toInt) + by rw [mlkem_barrett_reduce]; simp] + rw [mlkem_barrett_reduce_correct, q] + · rw [Int.bmod_bmod_eq_of_lt] <;> omega + · rw [abs_le]; omega + · rw [abs_le']; split_ands + · apply Int.le_of_lt + apply Int16.toInt_lt + · apply Int.neg_le_of_neg_le + apply Int16.le_toInt + +end MLKEMExample From 82f80d8ef71a6a1f69daaf605b801e9a3de5f6f5 Mon Sep 17 00:00:00 2001 From: Alix Trieu Date: Fri, 1 May 2026 11:13:33 +0200 Subject: [PATCH 2/2] Apply @eric-wieser 'corrections --- .../Algorithms/BarrettReduction/Aux.lean | 34 +++---- .../Algorithms/BarrettReduction/Signed.lean | 93 +++++++++---------- references.bib | 20 ++++ 3 files changed, 82 insertions(+), 65 deletions(-) diff --git a/Cslib/Crypto/Algorithms/BarrettReduction/Aux.lean b/Cslib/Crypto/Algorithms/BarrettReduction/Aux.lean index ec517935c..d301d7531 100644 --- a/Cslib/Crypto/Algorithms/BarrettReduction/Aux.lean +++ b/Cslib/Crypto/Algorithms/BarrettReduction/Aux.lean @@ -14,7 +14,7 @@ public import Mathlib.Algebra.Order.Floor.Defs public import Mathlib.Data.Int.DivMod import Mathlib.Tactic -/- +/-! # Auxiliary definitions and lemmas - Defines `clog2`, a base 2 upper logarithm and some associated lemmas @@ -26,23 +26,23 @@ public section namespace Nat -def clog2 : ℕ → ℕ := Nat.clog 2 +abbrev clog2 : ℕ → ℕ := Nat.clog 2 lemma le_clog2_self (n : ℕ) : - n ≤ 2 ^ (n.clog2) := by + n ≤ 2 ^ (n.clog2) := by apply le_pow_clog (by simp) n lemma log2_le_clog2 (n : ℕ) : - n.log2 ≤ n.clog2 := by + n.log2 ≤ n.clog2 := by rw [log2_eq_log_two] apply Nat.log_le_clog 2 n lemma le_pow_iff_clog2_le {x y : ℕ} : - x ≤ 2 ^ y ↔ clog2 x ≤ y := + x ≤ 2 ^ y ↔ clog2 x ≤ y := by symm; apply Nat.clog_le_iff_le_pow; simp lemma clog2_le_log2 (n : ℕ) : - n.clog2 ≤ n.log2 + 1 := by + n.clog2 ≤ n.log2 + 1 := by rw [log2_eq_log_two] rw [← le_pow_iff_clog2_le] apply le_of_lt @@ -53,7 +53,7 @@ lemma clog2_le_log2 (n : ℕ) : simp lemma clog2_eq (n : ℕ) : - n.clog2 = if 2 ^ n.log2 < n then n.log2 + 1 else n.log2 := by + n.clog2 = if 2 ^ n.log2 < n then n.log2 + 1 else n.log2 := by have H₀ := clog2_le_log2 n have H₁ := log2_le_clog2 n split_ifs with Hcond <;> rw [← Nat.lt_clog_iff_pow_lt (by simp), ← clog2] at Hcond <;> linarith @@ -63,7 +63,7 @@ end Nat namespace Int lemma abs_bmod_le (x : ℤ) (m : ℕ) (Hm : 0 < m) : - |x.bmod m| ≤ m / 2 := by + |x.bmod m| ≤ m / 2 := by rw [abs_le]; apply And.intro · apply Int.le_bmod Hm · transitivity @@ -71,7 +71,7 @@ lemma abs_bmod_le (x : ℤ) (m : ℕ) (Hm : 0 < m) : · omega lemma bmod_eq' (x : ℤ) (m : ℕ) : - x.bmod m = x - m * (round (x / (m: ℚ))) := by + x.bmod m = x - m * (round (x / (m: ℚ))) := by rw [round_eq, Int.bmod] have X: x % m < (m + 1) / 2 ↔ 2 * (x % m) < m := by omega cases Nat.eq_zero_or_pos m with @@ -107,7 +107,7 @@ lemma bmod_eq' (x : ℤ) (m : ℕ) : simp only [Int.sub_nonneg]; apply And.intro <;> try linarith lemma emod_def' (x : ℤ) (m : ℕ) : - x % ↑m = if x.bmod m < 0 then m + x.bmod m else x.bmod m := by + x % ↑m = if x.bmod m < 0 then m + x.bmod m else x.bmod m := by simp [Int.bmod_def] split_ifs <;> try omega · cases Nat.eq_zero_or_pos m with @@ -120,18 +120,18 @@ lemma emod_def' (x : ℤ) (m : ℕ) : have X := @Int.emod_lt_of_pos x m (by omega); linarith lemma bmod_eq_of_abs_lt {n : ℤ} {m : ℕ} (hlt : |n| < m / 2) : - n.bmod m = n := by + n.bmod m = n := by rw [abs_lt] at hlt apply Int.bmod_eq_of_le <;> omega lemma bmod_bmod_eq_of_le {x : ℤ} {m1 m2 : ℕ} (h : 0 < m1) (h' : m1 ≤ m2) : - (x.bmod m1).bmod m2 = x.bmod m1 := by + (x.bmod m1).bmod m2 = x.bmod m1 := by have X0 := @Int.le_bmod x m1 h have X1 := @Int.bmod_le x m1 h rw [@Int.bmod_eq_of_le _ m2] <;> omega lemma bmod_bmod_eq_of_lt {x : ℤ} {m1 m2 : ℕ} (h : 0 < m1) (h' : m1 < m2) : - (x.bmod m1).bmod m2 = x.bmod m1 := by + (x.bmod m1).bmod m2 = x.bmod m1 := by rw [bmod_bmod_eq_of_le] <;> omega end Int @@ -144,7 +144,7 @@ variable {α : Type*} variable [Field α] [LinearOrder α] [IsStrictOrderedRing α] [FloorRing α] lemma floor_sub_abs (a b : α) : - |⌊a⌋ - ⌊b⌋| ≤ ⌈|a - b|⌉ := by + |⌊a⌋ - ⌊b⌋| ≤ ⌈|a - b|⌉ := by wlog Hab: a ≥ b · rw [abs_sub_comm ⌊a⌋, abs_sub_comm a] apply this; apply le_of_not_ge at Hab; assumption @@ -163,7 +163,7 @@ lemma floor_sub_abs (a b : α) : linarith lemma floor_lt_iff (a b : α) : - ⌊a⌋ < ⌊b⌋ ↔ ∃ (n: ℤ), a < ↑n ∧ ↑n ≤ b := by + ⌊a⌋ < ⌊b⌋ ↔ ∃ (n: ℤ), a < ↑n ∧ ↑n ≤ b := by apply Iff.intro · intro H; cases lt_or_ge a ↑⌊b⌋ with | inl Hlt => use ↑⌊b⌋; apply And.intro @@ -179,13 +179,13 @@ lemma floor_lt_iff (a b : α) : · assumption lemma round_sub_abs (a b : α) : - |round a - round b| ≤ ⌈|a - b|⌉ := by + |round a - round b| ≤ ⌈|a - b|⌉ := by rw [round_eq, round_eq] rw [show (a - b = (a + 1/2) - (b + 1/2)) by linarith] apply floor_sub_abs lemma round_lt_iff (a b : α) : - round a < round b ↔ ∃ (n: ℤ), a < n + 1/2 ∧ n + 1/2 ≤ b := by + round a < round b ↔ ∃ (n: ℤ), a < n + 1/2 ∧ n + 1/2 ≤ b := by apply Iff.intro · rw [round_eq, round_eq]; intro H rw [floor_lt_iff] at H diff --git a/Cslib/Crypto/Algorithms/BarrettReduction/Signed.lean b/Cslib/Crypto/Algorithms/BarrettReduction/Signed.lean index f28e0060d..8978a3616 100644 --- a/Cslib/Crypto/Algorithms/BarrettReduction/Signed.lean +++ b/Cslib/Crypto/Algorithms/BarrettReduction/Signed.lean @@ -15,18 +15,15 @@ public import Mathlib.Algebra.Order.Round import Mathlib.Tactic public import Cslib.Crypto.Algorithms.BarrettReduction.Aux -/- +/-! # Signed Barrett Reduction This file formalizes signed variant of the Barrett reduction algorithm used in many schemes such as ML-DSA or ML-KEM. -This formalization is inspired by Section 2.4 of the following paper -Efficient Multiplication of Somewhat Small Integers Using Number-Theoretic Transforms -Hanno Becker, Vincent Hwang, Matthias J. Kannwischer, Lorenz Panny, and Bo-Yin Yang -IWSEC 2022 +This formalization is inspired by Section 2.4 of [BeckerHKPY22] -The main theorem is `barrett_reduce_spec`. +The main theorem is `barrettReduce_spec`. See example at the end of file for how to use it. -/ @@ -35,24 +32,24 @@ namespace Cslib.Crypto.Algorithms.BarrettReduction.Signed notation "⌊" x "⌉" => round (x : ℚ) -def is_approx (δ : ℚ) (α : ℚ → ℤ) : Prop := +def IsApprox (δ : ℚ) (α : ℚ → ℤ) : Prop := ∀ (x: ℚ), |(x - α x)| ≤ δ -lemma round_is_approx : is_approx (1/2) round := by +lemma round_isApprox : IsApprox (1/2) round := by intro x; apply abs_sub_round -def round_to_even (x : ℚ) : ℤ := +def roundToEven (x : ℚ) : ℤ := 2 * ⌊(x / 2)⌉ -def mod_approx (α : ℚ → ℤ) (x : ℤ) (N : ℕ) : ℤ := x - ↑N * (α (x/N)) +def modApprox (α : ℚ → ℤ) (x : ℤ) (N : ℕ) : ℤ := x - ↑N * (α (x/N)) -public def smod (x : ℤ) (N : ℕ) : ℤ := mod_approx round x N +public def smod (x : ℤ) (N : ℕ) : ℤ := modApprox round x N notation x "mod±" N => smod x N lemma smod_is_bmod (x : ℤ) (N : ℕ) : - (x mod± N) = (x.bmod N) := by - rw [Int.bmod_eq_self_sub_mul_bdiv, smod, mod_approx] + (x mod± N) = (x.bmod N) := by + rw [Int.bmod_eq_self_sub_mul_bdiv, smod, modApprox] rw [Int.bdiv]; split_ifs with HN · rw [HN]; simp · simp only [mul_ite, sub_right_inj] @@ -81,22 +78,22 @@ lemma smod_is_bmod (x : ℤ) (N : ℕ) : linarith] split <;> simp -def barrett_mul (R : ℕ) (a b : ℤ) (q : ℕ) : ℤ := +def barrettMul (R : ℕ) (a b : ℤ) (q : ℕ) : ℤ := a * b - q * ⌊((a * ⌊((b * R) / q)⌉) / R)⌉ -- This is Fact 2 of cited paper above. -- M is the bitwidth of the considered integer type, e.g., 16, 32, 64, etc. -lemma barrett_mul_spec (a b : ℤ) (M R k q : ℕ) - (H1_le_k : 1 ≤ k) - (Hk : |((b * R) / (q : ℚ)) - ⌊((b * R) / q)⌉| ≤ (1 / (2 ^ k))) - (HOddq : Odd q) (HR : R = 2 ^ (M - 1 + q.log2 - |b|.toNat.clog2)) - (HM : 2 ≤ M) - (Hb : |b| ≤ 2 ^ (M - 2)) - (Ha' : |a| ≤ 2 ^ ((M - 2) - |b|.toNat.clog2 + (k - 1))) : - barrett_mul R a b q = (a * b).bmod q := by +lemma barrettMul_spec (a b : ℤ) (M R k q : ℕ) + (H1_le_k : 1 ≤ k) + (Hk : |((b * R) / (q : ℚ)) - ⌊((b * R) / q)⌉| ≤ (1 / (2 ^ k))) + (HOddq : Odd q) (HR : R = 2 ^ (M - 1 + q.log2 - |b|.toNat.clog2)) + (HM : 2 ≤ M) + (Hb : |b| ≤ 2 ^ (M - 2)) + (Ha' : |a| ≤ 2 ^ ((M - 2) - |b|.toNat.clog2 + (k - 1))) : + barrettMul R a b q = (a * b).bmod q := by have Hqpos: q > 0 := by exact Odd.pos HOddq have HRpos: R > 0 := by subst R; exact Nat.two_pow_pos _ - rw [← smod_is_bmod, barrett_mul, smod, mod_approx] + rw [← smod_is_bmod, barrettMul, smod, modApprox] simp only [Int.cast_mul, sub_right_inj, mul_eq_mul_left_iff, Int.natCast_eq_zero]; left let δ := a * (round ((b * R) / (q: ℚ))) / (R: ℚ) - ((a * b) / q) rw [show ↑a * ↑(round (↑b * ↑R / (q:ℚ))) / (R: ℚ) = ((a * b) / q) + δ by simp [δ]] @@ -199,19 +196,19 @@ lemma barrett_mul_spec (a b : ℤ) (M R k q : ℕ) · apply abs_nonneg ring_nf; omega -public def barrett_reduce (R : ℕ) (a : ℤ) (q : ℕ) : ℤ := +public def barrettReduce (R : ℕ) (a : ℤ) (q : ℕ) : ℤ := a - q * ⌊((a * ⌊(R / q)⌉) / R)⌉ -public theorem barrett_reduce_spec (a : ℤ) (M R k q : ℕ) - (H1_le_k : 1 ≤ k) - (Hk : |(R / (q : ℚ)) - ⌊(R / q)⌉| ≤ (1 / (2 ^ k))) - (HOddq : Odd q) (HR : R = 2 ^ (M - 1 + q.log2)) - (HM : 2 ≤ M) - (Ha' : |a| ≤ 2 ^ ((M - 2) + (k - 1))) : - barrett_reduce R a q = a.bmod q := by +public theorem barrettReduce_spec (a : ℤ) (M R k q : ℕ) + (H1_le_k : 1 ≤ k) + (Hk : |(R / (q : ℚ)) - ⌊(R / q)⌉| ≤ (1 / (2 ^ k))) + (HOddq : Odd q) (HR : R = 2 ^ (M - 1 + q.log2)) + (HM : 2 ≤ M) + (Ha' : |a| ≤ 2 ^ ((M - 2) + (k - 1))) : + barrettReduce R a q = a.bmod q := by nth_rw 2 [← mul_one a] - rw [← barrett_mul_spec a 1 M R k q] <;> try assumption - · rw [barrett_reduce, barrett_mul] + rw [← barrettMul_spec a 1 M R k q] <;> try assumption + · rw [barrettReduce, barrettMul] rw [mul_one]; simp · simp only [Int.cast_one, one_mul, one_div]; simp at Hk; assumption · simp only [abs_one]; refine one_le_pow₀ (by simp) @@ -221,9 +218,9 @@ end Cslib.Crypto.Algorithms.BarrettReduction.Signed section MLKEMExample open Cslib.Crypto.Algorithms.BarrettReduction.Signed -/- +/-! Proof of correctness for the signed Barrett reduction -used in the reference implementation of Kyber/MLKEM +used in the reference implementation of Kyber/ML-KEM https://github.com/pq-crystals/kyber/blob/main/ref/reduce.c#L25-L42 -/ @@ -233,16 +230,16 @@ def R : ℕ := 2 ^ 26 def k : ℕ := 2 -- This follows closely the original C code, though in ℤ -def mlkem_barrett_reduce (a : ℤ) : ℤ := +def mlkemBarrettReduce (a : ℤ) : ℤ := let v := 20159 let t := (v * a + (1 <<< 25)) >>> (26: ℕ) let t := t * 3329 a - t -lemma mlkem_barrett_reduce_correct (a : ℤ) (Ha : |a| ≤ 2 ^ 15) : - mlkem_barrett_reduce a = a.bmod q := by - rw [← barrett_reduce_spec a M R k q] - · rw [mlkem_barrett_reduce, barrett_reduce, R, q]; simp only [Nat.cast_ofNat, +lemma mlkemBarrettReduce_correct (a : ℤ) (Ha : |a| ≤ 2 ^ 15) : + mlkemBarrettReduce a = a.bmod q := by + rw [← barrettReduce_spec a M R k q] + · rw [mlkemBarrettReduce, barrettReduce, R, q]; simp only [Nat.cast_ofNat, Nat.reduceShiftLeft, Nat.reducePow, sub_right_inj] rw [show (round (67108864 / (3329:ℚ))) = 20159 by decide +kernel] rw [Int.shiftRight_eq_div_pow, round_eq] @@ -263,16 +260,16 @@ lemma mlkem_barrett_reduce_correct (a : ℤ) (Ha : |a| ≤ 2 ^ 15) : · decide -- This is basically the C code translated manually into Lean -def mlkem_barrett_reduce_impl (a : Int16) : Int16 := +def mlkemBarrettReduceImpl (a : Int16) : Int16 := let v: Int16 := 20159 let t: Int32 := (v.toInt32 * a.toInt32 + ((1: Int32) <<< 25)) >>> 26 let t: Int16 := t.toInt16 * 3329 a - t -lemma mlkem_barrett_reduce_impl_correct (a : Int16) : - Int16.toInt (mlkem_barrett_reduce_impl a) = (Int16.toInt a).bmod q := by - rw [← mlkem_barrett_reduce_correct] - · rw [mlkem_barrett_reduce, mlkem_barrett_reduce_impl] +lemma mlkemBarrettReduceImpl_correct (a : Int16) : + Int16.toInt (mlkemBarrettReduceImpl a) = (Int16.toInt a).bmod q := by + rw [← mlkemBarrettReduce_correct] + · rw [mlkemBarrettReduce, mlkemBarrettReduceImpl] rw [Int16.toInt_sub, Int16.toInt_mul] simp only [Nat.reduceLeDiff, Int16.toInt32_ofNat, Int32.toInt_toInt16, Nat.reducePow, Int16.reduceToInt, Int.bmod_mul_bmod, Int.sub_bmod_bmod, Nat.cast_ofNat, Nat.reduceShiftLeft] @@ -287,9 +284,9 @@ lemma mlkem_barrett_reduce_impl_correct (a : Int16) : have Hle := Int16.le_toInt a have Hlt := Int16.toInt_lt a rw [@Int.bmod_eq_of_le _ 4294967296] <;> [skip; (simp; omega); (simp; omega)] - rw [show (a.toInt - (20159 * a.toInt + 33554432) >>> 26 * 3329 = mlkem_barrett_reduce a.toInt) - by rw [mlkem_barrett_reduce]; simp] - rw [mlkem_barrett_reduce_correct, q] + rw [show (a.toInt - (20159 * a.toInt + 33554432) >>> 26 * 3329 = mlkemBarrettReduce a.toInt) + by rw [mlkemBarrettReduce]; simp] + rw [mlkemBarrettReduce_correct, q] · rw [Int.bmod_bmod_eq_of_lt] <;> omega · rw [abs_le]; omega · rw [abs_le']; split_ands diff --git a/references.bib b/references.bib index 1f8c0dc51..292e2cff8 100644 --- a/references.bib +++ b/references.bib @@ -52,6 +52,26 @@ @article{ Barendregt1984 year={1984} } +@inproceedings{BeckerHKPY22, + author = {Hanno Becker and + Vincent Hwang and + Matthias J. Kannwischer and + Lorenz Panny and + Bo{-}Yin Yang}, + editor = {Chen{-}Mou Cheng and + Mitsuaki Akiyama}, + title = {Efficient Multiplication of Somewhat Small Integers Using Number-Theoretic + Transforms}, + booktitle = {Advances in Information and Computer Security - 17th International + Workshop on Security, {IWSEC} 2022, Tokyo, Japan, August 31 - September + 2, 2022, Proceedings}, + series = {Lecture Notes in Computer Science}, + pages = {3--23}, + publisher = {Springer}, + year = {2022}, + url = {https://doi.org/10.1007/978-3-031-15255-9\_1} +} + @book{ Hopcroft2006, author = {Hopcroft, John E. and Motwani, Rajeev and Ullman, Jeffrey D.}, title = {Introduction to Automata Theory, Languages, and Computation (3rd Edition)},