Skip to content

Commit

Permalink
perf: eagerly elaborate leaf nodes in linear_combination (#15599)
Browse files Browse the repository at this point in the history
Since the arithmetic type is known from the equality goal, we can eagerly elaborate all leaf nodes in `expandLinearCombo` while also avoiding re-elaborating them when deciding whether they are proofs or not. This brings the elaboration of a complicated example (see tests) from 3.91s to 0.29s.

The main issue is that large expressions can be slow to elaborate. In particular, exponentiation relies on the default instance mechanism, which appears can take quadratic time to resolve. Forcing the resolution of default instances within each coefficient of the linear combination bounds the default instance problems.

This is similar to PR #15570, but it sticks with the syntax transformation paradigm.

Co-authored-by: Heather Macbeth <[email protected]>
  • Loading branch information
2 people authored and bjoernkjoshanssen committed Sep 9, 2024
1 parent d26f5e8 commit 87a1651
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 46 deletions.
104 changes: 58 additions & 46 deletions Mathlib/Tactic/LinearCombination.lean
Original file line number Diff line number Diff line change
Expand Up @@ -51,60 +51,70 @@ theorem pf_div_c [Div α] (p : a = b) (c : α) : a / c = b / c := p ▸ rfl
theorem c_div_pf [Div α] (p : b = c) (a : α) : a / b = a / c := p ▸ rfl
theorem div_pf [Div α] (p₁ : (a₁:α) = b₁) (p₂ : a₂ = b₂) : a₁ / a₂ = b₁ / b₂ := p₁ ▸ p₂ ▸ rfl

/-- Result of `expandLinearCombo`, either an equality proof or a value. -/
inductive Expanded
/-- A proof of `a = b`. -/
| proof (pf : Syntax.Term)
/-- A value, equivalently a proof of `c = c`. -/
| const (c : Syntax.Term)

/--
Performs macro expansion of a linear combination expression,
using `+`/`-`/`*`/`/` on equations and values.
* `some p` means that `p` is a syntax corresponding to a proof of an equation.
For example, if `h : a = b` then `expandLinearCombo (2 * h)` returns `some (c_add_pf 2 h)`
* `.proof p` means that `p` is a syntax corresponding to a proof of an equation.
For example, if `h : a = b` then `expandLinearCombo (2 * h)` returns `.proof (c_add_pf 2 h)`
which is a proof of `2 * a = 2 * b`.
* `none` means that the input expression is not an equation but a value;
the input syntax itself is used in this case.
* `.const c` means that the input expression is not an equation but a value.
-/
partial def expandLinearCombo (stx : Syntax.Term) : TermElabM (Option Syntax.Term) := do
let mut result ← match stx with
| `(($e)) => expandLinearCombo e
partial def expandLinearCombo (ty : Expr) (stx : Syntax.Term) : TermElabM Expanded := withRef stx do
match stx with
| `(($e)) => expandLinearCombo ty e
| `($e₁ + $e₂) => do
match ← expandLinearCombo e₁, ← expandLinearCombo e₂ with
| none, none => pure none
| some p₁, none => ``(pf_add_c $p₁ $e₂)
| none, some p₂ => ``(c_add_pf $p₂ $e₁)
| some p₁, some p₂ => ``(add_pf $p₁ $p₂)
match ← expandLinearCombo ty e₁, ← expandLinearCombo ty e₂ with
| .const c₁, .const c₂ => .const <$> ``($c₁ + $c₂)
| .proof p₁, .const c₂ => .proof <$> ``(pf_add_c $p₁ $c₂)
| .const c₁, .proof p₂ => .proof <$> ``(c_add_pf $p₂ $c₁)
| .proof p₁, .proof p₂ => .proof <$> ``(add_pf $p₁ $p₂)
| `($e₁ - $e₂) => do
match ← expandLinearCombo e₁, ← expandLinearCombo e₂ with
| none, none => pure none
| some p₁, none => ``(pf_sub_c $p₁ $e₂)
| none, some p₂ => ``(c_sub_pf $p₂ $e₁)
| some p₁, some p₂ => ``(sub_pf $p₁ $p₂)
match ← expandLinearCombo ty e₁, ← expandLinearCombo ty e₂ with
| .const c₁, .const c₂ => .const <$> ``($c₁ - $c₂)
| .proof p₁, .const c₂ => .proof <$> ``(pf_sub_c $p₁ $c₂)
| .const c₁, .proof p₂ => .proof <$> ``(c_sub_pf $p₂ $c₁)
| .proof p₁, .proof p₂ => .proof <$> ``(sub_pf $p₁ $p₂)
| `(-$e) => do
match ← expandLinearCombo e with
| none => pure none
| some p => ``(neg_pf $p)
match ← expandLinearCombo ty e with
| .const c => .const <$> `(-$c)
| .proof p => .proof <$> ``(neg_pf $p)
| `(← $e) => do
match ← expandLinearCombo e with
| none => pure none
| some p => ``(Eq.symm $p)
match ← expandLinearCombo ty e with
| .const c => return .const c
| .proof p => .proof <$> ``(Eq.symm $p)
| `($e₁ * $e₂) => do
match ← expandLinearCombo e₁, ← expandLinearCombo e₂ with
| none, none => pure none
| some p₁, none => ``(pf_mul_c $p₁ $e₂)
| none, some p₂ => ``(c_mul_pf $p₂ $e₁)
| some p₁, some p₂ => ``(mul_pf $p₁ $p₂)
match ← expandLinearCombo ty e₁, ← expandLinearCombo ty e₂ with
| .const c₁, .const c₂ => .const <$> ``($c₁ * $c₂)
| .proof p₁, .const c₂ => .proof <$> ``(pf_mul_c $p₁ $c₂)
| .const c₁, .proof p₂ => .proof <$> ``(c_mul_pf $p₂ $c₁)
| .proof p₁, .proof p₂ => .proof <$> ``(mul_pf $p₁ $p₂)
| `($e⁻¹) => do
match ← expandLinearCombo e with
| none => pure none
| some p => ``(inv_pf $p)
match ← expandLinearCombo ty e with
| .const c => .const <$> `($c⁻¹)
| .proof p => .proof <$> ``(inv_pf $p)
| `($e₁ / $e₂) => do
match ← expandLinearCombo e₁, ← expandLinearCombo e₂ with
| none, none => pure none
| some p₁, none => ``(pf_div_c $p₁ $e₂)
| none, some p₂ => ``(c_div_pf $p₂ $e₁)
| some p₁, some p₂ => ``(div_pf $p₁ $p₂)
| e => do
let e ← elabTerm e none
let eType ← inferType e
let .true := (← withReducible do whnf eType).isEq | pure none
some <$> e.toSyntax
return result.map fun r => ⟨r.raw.setInfo (SourceInfo.fromRef stx true)⟩
match ← expandLinearCombo ty e₁, ← expandLinearCombo ty e₂ with
| .const c₁, .const c₂ => .const <$> ``($c₁ / $c₂)
| .proof p₁, .const c₂ => .proof <$> ``(pf_div_c $p₁ $c₂)
| .const c₁, .proof p₂ => .proof <$> ``(c_div_pf $p₂ $c₁)
| .proof p₁, .proof p₂ => .proof <$> ``(div_pf $p₁ $p₂)
| e =>
-- We have the expected type from the goal, so we can fully synthesize this leaf node.
withSynthesize do
-- It is OK to use `ty` as the expected type even if `e` is a proof.
-- The expected type is just a hint.
let c ← withSynthesizeLight <| Term.elabTerm e ty
if (← whnfR (← inferType c)).isEq then
.proof <$> c.toSyntax
else
.const <$> c.toSyntax

theorem eq_trans₃ (p : (a:α) = b) (p₁ : a = a') (p₂ : b = b') : a' = b' := p₁ ▸ p₂ ▸ p

Expand All @@ -119,12 +129,14 @@ theorem eq_of_add_pow [Ring α] [NoZeroDivisors α] (n : ℕ) (p : (a:α) = b)
def elabLinearCombination
(norm? : Option Syntax.Tactic) (exp? : Option Syntax.NumLit) (input : Option Syntax.Term)
(twoGoals := false) : Tactic.TacticM Unit := Tactic.withMainContext do
let some (ty, _) := (← (← Tactic.getMainGoal).getType').eq? |
throwError "'linear_combination' only proves equalities"
let p ← match input with
| none => `(Eq.refl 0)
| some e => withSynthesize do
match ← expandLinearCombo e with
| none => `(Eq.refl $e)
| some p => pure p
| some e =>
match ← expandLinearCombo ty e with
| .const c => `(Eq.refl $c)
| .proof p => pure p
let norm := norm?.getD (Unhygienic.run `(tactic| ring1))
Tactic.evalTactic <| ← withFreshMacroScope <|
if twoGoals then
Expand Down
43 changes: 43 additions & 0 deletions test/linear_combination.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import Mathlib.Tactic.Linarith

set_option autoImplicit true

private axiom test_sorry : ∀ {α}, α

-- We deliberately mock R here so that we don't have to import the deps
axiom Real : Type
notation "ℝ" => Real
Expand Down Expand Up @@ -118,6 +120,10 @@ example {α} [h : CommRing α] {a b c d e f : α} (h1 : a * d = b * c) (h2 : c *
example (x y z w : ℚ) (hzw : z = w) : x * z + 2 * y * z = x * w + 2 * y * w := by
linear_combination (x + 2 * y) * hzw

example (x : ℤ) : x ^ 2 = x ^ 2 := by linear_combination x ^ 2

example (x y : ℤ) (h : x = 0) : y ^ 2 * x = 0 := by linear_combination y ^ 2 * h

/-! ### Cases that explicitly use a config -/

example (x y : ℚ) (h1 : 3 * x + 2 * y = 10) (h2 : 2 * x + 5 * y = 3) : -11 * y + 1 = 11 + 1 := by
Expand Down Expand Up @@ -230,3 +236,40 @@ example (h : g a = g b) : a ^ 4 = b ^ 4 := by
example {r s a b : ℕ} (h₁ : (r : ℤ) = a + 1) (h₂ : (s : ℤ) = b + 1) :
r * s = (a + 1 : ℤ) * (b + 1) := by
linear_combination (↑b + 1) * h₁ + ↑r * h₂

-- Implementation at the time of the port (Nov 2022) was 110,000 heartbeats.
-- Eagerly elaborating leaf nodes brings this to 7,540 heartbeats.
set_option maxHeartbeats 8000 in
example (K : Type*) [Field K] [CharZero K] {x y z p q : K}
(h₀ : 3 * x ^ 2 + z ^ 2 * p = 0)
(h₁ : z * (2 * y) = 0)
(h₂ : -y ^ 2 + p * x * (2 * z) + q * (3 * z ^ 2) = 0) :
((27 * q ^ 2 + 4 * p ^ 3) * x) ^ 4 = 0 := by
linear_combination (norm := skip)
(256 / 3 * p ^ 12 * x ^ 2 + 128 * q * p ^ 11 * x * z + 2304 * q ^ 2 * p ^ 9 * x ^ 2 +
2592 * q ^ 3 * p ^ 8 * x * z -
64 * q * p ^ 10 * y ^ 2 +
23328 * q ^ 4 * p ^ 6 * x ^ 2 +
17496 * q ^ 5 * p ^ 5 * x * z -
1296 * q ^ 3 * p ^ 7 * y ^ 2 +
104976 * q ^ 6 * p ^ 3 * x ^ 2 +
39366 * q ^ 7 * p ^ 2 * x * z -
8748 * q ^ 5 * p ^ 4 * y ^ 2 +
177147 * q ^ 8 * x ^ 2 -
19683 * q ^ 7 * p * y ^ 2) *
h₀ +
(-(64 / 3 * p ^ 12 * x * y) + 32 * q * p ^ 11 * z * y - 432 * q ^ 2 * p ^ 9 * x * y +
648 * q ^ 3 * p ^ 8 * z * y -
2916 * q ^ 4 * p ^ 6 * x * y +
4374 * q ^ 5 * p ^ 5 * z * y -
6561 * q ^ 6 * p ^ 3 * x * y +
19683 / 2 * q ^ 7 * p ^ 2 * z * y) *
h₁ +
(-(128 / 3 * p ^ 12 * x * z) - 192 * q * p ^ 10 * x ^ 2 - 864 * q ^ 2 * p ^ 9 * x * z -
3888 * q ^ 3 * p ^ 7 * x ^ 2 -
5832 * q ^ 4 * p ^ 6 * x * z -
26244 * q ^ 5 * p ^ 4 * x ^ 2 -
13122 * q ^ 6 * p ^ 3 * x * z -
59049 * q ^ 7 * p * x ^ 2) *
h₂
exact test_sorry

0 comments on commit 87a1651

Please sign in to comment.