Skip to content

Commit

Permalink
Prevent inserting nulls coming from expressions into nonnullable columns
Browse files Browse the repository at this point in the history
  • Loading branch information
homar authored and findepi committed Aug 10, 2022
1 parent d8a0b2b commit 7667716
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@
import static io.trino.SystemSessionProperties.isCollectPlanStatisticsForAllQueries;
import static io.trino.metadata.MetadataUtil.createQualifiedObjectName;
import static io.trino.spi.StandardErrorCode.CATALOG_NOT_FOUND;
import static io.trino.spi.StandardErrorCode.INVALID_ARGUMENTS;
import static io.trino.spi.StandardErrorCode.INVALID_CAST_ARGUMENT;
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
import static io.trino.spi.StandardErrorCode.PERMISSION_DENIED;
Expand Down Expand Up @@ -451,29 +452,31 @@ private RelationPlan getInsertPlan(
continue;
}
Symbol output = symbolAllocator.newSymbol(column.getName(), column.getType());
Expression expression;
Type tableType = column.getType();
int index = insertColumns.indexOf(columns.get(column.getName()));
if (index < 0) {
if (supportsMissingColumnsOnInsert) {
continue;
}
Expression cast = new Cast(new NullLiteral(), toSqlType(column.getType()));
assignments.put(output, cast);
insertedColumnsBuilder.add(column);
expression = new Cast(new NullLiteral(), toSqlType(column.getType()));
}
else {
Symbol input = visibleFieldMappings.get(index);
Type tableType = column.getType();
Type queryType = symbolAllocator.getTypes().get(input);

if (queryType.equals(tableType) || typeCoercion.isTypeOnlyCoercion(queryType, tableType)) {
assignments.put(output, input.toSymbolReference());
expression = input.toSymbolReference();
}
else {
Expression cast = noTruncationCast(input.toSymbolReference(), queryType, tableType);
assignments.put(output, cast);
expression = noTruncationCast(input.toSymbolReference(), queryType, tableType);
}
insertedColumnsBuilder.add(column);
}
if (!column.isNullable()) {
expression = new CoalesceExpression(expression, createNullNotAllowedFailExpression(column.getName(), tableType));
}
assignments.put(output, expression);
insertedColumnsBuilder.add(column);
}

ProjectNode projectNode = new ProjectNode(idAllocator.getNextId(), plan.getRoot(), assignments.build());
Expand Down Expand Up @@ -532,6 +535,13 @@ private RelationPlan getInsertPlan(
statisticsMetadata);
}

private Expression createNullNotAllowedFailExpression(String columnName, Type type)
{
return new Cast(failFunction(metadata, session, INVALID_ARGUMENTS, format(
"NULL value not allowed for NOT NULL column: %s",
columnName)), toSqlType(type));
}

private static Function<Expression, Expression> failIfPredicateIsNotMet(Metadata metadata, Session session, ErrorCodeSupplier errorCode, String errorMessage)
{
FunctionCall fail = failFunction(metadata, session, errorCode, errorMessage);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,15 @@ public void testInsertIntoNotNullColumn()
assertQuery("SELECT * FROM " + table.getName(), "VALUES (NULL, 2)");
assertQueryFails(format("INSERT INTO %s (not_null_col, nullable_col) VALUES (NULL, 3)", table.getName()), "NULL value not allowed for NOT NULL column: not_null_col");
}

try (TestTable table = new TestTable(getQueryRunner()::execute, "not_null_no_cast", "(nullable_col INTEGER, not_null_col INTEGER NOT NULL)")) {
assertUpdate(format("INSERT INTO %s (not_null_col) VALUES (2)", table.getName()), 1);
assertQuery("SELECT * FROM " + table.getName(), "VALUES (NULL, 2)");
// This is enforced by the engine and not the connector
assertQueryFails(format("INSERT INTO %s (not_null_col, nullable_col) VALUES (NULL, 3)", table.getName()), "NULL value not allowed for NOT NULL column: not_null_col");
assertQueryFails(format("INSERT INTO %s (not_null_col, nullable_col) VALUES (TRY(5/0), 4)", table.getName()), "NULL value not allowed for NOT NULL column: not_null_col");
assertQueryFails(format("INSERT INTO %s (not_null_col) VALUES (TRY(6/0))", table.getName()), "NULL value not allowed for NOT NULL column: not_null_col");
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -815,14 +815,14 @@ public void testTableWithNonNullableColumns()
assertUpdate("INSERT INTO " + tableName + " VALUES(2, 20, 200)", 1);
assertThatThrownBy(() -> query("INSERT INTO " + tableName + " VALUES(null, 30, 300)"))
.hasMessageContaining("NULL value not allowed for NOT NULL column: col1");
assertThatThrownBy(() -> query("INSERT INTO " + tableName + " VALUES(TRY(5/0), 40, 400)"))
.hasMessageContaining("NULL value not allowed for NOT NULL column: col1");

//TODO this should fail https://github.com/trinodb/trino/issues/13434
assertUpdate("INSERT INTO " + tableName + " VALUES(TRY(5/0), 40, 400)", 1);
//TODO these 2 should fail https://github.com/trinodb/trino/issues/13435
assertUpdate("UPDATE " + tableName + " SET col2 = NULL where col3 = 100", 1);
assertUpdate("UPDATE " + tableName + " SET col2 = TRY(5/0) where col3 = 200", 1);

assertQuery("SELECT * FROM " + tableName, "VALUES(1, null, 100), (2, null, 200), (null, 40, 400)");
assertQuery("SELECT * FROM " + tableName, "VALUES(1, null, 100), (2, null, 200)");
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,32 @@ public void testDeleteWithLike()
.hasStackTraceContaining("TrinoException: Unsupported delete");
}

