From 766771696aff1f35f5dbeb517b55beb4f2d80f74 Mon Sep 17 00:00:00 2001 From: Konrad Dziedzic Date: Tue, 2 Aug 2022 15:52:42 +0200 Subject: [PATCH] Prevent inserting nulls coming from expressions into nonnullable columns --- .../io/trino/sql/planner/LogicalPlanner.java | 26 +++++++++++++------ .../BaseClickHouseConnectorTest.java | 9 +++++++ .../BaseDeltaLakeMinioConnectorTest.java | 6 ++--- .../mariadb/BaseMariaDbConnectorTest.java | 26 +++++++++++++++++++ .../io/trino/testing/BaseConnectorTest.java | 5 ++++ 5 files changed, 61 insertions(+), 11 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java index e0e14f4dbe2ce..fb3aa3c5b796d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java @@ -123,6 +123,7 @@ import static io.trino.SystemSessionProperties.isCollectPlanStatisticsForAllQueries; import static io.trino.metadata.MetadataUtil.createQualifiedObjectName; import static io.trino.spi.StandardErrorCode.CATALOG_NOT_FOUND; +import static io.trino.spi.StandardErrorCode.INVALID_ARGUMENTS; import static io.trino.spi.StandardErrorCode.INVALID_CAST_ARGUMENT; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.StandardErrorCode.PERMISSION_DENIED; @@ -451,29 +452,31 @@ private RelationPlan getInsertPlan( continue; } Symbol output = symbolAllocator.newSymbol(column.getName(), column.getType()); + Expression expression; + Type tableType = column.getType(); int index = insertColumns.indexOf(columns.get(column.getName())); if (index < 0) { if (supportsMissingColumnsOnInsert) { continue; } - Expression cast = new Cast(new NullLiteral(), toSqlType(column.getType())); - assignments.put(output, cast); - insertedColumnsBuilder.add(column); + expression = new Cast(new NullLiteral(), toSqlType(column.getType())); } else { Symbol input = visibleFieldMappings.get(index); - Type tableType = column.getType(); Type queryType = symbolAllocator.getTypes().get(input); if (queryType.equals(tableType) || typeCoercion.isTypeOnlyCoercion(queryType, tableType)) { - assignments.put(output, input.toSymbolReference()); + expression = input.toSymbolReference(); } else { - Expression cast = noTruncationCast(input.toSymbolReference(), queryType, tableType); - assignments.put(output, cast); + expression = noTruncationCast(input.toSymbolReference(), queryType, tableType); } - insertedColumnsBuilder.add(column); } + if (!column.isNullable()) { + expression = new CoalesceExpression(expression, createNullNotAllowedFailExpression(column.getName(), tableType)); + } + assignments.put(output, expression); + insertedColumnsBuilder.add(column); } ProjectNode projectNode = new ProjectNode(idAllocator.getNextId(), plan.getRoot(), assignments.build()); @@ -532,6 +535,13 @@ private RelationPlan getInsertPlan( statisticsMetadata); } + private Expression createNullNotAllowedFailExpression(String columnName, Type type) + { + return new Cast(failFunction(metadata, session, INVALID_ARGUMENTS, format( + "NULL value not allowed for NOT NULL column: %s", + columnName)), toSqlType(type)); + } + private static Function failIfPredicateIsNotMet(Metadata metadata, Session session, ErrorCodeSupplier errorCode, String errorMessage) { FunctionCall fail = failFunction(metadata, session, errorCode, errorMessage); diff --git a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/BaseClickHouseConnectorTest.java b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/BaseClickHouseConnectorTest.java index fe18af7c31c20..952c8c9e06324 100644 --- a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/BaseClickHouseConnectorTest.java +++ b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/BaseClickHouseConnectorTest.java @@ -544,6 +544,15 @@ public void testInsertIntoNotNullColumn() assertQuery("SELECT * FROM " + table.getName(), "VALUES (NULL, 2)"); assertQueryFails(format("INSERT INTO %s (not_null_col, nullable_col) VALUES (NULL, 3)", table.getName()), "NULL value not allowed for NOT NULL column: not_null_col"); } + + try (TestTable table = new TestTable(getQueryRunner()::execute, "not_null_no_cast", "(nullable_col INTEGER, not_null_col INTEGER NOT NULL)")) { + assertUpdate(format("INSERT INTO %s (not_null_col) VALUES (2)", table.getName()), 1); + assertQuery("SELECT * FROM " + table.getName(), "VALUES (NULL, 2)"); + // This is enforced by the engine and not the connector + assertQueryFails(format("INSERT INTO %s (not_null_col, nullable_col) VALUES (NULL, 3)", table.getName()), "NULL value not allowed for NOT NULL column: not_null_col"); + assertQueryFails(format("INSERT INTO %s (not_null_col, nullable_col) VALUES (TRY(5/0), 4)", table.getName()), "NULL value not allowed for NOT NULL column: not_null_col"); + assertQueryFails(format("INSERT INTO %s (not_null_col) VALUES (TRY(6/0))", table.getName()), "NULL value not allowed for NOT NULL column: not_null_col"); + } } @Override diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeMinioConnectorTest.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeMinioConnectorTest.java index 35ad764aab709..6529cd8eb33d6 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeMinioConnectorTest.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeMinioConnectorTest.java @@ -815,14 +815,14 @@ public void testTableWithNonNullableColumns() assertUpdate("INSERT INTO " + tableName + " VALUES(2, 20, 200)", 1); assertThatThrownBy(() -> query("INSERT INTO " + tableName + " VALUES(null, 30, 300)")) .hasMessageContaining("NULL value not allowed for NOT NULL column: col1"); + assertThatThrownBy(() -> query("INSERT INTO " + tableName + " VALUES(TRY(5/0), 40, 400)")) + .hasMessageContaining("NULL value not allowed for NOT NULL column: col1"); - //TODO this should fail https://github.com/trinodb/trino/issues/13434 - assertUpdate("INSERT INTO " + tableName + " VALUES(TRY(5/0), 40, 400)", 1); //TODO these 2 should fail https://github.com/trinodb/trino/issues/13435 assertUpdate("UPDATE " + tableName + " SET col2 = NULL where col3 = 100", 1); assertUpdate("UPDATE " + tableName + " SET col2 = TRY(5/0) where col3 = 200", 1); - assertQuery("SELECT * FROM " + tableName, "VALUES(1, null, 100), (2, null, 200), (null, 40, 400)"); + assertQuery("SELECT * FROM " + tableName, "VALUES(1, null, 100), (2, null, 200)"); } @Override diff --git a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbConnectorTest.java b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbConnectorTest.java index 697f543b0ec05..9d5f32ce502c0 100644 --- a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbConnectorTest.java +++ b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbConnectorTest.java @@ -262,6 +262,32 @@ public void testDeleteWithLike() .hasStackTraceContaining("TrinoException: Unsupported delete"); } + // Overridden because the method from BaseConnectorTest fails on one of the assertions, see TODO below + @Test + @Override + public void testInsertIntoNotNullColumn() + { + try (TestTable table = new TestTable(getQueryRunner()::execute, "insert_not_null", "(nullable_col INTEGER, not_null_col INTEGER NOT NULL)")) { + assertUpdate(format("INSERT INTO %s (not_null_col) VALUES (2)", table.getName()), 1); + assertQuery("SELECT * FROM " + table.getName(), "VALUES (NULL, 2)"); + assertQueryFails(format("INSERT INTO %s (nullable_col) VALUES (1)", table.getName()), errorMessageForInsertIntoNotNullColumn("not_null_col")); + assertQueryFails(format("INSERT INTO %s (not_null_col, nullable_col) VALUES (NULL, 3)", table.getName()), "NULL value not allowed for NOT NULL column: not_null_col"); + assertQueryFails(format("INSERT INTO %s (not_null_col, nullable_col) VALUES (TRY(5/0), 4)", table.getName()), "NULL value not allowed for NOT NULL column: not_null_col"); + assertQueryFails(format("INSERT INTO %s (not_null_col) VALUES (TRY(6/0))", table.getName()), "NULL value not allowed for NOT NULL column: not_null_col"); + assertQueryFails(format("INSERT INTO %s (nullable_col) SELECT nationkey FROM nation", table.getName()), errorMessageForInsertIntoNotNullColumn("not_null_col")); + // TODO (https://github.com/trinodb/trino/issues/13551) This doesn't fail for other connectors so + // probably shouldn't fail for MariaDB either. Once fixed, remove test override. + assertQueryFails(format("INSERT INTO %s (nullable_col) SELECT nationkey FROM nation WHERE regionkey < 0", table.getName()), ".*Field 'not_null_col' doesn't have a default value.*"); + } + + try (TestTable table = new TestTable(getQueryRunner()::execute, "commuted_not_null", "(nullable_col BIGINT, not_null_col BIGINT NOT NULL)")) { + assertUpdate(format("INSERT INTO %s (not_null_col) VALUES (2)", table.getName()), 1); + assertQuery("SELECT * FROM " + table.getName(), "VALUES (NULL, 2)"); + // This is enforced by the engine and not the connector + assertQueryFails(format("INSERT INTO %s (not_null_col, nullable_col) VALUES (NULL, 3)", table.getName()), "NULL value not allowed for NOT NULL column: not_null_col"); + } + } + @Override public void testNativeQueryCreateStatement() { diff --git a/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java b/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java index 7dcd27f693d0a..e1d2dadb31d9e 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java @@ -2809,6 +2809,11 @@ public void testInsertIntoNotNullColumn() assertQuery("SELECT * FROM " + table.getName(), "VALUES (NULL, 2)"); // The error message comes from remote databases when ConnectorMetadata.supportsMissingColumnsOnInsert is true assertQueryFails(format("INSERT INTO %s (nullable_col) VALUES (1)", table.getName()), errorMessageForInsertIntoNotNullColumn("not_null_col")); + assertQueryFails(format("INSERT INTO %s (not_null_col, nullable_col) VALUES (NULL, 3)", table.getName()), "NULL value not allowed for NOT NULL column: not_null_col"); + assertQueryFails(format("INSERT INTO %s (not_null_col, nullable_col) VALUES (TRY(5/0), 4)", table.getName()), "NULL value not allowed for NOT NULL column: not_null_col"); + assertQueryFails(format("INSERT INTO %s (not_null_col) VALUES (TRY(6/0))", table.getName()), "NULL value not allowed for NOT NULL column: not_null_col"); + assertQueryFails(format("INSERT INTO %s (nullable_col) SELECT nationkey FROM nation", table.getName()), errorMessageForInsertIntoNotNullColumn("not_null_col")); + assertUpdate(format("INSERT INTO %s (nullable_col) SELECT nationkey FROM nation WHERE regionkey < 0", table.getName()), 0); } try (TestTable table = new TestTable(getQueryRunner()::execute, "commuted_not_null", "(nullable_col BIGINT, not_null_col BIGINT NOT NULL)")) {