diff --git a/core/src/main/scala/org/bykn/bosatsu/Declaration.scala b/core/src/main/scala/org/bykn/bosatsu/Declaration.scala index 8c81bf11c..995025ca6 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Declaration.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Declaration.scala @@ -913,7 +913,7 @@ object Declaration { } def stringDeclOrLit(inner: Indy[NonBinding]): Indy[NonBinding] = { - val start = P.string("${") + val start = P.string("${").as((a: NonBinding) => a) val end = P.char('}') val q1 = '\'' val q2 = '"' diff --git a/core/src/main/scala/org/bykn/bosatsu/Pattern.scala b/core/src/main/scala/org/bykn/bosatsu/Pattern.scala index e79c11175..1adac62f8 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Pattern.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Pattern.scala @@ -144,10 +144,13 @@ sealed abstract class Pattern[+N, +T] { else inner case Pattern.StrPat(items) => Pattern.StrPat(items.map { - case wl@(Pattern.StrPart.WildStr | Pattern.StrPart.LitStr(_)) => wl + case wl@(Pattern.StrPart.WildStr | Pattern.StrPart.WildChar | Pattern.StrPart.LitStr(_)) => wl case in@Pattern.StrPart.NamedStr(n) => if (keep(n)) in else Pattern.StrPart.WildStr + case in@Pattern.StrPart.NamedChar(n) => + if (keep(n)) in + else Pattern.StrPart.WildChar }) case Pattern.ListPat(items) => Pattern.ListPat(items.map { @@ -182,10 +185,13 @@ sealed abstract class Pattern[+N, +T] { else (s1 + v, l1) case Pattern.StrPat(items) => items.foldLeft((Set.empty[Bindable], List.empty[Bindable])) { - case (res, Pattern.StrPart.WildStr | Pattern.StrPart.LitStr(_)) => res + case (res, Pattern.StrPart.WildStr | Pattern.StrPart.WildChar | Pattern.StrPart.LitStr(_)) => res case ((s1, l1), Pattern.StrPart.NamedStr(v)) => if (s1(v)) (s1, v :: l1) else (s1 + v, l1) + case ((s1, l1), Pattern.StrPart.NamedChar(v)) => + if (s1(v)) (s1, v :: l1) + else (s1 + v, l1) } case Pattern.ListPat(items) => items.foldLeft((Set.empty[Bindable], List.empty[Bindable])) { @@ -274,18 +280,24 @@ object Pattern { object StrPart { final case object WildStr extends StrPart final case class NamedStr(name: Bindable) extends StrPart + final case object WildChar extends StrPart + final case class NamedChar(name: Bindable) extends StrPart final case class LitStr(asString: String) extends StrPart // this is to circumvent scala warnings because these bosatsu // patterns like right. private[this] val dollar = "$" private[this] val wildDoc = Doc.text(s"$dollar{_}") + private[this] val wildCharDoc = Doc.text(s"${dollar}.{_}") private[this] val prefix = Doc.text(s"$dollar{") + private[this] val prefixChar = Doc.text(s"${dollar}.{") def document(q: Char): Document[StrPart] = Document.instance { case WildStr => wildDoc + case WildChar => wildCharDoc case NamedStr(b) => prefix + Document[Bindable].document(b) + Doc.char('}') + case NamedChar(b) => prefixChar + Document[Bindable].document(b) + Doc.char('}') case LitStr(s) => Doc.text(StringUtil.escape(q, s)) } } @@ -535,7 +547,10 @@ object Pattern { s match { case StrPart.NamedStr(n) => NamedSeqPattern.Bind(n.sourceCodeRepr, NamedSeqPattern.Wild) + case StrPart.NamedChar(n) => + NamedSeqPattern.Bind(n.sourceCodeRepr, NamedSeqPattern.Any) case StrPart.WildStr => NamedSeqPattern.Wild + case StrPart.WildChar => NamedSeqPattern.Any case StrPart.LitStr(s) => // reverse so we can build right associated s.toList.reverse match { @@ -615,9 +630,15 @@ object Pattern { (a, b) match { case (WildStr, WildStr) => 0 case (WildStr, _) => -1 - case (LitStr(_), WildStr) => 1 + case (WildChar, WildStr) => 1 + case (WildChar, WildChar) => 0 + case (WildChar, _) => -1 + case (LitStr(_), WildStr | WildChar) => 1 case (LitStr(sa), LitStr(sb)) => sa.compareTo(sb) - case (LitStr(_), NamedStr(_)) => -1 + case (LitStr(_), NamedStr(_) | NamedChar(_)) => -1 + case (NamedChar(_), WildStr | WildChar | LitStr(_)) => 1 + case (NamedChar(na), NamedChar(nb)) => ordBin.compare(na, nb) + case (NamedChar(_), NamedStr(_)) => -1 case (NamedStr(na), NamedStr(nb)) => ordBin.compare(na, nb) case (NamedStr(_), _) => 1 } @@ -906,14 +927,21 @@ object Pattern { private[this] val pwild = P.char('_').as(WildCard) private[this] val plit: P[Pattern[Nothing, Nothing]] = { val intp = Lit.integerParser.map(Literal(_)) - val start = P.string("${") + val startStr = P.string("${").as { (opt: Option[Bindable]) => + opt.fold(StrPart.WildStr: StrPart)(StrPart.NamedStr(_)) + } + val startChar = P.string("$.{").as { (opt: Option[Bindable]) => + opt.fold(StrPart.WildChar: StrPart)(StrPart.NamedChar(_)) + } + val start = startStr | startChar val end = P.char('}') - val pwild = P.char('_').as(StrPart.WildStr) - val pname = Identifier.bindableParser.map(StrPart.NamedStr(_)) + val pwild = P.char('_').as(None) + val pname = Identifier.bindableParser.map(Some(_)) + val part: P[Option[Bindable]] = pwild | pname def strp(q: Char): P[List[StrPart]] = - StringUtil.interpolatedString(q, start, pwild.orElse(pname), end) + StringUtil.interpolatedString(q, start, part, end) .map(_.map { case Left(p) => p case Right((_, str)) => StrPart.LitStr(str) diff --git a/core/src/main/scala/org/bykn/bosatsu/StringUtil.scala b/core/src/main/scala/org/bykn/bosatsu/StringUtil.scala index 04386f529..e0886a1e3 100644 --- a/core/src/main/scala/org/bykn/bosatsu/StringUtil.scala +++ b/core/src/main/scala/org/bykn/bosatsu/StringUtil.scala @@ -49,13 +49,13 @@ abstract class GenericStringUtil { end *> undelimitedString1(end).orElse(P.pure("")) <* end } - def interpolatedString[A](quoteChar: Char, istart: P[Unit], interp: P0[A], iend: P[Unit]): P[List[Either[A, (Region, String)]]] = { + def interpolatedString[A, B](quoteChar: Char, istart: P[A => B], interp: P0[A], iend: P[Unit]): P[List[Either[B, (Region, String)]]] = { val strQuote = P.char(quoteChar) - val strLit: P[String] = undelimitedString1(strQuote.orElse(istart)) - val notStr: P[A] = (istart ~ interp ~ iend).map { case ((_, a), _) => a } + val strLit: P[String] = undelimitedString1(strQuote.orElse(istart.void)) + val notStr: P[B] = (istart ~ interp ~ iend).map { case ((fn, a), _) => fn(a) } - val either: P[Either[A, (Region, String)]] = + val either: P[Either[B, (Region, String)]] = ((P.index.with1 ~ strLit ~ P.index).map { case ((s, str), l) => Right((Region(s, l), str)) }) .orElse(notStr.map(Left(_))) diff --git a/core/src/test/scala/org/bykn/bosatsu/ParserTest.scala b/core/src/test/scala/org/bykn/bosatsu/ParserTest.scala index af4872339..ab3913574 100644 --- a/core/src/test/scala/org/bykn/bosatsu/ParserTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/ParserTest.scala @@ -213,7 +213,7 @@ class ParserTest extends ParserTestBase { def singleq(str1: String, res: List[Either[Json, String]]) = parseTestAll( StringUtil - .interpolatedString('\'', P.string("${"), Json.parser, P.char('}')) + .interpolatedString('\'', P.string("${").as((j: Json) => j), Json.parser, P.char('}')) .map(_.map { case Right((_, str)) => Right(str) case Left(l) => Left(l)