Skip to content

Commit

Permalink
GH-42041: [Swift] Fix nullable type decoder issue (#42043)
Browse files Browse the repository at this point in the history
### Rationale for this change

There is an issue when decoding nullable types.  The previous method of checking for nil values always returned false for nullable types due too the ArrowArray types being non nullable.

### What changes are included in this PR?
This PR adds a IsNull method to the ArrowDecoder to be used for null checks.  Also, a check for nullable types has been added to the Unkeyed decode method.

### Are these changes tested?
Yes, tests have been added/modified to test this fix.
* GitHub Issue: #42041

Authored-by: Alva Bandy <[email protected]>
Signed-off-by: Sutou Kouhei <[email protected]>
  • Loading branch information
abandy authored Jun 9, 2024
1 parent 399408c commit 7aaea3d
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 16 deletions.
31 changes: 26 additions & 5 deletions swift/Arrow/Sources/Arrow/ArrowDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,16 @@ public class ArrowDecoder: Decoder {
let array: AnyArray = try self.getCol(col)
return array.asAny(self.rbIndex) as? T
}

func isNull(_ key: CodingKey) throws -> Bool {
let array: AnyArray = try self.getCol(key.stringValue)
return array.asAny(self.rbIndex) == nil
}

func isNull(_ col: Int) throws -> Bool {
let array: AnyArray = try self.getCol(col)
return array.asAny(self.rbIndex) == nil
}
}

