Skip to content

Commit

Permalink
Allow more than one parameter group on def (#1037)
Browse files Browse the repository at this point in the history
* Allow more than one parameter group on def

* use the syntax in tests

* remove commented code
  • Loading branch information
johnynek authored Sep 9, 2023
1 parent 30a7e97 commit ae52f82
Show file tree
Hide file tree
Showing 16 changed files with 139 additions and 94 deletions.
6 changes: 3 additions & 3 deletions core/src/main/scala/org/bykn/bosatsu/Declaration.scala
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ sealed abstract class Declaration {
// may or may not be recursive

val boundRest = bound + d.name
val boundBody = boundRest ++ d.args.patternNames
val boundBody = boundRest ++ d.args.toList.flatMap(_.patternNames)

val acc1 = loop(body.get, boundBody, acc)
loop(rest.padded, boundRest, acc1)
Expand Down Expand Up @@ -311,7 +311,7 @@ sealed abstract class Declaration {
case DefFn(d) =>
// def sets up a binding to itself, which
// may or may not be recursive
val acc1 = (acc + d.name) ++ d.args.patternNames
val acc1 = (acc + d.name) ++ d.args.toList.flatMap(_.patternNames)
val (body, rest) = d.result
val acc2 = loop(body.get, acc1)
loop(rest.padded, acc2)
Expand Down Expand Up @@ -569,7 +569,7 @@ object Declaration {
else if (scope.exists(shadows)) Some(d0)
else loopDec(d0)

val bodyScope = nm :: args.patternNames
val bodyScope = nm :: args.toList.flatMap(_.patternNames)
val restScope = nm :: Nil

(body.traverse(go(bodyScope, _)), rest.traverse(go(restScope, _)))
Expand Down
128 changes: 81 additions & 47 deletions core/src/main/scala/org/bykn/bosatsu/DefRecursionCheck.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.bykn.bosatsu

import cats.data.{NonEmptyList, Validated, ValidatedNel, StateT}
import org.typelevel.paiges.Doc

import cats.implicits._

Expand Down Expand Up @@ -40,10 +41,22 @@ object DefRecursionCheck {
def region = decl.region
def message = "unexpected recur: may only appear unnested inside a def"
}
case class RecurNotOnArg(decl: Declaration.Match, fnname: Bindable, args: NonEmptyList[Pattern.Parsed]) extends RecursionError {
case class RecurNotOnArg(decl: Declaration.Match,
fnname: Bindable,
args: NonEmptyList[NonEmptyList[Pattern.Parsed]]) extends RecursionError {

def region = decl.region
def message = {
val argStr = args.iterator.map { pat => Pattern.document[TypeRef].document(pat).render(80) }.mkString(", ")
val argsDoc =
Doc.intercalate(Doc.empty,
args.toList.map { group =>
(Doc.char('(') +
Doc.intercalate(Doc.comma + Doc.line,
group.toList.map { pat => Pattern.document[TypeRef].document(pat) }) +
Doc.char(')')).grouped
}
)
val argStr = argsDoc.render(80)
s"recur not on an argument to the def of ${fnname.sourceCodeRepr}, args: $argStr"
}
}
Expand Down Expand Up @@ -98,39 +111,39 @@ object DefRecursionCheck {
this match {
case TopLevel => Nil
case InDef(outer, n, _, _) => n :: outer.outerDefNames
case InDefRecurred(id, _, _, _) => id.outerDefNames
case InDefRecurred(id, _, _, _, _) => id.outerDefNames
case InRecurBranch(ir, _) => ir.outerDefNames
}

final def defNamesContain(n: Bindable): Boolean =
this match {
case TopLevel => false
case InDef(outer, dn, _, _) => (dn == n) || outer.defNamesContain(n)
case InDefRecurred(id, _, _, _) => id.defNamesContain(n)
case InDefRecurred(id, _, _, _, _) => id.defNamesContain(n)
case InRecurBranch(ir, _) => ir.defNamesContain(n)
}

def inDef(fnname: Bindable, args: NonEmptyList[Pattern.Parsed]): InDef =
def inDef(fnname: Bindable, args: NonEmptyList[NonEmptyList[Pattern.Parsed]]): InDef =
InDef(this, fnname, args, Set.empty)
}
sealed abstract class InDefState extends State {
final def defname: Bindable =
this match {
case InDef(_, defname, _, _) => defname
case InDefRecurred(ir, _, _, _) => ir.defname
case InRecurBranch(InDefRecurred(ir, _, _, _), _) => ir.defname
case InDefRecurred(ir, _, _, _, _) => ir.defname
case InRecurBranch(InDefRecurred(ir, _, _, _, _), _) => ir.defname
}
}
case object TopLevel extends State
case class InDef(outer: State, fnname: Bindable, args: NonEmptyList[Pattern.Parsed], localScope: Set[Bindable]) extends InDefState {
case class InDef(outer: State, fnname: Bindable, args: NonEmptyList[NonEmptyList[Pattern.Parsed]], localScope: Set[Bindable]) extends InDefState {

def addLocal(b: Bindable): InDef =
InDef(outer, fnname, args, localScope + b)

def setRecur(index: Int, m: Declaration.Match): InDefRecurred =
InDefRecurred(this, index, m, 0)
def setRecur(index: (Int, Int), m: Declaration.Match): InDefRecurred =
InDefRecurred(this, index._1, index._2, m, 0)
}
case class InDefRecurred(inRec: InDef, index: Int, recur: Declaration.Match, recCount: Int) extends InDefState {
case class InDefRecurred(inRec: InDef, group: Int, index: Int, recur: Declaration.Match, recCount: Int) extends InDefState {
def incRecCount: InDefRecurred = copy(recCount = recCount + 1)
}
case class InRecurBranch(inRec: InDefRecurred, branch: Pattern.Parsed) extends InDefState {
Expand All @@ -142,19 +155,25 @@ object DefRecursionCheck {
*/
def getRecurIndex(
fnname: Bindable,
args: NonEmptyList[Pattern.Parsed],
args: NonEmptyList[NonEmptyList[Pattern.Parsed]],
m: Declaration.Match,
locals: Set[Bindable]): ValidatedNel[RecursionError, Int] = {
locals: Set[Bindable]): ValidatedNel[RecursionError, (Int, Int)] = {
import Declaration._
m.arg match {
case Var(v) =>
v match {
case b: Bindable if locals(b) =>
Validated.invalidNel(RecurNotOnArg(m, fnname, args))
case _ =>
val idx = args.toList.indexWhere { p => p.topNames.contains(v) }
if (idx < 0) Validated.invalidNel(RecurNotOnArg(m, fnname, args))
else Validated.valid(idx)
val idxes = for {
(group, gidx) <- args.iterator.zipWithIndex
(item, idx) <- group.iterator.zipWithIndex
if item.topNames.contains(v)
} yield (gidx, idx)


if (idxes.hasNext) Validated.valid(idxes.next())
else Validated.invalidNel(RecurNotOnArg(m, fnname, args))
}
case _ =>
Validated.invalidNel(RecurNotOnArg(m, fnname, args))
Expand Down Expand Up @@ -229,37 +248,54 @@ object DefRecursionCheck {
})
} yield ()

def checkApply(nm: Bindable, args: NonEmptyList[Declaration], region: Region): St[Unit] =
private def argsOnDefName(fn: Declaration,
groups: NonEmptyList[NonEmptyList[Declaration]]): Option[(Bindable, NonEmptyList[NonEmptyList[Declaration]])] =
fn match {
case Declaration.Var(nm: Bindable) => Some((nm, groups))
case Declaration.Apply(fn1, args, _) =>
argsOnDefName(fn1, args :: groups)
case _ => None
}
def checkApply(fn: Declaration, args: NonEmptyList[Declaration], region: Region): St[Unit] =
getSt.flatMap {
case TopLevel =>
// without any recursion, normal typechecking will detect bad states:
args.traverse_(checkDecl)
checkDecl(fn) *> args.traverse_(checkDecl)
case irb@InRecurBranch(inrec, branch) =>
val idx = inrec.index
// here we are calling our recursive function
// make sure we do so on a substructural match
if (nm == irb.defname) {
args.get(idx.toLong) match {
case None =>
// not enough args to check recursion

argsOnDefName(fn, NonEmptyList.one(args)) match {
case Some((nm, groups)) =>
if (nm == irb.defname) {
val group = inrec.group
val idx = inrec.index
groups.get(group.toLong).flatMap(_.get(idx.toLong)) match {
case None =>
// not enough args to check recursion
failSt(InvalidRecursion(nm, region))
case Some(arg) =>
toSt(strictSubstructure(irb.defname, branch, arg)) *>
setSt(irb.incRecCount) // we have recurred again
}
}
else if (irb.defNamesContain(nm)) {
failSt(InvalidRecursion(nm, region))
case Some(arg) =>
toSt(strictSubstructure(irb.defname, branch, arg)) *>
setSt(irb.incRecCount) // we have recurred again
}
}
else if (irb.defNamesContain(nm)) {
failSt(InvalidRecursion(nm, region))
}
else {
// not a recursive call
args.traverse_(checkDecl)
}
else {
// not a recursive call
args.traverse_(checkDecl)
}
case None =>
// this isn't a recursive call
checkDecl(fn) *> args.traverse_(checkDecl)
}
case ir: InDefState =>
// we have either not yet, or already done the recursion
if (ir.defNamesContain(nm)) failSt(InvalidRecursion(nm, region))
else args.traverse_(checkDecl)
argsOnDefName(fn, NonEmptyList.one(args)) match {
case Some((nm, _)) if ir.defNamesContain(nm) => failSt(InvalidRecursion(nm, region))
case _ =>
checkDecl(fn) *> args.traverse_(checkDecl)
}
}
/*
* With the given state, check the given Declaration to see if
* we have valid recursion
Expand All @@ -268,12 +304,10 @@ object DefRecursionCheck {
import Declaration._
decl match {
case Annotation(t, _) => checkDecl(t)
case Apply(Var(nm: Bindable), args, _) =>
checkApply(nm, args, decl.region)
case Apply(fn, args, _) =>
checkDecl(fn) *> args.traverse_(checkDecl)
checkApply(fn, args, decl.region)
case ApplyOp(left, op, right) =>
checkApply(op, NonEmptyList(left, right :: Nil), decl.region)
checkApply(Var(op)(decl.region), NonEmptyList(left, right :: Nil), decl.region)
case Binding(BindingStatement(pat, thisDecl, next)) =>
checkForIllegalBindsSt(pat.names, decl) *>
checkDecl(thisDecl) *>
Expand Down Expand Up @@ -315,7 +349,7 @@ object DefRecursionCheck {
case recur@Match(RecursionKind.Recursive, _, cases) =>
// this is a state change
getSt.flatMap {
case TopLevel | InRecurBranch(_, _) | InDefRecurred(_, _, _, _) =>
case TopLevel | InRecurBranch(_, _) | InDefRecurred(_, _, _, _, _) =>
failSt(UnexpectedRecur(recur))
case InDef(_, defname, args, locals) =>
toSt(getRecurIndex(defname, args, recur, locals)).flatMap { idx =>
Expand All @@ -326,7 +360,7 @@ object DefRecursionCheck {
case ir@InDef(_, _, _, _) =>
val rec = ir.setRecur(idx, recur)
setSt(rec) *> beginBranch(pat)
case irr@InDefRecurred(_, _, _, _) =>
case irr@InDefRecurred(_, _, _, _, _) =>
setSt(InRecurBranch(irr, pat))
case illegal =>
// $COVERAGE-OFF$ this should be unreachable
Expand Down Expand Up @@ -418,17 +452,17 @@ object DefRecursionCheck {
*/
def checkDef[A](state: State, defstmt: DefStatement[Pattern.Parsed, (OptIndent[Declaration], A)]): Res = {
val body = defstmt.result._1.get
val nameArgs = defstmt.args.patternNames
val nameArgs = defstmt.args.toList.flatMap(_.patternNames)
val state1 = state.inDef(defstmt.name, defstmt.args)
checkForIllegalBinds(state, defstmt.name :: nameArgs, body) {
val st = setSt(state1) *> checkDecl(body) *> (getSt.flatMap {
case InDef(_, _, _, _) =>
// we never hit a recur
unitSt
case InDefRecurred(_, _, _, cnt) if cnt > 0 =>
case InDefRecurred(_, _, _, _, cnt) if cnt > 0 =>
// we did hit a recur
unitSt
case InDefRecurred(_, _, recur, 0) =>
case InDefRecurred(_, _, _, recur, 0) =>
// we hit a recur, but we didn't recurse
failSt[Unit](RecursiveDefNoRecur(defstmt.copy(result = defstmt.result._1.get), recur))
case unreachable =>
Expand Down
20 changes: 12 additions & 8 deletions core/src/main/scala/org/bykn/bosatsu/DefStatement.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import cats.syntax.all._
case class DefStatement[A, B](
name: Bindable,
typeArgs: Option[NonEmptyList[(TypeRef.TypeVar, Option[Kind])]],
args: NonEmptyList[A],
args: NonEmptyList[NonEmptyList[A]],
retType: Option[TypeRef],
result: B
)
Expand All @@ -36,12 +36,16 @@ object DefStatement {
}
}
val argDoc =
Doc.char('(') +
Doc.intercalate(
commaSpace,
args.map(Document[A].document(_)).toList
) +
Doc.char(')')
Doc.intercalate(Doc.empty,
args.toList.map { args =>
Doc.char('(') +
Doc.intercalate(
commaSpace,
args.map(Document[A].document(_)).toList
) +
Doc.char(')')
}
)
val line0 =
defDoc + Document[Bindable].document(name) + taDoc + argDoc + res + Doc
.char(':')
Expand All @@ -63,7 +67,7 @@ object DefStatement {
(
Parser.keySpace(
"def"
) *> (Identifier.bindableParser ~ TypeRef.typeParams(kindAnnot.?).? ~ args) <* maybeSpace,
) *> (Identifier.bindableParser ~ TypeRef.typeParams(kindAnnot.?).? ~ args.rep) <* maybeSpace,
result.with1 <* (maybeSpace.with1 ~ P.char(':')),
resultTParser
)
Expand Down
18 changes: 12 additions & 6 deletions core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,12 @@ final class SourceConverter(
(unTypedBody, toType(t, region), tag).parMapN(Expr.Annotation(_, _, _))
}

(ds.args.traverse(convertPattern(_, region)), bodyExp, tag).parMapN { (as, b, t) =>
val lambda = Expr.buildPatternLambda(as, b, t)
(Traverse[NonEmptyList]
.compose[NonEmptyList]
.traverse(ds.args)(convertPattern(_, region)),
bodyExp,
tag).parMapN { (groups, b, t) =>
val lambda = groups.toList.foldRight(b) { case (as, b) => Expr.buildPatternLambda(as, b, t) }
ds.typeArgs match {
case None => success(lambda)
case Some(args) =>
Expand Down Expand Up @@ -200,7 +204,7 @@ final class SourceConverter(
val inExpr = defstmt.result match {
case (_, Padding(_, in)) => withBound(in, defstmt.name :: Nil)
}
val newBindings = defstmt.name :: defstmt.args.patternNames
val newBindings = defstmt.name :: defstmt.args.toList.flatMap(_.patternNames)
// TODO
val lambda = toLambdaExpr(defstmt, decl.region, success(decl))({ res => withBound(res._1.get, newBindings) })

Expand Down Expand Up @@ -1100,7 +1104,7 @@ final class SourceConverter(
case Left(d@Def(dstmt)) =>
val d1 = if (dstmt.name === bind) dstmt.copy(name = newNameV) else dstmt
val res =
if (dstmt.args.iterator.flatMap(_.names).exists(_ == bind)) {
if (dstmt.args.flatten.iterator.flatMap(_.names).exists(_ == bind)) {
// the args are shadowing the binding, so we don't need to substitute
dstmt.result
}
Expand Down Expand Up @@ -1145,7 +1149,7 @@ final class SourceConverter(
val r = apply(decl, Set.empty, topBound).map((nm, RecursionKind.NonRecursive, _) :: Nil)
(topBound + nm, r)

case Right(Left(d @ Def(defstmt@DefStatement(_, _, pat, _, _)))) =>
case Right(Left(d @ Def(defstmt@DefStatement(_, _, argGroups, _, _)))) =>
// using body for the outer here is a bummer, but not really a good outer otherwise

val boundName = defstmt.name
Expand All @@ -1157,7 +1161,9 @@ final class SourceConverter(
defstmt,
d.region,
success(defstmt.result.get))(
{ (res: OptIndent[Declaration]) => apply(res.get, pat.iterator.flatMap(_.names).toSet + boundName, topBound1) })
{ (res: OptIndent[Declaration]) =>
apply(res.get, argGroups.flatten.iterator.flatMap(_.names).toSet + boundName, topBound1)
})

val r = lam.map { (l: Expr[Declaration]) =>
// We rely on DefRecursionCheck to rule out bad recursions
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/org/bykn/bosatsu/Statement.scala
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ object Statement {
case Def(defstatement) =>
val innerFrees = defstatement.result.get.freeVars
// but the def name and, args shadow
(innerFrees - defstatement.name) -- defstatement.args.patternNames
(innerFrees - defstatement.name) -- defstatement.args.toList.flatMap(_.patternNames)
case ExternalDef(_, _, _) => SortedSet.empty
}

Expand All @@ -106,7 +106,7 @@ object Statement {
this match {
case Bind(BindingStatement(pat, decl, _)) => decl.allNames ++ pat.names
case Def(defstatement) =>
(defstatement.result.get.allNames + defstatement.name) ++ defstatement.args.patternNames
(defstatement.result.get.allNames + defstatement.name) ++ defstatement.args.toList.flatMap(_.patternNames)
case ExternalDef(name, _, _) => SortedSet(name)
}
}
Expand Down
4 changes: 2 additions & 2 deletions core/src/test/scala/org/bykn/bosatsu/DeclarationTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ class DeclarationTest extends AnyFunSuite {
val b = Identifier.Backticked("")
val d1 = Literal(Lit.fromInt(0))
val d0 = DefFn(
DefStatement(Name("mfLjwok"),None, NonEmptyList.of(Pattern.Var(Name("foo"))),None,
DefStatement(Name("mfLjwok"),None, NonEmptyList.one(NonEmptyList.one(Pattern.Var(Name("foo")))),None,
(NotSameLine(Padding(10,Indented(10,Var(Backticked(""))))),
Padding(10,Binding(BindingStatement(
Pattern.Var(Backticked("")),Var(Constructor("Rgt")),Padding(1,DefFn(DefStatement(Backticked(""),None,NonEmptyList.of(Pattern.Var(Name("bar"))),None,(NotSameLine(Padding(2,Indented(4,Literal(Lit.fromInt(42))))),Padding(2,DefFn(DefStatement(Name("gkxAckqpatu"),None, NonEmptyList.of(Pattern.Var(Name("quux"))),Some(TypeRef.TypeName(TypeName(Constructor("Y")))),(NotSameLine(Padding(6,Indented(8,Literal(Lit("oimsu"))))),Padding(2,Var(Name("j")))))))))))))))))
Pattern.Var(Backticked("")),Var(Constructor("Rgt")),Padding(1,DefFn(DefStatement(Backticked(""),None,NonEmptyList.one(NonEmptyList.one(Pattern.Var(Name("bar")))),None,(NotSameLine(Padding(2,Indented(4,Literal(Lit.fromInt(42))))),Padding(2,DefFn(DefStatement(Name("gkxAckqpatu"),None, NonEmptyList.one(NonEmptyList.one(Pattern.Var(Name("quux")))),Some(TypeRef.TypeName(TypeName(Constructor("Y")))),(NotSameLine(Padding(6,Indented(8,Literal(Lit("oimsu"))))),Padding(2,Var(Name("j")))))))))))))))))

(b, d1, d0)
}
Expand Down
Loading

0 comments on commit ae52f82

Please sign in to comment.