Skip to content

Commit

Permalink
Add support for AuthorizedIdentity JWT claim
Browse files Browse the repository at this point in the history
  • Loading branch information
prithvip committed Sep 17, 2024
1 parent bcb2431 commit fd2615e
Show file tree
Hide file tree
Showing 11 changed files with 241 additions and 9 deletions.
1 change: 0 additions & 1 deletion presto-main/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,6 @@
<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt-jackson</artifactId>
<scope>runtime</scope>
</dependency>

<dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.facebook.presto.metadata.SessionPropertyManager;
import com.facebook.presto.spi.function.SqlFunctionId;
import com.facebook.presto.spi.function.SqlInvokedFunction;
import com.facebook.presto.spi.security.AuthorizedIdentity;
import com.facebook.presto.spi.security.Identity;
import com.facebook.presto.spi.security.SelectedRole;
import com.facebook.presto.spi.session.ResourceEstimates;
Expand Down Expand Up @@ -76,6 +77,7 @@
import static com.facebook.presto.client.PrestoHeaders.PRESTO_TRACE_TOKEN;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_TRANSACTION_ID;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_USER;
import static com.facebook.presto.server.security.ServletSecurityUtils.authorizedIdentity;
import static com.facebook.presto.sql.parser.ParsingOptions.DecimalLiteralTreatment.AS_DOUBLE;
import static com.google.common.base.Strings.emptyToNull;
import static com.google.common.base.Strings.isNullOrEmpty;
Expand All @@ -99,6 +101,7 @@ public final class HttpRequestSessionContext
private final String schema;

private final Identity identity;
private final Optional<AuthorizedIdentity> authorizedIdentity;
private final List<X509Certificate> certificates;

private final String source;
Expand Down Expand Up @@ -155,6 +158,7 @@ public HttpRequestSessionContext(HttpServletRequest servletRequest, SqlParserOpt
ImmutableMap.of(),
Optional.empty(),
Optional.empty());
authorizedIdentity = authorizedIdentity(servletRequest);

X509Certificate[] certs = (X509Certificate[]) servletRequest.getAttribute(X509_ATTRIBUTE);
if (certs != null && certs.length > 0) {
Expand Down Expand Up @@ -404,6 +408,12 @@ public Identity getIdentity()
return identity;
}

@Override
public Optional<AuthorizedIdentity> getAuthorizedIdentity()
{
return authorizedIdentity;
}

