Skip to content

Commit

Permalink
Fix issue 446 (#1158)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnynek authored Feb 26, 2024
1 parent 5d84d3f commit 88da254
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 43 deletions.
22 changes: 22 additions & 0 deletions core/src/main/scala/org/bykn/bosatsu/ListUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,26 @@ private[bosatsu] object ListUtil {
if (bs eq as) nel
else NonEmptyList.fromListUnsafe(bs)
}

def distinctByHashSet[A](nel: NonEmptyList[A]): NonEmptyList[A] = {
// This code leverages the scala type ::[A] which is the nonempty
// list in order to avoid allocations building a NonEmptyList
// since a :: tailnel will have to allocate twice vs 1 time.
def revCons(item: ::[A], tail: List[A]): NonEmptyList[A] =
item.tail match {
case nel: ::[A] => revCons(nel, item.head :: tail)
case _: Nil.type => NonEmptyList(item.head, tail)
}
@annotation.tailrec
def loop(prior: Set[A], tail: List[A], front: ::[A]): NonEmptyList[A] =
tail match {
case Nil => revCons(front, Nil)
case head :: next =>
if (prior(head)) loop(prior, next, front)
else loop(prior + head, next, ::(head, front))
}

val h = nel.head
loop(Set.empty + h, nel.tail, ::(h, Nil))
}
}
13 changes: 8 additions & 5 deletions core/src/main/scala/org/bykn/bosatsu/Package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,14 @@ object Package {
// here we make a pass to get all the local names
val optProg = SourceConverter
.toProgram(p, imps.map(i => i.copy(pack = i.pack.name)), stmts)
.leftMap(
_.map(
PackageError.SourceConverterErrorIn(_, p): PackageError
).toNonEmptyList
)
.leftMap { scerrs =>
scerrs.groupByNem(_.region)
.transform { (region, errs) =>
val uniqs = ListUtil.distinctByHashSet(errs.toNonEmptyList)
PackageError.SourceConverterErrorsIn(region, uniqs, p): PackageError
}
.toNonEmptyList
}

lazy val typeDefRegions: Map[Type.Const.Defined, Region] =
stmts.iterator.collect { case tds: TypeDefinitionStatement =>
Expand Down
61 changes: 50 additions & 11 deletions core/src/main/scala/org/bykn/bosatsu/PackageError.scala
Original file line number Diff line number Diff line change
Expand Up @@ -561,25 +561,64 @@ object PackageError {
}
}

case class SourceConverterErrorIn(
err: SourceConverter.Error,
case class SourceConverterErrorsIn(
region: Region,
errs: NonEmptyList[SourceConverter.Error],
pack: PackageName
) extends PackageError {
def message(
sourceMap: Map[PackageName, (LocationMap, String)],
errColor: Colorize
) = {
val (lm, _) = sourceMap.getMapSrc(pack)
val msg = {
val context =
lm.showRegion(err.region, 2, errColor)
.getOrElse(
Doc.str(err.region)
) // we should highlight the whole region

Doc.text(err.message) + Doc.hardLine + context
val context =
lm.showRegion(region, 2, errColor)
.getOrElse(
Doc.str(region)
) // we should highlight the whole region
val headDoc = sourceMap.headLine(pack, Some(region))

val (missing, notMissing) = errs.toList.partitionMap {
case ma: SourceConverter.MissingArg => Left(ma)
case notMa => Right(notMa)
}
val mdocs = missing.groupBy { ma => (ma.name, ma.syntax) }
.toList
.sortBy { case ((name, _), _) => name }
.map { case ((_, syn), mas) =>
val allMissing = mas.map(_.missing)

val missingDoc = Doc.intercalate(Doc.comma + Doc.space,
allMissing.sorted.map { m => Doc.text(m.asString) })

val fieldStr = if (allMissing.lengthCompare(1) == 0) "field" else "fields"

val hint =
syn match {
case SourceConverter.ConstructorSyntax.Pat(_) =>
Doc.line + Doc.text("if you want to ignore those fields, add a ... to signify ignoring missing.")
case _ =>
// we can't ignore fields when constructing
Doc.empty
}
(Doc.text(s"missing $fieldStr: ") + missingDoc + Doc.line + Doc.text("in") +
Doc.line + syn.toDoc + hint
).nested(4)
}

val mdoc = Doc.intercalate(Doc.hardLine, mdocs)
val notMDoc = Doc.intercalate(Doc.hardLine, notMissing.map { se => Doc.text(se.message) })
val msg = if (missing.nonEmpty) {
if (notMissing.nonEmpty) {
mdoc + Doc.hardLine + notMDoc
}
else mdoc
}
else {
notMDoc
}
val doc = sourceMap.headLine(pack, Some(err.region)) + Doc.hardLine + msg

val doc = headDoc + Doc.hardLine + msg + Doc.hardLine + context

doc.render(80)
}
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/bykn/bosatsu/Pattern.scala
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,7 @@ object Pattern {
}
)
prefix +
Doc.text(" {") +
Doc.text(" { ") +
kvargs +
suffix +
Doc.text(" }")
Expand Down
50 changes: 25 additions & 25 deletions core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1706,7 +1706,7 @@ x = 1
main = match x:
case Foo: 2
""")) { case te @ PackageError.SourceConverterErrorIn(_, _) =>
""")) { case te @ PackageError.SourceConverterErrorsIn(_, _, _) =>
val msg = te.message(Map.empty, Colorize.None)
assert(!msg.contains("Name("))
assert(msg.contains("package B\nunknown constructor Foo"))
Expand All @@ -1720,7 +1720,7 @@ struct X
main = match 1:
case X1: 0
""")) { case te @ PackageError.SourceConverterErrorIn(_, _) =>
""")) { case te @ PackageError.SourceConverterErrorsIn(_, _, _) =>
assert(
te.message(
Map.empty,
Expand Down Expand Up @@ -2296,7 +2296,7 @@ get = Pair(first, ...) -> first
# missing second
first = 1
res = get(Pair { first })
""")) { case s @ PackageError.SourceConverterErrorIn(_, _) =>
""")) { case s @ PackageError.SourceConverterErrorsIn(_, _, _) =>
s.message(Map.empty, Colorize.None); ()
}

Expand All @@ -2311,7 +2311,7 @@ get = Pair(first, ...) -> first
first = 1
second = 3
res = get(Pair { first, second, third })
""")) { case s @ PackageError.SourceConverterErrorIn(_, _) =>
""")) { case s @ PackageError.SourceConverterErrorsIn(_, _, _) =>
s.message(Map.empty, Colorize.None); ()
}

