Skip to content

Commit

Permalink
Provide query owner full identity to view and kill query security checks
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed Nov 16, 2021
1 parent 32371f7 commit f7aac0d
Show file tree
Hide file tree
Showing 19 changed files with 186 additions and 57 deletions.
24 changes: 17 additions & 7 deletions core/trino-main/src/main/java/io/trino/SessionRepresentation.java
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,22 @@ public String getTimeZone()
return timeZoneKey.getId();
}

public Identity toIdentity()
{
return toIdentity(emptyMap());
}

public Identity toIdentity(Map<String, String> extraCredentials)
{
return Identity.forUser(user)
.withGroups(groups)
.withPrincipal(principal.map(BasicPrincipal::new))
.withEnabledRoles(enabledRoles)
.withConnectorRoles(catalogRoles)
.withExtraCredentials(extraCredentials)
.build();
}

public Session toSession(SessionPropertyManager sessionPropertyManager)
{
return toSession(sessionPropertyManager, emptyMap());
Expand All @@ -320,13 +336,7 @@ public Session toSession(SessionPropertyManager sessionPropertyManager, Map<Stri
new QueryId(queryId),
transactionId,
clientTransactionSupport,
Identity.forUser(user)
.withGroups(groups)
.withPrincipal(principal.map(BasicPrincipal::new))
.withEnabledRoles(enabledRoles)
.withConnectorRoles(catalogRoles)
.withExtraCredentials(extraCredentials)
.build(),
toIdentity(extraCredentials),
source,
catalog,
schema,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public void killQuery(String queryId, String message, ConnectorSession session)
checkState(dispatchManager.isPresent(), "No dispatch manager is set. kill_query procedure should be executed on coordinator.");
DispatchQuery dispatchQuery = dispatchManager.get().getQuery(query);

checkCanKillQueryOwnedBy(((FullConnectorSession) session).getSession().getIdentity(), dispatchQuery.getSession().getUser(), accessControl);
checkCanKillQueryOwnedBy(((FullConnectorSession) session).getSession().getIdentity(), dispatchQuery.getSession().getIdentity(), accessControl);

// check before killing to provide the proper error message (this is racy)
if (dispatchQuery.isDone()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import io.trino.spi.type.Type;

import java.security.Principal;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -81,21 +82,22 @@ public interface AccessControl
*
* @throws AccessDeniedException if not allowed
*/
void checkCanViewQueryOwnedBy(Identity identity, String queryOwner);
void checkCanViewQueryOwnedBy(Identity identity, Identity queryOwner);

/**
* Filter the list of users to those the identity view query owned by the user. The method
* will not be called with the current user in the set.
* @return
*/
Set<String> filterQueriesOwnedBy(Identity identity, Set<String> queryOwners);
Collection<Identity> filterQueriesOwnedBy(Identity identity, Collection<Identity> queryOwners);

/**
* Checks if identity can kill a query owned by the specified user. The method
* will not be called when the current user is the query owner.
*
* @throws AccessDeniedException if not allowed
*/
void checkCanKillQueryOwnedBy(Identity identity, String queryOwner);
void checkCanKillQueryOwnedBy(Identity identity, Identity queryOwner);

/**
* Filter the list of catalogs to those visible to the identity.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import java.io.IOException;
import java.io.UncheckedIOException;
import java.security.Principal;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -263,24 +264,24 @@ public void checkCanExecuteQuery(Identity identity)
}

@Override
public void checkCanViewQueryOwnedBy(Identity identity, String queryOwner)
public void checkCanViewQueryOwnedBy(Identity identity, Identity queryOwner)
{
requireNonNull(identity, "identity is null");

systemAuthorizationCheck(control -> control.checkCanViewQueryOwnedBy(new SystemSecurityContext(identity, Optional.empty()), queryOwner));
}

@Override
public Set<String> filterQueriesOwnedBy(Identity identity, Set<String> queryOwners)
public Collection<Identity> filterQueriesOwnedBy(Identity identity, Collection<Identity> queryOwners)
{
for (SystemAccessControl systemAccessControl : getSystemAccessControls()) {
queryOwners = systemAccessControl.filterViewQueryOwnedBy(new SystemSecurityContext(identity, Optional.empty()), queryOwners);
queryOwners = systemAccessControl.filterViewQuery(new SystemSecurityContext(identity, Optional.empty()), queryOwners);
}
return queryOwners;
}

@Override
public void checkCanKillQueryOwnedBy(Identity identity, String queryOwner)
public void checkCanKillQueryOwnedBy(Identity identity, Identity queryOwner)
{
requireNonNull(identity, "identity is null");
requireNonNull(queryOwner, "queryOwner is null");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
*/
package io.trino.security;

import com.google.common.collect.ImmutableSet;
import io.trino.SessionRepresentation;
import io.trino.server.BasicQueryInfo;
import io.trino.spi.security.Identity;

import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Set;

import static com.google.common.collect.ImmutableList.toImmutableList;
Expand All @@ -28,38 +29,85 @@ public final class AccessControlUtil
{
private AccessControlUtil() {}

public static void checkCanViewQueryOwnedBy(Identity identity, String queryOwner, AccessControl accessControl)
public static void checkCanViewQueryOwnedBy(Identity identity, Identity queryOwner, AccessControl accessControl)
{
if (identity.getUser().equals(queryOwner)) {
if (identity.getUser().equals(queryOwner.getUser())) {
return;
}
accessControl.checkCanViewQueryOwnedBy(identity, queryOwner);
}

public static List<BasicQueryInfo> filterQueries(Identity identity, List<BasicQueryInfo> queries, AccessControl accessControl)
{
String currentUser = identity.getUser();
Set<String> owners = queries.stream()
Collection<Identity> owners = queries.stream()
.map(BasicQueryInfo::getSession)
.map(SessionRepresentation::getUser)
.filter(owner -> !owner.equals(currentUser))
.collect(toImmutableSet());
.map(SessionRepresentation::toIdentity)
.filter(owner -> !owner.getUser().equals(identity.getUser()))
.map(FullIdentityEquality::new)
.distinct()
.map(FullIdentityEquality::getIdentity)
.collect(toImmutableList());
owners = accessControl.filterQueriesOwnedBy(identity, owners);

Set<String> allowedOwners = ImmutableSet.<String>builder()
.add(currentUser)
.addAll(owners)
.build();
Set<FullIdentityEquality> allowedOwners = owners.stream()
.map(FullIdentityEquality::new)
.collect(toImmutableSet());
return queries.stream()
.filter(queryInfo -> allowedOwners.contains(queryInfo.getSession().getUser()))
.filter(queryInfo -> {
Identity queryIdentity = queryInfo.getSession().toIdentity();
return queryIdentity.getUser().equals(identity.getUser()) || allowedOwners.contains(new FullIdentityEquality(queryIdentity));
})
.collect(toImmutableList());
}

public static void checkCanKillQueryOwnedBy(Identity identity, String queryOwner, AccessControl accessControl)
public static void checkCanKillQueryOwnedBy(Identity identity, Identity queryOwner, AccessControl accessControl)
{
if (identity.getUser().equals(queryOwner)) {
if (identity.getUser().equals(queryOwner.getUser())) {
return;
}
accessControl.checkCanKillQueryOwnedBy(identity, queryOwner);
}

private static class FullIdentityEquality
{
private final Identity identity;

public FullIdentityEquality(Identity identity)
{
this.identity = identity;
}

public Identity getIdentity()
{
return identity;
}

@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
FullIdentityEquality that = (FullIdentityEquality) o;
return Objects.equals(identity.getUser(), that.identity.getUser()) &&
Objects.equals(identity.getGroups(), that.identity.getGroups()) &&
Objects.equals(identity.getPrincipal(), that.identity.getPrincipal()) &&
Objects.equals(identity.getEnabledRoles(), that.identity.getEnabledRoles()) &&
Objects.equals(identity.getCatalogRoles(), that.identity.getCatalogRoles());
}

@Override
public int hashCode()
{
return Objects.hash(
identity.getUser(),
identity.getGroups(),
identity.getPrincipal(),
identity.getEnabledRoles(),
identity.getCatalogRoles());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.trino.spi.security.TrinoPrincipal;

import java.security.Principal;
import java.util.Collection;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
Expand Down Expand Up @@ -55,18 +56,18 @@ public void checkCanExecuteQuery(Identity identity)
}

@Override
public void checkCanViewQueryOwnedBy(Identity identity, String queryOwner)
public void checkCanViewQueryOwnedBy(Identity identity, Identity queryOwner)
{
}

@Override
public Set<String> filterQueriesOwnedBy(Identity identity, Set<String> queryOwners)
public Collection<Identity> filterQueriesOwnedBy(Identity identity, Collection<Identity> queryOwners)
{
return queryOwners;
}

@Override
public void checkCanKillQueryOwnedBy(Identity identity, String queryOwner)
public void checkCanKillQueryOwnedBy(Identity identity, Identity queryOwner)
{
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.security;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.trino.metadata.QualifiedObjectName;
import io.trino.spi.connector.CatalogSchemaName;
Expand All @@ -23,6 +24,7 @@
import io.trino.spi.security.TrinoPrincipal;

import java.security.Principal;
import java.util.Collection;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
Expand Down Expand Up @@ -121,19 +123,19 @@ public void checkCanExecuteQuery(Identity identity)
}

@Override
public void checkCanViewQueryOwnedBy(Identity identity, String queryOwner)
public void checkCanViewQueryOwnedBy(Identity identity, Identity queryOwner)
{
denyViewQuery();
}

@Override
public Set<String> filterQueriesOwnedBy(Identity identity, Set<String> queryOwners)
public Collection<Identity> filterQueriesOwnedBy(Identity identity, Collection<Identity> queryOwners)
{
return ImmutableSet.of();
return ImmutableList.of();
}

@Override
public void checkCanKillQueryOwnedBy(Identity identity, String queryOwner)
public void checkCanKillQueryOwnedBy(Identity identity, Identity queryOwner)
{
denyKillQuery();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.trino.spi.type.Type;

import java.security.Principal;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -82,19 +83,19 @@ public void checkCanExecuteQuery(Identity identity)
}

@Override
public void checkCanViewQueryOwnedBy(Identity identity, String queryOwner)
public void checkCanViewQueryOwnedBy(Identity identity, Identity queryOwner)
{
delegate().checkCanViewQueryOwnedBy(identity, queryOwner);
}

@Override
public Set<String> filterQueriesOwnedBy(Identity identity, Set<String> queryOwners)
public Collection<Identity> filterQueriesOwnedBy(Identity identity, Collection<Identity> queryOwners)
{
return delegate().filterQueriesOwnedBy(identity, queryOwners);
}

@Override
public void checkCanKillQueryOwnedBy(Identity identity, String queryOwner)
public void checkCanKillQueryOwnedBy(Identity identity, Identity queryOwner)
{
delegate().checkCanKillQueryOwnedBy(identity, queryOwner);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ public Response getQueryInfo(@PathParam("queryId") QueryId queryId, @Context Htt
return Response.status(Status.GONE).build();
}
try {
checkCanViewQueryOwnedBy(sessionContextFactory.extractAuthorizedIdentity(servletRequest, httpHeaders, alternateHeaderName), queryInfo.get().getSession().getUser(), accessControl);
checkCanViewQueryOwnedBy(sessionContextFactory.extractAuthorizedIdentity(servletRequest, httpHeaders, alternateHeaderName), queryInfo.get().getSession().toIdentity(), accessControl);
return Response.ok(queryInfo.get()).build();
}
catch (AccessDeniedException e) {
Expand All @@ -117,7 +117,7 @@ public void cancelQuery(@PathParam("queryId") QueryId queryId, @Context HttpServ

try {
BasicQueryInfo queryInfo = dispatchManager.getQueryInfo(queryId);
checkCanKillQueryOwnedBy(sessionContextFactory.extractAuthorizedIdentity(servletRequest, httpHeaders, alternateHeaderName), queryInfo.getSession().getUser(), accessControl);
checkCanKillQueryOwnedBy(sessionContextFactory.extractAuthorizedIdentity(servletRequest, httpHeaders, alternateHeaderName), queryInfo.getSession().toIdentity(), accessControl);
dispatchManager.cancelQuery(queryId);
}
catch (AccessDeniedException e) {
Expand Down Expand Up @@ -150,7 +150,7 @@ private Response failQuery(QueryId queryId, TrinoException queryException, HttpS
try {
BasicQueryInfo queryInfo = dispatchManager.getQueryInfo(queryId);

checkCanKillQueryOwnedBy(sessionContextFactory.extractAuthorizedIdentity(servletRequest, httpHeaders, alternateHeaderName), queryInfo.getSession().getUser(), accessControl);
checkCanKillQueryOwnedBy(sessionContextFactory.extractAuthorizedIdentity(servletRequest, httpHeaders, alternateHeaderName), queryInfo.getSession().toIdentity(), accessControl);

// check before killing to provide the proper error code (this is racy)
if (queryInfo.getState().isDone()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ public QueryStateInfo getQueryStateInfo(@PathParam("queryId") String queryId, @C
{
try {
BasicQueryInfo queryInfo = dispatchManager.getQueryInfo(new QueryId(queryId));
checkCanViewQueryOwnedBy(sessionContextFactory.extractAuthorizedIdentity(servletRequest, httpHeaders, alternateHeaderName), queryInfo.getSession().getUser(), accessControl);
checkCanViewQueryOwnedBy(sessionContextFactory.extractAuthorizedIdentity(servletRequest, httpHeaders, alternateHeaderName), queryInfo.getSession().toIdentity(), accessControl);
return getQueryStateInfo(queryInfo);
}
catch (AccessDeniedException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public Response getQueryInfo(@PathParam("queryId") QueryId queryId, @Context Htt
Optional<QueryInfo> queryInfo = dispatchManager.getFullQueryInfo(queryId);
if (queryInfo.isPresent()) {
try {
checkCanViewQueryOwnedBy(sessionContextFactory.extractAuthorizedIdentity(servletRequest, httpHeaders, alternateHeaderName), queryInfo.get().getSession().getUser(), accessControl);
checkCanViewQueryOwnedBy(sessionContextFactory.extractAuthorizedIdentity(servletRequest, httpHeaders, alternateHeaderName), queryInfo.get().getSession().toIdentity(), accessControl);
return Response.ok(queryInfo.get()).build();
}
catch (AccessDeniedException e) {
Expand Down Expand Up @@ -130,7 +130,7 @@ private Response failQuery(QueryId queryId, TrinoException queryException, HttpS
try {
BasicQueryInfo queryInfo = dispatchManager.getQueryInfo(queryId);

checkCanKillQueryOwnedBy(sessionContextFactory.extractAuthorizedIdentity(servletRequest, httpHeaders, alternateHeaderName), queryInfo.getSession().getUser(), accessControl);
checkCanKillQueryOwnedBy(sessionContextFactory.extractAuthorizedIdentity(servletRequest, httpHeaders, alternateHeaderName), queryInfo.getSession().toIdentity(), accessControl);

// check before killing to provide the proper error code (this is racy)
if (queryInfo.getState().isDone()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ public Response getThreads(
Optional<QueryInfo> queryInfo = dispatchManager.getFullQueryInfo(queryId);
if (queryInfo.isPresent()) {
try {
checkCanViewQueryOwnedBy(sessionContextFactory.extractAuthorizedIdentity(servletRequest, httpHeaders, alternateHeaderName), queryInfo.get().getSession().getUser(), accessControl);
checkCanViewQueryOwnedBy(sessionContextFactory.extractAuthorizedIdentity(servletRequest, httpHeaders, alternateHeaderName), queryInfo.get().getSession().toIdentity(), accessControl);
return proxyJsonResponse(nodeId, "v1/task/" + task);
}
catch (AccessDeniedException e) {
Expand Down
Loading

0 comments on commit f7aac0d

Please sign in to comment.