diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala index e6b9eb40786..70cbd66b29f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala @@ -200,7 +200,11 @@ object GpuOrcScan { } // Replace values that cause overflow with nulls, same with CPU ORC. withResource(overflowFlags) { _ => - casted.copyWithBooleanColumnAsValidity(overflowFlags) + // This is an integer type so we don't have to worry about + // nested DTypes here. + withResource(Scalar.fromNull(toType)) { NULL => + overflowFlags.ifElse(casted, NULL) + } } } } @@ -334,7 +338,9 @@ object GpuOrcScan { // next convert to long, // then down cast long to the target integral type. val longDoubles = withResource(doubleCanFitInLong(col)) { fitLongs => - col.copyWithBooleanColumnAsValidity(fitLongs) + withResource(Scalar.fromNull(fromDt)) { NULL => + fitLongs.ifElse(col, NULL) + } } withResource(longDoubles) { _ => withResource(longDoubles.castTo(DType.INT64)) { longs => diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala index 1af0f7143f1..67cfb421266 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -297,8 +297,10 @@ case class GpuArrayContains(left: Expression, right: Expression) val containsKeyOrNotContainsNull = withResource(notContainsNull) { containsResult.or(_) } - withResource(containsKeyOrNotContainsNull) { - containsResult.copyWithBooleanColumnAsValidity(_) + withResource(containsKeyOrNotContainsNull) { lcnn => + withResource(Scalar.fromNull(DType.BOOL8)) { NULL => + lcnn.ifElse(containsResult, NULL) + } } }