Skip to content

Commit

Permalink
hopefully get tests green
Browse files Browse the repository at this point in the history
  • Loading branch information
johnynek committed Oct 1, 2023
1 parent 2471d42 commit bfa2b37
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 37 deletions.
3 changes: 3 additions & 0 deletions cli/src/main/protobuf/bosatsu/TypedAst.proto
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ message Literal {
string stringValue = 1;
int64 intValueAs64 = 2;
string intValueAsString = 3;
int32 charValue = 4;
}
}

Expand Down Expand Up @@ -221,6 +222,8 @@ message StrPart {
WildCardPat unnamedStr = 1;
int32 namedStr = 2;
int32 literalStr = 3;
WildCardPat unnamedChar = 4;
int32 namedChar = 5;
}
}

Expand Down
12 changes: 12 additions & 0 deletions cli/src/main/scala/org/bykn/bosatsu/TypedExprToProto.scala
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,8 @@ object ProtoConverter {
case proto.StrPart.Value.LiteralStr(idx) => str(idx).map(Pattern.StrPart.LitStr(_))
case proto.StrPart.Value.UnnamedStr(_) => Success(Pattern.StrPart.WildStr)
case proto.StrPart.Value.NamedStr(idx) => bindable(idx).map { n => Pattern.StrPart.NamedStr(n) }
case proto.StrPart.Value.UnnamedChar(_) => Success(Pattern.StrPart.WildChar)
case proto.StrPart.Value.NamedChar(idx) => bindable(idx).map { n => Pattern.StrPart.NamedChar(n) }
}

