Skip to content

Commit

Permalink
feat: allow doSeq in let x <- e | seq
Browse files Browse the repository at this point in the history
fixes #1804
  • Loading branch information
digama0 authored and leodemoura committed Nov 8, 2022
1 parent 9f2182f commit 4bf89df
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 9 deletions.
11 changes: 4 additions & 7 deletions src/Lean/Elab/Do.lean
Original file line number Diff line number Diff line change
Expand Up @@ -681,9 +681,6 @@ def getDoReassignVars (doReassign : Syntax) : TermElabM (Array Var) := do
def mkDoSeq (doElems : Array Syntax) : Syntax :=
mkNode `Lean.Parser.Term.doSeqIndent #[mkNullNode <| doElems.map fun doElem => mkNullNode #[doElem, mkNullNode]]

def mkSingletonDoSeq (doElem : Syntax) : Syntax :=
mkDoSeq #[doElem]

/--
If the given syntax is a `doIf`, return an equivalent `doIf` that has an `else` but no `else if`s or `if let`s. -/
private def expandDoIf? (stx : Syntax) : MacroM (Option Syntax) := match stx with
Expand Down Expand Up @@ -1301,7 +1298,7 @@ mutual
where
```
def doIdDecl := leading_parser ident >> optType >> leftArrow >> doElemParser
def doPatDecl := leading_parser termParser >> leftArrow >> doElemParser >> optional (" | " >> doElemParser)
def doPatDecl := leading_parser termParser >> leftArrow >> doElemParser >> optional (" | " >> doSeq)
```
-/
partial def doLetArrowToCode (doLetArrow : Syntax) (doElems : List Syntax) : M CodeBlock := do
Expand Down Expand Up @@ -1336,17 +1333,17 @@ mutual
else
pure doElems.toArray
let contSeq := mkDoSeq contSeq
let elseSeq := mkSingletonDoSeq optElse[1]
let elseSeq := optElse[1]
let auxDo ← `(do let%$doLetArrow __discr ← $doElem; match%$doLetArrow __discr with | $pattern:term => $contSeq | _ => $elseSeq)
doSeqToCode <| getDoSeqElems (getDoSeq auxDo)
else
throwError "unexpected kind of `do` declaration"

partial def doLetElseToCode (doLetElse : Syntax) (doElems : List Syntax) : M CodeBlock := do
-- "let " >> optional "mut " >> termParser >> " := " >> termParser >> checkColGt >> " | " >> doElemParser
-- "let " >> optional "mut " >> termParser >> " := " >> termParser >> checkColGt >> " | " >> doSeq
let pattern := doLetElse[2]
let val := doLetElse[4]
let elseSeq := mkSingletonDoSeq doLetElse[6]
let elseSeq := doLetElse[6]
let contSeq ← if isMutableLet doLetElse then
let vars ← (← getPatternVarsEx pattern).mapM fun var => `(doElem| let mut $var := $var)
pure (vars ++ doElems.toArray)
Expand Down
4 changes: 2 additions & 2 deletions src/Lean/Parser/Do.lean
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def notFollowedByRedefinedTermToken :=
"let " >> optional "mut " >> letDecl
@[builtin_doElem_parser] def doLetElse := leading_parser
"let " >> optional "mut " >> termParser >> " := " >> termParser >>
checkColGt >> " | " >> doElemParser
checkColGt >> " | " >> doSeq

@[builtin_doElem_parser] def doLetRec := leading_parser
group ("let " >> nonReservedSymbol "rec ") >> letRecDecls
Expand All @@ -58,7 +58,7 @@ def doIdDecl := leading_parser
doElemParser
def doPatDecl := leading_parser
atomic (termParser >> ppSpace >> leftArrow) >>
doElemParser >> optional (checkColGt >> " | " >> doElemParser)
doElemParser >> optional (checkColGt >> " | " >> doSeq)
@[builtin_doElem_parser] def doLetArrow := leading_parser
withPosition ("let " >> optional "mut " >> (doIdDecl <|> doPatDecl))

Expand Down
15 changes: 15 additions & 0 deletions tests/lean/1804.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
def inc : StateM Nat Nat := do
let s ← get
modify (· + 1)
return s

def f (x : Bool) : StateM Nat Nat := do
let .true := x | return (← inc)
get

def g (x : Bool) : StateM Nat Nat := do
let .true := x | do return (← inc)
get

#eval g true |>.run' 0 -- `0` as expected
#eval f true |>.run' 0 -- should return `0`, not `1`
2 changes: 2 additions & 0 deletions tests/lean/1804.lean.expected.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
0
0

0 comments on commit 4bf89df

Please sign in to comment.