diff --git a/pkg/ccl/logictestccl/tests/3node-tenant/generated_test.go b/pkg/ccl/logictestccl/tests/3node-tenant/generated_test.go index c059ea9dafee..e10f277c807d 100644 --- a/pkg/ccl/logictestccl/tests/3node-tenant/generated_test.go +++ b/pkg/ccl/logictestccl/tests/3node-tenant/generated_test.go @@ -2124,6 +2124,13 @@ func TestTenantLogic_udf_delete( runLogicTest(t, "udf_delete") } +func TestTenantLogic_udf_fk( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runLogicTest(t, "udf_fk") +} + func TestTenantLogic_udf_in_column_defaults( t *testing.T, ) { diff --git a/pkg/sql/apply_join.go b/pkg/sql/apply_join.go index d22e84a0e375..1147b59db955 100644 --- a/pkg/sql/apply_join.go +++ b/pkg/sql/apply_join.go @@ -20,6 +20,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/opt/exec" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/types" + "github.com/cockroachdb/cockroach/pkg/util/buildutil" "github.com/cockroachdb/cockroach/pkg/util/errorutil/unimplemented" "github.com/cockroachdb/cockroach/pkg/util/tracing" "github.com/cockroachdb/errors" @@ -298,6 +299,10 @@ func runPlanInsidePlan( defer func() { params.p.curPlan.subqueryPlans = oldSubqueries }() + subqueryEvalCtxFactory := func() *extendedEvalContext { + return params.p.ExtendedEvalContextCopyAndReset() + } + // Create a separate memory account for the results of the subqueries. // Note that we intentionally defer the closure of the account until we // return from this method (after the main query is executed). @@ -306,7 +311,7 @@ func runPlanInsidePlan( if !execCfg.DistSQLPlanner.PlanAndRunSubqueries( ctx, params.p, - params.extendedEvalCtx.copy, + subqueryEvalCtxFactory, plan.subqueryPlans, recv, &subqueryResultMemAcc, @@ -341,9 +346,51 @@ func runPlanInsidePlan( finishedSetupFn, cleanup := getFinishedSetupFn(&plannerCopy) defer cleanup() + var evalCtxFactory func(usedConcurrently bool) *extendedEvalContext + if len(plan.cascades) != 0 || + len(plan.checkPlans) != 0 { + serialEvalCtx := plannerCopy.ExtendedEvalContextCopyAndReset() + evalCtxFactory = func(usedConcurrently bool) *extendedEvalContext { + if usedConcurrently { + return plannerCopy.ExtendedEvalContextCopyAndReset() + } + // Reuse the same object if this factory is not used concurrently. + plannerCopy.ExtendedEvalContextReset(serialEvalCtx) + return serialEvalCtx + } + } execCfg.DistSQLPlanner.PlanAndRun( ctx, evalCtx, planCtx, plannerCopy.Txn(), plan.main, recv, finishedSetupFn, ) + if p := planCtx.getPortalPauseInfo(); p != nil { + if buildutil.CrdbTestBuild && p.resumableFlow.flow == nil { + checkErr := errors.AssertionFailedf("flow for portal %s cannot be found", plannerCopy.pausablePortal.Name) + if recv.commErr != nil { + recv.commErr = errors.CombineErrors(recv.commErr, checkErr) + } else { + return checkErr + } + } + if !p.resumableFlow.cleanup.isComplete { + p.resumableFlow.cleanup.appendFunc(namedFunc{ + fName: "cleanup flow", f: func() { + p.resumableFlow.flow.Cleanup(ctx) + }, + }) + } + } + + if recv.commErr != nil || recv.getError() != nil { + return recv.commErr + } + + execCfg.DistSQLPlanner.PlanAndRunCascadesAndChecks( + ctx, &plannerCopy, evalCtxFactory, &plannerCopy.curPlan.planComponents, recv, + ) + if recv.commErr != nil { + return recv.commErr + } + return resultWriter.Err() } diff --git a/pkg/sql/conn_executor_exec.go b/pkg/sql/conn_executor_exec.go index cd38a3628fcd..b0ee6f3061a6 100644 --- a/pkg/sql/conn_executor_exec.go +++ b/pkg/sql/conn_executor_exec.go @@ -2079,19 +2079,17 @@ func (ex *connExecutor) execWithDistSQLEngine( if len(planner.curPlan.subqueryPlans) != 0 || len(planner.curPlan.cascades) != 0 || len(planner.curPlan.checkPlans) != 0 { - var serialEvalCtx extendedEvalContext - ex.initEvalCtx(ctx, &serialEvalCtx, planner) + serialEvalCtx := planner.ExtendedEvalContextCopyAndReset() + ex.initEvalCtx(ctx, serialEvalCtx, planner) evalCtxFactory = func(usedConcurrently bool) *extendedEvalContext { // Reuse the same object if this factory is not used concurrently. - factoryEvalCtx := &serialEvalCtx + factoryEvalCtx := serialEvalCtx if usedConcurrently { - factoryEvalCtx = &extendedEvalContext{} + factoryEvalCtx = planner.ExtendedEvalContextCopyAndReset() ex.initEvalCtx(ctx, factoryEvalCtx, planner) } ex.resetEvalCtx(factoryEvalCtx, planner.txn, planner.ExtendedEvalContext().StmtTimestamp) - factoryEvalCtx.Placeholders = &planner.semaCtx.Placeholders - factoryEvalCtx.Annotations = &planner.semaCtx.Annotations - factoryEvalCtx.SessionID = planner.ExtendedEvalContext().SessionID + planner.ExtendedEvalContextReset(factoryEvalCtx) return factoryEvalCtx } } diff --git a/pkg/sql/logictest/testdata/logic_test/udf_fk b/pkg/sql/logictest/testdata/logic_test/udf_fk new file mode 100644 index 000000000000..0ce5977eb743 --- /dev/null +++ b/pkg/sql/logictest/testdata/logic_test/udf_fk @@ -0,0 +1,228 @@ +# Disable fast path for some test runs. +let $enable_insert_fast_path +SELECT random() < 0.5 + +statement ok +SET enable_insert_fast_path = $enable_insert_fast_path + +statement ok +CREATE TABLE parent (p INT PRIMARY KEY); + +statement ok +CREATE TABLE child (c INT PRIMARY KEY, p INT NOT NULL REFERENCES parent(p)); + + +subtest insert + +statement ok +CREATE FUNCTION f_fk_c(k INT, r INT) RETURNS RECORD AS $$ + INSERT INTO child VALUES (k,r) RETURNING *; +$$ LANGUAGE SQL; + +statement ok +CREATE FUNCTION f_fk_p(r INT) RETURNS RECORD AS $$ + INSERT INTO parent VALUES (r) RETURNING *; +$$ LANGUAGE SQL; + +statement ok +CREATE FUNCTION f_fk_c_p(k INT, r INT) RETURNS RECORD AS $$ + INSERT INTO child VALUES (k,r); + INSERT INTO parent VALUES (r) RETURNING *; +$$ LANGUAGE SQL; + +statement ok +CREATE FUNCTION f_fk_p_c(k INT, r INT) RETURNS RECORD AS $$ + INSERT INTO parent VALUES (r); + INSERT INTO child VALUES (k, r) RETURNING *; +$$ LANGUAGE SQL; + +statement error pq: insert on table "child" violates foreign key constraint "child_p_fkey" +SELECT f_fk_c(100, 1); + +statement error pq: insert on table "child" violates foreign key constraint "child_p_fkey" +SELECT f_fk_c_p(100, 1); + +query T +SELECT f_fk_p_c(100, 1); +---- +(100,1) + +statement error pq: insert on table "child" violates foreign key constraint "child_p_fkey" +WITH x AS (SELECT f_fk_c(101, 2)) INSERT INTO parent VALUES (2); + +query T +WITH x AS (INSERT INTO parent VALUES (2) RETURNING p) SELECT f_fk_c(101, 2); +---- +(101,2) + +statement ok +TRUNCATE parent CASCADE + +statement ok +INSERT INTO parent (p) VALUES (1); + +statement ok +CREATE FUNCTION f_fk_c_multi(k1 INT, r1 INT, k2 INT, r2 INT) RETURNS SETOF RECORD AS $$ + INSERT INTO child VALUES (k1,r1); + INSERT INTO child VALUES (k2,r2); + SELECT * FROM child WHERE c = k1 OR c = k2; +$$ LANGUAGE SQL; + +statement error pq: insert on table "child" violates foreign key constraint "child_p_fkey" +SELECT f_fk_c_multi(101, 1, 102, 2); + +statement error pq: insert on table "child" violates foreign key constraint "child_p_fkey" +SELECT f_fk_c_multi(101, 2, 102, 1); + +query T rowsort +SELECT f_fk_c_multi(101, 1, 102, 1); +---- +(101,1) +(102,1) + +# Sequences advance even if subsequent statements fail foreign key checks. +statement ok +CREATE SEQUENCE s; + +statement ok +CREATE FUNCTION f_fk_c_seq_first(k INT, r INT) RETURNS RECORD AS $$ + SELECT nextval('s'); + INSERT INTO child VALUES (k,r) RETURNING *; +$$ LANGUAGE SQL; + +statement ok +CREATE FUNCTION f_fk_c_seq_last(k INT, r INT) RETURNS RECORD AS $$ + INSERT INTO child VALUES (k,r) RETURNING *; + SELECT nextval('s'); +$$ LANGUAGE SQL; + +statement error pq: insert on table "child" violates foreign key constraint "child_p_fkey" +SELECT f_fk_c_seq_last(103,2); + +statement error pq: currval\(\): currval of sequence \"test.public.s\" is not yet defined in this session +SELECT currval('s'); + +statement error pq: insert on table "child" violates foreign key constraint "child_p_fkey" +SELECT f_fk_c_seq_first(103,2); + +query I +SELECT currval('s'); +---- +1 + +subtest delete + +statement ok +TRUNCATE parent CASCADE + +statement ok +INSERT INTO parent (p) VALUES (1), (2), (3), (4); + +statement ok +INSERT INTO child (c, p) VALUES (100, 1), (101, 2), (102, 3); + +query I rowsort +SELECT * FROM parent +---- +1 +2 +3 +4 + +query II rowsort +SELECT * FROM child +---- +100 1 +101 2 +102 3 + +statement ok +CREATE FUNCTION f_fk_c_del(k INT) RETURNS RECORD AS $$ + DELETE FROM child WHERE c = k RETURNING *; +$$ LANGUAGE SQL; + +statement ok +CREATE FUNCTION f_fk_p_del(r INT) RETURNS RECORD AS $$ + DELETE FROM parent WHERE p = r RETURNING *; +$$ LANGUAGE SQL; + +statement ok +CREATE FUNCTION f_fk_c_p_del(k INT, r INT) RETURNS RECORD AS $$ + DELETE FROM child WHERE c = k RETURNING *; + DELETE FROM parent WHERE p = r RETURNING *; +$$ LANGUAGE SQL; + +statement ok +CREATE FUNCTION f_fk_p_c_del(k INT, r INT) RETURNS RECORD AS $$ + DELETE FROM parent WHERE p = r RETURNING *; + DELETE FROM child WHERE c = k RETURNING *; +$$ LANGUAGE SQL; + +statement ok +SELECT f_fk_p_del(4); + +statement error pq: delete on table "parent" violates foreign key constraint "child_p_fkey" on table "child"\nDETAIL: Key \(p\)=\(3\) is still referenced from table "child"\. +SELECT f_fk_p_del(3); + +statement ok +SELECT f_fk_c_del(102); + +statement ok +SELECT f_fk_p_del(3); + +statement error pq: delete on table "parent" violates foreign key constraint "child_p_fkey" on table "child"\nDETAIL: Key \(p\)=\(2\) is still referenced from table "child"\. +SELECT f_fk_p_c_del(101,2); + +statement ok +SELECT f_fk_c_p_del(101,2); + +statement ok +SELECT f_fk_c_del(100), f_fk_p_del(1); + +query I rowsort +SELECT * FROM parent +---- + +query II rowsort +SELECT * FROM child +---- + + +subtest upsert + +statement ok +TRUNCATE parent CASCADE + +statement ok +CREATE FUNCTION f_fk_c_ocdu(k INT, r INT) RETURNS RECORD AS $$ + INSERT INTO child VALUES (k, r) ON CONFLICT (c) DO UPDATE SET p = r RETURNING *; +$$ LANGUAGE SQL; + +statement ok +INSERT INTO parent VALUES (1), (3); + +# Insert +statement ok +SELECT f_fk_c_ocdu(100,1); + +# Update to value not in parent fails. +statement error pq: insert on table "child" violates foreign key constraint "child_p_fkey" +SELECT f_fk_c_ocdu(100,2); + +# Inserting value not in parent fails. +statement error pq: insert on table "child" violates foreign key constraint "child_p_fkey" +SELECT f_fk_c_ocdu(101,2); + +statement ok +CREATE FUNCTION f_fk_c_ups(k INT, r INT) RETURNS RECORD AS $$ + UPSERT INTO child VALUES (k, r) RETURNING *; +$$ LANGUAGE SQL; + +statement ok +SELECT f_fk_c_ups(102,3); + +statement error pq: upsert on table "child" violates foreign key constraint "child_p_fkey" +SELECT f_fk_c_ups(102,4); + +statement error pq: upsert on table "child" violates foreign key constraint "child_p_fkey" +SELECT f_fk_c_ups(103,4); diff --git a/pkg/sql/logictest/tests/fakedist-disk/generated_test.go b/pkg/sql/logictest/tests/fakedist-disk/generated_test.go index ab65395c96b6..18ef23488a7f 100644 --- a/pkg/sql/logictest/tests/fakedist-disk/generated_test.go +++ b/pkg/sql/logictest/tests/fakedist-disk/generated_test.go @@ -2088,6 +2088,13 @@ func TestLogic_udf_delete( runLogicTest(t, "udf_delete") } +func TestLogic_udf_fk( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runLogicTest(t, "udf_fk") +} + func TestLogic_udf_in_column_defaults( t *testing.T, ) { diff --git a/pkg/sql/logictest/tests/fakedist-vec-off/generated_test.go b/pkg/sql/logictest/tests/fakedist-vec-off/generated_test.go index 3720fab4b21e..7315d8d4b21d 100644 --- a/pkg/sql/logictest/tests/fakedist-vec-off/generated_test.go +++ b/pkg/sql/logictest/tests/fakedist-vec-off/generated_test.go @@ -2095,6 +2095,13 @@ func TestLogic_udf_delete( runLogicTest(t, "udf_delete") } +func TestLogic_udf_fk( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runLogicTest(t, "udf_fk") +} + func TestLogic_udf_in_column_defaults( t *testing.T, ) { diff --git a/pkg/sql/logictest/tests/fakedist/generated_test.go b/pkg/sql/logictest/tests/fakedist/generated_test.go index 6d79c2e6e3b9..aa9a5b8bb06f 100644 --- a/pkg/sql/logictest/tests/fakedist/generated_test.go +++ b/pkg/sql/logictest/tests/fakedist/generated_test.go @@ -2109,6 +2109,13 @@ func TestLogic_udf_delete( runLogicTest(t, "udf_delete") } +func TestLogic_udf_fk( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runLogicTest(t, "udf_fk") +} + func TestLogic_udf_in_column_defaults( t *testing.T, ) { diff --git a/pkg/sql/logictest/tests/local-legacy-schema-changer/generated_test.go b/pkg/sql/logictest/tests/local-legacy-schema-changer/generated_test.go index 034a0e06ecd1..5fd05c313089 100644 --- a/pkg/sql/logictest/tests/local-legacy-schema-changer/generated_test.go +++ b/pkg/sql/logictest/tests/local-legacy-schema-changer/generated_test.go @@ -2081,6 +2081,13 @@ func TestLogic_udf_delete( runLogicTest(t, "udf_delete") } +func TestLogic_udf_fk( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runLogicTest(t, "udf_fk") +} + func TestLogic_udf_in_column_defaults( t *testing.T, ) { diff --git a/pkg/sql/logictest/tests/local-mixed-22.2-23.1/generated_test.go b/pkg/sql/logictest/tests/local-mixed-22.2-23.1/generated_test.go index 8fa551abc51b..3bb6ba38ce0e 100644 --- a/pkg/sql/logictest/tests/local-mixed-22.2-23.1/generated_test.go +++ b/pkg/sql/logictest/tests/local-mixed-22.2-23.1/generated_test.go @@ -2039,6 +2039,13 @@ func TestLogic_udf_delete( runLogicTest(t, "udf_delete") } +func TestLogic_udf_fk( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runLogicTest(t, "udf_fk") +} + func TestLogic_udf_insert( t *testing.T, ) { diff --git a/pkg/sql/logictest/tests/local-vec-off/generated_test.go b/pkg/sql/logictest/tests/local-vec-off/generated_test.go index 70a53f3fb97d..e24a77efcde9 100644 --- a/pkg/sql/logictest/tests/local-vec-off/generated_test.go +++ b/pkg/sql/logictest/tests/local-vec-off/generated_test.go @@ -2109,6 +2109,13 @@ func TestLogic_udf_delete( runLogicTest(t, "udf_delete") } +func TestLogic_udf_fk( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runLogicTest(t, "udf_fk") +} + func TestLogic_udf_in_column_defaults( t *testing.T, ) { diff --git a/pkg/sql/logictest/tests/local/generated_test.go b/pkg/sql/logictest/tests/local/generated_test.go index 165c495c46fe..db3d9b44eb36 100644 --- a/pkg/sql/logictest/tests/local/generated_test.go +++ b/pkg/sql/logictest/tests/local/generated_test.go @@ -2312,6 +2312,13 @@ func TestLogic_udf_delete( runLogicTest(t, "udf_delete") } +func TestLogic_udf_fk( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runLogicTest(t, "udf_fk") +} + func TestLogic_udf_in_column_defaults( t *testing.T, ) { diff --git a/pkg/sql/opt/exec/execbuilder/scalar.go b/pkg/sql/opt/exec/execbuilder/scalar.go index 8af8a06171cd..790715e55de5 100644 --- a/pkg/sql/opt/exec/execbuilder/scalar.go +++ b/pkg/sql/opt/exec/execbuilder/scalar.go @@ -1121,9 +1121,6 @@ func (b *Builder) buildRoutinePlanGenerator( if len(eb.cascades) > 0 { return expectedLazyRoutineError("cascade") } - if len(eb.checks) > 0 { - return expectedLazyRoutineError("check") - } isFinalPlan := i == len(stmts)-1 err = fn(plan, isFinalPlan) if err != nil { diff --git a/pkg/sql/opt/optbuilder/testdata/udf b/pkg/sql/opt/optbuilder/testdata/udf index 352d99fbf27d..538e06d0f5c1 100644 --- a/pkg/sql/opt/optbuilder/testdata/udf +++ b/pkg/sql/opt/optbuilder/testdata/udf @@ -1739,3 +1739,110 @@ project │ └── () └── projections └── ups3(1, 2, 3, 4, 5, 6) + +# -------------------------------------------------- +# UDFs with foreign key constraints. +# -------------------------------------------------- + +exec-ddl +CREATE TABLE parent (p INT PRIMARY KEY); +---- + +exec-ddl +CREATE TABLE child (c INT PRIMARY KEY, p INT NOT NULL REFERENCES parent(p)); +---- + +exec-ddl +CREATE FUNCTION f_fk(k INT, r INT) RETURNS RECORD AS $$ + INSERT INTO child VALUES (k,r) RETURNING *; +$$ LANGUAGE SQL; +---- + +opt format=show-scalars +SELECT f_fk(100, 1), f_fk(101, 2); +---- +values + ├── columns: f_fk:14 f_fk:28 + └── tuple + ├── udf: f_fk + │ ├── params: k:1 r:2 + │ ├── args + │ │ ├── const: 100 + │ │ └── const: 1 + │ └── body + │ └── project + │ ├── columns: column13:13!null + │ ├── insert child + │ │ ├── columns: c:3!null child.p:4!null + │ │ ├── insert-mapping: + │ │ │ ├── column1:7 => c:3 + │ │ │ └── column2:8 => child.p:4 + │ │ ├── return-mapping: + │ │ │ ├── column1:7 => c:3 + │ │ │ └── column2:8 => child.p:4 + │ │ ├── input binding: &1 + │ │ ├── values + │ │ │ ├── columns: column1:7 column2:8 + │ │ │ └── tuple + │ │ │ ├── variable: k:1 + │ │ │ └── variable: r:2 + │ │ └── f-k-checks + │ │ └── f-k-checks-item: child(p) -> parent(p) + │ │ └── anti-join (hash) + │ │ ├── columns: p:9 + │ │ ├── with-scan &1 + │ │ │ ├── columns: p:9 + │ │ │ └── mapping: + │ │ │ └── column2:8 => p:9 + │ │ ├── scan parent + │ │ │ ├── columns: parent.p:10!null + │ │ │ └── flags: disabled not visible index feature + │ │ └── filters + │ │ └── eq + │ │ ├── variable: p:9 + │ │ └── variable: parent.p:10 + │ └── projections + │ └── tuple [as=column13:13] + │ ├── variable: c:3 + │ └── variable: child.p:4 + └── udf: f_fk + ├── params: k:15 r:16 + ├── args + │ ├── const: 101 + │ └── const: 2 + └── body + └── project + ├── columns: column27:27!null + ├── insert child + │ ├── columns: c:17!null child.p:18!null + │ ├── insert-mapping: + │ │ ├── column1:21 => c:17 + │ │ └── column2:22 => child.p:18 + │ ├── return-mapping: + │ │ ├── column1:21 => c:17 + │ │ └── column2:22 => child.p:18 + │ ├── input binding: &2 + │ ├── values + │ │ ├── columns: column1:21 column2:22 + │ │ └── tuple + │ │ ├── variable: k:15 + │ │ └── variable: r:16 + │ └── f-k-checks + │ └── f-k-checks-item: child(p) -> parent(p) + │ └── anti-join (hash) + │ ├── columns: p:23 + │ ├── with-scan &2 + │ │ ├── columns: p:23 + │ │ └── mapping: + │ │ └── column2:22 => p:23 + │ ├── scan parent + │ │ ├── columns: parent.p:24!null + │ │ └── flags: disabled not visible index feature + │ └── filters + │ └── eq + │ ├── variable: p:23 + │ └── variable: parent.p:24 + └── projections + └── tuple [as=column27:27] + ├── variable: c:17 + └── variable: child.p:18 diff --git a/pkg/sql/planner.go b/pkg/sql/planner.go index 01f0c4a75a3e..37d29f829004 100644 --- a/pkg/sql/planner.go +++ b/pkg/sql/planner.go @@ -555,6 +555,22 @@ func internalExtendedEvalCtx( return ret } +// ExtendedEvalContextCopyAndReset returns a function that produces +// extendedEvalContexts for parallel subquery, cascade, and check execution. +func (p *planner) ExtendedEvalContextCopyAndReset() *extendedEvalContext { + evalCtx := p.ExtendedEvalContextCopy() + p.ExtendedEvalContextReset(evalCtx) + return evalCtx +} + +// ExtendedEvalContextReset resets context fields so that the context may be +// reused across subquery, cascade, and check execution. +func (p *planner) ExtendedEvalContextReset(evalCtx *extendedEvalContext) { + evalCtx.Placeholders = &p.semaCtx.Placeholders + evalCtx.Annotations = &p.semaCtx.Annotations + evalCtx.SessionID = p.ExtendedEvalContext().SessionID +} + // SemaCtx provides access to the planner's SemaCtx. func (p *planner) SemaCtx() *tree.SemaContext { return &p.semaCtx