Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Oct 12, 2017
1 parent e8e8fee commit 200cd20
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,39 +23,43 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.sources.v2.reader._

/**
* A base class for data source reader holder and defines equals/hashCode methods.
* A base class for data source reader holder with customized equals/hashCode methods.
*/
trait DataSourceReaderHolder {

/**
* The full output of the data source reader, without column pruning.
*/
def fullOutput: Seq[AttributeReference]
def reader: DataSourceV2Reader

override def equals(other: Any): Boolean = other match {
case other: DataSourceV2Relation =>
val basicEquals = this.fullOutput == other.fullOutput &&
this.reader.getClass == other.reader.getClass &&
this.reader.readSchema() == other.reader.readSchema()
/**
* The held data source reader.
*/
def reader: DataSourceV2Reader

val samePushedFilters = (this.reader, other.reader) match {
case (l: SupportsPushDownCatalystFilters, r: SupportsPushDownCatalystFilters) =>
l.pushedCatalystFilters().toSeq == r.pushedCatalystFilters().toSeq
case (l: SupportsPushDownFilters, r: SupportsPushDownFilters) =>
l.pushedFilters().toSeq == r.pushedFilters().toSeq
case _ => true
}
/**
* The metadata of this data source reader that can be used for equality test.
*/
private def metadata: Seq[Any] = {
val filters: Any = reader match {
case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet
case s: SupportsPushDownFilters => s.pushedFilters().toSet
case _ => Nil
}
Seq(fullOutput, reader.getClass, reader.readSchema(), filters)
}

basicEquals && samePushedFilters
def canEqual(other: Any): Boolean

override def equals(other: Any): Boolean = other match {
case other: DataSourceReaderHolder =>
canEqual(other) && metadata.length == other.metadata.length &&
metadata.zip(other.metadata).forall { case (l, r) => l == r }
case _ => false
}

override def hashCode(): Int = {
val state = Seq(fullOutput, reader.getClass, reader.readSchema())
val filters: Any = reader match {
case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSeq
case s: SupportsPushDownFilters => s.pushedFilters().toSeq
case _ => Nil
}
(state :+ filters).map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b)
metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b)
}

lazy val output: Seq[Attribute] = reader.readSchema().map(_.name).map { name =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ case class DataSourceV2Relation(
fullOutput: Seq[AttributeReference],
reader: DataSourceV2Reader) extends LeafNode with DataSourceReaderHolder {

override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2Relation]

override def computeStats(): Statistics = reader match {
case r: SupportsReportStatistics =>
Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ case class DataSourceV2ScanExec(
fullOutput: Seq[AttributeReference],
@transient reader: DataSourceV2Reader) extends LeafExecNode with DataSourceReaderHolder {

override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec]

override def references: AttributeSet = AttributeSet.empty

override lazy val metrics = Map(
Expand Down

0 comments on commit 200cd20

Please sign in to comment.