Skip to content

Commit

Permalink
Improve TypedExpr.substitute issue 1126 (#1313)
Browse files Browse the repository at this point in the history
* Improve TypedExpr.substitute issue 1126

* add some pattern tests

* fix the law checking

* fix inlining for small lambdas

* cleanup

* fix lurking bug in python generation

* improve codegen

* minor polish
  • Loading branch information
johnynek authored Dec 12, 2024
1 parent 389f52b commit 6b10791
Show file tree
Hide file tree
Showing 9 changed files with 654 additions and 247 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class ClangGenTest extends munit.FunSuite {
To inspect the code, change the hash, and it will print the code out
*/
testFilesCompilesToHash("test_workspace/Ackermann.bosatsu")(
"260c81bc79b6232a3f174cb9afc04143"
"ccbf676b90cf04397c908d23f86b6434"
)
}
}
65 changes: 64 additions & 1 deletion core/src/main/scala/org/bykn/bosatsu/Pattern.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,36 @@ sealed abstract class Pattern[+N, +T] {
Nil
}

def substitute(table: Map[Bindable, Bindable]): Pattern[N, T] =
this match {
case Pattern.WildCard | Pattern.Literal(_) => this
case Pattern.Var(b) =>
table.get(b) match {
case None => this
case Some(b1) => Pattern.Var(b1)
}
case Pattern.Named(n, p) =>
val p1 = p.substitute(table)
val n2 = table.get(n) match {
case None => n
case Some(n1) => n1
}
if ((p1 eq p) && (n2 eq n)) this
else Pattern.Named(n2, p1)
case Pattern.Annotation(p, t) =>
val p1 = p.substitute(table)
if (p1 eq p) this
else Pattern.Annotation(p1, t)
case Pattern.Union(h, t) =>
Pattern.Union(h.substitute(table), t.map(_.substitute(table)))
case Pattern.PositionalStruct(n, pats) =>
Pattern.PositionalStruct(n, pats.map(_.substitute(table)))
case Pattern.ListPat(parts) =>
Pattern.ListPat(parts.map(_.substitute(table)))
case Pattern.StrPat(parts) =>
Pattern.StrPat(parts.map(_.substitute(table)))
}

/** List all the names that strictly smaller than anything that would match
* this pattern e.g. a top level var, would not be returned
*/
Expand Down Expand Up @@ -312,7 +342,24 @@ object Pattern {
extends NamedKind
}

sealed abstract class StrPart
sealed abstract class StrPart {
import StrPart._

def substitute(table: Map[Bindable, Bindable]): StrPart =
this match {
case WildStr | LitStr(_) | WildChar => this
case NamedStr(n) =>
table.get(n) match {
case None => this
case Some(n1) => NamedStr(n1)
}
case NamedChar(n) =>
table.get(n) match {
case None => this
case Some(n1) => NamedChar(n1)
}
}
}
object StrPart {
final case object WildStr extends StrPart
final case class NamedStr(name: Bindable) extends StrPart
Expand Down Expand Up @@ -354,6 +401,22 @@ object Pattern {
final case class Item[A](pat: A) extends ListPart[A] {
def map[B](fn: A => B): ListPart[B] = Item(fn(pat))
}

implicit class ListPartPat[N, T](val self: ListPart[Pattern[N, T]]) extends AnyVal {
def substitute(table: Map[Bindable, Bindable]): ListPart[Pattern[N, T]] =
self match {
case WildList => WildList
case NamedList(n) =>
table.get(n) match {
case None => self
case Some(n1) => NamedList(n1)
}
case Item(p) =>
val p1 = p.substitute(table)
if (p1 eq p) self
else Item(p1)
}
}
}

/** This will match any list without any binding
Expand Down
Loading

0 comments on commit 6b10791

Please sign in to comment.