Skip to content

Commit

Permalink
Char literal and patterns (#1052)
Browse files Browse the repository at this point in the history
* Add parsing of char patterns

* checkpoint mostly working

* checkpoint with interpreter tests working

* hopefully get tests green

* make scalajs tests pass

* fix python SelectItem generation

* avoid codepoints for scalajs

* increase coverage

* actually add LitTest file

* improve Lit.fromChar test

* Implement char matching in Python

* fix python test
  • Loading branch information
johnynek authored Oct 8, 2023
1 parent 5862794 commit 8bf939b
Show file tree
Hide file tree
Showing 28 changed files with 676 additions and 179 deletions.
3 changes: 3 additions & 0 deletions cli/src/main/protobuf/bosatsu/TypedAst.proto
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ message Literal {
string stringValue = 1;
int64 intValueAs64 = 2;
string intValueAsString = 3;
int32 charValue = 4;
}
}

Expand Down Expand Up @@ -221,6 +222,8 @@ message StrPart {
WildCardPat unnamedStr = 1;
int32 namedStr = 2;
int32 literalStr = 3;
WildCardPat unnamedChar = 4;
int32 namedChar = 5;
}
}

Expand Down
12 changes: 12 additions & 0 deletions cli/src/main/scala/org/bykn/bosatsu/TypedExprToProto.scala
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,8 @@ object ProtoConverter {
case proto.StrPart.Value.LiteralStr(idx) => str(idx).map(Pattern.StrPart.LitStr(_))
case proto.StrPart.Value.UnnamedStr(_) => Success(Pattern.StrPart.WildStr)
case proto.StrPart.Value.NamedStr(idx) => bindable(idx).map { n => Pattern.StrPart.NamedStr(n) }
case proto.StrPart.Value.UnnamedChar(_) => Success(Pattern.StrPart.WildChar)
case proto.StrPart.Value.NamedChar(idx) => bindable(idx).map { n => Pattern.StrPart.NamedChar(n) }
}

