From 4bf89dfa121f3f2b55ff1e4a94a96a00d8e7b58e Mon Sep 17 00:00:00 2001 From: Mario Carneiro Date: Mon, 7 Nov 2022 14:19:20 -0500 Subject: [PATCH] feat: allow `doSeq` in `let x <- e | seq` fixes #1804 --- src/Lean/Elab/Do.lean | 11 ++++------- src/Lean/Parser/Do.lean | 4 ++-- tests/lean/1804.lean | 15 +++++++++++++++ tests/lean/1804.lean.expected.out | 2 ++ 4 files changed, 23 insertions(+), 9 deletions(-) create mode 100644 tests/lean/1804.lean create mode 100644 tests/lean/1804.lean.expected.out diff --git a/src/Lean/Elab/Do.lean b/src/Lean/Elab/Do.lean index a4a23b3dd5b5..379e87d5c4cd 100644 --- a/src/Lean/Elab/Do.lean +++ b/src/Lean/Elab/Do.lean @@ -681,9 +681,6 @@ def getDoReassignVars (doReassign : Syntax) : TermElabM (Array Var) := do def mkDoSeq (doElems : Array Syntax) : Syntax := mkNode `Lean.Parser.Term.doSeqIndent #[mkNullNode <| doElems.map fun doElem => mkNullNode #[doElem, mkNullNode]] -def mkSingletonDoSeq (doElem : Syntax) : Syntax := - mkDoSeq #[doElem] - /-- If the given syntax is a `doIf`, return an equivalent `doIf` that has an `else` but no `else if`s or `if let`s. -/ private def expandDoIf? (stx : Syntax) : MacroM (Option Syntax) := match stx with @@ -1301,7 +1298,7 @@ mutual where ``` def doIdDecl := leading_parser ident >> optType >> leftArrow >> doElemParser - def doPatDecl := leading_parser termParser >> leftArrow >> doElemParser >> optional (" | " >> doElemParser) + def doPatDecl := leading_parser termParser >> leftArrow >> doElemParser >> optional (" | " >> doSeq) ``` -/ partial def doLetArrowToCode (doLetArrow : Syntax) (doElems : List Syntax) : M CodeBlock := do @@ -1336,17 +1333,17 @@ mutual else pure doElems.toArray let contSeq := mkDoSeq contSeq - let elseSeq := mkSingletonDoSeq optElse[1] + let elseSeq := optElse[1] let auxDo ← `(do let%$doLetArrow __discr ← $doElem; match%$doLetArrow __discr with | $pattern:term => $contSeq | _ => $elseSeq) doSeqToCode <| getDoSeqElems (getDoSeq auxDo) else throwError "unexpected kind of `do` declaration" partial def doLetElseToCode (doLetElse : Syntax) (doElems : List Syntax) : M CodeBlock := do - -- "let " >> optional "mut " >> termParser >> " := " >> termParser >> checkColGt >> " | " >> doElemParser + -- "let " >> optional "mut " >> termParser >> " := " >> termParser >> checkColGt >> " | " >> doSeq let pattern := doLetElse[2] let val := doLetElse[4] - let elseSeq := mkSingletonDoSeq doLetElse[6] + let elseSeq := doLetElse[6] let contSeq ← if isMutableLet doLetElse then let vars ← (← getPatternVarsEx pattern).mapM fun var => `(doElem| let mut $var := $var) pure (vars ++ doElems.toArray) diff --git a/src/Lean/Parser/Do.lean b/src/Lean/Parser/Do.lean index a599dbcc7503..59099eb152f8 100644 --- a/src/Lean/Parser/Do.lean +++ b/src/Lean/Parser/Do.lean @@ -49,7 +49,7 @@ def notFollowedByRedefinedTermToken := "let " >> optional "mut " >> letDecl @[builtin_doElem_parser] def doLetElse := leading_parser "let " >> optional "mut " >> termParser >> " := " >> termParser >> - checkColGt >> " | " >> doElemParser + checkColGt >> " | " >> doSeq @[builtin_doElem_parser] def doLetRec := leading_parser group ("let " >> nonReservedSymbol "rec ") >> letRecDecls @@ -58,7 +58,7 @@ def doIdDecl := leading_parser doElemParser def doPatDecl := leading_parser atomic (termParser >> ppSpace >> leftArrow) >> - doElemParser >> optional (checkColGt >> " | " >> doElemParser) + doElemParser >> optional (checkColGt >> " | " >> doSeq) @[builtin_doElem_parser] def doLetArrow := leading_parser withPosition ("let " >> optional "mut " >> (doIdDecl <|> doPatDecl)) diff --git a/tests/lean/1804.lean b/tests/lean/1804.lean new file mode 100644 index 000000000000..0e7db5d40d32 --- /dev/null +++ b/tests/lean/1804.lean @@ -0,0 +1,15 @@ +def inc : StateM Nat Nat := do + let s ← get + modify (· + 1) + return s + +def f (x : Bool) : StateM Nat Nat := do + let .true := x | return (← inc) + get + +def g (x : Bool) : StateM Nat Nat := do + let .true := x | do return (← inc) + get + +#eval g true |>.run' 0 -- `0` as expected +#eval f true |>.run' 0 -- should return `0`, not `1` diff --git a/tests/lean/1804.lean.expected.out b/tests/lean/1804.lean.expected.out new file mode 100644 index 000000000000..aa47d0d46d47 --- /dev/null +++ b/tests/lean/1804.lean.expected.out @@ -0,0 +1,2 @@ +0 +0