Skip to content

Commit

Permalink
Issue #8216 - improve testing for end_session_endpoint
Browse files Browse the repository at this point in the history
Signed-off-by: Lachlan Roberts <[email protected]>
  • Loading branch information
lachlan-roberts committed Jul 12, 2022
1 parent 26732c9 commit 90fe562
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ public void logout(ServletRequest request)
session.removeAttribute(SessionAuthentication.__J_AUTHENTICATED);
session.removeAttribute(CLAIMS);
session.removeAttribute(RESPONSE);
session.removeAttribute(ISSUER);
}
}

Expand All @@ -269,30 +270,33 @@ private void attemptLogoutRedirect(ServletRequest request)
{
Request baseRequest = Objects.requireNonNull(Request.getBaseRequest(request));
Response baseResponse = baseRequest.getResponse();
String endSessionEndpoint = _openIdConfiguration.getEndSessionEndpoint();
if (endSessionEndpoint == null)
return;

StringBuilder redirectUri = new StringBuilder(128);
URIUtil.appendSchemeHostPort(redirectUri, request.getScheme(), request.getServerName(), request.getServerPort());
redirectUri.append(baseRequest.getContextPath());
redirectUri.append(_logoutRedirectPath);

String endSessionEndpoint = _openIdConfiguration.getEndSessionEndpoint();
HttpSession session = baseRequest.getSession(false);
if (session == null)
if (endSessionEndpoint == null || session == null)
{
baseResponse.sendRedirect(redirectUri.toString(), true);
return;
}

Object openIdResponse = session.getAttribute(OpenIdAuthenticator.RESPONSE);
if (openIdResponse instanceof Map)
if (!(openIdResponse instanceof Map))
{
@SuppressWarnings("rawtypes")
String idToken = (String)((Map)openIdResponse).get("id_token");

baseResponse.sendRedirect(endSessionEndpoint +
"?id_token_hint=" + UrlEncoded.encodeString(idToken, StandardCharsets.UTF_8) +
"&post_logout_redirect_uri=" + UrlEncoded.encodeString(redirectUri.toString(), StandardCharsets.UTF_8),
true);
baseResponse.sendRedirect(redirectUri.toString(), true);
return;
}

@SuppressWarnings("rawtypes")
String idToken = (String)((Map)openIdResponse).get("id_token");
baseResponse.sendRedirect(endSessionEndpoint +
"?id_token_hint=" + UrlEncoded.encodeString(idToken, StandardCharsets.UTF_8) +
"&post_logout_redirect_uri=" + UrlEncoded.encodeString(redirectUri.toString(), StandardCharsets.UTF_8),
true);
}
catch (Throwable t)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ public class OpenIdConfiguration extends ContainerLifeCycle
private static final String CONFIG_PATH = "/.well-known/openid-configuration";
private static final String AUTHORIZATION_ENDPOINT = "authorization_endpoint";
private static final String TOKEN_ENDPOINT = "token_endpoint";
private static final String END_SESSION_ENDPOINT = "end_session_endpoint";
private static final String ISSUER = "issuer";

private final HttpClient httpClient;
Expand Down Expand Up @@ -164,11 +165,11 @@ protected void processMetadata(Map<String, Object> discoveryDocument)
tokenEndpoint = (String)discoveryDocument.get(TOKEN_ENDPOINT);
if (tokenEndpoint == null)
throw new IllegalStateException(TOKEN_ENDPOINT);
endSessionEndpoint = (String)discoveryDocument.get("end_session_endpoint");

// End session endpoint is optional.
if (endSessionEndpoint == null)
throw new IllegalArgumentException("end_session_endpoint");
endSessionEndpoint = (String)discoveryDocument.get(END_SESSION_ENDPOINT);

// We are lenient and not throw here as some major OIDC providers do not conform to this.
if (!Objects.equals(discoveryDocument.get(ISSUER), issuer))
LOG.warn("The issuer in the metadata is not correct.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import java.io.IOException;
import java.security.Principal;
import java.util.Map;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
Expand All @@ -35,6 +36,7 @@

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.is;

Expand Down Expand Up @@ -155,6 +157,11 @@ public void testLoginLogout() throws Exception
assertThat(response.getStatus(), is(HttpStatus.OK_200));
content = response.getContentAsString();
assertThat(content, containsString("not authenticated"));

// Test that the user was logged out successfully on the openid provider.
assertThat(openIdProvider.getLoggedInUsers().getCurrent(), equalTo(0L));
assertThat(openIdProvider.getLoggedInUsers().getMax(), equalTo(1L));
assertThat(openIdProvider.getLoggedInUsers().getTotal(), equalTo(1L));
}

