Skip to content

Commit

Permalink
GH-42041:Fix nullable type decoder issue
Browse files Browse the repository at this point in the history
  • Loading branch information
abandy committed Jun 9, 2024
1 parent 399408c commit f95f5c4
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 16 deletions.
33 changes: 28 additions & 5 deletions swift/Arrow/Sources/Arrow/ArrowDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,16 @@ public class ArrowDecoder: Decoder {
let array: AnyArray = try self.getCol(col)
return array.asAny(self.rbIndex) as? T
}

func isNull(_ key: CodingKey) throws -> Bool {
let array: AnyArray = try self.getCol(key.stringValue)
return array.asAny(self.rbIndex) == nil
}

func isNull(_ col: Int) throws -> Bool {
let array: AnyArray = try self.getCol(col)
return array.asAny(self.rbIndex) == nil
}
}

private struct ArrowUnkeyedDecoding: UnkeyedDecodingContainer {
Expand All @@ -126,11 +136,19 @@ private struct ArrowUnkeyedDecoding: UnkeyedDecodingContainer {

mutating func decodeNil() throws -> Bool {
defer {increment()}
return try self.decoder.doDecode(self.currentIndex) == nil
return try self.decoder.isNull(self.currentIndex)
}

mutating func decode<T>(_ type: T.Type) throws -> T where T: Decodable {
if type == Int8.self || type == Int16.self ||
if type == Int8?.self || type == Int16?.self ||
type == Int32?.self || type == Int64?.self ||
type == UInt8?.self || type == UInt16?.self ||
type == UInt32?.self || type == UInt64?.self ||
type == String?.self || type == Double?.self ||
type == Float?.self || type == Date?.self {
defer {increment()}
return try self.decoder.doDecode(self.currentIndex)!
} else if type == Int8.self || type == Int16.self ||
type == Int32.self || type == Int64.self ||
type == UInt8.self || type == UInt16.self ||
type == UInt32.self || type == UInt64.self ||
Expand Down Expand Up @@ -173,7 +191,7 @@ private struct ArrowKeyedDecoding<Key: CodingKey>: KeyedDecodingContainerProtoco
}

func decodeNil(forKey key: Key) throws -> Bool {
return try self.decoder.doDecode(key) == nil
try self.decoder.isNull(key)
}

func decode(_ type: Bool.Type, forKey key: Key) throws -> Bool {
Expand Down Expand Up @@ -273,7 +291,7 @@ private struct ArrowSingleValueDecoding: SingleValueDecodingContainer {

func decodeNil() -> Bool {
do {
return try self.decoder.doDecode(0) == nil
return try self.decoder.isNull(0)
} catch {
return false
}
Expand Down Expand Up @@ -338,7 +356,12 @@ private struct ArrowSingleValueDecoding: SingleValueDecodingContainer {
}

func decode<T>(_ type: T.Type) throws -> T where T: Decodable {
if type == Date.self {
if type == Int8.self || type == Int16.self ||
type == Int32.self || type == Int64.self ||
type == UInt8.self || type == UInt16.self ||
type == UInt32.self || type == UInt64.self ||
type == String.self || type == Double.self ||
type == Float.self || type == Date.self {
return try self.decoder.doDecode(0)!
} else {
throw ArrowError.invalid("Type \(type) is currently not supported")
Expand Down
67 changes: 56 additions & 11 deletions swift/Arrow/Tests/ArrowTests/CodableTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import XCTest
@testable import Arrow

final class CodableTests: XCTestCase {
public class TestClass: Codable {
public class TestClass: Codable, NoArgInit {
public var propBool: Bool
public var propInt8: Int8
public var propInt16: Int16
Expand All @@ -30,7 +30,7 @@ final class CodableTests: XCTestCase {
public var propUInt32: UInt32
public var propUInt64: UInt64
public var propFloat: Float
public var propDouble: Double
public var propDouble: Double?
public var propString: String
public var propDate: Date

Expand All @@ -53,7 +53,6 @@ final class CodableTests: XCTestCase {

func testArrowKeyedDecoder() throws { // swiftlint:disable:this function_body_length
let date1 = Date(timeIntervalSinceReferenceDate: 86400 * 5000 + 352)

let boolBuilder = try ArrowArrayBuilders.loadBoolArrayBuilder()
let int8Builder: NumberArrayBuilder<Int8> = try ArrowArrayBuilders.loadNumberArrayBuilder()
let int16Builder: NumberArrayBuilder<Int16> = try ArrowArrayBuilders.loadNumberArrayBuilder()
Expand All @@ -78,7 +77,7 @@ final class CodableTests: XCTestCase {
uint32Builder.append(70, 71, 72)
uint64Builder.append(80, 81, 82)
floatBuilder.append(90.1, 91.1, 92.1)
doubleBuilder.append(100.1, 101.1, 102.1)
doubleBuilder.append(101.1, nil, nil)
stringBuilder.append("test0", "test1", "test2")
dateBuilder.append(date1, date1, date1)
let result = RecordBatch.Builder()
Expand All @@ -102,7 +101,6 @@ final class CodableTests: XCTestCase {
var testClasses = try decoder.decode(TestClass.self)
for index in 0..<testClasses.count {
let testClass = testClasses[index]
var col = 0
XCTAssertEqual(testClass.propBool, index % 2 == 0 ? false : true)
XCTAssertEqual(testClass.propInt8, Int8(index + 10))
XCTAssertEqual(testClass.propInt16, Int16(index + 20))
Expand All @@ -113,7 +111,11 @@ final class CodableTests: XCTestCase {
XCTAssertEqual(testClass.propUInt32, UInt32(index + 70))
XCTAssertEqual(testClass.propUInt64, UInt64(index + 80))
XCTAssertEqual(testClass.propFloat, Float(index) + 90.1)
XCTAssertEqual(testClass.propDouble, Double(index) + 100.1)
if index == 0 {
XCTAssertEqual(testClass.propDouble, 101.1)
} else {
XCTAssertEqual(testClass.propDouble, nil)
}
XCTAssertEqual(testClass.propString, "test\(index)")
XCTAssertEqual(testClass.propDate, date1)
}
Expand All @@ -124,7 +126,7 @@ final class CodableTests: XCTestCase {

func testArrowSingleDecoder() throws {
let int8Builder: NumberArrayBuilder<Int8> = try ArrowArrayBuilders.loadNumberArrayBuilder()
int8Builder.append(10, 11, 12, nil)
int8Builder.append(10, 11, 12)
let result = RecordBatch.Builder()
.addColumn("propInt8", arrowArray: try int8Builder.toHolder())
.finish()
Expand All @@ -134,7 +136,26 @@ final class CodableTests: XCTestCase {
let testData = try decoder.decode(Int8?.self)
for index in 0..<testData.count {
let val: Int8? = testData[index]
if val != nil {
XCTAssertEqual(val!, Int8(index + 10))
}
case .failure(let err):
throw err
}

let int8WNilBuilder: NumberArrayBuilder<Int8> = try ArrowArrayBuilders.loadNumberArrayBuilder()
int8WNilBuilder.append(10, nil, 12, nil)
let resultWNil = RecordBatch.Builder()
.addColumn("propInt8", arrowArray: try int8WNilBuilder.toHolder())
.finish()
switch resultWNil {
case .success(let rb):
let decoder = ArrowDecoder(rb)
let testData = try decoder.decode(Int8?.self)
for index in 0..<testData.count {
let val: Int8? = testData[index]
if index % 2 == 1 {
XCTAssertNil(val)
} else {
XCTAssertEqual(val!, Int8(index + 10))
}
}
Expand All @@ -146,8 +167,8 @@ final class CodableTests: XCTestCase {
func testArrowUnkeyedDecoder() throws {
let int8Builder: NumberArrayBuilder<Int8> = try ArrowArrayBuilders.loadNumberArrayBuilder()
let stringBuilder = try ArrowArrayBuilders.loadStringArrayBuilder()
int8Builder.append(10, 11, 12)
stringBuilder.append("test0", "test1", "test2")
int8Builder.append(10, 11, 12, 13)
stringBuilder.append("test0", "test1", "test2", "test3")
let result = RecordBatch.Builder()
.addColumn("propInt8", arrowArray: try int8Builder.toHolder())
.addColumn("propString", arrowArray: try stringBuilder.toHolder())
Expand All @@ -165,6 +186,30 @@ final class CodableTests: XCTestCase {
case .failure(let err):
throw err
}
}

let stringWNilBuilder = try ArrowArrayBuilders.loadStringArrayBuilder()
stringWNilBuilder.append(nil, "test1", nil, "test3")
let resultWNil = RecordBatch.Builder()
.addColumn("propInt8", arrowArray: try int8Builder.toHolder())
.addColumn("propString", arrowArray: try stringWNilBuilder.toHolder())
.finish()
switch resultWNil {
case .success(let rb):
let decoder = ArrowDecoder(rb)
let testData = try decoder.decode([Int8: String?].self)
var index: Int8 = 0
for data in testData {
let str = data[10 + index]
if index % 2 == 0 {
XCTAssertNil(str!)
} else {
XCTAssertEqual(str, "test\(index)")
}
index += 1
}
case .failure(let err):
throw err
}

}
}

0 comments on commit f95f5c4

Please sign in to comment.