From bcbd4ec192efaff95fc060285f7c947e17862915 Mon Sep 17 00:00:00 2001 From: Floris van Doorn Date: Fri, 28 Apr 2023 06:56:19 +0000 Subject: [PATCH] perf: improve performance of `to_additive` (#3632) * `applyReplacementFun` now treats applications `f x_1 ... x_n` as atomic, and recurses directly into `f` and `x_i` (before it recursed on the partial appliations `f x_1 ... x_j`) * I had to reimplement the way `to_additive` reorders arguments, so at the same time I also made it more flexible. We can now reorder with an arbitrary permutation, and you have to specify this by providing a permutation using cycle notation (e.g. `(reorder := 1 2 3, 8 9)` means we're permuting the first three arguments and swapping arguments 8 and 9). This implements the first item of #1074. * `additiveTest` now memorizes the test on previously-visited subexpressions. Thanks to @kmill for this suggestion! The performance on (one of) the slowest declaration(s) to additivize (`MonoidLocalization.lift`) is summarized below (note: `dsimp only` refers to adding a single `dsimp only` tactic in the declaration, which was done in #3580) ``` original: 27400ms better applyReplacementFun: 1550ms better applyReplacementFun + better additiveTest: 176ms dsimp only: 6710ms better applyReplacementFun + dsimp only: 425ms better applyReplacementFun + better additiveTest + dsimp only: 128ms ``` --- Mathlib/Algebra/Group/Defs.lean | 10 +- Mathlib/Algebra/Group/OrderSynonym.lean | 24 +-- Mathlib/Algebra/Group/ULift.lean | 4 +- Mathlib/Algebra/Hom/Group.lean | 4 +- Mathlib/Algebra/Hom/Iterate.lean | 4 +- Mathlib/Data/Array/Defs.lean | 18 +- Mathlib/Data/Pi/Algebra.lean | 8 +- Mathlib/GroupTheory/GroupAction/Prod.lean | 10 +- Mathlib/Tactic/ToAdditive.lean | 242 +++++++++++----------- test/toAdditive.lean | 19 +- 10 files changed, 185 insertions(+), 158 deletions(-) diff --git a/Mathlib/Algebra/Group/Defs.lean b/Mathlib/Algebra/Group/Defs.lean index 18ec32b7052ff..bb2639412213c 100644 --- a/Mathlib/Algebra/Group/Defs.lean +++ b/Mathlib/Algebra/Group/Defs.lean @@ -88,16 +88,16 @@ infixl:65 " -ᵥ " => VSub.vsub infixr:73 " • " => HSMul.hSMul attribute [to_additive existing] Mul Div HMul instHMul HDiv instHDiv HSMul -attribute [to_additive (reorder := 1) SMul] Pow -attribute [to_additive (reorder := 1)] HPow -attribute [to_additive existing (reorder := 1 5) hSMul] HPow.hPow -attribute [to_additive existing (reorder := 1 4) smul] Pow.pow +attribute [to_additive (reorder := 1 2) SMul] Pow +attribute [to_additive (reorder := 1 2)] HPow +attribute [to_additive existing (reorder := 1 2, 5 6) hSMul] HPow.hPow +attribute [to_additive existing (reorder := 1 2, 4 5) smul] Pow.pow @[to_additive (attr := default_instance)] instance instHSMul [SMul α β] : HSMul α β β where hSMul := SMul.smul -attribute [to_additive existing (reorder := 1)] instHPow +attribute [to_additive existing (reorder := 1 2)] instHPow universe u diff --git a/Mathlib/Algebra/Group/OrderSynonym.lean b/Mathlib/Algebra/Group/OrderSynonym.lean index 38a1eedd11d5a..c429b79fcf0a5 100644 --- a/Mathlib/Algebra/Group/OrderSynonym.lean +++ b/Mathlib/Algebra/Group/OrderSynonym.lean @@ -37,12 +37,12 @@ instance [h : Inv α] : Inv αᵒᵈ := h @[to_additive] instance [h : Div α] : Div αᵒᵈ := h -@[to_additive (attr := to_additive) (reorder := 1) instSMulOrderDual] +@[to_additive (attr := to_additive) (reorder := 1 2) instSMulOrderDual] instance [h : Pow α β] : Pow αᵒᵈ β := h #align order_dual.has_pow instPowOrderDual #align order_dual.has_smul instSMulOrderDual -@[to_additive (attr := to_additive) (reorder := 1) instSMulOrderDual'] +@[to_additive (attr := to_additive) (reorder := 1 2) instSMulOrderDual'] instance instPowOrderDual' [h : Pow α β] : Pow α βᵒᵈ := h #align order_dual.has_pow' instPowOrderDual' #align order_dual.has_smul' instSMulOrderDual' @@ -140,25 +140,25 @@ theorem ofDual_div [Div α] (a b : αᵒᵈ) : ofDual (a / b) = ofDual a / ofDua #align of_dual_div ofDual_div #align of_dual_sub ofDual_sub -@[to_additive (attr := simp, to_additive) (reorder := 1 4) toDual_smul] +@[to_additive (attr := simp, to_additive) (reorder:= 1 2, 4 5) toDual_smul] theorem toDual_pow [Pow α β] (a : α) (b : β) : toDual (a ^ b) = toDual a ^ b := rfl #align to_dual_pow toDual_pow #align to_dual_smul toDual_smul #align to_dual_vadd toDual_vadd -@[to_additive (attr := simp, to_additive) (reorder := 1 4) ofDual_smul] +@[to_additive (attr := simp, to_additive) (reorder:= 1 2, 4 5) ofDual_smul] theorem ofDual_pow [Pow α β] (a : αᵒᵈ) (b : β) : ofDual (a ^ b) = ofDual a ^ b := rfl #align of_dual_pow ofDual_pow #align of_dual_smul ofDual_smul #align of_dual_vadd ofDual_vadd -@[to_additive (attr := simp, to_additive) (reorder := 1 4) toDual_smul'] +@[to_additive (attr := simp, to_additive) (reorder:= 1 2, 4 5) toDual_smul'] theorem pow_toDual [Pow α β] (a : α) (b : β) : a ^ toDual b = a ^ b := rfl #align pow_to_dual pow_toDual #align to_dual_smul' toDual_smul' #align to_dual_vadd' toDual_vadd' -@[to_additive (attr := simp, to_additive) (reorder := 1 4) ofDual_smul'] +@[to_additive (attr := simp, to_additive) (reorder:= 1 2, 4 5) ofDual_smul'] theorem pow_ofDual [Pow α β] (a : α) (b : βᵒᵈ) : a ^ ofDual b = a ^ b := rfl #align pow_of_dual pow_ofDual #align of_dual_smul' ofDual_smul' @@ -179,12 +179,12 @@ instance [h : Inv α] : Inv (Lex α) := h @[to_additive] instance [h : Div α] : Div (Lex α) := h -@[to_additive (attr := to_additive) (reorder := 1) instSMulLex] +@[to_additive (attr := to_additive) (reorder := 1 2) instSMulLex] instance [h : Pow α β] : Pow (Lex α) β := h #align lex.has_pow instPowLex #align lex.has_smul instSMulLex -@[to_additive (attr := to_additive) (reorder := 1) instSMulLex'] +@[to_additive (attr := to_additive) (reorder := 1 2) instSMulLex'] instance instPowLex' [h : Pow α β] : Pow α (Lex β) := h #align lex.has_pow' instPowLex' #align lex.has_smul' instSMulLex' @@ -280,25 +280,25 @@ theorem ofLex_div [Div α] (a b : Lex α) : ofLex (a / b) = ofLex a / ofLex b := #align of_lex_div ofLex_div #align of_lex_sub ofLex_sub -@[to_additive (attr := simp, to_additive) (reorder := 1 4) toLex_smul] +@[to_additive (attr := simp, to_additive) (reorder:= 1 2, 4 5) toLex_smul] theorem toLex_pow [Pow α β] (a : α) (b : β) : toLex (a ^ b) = toLex a ^ b := rfl #align to_lex_pow toLex_pow #align to_lex_smul toLex_smul #align to_lex_vadd toLex_vadd -@[to_additive (attr := simp, to_additive) (reorder := 1 4) ofLex_smul] +@[to_additive (attr := simp, to_additive) (reorder:= 1 2, 4 5) ofLex_smul] theorem ofLex_pow [Pow α β] (a : Lex α) (b : β) : ofLex (a ^ b) = ofLex a ^ b := rfl #align of_lex_pow ofLex_pow #align of_lex_smul ofLex_smul #align of_lex_vadd ofLex_vadd -@[to_additive (attr := simp, to_additive) (reorder := 1 4) toLex_smul'] +@[to_additive (attr := simp, to_additive) (reorder:= 1 2, 4 5) toLex_smul'] theorem pow_toLex [Pow α β] (a : α) (b : β) : a ^ toLex b = a ^ b := rfl #align pow_to_lex pow_toLex #align to_lex_smul' toLex_smul' #align to_lex_vadd' toLex_vadd' -@[to_additive (attr := simp, to_additive) (reorder := 1 4) ofLex_smul'] +@[to_additive (attr := simp, to_additive) (reorder:= 1 2, 4 5) ofLex_smul'] theorem pow_ofLex [Pow α β] (a : α) (b : Lex β) : a ^ ofLex b = a ^ b := rfl #align pow_of_lex pow_ofLex #align of_lex_smul' ofLex_smul' diff --git a/Mathlib/Algebra/Group/ULift.lean b/Mathlib/Algebra/Group/ULift.lean index abd933388c027..887b7c4ee63f2 100644 --- a/Mathlib/Algebra/Group/ULift.lean +++ b/Mathlib/Algebra/Group/ULift.lean @@ -89,12 +89,12 @@ theorem smul_down [SMul α β] (a : α) (b : ULift.{v} β) : (a • b).down = a #align ulift.smul_down ULift.smul_down #align ulift.vadd_down ULift.vadd_down -@[to_additive existing (reorder := 1) smul] +@[to_additive existing (reorder := 1 2) smul] instance pow [Pow α β] : Pow (ULift α) β := ⟨fun x n => up (x.down ^ n)⟩ #align ulift.has_pow ULift.pow -@[to_additive existing (attr := simp) (reorder := 1) smul_down] +@[to_additive existing (attr := simp) (reorder := 1 2) smul_down] theorem pow_down [Pow α β] (a : ULift.{v} α) (b : β) : (a ^ b).down = a.down ^ b := rfl #align ulift.pow_down ULift.pow_down diff --git a/Mathlib/Algebra/Hom/Group.lean b/Mathlib/Algebra/Hom/Group.lean index 2738be9742826..5cf4a010307ea 100644 --- a/Mathlib/Algebra/Hom/Group.lean +++ b/Mathlib/Algebra/Hom/Group.lean @@ -444,7 +444,7 @@ theorem map_div [Group G] [DivisionMonoid H] [MonoidHomClass F G H] (f : F) : #align map_div map_div #align map_sub map_sub -@[to_additive (attr := simp) (reorder := 8)] +@[to_additive (attr := simp) (reorder := 8 9)] theorem map_pow [Monoid G] [Monoid H] [MonoidHomClass F G H] (f : F) (a : G) : ∀ n : ℕ, f (a ^ n) = f a ^ n | 0 => by rw [pow_zero, pow_zero, map_one] @@ -461,7 +461,7 @@ theorem map_zpow' [DivInvMonoid G] [DivInvMonoid H] [MonoidHomClass F G H] #align map_zsmul' map_zsmul' /-- Group homomorphisms preserve integer power. -/ -@[to_additive (attr := simp) (reorder := 8) +@[to_additive (attr := simp) (reorder := 8 9) "Additive group homomorphisms preserve integer scaling."] theorem map_zpow [Group G] [DivisionMonoid H] [MonoidHomClass F G H] (f : F) (g : G) (n : ℤ) : f (g ^ n) = f g ^ n := map_zpow' f (map_inv f) g n diff --git a/Mathlib/Algebra/Hom/Iterate.lean b/Mathlib/Algebra/Hom/Iterate.lean index 8c6d6c45667f9..aa4221dc30c3f 100644 --- a/Mathlib/Algebra/Hom/Iterate.lean +++ b/Mathlib/Algebra/Hom/Iterate.lean @@ -103,14 +103,14 @@ theorem iterate_map_smul (f : M →+ M) (n m : ℕ) (x : M) : (f^[n]) (m • x) f.toMultiplicative.iterate_map_pow n x m #align add_monoid_hom.iterate_map_smul AddMonoidHom.iterate_map_smul -attribute [to_additive (reorder := 5)] MonoidHom.iterate_map_pow +attribute [to_additive (reorder := 5 6)] MonoidHom.iterate_map_pow #align add_monoid_hom.iterate_map_nsmul AddMonoidHom.iterate_map_nsmul theorem iterate_map_zsmul (f : G →+ G) (n : ℕ) (m : ℤ) (x : G) : (f^[n]) (m • x) = m • (f^[n]) x := f.toMultiplicative.iterate_map_zpow n x m #align add_monoid_hom.iterate_map_zsmul AddMonoidHom.iterate_map_zsmul -attribute [to_additive existing (reorder := 5)] MonoidHom.iterate_map_zpow +attribute [to_additive existing (reorder := 5 6)] MonoidHom.iterate_map_zpow end AddMonoidHom diff --git a/Mathlib/Data/Array/Defs.lean b/Mathlib/Data/Array/Defs.lean index 7e89810770cf7..0b06e120c7914 100644 --- a/Mathlib/Data/Array/Defs.lean +++ b/Mathlib/Data/Array/Defs.lean @@ -4,7 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Arthur Paulino, Floris van Doorn -/ -import Std +import Lean /-! ## Definitions on Arrays @@ -14,3 +14,19 @@ proofs about these definitions, those are contained in other files in `Mathlib.D -/ namespace Array + +/-- Permute the array using a sequence of indices defining a cyclic permutation. + If the list of indices `l = [i₁, i₂, ..., iₙ]` are all distinct then + `(cyclicPermute! a l)[iₖ₊₁] = a[iₖ]` and `(cyclicPermute! a l)[i₀] = a[iₙ]` -/ +def cyclicPermute! [Inhabited α] : Array α → List Nat → Array α +| a, [] => a +| a, i :: is => cyclicPermuteAux a is a[i]! i +where cyclicPermuteAux : Array α → List Nat → α → Nat → Array α +| a, [], x, i0 => a.set! i0 x +| a, i :: is, x, i0 => + let (y, a) := a.swapAt! i x + cyclicPermuteAux a is y i0 + +/-- Permute the array using a list of cycles. -/ +def permute! [Inhabited α] (a : Array α) (ls : List (List Nat)) : Array α := +ls.foldl (init := a) (·.cyclicPermute! ·) diff --git a/Mathlib/Data/Pi/Algebra.lean b/Mathlib/Data/Pi/Algebra.lean index c31b6d57a9524..9d0ebf88c8a7b 100644 --- a/Mathlib/Data/Pi/Algebra.lean +++ b/Mathlib/Data/Pi/Algebra.lean @@ -117,28 +117,28 @@ instance instSMul [∀ i, SMul α <| f i] : SMul α (∀ i : I, f i) := instance instPow [∀ i, Pow (f i) β] : Pow (∀ i, f i) β := ⟨fun x b i => x i ^ b⟩ -@[to_additive (attr := simp, to_additive) (reorder := 5) smul_apply] +@[to_additive (attr := simp, to_additive) (reorder := 5 6) smul_apply] theorem pow_apply [∀ i, Pow (f i) β] (x : ∀ i, f i) (b : β) (i : I) : (x ^ b) i = x i ^ b := rfl #align pi.pow_apply Pi.pow_apply #align pi.smul_apply Pi.smul_apply #align pi.vadd_apply Pi.vadd_apply -@[to_additive (attr := to_additive) (reorder := 5) smul_def] +@[to_additive (attr := to_additive) (reorder := 5 6) smul_def] theorem pow_def [∀ i, Pow (f i) β] (x : ∀ i, f i) (b : β) : x ^ b = fun i => x i ^ b := rfl #align pi.pow_def Pi.pow_def #align pi.smul_def Pi.smul_def #align pi.vadd_def Pi.vadd_def -@[to_additive (attr := simp, to_additive) (reorder := 2 5) smul_const] +@[to_additive (attr := simp, to_additive) (reorder:= 2 3, 5 6) smul_const] theorem const_pow [Pow α β] (a : α) (b : β) : const I a ^ b = const I (a ^ b) := rfl #align pi.const_pow Pi.const_pow #align pi.smul_const Pi.smul_const #align pi.vadd_const Pi.vadd_const -@[to_additive (attr := to_additive) (reorder := 6) smul_comp] +@[to_additive (attr := to_additive) (reorder := 6 7) smul_comp] theorem pow_comp [Pow γ α] (x : β → γ) (a : α) (y : I → β) : (x ^ a) ∘ y = x ∘ y ^ a := rfl #align pi.pow_comp Pi.pow_comp diff --git a/Mathlib/GroupTheory/GroupAction/Prod.lean b/Mathlib/GroupTheory/GroupAction/Prod.lean index db387f43d6869..dfd5034599ce3 100644 --- a/Mathlib/GroupTheory/GroupAction/Prod.lean +++ b/Mathlib/GroupTheory/GroupAction/Prod.lean @@ -89,12 +89,12 @@ instance pow : Pow (α × β) E where pow p c := (p.1 ^ c, p.2 ^ c) #align prod.has_pow Prod.pow #align prod.has_smul Prod.smul -@[to_additive existing (attr := simp) (reorder := 6) smul_fst] +@[to_additive existing (attr := simp) (reorder := 6 7) smul_fst] theorem pow_fst (p : α × β) (c : E) : (p ^ c).fst = p.fst ^ c := rfl #align prod.pow_fst Prod.pow_fst -@[to_additive existing (attr := simp) (reorder := 6) smul_snd] +@[to_additive existing (attr := simp) (reorder := 6 7) smul_snd] theorem pow_snd (p : α × β) (c : E) : (p ^ c).snd = p.snd ^ c := rfl #align prod.pow_snd Prod.pow_snd @@ -102,17 +102,17 @@ theorem pow_snd (p : α × β) (c : E) : (p ^ c).snd = p.snd ^ c := /- Note that the `c` arguments to this lemmas cannot be in the more natural right-most positions due to limitations in `to_additive` and `to_additive_reorder`, which will silently fail to reorder more than two adjacent arguments -/ -@[to_additive existing (attr := simp) (reorder := 6) smul_mk] +@[to_additive existing (attr := simp) (reorder := 6 7) smul_mk] theorem pow_mk (c : E) (a : α) (b : β) : Prod.mk a b ^ c = Prod.mk (a ^ c) (b ^ c) := rfl #align prod.pow_mk Prod.pow_mk -@[to_additive existing (reorder := 6) smul_def] +@[to_additive existing (reorder := 6 7) smul_def] theorem pow_def (p : α × β) (c : E) : p ^ c = (p.1 ^ c, p.2 ^ c) := rfl #align prod.pow_def Prod.pow_def -@[to_additive existing (attr := simp) (reorder := 6) smul_swap] +@[to_additive existing (attr := simp) (reorder := 6 7) smul_swap] theorem pow_swap (p : α × β) (c : E) : (p ^ c).swap = p.swap ^ c := rfl #align prod.pow_swap Prod.pow_swap diff --git a/Mathlib/Tactic/ToAdditive.lean b/Mathlib/Tactic/ToAdditive.lean index fa07ffcfebd87..1038f6440da39 100644 --- a/Mathlib/Tactic/ToAdditive.lean +++ b/Mathlib/Tactic/ToAdditive.lean @@ -6,6 +6,7 @@ Ported by: E.W.Ayers -/ import Mathlib.Init.Data.Nat.Notation import Mathlib.Data.String.Defs +import Mathlib.Data.Array.Defs import Mathlib.Data.KVMap import Mathlib.Lean.Expr.ReplaceRec import Mathlib.Lean.EnvExtension @@ -36,13 +37,13 @@ syntax (name := to_additive_ignore_args) "to_additive_ignore_args" num* : attr /-- The `to_additive_relevant_arg` attribute. -/ syntax (name := to_additive_relevant_arg) "to_additive_relevant_arg" num : attr /-- The `to_additive_reorder` attribute. -/ -syntax (name := to_additive_reorder) "to_additive_reorder" num* : attr +syntax (name := to_additive_reorder) "to_additive_reorder" (num+),+ : attr /-- The `to_additive_change_numeral` attribute. -/ syntax (name := to_additive_change_numeral) "to_additive_change_numeral" num* : attr /-- An `attr := ...` option for `to_additive`. -/ syntax toAdditiveAttrOption := &"attr" ":=" Parser.Term.attrInstance,* /-- An `reorder := ...` option for `to_additive`. -/ -syntax toAdditiveReorderOption := &"reorder" ":=" num+ +syntax toAdditiveReorderOption := &"reorder" ":=" (num+),+ /-- Options to `to_additive`. -/ syntax toAdditiveParenthesizedOption := "(" toAdditiveAttrOption <|> toAdditiveReorderOption ")" /-- Options to `to_additive`. -/ @@ -150,14 +151,10 @@ initialize ignoreArgsAttr : NameMapExtension (List Nat) ← /-- An attribute that stores all the declarations that needs their arguments reordered when -applying `@[to_additive]`. Currently, we only support swapping consecutive arguments. -The list of the natural numbers contains the positions of the first of the two arguments -to be swapped. -If the first two arguments are swapped, the first two universe variables are also swapped. -Example: `@[to_additive_reorder 1 4]` swaps the first two arguments and the arguments in -positions 4 and 5. +applying `@[to_additive]`. It is applied automatically by the `(reorder := ...)` syntax of +`to_additive`, and should not usually be added manually. -/ -initialize reorderAttr : NameMapExtension (List Nat) ← +initialize reorderAttr : NameMapExtension (List $ List Nat) ← registerNameMapAttribute { name := `to_additive_reorder descr := @@ -166,12 +163,12 @@ initialize reorderAttr : NameMapExtension (List Nat) ← We keep it as an attribute for now so that mathport can still use it, and it can generate a warning." add := fun - | _, stx@`(attr| to_additive_reorder $[$ids:num]*) => do + | _, stx@`(attr| to_additive_reorder $[$[$reorders:num]*],*) => do Linter.logLintIf linter.toAdditiveReorder stx m!"Using this attribute is deprecated. Use `@[to_additive (reorder := )]` {"" }instead.\nThat will also generate the additive version with the arguments swapped, {"" }so you are probably able to remove the manually written additive declaration." - pure <| Array.toList <| ids.map (·.1.isNatLit?.get!) + pure <| reorders.toList.map (·.toList.map (·.raw.isNatLit?.get! - 1)) | _, _ => throwUnsupportedSyntax } /-- @@ -192,8 +189,7 @@ The reason is that whether we additivize a declaration is an all-or-nothing deci we will not be able to additivize declarations that (e.g.) talk about multiplication on `ℕ × α` anyway. -Warning: adding `@[to_additive_reorder]` with an equal or smaller number than the number in this -attribute is currently not supported. +Warning: interactions between this and the `(reorder := ...)` argument are not well-tested. -/ initialize relevantArgAttr : NameMapExtension Nat ← registerNameMapAttribute { @@ -259,8 +255,8 @@ structure Config : Type where /-- If `allowAutoName` is `false` (default) then `@[to_additive]` will check whether the given name can be auto-generated. -/ allowAutoName : Bool := false - /-- The arguments that should be reordered by `to_additive` -/ - reorder : List Nat := [] + /-- The arguments that should be reordered by `to_additive`, using cycle notation. -/ + reorder : List (List Nat) := [] /-- The attributes which we want to give to both the multiplicative and additive versions. For certain attributes (such as `simp` and `simps`) this will also add generated lemmas to the translation dictionary. -/ @@ -278,28 +274,44 @@ structure Config : Type where variable [Monad M] [MonadOptions M] [MonadEnv M] -/-- Auxilliary function for `additiveTest`. The bool argument *only* matters when applied -to exactly a constant. -/ -def additiveTestAux (findTranslation? : Name → Option Name) - (ignore : Name → Option (List ℕ)) : Bool → Expr → Bool := visit where - /-- see `additiveTestAux` -/ - visit : Bool → Expr → Bool - | b, .const n _ => b || (findTranslation? n).isSome - | _, x@(.app e a) => Id.run do - if !visit true e then - return false - -- make sure that we don't treat `(fun x => α) (n + 1)` as a type that depends on `Nat` - if x.isConstantApplication then - return true - if let some n := e.getAppFn.constName? then - if let some l := ignore n then - if e.getAppNumArgs + 1 ∈ l then - return true - visit false a - | _, .lam _ _ t _ => visit false t - | _, .forallE _ _ t _ => visit false t - | _, .letE _ _ e body _ => visit false e && visit false body - | _, _ => true +open Lean.Expr.FindImpl in +/-- Implementation function for `additiveTest`. + We cache previous applications of the function, using the same method that `Expr.find?` uses, + to avoid visiting the same subexpression many times. Note that we only need to cache the + expressions without taking the value of `inApp` into account, since `inApp` only matters when + the expression is a constant. However, for this reason we have to make sure that we never + cache constant expressions, so that's why the `if`s in the implementation are in this order. + + Note that this function is still called many times by `applyReplacementFun` + and we're not remembering the cache between these calls. -/ +unsafe def additiveTestUnsafe (findTranslation? : Name → Option Name) + (ignore : Name → Option (List ℕ)) (e : Expr) : Bool := + let size := cacheSize + let rec visit (e : Expr) (inApp := false) : OptionT FindM Unit := do + if e.isConst then + if inApp || (findTranslation? e.constName).isSome then + failure + else + return + if ← visited e size then + failure + match e with + | x@(.app e a) => + visit e true <|> do + -- make sure that we don't treat `(fun x => α) (n + 1)` as a type that depends on `Nat` + guard !x.isConstantApplication + if let some n := e.getAppFn.constName? then + if let some l := ignore n then + if e.getAppNumArgs + 1 ∈ l then + failure + visit a + | .lam _ _ t _ => visit t + | .forallE _ _ t _ => visit t + | .letE _ _ e body _ => visit e <|> visit body + | .mdata _ b => visit b + | .proj _ _ b => visit b + | _ => failure + Option.isNone <| Id.run <| (visit e).run' initCache /-- `additiveTest e` tests whether the expression `e` contains no constant @@ -312,7 +324,7 @@ We ignore all arguments specified by the `ignore` `NameMap`. -/ def additiveTest (findTranslation? : Name → Option Name) (ignore : Name → Option (List ℕ)) (e : Expr) : Bool := - additiveTestAux findTranslation? ignore false e + unsafe additiveTestUnsafe findTranslation? ignore e /-- Swap the first two elements of a list -/ def _root_.List.swapFirstTwo {α : Type _} : List α → List α @@ -339,78 +351,77 @@ e.g. `g x₁ x₂ x₃ ... xₙ` becomes `g x₂ x₁ x₃ ... xₙ` if `reorder -/ def applyReplacementFun (e : Expr) : MetaM Expr := do let env ← getEnv - let reorderFn : Name → List ℕ := fun nm ↦ (reorderAttr.find? env nm |>.getD []) - let isRelevant : Name → ℕ → Bool := fun nm i ↦ i == (relevantArgAttr.find? env nm).getD 0 + let reorderFn : Name → List (List ℕ) := fun nm ↦ (reorderAttr.find? env nm |>.getD []) + let relevantArg : Name → ℕ := fun nm ↦ (relevantArgAttr.find? env nm).getD 0 return aux (findTranslation? <| ← getEnv) reorderFn (ignoreArgsAttr.find? env) - (changeNumeralAttr.find? env) isRelevant (← getBoolOption `trace.to_additive_detail) e + (changeNumeralAttr.find? env) relevantArg (← getBoolOption `trace.to_additive_detail) e where /-- Implementation of `applyReplacementFun`. -/ aux (findTranslation? : Name → Option Name) - (reorderFn : Name → List ℕ) (ignore : Name → Option (List ℕ)) - (changeNumeral? : Name → Option (List Nat)) (isRelevant : Name → ℕ → Bool) (trace : Bool) : + (reorderFn : Name → List (List ℕ)) (ignore : Name → Option (List ℕ)) + (changeNumeral? : Name → Option (List Nat)) (relevantArg : Name → ℕ) (trace : Bool) : Expr → Expr := Lean.Expr.replaceRec fun r e ↦ Id.run do if trace then dbg_trace s!"replacing at {e}" match e with - | .const n₀ ls => do + | .const n₀ ls₀ => do let n₁ := n₀.mapPrefix findTranslation? - if trace && n₀ != n₁ then - dbg_trace s!"changing {n₀} to {n₁}" - let ls : List Level := if 1 ∈ reorderFn n₀ then ls.swapFirstTwo else ls - return some <| Lean.mkConst n₁ ls + let ls₁ : List Level := if 0 ∈ (reorderFn n₀).join then ls₀.swapFirstTwo else ls₀ + if trace then + if n₀ != n₁ then + dbg_trace s!"changing {n₀} to {n₁}" + if 0 ∈ (reorderFn n₀).join then + dbg_trace s!"reordering the universe variables from {ls₀} to {ls₁}" + return some <| Lean.mkConst n₁ ls₁ | .app g x => do let gf := g.getAppFn if gf.isBVar && x.isLit then if trace then dbg_trace s!"applyReplacementFun: Variables applied to numerals are not changed {g.app x}" return some <| g.app x - if let some nm := gf.constName? then - let gArgs := g.getAppArgs - -- e = `(nm y₁ .. yₙ x) - /- Test if arguments should be reordered. -/ - if h : gArgs.size > 0 then - let c1 : Bool := gArgs.size ∈ reorderFn nm - let c2 := additiveTest findTranslation? ignore gArgs[0] - if c1 && c2 then - -- interchange `x` and the last argument of `g` - let x := r x - let gf := r g.appFn! - let ga := r g.appArg! - let e₂ := mkApp2 gf x ga + let gArgs := g.getAppArgs + let mut gAllArgs := gArgs.push x + let (gfAdditive, gAllArgsAdditive) ← + if let some nm := gf.constName? then + -- e = `(nm y₁ .. yₙ x) + /- Test if the head should not be replaced. -/ + let relevantArgId := relevantArg nm + let gfAdditive := + if relevantArgId < gAllArgs.size && gf.isConst && + not (additiveTest findTranslation? ignore gAllArgs[relevantArgId]!) then Id.run <| do + if trace then + dbg_trace + s!"{gAllArgs[relevantArgId]!} contains a fixed type, so {nm} is not changed" + gf + else + r gf + /- Test if arguments should be reordered. -/ + let reorder := reorderFn nm + if !reorder.isEmpty && relevantArgId < gAllArgs.size && + additiveTest findTranslation? ignore gAllArgs[relevantArgId]! then + gAllArgs := gAllArgs.permute! reorder if trace then - dbg_trace s!"reordering {nm}: {x} ↔ {ga}\nBefore: {e}\nAfter: {e₂}" - return some e₂ - /- Test if the head should not be replaced. -/ - let c1 := isRelevant nm gArgs.size - let c2 := gf.isConst - let c3 := additiveTest findTranslation? ignore x - if trace && c1 && c2 && c3 then - dbg_trace s!"{x} doesn't contain a fixed type, so we will change {nm}" - if c1 && c2 && not c3 then - if trace then - dbg_trace s!"{x} contains a fixed type, so {nm} is not changed" - let x ← r x - let args ← gArgs.mapM r - return some $ mkApp (mkAppN gf args) x - /- Do not replace numerals in specific types. -/ - let gAllArgs := gArgs.push x - let firstArg := gAllArgs[0] - if let some changedArgNrs := changeNumeral? nm then - if additiveTest findTranslation? ignore firstArg then - if trace then - dbg_trace s!"applyReplacementFun: We change the numerals in {g.app x}. { - ""}However, we will still recurse into all the non-numeral arguments." - -- In this case, we still update all arguments of `g` that are not numerals, - -- since all other arguments can contain subexpressions like - -- `(fun x ↦ ℕ) (1 : G)`, and we have to update the `(1 : G)` to `(0 : G)` - let newArgs ← gAllArgs.mapIdx fun argNr arg ↦ - if changedArgNrs.contains argNr then - r <| changeNumeral arg - else - r arg - return some <| mkAppN gf newArgs - return e.updateApp! (← r g) (← r x) + dbg_trace s!"reordering the arguments of {nm} using the cyclic permutations {reorder}" + /- Do not replace numerals in specific types. -/ + let firstArg := gAllArgs[0]! + if let some changedArgNrs := changeNumeral? nm then + if additiveTest findTranslation? ignore firstArg then + if trace then + dbg_trace s!"applyReplacementFun: We change the numerals in this expression. { + ""}However, we will still recurse into all the non-numeral arguments." + -- In this case, we still update all arguments of `g` that are not numerals, + -- since all other arguments can contain subexpressions like + -- `(fun x ↦ ℕ) (1 : G)`, and we have to update the `(1 : G)` to `(0 : G)` + gAllArgs := gAllArgs.mapIdx fun argNr arg ↦ + if changedArgNrs.contains argNr then + changeNumeral arg + else + arg + pure <| (gfAdditive, ← gAllArgs.mapM r) + else + pure (← r gf, ← gAllArgs.mapM r) + return some <| mkAppN gfAdditive gAllArgsAdditive | .proj n₀ idx e => do let n₁ := n₀.mapPrefix findTranslation? if trace then @@ -428,7 +439,7 @@ def etaExpandN (n : Nat) (e : Expr): MetaM Expr := do `reorder.find n`. -/ def expand (e : Expr) : MetaM Expr := do let env ← getEnv - let reorderFn : Name → List ℕ := fun nm ↦ (reorderAttr.find? env nm |>.getD []) + let reorderFn : Name → List (List ℕ) := fun nm ↦ (reorderAttr.find? env nm |>.getD []) let e₂ ← Lean.Meta.transform (input := e) (post := fun e => return .done e) <| fun e ↦ do let e0 := e.getAppFn let es := e.getAppArgs @@ -437,7 +448,7 @@ def expand (e : Expr) : MetaM Expr := do if reorder.isEmpty then -- no need to expand if nothing needs reordering return .continue - let needed_n := reorder.foldr Nat.max 0 + 1 + let needed_n := reorder.join.foldr Nat.max 0 + 1 -- the second disjunct is a temporary fix to avoid infinite loops. -- We may need to use `replaceRec` or something similar to not change the head of an application if needed_n ≤ es.size || es.size == 0 then @@ -453,38 +464,25 @@ def expand (e : Expr) : MetaM Expr := do return e₂ /-- Reorder pi-binders. See doc of `reorderAttr` for the interpretation of the argument -/ -def reorderForall (src : Expr) (reorder : List Nat := []) : MetaM Expr := do +def reorderForall (src : Expr) (reorder : List (List Nat) := []) : MetaM Expr := do if reorder == [] then return src forallTelescope src fun xs e => do - let xs ← reorder.foldrM (init := xs) fun i xs => - if h : i < xs.size then - pure <| xs.swap ⟨i - 1, Nat.lt_of_le_of_lt i.pred_le h⟩ ⟨i, h⟩ - else - throwError "the declaration does not have enough arguments to reorder the given arguments: { - xs.size} ≤ {i}" - mkForallFVars xs e + mkForallFVars (xs.permute! reorder) e /-- Reorder lambda-binders. See doc of `reorderAttr` for the interpretation of the argument -/ -def reorderLambda (src : Expr) (reorder : List Nat := []) : MetaM Expr := do +def reorderLambda (src : Expr) (reorder : List (List Nat) := []) : MetaM Expr := do if reorder == [] then return src lambdaTelescope src fun xs e => do - let xs ← reorder.foldrM (init := xs) fun i xs => - if h : i < xs.size then - pure <| xs.swap ⟨i - 1, Nat.lt_of_le_of_lt i.pred_le h⟩ ⟨i, h⟩ - else - throwError "the declaration does not have enough arguments to reorder the given arguments. { - xs.size} ≤ {i}.\nIf this is a field projection, make sure to use `@[to_additive]` on {"" - }the field first." - mkLambdaFVars xs e + mkLambdaFVars (xs.permute! reorder) e /-- Run applyReplacementFun on the given `srcDecl` to make a new declaration with name `tgt` -/ def updateDecl - (tgt : Name) (srcDecl : ConstantInfo) (reorder : List Nat := []) + (tgt : Name) (srcDecl : ConstantInfo) (reorder : List (List Nat) := []) : MetaM ConstantInfo := do let mut decl := srcDecl.updateName tgt - if 1 ∈ reorder then + if 0 ∈ reorder.join then decl := decl.updateLevelParams decl.levelParams.swapFirstTwo decl := decl.updateType <| ← applyReplacementFun <| ← reorderForall (← expand decl.type) reorder if let some v := decl.value? then @@ -874,18 +872,19 @@ def proceedFields (src tgt : Name) : CoreM Unit := do /-- Elaboration of the configuration options for `to_additive`. -/ def elabToAdditive : Syntax → CoreM Config | `(attr| to_additive%$tk $[?%$trace]? $[$opts:toAdditiveOption]* $[$tgt]? $[$doc]?) => do - let mut attrs : Array Syntax := #[] + let mut attrs := #[] let mut reorder := [] let mut existing := some false for stx in opts do match stx with | `(toAdditiveOption| (attr := $[$stxs],*)) => attrs := attrs ++ stxs - | `(toAdditiveOption| (reorder := $[$reorders:num]*)) => - reorder := reorder ++ reorders.toList.map (·.raw.isNatLit?.get!) + | `(toAdditiveOption| (reorder := $[$[$reorders:num]*],*)) => + reorder := reorder ++ reorders.toList.map (·.toList.map (·.raw.isNatLit?.get! - 1)) | `(toAdditiveOption| existing) => existing := some true | _ => throwUnsupportedSyntax + reorder := reorder.reverse trace[to_additive_detail] "attributes: {attrs}; reorder arguments: {reorder}" return { trace := trace.isSome tgt := match tgt with | some tgt => tgt.getId | none => Name.anonymous @@ -1077,6 +1076,12 @@ The transport tries to do the right thing in most cases using several heuristics described below. However, in some cases it fails, and requires manual intervention. +Use the `(reorder := ...)` syntax to reorder the arguments in the generated additive declaration. +This is specified using cycle notation. For example `(reorder := 1 2, 5 6)` swaps the first two +arguments with each other and the fifth and the sixth argument and `(reorder := 3 4 5)` will move +the fifth argument before the third argument. This is mostly useful to translate declarations using +`Pow` to those using `SMul`. + Use the `(attr := ...)` syntax to apply attributes to both the multiplicative and the additive version: @@ -1181,8 +1186,7 @@ mismatch error. multiplicative and additive version. This might mean that arguments have an "unnatural" order (e.g. `Monoid.npow n x` corresponds to `x ^ n`, but it is convenient that `Monoid.npow` has this argument order, since it matches `AddMonoid.nsmul n x`. - * If this is not possible, add the `[to_additive_reorder k]` to the multiplicative declaration - to indicate that the `k`-th and `(k+1)`-st arguments are reordered in the additive version. + * If this is not possible, add `(reorder := ...)` argument to `to_additive`. If neither of these solutions work, and `to_additive` is unable to automatically generate the additive version of a declaration, manually write and prove the additive version. diff --git a/test/toAdditive.lean b/test/toAdditive.lean index 287ab6994368e..dad9e0c1c4b13 100644 --- a/test/toAdditive.lean +++ b/test/toAdditive.lean @@ -30,8 +30,8 @@ class my_has_scalar (M : Type u) (α : Type v) := (smul : M → α → α) instance : my_has_scalar Nat Nat := ⟨fun a b => a * b⟩ -attribute [to_additive (reorder := 1) my_has_scalar] my_has_pow -attribute [to_additive (reorder := 1 4)] my_has_pow.pow +attribute [to_additive (reorder := 1 2) my_has_scalar] my_has_pow +attribute [to_additive (reorder:= 1 2, 4 5)] my_has_pow.pow @[to_additive bar1] def foo1 {α : Type u} [my_has_pow α ℕ] (x : α) (n : ℕ) : α := @my_has_pow.pow α ℕ _ x n @@ -99,16 +99,16 @@ theorem bar11_works : bar11 = foo11 := by rfl @[to_additive bar12] def foo12 (_ : Nat) (_ : Int) : Fin 37 := ⟨2, by decide⟩ -@[to_additive (reorder := 1 4) bar13] +@[to_additive (reorder:= 1 2, 4 5) bar13] lemma foo13 {α β : Type u} [my_has_pow α β] (x : α) (y : β) : x ^ y = x ^ y := rfl -@[to_additive (reorder := 1 4) bar14] +@[to_additive (reorder:= 1 2, 4 5) bar14] def foo14 {α β : Type u} [my_has_pow α β] (x : α) (y : β) : α := (x ^ y) ^ y -@[to_additive (reorder := 1 4) bar15] +@[to_additive (reorder:= 1 2, 4 5) bar15] lemma foo15 {α β : Type u} [my_has_pow α β] (x : α) (y : β) : foo14 x y = (x ^ y) ^ y := rfl -@[to_additive (reorder := 1 4) bar16] +@[to_additive (reorder:= 1 2, 4 5) bar16] lemma foo16 {α β : Type u} [my_has_pow α β] (x : α) (y : β) : foo14 x y = (x ^ y) ^ y := foo15 x y initialize testExt : SimpExtension ← @@ -313,6 +313,13 @@ theorem isUnit'_iff_exists_inv [CommMonoid M] {a : M} : IsUnit' a ↔ ∃ b, a * theorem isUnit'_iff_exists_inv' [CommMonoid M] {a : M} : IsUnit' a ↔ ∃ b, b * a = 1 := by simp [isUnit'_iff_exists_inv, mul_comm] +/-! Test a permutation with a cycle of length > 2. -/ +@[to_additive (reorder := 3 4 5)] +def reorderMulThree {α : Type _} [Mul α] (x y z : α) : α := x * y * z + +example {α : Type _} [Add α] (x y z : α) : reorderAddThree z x y = x + y + z := rfl + + def Ones : ℕ → Q(Nat) | 0 => q(1) | (n+1) => q($(Ones n) + $(Ones n))