diff --git a/pkg/ccl/importccl/import_stmt.go b/pkg/ccl/importccl/import_stmt.go index a204d2e7f2d8..3604bb3adeff 100644 --- a/pkg/ccl/importccl/import_stmt.go +++ b/pkg/ccl/importccl/import_stmt.go @@ -606,13 +606,17 @@ func importPlanHook( // expressions are nullable. if len(isTargetCol) != 0 { for _, col := range found.VisibleColumns() { - if !(isTargetCol[col.Name] || col.IsNullable() || col.HasDefault()) { + if !(isTargetCol[col.Name] || col.IsNullable() || col.HasDefault() || col.IsComputed()) { return errors.Newf( "all non-target columns in IMPORT INTO must be nullable "+ - "or have default expressions but violated by column %q", + "or have default expressions, or have computed expressions"+ + " but violated by column %q", col.Name, ) } + if isTargetCol[col.Name] && col.IsComputed() { + return sqlbase.CannotWriteToComputedColError(col.Name) + } } } diff --git a/pkg/ccl/importccl/import_stmt_test.go b/pkg/ccl/importccl/import_stmt_test.go index 268954261fd5..401f19755a7d 100644 --- a/pkg/ccl/importccl/import_stmt_test.go +++ b/pkg/ccl/importccl/import_stmt_test.go @@ -66,6 +66,33 @@ import ( "github.com/stretchr/testify/require" ) +func createAvroData( + t *testing.T, name string, fields []map[string]interface{}, rows []map[string]interface{}, +) string { + var data bytes.Buffer + // Set up a simple schema for the import data. + schema := map[string]interface{}{ + "type": "record", + "name": name, + "fields": fields, + } + schemaStr, err := json.Marshal(schema) + require.NoError(t, err) + codec, err := goavro.NewCodec(string(schemaStr)) + require.NoError(t, err) + // Create an AVRO writer from the schema. + ocf, err := goavro.NewOCFWriter(goavro.OCFConfig{ + W: &data, + Codec: codec, + }) + require.NoError(t, err) + for _, row := range rows { + require.NoError(t, ocf.Append([]interface{}{row})) + } + // Retrieve the AVRO encoded data. + return data.String() +} + func TestImportData(t *testing.T) { defer leaktest.AfterTest(t)() defer log.Scope(t).Close(t) @@ -1427,14 +1454,6 @@ func TestImportCSVStmt(t *testing.T) { ``, "invalid option \"foo\"", }, - { - "bad-computed-column", - `IMPORT TABLE t (a INT8 PRIMARY KEY, b STRING AS ('hello') STORED, INDEX (b), INDEX (a, b)) CSV DATA (%s) WITH skip = '2'`, - nil, - testFiles.filesWithOpts, - ``, - "computed columns not supported", - }, { "primary-key-dup", `IMPORT TABLE t CREATE USING $1 CSV DATA (%s)`, @@ -3291,6 +3310,135 @@ func TestImportDefault(t *testing.T) { }) } +func TestImportComputed(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + const nodes = 3 + + ctx := context.Background() + baseDir := filepath.Join("testdata", "csv") + tc := testcluster.StartTestCluster(t, nodes, base.TestClusterArgs{ServerArgs: base.TestServerArgs{ExternalIODir: baseDir}}) + defer tc.Stopper().Stop(ctx) + conn := tc.Conns[0] + + sqlDB := sqlutils.MakeSQLRunner(conn) + var data string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "GET" { + _, _ = w.Write([]byte(data)) + } + })) + avroField := []map[string]interface{}{ + { + "name": "a", + "type": "int", + }, + { + "name": "b", + "type": "int", + }, + } + avroRows := []map[string]interface{}{ + {"a": 1, "b": 2}, {"a": 3, "b": 4}, + } + avroData := createAvroData(t, "t", avroField, avroRows) + pgdumpData := ` +CREATE TABLE users (a INT, b INT, c INT AS (a + b) STORED); +INSERT INTO users (a, b) VALUES (1, 2), (3, 4); +` + defer srv.Close() + tests := []struct { + into bool + name string + data string + create string + targetCols string + format string + // We expect exactly one of expectedResults and expectedError. + expectedResults [][]string + expectedError string + }{ + { + into: true, + name: "addition", + data: "35,23\n67,10", + create: "a INT, b INT, c INT AS (a + b) STORED", + targetCols: "a, b", + format: "CSV", + expectedResults: [][]string{{"35", "23", "58"}, {"67", "10", "77"}}, + }, + { + into: true, + name: "cannot-be-targeted", + data: "1,2,3\n3,4,5", + create: "a INT, b INT, c INT AS (a + b) STORED", + targetCols: "a, b, c", + format: "CSV", + expectedError: `cannot write directly to computed column "c"`, + }, + { + into: true, + name: "import-into-avro", + data: avroData, + create: "a INT, b INT, c INT AS (a + b) STORED", + targetCols: "a, b", + format: "AVRO", + expectedResults: [][]string{{"1", "2", "3"}, {"3", "4", "7"}}, + }, + { + into: false, + name: "import-table-csv", + data: "35,23\n67,10", + create: "a INT, c INT AS (a + b) STORED, b INT", + targetCols: "a, b", + format: "CSV", + expectedError: "not supported by IMPORT INTO CSV", + }, + { + into: false, + name: "import-table-avro", + data: avroData, + create: "a INT, b INT, c INT AS (a + b) STORED", + targetCols: "a, b", + format: "AVRO", + expectedResults: [][]string{{"1", "2", "3"}, {"3", "4", "7"}}, + }, + { + into: false, + name: "pgdump", + data: pgdumpData, + format: "PGDUMP", + expectedResults: [][]string{{"1", "2", "3"}, {"3", "4", "7"}}, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + defer sqlDB.Exec(t, `DROP TABLE IF EXISTS users`) + data = test.data + var importStmt string + if test.into { + sqlDB.Exec(t, fmt.Sprintf(`CREATE TABLE users (%s)`, test.create)) + importStmt = fmt.Sprintf(`IMPORT INTO users (%s) %s DATA (%q)`, + test.targetCols, test.format, srv.URL) + } else { + if test.format == "CSV" || test.format == "AVRO" { + importStmt = fmt.Sprintf( + `IMPORT TABLE users (%s) %s DATA (%q)`, test.create, test.format, srv.URL) + } else { + importStmt = fmt.Sprintf(`IMPORT %s (%q)`, test.format, srv.URL) + } + } + if test.expectedError != "" { + sqlDB.ExpectErr(t, test.expectedError, importStmt) + } else { + sqlDB.Exec(t, importStmt) + sqlDB.CheckQueryResults(t, `SELECT * FROM users`, test.expectedResults) + } + }) + } +} + // goos: darwin // goarch: amd64 // pkg: github.com/cockroachdb/cockroach/pkg/ccl/importccl diff --git a/pkg/ccl/importccl/import_table_creation.go b/pkg/ccl/importccl/import_table_creation.go index 83d9a5255840..9d9f370de60f 100644 --- a/pkg/ccl/importccl/import_table_creation.go +++ b/pkg/ccl/importccl/import_table_creation.go @@ -118,11 +118,6 @@ func MakeSimpleTableDescriptor( *tree.UniqueConstraintTableDef: // ignore case *tree.ColumnTableDef: - if def.Computed.Expr != nil { - return nil, unimplemented.NewWithIssueDetailf(42846, "import.computed", - "computed columns not supported: %s", tree.AsString(def)) - } - if err := sql.SimplifySerialInColumnDefWithRowID(ctx, def, &create.Table); err != nil { return nil, err } diff --git a/pkg/ccl/importccl/read_import_csv.go b/pkg/ccl/importccl/read_import_csv.go index 109681482b94..0ac0655c72aa 100644 --- a/pkg/ccl/importccl/read_import_csv.go +++ b/pkg/ccl/importccl/read_import_csv.go @@ -133,6 +133,13 @@ func (p *csvRowProducer) Row() (interface{}, error) { p.rowNum++ expectedColsLen := len(p.expectedColumns) if expectedColsLen == 0 { + // TODO(anzoteh96): this should really be only checked once per import instead of every row. + for _, col := range p.importCtx.tableDesc.VisibleColumns() { + if col.IsComputed() { + return nil, + errors.Newf("%q is computed, which is not supported by IMPORT INTO CSV", col.Name) + } + } expectedColsLen = len(p.importCtx.tableDesc.VisibleColumns()) } diff --git a/pkg/sql/row/row_converter.go b/pkg/sql/row/row_converter.go index dd4b924750cd..f5ba5416a994 100644 --- a/pkg/sql/row/row_converter.go +++ b/pkg/sql/row/row_converter.go @@ -14,6 +14,7 @@ import ( "context" "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/sql/schemaexpr" "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins" "github.com/cockroachdb/cockroach/pkg/sql/sem/transform" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" @@ -121,6 +122,9 @@ func GenerateInsertRow( // since we disallow computed columns from referencing other computed // columns, all the columns which could possibly be referenced *are* // available. + if !computedCols[i].IsComputed() { + continue + } d, err := computeExprs[i].Eval(evalCtx) if err != nil { return nil, errors.Wrapf(err, "computed column %s", tree.ErrString((*tree.Name)(&computedCols[i].Name))) @@ -265,9 +269,14 @@ func NewDatumRowConverter( var txCtx transform.ExprTransformContext semaCtx := tree.MakeSemaContext() - cols, defaultExprs, err := sqlbase.ProcessDefaultColumns(ctx, targetColDescriptors, immutDesc, &txCtx, c.EvalCtx, &semaCtx) + relevantColumns := func(col *sqlbase.ColumnDescriptor) bool { + return col.HasDefault() || col.IsComputed() + } + cols := sqlbase.ProcessColumnSet( + targetColDescriptors, immutDesc, relevantColumns) + defaultExprs, err := sqlbase.MakeDefaultExprs(ctx, cols, &txCtx, c.EvalCtx, &semaCtx) if err != nil { - return nil, errors.Wrap(err, "process default columns") + return nil, errors.Wrap(err, "process default and computed columns") } ri, err := MakeInserter( @@ -335,6 +344,9 @@ func NewDatumRowConverter( c.Datums = append(c.Datums, nil) } } + if col.IsComputed() && !isTargetCol(col) { + c.Datums = append(c.Datums, nil) + } } if len(c.Datums) != len(cols) { return nil, errors.New("unexpected hidden column") @@ -372,10 +384,28 @@ func (c *DatumRowConverter) Row(ctx context.Context, sourceID int32, rowIndex in c.Datums[i] = datum } } - - // TODO(justin): we currently disallow computed columns in import statements. - var computeExprs []tree.TypedExpr + colsForCompute := make([]sqlbase.ColumnDescriptor, len(c.tableDesc.Columns)) + for _, col := range c.tableDesc.Columns { + colsForCompute[c.computedIVarContainer.Mapping[col.ID]] = col + } + semaCtx := tree.MakeSemaContext() + semaCtx.TypeResolver = c.EvalCtx.TypeResolver + var txCtx transform.ExprTransformContext + computeExprs, err := schemaexpr.MakeComputedExprs( + ctx, + colsForCompute, + c.tableDesc, + tree.NewUnqualifiedTableName(tree.Name(c.tableDesc.Name)), + &txCtx, + c.EvalCtx, + &semaCtx, true /*addingCols*/) + if err != nil { + return errors.Wrapf(err, "error evaluating computed expression for IMPORT INTO") + } var computedCols []sqlbase.ColumnDescriptor + if len(computeExprs) > 0 { + computedCols = colsForCompute + } insertRow, err := GenerateInsertRow( c.defaultCache, computeExprs, c.cols, computedCols, c.EvalCtx, diff --git a/pkg/sql/sqlbase/default_exprs.go b/pkg/sql/sqlbase/default_exprs.go index 1507836bed3c..3ef9f249ba3f 100644 --- a/pkg/sql/sqlbase/default_exprs.go +++ b/pkg/sql/sqlbase/default_exprs.go @@ -90,14 +90,16 @@ func ProcessDefaultColumns( evalCtx *tree.EvalContext, semaCtx *tree.SemaContext, ) ([]ColumnDescriptor, []tree.TypedExpr, error) { - cols = processColumnSet(cols, tableDesc, func(col *ColumnDescriptor) bool { + cols = ProcessColumnSet(cols, tableDesc, func(col *ColumnDescriptor) bool { return col.DefaultExpr != nil }) defaultExprs, err := MakeDefaultExprs(ctx, cols, txCtx, evalCtx, semaCtx) return cols, defaultExprs, err } -func processColumnSet( +// ProcessColumnSet returns columns in cols, and other writable +// columns in tableDesc that fulfills a given criteria in inSet. +func ProcessColumnSet( cols []ColumnDescriptor, tableDesc *ImmutableTableDescriptor, inSet func(*ColumnDescriptor) bool, ) []ColumnDescriptor { colIDSet := make(map[ColumnID]struct{}, len(cols))