Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add annotation on lambdas to improve type error locations #1038

Merged
merged 1 commit into from
Sep 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions core/src/main/scala/org/bykn/bosatsu/Pattern.scala
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,18 @@ sealed abstract class Pattern[+N, +T] {

loop(this)._2.distinct.sorted
}

/**
* @return the type if we can directly see it
*/
def simpleTypeOf: Option[T] =
this match {
case Pattern.Named(_, p) => p.simpleTypeOf
case Pattern.Annotation(_, t) => Some(t)
case Pattern.Union(_, _) | Pattern.ListPat(_) | Pattern.Literal(_) |
Pattern.WildCard | Pattern.Var(_) | Pattern.StrPat(_) |
Pattern.PositionalStruct(_, _) => None
}
}

object Pattern {
Expand Down
45 changes: 34 additions & 11 deletions core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package org.bykn.bosatsu

import cats.{Applicative, Traverse}
import cats.data.{ Chain, Ior, NonEmptyChain, NonEmptyList, State }
import cats.implicits._
import org.bykn.bosatsu.rankn.{ParsedTypeEnv, Type, TypeEnv}
import scala.collection.immutable.SortedSet
import scala.collection.mutable.{Map => MMap}
Expand All @@ -11,6 +10,8 @@ import org.typelevel.paiges.{Doc, Document}
// this is used to make slightly nicer syntax on Error creation
import scala.language.implicitConversions

import cats.syntax.all._

import ListLang.{KVPair, SpliceOrItem}

import Identifier.{Bindable, Constructor}
Expand Down Expand Up @@ -89,17 +90,37 @@ final class SourceConverter(
ds: DefStatement[Pattern.Parsed, B], region: Region, tag: Result[Declaration])(
resultExpr: B => Result[Expr[Declaration]]): Result[Expr[Declaration]] = {
val unTypedBody = resultExpr(ds.result)
val bodyExp =
ds.retType.fold(unTypedBody) { t =>
(unTypedBody, toType(t, region), tag).parMapN(Expr.Annotation(_, _, _))

val bodyType: Option[Result[Type]] = ds.retType.map(toType(_, region))

val bodyExp: Result[Expr[Declaration]] =
bodyType.fold(unTypedBody) { t =>
(unTypedBody, t, tag).parMapN(Expr.Annotation(_, _, _))
}

val travNE2 = Traverse[NonEmptyList].compose[NonEmptyList]

type Pat = Pattern[(PackageName, Constructor), Type]
val convertedArgs: Result[NonEmptyList[NonEmptyList[Pat]]] =
travNE2.traverse(ds.args)(convertPattern(_, region))

// If we have the full type of the lambda, apply it. This
// helps in recursive cases since we can see at the call site
// rather than the final recursive let binding that an application
// was incorrect. Without this, type errors become very non-specific.
val maybeFullyTyped: Result[Option[Type]] =
(convertedArgs, bodyType.sequence).parMapN { case (args, optResTpe) =>
(travNE2.traverse(args)((p: Pat) => p.simpleTypeOf), optResTpe).mapN { case (argsTpe, resTpe) =>
argsTpe.toList.foldRight(resTpe) { (args, res) => rankn.Type.Fun(args, res) }
}
}

(Traverse[NonEmptyList]
.compose[NonEmptyList]
.traverse(ds.args)(convertPattern(_, region)),
(convertedArgs,
bodyExp,
tag).parMapN { (groups, b, t) =>
val lambda = groups.toList.foldRight(b) { case (as, b) => Expr.buildPatternLambda(as, b, t) }
tag,
maybeFullyTyped).parMapN { (groups, b, t, fullType) =>
val lambda0 = groups.toList.foldRight(b) { case (as, b) => Expr.buildPatternLambda(as, b, t) }
val lambda = fullType.fold(lambda0)(Expr.Annotation(lambda0, _, t))
ds.typeArgs match {
case None => success(lambda)
case Some(args) =>
Expand Down Expand Up @@ -189,7 +210,7 @@ final class SourceConverter(
}
case pat =>
// TODO: we need the region on the pattern...
(convertPattern(pat, decl.region), erest, rrhs).parMapN { (newPattern, e, rhs) =>
(convertPattern(pat, decl.region - value.region), erest, rrhs).parMapN { (newPattern, e, rhs) =>
val expBranches = NonEmptyList.of((newPattern, e))
Expr.Match(rhs, expBranches, decl)
}
Expand Down Expand Up @@ -996,8 +1017,10 @@ final class SourceConverter(
val ident = alloc()
NonEmptyList.one((ident, decl))
case complex =>
// TODO, flattening the pattern (a, b, c, d) = (1, 2, 3, 4) might be nice...
// flattening the pattern (a, b, c, d) = (1, 2, 3, 4) might be nice...
// that is not done yet, it will allocate the tuple, just to destructure it
// but that optimization is done later since it is allocation and deallocation
// of a struct
val (prefix, rightHandSide) =
if (decl.isCheap) {
// no need to make a new var to point to a var
Expand Down
21 changes: 21 additions & 0 deletions core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2891,6 +2891,27 @@ Region(396,450)""")

}

test("error early on a bad type in a recursive function") {
val testCode = """
package BadRec

enum N: Z, S(n: N)

def toInt(n: N, acc: Int) -> Int:
recur n:
case Z: acc
case S(n): toInt(n, "foo")

"""
evalFail(List(testCode)) { case [email protected](_, _) =>
val message = kie.message(Map.empty, Colorize.None)
assert(message.contains("Region(122,127)"))
val badRegion = testCode.substring(122, 127)
assert(badRegion == "\"foo\"")
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without this change the bad region is the entire body of toInt which is very nondescript.

()
}
}

test("declaring a generic parameter works fine") {
runBosatsuTest(List("""
package Generic
Expand Down