From 6a15f86400bc6db784efe311e1fafdd3cff94f69 Mon Sep 17 00:00:00 2001
From: Michael Auerswald <michael.auerswald@gmail.com>
Date: Tue, 4 Apr 2023 12:06:39 +0200
Subject: [PATCH] improve-saml-test-connection return

---
 packages/cli/src/sso/saml/constants.ts        |  6 ++--
 .../src/sso/saml/routes/saml.controller.ee.ts | 16 ++++++++---
 packages/cli/src/sso/saml/saml.service.ee.ts  | 28 +++++++++++--------
 .../cli/src/sso/saml/serviceProvider.ee.ts    |  5 ++++
 .../cli/src/sso/saml/types/samlPreferences.ts |  4 +++
 5 files changed, 40 insertions(+), 19 deletions(-)

diff --git a/packages/cli/src/sso/saml/constants.ts b/packages/cli/src/sso/saml/constants.ts
index 3729f3ce51660..4b1bca0fce510 100644
--- a/packages/cli/src/sso/saml/constants.ts
+++ b/packages/cli/src/sso/saml/constants.ts
@@ -3,8 +3,6 @@ export class SamlUrls {
 
 	static readonly initSSO = '/initsso';
 
-	static readonly restInitSSO = this.samlRESTRoot + this.initSSO;
-
 	static readonly acs = '/acs';
 
 	static readonly restAcs = this.samlRESTRoot + this.acs;
@@ -17,9 +15,9 @@ export class SamlUrls {
 
 	static readonly configTest = '/config/test';
 
-	static readonly configToggleEnabled = '/config/toggle';
+	static readonly configTestReturn = '/config/test/return';
 
-	static readonly restConfig = this.samlRESTRoot + this.config;
+	static readonly configToggleEnabled = '/config/toggle';
 
 	static readonly defaultRedirect = '/';
 
diff --git a/packages/cli/src/sso/saml/routes/saml.controller.ee.ts b/packages/cli/src/sso/saml/routes/saml.controller.ee.ts
index a5ba82699bd0b..731d67f3753df 100644
--- a/packages/cli/src/sso/saml/routes/saml.controller.ee.ts
+++ b/packages/cli/src/sso/saml/routes/saml.controller.ee.ts
@@ -16,7 +16,11 @@ import type { PostBindingContext } from 'samlify/types/src/entity';
 import { isSamlLicensedAndEnabled } from '../samlHelpers';
 import type { SamlLoginBinding } from '../types';
 import { AuthenticatedRequest } from '@/requests';
-import { getServiceProviderEntityId, getServiceProviderReturnUrl } from '../serviceProvider.ee';
+import {
+	getServiceProviderConfigTestReturnUrl,
+	getServiceProviderEntityId,
+	getServiceProviderReturnUrl,
+} from '../serviceProvider.ee';
 
 @RestController('/sso/saml')
 export class SamlController {
@@ -100,6 +104,10 @@ export class SamlController {
 	private async acsHandler(req: express.Request, res: express.Response, binding: SamlLoginBinding) {
 		const loginResult = await this.samlService.handleSamlLogin(req, binding);
 		if (loginResult) {
+			// eslint-disable-next-line @typescript-eslint/no-unsafe-member-access
+			if (req.body.RelayState && req.body.RelayState === getServiceProviderConfigTestReturnUrl()) {
+				return res.status(202).send(loginResult.attributes);
+			}
 			if (loginResult.authenticatedUser) {
 				// Only sign in user if SAML is enabled, otherwise treat as test connection
 				if (isSamlLicensedAndEnabled()) {
@@ -134,11 +142,11 @@ export class SamlController {
 	 */
 	@Get(SamlUrls.configTest, { middlewares: [samlLicensedOwnerMiddleware] })
 	async configTestGet(req: AuthenticatedRequest, res: express.Response) {
-		return this.handleInitSSO(res);
+		return this.handleInitSSO(res, getServiceProviderConfigTestReturnUrl());
 	}
 
-	private async handleInitSSO(res: express.Response) {
-		const result = this.samlService.getLoginRequestUrl();
+	private async handleInitSSO(res: express.Response, relayState?: string) {
+		const result = this.samlService.getLoginRequestUrl(relayState);
 		if (result?.binding === 'redirect') {
 			return result.context.context;
 		} else if (result?.binding === 'post') {
diff --git a/packages/cli/src/sso/saml/saml.service.ee.ts b/packages/cli/src/sso/saml/saml.service.ee.ts
index 1bf8e36b9ffc2..436aae8e03404 100644
--- a/packages/cli/src/sso/saml/saml.service.ee.ts
+++ b/packages/cli/src/sso/saml/saml.service.ee.ts
@@ -20,12 +20,13 @@ import {
 	setSamlLoginLabel,
 	updateUserFromSamlAttributes,
 } from './samlHelpers';
-import type { Settings } from '../../databases/entities/Settings';
+import type { Settings } from '@/databases/entities/Settings';
 import axios from 'axios';
 import https from 'https';
 import type { SamlLoginBinding } from './types';
 import type { BindingContext, PostBindingContext } from 'samlify/types/src/entity';
 import { validateMetadata, validateResponse } from './samlValidator';
+import { getInstanceBaseUrl } from '@/UserManagement/UserManagementHelper';
 
 @Service()
 export class SamlService {
@@ -48,6 +49,7 @@ export class SamlService {
 		loginLabel: 'SAML',
 		wantAssertionsSigned: true,
 		wantMessageSigned: true,
+		relayState: getInstanceBaseUrl(),
 		signatureConfig: {
 			prefix: 'ds',
 			location: {
@@ -92,7 +94,10 @@ export class SamlService {
 		return getServiceProviderInstance(this._samlPreferences);
 	}
 
-	getLoginRequestUrl(binding?: SamlLoginBinding): {
+	getLoginRequestUrl(
+		relayState?: string,
+		binding?: SamlLoginBinding,
+	): {
 		binding: SamlLoginBinding;
 		context: BindingContext | PostBindingContext;
 	} {
@@ -100,28 +105,29 @@ export class SamlService {
 		if (binding === 'post') {
 			return {
 				binding,
-				context: this.getPostLoginRequestUrl(),
+				context: this.getPostLoginRequestUrl(relayState),
 			};
 		} else {
 			return {
 				binding,
-				context: this.getRedirectLoginRequestUrl(),
+				context: this.getRedirectLoginRequestUrl(relayState),
 			};
 		}
 	}
 
-	private getRedirectLoginRequestUrl(): BindingContext {
-		const loginRequest = this.getServiceProviderInstance().createLoginRequest(
-			this.getIdentityProviderInstance(),
-			'redirect',
-		);
+	private getRedirectLoginRequestUrl(relayState?: string): BindingContext {
+		const sp = this.getServiceProviderInstance();
+		sp.entitySetting.relayState = relayState ?? getInstanceBaseUrl();
+		const loginRequest = sp.createLoginRequest(this.getIdentityProviderInstance(), 'redirect');
 		//TODO:SAML: debug logging
 		LoggerProxy.debug(loginRequest.context);
 		return loginRequest;
 	}
 
-	private getPostLoginRequestUrl(): PostBindingContext {
-		const loginRequest = this.getServiceProviderInstance().createLoginRequest(
+	private getPostLoginRequestUrl(relayState?: string): PostBindingContext {
+		const sp = this.getServiceProviderInstance();
+		sp.entitySetting.relayState = relayState ?? getInstanceBaseUrl();
+		const loginRequest = sp.createLoginRequest(
 			this.getIdentityProviderInstance(),
 			'post',
 		) as PostBindingContext;
diff --git a/packages/cli/src/sso/saml/serviceProvider.ee.ts b/packages/cli/src/sso/saml/serviceProvider.ee.ts
index 5d992830120a0..f6d707eaf86be 100644
--- a/packages/cli/src/sso/saml/serviceProvider.ee.ts
+++ b/packages/cli/src/sso/saml/serviceProvider.ee.ts
@@ -15,6 +15,10 @@ export function getServiceProviderReturnUrl(): string {
 	return getInstanceBaseUrl() + SamlUrls.restAcs;
 }
 
+export function getServiceProviderConfigTestReturnUrl(): string {
+	return getInstanceBaseUrl() + SamlUrls.configTestReturn;
+}
+
 // TODO:SAML: make these configurable for the end user
 export function getServiceProviderInstance(prefs: SamlPreferences): ServiceProviderInstance {
 	if (serviceProviderInstance === undefined) {
@@ -24,6 +28,7 @@ export function getServiceProviderInstance(prefs: SamlPreferences): ServiceProvi
 			wantAssertionsSigned: prefs.wantAssertionsSigned,
 			wantMessageSigned: prefs.wantMessageSigned,
 			signatureConfig: prefs.signatureConfig,
+			relayState: prefs.relayState,
 			nameIDFormat: ['urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress'],
 			assertionConsumerService: [
 				{
diff --git a/packages/cli/src/sso/saml/types/samlPreferences.ts b/packages/cli/src/sso/saml/types/samlPreferences.ts
index c5c72bcd0f85e..da02f1ebc4cef 100644
--- a/packages/cli/src/sso/saml/types/samlPreferences.ts
+++ b/packages/cli/src/sso/saml/types/samlPreferences.ts
@@ -57,4 +57,8 @@ export class SamlPreferences {
 			action: 'after',
 		},
 	};
+
+	@IsString()
+	@IsOptional()
+	relayState?: string = '';
 }