diff --git a/bigquery/storage_bench_test.go b/bigquery/storage_bench_test.go index 53c9feea5bea..95b959296e94 100644 --- a/bigquery/storage_bench_test.go +++ b/bigquery/storage_bench_test.go @@ -74,7 +74,7 @@ func BenchmarkIntegration_StorageReadQuery(b *testing.B) { } } b.ReportMetric(float64(it.TotalRows), "rows") - bqSession := it.arrowIterator.(*storageArrowIterator).session.bqSession + bqSession := it.arrowIterator.(*storageArrowIterator).rs.bqSession b.ReportMetric(float64(len(bqSession.Streams)), "parallel_streams") b.ReportMetric(float64(maxStreamCount), "max_streams") } diff --git a/bigquery/storage_integration_test.go b/bigquery/storage_integration_test.go index 81d68aee8318..cbc9b5afd51b 100644 --- a/bigquery/storage_integration_test.go +++ b/bigquery/storage_integration_test.go @@ -257,7 +257,7 @@ func TestIntegration_StorageReadQueryOrdering(t *testing.T) { } total++ // as we read the first value separately - session := it.arrowIterator.(*storageArrowIterator).session + session := it.arrowIterator.(*storageArrowIterator).rs bqSession := session.bqSession if len(bqSession.Streams) == 0 { t.Fatalf("%s: expected to use at least one stream but found %d", tc.name, len(bqSession.Streams)) @@ -325,7 +325,7 @@ func TestIntegration_StorageReadQueryStruct(t *testing.T) { total++ } - bqSession := it.arrowIterator.(*storageArrowIterator).session.bqSession + bqSession := it.arrowIterator.(*storageArrowIterator).rs.bqSession if len(bqSession.Streams) == 0 { t.Fatalf("should use more than one stream but found %d", len(bqSession.Streams)) } @@ -374,7 +374,7 @@ func TestIntegration_StorageReadQueryMorePages(t *testing.T) { } total++ // as we read the first value separately - bqSession := it.arrowIterator.(*storageArrowIterator).session.bqSession + bqSession := it.arrowIterator.(*storageArrowIterator).rs.bqSession if len(bqSession.Streams) == 0 { t.Fatalf("should use more than one stream but found %d", len(bqSession.Streams)) } diff --git a/bigquery/storage_iterator.go b/bigquery/storage_iterator.go index c715d9686c8d..70abe0cd2ea1 100644 --- a/bigquery/storage_iterator.go +++ b/bigquery/storage_iterator.go @@ -37,13 +37,12 @@ type storageArrowIterator struct { done uint32 // atomic flag initialized bool errs chan error - ctx context.Context schema Schema rawSchema []byte records chan *ArrowRecordBatch - session *readSession + rs *readSession } var _ ArrowIterator = &storageArrowIterator{} @@ -123,8 +122,7 @@ func resolveLastChildSelectJob(ctx context.Context, job *Job) (*Job, error) { func newRawStorageRowIterator(rs *readSession, schema Schema) (*storageArrowIterator, error) { arrowIt := &storageArrowIterator{ - ctx: rs.ctx, - session: rs, + rs: rs, schema: schema, records: make(chan *ArrowRecordBatch, rs.settings.maxWorkerCount+1), errs: make(chan error, rs.settings.maxWorkerCount+1), @@ -144,7 +142,7 @@ func newStorageRowIterator(rs *readSession, schema Schema) (*RowIterator, error) if err != nil { return nil, err } - totalRows := arrowIt.session.bqSession.EstimatedRowCount + totalRows := arrowIt.rs.bqSession.EstimatedRowCount it := &RowIterator{ ctx: rs.ctx, arrowIterator: arrowIt, @@ -191,7 +189,7 @@ func (it *storageArrowIterator) init() error { return nil } - bqSession := it.session.bqSession + bqSession := it.rs.bqSession if bqSession == nil { return errors.New("read session not initialized") } @@ -203,7 +201,7 @@ func (it *storageArrowIterator) init() error { wg := sync.WaitGroup{} wg.Add(len(streams)) - sem := semaphore.NewWeighted(int64(it.session.settings.maxWorkerCount)) + sem := semaphore.NewWeighted(int64(it.rs.settings.maxWorkerCount)) go func() { wg.Wait() close(it.records) @@ -213,7 +211,7 @@ func (it *storageArrowIterator) init() error { go func() { for _, readStream := range streams { - err := sem.Acquire(it.ctx, 1) + err := sem.Acquire(it.rs.ctx, 1) if err != nil { wg.Done() continue @@ -241,17 +239,17 @@ func (it *storageArrowIterator) processStream(readStream string) { bo := gax.Backoff{} var offset int64 for { - rowStream, err := it.session.readRows(&storagepb.ReadRowsRequest{ + rowStream, err := it.rs.readRows(&storagepb.ReadRowsRequest{ ReadStream: readStream, Offset: offset, }) if err != nil { - if it.session.ctx.Err() != nil { // context cancelled, don't try again + if it.rs.ctx.Err() != nil { // context cancelled, don't try again return } backoff, shouldRetry := retryReadRows(bo, err) if shouldRetry { - if err := gax.Sleep(it.ctx, backoff); err != nil { + if err := gax.Sleep(it.rs.ctx, backoff); err != nil { return // context cancelled } continue @@ -264,12 +262,12 @@ func (it *storageArrowIterator) processStream(readStream string) { return } if err != nil { - if it.session.ctx.Err() != nil { // context cancelled, don't queue error + if it.rs.ctx.Err() != nil { // context cancelled, don't queue error return } backoff, shouldRetry := retryReadRows(bo, err) if shouldRetry { - if err := gax.Sleep(it.ctx, backoff); err != nil { + if err := gax.Sleep(it.rs.ctx, backoff); err != nil { return // context cancelled } continue @@ -338,8 +336,8 @@ func (it *storageArrowIterator) Next() (*ArrowRecordBatch, error) { return record, nil case err := <-it.errs: return nil, err - case <-it.ctx.Done(): - return nil, it.ctx.Err() + case <-it.rs.ctx.Done(): + return nil, it.rs.ctx.Err() } } diff --git a/bigquery/storage_iterator_test.go b/bigquery/storage_iterator_test.go index 8938b0a26173..33dc77201dcc 100644 --- a/bigquery/storage_iterator_test.go +++ b/bigquery/storage_iterator_test.go @@ -125,7 +125,7 @@ func TestStorageIteratorRetry(t *testing.T) { it.processStream("test-stream") - if errors.Is(it.ctx.Err(), context.Canceled) || errors.Is(it.ctx.Err(), context.DeadlineExceeded) { + if errors.Is(it.rs.ctx.Err(), context.Canceled) || errors.Is(it.rs.ctx.Err(), context.DeadlineExceeded) { if tc.wantFail { continue }