Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve TypedExpr.substitute issue 1126 #1313

Merged
merged 8 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class ClangGenTest extends munit.FunSuite {
To inspect the code, change the hash, and it will print the code out
*/
testFilesCompilesToHash("test_workspace/Ackermann.bosatsu")(
"260c81bc79b6232a3f174cb9afc04143"
"01b16c11c1597e46371d356111276af5"
)
}
}
65 changes: 64 additions & 1 deletion core/src/main/scala/org/bykn/bosatsu/Pattern.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,36 @@ sealed abstract class Pattern[+N, +T] {
Nil
}

def substitute(table: Map[Bindable, Bindable]): Pattern[N, T] =
this match {
case Pattern.WildCard | Pattern.Literal(_) => this
case Pattern.Var(b) =>
table.get(b) match {
case None => this
case Some(b1) => Pattern.Var(b1)
}
case Pattern.Named(n, p) =>
val p1 = p.substitute(table)
val n2 = table.get(n) match {
case None => n
case Some(n1) => n1
}
if ((p1 eq p) && (n2 eq n)) this
else Pattern.Named(n2, p1)
case Pattern.Annotation(p, t) =>
val p1 = p.substitute(table)
if (p1 eq p) this
else Pattern.Annotation(p1, t)
case Pattern.Union(h, t) =>
Pattern.Union(h.substitute(table), t.map(_.substitute(table)))
case Pattern.PositionalStruct(n, pats) =>
Pattern.PositionalStruct(n, pats.map(_.substitute(table)))
case Pattern.ListPat(parts) =>
Pattern.ListPat(parts.map(_.substitute(table)))
case Pattern.StrPat(parts) =>
Pattern.StrPat(parts.map(_.substitute(table)))
}

/** List all the names that strictly smaller than anything that would match
* this pattern e.g. a top level var, would not be returned
*/
Expand Down Expand Up @@ -312,7 +342,24 @@ object Pattern {
extends NamedKind
}

sealed abstract class StrPart
sealed abstract class StrPart {
import StrPart._

def substitute(table: Map[Bindable, Bindable]): StrPart =
this match {
case WildStr | LitStr(_) | WildChar => this
case NamedStr(n) =>
table.get(n) match {
case None => this
case Some(n1) => NamedStr(n1)
}
case NamedChar(n) =>
table.get(n) match {
case None => this
case Some(n1) => NamedChar(n1)
}
}
}
object StrPart {
final case object WildStr extends StrPart
final case class NamedStr(name: Bindable) extends StrPart
Expand Down Expand Up @@ -354,6 +401,22 @@ object Pattern {
final case class Item[A](pat: A) extends ListPart[A] {
def map[B](fn: A => B): ListPart[B] = Item(fn(pat))
}

implicit class ListPartPat[N, T](val self: ListPart[Pattern[N, T]]) extends AnyVal {
def substitute(table: Map[Bindable, Bindable]): ListPart[Pattern[N, T]] =
self match {
case WildList => WildList
case NamedList(n) =>
table.get(n) match {
case None => self
case Some(n1) => NamedList(n1)
}
case Item(p) =>
val p1 = p.substitute(table)
if (p1 eq p) self
else Item(p1)
}
}
}

/** This will match any list without any binding
Expand Down
Loading
Loading