Skip to content

brettkoonce/lean4-mlir

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1,022 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Lean 4 → MLIR → GPU

Interactive proof blueprint: brettkoonce.github.io/lean4-mlir/blueprint/ (or PDF) — clickable dependency DAG for the full VJP proof suite (no sorrys, zero project axioms), from pdiv primitives up to the whole-network VJPs (ViT, ResNet, MobileNetV2, ConvNeXt, EfficientNet).

Lean 4 as a specification language for neural networks. Declare architecture in Lean, generate StableHLO MLIR (forward + loss + backward + optimizer all in one fused function), compile to GPU via IREE, train end-to-end. No Python runtime, no autograd library — the gradients are computed at codegen time in Lean.

Companion code for the upcoming book Verified Deep Learning with Lean 4 (follow-up to Convolutional Neural Networks with Swift for TensorFlow, Apress).

DOI

Current version: v0.6.0 — Object detection joins the framework. A YOLOv1 person detector on Pascal VOC, bootstrapped from Chapter 6's ResNet-34 backbone with a 1×1 convolutional detection head — the FC head can't learn localization on a small dataset (diagnosed end to end in planning/yolo_v5.md); the focal-loss objectness term is the remaining piece, flagged in the bestiary's new detection intro. The IREE/Lean training path gains global-norm gradient clipping, env-var checkpoint resume (LEAN_MLIR_INIT_LOAD / LEAN_MLIR_START_STEP), and per-step LR warmup. Blueprint adds demo-anchored intros for object detection (§11.2.2) and diffusion (§11.2.7), redraws the MNIST/CIFAR training curves as native pgfplots (vector, regenerable from logs/), and adds a BatchNorm "why normalizing helps" intuition.

