Skip to content

Commit

Permalink
Improve TypedExpr.substitute issue 1126
Browse files Browse the repository at this point in the history
  • Loading branch information
johnynek committed Dec 11, 2024
1 parent 389f52b commit a8377a3
Show file tree
Hide file tree
Showing 4 changed files with 450 additions and 93 deletions.
65 changes: 64 additions & 1 deletion core/src/main/scala/org/bykn/bosatsu/Pattern.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,36 @@ sealed abstract class Pattern[+N, +T] {
Nil
}

def substitute(table: Map[Bindable, Bindable]): Pattern[N, T] =
this match {
case Pattern.WildCard | Pattern.Literal(_) => this
case Pattern.Var(b) =>
table.get(b) match {
case None => this
case Some(b1) => Pattern.Var(b1)
}
case Pattern.Named(n, p) =>
val p1 = p.substitute(table)
val n2 = table.get(n) match {
case None => n
case Some(n1) => n1
}
if ((p1 eq p) && (n2 eq n)) this
else Pattern.Named(n2, p1)
case Pattern.Annotation(p, t) =>
val p1 = p.substitute(table)
if (p1 eq p) this
else Pattern.Annotation(p1, t)
case Pattern.Union(h, t) =>
Pattern.Union(h.substitute(table), t.map(_.substitute(table)))
case Pattern.PositionalStruct(n, pats) =>
Pattern.PositionalStruct(n, pats.map(_.substitute(table)))
case Pattern.ListPat(parts) =>
Pattern.ListPat(parts.map(_.substitute(table)))
case Pattern.StrPat(parts) =>
Pattern.StrPat(parts.map(_.substitute(table)))
}

/** List all the names that strictly smaller than anything that would match
* this pattern e.g. a top level var, would not be returned
*/
Expand Down Expand Up @@ -312,7 +342,24 @@ object Pattern {
extends NamedKind
}

sealed abstract class StrPart
sealed abstract class StrPart {
import StrPart._

def substitute(table: Map[Bindable, Bindable]): StrPart =
this match {
case WildStr | LitStr(_) | WildChar => this
case NamedStr(n) =>
table.get(n) match {
case None => this
case Some(n1) => NamedStr(n1)
}
case NamedChar(n) =>
table.get(n) match {
case None => this
case Some(n1) => NamedChar(n1)
}
}
}
object StrPart {
final case object WildStr extends StrPart
final case class NamedStr(name: Bindable) extends StrPart
Expand Down Expand Up @@ -354,6 +401,22 @@ object Pattern {
final case class Item[A](pat: A) extends ListPart[A] {
def map[B](fn: A => B): ListPart[B] = Item(fn(pat))
}

implicit class ListPartPat[N, T](val self: ListPart[Pattern[N, T]]) extends AnyVal {
def substitute(table: Map[Bindable, Bindable]): ListPart[Pattern[N, T]] =
self match {
case WildList => WildList
case NamedList(n) =>
table.get(n) match {
case None => self
case Some(n1) => NamedList(n1)
}
case Item(p) =>
val p1 = p.substitute(table)
if (p1 eq p) self
else Item(p1)
}
}
}

/** This will match any list without any binding
Expand Down
Loading

0 comments on commit a8377a3

Please sign in to comment.