From 21ffd82c05c93b873ae3c27128eb8604ed0c735f Mon Sep 17 00:00:00 2001 From: abandy Date: Sat, 27 Jan 2024 17:25:54 -0500 Subject: [PATCH] GH-39720: [Swift] Switch reader to use arrow field instead of proto for building arrays (#39721) This PR updates the ArrowReaderHelper to use an ArrowField object for building an Array instead of a protobuf field obj. This removes leveraging protobuf from building out the Arrays and makes the code easier to reuse (like for the C Data Interface) * Closes: #39720 Authored-by: Alva Bandy Signed-off-by: Sutou Kouhei --- .../Sources/Arrow/ArrowArrayBuilder.swift | 12 +- swift/Arrow/Sources/Arrow/ArrowData.swift | 4 +- .../Sources/Arrow/ArrowReaderHelper.swift | 159 ++++++++---------- swift/Arrow/Sources/Arrow/ArrowType.swift | 44 +++++ swift/Arrow/Sources/Arrow/ProtoUtil.swift | 72 ++++++++ swift/Arrow/Tests/ArrowTests/ArrayTests.swift | 34 ++++ 6 files changed, 221 insertions(+), 104 deletions(-) create mode 100644 swift/Arrow/Sources/Arrow/ProtoUtil.swift diff --git a/swift/Arrow/Sources/Arrow/ArrowArrayBuilder.swift b/swift/Arrow/Sources/Arrow/ArrowArrayBuilder.swift index 32728dc7eeaa4..b78f0ccd74997 100644 --- a/swift/Arrow/Sources/Arrow/ArrowArrayBuilder.swift +++ b/swift/Arrow/Sources/Arrow/ArrowArrayBuilder.swift @@ -36,12 +36,12 @@ public class ArrowArrayBuilder> public func finish() throws -> ArrowArray { let buffers = self.bufferBuilder.finish() - let arrowData = try ArrowData(self.type, buffers: buffers, nullCount: self.nullCount, stride: self.getStride()) + let arrowData = try ArrowData(self.type, buffers: buffers, nullCount: self.nullCount) return U(arrowData) } public func getStride() -> Int { - MemoryLayout.stride + return self.type.getStride() } } @@ -73,20 +73,12 @@ public class Date32ArrayBuilder: ArrowArrayBuilder Int { - MemoryLayout.stride - } } public class Date64ArrayBuilder: ArrowArrayBuilder { fileprivate convenience init() throws { try self.init(ArrowType(ArrowType.ArrowDate64)) } - - public override func getStride() -> Int { - MemoryLayout.stride - } } public class Time32ArrayBuilder: ArrowArrayBuilder, Time32Array> { diff --git a/swift/Arrow/Sources/Arrow/ArrowData.swift b/swift/Arrow/Sources/Arrow/ArrowData.swift index 60281a8d24133..93986b5955bd8 100644 --- a/swift/Arrow/Sources/Arrow/ArrowData.swift +++ b/swift/Arrow/Sources/Arrow/ArrowData.swift @@ -24,7 +24,7 @@ public class ArrowData { public let length: UInt public let stride: Int - init(_ arrowType: ArrowType, buffers: [ArrowBuffer], nullCount: UInt, stride: Int) throws { + init(_ arrowType: ArrowType, buffers: [ArrowBuffer], nullCount: UInt) throws { let infoType = arrowType.info switch infoType { case let .primitiveInfo(typeId): @@ -45,7 +45,7 @@ public class ArrowData { self.buffers = buffers self.nullCount = nullCount self.length = buffers[1].length - self.stride = stride + self.stride = arrowType.getStride() } public func isNull(_ at: UInt) -> Bool { diff --git a/swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift b/swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift index 7b3ec04b3aa36..fb4a13b766f10 100644 --- a/swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift +++ b/swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift @@ -21,8 +21,8 @@ import Foundation private func makeBinaryHolder(_ buffers: [ArrowBuffer], nullCount: UInt) -> Result { do { - let arrowData = try ArrowData(ArrowType(ArrowType.ArrowBinary), buffers: buffers, - nullCount: nullCount, stride: MemoryLayout.stride) + let arrowType = ArrowType(ArrowType.ArrowBinary) + let arrowData = try ArrowData(arrowType, buffers: buffers, nullCount: nullCount) return .success(ArrowArrayHolder(BinaryArray(arrowData))) } catch let error as ArrowError { return .failure(error) @@ -34,8 +34,8 @@ private func makeBinaryHolder(_ buffers: [ArrowBuffer], private func makeStringHolder(_ buffers: [ArrowBuffer], nullCount: UInt) -> Result { do { - let arrowData = try ArrowData(ArrowType(ArrowType.ArrowString), buffers: buffers, - nullCount: nullCount, stride: MemoryLayout.stride) + let arrowType = ArrowType(ArrowType.ArrowString) + let arrowData = try ArrowData(arrowType, buffers: buffers, nullCount: nullCount) return .success(ArrowArrayHolder(StringArray(arrowData))) } catch let error as ArrowError { return .failure(error) @@ -44,33 +44,17 @@ private func makeStringHolder(_ buffers: [ArrowBuffer], } } -private func makeFloatHolder(_ floatType: org_apache_arrow_flatbuf_FloatingPoint, - buffers: [ArrowBuffer], - nullCount: UInt -) -> Result { - switch floatType.precision { - case .single: - return makeFixedHolder(Float.self, buffers: buffers, arrowType: ArrowType.ArrowFloat, nullCount: nullCount) - case .double: - return makeFixedHolder(Double.self, buffers: buffers, arrowType: ArrowType.ArrowDouble, nullCount: nullCount) - default: - return .failure(.unknownType("Float precision \(floatType.precision) currently not supported")) - } -} - -private func makeDateHolder(_ dateType: org_apache_arrow_flatbuf_Date, +private func makeDateHolder(_ field: ArrowField, buffers: [ArrowBuffer], nullCount: UInt ) -> Result { do { - if dateType.unit == .day { - let arrowData = try ArrowData(ArrowType(ArrowType.ArrowString), buffers: buffers, - nullCount: nullCount, stride: MemoryLayout.stride) + if field.type.id == .date32 { + let arrowData = try ArrowData(field.type, buffers: buffers, nullCount: nullCount) return .success(ArrowArrayHolder(Date32Array(arrowData))) } - let arrowData = try ArrowData(ArrowType(ArrowType.ArrowString), buffers: buffers, - nullCount: nullCount, stride: MemoryLayout.stride) + let arrowData = try ArrowData(field.type, buffers: buffers, nullCount: nullCount) return .success(ArrowArrayHolder(Date64Array(arrowData))) } catch let error as ArrowError { return .failure(error) @@ -79,22 +63,26 @@ private func makeDateHolder(_ dateType: org_apache_arrow_flatbuf_Date, } } -private func makeTimeHolder(_ timeType: org_apache_arrow_flatbuf_Time, +private func makeTimeHolder(_ field: ArrowField, buffers: [ArrowBuffer], nullCount: UInt ) -> Result { do { - if timeType.unit == .second || timeType.unit == .millisecond { - let arrowUnit: ArrowTime32Unit = timeType.unit == .second ? .seconds : .milliseconds - let arrowData = try ArrowData(ArrowTypeTime32(arrowUnit), buffers: buffers, - nullCount: nullCount, stride: MemoryLayout.stride) - return .success(ArrowArrayHolder(FixedArray(arrowData))) + if field.type.id == .time32 { + if let arrowType = field.type as? ArrowTypeTime32 { + let arrowData = try ArrowData(arrowType, buffers: buffers, nullCount: nullCount) + return .success(ArrowArrayHolder(FixedArray(arrowData))) + } else { + return .failure(.invalid("Incorrect field type for time: \(field.type)")) + } } - let arrowUnit: ArrowTime64Unit = timeType.unit == .microsecond ? .microseconds : .nanoseconds - let arrowData = try ArrowData(ArrowTypeTime64(arrowUnit), buffers: buffers, - nullCount: nullCount, stride: MemoryLayout.stride) - return .success(ArrowArrayHolder(FixedArray(arrowData))) + if let arrowType = field.type as? ArrowTypeTime64 { + let arrowData = try ArrowData(arrowType, buffers: buffers, nullCount: nullCount) + return .success(ArrowArrayHolder(FixedArray(arrowData))) + } else { + return .failure(.invalid("Incorrect field type for time: \(field.type)")) + } } catch let error as ArrowError { return .failure(error) } catch { @@ -105,8 +93,8 @@ private func makeTimeHolder(_ timeType: org_apache_arrow_flatbuf_Time, private func makeBoolHolder(_ buffers: [ArrowBuffer], nullCount: UInt) -> Result { do { - let arrowData = try ArrowData(ArrowType(ArrowType.ArrowBool), buffers: buffers, - nullCount: nullCount, stride: MemoryLayout.stride) + let arrowType = ArrowType(ArrowType.ArrowBool) + let arrowData = try ArrowData(arrowType, buffers: buffers, nullCount: nullCount) return .success(ArrowArrayHolder(BoolArray(arrowData))) } catch let error as ArrowError { return .failure(error) @@ -116,13 +104,11 @@ private func makeBoolHolder(_ buffers: [ArrowBuffer], } private func makeFixedHolder( - _: T.Type, buffers: [ArrowBuffer], - arrowType: ArrowType.Info, + _: T.Type, field: ArrowField, buffers: [ArrowBuffer], nullCount: UInt ) -> Result { do { - let arrowData = try ArrowData(ArrowType(arrowType), buffers: buffers, - nullCount: nullCount, stride: MemoryLayout.stride) + let arrowData = try ArrowData(field.type, buffers: buffers, nullCount: nullCount) return .success(ArrowArrayHolder(FixedArray(arrowData))) } catch let error as ArrowError { return .failure(error) @@ -131,67 +117,56 @@ private func makeFixedHolder( } } -func makeArrayHolder( // swiftlint:disable:this cyclomatic_complexity function_body_length +func makeArrayHolder( _ field: org_apache_arrow_flatbuf_Field, buffers: [ArrowBuffer], nullCount: UInt ) -> Result { - let type = field.typeType - switch type { - case .int: - let intType = field.type(type: org_apache_arrow_flatbuf_Int.self)! - let bitWidth = intType.bitWidth - if bitWidth == 8 { - if intType.isSigned { - return makeFixedHolder(Int8.self, buffers: buffers, - arrowType: ArrowType.ArrowInt8, nullCount: nullCount) - } else { - return makeFixedHolder(UInt8.self, buffers: buffers, - arrowType: ArrowType.ArrowUInt8, nullCount: nullCount) - } - } else if bitWidth == 16 { - if intType.isSigned { - return makeFixedHolder(Int16.self, buffers: buffers, - arrowType: ArrowType.ArrowInt16, nullCount: nullCount) - } else { - return makeFixedHolder(UInt16.self, buffers: buffers, - arrowType: ArrowType.ArrowUInt16, nullCount: nullCount) - } - } else if bitWidth == 32 { - if intType.isSigned { - return makeFixedHolder(Int32.self, buffers: buffers, - arrowType: ArrowType.ArrowInt32, nullCount: nullCount) - } else { - return makeFixedHolder(UInt32.self, buffers: buffers, - arrowType: ArrowType.ArrowUInt32, nullCount: nullCount) - } - } else if bitWidth == 64 { - if intType.isSigned { - return makeFixedHolder(Int64.self, buffers: buffers, - arrowType: ArrowType.ArrowInt64, nullCount: nullCount) - } else { - return makeFixedHolder(UInt64.self, buffers: buffers, - arrowType: ArrowType.ArrowUInt64, nullCount: nullCount) - } - } - return .failure(.unknownType("Int width \(bitWidth) currently not supported")) - case .bool: + let arrowField = fromProto(field: field) + return makeArrayHolder(arrowField, buffers: buffers, nullCount: nullCount) +} + +func makeArrayHolder( // swiftlint:disable:this cyclomatic_complexity + _ field: ArrowField, + buffers: [ArrowBuffer], + nullCount: UInt +) -> Result { + let typeId = field.type.id + switch typeId { + case .int8: + return makeFixedHolder(Int8.self, field: field, buffers: buffers, nullCount: nullCount) + case .uint8: + return makeFixedHolder(UInt8.self, field: field, buffers: buffers, nullCount: nullCount) + case .int16: + return makeFixedHolder(Int16.self, field: field, buffers: buffers, nullCount: nullCount) + case .uint16: + return makeFixedHolder(UInt16.self, field: field, buffers: buffers, nullCount: nullCount) + case .int32: + return makeFixedHolder(Int32.self, field: field, buffers: buffers, nullCount: nullCount) + case .uint32: + return makeFixedHolder(UInt32.self, field: field, buffers: buffers, nullCount: nullCount) + case .int64: + return makeFixedHolder(Int64.self, field: field, buffers: buffers, nullCount: nullCount) + case .uint64: + return makeFixedHolder(UInt64.self, field: field, buffers: buffers, nullCount: nullCount) + case .boolean: return makeBoolHolder(buffers, nullCount: nullCount) - case .floatingpoint: - let floatType = field.type(type: org_apache_arrow_flatbuf_FloatingPoint.self)! - return makeFloatHolder(floatType, buffers: buffers, nullCount: nullCount) - case .utf8: + case .float: + return makeFixedHolder(Float.self, field: field, buffers: buffers, nullCount: nullCount) + case .double: + return makeFixedHolder(Double.self, field: field, buffers: buffers, nullCount: nullCount) + case .string: return makeStringHolder(buffers, nullCount: nullCount) case .binary: return makeBinaryHolder(buffers, nullCount: nullCount) - case .date: - let dateType = field.type(type: org_apache_arrow_flatbuf_Date.self)! - return makeDateHolder(dateType, buffers: buffers, nullCount: nullCount) - case .time: - let timeType = field.type(type: org_apache_arrow_flatbuf_Time.self)! - return makeTimeHolder(timeType, buffers: buffers, nullCount: nullCount) + case .date32: + return makeDateHolder(field, buffers: buffers, nullCount: nullCount) + case .time32: + return makeTimeHolder(field, buffers: buffers, nullCount: nullCount) + case .time64: + return makeTimeHolder(field, buffers: buffers, nullCount: nullCount) default: - return .failure(.unknownType("Type \(type) currently not supported")) + return .failure(.unknownType("Type \(typeId) currently not supported")) } } diff --git a/swift/Arrow/Sources/Arrow/ArrowType.swift b/swift/Arrow/Sources/Arrow/ArrowType.swift index e63647d0797ee..f5a869f7cdaff 100644 --- a/swift/Arrow/Sources/Arrow/ArrowType.swift +++ b/swift/Arrow/Sources/Arrow/ArrowType.swift @@ -19,6 +19,8 @@ import Foundation public typealias Time32 = Int32 public typealias Time64 = Int64 +public typealias Date32 = Int32 +public typealias Date64 = Int64 func FlatBuffersVersion_23_1_4() { // swiftlint:disable:this identifier_name } @@ -165,6 +167,48 @@ public class ArrowType { return ArrowType.ArrowUnknown } } + + public func getStride( // swiftlint:disable:this cyclomatic_complexity + ) -> Int { + switch self.id { + case .int8: + return MemoryLayout.stride + case .int16: + return MemoryLayout.stride + case .int32: + return MemoryLayout.stride + case .int64: + return MemoryLayout.stride + case .uint8: + return MemoryLayout.stride + case .uint16: + return MemoryLayout.stride + case .uint32: + return MemoryLayout.stride + case .uint64: + return MemoryLayout.stride + case .float: + return MemoryLayout.stride + case .double: + return MemoryLayout.stride + case .boolean: + return MemoryLayout.stride + case .date32: + return MemoryLayout.stride + case .date64: + return MemoryLayout.stride + case .time32: + return MemoryLayout.stride + case .time64: + return MemoryLayout.stride + case .binary: + return MemoryLayout.stride + case .string: + return MemoryLayout.stride + default: + fatalError("Stride requested for unknown type: \(self)") + } + } } extension ArrowType.Info: Equatable { diff --git a/swift/Arrow/Sources/Arrow/ProtoUtil.swift b/swift/Arrow/Sources/Arrow/ProtoUtil.swift new file mode 100644 index 0000000000000..f7fd725fe1140 --- /dev/null +++ b/swift/Arrow/Sources/Arrow/ProtoUtil.swift @@ -0,0 +1,72 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import Foundation + +func fromProto( // swiftlint:disable:this cyclomatic_complexity + field: org_apache_arrow_flatbuf_Field +) -> ArrowField { + let type = field.typeType + var arrowType = ArrowType(ArrowType.ArrowUnknown) + switch type { + case .int: + let intType = field.type(type: org_apache_arrow_flatbuf_Int.self)! + let bitWidth = intType.bitWidth + if bitWidth == 8 { + arrowType = ArrowType(intType.isSigned ? ArrowType.ArrowInt8 : ArrowType.ArrowUInt8) + } else if bitWidth == 16 { + arrowType = ArrowType(intType.isSigned ? ArrowType.ArrowInt16 : ArrowType.ArrowUInt16) + } else if bitWidth == 32 { + arrowType = ArrowType(intType.isSigned ? ArrowType.ArrowInt32 : ArrowType.ArrowUInt32) + } else if bitWidth == 64 { + arrowType = ArrowType(intType.isSigned ? ArrowType.ArrowInt64 : ArrowType.ArrowUInt64) + } + case .bool: + arrowType = ArrowType(ArrowType.ArrowBool) + case .floatingpoint: + let floatType = field.type(type: org_apache_arrow_flatbuf_FloatingPoint.self)! + if floatType.precision == .single { + arrowType = ArrowType(ArrowType.ArrowFloat) + } else if floatType.precision == .double { + arrowType = ArrowType(ArrowType.ArrowDouble) + } + case .utf8: + arrowType = ArrowType(ArrowType.ArrowString) + case .binary: + arrowType = ArrowType(ArrowType.ArrowBinary) + case .date: + let dateType = field.type(type: org_apache_arrow_flatbuf_Date.self)! + if dateType.unit == .day { + arrowType = ArrowType(ArrowType.ArrowDate32) + } else { + arrowType = ArrowType(ArrowType.ArrowDate64) + } + case .time: + let timeType = field.type(type: org_apache_arrow_flatbuf_Time.self)! + if timeType.unit == .second || timeType.unit == .millisecond { + let arrowUnit: ArrowTime32Unit = timeType.unit == .second ? .seconds : .milliseconds + arrowType = ArrowTypeTime32(arrowUnit) + } else { + let arrowUnit: ArrowTime64Unit = timeType.unit == .microsecond ? .microseconds : .nanoseconds + arrowType = ArrowTypeTime64(arrowUnit) + } + default: + arrowType = ArrowType(ArrowType.ArrowUnknown) + } + + return ArrowField(field.name ?? "", type: arrowType, isNullable: field.nullable) +} diff --git a/swift/Arrow/Tests/ArrowTests/ArrayTests.swift b/swift/Arrow/Tests/ArrowTests/ArrayTests.swift index 069dbfc88f3ac..f5bfa0506e62f 100644 --- a/swift/Arrow/Tests/ArrowTests/ArrayTests.swift +++ b/swift/Arrow/Tests/ArrowTests/ArrayTests.swift @@ -211,4 +211,38 @@ final class ArrayTests: XCTestCase { XCTAssertEqual(microArray[1], 20000) XCTAssertEqual(microArray[2], 987654321) } + + func checkHolderForType(_ checkType: ArrowType) throws { + let buffers = [ArrowBuffer(length: 0, capacity: 0, + rawPointer: UnsafeMutableRawPointer.allocate(byteCount: 0, alignment: .zero)), + ArrowBuffer(length: 0, capacity: 0, + rawPointer: UnsafeMutableRawPointer.allocate(byteCount: 0, alignment: .zero))] + let field = ArrowField("", type: checkType, isNullable: true) + switch makeArrayHolder(field, buffers: buffers, nullCount: 0) { + case .success(let holder): + XCTAssertEqual(holder.type.id, checkType.id) + case .failure(let err): + throw err + } + } + + func testArrayHolders() throws { + try checkHolderForType(ArrowType(ArrowType.ArrowInt8)) + try checkHolderForType(ArrowType(ArrowType.ArrowUInt8)) + try checkHolderForType(ArrowType(ArrowType.ArrowInt16)) + try checkHolderForType(ArrowType(ArrowType.ArrowUInt16)) + try checkHolderForType(ArrowType(ArrowType.ArrowInt32)) + try checkHolderForType(ArrowType(ArrowType.ArrowUInt32)) + try checkHolderForType(ArrowType(ArrowType.ArrowInt64)) + try checkHolderForType(ArrowType(ArrowType.ArrowUInt64)) + try checkHolderForType(ArrowTypeTime32(.seconds)) + try checkHolderForType(ArrowTypeTime32(.milliseconds)) + try checkHolderForType(ArrowTypeTime64(.microseconds)) + try checkHolderForType(ArrowTypeTime64(.nanoseconds)) + try checkHolderForType(ArrowType(ArrowType.ArrowBinary)) + try checkHolderForType(ArrowType(ArrowType.ArrowFloat)) + try checkHolderForType(ArrowType(ArrowType.ArrowDouble)) + try checkHolderForType(ArrowType(ArrowType.ArrowBool)) + try checkHolderForType(ArrowType(ArrowType.ArrowString)) + } }