public static class LoginPage extends HttpServlet
Expand All @@ -171,10 +178,9 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) t
public static class LogoutPage extends HttpServlet
{
@Override
protected void doGet(HttpServletRequest request, HttpServletResponse response) throws IOException
protected void doGet(HttpServletRequest request, HttpServletResponse response) throws IOException, ServletException
{
request.getSession().invalidate();
response.sendRedirect("/");
request.logout();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.eclipse.jetty.servlet.ServletHolder;
import org.eclipse.jetty.util.StringUtil;
import org.eclipse.jetty.util.component.ContainerLifeCycle;
import org.eclipse.jetty.util.statistic.CounterStatistic;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -48,6 +49,7 @@ public class OpenIdProvider extends ContainerLifeCycle
private static final String CONFIG_PATH = "/.well-known/openid-configuration";
private static final String AUTH_PATH = "/auth";
private static final String TOKEN_PATH = "/token";
private static final String END_SESSION_PATH = "/end_session";
private final Map<String, User> issuedAuthCodes = new HashMap<>();

protected final String clientId;
Expand All @@ -58,6 +60,7 @@ public class OpenIdProvider extends ContainerLifeCycle
private int port = 0;
private String provider;
private User preAuthedUser;
private final CounterStatistic loggedInUsers = new CounterStatistic();

public static void main(String[] args) throws Exception
{
Expand Down Expand Up @@ -91,9 +94,10 @@ public OpenIdProvider(String clientId, String clientSecret)

ServletContextHandler contextHandler = new ServletContextHandler();
contextHandler.setContextPath("/");
contextHandler.addServlet(new ServletHolder(new OpenIdConfigServlet()), CONFIG_PATH);
contextHandler.addServlet(new ServletHolder(new OpenIdAuthEndpoint()), AUTH_PATH);
contextHandler.addServlet(new ServletHolder(new OpenIdTokenEndpoint()), TOKEN_PATH);
contextHandler.addServlet(new ServletHolder(new ConfigServlet()), CONFIG_PATH);
contextHandler.addServlet(new ServletHolder(new AuthEndpoint()), AUTH_PATH);
contextHandler.addServlet(new ServletHolder(new TokenEndpoint()), TOKEN_PATH);
contextHandler.addServlet(new ServletHolder(new EndSessionEndpoint()), END_SESSION_PATH);
server.setHandler(contextHandler);

addBean(server);
Expand All @@ -112,6 +116,11 @@ public OpenIdConfiguration getOpenIdConfiguration()
return new OpenIdConfiguration(provider, authEndpoint, tokenEndpoint, clientId, clientSecret, null);
}

public CounterStatistic getLoggedInUsers()
{
return loggedInUsers;
}

@Override
protected void doStart() throws Exception
{
Expand Down Expand Up @@ -144,7 +153,7 @@ public void addRedirectUri(String uri)
redirectUris.add(uri);
}

public class OpenIdAuthEndpoint extends HttpServlet
public class AuthEndpoint extends HttpServlet
{
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException
Expand Down Expand Up @@ -252,7 +261,7 @@ public void redirectUser(HttpServletRequest request, User user, String redirectU
}
}

public class OpenIdTokenEndpoint extends HttpServlet
private class TokenEndpoint extends HttpServlet
{
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException
Expand Down Expand Up @@ -285,12 +294,44 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws S
"\"token_type\": \"Bearer\"" +
"}";

loggedInUsers.increment();
resp.setContentType("text/plain");
resp.getWriter().print(response);
}
}

public class OpenIdConfigServlet extends HttpServlet
private class EndSessionEndpoint extends HttpServlet
{
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException
{
doPost(req, resp);
}

@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException
{
String idToken = req.getParameter("id_token_hint");
if (idToken == null)
{
resp.sendError(HttpServletResponse.SC_BAD_REQUEST, "no id_token_hint");
return;
}

String logoutRedirect = req.getParameter("post_logout_redirect_uri");
if (logoutRedirect == null)
{
resp.sendError(HttpServletResponse.SC_BAD_REQUEST, "no post_logout_redirect_uri");
return;
}

loggedInUsers.decrement();
resp.setContentType("text/plain");
resp.sendRedirect(logoutRedirect);
}
}

private class ConfigServlet extends HttpServlet
{
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException
Expand All @@ -299,6 +340,7 @@ protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IO
"\"issuer\": \"" + provider + "\"," +
"\"authorization_endpoint\": \"" + provider + AUTH_PATH + "\"," +
"\"token_endpoint\": \"" + provider + TOKEN_PATH + "\"," +
"\"end_session_endpoint\": \"" + provider + END_SESSION_PATH + "\"," +
"}";

resp.getWriter().write(discoveryDocument);
Expand Down Expand Up @@ -336,5 +378,13 @@ public String getIdToken(String provider, String clientId)
long expiry = System.currentTimeMillis() + Duration.ofMinutes(1).toMillis();
return JwtEncoder.createIdToken(provider, clientId, subject, name, expiry);
}

@Override
public boolean equals(Object obj)
{
if (!(obj instanceof User))
return false;
return Objects.equals(subject, ((User)obj).subject) && Objects.equals(name, ((User)obj).name);
}
}
}

0 comments on commit 90fe562

Please sign in to comment.