Skip to content

Commit

Permalink
feat(bigquery): expose Apache Arrow data through ArrowIterator (#8506)
Browse files Browse the repository at this point in the history
As we have some planned work to support Arrow data fetching on other query APIs, so we need to think of an interface that will support all of those query paths and also work as a base for other Arrow projects like ADBC. So this PR detaches the Storage API from the Arrow Decoder and creates a new ArrowIterator interface. This new interface is implemented by the Storage iterator and later can be implemented for other query interfaces that supports Arrow.

Resolves #8100
  • Loading branch information
alvarowolfx authored and bhshkh committed Nov 3, 2023
1 parent b5f205c commit 8e36f5b
Show file tree
Hide file tree
Showing 6 changed files with 241 additions and 71 deletions.
105 changes: 85 additions & 20 deletions bigquery/arrow.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,49 +19,114 @@ import (
"encoding/base64"
"errors"
"fmt"
"io"
"math/big"

"cloud.google.com/go/civil"
"github.com/apache/arrow/go/v12/arrow"
"github.com/apache/arrow/go/v12/arrow/array"
"github.com/apache/arrow/go/v12/arrow/ipc"
"github.com/apache/arrow/go/v12/arrow/memory"
"google.golang.org/api/iterator"
)

type arrowDecoder struct {
tableSchema Schema
rawArrowSchema []byte
arrowSchema *arrow.Schema
// ArrowRecordBatch represents an Arrow RecordBatch with the source PartitionID
type ArrowRecordBatch struct {
reader io.Reader
// Serialized Arrow Record Batch.
Data []byte
// Serialized Arrow Schema.
Schema []byte
// Source partition ID. In the Storage API world, it represents the ReadStream.
PartitionID string
}

// Read makes ArrowRecordBatch implements io.Reader
func (r *ArrowRecordBatch) Read(p []byte) (int, error) {
if r.reader == nil {
buf := bytes.NewBuffer(r.Schema)
buf.Write(r.Data)
r.reader = buf
}
return r.reader.Read(p)
}

// ArrowIterator represents a way to iterate through a stream of arrow records.
// Experimental: this interface is experimental and may be modified or removed in future versions,
// regardless of any other documented package stability guarantees.
type ArrowIterator interface {
Next() (*ArrowRecordBatch, error)
Schema() Schema
SerializedArrowSchema() []byte
}

func newArrowDecoderFromSession(session *readSession, schema Schema) (*arrowDecoder, error) {
bqSession := session.bqSession
if bqSession == nil {
return nil, errors.New("read session not initialized")
// NewArrowIteratorReader allows to consume an ArrowIterator as an io.Reader.
// Experimental: this interface is experimental and may be modified or removed in future versions,
// regardless of any other documented package stability guarantees.
func NewArrowIteratorReader(it ArrowIterator) io.Reader {
return &arrowIteratorReader{
it: it,
}
arrowSerializedSchema := bqSession.GetArrowSchema().GetSerializedSchema()
}

type arrowIteratorReader struct {
buf *bytes.Buffer
it ArrowIterator
}

// Read makes ArrowIteratorReader implement io.Reader
func (r *arrowIteratorReader) Read(p []byte) (int, error) {
if r.it == nil {
return -1, errors.New("bigquery: nil ArrowIterator")
}
if r.buf == nil { // init with schema
buf := bytes.NewBuffer(r.it.SerializedArrowSchema())
r.buf = buf
}
n, err := r.buf.Read(p)
if err == io.EOF {
batch, err := r.it.Next()
if err == iterator.Done {
return 0, io.EOF
}
r.buf.Write(batch.Data)
return r.Read(p)
}
return n, err
}

type arrowDecoder struct {
allocator memory.Allocator
tableSchema Schema
arrowSchema *arrow.Schema
}

func newArrowDecoder(arrowSerializedSchema []byte, schema Schema) (*arrowDecoder, error) {
buf := bytes.NewBuffer(arrowSerializedSchema)
r, err := ipc.NewReader(buf)
if err != nil {
return nil, err
}
defer r.Release()
p := &arrowDecoder{
tableSchema: schema,
rawArrowSchema: arrowSerializedSchema,
arrowSchema: r.Schema(),
tableSchema: schema,
arrowSchema: r.Schema(),
allocator: memory.DefaultAllocator,
}
return p, nil
}

func (ap *arrowDecoder) createIPCReaderForBatch(serializedArrowRecordBatch []byte) (*ipc.Reader, error) {
buf := bytes.NewBuffer(ap.rawArrowSchema)
buf.Write(serializedArrowRecordBatch)
return ipc.NewReader(buf, ipc.WithSchema(ap.arrowSchema))
func (ap *arrowDecoder) createIPCReaderForBatch(arrowRecordBatch *ArrowRecordBatch) (*ipc.Reader, error) {
return ipc.NewReader(
arrowRecordBatch,
ipc.WithSchema(ap.arrowSchema),
ipc.WithAllocator(ap.allocator),
)
}

// decodeArrowRecords decodes BQ ArrowRecordBatch into rows of []Value.
func (ap *arrowDecoder) decodeArrowRecords(serializedArrowRecordBatch []byte) ([][]Value, error) {
r, err := ap.createIPCReaderForBatch(serializedArrowRecordBatch)
func (ap *arrowDecoder) decodeArrowRecords(arrowRecordBatch *ArrowRecordBatch) ([][]Value, error) {
r, err := ap.createIPCReaderForBatch(arrowRecordBatch)
if err != nil {
return nil, err
}
Expand All @@ -79,8 +144,8 @@ func (ap *arrowDecoder) decodeArrowRecords(serializedArrowRecordBatch []byte) ([
}

// decodeRetainedArrowRecords decodes BQ ArrowRecordBatch into a list of retained arrow.Record.
func (ap *arrowDecoder) decodeRetainedArrowRecords(serializedArrowRecordBatch []byte) ([]arrow.Record, error) {
r, err := ap.createIPCReaderForBatch(serializedArrowRecordBatch)
func (ap *arrowDecoder) decodeRetainedArrowRecords(arrowRecordBatch *ArrowRecordBatch) ([]arrow.Record, error) {
r, err := ap.createIPCReaderForBatch(arrowRecordBatch)
if err != nil {
return nil, err
}
Expand Down
3 changes: 2 additions & 1 deletion bigquery/iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ type RowIterator struct {
ctx context.Context
src *rowSource

arrowIterator *arrowIterator
arrowIterator ArrowIterator
arrowDecoder *arrowDecoder

pageInfo *iterator.PageInfo
nextFunc func() error
Expand Down
2 changes: 1 addition & 1 deletion bigquery/storage_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func BenchmarkIntegration_StorageReadQuery(b *testing.B) {
}
}
b.ReportMetric(float64(it.TotalRows), "rows")
bqSession := it.arrowIterator.session.bqSession
bqSession := it.arrowIterator.(*storageArrowIterator).session.bqSession
b.ReportMetric(float64(len(bqSession.Streams)), "parallel_streams")
b.ReportMetric(float64(maxStreamCount), "max_streams")
}
Expand Down
93 changes: 88 additions & 5 deletions bigquery/storage_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ import (
"time"

"cloud.google.com/go/internal/testutil"
"github.com/apache/arrow/go/v12/arrow"
"github.com/apache/arrow/go/v12/arrow/array"
"github.com/apache/arrow/go/v12/arrow/ipc"
"github.com/apache/arrow/go/v12/arrow/math"
"github.com/apache/arrow/go/v12/arrow/memory"
"github.com/google/go-cmp/cmp"
"google.golang.org/api/iterator"
)
Expand Down Expand Up @@ -250,11 +255,12 @@ func TestIntegration_StorageReadQueryOrdering(t *testing.T) {
}
total++ // as we read the first value separately

bqSession := it.arrowIterator.session.bqSession
session := it.arrowIterator.(*storageArrowIterator).session
bqSession := session.bqSession
if len(bqSession.Streams) == 0 {
t.Fatalf("%s: expected to use at least one stream but found %d", tc.name, len(bqSession.Streams))
}
streamSettings := it.arrowIterator.session.settings.maxStreamCount
streamSettings := session.settings.maxStreamCount
if tc.maxExpectedStreams > 0 {
if streamSettings > tc.maxExpectedStreams {
t.Fatalf("%s: expected stream settings to be at most %d streams but found %d", tc.name, tc.maxExpectedStreams, streamSettings)
Expand Down Expand Up @@ -317,7 +323,7 @@ func TestIntegration_StorageReadQueryStruct(t *testing.T) {
total++
}

bqSession := it.arrowIterator.session.bqSession
bqSession := it.arrowIterator.(*storageArrowIterator).session.bqSession
if len(bqSession.Streams) == 0 {
t.Fatalf("should use more than one stream but found %d", len(bqSession.Streams))
}
Expand Down Expand Up @@ -366,7 +372,7 @@ func TestIntegration_StorageReadQueryMorePages(t *testing.T) {
}
total++ // as we read the first value separately

bqSession := it.arrowIterator.session.bqSession
bqSession := it.arrowIterator.(*storageArrowIterator).session.bqSession
if len(bqSession.Streams) == 0 {
t.Fatalf("should use more than one stream but found %d", len(bqSession.Streams))
}
Expand Down Expand Up @@ -418,11 +424,88 @@ func TestIntegration_StorageReadCancel(t *testing.T) {
}
// resources are cleaned asynchronously
time.Sleep(time.Second)
if !it.arrowIterator.isDone() {
arrowIt := it.arrowIterator.(*storageArrowIterator)
if !arrowIt.isDone() {
t.Fatal("expected stream to be done")
}
}

func TestIntegration_StorageReadArrow(t *testing.T) {
if client == nil {
t.Skip("Integration tests skipped")
}
ctx := context.Background()
table := "`bigquery-public-data.usa_names.usa_1910_current`"
sql := fmt.Sprintf(`SELECT name, number, state FROM %s where state = "CA"`, table)

q := storageOptimizedClient.Query(sql)
job, err := q.Run(ctx) // force usage of Storage API by skipping fast paths
if err != nil {
t.Fatal(err)
}
it, err := job.Read(ctx)
if err != nil {
t.Fatal(err)
}

checkedAllocator := memory.NewCheckedAllocator(memory.DefaultAllocator)
it.arrowDecoder.allocator = checkedAllocator
defer checkedAllocator.AssertSize(t, 0)

arrowIt, err := it.ArrowIterator()
if err != nil {
t.Fatalf("expected iterator to be accelerated: %v", err)
}
arrowItReader := NewArrowIteratorReader(arrowIt)

records := []arrow.Record{}
r, err := ipc.NewReader(arrowItReader, ipc.WithAllocator(checkedAllocator))
numrec := 0
for r.Next() {
rec := r.Record()
rec.Retain()
defer rec.Release()
records = append(records, rec)
numrec += int(rec.NumRows())
}
r.Release()

arrowSchema := r.Schema()
arrowTable := array.NewTableFromRecords(arrowSchema, records)
defer arrowTable.Release()
if arrowTable.NumRows() != int64(it.TotalRows) {
t.Fatalf("should have a table with %d rows, but found %d", it.TotalRows, arrowTable.NumRows())
}
if arrowTable.NumCols() != 3 {
t.Fatalf("should have a table with 3 columns, but found %d", arrowTable.NumCols())
}

sumSQL := fmt.Sprintf(`SELECT sum(number) as total FROM %s where state = "CA"`, table)
sumQuery := client.Query(sumSQL)
sumIt, err := sumQuery.Read(ctx)
if err != nil {
t.Fatal(err)
}
sumValues := []Value{}
err = sumIt.Next(&sumValues)
if err != nil {
t.Fatal(err)
}
totalFromSQL := sumValues[0].(int64)

tr := array.NewTableReader(arrowTable, arrowTable.NumRows())
defer tr.Release()
var totalFromArrow int64
for tr.Next() {
rec := tr.Record()
vec := rec.Column(1).(*array.Int64)
totalFromArrow += math.Int64.Sum(vec)
}
if totalFromArrow != totalFromSQL {
t.Fatalf("expected total to be %d, but with arrow we got %d", totalFromSQL, totalFromArrow)
}
}

func countIteratorRows(it *RowIterator) (total uint64, err error) {
for {
var dst []Value
Expand Down
Loading

0 comments on commit 8e36f5b

Please sign in to comment.