Skip to content

Commit

Permalink
Ensure that the EncryptedKey is passed to the DecryptionKeyLocator fo…
Browse files Browse the repository at this point in the history
…r SAML

Closes keycloak#22974
  • Loading branch information
rmartinc authored and mposolda committed Sep 20, 2023
1 parent 48e4e97 commit f8a9e01
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 54 deletions.
79 changes: 44 additions & 35 deletions js/apps/admin-ui/src/identity-providers/add/DescriptorSettings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ const Fields = ({ readOnly }: DescriptorSettingsProps) => {
name: "config.wantAuthnRequestsSigned",
});

const wantAssertionsEncrypted = useWatch({
control,
name: "config.wantAssertionsEncrypted",
});

const validateSignature = useWatch({
control,
name: "config.validateSignature",
Expand Down Expand Up @@ -376,41 +381,6 @@ const Fields = ({ readOnly }: DescriptorSettingsProps) => {
)}
></Controller>
</FormGroup>
<FormGroup
label={t("encryptionAlgorithm")}
labelIcon={
<HelpItem
helpText={t("encryptionAlgorithmHelp")}
fieldLabelId="identity-provider:encryptionAlgorithm"
/>
}
fieldId="kc-encryptionAlgorithm"
>
<Controller
name="config.encryptionAlgorithm"
defaultValue="RSA-OAEP"
control={control}
render={({ field }) => (
<Select
toggleId="kc-encryptionAlgorithm"
onToggle={(isExpanded) =>
setEncryptionAlgorithmDropdownOpen(isExpanded)
}
isOpen={encryptionAlgorithmDropdownOpen}
onSelect={(_, value) => {
field.onChange(value.toString());
setEncryptionAlgorithmDropdownOpen(false);
}}
selections={field.value}
variant={SelectVariant.single}
isDisabled={readOnly}
>
<SelectOption value="RSA-OAEP" />
<SelectOption value="RSA1_5" />
</Select>
)}
></Controller>
</FormGroup>
<FormGroup
label={t("samlSignatureKeyName")}
labelIcon={
Expand Down Expand Up @@ -461,6 +431,45 @@ const Fields = ({ readOnly }: DescriptorSettingsProps) => {
label="wantAssertionsEncrypted"
isReadOnly={readOnly}
/>

{wantAssertionsEncrypted === "true" && (
<FormGroup
label={t("encryptionAlgorithm")}
labelIcon={
<HelpItem
helpText={t("encryptionAlgorithmHelp")}
fieldLabelId="encryptionAlgorithm"
/>
}
fieldId="kc-encryptionAlgorithm"
>
<Controller
name="config.encryptionAlgorithm"
defaultValue="RSA-OAEP"
control={control}
render={({ field }) => (
<Select
toggleId="kc-encryptionAlgorithm"
onToggle={(isExpanded) =>
setEncryptionAlgorithmDropdownOpen(isExpanded)
}
isOpen={encryptionAlgorithmDropdownOpen}
onSelect={(_, value) => {
field.onChange(value.toString());
setEncryptionAlgorithmDropdownOpen(false);
}}
selections={field.value}
variant={SelectVariant.single}
isDisabled={readOnly}
>
<SelectOption value="RSA-OAEP" />
<SelectOption value="RSA1_5" />
</Select>
)}
></Controller>
</FormGroup>
)}

<SwitchField
field="config.forceAuthn"
label="forceAuthentication"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.xml.security.encryption.EncryptedKey;
import org.apache.xml.security.encryption.XMLCipher;
import org.apache.xml.security.encryption.XMLEncryptionException;
import org.apache.xml.security.exceptions.XMLSecurityException;
import org.apache.xml.security.utils.EncryptionConstants;

import org.keycloak.saml.common.PicketLinkLogger;
Expand Down Expand Up @@ -252,27 +253,25 @@ public static Element decryptElementInDocument(Document documentWithEncryptedEle
if (encDataElement == null)
throw logger.domMissingElementError("No element representing the encrypted data found");

// Look at siblings for the key
Element encKeyElement = getNextElementNode(encDataElement.getNextSibling());
if (encKeyElement == null) {
// Search the enc data element for enc key
NodeList nodeList = encDataElement.getElementsByTagNameNS(EncryptionConstants.EncryptionSpecNS, EncryptionConstants._TAG_ENCRYPTEDKEY);

if (nodeList == null || nodeList.getLength() == 0)
throw logger.nullValueError("Encrypted Key not found in the enc data");

encKeyElement = (Element) nodeList.item(0);
}

XMLCipher cipher;
EncryptedData encryptedData;
EncryptedKey encryptedKey;
try {
cipher = XMLCipher.getInstance();
cipher.init(XMLCipher.DECRYPT_MODE, null);
encryptedData = cipher.loadEncryptedData(documentWithEncryptedElement, encDataElement);
encryptedKey = cipher.loadEncryptedKey(documentWithEncryptedElement, encKeyElement);
} catch (XMLEncryptionException e1) {
if (encryptedData.getKeyInfo() == null) {
throw logger.domMissingElementError("No element representing KeyInfo found in the EncryptedData");
}

encryptedKey = encryptedData.getKeyInfo().itemEncryptedKey(0);
if (encryptedKey == null) {
// the encrypted key is not inside the encrypted data, locate it
Element encKeyElement = locateEncryptedKeyElement(encDataElement);
encryptedKey = cipher.loadEncryptedKey(documentWithEncryptedElement, encKeyElement);
encryptedData.getKeyInfo().add(encryptedKey);
}
} catch (XMLSecurityException e1) {
throw logger.processingError(e1);
}

Expand Down Expand Up @@ -325,6 +324,28 @@ public static Element decryptElementInDocument(Document documentWithEncryptedEle
return decryptedDoc.getDocumentElement();
}

/**
* Locates the EncryptedKey element once the EncryptedData element is found.
* A exception is thrown if not found.
*
* @param encDataElement The EncryptedData element found
* @return The EncryptedKey element
*/
private static Element locateEncryptedKeyElement(Element encDataElement) {
// Look at siblings for the key
Element encKeyElement = getNextElementNode(encDataElement.getNextSibling());
if (encKeyElement == null) {
// Search the enc data element for enc key
NodeList nodeList = encDataElement.getElementsByTagNameNS(EncryptionConstants.EncryptionSpecNS, EncryptionConstants._TAG_ENCRYPTEDKEY);

if (nodeList == null || nodeList.getLength() == 0)
throw logger.nullValueError("Encrypted Key not found in the enc data");

encKeyElement = (Element) nodeList.item(0);
}
return encKeyElement;
}

/**
* From the secret key, get the W3C XML Encryption URL
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ public List<PrivateKey> getKeys(EncryptedData encryptedData) {
// Map keys to PrivateKey
return keysStream
.map(KeyWrapper::getPrivateKey)
.filter(Objects::nonNull)
.map(Key::getEncoded)
.map(encoded -> {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.Collections;
import java.util.function.Function;
import javax.crypto.Cipher;
import javax.crypto.NoSuchPaddingException;
import org.apache.xml.security.encryption.XMLCipher;
import org.apache.xml.security.exceptions.XMLSecurityException;
import org.apache.xml.security.utils.EncryptionConstants;
import org.hamcrest.MatcherAssert;
import org.hamcrest.Matchers;
Expand All @@ -36,17 +39,19 @@
import org.keycloak.dom.saml.v2.assertion.NameIDType;
import org.keycloak.dom.saml.v2.protocol.ResponseType;
import org.keycloak.models.KeycloakSession;
import org.keycloak.protocol.saml.JaxrsSAML2BindingBuilder;
import org.keycloak.protocol.saml.SAMLEncryptionAlgorithms;
import org.keycloak.saml.SAML2LoginResponseBuilder;
import org.keycloak.saml.SAMLRequestParser;
import org.keycloak.saml.common.constants.JBossSAMLConstants;
import org.keycloak.saml.common.constants.JBossSAMLURIConstants;
import org.keycloak.saml.common.util.DocumentUtil;
import org.keycloak.saml.processing.core.saml.v2.common.SAMLDocumentHolder;
import org.keycloak.saml.processing.core.saml.v2.util.AssertionUtil;
import org.keycloak.saml.processing.core.util.XMLEncryptionUtil;
import org.keycloak.services.DefaultKeycloakSession;
import org.keycloak.services.DefaultKeycloakSessionFactory;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.NodeList;

/**
* <p>Simple test class that checks SAML encryption with different algorithms.
Expand All @@ -56,8 +61,6 @@
*/
public class SamlEncryptionTest {

private static final KeyPair rsaKeyPair;

static {
try {
KeyPairGenerator rsa = KeyPairGenerator.getInstance("RSA");
Expand All @@ -68,6 +71,17 @@ public class SamlEncryptionTest {
}
}

private static final KeyPair rsaKeyPair;
private static final XMLEncryptionUtil.DecryptionKeyLocator keyLocator = data -> {
try {
Assert.assertNotNull("EncryptedData does not contain KeyInfo", data.getKeyInfo());
Assert.assertNotNull("EncryptedData does not contain EncryptedKey", data.getKeyInfo().itemEncryptedKey(0));
return Collections.singletonList(rsaKeyPair.getPrivate());
} catch (XMLSecurityException e) {
throw new IllegalArgumentException("EncryptedData does not contain KeyInfo ", e);
}
};

@BeforeClass
public static void beforeClass() {
Cipher cipher = null;
Expand All @@ -86,6 +100,11 @@ public static void beforeClass() {
}

private void testEncryption(KeyPair pair, String alg, int keySize, String keyWrapAlg, String keyWrapHashMethod, String keyWrapMgf) throws Exception {
testEncryption(pair, alg, keySize, keyWrapAlg, keyWrapHashMethod, keyWrapMgf, Function.identity());
}

private void testEncryption(KeyPair pair, String alg, int keySize, String keyWrapAlg,
String keyWrapHashMethod, String keyWrapMgf, Function<Document,Document> transformer) throws Exception {
SAML2LoginResponseBuilder builder = new SAML2LoginResponseBuilder();
builder.requestID("requestId")
.destination("http://localhost")
Expand Down Expand Up @@ -120,18 +139,38 @@ private void testEncryption(KeyPair pair, String alg, int keySize, String keyWra
Document samlDocument = builder.buildDocument(samlModel);
bindingBuilder.postBinding(samlDocument);

samlDocument = transformer.apply(samlDocument);

String samlResponse = DocumentUtil.getDocumentAsString(samlDocument);

SAMLDocumentHolder holder = SAMLRequestParser.parseResponseDocument(samlResponse.getBytes(StandardCharsets.UTF_8));
ResponseType responseType = (ResponseType) holder.getSamlObject();
Assert.assertTrue("Assertion is not encrypted", AssertionUtil.isAssertionEncrypted(responseType));
AssertionType assertion = AssertionUtil.getAssertion(holder, responseType, pair.getPrivate());
AssertionUtil.decryptAssertion(responseType, keyLocator);
AssertionType assertion = responseType.getAssertions().get(0).getAssertion();
Assert.assertEquals("issuer", assertion.getIssuer().getValue());
MatcherAssert.assertThat(assertion.getSubject().getSubType().getBaseID(), Matchers.instanceOf(NameIDType.class));
NameIDType nameId = (NameIDType) assertion.getSubject().getSubType().getBaseID();
Assert.assertEquals("nameId", nameId.getValue());
}

private Document moveEncryptedKeyToRetrievalMethod(Document doc) {
NodeList nodes = doc.getElementsByTagNameNS(JBossSAMLURIConstants.XMLENC_NSURI.get(), JBossSAMLConstants.ENCRYPTED_KEY.get());
Element encKey = (Element) nodes.item(0);
Element keyInfo = (Element) encKey.getParentNode();

// remove the encKey, insert into EncryptedAssertion and substitute it with a RetrievalMethod
keyInfo.removeChild(encKey);
encKey.setAttribute("Id", "encryption-key-123");
keyInfo.getParentNode().getParentNode().appendChild(encKey);
Element retrievalMethod = doc.createElementNS(JBossSAMLURIConstants.XMLENC_NSURI.get(), "xenc:RetrievalMethod");
retrievalMethod.setAttribute("Type", "http://www.w3.org/2001/04/xmlenc#EncryptedKey");
retrievalMethod.setAttribute("URI", "encryption-key-123");
keyInfo.appendChild(retrievalMethod);

return doc;
}

@Test
public void testDefault() throws Exception {
testEncryption(rsaKeyPair, null, -1, null, null, null);
Expand Down Expand Up @@ -164,4 +203,9 @@ public void testKeyWrapsWithSha512() throws Exception {
public void testRsaOaep11WithSha512AndMgfSha512() throws Exception {
testEncryption(rsaKeyPair, "AES", 256, XMLCipher.RSA_OAEP_11, XMLCipher.SHA512, EncryptionConstants.MGF1_SHA512);
}

@Test
public void testEncryptionWithRetrievalMethod() throws Exception {
testEncryption(rsaKeyPair, null, -1, null, null, null, this::moveEncryptedKeyToRetrievalMethod);
}
}

0 comments on commit f8a9e01

Please sign in to comment.