diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/OpenSamlInitializationService.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/OpenSamlInitializationService.java index a40ac25b972..8d2409d0d5b 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/OpenSamlInitializationService.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/OpenSamlInitializationService.java @@ -88,7 +88,12 @@ private OpenSamlInitializationService() { * @throws Saml2Exception if OpenSAML failed to initialize */ public static boolean initialize() { - return initialize((registry) -> { + return initialize(true, (registry) -> { + }); + } + + public static boolean initializedAlready() { + return initialize(false, (registry) -> { }); } @@ -104,19 +109,25 @@ public static boolean initialize() { * failed to initialize */ public static void requireInitialize(Consumer registryConsumer) { - if (!initialize(registryConsumer)) { + if (!initialize(true, registryConsumer)) { throw new Saml2Exception("OpenSAML was already initialized previously"); } } - private static boolean initialize(Consumer registryConsumer) { + private static boolean initialize(boolean initOpenSaml, Consumer registryConsumer) { if (initialized.compareAndSet(false, true)) { log.trace("Initializing OpenSAML"); - try { - InitializationService.initialize(); - } - catch (Exception ex) { - throw new Saml2Exception(ex); + if (initOpenSaml) { + try { + InitializationService.initialize(); + } catch (Exception ex) { + throw new Saml2Exception(ex); + } + } else { + if (ConfigurationService.get(XMLObjectProviderRegistry.class) == null) { + log.debug("OpenSAML not ready"); + return false; + } } registryConsumer.accept(ConfigurationService.get(XMLObjectProviderRegistry.class)); log.debug("Initialized OpenSAML"); diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/OpenSamlInitializationServiceTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/OpenSamlInitializationServiceTests.java index b5d56408076..9a8b45fbbaa 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/OpenSamlInitializationServiceTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/OpenSamlInitializationServiceTests.java @@ -31,17 +31,34 @@ * * @author Josh Cummings */ -public class OpenSamlInitializationServiceTests { +class OpenSamlInitializationServiceTests { @Test - public void initializeWhenInvokedMultipleTimesThenInitializesOnce() { + void initializeWhenInvokedMultipleTimesThenInitializesOnce() { OpenSamlInitializationService.initialize(); XMLObjectProviderRegistry registry = ConfigurationService.get(XMLObjectProviderRegistry.class); assertThat(registry.getBuilderFactory().getBuilder(AuthnRequest.DEFAULT_ELEMENT_NAME)).isNotNull(); assertThatExceptionOfType(Saml2Exception.class) - .isThrownBy(() -> OpenSamlInitializationService.requireInitialize((r) -> { + .isThrownBy(() -> OpenSamlInitializationService.requireInitialize(r -> { })) .withMessageContaining("OpenSAML was already initialized previously"); } + @Test + void initializedAlreadyWhenInitializedThenReturnsTrue() { + Saml2Utils.fipsCompliantOpenSamlInit(); + assertThat(OpenSamlInitializationService.initializedAlready()).isIn(true, false); + XMLObjectProviderRegistry registry = ConfigurationService.get(XMLObjectProviderRegistry.class); + assertThat(registry.getBuilderFactory().getBuilder(AuthnRequest.DEFAULT_ELEMENT_NAME)).isNotNull(); + } + + @Test + void initializedAlreadyWhenInitializedThenReturnsBuildIsNull() { + if (OpenSamlInitializationService.initializedAlready()) { + XMLObjectProviderRegistry registry = ConfigurationService.get(XMLObjectProviderRegistry.class); + assertThat(registry.getBuilderFactory().getBuilder(AuthnRequest.DEFAULT_ELEMENT_NAME)).isNotNull(); + } else { + assertThat(ConfigurationService.get(XMLObjectProviderRegistry.class)).isNull(); + } + } } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/Saml2Utils.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/Saml2Utils.java index e512fa5a99e..ff33d0c26aa 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/Saml2Utils.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/Saml2Utils.java @@ -20,13 +20,23 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.Base64; +import java.util.Properties; +import java.util.ServiceLoader; import java.util.zip.Deflater; import java.util.zip.DeflaterOutputStream; import java.util.zip.Inflater; import java.util.zip.InflaterOutputStream; +import org.opensaml.core.config.ConfigurationService; +import org.opensaml.core.config.InitializationException; +import org.opensaml.core.config.Initializer; +import org.opensaml.core.config.provider.PropertiesAdapter; +import org.opensaml.security.config.GlobalNamedCurveRegistryInitializer; +import org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap; import org.springframework.security.saml2.Saml2Exception; +import static org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap.CONFIG_PROPERTY_ECDH_DEFAULT_KDF; + public final class Saml2Utils { private Saml2Utils() { @@ -67,4 +77,21 @@ public static String samlInflate(byte[] b) { } } + public static void fipsCompliantOpenSamlInit() { + Properties props = new Properties(); + props.setProperty(CONFIG_PROPERTY_ECDH_DEFAULT_KDF, DefaultSecurityConfigurationBootstrap.PBKDF2); + ConfigurationService.setDefaultConfigurationPropertiesSource(() -> new PropertiesAdapter(props)); + Class toSkip = GlobalNamedCurveRegistryInitializer.class; + ServiceLoader.load(Initializer.class).stream() + .filter(provider -> provider.type() != toSkip) + .forEach(Saml2Utils::init); + } + + private static void init(ServiceLoader.Provider provider) { + try { + provider.get().init(); + } catch (InitializationException ex) { + throw new Saml2Exception(ex); + } + } }