Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issues due to mask aliasing #15680

Merged
merged 2 commits into from
Jan 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -566,8 +566,8 @@ default List<ViewExpression> getRowFilters(SecurityContext context, QualifiedObj
return ImmutableList.of();
}

default List<ViewExpression> getColumnMasks(SecurityContext context, QualifiedObjectName tableName, String columnName, Type type)
default Optional<ViewExpression> getColumnMask(SecurityContext context, QualifiedObjectName tableName, String columnName, Type type)
{
return ImmutableList.of();
return Optional.empty();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
import static com.google.common.base.Strings.isNullOrEmpty;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.airlift.configuration.ConfigurationLoader.loadPropertiesFrom;
import static io.trino.spi.StandardErrorCode.INVALID_COLUMN_MASK;
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
import static io.trino.spi.StandardErrorCode.SERVER_STARTING_UP;
import static java.lang.String.format;
Expand Down Expand Up @@ -1248,26 +1249,30 @@ public List<ViewExpression> getRowFilters(SecurityContext context, QualifiedObje
}

@Override
public List<ViewExpression> getColumnMasks(SecurityContext context, QualifiedObjectName tableName, String columnName, Type type)
public Optional<ViewExpression> getColumnMask(SecurityContext context, QualifiedObjectName tableName, String columnName, Type type)
{
requireNonNull(context, "context is null");
requireNonNull(tableName, "tableName is null");

ImmutableList.Builder<ViewExpression> masks = ImmutableList.builder();

// connector-provided masks take precedence over global masks
ConnectorAccessControl connectorAccessControl = getConnectorAccessControl(context.getTransactionId(), tableName.getCatalogName());
if (connectorAccessControl != null) {
connectorAccessControl.getColumnMasks(toConnectorSecurityContext(tableName.getCatalogName(), context), tableName.asSchemaTableName(), columnName, type)
.forEach(masks::add);
connectorAccessControl.getColumnMask(toConnectorSecurityContext(tableName.getCatalogName(), context), tableName.asSchemaTableName(), columnName, type)
.ifPresent(masks::add);
}

for (SystemAccessControl systemAccessControl : getSystemAccessControls()) {
systemAccessControl.getColumnMasks(context.toSystemSecurityContext(), tableName.asCatalogSchemaTableName(), columnName, type)
.forEach(masks::add);
systemAccessControl.getColumnMask(context.toSystemSecurityContext(), tableName.asCatalogSchemaTableName(), columnName, type)
martint marked this conversation as resolved.
Show resolved Hide resolved
.ifPresent(masks::add);
}

return masks.build();
List<ViewExpression> allMasks = masks.build();
if (allMasks.size() > 1) {
throw new TrinoException(INVALID_COLUMN_MASK, format("Column must have a single mask: %s", columnName));
}

return allMasks.stream().findFirst();
martint marked this conversation as resolved.
Show resolved Hide resolved
}

private ConnectorAccessControl getConnectorAccessControl(TransactionId transactionId, String catalogName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -486,8 +486,8 @@ public List<ViewExpression> getRowFilters(SecurityContext context, QualifiedObje
}

@Override
public List<ViewExpression> getColumnMasks(SecurityContext context, QualifiedObjectName tableName, String columnName, Type type)
public Optional<ViewExpression> getColumnMask(SecurityContext context, QualifiedObjectName tableName, String columnName, Type type)
{
return delegate().getColumnMasks(context, tableName, columnName, type);
return delegate().getColumnMask(context, tableName, columnName, type);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -492,15 +492,21 @@ public List<ViewExpression> getRowFilters(ConnectorSecurityContext context, Sche
}

@Override
public List<ViewExpression> getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
public Optional<ViewExpression> getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
{
checkArgument(context == null, "context must be null");
if (accessControl.getColumnMasks(securityContext, new QualifiedObjectName(catalogName, tableName.getSchemaName(), tableName.getTableName()), columnName, type).isEmpty()) {
return ImmutableList.of();
if (accessControl.getColumnMask(securityContext, new QualifiedObjectName(catalogName, tableName.getSchemaName(), tableName.getTableName()), columnName, type).isEmpty()) {
return Optional.empty();
}
throw new TrinoException(NOT_SUPPORTED, "Column masking not supported");
}

@Override
public List<ViewExpression> getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
{
throw new UnsupportedOperationException();
}

private QualifiedObjectName getQualifiedObjectName(SchemaTableName schemaTableName)
{
return new QualifiedObjectName(catalogName, schemaTableName.getSchemaName(), schemaTableName.getTableName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.trino.spi.type.Type;

import java.util.List;
import java.util.Optional;
import java.util.Set;

import static com.google.common.base.Verify.verify;
Expand Down Expand Up @@ -92,9 +93,9 @@ public List<ViewExpression> getRowFilters(SecurityContext context, QualifiedObje
}

@Override
public List<ViewExpression> getColumnMasks(SecurityContext context, QualifiedObjectName tableName, String columnName, Type type)
public Optional<ViewExpression> getColumnMask(SecurityContext context, QualifiedObjectName tableName, String columnName, Type type)
{
return delegate.getColumnMasks(context, tableName, columnName, type);
return delegate.getColumnMask(context, tableName, columnName, type);
}

private static void wrapAccessDeniedException(Runnable runnable)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ public class Analysis
private final Map<NodeRef<Table>, List<Expression>> rowFilters = new LinkedHashMap<>();

private final Multiset<ColumnMaskScopeEntry> columnMaskScopes = HashMultiset.create();
private final Map<NodeRef<Table>, Map<String, List<Expression>>> columnMasks = new LinkedHashMap<>();
private final Map<NodeRef<Table>, Map<String, Expression>> columnMasks = new LinkedHashMap<>();

private final Map<NodeRef<Unnest>, UnnestAnalysis> unnestAnalysis = new LinkedHashMap<>();
private Optional<Create> create = Optional.empty();
Expand Down Expand Up @@ -1093,12 +1093,13 @@ public void unregisterTableForColumnMasking(QualifiedObjectName table, String co

public void addColumnMask(Table table, String column, Expression mask)
{
Map<String, List<Expression>> masks = columnMasks.computeIfAbsent(NodeRef.of(table), node -> new LinkedHashMap<>());
masks.computeIfAbsent(column, name -> new ArrayList<>())
.add(mask);
Map<String, Expression> masks = columnMasks.computeIfAbsent(NodeRef.of(table), node -> new LinkedHashMap<>());
checkArgument(!masks.containsKey(column), "Mask already exists for column %s", column);

masks.put(column, mask);
}

public Map<String, List<Expression>> getColumnMasks(Table table)
public Map<String, Expression> getColumnMasks(Table table)
{
return columnMasks.getOrDefault(NodeRef.of(table), ImmutableMap.of());
}
Expand All @@ -1118,10 +1119,8 @@ public List<TableInfo> getReferencedTables()
.distinct()
.map(fieldName -> new ColumnInfo(
fieldName,
columnMasks.getOrDefault(table, ImmutableMap.of())
.getOrDefault(fieldName, ImmutableList.of()).stream()
.map(Expression::toString)
.collect(toImmutableList())))
Optional.ofNullable(columnMasks.getOrDefault(table, ImmutableMap.of()).get(fieldName))
.map(Expression::toString)))
.collect(toImmutableList());

TableEntry info = entry.getValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ protected Scope visitInsert(Insert insert, Optional<Scope> scope)
.collect(toImmutableList());

for (ColumnSchema column : columns) {
if (!accessControl.getColumnMasks(session.toSecurityContext(), targetTable, column.getName(), column.getType()).isEmpty()) {
if (accessControl.getColumnMask(session.toSecurityContext(), targetTable, column.getName(), column.getType()).isPresent()) {
throw semanticException(NOT_SUPPORTED, insert, "Insert into table with column masks is not supported");
}
}
Expand Down Expand Up @@ -785,7 +785,7 @@ protected Scope visitDelete(Delete node, Optional<Scope> scope)

TableSchema tableSchema = metadata.getTableSchema(session, handle);
for (ColumnSchema tableColumn : tableSchema.getColumns()) {
if (!accessControl.getColumnMasks(session.toSecurityContext(), tableName, tableColumn.getName(), tableColumn.getType()).isEmpty()) {
if (accessControl.getColumnMask(session.toSecurityContext(), tableName, tableColumn.getName(), tableColumn.getType()).isPresent()) {
throw semanticException(NOT_SUPPORTED, node, "Delete from table with column mask");
}
}
Expand Down Expand Up @@ -1149,7 +1149,7 @@ protected Scope visitTableExecute(TableExecute node, Optional<Scope> scope)

TableMetadata tableMetadata = metadata.getTableMetadata(session, tableHandle);
for (ColumnMetadata tableColumn : tableMetadata.getColumns()) {
if (!accessControl.getColumnMasks(session.toSecurityContext(), tableName, tableColumn.getName(), tableColumn.getType()).isEmpty()) {
if (accessControl.getColumnMask(session.toSecurityContext(), tableName, tableColumn.getName(), tableColumn.getType()).isPresent()) {
throw semanticException(NOT_SUPPORTED, node, "ALTER TABLE EXECUTE is not supported for table with column masks");
}
}
Expand Down Expand Up @@ -2222,10 +2222,10 @@ private void analyzeFiltersAndMasks(Table table, QualifiedObjectName name, Optio
for (int index = 0; index < relationType.getAllFieldCount(); index++) {
Field field = relationType.getFieldByIndex(index);
if (field.getName().isPresent()) {
List<ViewExpression> masks = accessControl.getColumnMasks(session.toSecurityContext(), name, field.getName().get(), field.getType());
Optional<ViewExpression> mask = accessControl.getColumnMask(session.toSecurityContext(), name, field.getName().get(), field.getType());

if (!masks.isEmpty() && checkCanSelectFromColumn(name, field.getName().orElseThrow())) {
masks.forEach(mask -> analyzeColumnMask(session.getIdentity().getUser(), table, name, field, accessControlScope, mask));
if (mask.isPresent() && checkCanSelectFromColumn(name, field.getName().orElseThrow())) {
analyzeColumnMask(session.getIdentity().getUser(), table, name, field, accessControlScope, mask.get());
}
}
}
Expand Down Expand Up @@ -3178,7 +3178,7 @@ protected Scope visitUpdate(Update update, Optional<Scope> scope)
// TODO: how to deal with connectors that need to see the pre-image of rows to perform the update without
// flowing that data through the masking logic
for (ColumnSchema tableColumn : allColumns) {
if (!accessControl.getColumnMasks(session.toSecurityContext(), tableName, tableColumn.getName(), tableColumn.getType()).isEmpty()) {
if (accessControl.getColumnMask(session.toSecurityContext(), tableName, tableColumn.getName(), tableColumn.getType()).isPresent()) {
throw semanticException(NOT_SUPPORTED, update, "Updating a table with column masks is not supported");
}
}
Expand Down Expand Up @@ -3307,7 +3307,7 @@ protected Scope visitMerge(Merge merge, Optional<Scope> scope)
Scope joinScope = createAndAssignScope(merge, scope, targetTableScope.getRelationType().joinWith(sourceTableScope.getRelationType()));

for (ColumnSchema column : dataColumnSchemas) {
if (!accessControl.getColumnMasks(session.toSecurityContext(), tableName, column.getName(), column.getType()).isEmpty()) {
if (accessControl.getColumnMask(session.toSecurityContext(), tableName, column.getName(), column.getType()).isPresent()) {
throw semanticException(NOT_SUPPORTED, merge, "Cannot merge into a table with column masks");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -293,7 +292,7 @@ public RelationPlan addRowFilters(Table node, RelationPlan plan, Function<Expres

private RelationPlan addColumnMasks(Table table, RelationPlan plan)
{
Map<String, List<Expression>> columnMasks = analysis.getColumnMasks(table);
Map<String, Expression> columnMasks = analysis.getColumnMasks(table);

// A Table can represent a WITH query, which can have anonymous fields. On the other hand,
// it can't have masks. The loop below expects fields to have proper names, so bail out
Expand All @@ -305,27 +304,33 @@ private RelationPlan addColumnMasks(Table table, RelationPlan plan)
PlanBuilder planBuilder = newPlanBuilder(plan, analysis, lambdaDeclarationToSymbolMap, session, plannerContext)
.withScope(analysis.getAccessControlScope(table), plan.getFieldMappings()); // The fields in the access control scope has the same layout as those for the table scope

Assignments.Builder assignments = Assignments.builder();
assignments.putIdentities(planBuilder.getRoot().getOutputSymbols());

List<Symbol> fieldMappings = new ArrayList<>();
for (int i = 0; i < plan.getDescriptor().getAllFieldCount(); i++) {
Field field = plan.getDescriptor().getFieldByIndex(i);

for (Expression mask : columnMasks.getOrDefault(field.getName().orElseThrow(), ImmutableList.of())) {
Expression mask = columnMasks.get(field.getName().orElseThrow());
Symbol symbol = plan.getFieldMappings().get(i);
Expression projection = symbol.toSymbolReference();
if (mask != null) {
planBuilder = subqueryPlanner.handleSubqueries(planBuilder, mask, analysis.getSubqueries(mask));

Map<Symbol, Expression> assignments = new LinkedHashMap<>();
for (Symbol symbol : planBuilder.getRoot().getOutputSymbols()) {
assignments.put(symbol, symbol.toSymbolReference());
}
assignments.put(plan.getFieldMappings().get(i), coerceIfNecessary(analysis, mask, planBuilder.rewrite(mask)));

planBuilder = planBuilder
.withNewRoot(new ProjectNode(
idAllocator.getNextId(),
planBuilder.getRoot(),
Assignments.copyOf(assignments)));
symbol = symbolAllocator.newSymbol(symbol);
projection = coerceIfNecessary(analysis, mask, planBuilder.rewrite(mask));
}

assignments.put(symbol, projection);
fieldMappings.add(symbol);
}

return new RelationPlan(planBuilder.getRoot(), plan.getScope(), plan.getFieldMappings(), outerContext);
planBuilder = planBuilder
.withNewRoot(new ProjectNode(
idAllocator.getNextId(),
planBuilder.getRoot(),
assignments.build()));

return new RelationPlan(planBuilder.getRoot(), plan.getScope(), fieldMappings, outerContext);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ public class TestingAccessControlManager

private final Set<TestingPrivilege> denyPrivileges = new HashSet<>();
private final Map<RowFilterKey, List<ViewExpression>> rowFilters = new HashMap<>();
private final Map<ColumnMaskKey, List<ViewExpression>> columnMasks = new HashMap<>();
private final Map<ColumnMaskKey, ViewExpression> columnMasks = new HashMap<>();
private Predicate<String> deniedCatalogs = s -> true;
private Predicate<String> deniedSchemas = s -> true;
private Predicate<SchemaTableName> deniedTables = s -> true;
Expand Down Expand Up @@ -175,8 +175,7 @@ public void rowFilter(QualifiedObjectName table, String identity, ViewExpression

public void columnMask(QualifiedObjectName table, String column, String identity, ViewExpression mask)
{
columnMasks.computeIfAbsent(new ColumnMaskKey(identity, table, column), key -> new ArrayList<>())
.add(mask);
columnMasks.put(new ColumnMaskKey(identity, table, column), mask);
}

public void reset()
Expand Down Expand Up @@ -746,13 +745,13 @@ public List<ViewExpression> getRowFilters(SecurityContext context, QualifiedObje
}

@Override
public List<ViewExpression> getColumnMasks(SecurityContext context, QualifiedObjectName tableName, String column, Type type)
public Optional<ViewExpression> getColumnMask(SecurityContext context, QualifiedObjectName tableName, String column, Type type)
{
List<ViewExpression> viewExpressions = columnMasks.get(new ColumnMaskKey(context.getIdentity().getUser(), tableName, column));
if (viewExpressions != null) {
return viewExpressions;
ViewExpression mask = columnMasks.get(new ColumnMaskKey(context.getIdentity().getUser(), tableName, column));
if (mask != null) {
return Optional.of(mask);
}
return super.getColumnMasks(context, tableName, column, type);
return super.getColumnMask(context, tableName, column, type);
}

private boolean shouldDenyPrivilege(String actorName, String entityName, TestingPrivilegeType verb)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,16 @@ public List<ViewExpression> getRowFilters(ConnectorSecurityContext context, Sche
.orElseGet(ImmutableList::of);
}

@Override
public Optional<ViewExpression> getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
{
return Optional.ofNullable(columnMasks.apply(tableName, columnName));
}

@Override
public List<ViewExpression> getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
{
return Optional.ofNullable(columnMasks.apply(tableName, columnName))
.map(ImmutableList::of)
.orElseGet(ImmutableList::of);
throw new UnsupportedOperationException();
}

public void grantSchemaPrivileges(String schemaName, Set<Privilege> privileges, TrinoPrincipal grantee, boolean grantOption)
Expand Down
Loading