Skip to content

Commit

Permalink
Merge branch 'gh-1804' into gh-1144
Browse files Browse the repository at this point in the history
  • Loading branch information
lidavidm committed May 3, 2024
2 parents 22c5d4a + 775fde2 commit 18d753b
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 70 deletions.
1 change: 1 addition & 0 deletions c/driver/flightsql/dremio_flightsql_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class DremioFlightSqlStatementTest : public ::testing::Test,
void TestSqlIngestColumnEscaping() {
GTEST_SKIP() << "Column escaping not implemented";
}
void TestSqlQueryEmpty() { GTEST_SKIP() << "Dremio doesn't support 'acceptPut'"; }
void TestSqlQueryRowsAffectedDelete() {
GTEST_SKIP() << "Cannot query rows affected in delete (not implemented)";
}
Expand Down
2 changes: 2 additions & 0 deletions c/validation/adbc_validation.h
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ class StatementTest {
void TestSqlPrepareErrorNoQuery();
void TestSqlPrepareErrorParamCountMismatch();

void TestSqlQueryEmpty();
void TestSqlQueryInts();
void TestSqlQueryFloats();
void TestSqlQueryStrings();
Expand Down Expand Up @@ -504,6 +505,7 @@ class StatementTest {
TEST_F(FIXTURE, SqlPrepareErrorParamCountMismatch) { \
TestSqlPrepareErrorParamCountMismatch(); \
} \
TEST_F(FIXTURE, SqlQueryEmpty) { TestSqlQueryEmpty(); } \
TEST_F(FIXTURE, SqlQueryInts) { TestSqlQueryInts(); } \
TEST_F(FIXTURE, SqlQueryFloats) { TestSqlQueryFloats(); } \
TEST_F(FIXTURE, SqlQueryStrings) { TestSqlQueryStrings(); } \
Expand Down
35 changes: 35 additions & 0 deletions c/validation/adbc_validation_statement.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2062,6 +2062,41 @@ void StatementTest::TestSqlPrepareErrorParamCountMismatch() {
::testing::Not(IsOkStatus(&error)));
}

void StatementTest::TestSqlQueryEmpty() {
ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));

ASSERT_THAT(quirks()->DropTable(&connection, "QUERYEMPTY", &error), IsOkStatus(&error));
ASSERT_THAT(
AdbcStatementSetSqlQuery(&statement, "CREATE TABLE QUERYEMPTY (FOO INT)", &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error),
IsOkStatus(&error));

ASSERT_THAT(
AdbcStatementSetSqlQuery(&statement, "SELECT * FROM QUERYEMPTY WHERE 1=0", &error),
IsOkStatus(&error));
{
StreamReader reader;
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
&reader.rows_affected, &error),
IsOkStatus(&error));
ASSERT_THAT(reader.rows_affected,
::testing::AnyOf(::testing::Eq(0), ::testing::Eq(-1)));

ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
ASSERT_EQ(1, reader.schema->n_children);

while (true) {
ASSERT_NO_FATAL_FAILURE(reader.Next());
if (!reader.array->release) {
break;
}
ASSERT_EQ(0, reader.array->length);
}
}
ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
}

void StatementTest::TestSqlQueryInts() {
ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 42", &error),
Expand Down
149 changes: 79 additions & 70 deletions go/adbc/driver/snowflake/record_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -571,23 +571,9 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake
}

ch := make(chan arrow.Record, bufferSize)
r, err := batches[0].GetStream(ctx)
if err != nil {
return nil, errToAdbcErr(adbc.StatusIO, err)
}

rr, err := ipc.NewReader(r, ipc.WithAllocator(alloc))
if err != nil {
return nil, adbc.Error{
Msg: err.Error(),
Code: adbc.StatusInvalidState,
}
}

group, ctx := errgroup.WithContext(compute.WithAllocator(ctx, alloc))
ctx, cancelFn := context.WithCancel(ctx)

schema, recTransform := getTransformer(rr.Schema(), ld, useHighPrecision)
group.SetLimit(prefetchConcurrency)

defer func() {
if err != nil {
Expand All @@ -596,80 +582,103 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake
}
}()

