From f8a9e0134a79ae807a197f86e5b30a1ec7629d04 Mon Sep 17 00:00:00 2001 From: rmartinc Date: Tue, 12 Sep 2023 11:34:18 +0200 Subject: [PATCH] Ensure that the EncryptedKey is passed to the DecryptionKeyLocator for SAML Closes https://github.com/keycloak/keycloak/issues/22974 --- .../add/DescriptorSettings.tsx | 79 +++++++++++-------- .../core/util/XMLEncryptionUtil.java | 49 ++++++++---- .../saml/SAMLDecryptionKeysLocator.java | 1 + .../protocol/saml/SamlEncryptionTest.java | 54 +++++++++++-- 4 files changed, 129 insertions(+), 54 deletions(-) diff --git a/js/apps/admin-ui/src/identity-providers/add/DescriptorSettings.tsx b/js/apps/admin-ui/src/identity-providers/add/DescriptorSettings.tsx index 37c156a75c2c..34e90712fa82 100644 --- a/js/apps/admin-ui/src/identity-providers/add/DescriptorSettings.tsx +++ b/js/apps/admin-ui/src/identity-providers/add/DescriptorSettings.tsx @@ -49,6 +49,11 @@ const Fields = ({ readOnly }: DescriptorSettingsProps) => { name: "config.wantAuthnRequestsSigned", }); + const wantAssertionsEncrypted = useWatch({ + control, + name: "config.wantAssertionsEncrypted", + }); + const validateSignature = useWatch({ control, name: "config.validateSignature", @@ -376,41 +381,6 @@ const Fields = ({ readOnly }: DescriptorSettingsProps) => { )} > - - } - fieldId="kc-encryptionAlgorithm" - > - ( - - )} - > - { label="wantAssertionsEncrypted" isReadOnly={readOnly} /> + + {wantAssertionsEncrypted === "true" && ( + + } + fieldId="kc-encryptionAlgorithm" + > + ( + + )} + > + + )} + getKeys(EncryptedData encryptedData) { // Map keys to PrivateKey return keysStream .map(KeyWrapper::getPrivateKey) + .filter(Objects::nonNull) .map(Key::getEncoded) .map(encoded -> { try { diff --git a/services/src/test/java/org/keycloak/protocol/saml/SamlEncryptionTest.java b/services/src/test/java/org/keycloak/protocol/saml/SamlEncryptionTest.java index e6a5c7e22b46..7e62e88c2399 100644 --- a/services/src/test/java/org/keycloak/protocol/saml/SamlEncryptionTest.java +++ b/services/src/test/java/org/keycloak/protocol/saml/SamlEncryptionTest.java @@ -22,9 +22,12 @@ import java.security.KeyPairGenerator; import java.security.NoSuchAlgorithmException; import java.security.SecureRandom; +import java.util.Collections; +import java.util.function.Function; import javax.crypto.Cipher; import javax.crypto.NoSuchPaddingException; import org.apache.xml.security.encryption.XMLCipher; +import org.apache.xml.security.exceptions.XMLSecurityException; import org.apache.xml.security.utils.EncryptionConstants; import org.hamcrest.MatcherAssert; import org.hamcrest.Matchers; @@ -36,17 +39,19 @@ import org.keycloak.dom.saml.v2.assertion.NameIDType; import org.keycloak.dom.saml.v2.protocol.ResponseType; import org.keycloak.models.KeycloakSession; -import org.keycloak.protocol.saml.JaxrsSAML2BindingBuilder; -import org.keycloak.protocol.saml.SAMLEncryptionAlgorithms; import org.keycloak.saml.SAML2LoginResponseBuilder; import org.keycloak.saml.SAMLRequestParser; +import org.keycloak.saml.common.constants.JBossSAMLConstants; import org.keycloak.saml.common.constants.JBossSAMLURIConstants; import org.keycloak.saml.common.util.DocumentUtil; import org.keycloak.saml.processing.core.saml.v2.common.SAMLDocumentHolder; import org.keycloak.saml.processing.core.saml.v2.util.AssertionUtil; +import org.keycloak.saml.processing.core.util.XMLEncryptionUtil; import org.keycloak.services.DefaultKeycloakSession; import org.keycloak.services.DefaultKeycloakSessionFactory; import org.w3c.dom.Document; +import org.w3c.dom.Element; +import org.w3c.dom.NodeList; /** *

