From 6362f8c3b96675fa839b2d4fd398ddcf4ece8072 Mon Sep 17 00:00:00 2001 From: Rob23oba Date: Thu, 5 Mar 2026 16:14:54 +0100 Subject: [PATCH 1/2] optimization for casesOn proof overapp --- src/Lean/Compiler/LCNF/ToLCNF.lean | 52 +++++++++++++++++++++++++----- src/Lean/Meta/CasesInfo.lean | 4 ++- 2 files changed, 47 insertions(+), 9 deletions(-) diff --git a/src/Lean/Compiler/LCNF/ToLCNF.lean b/src/Lean/Compiler/LCNF/ToLCNF.lean index 3f307f68da30..97ca39b79574 100644 --- a/src/Lean/Compiler/LCNF/ToLCNF.lean +++ b/src/Lean/Compiler/LCNF/ToLCNF.lean @@ -414,6 +414,40 @@ private def checkComputable (ref : Name) : M Unit := do else if getOriginalConstKind? (← getEnv) ref matches some .axiom | some .quot | some .induct | some .thm then throwNamedError lean.dependsOnNoncomputable f!"`{ref}` not supported by code generator; consider marking definition as `noncomputable`" +/-- +Given a motive `fun ... => ∀ (h₁ : p₁) ... (hₙ : pₙ), ...` where all `pᵢ` are propositions, +returns `min maxArgs n`. +-/ +def countProofOverApp (motive : Expr) (maxArgs : Nat) : MetaM Nat := do + Meta.lambdaTelescope motive fun vars body => do + Meta.forallTelescope body fun forallVars _ => do + let n := min maxArgs forallVars.size + for h : i in *...n do + have h : i < min maxArgs forallVars.size := h + unless ← Meta.isProof forallVars[i] do + return i + return n + +/-- +Given `e : ∀ (h₁ : p₁) ... (hₙ : pₙ), b h₁ ... hₙ`, returns `e (lcProof p₁) ... (lcProof pₙ)`. +-/ +def mkProofOverApp (e : Expr) (n : Nat) : MetaM Expr := do + let mut type ← Meta.inferType e + if type.getNumHeadForalls < n then + type ← Meta.forallBoundedTelescope type n Meta.mkForallFVars + go type n #[] +where + go (type : Expr) (n : Nat) (vars : Array Expr) : MetaM Expr := + match n with + | 0 => return e.beta vars + | k + 1 => + match type with + | .forallE _ t b _ => + let t := t.instantiateRev vars + go b k (vars.push (mkLcProof t)) + | .mdata _ type => go type n vars + | _ => Meta.throwFunctionExpected (e.beta vars) + /-- Eta reduce implicits. We use this function to eliminate introduced by the implicit lambda feature, where it generates terms such as `fun {α} => ReaderT.pure` @@ -541,11 +575,11 @@ where /-- Visit a `matcher`/`casesOn` alternative. -/ - visitAlt (casesAltInfo : CasesAltInfo) (e : Expr) : M (Expr × (Alt .pure)) := do + visitAlt (casesAltInfo : CasesAltInfo) (e : Expr) (nproofs : Nat) : M (Expr × (Alt .pure)) := do withNewScope do match casesAltInfo with | .default numHyps => - let c ← toCode (← visit (mkAppN e (Array.replicate numHyps erasedExpr))) + let c ← toCode (← visit (← liftMetaM <| mkProofOverApp e (numHyps + nproofs))) let altType ← c.inferType return (altType, .default c) | .ctor ctorName numParams => @@ -572,14 +606,16 @@ where not occur in the type of `as : List α`. -/ p.update (← applyToAny p.type) - let c ← toCode (← visit e) + let c ← toCode (← visit (← liftMetaM <| mkProofOverApp e nproofs)) let altType ← c.inferType return (altType, .alt ctorName ps c) visitCases (casesInfo : CasesInfo) (e : Expr) : M (Arg .pure) := etaIfUnderApplied e casesInfo.arity do let args := e.getAppArgs - let mut resultType ← toLCNFType (← liftMetaM do Meta.inferType (mkAppN e.getAppFn args[*...casesInfo.arity])) + let nproofs ← liftMetaM <| countProofOverApp args[casesInfo.motivePos]! (args.size - casesInfo.arity) + let arity := casesInfo.arity + nproofs + let mut resultType ← toLCNFType (← liftMetaM do Meta.inferType (mkAppN e.getAppFn args[*...arity])) let typeName := casesInfo.indName let .inductInfo indVal ← getConstInfo typeName | unreachable! if casesInfo.numAlts == 0 then @@ -609,8 +645,8 @@ where fieldArgs := fieldArgs.push fieldArg return fieldArgs let f := args[casesInfo.altsRange.lower]! - let result ← visit (mkAppN f fieldArgs) - mkOverApplication result args casesInfo.arity + let result ← visit (← liftMetaM <| mkProofOverApp (mkAppN f fieldArgs) nproofs) + mkOverApplication result args arity else let mut alts := #[] let discr ← visitAppArg args[casesInfo.discrPos]! @@ -618,14 +654,14 @@ where | .fvar discrFVarId => pure discrFVarId | .erased | .type .. => mkAuxLetDecl .erased for i in casesInfo.altsRange, numParams in casesInfo.altNumParams do - let (altType, alt) ← visitAlt numParams args[i]! + let (altType, alt) ← visitAlt numParams args[i]! nproofs resultType := joinTypes altType resultType alts := alts.push alt let cases := ⟨typeName, resultType, discrFVarId, alts⟩ let auxDecl ← mkAuxParam resultType pushElement (.cases auxDecl cases) let result := .fvar auxDecl.fvarId - mkOverApplication result args casesInfo.arity + mkOverApplication result args arity visitCtor (arity : Nat) (e : Expr) : M (Arg .pure) := etaIfUnderApplied e arity do diff --git a/src/Lean/Meta/CasesInfo.lean b/src/Lean/Meta/CasesInfo.lean index 6d85a3c56d2a..90bcc25800b7 100644 --- a/src/Lean/Meta/CasesInfo.lean +++ b/src/Lean/Meta/CasesInfo.lean @@ -45,6 +45,7 @@ public structure CasesInfo where indName : Name arity : Nat discrPos : Nat + motivePos : Nat altsRange : Std.Rco Nat altNumParams : Array CasesAltInfo @@ -61,6 +62,7 @@ public def getCasesInfo? (declName : Name) : CoreM (Option CasesInfo) := do assert! r.appArg!.isFVar -- major argument assert! r.getAppFn.isFVar -- motive let some discrPos := xs.idxOf? r.appArg! | unreachable! + let some motivePos := xs.idxOf? r.getAppFn | unreachable! let some indName := (← inferType xs[discrPos]!).getAppFn.constName? | unreachable! -- We recognize the per-ctor elims side condition here let xsTys ← (xs.extract (discrPos+1)).mapM inferType @@ -80,4 +82,4 @@ public def getCasesInfo? (declName : Name) : CoreM (Option CasesInfo) := do let some ctorName := motiveArg.getAppFn.constName? | unreachable! let ctorVal ← getConstInfoCtor ctorName return .ctor ctorName ctorVal.numFields - return some { declName, indName, arity, discrPos, altsRange, altNumParams } + return some { declName, indName, arity, discrPos, motivePos, altsRange, altNumParams } From b29fa21382ea45e6c59c77a938a2ea53a24115a4 Mon Sep 17 00:00:00 2001 From: Rob23oba Date: Thu, 5 Mar 2026 16:36:31 +0100 Subject: [PATCH 2/2] fix test --- tests/elab/sparseCasesOn.lean | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/elab/sparseCasesOn.lean b/tests/elab/sparseCasesOn.lean index 3c300e1b3b73..dc64fc1e2706 100644 --- a/tests/elab/sparseCasesOn.lean +++ b/tests/elab/sparseCasesOn.lean @@ -47,7 +47,7 @@ trace: [Compiler.saveBase] size: 7 let _x.3 := sort u; return _x.3 | _ => - let _x.4 := else.1 _; + let _x.4 := else.1 ◾; return _x.4 -/ #guard_msgs in