Skip to content

Commit

Permalink
GH-37726: [Swift] Update flight behavior to be similar to existing impls
Browse files Browse the repository at this point in the history
  • Loading branch information
abandy committed Oct 25, 2023
1 parent 73589dd commit 4736385
Show file tree
Hide file tree
Showing 15 changed files with 5,521 additions and 176 deletions.
1 change: 1 addition & 0 deletions swift/.swiftlint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ excluded:
- Arrow/Sources/Arrow/Tensor_generated.swift
- ArrowFlight/Sources/ArrowFlight/Flight.grpc.swift
- ArrowFlight/Sources/ArrowFlight/Flight.pb.swift
- ArrowFlight/Sources/ArrowFlight/FlightSql.pb.swift
identifier_name:
min_length: 2 # only warning
allow_zero_lintable_files: false
72 changes: 61 additions & 11 deletions swift/Arrow/Sources/Arrow/ArrowReader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public class ArrowReader {
}

public class ArrowReaderResult {
fileprivate var messageSchema: org_apache_arrow_flatbuf_Schema?
public var schema: ArrowSchema?
public var batches = [RecordBatch]()
}
Expand Down Expand Up @@ -95,19 +96,19 @@ public class ArrowReader {
}
}

private func loadRecordBatch(_ message: org_apache_arrow_flatbuf_Message,
schema: org_apache_arrow_flatbuf_Schema,
arrowSchema: ArrowSchema,
data: Data,
messageEndOffset: Int64
private func loadRecordBatch(
_ recordBatch: org_apache_arrow_flatbuf_RecordBatch,
schema: org_apache_arrow_flatbuf_Schema,
arrowSchema: ArrowSchema,
data: Data,
messageEndOffset: Int64
) -> Result<RecordBatch, ArrowError> {
let recordBatch = message.header(type: org_apache_arrow_flatbuf_RecordBatch.self)
let nodesCount = recordBatch?.nodesCount ?? 0
let nodesCount = recordBatch.nodesCount
var bufferIndex: Int32 = 0
var columns: [ArrowArrayHolder] = []
for nodeIndex in 0 ..< nodesCount {
let field = schema.fields(at: nodeIndex)!
let loadInfo = DataLoadInfo(recordBatch: recordBatch!, field: field,
let loadInfo = DataLoadInfo(recordBatch: recordBatch, field: field,
nodeIndex: nodeIndex, bufferIndex: bufferIndex,
fileData: data, messageOffset: messageEndOffset)
var result: Result<ArrowArrayHolder, ArrowError>
Expand All @@ -130,7 +131,9 @@ public class ArrowReader {
return .success(RecordBatch(arrowSchema, columns: columns))
}

public func fromStream(_ fileData: Data) -> Result<ArrowReaderResult, ArrowError> {
public func fromStream( // swiftlint:disable:this function_body_length
_ fileData: Data
) -> Result<ArrowReaderResult, ArrowError> {
let footerLength = fileData.withUnsafeBytes { rawBuffer in
rawBuffer.loadUnaligned(fromByteOffset: fileData.count - 4, as: Int32.self)
}
Expand Down Expand Up @@ -172,8 +175,13 @@ public class ArrowReader {
switch message.headerType {
case .recordbatch:
do {
let recordBatch = try loadRecordBatch(message, schema: footer.schema!, arrowSchema: result.schema!,
data: fileData, messageEndOffset: messageEndOffset).get()
let rbMessage = message.header(type: org_apache_arrow_flatbuf_RecordBatch.self)!
let recordBatch = try loadRecordBatch(
rbMessage,
schema: footer.schema!,
arrowSchema: result.schema!,
data: fileData,
messageEndOffset: messageEndOffset).get()
result.batches.append(recordBatch)
} catch let error as ArrowError {
return .failure(error)
Expand Down Expand Up @@ -203,4 +211,46 @@ public class ArrowReader {
return .failure(.unknownError("Error loading file: \(error)"))
}
}

static public func makeArrowReaderResult() -> ArrowReaderResult {
return ArrowReaderResult()
}

public func fromMessage(
_ dataHeader: Data,
dataBody: Data,
result: ArrowReaderResult
) -> Result<Void, ArrowError> {
let mbb = ByteBuffer(data: dataHeader)
let message = org_apache_arrow_flatbuf_Message.getRootAsMessage(bb: mbb)
switch message.headerType {
case .schema:
let sMessage = message.header(type: org_apache_arrow_flatbuf_Schema.self)!
switch loadSchema(sMessage) {
case .success(let schema):
result.schema = schema
result.messageSchema = sMessage
return .success(())
case .failure(let error):
return .failure(error)
}
case .recordbatch:
let rbMessage = message.header(type: org_apache_arrow_flatbuf_RecordBatch.self)!
do {
let recordBatch = try loadRecordBatch(
rbMessage, schema: result.messageSchema!, arrowSchema: result.schema!,
data: dataBody, messageEndOffset: 0).get()
result.batches.append(recordBatch)
return .success(())
} catch let error as ArrowError {
return .failure(error)
} catch {
return .failure(.unknownError("Unexpected error: \(error)"))
}

default:
return .failure(.unknownError("Unhandled header type: \(message.headerType)"))
}
}

}
29 changes: 16 additions & 13 deletions swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ private func makeFloatHolder(_ floatType: org_apache_arrow_flatbuf_FloatingPoint
) -> Result<ArrowArrayHolder, ArrowError> {
switch floatType.precision {
case .single:
return makeFixedHolder(Float.self, buffers: buffers)
return makeFixedHolder(Float.self, buffers: buffers, arrowType: ArrowType.ArrowFloat)
case .double:
return makeFixedHolder(Double.self, buffers: buffers)
return makeFixedHolder(Double.self, buffers: buffers, arrowType: ArrowType.ArrowDouble)
default:
return .failure(.unknownType("Float precision \(floatType.precision) currently not supported"))
}
Expand Down Expand Up @@ -99,7 +99,7 @@ private func makeTimeHolder(_ timeType: org_apache_arrow_flatbuf_Time,

private func makeBoolHolder(_ buffers: [ArrowBuffer]) -> Result<ArrowArrayHolder, ArrowError> {
do {
let arrowData = try ArrowData(ArrowType(ArrowType.ArrowInt32), buffers: buffers,
let arrowData = try ArrowData(ArrowType(ArrowType.ArrowBool), buffers: buffers,
nullCount: buffers[0].length, stride: MemoryLayout<UInt8>.stride)
return .success(ArrowArrayHolder(BoolArray(arrowData)))
} catch let error as ArrowError {
Expand All @@ -109,9 +109,12 @@ private func makeBoolHolder(_ buffers: [ArrowBuffer]) -> Result<ArrowArrayHolder
}
}

private func makeFixedHolder<T>(_: T.Type, buffers: [ArrowBuffer]) -> Result<ArrowArrayHolder, ArrowError> {
private func makeFixedHolder<T>(
_: T.Type, buffers: [ArrowBuffer],
arrowType: ArrowType.Info
) -> Result<ArrowArrayHolder, ArrowError> {
do {
let arrowData = try ArrowData(ArrowType(ArrowType.ArrowInt32), buffers: buffers,
let arrowData = try ArrowData(ArrowType(arrowType), buffers: buffers,
nullCount: buffers[0].length, stride: MemoryLayout<T>.stride)
return .success(ArrowArrayHolder(FixedArray<T>(arrowData)))
} catch let error as ArrowError {
Expand All @@ -132,27 +135,27 @@ func makeArrayHolder( // swiftlint:disable:this cyclomatic_complexity
let bitWidth = intType.bitWidth
if bitWidth == 8 {
if intType.isSigned {
return makeFixedHolder(Int8.self, buffers: buffers)
return makeFixedHolder(Int8.self, buffers: buffers, arrowType: ArrowType.ArrowInt8)
} else {
return makeFixedHolder(UInt8.self, buffers: buffers)
return makeFixedHolder(UInt8.self, buffers: buffers, arrowType: ArrowType.ArrowUInt8)
}
} else if bitWidth == 16 {
if intType.isSigned {
return makeFixedHolder(Int16.self, buffers: buffers)
return makeFixedHolder(Int16.self, buffers: buffers, arrowType: ArrowType.ArrowInt16)
} else {
return makeFixedHolder(UInt16.self, buffers: buffers)
return makeFixedHolder(UInt16.self, buffers: buffers, arrowType: ArrowType.ArrowUInt16)
}
} else if bitWidth == 32 {
if intType.isSigned {
return makeFixedHolder(Int32.self, buffers: buffers)
return makeFixedHolder(Int32.self, buffers: buffers, arrowType: ArrowType.ArrowInt32)
} else {
return makeFixedHolder(UInt32.self, buffers: buffers)
return makeFixedHolder(UInt32.self, buffers: buffers, arrowType: ArrowType.ArrowUInt32)
}
} else if bitWidth == 64 {
if intType.isSigned {
return makeFixedHolder(Int64.self, buffers: buffers)
return makeFixedHolder(Int64.self, buffers: buffers, arrowType: ArrowType.ArrowInt64)
} else {
return makeFixedHolder(UInt64.self, buffers: buffers)
return makeFixedHolder(UInt64.self, buffers: buffers, arrowType: ArrowType.ArrowUInt64)
}
}
return .failure(.unknownType("Int width \(bitWidth) currently not supported"))
Expand Down
11 changes: 11 additions & 0 deletions swift/Arrow/Sources/Arrow/ArrowType.swift
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,17 @@ public class ArrowType {
self.info = info
}

public var id: ArrowTypeId {
switch self.info {
case .primitiveInfo(let id):
return id
case .timeInfo(let id):
return id
case .variableInfo(let id):
return id
}
}

public enum Info {
case primitiveInfo(ArrowTypeId)
case variableInfo(ArrowTypeId)
Expand Down
57 changes: 51 additions & 6 deletions swift/Arrow/Sources/Arrow/ArrowWriter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public protocol DataWriter {
func append(_ data: Data)
}

public class ArrowWriter {
public class ArrowWriter { // swiftlint:disable:this type_body_length
public class InMemDataWriter: DataWriter {
public private(set) var data: Data
public var count: Int { return data.count }
Expand Down Expand Up @@ -110,12 +110,15 @@ public class ArrowWriter {
endianness: .little,
fieldsVectorOffset: fieldsOffset)
return .success(schemaOffset)

}

private func writeRecordBatches(_ writer: inout DataWriter,
batches: [RecordBatch]
private func writeRecordBatches(
_ writer: inout DataWriter,
batches: [RecordBatch]
) -> Result<[org_apache_arrow_flatbuf_Block], ArrowError> {
var rbBlocks = [org_apache_arrow_flatbuf_Block]()

for batch in batches {
let startIndex = writer.count
switch writeRecordBatch(batch: batch) {
Expand All @@ -141,7 +144,6 @@ public class ArrowWriter {

private func writeRecordBatch(batch: RecordBatch) -> Result<(Data, Offset), ArrowError> {
let schema = batch.schema
var output = Data()
var fbb = FlatBufferBuilder()

// write out field nodes
Expand All @@ -156,6 +158,7 @@ public class ArrowWriter {
}

let nodeOffset = fbb.endVector(len: schema.fields.count)

// write out buffers
var buffers = [org_apache_arrow_flatbuf_Buffer]()
var bufferOffset = Int(0)
Expand All @@ -179,16 +182,17 @@ public class ArrowWriter {
let startRb = org_apache_arrow_flatbuf_RecordBatch.startRecordBatch(&fbb)
org_apache_arrow_flatbuf_RecordBatch.addVectorOf(nodes: nodeOffset, &fbb)
org_apache_arrow_flatbuf_RecordBatch.addVectorOf(buffers: batchBuffersOffset, &fbb)
org_apache_arrow_flatbuf_RecordBatch.add(length: Int64(batch.length), &fbb)
let recordBatchOffset = org_apache_arrow_flatbuf_RecordBatch.endRecordBatch(&fbb, start: startRb)
let bodySize = Int64(bufferOffset)
let startMessage = org_apache_arrow_flatbuf_Message.startMessage(&fbb)
org_apache_arrow_flatbuf_Message.add(version: .max, &fbb)
org_apache_arrow_flatbuf_Message.add(bodyLength: Int64(bodySize), &fbb)
org_apache_arrow_flatbuf_Message.add(headerType: .recordbatch, &fbb)
org_apache_arrow_flatbuf_Message.add(header: recordBatchOffset, &fbb)
let messageOffset = org_apache_arrow_flatbuf_Message.endMessage(&fbb, start: startMessage)
fbb.finish(offset: messageOffset)
output.append(fbb.data)
return .success((output, Offset(offset: UInt32(output.count))))
return .success((fbb.data, Offset(offset: UInt32(fbb.data.count))))
}

private func writeRecordBatchData(_ writer: inout DataWriter, batch: RecordBatch) -> Result<Bool, ArrowError> {
Expand Down Expand Up @@ -298,4 +302,45 @@ public class ArrowWriter {

return .success(true)
}

public func toMessage(_ batch: RecordBatch) -> Result<[Data], ArrowError> {
var writer: any DataWriter = InMemDataWriter()
switch writeRecordBatch(batch: batch) {
case .success(let message):
writer.append(message.0)
addPadForAlignment(&writer)
var dataWriter: any DataWriter = InMemDataWriter()
switch writeRecordBatchData(&dataWriter, batch: batch) {
case .success:
return .success([
(writer as! InMemDataWriter).data, // swiftlint:disable:this force_cast
(dataWriter as! InMemDataWriter).data // swiftlint:disable:this force_cast
])
case .failure(let error):
return .failure(error)
}
case .failure(let error):
return .failure(error)
}
}

public func toMessage(_ schema: ArrowSchema) -> Result<Data, ArrowError> {
var schemaSize: Int32 = 0
var fbb = FlatBufferBuilder()
switch writeSchema(&fbb, schema: schema) {
case .success(let schemaOffset):
schemaSize = Int32(schemaOffset.o)
case .failure(let error):
return .failure(error)
}

let startMessage = org_apache_arrow_flatbuf_Message.startMessage(&fbb)
org_apache_arrow_flatbuf_Message.add(bodyLength: Int64(0), &fbb)
org_apache_arrow_flatbuf_Message.add(headerType: .schema, &fbb)
org_apache_arrow_flatbuf_Message.add(header: Offset(offset: UOffset(schemaSize)), &fbb)
org_apache_arrow_flatbuf_Message.add(version: .max, &fbb)
let messageOffset = org_apache_arrow_flatbuf_Message.endMessage(&fbb, start: startMessage)
fbb.finish(offset: messageOffset)
return .success(fbb.data)
}
}
14 changes: 8 additions & 6 deletions swift/Arrow/Tests/ArrowTests/TableTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ final class TableTests: XCTestCase {
}

func testTable() throws {
let uint8Builder: NumberArrayBuilder<UInt8> = try ArrowArrayBuilders.loadNumberArrayBuilder()
uint8Builder.append(10)
uint8Builder.append(22)
let doubleBuilder: NumberArrayBuilder<Double> = try ArrowArrayBuilders.loadNumberArrayBuilder()
doubleBuilder.append(11.11)
doubleBuilder.append(22.22)
let stringBuilder = try ArrowArrayBuilders.loadStringArrayBuilder()
stringBuilder.append("test10")
stringBuilder.append("test22")
Expand All @@ -46,14 +46,14 @@ final class TableTests: XCTestCase {
date32Builder.append(date1)
date32Builder.append(date2)
let table = try ArrowTable.Builder()
.addColumn("col1", arrowArray: uint8Builder.finish())
.addColumn("col1", arrowArray: doubleBuilder.finish())
.addColumn("col2", arrowArray: stringBuilder.finish())
.addColumn("col3", arrowArray: date32Builder.finish())
.finish()
let schema = table.schema
XCTAssertEqual(schema.fields.count, 3)
XCTAssertEqual(schema.fields[0].name, "col1")
XCTAssertEqual(schema.fields[0].type.info, ArrowType.ArrowUInt8)
XCTAssertEqual(schema.fields[0].type.info, ArrowType.ArrowDouble)
XCTAssertEqual(schema.fields[0].isNullable, false)
XCTAssertEqual(schema.fields[1].name, "col2")
XCTAssertEqual(schema.fields[1].type.info, ArrowType.ArrowString)
Expand All @@ -62,12 +62,14 @@ final class TableTests: XCTestCase {
XCTAssertEqual(schema.fields[1].type.info, ArrowType.ArrowString)
XCTAssertEqual(schema.fields[1].isNullable, false)
XCTAssertEqual(table.columns.count, 3)
let col1: ChunkedArray<UInt8> = table.columns[0].data()
let col1: ChunkedArray<Double> = table.columns[0].data()
let col2: ChunkedArray<String> = table.columns[1].data()
let col3: ChunkedArray<Date> = table.columns[2].data()
XCTAssertEqual(col1.length, 2)
XCTAssertEqual(col2.length, 2)
XCTAssertEqual(col3.length, 2)
XCTAssertEqual(col1[0], 11.11)
XCTAssertEqual(col2[1], "test22")
}

func testTableWithChunkedData() throws {
Expand Down
Loading

0 comments on commit 4736385

Please sign in to comment.