Simple test class that checks SAML encryption with different algorithms. @@ -56,8 +61,6 @@ */ public class SamlEncryptionTest { - private static final KeyPair rsaKeyPair; - static { try { KeyPairGenerator rsa = KeyPairGenerator.getInstance("RSA"); @@ -68,6 +71,17 @@ public class SamlEncryptionTest { } } + private static final KeyPair rsaKeyPair; + private static final XMLEncryptionUtil.DecryptionKeyLocator keyLocator = data -> { + try { + Assert.assertNotNull("EncryptedData does not contain KeyInfo", data.getKeyInfo()); + Assert.assertNotNull("EncryptedData does not contain EncryptedKey", data.getKeyInfo().itemEncryptedKey(0)); + return Collections.singletonList(rsaKeyPair.getPrivate()); + } catch (XMLSecurityException e) { + throw new IllegalArgumentException("EncryptedData does not contain KeyInfo ", e); + } + }; + @BeforeClass public static void beforeClass() { Cipher cipher = null; @@ -86,6 +100,11 @@ public static void beforeClass() { } private void testEncryption(KeyPair pair, String alg, int keySize, String keyWrapAlg, String keyWrapHashMethod, String keyWrapMgf) throws Exception { + testEncryption(pair, alg, keySize, keyWrapAlg, keyWrapHashMethod, keyWrapMgf, Function.identity()); + } + + private void testEncryption(KeyPair pair, String alg, int keySize, String keyWrapAlg, + String keyWrapHashMethod, String keyWrapMgf, Function transformer) throws Exception { SAML2LoginResponseBuilder builder = new SAML2LoginResponseBuilder(); builder.requestID("requestId") .destination("http://localhost") @@ -120,18 +139,38 @@ private void testEncryption(KeyPair pair, String alg, int keySize, String keyWra Document samlDocument = builder.buildDocument(samlModel); bindingBuilder.postBinding(samlDocument); + samlDocument = transformer.apply(samlDocument); + String samlResponse = DocumentUtil.getDocumentAsString(samlDocument); SAMLDocumentHolder holder = SAMLRequestParser.parseResponseDocument(samlResponse.getBytes(StandardCharsets.UTF_8)); ResponseType responseType = (ResponseType) holder.getSamlObject(); Assert.assertTrue("Assertion is not encrypted", AssertionUtil.isAssertionEncrypted(responseType)); - AssertionType assertion = AssertionUtil.getAssertion(holder, responseType, pair.getPrivate()); + AssertionUtil.decryptAssertion(responseType, keyLocator); + AssertionType assertion = responseType.getAssertions().get(0).getAssertion(); Assert.assertEquals("issuer", assertion.getIssuer().getValue()); MatcherAssert.assertThat(assertion.getSubject().getSubType().getBaseID(), Matchers.instanceOf(NameIDType.class)); NameIDType nameId = (NameIDType) assertion.getSubject().getSubType().getBaseID(); Assert.assertEquals("nameId", nameId.getValue()); } + private Document moveEncryptedKeyToRetrievalMethod(Document doc) { + NodeList nodes = doc.getElementsByTagNameNS(JBossSAMLURIConstants.XMLENC_NSURI.get(), JBossSAMLConstants.ENCRYPTED_KEY.get()); + Element encKey = (Element) nodes.item(0); + Element keyInfo = (Element) encKey.getParentNode(); + + // remove the encKey, insert into EncryptedAssertion and substitute it with a RetrievalMethod + keyInfo.removeChild(encKey); + encKey.setAttribute("Id", "encryption-key-123"); + keyInfo.getParentNode().getParentNode().appendChild(encKey); + Element retrievalMethod = doc.createElementNS(JBossSAMLURIConstants.XMLENC_NSURI.get(), "xenc:RetrievalMethod"); + retrievalMethod.setAttribute("Type", "http://www.w3.org/2001/04/xmlenc#EncryptedKey"); + retrievalMethod.setAttribute("URI", "encryption-key-123"); + keyInfo.appendChild(retrievalMethod); + + return doc; + } + @Test public void testDefault() throws Exception { testEncryption(rsaKeyPair, null, -1, null, null, null); @@ -164,4 +203,9 @@ public void testKeyWrapsWithSha512() throws Exception { public void testRsaOaep11WithSha512AndMgfSha512() throws Exception { testEncryption(rsaKeyPair, "AES", 256, XMLCipher.RSA_OAEP_11, XMLCipher.SHA512, EncryptionConstants.MGF1_SHA512); } + + @Test + public void testEncryptionWithRetrievalMethod() throws Exception { + testEncryption(rsaKeyPair, null, -1, null, null, null, this::moveEncryptedKeyToRetrievalMethod); + } }