Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
[NSE-1108] allow to use different cases in column names (#1098)
Browse files Browse the repository at this point in the history
* allow to use differnent case in SQL

Signed-off-by: Yuan Zhou <[email protected]>

* fix window

Signed-off-by: Yuan Zhou <[email protected]>

* Revert "fix window"

This reverts commit 32dfc62.

* new fix for window

Signed-off-by: Yuan Zhou <[email protected]>

* fix join

note: this fix relies on lower case table schema, a better solution would be switching to the boundreference:
https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala#L69
  • Loading branch information
zhouyuan authored Sep 14, 2022
1 parent e58a4b3 commit f834def
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -254,13 +254,15 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression],
f.children
.flatMap {
case a: AttributeReference =>
val attr = ConverterUtils.getAttrFromExpr(a)
Some(TreeBuilder.makeField(
Field.nullable(a.name,
CodeGeneration.getResultType(a.dataType))))
Field.nullable(attr.name,
CodeGeneration.getResultType(attr.dataType))))
case c: Cast if c.child.isInstanceOf[AttributeReference] =>
val attr = ConverterUtils.getAttrFromExpr(c)
Some(TreeBuilder.makeField(
Field.nullable(c.child.asInstanceOf[AttributeReference].name,
CodeGeneration.getResultType(c.dataType))))
Field.nullable(attr.name,
CodeGeneration.getResultType(attr.dataType))))
case _: Cast | _ : Literal =>
None
case _ =>
Expand All @@ -271,9 +273,9 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression],
// TODO(yuan): using ConverterUtils.getAttrFromExpr
val groupingExpressions: Seq[AttributeReference] = partitionSpec.map{
case a: AttributeReference =>
a
ConverterUtils.getAttrFromExpr(a)
case c: Cast if c.child.isInstanceOf[AttributeReference] =>
c.child.asInstanceOf[AttributeReference]
ConverterUtils.getAttrFromExpr(c)
case _: Cast | _ : Literal =>
null
case n: KnownFloatingPointNormalized =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,14 @@ object ColumnarConditionedProbeJoin extends Logging {
buildInputAttributes: Seq[Attribute],
builder_type: Int = 1,
is_broadcast: Boolean = false): TreeNode = {
val buildInputFieldList: List[Field] = buildInputAttributes.toList.map(attr => {
Field
.nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(attr.dataType))
val buildInputAttrList: List[Attribute] = buildInputAttributes.toList.map(attr => {
attr.withName(attr.name.toLowerCase)
})
val buildKeysFunctionList: List[TreeNode] = buildKeys.toList.map(expr => {
val (nativeNode, returnType) = if (!is_broadcast) {
ConverterUtils.getColumnarFuncNode(expr)
} else {
ConverterUtils.getColumnarFuncNode(expr, buildInputAttributes)
ConverterUtils.getColumnarFuncNode(expr, buildInputAttrList)
}
if (s"${nativeNode.toProtobuf}".contains("none#")) {
throw new UnsupportedOperationException(
Expand Down Expand Up @@ -105,12 +104,10 @@ object ColumnarConditionedProbeJoin extends Logging {
builder_type: Int = 0,
isNullAwareAntiJoin: Boolean = false): TreeNode = {
val buildInputFieldList: List[Field] = buildInputAttributes.toList.map(attr => {
Field
.nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(attr.dataType))
ConverterUtils.createArrowField(attr)
})
val streamInputFieldList: List[Field] = streamInputAttributes.toList.map(attr => {
Field
.nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(attr.dataType))
ConverterUtils.createArrowField(attr)
})

val buildKeysFunctionList: List[TreeNode] = buildKeys.toList.map(expr => {
Expand All @@ -132,9 +129,7 @@ object ColumnarConditionedProbeJoin extends Logging {
})

val resultFunctionList: List[TreeNode] = output.toList.map(field => {
val field_node = Field.nullable(
s"${field.name}#${field.exprId.id}",
CodeGeneration.getResultType(field.dataType))
val field_node = ConverterUtils.createArrowField(field)
TreeBuilder.makeField(field_node)
})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ object ColumnarExpressionConverter extends Logging {
BindReferences.bindReference(expr, attributeSeq, true)
if (bindReference == expr) {
if (expIdx == -1) {
return new ColumnarAttributeReference(a.name, a.dataType, a.nullable, a.metadata)(
return new ColumnarAttributeReference(a.name.toLowerCase, a.dataType, a.nullable, a.metadata)(
a.exprId,
a.qualifier)
} else {
Expand All @@ -58,7 +58,7 @@ object ColumnarExpressionConverter extends Logging {
val b = bindReference.asInstanceOf[BoundReference]
new ColumnarBoundReference(b.ordinal, b.dataType, b.nullable)
} else {
return new ColumnarAttributeReference(a.name, a.dataType, a.nullable, a.metadata)(
return new ColumnarAttributeReference(a.name.toLowerCase, a.dataType, a.nullable, a.metadata)(
a.exprId,
a.qualifier)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,7 @@ class ColumnarHashAggregation(
}

val originalInputFieldList = originalInputAttributes.toList.map(attr => {
Field
.nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(attr.dataType))
ConverterUtils.createArrowField(attr)
})

//////////////// Project original input to aggregateExpression input //////////////////
Expand Down Expand Up @@ -501,7 +500,7 @@ class ColumnarHashAggregation(
val inputAttrs = originalInputAttributes.zipWithIndex
.filter {
case (attr, i) =>
!groupingAttributes.contains(attr) && !partialProjectOrdinalList.toList.contains(i)
!groupingAttributes.contains(attr.withName(attr.name.toLowerCase)) && !partialProjectOrdinalList.toList.contains(i)
}
.map(_._1)
inputAttrQueue = scala.collection.mutable.Queue(inputAttrs: _*)
Expand All @@ -516,10 +515,7 @@ class ColumnarHashAggregation(

val aggregateAttributeFieldList =
allAggregateResultAttributes.map(attr => {
Field
.nullable(
s"${attr.name}#${attr.exprId.id}",
CodeGeneration.getResultType(attr.dataType))
ConverterUtils.createArrowField(attr)
})

val nativeFuncNodes = groupingNativeFuncNodes ::: aggrNativeFuncNodes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ object ConverterUtils extends Logging {
case a: AggregateExpression =>
getAttrFromExpr(a.aggregateFunction.children(0))
case a: AttributeReference =>
a
a.withName(a.name.toLowerCase)
case a: Alias =>
if (skipAlias) {
if (a.child.isInstanceOf[AttributeReference] || a.child.isInstanceOf[Coalesce]) {
Expand Down Expand Up @@ -483,8 +483,7 @@ object ConverterUtils extends Logging {

def toArrowSchema(attributes: Seq[Attribute]): Schema = {
val fields = attributes.map(attr => {
Field
.nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(attr.dataType))
createArrowField(attr)
})
new Schema(fields.toList.asJava)
}
Expand Down Expand Up @@ -630,7 +629,7 @@ object ConverterUtils extends Logging {
}

def createArrowField(attr: Attribute): Field =
createArrowField(s"${attr.name}#${attr.exprId.id}", attr.dataType)
createArrowField(s"${attr.name.toLowerCase}#${attr.exprId.id}", attr.dataType)

private def asTimestampType(inType: ArrowType): ArrowType.Timestamp = {
if (inType.getTypeID != ArrowTypeID.Timestamp) {
Expand Down

0 comments on commit f834def

Please sign in to comment.