// Overridden because the method from BaseConnectorTest fails on one of the assertions, see TODO below
@Test
@Override
public void testInsertIntoNotNullColumn()
{
try (TestTable table = new TestTable(getQueryRunner()::execute, "insert_not_null", "(nullable_col INTEGER, not_null_col INTEGER NOT NULL)")) {
assertUpdate(format("INSERT INTO %s (not_null_col) VALUES (2)", table.getName()), 1);
assertQuery("SELECT * FROM " + table.getName(), "VALUES (NULL, 2)");
assertQueryFails(format("INSERT INTO %s (nullable_col) VALUES (1)", table.getName()), errorMessageForInsertIntoNotNullColumn("not_null_col"));
assertQueryFails(format("INSERT INTO %s (not_null_col, nullable_col) VALUES (NULL, 3)", table.getName()), "NULL value not allowed for NOT NULL column: not_null_col");
assertQueryFails(format("INSERT INTO %s (not_null_col, nullable_col) VALUES (TRY(5/0), 4)", table.getName()), "NULL value not allowed for NOT NULL column: not_null_col");
assertQueryFails(format("INSERT INTO %s (not_null_col) VALUES (TRY(6/0))", table.getName()), "NULL value not allowed for NOT NULL column: not_null_col");
assertQueryFails(format("INSERT INTO %s (nullable_col) SELECT nationkey FROM nation", table.getName()), errorMessageForInsertIntoNotNullColumn("not_null_col"));
// TODO (https://github.com/trinodb/trino/issues/13551) This doesn't fail for other connectors so
// probably shouldn't fail for MariaDB either. Once fixed, remove test override.
assertQueryFails(format("INSERT INTO %s (nullable_col) SELECT nationkey FROM nation WHERE regionkey < 0", table.getName()), ".*Field 'not_null_col' doesn't have a default value.*");
}

try (TestTable table = new TestTable(getQueryRunner()::execute, "commuted_not_null", "(nullable_col BIGINT, not_null_col BIGINT NOT NULL)")) {
assertUpdate(format("INSERT INTO %s (not_null_col) VALUES (2)", table.getName()), 1);
assertQuery("SELECT * FROM " + table.getName(), "VALUES (NULL, 2)");
// This is enforced by the engine and not the connector
assertQueryFails(format("INSERT INTO %s (not_null_col, nullable_col) VALUES (NULL, 3)", table.getName()), "NULL value not allowed for NOT NULL column: not_null_col");
}
}

@Override
public void testNativeQueryCreateStatement()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2809,6 +2809,11 @@ public void testInsertIntoNotNullColumn()
assertQuery("SELECT * FROM " + table.getName(), "VALUES (NULL, 2)");
// The error message comes from remote databases when ConnectorMetadata.supportsMissingColumnsOnInsert is true
assertQueryFails(format("INSERT INTO %s (nullable_col) VALUES (1)", table.getName()), errorMessageForInsertIntoNotNullColumn("not_null_col"));
assertQueryFails(format("INSERT INTO %s (not_null_col, nullable_col) VALUES (NULL, 3)", table.getName()), "NULL value not allowed for NOT NULL column: not_null_col");
assertQueryFails(format("INSERT INTO %s (not_null_col, nullable_col) VALUES (TRY(5/0), 4)", table.getName()), "NULL value not allowed for NOT NULL column: not_null_col");
assertQueryFails(format("INSERT INTO %s (not_null_col) VALUES (TRY(6/0))", table.getName()), "NULL value not allowed for NOT NULL column: not_null_col");
assertQueryFails(format("INSERT INTO %s (nullable_col) SELECT nationkey FROM nation", table.getName()), errorMessageForInsertIntoNotNullColumn("not_null_col"));
assertUpdate(format("INSERT INTO %s (nullable_col) SELECT nationkey FROM nation WHERE regionkey < 0", table.getName()), 0);
}

try (TestTable table = new TestTable(getQueryRunner()::execute, "commuted_not_null", "(nullable_col BIGINT, not_null_col BIGINT NOT NULL)")) {
Expand Down

0 comments on commit 7667716

Please sign in to comment.