group.SetLimit(prefetchConcurrency)
group.Go(func() error {
defer rr.Release()
defer r.Close()
if len(batches) > 1 {
defer close(ch)
}

for rr.Next() && ctx.Err() == nil {
rec := rr.Record()
rec, err = recTransform(ctx, rec)
if err != nil {
return err
}
ch <- rec
}
return rr.Err()
})

chs := make([]chan arrow.Record, len(batches))
chs[0] = ch
rdr := &reader{
refCount: 1,
chs: chs,
err: nil,
cancelFn: cancelFn,
schema: schema,
}

lastChannelIndex := len(chs) - 1
go func() {
for i, b := range batches[1:] {
batch, batchIdx := b, i+1
chs[batchIdx] = make(chan arrow.Record, bufferSize)
group.Go(func() error {
// close channels (except the last) so that Next can move on to the next channel properly
if batchIdx != lastChannelIndex {
defer close(chs[batchIdx])
}
if len(batches) > 0 {
r, err := batches[0].GetStream(ctx)
if err != nil {
return nil, errToAdbcErr(adbc.StatusIO, err)
}

rdr, err := batch.GetStream(ctx)
if err != nil {
return err
}
defer rdr.Close()
rr, err := ipc.NewReader(r, ipc.WithAllocator(alloc))
if err != nil {
return nil, adbc.Error{
Msg: err.Error(),
Code: adbc.StatusInvalidState,
}
}

var recTransform recordTransformer
rdr.schema, recTransform = getTransformer(rr.Schema(), ld, useHighPrecision)

rr, err := ipc.NewReader(rdr, ipc.WithAllocator(alloc))
group.Go(func() error {
defer rr.Release()
defer r.Close()
if len(batches) > 1 {
defer close(ch)
}

for rr.Next() && ctx.Err() == nil {
rec := rr.Record()
rec, err = recTransform(ctx, rec)
if err != nil {
return err
}
defer rr.Release()
ch <- rec
}
return rr.Err()
})

chs[0] = ch

lastChannelIndex := len(chs) - 1
go func() {
for i, b := range batches[1:] {
batch, batchIdx := b, i+1
chs[batchIdx] = make(chan arrow.Record, bufferSize)
group.Go(func() error {
// close channels (except the last) so that Next can move on to the next channel properly
if batchIdx != lastChannelIndex {
defer close(chs[batchIdx])
}

for rr.Next() && ctx.Err() == nil {
rec := rr.Record()
rec, err = recTransform(ctx, rec)
rdr, err := batch.GetStream(ctx)
if err != nil {
return err
}
chs[batchIdx] <- rec
}
defer rdr.Close()

return rr.Err()
})
}
rr, err := ipc.NewReader(rdr, ipc.WithAllocator(alloc))
if err != nil {
return err
}
defer rr.Release()

// place this here so that we always clean up, but they can't be in a
// separate goroutine. Otherwise we'll have a race condition between
// the call to wait and the calls to group.Go to kick off the jobs
// to perform the pre-fetching (GH-1283).
rdr.err = group.Wait()
// don't close the last channel until after the group is finished,
// so that Next() can only return after reader.err may have been set
close(chs[lastChannelIndex])
}()
for rr.Next() && ctx.Err() == nil {
rec := rr.Record()
rec, err = recTransform(ctx, rec)
if err != nil {
return err
}
chs[batchIdx] <- rec
}

return rr.Err()
})
}

// place this here so that we always clean up, but they can't be in a
// separate goroutine. Otherwise we'll have a race condition between
// the call to wait and the calls to group.Go to kick off the jobs
// to perform the pre-fetching (GH-1283).
rdr.err = group.Wait()
// don't close the last channel until after the group is finished,
// so that Next() can only return after reader.err may have been set
close(chs[lastChannelIndex])
}()
} else {
schema, err := rowTypesToArrowSchema(ctx, ld, useHighPrecision)
if err != nil {
return nil, err
}
rdr.schema, _ = getTransformer(schema, ld, useHighPrecision)
}

return rdr, nil
}
Expand Down

0 comments on commit 18d753b

Please sign in to comment.