Skip to content

Commit

Permalink
Fix aliasing of fields in mask expressions
Browse files Browse the repository at this point in the history
This causes two problems:
* Masks can inadvertently refer to columns that appear earlier
  in the list of columns as the projection is planned. This
  causes the mask expression to see the masked value of other columns
  instead of the underlying value
* A possible bug in other optimizers causes mask expressions to be
  lost when the result of a such an expression is of type ROW
  and there's a dereference of a field downstream
  • Loading branch information
martint committed Jan 17, 2023
1 parent d95101b commit fc57354
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -305,28 +304,33 @@ private RelationPlan addColumnMasks(Table table, RelationPlan plan)
PlanBuilder planBuilder = newPlanBuilder(plan, analysis, lambdaDeclarationToSymbolMap, session, plannerContext)
.withScope(analysis.getAccessControlScope(table), plan.getFieldMappings()); // The fields in the access control scope has the same layout as those for the table scope

Assignments.Builder assignments = Assignments.builder();
assignments.putIdentities(planBuilder.getRoot().getOutputSymbols());

List<Symbol> fieldMappings = new ArrayList<>();
for (int i = 0; i < plan.getDescriptor().getAllFieldCount(); i++) {
Field field = plan.getDescriptor().getFieldByIndex(i);

Expression mask = columnMasks.get(field.getName().orElseThrow());
Symbol symbol = plan.getFieldMappings().get(i);
Expression projection = symbol.toSymbolReference();
if (mask != null) {
planBuilder = subqueryPlanner.handleSubqueries(planBuilder, mask, analysis.getSubqueries(mask));

Map<Symbol, Expression> assignments = new LinkedHashMap<>();
for (Symbol symbol : planBuilder.getRoot().getOutputSymbols()) {
assignments.put(symbol, symbol.toSymbolReference());
}
assignments.put(plan.getFieldMappings().get(i), coerceIfNecessary(analysis, mask, planBuilder.rewrite(mask)));

planBuilder = planBuilder
.withNewRoot(new ProjectNode(
idAllocator.getNextId(),
planBuilder.getRoot(),
Assignments.copyOf(assignments)));
symbol = symbolAllocator.newSymbol(symbol);
projection = coerceIfNecessary(analysis, mask, planBuilder.rewrite(mask));
}

assignments.put(symbol, projection);
fieldMappings.add(symbol);
}

return new RelationPlan(planBuilder.getRoot(), plan.getScope(), plan.getFieldMappings(), outerContext);
planBuilder = planBuilder
.withNewRoot(new ProjectNode(
idAllocator.getNextId(),
planBuilder.getRoot(),
assignments.build()));

return new RelationPlan(planBuilder.getRoot(), plan.getScope(), fieldMappings, outerContext);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import io.trino.spi.security.Identity;
import io.trino.spi.security.ViewExpression;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.VarcharType;
import io.trino.testing.LocalQueryRunner;
import io.trino.testing.TestingAccessControlManager;
Expand All @@ -41,6 +42,7 @@
import static io.trino.connector.MockConnectorEntities.TPCH_NATION_WITH_HIDDEN_COLUMN;
import static io.trino.connector.MockConnectorEntities.TPCH_WITH_HIDDEN_COLUMN_DATA;
import static io.trino.plugin.tpch.TpchMetadata.TINY_SCHEMA_NAME;
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.testing.TestingAccessControlManager.TestingPrivilegeType.SELECT_COLUMN;
import static io.trino.testing.TestingAccessControlManager.privilege;
import static io.trino.testing.TestingHandles.TEST_CATALOG_NAME;
Expand Down Expand Up @@ -86,6 +88,26 @@ public void init()
Optional.of(VIEW_OWNER),
false);

ConnectorViewDefinition viewWithNested = new ConnectorViewDefinition(
"""
SELECT * FROM (
VALUES
ROW(ROW(1,2), 0),
ROW(ROW(3,4), 1)
) t(nested, id)
""",
Optional.empty(),
Optional.empty(),
ImmutableList.of(
new ConnectorViewDefinition.ViewColumn("nested", RowType.from(ImmutableList.of(
RowType.field(INTEGER),
RowType.field(INTEGER))).getTypeId(),
Optional.empty()),
new ConnectorViewDefinition.ViewColumn("id", INTEGER.getTypeId(), Optional.empty())),
Optional.empty(),
Optional.of(VIEW_OWNER),
false);

ConnectorMaterializedViewDefinition materializedView = new ConnectorMaterializedViewDefinition(
"SELECT * FROM local.tiny.nation",
Optional.empty(),
Expand Down Expand Up @@ -142,7 +164,8 @@ public void init()
throw new UnsupportedOperationException();
})
.withGetViews((s, prefix) -> ImmutableMap.of(
new SchemaTableName("default", "nation_view"), view))
new SchemaTableName("default", "nation_view"), view,
new SchemaTableName("default", "view_with_nested"), viewWithNested))
.withGetMaterializedViews((s, prefix) -> ImmutableMap.of(
new SchemaTableName("default", "nation_materialized_view"), materializedView,
new SchemaTableName("default", "nation_fresh_materialized_view"), freshMaterializedView,
Expand Down Expand Up @@ -309,7 +332,7 @@ public void testMaterializedView()
.setIdentity(Identity.forUser(USER).build())
.build(),
"SELECT name FROM mock.default.materialized_view_with_casts WHERE nationkey = 1"))
.matches("VALUES CAST('RA' AS VARCHAR(2))");
.matches("VALUES 'RA'");
}

@Test
Expand Down Expand Up @@ -812,7 +835,7 @@ public void testMultipleMasksUsingOtherMaskedColumns()
new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"),
"clerk",
USER,
new ViewExpression(USER, Optional.empty(), Optional.empty(), "cast(regexp_replace(clerk,'(Clerk#)','***#') as varchar(15))"));
new ViewExpression(USER, Optional.empty(), Optional.empty(), "cast('###' as varchar(15))"));

accessControl.columnMask(
new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"),
Expand All @@ -827,6 +850,20 @@ public void testMultipleMasksUsingOtherMaskedColumns()
new ViewExpression(USER, Optional.empty(), Optional.empty(), "if(regexp_extract(clerk,'([1-9]+)') IN ('951'), '***', comment)"));

assertThat(assertions.query(query))
.matches("VALUES (CAST('***' as varchar(79)), '*', CAST('***#000000951' as varchar(15)))");
.matches("VALUES (CAST('***' as varchar(79)), '*', CAST('###' as varchar(15)))");
}

@Test
public void testColumnAliasing()
{
accessControl.reset();
accessControl.columnMask(
new QualifiedObjectName(MOCK_CATALOG, "default", "view_with_nested"),
"nested",
USER,
new ViewExpression(USER, Optional.empty(), Optional.empty(), "if(id = 0, nested)"));

assertThat(assertions.query("SELECT nested[1] FROM mock.default.view_with_nested"))
.matches("VALUES 1, NULL");
}
}

0 comments on commit fc57354

Please sign in to comment.