diff --git a/server/tables/pgcatalog/iters.go b/server/tables/pgcatalog/iters.go new file mode 100644 index 0000000000..64b73ad0c9 --- /dev/null +++ b/server/tables/pgcatalog/iters.go @@ -0,0 +1,52 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pgcatalog + +import "github.com/dolthub/go-mysql-server/sql" + +// emptyRowIter implements the sql.RowIter for empty table. +func emptyRowIter() (sql.RowIter, error) { + return sql.RowsToRowIter(), nil +} + +// currentDatabaseSchemaIter iterates over all schemas in the current database, calling cb +// for each schema. Once all schemas have been processed or the callback returns +// false or an error, the iteration stops. +func currentDatabaseSchemaIter(ctx *sql.Context, c sql.Catalog, cb func(schema sql.DatabaseSchema) (bool, error)) (sql.Database, error) { + currentDB := ctx.GetCurrentDatabase() + db, err := c.Database(ctx, currentDB) + if err != nil { + return nil, err + } + + if schDB, ok := db.(sql.SchemaDatabase); ok { + schemas, err := schDB.AllSchemas(ctx) + if err != nil { + return nil, err + } + + for _, schema := range schemas { + cont, err := cb(schema) + if err != nil { + return nil, err + } + if !cont { + break + } + } + } + + return db, nil +} diff --git a/server/tables/pgcatalog/pg_attribute.go b/server/tables/pgcatalog/pg_attribute.go index f3e3570678..71a4915f16 100644 --- a/server/tables/pgcatalog/pg_attribute.go +++ b/server/tables/pgcatalog/pg_attribute.go @@ -43,11 +43,6 @@ func (p PgAttributeHandler) Name() string { return PgAttributeName } -// emptyRowIter implements the sql.RowIter for empty table. -func emptyRowIter() (sql.RowIter, error) { - return sql.RowsToRowIter(), nil -} - // RowIter implements the interface tables.Handler. func (p PgAttributeHandler) RowIter(ctx *sql.Context) (sql.RowIter, error) { doltSession := dsess.DSessFromSess(ctx.Session) @@ -55,7 +50,7 @@ func (p PgAttributeHandler) RowIter(ctx *sql.Context) (sql.RowIter, error) { var cols []*sql.Column - err := currentDatabaseSchemaIter(ctx, c, func(db sql.Database) (bool, error) { + _, err := currentDatabaseSchemaIter(ctx, c, func(db sql.DatabaseSchema) (bool, error) { err := sql.DBTableIter(ctx, db, func(t sql.Table) (cont bool, err error) { for _, col := range t.Schema() { cols = append(cols, col) diff --git a/server/tables/pgcatalog/pg_class.go b/server/tables/pgcatalog/pg_class.go index bd7be1b271..673ccf69b1 100644 --- a/server/tables/pgcatalog/pg_class.go +++ b/server/tables/pgcatalog/pg_class.go @@ -43,48 +43,6 @@ func (p PgClassHandler) Name() string { return PgClassName } -// currentDatabaseSchemaIter iterates over all schemas in the current database, -// calling cb for the database and each schema. Once all schemas have been -// processed or the callback returns false or an error, the iteration stops. -func currentDatabaseSchemaIter(ctx *sql.Context, c sql.Catalog, cb func(db sql.Database) (bool, error)) error { - currentDB := ctx.GetCurrentDatabase() - dbs := c.AllDatabases(ctx) - - for _, db := range dbs { - if currentDB != "" && db.Name() != currentDB { - continue - } - - if schDB, ok := db.(sql.SchemaDatabase); ok { - schemas, err := schDB.AllSchemas(ctx) - if err != nil { - return err - } - - for _, schema := range schemas { - cont, err := cb(schema) - if err != nil { - return err - } - if !cont { - break - } - } - } - - cont, err := cb(db) - if err != nil { - return err - } - if !cont { - break - } - - } - - return nil -} - // RowIter implements the interface tables.Handler. func (p PgClassHandler) RowIter(ctx *sql.Context) (sql.RowIter, error) { doltSession := dsess.DSessFromSess(ctx.Session) @@ -92,7 +50,7 @@ func (p PgClassHandler) RowIter(ctx *sql.Context) (sql.RowIter, error) { var classes []Class - err := currentDatabaseSchemaIter(ctx, c, func(db sql.Database) (bool, error) { + currentDB, err := currentDatabaseSchemaIter(ctx, c, func(db sql.DatabaseSchema) (bool, error) { // Get tables and table indexes err := sql.DBTableIter(ctx, db, func(t sql.Table) (cont bool, err error) { hasIndexes := false @@ -118,24 +76,24 @@ func (p PgClassHandler) RowIter(ctx *sql.Context) (sql.RowIter, error) { return false, err } - // Get views - if vdb, ok := db.(sql.ViewDatabase); ok { - views, err := vdb.AllViews(ctx) - if err != nil { - return false, err - } - - for _, view := range views { - classes = append(classes, Class{name: view.Name, hasIndexes: false, kind: "v"}) - } - } - return true, nil }) if err != nil { return nil, err } + // Get views + if vdb, ok := currentDB.(sql.ViewDatabase); ok { + views, err := vdb.AllViews(ctx) + if err != nil { + return nil, err + } + + for _, view := range views { + classes = append(classes, Class{name: view.Name, hasIndexes: false, kind: "v"}) + } + } + return &pgClassRowIter{ classes: classes, idx: 0, diff --git a/server/tables/pgcatalog/pg_tables.go b/server/tables/pgcatalog/pg_tables.go index a1b768dfea..d41e796ea0 100644 --- a/server/tables/pgcatalog/pg_tables.go +++ b/server/tables/pgcatalog/pg_tables.go @@ -17,6 +17,8 @@ package pgcatalog import ( "io" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" + sqle "github.com/dolthub/go-mysql-server" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/doltgresql/server/tables" @@ -43,8 +45,34 @@ func (p PgTablesHandler) Name() string { // RowIter implements the interface tables.Handler. func (p PgTablesHandler) RowIter(ctx *sql.Context) (sql.RowIter, error) { - // TODO: Implement pg_tables row iter - return emptyRowIter() + doltSession := dsess.DSessFromSess(ctx.Session) + c := sqle.NewDefault(doltSession.Provider()).Analyzer.Catalog + + var tables []sql.Table + var schemas []string + + // TODO: This should include a few information_schema tables + _, err := currentDatabaseSchemaIter(ctx, c, func(sch sql.DatabaseSchema) (bool, error) { + err := sql.DBTableIter(ctx, sch, func(t sql.Table) (cont bool, err error) { + tables = append(tables, t) + schemas = append(schemas, sch.SchemaName()) + return true, nil + }) + if err != nil { + return false, err + } + + return true, nil + }) + if err != nil { + return nil, err + } + + return &pgTablesRowIter{ + tables: tables, + schemas: schemas, + idx: 0, + }, nil } // Schema implements the interface tables.Handler. @@ -69,13 +97,45 @@ var pgTablesSchema = sql.Schema{ // pgTablesRowIter is the sql.RowIter for the pg_tables table. type pgTablesRowIter struct { + tables []sql.Table + schemas []string + idx int } var _ sql.RowIter = (*pgTablesRowIter)(nil) // Next implements the interface sql.RowIter. func (iter *pgTablesRowIter) Next(ctx *sql.Context) (sql.Row, error) { - return nil, io.EOF + if iter.idx >= len(iter.tables) { + return nil, io.EOF + } + iter.idx++ + table := iter.tables[iter.idx-1] + schema := iter.schemas[iter.idx-1] + + hasIndexes := false + if it, ok := table.(sql.IndexAddressable); ok { + idxs, err := it.GetIndexes(ctx) + if err != nil { + return nil, err + } + + if len(idxs) > 0 { + hasIndexes = true + } + } + + // TODO: Implement the rest of these pg_tables columns + return sql.Row{ + schema, // schemaname + table.Name(), // tablename + "", // tableowner + "", // tablespace + hasIndexes, // hasindexes + false, // hasrules + false, // hastriggers + false, // rowsecurity + }, nil } // Close implements the interface sql.RowIter. diff --git a/testing/go/pgcatalog_test.go b/testing/go/pgcatalog_test.go index f0a63e080e..d91df9f6c7 100644 --- a/testing/go/pgcatalog_test.go +++ b/testing/go/pgcatalog_test.go @@ -3251,10 +3251,23 @@ func TestPgTables(t *testing.T) { RunScripts(t, []ScriptTest{ { Name: "pg_tables", + SetUpScript: []string{ + `CREATE SCHEMA testschema;`, + `SET search_path TO testschema;`, + `CREATE TABLE testing (pk INT primary key, v1 INT);`, + + // Should show classes for all schemas + `CREATE SCHEMA testschema2;`, + `SET search_path TO testschema2;`, + }, Assertions: []ScriptTestAssertion{ { - Query: `SELECT * FROM "pg_catalog"."pg_tables";`, - Expected: []sql.Row{}, + Query: `SELECT * FROM "pg_catalog"."pg_tables" WHERE tablename='testing';`, + Expected: []sql.Row{{"testschema", "testing", "", "", "t", "f", "f", "f"}}, + }, + { + Query: `SELECT count(*) FROM "pg_catalog"."pg_tables" WHERE schemaname='pg_catalog';`, + Expected: []sql.Row{{139}}, }, { // Different cases and quoted, so it fails Query: `SELECT * FROM "PG_catalog"."pg_tables";`, @@ -3265,8 +3278,12 @@ func TestPgTables(t *testing.T) { ExpectedErr: "not", }, { // Different cases but non-quoted, so it works - Query: "SELECT tablename FROM PG_catalog.pg_TABLES ORDER BY tablename;", - Expected: []sql.Row{}, + Query: "SELECT schemaname, tablename FROM PG_catalog.pg_TABLES ORDER BY tablename DESC LIMIT 3;", + Expected: []sql.Row{ + {"testschema", "testing"}, + {"pg_catalog", "pg_views"}, + {"pg_catalog", "pg_user_mappings"}, + }, }, }, },