Skip to content

Commit

Permalink
Add parsing of char patterns
Browse files Browse the repository at this point in the history
  • Loading branch information
johnynek committed Sep 28, 2023
1 parent cef4c90 commit 8e8aef3
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 14 deletions.
2 changes: 1 addition & 1 deletion core/src/main/scala/org/bykn/bosatsu/Declaration.scala
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,7 @@ object Declaration {
}

def stringDeclOrLit(inner: Indy[NonBinding]): Indy[NonBinding] = {
val start = P.string("${")
val start = P.string("${").as((a: NonBinding) => a)
val end = P.char('}')
val q1 = '\''
val q2 = '"'
Expand Down
44 changes: 36 additions & 8 deletions core/src/main/scala/org/bykn/bosatsu/Pattern.scala
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,13 @@ sealed abstract class Pattern[+N, +T] {
else inner
case Pattern.StrPat(items) =>
Pattern.StrPat(items.map {
case wl@(Pattern.StrPart.WildStr | Pattern.StrPart.LitStr(_)) => wl
case wl@(Pattern.StrPart.WildStr | Pattern.StrPart.WildChar | Pattern.StrPart.LitStr(_)) => wl
case in@Pattern.StrPart.NamedStr(n) =>
if (keep(n)) in
else Pattern.StrPart.WildStr
case in@Pattern.StrPart.NamedChar(n) =>
if (keep(n)) in
else Pattern.StrPart.WildChar
})
case Pattern.ListPat(items) =>
Pattern.ListPat(items.map {
Expand Down Expand Up @@ -182,10 +185,13 @@ sealed abstract class Pattern[+N, +T] {
else (s1 + v, l1)
case Pattern.StrPat(items) =>
items.foldLeft((Set.empty[Bindable], List.empty[Bindable])) {
case (res, Pattern.StrPart.WildStr | Pattern.StrPart.LitStr(_)) => res
case (res, Pattern.StrPart.WildStr | Pattern.StrPart.WildChar | Pattern.StrPart.LitStr(_)) => res
case ((s1, l1), Pattern.StrPart.NamedStr(v)) =>
if (s1(v)) (s1, v :: l1)
else (s1 + v, l1)
case ((s1, l1), Pattern.StrPart.NamedChar(v)) =>
if (s1(v)) (s1, v :: l1)
else (s1 + v, l1)
}
case Pattern.ListPat(items) =>
items.foldLeft((Set.empty[Bindable], List.empty[Bindable])) {
Expand Down Expand Up @@ -274,18 +280,24 @@ object Pattern {
object StrPart {
final case object WildStr extends StrPart
final case class NamedStr(name: Bindable) extends StrPart
final case object WildChar extends StrPart
final case class NamedChar(name: Bindable) extends StrPart
final case class LitStr(asString: String) extends StrPart

// this is to circumvent scala warnings because these bosatsu
// patterns like right.
private[this] val dollar = "$"
private[this] val wildDoc = Doc.text(s"$dollar{_}")
private[this] val wildCharDoc = Doc.text(s"${dollar}.{_}")
private[this] val prefix = Doc.text(s"$dollar{")
private[this] val prefixChar = Doc.text(s"${dollar}.{")

def document(q: Char): Document[StrPart] =
Document.instance {
case WildStr => wildDoc
case WildChar => wildCharDoc
case NamedStr(b) => prefix + Document[Bindable].document(b) + Doc.char('}')
case NamedChar(b) => prefixChar + Document[Bindable].document(b) + Doc.char('}')
case LitStr(s) => Doc.text(StringUtil.escape(q, s))
}
}
Expand Down Expand Up @@ -535,7 +547,10 @@ object Pattern {
s match {
case StrPart.NamedStr(n) =>
NamedSeqPattern.Bind(n.sourceCodeRepr, NamedSeqPattern.Wild)
case StrPart.NamedChar(n) =>
NamedSeqPattern.Bind(n.sourceCodeRepr, NamedSeqPattern.Any)
case StrPart.WildStr => NamedSeqPattern.Wild
case StrPart.WildChar => NamedSeqPattern.Any
case StrPart.LitStr(s) =>
// reverse so we can build right associated
s.toList.reverse match {
Expand Down Expand Up @@ -615,9 +630,15 @@ object Pattern {
(a, b) match {
case (WildStr, WildStr) => 0
case (WildStr, _) => -1
case (LitStr(_), WildStr) => 1
case (WildChar, WildStr) => 1
case (WildChar, WildChar) => 0
case (WildChar, _) => -1
case (LitStr(_), WildStr | WildChar) => 1
case (LitStr(sa), LitStr(sb)) => sa.compareTo(sb)
case (LitStr(_), NamedStr(_)) => -1
case (LitStr(_), NamedStr(_) | NamedChar(_)) => -1
case (NamedChar(_), WildStr | WildChar | LitStr(_)) => 1
case (NamedChar(na), NamedChar(nb)) => ordBin.compare(na, nb)
case (NamedChar(_), NamedStr(_)) => -1
case (NamedStr(na), NamedStr(nb)) => ordBin.compare(na, nb)
case (NamedStr(_), _) => 1
}
Expand Down Expand Up @@ -906,14 +927,21 @@ object Pattern {
private[this] val pwild = P.char('_').as(WildCard)
private[this] val plit: P[Pattern[Nothing, Nothing]] = {
val intp = Lit.integerParser.map(Literal(_))
val start = P.string("${")
val startStr = P.string("${").as { (opt: Option[Bindable]) =>
opt.fold(StrPart.WildStr: StrPart)(StrPart.NamedStr(_))
}
val startChar = P.string("$.{").as { (opt: Option[Bindable]) =>
opt.fold(StrPart.WildChar: StrPart)(StrPart.NamedChar(_))
}
val start = startStr | startChar
val end = P.char('}')

val pwild = P.char('_').as(StrPart.WildStr)
val pname = Identifier.bindableParser.map(StrPart.NamedStr(_))
val pwild = P.char('_').as(None)
val pname = Identifier.bindableParser.map(Some(_))
val part: P[Option[Bindable]] = pwild | pname

def strp(q: Char): P[List[StrPart]] =
StringUtil.interpolatedString(q, start, pwild.orElse(pname), end)
StringUtil.interpolatedString(q, start, part, end)
.map(_.map {
case Left(p) => p
case Right((_, str)) => StrPart.LitStr(str)
Expand Down
8 changes: 4 additions & 4 deletions core/src/main/scala/org/bykn/bosatsu/StringUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ abstract class GenericStringUtil {
end *> undelimitedString1(end).orElse(P.pure("")) <* end
}

def interpolatedString[A](quoteChar: Char, istart: P[Unit], interp: P0[A], iend: P[Unit]): P[List[Either[A, (Region, String)]]] = {
def interpolatedString[A, B](quoteChar: Char, istart: P[A => B], interp: P0[A], iend: P[Unit]): P[List[Either[B, (Region, String)]]] = {
val strQuote = P.char(quoteChar)

val strLit: P[String] = undelimitedString1(strQuote.orElse(istart))
val notStr: P[A] = (istart ~ interp ~ iend).map { case ((_, a), _) => a }
val strLit: P[String] = undelimitedString1(strQuote.orElse(istart.void))
val notStr: P[B] = (istart ~ interp ~ iend).map { case ((fn, a), _) => fn(a) }

val either: P[Either[A, (Region, String)]] =
val either: P[Either[B, (Region, String)]] =
((P.index.with1 ~ strLit ~ P.index).map { case ((s, str), l) => Right((Region(s, l), str)) })
.orElse(notStr.map(Left(_)))

Expand Down
2 changes: 1 addition & 1 deletion core/src/test/scala/org/bykn/bosatsu/ParserTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ class ParserTest extends ParserTestBase {
def singleq(str1: String, res: List[Either[Json, String]]) =
parseTestAll(
StringUtil
.interpolatedString('\'', P.string("${"), Json.parser, P.char('}'))
.interpolatedString('\'', P.string("${").as((j: Json) => j), Json.parser, P.char('}'))
.map(_.map {
case Right((_, str)) => Right(str)
case Left(l) => Left(l)
Expand Down

0 comments on commit 8e8aef3

Please sign in to comment.