The v0.5.7 headline still holds: two parallel-agent audits closed. The "canonical correct := rfl" pattern at non-smooth operators (ReLU, the composed MLP, MaxPool2) now has machine-checked smooth-point bridges: relu_codegen_matches_canonical and maxPool2_codegen_matches_canonical prove the canonical-witness backward equals the codegen formula wherever every coordinate avoids the kink. A HasVJPAt pointwise framework provides smooth-input variants of the three kinked-operator instances whose correct field is a real chain-rule proof rather than rfl. The comparator suite extends from 38 → 41 theorems independently kernel-rechecked against [propext, Quot.sound, Classical.choice]. Blueprint gets a half-dozen flow improvements (GAP defined at first material use in Ch 6, Diffusion split into its own Bestiary subsection, ResNet entry expanded to the full standard family including R-18, Tomáš Skřivan's Scientific Computing in Lean credited at the top of the acknowledgments). Android bottom-cutoff bug fixed (issue #2); Umami cookieless analytics replaces planned GA. First Zenodo deposit lands with this release.

The v0.5.6 headline still holds: Chapter 9 lands its ConvNeXt-T worked example (84.94% val on Imagenette, paper-faithful recipe); Chapter 10 gets a Data Augmentation section with a 9-row ViT recipe ablation table — CutMix is the load-bearing knob at 9.5K images, and stacking RandAugment + Random Erasing on top of it hurts val accuracy. Bestiary gets paper-exact entries for VGG, ResNet-50/101/152, WRN, and DenseNet, plus the "N new primitives" claim reframed around the Ch 2-10 reader's toolbox (what's free) rather than the codebase (what's already in Types.lean). Found and fixed a long-standing eval-pipeline bug along the way: centerCrop was running on already-224 val data, reading past per-image bounds and making heavy-aug runs appear to collapse. New LEAN_MLIR_EVAL_ONLY=1 mode re-evals saved checkpoints in ~5 sec each.

The v0.5.5 headline still holds: Swish/SiLU as a first-class activation (forward + backward + proved swish_has_vjp_correct) plus the independent-kernel comparator re-check covering 38 theorems via public *_has_vjp_correct wrappers, and Ch 2's "Why VJPs, not Jacobians?" bridge + canonical-pdiv witness explainer + three-pillar TikZ spine diagram.

On top of that, a differential-test suite in tests/vjp_oracle/ uses JAX's value_and_grad as an oracle for the hand-derived VJPs in LeanMlir/Proofs/. Nine test cases cover every axiom family — dense, conv, BN, maxPool, residual (biPath), depthwise, SE (elementwise product), attention, and the transformer block — each verified to 1–2 ULP of JAX autodiff. The ULP-floor cross-backend agreement (Lean→IREE→GPU vs Lean→JAX→XLA on both NVIDIA and AMD) established in v0.5.4 still holds; see traces/CROSS_BACKEND_RESULTS.md for the four-corner verification tables.

Three phases

This project went through three implementations of the same idea — "Lean 4 as a specification language for deep learning" — each shedding more dependencies than the last.

Phase 1 — Pure Lean 4. mnist-lean4/: everything in Lean, Float64 as the only datatype, hand-written gradients, C FFI to OpenBLAS / hipBLAS for the matmuls. Worked end-to-end on MNIST through ResNet-34 but performance was poor — every operation crossed the FFI boundary, no fusion, no autodiff, no JIT.

Phase 2 — Lean → JAX. jax/: Lean as a metaprogramming layer that emits idiomatic JAX Python (jax/Jax/Codegen.lean, ~2100 lines). The generated script gets value_and_grad autodiff and XLA JIT for free, runs on any JAX-supported device. Trades the pure-Lean story for a working stack and real GPU performance. See jax/README.md for details.

Phase 3 — Lean → StableHLO → MLIR → device. (this README) No Python runtime at all. Lean directly emits StableHLO MLIR, IREE compiles it to a GPU flatbuffer, a thin C FFI loads and runs it. The pure-math version of phase 2 — autodiff is done at codegen time in Lean (LeanMlir/MlirCodegen.lean, ~7500 lines), not at runtime by a framework. See RESULTS.md for the per-architecture numbers.

The VJP correctness proofs live in LeanMlir/Proofs/ — chapter-by-chapter, for tensor ops, MLP, CNN, residual, batch norm, depthwise, SE, LayerNorm, and attention, up to whole-network backward passes (ViT, ResNet, MobileNetV2, ConvNeXt, EfficientNet). What they establish: each reference forward function, written in exact real arithmetic (), has a backward pass equal to its Mathlib fderiv Jacobian-transpose — with zero project axioms (#print axioms closes under the Lean-core triple alone).

The whole-network results come in two forms, set by the architecture's activations:

  • Unconditional — ViT, ConvNeXt, EfficientNet. These use only smooth ops (GELU / Swish / sigmoid, softmax, LayerNorm, convolution — no ReLU, no max-pool), so the VJP holds at every input, with the LayerNorm/BatchNorm 0 < ε positivity as the only side conditions (vit_full_has_vjp, convnext_has_vjp, efficientnet_has_vjp : HasVJP …).

  • Conditional + concretely instantiated — MLP, MNIST-CNN, ResNet, MobileNetV2. ReLU, ReLU6 and max-pool are genuinely non-differentiable at their kinks, so the generic whole-network VJP is stated at a smooth point (*_has_vjp_at, under per-site "off the kink" hypotheses). Each is then instantiated on a concrete (small, representative) network with every smoothness hypothesis discharged, giving a hypothesis-free correctness theorem (MlpConcrete, Spatial/Mini, CnnConcrete, MobileNetV2Concrete) — proof that the kink-avoidance conditions are jointly satisfiable on the real forward, not vacuous.

Axiom closure on every one of these is a CI invariant (tests/AuditAxioms.lean); the generic headline theorems are additionally re-checked by the independent tests/comparator/ kernel pass.

These proofs are about the reference definitions in Proofs/, not the Float32 StableHLO the codegen emits — the two are written separately and no Lean theorem currently links them. The connection is instead twofold. (1) Structural: codegen and proofs were developed independently and arrived at the same decomposition — every backward pass factors through the standalone gradient of one new primitive per architecture (softmax for attention, the spatial reductions for BN, the rank-1 collapse for SE), and everything else is composition via the chain rule on tools from earlier chapters — and the codegen cites the matching proof inline in the MLIR it generates. (2) Numerical: finite-difference checks (LeanMlir/Proofs/check_jacobians.py) and JAX value_and_grad oracles (tests/vjp_oracle/) exercise the emitted formulas — including at the ReLU/MaxPool kinks, where the codegen substitutes the standard subgradient convention. See the "Codegen trust boundary" section of LeanMlir/Proofs/README.md for the precise gap. Closing it formally — a forward-extraction lemma tying a proven *Forward to the codegen's emitted graph — is open future work.

What is and isn't verified

All proofs are over exact reals (). The emitted MLIR and GPU execution are Float32; iree-compile, the IREE runtime, and the FFI are trusted. Within that boundary, the verification is tiered by dataset / backend:

Tier 1 — MNIST (linear, mlp, cnn): forward + backward bridged. The reference forward and backward are proven faithful to the Mathlib fderiv math as rendered StableHLO graphs (mlpFwdGraph_faithful, mlpBackGraph_faithful, cnnFwdGraph_faithful, cnnBackGraph_faithful; for linear also the param-grad Jacobians wGrad/bGrad_is*Jacobian and sgdW/sgdB_descends_certified_grad). All audited to the 3-axiom closure. Caveat: the train-step .mlir is currently assembled from these proven op-graphs with a hand-written grad/SGD tail (see linearTrainStepModuleV); folding that tail into the rendered AST so the whole train-step module is render(provenGraph) is in progress. Tier 1 also now carries the Float32 bridge (below): forward, gradient, and SGD-step rounding budgets for the linear/MLP nets, and for linear a proven descent guarantee.

Tier 2 — CIFAR (cifar, cifar-bn): forward bridged, backward WIP. cifarFwdGraph_faithful / cifarBnFwdGraph_faithful (plus op-level bnBack_faithful) hold; the whole-net backward graph and the train step are not yet rendered from a proof.

Tier 3 — Imagenette (ResNet-34, MobileNetV2, ConvNeXt, EfficientNet, ViT): ℝ whole-net VJP proven; codegen bridge WIP. The whole-network VJP is proven over (resnet34_has_vjp_at, vit_full_has_vjp, convnext_has_vjp, efficientnet_has_vjp, mobilenetv2_has_vjp_at). The rendered-MLIR bridge is forward-graph-only (resnet/mnv2/convnext) or op-level-only (efficientnet/vit), and the GPU trainers behind the Imagenette numbers below run the unverified MlirCodegen.lean path. No theorem links those proofs to that codegen yet.

Tier 4 — ImageNet-1k (phase-2 Lean→JAX bridge): scale baseline, gradients not Lean-verified. Full 1000-class ImageNet runs use the phase-2 path (jax/Jax/Codegen.lean, ~1100 lines: NetSpec → idiomatic JAX Python), where JAX's value_and_grad computes the gradients and XLA does the compilation — the Lean VJP proofs are not in the loop. The only proof-adjacent Lean artifact is the shared NetSpec ADT (the same architecture spec whose phase-3 backward is proven over ); the emitter itself is unverified. This tier exists to (a) establish scale baselines the verified-IREE codegen can't yet reach — ConvNeXt-T 75.93% / EfficientNet-B0 72.31% / ResNet-34 72.02% / MobileNetV2 68.33% / ViT-Tiny 65.64% top-1, full 50k val (jax/runs/*/RESULTS.md) — and (b) serve as the differential-test oracle: tests/vjp_oracle/ uses JAX value_and_grad as ground truth to cross-check the Tier 1–3 Lean-derived VJPs to 1–2 ULP. So Tier 4 is the least-verified tier by gradient provenance but the one that empirically anchors the others. Whether phase-3 verified codegen can reach ImageNet scale is open.

The ℝ→Float32 bridge (Tier 1)

All tier proofs are over exact reals; LeanMlir/Proofs/FloatBridge.lean + SgdDescent.lean/SgdDescentLinear.lean/SgdDescentMlp.lean close the rounding gap for the Tier-1 nets, hypothesis-style (zero project axioms — a FloatModel is any rounding operator with relative error u; binary32 instantiates it with u = 2⁻²⁴ on the normal range, subnormals open). The chain, every link in the 3-axiom audit:

  • Forward (mlp_float_close_uniform): dot/dense budgets in the classical compounded form, valid for every summation association (IREE may reassociate reductions freely). ReLU is exact in float — the op that forces the off-the-kink hypotheses over is the free op here.
  • Backward (mlp_{w2,w1,w0,b2,b1,b0}_step_float_close): every rounded SGD parameter entry within an explicit budget of θ − lr·(aᵢ·cⱼ) — the same emitWeightGrad/emitBiasGrad entries mlp_render_*_certified prove equal to the pdiv-Jacobian contractions. The ReLU masks need quantitative margins (ez < |zᵢ|: rounding must not flip a sign).
  • Loss head (softmax_ce_cot_close): the rounded softmax−onehot cotangent vs the certified gradient, given an exp accuracy hypothesis (|fexp t − exp t| ≤ eexp·exp t — GPU exp has no IEEE spec; eexp is the constant tests/vjp_oracle/ validates at 1–2 ULP).
  • Descent (sgd_descends, linear_sgd_descends, mlp_{output,hidden,input}_sgd_descends): an η-accurate gradient step still decreases the loss — with the smoothness hypothesis proven, not assumed: explicit constant 2a²/(1−2aD) for the linear net, and through the MLP's ReLU kinks per weight layer under quantitative margins (the step's ℓ1 radius cannot flip a mask sign, so the sign pattern freezes along the segment): 2d₃w₂²a²/(1−2w₂aD) for the hidden layer, 2d₃d₂²w₁²w₂²a²/(1−2w₂d₂w₁aD) for the input layer; the output layer is the linear theorem at the hidden activation, margin-free. No Hessian anywhere (the same softmax ratio sandwich as the float budgets).

Measured vs proven (scripts/margin_probe.py, an f32/f64 twin of the 97.8% GPU run; numeric capstones instantiated at the trained magnitudes |W| ≤ 3/5):

quantity worst-case theorem measured
logit drift ≤ 5100 (mnist_mlp_float_budget) 1.6·10⁻⁵
cotangent ≤ 21/1000 at δ=1/100 (mnist_cot_budget) 2.2·10⁻⁶
W₂ SGD step ≤ 5/4 (mnist_w2_step_float_budget) 7.5·10⁻⁹
ReLU mask flips 0 under margins 0 / 29.5M

The worst-case-vs-measured gap (up to ~10⁸) is the quantitative case for a-posteriori certificates past toy depth; the zero flip count says the margin hypotheses describe real training, not a technicality.

Not yet verified anywhere: the ~7500-line MlirCodegen.lean (zero theorems); outside Tier 1, the train-step text that iree-compile actually consumes; and, within the float bridge, subnormals (the model is relative-error-only), the joint all-layers descent step and bias columns (the per-weight-layer constants are proven for linear + MLP; for the CNN the new ingredients are proven — quantitative max-pool selection margins that freeze the argmax routing, pool ℓ1-contraction, conv-kernel drift with the weight-sharing factor — but the conv-layer capstone assembly is open, planning/sgd_descent_cnn.md; so is every-parameter-at-once, where the logits are no longer affine in the moving parameters), and any link from the Lean-side FloatModel to IREE's actual kernels beyond the empirical probe.

Concrete-instance honesty. The conditional capstones (MLP, MNIST-CNN, CIFAR, MobileNetV2, ResNet-34) are instantiated to discharge their off-the-kink hypotheses. MlpConcrete, Micro/Mini/Spatial (MNIST) and Tiny (CIFAR) are live witnesses (non-constant forward, nonzero Jacobian). MobileNetV2Concrete, CnnConcrete, and ResNet34Concrete are degenerate constant-output nets (zero Jacobian) — they prove the hypothesis bundle is satisfiable but say nothing about a realistic gradient; live witnesses for the deep/BN/ReLU6 nets are follow-up.

Pipeline

Lean NetSpec  (~15 lines)
   │
   │  MlirCodegen.generateTrainStep
   ▼
StableHLO MLIR  (500 KB - 2 MB of text, forward+loss+backward+Adam fused)
   │
   │  iree-compile (~10-15 min for ROCm gfx1100)
   ▼
VMFB flatbuffer  (1.8-3 MB)
   │
   │  IREE runtime via libiree_ffi.so
   ▼
GPU execution  (HIP/ROCm or CUDA)

The same Lean → MLIR pipeline handles every architecture. Adding a new architecture means extending LeanMlir/MlirCodegen.lean with:

  • forward emission for the new layer types
  • VJP / backward emission
  • FwdRec recording for backward intermediates

The training executable, FFI, and IREE runtime are unchanged.

Cross-backend verification

Phase 2 and Phase 3 share the same Lean NetSpec ADT but compile through completely independent stacks (JAX/XLA vs IREE). Differential testing confirms both stacks produce the same training dynamics on the same input, for both MLP (670K params, 12 epochs) and CNN (1.7M params with conv+BN, 15 epochs):

diff MLP step 1 Δ CNN step 1 Δ
phase 2 (JAX) vs phase 3 (IREE) ~2e-7 ~1e-5 to 1e-4
phase 3 ROCm vs phase 3 CUDA 0 0
phase 2 CPU vs phase 2 CUDA ~4e-6 ~1e-4

MLP hits the float32 ULP floor because it's dense-only. CNN's noise floor is looser by ~100× because each conv-BN layer does two reductions over ~100k-element tensors and XLA's reduction trees differ from IREE's — both pipelines do correct math, just with different summation orders. Phase 3 ROCm ≡ Phase 3 CUDA is bit-identical at step 1 on both networks. Reproducible in 5 minutes via traces/CROSS_BACKEND_RESULTS.md.

VJP oracle

A separate per-axiom differential test in tests/vjp_oracle/ uses JAX's value_and_grad as a correctness oracle for every hand-derived backward pass in LeanMlir/Proofs/. Each test case is a minimal NetSpec exercising one axiom in isolation; the oracle compares step-2 loss (the first step whose value depends on the backward pass) against phase 2's autodiff-derived gradients.

Nine cases, all green on mars (ROCm + CPU) and ares (CUDA):

case axiom step 2 Δ
dense dense_has_vjp + softmaxCE_grad 2.7e-07
dense-relu relu_has_vjp + vjp_comp 4.8e-07
conv conv2d_has_vjp + flatten_has_vjp 2.2e-07
convbn convBn_has_vjp (BN-mode) 2.2e-06
conv-pool maxPool_has_vjp (argmax tiebreaks) 1.2e-04
residual biPath_has_vjp (additive fan-in) 3.1e-07
depthwise depthwise-conv VJP via .invertedResidual 1.1e-05
mbconv elemwiseProduct_has_vjp (SE gate) + Swish 1.6e-06
attention patchEmbed + transformerBlock_has_vjp_mat + classifier 1.8e-07

Run with tests/vjp_oracle/run.sh. Adding a new axiom means dropping a minimal Lean spec under tests/vjp_oracle/phase{2,3}/ plus one line in the lakefiles — see tests/vjp_oracle/README.md.

The oracle also surfaced a real heInitParams bug (shape-peek heuristic misfiring at patchEmbed + transformer-block boundaries) and a JAX-ROCm crash on gfx1100 (filed as ROCm/MIOpen#3955; repro lives at upstream-issues/2026-04-rocm-miopen-conv-segv/).

Results (Imagenette, 10 classes, 224×224)

Trained from scratch on a single AMD 7900 XTX (gfx1100), Adam, batch 32, cosine LR + 3-epoch warmup, label smoothing 0.1, weight decay 1e-4, random crop (256→224) + horizontal flip, running BN stats for eval.

Model Params Val accuracy
ResNet-34 21.3M 90.29%
ResNet-50 23.5M 89.40%
EfficientNetV2-S 38.2M 88.50%
EfficientNet-B0 7.2M 87.58%
MobileNetV2 2.2M 87.09%
MobileNetV3-Large 3.0M 86.48%
MobileNetV4-Medium 4.1M 84.58%
ViT-Tiny 5.5M 71.70%

Per-epoch eval histories and ablation tables in RESULTS.md.

Quick start

1. Install Lean 4

curl https://raw.githubusercontent.com/leanprover/elan/master/elan-init.sh -sSf | sh

2. Install IREE

You need the IREE runtime built for your GPU (CUDA or ROCm). The FFI shim in ffi/ links against libiree_runtime_unified.a from the IREE build tree. See IREE_BUILD.md for build instructions.

3. Get data

./download_mnist.sh        # MNIST (Ch 3-4 trainers)
./download_cifar.sh        # CIFAR-10 (Ch 5 trainers)
./download_imagenette.sh   # Imagenette 320px → preprocessed binary (Ch 6+)

4. Build a trainer

lake build resnet34-train

This compiles the Lean trainer (which generates MLIR + drives IREE + runs the training loop). Other targets, in roughly book order: mnist-mlp-train, mnist-cnn-train, cifar-cnn-train, cifar-bn-train, resnet50-train, mobilenet-v2-train, mobilenet-v3-train, mobilenet-v4-train, efficientnet-train, efficientnet-v2-train, vgg-train, vit-tiny-train.

5. Run

The first invocation generates and compiles the vmfbs (slow — IREE compilation takes 10-15 min for ResNet-sized models). Subsequent runs reuse the cached vmfbs unless you clear .lake/build/.

HIP_VISIBLE_DEVICES=0 IREE_BACKEND=rocm .lake/build/bin/resnet34-train

# Or via the included shell wrapper that sets the env vars correctly
bash run.sh resnet34                  # GPU 0, ROCm (defaults)
bash run.sh efficientnet-v2 1 cuda    # GPU 1, CUDA

For CUDA, set IREE_BACKEND=cuda (the default) and use CUDA_VISIBLE_DEVICES.

Lean specs

The same NetSpec type is used by all three phases. A spec is a list of Layer values:

def resnet34 : NetSpec where
  name := "ResNet-34"
  imageH := 224
  imageW := 224
  layers := [
    .convBn 3 64 7 2 .same,
    .maxPool 2 2,
    .residualBlock  64  64 3 1,
    .residualBlock  64 128 4 2,
    .residualBlock 128 256 6 2,
    .residualBlock 256 512 3 2,
    .globalAvgPool,
    .dense 512 10 .identity
  ]

def vitTiny : NetSpec where
  name := "ViT-Tiny"
  imageH := 224
  imageW := 224
  layers := [
    .patchEmbed 3 192 16 196,             -- (224/16)^2 = 196 patches
    .transformerEncoder 192 3 768 12,     -- 12 blocks, 3 heads, MLP dim 768
    .dense 192 10 .identity
  ]

def mobilenetV4Medium : NetSpec where
  name := "MobileNet V4-Medium"
  imageH := 224
  imageW := 224
  layers := [
    .convBn 3 32 3 2 .same,
    .fusedMbConv 32 48 4 3 2 1 false,
    .uib  48  80 4 2 3 5,    -- ExtraDW
    .uib  80 160 6 2 0 3,    -- IB (= MBConv)
    .uib 160 160 4 1 5 0,    -- ConvNeXt
    .uib 160 160 4 1 0 0,    -- FFN
    -- ... 11 more UIB blocks
    .convBn 256 1280 1 1 .same,
    .globalAvgPool,
    .dense 1280 10 .identity
  ]

Project structure

lean4-mlir/
├── README.md               -- this file (phase 3)
├── RESULTS.md              -- per-architecture eval histories + ablations
├── IREE_BUILD.md           -- how to build libiree_ffi.so from scratch
├── ROCM.md                 -- ROCm setup notes
├── BENCHMARK.md            -- ROCm vs CUDA performance comparison
├── lakefile.lean           -- Lake build config (libraries + ~30 execs)
│
├── LeanMlir.lean           -- umbrella module
├── LeanMlir/
│   ├── MlirCodegen.lean    -- ~7500 lines, NetSpec → StableHLO MLIR
│   ├── IreeRuntime.lean    -- Lean ↔ libiree_ffi.so bindings
│   ├── F32Array.lean       -- ByteArray-backed float32 helpers
│   ├── Spec.lean           -- NetSpec / Layer / param-counting
│   ├── Types.lean          -- core types (Layer, Activation, Padding, ...)
│   ├── MnistData.lean      -- IDX file loader (older training paths)
│   └── Proofs/             -- VJP correctness proofs (~36,700 lines)
│       ├── Tensor.lean
│       ├── MLP.lean
│       ├── CNN.lean
│       ├── Residual.lean
│       ├── BatchNorm.lean
│       ├── Depthwise.lean
│       ├── SE.lean
│       ├── LayerNorm.lean
│       ├── Attention.lean
│       ├── FloatBridge.lean      -- ℝ→Float32: rounding budgets (Tier 1)
│       ├── SgdDescent.lean       -- inexact-gradient descent over ℝ
│       ├── SgdDescentLinear.lean -- Lipschitz constants, linear loss
│       ├── SgdDescentMlp.lean    -- ...through the MLP's ReLU kinks (margins)
│       └── SgdDescentCnn.lean    -- ...toward the CNN: pool selection margins
│
├── Main*Train.lean         -- phase 3 trainers (one per architecture)
│   ├── MainResnetTrain.lean
│   ├── MainResnet50Train.lean
│   ├── MainMobilenetV2Train.lean
│   ├── MainMobilenetV3Train.lean
│   ├── MainMobilenetV4Train.lean
│   ├── MainEfficientNetTrain.lean
│   ├── MainEfficientNetV2Train.lean
│   ├── MainVitTrain.lean
│   ├── MainVggTrain.lean
│   ├── MainMnistMlpTrain.lean
│   ├── MainMnistCnnTrain.lean
│   ├── MainCifarCnnBnTrain.lean
│   ├── MainCifarCnnTrain.lean
│   └── MainAblation.lean
│
├── tests/                  -- unit tests + smoke tests + differential tests
│   ├── Test*.lean          -- runtime / FFI / codegen sanity tests
│   ├── BenchResnet.lean
│   ├── diff_traces.py      -- JSONL trace diff helper
│   ├── cross_backend_mnist_mlp.sh
│   └── vjp_oracle/         -- JAX-autodiff oracle for hand-derived VJPs
│       ├── README.md
│       ├── run.sh
│       ├── diff_step.py
│       ├── phase3/         -- Lean→IREE test trainers
│       └── phase2/         -- (mirrored at jax/tests/vjp_oracle/phase2/)
│
├── upstream-issues/        -- isolated reproducers + backtraces for bugs
│   └── 2026-04-rocm-miopen-conv-segv/  -- ROCm/MIOpen#3955
│
├── ffi/
│   ├── iree_ffi.{c,h}      -- IREE runtime wrapper
│   ├── iree_lean_ffi.c     -- Lean FFI bindings
│   ├── f32_helpers.c       -- data loading, He init, EMA, augmentation
│   └── libiree_ffi.so      -- compiled shared library
│
├── jax/                    -- phase 2 (Lean → JAX Python)
│   ├── README.md
│   ├── Jax.lean
│   ├── Jax/{Codegen,Runner}.lean
│   ├── Main*.lean          -- 14 JAX-driven architecture specs
│   └── tests/vjp_oracle/phase2/  -- phase-2 mirror of oracle specs
│
├── mnist-lean4/            -- phase 1 (pure Lean 4 + C BLAS)
│
├── traces/                 -- committed cross-backend training traces
│   ├── CROSS_BACKEND_RESULTS.md
│   ├── TRACE_FORMAT.md
│   └── mnist_{mlp,cnn}.*.jsonl
│
├── data/                   -- downloaded + preprocessed datasets
└── run_*.sh                -- shell wrappers for tmux env propagation

Supported layers (phase 3 codegen)

Layer Description
dense Fully connected (with optional activation)
conv2d Standard convolution
convBn Conv + batch norm + ReLU/ReLU6/Swish/h-swish
residualBlock BasicBlock (ResNet-18/34)
bottleneckBlock Bottleneck (ResNet-50/101/152)
invertedResidual Expand → depthwise → project + skip (MobileNetV2)
mbConv + Squeeze-Excitation, Swish (EfficientNet)
mbConvV3 + h-swish + h-sigmoid SE (MobileNetV3, exact math)
fusedMbConv k×k regular conv replaces (1×1 expand + depthwise) (EfficientNetV2)
uib Universal Inverted Bottleneck — pre-DW? + expand + post-DW? + project (MobileNetV4)
patchEmbed Conv patch projection + CLS token + positional embedding (ViT)
transformerEncoder LN → MHSA → + → LN → MLP → +, with exact tanh-form GELU
maxPool, globalAvgPool, flatten Structural

Activations supported with exact backward: ReLU, ReLU6, Swish, h-swish, h-sigmoid, GELU (tanh form). Layer-norm and batch-norm both have proper VJPs and (for BN) running statistics for eval.

Lean version

Tested with Lean 4.30.0 / Lake 5.0.0, IREE built from source against ROCm 7.2.0 / gfx1100.

Citing this work

@software{koonce2026leanmlir,
  author  = {Brett Koonce},
  title   = {Verified Deep Learning with Lean 4: Formal Backpropagation from MLP to Attention, via MLIR},
  url     = {https://github.com/brettkoonce/lean4-mlir},
  doi     = {10.5281/zenodo.20402133},
  version = {0.6.0},
  year    = {2026},
}

About

LLM-assisted Lean specification of neural architectures with verified IREE codegen.

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors