Skip to content

Commit

Permalink
apacheGH-37726: [Swift] Update flight behavior to be similar to exist…
Browse files Browse the repository at this point in the history
…ing impls
  • Loading branch information
abandy committed Sep 18, 2023
1 parent 440dc92 commit 24649c9
Show file tree
Hide file tree
Showing 14 changed files with 5,486 additions and 184 deletions.
48 changes: 43 additions & 5 deletions swift/Arrow/Sources/Arrow/ArrowReader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public class ArrowReader {
}

public class ArrowReaderResult {
fileprivate var messageSchema: org_apache_arrow_flatbuf_Schema?
public var schema: ArrowSchema?
public var batches = [RecordBatch]()
}
Expand Down Expand Up @@ -95,15 +96,14 @@ public class ArrowReader {
}
}

private func loadRecordBatch(_ message: org_apache_arrow_flatbuf_Message, schema: org_apache_arrow_flatbuf_Schema,
private func loadRecordBatch(_ recordBatch: org_apache_arrow_flatbuf_RecordBatch, schema: org_apache_arrow_flatbuf_Schema,
arrowSchema: ArrowSchema, data: Data, messageEndOffset: Int64) -> Result<RecordBatch, ArrowError> {
let recordBatch = message.header(type: org_apache_arrow_flatbuf_RecordBatch.self)
let nodesCount = recordBatch?.nodesCount ?? 0
let nodesCount = recordBatch.nodesCount
var bufferIndex: Int32 = 0
var columns: [ArrowArrayHolder] = []
for nodeIndex in 0 ..< nodesCount {
let field = schema.fields(at: nodeIndex)!
let loadInfo = DataLoadInfo(recordBatch: recordBatch!, field: field,
let loadInfo = DataLoadInfo(recordBatch: recordBatch, field: field,
nodeIndex: nodeIndex, bufferIndex: bufferIndex,
fileData: data, messageOffset: messageEndOffset)
var result: Result<ArrowArrayHolder, ArrowError>
Expand Down Expand Up @@ -169,7 +169,8 @@ public class ArrowReader {
switch message.headerType {
case .recordbatch:
do {
let recordBatch = try loadRecordBatch(message, schema: footer.schema!, arrowSchema: result.schema!,
let rbMessage = message.header(type: org_apache_arrow_flatbuf_RecordBatch.self)!
let recordBatch = try loadRecordBatch(rbMessage, schema: footer.schema!, arrowSchema: result.schema!,
data: fileData, messageEndOffset: messageEndOffset).get()
result.batches.append(recordBatch)
} catch let error as ArrowError {
Expand Down Expand Up @@ -200,4 +201,41 @@ public class ArrowReader {
return .failure(.unknownError("Error loading file: \(error)"))
}
}

static public func MakeArrowReaderResult() -> ArrowReaderResult{
return ArrowReaderResult();
}

public func fromMessage(_ dataHeader: Data, dataBody: Data, result: ArrowReaderResult) -> Result<Void, ArrowError> {
let mbb = ByteBuffer(data: dataHeader)
let message = org_apache_arrow_flatbuf_Message.getRootAsMessage(bb: mbb)
switch message.headerType {
case .schema:
let sMessage = message.header(type: org_apache_arrow_flatbuf_Schema.self)!
switch loadSchema(sMessage) {
case .success(let schema):
result.schema = schema
result.messageSchema = sMessage
return .success(())
case .failure(let error):
return .failure(error)
}
case .recordbatch:
let rbMessage = message.header(type: org_apache_arrow_flatbuf_RecordBatch.self)!
do {
let recordBatch = try loadRecordBatch(rbMessage, schema: result.messageSchema!, arrowSchema: result.schema!,
data: dataBody, messageEndOffset: 0).get()
result.batches.append(recordBatch)
return .success(())
} catch let error as ArrowError {
return .failure(error)
} catch {
return .failure(.unknownError("Unexpected error: \(error)"))
}

default:
return .failure(.unknownError("Unhandled header type: \(message.headerType)"))
}
}

}
26 changes: 13 additions & 13 deletions swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ fileprivate func makeStringHolder(_ buffers: [ArrowBuffer]) -> Result<ArrowArray
fileprivate func makeFloatHolder(_ floatType: org_apache_arrow_flatbuf_FloatingPoint, buffers: [ArrowBuffer]) -> Result<ArrowArrayHolder, ArrowError> {
switch floatType.precision {
case .single:
return makeFixedHolder(Float.self, buffers: buffers)
return makeFixedHolder(Float.self, buffers: buffers, arrowType: ArrowType.ArrowFloat)
case .double:
return makeFixedHolder(Double.self, buffers: buffers)
return makeFixedHolder(Double.self, buffers: buffers, arrowType: ArrowType.ArrowDouble)
default:
return .failure(.unknownType("Float precision \(floatType.precision) currently not supported"))
}
Expand Down Expand Up @@ -94,7 +94,7 @@ fileprivate func makeTimeHolder(_ timeType: org_apache_arrow_flatbuf_Time, buffe

fileprivate func makeBoolHolder(_ buffers: [ArrowBuffer]) -> Result<ArrowArrayHolder, ArrowError> {
do {
let arrowData = try ArrowData(ArrowType(ArrowType.ArrowInt32), buffers: buffers,
let arrowData = try ArrowData(ArrowType(ArrowType.ArrowBool), buffers: buffers,
nullCount: buffers[0].length, stride: MemoryLayout<UInt8>.stride)
return .success(ArrowArrayHolder(BoolArray(arrowData)))
} catch let error as ArrowError {
Expand All @@ -104,9 +104,9 @@ fileprivate func makeBoolHolder(_ buffers: [ArrowBuffer]) -> Result<ArrowArrayHo
}
}

fileprivate func makeFixedHolder<T>(_: T.Type, buffers: [ArrowBuffer]) -> Result<ArrowArrayHolder, ArrowError> {
fileprivate func makeFixedHolder<T>(_: T.Type, buffers: [ArrowBuffer], arrowType: ArrowType.Info) -> Result<ArrowArrayHolder, ArrowError> {
do {
let arrowData = try ArrowData(ArrowType(ArrowType.ArrowInt32), buffers: buffers,
let arrowData = try ArrowData(ArrowType(arrowType), buffers: buffers,
nullCount: buffers[0].length, stride: MemoryLayout<T>.stride)
return .success(ArrowArrayHolder(FixedArray<T>(arrowData)))
} catch let error as ArrowError {
Expand All @@ -124,27 +124,27 @@ func makeArrayHolder(_ field: org_apache_arrow_flatbuf_Field, buffers: [ArrowBuf
let bitWidth = intType.bitWidth
if bitWidth == 8 {
if intType.isSigned {
return makeFixedHolder(Int8.self, buffers: buffers)
return makeFixedHolder(Int8.self, buffers: buffers, arrowType: ArrowType.ArrowInt8)
} else {
return makeFixedHolder(UInt8.self, buffers: buffers)
return makeFixedHolder(UInt8.self, buffers: buffers, arrowType: ArrowType.ArrowUInt8)
}
} else if bitWidth == 16 {
if intType.isSigned {
return makeFixedHolder(Int16.self, buffers: buffers)
return makeFixedHolder(Int16.self, buffers: buffers, arrowType: ArrowType.ArrowInt16)
} else {
return makeFixedHolder(UInt16.self, buffers: buffers)
return makeFixedHolder(UInt16.self, buffers: buffers, arrowType: ArrowType.ArrowUInt16)
}
} else if bitWidth == 32 {
if intType.isSigned {
return makeFixedHolder(Int32.self, buffers: buffers)
return makeFixedHolder(Int32.self, buffers: buffers, arrowType: ArrowType.ArrowInt32)
} else {
return makeFixedHolder(UInt32.self, buffers: buffers)
return makeFixedHolder(UInt32.self, buffers: buffers, arrowType: ArrowType.ArrowUInt32)
}
} else if bitWidth == 64 {
if intType.isSigned {
return makeFixedHolder(Int64.self, buffers: buffers)
return makeFixedHolder(Int64.self, buffers: buffers, arrowType: ArrowType.ArrowInt64)
} else {
return makeFixedHolder(UInt64.self, buffers: buffers)
return makeFixedHolder(UInt64.self, buffers: buffers, arrowType: ArrowType.ArrowUInt64)
}
}
return .failure(.unknownType("Int width \(bitWidth) currently not supported"))
Expand Down
11 changes: 11 additions & 0 deletions swift/Arrow/Sources/Arrow/ArrowType.swift
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,17 @@ public class ArrowType {
self.info = info
}

public var id: ArrowTypeId {
switch self.info {
case .PrimitiveInfo(let id):
return id;
case .TimeInfo(let id):
return id;
case .VariableInfo(let id):
return id;
}
}

public enum Info {
case PrimitiveInfo(ArrowTypeId)
case VariableInfo(ArrowTypeId)
Expand Down
82 changes: 60 additions & 22 deletions swift/Arrow/Sources/Arrow/ArrowWriter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public class ArrowWriter {
self.data.append(data)
}
}

public class FileDataWriter : DataWriter {
private var handle: FileHandle
private var current_size: Int = 0
Expand All @@ -52,7 +52,7 @@ public class ArrowWriter {
self.current_size += data.count
}
}

public class Info {
public let type: org_apache_arrow_flatbuf_MessageHeader
public let schema: ArrowSchema
Expand All @@ -62,7 +62,7 @@ public class ArrowWriter {
self.schema = schema
self.batches = batches
}

public convenience init(_ type: org_apache_arrow_flatbuf_MessageHeader, schema: ArrowSchema) {
self.init(type, schema: schema, batches: [RecordBatch]())
}
Expand Down Expand Up @@ -91,7 +91,7 @@ public class ArrowWriter {
return .failure(error)
}
}

private func writeSchema(_ fbb: inout FlatBufferBuilder, schema: ArrowSchema) -> Result<Offset, ArrowError> {
var fieldOffsets = [Offset]()
for field in schema.fields {
Expand All @@ -103,16 +103,16 @@ public class ArrowWriter {
}

}

let fieldsOffset: Offset = fbb.createVector(ofOffsets: fieldOffsets)
let schemaOffset = org_apache_arrow_flatbuf_Schema.createSchema(&fbb, endianness: .little, fieldsVectorOffset: fieldsOffset)
return .success(schemaOffset)

}

private func writeRecordBatches(_ writer: inout DataWriter, batches: [RecordBatch]) -> Result<[org_apache_arrow_flatbuf_Block], ArrowError> {
var rbBlocks = [org_apache_arrow_flatbuf_Block]()

for batch in batches {
let startIndex = writer.count
switch writeRecordBatch(batch: batch) {
Expand All @@ -129,15 +129,14 @@ public class ArrowWriter {
return .failure(error)
}
}

return .success(rbBlocks)
}

private func writeRecordBatch(batch: RecordBatch) -> Result<(Data, Offset), ArrowError> {
let schema = batch.schema
var output = Data()
var fbb = FlatBufferBuilder()

// write out field nodes
var fieldNodeOffsets = [Offset]()
fbb.startVector(schema.fields.count, elementSize: MemoryLayout<org_apache_arrow_flatbuf_FieldNode>.size)
Expand All @@ -148,7 +147,7 @@ public class ArrowWriter {
}

let nodeOffset = fbb.endVector(len: schema.fields.count)

// write out buffers
var buffers = [org_apache_arrow_flatbuf_Buffer]()
var bufferOffset = Int(0)
Expand All @@ -167,22 +166,23 @@ public class ArrowWriter {
for buffer in buffers.reversed() {
fbb.create(struct: buffer)
}

let batchBuffersOffset = fbb.endVector(len: buffers.count)
let startRb = org_apache_arrow_flatbuf_RecordBatch.startRecordBatch(&fbb);
org_apache_arrow_flatbuf_RecordBatch.addVectorOf(nodes:nodeOffset, &fbb)
org_apache_arrow_flatbuf_RecordBatch.addVectorOf(buffers:batchBuffersOffset, &fbb)
org_apache_arrow_flatbuf_RecordBatch.add(length: Int64(batch.length), &fbb)
let recordBatchOffset = org_apache_arrow_flatbuf_RecordBatch.endRecordBatch(&fbb, start: startRb)

let bodySize = Int64(bufferOffset);
let startMessage = org_apache_arrow_flatbuf_Message.startMessage(&fbb)
org_apache_arrow_flatbuf_Message.add(version: .v5, &fbb)
org_apache_arrow_flatbuf_Message.add(bodyLength: Int64(bodySize), &fbb)
org_apache_arrow_flatbuf_Message.add(headerType: .recordbatch, &fbb)
org_apache_arrow_flatbuf_Message.add(header: recordBatchOffset, &fbb)
let messageOffset = org_apache_arrow_flatbuf_Message.endMessage(&fbb, start: startMessage)
fbb.finish(offset: messageOffset)
output.append(fbb.data)
return .success((output, Offset(offset: UInt32(output.count))))
return .success((fbb.data, Offset(offset: UInt32(fbb.data.count))))
}

private func writeRecordBatchData(_ writer: inout DataWriter, batch: RecordBatch) -> Result<Bool, ArrowError> {
Expand All @@ -194,7 +194,7 @@ public class ArrowWriter {
writer.append(bufferData)
}
}

return .success(true)
}

Expand All @@ -218,10 +218,10 @@ public class ArrowWriter {
case .failure(let error):
return .failure(error)
}

return .success(fbb.data)
}

private func writeStream(_ writer: inout DataWriter, info: ArrowWriter.Info) -> Result<Bool, ArrowError> {
var fbb: FlatBufferBuilder = FlatBufferBuilder()
switch writeSchema(&fbb, schema: info.schema) {
Expand Down Expand Up @@ -250,10 +250,10 @@ public class ArrowWriter {
case .failure(let error):
return .failure(error)
}

return .success(true)
}

public func toStream(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> {
var writer: any DataWriter = InMemDataWriter()
switch writeStream(&writer, info: info) {
Expand All @@ -263,7 +263,7 @@ public class ArrowWriter {
return .failure(error)
}
}

public func toFile(_ fileName: URL, info: ArrowWriter.Info) -> Result<Bool, ArrowError> {
do {
try Data().write(to: fileName)
Expand All @@ -276,7 +276,7 @@ public class ArrowWriter {

var markerData = FILEMARKER.data(using: .utf8)!;
addPadForAlignment(&markerData)

var writer: any DataWriter = FileDataWriter(fileHandle)
writer.append(FILEMARKER.data(using: .utf8)!)
switch writeStream(&writer, info: info) {
Expand All @@ -288,4 +288,42 @@ public class ArrowWriter {

return .success(true)
}

public func toMessage(_ batch: RecordBatch) -> Result<[Data], ArrowError> {
var writer: any DataWriter = InMemDataWriter()
switch writeRecordBatch(batch: batch) {
case .success(let message):
writer.append(message.0)
addPadForAlignment(&writer)
var dataWriter: any DataWriter = InMemDataWriter()
switch writeRecordBatchData(&dataWriter, batch: batch) {
case .success(_):
return .success([(writer as! InMemDataWriter).data, (dataWriter as! InMemDataWriter).data])
case .failure(let error):
return .failure(error)
}
case .failure(let error):
return .failure(error)
}
}

public func toMessage(_ schema: ArrowSchema) -> Result<Data, ArrowError> {
var schemaSize: Int32 = 0;
var fbb = FlatBufferBuilder()
switch writeSchema(&fbb, schema: schema) {
case .success(let schemaOffset):
schemaSize = Int32(schemaOffset.o)
case .failure(let error):
return .failure(error)
}

let startMessage = org_apache_arrow_flatbuf_Message.startMessage(&fbb)
org_apache_arrow_flatbuf_Message.add(bodyLength: Int64(0), &fbb)
org_apache_arrow_flatbuf_Message.add(headerType: .schema, &fbb)
org_apache_arrow_flatbuf_Message.add(header: Offset(offset: UOffset(schemaSize)), &fbb)
org_apache_arrow_flatbuf_Message.add(version: .v5, &fbb)
let messageOffset = org_apache_arrow_flatbuf_Message.endMessage(&fbb, start: startMessage)
fbb.finish(offset: messageOffset)
return .success(fbb.data)
}
}
Loading

0 comments on commit 24649c9

Please sign in to comment.