Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(bigquery): expose Apache Arrow data through ArrowIterator #8506

Merged
merged 13 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 78 additions & 20 deletions bigquery/arrow.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,49 +19,107 @@ import (
"encoding/base64"
"errors"
"fmt"
"io"
"math/big"

"cloud.google.com/go/civil"
"github.com/apache/arrow/go/v12/arrow"
"github.com/apache/arrow/go/v12/arrow/array"
"github.com/apache/arrow/go/v12/arrow/ipc"
"google.golang.org/api/iterator"
)

type arrowDecoder struct {
tableSchema Schema
rawArrowSchema []byte
arrowSchema *arrow.Schema
// ArrowRecordBatch represents an Arrow RecordBatch with the source PartitionID
type ArrowRecordBatch struct {
reader io.Reader
// Serialized Arrow Record Batch.
Data []byte
// Serialized Arrow Schema.
Schema []byte
// Source partition ID. In the Storage API world, it represents the ReadStream.
PartitionID string
}

// Read makes ArrowRecordBatch implements io.Reader
func (r *ArrowRecordBatch) Read(p []byte) (int, error) {
if r.reader == nil {
buf := bytes.NewBuffer(r.Schema)
buf.Write(r.Data)
r.reader = buf
}
return r.reader.Read(p)
}

// ArrowIterator represents a way to iterate through a stream of arrow records.
// Experimental: this interface is experimental and may be modified or removed in future versions,
// regardless of any other documented package stability guarantees.
type ArrowIterator interface {
Next() (*ArrowRecordBatch, error)
Schema() Schema
SerializedArrowSchema() []byte
}

func newArrowDecoderFromSession(session *readSession, schema Schema) (*arrowDecoder, error) {
bqSession := session.bqSession
if bqSession == nil {
return nil, errors.New("read session not initialized")
// NewArrowIteratorReader allows to consume an ArrowIterator as an io.Reader.
// Experimental: this interface is experimental and may be modified or removed in future versions,
// regardless of any other documented package stability guarantees.
func NewArrowIteratorReader(it ArrowIterator) io.Reader {
return &arrowIteratorReader{
it: it,
}
arrowSerializedSchema := bqSession.GetArrowSchema().GetSerializedSchema()
}

type arrowIteratorReader struct {
buf *bytes.Buffer
it ArrowIterator
}

// Read makes ArrowIteratorReader implement io.Reader
func (r *arrowIteratorReader) Read(p []byte) (int, error) {
if r.it == nil {
return -1, errors.New("bigquery: nil ArrowIterator")
}
if r.buf == nil { // init with schema
buf := bytes.NewBuffer(r.it.SerializedArrowSchema())
r.buf = buf
}
n, err := r.buf.Read(p)
if err == io.EOF {
batch, err := r.it.Next()
if err == iterator.Done {
return 0, io.EOF
}
r.buf.Write(batch.Data)
return r.Read(p)
}
return n, err
}

type arrowDecoder struct {
tableSchema Schema
arrowSchema *arrow.Schema
}

func newArrowDecoder(arrowSerializedSchema []byte, schema Schema) (*arrowDecoder, error) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would probably be worthwhile (though certainly could be done as a follow-up) to allow passing a memory.Allocator interface here that would be stored in the arrowDecoder to allow a user to configure how memory gets allocated for the arrow batches (it would be passed as ipc.WithAllocator(mem) to ipc.NewReader)

In most cases users would probalby just use memory.DefaultAllocator but in other cases, depending on the constraints of the system, they might want to use a custom allocator such as a malloc based allocator that uses C memory to avoid garbage collection passes, or any other custom allocation they might want for specialized situations.

The other benefit of this would be that you could use memory.CheckedAllocator in unit tests to verify that everything properly has Release called if necessary, etc.

buf := bytes.NewBuffer(arrowSerializedSchema)
r, err := ipc.NewReader(buf)
if err != nil {
return nil, err
}
defer r.Release()
p := &arrowDecoder{
tableSchema: schema,
rawArrowSchema: arrowSerializedSchema,
arrowSchema: r.Schema(),
tableSchema: schema,
arrowSchema: r.Schema(),
}
return p, nil
}

func (ap *arrowDecoder) createIPCReaderForBatch(serializedArrowRecordBatch []byte) (*ipc.Reader, error) {
buf := bytes.NewBuffer(ap.rawArrowSchema)
buf.Write(serializedArrowRecordBatch)
return ipc.NewReader(buf, ipc.WithSchema(ap.arrowSchema))
func (ap *arrowDecoder) createIPCReaderForBatch(arrowRecordBatch *ArrowRecordBatch) (*ipc.Reader, error) {
return ipc.NewReader(arrowRecordBatch, ipc.WithSchema(ap.arrowSchema))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as above, it would be great (but could be a follow-up) if either newArrowDecoder accepted a memory.Allocator which would get used here as ipc.WithAllocator(mem) or if this method optionally took an allocator (defaulting to memory.DefaultAllocator if nil was passed)

}

// decodeArrowRecords decodes BQ ArrowRecordBatch into rows of []Value.
func (ap *arrowDecoder) decodeArrowRecords(serializedArrowRecordBatch []byte) ([][]Value, error) {
r, err := ap.createIPCReaderForBatch(serializedArrowRecordBatch)
func (ap *arrowDecoder) decodeArrowRecords(arrowRecordBatch *ArrowRecordBatch) ([][]Value, error) {
r, err := ap.createIPCReaderForBatch(arrowRecordBatch)
if err != nil {
return nil, err
}
Expand All @@ -79,8 +137,8 @@ func (ap *arrowDecoder) decodeArrowRecords(serializedArrowRecordBatch []byte) ([
}

// decodeRetainedArrowRecords decodes BQ ArrowRecordBatch into a list of retained arrow.Record.
func (ap *arrowDecoder) decodeRetainedArrowRecords(serializedArrowRecordBatch []byte) ([]arrow.Record, error) {
r, err := ap.createIPCReaderForBatch(serializedArrowRecordBatch)
func (ap *arrowDecoder) decodeRetainedArrowRecords(arrowRecordBatch *ArrowRecordBatch) ([]arrow.Record, error) {
r, err := ap.createIPCReaderForBatch(arrowRecordBatch)
if err != nil {
return nil, err
}
Expand Down
3 changes: 2 additions & 1 deletion bigquery/iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ type RowIterator struct {
ctx context.Context
src *rowSource

arrowIterator *arrowIterator
arrowIterator ArrowIterator
arrowDecoder *arrowDecoder

pageInfo *iterator.PageInfo
nextFunc func() error
Expand Down
2 changes: 1 addition & 1 deletion bigquery/storage_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func BenchmarkIntegration_StorageReadQuery(b *testing.B) {
}
}
b.ReportMetric(float64(it.TotalRows), "rows")
bqSession := it.arrowIterator.session.bqSession
bqSession := it.arrowIterator.(*storageArrowIterator).session.bqSession
b.ReportMetric(float64(len(bqSession.Streams)), "parallel_streams")
b.ReportMetric(float64(maxStreamCount), "max_streams")
}
Expand Down
87 changes: 82 additions & 5 deletions bigquery/storage_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ import (
"time"

"cloud.google.com/go/internal/testutil"
"github.com/apache/arrow/go/v12/arrow"
"github.com/apache/arrow/go/v12/arrow/array"
"github.com/apache/arrow/go/v12/arrow/ipc"
"github.com/apache/arrow/go/v12/arrow/math"
"github.com/google/go-cmp/cmp"
"google.golang.org/api/iterator"
)
Expand Down Expand Up @@ -250,11 +254,12 @@ func TestIntegration_StorageReadQueryOrdering(t *testing.T) {
}
total++ // as we read the first value separately

bqSession := it.arrowIterator.session.bqSession
session := it.arrowIterator.(*storageArrowIterator).session
bqSession := session.bqSession
if len(bqSession.Streams) == 0 {
t.Fatalf("%s: expected to use at least one stream but found %d", tc.name, len(bqSession.Streams))
}
streamSettings := it.arrowIterator.session.settings.maxStreamCount
streamSettings := session.settings.maxStreamCount
if tc.maxExpectedStreams > 0 {
if streamSettings > tc.maxExpectedStreams {
t.Fatalf("%s: expected stream settings to be at most %d streams but found %d", tc.name, tc.maxExpectedStreams, streamSettings)
Expand Down Expand Up @@ -317,7 +322,7 @@ func TestIntegration_StorageReadQueryStruct(t *testing.T) {
total++
}

bqSession := it.arrowIterator.session.bqSession
bqSession := it.arrowIterator.(*storageArrowIterator).session.bqSession
if len(bqSession.Streams) == 0 {
t.Fatalf("should use more than one stream but found %d", len(bqSession.Streams))
}
Expand Down Expand Up @@ -366,7 +371,7 @@ func TestIntegration_StorageReadQueryMorePages(t *testing.T) {
}
total++ // as we read the first value separately

bqSession := it.arrowIterator.session.bqSession
bqSession := it.arrowIterator.(*storageArrowIterator).session.bqSession
if len(bqSession.Streams) == 0 {
t.Fatalf("should use more than one stream but found %d", len(bqSession.Streams))
}
Expand Down Expand Up @@ -418,11 +423,83 @@ func TestIntegration_StorageReadCancel(t *testing.T) {
}
// resources are cleaned asynchronously
time.Sleep(time.Second)
if !it.arrowIterator.isDone() {
arrowIt := it.arrowIterator.(*storageArrowIterator)
if !arrowIt.isDone() {
t.Fatal("expected stream to be done")
}
}

func TestIntegration_StorageReadArrow(t *testing.T) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@k-anshul @zeroshade this integration test show an example on how that interface would be used.

if client == nil {
t.Skip("Integration tests skipped")
}
ctx := context.Background()
table := "`bigquery-public-data.usa_names.usa_1910_current`"
sql := fmt.Sprintf(`SELECT name, number, state FROM %s where state = "CA"`, table)

q := storageOptimizedClient.Query(sql)
job, err := q.Run(ctx) // force usage of Storage API by skipping fast paths
if err != nil {
t.Fatal(err)
}
it, err := job.Read(ctx)
if err != nil {
t.Fatal(err)
}

arrowIt, err := it.ArrowIterator()
if err != nil {
t.Fatalf("expected iterator to be accelerated: %v", err)
}
arrowItReader := NewArrowIteratorReader(arrowIt)

records := []arrow.Record{}
r, err := ipc.NewReader(arrowItReader)
numrec := 0
for r.Next() {
rec := r.Record()
rec.Retain()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused by the release/retains here, but I've not been spending much time with arrow recently. If you retain individual records do you need to release them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they are going to be releases by the ipc.Reader later on the r.Release() call right after.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not exactly. If you call Retain on an individual record, then you will need to call Release on that record.

The ipc.Reader keeps only the current Record, reusing that member. When you call Next() it will release the record it had before loading the next one. This is why you need to call Retain on the records that you put into the slice, so that they aren't deallocated by the ipc.Reader calling Release on them. However you should also add a defer rec.Release() in the loop to ensure that record gets released, the ipc.Reader will not retain any references to those records and therefore will not call Release on them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting, I didn't know that. I'll make the changes to call rec.Release on each record.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to verify that everything is released/retained correctly, you could use memory.CheckedAllocator and defer mem.AssertSize(t, 0) then pass the checked allocator to everything (like to ipc.NewReader) so that it is used for all the memory allocations.

Not absolutely necessary, but an optional way to add some assurances if desired.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm gonna add support for changing the allocator just internally for now for test purposes, I liked the idea of verifying that there are no memory leaks. Thanks for the tip.

defer rec.Release()
records = append(records, rec)
numrec += int(rec.NumRows())
}
r.Release()

arrowSchema := r.Schema()
arrowTable := array.NewTableFromRecords(arrowSchema, records)
defer arrowTable.Release()
if arrowTable.NumRows() != int64(it.TotalRows) {
t.Fatalf("should have a table with %d rows, but found %d", it.TotalRows, arrowTable.NumRows())
}
if arrowTable.NumCols() != 3 {
t.Fatalf("should have a table with 3 columns, but found %d", arrowTable.NumCols())
}

sumSQL := fmt.Sprintf(`SELECT sum(number) as total FROM %s where state = "CA"`, table)
sumQuery := client.Query(sumSQL)
sumIt, err := sumQuery.Read(ctx)
if err != nil {
t.Fatal(err)
}
sumValues := []Value{}
err = sumIt.Next(&sumValues)
if err != nil {
t.Fatal(err)
}
totalFromSQL := sumValues[0].(int64)

tr := array.NewTableReader(arrowTable, arrowTable.NumRows())

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add defer tr.Release() please.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after adding support for changing the allocator I found that without the tr.Release a memory leak was happening. Good catch and awesome tips on how to catch those leaks 🎉

var totalFromArrow int64
for tr.Next() {
rec := tr.Record()
vec := rec.Column(1).(*array.Int64)
totalFromArrow += math.Int64.Sum(vec)
}
if totalFromArrow != totalFromSQL {
t.Fatalf("expected total to be %d, but with arrow we got %d", totalFromSQL, totalFromArrow)
}
}

func countIteratorRows(it *RowIterator) (total uint64, err error) {
for {
var dst []Value
Expand Down
Loading
Loading