From 72cb583f0bda4d7c41972c92ead0598af489c3e8 Mon Sep 17 00:00:00 2001 From: zibet27 Date: Mon, 2 Mar 2026 22:01:07 +0100 Subject: [PATCH 1/3] Server SAML Auth. SSO authentication flow. --- .../api/ktor-server-auth-saml.api | 45 ++ .../src/io/ktor/server/auth/saml/SamlAuth.kt | 99 +++ .../auth/saml/SamlAuthenticationProvider.kt | 345 ++++++++++ .../server/auth/saml/SamlRequestBuilder.kt | 230 +++++++ .../server/auth/saml/SamlResponseProcessor.kt | 243 +++++++ .../auth/saml/SamlSignatureValidator.kt | 131 ++++ .../src/io/ktor/server/auth/saml/SamlUtils.kt | 196 ++++++ .../auth/saml/RelayStateValidationTest.kt | 129 ++++ .../server/auth/saml/SAMLPrincipalTest.kt | 61 ++ .../auth/saml/SAMLRequestBuilderTest.kt | 99 +++ .../auth/saml/SAMLResponseProcessorTest.kt | 641 ++++++++++++++++++ .../io/ktor/server/auth/saml/SamlAuthTest.kt | 539 +++++++++++++++ .../test/io/ktor/server/auth/saml/TestUtil.kt | 355 ++++++++++ 13 files changed, 3113 insertions(+) create mode 100644 ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlAuth.kt create mode 100644 ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlAuthenticationProvider.kt create mode 100644 ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlRequestBuilder.kt create mode 100644 ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlResponseProcessor.kt create mode 100644 ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlSignatureValidator.kt create mode 100644 ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/RelayStateValidationTest.kt create mode 100644 ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SAMLPrincipalTest.kt create mode 100644 ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SAMLRequestBuilderTest.kt create mode 100644 ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SAMLResponseProcessorTest.kt create mode 100644 ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SamlAuthTest.kt create mode 100644 ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/TestUtil.kt diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/api/ktor-server-auth-saml.api b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/api/ktor-server-auth-saml.api index a13ca488d41..9f150ad2d22 100644 --- a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/api/ktor-server-auth-saml.api +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/api/ktor-server-auth-saml.api @@ -63,6 +63,17 @@ public final class io/ktor/server/auth/saml/SamlAlgorithms { public final fun getRECOMMENDED_SIGNATURE_ALGORITHMS ()Ljava/util/Set; } +public final class io/ktor/server/auth/saml/SamlAuthKt { + public static final fun saml (Lio/ktor/server/auth/AuthenticationConfig;Ljava/lang/String;Ljava/lang/String;Lkotlin/jvm/functions/Function1;)V + public static final fun saml (Lio/ktor/server/auth/AuthenticationConfig;Ljava/lang/String;Lkotlin/jvm/functions/Function1;)V + public static synthetic fun saml$default (Lio/ktor/server/auth/AuthenticationConfig;Ljava/lang/String;Ljava/lang/String;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V + public static synthetic fun saml$default (Lio/ktor/server/auth/AuthenticationConfig;Ljava/lang/String;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V +} + +public final class io/ktor/server/auth/saml/SamlAuthenticationProvider : io/ktor/server/auth/AuthenticationProvider { + public fun onAuthenticate (Lio/ktor/server/auth/AuthenticationContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + public final class io/ktor/server/auth/saml/SamlAuthnContext { public static final field Companion Lio/ktor/server/auth/saml/SamlAuthnContext$Companion; public static final synthetic fun box-impl (Ljava/lang/String;)Lio/ktor/server/auth/saml/SamlAuthnContext; @@ -189,12 +200,41 @@ public final class io/ktor/server/auth/saml/SamlPrincipal { public final fun hasAttribute (Ljava/lang/String;)Z } +public final class io/ktor/server/auth/saml/SamlRedirectResult { + public fun (Ljava/lang/String;Ljava/lang/String;)V + public final fun getMessageId ()Ljava/lang/String; + public final fun getRedirectUrl ()Ljava/lang/String; +} + public abstract interface class io/ktor/server/auth/saml/SamlReplayCache : java/lang/AutoCloseable { public abstract fun isReplayed (Ljava/lang/String;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public abstract fun recordAssertion (Ljava/lang/String;Lkotlin/time/Instant;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public abstract fun tryRecordAssertion (Ljava/lang/String;Lkotlin/time/Instant;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } +public final class io/ktor/server/auth/saml/SamlSession { + public static final field Companion Lio/ktor/server/auth/saml/SamlSession$Companion; + public fun (Ljava/lang/String;Ljava/lang/String;)V + public synthetic fun (Ljava/lang/String;Ljava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun getLogoutRequestId ()Ljava/lang/String; + public final fun getRequestId ()Ljava/lang/String; +} + +public final synthetic class io/ktor/server/auth/saml/SamlSession$$serializer : kotlinx/serialization/internal/GeneratedSerializer { + public static final field INSTANCE Lio/ktor/server/auth/saml/SamlSession$$serializer; + public final fun childSerializers ()[Lkotlinx/serialization/KSerializer; + public final fun deserialize (Lkotlinx/serialization/encoding/Decoder;)Lio/ktor/server/auth/saml/SamlSession; + public synthetic fun deserialize (Lkotlinx/serialization/encoding/Decoder;)Ljava/lang/Object; + public final fun getDescriptor ()Lkotlinx/serialization/descriptors/SerialDescriptor; + public final fun serialize (Lkotlinx/serialization/encoding/Encoder;Lio/ktor/server/auth/saml/SamlSession;)V + public synthetic fun serialize (Lkotlinx/serialization/encoding/Encoder;Ljava/lang/Object;)V + public fun typeParametersSerializers ()[Lkotlinx/serialization/KSerializer; +} + +public final class io/ktor/server/auth/saml/SamlSession$Companion { + public final fun serializer ()Lkotlinx/serialization/KSerializer; +} + public final class io/ktor/server/auth/saml/SamlSpMetadata { public final fun getAcsUrl ()Ljava/lang/String; public final fun getEncryptionCredential ()Lorg/opensaml/security/x509/BasicX509Credential; @@ -220,6 +260,11 @@ public final class io/ktor/server/auth/saml/SamlSpMetadata { public final fun technicalContact (Lkotlin/jvm/functions/Function1;)V } +public final class io/ktor/server/auth/saml/SamlValidationException : java/lang/Exception { + public fun (Ljava/lang/String;Ljava/lang/Throwable;)V + public synthetic fun (Ljava/lang/String;Ljava/lang/Throwable;ILkotlin/jvm/internal/DefaultConstructorMarker;)V +} + public final class io/ktor/server/auth/saml/SignatureAlgorithm { public static final field Companion Lio/ktor/server/auth/saml/SignatureAlgorithm$Companion; public fun (Ljava/lang/String;Ljava/lang/String;)V diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlAuth.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlAuth.kt new file mode 100644 index 00000000000..19bc0b77522 --- /dev/null +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlAuth.kt @@ -0,0 +1,99 @@ +/* + * Copyright 2014-2026 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package io.ktor.server.auth.saml + +import io.ktor.server.auth.* + +/** + * Installs SAML 2.0 authentication. + * + * SAML (Security Assertion Markup Language) 2.0 is an XML-based standard for exchanging + * authentication and authorization data between parties. This plugin implements the + * Web Browser SSO Profile for both Service Provider-initiated and Identity Provider-initiated authentication flows. + * + * ## Example Usage + * + * ```kotlin + * install(Authentication) { + * saml("saml-auth") { + * // Service Provider configuration + * sp = SamlSpMetadata { + * spEntityId = "https://myapp.example.com/saml/metadata" + * acsUrl = "https://myapp.example.com/saml/acs" + * signingCredential = SamlCrypto.loadCredential( + * keystorePath = "/path/to/keystore.jks", + * keystorePassword = "example_pass", + * keyAlias = "sp-key", + * keyPassword = "example_pass" + * ) + * } + * + * // Identity Provider metadata + * idp = parseSamlIdpMetadata(idpMetadataXml) + * + * // Validation logic + * validate { credential -> + * val nameId = credential.nameId + * val email = credential.getAttribute("email") + * if (email != null) { + * SamlPrincipal(credential.assertion) + * } else { + * null + * } + * } + * } + * } + * + * routing { + * authenticate("saml-auth") { + * get("/profile") { + * val principal = call.principal()!! + * call.respondText("Hello ${principal.nameId}") + * } + * } + * } + * ``` + * + * ## Security Features + * + * This implementation follows OWASP SAML Security Cheat Sheet recommendations: + * - **XXE Protection**: Prevents XML External Entity attacks + * - **XSW Protection**: Prevents XML Signature Wrapping + * - **Replay Protection**: Tracks processed assertion IDs to prevent replay attacks + * - **Signature Verification**: Validates XML signatures on assertions using IdP certificates + * - **Timestamp Validation**: Checks NotBefore/NotOnOrAfter with configurable clock skew + * - **Audience Restriction**: Ensures assertions are intended for this Service Provider + * + * @param name Optional name for this authentication provider + * @param configure Configuration block for SAML authentication + * + * [Report a problem](https://ktor.io/feedback/?fqname=io.ktor.server.auth.saml.saml) + */ +public fun AuthenticationConfig.saml( + name: String? = null, + configure: SamlConfig.() -> Unit +): Unit = saml(name, description = null, configure) + +/** + * Installs SAML 2.0 authentication with a custom description. + * + * @param name Optional name for this authentication provider + * @param description Optional description of the provider + * @param configure Configuration block for SAML authentication + * + * @see saml + */ +public fun AuthenticationConfig.saml( + name: String? = null, + description: String? = null, + configure: SamlConfig.() -> Unit +) { + LibSaml.ensureInitialized() + val provider = SamlConfig(name, description) + .apply(configure) + .let { SamlAuthenticationProvider(it) } + + register(provider) +} diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlAuthenticationProvider.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlAuthenticationProvider.kt new file mode 100644 index 00000000000..9314da42d26 --- /dev/null +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlAuthenticationProvider.kt @@ -0,0 +1,345 @@ +/* + * Copyright 2014-2026 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package io.ktor.server.auth.saml + +import io.ktor.http.* +import io.ktor.server.auth.* +import io.ktor.server.request.* +import io.ktor.server.response.* +import io.ktor.server.sessions.* +import kotlinx.coroutines.CancellationException +import kotlinx.serialization.Serializable +import org.opensaml.security.x509.BasicX509Credential +import org.slf4j.Logger +import org.slf4j.LoggerFactory +import java.net.URI + +/** + * SAML 2.0 authentication provider for Ktor Server. + * + * This provider implements SP-initiated SAML 2.0 Web Browser SSO Profile. + * It handles: + * - Generating and signing AuthnRequests + * - Redirecting users to the IdP for authentication + * - Receiving and validating SAML responses + * - Extracting user identity and attributes + * + * ## Security Features + * + * - XXE protection via secure XML parser + * - XML Signature Wrapping (XSW) protection via SAML20AssertionValidator + * - Replay attack protection via assertion ID cache + * - Signature verification of assertions + * - Timestamp validation with clock skew tolerance + * - Audience restriction validation + */ +public class SamlAuthenticationProvider internal constructor( + internal val config: SamlConfig +) : AuthenticationProvider(config) { + + private val logger: Logger = LoggerFactory.getLogger(SamlAuthenticationProvider::class.java) + private val spMetadata = requireNotNull(config.sp) { + "SP metadata must be configured. Use sp = SamlSpMetadata { ... } to set it." + } + private val spEntityId = requireNotNull(spMetadata.spEntityId) { "SP entity ID must be configured." } + private val idpMetadata = requireNotNull(config.idp) { + "IdP metadata must be configured. Use idp = parseSamlIdpMetadata(...) to set it." + } + private val relayValidator = config.allowedRelayStateUrls.let { + if (it == null) { + logger.warn("RelayState validation is disabled. This is unsafe in production.") + } + RelayValidator(allowedRelayStateUrls = it) + } + private val requestedAuthnContext = config.requestedAuthnContext + + private val authenticationFunction = requireNotNull(config.authenticationFunction) { + "SAML auth validate function must be specified" + } + private val challengeFunction: SamlAuthChallengeFunction = config.challengeFunction ?: { + call.respond(HttpStatusCode.Unauthorized) + } + + internal val signingCredential: BasicX509Credential? = spMetadata.signingCredential + + private val encryptionCredential: BasicX509Credential? = run { + val encryptionCredential = spMetadata.encryptionCredential + when { + // Explicit encryption credential is set - validate it supports decryption + encryptionCredential != null -> { + val keyAlgorithm = encryptionCredential.privateKey?.algorithm ?: "unknown" + require(encryptionCredential.supportsDecryption) { + "Encryption credential must use an RSA key for SAML assertion decryption. " + + "The configured encryption credential uses $keyAlgorithm algorithm. " + + "SAML encryption uses RSA-OAEP for key transport." + } + encryptionCredential + } + // Fall back to signing credential if it supports decryption + signingCredential != null && signingCredential.supportsDecryption -> signingCredential + // Signing credential is set but doesn't support decryption + signingCredential != null -> { + logger.warn( + "Signing credential uses ${signingCredential.privateKey?.algorithm ?: "unknown"} algorithm " + + "which cannot be used for SAML assertion decryption. " + + "If the IdP sends encrypted assertions, decryption will fail. " + + "Configure a separate encryptionCredential with an RSA key to support encrypted assertions." + ) + null + } + // No credentials configured + else -> null + } + } + + private val signatureVerifier by lazy { + SamlSignatureVerifier( + idpMetadata = idpMetadata, + allowedDigestAlgorithms = config.allowedDigestAlgorithms, + allowedSignatureAlgorithms = config.allowedSignatureAlgorithms, + ) + } + + private val replayCache by lazy { + config.replayCache ?: InMemorySamlReplayCache() + } + + private val responseProcessor: SamlResponseProcessor by lazy { + SamlResponseProcessor( + acsUrl = acsUrl, + spEntityId = spEntityId, + idpMetadata = idpMetadata, + replayCache = replayCache, + clockSkew = config.clockSkew, + decryptionCredential = encryptionCredential, + requireDestination = config.requireDestination, + allowIdpInitiatedSso = config.allowIdpInitiatedSso, + requireSignedResponse = config.requireSignedResponse, + requireSignedAssertions = spMetadata.wantAssertionsSigned, + signatureVerifier = signatureVerifier, + ) + } + + private val acsUrl = spMetadata.acsUrl.also { url -> + require(isAbsoluteUrl(url)) { + "ACS URL must be an absolute URL (e.g., https://myapp.example.com/saml/acs). Got: $url" + } + } + + private val sloUrl = spMetadata.sloUrl.also { url -> + if (!isAbsoluteUrl(url)) { + logger.warn("SLO URL should be an absolute URL for proper SAML 2.0 compliance. Got: $url") + } + } + + private val acsPath = Url(acsUrl).encodedPath + private val sloPath = if (isAbsoluteUrl(sloUrl)) Url(sloUrl).encodedPath else sloUrl + + init { + require(config.clockSkew.isPositive()) { "clockSkew must be positive, got: ${config.clockSkew}" } + require(acsPath != sloPath) { "acsPath and sloPath must be different, got: $acsPath" } + } + + override suspend fun onAuthenticate(context: AuthenticationContext) { + val request = context.call.request + when { + request.httpMethod == HttpMethod.Post && request.path() == acsPath -> + context.handleSamlCallback() + + else -> context.handleChallenge() + } + } + + /** + * Handles the SAML callback (ACS endpoint). + * Processes the SAML response and validates the assertion. + */ + private suspend fun AuthenticationContext.handleSamlCallback() { + try { + val parameters = call.receiveParameters() + val samlResponseBase64 = parameters["SAMLResponse"] + val relayState = parameters["RelayState"] + + if (samlResponseBase64 == null) { + challenge(SAML_AUTH_KEY, AuthenticationFailedCause.NoCredentials) { challenge, call -> + challengeFunction(SamlChallengeContext(call), AuthenticationFailedCause.NoCredentials) + challenge.complete() + } + return + } + + val session = call.sessions.get() + val expectedRequestId = session?.requestId + + val credentials = responseProcessor.processResponse(samlResponseBase64, expectedRequestId) + val principal = authenticationFunction(call, credentials) + + if (principal == null) { + challenge(SAML_AUTH_KEY, AuthenticationFailedCause.InvalidCredentials) { challenge, call -> + challengeFunction(SamlChallengeContext(call), AuthenticationFailedCause.InvalidCredentials) + challenge.complete() + } + return + } + + this.principal(name, principal) + call.sessions.clear() + + when { + relayState.isNullOrBlank() -> return + relayValidator.validate(url = relayState) -> call.respondRedirect(url = relayState) + else -> logger.warn("RelayState URL not in allowlist, ignoring: $relayState") + } + } catch (e: CancellationException) { + throw e + } catch (e: SamlValidationException) { + error(SAML_AUTH_KEY, AuthenticationFailedCause.Error(e.message ?: "SAML validation failed")) + } catch (e: Exception) { + val message = "SAML processing error: ${e.message ?: "Unknown error"}" + error(SAML_AUTH_KEY, AuthenticationFailedCause.Error(message)) + } + } + + /** + * Handles the challenge phase (initiates SAML authentication). + * Generates an AuthnRequest and sends it to the IdP using the configured binding. + */ + private fun AuthenticationContext.handleChallenge() { + challenge(SAML_AUTH_KEY, cause = AuthenticationFailedCause.NoCredentials) { challenge, call -> + try { + when (config.authnRequestBinding) { + SamlBinding.HttpRedirect -> { + val result = buildAuthnRequestRedirect( + acsUrl = acsUrl, + spEntityId = spEntityId, + idpSsoUrl = idpMetadata.getSsoUrlFor(SamlBinding.HttpRedirect), + relayState = call.request.uri, + signingCredential = signingCredential, + nameIdFormat = config.nameIdFormat, + forceAuthn = config.forceAuthn, + signatureAlgorithm = config.signatureAlgorithm, + requestedAuthnContext = requestedAuthnContext + ) + call.sessions.set(SamlSession(requestId = result.messageId)) + call.respondRedirect(result.redirectUrl) + } + + SamlBinding.HttpPost -> { + val postData = buildAuthnRequestPost( + acsUrl = acsUrl, + spEntityId = spEntityId, + relayState = call.request.uri, + forceAuthn = config.forceAuthn, + nameIdFormat = config.nameIdFormat, + signingCredential = signingCredential, + requestedAuthnContext = requestedAuthnContext, + signatureAlgorithm = config.signatureAlgorithm, + idpSsoUrl = idpMetadata.getSsoUrlFor(SamlBinding.HttpPost), + ) + call.sessions.set(SamlSession(postData.requestId)) + call.respondText(postData.toAutoSubmitHtml(), ContentType.Text.Html) + } + } + challenge.complete() + } catch (e: Exception) { + logger.error("Failed to initiate SAML authentication", e) + challengeFunction( + SamlChallengeContext(call), + AuthenticationFailedCause.Error("Failed to initiate SAML") + ) + if (!challenge.completed && call.response.status() != null) { + challenge.complete() + } + } + } + } +} + +internal class RelayValidator(private val allowedRelayStateUrls: List?) { + /** + * Validates that the relay state URL is allowed for redirect. + */ + fun validate(url: String): Boolean = when { + allowedRelayStateUrls == null -> true + // Reject dangerous patterns + url.startsWith("//") || url.contains("\\") || url.any { it.isISOControl() } -> false + // Allow relative paths + url.startsWith("/") -> { + val normalized = URI(url).normalize().toString() + if (normalized.startsWith("//") || !normalized.startsWith("/")) return false + if (allowedRelayStateUrls.isEmpty()) return true + val pathPrefixes = allowedRelayStateUrls.filter { it.startsWith("/") && !it.startsWith("//") } + pathPrefixes.any { prefix -> + normalized.startsWith(prefix) && (normalized == prefix || prefix.endsWith("/")) + } + } + // Validate absolute URLs + else -> { + allowedRelayStateUrls.any { prefix -> isAllowedAbsoluteRelayState(url, prefix) } + } + } + + /** + * Validates an absolute URL against an allowed prefix. + * Checks a scheme, host, port, and path with segment boundary validation. + */ + private fun isAllowedAbsoluteRelayState(targetUrl: String, allowedPrefix: String): Boolean { + val target = runCatching { Url(targetUrl) }.getOrNull() ?: return false + val allowed = runCatching { Url(allowedPrefix) }.getOrNull() ?: return false + + // Reject URLs with userinfo (user:pass@host) - bypass technique + if (target.user != null || target.password != null) return false + + // Only allow http and https schemes + if (target.protocol.name !in listOf("http", "https")) return false + if (allowed.protocol.name !in listOf("http", "https")) return false + + // Exact protocol match (case-insensitive) + if (!target.protocol.name.equals(allowed.protocol.name, ignoreCase = true)) return false + + // Exact host match (case-insensitive per RFC 3986) + if (!target.host.equals(allowed.host, ignoreCase = true)) return false + + // Exact port match + if (target.port != allowed.port) return false + + // Path prefix match with segment boundary check + val allowedPath = allowed.encodedPath.ifBlank { "/" } + val targetPath = target.encodedPath.ifBlank { "/" } + + if (!targetPath.startsWith(allowedPath)) return false + if (targetPath == allowedPath) return true + if (allowedPath.endsWith("/")) return true + return targetPath.getOrNull(allowedPath.length) == '/' + } +} + +/** + * Session data for SAML authentication. + * Stores the AuthnRequest ID for InResponseTo validation and optional LogoutRequest ID for SLO. + * + * When using SAML authentication, you need to install the Sessions plugin + * and register this session type: + * + * ```kotlin + * install(Sessions) { + * cookie("SAML_SESSION") + * } + * ``` + * + * @property requestId The ID of the AuthnRequest, used for InResponseTo validation + * @property logoutRequestId The ID of the LogoutRequest (for Single Logout), used for InResponseTo validation + */ +@Serializable +public class SamlSession( + public val requestId: String, + public val logoutRequestId: String? = null +) + +private val SAML_AUTH_KEY: Any = "SAMLAuth" + +private fun isAbsoluteUrl(url: String): Boolean { + return url.startsWith("http://") || url.startsWith("https://") +} diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlRequestBuilder.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlRequestBuilder.kt new file mode 100644 index 00000000000..83ad29881b9 --- /dev/null +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlRequestBuilder.kt @@ -0,0 +1,230 @@ +/* + * Copyright 2014-2026 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package io.ktor.server.auth.saml + +import io.ktor.util.* +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport +import org.opensaml.saml.common.xml.SAMLConstants +import org.opensaml.saml.saml2.core.* +import org.opensaml.security.credential.Credential +import org.opensaml.xmlsec.keyinfo.impl.X509KeyInfoGeneratorFactory +import org.opensaml.xmlsec.signature.Signature +import org.opensaml.xmlsec.signature.support.SignatureConstants +import org.opensaml.xmlsec.signature.support.Signer +import kotlin.io.encoding.Base64 +import kotlin.time.Clock +import kotlin.time.ExperimentalTime +import kotlin.time.toJavaInstant + +private val DEFAULT_SIGNATURE_ALGORITHM = SignatureAlgorithm.RSA_SHA256 + +/** + * Builds an AuthnRequest and generates a redirect URL to the IdP. + * + * @param spEntityId The Service Provider entity ID + * @param acsUrl The Assertion Consumer Service URL + * @param idpSsoUrl The IdP's Single Sign-On URL + * @param relayState Optional RelayState parameter (original requested URL) + * @param signingCredential Credential for signing (if null, no signing is performed) + * @param nameIdFormat Optional NameID format to request (e.g., email, persistent) + * @param forceAuthn Whether to force re-authentication at the IdP + * @param signatureAlgorithm Signature algorithm to use for signing + * @param requestedAuthnContext Optional requested authentication context + */ +internal fun buildAuthnRequestRedirect( + spEntityId: String, + acsUrl: String, + idpSsoUrl: String, + relayState: String? = null, + signingCredential: Credential? = null, + nameIdFormat: NameIdFormat? = null, + forceAuthn: Boolean = false, + signatureAlgorithm: SignatureAlgorithm = DEFAULT_SIGNATURE_ALGORITHM, + requestedAuthnContext: SamlAuthnContext? = null +): SamlRedirectResult { + LibSaml.ensureInitialized() + val authnRequest = buildAuthnRequest( + spEntityId, + acsUrl, + destination = idpSsoUrl, + nameIdFormat, + forceAuthn, + requestedAuthnContext + ) + return buildSamlRedirectResult( + messageId = checkNotNull(authnRequest.id), + samlObject = authnRequest, + destinationUrl = idpSsoUrl, + parameterName = "SAMLRequest", + relayState = relayState, + signingCredential = signingCredential, + signatureAlgorithm = signatureAlgorithm + ) +} + +/** + * Builds an AuthnRequest and returns the data for HTTP-POST binding. + * + * In HTTP-POST binding, the AuthnRequest is: + * - Signed using XML Signature (embedded in the document) if signingCredential is provided + * - Base64-encoded (without deflation) + * - Sent in an HTML form that auto-submits via JavaScript + */ +internal fun buildAuthnRequestPost( + spEntityId: String, + acsUrl: String, + idpSsoUrl: String, + relayState: String? = null, + signingCredential: Credential? = null, + nameIdFormat: NameIdFormat? = null, + forceAuthn: Boolean = false, + signatureAlgorithm: SignatureAlgorithm = DEFAULT_SIGNATURE_ALGORITHM, + requestedAuthnContext: SamlAuthnContext? = null +): AuthnRequestPostData { + LibSaml.ensureInitialized() + val authnRequest = buildAuthnRequest( + spEntityId, + acsUrl, + destination = idpSsoUrl, + nameIdFormat, + forceAuthn, + requestedAuthnContext + ) + val requestId = checkNotNull(authnRequest.id) + + if (signingCredential != null) { + authnRequest.addSignature(signingCredential, signatureAlgorithm) + } + + var authnRequestXml = authnRequest.marshalToString() + if (signingCredential != null) { + Signer.signObject(checkNotNull(authnRequest.signature)) + // Re-marshal to include the computed signature value + authnRequestXml = authnRequest.marshalToString() + } + + val encodedRequest = Base64.encode(source = authnRequestXml.toByteArray(Charsets.UTF_8)) + + return AuthnRequestPostData( + requestId = requestId, + idpSsoUrl = idpSsoUrl, + samlRequest = encodedRequest, + relayState = relayState + ) +} + +/** + * Adds a Signature element to an AuthnRequest for HTTP-POST binding. + * The actual signing (computing the signature value) must happen after marshaling. + */ +private fun AuthnRequest.addSignature(credential: Credential, signatureAlgorithm: SignatureAlgorithm) { + val builderFactory = XMLObjectProviderRegistrySupport.getBuilderFactory() + + val keyInfoGeneratorFactory = X509KeyInfoGeneratorFactory() + keyInfoGeneratorFactory.setEmitEntityCertificate(true) + val keyInfoGenerator = keyInfoGeneratorFactory.newInstance() + + this.signature = builderFactory.buildXmlObject(Signature.DEFAULT_ELEMENT_NAME) { + this.signingCredential = credential + this.signatureAlgorithm = signatureAlgorithm.uri + this.keyInfo = keyInfoGenerator.generate(credential) + this.canonicalizationAlgorithm = SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS + } +} + +@OptIn(ExperimentalTime::class) +private fun buildAuthnRequest( + spEntityId: String, + acsUrl: String, + destination: String, + nameIdFormat: NameIdFormat?, + forceAuthn: Boolean, + requestedAuthnContext: SamlAuthnContext? +): AuthnRequest { + val builderFactory = XMLObjectProviderRegistrySupport.getBuilderFactory() + val issuer = builderFactory.build(Issuer.DEFAULT_ELEMENT_NAME) { + value = spEntityId + } + + val nameIDPolicy = nameIdFormat?.let { format -> + builderFactory.build(NameIDPolicy.DEFAULT_ELEMENT_NAME) { + this.format = format.uri + this.allowCreate = true + } + } + + val reqAuthnContext = requestedAuthnContext?.let { requestedAuthnContext -> + val ref = builderFactory.build(AuthnContextClassRef.DEFAULT_ELEMENT_NAME) { + this.uri = requestedAuthnContext.uri + } + + builderFactory.build(RequestedAuthnContext.DEFAULT_ELEMENT_NAME) { + this.authnContextClassRefs.add(ref) + this.comparison = AuthnContextComparisonTypeEnumeration.EXACT + } + } + + return builderFactory.build(AuthnRequest.DEFAULT_ELEMENT_NAME) { + this.issuer = issuer + this.destination = destination + this.nameIDPolicy = nameIDPolicy + this.id = generateSecureSamlId() + this.assertionConsumerServiceURL = acsUrl + this.issueInstant = Clock.System.now().toJavaInstant() + this.protocolBinding = SAMLConstants.SAML2_POST_BINDING_URI + reqAuthnContext?.let { this.requestedAuthnContext = it } + if (forceAuthn) { + this.isForceAuthn = true + } + } +} + +/** + * AuthnRequest encoded for HTTP-POST binding. + * + * @property requestId The ID of the AuthnRequest (for correlating the response) + * @property idpSsoUrl The IdP's SSO URL (form action) + * @property samlRequest The Base64-encoded SAMLRequest + * @property relayState Optional RelayState for the IdP to return + */ +internal class AuthnRequestPostData( + val requestId: String, + val idpSsoUrl: String, + val samlRequest: String, + val relayState: String? +) { + /** + * Generates an auto-submit HTML form for the HTTP-POST binding. + * + * The form will automatically submit when loaded (requires JavaScript). + * A Submit button is included as a fallback for users without JavaScript. + */ + fun toAutoSubmitHtml(): String { + val relayStateInput = if (relayState != null) { + """""" + } else { + "" + } + return """ + | + | + | + | + | Redirecting to Identity Provider... + | + | + | + |
+ | + | $relayStateInput + | + |
+ | + | + """.trimMargin() + } +} diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlResponseProcessor.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlResponseProcessor.kt new file mode 100644 index 00000000000..5bb112eacd2 --- /dev/null +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlResponseProcessor.kt @@ -0,0 +1,243 @@ +/* + * Copyright 2014-2026 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package io.ktor.server.auth.saml + +import org.opensaml.saml.common.assertion.ValidationContext +import org.opensaml.saml.common.assertion.ValidationResult +import org.opensaml.saml.saml2.assertion.* +import org.opensaml.saml.saml2.assertion.impl.AudienceRestrictionConditionValidator +import org.opensaml.saml.saml2.assertion.impl.BearerSubjectConfirmationValidator +import org.opensaml.saml.saml2.core.Assertion +import org.opensaml.saml.saml2.core.EncryptedAssertion +import org.opensaml.saml.saml2.core.Response +import org.opensaml.saml.saml2.core.StatusCode +import org.opensaml.saml.saml2.encryption.Decrypter +import org.opensaml.saml.saml2.encryption.EncryptedElementTypeEncryptedKeyResolver +import org.opensaml.security.credential.Credential +import org.opensaml.xmlsec.encryption.support.ChainingEncryptedKeyResolver +import org.opensaml.xmlsec.encryption.support.InlineEncryptedKeyResolver +import org.opensaml.xmlsec.encryption.support.SimpleRetrievalMethodEncryptedKeyResolver +import org.opensaml.xmlsec.keyinfo.impl.StaticKeyInfoCredentialResolver +import org.w3c.dom.Document +import java.io.ByteArrayInputStream +import kotlin.io.encoding.Base64 +import kotlin.time.* +import kotlin.time.Duration.Companion.minutes + +/** + * Maximum acceptable age for Response IssueInstant. + */ +private val RESPONSE_LIFETIME = 5.minutes + +/** + * Processes and validates SAML responses from the Identity Provider. + * + * This class handles the complete validation chain for SAML assertions: + * 1. XML parsing with XXE protection + * 2. Response unmarshalling + * 3. Encrypted assertion decryption + * 4. XML signature verification (XSW protection) + * 5. Assertion semantic validation (timestamps, audience, subject confirmation) + * 6. Replay attack detection + */ +@OptIn(ExperimentalTime::class) +internal class SamlResponseProcessor( + private val spEntityId: String, + private val acsUrl: String, + private val idpMetadata: IdPMetadata, + private val decryptionCredential: Credential?, + private val clockSkew: Duration, + private val replayCache: SamlReplayCache, + private val requireSignedAssertions: Boolean, + private val requireSignedResponse: Boolean, + private val requireDestination: Boolean, + private val allowIdpInitiatedSso: Boolean, + private val signatureVerifier: SamlSignatureVerifier +) { + init { + LibSaml.ensureInitialized() + } + + /** + * SAML 2.0 Assertion validator using OpenSAML's built-in validation. + * + * This validator handles: + * - IssueInstant validation (with clock skew) + * - Conditions validation (NotBefore/NotOnOrAfter) + * - Subject confirmation validation (Bearer method) + * - Audience restriction validation + * - Issuer validation + * - Signature validation + */ + private val assertionValidator: SAML20AssertionValidator by lazy { + val conditionValidators = listOf( + AudienceRestrictionConditionValidator() + ) + val subjectConfirmationValidators = listOf( + BearerSubjectConfirmationValidator() + ) + val statementValidators = emptyList() + SAML20AssertionValidator( + conditionValidators, + subjectConfirmationValidators, + statementValidators, + null, // No generic assertion validator extension + signatureVerifier.signatureTrustEngine, + signatureVerifier.signatureProfileValidator + ) + } + + /** + * Processes a Base64-encoded SAML response. + * + * @param samlResponseBase64 The Base64-encoded SAML response from the POST parameter + * @param expectedRequestId The ID of the AuthnRequest that initiated this flow (for InResponseTo validation) + * @return SamlCredential containing the validated assertion + * @throws SamlValidationException if validation fails + */ + suspend fun processResponse(samlResponseBase64: String, expectedRequestId: String?): SamlCredential { + val samlResponseXml = String(bytes = Base64.decode(samlResponseBase64)) + val response = parseResponse(samlResponseXml).also { it.validate(expectedRequestId) } + val assertion = response.extractAssertion().also { it.validate(expectedRequestId) } + return SamlCredential(response, assertion) + } + + private fun parseResponse(xml: String): Response = withValidationException { + val document: Document = LibSaml.parserPool.parse(ByteArrayInputStream(xml.toByteArray())) + document.documentElement.unmarshall() + } + + private fun Response.validate(expectedRequestId: String?) { + val statusCode = status?.statusCode?.value + if (statusCode != StatusCode.SUCCESS) { + val statusMessage = status?.statusMessage?.value ?: "No message" + throw SamlValidationException("SAML response status is not Success: $statusCode - $statusMessage") + } + + when { + expectedRequestId != null -> { + samlAssert(inResponseTo == expectedRequestId) { "InResponseTo mismatch" } + } + + !allowIdpInitiatedSso -> throw SamlValidationException("IdP-initiated SSO is not allowed.") + } + + val issuer = issuer?.value + samlAssert(issuer == idpMetadata.entityId) { "Response issuer mismatch" } + + val destination = destination + samlAssert(!requireDestination || destination != null) { "Response Destination is not present" } + samlAssert(destination == null || destination == acsUrl) { "Response Destination mismatch" } + + val issueInstant = samlRequire(issueInstant?.toKotlinInstant()) { "Response IssueInstant is required" } + + val now = Clock.System.now() + val effectiveMinTime = now - clockSkew - RESPONSE_LIFETIME + val effectiveMaxTime = now + clockSkew + + samlAssert(issueInstant >= effectiveMinTime) { "Response IssueInstant is too old" } + samlAssert(issueInstant <= effectiveMaxTime) { "Response IssueInstant is in the future" } + + if (requireSignedResponse) { + signatureVerifier.verify(signedObject = this) + } + } + + /** + * Extracts and decrypts the assertion from the response. + */ + private fun Response.extractAssertion(): Assertion { + return samlRequire(encryptedAssertions.firstOrNull()?.decrypt() ?: assertions.firstOrNull()) { + "No assertion found in SAML response" + } + } + + private fun EncryptedAssertion.decrypt(): Assertion = withValidationException { + val decryptionCredential = checkNotNull(decryptionCredential) { + "No valid decryption credential is provided" + } + val kekResolver = StaticKeyInfoCredentialResolver(decryptionCredential) + + val encKeyResolvers = listOf( + InlineEncryptedKeyResolver(), + EncryptedElementTypeEncryptedKeyResolver(), + SimpleRetrievalMethodEncryptedKeyResolver() + ) + val encryptedKeyResolver = ChainingEncryptedKeyResolver(encKeyResolvers) + + val decrypter = Decrypter(null, kekResolver, encryptedKeyResolver).apply { + isRootInNewDocument = true + } + decrypter.decrypt(this) + } + + private suspend fun Assertion.validate(expectedRequestId: String?) { + validateAssertionSemantics(expectedRequestId) + checkReplay() + } + + /** + * Validates assertion semantics using OpenSAML's SAML20AssertionValidator. + * + * The validator handles: + * - IssueInstant validation (with clock skew tolerance) + * - Conditions validation (NotBefore / NotOnOrAfter) + * - Subject confirmation validation (Bearer method with InResponseTo, Recipient, NotOnOrAfter) + * - Audience restriction validation + * - Issuer validation + * - Signature validation (if [requireSignedAssertions] is true or assertion is signed) + * + * @throws SamlValidationException if validation fails + */ + private fun Assertion.validateAssertionSemantics(expectedRequestId: String?) = withValidationException { + val validationContext = buildValidationContext(expectedRequestId) + + val result = assertionValidator.validate(this, validationContext) + samlAssert(result == ValidationResult.VALID) { "SAML assertion validation failed" } + } + + /** + * Builds a validation context for assertion validation. + * + * @return ValidationContext with required parameters + */ + private fun buildValidationContext(expectedRequestId: String?): ValidationContext { + val params = mutableMapOf() + params[SAML2AssertionValidationParameters.CLOCK_SKEW] = clockSkew.toJavaDuration() + params[SAML2AssertionValidationParameters.VALID_ISSUERS] = setOf(idpMetadata.entityId) + params[SAML2AssertionValidationParameters.COND_VALID_AUDIENCES] = setOf(spEntityId) + params[SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS] = setOf(acsUrl) + params[SAML2AssertionValidationParameters.SIGNATURE_REQUIRED] = requireSignedAssertions + params[SAML2AssertionValidationParameters.SC_CHECK_ADDRESS] = false + if (expectedRequestId != null) { + params[SAML2AssertionValidationParameters.SC_VALID_IN_RESPONSE_TO] = expectedRequestId + params[SAML2AssertionValidationParameters.SC_IN_RESPONSE_TO_REQUIRED] = true + } else { + params[SAML2AssertionValidationParameters.SC_IN_RESPONSE_TO_REQUIRED] = false + } + return ValidationContext(params) + } + + /** + * Checks if the assertion has been replayed (already processed). + * + * @throws SamlValidationException if this is a replay + */ + private suspend fun Assertion.checkReplay() { + val assertionId = samlRequire(id) { "Assertion must have an ID" } + val expirationTime = conditions?.notOnOrAfter?.toKotlinInstant() + ?: (Clock.System.now() + RESPONSE_LIFETIME) + + val recorded = replayCache.tryRecordAssertion(assertionId, expirationTime) + samlAssert(recorded) { + "Assertion has already been processed (replay attack)" + } + } +} + +/** + * Exception thrown when SAML validation fails. + */ +public class SamlValidationException(message: String, cause: Throwable? = null) : Exception(message, cause) diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlSignatureValidator.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlSignatureValidator.kt new file mode 100644 index 00000000000..c6121c6eee0 --- /dev/null +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlSignatureValidator.kt @@ -0,0 +1,131 @@ +/* + * Copyright 2014-2026 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package io.ktor.server.auth.saml + +import org.opensaml.saml.common.SignableSAMLObject +import org.opensaml.saml.security.impl.SAMLSignatureProfileValidator +import org.opensaml.security.credential.Credential +import org.opensaml.security.credential.MutableCredential +import org.opensaml.security.credential.impl.CollectionCredentialResolver +import org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap +import org.opensaml.xmlsec.signature.support.SignatureValidator +import org.opensaml.xmlsec.signature.support.impl.ExplicitKeySignatureTrustEngine +import org.opensaml.xmlsec.signature.support.impl.SignatureAlgorithmValidator +import java.security.Signature +import kotlin.io.encoding.Base64 + +/** + * Verifies XML signatures on SAML objects. + * + * This class provides comprehensive signature validation: + * 1. Validates signature and digest algorithms against allowlists (if configured) + * 2. SAML signature profile validation (prevents XSW attacks) + * 3. Cryptographic signature verification against IdP credentials + */ +internal class SamlSignatureVerifier( + idpMetadata: IdPMetadata, + private val allowedSignatureAlgorithms: Set? = null, + allowedDigestAlgorithms: Set? = null +) { + /** + * Credentials with entity ID set for proper credential resolution. + */ + val credentials: List = idpMetadata.signingCredentials.map { credential -> + credential.apply { + if (this is MutableCredential) { + entityId = idpMetadata.entityId + } + } + } + + private val algorithmValidator: SignatureAlgorithmValidator? = + if (allowedSignatureAlgorithms != null || allowedDigestAlgorithms != null) { + val whitelistAlgorithms = buildSet { + allowedSignatureAlgorithms?.let { algorithms -> addAll(algorithms.map { it.uri }) } + allowedDigestAlgorithms?.let { algorithms -> addAll(algorithms.map { it.uri }) } + } + SignatureAlgorithmValidator(whitelistAlgorithms, null) + } else { + null + } + + val signatureTrustEngine = run { + val credentialResolver = CollectionCredentialResolver(credentials) + val keyInfoResolver = DefaultSecurityConfigurationBootstrap + .buildBasicInlineKeyInfoCredentialResolver() + ExplicitKeySignatureTrustEngine(credentialResolver, keyInfoResolver) + } + + val signatureProfileValidator = SAMLSignatureProfileValidator() + + /** + * Verifies the XML signature on a SignableSAMLObject. + */ + fun verify(signedObject: SignableSAMLObject) { + val signature = samlRequire(signedObject.signature) { "No signature to verify" } + + if (algorithmValidator != null) { + try { + algorithmValidator.validate(signature) + } catch (e: Exception) { + throw SamlValidationException("Algorithm validation failed: ${e.message}", e) + } + } + + try { + signatureProfileValidator.validate(signature) + } catch (e: Exception) { + throw SamlValidationException("Signature profile validation failed: ${e.message}", e) + } + + val valid = credentials.any { credential -> + runCatching { SignatureValidator.validate(signature, credential) }.isSuccess + } + samlAssert(valid) { "Signature verification failed with all IdP credentials" } + } + + /** + * Verifies a query string signature. + */ + fun verifyQueryString( + queryString: String, + signatureBase64: String, + signatureAlgorithmUri: String + ) { + val signatureAlgorithm = samlRequire(SignatureAlgorithm.from(signatureAlgorithmUri)) { + "Unsupported signature algorithm." + } + + samlAssert(allowedSignatureAlgorithms == null || signatureAlgorithm in allowedSignatureAlgorithms) { + "Signature algorithm not in allowlist." + } + + val signatureBytes = try { + Base64.decode(source = signatureBase64) + } catch (e: Exception) { + throw SamlValidationException("Invalid Base64 signature", e) + } + + val signatureIdx = queryString.indexOf("&Signature=") + val queryStringWithoutSignature = if (signatureIdx >= 0) { + queryString.substring(0, signatureIdx).toByteArray(Charsets.UTF_8) + } else { + throw SamlValidationException("Missing Signature parameter") + } + + // Try to verify with each credential + val verified = credentials.any { credential -> + val publicKey = credential.publicKey ?: return@any false + runCatching { + val signature = Signature.getInstance(signatureAlgorithm.jcaAlgorithm) + signature.initVerify(publicKey) + signature.update(queryStringWithoutSignature) + signature.verify(signatureBytes) + }.getOrDefault(false) + } + + samlAssert(verified) { "HTTP-Redirect signature verification failed" } + } +} diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlUtils.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlUtils.kt index e115db06e17..6f023746803 100644 --- a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlUtils.kt +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlUtils.kt @@ -4,10 +4,64 @@ package io.ktor.server.auth.saml +import io.ktor.utils.io.charsets.name import org.opensaml.core.xml.XMLObject +import org.opensaml.core.xml.XMLObjectBuilderFactory import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport +import org.opensaml.core.xml.io.MarshallingException import org.opensaml.core.xml.io.UnmarshallingException +import org.opensaml.saml.common.SAMLObject +import org.opensaml.saml.common.SAMLObjectBuilder +import org.opensaml.security.credential.Credential import org.w3c.dom.Element +import java.io.ByteArrayOutputStream +import java.io.StringWriter +import java.net.URLEncoder +import java.security.Signature +import java.util.zip.Deflater +import java.util.zip.DeflaterOutputStream +import javax.xml.namespace.QName +import javax.xml.transform.TransformerFactory +import javax.xml.transform.dom.DOMSource +import javax.xml.transform.stream.StreamResult +import kotlin.io.encoding.Base64 +import kotlin.text.Charsets +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +/** + * Result of building a SAML redirect (AuthnRequest, LogoutRequest, or LogoutResponse). + * + * @property messageId The ID of the SAML message + * @property redirectUrl The complete redirect URL with all parameters + */ +public class SamlRedirectResult( + public val messageId: String, + public val redirectUrl: String +) + +/** + * @return Secure random ID with the prefix "_" + */ +@OptIn(ExperimentalUuidApi::class) +internal fun generateSecureSamlId(): String = "_" + Uuid.random() + +/** + * Marshals a SAML XMLObject to an XML string. + */ +internal fun XMLObject.marshalToString(): String { + val marshallerFactory = XMLObjectProviderRegistrySupport.getMarshallerFactory() + val marshaller = marshallerFactory.getMarshaller(this) + ?: throw MarshallingException("No marshaller found for object: $elementQName") + + val element = marshaller.marshall(this) + + val transformerFactory = TransformerFactory.newInstance() + val transformer = transformerFactory.newTransformer() + val stringWriter = StringWriter() + transformer.transform(DOMSource(element), StreamResult(stringWriter)) + return stringWriter.toString() +} /** * Unmarshalls a DOM Element to a SAML XMLObject of the specified type. @@ -18,3 +72,145 @@ internal inline fun Element.unmarshall(): T { ?: throw UnmarshallingException("No unmarshaller found for element: ${this.localName}") return unmarshaller.unmarshall(this) as T } + +internal fun String.encodeSamlMessage(deflate: Boolean): String { + val bytes = toByteArray(Charsets.UTF_8) + if (!deflate) { + return Base64.encode(source = bytes) + } + val bytesOut = ByteArrayOutputStream() + val deflater = Deflater(Deflater.DEFLATED, true) + DeflaterOutputStream(bytesOut, deflater).use { it.write(bytes) } + return Base64.encode(source = bytesOut.toByteArray()) +} + +@Suppress("UNCHECKED_CAST") +internal inline fun XMLObjectBuilderFactory.build( + key: QName, + crossinline configure: O.() -> Unit +): O { + val builder = getBuilder(key) as SAMLObjectBuilder + return builder.buildObject().apply(configure) +} + +/** + * Builds a SAML redirect result by marshaling a request/response object and constructing the redirect URL. + * + * @param messageId The ID of the SAML message + * @param samlObject The SAML object to marshal + * @param destinationUrl The IdP URL to redirect to + * @param parameterName The query parameter name ("SAMLRequest" or "SAMLResponse") + * @param relayState Optional RelayState parameter + * @param signingCredential Optional credential for signing + * @param signatureAlgorithm The signature algorithm to use if signing + */ +internal fun buildSamlRedirectResult( + messageId: String, + samlObject: XMLObject, + destinationUrl: String, + parameterName: String, + relayState: String?, + signingCredential: Credential?, + signatureAlgorithm: SignatureAlgorithm +): SamlRedirectResult { + val xml = samlObject.marshalToString() + val encodedMessage = xml.encodeSamlMessage(deflate = true) + val redirectUrl = buildSamlRedirectUrl( + destinationUrl = destinationUrl, + parameterName = parameterName, + encodedMessage = encodedMessage, + relayState = relayState, + signingCredential = signingCredential, + signatureAlgorithm = signatureAlgorithm + ) + return SamlRedirectResult(messageId, redirectUrl) +} + +private fun buildSamlRedirectUrl( + destinationUrl: String, + parameterName: String, + encodedMessage: String, + relayState: String?, + signingCredential: Credential?, + signatureAlgorithm: SignatureAlgorithm +): String { + val enc = Charsets.UTF_8.name + val urlEncodedMessage = URLEncoder.encode(encodedMessage, enc) + + val urlBuilder = StringBuilder(destinationUrl) + urlBuilder.append(if (destinationUrl.contains("?")) "&" else "?") + urlBuilder.append(parameterName).append("=").append(urlEncodedMessage) + + if (!relayState.isNullOrBlank()) { + urlBuilder.append("&RelayState=").append(URLEncoder.encode(relayState, enc)) + } + if (signingCredential != null) { + urlBuilder + .append("&SigAlg=") + .append(URLEncoder.encode(signatureAlgorithm.uri, enc)) + + val queryString = urlBuilder.toString().substringAfter("?") + val signature = signQueryString(queryString, signingCredential, signatureAlgorithm) + urlBuilder + .append("&Signature=") + .append(URLEncoder.encode(signature, enc)) + } + + return urlBuilder.toString() +} + +@Suppress("UNCHECKED_CAST") +internal inline fun XMLObjectBuilderFactory.buildXmlObject( + key: QName, + crossinline configure: O.() -> Unit +): O { + val builder = getBuilder(key) as org.opensaml.core.xml.XMLObjectBuilder + return builder.buildObject(key).apply(configure) +} + +/** + * Signs a query string for HTTP-Redirect binding. + * + * In HTTP-Redirect binding, the signature is computed over the raw query string + * (NOT the XML itself). The signature is then appended as a separate parameter. + */ +internal fun signQueryString( + queryString: String, + credential: Credential, + signatureAlgorithm: SignatureAlgorithm +): String { + val privateKey = requireNotNull(credential.privateKey) { + "Credential must have a private key for signing" + } + + val signature = Signature.getInstance(signatureAlgorithm.jcaAlgorithm) + signature.initSign(privateKey) + signature.update(queryString.toByteArray(Charsets.UTF_8)) + + val signatureBytes = signature.sign() + return Base64.encode(source = signatureBytes) +} + +/** + * Checks if [value] is not null, throwing [SamlValidationException] with a message from [lazyMessage] if it is null. + */ +internal inline fun samlRequire(value: T?, crossinline lazyMessage: () -> String): T { + return value ?: throw SamlValidationException(lazyMessage()) +} + +/** + * Checks if [value] is true, throwing [SamlValidationException] with a message from [lazyMessage] if it is false. + */ +internal inline fun samlAssert(value: Boolean, crossinline lazyMessage: () -> String) { + if (!value) { + throw SamlValidationException(lazyMessage()) + } +} + +internal inline fun withValidationException(crossinline block: () -> T): T { + try { + return block() + } catch (e: Exception) { + throw SamlValidationException("SAML validation failed", e) + } +} diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/RelayStateValidationTest.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/RelayStateValidationTest.kt new file mode 100644 index 00000000000..0982a9dd46b --- /dev/null +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/RelayStateValidationTest.kt @@ -0,0 +1,129 @@ +/* + * Copyright 2014-2026 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package io.ktor.server.auth.saml + +import kotlin.test.* + +/** + * Unit tests for RelayValidator to prevent open redirect attacks. + */ +class RelayStateValidationTest { + + @Test + fun `relative paths with empty allowlist are accepted`() { + val validator = RelayValidator(allowedRelayStateUrls = emptyList()) + + // Basic relative paths with various components + assertTrue(validator.validate("/dashboard")) + assertTrue(validator.validate("/search?q=test&page=1")) + assertTrue(validator.validate("/page#section")) + assertTrue(validator.validate("/any/path")) + assertTrue(validator.validate("/")) + } + + @Test + fun `dangerous URL patterns are blocked`() { + val validator = RelayValidator(allowedRelayStateUrls = emptyList()) + + // Scheme-relative, backslashes, and control characters + assertFalse(validator.validate("//evil.com/phish")) + assertFalse(validator.validate("/foo\\..\\bar")) + assertFalse(validator.validate("/foo\u0000bar")) + + // Dangerous schemes + assertFalse(validator.validate("javascript:alert(1)")) + assertFalse(validator.validate("data:text/html,")) + assertFalse(validator.validate("ftp://example.com/file")) + assertFalse(validator.validate("file:///etc/passwd")) + } + + @Test + fun `absolute URLs with allowlist validation`() { + val validator = RelayValidator(allowedRelayStateUrls = listOf("https://myapp.example.com/")) + + // Allowed origin + assertTrue(validator.validate("https://myapp.example.com/dashboard")) + + // Blocked: wrong origin, port, or scheme + assertFalse(validator.validate("https://evil.com/phish")) + assertFalse(validator.validate("https://myapp.example.com:8443/dashboard")) + assertFalse(validator.validate("http://myapp.example.com/dashboard")) + + // Blocked: bypass attempts + assertFalse(validator.validate("https://myapp.example.com@evil.com/phish")) + assertFalse(validator.validate("https://myapp.example.com.evil.com/phish")) + } + + @Test + fun `path prefix matching with segment boundaries`() { + // Prefix without trailing slash - exact match only + val validatorNoSlash = RelayValidator(allowedRelayStateUrls = listOf("https://myapp.example.com/app")) + assertTrue(validatorNoSlash.validate("https://myapp.example.com/app/dashboard")) + assertFalse(validatorNoSlash.validate("https://myapp.example.com/application")) + + // Prefix with trailing slash - allows subpaths + val validatorWithSlash = RelayValidator(allowedRelayStateUrls = listOf("https://myapp.example.com/app/")) + assertTrue(validatorWithSlash.validate("https://myapp.example.com/app/dashboard")) + assertFalse(validatorWithSlash.validate("https://myapp.example.com/app")) + } + + @Test + fun `multiple allowed origins`() { + val validator = RelayValidator( + allowedRelayStateUrls = listOf( + "https://app1.example.com/", + "https://app2.example.com/" + ) + ) + + assertTrue(validator.validate("https://app1.example.com/page")) + assertTrue(validator.validate("https://app2.example.com/page")) + assertFalse(validator.validate("https://app3.example.com/page")) + } + + @Test + fun `null allowlist disables validation`() { + val validator = RelayValidator(allowedRelayStateUrls = null) + assertTrue(validator.validate("https://any-domain.com/any-path")) + assertTrue(validator.validate("/any/path")) + } + + @Test + fun `host comparison is case insensitive`() { + val validator = RelayValidator(allowedRelayStateUrls = listOf("https://MyApp.Example.COM/")) + assertTrue(validator.validate("https://myapp.example.com/page")) + } + + @Test + fun `mixed relative and absolute URLs in allowlist`() { + val validator = RelayValidator( + allowedRelayStateUrls = listOf( + "/local/", + "https://external.com/" + ) + ) + + assertTrue(validator.validate("/local/page")) + assertFalse(validator.validate("/other/page")) + assertTrue(validator.validate("https://external.com/page")) + assertFalse(validator.validate("https://other.com/page")) + } + + @Test + fun `default ports are handled correctly`() { + val validator = RelayValidator(allowedRelayStateUrls = listOf("https://myapp.example.com/")) + assertTrue(validator.validate("https://myapp.example.com/page")) + + val validatorWithPort = RelayValidator(allowedRelayStateUrls = listOf("https://myapp.example.com:443/")) + assertTrue(validatorWithPort.validate("https://myapp.example.com/page")) + } + + @Test + fun `malformed URLs are rejected`() { + val validator = RelayValidator(allowedRelayStateUrls = listOf("https://myapp.example.com/")) + assertFalse(validator.validate("https://[invalid")) + assertFalse(validator.validate("ht!tp://example.com")) + } +} diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SAMLPrincipalTest.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SAMLPrincipalTest.kt new file mode 100644 index 00000000000..68419f6b0e7 --- /dev/null +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SAMLPrincipalTest.kt @@ -0,0 +1,61 @@ +/* + * Copyright 2014-2026 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package io.ktor.server.auth.saml + +import kotlin.test.* +import kotlin.time.ExperimentalTime + +@OptIn(ExperimentalTime::class) +class SamlPrincipalTest { + + @Test + fun `test SamlCredential construction`() { + val expectedNameId = "user123@example.com" + val expectedSessionIndex = "session-12345" + val assertion = SamlTestUtils.createTestAssertion( + nameId = expectedNameId, + sessionIndex = expectedSessionIndex + ) + + val response = SamlTestUtils.createTestResponse(assertion) + val credential = SamlCredential(response, assertion) + val principal = SamlPrincipal(assertion) + + assertEquals(expectedNameId, principal.nameId) + assertEquals(response, credential.response) + assertEquals(assertion, credential.assertion) + assertEquals(expectedSessionIndex, principal.sessionIndex) + } + + @Test + fun `test SamlPrincipal throws exception when NameID is missing`() { + val assertion = SamlTestUtils.createTestAssertion().apply { + subject?.nameID = null // Remove the NameID + } + assertFailsWith { + SamlPrincipal(assertion) + } + } + + @Test + fun `test SamlPrincipal sessionIndex is null when AuthnStatement is missing`() { + val assertion = SamlTestUtils.createTestAssertion(sessionIndex = null) + val principal = SamlPrincipal(assertion) + + assertNull(principal.sessionIndex) + } + + @Test + fun `test SamlPrincipal missing attribute`() { + val assertion = SamlTestUtils.createTestAssertion() + val principal = SamlPrincipal(assertion) + + assertFalse(principal.hasAttribute("nonexistent")) + assertNull(principal.getAttribute("nonexistent")) + + val values = principal.getAttributeValues("nonexistent") + assertTrue(values.isEmpty()) + } +} diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SAMLRequestBuilderTest.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SAMLRequestBuilderTest.kt new file mode 100644 index 00000000000..ad61aaad40a --- /dev/null +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SAMLRequestBuilderTest.kt @@ -0,0 +1,99 @@ +/* + * Copyright 2014-2026 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package io.ktor.server.auth.saml + +import java.net.URLEncoder +import kotlin.test.* + +class SamlRequestBuilderTest { + + @Test + fun `test buildAuthnRequestRedirect generates valid request ID and URL`() { + val spEntityId = "test-sp" + val acsUrl = "http://localhost:8080/saml/acs" + val idpSsoUrl = "http://idp.example.com/sso" + + val result = buildAuthnRequestRedirect( + spEntityId = spEntityId, + acsUrl = acsUrl, + idpSsoUrl = idpSsoUrl + ) + + assertNotNull(result.messageId) + assertTrue(result.messageId.startsWith("_")) + assertTrue(result.redirectUrl.startsWith(idpSsoUrl)) + assertTrue(result.redirectUrl.contains("SAMLRequest=")) + assertTrue(result.redirectUrl.startsWith("$idpSsoUrl?")) + + assertFalse(result.redirectUrl.contains("Signature=")) + assertFalse(result.redirectUrl.contains("SigAlg=")) + } + + @Test + fun `test buildAuthnRequestRedirect includes RelayState when provided`() { + val spEntityId = "test-sp" + val acsUrl = "http://localhost:8080/saml/acs" + val idpSsoUrl = "http://idp.example.com/sso" + val relayState = "/protected/resource" + + val result = buildAuthnRequestRedirect( + spEntityId = spEntityId, + acsUrl = acsUrl, + idpSsoUrl = idpSsoUrl, + relayState = relayState + ) + + assertTrue(result.redirectUrl.contains("RelayState=")) + assertTrue(result.redirectUrl.contains(URLEncoder.encode(relayState, "UTF-8"))) + } + + @Test + fun `test buildAuthnRequestRedirect generates unique request IDs`() { + val spEntityId = "test-sp" + val acsUrl = "http://localhost:8080/saml/acs" + val idpSsoUrl = "http://idp.example.com/sso" + + val result1 = buildAuthnRequestRedirect( + spEntityId = spEntityId, + acsUrl = acsUrl, + idpSsoUrl = idpSsoUrl + ) + + val result2 = buildAuthnRequestRedirect( + spEntityId = spEntityId, + acsUrl = acsUrl, + idpSsoUrl = idpSsoUrl + ) + + assertNotEquals(result1.messageId, result2.messageId) + } + + @Test + fun `test buildAuthnRequestRedirect with forceAuthn`() { + val spEntityId = "test-sp" + val acsUrl = "http://localhost:8080/saml/acs" + val idpSsoUrl = "http://idp.example.com/sso" + + // Test with forceAuthn enabled + val result = buildAuthnRequestRedirect( + spEntityId = spEntityId, + acsUrl = acsUrl, + idpSsoUrl = idpSsoUrl, + forceAuthn = true + ) + + assertNotNull(result.messageId) + assertTrue(result.redirectUrl.contains("SAMLRequest=")) + } + + @Test + fun `test NameIdFormat constants`() { + assertEquals("urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress", NameIdFormat.Email.uri) + assertEquals("urn:oasis:names:tc:SAML:2.0:nameid-format:persistent", NameIdFormat.Persistent.uri) + assertEquals("urn:oasis:names:tc:SAML:2.0:nameid-format:transient", NameIdFormat.Transient.uri) + assertEquals("urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified", NameIdFormat.Unspecified.uri) + assertEquals("urn:custom:format", NameIdFormat("urn:custom:format").uri) + } +} diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SAMLResponseProcessorTest.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SAMLResponseProcessorTest.kt new file mode 100644 index 00000000000..77a8ce5a2ce --- /dev/null +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SAMLResponseProcessorTest.kt @@ -0,0 +1,641 @@ +/* + * Copyright 2014-2026 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package io.ktor.server.auth.saml + +import kotlinx.coroutines.test.runTest +import org.opensaml.saml.saml2.core.StatusCode +import kotlin.test.* +import kotlin.time.Clock +import kotlin.time.Duration +import kotlin.time.Duration.Companion.hours +import kotlin.time.Duration.Companion.minutes +import kotlin.time.ExperimentalTime + +@OptIn(ExperimentalTime::class) +class SamlResponseProcessorTest { + private lateinit var replayCache: InMemorySamlReplayCache + + @BeforeTest + fun setup() { + replayCache = InMemorySamlReplayCache() + } + + @AfterTest + fun teardown() { + replayCache.close() + } + + @Test + fun `test decrypt encrypted assertion successfully`() = runTest { + val processor = createProcessor(requireSignedAssertions = false) + + // Create assertion with audience for SP + val assertion = SamlTestUtils.createTestAssertion( + nameId = "user@example.com", + issuerEntityId = IDP_ENTITY_ID, + audienceEntityId = SP_ENTITY_ID, + recipientUrl = ACS_URL + ) + val encryptedAssertion = SamlTestUtils.encryptAssertion(assertion, spCredentials.credential) + val response = SamlTestUtils.createTestResponseWithEncryptedAssertion( + encryptedAssertion = encryptedAssertion, + issuerEntityId = IDP_ENTITY_ID + ) + + val base64Response = SamlTestUtils.encodeResponseToBase64(response) + + val credential = processor.processResponse(base64Response, null) + assertEquals("user@example.com", credential.assertion.subject?.nameID?.value) + } + + @Test + fun `test decrypt fails without decryption credential`() = runTest { + val idpMetadata = IdPMetadata( + entityId = IDP_ENTITY_ID, + ssoUrl = "https://idp.example.com/sso", + sloUrl = null, + signingCredentials = listOf(idpCredentials.credential) + ) + val processor = SamlResponseProcessor( + spEntityId = SP_ENTITY_ID, + acsUrl = ACS_URL, + idpMetadata = idpMetadata, + decryptionCredential = null, // No decryption credential + clockSkew = 5.minutes, + replayCache = replayCache, + requireSignedAssertions = false, + requireSignedResponse = false, + requireDestination = false, + allowIdpInitiatedSso = true, + signatureVerifier = SamlSignatureVerifier(idpMetadata) + ) + + val assertion = SamlTestUtils.createTestAssertion( + issuerEntityId = IDP_ENTITY_ID, + audienceEntityId = SP_ENTITY_ID + ) + val encryptedAssertion = SamlTestUtils.encryptAssertion(assertion, spCredentials.credential) + val response = SamlTestUtils.createTestResponseWithEncryptedAssertion( + encryptedAssertion = encryptedAssertion, + issuerEntityId = IDP_ENTITY_ID + ) + val base64Response = SamlTestUtils.encodeResponseToBase64(response) + + assertFailsWith { + processor.processResponse(base64Response, null) + } + } + + @Test + fun `test decrypt fails with wrong key`() = runTest { + val processor = createProcessor(requireSignedAssertions = false) + val otherCredentials = SamlTestUtils.generateTestCredentials() + val assertion = SamlTestUtils.createTestAssertion( + issuerEntityId = IDP_ENTITY_ID, + audienceEntityId = SP_ENTITY_ID + ) + val encryptedAssertion = SamlTestUtils.encryptAssertion(assertion, otherCredentials.credential) + val response = SamlTestUtils.createTestResponseWithEncryptedAssertion( + encryptedAssertion = encryptedAssertion, + issuerEntityId = IDP_ENTITY_ID + ) + val base64Response = SamlTestUtils.encodeResponseToBase64(response) + + assertFailsWith { + processor.processResponse(base64Response, expectedRequestId = null) + } + } + + @Test + fun `test response signature validation`() = runTest { + val processorRequiresSig = createProcessor(requireSignedAssertions = false, requireSignedResponse = true) + val processorNoSigRequired = createProcessor(requireSignedAssertions = false, requireSignedResponse = false) + + // Valid signed response is accepted + val assertion1 = SamlTestUtils.createTestAssertion( + issuerEntityId = IDP_ENTITY_ID, + audienceEntityId = SP_ENTITY_ID, + recipientUrl = ACS_URL + ) + val response1 = SamlTestUtils.createTestResponse(assertion = assertion1, issuerEntityId = IDP_ENTITY_ID) + SamlTestUtils.signResponse(response1, idpCredentials.credential) + assertNotNull(processorRequiresSig.processResponse(SamlTestUtils.encodeResponseToBase64(response1), null)) + + // Unsigned response is rejected when signature required + val assertion2 = SamlTestUtils.createTestAssertion( + issuerEntityId = IDP_ENTITY_ID, + audienceEntityId = SP_ENTITY_ID, + recipientUrl = ACS_URL + ) + val response2 = SamlTestUtils.createTestResponse(assertion = assertion2, issuerEntityId = IDP_ENTITY_ID) + assertFailsWith { + processorRequiresSig.processResponse(SamlTestUtils.encodeResponseToBase64(response2), null) + } + + // Test 3: Unsigned response is accepted when signature not required + val assertion3 = SamlTestUtils.createTestAssertion( + issuerEntityId = IDP_ENTITY_ID, + audienceEntityId = SP_ENTITY_ID, + recipientUrl = ACS_URL + ) + val response3 = SamlTestUtils.createTestResponse(assertion = assertion3, issuerEntityId = IDP_ENTITY_ID) + assertNotNull(processorNoSigRequired.processResponse(SamlTestUtils.encodeResponseToBase64(response3), null)) + } + + @Test + fun `test response with invalid signature is rejected`() = runTest { + val processor = createProcessor(requireSignedAssertions = false, requireSignedResponse = true) + + val assertion = SamlTestUtils.createTestAssertion( + issuerEntityId = IDP_ENTITY_ID, + audienceEntityId = SP_ENTITY_ID, + recipientUrl = ACS_URL + ) + val response = SamlTestUtils.createTestResponse(assertion = assertion, issuerEntityId = IDP_ENTITY_ID) + + val wrongCredentials = SamlTestUtils.generateTestCredentials() + SamlTestUtils.signResponse(response, wrongCredentials.credential) + + assertFailsWith { + processor.processResponse(SamlTestUtils.encodeResponseToBase64(response), null) + } + } + + @Test + fun `test assertion signature validation`() = runTest { + val processorRequiresSig = createProcessor(requireSignedAssertions = true) + val processorNoSigRequired = createProcessor(requireSignedAssertions = false) + + // Valid signed assertion is accepted + val assertion1 = SamlTestUtils.createTestAssertion( + issuerEntityId = IDP_ENTITY_ID, + audienceEntityId = SP_ENTITY_ID, + recipientUrl = ACS_URL + ) + SamlTestUtils.signAssertion(assertion1, idpCredentials.credential) + val response1 = SamlTestUtils.createTestResponse(assertion = assertion1, issuerEntityId = IDP_ENTITY_ID) + assertNotNull(processorRequiresSig.processResponse(SamlTestUtils.encodeResponseToBase64(response1), null)) + + // Unsigned assertion is rejected when signature required + val assertion2 = SamlTestUtils.createTestAssertion( + issuerEntityId = IDP_ENTITY_ID, + audienceEntityId = SP_ENTITY_ID + ) + val response2 = SamlTestUtils.createTestResponse(assertion = assertion2, issuerEntityId = IDP_ENTITY_ID) + assertFailsWith { + processorRequiresSig.processResponse(SamlTestUtils.encodeResponseToBase64(response2), null) + } + + // Unsigned assertion is accepted when signature not required + val assertion3 = SamlTestUtils.createTestAssertion( + issuerEntityId = IDP_ENTITY_ID, + audienceEntityId = SP_ENTITY_ID, + recipientUrl = ACS_URL + ) + val response3 = SamlTestUtils.createTestResponse(assertion = assertion3, issuerEntityId = IDP_ENTITY_ID) + assertNotNull(processorNoSigRequired.processResponse(SamlTestUtils.encodeResponseToBase64(response3), null)) + } + + @Test + fun `test assertion with invalid signature is rejected`() = runTest { + val processor = createProcessor(requireSignedAssertions = true) + + val assertion = SamlTestUtils.createTestAssertion( + issuerEntityId = IDP_ENTITY_ID, + audienceEntityId = SP_ENTITY_ID + ) + + val wrongCredentials = SamlTestUtils.generateTestCredentials() + SamlTestUtils.signAssertion(assertion, wrongCredentials.credential) + + val response = SamlTestUtils.createTestResponse(assertion = assertion, issuerEntityId = IDP_ENTITY_ID) + assertFailsWith { + processor.processResponse(SamlTestUtils.encodeResponseToBase64(response), null) + } + } + + @Test + fun `test reject wrong audience`() = runTest { + val processor = createProcessor(requireSignedAssertions = false) + + val assertion = SamlTestUtils.createTestAssertion( + issuerEntityId = IDP_ENTITY_ID, + audienceEntityId = "https://wrong-sp.example.com", // Wrong audience! + recipientUrl = ACS_URL + ) + + val response = SamlTestUtils.createTestResponse( + assertion = assertion, + issuerEntityId = IDP_ENTITY_ID + ) + val base64Response = SamlTestUtils.encodeResponseToBase64(response) + + assertFailsWith { + processor.processResponse(base64Response, null) + } + } + + @Test + fun `test reject expired assertion`() = runTest { + val processor = createProcessor(requireSignedAssertions = false, clockSkew = 1.minutes) + + val assertion = SamlTestUtils.createTestAssertion( + issuerEntityId = IDP_ENTITY_ID, + audienceEntityId = SP_ENTITY_ID, + notBefore = Clock.System.now() - 1.hours, + notOnOrAfter = Clock.System.now() - 10.minutes + ) + + val response = SamlTestUtils.createTestResponse( + assertion = assertion, + issuerEntityId = IDP_ENTITY_ID + ) + val base64Response = SamlTestUtils.encodeResponseToBase64(response) + assertFailsWith { + processor.processResponse(base64Response, null) + } + } + + @Test + fun `test reject assertion not yet valid`() = runTest { + val processor = createProcessor(requireSignedAssertions = false, clockSkew = 1.minutes) + + val assertion = SamlTestUtils.createTestAssertion( + issuerEntityId = IDP_ENTITY_ID, + audienceEntityId = SP_ENTITY_ID, + notBefore = Clock.System.now() + 10.minutes, + notOnOrAfter = Clock.System.now() - 20.minutes + ) + + val response = SamlTestUtils.createTestResponse( + assertion = assertion, + issuerEntityId = IDP_ENTITY_ID + ) + val base64Response = SamlTestUtils.encodeResponseToBase64(response) + + assertFailsWith { + processor.processResponse(base64Response, null) + } + } + + @Test + fun `test reject response with wrong issuer`() = runTest { + val processor = createProcessor(requireSignedAssertions = false) + + val assertion = SamlTestUtils.createTestAssertion( + issuerEntityId = IDP_ENTITY_ID, + audienceEntityId = SP_ENTITY_ID + ) + val response = SamlTestUtils.createTestResponse( + assertion = assertion, + issuerEntityId = "https://wrong-idp.example.com" // Wrong issuer! + ) + val base64Response = SamlTestUtils.encodeResponseToBase64(response) + + assertFailsWith { + processor.processResponse(base64Response, null) + } + } + + @Test + fun `test reject response with InResponseTo mismatch`() = runTest { + val processor = createProcessor(requireSignedAssertions = false) + + val assertion = SamlTestUtils.createTestAssertion( + issuerEntityId = IDP_ENTITY_ID, + audienceEntityId = SP_ENTITY_ID, + recipientUrl = ACS_URL + ) + + val response = SamlTestUtils.createTestResponse( + assertion = assertion, + issuerEntityId = IDP_ENTITY_ID, + inResponseTo = "_wrong-request-id" + ) + val base64Response = SamlTestUtils.encodeResponseToBase64(response) + + assertFailsWith { + processor.processResponse(base64Response, "_expected-request-id") + } + } + + @Test + fun `test reject error response status`() = runTest { + val processor = createProcessor(requireSignedAssertions = false) + val assertion = SamlTestUtils.createTestAssertion( + issuerEntityId = IDP_ENTITY_ID, + audienceEntityId = SP_ENTITY_ID + ) + val response = SamlTestUtils.createTestResponse( + assertion = assertion, + issuerEntityId = IDP_ENTITY_ID, + statusCode = StatusCode.RESPONDER // Error status + ) + val base64Response = SamlTestUtils.encodeResponseToBase64(response) + assertFailsWith { + processor.processResponse(base64Response, null) + } + } + + @Test + fun `test IdP-initiated SSO handling`() = runTest { + val processorAllowIdpInit = createProcessor(requireSignedAssertions = false, allowIdpInitiatedSso = true) + val processorDisallowIdpInit = createProcessor(requireSignedAssertions = false, allowIdpInitiatedSso = false) + + // IdP-initiated SSO is accepted when allowed + val assertion1 = SamlTestUtils.createTestAssertion( + nameId = "idp-initiated-user@example.com", + issuerEntityId = IDP_ENTITY_ID, + audienceEntityId = SP_ENTITY_ID, + recipientUrl = ACS_URL + ) + val response1 = SamlTestUtils.createTestResponse( + assertion = assertion1, + issuerEntityId = IDP_ENTITY_ID, + inResponseTo = null + ) + val credential1 = processorAllowIdpInit.processResponse( + SamlTestUtils.encodeResponseToBase64(response1), + expectedRequestId = null + ) + assertEquals("idp-initiated-user@example.com", credential1.assertion.subject?.nameID?.value) + + // IdP-initiated SSO is rejected when not allowed + val assertion2 = SamlTestUtils.createTestAssertion( + nameId = "idp-initiated-user@example.com", + issuerEntityId = IDP_ENTITY_ID, + audienceEntityId = SP_ENTITY_ID, + recipientUrl = ACS_URL + ) + val response2 = SamlTestUtils.createTestResponse( + assertion = assertion2, + issuerEntityId = IDP_ENTITY_ID, + inResponseTo = null + ) + val exception = assertFailsWith { + processorDisallowIdpInit.processResponse( + SamlTestUtils.encodeResponseToBase64(response2), + expectedRequestId = null + ) + } + assertTrue(exception.message!!.contains("IdP-initiated SSO is not allowed")) + + // SP-initiated flow works when IdP-initiated is disabled + val requestId = "_sp-initiated-request-123" + val assertion3 = SamlTestUtils.createTestAssertion( + nameId = "sp-initiated-user@example.com", + issuerEntityId = IDP_ENTITY_ID, + audienceEntityId = SP_ENTITY_ID, + recipientUrl = ACS_URL, + inResponseTo = requestId + ) + val response3 = SamlTestUtils.createTestResponse( + assertion = assertion3, + issuerEntityId = IDP_ENTITY_ID, + inResponseTo = requestId + ) + val credential3 = processorDisallowIdpInit.processResponse( + SamlTestUtils.encodeResponseToBase64(response3), + expectedRequestId = requestId + ) + assertEquals("sp-initiated-user@example.com", credential3.assertion.subject?.nameID?.value) + + // InResponseTo is validated even when IdP-initiated is enabled + val wrongRequestId = "_wrong-response-id" + val assertion4 = SamlTestUtils.createTestAssertion( + nameId = "user@example.com", + issuerEntityId = IDP_ENTITY_ID, + audienceEntityId = SP_ENTITY_ID, + recipientUrl = ACS_URL, + inResponseTo = wrongRequestId + ) + val response4 = SamlTestUtils.createTestResponse( + assertion = assertion4, + issuerEntityId = IDP_ENTITY_ID, + inResponseTo = wrongRequestId + ) + assertFailsWith { + processorAllowIdpInit.processResponse( + SamlTestUtils.encodeResponseToBase64(response4), + expectedRequestId = "_expected-request-id" + ) + } + } + + @Test + fun `test reject replayed assertion`() = runTest { + val processor = createProcessor(requireSignedAssertions = false) + + val assertion = SamlTestUtils.createTestAssertion( + issuerEntityId = IDP_ENTITY_ID, + audienceEntityId = SP_ENTITY_ID, + recipientUrl = ACS_URL + ) + + val response = SamlTestUtils.createTestResponse( + assertion = assertion, + issuerEntityId = IDP_ENTITY_ID + ) + val base64Response = SamlTestUtils.encodeResponseToBase64(response) + + val credential = processor.processResponse(base64Response, null) + assertNotNull(credential) + + // The second request with the same assertion should fail (replay) + assertFailsWith { + processor.processResponse(base64Response, null) + } + } + + @Test + fun `test process complete signed and encrypted response`() = runTest { + val processor = createProcessor(requireSignedAssertions = true) + + val requestId = "_test-request-123" + val assertion = SamlTestUtils.createTestAssertion( + nameId = "john.doe@example.com", + issuerEntityId = IDP_ENTITY_ID, + audienceEntityId = SP_ENTITY_ID, + recipientUrl = ACS_URL, + inResponseTo = requestId + ) + + SamlTestUtils.signAssertion(assertion, idpCredentials.credential) + val encryptedAssertion = SamlTestUtils.encryptAssertion(assertion, spCredentials.credential) + val response = SamlTestUtils.createTestResponseWithEncryptedAssertion( + encryptedAssertion = encryptedAssertion, + issuerEntityId = IDP_ENTITY_ID, + inResponseTo = requestId + ) + val base64Response = SamlTestUtils.encodeResponseToBase64(response) + + val credential = processor.processResponse(base64Response, requestId) + + assertEquals("john.doe@example.com", credential.assertion.subject?.nameID?.value) + assertEquals(IDP_ENTITY_ID, credential.assertion.issuer?.value) + } + + @Test + fun `test assertion issuer mismatch`() = runTest { + val processor = createProcessor(requireSignedAssertions = false) + val assertion = SamlTestUtils.createTestAssertion( + issuerEntityId = "https://wrong-idp.example.com", // Wrong issuer in assertion! + audienceEntityId = SP_ENTITY_ID + ) + val response = SamlTestUtils.createTestResponse( + assertion = assertion, + issuerEntityId = IDP_ENTITY_ID // Correct issuer in response + ) + val base64Response = SamlTestUtils.encodeResponseToBase64(response) + assertFailsWith { + processor.processResponse(base64Response, null) + } + } + + @Test + fun `test Destination validation`() = runTest { + val processorDefault = createProcessor(requireSignedAssertions = false) + val assertionWrongDest = SamlTestUtils.createTestAssertion( + issuerEntityId = IDP_ENTITY_ID, + audienceEntityId = SP_ENTITY_ID, + recipientUrl = ACS_URL + ) + val responseWrongDest = SamlTestUtils.createTestResponse( + assertion = assertionWrongDest, + issuerEntityId = IDP_ENTITY_ID, + destination = "https://attacker.example.com/acs" + ) + assertFailsWith { + processorDefault.processResponse(SamlTestUtils.encodeResponseToBase64(responseWrongDest), null) + } + + val processorNoDestRequired = createProcessor(requireSignedAssertions = false, requireDestination = false) + val assertionNoDest = SamlTestUtils.createTestAssertion( + issuerEntityId = IDP_ENTITY_ID, + audienceEntityId = SP_ENTITY_ID, + recipientUrl = ACS_URL + ) + val responseNoDest = SamlTestUtils.createTestResponse( + assertion = assertionNoDest, + issuerEntityId = IDP_ENTITY_ID, + destination = null + ) + assertNotNull( + processorNoDestRequired.processResponse(SamlTestUtils.encodeResponseToBase64(responseNoDest), null) + ) + + val processorDestRequired = createProcessor(requireSignedAssertions = false, requireDestination = true) + val assertionNoDestRequired = SamlTestUtils.createTestAssertion( + issuerEntityId = IDP_ENTITY_ID, + audienceEntityId = SP_ENTITY_ID, + recipientUrl = ACS_URL + ) + val responseNoDestRequired = SamlTestUtils.createTestResponse( + assertion = assertionNoDestRequired, + issuerEntityId = IDP_ENTITY_ID, + destination = null + ) + val exception = assertFailsWith { + processorDestRequired.processResponse(SamlTestUtils.encodeResponseToBase64(responseNoDestRequired), null) + } + assertTrue(exception.message!!.contains("Destination")) + + val assertionCorrectDest = SamlTestUtils.createTestAssertion( + issuerEntityId = IDP_ENTITY_ID, + audienceEntityId = SP_ENTITY_ID, + recipientUrl = ACS_URL + ) + val responseCorrectDest = SamlTestUtils.createTestResponse( + assertion = assertionCorrectDest, + issuerEntityId = IDP_ENTITY_ID, + destination = ACS_URL + ) + assertNotNull(processorDefault.processResponse(SamlTestUtils.encodeResponseToBase64(responseCorrectDest), null)) + } + + @Test + fun `test Recipient validation`() = runTest { + val processor = createProcessor(requireSignedAssertions = false) + + // Wrong recipient is rejected + val assertionWrongRecipient = SamlTestUtils.createTestAssertion( + issuerEntityId = IDP_ENTITY_ID, + audienceEntityId = SP_ENTITY_ID, + recipientUrl = "https://attacker.example.com/acs" + ) + val responseWrongRecipient = SamlTestUtils.createTestResponse( + assertion = assertionWrongRecipient, + issuerEntityId = IDP_ENTITY_ID + ) + assertFailsWith { + processor.processResponse(SamlTestUtils.encodeResponseToBase64(responseWrongRecipient), null) + } + + // Missing recipient is accepted + val assertionNoRecipient = SamlTestUtils.createTestAssertion( + issuerEntityId = IDP_ENTITY_ID, + audienceEntityId = SP_ENTITY_ID, + recipientUrl = null + ) + val responseNoRecipient = SamlTestUtils.createTestResponse( + assertion = assertionNoRecipient, + issuerEntityId = IDP_ENTITY_ID + ) + assertNotNull(processor.processResponse(SamlTestUtils.encodeResponseToBase64(responseNoRecipient), null)) + + // The correct recipient is accepted + val assertionCorrectRecipient = SamlTestUtils.createTestAssertion( + issuerEntityId = IDP_ENTITY_ID, + audienceEntityId = SP_ENTITY_ID, + recipientUrl = ACS_URL + ) + val responseCorrectRecipient = SamlTestUtils.createTestResponse( + assertion = assertionCorrectRecipient, + issuerEntityId = IDP_ENTITY_ID + ) + assertNotNull(processor.processResponse(SamlTestUtils.encodeResponseToBase64(responseCorrectRecipient), null)) + } + + private fun createProcessor( + requireSignedAssertions: Boolean = true, + requireSignedResponse: Boolean = false, + requireDestination: Boolean = false, + clockSkew: Duration = 5.minutes, + acsUrl: String = ACS_URL, + allowIdpInitiatedSso: Boolean = true + ): SamlResponseProcessor { + val idpMetadata = IdPMetadata( + entityId = IDP_ENTITY_ID, + ssoUrl = "https://idp.example.com/sso", + sloUrl = null, + signingCredentials = listOf(idpCredentials.credential) + ) + return SamlResponseProcessor( + spEntityId = SP_ENTITY_ID, + acsUrl = acsUrl, + idpMetadata = idpMetadata, + decryptionCredential = spCredentials.credential, + clockSkew = clockSkew, + replayCache = replayCache, + requireSignedAssertions = requireSignedAssertions, + requireSignedResponse = requireSignedResponse, + requireDestination = requireDestination, + allowIdpInitiatedSso = allowIdpInitiatedSso, + signatureVerifier = SamlSignatureVerifier(idpMetadata) + ) + } + + companion object { + private const val SP_ENTITY_ID = "https://sp.example.com" + private const val IDP_ENTITY_ID = "https://idp.example.com" + private const val ACS_URL = "https://sp.example.com/saml/acs" + private val idpCredentials: SamlTestUtils.TestCredentials by lazy { + SamlTestUtils.sharedIdpCredentials + } + private val spCredentials: SamlTestUtils.TestCredentials by lazy { + SamlTestUtils.sharedSpCredentials + } + } +} diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SamlAuthTest.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SamlAuthTest.kt new file mode 100644 index 00000000000..1641c874775 --- /dev/null +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SamlAuthTest.kt @@ -0,0 +1,539 @@ +/* + * Copyright 2014-2026 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package io.ktor.server.auth.saml + +import io.ktor.client.request.* +import io.ktor.client.statement.* +import io.ktor.http.* +import io.ktor.server.auth.* +import io.ktor.server.response.* +import io.ktor.server.routing.* +import io.ktor.server.sessions.* +import io.ktor.server.testing.* +import java.io.File +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue +import kotlin.time.ExperimentalTime + +/** + * Integration tests for SAML authentication. + */ +@OptIn(ExperimentalTime::class) +class SamlAuthTest { + + private fun ApplicationTestBuilder.noRedirectsClient() = createClient { followRedirects = false } + + @Test + fun `test unauthenticated request redirects to IdP`() = testApplication { + configureSamlAuth() + + val response = noRedirectsClient().get("/protected") + + assertEquals(HttpStatusCode.Found, response.status) + val location = response.headers[HttpHeaders.Location] + assertNotNull(location) + assertTrue(location.startsWith(IDP_SSO_URL), "Should redirect to IdP SSO URL") + + // Verify AuthnRequest parameters in the redirect URL + val url = Url(location) + assertNotNull(url.parameters["SAMLRequest"], "SAMLRequest parameter should be present") + assertEquals("/protected", url.parameters["RelayState"], "RelayState should contain original URL") + } + + @Test + fun `test redirect contains correct RelayState`() = testApplication { + configureSamlAuth() + + val response = noRedirectsClient().get("/protected/resource?param=value") + + assertEquals(HttpStatusCode.Found, response.status) + val location = response.headers[HttpHeaders.Location] + assertNotNull(location) + + val url = Url(location) + assertEquals("/protected/resource?param=value", url.parameters["RelayState"]) + } + + @Test + fun `test unauthenticated request with HTTP-POST binding returns auto-submit form`() = testApplication { + configureSamlAuth(authnRequestBinding = SamlBinding.HttpPost) + + val response = client.get("/protected") + + assertEquals(HttpStatusCode.OK, response.status) + assertEquals(ContentType.Text.Html.withCharset(Charsets.UTF_8), response.contentType()) + + val body = response.bodyAsText() + assertTrue(body.contains("
+ call.respond(HttpStatusCode.Forbidden, "Custom challenge: $cause") + } + ) + + // Missing SAMLResponse should trigger a challenge + val response = client.post(ACS_PATH) { + contentType(ContentType.Application.FormUrlEncoded) + setBody("RelayState=/some-page") + } + + assertEquals(HttpStatusCode.Forbidden, response.status) + assertTrue(response.bodyAsText().contains("Custom challenge")) + } + + @Test + fun `test validation function`() = testApplication { + configureSamlAuth( + wantAssertionsSigned = false, + validateFunction = { credential -> + // Only accept users from a specific domain + val nameId = credential.assertion.subject?.nameID?.value + if (nameId?.endsWith("@allowed.com") == true) { + SamlPrincipal(credential.assertion) + } else { + null // Reject + } + } + ) + + val validAssertion = SamlTestUtils.createTestAssertion( + nameId = "user@allowed.com", + issuerEntityId = IDP_ENTITY_ID, + audienceEntityId = SP_ENTITY_ID, + recipientUrl = ACS_URL + ) + val validSamlResponse = SamlTestUtils.createTestResponse( + assertion = validAssertion, + issuerEntityId = IDP_ENTITY_ID + ) + val validBase64 = SamlTestUtils.encodeResponseToBase64(validSamlResponse) + val validResponse = client.post(ACS_PATH) { + contentType(ContentType.Application.FormUrlEncoded) + setBody("SAMLResponse=${validBase64.encodeURLParameter()}") + } + + assertEquals(HttpStatusCode.OK, validResponse.status) + assertEquals("Hello, user@allowed.com", validResponse.bodyAsText()) + + val invalidAssertion = SamlTestUtils.createTestAssertion( + nameId = "user@notallowed.com", + issuerEntityId = IDP_ENTITY_ID, + audienceEntityId = SP_ENTITY_ID, + recipientUrl = ACS_URL + ) + val invalidSamlResponse = SamlTestUtils.createTestResponse( + assertion = invalidAssertion, + issuerEntityId = IDP_ENTITY_ID + ) + val invalidBase64 = SamlTestUtils.encodeResponseToBase64(invalidSamlResponse) + + val response = client.post(ACS_PATH) { + contentType(ContentType.Application.FormUrlEncoded) + setBody("SAMLResponse=${invalidBase64.encodeURLParameter()}") + } + assertEquals(HttpStatusCode.Unauthorized, response.status) + } + + @Test + fun `test SP-initiated authentication flow redirect`() = testApplication { + configureSamlAuth(wantAssertionsSigned = false, allowIdpInitiatedSso = false) + + // Trigger SP-initiated flow + val challengeResponse = noRedirectsClient().get("/protected") + assertEquals(HttpStatusCode.Found, challengeResponse.status) + + // Verify the session cookie is set + val sessionCookie = challengeResponse.headers[HttpHeaders.SetCookie] + assertNotNull(sessionCookie, "SAML session cookie should be set") + + // Verify redirect to IdP with SAMLRequest + val redirectUrl = Url(challengeResponse.headers[HttpHeaders.Location]!!) + assertTrue(redirectUrl.toString().startsWith(IDP_SSO_URL), "Should redirect to IdP SSO URL") + assertNotNull(redirectUrl.parameters["SAMLRequest"], "SAMLRequest should be in redirect URL") + } + + @Test + fun `test multiple SAML providers`() = testApplication { + // Reuse shared credentials for IDP1, generate a unique one for IDP2 to test isolation + val idp1Credentials = idpCredentials + val idp2Credentials = SamlTestUtils.generateTestCredentials() + + val idp1Metadata = IdPMetadata( + entityId = "https://idp1.example.com", + ssoUrl = "https://idp1.example.com/sso", + sloUrl = null, + signingCredentials = listOf(idp1Credentials.credential) + ) + + val idp2Metadata = IdPMetadata( + entityId = "https://idp2.example.com", + ssoUrl = "https://idp2.example.com/sso", + sloUrl = null, + signingCredentials = listOf(idp2Credentials.credential) + ) + + // Sessions plugin is required for SAML auth to store the request ID + install(Sessions) { + cookie("SAML_SESSION") + } + + install(Authentication) { + saml("idp1") { + sp = SamlSpMetadata { + spEntityId = SP_ENTITY_ID + acsUrl = "http://localhost/saml/acs/idp1" + wantAssertionsSigned = false + } + idp = idp1Metadata + allowIdpInitiatedSso = true + requireDestination = false + validate { credential -> + SamlPrincipal(credential.assertion) + } + } + saml("idp2") { + sp = SamlSpMetadata { + spEntityId = SP_ENTITY_ID + acsUrl = "http://localhost/saml/acs/idp2" + wantAssertionsSigned = false + } + idp = idp2Metadata + allowIdpInitiatedSso = true + requireDestination = false + validate { credential -> + SamlPrincipal(credential.assertion) + } + } + } + + routing { + authenticate("idp1") { + get("/protected/idp1") { + val principal = call.principal()!! + call.respondText("IDP1: ${principal.nameId}") + } + post("/saml/acs/idp1") { + val principal = call.principal()!! + call.respondText("IDP1: ${principal.nameId}") + } + } + authenticate("idp2") { + get("/protected/idp2") { + val principal = call.principal()!! + call.respondText("IDP2: ${principal.nameId}") + } + post("/saml/acs/idp2") { + val principal = call.principal()!! + call.respondText("IDP2: ${principal.nameId}") + } + } + } + + // Test IDP1 + val assertion1 = SamlTestUtils.createTestAssertion( + nameId = "user1@idp1.com", + issuerEntityId = "https://idp1.example.com", + audienceEntityId = SP_ENTITY_ID, + recipientUrl = "http://localhost/saml/acs/idp1" + ) + val samlResponse1 = SamlTestUtils.createTestResponse( + assertion = assertion1, + issuerEntityId = "https://idp1.example.com" + ) + val base64Response1 = SamlTestUtils.encodeResponseToBase64(samlResponse1) + + val response1 = client.post("/saml/acs/idp1") { + contentType(ContentType.Application.FormUrlEncoded) + setBody("SAMLResponse=${base64Response1.encodeURLParameter()}") + } + assertEquals(HttpStatusCode.OK, response1.status) + assertEquals("IDP1: user1@idp1.com", response1.bodyAsText()) + + // Test IDP2 + val assertion2 = SamlTestUtils.createTestAssertion( + nameId = "user2@idp2.com", + issuerEntityId = "https://idp2.example.com", + audienceEntityId = SP_ENTITY_ID, + recipientUrl = "http://localhost/saml/acs/idp2" + ) + val samlResponse2 = SamlTestUtils.createTestResponse( + assertion = assertion2, + issuerEntityId = "https://idp2.example.com" + ) + val base64Response2 = SamlTestUtils.encodeResponseToBase64(samlResponse2) + + val response2 = client.post("/saml/acs/idp2") { + contentType(ContentType.Application.FormUrlEncoded) + setBody("SAMLResponse=${base64Response2.encodeURLParameter()}") + } + assertEquals(HttpStatusCode.OK, response2.status) + assertEquals("IDP2: user2@idp2.com", response2.bodyAsText()) + } + + private fun ApplicationTestBuilder.configureSamlAuth( + wantAssertionsSigned: Boolean = false, + useKeyStore: Boolean = true, + allowIdpInitiatedSso: Boolean = true, + requireDestination: Boolean = false, + authnRequestBinding: SamlBinding = SamlBinding.HttpRedirect, + customChallenge: (suspend SamlChallengeContext.(AuthenticationFailedCause) -> Unit)? = null, + validateFunction: (suspend io.ktor.server.application.ApplicationCall.(SamlCredential) -> Any?)? = null + ) { + install(Sessions) { + cookie("SAML_SESSION") + } + + install(Authentication) { + saml("saml-auth") { + sp = SamlSpMetadata { + spEntityId = SP_ENTITY_ID + acsUrl = ACS_URL + this.wantAssertionsSigned = wantAssertionsSigned + + if (useKeyStore) { + signingCredential = SamlCrypto.loadCredential( + keystorePath = spKeyStoreFile.absolutePath, + keystorePassword = "test-pass", + keyAlias = "sp-key", + keyPassword = "test-pass" + ) + } + } + idp = IdPMetadata( + entityId = IDP_ENTITY_ID, + ssoUrl = IDP_SSO_URL, + sloUrl = null, + signingCredentials = listOf(idpCredentials.credential) + ) + this.allowIdpInitiatedSso = allowIdpInitiatedSso + this.requireDestination = requireDestination + this.authnRequestBinding = authnRequestBinding + + validate( + validateFunction ?: { credential -> + SamlPrincipal(credential.assertion) + } + ) + + if (customChallenge != null) { + challenge(customChallenge) + } + } + } + + routing { + authenticate("saml-auth") { + get("/protected") { + val principal = call.principal()!! + call.respondText("Hello, ${principal.nameId}") + } + get("/protected/{path...}") { + val principal = call.principal()!! + call.respondText("Hello, ${principal.nameId}") + } + post(ACS_PATH) { + val principal = call.principal()!! + call.respondText("Hello, ${principal.nameId}") + } + } + } + } + + companion object { + private const val SP_ENTITY_ID = "https://sp.example.com" + private const val IDP_ENTITY_ID = "https://idp.example.com" + private const val ACS_PATH = "/saml/acs" + private const val ACS_URL = "http://localhost$ACS_PATH" + private const val IDP_SSO_URL = "https://idp.example.com/sso" + + private val idpCredentials: SamlTestUtils.TestCredentials by lazy { + SamlTestUtils.sharedIdpCredentials + } + private val spCredentials: SamlTestUtils.TestCredentials by lazy { + SamlTestUtils.sharedSpCredentials + } + + private val spKeyStoreFile: File by lazy { + File.createTempFile("sp-keystore", ".jks").also { file -> + file.deleteOnExit() + spCredentials.saveToKeyStore( + file = file, + storePassword = "test-pass", + keyAlias = "sp-key", + keyPassword = "test-pass" + ) + } + } + } +} diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/TestUtil.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/TestUtil.kt new file mode 100644 index 00000000000..4cd0964f8b2 --- /dev/null +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/TestUtil.kt @@ -0,0 +1,355 @@ +/* + * Copyright 2014-2026 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package io.ktor.server.auth.saml + +import io.ktor.network.tls.certificates.* +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport +import org.opensaml.core.xml.io.MarshallingException +import org.opensaml.core.xml.schema.XSString +import org.opensaml.saml.saml2.core.* +import org.opensaml.saml.saml2.encryption.Encrypter +import org.opensaml.security.credential.Credential +import org.opensaml.security.x509.BasicX509Credential +import org.opensaml.xmlsec.encryption.support.DataEncryptionParameters +import org.opensaml.xmlsec.encryption.support.EncryptionConstants +import org.opensaml.xmlsec.encryption.support.KeyEncryptionParameters +import org.opensaml.xmlsec.keyinfo.impl.X509KeyInfoGeneratorFactory +import org.opensaml.xmlsec.signature.Signature +import org.opensaml.xmlsec.signature.support.SignatureConstants +import org.opensaml.xmlsec.signature.support.Signer +import java.security.KeyPair +import java.security.KeyStore +import java.security.PrivateKey +import java.security.cert.X509Certificate +import kotlin.io.encoding.Base64 +import kotlin.time.Clock +import kotlin.time.Duration.Companion.seconds +import kotlin.time.ExperimentalTime +import kotlin.time.Instant +import kotlin.time.toJavaInstant +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +/** + * Test utilities for SAML tests. + */ +@OptIn(ExperimentalTime::class, ExperimentalUuidApi::class) +object SamlTestUtils { + + /** + * Shared IDP credentials for tests that don't need unique credentials. + */ + val sharedIdpCredentials: TestCredentials by lazy { generateTestCredentials() } + + /** + * Shared SP credentials for tests that don't need unique credentials. + */ + val sharedSpCredentials: TestCredentials by lazy { generateTestCredentials() } + + /** + * Test credential holder for signing and encryption operations. + */ + data class TestCredentials( + val credential: BasicX509Credential, + val keyPair: KeyPair, + val certificate: X509Certificate + ) { + /** + * Saves these credentials to a keystore file. + */ + fun saveToKeyStore(file: java.io.File, storePassword: String, keyAlias: String, keyPassword: String) { + val keyStore = KeyStore.getInstance("JKS") + keyStore.load(null, storePassword.toCharArray()) + keyStore.setKeyEntry( + keyAlias, + keyPair.private, + keyPassword.toCharArray(), + arrayOf(certificate) + ) + keyStore.saveToFile(file, storePassword) + } + } + + private const val KEY_ALIAS = "test_key" + private const val KEY_PASSWORD = "test_pass" + + fun generateTestCredentials(): TestCredentials = generateRsaTestCredentials() + + /** + * Generates RSA test credentials for signing and encryption. + */ + fun generateRsaTestCredentials(): TestCredentials { + val keyStore = generateCertificate( + algorithm = "SHA256withRSA", + keyAlias = KEY_ALIAS, + keyPassword = KEY_PASSWORD, + keySizeInBits = 2048 + ) + + val privateKey = keyStore.getKey(KEY_ALIAS, KEY_PASSWORD.toCharArray()) as PrivateKey + val certificate = keyStore.getCertificate(KEY_ALIAS) as X509Certificate + val keyPair = KeyPair(certificate.publicKey, privateKey) + + val credential = BasicX509Credential(certificate, privateKey) + credential.entityId = "test-entity" + + return TestCredentials(credential, keyPair, certificate) + } + + /** + * Signs a SAML assertion using the provided credential. + */ + fun signAssertion(assertion: Assertion, credential: Credential) { + val builderFactory = XMLObjectProviderRegistrySupport.getBuilderFactory() + + val keyInfoGeneratorFactory = X509KeyInfoGeneratorFactory() + keyInfoGeneratorFactory.setEmitEntityCertificate(true) + val keyInfoGenerator = keyInfoGeneratorFactory.newInstance() + + val signature = builderFactory.buildXmlObject(Signature.DEFAULT_ELEMENT_NAME) { + signingCredential = credential + signatureAlgorithm = SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256 + canonicalizationAlgorithm = SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS + keyInfo = keyInfoGenerator.generate(credential) + } + + assertion.signature = signature + + val marshallerFactory = XMLObjectProviderRegistrySupport.getMarshallerFactory() + val marshaller = marshallerFactory.getMarshaller(assertion) + ?: throw MarshallingException("No marshaller found for Assertion") + marshaller.marshall(assertion) + + Signer.signObject(signature) + } + + /** + * Signs a SAML response using the provided credential. + */ + fun signResponse(response: Response, credential: Credential) { + LibSaml.ensureInitialized() + val builderFactory = XMLObjectProviderRegistrySupport.getBuilderFactory() + + val keyInfoGeneratorFactory = X509KeyInfoGeneratorFactory() + keyInfoGeneratorFactory.setEmitEntityCertificate(true) + val keyInfoGenerator = keyInfoGeneratorFactory.newInstance() + + val signature = builderFactory.buildXmlObject(Signature.DEFAULT_ELEMENT_NAME) { + signingCredential = credential + signatureAlgorithm = SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256 + canonicalizationAlgorithm = SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS + keyInfo = keyInfoGenerator.generate(credential) + } + + response.signature = signature + + val marshallerFactory = XMLObjectProviderRegistrySupport.getMarshallerFactory() + val marshaller = marshallerFactory.getMarshaller(response) + ?: throw MarshallingException("No marshaller found for Response") + marshaller.marshall(response) + + Signer.signObject(signature) + } + + /** + * Encrypts a SAML assertion using the provided credential. + */ + fun encryptAssertion(assertion: Assertion, encryptionCredential: Credential): EncryptedAssertion { + LibSaml.ensureInitialized() + val dataEncryptionParams = DataEncryptionParameters() + dataEncryptionParams.algorithm = EncryptionConstants.ALGO_ID_BLOCKCIPHER_AES256_GCM + + val keyEncryptionParams = KeyEncryptionParameters() + keyEncryptionParams.encryptionCredential = encryptionCredential + keyEncryptionParams.algorithm = EncryptionConstants.ALGO_ID_KEYTRANSPORT_RSAOAEP + + val keyInfoGeneratorFactory = X509KeyInfoGeneratorFactory() + keyInfoGeneratorFactory.setEmitEntityCertificate(true) + keyEncryptionParams.keyInfoGenerator = keyInfoGeneratorFactory.newInstance() + + val encrypter = Encrypter(dataEncryptionParams, keyEncryptionParams) + encrypter.keyPlacement = Encrypter.KeyPlacement.PEER + + return encrypter.encrypt(assertion) + } + + fun encodeResponseToBase64(response: Response): String { + return Base64.encode(source = response.marshalToString().toByteArray()) + } + + /** + * Creates a minimal valid SAML assertion for testing. + */ + fun createTestAssertion( + nameId: String = "test-user@example.com", + issuerEntityId: String = "test-idp", + audienceEntityId: String? = null, + recipientUrl: String? = null, + inResponseTo: String? = null, + attributes: Map> = emptyMap(), + sessionIndex: String? = "session-" + Uuid.random().toString(), + notBefore: Instant = Clock.System.now() - 60.seconds, + notOnOrAfter: Instant = Clock.System.now() + 300.seconds + ): Assertion { + LibSaml.ensureInitialized() + val builderFactory = XMLObjectProviderRegistrySupport.getBuilderFactory() + + val issuer = builderFactory.build(Issuer.DEFAULT_ELEMENT_NAME) { + value = issuerEntityId + } + + val nameIDObj = builderFactory.build(NameID.DEFAULT_ELEMENT_NAME) { + value = nameId + format = NameIDType.EMAIL + } + + // Note: SubjectConfirmationData for Bearer method should NOT have NotBefore + // per SAML 2.0 spec (section 2.4.1.2). Only NotOnOrAfter is allowed. + val subjectConfirmationData = builderFactory.build( + SubjectConfirmationData.DEFAULT_ELEMENT_NAME + ) { + this.notOnOrAfter = notOnOrAfter.toJavaInstant() + recipientUrl?.let { this.recipient = it } + inResponseTo?.let { this.inResponseTo = it } + } + + val subjectConfirmation = builderFactory.build(SubjectConfirmation.DEFAULT_ELEMENT_NAME) { + method = SubjectConfirmation.METHOD_BEARER + this.subjectConfirmationData = subjectConfirmationData + } + + val subject = builderFactory.build(Subject.DEFAULT_ELEMENT_NAME) { + this.nameID = nameIDObj + subjectConfirmations.add(subjectConfirmation) + } + + val conditions = builderFactory.build(Conditions.DEFAULT_ELEMENT_NAME) { + this.notBefore = notBefore.toJavaInstant() + this.notOnOrAfter = notOnOrAfter.toJavaInstant() + + // Add AudienceRestriction if audience is provided + if (audienceEntityId != null) { + val audience = builderFactory.build(Audience.DEFAULT_ELEMENT_NAME) { + uri = audienceEntityId + } + val audienceRestriction = builderFactory.build( + AudienceRestriction.DEFAULT_ELEMENT_NAME + ) { + audiences.add(audience) + } + audienceRestrictions.add(audienceRestriction) + } + } + + val authnStatement = sessionIndex?.let { + builderFactory.build(AuthnStatement.DEFAULT_ELEMENT_NAME) { + this.authnInstant = Clock.System.now().toJavaInstant() + this.sessionIndex = it + } + } + + val attributeStatement = if (attributes.isNotEmpty()) { + builderFactory.build(AttributeStatement.DEFAULT_ELEMENT_NAME) { + attributes.forEach { (name, values) -> + val attribute = builderFactory.build(Attribute.DEFAULT_ELEMENT_NAME) { + this.name = name + } + values.forEach { value -> + val v = builderFactory.buildXmlObject(AttributeValue.DEFAULT_ELEMENT_NAME) { + this.value = value + } + attribute.attributeValues.add(v) + } + this.attributes.add(attribute) + } + } + } else { + null + } + + return builderFactory.build(Assertion.DEFAULT_ELEMENT_NAME) { + id = generateSecureSamlId() + issueInstant = Clock.System.now().toJavaInstant() + this.issuer = issuer + this.subject = subject + this.conditions = conditions + authnStatement?.let { authnStatements.add(it) } + attributeStatement?.let { attributeStatements.add(it) } + } + } + + fun createTestResponse( + assertion: Assertion, + issuerEntityId: String = "test-idp", + inResponseTo: String? = null, + statusCode: String = StatusCode.SUCCESS, + destination: String? = null + ): Response { + LibSaml.ensureInitialized() + val builderFactory = XMLObjectProviderRegistrySupport.getBuilderFactory() + + val issuer = builderFactory.build(Issuer.DEFAULT_ELEMENT_NAME) { + value = issuerEntityId + } + + val statusCodeObj = builderFactory.build(StatusCode.DEFAULT_ELEMENT_NAME) { + value = statusCode + } + + val status = builderFactory.build(Status.DEFAULT_ELEMENT_NAME) { + this.statusCode = statusCodeObj + } + + return builderFactory.build(Response.DEFAULT_ELEMENT_NAME) { + id = generateSecureSamlId() + issueInstant = Clock.System.now().toJavaInstant() + inResponseTo?.let { this.inResponseTo = it } + destination?.let { this.destination = it } + this.issuer = issuer + this.status = status + assertions.add(assertion) + } + } + + /** + * Creates a test SAML response with an encrypted assertion. + * + * @param encryptedAssertion The encrypted assertion to include + * @param issuerEntityId The entity ID of the issuer (IdP) + * @param inResponseTo The InResponseTo value (request ID) + * @param destination The Destination attribute (ACS URL) + */ + fun createTestResponseWithEncryptedAssertion( + encryptedAssertion: EncryptedAssertion, + issuerEntityId: String = "test-idp", + inResponseTo: String? = null, + destination: String? = null + ): Response { + LibSaml.ensureInitialized() + val builderFactory = XMLObjectProviderRegistrySupport.getBuilderFactory() + + val issuer = builderFactory.build(Issuer.DEFAULT_ELEMENT_NAME) { + value = issuerEntityId + } + + val statusCode = builderFactory.build(StatusCode.DEFAULT_ELEMENT_NAME) { + value = StatusCode.SUCCESS + } + + val status = builderFactory.build(Status.DEFAULT_ELEMENT_NAME) { + this.statusCode = statusCode + } + + return builderFactory.build(Response.DEFAULT_ELEMENT_NAME) { + id = generateSecureSamlId() + issueInstant = Clock.System.now().toJavaInstant() + inResponseTo?.let { this.inResponseTo = it } + destination?.let { this.destination = it } + this.issuer = issuer + this.status = status + encryptedAssertions.add(encryptedAssertion) + } + } +} From 4716d3c869ab5be993f076e3ed077dd50825417f Mon Sep 17 00:00:00 2001 From: zibet27 Date: Fri, 6 Mar 2026 10:34:07 +0100 Subject: [PATCH 2/3] do not wrap saml validation exceptions --- .../jvm/src/io/ktor/server/auth/saml/SamlSignatureValidator.kt | 1 + .../jvm/src/io/ktor/server/auth/saml/SamlUtils.kt | 2 ++ .../test/io/ktor/server/auth/saml/SAMLResponseProcessorTest.kt | 2 +- 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlSignatureValidator.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlSignatureValidator.kt index c6121c6eee0..db5b730e17b 100644 --- a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlSignatureValidator.kt +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlSignatureValidator.kt @@ -108,6 +108,7 @@ internal class SamlSignatureVerifier( throw SamlValidationException("Invalid Base64 signature", e) } + // signature should be the last parameter val signatureIdx = queryString.indexOf("&Signature=") val queryStringWithoutSignature = if (signatureIdx >= 0) { queryString.substring(0, signatureIdx).toByteArray(Charsets.UTF_8) diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlUtils.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlUtils.kt index 6f023746803..b76434e8998 100644 --- a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlUtils.kt +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlUtils.kt @@ -210,6 +210,8 @@ internal inline fun samlAssert(value: Boolean, crossinline lazyMessage: () -> St internal inline fun withValidationException(crossinline block: () -> T): T { try { return block() + } catch (e: SamlValidationException) { + throw e } catch (e: Exception) { throw SamlValidationException("SAML validation failed", e) } diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SAMLResponseProcessorTest.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SAMLResponseProcessorTest.kt index 77a8ce5a2ce..18a9f0a4da6 100644 --- a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SAMLResponseProcessorTest.kt +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SAMLResponseProcessorTest.kt @@ -266,7 +266,7 @@ class SamlResponseProcessorTest { issuerEntityId = IDP_ENTITY_ID, audienceEntityId = SP_ENTITY_ID, notBefore = Clock.System.now() + 10.minutes, - notOnOrAfter = Clock.System.now() - 20.minutes + notOnOrAfter = Clock.System.now() + 20.minutes ) val response = SamlTestUtils.createTestResponse( From fb8a5ae4bda0afb2d64c7af9039def3c338438ec Mon Sep 17 00:00:00 2001 From: zibet27 Date: Tue, 10 Mar 2026 13:27:00 +0100 Subject: [PATCH 3/3] fix IdPMetadata constructor usage --- .../auth/saml/SAMLResponseProcessorTest.kt | 20 ++++++------- .../io/ktor/server/auth/saml/SamlAuthTest.kt | 30 +++++++++---------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SAMLResponseProcessorTest.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SAMLResponseProcessorTest.kt index 18a9f0a4da6..1427d949f38 100644 --- a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SAMLResponseProcessorTest.kt +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SAMLResponseProcessorTest.kt @@ -52,12 +52,12 @@ class SamlResponseProcessorTest { @Test fun `test decrypt fails without decryption credential`() = runTest { - val idpMetadata = IdPMetadata( - entityId = IDP_ENTITY_ID, - ssoUrl = "https://idp.example.com/sso", - sloUrl = null, + val idpMetadata = IdPMetadata { + entityId = IDP_ENTITY_ID + ssoUrl = "https://idp.example.com/sso" + sloUrl = null signingCredentials = listOf(idpCredentials.credential) - ) + } val processor = SamlResponseProcessor( spEntityId = SP_ENTITY_ID, acsUrl = ACS_URL, @@ -606,12 +606,12 @@ class SamlResponseProcessorTest { acsUrl: String = ACS_URL, allowIdpInitiatedSso: Boolean = true ): SamlResponseProcessor { - val idpMetadata = IdPMetadata( - entityId = IDP_ENTITY_ID, - ssoUrl = "https://idp.example.com/sso", - sloUrl = null, + val idpMetadata = IdPMetadata { + entityId = IDP_ENTITY_ID + ssoUrl = "https://idp.example.com/sso" + sloUrl = null signingCredentials = listOf(idpCredentials.credential) - ) + } return SamlResponseProcessor( spEntityId = SP_ENTITY_ID, acsUrl = acsUrl, diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SamlAuthTest.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SamlAuthTest.kt index 1641c874775..182b00607c4 100644 --- a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SamlAuthTest.kt +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SamlAuthTest.kt @@ -329,19 +329,19 @@ class SamlAuthTest { val idp1Credentials = idpCredentials val idp2Credentials = SamlTestUtils.generateTestCredentials() - val idp1Metadata = IdPMetadata( - entityId = "https://idp1.example.com", - ssoUrl = "https://idp1.example.com/sso", - sloUrl = null, + val idp1Metadata = IdPMetadata { + entityId = "https://idp1.example.com" + ssoUrl = "https://idp1.example.com/sso" + sloUrl = null signingCredentials = listOf(idp1Credentials.credential) - ) + } - val idp2Metadata = IdPMetadata( - entityId = "https://idp2.example.com", - ssoUrl = "https://idp2.example.com/sso", - sloUrl = null, + val idp2Metadata = IdPMetadata { + entityId = "https://idp2.example.com" + ssoUrl = "https://idp2.example.com/sso" + sloUrl = null signingCredentials = listOf(idp2Credentials.credential) - ) + } // Sessions plugin is required for SAML auth to store the request ID install(Sessions) { @@ -470,12 +470,12 @@ class SamlAuthTest { ) } } - idp = IdPMetadata( - entityId = IDP_ENTITY_ID, - ssoUrl = IDP_SSO_URL, - sloUrl = null, + idp = IdPMetadata { + entityId = IDP_ENTITY_ID + ssoUrl = IDP_SSO_URL + sloUrl = null signingCredentials = listOf(idpCredentials.credential) - ) + } this.allowIdpInitiatedSso = allowIdpInitiatedSso this.requireDestination = requireDestination this.authnRequestBinding = authnRequestBinding