items.toList match {
Expand Down Expand Up @@ -542,6 +544,8 @@ object ProtoConverter {
case _: ArithmeticException =>
proto.Literal.Value.IntValueAsString(i.toString)
}
case c @ Lit.Chr(_) =>
proto.Literal.Value.CharValue(c.toCodePoint)
case Lit.Str(str) =>
proto.Literal.Value.StringValue(str)
}
Expand All @@ -554,6 +558,8 @@ object ProtoConverter {
Failure(new Exception("unexpected unset Literal value in pattern"))
case proto.Literal.Value.StringValue(s) =>
Success(Lit.Str(s))
case proto.Literal.Value.CharValue(cp) =>
Success(Lit.Chr.fromCodePoint(cp))
case proto.Literal.Value.IntValueAs64(l) =>
Success(Lit(l))
case proto.Literal.Value.IntValueAsString(s) =>
Expand Down Expand Up @@ -587,10 +593,16 @@ object ProtoConverter {
parts.traverse {
case Pattern.StrPart.WildStr =>
tabPure(proto.StrPart(proto.StrPart.Value.UnnamedStr(proto.WildCardPat())))
case Pattern.StrPart.WildChar =>
tabPure(proto.StrPart(proto.StrPart.Value.UnnamedChar(proto.WildCardPat())))
case Pattern.StrPart.NamedStr(n) =>
getId(n.sourceCodeRepr).map { idx =>
proto.StrPart(proto.StrPart.Value.NamedStr(idx))
}
case Pattern.StrPart.NamedChar(n) =>
getId(n.sourceCodeRepr).map { idx =>
proto.StrPart(proto.StrPart.Value.NamedChar(idx))
}
case Pattern.StrPart.LitStr(s) =>
getId(s).map { idx =>
proto.StrPart(proto.StrPart.Value.LiteralStr(idx))
Expand Down
18 changes: 10 additions & 8 deletions core/src/main/scala/org/bykn/bosatsu/Lit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,23 @@ object Lit {
case class Str(toStr: String) extends Lit {
def unboxToAny: Any = toStr
}
case class Chr(toCodePoint: Int) extends Lit {
lazy val asStr: String =
(new java.lang.StringBuilder).appendCodePoint(toCodePoint).toString

case class Chr(asStr: String) extends Lit {
def toCodePoint: Int = asStr.codePointAt(0)
def unboxToAny: Any = asStr
}
object Chr {
def fromCodePoint(cp: Int): Chr =
Chr((new java.lang.StringBuilder).appendCodePoint(cp).toString)
}

val EmptyStr: Str = Str("")

def fromInt(i: Int): Lit = Integer(BigInteger.valueOf(i.toLong))
def fromChar(c: Char): Lit =
if (c >= 0xd800 && c < 0xdc00)
throw new IllegalArgumentException(s"utf-16 character int=${c.toInt} is not a valid single codepoint")
else Chr(c.toInt)
def fromCodePoint(cp: Int): Lit = Chr(cp)
else Chr.fromCodePoint(c.toInt)
def fromCodePoint(cp: Int): Lit = Chr.fromCodePoint(cp)

def apply(i: Long): Lit = apply(BigInteger.valueOf(i))
def apply(bi: BigInteger): Lit = Integer(bi)
Expand All @@ -58,7 +60,7 @@ object Lit {

val codePointParser: P[Chr] = {
(StringUtil.codepoint(P.string(".\""), P.char('"')) |
StringUtil.codepoint(P.string(".'"), P.char('\''))).map(Chr(_))
StringUtil.codepoint(P.string(".'"), P.char('\''))).map(Chr.fromCodePoint(_))
}

implicit val litOrdering: Ordering[Lit] =
Expand All @@ -68,7 +70,7 @@ object Lit {
case (Integer(a), Integer(b)) => a.compareTo(b)
case (Integer(_), Str(_) | Chr(_)) => -1
case (Chr(_), Integer(_)) => 1
case (Chr(a), Chr(b)) => java.lang.Integer.compare(a, b)
case (Chr(a), Chr(b)) => a.compareTo(b)
case (Chr(_), Str(_)) => -1
case (Str(_), Integer(_)| Chr(_)) => 1
case (Str(a), Str(b)) => a.compareTo(b)
Expand Down
10 changes: 6 additions & 4 deletions core/src/main/scala/org/bykn/bosatsu/MatchlessToValue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -533,11 +533,12 @@ object MatchlessToValue {
def matchString(str: String, pat: List[Matchless.StrPart], binds: Int): Array[String] = {
import Matchless.StrPart._

val strLen = str.length()
val results = if (binds > 0) new Array[String](binds) else emptyStringArray

def loop(offset: Int, pat: List[Matchless.StrPart], next: Int): Boolean =
pat match {
case Nil => offset == str.length
case Nil => offset == strLen
case LitStr(expect) :: tail =>
val len = expect.length
str.regionMatches(offset, expect, 0, len) && loop(offset + len, tail, next)
Expand Down Expand Up @@ -570,13 +571,14 @@ object MatchlessToValue {
// checks at all possible later offsets
// a smarter algorithm could see if there
// are Lit parts that can match or not
val checks = (offset until str.length).iterator
var matched = false
var off1 = offset
val n1 = if (h.capture) (next + 1) else next
while (!matched && checks.hasNext) {
off1 = checks.next()
while (!matched && (off1 < strLen)) {
matched = loop(off1, rest, n1)
if (!matched) {
off1 = off1 + 1
}
}

matched && {
Expand Down
64 changes: 40 additions & 24 deletions core/src/main/scala/org/bykn/bosatsu/StringUtil.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package org.bykn.bosatsu

import cats.parse.{Parser0 => P0, Parser => P}
import cats.parse.{Parser0 => P0, Parser => P, Accumulator, Appender}

abstract class GenericStringUtil {
protected def decodeTable: Map[Char, Char]
Expand Down Expand Up @@ -37,40 +37,56 @@ abstract class GenericStringUtil {
P.char('\\') *> after
}

val singleUtf16Codepoint: P[Char] =
P.charWhere { c =>
val utf16Codepoint: P[Int] = {
// see: https://en.wikipedia.org/wiki/UTF-16
val first = P.anyChar.map { c =>
val ci = c.toInt

ci < 0xd800 || ci >= 0xdc00
if (ci < 0xd800 || ci >= 0xe000) Right(ci)
else Left(ci)
}

val second: P[Int => Int] =
P.charWhere { c =>
val ci = c.toInt
(0xdc00 <= ci) && (ci <= 0xdfff)
}
.map { low =>
val lowOff = low - 0xdc00 + 0x10000

{ high =>
val highPart = (high - 0xd800) * 0x400
highPart + lowOff
}
}

P.select(first)(second)
}

val codePointAccumulator: Accumulator[Int, String] =
new Accumulator[Int, String] {
def newAppender(first: Int): Appender[Int,String] =
new Appender[Int, String] {
val strbuilder = new java.lang.StringBuilder
strbuilder.appendCodePoint(first)

def append(item: Int) = {
strbuilder.appendCodePoint(item)
this
}
def finish(): String = strbuilder.toString
}
}
/**
* String content without the delimiter
*/
def undelimitedString1(endP: P[Unit]): P[String] = {
import cats.parse.{Accumulator, Appender}

implicit val codePointAccumulator: Accumulator[Int, String] =
new Accumulator[Int, String] {
def newAppender(first: Int): Appender[Int,String] =
new Appender[Int, String] {
val strbuilder = new java.lang.StringBuilder
strbuilder.appendCodePoint(first)

def append(item: Int) = {
strbuilder.appendCodePoint(item)
this
}
def finish(): String = strbuilder.toString
}
}
escapedToken.orElse((!endP).with1 *> singleUtf16Codepoint.map(_.toInt))
.repAs
escapedToken.orElse((!endP).with1 *> utf16Codepoint)
.repAs(codePointAccumulator)
}

def codepoint(startP: P[Any], endP: P[Any]): P[Int] =
startP *>
escapedToken.orElse((!endP).with1 *> singleUtf16Codepoint.map(_.toInt)) <*
escapedToken.orElse((!endP).with1 *> utf16Codepoint) <*
endP

def escapedString(q: Char): P[String] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,9 @@ object PythonGen {
case i => Code.Apply(Code.DotSelect(i, Code.Ident("__str__")), Nil)
}
}, 1)),
(Identifier.unsafeBindable("char_to_String"),
// we encode chars as strings so this is just identity
({ input => Env.envMonad.pure(input.head) }, 1)),
(Identifier.unsafeBindable("trace"),
({
input => Env.onLast2(input.head, input.tail.head) { (msg, i) =>
Expand Down
2 changes: 1 addition & 1 deletion core/src/test/scala/org/bykn/bosatsu/Gen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ object Generators {
str <- lowerIdent // TODO
} yield Lit.Str(str)

val char = Gen.choose(0, 0xd7ff).map { i => Lit.Chr(i) }
val char = Gen.choose(0, 0xd7ff).map { i => Lit.Chr.fromCodePoint(i) }

val bi = Arbitrary.arbitrary[BigInt].map { bi => Lit.Integer(bi.bigInteger) }
Gen.oneOf(str, bi, char)
Expand Down
20 changes: 20 additions & 0 deletions core/src/test/scala/org/bykn/bosatsu/ParserTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,26 @@ class ParserTest extends ParserTestBase {
singleq(s"'$dollar{42}bar'", List(Left(Json.JNumberStr("42")), Right("bar")))
}

test("we can decode any utf16") {
val p = StringUtil.utf16Codepoint.repAs(StringUtil.codePointAccumulator) | P.pure("")
val genCodePoints: Gen[Int] =
Gen.frequency(
(10, Gen.choose(0, 0xd7ff)),
(1, Gen.choose(0, 0x10ffff).filterNot { cp =>
(0xD800 <= cp && cp <= 0xDFFF)
})
)
forAll(Gen.listOf(genCodePoints)) { cps =>
val strbuilder = new java.lang.StringBuilder
cps.foreach(strbuilder.appendCodePoint(_))
val str = strbuilder.toString
val hex = cps.map(_.toHexString)

assert(p.parseAll(str).map(_.codePoints.toArray.toList) == Right(cps),
s"hex = $hex, str = ${str.codePoints.toArray.toList} utf16 = ${str.toCharArray().toList.map(_.toInt.toHexString)}")
}
}

test("Identifier round trips") {
forAll(Generators.identifierGen)(law(Identifier.parser))

Expand Down

0 comments on commit bfa2b37

Please sign in to comment.