From 00d67273ed4684a1c2d5f284a296aa0410c948b7 Mon Sep 17 00:00:00 2001 From: timothyF95 Date: Wed, 25 Mar 2026 14:58:15 +0000 Subject: [PATCH 1/5] Refactor: extract oauth from login --- cmd/login/login.go | 193 ++---------------- cmd/login/login_test.go | 18 +- internal/oauth/browser.go | 20 ++ internal/oauth/exchange.go | 59 ++++++ internal/oauth/exchange_test.go | 50 +++++ .../oauth}/htmlPages/error.html | 0 .../oauth}/htmlPages/output.css | 0 .../oauth}/htmlPages/success.html | 0 .../oauth}/htmlPages/waiting.html | 0 internal/oauth/pages.go | 87 ++++++++ internal/oauth/pkce.go | 20 ++ internal/oauth/pkce_test.go | 22 ++ internal/oauth/server.go | 24 +++ internal/oauth/state.go | 17 ++ 14 files changed, 323 insertions(+), 187 deletions(-) create mode 100644 internal/oauth/browser.go create mode 100644 internal/oauth/exchange.go create mode 100644 internal/oauth/exchange_test.go rename {cmd/login => internal/oauth}/htmlPages/error.html (100%) rename {cmd/login => internal/oauth}/htmlPages/output.css (100%) rename {cmd/login => internal/oauth}/htmlPages/success.html (100%) rename {cmd/login => internal/oauth}/htmlPages/waiting.html (100%) create mode 100644 internal/oauth/pages.go create mode 100644 internal/oauth/pkce.go create mode 100644 internal/oauth/pkce_test.go create mode 100644 internal/oauth/server.go create mode 100644 internal/oauth/state.go diff --git a/cmd/login/login.go b/cmd/login/login.go index 9251c78b..0ee07280 100644 --- a/cmd/login/login.go +++ b/cmd/login/login.go @@ -2,17 +2,10 @@ package login import ( "context" - "crypto/rand" - "crypto/sha256" - "embed" - "encoding/base64" - "encoding/json" "fmt" - "io" "net" "net/http" "net/url" - "os/exec" rt "runtime" "strings" "time" @@ -24,28 +17,19 @@ import ( "github.com/smartcontractkit/cre-cli/internal/constants" "github.com/smartcontractkit/cre-cli/internal/credentials" "github.com/smartcontractkit/cre-cli/internal/environments" + "github.com/smartcontractkit/cre-cli/internal/oauth" "github.com/smartcontractkit/cre-cli/internal/runtime" "github.com/smartcontractkit/cre-cli/internal/tenantctx" "github.com/smartcontractkit/cre-cli/internal/ui" ) var ( - httpClient = &http.Client{Timeout: 10 * time.Second} - errorPage = "htmlPages/error.html" - successPage = "htmlPages/success.html" - waitingPage = "htmlPages/waiting.html" - stylePage = "htmlPages/output.css" - // OrgMembershipErrorSubstring is the error message substring returned by Auth0 // when a user doesn't belong to any organization during the auth flow. // This typically happens during sign-up when the organization hasn't been created yet. OrgMembershipErrorSubstring = "user does not belong to any organization" ) -//go:embed htmlPages/*.html -//go:embed htmlPages/*.css -var htmlFiles embed.FS - func New(runtimeCtx *runtime.Context) *cobra.Command { cmd := &cobra.Command{ Use: "login", @@ -102,7 +86,7 @@ func (h *handler) execute() error { // Use spinner for the token exchange h.spinner.Start("Exchanging authorization code...") - tokenSet, err := h.exchangeCodeForTokens(context.Background(), code) + tokenSet, err := oauth.ExchangeAuthorizationCode(context.Background(), nil, h.environmentSet, code, h.lastPKCEVerifier) if err != nil { h.spinner.StopAll() h.log.Error().Err(err).Msg("code exchange failed") @@ -162,13 +146,13 @@ func (h *handler) startAuthFlow() (string, error) { } }() - verifier, challenge, err := generatePKCE() + verifier, challenge, err := oauth.GeneratePKCE() if err != nil { h.spinner.Stop() return "", err } h.lastPKCEVerifier = verifier - h.lastState = randomState() + h.lastState = oauth.RandomState() authURL := h.buildAuthURL(challenge, h.lastState) @@ -180,7 +164,7 @@ func (h *handler) startAuthFlow() (string, error) { ui.URL(authURL) ui.Line() - if err := openBrowser(authURL, rt.GOOS); err != nil { + if err := oauth.OpenBrowser(authURL, rt.GOOS); err != nil { ui.Warning("Could not open browser automatically") ui.Dim("Please open the URL above in your browser") ui.Line() @@ -199,19 +183,7 @@ func (h *handler) startAuthFlow() (string, error) { } func (h *handler) setupServer(codeCh chan string) (*http.Server, net.Listener, error) { - mux := http.NewServeMux() - mux.HandleFunc("/callback", h.callbackHandler(codeCh)) - - // TODO: Add a fallback port in case the default port is in use - listener, err := net.Listen("tcp", constants.AuthListenAddr) - if err != nil { - return nil, nil, fmt.Errorf("failed to listen on %s: %w", constants.AuthListenAddr, err) - } - - return &http.Server{ - Handler: mux, - ReadHeaderTimeout: 5 * time.Second, - }, listener, nil + return oauth.NewCallbackHTTPServer(constants.AuthListenAddr, h.callbackHandler(codeCh)) } func (h *handler) callbackHandler(codeCh chan string) http.HandlerFunc { @@ -225,120 +197,52 @@ func (h *handler) callbackHandler(codeCh chan string) http.HandlerFunc { if strings.Contains(errorDesc, OrgMembershipErrorSubstring) { if h.retryCount >= maxOrgNotFoundRetries { h.log.Error().Int("retries", h.retryCount).Msg("organization setup timed out after maximum retries") - h.serveEmbeddedHTML(w, errorPage, http.StatusBadRequest) + oauth.ServeEmbeddedHTML(h.log, w, oauth.PageError, http.StatusBadRequest) return } // Generate new authentication credentials for the retry - verifier, challenge, err := generatePKCE() + verifier, challenge, err := oauth.GeneratePKCE() if err != nil { h.log.Error().Err(err).Msg("failed to prepare authentication retry") - h.serveEmbeddedHTML(w, errorPage, http.StatusInternalServerError) + oauth.ServeEmbeddedHTML(h.log, w, oauth.PageError, http.StatusInternalServerError) return } h.lastPKCEVerifier = verifier - h.lastState = randomState() + h.lastState = oauth.RandomState() h.retryCount++ // Build the new auth URL for redirect authURL := h.buildAuthURL(challenge, h.lastState) h.log.Debug().Int("attempt", h.retryCount).Int("max", maxOrgNotFoundRetries).Msg("organization setup in progress, retrying") - h.serveWaitingPage(w, authURL) + oauth.ServeWaitingPage(h.log, w, authURL) return } // Generic Auth0 error h.log.Error().Str("error", errorParam).Str("description", errorDesc).Msg("auth error in callback") - h.serveEmbeddedHTML(w, errorPage, http.StatusBadRequest) + oauth.ServeEmbeddedHTML(h.log, w, oauth.PageError, http.StatusBadRequest) return } if st := r.URL.Query().Get("state"); st == "" || h.lastState == "" || st != h.lastState { h.log.Error().Msg("invalid state in response") - h.serveEmbeddedHTML(w, errorPage, http.StatusBadRequest) + oauth.ServeEmbeddedHTML(h.log, w, oauth.PageError, http.StatusBadRequest) return } code := r.URL.Query().Get("code") if code == "" { h.log.Error().Msg("no code in response") - h.serveEmbeddedHTML(w, errorPage, http.StatusBadRequest) + oauth.ServeEmbeddedHTML(h.log, w, oauth.PageError, http.StatusBadRequest) return } - h.serveEmbeddedHTML(w, successPage, http.StatusOK) + oauth.ServeEmbeddedHTML(h.log, w, oauth.PageSuccess, http.StatusOK) codeCh <- code } } -func (h *handler) serveEmbeddedHTML(w http.ResponseWriter, filePath string, status int) { - htmlContent, err := htmlFiles.ReadFile(filePath) - if err != nil { - h.log.Error().Err(err).Str("file", filePath).Msg("failed to read embedded HTML file") - h.sendHTTPError(w) - return - } - - cssContent, err := htmlFiles.ReadFile(stylePage) - if err != nil { - h.log.Error().Err(err).Str("file", stylePage).Msg("failed to read embedded CSS file") - h.sendHTTPError(w) - return - } - - modified := strings.Replace( - string(htmlContent), - ``, - fmt.Sprintf("", string(cssContent)), - 1, - ) - - w.Header().Set("Content-Type", "text/html") - w.WriteHeader(status) - if _, err := w.Write([]byte(modified)); err != nil { - h.log.Error().Err(err).Msg("failed to write HTML response") - } -} - -// serveWaitingPage serves the waiting page with the redirect URL injected. -// This is used when handling organization membership errors during sign-up flow. -func (h *handler) serveWaitingPage(w http.ResponseWriter, redirectURL string) { - htmlContent, err := htmlFiles.ReadFile(waitingPage) - if err != nil { - h.log.Error().Err(err).Str("file", waitingPage).Msg("failed to read waiting page HTML file") - h.sendHTTPError(w) - return - } - - cssContent, err := htmlFiles.ReadFile(stylePage) - if err != nil { - h.log.Error().Err(err).Str("file", stylePage).Msg("failed to read embedded CSS file") - h.sendHTTPError(w) - return - } - - // Inject CSS inline - modified := strings.Replace( - string(htmlContent), - ``, - fmt.Sprintf("", string(cssContent)), - 1, - ) - - // Inject the redirect URL - modified = strings.Replace(modified, "{{REDIRECT_URL}}", redirectURL, 1) - - w.Header().Set("Content-Type", "text/html") - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte(modified)); err != nil { - h.log.Error().Err(err).Msg("failed to write waiting page response") - } -} - -func (h *handler) sendHTTPError(w http.ResponseWriter) { - http.Error(w, "Internal Server Error", http.StatusInternalServerError) -} - func (h *handler) buildAuthURL(codeChallenge, state string) string { params := url.Values{} params.Set("client_id", h.environmentSet.ClientID) @@ -355,41 +259,6 @@ func (h *handler) buildAuthURL(codeChallenge, state string) string { return h.environmentSet.AuthBase + constants.AuthAuthorizePath + "?" + params.Encode() } -func (h *handler) exchangeCodeForTokens(ctx context.Context, code string) (*credentials.CreLoginTokenSet, error) { - form := url.Values{} - form.Set("grant_type", "authorization_code") - form.Set("client_id", h.environmentSet.ClientID) - form.Set("code", code) - form.Set("redirect_uri", constants.AuthRedirectURI) - form.Set("code_verifier", h.lastPKCEVerifier) - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.environmentSet.AuthBase+constants.AuthTokenPath, strings.NewReader(form.Encode())) - if err != nil { - return nil, fmt.Errorf("create request: %w", err) - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - resp, err := httpClient.Do(req) // #nosec G704 -- URL is from trusted environment config - if err != nil { - return nil, fmt.Errorf("perform request: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) - if err != nil { - return nil, fmt.Errorf("read response: %w", err) - } - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body) - } - - var tokenSet credentials.CreLoginTokenSet - if err := json.Unmarshal(body, &tokenSet); err != nil { - return nil, fmt.Errorf("unmarshal token set: %w", err) - } - return &tokenSet, nil -} - func (h *handler) fetchTenantConfig(tokenSet *credentials.CreLoginTokenSet) error { creds := &credentials.Credentials{ Tokens: tokenSet, @@ -404,35 +273,3 @@ func (h *handler) fetchTenantConfig(tokenSet *credentials.CreLoginTokenSet) erro return tenantctx.FetchAndWriteContext(context.Background(), gqlClient, envName, h.log) } - -func openBrowser(urlStr string, goos string) error { - switch goos { - case "darwin": - return exec.Command("open", urlStr).Start() - case "linux": - return exec.Command("xdg-open", urlStr).Start() - case "windows": - return exec.Command("rundll32", "url.dll,FileProtocolHandler", urlStr).Start() - default: - return fmt.Errorf("unsupported OS: %s", goos) - } -} - -func generatePKCE() (verifier, challenge string, err error) { - b := make([]byte, 32) - if _, err = rand.Read(b); err != nil { - return "", "", err - } - verifier = base64.RawURLEncoding.EncodeToString(b) - sum := sha256.Sum256([]byte(verifier)) - challenge = base64.RawURLEncoding.EncodeToString(sum[:]) - return verifier, challenge, nil -} - -func randomState() string { - b := make([]byte, 16) - if _, err := rand.Read(b); err != nil { - return fmt.Sprintf("%d", time.Now().UnixNano()) - } - return base64.RawURLEncoding.EncodeToString(b) -} diff --git a/cmd/login/login_test.go b/cmd/login/login_test.go index 782f2d18..90ce4c66 100644 --- a/cmd/login/login_test.go +++ b/cmd/login/login_test.go @@ -14,6 +14,7 @@ import ( "github.com/smartcontractkit/cre-cli/internal/credentials" "github.com/smartcontractkit/cre-cli/internal/environments" + "github.com/smartcontractkit/cre-cli/internal/oauth" "github.com/smartcontractkit/cre-cli/internal/ui" ) @@ -51,9 +52,9 @@ func TestSaveCredentials_WritesYAML(t *testing.T) { } func TestGeneratePKCE_ReturnsValidChallenge(t *testing.T) { - verifier, challenge, err := generatePKCE() + verifier, challenge, err := oauth.GeneratePKCE() if err != nil { - t.Fatalf("generatePKCE error: %v", err) + t.Fatalf("GeneratePKCE error: %v", err) } if verifier == "" || challenge == "" { t.Error("PKCE verifier or challenge is empty") @@ -61,8 +62,8 @@ func TestGeneratePKCE_ReturnsValidChallenge(t *testing.T) { } func TestRandomState_IsRandomAndNonEmpty(t *testing.T) { - state1 := randomState() - state2 := randomState() + state1 := oauth.RandomState() + state2 := oauth.RandomState() if state1 == "" || state2 == "" { t.Error("randomState returned empty string") } @@ -72,16 +73,16 @@ func TestRandomState_IsRandomAndNonEmpty(t *testing.T) { } func TestOpenBrowser_UnsupportedOS(t *testing.T) { - err := openBrowser("http://example.com", "plan9") + err := oauth.OpenBrowser("http://example.com", "plan9") if err == nil || !strings.Contains(err.Error(), "unsupported OS") { t.Errorf("expected unsupported OS error, got %v", err) } } func TestServeEmbeddedHTML_ErrorOnMissingFile(t *testing.T) { - h := &handler{log: &zerolog.Logger{}, spinner: ui.NewSpinner()} + log := zerolog.Nop() w := httptest.NewRecorder() - h.serveEmbeddedHTML(w, "htmlPages/doesnotexist.html", http.StatusOK) + oauth.ServeEmbeddedHTML(&log, w, "htmlPages/doesnotexist.html", http.StatusOK) resp := w.Result() if resp.StatusCode != http.StatusInternalServerError { t.Errorf("expected 500 error, got %d", resp.StatusCode) @@ -274,12 +275,11 @@ func TestCallbackHandler_GenericAuth0Error(t *testing.T) { func TestServeWaitingPage(t *testing.T) { logger := zerolog.Nop() - h := &handler{log: &logger, spinner: ui.NewSpinner()} w := httptest.NewRecorder() redirectURL := "https://auth.example.com/authorize?client_id=test&state=abc123" - h.serveWaitingPage(w, redirectURL) + oauth.ServeWaitingPage(&logger, w, redirectURL) resp := w.Result() body, _ := io.ReadAll(resp.Body) diff --git a/internal/oauth/browser.go b/internal/oauth/browser.go new file mode 100644 index 00000000..99e13424 --- /dev/null +++ b/internal/oauth/browser.go @@ -0,0 +1,20 @@ +package oauth + +import ( + "fmt" + "os/exec" +) + +// OpenBrowser opens urlStr in the default browser for the given GOOS value. +func OpenBrowser(urlStr string, goos string) error { + switch goos { + case "darwin": + return exec.Command("open", urlStr).Start() + case "linux": + return exec.Command("xdg-open", urlStr).Start() + case "windows": + return exec.Command("rundll32", "url.dll,FileProtocolHandler", urlStr).Start() + default: + return fmt.Errorf("unsupported OS: %s", goos) + } +} diff --git a/internal/oauth/exchange.go b/internal/oauth/exchange.go new file mode 100644 index 00000000..3af69e3d --- /dev/null +++ b/internal/oauth/exchange.go @@ -0,0 +1,59 @@ +package oauth + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/smartcontractkit/cre-cli/internal/constants" + "github.com/smartcontractkit/cre-cli/internal/credentials" + "github.com/smartcontractkit/cre-cli/internal/environments" +) + +// DefaultHTTPClient is used for token exchange when no client is supplied. +var DefaultHTTPClient = &http.Client{Timeout: 10 * time.Second} + +// ExchangeAuthorizationCode exchanges an OAuth authorization code for tokens using +// environment credentials (AuthBase, ClientID) and PKCE code_verifier. +func ExchangeAuthorizationCode(ctx context.Context, httpClient *http.Client, env *environments.EnvironmentSet, code, codeVerifier string) (*credentials.CreLoginTokenSet, error) { + if httpClient == nil { + httpClient = DefaultHTTPClient + } + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("client_id", env.ClientID) + form.Set("code", code) + form.Set("redirect_uri", constants.AuthRedirectURI) + form.Set("code_verifier", codeVerifier) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, env.AuthBase+constants.AuthTokenPath, strings.NewReader(form.Encode())) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := httpClient.Do(req) // #nosec G704 -- URL is from trusted environment config + if err != nil { + return nil, fmt.Errorf("perform request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body) + } + + var tokenSet credentials.CreLoginTokenSet + if err := json.Unmarshal(body, &tokenSet); err != nil { + return nil, fmt.Errorf("unmarshal token set: %w", err) + } + return &tokenSet, nil +} diff --git a/internal/oauth/exchange_test.go b/internal/oauth/exchange_test.go new file mode 100644 index 00000000..fcf19faf --- /dev/null +++ b/internal/oauth/exchange_test.go @@ -0,0 +1,50 @@ +package oauth + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/cre-cli/internal/constants" + "github.com/smartcontractkit/cre-cli/internal/credentials" + "github.com/smartcontractkit/cre-cli/internal/environments" +) + +func TestExchangeAuthorizationCode(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method", http.StatusMethodNotAllowed) + return + } + if err := r.ParseForm(); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + assert.Equal(t, "authorization_code", r.Form.Get("grant_type")) + assert.Equal(t, "cid", r.Form.Get("client_id")) + assert.Equal(t, "auth-code", r.Form.Get("code")) + assert.Equal(t, constants.AuthRedirectURI, r.Form.Get("redirect_uri")) + assert.Equal(t, "verifier", r.Form.Get("code_verifier")) + + _ = json.NewEncoder(w).Encode(credentials.CreLoginTokenSet{ + AccessToken: "a", + TokenType: "Bearer", + }) + })) + defer ts.Close() + + env := &environments.EnvironmentSet{ + AuthBase: ts.URL, + ClientID: "cid", + } + + tok, err := ExchangeAuthorizationCode(context.Background(), ts.Client(), env, "auth-code", "verifier") + require.NoError(t, err) + require.NotNil(t, tok) + assert.Equal(t, "a", tok.AccessToken) +} diff --git a/cmd/login/htmlPages/error.html b/internal/oauth/htmlPages/error.html similarity index 100% rename from cmd/login/htmlPages/error.html rename to internal/oauth/htmlPages/error.html diff --git a/cmd/login/htmlPages/output.css b/internal/oauth/htmlPages/output.css similarity index 100% rename from cmd/login/htmlPages/output.css rename to internal/oauth/htmlPages/output.css diff --git a/cmd/login/htmlPages/success.html b/internal/oauth/htmlPages/success.html similarity index 100% rename from cmd/login/htmlPages/success.html rename to internal/oauth/htmlPages/success.html diff --git a/cmd/login/htmlPages/waiting.html b/internal/oauth/htmlPages/waiting.html similarity index 100% rename from cmd/login/htmlPages/waiting.html rename to internal/oauth/htmlPages/waiting.html diff --git a/internal/oauth/pages.go b/internal/oauth/pages.go new file mode 100644 index 00000000..31d07220 --- /dev/null +++ b/internal/oauth/pages.go @@ -0,0 +1,87 @@ +package oauth + +import ( + "embed" + "fmt" + "net/http" + "strings" + + "github.com/rs/zerolog" +) + +const ( + PageError = "htmlPages/error.html" + PageSuccess = "htmlPages/success.html" + PageWaiting = "htmlPages/waiting.html" + StylePage = "htmlPages/output.css" +) + +//go:embed htmlPages/*.html +//go:embed htmlPages/*.css +var htmlFiles embed.FS + +// ServeEmbeddedHTML serves an embedded HTML page with inline CSS. +func ServeEmbeddedHTML(log *zerolog.Logger, w http.ResponseWriter, filePath string, status int) { + htmlContent, err := htmlFiles.ReadFile(filePath) + if err != nil { + log.Error().Err(err).Str("file", filePath).Msg("failed to read embedded HTML file") + sendHTTPError(w) + return + } + + cssContent, err := htmlFiles.ReadFile(StylePage) + if err != nil { + log.Error().Err(err).Str("file", StylePage).Msg("failed to read embedded CSS file") + sendHTTPError(w) + return + } + + modified := strings.Replace( + string(htmlContent), + ``, + fmt.Sprintf("", string(cssContent)), + 1, + ) + + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(status) + if _, err := w.Write([]byte(modified)); err != nil { + log.Error().Err(err).Msg("failed to write HTML response") + } +} + +// ServeWaitingPage serves the waiting page with the redirect URL injected. +func ServeWaitingPage(log *zerolog.Logger, w http.ResponseWriter, redirectURL string) { + htmlContent, err := htmlFiles.ReadFile(PageWaiting) + if err != nil { + log.Error().Err(err).Str("file", PageWaiting).Msg("failed to read waiting page HTML file") + sendHTTPError(w) + return + } + + cssContent, err := htmlFiles.ReadFile(StylePage) + if err != nil { + log.Error().Err(err).Str("file", StylePage).Msg("failed to read embedded CSS file") + sendHTTPError(w) + return + } + + modified := strings.Replace( + string(htmlContent), + ``, + fmt.Sprintf("", string(cssContent)), + 1, + ) + + modified = strings.Replace(modified, "{{REDIRECT_URL}}", redirectURL, 1) + + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte(modified)); err != nil { + log.Error().Err(err).Msg("failed to write waiting page response") + } +} + +func sendHTTPError(w http.ResponseWriter) { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) +} diff --git a/internal/oauth/pkce.go b/internal/oauth/pkce.go new file mode 100644 index 00000000..0ed0e8a0 --- /dev/null +++ b/internal/oauth/pkce.go @@ -0,0 +1,20 @@ +package oauth + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "fmt" +) + +// GeneratePKCE returns an RFC 7636 S256 code verifier and code challenge. +func GeneratePKCE() (verifier, challenge string, err error) { + b := make([]byte, 32) + if _, err = rand.Read(b); err != nil { + return "", "", fmt.Errorf("pkce random: %w", err) + } + verifier = base64.RawURLEncoding.EncodeToString(b) + sum := sha256.Sum256([]byte(verifier)) + challenge = base64.RawURLEncoding.EncodeToString(sum[:]) + return verifier, challenge, nil +} diff --git a/internal/oauth/pkce_test.go b/internal/oauth/pkce_test.go new file mode 100644 index 00000000..50b7c376 --- /dev/null +++ b/internal/oauth/pkce_test.go @@ -0,0 +1,22 @@ +package oauth + +import ( + "crypto/sha256" + "encoding/base64" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGeneratePKCE_S256(t *testing.T) { + verifier, challenge, err := GeneratePKCE() + require.NoError(t, err) + require.NotEmpty(t, verifier) + require.NotEmpty(t, challenge) + + sum := sha256.Sum256([]byte(verifier)) + decoded, err := base64.RawURLEncoding.DecodeString(challenge) + require.NoError(t, err) + assert.Equal(t, sum[:], decoded) +} diff --git a/internal/oauth/server.go b/internal/oauth/server.go new file mode 100644 index 00000000..4ca2abe5 --- /dev/null +++ b/internal/oauth/server.go @@ -0,0 +1,24 @@ +package oauth + +import ( + "fmt" + "net" + "net/http" + "time" +) + +// NewCallbackHTTPServer listens on listenAddr and serves callback on /callback. +func NewCallbackHTTPServer(listenAddr string, callback http.HandlerFunc) (*http.Server, net.Listener, error) { + mux := http.NewServeMux() + mux.HandleFunc("/callback", callback) + + listener, err := net.Listen("tcp", listenAddr) + if err != nil { + return nil, nil, fmt.Errorf("failed to listen on %s: %w", listenAddr, err) + } + + return &http.Server{ + Handler: mux, + ReadHeaderTimeout: 5 * time.Second, + }, listener, nil +} diff --git a/internal/oauth/state.go b/internal/oauth/state.go new file mode 100644 index 00000000..4019de61 --- /dev/null +++ b/internal/oauth/state.go @@ -0,0 +1,17 @@ +package oauth + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + "time" +) + +// RandomState returns a random OAuth state value for CSRF protection. +func RandomState() string { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + return fmt.Sprintf("%d", time.Now().UnixNano()) + } + return base64.RawURLEncoding.EncodeToString(b) +} From fd957ca24da181c451c0450be069bc54c29f53c6 Mon Sep 17 00:00:00 2001 From: timothyF95 Date: Wed, 25 Mar 2026 15:18:43 +0000 Subject: [PATCH 2/5] Lint --- internal/oauth/exchange_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/internal/oauth/exchange_test.go b/internal/oauth/exchange_test.go index fcf19faf..a7923a36 100644 --- a/internal/oauth/exchange_test.go +++ b/internal/oauth/exchange_test.go @@ -21,6 +21,7 @@ func TestExchangeAuthorizationCode(t *testing.T) { http.Error(w, "method", http.StatusMethodNotAllowed) return } + r.Body = http.MaxBytesReader(w, r.Body, 1<<20) if err := r.ParseForm(); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return @@ -32,7 +33,7 @@ func TestExchangeAuthorizationCode(t *testing.T) { assert.Equal(t, "verifier", r.Form.Get("code_verifier")) _ = json.NewEncoder(w).Encode(credentials.CreLoginTokenSet{ - AccessToken: "a", + AccessToken: "a", // #nosec G101 G117 -- test fixture, not a real credential TokenType: "Bearer", }) })) From a2ae9cd213b46ef7643c9522b69a1c2a74e987ae Mon Sep 17 00:00:00 2001 From: timothyF95 Date: Thu, 26 Mar 2026 06:45:58 +0000 Subject: [PATCH 3/5] complete vault secrets browser OAuth (callback + exchange) --- cmd/login/login.go | 17 +++- cmd/login/login_test.go | 10 ++- cmd/secrets/common/browser_flow.go | 79 ++++++++++++++----- cmd/secrets/common/browser_flow_test.go | 8 +- internal/environments/environments_test.go | 8 +- internal/oauth/exchange.go | 20 +++-- internal/oauth/exchange_test.go | 27 ++++++- internal/oauth/htmlPages/secrets_error.html | 63 +++++++++++++++ internal/oauth/htmlPages/secrets_success.html | 59 ++++++++++++++ internal/oauth/pages.go | 10 ++- internal/oauth/secrets_callback.go | 41 ++++++++++ internal/oauth/secrets_callback_test.go | 66 ++++++++++++++++ internal/oauth/state.go | 42 ++++++++-- internal/oauth/state_test.go | 42 ++++++++++ 14 files changed, 447 insertions(+), 45 deletions(-) create mode 100644 internal/oauth/htmlPages/secrets_error.html create mode 100644 internal/oauth/htmlPages/secrets_success.html create mode 100644 internal/oauth/secrets_callback.go create mode 100644 internal/oauth/secrets_callback_test.go create mode 100644 internal/oauth/state_test.go diff --git a/cmd/login/login.go b/cmd/login/login.go index 0ee07280..01c271a3 100644 --- a/cmd/login/login.go +++ b/cmd/login/login.go @@ -86,7 +86,7 @@ func (h *handler) execute() error { // Use spinner for the token exchange h.spinner.Start("Exchanging authorization code...") - tokenSet, err := oauth.ExchangeAuthorizationCode(context.Background(), nil, h.environmentSet, code, h.lastPKCEVerifier) + tokenSet, err := oauth.ExchangeAuthorizationCode(context.Background(), nil, h.environmentSet, code, h.lastPKCEVerifier, "", "") if err != nil { h.spinner.StopAll() h.log.Error().Err(err).Msg("code exchange failed") @@ -152,7 +152,12 @@ func (h *handler) startAuthFlow() (string, error) { return "", err } h.lastPKCEVerifier = verifier - h.lastState = oauth.RandomState() + state, err := oauth.RandomState() + if err != nil { + h.spinner.Stop() + return "", err + } + h.lastState = state authURL := h.buildAuthURL(challenge, h.lastState) @@ -209,7 +214,13 @@ func (h *handler) callbackHandler(codeCh chan string) http.HandlerFunc { return } h.lastPKCEVerifier = verifier - h.lastState = oauth.RandomState() + st, err := oauth.RandomState() + if err != nil { + h.log.Error().Err(err).Msg("failed to generate OAuth state for retry") + oauth.ServeEmbeddedHTML(h.log, w, oauth.PageError, http.StatusInternalServerError) + return + } + h.lastState = st h.retryCount++ // Build the new auth URL for redirect diff --git a/cmd/login/login_test.go b/cmd/login/login_test.go index 90ce4c66..34b4e6ec 100644 --- a/cmd/login/login_test.go +++ b/cmd/login/login_test.go @@ -62,8 +62,14 @@ func TestGeneratePKCE_ReturnsValidChallenge(t *testing.T) { } func TestRandomState_IsRandomAndNonEmpty(t *testing.T) { - state1 := oauth.RandomState() - state2 := oauth.RandomState() + state1, err := oauth.RandomState() + if err != nil { + t.Fatalf("RandomState: %v", err) + } + state2, err := oauth.RandomState() + if err != nil { + t.Fatalf("RandomState: %v", err) + } if state1 == "" || state2 == "" { t.Error("randomState returned empty string") } diff --git a/cmd/secrets/common/browser_flow.go b/cmd/secrets/common/browser_flow.go index 9e12d821..3888974f 100644 --- a/cmd/secrets/common/browser_flow.go +++ b/cmd/secrets/common/browser_flow.go @@ -2,12 +2,12 @@ package common import ( "context" - "crypto/rand" - "crypto/sha256" - "encoding/base64" "encoding/hex" "fmt" + "net/http" + rt "runtime" "strings" + "time" "github.com/google/uuid" "github.com/machinebox/graphql" @@ -19,6 +19,7 @@ import ( "github.com/smartcontractkit/cre-cli/internal/client/graphqlclient" "github.com/smartcontractkit/cre-cli/internal/constants" "github.com/smartcontractkit/cre-cli/internal/credentials" + "github.com/smartcontractkit/cre-cli/internal/oauth" "github.com/smartcontractkit/cre-cli/internal/ui" ) @@ -45,7 +46,8 @@ func digestHexString(digest [32]byte) string { } // executeBrowserUpsert handles secrets create/update when the user signs in with their organization account. -// It encrypts the payload, binds a digest, and completes the platform authorization request for this step. +// It encrypts the payload, binds a digest, requests a platform authorization URL, completes OAuth in the browser, +// exchanges the code for tokens, and saves credentials. func (h *Handler) executeBrowserUpsert(ctx context.Context, inputs UpsertSecretsInputs, method string) error { if h.Credentials.AuthType == credentials.AuthTypeApiKey { return fmt.Errorf("this sign-in flow requires an interactive login; API keys are not supported") @@ -105,7 +107,7 @@ func (h *Handler) executeBrowserUpsert(ctx context.Context, inputs UpsertSecrets return err } - _, challenge, err := generatePKCES256() + verifier, challenge, err := oauth.GeneratePKCE() if err != nil { return err } @@ -132,22 +134,63 @@ func (h *Handler) executeBrowserUpsert(ctx context.Context, inputs UpsertSecrets if err := gqlClient.Execute(ctx, gqlReq, &gqlResp); err != nil { return fmt.Errorf("could not complete the authorization request") } - if gqlResp.CreateVaultAuthorizationURL.URL == "" { + authURL := gqlResp.CreateVaultAuthorizationURL.URL + if authURL == "" { return fmt.Errorf("could not complete the authorization request") } - ui.Success("Authorization completed successfully.") - return nil -} + oauthIssuerBase, err := oauth.OAuthServerBaseFromAuthorizeURL(authURL) + if err != nil { + return fmt.Errorf("invalid authorization URL from server: %w", err) + } + platformState, _ := oauth.StateFromAuthorizeURL(authURL) + oauthClientIDFromURL, _ := oauth.ClientIDFromAuthorizeURL(authURL) + oauthClientIDForExchange := strings.TrimSpace(oauthClientIDFromURL) + + codeCh := make(chan string, 1) + server, listener, err := oauth.NewCallbackHTTPServer(constants.AuthListenAddr, oauth.SecretsCallbackHandler(codeCh, platformState, h.Log)) + if err != nil { + return fmt.Errorf("could not start local callback server: %w", err) + } + defer func() { + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = server.Shutdown(shutdownCtx) + }() + + go func() { + if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { + h.Log.Error().Err(err).Msg("secrets oauth callback server error") + } + }() -// generatePKCES256 builds the PKCE verifier and challenge used for secure authorization. -func generatePKCES256() (verifier string, challenge string, err error) { - b := make([]byte, 32) - if _, err := rand.Read(b); err != nil { - return "", "", fmt.Errorf("pkce random: %w", err) + ui.Dim("Opening your browser to complete sign-in...") + if err := oauth.OpenBrowser(authURL, rt.GOOS); err != nil { + ui.Warning("Could not open browser automatically") + ui.Dim("Open this URL in your browser:") } - verifier = base64.RawURLEncoding.EncodeToString(b) - sum := sha256.Sum256([]byte(verifier)) - challenge = base64.RawURLEncoding.EncodeToString(sum[:]) - return verifier, challenge, nil + ui.URL(authURL) + ui.Line() + ui.Dim("Waiting for authorization... (Press Ctrl+C to cancel)") + + var code string + select { + case code = <-codeCh: + case <-time.After(500 * time.Second): + return fmt.Errorf("timeout waiting for authorization") + case <-ctx.Done(): + return ctx.Err() + } + + ui.Dim("Saving credentials...") + tokenSet, err := oauth.ExchangeAuthorizationCode(ctx, nil, h.EnvironmentSet, code, verifier, oauthClientIDForExchange, oauthIssuerBase) + if err != nil { + return fmt.Errorf("token exchange failed: %w", err) + } + if err := credentials.SaveCredentials(tokenSet); err != nil { + return fmt.Errorf("failed to save credentials: %w", err) + } + + ui.Success("Authorization completed successfully.") + return nil } diff --git a/cmd/secrets/common/browser_flow_test.go b/cmd/secrets/common/browser_flow_test.go index b8c42429..5c391f63 100644 --- a/cmd/secrets/common/browser_flow_test.go +++ b/cmd/secrets/common/browser_flow_test.go @@ -9,6 +9,8 @@ import ( "github.com/stretchr/testify/require" "github.com/smartcontractkit/chainlink/v2/core/capabilities/vault/vaulttypes" + + "github.com/smartcontractkit/cre-cli/internal/oauth" ) func TestVaultPermissionForMethod(t *testing.T) { @@ -30,9 +32,9 @@ func TestDigestHexString(t *testing.T) { assert.Equal(t, "0x0102030000000000000000000000000000000000000000000000000000000000", digestHexString(d)) } -// TestGeneratePKCES256 checks PKCE S256 (RFC 7636) used by the browser secrets authorization step. -func TestGeneratePKCES256(t *testing.T) { - verifier, challenge, err := generatePKCES256() +// TestBrowserFlowPKCE checks PKCE S256 (RFC 7636) used by the browser secrets authorization step. +func TestBrowserFlowPKCE(t *testing.T) { + verifier, challenge, err := oauth.GeneratePKCE() require.NoError(t, err) require.NotEmpty(t, verifier) require.NotEmpty(t, challenge) diff --git a/internal/environments/environments_test.go b/internal/environments/environments_test.go index e82c6d8e..527da558 100644 --- a/internal/environments/environments_test.go +++ b/internal/environments/environments_test.go @@ -82,10 +82,10 @@ func TestNewEnvironmentSet_FallbackAndOverrides(t *testing.T) { WorkflowRegistryChainName: "ethereum-testnet-sepolia", }, "STAGING": { - AuthBase: "g", - ClientID: "h", - GraphQLURL: "i", - Audience: "bb", + AuthBase: "g", + ClientID: "h", + GraphQLURL: "i", + Audience: "bb", WorkflowRegistryAddress: "0xstaging_wr", WorkflowRegistryChainName: "polygon-mainnet", diff --git a/internal/oauth/exchange.go b/internal/oauth/exchange.go index 3af69e3d..68904d05 100644 --- a/internal/oauth/exchange.go +++ b/internal/oauth/exchange.go @@ -18,20 +18,30 @@ import ( // DefaultHTTPClient is used for token exchange when no client is supplied. var DefaultHTTPClient = &http.Client{Timeout: 10 * time.Second} -// ExchangeAuthorizationCode exchanges an OAuth authorization code for tokens using -// environment credentials (AuthBase, ClientID) and PKCE code_verifier. -func ExchangeAuthorizationCode(ctx context.Context, httpClient *http.Client, env *environments.EnvironmentSet, code, codeVerifier string) (*credentials.CreLoginTokenSet, error) { +// ExchangeAuthorizationCode exchanges an OAuth authorization code for tokens (PKCE). +// If oauthClientID is non-empty, it is used as client_id (must match the authorize URL). +// If oauthAuthServerBase is non-empty (scheme + host only), it is used as the token endpoint host; +// otherwise env.AuthBase is used (e.g. cre login builds the authorize URL from env). +func ExchangeAuthorizationCode(ctx context.Context, httpClient *http.Client, env *environments.EnvironmentSet, code, codeVerifier, oauthClientID, oauthAuthServerBase string) (*credentials.CreLoginTokenSet, error) { if httpClient == nil { httpClient = DefaultHTTPClient } + clientID := env.ClientID + if oauthClientID != "" { + clientID = oauthClientID + } + authBase := env.AuthBase + if oauthAuthServerBase != "" { + authBase = oauthAuthServerBase + } form := url.Values{} form.Set("grant_type", "authorization_code") - form.Set("client_id", env.ClientID) + form.Set("client_id", clientID) form.Set("code", code) form.Set("redirect_uri", constants.AuthRedirectURI) form.Set("code_verifier", codeVerifier) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, env.AuthBase+constants.AuthTokenPath, strings.NewReader(form.Encode())) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, authBase+constants.AuthTokenPath, strings.NewReader(form.Encode())) if err != nil { return nil, fmt.Errorf("create request: %w", err) } diff --git a/internal/oauth/exchange_test.go b/internal/oauth/exchange_test.go index a7923a36..f05efc2d 100644 --- a/internal/oauth/exchange_test.go +++ b/internal/oauth/exchange_test.go @@ -44,8 +44,33 @@ func TestExchangeAuthorizationCode(t *testing.T) { ClientID: "cid", } - tok, err := ExchangeAuthorizationCode(context.Background(), ts.Client(), env, "auth-code", "verifier") + tok, err := ExchangeAuthorizationCode(context.Background(), ts.Client(), env, "auth-code", "verifier", "", "") require.NoError(t, err) require.NotNil(t, tok) assert.Equal(t, "a", tok.AccessToken) } + +func TestExchangeAuthorizationCode_OAuthOverrides(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, 1<<20) + if err := r.ParseForm(); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + assert.Equal(t, "override-cid", r.Form.Get("client_id")) + _ = json.NewEncoder(w).Encode(credentials.CreLoginTokenSet{ + AccessToken: "b", // #nosec G101 G117 -- test fixture + TokenType: "Bearer", + }) + })) + defer ts.Close() + + env := &environments.EnvironmentSet{ + AuthBase: "https://wrong.example", + ClientID: "wrong", + } + + tok, err := ExchangeAuthorizationCode(context.Background(), ts.Client(), env, "c", "v", "override-cid", ts.URL) + require.NoError(t, err) + assert.Equal(t, "b", tok.AccessToken) +} diff --git a/internal/oauth/htmlPages/secrets_error.html b/internal/oauth/htmlPages/secrets_error.html new file mode 100644 index 00000000..f0d43412 --- /dev/null +++ b/internal/oauth/htmlPages/secrets_error.html @@ -0,0 +1,63 @@ + + + + + Secrets authorization failed + + + + + +
+ + + + +

CRE

+
+
+ + + + + + + + + + +

+ Secrets authorization was unsuccessful +

+

+ Your vault sign-in step could not be completed. Close this window and try + again from your terminal. +

+
+ + diff --git a/internal/oauth/htmlPages/secrets_success.html b/internal/oauth/htmlPages/secrets_success.html new file mode 100644 index 00000000..0eb515a3 --- /dev/null +++ b/internal/oauth/htmlPages/secrets_success.html @@ -0,0 +1,59 @@ + + + + + Secrets authorization complete + + + + + +
+ + + + +

CRE

+
+
+ + + +

+ Your secrets request was signed successfully +

+

+ Vault authorization is complete. You can close this window; the CLI will + finish in your terminal. +

+
+ + diff --git a/internal/oauth/pages.go b/internal/oauth/pages.go index 31d07220..040a4ec3 100644 --- a/internal/oauth/pages.go +++ b/internal/oauth/pages.go @@ -10,10 +10,12 @@ import ( ) const ( - PageError = "htmlPages/error.html" - PageSuccess = "htmlPages/success.html" - PageWaiting = "htmlPages/waiting.html" - StylePage = "htmlPages/output.css" + PageError = "htmlPages/error.html" + PageSuccess = "htmlPages/success.html" + PageSecretsSuccess = "htmlPages/secrets_success.html" + PageSecretsError = "htmlPages/secrets_error.html" + PageWaiting = "htmlPages/waiting.html" + StylePage = "htmlPages/output.css" ) //go:embed htmlPages/*.html diff --git a/internal/oauth/secrets_callback.go b/internal/oauth/secrets_callback.go new file mode 100644 index 00000000..cb7f93af --- /dev/null +++ b/internal/oauth/secrets_callback.go @@ -0,0 +1,41 @@ +package oauth + +import ( + "net/http" + + "github.com/rs/zerolog" +) + +// SecretsCallbackHandler handles the OAuth redirect for the browser secrets flow. +// If expectedState is non-empty (parsed from the platform authorize URL), the callback +// must include the same state; otherwise only a non-empty authorization code is required. +func SecretsCallbackHandler(codeCh chan<- string, expectedState string, log *zerolog.Logger) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + errorParam := r.URL.Query().Get("error") + errorDesc := r.URL.Query().Get("error_description") + + if errorParam != "" { + log.Error().Str("error", errorParam).Str("description", errorDesc).Msg("auth error in secrets callback") + ServeEmbeddedHTML(log, w, PageSecretsError, http.StatusBadRequest) + return + } + + if expectedState != "" { + if st := r.URL.Query().Get("state"); st != expectedState { + log.Error().Str("got", st).Str("want", expectedState).Msg("invalid state in secrets callback") + ServeEmbeddedHTML(log, w, PageSecretsError, http.StatusBadRequest) + return + } + } + + code := r.URL.Query().Get("code") + if code == "" { + log.Error().Msg("no code in secrets callback") + ServeEmbeddedHTML(log, w, PageSecretsError, http.StatusBadRequest) + return + } + + ServeEmbeddedHTML(log, w, PageSecretsSuccess, http.StatusOK) + codeCh <- code + } +} diff --git a/internal/oauth/secrets_callback_test.go b/internal/oauth/secrets_callback_test.go new file mode 100644 index 00000000..e7071dab --- /dev/null +++ b/internal/oauth/secrets_callback_test.go @@ -0,0 +1,66 @@ +package oauth + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" +) + +func TestSecretsCallbackHandler_success(t *testing.T) { + log := zerolog.Nop() + codeCh := make(chan string, 1) + h := SecretsCallbackHandler(codeCh, "want-state", &log) + + req := httptest.NewRequest(http.MethodGet, "/callback?code=the-code&state=want-state", nil) + rr := httptest.NewRecorder() + h(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + assert.Equal(t, "the-code", <-codeCh) +} + +func TestSecretsCallbackHandler_stateMismatch(t *testing.T) { + log := zerolog.Nop() + codeCh := make(chan string, 1) + h := SecretsCallbackHandler(codeCh, "want", &log) + + req := httptest.NewRequest(http.MethodGet, "/callback?code=c&state=wrong", nil) + rr := httptest.NewRecorder() + h(rr, req) + + assert.Equal(t, http.StatusBadRequest, rr.Code) + select { + case <-codeCh: + t.Fatal("expected no code") + default: + } +} + +func TestSecretsCallbackHandler_oauthError(t *testing.T) { + log := zerolog.Nop() + codeCh := make(chan string, 1) + h := SecretsCallbackHandler(codeCh, "", &log) + + req := httptest.NewRequest(http.MethodGet, "/callback?error=access_denied", nil) + rr := httptest.NewRecorder() + h(rr, req) + + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Len(t, codeCh, 0) +} + +func TestSecretsCallbackHandler_noStateRequired(t *testing.T) { + log := zerolog.Nop() + codeCh := make(chan string, 1) + h := SecretsCallbackHandler(codeCh, "", &log) + + req := httptest.NewRequest(http.MethodGet, "/callback?code=only-code", nil) + rr := httptest.NewRecorder() + h(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + assert.Equal(t, "only-code", <-codeCh) +} diff --git a/internal/oauth/state.go b/internal/oauth/state.go index 4019de61..bf0de0ec 100644 --- a/internal/oauth/state.go +++ b/internal/oauth/state.go @@ -4,14 +4,46 @@ import ( "crypto/rand" "encoding/base64" "fmt" - "time" + "net/url" ) -// RandomState returns a random OAuth state value for CSRF protection. -func RandomState() string { +// RandomState returns a URL-safe random string suitable for OAuth "state". +func RandomState() (string, error) { b := make([]byte, 16) if _, err := rand.Read(b); err != nil { - return fmt.Sprintf("%d", time.Now().UnixNano()) + return "", fmt.Errorf("oauth: random state: %w", err) } - return base64.RawURLEncoding.EncodeToString(b) + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// StateFromAuthorizeURL returns the OAuth "state" query parameter from an authorize URL, if present. +func StateFromAuthorizeURL(raw string) (string, error) { + u, err := url.Parse(raw) + if err != nil { + return "", err + } + return u.Query().Get("state"), nil +} + +// ClientIDFromAuthorizeURL returns the "client_id" query parameter from an authorize URL (if present). +// Token exchange must use the same client_id the IdP bound to the authorization code. +func ClientIDFromAuthorizeURL(raw string) (string, error) { + u, err := url.Parse(raw) + if err != nil { + return "", err + } + return u.Query().Get("client_id"), nil +} + +// OAuthServerBaseFromAuthorizeURL returns the authorization server origin (scheme + host) for the +// given authorize URL. The token endpoint must be on the same host that issued the authorization code. +func OAuthServerBaseFromAuthorizeURL(raw string) (string, error) { + u, err := url.Parse(raw) + if err != nil { + return "", err + } + if u.Scheme == "" || u.Host == "" { + return "", fmt.Errorf("authorize URL missing scheme or host") + } + return u.Scheme + "://" + u.Host, nil } diff --git a/internal/oauth/state_test.go b/internal/oauth/state_test.go new file mode 100644 index 00000000..fba0e450 --- /dev/null +++ b/internal/oauth/state_test.go @@ -0,0 +1,42 @@ +package oauth + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRandomState(t *testing.T) { + s, err := RandomState() + require.NoError(t, err) + require.NotEmpty(t, s) + s2, err := RandomState() + require.NoError(t, err) + assert.NotEqual(t, s, s2) +} + +func TestStateFromAuthorizeURL(t *testing.T) { + s, err := StateFromAuthorizeURL("https://id.example/authorize?state=abc&client_id=x") + require.NoError(t, err) + assert.Equal(t, "abc", s) + + s, err = StateFromAuthorizeURL("https://id.example/authorize") + require.NoError(t, err) + assert.Equal(t, "", s) +} + +func TestClientIDFromAuthorizeURL(t *testing.T) { + c, err := ClientIDFromAuthorizeURL("https://auth0.example/authorize?client_id=myapp&response_type=code") + require.NoError(t, err) + assert.Equal(t, "myapp", c) +} + +func TestOAuthServerBaseFromAuthorizeURL(t *testing.T) { + base, err := OAuthServerBaseFromAuthorizeURL("https://tenant.auth0.com/authorize?foo=1") + require.NoError(t, err) + assert.Equal(t, "https://tenant.auth0.com", base) + + _, err = OAuthServerBaseFromAuthorizeURL("/relative") + assert.Error(t, err) +} From ba1e70b6e94753b81f8c34e7ced5e958cc2a473e Mon Sep 17 00:00:00 2001 From: timothyF95 Date: Thu, 26 Mar 2026 06:57:17 +0000 Subject: [PATCH 4/5] Lint --- internal/environments/environments_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/internal/environments/environments_test.go b/internal/environments/environments_test.go index 527da558..e82c6d8e 100644 --- a/internal/environments/environments_test.go +++ b/internal/environments/environments_test.go @@ -82,10 +82,10 @@ func TestNewEnvironmentSet_FallbackAndOverrides(t *testing.T) { WorkflowRegistryChainName: "ethereum-testnet-sepolia", }, "STAGING": { - AuthBase: "g", - ClientID: "h", - GraphQLURL: "i", - Audience: "bb", + AuthBase: "g", + ClientID: "h", + GraphQLURL: "i", + Audience: "bb", WorkflowRegistryAddress: "0xstaging_wr", WorkflowRegistryChainName: "polygon-mainnet", From e95ff7d85857b35e03cd43a9a4d36ff8c2f1febc Mon Sep 17 00:00:00 2001 From: timothyF95 Date: Thu, 26 Mar 2026 13:05:35 +0000 Subject: [PATCH 5/5] Update token exchage flow --- cmd/secrets/common/browser_flow.go | 43 +++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/cmd/secrets/common/browser_flow.go b/cmd/secrets/common/browser_flow.go index 3888974f..53d60667 100644 --- a/cmd/secrets/common/browser_flow.go +++ b/cmd/secrets/common/browser_flow.go @@ -29,6 +29,13 @@ const createVaultAuthURLMutation = `mutation CreateVaultAuthorizationUrl($reques } }` +const exchangeAuthCodeToTokenMutation = `mutation ExchangeAuthCodeToToken($request: AuthCodeTokenExchangeRequest!) { + exchangeAuthCodeToToken(request: $request) { + accessToken + expiresIn + } +}` + // vaultPermissionForMethod returns the API permission name for the given vault operation. func vaultPermissionForMethod(method string) (string, error) { switch method { @@ -47,7 +54,8 @@ func digestHexString(digest [32]byte) string { // executeBrowserUpsert handles secrets create/update when the user signs in with their organization account. // It encrypts the payload, binds a digest, requests a platform authorization URL, completes OAuth in the browser, -// exchanges the code for tokens, and saves credentials. +// and exchanges the code via the platform for a short-lived vault JWT (for future DON gateway submission). +// Login tokens in ~/.cre/cre.yaml are not modified; that session stays separate from this vault-only token. func (h *Handler) executeBrowserUpsert(ctx context.Context, inputs UpsertSecretsInputs, method string) error { if h.Credentials.AuthType == credentials.AuthTypeApiKey { return fmt.Errorf("this sign-in flow requires an interactive login; API keys are not supported") @@ -139,13 +147,7 @@ func (h *Handler) executeBrowserUpsert(ctx context.Context, inputs UpsertSecrets return fmt.Errorf("could not complete the authorization request") } - oauthIssuerBase, err := oauth.OAuthServerBaseFromAuthorizeURL(authURL) - if err != nil { - return fmt.Errorf("invalid authorization URL from server: %w", err) - } platformState, _ := oauth.StateFromAuthorizeURL(authURL) - oauthClientIDFromURL, _ := oauth.ClientIDFromAuthorizeURL(authURL) - oauthClientIDForExchange := strings.TrimSpace(oauthClientIDFromURL) codeCh := make(chan string, 1) server, listener, err := oauth.NewCallbackHTTPServer(constants.AuthListenAddr, oauth.SecretsCallbackHandler(codeCh, platformState, h.Log)) @@ -182,15 +184,30 @@ func (h *Handler) executeBrowserUpsert(ctx context.Context, inputs UpsertSecrets return ctx.Err() } - ui.Dim("Saving credentials...") - tokenSet, err := oauth.ExchangeAuthorizationCode(ctx, nil, h.EnvironmentSet, code, verifier, oauthClientIDForExchange, oauthIssuerBase) - if err != nil { + ui.Dim("Completing vault authorization...") + exchangeReq := graphql.NewRequest(exchangeAuthCodeToTokenMutation) + exchangeReq.Var("request", map[string]any{ + "code": code, + "codeVerifier": verifier, + "redirectUri": constants.AuthRedirectURI, + }) + var exchangeResp struct { + ExchangeAuthCodeToToken struct { + AccessToken string `json:"accessToken"` + ExpiresIn int `json:"expiresIn"` + } `json:"exchangeAuthCodeToToken"` + } + if err := gqlClient.Execute(ctx, exchangeReq, &exchangeResp); err != nil { return fmt.Errorf("token exchange failed: %w", err) } - if err := credentials.SaveCredentials(tokenSet); err != nil { - return fmt.Errorf("failed to save credentials: %w", err) + tok := exchangeResp.ExchangeAuthCodeToToken + if tok.AccessToken == "" { + return fmt.Errorf("token exchange failed: empty access token") } + // Short-lived vault JWT for future DON secret submission; do not persist or replace cre login tokens. + _ = tok.AccessToken + _ = tok.ExpiresIn - ui.Success("Authorization completed successfully.") + ui.Success("Vault authorization completed.") return nil }