Skip to content

Commit

Permalink
checkpoint with interpreter tests working
Browse files Browse the repository at this point in the history
  • Loading branch information
johnynek committed Oct 1, 2023
1 parent 1cc9966 commit 2471d42
Show file tree
Hide file tree
Showing 9 changed files with 94 additions and 30 deletions.
2 changes: 1 addition & 1 deletion core/src/main/scala/org/bykn/bosatsu/Lit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ object Lit {
lazy val asStr: String =
(new java.lang.StringBuilder).appendCodePoint(toCodePoint).toString

def unboxToAny: Any = toCodePoint
def unboxToAny: Any = asStr
}

val EmptyStr: Str = Str("")
Expand Down
38 changes: 36 additions & 2 deletions core/src/main/scala/org/bykn/bosatsu/MatchlessToValue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,21 @@ object MatchlessToValue {
case LitStr(expect) :: tail =>
val len = expect.length
str.regionMatches(offset, expect, 0, len) && loop(offset + len, tail, next)
case (c: CharPart) :: tail => ???
case (c: CharPart) :: tail =>
try {
val nextOffset = str.offsetByCodePoints(offset, 1)
val n =
if (c.capture) {
results(next) = str.substring(offset, nextOffset)
next + 1
}
else next

loop(nextOffset, tail, n)
}
catch {
case _: IndexOutOfBoundsException => false
}
case (h: Glob) :: tail =>
tail match {
case Nil =>
Expand All @@ -550,7 +564,27 @@ object MatchlessToValue {
results(next) = str.substring(offset)
}
true
case (c: CharPart) :: tail2 => ???
case rest @ ((_: CharPart) :: _) =>
// (.*)(.)tail2
// this is a naive algorithm that just
// checks at all possible later offsets
// a smarter algorithm could see if there
// are Lit parts that can match or not
val checks = (offset until str.length).iterator
var matched = false
var off1 = offset
val n1 = if (h.capture) (next + 1) else next
while (!matched && checks.hasNext) {
off1 = checks.next()
matched = loop(off1, rest, n1)
}

matched && {
if (h.capture) {
results(next) = str.substring(offset, off1)
}
true
}
case LitStr(expect) :: tail2 =>
val next1 = if (h.capture) next + 1 else next

Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/bykn/bosatsu/PackageError.scala
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ object PackageError {
case InvalidStrPat(pat, _) =>
Doc.text(s"invalid string pattern: ") +
Document[Pattern.Parsed].document(pat) +
Doc.text(" (adjacent bindings aren't allowed)")
Doc.text(" (adjacent string bindings aren't allowed)")
case MultipleSplicesInPattern(_, _) =>
// TODO: get printing of compiled patterns working well
//val docp = Document[Pattern.Parsed].document(Pattern.ListPat(pat)) +
Expand Down
27 changes: 11 additions & 16 deletions core/src/main/scala/org/bykn/bosatsu/Pattern.scala
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,9 @@ object Pattern {
case SeqPart.Lit(c) :: tail =>
loop(tail, c :: front)
case SeqPart.AnyElem :: tail =>
loop(tail, Nil).prepend(StrPart.WildChar)
loop(tail, Nil)
.prepend(StrPart.WildChar)
.prependList(lit(front))
case SeqPart.Wildcard :: SeqPart.AnyElem :: tail =>
// *_, _ is the same as _, *_
loop(SeqPart.AnyElem :: SeqPart.Wildcard :: tail, front)
Expand All @@ -542,6 +544,8 @@ object Pattern {
}

def toNamedSeqPattern(sp: StrPat): NamedSeqPattern[Char] = {
val empty: NamedSeqPattern[Char] = NamedSeqPattern.NEmpty

def partToNsp(s: StrPart): NamedSeqPattern[Char] =
s match {
case StrPart.NamedStr(n) =>
Expand All @@ -551,24 +555,15 @@ object Pattern {
case StrPart.WildStr => NamedSeqPattern.Wild
case StrPart.WildChar => NamedSeqPattern.Any
case StrPart.LitStr(s) =>
// reverse so we can build right associated
s.toList.reverse match {
case Nil => NamedSeqPattern.NEmpty
case h :: tail =>
tail.foldLeft(NamedSeqPattern.fromLit(h)) { (right, head) =>
NamedSeqPattern.NCat(NamedSeqPattern.fromLit(head), right)
}
if (s.isEmpty) empty
else s.toList.foldRight(empty) { (c, tail) =>
NamedSeqPattern.NCat(NamedSeqPattern.fromLit(c), tail)
}
}

def loop(sp: List[StrPart]): NamedSeqPattern[Char] =
sp match {
case Nil => NamedSeqPattern.NEmpty
case h :: t =>
NamedSeqPattern.NCat(partToNsp(h), loop(t))
}

loop(sp.parts.toList)
sp.parts.toList.foldRight(empty) { (h, t) =>
NamedSeqPattern.NCat(partToNsp(h), t)
}
}

def fromLitStr(s: String): StrPat =
Expand Down
14 changes: 13 additions & 1 deletion core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,17 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) {

case sp@StrPat(_) =>
val simp = sp.toSeqPattern
if (simp.normalize == simp) validUnit
def hasAdjacentWild[A](seq: SeqPattern[A]): Boolean =
seq match {
case SeqPattern.Empty => false
case SeqPattern.Cat(SeqPart.Wildcard, tail) =>
tail match {
case SeqPattern.Cat(SeqPart.Wildcard, _) => true
case notStartWild => hasAdjacentWild(notStartWild)
}
case SeqPattern.Cat(_, tail) => hasAdjacentWild(tail)
}
if (!hasAdjacentWild(simp)) validUnit
else Left(NonEmptyList(InvalidStrPat(sp, inEnv), Nil))

case PositionalStruct(name, args) =>
Expand Down Expand Up @@ -315,6 +325,8 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) {
case (WildCard, right@StrPat(_)) =>
// _ is the same as "${_}" for well typed expressions
strPatternSetOps.difference(StrPat(NonEmptyList(StrPart.WildStr, Nil)), right)
case (WildCard, Literal(Lit.Str(str))) =>
difference(WildCard, StrPat.fromLitStr(str))
case (Var(v), right@StrPat(_)) =>
// v is the same as "${v}" for well typed expressions
strPatternSetOps.difference(StrPat(NonEmptyList(StrPart.NamedStr(v), Nil)), right)
Expand Down
6 changes: 2 additions & 4 deletions core/src/main/scala/org/bykn/bosatsu/pattern/SeqPattern.scala
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,8 @@ object SeqPattern {
}

def fromList[A](ps: List[SeqPart[A]]): SeqPattern[A] =
ps match {
case h :: tail =>
Cat(h, fromList(tail))
case Nil => Empty
ps.foldRight(Empty: SeqPattern[A]) { (h, tail) =>
Cat(h, tail)
}

val Wild: SeqPattern[Nothing] = Cat(SeqPart.Wildcard, Empty)
Expand Down
12 changes: 9 additions & 3 deletions core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2349,7 +2349,7 @@ main = match x:
""")) { case sce@PackageError.TotalityCheckError(_, _) =>
val dollar = '$'
assert(sce.message(Map.empty, Colorize.None) ==
s"in file: <unknown source>, package Err\nRegion(36,91)\ninvalid string pattern: '$dollar{_}$dollar{_}' (adjacent bindings aren't allowed)")
s"in file: <unknown source>, package Err\nRegion(36,91)\ninvalid string pattern: '$dollar{_}$dollar{_}' (adjacent string bindings aren't allowed)")
()
}
}
Expand Down Expand Up @@ -3143,7 +3143,13 @@ good2 = match "$.{just_x}":
test2 = Assertion(good2, "interpolation match")
all = TestSuite("chars", [test1, test2])
"""), "Foo", 2)
def last(str) -> Option[Char]:
match str:
case "": None
case "${_}$.{c}": Some(c)
test3 = Assertion(last("foo") matches Some(.'o'), "last test")
all = TestSuite("chars", [test1, test2, test3])
"""), "Foo", 3)
}
}
17 changes: 15 additions & 2 deletions core/src/test/scala/org/bykn/bosatsu/TotalityTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ import org.typelevel.paiges.Document

import Identifier.Constructor

import cats.implicits._

class TotalityTest extends SetOpsLaws[Pattern[(PackageName, Constructor), Type]] {
type Pat = Pattern[(PackageName, Constructor), Type]

Expand Down Expand Up @@ -405,4 +403,19 @@ enum Either: Left(l), Right(r)

check("""["${foo}$.{_}", "$.{bar}$.{_}$.{_}"]""")
}
test("string match totality") {
val tc = TotalityCheck(predefTE)

//val ps = patterns("""["", "${_}$.{_}"]""")
//val ps = patterns("""[""]""")
val ps = patterns("""["${_}$.{_}", ""]""")
val diff = tc.missingBranches(ps)
assert(diff == Nil)

/*
val ps1 = patterns("""["", "$.{_}${_}"]""")
val diff1 = tc.missingBranches(ps1)
assert(diff1 == Nil)
*/
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,12 @@ abstract class SeqPatternLaws[E, I, S, R] extends AnyFunSuite {
forAll(genNamed, genSeq)(namedMatchesPatternLaw(_, _))
}

test("* - [] - [_, *] == empty") {
val diff1 = setOps.difference(Cat(Wildcard, Empty), Empty)
assert(diff1.flatMap(setOps.difference(_, Cat(AnyElem, Cat(Wildcard, Empty)))) == Nil)
assert(diff1.flatMap(setOps.difference(_, Cat(Wildcard, Cat(AnyElem, Empty)))) == Nil)
}

/*
test("if x - y is empty, (x + z) - (y + z) is empty") {
forAll { (x0: Pattern, y0: Pattern, z0: Pattern) =>
Expand Down

0 comments on commit 2471d42

Please sign in to comment.