diff --git a/packages/cli/src/sso/saml/routes/saml.controller.ee.ts b/packages/cli/src/sso/saml/routes/saml.controller.ee.ts index babb4cde9eab2..5ee636c8c50fa 100644 --- a/packages/cli/src/sso/saml/routes/saml.controller.ee.ts +++ b/packages/cli/src/sso/saml/routes/saml.controller.ee.ts @@ -13,7 +13,7 @@ import { getInitSSOFormView } from '../views/initSsoPost'; import { issueCookie } from '@/auth/jwt'; import { validate } from 'class-validator'; import type { PostBindingContext } from 'samlify/types/src/entity'; -import { isSamlLicensedAndEnabled } from '../samlHelpers'; +import { isConnectionTestRequest, isSamlLicensedAndEnabled } from '../samlHelpers'; import type { SamlLoginBinding } from '../types'; import { AuthenticatedRequest } from '@/requests'; import { @@ -111,7 +111,7 @@ export class SamlController { try { const loginResult = await this.samlService.handleSamlLogin(req, binding); // if RelayState is set to the test connection Url, this is a test connection - if (req.body.RelayState && req.body.RelayState === getServiceProviderConfigTestReturnUrl()) { + if (isConnectionTestRequest(req)) { if (loginResult.authenticatedUser) { return res.send(getSamlConnectionTestSuccessView(loginResult.attributes)); } else { @@ -133,7 +133,7 @@ export class SamlController { } throw new AuthError('SAML Authentication failed'); } catch (err) { - if (req.body.RelayState && req.body.RelayState === getServiceProviderConfigTestReturnUrl()) { + if (isConnectionTestRequest(req)) { return res.send(getSamlConnectionTestFailedView(err.message)); } throw new AuthError('SAML Authentication failed: ' + err.message); diff --git a/packages/cli/src/sso/saml/samlHelpers.ts b/packages/cli/src/sso/saml/samlHelpers.ts index 55efbedd6c839..4ff9970b29eed 100644 --- a/packages/cli/src/sso/saml/samlHelpers.ts +++ b/packages/cli/src/sso/saml/samlHelpers.ts @@ -18,6 +18,8 @@ import { isSamlCurrentAuthenticationMethod, setCurrentAuthenticationMethod, } from '../ssoHelpers'; +import { getServiceProviderConfigTestReturnUrl } from './serviceProvider.ee'; +import { SamlConfiguration } from './types/requests'; /** * Check whether the SAML feature is licensed and enabled in the instance */ @@ -173,3 +175,7 @@ export function getMappedSamlAttributesFromFlowResult( } return result; } + +export function isConnectionTestRequest(req: SamlConfiguration.AcsRequest): boolean { + return req.body.RelayState === getServiceProviderConfigTestReturnUrl(); +}