Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 44 additions & 8 deletions src/Lean/Compiler/LCNF/ToLCNF.lean
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,40 @@ private def checkComputable (ref : Name) : M Unit := do
else if getOriginalConstKind? (← getEnv) ref matches some .axiom | some .quot | some .induct | some .thm then
throwNamedError lean.dependsOnNoncomputable f!"`{ref}` not supported by code generator; consider marking definition as `noncomputable`"

/--
Given a motive `fun ... => ∀ (h₁ : p₁) ... (hₙ : pₙ), ...` where all `pᵢ` are propositions,
returns `min maxArgs n`.
-/
def countProofOverApp (motive : Expr) (maxArgs : Nat) : MetaM Nat := do
Meta.lambdaTelescope motive fun vars body => do
Meta.forallTelescope body fun forallVars _ => do
let n := min maxArgs forallVars.size
for h : i in *...n do
have h : i < min maxArgs forallVars.size := h
unless ← Meta.isProof forallVars[i] do
return i
return n

/--
Given `e : ∀ (h₁ : p₁) ... (hₙ : pₙ), b h₁ ... hₙ`, returns `e (lcProof p₁) ... (lcProof pₙ)`.
-/
def mkProofOverApp (e : Expr) (n : Nat) : MetaM Expr := do
let mut type ← Meta.inferType e
if type.getNumHeadForalls < n then
type ← Meta.forallBoundedTelescope type n Meta.mkForallFVars
go type n #[]
where
go (type : Expr) (n : Nat) (vars : Array Expr) : MetaM Expr :=
match n with
| 0 => return e.beta vars
| k + 1 =>
match type with
| .forallE _ t b _ =>
let t := t.instantiateRev vars
go b k (vars.push (mkLcProof t))
| .mdata _ type => go type n vars
| _ => Meta.throwFunctionExpected (e.beta vars)

/--
Eta reduce implicits. We use this function to eliminate introduced by the implicit lambda feature,
where it generates terms such as `fun {α} => ReaderT.pure`
Expand Down Expand Up @@ -541,11 +575,11 @@ where
/--
Visit a `matcher`/`casesOn` alternative.
-/
visitAlt (casesAltInfo : CasesAltInfo) (e : Expr) : M (Expr × (Alt .pure)) := do
visitAlt (casesAltInfo : CasesAltInfo) (e : Expr) (nproofs : Nat) : M (Expr × (Alt .pure)) := do
withNewScope do
match casesAltInfo with
| .default numHyps =>
let c ← toCode (← visit (mkAppN e (Array.replicate numHyps erasedExpr)))
let c ← toCode (← visit (← liftMetaM <| mkProofOverApp e (numHyps + nproofs)))
let altType ← c.inferType
return (altType, .default c)
| .ctor ctorName numParams =>
Expand All @@ -572,14 +606,16 @@ where
not occur in the type of `as : List α`.
-/
p.update (← applyToAny p.type)
let c ← toCode (← visit e)
let c ← toCode (← visit (← liftMetaM <| mkProofOverApp e nproofs))
let altType ← c.inferType
return (altType, .alt ctorName ps c)

visitCases (casesInfo : CasesInfo) (e : Expr) : M (Arg .pure) :=
etaIfUnderApplied e casesInfo.arity do
let args := e.getAppArgs
let mut resultType ← toLCNFType (← liftMetaM do Meta.inferType (mkAppN e.getAppFn args[*...casesInfo.arity]))
let nproofs ← liftMetaM <| countProofOverApp args[casesInfo.motivePos]! (args.size - casesInfo.arity)
let arity := casesInfo.arity + nproofs
let mut resultType ← toLCNFType (← liftMetaM do Meta.inferType (mkAppN e.getAppFn args[*...arity]))
let typeName := casesInfo.indName
let .inductInfo indVal ← getConstInfo typeName | unreachable!
if casesInfo.numAlts == 0 then
Expand Down Expand Up @@ -609,23 +645,23 @@ where
fieldArgs := fieldArgs.push fieldArg
return fieldArgs
let f := args[casesInfo.altsRange.lower]!
let result ← visit (mkAppN f fieldArgs)
mkOverApplication result args casesInfo.arity
let result ← visit (← liftMetaM <| mkProofOverApp (mkAppN f fieldArgs) nproofs)
mkOverApplication result args arity
else
let mut alts := #[]
let discr ← visitAppArg args[casesInfo.discrPos]!
let discrFVarId ← match discr with
| .fvar discrFVarId => pure discrFVarId
| .erased | .type .. => mkAuxLetDecl .erased
for i in casesInfo.altsRange, numParams in casesInfo.altNumParams do
let (altType, alt) ← visitAlt numParams args[i]!
let (altType, alt) ← visitAlt numParams args[i]! nproofs
resultType := joinTypes altType resultType
alts := alts.push alt
let cases := ⟨typeName, resultType, discrFVarId, alts⟩
let auxDecl ← mkAuxParam resultType
pushElement (.cases auxDecl cases)
let result := .fvar auxDecl.fvarId
mkOverApplication result args casesInfo.arity
mkOverApplication result args arity

visitCtor (arity : Nat) (e : Expr) : M (Arg .pure) :=
etaIfUnderApplied e arity do
Expand Down
4 changes: 3 additions & 1 deletion src/Lean/Meta/CasesInfo.lean
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public structure CasesInfo where
indName : Name
arity : Nat
discrPos : Nat
motivePos : Nat
altsRange : Std.Rco Nat
altNumParams : Array CasesAltInfo

Expand All @@ -61,6 +62,7 @@ public def getCasesInfo? (declName : Name) : CoreM (Option CasesInfo) := do
assert! r.appArg!.isFVar -- major argument
assert! r.getAppFn.isFVar -- motive
let some discrPos := xs.idxOf? r.appArg! | unreachable!
let some motivePos := xs.idxOf? r.getAppFn | unreachable!
let some indName := (← inferType xs[discrPos]!).getAppFn.constName? | unreachable!
-- We recognize the per-ctor elims side condition here
let xsTys ← (xs.extract (discrPos+1)).mapM inferType
Expand All @@ -80,4 +82,4 @@ public def getCasesInfo? (declName : Name) : CoreM (Option CasesInfo) := do
let some ctorName := motiveArg.getAppFn.constName? | unreachable!
let ctorVal ← getConstInfoCtor ctorName
return .ctor ctorName ctorVal.numFields
return some { declName, indName, arity, discrPos, altsRange, altNumParams }
return some { declName, indName, arity, discrPos, motivePos, altsRange, altNumParams }
2 changes: 1 addition & 1 deletion tests/elab/sparseCasesOn.lean
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ trace: [Compiler.saveBase] size: 7
let _x.3 := sort u;
return _x.3
| _ =>
let _x.4 := else.1 _;
let _x.4 := else.1 ;
return _x.4
-/
#guard_msgs in
Expand Down
Loading