diff --git a/docs/auth.md b/docs/auth.md index 2dca9226f2c..8a60c6a5751 100644 --- a/docs/auth.md +++ b/docs/auth.md @@ -1049,6 +1049,73 @@ Content-Type: application/json } ``` +### POST /auth/token_exchange + +This endpoint exchanges an external OIDC `id_token` for a normal Cozy OAuth +client and token pair on the target instance. + +It is intended for browser-based admin applications that authenticate with an +external identity provider, then need to call Cozy APIs directly on an +organization instance. + +The target Cozy instance is the request host. The exchanged `id_token` must: + +- be signed by the configured OIDC provider +- match the configured issuer and audience +- contain an `org_id` claim equal to the target instance organization id +- contain an `org_role` claim equal to `owner` or `admin` + +The request body is JSON: + +- `id_token`, the external OIDC token +- `scope`, currently limited to `io.cozy.files` + +Example: + +```http +POST /auth/token_exchange HTTP/1.1 +Host: myorg123.example.com +Content-Type: application/json +Accept: application/json + +{ + "id_token": "eyJhbGciOiJSUzI1NiIsImtpZCI6InRva2VuLWV4Y2hhbmdlIn0...", + "scope": "io.cozy.files" +} +``` + +Response: + +```http +HTTP/1.1 200 OK +Content-Type: application/json +Cache-Control: no-store +Pragma: no-cache +``` + +```json +{ + "access_token": "eyJhbGciOiJS", + "token_type": "bearer", + "refresh_token": "eyJhbGciOiJS", + "scope": "io.cozy.files", + "client_id": "64ce5cb0-bd4c-11e6-880e-b3b7dfda89d3", + "client_secret": "Oung7oi5", + "registration_access_token": "reg123" +} +``` + +The returned OAuth client is a normal Cozy OAuth client: + +- `client_id`, `client_secret`, `access_token`, and `refresh_token` can be + used directly with `cozy-client` +- `registration_access_token` can be used with + `DELETE /auth/register/:client-id` to revoke that exchanged client + +When the external `id_token` contains a `sid` claim, the created OAuth client +is bound to that upstream OIDC session so it can be revoked by OIDC +backchannel logout. + ### POST /auth/session_code This endpoint can be used by the flagship application in order to create a diff --git a/web/auth/auth.go b/web/auth/auth.go index d8ff5ed0fad..7abcef90fc0 100644 --- a/web/auth/auth.go +++ b/web/auth/auth.go @@ -19,6 +19,7 @@ import ( "github.com/cozy/cozy-stack/pkg/config/config" "github.com/cozy/cozy-stack/pkg/crypto" "github.com/cozy/cozy-stack/pkg/limits" + "github.com/cozy/cozy-stack/pkg/utils" "github.com/cozy/cozy-stack/web/middlewares" "github.com/labstack/echo/v4" ) @@ -510,6 +511,95 @@ func registerPreflight(c echo.Context) error { return corsPreflight(echo.POST)(c) } +func tokenExchangeCORS(c echo.Context) bool { + origin := c.Request().Header.Get(echo.HeaderOrigin) + if origin == "" { + return true + } + inst := middlewares.GetInstance(c) + if !tokenExchangeOriginAllowed(origin, inst) { + return false + } + + res := c.Response() + res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) + res.Header().Set(echo.HeaderAccessControlAllowOrigin, origin) + res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true") + return true +} + +// Allow CORS from *.org_domain +func tokenExchangeOriginAllowed(origin string, inst *instance.Instance) bool { + if inst == nil || inst.OrgDomain == "" { + return false + } + originHost := utils.ExtractInstanceHost(origin) + if originHost == "" { + return false + } + orgDomain := utils.NormalizeDomain(inst.OrgDomain) + return originHost == orgDomain || strings.HasSuffix(originHost, "."+orgDomain) +} + +func tokenExchangePreflight(c echo.Context) error { + if !tokenExchangeCORS(c) { + return c.NoContent(http.StatusForbidden) + } + req := c.Request() + res := c.Response() + res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod) + res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders) + res.Header().Set(echo.HeaderAccessControlAllowMethods, echo.POST) + res.Header().Set(echo.HeaderAccessControlMaxAge, middlewares.MaxAgeCORS) + if h := req.Header.Get(echo.HeaderAccessControlRequestHeaders); h != "" { + res.Header().Set(echo.HeaderAccessControlAllowHeaders, h) + } + return c.NoContent(http.StatusNoContent) +} + +func tokenExchange(c echo.Context) error { + if !tokenExchangeCORS(c) { + return c.JSON(http.StatusForbidden, echo.Map{ + "error": "the origin of this application is not allowed", + }) + } + c.Response().Header().Set("Cache-Control", "no-store") + c.Response().Header().Set("Pragma", "no-cache") + + inst := middlewares.GetInstance(c) + var reqBody tokenExchangeRequest + if err := c.Bind(&reqBody); err != nil { + return c.JSON(http.StatusBadRequest, echo.Map{ + "error": "invalid request body", + }) + } + reqBody.IDToken = strings.TrimSpace(reqBody.IDToken) + reqBody.Scope = strings.TrimSpace(reqBody.Scope) + if reqBody.IDToken == "" { + return c.JSON(http.StatusBadRequest, echo.Map{ + "error": "the id_token parameter is mandatory", + }) + } + if reqBody.Scope == "" { + return c.JSON(http.StatusBadRequest, echo.Map{ + "error": "the scope parameter is mandatory", + }) + } + + out, err := executeTokenExchange(c, inst, reqBody) + if err != nil { + var httpErr *echo.HTTPError + if errors.As(err, &httpErr) { + return c.JSON(httpErr.Code, echo.Map{ + "error": fmt.Sprint(httpErr.Message), + }) + } + return err + } + + return c.JSON(http.StatusOK, out) +} + func registerFromWebApp(c echo.Context) error { res := c.Response() origin := c.Request().Header.Get(echo.HeaderOrigin) @@ -683,6 +773,8 @@ func Routes(router *echo.Group) { authHandler.Register(router.Group("/authorize", noCSRF)) router.POST("/access_token", accessToken) + router.POST("/token_exchange", tokenExchange, middlewares.AcceptJSON, middlewares.ContentTypeJSON) + router.OPTIONS("/token_exchange", tokenExchangePreflight) // Flagship app router.POST("/session_code", CreateSessionCode) diff --git a/web/auth/auth_test.go b/web/auth/auth_test.go index 72d0299c948..f98a15893df 100644 --- a/web/auth/auth_test.go +++ b/web/auth/auth_test.go @@ -4,10 +4,16 @@ package auth_test import ( + "context" + "crypto/rand" + "crypto/rsa" "encoding/base64" "encoding/hex" + "encoding/json" "fmt" + "math/big" "net/http" + "net/http/httptest" "net/url" "testing" "time" @@ -15,6 +21,7 @@ import ( "github.com/cozy/cozy-stack/model/instance" "github.com/cozy/cozy-stack/model/instance/lifecycle" "github.com/cozy/cozy-stack/model/oauth" + oidcbinding "github.com/cozy/cozy-stack/model/oidc/binding" "github.com/cozy/cozy-stack/model/permission" "github.com/cozy/cozy-stack/model/session" "github.com/cozy/cozy-stack/model/stack" @@ -2111,6 +2118,400 @@ func TestAuth(t *testing.T) { }) } +func TestTokenExchange(t *testing.T) { + if testing.Short() { + t.Skip("an instance is required for this test: test skipped due to the use of --short flag") + } + + config.UseTestFile(t) + conf := config.GetConfig() + conf.Assets = "../../assets" + _ = web.LoadSupportedLocales() + if _, err := couchdb.CheckStatus(context.Background()); err != nil { + t.Skip("couchdb is required for this test") + } + + privateKey, kid, jwksURL := newTokenExchangeSigningKey(t) + const ( + contextName = "token-exchange-test" + clientID = "cozy-twake-int" + issuer = "https://sign-up.example.com/" + ) + conf.Authentication = map[string]interface{}{ + contextName: map[string]interface{}{ + "oidc": makeTokenExchangeOIDCConfig(issuer, clientID, jwksURL), + }, + } + + setup := testutils.NewSetup(t, t.Name()) + + testInstance := setup.GetTestInstance(&lifecycle.Options{ + Domain: "token-exchange.cozy.example.net", + OrgDomain: "example.com", + OrgID: "myorg123", + ContextName: contextName, + Email: "token-exchange@example.com", + }) + + ts := setup.GetTestServer("/test", fakeAPI, func(r *echo.Echo) *echo.Echo { + handler, err := web.CreateSubdomainProxy(r, &stack.Services{}, apps.Serve) + require.NoError(t, err, "Cant start subdomain proxy") + return handler + }) + ts.Config.Handler.(*echo.Echo).HTTPErrorHandler = errors.ErrorHandler + require.NoError(t, dynamic.InitDynamicAssetFS(config.FsURL().String()), "Could not init dynamic FS") + + e := testutils.CreateTestClient(t, ts.URL) + + t.Run("RequiresMandatoryParameters", func(t *testing.T) { + e.POST("/auth/token_exchange"). + WithHost(testInstance.Domain). + WithHeader("Accept", "application/json"). + WithJSON(map[string]string{}). + Expect(). + Status(http.StatusBadRequest). + JSON().Object(). + ValueEqual("error", "the id_token parameter is mandatory") + }) + + t.Run("AllowsOrgDomainOrigins", func(t *testing.T) { + for _, origin := range []string{ + "https://example.com", + "https://admin.example.com", + "https://workspace.sales.example.com", + } { + resp := e.OPTIONS("/auth/token_exchange"). + WithHost(testInstance.Domain). + WithHeader("Origin", origin). + WithHeader("Access-Control-Request-Headers", "content-type"). + Expect(). + Status(http.StatusNoContent) + + resp.Header("Access-Control-Allow-Origin").Equal(origin) + resp.Header("Access-Control-Allow-Methods").Equal(http.MethodPost) + resp.Header("Access-Control-Allow-Headers").Equal("content-type") + } + }) + + t.Run("RejectsOtherOrigins", func(t *testing.T) { + e.OPTIONS("/auth/token_exchange"). + WithHost(testInstance.Domain). + WithHeader("Origin", "https://admin.other.com"). + Expect(). + Status(http.StatusForbidden). + Header("Access-Control-Allow-Origin").Empty() + }) + + t.Run("ExchangesToken", func(t *testing.T) { + idToken := makeTokenExchangeSignedJWT(t, privateKey, kid, map[string]interface{}{ + "iss": issuer, + "aud": []string{clientID}, + "azp": clientID, + "sub": "admin-user", + "iat": time.Now().Unix(), + "exp": time.Now().Add(time.Hour).Unix(), + "org_id": testInstance.OrgID, + "org_domain": testInstance.OrgDomain, + "org_role": "owner", + }) + + resp := e.POST("/auth/token_exchange"). + WithHost(testInstance.Domain). + WithHeader("Accept", "application/json"). + WithHeader("Origin", "https://admin.example.com"). + WithJSON(map[string]string{ + "id_token": idToken, + "scope": "io.cozy.files", + }). + Expect(). + Status(http.StatusOK) + resp.Header("Cache-Control").Equal("no-store") + resp.Header("Pragma").Equal("no-cache") + resp.Header("Access-Control-Allow-Origin").Equal("https://admin.example.com") + obj := resp.JSON().Object() + + clientIDValue := obj.Value("client_id").String().Raw() + clientSecretValue := obj.Value("client_secret").String().Raw() + registrationToken := obj.Value("registration_access_token").String().Raw() + accessToken := obj.Value("access_token").String().Raw() + refreshToken := obj.Value("refresh_token").String().Raw() + + obj.ValueEqual("token_type", "bearer") + obj.ValueEqual("scope", "io.cozy.files") + require.NotEmpty(t, clientIDValue) + require.NotEmpty(t, clientSecretValue) + require.NotEmpty(t, registrationToken) + assertValidToken(t, testInstance, accessToken, consts.AccessTokenAudience, clientIDValue, "io.cozy.files") + assertValidToken(t, testInstance, refreshToken, consts.RefreshTokenAudience, clientIDValue, "io.cozy.files") + + client, err := oauth.FindClient(testInstance, clientIDValue) + require.NoError(t, err) + require.Equal(t, clientIDValue, client.ClientID) + require.Equal(t, clientSecretValue, client.ClientSecret) + require.False(t, client.Pending) + require.Equal(t, []string{"https://admin.example.com"}, client.RedirectURIs) + }) + + t.Run("BindsExchangedClientToOIDCSID", func(t *testing.T) { + const sid = "token-exchange-sid-123" + idToken := makeTokenExchangeSignedJWT(t, privateKey, kid, map[string]interface{}{ + "iss": issuer, + "aud": []string{clientID}, + "sub": "admin-user-bound", + "sid": sid, + "iat": time.Now().Unix(), + "exp": time.Now().Add(time.Hour).Unix(), + "org_id": testInstance.OrgID, + "org_domain": testInstance.OrgDomain, + "org_role": "owner", + }) + + obj := e.POST("/auth/token_exchange"). + WithHost(testInstance.Domain). + WithHeader("Accept", "application/json"). + WithHeader("Origin", "https://admin.example.com"). + WithJSON(map[string]string{ + "id_token": idToken, + "scope": "io.cozy.files", + }). + Expect(). + Status(http.StatusOK). + JSON().Object() + + clientIDValue := obj.Value("client_id").String().Raw() + client, err := oauth.FindClient(testInstance, clientIDValue) + require.NoError(t, err) + require.Equal(t, sid, client.OIDCSessionID) + + boundClients, err := oidcbinding.ListOAuthClients(contextName, sid) + require.NoError(t, err) + require.Len(t, boundClients, 1) + require.Equal(t, contextName, boundClients[0].OIDCProviderKey) + require.Equal(t, testInstance.Domain, boundClients[0].Domain) + require.Equal(t, clientIDValue, boundClients[0].OAuthClientID) + }) + + t.Run("AcceptsAdminOrgRole", func(t *testing.T) { + idToken := makeTokenExchangeSignedJWT(t, privateKey, kid, map[string]interface{}{ + "iss": issuer, + "aud": []string{clientID}, + "sub": "admin-user-role", + "iat": time.Now().Unix(), + "exp": time.Now().Add(time.Hour).Unix(), + "org_id": testInstance.OrgID, + "org_domain": testInstance.OrgDomain, + "org_role": "admin", + }) + + e.POST("/auth/token_exchange"). + WithHost(testInstance.Domain). + WithHeader("Accept", "application/json"). + WithHeader("Origin", "https://admin.example.com"). + WithJSON(map[string]string{ + "id_token": idToken, + "scope": "io.cozy.files", + }). + Expect(). + Status(http.StatusOK) + }) + + t.Run("RevokesExchangedClientIndependently", func(t *testing.T) { + idToken := makeTokenExchangeSignedJWT(t, privateKey, kid, map[string]interface{}{ + "iss": issuer, + "aud": []string{clientID}, + "sub": "admin-user-revoke", + "iat": time.Now().Unix(), + "exp": time.Now().Add(time.Hour).Unix(), + "org_id": testInstance.OrgID, + "org_domain": testInstance.OrgDomain, + "org_role": "owner", + }) + + obj := e.POST("/auth/token_exchange"). + WithHost(testInstance.Domain). + WithHeader("Accept", "application/json"). + WithHeader("Origin", "https://admin.example.com"). + WithJSON(map[string]string{ + "id_token": idToken, + "scope": "io.cozy.files", + }). + Expect(). + Status(http.StatusOK). + JSON().Object() + + clientIDValue := obj.Value("client_id").String().Raw() + registrationToken := obj.Value("registration_access_token").String().Raw() + + e.DELETE("/auth/register/"+clientIDValue). + WithHost(testInstance.Domain). + WithHeader("Authorization", "Bearer "+registrationToken). + Expect(). + Status(http.StatusNoContent) + + _, err := oauth.FindClient(testInstance, clientIDValue) + require.Error(t, err) + require.True(t, couchdb.IsNotFoundError(err)) + }) + + t.Run("CreatesIndependentOAuthClients", func(t *testing.T) { + firstToken := makeTokenExchangeSignedJWT(t, privateKey, kid, map[string]interface{}{ + "iss": issuer, + "aud": []string{clientID}, + "sub": "admin-user-1", + "iat": time.Now().Unix(), + "exp": time.Now().Add(time.Hour).Unix(), + "org_id": testInstance.OrgID, + "org_domain": testInstance.OrgDomain, + "org_role": "owner", + }) + secondToken := makeTokenExchangeSignedJWT(t, privateKey, kid, map[string]interface{}{ + "iss": issuer, + "aud": []string{clientID}, + "sub": "admin-user-2", + "iat": time.Now().Unix(), + "exp": time.Now().Add(time.Hour).Unix(), + "org_id": testInstance.OrgID, + "org_domain": testInstance.OrgDomain, + "org_role": "owner", + }) + + first := e.POST("/auth/token_exchange"). + WithHost(testInstance.Domain). + WithHeader("Accept", "application/json"). + WithHeader("Origin", "https://admin.example.com"). + WithJSON(map[string]string{ + "id_token": firstToken, + "scope": "io.cozy.files", + }). + Expect(). + Status(http.StatusOK). + JSON().Object() + + second := e.POST("/auth/token_exchange"). + WithHost(testInstance.Domain). + WithHeader("Accept", "application/json"). + WithHeader("Origin", "https://workspace.sales.example.com"). + WithJSON(map[string]string{ + "id_token": secondToken, + "scope": "io.cozy.files", + }). + Expect(). + Status(http.StatusOK). + JSON().Object() + + firstClientID := first.Value("client_id").String().Raw() + firstClientSecret := first.Value("client_secret").String().Raw() + firstRegistrationToken := first.Value("registration_access_token").String().Raw() + secondClientID := second.Value("client_id").String().Raw() + secondClientSecret := second.Value("client_secret").String().Raw() + secondRegistrationToken := second.Value("registration_access_token").String().Raw() + + require.NotEqual(t, firstClientID, secondClientID) + require.NotEqual(t, firstClientSecret, secondClientSecret) + require.NotEqual(t, firstRegistrationToken, secondRegistrationToken) + + firstClient, err := oauth.FindClient(testInstance, firstClientID) + require.NoError(t, err) + require.Equal(t, []string{"https://admin.example.com"}, firstClient.RedirectURIs) + + secondClient, err := oauth.FindClient(testInstance, secondClientID) + require.NoError(t, err) + require.Equal(t, []string{"https://workspace.sales.example.com"}, secondClient.RedirectURIs) + }) + + t.Run("RejectsInvalidToken", func(t *testing.T) { + e.POST("/auth/token_exchange"). + WithHost(testInstance.Domain). + WithHeader("Accept", "application/json"). + WithHeader("Origin", "https://admin.example.com"). + WithJSON(map[string]string{ + "id_token": "not-a-token", + "scope": "io.cozy.files", + }). + Expect(). + Status(http.StatusBadRequest). + JSON().Object(). + ValueEqual("error", "invalid token") + }) + + t.Run("RejectsMissingAdminAuthorization", func(t *testing.T) { + idToken := makeTokenExchangeSignedJWT(t, privateKey, kid, map[string]interface{}{ + "iss": issuer, + "aud": []string{clientID}, + "sub": "member-user", + "iat": time.Now().Unix(), + "exp": time.Now().Add(time.Hour).Unix(), + "org_id": testInstance.OrgID, + "org_domain": testInstance.OrgDomain, + }) + + e.POST("/auth/token_exchange"). + WithHost(testInstance.Domain). + WithHeader("Accept", "application/json"). + WithHeader("Origin", "https://admin.example.com"). + WithJSON(map[string]string{ + "id_token": idToken, + "scope": "io.cozy.files", + }). + Expect(). + Status(http.StatusBadRequest). + JSON().Object(). + ValueEqual("error", "admin authorization is required") + }) + + t.Run("RejectsOrgIDMismatch", func(t *testing.T) { + idToken := makeTokenExchangeSignedJWT(t, privateKey, kid, map[string]interface{}{ + "iss": issuer, + "aud": []string{clientID}, + "sub": "admin-user", + "iat": time.Now().Unix(), + "exp": time.Now().Add(time.Hour).Unix(), + "org_id": "other-org", + "org_domain": testInstance.OrgDomain, + "org_role": "owner", + }) + + e.POST("/auth/token_exchange"). + WithHost(testInstance.Domain). + WithHeader("Accept", "application/json"). + WithHeader("Origin", "https://admin.example.com"). + WithJSON(map[string]string{ + "id_token": idToken, + "scope": "io.cozy.files", + }). + Expect(). + Status(http.StatusBadRequest). + JSON().Object(). + ValueEqual("error", "org_id mismatch") + }) + + t.Run("RejectsInvalidScope", func(t *testing.T) { + idToken := makeTokenExchangeSignedJWT(t, privateKey, kid, map[string]interface{}{ + "iss": issuer, + "aud": []string{clientID}, + "sub": "admin-user", + "iat": time.Now().Unix(), + "exp": time.Now().Add(time.Hour).Unix(), + "org_id": testInstance.OrgID, + "org_domain": testInstance.OrgDomain, + "org_role": "owner", + }) + + e.POST("/auth/token_exchange"). + WithHost(testInstance.Domain). + WithHeader("Accept", "application/json"). + WithHeader("Origin", "https://admin.example.com"). + WithJSON(map[string]string{ + "id_token": idToken, + "scope": "io.cozy.contacts", + }). + Expect(). + Status(http.StatusBadRequest). + JSON().Object(). + ValueEqual("error", "invalid scope") + }) +} + func getLoginCSRFToken(e *httpexpect.Expect) string { return e.GET("/auth/login"). WithHost(domain). @@ -2146,7 +2547,69 @@ func assertValidToken(t *testing.T, testInstance *instance.Instance, token, audi }, &claims) assert.NoError(t, err) assert.Equal(t, audience, claims.Audience[0]) - assert.Equal(t, domain, claims.Issuer) + assert.Equal(t, testInstance.Domain, claims.Issuer) assert.Equal(t, subject, claims.Subject) assert.Equal(t, scope, claims.Scope) } + +func newTokenExchangeSigningKey(t *testing.T) (*rsa.PrivateKey, string, string) { + t.Helper() + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + kid := "token-exchange-test-key" + jwksServer := newTokenExchangeJWKSHandler(t, privateKey, kid) + t.Cleanup(jwksServer.Close) + + return privateKey, kid, jwksServer.URL + "/jwks" +} + +func makeTokenExchangeOIDCConfig(issuer, clientID, jwksURL string) map[string]interface{} { + return map[string]interface{}{ + "allow_oauth_token": true, + "issuer": issuer, + "client_id": clientID, + "client_secret": "provider-secret", + "scope": "openid profile email ADMIN", + "redirect_uri": "https://admin.example.com/callback", + "authorize_url": "https://sign-up.example.com/authorize", + "token_url": "https://sign-up.example.com/token", + "userinfo_url": "https://sign-up.example.com/userinfo", + "userinfo_instance_field": "workplaceFqdn", + "id_token_jwk_url": jwksURL, + } +} + +func makeTokenExchangeSignedJWT(t *testing.T, privateKey *rsa.PrivateKey, kid string, claims map[string]interface{}) string { + t.Helper() + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims(claims)) + token.Header["kid"] = kid + signed, err := token.SignedString(privateKey) + require.NoError(t, err) + return signed +} + +func newTokenExchangeJWKSHandler(t *testing.T, privateKey *rsa.PrivateKey, kid string) *httptest.Server { + t.Helper() + + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/jwks" { + w.WriteHeader(http.StatusNotFound) + return + } + e := big.NewInt(int64(privateKey.PublicKey.E)).Bytes() + payload := map[string]interface{}{ + "keys": []map[string]string{{ + "kty": "RSA", + "use": "sig", + "kid": kid, + "alg": "RS256", + "n": base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.N.Bytes()), + "e": base64.RawURLEncoding.EncodeToString(e), + }}, + } + require.NoError(t, json.NewEncoder(w).Encode(payload)) + })) +} diff --git a/web/auth/token_exchange.go b/web/auth/token_exchange.go new file mode 100644 index 00000000000..42262215774 --- /dev/null +++ b/web/auth/token_exchange.go @@ -0,0 +1,254 @@ +package auth + +import ( + "crypto/subtle" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/cozy/cozy-stack/model/instance" + "github.com/cozy/cozy-stack/model/oauth" + oidcbinding "github.com/cozy/cozy-stack/model/oidc/binding" + oidcprovider "github.com/cozy/cozy-stack/model/oidc/provider" + "github.com/cozy/cozy-stack/pkg/config/config" + "github.com/cozy/cozy-stack/pkg/consts" + "github.com/cozy/cozy-stack/pkg/couchdb" + "github.com/cozy/cozy-stack/pkg/limits" + "github.com/cozy/cozy-stack/pkg/utils" + "github.com/golang-jwt/jwt/v5" + "github.com/labstack/echo/v4" +) + +const ( + tokenExchangeAdminRole = "admin" + tokenExchangeOwnerRole = "owner" + tokenExchangeAllowedScope = "io.cozy.files" + tokenExchangeOAuthClientName = "Twake Admin Panel" + tokenExchangeOAuthClientSoftwareID = "twake-admin-panel" +) + +type tokenExchangeRequest struct { + IDToken string `json:"id_token"` + Scope string `json:"scope"` +} + +type tokenExchangeResponse struct { + AccessTokenReponse + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + RegistrationToken string `json:"registration_access_token"` +} + +func executeTokenExchange(c echo.Context, inst *instance.Instance, req tokenExchangeRequest) (*tokenExchangeResponse, error) { + claims, err := validateTokenExchangeIDToken(inst, req.IDToken) + if err != nil { + return nil, echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + if req.Scope != tokenExchangeAllowedScope { + return nil, echo.NewHTTPError(http.StatusBadRequest, "invalid scope") + } + + client, err := createTokenExchangeOAuthClient(c, inst) + if err != nil { + return nil, err + } + defer LockOAuthClient(inst, client.ClientID)() + + if err := bindTokenExchangeOIDCSession(inst, client, claims); err != nil { + if delErr := client.Delete(inst); delErr != nil { + inst.Logger().WithNamespace("token_exchange").Warnf("Cannot delete orphaned OAuth client %s: %s", client.CouchID, delErr.Description) + } + return nil, err + } + + return buildTokenExchangeResponse(inst, client, req.Scope) +} + +func validateTokenExchangeIDToken(inst *instance.Instance, raw string) (jwt.MapClaims, error) { + if inst == nil { + return nil, errors.New("instance is missing") + } + conf, err := oidcprovider.LoadConfig( + inst.ContextName, + oidcprovider.RequireClientID, + oidcprovider.RequireIDTokenKeyURL, + oidcprovider.RequireIssuerOrTokenURL, + ) + if err != nil || !conf.AllowOAuthToken { + return nil, errors.New("this endpoint is not enabled") + } + + claims, err := oidcprovider.VerifyIDToken(raw, conf) + if err != nil { + return nil, errors.New("invalid token") + } + + expectedIssuer, err := oidcprovider.GetIssuer(inst.ContextName, conf) + if err != nil { + inst.Logger().WithNamespace("token_exchange").Errorf("Cannot get OIDC issuer for context %s: %s", inst.ContextName, err) + return nil, echo.NewHTTPError(http.StatusInternalServerError, "internal server error") + } + issuer, err := claims.GetIssuer() + if err != nil || issuer == "" || issuer != expectedIssuer { + return nil, errors.New("invalid token issuer") + } + if !tokenExchangeAudienceMatches(claims, conf.ClientID) { + return nil, errors.New("invalid token audience") + } + issuedAt, err := claims.GetIssuedAt() + if err != nil || issuedAt == nil { + return nil, errors.New("invalid token") + } + if issuedAt.Time.After(time.Now().Add(5 * time.Minute)) { + return nil, errors.New("invalid token") + } + if !tokenExchangeHasAdminRole(claims) { + return nil, errors.New("admin authorization is required") + } + + orgID, _ := tokenExchangeClaimString(claims, "org_id") + if orgID == "" || subtle.ConstantTimeCompare([]byte(orgID), []byte(inst.OrgID)) == 0 { + return nil, errors.New("org_id mismatch") + } + if inst.OrgDomain != "" { + orgDomain, ok := tokenExchangeClaimString(claims, "org_domain") + if !ok || orgDomain == "" { + return nil, errors.New("org_domain claim is required") + } + if subtle.ConstantTimeCompare([]byte(orgDomain), []byte(inst.OrgDomain)) == 0 { + return nil, errors.New("org_domain mismatch") + } + } + + return claims, nil +} + +func createTokenExchangeOAuthClient(c echo.Context, inst *instance.Instance) (*oauth.Client, error) { + redirectURI, err := tokenExchangeRedirectURI(c, inst) + if err != nil { + return nil, echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + if err := config.GetRateLimiter().CheckRateLimit(inst, limits.OAuthClientType); limits.IsLimitReachedOrExceeded(err) { + return nil, echo.NewHTTPError(http.StatusNotFound, "Not found") + } + + client := &oauth.Client{ + RedirectURIs: []string{redirectURI}, + ClientName: tokenExchangeOAuthClientName, + ClientKind: "browser", + SoftwareID: tokenExchangeOAuthClientSoftwareID, + } + if regErr := client.Create(inst); regErr != nil { + return nil, echo.NewHTTPError(regErr.Code, regErr.Description) + } + + storedClient, err := oauth.FindClient(inst, client.ClientID) + if err != nil { + return nil, err + } + storedClient.RegistrationToken = client.RegistrationToken + return storedClient, nil +} + +func bindTokenExchangeOIDCSession(inst *instance.Instance, client *oauth.Client, claims jwt.MapClaims) error { + sessionID, _ := tokenExchangeClaimString(claims, "sid") + if sessionID != "" { + client.OIDCSessionID = sessionID + } + + client.Pending = false + client.ClientID = "" + if err := couchdb.UpdateDoc(inst, client); err != nil { + inst.Logger().WithNamespace("oidc").Warnf("Cannot update OAuth client %s: %s", client.CouchID, err) + return err + } + + if sessionID != "" { + if err := oidcbinding.BindOAuthClient(inst.ContextName, inst.Domain, sessionID, client.CouchID); err != nil { + inst.Logger().WithNamespace("oidc").Errorf("Cannot bind OIDC session %s to OAuth client %s: %s", sessionID, client.CouchID, err) + return fmt.Errorf("cannot bind OIDC session to OAuth client: %w", err) + } + } + + client.ClientID = client.CouchID + return nil +} + +func buildTokenExchangeResponse(inst *instance.Instance, client *oauth.Client, scope string) (*tokenExchangeResponse, error) { + out := &tokenExchangeResponse{ + AccessTokenReponse: AccessTokenReponse{ + Type: "bearer", + Scope: scope, + }, + ClientID: client.ClientID, + ClientSecret: client.ClientSecret, + RegistrationToken: client.RegistrationToken, + } + + refreshToken, err := client.CreateJWT(inst, consts.RefreshTokenAudience, scope) + if err != nil { + return nil, echo.NewHTTPError(http.StatusInternalServerError, "Can't generate refresh token") + } + out.Refresh = refreshToken + + accessToken, err := client.CreateJWT(inst, consts.AccessTokenAudience, scope) + if err != nil { + return nil, echo.NewHTTPError(http.StatusInternalServerError, "Can't generate access token") + } + out.Access = accessToken + + client.LastRefreshedAt = time.Now() + if err := couchdb.UpdateDoc(inst, client); err != nil { + inst.Logger().WithNamespace("token_exchange").Warnf("Cannot update LastRefreshedAt for client %s: %s", client.CouchID, err) + } + + return out, nil +} + +func tokenExchangeAudienceMatches(claims jwt.MapClaims, clientID string) bool { + aud, err := claims.GetAudience() + if err == nil && len(aud) > 0 { + for _, value := range aud { + if value == clientID { + return true + } + } + return false + } + azp, _ := tokenExchangeClaimString(claims, "azp") + return azp == clientID +} + +func tokenExchangeHasAdminRole(claims jwt.MapClaims) bool { + orgRole, ok := tokenExchangeClaimString(claims, "org_role") + return ok && (strings.EqualFold(orgRole, tokenExchangeAdminRole) || + strings.EqualFold(orgRole, tokenExchangeOwnerRole)) +} + +func tokenExchangeClaimString(claims jwt.MapClaims, key string) (string, bool) { + raw, ok := claims[key] + if !ok || raw == nil { + return "", false + } + value, ok := raw.(string) + return value, ok +} + +func tokenExchangeRedirectURI(c echo.Context, inst *instance.Instance) (string, error) { + if origin := c.Request().Header.Get(echo.HeaderOrigin); origin != "" { + u, err := url.Parse(origin) + if err == nil && u.Scheme != "" && u.Host != "" && utils.StripPort(u.Host) != inst.Domain { + u.Path = "" + u.RawQuery = "" + u.Fragment = "" + return u.String(), nil + } + } + if inst != nil && inst.OrgDomain != "" { + return "https://admin." + inst.OrgDomain, nil + } + return "", errors.New("cannot determine redirect URI") +} diff --git a/web/middlewares/secure.go b/web/middlewares/secure.go index 038ebd899f6..4e41c7f4fbb 100644 --- a/web/middlewares/secure.go +++ b/web/middlewares/secure.go @@ -3,6 +3,7 @@ package middlewares import ( "fmt" "net/url" + "regexp" "strings" "time" @@ -301,15 +302,21 @@ func (b cspBuilder) makeCSPHeader(header, cspAllowList string, sources []CSPSour } } } - // Add matrix.{org_domain} to frame-src directive if present (for iframes) - if header == "frame-src" && b.instance != nil && b.instance.OrgDomain != "" { + // Add matrix.{org_domain} to frame-src directive if present (for iframes). + // OrgDomain is validated to contain only safe characters before being injected into the header. + if header == "frame-src" && b.instance != nil && isSafeDomain(b.instance.OrgDomain) { headers = append(headers, "matrix."+b.instance.OrgDomain) } - // Add api-login-{org_id}.{domain without prefix} to connect-src directive if present - if header == "connect-src" && b.instance != nil && b.instance.OrgID != "" { + // Add api-login-{org_id}.{domain without prefix} to connect-src directive if present. + // OrgID, OrgDomain, and domain are all validated before being injected into the header. + if header == "connect-src" && b.instance != nil && isSafeDomain(b.instance.OrgID) { _, domain, found := strings.Cut(b.instance.Domain, ".") - if found { + if found && isSafeDomain(domain) { headers = append(headers, "api-login-"+b.instance.OrgID+"."+domain) + headers = append(headers, b.instance.OrgID+"."+domain) + } + if isSafeDomain(b.instance.OrgDomain) { + headers = append(headers, b.instance.OrgID+"."+b.instance.OrgDomain) } } if len(headers) == 0 { @@ -344,3 +351,14 @@ func appendCSPRule(currentRules, ruleType string, appendedValues ...string) (new } return } + +// safeDomainRe matches hostname values safe to embed in a CSP header +// (alphanumeric, hyphens, dots — no spaces, semicolons, or other +// CSP-significant characters that could break or inject new directives). +var safeDomainRe = regexp.MustCompile(`^[a-zA-Z0-9.-]+$`) + +// isSafeDomain returns true if s is a non-empty value safe for embedding in a +// CSP header as a hostname or hostname component. +func isSafeDomain(s string) bool { + return s != "" && safeDomainRe.MatchString(s) +} diff --git a/web/middlewares/secure_test.go b/web/middlewares/secure_test.go index f8c0c2aab07..0474d490c6e 100644 --- a/web/middlewares/secure_test.go +++ b/web/middlewares/secure_test.go @@ -8,25 +8,13 @@ import ( "time" "github.com/cozy/cozy-stack/model/instance" - "github.com/cozy/cozy-stack/pkg/assets/dynamic" "github.com/cozy/cozy-stack/pkg/config/config" - "github.com/cozy/cozy-stack/tests/testutils" "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestSecure(t *testing.T) { - if testing.Short() { - t.Skip("an instance is required for this test: test skipped due to the use of --short flag") - } - config.UseTestFile(t) - config.GetConfig().Assets = "../../assets" - setup := testutils.NewSetup(t, t.Name()) - - setup.SetupSwiftTest() - require.NoError(t, dynamic.InitDynamicAssetFS(config.FsURL().String()), "Could not init dynamic FS") t.Run("SecureMiddlewareHSTS", func(t *testing.T) { e := echo.New() @@ -233,12 +221,8 @@ func TestSecure(t *testing.T) { csp := rec.Header().Get(echo.HeaderContentSecurityPolicy) - // Verify that api-login-myorg123.cozy.example.com appears only once (in connect-src) - expectedDomain := "api-login-myorg123.cozy.example.com" - count := strings.Count(csp, expectedDomain) - assert.Equal(t, 1, count, - "%s should appear exactly once (in connect-src), but found %d times. CSP: %s", - expectedDomain, count, csp) + apiLoginDomain := "api-login-myorg123.cozy.example.com" + orgInstanceDomain := "myorg123.cozy.example.com" // Verify that connect-src contains the api-login domain connectSrcIndex := strings.Index(csp, "connect-src ") @@ -250,8 +234,10 @@ func TestSecure(t *testing.T) { "connect-src should end with semicolon") connectSrcContent := csp[connectSrcIndex : connectSrcIndex+connectSrcEnd] - assert.Contains(t, connectSrcContent, expectedDomain, - "connect-src should contain %s. Found: %s", expectedDomain, connectSrcContent) + assert.Contains(t, connectSrcContent, apiLoginDomain, + "connect-src should contain %s. Found: %s", apiLoginDomain, connectSrcContent) + assert.Contains(t, connectSrcContent, orgInstanceDomain, + "connect-src should contain %s. Found: %s", orgInstanceDomain, connectSrcContent) // Verify that other directives do NOT contain the api-login domain otherDirectives := []string{ @@ -276,10 +262,62 @@ func TestSecure(t *testing.T) { directiveEnd := strings.Index(csp[directiveIndex:], ";") if directiveEnd != -1 { directiveContent := csp[directiveIndex : directiveIndex+directiveEnd] - assert.NotContains(t, directiveContent, expectedDomain, - "Directive %s should NOT contain %s. Found: %s", directivePattern, expectedDomain, directiveContent) + assert.NotContains(t, directiveContent, apiLoginDomain, + "Directive %s should NOT contain %s. Found: %s", directivePattern, apiLoginDomain, directiveContent) + assert.NotContains(t, directiveContent, orgInstanceDomain, + "Directive %s should NOT contain %s. Found: %s", directivePattern, orgInstanceDomain, directiveContent) } } } }) + + t.Run("SecureMiddlewareCSPWithOrgIDAndOrgDomainConnectSrc", func(t *testing.T) { + e := echo.New() + req, _ := http.NewRequest(echo.GET, "http://app.cozy.local/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + inst := &instance.Instance{ + Domain: "alice.cozy.example.com", + OrgID: "myorg123", + OrgDomain: "example.com", + } + c.Set("instance", inst) + h := Secure(&SecureConfig{ + CSPDefaultSrc: []CSPSource{CSPSrcSelf}, + CSPScriptSrc: []CSPSource{CSPSrcSelf}, + CSPFrameSrc: []CSPSource{CSPSrcSelf}, + CSPConnectSrc: []CSPSource{CSPSrcSelf}, + CSPFontSrc: []CSPSource{CSPSrcSelf}, + CSPImgSrc: []CSPSource{CSPSrcSelf}, + CSPManifestSrc: []CSPSource{CSPSrcSelf}, + CSPMediaSrc: []CSPSource{CSPSrcSelf}, + CSPObjectSrc: []CSPSource{CSPSrcSelf}, + CSPStyleSrc: []CSPSource{CSPSrcSelf}, + CSPWorkerSrc: []CSPSource{CSPSrcSelf}, + CSPFrameAncestors: []CSPSource{CSPSrcSelf}, + CSPBaseURI: []CSPSource{CSPSrcSelf}, + CSPFormAction: []CSPSource{CSPSrcSelf}, + })(echo.NotFoundHandler) + _ = h(c) + + csp := rec.Header().Get(echo.HeaderContentSecurityPolicy) + expectedDomain := "myorg123.example.com" + + count := strings.Count(csp, expectedDomain) + assert.Equal(t, 1, count, + "%s should appear exactly once (in connect-src), but found %d times. CSP: %s", + expectedDomain, count, csp) + + connectSrcIndex := strings.Index(csp, "connect-src ") + assert.NotEqual(t, -1, connectSrcIndex, + "connect-src should be present in CSP. Full CSP: %s", csp) + + connectSrcEnd := strings.Index(csp[connectSrcIndex:], ";") + assert.NotEqual(t, -1, connectSrcEnd, + "connect-src should end with semicolon") + + connectSrcContent := csp[connectSrcIndex : connectSrcIndex+connectSrcEnd] + assert.Contains(t, connectSrcContent, expectedDomain, + "connect-src should contain %s. Found: %s", expectedDomain, connectSrcContent) + }) }