private struct ArrowUnkeyedDecoding: UnkeyedDecodingContainer {
Expand All @@ -126,11 +136,17 @@ private struct ArrowUnkeyedDecoding: UnkeyedDecodingContainer {

mutating func decodeNil() throws -> Bool {
defer {increment()}
return try self.decoder.doDecode(self.currentIndex) == nil
return try self.decoder.isNull(self.currentIndex)
}

mutating func decode<T>(_ type: T.Type) throws -> T where T: Decodable {
if type == Int8.self || type == Int16.self ||
if type == Int8?.self || type == Int16?.self ||
type == Int32?.self || type == Int64?.self ||
type == UInt8?.self || type == UInt16?.self ||
type == UInt32?.self || type == UInt64?.self ||
type == String?.self || type == Double?.self ||
type == Float?.self || type == Date?.self ||
type == Int8.self || type == Int16.self ||
type == Int32.self || type == Int64.self ||
type == UInt8.self || type == UInt16.self ||
type == UInt32.self || type == UInt64.self ||
Expand Down Expand Up @@ -173,7 +189,7 @@ private struct ArrowKeyedDecoding<Key: CodingKey>: KeyedDecodingContainerProtoco
}

func decodeNil(forKey key: Key) throws -> Bool {
return try self.decoder.doDecode(key) == nil
try self.decoder.isNull(key)
}

func decode(_ type: Bool.Type, forKey key: Key) throws -> Bool {
Expand Down Expand Up @@ -273,7 +289,7 @@ private struct ArrowSingleValueDecoding: SingleValueDecodingContainer {

func decodeNil() -> Bool {
do {
return try self.decoder.doDecode(0) == nil
return try self.decoder.isNull(0)
} catch {
return false
}
Expand Down Expand Up @@ -338,7 +354,12 @@ private struct ArrowSingleValueDecoding: SingleValueDecodingContainer {
}

func decode<T>(_ type: T.Type) throws -> T where T: Decodable {
if type == Date.self {
if type == Int8.self || type == Int16.self ||
type == Int32.self || type == Int64.self ||
type == UInt8.self || type == UInt16.self ||
type == UInt32.self || type == UInt64.self ||
type == String.self || type == Double.self ||
type == Float.self || type == Date.self {
return try self.decoder.doDecode(0)!
} else {
throw ArrowError.invalid("Type \(type) is currently not supported")
Expand Down
73 changes: 62 additions & 11 deletions swift/Arrow/Tests/ArrowTests/CodableTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ final class CodableTests: XCTestCase {
public var propUInt32: UInt32
public var propUInt64: UInt64
public var propFloat: Float
public var propDouble: Double
public var propDouble: Double?
public var propString: String
public var propDate: Date

Expand All @@ -53,7 +53,6 @@ final class CodableTests: XCTestCase {

func testArrowKeyedDecoder() throws { // swiftlint:disable:this function_body_length
let date1 = Date(timeIntervalSinceReferenceDate: 86400 * 5000 + 352)

let boolBuilder = try ArrowArrayBuilders.loadBoolArrayBuilder()
let int8Builder: NumberArrayBuilder<Int8> = try ArrowArrayBuilders.loadNumberArrayBuilder()
let int16Builder: NumberArrayBuilder<Int16> = try ArrowArrayBuilders.loadNumberArrayBuilder()
Expand All @@ -78,7 +77,7 @@ final class CodableTests: XCTestCase {
uint32Builder.append(70, 71, 72)
uint64Builder.append(80, 81, 82)
floatBuilder.append(90.1, 91.1, 92.1)
doubleBuilder.append(100.1, 101.1, 102.1)
doubleBuilder.append(101.1, nil, nil)
stringBuilder.append("test0", "test1", "test2")
dateBuilder.append(date1, date1, date1)
let result = RecordBatch.Builder()
Expand All @@ -102,7 +101,6 @@ final class CodableTests: XCTestCase {
var testClasses = try decoder.decode(TestClass.self)
for index in 0..<testClasses.count {
let testClass = testClasses[index]
var col = 0
XCTAssertEqual(testClass.propBool, index % 2 == 0 ? false : true)
XCTAssertEqual(testClass.propInt8, Int8(index + 10))
XCTAssertEqual(testClass.propInt16, Int16(index + 20))
Expand All @@ -113,7 +111,11 @@ final class CodableTests: XCTestCase {
XCTAssertEqual(testClass.propUInt32, UInt32(index + 70))
XCTAssertEqual(testClass.propUInt64, UInt64(index + 80))
XCTAssertEqual(testClass.propFloat, Float(index) + 90.1)
XCTAssertEqual(testClass.propDouble, Double(index) + 100.1)
if index == 0 {
XCTAssertEqual(testClass.propDouble, 101.1)
} else {
XCTAssertEqual(testClass.propDouble, nil)
}
XCTAssertEqual(testClass.propString, "test\(index)")
XCTAssertEqual(testClass.propDate, date1)
}
Expand All @@ -122,9 +124,9 @@ final class CodableTests: XCTestCase {
}
}

func testArrowSingleDecoder() throws {
func testArrowSingleDecoderWithoutNull() throws {
let int8Builder: NumberArrayBuilder<Int8> = try ArrowArrayBuilders.loadNumberArrayBuilder()
int8Builder.append(10, 11, 12, nil)
int8Builder.append(10, 11, 12)
let result = RecordBatch.Builder()
.addColumn("propInt8", arrowArray: try int8Builder.toHolder())
.finish()
Expand All @@ -134,7 +136,28 @@ final class CodableTests: XCTestCase {
let testData = try decoder.decode(Int8?.self)
for index in 0..<testData.count {
let val: Int8? = testData[index]
if val != nil {
XCTAssertEqual(val!, Int8(index + 10))
}
case .failure(let err):
throw err
}
}

func testArrowSingleDecoderWithNull() throws {
let int8WNilBuilder: NumberArrayBuilder<Int8> = try ArrowArrayBuilders.loadNumberArrayBuilder()
int8WNilBuilder.append(10, nil, 12, nil)
let resultWNil = RecordBatch.Builder()
.addColumn("propInt8", arrowArray: try int8WNilBuilder.toHolder())
.finish()
switch resultWNil {
case .success(let rb):
let decoder = ArrowDecoder(rb)
let testData = try decoder.decode(Int8?.self)
for index in 0..<testData.count {
let val: Int8? = testData[index]
if index % 2 == 1 {
XCTAssertNil(val)
} else {
XCTAssertEqual(val!, Int8(index + 10))
}
}
Expand All @@ -143,11 +166,11 @@ final class CodableTests: XCTestCase {
}
}

func testArrowUnkeyedDecoder() throws {
func testArrowUnkeyedDecoderWithoutNull() throws {
let int8Builder: NumberArrayBuilder<Int8> = try ArrowArrayBuilders.loadNumberArrayBuilder()
let stringBuilder = try ArrowArrayBuilders.loadStringArrayBuilder()
int8Builder.append(10, 11, 12)
stringBuilder.append("test0", "test1", "test2")
int8Builder.append(10, 11, 12, 13)
stringBuilder.append("test0", "test1", "test2", "test3")
let result = RecordBatch.Builder()
.addColumn("propInt8", arrowArray: try int8Builder.toHolder())
.addColumn("propString", arrowArray: try stringBuilder.toHolder())
Expand All @@ -167,4 +190,32 @@ final class CodableTests: XCTestCase {
}
}

func testArrowUnkeyedDecoderWithNull() throws {
let int8Builder: NumberArrayBuilder<Int8> = try ArrowArrayBuilders.loadNumberArrayBuilder()
let stringWNilBuilder = try ArrowArrayBuilders.loadStringArrayBuilder()
int8Builder.append(10, 11, 12, 13)
stringWNilBuilder.append(nil, "test1", nil, "test3")
let resultWNil = RecordBatch.Builder()
.addColumn("propInt8", arrowArray: try int8Builder.toHolder())
.addColumn("propString", arrowArray: try stringWNilBuilder.toHolder())
.finish()
switch resultWNil {
case .success(let rb):
let decoder = ArrowDecoder(rb)
let testData = try decoder.decode([Int8: String?].self)
var index: Int8 = 0
for data in testData {
let str = data[10 + index]
if index % 2 == 0 {
XCTAssertNil(str!)
} else {
XCTAssertEqual(str, "test\(index)")
}
index += 1
}
case .failure(let err):
throw err
}

}
}

0 comments on commit 7aaea3d

Please sign in to comment.