@Override
public List<X509Certificate> getCertificates()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ else if (context.getTimeZoneId() != null) {
private Identity authenticateIdentity(QueryId queryId, SessionContext context)
{
checkPermissions(accessControl, securityConfig, queryId, context);
Optional<AuthorizedIdentity> authorizedIdentity = getAuthorizedIdentity(accessControl, securityConfig, queryId, context);
Optional<AuthorizedIdentity> authorizedIdentity = context.getAuthorizedIdentity();
authorizedIdentity = authorizedIdentity.isPresent() ? authorizedIdentity : getAuthorizedIdentity(accessControl, securityConfig, queryId, context);

return authorizedIdentity.map(identity -> new Identity(
context.getIdentity().getUser(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.facebook.presto.common.transaction.TransactionId;
import com.facebook.presto.spi.function.SqlFunctionId;
import com.facebook.presto.spi.function.SqlInvokedFunction;
import com.facebook.presto.spi.security.AuthorizedIdentity;
import com.facebook.presto.spi.security.Identity;
import com.facebook.presto.spi.session.ResourceEstimates;
import com.facebook.presto.spi.tracing.Tracer;
Expand All @@ -34,6 +35,11 @@ public interface SessionContext
{
Identity getIdentity();

default Optional<AuthorizedIdentity> getAuthorizedIdentity()
{
return Optional.empty();
}

default List<X509Certificate> getCertificates()
{
return ImmutableList.of();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
import com.facebook.airlift.http.server.Authenticator;
import com.facebook.airlift.http.server.BasicPrincipal;
import com.facebook.airlift.security.pem.PemReader;
import com.facebook.presto.spi.security.AuthorizedIdentity;
import com.google.common.base.CharMatcher;
import com.google.common.collect.ImmutableMap;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Jws;
import io.jsonwebtoken.JwsHeader;
Expand All @@ -28,6 +30,7 @@
import io.jsonwebtoken.SignatureException;
import io.jsonwebtoken.SigningKeyResolver;
import io.jsonwebtoken.UnsupportedJwtException;
import io.jsonwebtoken.jackson.io.JacksonDeserializer;

import javax.crypto.spec.SecretKeySpec;
import javax.inject.Inject;
Expand All @@ -41,6 +44,8 @@
import java.util.concurrent.ConcurrentMap;
import java.util.function.Function;

import static com.facebook.presto.server.security.ServletSecurityUtils.AUTHORIZED_IDENTITY_ATTRIBUTE;
import static com.facebook.presto.server.security.ServletSecurityUtils.setAuthorizedIdentity;
import static com.google.common.base.CharMatcher.inRange;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Strings.nullToEmpty;
Expand Down Expand Up @@ -73,7 +78,8 @@ public JsonWebTokenAuthenticator(JsonWebTokenConfig config)
keyLoader = new StaticKeyLoader(config.getKeyFile());
}

JwtParser jwtParser = Jwts.parser()
JwtParser jwtParser = Jwts.parserBuilder()
.deserializeJsonWith(new JacksonDeserializer<>(ImmutableMap.of(AUTHORIZED_IDENTITY_ATTRIBUTE, AuthorizedIdentity.class)))
.setSigningKeyResolver(new SigningKeyResolver()
{
// interface uses raw types and this can not be fixed here
Expand All @@ -90,7 +96,8 @@ public Key resolveSigningKey(JwsHeader header, String plaintext)
{
return keyLoader.apply(header);
}
});
})
.build();

if (config.getRequiredIssuer() != null) {
jwtParser.requireIssuer(config.getRequiredIssuer());
Expand Down Expand Up @@ -118,6 +125,12 @@ public Principal authenticate(HttpServletRequest request)

try {
Jws<Claims> claimsJws = jwtParser.parseClaimsJws(token);

AuthorizedIdentity authorizedIdentity = claimsJws.getBody().get(AUTHORIZED_IDENTITY_ATTRIBUTE, AuthorizedIdentity.class);
if (authorizedIdentity != null) {
setAuthorizedIdentity(request, authorizedIdentity);
}

String subject = claimsJws.getBody().getSubject();
return new BasicPrincipal(subject);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.server.security;

import com.facebook.presto.spi.security.AuthorizedIdentity;

import javax.servlet.http.HttpServletRequest;

import java.util.Optional;

public class ServletSecurityUtils
{
public static final String AUTHORIZED_IDENTITY_ATTRIBUTE = "presto.authorized-identity";

private ServletSecurityUtils() {}

public static void setAuthorizedIdentity(HttpServletRequest servletRequest, AuthorizedIdentity authorizedIdentity)
{
servletRequest.setAttribute(AUTHORIZED_IDENTITY_ATTRIBUTE, authorizedIdentity);
}

public static Optional<AuthorizedIdentity> authorizedIdentity(HttpServletRequest servletRequest)
{
return Optional.ofNullable((AuthorizedIdentity) servletRequest.getAttribute(AUTHORIZED_IDENTITY_ATTRIBUTE));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
package com.facebook.presto.server;

import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ListMultimap;

import javax.servlet.AsyncContext;
Expand All @@ -35,6 +34,7 @@
import java.security.Principal;
import java.util.Collection;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;

Expand All @@ -53,7 +53,7 @@ public MockHttpServletRequest(ListMultimap<String, String> headers, String remot
{
this.headers = ImmutableListMultimap.copyOf(requireNonNull(headers, "headers is null"));
this.remoteAddress = requireNonNull(remoteAddress, "remoteAddress is null");
this.attributes = ImmutableMap.copyOf(requireNonNull(attributes, "attributes is null"));
this.attributes = new HashMap<>(requireNonNull(attributes, "attributes is null"));
}

@Override
Expand Down Expand Up @@ -371,7 +371,7 @@ public String getRemoteHost()
@Override
public void setAttribute(String name, Object o)
{
throw new UnsupportedOperationException();
attributes.put(name, o);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.facebook.presto.spi.function.RoutineCharacteristics;
import com.facebook.presto.spi.function.SqlFunctionId;
import com.facebook.presto.spi.function.SqlInvokedFunction;
import com.facebook.presto.spi.security.AuthorizedIdentity;
import com.facebook.presto.spi.security.Identity;
import com.facebook.presto.spi.security.SelectedRole;
import com.facebook.presto.sql.parser.IdentifierSymbol;
Expand Down Expand Up @@ -55,6 +56,7 @@
import static com.facebook.presto.client.PrestoHeaders.PRESTO_USER;
import static com.facebook.presto.common.type.StandardTypes.INTEGER;
import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature;
import static com.facebook.presto.server.security.ServletSecurityUtils.AUTHORIZED_IDENTITY_ATTRIBUTE;
import static com.facebook.presto.spi.function.FunctionVersion.notVersioned;
import static com.facebook.presto.spi.function.RoutineCharacteristics.Determinism.DETERMINISTIC;
import static com.facebook.presto.spi.function.RoutineCharacteristics.NullCallClause.RETURNS_NULL_ON_NULL_INPUT;
Expand Down Expand Up @@ -211,6 +213,24 @@ public void testExtraCredentials()
.build());
}

@Test
public void testAuthorizedIdentity()
{
AuthorizedIdentity authorizedIdentity = new AuthorizedIdentity("username", "reasonForSelect", false);
HttpServletRequest request = new MockHttpServletRequest(
ImmutableListMultimap.<String, String>builder()
.put(PRESTO_USER, "testUser")
.put(PRESTO_SOURCE, "testSource")
.put(PRESTO_CATALOG, "testCatalog")
.put(PRESTO_SCHEMA, "testSchema")
.build(),
"testRemote",
ImmutableMap.of(AUTHORIZED_IDENTITY_ATTRIBUTE, authorizedIdentity));

HttpRequestSessionContext context = new HttpRequestSessionContext(request, new SqlParserOptions());
assertEquals(context.getAuthorizedIdentity(), Optional.of(authorizedIdentity));
}

protected static String urlEncode(String value)
{
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.facebook.presto.spi.function.SqlFunctionId;
import com.facebook.presto.spi.function.SqlInvokedFunction;
import com.facebook.presto.spi.security.AllowAllAccessControl;
import com.facebook.presto.spi.security.AuthorizedIdentity;
import com.facebook.presto.sql.SqlEnvironmentConfig;
import com.facebook.presto.sql.parser.SqlParserOptions;
import com.google.common.collect.ImmutableListMultimap;
Expand Down Expand Up @@ -54,6 +55,7 @@
import static com.facebook.presto.server.TestHttpRequestSessionContext.createFunctionAdd;
import static com.facebook.presto.server.TestHttpRequestSessionContext.createSqlFunctionIdAdd;
import static com.facebook.presto.server.TestHttpRequestSessionContext.urlEncode;
import static com.facebook.presto.server.security.ServletSecurityUtils.AUTHORIZED_IDENTITY_ATTRIBUTE;
import static com.facebook.presto.transaction.InMemoryTransactionManager.createTestTransactionManager;
import static java.lang.String.format;
import static org.testng.Assert.assertEquals;
Expand All @@ -64,6 +66,7 @@ public class TestQuerySessionSupplier
private static final SqlInvokedFunction SQL_FUNCTION_ADD = createFunctionAdd();
private static final String SERIALIZED_SQL_FUNCTION_ID_ADD = jsonCodec(SqlFunctionId.class).toJson(SQL_FUNCTION_ID_ADD);
private static final String SERIALIZED_SQL_FUNCTION_ADD = jsonCodec(SqlInvokedFunction.class).toJson(SQL_FUNCTION_ADD);
private static final AuthorizedIdentity AUTHORIZED_IDENTITY = new AuthorizedIdentity("userName", "reasonForSelect", false);

private static final HttpServletRequest TEST_REQUEST = new MockHttpServletRequest(
ImmutableListMultimap.<String, String>builder()
Expand All @@ -81,7 +84,7 @@ public class TestQuerySessionSupplier
.put(PRESTO_SESSION_FUNCTION, format("%s=%s", urlEncode(SERIALIZED_SQL_FUNCTION_ID_ADD), urlEncode(SERIALIZED_SQL_FUNCTION_ADD)))
.build(),
"testRemote",
ImmutableMap.of());
ImmutableMap.of(AUTHORIZED_IDENTITY_ATTRIBUTE, AUTHORIZED_IDENTITY));

@Test
public void testCreateSession()
Expand Down Expand Up @@ -123,6 +126,8 @@ public WarningCollector create(WarningHandlingLevel warningHandlingLevel)
.put("query2", "select * from bar")
.build());
assertEquals(session.getSessionFunctions(), ImmutableMap.of(SQL_FUNCTION_ID_ADD, SQL_FUNCTION_ADD));
assertEquals(session.getIdentity().getSelectedUser().get(), AUTHORIZED_IDENTITY.getUserName());
assertEquals(session.getIdentity().getReasonForSelect(), AUTHORIZED_IDENTITY.getReasonForSelect());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.server.security;

import com.facebook.airlift.http.server.AuthenticationException;
import com.facebook.presto.server.MockHttpServletRequest;
import com.facebook.presto.spi.security.AuthorizedIdentity;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.io.Files;
import io.jsonwebtoken.Jwts;
import org.testng.annotations.AfterTest;
import org.testng.annotations.BeforeTest;
import org.testng.annotations.Test;

import javax.servlet.http.HttpServletRequest;

import java.io.IOException;
import java.nio.file.Path;
import java.security.Principal;

import static com.facebook.presto.server.security.ServletSecurityUtils.AUTHORIZED_IDENTITY_ATTRIBUTE;
import static com.facebook.presto.server.security.ServletSecurityUtils.authorizedIdentity;
import static com.facebook.presto.testing.assertions.Assert.assertEquals;
import static com.google.common.io.Files.createTempDir;
import static com.google.common.io.MoreFiles.deleteRecursively;
import static com.google.common.io.RecursiveDeleteOption.ALLOW_INSECURE;
import static com.google.common.net.HttpHeaders.AUTHORIZATION;
import static io.jsonwebtoken.JwsHeader.KEY_ID;
import static io.jsonwebtoken.SignatureAlgorithm.HS256;
import static io.jsonwebtoken.security.Keys.secretKeyFor;
import static java.nio.file.Files.readAllBytes;
import static java.util.Base64.getMimeDecoder;
import static java.util.Base64.getMimeEncoder;

public class TestJsonWebTokenAuthenticator
{
private static final String KEY_ID_FOO = "foo";
private static final String TEST_PRINCIPAL = "testPrincipal";

private Path temporaryDirectory;
private Path keyFile;
private JsonWebTokenConfig jsonWebTokenConfig;

@BeforeTest
public void setup()
throws IOException
{
temporaryDirectory = createTempDir().toPath();
keyFile = temporaryDirectory.resolve(KEY_ID_FOO + ".key");
byte[] key = getMimeEncoder().encode(secretKeyFor(HS256).getEncoded());
Files.write(key, keyFile.toFile());
jsonWebTokenConfig = new JsonWebTokenConfig().setKeyFile(keyFile.toAbsolutePath().toString());
}

@AfterTest(alwaysRun = true)
public void cleanup()
throws IOException
{
deleteRecursively(temporaryDirectory, ALLOW_INSECURE);
}

@Test
public void testJsonWebTokenWithAuthorizedUserClaim()
throws IOException, AuthenticationException
{
AuthorizedIdentity authorizedIdentity = new AuthorizedIdentity("user", "reasonForSelect", false);
String jsonWebToken = createJsonWebToken(keyFile, TEST_PRINCIPAL, authorizedIdentity);
HttpServletRequest request = new MockHttpServletRequest(
ImmutableListMultimap.of(AUTHORIZATION, "Bearer " + jsonWebToken),
"remoteAddress",
ImmutableMap.of());
Principal principal = new JsonWebTokenAuthenticator(jsonWebTokenConfig).authenticate(request);

assertEquals(principal.getName(), TEST_PRINCIPAL);
assertEquals(authorizedIdentity(request).get(), authorizedIdentity);
}

private static String createJsonWebToken(Path keyFile, String principal, AuthorizedIdentity authorizedIdentity)
throws IOException
{
byte[] key = getMimeDecoder().decode(readAllBytes(keyFile.toAbsolutePath()));
return Jwts.builder()
.signWith(HS256, key)
.setHeaderParam(KEY_ID, KEY_ID_FOO)
.setSubject(principal)
.claim(AUTHORIZED_IDENTITY_ATTRIBUTE, authorizedIdentity)
.compact();
}
}
Loading

0 comments on commit fd2615e

Please sign in to comment.