diff --git a/x-pack/plugin/identity-provider/src/main/java/org/elasticsearch/xpack/idp/saml/idp/SamlIdentityProviderBuilder.java b/x-pack/plugin/identity-provider/src/main/java/org/elasticsearch/xpack/idp/saml/idp/SamlIdentityProviderBuilder.java index 885e77d3eeec0..8051b08c78a86 100644 --- a/x-pack/plugin/identity-provider/src/main/java/org/elasticsearch/xpack/idp/saml/idp/SamlIdentityProviderBuilder.java +++ b/x-pack/plugin/identity-provider/src/main/java/org/elasticsearch/xpack/idp/saml/idp/SamlIdentityProviderBuilder.java @@ -168,6 +168,10 @@ public SamlIdentityProvider build() throws ValidationException { ex.addValidationError("Service provider defaults must be specified"); } + if (allowedNameIdFormats == null || allowedNameIdFormats.isEmpty()) { + ex.addValidationError("At least 1 allowed NameID format must be specified"); + } + if (ex.validationErrors().isEmpty() == false) { throw ex; } @@ -260,6 +264,9 @@ public SamlIdentityProviderBuilder singleLogoutEndpoint(String binding, URL endp } public SamlIdentityProviderBuilder allowedNameIdFormat(String nameIdFormat) { + if (this.allowedNameIdFormats == null) { + this.allowedNameIdFormats = new HashSet<>(); + } this.allowedNameIdFormats.add(nameIdFormat); return this; } diff --git a/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/saml/idp/SamlIdentityProviderBuilderTests.java b/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/saml/idp/SamlIdentityProviderBuilderTests.java index 1ca0a955dd08a..e5f995204ac0c 100644 --- a/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/saml/idp/SamlIdentityProviderBuilderTests.java +++ b/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/saml/idp/SamlIdentityProviderBuilderTests.java @@ -21,8 +21,10 @@ import org.elasticsearch.xpack.idp.saml.test.IdpSamlTestCase; import org.hamcrest.Matchers; import org.mockito.Mockito; +import org.opensaml.saml.saml2.core.NameID; import org.opensaml.security.x509.X509Credential; +import java.net.URL; import java.nio.file.Files; import java.nio.file.Path; import java.security.PrivateKey; @@ -43,13 +45,16 @@ import static org.elasticsearch.xpack.idp.saml.idp.SamlIdentityProviderBuilder.IDP_SLO_REDIRECT_ENDPOINT; import static org.elasticsearch.xpack.idp.saml.idp.SamlIdentityProviderBuilder.IDP_SSO_POST_ENDPOINT; import static org.elasticsearch.xpack.idp.saml.idp.SamlIdentityProviderBuilder.IDP_SSO_REDIRECT_ENDPOINT; +import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.notNullValue; import static org.opensaml.saml.common.xml.SAMLConstants.SAML2_POST_BINDING_URI; import static org.opensaml.saml.common.xml.SAMLConstants.SAML2_REDIRECT_BINDING_URI; +import static org.opensaml.saml.saml2.core.NameIDType.EMAIL; import static org.opensaml.saml.saml2.core.NameIDType.PERSISTENT; import static org.opensaml.saml.saml2.core.NameIDType.TRANSIENT; @@ -592,4 +597,39 @@ public void testCreateMetadataSigningCredentialFromKeyStoreWithMultipleEntriesBu ); } + public void testCreateViaMethodCalls() throws Exception { + final String entityId = randomAlphaOfLength(4) + ":" + randomAlphaOfLength(6) + "/" + randomAlphaOfLengthBetween(4, 12); + final URL redirectUrl = new URL( + randomFrom("http", "https") + + "://" + + String.join(".", randomArray(2, 5, String[]::new, () -> randomAlphaOfLengthBetween(3, 6))) + + "/" + + String.join("/", randomArray(1, 3, String[]::new, () -> randomAlphaOfLengthBetween(2, 4))) + ); + + final X509Credential credential = readCredentials("RSA", randomFrom(1024, 2048)); + final String nameIdFormat = randomFrom(NameID.TRANSIENT, PERSISTENT, EMAIL); + + final SamlServiceProviderResolver serviceResolver = Mockito.mock(SamlServiceProviderResolver.class); + final WildcardServiceProviderResolver wildcardResolver = Mockito.mock(WildcardServiceProviderResolver.class); + final ServiceProviderDefaults spDefaults = new ServiceProviderDefaults( + randomAlphaOfLength(2), + nameIdFormat, + Duration.ofMinutes(randomIntBetween(1, 10)) + ); + final SamlIdentityProvider idp = SamlIdentityProvider.builder(serviceResolver, wildcardResolver) + .entityId(entityId) + .singleSignOnEndpoint(SAML2_REDIRECT_BINDING_URI, redirectUrl) + .signingCredential(credential) + .serviceProviderDefaults(spDefaults) + .allowedNameIdFormat(nameIdFormat) + .build(); + + assertThat(idp.getEntityId(), is(entityId)); + assertThat(idp.getSingleSignOnEndpoint(SAML2_REDIRECT_BINDING_URI), is(redirectUrl)); + assertThat(idp.getSigningCredential(), is(credential)); + assertThat(idp.getServiceProviderDefaults(), is(spDefaults)); + assertThat(idp.getAllowedNameIdFormats(), contains(nameIdFormat)); + } + }