Expand All @@ -2323,7 +2323,7 @@ struct Pair(first, second)
get = Pair { first } -> first
res = get(Pair(1, "two"))
""")) { case s @ PackageError.SourceConverterErrorIn(_, _) =>
""")) { case s @ PackageError.SourceConverterErrorsIn(_, _, _) =>
s.message(Map.empty, Colorize.None); ()
}

Expand All @@ -2336,7 +2336,7 @@ struct Pair(first, second)
get = Pair(first) -> first
res = get(Pair(1, "two"))
""")) { case s @ PackageError.SourceConverterErrorIn(_, _) =>
""")) { case s @ PackageError.SourceConverterErrorsIn(_, _, _) =>
s.message(Map.empty, Colorize.None); ()
}

Expand All @@ -2349,7 +2349,7 @@ struct Pair(first, second)
get = \Pair { first, sec: _ } -> first
res = get(Pair(1, "two"))
""")) { case s @ PackageError.SourceConverterErrorIn(_, _) =>
""")) { case s @ PackageError.SourceConverterErrorsIn(_, _, _) =>
s.message(Map.empty, Colorize.None); ()
}

Expand All @@ -2362,7 +2362,7 @@ struct Pair(first, second)
get = Pair { first, sec: _, ... } -> first
res = get(Pair(1, "two"))
""")) { case s @ PackageError.SourceConverterErrorIn(_, _) =>
""")) { case s @ PackageError.SourceConverterErrorsIn(_, _, _) =>
s.message(Map.empty, Colorize.None); ()
}

Expand All @@ -2375,7 +2375,7 @@ struct Pair(first, second)
get = Pair(first, _, _) -> first
res = get(Pair(1, "two"))
""")) { case s @ PackageError.SourceConverterErrorIn(_, _) =>
""")) { case s @ PackageError.SourceConverterErrorsIn(_, _, _) =>
s.message(Map.empty, Colorize.None); ()
}

