diff --git a/swift/Arrow/Sources/Arrow/ArrowDecoder.swift b/swift/Arrow/Sources/Arrow/ArrowDecoder.swift index 7e0c69b1e79e8..9aa8a65137d28 100644 --- a/swift/Arrow/Sources/Arrow/ArrowDecoder.swift +++ b/swift/Arrow/Sources/Arrow/ArrowDecoder.swift @@ -104,6 +104,16 @@ public class ArrowDecoder: Decoder { let array: AnyArray = try self.getCol(col) return array.asAny(self.rbIndex) as? T } + + func isNull(_ key: CodingKey) throws -> Bool { + let array: AnyArray = try self.getCol(key.stringValue) + return array.asAny(self.rbIndex) == nil + } + + func isNull(_ col: Int) throws -> Bool { + let array: AnyArray = try self.getCol(col) + return array.asAny(self.rbIndex) == nil + } } private struct ArrowUnkeyedDecoding: UnkeyedDecodingContainer { @@ -126,11 +136,17 @@ private struct ArrowUnkeyedDecoding: UnkeyedDecodingContainer { mutating func decodeNil() throws -> Bool { defer {increment()} - return try self.decoder.doDecode(self.currentIndex) == nil + return try self.decoder.isNull(self.currentIndex) } mutating func decode(_ type: T.Type) throws -> T where T: Decodable { - if type == Int8.self || type == Int16.self || + if type == Int8?.self || type == Int16?.self || + type == Int32?.self || type == Int64?.self || + type == UInt8?.self || type == UInt16?.self || + type == UInt32?.self || type == UInt64?.self || + type == String?.self || type == Double?.self || + type == Float?.self || type == Date?.self || + type == Int8.self || type == Int16.self || type == Int32.self || type == Int64.self || type == UInt8.self || type == UInt16.self || type == UInt32.self || type == UInt64.self || @@ -173,7 +189,7 @@ private struct ArrowKeyedDecoding: KeyedDecodingContainerProtoco } func decodeNil(forKey key: Key) throws -> Bool { - return try self.decoder.doDecode(key) == nil + try self.decoder.isNull(key) } func decode(_ type: Bool.Type, forKey key: Key) throws -> Bool { @@ -273,7 +289,7 @@ private struct ArrowSingleValueDecoding: SingleValueDecodingContainer { func decodeNil() -> Bool { do { - return try self.decoder.doDecode(0) == nil + return try self.decoder.isNull(0) } catch { return false } @@ -338,7 +354,12 @@ private struct ArrowSingleValueDecoding: SingleValueDecodingContainer { } func decode(_ type: T.Type) throws -> T where T: Decodable { - if type == Date.self { + if type == Int8.self || type == Int16.self || + type == Int32.self || type == Int64.self || + type == UInt8.self || type == UInt16.self || + type == UInt32.self || type == UInt64.self || + type == String.self || type == Double.self || + type == Float.self || type == Date.self { return try self.decoder.doDecode(0)! } else { throw ArrowError.invalid("Type \(type) is currently not supported") diff --git a/swift/Arrow/Tests/ArrowTests/CodableTests.swift b/swift/Arrow/Tests/ArrowTests/CodableTests.swift index e7359467ae1c5..d7d3414cf6250 100644 --- a/swift/Arrow/Tests/ArrowTests/CodableTests.swift +++ b/swift/Arrow/Tests/ArrowTests/CodableTests.swift @@ -30,7 +30,7 @@ final class CodableTests: XCTestCase { public var propUInt32: UInt32 public var propUInt64: UInt64 public var propFloat: Float - public var propDouble: Double + public var propDouble: Double? public var propString: String public var propDate: Date @@ -53,7 +53,6 @@ final class CodableTests: XCTestCase { func testArrowKeyedDecoder() throws { // swiftlint:disable:this function_body_length let date1 = Date(timeIntervalSinceReferenceDate: 86400 * 5000 + 352) - let boolBuilder = try ArrowArrayBuilders.loadBoolArrayBuilder() let int8Builder: NumberArrayBuilder = try ArrowArrayBuilders.loadNumberArrayBuilder() let int16Builder: NumberArrayBuilder = try ArrowArrayBuilders.loadNumberArrayBuilder() @@ -78,7 +77,7 @@ final class CodableTests: XCTestCase { uint32Builder.append(70, 71, 72) uint64Builder.append(80, 81, 82) floatBuilder.append(90.1, 91.1, 92.1) - doubleBuilder.append(100.1, 101.1, 102.1) + doubleBuilder.append(101.1, nil, nil) stringBuilder.append("test0", "test1", "test2") dateBuilder.append(date1, date1, date1) let result = RecordBatch.Builder() @@ -102,7 +101,6 @@ final class CodableTests: XCTestCase { var testClasses = try decoder.decode(TestClass.self) for index in 0.. = try ArrowArrayBuilders.loadNumberArrayBuilder() - int8Builder.append(10, 11, 12, nil) + int8Builder.append(10, 11, 12) let result = RecordBatch.Builder() .addColumn("propInt8", arrowArray: try int8Builder.toHolder()) .finish() @@ -134,7 +136,28 @@ final class CodableTests: XCTestCase { let testData = try decoder.decode(Int8?.self) for index in 0.. = try ArrowArrayBuilders.loadNumberArrayBuilder() + int8WNilBuilder.append(10, nil, 12, nil) + let resultWNil = RecordBatch.Builder() + .addColumn("propInt8", arrowArray: try int8WNilBuilder.toHolder()) + .finish() + switch resultWNil { + case .success(let rb): + let decoder = ArrowDecoder(rb) + let testData = try decoder.decode(Int8?.self) + for index in 0.. = try ArrowArrayBuilders.loadNumberArrayBuilder() let stringBuilder = try ArrowArrayBuilders.loadStringArrayBuilder() - int8Builder.append(10, 11, 12) - stringBuilder.append("test0", "test1", "test2") + int8Builder.append(10, 11, 12, 13) + stringBuilder.append("test0", "test1", "test2", "test3") let result = RecordBatch.Builder() .addColumn("propInt8", arrowArray: try int8Builder.toHolder()) .addColumn("propString", arrowArray: try stringBuilder.toHolder()) @@ -167,4 +190,32 @@ final class CodableTests: XCTestCase { } } + func testArrowUnkeyedDecoderWithNull() throws { + let int8Builder: NumberArrayBuilder = try ArrowArrayBuilders.loadNumberArrayBuilder() + let stringWNilBuilder = try ArrowArrayBuilders.loadStringArrayBuilder() + int8Builder.append(10, 11, 12, 13) + stringWNilBuilder.append(nil, "test1", nil, "test3") + let resultWNil = RecordBatch.Builder() + .addColumn("propInt8", arrowArray: try int8Builder.toHolder()) + .addColumn("propString", arrowArray: try stringWNilBuilder.toHolder()) + .finish() + switch resultWNil { + case .success(let rb): + let decoder = ArrowDecoder(rb) + let testData = try decoder.decode([Int8: String?].self) + var index: Int8 = 0 + for data in testData { + let str = data[10 + index] + if index % 2 == 0 { + XCTAssertNil(str!) + } else { + XCTAssertEqual(str, "test\(index)") + } + index += 1 + } + case .failure(let err): + throw err + } + + } }