From dad4bfb1d78fa8b176237b5616cad41d7e664a99 Mon Sep 17 00:00:00 2001
From: David Li
Date: Thu, 13 Jul 2023 11:00:04 -0400
Subject: [PATCH] fix(go/adbc/pkg): follow CGO rules properly
Fixes #729.
---
c/driver/flightsql/sqlite_flightsql_test.cc | 106 ++++++++++++++++++++
go/adbc/go.mod | 2 +-
go/adbc/go.sum | 6 +-
go/adbc/pkg/_tmpl/driver.go.tmpl | 5 +
go/adbc/pkg/_tmpl/utils.c.tmpl | 11 ++
go/adbc/pkg/flightsql/driver.go | 5 +
go/adbc/pkg/flightsql/utils.c | 11 ++
go/adbc/pkg/panicdummy/driver.go | 5 +
go/adbc/pkg/panicdummy/utils.c | 11 ++
go/adbc/pkg/snowflake/driver.go | 5 +
go/adbc/pkg/snowflake/utils.c | 11 ++
11 files changed, 173 insertions(+), 5 deletions(-)
diff --git a/c/driver/flightsql/sqlite_flightsql_test.cc b/c/driver/flightsql/sqlite_flightsql_test.cc
index 2bf0441e3a..96d448a0e4 100644
--- a/c/driver/flightsql/sqlite_flightsql_test.cc
+++ b/c/driver/flightsql/sqlite_flightsql_test.cc
@@ -15,15 +15,21 @@
// specific language governing permissions and limitations
// under the License.
+#include
+#include
+#include
+
#include
#include
#include
#include
#include
#include
+
#include "validation/adbc_validation.h"
#include "validation/adbc_validation_util.h"
+using adbc_validation::IsOkErrno;
using adbc_validation::IsOkStatus;
#define CHECK_OK(EXPR) \
@@ -103,6 +109,106 @@ class SqliteFlightSqlTest : public ::testing::Test, public adbc_validation::Data
};
ADBCV_TEST_DATABASE(SqliteFlightSqlTest)
+TEST_F(SqliteFlightSqlTest, TestGarbageInput) {
+ // Regression test for https://github.com/apache/arrow-adbc/issues/729
+
+ // 0xc000000000 is the base of the Go heap. Go's write barriers ask
+ // the GC to mark both the pointer being written, and the pointer
+ // being *overwritten*. So if Go overwrites a value in a C
+ // structure that looks like a Go pointer, the GC may get confused
+ // and error.
+ void* bad_pointer = reinterpret_cast(uintptr_t(0xc000000240));
+
+ // ADBC functions are expected not to blindly overwrite an
+ // already-allocated value/callers are expected to zero-initialize.
+ database.private_data = bad_pointer;
+ database.private_driver = reinterpret_cast(bad_pointer);
+ ASSERT_THAT(AdbcDatabaseNew(&database, &error), ::testing::Not(IsOkStatus(&error)));
+
+ std::memset(&database, 0, sizeof(database));
+ ASSERT_THAT(AdbcDatabaseNew(&database, &error), IsOkStatus(&error));
+ ASSERT_THAT(quirks()->SetupDatabase(&database, &error), IsOkStatus(&error));
+ ASSERT_THAT(AdbcDatabaseInit(&database, &error), IsOkStatus(&error));
+
+ struct AdbcConnection connection;
+ connection.private_data = bad_pointer;
+ connection.private_driver = reinterpret_cast(bad_pointer);
+ ASSERT_THAT(AdbcConnectionNew(&connection, &error), ::testing::Not(IsOkStatus(&error)));
+
+ std::memset(&connection, 0, sizeof(connection));
+ ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error));
+ ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error));
+
+ struct AdbcStatement statement;
+ statement.private_data = bad_pointer;
+ statement.private_driver = reinterpret_cast(bad_pointer);
+ ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error),
+ ::testing::Not(IsOkStatus(&error)));
+
+ // This needs to happen in parallel since we need to trigger the
+ // write barrier buffer, which means we need to trigger a GC. The
+ // Go FFI bridge deterministically triggers GC on Release calls.
+
+ auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(5);
+ while (std::chrono::steady_clock::now() < deadline) {
+ std::vector threads;
+ std::random_device rd;
+ for (int i = 0; i < 23; i++) {
+ auto seed = rd();
+ threads.emplace_back([&, seed]() {
+ std::mt19937 gen(seed);
+ std::uniform_int_distribution dist(0xc000000000L, 0xc000002000L);
+ for (int i = 0; i < 23; i++) {
+ void* bad_pointer = reinterpret_cast(uintptr_t(dist(gen)));
+
+ struct AdbcStatement statement;
+ std::memset(&statement, 0, sizeof(statement));
+ ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error),
+ IsOkStatus(&error));
+
+ ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 1", &error),
+ IsOkStatus(&error));
+ // This is not expected to be zero-initialized
+ struct ArrowArrayStream stream;
+ stream.private_data = bad_pointer;
+ stream.release =
+ reinterpret_cast(bad_pointer);
+ ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &stream, nullptr, &error),
+ IsOkStatus(&error));
+
+ struct ArrowSchema schema;
+ std::memset(&schema, 0, sizeof(schema));
+ schema.name = reinterpret_cast(bad_pointer);
+ schema.format = reinterpret_cast(bad_pointer);
+ schema.private_data = bad_pointer;
+ ASSERT_THAT(stream.get_schema(&stream, &schema), IsOkErrno());
+
+ while (true) {
+ struct ArrowArray array;
+ array.private_data = bad_pointer;
+ ASSERT_THAT(stream.get_next(&stream, &array), IsOkErrno());
+ if (array.release) {
+ array.release(&array);
+ } else {
+ break;
+ }
+ }
+
+ schema.release(&schema);
+ stream.release(&stream);
+ ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
+ }
+ });
+ }
+ for (auto& thread : threads) {
+ thread.join();
+ }
+ }
+
+ ASSERT_THAT(AdbcConnectionRelease(&connection, &error), IsOkStatus(&error));
+ ASSERT_THAT(AdbcDatabaseRelease(&database, &error), IsOkStatus(&error));
+}
+
class SqliteFlightSqlConnectionTest : public ::testing::Test,
public adbc_validation::ConnectionTest {
public:
diff --git a/go/adbc/go.mod b/go/adbc/go.mod
index e4496c24eb..1a412b22cf 100644
--- a/go/adbc/go.mod
+++ b/go/adbc/go.mod
@@ -20,7 +20,7 @@ module github.com/apache/arrow-adbc/go/adbc
go 1.18
require (
- github.com/apache/arrow/go/v13 v13.0.0-20230710202504-70f447636553
+ github.com/apache/arrow/go/v13 v13.0.0-20230713180941-b97597765355
github.com/bluele/gcache v0.0.2
github.com/google/uuid v1.3.0
github.com/snowflakedb/gosnowflake v1.6.21
diff --git a/go/adbc/go.sum b/go/adbc/go.sum
index c8b128ecbc..70654f961e 100644
--- a/go/adbc/go.sum
+++ b/go/adbc/go.sum
@@ -16,10 +16,8 @@ github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/
github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
github.com/apache/arrow/go/v12 v12.0.0 h1:xtZE63VWl7qLdB0JObIXvvhGjoVNrQ9ciIHG2OK5cmc=
github.com/apache/arrow/go/v12 v12.0.0/go.mod h1:d+tV/eHZZ7Dz7RPrFKtPK02tpr+c9/PEd/zm8mDS9Vg=
-github.com/apache/arrow/go/v13 v13.0.0-20230620164925-94af6c3c9646 h1:hLcsUn9hiiD7jDfJDKOe1tBfOL5v0wgrya5S8XXqzLw=
-github.com/apache/arrow/go/v13 v13.0.0-20230620164925-94af6c3c9646/go.mod h1:W69eByFNO0ZR30q1/7Sr9d83zcVZmF2MiP3fFYAWJOc=
-github.com/apache/arrow/go/v13 v13.0.0-20230710202504-70f447636553 h1:LV3nIWJ2254APRpYAcMxWbxoQwt66gnrkZ5NaDs1IPI=
-github.com/apache/arrow/go/v13 v13.0.0-20230710202504-70f447636553/go.mod h1:W69eByFNO0ZR30q1/7Sr9d83zcVZmF2MiP3fFYAWJOc=
+github.com/apache/arrow/go/v13 v13.0.0-20230713180941-b97597765355 h1:QuXqLb2HzL5EjY99fFp+iG9NagAruvQIbU/2++x+2VY=
+github.com/apache/arrow/go/v13 v13.0.0-20230713180941-b97597765355/go.mod h1:W69eByFNO0ZR30q1/7Sr9d83zcVZmF2MiP3fFYAWJOc=
github.com/apache/thrift v0.16.0 h1:qEy6UW60iVOlUy+b9ZR0d5WzUWYGOo4HfopoyBaNmoY=
github.com/apache/thrift v0.16.0/go.mod h1:PHK3hniurgQaNMZYaCLEqXKsYK8upmhPbmdP2FXSqgU=
github.com/aws/aws-sdk-go-v2 v1.18.0 h1:882kkTpSFhdgYRKVZ/VCgf7sd0ru57p2JCxz4/oN5RY=
diff --git a/go/adbc/pkg/_tmpl/driver.go.tmpl b/go/adbc/pkg/_tmpl/driver.go.tmpl
index 03a94c021d..68a8e23f1a 100644
--- a/go/adbc/pkg/_tmpl/driver.go.tmpl
+++ b/go/adbc/pkg/_tmpl/driver.go.tmpl
@@ -612,6 +612,11 @@ func {{.Prefix}}StatementNew(cnxn *C.struct_AdbcConnection, stmt *C.struct_AdbcS
setErr(err, "AdbcStatementNew: Go panicked, driver is in unknown state")
return C.ADBC_STATUS_INTERNAL
}
+ if stmt.private_data != nil {
+ setErr(err, "AdbcStatementNew: statement already allocated")
+ return C.ADBC_STATUS_INVALID_STATE
+ }
+
conn := checkConnInit(cnxn, err, "AdbcStatementNew")
if conn == nil {
return C.ADBC_STATUS_INVALID_STATE
diff --git a/go/adbc/pkg/_tmpl/utils.c.tmpl b/go/adbc/pkg/_tmpl/utils.c.tmpl
index 29d19bc55f..38222875fd 100644
--- a/go/adbc/pkg/_tmpl/utils.c.tmpl
+++ b/go/adbc/pkg/_tmpl/utils.c.tmpl
@@ -21,6 +21,8 @@
#include "utils.h"
+#include
+
#ifdef __cplusplus
extern "C" {
#endif
@@ -74,6 +76,7 @@ AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection,
uint32_t* info_codes, size_t info_codes_length,
struct ArrowArrayStream* out,
struct AdbcError* error) {
+ if (out) memset(out, 0, sizeof(*out));
return {{.Prefix}}ConnectionGetInfo(connection, info_codes, info_codes_length, out, error);
}
@@ -83,6 +86,7 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d
const char* column_name,
struct ArrowArrayStream* out,
struct AdbcError* error) {
+ if (out) memset(out, 0, sizeof(*out));
return {{.Prefix}}ConnectionGetObjects(connection, depth, catalog, db_schema, table_name,
table_type, column_name, out, error);
}
@@ -92,6 +96,7 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection,
const char* table_name,
struct ArrowSchema* schema,
struct AdbcError* error) {
+ if (schema) memset(schema, 0, sizeof(*schema));
return {{.Prefix}}ConnectionGetTableSchema(connection, catalog, db_schema, table_name,
schema, error);
}
@@ -99,6 +104,7 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection,
AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection,
struct ArrowArrayStream* out,
struct AdbcError* error) {
+ if (out) memset(out, 0, sizeof(*out));
return {{.Prefix}}ConnectionGetTableTypes(connection, out, error);
}
@@ -107,6 +113,7 @@ AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection,
size_t serialized_length,
struct ArrowArrayStream* out,
struct AdbcError* error) {
+ if (out) memset(out, 0, sizeof(*out));
return {{.Prefix}}ConnectionReadPartition(connection, serialized_partition,
serialized_length, out, error);
}
@@ -136,6 +143,7 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement,
struct ArrowArrayStream* out,
int64_t* rows_affected,
struct AdbcError* error) {
+ if (out) memset(out, 0, sizeof(*out));
return {{.Prefix}}StatementExecuteQuery(statement, out, rows_affected, error);
}
@@ -170,6 +178,7 @@ AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement,
AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement,
struct ArrowSchema* schema,
struct AdbcError* error) {
+ if (schema) memset(schema, 0, sizeof(*schema));
return {{.Prefix}}StatementGetParameterSchema(statement, schema, error);
}
@@ -183,6 +192,8 @@ AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement,
struct AdbcPartitions* partitions,
int64_t* rows_affected,
struct AdbcError* error) {
+ if (schema) memset(schema, 0, sizeof(*schema));
+ if (partitions) memset(partitions, 0, sizeof(*partitions));
return {{.Prefix}}StatementExecutePartitions(statement, schema, partitions, rows_affected,
error);
}
diff --git a/go/adbc/pkg/flightsql/driver.go b/go/adbc/pkg/flightsql/driver.go
index 6d5cf75b8b..a61d807915 100644
--- a/go/adbc/pkg/flightsql/driver.go
+++ b/go/adbc/pkg/flightsql/driver.go
@@ -616,6 +616,11 @@ func FlightSQLStatementNew(cnxn *C.struct_AdbcConnection, stmt *C.struct_AdbcSta
setErr(err, "AdbcStatementNew: Go panicked, driver is in unknown state")
return C.ADBC_STATUS_INTERNAL
}
+ if stmt.private_data != nil {
+ setErr(err, "AdbcStatementNew: statement already allocated")
+ return C.ADBC_STATUS_INVALID_STATE
+ }
+
conn := checkConnInit(cnxn, err, "AdbcStatementNew")
if conn == nil {
return C.ADBC_STATUS_INVALID_STATE
diff --git a/go/adbc/pkg/flightsql/utils.c b/go/adbc/pkg/flightsql/utils.c
index 3d3d89c5ce..41777a98c4 100644
--- a/go/adbc/pkg/flightsql/utils.c
+++ b/go/adbc/pkg/flightsql/utils.c
@@ -23,6 +23,8 @@
#include "utils.h"
+#include
+
#ifdef __cplusplus
extern "C" {
#endif
@@ -76,6 +78,7 @@ AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection,
uint32_t* info_codes, size_t info_codes_length,
struct ArrowArrayStream* out,
struct AdbcError* error) {
+ if (out) memset(out, 0, sizeof(*out));
return FlightSQLConnectionGetInfo(connection, info_codes, info_codes_length, out,
error);
}
@@ -86,6 +89,7 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d
const char* column_name,
struct ArrowArrayStream* out,
struct AdbcError* error) {
+ if (out) memset(out, 0, sizeof(*out));
return FlightSQLConnectionGetObjects(connection, depth, catalog, db_schema, table_name,
table_type, column_name, out, error);
}
@@ -95,6 +99,7 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection,
const char* table_name,
struct ArrowSchema* schema,
struct AdbcError* error) {
+ if (schema) memset(schema, 0, sizeof(*schema));
return FlightSQLConnectionGetTableSchema(connection, catalog, db_schema, table_name,
schema, error);
}
@@ -102,6 +107,7 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection,
AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection,
struct ArrowArrayStream* out,
struct AdbcError* error) {
+ if (out) memset(out, 0, sizeof(*out));
return FlightSQLConnectionGetTableTypes(connection, out, error);
}
@@ -110,6 +116,7 @@ AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection,
size_t serialized_length,
struct ArrowArrayStream* out,
struct AdbcError* error) {
+ if (out) memset(out, 0, sizeof(*out));
return FlightSQLConnectionReadPartition(connection, serialized_partition,
serialized_length, out, error);
}
@@ -139,6 +146,7 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement,
struct ArrowArrayStream* out,
int64_t* rows_affected,
struct AdbcError* error) {
+ if (out) memset(out, 0, sizeof(*out));
return FlightSQLStatementExecuteQuery(statement, out, rows_affected, error);
}
@@ -173,6 +181,7 @@ AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement,
AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement,
struct ArrowSchema* schema,
struct AdbcError* error) {
+ if (schema) memset(schema, 0, sizeof(*schema));
return FlightSQLStatementGetParameterSchema(statement, schema, error);
}
@@ -186,6 +195,8 @@ AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement,
struct AdbcPartitions* partitions,
int64_t* rows_affected,
struct AdbcError* error) {
+ if (schema) memset(schema, 0, sizeof(*schema));
+ if (partitions) memset(partitions, 0, sizeof(*partitions));
return FlightSQLStatementExecutePartitions(statement, schema, partitions, rows_affected,
error);
}
diff --git a/go/adbc/pkg/panicdummy/driver.go b/go/adbc/pkg/panicdummy/driver.go
index 374c3cb828..36602f4299 100644
--- a/go/adbc/pkg/panicdummy/driver.go
+++ b/go/adbc/pkg/panicdummy/driver.go
@@ -616,6 +616,11 @@ func PanicDummyStatementNew(cnxn *C.struct_AdbcConnection, stmt *C.struct_AdbcSt
setErr(err, "AdbcStatementNew: Go panicked, driver is in unknown state")
return C.ADBC_STATUS_INTERNAL
}
+ if stmt.private_data != nil {
+ setErr(err, "AdbcStatementNew: statement already allocated")
+ return C.ADBC_STATUS_INVALID_STATE
+ }
+
conn := checkConnInit(cnxn, err, "AdbcStatementNew")
if conn == nil {
return C.ADBC_STATUS_INVALID_STATE
diff --git a/go/adbc/pkg/panicdummy/utils.c b/go/adbc/pkg/panicdummy/utils.c
index 5978aaa5b6..d0a2936618 100644
--- a/go/adbc/pkg/panicdummy/utils.c
+++ b/go/adbc/pkg/panicdummy/utils.c
@@ -23,6 +23,8 @@
#include "utils.h"
+#include
+
#ifdef __cplusplus
extern "C" {
#endif
@@ -76,6 +78,7 @@ AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection,
uint32_t* info_codes, size_t info_codes_length,
struct ArrowArrayStream* out,
struct AdbcError* error) {
+ if (out) memset(out, 0, sizeof(*out));
return PanicDummyConnectionGetInfo(connection, info_codes, info_codes_length, out,
error);
}
@@ -86,6 +89,7 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d
const char* column_name,
struct ArrowArrayStream* out,
struct AdbcError* error) {
+ if (out) memset(out, 0, sizeof(*out));
return PanicDummyConnectionGetObjects(connection, depth, catalog, db_schema, table_name,
table_type, column_name, out, error);
}
@@ -95,6 +99,7 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection,
const char* table_name,
struct ArrowSchema* schema,
struct AdbcError* error) {
+ if (schema) memset(schema, 0, sizeof(*schema));
return PanicDummyConnectionGetTableSchema(connection, catalog, db_schema, table_name,
schema, error);
}
@@ -102,6 +107,7 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection,
AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection,
struct ArrowArrayStream* out,
struct AdbcError* error) {
+ if (out) memset(out, 0, sizeof(*out));
return PanicDummyConnectionGetTableTypes(connection, out, error);
}
@@ -110,6 +116,7 @@ AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection,
size_t serialized_length,
struct ArrowArrayStream* out,
struct AdbcError* error) {
+ if (out) memset(out, 0, sizeof(*out));
return PanicDummyConnectionReadPartition(connection, serialized_partition,
serialized_length, out, error);
}
@@ -139,6 +146,7 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement,
struct ArrowArrayStream* out,
int64_t* rows_affected,
struct AdbcError* error) {
+ if (out) memset(out, 0, sizeof(*out));
return PanicDummyStatementExecuteQuery(statement, out, rows_affected, error);
}
@@ -173,6 +181,7 @@ AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement,
AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement,
struct ArrowSchema* schema,
struct AdbcError* error) {
+ if (schema) memset(schema, 0, sizeof(*schema));
return PanicDummyStatementGetParameterSchema(statement, schema, error);
}
@@ -186,6 +195,8 @@ AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement,
struct AdbcPartitions* partitions,
int64_t* rows_affected,
struct AdbcError* error) {
+ if (schema) memset(schema, 0, sizeof(*schema));
+ if (partitions) memset(partitions, 0, sizeof(*partitions));
return PanicDummyStatementExecutePartitions(statement, schema, partitions,
rows_affected, error);
}
diff --git a/go/adbc/pkg/snowflake/driver.go b/go/adbc/pkg/snowflake/driver.go
index 31e2f131a5..88ba66d167 100644
--- a/go/adbc/pkg/snowflake/driver.go
+++ b/go/adbc/pkg/snowflake/driver.go
@@ -616,6 +616,11 @@ func SnowflakeStatementNew(cnxn *C.struct_AdbcConnection, stmt *C.struct_AdbcSta
setErr(err, "AdbcStatementNew: Go panicked, driver is in unknown state")
return C.ADBC_STATUS_INTERNAL
}
+ if stmt.private_data != nil {
+ setErr(err, "AdbcStatementNew: statement already allocated")
+ return C.ADBC_STATUS_INVALID_STATE
+ }
+
conn := checkConnInit(cnxn, err, "AdbcStatementNew")
if conn == nil {
return C.ADBC_STATUS_INVALID_STATE
diff --git a/go/adbc/pkg/snowflake/utils.c b/go/adbc/pkg/snowflake/utils.c
index 8c360b0fd1..24d3ca3d90 100644
--- a/go/adbc/pkg/snowflake/utils.c
+++ b/go/adbc/pkg/snowflake/utils.c
@@ -23,6 +23,8 @@
#include "utils.h"
+#include
+
#ifdef __cplusplus
extern "C" {
#endif
@@ -76,6 +78,7 @@ AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection,
uint32_t* info_codes, size_t info_codes_length,
struct ArrowArrayStream* out,
struct AdbcError* error) {
+ if (out) memset(out, 0, sizeof(*out));
return SnowflakeConnectionGetInfo(connection, info_codes, info_codes_length, out,
error);
}
@@ -86,6 +89,7 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d
const char* column_name,
struct ArrowArrayStream* out,
struct AdbcError* error) {
+ if (out) memset(out, 0, sizeof(*out));
return SnowflakeConnectionGetObjects(connection, depth, catalog, db_schema, table_name,
table_type, column_name, out, error);
}
@@ -95,6 +99,7 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection,
const char* table_name,
struct ArrowSchema* schema,
struct AdbcError* error) {
+ if (schema) memset(schema, 0, sizeof(*schema));
return SnowflakeConnectionGetTableSchema(connection, catalog, db_schema, table_name,
schema, error);
}
@@ -102,6 +107,7 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection,
AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection,
struct ArrowArrayStream* out,
struct AdbcError* error) {
+ if (out) memset(out, 0, sizeof(*out));
return SnowflakeConnectionGetTableTypes(connection, out, error);
}
@@ -110,6 +116,7 @@ AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection,
size_t serialized_length,
struct ArrowArrayStream* out,
struct AdbcError* error) {
+ if (out) memset(out, 0, sizeof(*out));
return SnowflakeConnectionReadPartition(connection, serialized_partition,
serialized_length, out, error);
}
@@ -139,6 +146,7 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement,
struct ArrowArrayStream* out,
int64_t* rows_affected,
struct AdbcError* error) {
+ if (out) memset(out, 0, sizeof(*out));
return SnowflakeStatementExecuteQuery(statement, out, rows_affected, error);
}
@@ -173,6 +181,7 @@ AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement,
AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement,
struct ArrowSchema* schema,
struct AdbcError* error) {
+ if (schema) memset(schema, 0, sizeof(*schema));
return SnowflakeStatementGetParameterSchema(statement, schema, error);
}
@@ -186,6 +195,8 @@ AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement,
struct AdbcPartitions* partitions,
int64_t* rows_affected,
struct AdbcError* error) {
+ if (schema) memset(schema, 0, sizeof(*schema));
+ if (partitions) memset(partitions, 0, sizeof(*partitions));
return SnowflakeStatementExecutePartitions(statement, schema, partitions, rows_affected,
error);
}