Skip to content

Commit

Permalink
map equality
Browse files Browse the repository at this point in the history
  • Loading branch information
pshirshov committed Aug 24, 2023
1 parent 906ecc3 commit c63e409
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,24 +85,36 @@ object CSDefnTranslator {
.map(group => q"""HashCode.Combine(${group.join(", ")})""")
.toList

def mkComparator(ref: TextTree[CSValue],
oref: TextTree[CSValue],
tpe: TypeRef): TextTree[CSValue] = {
TypeId.comparator(tpe) match {
case ComparatorType.Direct =>
q"$ref == $oref"
case ComparatorType.ObjectEquals =>
q"((Object)$ref).Equals($oref)"
case ComparatorType.OptionEquals =>
q"Equals($ref, $oref)"
case ComparatorType.SeqEquals =>
q"$ref.SequenceEqual($oref)"
case ComparatorType.SetEquals =>
q"$ref.SetEquals($oref)"
case ComparatorType.MapEquals(valtpe) =>
val vref = q"$oref[key]"
val ovref = q"$ref[key]"

val cmp = mkComparator(vref, ovref, valtpe)

q"($ref.Count == $oref.Count && !$ref.Keys.Any(key => !$oref.Keys.Contains(key)) && !$ref.Keys.Any(key => $cmp))"
}
}

val comparators = outs.map(o => (o._4, o._3._1)).map {
case (f, name) =>
val ref = q"$name"
TypeId.comparator(f.tpe) match {
case ComparatorType.Direct =>
q"$ref == other.$ref"
case ComparatorType.ObjectEquals =>
q"((Object)$ref).Equals(other.$ref)"
case ComparatorType.OptionEquals =>
q"Equals($ref, other.$ref)"
case ComparatorType.SeqEquals =>
q"$ref.SequenceEqual(other.$ref)"
case ComparatorType.SetEquals =>
q"$ref.SetEquals(other.$ref)"
case ComparatorType.MapEquals =>
val oref = q"other.$ref"
q"($ref.Count == $oref.Count && !$ref.Keys.Any(key => !$oref.Keys.Contains(key)) && !$ref.Keys.Any(key => $oref[key] != $ref[key]))"
}
val oref = q"other.$ref"

mkComparator(ref, oref, f.tpe)
}

val hc = if (hcGroups.isEmpty) { q"0" } else { hcGroups.join(" ^\n") }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ object TypeId {
case object OptionEquals extends ComparatorType
case object SeqEquals extends ComparatorType
case object SetEquals extends ComparatorType
case object MapEquals extends ComparatorType
case class MapEquals(valTpe: TypeRef) extends ComparatorType
}

def comparator(ref: TypeRef): ComparatorType = {
Expand All @@ -155,7 +155,7 @@ object TypeId {
ComparatorType.SetEquals

case TypeId.Builtins.map =>
ComparatorType.MapEquals
ComparatorType.MapEquals(c.args.last)
case TypeId.Builtins.lst =>
ComparatorType.SeqEquals
case _ =>
Expand Down

0 comments on commit c63e409

Please sign in to comment.