From ae52f82f9b9e0bf6fc26491a1c4d3917a0e4e8da Mon Sep 17 00:00:00 2001 From: "P. Oscar Boykin" Date: Sat, 9 Sep 2023 07:46:35 -1000 Subject: [PATCH] Allow more than one parameter group on def (#1037) * Allow more than one parameter group on def * use the syntax in tests * remove commented code --- .../scala/org/bykn/bosatsu/Declaration.scala | 6 +- .../org/bykn/bosatsu/DefRecursionCheck.scala | 128 +++++++++++------- .../scala/org/bykn/bosatsu/DefStatement.scala | 20 +-- .../org/bykn/bosatsu/SourceConverter.scala | 18 ++- .../scala/org/bykn/bosatsu/Statement.scala | 4 +- .../org/bykn/bosatsu/DeclarationTest.scala | 4 +- .../org/bykn/bosatsu/EvaluationTest.scala | 6 +- .../src/test/scala/org/bykn/bosatsu/Gen.scala | 6 +- .../scala/org/bykn/bosatsu/ParserTest.scala | 3 +- test_workspace/ApplicativeTraverse.bosatsu | 2 +- test_workspace/Dict.bosatsu | 6 +- test_workspace/List.bosatsu | 6 +- test_workspace/Option.bosatsu | 2 +- test_workspace/Queue.bosatsu | 10 +- test_workspace/TreeList.bosatsu | 10 +- test_workspace/euler4.bosatsu | 2 +- 16 files changed, 139 insertions(+), 94 deletions(-) diff --git a/core/src/main/scala/org/bykn/bosatsu/Declaration.scala b/core/src/main/scala/org/bykn/bosatsu/Declaration.scala index 99704f21c..8c81bf11c 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Declaration.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Declaration.scala @@ -198,7 +198,7 @@ sealed abstract class Declaration { // may or may not be recursive val boundRest = bound + d.name - val boundBody = boundRest ++ d.args.patternNames + val boundBody = boundRest ++ d.args.toList.flatMap(_.patternNames) val acc1 = loop(body.get, boundBody, acc) loop(rest.padded, boundRest, acc1) @@ -311,7 +311,7 @@ sealed abstract class Declaration { case DefFn(d) => // def sets up a binding to itself, which // may or may not be recursive - val acc1 = (acc + d.name) ++ d.args.patternNames + val acc1 = (acc + d.name) ++ d.args.toList.flatMap(_.patternNames) val (body, rest) = d.result val acc2 = loop(body.get, acc1) loop(rest.padded, acc2) @@ -569,7 +569,7 @@ object Declaration { else if (scope.exists(shadows)) Some(d0) else loopDec(d0) - val bodyScope = nm :: args.patternNames + val bodyScope = nm :: args.toList.flatMap(_.patternNames) val restScope = nm :: Nil (body.traverse(go(bodyScope, _)), rest.traverse(go(restScope, _))) diff --git a/core/src/main/scala/org/bykn/bosatsu/DefRecursionCheck.scala b/core/src/main/scala/org/bykn/bosatsu/DefRecursionCheck.scala index 95b38f688..62847032d 100644 --- a/core/src/main/scala/org/bykn/bosatsu/DefRecursionCheck.scala +++ b/core/src/main/scala/org/bykn/bosatsu/DefRecursionCheck.scala @@ -1,6 +1,7 @@ package org.bykn.bosatsu import cats.data.{NonEmptyList, Validated, ValidatedNel, StateT} +import org.typelevel.paiges.Doc import cats.implicits._ @@ -40,10 +41,22 @@ object DefRecursionCheck { def region = decl.region def message = "unexpected recur: may only appear unnested inside a def" } - case class RecurNotOnArg(decl: Declaration.Match, fnname: Bindable, args: NonEmptyList[Pattern.Parsed]) extends RecursionError { + case class RecurNotOnArg(decl: Declaration.Match, + fnname: Bindable, + args: NonEmptyList[NonEmptyList[Pattern.Parsed]]) extends RecursionError { + def region = decl.region def message = { - val argStr = args.iterator.map { pat => Pattern.document[TypeRef].document(pat).render(80) }.mkString(", ") + val argsDoc = + Doc.intercalate(Doc.empty, + args.toList.map { group => + (Doc.char('(') + + Doc.intercalate(Doc.comma + Doc.line, + group.toList.map { pat => Pattern.document[TypeRef].document(pat) }) + + Doc.char(')')).grouped + } + ) + val argStr = argsDoc.render(80) s"recur not on an argument to the def of ${fnname.sourceCodeRepr}, args: $argStr" } } @@ -98,7 +111,7 @@ object DefRecursionCheck { this match { case TopLevel => Nil case InDef(outer, n, _, _) => n :: outer.outerDefNames - case InDefRecurred(id, _, _, _) => id.outerDefNames + case InDefRecurred(id, _, _, _, _) => id.outerDefNames case InRecurBranch(ir, _) => ir.outerDefNames } @@ -106,31 +119,31 @@ object DefRecursionCheck { this match { case TopLevel => false case InDef(outer, dn, _, _) => (dn == n) || outer.defNamesContain(n) - case InDefRecurred(id, _, _, _) => id.defNamesContain(n) + case InDefRecurred(id, _, _, _, _) => id.defNamesContain(n) case InRecurBranch(ir, _) => ir.defNamesContain(n) } - def inDef(fnname: Bindable, args: NonEmptyList[Pattern.Parsed]): InDef = + def inDef(fnname: Bindable, args: NonEmptyList[NonEmptyList[Pattern.Parsed]]): InDef = InDef(this, fnname, args, Set.empty) } sealed abstract class InDefState extends State { final def defname: Bindable = this match { case InDef(_, defname, _, _) => defname - case InDefRecurred(ir, _, _, _) => ir.defname - case InRecurBranch(InDefRecurred(ir, _, _, _), _) => ir.defname + case InDefRecurred(ir, _, _, _, _) => ir.defname + case InRecurBranch(InDefRecurred(ir, _, _, _, _), _) => ir.defname } } case object TopLevel extends State - case class InDef(outer: State, fnname: Bindable, args: NonEmptyList[Pattern.Parsed], localScope: Set[Bindable]) extends InDefState { + case class InDef(outer: State, fnname: Bindable, args: NonEmptyList[NonEmptyList[Pattern.Parsed]], localScope: Set[Bindable]) extends InDefState { def addLocal(b: Bindable): InDef = InDef(outer, fnname, args, localScope + b) - def setRecur(index: Int, m: Declaration.Match): InDefRecurred = - InDefRecurred(this, index, m, 0) + def setRecur(index: (Int, Int), m: Declaration.Match): InDefRecurred = + InDefRecurred(this, index._1, index._2, m, 0) } - case class InDefRecurred(inRec: InDef, index: Int, recur: Declaration.Match, recCount: Int) extends InDefState { + case class InDefRecurred(inRec: InDef, group: Int, index: Int, recur: Declaration.Match, recCount: Int) extends InDefState { def incRecCount: InDefRecurred = copy(recCount = recCount + 1) } case class InRecurBranch(inRec: InDefRecurred, branch: Pattern.Parsed) extends InDefState { @@ -142,9 +155,9 @@ object DefRecursionCheck { */ def getRecurIndex( fnname: Bindable, - args: NonEmptyList[Pattern.Parsed], + args: NonEmptyList[NonEmptyList[Pattern.Parsed]], m: Declaration.Match, - locals: Set[Bindable]): ValidatedNel[RecursionError, Int] = { + locals: Set[Bindable]): ValidatedNel[RecursionError, (Int, Int)] = { import Declaration._ m.arg match { case Var(v) => @@ -152,9 +165,15 @@ object DefRecursionCheck { case b: Bindable if locals(b) => Validated.invalidNel(RecurNotOnArg(m, fnname, args)) case _ => - val idx = args.toList.indexWhere { p => p.topNames.contains(v) } - if (idx < 0) Validated.invalidNel(RecurNotOnArg(m, fnname, args)) - else Validated.valid(idx) + val idxes = for { + (group, gidx) <- args.iterator.zipWithIndex + (item, idx) <- group.iterator.zipWithIndex + if item.topNames.contains(v) + } yield (gidx, idx) + + + if (idxes.hasNext) Validated.valid(idxes.next()) + else Validated.invalidNel(RecurNotOnArg(m, fnname, args)) } case _ => Validated.invalidNel(RecurNotOnArg(m, fnname, args)) @@ -229,37 +248,54 @@ object DefRecursionCheck { }) } yield () - def checkApply(nm: Bindable, args: NonEmptyList[Declaration], region: Region): St[Unit] = + private def argsOnDefName(fn: Declaration, + groups: NonEmptyList[NonEmptyList[Declaration]]): Option[(Bindable, NonEmptyList[NonEmptyList[Declaration]])] = + fn match { + case Declaration.Var(nm: Bindable) => Some((nm, groups)) + case Declaration.Apply(fn1, args, _) => + argsOnDefName(fn1, args :: groups) + case _ => None + } + def checkApply(fn: Declaration, args: NonEmptyList[Declaration], region: Region): St[Unit] = getSt.flatMap { case TopLevel => // without any recursion, normal typechecking will detect bad states: - args.traverse_(checkDecl) + checkDecl(fn) *> args.traverse_(checkDecl) case irb@InRecurBranch(inrec, branch) => - val idx = inrec.index - // here we are calling our recursive function - // make sure we do so on a substructural match - if (nm == irb.defname) { - args.get(idx.toLong) match { - case None => - // not enough args to check recursion + + argsOnDefName(fn, NonEmptyList.one(args)) match { + case Some((nm, groups)) => + if (nm == irb.defname) { + val group = inrec.group + val idx = inrec.index + groups.get(group.toLong).flatMap(_.get(idx.toLong)) match { + case None => + // not enough args to check recursion + failSt(InvalidRecursion(nm, region)) + case Some(arg) => + toSt(strictSubstructure(irb.defname, branch, arg)) *> + setSt(irb.incRecCount) // we have recurred again + } + } + else if (irb.defNamesContain(nm)) { failSt(InvalidRecursion(nm, region)) - case Some(arg) => - toSt(strictSubstructure(irb.defname, branch, arg)) *> - setSt(irb.incRecCount) // we have recurred again - } - } - else if (irb.defNamesContain(nm)) { - failSt(InvalidRecursion(nm, region)) - } - else { - // not a recursive call - args.traverse_(checkDecl) + } + else { + // not a recursive call + args.traverse_(checkDecl) + } + case None => + // this isn't a recursive call + checkDecl(fn) *> args.traverse_(checkDecl) } case ir: InDefState => // we have either not yet, or already done the recursion - if (ir.defNamesContain(nm)) failSt(InvalidRecursion(nm, region)) - else args.traverse_(checkDecl) + argsOnDefName(fn, NonEmptyList.one(args)) match { + case Some((nm, _)) if ir.defNamesContain(nm) => failSt(InvalidRecursion(nm, region)) + case _ => + checkDecl(fn) *> args.traverse_(checkDecl) } + } /* * With the given state, check the given Declaration to see if * we have valid recursion @@ -268,12 +304,10 @@ object DefRecursionCheck { import Declaration._ decl match { case Annotation(t, _) => checkDecl(t) - case Apply(Var(nm: Bindable), args, _) => - checkApply(nm, args, decl.region) case Apply(fn, args, _) => - checkDecl(fn) *> args.traverse_(checkDecl) + checkApply(fn, args, decl.region) case ApplyOp(left, op, right) => - checkApply(op, NonEmptyList(left, right :: Nil), decl.region) + checkApply(Var(op)(decl.region), NonEmptyList(left, right :: Nil), decl.region) case Binding(BindingStatement(pat, thisDecl, next)) => checkForIllegalBindsSt(pat.names, decl) *> checkDecl(thisDecl) *> @@ -315,7 +349,7 @@ object DefRecursionCheck { case recur@Match(RecursionKind.Recursive, _, cases) => // this is a state change getSt.flatMap { - case TopLevel | InRecurBranch(_, _) | InDefRecurred(_, _, _, _) => + case TopLevel | InRecurBranch(_, _) | InDefRecurred(_, _, _, _, _) => failSt(UnexpectedRecur(recur)) case InDef(_, defname, args, locals) => toSt(getRecurIndex(defname, args, recur, locals)).flatMap { idx => @@ -326,7 +360,7 @@ object DefRecursionCheck { case ir@InDef(_, _, _, _) => val rec = ir.setRecur(idx, recur) setSt(rec) *> beginBranch(pat) - case irr@InDefRecurred(_, _, _, _) => + case irr@InDefRecurred(_, _, _, _, _) => setSt(InRecurBranch(irr, pat)) case illegal => // $COVERAGE-OFF$ this should be unreachable @@ -418,17 +452,17 @@ object DefRecursionCheck { */ def checkDef[A](state: State, defstmt: DefStatement[Pattern.Parsed, (OptIndent[Declaration], A)]): Res = { val body = defstmt.result._1.get - val nameArgs = defstmt.args.patternNames + val nameArgs = defstmt.args.toList.flatMap(_.patternNames) val state1 = state.inDef(defstmt.name, defstmt.args) checkForIllegalBinds(state, defstmt.name :: nameArgs, body) { val st = setSt(state1) *> checkDecl(body) *> (getSt.flatMap { case InDef(_, _, _, _) => // we never hit a recur unitSt - case InDefRecurred(_, _, _, cnt) if cnt > 0 => + case InDefRecurred(_, _, _, _, cnt) if cnt > 0 => // we did hit a recur unitSt - case InDefRecurred(_, _, recur, 0) => + case InDefRecurred(_, _, _, recur, 0) => // we hit a recur, but we didn't recurse failSt[Unit](RecursiveDefNoRecur(defstmt.copy(result = defstmt.result._1.get), recur)) case unreachable => diff --git a/core/src/main/scala/org/bykn/bosatsu/DefStatement.scala b/core/src/main/scala/org/bykn/bosatsu/DefStatement.scala index 4f5eccada..932814983 100644 --- a/core/src/main/scala/org/bykn/bosatsu/DefStatement.scala +++ b/core/src/main/scala/org/bykn/bosatsu/DefStatement.scala @@ -12,7 +12,7 @@ import cats.syntax.all._ case class DefStatement[A, B]( name: Bindable, typeArgs: Option[NonEmptyList[(TypeRef.TypeVar, Option[Kind])]], - args: NonEmptyList[A], + args: NonEmptyList[NonEmptyList[A]], retType: Option[TypeRef], result: B ) @@ -36,12 +36,16 @@ object DefStatement { } } val argDoc = - Doc.char('(') + - Doc.intercalate( - commaSpace, - args.map(Document[A].document(_)).toList - ) + - Doc.char(')') + Doc.intercalate(Doc.empty, + args.toList.map { args => + Doc.char('(') + + Doc.intercalate( + commaSpace, + args.map(Document[A].document(_)).toList + ) + + Doc.char(')') + } + ) val line0 = defDoc + Document[Bindable].document(name) + taDoc + argDoc + res + Doc .char(':') @@ -63,7 +67,7 @@ object DefStatement { ( Parser.keySpace( "def" - ) *> (Identifier.bindableParser ~ TypeRef.typeParams(kindAnnot.?).? ~ args) <* maybeSpace, + ) *> (Identifier.bindableParser ~ TypeRef.typeParams(kindAnnot.?).? ~ args.rep) <* maybeSpace, result.with1 <* (maybeSpace.with1 ~ P.char(':')), resultTParser ) diff --git a/core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala b/core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala index ddf618e8b..64c161090 100644 --- a/core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala +++ b/core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala @@ -94,8 +94,12 @@ final class SourceConverter( (unTypedBody, toType(t, region), tag).parMapN(Expr.Annotation(_, _, _)) } - (ds.args.traverse(convertPattern(_, region)), bodyExp, tag).parMapN { (as, b, t) => - val lambda = Expr.buildPatternLambda(as, b, t) + (Traverse[NonEmptyList] + .compose[NonEmptyList] + .traverse(ds.args)(convertPattern(_, region)), + bodyExp, + tag).parMapN { (groups, b, t) => + val lambda = groups.toList.foldRight(b) { case (as, b) => Expr.buildPatternLambda(as, b, t) } ds.typeArgs match { case None => success(lambda) case Some(args) => @@ -200,7 +204,7 @@ final class SourceConverter( val inExpr = defstmt.result match { case (_, Padding(_, in)) => withBound(in, defstmt.name :: Nil) } - val newBindings = defstmt.name :: defstmt.args.patternNames + val newBindings = defstmt.name :: defstmt.args.toList.flatMap(_.patternNames) // TODO val lambda = toLambdaExpr(defstmt, decl.region, success(decl))({ res => withBound(res._1.get, newBindings) }) @@ -1100,7 +1104,7 @@ final class SourceConverter( case Left(d@Def(dstmt)) => val d1 = if (dstmt.name === bind) dstmt.copy(name = newNameV) else dstmt val res = - if (dstmt.args.iterator.flatMap(_.names).exists(_ == bind)) { + if (dstmt.args.flatten.iterator.flatMap(_.names).exists(_ == bind)) { // the args are shadowing the binding, so we don't need to substitute dstmt.result } @@ -1145,7 +1149,7 @@ final class SourceConverter( val r = apply(decl, Set.empty, topBound).map((nm, RecursionKind.NonRecursive, _) :: Nil) (topBound + nm, r) - case Right(Left(d @ Def(defstmt@DefStatement(_, _, pat, _, _)))) => + case Right(Left(d @ Def(defstmt@DefStatement(_, _, argGroups, _, _)))) => // using body for the outer here is a bummer, but not really a good outer otherwise val boundName = defstmt.name @@ -1157,7 +1161,9 @@ final class SourceConverter( defstmt, d.region, success(defstmt.result.get))( - { (res: OptIndent[Declaration]) => apply(res.get, pat.iterator.flatMap(_.names).toSet + boundName, topBound1) }) + { (res: OptIndent[Declaration]) => + apply(res.get, argGroups.flatten.iterator.flatMap(_.names).toSet + boundName, topBound1) + }) val r = lam.map { (l: Expr[Declaration]) => // We rely on DefRecursionCheck to rule out bad recursions diff --git a/core/src/main/scala/org/bykn/bosatsu/Statement.scala b/core/src/main/scala/org/bykn/bosatsu/Statement.scala index 3063a3f91..4469a94b9 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Statement.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Statement.scala @@ -95,7 +95,7 @@ object Statement { case Def(defstatement) => val innerFrees = defstatement.result.get.freeVars // but the def name and, args shadow - (innerFrees - defstatement.name) -- defstatement.args.patternNames + (innerFrees - defstatement.name) -- defstatement.args.toList.flatMap(_.patternNames) case ExternalDef(_, _, _) => SortedSet.empty } @@ -106,7 +106,7 @@ object Statement { this match { case Bind(BindingStatement(pat, decl, _)) => decl.allNames ++ pat.names case Def(defstatement) => - (defstatement.result.get.allNames + defstatement.name) ++ defstatement.args.patternNames + (defstatement.result.get.allNames + defstatement.name) ++ defstatement.args.toList.flatMap(_.patternNames) case ExternalDef(name, _, _) => SortedSet(name) } } diff --git a/core/src/test/scala/org/bykn/bosatsu/DeclarationTest.scala b/core/src/test/scala/org/bykn/bosatsu/DeclarationTest.scala index 775e61c41..f162145d5 100644 --- a/core/src/test/scala/org/bykn/bosatsu/DeclarationTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/DeclarationTest.scala @@ -123,10 +123,10 @@ class DeclarationTest extends AnyFunSuite { val b = Identifier.Backticked("") val d1 = Literal(Lit.fromInt(0)) val d0 = DefFn( - DefStatement(Name("mfLjwok"),None, NonEmptyList.of(Pattern.Var(Name("foo"))),None, + DefStatement(Name("mfLjwok"),None, NonEmptyList.one(NonEmptyList.one(Pattern.Var(Name("foo")))),None, (NotSameLine(Padding(10,Indented(10,Var(Backticked(""))))), Padding(10,Binding(BindingStatement( - Pattern.Var(Backticked("")),Var(Constructor("Rgt")),Padding(1,DefFn(DefStatement(Backticked(""),None,NonEmptyList.of(Pattern.Var(Name("bar"))),None,(NotSameLine(Padding(2,Indented(4,Literal(Lit.fromInt(42))))),Padding(2,DefFn(DefStatement(Name("gkxAckqpatu"),None, NonEmptyList.of(Pattern.Var(Name("quux"))),Some(TypeRef.TypeName(TypeName(Constructor("Y")))),(NotSameLine(Padding(6,Indented(8,Literal(Lit("oimsu"))))),Padding(2,Var(Name("j"))))))))))))))))) + Pattern.Var(Backticked("")),Var(Constructor("Rgt")),Padding(1,DefFn(DefStatement(Backticked(""),None,NonEmptyList.one(NonEmptyList.one(Pattern.Var(Name("bar")))),None,(NotSameLine(Padding(2,Indented(4,Literal(Lit.fromInt(42))))),Padding(2,DefFn(DefStatement(Name("gkxAckqpatu"),None, NonEmptyList.one(NonEmptyList.one(Pattern.Var(Name("quux")))),Some(TypeRef.TypeName(TypeName(Constructor("Y")))),(NotSameLine(Padding(6,Indented(8,Literal(Lit("oimsu"))))),Padding(2,Var(Name("j"))))))))))))))))) (b, d1, d0) } diff --git a/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala b/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala index cb49fdcbf..2845f3d19 100644 --- a/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala @@ -1486,7 +1486,7 @@ def fn(x): main = fn """)) { case te@PackageError.RecursionError(_, _) => - assert(te.message(Map.empty, Colorize.None) == "in file: , package A\nrecur not on an argument to the def of fn, args: x\nRegion(25,43)\n") + assert(te.message(Map.empty, Colorize.None) == "in file: , package A\nrecur not on an argument to the def of fn, args: (x)\nRegion(25,43)\n") () } @@ -1500,7 +1500,7 @@ def fn(x): main = fn """)) { case te@PackageError.RecursionError(_, _) => - assert(te.message(Map.empty, Colorize.None) == "in file: , package A\nrecur not on an argument to the def of fn, args: x\nRegion(25,42)\n") + assert(te.message(Map.empty, Colorize.None) == "in file: , package A\nrecur not on an argument to the def of fn, args: (x)\nRegion(25,42)\n") () } @@ -2577,7 +2577,7 @@ def bar(y, _: String, x): test = Assertion(True, "") """)) { case re@PackageError.RecursionError(_, _) => - assert(re.message(Map.empty, Colorize.None) == "in file: , package S\nrecur not on an argument to the def of bar, args: y, _: String, x\nRegion(107,165)\n") + assert(re.message(Map.empty, Colorize.None) == "in file: , package S\nrecur not on an argument to the def of bar, args: (y, _: String, x)\nRegion(107,165)\n") () } } diff --git a/core/src/test/scala/org/bykn/bosatsu/Gen.scala b/core/src/test/scala/org/bykn/bosatsu/Gen.scala index 68b9c32af..ab203f898 100644 --- a/core/src/test/scala/org/bykn/bosatsu/Gen.scala +++ b/core/src/test/scala/org/bykn/bosatsu/Gen.scala @@ -146,11 +146,11 @@ object Generators { def defGen[T](dec: Gen[T]): Gen[DefStatement[Pattern.Parsed, T]] = for { name <- bindIdentGen - args <- nonEmpty(argGen) + args <- smallNonEmptyList(smallNonEmptyList(argGen, 8), 20) tpes <- smallList(Gen.zip(typeRefVarGen, Gen.option(NTypeGen.genKind))) retType <- Gen.option(typeRefGen) body <- dec - } yield DefStatement(name, NonEmptyList.fromList(tpes), args.map(argToPat), retType, body) + } yield DefStatement(name, NonEmptyList.fromList(tpes), args.map(_.map(argToPat)), retType, body) def genSpliceOrItem[A](spliceGen: Gen[A], itemGen: Gen[A]): Gen[ListLang.SpliceOrItem[A]] = Gen.oneOf(spliceGen.map(ListLang.SpliceOrItem.Splice(_)), @@ -279,7 +279,7 @@ object Generators { fn <- fnGen dotApply <- dotApplyGen useDot = dotApply && isVar(fn) // f.bar needs the fn to be a var - argsGen = if (useDot) arg.map(NonEmptyList.one(_)) else nonEmpty(arg) + argsGen = if (useDot) arg.map(NonEmptyList.one(_)) else smallNonEmptyList(arg, 8) args <- argsGen } yield Apply(fn, args, ApplyKind.Parens)(emptyRegion)) // TODO this should pass if we use `foo.bar(a, b)` syntax } diff --git a/core/src/test/scala/org/bykn/bosatsu/ParserTest.scala b/core/src/test/scala/org/bykn/bosatsu/ParserTest.scala index 2c24e091c..af4872339 100644 --- a/core/src/test/scala/org/bykn/bosatsu/ParserTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/ParserTest.scala @@ -577,7 +577,8 @@ foo""" parseTestAll( Declaration.parser(""), defWithComment, - Declaration.DefFn(DefStatement(Identifier.Name("foo"), None, NonEmptyList.of(Pattern.Var(Identifier.Name("a"))), None, + Declaration.DefFn(DefStatement(Identifier.Name("foo"), None, + NonEmptyList.one(NonEmptyList.one(Pattern.Var(Identifier.Name("a")))), None, (OptIndent.paddedIndented(1, 2, Declaration.CommentNB(CommentStatement(NonEmptyList.of(" comment here"), Padding(0, mkVar("a"))))), Padding(0, mkVar("foo")))))) diff --git a/test_workspace/ApplicativeTraverse.bosatsu b/test_workspace/ApplicativeTraverse.bosatsu index a07ad0ec2..411e304fd 100644 --- a/test_workspace/ApplicativeTraverse.bosatsu +++ b/test_workspace/ApplicativeTraverse.bosatsu @@ -90,7 +90,7 @@ def trav_l[f: * -> *](app: Applicative[f], fn: a -> f[b], lst: List[a]) -> f[Lis trav_list_opt = (fn, lst) -> trav_l(applicative_Option, fn, lst) # Here is equality on Option[List[Int]] -eq_opt_list_int = (a, b) -> eq_Option((l1, l2) -> eq_List(eq_Int, l1, l2), a, b) +eq_opt_list_int = eq_Option(eq_List(eq_Int)) operator == = eq_opt_list_int diff --git a/test_workspace/Dict.bosatsu b/test_workspace/Dict.bosatsu index b1d7591c8..b6a4484ec 100644 --- a/test_workspace/Dict.bosatsu +++ b/test_workspace/Dict.bosatsu @@ -5,8 +5,8 @@ from Bosatsu/List import eq_List export eq_Dict, eq_Pair -def eq_Pair(eq_a, eq_b, (l1, l2), (r1, r2)) -> Bool: +def eq_Pair(eq_a, eq_b)((l1, l2), (r1, r2)) -> Bool: eq_b(l2, r2) if eq_a(l1, r1) else False -def eq_Dict(eq_key, eq_value, left, right) -> Bool: - eq_List((d1, d2) -> eq_Pair(eq_key, eq_value, d1, d2), items(left), items(right)) +def eq_Dict(eq_key, eq_value)(left, right) -> Bool: + eq_List(eq_Pair(eq_key, eq_value))(items(left), items(right)) diff --git a/test_workspace/List.bosatsu b/test_workspace/List.bosatsu index c3937f02c..78aade8e0 100644 --- a/test_workspace/List.bosatsu +++ b/test_workspace/List.bosatsu @@ -34,14 +34,14 @@ def head(xs: List[a]) -> Option[a]: []: None [h, *_]: Some(h) -def eq_List(fn: (a, a) -> Bool, a: List[a], b: List[a]) -> Bool: +def eq_List(fn: (a, a) -> Bool)(a: List[a], b: List[a]) -> Bool: recur a: []: b matches [] [ah, *at]: match b: []: False [bh, *bt]: - eq_List(fn, at, bt) if fn(ah, bh) else False + eq_List(fn)(at, bt) if fn(ah, bh) else False def zip(left: List[a], right: List[b]) -> List[(a, b)]: recur left: @@ -83,7 +83,7 @@ def sort(ord: Order[a], list: List[a]) -> List[a]: # Test code below ########################## -operator =*= = (x, y) -> eq_List(eq_Int, x, y) +operator =*= = eq_List(eq_Int) def not(x): False if x else True diff --git a/test_workspace/Option.bosatsu b/test_workspace/Option.bosatsu index ec5440528..3fb307129 100644 --- a/test_workspace/Option.bosatsu +++ b/test_workspace/Option.bosatsu @@ -2,7 +2,7 @@ package Bosatsu/Option export eq_Option -def eq_Option(eq: (a, a) -> Bool, left: Option[a], right: Option[a]) -> Bool: +def eq_Option(eq: (a, a) -> Bool)(left: Option[a], right: Option[a]) -> Bool: match (left, right): (Some(a), Some(b)): eq(a, b) (None, None): True diff --git a/test_workspace/Queue.bosatsu b/test_workspace/Queue.bosatsu index e5fd11861..e94297668 100644 --- a/test_workspace/Queue.bosatsu +++ b/test_workspace/Queue.bosatsu @@ -45,7 +45,7 @@ def fold_Queue(Queue(f, b): Queue[a], init: b, fold_fn: (b, a) -> b) -> b: def reverse_Queue(Queue(f, b): Queue[a]) -> Queue[a]: Queue(b.reverse(), f.reverse()) -def eq_Queue(eq_fn: (a, a) -> Bool, left: Queue[a], right: Queue[a]) -> Bool: +def eq_Queue(eq_fn: (a, a) -> Bool)(left: Queue[a], right: Queue[a]) -> Bool: res = left.fold_Queue((True, right), \(g, right), al -> if g: match unpush(right): @@ -66,15 +66,15 @@ def to_List(Queue(f, b): Queue[a]) -> List[a]: ## TestSuites below ######## -def eq_Opt(eq_inner): - (a, b) -> match (a, b): +def eq_Opt(eq_inner)(a, b): + match (a, b): (Some(a), Some(b)): eq_inner(a, b) (None, None): True _: False eq_oi = eq_Opt(eq_Int) -eq_qi = (q1, q2) -> eq_Queue(eq_Int, q1, q2) -eq_li = (l1, l2) -> eq_List(eq_Int, l1, l2) +eq_qi = eq_Queue(eq_Int) +eq_li = eq_List(eq_Int) q12 = empty.push(1).push(2) diff --git a/test_workspace/TreeList.bosatsu b/test_workspace/TreeList.bosatsu index 46c0ccdfd..56040f692 100644 --- a/test_workspace/TreeList.bosatsu +++ b/test_workspace/TreeList.bosatsu @@ -93,7 +93,7 @@ def fold(TreeList(trees): TreeList[a], init: b, fn: (b, a) -> b) -> b: def to_List(list: TreeList[a]) -> List[a]: fold(list, [], \l, h -> [h, *l]).reverse() -def eq_TreeList(fn, a, b): +def eq_TreeList(fn)(a, b): (res, _) = a.fold((True, b), \(current, b), h -> if current: match decons(b): @@ -105,15 +105,15 @@ def eq_TreeList(fn, a, b): ) res -def eq_Opt(fn, a, b): +def eq_Opt(fn)(a, b): match (a, b): (Some(a), Some(b)): fn(a, b) (None, None): True _: False -eq_oi = (a, b) -> eq_Opt(eq_Int, a, b) -eq_ti = (a, b) -> eq_TreeList(eq_Int, a, b) -eq_li = (a, b) -> eq_List(eq_Int, a, b) +eq_oi = eq_Opt(eq_Int) +eq_ti = eq_TreeList(eq_Int) +eq_li = eq_List(eq_Int) operator +/ = cons diff --git a/test_workspace/euler4.bosatsu b/test_workspace/euler4.bosatsu index 0a700bb82..de4dba30d 100644 --- a/test_workspace/euler4.bosatsu +++ b/test_workspace/euler4.bosatsu @@ -58,7 +58,7 @@ def digit_list(n): reverse(rev_list) def is_palindrome(eq_fn, lst): - eq_List(eq_fn, lst, reverse(lst)) + eq_List(eq_fn)(lst, reverse(lst)) def num_is_palindrome(n): digits = digit_list(n)