Expand All @@ -2388,7 +2388,7 @@ struct Pair(first, second)
get = Pair(first, _, _, ...) -> first
res = get(Pair(1, "two"))
""")) { case s @ PackageError.SourceConverterErrorIn(_, _) =>
""")) { case s @ PackageError.SourceConverterErrorsIn(_, _, _) =>
s.message(Map.empty, Colorize.None); ()
}
}
Expand Down Expand Up @@ -2613,7 +2613,7 @@ external def foo(x: String) -> List[String]
def foo(x): x
""")) { case s @ PackageError.SourceConverterErrorIn(_, _) =>
""")) { case s @ PackageError.SourceConverterErrorsIn(_, _, _) =>
assert(
s.message(
Map.empty,
Expand All @@ -2630,7 +2630,7 @@ external def foo(x: String) -> List[String]
foo = 1
""")) { case s @ PackageError.SourceConverterErrorIn(_, _) =>
""")) { case s @ PackageError.SourceConverterErrorsIn(_, _, _) =>
assert(
s.message(
Map.empty,
Expand All @@ -2646,7 +2646,7 @@ package A
external def foo(x: String) -> List[String]
external def foo(x: String) -> List[String]
""")) { case s @ PackageError.SourceConverterErrorIn(_, _) =>
""")) { case s @ PackageError.SourceConverterErrorsIn(_, _, _) =>
assert(
s.message(
Map.empty,
Expand Down Expand Up @@ -2686,7 +2686,7 @@ package Err
struct Foo[a](a)
main = Foo(1, "2")
""")) { case sce @ PackageError.SourceConverterErrorIn(_, _) =>
""")) { case sce @ PackageError.SourceConverterErrorsIn(_, _, _) =>
assert(
sce.message(
Map.empty,
Expand All @@ -2702,7 +2702,7 @@ package Err
struct Foo[a](a: a, b: b)
main = Foo(1, "2")
""")) { case sce @ PackageError.SourceConverterErrorIn(_, _) =>
""")) { case sce @ PackageError.SourceConverterErrorsIn(_, _, _) =>
assert(
sce.message(
Map.empty,
Expand All @@ -2718,7 +2718,7 @@ package Err
enum Enum[a]: Foo(a)
main = Foo(1, "2")
""")) { case sce @ PackageError.SourceConverterErrorIn(_, _) =>
""")) { case sce @ PackageError.SourceConverterErrorsIn(_, _, _) =>
assert(
sce.message(
Map.empty,
Expand All @@ -2734,7 +2734,7 @@ package Err
enum Enum[a]: Foo(a: a), Bar(a: b)
main = Foo(1, "2")
""")) { case sce @ PackageError.SourceConverterErrorIn(_, _) =>
""")) { case sce @ PackageError.SourceConverterErrorsIn(_, _, _) =>
assert(
sce.message(
Map.empty,
Expand Down Expand Up @@ -3004,7 +3004,7 @@ struct Foo
struct Foo(x)

main = Foo(1)
""")) { case sce @ PackageError.SourceConverterErrorIn(_, _) =>
""")) { case sce @ PackageError.SourceConverterErrorsIn(_, _, _) =>
assert(
sce.message(
Map.empty,
Expand All @@ -3024,7 +3024,7 @@ enum Bar: Foo
struct Foo(x)

main = Foo(1)
""")) { case sce @ PackageError.SourceConverterErrorIn(_, _) =>
""")) { case sce @ PackageError.SourceConverterErrorsIn(_, _, _) =>
assert(
sce.message(
Map.empty,
Expand Down Expand Up @@ -3132,7 +3132,7 @@ out = match (1,2):
case (a, a): a
test = Assertion(True, "")
""")) { case sce @ PackageError.SourceConverterErrorIn(_, _) =>
""")) { case sce @ PackageError.SourceConverterErrorsIn(_, _, _) =>
assert(
sce.message(
Map.empty,
Expand All @@ -3149,7 +3149,7 @@ out = match [(1,2), (1, 0)]:
case _: 0
test = Assertion(True, "")
""")) { case sce @ PackageError.SourceConverterErrorIn(_, _) =>
""")) { case sce @ PackageError.SourceConverterErrorsIn(_, _, _) =>
assert(
sce.message(
Map.empty,
Expand Down Expand Up @@ -3198,7 +3198,7 @@ struct Bar(baz: Either[Int, String])
test = Assertion(True, "")
""")) { case sce @ PackageError.SourceConverterErrorIn(_, _) =>
""")) { case sce @ PackageError.SourceConverterErrorsIn(_, _, _) =>
assert(
sce.message(
Map.empty,
Expand Down Expand Up @@ -3287,7 +3287,7 @@ def foo[a](a: a) -> a:
def and_again[b](x: a): x
and_again(again(x))
""")) { case sce @ PackageError.SourceConverterErrorIn(_, _) =>
""")) { case sce @ PackageError.SourceConverterErrorsIn(_, _, _) =>
assert(
sce.message(
Map.empty,
Expand Down Expand Up @@ -3870,7 +3870,7 @@ z = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
"""
evalFail(List(testCode)) {
case kie @ PackageError.SourceConverterErrorIn(_, _) =>
case kie @ PackageError.SourceConverterErrorsIn(_, _, _) =>
val message = kie.message(Map.empty, Colorize.None)
assert(
message.contains(
Expand All @@ -3896,7 +3896,7 @@ res = z matches (1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
"""
evalFail(List(testCode1)) {
case kie @ PackageError.SourceConverterErrorIn(_, _) =>
case kie @ PackageError.SourceConverterErrorsIn(_, _, _) =>
val message = kie.message(Map.empty, Colorize.None)
assert(
message.contains(
Expand Down
9 changes: 9 additions & 0 deletions core/src/test/scala/org/bykn/bosatsu/ListUtilTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,13 @@ class ListUtilTest extends AnyFunSuite {
}
}
}

test("distinctByHashSet works like List.distinct") {
forAll { (nel: NonEmptyList[Byte] ) =>
val asList = nel.toList.distinct
val viaFn = ListUtil.distinctByHashSet(nel).toList

assert(viaFn == asList)
}
}
}
3 changes: 2 additions & 1 deletion core/src/test/scala/org/bykn/bosatsu/ParserTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ trait ParseFns {
("...(" + s.drop(idx - 20).take(30) + ")...")
}

def firstDiff(s1: String, s2: String): String =
@annotation.tailrec
final def firstDiff(s1: String, s2: String): String =
if (s1 == s2) ""
else if (s1.isEmpty) s2
else if (s2.isEmpty) s1
Expand Down

0 comments on commit 88da254

Please sign in to comment.