From a73a8516858c0e80f165c29997f1c7486ab6c389 Mon Sep 17 00:00:00 2001 From: Brennan Lamey <66885902+brennanjl@users.noreply.github.com> Date: Wed, 27 Sep 2023 10:08:14 -0500 Subject: [PATCH] Fix/deterministic ordering (#308) * added several new tests for the sqlanalyzer * fixed various bugs in the database engine, and improved unit tests * added tool to analyze select cores. this is needed to support determinism for CTEs * added the ability to predict attribute types, and generate a table * rebased main, merged proto * minor additions * fixed comment typo * fixed failing unit test * fixed gavins feedback * fixed jons feedback --- go.mod | 2 +- pkg/engine/dataset/data_test.go | 8 +- pkg/engine/db/db.go | 68 +- pkg/engine/db/db_test.go | 190 +++- pkg/engine/db/metadata.go | 13 +- pkg/engine/db/persist.go | 39 +- pkg/engine/db/procedures.go | 6 +- .../db/sql-ddl-generator/generate_test.go | 42 + pkg/engine/db/upgrade_test.go | 8 +- pkg/engine/engine_test.go | 69 +- pkg/engine/sqlanalyzer/analyzer.go | 87 +- pkg/engine/sqlanalyzer/analyzer_test.go | 194 ++++ .../sqlanalyzer/attributes/select_core.go | 290 ++++++ .../attributes/select_core_test.go | 296 ++++++ pkg/engine/sqlanalyzer/attributes/table.go | 82 ++ pkg/engine/sqlanalyzer/attributes/types.go | 256 +++++ pkg/engine/sqlanalyzer/clean/clean.go | 69 ++ pkg/engine/sqlanalyzer/clean/errors.go | 34 + pkg/engine/sqlanalyzer/clean/walker.go | 403 ++++++++ pkg/engine/sqlanalyzer/join/joins.go | 5 +- pkg/engine/sqlanalyzer/mutativity.go | 2 +- pkg/engine/sqlanalyzer/order/analyzer.go | 51 +- pkg/engine/sqlanalyzer/order/order_test.go | 36 +- pkg/engine/sqlanalyzer/order/visitors.go | 9 +- pkg/engine/sqlanalyzer/utils/utils.go | 32 + pkg/engine/sqlparser/tree/collate.go | 29 +- pkg/engine/sqlparser/tree/insert.go | 9 + pkg/engine/sqlparser/tree/join-clause.go | 12 + pkg/engine/sqlparser/tree/operators.go | 47 + pkg/engine/sqlparser/tree/order-by.go | 34 +- pkg/engine/sqlparser/tree/select.go | 18 + pkg/engine/sqlparser/tree/update.go | 33 +- pkg/engine/sqlparser/tree/upsert.go | 11 + pkg/engine/sqlparser/tree/utils_test.go | 7 - pkg/engine/sqlparser/tree/visitor.go | 591 ------------ pkg/engine/sqlparser/tree/walker.go | 913 ++++++++++++++++++ pkg/engine/types/clean.go | 16 - pkg/engine/types/foreign_key.go | 23 + pkg/engine/types/index.go | 15 +- pkg/engine/types/table.go | 80 +- pkg/engine/types/testdata/tables.go | 173 ++++ pkg/sessions/session.go | 2 +- pkg/utils/order/order.go | 53 +- 43 files changed, 3570 insertions(+), 787 deletions(-) create mode 100644 pkg/engine/sqlanalyzer/analyzer_test.go create mode 100644 pkg/engine/sqlanalyzer/attributes/select_core.go create mode 100644 pkg/engine/sqlanalyzer/attributes/select_core_test.go create mode 100644 pkg/engine/sqlanalyzer/attributes/table.go create mode 100644 pkg/engine/sqlanalyzer/attributes/types.go create mode 100644 pkg/engine/sqlanalyzer/clean/clean.go create mode 100644 pkg/engine/sqlanalyzer/clean/errors.go create mode 100644 pkg/engine/sqlanalyzer/clean/walker.go create mode 100644 pkg/engine/sqlanalyzer/utils/utils.go delete mode 100644 pkg/engine/sqlparser/tree/visitor.go create mode 100644 pkg/engine/types/testdata/tables.go diff --git a/go.mod b/go.mod index f9cd645e0..605612ef6 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/decred/dcrd/certgen v1.1.2 github.com/dgraph-io/badger/v3 v3.2103.5 github.com/ethereum/go-ethereum v1.12.0 + github.com/google/go-cmp v0.5.9 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.0-rc.5 github.com/grpc-ecosystem/grpc-gateway/v2 v2.12.0 github.com/joho/godotenv v1.5.1 @@ -117,7 +118,6 @@ require ( github.com/google/btree v1.1.2 // indirect github.com/google/flatbuffers v1.12.1 // indirect github.com/google/gnostic v0.5.7-v3refs // indirect - github.com/google/go-cmp v0.5.9 // indirect github.com/google/gofuzz v1.2.0 // indirect github.com/google/orderedcode v0.0.1 // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect diff --git a/pkg/engine/dataset/data_test.go b/pkg/engine/dataset/data_test.go index 7aa0a1d0d..396e35f40 100644 --- a/pkg/engine/dataset/data_test.go +++ b/pkg/engine/dataset/data_test.go @@ -38,11 +38,11 @@ var ( }, { Type: types.MIN_LENGTH, - Value: 5, + Value: "5", }, { Type: types.MAX_LENGTH, - Value: 32, + Value: "32", }, }, }, @@ -55,11 +55,11 @@ var ( }, { Type: types.MIN, - Value: 13, + Value: "13", }, { Type: types.MAX, - Value: 200, + Value: "200", }, }, }, diff --git a/pkg/engine/db/db.go b/pkg/engine/db/db.go index 2d9606143..5ccc78bbf 100644 --- a/pkg/engine/db/db.go +++ b/pkg/engine/db/db.go @@ -13,14 +13,18 @@ import ( "io" "sync" + "github.com/kwilteam/kwil-db/pkg/log" + "go.uber.org/zap" + "github.com/kwilteam/kwil-db/pkg/engine/sqlanalyzer" - "github.com/kwilteam/kwil-db/pkg/engine/sqlparser" "github.com/kwilteam/kwil-db/pkg/sql" ) type DB struct { Sqldb SqlDB + log log.Logger + // caches metadata from QueryUnsafe. // this is a really bad practice, but works. // essentially, we cache the metadata the first time it is retrieved, during schema @@ -30,6 +34,23 @@ type DB struct { mu sync.RWMutex } +// NewDB wraps the provided database with a DB abstraction layer. +// It will initialize the metadata table if it does not exist. +func NewDB(ctx context.Context, sqldb SqlDB, opts ...DBOpt) (*DB, error) { + db := &DB{ + Sqldb: sqldb, + metadataCache: make(map[metadataType][]*metadata), + log: log.NewNoOp(), + } + + err := db.initMetadataTable(ctx) + if err != nil { + return nil, err + } + + return db, nil +} + func (d *DB) Close() error { return d.Sqldb.Close() } @@ -39,41 +60,28 @@ func (d *DB) Delete() error { } func (d *DB) Prepare(ctx context.Context, query string) (*PreparedStatement, error) { - ast, err := sqlparser.Parse(query) - if err != nil { - return nil, err - } - tables, err := d.ListTables(ctx) if err != nil { return nil, err } - err = sqlanalyzer.ApplyRules(ast, sqlanalyzer.AllRules, &sqlanalyzer.RuleMetadata{ + analyzed, err := sqlanalyzer.ApplyRules(query, sqlanalyzer.AllRules, &sqlanalyzer.RuleMetadata{ Tables: tables, }) if err != nil { + d.log.Debug("failed to analyze query", zap.String("query", query), zap.Error(err)) return nil, err } - mutativity, err := sqlanalyzer.IsMutative(ast) - if err != nil { - return nil, err - } - - generatedSql, err := ast.ToSQL() - if err != nil { - return nil, err - } - - prepStmt, err := d.Sqldb.Prepare(generatedSql) + prepStmt, err := d.Sqldb.Prepare(analyzed.Statement()) if err != nil { + d.log.Error("failed to prepare analyzed statement", zap.String("query", query), zap.Error(err)) return nil, err } return &PreparedStatement{ Statement: prepStmt, - mutative: mutativity, + mutative: analyzed.Mutative(), }, nil } @@ -85,20 +93,6 @@ func (d *DB) Savepoint() (sql.Savepoint, error) { return d.Sqldb.Savepoint() } -func NewDB(ctx context.Context, sqldb SqlDB) (*DB, error) { - db := &DB{ - Sqldb: sqldb, - metadataCache: make(map[metadataType][]*metadata), - } - - err := db.initMetadataTable(ctx) - if err != nil { - return nil, err - } - - return db, nil -} - func (d *DB) CreateSession() (sql.Session, error) { return d.Sqldb.CreateSession() } @@ -106,3 +100,11 @@ func (d *DB) CreateSession() (sql.Session, error) { func (d *DB) ApplyChangeset(changeset io.Reader) error { return d.Sqldb.ApplyChangeset(changeset) } + +type DBOpt func(*DB) + +func WithLogger(logger log.Logger) DBOpt { + return func(db *DB) { + db.log = logger + } +} diff --git a/pkg/engine/db/db_test.go b/pkg/engine/db/db_test.go index fe82f8f1f..4f087d23a 100644 --- a/pkg/engine/db/db_test.go +++ b/pkg/engine/db/db_test.go @@ -2,13 +2,13 @@ package db_test import ( "context" - "reflect" "testing" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/kwilteam/kwil-db/pkg/engine/db" "github.com/kwilteam/kwil-db/pkg/engine/db/test" "github.com/kwilteam/kwil-db/pkg/engine/types" - "github.com/kwilteam/kwil-db/pkg/engine/types/testdata" "github.com/kwilteam/kwil-db/pkg/sql" "github.com/stretchr/testify/assert" ) @@ -41,11 +41,16 @@ func Test_CreateTables(t *testing.T) { { name: "create 2 tables", test: func(t *testing.T, datastore *db.DB) { + psts := tblPosts + _ = psts + usrs := tblUsers + _ = usrs + ctx := context.Background() - err := datastore.CreateTable(ctx, &testdata.Table_users) + err := datastore.CreateTable(ctx, tblUsers) assert.NoError(t, err) - err = datastore.CreateTable(ctx, &testdata.Table_posts) + err = datastore.CreateTable(ctx, tblPosts) assert.NoError(t, err) tbls, err := datastore.ListTables(ctx) @@ -55,11 +60,12 @@ func Test_CreateTables(t *testing.T) { containsUsers := false containsPosts := false for _, tbl := range tbls { - if reflect.DeepEqual(tbl, &testdata.Table_users) { + + if deepEqual(tbl, tblUsers) { containsUsers = true } - if reflect.DeepEqual(tbl, &testdata.Table_posts) { + if deepEqual(tbl, tblPosts) { containsPosts = true } } @@ -74,10 +80,10 @@ func Test_CreateTables(t *testing.T) { ctx := context.Background() - err := datastore.StoreProcedure(ctx, &testdata.Procedure_create_user) + err := datastore.StoreProcedure(ctx, procedureCreateUser) assert.NoError(t, err) - err = datastore.StoreProcedure(ctx, &testdata.Procedure_create_post) + err = datastore.StoreProcedure(ctx, procedureCreatePost) assert.NoError(t, err) procs, err := datastore.ListProcedures(ctx) @@ -87,11 +93,11 @@ func Test_CreateTables(t *testing.T) { containsGetUser := false containsGetPost := false for _, proc := range procs { - if reflect.DeepEqual(proc, &testdata.Procedure_create_user) { + if deepEqual(proc, procedureCreateUser) { containsGetUser = true } - if reflect.DeepEqual(proc, &testdata.Procedure_create_post) { + if deepEqual(proc, procedureCreatePost) { containsGetPost = true } } @@ -119,11 +125,11 @@ func Test_CreateTables(t *testing.T) { containsExt1 := false containsExt2 := false for _, ext := range exts { - if reflect.DeepEqual(ext, testExt1) { + if deepEqual(ext, testExt1) { containsExt1 = true } - if reflect.DeepEqual(ext, testExt2) { + if deepEqual(ext, testExt2) { containsExt2 = true } } @@ -150,9 +156,9 @@ func Test_CreateTables(t *testing.T) { } } -var defaultTables = []*types.Table{ - &testdata.Table_users, - &testdata.Table_posts, +// deepEqual does a deep comparison, while considering empty slices as equal to nils. +func deepEqual(a, b any) bool { + return cmp.Equal(a, b, cmpopts.EquateEmpty()) } func Test_Prepare(t *testing.T) { @@ -195,7 +201,11 @@ func Test_Prepare(t *testing.T) { } _, err = datastore.Prepare(ctx, tt.statement) - assert.Equal(t, tt.wantErr, err != nil) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } }) } } @@ -230,3 +240,151 @@ func (b *baseMockDatastore) Savepoint() (sql.Savepoint, error) { func (b *baseMockDatastore) TableExists(ctx context.Context, table string) (bool, error) { return false, nil } + +var ( + tblUsers = &types.Table{ + Name: "users", + Columns: []*types.Column{ + { + Name: "id", + Type: types.INT, + Attributes: []*types.Attribute{ + { + Type: types.PRIMARY_KEY, + }, + }, + }, + { + Name: "name", + Type: types.TEXT, + Attributes: []*types.Attribute{ + { + Type: types.NOT_NULL, + }, + { + Type: types.MIN_LENGTH, + Value: "3", + }, + { + Type: types.MAX_LENGTH, + Value: "255", + }, + { + Type: types.UNIQUE, + }, + }, + }, + { + Name: "age", + Type: types.INT, + Attributes: []*types.Attribute{ + { + Type: types.NOT_NULL, + }, + { + Type: types.MIN, + Value: "0", + }, + { + Type: types.MAX, + Value: "150", + }, + }, + }, + }, + Indexes: []*types.Index{ + { + Name: "age_index", + Columns: []string{ + "age", + }, + Type: types.BTREE, + }, + }, + } + + tblPosts = &types.Table{ + Name: "posts", + Columns: []*types.Column{ + { + Name: "id1", + Type: types.INT, + }, + { + Name: "id2", // doing this to check composite primary keys + Type: types.INT, + }, + { + Name: "title", + Type: types.TEXT, + Attributes: []*types.Attribute{ + { + Type: types.NOT_NULL, + }, + }, + }, + { + Name: "author_id", + Type: types.INT, + Attributes: []*types.Attribute{ + { + Type: types.NOT_NULL, + }, + }, + }, + }, + Indexes: []*types.Index{ + { + Name: "primary_key", + Columns: []string{ + "id1", + "id2", + }, + Type: types.PRIMARY, + }, + }, + ForeignKeys: []*types.ForeignKey{ + { + ChildKeys: []string{"author_id"}, + ParentKeys: []string{"id"}, + ParentTable: "users", + Actions: []*types.ForeignKeyAction{ + { + On: types.ON_UPDATE, + Do: types.DO_CASCADE, + }, + }, + }, + }, + } +) + +var defaultTables = []*types.Table{ + tblUsers, + tblPosts, +} + +var ( + procedureCreateUser = &types.Procedure{ + Name: "create_user", + Args: []string{"$id", "$name", "$age"}, + Public: false, + Modifiers: []types.Modifier{types.ModifierAuthenticated}, + Statements: []string{ + "INSERT INTO users (id, name, age) VALUES ($id, $name, $age);", + }, + } + + procedureCreatePost = &types.Procedure{ + Name: "create_post", + Args: []string{"$id1", "$id2", "$title", "$author_id"}, + Public: true, + Modifiers: []types.Modifier{ + types.ModifierAuthenticated, + types.ModifierOwner, + }, + Statements: []string{ + "INSERT INTO posts (id1, id2, title, author_id) VALUES ($id1, $id2, $title, $author_id);", + }, + } +) diff --git a/pkg/engine/db/metadata.go b/pkg/engine/db/metadata.go index 210d8f1ce..28c2e39e2 100644 --- a/pkg/engine/db/metadata.go +++ b/pkg/engine/db/metadata.go @@ -2,10 +2,11 @@ package db import ( "context" - "encoding/json" + "fmt" "github.com/kwilteam/kwil-db/pkg/engine/types" + "github.com/kwilteam/kwil-db/pkg/serialize" ) type metadata struct { @@ -116,7 +117,7 @@ func (d *DB) getMetadata(ctx context.Context, metaType metadataType) ([]*metadat // VersionedMetadata is a generic that wraps a serializable type with a version type VersionedMetadata struct { - Version int `json:"version"` + Version uint `json:"version"` Data []byte `json:"data"` } @@ -129,7 +130,7 @@ func (d *DB) getVersionedMetadata(ctx context.Context, metaType metadataType) ([ var versionedMetas []*VersionedMetadata for _, meta := range metas { versionedMeta := &VersionedMetadata{} - err = json.Unmarshal(meta.Content, versionedMeta) + err = serialize.DecodeInto(meta.Content, versionedMeta) if err != nil { return nil, err } @@ -141,7 +142,7 @@ func (d *DB) getVersionedMetadata(ctx context.Context, metaType metadataType) ([ } func (d *DB) persistVersionedMetadata(ctx context.Context, identifier string, metaType metadataType, meta *VersionedMetadata) error { - bts, err := json.Marshal(meta) + bts, err := serialize.Encode(meta) if err != nil { return err } @@ -155,7 +156,7 @@ func (d *DB) persistVersionedMetadata(ctx context.Context, identifier string, me // serializable is an interface and generic that all serializable types must implement type serializable interface { - types.Table | types.Procedure | types.Extension + types.Table | types.Procedure | encodeableExtension } func decodeMetadata[T serializable](meta []*VersionedMetadata) ([]*T, error) { @@ -164,7 +165,7 @@ func decodeMetadata[T serializable](meta []*VersionedMetadata) ([]*T, error) { for _, value := range meta { tbl := new(T) - err := json.Unmarshal(value.Data, tbl) + err := serialize.DecodeInto(value.Data, tbl) if err != nil { return nil, err } diff --git a/pkg/engine/db/persist.go b/pkg/engine/db/persist.go index b8a562f70..98e2ec200 100644 --- a/pkg/engine/db/persist.go +++ b/pkg/engine/db/persist.go @@ -2,9 +2,10 @@ package db import ( "context" - "encoding/json" "github.com/kwilteam/kwil-db/pkg/engine/types" + "github.com/kwilteam/kwil-db/pkg/serialize" + "github.com/kwilteam/kwil-db/pkg/utils/order" ) const ( @@ -15,7 +16,7 @@ const ( // persistTableMetadata persists the metadata for a table to the database func (d *DB) persistTableMetadata(ctx context.Context, table *types.Table) error { - bts, err := json.Marshal(table) + bts, err := serialize.Encode(table) if err != nil { return err } @@ -38,7 +39,7 @@ func (d *DB) ListTables(ctx context.Context) ([]*types.Table, error) { // StoreProcedure stores a procedure in the database func (d *DB) StoreProcedure(ctx context.Context, procedure *types.Procedure) error { - bts, err := json.Marshal(procedure) + bts, err := serialize.Encode(procedure) if err != nil { return err } @@ -61,7 +62,11 @@ func (d *DB) ListProcedures(ctx context.Context) ([]*types.Procedure, error) { // StoreExtension stores an extension in the database func (d *DB) StoreExtension(ctx context.Context, extension *types.Extension) error { - bts, err := json.Marshal(extension) + bts, err := serialize.Encode(&encodeableExtension{ + Name: extension.Name, + Initialization: order.OrderMap(extension.Initialization), + Alias: extension.Alias, + }) if err != nil { return err } @@ -72,6 +77,16 @@ func (d *DB) StoreExtension(ctx context.Context, extension *types.Extension) err }) } +// encodeableExtension is a modification of the extension struct that can be encoded +// using rlp. This is because the extension struct contains a map[string]string and +// since maps cannot be rlp encoded, we need to convert the map[string]string to a slice +// of key value pairs +type encodeableExtension struct { + Name string + Initialization []*order.KVPair[string, string] + Alias string +} + // ListExtensions lists all extensions in the database func (d *DB) ListExtensions(ctx context.Context) ([]*types.Extension, error) { meta, err := d.getVersionedMetadata(ctx, metadataTypeExtension) @@ -79,5 +94,19 @@ func (d *DB) ListExtensions(ctx context.Context) ([]*types.Extension, error) { return nil, err } - return decodeMetadata[types.Extension](meta) + encodeable, err := decodeMetadata[encodeableExtension](meta) + if err != nil { + return nil, err + } + + var extensions []*types.Extension + for _, ext := range encodeable { + extensions = append(extensions, &types.Extension{ + Name: ext.Name, + Initialization: order.ToMap(ext.Initialization), + Alias: ext.Alias, + }) + } + + return extensions, nil } diff --git a/pkg/engine/db/procedures.go b/pkg/engine/db/procedures.go index e9074074c..0ab8bb93d 100644 --- a/pkg/engine/db/procedures.go +++ b/pkg/engine/db/procedures.go @@ -1,10 +1,10 @@ package db import ( - "encoding/json" "fmt" "github.com/kwilteam/kwil-db/pkg/engine/types" + "github.com/kwilteam/kwil-db/pkg/serialize" ) /* @@ -25,9 +25,9 @@ func decodeVersionedProcedures(meta []*VersionedMetadata) ([]*types.Procedure, e return procedures, nil } -func decodeProcedure(version int, procedureBytes []byte) (*types.Procedure, error) { +func decodeProcedure(version uint, procedureBytes []byte) (*types.Procedure, error) { procedure := &types.Procedure{} - err := json.Unmarshal(procedureBytes, procedure) + err := serialize.DecodeInto(procedureBytes, procedure) if err != nil { return nil, err } diff --git a/pkg/engine/db/sql-ddl-generator/generate_test.go b/pkg/engine/db/sql-ddl-generator/generate_test.go index cd74221ef..b7f361348 100644 --- a/pkg/engine/db/sql-ddl-generator/generate_test.go +++ b/pkg/engine/db/sql-ddl-generator/generate_test.go @@ -330,6 +330,48 @@ func TestGenerateDDL(t *testing.T) { } } +// there used to be a bug where the DDL generator would edit a table's primary key index, +// if one existed. It would add an extra '\"' to the beginning and end of each column name. +func Test_PrimaryIndexModification(t *testing.T) { + testTable := &types.Table{ + Name: "test", + Columns: []*types.Column{ + { + Name: "id1", + Type: types.INT, + }, + { + Name: "id2", // doing this to check composite primary keys + Type: types.INT, + }, + }, + Indexes: []*types.Index{ + { + Name: "primary", + Columns: []string{ + "id1", + "id2", + }, + Type: types.PRIMARY, + }, + }, + } + + _, err := sqlddlgenerator.GenerateDDL(testTable) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + // check that the primary key index was not modified + if testTable.Indexes[0].Columns[0] != "id1" { + t.Errorf("primary key index was modified. Expected 'id1', got '%s'", testTable.Indexes[0].Columns[0]) + } + + if testTable.Indexes[0].Columns[1] != "id2" { + t.Errorf("primary key index was modified. Expected 'id2', got '%s'", testTable.Indexes[0].Columns[1]) + } +} + func removeWhitespace(s string) string { return strings.Map(func(r rune) rune { if unicode.IsSpace(r) { diff --git a/pkg/engine/db/upgrade_test.go b/pkg/engine/db/upgrade_test.go index ca5733b64..24464215d 100644 --- a/pkg/engine/db/upgrade_test.go +++ b/pkg/engine/db/upgrade_test.go @@ -2,12 +2,12 @@ package db_test import ( "context" - "encoding/json" "io" "testing" "github.com/kwilteam/kwil-db/pkg/engine/db" "github.com/kwilteam/kwil-db/pkg/engine/types" + "github.com/kwilteam/kwil-db/pkg/serialize" "github.com/kwilteam/kwil-db/pkg/sql" "github.com/stretchr/testify/assert" ) @@ -137,7 +137,7 @@ func (m procedureStore) Query(ctx context.Context, query string, args map[string returnVals := []map[string]interface{}{} for _, procedure := range m.procedures { - serializedProc, err := json.Marshal(procedure.Procedure) + serializedProc, err := serialize.Encode(procedure.Procedure) if err != nil { return nil, err } @@ -147,7 +147,7 @@ func (m procedureStore) Query(ctx context.Context, query string, args map[string Data: serializedProc, } - contentBytes, err := json.Marshal(dbVersionedProcedure) + contentBytes, err := serialize.Encode(dbVersionedProcedure) if err != nil { return nil, err } @@ -174,6 +174,6 @@ func (m *procedureStore) CreateSession() (sql.Session, error) { } type versionedProcedure struct { - Version int + Version uint Procedure *types.Procedure } diff --git a/pkg/engine/engine_test.go b/pkg/engine/engine_test.go index 0a4a5e978..591f54538 100644 --- a/pkg/engine/engine_test.go +++ b/pkg/engine/engine_test.go @@ -5,6 +5,8 @@ import ( "errors" "testing" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/kwilteam/kwil-db/pkg/crypto" "github.com/kwilteam/kwil-db/pkg/crypto/addresses" engine "github.com/kwilteam/kwil-db/pkg/engine" @@ -34,13 +36,13 @@ func newTestUser() types.UserIdentifier { var ( testTables = []*types.Table{ - &testdata.Table_users, - &testdata.Table_posts, + testdata.TableUsers, + testdata.TablePosts, } testProcedures = []*types.Procedure{ - &testdata.Procedure_create_user, - &testdata.Procedure_create_post, + procedureCreateUser, + procedureCreatePost, } testInitializedExtensions = []*types.Extension{ @@ -54,7 +56,6 @@ var ( } ) -// TODO: this test is not passing func Test_Open(t *testing.T) { ctx := context.Background() user := newTestUser() @@ -109,11 +110,20 @@ func Test_Open(t *testing.T) { t.Fatal(err) } - assert.ElementsMatch(t, testTables, tables) + for _, table := range tables { + if !deepEqual(table, findTable(table.Name)) { + t.Errorf("tables not equal: %v, %v", table, findTable(table.Name)) + } + } // check if the dataset has the correct procedures procs := dataset.ListProcedures() - assert.ElementsMatch(t, testProcedures, procs) + + for _, proc := range procs { + if !deepEqual(proc, findProc(proc.Name)) { + t.Errorf("procedures not equal: %v, %v", proc, findProc(proc.Name)) + } + } pub, err := user.PubKey() if err != nil { @@ -129,6 +139,26 @@ func Test_Open(t *testing.T) { assert.ElementsMatch(t, []string{"testdb1"}, datasets) } +func findProc(name string) *types.Procedure { + for _, proc := range testProcedures { + if proc.Name == name { + return proc + } + } + + panic("procedure not found") +} + +func findTable(name string) *types.Table { + for _, table := range testTables { + if table.Name == name { + return table + } + } + + panic("table not found") +} + func Test_CreateDataset(t *testing.T) { type execution struct { procedure string @@ -343,3 +373,28 @@ func (m *mockRegister) Unregister(ctx context.Context, name string) error { return nil } + +var ( + procedureCreateUser = &types.Procedure{ + Name: "create_user", + Args: []string{"$id", "$username", "$age"}, + Public: true, + Statements: []string{ + "INSERT INTO users (id, username, age, address) VALUES ($id, $username, $age, @caller);", + }, + } + + procedureCreatePost = &types.Procedure{ + Name: "create_post", + Args: []string{"$id", "$title", "$content", "$date_string"}, + Public: true, + Statements: []string{ + "INSERT INTO posts (id, title, content, author_id, post_date)VALUES ($id, $title, $content, (SELECT id FROM users WHERE address=@caller), $date_string);", + }, + } +) + +// deepEqual does a deep comparison, while considering empty slices as equal to nils. +func deepEqual(a, b any) bool { + return cmp.Equal(a, b, cmpopts.EquateEmpty()) +} diff --git a/pkg/engine/sqlanalyzer/analyzer.go b/pkg/engine/sqlanalyzer/analyzer.go index 74244e406..5c26093ce 100644 --- a/pkg/engine/sqlanalyzer/analyzer.go +++ b/pkg/engine/sqlanalyzer/analyzer.go @@ -4,8 +4,10 @@ import ( "fmt" "github.com/kwilteam/kwil-db/pkg/engine/sqlanalyzer/aggregate" + "github.com/kwilteam/kwil-db/pkg/engine/sqlanalyzer/clean" "github.com/kwilteam/kwil-db/pkg/engine/sqlanalyzer/join" "github.com/kwilteam/kwil-db/pkg/engine/sqlanalyzer/order" + "github.com/kwilteam/kwil-db/pkg/engine/sqlparser" "github.com/kwilteam/kwil-db/pkg/engine/sqlparser/tree" "github.com/kwilteam/kwil-db/pkg/engine/types" ) @@ -30,34 +32,63 @@ func (a *acceptWrapper) Accept(walker tree.Walker) (err error) { return a.inner.Accept(walker) } -// ApplyRules analyzes the given statement and returns the statement. -// NOTE: this can change the statement, so it is recommended to clone the statement before analyzing it -// if you want to keep the original statement. -func ApplyRules(stmt accepter, flags VerifyFlag, metadata *RuleMetadata) error { - accept := &acceptWrapper{inner: stmt} +// ApplyRules analyzes the given statement and returns the transformed statement. +// It parses it, and then traverses the AST with the given flags. +// It will alter the statement to make it conform to the given flags, or return an error if it cannot. +func ApplyRules(stmt string, flags VerifyFlag, metadata *RuleMetadata) (*AnalyzedStatement, error) { + copiedMetadata, err := metadata.Clean() + if err != nil { + return nil, fmt.Errorf("error cleaning metadata: %w", err) + } + + parsed, err := sqlparser.Parse(stmt) + if err != nil { + return nil, fmt.Errorf("error parsing statement: %w", err) + } + + accept := &acceptWrapper{inner: parsed} + + clnr := clean.NewStatementCleaner() + err = accept.Accept(clnr) + if err != nil { + return nil, fmt.Errorf("error cleaning statement: %w", err) + } if flags&NoCartesianProduct != 0 { err := accept.Accept(join.NewJoinWalker()) if err != nil { - return fmt.Errorf("error applying join rules: %w", err) + return nil, fmt.Errorf("error applying join rules: %w", err) } } if flags&GuaranteedOrder != 0 { - err := accept.Accept(order.NewOrderWalker(metadata.Tables)) + err := accept.Accept(order.NewOrderWalker(copiedMetadata.Tables)) if err != nil { - return fmt.Errorf("error enforcing guaranteed order: %w", err) + return nil, fmt.Errorf("error enforcing guaranteed order: %w", err) } } if flags&DeterministicAggregates != 0 { err := accept.Accept(aggregate.NewGroupByWalker()) if err != nil { - return fmt.Errorf("error enforcing aggregate determinism: %w", err) + return nil, fmt.Errorf("error enforcing aggregate determinism: %w", err) } } - return nil + mutative, err := isMutative(parsed) + if err != nil { + return nil, fmt.Errorf("error determining mutativity: %w", err) + } + + generated, err := parsed.ToSQL() + if err != nil { + return nil, fmt.Errorf("error generating SQL: %w", err) + } + + return &AnalyzedStatement{ + stmt: generated, + mutative: mutative, + }, nil } // RuleMetadata contains metadata that is needed to enforce a rule @@ -66,6 +97,24 @@ type RuleMetadata struct { Tables []*types.Table } +// Clean copies the tables and cleans them +func (r *RuleMetadata) Clean() (*RuleMetadata, error) { + cleaned := &RuleMetadata{ + Tables: make([]*types.Table, len(r.Tables)), + } + + for i, tbl := range r.Tables { + err := tbl.Clean() + if err != nil { + return nil, fmt.Errorf(`error cleaning table "%s": %w`, tbl.Name, err) + } + + cleaned.Tables[i] = tbl.Copy() + } + + return cleaned, nil +} + type VerifyFlag uint8 const ( @@ -80,3 +129,21 @@ const ( const ( AllRules = NoCartesianProduct | GuaranteedOrder | DeterministicAggregates ) + +// AnalyzedStatement is a statement that has been analyzed by the analyzer +// As we progressively add more types of analysis (e.g. query pricing), we will add more fields to this struct +type AnalyzedStatement struct { + stmt string + mutative bool +} + +// Mutative returns true if the statement will mutate the database +func (a *AnalyzedStatement) Mutative() bool { + return a.mutative +} + +// Statements returns a new statement that is the result of the analysis +// It may contains changes to the original statement, depending on the flags that were passed in +func (a *AnalyzedStatement) Statement() string { + return a.stmt +} diff --git a/pkg/engine/sqlanalyzer/analyzer_test.go b/pkg/engine/sqlanalyzer/analyzer_test.go new file mode 100644 index 000000000..828dcf892 --- /dev/null +++ b/pkg/engine/sqlanalyzer/analyzer_test.go @@ -0,0 +1,194 @@ +package sqlanalyzer_test + +import ( + "testing" + + "github.com/kwilteam/kwil-db/pkg/engine/sqlanalyzer" + "github.com/kwilteam/kwil-db/pkg/engine/types" + "github.com/kwilteam/kwil-db/pkg/engine/types/testdata" + "github.com/stretchr/testify/assert" +) + +func Test_Analyze(t *testing.T) { + type testCase struct { + name string + stmt string + want string + metadata *sqlanalyzer.RuleMetadata + wantErr bool + } + + tests := []testCase{ + { + name: "simple select", + stmt: "SELECT * FROM users", + want: `SELECT * FROM "users" ORDER BY "users"."id" ASC NULLS LAST;`, + metadata: &sqlanalyzer.RuleMetadata{ + Tables: []*types.Table{ + tblUsers, + }, + }, + }, + { + name: "select with joins and subqueries", + stmt: `SELECT p.id, p.title + FROM posts AS p + INNER JOIN followers AS f ON p.user_id = f.user_id + INNER JOIN users AS u ON u.id = f.user_id + WHERE f.follower_id = ( + SELECT id FROM users WHERE username = $username + ) + ORDER BY date(p.post_date) DESC NULLS LAST + LIMIT 20 OFFSET $offset;`, + want: `SELECT "p"."id", "p"."title" + FROM "posts" AS "p" + INNER JOIN "followers" AS "f" ON "p"."user_id" = "f"."user_id" + INNER JOIN "users" AS "u" ON "u"."id" = "f"."user_id" + WHERE "f"."follower_id" = ( + SELECT "id" FROM "users" WHERE "username" = $username ORDER BY "users"."id" ASC NULLS LAST + ) + ORDER BY date ("p"."post_date") DESC NULLS LAST, + "f"."follower_id" ASC NULLS LAST, "f"."user_id" ASC NULLS LAST, "p"."id" ASC NULLS LAST, "u"."id" ASC NULLS LAST + LIMIT 20 OFFSET $offset;`, + metadata: &sqlanalyzer.RuleMetadata{ + Tables: []*types.Table{ + tblUsers, + tblPosts, + tblFollowers, + }, + }, + }, + { + name: "table joined on self", + stmt: `SELECT u1.id, u1.name, u2.name + FROM users AS u1 + INNER JOIN users AS u2 ON u1.id = u2.id`, + want: `SELECT "u1"."id", "u1"."name", "u2"."name" + FROM "users" AS "u1" + INNER JOIN "users" AS "u2" ON "u1"."id" = "u2"."id" + ORDER BY "u1"."id" ASC NULLS LAST, "u2"."id" ASC NULLS LAST;`, + metadata: &sqlanalyzer.RuleMetadata{ + Tables: []*types.Table{ + tblUsers, + }, + }, + }, + { + name: "common table expression", + stmt: `WITH + users_aged_20 AS ( + SELECT id, username FROM users WHERE age = 20 + ) + SELECT * FROM users_aged_20`, + want: `WITH + "users_aged_20" AS ( + SELECT "users"."id", "users"."username" FROM "users" WHERE "age" = 20 ORDER BY "users"."id" ASC NULLS LAST + ) + SELECT * FROM "users_aged_20" ORDER BY "users_aged_20"."id" ASC NULLS LAST, "users_aged_20"."username" ASC NULLS LAST;`, + metadata: &sqlanalyzer.RuleMetadata{ + Tables: []*types.Table{ + tblUsers, + }, + }, + }, + { + name: "basic insert", + stmt: `INSERT INTO users (id, username, age) VALUES (1, 'user1', 20)`, + want: `INSERT INTO "users" ("id", "username", "age") VALUES (1, 'user1', 20);`, + metadata: &sqlanalyzer.RuleMetadata{ + Tables: []*types.Table{ + tblUsers, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := sqlanalyzer.ApplyRules(tt.stmt, sqlanalyzer.AllRules, tt.metadata) + if (err != nil) != tt.wantErr { + t.Errorf("ApplyRules() error = %v, wantErr %v", err, tt.wantErr) + return + } + + assert.Equal(t, removeSpaces(tt.want), removeSpaces(got.Statement())) + }) + } +} + +var ( + tblUsers = testdata.TableUsers + + tblPosts = &types.Table{ + Name: "posts", + Columns: []*types.Column{ + { + Name: "id", + Type: types.INT, + Attributes: []*types.Attribute{ + { + Type: types.PRIMARY_KEY, + }, + }, + }, + { + Name: "user_id", + Type: types.INT, + Attributes: []*types.Attribute{ + { + Type: types.NOT_NULL, + }, + }, + }, + { + Name: "title", + Type: types.TEXT, + }, + { + Name: "post_date", + Type: types.INT, + Attributes: []*types.Attribute{ + { + Type: types.NOT_NULL, + }, + }, + }, + }, + } + + tblFollowers = &types.Table{ + Name: "followers", + Columns: []*types.Column{ + { + Name: "user_id", + Type: types.INT, + }, + { + Name: "follower_id", + Type: types.INT, + }, + }, + Indexes: []*types.Index{ + { + Name: "primary_key", + Columns: []string{ + "user_id", + "follower_id", + }, + Type: types.PRIMARY, + }, + }, + } +) + +// removeSpaces removes all spaces from a string. +// this is useful for comparing strings, where one is generated +func removeSpaces(s string) string { + var result []rune + for _, ch := range s { + if ch != ' ' && ch != '\n' && ch != '\r' && ch != '\t' { + result = append(result, ch) + } + } + return string(result) +} diff --git a/pkg/engine/sqlanalyzer/attributes/select_core.go b/pkg/engine/sqlanalyzer/attributes/select_core.go new file mode 100644 index 000000000..624350aea --- /dev/null +++ b/pkg/engine/sqlanalyzer/attributes/select_core.go @@ -0,0 +1,290 @@ +/* +Package attributes analyzes a returned relations attributes, maintaining order. +This is useful for determining the relation schema that a query / CTE returns. + +For example, given the following query: + + WITH satoshi_posts AS ( + SELECT id, title, content FROM posts + WHERE user_id = ( + SELECT id FROM users WHERE username = 'satoshi' LIMIT 1 + ) + ) + SELECT id, title FROM satoshi_posts; + +The attributes package will be able to determine that: + 1. The result of this query is a relation with two attributes: id and title + 2. The result of the common table expression satoshi_posts is a relation with three attributes: id, title, and content +*/ +package attributes + +import ( + "fmt" + + "github.com/kwilteam/kwil-db/pkg/engine/sqlparser/tree" + "github.com/kwilteam/kwil-db/pkg/engine/types" +) + +// A RelationAttribute is a column or expression that is part of a relation. +// It contains the logical representation of the attribute, as well as the +// data type +type RelationAttribute struct { + // ResultExpression is the expression that represents the attribute + // This can be things like "column_name", "table"."column_name", "sum(column_name)"", "5", etc. + ResultExpression *tree.ResultColumnExpression + + // Type is the data type of the attribute + Type types.DataType +} + +// GetSelectCoreRelationAttributes will analyze the select core and return the +// identified relation. +// It returns a list of result column expressions. These can be things like: +// tbl1.col, col, col AS alias, col*5 AS alias, etc. +// If a statement has "SELECT * FROM tbl", +// then the result column expressions will be tbl.col_1, tbl.col_2, etc. +func GetSelectCoreRelationAttributes(selectCore *tree.SelectCore, tables []*types.Table) ([]*RelationAttribute, error) { + walker := newSelectCoreWalker(tables) + err := selectCore.Accept(walker) + if err != nil { + return nil, fmt.Errorf("error analyzing select core: %w", err) + } + + return walker.detectedAttributes, nil +} + +func newSelectCoreWalker(tables []*types.Table) *selectCoreAnalyzer { + return &selectCoreAnalyzer{ + Walker: tree.NewBaseWalker(), + context: newSelectCoreContext(nil), + schemaTables: tables, + detectedAttributes: []*RelationAttribute{}, + } +} + +// selectCoreAnalyzer will walk the tree and identify the returned attributes for the select core +type selectCoreAnalyzer struct { + tree.Walker + context *selectCoreContext + schemaTables []*types.Table + + // detectedAttributes is a list of the detected attributes + // from the scope + detectedAttributes []*RelationAttribute +} + +// newScope creates a new scope for the select core +// it sets the parent scope to the current scope +func (s *selectCoreAnalyzer) newScope() { + oldCtx := s.context + s.context = newSelectCoreContext(oldCtx) +} + +// oldScope pops the current scope and returns to the parent scope +// if there is no parent scope, it simply sets the current scope to nil +func (s *selectCoreAnalyzer) oldScope() { + if s.context == nil { + panic("oldScope called with no current scope") + } + if s.context.parent == nil { + s.context = nil + return + } + + s.context = s.context.parent +} + +type selectCoreContext struct { + // Parent is the parent context + parent *selectCoreContext + + // results is the ordered list of query results + results []tree.ResultColumn + + // usedTables is a list of tables used in the select core + usedTables []*types.Table +} + +// relations returns the identified relations +// it will expand the stars and table stars to the list of columns +func (s *selectCoreContext) relations() ([]*RelationAttribute, error) { + results := make([]*RelationAttribute, 0) + + for _, res := range s.results { + exprs, err := s.evaluateResult(res) + if err != nil { + return nil, err + } + + results = append(results, exprs...) + } + + return results, nil +} + +// addResult adds a result to the list of results +func (s *selectCoreContext) addResult(result tree.ResultColumn) { + s.results = append(s.results, result) +} + +// evaluateResult evaluates a result, returning it as a column expression +func (s *selectCoreContext) evaluateResult(result tree.ResultColumn) ([]*RelationAttribute, error) { + results := []*RelationAttribute{} + + switch r := result.(type) { + default: + panic(fmt.Sprintf("unknown result type: %T", r)) + case *tree.ResultColumnExpression: + copied := *r + + dataType, err := predictReturnType(r.Expression, s.usedTables) + if err != nil { + return nil, err + } + + if len(s.usedTables) > 0 { + err := addTableIfNotPresent(s.usedTables[0].Name, &copied) + if err != nil { + return nil, err + } + } + + results = append(results, &RelationAttribute{ + ResultExpression: &copied, + Type: dataType, + }) + case *tree.ResultColumnStar: + for _, tbl := range s.usedTables { + for _, col := range tbl.Columns { + results = append(results, &RelationAttribute{ + ResultExpression: &tree.ResultColumnExpression{ + Expression: &tree.ExpressionColumn{ + Table: tbl.Name, + Column: col.Name, + }, + }, + Type: col.Type, + }) + } + } + case *tree.ResultColumnTable: + tbl, err := findTable(s.usedTables, r.TableName) + if err != nil { + return nil, err + } + + for _, col := range tbl.Columns { + results = append(results, &RelationAttribute{ + ResultExpression: &tree.ResultColumnExpression{ + Expression: &tree.ExpressionColumn{ + Table: tbl.Name, + Column: col.Name, + }, + }, + Type: col.Type, + }) + } + } + + return results, nil +} + +// newSelectCoreContext creates a new select core context +func newSelectCoreContext(parent *selectCoreContext) *selectCoreContext { + return &selectCoreContext{ + parent: parent, + } +} + +// EnterSelectCore creates a new scope. +func (s *selectCoreAnalyzer) EnterSelectCore(node *tree.SelectCore) error { + s.newScope() + + return nil +} + +// ExitSelectCore pops the current scope. +func (s *selectCoreAnalyzer) ExitSelectCore(node *tree.SelectCore) error { + var err error + s.detectedAttributes, err = s.context.relations() + if err != nil { + return err + } + + s.oldScope() + return nil +} + +// EnterTableOrSubqueryTable adds the table to the list of used tables. +func (s *selectCoreAnalyzer) EnterTableOrSubqueryTable(node *tree.TableOrSubqueryTable) error { + tbl, err := findTable(s.schemaTables, node.Name) + if err != nil { + return err + } + + identifier := node.Name + if node.Alias != "" { + identifier = node.Alias + } + + s.context.usedTables = append(s.context.usedTables, &types.Table{ + Name: identifier, + Columns: tbl.Columns, + Indexes: tbl.Indexes, + ForeignKeys: tbl.ForeignKeys, + }) + + return nil +} + +// EnterResultColumnExpression adds the result column expression to the list of attributes +func (s *selectCoreAnalyzer) EnterResultColumnExpression(node *tree.ResultColumnExpression) error { + s.context.addResult(node) + return nil +} + +// EnterResultColumnStar adds the result column expression to the list of attributes +func (s *selectCoreAnalyzer) EnterResultColumnStar(node *tree.ResultColumnStar) error { + s.context.addResult(node) + return nil +} + +// EnterResultColumnTable adds the result column expression to the list of attributes +func (s *selectCoreAnalyzer) EnterResultColumnTable(node *tree.ResultColumnTable) error { + s.context.addResult(node) + return nil +} + +// findTable finds a table by name +func findTable(tables []*types.Table, name string) (*types.Table, error) { + for _, t := range tables { + if t.Name == name { + return t, nil + } + } + + return nil, fmt.Errorf(`table "%s" not found`, name) +} + +// findColumn finds a column by name +func findColumn(columns []*types.Column, name string) (*types.Column, error) { + for _, c := range columns { + if c.Name == name { + return c, nil + } + } + + return nil, fmt.Errorf(`column "%s" not found`, name) +} + +// addTableIfNotPresent adds the table name to the column if it is not already present. +func addTableIfNotPresent(tableName string, expr tree.Accepter) error { + return expr.Accept(&tree.ImplementedWalker{ + FuncEnterExpressionColumn: func(col *tree.ExpressionColumn) error { + if col.Table == "" { + col.Table = tableName + } + return nil + }, + }) +} diff --git a/pkg/engine/sqlanalyzer/attributes/select_core_test.go b/pkg/engine/sqlanalyzer/attributes/select_core_test.go new file mode 100644 index 000000000..d98cc4e34 --- /dev/null +++ b/pkg/engine/sqlanalyzer/attributes/select_core_test.go @@ -0,0 +1,296 @@ +package attributes_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/kwilteam/kwil-db/pkg/engine/sqlanalyzer/attributes" + "github.com/kwilteam/kwil-db/pkg/engine/sqlparser" + "github.com/kwilteam/kwil-db/pkg/engine/sqlparser/tree" + "github.com/kwilteam/kwil-db/pkg/engine/types" + "github.com/kwilteam/kwil-db/pkg/engine/types/testdata" + "github.com/stretchr/testify/assert" +) + +func TestGetSelectCoreRelationAttributes(t *testing.T) { + tests := []struct { + name string + tables []*types.Table + stmt string + want []*attributes.RelationAttribute + resultTableCols []*types.Column + // wantInequality is true if we want the test to fail if the result is equal to want + wantInequality bool + wantErr bool + }{ + { + name: "simple select", + tables: []*types.Table{ + testdata.TableUsers, + }, + stmt: "SELECT id FROM users", + want: []*attributes.RelationAttribute{ + tblCol(types.INT, "users", "id"), + }, + resultTableCols: []*types.Column{ + col("id", types.INT), + }, + }, + { + name: "simple select - failure", + tables: []*types.Table{ + testdata.TableUsers, + }, + stmt: "SELECT id FROM users", + want: []*attributes.RelationAttribute{ + tblCol(types.TEXT, "users", "name"), + }, + wantInequality: true, + }, + { + name: "simple select with alias", + tables: []*types.Table{ + testdata.TableUsers, + }, + stmt: "SELECT id AS user_id FROM users", + want: []*attributes.RelationAttribute{ + tblColAlias(types.INT, "users", "id", "user_id"), + }, + resultTableCols: []*types.Column{ + col("user_id", types.INT), + }, + }, + { + name: "test subquery is ignored", + tables: []*types.Table{ + testdata.TableUsers, + testdata.TablePosts, + }, + stmt: "SELECT id FROM users WHERE id IN (SELECT author_id FROM posts)", + want: []*attributes.RelationAttribute{ + tblCol(types.INT, "users", "id"), + }, + resultTableCols: []*types.Column{ + col("id", types.INT), + }, + }, + { + name: "test star, table star works", + tables: []*types.Table{ + testdata.TableUsers, + }, + stmt: "SELECT users.*, * FROM users", + want: []*attributes.RelationAttribute{ + tblCol(types.INT, "users", "id"), // we expect them twice since it is defined twice + tblCol(types.TEXT, "users", "username"), + tblCol(types.INT, "users", "age"), + tblCol(types.TEXT, "users", "address"), + tblCol(types.INT, "users", "id"), + tblCol(types.TEXT, "users", "username"), + tblCol(types.INT, "users", "age"), + tblCol(types.TEXT, "users", "address"), + }, + resultTableCols: []*types.Column{ + col("id", types.INT), + col("username", types.TEXT), + col("age", types.INT), + col("address", types.TEXT), + col("id:1", types.INT), + col("username:1", types.TEXT), + col("age:1", types.INT), + col("address:1", types.TEXT), + }, + }, + { + name: "test star, table star, literal, untabled column, and tabled column with alias work and join", + tables: []*types.Table{ + testdata.TableUsers, + testdata.TablePosts, + }, + stmt: "SELECT users.*, *, age, users.age AS the_age, 5 as the_literal_5 FROM users INNER JOIN posts ON users.id = posts.author_id", + want: []*attributes.RelationAttribute{ + // all user columns from users.* + tblCol(types.INT, "users", "id"), + tblCol(types.TEXT, "users", "username"), + tblCol(types.INT, "users", "age"), + tblCol(types.TEXT, "users", "address"), + + // all user columns from * + tblCol(types.INT, "users", "id"), + tblCol(types.TEXT, "users", "username"), + tblCol(types.INT, "users", "age"), + tblCol(types.TEXT, "users", "address"), + + // all post columns from * + tblCol(types.INT, "posts", "id"), + tblCol(types.TEXT, "posts", "title"), + tblCol(types.TEXT, "posts", "content"), + tblCol(types.INT, "posts", "author_id"), + tblCol(types.TEXT, "posts", "post_date"), + + // age + tblCol(types.INT, "users", "age"), + + // users.age AS the_age + tblColAlias(types.INT, "users", "age", "the_age"), + + // 5 + literal(types.INT, "5", "the_literal_5"), + }, + resultTableCols: []*types.Column{ + col("id", types.INT), + col("username", types.TEXT), + col("age", types.INT), + col("address", types.TEXT), + col("id:1", types.INT), + col("username:1", types.TEXT), + col("age:1", types.INT), + col("address:1", types.TEXT), + col("id:2", types.INT), + col("title", types.TEXT), + col("content", types.TEXT), + col("author_id", types.INT), + col("post_date", types.TEXT), + col("age:2", types.INT), + col("the_age", types.INT), + col("the_literal_5", types.INT), + }, + }, + { + name: "join with aliases", + tables: []*types.Table{ + testdata.TableUsers, + testdata.TablePosts, + }, + stmt: "SELECT u.id AS user_id, u.username AS username, count(p.id) AS post_count FROM users AS u LEFT JOIN posts AS p ON u.id = p.author_id GROUP BY u.id", + want: []*attributes.RelationAttribute{ + tblColAlias(types.INT, "u", "id", "user_id"), + tblColAlias(types.TEXT, "u", "username", "username"), + { + ResultExpression: &tree.ResultColumnExpression{ + Expression: &tree.ExpressionFunction{ + Function: &tree.FunctionCOUNT, + Inputs: []tree.Expression{ + &tree.ExpressionColumn{ + Table: "p", + Column: "id", + }, + }, + }, + Alias: "post_count", + }, + Type: types.INT, + }, + }, + resultTableCols: []*types.Column{ + col("user_id", types.INT), + col("username", types.TEXT), + col("post_count", types.INT), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ast, err := sqlparser.Parse(tt.stmt) + if err != nil { + t.Errorf("GetSelectCoreRelationAttributes() error = %v", err) + return + } + selectStmt, okj := ast.(*tree.Select) + if !okj { + t.Errorf("test case %s is not a select statement", tt.name) + return + } + + got, err := attributes.GetSelectCoreRelationAttributes(selectStmt.SelectStmt.SelectCores[0], tt.tables) + if (err != nil) != tt.wantErr { + t.Errorf("GetSelectCoreRelationAttributes() error = %v, wantErr %v", err, tt.wantErr) + return + } + if len(got) != len(tt.want) { + t.Errorf("Invalid length. GetSelectCoreRelationAttributes() got = %v, want %v", got, tt.want) + return + } + + same := true + incorrectIdx := -1 + for i, g := range got { + if !cmp.Equal(*g, *tt.want[i], cmpopts.IgnoreUnexported(tree.ResultColumnExpression{}, tree.AnySQLFunction{}, tree.AggregateFunc{}, tree.ExpressionFunction{}, tree.ExpressionColumn{}, tree.ExpressionLiteral{}), cmpopts.EquateEmpty()) { + same = false + incorrectIdx = i + break + } + } + + if same != !tt.wantInequality { + t.Errorf("GetSelectCoreRelationAttributes() got = %v, want %v", got, tt.want) + if incorrectIdx != -1 { + t.Errorf("Incorrect index: %d", incorrectIdx) + } + } + + if tt.wantInequality { + return + } + + genTable, err := attributes.TableFromAttributes("result_table", got, true) + if err != nil { + t.Errorf("GetSelectCoreRelationAttributes() error = %v", err) + return + } + // check that the auto primary key works + assert.Equal(t, len(tt.want), len(genTable.Indexes[0].Columns)) + + // check that the columns are correct + if !cmp.Equal(tt.resultTableCols, genTable.Columns, cmpopts.IgnoreSliceElements(func(v int) bool { return true })) { + t.Errorf("GetSelectCoreRelationAttributes() got = %v, want %v", got, tt.want) + return + } + }) + } +} + +func tblCol(dataType types.DataType, tbl, column string) *attributes.RelationAttribute { + return &attributes.RelationAttribute{ + ResultExpression: &tree.ResultColumnExpression{ + Expression: &tree.ExpressionColumn{ + Table: tbl, + Column: column, + }, + }, + Type: dataType, + } +} + +func tblColAlias(dataType types.DataType, tbl, column, alias string) *attributes.RelationAttribute { + return &attributes.RelationAttribute{ + ResultExpression: &tree.ResultColumnExpression{ + Expression: &tree.ExpressionColumn{ + Table: tbl, + Column: column, + }, + Alias: alias, + }, + Type: dataType, + } +} + +func literal(dataType types.DataType, lit string, alias string) *attributes.RelationAttribute { + return &attributes.RelationAttribute{ + ResultExpression: &tree.ResultColumnExpression{ + Expression: &tree.ExpressionLiteral{ + Value: lit, + }, + Alias: alias, + }, + Type: dataType, + } +} + +func col(name string, datatype types.DataType) *types.Column { + return &types.Column{ + Name: name, + Type: datatype, + } +} diff --git a/pkg/engine/sqlanalyzer/attributes/table.go b/pkg/engine/sqlanalyzer/attributes/table.go new file mode 100644 index 000000000..92f1faa57 --- /dev/null +++ b/pkg/engine/sqlanalyzer/attributes/table.go @@ -0,0 +1,82 @@ +package attributes + +import ( + "fmt" + + "github.com/kwilteam/kwil-db/pkg/engine/sqlparser/tree" + "github.com/kwilteam/kwil-db/pkg/engine/types" +) + +// TableFromAttributes generates a table structure from a list of relation attributes. +// It will do it's best to interpret the proper name for the attributes. +// If a column is given (either as a table.column or just column), it will apply the new table name. +// If any other expression is given (math on top of a column, aggregate, etc), it will enforce that an alias +// is given in the ResultColumnExpression of the relation attribute. The types of ambiguous naming supported by SQLite for CTEs is +// not clear from their docs, so this is to be safe. +// It takes a boolean to determine if a primary key should be added to the table. +// If true, the primary key is simply a composite key of all of the columns in the table. +// If it will return two columns of the same name, it will add a suffix of ":1", ":2", etc. +func TableFromAttributes(tableName string, attrs []*RelationAttribute, withPrimaryKey bool) (*types.Table, error) { + cols := []*types.Column{} + nameCounts := map[string]int{} + + for _, attr := range attrs { + var colToAdd *types.Column + + // if it's a column, then we can just use that + exprColumn, ok := attr.ResultExpression.Expression.(*tree.ExpressionColumn) + if ok { + colName := exprColumn.Column + if attr.ResultExpression.Alias != "" { + colName = attr.ResultExpression.Alias + } + + colToAdd = &types.Column{ + Name: colName, + Type: attr.Type, + } + } else { + // else we need to make sure it has an alias + if attr.ResultExpression.Alias == "" { + return nil, fmt.Errorf("%w: result columns that contain complex statements must have an alias", ErrInvalidReturnExpression) + } + + colToAdd = &types.Column{ + Name: attr.ResultExpression.Alias, + Type: attr.Type, + } + } + + timesAppeared, ok := nameCounts[colToAdd.Name] + if ok { + nameCounts[colToAdd.Name] = timesAppeared + 1 + colToAdd.Name = fmt.Sprintf("%s:%d", colToAdd.Name, timesAppeared) + } else { + nameCounts[colToAdd.Name] = 1 + } + + cols = append(cols, colToAdd) + } + + table := &types.Table{ + Name: tableName, + Columns: cols, + } + + if withPrimaryKey { + colNames := []string{} + for _, col := range cols { + colNames = append(colNames, col.Name) + } + + table.Indexes = []*types.Index{ + { + Name: fmt.Sprintf("%s_pk", tableName), + Columns: colNames, + Type: types.PRIMARY, + }, + } + } + + return table, nil +} diff --git a/pkg/engine/sqlanalyzer/attributes/types.go b/pkg/engine/sqlanalyzer/attributes/types.go new file mode 100644 index 000000000..856c62f02 --- /dev/null +++ b/pkg/engine/sqlanalyzer/attributes/types.go @@ -0,0 +1,256 @@ +package attributes + +import ( + "fmt" + + "github.com/kwilteam/kwil-db/pkg/engine/sqlanalyzer/utils" + "github.com/kwilteam/kwil-db/pkg/engine/sqlparser/tree" + "github.com/kwilteam/kwil-db/pkg/engine/types" +) + +// predictReturnType will attempt to predict the return type of an expression. +// If it is ambiguous but is a valid return expression, it will return types.TEXT. +// If it is invalid, it will return an error. +func predictReturnType(expr tree.Expression, tables []*types.Table) (types.DataType, error) { + w := &returnTypeWalker{ + Walker: tree.NewBaseWalker(), + tables: tables, + } + + err := expr.Accept(w) + if err != nil { + return types.TEXT, fmt.Errorf("error predicting return type: %w", err) + } + + if !w.detected { + return types.TEXT, fmt.Errorf("could not detect return type for expression: %s", expr) + } + + return w.detectedType, nil +} + +// ErrInvalidReturnExpression is returned when an expression cannot be used as a result column +var ErrInvalidReturnExpression = fmt.Errorf("expression cannot be used as a result column") + +// errReturnExpr is used to return an error when an expression cannot be used as a result column +func errReturnExpr(expr tree.Expression) error { + return fmt.Errorf("%w: using expression %s", ErrInvalidReturnExpression, expr) +} + +type returnTypeWalker struct { + tree.Walker + detected bool + detectedType types.DataType + tables []*types.Table +} + +var _ tree.Walker = &returnTypeWalker{} + +func (r *returnTypeWalker) EnterExpressionArithmetic(p0 *tree.ExpressionArithmetic) error { + r.set(types.INT) + return nil +} +func (r *returnTypeWalker) EnterExpressionBetween(p0 *tree.ExpressionBetween) error { + r.set(types.INT) + return nil +} +func (r *returnTypeWalker) EnterExpressionBinaryComparison(p0 *tree.ExpressionBinaryComparison) error { + r.set(types.INT) + return nil +} +func (r *returnTypeWalker) EnterExpressionBindParameter(p0 *tree.ExpressionBindParameter) error { + r.set(types.TEXT) + return nil +} +func (r *returnTypeWalker) EnterExpressionCase(p0 *tree.ExpressionCase) error { + r.set(types.TEXT) + return nil +} +func (r *returnTypeWalker) EnterExpressionCollate(p0 *tree.ExpressionCollate) error { + r.set(types.TEXT) + return nil +} + +// we need to identify the column type +// there are three potential cases here +// 1. the expression declares the table +// - we need to search the table for the column to get the data type +// +// 2. the expression does not declare the table, but usedTables is not empty +// - we need to search the first usedTables table for the column to get the data type, and add the table name to the column +// - if we can't find the column, we return an error +// +// 3. the expression does not declare the table, and usedTables is empty +// - we return an error +func (r *returnTypeWalker) EnterExpressionColumn(p0 *tree.ExpressionColumn) error { + if r.detected { + return nil + } + + // case 1 + if p0.Table != "" { + table, err := findTable(r.tables, p0.Table) + if err != nil { + return err + } + + col, err := findColumn(table.Columns, p0.Column) + if err != nil { + return err + } + + r.set(col.Type) + return nil + } + + // case 2 + if len(r.tables) > 0 && r.tables[0] != nil { + col, err := findColumn(r.tables[0].Columns, p0.Column) + if err != nil { + return err + } + + r.set(col.Type) + return nil + } + + // case 3 + return fmt.Errorf(`%w: could not identify column "%s"`, ErrInvalidReturnExpression, p0.Column) + +} + +func (r *returnTypeWalker) EnterExpressionDistinct(p0 *tree.ExpressionDistinct) error { + r.set(types.INT) + return nil +} +func (r *returnTypeWalker) EnterExpressionFunction(p0 *tree.ExpressionFunction) error { + if r.detected { + return nil + } + + switch p0.Function { + + // scalars + case &tree.FunctionABS: + r.set(types.INT) + case &tree.FunctionCOALESCE: + // ambiguous + r.set(types.TEXT) + case &tree.FunctionERROR: + return fmt.Errorf("%w: using function %s", ErrInvalidReturnExpression, p0.Function.Name()) + case &tree.FunctionFORMAT: + r.set(types.TEXT) + case &tree.FunctionGLOB: + r.set(types.INT) + case &tree.FunctionHEX: + r.set(types.TEXT) + case &tree.FunctionIFNULL: + // ambiguous + r.set(types.TEXT) + case &tree.FunctionIIF: + // ambiguous + r.set(types.TEXT) + case &tree.FunctionINSTR: + r.set(types.INT) + case &tree.FunctionLENGTH: + r.set(types.INT) + case &tree.FunctionLIKE: + r.set(types.INT) + case &tree.FunctionLOWER: + r.set(types.TEXT) + case &tree.FunctionLTRIM: + r.set(types.TEXT) + case &tree.FunctionNULLIF: + // ambiguous + r.set(types.TEXT) + case &tree.FunctionQUOTE: + r.set(types.TEXT) + case &tree.FunctionREPLACE: + r.set(types.TEXT) + case &tree.FunctionRTRIM: + r.set(types.TEXT) + case &tree.FunctionSIGN: + r.set(types.INT) + case &tree.FunctionSUBSTR: + r.set(types.TEXT) + case &tree.FunctionTRIM: + r.set(types.TEXT) + case &tree.FunctionTYPEOF: + r.set(types.TEXT) + case &tree.FunctionUNHEX: + r.set(types.TEXT) + case &tree.FunctionUNICODE: + r.set(types.INT) + case &tree.FunctionUPPER: + r.set(types.TEXT) + + // aggregates + case &tree.FunctionCOUNT: + r.set(types.INT) + case &tree.FunctionGROUPCONCAT: + r.set(types.TEXT) + case &tree.FunctionMAX: + r.set(types.INT) + case &tree.FunctionMIN: + r.set(types.INT) + + // datetime (all return text) + case &tree.FunctionDATE, &tree.FunctionTIME, &tree.FunctionDATETIME, &tree.FunctionUNIXEPOCH, &tree.FunctionSTRFTIME: + r.set(types.TEXT) + default: + return fmt.Errorf("unknown function: %s", p0.Function) + } + + return nil +} +func (r *returnTypeWalker) EnterExpressionIsNull(p0 *tree.ExpressionIsNull) error { + r.set(types.INT) + return nil +} +func (r *returnTypeWalker) EnterExpressionList(p0 *tree.ExpressionList) error { + return errReturnExpr(p0) +} + +// EnterExpressionLiteral will attempt to detect the type of the literal +func (r *returnTypeWalker) EnterExpressionLiteral(p0 *tree.ExpressionLiteral) error { + if r.detected { + return nil + } + + dataTypes, err := utils.IsLiteral(p0.Value) + if err != nil { + return err + } + switch dataTypes { + case types.TEXT: + r.set(types.TEXT) + case types.INT: + r.set(types.INT) + default: + return fmt.Errorf("unknown literal type for analyzed relation attribute: %s", dataTypes) + } + + return nil +} +func (r *returnTypeWalker) EnterExpressionSelect(p0 *tree.ExpressionSelect) error { + return errReturnExpr(p0) +} +func (r *returnTypeWalker) EnterExpressionStringCompare(p0 *tree.ExpressionStringCompare) error { + r.set(types.INT) + return nil +} +func (r *returnTypeWalker) EnterExpressionUnary(p0 *tree.ExpressionUnary) error { + r.set(types.INT) + return nil +} + +// set sets the detected type if it has not already been set +// since we only want the first detected type +func (r *returnTypeWalker) set(t types.DataType) { + if r.detected { + return + } + + r.detected = true + r.detectedType = t +} diff --git a/pkg/engine/sqlanalyzer/clean/clean.go b/pkg/engine/sqlanalyzer/clean/clean.go new file mode 100644 index 000000000..63cf141ac --- /dev/null +++ b/pkg/engine/sqlanalyzer/clean/clean.go @@ -0,0 +1,69 @@ +/* +Package clean cleans SQL queries. + +This includes making identifiers lower case. + +The walker in this package implements all the tree.Walker methods, even if it +doesn't do anything. This is to ensure that if we need to add more cleaning / validation +rules, we know that we've covered all the nodes. + +For example, EnterDelete does nothing, but if we later set a limit on the amount of +CTEs allowed, then we would add it there. +*/ +package clean + +import ( + "errors" + "fmt" + "regexp" + "strings" + + "github.com/kwilteam/kwil-db/pkg/engine/sqlanalyzer/utils" +) + +// checks that the string only contains alphanumeric characters and underscores +var identifierRegexp = regexp.MustCompile(`^[a-z][a-z0-9_]*$`) + +// cleanIdentifier checks that the identifier is a valid identifier and returns +// it in lower case +func cleanIdentifier(identifier string) (string, error) { + res := strings.ToLower(identifier) + + if !identifierRegexp.MatchString(res) { + return "", wrapErr(ErrInvalidIdentifier, fmt.Errorf(`identifier must start with letter and only contain alphanumeric characters or underscores, received: "%s"`, identifier)) + } + + return res, nil +} + +// cleanIdentifiers checks several identifiers and returns them in lower case +func cleanIdentifiers(identifiers []string) ([]string, error) { + res := make([]string, len(identifiers)) + + for i, identifier := range identifiers { + var err error + res[i], err = cleanIdentifier(identifier) + if err != nil { + return nil, err + } + } + + return res, nil +} + +// checkLiteral checks that the literal is a valid literal. +// It either must be guarded with single quotes, or it must be a number. +func checkLiteral(literal string) error { + _, err := utils.IsLiteral(literal) + return wrapErr(ErrInvalidLiteral, err) +} + +// checkBindParameter checks that the bind parameter is a valid bind parameter. +// It must start with either a $ or @. +func checkBindParameter(bindParameter string) error { + if !strings.HasPrefix(bindParameter, "$") && !strings.HasPrefix(bindParameter, "@") { + return wrapErr(ErrInvalidBindParameter, errors.New("bind parameter must start with $ or @")) + } + + return nil +} diff --git a/pkg/engine/sqlanalyzer/clean/errors.go b/pkg/engine/sqlanalyzer/clean/errors.go new file mode 100644 index 000000000..ecd617124 --- /dev/null +++ b/pkg/engine/sqlanalyzer/clean/errors.go @@ -0,0 +1,34 @@ +package clean + +import ( + "errors" + "fmt" +) + +var ( + ErrInvalidIdentifier = errors.New("invalid identifier") + ErrInvalidLiteral = errors.New("invalid literal") + ErrInvalidBindParameter = errors.New("invalid bind parameter") + ErrInvalidCollation = errors.New("invalid collation") + ErrInvalidInsertType = errors.New("invalid insert type") + ErrInvalidUpdateType = errors.New("invalid update type") + ErrInvalidSelectType = errors.New("invalid select type") + ErrInvalidJoinOperator = errors.New("invalid join operator") + ErrInvalidOrderType = errors.New("invalid order type") + ErrInvalidNullOrderType = errors.New("invalid null order type") + ErrInvalidReturningClause = errors.New("invalid returning clause") + ErrInvalidCompoundOperator = errors.New("invalid compound operator") + ErrInvalidUpsertType = errors.New("invalid upsert type") + ErrInvalidUnaryOperator = errors.New("invalid unary operator") + ErrInvalidBinaryOperator = errors.New("invalid binary operator") + ErrInvalidStringComparisonOperator = errors.New("invalid string comparison operator") + ErrInvalidArithmeticOperator = errors.New("invalid arithmetic operator") +) + +// wrapErr wraps an error with another, if the second error is not nil +func wrapErr(err error, err2 error) error { + if err2 == nil { + return nil + } + return fmt.Errorf("%w: %s", err, err2) +} diff --git a/pkg/engine/sqlanalyzer/clean/walker.go b/pkg/engine/sqlanalyzer/clean/walker.go new file mode 100644 index 000000000..124df8fd0 --- /dev/null +++ b/pkg/engine/sqlanalyzer/clean/walker.go @@ -0,0 +1,403 @@ +package clean + +import ( + "errors" + + "github.com/kwilteam/kwil-db/pkg/engine/sqlparser/tree" +) + +// TODO: the statement cleaner should also check for table / column existence +func NewStatementCleaner() *StatementCleaner { + return &StatementCleaner{ + Walker: tree.NewBaseWalker(), + } +} + +var _ tree.Walker = &StatementCleaner{} + +type StatementCleaner struct { + tree.Walker +} + +// EnterAggregateFunc checks that the function name is a valid identifier +func (s *StatementCleaner) EnterAggregateFunc(node *tree.AggregateFunc) (err error) { + node.FunctionName, err = cleanIdentifier(node.FunctionName) + return wrapErr(ErrInvalidIdentifier, err) +} + +// EnterConflictTarget checks that the indexed column names are valid identifiers +func (s *StatementCleaner) EnterConflictTarget(node *tree.ConflictTarget) (err error) { + node.IndexedColumns, err = cleanIdentifiers(node.IndexedColumns) + return wrapErr(ErrInvalidIdentifier, err) +} + +// EnterCTE checks that the table name and column names are valid identifiers +func (s *StatementCleaner) EnterCTE(node *tree.CTE) (err error) { + node.Table, err = cleanIdentifier(node.Table) + if err != nil { + return wrapErr(ErrInvalidIdentifier, err) + } + node.Columns, err = cleanIdentifiers(node.Columns) + return wrapErr(ErrInvalidIdentifier, err) +} + +// EnterDateTimeFunc checks that the function name is a valid identifier +func (s *StatementCleaner) EnterDateTimeFunc(node *tree.DateTimeFunction) (err error) { + node.FunctionName, err = cleanIdentifier(node.FunctionName) + return wrapErr(ErrInvalidIdentifier, err) +} + +// EnterDelete does nothing +func (s *StatementCleaner) EnterDelete(node *tree.Delete) (err error) { + return nil +} + +// EnterDeleteStmt does nothing +func (s *StatementCleaner) EnterDeleteStmt(node *tree.DeleteStmt) (err error) { + return nil +} + +// EnterExpressionLiteral checks that the literal is a valid literal +func (s *StatementCleaner) EnterExpressionLiteral(node *tree.ExpressionLiteral) (err error) { + return wrapErr(ErrInvalidLiteral, checkLiteral(node.Value)) +} + +// EnterExpressionBindParameter checks that the bind parameter is a valid bind parameter +func (s *StatementCleaner) EnterExpressionBindParameter(node *tree.ExpressionBindParameter) (err error) { + return wrapErr(ErrInvalidBindParameter, checkBindParameter(node.Parameter)) +} + +// EnterExpressionColumn checks that the table and column names are valid identifiers +func (s *StatementCleaner) EnterExpressionColumn(node *tree.ExpressionColumn) (err error) { + if node.Table != "" { + node.Table, err = cleanIdentifier(node.Table) + if err != nil { + return wrapErr(ErrInvalidIdentifier, err) + } + } + + node.Column, err = cleanIdentifier(node.Column) + return wrapErr(ErrInvalidIdentifier, err) +} + +// EnterExpressionUnary checks that the operator is a valid operator +func (s *StatementCleaner) EnterExpressionUnary(node *tree.ExpressionUnary) (err error) { + return wrapErr(ErrInvalidUnaryOperator, node.Operator.Valid()) +} + +// EnterExpressionBinary checks that the operator is a valid operator +func (s *StatementCleaner) EnterExpressionBinaryComparison(node *tree.ExpressionBinaryComparison) (err error) { + return wrapErr(ErrInvalidBinaryOperator, node.Operator.Valid()) +} + +// EnterExpressionFunction does nothing, since the function implementation is visited separately +func (s *StatementCleaner) EnterExpressionFunction(node *tree.ExpressionFunction) (err error) { + return nil +} + +// EnterExpressionList does nothing +func (s *StatementCleaner) EnterExpressionList(node *tree.ExpressionList) (err error) { + return nil +} + +// EnterExpressionCollate checks that the collation is a valid collation +func (s *StatementCleaner) EnterExpressionCollate(node *tree.ExpressionCollate) (err error) { + if node.Collation.Empty() { + return wrapErr(ErrInvalidCollation, errors.New("collation cannot be empty")) + } + + err = node.Collation.Valid() + if err != nil { + return wrapErr(ErrInvalidCollation, err) + } + + return nil +} + +// EnterExpressionStringCompare checks that the operator is a valid operator +func (s *StatementCleaner) EnterExpressionStringCompare(node *tree.ExpressionStringCompare) (err error) { + return wrapErr(ErrInvalidStringComparisonOperator, node.Operator.Valid()) +} + +// EnterExpressionIsNull does nothing +func (s *StatementCleaner) EnterExpressionIsNull(node *tree.ExpressionIsNull) (err error) { + return nil +} + +// EnterExpressionDistinct does nothing +func (s *StatementCleaner) EnterExpressionDistinct(node *tree.ExpressionDistinct) (err error) { + return nil +} + +// EnterExpressionBetween does nothing +func (s *StatementCleaner) EnterExpressionBetween(node *tree.ExpressionBetween) (err error) { + return nil +} + +// EnterExpressionExists checks that you can only negate EXISTS +func (s *StatementCleaner) EnterExpressionSelect(node *tree.ExpressionSelect) (err error) { + if node.IsNot && !node.IsExists { + return wrapErr(ErrInvalidIdentifier, errors.New("cannot negate non-EXISTS select expression")) + } + + return nil +} + +// EnterExpressionCase does nothing +func (s *StatementCleaner) EnterExpressionCase(node *tree.ExpressionCase) (err error) { + return nil +} + +// EnterExpressionArithmetic checks the validity of the operator +func (s *StatementCleaner) EnterExpressionArithmetic(node *tree.ExpressionArithmetic) (err error) { + return wrapErr(ErrInvalidArithmeticOperator, node.Operator.Valid()) +} + +// EnterScalarFunc checks that the function name is a valid identifier and is a scalar function +func (s *StatementCleaner) EnterScalarFunc(node *tree.ScalarFunction) (err error) { + node.FunctionName, err = cleanIdentifier(node.FunctionName) + return wrapErr(ErrInvalidIdentifier, err) +} + +// EnterGroupBy does nothing +func (s *StatementCleaner) EnterGroupBy(node *tree.GroupBy) (err error) { + return nil +} + +// EnterInsert does nothing +func (s *StatementCleaner) EnterInsert(node *tree.Insert) (err error) { + return nil +} + +// EnterInsertStmt cleans the insert type, table, table alias, and columns +func (s *StatementCleaner) EnterInsertStmt(node *tree.InsertStmt) (err error) { + err = node.InsertType.Valid() + if err != nil { + return wrapErr(ErrInvalidInsertType, err) + } + + node.Table, err = cleanIdentifier(node.Table) + if err != nil { + return wrapErr(ErrInvalidIdentifier, err) + } + + if node.TableAlias != "" { + node.TableAlias, err = cleanIdentifier(node.TableAlias) + if err != nil { + return wrapErr(ErrInvalidIdentifier, err) + } + } + + node.Columns, err = cleanIdentifiers(node.Columns) + return wrapErr(ErrInvalidIdentifier, err) +} + +// EnterJoinClause does nothing +func (s *StatementCleaner) EnterJoinClause(node *tree.JoinClause) (err error) { + return nil +} + +// EnterJoinConstraint does nothing +func (s *StatementCleaner) EnterJoinPredicate(node *tree.JoinPredicate) (err error) { + return nil +} + +// EnterJoinOperator validates the join operator +func (s *StatementCleaner) EnterJoinOperator(node *tree.JoinOperator) (err error) { + return wrapErr(ErrInvalidJoinOperator, node.Valid()) +} + +// EnterLimit does nothing +func (s *StatementCleaner) EnterLimit(node *tree.Limit) (err error) { + return nil +} + +// EnterOrderBy does nothing +func (s *StatementCleaner) EnterOrderBy(node *tree.OrderBy) (err error) { + return nil +} + +// EnterOrderingTerm validates the order type and null order type +func (s *StatementCleaner) EnterOrderingTerm(node *tree.OrderingTerm) (err error) { + if err = node.Collation.Valid(); err != nil { + return wrapErr(ErrInvalidCollation, err) + } + + // ordertype and nullorderingtype are both valid as empty, so we don't need to check for that + if err = node.OrderType.Valid(); err != nil { + return wrapErr(ErrInvalidOrderType, err) + } + + if err = node.NullOrdering.Valid(); err != nil { + return wrapErr(ErrInvalidNullOrderType, err) + } + + return nil +} + +// EnterQualifiedTableName checks the table name and alias and indexed by column +func (s *StatementCleaner) EnterQualifiedTableName(node *tree.QualifiedTableName) (err error) { + node.TableName, err = cleanIdentifier(node.TableName) + if err != nil { + return wrapErr(ErrInvalidIdentifier, err) + } + + if node.TableAlias != "" { + node.TableAlias, err = cleanIdentifier(node.TableAlias) + if err != nil { + return wrapErr(ErrInvalidIdentifier, err) + } + } + + if node.IndexedBy != "" { + node.IndexedBy, err = cleanIdentifier(node.IndexedBy) + if err != nil { + return wrapErr(ErrInvalidIdentifier, err) + } + } + + return nil +} + +// EnterResultColumnStar does nothing +func (s *StatementCleaner) EnterResultColumnStar(node *tree.ResultColumnStar) (err error) { + return nil +} + +// EnterResultColumnExpression checks the alias if it exists +func (s *StatementCleaner) EnterResultColumnExpression(node *tree.ResultColumnExpression) (err error) { + if node.Alias != "" { + node.Alias, err = cleanIdentifier(node.Alias) + if err != nil { + return wrapErr(ErrInvalidIdentifier, err) + } + } + + return nil +} + +// EnterResultColumnTable checks the table name +func (s *StatementCleaner) EnterResultColumnTable(node *tree.ResultColumnTable) (err error) { + node.TableName, err = cleanIdentifier(node.TableName) + if err != nil { + return wrapErr(ErrInvalidIdentifier, err) + } + + return nil +} + +// EnterReturningClause does nothing +func (s *StatementCleaner) EnterReturningClause(node *tree.ReturningClause) (err error) { + return nil +} + +// EnterReturningClauseColumn checks that either all is selected, or that an expression is used. An alias can +// only be used if an expression is used. +func (s *StatementCleaner) EnterReturningClauseColumn(node *tree.ReturningClauseColumn) (err error) { + if node.All && node.Expression != nil { + return wrapErr(ErrInvalidReturningClause, errors.New("all and expression cannot be set at the same time")) + } + + if node.Alias != "" && node.Expression == nil { + return wrapErr(ErrInvalidReturningClause, errors.New("alias cannot be set without an expression")) + } + + return nil +} + +// EnterSelect does nothing +func (s *StatementCleaner) EnterSelect(node *tree.Select) (err error) { + return nil +} + +// EnterSelectCore validates the select type +func (s *StatementCleaner) EnterSelectCore(node *tree.SelectCore) (err error) { + return wrapErr(ErrInvalidSelectType, node.SelectType.Valid()) +} + +// EnterSelectStmt checks that, for each SelectCore besides the last, a compound operator is provided +func (s *StatementCleaner) EnterSelectStmt(node *tree.SelectStmt) (err error) { + for _, core := range node.SelectCores[:len(node.SelectCores)-1] { + if core.Compound == nil { + return wrapErr(ErrInvalidCompoundOperator, errors.New("compound operator must be provided for all SelectCores except the last")) + } + } + + return nil +} + +// EnterFromClause does nothing +func (s *StatementCleaner) EnterFromClause(node *tree.FromClause) (err error) { + return nil +} + +// EnterCompoundOperator validates the compound operator +func (s *StatementCleaner) EnterCompoundOperator(node *tree.CompoundOperator) (err error) { + return wrapErr(ErrInvalidCompoundOperator, node.Operator.Valid()) +} + +// EnterTableOrSubquery checks the table name and alias +func (s *StatementCleaner) EnterTableOrSubqueryTable(node *tree.TableOrSubqueryTable) (err error) { + node.Name, err = cleanIdentifier(node.Name) + if err != nil { + return wrapErr(ErrInvalidIdentifier, err) + } + + if node.Alias != "" { + node.Alias, err = cleanIdentifier(node.Alias) + if err != nil { + return wrapErr(ErrInvalidIdentifier, err) + } + } + + return nil +} + +// EnterTableOrSubquerySelect checks the alias +func (s *StatementCleaner) EnterTableOrSubquerySelect(node *tree.TableOrSubquerySelect) (err error) { + if node.Alias != "" { + node.Alias, err = cleanIdentifier(node.Alias) + if err != nil { + return wrapErr(ErrInvalidIdentifier, err) + } + } + + return nil +} + +// EnterTableOrSubqueryList does nothing +func (s *StatementCleaner) EnterTableOrSubqueryList(node *tree.TableOrSubqueryList) (err error) { + return nil +} + +// EnterTableOrSubqueryJoin does nothing +func (s *StatementCleaner) EnterTableOrSubqueryJoin(node *tree.TableOrSubqueryJoin) (err error) { + return nil +} + +// EnterUpdateSetClause checks the column names +func (s *StatementCleaner) EnterUpdateSetClause(node *tree.UpdateSetClause) (err error) { + for i, column := range node.Columns { + node.Columns[i], err = cleanIdentifier(column) + if err != nil { + return wrapErr(ErrInvalidIdentifier, err) + } + } + + return nil +} + +// EnterUpdate does nothing +func (s *StatementCleaner) EnterUpdate(node *tree.Update) (err error) { + return nil +} + +// EnterUpdateStmt validates the update type +func (s *StatementCleaner) EnterUpdateStmt(node *tree.UpdateStmt) (err error) { + return wrapErr(ErrInvalidUpdateType, node.Or.Valid()) +} + +// EnterUpsert validates the upsert type +func (s *StatementCleaner) EnterUpsert(node *tree.Upsert) (err error) { + return wrapErr(ErrInvalidUpsertType, node.Type.Valid()) +} diff --git a/pkg/engine/sqlanalyzer/join/joins.go b/pkg/engine/sqlanalyzer/join/joins.go index bcad09809..9ca4b78d8 100644 --- a/pkg/engine/sqlanalyzer/join/joins.go +++ b/pkg/engine/sqlanalyzer/join/joins.go @@ -41,8 +41,9 @@ import "github.com/kwilteam/kwil-db/pkg/engine/sqlparser/tree" - Are joined with an "=" operator The joinAnalyzer is used in a DFS manner to determine if a join is valid. It can exist in one of the following states: -- BinaryCompareNotFound: The join has not yet encountered a binary comparison -- BinaryCompareFound: The join has encountered a binary comparison +- joinableStatusInvalid: the join is invalid +- joinableStatusContainsColumn: the join contains a column from one of the tables +- joinableStatusValid: the join is valid */ diff --git a/pkg/engine/sqlanalyzer/mutativity.go b/pkg/engine/sqlanalyzer/mutativity.go index 660c687a1..eed580f9d 100644 --- a/pkg/engine/sqlanalyzer/mutativity.go +++ b/pkg/engine/sqlanalyzer/mutativity.go @@ -2,7 +2,7 @@ package sqlanalyzer import "github.com/kwilteam/kwil-db/pkg/engine/sqlanalyzer/mutative" -func IsMutative(stmt accepter) (bool, error) { +func isMutative(stmt accepter) (bool, error) { mutativityWalker := mutative.NewMutativityWalker() err := stmt.Accept(mutativityWalker) diff --git a/pkg/engine/sqlanalyzer/order/analyzer.go b/pkg/engine/sqlanalyzer/order/analyzer.go index d4b3d86f6..6b138a059 100644 --- a/pkg/engine/sqlanalyzer/order/analyzer.go +++ b/pkg/engine/sqlanalyzer/order/analyzer.go @@ -1,12 +1,14 @@ package order import ( - "errors" + "fmt" + "github.com/kwilteam/kwil-db/pkg/engine/sqlanalyzer/attributes" "github.com/kwilteam/kwil-db/pkg/engine/sqlparser/tree" "github.com/kwilteam/kwil-db/pkg/engine/types" ) +// ExitOrderBy adds the required ordering terms to the statement. func (o *orderAnalyzer) ExitOrderBy(node *tree.OrderBy) error { if o.context.IsCompound { // sort all result columns @@ -31,9 +33,20 @@ func (o *orderAnalyzer) ExitOrderBy(node *tree.OrderBy) error { return nil } +// EnterSelectStmt creates a new scope. +// if the statement does not have an order by clause, one is created. +// it checks if the statement is a compound statement, and if so, sets the flag. func (o *orderAnalyzer) EnterSelectStmt(node *tree.SelectStmt) error { o.newScope() + // a bug was found where nil OrderBy would cause no ordering terms to be added + // this needs to be cleaned up later if there are no ordering terms + if node.OrderBy == nil { + node.OrderBy = &tree.OrderBy{ + OrderingTerms: []*tree.OrderingTerm{}, + } + } + if len(node.SelectCores) > 1 { o.context.IsCompound = true } @@ -43,17 +56,27 @@ func (o *orderAnalyzer) EnterSelectStmt(node *tree.SelectStmt) error { return nil } +// ExitSelectStmt pops the current scope. func (o *orderAnalyzer) ExitSelectStmt(node *tree.SelectStmt) error { + // we created a provisional order by clause in case one does not exist + // we clean it up here if it is empty + if node.OrderBy != nil && len(node.OrderBy.OrderingTerms) == 0 { + node.OrderBy = nil + } + o.oldScope() return nil } +// ExitSelectCore increments the current select position. +// This is can be used to determine compound select position. func (o *orderAnalyzer) ExitSelectCore(node *tree.SelectCore) error { o.context.currentSelectPosition++ return nil } +// EnterTableOrSubqueryTable adds the table to the list of used tables. func (o *orderAnalyzer) EnterTableOrSubqueryTable(node *tree.TableOrSubqueryTable) error { if o.context.currentSelectPosition != 0 { return nil @@ -78,6 +101,30 @@ func (o *orderAnalyzer) EnterTableOrSubqueryTable(node *tree.TableOrSubqueryTabl return nil } +// EnterCommonTableExpression adds the table to the list of used tables. +// This allows it to be used later for the ordering terms. +func (o *orderAnalyzer) EnterCTE(node *tree.CTE) error { + if len(node.Select.SelectCores) == 0 { + return nil + } + + cteAttributes, err := attributes.GetSelectCoreRelationAttributes(node.Select.SelectCores[0], o.schemaTables) + if err != nil { + return err + } + + cteTable, err := attributes.TableFromAttributes(node.Table, cteAttributes, true) + if err != nil { + return err + } + + o.schemaTables = append(o.schemaTables, cteTable) + + return nil +} + +// we need to add common table expressions to the list of the schemas tables, as well as the list of used tables +// this means we need to detect the structure of the common table expression func findTable(tables []*types.Table, name string) (*types.Table, error) { for _, t := range tables { if t.Name == name { @@ -85,5 +132,5 @@ func findTable(tables []*types.Table, name string) (*types.Table, error) { } } - return nil, errors.New("table not found") + return nil, fmt.Errorf(`table "%s" not found`, name) } diff --git a/pkg/engine/sqlanalyzer/order/order_test.go b/pkg/engine/sqlanalyzer/order/order_test.go index fdcd02779..790708ef8 100644 --- a/pkg/engine/sqlanalyzer/order/order_test.go +++ b/pkg/engine/sqlanalyzer/order/order_test.go @@ -19,6 +19,26 @@ func Test_Ordering(t *testing.T) { } testCases := []testCase{ + { + name: "select star with no ordering", + tables: defaultTables, + selectCores: []*tree.SelectCore{ + Select(). + Columns("*"). + From("users"). + Build(), + }, + expectedOrderingTerms: []*tree.OrderingTerm{ + { + Expression: &tree.ExpressionColumn{ + Table: "users", + Column: "id", + }, + OrderType: tree.OrderTypeAsc, + NullOrdering: tree.NullOrderingTypeLast, + }, + }, + }, { name: "simple ordering", tables: defaultTables, @@ -250,11 +270,17 @@ func Test_Ordering(t *testing.T) { t.Run(tc.name, func(t *testing.T) { walker := order.NewOrderWalker(tc.tables) + // we test for nil orderBy, since a previous bug was caused by having an empty orderBy + var orderBy *tree.OrderBy + if tc.originalOrderingTerms != nil { + orderBy = &tree.OrderBy{ + OrderingTerms: tc.originalOrderingTerms, + } + } + selectStmt := &tree.SelectStmt{ SelectCores: tc.selectCores, - OrderBy: &tree.OrderBy{ - OrderingTerms: tc.originalOrderingTerms, - }, + OrderBy: orderBy, } err := selectStmt.Accept(walker) @@ -262,7 +288,9 @@ func Test_Ordering(t *testing.T) { t.Fatal(err) } - assert.Equal(t, tc.expectedOrderingTerms, selectStmt.OrderBy.OrderingTerms) + if tc.expectedOrderingTerms != nil { + assert.EqualValues(t, tc.expectedOrderingTerms, selectStmt.OrderBy.OrderingTerms) + } }) } } diff --git a/pkg/engine/sqlanalyzer/order/visitors.go b/pkg/engine/sqlanalyzer/order/visitors.go index 0bce6f0a4..b2d031c53 100644 --- a/pkg/engine/sqlanalyzer/order/visitors.go +++ b/pkg/engine/sqlanalyzer/order/visitors.go @@ -6,15 +6,20 @@ import ( ) func NewOrderWalker(tables []*types.Table) tree.Walker { + // copy tables, since we will be modifying the tables slice to register CTEs + tbls := make([]*types.Table, len(tables)) + copy(tbls, tables) + return &orderAnalyzer{ Walker: tree.NewBaseWalker(), - schemaTables: tables, + schemaTables: tbls, } } type orderAnalyzer struct { tree.Walker - context *orderContext + context *orderContext + // schemaTables is a list of all tables in the schema schemaTables []*types.Table } diff --git a/pkg/engine/sqlanalyzer/utils/utils.go b/pkg/engine/sqlanalyzer/utils/utils.go new file mode 100644 index 000000000..de269cd6f --- /dev/null +++ b/pkg/engine/sqlanalyzer/utils/utils.go @@ -0,0 +1,32 @@ +package utils + +import ( + "fmt" + "strconv" + "strings" + + "github.com/kwilteam/kwil-db/pkg/engine/types" +) + +// IsLiteral detects if the passed string is convertable to a literal. +// It returns the type of the literal, or an error if it is not a literal. +func IsLiteral(literal string) (types.DataType, error) { + if strings.HasPrefix(literal, "'") && strings.HasSuffix(literal, "'") { + return types.TEXT, nil + } + + if strings.EqualFold(literal, "true") || strings.EqualFold(literal, "false") { + return types.INT, nil + } + + if strings.EqualFold(literal, "null") { + return types.NULL, nil + } + + _, err := strconv.Atoi(literal) + if err != nil { + return types.NULL, fmt.Errorf("invalid literal: could not detect literal type: %s", literal) + } + + return types.INT, nil +} diff --git a/pkg/engine/sqlparser/tree/collate.go b/pkg/engine/sqlparser/tree/collate.go index 54f0a375c..613dffce0 100644 --- a/pkg/engine/sqlparser/tree/collate.go +++ b/pkg/engine/sqlparser/tree/collate.go @@ -1,5 +1,10 @@ package tree +import ( + "fmt" + "strings" +) + type CollationType string const ( @@ -9,16 +14,28 @@ const ( ) func (c CollationType) String() string { - c.check() return string(c) } -func (c CollationType) check() { - if !c.Valid() { - panic("invalid collation type") +// Valid checks if the collation type is valid. +// Empty collation types are considered valid. +func (c *CollationType) Valid() error { + if c.Empty() { + return nil + } + + newC := CollationType(strings.ToUpper(string(*c))) + + switch newC { + case CollationTypeBinary, CollationTypeNoCase, CollationTypeRTrim: + default: + return fmt.Errorf("invalid collation type: %s", c) } + *c = newC + + return nil } -func (c CollationType) Valid() bool { - return c == CollationTypeBinary || c == CollationTypeNoCase || c == CollationTypeRTrim +func (c CollationType) Empty() bool { + return c == "" } diff --git a/pkg/engine/sqlparser/tree/insert.go b/pkg/engine/sqlparser/tree/insert.go index ef4a7f22d..a93f3121c 100644 --- a/pkg/engine/sqlparser/tree/insert.go +++ b/pkg/engine/sqlparser/tree/insert.go @@ -84,6 +84,15 @@ const ( InsertTypeInsertOrReplace ) +func (i InsertType) Valid() error { + switch i { + case InsertTypeInsert, InsertTypeReplace, InsertTypeInsertOrReplace: + return nil + default: + return fmt.Errorf("invalid insert type: %d", i) + } +} + func (i *InsertType) String() string { switch *i { case InsertTypeInsert: diff --git a/pkg/engine/sqlparser/tree/join-clause.go b/pkg/engine/sqlparser/tree/join-clause.go index c90827445..08edfe81b 100644 --- a/pkg/engine/sqlparser/tree/join-clause.go +++ b/pkg/engine/sqlparser/tree/join-clause.go @@ -169,3 +169,15 @@ func (j *JoinOperator) ToSQL() string { stmt.Token.Join() return stmt.String() } + +func (j *JoinOperator) Valid() error { + if j.JoinType < JoinTypeJoin || j.JoinType > JoinTypeFull { + return fmt.Errorf("invalid join type: %d", j.JoinType) + } + + if j.Outer && (j.JoinType == JoinTypeJoin || j.JoinType == JoinTypeInner) { + return fmt.Errorf("outer join cannot be used with generic join or inner join") + } + + return nil +} diff --git a/pkg/engine/sqlparser/tree/operators.go b/pkg/engine/sqlparser/tree/operators.go index ca205dc70..9c47efd69 100644 --- a/pkg/engine/sqlparser/tree/operators.go +++ b/pkg/engine/sqlparser/tree/operators.go @@ -1,8 +1,11 @@ package tree +import "fmt" + type BinaryOperator interface { binary() String() string + Valid() error } type ArithmeticOperator string @@ -26,6 +29,15 @@ func (a ArithmeticOperator) String() string { return string(a) } +func (a ArithmeticOperator) Valid() error { + switch a { + case ArithmeticOperatorAdd, ArithmeticOperatorSubtract, ArithmeticOperatorMultiply, ArithmeticOperatorDivide, ArithmeticOperatorModulus, ArithmeticOperatorBitwiseAnd, ArithmeticOperatorBitwiseOr, ArithmeticOperatorBitwiseLeftShift, ArithmeticOperatorBitwiseRightShift, ArithmeticConcat: + return nil + default: + return fmt.Errorf("invalid arithmetic operator: %s", a) + } +} + type ComparisonOperator string const ( @@ -48,6 +60,15 @@ func (c ComparisonOperator) String() string { return string(c) } +func (c ComparisonOperator) Valid() error { + switch c { + case ComparisonOperatorDoubleEqual, ComparisonOperatorEqual, ComparisonOperatorNotEqualDiamond, ComparisonOperatorNotEqual, ComparisonOperatorGreaterThan, ComparisonOperatorLessThan, ComparisonOperatorGreaterThanOrEqual, ComparisonOperatorLessThanOrEqual, ComparisonOperatorIs, ComparisonOperatorIsNot, ComparisonOperatorIn, ComparisonOperatorNotIn: + return nil + default: + return fmt.Errorf("invalid comparison operator: %s", c) + } +} + type LogicalOperator string const ( @@ -60,6 +81,15 @@ func (l LogicalOperator) String() string { return string(l) } +func (l LogicalOperator) Valid() error { + switch l { + case LogicalOperatorAnd, LogicalOperatorOr: + return nil + default: + return fmt.Errorf("invalid logical operator: %s", l) + } +} + type StringOperator string const ( @@ -77,6 +107,14 @@ func (s StringOperator) binary() {} func (s StringOperator) String() string { return string(s) } +func (s StringOperator) Valid() error { + switch s { + case StringOperatorLike, StringOperatorNotLike, StringOperatorGlob, StringOperatorNotGlob, StringOperatorRegexp, StringOperatorNotRegexp, StringOperatorMatch, StringOperatorNotMatch: + return nil + default: + return fmt.Errorf("invalid string operator: %s", s) + } +} func (s StringOperator) Escapable() bool { switch s { case StringOperatorLike, StringOperatorNotLike: @@ -98,3 +136,12 @@ const ( func (u UnaryOperator) String() string { return string(u) } + +func (u UnaryOperator) Valid() error { + switch u { + case UnaryOperatorPlus, UnaryOperatorMinus, UnaryOperatorNot, UnaryOperatorBitNot: + return nil + default: + return fmt.Errorf("invalid unary operator: %s", u) + } +} diff --git a/pkg/engine/sqlparser/tree/order-by.go b/pkg/engine/sqlparser/tree/order-by.go index 7bb1a151c..3e6388f0d 100644 --- a/pkg/engine/sqlparser/tree/order-by.go +++ b/pkg/engine/sqlparser/tree/order-by.go @@ -1,6 +1,9 @@ package tree import ( + "fmt" + "strings" + sqlwriter "github.com/kwilteam/kwil-db/pkg/engine/sqlparser/tree/sql-writer" ) @@ -52,9 +55,12 @@ func (o *OrderingTerm) ToSQL() string { stmt.WriteString(o.Expression.ToSQL()) - if o.Collation.Valid() { + err := o.Collation.Valid() + if !o.Collation.Empty() && err == nil { stmt.Token.Collate() stmt.WriteString(o.Collation.String()) + } else if !o.Collation.Empty() && err != nil { + panic(err) } if o.OrderType != OrderTypeNone { @@ -80,6 +86,14 @@ func (n NullOrderingType) String() string { return string(n) } +func (n NullOrderingType) Valid() error { + if n != NullOrderingTypeFirst && n != NullOrderingTypeLast && n != NullOrderingTypeNone { + return fmt.Errorf("invalid null ordering type: %s", n) + } + + return nil +} + type OrderType string const ( @@ -94,7 +108,21 @@ func (o OrderType) String() string { } func (o OrderType) check() { - if o != OrderTypeNone && o != OrderTypeAsc && o != OrderTypeDesc { - panic("invalid order type") + + err := o.Valid() + if err != nil { + panic(err) } } + +func (o *OrderType) Valid() error { + upper := OrderType(strings.ToUpper(string(*o))) + + if upper != OrderTypeAsc && upper != OrderTypeDesc && upper != OrderTypeNone { + return fmt.Errorf("invalid order type: %s", o) + } + + *o = upper + + return nil +} diff --git a/pkg/engine/sqlparser/tree/select.go b/pkg/engine/sqlparser/tree/select.go index 31fbe5fd6..b066ee9ba 100644 --- a/pkg/engine/sqlparser/tree/select.go +++ b/pkg/engine/sqlparser/tree/select.go @@ -143,6 +143,15 @@ const ( SelectTypeDistinct ) +func (s SelectType) Valid() error { + switch s { + case SelectTypeAll, SelectTypeDistinct: + return nil + default: + return fmt.Errorf("invalid select type: %d", s) + } +} + type FromClause struct { JoinClause *JoinClause } @@ -171,6 +180,15 @@ const ( CompoundOperatorTypeExcept ) +func (c CompoundOperatorType) Valid() error { + switch c { + case CompoundOperatorTypeUnion, CompoundOperatorTypeUnionAll, CompoundOperatorTypeIntersect, CompoundOperatorTypeExcept: + return nil + default: + return fmt.Errorf("invalid compound operator type: %d", c) + } +} + func (c *CompoundOperatorType) ToSQL() string { switch *c { case CompoundOperatorTypeUnion: diff --git a/pkg/engine/sqlparser/tree/update.go b/pkg/engine/sqlparser/tree/update.go index 2315169f8..e718deda4 100644 --- a/pkg/engine/sqlparser/tree/update.go +++ b/pkg/engine/sqlparser/tree/update.go @@ -82,20 +82,37 @@ const ( UpdateOrRollback UpdateOr = "ROLLBACK" ) -func (u *UpdateOr) check() { +func (u *UpdateOr) Valid() error { + if u.Empty() { + return nil + } + switch *u { - case UpdateOrAbort: - case UpdateOrFail: - case UpdateOrIgnore: - case UpdateOrReplace: - case UpdateOrRollback: + case UpdateOrAbort, UpdateOrFail, UpdateOrIgnore, UpdateOrReplace, UpdateOrRollback: + return nil default: - panic("unknown UpdateOr") + return fmt.Errorf("unknown UpdateOr: %s", *u) + } +} + +func (u UpdateOr) Empty() bool { + return u == "" +} + +func (u UpdateOr) check() { + if u.Empty() { + return + } + if err := u.Valid(); err != nil { + panic(err) } } func (u *UpdateOr) ToSQL() string { u.check() + if u.Empty() { + return "" + } stmt := sqlwriter.NewWriter() stmt.Token.Or() @@ -108,7 +125,7 @@ func (u *UpdateStmt) ToSQL() string { stmt := sqlwriter.NewWriter() stmt.Token.Update() - if u.Or != "" { + if !u.Or.Empty() { stmt.WriteString(u.Or.ToSQL()) } stmt.WriteString(u.QualifiedTableName.ToSQL()) diff --git a/pkg/engine/sqlparser/tree/upsert.go b/pkg/engine/sqlparser/tree/upsert.go index 99926522d..a1156a8f9 100644 --- a/pkg/engine/sqlparser/tree/upsert.go +++ b/pkg/engine/sqlparser/tree/upsert.go @@ -1,6 +1,8 @@ package tree import ( + "fmt" + sqlwriter "github.com/kwilteam/kwil-db/pkg/engine/sqlparser/tree/sql-writer" ) @@ -11,6 +13,15 @@ const ( UpsertTypeDoUpdate ) +func (u UpsertType) Valid() error { + switch u { + case UpsertTypeDoNothing, UpsertTypeDoUpdate: + return nil + default: + return fmt.Errorf("invalid upsert type: %d", u) + } +} + type Upsert struct { ConflictTarget *ConflictTarget Type UpsertType diff --git a/pkg/engine/sqlparser/tree/utils_test.go b/pkg/engine/sqlparser/tree/utils_test.go index 05a06e86d..4aed81f9a 100644 --- a/pkg/engine/sqlparser/tree/utils_test.go +++ b/pkg/engine/sqlparser/tree/utils_test.go @@ -24,10 +24,3 @@ func compareIgnoringWhitespace(a, b string) bool { return aWithoutWhitespace == bWithoutWhitespace } - -/* - - - - - */ diff --git a/pkg/engine/sqlparser/tree/visitor.go b/pkg/engine/sqlparser/tree/visitor.go deleted file mode 100644 index 87d2497a5..000000000 --- a/pkg/engine/sqlparser/tree/visitor.go +++ /dev/null @@ -1,591 +0,0 @@ -package tree - -/* -type Visitor interface { - VisitAggregateFunc(*AggregateFunc) error - VisitConflictTarget(*ConflictTarget) error - VisitCTE(*CTE) error - VisitDateTimeFunc(*DateTimeFunction) error - VisitDelete(*Delete) error - VisitDeleteStmt(*DeleteStmt) error - VisitExpressionLiteral(*ExpressionLiteral) error - VisitExpressionBindParameter(*ExpressionBindParameter) error - VisitExpressionColumn(*ExpressionColumn) error - VisitExpressionUnary(*ExpressionUnary) error - VisitExpressionBinaryComparison(*ExpressionBinaryComparison) error - VisitExpressionFunction(*ExpressionFunction) error - VisitExpressionList(*ExpressionList) error - VisitExpressionCollate(*ExpressionCollate) error - VisitExpressionStringCompare(*ExpressionStringCompare) error - VisitExpressionIsNull(*ExpressionIsNull) error - VisitExpressionDistinct(*ExpressionDistinct) error - VisitExpressionBetween(*ExpressionBetween) error - VisitExpressionSelect(*ExpressionSelect) error - VisitExpressionCase(*ExpressionCase) error - VisitExpressionArithmetic(*ExpressionArithmetic) error - VisitScalarFunc(*ScalarFunction) error - VisitGroupBy(*GroupBy) error - VisitInsert(*Insert) error - VisitInsertStmt(*InsertStmt) error - VisitJoinClause(*JoinClause) error - VisitJoinPredicate(*JoinPredicate) error - VisitJoinOperator(*JoinOperator) error - VisitLimit(*Limit) error - VisitOrderBy(*OrderBy) error - VisitOrderingTerm(*OrderingTerm) error - VisitQualifiedTableName(*QualifiedTableName) error - VisitResultColumnStar(*ResultColumnStar) error - VisitResultColumnExpression(*ResultColumnExpression) error - VisitResultColumnTable(*ResultColumnTable) error - VisitReturningClause(*ReturningClause) error - VisitReturningClauseColumn(*ReturningClauseColumn) error - VisitSelect(*Select) error - VisitSelectCore(*SelectCore) error - VisitSelectStmt(*SelectStmt) error - VisitFromClause(*FromClause) error - VisitCompoundOperator(*CompoundOperator) error - VisitTableOrSubqueryTable(*TableOrSubqueryTable) error - VisitTableOrSubquerySelect(*TableOrSubquerySelect) error - VisitTableOrSubqueryList(*TableOrSubqueryList) error - VisitTableOrSubqueryJoin(*TableOrSubqueryJoin) error - VisitUpdateSetClause(*UpdateSetClause) error - VisitUpdate(*Update) error - VisitUpdateStmt(*UpdateStmt) error - VisitUpsert(*Upsert) error -} - -func NewBaseVisitor() *BaseVisitor { - return &BaseVisitor{} -} - -type BaseVisitor struct { -} - -func (b *BaseVisitor) VisitAggregateFunc(p0 *AggregateFunc) error { - return nil -} - -func (b *BaseVisitor) VisitCTE(p0 *CTE) error { - return nil -} - -func (b *BaseVisitor) VisitCompoundOperator(p0 *CompoundOperator) error { - return nil -} - -func (b *BaseVisitor) VisitConflictTarget(p0 *ConflictTarget) error { - return nil -} - -func (b *BaseVisitor) VisitDateTimeFunc(p0 *DateTimeFunction) error { - return nil -} - -func (b *BaseVisitor) VisitDelete(p0 *Delete) error { - return nil -} - -func (b *BaseVisitor) VisitDeleteStmt(p0 *DeleteStmt) error { - return nil -} - -func (b *BaseVisitor) VisitExpressionArithmetic(p0 *ExpressionArithmetic) error { - return nil -} - -func (b *BaseVisitor) VisitExpressionBetween(p0 *ExpressionBetween) error { - return nil -} - -func (b *BaseVisitor) VisitExpressionBinaryComparison(p0 *ExpressionBinaryComparison) error { - return nil -} - -func (b *BaseVisitor) VisitExpressionBindParameter(p0 *ExpressionBindParameter) error { - return nil -} - -func (b *BaseVisitor) VisitExpressionCase(p0 *ExpressionCase) error { - return nil -} - -func (b *BaseVisitor) VisitExpressionCollate(p0 *ExpressionCollate) error { - return nil -} - -func (b *BaseVisitor) VisitExpressionColumn(p0 *ExpressionColumn) error { - return nil -} - -func (b *BaseVisitor) VisitExpressionDistinct(p0 *ExpressionDistinct) error { - return nil -} - -func (b *BaseVisitor) VisitExpressionFunction(p0 *ExpressionFunction) error { - return nil -} - -func (b *BaseVisitor) VisitExpressionIsNull(p0 *ExpressionIsNull) error { - return nil -} - -func (b *BaseVisitor) VisitExpressionList(p0 *ExpressionList) error { - return nil -} - -func (b *BaseVisitor) VisitExpressionLiteral(p0 *ExpressionLiteral) error { - return nil -} - -func (b *BaseVisitor) VisitExpressionSelect(p0 *ExpressionSelect) error { - return nil -} - -func (b *BaseVisitor) VisitExpressionStringCompare(p0 *ExpressionStringCompare) error { - return nil -} - -func (b *BaseVisitor) VisitExpressionUnary(p0 *ExpressionUnary) error { - return nil -} - -func (b *BaseVisitor) VisitFromClause(p0 *FromClause) error { - return nil -} - -func (b *BaseVisitor) VisitGroupBy(p0 *GroupBy) error { - return nil -} - -func (b *BaseVisitor) VisitInsert(p0 *Insert) error { - return nil -} - -func (b *BaseVisitor) VisitInsertStmt(p0 *InsertStmt) error { - return nil -} - -func (b *BaseVisitor) VisitJoinClause(p0 *JoinClause) error { - return nil -} - -func (b *BaseVisitor) VisitJoinOperator(p0 *JoinOperator) error { - return nil -} - -func (b *BaseVisitor) VisitJoinPredicate(p0 *JoinPredicate) error { - return nil -} - -func (b *BaseVisitor) VisitLimit(p0 *Limit) error { - return nil -} - -func (b *BaseVisitor) VisitOrderBy(p0 *OrderBy) error { - return nil -} - -func (b *BaseVisitor) VisitOrderingTerm(p0 *OrderingTerm) error { - return nil -} - -func (b *BaseVisitor) VisitQualifiedTableName(p0 *QualifiedTableName) error { - return nil -} - -func (b *BaseVisitor) VisitResultColumnExpression(p0 *ResultColumnExpression) error { - return nil -} - -func (b *BaseVisitor) VisitResultColumnStar(p0 *ResultColumnStar) error { - return nil -} - -func (b *BaseVisitor) VisitResultColumnTable(p0 *ResultColumnTable) error { - return nil -} - -func (b *BaseVisitor) VisitReturningClause(p0 *ReturningClause) error { - return nil -} - -func (b *BaseVisitor) VisitReturningClauseColumn(p0 *ReturningClauseColumn) error { - return nil -} - -func (b *BaseVisitor) VisitScalarFunc(p0 *ScalarFunction) error { - return nil -} - -func (b *BaseVisitor) VisitSelect(p0 *Select) error { - return nil -} - -func (b *BaseVisitor) VisitSelectCore(p0 *SelectCore) error { - return nil -} - -func (b *BaseVisitor) VisitSelectStmt(p0 *SelectStmt) error { - return nil -} - -func (b *BaseVisitor) VisitTableOrSubqueryJoin(p0 *TableOrSubqueryJoin) error { - return nil -} - -func (b *BaseVisitor) VisitTableOrSubqueryList(p0 *TableOrSubqueryList) error { - return nil -} - -func (b *BaseVisitor) VisitTableOrSubquerySelect(p0 *TableOrSubquerySelect) error { - return nil -} - -func (b *BaseVisitor) VisitTableOrSubqueryTable(p0 *TableOrSubqueryTable) error { - return nil -} - -func (b *BaseVisitor) VisitUpdate(p0 *Update) error { - return nil -} - -func (b *BaseVisitor) VisitUpdateSetClause(p0 *UpdateSetClause) error { - return nil -} - -func (b *BaseVisitor) VisitUpdateStmt(p0 *UpdateStmt) error { - return nil -} - -func (b *BaseVisitor) VisitUpsert(p0 *Upsert) error { - return nil -} - -type ExtendableVisitor interface { - Visitor - AddVisitor(Visitor) -} - -type extendableVisitor struct { - decorators []Visitor -} - -func NewExtendableVisitor(decorators ...Visitor) ExtendableVisitor { - return &extendableVisitor{ - decorators: decorators, - } -} - -func (b *extendableVisitor) runDecorators(fn func(visitor Visitor) error) error { - for _, decorator := range b.decorators { - if err := fn(decorator); err != nil { - return err - } - } - return nil -} - -func (b *extendableVisitor) VisitAggregateFunc(p0 *AggregateFunc) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitAggregateFunc(p0) - }) -} - -func (b *extendableVisitor) VisitCTE(p0 *CTE) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitCTE(p0) - }) -} - -func (b *extendableVisitor) VisitCompoundOperator(p0 *CompoundOperator) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitCompoundOperator(p0) - }) -} - -func (b *extendableVisitor) VisitConflictTarget(p0 *ConflictTarget) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitConflictTarget(p0) - }) -} - -func (b *extendableVisitor) VisitDateTimeFunc(p0 *DateTimeFunction) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitDateTimeFunc(p0) - }) -} - -func (b *extendableVisitor) VisitDelete(p0 *Delete) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitDelete(p0) - }) -} - -func (b *extendableVisitor) VisitDeleteStmt(p0 *DeleteStmt) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitDeleteStmt(p0) - }) -} - -func (b *extendableVisitor) VisitExpressionArithmetic(p0 *ExpressionArithmetic) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitExpressionArithmetic(p0) - }) -} - -func (b *extendableVisitor) VisitExpressionBetween(p0 *ExpressionBetween) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitExpressionBetween(p0) - }) -} - -func (b *extendableVisitor) VisitExpressionBinaryComparison(p0 *ExpressionBinaryComparison) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitExpressionBinaryComparison(p0) - }) -} - -func (b *extendableVisitor) VisitExpressionBindParameter(p0 *ExpressionBindParameter) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitExpressionBindParameter(p0) - }) -} - -func (b *extendableVisitor) VisitExpressionCase(p0 *ExpressionCase) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitExpressionCase(p0) - }) -} - -func (b *extendableVisitor) VisitExpressionCollate(p0 *ExpressionCollate) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitExpressionCollate(p0) - }) -} - -func (b *extendableVisitor) VisitExpressionColumn(p0 *ExpressionColumn) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitExpressionColumn(p0) - }) -} - -func (b *extendableVisitor) VisitExpressionDistinct(p0 *ExpressionDistinct) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitExpressionDistinct(p0) - }) -} - -func (b *extendableVisitor) VisitExpressionFunction(p0 *ExpressionFunction) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitExpressionFunction(p0) - }) -} - -func (b *extendableVisitor) VisitExpressionIsNull(p0 *ExpressionIsNull) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitExpressionIsNull(p0) - }) -} - -func (b *extendableVisitor) VisitExpressionList(p0 *ExpressionList) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitExpressionList(p0) - }) -} - -func (b *extendableVisitor) VisitExpressionLiteral(p0 *ExpressionLiteral) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitExpressionLiteral(p0) - }) -} - -func (b *extendableVisitor) VisitExpressionSelect(p0 *ExpressionSelect) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitExpressionSelect(p0) - }) -} - -func (b *extendableVisitor) VisitExpressionStringCompare(p0 *ExpressionStringCompare) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitExpressionStringCompare(p0) - }) -} - -func (b *extendableVisitor) VisitExpressionUnary(p0 *ExpressionUnary) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitExpressionUnary(p0) - }) -} - -func (b *extendableVisitor) VisitFromClause(p0 *FromClause) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitFromClause(p0) - }) -} - -func (b *extendableVisitor) VisitGroupBy(p0 *GroupBy) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitGroupBy(p0) - }) -} - -func (b *extendableVisitor) VisitInsert(p0 *Insert) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitInsert(p0) - }) -} - -func (b *extendableVisitor) VisitInsertStmt(p0 *InsertStmt) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitInsertStmt(p0) - }) -} - -func (b *extendableVisitor) VisitJoinClause(p0 *JoinClause) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitJoinClause(p0) - }) -} - -func (b *extendableVisitor) VisitJoinOperator(p0 *JoinOperator) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitJoinOperator(p0) - }) -} - -func (b *extendableVisitor) VisitJoinPredicate(p0 *JoinPredicate) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitJoinPredicate(p0) - }) -} - -func (b *extendableVisitor) VisitLimit(p0 *Limit) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitLimit(p0) - }) -} - -func (b *extendableVisitor) VisitOrderBy(p0 *OrderBy) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitOrderBy(p0) - }) -} - -func (b *extendableVisitor) VisitOrderingTerm(p0 *OrderingTerm) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitOrderingTerm(p0) - }) -} - -func (b *extendableVisitor) VisitQualifiedTableName(p0 *QualifiedTableName) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitQualifiedTableName(p0) - }) -} - -func (b *extendableVisitor) VisitResultColumnExpression(p0 *ResultColumnExpression) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitResultColumnExpression(p0) - }) -} - -func (b *extendableVisitor) VisitResultColumnStar(p0 *ResultColumnStar) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitResultColumnStar(p0) - }) -} - -func (b *extendableVisitor) VisitResultColumnTable(p0 *ResultColumnTable) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitResultColumnTable(p0) - }) -} - -func (b *extendableVisitor) VisitReturningClause(p0 *ReturningClause) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitReturningClause(p0) - }) -} - -func (b *extendableVisitor) VisitReturningClauseColumn(p0 *ReturningClauseColumn) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitReturningClauseColumn(p0) - }) -} - -func (b *extendableVisitor) VisitScalarFunc(p0 *ScalarFunction) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitScalarFunc(p0) - }) -} - -func (b *extendableVisitor) VisitSelect(p0 *Select) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitSelect(p0) - }) -} - -func (b *extendableVisitor) VisitSelectCore(p0 *SelectCore) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitSelectCore(p0) - }) -} - -func (b *extendableVisitor) VisitSelectStmt(p0 *SelectStmt) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitSelectStmt(p0) - }) -} - -func (b *extendableVisitor) VisitTableOrSubqueryJoin(p0 *TableOrSubqueryJoin) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitTableOrSubqueryJoin(p0) - }) -} - -func (b *extendableVisitor) VisitTableOrSubqueryList(p0 *TableOrSubqueryList) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitTableOrSubqueryList(p0) - }) -} - -func (b *extendableVisitor) VisitTableOrSubquerySelect(p0 *TableOrSubquerySelect) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitTableOrSubquerySelect(p0) - }) -} - -func (b *extendableVisitor) VisitTableOrSubqueryTable(p0 *TableOrSubqueryTable) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitTableOrSubqueryTable(p0) - }) -} - -func (b *extendableVisitor) VisitUpdate(p0 *Update) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitUpdate(p0) - }) -} - -func (b *extendableVisitor) VisitUpdateSetClause(p0 *UpdateSetClause) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitUpdateSetClause(p0) - }) -} - -func (b *extendableVisitor) VisitUpdateStmt(p0 *UpdateStmt) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitUpdateStmt(p0) - }) -} - -func (b *extendableVisitor) VisitUpsert(p0 *Upsert) error { - return b.runDecorators(func(v Visitor) error { - return v.VisitUpsert(p0) - }) -} - -func (b *extendableVisitor) AddVisitor(v Visitor) { - b.decorators = append(b.decorators, v) -} -*/ diff --git a/pkg/engine/sqlparser/tree/walker.go b/pkg/engine/sqlparser/tree/walker.go index aa2f62f3f..e9f19b0f6 100644 --- a/pkg/engine/sqlparser/tree/walker.go +++ b/pkg/engine/sqlparser/tree/walker.go @@ -150,6 +150,8 @@ type Walker interface { type BaseWalker struct{} +var _ Walker = &BaseWalker{} + func NewBaseWalker() Walker { return &BaseWalker{} } @@ -553,3 +555,914 @@ func (b *BaseWalker) EnterUpsert(p0 *Upsert) error { func (b *BaseWalker) ExitUpsert(p0 *Upsert) error { return nil } + +type BaseAccepter struct{} + +// ImplementedWalker implements the Walker interface. +// Unlike BaseWalker, it holds the methods to be implemented +// as functions in a struct. This makes it easier to implement +// for small, one-off walkers. +type ImplementedWalker struct { + FuncEnterAggregateFunc func(p0 *AggregateFunc) error + FuncExitAggregateFunc func(p0 *AggregateFunc) error + FuncEnterCTE func(p0 *CTE) error + FuncExitCTE func(p0 *CTE) error + FuncEnterCompoundOperator func(p0 *CompoundOperator) error + FuncExitCompoundOperator func(p0 *CompoundOperator) error + FuncEnterConflictTarget func(p0 *ConflictTarget) error + FuncExitConflictTarget func(p0 *ConflictTarget) error + FuncEnterDateTimeFunc func(p0 *DateTimeFunction) error + FuncExitDateTimeFunc func(p0 *DateTimeFunction) error + FuncEnterDelete func(p0 *Delete) error + FuncExitDelete func(p0 *Delete) error + FuncEnterDeleteStmt func(p0 *DeleteStmt) error + FuncExitDeleteStmt func(p0 *DeleteStmt) error + FuncEnterExpressionArithmetic func(p0 *ExpressionArithmetic) error + FuncExitExpressionArithmetic func(p0 *ExpressionArithmetic) error + FuncEnterExpressionBetween func(p0 *ExpressionBetween) error + FuncExitExpressionBetween func(p0 *ExpressionBetween) error + FuncEnterExpressionBinaryComparison func(p0 *ExpressionBinaryComparison) error + FuncExitExpressionBinaryComparison func(p0 *ExpressionBinaryComparison) error + FuncEnterExpressionBindParameter func(p0 *ExpressionBindParameter) error + FuncExitExpressionBindParameter func(p0 *ExpressionBindParameter) error + FuncEnterExpressionCase func(p0 *ExpressionCase) error + FuncExitExpressionCase func(p0 *ExpressionCase) error + FuncEnterExpressionCollate func(p0 *ExpressionCollate) error + FuncExitExpressionCollate func(p0 *ExpressionCollate) error + FuncEnterExpressionColumn func(p0 *ExpressionColumn) error + FuncExitExpressionColumn func(p0 *ExpressionColumn) error + FuncEnterExpressionDistinct func(p0 *ExpressionDistinct) error + FuncExitExpressionDistinct func(p0 *ExpressionDistinct) error + FuncEnterExpressionFunction func(p0 *ExpressionFunction) error + FuncExitExpressionFunction func(p0 *ExpressionFunction) error + FuncEnterExpressionIsNull func(p0 *ExpressionIsNull) error + FuncExitExpressionIsNull func(p0 *ExpressionIsNull) error + FuncEnterExpressionList func(p0 *ExpressionList) error + FuncExitExpressionList func(p0 *ExpressionList) error + FuncEnterExpressionLiteral func(p0 *ExpressionLiteral) error + FuncExitExpressionLiteral func(p0 *ExpressionLiteral) error + FuncEnterExpressionSelect func(p0 *ExpressionSelect) error + FuncExitExpressionSelect func(p0 *ExpressionSelect) error + FuncEnterExpressionStringCompare func(p0 *ExpressionStringCompare) error + FuncExitExpressionStringCompare func(p0 *ExpressionStringCompare) error + FuncEnterExpressionUnary func(p0 *ExpressionUnary) error + FuncExitExpressionUnary func(p0 *ExpressionUnary) error + FuncEnterFromClause func(p0 *FromClause) error + FuncExitFromClause func(p0 *FromClause) error + FuncEnterGroupBy func(p0 *GroupBy) error + FuncExitGroupBy func(p0 *GroupBy) error + FuncEnterInsert func(p0 *Insert) error + FuncExitInsert func(p0 *Insert) error + FuncEnterInsertStmt func(p0 *InsertStmt) error + FuncExitInsertStmt func(p0 *InsertStmt) error + FuncEnterJoinClause func(p0 *JoinClause) error + FuncExitJoinClause func(p0 *JoinClause) error + FuncEnterJoinOperator func(p0 *JoinOperator) error + FuncExitJoinOperator func(p0 *JoinOperator) error + FuncEnterJoinPredicate func(p0 *JoinPredicate) error + FuncExitJoinPredicate func(p0 *JoinPredicate) error + FuncEnterLimit func(p0 *Limit) error + FuncExitLimit func(p0 *Limit) error + FuncEnterOrderBy func(p0 *OrderBy) error + FuncExitOrderBy func(p0 *OrderBy) error + FuncEnterOrderingTerm func(p0 *OrderingTerm) error + FuncExitOrderingTerm func(p0 *OrderingTerm) error + FuncEnterQualifiedTableName func(p0 *QualifiedTableName) error + FuncExitQualifiedTableName func(p0 *QualifiedTableName) error + FuncEnterResultColumnExpression func(p0 *ResultColumnExpression) error + FuncExitResultColumnExpression func(p0 *ResultColumnExpression) error + FuncEnterResultColumnStar func(p0 *ResultColumnStar) error + FuncExitResultColumnStar func(p0 *ResultColumnStar) error + FuncEnterResultColumnTable func(p0 *ResultColumnTable) error + FuncExitResultColumnTable func(p0 *ResultColumnTable) error + FuncEnterReturningClause func(p0 *ReturningClause) error + FuncExitReturningClause func(p0 *ReturningClause) error + FuncEnterReturningClauseColumn func(p0 *ReturningClauseColumn) error + FuncExitReturningClauseColumn func(p0 *ReturningClauseColumn) error + FuncEnterScalarFunc func(p0 *ScalarFunction) error + FuncExitScalarFunc func(p0 *ScalarFunction) error + FuncEnterSelect func(p0 *Select) error + FuncExitSelect func(p0 *Select) error + FuncEnterSelectCore func(p0 *SelectCore) error + FuncExitSelectCore func(p0 *SelectCore) error + FuncEnterSelectStmt func(p0 *SelectStmt) error + FuncExitSelectStmt func(p0 *SelectStmt) error + FuncEnterTableOrSubqueryJoin func(p0 *TableOrSubqueryJoin) error + FuncExitTableOrSubqueryJoin func(p0 *TableOrSubqueryJoin) error + FuncEnterTableOrSubqueryList func(p0 *TableOrSubqueryList) error + FuncExitTableOrSubqueryList func(p0 *TableOrSubqueryList) error + FuncEnterTableOrSubquerySelect func(p0 *TableOrSubquerySelect) error + FuncExitTableOrSubquerySelect func(p0 *TableOrSubquerySelect) error + FuncEnterTableOrSubqueryTable func(p0 *TableOrSubqueryTable) error + FuncExitTableOrSubqueryTable func(p0 *TableOrSubqueryTable) error + FuncEnterUpdate func(p0 *Update) error + FuncExitUpdate func(p0 *Update) error + FuncEnterUpdateSetClause func(p0 *UpdateSetClause) error + FuncExitUpdateSetClause func(p0 *UpdateSetClause) error + FuncEnterUpdateStmt func(p0 *UpdateStmt) error + FuncExitUpdateStmt func(p0 *UpdateStmt) error + FuncEnterUpsert func(p0 *Upsert) error + FuncExitUpsert func(p0 *Upsert) error +} + +var _ Walker = &ImplementedWalker{} + +func (b *ImplementedWalker) EnterAggregateFunc(p0 *AggregateFunc) error { + if b.FuncEnterAggregateFunc == nil { + return nil + } + + return b.FuncEnterAggregateFunc(p0) +} + +func (b *ImplementedWalker) ExitAggregateFunc(p0 *AggregateFunc) error { + if b.FuncExitAggregateFunc == nil { + return nil + } + + return b.FuncExitAggregateFunc(p0) +} + +func (b *ImplementedWalker) EnterCTE(p0 *CTE) error { + if b.FuncEnterCTE == nil { + return nil + } + + return b.FuncEnterCTE(p0) +} + +func (b *ImplementedWalker) ExitCTE(p0 *CTE) error { + if b.FuncExitCTE == nil { + return nil + } + + return b.FuncExitCTE(p0) +} + +func (b *ImplementedWalker) EnterCompoundOperator(p0 *CompoundOperator) error { + if b.FuncEnterCompoundOperator == nil { + return nil + } + + return b.FuncEnterCompoundOperator(p0) +} + +func (b *ImplementedWalker) ExitCompoundOperator(p0 *CompoundOperator) error { + if b.FuncExitCompoundOperator == nil { + return nil + } + + return b.FuncExitCompoundOperator(p0) +} + +func (b *ImplementedWalker) EnterConflictTarget(p0 *ConflictTarget) error { + if b.FuncEnterConflictTarget == nil { + return nil + } + + return b.FuncEnterConflictTarget(p0) +} + +func (b *ImplementedWalker) ExitConflictTarget(p0 *ConflictTarget) error { + if b.FuncExitConflictTarget == nil { + return nil + } + + return b.FuncExitConflictTarget(p0) +} + +func (b *ImplementedWalker) EnterDateTimeFunc(p0 *DateTimeFunction) error { + if b.FuncEnterDateTimeFunc == nil { + return nil + } + + return b.FuncEnterDateTimeFunc(p0) +} + +func (b *ImplementedWalker) ExitDateTimeFunc(p0 *DateTimeFunction) error { + if b.FuncExitDateTimeFunc == nil { + return nil + } + + return b.FuncExitDateTimeFunc(p0) +} + +func (b *ImplementedWalker) EnterDelete(p0 *Delete) error { + if b.FuncEnterDelete == nil { + return nil + } + + return b.FuncEnterDelete(p0) +} + +func (b *ImplementedWalker) ExitDelete(p0 *Delete) error { + if b.FuncExitDelete == nil { + return nil + } + + return b.FuncExitDelete(p0) +} + +func (b *ImplementedWalker) EnterDeleteStmt(p0 *DeleteStmt) error { + if b.FuncEnterDeleteStmt == nil { + return nil + } + + return b.FuncEnterDeleteStmt(p0) +} + +func (b *ImplementedWalker) ExitDeleteStmt(p0 *DeleteStmt) error { + if b.FuncExitDeleteStmt == nil { + return nil + } + + return b.FuncExitDeleteStmt(p0) +} + +func (b *ImplementedWalker) EnterExpressionArithmetic(p0 *ExpressionArithmetic) error { + if b.FuncEnterExpressionArithmetic == nil { + return nil + } + + return b.FuncEnterExpressionArithmetic(p0) +} + +func (b *ImplementedWalker) ExitExpressionArithmetic(p0 *ExpressionArithmetic) error { + if b.FuncExitExpressionArithmetic == nil { + return nil + } + + return b.FuncExitExpressionArithmetic(p0) +} + +func (b *ImplementedWalker) EnterExpressionBetween(p0 *ExpressionBetween) error { + if b.FuncEnterExpressionBetween == nil { + return nil + } + + return b.FuncEnterExpressionBetween(p0) +} + +func (b *ImplementedWalker) ExitExpressionBetween(p0 *ExpressionBetween) error { + if b.FuncExitExpressionBetween == nil { + return nil + } + + return b.FuncExitExpressionBetween(p0) +} + +func (b *ImplementedWalker) EnterExpressionBinaryComparison(p0 *ExpressionBinaryComparison) error { + if b.FuncEnterExpressionBinaryComparison == nil { + return nil + } + + return b.FuncEnterExpressionBinaryComparison(p0) +} + +func (b *ImplementedWalker) ExitExpressionBinaryComparison(p0 *ExpressionBinaryComparison) error { + if b.FuncExitExpressionBinaryComparison == nil { + return nil + } + + return b.FuncExitExpressionBinaryComparison(p0) +} + +func (b *ImplementedWalker) EnterExpressionBindParameter(p0 *ExpressionBindParameter) error { + if b.FuncEnterExpressionBindParameter == nil { + return nil + } + + return b.FuncEnterExpressionBindParameter(p0) +} + +func (b *ImplementedWalker) ExitExpressionBindParameter(p0 *ExpressionBindParameter) error { + if b.FuncExitExpressionBindParameter == nil { + return nil + } + + return b.FuncExitExpressionBindParameter(p0) +} + +func (b *ImplementedWalker) EnterExpressionCase(p0 *ExpressionCase) error { + if b.FuncEnterExpressionCase == nil { + return nil + } + + return b.FuncEnterExpressionCase(p0) +} + +func (b *ImplementedWalker) ExitExpressionCase(p0 *ExpressionCase) error { + if b.FuncExitExpressionCase == nil { + return nil + } + + return b.FuncExitExpressionCase(p0) +} + +func (b *ImplementedWalker) EnterExpressionCollate(p0 *ExpressionCollate) error { + if b.FuncEnterExpressionCollate == nil { + return nil + } + + return b.FuncEnterExpressionCollate(p0) +} + +func (b *ImplementedWalker) ExitExpressionCollate(p0 *ExpressionCollate) error { + if b.FuncExitExpressionCollate == nil { + return nil + } + + return b.FuncExitExpressionCollate(p0) +} + +func (b *ImplementedWalker) EnterExpressionColumn(p0 *ExpressionColumn) error { + if b.FuncEnterExpressionColumn == nil { + return nil + } + + return b.FuncEnterExpressionColumn(p0) +} + +func (b *ImplementedWalker) ExitExpressionColumn(p0 *ExpressionColumn) error { + if b.FuncExitExpressionColumn == nil { + return nil + } + + return b.FuncExitExpressionColumn(p0) +} + +func (b *ImplementedWalker) EnterExpressionDistinct(p0 *ExpressionDistinct) error { + if b.FuncEnterExpressionDistinct == nil { + return nil + } + + return b.FuncEnterExpressionDistinct(p0) +} + +func (b *ImplementedWalker) ExitExpressionDistinct(p0 *ExpressionDistinct) error { + if b.FuncExitExpressionDistinct == nil { + return nil + } + + return b.FuncExitExpressionDistinct(p0) +} + +func (b *ImplementedWalker) EnterExpressionFunction(p0 *ExpressionFunction) error { + if b.FuncEnterExpressionFunction == nil { + return nil + } + + return b.FuncEnterExpressionFunction(p0) +} + +func (b *ImplementedWalker) ExitExpressionFunction(p0 *ExpressionFunction) error { + if b.FuncExitExpressionFunction == nil { + return nil + } + + return b.FuncExitExpressionFunction(p0) +} + +func (b *ImplementedWalker) EnterExpressionIsNull(p0 *ExpressionIsNull) error { + if b.FuncEnterExpressionIsNull == nil { + return nil + } + + return b.FuncEnterExpressionIsNull(p0) +} + +func (b *ImplementedWalker) ExitExpressionIsNull(p0 *ExpressionIsNull) error { + if b.FuncExitExpressionIsNull == nil { + return nil + } + + return b.FuncExitExpressionIsNull(p0) +} + +func (b *ImplementedWalker) EnterExpressionList(p0 *ExpressionList) error { + if b.FuncEnterExpressionList == nil { + return nil + } + + return b.FuncEnterExpressionList(p0) +} + +func (b *ImplementedWalker) ExitExpressionList(p0 *ExpressionList) error { + if b.FuncExitExpressionList == nil { + return nil + } + + return b.FuncExitExpressionList(p0) +} + +func (b *ImplementedWalker) EnterExpressionLiteral(p0 *ExpressionLiteral) error { + if b.FuncEnterExpressionLiteral == nil { + return nil + } + + return b.FuncEnterExpressionLiteral(p0) +} + +func (b *ImplementedWalker) ExitExpressionLiteral(p0 *ExpressionLiteral) error { + if b.FuncExitExpressionLiteral == nil { + return nil + } + + return b.FuncExitExpressionLiteral(p0) +} + +func (b *ImplementedWalker) EnterExpressionSelect(p0 *ExpressionSelect) error { + if b.FuncEnterExpressionSelect == nil { + return nil + } + + return b.FuncEnterExpressionSelect(p0) +} + +func (b *ImplementedWalker) ExitExpressionSelect(p0 *ExpressionSelect) error { + if b.FuncExitExpressionSelect == nil { + return nil + } + + return b.FuncExitExpressionSelect(p0) +} + +func (b *ImplementedWalker) EnterExpressionStringCompare(p0 *ExpressionStringCompare) error { + if b.FuncEnterExpressionStringCompare == nil { + return nil + } + + return b.FuncEnterExpressionStringCompare(p0) +} + +func (b *ImplementedWalker) ExitExpressionStringCompare(p0 *ExpressionStringCompare) error { + if b.FuncExitExpressionStringCompare == nil { + return nil + } + + return b.FuncExitExpressionStringCompare(p0) +} + +func (b *ImplementedWalker) EnterExpressionUnary(p0 *ExpressionUnary) error { + if b.FuncEnterExpressionUnary == nil { + return nil + } + + return b.FuncEnterExpressionUnary(p0) +} + +func (b *ImplementedWalker) ExitExpressionUnary(p0 *ExpressionUnary) error { + if b.FuncExitExpressionUnary == nil { + return nil + } + + return b.FuncExitExpressionUnary(p0) +} + +func (b *ImplementedWalker) EnterFromClause(p0 *FromClause) error { + if b.FuncEnterFromClause == nil { + return nil + } + + return b.FuncEnterFromClause(p0) +} + +func (b *ImplementedWalker) ExitFromClause(p0 *FromClause) error { + if b.FuncExitFromClause == nil { + return nil + } + + return b.FuncExitFromClause(p0) +} + +func (b *ImplementedWalker) EnterGroupBy(p0 *GroupBy) error { + if b.FuncEnterGroupBy == nil { + return nil + } + + return b.FuncEnterGroupBy(p0) +} + +func (b *ImplementedWalker) ExitGroupBy(p0 *GroupBy) error { + if b.FuncExitGroupBy == nil { + return nil + } + + return b.FuncExitGroupBy(p0) +} + +func (b *ImplementedWalker) EnterInsert(p0 *Insert) error { + if b.FuncEnterInsert == nil { + return nil + } + + return b.FuncEnterInsert(p0) +} + +func (b *ImplementedWalker) ExitInsert(p0 *Insert) error { + if b.FuncExitInsert == nil { + return nil + } + + return b.FuncExitInsert(p0) +} + +func (b *ImplementedWalker) EnterInsertStmt(p0 *InsertStmt) error { + if b.FuncEnterInsertStmt == nil { + return nil + } + + return b.FuncEnterInsertStmt(p0) +} + +func (b *ImplementedWalker) ExitInsertStmt(p0 *InsertStmt) error { + if b.FuncExitInsertStmt == nil { + return nil + } + + return b.FuncExitInsertStmt(p0) +} + +func (b *ImplementedWalker) EnterJoinClause(p0 *JoinClause) error { + if b.FuncEnterJoinClause == nil { + return nil + } + + return b.FuncEnterJoinClause(p0) +} + +func (b *ImplementedWalker) ExitJoinClause(p0 *JoinClause) error { + if b.FuncExitJoinClause == nil { + return nil + } + + return b.FuncExitJoinClause(p0) +} + +func (b *ImplementedWalker) EnterJoinOperator(p0 *JoinOperator) error { + if b.FuncEnterJoinOperator == nil { + return nil + } + + return b.FuncEnterJoinOperator(p0) +} + +func (b *ImplementedWalker) ExitJoinOperator(p0 *JoinOperator) error { + if b.FuncExitJoinOperator == nil { + return nil + } + + return b.FuncExitJoinOperator(p0) +} + +func (b *ImplementedWalker) EnterJoinPredicate(p0 *JoinPredicate) error { + if b.FuncEnterJoinPredicate == nil { + return nil + } + + return b.FuncEnterJoinPredicate(p0) +} + +func (b *ImplementedWalker) ExitJoinPredicate(p0 *JoinPredicate) error { + if b.FuncExitJoinPredicate == nil { + return nil + } + + return b.FuncExitJoinPredicate(p0) +} + +func (b *ImplementedWalker) EnterLimit(p0 *Limit) error { + if b.FuncEnterLimit == nil { + return nil + } + + return b.FuncEnterLimit(p0) +} + +func (b *ImplementedWalker) ExitLimit(p0 *Limit) error { + if b.FuncExitLimit == nil { + return nil + } + + return b.FuncExitLimit(p0) +} + +func (b *ImplementedWalker) EnterOrderBy(p0 *OrderBy) error { + if b.FuncEnterOrderBy == nil { + return nil + } + + return b.FuncEnterOrderBy(p0) +} + +func (b *ImplementedWalker) ExitOrderBy(p0 *OrderBy) error { + if b.FuncExitOrderBy == nil { + return nil + } + + return b.FuncExitOrderBy(p0) +} + +func (b *ImplementedWalker) EnterOrderingTerm(p0 *OrderingTerm) error { + if b.FuncEnterOrderingTerm == nil { + return nil + } + + return b.FuncEnterOrderingTerm(p0) +} + +func (b *ImplementedWalker) ExitOrderingTerm(p0 *OrderingTerm) error { + if b.FuncExitOrderingTerm == nil { + return nil + } + + return b.FuncExitOrderingTerm(p0) +} + +func (b *ImplementedWalker) EnterQualifiedTableName(p0 *QualifiedTableName) error { + if b.FuncEnterQualifiedTableName == nil { + return nil + } + + return b.FuncEnterQualifiedTableName(p0) +} + +func (b *ImplementedWalker) ExitQualifiedTableName(p0 *QualifiedTableName) error { + if b.FuncExitQualifiedTableName == nil { + return nil + } + + return b.FuncExitQualifiedTableName(p0) +} + +func (b *ImplementedWalker) EnterResultColumnExpression(p0 *ResultColumnExpression) error { + if b.FuncEnterResultColumnExpression == nil { + return nil + } + + return b.FuncEnterResultColumnExpression(p0) +} + +func (b *ImplementedWalker) ExitResultColumnExpression(p0 *ResultColumnExpression) error { + if b.FuncExitResultColumnExpression == nil { + return nil + } + + return b.FuncExitResultColumnExpression(p0) +} + +func (b *ImplementedWalker) EnterResultColumnStar(p0 *ResultColumnStar) error { + if b.FuncEnterResultColumnStar == nil { + return nil + } + + return b.FuncEnterResultColumnStar(p0) +} + +func (b *ImplementedWalker) ExitResultColumnStar(p0 *ResultColumnStar) error { + if b.FuncExitResultColumnStar == nil { + return nil + } + + return b.FuncExitResultColumnStar(p0) +} + +func (b *ImplementedWalker) EnterResultColumnTable(p0 *ResultColumnTable) error { + if b.FuncEnterResultColumnTable == nil { + return nil + } + + return b.FuncEnterResultColumnTable(p0) +} + +func (b *ImplementedWalker) ExitResultColumnTable(p0 *ResultColumnTable) error { + if b.FuncExitResultColumnTable == nil { + return nil + } + + return b.FuncExitResultColumnTable(p0) +} + +func (b *ImplementedWalker) EnterReturningClause(p0 *ReturningClause) error { + if b.FuncEnterReturningClause == nil { + return nil + } + + return b.FuncEnterReturningClause(p0) +} + +func (b *ImplementedWalker) ExitReturningClause(p0 *ReturningClause) error { + if b.FuncExitReturningClause == nil { + return nil + } + + return b.FuncExitReturningClause(p0) +} + +func (b *ImplementedWalker) EnterReturningClauseColumn(p0 *ReturningClauseColumn) error { + if b.FuncEnterReturningClauseColumn == nil { + return nil + } + + return b.FuncEnterReturningClauseColumn(p0) +} + +func (b *ImplementedWalker) ExitReturningClauseColumn(p0 *ReturningClauseColumn) error { + if b.FuncExitReturningClauseColumn == nil { + return nil + } + + return b.FuncExitReturningClauseColumn(p0) +} + +func (b *ImplementedWalker) EnterScalarFunc(p0 *ScalarFunction) error { + if b.FuncEnterScalarFunc == nil { + return nil + } + + return b.FuncEnterScalarFunc(p0) +} + +func (b *ImplementedWalker) ExitScalarFunc(p0 *ScalarFunction) error { + if b.FuncExitScalarFunc == nil { + return nil + } + + return b.FuncExitScalarFunc(p0) +} + +func (b *ImplementedWalker) EnterSelect(p0 *Select) error { + if b.FuncEnterSelect == nil { + return nil + } + + return b.FuncEnterSelect(p0) +} + +func (b *ImplementedWalker) ExitSelect(p0 *Select) error { + if b.FuncExitSelect == nil { + return nil + } + + return b.FuncExitSelect(p0) +} + +func (b *ImplementedWalker) EnterSelectCore(p0 *SelectCore) error { + if b.FuncEnterSelectCore == nil { + return nil + } + + return b.FuncEnterSelectCore(p0) +} + +func (b *ImplementedWalker) ExitSelectCore(p0 *SelectCore) error { + if b.FuncExitSelectCore == nil { + return nil + } + + return b.FuncExitSelectCore(p0) +} + +func (b *ImplementedWalker) EnterSelectStmt(p0 *SelectStmt) error { + if b.FuncEnterSelectStmt == nil { + return nil + } + + return b.FuncEnterSelectStmt(p0) +} + +func (b *ImplementedWalker) ExitSelectStmt(p0 *SelectStmt) error { + if b.FuncExitSelectStmt == nil { + return nil + } + + return b.FuncExitSelectStmt(p0) +} + +func (b *ImplementedWalker) EnterTableOrSubqueryJoin(p0 *TableOrSubqueryJoin) error { + if b.FuncEnterTableOrSubqueryJoin == nil { + return nil + } + + return b.FuncEnterTableOrSubqueryJoin(p0) +} + +func (b *ImplementedWalker) ExitTableOrSubqueryJoin(p0 *TableOrSubqueryJoin) error { + if b.FuncExitTableOrSubqueryJoin == nil { + return nil + } + + return b.FuncExitTableOrSubqueryJoin(p0) +} + +func (b *ImplementedWalker) EnterTableOrSubqueryList(p0 *TableOrSubqueryList) error { + if b.FuncEnterTableOrSubqueryList == nil { + return nil + } + + return b.FuncEnterTableOrSubqueryList(p0) +} + +func (b *ImplementedWalker) ExitTableOrSubqueryList(p0 *TableOrSubqueryList) error { + if b.FuncExitTableOrSubqueryList == nil { + return nil + } + + return b.FuncExitTableOrSubqueryList(p0) +} + +func (b *ImplementedWalker) EnterTableOrSubquerySelect(p0 *TableOrSubquerySelect) error { + if b.FuncEnterTableOrSubquerySelect == nil { + return nil + } + + return b.FuncEnterTableOrSubquerySelect(p0) +} + +func (b *ImplementedWalker) ExitTableOrSubquerySelect(p0 *TableOrSubquerySelect) error { + if b.FuncExitTableOrSubquerySelect == nil { + return nil + } + + return b.FuncExitTableOrSubquerySelect(p0) +} + +func (b *ImplementedWalker) EnterTableOrSubqueryTable(p0 *TableOrSubqueryTable) error { + if b.FuncEnterTableOrSubqueryTable == nil { + return nil + } + + return b.FuncEnterTableOrSubqueryTable(p0) +} + +func (b *ImplementedWalker) ExitTableOrSubqueryTable(p0 *TableOrSubqueryTable) error { + if b.FuncExitTableOrSubqueryTable == nil { + return nil + } + + return b.FuncExitTableOrSubqueryTable(p0) +} + +func (b *ImplementedWalker) EnterUpdate(p0 *Update) error { + if b.FuncEnterUpdate == nil { + return nil + } + + return b.FuncEnterUpdate(p0) +} + +func (b *ImplementedWalker) ExitUpdate(p0 *Update) error { + if b.FuncExitUpdate == nil { + return nil + } + + return b.FuncExitUpdate(p0) +} + +func (b *ImplementedWalker) EnterUpdateSetClause(p0 *UpdateSetClause) error { + if b.FuncEnterUpdateSetClause == nil { + return nil + } + + return b.FuncEnterUpdateSetClause(p0) +} + +func (b *ImplementedWalker) ExitUpdateSetClause(p0 *UpdateSetClause) error { + if b.FuncExitUpdateSetClause == nil { + return nil + } + + return b.FuncExitUpdateSetClause(p0) +} + +func (b *ImplementedWalker) EnterUpdateStmt(p0 *UpdateStmt) error { + if b.FuncEnterUpdateStmt == nil { + return nil + } + + return b.FuncEnterUpdateStmt(p0) +} + +func (b *ImplementedWalker) ExitUpdateStmt(p0 *UpdateStmt) error { + if b.FuncExitUpdateStmt == nil { + return nil + } + + return b.FuncExitUpdateStmt(p0) +} + +func (b *ImplementedWalker) EnterUpsert(p0 *Upsert) error { + if b.FuncEnterUpsert == nil { + return nil + } + + return b.FuncEnterUpsert(p0) +} + +func (b *ImplementedWalker) ExitUpsert(p0 *Upsert) error { + if b.FuncExitUpsert == nil { + return nil + } + + return b.FuncExitUpsert(p0) +} diff --git a/pkg/engine/types/clean.go b/pkg/engine/types/clean.go index 01df8e0af..3c660eff3 100644 --- a/pkg/engine/types/clean.go +++ b/pkg/engine/types/clean.go @@ -3,7 +3,6 @@ package types import ( "errors" "fmt" - "reflect" "strings" "github.com/kwilteam/kwil-db/pkg/engine/types/validation" @@ -44,21 +43,6 @@ func cleanIdents(idents *[]string) error { return nil } -func cleanScalar(scalar *any) error { - if scalar == nil { - return nil - } - - kind := reflect.TypeOf(*scalar).Kind() - - switch kind { - case reflect.String, reflect.Int, reflect.Float64, reflect.Bool, reflect.Int64, reflect.Float32, reflect.Int32, reflect.Int16, reflect.Int8, reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: - return nil - default: - return fmt.Errorf("invalid scalar type: %s", kind.String()) - } -} - func cleanActionParameters(inputs *[]string) error { if inputs == nil { return nil diff --git a/pkg/engine/types/foreign_key.go b/pkg/engine/types/foreign_key.go index 0b7c14de5..f74c08ff1 100644 --- a/pkg/engine/types/foreign_key.go +++ b/pkg/engine/types/foreign_key.go @@ -42,6 +42,21 @@ func (f *ForeignKey) Clean() error { return nil } +// Copy returns a copy of the foreign key +func (f *ForeignKey) Copy() *ForeignKey { + actions := make([]*ForeignKeyAction, len(f.Actions)) + for i, action := range f.Actions { + actions[i] = action.Copy() + } + + return &ForeignKey{ + ChildKeys: f.ChildKeys, + ParentKeys: f.ParentKeys, + ParentTable: f.ParentTable, + Actions: actions, + } +} + // ForeignKeyAction is used to specify what should occur // if a parent key is updated or deleted type ForeignKeyAction struct { @@ -60,6 +75,14 @@ func (f *ForeignKeyAction) Clean() error { ) } +// Copy returns a copy of the foreign key action +func (f *ForeignKeyAction) Copy() *ForeignKeyAction { + return &ForeignKeyAction{ + On: f.On, + Do: f.Do, + } +} + // ForeignKeyActionOn specifies when a foreign key action should occur. // It can be either "UPDATE" or "DELETE". type ForeignKeyActionOn string diff --git a/pkg/engine/types/index.go b/pkg/engine/types/index.go index 66bc386fb..4bec9aa53 100644 --- a/pkg/engine/types/index.go +++ b/pkg/engine/types/index.go @@ -8,9 +8,9 @@ import ( type IndexType string type Index struct { - Name string `json:"name" clean:"lower"` - Columns []string `json:"columns" clean:"lower"` - Type IndexType `json:"type" clean:"is_enum,index_type"` + Name string `json:"name"` + Columns []string `json:"columns"` + Type IndexType `json:"type"` } func (i *Index) Clean() error { @@ -21,6 +21,15 @@ func (i *Index) Clean() error { ) } +// Copy returns a copy of the index. +func (i *Index) Copy() *Index { + return &Index{ + Name: i.Name, + Columns: i.Columns, + Type: i.Type, + } +} + const ( BTREE IndexType = "BTREE" UNIQUE_BTREE IndexType = "UNIQUE_BTREE" diff --git a/pkg/engine/types/table.go b/pkg/engine/types/table.go index 53313778b..b1eae4ab2 100644 --- a/pkg/engine/types/table.go +++ b/pkg/engine/types/table.go @@ -5,16 +5,12 @@ import ( ) type Table struct { - Name string `json:"name" clean:"lower"` + Name string `json:"name"` Columns []*Column `json:"columns"` Indexes []*Index `json:"indexes,omitempty"` ForeignKeys []*ForeignKey `json:"foreign_keys"` } -func (t *Table) Identifier() string { - return t.Name -} - func (t *Table) Clean() error { hasPrimaryAttribute := false for _, col := range t.Columns { @@ -66,14 +62,14 @@ func (t *Table) Clean() error { func (t *Table) GetPrimaryKey() ([]string, error) { var primaryKey []string - hasAttribitePrimaryKey := false + hasAttributePrimaryKey := false for _, col := range t.Columns { for _, attr := range col.Attributes { if attr.Type == PRIMARY_KEY { - if hasAttribitePrimaryKey { + if hasAttributePrimaryKey { return nil, fmt.Errorf("table %s has multiple primary attributes", t.Name) } - hasAttribitePrimaryKey = true + hasAttributePrimaryKey = true primaryKey = []string{col.Name} } } @@ -86,25 +82,50 @@ func (t *Table) GetPrimaryKey() ([]string, error) { return nil, fmt.Errorf("table %s has multiple primary indexes", t.Name) } hasIndexPrimaryKey = true - primaryKey = idx.Columns + + // copy + // if we do not copy, then the returned slice will allow modification of the index + primaryKey = make([]string, len(idx.Columns)) + copy(primaryKey, idx.Columns) } } - if !hasAttribitePrimaryKey && !hasIndexPrimaryKey { + if !hasAttributePrimaryKey && !hasIndexPrimaryKey { return nil, fmt.Errorf("table %s has no primary key", t.Name) } - if hasAttribitePrimaryKey && hasIndexPrimaryKey { + if hasAttributePrimaryKey && hasIndexPrimaryKey { return nil, fmt.Errorf("table %s has both primary attribute and primary index", t.Name) } return primaryKey, nil } +// Copy returns a copy of the table +func (t *Table) Copy() *Table { + res := &Table{ + Name: t.Name, + } + + for _, col := range t.Columns { + res.Columns = append(res.Columns, col.Copy()) + } + + for _, idx := range t.Indexes { + res.Indexes = append(res.Indexes, idx.Copy()) + } + + for _, fk := range t.ForeignKeys { + res.ForeignKeys = append(res.ForeignKeys, fk.Copy()) + } + + return res +} + type Column struct { - Name string `json:"name" clean:"lower"` - Type DataType `json:"type" clean:"is_enum,data_type"` - Attributes []*Attribute `json:"attributes,omitempty" traverse:"shallow"` + Name string `json:"name"` + Type DataType `json:"type"` + Attributes []*Attribute `json:"attributes,omitempty"` } func (c *Column) Clean() error { @@ -120,6 +141,20 @@ func (c *Column) Clean() error { ) } +// Copy returns a copy of the column +func (c *Column) Copy() *Column { + res := &Column{ + Name: c.Name, + Type: c.Type, + } + + for _, attr := range c.Attributes { + res.Attributes = append(res.Attributes, attr.Copy()) + } + + return res +} + func (c *Column) hasPrimary() bool { for _, attr := range c.Attributes { if attr.Type == PRIMARY_KEY { @@ -130,17 +165,20 @@ func (c *Column) hasPrimary() bool { } type Attribute struct { - Type AttributeType `json:"type" clean:"is_enum,attribute_type"` - Value any `json:"value"` + Type AttributeType `json:"type"` + Value string `json:"value,omitempty"` } func (a *Attribute) Clean() error { - if a.Value == nil { - return a.Type.Clean() - } - return runCleans( a.Type.Clean(), - cleanScalar(&a.Value), ) } + +// Copy returns a copy of the attribute +func (a *Attribute) Copy() *Attribute { + return &Attribute{ + Type: a.Type, + Value: a.Value, + } +} diff --git a/pkg/engine/types/testdata/tables.go b/pkg/engine/types/testdata/tables.go new file mode 100644 index 000000000..754f427c9 --- /dev/null +++ b/pkg/engine/types/testdata/tables.go @@ -0,0 +1,173 @@ +package testdata + +import "github.com/kwilteam/kwil-db/pkg/engine/types" + +var ( + TableUsers = &types.Table{ + Name: "users", + Columns: []*types.Column{ + { + Name: "id", + Type: types.INT, + Attributes: []*types.Attribute{ + { + Type: types.PRIMARY_KEY, + }, + { + Type: types.NOT_NULL, + }, + }, + }, + { + Name: "username", + Type: types.TEXT, + Attributes: []*types.Attribute{ + { + Type: types.NOT_NULL, + }, + { + Type: types.UNIQUE, + }, + { + Type: types.MIN_LENGTH, + Value: "5", + }, + { + Type: types.MAX_LENGTH, + Value: "32", + }, + }, + }, + { + Name: "age", + Type: types.INT, + Attributes: []*types.Attribute{ + { + Type: types.NOT_NULL, + }, + { + Type: types.MIN, + Value: "13", + }, + { + Type: types.MAX, + Value: "420", + }, + }, + }, + { + Name: "address", + Type: types.TEXT, + Attributes: []*types.Attribute{ + { + Type: types.NOT_NULL, + }, + { + Type: types.UNIQUE, + }, + }, + }, + }, + Indexes: []*types.Index{ + { + Name: "age_idx", + Columns: []string{ + "age", + }, + Type: types.BTREE, + }, + }, + } + + TablePosts = &types.Table{ + Name: "posts", + Columns: []*types.Column{ + { + Name: "id", + Type: types.INT, + Attributes: []*types.Attribute{ + { + Type: types.PRIMARY_KEY, + }, + { + Type: types.NOT_NULL, + }, + }, + }, + { + Name: "title", + Type: types.TEXT, + Attributes: []*types.Attribute{ + { + Type: types.NOT_NULL, + }, + { + Type: types.MAX_LENGTH, + Value: "300", + }, + }, + }, + { + Name: "content", + Type: types.TEXT, + Attributes: []*types.Attribute{ + { + Type: types.NOT_NULL, + }, + { + Type: types.MAX_LENGTH, + Value: "10000", + }, + }, + }, + { + Name: "author_id", + Type: types.INT, + Attributes: []*types.Attribute{ + { + Type: types.NOT_NULL, + }, + }, + }, + { + Name: "post_date", + Type: types.TEXT, + Attributes: []*types.Attribute{ + { + Type: types.NOT_NULL, + }, + }, + }, + }, + Indexes: []*types.Index{ + { + Name: "author_idx", + Columns: []string{ + "author_id", + }, + Type: types.BTREE, + }, + }, + ForeignKeys: []*types.ForeignKey{ + { + ChildKeys: []string{ + "author_id", + }, + ParentKeys: []string{ + "id", + }, + ParentTable: "users", + Actions: []*types.ForeignKeyAction{ + { + On: types.ON_UPDATE, + Do: types.DO_CASCADE, + }, + { + On: types.ON_DELETE, + Do: types.DO_CASCADE, + }, + }, + }, + }, + } +) diff --git a/pkg/sessions/session.go b/pkg/sessions/session.go index 8e2fccb74..f9c3b4c5e 100644 --- a/pkg/sessions/session.go +++ b/pkg/sessions/session.go @@ -368,7 +368,7 @@ func (a *AtomicCommitter) endApply(ctx context.Context) error { func (a *AtomicCommitter) id(ctx context.Context) (id []byte, err error) { hash := sha256.New() - for _, c := range order.OrderMapLexicographically[CommittableId, Committable](a.committables) { + for _, c := range order.OrderMap[CommittableId, Committable](a.committables) { commitId, err := c.Value.ID(ctx) if err != nil { return nil, wrapError(ErrID, err) diff --git a/pkg/utils/order/order.go b/pkg/utils/order/order.go index 28cba3616..4e9eac4d5 100644 --- a/pkg/utils/order/order.go +++ b/pkg/utils/order/order.go @@ -1,15 +1,14 @@ package order -import "sort" +import ( + "cmp" + "sort" +) -// OrderMapLexicographically orders a map lexicographically by its keys. +// OrderMap orders a map lexicographically by its keys. // It permits any map with keys that are generically orderable. -// TODO: once upgraded to go 1.21, an equivalent is in the standard library -func OrderMapLexicographically[S Ordered, T any](m map[S]T) []*struct { - Id S - Value T -} { - keys := make([]S, 0, len(m)) +func OrderMap[O cmp.Ordered, T any](m map[O]T) []*KVPair[O, T] { + keys := make([]O, 0, len(m)) for k := range m { keys = append(keys, k) } @@ -18,17 +17,11 @@ func OrderMapLexicographically[S Ordered, T any](m map[S]T) []*struct { return keys[i] < keys[j] }) - result := make([]*struct { - Id S - Value T - }, 0, len(m)) + result := make([]*KVPair[O, T], 0, len(m)) for _, k := range keys { - result = append(result, &struct { - Id S - Value T - }{ - Id: k, + result = append(result, &KVPair[O, T]{ + Key: k, Value: m[k], }) } @@ -36,22 +29,20 @@ func OrderMapLexicographically[S Ordered, T any](m map[S]T) []*struct { return result } -type Ordered interface { - Integer | Float | ~string -} - -type Float interface { - ~float32 | ~float64 +// KVPair is a pair of key and value for a map. +// The key must be orderable. +type KVPair[O cmp.Ordered, T any] struct { + Key O + Value T } -type Integer interface { - Signed | Unsigned -} +// ToMap converts a slice of flattened pairs to a map. +func ToMap[O cmp.Ordered, T any](pairs []*KVPair[O, T]) map[O]T { + result := make(map[O]T, len(pairs)) -type Unsigned interface { - ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr -} + for _, pair := range pairs { + result[pair.Key] = pair.Value + } -type Signed interface { - ~int | ~int8 | ~int16 | ~int32 | ~int64 + return result }