diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2a756e63..abd10557 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -47,7 +47,7 @@ jobs: - name: Run golangci-lint uses: golangci/golangci-lint-action@v9 with: - version: latest + version: v2.10.1 test: name: Test diff --git a/.golangci.yml b/.golangci.yml index 5a2474e6..71e48000 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -58,6 +58,26 @@ linters: - text: "G104:" linters: - gosec + # gosec G117: exported struct fields matching secret-name patterns are intentional DTOs + # (OCI registry tokens, CLI credential request bodies, telemetry config, JWT response frames). + # These fields must be exported and carry credentials by design — they are not accidental leaks. + - text: "G117:" + linters: + - gosec + # gosec G703/G704/G705: taint-analysis false positives for a CLI tool and reverse proxy. + # G704 (SSRF): target URLs come from operator-configured routing tables or explicit CLI flags, + # not from untrusted user input. G703 (path traversal): XDG_RUNTIME_DIR is a trusted env var + # set by the login session, not user-supplied. G705 (XSS): registry manifest data is binary + # OCI content served with an explicit Content-Type, not rendered as HTML. + - text: "G703:" + linters: + - gosec + - text: "G704:" + linters: + - gosec + - text: "G705:" + linters: + - gosec formatters: enable: diff --git a/internal/adapters/dto/registry_token.go b/internal/adapters/dto/registry_token.go index a14df3a5..135c5270 100644 --- a/internal/adapters/dto/registry_token.go +++ b/internal/adapters/dto/registry_token.go @@ -3,7 +3,7 @@ package dto // TokenResponse represents the response from the token server. type TokenResponse struct { Token string `json:"token"` - AccessToken string `json:"access_token,omitempty"` + AccessToken string `json:"access_token,omitempty"` //nolint:gosec // OCI token response field, required by Docker Registry API v2 spec ExpiresIn int `json:"expires_in,omitempty"` IssuedAt string `json:"issued_at,omitempty"` } diff --git a/internal/adapters/in/cli/auth.go b/internal/adapters/in/cli/auth.go index d44c3119..ffe77cc8 100644 --- a/internal/adapters/in/cli/auth.go +++ b/internal/adapters/in/cli/auth.go @@ -24,6 +24,16 @@ import ( "github.com/bnema/gordon/pkg/duration" ) +// stdinFD returns os.Stdin's file descriptor as an int, guarded against +// uintptr overflow on platforms where uintptr > int. +func stdinFD() (int, error) { + fd := os.Stdin.Fd() + if fd > uintptr(^uint(0)>>1) { + return 0, fmt.Errorf("stdin fd value %d overflows int", fd) + } + return int(fd), nil +} + // newAuthCmd creates the auth command group. func newAuthCmd() *cobra.Command { cmd := &cobra.Command{ @@ -274,7 +284,11 @@ func runAuthLogin(remoteName, username, password string) error { if password == "" { // Prompt for password (hidden input) fmt.Print("Password: ") - passwordBytes, err := term.ReadPassword(int(os.Stdin.Fd())) + fd, err := stdinFD() + if err != nil { + return fmt.Errorf("failed to get stdin fd: %w", err) + } + passwordBytes, err := term.ReadPassword(fd) if err != nil { // Fallback for non-terminal input reader := bufio.NewReader(os.Stdin) @@ -595,7 +609,11 @@ func runPasswordHash() error { fmt.Print("Enter password: ") // Read password without echo (use os.Stdin.Fd() for better compatibility) - passwordBytes, err := term.ReadPassword(int(os.Stdin.Fd())) + fd, err := stdinFD() + if err != nil { + return fmt.Errorf("failed to get stdin fd: %w", err) + } + passwordBytes, err := term.ReadPassword(fd) if err != nil { // Fallback for non-terminal input reader := bufio.NewReader(os.Stdin) diff --git a/internal/adapters/in/cli/push.go b/internal/adapters/in/cli/push.go index 18ae76e3..d0e2f76e 100644 --- a/internal/adapters/in/cli/push.go +++ b/internal/adapters/in/cli/push.go @@ -9,6 +9,7 @@ import ( "path/filepath" "regexp" "strings" + "time" tea "github.com/charmbracelet/bubbletea" "github.com/spf13/cobra" @@ -463,8 +464,8 @@ func buildAndPush(ctx context.Context, version, platform, dockerfile string, bui // Cloudflare's 100MB per-request limit. Loading locally then // using docker push gives us chunked uploads (~5MB per request). fmt.Println("\nBuilding image...") - buildCmd := exec.CommandContext(ctx, "docker", buildImageArgs(version, platform, dockerfile, buildArgs, versionRef, latestRef)...) // #nosec G204 - buildCmd.Env = append(os.Environ(), "VERSION="+version) + buildCmd := exec.CommandContext(ctx, "docker", buildImageArgs(ctx, version, platform, dockerfile, buildArgs, versionRef, latestRef)...) // #nosec G204 + buildCmd.Env = os.Environ() // VERSION is now passed as --build-arg VERSION= buildCmd.Stdout = os.Stdout buildCmd.Stderr = os.Stderr if err := buildCmd.Run(); err != nil { @@ -484,23 +485,52 @@ func buildAndPush(ctx context.Context, version, platform, dockerfile string, bui return nil } +// standardBuildArgs returns the standard set of git-related build args as +// explicit KEY=VALUE pairs. User-supplied args are appended after and take +// precedence (Docker uses the last occurrence of a duplicate key). +func standardBuildArgs(ctx context.Context, version string) []string { + gitSHA := resolveGitSHA(ctx) + buildTime := time.Now().UTC().Format(time.RFC3339) + return []string{ + "VERSION=" + version, + "GIT_TAG=" + version, + "GIT_SHA=" + gitSHA, + "BUILD_TIME=" + buildTime, + } +} + +// resolveGitSHA returns the short git SHA of HEAD, or "unknown" if unavailable. +func resolveGitSHA(ctx context.Context) string { + out, err := exec.CommandContext(ctx, "git", "rev-parse", "--short", "HEAD").Output() // #nosec G204 + if err != nil { + return "unknown" + } + return strings.TrimSpace(string(out)) +} + // buildImageArgs constructs the docker buildx build arguments. // Uses --load instead of --push so the image is loaded into the local // daemon, allowing docker push to handle the upload with chunked requests. -func buildImageArgs(version, platform, dockerfile string, buildArgs []string, versionRef, latestRef string) []string { +func buildImageArgs(ctx context.Context, version, platform, dockerfile string, buildArgs []string, versionRef, latestRef string) []string { args := []string{ "buildx", "build", "--platform", platform, "-f", dockerfile, "-t", latestRef, - "--build-arg", "VERSION", } if version != "latest" { args = append(args, "-t", versionRef) } + + // Inject standard git build args as explicit KEY=VALUE pairs. + // User-supplied --build-arg flags are appended AFTER so they override defaults. + for _, ba := range standardBuildArgs(ctx, version) { + args = append(args, "--build-arg", ba) + } for _, ba := range buildArgs { args = append(args, "--build-arg", ba) } + args = append(args, "--load", ".") return args } @@ -682,23 +712,26 @@ func parseImageRef(image string) (registry, name, tag string) { } // getGitVersion returns git describe output, or empty string if unavailable. +// When it falls back it prints a warning to stderr so the user knows the +// image will be tagged "latest" rather than a real version. func getGitVersion(ctx context.Context) string { - out, err := exec.CommandContext(ctx, "git", "describe", "--tags", "--dirty").Output() + out, err := exec.CommandContext(ctx, "git", "describe", "--tags", "--dirty").Output() // #nosec G204 if err != nil { + fmt.Fprintf(os.Stderr, "Warning: unable to determine git tag (%v) — image version will be 'latest'. Tag your repo to get versioned images.\n", err) return "" } return strings.TrimSpace(string(out)) } func dockerTag(ctx context.Context, src, dst string) error { - cmd := exec.CommandContext(ctx, "docker", "tag", src, dst) + cmd := exec.CommandContext(ctx, "docker", "tag", src, dst) //nolint:gosec // binary is constant; image refs validated by OCI ref parser cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr return cmd.Run() } func dockerPush(ctx context.Context, ref string) error { - cmd := exec.CommandContext(ctx, "docker", "push", ref) + cmd := exec.CommandContext(ctx, "docker", "push", ref) //nolint:gosec // binary is constant; image ref validated by OCI ref parser cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr return cmd.Run() diff --git a/internal/adapters/in/cli/push_test.go b/internal/adapters/in/cli/push_test.go index d185505b..5cca8a74 100644 --- a/internal/adapters/in/cli/push_test.go +++ b/internal/adapters/in/cli/push_test.go @@ -1,8 +1,12 @@ package cli import ( + "context" + "io" "os" + "os/exec" "path/filepath" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -65,15 +69,56 @@ func TestParseImageRef(t *testing.T) { func TestBuildAndPush_BuildArgs(t *testing.T) { // Verify buildImageArgs produces --load instead of --push - args := buildImageArgs("v1.0.0", "linux/amd64", "Dockerfile", []string{"CGO_ENABLED=0"}, "reg.example.com/app:v1.0.0", "reg.example.com/app:latest") + args := buildImageArgs(context.Background(), "v1.0.0", "linux/amd64", "Dockerfile", []string{"CGO_ENABLED=0"}, "reg.example.com/app:v1.0.0", "reg.example.com/app:latest") assert.Contains(t, args, "--load") assert.NotContains(t, args, "--push") assert.Contains(t, args, "--platform") assert.Contains(t, args, "-f") assert.Contains(t, args, "Dockerfile") - assert.Contains(t, args, "VERSION") - assert.NotContains(t, args, "VERSION=v1.0.0") + assert.Contains(t, args, "VERSION=v1.0.0") +} + +func TestBuildImageArgsInjectsGitBuildArgs(t *testing.T) { + args := buildImageArgs(context.Background(), "v1.2.3", "linux/amd64", "Dockerfile", nil, "registry/img:v1.2.3", "registry/img:latest") + + // Must contain explicit KEY=VALUE for all standard git build args + argStr := strings.Join(args, " ") + for _, key := range []string{"VERSION=v1.2.3", "GIT_TAG=v1.2.3", "GIT_SHA=", "BUILD_TIME="} { + if !strings.Contains(argStr, key) { + t.Errorf("expected args to contain %q, got: %s", key, argStr) + } + } + + // Must NOT contain bare "--build-arg VERSION" (without =value) + for i, a := range args { + if a == "--build-arg" && i+1 < len(args) && args[i+1] == "VERSION" { + t.Error("found bare '--build-arg VERSION' (without =value); should be '--build-arg VERSION=v1.2.3'") + } + } +} + +func TestBuildImageArgsUserArgsOverrideDefaults(t *testing.T) { + userArgs := []string{"GIT_TAG=custom-override"} + args := buildImageArgs(context.Background(), "v1.2.3", "linux/amd64", "Dockerfile", userArgs, "r/i:v1.2.3", "r/i:latest") + + // Count how many times GIT_TAG appears and track the last occurrence index. + // Docker uses the last occurrence of a duplicate --build-arg key, so the user + // override must come after the default injected value. + count := 0 + lastIdx := -1 + for i, a := range args { + if a == "--build-arg" && i+1 < len(args) && strings.HasPrefix(args[i+1], "GIT_TAG=") { + count++ + lastIdx = i + 1 + } + } + if count < 2 { + t.Errorf("expected GIT_TAG to appear twice (default + override), got %d", count) + } + if lastIdx < 0 || args[lastIdx] != "GIT_TAG="+userArgs[0][len("GIT_TAG="):] { + t.Errorf("expected last GIT_TAG= arg to be the user override %q, got %q", "GIT_TAG=custom-override", args[lastIdx]) + } } func TestParseTagRef(t *testing.T) { @@ -142,7 +187,7 @@ func TestVersionFromTagRefs(t *testing.T) { } func TestBuildImageArgs_CustomDockerfile(t *testing.T) { - args := buildImageArgs("v1.0.0", "linux/amd64", "docker/app/Dockerfile", nil, "reg.example.com/app:v1.0.0", "reg.example.com/app:latest") + args := buildImageArgs(context.Background(), "v1.0.0", "linux/amd64", "docker/app/Dockerfile", nil, "reg.example.com/app:v1.0.0", "reg.example.com/app:latest") assert.Contains(t, args, "-f") assert.Contains(t, args, "docker/app/Dockerfile") @@ -301,6 +346,55 @@ func TestSplitLabelPairs(t *testing.T) { } } +func TestGetGitVersionNoTagsReturnsFallbackAndWarns(t *testing.T) { + ctx := context.Background() + + // Create a temp dir with a git repo that has no tags + tmpDir := t.TempDir() + if err := exec.Command("git", "-C", tmpDir, "init").Run(); err != nil { // #nosec G204 + t.Skipf("git init failed: %v", err) + } + // Need at least one commit so git describe has something to describe + if err := exec.Command("git", "-C", tmpDir, "-c", "user.email=test@test.com", "-c", "user.name=Test", "commit", "--allow-empty", "-m", "init").Run(); err != nil { // #nosec G204 + t.Skipf("git commit failed: %v", err) + } + + // Change to the tmpDir for this test so getGitVersion uses the tag-less repo + origDir, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + if err := os.Chdir(tmpDir); err != nil { // #nosec G204 + t.Fatal(err) + } + defer os.Chdir(origDir) // #nosec G204 + + // Redirect stderr to capture the warning + origStderr := os.Stderr + r, w, err := os.Pipe() + if err != nil { + t.Fatal(err) + } + os.Stderr = w + defer func() { + w.Close() + os.Stderr = origStderr + }() + + v := getGitVersion(ctx) + + w.Close() + os.Stderr = origStderr + stderrOutput, _ := io.ReadAll(r) + + // Should return "" (fallback to "latest" handled by determineVersion) + assert.Equal(t, "", v, "expected empty string fallback when no git tags exist") + + // Should have printed a warning to stderr + assert.Contains(t, string(stderrOutput), "latest", + "expected a warning about 'latest' fallback on stderr") +} + func TestParseLabelPair(t *testing.T) { tests := []struct { name string diff --git a/internal/adapters/in/cli/remote/auth.go b/internal/adapters/in/cli/remote/auth.go index 513d7de1..9c25e09f 100644 --- a/internal/adapters/in/cli/remote/auth.go +++ b/internal/adapters/in/cli/remote/auth.go @@ -12,7 +12,7 @@ import ( // PasswordRequest represents the request body for POST /auth/password. type PasswordRequest struct { Username string `json:"username"` - Password string `json:"password"` + Password string `json:"password"` //nolint:gosec // intentional: CLI credential DTO for auth endpoint } // PasswordResponse represents the response from POST /auth/password. diff --git a/internal/adapters/in/cli/remote/client.go b/internal/adapters/in/cli/remote/client.go index f965cd67..9ea2ca17 100644 --- a/internal/adapters/in/cli/remote/client.go +++ b/internal/adapters/in/cli/remote/client.go @@ -21,10 +21,11 @@ import ( // Client is an HTTP client for the Gordon admin API. type Client struct { - baseURL string - token string - httpClient *http.Client - insecureTLS bool + baseURL string + token string + httpClient *http.Client + insecureTLS bool + onTokenRefreshed func(newToken string) // optional callback to persist a refreshed token } var ( @@ -84,6 +85,37 @@ func WithInsecureTLS(insecure bool) ClientOption { } } +// WithTokenRefreshCallback sets a callback that is invoked when the server +// returns a refreshed token in the X-Gordon-Token response header. +// The CLI uses this to persist the new token to the remotes config atomically. +func WithTokenRefreshCallback(fn func(newToken string)) ClientOption { + return func(c *Client) { + c.onTokenRefreshed = fn + } +} + +// observeTokenRotation checks the X-Gordon-Token header on a response and, if +// the server issued a refreshed token, updates the client's in-memory token and +// invokes the optional persistence callback. Panics in the callback are caught +// so they never propagate to the caller. +func (c *Client) observeTokenRotation(resp *http.Response) { + newToken := resp.Header.Get("X-Gordon-Token") + if newToken == "" || newToken == c.token { + return + } + c.token = newToken + if c.onTokenRefreshed != nil { + func() { + defer func() { + if r := recover(); r != nil { + fmt.Fprintf(os.Stderr, "warning: token refresh callback panicked: %v\n", r) + } + }() + c.onTokenRefreshed(newToken) + }() + } +} + func (c *Client) applyTLSConfig() { if !c.insecureTLS { return @@ -140,7 +172,15 @@ func (c *Client) request(ctx context.Context, method, path string, body any) (*h req.Header.Set("Authorization", "Bearer "+c.token) } - return c.httpClient.Do(req) + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + + // Observe token rotation: the server re-signs and returns the token via X-Gordon-Token. + c.observeTokenRotation(resp) + + return resp, nil } // parseResponse parses a JSON response into the given target. @@ -887,6 +927,9 @@ func (c *Client) streamLogs(ctx context.Context, path string) (<-chan string, er return nil, fmt.Errorf("failed to connect: %w", err) } + // Observe token rotation on streaming responses, same as request(). + c.observeTokenRotation(resp) + if resp.StatusCode >= 400 { body, _ := io.ReadAll(resp.Body) resp.Body.Close() diff --git a/internal/adapters/in/cli/remote/config.go b/internal/adapters/in/cli/remote/config.go index 60cee5c1..bcbf3f3b 100644 --- a/internal/adapters/in/cli/remote/config.go +++ b/internal/adapters/in/cli/remote/config.go @@ -138,16 +138,28 @@ func SaveRemotes(path string, config *ClientConfig) error { // ResolveRemote resolves the remote URL and token from configuration. // Precedence: flag > env > config > active remote func ResolveRemote(flagRemote, flagToken string, flagInsecure bool) (url, token string, insecureTLS, isRemote bool) { + url, token, insecureTLS, _, isRemote = ResolveRemoteFull(flagRemote, flagToken, flagInsecure) + return url, token, insecureTLS, isRemote +} + +// ResolveRemoteFull resolves the remote URL, token, and named remote from configuration. +// It also returns the remote name for use in token persistence callbacks. +// Precedence: flag > env > config > active remote +func ResolveRemoteFull(flagRemote, flagToken string, flagInsecure bool) (url, token string, insecureTLS bool, remoteName string, isRemote bool) { config, _ := LoadClientConfig("") remotes, _ := LoadRemotes("") - url, remoteName, isRemote := resolveRemoteURL(flagRemote, config, remotes) + var name string + url, name, isRemote = resolveRemoteURL(flagRemote, config, remotes) if isRemote { + // resolveToken looks up remotes.Active internally; name == remotes.Active here + // because resolveRemoteURL only sets name when it finds the active remote entry. token = resolveToken(flagToken, config, remotes) - insecureTLS = resolveInsecureTLS(flagInsecure, config, remotes, remoteName) + insecureTLS = resolveInsecureTLS(flagInsecure, config, remotes, name) + remoteName = name } - return url, token, insecureTLS, isRemote + return url, token, insecureTLS, remoteName, isRemote } // resolveRemoteURL resolves the remote URL from various sources. diff --git a/internal/adapters/in/cli/root.go b/internal/adapters/in/cli/root.go index 23ead51d..9d389fd2 100644 --- a/internal/adapters/in/cli/root.go +++ b/internal/adapters/in/cli/root.go @@ -144,12 +144,23 @@ Commands are organized by where they run: // GetRemoteClient returns a remote client if targeting a remote instance, // or nil if running locally. func GetRemoteClient() (*remote.Client, bool) { - url, token, insecureTLS, isRemote := remote.ResolveRemote(remoteFlag, tokenFlag, insecureTLSFlag) + url, token, insecureTLS, remoteName, isRemote := remote.ResolveRemoteFull(remoteFlag, tokenFlag, insecureTLSFlag) if !isRemote { return nil, false } - client := remote.NewClient(url, remoteClientOptions(token, insecureTLS)...) + opts := remoteClientOptions(token, insecureTLS) + + // When using a named remote (from remotes.toml), register a callback to persist + // refreshed tokens returned by the server in the X-Gordon-Token header. + if remoteName != "" { + name := remoteName // capture for closure + opts = append(opts, remote.WithTokenRefreshCallback(func(newToken string) { + _ = remote.UpdateRemoteToken(name, newToken) // best-effort; errors are non-fatal + })) + } + + client := remote.NewClient(url, opts...) return client, true } diff --git a/internal/adapters/in/cli/ui/components/selector.go b/internal/adapters/in/cli/ui/components/selector.go index 47fc31c1..3b3d1480 100644 --- a/internal/adapters/in/cli/ui/components/selector.go +++ b/internal/adapters/in/cli/ui/components/selector.go @@ -82,7 +82,8 @@ func (m SelectorModel) View() string { if i == m.cursor { b.WriteString(styles.Theme.Highlight.Render(cursor + line)) } else { - b.WriteString(fmt.Sprintf("%s%s", cursor, line)) + b.WriteString(cursor) + b.WriteString(line) } b.WriteString("\n") } diff --git a/internal/adapters/in/http/admin/middleware.go b/internal/adapters/in/http/admin/middleware.go index c1e3f447..f32d6083 100644 --- a/internal/adapters/in/http/admin/middleware.go +++ b/internal/adapters/in/http/admin/middleware.go @@ -106,6 +106,14 @@ func AuthMiddleware( ctx = context.WithValue(ctx, domain.ContextKeySubject, claims.Subject) ctx = context.WithValue(ctx, domain.TokenClaimsKey, claims) + // Attempt to slide token expiry. Non-fatal if it fails. + // The new token is returned in X-Gordon-Token so the CLI can persist it atomically. + if newToken, extErr := authSvc.ExtendToken(ctx, token); extErr != nil { + log.Warn().Err(extErr).Str("subject", claims.Subject).Msg("token extension failed") + } else if newToken != token { + w.Header().Set("X-Gordon-Token", newToken) + } + next.ServeHTTP(w, r.WithContext(ctx)) }) } diff --git a/internal/adapters/in/http/admin/middleware_test.go b/internal/adapters/in/http/admin/middleware_test.go index 8a9a729f..3802acff 100644 --- a/internal/adapters/in/http/admin/middleware_test.go +++ b/internal/adapters/in/http/admin/middleware_test.go @@ -2,6 +2,7 @@ package admin import ( "context" + "errors" "net/http" "net/http/httptest" "testing" @@ -83,6 +84,8 @@ func TestAuthMiddleware_TrustedProxy(t *testing.T) { Subject: "admin", Scopes: []string{"admin:*:*"}, }, nil) + // ExtendToken is called after successful validation; non-fatal if it fails + authSvc.EXPECT().ExtendToken(mock.Anything, "valid-admin-token").Return("valid-admin-token", nil) handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) @@ -110,6 +113,8 @@ func TestAuthMiddleware_NoRateLimiting(t *testing.T) { Subject: "admin", Scopes: []string{"admin:*:*"}, }, nil) + // ExtendToken is called after successful validation; non-fatal if it fails + authSvc.EXPECT().ExtendToken(mock.Anything, "valid-admin-token").Return("valid-admin-token", nil) handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) @@ -247,6 +252,8 @@ func TestAuthMiddleware_Success(t *testing.T) { Subject: "admin", Scopes: []string{"admin:*:*"}, }, nil) + // ExtendToken is called after successful validation; non-fatal if it fails + authSvc.EXPECT().ExtendToken(mock.Anything, "valid-admin-token").Return("valid-admin-token", nil) var capturedCtx context.Context handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -358,6 +365,60 @@ func TestRequireScope(t *testing.T) { } } +func TestAuthMiddleware_ExtendTokenFailureIsNonFatal(t *testing.T) { + authSvc := inmocks.NewMockAuthService(t) + + authSvc.EXPECT().IsEnabled().Return(true) + authSvc.EXPECT().ValidateToken(mock.Anything, "valid-admin-token").Return(&domain.TokenClaims{ + Subject: "admin", + Scopes: []string{"admin:*:*"}, + }, nil) + // ExtendToken fails — response must still be 200 (non-fatal) + authSvc.EXPECT().ExtendToken(mock.Anything, "valid-admin-token").Return("", errors.New("store unavailable")) + + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + mw := AuthMiddleware(authSvc, nil, nil, nil, adminTestLogger()) + wrappedHandler := mw(handler) + + req := httptest.NewRequest(http.MethodGet, "/admin/status", nil) + req.Header.Set("Authorization", "Bearer valid-admin-token") + rec := httptest.NewRecorder() + wrappedHandler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Empty(t, rec.Header().Get("X-Gordon-Token")) +} + +func TestAuthMiddleware_ExtendTokenRotationHeader(t *testing.T) { + authSvc := inmocks.NewMockAuthService(t) + + authSvc.EXPECT().IsEnabled().Return(true) + authSvc.EXPECT().ValidateToken(mock.Anything, "old-token").Return(&domain.TokenClaims{ + Subject: "admin", + Scopes: []string{"admin:*:*"}, + }, nil) + // ExtendToken returns a rotated token — must appear in response header + authSvc.EXPECT().ExtendToken(mock.Anything, "old-token").Return("rotated-token", nil) + + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + mw := AuthMiddleware(authSvc, nil, nil, nil, adminTestLogger()) + wrappedHandler := mw(handler) + + req := httptest.NewRequest(http.MethodGet, "/admin/status", nil) + req.Header.Set("Authorization", "Bearer old-token") + rec := httptest.NewRecorder() + wrappedHandler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "rotated-token", rec.Header().Get("X-Gordon-Token")) +} + func TestHasAccess(t *testing.T) { ctx := context.WithValue(context.Background(), domain.ContextKeyScopes, []string{"admin:routes:read", "admin:secrets:*"}) diff --git a/internal/adapters/in/http/auth/handler.go b/internal/adapters/in/http/auth/handler.go index ede726b4..510498a1 100644 --- a/internal/adapters/in/http/auth/handler.go +++ b/internal/adapters/in/http/auth/handler.go @@ -12,6 +12,7 @@ import ( "github.com/bnema/zerowrap" "github.com/bnema/gordon/internal/adapters/dto" + "github.com/bnema/gordon/internal/adapters/in/http/httputil" "github.com/bnema/gordon/internal/boundaries/in" "github.com/bnema/gordon/internal/domain" ) @@ -269,7 +270,7 @@ func (h *Handler) handleToken(w http.ResponseWriter, r *http.Request) { } func (h *Handler) authenticateTokenCredentials(ctx context.Context, r *http.Request, username, password string, log zerowrap.Logger) (bool, *domain.TokenClaims) { - if isLocalhostRequest(r) && h.isInternalAuth(username, password) { + if httputil.IsLocalhostRequest(r) && h.isInternalAuth(username, password) { log.Debug().Str("username", username).Msg("internal registry auth accepted") return true, nil } @@ -350,16 +351,6 @@ func hasGrantedRegistryAccess(grantedScopes []string, repoName, action string) b return false } -// isLocalhostRequest checks if the request originates from localhost. -// SECURITY: Uses RemoteAddr (server-set) instead of Host header (client-spoofable). -func isLocalhostRequest(r *http.Request) bool { - host := r.RemoteAddr - // RemoteAddr includes port, e.g., "127.0.0.1:12345" or "[::1]:12345" - return strings.HasPrefix(host, "127.") || - strings.HasPrefix(host, "[::1]") || - strings.HasPrefix(host, "::1") -} - // parseRequestedScopes extracts and validates scope parameters from the request. // Per Docker Registry v2 auth spec, scope format is: repository:name:actions // Example: GET /auth/token?scope=repository:myrepo:push,pull&scope=repository:other:pull diff --git a/internal/adapters/in/http/auth/handler_test.go b/internal/adapters/in/http/auth/handler_test.go index 2fc8da03..d3620c54 100644 --- a/internal/adapters/in/http/auth/handler_test.go +++ b/internal/adapters/in/http/auth/handler_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/bnema/gordon/internal/adapters/in/http/httputil" "github.com/bnema/gordon/internal/boundaries/in/mocks" "github.com/bnema/gordon/internal/domain" ) @@ -435,7 +436,6 @@ func TestIsLocalhostRequest(t *testing.T) { {"127.0.0.1 is localhost", "127.0.0.1:12345", true}, {"127.0.0.2 is localhost", "127.0.0.2:12345", true}, {"::1 is localhost", "[::1]:12345", true}, - {"::1 without brackets", "::1:12345", true}, {"192.168.1.1 is not localhost", "192.168.1.1:12345", false}, {"10.0.0.1 is not localhost", "10.0.0.1:12345", false}, {"public IP is not localhost", "8.8.8.8:12345", false}, @@ -446,7 +446,7 @@ func TestIsLocalhostRequest(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/auth/token", nil) req.RemoteAddr = tt.remoteAddr - result := isLocalhostRequest(req) + result := httputil.IsLocalhostRequest(req) assert.Equal(t, tt.want, result) }) diff --git a/internal/adapters/in/http/httputil/localhost.go b/internal/adapters/in/http/httputil/localhost.go new file mode 100644 index 00000000..91bcd17d --- /dev/null +++ b/internal/adapters/in/http/httputil/localhost.go @@ -0,0 +1,22 @@ +package httputil + +import ( + "net" + "net/http" + "net/netip" +) + +// IsLocalhostRequest reports whether the request originates from localhost. +// SECURITY: Uses RemoteAddr (server-set) instead of Host header (client-spoofable). +func IsLocalhostRequest(r *http.Request) bool { + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + // RemoteAddr has no port (unusual but possible in tests/Unix sockets). + host = r.RemoteAddr + } + addr, err := netip.ParseAddr(host) + if err != nil { + return false + } + return addr.IsLoopback() +} diff --git a/internal/adapters/in/http/httputil/localhost_test.go b/internal/adapters/in/http/httputil/localhost_test.go new file mode 100644 index 00000000..a0fc2cc5 --- /dev/null +++ b/internal/adapters/in/http/httputil/localhost_test.go @@ -0,0 +1,33 @@ +package httputil_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/bnema/gordon/internal/adapters/in/http/httputil" +) + +func TestIsLocalhostRequest(t *testing.T) { + tests := []struct { + name string + remoteAddr string + want bool + }{ + {"ipv4 loopback", "127.0.0.1:12345", true}, + {"ipv4 loopback other", "127.0.0.2:9000", true}, + {"ipv6 loopback bracketed", "[::1]:12345", true}, + {"external ipv4", "192.168.1.1:12345", false}, + {"external ipv6", "[2001:db8::1]:12345", false}, + {"empty", "", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = tt.remoteAddr + if got := httputil.IsLocalhostRequest(req); got != tt.want { + t.Errorf("IsLocalhostRequest(%q) = %v, want %v", tt.remoteAddr, got, tt.want) + } + }) + } +} diff --git a/internal/adapters/in/http/middleware/auth.go b/internal/adapters/in/http/middleware/auth.go index d36acc33..70cfa794 100644 --- a/internal/adapters/in/http/middleware/auth.go +++ b/internal/adapters/in/http/middleware/auth.go @@ -10,6 +10,7 @@ import ( "github.com/bnema/zerowrap" "github.com/bnema/gordon/internal/adapters/dto" + "github.com/bnema/gordon/internal/adapters/in/http/httputil" "github.com/bnema/gordon/internal/boundaries/in" "github.com/bnema/gordon/internal/domain" ) @@ -18,75 +19,6 @@ import ( // Using domain key for consistency across all auth flows. const TokenClaimsKey = domain.TokenClaimsKey -// RegistryAuth middleware provides Docker Registry authentication. -func RegistryAuth(username, password string, log zerowrap.Logger) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Require auth for all registry operations - if !isAuthenticated(r, username, password, log) { - w.Header().Set("WWW-Authenticate", `Basic realm="Gordon Registry"`) - w.Header().Set("Docker-Distribution-API-Version", "registry/2.0") - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusUnauthorized) - _ = json.NewEncoder(w).Encode(dto.ErrorResponse{Error: "Unauthorized"}) - - log.Warn(). - Str(zerowrap.FieldLayer, "adapter"). - Str(zerowrap.FieldAdapter, "http"). - Str(zerowrap.FieldMethod, r.Method). - Str(zerowrap.FieldPath, r.URL.Path). - Str(zerowrap.FieldClientIP, r.RemoteAddr). - Msg("unauthorized registry access attempt") - return - } - - next.ServeHTTP(w, r) - }) - } -} - -// isAuthenticated checks basic auth credentials. -func isAuthenticated(r *http.Request, expectedUsername, expectedPassword string, log zerowrap.Logger) bool { - authHeader := r.Header.Get("Authorization") - log.Debug(). - Str(zerowrap.FieldLayer, "adapter"). - Str(zerowrap.FieldAdapter, "http"). - Str(zerowrap.FieldMethod, r.Method). - Str(zerowrap.FieldPath, r.URL.Path). - Bool("has_auth_header", authHeader != ""). - Msg("processing authentication request") - - username, password, ok := r.BasicAuth() - if !ok { - log.Debug(). - Str(zerowrap.FieldMethod, r.Method). - Str(zerowrap.FieldPath, r.URL.Path). - Msg("no basic auth provided") - return false - } - - // Use constant-time comparison to prevent timing attacks - usernameMatch := subtle.ConstantTimeCompare([]byte(username), []byte(expectedUsername)) == 1 - passwordMatch := subtle.ConstantTimeCompare([]byte(password), []byte(expectedPassword)) == 1 - - authenticated := usernameMatch && passwordMatch - if !authenticated { - log.Debug(). - Str("provided_username", redactUsername(username)). - Str(zerowrap.FieldMethod, r.Method). - Str(zerowrap.FieldPath, r.URL.Path). - Msg("authentication failed") - } else { - log.Debug(). - Str("username", redactUsername(username)). - Str(zerowrap.FieldMethod, r.Method). - Str(zerowrap.FieldPath, r.URL.Path). - Msg("authentication successful") - } - - return authenticated -} - // sanitizeHeaderValue removes characters that could enable header injection. // Only allows alphanumeric, dots, hyphens, colons, and square brackets // (sufficient for host:port and IPv6 addresses). @@ -125,7 +57,7 @@ func RegistryAuthV2(authSvc in.AuthService, internalAuth InternalRegistryAuth, l return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Allow localhost requests only with internal instance credentials. - if isLocalhostRequest(r) && isInternalRegistryAuth(r, internalAuth) { + if httputil.IsLocalhostRequest(r) && isInternalRegistryAuth(r, internalAuth) { log.Debug(). Str(zerowrap.FieldLayer, "adapter"). Str(zerowrap.FieldAdapter, "http"). @@ -296,19 +228,6 @@ func authenticateToken(ctx context.Context, r *http.Request, authSvc in.AuthServ return true, claims } -// isLocalhostRequest checks if the request originates from localhost. -// This is used to allow Gordon to pull from its own registry with internal auth. -func isLocalhostRequest(r *http.Request) bool { - host := r.RemoteAddr - // RemoteAddr includes port, e.g., "127.0.0.1:12345" or "[::1]:12345" - if strings.HasPrefix(host, "127.") || - strings.HasPrefix(host, "[::1]") || - strings.HasPrefix(host, "::1") { - return true - } - return false -} - func isInternalRegistryAuth(r *http.Request, internalAuth InternalRegistryAuth) bool { if internalAuth.Username == "" || internalAuth.Password == "" { return false diff --git a/internal/adapters/in/http/middleware/auth_test.go b/internal/adapters/in/http/middleware/auth_test.go index c01ac45b..bc8dd69f 100644 --- a/internal/adapters/in/http/middleware/auth_test.go +++ b/internal/adapters/in/http/middleware/auth_test.go @@ -2,7 +2,6 @@ package middleware import ( "context" - "encoding/base64" "errors" "net/http" "net/http/httptest" @@ -19,340 +18,6 @@ func testLogger() zerowrap.Logger { return zerowrap.Default() } -func TestRegistryAuth_ValidCredentials(t *testing.T) { - log := testLogger() - middleware := RegistryAuth("admin", "secret", log) - - handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - w.Write([]byte("authenticated")) - })) - - req := httptest.NewRequest("GET", "/v2/", nil) - req.SetBasicAuth("admin", "secret") - rec := httptest.NewRecorder() - - handler.ServeHTTP(rec, req) - - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "authenticated", rec.Body.String()) -} - -func TestRegistryAuth_InvalidUsername(t *testing.T) { - log := testLogger() - middleware := RegistryAuth("admin", "secret", log) - - handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - - req := httptest.NewRequest("GET", "/v2/", nil) - req.SetBasicAuth("wronguser", "secret") - rec := httptest.NewRecorder() - - handler.ServeHTTP(rec, req) - - assert.Equal(t, http.StatusUnauthorized, rec.Code) - assert.Equal(t, `Basic realm="Gordon Registry"`, rec.Header().Get("WWW-Authenticate")) - assert.Equal(t, "registry/2.0", rec.Header().Get("Docker-Distribution-API-Version")) -} - -func TestRegistryAuth_InvalidPassword(t *testing.T) { - log := testLogger() - middleware := RegistryAuth("admin", "secret", log) - - handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - - req := httptest.NewRequest("GET", "/v2/", nil) - req.SetBasicAuth("admin", "wrongpassword") - rec := httptest.NewRecorder() - - handler.ServeHTTP(rec, req) - - assert.Equal(t, http.StatusUnauthorized, rec.Code) -} - -func TestRegistryAuth_NoCredentials(t *testing.T) { - log := testLogger() - middleware := RegistryAuth("admin", "secret", log) - - handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - - req := httptest.NewRequest("GET", "/v2/", nil) - // No auth header - rec := httptest.NewRecorder() - - handler.ServeHTTP(rec, req) - - assert.Equal(t, http.StatusUnauthorized, rec.Code) - assert.Equal(t, `Basic realm="Gordon Registry"`, rec.Header().Get("WWW-Authenticate")) -} - -func TestRegistryAuth_MalformedAuthHeader(t *testing.T) { - log := testLogger() - middleware := RegistryAuth("admin", "secret", log) - - handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - - req := httptest.NewRequest("GET", "/v2/", nil) - req.Header.Set("Authorization", "Basic notbase64!!!") // Invalid base64 - rec := httptest.NewRecorder() - - handler.ServeHTTP(rec, req) - - assert.Equal(t, http.StatusUnauthorized, rec.Code) -} - -func TestRegistryAuth_WrongAuthScheme(t *testing.T) { - log := testLogger() - middleware := RegistryAuth("admin", "secret", log) - - handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - - req := httptest.NewRequest("GET", "/v2/", nil) - req.Header.Set("Authorization", "Bearer some-token") - rec := httptest.NewRecorder() - - handler.ServeHTTP(rec, req) - - assert.Equal(t, http.StatusUnauthorized, rec.Code) -} - -func TestRegistryAuth_EmptyCredentials(t *testing.T) { - log := testLogger() - middleware := RegistryAuth("admin", "secret", log) - - handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - - req := httptest.NewRequest("GET", "/v2/", nil) - req.SetBasicAuth("", "") - rec := httptest.NewRecorder() - - handler.ServeHTTP(rec, req) - - assert.Equal(t, http.StatusUnauthorized, rec.Code) -} - -func TestRegistryAuth_PartialCredentials(t *testing.T) { - log := testLogger() - middleware := RegistryAuth("admin", "secret", log) - - handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - - // Only username, empty password - req := httptest.NewRequest("GET", "/v2/", nil) - req.SetBasicAuth("admin", "") - rec := httptest.NewRecorder() - - handler.ServeHTTP(rec, req) - - assert.Equal(t, http.StatusUnauthorized, rec.Code) -} - -func TestIsAuthenticated_ValidCredentials(t *testing.T) { - log := testLogger() - - req := httptest.NewRequest("GET", "/v2/", nil) - req.SetBasicAuth("admin", "secret") - - result := isAuthenticated(req, "admin", "secret", log) - - assert.True(t, result) -} - -func TestIsAuthenticated_InvalidCredentials(t *testing.T) { - log := testLogger() - - req := httptest.NewRequest("GET", "/v2/", nil) - req.SetBasicAuth("admin", "wrong") - - result := isAuthenticated(req, "admin", "secret", log) - - assert.False(t, result) -} - -func TestIsAuthenticated_NoAuthHeader(t *testing.T) { - log := testLogger() - - req := httptest.NewRequest("GET", "/v2/", nil) - - result := isAuthenticated(req, "admin", "secret", log) - - assert.False(t, result) -} - -func TestIsAuthenticated_TimingAttackPrevention(t *testing.T) { - // This test verifies that constant-time comparison is used - // by checking both username and password are validated - log := testLogger() - - tests := []struct { - name string - username string - password string - expectedUser string - expectedPass string - shouldAuth bool - }{ - { - name: "both correct", - username: "admin", - password: "secret", - expectedUser: "admin", - expectedPass: "secret", - shouldAuth: true, - }, - { - name: "username wrong first char", - username: "bdmin", - password: "secret", - expectedUser: "admin", - expectedPass: "secret", - shouldAuth: false, - }, - { - name: "username wrong last char", - username: "admio", - password: "secret", - expectedUser: "admin", - expectedPass: "secret", - shouldAuth: false, - }, - { - name: "password wrong first char", - username: "admin", - password: "aecret", - expectedUser: "admin", - expectedPass: "secret", - shouldAuth: false, - }, - { - name: "password wrong last char", - username: "admin", - password: "secreo", - expectedUser: "admin", - expectedPass: "secret", - shouldAuth: false, - }, - { - name: "different length username", - username: "adm", - password: "secret", - expectedUser: "admin", - expectedPass: "secret", - shouldAuth: false, - }, - { - name: "different length password", - username: "admin", - password: "sec", - expectedUser: "admin", - expectedPass: "secret", - shouldAuth: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - req := httptest.NewRequest("GET", "/v2/", nil) - req.SetBasicAuth(tt.username, tt.password) - - result := isAuthenticated(req, tt.expectedUser, tt.expectedPass, log) - - assert.Equal(t, tt.shouldAuth, result) - }) - } -} - -func TestRegistryAuth_SpecialCharactersInCredentials(t *testing.T) { - log := testLogger() - - tests := []struct { - name string - username string - password string - }{ - {"unicode username", "用户", "password"}, - {"unicode password", "admin", "密码"}, - {"special chars", "admin@example.com", "p@ss:w0rd!"}, - {"spaces", "admin user", "pass word"}, - {"colon in password", "admin", "pass:word"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - middleware := RegistryAuth(tt.username, tt.password, log) - - handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - - req := httptest.NewRequest("GET", "/v2/", nil) - req.SetBasicAuth(tt.username, tt.password) - rec := httptest.NewRecorder() - - handler.ServeHTTP(rec, req) - - assert.Equal(t, http.StatusOK, rec.Code) - }) - } -} - -func TestRegistryAuth_Base64EdgeCases(t *testing.T) { - log := testLogger() - middleware := RegistryAuth("admin", "secret", log) - - handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - - tests := []struct { - name string - authHeader string - wantStatus int - }{ - { - name: "empty basic auth", - authHeader: "Basic ", - wantStatus: http.StatusUnauthorized, - }, - { - name: "basic auth with only colon", - authHeader: "Basic " + base64.StdEncoding.EncodeToString([]byte(":")), - wantStatus: http.StatusUnauthorized, - }, - { - name: "missing colon in decoded value", - authHeader: "Basic " + base64.StdEncoding.EncodeToString([]byte("adminpassword")), - wantStatus: http.StatusUnauthorized, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - req := httptest.NewRequest("GET", "/v2/", nil) - req.Header.Set("Authorization", tt.authHeader) - rec := httptest.NewRecorder() - - handler.ServeHTTP(rec, req) - - assert.Equal(t, tt.wantStatus, rec.Code) - }) - } -} - func TestRegistryAuthV2_LocalhostBypassRequiresInternalAuth(t *testing.T) { log := testLogger() called := false @@ -505,6 +170,10 @@ func (s stubAuthService) GetAuthStatus(context.Context) (*domain.AuthStatus, err return nil, errors.New("not implemented") } +func (s stubAuthService) ExtendToken(context.Context, string) (string, error) { + return "", errors.New("not implemented") +} + // Tests for checkScopeAccess function func TestCheckScopeAccess_ActionMapping(t *testing.T) { diff --git a/internal/adapters/out/docker/runtime.go b/internal/adapters/out/docker/runtime.go index 674b6635..9a5ea530 100644 --- a/internal/adapters/out/docker/runtime.go +++ b/internal/adapters/out/docker/runtime.go @@ -1016,8 +1016,8 @@ func (r *Runtime) CreateVolume(ctx context.Context, volumeName string) error { _, err := r.client.VolumeCreate(ctx, volume.CreateOptions{ Name: volumeName, Labels: map[string]string{ - "gordon.managed": "true", - "gordon.created": "auto", + domain.LabelManaged: "true", + domain.LabelCreated: time.Now().UTC().Format(time.RFC3339), }, }) if err != nil { @@ -1142,7 +1142,7 @@ func (r *Runtime) CreateNetwork(ctx context.Context, name string, options map[st createOptions := network.CreateOptions{ Driver: driver, Labels: map[string]string{ - "gordon.managed": "true", + domain.LabelManaged: "true", }, } diff --git a/internal/adapters/out/domainsecrets/pass_store.go b/internal/adapters/out/domainsecrets/pass_store.go index 7cc786f5..8fc51dad 100644 --- a/internal/adapters/out/domainsecrets/pass_store.go +++ b/internal/adapters/out/domainsecrets/pass_store.go @@ -33,7 +33,10 @@ type PassStore struct { // NewPassStore creates a new pass-based domain secret store. func NewPassStore(log zerowrap.Logger) (*PassStore, error) { - if err := exec.Command("pass", "version").Run(); err != nil { + // Use a timeout so a stalled GPG agent or keyring does not hang startup. + probeCtx, probeCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer probeCancel() + if err := exec.CommandContext(probeCtx, "pass", "version").Run(); err != nil { //nolint:gosec // binary is a constant ("pass"), no user input return nil, fmt.Errorf("pass is not available: %w", err) } @@ -646,7 +649,7 @@ func (s *PassStore) ManifestExists(domainName string) (bool, error) { } func (s *PassStore) passInsert(ctx context.Context, path, value string) error { - cmd := exec.CommandContext(ctx, "pass", "insert", "-m", "-f", path) + cmd := exec.CommandContext(ctx, "pass", "insert", "-m", "-f", path) //nolint:gosec // binary is constant ("pass"); path arguments validated by secrets path validator cmd.Stdin = strings.NewReader(value) output, err := cmd.CombinedOutput() if err != nil { @@ -656,7 +659,7 @@ func (s *PassStore) passInsert(ctx context.Context, path, value string) error { } func (s *PassStore) passRemove(ctx context.Context, path string) error { - cmd := exec.CommandContext(ctx, "pass", "rm", "-f", path) + cmd := exec.CommandContext(ctx, "pass", "rm", "-f", path) //nolint:gosec // binary is constant ("pass"); path arguments validated by secrets path validator output, err := cmd.CombinedOutput() if err != nil { if passEntryMissing(string(output)) { @@ -668,7 +671,7 @@ func (s *PassStore) passRemove(ctx context.Context, path string) error { } func (s *PassStore) passShow(ctx context.Context, path string) (string, bool, error) { - cmd := exec.CommandContext(ctx, "pass", "show", path) + cmd := exec.CommandContext(ctx, "pass", "show", path) //nolint:gosec // binary is constant ("pass"); path arguments validated by secrets path validator output, err := cmd.CombinedOutput() if err != nil { if passEntryMissing(string(output)) { @@ -687,7 +690,7 @@ func (s *PassStore) listTopLevelEntries(ctx context.Context, basePath string) ([ return nil, err } - cmd := exec.CommandContext(ctx, "pass", "ls", basePath) + cmd := exec.CommandContext(ctx, "pass", "ls", basePath) //nolint:gosec // binary is constant ("pass"); path arguments validated by secrets path validator output, err := cmd.CombinedOutput() if err != nil { if passEntryMissing(string(output)) { @@ -809,8 +812,12 @@ func (s *PassStore) cleanupInsertedPaths(paths []string) { } for _, path := range paths { - ctx, cancel := context.WithTimeout(context.Background(), s.timeout) - _ = s.passRemove(ctx, path) - cancel() + // Wrap in a closure so defer cancel() runs at the end of each iteration, + // not at the end of cleanupInsertedPaths (which would delay all cancels). + func(p string) { + ctx, cancel := context.WithTimeout(context.Background(), s.timeout) + defer cancel() + _ = s.passRemove(ctx, p) + }(path) } } diff --git a/internal/adapters/out/secrets/pass.go b/internal/adapters/out/secrets/pass.go index 87c038de..0a8cadf1 100644 --- a/internal/adapters/out/secrets/pass.go +++ b/internal/adapters/out/secrets/pass.go @@ -80,7 +80,7 @@ func (p *PassProvider) GetSecret(ctx context.Context, path string) (string, erro ctx, cancel := context.WithTimeout(ctx, p.timeout) defer cancel() - cmd := exec.CommandContext(ctx, "pass", "show", path) + cmd := exec.CommandContext(ctx, "pass", "show", path) //nolint:gosec // binary is constant ("pass"); arguments are validated secret paths output, err := cmd.Output() if err != nil { if exitError, ok := err.(*exec.ExitError); ok { @@ -105,8 +105,10 @@ func (p *PassProvider) GetSecret(ctx context.Context, path string) (string, erro } // IsAvailable checks if pass is available in the system. +// Uses a short timeout so a stalled GPG agent does not hang the caller. func (p *PassProvider) IsAvailable() bool { - cmd := exec.Command("pass", "version") - err := cmd.Run() - return err == nil + ctx, cancel := context.WithTimeout(context.Background(), p.timeout) + defer cancel() + cmd := exec.CommandContext(ctx, "pass", "version") //nolint:gosec // binary is a constant ("pass"), no user input + return cmd.Run() == nil } diff --git a/internal/adapters/out/tokenstore/pass.go b/internal/adapters/out/tokenstore/pass.go index 05621022..87e04797 100644 --- a/internal/adapters/out/tokenstore/pass.go +++ b/internal/adapters/out/tokenstore/pass.go @@ -10,6 +10,7 @@ import ( "strings" "sync" "time" + "unicode/utf8" "github.com/bnema/zerowrap" @@ -52,6 +53,10 @@ func validateSubject(subject string) error { return fmt.Errorf("invalid subject: cannot contain '..' to prevent path traversal") } + if strings.HasSuffix(subject, ".meta") { + return fmt.Errorf("invalid subject: cannot end with '.meta' (reserved for token metadata files)") + } + return nil } @@ -79,12 +84,13 @@ func NewPassStore(log zerowrap.Logger) *PassStore { // tokenMetadata holds the non-sensitive token information. type tokenMetadata struct { - ID string `json:"id"` - Subject string `json:"subject"` - Scopes []string `json:"scopes"` - IssuedAt time.Time `json:"issued_at"` - ExpiresAt time.Time `json:"expires_at,omitempty"` - Revoked bool `json:"revoked"` + ID string `json:"id"` + Subject string `json:"subject"` + Scopes []string `json:"scopes"` + IssuedAt time.Time `json:"issued_at"` + ExpiresAt time.Time `json:"expires_at,omitempty"` + Revoked bool `json:"revoked"` + LastExtendedAt time.Time `json:"last_extended_at"` } // SaveToken stores a token JWT and metadata in pass. @@ -104,12 +110,13 @@ func (s *PassStore) SaveToken(ctx context.Context, token *domain.Token, jwt stri // Store metadata meta := tokenMetadata{ - ID: token.ID, - Subject: token.Subject, - Scopes: token.Scopes, - IssuedAt: token.IssuedAt, - ExpiresAt: token.ExpiresAt, - Revoked: token.Revoked, + ID: token.ID, + Subject: token.Subject, + Scopes: token.Scopes, + IssuedAt: token.IssuedAt, + ExpiresAt: token.ExpiresAt, + Revoked: token.Revoked, + LastExtendedAt: token.LastExtendedAt, } metaJSON, err := json.Marshal(meta) if err != nil { @@ -173,12 +180,13 @@ func (s *PassStore) GetToken(ctx context.Context, subject string) (string, *doma } token := &domain.Token{ - ID: meta.ID, - Subject: meta.Subject, - Scopes: meta.Scopes, - IssuedAt: meta.IssuedAt, - ExpiresAt: meta.ExpiresAt, - Revoked: meta.Revoked, + ID: meta.ID, + Subject: meta.Subject, + Scopes: meta.Scopes, + IssuedAt: meta.IssuedAt, + ExpiresAt: meta.ExpiresAt, + Revoked: meta.Revoked, + LastExtendedAt: meta.LastExtendedAt, } // Cache the token @@ -189,49 +197,170 @@ func (s *PassStore) GetToken(ctx context.Context, subject string) (string, *doma return jwt, token, nil } +// treeEntry holds a single parsed entry from pass ls tree output. +type treeEntry struct { + depth int + name string +} + +// parsePassLsEntries parses raw pass ls output into a slice of tree entries. +// Each entry carries its nesting depth (0 = direct child of the listed root) +// and the bare file/directory name on that line. +func parsePassLsEntries(output string) []treeEntry { + var entries []treeEntry + + lines := strings.Split(output, "\n") + // Skip the first line (the directory header, e.g. "gordon/registry/tokens"). + for _, line := range lines[1:] { + // Strip ANSI escape sequences (pass may output coloured text on a TTY). + line = ansiRegex.ReplaceAllString(line, "") + + if strings.TrimSpace(line) == "" { + continue + } + + // Locate the branch marker. pass uses Unicode box-drawing characters: + // ├── or └── (U+251C / U+2514, then two U+2500 dashes, then a space) + // Some terminals may fall back to ASCII equivalents: + // |-- or `-- (1-byte ASCII chars) + // + // Regardless of Unicode vs ASCII, each indent level is exactly 4 characters + // wide (e.g. "│ " = pipe + 3 spaces, " " = 4 spaces, "| " = pipe + 3). + // We count runes before the branch character and divide by 4 to get depth. + + var nameOffset int // byte offset where the entry name begins + var branchRune int // rune count before the branch character (├ └ | `) + + if idx := strings.Index(line, "── "); idx != -1 { + // Unicode variant: the two '─' dashes (U+2500) start at byte idx. + // The branch character (├ or └, 3 bytes each) immediately precedes them. + branchByteIdx := idx - 3 + if branchByteIdx < 0 { + // Marker is too close to the start — malformed line, skip. + continue + } + // Count runes in the indentation prefix before the branch char. + branchRune = utf8.RuneCountInString(line[:branchByteIdx]) + // Name starts after: branchChar(3B) + ─(3B) + ─(3B) + SP(1B) = 10B from branch. + nameOffset = branchByteIdx + 10 + } else if idx := strings.Index(line, "-- "); idx != -1 { + // ASCII variant: branch char (| or `, 1 byte) is immediately before the dashes. + branchByteIdx := idx - 1 + if branchByteIdx < 0 { + // Marker is too close to the start — malformed line, skip. + continue + } + branchRune = utf8.RuneCountInString(line[:branchByteIdx]) + // Name starts after: branchChar(1B) + -(1B) + -(1B) + SP(1B) = 4B from branch. + nameOffset = branchByteIdx + 4 + } else { + // No recognisable tree marker on this line — skip it. + continue + } + + if nameOffset > len(line) { + continue + } + name := strings.TrimSpace(line[nameOffset:]) + if name == "" { + continue + } + + // Each indent level is 4 characters wide. + depth := branchRune / 4 + entries = append(entries, treeEntry{depth: depth, name: name}) + } + + return entries +} + +// parsePassLsOutput parses the tree-formatted output of `pass ls` and returns +// a flat list of full subject paths. It reconstructs slash-separated paths for +// nested entries (e.g. "team/alice") by tracking the indentation depth. +// Only leaf nodes (actual pass entries, not intermediate directories) are returned. +// +// pass ls outputs a tree like: +// +// gordon/registry/tokens +// ├── admin +// └── team +// ├── alice +// └── bob +// +// This function returns ["admin", "team/alice", "team/bob"]. +func parsePassLsOutput(output string) []string { + type frame struct { + depth int + name string + } + + entries := parsePassLsEntries(output) + + var subjects []string + var stack []frame + + for i, entry := range entries { + // Pop ancestors that are at the same depth or deeper than the current entry. + for len(stack) > 0 && stack[len(stack)-1].depth >= entry.depth { + stack = stack[:len(stack)-1] + } + + // Build the full path for this entry by joining ancestors with this name. + parts := make([]string, 0, len(stack)+1) + for _, f := range stack { + parts = append(parts, f.name) + } + parts = append(parts, entry.name) + subject := strings.Join(parts, "/") + + // Determine whether this entry is an intermediate directory node. + // A node is a directory if the immediately following entry is deeper. + isDir := i+1 < len(entries) && entries[i+1].depth > entry.depth + + // Always push so deeper entries can use this as ancestor context. + stack = append(stack, frame(entry)) + + if !isDir { + subjects = append(subjects, subject) + } + } + + if subjects == nil { + return []string{} + } + return subjects +} + // ListTokens returns all stored tokens from pass. func (s *PassStore) ListTokens(ctx context.Context) ([]domain.Token, error) { ctx, cancel := context.WithTimeout(ctx, s.timeout) defer cancel() // List all entries under the token path - cmd := exec.CommandContext(ctx, "pass", "ls", passTokenPath) + cmd := exec.CommandContext(ctx, "pass", "ls", passTokenPath) //nolint:gosec // passTokenPath is a fixed constant output, err := cmd.Output() if err != nil { // If the path doesn't exist, return empty list return []domain.Token{}, nil } - var tokens []domain.Token - lines := strings.Split(string(output), "\n") - - for _, line := range lines { - line = strings.TrimSpace(line) + subjects := parsePassLsOutput(string(output)) - // Strip ANSI escape sequences (pass outputs colored text on TTY) - line = ansiRegex.ReplaceAllString(line, "") - - // Skip tree formatting characters (both Unicode and ASCII variants) - // Unicode: ├── └── │ - line = strings.TrimPrefix(line, "├── ") - line = strings.TrimPrefix(line, "└── ") - line = strings.TrimPrefix(line, "│ ") - // ASCII: |-- `-- | - line = strings.TrimPrefix(line, "|-- ") - line = strings.TrimPrefix(line, "`-- ") - line = strings.TrimPrefix(line, "| ") - - line = strings.TrimSpace(line) - - // Skip .meta files, empty lines, and the header line - if line == "" || strings.HasSuffix(line, ".meta") || line == passTokenPath { + var tokens []domain.Token + for _, subject := range subjects { + // Skip .meta files — they are metadata companions, not tokens themselves. + if strings.HasSuffix(subject, ".meta") { continue } - // Try to get the token metadata - _, token, err := s.GetToken(ctx, line) + // Try to get the token metadata. + // Use a fresh per-call context so each GetToken gets the full timeout, + // not the shared (potentially exhausted) deadline from the outer ctx. + callCtx, callCancel := context.WithTimeout(context.Background(), s.timeout) + _, token, err := s.GetToken(callCtx, subject) + callCancel() if err != nil { - s.log.Warn().Err(err).Str("subject", line).Msg("failed to get token") + s.log.Warn().Err(err).Str("subject", subject).Msg("failed to get token") continue } @@ -321,6 +450,25 @@ func (s *PassStore) IsRevoked(ctx context.Context, tokenID string) (bool, error) return revoked, nil } +// UpdateTokenExpiry updates the JWT and expiry metadata for an existing token. +// LastExtendedAt is also updated to track debounce timing. +// UpdateTokenExpiry enforces update-only semantics: it returns an error if +// token is nil or if no existing record is found for token.Subject. +func (s *PassStore) UpdateTokenExpiry(ctx context.Context, token *domain.Token, newJWT string) error { + if token == nil { + return fmt.Errorf("UpdateTokenExpiry: token must not be nil") + } + // Confirm the token exists before overwriting to avoid silent creation. + _, existing, err := s.GetToken(ctx, token.Subject) + if err != nil { + return fmt.Errorf("UpdateTokenExpiry: token not found for subject %q: %w", token.Subject, err) + } + if existing == nil { + return fmt.Errorf("UpdateTokenExpiry: no existing token for subject %q", token.Subject) + } + return s.SaveToken(ctx, token, newJWT) +} + // DeleteToken removes token from pass. func (s *PassStore) DeleteToken(ctx context.Context, subject string) error { if err := validateSubject(subject); err != nil { @@ -363,7 +511,7 @@ func (s *PassStore) DeleteToken(ctx context.Context, subject string) error { // passInsert inserts a value into pass. func (s *PassStore) passInsert(ctx context.Context, path, value string) error { - cmd := exec.CommandContext(ctx, "pass", "insert", "-m", "-f", path) + cmd := exec.CommandContext(ctx, "pass", "insert", "-m", "-f", path) //nolint:gosec // binary is constant ("pass"); path arguments are sanitized token subjects cmd.Stdin = strings.NewReader(value) output, err := cmd.CombinedOutput() if err != nil { @@ -374,7 +522,7 @@ func (s *PassStore) passInsert(ctx context.Context, path, value string) error { // passShow retrieves a value from pass. func (s *PassStore) passShow(ctx context.Context, path string) (string, error) { - cmd := exec.CommandContext(ctx, "pass", "show", path) + cmd := exec.CommandContext(ctx, "pass", "show", path) //nolint:gosec // binary is constant ("pass"); path arguments are sanitized token subjects output, err := cmd.Output() if err != nil { return "", fmt.Errorf("pass show failed: %w", err) diff --git a/internal/adapters/out/tokenstore/pass_test.go b/internal/adapters/out/tokenstore/pass_test.go index 2309e826..2c8e1baf 100644 --- a/internal/adapters/out/tokenstore/pass_test.go +++ b/internal/adapters/out/tokenstore/pass_test.go @@ -12,6 +12,74 @@ import ( "github.com/bnema/gordon/internal/domain" ) +// TestParsePassLsOutput verifies that pass ls tree output is parsed correctly, +// including subjects that contain "/" (multi-level paths). +func TestParsePassLsOutput(t *testing.T) { + tests := []struct { + name string + input string + want []string + }{ + { + name: "flat entries only", + input: `gordon/registry/tokens +├── admin +└── bob +`, + want: []string{"admin", "bob"}, + }, + { + name: "nested entries with slash in subject", + input: `gordon/registry/tokens +├── admin +└── team + ├── alice + └── bob +`, + want: []string{"admin", "team/alice", "team/bob"}, + }, + { + name: "deeply nested entries", + input: `gordon/registry/tokens +└── org + └── team + └── alice +`, + want: []string{"org/team/alice"}, + }, + { + name: "mixed flat and nested", + input: `gordon/registry/tokens +├── admin +├── team +│ ├── alice +│ └── bob +└── standalone +`, + want: []string{"admin", "team/alice", "team/bob", "standalone"}, + }, + { + name: "empty output (no entries)", + input: "gordon/registry/tokens\n", + want: []string{}, + }, + { + name: "ascii tree decoration variant", + input: `gordon/registry/tokens +|-- admin +` + "`-- team\n `-- alice\n", + want: []string{"admin", "team/alice"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := parsePassLsOutput(tt.input) + assert.Equal(t, tt.want, got) + }) + } +} + // These tests verify the in-memory caching behavior of PassStore. // They directly manipulate the cache to test cache hit/miss logic // without requiring the actual pass binary. diff --git a/internal/adapters/out/tokenstore/unsafe.go b/internal/adapters/out/tokenstore/unsafe.go index 9908d4a2..7e0e0610 100644 --- a/internal/adapters/out/tokenstore/unsafe.go +++ b/internal/adapters/out/tokenstore/unsafe.go @@ -9,6 +9,7 @@ import ( "os" "path/filepath" "strings" + "sync" "github.com/bnema/zerowrap" @@ -32,6 +33,7 @@ const ( // UnsafeStore implements TokenStore using plain text files. // WARNING: This store does not encrypt secrets. Only use when pass/sops are unavailable. type UnsafeStore struct { + mu sync.RWMutex dataDir string log zerowrap.Logger } @@ -73,12 +75,13 @@ func (s *UnsafeStore) SaveToken(_ context.Context, token *domain.Token, jwt stri data := unsafeTokenData{ JWT: jwt, Metadata: tokenMetadata{ - ID: token.ID, - Subject: token.Subject, // Store original subject in metadata - Scopes: token.Scopes, - IssuedAt: token.IssuedAt, - ExpiresAt: token.ExpiresAt, - Revoked: token.Revoked, + ID: token.ID, + Subject: token.Subject, // Store original subject in metadata + Scopes: token.Scopes, + IssuedAt: token.IssuedAt, + ExpiresAt: token.ExpiresAt, + Revoked: token.Revoked, + LastExtendedAt: token.LastExtendedAt, }, } @@ -124,12 +127,13 @@ func (s *UnsafeStore) GetToken(_ context.Context, subject string) (string, *doma } token := &domain.Token{ - ID: data.Metadata.ID, - Subject: data.Metadata.Subject, - Scopes: data.Metadata.Scopes, - IssuedAt: data.Metadata.IssuedAt, - ExpiresAt: data.Metadata.ExpiresAt, - Revoked: data.Metadata.Revoked, + ID: data.Metadata.ID, + Subject: data.Metadata.Subject, + Scopes: data.Metadata.Scopes, + IssuedAt: data.Metadata.IssuedAt, + ExpiresAt: data.Metadata.ExpiresAt, + Revoked: data.Metadata.Revoked, + LastExtendedAt: data.Metadata.LastExtendedAt, } return data.JWT, token, nil @@ -168,12 +172,13 @@ func (s *UnsafeStore) ListTokens(_ context.Context) ([]domain.Token, error) { } token := domain.Token{ - ID: data.Metadata.ID, - Subject: data.Metadata.Subject, - Scopes: data.Metadata.Scopes, - IssuedAt: data.Metadata.IssuedAt, - ExpiresAt: data.Metadata.ExpiresAt, - Revoked: data.Metadata.Revoked, + ID: data.Metadata.ID, + Subject: data.Metadata.Subject, + Scopes: data.Metadata.Scopes, + IssuedAt: data.Metadata.IssuedAt, + ExpiresAt: data.Metadata.ExpiresAt, + Revoked: data.Metadata.Revoked, + LastExtendedAt: data.Metadata.LastExtendedAt, } tokens = append(tokens, token) @@ -184,6 +189,9 @@ func (s *UnsafeStore) ListTokens(_ context.Context) ([]domain.Token, error) { // Revoke adds token ID to revocation list file. func (s *UnsafeStore) Revoke(_ context.Context, tokenID string) error { + s.mu.Lock() + defer s.mu.Unlock() + revokedFile := filepath.Join(s.dataDir, unsafeRevokedFile) // Ensure parent directory exists @@ -226,6 +234,9 @@ func (s *UnsafeStore) Revoke(_ context.Context, tokenID string) error { // IsRevoked checks if token ID is in revocation list. func (s *UnsafeStore) IsRevoked(_ context.Context, tokenID string) (bool, error) { + s.mu.RLock() + defer s.mu.RUnlock() + revokedList, err := s.getRevokedList() if err != nil { return false, err @@ -240,6 +251,12 @@ func (s *UnsafeStore) IsRevoked(_ context.Context, tokenID string) (bool, error) return false, nil } +// UpdateTokenExpiry updates the JWT and expiry metadata for an existing token. +// LastExtendedAt is also updated to track debounce timing. +func (s *UnsafeStore) UpdateTokenExpiry(ctx context.Context, token *domain.Token, newJWT string) error { + return s.SaveToken(ctx, token, newJWT) +} + // DeleteToken removes token file. func (s *UnsafeStore) DeleteToken(_ context.Context, subject string) error { // SECURITY: Use sanitized filename to prevent path traversal diff --git a/internal/adapters/out/tokenstore/unsafe_test.go b/internal/adapters/out/tokenstore/unsafe_test.go new file mode 100644 index 00000000..b0206843 --- /dev/null +++ b/internal/adapters/out/tokenstore/unsafe_test.go @@ -0,0 +1,78 @@ +package tokenstore + +import ( + "context" + "fmt" + "io" + "sync" + "testing" + + "github.com/bnema/zerowrap" + + "github.com/bnema/gordon/internal/domain" +) + +// newTestUnsafeStore creates a UnsafeStore backed by a temp directory. +func newTestUnsafeStore(t *testing.T) *UnsafeStore { + t.Helper() + dir := t.TempDir() + log := zerowrap.New(zerowrap.Config{Level: "disabled", Output: io.Discard}) + store, err := NewUnsafeStore(dir, log) + if err != nil { + t.Fatalf("NewUnsafeStore: %v", err) + } + return store +} + +func TestUnsafeStoreRevokeConcurrentNoLoss(t *testing.T) { + // Run with: go test -race ./internal/adapters/out/tokenstore/... -run TestUnsafeStoreRevokeConcurrentNoLoss -v + store := newTestUnsafeStore(t) + + ctx := context.Background() + const n = 10 + + // Save n tokens + tokenIDs := make([]string, n) + for i := 0; i < n; i++ { + id := fmt.Sprintf("token-%d", i) + tokenIDs[i] = id + tok := &domain.Token{ + ID: id, + Subject: fmt.Sprintf("user-%d", i), + Scopes: []string{"admin"}, + } + if err := store.SaveToken(ctx, tok, "jwt-"+id); err != nil { + t.Fatal(err) + } + } + + // Revoke all concurrently; collect errors via a buffered channel. + var wg sync.WaitGroup + errs := make(chan error, len(tokenIDs)) + for _, id := range tokenIDs { + wg.Add(1) + go func(id string) { + defer wg.Done() + if err := store.Revoke(ctx, id); err != nil { + errs <- fmt.Errorf("Revoke(%q): %w", id, err) + } + }(id) + } + wg.Wait() + close(errs) + for err := range errs { + t.Error(err) + } + + // Verify all n token IDs are in the revoked list + for _, id := range tokenIDs { + revoked, err := store.IsRevoked(ctx, id) + if err != nil { + t.Errorf("IsRevoked(%q): %v", id, err) + continue + } + if !revoked { + t.Errorf("token %q should be revoked but IsRevoked returned false", id) + } + } +} diff --git a/internal/app/run.go b/internal/app/run.go index dbacead0..f5569c98 100644 --- a/internal/app/run.go +++ b/internal/app/run.go @@ -819,19 +819,73 @@ func cleanupInternalCredentials() { _ = os.Remove(getInternalCredentialsFile()) } -// GetInternalCredentials reads the internal registry credentials from file. -// This is used by the CLI to display credentials for manual recovery. -func GetInternalCredentials() (*InternalCredentials, error) { - credFile := getInternalCredentialsFile() - data, err := os.ReadFile(credFile) - if err != nil { - return nil, fmt.Errorf("failed to read credentials file (is Gordon running?): %w", err) +// getInternalCredentialsCandidates returns candidate file paths in priority order: +// 1. XDG_RUNTIME_DIR/gordon/ (set by systemd for the daemon) +// 2. /run/user//gordon/ (well-known systemd default, for CLI in shells without XDG_RUNTIME_DIR) +// 3. ~/.gordon/run/ (fallback for non-systemd environments) +// 4. os.TempDir() (last resort, matches getInternalCredentialsFile fallback path) +func getInternalCredentialsCandidates() []string { + var candidates []string + + // 1. XDG_RUNTIME_DIR (set in daemon's environment) + if runtimeDir := os.Getenv("XDG_RUNTIME_DIR"); runtimeDir != "" { + candidates = append(candidates, filepath.Join(runtimeDir, "gordon", "internal-creds.json")) + } + + // 2. /run/user//gordon/ (systemd default, may not be in CLI's env) + uid := os.Getuid() + sysRuntime := filepath.Join("/run/user", fmt.Sprintf("%d", uid), "gordon", "internal-creds.json") + // Avoid duplicate if XDG_RUNTIME_DIR already points here + if len(candidates) == 0 || candidates[0] != sysRuntime { + candidates = append(candidates, sysRuntime) + } + + // 3. ~/.gordon/run/ fallback + if homeDir, err := os.UserHomeDir(); err == nil { + candidates = append(candidates, filepath.Join(homeDir, ".gordon", "run", "internal-creds.json")) + } + + // 4. os.TempDir() last resort — matches the fallback path in getInternalCredentialsFile, + // ensuring GetInternalCredentials can find credentials even when getSecureRuntimeDir fails. + candidates = append(candidates, filepath.Join(os.TempDir(), "gordon-internal-creds.json")) + + return candidates +} + +// GetInternalCredentialsFromCandidates reads credentials from the first candidate file that exists. +// Exported for testing. +func GetInternalCredentialsFromCandidates(candidates []string) (*InternalCredentials, error) { + var lastErr error + for _, path := range candidates { + data, err := os.ReadFile(path) + if os.IsNotExist(err) { + continue + } + if err != nil { + // Non-permission errors (e.g. EACCES) may be transient or path-specific; + // record and try the next candidate rather than failing immediately. + lastErr = fmt.Errorf("failed to read credentials file %s: %w", path, err) + continue + } + var creds InternalCredentials + if err := json.Unmarshal(data, &creds); err != nil { + // Corrupt file — record and fall through to lower-priority candidates. + lastErr = fmt.Errorf("failed to parse credentials at %s: %w", path, err) + continue + } + return &creds, nil } - var creds InternalCredentials - if err := json.Unmarshal(data, &creds); err != nil { - return nil, fmt.Errorf("failed to parse credentials: %w", err) + if lastErr != nil { + return nil, lastErr } - return &creds, nil + return nil, fmt.Errorf("no credentials file found (is Gordon running?): checked %v", candidates) +} + +// GetInternalCredentials reads the internal registry credentials from file. +// Probes all candidate runtime directories so CLI works regardless of whether +// XDG_RUNTIME_DIR is set in the current shell environment. +func GetInternalCredentials() (*InternalCredentials, error) { + return GetInternalCredentialsFromCandidates(getInternalCredentialsCandidates()) } // createAuthService creates the authentication service and token store. @@ -1645,12 +1699,13 @@ func runServers(ctx context.Context, v *viper.Viper, cfg Config, registryHandler registrySrv, registryReady := startServer(fmt.Sprintf(":%d", cfg.Server.RegistryPort), registryHandler, "registry", errChan, log) proxySrv, proxyReady := startServer(fmt.Sprintf(":%d", cfg.Server.Port), proxyHandler, "proxy", errChan, log) + var tlsSrv *http.Server var tlsReady <-chan struct{} if cfg.Server.TLSEnabled { if cfg.Server.TLSCertFile == "" || cfg.Server.TLSKeyFile == "" { return fmt.Errorf("server.tls_enabled=true requires both server.tls_cert_file and server.tls_key_file") } - tlsReady = startTLSServer( + tlsSrv, tlsReady = startTLSServer( fmt.Sprintf(":%d", cfg.Server.TLSPort), proxyHandler, "proxy-tls", @@ -1695,7 +1750,7 @@ func runServers(ctx context.Context, v *viper.Viper, cfg Config, registryHandler syncAndAutoStart(ctx, svc, log) waitForShutdown(ctx, errChan, reloadChan, deployChan, eventBus, log) - gracefulShutdown(registrySrv, proxySrv, containerSvc, log) + gracefulShutdown(registrySrv, proxySrv, tlsSrv, containerSvc, log) return nil } @@ -1895,17 +1950,19 @@ func waitForShutdown(ctx context.Context, errChan <-chan error, reloadChan, depl // gracefulShutdown stops HTTP servers with a 30s timeout, then shuts down // the container service and cleans up runtime files. -func gracefulShutdown(registrySrv, proxySrv *http.Server, containerSvc *container.Service, log zerowrap.Logger) { +func gracefulShutdown(registrySrv, proxySrv, tlsSrv *http.Server, containerSvc *container.Service, log zerowrap.Logger) { log.Info().Msg("shutting down Gordon...") shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second) defer shutdownCancel() - if err := registrySrv.Shutdown(shutdownCtx); err != nil { - log.Warn().Err(err).Msg("registry server shutdown error") - } - if err := proxySrv.Shutdown(shutdownCtx); err != nil { - log.Warn().Err(err).Msg("proxy server shutdown error") + for _, srv := range []*http.Server{registrySrv, proxySrv, tlsSrv} { + if srv == nil { + continue + } + if err := srv.Shutdown(shutdownCtx); err != nil { + log.Warn().Err(err).Str("addr", srv.Addr).Msg("server shutdown error") + } } containerSvc.StopMonitor() @@ -1955,9 +2012,20 @@ func startServer(addr string, handler http.Handler, name string, errChan chan<- } // startTLSServer starts an HTTPS server with the provided certificate and key. -func startTLSServer(addr string, handler http.Handler, name, certFile, keyFile string, errChan chan<- error, log zerowrap.Logger) <-chan struct{} { +// It returns the server instance (for graceful shutdown) and a channel that closes once the port is bound. +func startTLSServer(addr string, handler http.Handler, name, certFile, keyFile string, errChan chan<- error, log zerowrap.Logger) (*http.Server, <-chan struct{}) { ready := make(chan struct{}) + server := &http.Server{ + Addr: addr, + Handler: handler, + ReadHeaderTimeout: 10 * time.Second, + ReadTimeout: 5 * time.Minute, + WriteTimeout: 5 * time.Minute, + IdleTimeout: 120 * time.Second, + MaxHeaderBytes: 1 << 20, + } + go func() { log.Info(). Str("address", addr). @@ -1965,16 +2033,6 @@ func startTLSServer(addr string, handler http.Handler, name, certFile, keyFile s Str("key_file", keyFile). Msgf("%s server starting", name) - server := &http.Server{ - Addr: addr, - Handler: handler, - ReadHeaderTimeout: 10 * time.Second, - ReadTimeout: 5 * time.Minute, - WriteTimeout: 5 * time.Minute, - IdleTimeout: 120 * time.Second, - MaxHeaderBytes: 1 << 20, - } - ln, err := net.Listen("tcp", addr) if err != nil { errChan <- fmt.Errorf("%s server error: %w", name, err) @@ -1987,7 +2045,7 @@ func startTLSServer(addr string, handler http.Handler, name, certFile, keyFile s } }() - return ready + return server, ready } // SendReloadSignal sends SIGUSR1 to the running Gordon process. diff --git a/internal/app/run_internal_creds_test.go b/internal/app/run_internal_creds_test.go new file mode 100644 index 00000000..7ecbbea8 --- /dev/null +++ b/internal/app/run_internal_creds_test.go @@ -0,0 +1,78 @@ +package app + +import ( + "os" + "path/filepath" + "testing" +) + +func TestGetInternalCredentialsFindsRuntimeFile(t *testing.T) { + // Simulate daemon writing to XDG_RUNTIME_DIR/gordon/ while + // CLI has no XDG_RUNTIME_DIR set. + tmpRuntime := t.TempDir() + gordonRuntime := filepath.Join(tmpRuntime, "gordon") + if err := os.MkdirAll(gordonRuntime, 0700); err != nil { + t.Fatal(err) + } + + // Write "live" creds to the runtime dir (as daemon would) + liveCreds := `{"username":"gordon-internal","password":"livepassword"}` + if err := os.WriteFile(filepath.Join(gordonRuntime, "internal-creds.json"), []byte(liveCreds), 0600); err != nil { + t.Fatal(err) + } + + // Also write a stale creds file to home fallback + tmpHome := t.TempDir() + gordonHome := filepath.Join(tmpHome, ".gordon", "run") + if err := os.MkdirAll(gordonHome, 0700); err != nil { + t.Fatal(err) + } + staleCreds := `{"username":"gordon-internal","password":"stalepassword"}` + if err := os.WriteFile(filepath.Join(gordonHome, "internal-creds.json"), []byte(staleCreds), 0600); err != nil { + t.Fatal(err) + } + + creds, err := GetInternalCredentialsFromCandidates([]string{ + filepath.Join(gordonRuntime, "internal-creds.json"), + filepath.Join(gordonHome, "internal-creds.json"), + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if creds.Password != "livepassword" { + t.Errorf("got password %q, want %q", creds.Password, "livepassword") + } +} + +func TestGetInternalCredentialsFallsBackToStale(t *testing.T) { + // If only the stale/home file exists, it should still work + tmpHome := t.TempDir() + gordonHome := filepath.Join(tmpHome, ".gordon", "run") + if err := os.MkdirAll(gordonHome, 0700); err != nil { + t.Fatal(err) + } + staleCreds := `{"username":"gordon-internal","password":"stalepassword"}` + if err := os.WriteFile(filepath.Join(gordonHome, "internal-creds.json"), []byte(staleCreds), 0600); err != nil { + t.Fatal(err) + } + + creds, err := GetInternalCredentialsFromCandidates([]string{ + filepath.Join(t.TempDir(), "gordon", "internal-creds.json"), // doesn't exist + filepath.Join(gordonHome, "internal-creds.json"), + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if creds.Password != "stalepassword" { + t.Errorf("got password %q, want %q", creds.Password, "stalepassword") + } +} + +func TestGetInternalCredentialsNoCandidatesReturnsError(t *testing.T) { + _, err := GetInternalCredentialsFromCandidates([]string{ + filepath.Join(t.TempDir(), "nonexistent", "internal-creds.json"), + }) + if err == nil { + t.Fatal("expected error when no credentials found, got nil") + } +} diff --git a/internal/boundaries/in/auth.go b/internal/boundaries/in/auth.go index dc34d4ff..affd95e0 100644 --- a/internal/boundaries/in/auth.go +++ b/internal/boundaries/in/auth.go @@ -46,4 +46,9 @@ type AuthService interface { // GetAuthStatus returns authentication status from context. // Extracts and validates token claims that were set by auth middleware. GetAuthStatus(ctx context.Context) (*domain.AuthStatus, error) + + // ExtendToken re-issues the token with expiry slid forward by 24h. + // Returns the same token string if it was already extended within the debounce window (1h). + // Skips extension for ephemeral access tokens (≤5min) and service tokens. + ExtendToken(ctx context.Context, tokenString string) (string, error) } diff --git a/internal/boundaries/in/container.go b/internal/boundaries/in/container.go index ec6c62cf..53bf6fc9 100644 --- a/internal/boundaries/in/container.go +++ b/internal/boundaries/in/container.go @@ -45,6 +45,10 @@ type ContainerService interface { // SyncContainers synchronizes containers with configured routes. SyncContainers(ctx context.Context) error + // UpdateAttachments updates the attachment configuration in the container service. + // This is called after a config reload to propagate attachment changes without restart. + UpdateAttachments(attachments map[string][]string) + // AutoStart starts containers for the provided routes that aren't running. AutoStart(ctx context.Context, routes []domain.Route) error diff --git a/internal/boundaries/in/mocks/mock_auth_service.go b/internal/boundaries/in/mocks/mock_auth_service.go index d75d12ec..e845f234 100644 --- a/internal/boundaries/in/mocks/mock_auth_service.go +++ b/internal/boundaries/in/mocks/mock_auth_service.go @@ -39,6 +39,72 @@ func (_m *MockAuthService) EXPECT() *MockAuthService_Expecter { return &MockAuthService_Expecter{mock: &_m.Mock} } +// ExtendToken provides a mock function for the type MockAuthService +func (_mock *MockAuthService) ExtendToken(ctx context.Context, tokenString string) (string, error) { + ret := _mock.Called(ctx, tokenString) + + if len(ret) == 0 { + panic("no return value specified for ExtendToken") + } + + var r0 string + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) (string, error)); ok { + return returnFunc(ctx, tokenString) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string) string); ok { + r0 = returnFunc(ctx, tokenString) + } else { + r0 = ret.Get(0).(string) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = returnFunc(ctx, tokenString) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// MockAuthService_ExtendToken_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ExtendToken' +type MockAuthService_ExtendToken_Call struct { + *mock.Call +} + +// ExtendToken is a helper method to define mock.On call +// - ctx context.Context +// - tokenString string +func (_e *MockAuthService_Expecter) ExtendToken(ctx interface{}, tokenString interface{}) *MockAuthService_ExtendToken_Call { + return &MockAuthService_ExtendToken_Call{Call: _e.mock.On("ExtendToken", ctx, tokenString)} +} + +func (_c *MockAuthService_ExtendToken_Call) Run(run func(ctx context.Context, tokenString string)) *MockAuthService_ExtendToken_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockAuthService_ExtendToken_Call) Return(s string, err error) *MockAuthService_ExtendToken_Call { + _c.Call.Return(s, err) + return _c +} + +func (_c *MockAuthService_ExtendToken_Call) RunAndReturn(run func(ctx context.Context, tokenString string) (string, error)) *MockAuthService_ExtendToken_Call { + _c.Call.Return(run) + return _c +} + // GenerateAccessToken provides a mock function for the type MockAuthService func (_mock *MockAuthService) GenerateAccessToken(ctx context.Context, subject string, scopes []string, expiry time.Duration) (string, error) { ret := _mock.Called(ctx, subject, scopes, expiry) diff --git a/internal/boundaries/in/mocks/mock_container_service.go b/internal/boundaries/in/mocks/mock_container_service.go index 9879312d..cd296b54 100644 --- a/internal/boundaries/in/mocks/mock_container_service.go +++ b/internal/boundaries/in/mocks/mock_container_service.go @@ -795,3 +795,43 @@ func (_c *MockContainerService_SyncContainers_Call) RunAndReturn(run func(ctx co _c.Call.Return(run) return _c } + +// UpdateAttachments provides a mock function for the type MockContainerService +func (_mock *MockContainerService) UpdateAttachments(attachments map[string][]string) { + _mock.Called(attachments) + return +} + +// MockContainerService_UpdateAttachments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateAttachments' +type MockContainerService_UpdateAttachments_Call struct { + *mock.Call +} + +// UpdateAttachments is a helper method to define mock.On call +// - attachments map[string][]string +func (_e *MockContainerService_Expecter) UpdateAttachments(attachments interface{}) *MockContainerService_UpdateAttachments_Call { + return &MockContainerService_UpdateAttachments_Call{Call: _e.mock.On("UpdateAttachments", attachments)} +} + +func (_c *MockContainerService_UpdateAttachments_Call) Run(run func(attachments map[string][]string)) *MockContainerService_UpdateAttachments_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 map[string][]string + if args[0] != nil { + arg0 = args[0].(map[string][]string) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *MockContainerService_UpdateAttachments_Call) Return() *MockContainerService_UpdateAttachments_Call { + _c.Call.Return() + return _c +} + +func (_c *MockContainerService_UpdateAttachments_Call) RunAndReturn(run func(attachments map[string][]string)) *MockContainerService_UpdateAttachments_Call { + _c.Run(run) + return _c +} diff --git a/internal/boundaries/out/mocks/mock_token_store.go b/internal/boundaries/out/mocks/mock_token_store.go index 4a7d97a3..38478d73 100644 --- a/internal/boundaries/out/mocks/mock_token_store.go +++ b/internal/boundaries/out/mocks/mock_token_store.go @@ -416,3 +416,66 @@ func (_c *MockTokenStore_SaveToken_Call) RunAndReturn(run func(ctx context.Conte _c.Call.Return(run) return _c } + +// UpdateTokenExpiry provides a mock function for the type MockTokenStore +func (_mock *MockTokenStore) UpdateTokenExpiry(ctx context.Context, token *domain.Token, newJWT string) error { + ret := _mock.Called(ctx, token, newJWT) + + if len(ret) == 0 { + panic("no return value specified for UpdateTokenExpiry") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, *domain.Token, string) error); ok { + r0 = returnFunc(ctx, token, newJWT) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// MockTokenStore_UpdateTokenExpiry_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateTokenExpiry' +type MockTokenStore_UpdateTokenExpiry_Call struct { + *mock.Call +} + +// UpdateTokenExpiry is a helper method to define mock.On call +// - ctx context.Context +// - token *domain.Token +// - newJWT string +func (_e *MockTokenStore_Expecter) UpdateTokenExpiry(ctx interface{}, token interface{}, newJWT interface{}) *MockTokenStore_UpdateTokenExpiry_Call { + return &MockTokenStore_UpdateTokenExpiry_Call{Call: _e.mock.On("UpdateTokenExpiry", ctx, token, newJWT)} +} + +func (_c *MockTokenStore_UpdateTokenExpiry_Call) Run(run func(ctx context.Context, token *domain.Token, newJWT string)) *MockTokenStore_UpdateTokenExpiry_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 *domain.Token + if args[1] != nil { + arg1 = args[1].(*domain.Token) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *MockTokenStore_UpdateTokenExpiry_Call) Return(err error) *MockTokenStore_UpdateTokenExpiry_Call { + _c.Call.Return(err) + return _c +} + +func (_c *MockTokenStore_UpdateTokenExpiry_Call) RunAndReturn(run func(ctx context.Context, token *domain.Token, newJWT string) error) *MockTokenStore_UpdateTokenExpiry_Call { + _c.Call.Return(run) + return _c +} diff --git a/internal/boundaries/out/token_store.go b/internal/boundaries/out/token_store.go index 4b363ae8..d7a55fa1 100644 --- a/internal/boundaries/out/token_store.go +++ b/internal/boundaries/out/token_store.go @@ -26,4 +26,8 @@ type TokenStore interface { // DeleteToken removes token from store. DeleteToken(ctx context.Context, subject string) error + + // UpdateTokenExpiry updates the JWT and expiry/LastExtendedAt metadata for an existing token. + // Used by token sliding expiry to re-sign tokens without changing the JTI. + UpdateTokenExpiry(ctx context.Context, token *domain.Token, newJWT string) error } diff --git a/internal/domain/auth.go b/internal/domain/auth.go index be068166..6be17c86 100644 --- a/internal/domain/auth.go +++ b/internal/domain/auth.go @@ -20,12 +20,13 @@ const ( // Token represents a generated authentication token stored in the secrets backend. type Token struct { - ID string - Subject string - Scopes []string - IssuedAt time.Time - ExpiresAt time.Time // Zero value means never expires - Revoked bool + ID string + Subject string + Scopes []string + IssuedAt time.Time + ExpiresAt time.Time // Zero value means never expires + Revoked bool + LastExtendedAt time.Time // Zero value means never extended } // TokenClaims represents the JWT claims for a token. diff --git a/internal/domain/labels_test.go b/internal/domain/labels_test.go new file mode 100644 index 00000000..6d0c9c5e --- /dev/null +++ b/internal/domain/labels_test.go @@ -0,0 +1,31 @@ +package domain_test + +import ( + "testing" + + "github.com/bnema/gordon/internal/domain" +) + +// TestLabelConstantsValues guards against accidental value changes that +// would silently break container discovery. +func TestLabelConstantsValues(t *testing.T) { + tests := []struct { + constant string + expected string + }{ + {domain.LabelDomain, "gordon.domain"}, + {domain.LabelImage, "gordon.image"}, + {domain.LabelManaged, "gordon.managed"}, + {domain.LabelRoute, "gordon.route"}, + {domain.LabelAttachment, "gordon.attachment"}, + {domain.LabelAttachedTo, "gordon.attached-to"}, + {domain.LabelCreated, "gordon.created"}, + } + for _, tt := range tests { + t.Run(tt.expected, func(t *testing.T) { + if tt.constant != tt.expected { + t.Errorf("constant value changed: got %q, want %q", tt.constant, tt.expected) + } + }) + } +} diff --git a/internal/usecase/auth/service.go b/internal/usecase/auth/service.go index c14a0678..63670116 100644 --- a/internal/usecase/auth/service.go +++ b/internal/usecase/auth/service.go @@ -27,6 +27,13 @@ const ( MaxAccessTokenLifetime = 5 * time.Minute // maxAccessTokenLifetimeSecs is MaxAccessTokenLifetime in seconds for JWT comparisons. maxAccessTokenLifetimeSecs = int64(MaxAccessTokenLifetime / time.Second) + // tokenExtensionTTL is the amount of time added to a token's expiry when extended. + tokenExtensionTTL = 24 * time.Hour + // tokenExtensionDebounce is the minimum time between token extensions. + tokenExtensionDebounce = time.Hour + // serviceTokenSubject is the subject used for internal service tokens. + // Service tokens are not extended to avoid churn on the token store. + serviceTokenSubject = "gordon-service" ) // Config holds the authentication configuration. @@ -161,7 +168,10 @@ func (s *Service) parseTokenClaims(tokenString string) (jwt.MapClaims, error) { return nil, domain.ErrInvalidToken } return s.config.TokenSecret, nil - }, jwt.WithIssuer(TokenIssuer)) // SECURITY: Enforce issuer validation + }, + jwt.WithIssuer(TokenIssuer), // SECURITY: Enforce issuer validation + jwt.WithIssuedAt(), // SECURITY: Reject tokens with iat in the future + ) if err != nil { return nil, domain.ErrInvalidToken } @@ -258,6 +268,10 @@ func (s *Service) isEphemeralAccessToken(claims *domain.TokenClaims) bool { // stolen secrets from creating arbitrary short-lived tokens now := time.Now().UTC().Unix() tokenAge := now - claims.IssuedAt + if tokenAge < 0 { + // iat is in the future — do not classify as ephemeral (defense-in-depth: prevents bypass) + return false + } if tokenAge > maxAccessTokenLifetimeSecs { return false // Old tokens require store validation } @@ -433,6 +447,119 @@ func (s *Service) ListTokens(ctx context.Context) ([]domain.Token, error) { return tokens, nil } +// ExtendToken re-issues the token with expiry slid forward by 24h. +// Returns the same token string if it was already extended within the debounce window (1h). +// Skips extension for ephemeral access tokens (≤5min) and service tokens. +func (s *Service) ExtendToken(ctx context.Context, tokenString string) (string, error) { + ctx = zerowrap.CtxWithFields(ctx, map[string]any{ + zerowrap.FieldLayer: "usecase", + zerowrap.FieldUseCase: "ExtendToken", + }) + log := zerowrap.FromCtx(ctx) + + // Parse claims without store validation first (to check skip conditions cheaply) + rawClaims, err := s.parseTokenClaims(tokenString) + if err != nil { + return "", domain.ErrInvalidToken + } + + tokenClaims := buildTokenClaims(rawClaims) + + // Skip ephemeral access tokens (≤5min lifetime) + if s.isEphemeralAccessToken(tokenClaims) { + log.Debug().Msg("skipping extension for ephemeral access token") + return tokenString, nil + } + + // Skip service tokens + if tokenClaims.Subject == serviceTokenSubject { + log.Debug().Msg("skipping extension for service token") + return tokenString, nil + } + + // Skip tokens with no expiry (permanent tokens must not be silently converted to 24h tokens) + if tokenClaims.ExpiresAt <= 0 { + log.Debug().Str("subject", tokenClaims.Subject).Msg("skipping extension for non-expiring token") + return tokenString, nil + } + + // Full validation including store checks (not expired, not revoked, exists in store) + if _, err := s.ValidateToken(ctx, tokenString); err != nil { + return "", fmt.Errorf("token validation failed: %w", err) + } + + // Fetch stored token to check debounce and get metadata + _, storedToken, err := s.tokenStore.GetToken(ctx, tokenClaims.Subject) + if err != nil { + return "", fmt.Errorf("failed to fetch stored token: %w", err) + } + + // TOCTOU guard: confirm the store's current JTI matches the JWT we validated. + // A mismatch means the token was rotated between ValidateToken and GetToken; + // in that case return the original (still-valid) token without re-signing. + if storedToken.ID != tokenClaims.ID { + log.Debug(). + Str("subject", tokenClaims.Subject). + Str("jwt_jti", tokenClaims.ID). + Str("store_id", storedToken.ID). + Msg("token ID mismatch after rotation, skipping extension") + return tokenString, nil + } + + // Debounce: skip if already extended within the last hour + if !storedToken.LastExtendedAt.IsZero() && time.Since(storedToken.LastExtendedAt) < tokenExtensionDebounce { + log.Debug(). + Str("subject", tokenClaims.Subject). + Time("last_extended_at", storedToken.LastExtendedAt). + Msg("token extension debounced") + return tokenString, nil + } + + // Re-sign with same JTI but new expiry + now := time.Now().UTC() + newExpiresAt := now.Add(tokenExtensionTTL) + + newClaims := jwt.MapClaims{ + "jti": tokenClaims.ID, // Reuse existing JTI to avoid invalidating concurrent requests + "sub": tokenClaims.Subject, + "iss": TokenIssuer, + "iat": now.Unix(), // Update to current time (token is being re-issued) + "nbf": now.Unix(), // SECURITY: not-before matches issuance time + "scopes": tokenClaims.Scopes, + "exp": newExpiresAt.Unix(), + } + + jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, newClaims) + newTokenString, err := jwtToken.SignedString(s.config.TokenSecret) + if err != nil { + return "", log.WrapErr(err, "failed to re-sign token") + } + + // Update stored token with new JWT, expiry, and LastExtendedAt + updatedToken := &domain.Token{ + ID: storedToken.ID, + Subject: storedToken.Subject, + Scopes: storedToken.Scopes, + IssuedAt: storedToken.IssuedAt, + ExpiresAt: newExpiresAt, + Revoked: storedToken.Revoked, + LastExtendedAt: now, + } + + if err := s.tokenStore.UpdateTokenExpiry(ctx, updatedToken, newTokenString); err != nil { + // Non-fatal: log and return original token rather than failing the request + log.Warn().Err(err).Str("subject", tokenClaims.Subject).Msg("failed to persist extended token, returning original") + return tokenString, nil + } + + log.Debug(). + Str("subject", tokenClaims.Subject). + Time("new_expires_at", newExpiresAt). + Msg("token expiry extended") + + return newTokenString, nil +} + // GeneratePasswordHash generates a bcrypt hash for a password. func (s *Service) GeneratePasswordHash(password string) (string, error) { hash, err := bcrypt.GenerateFromPassword([]byte(password), DefaultBcryptCost) diff --git a/internal/usecase/auth/service_test.go b/internal/usecase/auth/service_test.go index 9432716e..3b803326 100644 --- a/internal/usecase/auth/service_test.go +++ b/internal/usecase/auth/service_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/bnema/zerowrap" + "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -816,3 +817,236 @@ func TestService_GenerateAccessToken_HasNbfClaim(t *testing.T) { assert.Equal(t, iat, nbf, "nbf should equal iat for access tokens") } + +func TestExtendTokenSlidesExpiry(t *testing.T) { + tokenStore := mocks.NewMockTokenStore(t) + + svc := NewService(Config{ + Enabled: true, + AuthType: domain.AuthTypeToken, + TokenSecret: []byte("test-secret-key-for-jwt-signing"), + }, tokenStore, zerowrap.Default()) + + ctx := testContext() + + // Generate a token with a short remaining life (1h) so ExtendToken produces a different exp + var capturedToken *domain.Token + tokenStore.EXPECT(). + SaveToken(mock.Anything, mock.Anything, mock.Anything). + Run(func(_ context.Context, t *domain.Token, _ string) { + capturedToken = t + }). + Return(nil) + + tokenStr, err := svc.GenerateToken(ctx, "testuser", []string{"admin:*:*"}, time.Hour) + require.NoError(t, err) + + // ExtendToken flow: + // 1. ValidateToken → GetToken (ensureTokenExists) + IsRevoked + tokenStore.EXPECT(). + GetToken(mock.Anything, "testuser"). + Return(tokenStr, capturedToken, nil).Once() + tokenStore.EXPECT(). + IsRevoked(mock.Anything, capturedToken.ID). + Return(false, nil).Once() + + // 2. GetToken again for debounce check + tokenStore.EXPECT(). + GetToken(mock.Anything, "testuser"). + Return(tokenStr, capturedToken, nil).Once() + + // 3. UpdateTokenExpiry + var updatedToken *domain.Token + tokenStore.EXPECT(). + UpdateTokenExpiry(mock.Anything, mock.Anything, mock.Anything). + Run(func(_ context.Context, t *domain.Token, _ string) { + updatedToken = t + }). + Return(nil) + + // Extend it + newTokenStr, err := svc.ExtendToken(ctx, tokenStr) + require.NoError(t, err, "ExtendToken should not fail") + + // Since the original token expires in 1h and the new one expires in 24h, they must differ + assert.NotEqual(t, tokenStr, newTokenStr, "expected a new token string after extension") + assert.NotNil(t, updatedToken) + assert.False(t, updatedToken.LastExtendedAt.IsZero(), "LastExtendedAt should be set") + + // New token must be valid — set up mock expectations for ValidateToken + tokenStore.EXPECT(). + GetToken(mock.Anything, "testuser"). + Return(newTokenStr, updatedToken, nil) + tokenStore.EXPECT(). + IsRevoked(mock.Anything, capturedToken.ID). + Return(false, nil) + + claims, err := svc.ValidateToken(ctx, newTokenStr) + require.NoError(t, err, "new token must be valid") + + // New token should have ~24h expiry from now + expectedExpiry := time.Now().Add(24 * time.Hour) + actualExpiry := time.Unix(claims.ExpiresAt, 0) + assert.WithinDuration(t, expectedExpiry, actualExpiry, 5*time.Minute, "expiry should be ~24h from now") +} + +func TestExtendTokenDebounce(t *testing.T) { + tokenStore := mocks.NewMockTokenStore(t) + + svc := NewService(Config{ + Enabled: true, + AuthType: domain.AuthTypeToken, + TokenSecret: []byte("test-secret-key-for-jwt-signing"), + }, tokenStore, zerowrap.Default()) + + ctx := testContext() + + // Generate a short-lived token (1h) so the first extension produces a different JWT + var capturedToken *domain.Token + tokenStore.EXPECT(). + SaveToken(mock.Anything, mock.Anything, mock.Anything). + Run(func(_ context.Context, t *domain.Token, _ string) { + capturedToken = t + }). + Return(nil) + + tokenStr, err := svc.GenerateToken(ctx, "testuser", []string{"admin:*:*"}, time.Hour) + require.NoError(t, err) + + // First extend: flow is ValidateToken(GetToken+IsRevoked) + GetToken(debounce) + UpdateTokenExpiry + tokenStore.EXPECT(). + GetToken(mock.Anything, "testuser"). + Return(tokenStr, capturedToken, nil).Once() + tokenStore.EXPECT(). + IsRevoked(mock.Anything, capturedToken.ID). + Return(false, nil).Once() + tokenStore.EXPECT(). + GetToken(mock.Anything, "testuser"). + Return(tokenStr, capturedToken, nil).Once() + + var extendedToken *domain.Token + var extendedJWT string + tokenStore.EXPECT(). + UpdateTokenExpiry(mock.Anything, mock.Anything, mock.Anything). + Run(func(_ context.Context, t *domain.Token, jwt string) { + extendedToken = t + extendedJWT = jwt + }). + Return(nil).Once() + + newToken1, err := svc.ExtendToken(ctx, tokenStr) + require.NoError(t, err) + assert.NotEqual(t, tokenStr, newToken1, "first extend should produce a new token") + assert.NotNil(t, extendedToken) + + // Second extend immediately — debounce should kick in (LastExtendedAt just set to now) + // Flow: ValidateToken(GetToken+IsRevoked) + GetToken(debounce) → debounced, no UpdateTokenExpiry + tokenStore.EXPECT(). + GetToken(mock.Anything, "testuser"). + Return(extendedJWT, extendedToken, nil).Once() + tokenStore.EXPECT(). + IsRevoked(mock.Anything, capturedToken.ID). + Return(false, nil).Once() + tokenStore.EXPECT(). + GetToken(mock.Anything, "testuser"). + Return(extendedJWT, extendedToken, nil).Once() + + // No UpdateTokenExpiry should be called — debounced + newToken2, err := svc.ExtendToken(ctx, newToken1) + require.NoError(t, err) + + assert.Equal(t, newToken1, newToken2, "expected debounce to return same token within 1h window") +} + +func TestExtendTokenSkipsEphemeral(t *testing.T) { + svc := NewService(Config{ + Enabled: true, + AuthType: domain.AuthTypeToken, + TokenSecret: []byte("test-secret-key-for-jwt-signing"), + }, nil, zerowrap.Default()) + + ctx := testContext() + + // Generate an ephemeral access token (≤5min) + tokenStr, err := svc.GenerateAccessToken(ctx, "testuser", []string{"pull"}, 5*time.Minute) + require.NoError(t, err) + + // ExtendToken should skip ephemeral tokens and return the same string + result, err := svc.ExtendToken(ctx, tokenStr) + require.NoError(t, err) + assert.Equal(t, tokenStr, result, "ephemeral tokens should not be extended") +} + +func newTestAuthService(t *testing.T) (*Service, *mocks.MockTokenStore) { + t.Helper() + tokenStore := mocks.NewMockTokenStore(t) + svc := NewService(Config{ + Enabled: true, + AuthType: domain.AuthTypeToken, + TokenSecret: []byte("test-secret-key-for-jwt-signing"), + }, tokenStore, zerowrap.Default()) + return svc, tokenStore +} + +func TestIsEphemeralAccessTokenRejectsFutureIat(t *testing.T) { + // A token with iat in the future should NOT be treated as ephemeral + // (otherwise it bypasses revocation checks) + svc, _ := newTestAuthService(t) + + // Craft a token with iat = now+10min, exp = now+15min (age appears as -10min → negative → < 5min). + // nbf is set to current time (not future) so the JWT library's nbf check does not reject it — + // we need to isolate the iat bypass specifically. + now := time.Now().UTC() + claims := jwt.MapClaims{ + "jti": "test-jti", + "sub": "testuser", + "iss": TokenIssuer, + "iat": now.Add(10 * time.Minute).Unix(), // future iat — the attack vector + "nbf": now.Add(-1 * time.Second).Unix(), // current nbf so nbf check passes + "exp": now.Add(15 * time.Minute).Unix(), // exp - iat = 5min → looks ephemeral + "scopes": []string{"admin:*:*"}, + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenStr, err := token.SignedString(svc.config.TokenSecret) + if err != nil { + t.Fatal(err) + } + + // Parsing should either reject the token outright (future iat) or not classify it as ephemeral + rawClaims, err := svc.parseTokenClaims(tokenStr) + if err != nil { + // Acceptable: future iat rejected at parse time + t.Logf("token with future iat rejected at parse: %v", err) + return + } + tokenClaims := buildTokenClaims(rawClaims) + if svc.isEphemeralAccessToken(tokenClaims) { + t.Error("token with future iat must NOT be classified as ephemeral — it bypasses revocation") + } +} + +func TestExtendTokenSkipsServiceToken(t *testing.T) { + tokenStore := mocks.NewMockTokenStore(t) + + svc := NewService(Config{ + Enabled: true, + AuthType: domain.AuthTypeToken, + TokenSecret: []byte("test-secret-key-for-jwt-signing"), + }, tokenStore, zerowrap.Default()) + + ctx := testContext() + + // Generate a service token — service tokens skip extension before store checks + tokenStore.EXPECT(). + SaveToken(mock.Anything, mock.Anything, mock.Anything). + Return(nil) + + const svcSubject = "gordon-service" + tokenStr, err := svc.GenerateToken(ctx, svcSubject, []string{"pull"}, 24*time.Hour) + require.NoError(t, err) + + // ExtendToken should skip service tokens WITHOUT touching the store + result, err := svc.ExtendToken(ctx, tokenStr) + require.NoError(t, err) + assert.Equal(t, tokenStr, result, "service tokens should not be extended") +} diff --git a/internal/usecase/config/service.go b/internal/usecase/config/service.go index d6bc05bb..ff9a22a2 100644 --- a/internal/usecase/config/service.go +++ b/internal/usecase/config/service.go @@ -418,9 +418,14 @@ func (s *Service) Save(ctx context.Context) error { s.mu.Lock() defer s.mu.Unlock() - // Update viper with current config + // Sync ALL mutable config sections back to Viper before writing. + // Using explicit Set() ensures the values are stored as proper flat maps, + // preventing viper from splitting dotted keys (e.g. "reg.example.com") into + // nested TOML subtrees which corrupts the data on re-read. s.viper.Set("routes", s.config.Routes) + s.viper.Set("external_routes", s.config.ExternalRoutes) s.viper.Set("attachments", s.config.Attachments) + s.viper.Set("network_groups", s.config.NetworkGroups) // Record save time to debounce file watcher events atomic.StoreInt64(&s.lastSaveTime, time.Now().UnixNano()) @@ -534,13 +539,6 @@ func (s *Service) GetConfig() Config { return s.config } -// GetRegistryAuthConfig returns registry authentication configuration. -func (s *Service) GetRegistryAuthConfig() (enabled bool, username, password string) { - s.mu.RLock() - defer s.mu.RUnlock() - return s.config.RegistryAuthEnabled, s.config.RegistryAuthUsername, s.config.RegistryAuthPassword -} - // GetVolumeConfig returns volume configuration. func (s *Service) GetVolumeConfig() (autoCreate bool, prefix string, preserve bool) { s.mu.RLock() diff --git a/internal/usecase/config/service_test.go b/internal/usecase/config/service_test.go index 9cd30145..f74da13f 100644 --- a/internal/usecase/config/service_test.go +++ b/internal/usecase/config/service_test.go @@ -53,11 +53,6 @@ func TestService_Load(t *testing.T) { assert.True(t, svc.IsNetworkIsolationEnabled()) assert.Equal(t, "gordon", svc.GetNetworkPrefix()) - enabled, username, password := svc.GetRegistryAuthConfig() - assert.True(t, enabled) - assert.Equal(t, "admin", username) - assert.Equal(t, "secret", password) - autoCreate, prefix, preserve := svc.GetVolumeConfig() assert.True(t, autoCreate) assert.Equal(t, "gordon", prefix) @@ -1091,6 +1086,174 @@ func TestService_FindRoutesByImage(t *testing.T) { }) } +func TestSavePreservesAllConfigFields(t *testing.T) { + t.Run("network_groups are persisted to disk", func(t *testing.T) { + // Setup: config file with auth, server, and network_groups + tmpDir := t.TempDir() + configFile := filepath.Join(tmpDir, "gordon.toml") + initialConfig := ` +[server] +port = 9999 + +[auth] +username = "admin" +token_secret = "supersecretvalue" + +[routes] +"app.example.com" = "myapp:latest" + +[network_groups] +frontend = ["app1.example.com", "app2.example.com"] +` + err := os.WriteFile(configFile, []byte(initialConfig), 0600) + require.NoError(t, err) + + v := viper.New() + v.SetConfigFile(configFile) + err = v.ReadInConfig() + require.NoError(t, err) + + eventBus := mocks.NewMockEventPublisher(t) + svc := NewService(v, eventBus) + ctx := testContext() + + err = svc.Load(ctx) + require.NoError(t, err) + + // Verify network_groups loaded correctly in memory + groups := svc.GetNetworkGroups() + require.Len(t, groups, 1) + require.ElementsMatch(t, []string{"app1.example.com", "app2.example.com"}, groups["frontend"]) + + // Trigger a Save via AddRoute (which calls Save internally) + err = svc.AddRoute(ctx, domain.Route{Domain: "new.example.com", Image: "newapp:latest"}) + require.NoError(t, err) + + // Re-read the file to verify network_groups was persisted + v2 := viper.New() + v2.SetConfigFile(configFile) + err = v2.ReadInConfig() + require.NoError(t, err) + + // Verify network_groups is still intact and correct after save + savedGroups := loadStringArrayMap(v2.Get("network_groups")) + assert.Len(t, savedGroups, 1, "network_groups should still be present in saved config") + assert.ElementsMatch(t, []string{"app1.example.com", "app2.example.com"}, savedGroups["frontend"]) + + // Verify auth fields are preserved + assert.Equal(t, "admin", v2.GetString("auth.username")) + assert.Equal(t, "supersecretvalue", v2.GetString("auth.token_secret")) + + // Verify server fields are preserved + assert.Equal(t, 9999, v2.GetInt("server.port")) + }) + + // Regression test: external_routes keys contain dots (e.g. "reg.example.com"). + // Without explicitly calling viper.Set("external_routes", ...) before WriteConfig, + // viper splits dotted keys into nested subtrees, corrupting the data on re-read. + t.Run("external_routes with dotted domain keys are not corrupted on Save", func(t *testing.T) { + tmpDir := t.TempDir() + configFile := filepath.Join(tmpDir, "gordon.toml") + initialConfig := ` +[server] +port = 8080 + +[routes] +"app.example.com" = "myapp:latest" + +[external_routes] +"reg.example.com" = "localhost:5000" +` + err := os.WriteFile(configFile, []byte(initialConfig), 0600) + require.NoError(t, err) + + v := viper.New() + v.SetConfigFile(configFile) + err = v.ReadInConfig() + require.NoError(t, err) + + eventBus := mocks.NewMockEventPublisher(t) + svc := NewService(v, eventBus) + ctx := testContext() + + err = svc.Load(ctx) + require.NoError(t, err) + + // Verify external_routes loaded correctly in memory + extRoutes := svc.GetExternalRoutes() + require.Equal(t, "localhost:5000", extRoutes["reg.example.com"]) + + // Trigger Save + err = svc.AddRoute(ctx, domain.Route{Domain: "new.example.com", Image: "newapp:latest"}) + require.NoError(t, err) + + // Re-read the file with a fresh viper instance + v2 := viper.New() + v2.SetConfigFile(configFile) + err = v2.ReadInConfig() + require.NoError(t, err) + + // external_routes must survive Save with correct flat map structure. + // Without the fix, viper splits "reg.example.com" into nested TOML subtrees + // and GetString("external_routes.reg\\.example\\.com") returns empty string. + savedExtRoutes := loadStringMap(v2.Get("external_routes")) + assert.Equal(t, "localhost:5000", savedExtRoutes["reg.example.com"], + "external_routes must not be corrupted by viper's dot-path splitting on WriteConfig") + }) + + t.Run("network_groups added in memory are written to disk on Save", func(t *testing.T) { + // Setup: config file WITHOUT network_groups + tmpDir := t.TempDir() + configFile := filepath.Join(tmpDir, "gordon.toml") + initialConfig := ` +[server] +port = 8080 + +[routes] +"app.example.com" = "myapp:latest" +` + err := os.WriteFile(configFile, []byte(initialConfig), 0600) + require.NoError(t, err) + + v := viper.New() + v.SetConfigFile(configFile) + err = v.ReadInConfig() + require.NoError(t, err) + + // Inject network_groups into viper as if they were loaded from some other source + // This simulates a case where network_groups exist in memory but not on disk + v.Set("network_groups", map[string]interface{}{ + "backend": []interface{}{"api.example.com"}, + }) + + eventBus := mocks.NewMockEventPublisher(t) + svc := NewService(v, eventBus) + ctx := testContext() + + err = svc.Load(ctx) + require.NoError(t, err) + + // Verify network_groups is loaded in memory + groups := svc.GetNetworkGroups() + require.Len(t, groups, 1) + + // Trigger Save + err = svc.AddRoute(ctx, domain.Route{Domain: "new.example.com", Image: "newapp:latest"}) + require.NoError(t, err) + + // Re-read the file + v2 := viper.New() + v2.SetConfigFile(configFile) + err = v2.ReadInConfig() + require.NoError(t, err) + + // network_groups must be present in the saved file + savedGroups := loadStringArrayMap(v2.Get("network_groups")) + assert.Len(t, savedGroups, 1, "network_groups added in memory must be written on Save") + assert.ElementsMatch(t, []string{"api.example.com"}, savedGroups["backend"]) + }) +} + func TestExtractDomainFromImageName(t *testing.T) { tests := []struct { name string diff --git a/internal/usecase/container/autoroute.go b/internal/usecase/container/autoroute.go index 535309e0..5b70ac9a 100644 --- a/internal/usecase/container/autoroute.go +++ b/internal/usecase/container/autoroute.go @@ -493,7 +493,7 @@ func writeEnvFile(path string, env map[string]string) error { if strings.ContainsAny(v, " \t\n\"'$\\") { v = fmt.Sprintf("\"%s\"", strings.ReplaceAll(v, "\"", "\\\"")) } - buf.WriteString(fmt.Sprintf("%s=%s\n", k, v)) + fmt.Fprintf(&buf, "%s=%s\n", k, v) } return os.WriteFile(path, buf.Bytes(), 0600) diff --git a/internal/usecase/container/events.go b/internal/usecase/container/events.go index 88975274..c207a548 100644 --- a/internal/usecase/container/events.go +++ b/internal/usecase/container/events.go @@ -107,11 +107,18 @@ func (h *ConfigReloadHandler) Handle(ctx context.Context, event domain.Event) er log.Warn().Err(err).Msg("failed to sync containers before reload, proceeding with current state") } + // Propagate updated attachment configuration so new attachments take effect + // without requiring a Gordon restart (fixes issue #87 part 2). + attachments := h.configSvc.GetAllAttachments(ctx) + log.Debug().Int("attachment_groups", len(attachments)).Msg("propagating updated attachment configuration") + h.containerSvc.UpdateAttachments(attachments) + log.Debug().Int("attachment_groups", len(attachments)).Msg("attachment configuration propagated") + currentContainers := h.containerSvc.List(ctx) activeRoutes := make(map[string]*domain.Container) for _, container := range currentContainers { - if route, exists := container.Labels["gordon.route"]; exists { + if route, exists := container.Labels[domain.LabelRoute]; exists { activeRoutes[route] = container } } @@ -119,7 +126,7 @@ func (h *ConfigReloadHandler) Handle(ctx context.Context, event domain.Event) er routes := h.configSvc.GetRoutes(ctx) for _, route := range routes { if container, exists := activeRoutes[route.Domain]; exists { - currentImage := container.Labels["gordon.image"] + currentImage := container.Labels[domain.LabelImage] if currentImage != route.Image { log.Info(). Str("domain", route.Domain). diff --git a/internal/usecase/container/events_test.go b/internal/usecase/container/events_test.go index 535999c7..cc638c69 100644 --- a/internal/usecase/container/events_test.go +++ b/internal/usecase/container/events_test.go @@ -209,6 +209,10 @@ func TestConfigReloadHandler_Handle_DeploysNewRoutes(t *testing.T) { // SyncContainers is called first containerSvc.EXPECT().SyncContainers(mock.Anything).Return(nil) + // Attachment config is propagated after sync + configSvc.EXPECT().GetAllAttachments(mock.Anything).Return(map[string][]string{}) + containerSvc.EXPECT().UpdateAttachments(map[string][]string{}).Return() + // No existing containers containerSvc.EXPECT().List(mock.Anything).Return(map[string]*domain.Container{}) @@ -241,6 +245,10 @@ func TestConfigReloadHandler_Handle_StopsRemovedRoutes(t *testing.T) { // SyncContainers is called first containerSvc.EXPECT().SyncContainers(mock.Anything).Return(nil) + // Attachment config is propagated after sync + configSvc.EXPECT().GetAllAttachments(mock.Anything).Return(map[string][]string{}) + containerSvc.EXPECT().UpdateAttachments(map[string][]string{}).Return() + // Existing container for a route that's no longer configured containerSvc.EXPECT().List(mock.Anything).Return(map[string]*domain.Container{ "removed.example.com": { @@ -277,6 +285,10 @@ func TestConfigReloadHandler_Handle_RedeploysChangedImage(t *testing.T) { // SyncContainers is called first containerSvc.EXPECT().SyncContainers(mock.Anything).Return(nil) + // Attachment config is propagated after sync + configSvc.EXPECT().GetAllAttachments(mock.Anything).Return(map[string][]string{}) + containerSvc.EXPECT().UpdateAttachments(map[string][]string{}).Return() + // Existing container with old image containerSvc.EXPECT().List(mock.Anything).Return(map[string]*domain.Container{ "app.example.com": { @@ -317,6 +329,10 @@ func TestConfigReloadHandler_Handle_NoChanges(t *testing.T) { // SyncContainers is called first containerSvc.EXPECT().SyncContainers(mock.Anything).Return(nil) + // Attachment config is propagated after sync + configSvc.EXPECT().GetAllAttachments(mock.Anything).Return(map[string][]string{}) + containerSvc.EXPECT().UpdateAttachments(map[string][]string{}).Return() + // Existing container matches config containerSvc.EXPECT().List(mock.Anything).Return(map[string]*domain.Container{ "app.example.com": { @@ -494,6 +510,40 @@ func TestManualReloadHandler_Handle_DoesNotRestartRunningContainers(t *testing.T // ManualDeployHandler tests +func TestConfigReloadHandlerUpdatesContainerConfig(t *testing.T) { + containerSvc := inmocks.NewMockContainerService(t) + configSvc := inmocks.NewMockConfigService(t) + + handler := NewConfigReloadHandler(testCtx(), containerSvc, configSvc) + + // SyncContainers is called first + containerSvc.EXPECT().SyncContainers(mock.Anything).Return(nil) + + // No existing containers + containerSvc.EXPECT().List(mock.Anything).Return(map[string]*domain.Container{}) + + // New attachments from config + newAttachments := map[string][]string{ + "app.example.com": {"postgres:15", "redis:7"}, + } + configSvc.EXPECT().GetAllAttachments(mock.Anything).Return(newAttachments) + + // Expect UpdateAttachments to be called with new attachment config + containerSvc.EXPECT().UpdateAttachments(newAttachments).Return() + + // No routes configured + configSvc.EXPECT().GetRoutes(mock.Anything).Return([]domain.Route{}) + + event := domain.Event{ + ID: "event-123", + Type: domain.EventConfigReload, + } + + err := handler.Handle(context.Background(), event) + + assert.NoError(t, err) +} + func TestManualDeployHandler_CanHandle(t *testing.T) { containerSvc := inmocks.NewMockContainerService(t) configSvc := inmocks.NewMockConfigService(t) diff --git a/internal/usecase/container/service.go b/internal/usecase/container/service.go index 7197d9e5..0e5cf71f 100644 --- a/internal/usecase/container/service.go +++ b/internal/usecase/container/service.go @@ -219,10 +219,10 @@ func (s *Service) buildContainerConfig(containerDomain, image, actualImageRef st NetworkMode: networkName, Hostname: containerDomain, Labels: map[string]string{ - "gordon.domain": containerDomain, - "gordon.image": image, - "gordon.managed": "true", - "gordon.route": containerDomain, + domain.LabelDomain: containerDomain, + domain.LabelImage: image, + domain.LabelManaged: "true", + domain.LabelRoute: containerDomain, }, AutoRemove: false, } @@ -903,16 +903,26 @@ func (s *Service) SyncContainers(ctx context.Context) error { } managed := make(map[string]*domain.Container) + attachments := make(map[string][]string) for _, c := range allContainers { - if c.Labels != nil { - if d, ok := c.Labels["gordon.domain"]; ok && c.Labels["gordon.managed"] == "true" { - managed[d] = c + if c.Labels == nil || c.Labels[domain.LabelManaged] != "true" { + continue + } + if c.Labels[domain.LabelAttachment] == "true" { + owner := c.Labels[domain.LabelAttachedTo] + if owner != "" { + attachments[owner] = append(attachments[owner], c.ID) } + continue + } + if d, ok := c.Labels[domain.LabelDomain]; ok { + managed[d] = c } } s.mu.Lock() s.containers = managed + s.attachments = attachments newCount := int64(len(managed)) delta := newCount - s.managedCount s.managedCount = newCount @@ -1061,6 +1071,24 @@ func (s *Service) UpdateConfig(config Config) { s.mu.Unlock() } +// UpdateAttachments updates only the attachment configuration in the service. +// This is called after a config reload to propagate attachment changes without restart. +// The incoming map is deep-copied so external callers cannot mutate service state. +func (s *Service) UpdateAttachments(attachments map[string][]string) { + var copied map[string][]string + if attachments != nil { + copied = make(map[string][]string, len(attachments)) + for k, v := range attachments { + sl := make([]string, len(v)) + copy(sl, v) + copied[k] = sl + } + } + s.mu.Lock() + s.config.Attachments = copied + s.mu.Unlock() +} + // Helper methods // cleanupFailedContainer stops and removes a container that failed to start properly. @@ -1875,10 +1903,10 @@ func (s *Service) deployAttachedService(ctx context.Context, ownerDomain, servic Volumes: volumes, NetworkMode: networkName, // Same network as main app Labels: map[string]string{ - "gordon.managed": "true", - "gordon.attachment": "true", - "gordon.attached-to": ownerDomain, - "gordon.image": serviceImage, + domain.LabelManaged: "true", + domain.LabelAttachment: "true", + domain.LabelAttachedTo: ownerDomain, + domain.LabelImage: serviceImage, }, } diff --git a/internal/usecase/container/service_test.go b/internal/usecase/container/service_test.go index cee84566..fae2eeb8 100644 --- a/internal/usecase/container/service_test.go +++ b/internal/usecase/container/service_test.go @@ -2332,6 +2332,49 @@ func TestService_Deploy_DoesNotSkipWhenImageIDDiffers(t *testing.T) { assert.Equal(t, "new-container", result.ID) } +func TestSyncContainersRebuildsAttachmentMap(t *testing.T) { + runtime := mocks.NewMockContainerRuntime(t) + envLoader := mocks.NewMockEnvLoader(t) + eventBus := mocks.NewMockEventPublisher(t) + + svc := NewService(runtime, envLoader, eventBus, nil, Config{}) + ctx := testContext() + + // ListContainers returns one main container and one attachment container + runtime.EXPECT().ListContainers(mock.Anything, false).Return([]*domain.Container{ + { + ID: "main-container-1", + Labels: map[string]string{ + domain.LabelDomain: "app.example.com", + domain.LabelManaged: "true", + }, + }, + { + ID: "attachment-container-1", + Labels: map[string]string{ + domain.LabelManaged: "true", + domain.LabelAttachment: "true", + domain.LabelAttachedTo: "app.example.com", + }, + }, + }, nil).Once() + + require.NoError(t, svc.SyncContainers(ctx)) + + // Main container should be in s.containers + tracked, exists := svc.Get(ctx, "app.example.com") + require.True(t, exists, "main container should be tracked after SyncContainers") + assert.Equal(t, "main-container-1", tracked.ID) + + // Attachment container should be in s.attachments under the owner domain + svc.mu.RLock() + attachIDs := svc.attachments["app.example.com"] + svc.mu.RUnlock() + + require.Len(t, attachIDs, 1, "attachment map should contain one entry for app.example.com after SyncContainers") + assert.Equal(t, "attachment-container-1", attachIDs[0]) +} + func TestService_Deploy_SkipRedundantDeploy_ContainerNotRunning(t *testing.T) { runtime := mocks.NewMockContainerRuntime(t) envLoader := mocks.NewMockEnvLoader(t)