Skip to content

Commit

Permalink
fix: Reduced SQL calls in GetObjects to two, added prefixing DbName f…
Browse files Browse the repository at this point in the history
…or INFORMATION_SCHEMA calls, and removed SQL cursor code

Additional changes:
* Reduced SQL calls by making only 1 - 2 SQL call based on the ObjectsDepth and using that data to populate all the previous depth information
* There is one SHOW TERSE DATABASES that is always called and the catalogPattern is used to filter the databases and prepare the SQL based on the depth for schema, tables, or columns
* The SQL cursor code was removed and the replaced with a static SQL that is prepared in Go based on the databases that match the catalogPattern
* GetObjects populates the MetadataRecords by making the necessary SQL call based on ObjectsDepth
* Modified the logic of GetObjects Init to pass MetadataRecords in getObjectsDbSchemas and getObjectsTables
* Modified tests to check the table type returned
  • Loading branch information
ryan-syed committed Jan 2, 2024
1 parent 650994d commit b76af78
Show file tree
Hide file tree
Showing 4 changed files with 505 additions and 202 deletions.
18 changes: 8 additions & 10 deletions csharp/test/Drivers/Interop/Snowflake/DriverTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,6 @@ public DriverTests()
Dictionary<string, string> options = new Dictionary<string, string>();
_snowflakeDriver = SnowflakeTestingUtils.GetSnowflakeAdbcDriver(_testConfiguration, out parameters);

string databaseName = _testConfiguration.Metadata.Catalog;
string schemaName = _testConfiguration.Metadata.Schema;

parameters[SnowflakeParameters.DATABASE] = databaseName;
parameters[SnowflakeParameters.SCHEMA] = schemaName;

_database = _snowflakeDriver.Open(parameters);
_connection = _database.Connect(options);
}
Expand Down Expand Up @@ -214,7 +208,7 @@ public void CanGetObjectsTables(string tableNamePattern)
string tableName = _testConfiguration.Metadata.Table;

using IArrowArrayStream stream = _connection.GetObjects(
depth: AdbcConnection.GetObjectsDepth.All,
depth: AdbcConnection.GetObjectsDepth.Tables,
catalogPattern: databaseName,
dbSchemaPattern: schemaName,
tableNamePattern: tableNamePattern,
Expand All @@ -235,6 +229,7 @@ public void CanGetObjectsTables(string tableNamePattern)

AdbcTable table = tables.Where((table) => string.Equals(table.Name, tableName)).FirstOrDefault();
Assert.True(table != null, "table should not be null");
Assert.Equal("BASE TABLE", table.Type);
}

/// <summary>
Expand All @@ -260,18 +255,21 @@ public void CanGetObjectsAll()
using RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result;

List<AdbcCatalog> catalogs = GetObjectsParser.ParseCatalog(recordBatch, databaseName, schemaName);

List<AdbcColumn> columns = catalogs
AdbcTable table = catalogs
.Where(c => string.Equals(c.Name, databaseName))
.Select(c => c.DbSchemas)
.FirstOrDefault()
.Where(s => string.Equals(s.Name, schemaName))
.Select(s => s.Tables)
.FirstOrDefault()
.Where(t => string.Equals(t.Name, tableName))
.Select(t => t.Columns)
.FirstOrDefault();


Assert.True(table != null, "table should not be null");
Assert.Equal("BASE TABLE", table.Type);
List<AdbcColumn> columns = table.Columns;

Assert.True(columns != null, "Columns cannot be null");
Assert.Equal(_testConfiguration.Metadata.ExpectedColumnCount, columns.Count);

Expand Down
4 changes: 2 additions & 2 deletions go/adbc/driver/flightsql/flightsql_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ func (c *cnxn) readInfo(ctx context.Context, expectedSchema *arrow.Schema, info
}

// Helper function to build up a map of catalogs to DB schemas
func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string) (result map[string][]string, err error) {
func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, metadataRecords []internal.Metadata) (result map[string][]string, err error) {
if depth == adbc.ObjectDepthCatalogs {
return
}
Expand Down Expand Up @@ -588,7 +588,7 @@ func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth,
return
}

func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) (result internal.SchemaToTableInfo, err error) {
func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string, metadataRecords []internal.Metadata) (result internal.SchemaToTableInfo, err error) {
if depth == adbc.ObjectDepthCatalogs || depth == adbc.ObjectDepthDBSchemas {
return
}
Expand Down
21 changes: 17 additions & 4 deletions go/adbc/driver/internal/shared_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ package internal

import (
"context"
"database/sql"
"regexp"
"strconv"
"strings"
"time"

"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow/go/v14/arrow"
Expand All @@ -38,8 +40,18 @@ type TableInfo struct {
Schema *arrow.Schema
}

type GetObjDBSchemasFn func(ctx context.Context, depth adbc.ObjectDepth, catalog *string, schema *string) (map[string][]string, error)
type GetObjTablesFn func(ctx context.Context, depth adbc.ObjectDepth, catalog *string, schema *string, tableName *string, columnName *string, tableType []string) (map[CatalogAndSchema][]TableInfo, error)
type Metadata struct {
Created time.Time
ColName, DataType string
Dbname, Kind, Schema, TblName, TblType, IdentGen, IdentIncrement, Comment sql.NullString
OrdinalPos int
NumericPrec, NumericPrecRadix, NumericScale, DatetimePrec sql.NullInt16
IsNullable, IsIdent bool
CharMaxLength, CharOctetLength sql.NullInt32
}

type GetObjDBSchemasFn func(ctx context.Context, depth adbc.ObjectDepth, catalog *string, schema *string, metadataRecords []Metadata) (map[string][]string, error)
type GetObjTablesFn func(ctx context.Context, depth adbc.ObjectDepth, catalog *string, schema *string, tableName *string, columnName *string, tableType []string, metadataRecords []Metadata) (map[CatalogAndSchema][]TableInfo, error)
type SchemaToTableInfo = map[CatalogAndSchema][]TableInfo

// Helper function that compiles a SQL-style pattern (%, _) to a regex
Expand Down Expand Up @@ -87,6 +99,7 @@ type GetObjects struct {
builder *array.RecordBuilder
schemaLookup map[string][]string
tableLookup map[CatalogAndSchema][]TableInfo
MetadataRecords []Metadata
catalogPattern *regexp.Regexp
columnNamePattern *regexp.Regexp

Expand Down Expand Up @@ -123,13 +136,13 @@ type GetObjects struct {
}

func (g *GetObjects) Init(mem memory.Allocator, getObj GetObjDBSchemasFn, getTbls GetObjTablesFn) error {
if catalogToDbSchemas, err := getObj(g.Ctx, g.Depth, g.Catalog, g.DbSchema); err != nil {
if catalogToDbSchemas, err := getObj(g.Ctx, g.Depth, g.Catalog, g.DbSchema, g.MetadataRecords); err != nil {
return err
} else {
g.schemaLookup = catalogToDbSchemas
}

if tableLookup, err := getTbls(g.Ctx, g.Depth, g.Catalog, g.DbSchema, g.TableName, g.ColumnName, g.TableType); err != nil {
if tableLookup, err := getTbls(g.Ctx, g.Depth, g.Catalog, g.DbSchema, g.TableName, g.ColumnName, g.TableType, g.MetadataRecords); err != nil {
return err
} else {
g.tableLookup = tableLookup
Expand Down
Loading

0 comments on commit b76af78

Please sign in to comment.