From a01fcab588ae6e47376bf2398e28ff7356882afc Mon Sep 17 00:00:00 2001 From: zibet27 Date: Mon, 2 Mar 2026 22:10:54 +0100 Subject: [PATCH 1/4] Server SAML Auth. Logout support. --- .../api/ktor-server-auth-saml.api | 4 + .../src/io/ktor/server/auth/saml/SamlAuth.kt | 76 +++ .../auth/saml/SamlAuthenticationProvider.kt | 140 ++++++ .../server/auth/saml/SamlLogoutBuilder.kt | 151 ++++++ .../server/auth/saml/SamlLogoutProcessor.kt | 215 ++++++++ .../src/io/ktor/server/auth/saml/SamlUtils.kt | 18 + .../auth/saml/SamlLogoutIntegrationTest.kt | 375 ++++++++++++++ .../ktor/server/auth/saml/SamlLogoutTest.kt | 468 ++++++++++++++++++ .../test/io/ktor/server/auth/saml/TestUtil.kt | 282 +++++++++++ 9 files changed, 1729 insertions(+) create mode 100644 ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlLogoutBuilder.kt create mode 100644 ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlLogoutProcessor.kt create mode 100644 ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SamlLogoutIntegrationTest.kt create mode 100644 ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SamlLogoutTest.kt diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/api/ktor-server-auth-saml.api b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/api/ktor-server-auth-saml.api index 9f150ad2d22..e2b909ce89e 100644 --- a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/api/ktor-server-auth-saml.api +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/api/ktor-server-auth-saml.api @@ -68,6 +68,10 @@ public final class io/ktor/server/auth/saml/SamlAuthKt { public static final fun saml (Lio/ktor/server/auth/AuthenticationConfig;Ljava/lang/String;Lkotlin/jvm/functions/Function1;)V public static synthetic fun saml$default (Lio/ktor/server/auth/AuthenticationConfig;Ljava/lang/String;Ljava/lang/String;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V public static synthetic fun saml$default (Lio/ktor/server/auth/AuthenticationConfig;Ljava/lang/String;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V + public static final fun samlLogout (Lio/ktor/server/application/ApplicationCall;Lio/ktor/server/auth/saml/SamlPrincipal;Lio/ktor/server/auth/saml/SamlSpMetadata;Ljava/lang/String;Ljava/lang/String;Lio/ktor/server/auth/saml/SignatureAlgorithm;)Lio/ktor/server/auth/saml/SamlRedirectResult; + public static final fun samlLogout (Lio/ktor/server/application/ApplicationCall;Ljava/lang/String;Ljava/lang/String;Lio/ktor/server/auth/saml/SamlSpMetadata;Ljava/lang/String;Ljava/lang/String;Lio/ktor/server/auth/saml/SignatureAlgorithm;)Lio/ktor/server/auth/saml/SamlRedirectResult; + public static synthetic fun samlLogout$default (Lio/ktor/server/application/ApplicationCall;Lio/ktor/server/auth/saml/SamlPrincipal;Lio/ktor/server/auth/saml/SamlSpMetadata;Ljava/lang/String;Ljava/lang/String;Lio/ktor/server/auth/saml/SignatureAlgorithm;ILjava/lang/Object;)Lio/ktor/server/auth/saml/SamlRedirectResult; + public static synthetic fun samlLogout$default (Lio/ktor/server/application/ApplicationCall;Ljava/lang/String;Ljava/lang/String;Lio/ktor/server/auth/saml/SamlSpMetadata;Ljava/lang/String;Ljava/lang/String;Lio/ktor/server/auth/saml/SignatureAlgorithm;ILjava/lang/Object;)Lio/ktor/server/auth/saml/SamlRedirectResult; } public final class io/ktor/server/auth/saml/SamlAuthenticationProvider : io/ktor/server/auth/AuthenticationProvider { diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlAuth.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlAuth.kt index 19bc0b77522..98aa40645d9 100644 --- a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlAuth.kt +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlAuth.kt @@ -4,7 +4,9 @@ package io.ktor.server.auth.saml +import io.ktor.server.application.* import io.ktor.server.auth.* +import io.ktor.server.sessions.* /** * Installs SAML 2.0 authentication. @@ -97,3 +99,77 @@ public fun AuthenticationConfig.saml( register(provider) } + +/** + * Initiates SP-initiated SAML Single Logout. + * + * @param principal The authenticated SAML principal containing NameID and session info + * @param spMetadata The Service Provider metadata containing the entity ID and signing credential + * @param idpSloUrl The IdP's SLO URL + * @param relayState Optional URL to redirect to after logout completes + * @param signatureAlgorithm The signature algorithm to use for signing the LogoutRequest + * @return [SamlRedirectResult] containing the request ID and redirect URL to the IdP + */ +public fun ApplicationCall.samlLogout( + principal: SamlPrincipal, + spMetadata: SamlSpMetadata, + idpSloUrl: String, + relayState: String? = null, + signatureAlgorithm: SignatureAlgorithm = SignatureAlgorithm.RSA_SHA256 +): SamlRedirectResult = samlLogout( + nameId = principal.nameId, + idpSloUrl = idpSloUrl, + spMetadata = spMetadata, + sessionIndex = principal.sessionIndex, + relayState = relayState, + signatureAlgorithm = signatureAlgorithm +) + +/** + * Initiates SP-initiated SAML Single Logout with explicit NameID. + * + * @param nameId The NameID of the user to log out + * @param idpSloUrl The IdP's SLO URL + * @param spMetadata The Service Provider metadata containing the entity ID and signing credential + * @param sessionIndex The session index from the AuthnStatement (optional but recommended) + * @param relayState Optional URL to redirect to after logout completes + * @param signatureAlgorithm The signature algorithm to use for signing the LogoutRequest + * @return [SamlRedirectResult] containing the request ID and redirect URL to the IdP + */ +public fun ApplicationCall.samlLogout( + nameId: String, + idpSloUrl: String, + spMetadata: SamlSpMetadata, + sessionIndex: String?, + relayState: String? = null, + signatureAlgorithm: SignatureAlgorithm = SignatureAlgorithm.RSA_SHA256 +): SamlRedirectResult { + LibSaml.ensureInitialized() + + val spEntityId = spMetadata.spEntityId + require(!spEntityId.isNullOrBlank()) { "spEntityId must not be blank for logout" } + require(nameId.isNotBlank()) { "nameId must not be blank for logout" } + require(idpSloUrl.isNotBlank()) { "idpSloUrl must not be blank for logout" } + + val result = buildLogoutRequestRedirect( + nameId = nameId, + idpSloUrl = idpSloUrl, + relayState = relayState, + spEntityId = spEntityId, + sessionIndex = sessionIndex, + signingCredential = spMetadata.signingCredential, + signatureAlgorithm = signatureAlgorithm + ) + + // Store the logout request ID in the session for InResponseTo validation + val currentSession = checkNotNull(sessions.get()) { + "No current session found. Did you forget to call authenticate() or sessions.install()?" + } + val newSession = SamlSession( + requestId = currentSession.requestId, + logoutRequestId = result.messageId + ) + sessions.set(newSession) + + return result +} diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlAuthenticationProvider.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlAuthenticationProvider.kt index 9314da42d26..0f26818d588 100644 --- a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlAuthenticationProvider.kt +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlAuthenticationProvider.kt @@ -11,6 +11,7 @@ import io.ktor.server.response.* import io.ktor.server.sessions.* import kotlinx.coroutines.CancellationException import kotlinx.serialization.Serializable +import org.opensaml.saml.saml2.core.StatusCode import org.opensaml.security.x509.BasicX509Credential import org.slf4j.Logger import org.slf4j.LoggerFactory @@ -137,6 +138,21 @@ public class SamlAuthenticationProvider internal constructor( private val acsPath = Url(acsUrl).encodedPath private val sloPath = if (isAbsoluteUrl(sloUrl)) Url(sloUrl).encodedPath else sloUrl + private val enableSingleLogout = config.enableSingleLogout + private val logoutProcessor: SamlLogoutProcessor? by lazy { + if (!enableSingleLogout) return@lazy null + SamlLogoutProcessor( + sloUrl = sloUrl, + idpMetadata = idpMetadata, + replayCache = replayCache, + clockSkew = config.clockSkew, + signatureVerifier = signatureVerifier, + requireDestination = config.requireDestination, + requireSignedLogoutRequest = config.requireSignedLogoutRequest, + requireSignedLogoutResponse = config.requireSignedResponse, + ) + } + init { require(config.clockSkew.isPositive()) { "clockSkew must be positive, got: ${config.clockSkew}" } require(acsPath != sloPath) { "acsPath and sloPath must be different, got: $acsPath" } @@ -148,6 +164,7 @@ public class SamlAuthenticationProvider internal constructor( request.httpMethod == HttpMethod.Post && request.path() == acsPath -> context.handleSamlCallback() + enableSingleLogout && request.path() == sloPath -> context.handleSloEndpoint() else -> context.handleChallenge() } } @@ -255,6 +272,123 @@ public class SamlAuthenticationProvider internal constructor( } } } + + /** + * Handles the Single Logout (SLO) endpoint. + * Routes to either LogoutRequest or LogoutResponse handling based on the parameters. + */ + private suspend fun AuthenticationContext.handleSloEndpoint() { + try { + val parameters = when (call.request.httpMethod) { + HttpMethod.Get -> call.request.queryParameters + HttpMethod.Post -> call.receiveParameters() + else -> { + logger.debug("SLO endpoint called with unsupported method: {}", call.request.httpMethod) + return call.respond(HttpStatusCode.MethodNotAllowed) + } + } + + val samlRequest = parameters["SAMLRequest"] + val samlResponse = parameters["SAMLResponse"] + + when { + samlRequest != null -> handleIdpLogoutRequest(samlRequest, parameters) + samlResponse != null -> handleLogoutResponse(samlResponse, parameters) + else -> { + logger.debug("SLO endpoint called without SAMLRequest or SAMLResponse") + call.respond(HttpStatusCode.BadRequest, "Missing SAMLRequest or SAMLResponse parameter") + } + } + } catch (e: CancellationException) { + throw e + } catch (e: Exception) { + logger.debug("SLO endpoint error", e) + call.respond(HttpStatusCode.InternalServerError, "Logout processing failed") + } + } + + /** + * Handles IdP-initiated LogoutRequest. + * Processes the request, terminates the local session, and sends back a LogoutResponse. + */ + private suspend fun AuthenticationContext.handleIdpLogoutRequest( + samlRequestBase64: String, + parameters: Parameters + ) { + val processor = checkNotNull(logoutProcessor) + try { + val logoutRequest = processor.processRequest( + samlRequestBase64 = samlRequestBase64, + binding = call.request.samlBinding(), + queryString = call.request.queryString(), + signatureParam = parameters["Signature"], + signatureAlgorithmParam = parameters["SigAlg"], + ) + call.sessions.clear() + + val idpSloUrl = requireNotNull(idpMetadata.sloUrl) { "IdP SLO URL not found" } + val logoutResponse = buildLogoutResponseRedirect( + spEntityId = spEntityId, + idpSloUrl = idpSloUrl, + inResponseTo = logoutRequest.requestId, + statusCodeValue = StatusCode.SUCCESS, + relayState = parameters["RelayState"], + signingCredential = signingCredential, + signatureAlgorithm = config.signatureAlgorithm + ) + + call.respondRedirect(logoutResponse.redirectUrl) + } catch (e: SamlValidationException) { + logger.error("IdP LogoutRequest validation failed", e) + call.respond(HttpStatusCode.BadRequest, "Invalid logout request") + } + } + + /** + * Handles LogoutResponse from IdP (after SP-initiated logout). + * Validates the response and completes the logout process. + */ + private suspend fun AuthenticationContext.handleLogoutResponse( + samlResponseBase64: String, + parameters: Parameters + ) { + val processor = checkNotNull(logoutProcessor) { "Logout response processor not initialized" } + + try { + val session = call.sessions.get() + val expectedRequestId = session?.logoutRequestId + + val result = processor.processResponse( + samlResponseBase64 = samlResponseBase64, + expectedRequestId = expectedRequestId, + binding = call.request.samlBinding(), + queryString = call.request.queryString(), + signatureParam = parameters["Signature"], + signatureAlgorithmParam = parameters["SigAlg"], + ) + + // Clear the SAML session regardless of status + call.sessions.clear() + + if (!result.isSuccess) { + val statusMessage = result.statusMessage ?: "No message" + logger.warn("IdP logout failed with status ${result.statusCode}: $statusMessage") + call.respond(HttpStatusCode.BadGateway, "IdP logout failed: $statusMessage") + return + } + + // Redirect to RelayState or respond with success + val relayState = parameters["RelayState"] + if (!relayState.isNullOrBlank() && relayValidator.validate(url = relayState)) { + call.respondRedirect(relayState) + } else { + call.respond(HttpStatusCode.OK, "Logout completed") + } + } catch (e: SamlValidationException) { + logger.debug("LogoutResponse validation failed", e) + call.respond(HttpStatusCode.BadRequest, "Invalid logout response") + } + } } internal class RelayValidator(private val allowedRelayStateUrls: List?) { @@ -343,3 +477,9 @@ private val SAML_AUTH_KEY: Any = "SAMLAuth" private fun isAbsoluteUrl(url: String): Boolean { return url.startsWith("http://") || url.startsWith("https://") } + +private fun ApplicationRequest.samlBinding(): SamlBinding = when (httpMethod) { + HttpMethod.Post -> SamlBinding.HttpPost + HttpMethod.Get -> SamlBinding.HttpRedirect + else -> error("Unsupported HTTP method: $httpMethod") +} diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlLogoutBuilder.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlLogoutBuilder.kt new file mode 100644 index 00000000000..f9a214cd87e --- /dev/null +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlLogoutBuilder.kt @@ -0,0 +1,151 @@ +/* + * Copyright 2014-2026 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package io.ktor.server.auth.saml + +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport +import org.opensaml.saml.saml2.core.* +import org.opensaml.security.credential.Credential +import kotlin.time.Clock +import kotlin.time.ExperimentalTime +import kotlin.time.toJavaInstant + +/** + * Builds a LogoutRequest and returns the redirect URL for HTTP-Redirect binding. + * + * @param spEntityId The Service Provider's entity ID (Issuer) + * @param idpSloUrl The IdP's Single Logout Service URL + * @param nameId The NameID of the principal to log out + * @param nameIdFormat The format of the NameID (optional) + * @param sessionIndex The session index from the AuthnStatement (optional) + * @param relayState Optional RelayState for post-logout redirect + * @param signingCredential Credential for signing (if null, no signing is performed) + * @param signatureAlgorithm Signature algorithm (default: RSA-SHA256) + */ +@OptIn(ExperimentalTime::class) +internal fun buildLogoutRequestRedirect( + spEntityId: String, + idpSloUrl: String, + nameId: String, + nameIdFormat: NameIdFormat? = null, + sessionIndex: String? = null, + relayState: String? = null, + signingCredential: Credential? = null, + signatureAlgorithm: SignatureAlgorithm = SignatureAlgorithm.RSA_SHA256 +): SamlRedirectResult { + LibSaml.ensureInitialized() + val logoutRequest = buildLogoutRequest( + spEntityId = spEntityId, + idpSloUrl = idpSloUrl, + nameId = nameId, + nameIdFormat = nameIdFormat, + sessionIndex = sessionIndex + ) + return buildSamlRedirectResult( + messageId = checkNotNull(logoutRequest.id), + samlObject = logoutRequest, + destinationUrl = idpSloUrl, + parameterName = "SAMLRequest", + relayState = relayState, + signingCredential = signingCredential, + signatureAlgorithm = signatureAlgorithm + ) +} + +@OptIn(ExperimentalTime::class) +private fun buildLogoutRequest( + spEntityId: String, + idpSloUrl: String, + nameId: String, + nameIdFormat: NameIdFormat?, + sessionIndex: String? +): LogoutRequest { + val builderFactory = XMLObjectProviderRegistrySupport.getBuilderFactory() + val issuer = builderFactory.build(Issuer.DEFAULT_ELEMENT_NAME) { + value = spEntityId + } + val nameIdElement = builderFactory.build(NameID.DEFAULT_ELEMENT_NAME) { + value = nameId + nameIdFormat?.let { format = it.uri } + } + val sessionIndexElement = sessionIndex?.let { + builderFactory.build(SessionIndex.DEFAULT_ELEMENT_NAME) { + value = it + } + } + return builderFactory.build(LogoutRequest.DEFAULT_ELEMENT_NAME) { + id = generateSecureSamlId() + issueInstant = Clock.System.now().toJavaInstant() + destination = idpSloUrl + this.issuer = issuer + this.nameID = nameIdElement + sessionIndexElement?.let { sessionIndexes.add(it) } + } +} + +/** + * Builds a LogoutResponse and returns the redirect URL for HTTP-Redirect binding. + * + * @param spEntityId The Service Provider's entity ID (Issuer) + * @param idpSloUrl The IdP's Single Logout Service URL + * @param inResponseTo The ID of the LogoutRequest this is responding to + * @param statusCodeValue The status code (default: SUCCESS) + * @param relayState Optional RelayState for post-logout redirect + * @param signingCredential Credential for signing (if null, no signing is performed) + * @param signatureAlgorithm Signature algorithm (default: RSA-SHA256) + */ +@OptIn(ExperimentalTime::class) +internal fun buildLogoutResponseRedirect( + spEntityId: String, + idpSloUrl: String, + inResponseTo: String, + statusCodeValue: String = StatusCode.SUCCESS, + relayState: String? = null, + signingCredential: Credential? = null, + signatureAlgorithm: SignatureAlgorithm = SignatureAlgorithm.RSA_SHA256 +): SamlRedirectResult { + LibSaml.ensureInitialized() + val logoutResponse = buildLogoutResponse( + spEntityId = spEntityId, + idpSloUrl = idpSloUrl, + inResponseTo = inResponseTo, + statusCodeValue = statusCodeValue + ) + return buildSamlRedirectResult( + messageId = checkNotNull(logoutResponse.id), + samlObject = logoutResponse, + destinationUrl = idpSloUrl, + parameterName = "SAMLResponse", + relayState = relayState, + signingCredential = signingCredential, + signatureAlgorithm = signatureAlgorithm + ) +} + +@OptIn(ExperimentalTime::class) +private fun buildLogoutResponse( + spEntityId: String, + idpSloUrl: String, + inResponseTo: String, + statusCodeValue: String +): LogoutResponse { + val builderFactory = XMLObjectProviderRegistrySupport.getBuilderFactory() + val issuer = builderFactory.build(Issuer.DEFAULT_ELEMENT_NAME) { + value = spEntityId + } + val statusCode = builderFactory.build(StatusCode.DEFAULT_ELEMENT_NAME) { + value = statusCodeValue + } + val status = builderFactory.build(Status.DEFAULT_ELEMENT_NAME) { + this.statusCode = statusCode + } + return builderFactory.build(LogoutResponse.DEFAULT_ELEMENT_NAME) { + this.id = generateSecureSamlId() + this.issueInstant = Clock.System.now().toJavaInstant() + this.destination = idpSloUrl + this.inResponseTo = inResponseTo + this.issuer = issuer + this.status = status + } +} diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlLogoutProcessor.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlLogoutProcessor.kt new file mode 100644 index 00000000000..ba79586b86f --- /dev/null +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlLogoutProcessor.kt @@ -0,0 +1,215 @@ +/* + * Copyright 2014-2026 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package io.ktor.server.auth.saml + +import org.opensaml.saml.saml2.core.LogoutRequest +import org.opensaml.saml.saml2.core.LogoutResponse +import org.opensaml.saml.saml2.core.StatusCode +import org.w3c.dom.Document +import kotlin.time.Clock +import kotlin.time.Duration +import kotlin.time.Duration.Companion.minutes +import kotlin.time.ExperimentalTime +import kotlin.time.Instant +import kotlin.time.toKotlinInstant + +/** + * Result of processing a SAML LogoutRequest. + * + * @property requestId The ID of the LogoutRequest + * @property nameId The NameID of the subject to log out + * @property sessionIndex The session index to log out (optional) + */ +internal class LogoutRequestResult( + val requestId: String, + val nameId: String, + val sessionIndex: String? +) + +/** + * Result of processing a SAML LogoutResponse. + * + * @property statusCode The SAML status code + * @property statusMessage Optional status message from the IdP + * @property inResponseTo The ID of the LogoutRequest this response to + */ +internal class LogoutResult( + val statusCode: String, + val statusMessage: String?, + val inResponseTo: String? +) { + val isSuccess: Boolean get() = statusCode == StatusCode.SUCCESS +} + +/** + * Maximum acceptable age for LogoutRequest/LogoutResponse IssueInstant. + */ +private val LOGOUT_MESSAGE_LIFETIME = 5.minutes + +/** + * Processor for SAML LogoutRequest and LogoutResponse messages. + * + * This class validates and extracts information from SAML 2.0 logout + * messages received from the IdP during Single Logout (SLO). + * + * ## Security Features + * + * - Issuer validation (required by default) + * - IssueInstant freshness validation with configurable clock skew + * - Replay attack protection for LogoutRequest IDs + * - Signature verification (configurable) + * - Destination validation + */ +@OptIn(ExperimentalTime::class) +internal class SamlLogoutProcessor( + private val sloUrl: String, + private val idpMetadata: IdPMetadata, + private val requireSignedLogoutRequest: Boolean, + private val requireSignedLogoutResponse: Boolean, + private val requireDestination: Boolean, + private val signatureVerifier: SamlSignatureVerifier, + private val clockSkew: Duration, + private val replayCache: SamlReplayCache, +) { + init { + LibSaml.ensureInitialized() + } + + /** + * Processes a Base64-encoded SAML LogoutRequest. + * + * @param samlRequestBase64 The Base64-encoded LogoutRequest (deflated for HTTP-Redirect binding) + * @param binding The SAML binding used (HTTP-Redirect or HTTP-POST) + * @param queryString The raw query string for HTTP-Redirect binding signature verification. + * Must preserve the exact encoding from the IdP. The Signature parameter will be removed internally. + * @return LogoutRequestResult containing the request ID, nameId, and sessionIndex + * @throws SamlValidationException if the request is malformed or invalid + */ + suspend fun processRequest( + samlRequestBase64: String, + binding: SamlBinding, + queryString: String? = null, + signatureParam: String? = null, + signatureAlgorithmParam: String? = null + ): LogoutRequestResult { + val logoutRequest = withValidationException { + val isDeflated = binding == SamlBinding.HttpRedirect + val requestXml = samlRequestBase64.decodeSamlMessage(isDeflated) + + val document: Document = LibSaml.parserPool.parse(requestXml.toByteArray().inputStream()) + document.documentElement.unmarshall() + } + + val requestId = samlRequire(logoutRequest.id) { "LogoutRequest must have an ID" } + + // Issuer is required for security - ensures the request is from the expected IdP + val issuer = samlRequire(logoutRequest.issuer?.value) { "LogoutRequest Issuer is required" } + samlAssert(issuer == idpMetadata.entityId) { "Issuer mismatch" } + + // Validate IssueInstant freshness + val issueInstant = samlRequire(logoutRequest.issueInstant?.toKotlinInstant()) { + "LogoutRequest IssueInstant is required" + } + validateIssueInstant(issueInstant, "LogoutRequest") + + val expirationTime = Clock.System.now() + LOGOUT_MESSAGE_LIFETIME + clockSkew + val recorded = replayCache.tryRecordAssertion(assertionId = requestId, expirationTime) + samlAssert(recorded) { + "LogoutRequest has already been processed (replay attack)" + } + + val destination = logoutRequest.destination + samlAssert(!requireDestination || destination != null) { "LogoutRequest Destination is not present" } + samlAssert(destination == null || destination == sloUrl) { "Destination mismatch" } + + if (requireSignedLogoutRequest) { + if (binding == SamlBinding.HttpRedirect) { + signatureVerifier.verifyQueryString( + queryString = checkNotNull(queryString), + signatureBase64 = samlRequire(signatureParam) { "Signature is missing" }, + signatureAlgorithmUri = samlRequire(signatureAlgorithmParam) { "SigAlg is missing" } + ) + } else { + signatureVerifier.verify(signedObject = logoutRequest) + } + } + + val nameId = samlRequire(logoutRequest.nameID?.value) { "LogoutRequest must contain a NameID" } + + val sessionIndex = logoutRequest.sessionIndexes.firstOrNull()?.value + + return LogoutRequestResult( + requestId = requestId, + nameId = nameId, + sessionIndex = sessionIndex + ) + } + + /** + * Processes a Base64-encoded SAML LogoutResponse. + * + * @param samlResponseBase64 The Base64-encoded LogoutResponse (deflated for HTTP-Redirect binding) + * @param expectedRequestId The ID of the LogoutRequest that was sent (for InResponseTo validation) + * @param binding The SAML binding used (HTTP-Redirect or HTTP-POST) + * @param queryString The raw query string for HTTP-Redirect binding signature verification. + * @throws SamlValidationException if the response is malformed or invalid + */ + fun processResponse( + samlResponseBase64: String, + expectedRequestId: String?, + binding: SamlBinding, + queryString: String? = null, + signatureParam: String? = null, + signatureAlgorithmParam: String? = null + ): LogoutResult { + val responseXml = samlResponseBase64.decodeSamlMessage(isDeflated = binding == SamlBinding.HttpRedirect) + val document: Document = LibSaml.parserPool.parse(responseXml.toByteArray().inputStream()) + val logoutResponse = document.documentElement.unmarshall() + + val inResponseTo = logoutResponse.inResponseTo + samlAssert(expectedRequestId == null || inResponseTo == expectedRequestId) { "InResponseTo mismatch" } + + // Issuer is required for security - ensures the response is from the expected IdP + val issuer = samlRequire(logoutResponse.issuer?.value) { "LogoutResponse Issuer is required" } + samlAssert(issuer == idpMetadata.entityId) { "Issuer mismatch" } + + // Validate IssueInstant freshness + val issueInstant = samlRequire(logoutResponse.issueInstant?.toKotlinInstant()) { + "LogoutResponse IssueInstant is required" + } + validateIssueInstant(issueInstant, "LogoutResponse") + + val destination = logoutResponse.destination + samlAssert(!requireDestination || destination != null) { "LogoutResponse Destination is not present" } + samlAssert(destination == null || destination == sloUrl) { "Destination mismatch" } + + if (requireSignedLogoutResponse) { + if (binding == SamlBinding.HttpRedirect) { + signatureVerifier.verifyQueryString( + queryString = checkNotNull(queryString), + signatureBase64 = samlRequire(signatureParam) { "Signature is missing" }, + signatureAlgorithmUri = samlRequire(signatureAlgorithmParam) { "SigAlg is missing" }, + ) + } else { + signatureVerifier.verify(signedObject = logoutResponse) + } + } + + val status = samlRequire(logoutResponse.status) { "LogoutResponse has no Status element" } + val statusCode = samlRequire(status.statusCode?.value) { "LogoutResponse Status has no StatusCode" } + val statusMessage = status.statusMessage?.value + + return LogoutResult(statusCode, statusMessage, inResponseTo) + } + + private fun validateIssueInstant(issueInstant: Instant, messageType: String) { + val now = Clock.System.now() + val effectiveMinTime = now - clockSkew - LOGOUT_MESSAGE_LIFETIME + val effectiveMaxTime = now + clockSkew + + samlAssert(issueInstant >= effectiveMinTime) { "$messageType IssueInstant is too old" } + samlAssert(issueInstant <= effectiveMaxTime) { "$messageType IssueInstant is in the future" } + } +} diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlUtils.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlUtils.kt index 6f023746803..c1571967729 100644 --- a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlUtils.kt +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlUtils.kt @@ -20,6 +20,8 @@ import java.net.URLEncoder import java.security.Signature import java.util.zip.Deflater import java.util.zip.DeflaterOutputStream +import java.util.zip.Inflater +import java.util.zip.InflaterInputStream import javax.xml.namespace.QName import javax.xml.transform.TransformerFactory import javax.xml.transform.dom.DOMSource @@ -84,6 +86,22 @@ internal fun String.encodeSamlMessage(deflate: Boolean): String { return Base64.encode(source = bytesOut.toByteArray()) } +/** + * Decodes a Base64-encoded SAML message. + * + * @param isDeflated Whether the message is deflated (HTTP-Redirect binding: true, HTTP-POST: false) + * @return Decoded XML string + */ +internal fun String.decodeSamlMessage(isDeflated: Boolean): String { + val decodedBytes = Base64.decode(source = this) + if (!isDeflated) { + return decodedBytes.toString(Charsets.UTF_8) + } + val inflater = Inflater(true) + val inflaterInputStream = InflaterInputStream(decodedBytes.inputStream(), inflater) + return inflaterInputStream.readBytes().toString(Charsets.UTF_8) +} + @Suppress("UNCHECKED_CAST") internal inline fun XMLObjectBuilderFactory.build( key: QName, diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SamlLogoutIntegrationTest.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SamlLogoutIntegrationTest.kt new file mode 100644 index 00000000000..cf10d67fa85 --- /dev/null +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SamlLogoutIntegrationTest.kt @@ -0,0 +1,375 @@ +/* + * Copyright 2014-2026 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package io.ktor.server.auth.saml + +import io.ktor.client.request.* +import io.ktor.client.statement.* +import io.ktor.http.* +import io.ktor.server.auth.* +import io.ktor.server.response.* +import io.ktor.server.routing.* +import io.ktor.server.sessions.* +import io.ktor.server.testing.* +import org.opensaml.saml.saml2.core.StatusCode +import java.io.File +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue +import kotlin.time.ExperimentalTime + +/** + * Integration tests for SAML Single Logout functionality + */ +@OptIn(ExperimentalTime::class) +class SamlLogoutIntegrationTest { + + private fun ApplicationTestBuilder.noRedirectsClient() = createClient { followRedirects = false } + + @Test + fun `test samlLogout builds redirect with RelayState and updates session`() = testApplication { + configureSamlAuth(enableSingleLogout = true) + + val basicResponse = noRedirectsClient().get("/test-logout") + assertEquals(HttpStatusCode.Found, basicResponse.status) + + val location = basicResponse.headers[HttpHeaders.Location] + assertNotNull(location, "Should redirect to IdP SLO URL") + assertTrue(location.startsWith(IDP_SLO_URL), "Should redirect to IdP SLO URL") + assertTrue(location.contains("SAMLRequest="), "Redirect URL should contain SAMLRequest") + + val sessionCookie = basicResponse.headers[HttpHeaders.SetCookie] + assertNotNull(sessionCookie, "Session should be updated with logoutRequestId") + + // Test logout with RelayState + val relayResponse = noRedirectsClient().get("/test-logout-with-relay") + assertEquals(HttpStatusCode.Found, relayResponse.status) + + val relayLocation = relayResponse.headers[HttpHeaders.Location] + assertNotNull(relayLocation) + assertTrue(relayLocation.contains("RelayState="), "Redirect URL should contain RelayState") + } + + @Test + fun `test IdP-initiated logout via HTTP-POST with RelayState`() = testApplication { + configureSamlAuth(enableSingleLogout = true) + + // First request without RelayState + val logoutRequestXml1 = SamlTestUtils.createLogoutRequest( + issuer = IDP_ENTITY_ID, + destination = SLO_URL, + nameId = "user@example.com", + sessionIndex = "_session123", + requestId = "_request_1" + ) + val base64Request1 = SamlTestUtils.encodeForPost(logoutRequestXml1) + + val response = noRedirectsClient().post(SLO_PATH) { + contentType(ContentType.Application.FormUrlEncoded) + setBody("SAMLRequest=${base64Request1.encodeURLParameter()}") + } + + assertEquals(HttpStatusCode.Found, response.status) + val location = response.headers[HttpHeaders.Location] + assertNotNull(location, "Should redirect to IdP with LogoutResponse") + assertTrue(location.startsWith(IDP_SLO_URL), "Should redirect to IdP SLO URL") + assertTrue(location.contains("SAMLResponse="), "Redirect should contain SAMLResponse") + + // Second request with RelayState (using a different request ID to avoid replay detection) + val logoutRequestXml2 = SamlTestUtils.createLogoutRequest( + issuer = IDP_ENTITY_ID, + destination = SLO_URL, + nameId = "user@example.com", + sessionIndex = "_session123", + requestId = "_request_2" + ) + val base64Request2 = SamlTestUtils.encodeForPost(logoutRequestXml2) + + val relayResponse = noRedirectsClient().post(SLO_PATH) { + contentType(ContentType.Application.FormUrlEncoded) + setBody("SAMLRequest=${base64Request2.encodeURLParameter()}&RelayState=/post-logout") + } + + assertEquals(HttpStatusCode.Found, relayResponse.status) + val relayLocation = relayResponse.headers[HttpHeaders.Location] + assertNotNull(relayLocation) + assertTrue(relayLocation.contains("RelayState="), "RelayState should be preserved") + } + + @Test + fun `test IdP-initiated logout via HTTP-GET`() = testApplication { + configureSamlAuth(enableSingleLogout = true) + + val logoutRequestXml = SamlTestUtils.createLogoutRequest( + issuer = IDP_ENTITY_ID, + destination = SLO_URL, + nameId = "user@example.com" + ) + val encodedRequest = logoutRequestXml.encodeSamlMessage(deflate = true) + + val response = noRedirectsClient().get(SLO_PATH) { + parameter("SAMLRequest", encodedRequest) + } + + assertEquals(HttpStatusCode.Found, response.status) + val location = response.headers[HttpHeaders.Location] + assertNotNull(location) + assertTrue(location.contains("SAMLResponse="), "Should respond with SAMLResponse") + } + + @Test + fun `test LogoutResponse processing with success and failure status`() = testApplication { + configureSamlAuth(enableSingleLogout = true) + + val successResponseXml = SamlTestUtils.createLogoutResponse( + inResponseTo = "_test_request_id", + statusCode = StatusCode.SUCCESS, + issuer = IDP_ENTITY_ID, + destination = SLO_URL + ) + val successBase64 = SamlTestUtils.encodeForPost(successResponseXml) + + val successResponse = client.post(SLO_PATH) { + contentType(ContentType.Application.FormUrlEncoded) + setBody("SAMLResponse=${successBase64.encodeURLParameter()}") + } + + assertEquals(HttpStatusCode.OK, successResponse.status) + assertTrue(successResponse.bodyAsText().contains("Logout completed")) + + val failureResponseXml = SamlTestUtils.createLogoutResponse( + inResponseTo = "_test_request", + statusCode = StatusCode.RESPONDER, + statusMessage = "Logout failed at IdP", + issuer = IDP_ENTITY_ID, + destination = SLO_URL + ) + val failureBase64 = SamlTestUtils.encodeForPost(failureResponseXml) + + val failureResponse = client.post(SLO_PATH) { + contentType(ContentType.Application.FormUrlEncoded) + setBody("SAMLResponse=${failureBase64.encodeURLParameter()}") + } + + // Non-success LogoutResponse should result in BadGateway (IdP failed to complete logout) + assertEquals(HttpStatusCode.BadGateway, failureResponse.status) + assertTrue(failureResponse.bodyAsText().contains("IdP logout failed")) + } + + @Test + fun `test LogoutResponse with RelayState redirects`() = testApplication { + configureSamlAuth(enableSingleLogout = true) + + val logoutResponseXml = SamlTestUtils.createLogoutResponse( + inResponseTo = "_test_request", + statusCode = StatusCode.SUCCESS, + issuer = IDP_ENTITY_ID, + destination = SLO_URL + ) + val base64Response = SamlTestUtils.encodeForPost(logoutResponseXml) + + val response = noRedirectsClient().post(SLO_PATH) { + contentType(ContentType.Application.FormUrlEncoded) + setBody("SAMLResponse=${base64Response.encodeURLParameter()}&RelayState=/post-logout-page") + } + + assertEquals(HttpStatusCode.Found, response.status) + val location = response.headers[HttpHeaders.Location] + assertEquals("/post-logout-page", location) + } + + @Test + fun `test SLO endpoint disabled when enableSingleLogout is false`() = testApplication { + configureSamlAuth(enableSingleLogout = false) + + val logoutRequestXml = SamlTestUtils.createLogoutRequest( + issuer = IDP_ENTITY_ID, + destination = SLO_URL, + nameId = "user@example.com" + ) + val base64Request = SamlTestUtils.encodeForPost(logoutRequestXml) + + val response = noRedirectsClient().post(SLO_PATH) { + contentType(ContentType.Application.FormUrlEncoded) + setBody("SAMLRequest=${base64Request.encodeURLParameter()}") + } + + // When SLO is disabled, the endpoint triggers an auth challenge (redirect to IdP for SSO) + assertEquals(HttpStatusCode.Found, response.status) + val location = response.headers[HttpHeaders.Location] + assertNotNull(location) + assertTrue(location.startsWith(IDP_SSO_URL), "Should redirect to IdP SSO URL when SLO is disabled") + } + + @Test + fun `test SLO rejects invalid requests`() = testApplication { + configureSamlAuth(enableSingleLogout = true) + + // Test unsupported HTTP method + val putResponse = client.put(SLO_PATH) { + contentType(ContentType.Application.FormUrlEncoded) + setBody("SAMLRequest=test") + } + assertEquals(HttpStatusCode.MethodNotAllowed, putResponse.status) + + // Test missing SAMLRequest and SAMLResponse + val missingResponse = client.post(SLO_PATH) { + contentType(ContentType.Application.FormUrlEncoded) + setBody("RelayState=/some-page") + } + assertEquals(HttpStatusCode.BadRequest, missingResponse.status) + } + + @Test + fun `test SLO rejects wrong issuer`() = testApplication { + configureSamlAuth(enableSingleLogout = true) + + // Test LogoutRequest with the wrong issuer + val badRequestXml = SamlTestUtils.createLogoutRequest( + issuer = "https://malicious-idp.example.com", + destination = SLO_URL, + nameId = "user@example.com" + ) + val badRequestBase64 = SamlTestUtils.encodeForPost(badRequestXml) + + val requestResponse = client.post(SLO_PATH) { + contentType(ContentType.Application.FormUrlEncoded) + setBody("SAMLRequest=${badRequestBase64.encodeURLParameter()}") + } + assertEquals(HttpStatusCode.BadRequest, requestResponse.status) + + // Test LogoutResponse with the wrong issuer + val badResponseXml = SamlTestUtils.createLogoutResponse( + inResponseTo = "_test_request", + statusCode = StatusCode.SUCCESS, + issuer = "https://malicious-idp.example.com", + destination = SLO_URL + ) + val badResponseBase64 = SamlTestUtils.encodeForPost(badResponseXml) + + val responseResponse = client.post(SLO_PATH) { + contentType(ContentType.Application.FormUrlEncoded) + setBody("SAMLResponse=${badResponseBase64.encodeURLParameter()}") + } + assertEquals(HttpStatusCode.BadRequest, responseResponse.status) + } + + private fun ApplicationTestBuilder.configureSamlAuth(enableSingleLogout: Boolean = false) { + install(Sessions) { + cookie("SAML_SESSION") + } + + val spMetadata = SamlSpMetadata { + spEntityId = SP_ENTITY_ID + acsUrl = ACS_URL + sloUrl = SLO_URL + wantAssertionsSigned = false + + if (enableSingleLogout) { + signingCredential = SamlCrypto.loadCredential( + keystorePath = spKeyStoreFile.absolutePath, + keystorePassword = "test-pass", + keyAlias = "sp-key", + keyPassword = "test-pass" + ) + } + } + + val testIdpMetadata = SamlTestUtils.createTestIdPMetadataWithSlo( + entityId = IDP_ENTITY_ID, + ssoUrl = IDP_SSO_URL, + sloUrl = IDP_SLO_URL + ) + + install(Authentication) { + saml("saml-auth") { + this.sp = spMetadata + this.enableSingleLogout = enableSingleLogout + this.idp = testIdpMetadata + allowIdpInitiatedSso = true + requireDestination = false + requireSignedResponse = false + requireSignedLogoutRequest = false + validate { credential -> + SamlPrincipal(credential.assertion) + } + } + } + + routing { + authenticate("saml-auth") { + get("/protected") { + val principal = call.principal()!! + call.respondText("Hello, ${principal.nameId}") + } + post(ACS_PATH) { + val principal = call.principal()!! + call.respondText("Hello, ${principal.nameId}") + } + // SLO endpoint must be under an authenticated block for the provider to handle it + get(SLO_PATH) { + call.respondText("SLO not handled") + } + post(SLO_PATH) { + call.respondText("SLO not handled") + } + } + + if (enableSingleLogout) { + get("/test-logout") { + // Create a session for the test (normally this would be created during authentication) + call.sessions.set(SamlSession(requestId = "_auth_request_123")) + val result = call.samlLogout( + nameId = "user@example.com", + idpSloUrl = IDP_SLO_URL, + spMetadata = spMetadata, + sessionIndex = "_session123" + ) + call.respondRedirect(result.redirectUrl) + } + + get("/test-logout-with-relay") { + // Create a session for the test (normally this would be created during authentication) + call.sessions.set(SamlSession(requestId = "_auth_request_123")) + val result = call.samlLogout( + nameId = "user@example.com", + idpSloUrl = IDP_SLO_URL, + spMetadata = spMetadata, + sessionIndex = "_session123", + relayState = "/post-logout-page" + ) + call.respondRedirect(result.redirectUrl) + } + } + } + } + + companion object { + private const val SP_ENTITY_ID = "https://sp.example.com" + private const val IDP_ENTITY_ID = "https://idp.example.com" + private const val ACS_PATH = "/saml/acs" + private const val ACS_URL = "http://localhost$ACS_PATH" + private const val SLO_PATH = "/saml/slo" + private const val SLO_URL = "http://localhost$SLO_PATH" + private const val IDP_SSO_URL = "https://idp.example.com/sso" + private const val IDP_SLO_URL = "https://idp.example.com/slo" + + private val spCredentials: SamlTestUtils.TestCredentials by lazy { + SamlTestUtils.sharedSpCredentials + } + + private val spKeyStoreFile: File by lazy { + File.createTempFile("sp-keystore", ".jks").also { file -> + file.deleteOnExit() + spCredentials.saveToKeyStore( + file = file, + storePassword = "test-pass", + keyAlias = "sp-key", + keyPassword = "test-pass" + ) + } + } + } +} diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SamlLogoutTest.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SamlLogoutTest.kt new file mode 100644 index 00000000000..ce1f68dbeab --- /dev/null +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SamlLogoutTest.kt @@ -0,0 +1,468 @@ +/* + * Copyright 2014-2026 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package io.ktor.server.auth.saml + +import kotlinx.coroutines.runBlocking +import org.opensaml.saml.saml2.core.StatusCode +import kotlin.test.* +import kotlin.time.Clock +import kotlin.time.Duration.Companion.minutes +import kotlin.time.Duration.Companion.seconds +import kotlin.time.ExperimentalTime + +/** + * Tests for SAML Single Logout (SLO) functionality + */ +@OptIn(ExperimentalTime::class) +class SamlLogoutTest { + + @Test + fun `build logout request redirect with and without signing`() { + val credentials = SamlTestUtils.generateTestCredentials() + + val unsignedResult = buildLogoutRequestRedirect( + spEntityId = "https://sp.example.com", + idpSloUrl = "https://idp.example.com/saml/slo", + nameId = "user@example.com", + nameIdFormat = NameIdFormat.Email, + sessionIndex = "_session123", + relayState = "/dashboard", + signingCredential = null + ) + assertNotNull(unsignedResult.messageId) + assertTrue(unsignedResult.messageId.startsWith("_")) + assertTrue(unsignedResult.redirectUrl.startsWith("https://idp.example.com/saml/slo")) + assertTrue(unsignedResult.redirectUrl.contains("SAMLRequest=")) + assertTrue(unsignedResult.redirectUrl.contains("RelayState=")) + assertFalse(unsignedResult.redirectUrl.contains("Signature=")) + + val signedResult = buildLogoutRequestRedirect( + spEntityId = "https://sp.example.com", + idpSloUrl = "https://idp.example.com/saml/slo", + nameId = "user@example.com", + sessionIndex = null, + signingCredential = credentials.credential + ) + assertNotNull(signedResult.messageId) + assertTrue(signedResult.redirectUrl.contains("SAMLRequest=")) + assertTrue(signedResult.redirectUrl.contains("SigAlg=")) + assertTrue(signedResult.redirectUrl.contains("Signature=")) + } + + @Test + fun `process logout response with different status codes`() { + val idpMetadata = SamlTestUtils.createTestIdPMetadata() + val processor = createProcessor(idpMetadata) + + val successResponse = SamlTestUtils.createLogoutResponse( + inResponseTo = "_request123", + statusCode = StatusCode.SUCCESS, + issuer = idpMetadata.entityId, + destination = "https://sp.example.com/saml/slo" + ) + val successResult = processor.processResponse( + samlResponseBase64 = SamlTestUtils.encodeForPost(successResponse), + expectedRequestId = "_request123", + binding = SamlBinding.HttpPost + ) + assertTrue(successResult.isSuccess) + assertEquals(StatusCode.SUCCESS, successResult.statusCode) + assertEquals("_request123", successResult.inResponseTo) + + val failedResponse = SamlTestUtils.createLogoutResponse( + inResponseTo = "_request456", + statusCode = StatusCode.RESPONDER, + statusMessage = "Logout failed at IdP", + issuer = idpMetadata.entityId, + destination = "https://sp.example.com/saml/slo" + ) + val failedResult = processor.processResponse( + samlResponseBase64 = SamlTestUtils.encodeForPost(failedResponse), + expectedRequestId = "_request456", + binding = SamlBinding.HttpPost + ) + assertFalse(failedResult.isSuccess) + assertEquals(StatusCode.RESPONDER, failedResult.statusCode) + assertEquals("Logout failed at IdP", failedResult.statusMessage) + } + + @Test + fun `process logout response validation failures`() { + val idpMetadata = SamlTestUtils.createTestIdPMetadata() + val processor = createProcessor(idpMetadata) + + val mismatchedInResponseTo = SamlTestUtils.createLogoutResponse( + inResponseTo = "_different_request", + statusCode = StatusCode.SUCCESS, + issuer = idpMetadata.entityId, + destination = "https://sp.example.com/saml/slo" + ) + assertFailsWith { + processor.processResponse( + samlResponseBase64 = SamlTestUtils.encodeForPost(mismatchedInResponseTo), + expectedRequestId = "_request123", + binding = SamlBinding.HttpPost + ) + } + + val wrongIssuer = SamlTestUtils.createLogoutResponse( + inResponseTo = "_request123", + statusCode = StatusCode.SUCCESS, + issuer = "https://malicious-idp.example.com", + destination = "https://sp.example.com/saml/slo" + ) + assertFailsWith { + processor.processResponse( + samlResponseBase64 = SamlTestUtils.encodeForPost(wrongIssuer), + expectedRequestId = "_request123", + binding = SamlBinding.HttpPost + ) + } + } + + @Test + fun `IdP metadata SLO URL parsing`() { + val withSloUrl = parseSamlIdpMetadata( + """ + + + + + + + + + """.trimIndent(), + validateCertificateExpiration = false + ) + assertEquals("https://idp.example.com", withSloUrl.entityId) + assertEquals("https://idp.example.com/sso", withSloUrl.ssoUrl) + assertEquals("https://idp.example.com/slo", withSloUrl.sloUrl) + + val withoutSloUrl = parseSamlIdpMetadata( + """ + + + + + + + """.trimIndent(), + validateCertificateExpiration = false + ) + assertEquals("https://idp.example.com", withoutSloUrl.entityId) + assertEquals("https://idp.example.com/sso", withoutSloUrl.ssoUrl) + assertNull(withoutSloUrl.sloUrl) + } + + @Test + fun `HTTP-Redirect signature verification for requests and responses`() = runBlocking { + val credentials = SamlTestUtils.sharedIdpCredentials + val idpMetadata = SamlTestUtils.createTestIdPMetadataWithSlo(credentials = credentials) + val processor = SamlLogoutProcessor( + sloUrl = "https://sp.example.com/saml/slo", + idpMetadata = idpMetadata, + requireSignedLogoutRequest = true, + requireSignedLogoutResponse = true, + requireDestination = true, + signatureVerifier = SamlSignatureVerifier(idpMetadata), + clockSkew = 60.seconds, + replayCache = InMemorySamlReplayCache() + ) + + val signedRequest = SamlTestUtils.createSignedLogoutRequestRedirect( + credentials = credentials, + issuer = idpMetadata.entityId, + destination = "https://sp.example.com/saml/slo", + nameId = "user@example.com", + sessionIndex = "_session123" + ) + val requestResult = processor.processRequest( + samlRequestBase64 = signedRequest.samlMessageBase64, + binding = SamlBinding.HttpRedirect, + queryString = signedRequest.fullQueryString, + signatureParam = signedRequest.signatureBase64, + signatureAlgorithmParam = signedRequest.signatureAlgorithmUri + ) + assertEquals("user@example.com", requestResult.nameId) + assertEquals("_session123", requestResult.sessionIndex) + + val signedResponse = SamlTestUtils.createSignedLogoutResponseRedirect( + credentials = credentials, + inResponseTo = "_request123", + statusCode = StatusCode.SUCCESS, + issuer = idpMetadata.entityId, + destination = "https://sp.example.com/saml/slo" + ) + val responseResult = processor.processResponse( + samlResponseBase64 = signedResponse.samlMessageBase64, + expectedRequestId = "_request123", + binding = SamlBinding.HttpRedirect, + queryString = signedResponse.fullQueryString, + signatureParam = signedResponse.signatureBase64, + signatureAlgorithmParam = signedResponse.signatureAlgorithmUri + ) + assertTrue(responseResult.isSuccess) + assertEquals("_request123", responseResult.inResponseTo) + } + + @Test + fun `signature verification failures`(): Unit = runBlocking { + val signingCredentials = SamlTestUtils.generateTestCredentials() + val verificationCredentials = SamlTestUtils.generateTestCredentials() + val idpMetadata = SamlTestUtils.createTestIdPMetadataWithSlo(credentials = verificationCredentials) + val processor = SamlLogoutProcessor( + sloUrl = "https://sp.example.com/saml/slo", + idpMetadata = idpMetadata, + requireSignedLogoutRequest = true, + requireSignedLogoutResponse = true, + requireDestination = true, + signatureVerifier = SamlSignatureVerifier(idpMetadata), + clockSkew = 60.seconds, + replayCache = InMemorySamlReplayCache() + ) + + val wrongKey = SamlTestUtils.createSignedLogoutRequestRedirect( + credentials = signingCredentials, + issuer = idpMetadata.entityId, + destination = "https://sp.example.com/saml/slo", + nameId = "user@example.com" + ) + assertFailsWith { + processor.processRequest( + samlRequestBase64 = wrongKey.samlMessageBase64, + binding = SamlBinding.HttpRedirect, + queryString = wrongKey.fullQueryString + ) + } + + val validCredentials = SamlTestUtils.sharedIdpCredentials + val validMetadata = SamlTestUtils.createTestIdPMetadataWithSlo(credentials = validCredentials) + val validProcessor = SamlLogoutProcessor( + sloUrl = "https://sp.example.com/saml/slo", + idpMetadata = validMetadata, + requireSignedLogoutRequest = true, + requireSignedLogoutResponse = true, + requireDestination = true, + signatureVerifier = SamlSignatureVerifier(validMetadata), + clockSkew = 60.seconds, + replayCache = InMemorySamlReplayCache() + ) + val signedMessage = SamlTestUtils.createSignedLogoutRequestRedirect( + credentials = validCredentials, + issuer = validMetadata.entityId, + destination = "https://sp.example.com/saml/slo", + nameId = "user@example.com" + ) + val tamperedQueryString = signedMessage.fullQueryString.replace("&Signature=", "&extra=param&Signature=") + assertFailsWith { + validProcessor.processRequest( + samlRequestBase64 = signedMessage.samlMessageBase64, + binding = SamlBinding.HttpRedirect, + queryString = tamperedQueryString + ) + } + } + + @Test + fun `missing Issuer is rejected`(): Unit = runBlocking { + val idpMetadata = SamlTestUtils.createTestIdPMetadata() + val processor = createProcessor(idpMetadata) + + val requestWithoutIssuer = SamlTestUtils.createLogoutRequest( + issuer = null, + destination = "https://sp.example.com/saml/slo", + nameId = "user@example.com" + ) + val requestException = assertFailsWith { + processor.processRequest( + samlRequestBase64 = SamlTestUtils.encodeForPost(requestWithoutIssuer), + binding = SamlBinding.HttpPost + ) + } + assertTrue(requestException.message!!.contains("Issuer is required")) + + val responseWithoutIssuer = SamlTestUtils.createLogoutResponse( + inResponseTo = "_request123", + statusCode = StatusCode.SUCCESS, + issuer = null, + destination = "https://sp.example.com/saml/slo" + ) + assertFailsWith { + processor.processResponse( + samlResponseBase64 = SamlTestUtils.encodeForPost(responseWithoutIssuer), + expectedRequestId = "_request123", + binding = SamlBinding.HttpPost + ) + } + } + + @Test + fun `IssueInstant validation for LogoutRequest`() = runBlocking { + val idpMetadata = SamlTestUtils.createTestIdPMetadata() + val processor = SamlLogoutProcessor( + sloUrl = "https://sp.example.com/saml/slo", + idpMetadata = idpMetadata, + requireSignedLogoutRequest = false, + requireSignedLogoutResponse = false, + requireDestination = false, + signatureVerifier = SamlSignatureVerifier(idpMetadata), + clockSkew = 60.seconds, + replayCache = InMemorySamlReplayCache() + ) + + val oldIssueInstant = Clock.System.now() - 10.minutes + val oldRequest = SamlTestUtils.createLogoutRequest( + issuer = idpMetadata.entityId, + destination = "https://sp.example.com/saml/slo", + nameId = "user@example.com", + issueInstant = oldIssueInstant + ) + assertFailsWith { + processor.processRequest( + samlRequestBase64 = SamlTestUtils.encodeForPost(oldRequest), + binding = SamlBinding.HttpPost + ) + } + + val futureIssueInstant = Clock.System.now() + 5.minutes + val futureRequest = SamlTestUtils.createLogoutRequest( + issuer = idpMetadata.entityId, + destination = "https://sp.example.com/saml/slo", + nameId = "user@example.com", + issueInstant = futureIssueInstant + ) + assertFailsWith { + processor.processRequest( + samlRequestBase64 = SamlTestUtils.encodeForPost(futureRequest), + binding = SamlBinding.HttpPost + ) + } + + val withinSkew = Clock.System.now() + 30.seconds + val validRequest = SamlTestUtils.createLogoutRequest( + issuer = idpMetadata.entityId, + destination = "https://sp.example.com/saml/slo", + nameId = "user@example.com", + issueInstant = withinSkew + ) + val result = processor.processRequest( + samlRequestBase64 = SamlTestUtils.encodeForPost(validRequest), + binding = SamlBinding.HttpPost + ) + assertEquals("user@example.com", result.nameId) + } + + @Test + fun `IssueInstant validation for LogoutResponse`() { + val idpMetadata = SamlTestUtils.createTestIdPMetadata() + val processor = SamlLogoutProcessor( + sloUrl = "https://sp.example.com/saml/slo", + idpMetadata = idpMetadata, + requireSignedLogoutRequest = false, + requireSignedLogoutResponse = false, + requireDestination = false, + signatureVerifier = SamlSignatureVerifier(idpMetadata), + clockSkew = 60.seconds, + replayCache = InMemorySamlReplayCache() + ) + + val oldIssueInstant = Clock.System.now() - 10.minutes + val oldResponse = SamlTestUtils.createLogoutResponse( + inResponseTo = "_request123", + statusCode = StatusCode.SUCCESS, + issuer = idpMetadata.entityId, + destination = "https://sp.example.com/saml/slo", + issueInstant = oldIssueInstant + ) + assertFailsWith { + processor.processResponse( + samlResponseBase64 = SamlTestUtils.encodeForPost(oldResponse), + expectedRequestId = "_request123", + binding = SamlBinding.HttpPost + ) + } + + val futureIssueInstant = Clock.System.now() + 5.minutes + val futureResponse = SamlTestUtils.createLogoutResponse( + inResponseTo = "_request123", + statusCode = StatusCode.SUCCESS, + issuer = idpMetadata.entityId, + destination = "https://sp.example.com/saml/slo", + issueInstant = futureIssueInstant + ) + assertFailsWith { + processor.processResponse( + samlResponseBase64 = SamlTestUtils.encodeForPost(futureResponse), + expectedRequestId = "_request123", + binding = SamlBinding.HttpPost + ) + } + } + + @Test + fun `replay protection for LogoutRequest`() = runBlocking { + val idpMetadata = SamlTestUtils.createTestIdPMetadata() + val replayCache = InMemorySamlReplayCache() + val processor = SamlLogoutProcessor( + sloUrl = "https://sp.example.com/saml/slo", + idpMetadata = idpMetadata, + requireSignedLogoutRequest = false, + requireSignedLogoutResponse = false, + requireDestination = false, + signatureVerifier = SamlSignatureVerifier(idpMetadata), + clockSkew = 60.seconds, + replayCache = replayCache + ) + + val firstRequest = SamlTestUtils.createLogoutRequest( + issuer = idpMetadata.entityId, + destination = "https://sp.example.com/saml/slo", + nameId = "user@example.com", + requestId = "_fixed_request_id" + ) + val encodedRequest = SamlTestUtils.encodeForPost(firstRequest) + val result = processor.processRequest( + samlRequestBase64 = encodedRequest, + binding = SamlBinding.HttpPost + ) + assertEquals("user@example.com", result.nameId) + + assertFailsWith { + processor.processRequest(samlRequestBase64 = encodedRequest, binding = SamlBinding.HttpPost) + } + + val secondRequest = SamlTestUtils.createLogoutRequest( + issuer = idpMetadata.entityId, + destination = "https://sp.example.com/saml/slo", + nameId = "user@example.com", + requestId = "_different_request_id" + ) + val result2 = processor.processRequest( + samlRequestBase64 = SamlTestUtils.encodeForPost(secondRequest), + binding = SamlBinding.HttpPost + ) + assertEquals("user@example.com", result2.nameId) + + replayCache.close() + } + + private fun createProcessor( + idpMetadata: IdPMetadata, + requireSignedLogoutRequest: Boolean = false, + requireSignedLogoutResponse: Boolean = false, + requireDestination: Boolean = false, + clockSkew: kotlin.time.Duration = 60.seconds + ) = SamlLogoutProcessor( + sloUrl = "https://sp.example.com/saml/slo", + idpMetadata = idpMetadata, + requireSignedLogoutRequest = requireSignedLogoutRequest, + requireSignedLogoutResponse = requireSignedLogoutResponse, + requireDestination = requireDestination, + signatureVerifier = SamlSignatureVerifier(idpMetadata), + clockSkew = clockSkew, + replayCache = InMemorySamlReplayCache() + ) +} diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/TestUtil.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/TestUtil.kt index 4cd0964f8b2..19c754744ea 100644 --- a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/TestUtil.kt +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/TestUtil.kt @@ -8,6 +8,7 @@ import io.ktor.network.tls.certificates.* import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport import org.opensaml.core.xml.io.MarshallingException import org.opensaml.core.xml.schema.XSString +import org.opensaml.saml.common.SAMLVersion import org.opensaml.saml.saml2.core.* import org.opensaml.saml.saml2.encryption.Encrypter import org.opensaml.security.credential.Credential @@ -352,4 +353,285 @@ object SamlTestUtils { encryptedAssertions.add(encryptedAssertion) } } + + /** + * Creates a test IdP metadata with SLO support. + */ + fun createTestIdPMetadata( + entityId: String = "https://idp.example.com", + ssoUrl: String = "https://idp.example.com/sso", + sloUrl: String? = "https://idp.example.com/slo" + ): IdPMetadata { + val credentials = generateTestCredentials() + return IdPMetadata( + entityId = entityId, + ssoUrl = ssoUrl, + sloUrl = sloUrl, + signingCredentials = listOf(credentials.credential) + ) + } + + /** + * Creates a SAML LogoutResponse XML. + * + * @param inResponseTo The ID of the LogoutRequest being responded to + * @param statusCode The SAML status code + * @param statusMessage Optional status message + * @param issuer The issuer entity ID (null to omit Issuer element) + * @param destination The destination URL + * @param issueInstant Custom IssueInstant (defaults to current time) + */ + fun createLogoutResponse( + inResponseTo: String, + statusCode: String, + statusMessage: String? = null, + issuer: String? = null, + destination: String, + issueInstant: Instant = Clock.System.now() + ): String { + LibSaml.ensureInitialized() + val builderFactory = XMLObjectProviderRegistrySupport.getBuilderFactory() + + val issuerObj = issuer?.let { + builderFactory.build(Issuer.DEFAULT_ELEMENT_NAME) { + value = it + } + } + + val statusCodeObj = builderFactory.build(StatusCode.DEFAULT_ELEMENT_NAME) { + value = statusCode + } + + val statusMessageObj = statusMessage?.let { + builderFactory.build(StatusMessage.DEFAULT_ELEMENT_NAME) { + value = it + } + } + + val status = builderFactory.build(Status.DEFAULT_ELEMENT_NAME) { + this.statusCode = statusCodeObj + statusMessageObj?.let { this.statusMessage = it } + } + + val logoutResponse = builderFactory.build(LogoutResponse.DEFAULT_ELEMENT_NAME) { + id = generateSecureSamlId() + this.issueInstant = issueInstant.toJavaInstant() + this.inResponseTo = inResponseTo + this.destination = destination + version = SAMLVersion.VERSION_20 + issuerObj?.let { this.issuer = it } + this.status = status + } + + return logoutResponse.marshalToString() + } + + /** + * Base64 encodes a string for HTTP-POST binding. + */ + fun encodeForPost(xml: String): String = Base64.encode(source = xml.toByteArray()) + + /** + * Creates a SAML LogoutRequest XML for testing IdP-initiated logout. + */ + fun createLogoutRequest( + issuer: String? = null, + destination: String, + nameId: String, + nameIdFormat: String? = null, + sessionIndex: String? = null, + issueInstant: Instant = Clock.System.now(), + requestId: String = generateSecureSamlId() + ): String { + LibSaml.ensureInitialized() + val builderFactory = XMLObjectProviderRegistrySupport.getBuilderFactory() + + val issuerObj = issuer?.let { + builderFactory.build(Issuer.DEFAULT_ELEMENT_NAME) { + value = it + } + } + + val nameIdObj = builderFactory.build(NameID.DEFAULT_ELEMENT_NAME) { + value = nameId + nameIdFormat?.let { format = it } + } + + val sessionIndexElement = sessionIndex?.let { + builderFactory.build(SessionIndex.DEFAULT_ELEMENT_NAME) { + value = it + } + } + + val logoutRequest = builderFactory.build(LogoutRequest.DEFAULT_ELEMENT_NAME) { + id = requestId + this.issueInstant = issueInstant.toJavaInstant() + this.destination = destination + version = SAMLVersion.VERSION_20 + issuerObj?.let { this.issuer = it } + this.nameID = nameIdObj + sessionIndexElement?.let { sessionIndexes.add(it) } + } + + return logoutRequest.marshalToString() + } + + fun createTestIdPMetadataWithSlo( + entityId: String = "https://idp.example.com", + ssoUrl: String = "https://idp.example.com/sso", + sloUrl: String = "https://idp.example.com/slo", + credentials: TestCredentials = sharedIdpCredentials + ): IdPMetadata { + return IdPMetadata( + entityId = entityId, + ssoUrl = ssoUrl, + sloUrl = sloUrl, + signingCredentials = listOf(credentials.credential) + ) + } + + /** + * Result of creating a signed SAML message for HTTP-Redirect binding. + * + * @property fullQueryString The complete query string including the Signature parameter + * @property samlMessageBase64 The Base64-encoded (and deflated) SAML message + */ + data class SignedRedirectMessage( + val fullQueryString: String, + val samlMessageBase64: String, + val signatureBase64: String, + val signatureAlgorithmUri: String, + ) + + /** + * Creates a signed LogoutRequest for HTTP-Redirect binding testing. + * + * This creates the exact query string that would be signed by an IdP, + * preserving the encoding to ensure signature verification succeeds. + * + * @param credentials The signing credentials + * @param issuer The issuer entity ID + * @param destination The SP's SLO URL + * @param nameId The NameID of the subject to log out + * @param sessionIndex Optional session index + * @param relayState Optional RelayState + * @param signatureAlgorithm The signature algorithm to use + * @return SignedRedirectMessage containing all components needed for verification + */ + fun createSignedLogoutRequestRedirect( + credentials: TestCredentials, + issuer: String = "https://idp.example.com", + destination: String = "https://sp.example.com/saml/slo", + nameId: String = "user@example.com", + sessionIndex: String? = "_session123", + relayState: String? = null, + signatureAlgorithm: SignatureAlgorithm = SignatureAlgorithm.RSA_SHA256 + ): SignedRedirectMessage { + LibSaml.ensureInitialized() + + val logoutRequestXml = createLogoutRequest( + issuer = issuer, + destination = destination, + nameId = nameId, + sessionIndex = sessionIndex + ) + + // Deflate and Base64 encode for HTTP-Redirect binding + val samlMessageBase64 = logoutRequestXml.toByteArray().deflateForRedirect() + + // Build the query string in the order required by SAML spec + val enc = "UTF-8" + val queryParts = mutableListOf() + queryParts.add("SAMLRequest=${java.net.URLEncoder.encode(samlMessageBase64, enc)}") + if (relayState != null) { + queryParts.add("RelayState=${java.net.URLEncoder.encode(relayState, enc)}") + } + queryParts.add("SigAlg=${java.net.URLEncoder.encode(signatureAlgorithm.uri, enc)}") + val queryStringWithoutSignature = queryParts.joinToString("&") + + // Sign the query string + val signature = signQueryString(queryStringWithoutSignature, credentials.credential, signatureAlgorithm) + val fullQueryString = "$queryStringWithoutSignature&Signature=${java.net.URLEncoder.encode(signature, enc)}" + + return SignedRedirectMessage( + fullQueryString = fullQueryString, + samlMessageBase64 = samlMessageBase64, + signatureBase64 = signature, + signatureAlgorithmUri = signatureAlgorithm.uri + ) + } + + /** + * Creates a signed LogoutResponse for HTTP-Redirect binding testing. + * + * @param credentials The signing credentials + * @param inResponseTo The ID of the LogoutRequest being responded to + * @param statusCode The SAML status code + * @param issuer The issuer entity ID + * @param destination The SP's SLO URL + * @param relayState Optional RelayState + * @param signatureAlgorithm The signature algorithm to use + * @return SignedRedirectMessage containing all components needed for verification + */ + fun createSignedLogoutResponseRedirect( + credentials: TestCredentials, + inResponseTo: String = "_request123", + statusCode: String = StatusCode.SUCCESS, + issuer: String = "https://idp.example.com", + destination: String = "https://sp.example.com/saml/slo", + relayState: String? = null, + signatureAlgorithm: SignatureAlgorithm = SignatureAlgorithm.RSA_SHA256 + ): SignedRedirectMessage { + LibSaml.ensureInitialized() + + val logoutResponseXml = createLogoutResponse( + inResponseTo = inResponseTo, + statusCode = statusCode, + issuer = issuer, + destination = destination + ) + + // Deflate and Base64 encode for HTTP-Redirect binding + val samlMessageBase64 = logoutResponseXml.toByteArray().deflateForRedirect() + + // Build the query string in the order required by SAML spec + val enc = "UTF-8" + val queryParts = mutableListOf() + queryParts.add("SAMLResponse=${java.net.URLEncoder.encode(samlMessageBase64, enc)}") + if (relayState != null) { + queryParts.add("RelayState=${java.net.URLEncoder.encode(relayState, enc)}") + } + queryParts.add("SigAlg=${java.net.URLEncoder.encode(signatureAlgorithm.uri, enc)}") + val queryStringWithoutSignature = queryParts.joinToString("&") + + // Sign the query string + val signature = signQueryString(queryStringWithoutSignature, credentials.credential, signatureAlgorithm) + val fullQueryString = "$queryStringWithoutSignature&Signature=${java.net.URLEncoder.encode(signature, enc)}" + + return SignedRedirectMessage( + fullQueryString = fullQueryString, + samlMessageBase64 = samlMessageBase64, + signatureBase64 = signature, + signatureAlgorithmUri = signatureAlgorithm.uri + ) + } + + /** + * Deflates and Base64 encodes bytes for HTTP-Redirect binding. + */ + private fun ByteArray.deflateForRedirect(): String { + val deflater = java.util.zip.Deflater(java.util.zip.Deflater.DEFAULT_COMPRESSION, true) + deflater.setInput(this) + deflater.finish() + + val outputStream = java.io.ByteArrayOutputStream() + val buffer = ByteArray(1024) + while (!deflater.finished()) { + val count = deflater.deflate(buffer) + outputStream.write(buffer, 0, count) + } + deflater.end() + + return Base64.encode(source = outputStream.toByteArray()) + } } From c57c3684c63dc0b3f3006e0894fdaf5598cad455 Mon Sep 17 00:00:00 2001 From: zibet27 Date: Fri, 6 Mar 2026 15:07:44 +0100 Subject: [PATCH 2/4] always expect InResposeTo --- .../auth/saml/SamlAuthenticationProvider.kt | 5 ++ .../server/auth/saml/SamlLogoutProcessor.kt | 11 +-- .../server/auth/saml/SamlResponseProcessor.kt | 27 +++---- .../src/io/ktor/server/auth/saml/SamlUtils.kt | 14 +++- .../auth/saml/SamlLogoutIntegrationTest.kt | 71 +++++++++++++++++-- 5 files changed, 99 insertions(+), 29 deletions(-) diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlAuthenticationProvider.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlAuthenticationProvider.kt index 0f26818d588..e0295e25756 100644 --- a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlAuthenticationProvider.kt +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlAuthenticationProvider.kt @@ -292,6 +292,11 @@ public class SamlAuthenticationProvider internal constructor( val samlResponse = parameters["SAMLResponse"] when { + samlRequest != null && samlResponse != null -> { + logger.debug("SLO endpoint failed. Both `SAMLRequest` and `SAMLRequest` are present") + call.respond(HttpStatusCode.BadRequest, "Malformed SAML request") + } + samlRequest != null -> handleIdpLogoutRequest(samlRequest, parameters) samlResponse != null -> handleLogoutResponse(samlResponse, parameters) else -> { diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlLogoutProcessor.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlLogoutProcessor.kt index ba79586b86f..d47e7275a20 100644 --- a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlLogoutProcessor.kt +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlLogoutProcessor.kt @@ -164,12 +164,15 @@ internal class SamlLogoutProcessor( signatureParam: String? = null, signatureAlgorithmParam: String? = null ): LogoutResult { - val responseXml = samlResponseBase64.decodeSamlMessage(isDeflated = binding == SamlBinding.HttpRedirect) - val document: Document = LibSaml.parserPool.parse(responseXml.toByteArray().inputStream()) - val logoutResponse = document.documentElement.unmarshall() + val logoutResponse = withValidationException { + val responseXml = samlResponseBase64.decodeSamlMessage(isDeflated = binding == SamlBinding.HttpRedirect) + val document: Document = LibSaml.parserPool.parse(responseXml.toByteArray().inputStream()) + document.documentElement.unmarshall() + } val inResponseTo = logoutResponse.inResponseTo - samlAssert(expectedRequestId == null || inResponseTo == expectedRequestId) { "InResponseTo mismatch" } + samlRequire(expectedRequestId) { "Unexpected logout response" } + samlAssert(inResponseTo == expectedRequestId) { "InResponseTo mismatch" } // Issuer is required for security - ensures the response is from the expected IdP val issuer = samlRequire(logoutResponse.issuer?.value) { "LogoutResponse Issuer is required" } diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlResponseProcessor.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlResponseProcessor.kt index 5bb112eacd2..0f57cfbdc4a 100644 --- a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlResponseProcessor.kt +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlResponseProcessor.kt @@ -98,36 +98,31 @@ internal class SamlResponseProcessor( * @throws SamlValidationException if validation fails */ suspend fun processResponse(samlResponseBase64: String, expectedRequestId: String?): SamlCredential { - val samlResponseXml = String(bytes = Base64.decode(samlResponseBase64)) - val response = parseResponse(samlResponseXml).also { it.validate(expectedRequestId) } + val response = parseResponse(samlResponseBase64).also { it.validate(expectedRequestId) } val assertion = response.extractAssertion().also { it.validate(expectedRequestId) } return SamlCredential(response, assertion) } - private fun parseResponse(xml: String): Response = withValidationException { - val document: Document = LibSaml.parserPool.parse(ByteArrayInputStream(xml.toByteArray())) + private fun parseResponse(samlResponseBase64: String): Response = withValidationException { + val xmlStream = ByteArrayInputStream(Base64.decode(source = samlResponseBase64)) + val document: Document = LibSaml.parserPool.parse(xmlStream) document.documentElement.unmarshall() } private fun Response.validate(expectedRequestId: String?) { val statusCode = status?.statusCode?.value - if (statusCode != StatusCode.SUCCESS) { + samlAssert(statusCode == StatusCode.SUCCESS) { val statusMessage = status?.statusMessage?.value ?: "No message" - throw SamlValidationException("SAML response status is not Success: $statusCode - $statusMessage") + "SAML response status is not Success: $statusCode - $statusMessage" } - when { - expectedRequestId != null -> { - samlAssert(inResponseTo == expectedRequestId) { "InResponseTo mismatch" } - } - - !allowIdpInitiatedSso -> throw SamlValidationException("IdP-initiated SSO is not allowed.") + if (expectedRequestId != null) { + samlAssert(inResponseTo == expectedRequestId) { "InResponseTo mismatch" } + } else { + samlAssert(allowIdpInitiatedSso) { "IdP-initiated SSO is not allowed." } } - val issuer = issuer?.value - samlAssert(issuer == idpMetadata.entityId) { "Response issuer mismatch" } - - val destination = destination + samlAssert(issuer?.value == idpMetadata.entityId) { "Response issuer mismatch" } samlAssert(!requireDestination || destination != null) { "Response Destination is not present" } samlAssert(destination == null || destination == acsUrl) { "Response Destination mismatch" } diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlUtils.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlUtils.kt index c1571967729..9c691b7bcf9 100644 --- a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlUtils.kt +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlUtils.kt @@ -82,7 +82,11 @@ internal fun String.encodeSamlMessage(deflate: Boolean): String { } val bytesOut = ByteArrayOutputStream() val deflater = Deflater(Deflater.DEFLATED, true) - DeflaterOutputStream(bytesOut, deflater).use { it.write(bytes) } + try { + DeflaterOutputStream(bytesOut, deflater).use { it.write(bytes) } + } finally { + deflater.end() + } return Base64.encode(source = bytesOut.toByteArray()) } @@ -98,8 +102,12 @@ internal fun String.decodeSamlMessage(isDeflated: Boolean): String { return decodedBytes.toString(Charsets.UTF_8) } val inflater = Inflater(true) - val inflaterInputStream = InflaterInputStream(decodedBytes.inputStream(), inflater) - return inflaterInputStream.readBytes().toString(Charsets.UTF_8) + try { + val inflaterInputStream = InflaterInputStream(decodedBytes.inputStream(), inflater) + return inflaterInputStream.readBytes().toString(Charsets.UTF_8) + } finally { + inflater.end() + } } @Suppress("UNCHECKED_CAST") diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SamlLogoutIntegrationTest.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SamlLogoutIntegrationTest.kt index cf10d67fa85..15b8089f7e0 100644 --- a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SamlLogoutIntegrationTest.kt +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SamlLogoutIntegrationTest.kt @@ -123,24 +123,69 @@ class SamlLogoutIntegrationTest { fun `test LogoutResponse processing with success and failure status`() = testApplication { configureSamlAuth(enableSingleLogout = true) + val testClient = noRedirectsClient() + + // Initiate SP-initiated logout to populate session with logoutRequestId + val initiateResponse = testClient.get("/test-logout") + assertEquals(HttpStatusCode.Found, initiateResponse.status) + val sessionCookie = initiateResponse.headers[HttpHeaders.SetCookie] + assertNotNull(sessionCookie, "Session cookie should be set") + val logoutRequestId = initiateResponse.headers["X-Logout-Request-Id"] + assertNotNull(logoutRequestId, "LogoutRequest ID should be returned in header") + + // Test SUCCESS case: InResponseTo matches the stored logoutRequestId val successResponseXml = SamlTestUtils.createLogoutResponse( - inResponseTo = "_test_request_id", + inResponseTo = logoutRequestId, statusCode = StatusCode.SUCCESS, issuer = IDP_ENTITY_ID, destination = SLO_URL ) val successBase64 = SamlTestUtils.encodeForPost(successResponseXml) - val successResponse = client.post(SLO_PATH) { + val successResponse = testClient.post(SLO_PATH) { contentType(ContentType.Application.FormUrlEncoded) + header(HttpHeaders.Cookie, sessionCookie.substringBefore(";")) setBody("SAMLResponse=${successBase64.encodeURLParameter()}") } assertEquals(HttpStatusCode.OK, successResponse.status) assertTrue(successResponse.bodyAsText().contains("Logout completed")) + // Test FAILURE case: InResponseTo does NOT match the stored logoutRequestId + // Re-initiate logout to get a fresh session with logoutRequestId + val reinitiateResponse = testClient.get("/test-logout") + assertEquals(HttpStatusCode.Found, reinitiateResponse.status) + val freshSessionCookie = reinitiateResponse.headers[HttpHeaders.SetCookie] + assertNotNull(freshSessionCookie) + + val mismatchedResponseXml = SamlTestUtils.createLogoutResponse( + inResponseTo = "_different_request_id", + statusCode = StatusCode.SUCCESS, + issuer = IDP_ENTITY_ID, + destination = SLO_URL + ) + val mismatchedBase64 = SamlTestUtils.encodeForPost(mismatchedResponseXml) + + val mismatchedResponse = testClient.post(SLO_PATH) { + contentType(ContentType.Application.FormUrlEncoded) + header(HttpHeaders.Cookie, freshSessionCookie.substringBefore(";")) + setBody("SAMLResponse=${mismatchedBase64.encodeURLParameter()}") + } + + // Mismatched InResponseTo should result in BadRequest (InResponseTo mismatch is caught as validation error) + assertEquals(HttpStatusCode.BadRequest, mismatchedResponse.status) + assertTrue(mismatchedResponse.bodyAsText().contains("Invalid logout response")) + + // Test IdP failure case: InResponseTo matches but IdP reports failure status + val reinitiateResponse2 = testClient.get("/test-logout") + assertEquals(HttpStatusCode.Found, reinitiateResponse2.status) + val freshSessionCookie2 = reinitiateResponse2.headers[HttpHeaders.SetCookie] + assertNotNull(freshSessionCookie2) + val logoutRequestId2 = reinitiateResponse2.headers["X-Logout-Request-Id"] + assertNotNull(logoutRequestId2) + val failureResponseXml = SamlTestUtils.createLogoutResponse( - inResponseTo = "_test_request", + inResponseTo = logoutRequestId2, statusCode = StatusCode.RESPONDER, statusMessage = "Logout failed at IdP", issuer = IDP_ENTITY_ID, @@ -148,8 +193,9 @@ class SamlLogoutIntegrationTest { ) val failureBase64 = SamlTestUtils.encodeForPost(failureResponseXml) - val failureResponse = client.post(SLO_PATH) { + val failureResponse = testClient.post(SLO_PATH) { contentType(ContentType.Application.FormUrlEncoded) + header(HttpHeaders.Cookie, freshSessionCookie2.substringBefore(";")) setBody("SAMLResponse=${failureBase64.encodeURLParameter()}") } @@ -162,16 +208,27 @@ class SamlLogoutIntegrationTest { fun `test LogoutResponse with RelayState redirects`() = testApplication { configureSamlAuth(enableSingleLogout = true) + val testClient = noRedirectsClient() + + // Initiate SP-initiated logout to populate session with logoutRequestId + val initiateResponse = testClient.get("/test-logout") + assertEquals(HttpStatusCode.Found, initiateResponse.status) + val sessionCookie = initiateResponse.headers[HttpHeaders.SetCookie] + assertNotNull(sessionCookie) + val logoutRequestId = initiateResponse.headers["X-Logout-Request-Id"] + assertNotNull(logoutRequestId) + val logoutResponseXml = SamlTestUtils.createLogoutResponse( - inResponseTo = "_test_request", + inResponseTo = logoutRequestId, statusCode = StatusCode.SUCCESS, issuer = IDP_ENTITY_ID, destination = SLO_URL ) val base64Response = SamlTestUtils.encodeForPost(logoutResponseXml) - val response = noRedirectsClient().post(SLO_PATH) { + val response = testClient.post(SLO_PATH) { contentType(ContentType.Application.FormUrlEncoded) + header(HttpHeaders.Cookie, sessionCookie.substringBefore(";")) setBody("SAMLResponse=${base64Response.encodeURLParameter()}&RelayState=/post-logout-page") } @@ -327,6 +384,8 @@ class SamlLogoutIntegrationTest { spMetadata = spMetadata, sessionIndex = "_session123" ) + // Include the messageId in a header for tests to use when constructing LogoutResponse + call.response.header("X-Logout-Request-Id", result.messageId) call.respondRedirect(result.redirectUrl) } From bb585b8d58a0a660e9c862543c5a417305f2e227 Mon Sep 17 00:00:00 2001 From: zibet27 Date: Tue, 10 Mar 2026 13:43:17 +0100 Subject: [PATCH 3/4] add test cerificate for idp metadata --- .../ktor/server/auth/saml/SamlLogoutTest.kt | 44 ++++++++++++++++--- .../test/io/ktor/server/auth/saml/TestUtil.kt | 24 +++++----- 2 files changed, 50 insertions(+), 18 deletions(-) diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SamlLogoutTest.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SamlLogoutTest.kt index ce1f68dbeab..8816f6b2e76 100644 --- a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SamlLogoutTest.kt +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/SamlLogoutTest.kt @@ -4,8 +4,10 @@ package io.ktor.server.auth.saml +import io.ktor.network.tls.certificates.buildKeyStore import kotlinx.coroutines.runBlocking import org.opensaml.saml.saml2.core.StatusCode +import kotlin.io.encoding.Base64 import kotlin.test.* import kotlin.time.Clock import kotlin.time.Duration.Companion.minutes @@ -127,8 +129,17 @@ class SamlLogoutTest { val withSloUrl = parseSamlIdpMetadata( """ - + + + + + $TEST_CERTIFICATE_BASE64 + + + @@ -144,8 +155,17 @@ class SamlLogoutTest { val withoutSloUrl = parseSamlIdpMetadata( """ - + + + + + $TEST_CERTIFICATE_BASE64 + + + @@ -174,7 +194,7 @@ class SamlLogoutTest { val signedRequest = SamlTestUtils.createSignedLogoutRequestRedirect( credentials = credentials, - issuer = idpMetadata.entityId, + issuer = idpMetadata.entityId!!, destination = "https://sp.example.com/saml/slo", nameId = "user@example.com", sessionIndex = "_session123" @@ -193,7 +213,7 @@ class SamlLogoutTest { credentials = credentials, inResponseTo = "_request123", statusCode = StatusCode.SUCCESS, - issuer = idpMetadata.entityId, + issuer = idpMetadata.entityId!!, destination = "https://sp.example.com/saml/slo" ) val responseResult = processor.processResponse( @@ -226,7 +246,7 @@ class SamlLogoutTest { val wrongKey = SamlTestUtils.createSignedLogoutRequestRedirect( credentials = signingCredentials, - issuer = idpMetadata.entityId, + issuer = idpMetadata.entityId!!, destination = "https://sp.example.com/saml/slo", nameId = "user@example.com" ) @@ -252,7 +272,7 @@ class SamlLogoutTest { ) val signedMessage = SamlTestUtils.createSignedLogoutRequestRedirect( credentials = validCredentials, - issuer = validMetadata.entityId, + issuer = validMetadata.entityId!!, destination = "https://sp.example.com/saml/slo", nameId = "user@example.com" ) @@ -465,4 +485,16 @@ class SamlLogoutTest { clockSkew = clockSkew, replayCache = InMemorySamlReplayCache() ) + + companion object { + private val TEST_CERTIFICATE_BASE64: String by lazy { + val keyStore = buildKeyStore { + certificate("test") { + password = "test" + } + } + val cert = keyStore.getCertificate("test") as java.security.cert.X509Certificate + Base64.encode(cert.encoded) + } + } } diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/TestUtil.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/TestUtil.kt index 19c754744ea..93c0296dd87 100644 --- a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/TestUtil.kt +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/TestUtil.kt @@ -363,12 +363,12 @@ object SamlTestUtils { sloUrl: String? = "https://idp.example.com/slo" ): IdPMetadata { val credentials = generateTestCredentials() - return IdPMetadata( - entityId = entityId, - ssoUrl = ssoUrl, - sloUrl = sloUrl, - signingCredentials = listOf(credentials.credential) - ) + return IdPMetadata { + this.entityId = entityId + this.ssoUrl = ssoUrl + this.sloUrl = sloUrl + this.signingCredentials = listOf(credentials.credential) + } } /** @@ -482,12 +482,12 @@ object SamlTestUtils { sloUrl: String = "https://idp.example.com/slo", credentials: TestCredentials = sharedIdpCredentials ): IdPMetadata { - return IdPMetadata( - entityId = entityId, - ssoUrl = ssoUrl, - sloUrl = sloUrl, - signingCredentials = listOf(credentials.credential) - ) + return IdPMetadata { + this.entityId = entityId + this.ssoUrl = ssoUrl + this.sloUrl = sloUrl + this.signingCredentials = listOf(credentials.credential) + } } /** From 28ae100ab689803f97127a2a3882dc691688882d Mon Sep 17 00:00:00 2001 From: zibet27 Date: Wed, 18 Mar 2026 11:02:50 +0100 Subject: [PATCH 4/4] improve RelayStateValidator --- gradlew.bat | 186 ++++++++--------- .../api/ktor-server-auth-saml.api | 34 +++- .../server/auth/saml/RelayStateValidator.kt | 189 ++++++++++++++++++ .../auth/saml/SamlAuthenticationProvider.kt | 90 ++------- .../io/ktor/server/auth/saml/SamlConfig.kt | 29 +-- .../auth/saml/RelayStateValidationTest.kt | 131 ++++++++---- 6 files changed, 434 insertions(+), 225 deletions(-) create mode 100644 ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/RelayStateValidator.kt diff --git a/gradlew.bat b/gradlew.bat index e509b2dd8fe..c4bdd3ab8e3 100644 --- a/gradlew.bat +++ b/gradlew.bat @@ -1,93 +1,93 @@ -@rem -@rem Copyright 2015 the original author or authors. -@rem -@rem Licensed under the Apache License, Version 2.0 (the "License"); -@rem you may not use this file except in compliance with the License. -@rem You may obtain a copy of the License at -@rem -@rem https://www.apache.org/licenses/LICENSE-2.0 -@rem -@rem Unless required by applicable law or agreed to in writing, software -@rem distributed under the License is distributed on an "AS IS" BASIS, -@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -@rem See the License for the specific language governing permissions and -@rem limitations under the License. -@rem -@rem SPDX-License-Identifier: Apache-2.0 -@rem - -@if "%DEBUG%"=="" @echo off -@rem ########################################################################## -@rem -@rem Gradle startup script for Windows -@rem -@rem ########################################################################## - -@rem Set local scope for the variables with windows NT shell -if "%OS%"=="Windows_NT" setlocal - -set DIRNAME=%~dp0 -if "%DIRNAME%"=="" set DIRNAME=. -@rem This is normally unused -set APP_BASE_NAME=%~n0 -set APP_HOME=%DIRNAME% - -@rem Resolve any "." and ".." in APP_HOME to make it shorter. -for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi - -@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" - -@rem Find java.exe -if defined JAVA_HOME goto findJavaFromJavaHome - -set JAVA_EXE=java.exe -%JAVA_EXE% -version >NUL 2>&1 -if %ERRORLEVEL% equ 0 goto execute - -echo. 1>&2 -echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 1>&2 -echo. 1>&2 -echo Please set the JAVA_HOME variable in your environment to match the 1>&2 -echo location of your Java installation. 1>&2 - -goto fail - -:findJavaFromJavaHome -set JAVA_HOME=%JAVA_HOME:"=% -set JAVA_EXE=%JAVA_HOME%/bin/java.exe - -if exist "%JAVA_EXE%" goto execute - -echo. 1>&2 -echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 1>&2 -echo. 1>&2 -echo Please set the JAVA_HOME variable in your environment to match the 1>&2 -echo location of your Java installation. 1>&2 - -goto fail - -:execute -@rem Setup the command line - - - -@rem Execute Gradle -"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -jar "%APP_HOME%\gradle\wrapper\gradle-wrapper.jar" %* - -:end -@rem End local scope for the variables with windows NT shell -if %ERRORLEVEL% equ 0 goto mainEnd - -:fail -rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of -rem the _cmd.exe /c_ return code! -set EXIT_CODE=%ERRORLEVEL% -if %EXIT_CODE% equ 0 set EXIT_CODE=1 -if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% -exit /b %EXIT_CODE% - -:mainEnd -if "%OS%"=="Windows_NT" endlocal - -:omega +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem +@rem SPDX-License-Identifier: Apache-2.0 +@rem + +@if "%DEBUG%"=="" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%"=="" set DIRNAME=. +@rem This is normally unused +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Resolve any "." and ".." in APP_HOME to make it shorter. +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if %ERRORLEVEL% equ 0 goto execute + +echo. 1>&2 +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto execute + +echo. 1>&2 +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 + +goto fail + +:execute +@rem Setup the command line + + + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -jar "%APP_HOME%\gradle\wrapper\gradle-wrapper.jar" %* + +:end +@rem End local scope for the variables with windows NT shell +if %ERRORLEVEL% equ 0 goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +set EXIT_CODE=%ERRORLEVEL% +if %EXIT_CODE% equ 0 set EXIT_CODE=1 +if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% +exit /b %EXIT_CODE% + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/api/ktor-server-auth-saml.api b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/api/ktor-server-auth-saml.api index bdab2b42558..b7d0e89adfa 100644 --- a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/api/ktor-server-auth-saml.api +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/api/ktor-server-auth-saml.api @@ -70,6 +70,36 @@ public final class io/ktor/server/auth/saml/NameIdFormat$Companion { public final fun getUnspecified-yVcJgns ()Ljava/lang/String; } +public abstract interface class io/ktor/server/auth/saml/RelayStateValidator { + public static final field Companion Lio/ktor/server/auth/saml/RelayStateValidator$Companion; + public abstract fun validate (Ljava/lang/String;)Z +} + +public final class io/ktor/server/auth/saml/RelayStateValidator$AllowList : io/ktor/server/auth/saml/RelayStateValidator { + public fun (Ljava/util/List;)V + public fun ([Ljava/lang/String;)V + public fun validate (Ljava/lang/String;)Z +} + +public final class io/ktor/server/auth/saml/RelayStateValidator$Companion { + public final fun containsDangerousPatterns (Ljava/lang/String;)Z + public final fun isAbsoluteUrlSafe (Ljava/lang/String;)Z + public final fun normalizeRelativePath (Ljava/lang/String;)Ljava/lang/String; +} + +public final class io/ktor/server/auth/saml/RelayStateValidator$Custom : io/ktor/server/auth/saml/RelayStateValidator { + public fun (Lkotlin/jvm/functions/Function1;)V + public fun validate (Ljava/lang/String;)Z +} + +public final class io/ktor/server/auth/saml/RelayStateValidator$Default : io/ktor/server/auth/saml/RelayStateValidator { + public static final field INSTANCE Lio/ktor/server/auth/saml/RelayStateValidator$Default; + public fun equals (Ljava/lang/Object;)Z + public fun hashCode ()I + public fun toString ()Ljava/lang/String; + public fun validate (Ljava/lang/String;)Z +} + public final class io/ktor/server/auth/saml/SamlAlgorithms { public static final field INSTANCE Lio/ktor/server/auth/saml/SamlAlgorithms; public final fun getRECOMMENDED_DIGEST_ALGORITHMS ()Ljava/util/Set; @@ -131,7 +161,6 @@ public final class io/ktor/server/auth/saml/SamlConfig : io/ktor/server/auth/Aut public final fun challenge (Lkotlin/jvm/functions/Function3;)V public final fun getAllowIdpInitiatedSso ()Z public final fun getAllowedDigestAlgorithms ()Ljava/util/Set; - public final fun getAllowedRelayStateUrls ()Ljava/util/List; public final fun getAllowedSignatureAlgorithms ()Ljava/util/Set; public final fun getAuthnRequestBinding ()Lio/ktor/server/auth/saml/SamlBinding; public final fun getClockSkew-UwyO8pc ()J @@ -140,6 +169,7 @@ public final class io/ktor/server/auth/saml/SamlConfig : io/ktor/server/auth/Aut public final fun getForceAuthn ()Z public final fun getIdp ()Lio/ktor/server/auth/saml/IdPMetadata; public final fun getNameIdFormat-W8VwlJw ()Ljava/lang/String; + public final fun getRelayStateValidator ()Lio/ktor/server/auth/saml/RelayStateValidator; public final fun getRequestedAuthnContext-4ml6bek ()Ljava/lang/String; public final fun getRequireDestination ()Z public final fun getRequireSignedLogoutRequest ()Z @@ -149,7 +179,6 @@ public final class io/ktor/server/auth/saml/SamlConfig : io/ktor/server/auth/Aut public final fun replayCache (Lio/ktor/server/auth/saml/SamlReplayCache;)V public final fun setAllowIdpInitiatedSso (Z)V public final fun setAllowedDigestAlgorithms (Ljava/util/Set;)V - public final fun setAllowedRelayStateUrls (Ljava/util/List;)V public final fun setAllowedSignatureAlgorithms (Ljava/util/Set;)V public final fun setAuthnRequestBinding (Lio/ktor/server/auth/saml/SamlBinding;)V public final fun setClockSkew-LRDsOJo (J)V @@ -158,6 +187,7 @@ public final class io/ktor/server/auth/saml/SamlConfig : io/ktor/server/auth/Aut public final fun setForceAuthn (Z)V public final fun setIdp (Lio/ktor/server/auth/saml/IdPMetadata;)V public final fun setNameIdFormat-eInvKLk (Ljava/lang/String;)V + public final fun setRelayStateValidator (Lio/ktor/server/auth/saml/RelayStateValidator;)V public final fun setRequestedAuthnContext-Br-yMbY (Ljava/lang/String;)V public final fun setRequireDestination (Z)V public final fun setRequireSignedLogoutRequest (Z)V diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/RelayStateValidator.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/RelayStateValidator.kt new file mode 100644 index 00000000000..e2ead32e314 --- /dev/null +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/RelayStateValidator.kt @@ -0,0 +1,189 @@ +/* + * Copyright 2014-2026 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package io.ktor.server.auth.saml + +import io.ktor.http.* +import java.net.URI + +/** + * Validates RelayState URLs to prevent open redirect attacks. + * + * SAML RelayState is a parameter passed through the authentication flow that typically + * contains a URL to redirect users to after authentication. + * + * Three implementations are provided: + * - [RelayStateValidator.Default]: Allows all URLs that pass basic safety checks (default) + * - [RelayStateValidator.AllowList]: Only allows URLs matching a configured allowlist + * - [RelayStateValidator.Custom]: Uses a user-provided validation function + * + * [Report a problem](https://ktor.io/feedback/?fqname=io.ktor.server.auth.saml.RelayStateValidator) + */ +public interface RelayStateValidator { + /** + * Validates that the relay state URL is safe for redirect. + * + * @param url The RelayState URL to validate + * @return `true` if the URL is allowed, `false` otherwise + */ + public fun validate(url: String): Boolean + + /** + * Default validator that allows all URLs passing basic safety checks. + * + * This validator permits: + * - Relative paths starting with `/` + * - Absolute HTTP/HTTPS URLs + * + * It rejects: + * - Protocol-relative URLs (`//example.com`) + * - URLs with backslashes + * - URLs with control characters + * - Non-HTTP(S) schemes (javascript:, data:, ftp:, etc.) + * - URLs with userinfo (`https://user:pass@host`) + */ + public data object Default : RelayStateValidator { + override fun validate(url: String): Boolean = when { + containsDangerousPatterns(url) -> false + url.startsWith("/") -> normalizeRelativePath(url) != null + else -> isAbsoluteUrlSafe(url) + } + } + + /** + * Validator that only allows URLs matching a configured allowlist. + * + * URLs must pass basic safety checks AND match one of the allowed prefixes. + * + * @param allowedUrls List of allowed URL prefixes (must not be empty). Can include: + * - Relative paths starting with `/` (e.g., `/app/`) + * - Absolute URLs (e.g., `https://myapp.example.com/`) + * + * For absolute URLs, matching checks: + * - Exact protocol match (case-insensitive) + * - Exact host match (case-insensitive per RFC 3986) + * - Exact port match + * - Path prefix match with segment boundary validation + * + * @throws IllegalArgumentException if [allowedUrls] is empty + */ + public class AllowList(private val allowedUrls: List) : RelayStateValidator { + + init { + require(allowedUrls.isNotEmpty()) { "AllowList requires at least one allowed URL" } + } + + public constructor(vararg allowedUrls: String) : this(allowedUrls.toList()) + + override fun validate(url: String): Boolean { + if (containsDangerousPatterns(url)) { + return false + } + if (url.startsWith("/")) { + val normalized = normalizeRelativePath(url) ?: return false + val pathPrefixes = allowedUrls.filter { it.startsWith("/") && !it.startsWith("//") } + return pathPrefixes.any { prefix -> + normalized.startsWith(prefix) && (normalized == prefix || prefix.endsWith("/")) + } + } + if (!isAbsoluteUrlSafe(url)) { + return false + } + return allowedUrls.any { prefix -> isAllowedAbsoluteRelayState(url, prefix) } + } + + /** + * Validates an absolute URL against an allowed prefix. + * Checks a scheme, host, port, and path with segment boundary validation. + */ + private fun isAllowedAbsoluteRelayState(targetUrl: String, allowedPrefix: String): Boolean { + val target = runCatching { Url(targetUrl) }.getOrNull() ?: return false + val allowed = runCatching { Url(allowedPrefix) }.getOrNull() ?: return false + + // Reject URLs with userinfo (user:pass@host) - bypass technique + if (target.user != null || target.password != null) return false + + // Only allow http and https schemes + if (target.protocol.name !in listOf("http", "https")) return false + if (allowed.protocol.name !in listOf("http", "https")) return false + + // Exact protocol match (case-insensitive) + if (!target.protocol.name.equals(allowed.protocol.name, ignoreCase = true)) return false + + // Exact host match (case-insensitive per RFC 3986) + if (!target.host.equals(allowed.host, ignoreCase = true)) return false + + // Exact port match + if (target.port != allowed.port) return false + + // Path prefix match with segment boundary check + val allowedPath = allowed.encodedPath.ifBlank { "/" } + val targetPath = target.encodedPath.ifBlank { "/" } + + if (!targetPath.startsWith(allowedPath)) return false + if (targetPath == allowedPath) return true + if (allowedPath.endsWith("/")) return true + return targetPath.getOrNull(allowedPath.length) == '/' + } + } + + /** + * Validator that uses a custom validation function. + * + * **Note:** The custom function is responsible for all validation, + * including basic safety checks. Consider using [Default] or [AllowList] + * for common use cases, or use the helper functions in [RelayStateValidator.Companion] + * to implement safety checks. + * + * @param validator Function that returns `true` if the URL is allowed + */ + public class Custom(private val validator: (String) -> Boolean) : RelayStateValidator { + override fun validate(url: String): Boolean = validator(url) + } + + public companion object { + /** + * Dangerous patterns include: + * - Protocol-relative URLs (starting with `//`) + * - Backslashes (potential path traversal on Windows) + * - Control characters + */ + public fun containsDangerousPatterns(url: String): Boolean { + return url.startsWith("//") || url.contains("\\") || url.any { it.isISOControl() } + } + + /** + * Validates and normalizes a relative path. + * + * A valid relative path: + * - Starts with `/` + * - After normalization, still starts with `/` and doesn't become protocol-relative + * + * @param url The URL to normalize + * @return The normalized path if valid, or `null` if invalid + */ + public fun normalizeRelativePath(url: String): String? { + if (!url.startsWith("/")) return null + val normalized = URI(url).normalize().toString() + return if (normalized.startsWith("/") && !normalized.startsWith("//")) normalized else null + } + + private val PROTOCOLS = listOf("http", "https") + + /** + * Checks if an absolute URL is safe for redirect. + * + * A safe absolute URL: + * - Uses HTTP or HTTPS scheme + * - Has no userinfo (user:pass@host) + */ + public fun isAbsoluteUrlSafe(url: String): Boolean { + val parsed = runCatching { Url(url) }.getOrNull() ?: return false + // Reject URLs with userinfo + if (parsed.user != null || parsed.password != null) return false + // Only allow http and https schemes + return parsed.protocol.name in PROTOCOLS + } + } +} diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlAuthenticationProvider.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlAuthenticationProvider.kt index e0295e25756..b6c1e618984 100644 --- a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlAuthenticationProvider.kt +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlAuthenticationProvider.kt @@ -15,7 +15,6 @@ import org.opensaml.saml.saml2.core.StatusCode import org.opensaml.security.x509.BasicX509Credential import org.slf4j.Logger import org.slf4j.LoggerFactory -import java.net.URI /** * SAML 2.0 authentication provider for Ktor Server. @@ -48,12 +47,7 @@ public class SamlAuthenticationProvider internal constructor( private val idpMetadata = requireNotNull(config.idp) { "IdP metadata must be configured. Use idp = parseSamlIdpMetadata(...) to set it." } - private val relayValidator = config.allowedRelayStateUrls.let { - if (it == null) { - logger.warn("RelayState validation is disabled. This is unsafe in production.") - } - RelayValidator(allowedRelayStateUrls = it) - } + private val relayStateValidator = config.relayStateValidator private val requestedAuthnContext = config.requestedAuthnContext private val authenticationFunction = requireNotNull(config.authenticationFunction) { @@ -206,8 +200,8 @@ public class SamlAuthenticationProvider internal constructor( when { relayState.isNullOrBlank() -> return - relayValidator.validate(url = relayState) -> call.respondRedirect(url = relayState) - else -> logger.warn("RelayState URL not in allowlist, ignoring: $relayState") + relayStateValidator.validate(url = relayState) -> call.respondRedirect(url = relayState) + else -> logger.warn("RelayState URL validation failed, ignoring: $relayState") } } catch (e: CancellationException) { throw e @@ -293,7 +287,7 @@ public class SamlAuthenticationProvider internal constructor( when { samlRequest != null && samlResponse != null -> { - logger.debug("SLO endpoint failed. Both `SAMLRequest` and `SAMLRequest` are present") + logger.debug("SLO endpoint failed. Both `SAMLRequest` and `SAMLResponse` are present") call.respond(HttpStatusCode.BadRequest, "Malformed SAML request") } @@ -332,12 +326,22 @@ public class SamlAuthenticationProvider internal constructor( call.sessions.clear() val idpSloUrl = requireNotNull(idpMetadata.sloUrl) { "IdP SLO URL not found" } + val validatedRelayState = parameters["RelayState"].takeIf { relayState -> + when { + relayState.isNullOrBlank() -> false + relayStateValidator.validate(relayState) -> true + else -> { + logger.warn("RelayState URL validation failed in IdP logout request, ignoring: $relayState") + false + } + } + } val logoutResponse = buildLogoutResponseRedirect( spEntityId = spEntityId, idpSloUrl = idpSloUrl, inResponseTo = logoutRequest.requestId, statusCodeValue = StatusCode.SUCCESS, - relayState = parameters["RelayState"], + relayState = validatedRelayState, signingCredential = signingCredential, signatureAlgorithm = config.signatureAlgorithm ) @@ -384,9 +388,12 @@ public class SamlAuthenticationProvider internal constructor( // Redirect to RelayState or respond with success val relayState = parameters["RelayState"] - if (!relayState.isNullOrBlank() && relayValidator.validate(url = relayState)) { + if (!relayState.isNullOrBlank() && relayStateValidator.validate(url = relayState)) { call.respondRedirect(relayState) } else { + if (!relayState.isNullOrBlank()) { + logger.warn("RelayState URL validation failed in IdP logout response, ignoring: $relayState") + } call.respond(HttpStatusCode.OK, "Logout completed") } } catch (e: SamlValidationException) { @@ -396,65 +403,6 @@ public class SamlAuthenticationProvider internal constructor( } } -internal class RelayValidator(private val allowedRelayStateUrls: List?) { - /** - * Validates that the relay state URL is allowed for redirect. - */ - fun validate(url: String): Boolean = when { - allowedRelayStateUrls == null -> true - // Reject dangerous patterns - url.startsWith("//") || url.contains("\\") || url.any { it.isISOControl() } -> false - // Allow relative paths - url.startsWith("/") -> { - val normalized = URI(url).normalize().toString() - if (normalized.startsWith("//") || !normalized.startsWith("/")) return false - if (allowedRelayStateUrls.isEmpty()) return true - val pathPrefixes = allowedRelayStateUrls.filter { it.startsWith("/") && !it.startsWith("//") } - pathPrefixes.any { prefix -> - normalized.startsWith(prefix) && (normalized == prefix || prefix.endsWith("/")) - } - } - // Validate absolute URLs - else -> { - allowedRelayStateUrls.any { prefix -> isAllowedAbsoluteRelayState(url, prefix) } - } - } - - /** - * Validates an absolute URL against an allowed prefix. - * Checks a scheme, host, port, and path with segment boundary validation. - */ - private fun isAllowedAbsoluteRelayState(targetUrl: String, allowedPrefix: String): Boolean { - val target = runCatching { Url(targetUrl) }.getOrNull() ?: return false - val allowed = runCatching { Url(allowedPrefix) }.getOrNull() ?: return false - - // Reject URLs with userinfo (user:pass@host) - bypass technique - if (target.user != null || target.password != null) return false - - // Only allow http and https schemes - if (target.protocol.name !in listOf("http", "https")) return false - if (allowed.protocol.name !in listOf("http", "https")) return false - - // Exact protocol match (case-insensitive) - if (!target.protocol.name.equals(allowed.protocol.name, ignoreCase = true)) return false - - // Exact host match (case-insensitive per RFC 3986) - if (!target.host.equals(allowed.host, ignoreCase = true)) return false - - // Exact port match - if (target.port != allowed.port) return false - - // Path prefix match with segment boundary check - val allowedPath = allowed.encodedPath.ifBlank { "/" } - val targetPath = target.encodedPath.ifBlank { "/" } - - if (!targetPath.startsWith(allowedPath)) return false - if (targetPath == allowedPath) return true - if (allowedPath.endsWith("/")) return true - return targetPath.getOrNull(allowedPath.length) == '/' - } -} - /** * Session data for SAML authentication. * Stores the AuthnRequest ID for InResponseTo validation and optional LogoutRequest ID for SLO. diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlConfig.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlConfig.kt index e0de01e133c..057e4252f2b 100644 --- a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlConfig.kt +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/src/io/ktor/server/auth/saml/SamlConfig.kt @@ -183,32 +183,15 @@ public class SamlConfig internal constructor( public var allowIdpInitiatedSso: Boolean = false /** - * List of allowed URL prefixes for RelayState redirection. + * Validator for RelayState URLs to prevent open redirect attacks. * - * When configured, the SP will only redirect to URLs that start with one of these prefixes - * after successful authentication. This prevents open redirect attacks where an attacker - * crafts a SAML response with a malicious RelayState parameter. + * When a SAML response includes a RelayState parameter, the SP may redirect the user + * to that URL after successful authentication. Without proper validation, this can be + * exploited for open redirect attacks. * - * By default, an empty list is used which blocks all external redirects (only relative paths - * starting with "/" are allowed). - * - * To allow all RelayState URLs (UNSAFE, not recommended for production), set to `null`. - * - * **Example:** - * ```kotlin - * saml("saml-auth") { - * // Only allow redirects to same origin - * allowedRelayStateUrls = listOf("https://myapp.example.com/") - * - * // Or allow multiple domains - * allowedRelayStateUrls = listOf( - * "https://myapp.example.com/", - * "https://app.example.com/" - * ) - * } - * ``` + * @see RelayStateValidator for detailed documentation on each validator type */ - public var allowedRelayStateUrls: List? = emptyList() + public var relayStateValidator: RelayStateValidator = RelayStateValidator.Default /** * Whether to require the Destination attribute in SAML responses. diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/RelayStateValidationTest.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/RelayStateValidationTest.kt index 0982a9dd46b..1039ec429c1 100644 --- a/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/RelayStateValidationTest.kt +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-saml/jvm/test/io/ktor/server/auth/saml/RelayStateValidationTest.kt @@ -7,13 +7,13 @@ package io.ktor.server.auth.saml import kotlin.test.* /** - * Unit tests for RelayValidator to prevent open redirect attacks. + * Unit tests for RelayStateValidator to prevent open redirect attacks. */ class RelayStateValidationTest { @Test - fun `relative paths with empty allowlist are accepted`() { - val validator = RelayValidator(allowedRelayStateUrls = emptyList()) + fun `Default validator allows relative paths`() { + val validator = RelayStateValidator.Default // Basic relative paths with various components assertTrue(validator.validate("/dashboard")) @@ -24,8 +24,17 @@ class RelayStateValidationTest { } @Test - fun `dangerous URL patterns are blocked`() { - val validator = RelayValidator(allowedRelayStateUrls = emptyList()) + fun `Default validator allows safe absolute URLs`() { + val validator = RelayStateValidator.Default + + assertTrue(validator.validate("https://any-domain.com/any-path")) + assertTrue(validator.validate("http://localhost/test")) + assertTrue(validator.validate("https://example.com:8080/path")) + } + + @Test + fun `Default validator blocks dangerous patterns`() { + val validator = RelayStateValidator.Default // Scheme-relative, backslashes, and control characters assertFalse(validator.validate("//evil.com/phish")) @@ -40,8 +49,39 @@ class RelayStateValidationTest { } @Test - fun `absolute URLs with allowlist validation`() { - val validator = RelayValidator(allowedRelayStateUrls = listOf("https://myapp.example.com/")) + fun `Default validator blocks URLs with userinfo`() { + val validator = RelayStateValidator.Default + + assertFalse(validator.validate("https://user@evil.com/phish")) + assertFalse(validator.validate("https://user:pass@evil.com/phish")) + } + + @Test + fun `AllowList validator requires non-empty list`() { + assertFailsWith { + RelayStateValidator.AllowList(emptyList()) + } + } + + @Test + fun `AllowList validator blocks dangerous patterns`() { + val validator = RelayStateValidator.AllowList("/") + + // Scheme-relative, backslashes, and control characters + assertFalse(validator.validate("//evil.com/phish")) + assertFalse(validator.validate("/foo\\..\\bar")) + assertFalse(validator.validate("/foo\u0000bar")) + + // Dangerous schemes + assertFalse(validator.validate("javascript:alert(1)")) + assertFalse(validator.validate("data:text/html,")) + assertFalse(validator.validate("ftp://example.com/file")) + assertFalse(validator.validate("file:///etc/passwd")) + } + + @Test + fun `AllowList validator with absolute URL allowlist`() { + val validator = RelayStateValidator.AllowList("https://myapp.example.com/") // Allowed origin assertTrue(validator.validate("https://myapp.example.com/dashboard")) @@ -57,25 +97,23 @@ class RelayStateValidationTest { } @Test - fun `path prefix matching with segment boundaries`() { + fun `AllowList validator path prefix matching with segment boundaries`() { // Prefix without trailing slash - exact match only - val validatorNoSlash = RelayValidator(allowedRelayStateUrls = listOf("https://myapp.example.com/app")) + val validatorNoSlash = RelayStateValidator.AllowList("https://myapp.example.com/app") assertTrue(validatorNoSlash.validate("https://myapp.example.com/app/dashboard")) assertFalse(validatorNoSlash.validate("https://myapp.example.com/application")) // Prefix with trailing slash - allows subpaths - val validatorWithSlash = RelayValidator(allowedRelayStateUrls = listOf("https://myapp.example.com/app/")) + val validatorWithSlash = RelayStateValidator.AllowList("https://myapp.example.com/app/") assertTrue(validatorWithSlash.validate("https://myapp.example.com/app/dashboard")) assertFalse(validatorWithSlash.validate("https://myapp.example.com/app")) } @Test - fun `multiple allowed origins`() { - val validator = RelayValidator( - allowedRelayStateUrls = listOf( - "https://app1.example.com/", - "https://app2.example.com/" - ) + fun `AllowList validator with multiple allowed origins`() { + val validator = RelayStateValidator.AllowList( + "https://app1.example.com/", + "https://app2.example.com/" ) assertTrue(validator.validate("https://app1.example.com/page")) @@ -84,25 +122,16 @@ class RelayStateValidationTest { } @Test - fun `null allowlist disables validation`() { - val validator = RelayValidator(allowedRelayStateUrls = null) - assertTrue(validator.validate("https://any-domain.com/any-path")) - assertTrue(validator.validate("/any/path")) - } - - @Test - fun `host comparison is case insensitive`() { - val validator = RelayValidator(allowedRelayStateUrls = listOf("https://MyApp.Example.COM/")) + fun `AllowList validator host comparison is case insensitive`() { + val validator = RelayStateValidator.AllowList("https://MyApp.Example.COM/") assertTrue(validator.validate("https://myapp.example.com/page")) } @Test - fun `mixed relative and absolute URLs in allowlist`() { - val validator = RelayValidator( - allowedRelayStateUrls = listOf( - "/local/", - "https://external.com/" - ) + fun `AllowList validator with mixed relative and absolute URLs`() { + val validator = RelayStateValidator.AllowList( + "/local/", + "https://external.com/" ) assertTrue(validator.validate("/local/page")) @@ -112,18 +141,48 @@ class RelayStateValidationTest { } @Test - fun `default ports are handled correctly`() { - val validator = RelayValidator(allowedRelayStateUrls = listOf("https://myapp.example.com/")) + fun `AllowList validator handles default ports correctly`() { + val validator = RelayStateValidator.AllowList("https://myapp.example.com/") assertTrue(validator.validate("https://myapp.example.com/page")) - val validatorWithPort = RelayValidator(allowedRelayStateUrls = listOf("https://myapp.example.com:443/")) + val validatorWithPort = RelayStateValidator.AllowList("https://myapp.example.com:443/") assertTrue(validatorWithPort.validate("https://myapp.example.com/page")) } @Test - fun `malformed URLs are rejected`() { - val validator = RelayValidator(allowedRelayStateUrls = listOf("https://myapp.example.com/")) + fun `AllowList validator rejects malformed URLs`() { + val validator = RelayStateValidator.AllowList("https://myapp.example.com/") assertFalse(validator.validate("https://[invalid")) assertFalse(validator.validate("ht!tp://example.com")) } + + @Test + fun `Custom validator uses provided function`() { + val validator = RelayStateValidator.Custom { url -> + url.startsWith("/safe/") && !url.contains("..") + } + + assertTrue(validator.validate("/safe/page")) + assertTrue(validator.validate("/safe/nested/page")) + assertFalse(validator.validate("/unsafe/page")) + assertFalse(validator.validate("/safe/../unsafe")) + } + + @Test + fun `Custom validator can allow all URLs`() { + val validator = RelayStateValidator.Custom { true } + + assertTrue(validator.validate("https://any.com/path")) + assertTrue(validator.validate("/any/path")) + // Note: Custom validator bypasses safety checks - user is responsible + assertTrue(validator.validate("javascript:alert(1)")) + } + + @Test + fun `Custom validator can block all URLs`() { + val validator = RelayStateValidator.Custom { false } + + assertFalse(validator.validate("https://safe.com/path")) + assertFalse(validator.validate("/safe/path")) + } }