Skip to content

Commit

Permalink
Redesign function access control methods
Browse files Browse the repository at this point in the history
Simplify function access control methods for better integration with
catalog functions.
* Return boolean from catalog access check calls instead of throwing
* Return boolean from function check calls instead of throwing
* Merge simple function checks with table function check
* Add canCreateViewWithExecuteFunction for use in ViewAccessControl
  • Loading branch information
dain committed Oct 5, 2023
1 parent d23cae4 commit 051e53c
Show file tree
Hide file tree
Showing 35 changed files with 496 additions and 435 deletions.
21 changes: 5 additions & 16 deletions core/trino-main/src/main/java/io/trino/security/AccessControl.java
Original file line number Diff line number Diff line change
Expand Up @@ -406,14 +406,7 @@ default void checkCanSetViewAuthorization(SecurityContext context, QualifiedObje
*
* @throws AccessDeniedException if not allowed
*/
void checkCanGrantExecuteFunctionPrivilege(SecurityContext context, String functionName, Identity grantee, boolean grantOption);

/**
* Check if identity is allowed to create a view that executes the function.
*
* @throws AccessDeniedException if not allowed
*/
void checkCanGrantExecuteFunctionPrivilege(SecurityContext context, FunctionKind functionKind, QualifiedObjectName functionName, Identity grantee, boolean grantOption);
void checkCanGrantExecuteFunctionPrivilege(SecurityContext context, FunctionKind functionKind, QualifiedObjectName functionName, TrinoPrincipal grantee, boolean grantOption);

/**
* Check if identity is allowed to grant a privilege to the grantee on the specified schema.
Expand Down Expand Up @@ -555,18 +548,14 @@ void checkCanRevokeRoles(SecurityContext context,
void checkCanExecuteProcedure(SecurityContext context, QualifiedObjectName procedureName);

/**
* Check if identity is allowed to execute function
*
* @throws AccessDeniedException if not allowed
* Is the identity allowed to execute function?
*/
void checkCanExecuteFunction(SecurityContext context, String functionName);
boolean canExecuteFunction(SecurityContext context, FunctionKind functionKind, QualifiedObjectName functionName);

/**
* Check if identity is allowed to execute function
*
* @throws AccessDeniedException if not allowed
* Is the identity allowed to create a view that executes the specified function?
*/
void checkCanExecuteFunction(SecurityContext context, FunctionKind functionKind, QualifiedObjectName functionName);
boolean canCreateViewWithExecuteFunction(SecurityContext context, FunctionKind functionKind, QualifiedObjectName functionName);

/**
* Check if identity is allowed to execute given table procedure on given table
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.function.FunctionKind;
import io.trino.spi.security.Identity;
import io.trino.spi.security.PrincipalType;
import io.trino.spi.security.Privilege;
import io.trino.spi.security.SystemAccessControl;
import io.trino.spi.security.SystemAccessControlFactory;
Expand All @@ -68,7 +67,9 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import java.util.function.BiPredicate;
import java.util.function.Consumer;
import java.util.function.Predicate;

import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Strings.isNullOrEmpty;
Expand All @@ -77,6 +78,7 @@
import static io.trino.spi.StandardErrorCode.INVALID_COLUMN_MASK;
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
import static io.trino.spi.StandardErrorCode.SERVER_STARTING_UP;
import static io.trino.spi.security.AccessDeniedException.denyCatalogAccess;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

Expand Down Expand Up @@ -930,20 +932,7 @@ public void checkCanSetMaterializedViewProperties(SecurityContext securityContex
}

@Override
public void checkCanGrantExecuteFunctionPrivilege(SecurityContext securityContext, String functionName, Identity grantee, boolean grantOption)
{
requireNonNull(securityContext, "securityContext is null");
requireNonNull(functionName, "functionName is null");

systemAuthorizationCheck(control -> control.checkCanGrantExecuteFunctionPrivilege(
securityContext.toSystemSecurityContext(),
functionName,
new TrinoPrincipal(PrincipalType.USER, grantee.getUser()),
grantOption));
}

@Override
public void checkCanGrantExecuteFunctionPrivilege(SecurityContext securityContext, FunctionKind functionKind, QualifiedObjectName functionName, Identity grantee, boolean grantOption)
public void checkCanGrantExecuteFunctionPrivilege(SecurityContext securityContext, FunctionKind functionKind, QualifiedObjectName functionName, TrinoPrincipal grantee, boolean grantOption)
{
requireNonNull(securityContext, "securityContext is null");
requireNonNull(functionKind, "functionKind is null");
Expand All @@ -953,7 +942,7 @@ public void checkCanGrantExecuteFunctionPrivilege(SecurityContext securityContex
securityContext.toSystemSecurityContext(),
functionKind,
functionName.asCatalogSchemaRoutineName(),
new TrinoPrincipal(PrincipalType.USER, grantee.getUser()),
grantee,
grantOption));

catalogAuthorizationCheck(
Expand All @@ -963,7 +952,7 @@ public void checkCanGrantExecuteFunctionPrivilege(SecurityContext securityContex
context,
functionKind,
functionName.asSchemaRoutineName(),
new TrinoPrincipal(PrincipalType.USER, grantee.getUser()),
grantee,
grantOption));
}

Expand Down Expand Up @@ -1246,32 +1235,45 @@ public void checkCanExecuteProcedure(SecurityContext securityContext, QualifiedO
}

@Override
public void checkCanExecuteFunction(SecurityContext context, String functionName)
public boolean canExecuteFunction(SecurityContext securityContext, FunctionKind functionKind, QualifiedObjectName functionName)
{
requireNonNull(context, "context is null");
requireNonNull(securityContext, "securityContext is null");
requireNonNull(functionKind, "functionKind is null");
requireNonNull(functionName, "functionName is null");

systemAuthorizationCheck(control -> control.checkCanExecuteFunction(context.toSystemSecurityContext(), functionName));
if (!canAccessCatalog(securityContext, functionName.getCatalogName())) {
return false;
}

if (!systemAuthorizationTest(control -> control.canExecuteFunction(securityContext.toSystemSecurityContext(), functionKind, functionName.asCatalogSchemaRoutineName()))) {
return false;
}

return catalogAuthorizationTest(
functionName.getCatalogName(),
securityContext,
(control, context) -> control.canExecuteFunction(context, functionKind, functionName.asSchemaRoutineName()));
}

@Override
public void checkCanExecuteFunction(SecurityContext securityContext, FunctionKind functionKind, QualifiedObjectName functionName)
public boolean canCreateViewWithExecuteFunction(SecurityContext securityContext, FunctionKind functionKind, QualifiedObjectName functionName)
{
requireNonNull(securityContext, "securityContext is null");
requireNonNull(functionKind, "functionKind is null");
requireNonNull(functionName, "functionName is null");

checkCanAccessCatalog(securityContext, functionName.getCatalogName());
if (!canAccessCatalog(securityContext, functionName.getCatalogName())) {
return false;
}

systemAuthorizationCheck(control -> control.checkCanExecuteFunction(
securityContext.toSystemSecurityContext(),
functionKind,
functionName.asCatalogSchemaRoutineName()));
if (!systemAuthorizationTest(control -> control.canCreateViewWithExecuteFunction(securityContext.toSystemSecurityContext(), functionKind, functionName.asCatalogSchemaRoutineName()))) {
return false;
}

catalogAuthorizationCheck(
return catalogAuthorizationTest(
functionName.getCatalogName(),
securityContext,
(control, context) -> control.checkCanExecuteFunction(context, functionKind, functionName.asSchemaRoutineName()));
(control, context) -> control.canCreateViewWithExecuteFunction(context, functionKind, functionName.asSchemaRoutineName()));
}

@Override
Expand Down Expand Up @@ -1372,16 +1374,33 @@ public CounterStat getAuthorizationFail()

private void checkCanAccessCatalog(SecurityContext securityContext, String catalogName)
{
try {
for (SystemAccessControl systemAccessControl : getSystemAccessControls()) {
systemAccessControl.checkCanAccessCatalog(securityContext.toSystemSecurityContext(), catalogName);
if (!canAccessCatalog(securityContext, catalogName)) {
denyCatalogAccess(catalogName);
}
}

private boolean canAccessCatalog(SecurityContext securityContext, String catalogName)
{
for (SystemAccessControl systemAccessControl : getSystemAccessControls()) {
if (!systemAccessControl.canAccessCatalog(securityContext.toSystemSecurityContext(), catalogName)) {
authorizationFail.update(1);
return false;
}
authorizationSuccess.update(1);
}
catch (TrinoException e) {
authorizationFail.update(1);
throw e;
authorizationSuccess.update(1);
return true;
}

private boolean systemAuthorizationTest(Predicate<SystemAccessControl> check)
{
for (SystemAccessControl systemAccessControl : getSystemAccessControls()) {
if (!check.test(systemAccessControl)) {
authorizationFail.update(1);
return false;
}
}
authorizationSuccess.update(1);
return true;
}

private void systemAuthorizationCheck(Consumer<SystemAccessControl> check)
Expand All @@ -1398,6 +1417,23 @@ private void systemAuthorizationCheck(Consumer<SystemAccessControl> check)
}
}

private boolean catalogAuthorizationTest(String catalogName, SecurityContext securityContext, BiPredicate<ConnectorAccessControl, ConnectorSecurityContext> check)
{
ConnectorAccessControl connectorAccessControl = getConnectorAccessControl(securityContext.getTransactionId(), catalogName);
if (connectorAccessControl == null) {
return true;
}

boolean result = check.test(connectorAccessControl, toConnectorSecurityContext(catalogName, securityContext));
if (result) {
authorizationSuccess.update(1);
}
else {
authorizationFail.update(1);
}
return result;
}

private void catalogAuthorizationCheck(String catalogName, SecurityContext securityContext, BiConsumer<ConnectorAccessControl, ConnectorSecurityContext> check)
{
ConnectorAccessControl connectorAccessControl = getConnectorAccessControl(securityContext.getTransactionId(), catalogName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,19 @@ public void checkCanSetMaterializedViewProperties(SecurityContext context, Quali
}

@Override
public void checkCanGrantExecuteFunctionPrivilege(SecurityContext context, String functionName, Identity grantee, boolean grantOption)
public boolean canExecuteFunction(SecurityContext context, FunctionKind functionKind, QualifiedObjectName functionName)
{
return true;
}

@Override
public void checkCanGrantExecuteFunctionPrivilege(SecurityContext context, FunctionKind functionKind, QualifiedObjectName functionName, Identity grantee, boolean grantOption)
public boolean canCreateViewWithExecuteFunction(SecurityContext context, FunctionKind functionKind, QualifiedObjectName functionName)
{
return true;
}

@Override
public void checkCanGrantExecuteFunctionPrivilege(SecurityContext context, FunctionKind functionKind, QualifiedObjectName functionName, TrinoPrincipal grantee, boolean grantOption)
{
}

Expand Down Expand Up @@ -387,16 +394,6 @@ public void checkCanExecuteProcedure(SecurityContext context, QualifiedObjectNam
{
}

@Override
public void checkCanExecuteFunction(SecurityContext context, String functionName)
{
}

@Override
public void checkCanExecuteFunction(SecurityContext context, FunctionKind functionKind, QualifiedObjectName functionName)
{
}

@Override
public void checkCanExecuteTableProcedure(SecurityContext context, QualifiedObjectName tableName, String procedureName)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
import static io.trino.spi.security.AccessDeniedException.denyDropSchema;
import static io.trino.spi.security.AccessDeniedException.denyDropTable;
import static io.trino.spi.security.AccessDeniedException.denyDropView;
import static io.trino.spi.security.AccessDeniedException.denyExecuteFunction;
import static io.trino.spi.security.AccessDeniedException.denyExecuteProcedure;
import static io.trino.spi.security.AccessDeniedException.denyExecuteQuery;
import static io.trino.spi.security.AccessDeniedException.denyExecuteTableProcedure;
Expand Down Expand Up @@ -401,13 +400,7 @@ public void checkCanSetMaterializedViewProperties(SecurityContext context, Quali
}

@Override
public void checkCanGrantExecuteFunctionPrivilege(SecurityContext context, String functionName, Identity grantee, boolean grantOption)
{
denyGrantExecuteFunctionPrivilege(functionName, context.getIdentity(), grantee);
}

@Override
public void checkCanGrantExecuteFunctionPrivilege(SecurityContext context, FunctionKind functionKind, QualifiedObjectName functionName, Identity grantee, boolean grantOption)
public void checkCanGrantExecuteFunctionPrivilege(SecurityContext context, FunctionKind functionKind, QualifiedObjectName functionName, TrinoPrincipal grantee, boolean grantOption)
{
denyGrantExecuteFunctionPrivilege(functionName.toString(), context.getIdentity(), grantee);
}
Expand Down Expand Up @@ -521,15 +514,15 @@ public void checkCanExecuteProcedure(SecurityContext context, QualifiedObjectNam
}

@Override
public void checkCanExecuteFunction(SecurityContext context, String functionName)
public boolean canExecuteFunction(SecurityContext context, FunctionKind functionKind, QualifiedObjectName functionName)
{
denyExecuteFunction(functionName);
return false;
}

@Override
public void checkCanExecuteFunction(SecurityContext context, FunctionKind functionKind, QualifiedObjectName functionName)
public boolean canCreateViewWithExecuteFunction(SecurityContext context, FunctionKind functionKind, QualifiedObjectName functionName)
{
denyExecuteFunction(functionName.toString());
return false;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,13 +354,7 @@ public void checkCanSetMaterializedViewProperties(SecurityContext context, Quali
}

@Override
public void checkCanGrantExecuteFunctionPrivilege(SecurityContext context, String functionName, Identity grantee, boolean grantOption)
{
delegate().checkCanGrantExecuteFunctionPrivilege(context, functionName, grantee, grantOption);
}

@Override
public void checkCanGrantExecuteFunctionPrivilege(SecurityContext context, FunctionKind functionKind, QualifiedObjectName functionName, Identity grantee, boolean grantOption)
public void checkCanGrantExecuteFunctionPrivilege(SecurityContext context, FunctionKind functionKind, QualifiedObjectName functionName, TrinoPrincipal grantee, boolean grantOption)
{
delegate().checkCanGrantExecuteFunctionPrivilege(context, functionKind, functionName, grantee, grantOption);
}
Expand Down Expand Up @@ -474,15 +468,15 @@ public void checkCanExecuteProcedure(SecurityContext context, QualifiedObjectNam
}

@Override
public void checkCanExecuteFunction(SecurityContext context, String functionName)
public boolean canExecuteFunction(SecurityContext context, FunctionKind functionKind, QualifiedObjectName functionName)
{
delegate().checkCanExecuteFunction(context, functionName);
return delegate().canExecuteFunction(context, functionKind, functionName);
}

@Override
public void checkCanExecuteFunction(SecurityContext context, FunctionKind functionKind, QualifiedObjectName functionName)
public boolean canCreateViewWithExecuteFunction(SecurityContext context, FunctionKind functionKind, QualifiedObjectName functionName)
{
delegate().checkCanExecuteFunction(context, functionKind, functionName);
return delegate().canCreateViewWithExecuteFunction(context, functionKind, functionName);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import io.trino.spi.connector.SchemaRoutineName;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.function.FunctionKind;
import io.trino.spi.security.Identity;
import io.trino.spi.security.Privilege;
import io.trino.spi.security.TrinoPrincipal;
import io.trino.spi.security.ViewExpression;
Expand Down Expand Up @@ -333,7 +332,7 @@ public void checkCanGrantExecuteFunctionPrivilege(ConnectorSecurityContext conte
securityContext,
functionKind,
getQualifiedObjectName(functionName),
Identity.ofUser(grantee.getName()),
grantee,
grantOption);
}

Expand Down Expand Up @@ -475,10 +474,17 @@ public void checkCanExecuteTableProcedure(ConnectorSecurityContext context, Sche
}

@Override
public void checkCanExecuteFunction(ConnectorSecurityContext context, FunctionKind functionKind, SchemaRoutineName function)
public boolean canExecuteFunction(ConnectorSecurityContext context, FunctionKind functionKind, SchemaRoutineName function)
{
checkArgument(context == null, "context must be null");
accessControl.checkCanExecuteFunction(securityContext, functionKind, new QualifiedObjectName(catalogName, function.getSchemaName(), function.getRoutineName()));
return accessControl.canExecuteFunction(securityContext, functionKind, new QualifiedObjectName(catalogName, function.getSchemaName(), function.getRoutineName()));
}

@Override
public boolean canCreateViewWithExecuteFunction(ConnectorSecurityContext context, FunctionKind functionKind, SchemaRoutineName function)
{
checkArgument(context == null, "context must be null");
return accessControl.canCreateViewWithExecuteFunction(securityContext, functionKind, new QualifiedObjectName(catalogName, function.getSchemaName(), function.getRoutineName()));
}

@Override
Expand Down
Loading

0 comments on commit 051e53c

Please sign in to comment.