Skip to content

Commit

Permalink
Leave only single startOAuth2Challenge method in OAuth2Service class
Browse files Browse the repository at this point in the history
  • Loading branch information
wendigo authored and kokosing committed Feb 8, 2022
1 parent 9b5bbd1 commit 4f39b2b
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import javax.inject.Inject;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.UriBuilder;
import javax.ws.rs.core.UriInfo;

import java.io.IOException;
import java.net.URI;
Expand All @@ -59,7 +58,6 @@
import static io.jsonwebtoken.security.Keys.hmacShaKeyFor;
import static io.trino.server.security.jwt.JwtUtil.newJwtBuilder;
import static io.trino.server.security.jwt.JwtUtil.newJwtParserBuilder;
import static io.trino.server.security.oauth2.OAuth2CallbackResource.CALLBACK_ENDPOINT;
import static io.trino.server.ui.FormWebUiAuthenticationFilter.UI_LOCATION;
import static java.lang.String.format;
import static java.nio.charset.StandardCharsets.UTF_8;
Expand Down Expand Up @@ -142,21 +140,7 @@ public OAuth2Service(OAuth2Client client, @ForOAuth2 SigningKeyResolver signingK
this.webUiOAuthEnabled = requireNonNull(webUiOAuthEnabled, "webUiOAuthEnabled is null").isPresent();
}

public Response startOAuth2Challenge(UriInfo uriInfo)
{
return startOAuth2Challenge(
uriInfo.getBaseUri().resolve(CALLBACK_ENDPOINT),
Optional.empty());
}

public Response startOAuth2Challenge(UriInfo uriInfo, String handlerState)
{
return startOAuth2Challenge(
uriInfo.getBaseUri().resolve(CALLBACK_ENDPOINT),
Optional.of(handlerState));
}

private Response startOAuth2Challenge(URI callbackUri, Optional<String> handlerState)
public Response startOAuth2Challenge(URI callbackUri, Optional<String> handlerState)
{
Instant challengeExpiration = now().plus(challengeTimeout);
String state = newJwtBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,12 @@
import javax.ws.rs.core.UriInfo;

import java.util.Map;
import java.util.Optional;
import java.util.UUID;

import static io.airlift.jaxrs.AsyncResponseHandler.bindAsyncResponse;
import static io.trino.server.security.ResourceSecurity.AccessType.PUBLIC;
import static io.trino.server.security.oauth2.OAuth2CallbackResource.CALLBACK_ENDPOINT;
import static io.trino.server.security.oauth2.OAuth2TokenExchange.MAX_POLL_TIME;
import static io.trino.server.security.oauth2.OAuth2TokenExchange.hashAuthId;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -74,7 +76,7 @@ public OAuth2TokenExchangeResource(OAuth2TokenExchange tokenExchange, OAuth2Serv
@Produces(MediaType.APPLICATION_JSON)
public Response initiateTokenExchange(@PathParam("authIdHash") String authIdHash, @Context UriInfo uriInfo)
{
return service.startOAuth2Challenge(uriInfo, authIdHash);
return service.startOAuth2Challenge(uriInfo.getBaseUri().resolve(CALLBACK_ENDPOINT), Optional.ofNullable(authIdHash));
}

@ResourceSecurity(PUBLIC)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import static io.trino.server.ServletSecurityUtils.sendErrorMessage;
import static io.trino.server.ServletSecurityUtils.sendWwwAuthenticate;
import static io.trino.server.ServletSecurityUtils.setAuthenticatedIdentity;
import static io.trino.server.security.oauth2.OAuth2CallbackResource.CALLBACK_ENDPOINT;
import static io.trino.server.ui.FormWebUiAuthenticationFilter.DISABLED_LOCATION;
import static io.trino.server.ui.FormWebUiAuthenticationFilter.DISABLED_LOCATION_URI;
import static io.trino.server.ui.FormWebUiAuthenticationFilter.TRINO_FORM_LOGIN;
Expand Down Expand Up @@ -137,7 +138,9 @@ private void needAuthentication(ContainerRequestContext request)
sendWwwAuthenticate(request, "Unauthorized", ImmutableSet.of(TRINO_FORM_LOGIN));
return;
}
request.abortWith(service.startOAuth2Challenge(request.getUriInfo()));
request.abortWith(service.startOAuth2Challenge(
request.getUriInfo().getBaseUri().resolve(CALLBACK_ENDPOINT),
Optional.empty()));
}

private static boolean isValidPrincipal(Object principal)
Expand Down

0 comments on commit 4f39b2b

Please sign in to comment.