items.toList match {
Expand Down Expand Up @@ -542,6 +544,8 @@ object ProtoConverter {
case _: ArithmeticException =>
proto.Literal.Value.IntValueAsString(i.toString)
}
case c @ Lit.Chr(_) =>
proto.Literal.Value.CharValue(c.toCodePoint)
case Lit.Str(str) =>
proto.Literal.Value.StringValue(str)
}
Expand All @@ -554,6 +558,8 @@ object ProtoConverter {
Failure(new Exception("unexpected unset Literal value in pattern"))
case proto.Literal.Value.StringValue(s) =>
Success(Lit.Str(s))
case proto.Literal.Value.CharValue(cp) =>
Success(Lit.Chr.fromCodePoint(cp))
case proto.Literal.Value.IntValueAs64(l) =>
Success(Lit(l))
case proto.Literal.Value.IntValueAsString(s) =>
Expand Down Expand Up @@ -587,10 +593,16 @@ object ProtoConverter {
parts.traverse {
case Pattern.StrPart.WildStr =>
tabPure(proto.StrPart(proto.StrPart.Value.UnnamedStr(proto.WildCardPat())))
case Pattern.StrPart.WildChar =>
tabPure(proto.StrPart(proto.StrPart.Value.UnnamedChar(proto.WildCardPat())))
case Pattern.StrPart.NamedStr(n) =>
getId(n.sourceCodeRepr).map { idx =>
proto.StrPart(proto.StrPart.Value.NamedStr(idx))
}
case Pattern.StrPart.NamedChar(n) =>
getId(n.sourceCodeRepr).map { idx =>
proto.StrPart(proto.StrPart.Value.NamedChar(idx))
}
case Pattern.StrPart.LitStr(s) =>
getId(s).map { idx =>
proto.StrPart(proto.StrPart.Value.LiteralStr(idx))
Expand Down
4 changes: 4 additions & 0 deletions core/src/main/resources/bosatsu/predef.bosatsu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package Bosatsu/Predef

export (
Bool(),
Char,
Comparison(),
Int,
Option(),
Expand All @@ -14,6 +15,7 @@ export (
Dict,
add,
add_key,
char_to_String,
cmp_Int,
concat,
concat_String,
Expand Down Expand Up @@ -151,7 +153,9 @@ def range_fold(inclusiveLower: Int, exclusiveUpper: Int, init: a, fn: (a, Int) -
# String functions
#############
external struct String
external struct Char

external def char_to_String(c: Char) -> String
external def string_Order_fn(str0: String, str1: String) -> Comparison
string_Order = Order(string_Order_fn)
external def concat_String(items: List[String]) -> String
Expand Down
58 changes: 38 additions & 20 deletions core/src/main/scala/org/bykn/bosatsu/Declaration.scala
Original file line number Diff line number Diff line change
Expand Up @@ -146,14 +146,15 @@ sealed abstract class Declaration {

case StringDecl(parts) =>
val useDouble = parts.exists {
case Right((_, str)) => str.contains('\'') && !str.contains('"')
case Left(_) => false
case StringDecl.Literal(_, str) => str.contains('\'') && !str.contains('"')
case _ => false
}
val q = if (useDouble) '"' else '\''
val inner = Doc.intercalate(Doc.empty,
parts.toList.map {
case Right((_, str)) => Doc.text(StringUtil.escape(q, str))
case Left(decl) => Doc.text("${") + decl.toDoc + Doc.char('}')
case StringDecl.Literal(_, str) => Doc.text(StringUtil.escape(q, str))
case StringDecl.StrExpr(decl) => Doc.text("${") + decl.toDoc + Doc.char('}')
case StringDecl.CharExpr(decl) => Doc.text("$.{") + decl.toDoc + Doc.char('}')
})
Doc.char(q) + inner + Doc.char(q)

Expand Down Expand Up @@ -233,7 +234,8 @@ sealed abstract class Declaration {
case Var(_) => acc
case StringDecl(items) =>
items.foldLeft(acc) {
case (acc, Left(nb)) => loop(nb, bound, acc)
case (acc, StringDecl.StrExpr(nb)) => loop(nb, bound, acc)
case (acc, StringDecl.CharExpr(nb)) => loop(nb, bound, acc)
case (acc, _) => acc
}
case ListDecl(ListLang.Cons(items)) =>
Expand Down Expand Up @@ -345,8 +347,9 @@ sealed abstract class Declaration {
case Var(_) => acc
case StringDecl(nel) =>
nel.foldLeft(acc) {
case (acc0, Left(decl)) => loop(decl, acc0)
case (acc0, Right(_)) => acc0
case (acc0, StringDecl.StrExpr(decl)) => loop(decl, acc0)
case (acc0, StringDecl.CharExpr(decl)) => loop(decl, acc0)
case (acc0, _) => acc0
}
case ListDecl(ListLang.Cons(items)) =>
items.foldLeft(acc) { (acc0, sori) =>
Expand Down Expand Up @@ -523,8 +526,9 @@ object Declaration {
case StringDecl(nel) =>
nel
.traverse {
case Left(nb) => loop(nb).map(Left(_))
case right => Some(right)
case StringDecl.StrExpr(nb) => loop(nb).map(StringDecl.StrExpr(_))
case StringDecl.CharExpr(nb) => loop(nb).map(StringDecl.CharExpr(_))
case lit => Some(lit)
}
.map(StringDecl(_)(decl.region))
case ListDecl(ll) =>
Expand Down Expand Up @@ -669,8 +673,9 @@ object Declaration {
case Var(b) => Var(b)(r)
case StringDecl(nel) =>
val ne1 = nel.map {
case Right((_, s)) => Right((r, s))
case Left(e) => Left(e.replaceRegionsNB(r))
case StringDecl.Literal(_, s) => StringDecl.Literal(r, s)
case StringDecl.CharExpr(e) => StringDecl.CharExpr(e.replaceRegionsNB(r))
case StringDecl.StrExpr(e) => StringDecl.StrExpr(e.replaceRegionsNB(r))
}
StringDecl(ne1)(r)
case ListDecl(ListLang.Cons(items)) =>
Expand Down Expand Up @@ -756,7 +761,13 @@ object Declaration {
/**
* This represents interpolated strings
*/
case class StringDecl(items: NonEmptyList[Either[NonBinding, (Region, String)]])(implicit val region: Region) extends NonBinding
case class StringDecl(items: NonEmptyList[StringDecl.Part])(implicit val region: Region) extends NonBinding
object StringDecl {
sealed abstract class Part
case class Literal(region: Region, toStr: String) extends Part
case class StrExpr(nonBinding: NonBinding) extends Part
case class CharExpr(nonBinding: NonBinding) extends Part
}
/**
* This represents the list construction language
*/
Expand Down Expand Up @@ -788,13 +799,14 @@ object Declaration {
Pattern.StructKind.Named(nm, Pattern.StructKind.Style.TupleLike), Nil))
case Var(v: Bindable) => Some(Pattern.Var(v))
case Literal(lit) => Some(Pattern.Literal(lit))
case StringDecl(NonEmptyList(Right((_, s)), Nil)) =>
case StringDecl(NonEmptyList(StringDecl.Literal(_, s), Nil)) =>
Some(Pattern.Literal(Lit.Str(s)))
case StringDecl(items) =>
def toStrPart(p: Either[NonBinding, (Region, String)]): Option[Pattern.StrPart] =
def toStrPart(p: StringDecl.Part): Option[Pattern.StrPart] =
p match {
case Right((_, str)) => Some(Pattern.StrPart.LitStr(str))
case Left(Var(v: Bindable)) => Some(Pattern.StrPart.NamedStr(v))
case StringDecl.Literal(_, str) => Some(Pattern.StrPart.LitStr(str))
case StringDecl.StrExpr(Var(v: Bindable)) => Some(Pattern.StrPart.NamedStr(v))
case StringDecl.CharExpr(Var(v: Bindable)) => Some(Pattern.StrPart.NamedChar(v))
case _ => None
}
items.traverse(toStrPart).map(Pattern.StrPat(_))
Expand Down Expand Up @@ -913,7 +925,8 @@ object Declaration {
}

def stringDeclOrLit(inner: Indy[NonBinding]): Indy[NonBinding] = {
val start = P.string("${")
val start = P.string("${").as((a: NonBinding) => StringDecl.StrExpr(a)) |
P.string("$.{").as((a: NonBinding) => StringDecl.CharExpr(a))
val end = P.char('}')
val q1 = '\''
val q2 = '"'
Expand All @@ -929,7 +942,10 @@ object Declaration {
case (r, Right((_, str)) :: Nil) =>
Literal(Lit.Str(str))(r)
case (r, h :: tail) =>
StringDecl(NonEmptyList(h, tail))(r)
StringDecl(NonEmptyList(h, tail).map {
case Right((region, str)) => StringDecl.Literal(region, str)
case Left(expr) => expr
})(r)
}
}
}
Expand Down Expand Up @@ -1069,7 +1085,8 @@ object Declaration {
.region
.map { case (r, l) => DictDecl(l)(r) }

val lits: P[Literal] = Lit.integerParser.region.map { case (r, l) => Literal(l)(r) }
val lits: P[Literal] =
(Lit.integerParser | Lit.codePointParser).region.map { case (r, l) => Literal(l)(r) }

private sealed abstract class ParseMode
private object ParseMode {
Expand Down Expand Up @@ -1171,8 +1188,9 @@ object Declaration {
val slashcontinuation = ((maybeSpace ~ P.char('\\') ~ toEOL1).backtrack ~ Parser.maybeSpacesAndLines).?.void
// 0 or more args
val params0 = recNonBind.parensLines0Cut
val justDot = P.not(P.string(".\"") | P.string(".'")).with1 *> P.char('.')
val dotApply: P[NonBinding => NonBinding] =
(slashcontinuation.with1 *> P.char('.') *> (fn ~ params0))
(slashcontinuation.with1 *> justDot *> (fn ~ params0))
.region
.map { case (r2, (fn, args)) =>

Expand Down
5 changes: 3 additions & 2 deletions core/src/main/scala/org/bykn/bosatsu/DefRecursionCheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -522,8 +522,9 @@ object DefRecursionCheck {
}
case StringDecl(parts) =>
parts.parTraverse_ {
case Left(nb) => checkDecl(nb)
case Right(_) => unitSt
case StringDecl.CharExpr(nb) => checkDecl(nb)
case StringDecl.StrExpr(nb) => checkDecl(nb)
case StringDecl.Literal(_, _) => unitSt
}
case ListDecl(ll) =>
ll match {
Expand Down
47 changes: 44 additions & 3 deletions core/src/main/scala/org/bykn/bosatsu/Lit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ sealed abstract class Lit {
def repr: String =
this match {
case Lit.Integer(i) => i.toString
case c @ Lit.Chr(_) =>
".'" + escape('\'', c.asStr) + "'"
case Lit.Str(s) => "\"" + escape('"', s) + "\""
}

Expand All @@ -22,10 +24,35 @@ object Lit {
case class Str(toStr: String) extends Lit {
def unboxToAny: Any = toStr
}
case class Chr(asStr: String) extends Lit {
def toCodePoint: Int = asStr.codePointAt(0)
def unboxToAny: Any = asStr
}
object Chr {
private def build(cp: Int): Chr =
Chr((new java.lang.StringBuilder).appendCodePoint(cp).toString)

private[this] val cache: Array[Chr] =
(0 until 256).map(build).toArray
/**
* @throws IllegalArgumentException on a bad codepoint
*/
def fromCodePoint(cp: Int): Chr =
if ((0 <= cp) && (cp < 256)) cache(cp)
else build(cp)
}

val EmptyStr: Str = Str("")

def fromInt(i: Int): Lit = Integer(BigInteger.valueOf(i.toLong))

def fromChar(c: Char): Lit =
if (0xd800 <= c && c < 0xe000)
throw new IllegalArgumentException(s"utf-16 character int=${c.toInt} is not a valid single codepoint")
else Chr.fromCodePoint(c.toInt)

def fromCodePoint(cp: Int): Lit = Chr.fromCodePoint(cp)

def apply(i: Long): Lit = apply(BigInteger.valueOf(i))
def apply(bi: BigInteger): Lit = Integer(bi)
def apply(str: String): Lit = Str(str)
Expand All @@ -42,18 +69,26 @@ object Lit {
str(q1).orElse(str(q2))
}

val codePointParser: P[Chr] = {
(StringUtil.codepoint(P.string(".\""), P.char('"')) |
StringUtil.codepoint(P.string(".'"), P.char('\''))).map(Chr.fromCodePoint(_))
}

implicit val litOrdering: Ordering[Lit] =
new Ordering[Lit] {
def compare(a: Lit, b: Lit): Int =
(a, b) match {
case (Integer(a), Integer(b)) => a.compareTo(b)
case (Integer(_), Str(_)) => -1
case (Str(_), Integer(_)) => 1
case (Integer(_), Str(_) | Chr(_)) => -1
case (Chr(_), Integer(_)) => 1
case (Chr(a), Chr(b)) => a.compareTo(b)
case (Chr(_), Str(_)) => -1
case (Str(_), Integer(_)| Chr(_)) => 1
case (Str(a), Str(b)) => a.compareTo(b)
}
}

val parser: P[Lit] = integerParser.orElse(stringParser)
val parser: P[Lit] = integerParser | stringParser | codePointParser

implicit val document: Document[Lit] =
Document.instance[Lit] {
Expand All @@ -62,6 +97,12 @@ object Lit {
case Str(str) =>
val q = if (str.contains('\'') && !str.contains('"')) '"' else '\''
Doc.char(q) + Doc.text(escape(q, str)) + Doc.char(q)
case c @ Chr(_) =>
val str = c.asStr
val (start, end) =
if (str.contains('\'') && !str.contains('"')) (".\"", '"')
else (".'", '\'')
Doc.text(start) + Doc.text(escape(end, str)) + Doc.char(end)
}
}

6 changes: 6 additions & 0 deletions core/src/main/scala/org/bykn/bosatsu/Matchless.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ object Matchless {
sealed abstract class StrPart
object StrPart {
sealed abstract class Glob(val capture: Boolean) extends StrPart
sealed abstract class CharPart(val capture: Boolean) extends StrPart
case object WildStr extends Glob(false)
case object IndexStr extends Glob(true)
case object WildChar extends CharPart(false)
case object IndexChar extends CharPart(true)
case class LitStr(asString: String) extends StrPart
}

Expand Down Expand Up @@ -341,13 +344,16 @@ object Matchless {
// that each name is distinct
// should be checked in the SourceConverter/TotalityChecking code
case Pattern.StrPart.NamedStr(n) => n
case Pattern.StrPart.NamedChar(n) => n
}

val muts = sbinds.traverse { b => makeAnon.map(LocalAnonMut(_)).map((b, _)) }

val pat = items.toList.map {
case Pattern.StrPart.NamedStr(_) => StrPart.IndexStr
case Pattern.StrPart.NamedChar(_) => StrPart.IndexChar
case Pattern.StrPart.WildStr => StrPart.WildStr
case Pattern.StrPart.WildChar => StrPart.WildChar
case Pattern.StrPart.LitStr(s) => StrPart.LitStr(s)
}

Expand Down
Loading

0 comments on commit 8bf939b

Please sign in to comment.