Skip to content

Commit

Permalink
Add annotation on lambdas to improve type error locations
Browse files Browse the repository at this point in the history
  • Loading branch information
johnynek committed Sep 9, 2023
1 parent ae52f82 commit fc99ca2
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 11 deletions.
12 changes: 12 additions & 0 deletions core/src/main/scala/org/bykn/bosatsu/Pattern.scala
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,18 @@ sealed abstract class Pattern[+N, +T] {

loop(this)._2.distinct.sorted
}

/**
* @return the type if we can directly see it
*/
def simpleTypeOf: Option[T] =
this match {
case Pattern.Named(_, p) => p.simpleTypeOf
case Pattern.Annotation(_, t) => Some(t)
case Pattern.Union(_, _) | Pattern.ListPat(_) | Pattern.Literal(_) |
Pattern.WildCard | Pattern.Var(_) | Pattern.StrPat(_) |
Pattern.PositionalStruct(_, _) => None
}
}

object Pattern {
Expand Down
45 changes: 34 additions & 11 deletions core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package org.bykn.bosatsu

import cats.{Applicative, Traverse}
import cats.data.{ Chain, Ior, NonEmptyChain, NonEmptyList, State }
import cats.implicits._
import org.bykn.bosatsu.rankn.{ParsedTypeEnv, Type, TypeEnv}
import scala.collection.immutable.SortedSet
import scala.collection.mutable.{Map => MMap}
Expand All @@ -11,6 +10,8 @@ import org.typelevel.paiges.{Doc, Document}
// this is used to make slightly nicer syntax on Error creation
import scala.language.implicitConversions

import cats.syntax.all._

import ListLang.{KVPair, SpliceOrItem}

import Identifier.{Bindable, Constructor}
Expand Down Expand Up @@ -89,17 +90,37 @@ final class SourceConverter(
ds: DefStatement[Pattern.Parsed, B], region: Region, tag: Result[Declaration])(
resultExpr: B => Result[Expr[Declaration]]): Result[Expr[Declaration]] = {
val unTypedBody = resultExpr(ds.result)
val bodyExp =
ds.retType.fold(unTypedBody) { t =>
(unTypedBody, toType(t, region), tag).parMapN(Expr.Annotation(_, _, _))

val bodyType: Option[Result[Type]] = ds.retType.map(toType(_, region))

val bodyExp: Result[Expr[Declaration]] =
bodyType.fold(unTypedBody) { t =>
(unTypedBody, t, tag).parMapN(Expr.Annotation(_, _, _))
}

val travNE2 = Traverse[NonEmptyList].compose[NonEmptyList]

type Pat = Pattern[(PackageName, Constructor), Type]
val convertedArgs: Result[NonEmptyList[NonEmptyList[Pat]]] =
travNE2.traverse(ds.args)(convertPattern(_, region))

// If we have the full type of the lambda, apply it. This
// helps in recursive cases since we can see at the call site
// rather than the final recursive let binding that an application
// was incorrect. Without this, type errors become very non-specific.
val maybeFullyTyped: Result[Option[Type]] =
(convertedArgs, bodyType.sequence).parMapN { case (args, optResTpe) =>
(travNE2.traverse(args)((p: Pat) => p.simpleTypeOf), optResTpe).mapN { case (argsTpe, resTpe) =>
argsTpe.toList.foldRight(resTpe) { (args, res) => rankn.Type.Fun(args, res) }
}
}

(Traverse[NonEmptyList]
.compose[NonEmptyList]
.traverse(ds.args)(convertPattern(_, region)),
(convertedArgs,
bodyExp,
tag).parMapN { (groups, b, t) =>
val lambda = groups.toList.foldRight(b) { case (as, b) => Expr.buildPatternLambda(as, b, t) }
tag,
maybeFullyTyped).parMapN { (groups, b, t, fullType) =>
val lambda0 = groups.toList.foldRight(b) { case (as, b) => Expr.buildPatternLambda(as, b, t) }
val lambda = fullType.fold(lambda0)(Expr.Annotation(lambda0, _, t))
ds.typeArgs match {
case None => success(lambda)
case Some(args) =>
Expand Down Expand Up @@ -189,7 +210,7 @@ final class SourceConverter(
}
case pat =>
// TODO: we need the region on the pattern...
(convertPattern(pat, decl.region), erest, rrhs).parMapN { (newPattern, e, rhs) =>
(convertPattern(pat, decl.region - value.region), erest, rrhs).parMapN { (newPattern, e, rhs) =>
val expBranches = NonEmptyList.of((newPattern, e))
Expr.Match(rhs, expBranches, decl)
}
Expand Down Expand Up @@ -996,8 +1017,10 @@ final class SourceConverter(
val ident = alloc()
NonEmptyList.one((ident, decl))
case complex =>
// TODO, flattening the pattern (a, b, c, d) = (1, 2, 3, 4) might be nice...
// flattening the pattern (a, b, c, d) = (1, 2, 3, 4) might be nice...
// that is not done yet, it will allocate the tuple, just to destructure it
// but that optimization is done later since it is allocation and deallocation
// of a struct
val (prefix, rightHandSide) =
if (decl.isCheap) {
// no need to make a new var to point to a var
Expand Down
21 changes: 21 additions & 0 deletions core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2891,6 +2891,27 @@ Region(396,450)""")

}

test("error early on a bad type in a recursive function") {
val testCode = """
package BadRec
enum N: Z, S(n: N)
def toInt(n: N, acc: Int) -> Int:
recur n:
case Z: acc
case S(n): toInt(n, "foo")
"""
evalFail(List(testCode)) { case kie@PackageError.TypeErrorIn(_, _) =>
val message = kie.message(Map.empty, Colorize.None)
assert(message.contains("Region(122,127)"))
val badRegion = testCode.substring(122, 127)
assert(badRegion == "\"foo\"")
()
}
}

test("declaring a generic parameter works fine") {
runBosatsuTest(List("""
package Generic
Expand Down

0 comments on commit fc99ca2

Please sign in to comment.