diff --git a/cmd/login/login.go b/cmd/login/login.go index 0ee07280..01c271a3 100644 --- a/cmd/login/login.go +++ b/cmd/login/login.go @@ -86,7 +86,7 @@ func (h *handler) execute() error { // Use spinner for the token exchange h.spinner.Start("Exchanging authorization code...") - tokenSet, err := oauth.ExchangeAuthorizationCode(context.Background(), nil, h.environmentSet, code, h.lastPKCEVerifier) + tokenSet, err := oauth.ExchangeAuthorizationCode(context.Background(), nil, h.environmentSet, code, h.lastPKCEVerifier, "", "") if err != nil { h.spinner.StopAll() h.log.Error().Err(err).Msg("code exchange failed") @@ -152,7 +152,12 @@ func (h *handler) startAuthFlow() (string, error) { return "", err } h.lastPKCEVerifier = verifier - h.lastState = oauth.RandomState() + state, err := oauth.RandomState() + if err != nil { + h.spinner.Stop() + return "", err + } + h.lastState = state authURL := h.buildAuthURL(challenge, h.lastState) @@ -209,7 +214,13 @@ func (h *handler) callbackHandler(codeCh chan string) http.HandlerFunc { return } h.lastPKCEVerifier = verifier - h.lastState = oauth.RandomState() + st, err := oauth.RandomState() + if err != nil { + h.log.Error().Err(err).Msg("failed to generate OAuth state for retry") + oauth.ServeEmbeddedHTML(h.log, w, oauth.PageError, http.StatusInternalServerError) + return + } + h.lastState = st h.retryCount++ // Build the new auth URL for redirect diff --git a/cmd/login/login_test.go b/cmd/login/login_test.go index 90ce4c66..34b4e6ec 100644 --- a/cmd/login/login_test.go +++ b/cmd/login/login_test.go @@ -62,8 +62,14 @@ func TestGeneratePKCE_ReturnsValidChallenge(t *testing.T) { } func TestRandomState_IsRandomAndNonEmpty(t *testing.T) { - state1 := oauth.RandomState() - state2 := oauth.RandomState() + state1, err := oauth.RandomState() + if err != nil { + t.Fatalf("RandomState: %v", err) + } + state2, err := oauth.RandomState() + if err != nil { + t.Fatalf("RandomState: %v", err) + } if state1 == "" || state2 == "" { t.Error("randomState returned empty string") } diff --git a/cmd/secrets/common/browser_flow.go b/cmd/secrets/common/browser_flow.go index 9e12d821..53d60667 100644 --- a/cmd/secrets/common/browser_flow.go +++ b/cmd/secrets/common/browser_flow.go @@ -2,12 +2,12 @@ package common import ( "context" - "crypto/rand" - "crypto/sha256" - "encoding/base64" "encoding/hex" "fmt" + "net/http" + rt "runtime" "strings" + "time" "github.com/google/uuid" "github.com/machinebox/graphql" @@ -19,6 +19,7 @@ import ( "github.com/smartcontractkit/cre-cli/internal/client/graphqlclient" "github.com/smartcontractkit/cre-cli/internal/constants" "github.com/smartcontractkit/cre-cli/internal/credentials" + "github.com/smartcontractkit/cre-cli/internal/oauth" "github.com/smartcontractkit/cre-cli/internal/ui" ) @@ -28,6 +29,13 @@ const createVaultAuthURLMutation = `mutation CreateVaultAuthorizationUrl($reques } }` +const exchangeAuthCodeToTokenMutation = `mutation ExchangeAuthCodeToToken($request: AuthCodeTokenExchangeRequest!) { + exchangeAuthCodeToToken(request: $request) { + accessToken + expiresIn + } +}` + // vaultPermissionForMethod returns the API permission name for the given vault operation. func vaultPermissionForMethod(method string) (string, error) { switch method { @@ -45,7 +53,9 @@ func digestHexString(digest [32]byte) string { } // executeBrowserUpsert handles secrets create/update when the user signs in with their organization account. -// It encrypts the payload, binds a digest, and completes the platform authorization request for this step. +// It encrypts the payload, binds a digest, requests a platform authorization URL, completes OAuth in the browser, +// and exchanges the code via the platform for a short-lived vault JWT (for future DON gateway submission). +// Login tokens in ~/.cre/cre.yaml are not modified; that session stays separate from this vault-only token. func (h *Handler) executeBrowserUpsert(ctx context.Context, inputs UpsertSecretsInputs, method string) error { if h.Credentials.AuthType == credentials.AuthTypeApiKey { return fmt.Errorf("this sign-in flow requires an interactive login; API keys are not supported") @@ -105,7 +115,7 @@ func (h *Handler) executeBrowserUpsert(ctx context.Context, inputs UpsertSecrets return err } - _, challenge, err := generatePKCES256() + verifier, challenge, err := oauth.GeneratePKCE() if err != nil { return err } @@ -132,22 +142,72 @@ func (h *Handler) executeBrowserUpsert(ctx context.Context, inputs UpsertSecrets if err := gqlClient.Execute(ctx, gqlReq, &gqlResp); err != nil { return fmt.Errorf("could not complete the authorization request") } - if gqlResp.CreateVaultAuthorizationURL.URL == "" { + authURL := gqlResp.CreateVaultAuthorizationURL.URL + if authURL == "" { return fmt.Errorf("could not complete the authorization request") } - ui.Success("Authorization completed successfully.") - return nil -} + platformState, _ := oauth.StateFromAuthorizeURL(authURL) + + codeCh := make(chan string, 1) + server, listener, err := oauth.NewCallbackHTTPServer(constants.AuthListenAddr, oauth.SecretsCallbackHandler(codeCh, platformState, h.Log)) + if err != nil { + return fmt.Errorf("could not start local callback server: %w", err) + } + defer func() { + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = server.Shutdown(shutdownCtx) + }() + + go func() { + if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { + h.Log.Error().Err(err).Msg("secrets oauth callback server error") + } + }() + + ui.Dim("Opening your browser to complete sign-in...") + if err := oauth.OpenBrowser(authURL, rt.GOOS); err != nil { + ui.Warning("Could not open browser automatically") + ui.Dim("Open this URL in your browser:") + } + ui.URL(authURL) + ui.Line() + ui.Dim("Waiting for authorization... (Press Ctrl+C to cancel)") + + var code string + select { + case code = <-codeCh: + case <-time.After(500 * time.Second): + return fmt.Errorf("timeout waiting for authorization") + case <-ctx.Done(): + return ctx.Err() + } -// generatePKCES256 builds the PKCE verifier and challenge used for secure authorization. -func generatePKCES256() (verifier string, challenge string, err error) { - b := make([]byte, 32) - if _, err := rand.Read(b); err != nil { - return "", "", fmt.Errorf("pkce random: %w", err) + ui.Dim("Completing vault authorization...") + exchangeReq := graphql.NewRequest(exchangeAuthCodeToTokenMutation) + exchangeReq.Var("request", map[string]any{ + "code": code, + "codeVerifier": verifier, + "redirectUri": constants.AuthRedirectURI, + }) + var exchangeResp struct { + ExchangeAuthCodeToToken struct { + AccessToken string `json:"accessToken"` + ExpiresIn int `json:"expiresIn"` + } `json:"exchangeAuthCodeToToken"` } - verifier = base64.RawURLEncoding.EncodeToString(b) - sum := sha256.Sum256([]byte(verifier)) - challenge = base64.RawURLEncoding.EncodeToString(sum[:]) - return verifier, challenge, nil + if err := gqlClient.Execute(ctx, exchangeReq, &exchangeResp); err != nil { + return fmt.Errorf("token exchange failed: %w", err) + } + tok := exchangeResp.ExchangeAuthCodeToToken + if tok.AccessToken == "" { + return fmt.Errorf("token exchange failed: empty access token") + } + // Short-lived vault JWT for future DON secret submission; do not persist or replace cre login tokens. + _ = tok.AccessToken + _ = tok.ExpiresIn + + ui.Success("Vault authorization completed.") + return nil } diff --git a/cmd/secrets/common/browser_flow_test.go b/cmd/secrets/common/browser_flow_test.go index b8c42429..5c391f63 100644 --- a/cmd/secrets/common/browser_flow_test.go +++ b/cmd/secrets/common/browser_flow_test.go @@ -9,6 +9,8 @@ import ( "github.com/stretchr/testify/require" "github.com/smartcontractkit/chainlink/v2/core/capabilities/vault/vaulttypes" + + "github.com/smartcontractkit/cre-cli/internal/oauth" ) func TestVaultPermissionForMethod(t *testing.T) { @@ -30,9 +32,9 @@ func TestDigestHexString(t *testing.T) { assert.Equal(t, "0x0102030000000000000000000000000000000000000000000000000000000000", digestHexString(d)) } -// TestGeneratePKCES256 checks PKCE S256 (RFC 7636) used by the browser secrets authorization step. -func TestGeneratePKCES256(t *testing.T) { - verifier, challenge, err := generatePKCES256() +// TestBrowserFlowPKCE checks PKCE S256 (RFC 7636) used by the browser secrets authorization step. +func TestBrowserFlowPKCE(t *testing.T) { + verifier, challenge, err := oauth.GeneratePKCE() require.NoError(t, err) require.NotEmpty(t, verifier) require.NotEmpty(t, challenge) diff --git a/internal/oauth/exchange.go b/internal/oauth/exchange.go index 3af69e3d..68904d05 100644 --- a/internal/oauth/exchange.go +++ b/internal/oauth/exchange.go @@ -18,20 +18,30 @@ import ( // DefaultHTTPClient is used for token exchange when no client is supplied. var DefaultHTTPClient = &http.Client{Timeout: 10 * time.Second} -// ExchangeAuthorizationCode exchanges an OAuth authorization code for tokens using -// environment credentials (AuthBase, ClientID) and PKCE code_verifier. -func ExchangeAuthorizationCode(ctx context.Context, httpClient *http.Client, env *environments.EnvironmentSet, code, codeVerifier string) (*credentials.CreLoginTokenSet, error) { +// ExchangeAuthorizationCode exchanges an OAuth authorization code for tokens (PKCE). +// If oauthClientID is non-empty, it is used as client_id (must match the authorize URL). +// If oauthAuthServerBase is non-empty (scheme + host only), it is used as the token endpoint host; +// otherwise env.AuthBase is used (e.g. cre login builds the authorize URL from env). +func ExchangeAuthorizationCode(ctx context.Context, httpClient *http.Client, env *environments.EnvironmentSet, code, codeVerifier, oauthClientID, oauthAuthServerBase string) (*credentials.CreLoginTokenSet, error) { if httpClient == nil { httpClient = DefaultHTTPClient } + clientID := env.ClientID + if oauthClientID != "" { + clientID = oauthClientID + } + authBase := env.AuthBase + if oauthAuthServerBase != "" { + authBase = oauthAuthServerBase + } form := url.Values{} form.Set("grant_type", "authorization_code") - form.Set("client_id", env.ClientID) + form.Set("client_id", clientID) form.Set("code", code) form.Set("redirect_uri", constants.AuthRedirectURI) form.Set("code_verifier", codeVerifier) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, env.AuthBase+constants.AuthTokenPath, strings.NewReader(form.Encode())) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, authBase+constants.AuthTokenPath, strings.NewReader(form.Encode())) if err != nil { return nil, fmt.Errorf("create request: %w", err) } diff --git a/internal/oauth/exchange_test.go b/internal/oauth/exchange_test.go index a7923a36..f05efc2d 100644 --- a/internal/oauth/exchange_test.go +++ b/internal/oauth/exchange_test.go @@ -44,8 +44,33 @@ func TestExchangeAuthorizationCode(t *testing.T) { ClientID: "cid", } - tok, err := ExchangeAuthorizationCode(context.Background(), ts.Client(), env, "auth-code", "verifier") + tok, err := ExchangeAuthorizationCode(context.Background(), ts.Client(), env, "auth-code", "verifier", "", "") require.NoError(t, err) require.NotNil(t, tok) assert.Equal(t, "a", tok.AccessToken) } + +func TestExchangeAuthorizationCode_OAuthOverrides(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, 1<<20) + if err := r.ParseForm(); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + assert.Equal(t, "override-cid", r.Form.Get("client_id")) + _ = json.NewEncoder(w).Encode(credentials.CreLoginTokenSet{ + AccessToken: "b", // #nosec G101 G117 -- test fixture + TokenType: "Bearer", + }) + })) + defer ts.Close() + + env := &environments.EnvironmentSet{ + AuthBase: "https://wrong.example", + ClientID: "wrong", + } + + tok, err := ExchangeAuthorizationCode(context.Background(), ts.Client(), env, "c", "v", "override-cid", ts.URL) + require.NoError(t, err) + assert.Equal(t, "b", tok.AccessToken) +} diff --git a/internal/oauth/htmlPages/secrets_error.html b/internal/oauth/htmlPages/secrets_error.html new file mode 100644 index 00000000..f0d43412 --- /dev/null +++ b/internal/oauth/htmlPages/secrets_error.html @@ -0,0 +1,63 @@ + + + + + Secrets authorization failed + + + + + +
+ + + + +

CRE

+
+
+ + + + + + + + + + +

+ Secrets authorization was unsuccessful +

+

+ Your vault sign-in step could not be completed. Close this window and try + again from your terminal. +

+
+ + diff --git a/internal/oauth/htmlPages/secrets_success.html b/internal/oauth/htmlPages/secrets_success.html new file mode 100644 index 00000000..0eb515a3 --- /dev/null +++ b/internal/oauth/htmlPages/secrets_success.html @@ -0,0 +1,59 @@ + + + + + Secrets authorization complete + + + + + +
+ + + + +

CRE

+
+
+ + + +

+ Your secrets request was signed successfully +

+

+ Vault authorization is complete. You can close this window; the CLI will + finish in your terminal. +

+
+ + diff --git a/internal/oauth/pages.go b/internal/oauth/pages.go index 31d07220..040a4ec3 100644 --- a/internal/oauth/pages.go +++ b/internal/oauth/pages.go @@ -10,10 +10,12 @@ import ( ) const ( - PageError = "htmlPages/error.html" - PageSuccess = "htmlPages/success.html" - PageWaiting = "htmlPages/waiting.html" - StylePage = "htmlPages/output.css" + PageError = "htmlPages/error.html" + PageSuccess = "htmlPages/success.html" + PageSecretsSuccess = "htmlPages/secrets_success.html" + PageSecretsError = "htmlPages/secrets_error.html" + PageWaiting = "htmlPages/waiting.html" + StylePage = "htmlPages/output.css" ) //go:embed htmlPages/*.html diff --git a/internal/oauth/secrets_callback.go b/internal/oauth/secrets_callback.go new file mode 100644 index 00000000..cb7f93af --- /dev/null +++ b/internal/oauth/secrets_callback.go @@ -0,0 +1,41 @@ +package oauth + +import ( + "net/http" + + "github.com/rs/zerolog" +) + +// SecretsCallbackHandler handles the OAuth redirect for the browser secrets flow. +// If expectedState is non-empty (parsed from the platform authorize URL), the callback +// must include the same state; otherwise only a non-empty authorization code is required. +func SecretsCallbackHandler(codeCh chan<- string, expectedState string, log *zerolog.Logger) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + errorParam := r.URL.Query().Get("error") + errorDesc := r.URL.Query().Get("error_description") + + if errorParam != "" { + log.Error().Str("error", errorParam).Str("description", errorDesc).Msg("auth error in secrets callback") + ServeEmbeddedHTML(log, w, PageSecretsError, http.StatusBadRequest) + return + } + + if expectedState != "" { + if st := r.URL.Query().Get("state"); st != expectedState { + log.Error().Str("got", st).Str("want", expectedState).Msg("invalid state in secrets callback") + ServeEmbeddedHTML(log, w, PageSecretsError, http.StatusBadRequest) + return + } + } + + code := r.URL.Query().Get("code") + if code == "" { + log.Error().Msg("no code in secrets callback") + ServeEmbeddedHTML(log, w, PageSecretsError, http.StatusBadRequest) + return + } + + ServeEmbeddedHTML(log, w, PageSecretsSuccess, http.StatusOK) + codeCh <- code + } +} diff --git a/internal/oauth/secrets_callback_test.go b/internal/oauth/secrets_callback_test.go new file mode 100644 index 00000000..e7071dab --- /dev/null +++ b/internal/oauth/secrets_callback_test.go @@ -0,0 +1,66 @@ +package oauth + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" +) + +func TestSecretsCallbackHandler_success(t *testing.T) { + log := zerolog.Nop() + codeCh := make(chan string, 1) + h := SecretsCallbackHandler(codeCh, "want-state", &log) + + req := httptest.NewRequest(http.MethodGet, "/callback?code=the-code&state=want-state", nil) + rr := httptest.NewRecorder() + h(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + assert.Equal(t, "the-code", <-codeCh) +} + +func TestSecretsCallbackHandler_stateMismatch(t *testing.T) { + log := zerolog.Nop() + codeCh := make(chan string, 1) + h := SecretsCallbackHandler(codeCh, "want", &log) + + req := httptest.NewRequest(http.MethodGet, "/callback?code=c&state=wrong", nil) + rr := httptest.NewRecorder() + h(rr, req) + + assert.Equal(t, http.StatusBadRequest, rr.Code) + select { + case <-codeCh: + t.Fatal("expected no code") + default: + } +} + +func TestSecretsCallbackHandler_oauthError(t *testing.T) { + log := zerolog.Nop() + codeCh := make(chan string, 1) + h := SecretsCallbackHandler(codeCh, "", &log) + + req := httptest.NewRequest(http.MethodGet, "/callback?error=access_denied", nil) + rr := httptest.NewRecorder() + h(rr, req) + + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Len(t, codeCh, 0) +} + +func TestSecretsCallbackHandler_noStateRequired(t *testing.T) { + log := zerolog.Nop() + codeCh := make(chan string, 1) + h := SecretsCallbackHandler(codeCh, "", &log) + + req := httptest.NewRequest(http.MethodGet, "/callback?code=only-code", nil) + rr := httptest.NewRecorder() + h(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + assert.Equal(t, "only-code", <-codeCh) +} diff --git a/internal/oauth/state.go b/internal/oauth/state.go index 4019de61..bf0de0ec 100644 --- a/internal/oauth/state.go +++ b/internal/oauth/state.go @@ -4,14 +4,46 @@ import ( "crypto/rand" "encoding/base64" "fmt" - "time" + "net/url" ) -// RandomState returns a random OAuth state value for CSRF protection. -func RandomState() string { +// RandomState returns a URL-safe random string suitable for OAuth "state". +func RandomState() (string, error) { b := make([]byte, 16) if _, err := rand.Read(b); err != nil { - return fmt.Sprintf("%d", time.Now().UnixNano()) + return "", fmt.Errorf("oauth: random state: %w", err) } - return base64.RawURLEncoding.EncodeToString(b) + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// StateFromAuthorizeURL returns the OAuth "state" query parameter from an authorize URL, if present. +func StateFromAuthorizeURL(raw string) (string, error) { + u, err := url.Parse(raw) + if err != nil { + return "", err + } + return u.Query().Get("state"), nil +} + +// ClientIDFromAuthorizeURL returns the "client_id" query parameter from an authorize URL (if present). +// Token exchange must use the same client_id the IdP bound to the authorization code. +func ClientIDFromAuthorizeURL(raw string) (string, error) { + u, err := url.Parse(raw) + if err != nil { + return "", err + } + return u.Query().Get("client_id"), nil +} + +// OAuthServerBaseFromAuthorizeURL returns the authorization server origin (scheme + host) for the +// given authorize URL. The token endpoint must be on the same host that issued the authorization code. +func OAuthServerBaseFromAuthorizeURL(raw string) (string, error) { + u, err := url.Parse(raw) + if err != nil { + return "", err + } + if u.Scheme == "" || u.Host == "" { + return "", fmt.Errorf("authorize URL missing scheme or host") + } + return u.Scheme + "://" + u.Host, nil } diff --git a/internal/oauth/state_test.go b/internal/oauth/state_test.go new file mode 100644 index 00000000..fba0e450 --- /dev/null +++ b/internal/oauth/state_test.go @@ -0,0 +1,42 @@ +package oauth + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRandomState(t *testing.T) { + s, err := RandomState() + require.NoError(t, err) + require.NotEmpty(t, s) + s2, err := RandomState() + require.NoError(t, err) + assert.NotEqual(t, s, s2) +} + +func TestStateFromAuthorizeURL(t *testing.T) { + s, err := StateFromAuthorizeURL("https://id.example/authorize?state=abc&client_id=x") + require.NoError(t, err) + assert.Equal(t, "abc", s) + + s, err = StateFromAuthorizeURL("https://id.example/authorize") + require.NoError(t, err) + assert.Equal(t, "", s) +} + +func TestClientIDFromAuthorizeURL(t *testing.T) { + c, err := ClientIDFromAuthorizeURL("https://auth0.example/authorize?client_id=myapp&response_type=code") + require.NoError(t, err) + assert.Equal(t, "myapp", c) +} + +func TestOAuthServerBaseFromAuthorizeURL(t *testing.T) { + base, err := OAuthServerBaseFromAuthorizeURL("https://tenant.auth0.com/authorize?foo=1") + require.NoError(t, err) + assert.Equal(t, "https://tenant.auth0.com", base) + + _, err = OAuthServerBaseFromAuthorizeURL("/relative") + assert.Error(t, err) +}