diff --git a/saml-authentication-server/src/main/java/jetbrains/buildServer/auth/saml/plugin/SamlAuthenticationScheme.java b/saml-authentication-server/src/main/java/jetbrains/buildServer/auth/saml/plugin/SamlAuthenticationScheme.java index 979fae4..dc7dfa4 100755 --- a/saml-authentication-server/src/main/java/jetbrains/buildServer/auth/saml/plugin/SamlAuthenticationScheme.java +++ b/saml-authentication-server/src/main/java/jetbrains/buildServer/auth/saml/plugin/SamlAuthenticationScheme.java @@ -33,6 +33,7 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpSession; import javax.xml.xpath.XPathException; import java.io.IOException; import java.net.MalformedURLException; @@ -88,6 +89,13 @@ public boolean isMultipleInstancesAllowed() { public void sendAuthnRequest(@NotNull HttpServletRequest request, @NotNull HttpServletResponse response) throws IOException, SettingsException { var samlSettings = buildSettings(); var auth = new Auth(samlSettings, request, response); + if (request.getSession() != null) { + Object urlKey = request.getSession().getAttribute("URL_KEY"); + if (urlKey instanceof String) { + auth.login((String)urlKey); + return; + } + } auth.login(); } @@ -106,6 +114,7 @@ public HttpAuthenticationResult processAuthenticationRequest(@NotNull HttpServle LOG.debug(String.format("SAML: incoming authentication request %s %s",request.getMethod(), request.getRequestURL())); var saml = request.getParameter(SamlPluginConstants.SAML_RESPONSE_REQUEST_PARAMETER); + var relayState = request.getParameter("RelayState"); if (StringUtil.isEmpty(saml)) { LOG.debug(String.format("%s parameter not found - returning N/A", SamlPluginConstants.SAML_RESPONSE_REQUEST_PARAMETER)); @@ -174,15 +183,30 @@ public HttpAuthenticationResult processAuthenticationRequest(@NotNull HttpServle LOG.info(String.format("SAML request authenticated for user %s/%s", user.getUsername(), user.getName())); - return HttpAuthenticationResult.authenticated( - new ServerPrincipal(user.getRealm(), user.getUsername(), null, settings.isCreateUsersAutomatically(), new HashMap<>()), - true).withRedirect(request.getContextPath() + "/"); + return authenticated(request, settings, user, relayState); } catch (Exception e) { LOG.error(e); return sendUnauthorizedRequest(request, response, String.format("Failed to authenticate request: %s", e.getMessage())); } } + private static String getRedirectUrl(HttpServletRequest request) { + HttpSession session = request.getSession(); + if (session == null) { + return request.getContextPath() + "/"; + } + String url = (String) session.getAttribute("URL_KEY"); + session.removeAttribute("URL_KEY"); + return url != null ? url : request.getContextPath() + "/"; + } + + @NotNull + private static HttpAuthenticationResult authenticated(@NotNull HttpServletRequest request, SamlPluginSettings settings, SUser user, String relayState) { + return HttpAuthenticationResult.authenticated( + new ServerPrincipal(user.getRealm(), user.getUsername(), null, settings.isCreateUsersAutomatically(), new HashMap<>()), + true).withRedirect(relayState != null ? relayState : getRedirectUrl(request)); + } + @NotNull private String getAttribute(@NotNull Auth saml, @NotNull SamlAttributeMappingSettings attributeMappingSettings) { switch (attributeMappingSettings.getMappingType()) {