Skip to content

Commit

Permalink
refactor(bigquery): single ctx usage with storage read (#10673)
Browse files Browse the repository at this point in the history
  • Loading branch information
alvarowolfx authored Aug 12, 2024
1 parent ab9a961 commit 2ca6e9d
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 20 deletions.
2 changes: 1 addition & 1 deletion bigquery/storage_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func BenchmarkIntegration_StorageReadQuery(b *testing.B) {
}
}
b.ReportMetric(float64(it.TotalRows), "rows")
bqSession := it.arrowIterator.(*storageArrowIterator).session.bqSession
bqSession := it.arrowIterator.(*storageArrowIterator).rs.bqSession
b.ReportMetric(float64(len(bqSession.Streams)), "parallel_streams")
b.ReportMetric(float64(maxStreamCount), "max_streams")
}
Expand Down
6 changes: 3 additions & 3 deletions bigquery/storage_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ func TestIntegration_StorageReadQueryOrdering(t *testing.T) {
}
total++ // as we read the first value separately

session := it.arrowIterator.(*storageArrowIterator).session
session := it.arrowIterator.(*storageArrowIterator).rs
bqSession := session.bqSession
if len(bqSession.Streams) == 0 {
t.Fatalf("%s: expected to use at least one stream but found %d", tc.name, len(bqSession.Streams))
Expand Down Expand Up @@ -325,7 +325,7 @@ func TestIntegration_StorageReadQueryStruct(t *testing.T) {
total++
}

bqSession := it.arrowIterator.(*storageArrowIterator).session.bqSession
bqSession := it.arrowIterator.(*storageArrowIterator).rs.bqSession
if len(bqSession.Streams) == 0 {
t.Fatalf("should use more than one stream but found %d", len(bqSession.Streams))
}
Expand Down Expand Up @@ -374,7 +374,7 @@ func TestIntegration_StorageReadQueryMorePages(t *testing.T) {
}
total++ // as we read the first value separately

bqSession := it.arrowIterator.(*storageArrowIterator).session.bqSession
bqSession := it.arrowIterator.(*storageArrowIterator).rs.bqSession
if len(bqSession.Streams) == 0 {
t.Fatalf("should use more than one stream but found %d", len(bqSession.Streams))
}
Expand Down
28 changes: 13 additions & 15 deletions bigquery/storage_iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,12 @@ type storageArrowIterator struct {
done uint32 // atomic flag
initialized bool
errs chan error
ctx context.Context

schema Schema
rawSchema []byte
records chan *ArrowRecordBatch

session *readSession
rs *readSession
}

var _ ArrowIterator = &storageArrowIterator{}
Expand Down Expand Up @@ -123,8 +122,7 @@ func resolveLastChildSelectJob(ctx context.Context, job *Job) (*Job, error) {

func newRawStorageRowIterator(rs *readSession, schema Schema) (*storageArrowIterator, error) {
arrowIt := &storageArrowIterator{
ctx: rs.ctx,
session: rs,
rs: rs,
schema: schema,
records: make(chan *ArrowRecordBatch, rs.settings.maxWorkerCount+1),
errs: make(chan error, rs.settings.maxWorkerCount+1),
Expand All @@ -144,7 +142,7 @@ func newStorageRowIterator(rs *readSession, schema Schema) (*RowIterator, error)
if err != nil {
return nil, err
}
totalRows := arrowIt.session.bqSession.EstimatedRowCount
totalRows := arrowIt.rs.bqSession.EstimatedRowCount
it := &RowIterator{
ctx: rs.ctx,
arrowIterator: arrowIt,
Expand Down Expand Up @@ -191,7 +189,7 @@ func (it *storageArrowIterator) init() error {
return nil
}

bqSession := it.session.bqSession
bqSession := it.rs.bqSession
if bqSession == nil {
return errors.New("read session not initialized")
}
Expand All @@ -203,7 +201,7 @@ func (it *storageArrowIterator) init() error {

wg := sync.WaitGroup{}
wg.Add(len(streams))
sem := semaphore.NewWeighted(int64(it.session.settings.maxWorkerCount))
sem := semaphore.NewWeighted(int64(it.rs.settings.maxWorkerCount))
go func() {
wg.Wait()
close(it.records)
Expand All @@ -213,7 +211,7 @@ func (it *storageArrowIterator) init() error {

go func() {
for _, readStream := range streams {
err := sem.Acquire(it.ctx, 1)
err := sem.Acquire(it.rs.ctx, 1)
if err != nil {
wg.Done()
continue
Expand Down Expand Up @@ -241,17 +239,17 @@ func (it *storageArrowIterator) processStream(readStream string) {
bo := gax.Backoff{}
var offset int64
for {
rowStream, err := it.session.readRows(&storagepb.ReadRowsRequest{
rowStream, err := it.rs.readRows(&storagepb.ReadRowsRequest{
ReadStream: readStream,
Offset: offset,
})
if err != nil {
if it.session.ctx.Err() != nil { // context cancelled, don't try again
if it.rs.ctx.Err() != nil { // context cancelled, don't try again
return
}
backoff, shouldRetry := retryReadRows(bo, err)
if shouldRetry {
if err := gax.Sleep(it.ctx, backoff); err != nil {
if err := gax.Sleep(it.rs.ctx, backoff); err != nil {
return // context cancelled
}
continue
Expand All @@ -264,12 +262,12 @@ func (it *storageArrowIterator) processStream(readStream string) {
return
}
if err != nil {
if it.session.ctx.Err() != nil { // context cancelled, don't queue error
if it.rs.ctx.Err() != nil { // context cancelled, don't queue error
return
}
backoff, shouldRetry := retryReadRows(bo, err)
if shouldRetry {
if err := gax.Sleep(it.ctx, backoff); err != nil {
if err := gax.Sleep(it.rs.ctx, backoff); err != nil {
return // context cancelled
}
continue
Expand Down Expand Up @@ -338,8 +336,8 @@ func (it *storageArrowIterator) Next() (*ArrowRecordBatch, error) {
return record, nil
case err := <-it.errs:
return nil, err
case <-it.ctx.Done():
return nil, it.ctx.Err()
case <-it.rs.ctx.Done():
return nil, it.rs.ctx.Err()
}
}

Expand Down
2 changes: 1 addition & 1 deletion bigquery/storage_iterator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func TestStorageIteratorRetry(t *testing.T) {

it.processStream("test-stream")

if errors.Is(it.ctx.Err(), context.Canceled) || errors.Is(it.ctx.Err(), context.DeadlineExceeded) {
if errors.Is(it.rs.ctx.Err(), context.Canceled) || errors.Is(it.rs.ctx.Err(), context.DeadlineExceeded) {
if tc.wantFail {
continue
}
Expand Down

0 comments on commit 2ca6e9d

Please sign in to comment.