From cfdfcf1fa7e5172bfe5bb91479ec0a36a6484212 Mon Sep 17 00:00:00 2001 From: Kim Morrison Date: Wed, 25 Sep 2024 15:47:09 +1000 Subject: [PATCH 1/4] chore: upstream some monad lemmas --- src/Init/Control/Lawful/Basic.lean | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/Init/Control/Lawful/Basic.lean b/src/Init/Control/Lawful/Basic.lean index f4927b2a5597..bd0c5d184bc9 100644 --- a/src/Init/Control/Lawful/Basic.lean +++ b/src/Init/Control/Lawful/Basic.lean @@ -33,6 +33,10 @@ attribute [simp] id_map @[simp] theorem id_map' [Functor m] [LawfulFunctor m] (x : m α) : (fun a => a) <$> x = x := id_map x +theorem Functor.map_map [Functor f] [LawfulFunctor f] (m : α → β) (g : β → γ) (x : f α) : + g <$> m <$> x = (g ∘ m) <$> x := + (comp_map _ _ _).symm + /-- The `Applicative` typeclass only contains the operations of an applicative functor. `LawfulApplicative` further asserts that these operations satisfy the laws of an applicative functor: @@ -114,6 +118,16 @@ theorem seqRight_eq_bind [Monad m] [LawfulMonad m] (x : m α) (y : m β) : x *> theorem seqLeft_eq_bind [Monad m] [LawfulMonad m] (x : m α) (y : m β) : x <* y = x >>= fun a => y >>= fun _ => pure a := by rw [seqLeft_eq]; simp [map_eq_pure_bind, seq_eq_bind_map] +theorem map_bind [Monad m] [LawfulMonad m](x : m α) {g : α → m β} {f : β → γ} : + f <$> (x >>= fun a => g a) = x >>= fun a => f <$> g a := by + rw [← bind_pure_comp, LawfulMonad.bind_assoc] + simp [bind_pure_comp] + +theorem bind_map_left [Monad m] [LawfulMonad m] (x : m α) (f : α → β) (g : β → m γ) : + ((f <$> x) >>= fun b => g b) = (x >>= fun a => g (f a)) := by + rw [← bind_pure_comp] + simp [bind_assoc, pure_bind] + /-- An alternative constructor for `LawfulMonad` which has more defaultable fields in the common case. From 41f7de9ec544a8cddac149b929572d05a28fe97b Mon Sep 17 00:00:00 2001 From: Kim Morrison Date: Wed, 25 Sep 2024 16:08:35 +1000 Subject: [PATCH 2/4] feat: adjust simp attributes on monad lemmas --- src/Init/Control/Lawful/Basic.lean | 21 +++++++++++++-------- src/Init/Control/Lawful/Instances.lean | 14 +++++++------- src/Init/Data/Array/Lemmas.lean | 2 +- src/Init/Data/List/Lemmas.lean | 5 +++++ 4 files changed, 26 insertions(+), 16 deletions(-) diff --git a/src/Init/Control/Lawful/Basic.lean b/src/Init/Control/Lawful/Basic.lean index bd0c5d184bc9..e2ef00bbe2c5 100644 --- a/src/Init/Control/Lawful/Basic.lean +++ b/src/Init/Control/Lawful/Basic.lean @@ -33,8 +33,8 @@ attribute [simp] id_map @[simp] theorem id_map' [Functor m] [LawfulFunctor m] (x : m α) : (fun a => a) <$> x = x := id_map x -theorem Functor.map_map [Functor f] [LawfulFunctor f] (m : α → β) (g : β → γ) (x : f α) : - g <$> m <$> x = (g ∘ m) <$> x := +@[simp] theorem Functor.map_map [Functor f] [LawfulFunctor f] (m : α → β) (g : β → γ) (x : f α) : + g <$> m <$> x = (fun a => g (m a)) <$> x := (comp_map _ _ _).symm /-- @@ -87,12 +87,16 @@ class LawfulMonad (m : Type u → Type v) [Monad m] extends LawfulApplicative m seq_assoc x g h := (by simp [← bind_pure_comp, ← bind_map, bind_assoc, pure_bind]) export LawfulMonad (bind_pure_comp bind_map pure_bind bind_assoc) -attribute [simp] pure_bind bind_assoc +attribute [simp] pure_bind bind_assoc bind_pure_comp @[simp] theorem bind_pure [Monad m] [LawfulMonad m] (x : m α) : x >>= pure = x := by show x >>= (fun a => pure (id a)) = x rw [bind_pure_comp, id_map] +/-- +Use `simp [← bind_pure_comp]` rather than `simp [map_eq_pure_bind]`, +as `bind_pure_comp` is in the default simp set, so also using `map_eq_pure_bind` would cause a loop. +-/ theorem map_eq_pure_bind [Monad m] [LawfulMonad m] (f : α → β) (x : m α) : f <$> x = x >>= fun a => pure (f a) := by rw [← bind_pure_comp] @@ -113,20 +117,21 @@ theorem seq_eq_bind {α β : Type u} [Monad m] [LawfulMonad m] (mf : m (α → theorem seqRight_eq_bind [Monad m] [LawfulMonad m] (x : m α) (y : m β) : x *> y = x >>= fun _ => y := by rw [seqRight_eq] - simp [map_eq_pure_bind, seq_eq_bind_map, const] + simp only [map_eq_pure_bind, const, seq_eq_bind_map, bind_assoc, pure_bind, id_eq, bind_pure] theorem seqLeft_eq_bind [Monad m] [LawfulMonad m] (x : m α) (y : m β) : x <* y = x >>= fun a => y >>= fun _ => pure a := by - rw [seqLeft_eq]; simp [map_eq_pure_bind, seq_eq_bind_map] + rw [seqLeft_eq] + simp only [map_eq_pure_bind, seq_eq_bind_map, bind_assoc, pure_bind, const_apply] -theorem map_bind [Monad m] [LawfulMonad m](x : m α) {g : α → m β} {f : β → γ} : +@[simp] theorem map_bind [Monad m] [LawfulMonad m] (x : m α) {g : α → m β} {f : β → γ} : f <$> (x >>= fun a => g a) = x >>= fun a => f <$> g a := by rw [← bind_pure_comp, LawfulMonad.bind_assoc] simp [bind_pure_comp] -theorem bind_map_left [Monad m] [LawfulMonad m] (x : m α) (f : α → β) (g : β → m γ) : +@[simp] theorem bind_map_left [Monad m] [LawfulMonad m] (x : m α) (f : α → β) (g : β → m γ) : ((f <$> x) >>= fun b => g b) = (x >>= fun a => g (f a)) := by rw [← bind_pure_comp] - simp [bind_assoc, pure_bind] + simp only [bind_assoc, pure_bind] /-- An alternative constructor for `LawfulMonad` which has more diff --git a/src/Init/Control/Lawful/Instances.lean b/src/Init/Control/Lawful/Instances.lean index 05f7db6e60ee..c239c19ee2bc 100644 --- a/src/Init/Control/Lawful/Instances.lean +++ b/src/Init/Control/Lawful/Instances.lean @@ -25,7 +25,7 @@ theorem ext {x y : ExceptT ε m α} (h : x.run = y.run) : x = y := by @[simp] theorem run_throw [Monad m] : run (throw e : ExceptT ε m β) = pure (Except.error e) := rfl @[simp] theorem run_bind_lift [Monad m] [LawfulMonad m] (x : m α) (f : α → ExceptT ε m β) : run (ExceptT.lift x >>= f : ExceptT ε m β) = x >>= fun a => run (f a) := by - simp[ExceptT.run, ExceptT.lift, bind, ExceptT.bind, ExceptT.mk, ExceptT.bindCont, map_eq_pure_bind] + simp [ExceptT.run, ExceptT.lift, bind, ExceptT.bind, ExceptT.mk, ExceptT.bindCont] @[simp] theorem bind_throw [Monad m] [LawfulMonad m] (f : α → ExceptT ε m β) : (throw e >>= f) = throw e := by simp [throw, throwThe, MonadExceptOf.throw, bind, ExceptT.bind, ExceptT.bindCont, ExceptT.mk] @@ -43,7 +43,7 @@ theorem run_bind [Monad m] (x : ExceptT ε m α) @[simp] theorem run_map [Monad m] [LawfulMonad m] (f : α → β) (x : ExceptT ε m α) : (f <$> x).run = Except.map f <$> x.run := by - simp [Functor.map, ExceptT.map, map_eq_pure_bind] + simp [Functor.map, ExceptT.map, ←bind_pure_comp] apply bind_congr intro a; cases a <;> simp [Except.map] @@ -62,7 +62,7 @@ protected theorem seqLeft_eq {α β ε : Type u} {m : Type u → Type v} [Monad intro | Except.error _ => simp | Except.ok _ => - simp [map_eq_pure_bind]; apply bind_congr; intro b; + simp [←bind_pure_comp]; apply bind_congr; intro b; cases b <;> simp [comp, Except.map, const] protected theorem seqRight_eq [Monad m] [LawfulMonad m] (x : ExceptT ε m α) (y : ExceptT ε m β) : x *> y = const α id <$> x <*> y := by @@ -175,7 +175,7 @@ theorem ext {x y : StateT σ m α} (h : ∀ s, x.run s = y.run s) : x = y := simp [bind, StateT.bind, run] @[simp] theorem run_map {α β σ : Type u} [Monad m] [LawfulMonad m] (f : α → β) (x : StateT σ m α) (s : σ) : (f <$> x).run s = (fun (p : α × σ) => (f p.1, p.2)) <$> x.run s := by - simp [Functor.map, StateT.map, run, map_eq_pure_bind] + simp [Functor.map, StateT.map, run, ←bind_pure_comp] @[simp] theorem run_get [Monad m] (s : σ) : (get : StateT σ m σ).run s = pure (s, s) := rfl @@ -210,13 +210,13 @@ theorem run_bind_lift {α σ : Type u} [Monad m] [LawfulMonad m] (x : m α) (f : theorem seqRight_eq [Monad m] [LawfulMonad m] (x : StateT σ m α) (y : StateT σ m β) : x *> y = const α id <$> x <*> y := by apply ext; intro s - simp [map_eq_pure_bind, const] + simp [←bind_pure_comp, const] apply bind_congr; intro p; cases p simp [Prod.eta] theorem seqLeft_eq [Monad m] [LawfulMonad m] (x : StateT σ m α) (y : StateT σ m β) : x <* y = const β <$> x <*> y := by apply ext; intro s - simp [map_eq_pure_bind] + simp [←bind_pure_comp] instance [Monad m] [LawfulMonad m] : LawfulMonad (StateT σ m) where id_map := by intros; apply ext; intros; simp[Prod.eta] @@ -224,7 +224,7 @@ instance [Monad m] [LawfulMonad m] : LawfulMonad (StateT σ m) where seqLeft_eq := seqLeft_eq seqRight_eq := seqRight_eq pure_seq := by intros; apply ext; intros; simp - bind_pure_comp := by intros; apply ext; intros; simp; apply LawfulMonad.bind_pure_comp + bind_pure_comp := by intros; apply ext; intros; simp bind_map := by intros; rfl pure_bind := by intros; apply ext; intros; simp bind_assoc := by intros; apply ext; intros; simp diff --git a/src/Init/Data/Array/Lemmas.lean b/src/Init/Data/Array/Lemmas.lean index 33a717e5463c..781de5f0aafb 100644 --- a/src/Init/Data/Array/Lemmas.lean +++ b/src/Init/Data/Array/Lemmas.lean @@ -652,7 +652,7 @@ theorem mapM_eq_mapM_toList [Monad m] [LawfulMonad m] (f : α → m β) (arr : A conv => rhs; rw [← List.reverse_reverse arr.toList] induction arr.toList.reverse with | nil => simp - | cons a l ih => simp [ih]; simp [map_eq_pure_bind] + | cons a l ih => simp [ih] @[deprecated mapM_eq_mapM_toList (since := "2024-09-09")] abbrev mapM_eq_mapM_data := @mapM_eq_mapM_toList diff --git a/src/Init/Data/List/Lemmas.lean b/src/Init/Data/List/Lemmas.lean index e6beec0b2a2b..ac50fcfa5bc4 100644 --- a/src/Init/Data/List/Lemmas.lean +++ b/src/Init/Data/List/Lemmas.lean @@ -1654,6 +1654,11 @@ theorem filterMap_eq_cons_iff {l} {b} {bs} : /-! ### append -/ +@[simp] theorem nil_append_fun : (([] : List α) ++ ·) = id := rfl + +@[simp] theorem cons_append_fun (a : α) (as : List α) : + (fun bs => ((a :: as) ++ bs)) = fun bs => a :: (as ++ bs) := rfl + theorem getElem_append {l₁ l₂ : List α} (n : Nat) (h) : (l₁ ++ l₂)[n] = if h' : n < l₁.length then l₁[n] else l₂[n - l₁.length]'(by simp at h h'; exact Nat.sub_lt_left_of_lt_add h' h) := by split <;> rename_i h' From 127f89e77c51dd324cef60a54b08e322aa69362e Mon Sep 17 00:00:00 2001 From: Kim Morrison Date: Wed, 25 Sep 2024 16:24:42 +1000 Subject: [PATCH 3/4] fix proof --- tests/lean/run/do_eqv_proofs.lean | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/lean/run/do_eqv_proofs.lean b/tests/lean/run/do_eqv_proofs.lean index 6b822d3dd1ee..42a4d2b825cc 100644 --- a/tests/lean/run/do_eqv_proofs.lean +++ b/tests/lean/run/do_eqv_proofs.lean @@ -7,7 +7,7 @@ theorem ex1 [Monad m] [LawfulMonad m] (b : Bool) (ma : m α) (mb : α → m α) (ma >>= fun x => if b then mb x else pure x) := by cases b <;> simp -attribute [simp] map_eq_pure_bind seq_eq_bind_map +attribute [simp] seq_eq_bind_map theorem ex2 [Monad m] [LawfulMonad m] (b : Bool) (ma : m α) (mb : α → m α) (a : α) : (do let mut x ← ma From fea0e85e76e8f2018df35e3ed65096fe74751452 Mon Sep 17 00:00:00 2001 From: Kim Morrison Date: Wed, 25 Sep 2024 20:07:51 +1000 Subject: [PATCH 4/4] cleanup --- src/Init/Control/Lawful/Basic.lean | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/src/Init/Control/Lawful/Basic.lean b/src/Init/Control/Lawful/Basic.lean index 95ba792d8370..2d7db8d7587a 100644 --- a/src/Init/Control/Lawful/Basic.lean +++ b/src/Init/Control/Lawful/Basic.lean @@ -123,26 +123,16 @@ theorem seqLeft_eq_bind [Monad m] [LawfulMonad m] (x : m α) (y : m β) : x <* y rw [seqLeft_eq] simp only [map_eq_pure_bind, seq_eq_bind_map, bind_assoc, pure_bind, const_apply] -@[simp] theorem map_bind [Monad m] [LawfulMonad m] (x : m α) {g : α → m β} {f : β → γ} : - f <$> (x >>= fun a => g a) = x >>= fun a => f <$> g a := by +@[simp] theorem map_bind [Monad m] [LawfulMonad m] (f : β → γ) (x : m α) (g : α → m β) : + f <$> (x >>= g) = x >>= fun a => f <$> g a := by rw [← bind_pure_comp, LawfulMonad.bind_assoc] simp [bind_pure_comp] -@[simp] theorem bind_map_left [Monad m] [LawfulMonad m] (x : m α) (f : α → β) (g : β → m γ) : +@[simp] theorem bind_map_left [Monad m] [LawfulMonad m] (f : α → β) (x : m α) (g : β → m γ) : ((f <$> x) >>= fun b => g b) = (x >>= fun a => g (f a)) := by rw [← bind_pure_comp] simp only [bind_assoc, pure_bind] -theorem map_bind [Monad m] [LawfulMonad m] (x : m α) {g : α → m β} {f : β → γ} : - f <$> (x >>= fun a => g a) = x >>= fun a => f <$> g a := by - rw [← bind_pure_comp, LawfulMonad.bind_assoc] - simp [bind_pure_comp] - -theorem bind_map_left [Monad m] [LawfulMonad m] (x : m α) (f : α → β) (g : β → m γ) : - ((f <$> x) >>= fun b => g b) = (x >>= fun a => g (f a)) := by - rw [← bind_pure_comp] - simp [bind_assoc, pure_bind] - /-- An alternative constructor for `LawfulMonad` which has more defaultable fields in the common case.