Skip to content

Commit

Permalink
Improve external def parsing (#1188)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnynek authored Mar 28, 2024
1 parent a472c8d commit 18efb84
Show file tree
Hide file tree
Showing 11 changed files with 186 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ object DefRecursionCheck {
case Def(defn) =>
// make this the same shape as a in declaration
checkDef(TopLevel, defn.copy(result = (defn.result, ())))
case ExternalDef(_, _, _) =>
case ExternalDef(_, _, _, _) =>
unitValid
}
case _ => unitValid
Expand Down
16 changes: 15 additions & 1 deletion core/src/main/scala/org/bykn/bosatsu/Package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,11 @@ object Package {
Type.Const.Defined(p, TypeName(tds.name)) -> tds.region
}.toMap

lazy val extDefRegions: Map[Identifier.Bindable, Region] =
stmts.iterator.collect { case ed: Statement.ExternalDef =>
ed.name -> ed.region
}.toMap

optProg.flatMap {
case Program((importedTypeEnv, parsedTypeEnv), lets, extDefs, _) =>
val inferVarianceParsed
Expand Down Expand Up @@ -336,8 +341,17 @@ object Package {
errs.map(PackageError.TotalityCheckError(p, _))
}

val theseExternals =
parsedTypeEnv
.externalDefs
.collect { case (pack, b, t) if pack === p =>
// by construction this has to have all the regions
(b, (t, extDefRegions(b)))
}
.toMap

val inferenceEither = Infer
.typeCheckLets(p, lets)
.typeCheckLets(p, lets, theseExternals)
.runFully(
withFQN,
Referant.typeConstructors(imps) ++ typeEnv.typeConstructors,
Expand Down
14 changes: 14 additions & 0 deletions core/src/main/scala/org/bykn/bosatsu/PackageError.scala
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,20 @@ object PackageError {
) + Doc.hardLine +
context

(doc, Some(region))
case Infer.Error.KindExpectedType(tpe, kind, region) =>
val tmap = showTypes(pack, tpe :: Nil)
val context =
lm.showRegion(region, 2, errColor)
.getOrElse(
Doc.str(region)
) // we should highlight the whole region
val doc = Doc.text("expected type ") +
tmap(tpe) + Doc.text(
" to have kind *, which is to say be a valid value, but it is kind "
) + Kind.toDoc(kind) + Doc.hardLine +
context

(doc, Some(region))
case Infer.Error.KindInvalidApply(applied, leftK, rightK, region) =>
val leftT = applied.on
Expand Down
63 changes: 49 additions & 14 deletions core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ final class SourceConverter(
SourceConverter.InvalidDefTypeParameters(
args,
freeVarsList,
ds,
Right(ds),
region
),
gen
Expand Down Expand Up @@ -1227,7 +1227,7 @@ final class SourceConverter(
values: List[Statement.ValueStatement]
): Result[Unit] = {
val extDefNames =
values.collect { case ed @ Statement.ExternalDef(name, _, _) =>
values.collect { case ed @ Statement.ExternalDef(name, _, _, _) =>
(name, ed.region)
}

Expand Down Expand Up @@ -1259,7 +1259,7 @@ final class SourceConverter(
s match {
case b @ Statement.Bind(_) => Some(Left(b))
case d @ Statement.Def(_) => Some(Right(d))
case Statement.ExternalDef(_, _, _) => None
case Statement.ExternalDef(_, _, _, _) => None
}

def checkDefBind(s: Statement.ValueStatement): Result[Unit] =
Expand Down Expand Up @@ -1391,7 +1391,7 @@ final class SourceConverter(
stmts.toList.flatMap {
case d @ Def(_) =>
(d.defstatement.name, RecursionKind.Recursive, Left(d)) :: Nil
case ExternalDef(_, _, _) =>
case ExternalDef(_, _, _, _) =>
// we don't allow external defs to shadow at all, so skip it here
Nil
case Bind(BindingStatement(bound, decl, _)) =>
Expand Down Expand Up @@ -1456,7 +1456,7 @@ final class SourceConverter(
}

val withEx: List[Either[ExternalDef, Flattened]] =
stmts.collect { case e @ ExternalDef(_, _, _) => Left(e) }.toList :::
stmts.collect { case e @ ExternalDef(_, _, _, _) => Left(e) }.toList :::
flatIn.map {
case (b, _, Left(d @ Def(dstmt))) =>
Right(Left(Def(dstmt.copy(name = b))(d.region)))
Expand Down Expand Up @@ -1513,7 +1513,7 @@ final class SourceConverter(
(boundName, rec, l1) :: Nil
}
(topBound1, r)
case Left(ExternalDef(n, _, _)) =>
case Left(ExternalDef(n, _, _, _)) =>
(topBound + n, success(Nil))
}
}(SourceConverter.parallelIor)).map(_.flatten)
Expand All @@ -1526,7 +1526,7 @@ final class SourceConverter(
], List[Statement]]] = {
val stmts = Statement.valuesOf(ss).toList
stmts
.collect { case ed @ Statement.ExternalDef(name, params, result) =>
.collect { case ed @ Statement.ExternalDef(name, ta, params, result) =>
(
params.traverse(p => toType(p._2, ed.region)),
toType(result, ed.region)
Expand All @@ -1547,7 +1547,7 @@ final class SourceConverter(
}
}
}
.map { (tpe: rankn.Type) =>
.flatMap { (tpe: rankn.Type) =>
val freeVars = rankn.Type.freeTyVars(tpe :: Nil)
// these vars were parsed so they are never skolem vars
val freeBound = freeVars.map {
Expand All @@ -1557,10 +1557,34 @@ final class SourceConverter(
sys.error(s"invariant violation: parsed a skolem var: $s")
// $COVERAGE-ON$
}
// TODO: Kind support parsing kinds
val maybeForAll =
rankn.Type.forAll(freeBound.map(n => (n, Kind.Type)), tpe)
(name, maybeForAll)
val finalTpe = ta match {
case None =>
success(rankn.Type.forAll(freeBound.map(n => (n, Kind.Type)), tpe))
case Some(frees0) =>
val frees = frees0.map { case (ref, optK) => ref.toBoundVar -> optK }
if (frees.iterator.map(_._1).toSet === freeBound.toSet[rankn.Type.Var.Bound]) {
success(rankn.Type.forAll(frees.map {
case (v, None) => (v, Kind.Type)
case (v, Some(k)) => (v, k)
}, tpe))
}
else {
val kindMap = frees.iterator.collect { case (v, Some(k)) => (v, k) }.toMap
val vs = freeBound.map { v => (v, kindMap.getOrElse(v, Kind.Type)) }
val t = rankn.Type.forAll(vs, tpe)
SourceConverter.partial(
SourceConverter.InvalidDefTypeParameters(
frees0,
freeBound,
Left(ed),
ed.region
),
t
)
}
}

finalTpe.map(name -> _)
}
}
// TODO: we could implement Iterable[Ior[A, B]] => Ior[A, Iterble[B]]
Expand Down Expand Up @@ -1887,10 +1911,21 @@ object SourceConverter {
final case class InvalidDefTypeParameters[B](
declaredParams: NonEmptyList[(TypeRef.TypeVar, Option[Kind])],
free: List[Type.Var.Bound],
defstmt: DefStatement[Pattern.Parsed, B],
defstmt: Either[Statement.ExternalDef, DefStatement[Pattern.Parsed, B]],
region: Region
) extends Error {

def name: Identifier.Bindable = defstmt match {
case Right(ds) => ds.name
case Left(ed) => ed.name
}

def expectation: String = defstmt match {
case Right(_) => "a subset of"
case Left(_) => "the same as"
}


def message = {
def tstr(l: List[Type.Var.Bound]): String =
l.iterator.map(_.name).mkString("[", ", ", "]")
Expand All @@ -1903,7 +1938,7 @@ object SourceConverter {
.renderTrim(80)

val freeStr = tstr(free)
s"${defstmt.name.asString} found declared types: $decl, not a subset of $freeStr"
s"${name.asString} found declared types: $decl, not $expectation $freeStr"
}
}

Expand Down
37 changes: 25 additions & 12 deletions core/src/main/scala/org/bykn/bosatsu/Statement.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ sealed abstract class Statement {
Struct(nm, typeArgs, args)(r)
case Enum(nm, typeArgs, parts) =>
Enum(nm, typeArgs, parts)(r)
case ExternalDef(name, args, res) =>
ExternalDef(name, args, res)(r)
case ExternalDef(name, ta, args, res) =>
ExternalDef(name, ta, args, res)(r)
case ExternalStruct(nm, targs) =>
ExternalStruct(nm, targs)(r)
}
Expand Down Expand Up @@ -83,7 +83,7 @@ object Statement {
case Bind(BindingStatement(bound, _, _)) =>
bound.names // TODO Keep identifiers
case Def(defstatement) => defstatement.name :: Nil
case ExternalDef(name, _, _) => name :: Nil
case ExternalDef(name, _, _, _) => name :: Nil
}

/** These are all the free bindable names in the right hand side of this
Expand All @@ -98,7 +98,7 @@ object Statement {
(innerFrees - defstatement.name) -- defstatement.args.toList.flatMap(
_.patternNames
)
case ExternalDef(_, _, _) => SortedSet.empty
case ExternalDef(_, _, _, _) => SortedSet.empty
}

/** These are all the bindings, free or not, in this Statement
Expand All @@ -109,7 +109,7 @@ object Statement {
case Def(defstatement) =>
(defstatement.result.get.allNames + defstatement.name) ++ defstatement.args.toList
.flatMap(_.patternNames)
case ExternalDef(name, _, _) => SortedSet(name)
case ExternalDef(name, _, _, _) => SortedSet(name)
}
}

Expand All @@ -126,6 +126,7 @@ object Statement {
extends ValueStatement
case class ExternalDef(
name: Bindable,
typeArgs: Option[NonEmptyList[(TypeRef.TypeVar, Option[Kind])]],
params: List[(Bindable, TypeRef)],
result: TypeRef
)(val region: Region)
Expand Down Expand Up @@ -230,6 +231,10 @@ object Statement {

val externalDef = {

val kindAnnot: P[Kind] =
(maybeSpace.soft.with1 *> (P.char(':') *> maybeSpace *> Kind.parser))
val typeParams = TypeRef.typeParams(kindAnnot.?).?

val args =
P.char('(') *> maybeSpace *> argParser.nonEmptyList <* maybeSpace <* P
.char(')')
Expand All @@ -239,16 +244,16 @@ object Statement {

(((keySpace(
"def"
) *> Identifier.bindableParser ~ args ~ result).region) <* toEOL)
.map { case (region, ((name, args), resType)) =>
ExternalDef(name, args.toList, resType)(region)
) *> Identifier.bindableParser ~ typeParams ~ args ~ result).region) <* toEOL)
.map { case (region, (((name, tps), args), resType)) =>
ExternalDef(name, tps, args.toList, resType)(region)
}
}

val externalVal =
(argParser <* toEOL).region
.map { case (region, (name, resType)) =>
ExternalDef(name, Nil, resType)(region)
ExternalDef(name, None, Nil, resType)(region)
}

keySpace("external") *> P.oneOf(
Expand Down Expand Up @@ -385,11 +390,19 @@ object Statement {
.char(':') +
colonSep +
indentedCons + Doc.line
case ExternalDef(name, Nil, res) =>
case ExternalDef(name, None, Nil, res) =>
Doc.text("external ") + Document[Bindable].document(name) + Doc.text(
": "
) + res.toDoc + Doc.line
case ExternalDef(name, args, res) =>
case ExternalDef(name, tps, args, res) =>
val taDoc = tps match {
case None => Doc.empty
case Some(ta) =>
TypeRef.docTypeArgs(ta.toList) {
case None => Doc.empty
case Some(k) => colonSpace + Kind.toDoc(k)
}
}
val argDoc = {
val da = Doc.intercalate(
Doc.text(", "),
Expand All @@ -401,7 +414,7 @@ object Statement {
}
Doc.text("external def ") + Document[Bindable].document(
name
) + argDoc + Doc.text(" -> ") + res.toDoc + Doc.line
) + taDoc + argDoc + Doc.text(" -> ") + res.toDoc + Doc.line
case ExternalStruct(nm, typeArgs) =>
val taDoc =
TypeRef.docTypeArgs(typeArgs.toList) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ object TypeRefConverter {
import TypeRef._

t match {
case TypeVar(v) => Applicative[F].pure(TyVar(Type.Var.Bound(v)))
case tv @ TypeVar(_) => Applicative[F].pure(TyVar(tv.toBoundVar))
case TypeName(n) => nameToType(n.ident).map(TyConst(_))
case TypeArrow(as, b) =>
(as.traverse(toType(_)), toType(b)).mapN(Fun(_, _))
Expand Down
25 changes: 23 additions & 2 deletions core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,11 @@ object Infer {
rightK: Kind,
region: Region
) extends TypeError
case class KindExpectedType(
tpe: Type,
kind: Kind.Cons,
region: Region
) extends TypeError
case class KindMismatch(
target: Type,
targetKind: Kind,
Expand Down Expand Up @@ -2614,7 +2619,8 @@ object Infer {
*/
def typeCheckLets[A: HasRegion](
pack: PackageName,
ls: List[(Bindable, RecursionKind, Expr[A])]
ls: List[(Bindable, RecursionKind, Expr[A])],
externals: Map[Bindable, (Type, Region)]
): Infer[List[(Bindable, RecursionKind, TypedExpr[A])]] = {
// Group together lets that don't include each other to get more type errors
// if we can
Expand Down Expand Up @@ -2655,7 +2661,22 @@ object Infer {
else Some(bs :+ item)
}

run(groups)
val checkExternals =
GetEnv.flatMap { env =>
externals
.toList
.sortBy { case (_, (_, region)) => region }
.parTraverse_ { case (_, (t, region)) =>
env.getKind(t, region) match {
case Right(Kind.Type) => unit
case Right(cons @ Kind.Cons(_, _)) =>
fail(Error.KindExpectedType(t, cons, region))
case Left(err) => fail(err)
}
}
}

run(groups).parProductL(checkExternals)
}

/** This is useful to testing purposes.
Expand Down
17 changes: 17 additions & 0 deletions core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3990,4 +3990,21 @@ test = TestSuite("bases",
12
)
}

test("external defs with explicit type parameters exactly match") {
val testCode = """
package ErrorCheck
external def foo[b](lst: List[a]) -> a
"""
evalFail(List(testCode)) {
case kie @ PackageError.SourceConverterErrorsIn(_, _, _) =>
val message = kie.message(Map.empty, Colorize.None)
assert(message.contains("Region(30,59)"))
assert(message.contains("[b], not the same as [a]"))
assert(testCode.substring(30, 59) == "def foo[b](lst: List[a]) -> a")
()
}
}
}
Loading

0 comments on commit 18efb84

Please sign in to comment.