diff --git a/.airlock/lint.sh b/.airlock/lint.sh new file mode 100755 index 0000000000..f707806ca1 --- /dev/null +++ b/.airlock/lint.sh @@ -0,0 +1,70 @@ +#!/usr/bin/env bash +set -euo pipefail + +REPO_ROOT="$(git rev-parse --show-toplevel)" +cd "$REPO_ROOT" + +# Compute changed files between base and head +BASE="${AIRLOCK_BASE_SHA:-HEAD~1}" +HEAD="${AIRLOCK_HEAD_SHA:-HEAD}" +CHANGED_FILES=$(git diff --name-only --diff-filter=ACMR "$BASE" "$HEAD" 2>/dev/null || git diff --name-only --cached) + +# Filter by language +GO_FILES=$(echo "$CHANGED_FILES" | grep '\.go$' || true) +PY_FILES=$(echo "$CHANGED_FILES" | grep '\.py$' || true) + +ERRORS=0 + +# --- Go --- +if [[ -n "$GO_FILES" ]]; then + echo "=== Go: gofmt (auto-fix) ===" + echo "$GO_FILES" | xargs -I{} gofmt -w "{}" 2>/dev/null || true + + echo "=== Go: golangci-lint ===" + # Get unique directories containing changed Go files + GO_DIRS=$(echo "$GO_FILES" | xargs -I{} dirname "{}" | sort -u | sed 's|$|/...|') + # Run golangci-lint but only report issues in changed files + LINT_OUTPUT=$(golangci-lint run --out-format line-number $GO_DIRS 2>&1 || true) + if [[ -n "$LINT_OUTPUT" ]]; then + # Filter to only issues in changed files + FILTERED="" + while IFS= read -r file; do + MATCH=$(echo "$LINT_OUTPUT" | grep "^${file}:" || true) + if [[ -n "$MATCH" ]]; then + FILTERED="${FILTERED}${MATCH}"$'\n' + fi + done <<< "$GO_FILES" + if [[ -n "${FILTERED// /}" ]] && [[ "${FILTERED}" != $'\n' ]]; then + echo "$FILTERED" + echo "golangci-lint: issues found in changed files" + ERRORS=1 + else + echo "golangci-lint: OK (issues only in unchanged files, skipping)" + fi + else + echo "golangci-lint: OK" + fi +fi + +# --- Python --- +if [[ -n "$PY_FILES" ]]; then + echo "=== Python: ruff format (auto-fix) ===" + echo "$PY_FILES" | xargs ruff format 2>/dev/null || true + + echo "=== Python: ruff check --fix ===" + echo "$PY_FILES" | xargs ruff check --fix 2>/dev/null || true + + echo "=== Python: ruff check (verify) ===" + if echo "$PY_FILES" | xargs ruff check 2>&1; then + echo "ruff check: OK" + else + echo "ruff check: issues found" + ERRORS=1 + fi +fi + +if [[ -z "$GO_FILES" && -z "$PY_FILES" ]]; then + echo "No Go or Python files changed. Nothing to lint." +fi + +exit $ERRORS diff --git a/.github/workflows/lint-test.yml b/.github/workflows/lint-test.yml new file mode 100644 index 0000000000..b8a831c87d --- /dev/null +++ b/.github/workflows/lint-test.yml @@ -0,0 +1,18 @@ +name: Lint & Test + +on: + pull_request: + types: [opened, synchronize, reopened] + +permissions: + contents: read + +jobs: + lint-test: + name: lint-test + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - uses: KooshaPari/phenotypeActions/actions/lint-test@main diff --git a/go.mod b/go.mod index 2c89cb28af..8bdb26b4c5 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/kooshapari/cliproxyapi-plusplus/v6 go 1.26.0 require ( + github.com/KooshaPari/phenotype-go-auth v0.0.0 github.com/andybalholm/brotli v1.2.0 github.com/atotto/clipboard v0.1.4 github.com/charmbracelet/bubbles v1.0.0 @@ -117,3 +118,4 @@ require ( modernc.org/memory v1.11.0 // indirect ) +replace github.com/KooshaPari/phenotype-go-auth => ../../../template-commons/phenotype-go-auth diff --git a/internal/auth/claude/token.go b/internal/auth/claude/token.go index ff1332d880..d005117d3f 100644 --- a/internal/auth/claude/token.go +++ b/internal/auth/claude/token.go @@ -4,7 +4,7 @@ package claude import ( - "github.com/kooshapari/cliproxyapi-plusplus/v6/internal/auth/base" + "github.com/KooshaPari/phenotype-go-auth" "github.com/kooshapari/cliproxyapi-plusplus/v6/internal/misc" ) diff --git a/internal/auth/copilot/token.go b/internal/auth/copilot/token.go index 89f284ca1d..f70340822b 100644 --- a/internal/auth/copilot/token.go +++ b/internal/auth/copilot/token.go @@ -4,7 +4,7 @@ package copilot import ( - "github.com/kooshapari/cliproxyapi-plusplus/v6/internal/auth/base" + "github.com/KooshaPari/phenotype-go-auth" "github.com/kooshapari/cliproxyapi-plusplus/v6/internal/misc" ) diff --git a/internal/auth/gemini/gemini_token.go b/internal/auth/gemini/gemini_token.go index 1c6e18f37a..75f4ff70ad 100644 --- a/internal/auth/gemini/gemini_token.go +++ b/internal/auth/gemini/gemini_token.go @@ -7,7 +7,7 @@ import ( "fmt" "strings" - "github.com/kooshapari/cliproxyapi-plusplus/v6/internal/auth/base" + "github.com/KooshaPari/phenotype-go-auth" "github.com/kooshapari/cliproxyapi-plusplus/v6/internal/misc" ) diff --git a/pkg/llmproxy/auth/base/token_storage.go b/pkg/llmproxy/auth/base/token_storage.go new file mode 100644 index 0000000000..83a03903f3 --- /dev/null +++ b/pkg/llmproxy/auth/base/token_storage.go @@ -0,0 +1,129 @@ +// Package base provides a shared foundation for OAuth2 token storage across all +// LLM proxy authentication providers. It centralises the common Save/Load/Clear +// file-I/O operations so that individual provider packages only need to embed +// BaseTokenStorage and add their own provider-specific fields. +package base + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + + "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/misc" +) + +// BaseTokenStorage holds the fields and file-I/O methods that every provider +// token struct shares. Provider-specific structs embed a *BaseTokenStorage +// (or a copy by value) and extend it with their own fields. +type BaseTokenStorage struct { + // AccessToken is the OAuth2 bearer token used to authenticate API requests. + AccessToken string `json:"access_token"` + + // RefreshToken is used to obtain a new access token when the current one expires. + RefreshToken string `json:"refresh_token,omitempty"` + + // Email is the account e-mail address associated with this token. + Email string `json:"email,omitempty"` + + // Type is the provider identifier (e.g. "claude", "codex", "kimi"). + // Each provider sets this before saving so that callers can identify + // which authentication provider a credential file belongs to. + Type string `json:"type"` + + // FilePath is the on-disk path used by Save/Load/Clear. It is not + // serialised to JSON; it is populated at runtime from the caller-supplied + // authFilePath argument. + FilePath string `json:"-"` +} + +// GetAccessToken returns the OAuth2 access token. +func (b *BaseTokenStorage) GetAccessToken() string { return b.AccessToken } + +// GetRefreshToken returns the OAuth2 refresh token. +func (b *BaseTokenStorage) GetRefreshToken() string { return b.RefreshToken } + +// GetEmail returns the e-mail address associated with the token. +func (b *BaseTokenStorage) GetEmail() string { return b.Email } + +// GetType returns the provider type string. +func (b *BaseTokenStorage) GetType() string { return b.Type } + +// Save serialises v (the outer provider struct that embeds BaseTokenStorage) +// to the file at authFilePath using an atomic write (write to a temp file, +// then rename). The directory is created if it does not already exist. +// +// v must be JSON-marshallable. Passing the provider struct rather than +// BaseTokenStorage itself ensures that all provider-specific fields are +// persisted alongside the base fields. +func (b *BaseTokenStorage) Save(authFilePath string, v any) error { + safePath, err := misc.ResolveSafeFilePath(authFilePath) + if err != nil { + return fmt.Errorf("base token storage: invalid file path: %w", err) + } + misc.LogSavingCredentials(safePath) + + if err = os.MkdirAll(filepath.Dir(safePath), 0o700); err != nil { + return fmt.Errorf("base token storage: create directory: %w", err) + } + + // Write to a temporary file in the same directory, then rename so that + // a concurrent reader never observes a partially-written file. + tmpFile, err := os.CreateTemp(filepath.Dir(safePath), ".tmp-token-*") + if err != nil { + return fmt.Errorf("base token storage: create temp file: %w", err) + } + tmpPath := tmpFile.Name() + + writeErr := json.NewEncoder(tmpFile).Encode(v) + closeErr := tmpFile.Close() + + if writeErr != nil { + _ = os.Remove(tmpPath) + return fmt.Errorf("base token storage: encode token: %w", writeErr) + } + if closeErr != nil { + _ = os.Remove(tmpPath) + return fmt.Errorf("base token storage: close temp file: %w", closeErr) + } + + if err = os.Rename(tmpPath, safePath); err != nil { + _ = os.Remove(tmpPath) + return fmt.Errorf("base token storage: rename temp file: %w", err) + } + return nil +} + +// Load reads the JSON file at authFilePath and unmarshals it into v. +// v should be a pointer to the outer provider struct so that all fields +// are populated. +func (b *BaseTokenStorage) Load(authFilePath string, v any) error { + safePath, err := misc.ResolveSafeFilePath(authFilePath) + if err != nil { + return fmt.Errorf("base token storage: invalid file path: %w", err) + } + + data, err := os.ReadFile(safePath) + if err != nil { + return fmt.Errorf("base token storage: read token file: %w", err) + } + + if err = json.Unmarshal(data, v); err != nil { + return fmt.Errorf("base token storage: unmarshal token: %w", err) + } + return nil +} + +// Clear removes the token file at authFilePath. It returns nil if the file +// does not exist (idempotent delete). +func (b *BaseTokenStorage) Clear(authFilePath string) error { + safePath, err := misc.ResolveSafeFilePath(authFilePath) + if err != nil { + return fmt.Errorf("base token storage: invalid file path: %w", err) + } + + if err = os.Remove(safePath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("base token storage: remove token file: %w", err) + } + return nil +} diff --git a/pkg/llmproxy/auth/claude/anthropic_auth.go b/pkg/llmproxy/auth/claude/anthropic_auth.go index 78a889ff3c..ec06454aa1 100644 --- a/pkg/llmproxy/auth/claude/anthropic_auth.go +++ b/pkg/llmproxy/auth/claude/anthropic_auth.go @@ -13,7 +13,8 @@ import ( "strings" "time" - "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/config" + "github.com/kooshapari/cliproxyapi-plusplus/v6/internal/config" + "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/auth/base" log "github.com/sirupsen/logrus" ) @@ -293,11 +294,13 @@ func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*C // - *ClaudeTokenStorage: A new token storage instance func (o *ClaudeAuth) CreateTokenStorage(bundle *ClaudeAuthBundle) *ClaudeTokenStorage { storage := &ClaudeTokenStorage{ - AccessToken: bundle.TokenData.AccessToken, - RefreshToken: bundle.TokenData.RefreshToken, - LastRefresh: bundle.LastRefresh, - Email: bundle.TokenData.Email, - Expire: bundle.TokenData.Expire, + BaseTokenStorage: base.BaseTokenStorage{ + AccessToken: bundle.TokenData.AccessToken, + RefreshToken: bundle.TokenData.RefreshToken, + Email: bundle.TokenData.Email, + }, + LastRefresh: bundle.LastRefresh, + Expire: bundle.TokenData.Expire, } return storage diff --git a/pkg/llmproxy/auth/claude/token.go b/pkg/llmproxy/auth/claude/token.go index 550331ed82..22bf50cbda 100644 --- a/pkg/llmproxy/auth/claude/token.go +++ b/pkg/llmproxy/auth/claude/token.go @@ -4,58 +4,23 @@ package claude import ( - "encoding/json" "fmt" - "os" - "path/filepath" - "strings" - "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/misc" + "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/auth/base" ) -func sanitizeTokenFilePath(authFilePath string) (string, error) { - trimmed := strings.TrimSpace(authFilePath) - if trimmed == "" { - return "", fmt.Errorf("token file path is empty") - } - cleaned := filepath.Clean(trimmed) - parts := strings.FieldsFunc(cleaned, func(r rune) bool { - return r == '/' || r == '\\' - }) - for _, part := range parts { - if part == ".." { - return "", fmt.Errorf("invalid token file path") - } - } - absPath, err := filepath.Abs(cleaned) - if err != nil { - return "", fmt.Errorf("failed to resolve token file path: %w", err) - } - return absPath, nil -} - // ClaudeTokenStorage stores OAuth2 token information for Anthropic Claude API authentication. // It maintains compatibility with the existing auth system while adding Claude-specific fields // for managing access tokens, refresh tokens, and user account information. type ClaudeTokenStorage struct { + base.BaseTokenStorage + // IDToken is the JWT ID token containing user claims and identity information. IDToken string `json:"id_token"` - // AccessToken is the OAuth2 access token used for authenticating API requests. - AccessToken string `json:"access_token"` - - // RefreshToken is used to obtain new access tokens when the current one expires. - RefreshToken string `json:"refresh_token"` - // LastRefresh is the timestamp of the last token refresh operation. LastRefresh string `json:"last_refresh"` - // Email is the Anthropic account email address associated with this token. - Email string `json:"email"` - - // Type indicates the authentication provider type, always "claude" for this storage. - Type string `json:"type"` - // Expire is the timestamp when the current access token expires. Expire string `json:"expired"` } @@ -70,34 +35,9 @@ type ClaudeTokenStorage struct { // Returns: // - error: An error if the operation fails, nil otherwise func (ts *ClaudeTokenStorage) SaveTokenToFile(authFilePath string) error { - safePath, err := misc.ResolveSafeFilePath(authFilePath) - if err != nil { - return fmt.Errorf("invalid token file path: %w", err) - } - misc.LogSavingCredentials(safePath) ts.Type = "claude" - safePath, err = sanitizeTokenFilePath(authFilePath) - if err != nil { - return err - } - - // Create directory structure if it doesn't exist - if err = os.MkdirAll(filepath.Dir(safePath), 0700); err != nil { - return fmt.Errorf("failed to create directory: %v", err) - } - - // Create the token file - f, err := os.Create(safePath) - if err != nil { - return fmt.Errorf("failed to create token file: %w", err) - } - defer func() { - _ = f.Close() - }() - - // Encode and write the token data as JSON - if err = json.NewEncoder(f).Encode(ts); err != nil { - return fmt.Errorf("failed to write token to file: %w", err) + if err := ts.Save(authFilePath, ts); err != nil { + return fmt.Errorf("claude token: %w", err) } return nil } diff --git a/pkg/llmproxy/auth/claude/utls_transport.go b/pkg/llmproxy/auth/claude/utls_transport.go index 5d1f7f1660..1f8f2c900b 100644 --- a/pkg/llmproxy/auth/claude/utls_transport.go +++ b/pkg/llmproxy/auth/claude/utls_transport.go @@ -8,8 +8,8 @@ import ( "strings" "sync" + pkgconfig "github.com/kooshapari/cliproxyapi-plusplus/v6/internal/config" tls "github.com/refraction-networking/utls" - pkgconfig "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/config" log "github.com/sirupsen/logrus" "golang.org/x/net/http2" "golang.org/x/net/proxy" diff --git a/pkg/llmproxy/auth/codex/openai_auth.go b/pkg/llmproxy/auth/codex/openai_auth.go index 9652caeba6..3adc4e469e 100644 --- a/pkg/llmproxy/auth/codex/openai_auth.go +++ b/pkg/llmproxy/auth/codex/openai_auth.go @@ -14,7 +14,8 @@ import ( "strings" "time" - "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/config" + "github.com/kooshapari/cliproxyapi-plusplus/v6/internal/config" + "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/auth/base" "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/util" log "github.com/sirupsen/logrus" ) @@ -257,13 +258,15 @@ func (o *CodexAuth) RefreshTokens(ctx context.Context, refreshToken string) (*Co // It populates the storage struct with token data, user information, and timestamps. func (o *CodexAuth) CreateTokenStorage(bundle *CodexAuthBundle) *CodexTokenStorage { storage := &CodexTokenStorage{ - IDToken: bundle.TokenData.IDToken, - AccessToken: bundle.TokenData.AccessToken, - RefreshToken: bundle.TokenData.RefreshToken, - AccountID: bundle.TokenData.AccountID, - LastRefresh: bundle.LastRefresh, - Email: bundle.TokenData.Email, - Expire: bundle.TokenData.Expire, + BaseTokenStorage: base.BaseTokenStorage{ + AccessToken: bundle.TokenData.AccessToken, + RefreshToken: bundle.TokenData.RefreshToken, + Email: bundle.TokenData.Email, + }, + IDToken: bundle.TokenData.IDToken, + AccountID: bundle.TokenData.AccountID, + LastRefresh: bundle.LastRefresh, + Expire: bundle.TokenData.Expire, } return storage diff --git a/pkg/llmproxy/auth/codex/token.go b/pkg/llmproxy/auth/codex/token.go index 4e8c3fb2ac..297ffdd003 100644 --- a/pkg/llmproxy/auth/codex/token.go +++ b/pkg/llmproxy/auth/codex/token.go @@ -4,54 +4,23 @@ package codex import ( - "encoding/json" "fmt" - "os" - "path/filepath" - "strings" - "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/misc" + "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/auth/base" ) -func sanitizeTokenFilePath(authFilePath string) (string, error) { - trimmed := strings.TrimSpace(authFilePath) - if trimmed == "" { - return "", fmt.Errorf("token file path is empty") - } - cleaned := filepath.Clean(trimmed) - parts := strings.FieldsFunc(cleaned, func(r rune) bool { - return r == '/' || r == '\\' - }) - for _, part := range parts { - if part == ".." { - return "", fmt.Errorf("invalid token file path") - } - } - absPath, err := filepath.Abs(cleaned) - if err != nil { - return "", fmt.Errorf("failed to resolve token file path: %w", err) - } - return absPath, nil -} - // CodexTokenStorage stores OAuth2 token information for OpenAI Codex API authentication. // It maintains compatibility with the existing auth system while adding Codex-specific fields // for managing access tokens, refresh tokens, and user account information. type CodexTokenStorage struct { + base.BaseTokenStorage + // IDToken is the JWT ID token containing user claims and identity information. IDToken string `json:"id_token"` - // AccessToken is the OAuth2 access token used for authenticating API requests. - AccessToken string `json:"access_token"` - // RefreshToken is used to obtain new access tokens when the current one expires. - RefreshToken string `json:"refresh_token"` // AccountID is the OpenAI account identifier associated with this token. AccountID string `json:"account_id"` // LastRefresh is the timestamp of the last token refresh operation. LastRefresh string `json:"last_refresh"` - // Email is the OpenAI account email address associated with this token. - Email string `json:"email"` - // Type indicates the authentication provider type, always "codex" for this storage. - Type string `json:"type"` // Expire is the timestamp when the current access token expires. Expire string `json:"expired"` } @@ -66,27 +35,9 @@ type CodexTokenStorage struct { // Returns: // - error: An error if the operation fails, nil otherwise func (ts *CodexTokenStorage) SaveTokenToFile(authFilePath string) error { - safePath, err := misc.ResolveSafeFilePath(authFilePath) - if err != nil { - return fmt.Errorf("invalid token file path: %w", err) - } - misc.LogSavingCredentials(safePath) ts.Type = "codex" - if err = os.MkdirAll(filepath.Dir(safePath), 0700); err != nil { - return fmt.Errorf("failed to create directory: %v", err) - } - - f, err := os.Create(safePath) - if err != nil { - return fmt.Errorf("failed to create token file: %w", err) - } - defer func() { - _ = f.Close() - }() - - if err = json.NewEncoder(f).Encode(ts); err != nil { - return fmt.Errorf("failed to write token to file: %w", err) + if err := ts.Save(authFilePath, ts); err != nil { + return fmt.Errorf("codex token: %w", err) } return nil - } diff --git a/pkg/llmproxy/auth/copilot/copilot_auth.go b/pkg/llmproxy/auth/copilot/copilot_auth.go index 0fac429994..bff26bece4 100644 --- a/pkg/llmproxy/auth/copilot/copilot_auth.go +++ b/pkg/llmproxy/auth/copilot/copilot_auth.go @@ -10,7 +10,8 @@ import ( "net/http" "time" - "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/config" + "github.com/kooshapari/cliproxyapi-plusplus/v6/internal/config" + "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/auth/base" "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/util" log "github.com/sirupsen/logrus" ) @@ -164,11 +165,13 @@ func (c *CopilotAuth) ValidateToken(ctx context.Context, accessToken string) (bo // CreateTokenStorage creates a new CopilotTokenStorage from auth bundle. func (c *CopilotAuth) CreateTokenStorage(bundle *CopilotAuthBundle) *CopilotTokenStorage { return &CopilotTokenStorage{ - AccessToken: bundle.TokenData.AccessToken, - TokenType: bundle.TokenData.TokenType, - Scope: bundle.TokenData.Scope, - Username: bundle.Username, - Type: "github-copilot", + BaseTokenStorage: base.BaseTokenStorage{ + AccessToken: bundle.TokenData.AccessToken, + Type: "github-copilot", + }, + TokenType: bundle.TokenData.TokenType, + Scope: bundle.TokenData.Scope, + Username: bundle.Username, } } diff --git a/pkg/llmproxy/auth/copilot/token.go b/pkg/llmproxy/auth/copilot/token.go index 657428a982..cecb2e4441 100644 --- a/pkg/llmproxy/auth/copilot/token.go +++ b/pkg/llmproxy/auth/copilot/token.go @@ -4,20 +4,17 @@ package copilot import ( - "encoding/json" "fmt" - "os" - "path/filepath" - "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/misc" + "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/auth/base" ) // CopilotTokenStorage stores OAuth2 token information for GitHub Copilot API authentication. // It maintains compatibility with the existing auth system while adding Copilot-specific fields // for managing access tokens and user account information. type CopilotTokenStorage struct { - // AccessToken is the OAuth2 access token used for authenticating API requests. - AccessToken string `json:"access_token"` + base.BaseTokenStorage + // TokenType is the type of token, typically "bearer". TokenType string `json:"token_type"` // Scope is the OAuth2 scope granted to the token. @@ -26,8 +23,6 @@ type CopilotTokenStorage struct { ExpiresAt string `json:"expires_at,omitempty"` // Username is the GitHub username associated with this token. Username string `json:"username"` - // Type indicates the authentication provider type, always "github-copilot" for this storage. - Type string `json:"type"` } // CopilotTokenData holds the raw OAuth token response from GitHub. @@ -72,26 +67,9 @@ type DeviceCodeResponse struct { // Returns: // - error: An error if the operation fails, nil otherwise func (ts *CopilotTokenStorage) SaveTokenToFile(authFilePath string) error { - safePath, err := misc.ResolveSafeFilePath(authFilePath) - if err != nil { - return fmt.Errorf("invalid token file path: %w", err) - } - misc.LogSavingCredentials(safePath) ts.Type = "github-copilot" - if err = os.MkdirAll(filepath.Dir(safePath), 0700); err != nil { - return fmt.Errorf("failed to create directory: %v", err) - } - - f, err := os.Create(safePath) - if err != nil { - return fmt.Errorf("failed to create token file: %w", err) - } - defer func() { - _ = f.Close() - }() - - if err = json.NewEncoder(f).Encode(ts); err != nil { - return fmt.Errorf("failed to write token to file: %w", err) + if err := ts.Save(authFilePath, ts); err != nil { + return fmt.Errorf("copilot token: %w", err) } return nil } diff --git a/pkg/llmproxy/auth/gemini/gemini_auth.go b/pkg/llmproxy/auth/gemini/gemini_auth.go index 6553cd26d2..08badb1283 100644 --- a/pkg/llmproxy/auth/gemini/gemini_auth.go +++ b/pkg/llmproxy/auth/gemini/gemini_auth.go @@ -14,9 +14,10 @@ import ( "net/url" "time" + "github.com/kooshapari/cliproxyapi-plusplus/v6/internal/config" + "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/auth/base" "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/auth/codex" "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/browser" - "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/config" "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/misc" "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/util" log "github.com/sirupsen/logrus" @@ -204,9 +205,11 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf ifToken["universe_domain"] = "googleapis.com" ts := GeminiTokenStorage{ + BaseTokenStorage: base.BaseTokenStorage{ + Email: emailResult.String(), + }, Token: ifToken, ProjectID: projectID, - Email: emailResult.String(), } return &ts, nil diff --git a/pkg/llmproxy/auth/gemini/gemini_token.go b/pkg/llmproxy/auth/gemini/gemini_token.go index 1fb90dae57..32b05729a1 100644 --- a/pkg/llmproxy/auth/gemini/gemini_token.go +++ b/pkg/llmproxy/auth/gemini/gemini_token.go @@ -4,37 +4,33 @@ package gemini import ( - "encoding/json" "fmt" - "os" - "path/filepath" "strings" - "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/misc" - log "github.com/sirupsen/logrus" + "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/auth/base" ) // GeminiTokenStorage stores OAuth2 token information for Google Gemini API authentication. // It maintains compatibility with the existing auth system while adding Gemini-specific fields // for managing access tokens, refresh tokens, and user account information. +// +// Note: Gemini wraps its raw OAuth2 token inside the Token field (type any) rather than +// storing access/refresh tokens as top-level strings, so BaseTokenStorage.AccessToken and +// BaseTokenStorage.RefreshToken remain empty for this provider. type GeminiTokenStorage struct { + base.BaseTokenStorage + // Token holds the raw OAuth2 token data, including access and refresh tokens. Token any `json:"token"` // ProjectID is the Google Cloud Project ID associated with this token. ProjectID string `json:"project_id"` - // Email is the email address of the authenticated user. - Email string `json:"email"` - // Auto indicates if the project ID was automatically selected. Auto bool `json:"auto"` // Checked indicates if the associated Cloud AI API has been verified as enabled. Checked bool `json:"checked"` - - // Type indicates the authentication provider type, always "gemini" for this storage. - Type string `json:"type"` } // SaveTokenToFile serializes the Gemini token storage to a JSON file. @@ -47,28 +43,9 @@ type GeminiTokenStorage struct { // Returns: // - error: An error if the operation fails, nil otherwise func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error { - safePath, err := misc.ResolveSafeFilePath(authFilePath) - if err != nil { - return fmt.Errorf("invalid token file path: %w", err) - } - misc.LogSavingCredentials(safePath) ts.Type = "gemini" - if err = os.MkdirAll(filepath.Dir(safePath), 0700); err != nil { - return fmt.Errorf("failed to create directory: %v", err) - } - - f, err := os.Create(safePath) - if err != nil { - return fmt.Errorf("failed to create token file: %w", err) - } - defer func() { - if errClose := f.Close(); errClose != nil { - log.Errorf("failed to close file: %v", errClose) - } - }() - - if err = json.NewEncoder(f).Encode(ts); err != nil { - return fmt.Errorf("failed to write token to file: %w", err) + if err := ts.Save(authFilePath, ts); err != nil { + return fmt.Errorf("gemini token: %w", err) } return nil } diff --git a/pkg/llmproxy/auth/iflow/iflow_auth.go b/pkg/llmproxy/auth/iflow/iflow_auth.go index d1c8fe26b2..a4ead0e04c 100644 --- a/pkg/llmproxy/auth/iflow/iflow_auth.go +++ b/pkg/llmproxy/auth/iflow/iflow_auth.go @@ -13,7 +13,8 @@ import ( "strings" "time" - "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/config" + "github.com/kooshapari/cliproxyapi-plusplus/v6/internal/config" + "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/auth/base" "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/util" log "github.com/sirupsen/logrus" ) @@ -243,14 +244,16 @@ func (ia *IFlowAuth) CreateTokenStorage(data *IFlowTokenData) *IFlowTokenStorage return nil } return &IFlowTokenStorage{ - AccessToken: data.AccessToken, - RefreshToken: data.RefreshToken, - LastRefresh: time.Now().Format(time.RFC3339), - Expire: data.Expire, - APIKey: data.APIKey, - Email: data.Email, - TokenType: data.TokenType, - Scope: data.Scope, + BaseTokenStorage: base.BaseTokenStorage{ + AccessToken: data.AccessToken, + RefreshToken: data.RefreshToken, + Email: data.Email, + }, + LastRefresh: time.Now().Format(time.RFC3339), + Expire: data.Expire, + APIKey: data.APIKey, + TokenType: data.TokenType, + Scope: data.Scope, } } @@ -528,12 +531,14 @@ func (ia *IFlowAuth) CreateCookieTokenStorage(data *IFlowTokenData) *IFlowTokenS } return &IFlowTokenStorage{ + BaseTokenStorage: base.BaseTokenStorage{ + Email: data.Email, + Type: "iflow", + }, APIKey: data.APIKey, - Email: data.Email, Expire: data.Expire, Cookie: cookieToSave, LastRefresh: time.Now().Format(time.RFC3339), - Type: "iflow", } } diff --git a/pkg/llmproxy/auth/iflow/iflow_token.go b/pkg/llmproxy/auth/iflow/iflow_token.go index b67f7148c0..fb925a2f1a 100644 --- a/pkg/llmproxy/auth/iflow/iflow_token.go +++ b/pkg/llmproxy/auth/iflow/iflow_token.go @@ -1,48 +1,28 @@ package iflow import ( - "encoding/json" "fmt" - "os" - "path/filepath" - "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/misc" + "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/auth/base" ) // IFlowTokenStorage persists iFlow OAuth credentials alongside the derived API key. type IFlowTokenStorage struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - LastRefresh string `json:"last_refresh"` - Expire string `json:"expired"` - APIKey string `json:"api_key"` - Email string `json:"email"` - TokenType string `json:"token_type"` - Scope string `json:"scope"` - Cookie string `json:"cookie"` - Type string `json:"type"` + base.BaseTokenStorage + + LastRefresh string `json:"last_refresh"` + Expire string `json:"expired"` + APIKey string `json:"api_key"` + TokenType string `json:"token_type"` + Scope string `json:"scope"` + Cookie string `json:"cookie"` } // SaveTokenToFile serialises the token storage to disk. func (ts *IFlowTokenStorage) SaveTokenToFile(authFilePath string) error { - safePath, err := misc.ResolveSafeFilePath(authFilePath) - if err != nil { - return fmt.Errorf("invalid token file path: %w", err) - } - misc.LogSavingCredentials(safePath) ts.Type = "iflow" - if err = os.MkdirAll(filepath.Dir(safePath), 0o700); err != nil { - return fmt.Errorf("iflow token: create directory failed: %w", err) - } - - f, err := os.Create(safePath) - if err != nil { - return fmt.Errorf("iflow token: create file failed: %w", err) - } - defer func() { _ = f.Close() }() - - if err = json.NewEncoder(f).Encode(ts); err != nil { - return fmt.Errorf("iflow token: encode token failed: %w", err) + if err := ts.Save(authFilePath, ts); err != nil { + return fmt.Errorf("iflow token: %w", err) } return nil } diff --git a/pkg/llmproxy/auth/kilo/kilo_token.go b/pkg/llmproxy/auth/kilo/kilo_token.go index bb09922167..71c17e1dc4 100644 --- a/pkg/llmproxy/auth/kilo/kilo_token.go +++ b/pkg/llmproxy/auth/kilo/kilo_token.go @@ -3,18 +3,20 @@ package kilo import ( - "encoding/json" "fmt" - "os" - "path/filepath" - "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/misc" - log "github.com/sirupsen/logrus" + "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/auth/base" ) // KiloTokenStorage stores token information for Kilo AI authentication. +// +// Note: Kilo uses a proprietary token format stored under the "kilocodeToken" JSON key +// rather than the standard "access_token" key, so BaseTokenStorage.AccessToken is not +// populated for this provider. The Email and Type fields from BaseTokenStorage are used. type KiloTokenStorage struct { - // Token is the Kilo access token. + base.BaseTokenStorage + + // Token is the Kilo access token (serialised as "kilocodeToken" for Kilo compatibility). Token string `json:"kilocodeToken"` // OrganizationID is the Kilo organization ID. @@ -22,38 +24,13 @@ type KiloTokenStorage struct { // Model is the default model to use. Model string `json:"kilocodeModel"` - - // Email is the email address of the authenticated user. - Email string `json:"email"` - - // Type indicates the authentication provider type, always "kilo" for this storage. - Type string `json:"type"` } // SaveTokenToFile serializes the Kilo token storage to a JSON file. func (ts *KiloTokenStorage) SaveTokenToFile(authFilePath string) error { - safePath, err := misc.ResolveSafeFilePath(authFilePath) - if err != nil { - return fmt.Errorf("invalid token file path: %w", err) - } - misc.LogSavingCredentials(safePath) ts.Type = "kilo" - if err = os.MkdirAll(filepath.Dir(safePath), 0700); err != nil { - return fmt.Errorf("failed to create directory: %v", err) - } - - f, err := os.Create(safePath) - if err != nil { - return fmt.Errorf("failed to create token file: %w", err) - } - defer func() { - if errClose := f.Close(); errClose != nil { - log.Errorf("failed to close file: %v", errClose) - } - }() - - if err = json.NewEncoder(f).Encode(ts); err != nil { - return fmt.Errorf("failed to write token to file: %w", err) + if err := ts.Save(authFilePath, ts); err != nil { + return fmt.Errorf("kilo token: %w", err) } return nil } diff --git a/pkg/llmproxy/auth/kimi/kimi.go b/pkg/llmproxy/auth/kimi/kimi.go index 42978882b7..2a5ebb6716 100644 --- a/pkg/llmproxy/auth/kimi/kimi.go +++ b/pkg/llmproxy/auth/kimi/kimi.go @@ -15,7 +15,8 @@ import ( "time" "github.com/google/uuid" - "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/config" + "github.com/kooshapari/cliproxyapi-plusplus/v6/internal/config" + "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/auth/base" "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/util" log "github.com/sirupsen/logrus" ) @@ -78,13 +79,15 @@ func (k *KimiAuth) CreateTokenStorage(bundle *KimiAuthBundle) *KimiTokenStorage expired = time.Unix(bundle.TokenData.ExpiresAt, 0).UTC().Format(time.RFC3339) } return &KimiTokenStorage{ - AccessToken: bundle.TokenData.AccessToken, - RefreshToken: bundle.TokenData.RefreshToken, - TokenType: bundle.TokenData.TokenType, - Scope: bundle.TokenData.Scope, - DeviceID: strings.TrimSpace(bundle.DeviceID), - Expired: expired, - Type: "kimi", + BaseTokenStorage: base.BaseTokenStorage{ + AccessToken: bundle.TokenData.AccessToken, + RefreshToken: bundle.TokenData.RefreshToken, + Type: "kimi", + }, + TokenType: bundle.TokenData.TokenType, + Scope: bundle.TokenData.Scope, + DeviceID: strings.TrimSpace(bundle.DeviceID), + Expired: expired, } } diff --git a/pkg/llmproxy/auth/kimi/token.go b/pkg/llmproxy/auth/kimi/token.go index 983cdc306a..61e410c40e 100644 --- a/pkg/llmproxy/auth/kimi/token.go +++ b/pkg/llmproxy/auth/kimi/token.go @@ -4,21 +4,16 @@ package kimi import ( - "encoding/json" "fmt" - "os" - "path/filepath" "time" - "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/misc" + "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/auth/base" ) // KimiTokenStorage stores OAuth2 token information for Kimi API authentication. type KimiTokenStorage struct { - // AccessToken is the OAuth2 access token used for authenticating API requests. - AccessToken string `json:"access_token"` - // RefreshToken is the OAuth2 refresh token used to obtain new access tokens. - RefreshToken string `json:"refresh_token"` + base.BaseTokenStorage + // TokenType is the type of token, typically "Bearer". TokenType string `json:"token_type"` // Scope is the OAuth2 scope granted to the token. @@ -27,8 +22,6 @@ type KimiTokenStorage struct { DeviceID string `json:"device_id,omitempty"` // Expired is the RFC3339 timestamp when the access token expires. Expired string `json:"expired,omitempty"` - // Type indicates the authentication provider type, always "kimi" for this storage. - Type string `json:"type"` } // KimiTokenData holds the raw OAuth token response from Kimi. @@ -71,29 +64,9 @@ type DeviceCodeResponse struct { // SaveTokenToFile serializes the Kimi token storage to a JSON file. func (ts *KimiTokenStorage) SaveTokenToFile(authFilePath string) error { - safePath, err := misc.ResolveSafeFilePath(authFilePath) - if err != nil { - return fmt.Errorf("invalid token file path: %w", err) - } - misc.LogSavingCredentials(safePath) ts.Type = "kimi" - - if err = os.MkdirAll(filepath.Dir(safePath), 0700); err != nil { - return fmt.Errorf("failed to create directory: %v", err) - } - - f, err := os.Create(safePath) - if err != nil { - return fmt.Errorf("failed to create token file: %w", err) - } - defer func() { - _ = f.Close() - }() - - encoder := json.NewEncoder(f) - encoder.SetIndent("", " ") - if err = encoder.Encode(ts); err != nil { - return fmt.Errorf("failed to write token to file: %w", err) + if err := ts.Save(authFilePath, ts); err != nil { + return fmt.Errorf("kimi token: %w", err) } return nil } diff --git a/pkg/llmproxy/auth/qwen/qwen_token.go b/pkg/llmproxy/auth/qwen/qwen_token.go index 64c1890c03..1163895146 100644 --- a/pkg/llmproxy/auth/qwen/qwen_token.go +++ b/pkg/llmproxy/auth/qwen/qwen_token.go @@ -4,33 +4,35 @@ package qwen import ( - "encoding/json" "fmt" "os" "path/filepath" "strings" + "github.com/KooshaPari/phenotype-go-kit/pkg/auth" "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/misc" ) -// QwenTokenStorage stores OAuth2 token information for Alibaba Qwen API authentication. -// It maintains compatibility with the existing auth system while adding Qwen-specific fields -// for managing access tokens, refresh tokens, and user account information. +// QwenTokenStorage extends BaseTokenStorage with Qwen-specific fields for managing +// access tokens, refresh tokens, and user account information. +// It embeds auth.BaseTokenStorage to inherit shared token management functionality. type QwenTokenStorage struct { - // AccessToken is the OAuth2 access token used for authenticating API requests. - AccessToken string `json:"access_token"` - // RefreshToken is used to obtain new access tokens when the current one expires. - RefreshToken string `json:"refresh_token"` - // LastRefresh is the timestamp of the last token refresh operation. - LastRefresh string `json:"last_refresh"` + *auth.BaseTokenStorage + // ResourceURL is the base URL for API requests. ResourceURL string `json:"resource_url"` - // Email is the Qwen account email address associated with this token. - Email string `json:"email"` - // Type indicates the authentication provider type, always "qwen" for this storage. - Type string `json:"type"` - // Expire is the timestamp when the current access token expires. - Expire string `json:"expired"` +} + +// NewQwenTokenStorage creates a new QwenTokenStorage instance with the given file path. +// Parameters: +// - filePath: The full path where the token file should be saved/loaded +// +// Returns: +// - *QwenTokenStorage: A new QwenTokenStorage instance +func NewQwenTokenStorage(filePath string) *QwenTokenStorage { + return &QwenTokenStorage{ + BaseTokenStorage: auth.NewBaseTokenStorage(filePath), + } } // SaveTokenToFile serializes the Qwen token storage to a JSON file. @@ -44,27 +46,16 @@ type QwenTokenStorage struct { // - error: An error if the operation fails, nil otherwise func (ts *QwenTokenStorage) SaveTokenToFile(authFilePath string) error { misc.LogSavingCredentials(authFilePath) - ts.Type = "qwen" - cleanPath, err := cleanTokenFilePath(authFilePath, "qwen token") - if err != nil { - return err - } - if err := os.MkdirAll(filepath.Dir(cleanPath), 0700); err != nil { - return fmt.Errorf("failed to create directory: %v", err) + if ts.BaseTokenStorage == nil { + return fmt.Errorf("qwen token: base token storage is nil") } - f, err := os.Create(cleanPath) - if err != nil { - return fmt.Errorf("failed to create token file: %w", err) + if _, err := cleanTokenFilePath(authFilePath, "qwen token"); err != nil { + return err } - defer func() { - _ = f.Close() - }() - if err = json.NewEncoder(f).Encode(ts); err != nil { - return fmt.Errorf("failed to write token to file: %w", err) - } - return nil + ts.BaseTokenStorage.Type = "qwen" + return ts.BaseTokenStorage.Save() } func cleanTokenFilePath(path, scope string) (string, error) { diff --git a/pkg/llmproxy/client/client.go b/pkg/llmproxy/client/client.go new file mode 100644 index 0000000000..f767634dea --- /dev/null +++ b/pkg/llmproxy/client/client.go @@ -0,0 +1,232 @@ +package client + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" +) + +// Client is an HTTP client for the cliproxyapi++ proxy server. +// +// It covers: +// - GET /v1/models — list available models +// - POST /v1/chat/completions — chat completions (non-streaming) +// - POST /v1/responses — OpenAI Responses API passthrough +// - GET / — health / reachability check +// +// Streaming variants are deliberately out of scope for this package; callers +// that need SSE should use [net/http] directly against [Client.BaseURL]. +type Client struct { + cfg clientConfig + http *http.Client +} + +// New creates a Client with the given options. +// +// Defaults: base URL http://127.0.0.1:8318, timeout 120 s, no auth. +func New(opts ...Option) *Client { + cfg := defaultConfig() + for _, o := range opts { + o(&cfg) + } + cfg.baseURL = strings.TrimRight(cfg.baseURL, "/") + return &Client{ + cfg: cfg, + http: &http.Client{Timeout: cfg.httpTimeout}, + } +} + +// BaseURL returns the proxy base URL this client is configured against. +func (c *Client) BaseURL() string { return c.cfg.baseURL } + +// --------------------------------------------------------------------------- +// Internal helpers +// --------------------------------------------------------------------------- + +func (c *Client) newRequest(ctx context.Context, method, path string, body any) (*http.Request, error) { + var bodyReader io.Reader + if body != nil { + b, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("cliproxy/client: marshal request body: %w", err) + } + bodyReader = bytes.NewReader(b) + } + + req, err := http.NewRequestWithContext(ctx, method, c.cfg.baseURL+path, bodyReader) + if err != nil { + return nil, err + } + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + req.Header.Set("Accept", "application/json") + + // LLM API key (Bearer token for /v1/* routes) + if c.cfg.apiKey != "" { + req.Header.Set("Authorization", "Bearer "+c.cfg.apiKey) + } + return req, nil +} + +func (c *Client) do(req *http.Request) ([]byte, int, error) { + resp, err := c.http.Do(req) + if err != nil { + return nil, 0, fmt.Errorf("cliproxy/client: HTTP %s %s: %w", req.Method, req.URL.Path, err) + } + defer func() { _ = resp.Body.Close() }() + + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil, resp.StatusCode, fmt.Errorf("cliproxy/client: read response body: %w", err) + } + return data, resp.StatusCode, nil +} + +func (c *Client) doJSON(req *http.Request, out any) error { + data, code, err := c.do(req) + if err != nil { + return err + } + if code >= 400 { + return parseAPIError(code, data) + } + if out == nil { + return nil + } + if err := json.Unmarshal(data, out); err != nil { + return fmt.Errorf("cliproxy/client: decode response (HTTP %d): %w", code, err) + } + return nil +} + +// parseAPIError extracts a structured error from a non-2xx response body. +// It mirrors the error shape produced by _make_error_body in the Python adapter. +func parseAPIError(code int, body []byte) *APIError { + var envelope struct { + Error struct { + Message string `json:"message"` + Code any `json:"code"` + } `json:"error"` + } + msg := strings.TrimSpace(string(body)) + if err := json.Unmarshal(body, &envelope); err == nil && envelope.Error.Message != "" { + msg = envelope.Error.Message + } + if msg == "" { + msg = fmt.Sprintf("proxy returned HTTP %d", code) + } + return &APIError{StatusCode: code, Message: msg, Code: envelope.Error.Code} +} + +// --------------------------------------------------------------------------- +// Public API +// --------------------------------------------------------------------------- + +// Health performs a lightweight GET / against the proxy and reports whether it +// is reachable. A nil error means the server responded with HTTP 2xx. +func (c *Client) Health(ctx context.Context) error { + req, err := c.newRequest(ctx, http.MethodGet, "/", nil) + if err != nil { + return err + } + _, code, err := c.do(req) + if err != nil { + return err + } + if code >= 400 { + return fmt.Errorf("cliproxy/client: health check failed with HTTP %d", code) + } + return nil +} + +// ListModels calls GET /v1/models and returns the normalised model list. +// +// cliproxyapi++ transforms the upstream OpenAI-compatible {"data":[...]} shape +// into {"models":[...]} for Codex compatibility. This method handles both +// shapes transparently. +func (c *Client) ListModels(ctx context.Context) (*ModelsResponse, error) { + req, err := c.newRequest(ctx, http.MethodGet, "/v1/models", nil) + if err != nil { + return nil, err + } + + // Use the underlying Do directly so we can read the response headers. + httpResp, err := c.http.Do(req) + if err != nil { + return nil, fmt.Errorf("cliproxy/client: GET /v1/models: %w", err) + } + defer func() { _ = httpResp.Body.Close() }() + + data, err := io.ReadAll(httpResp.Body) + if err != nil { + return nil, fmt.Errorf("cliproxy/client: read /v1/models body: %w", err) + } + if httpResp.StatusCode >= 400 { + return nil, parseAPIError(httpResp.StatusCode, data) + } + + // The proxy normalises the response to {"models":[...]}. + // Fall back to the raw OpenAI {"data":[...], "object":"list"} shape for + // consumers that hit the upstream directly. + var result ModelsResponse + var raw map[string]json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + return nil, fmt.Errorf("cliproxy/client: decode /v1/models: %w", err) + } + + if modelsJSON, ok := raw["models"]; ok { + if err := json.Unmarshal(modelsJSON, &result.Models); err != nil { + return nil, fmt.Errorf("cliproxy/client: decode models array: %w", err) + } + } else if dataJSON, ok := raw["data"]; ok { + if err := json.Unmarshal(dataJSON, &result.Models); err != nil { + return nil, fmt.Errorf("cliproxy/client: decode data array: %w", err) + } + } + + // Capture ETag from response header (set by the proxy for cache validation). + result.ETag = httpResp.Header.Get("x-models-etag") + + return &result, nil +} + +// ChatCompletion sends a non-streaming POST /v1/chat/completions request. +// +// For streaming completions use net/http directly; this package does not wrap +// SSE streams in order to avoid pulling in additional dependencies. +func (c *Client) ChatCompletion(ctx context.Context, r ChatCompletionRequest) (*ChatCompletionResponse, error) { + r.Stream = false // enforce non-streaming + req, err := c.newRequest(ctx, http.MethodPost, "/v1/chat/completions", r) + if err != nil { + return nil, err + } + var out ChatCompletionResponse + if err := c.doJSON(req, &out); err != nil { + return nil, err + } + return &out, nil +} + +// Responses sends a non-streaming POST /v1/responses request (OpenAI Responses +// API). The proxy transparently bridges this to /v1/chat/completions when the +// backend does not natively support the Responses endpoint. +// +// The raw decoded JSON is returned as map[string]any to remain forward- +// compatible as the Responses API schema evolves. +func (c *Client) Responses(ctx context.Context, r ResponsesRequest) (map[string]any, error) { + r.Stream = false + req, err := c.newRequest(ctx, http.MethodPost, "/v1/responses", r) + if err != nil { + return nil, err + } + var out map[string]any + if err := c.doJSON(req, &out); err != nil { + return nil, err + } + return out, nil +} diff --git a/pkg/llmproxy/client/client_test.go b/pkg/llmproxy/client/client_test.go new file mode 100644 index 0000000000..2c6da92194 --- /dev/null +++ b/pkg/llmproxy/client/client_test.go @@ -0,0 +1,339 @@ +package client_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/client" +) + +// --------------------------------------------------------------------------- +// helpers +// --------------------------------------------------------------------------- + +func newTestServer(t *testing.T, handler http.Handler) (*httptest.Server, *client.Client) { + t.Helper() + srv := httptest.NewServer(handler) + t.Cleanup(srv.Close) + c := client.New( + client.WithBaseURL(srv.URL), + client.WithTimeout(5*time.Second), + ) + return srv, c +} + +func writeJSON(w http.ResponseWriter, status int, v any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(v) +} + +// --------------------------------------------------------------------------- +// Health +// --------------------------------------------------------------------------- + +func TestHealth_OK(t *testing.T) { + _, c := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/" { + http.NotFound(w, r) + return + } + writeJSON(w, 200, map[string]string{"status": "ok"}) + })) + + if err := c.Health(context.Background()); err != nil { + t.Fatalf("Health() unexpected error: %v", err) + } +} + +func TestHealth_Unreachable(t *testing.T) { + // Point at a port nothing is listening on. + c := client.New( + client.WithBaseURL("http://127.0.0.1:1"), + client.WithTimeout(500*time.Millisecond), + ) + if err := c.Health(context.Background()); err == nil { + t.Fatal("Health() expected error for unreachable server, got nil") + } +} + +func TestHealth_ServerError(t *testing.T) { + _, c := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, 503, map[string]any{ + "error": map[string]any{"message": "service unavailable", "code": 503}, + }) + })) + if err := c.Health(context.Background()); err == nil { + t.Fatal("Health() expected error for 503, got nil") + } +} + +// --------------------------------------------------------------------------- +// ListModels +// --------------------------------------------------------------------------- + +func TestListModels_ProxyShape(t *testing.T) { + // cliproxyapi++ normalised shape: {"models": [...]} + _, c := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet || r.URL.Path != "/v1/models" { + http.NotFound(w, r) + return + } + w.Header().Set("x-models-etag", "abc123") + writeJSON(w, 200, map[string]any{ + "models": []map[string]any{ + {"id": "anthropic/claude-opus-4-6", "object": "model", "owned_by": "anthropic"}, + {"id": "openai/gpt-4o", "object": "model", "owned_by": "openai"}, + }, + }) + })) + + resp, err := c.ListModels(context.Background()) + if err != nil { + t.Fatalf("ListModels() unexpected error: %v", err) + } + if len(resp.Models) != 2 { + t.Fatalf("expected 2 models, got %d", len(resp.Models)) + } + if resp.Models[0].ID != "anthropic/claude-opus-4-6" { + t.Errorf("unexpected first model ID: %s", resp.Models[0].ID) + } +} + +func TestListModels_OpenAIShape(t *testing.T) { + // Raw upstream OpenAI shape: {"data": [...], "object": "list"} + _, c := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, 200, map[string]any{ + "object": "list", + "data": []map[string]any{ + {"id": "gpt-4o", "object": "model", "owned_by": "openai"}, + }, + }) + })) + + resp, err := c.ListModels(context.Background()) + if err != nil { + t.Fatalf("ListModels() unexpected error: %v", err) + } + if len(resp.Models) != 1 || resp.Models[0].ID != "gpt-4o" { + t.Errorf("unexpected models: %+v", resp.Models) + } +} + +func TestListModels_Error(t *testing.T) { + _, c := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, 401, map[string]any{ + "error": map[string]any{"message": "unauthorized", "code": 401}, + }) + })) + + _, err := c.ListModels(context.Background()) + if err == nil { + t.Fatal("ListModels() expected error for 401, got nil") + } + if _, ok := err.(*client.APIError); !ok { + t.Logf("error type: %T — not an *client.APIError, that is acceptable", err) + } +} + +// --------------------------------------------------------------------------- +// ChatCompletion +// --------------------------------------------------------------------------- + +func TestChatCompletion_OK(t *testing.T) { + _, c := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/v1/chat/completions" { + http.NotFound(w, r) + return + } + // Decode and validate request body + var body client.ChatCompletionRequest + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, "bad request", 400) + return + } + if body.Stream { + http.Error(w, "client must not set stream=true", 400) + return + } + writeJSON(w, 200, map[string]any{ + "id": "chatcmpl-test", + "object": "chat.completion", + "created": 1700000000, + "model": body.Model, + "choices": []map[string]any{ + { + "index": 0, + "message": map[string]any{"role": "assistant", "content": "Hello!"}, + "finish_reason": "stop", + }, + }, + "usage": map[string]any{ + "prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15, + }, + }) + })) + + resp, err := c.ChatCompletion(context.Background(), client.ChatCompletionRequest{ + Model: "anthropic/claude-opus-4-6", + Messages: []client.ChatMessage{ + {Role: "user", Content: "Say hi"}, + }, + }) + if err != nil { + t.Fatalf("ChatCompletion() unexpected error: %v", err) + } + if len(resp.Choices) == 0 { + t.Fatal("expected at least one choice") + } + if resp.Choices[0].Message.Content != "Hello!" { + t.Errorf("unexpected content: %q", resp.Choices[0].Message.Content) + } +} + +func TestChatCompletion_4xx(t *testing.T) { + _, c := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, 429, map[string]any{ + "error": map[string]any{"message": "rate limit exceeded", "code": 429}, + }) + })) + + _, err := c.ChatCompletion(context.Background(), client.ChatCompletionRequest{ + Model: "any", + Messages: []client.ChatMessage{{Role: "user", Content: "hi"}}, + }) + if err == nil { + t.Fatal("expected error for 429") + } +} + +// --------------------------------------------------------------------------- +// Responses +// --------------------------------------------------------------------------- + +func TestResponses_OK(t *testing.T) { + _, c := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/v1/responses" { + http.NotFound(w, r) + return + } + writeJSON(w, 200, map[string]any{ + "id": "resp_test", + "object": "response", + "output": []map[string]any{ + {"type": "message", "role": "assistant", "content": []map[string]any{ + {"type": "text", "text": "Hello from responses API"}, + }}, + }, + }) + })) + + out, err := c.Responses(context.Background(), client.ResponsesRequest{ + Model: "anthropic/claude-opus-4-6", + Input: "Say hi", + }) + if err != nil { + t.Fatalf("Responses() unexpected error: %v", err) + } + if out["id"] != "resp_test" { + t.Errorf("unexpected id: %v", out["id"]) + } +} + +// --------------------------------------------------------------------------- +// Options +// --------------------------------------------------------------------------- + +func TestWithAPIKey_SetsAuthorizationHeader(t *testing.T) { + var gotAuth string + _, c := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + writeJSON(w, 200, map[string]any{"models": []any{}}) + })) + // Rebuild with API key + _, c = newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + writeJSON(w, 200, map[string]any{"models": []any{}}) + })) + _ = c // silence unused warning; we rebuild below + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + writeJSON(w, 200, map[string]any{"models": []any{}}) + })) + t.Cleanup(srv.Close) + + c = client.New( + client.WithBaseURL(srv.URL), + client.WithAPIKey("sk-test-key"), + client.WithTimeout(5*time.Second), + ) + if _, err := c.ListModels(context.Background()); err != nil { + t.Fatalf("ListModels() unexpected error: %v", err) + } + if gotAuth != "Bearer sk-test-key" { + t.Errorf("expected 'Bearer sk-test-key', got %q", gotAuth) + } +} + +func TestBaseURL(t *testing.T) { + c := client.New(client.WithBaseURL("http://localhost:9999")) + if c.BaseURL() != "http://localhost:9999" { + t.Errorf("BaseURL() = %q, want %q", c.BaseURL(), "http://localhost:9999") + } +} + +// --------------------------------------------------------------------------- +// Error type +// --------------------------------------------------------------------------- + +func TestAPIError_Message(t *testing.T) { + _, c := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, 503, map[string]any{ + "error": map[string]any{ + "message": "service unavailable — no providers matched", + "code": 503, + }, + }) + })) + + _, err := c.ListModels(context.Background()) + if err == nil { + t.Fatal("expected error") + } + apiErr, ok := err.(*client.APIError) + if !ok { + t.Fatalf("expected *client.APIError, got %T", err) + } + if apiErr.StatusCode != 503 { + t.Errorf("StatusCode = %d, want 503", apiErr.StatusCode) + } + if apiErr.Message == "" { + t.Error("Message must not be empty") + } +} + +// --------------------------------------------------------------------------- +// Context cancellation +// --------------------------------------------------------------------------- + +func TestContextCancellation(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Block until client cancels + <-r.Context().Done() + w.WriteHeader(200) + })) + t.Cleanup(srv.Close) + + c := client.New(client.WithBaseURL(srv.URL), client.WithTimeout(5*time.Second)) + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + if err := c.Health(ctx); err == nil { + t.Fatal("expected error due to context cancellation") + } +} diff --git a/pkg/llmproxy/client/types.go b/pkg/llmproxy/client/types.go new file mode 100644 index 0000000000..216dd69d71 --- /dev/null +++ b/pkg/llmproxy/client/types.go @@ -0,0 +1,147 @@ +// Package client provides a Go SDK for the cliproxyapi++ HTTP proxy API. +// +// It covers the core LLM proxy surface: model listing, chat completions, the +// Responses API, and the proxy process lifecycle (start/stop/health). +// +// # Migration note +// +// This package is the canonical Go replacement for the Python adapter code +// that previously lived in thegent/src/thegent/cliproxy_adapter.py and +// related helpers. Any new consumer should import this package rather than +// re-implementing raw HTTP calls. +package client + +import "time" + +// --------------------------------------------------------------------------- +// Model types +// --------------------------------------------------------------------------- + +// Model is a single entry from GET /v1/models. +type Model struct { + ID string `json:"id"` + Object string `json:"object,omitempty"` + Created int64 `json:"created,omitempty"` + OwnedBy string `json:"owned_by,omitempty"` +} + +// ModelsResponse is the envelope returned by GET /v1/models. +// cliproxyapi++ normalises the upstream shape into {"models": [...]}. +type ModelsResponse struct { + Models []Model `json:"models"` + // ETag is populated from the x-models-etag response header when present. + ETag string `json:"-"` +} + +// --------------------------------------------------------------------------- +// Chat completions types +// --------------------------------------------------------------------------- + +// ChatMessage is a single message in a chat conversation. +type ChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// ChatCompletionRequest is the body for POST /v1/chat/completions. +type ChatCompletionRequest struct { + Model string `json:"model"` + Messages []ChatMessage `json:"messages"` + Stream bool `json:"stream,omitempty"` + // MaxTokens limits the number of tokens generated. + MaxTokens *int `json:"max_tokens,omitempty"` + // Temperature controls randomness (0–2). + Temperature *float64 `json:"temperature,omitempty"` +} + +// ChatChoice is a single completion choice. +type ChatChoice struct { + Index int `json:"index"` + Message ChatMessage `json:"message"` + FinishReason string `json:"finish_reason"` +} + +// Usage holds token counts reported by the backend. +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + Cost float64 `json:"cost,omitempty"` +} + +// ChatCompletionResponse is the non-streaming response from POST /v1/chat/completions. +type ChatCompletionResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatChoice `json:"choices"` + Usage Usage `json:"usage"` +} + +// --------------------------------------------------------------------------- +// Responses API types (POST /v1/responses) +// --------------------------------------------------------------------------- + +// ResponsesRequest is the body for POST /v1/responses (OpenAI Responses API). +type ResponsesRequest struct { + Model string `json:"model"` + Input any `json:"input"` + Stream bool `json:"stream,omitempty"` +} + +// --------------------------------------------------------------------------- +// Error type +// --------------------------------------------------------------------------- + +// APIError is returned when the server responds with a non-2xx status code. +type APIError struct { + StatusCode int + Message string + Code any +} + +func (e *APIError) Error() string { + return e.Message +} + +// --------------------------------------------------------------------------- +// Client options +// --------------------------------------------------------------------------- + +// Option configures a [Client]. +type Option func(*clientConfig) + +type clientConfig struct { + baseURL string + apiKey string + secretKey string + httpTimeout time.Duration +} + +func defaultConfig() clientConfig { + return clientConfig{ + baseURL: "http://127.0.0.1:8318", + httpTimeout: 120 * time.Second, + } +} + +// WithBaseURL overrides the proxy base URL (default: http://127.0.0.1:8318). +func WithBaseURL(u string) Option { + return func(c *clientConfig) { c.baseURL = u } +} + +// WithAPIKey sets the Authorization: Bearer header for LLM API calls. +func WithAPIKey(key string) Option { + return func(c *clientConfig) { c.apiKey = key } +} + +// WithSecretKey sets the management API bearer token (used for /v0/management/* routes). +func WithSecretKey(key string) Option { + return func(c *clientConfig) { c.secretKey = key } +} + +// WithTimeout sets the HTTP client timeout (default: 120s). +func WithTimeout(d time.Duration) Option { + return func(c *clientConfig) { c.httpTimeout = d } +} diff --git a/pkg/llmproxy/executor/kiro_auth.go b/pkg/llmproxy/executor/kiro_auth.go new file mode 100644 index 0000000000..2adf85d76f --- /dev/null +++ b/pkg/llmproxy/executor/kiro_auth.go @@ -0,0 +1,397 @@ +// Package executor provides HTTP request execution for various AI providers. +// This file contains Kiro-specific authentication handling logic. +package executor + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + "github.com/google/uuid" + kiroauth "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/auth/kiro" + "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/config" + "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/util" + cliproxyauth "github.com/kooshapari/cliproxyapi-plusplus/v6/sdk/cliproxy/auth" + log "github.com/sirupsen/logrus" +) + +// kiroCredentials extracts access token and profile ARN from auth object. +func kiroCredentials(auth *cliproxyauth.Auth) (accessToken, profileArn string) { + if auth == nil { + return "", "" + } + if auth.Metadata != nil { + if token, ok := auth.Metadata["access_token"].(string); ok { + accessToken = token + } + if arn, ok := auth.Metadata["profile_arn"].(string); ok { + profileArn = arn + } + } + if accessToken == "" && auth.Attributes != nil { + accessToken = auth.Attributes["access_token"] + profileArn = auth.Attributes["profile_arn"] + } + if accessToken == "" && auth.Metadata != nil { + if token, ok := auth.Metadata["accessToken"].(string); ok { + accessToken = token + } + if arn, ok := auth.Metadata["profileArn"].(string); ok { + profileArn = arn + } + } + return accessToken, profileArn +} + +// getTokenKey returns a unique key for rate limiting based on auth credentials. +func getTokenKey(auth *cliproxyauth.Auth) string { + if auth != nil && auth.ID != "" { + return auth.ID + } + accessToken, _ := kiroCredentials(auth) + if len(accessToken) > 16 { + return accessToken[:16] + } + return accessToken +} + +// isIDCAuth checks if this auth object uses IDC authentication. +func isIDCAuth(auth *cliproxyauth.Auth) bool { + if auth == nil || auth.Metadata == nil { + return false + } + authMethod := getMetadataString(auth.Metadata, "auth_method", "authMethod") + return strings.ToLower(authMethod) == "idc" || + strings.ToLower(authMethod) == "builder-id" +} + +// applyDynamicFingerprint applies fingerprint-based headers to requests. +func applyDynamicFingerprint(req *http.Request, auth *cliproxyauth.Auth) { + if isIDCAuth(auth) { + tokenKey := getTokenKey(auth) + fp := getGlobalFingerprintManager().GetFingerprint(tokenKey) + req.Header.Set("User-Agent", fp.BuildUserAgent()) + req.Header.Set("X-Amz-User-Agent", fp.BuildAmzUserAgent()) + req.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeVibe) + log.Debugf("kiro: using dynamic fingerprint for token %s (SDK:%s, OS:%s/%s, Kiro:%s)", + tokenKey[:8]+"...", fp.SDKVersion, fp.OSType, fp.OSVersion, fp.KiroVersion) + } else { + req.Header.Set("User-Agent", kiroUserAgent) + req.Header.Set("X-Amz-User-Agent", kiroFullUserAgent) + } +} + +// PrepareRequest prepares the HTTP request before execution. +func (e *KiroExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error { + if req == nil { + return nil + } + accessToken, _ := kiroCredentials(auth) + if strings.TrimSpace(accessToken) == "" { + return statusErr{code: http.StatusUnauthorized, msg: "missing access token"} + } + applyDynamicFingerprint(req, auth) + req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") + req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) + req.Header.Set("Authorization", "Bearer "+accessToken) + var attrs map[string]string + if auth != nil { + attrs = auth.Attributes + } + util.ApplyCustomHeadersFromAttrs(req, attrs) + return nil +} + +// HttpRequest injects Kiro credentials into the request and executes it. +func (e *KiroExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) { + if req == nil { + return nil, fmt.Errorf("kiro executor: request is nil") + } + if ctx == nil { + ctx = req.Context() + } + httpReq := req.WithContext(ctx) + if errPrepare := e.PrepareRequest(httpReq, auth); errPrepare != nil { + return nil, errPrepare + } + httpClient := newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 0) + return httpClient.Do(httpReq) +} + +// Refresh performs token refresh using appropriate OAuth2 flow. +func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { + e.refreshMu.Lock() + defer e.refreshMu.Unlock() + + var authID string + if auth != nil { + authID = auth.ID + } else { + authID = "" + } + log.Debugf("kiro executor: refresh called for auth %s", authID) + if auth == nil { + return nil, fmt.Errorf("kiro executor: auth is nil") + } + + if auth.Metadata != nil { + if lastRefresh, ok := auth.Metadata["last_refresh"].(string); ok { + if refreshTime, err := time.Parse(time.RFC3339, lastRefresh); err == nil { + if time.Since(refreshTime) < 30*time.Second { + log.Debugf("kiro executor: token was recently refreshed by another goroutine, skipping") + return auth, nil + } + } + } + if expiresAt, ok := auth.Metadata["expires_at"].(string); ok { + if expTime, err := time.Parse(time.RFC3339, expiresAt); err == nil { + if time.Until(expTime) > 20*time.Minute { + log.Debugf("kiro executor: token is still valid (expires in %v), skipping refresh", time.Until(expTime)) + updated := auth.Clone() + nextRefresh := expTime.Add(-20 * time.Minute) + minNextRefresh := time.Now().Add(30 * time.Second) + if nextRefresh.Before(minNextRefresh) { + nextRefresh = minNextRefresh + } + updated.NextRefreshAfter = nextRefresh + log.Debugf("kiro executor: setting NextRefreshAfter to %v (in %v)", nextRefresh.Format(time.RFC3339), time.Until(nextRefresh)) + return updated, nil + } + } + } + } + + var refreshToken string + var clientID, clientSecret string + var authMethod string + var region, startURL string + + if auth.Metadata != nil { + refreshToken = getMetadataString(auth.Metadata, "refresh_token", "refreshToken") + clientID = getMetadataString(auth.Metadata, "client_id", "clientId") + clientSecret = getMetadataString(auth.Metadata, "client_secret", "clientSecret") + authMethod = strings.ToLower(getMetadataString(auth.Metadata, "auth_method", "authMethod")) + region = getMetadataString(auth.Metadata, "region") + startURL = getMetadataString(auth.Metadata, "start_url", "startUrl") + } + + if refreshToken == "" { + return nil, fmt.Errorf("kiro executor: refresh token not found") + } + + var tokenData *kiroauth.KiroTokenData + var err error + + ssoClient := kiroauth.NewSSOOIDCClient(e.cfg) + + switch { + case clientID != "" && clientSecret != "" && authMethod == "idc" && region != "": + log.Debugf("kiro executor: using SSO OIDC refresh for IDC (region=%s)", region) + tokenData, err = ssoClient.RefreshTokenWithRegion(ctx, clientID, clientSecret, refreshToken, region, startURL) + case clientID != "" && clientSecret != "" && authMethod == "builder-id": + log.Debugf("kiro executor: using SSO OIDC refresh for AWS Builder ID") + tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken) + default: + log.Debugf("kiro executor: using Kiro OAuth refresh endpoint") + oauth := kiroauth.NewKiroOAuth(e.cfg) + tokenData, err = oauth.RefreshToken(ctx, refreshToken) + } + + if err != nil { + return nil, fmt.Errorf("kiro executor: token refresh failed: %w", err) + } + + updated := auth.Clone() + now := time.Now() + updated.UpdatedAt = now + updated.LastRefreshedAt = now + + if updated.Metadata == nil { + updated.Metadata = make(map[string]any) + } + updated.Metadata["access_token"] = tokenData.AccessToken + updated.Metadata["refresh_token"] = tokenData.RefreshToken + updated.Metadata["expires_at"] = tokenData.ExpiresAt + updated.Metadata["last_refresh"] = now.Format(time.RFC3339) + if tokenData.ProfileArn != "" { + updated.Metadata["profile_arn"] = tokenData.ProfileArn + } + if tokenData.AuthMethod != "" { + updated.Metadata["auth_method"] = tokenData.AuthMethod + } + if tokenData.Provider != "" { + updated.Metadata["provider"] = tokenData.Provider + } + if tokenData.ClientID != "" { + updated.Metadata["client_id"] = tokenData.ClientID + } + if tokenData.ClientSecret != "" { + updated.Metadata["client_secret"] = tokenData.ClientSecret + } + if tokenData.Region != "" { + updated.Metadata["region"] = tokenData.Region + } + if tokenData.StartURL != "" { + updated.Metadata["start_url"] = tokenData.StartURL + } + + if updated.Attributes == nil { + updated.Attributes = make(map[string]string) + } + updated.Attributes["access_token"] = tokenData.AccessToken + if tokenData.ProfileArn != "" { + updated.Attributes["profile_arn"] = tokenData.ProfileArn + } + + if expiresAt, parseErr := time.Parse(time.RFC3339, tokenData.ExpiresAt); parseErr == nil { + updated.NextRefreshAfter = expiresAt.Add(-20 * time.Minute) + } + + log.Infof("kiro executor: token refreshed successfully, expires at %s", tokenData.ExpiresAt) + return updated, nil +} + +// persistRefreshedAuth persists a refreshed auth record to disk. +func (e *KiroExecutor) persistRefreshedAuth(auth *cliproxyauth.Auth) error { + if auth == nil || auth.Metadata == nil { + return fmt.Errorf("kiro executor: cannot persist nil auth or metadata") + } + + var authPath string + if auth.Attributes != nil { + if p := strings.TrimSpace(auth.Attributes["path"]); p != "" { + authPath = p + } + } + if authPath == "" { + fileName := strings.TrimSpace(auth.FileName) + if fileName == "" { + return fmt.Errorf("kiro executor: auth has no file path or filename") + } + if filepath.IsAbs(fileName) { + authPath = fileName + } else if e.cfg != nil && e.cfg.AuthDir != "" { + authPath = filepath.Join(e.cfg.AuthDir, fileName) + } else { + return fmt.Errorf("kiro executor: cannot determine auth file path") + } + } + + raw, err := json.Marshal(auth.Metadata) + if err != nil { + return fmt.Errorf("kiro executor: marshal metadata failed: %w", err) + } + + tmp := authPath + ".tmp" + if err := os.WriteFile(tmp, raw, 0o600); err != nil { + return fmt.Errorf("kiro executor: write temp auth file failed: %w", err) + } + if err := os.Rename(tmp, authPath); err != nil { + return fmt.Errorf("kiro executor: rename auth file failed: %w", err) + } + + log.Debugf("kiro executor: persisted refreshed auth to %s", authPath) + return nil +} + +// reloadAuthFromFile reloads the auth object from its persistent storage. +func (e *KiroExecutor) reloadAuthFromFile(auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { + if auth == nil { + return nil, fmt.Errorf("kiro executor: cannot reload nil auth") + } + + var authPath string + if auth.Attributes != nil { + if p := strings.TrimSpace(auth.Attributes["path"]); p != "" { + authPath = p + } + } + if authPath == "" { + fileName := strings.TrimSpace(auth.FileName) + if fileName == "" { + return nil, fmt.Errorf("kiro executor: auth has no file path or filename for reload") + } + if filepath.IsAbs(fileName) { + authPath = fileName + } else if e.cfg != nil && e.cfg.AuthDir != "" { + authPath = filepath.Join(e.cfg.AuthDir, fileName) + } else { + return nil, fmt.Errorf("kiro executor: cannot determine auth file path for reload") + } + } + + raw, err := os.ReadFile(authPath) + if err != nil { + return nil, fmt.Errorf("kiro executor: read auth file failed: %w", err) + } + + var metadata map[string]any + if err := json.Unmarshal(raw, &metadata); err != nil { + return nil, fmt.Errorf("kiro executor: unmarshal auth metadata failed: %w", err) + } + + reloaded := auth.Clone() + reloaded.Metadata = metadata + log.Debugf("kiro executor: reloaded auth from %s", authPath) + return reloaded, nil +} + +// isTokenExpired checks if the access token has expired by decoding JWT. +func (e *KiroExecutor) isTokenExpired(accessToken string) bool { + if accessToken == "" { + return true + } + + parts := strings.Split(accessToken, ".") + if len(parts) != 3 { + return false + } + + payload := parts[1] + switch len(payload) % 4 { + case 1: + payload += "===" + case 2: + payload += "==" + case 3: + payload += "=" + } + + decoded, err := base64.RawStdEncoding.DecodeString(payload) + if err != nil { + log.Debugf("kiro: failed to decode JWT payload: %v", err) + return false + } + + var claims map[string]any + if err := json.Unmarshal(decoded, &claims); err != nil { + log.Debugf("kiro: failed to parse JWT claims: %v", err) + return false + } + + if exp, ok := claims["exp"]; ok { + var expiresAt int64 + switch v := exp.(type) { + case float64: + expiresAt = int64(v) + case int64: + expiresAt = v + default: + return false + } + + now := time.Now().Unix() + if now > expiresAt { + log.Debugf("kiro: token expired at %d (now: %d)", expiresAt, now) + return true + } + } + + return false +} diff --git a/pkg/llmproxy/executor/kiro_executor.go b/pkg/llmproxy/executor/kiro_executor.go index d674ab3db1..0a25f4e99c 100644 --- a/pkg/llmproxy/executor/kiro_executor.go +++ b/pkg/llmproxy/executor/kiro_executor.go @@ -23,11 +23,11 @@ import ( "time" "github.com/google/uuid" + kiroauth "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/auth/kiro" + "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/config" kiroclaude "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/translator/kiro/claude" kirocommon "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/translator/kiro/common" kiroopenai "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/translator/kiro/openai" - kiroauth "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/auth/kiro" - "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/config" "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/util" cliproxyauth "github.com/kooshapari/cliproxyapi-plusplus/v6/sdk/cliproxy/auth" cliproxyexecutor "github.com/kooshapari/cliproxyapi-plusplus/v6/sdk/cliproxy/executor" @@ -324,238 +324,12 @@ func newKiroHTTPClientWithPooling(ctx context.Context, cfg *config.Config, auth return pooledClient } -// kiroEndpointConfig bundles endpoint URL with its compatible Origin and AmzTarget values. -// This solves the "triple mismatch" problem where different endpoints require matching -// Origin and X-Amz-Target header values. -// -// Based on reference implementations: -// - amq2api-main: Uses Amazon Q endpoint with CLI origin and AmazonQDeveloperStreamingService target -// - AIClient-2-API: Uses CodeWhisperer endpoint with AI_EDITOR origin and AmazonCodeWhispererStreamingService target -type kiroEndpointConfig struct { - URL string // Endpoint URL - Origin string // Request Origin: "CLI" for Amazon Q quota, "AI_EDITOR" for Kiro IDE quota - AmzTarget string // X-Amz-Target header value - Name string // Endpoint name for logging -} - -// kiroDefaultRegion is the default AWS region for Kiro API endpoints. -// Used when no region is specified in auth metadata. -const kiroDefaultRegion = "us-east-1" - -// extractRegionFromProfileARN extracts the AWS region from a ProfileARN. -// ARN format: arn:aws:codewhisperer:REGION:ACCOUNT:profile/PROFILE_ID -// Returns empty string if region cannot be extracted. -func extractRegionFromProfileARN(profileArn string) string { - if profileArn == "" { - return "" - } - parts := strings.Split(profileArn, ":") - if len(parts) >= 4 && parts[3] != "" { - return parts[3] - } - return "" -} - -// buildKiroEndpointConfigs creates endpoint configurations for the specified region. -// This enables dynamic region support for Enterprise/IdC users in non-us-east-1 regions. -// -// Uses Q endpoint (q.{region}.amazonaws.com) as primary for ALL auth types: -// - Works universally across all AWS regions (CodeWhisperer endpoint only exists in us-east-1) -// - Uses /generateAssistantResponse path with AI_EDITOR origin -// - Does NOT require X-Amz-Target header -// -// The AmzTarget field is kept for backward compatibility but should be empty -// to indicate that the header should NOT be set. -func buildKiroEndpointConfigs(region string) []kiroEndpointConfig { - if region == "" { - region = kiroDefaultRegion - } - return []kiroEndpointConfig{ - { - // Primary: Q endpoint - works for all regions and auth types - URL: fmt.Sprintf("https://q.%s.amazonaws.com/generateAssistantResponse", region), - Origin: "AI_EDITOR", - AmzTarget: "", // Empty = don't set X-Amz-Target header - Name: "AmazonQ", - }, - { - // Fallback: CodeWhisperer endpoint (legacy, only works in us-east-1) - URL: fmt.Sprintf("https://codewhisperer.%s.amazonaws.com/generateAssistantResponse", region), - Origin: "AI_EDITOR", - AmzTarget: "AmazonCodeWhispererStreamingService.GenerateAssistantResponse", - Name: "CodeWhisperer", - }, - } -} - -// resolveKiroAPIRegion determines the AWS region for Kiro API calls. -// Region priority: -// 1. auth.Metadata["api_region"] - explicit API region override -// 2. ProfileARN region - extracted from arn:aws:service:REGION:account:resource -// 3. kiroDefaultRegion (us-east-1) - fallback -// Note: OIDC "region" is NOT used - it's for token refresh, not API calls -func resolveKiroAPIRegion(auth *cliproxyauth.Auth) string { - if auth == nil || auth.Metadata == nil { - return kiroDefaultRegion - } - // Priority 1: Explicit api_region override - if r, ok := auth.Metadata["api_region"].(string); ok && r != "" { - log.Debugf("kiro: using region %s (source: api_region)", r) - return r - } - // Priority 2: Extract from ProfileARN - if profileArn, ok := auth.Metadata["profile_arn"].(string); ok && profileArn != "" { - if arnRegion := extractRegionFromProfileARN(profileArn); arnRegion != "" { - log.Debugf("kiro: using region %s (source: profile_arn)", arnRegion) - return arnRegion - } - } - // Note: OIDC "region" field is NOT used for API endpoint - // Kiro API only exists in us-east-1, while OIDC region can vary (e.g., ap-northeast-2) - // Using OIDC region for API calls causes DNS failures - log.Debugf("kiro: using region %s (source: default)", kiroDefaultRegion) - return kiroDefaultRegion -} - -// kiroEndpointConfigs is kept for backward compatibility with default us-east-1 region. -// Prefer using buildKiroEndpointConfigs(region) for dynamic region support. -var kiroEndpointConfigs = buildKiroEndpointConfigs(kiroDefaultRegion) - -// getKiroEndpointConfigs returns the list of Kiro API endpoint configurations to try in order. -// Supports dynamic region based on auth metadata "api_region", "profile_arn", or "region" field. -// Supports reordering based on "preferred_endpoint" in auth metadata/attributes. -// -// Region priority: -// 1. auth.Metadata["api_region"] - explicit API region override -// 2. ProfileARN region - extracted from arn:aws:service:REGION:account:resource -// 3. kiroDefaultRegion (us-east-1) - fallback -// Note: OIDC "region" is NOT used - it's for token refresh, not API calls -func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig { - if auth == nil { - return kiroEndpointConfigs - } - - // Determine API region using shared resolution logic - region := resolveKiroAPIRegion(auth) - - // Build endpoint configs for the specified region - endpointConfigs := buildKiroEndpointConfigs(region) - - // For IDC auth, use Q endpoint with AI_EDITOR origin - // IDC tokens work with Q endpoint using Bearer auth - // The difference is only in how tokens are refreshed (OIDC with clientId/clientSecret for IDC) - // NOT in how API calls are made - both Social and IDC use the same endpoint/origin - if auth.Metadata != nil { - authMethod, _ := auth.Metadata["auth_method"].(string) - if strings.ToLower(authMethod) == "idc" { - log.Debugf("kiro: IDC auth, using Q endpoint (region: %s)", region) - return endpointConfigs - } - } - - // Check for preference - var preference string - if auth.Metadata != nil { - if p, ok := auth.Metadata["preferred_endpoint"].(string); ok { - preference = p - } - } - // Check attributes as fallback (e.g. from HTTP headers) - if preference == "" && auth.Attributes != nil { - preference = auth.Attributes["preferred_endpoint"] - } - - if preference == "" { - return endpointConfigs - } - - preference = strings.ToLower(strings.TrimSpace(preference)) - - // Create new slice to avoid modifying global state - var sorted []kiroEndpointConfig - var remaining []kiroEndpointConfig - - for _, cfg := range endpointConfigs { - name := strings.ToLower(cfg.Name) - // Check for matches - // CodeWhisperer aliases: codewhisperer, ide - // AmazonQ aliases: amazonq, q, cli - isMatch := false - if (preference == "codewhisperer" || preference == "ide") && name == "codewhisperer" { - isMatch = true - } else if (preference == "amazonq" || preference == "q" || preference == "cli") && name == "amazonq" { - isMatch = true - } - - if isMatch { - sorted = append(sorted, cfg) - } else { - remaining = append(remaining, cfg) - } - } - - // If preference didn't match anything, return default - if len(sorted) == 0 { - return endpointConfigs - } - - // Combine: preferred first, then others - return append(sorted, remaining...) -} - // KiroExecutor handles requests to AWS CodeWhisperer (Kiro) API. type KiroExecutor struct { cfg *config.Config refreshMu sync.Mutex // Serializes token refresh operations to prevent race conditions } -// isIDCAuth checks if the auth uses IDC (Identity Center) authentication method. -func isIDCAuth(auth *cliproxyauth.Auth) bool { - if auth == nil || auth.Metadata == nil { - return false - } - authMethod, _ := auth.Metadata["auth_method"].(string) - return strings.ToLower(authMethod) == "idc" -} - -// buildKiroPayloadForFormat builds the Kiro API payload based on the source format. -// This is critical because OpenAI and Claude formats have different tool structures: -// - OpenAI: tools[].function.name, tools[].function.description -// - Claude: tools[].name, tools[].description -// headers parameter allows checking Anthropic-Beta header for thinking mode detection. -// Returns the serialized JSON payload and a boolean indicating whether thinking mode was injected. -func buildKiroPayloadForFormat(body []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool, sourceFormat sdktranslator.Format, headers http.Header) ([]byte, bool) { - switch sourceFormat.String() { - case "openai": - log.Debugf("kiro: using OpenAI payload builder for source format: %s", sourceFormat.String()) - return kiroopenai.BuildKiroPayloadFromOpenAI(body, modelID, profileArn, origin, isAgentic, isChatOnly, headers, nil) - case "kiro": - // Body is already in Kiro format — pass through directly - log.Debugf("kiro: body already in Kiro format, passing through directly") - return sanitizeKiroPayload(body), false - default: - // Default to Claude format - log.Debugf("kiro: using Claude payload builder for source format: %s", sourceFormat.String()) - return kiroclaude.BuildKiroPayload(body, modelID, profileArn, origin, isAgentic, isChatOnly, headers, nil) - } -} - -func sanitizeKiroPayload(body []byte) []byte { - var payload map[string]any - if err := json.Unmarshal(body, &payload); err != nil { - return body - } - if _, exists := payload["user"]; !exists { - return body - } - delete(payload, "user") - sanitized, err := json.Marshal(payload) - if err != nil { - return body - } - return sanitized -} - // NewKiroExecutor creates a new Kiro executor instance. func NewKiroExecutor(cfg *config.Config) *KiroExecutor { return &KiroExecutor{cfg: cfg} @@ -1073,2669 +847,136 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. return resp, fmt.Errorf("kiro: all endpoints exhausted") } -// ExecuteStream handles streaming requests to Kiro API. -// Supports automatic token refresh on 401/403 errors and quota fallback on 429. -func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { - accessToken, profileArn := kiroCredentials(auth) - if accessToken == "" { - return nil, fmt.Errorf("kiro: access token not found in auth") - } +// kiroCredentials extracts access token and profile ARN from auth. - // Rate limiting: get token key for tracking - tokenKey := getTokenKey(auth) - rateLimiter := kiroauth.GetGlobalRateLimiter() - cooldownMgr := kiroauth.GetGlobalCooldownManager() +// NOTE: Claude SSE event builders moved to pkg/llmproxy/translator/kiro/claude/kiro_claude_stream.go +// The executor now uses kiroclaude.BuildClaude*Event() functions instead - // Check if token is in cooldown period - if cooldownMgr.IsInCooldown(tokenKey) { - remaining := cooldownMgr.GetRemainingCooldown(tokenKey) - reason := cooldownMgr.GetCooldownReason(tokenKey) - log.Warnf("kiro: token %s is in cooldown (reason: %s), remaining: %v", tokenKey, reason, remaining) - return nil, fmt.Errorf("kiro: token is in cooldown for %v (reason: %s)", remaining, reason) +// CountTokens counts tokens locally using tiktoken since Kiro API doesn't expose a token counting endpoint. +// This provides approximate token counts for client requests. +func (e *KiroExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + // Use tiktoken for local token counting + enc, err := getTokenizer(req.Model) + if err != nil { + log.Warnf("kiro: CountTokens failed to get tokenizer: %v, falling back to estimate", err) + // Fallback: estimate from payload size (roughly 4 chars per token) + estimatedTokens := len(req.Payload) / 4 + if estimatedTokens == 0 && len(req.Payload) > 0 { + estimatedTokens = 1 + } + return cliproxyexecutor.Response{ + Payload: []byte(fmt.Sprintf(`{"count":%d}`, estimatedTokens)), + }, nil } - // Wait for rate limiter before proceeding - log.Debugf("kiro: stream waiting for rate limiter for token %s", tokenKey) - rateLimiter.WaitForToken(tokenKey) - log.Debugf("kiro: stream rate limiter cleared for token %s", tokenKey) - - // Check if token is expired before making request (covers both normal and web_search paths) - if e.isTokenExpired(accessToken) { - log.Infof("kiro: access token expired, attempting recovery before stream request") + // Try to count tokens from the request payload + var totalTokens int64 - // 方案 B: 先尝试从文件重新加载 token(后台刷新器可能已更新文件) - reloadedAuth, reloadErr := e.reloadAuthFromFile(auth) - if reloadErr == nil && reloadedAuth != nil { - // 文件中有更新的 token,使用它 - auth = reloadedAuth - accessToken, profileArn = kiroCredentials(auth) - log.Infof("kiro: recovered token from file (background refresh) for stream, expires_at: %v", auth.Metadata["expires_at"]) + // Try OpenAI chat format first + if tokens, countErr := countOpenAIChatTokens(enc, req.Payload); countErr == nil && tokens > 0 { + totalTokens = tokens + log.Debugf("kiro: CountTokens counted %d tokens using OpenAI chat format", totalTokens) + } else { + // Fallback: count raw payload tokens + if tokenCount, countErr := enc.Count(string(req.Payload)); countErr == nil { + totalTokens = int64(tokenCount) + log.Debugf("kiro: CountTokens counted %d tokens from raw payload", totalTokens) } else { - // 文件中的 token 也过期了,执行主动刷新 - log.Debugf("kiro: file reload failed (%v), attempting active refresh for stream", reloadErr) - refreshedAuth, refreshErr := e.Refresh(ctx, auth) - if refreshErr != nil { - log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr) - } else if refreshedAuth != nil { - auth = refreshedAuth - // Persist the refreshed auth to file so subsequent requests use it - if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { - log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) - } - accessToken, profileArn = kiroCredentials(auth) - log.Infof("kiro: token refreshed successfully before stream request") + // Final fallback: estimate from payload size + totalTokens = int64(len(req.Payload) / 4) + if totalTokens == 0 && len(req.Payload) > 0 { + totalTokens = 1 } + log.Debugf("kiro: CountTokens estimated %d tokens from payload size", totalTokens) } } - // Check for pure web_search request - // Route to MCP endpoint instead of normal Kiro API - if kiroclaude.HasWebSearchTool(req.Payload) { - log.Infof("kiro: detected pure web_search request, routing to MCP endpoint") - streamWebSearch, errWebSearch := e.handleWebSearchStream(ctx, auth, req, opts, accessToken, profileArn) - if errWebSearch != nil { - return nil, errWebSearch - } - return &cliproxyexecutor.StreamResult{Chunks: streamWebSearch}, nil - } - - reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) - defer reporter.trackFailure(ctx, &err) - - from := opts.SourceFormat - to := sdktranslator.FromString("kiro") - body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) - - kiroModelID := e.mapModelToKiro(req.Model) + return cliproxyexecutor.Response{ + Payload: []byte(fmt.Sprintf(`{"count":%d}`, totalTokens)), + }, nil +} - // Determine agentic mode and effective profile ARN using helper functions - isAgentic, isChatOnly := determineAgenticMode(req.Model) - effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn) +// Refresh refreshes the Kiro OAuth token. +// Supports both AWS Builder ID (SSO OIDC) and Google OAuth (social login). +// Uses mutex to prevent race conditions when multiple concurrent requests try to refresh. +func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { + // Serialize token refresh operations to prevent race conditions + e.refreshMu.Lock() + defer e.refreshMu.Unlock() - // Execute stream with retry on 401/403 and 429 (quota exhausted) - // Note: currentOrigin and kiroPayload are built inside executeStreamWithRetry for each endpoint - streamKiro, errStreamKiro := e.executeStreamWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, body, from, reporter, kiroModelID, isAgentic, isChatOnly, tokenKey) - if errStreamKiro != nil { - return nil, errStreamKiro + var authID string + if auth != nil { + authID = auth.ID + } else { + authID = "" + } + log.Debugf("kiro executor: refresh called for auth %s", authID) + if auth == nil { + return nil, fmt.Errorf("kiro executor: auth is nil") } - return &cliproxyexecutor.StreamResult{Chunks: streamKiro}, nil -} -// executeStreamWithRetry performs the streaming HTTP request with automatic retry on auth errors. -// Supports automatic fallback between endpoints with different quotas: -// - Amazon Q endpoint (CLI origin) uses Amazon Q Developer quota -// - CodeWhisperer endpoint (AI_EDITOR origin) uses Kiro IDE quota -// Also supports multi-endpoint fallback similar to Antigravity implementation. -// tokenKey is used for rate limiting and cooldown tracking. -func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, accessToken, profileArn string, body []byte, from sdktranslator.Format, reporter *usageReporter, kiroModelID string, isAgentic, isChatOnly bool, tokenKey string) (<-chan cliproxyexecutor.StreamChunk, error) { - var currentOrigin string - maxRetries := 2 // Allow retries for token refresh + endpoint fallback - rateLimiter := kiroauth.GetGlobalRateLimiter() - cooldownMgr := kiroauth.GetGlobalCooldownManager() - endpointConfigs := getKiroEndpointConfigs(auth) - var last429Err error + // Double-check: After acquiring lock, verify token still needs refresh + // Another goroutine may have already refreshed while we were waiting + // NOTE: This check has a design limitation - it reads from the auth object passed in, + // not from persistent storage. If another goroutine returns a new Auth object (via Clone), + // this check won't see those updates. The mutex still prevents truly concurrent refreshes, + // but queued goroutines may still attempt redundant refreshes. This is acceptable as + // the refresh operation is idempotent and the extra API calls are infrequent. + if auth.Metadata != nil { + if lastRefresh, ok := auth.Metadata["last_refresh"].(string); ok { + if refreshTime, err := time.Parse(time.RFC3339, lastRefresh); err == nil { + // If token was refreshed within the last 30 seconds, skip refresh + if time.Since(refreshTime) < 30*time.Second { + log.Debugf("kiro executor: token was recently refreshed by another goroutine, skipping") + return auth, nil + } + } + } + // Also check if expires_at is now in the future with sufficient buffer + if expiresAt, ok := auth.Metadata["expires_at"].(string); ok { + if expTime, err := time.Parse(time.RFC3339, expiresAt); err == nil { + // If token expires more than 20 minutes from now, it's still valid + if time.Until(expTime) > 20*time.Minute { + log.Debugf("kiro executor: token is still valid (expires in %v), skipping refresh", time.Until(expTime)) + // CRITICAL FIX: Set NextRefreshAfter to prevent frequent refresh checks + // Without this, shouldRefresh() will return true again in 30 seconds + updated := auth.Clone() + // Set next refresh to 20 minutes before expiry, or at least 30 seconds from now + nextRefresh := expTime.Add(-20 * time.Minute) + minNextRefresh := time.Now().Add(30 * time.Second) + if nextRefresh.Before(minNextRefresh) { + nextRefresh = minNextRefresh + } + updated.NextRefreshAfter = nextRefresh + log.Debugf("kiro executor: setting NextRefreshAfter to %v (in %v)", nextRefresh.Format(time.RFC3339), time.Until(nextRefresh)) + return updated, nil + } + } + } + } - for endpointIdx := 0; endpointIdx < len(endpointConfigs); endpointIdx++ { - endpointConfig := endpointConfigs[endpointIdx] - url := endpointConfig.URL - // Use this endpoint's compatible Origin (critical for avoiding 403 errors) - currentOrigin = endpointConfig.Origin + var refreshToken string + var clientID, clientSecret string + var authMethod string + var region, startURL string - // Rebuild payload with the correct origin for this endpoint - // Each endpoint requires its matching Origin value in the request body - kiroPayload, thinkingEnabled := buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers) + if auth.Metadata != nil { + refreshToken = getMetadataString(auth.Metadata, "refresh_token", "refreshToken") + clientID = getMetadataString(auth.Metadata, "client_id", "clientId") + clientSecret = getMetadataString(auth.Metadata, "client_secret", "clientSecret") + authMethod = strings.ToLower(getMetadataString(auth.Metadata, "auth_method", "authMethod")) + region = getMetadataString(auth.Metadata, "region") + startURL = getMetadataString(auth.Metadata, "start_url", "startUrl") + } - log.Debugf("kiro: stream trying endpoint %d/%d: %s (Name: %s, Origin: %s)", - endpointIdx+1, len(endpointConfigs), url, endpointConfig.Name, currentOrigin) + if refreshToken == "" { + return nil, fmt.Errorf("kiro executor: refresh token not found") + } - for attempt := 0; attempt <= maxRetries; attempt++ { - // Apply human-like delay before first streaming request (not on retries) - // This mimics natural user behavior patterns - // Note: Delay is NOT applied during streaming response - only before initial request - if attempt == 0 && endpointIdx == 0 { - kiroauth.ApplyHumanLikeDelay() - } + var tokenData *kiroauth.KiroTokenData + var err error - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(kiroPayload)) - if err != nil { - return nil, err - } - - httpReq.Header.Set("Content-Type", kiroContentType) - httpReq.Header.Set("Accept", kiroAcceptStream) - // Only set X-Amz-Target if specified (Q endpoint doesn't require it) - if endpointConfig.AmzTarget != "" { - httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget) - } - // Kiro-specific headers - httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeVibe) - httpReq.Header.Set("x-amzn-codewhisperer-optout", "true") - - // Apply dynamic fingerprint-based headers - applyDynamicFingerprint(httpReq, auth) - - httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") - httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) - - // Bearer token authentication for all auth types (Builder ID, IDC, social, etc.) - httpReq.Header.Set("Authorization", "Bearer "+accessToken) - - var attrs map[string]string - if auth != nil { - attrs = auth.Attributes - } - util.ApplyCustomHeadersFromAttrs(httpReq, attrs) - - var authID, authLabel, authType, authValue string - if auth != nil { - authID = auth.ID - authLabel = auth.Label - authType, authValue = auth.AccountInfo() - } - recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ - URL: url, - Method: http.MethodPost, - Headers: httpReq.Header.Clone(), - Body: kiroPayload, - Provider: e.Identifier(), - AuthID: authID, - AuthLabel: authLabel, - AuthType: authType, - AuthValue: authValue, - }) - - httpClient := newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 0) - httpResp, err := httpClient.Do(httpReq) - if err != nil { - recordAPIResponseError(ctx, e.cfg, err) - - // Enhanced socket retry for streaming: Check if error is retryable (network timeout, connection reset, etc.) - retryCfg := defaultRetryConfig() - if isRetryableError(err) && attempt < retryCfg.MaxRetries { - delay := calculateRetryDelay(attempt, retryCfg) - logRetryAttempt(attempt, retryCfg.MaxRetries, fmt.Sprintf("stream socket error: %v", err), delay, endpointConfig.Name) - time.Sleep(delay) - continue - } - - return nil, err - } - recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) - - // Handle 429 errors (quota exhausted) - try next endpoint - // Each endpoint has its own quota pool, so we can try different endpoints - if httpResp.StatusCode == 429 { - respBody, _ := io.ReadAll(httpResp.Body) - _ = httpResp.Body.Close() - appendAPIResponseChunk(ctx, e.cfg, respBody) - - // Record failure and set cooldown for 429 - rateLimiter.MarkTokenFailed(tokenKey) - cooldownDuration := kiroauth.CalculateCooldownFor429(attempt) - cooldownMgr.SetCooldown(tokenKey, cooldownDuration, kiroauth.CooldownReason429) - log.Warnf("kiro: stream rate limit hit (429), token %s set to cooldown for %v", tokenKey, cooldownDuration) - - // Preserve last 429 so callers can correctly backoff when all endpoints are exhausted - last429Err = statusErr{code: httpResp.StatusCode, msg: string(respBody)} - - log.Warnf("kiro: stream %s endpoint quota exhausted (429), will try next endpoint, body: %s", - endpointConfig.Name, summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody)) - - // Break inner retry loop to try next endpoint (which has different quota) - break - } - - // Handle 5xx server errors with exponential backoff retry - // Enhanced: Use retryConfig for consistent retry behavior - if httpResp.StatusCode >= 500 && httpResp.StatusCode < 600 { - respBody, _ := io.ReadAll(httpResp.Body) - _ = httpResp.Body.Close() - appendAPIResponseChunk(ctx, e.cfg, respBody) - - retryCfg := defaultRetryConfig() - // Check if this specific 5xx code is retryable (502, 503, 504) - if isRetryableHTTPStatus(httpResp.StatusCode) && attempt < retryCfg.MaxRetries { - delay := calculateRetryDelay(attempt, retryCfg) - logRetryAttempt(attempt, retryCfg.MaxRetries, fmt.Sprintf("stream HTTP %d", httpResp.StatusCode), delay, endpointConfig.Name) - time.Sleep(delay) - continue - } else if attempt < maxRetries { - // Fallback for other 5xx errors (500, 501, etc.) - backoff := time.Duration(1< 30*time.Second { - backoff = 30 * time.Second - } - log.Warnf("kiro: stream server error %d, retrying in %v (attempt %d/%d)", httpResp.StatusCode, backoff, attempt+1, maxRetries) - time.Sleep(backoff) - continue - } - log.Errorf("kiro: stream server error %d after %d retries", httpResp.StatusCode, maxRetries) - return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} - } - - // Handle 400 errors - Credential/Validation issues - // Do NOT switch endpoints - return error immediately - if httpResp.StatusCode == 400 { - respBody, _ := io.ReadAll(httpResp.Body) - _ = httpResp.Body.Close() - appendAPIResponseChunk(ctx, e.cfg, respBody) - - log.Warnf("kiro: received 400 error (attempt %d/%d), body: %s", attempt+1, maxRetries+1, summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody)) - - // 400 errors indicate request validation issues - return immediately without retry - return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} - } - - // Handle 401 errors with token refresh and retry - // 401 = Unauthorized (token expired/invalid) - refresh token - if httpResp.StatusCode == 401 { - respBody, _ := io.ReadAll(httpResp.Body) - _ = httpResp.Body.Close() - appendAPIResponseChunk(ctx, e.cfg, respBody) - - log.Warnf("kiro: stream received 401 error, attempting token refresh") - refreshedAuth, refreshErr := e.Refresh(ctx, auth) - if refreshErr != nil { - log.Errorf("kiro: token refresh failed: %v", refreshErr) - return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} - } - - if refreshedAuth != nil { - auth = refreshedAuth - // Persist the refreshed auth to file so subsequent requests use it - if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { - log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) - // Continue anyway - the token is valid for this request - } - accessToken, profileArn = kiroCredentials(auth) - // Rebuild payload with new profile ARN if changed - kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers) - if attempt < maxRetries { - log.Infof("kiro: token refreshed successfully, retrying stream request (attempt %d/%d)", attempt+1, maxRetries+1) - continue - } - log.Infof("kiro: token refreshed successfully, no retries remaining") - } - - log.Warnf("kiro stream error, status: 401, body: %s", string(respBody)) - return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} - } - - // Handle 402 errors - Monthly Limit Reached - if httpResp.StatusCode == 402 { - respBody, _ := io.ReadAll(httpResp.Body) - _ = httpResp.Body.Close() - appendAPIResponseChunk(ctx, e.cfg, respBody) - - log.Warnf("kiro: stream received 402 (monthly limit). Upstream body: %s", string(respBody)) - - // Return upstream error body directly - return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} - } - - // Handle 403 errors - Access Denied / Token Expired - // Do NOT switch endpoints for 403 errors - if httpResp.StatusCode == 403 { - respBody, _ := io.ReadAll(httpResp.Body) - _ = httpResp.Body.Close() - appendAPIResponseChunk(ctx, e.cfg, respBody) - - // Log the 403 error details for debugging - log.Warnf("kiro: stream received 403 error (attempt %d/%d), body: %s", attempt+1, maxRetries+1, string(respBody)) - - respBodyStr := string(respBody) - - // Check for SUSPENDED status - return immediately without retry - if strings.Contains(respBodyStr, "SUSPENDED") || strings.Contains(respBodyStr, "TEMPORARILY_SUSPENDED") { - // Set long cooldown for suspended accounts - rateLimiter.CheckAndMarkSuspended(tokenKey, respBodyStr) - cooldownMgr.SetCooldown(tokenKey, kiroauth.LongCooldown, kiroauth.CooldownReasonSuspended) - log.Errorf("kiro: stream account is suspended, token %s set to cooldown for %v", tokenKey, kiroauth.LongCooldown) - return nil, statusErr{code: httpResp.StatusCode, msg: "account suspended: " + string(respBody)} - } - - // Check if this looks like a token-related 403 (some APIs return 403 for expired tokens) - isTokenRelated := strings.Contains(respBodyStr, "token") || - strings.Contains(respBodyStr, "expired") || - strings.Contains(respBodyStr, "invalid") || - strings.Contains(respBodyStr, "unauthorized") - - if isTokenRelated && attempt < maxRetries { - log.Warnf("kiro: 403 appears token-related, attempting token refresh") - refreshedAuth, refreshErr := e.Refresh(ctx, auth) - if refreshErr != nil { - log.Errorf("kiro: token refresh failed: %v", refreshErr) - // Token refresh failed - return error immediately - return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} - } - if refreshedAuth != nil { - auth = refreshedAuth - // Persist the refreshed auth to file so subsequent requests use it - if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { - log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) - // Continue anyway - the token is valid for this request - } - accessToken, profileArn = kiroCredentials(auth) - kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers) - log.Infof("kiro: token refreshed for 403, retrying stream request") - continue - } - } - - // For non-token 403 or after max retries, return error immediately - // Do NOT switch endpoints for 403 errors - log.Warnf("kiro: 403 error, returning immediately (no endpoint switch)") - return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} - } - - if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { - b, _ := io.ReadAll(httpResp.Body) - appendAPIResponseChunk(ctx, e.cfg, b) - log.Debugf("kiro stream error, status: %d, body: %s", httpResp.StatusCode, string(b)) - if errClose := httpResp.Body.Close(); errClose != nil { - log.Errorf("response body close error: %v", errClose) - } - return nil, statusErr{code: httpResp.StatusCode, msg: string(b)} - } - - out := make(chan cliproxyexecutor.StreamChunk) - - // Record success immediately since connection was established successfully - // Streaming errors will be handled separately - rateLimiter.MarkTokenSuccess(tokenKey) - log.Debugf("kiro: stream request successful, token %s marked as success", tokenKey) - - go func(resp *http.Response, thinkingEnabled bool) { - defer close(out) - defer func() { - if r := recover(); r != nil { - log.Errorf("kiro: panic in stream handler: %v", r) - out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("internal error: %v", r)} - } - }() - defer func() { - if errClose := resp.Body.Close(); errClose != nil { - log.Errorf("response body close error: %v", errClose) - } - }() - - // Kiro API always returns tags regardless of request parameters - // So we always enable thinking parsing for Kiro responses - log.Debugf("kiro: stream thinkingEnabled = %v (always true for Kiro)", thinkingEnabled) - - e.streamToChannel(ctx, resp.Body, out, from, payloadRequestedModel(opts, req.Model), opts.OriginalRequest, body, reporter, thinkingEnabled) - }(httpResp, thinkingEnabled) - - return out, nil - } - // Inner retry loop exhausted for this endpoint, try next endpoint - // Note: This code is unreachable because all paths in the inner loop - // either return or continue. Kept as comment for documentation. - } - - // All endpoints exhausted - if last429Err != nil { - return nil, last429Err - } - return nil, fmt.Errorf("kiro: stream all endpoints exhausted") -} - -// kiroCredentials extracts access token and profile ARN from auth. -func kiroCredentials(auth *cliproxyauth.Auth) (accessToken, profileArn string) { - if auth == nil { - return "", "" - } - - // Try Metadata first (wrapper format) - if auth.Metadata != nil { - if token, ok := auth.Metadata["access_token"].(string); ok { - accessToken = token - } - if arn, ok := auth.Metadata["profile_arn"].(string); ok { - profileArn = arn - } - } - - // Try Attributes - if accessToken == "" && auth.Attributes != nil { - accessToken = auth.Attributes["access_token"] - profileArn = auth.Attributes["profile_arn"] - } - - // Try direct fields from flat JSON format (new AWS Builder ID format) - if accessToken == "" && auth.Metadata != nil { - if token, ok := auth.Metadata["accessToken"].(string); ok { - accessToken = token - } - if arn, ok := auth.Metadata["profileArn"].(string); ok { - profileArn = arn - } - } - - return accessToken, profileArn -} - -// findRealThinkingEndTag finds the real end tag, skipping false positives. -// Returns -1 if no real end tag is found. -// -// Real tags from Kiro API have specific characteristics: -// - Usually preceded by newline (.\n) -// - Usually followed by newline (\n\n) -// - Not inside code blocks or inline code -// -// False positives (discussion text) have characteristics: -// - In the middle of a sentence -// - Preceded by discussion words like "标签", "tag", "returns" -// - Inside code blocks or inline code -// -// Parameters: -// - content: the content to search in -// - alreadyInCodeBlock: whether we're already inside a code block from previous chunks -// - alreadyInInlineCode: whether we're already inside inline code from previous chunks - -// determineAgenticMode determines if the model is an agentic or chat-only variant. -// Returns (isAgentic, isChatOnly) based on model name suffixes. -func determineAgenticMode(model string) (isAgentic, isChatOnly bool) { - isAgentic = strings.HasSuffix(model, "-agentic") - isChatOnly = strings.HasSuffix(model, "-chat") - return isAgentic, isChatOnly -} - -func getMetadataString(metadata map[string]any, keys ...string) string { - if metadata == nil { - return "" - } - for _, key := range keys { - if value, ok := metadata[key].(string); ok { - trimmed := strings.TrimSpace(value) - if trimmed != "" { - return trimmed - } - } - } - return "" -} - -// getEffectiveProfileArn determines if profileArn should be included based on auth method. -// profileArn is only needed for social auth (Google OAuth), not for AWS SSO OIDC (Builder ID/IDC). -// -// Detection logic (matching kiro-openai-gateway): -// 1. Check auth_method field: "builder-id" or "idc" -// 2. Check auth_type field: "aws_sso_oidc" (from kiro-cli tokens) -// 3. Check for client_id + client_secret presence (AWS SSO OIDC signature) - -// getEffectiveProfileArnWithWarning determines if profileArn should be included based on auth method, -// and logs a warning if profileArn is missing for non-builder-id auth. -// This consolidates the auth_method check that was previously done separately. -// -// AWS SSO OIDC (Builder ID/IDC) users don't need profileArn - sending it causes 403 errors. -// Only Kiro Desktop (social auth like Google/GitHub) users need profileArn. -// -// Detection logic (matching kiro-openai-gateway): -// 1. Check auth_method field: "builder-id" or "idc" -// 2. Check auth_type field: "aws_sso_oidc" (from kiro-cli tokens) -// 3. Check for client_id + client_secret presence (AWS SSO OIDC signature) -func getEffectiveProfileArnWithWarning(auth *cliproxyauth.Auth, profileArn string) string { - if auth != nil && auth.Metadata != nil { - // Check 1: auth_method field (from CLIProxyAPI tokens) - authMethod := strings.ToLower(getMetadataString(auth.Metadata, "auth_method", "authMethod")) - if authMethod == "builder-id" || authMethod == "idc" { - return "" // AWS SSO OIDC - don't include profileArn - } - // Check 2: auth_type field (from kiro-cli tokens) - if authType, ok := auth.Metadata["auth_type"].(string); ok && authType == "aws_sso_oidc" { - return "" // AWS SSO OIDC - don't include profileArn - } - // Check 3: client_id + client_secret presence (AWS SSO OIDC signature, like kiro-openai-gateway) - clientID := getMetadataString(auth.Metadata, "client_id", "clientId") - clientSecret := getMetadataString(auth.Metadata, "client_secret", "clientSecret") - if clientID != "" && clientSecret != "" { - return "" // AWS SSO OIDC - don't include profileArn - } - } - // For social auth (Kiro Desktop), profileArn is required - if profileArn == "" { - log.Warnf("kiro: profile ARN not found in auth, API calls may fail") - } - return profileArn -} - -// mapModelToKiro maps external model names to Kiro model IDs. -// Supports both Kiro and Amazon Q prefixes since they use the same API. -// Agentic variants (-agentic suffix) map to the same backend model IDs. -func (e *KiroExecutor) mapModelToKiro(model string) string { - modelMap := map[string]string{ - // Amazon Q format (amazonq- prefix) - same API as Kiro - "amazonq-auto": "auto", - "amazonq-claude-opus-4-6": "claude-opus-4.6", - "amazonq-claude-sonnet-4-6": "claude-sonnet-4.6", - "amazonq-claude-opus-4-5": "claude-opus-4.5", - "amazonq-claude-sonnet-4-5": "claude-sonnet-4.5", - "amazonq-claude-sonnet-4-5-20250929": "claude-sonnet-4.5", - "amazonq-claude-sonnet-4": "claude-sonnet-4", - "amazonq-claude-sonnet-4-20250514": "claude-sonnet-4", - "amazonq-claude-haiku-4-5": "claude-haiku-4.5", - // Kiro format (kiro- prefix) - valid model names that should be preserved - "kiro-claude-opus-4-6": "claude-opus-4.6", - "kiro-claude-sonnet-4-6": "claude-sonnet-4.6", - "kiro-claude-opus-4-5": "claude-opus-4.5", - "kiro-claude-sonnet-4-5": "claude-sonnet-4.5", - "kiro-claude-sonnet-4-5-20250929": "claude-sonnet-4.5", - "kiro-claude-sonnet-4": "claude-sonnet-4", - "kiro-claude-sonnet-4-20250514": "claude-sonnet-4", - "kiro-claude-haiku-4-5": "claude-haiku-4.5", - "kiro-auto": "auto", - // Native format (no prefix) - used by Kiro IDE directly - "claude-opus-4-6": "claude-opus-4.6", - "claude-opus-4.6": "claude-opus-4.6", - "claude-sonnet-4-6": "claude-sonnet-4.6", - "claude-sonnet-4.6": "claude-sonnet-4.6", - "claude-opus-4-5": "claude-opus-4.5", - "claude-opus-4.5": "claude-opus-4.5", - "claude-haiku-4-5": "claude-haiku-4.5", - "claude-haiku-4.5": "claude-haiku-4.5", - "claude-sonnet-4-5": "claude-sonnet-4.5", - "claude-sonnet-4-5-20250929": "claude-sonnet-4.5", - "claude-sonnet-4.5": "claude-sonnet-4.5", - "claude-sonnet-4": "claude-sonnet-4", - "claude-sonnet-4-20250514": "claude-sonnet-4", - "auto": "auto", - // Agentic variants (same backend model IDs, but with special system prompt) - "claude-opus-4.6-agentic": "claude-opus-4.6", - "claude-sonnet-4.6-agentic": "claude-sonnet-4.6", - "claude-opus-4.5-agentic": "claude-opus-4.5", - "claude-sonnet-4.5-agentic": "claude-sonnet-4.5", - "claude-sonnet-4-agentic": "claude-sonnet-4", - "claude-haiku-4.5-agentic": "claude-haiku-4.5", - "kiro-claude-opus-4-6-agentic": "claude-opus-4.6", - "kiro-claude-sonnet-4-6-agentic": "claude-sonnet-4.6", - "kiro-claude-opus-4-5-agentic": "claude-opus-4.5", - "kiro-claude-sonnet-4-5-agentic": "claude-sonnet-4.5", - "kiro-claude-sonnet-4-agentic": "claude-sonnet-4", - "kiro-claude-haiku-4-5-agentic": "claude-haiku-4.5", - } - if kiroID, ok := modelMap[model]; ok { - return kiroID - } - - // Smart fallback: try to infer model type from name patterns - modelLower := strings.ToLower(model) - - // Check for Haiku variants - if strings.Contains(modelLower, "haiku") { - log.Debug("kiro: unknown haiku variant, mapping to claude-haiku-4.5") - return "claude-haiku-4.5" - } - - // Check for Sonnet variants - if strings.Contains(modelLower, "sonnet") { - // Check for specific version patterns - if strings.Contains(modelLower, "3-7") || strings.Contains(modelLower, "3.7") { - log.Debug("kiro: unknown sonnet 3.7 variant, mapping to claude-3-7-sonnet-20250219") - return "claude-3-7-sonnet-20250219" - } - if strings.Contains(modelLower, "4-6") || strings.Contains(modelLower, "4.6") { - log.Debug("kiro: unknown sonnet 4.6 variant, mapping to claude-sonnet-4.6") - return "claude-sonnet-4.6" - } - if strings.Contains(modelLower, "4-5") || strings.Contains(modelLower, "4.5") { - log.Debug("kiro: unknown Sonnet 4.5 model, mapping to claude-sonnet-4.5") - return "claude-sonnet-4.5" - } - } - - // Check for Opus variants - if strings.Contains(modelLower, "opus") { - if strings.Contains(modelLower, "4-6") || strings.Contains(modelLower, "4.6") { - log.Debug("kiro: unknown Opus 4.6 model, mapping to claude-opus-4.6") - return "claude-opus-4.6" - } - log.Debug("kiro: unknown opus variant, mapping to claude-opus-4.5") - return "claude-opus-4.5" - } - - // Final fallback to Sonnet 4.5 (most commonly used model) - log.Warn("kiro: unknown model variant, falling back to claude-sonnet-4.5") - return "claude-sonnet-4.5" -} - -func kiroModelFingerprint(model string) string { - trimmed := strings.TrimSpace(model) - if trimmed == "" { - return "" - } - sum := sha256.Sum256([]byte(trimmed)) - return hex.EncodeToString(sum[:8]) -} - -// EventStreamError represents an Event Stream processing error -type EventStreamError struct { - Type string // "fatal", "malformed" - Message string - Cause error -} - -func (e *EventStreamError) Error() string { - if e.Cause != nil { - return fmt.Sprintf("event stream %s: %s: %v", e.Type, e.Message, e.Cause) - } - return fmt.Sprintf("event stream %s: %s", e.Type, e.Message) -} - -// eventStreamMessage represents a parsed AWS Event Stream message -type eventStreamMessage struct { - EventType string // Event type from headers (e.g., "assistantResponseEvent") - Payload []byte // JSON payload of the message -} - -// NOTE: Request building functions moved to pkg/llmproxy/translator/kiro/claude/kiro_claude_request.go -// The executor now uses kiroclaude.BuildKiroPayload() instead - -// parseEventStream parses AWS Event Stream binary format. -// Extracts text content, tool uses, and stop_reason from the response. -// Supports embedded [Called ...] tool calls and input buffering for toolUseEvent. -// Returns: content, toolUses, usageInfo, stopReason, error -func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroclaude.KiroToolUse, usage.Detail, string, error) { - var content strings.Builder - var toolUses []kiroclaude.KiroToolUse - var usageInfo usage.Detail - var stopReason string // Extracted from upstream response - reader := bufio.NewReader(body) - - // Tool use state tracking for input buffering and deduplication - processedIDs := make(map[string]bool) - var currentToolUse *kiroclaude.ToolUseState - - // Upstream usage tracking - Kiro API returns credit usage and context percentage - var upstreamContextPercentage float64 // Context usage percentage from upstream (e.g., 78.56) - - for { - msg, eventErr := e.readEventStreamMessage(reader) - if eventErr != nil { - log.Errorf("kiro: parseEventStream error: %v", eventErr) - return content.String(), toolUses, usageInfo, stopReason, eventErr - } - if msg == nil { - // Normal end of stream (EOF) - break - } - - eventType := msg.EventType - payload := msg.Payload - if len(payload) == 0 { - continue - } - - var event map[string]interface{} - if err := json.Unmarshal(payload, &event); err != nil { - log.Debugf("kiro: skipping malformed event: %v", err) - continue - } - - // Check for error/exception events in the payload (Kiro API may return errors with HTTP 200) - // These can appear as top-level fields or nested within the event - if errType, hasErrType := event["_type"].(string); hasErrType { - // AWS-style error: {"_type": "com.amazon.aws.codewhisperer#ValidationException", "message": "..."} - errMsg := "" - if msg, ok := event["message"].(string); ok { - errMsg = msg - } - log.Errorf("kiro: received AWS error in event stream: type=%s, message=%s", errType, errMsg) - return "", nil, usageInfo, stopReason, fmt.Errorf("kiro API error: %s - %s", errType, errMsg) - } - if errType, hasErrType := event["type"].(string); hasErrType && (errType == "error" || errType == "exception") { - // Generic error event - errMsg := "" - if msg, ok := event["message"].(string); ok { - errMsg = msg - } else if errObj, ok := event["error"].(map[string]interface{}); ok { - if msg, ok := errObj["message"].(string); ok { - errMsg = msg - } - } - log.Errorf("kiro: received error event in stream: type=%s, message=%s", errType, errMsg) - return "", nil, usageInfo, stopReason, fmt.Errorf("kiro API error: %s", errMsg) - } - - // Extract stop_reason from various event formats - // Kiro/Amazon Q API may include stop_reason in different locations - if sr := kirocommon.GetString(event, "stop_reason"); sr != "" { - stopReason = sr - log.Debugf("kiro: parseEventStream found stop_reason (top-level): %s", stopReason) - } - if sr := kirocommon.GetString(event, "stopReason"); sr != "" { - stopReason = sr - log.Debugf("kiro: parseEventStream found stopReason (top-level): %s", stopReason) - } - - // Handle different event types - switch eventType { - case "followupPromptEvent": - // Filter out followupPrompt events - these are UI suggestions, not content - log.Debugf("kiro: parseEventStream ignoring followupPrompt event") - continue - - case "assistantResponseEvent": - if assistantResp, ok := event["assistantResponseEvent"].(map[string]interface{}); ok { - if contentText, ok := assistantResp["content"].(string); ok { - content.WriteString(contentText) - } - // Extract stop_reason from assistantResponseEvent - if sr := kirocommon.GetString(assistantResp, "stop_reason"); sr != "" { - stopReason = sr - log.Debugf("kiro: parseEventStream found stop_reason in assistantResponseEvent: %s", stopReason) - } - if sr := kirocommon.GetString(assistantResp, "stopReason"); sr != "" { - stopReason = sr - log.Debugf("kiro: parseEventStream found stopReason in assistantResponseEvent: %s", stopReason) - } - // Extract tool uses from response - if toolUsesRaw, ok := assistantResp["toolUses"].([]interface{}); ok { - for _, tuRaw := range toolUsesRaw { - if tu, ok := tuRaw.(map[string]interface{}); ok { - toolUseID := kirocommon.GetStringValue(tu, "toolUseId") - // Check for duplicate - if processedIDs[toolUseID] { - log.Debugf("kiro: skipping duplicate tool use from assistantResponse: %s", toolUseID) - continue - } - processedIDs[toolUseID] = true - - toolUse := kiroclaude.KiroToolUse{ - ToolUseID: toolUseID, - Name: kirocommon.GetStringValue(tu, "name"), - } - if input, ok := tu["input"].(map[string]interface{}); ok { - toolUse.Input = input - } - toolUses = append(toolUses, toolUse) - } - } - } - } - // Also try direct format - if contentText, ok := event["content"].(string); ok { - content.WriteString(contentText) - } - // Direct tool uses - if toolUsesRaw, ok := event["toolUses"].([]interface{}); ok { - for _, tuRaw := range toolUsesRaw { - if tu, ok := tuRaw.(map[string]interface{}); ok { - toolUseID := kirocommon.GetStringValue(tu, "toolUseId") - // Check for duplicate - if processedIDs[toolUseID] { - log.Debugf("kiro: skipping duplicate direct tool use: %s", toolUseID) - continue - } - processedIDs[toolUseID] = true - - toolUse := kiroclaude.KiroToolUse{ - ToolUseID: toolUseID, - Name: kirocommon.GetStringValue(tu, "name"), - } - if input, ok := tu["input"].(map[string]interface{}); ok { - toolUse.Input = input - } - toolUses = append(toolUses, toolUse) - } - } - } - - case "toolUseEvent": - // Handle dedicated tool use events with input buffering - completedToolUses, newState := kiroclaude.ProcessToolUseEvent(event, currentToolUse, processedIDs) - currentToolUse = newState - toolUses = append(toolUses, completedToolUses...) - - case "supplementaryWebLinksEvent": - if inputTokens, ok := event["inputTokens"].(float64); ok { - usageInfo.InputTokens = int64(inputTokens) - } - if outputTokens, ok := event["outputTokens"].(float64); ok { - usageInfo.OutputTokens = int64(outputTokens) - } - - case "messageStopEvent", "message_stop": - // Handle message stop events which may contain stop_reason - if sr := kirocommon.GetString(event, "stop_reason"); sr != "" { - stopReason = sr - log.Debugf("kiro: parseEventStream found stop_reason in messageStopEvent: %s", stopReason) - } - if sr := kirocommon.GetString(event, "stopReason"); sr != "" { - stopReason = sr - log.Debugf("kiro: parseEventStream found stopReason in messageStopEvent: %s", stopReason) - } - - case "messageMetadataEvent", "metadataEvent": - // Handle message metadata events which contain token counts - // Official format: { tokenUsage: { outputTokens, totalTokens, uncachedInputTokens, cacheReadInputTokens, cacheWriteInputTokens, contextUsagePercentage } } - var metadata map[string]interface{} - if m, ok := event["messageMetadataEvent"].(map[string]interface{}); ok { - metadata = m - } else if m, ok := event["metadataEvent"].(map[string]interface{}); ok { - metadata = m - } else { - metadata = event // event itself might be the metadata - } - - // Check for nested tokenUsage object (official format) - if tokenUsage, ok := metadata["tokenUsage"].(map[string]interface{}); ok { - // outputTokens - precise output token count - if outputTokens, ok := tokenUsage["outputTokens"].(float64); ok { - usageInfo.OutputTokens = int64(outputTokens) - log.Infof("kiro: parseEventStream found precise outputTokens in tokenUsage: %d", usageInfo.OutputTokens) - } - // totalTokens - precise total token count - if totalTokens, ok := tokenUsage["totalTokens"].(float64); ok { - usageInfo.TotalTokens = int64(totalTokens) - log.Infof("kiro: parseEventStream found precise totalTokens in tokenUsage: %d", usageInfo.TotalTokens) - } - // uncachedInputTokens - input tokens not from cache - if uncachedInputTokens, ok := tokenUsage["uncachedInputTokens"].(float64); ok { - usageInfo.InputTokens = int64(uncachedInputTokens) - log.Infof("kiro: parseEventStream found uncachedInputTokens in tokenUsage: %d", usageInfo.InputTokens) - } - // cacheReadInputTokens - tokens read from cache - if cacheReadTokens, ok := tokenUsage["cacheReadInputTokens"].(float64); ok { - // Add to input tokens if we have uncached tokens, otherwise use as input - if usageInfo.InputTokens > 0 { - usageInfo.InputTokens += int64(cacheReadTokens) - } else { - usageInfo.InputTokens = int64(cacheReadTokens) - } - log.Debugf("kiro: parseEventStream found cacheReadInputTokens in tokenUsage: %d", int64(cacheReadTokens)) - } - // contextUsagePercentage - can be used as fallback for input token estimation - if ctxPct, ok := tokenUsage["contextUsagePercentage"].(float64); ok { - upstreamContextPercentage = ctxPct - log.Debugf("kiro: parseEventStream found contextUsagePercentage in tokenUsage: %.2f%%", ctxPct) - } - } - - // Fallback: check for direct fields in metadata (legacy format) - if usageInfo.InputTokens == 0 { - if inputTokens, ok := metadata["inputTokens"].(float64); ok { - usageInfo.InputTokens = int64(inputTokens) - log.Debugf("kiro: parseEventStream found inputTokens in messageMetadataEvent: %d", usageInfo.InputTokens) - } - } - if usageInfo.OutputTokens == 0 { - if outputTokens, ok := metadata["outputTokens"].(float64); ok { - usageInfo.OutputTokens = int64(outputTokens) - log.Debugf("kiro: parseEventStream found outputTokens in messageMetadataEvent: %d", usageInfo.OutputTokens) - } - } - if usageInfo.TotalTokens == 0 { - if totalTokens, ok := metadata["totalTokens"].(float64); ok { - usageInfo.TotalTokens = int64(totalTokens) - log.Debugf("kiro: parseEventStream found totalTokens in messageMetadataEvent: %d", usageInfo.TotalTokens) - } - } - - case "usageEvent", "usage": - // Handle dedicated usage events - if inputTokens, ok := event["inputTokens"].(float64); ok { - usageInfo.InputTokens = int64(inputTokens) - log.Debugf("kiro: parseEventStream found inputTokens in usageEvent: %d", usageInfo.InputTokens) - } - if outputTokens, ok := event["outputTokens"].(float64); ok { - usageInfo.OutputTokens = int64(outputTokens) - log.Debugf("kiro: parseEventStream found outputTokens in usageEvent: %d", usageInfo.OutputTokens) - } - if totalTokens, ok := event["totalTokens"].(float64); ok { - usageInfo.TotalTokens = int64(totalTokens) - log.Debugf("kiro: parseEventStream found totalTokens in usageEvent: %d", usageInfo.TotalTokens) - } - // Also check nested usage object - if usageObj, ok := event["usage"].(map[string]interface{}); ok { - if inputTokens, ok := usageObj["input_tokens"].(float64); ok { - usageInfo.InputTokens = int64(inputTokens) - } else if inputTokens, ok := usageObj["prompt_tokens"].(float64); ok { - usageInfo.InputTokens = int64(inputTokens) - } - if outputTokens, ok := usageObj["output_tokens"].(float64); ok { - usageInfo.OutputTokens = int64(outputTokens) - } else if outputTokens, ok := usageObj["completion_tokens"].(float64); ok { - usageInfo.OutputTokens = int64(outputTokens) - } - if totalTokens, ok := usageObj["total_tokens"].(float64); ok { - usageInfo.TotalTokens = int64(totalTokens) - } - log.Debugf("kiro: parseEventStream found usage object: input=%d, output=%d, total=%d", - usageInfo.InputTokens, usageInfo.OutputTokens, usageInfo.TotalTokens) - } - - case "metricsEvent": - // Handle metrics events which may contain usage data - if metrics, ok := event["metricsEvent"].(map[string]interface{}); ok { - if inputTokens, ok := metrics["inputTokens"].(float64); ok { - usageInfo.InputTokens = int64(inputTokens) - } - if outputTokens, ok := metrics["outputTokens"].(float64); ok { - usageInfo.OutputTokens = int64(outputTokens) - } - log.Debugf("kiro: parseEventStream found metricsEvent: input=%d, output=%d", - usageInfo.InputTokens, usageInfo.OutputTokens) - } - - case "meteringEvent": - // Handle metering events from Kiro API (usage billing information) - // Official format: { unit: string, unitPlural: string, usage: number } - if metering, ok := event["meteringEvent"].(map[string]interface{}); ok { - unit := "" - if u, ok := metering["unit"].(string); ok { - unit = u - } - usageVal := 0.0 - if u, ok := metering["usage"].(float64); ok { - usageVal = u - } - log.Infof("kiro: parseEventStream received meteringEvent: usage=%.2f %s", usageVal, unit) - // Store metering info for potential billing/statistics purposes - // Note: This is separate from token counts - it's AWS billing units - } else { - // Try direct fields - unit := "" - if u, ok := event["unit"].(string); ok { - unit = u - } - usageVal := 0.0 - if u, ok := event["usage"].(float64); ok { - usageVal = u - } - if unit != "" || usageVal > 0 { - log.Infof("kiro: parseEventStream received meteringEvent (direct): usage=%.2f %s", usageVal, unit) - } - } - - case "contextUsageEvent": - // Handle context usage events from Kiro API - // Format: {"contextUsageEvent": {"contextUsagePercentage": 0.53}} - if ctxUsage, ok := event["contextUsageEvent"].(map[string]interface{}); ok { - if ctxPct, ok := ctxUsage["contextUsagePercentage"].(float64); ok { - upstreamContextPercentage = ctxPct - log.Debugf("kiro: parseEventStream received contextUsageEvent: %.2f%%", ctxPct*100) - } - } else { - // Try direct field (fallback) - if ctxPct, ok := event["contextUsagePercentage"].(float64); ok { - upstreamContextPercentage = ctxPct - log.Debugf("kiro: parseEventStream received contextUsagePercentage (direct): %.2f%%", ctxPct*100) - } - } - - case "error", "exception", "internalServerException", "invalidStateEvent": - // Handle error events from Kiro API stream - errMsg := "" - errType := eventType - - // Try to extract error message from various formats - if msg, ok := event["message"].(string); ok { - errMsg = msg - } else if errObj, ok := event[eventType].(map[string]interface{}); ok { - if msg, ok := errObj["message"].(string); ok { - errMsg = msg - } - if t, ok := errObj["type"].(string); ok { - errType = t - } - } else if errObj, ok := event["error"].(map[string]interface{}); ok { - if msg, ok := errObj["message"].(string); ok { - errMsg = msg - } - if t, ok := errObj["type"].(string); ok { - errType = t - } - } - - // Check for specific error reasons - if reason, ok := event["reason"].(string); ok { - errMsg = fmt.Sprintf("%s (reason: %s)", errMsg, reason) - } - - log.Errorf("kiro: parseEventStream received error event: type=%s, message=%s", errType, errMsg) - - // For invalidStateEvent, we may want to continue processing other events - if eventType == "invalidStateEvent" { - log.Warnf("kiro: invalidStateEvent received, continuing stream processing") - continue - } - - // For other errors, return the error - if errMsg != "" { - return "", nil, usageInfo, stopReason, fmt.Errorf("kiro API error (%s): %s", errType, errMsg) - } - - default: - // Check for contextUsagePercentage in any event - if ctxPct, ok := event["contextUsagePercentage"].(float64); ok { - upstreamContextPercentage = ctxPct - log.Debugf("kiro: parseEventStream received context usage: %.2f%%", upstreamContextPercentage) - } - // Log unknown event types for debugging (to discover new event formats) - log.Debugf("kiro: parseEventStream unknown event type: %s, payload: %s", eventType, string(payload)) - } - - // Check for direct token fields in any event (fallback) - if usageInfo.InputTokens == 0 { - if inputTokens, ok := event["inputTokens"].(float64); ok { - usageInfo.InputTokens = int64(inputTokens) - log.Debugf("kiro: parseEventStream found direct inputTokens: %d", usageInfo.InputTokens) - } - } - if usageInfo.OutputTokens == 0 { - if outputTokens, ok := event["outputTokens"].(float64); ok { - usageInfo.OutputTokens = int64(outputTokens) - log.Debugf("kiro: parseEventStream found direct outputTokens: %d", usageInfo.OutputTokens) - } - } - - // Check for usage object in any event (OpenAI format) - if usageInfo.InputTokens == 0 || usageInfo.OutputTokens == 0 { - if usageObj, ok := event["usage"].(map[string]interface{}); ok { - if usageInfo.InputTokens == 0 { - if inputTokens, ok := usageObj["input_tokens"].(float64); ok { - usageInfo.InputTokens = int64(inputTokens) - } else if inputTokens, ok := usageObj["prompt_tokens"].(float64); ok { - usageInfo.InputTokens = int64(inputTokens) - } - } - if usageInfo.OutputTokens == 0 { - if outputTokens, ok := usageObj["output_tokens"].(float64); ok { - usageInfo.OutputTokens = int64(outputTokens) - } else if outputTokens, ok := usageObj["completion_tokens"].(float64); ok { - usageInfo.OutputTokens = int64(outputTokens) - } - } - if usageInfo.TotalTokens == 0 { - if totalTokens, ok := usageObj["total_tokens"].(float64); ok { - usageInfo.TotalTokens = int64(totalTokens) - } - } - log.Debugf("kiro: parseEventStream found usage object (fallback): input=%d, output=%d, total=%d", - usageInfo.InputTokens, usageInfo.OutputTokens, usageInfo.TotalTokens) - } - } - - // Also check nested supplementaryWebLinksEvent - if usageEvent, ok := event["supplementaryWebLinksEvent"].(map[string]interface{}); ok { - if inputTokens, ok := usageEvent["inputTokens"].(float64); ok { - usageInfo.InputTokens = int64(inputTokens) - } - if outputTokens, ok := usageEvent["outputTokens"].(float64); ok { - usageInfo.OutputTokens = int64(outputTokens) - } - } - } - - // Parse embedded tool calls from content (e.g., [Called tool_name with args: {...}]) - contentStr := content.String() - cleanedContent, embeddedToolUses := kiroclaude.ParseEmbeddedToolCalls(contentStr, processedIDs) - toolUses = append(toolUses, embeddedToolUses...) - - // Deduplicate all tool uses - toolUses = kiroclaude.DeduplicateToolUses(toolUses) - - // Apply fallback logic for stop_reason if not provided by upstream - // Priority: upstream stopReason > tool_use detection > end_turn default - if stopReason == "" { - if len(toolUses) > 0 { - stopReason = "tool_use" - log.Debugf("kiro: parseEventStream using fallback stop_reason: tool_use (detected %d tool uses)", len(toolUses)) - } else { - stopReason = "end_turn" - log.Debugf("kiro: parseEventStream using fallback stop_reason: end_turn") - } - } - - // Log warning if response was truncated due to max_tokens - if stopReason == "max_tokens" { - log.Warnf("kiro: response truncated due to max_tokens limit") - } - - // Use contextUsagePercentage to calculate more accurate input tokens - // Kiro model has 200k max context, contextUsagePercentage represents the percentage used - // Formula: input_tokens = contextUsagePercentage * 200000 / 100 - if upstreamContextPercentage > 0 { - calculatedInputTokens := int64(upstreamContextPercentage * 200000 / 100) - if calculatedInputTokens > 0 { - localEstimate := usageInfo.InputTokens - usageInfo.InputTokens = calculatedInputTokens - usageInfo.TotalTokens = usageInfo.InputTokens + usageInfo.OutputTokens - log.Infof("kiro: parseEventStream using contextUsagePercentage (%.2f%%) to calculate input tokens: %d (local estimate was: %d)", - upstreamContextPercentage, calculatedInputTokens, localEstimate) - } - } - - return cleanedContent, toolUses, usageInfo, stopReason, nil -} - -// readEventStreamMessage reads and validates a single AWS Event Stream message. -// Returns the parsed message or a structured error for different failure modes. -// This function implements boundary protection and detailed error classification. -// -// AWS Event Stream binary format: -// - Prelude (12 bytes): total_length (4) + headers_length (4) + prelude_crc (4) -// - Headers (variable): header entries -// - Payload (variable): JSON data -// - Message CRC (4 bytes): CRC32C of entire message (not validated, just skipped) -func (e *KiroExecutor) readEventStreamMessage(reader *bufio.Reader) (*eventStreamMessage, *EventStreamError) { - // Read prelude (first 12 bytes: total_len + headers_len + prelude_crc) - prelude := make([]byte, 12) - _, err := io.ReadFull(reader, prelude) - if err == io.EOF { - return nil, nil // Normal end of stream - } - if err != nil { - return nil, &EventStreamError{ - Type: ErrStreamFatal, - Message: "failed to read prelude", - Cause: err, - } - } - - totalLength := binary.BigEndian.Uint32(prelude[0:4]) - headersLength := binary.BigEndian.Uint32(prelude[4:8]) - // Note: prelude[8:12] is prelude_crc - we read it but don't validate (no CRC check per requirements) - - // Boundary check: minimum frame size - if totalLength < minEventStreamFrameSize { - return nil, &EventStreamError{ - Type: ErrStreamMalformed, - Message: fmt.Sprintf("invalid message length: %d (minimum is %d)", totalLength, minEventStreamFrameSize), - } - } - - // Boundary check: maximum message size - if totalLength > maxEventStreamMsgSize { - return nil, &EventStreamError{ - Type: ErrStreamMalformed, - Message: fmt.Sprintf("message too large: %d bytes (maximum is %d)", totalLength, maxEventStreamMsgSize), - } - } - - // Boundary check: headers length within message bounds - // Message structure: prelude(12) + headers(headersLength) + payload + message_crc(4) - // So: headersLength must be <= totalLength - 16 (12 for prelude + 4 for message_crc) - if headersLength > totalLength-16 { - return nil, &EventStreamError{ - Type: ErrStreamMalformed, - Message: fmt.Sprintf("headers length %d exceeds message bounds (total: %d)", headersLength, totalLength), - } - } - - // Read the rest of the message (total - 12 bytes already read) - remaining := make([]byte, totalLength-12) - _, err = io.ReadFull(reader, remaining) - if err != nil { - return nil, &EventStreamError{ - Type: ErrStreamFatal, - Message: "failed to read message body", - Cause: err, - } - } - - // Extract event type from headers - // Headers start at beginning of 'remaining', length is headersLength - var eventType string - if headersLength > 0 && headersLength <= uint32(len(remaining)) { - eventType = e.extractEventTypeFromBytes(remaining[:headersLength]) - } - - // Calculate payload boundaries - // Payload starts after headers, ends before message_crc (last 4 bytes) - payloadStart := headersLength - payloadEnd := uint32(len(remaining)) - 4 // Skip message_crc at end - - // Validate payload boundaries - if payloadStart >= payloadEnd { - // No payload, return empty message - return &eventStreamMessage{ - EventType: eventType, - Payload: nil, - }, nil - } - - payload := remaining[payloadStart:payloadEnd] - - return &eventStreamMessage{ - EventType: eventType, - Payload: payload, - }, nil -} - -func skipEventStreamHeaderValue(headers []byte, offset int, valueType byte) (int, bool) { - switch valueType { - case 0, 1: // bool true / bool false - return offset, true - case 2: // byte - if offset+1 > len(headers) { - return offset, false - } - return offset + 1, true - case 3: // short - if offset+2 > len(headers) { - return offset, false - } - return offset + 2, true - case 4: // int - if offset+4 > len(headers) { - return offset, false - } - return offset + 4, true - case 5: // long - if offset+8 > len(headers) { - return offset, false - } - return offset + 8, true - case 6: // byte array (2-byte length + data) - if offset+2 > len(headers) { - return offset, false - } - valueLen := int(binary.BigEndian.Uint16(headers[offset : offset+2])) - offset += 2 - if offset+valueLen > len(headers) { - return offset, false - } - return offset + valueLen, true - case 8: // timestamp - if offset+8 > len(headers) { - return offset, false - } - return offset + 8, true - case 9: // uuid - if offset+16 > len(headers) { - return offset, false - } - return offset + 16, true - default: - return offset, false - } -} - -// extractEventTypeFromBytes extracts the event type from raw header bytes (without prelude CRC prefix) -func (e *KiroExecutor) extractEventTypeFromBytes(headers []byte) string { - offset := 0 - for offset < len(headers) { - nameLen := int(headers[offset]) - offset++ - if offset+nameLen > len(headers) { - break - } - name := string(headers[offset : offset+nameLen]) - offset += nameLen - - if offset >= len(headers) { - break - } - valueType := headers[offset] - offset++ - - if valueType == 7 { // String type - if offset+2 > len(headers) { - break - } - valueLen := int(binary.BigEndian.Uint16(headers[offset : offset+2])) - offset += 2 - if offset+valueLen > len(headers) { - break - } - value := string(headers[offset : offset+valueLen]) - offset += valueLen - - if name == ":event-type" { - return value - } - continue - } - - nextOffset, ok := skipEventStreamHeaderValue(headers, offset, valueType) - if !ok { - break - } - offset = nextOffset - } - return "" -} - -// NOTE: Response building functions moved to pkg/llmproxy/translator/kiro/claude/kiro_claude_response.go -// The executor now uses kiroclaude.BuildClaudeResponse() and kiroclaude.ExtractThinkingFromContent() instead - -// streamToChannel converts AWS Event Stream to channel-based streaming. -// Supports tool calling - emits tool_use content blocks when tools are used. -// Includes embedded [Called ...] tool call parsing and input buffering for toolUseEvent. -// Implements duplicate content filtering using lastContentEvent detection (based on AIClient-2-API). -// Extracts stop_reason from upstream events when available. -// thinkingEnabled controls whether tags are parsed - only parse when request enabled thinking. -func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out chan<- cliproxyexecutor.StreamChunk, targetFormat sdktranslator.Format, model string, originalReq, claudeBody []byte, reporter *usageReporter, thinkingEnabled bool) { - reader := bufio.NewReaderSize(body, 20*1024*1024) // 20MB buffer to match other providers - var totalUsage usage.Detail - var hasToolUses bool // Track if any tool uses were emitted - var hasTruncatedTools bool // Track if any tool uses were truncated - var upstreamStopReason string // Track stop_reason from upstream events - - // Tool use state tracking for input buffering and deduplication - processedIDs := make(map[string]bool) - var currentToolUse *kiroclaude.ToolUseState - - // NOTE: Duplicate content filtering removed - it was causing legitimate repeated - // content (like consecutive newlines) to be incorrectly filtered out. - // The previous implementation compared lastContentEvent == contentDelta which - // is too aggressive for streaming scenarios. - - // Streaming token calculation - accumulate content for real-time token counting - // Based on AIClient-2-API implementation - var accumulatedContent strings.Builder - accumulatedContent.Grow(4096) // Pre-allocate 4KB capacity to reduce reallocations - - // Real-time usage estimation state - // These track when to send periodic usage updates during streaming - var lastUsageUpdateLen int // Last accumulated content length when usage was sent - var lastUsageUpdateTime = time.Now() // Last time usage update was sent - var lastReportedOutputTokens int64 // Last reported output token count - - // Upstream usage tracking - Kiro API returns credit usage and context percentage - var upstreamCreditUsage float64 // Credit usage from upstream (e.g., 1.458) - var upstreamContextPercentage float64 // Context usage percentage from upstream (e.g., 78.56) - var hasUpstreamUsage bool // Whether we received usage from upstream - - // Translator param for maintaining tool call state across streaming events - // IMPORTANT: This must persist across all TranslateStream calls - var translatorParam any - - // Thinking mode state tracking - tag-based parsing for tags in content - inThinkBlock := false // Whether we're currently inside a block - isThinkingBlockOpen := false // Track if thinking content block SSE event is open - thinkingBlockIndex := -1 // Index of the thinking content block - var accumulatedThinkingContent strings.Builder // Accumulate thinking content for token counting - - // Buffer for handling partial tag matches at chunk boundaries - var pendingContent strings.Builder // Buffer content that might be part of a tag - - // Pre-calculate input tokens from request if possible - // Kiro uses Claude format, so try Claude format first, then OpenAI format, then fallback - if enc, err := getTokenizer(model); err == nil { - var inputTokens int64 - var countMethod string - - // Try Claude format first (Kiro uses Claude API format) - if inp, err := countClaudeChatTokens(enc, claudeBody); err == nil && inp > 0 { - inputTokens = inp - countMethod = "claude" - } else if inp, err := countOpenAIChatTokens(enc, originalReq); err == nil && inp > 0 { - // Fallback to OpenAI format (for OpenAI-compatible requests) - inputTokens = inp - countMethod = "openai" - } else { - // Final fallback: estimate from raw request size (roughly 4 chars per token) - inputTokens = int64(len(claudeBody) / 4) - if inputTokens == 0 && len(claudeBody) > 0 { - inputTokens = 1 - } - countMethod = "estimate" - } - - totalUsage.InputTokens = inputTokens - log.Debugf("kiro: streamToChannel pre-calculated input tokens: %d (method: %s, claude body: %d bytes, original req: %d bytes)", - totalUsage.InputTokens, countMethod, len(claudeBody), len(originalReq)) - } - - contentBlockIndex := -1 - messageStartSent := false - isTextBlockOpen := false - var outputLen int - - // Ensure usage is published even on early return - defer func() { - reporter.publish(ctx, totalUsage) - }() - - for { - select { - case <-ctx.Done(): - return - default: - } - - msg, eventErr := e.readEventStreamMessage(reader) - if eventErr != nil { - // Log the error - log.Errorf("kiro: streamToChannel error: %v", eventErr) - - // Send error to channel for client notification - out <- cliproxyexecutor.StreamChunk{Err: eventErr} - return - } - if msg == nil { - // Normal end of stream (EOF) - // Flush any incomplete tool use before ending stream - if currentToolUse != nil && !processedIDs[currentToolUse.ToolUseID] { - log.Warnf("kiro: flushing incomplete tool use at EOF: %s (ID: %s)", currentToolUse.Name, currentToolUse.ToolUseID) - fullInput := currentToolUse.InputBuffer.String() - repairedJSON := kiroclaude.RepairJSON(fullInput) - var finalInput map[string]interface{} - if err := json.Unmarshal([]byte(repairedJSON), &finalInput); err != nil { - log.Warnf("kiro: failed to parse incomplete tool input at EOF: %v", err) - finalInput = make(map[string]interface{}) - } - - processedIDs[currentToolUse.ToolUseID] = true - contentBlockIndex++ - - // Send tool_use content block - blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", currentToolUse.ToolUseID, currentToolUse.Name) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - - // Send tool input as delta - inputBytes, _ := json.Marshal(finalInput) - inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputBytes), contentBlockIndex) - sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - - // Close block - blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) - sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - - hasToolUses = true - currentToolUse = nil - } - - // DISABLED: Tag-based pending character flushing - // This code block was used for tag-based thinking detection which has been - // replaced by reasoningContentEvent handling. No pending tag chars to flush. - // Original code preserved in git history. - break - } - - eventType := msg.EventType - payload := msg.Payload - if len(payload) == 0 { - continue - } - appendAPIResponseChunk(ctx, e.cfg, payload) - - var event map[string]interface{} - if err := json.Unmarshal(payload, &event); err != nil { - log.Warnf("kiro: failed to unmarshal event payload: %v, raw: %s", err, string(payload)) - continue - } - - // Check for error/exception events in the payload (Kiro API may return errors with HTTP 200) - // These can appear as top-level fields or nested within the event - if errType, hasErrType := event["_type"].(string); hasErrType { - // AWS-style error: {"_type": "com.amazon.aws.codewhisperer#ValidationException", "message": "..."} - errMsg := "" - if msg, ok := event["message"].(string); ok { - errMsg = msg - } - log.Errorf("kiro: received AWS error in stream: type=%s, message=%s", errType, errMsg) - out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("kiro API error: %s - %s", errType, errMsg)} - return - } - if errType, hasErrType := event["type"].(string); hasErrType && (errType == "error" || errType == "exception") { - // Generic error event - errMsg := "" - if msg, ok := event["message"].(string); ok { - errMsg = msg - } else if errObj, ok := event["error"].(map[string]interface{}); ok { - if msg, ok := errObj["message"].(string); ok { - errMsg = msg - } - } - log.Errorf("kiro: received error event in stream: type=%s, message=%s", errType, errMsg) - out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("kiro API error: %s", errMsg)} - return - } - - // Extract stop_reason from various event formats (streaming) - // Kiro/Amazon Q API may include stop_reason in different locations - if sr := kirocommon.GetString(event, "stop_reason"); sr != "" { - upstreamStopReason = sr - log.Debugf("kiro: streamToChannel found stop_reason (top-level): %s", upstreamStopReason) - } - if sr := kirocommon.GetString(event, "stopReason"); sr != "" { - upstreamStopReason = sr - log.Debugf("kiro: streamToChannel found stopReason (top-level): %s", upstreamStopReason) - } - - // Send message_start on first event - if !messageStartSent { - msgStart := kiroclaude.BuildClaudeMessageStartEvent(model, totalUsage.InputTokens) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStart, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - messageStartSent = true - } - - switch eventType { - case "followupPromptEvent": - // Filter out followupPrompt events - these are UI suggestions, not content - log.Debugf("kiro: streamToChannel ignoring followupPrompt event") - continue - - case "messageStopEvent", "message_stop": - // Handle message stop events which may contain stop_reason - if sr := kirocommon.GetString(event, "stop_reason"); sr != "" { - upstreamStopReason = sr - log.Debugf("kiro: streamToChannel found stop_reason in messageStopEvent: %s", upstreamStopReason) - } - if sr := kirocommon.GetString(event, "stopReason"); sr != "" { - upstreamStopReason = sr - log.Debugf("kiro: streamToChannel found stopReason in messageStopEvent: %s", upstreamStopReason) - } - - case "meteringEvent": - // Handle metering events from Kiro API (usage billing information) - // Official format: { unit: string, unitPlural: string, usage: number } - if metering, ok := event["meteringEvent"].(map[string]interface{}); ok { - unit := "" - if u, ok := metering["unit"].(string); ok { - unit = u - } - usageVal := 0.0 - if u, ok := metering["usage"].(float64); ok { - usageVal = u - } - upstreamCreditUsage = usageVal - hasUpstreamUsage = true - log.Infof("kiro: streamToChannel received meteringEvent: usage=%.4f %s", usageVal, unit) - } else { - // Try direct fields (event is meteringEvent itself) - if unit, ok := event["unit"].(string); ok { - if usage, ok := event["usage"].(float64); ok { - upstreamCreditUsage = usage - hasUpstreamUsage = true - log.Infof("kiro: streamToChannel received meteringEvent (direct): usage=%.4f %s", usage, unit) - } - } - } - - case "contextUsageEvent": - // Handle context usage events from Kiro API - // Format: {"contextUsageEvent": {"contextUsagePercentage": 0.53}} - if ctxUsage, ok := event["contextUsageEvent"].(map[string]interface{}); ok { - if ctxPct, ok := ctxUsage["contextUsagePercentage"].(float64); ok { - upstreamContextPercentage = ctxPct - log.Debugf("kiro: streamToChannel received contextUsageEvent: %.2f%%", ctxPct*100) - } - } else { - // Try direct field (fallback) - if ctxPct, ok := event["contextUsagePercentage"].(float64); ok { - upstreamContextPercentage = ctxPct - log.Debugf("kiro: streamToChannel received contextUsagePercentage (direct): %.2f%%", ctxPct*100) - } - } - - case "error", "exception", "internalServerException": - // Handle error events from Kiro API stream - errMsg := "" - errType := eventType - - // Try to extract error message from various formats - if msg, ok := event["message"].(string); ok { - errMsg = msg - } else if errObj, ok := event[eventType].(map[string]interface{}); ok { - if msg, ok := errObj["message"].(string); ok { - errMsg = msg - } - if t, ok := errObj["type"].(string); ok { - errType = t - } - } else if errObj, ok := event["error"].(map[string]interface{}); ok { - if msg, ok := errObj["message"].(string); ok { - errMsg = msg - } - } - - log.Errorf("kiro: streamToChannel received error event: type=%s, message=%s", errType, errMsg) - - // Send error to the stream and exit - if errMsg != "" { - out <- cliproxyexecutor.StreamChunk{ - Err: fmt.Errorf("kiro API error (%s): %s", errType, errMsg), - } - return - } - - case "invalidStateEvent": - // Handle invalid state events - log and continue (non-fatal) - errMsg := "" - if msg, ok := event["message"].(string); ok { - errMsg = msg - } else if stateEvent, ok := event["invalidStateEvent"].(map[string]interface{}); ok { - if msg, ok := stateEvent["message"].(string); ok { - errMsg = msg - } - } - log.Warnf("kiro: streamToChannel received invalidStateEvent: %s, continuing", errMsg) - continue - - case "assistantResponseEvent": - var contentDelta string - var toolUses []map[string]interface{} - - if assistantResp, ok := event["assistantResponseEvent"].(map[string]interface{}); ok { - if c, ok := assistantResp["content"].(string); ok { - contentDelta = c - } - // Extract stop_reason from assistantResponseEvent - if sr := kirocommon.GetString(assistantResp, "stop_reason"); sr != "" { - upstreamStopReason = sr - log.Debugf("kiro: streamToChannel found stop_reason in assistantResponseEvent: %s", upstreamStopReason) - } - if sr := kirocommon.GetString(assistantResp, "stopReason"); sr != "" { - upstreamStopReason = sr - log.Debugf("kiro: streamToChannel found stopReason in assistantResponseEvent: %s", upstreamStopReason) - } - // Extract tool uses from response - if tus, ok := assistantResp["toolUses"].([]interface{}); ok { - for _, tuRaw := range tus { - if tu, ok := tuRaw.(map[string]interface{}); ok { - toolUses = append(toolUses, tu) - } - } - } - } - if contentDelta == "" { - if c, ok := event["content"].(string); ok { - contentDelta = c - } - } - // Direct tool uses - if tus, ok := event["toolUses"].([]interface{}); ok { - for _, tuRaw := range tus { - if tu, ok := tuRaw.(map[string]interface{}); ok { - toolUses = append(toolUses, tu) - } - } - } - - // Handle text content with thinking mode support - if contentDelta != "" { - // NOTE: Duplicate content filtering was removed because it incorrectly - // filtered out legitimate repeated content (like consecutive newlines "\n\n"). - // Streaming naturally can have identical chunks that are valid content. - - outputLen += len(contentDelta) - // Accumulate content for streaming token calculation - accumulatedContent.WriteString(contentDelta) - - // Real-time usage estimation: Check if we should send a usage update - // This helps clients track context usage during long thinking sessions - shouldSendUsageUpdate := false - if accumulatedContent.Len()-lastUsageUpdateLen >= usageUpdateCharThreshold { - shouldSendUsageUpdate = true - } else if time.Since(lastUsageUpdateTime) >= usageUpdateTimeInterval && accumulatedContent.Len() > lastUsageUpdateLen { - shouldSendUsageUpdate = true - } - - if shouldSendUsageUpdate { - // Calculate current output tokens using tiktoken - var currentOutputTokens int64 - if enc, encErr := getTokenizer(model); encErr == nil { - if tokenCount, countErr := enc.Count(accumulatedContent.String()); countErr == nil { - currentOutputTokens = int64(tokenCount) - } - } - // Fallback to character estimation if tiktoken fails - if currentOutputTokens == 0 { - currentOutputTokens = int64(accumulatedContent.Len() / 4) - if currentOutputTokens == 0 { - currentOutputTokens = 1 - } - } - - // Only send update if token count has changed significantly (at least 10 tokens) - if currentOutputTokens > lastReportedOutputTokens+10 { - // Send ping event with usage information - // This is a non-blocking update that clients can optionally process - pingEvent := kiroclaude.BuildClaudePingEventWithUsage(totalUsage.InputTokens, currentOutputTokens) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, pingEvent, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - - lastReportedOutputTokens = currentOutputTokens - log.Debugf("kiro: sent real-time usage update - input: %d, output: %d (accumulated: %d chars)", - totalUsage.InputTokens, currentOutputTokens, accumulatedContent.Len()) - } - - lastUsageUpdateLen = accumulatedContent.Len() - lastUsageUpdateTime = time.Now() - } - - // TAG-BASED THINKING PARSING: Parse tags from content - // Combine pending content with new content for processing - pendingContent.WriteString(contentDelta) - processContent := pendingContent.String() - pendingContent.Reset() - - // Process content looking for thinking tags - for len(processContent) > 0 { - if inThinkBlock { - // We're inside a thinking block, look for - endIdx := strings.Index(processContent, kirocommon.ThinkingEndTag) - if endIdx >= 0 { - // Found end tag - emit thinking content before the tag - thinkingText := processContent[:endIdx] - if thinkingText != "" { - // Ensure thinking block is open - if !isThinkingBlockOpen { - contentBlockIndex++ - thinkingBlockIndex = contentBlockIndex - isThinkingBlockOpen = true - blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "") - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - } - // Send thinking delta - thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(thinkingText, thinkingBlockIndex) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - accumulatedThinkingContent.WriteString(thinkingText) - } - // Close thinking block - if isThinkingBlockOpen { - blockStop := kiroclaude.BuildClaudeThinkingBlockStopEvent(thinkingBlockIndex) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - isThinkingBlockOpen = false - } - inThinkBlock = false - processContent = processContent[endIdx+len(kirocommon.ThinkingEndTag):] - log.Debugf("kiro: closed thinking block, remaining content: %d chars", len(processContent)) - } else { - // No end tag found - check for partial match at end - partialMatch := false - for i := 1; i < len(kirocommon.ThinkingEndTag) && i <= len(processContent); i++ { - if strings.HasSuffix(processContent, kirocommon.ThinkingEndTag[:i]) { - // Possible partial tag at end, buffer it - pendingContent.WriteString(processContent[len(processContent)-i:]) - processContent = processContent[:len(processContent)-i] - partialMatch = true - break - } - } - if !partialMatch || len(processContent) > 0 { - // Emit all as thinking content - if processContent != "" { - if !isThinkingBlockOpen { - contentBlockIndex++ - thinkingBlockIndex = contentBlockIndex - isThinkingBlockOpen = true - blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "") - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - } - thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(processContent, thinkingBlockIndex) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - accumulatedThinkingContent.WriteString(processContent) - } - } - processContent = "" - } - } else { - // Not in thinking block, look for - startIdx := strings.Index(processContent, kirocommon.ThinkingStartTag) - if startIdx >= 0 { - // Found start tag - emit text content before the tag - textBefore := processContent[:startIdx] - if textBefore != "" { - // Close thinking block if open - if isThinkingBlockOpen { - blockStop := kiroclaude.BuildClaudeThinkingBlockStopEvent(thinkingBlockIndex) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - isThinkingBlockOpen = false - } - // Ensure text block is open - if !isTextBlockOpen { - contentBlockIndex++ - isTextBlockOpen = true - blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - } - // Send text delta - claudeEvent := kiroclaude.BuildClaudeStreamEvent(textBefore, contentBlockIndex) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - } - // Close text block before entering thinking - if isTextBlockOpen { - blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - isTextBlockOpen = false - } - inThinkBlock = true - processContent = processContent[startIdx+len(kirocommon.ThinkingStartTag):] - log.Debugf("kiro: entered thinking block") - } else { - // No start tag found - check for partial match at end - partialMatch := false - for i := 1; i < len(kirocommon.ThinkingStartTag) && i <= len(processContent); i++ { - if strings.HasSuffix(processContent, kirocommon.ThinkingStartTag[:i]) { - // Possible partial tag at end, buffer it - pendingContent.WriteString(processContent[len(processContent)-i:]) - processContent = processContent[:len(processContent)-i] - partialMatch = true - break - } - } - if !partialMatch || len(processContent) > 0 { - // Emit all as text content - if processContent != "" { - if !isTextBlockOpen { - contentBlockIndex++ - isTextBlockOpen = true - blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - } - claudeEvent := kiroclaude.BuildClaudeStreamEvent(processContent, contentBlockIndex) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - } - } - processContent = "" - } - } - } - } - - // Handle tool uses in response (with deduplication) - for _, tu := range toolUses { - toolUseID := kirocommon.GetString(tu, "toolUseId") - toolName := kirocommon.GetString(tu, "name") - - // Check for duplicate - if processedIDs[toolUseID] { - log.Debugf("kiro: skipping duplicate tool use in stream: %s", toolUseID) - continue - } - processedIDs[toolUseID] = true - - hasToolUses = true - // Close text block if open before starting tool_use block - if isTextBlockOpen && contentBlockIndex >= 0 { - blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - isTextBlockOpen = false - } - - // Emit tool_use content block - contentBlockIndex++ - - blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", toolUseID, toolName) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - - // Send input_json_delta with the tool input - if input, ok := tu["input"].(map[string]interface{}); ok { - inputJSON, err := json.Marshal(input) - if err != nil { - log.Debugf("kiro: failed to marshal tool input: %v", err) - // Don't continue - still need to close the block - } else { - inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex) - sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - } - } - - // Close tool_use block (always close even if input marshal failed) - blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) - sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - } - - case "reasoningContentEvent": - // Handle official reasoningContentEvent from Kiro API - // This replaces tag-based thinking detection with the proper event type - // Official format: { text: string, signature?: string, redactedContent?: base64 } - var thinkingText string - var signature string - - if re, ok := event["reasoningContentEvent"].(map[string]interface{}); ok { - if text, ok := re["text"].(string); ok { - thinkingText = text - } - if sig, ok := re["signature"].(string); ok { - signature = sig - if len(sig) > 20 { - log.Debugf("kiro: reasoningContentEvent has signature: %s...", sig[:20]) - } else { - log.Debugf("kiro: reasoningContentEvent has signature: %s", sig) - } - } - } else { - // Try direct fields - if text, ok := event["text"].(string); ok { - thinkingText = text - } - if sig, ok := event["signature"].(string); ok { - signature = sig - } - } - - if thinkingText != "" { - // Close text block if open before starting thinking block - if isTextBlockOpen && contentBlockIndex >= 0 { - blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - isTextBlockOpen = false - } - - // Start thinking block if not already open - if !isThinkingBlockOpen { - contentBlockIndex++ - thinkingBlockIndex = contentBlockIndex - isThinkingBlockOpen = true - blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "") - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - } - - // Send thinking content - thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(thinkingText, thinkingBlockIndex) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - - // Accumulate for token counting - accumulatedThinkingContent.WriteString(thinkingText) - log.Debugf("kiro: received reasoningContentEvent, text length: %d, has signature: %v", len(thinkingText), signature != "") - } - - // Note: We don't close the thinking block here - it will be closed when we see - // the next assistantResponseEvent or at the end of the stream - _ = signature // Signature can be used for verification if needed - - case "toolUseEvent": - // Handle dedicated tool use events with input buffering - completedToolUses, newState := kiroclaude.ProcessToolUseEvent(event, currentToolUse, processedIDs) - currentToolUse = newState - - // Emit completed tool uses - for _, tu := range completedToolUses { - // Check if this tool was truncated - emit with SOFT_LIMIT_REACHED marker - if tu.IsTruncated { - hasTruncatedTools = true - log.Infof("kiro: streamToChannel emitting truncated tool with SOFT_LIMIT_REACHED: %s (ID: %s)", tu.Name, tu.ToolUseID) - - // Close text block if open - if isTextBlockOpen && contentBlockIndex >= 0 { - blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - isTextBlockOpen = false - } - - contentBlockIndex++ - - // Emit tool_use with SOFT_LIMIT_REACHED marker input - blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", tu.ToolUseID, tu.Name) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - - // Build SOFT_LIMIT_REACHED marker input - markerInput := map[string]interface{}{ - "_status": "SOFT_LIMIT_REACHED", - "_message": "Tool output was truncated. Split content into smaller chunks (max 300 lines). Due to potential model hallucination, you MUST re-fetch the current working directory and generate the correct file_path.", - } - - markerJSON, _ := json.Marshal(markerInput) - inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(markerJSON), contentBlockIndex) - sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - - // Close tool_use block - blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) - sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - - hasToolUses = true // Keep this so stop_reason = tool_use - continue - } - - hasToolUses = true - - // Close text block if open - if isTextBlockOpen && contentBlockIndex >= 0 { - blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - isTextBlockOpen = false - } - - contentBlockIndex++ - - blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", tu.ToolUseID, tu.Name) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - - if tu.Input != nil { - inputJSON, err := json.Marshal(tu.Input) - if err != nil { - log.Debugf("kiro: failed to marshal tool input in toolUseEvent: %v", err) - } else { - inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex) - sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - } - } - - blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) - sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - } - - case "supplementaryWebLinksEvent": - if inputTokens, ok := event["inputTokens"].(float64); ok { - totalUsage.InputTokens = int64(inputTokens) - } - if outputTokens, ok := event["outputTokens"].(float64); ok { - totalUsage.OutputTokens = int64(outputTokens) - } - - case "messageMetadataEvent", "metadataEvent": - // Handle message metadata events which contain token counts - // Official format: { tokenUsage: { outputTokens, totalTokens, uncachedInputTokens, cacheReadInputTokens, cacheWriteInputTokens, contextUsagePercentage } } - var metadata map[string]interface{} - if m, ok := event["messageMetadataEvent"].(map[string]interface{}); ok { - metadata = m - } else if m, ok := event["metadataEvent"].(map[string]interface{}); ok { - metadata = m - } else { - metadata = event // event itself might be the metadata - } - - // Check for nested tokenUsage object (official format) - if tokenUsage, ok := metadata["tokenUsage"].(map[string]interface{}); ok { - // outputTokens - precise output token count - if outputTokens, ok := tokenUsage["outputTokens"].(float64); ok { - totalUsage.OutputTokens = int64(outputTokens) - hasUpstreamUsage = true - log.Infof("kiro: streamToChannel found precise outputTokens in tokenUsage: %d", totalUsage.OutputTokens) - } - // totalTokens - precise total token count - if totalTokens, ok := tokenUsage["totalTokens"].(float64); ok { - totalUsage.TotalTokens = int64(totalTokens) - log.Infof("kiro: streamToChannel found precise totalTokens in tokenUsage: %d", totalUsage.TotalTokens) - } - // uncachedInputTokens - input tokens not from cache - if uncachedInputTokens, ok := tokenUsage["uncachedInputTokens"].(float64); ok { - totalUsage.InputTokens = int64(uncachedInputTokens) - hasUpstreamUsage = true - log.Infof("kiro: streamToChannel found uncachedInputTokens in tokenUsage: %d", totalUsage.InputTokens) - } - // cacheReadInputTokens - tokens read from cache - if cacheReadTokens, ok := tokenUsage["cacheReadInputTokens"].(float64); ok { - // Add to input tokens if we have uncached tokens, otherwise use as input - if totalUsage.InputTokens > 0 { - totalUsage.InputTokens += int64(cacheReadTokens) - } else { - totalUsage.InputTokens = int64(cacheReadTokens) - } - hasUpstreamUsage = true - log.Debugf("kiro: streamToChannel found cacheReadInputTokens in tokenUsage: %d", int64(cacheReadTokens)) - } - // contextUsagePercentage - can be used as fallback for input token estimation - if ctxPct, ok := tokenUsage["contextUsagePercentage"].(float64); ok { - upstreamContextPercentage = ctxPct - log.Debugf("kiro: streamToChannel found contextUsagePercentage in tokenUsage: %.2f%%", ctxPct) - } - } - - // Fallback: check for direct fields in metadata (legacy format) - if totalUsage.InputTokens == 0 { - if inputTokens, ok := metadata["inputTokens"].(float64); ok { - totalUsage.InputTokens = int64(inputTokens) - hasUpstreamUsage = true - log.Debugf("kiro: streamToChannel found inputTokens in messageMetadataEvent: %d", totalUsage.InputTokens) - } - } - if totalUsage.OutputTokens == 0 { - if outputTokens, ok := metadata["outputTokens"].(float64); ok { - totalUsage.OutputTokens = int64(outputTokens) - hasUpstreamUsage = true - log.Debugf("kiro: streamToChannel found outputTokens in messageMetadataEvent: %d", totalUsage.OutputTokens) - } - } - if totalUsage.TotalTokens == 0 { - if totalTokens, ok := metadata["totalTokens"].(float64); ok { - totalUsage.TotalTokens = int64(totalTokens) - log.Debugf("kiro: streamToChannel found totalTokens in messageMetadataEvent: %d", totalUsage.TotalTokens) - } - } - - case "usageEvent", "usage": - // Handle dedicated usage events - if inputTokens, ok := event["inputTokens"].(float64); ok { - totalUsage.InputTokens = int64(inputTokens) - log.Debugf("kiro: streamToChannel found inputTokens in usageEvent: %d", totalUsage.InputTokens) - } - if outputTokens, ok := event["outputTokens"].(float64); ok { - totalUsage.OutputTokens = int64(outputTokens) - log.Debugf("kiro: streamToChannel found outputTokens in usageEvent: %d", totalUsage.OutputTokens) - } - if totalTokens, ok := event["totalTokens"].(float64); ok { - totalUsage.TotalTokens = int64(totalTokens) - log.Debugf("kiro: streamToChannel found totalTokens in usageEvent: %d", totalUsage.TotalTokens) - } - // Also check nested usage object - if usageObj, ok := event["usage"].(map[string]interface{}); ok { - if inputTokens, ok := usageObj["input_tokens"].(float64); ok { - totalUsage.InputTokens = int64(inputTokens) - } else if inputTokens, ok := usageObj["prompt_tokens"].(float64); ok { - totalUsage.InputTokens = int64(inputTokens) - } - if outputTokens, ok := usageObj["output_tokens"].(float64); ok { - totalUsage.OutputTokens = int64(outputTokens) - } else if outputTokens, ok := usageObj["completion_tokens"].(float64); ok { - totalUsage.OutputTokens = int64(outputTokens) - } - if totalTokens, ok := usageObj["total_tokens"].(float64); ok { - totalUsage.TotalTokens = int64(totalTokens) - } - log.Debugf("kiro: streamToChannel found usage object: input=%d, output=%d, total=%d", - totalUsage.InputTokens, totalUsage.OutputTokens, totalUsage.TotalTokens) - } - - case "metricsEvent": - // Handle metrics events which may contain usage data - if metrics, ok := event["metricsEvent"].(map[string]interface{}); ok { - if inputTokens, ok := metrics["inputTokens"].(float64); ok { - totalUsage.InputTokens = int64(inputTokens) - } - if outputTokens, ok := metrics["outputTokens"].(float64); ok { - totalUsage.OutputTokens = int64(outputTokens) - } - log.Debugf("kiro: streamToChannel found metricsEvent: input=%d, output=%d", - totalUsage.InputTokens, totalUsage.OutputTokens) - - } - default: - // Check for upstream usage events from Kiro API - // Format: {"unit":"credit","unitPlural":"credits","usage":1.458} - if unit, ok := event["unit"].(string); ok && unit == "credit" { - if usage, ok := event["usage"].(float64); ok { - upstreamCreditUsage = usage - hasUpstreamUsage = true - log.Debugf("kiro: received upstream credit usage: %.4f", upstreamCreditUsage) - } - } - // Format: {"contextUsagePercentage":78.56} - if ctxPct, ok := event["contextUsagePercentage"].(float64); ok { - upstreamContextPercentage = ctxPct - log.Debugf("kiro: received upstream context usage: %.2f%%", upstreamContextPercentage) - } - - // Check for token counts in unknown events - if inputTokens, ok := event["inputTokens"].(float64); ok { - totalUsage.InputTokens = int64(inputTokens) - hasUpstreamUsage = true - log.Debugf("kiro: streamToChannel found inputTokens in event %s: %d", eventType, totalUsage.InputTokens) - } - if outputTokens, ok := event["outputTokens"].(float64); ok { - totalUsage.OutputTokens = int64(outputTokens) - hasUpstreamUsage = true - log.Debugf("kiro: streamToChannel found outputTokens in event %s: %d", eventType, totalUsage.OutputTokens) - } - if totalTokens, ok := event["totalTokens"].(float64); ok { - totalUsage.TotalTokens = int64(totalTokens) - log.Debugf("kiro: streamToChannel found totalTokens in event %s: %d", eventType, totalUsage.TotalTokens) - } - - // Check for usage object in unknown events (OpenAI/Claude format) - if usageObj, ok := event["usage"].(map[string]interface{}); ok { - if inputTokens, ok := usageObj["input_tokens"].(float64); ok { - totalUsage.InputTokens = int64(inputTokens) - hasUpstreamUsage = true - } else if inputTokens, ok := usageObj["prompt_tokens"].(float64); ok { - totalUsage.InputTokens = int64(inputTokens) - hasUpstreamUsage = true - } - if outputTokens, ok := usageObj["output_tokens"].(float64); ok { - totalUsage.OutputTokens = int64(outputTokens) - hasUpstreamUsage = true - } else if outputTokens, ok := usageObj["completion_tokens"].(float64); ok { - totalUsage.OutputTokens = int64(outputTokens) - hasUpstreamUsage = true - } - if totalTokens, ok := usageObj["total_tokens"].(float64); ok { - totalUsage.TotalTokens = int64(totalTokens) - } - log.Debugf("kiro: streamToChannel found usage object in event %s: input=%d, output=%d, total=%d", - eventType, totalUsage.InputTokens, totalUsage.OutputTokens, totalUsage.TotalTokens) - } - - // Log unknown event types for debugging (to discover new event formats) - if eventType != "" { - log.Debugf("kiro: streamToChannel unknown event type: %s, payload: %s", eventType, string(payload)) - } - - } - - // Check nested usage event - if usageEvent, ok := event["supplementaryWebLinksEvent"].(map[string]interface{}); ok { - if inputTokens, ok := usageEvent["inputTokens"].(float64); ok { - totalUsage.InputTokens = int64(inputTokens) - } - if outputTokens, ok := usageEvent["outputTokens"].(float64); ok { - totalUsage.OutputTokens = int64(outputTokens) - } - } - - // Check for direct token fields in any event (fallback) - if totalUsage.InputTokens == 0 { - if inputTokens, ok := event["inputTokens"].(float64); ok { - totalUsage.InputTokens = int64(inputTokens) - log.Debugf("kiro: streamToChannel found direct inputTokens: %d", totalUsage.InputTokens) - } - } - if totalUsage.OutputTokens == 0 { - if outputTokens, ok := event["outputTokens"].(float64); ok { - totalUsage.OutputTokens = int64(outputTokens) - log.Debugf("kiro: streamToChannel found direct outputTokens: %d", totalUsage.OutputTokens) - } - } - - // Check for usage object in any event (OpenAI format) - if totalUsage.InputTokens == 0 || totalUsage.OutputTokens == 0 { - if usageObj, ok := event["usage"].(map[string]interface{}); ok { - if totalUsage.InputTokens == 0 { - if inputTokens, ok := usageObj["input_tokens"].(float64); ok { - totalUsage.InputTokens = int64(inputTokens) - } else if inputTokens, ok := usageObj["prompt_tokens"].(float64); ok { - totalUsage.InputTokens = int64(inputTokens) - } - } - if totalUsage.OutputTokens == 0 { - if outputTokens, ok := usageObj["output_tokens"].(float64); ok { - totalUsage.OutputTokens = int64(outputTokens) - } else if outputTokens, ok := usageObj["completion_tokens"].(float64); ok { - totalUsage.OutputTokens = int64(outputTokens) - } - } - if totalUsage.TotalTokens == 0 { - if totalTokens, ok := usageObj["total_tokens"].(float64); ok { - totalUsage.TotalTokens = int64(totalTokens) - } - } - log.Debugf("kiro: streamToChannel found usage object (fallback): input=%d, output=%d, total=%d", - totalUsage.InputTokens, totalUsage.OutputTokens, totalUsage.TotalTokens) - } - } - } - - // Close content block if open - if isTextBlockOpen && contentBlockIndex >= 0 { - blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - } - - // Streaming token calculation - calculate output tokens from accumulated content - // Only use local estimation if server didn't provide usage (server-side usage takes priority) - if totalUsage.OutputTokens == 0 && accumulatedContent.Len() > 0 { - // Try to use tiktoken for accurate counting - if enc, err := getTokenizer(model); err == nil { - if tokenCount, countErr := enc.Count(accumulatedContent.String()); countErr == nil { - totalUsage.OutputTokens = int64(tokenCount) - log.Debugf("kiro: streamToChannel calculated output tokens using tiktoken: %d", totalUsage.OutputTokens) - } else { - // Fallback on count error: estimate from character count - totalUsage.OutputTokens = int64(accumulatedContent.Len() / 4) - if totalUsage.OutputTokens == 0 { - totalUsage.OutputTokens = 1 - } - log.Debugf("kiro: streamToChannel tiktoken count failed, estimated from chars: %d", totalUsage.OutputTokens) - } - } else { - // Fallback: estimate from character count (roughly 4 chars per token) - totalUsage.OutputTokens = int64(accumulatedContent.Len() / 4) - if totalUsage.OutputTokens == 0 { - totalUsage.OutputTokens = 1 - } - log.Debugf("kiro: streamToChannel estimated output tokens from chars: %d (content len: %d)", totalUsage.OutputTokens, accumulatedContent.Len()) - } - } else if totalUsage.OutputTokens == 0 && outputLen > 0 { - // Legacy fallback using outputLen - totalUsage.OutputTokens = int64(outputLen / 4) - if totalUsage.OutputTokens == 0 { - totalUsage.OutputTokens = 1 - } - } - - // Use contextUsagePercentage to calculate more accurate input tokens - // Kiro model has 200k max context, contextUsagePercentage represents the percentage used - // Formula: input_tokens = contextUsagePercentage * 200000 / 100 - // Note: The effective input context is ~170k (200k - 30k reserved for output) - if upstreamContextPercentage > 0 { - // Calculate input tokens from context percentage - // Using 200k as the base since that's what Kiro reports against - calculatedInputTokens := int64(upstreamContextPercentage * 200000 / 100) - - // Only use calculated value if it's significantly different from local estimate - // This provides more accurate token counts based on upstream data - if calculatedInputTokens > 0 { - localEstimate := totalUsage.InputTokens - totalUsage.InputTokens = calculatedInputTokens - log.Debugf("kiro: using contextUsagePercentage (%.2f%%) to calculate input tokens: %d (local estimate was: %d)", - upstreamContextPercentage, calculatedInputTokens, localEstimate) - } - } - - totalUsage.TotalTokens = totalUsage.InputTokens + totalUsage.OutputTokens - - // Log upstream usage information if received - if hasUpstreamUsage { - log.Debugf("kiro: upstream usage - credits: %.4f, context: %.2f%%, final tokens - input: %d, output: %d, total: %d", - upstreamCreditUsage, upstreamContextPercentage, - totalUsage.InputTokens, totalUsage.OutputTokens, totalUsage.TotalTokens) - } - - // Determine stop reason: prefer upstream, then detect tool_use, default to end_turn - // SOFT_LIMIT_REACHED: Keep stop_reason = "tool_use" so Claude continues the loop - stopReason := upstreamStopReason - if hasTruncatedTools { - // Log that we're using SOFT_LIMIT_REACHED approach - log.Infof("kiro: streamToChannel using SOFT_LIMIT_REACHED - keeping stop_reason=tool_use for truncated tools") - } - if stopReason == "" { - if hasToolUses { - stopReason = "tool_use" - log.Debugf("kiro: streamToChannel using fallback stop_reason: tool_use") - } else { - stopReason = "end_turn" - log.Debugf("kiro: streamToChannel using fallback stop_reason: end_turn") - } - } - - // Log warning if response was truncated due to max_tokens - if stopReason == "max_tokens" { - log.Warnf("kiro: response truncated due to max_tokens limit (streamToChannel)") - } - - // Send message_delta event - msgDelta := kiroclaude.BuildClaudeMessageDeltaEvent(stopReason, totalUsage) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgDelta, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - - // Send message_stop event separately - msgStop := kiroclaude.BuildClaudeMessageStopOnlyEvent() - sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStop, &translatorParam) - for _, chunk := range sseData { - if chunk != "" { - out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} - } - } - // reporter.publish is called via defer -} - -// NOTE: Claude SSE event builders moved to pkg/llmproxy/translator/kiro/claude/kiro_claude_stream.go -// The executor now uses kiroclaude.BuildClaude*Event() functions instead - -// CountTokens counts tokens locally using tiktoken since Kiro API doesn't expose a token counting endpoint. -// This provides approximate token counts for client requests. -func (e *KiroExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { - // Use tiktoken for local token counting - enc, err := getTokenizer(req.Model) - if err != nil { - log.Warnf("kiro: CountTokens failed to get tokenizer: %v, falling back to estimate", err) - // Fallback: estimate from payload size (roughly 4 chars per token) - estimatedTokens := len(req.Payload) / 4 - if estimatedTokens == 0 && len(req.Payload) > 0 { - estimatedTokens = 1 - } - return cliproxyexecutor.Response{ - Payload: []byte(fmt.Sprintf(`{"count":%d}`, estimatedTokens)), - }, nil - } - - // Try to count tokens from the request payload - var totalTokens int64 - - // Try OpenAI chat format first - if tokens, countErr := countOpenAIChatTokens(enc, req.Payload); countErr == nil && tokens > 0 { - totalTokens = tokens - log.Debugf("kiro: CountTokens counted %d tokens using OpenAI chat format", totalTokens) - } else { - // Fallback: count raw payload tokens - if tokenCount, countErr := enc.Count(string(req.Payload)); countErr == nil { - totalTokens = int64(tokenCount) - log.Debugf("kiro: CountTokens counted %d tokens from raw payload", totalTokens) - } else { - // Final fallback: estimate from payload size - totalTokens = int64(len(req.Payload) / 4) - if totalTokens == 0 && len(req.Payload) > 0 { - totalTokens = 1 - } - log.Debugf("kiro: CountTokens estimated %d tokens from payload size", totalTokens) - } - } - - return cliproxyexecutor.Response{ - Payload: []byte(fmt.Sprintf(`{"count":%d}`, totalTokens)), - }, nil -} - -// Refresh refreshes the Kiro OAuth token. -// Supports both AWS Builder ID (SSO OIDC) and Google OAuth (social login). -// Uses mutex to prevent race conditions when multiple concurrent requests try to refresh. -func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { - // Serialize token refresh operations to prevent race conditions - e.refreshMu.Lock() - defer e.refreshMu.Unlock() - - var authID string - if auth != nil { - authID = auth.ID - } else { - authID = "" - } - log.Debugf("kiro executor: refresh called for auth %s", authID) - if auth == nil { - return nil, fmt.Errorf("kiro executor: auth is nil") - } - - // Double-check: After acquiring lock, verify token still needs refresh - // Another goroutine may have already refreshed while we were waiting - // NOTE: This check has a design limitation - it reads from the auth object passed in, - // not from persistent storage. If another goroutine returns a new Auth object (via Clone), - // this check won't see those updates. The mutex still prevents truly concurrent refreshes, - // but queued goroutines may still attempt redundant refreshes. This is acceptable as - // the refresh operation is idempotent and the extra API calls are infrequent. - if auth.Metadata != nil { - if lastRefresh, ok := auth.Metadata["last_refresh"].(string); ok { - if refreshTime, err := time.Parse(time.RFC3339, lastRefresh); err == nil { - // If token was refreshed within the last 30 seconds, skip refresh - if time.Since(refreshTime) < 30*time.Second { - log.Debugf("kiro executor: token was recently refreshed by another goroutine, skipping") - return auth, nil - } - } - } - // Also check if expires_at is now in the future with sufficient buffer - if expiresAt, ok := auth.Metadata["expires_at"].(string); ok { - if expTime, err := time.Parse(time.RFC3339, expiresAt); err == nil { - // If token expires more than 20 minutes from now, it's still valid - if time.Until(expTime) > 20*time.Minute { - log.Debugf("kiro executor: token is still valid (expires in %v), skipping refresh", time.Until(expTime)) - // CRITICAL FIX: Set NextRefreshAfter to prevent frequent refresh checks - // Without this, shouldRefresh() will return true again in 30 seconds - updated := auth.Clone() - // Set next refresh to 20 minutes before expiry, or at least 30 seconds from now - nextRefresh := expTime.Add(-20 * time.Minute) - minNextRefresh := time.Now().Add(30 * time.Second) - if nextRefresh.Before(minNextRefresh) { - nextRefresh = minNextRefresh - } - updated.NextRefreshAfter = nextRefresh - log.Debugf("kiro executor: setting NextRefreshAfter to %v (in %v)", nextRefresh.Format(time.RFC3339), time.Until(nextRefresh)) - return updated, nil - } - } - } - } - - var refreshToken string - var clientID, clientSecret string - var authMethod string - var region, startURL string - - if auth.Metadata != nil { - refreshToken = getMetadataString(auth.Metadata, "refresh_token", "refreshToken") - clientID = getMetadataString(auth.Metadata, "client_id", "clientId") - clientSecret = getMetadataString(auth.Metadata, "client_secret", "clientSecret") - authMethod = strings.ToLower(getMetadataString(auth.Metadata, "auth_method", "authMethod")) - region = getMetadataString(auth.Metadata, "region") - startURL = getMetadataString(auth.Metadata, "start_url", "startUrl") - } - - if refreshToken == "" { - return nil, fmt.Errorf("kiro executor: refresh token not found") - } - - var tokenData *kiroauth.KiroTokenData - var err error - - ssoClient := kiroauth.NewSSOOIDCClient(e.cfg) + ssoClient := kiroauth.NewSSOOIDCClient(e.cfg) // Use SSO OIDC refresh for AWS Builder ID or IDC, otherwise use Kiro's OAuth refresh endpoint switch { @@ -4030,662 +1271,3 @@ func (e *KiroExecutor) isTokenExpired(accessToken string) bool { return isExpired } - -// ══════════════════════════════════════════════════════════════════════════════ -// Web Search Handler (MCP API) -// ══════════════════════════════════════════════════════════════════════════════ - -// fetchToolDescription caching: -// Uses a mutex + fetched flag to ensure only one goroutine fetches at a time, -// with automatic retry on failure: -// - On failure, fetched stays false so subsequent calls will retry -// - On success, fetched is set to true — subsequent calls skip immediately (mutex-free fast path) -// The cached description is stored in the translator package via kiroclaude.SetWebSearchDescription(), -// enabling the translator's convertClaudeToolsToKiro to read it when building Kiro requests. -var ( - toolDescMu sync.Mutex - toolDescFetched atomic.Bool -) - -// fetchToolDescription calls MCP tools/list to get the web_search tool description -// and caches it. Safe to call concurrently — only one goroutine fetches at a time. -// If the fetch fails, subsequent calls will retry. On success, no further fetches occur. -// The httpClient parameter allows reusing a shared pooled HTTP client. -func fetchToolDescription(ctx context.Context, mcpEndpoint, authToken string, httpClient *http.Client, auth *cliproxyauth.Auth, authAttrs map[string]string) { - // Fast path: already fetched successfully, no lock needed - if toolDescFetched.Load() { - return - } - - toolDescMu.Lock() - defer toolDescMu.Unlock() - - // Double-check after acquiring lock - if toolDescFetched.Load() { - return - } - - handler := newWebSearchHandler(ctx, mcpEndpoint, authToken, httpClient, auth, authAttrs) - reqBody := []byte(`{"id":"tools_list","jsonrpc":"2.0","method":"tools/list"}`) - log.Debugf("kiro/websearch MCP tools/list request: %d bytes", len(reqBody)) - - req, err := http.NewRequestWithContext(ctx, "POST", mcpEndpoint, bytes.NewReader(reqBody)) - if err != nil { - log.Warnf("kiro/websearch: failed to create tools/list request: %v", err) - return - } - - // Reuse same headers as callMcpAPI - handler.setMcpHeaders(req) - - resp, err := handler.httpClient.Do(req) - if err != nil { - log.Warnf("kiro/websearch: tools/list request failed: %v", err) - return - } - defer func() { _ = resp.Body.Close() }() - - body, err := io.ReadAll(resp.Body) - if err != nil || resp.StatusCode != http.StatusOK { - log.Warnf("kiro/websearch: tools/list returned status %d", resp.StatusCode) - return - } - log.Debugf("kiro/websearch MCP tools/list response: [%d] %d bytes", resp.StatusCode, len(body)) - - // Parse: {"result":{"tools":[{"name":"web_search","description":"..."}]}} - var result struct { - Result *struct { - Tools []struct { - Name string `json:"name"` - Description string `json:"description"` - } `json:"tools"` - } `json:"result"` - } - if err := json.Unmarshal(body, &result); err != nil || result.Result == nil { - log.Warnf("kiro/websearch: failed to parse tools/list response") - return - } - - for _, tool := range result.Result.Tools { - if tool.Name == "web_search" && tool.Description != "" { - kiroclaude.SetWebSearchDescription(tool.Description) - toolDescFetched.Store(true) // success — no more fetches - log.Infof("kiro/websearch: cached web_search description from tools/list (%d bytes)", len(tool.Description)) - return - } - } - - // web_search tool not found in response - log.Warnf("kiro/websearch: web_search tool not found in tools/list response") -} - -// webSearchHandler handles web search requests via Kiro MCP API -type webSearchHandler struct { - ctx context.Context - mcpEndpoint string - httpClient *http.Client - authToken string - auth *cliproxyauth.Auth // for applyDynamicFingerprint - authAttrs map[string]string // optional, for custom headers from auth.Attributes -} - -// newWebSearchHandler creates a new webSearchHandler. -// If httpClient is nil, a default client with 30s timeout is used. -// Pass a shared pooled client (e.g. from getKiroPooledHTTPClient) for connection reuse. -func newWebSearchHandler(ctx context.Context, mcpEndpoint, authToken string, httpClient *http.Client, auth *cliproxyauth.Auth, authAttrs map[string]string) *webSearchHandler { - if httpClient == nil { - httpClient = &http.Client{ - Timeout: 30 * time.Second, - } - } - return &webSearchHandler{ - ctx: ctx, - mcpEndpoint: mcpEndpoint, - httpClient: httpClient, - authToken: authToken, - auth: auth, - authAttrs: authAttrs, - } -} - -// setMcpHeaders sets standard MCP API headers on the request, -// aligned with the GAR request pattern. -func (h *webSearchHandler) setMcpHeaders(req *http.Request) { - // 1. Content-Type & Accept (aligned with GAR) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "*/*") - - // 2. Kiro-specific headers (aligned with GAR) - req.Header.Set("x-amzn-kiro-agent-mode", "vibe") - req.Header.Set("x-amzn-codewhisperer-optout", "true") - - // 3. User-Agent: Reuse applyDynamicFingerprint for consistency - applyDynamicFingerprint(req, h.auth) - - // 4. AWS SDK identifiers - req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") - req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) - - // 5. Authentication - req.Header.Set("Authorization", "Bearer "+h.authToken) - - // 6. Custom headers from auth attributes - util.ApplyCustomHeadersFromAttrs(req, h.authAttrs) -} - -// mcpMaxRetries is the maximum number of retries for MCP API calls. -const mcpMaxRetries = 2 - -// callMcpAPI calls the Kiro MCP API with the given request. -// Includes retry logic with exponential backoff for retryable errors. -func (h *webSearchHandler) callMcpAPI(request *kiroclaude.McpRequest) (*kiroclaude.McpResponse, error) { - requestBody, err := json.Marshal(request) - if err != nil { - return nil, fmt.Errorf("failed to marshal MCP request: %w", err) - } - log.Debugf("kiro/websearch MCP request → %s (%d bytes)", h.mcpEndpoint, len(requestBody)) - - var lastErr error - for attempt := 0; attempt <= mcpMaxRetries; attempt++ { - if attempt > 0 { - backoff := time.Duration(1< 10*time.Second { - backoff = 10 * time.Second - } - log.Warnf("kiro/websearch: MCP retry %d/%d after %v (last error: %v)", attempt, mcpMaxRetries, backoff, lastErr) - select { - case <-h.ctx.Done(): - return nil, h.ctx.Err() - case <-time.After(backoff): - } - } - - req, err := http.NewRequestWithContext(h.ctx, "POST", h.mcpEndpoint, bytes.NewReader(requestBody)) - if err != nil { - return nil, fmt.Errorf("failed to create HTTP request: %w", err) - } - - h.setMcpHeaders(req) - - resp, err := h.httpClient.Do(req) - if err != nil { - lastErr = fmt.Errorf("MCP API request failed: %w", err) - continue // network error → retry - } - - body, err := io.ReadAll(resp.Body) - _ = resp.Body.Close() - if err != nil { - lastErr = fmt.Errorf("failed to read MCP response: %w", err) - continue // read error → retry - } - log.Debugf("kiro/websearch MCP response ← [%d] (%d bytes)", resp.StatusCode, len(body)) - - // Retryable HTTP status codes (aligned with GAR: 502, 503, 504) - if resp.StatusCode >= 502 && resp.StatusCode <= 504 { - lastErr = fmt.Errorf("MCP API returned retryable status %d: %s", resp.StatusCode, string(body)) - continue - } - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("MCP API returned status %d: %s", resp.StatusCode, string(body)) - } - - var mcpResponse kiroclaude.McpResponse - if err := json.Unmarshal(body, &mcpResponse); err != nil { - return nil, fmt.Errorf("failed to parse MCP response: %w", err) - } - - if mcpResponse.Error != nil { - code := -1 - if mcpResponse.Error.Code != nil { - code = *mcpResponse.Error.Code - } - msg := "Unknown error" - if mcpResponse.Error.Message != nil { - msg = *mcpResponse.Error.Message - } - return nil, fmt.Errorf("MCP error %d: %s", code, msg) - } - - return &mcpResponse, nil - } - - return nil, lastErr -} - -// webSearchAuthAttrs extracts auth attributes for MCP calls. -// Used by handleWebSearch and handleWebSearchStream to pass custom headers. -func webSearchAuthAttrs(auth *cliproxyauth.Auth) map[string]string { - if auth != nil { - return auth.Attributes - } - return nil -} - -const maxWebSearchIterations = 5 - -// handleWebSearchStream handles web_search requests: -// Step 1: tools/list (sync) → fetch/cache tool description -// Step 2+: MCP search → InjectToolResultsClaude → callKiroAndBuffer loop -// Note: We skip the "model decides to search" step because Claude Code already -// decided to use web_search. The Kiro tool description restricts non-coding -// topics, so asking the model again would cause it to refuse valid searches. -func (e *KiroExecutor) handleWebSearchStream( - ctx context.Context, - auth *cliproxyauth.Auth, - req cliproxyexecutor.Request, - opts cliproxyexecutor.Options, - accessToken, profileArn string, -) (<-chan cliproxyexecutor.StreamChunk, error) { - // Extract search query from Claude Code's web_search tool_use - query := kiroclaude.ExtractSearchQuery(req.Payload) - if query == "" { - log.Warnf("kiro/websearch: failed to extract search query, falling back to normal flow") - return e.callKiroDirectStream(ctx, auth, req, opts, accessToken, profileArn) - } - - // Build MCP endpoint using shared region resolution (supports api_region + ProfileARN fallback) - region := resolveKiroAPIRegion(auth) - mcpEndpoint := kiroclaude.BuildMcpEndpoint(region) - - // ── Step 1: tools/list (SYNC) — cache tool description ── - { - authAttrs := webSearchAuthAttrs(auth) - fetchToolDescription(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs) - } - - // Create output channel - out := make(chan cliproxyexecutor.StreamChunk) - - // Usage reporting: track web search requests like normal streaming requests - reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) - - go func() { - var wsErr error - defer reporter.trackFailure(ctx, &wsErr) - defer close(out) - - // Estimate input tokens using tokenizer (matching streamToChannel pattern) - var totalUsage usage.Detail - if enc, tokErr := getTokenizer(req.Model); tokErr == nil { - if inp, e := countClaudeChatTokens(enc, req.Payload); e == nil && inp > 0 { - totalUsage.InputTokens = inp - } else { - totalUsage.InputTokens = int64(len(req.Payload) / 4) - } - } else { - totalUsage.InputTokens = int64(len(req.Payload) / 4) - } - if totalUsage.InputTokens == 0 && len(req.Payload) > 0 { - totalUsage.InputTokens = 1 - } - var accumulatedOutputLen int - defer func() { - if wsErr != nil { - return // let trackFailure handle failure reporting - } - totalUsage.OutputTokens = int64(accumulatedOutputLen / 4) - if accumulatedOutputLen > 0 && totalUsage.OutputTokens == 0 { - totalUsage.OutputTokens = 1 - } - reporter.publish(ctx, totalUsage) - }() - - // Send message_start event to client (aligned with streamToChannel pattern) - // Use payloadRequestedModel to return user's original model alias - msgStart := kiroclaude.BuildClaudeMessageStartEvent( - payloadRequestedModel(opts, req.Model), - totalUsage.InputTokens, - ) - select { - case <-ctx.Done(): - return - case out <- cliproxyexecutor.StreamChunk{Payload: append(msgStart, '\n', '\n')}: - } - - // ── Step 2+: MCP search → InjectToolResultsClaude → callKiroAndBuffer loop ── - contentBlockIndex := 0 - currentQuery := query - - // Replace web_search tool description with a minimal one that allows re-search. - // The original tools/list description from Kiro restricts non-coding topics, - // but we've already decided to search. We keep the tool so the model can - // request additional searches when results are insufficient. - simplifiedPayload, simplifyErr := kiroclaude.ReplaceWebSearchToolDescription(bytes.Clone(req.Payload)) - if simplifyErr != nil { - log.Warnf("kiro/websearch: failed to simplify web_search tool: %v, using original payload", simplifyErr) - simplifiedPayload = bytes.Clone(req.Payload) - } - - currentClaudePayload := simplifiedPayload - totalSearches := 0 - - // Generate toolUseId for the first iteration (Claude Code already decided to search) - currentToolUseId := fmt.Sprintf("srvtoolu_%s", kiroclaude.GenerateToolUseID()) - - for iteration := 0; iteration < maxWebSearchIterations; iteration++ { - log.Infof("kiro/websearch: search iteration %d/%d", - iteration+1, maxWebSearchIterations) - - // MCP search - _, mcpRequest := kiroclaude.CreateMcpRequest(currentQuery) - - authAttrs := webSearchAuthAttrs(auth) - handler := newWebSearchHandler(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs) - mcpResponse, mcpErr := handler.callMcpAPI(mcpRequest) - - var searchResults *kiroclaude.WebSearchResults - if mcpErr != nil { - log.Warnf("kiro/websearch: MCP API call failed: %v, continuing with empty results", mcpErr) - } else { - searchResults = kiroclaude.ParseSearchResults(mcpResponse) - } - - resultCount := 0 - if searchResults != nil { - resultCount = len(searchResults.Results) - } - totalSearches++ - log.Infof("kiro/websearch: iteration %d — got %d search results", iteration+1, resultCount) - - // Send search indicator events to client - searchEvents := kiroclaude.GenerateSearchIndicatorEvents(currentQuery, currentToolUseId, searchResults, contentBlockIndex) - for _, event := range searchEvents { - select { - case <-ctx.Done(): - return - case out <- cliproxyexecutor.StreamChunk{Payload: event}: - } - } - contentBlockIndex += 2 - - // Inject tool_use + tool_result into Claude payload, then call GAR - var err error - currentClaudePayload, err = kiroclaude.InjectToolResultsClaude(currentClaudePayload, currentToolUseId, currentQuery, searchResults) - if err != nil { - log.Warnf("kiro/websearch: failed to inject tool results: %v", err) - wsErr = fmt.Errorf("failed to inject tool results: %w", err) - e.sendFallbackText(ctx, out, contentBlockIndex, currentQuery, searchResults) - return - } - - // Call GAR with modified Claude payload (full translation pipeline) - modifiedReq := req - modifiedReq.Payload = currentClaudePayload - kiroChunks, kiroErr := e.callKiroAndBuffer(ctx, auth, modifiedReq, opts, accessToken, profileArn) - if kiroErr != nil { - log.Warnf("kiro/websearch: Kiro API failed at iteration %d: %v", iteration+1, kiroErr) - wsErr = fmt.Errorf("kiro API failed at iteration %d: %w", iteration+1, kiroErr) - e.sendFallbackText(ctx, out, contentBlockIndex, currentQuery, searchResults) - return - } - - // Analyze response - analysis := kiroclaude.AnalyzeBufferedStream(kiroChunks) - log.Infof("kiro/websearch: iteration %d — stop_reason: %s, has_tool_use: %v", - iteration+1, analysis.StopReason, analysis.HasWebSearchToolUse) - - if analysis.HasWebSearchToolUse && analysis.WebSearchQuery != "" && iteration+1 < maxWebSearchIterations { - // Model wants another search - filteredChunks := kiroclaude.FilterChunksForClient(kiroChunks, analysis.WebSearchToolUseIndex, contentBlockIndex) - for _, chunk := range filteredChunks { - select { - case <-ctx.Done(): - return - case out <- cliproxyexecutor.StreamChunk{Payload: chunk}: - } - } - - currentQuery = analysis.WebSearchQuery - currentToolUseId = analysis.WebSearchToolUseId - continue - } - - // Model returned final response — stream to client - for _, chunk := range kiroChunks { - if contentBlockIndex > 0 && len(chunk) > 0 { - adjusted, shouldForward := kiroclaude.AdjustSSEChunk(chunk, contentBlockIndex) - if !shouldForward { - continue - } - accumulatedOutputLen += len(adjusted) - select { - case <-ctx.Done(): - return - case out <- cliproxyexecutor.StreamChunk{Payload: adjusted}: - } - } else { - accumulatedOutputLen += len(chunk) - select { - case <-ctx.Done(): - return - case out <- cliproxyexecutor.StreamChunk{Payload: chunk}: - } - } - } - log.Infof("kiro/websearch: completed after %d search iteration(s), total searches: %d", iteration+1, totalSearches) - return - } - - log.Warnf("kiro/websearch: reached max iterations (%d), stopping search loop", maxWebSearchIterations) - }() - - return out, nil -} - -// handleWebSearch handles web_search requests for non-streaming Execute path. -// Performs MCP search synchronously, injects results into the request payload, -// then calls the normal non-streaming Kiro API path which returns a proper -// Claude JSON response (not SSE chunks). -func (e *KiroExecutor) handleWebSearch( - ctx context.Context, - auth *cliproxyauth.Auth, - req cliproxyexecutor.Request, - opts cliproxyexecutor.Options, - accessToken, profileArn string, -) (cliproxyexecutor.Response, error) { - // Extract search query from Claude Code's web_search tool_use - query := kiroclaude.ExtractSearchQuery(req.Payload) - if query == "" { - log.Warnf("kiro/websearch: non-stream: failed to extract search query, falling back to normal Execute") - // Fall through to normal non-streaming path - return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn) - } - - // Build MCP endpoint using shared region resolution (supports api_region + ProfileARN fallback) - region := resolveKiroAPIRegion(auth) - mcpEndpoint := kiroclaude.BuildMcpEndpoint(region) - - // Step 1: Fetch/cache tool description (sync) - { - authAttrs := webSearchAuthAttrs(auth) - fetchToolDescription(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs) - } - - // Step 2: Perform MCP search - _, mcpRequest := kiroclaude.CreateMcpRequest(query) - - authAttrs := webSearchAuthAttrs(auth) - handler := newWebSearchHandler(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs) - mcpResponse, mcpErr := handler.callMcpAPI(mcpRequest) - - var searchResults *kiroclaude.WebSearchResults - if mcpErr != nil { - log.Warnf("kiro/websearch: non-stream: MCP API call failed: %v, continuing with empty results", mcpErr) - } else { - searchResults = kiroclaude.ParseSearchResults(mcpResponse) - } - - resultCount := 0 - if searchResults != nil { - resultCount = len(searchResults.Results) - } - log.Infof("kiro/websearch: non-stream: got %d search results", resultCount) - - // Step 3: Replace restrictive web_search tool description (align with streaming path) - simplifiedPayload, simplifyErr := kiroclaude.ReplaceWebSearchToolDescription(bytes.Clone(req.Payload)) - if simplifyErr != nil { - log.Warnf("kiro/websearch: non-stream: failed to simplify web_search tool: %v, using original payload", simplifyErr) - simplifiedPayload = bytes.Clone(req.Payload) - } - - // Step 4: Inject search tool_use + tool_result into Claude payload - currentToolUseId := fmt.Sprintf("srvtoolu_%s", kiroclaude.GenerateToolUseID()) - modifiedPayload, err := kiroclaude.InjectToolResultsClaude(simplifiedPayload, currentToolUseId, query, searchResults) - if err != nil { - log.Warnf("kiro/websearch: non-stream: failed to inject tool results: %v, falling back", err) - return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn) - } - - // Step 5: Call Kiro API via the normal non-streaming path (executeWithRetry) - // This path uses parseEventStream → BuildClaudeResponse → TranslateNonStream - // to produce a proper Claude JSON response - modifiedReq := req - modifiedReq.Payload = modifiedPayload - - resp, err := e.executeNonStreamFallback(ctx, auth, modifiedReq, opts, accessToken, profileArn) - if err != nil { - return resp, err - } - - // Step 6: Inject server_tool_use + web_search_tool_result into response - // so Claude Code can display "Did X searches in Ys" - indicators := []kiroclaude.SearchIndicator{ - { - ToolUseID: currentToolUseId, - Query: query, - Results: searchResults, - }, - } - injectedPayload, injErr := kiroclaude.InjectSearchIndicatorsInResponse(resp.Payload, indicators) - if injErr != nil { - log.Warnf("kiro/websearch: non-stream: failed to inject search indicators: %v", injErr) - } else { - resp.Payload = injectedPayload - } - - return resp, nil -} - -// callKiroAndBuffer calls the Kiro API and buffers all response chunks. -// Returns the buffered chunks for analysis before forwarding to client. -// Usage reporting is NOT done here — the caller (handleWebSearchStream) manages its own reporter. -func (e *KiroExecutor) callKiroAndBuffer( - ctx context.Context, - auth *cliproxyauth.Auth, - req cliproxyexecutor.Request, - opts cliproxyexecutor.Options, - accessToken, profileArn string, -) ([][]byte, error) { - from := opts.SourceFormat - to := sdktranslator.FromString("kiro") - body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) - log.Debugf("kiro/websearch GAR request: %d bytes", len(body)) - - kiroModelID := e.mapModelToKiro(req.Model) - isAgentic, isChatOnly := determineAgenticMode(req.Model) - effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn) - - tokenKey := getTokenKey(auth) - - kiroStream, err := e.executeStreamWithRetry( - ctx, auth, req, opts, accessToken, effectiveProfileArn, - body, from, nil, kiroModelID, isAgentic, isChatOnly, tokenKey, - ) - if err != nil { - return nil, err - } - - // Buffer all chunks - var chunks [][]byte - for chunk := range kiroStream { - if chunk.Err != nil { - return chunks, chunk.Err - } - if len(chunk.Payload) > 0 { - chunks = append(chunks, bytes.Clone(chunk.Payload)) - } - } - - log.Debugf("kiro/websearch GAR response: %d chunks buffered", len(chunks)) - - return chunks, nil -} - -// callKiroDirectStream creates a direct streaming channel to Kiro API without search. -func (e *KiroExecutor) callKiroDirectStream( - ctx context.Context, - auth *cliproxyauth.Auth, - req cliproxyexecutor.Request, - opts cliproxyexecutor.Options, - accessToken, profileArn string, -) (<-chan cliproxyexecutor.StreamChunk, error) { - from := opts.SourceFormat - to := sdktranslator.FromString("kiro") - body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) - - kiroModelID := e.mapModelToKiro(req.Model) - isAgentic, isChatOnly := determineAgenticMode(req.Model) - effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn) - - tokenKey := getTokenKey(auth) - - reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) - var streamErr error - defer reporter.trackFailure(ctx, &streamErr) - - stream, streamErr := e.executeStreamWithRetry( - ctx, auth, req, opts, accessToken, effectiveProfileArn, - body, from, reporter, kiroModelID, isAgentic, isChatOnly, tokenKey, - ) - return stream, streamErr -} - -// sendFallbackText sends a simple text response when the Kiro API fails during the search loop. -// Delegates SSE event construction to kiroclaude.BuildFallbackTextEvents() for alignment -// with how streamToChannel() uses BuildClaude*Event() functions. -func (e *KiroExecutor) sendFallbackText( - ctx context.Context, - out chan<- cliproxyexecutor.StreamChunk, - contentBlockIndex int, - query string, - searchResults *kiroclaude.WebSearchResults, -) { - events := kiroclaude.BuildFallbackTextEvents(contentBlockIndex, query, searchResults) - for _, event := range events { - select { - case <-ctx.Done(): - return - case out <- cliproxyexecutor.StreamChunk{Payload: append(event, '\n', '\n')}: - } - } -} - -// executeNonStreamFallback runs the standard non-streaming Execute path for a request. -// Used by handleWebSearch after injecting search results, or as a fallback. -func (e *KiroExecutor) executeNonStreamFallback( - ctx context.Context, - auth *cliproxyauth.Auth, - req cliproxyexecutor.Request, - opts cliproxyexecutor.Options, - accessToken, profileArn string, -) (cliproxyexecutor.Response, error) { - from := opts.SourceFormat - to := sdktranslator.FromString("kiro") - body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) - - kiroModelID := e.mapModelToKiro(req.Model) - isAgentic, isChatOnly := determineAgenticMode(req.Model) - effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn) - tokenKey := getTokenKey(auth) - - reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) - var err error - defer reporter.trackFailure(ctx, &err) - - resp, err := e.executeWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, body, from, to, reporter, kiroModelID, isAgentic, isChatOnly, tokenKey) - return resp, err -} - -func (e *KiroExecutor) CloseExecutionSession(sessionID string) {} diff --git a/pkg/llmproxy/executor/kiro_streaming.go b/pkg/llmproxy/executor/kiro_streaming.go new file mode 100644 index 0000000000..2e3ea70162 --- /dev/null +++ b/pkg/llmproxy/executor/kiro_streaming.go @@ -0,0 +1,2993 @@ +package executor + +import ( + "bufio" + "bytes" + "context" + "encoding/binary" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/google/uuid" + kiroauth "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/auth/kiro" + kiroclaude "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/translator/kiro/claude" + kirocommon "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/translator/kiro/common" + "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/util" + clipproxyauth "github.com/kooshapari/cliproxyapi-plusplus/v6/sdk/cliproxy/auth" + clipproxyexecutor "github.com/kooshapari/cliproxyapi-plusplus/v6/sdk/cliproxy/executor" + "github.com/kooshapari/cliproxyapi-plusplus/v6/sdk/cliproxy/usage" + sdktranslator "github.com/kooshapari/cliproxyapi-plusplus/v6/sdk/translator" + log "github.com/sirupsen/logrus" +) + +// ExecuteStream handles streaming requests to Kiro API. +// Supports automatic token refresh on 401/403 errors and quota fallback on 429. +func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { + accessToken, profileArn := kiroCredentials(auth) + if accessToken == "" { + return nil, fmt.Errorf("kiro: access token not found in auth") + } + + // Rate limiting: get token key for tracking + tokenKey := getTokenKey(auth) + rateLimiter := kiroauth.GetGlobalRateLimiter() + cooldownMgr := kiroauth.GetGlobalCooldownManager() + + // Check if token is in cooldown period + if cooldownMgr.IsInCooldown(tokenKey) { + remaining := cooldownMgr.GetRemainingCooldown(tokenKey) + reason := cooldownMgr.GetCooldownReason(tokenKey) + log.Warnf("kiro: token %s is in cooldown (reason: %s), remaining: %v", tokenKey, reason, remaining) + return nil, fmt.Errorf("kiro: token is in cooldown for %v (reason: %s)", remaining, reason) + } + + // Wait for rate limiter before proceeding + log.Debugf("kiro: stream waiting for rate limiter for token %s", tokenKey) + rateLimiter.WaitForToken(tokenKey) + log.Debugf("kiro: stream rate limiter cleared for token %s", tokenKey) + + // Check if token is expired before making request (covers both normal and web_search paths) + if e.isTokenExpired(accessToken) { + log.Infof("kiro: access token expired, attempting recovery before stream request") + + // 方案 B: 先尝试从文件重新加载 token(后台刷新器可能已更新文件) + reloadedAuth, reloadErr := e.reloadAuthFromFile(auth) + if reloadErr == nil && reloadedAuth != nil { + // 文件中有更新的 token,使用它 + auth = reloadedAuth + accessToken, profileArn = kiroCredentials(auth) + log.Infof("kiro: recovered token from file (background refresh) for stream, expires_at: %v", auth.Metadata["expires_at"]) + } else { + // 文件中的 token 也过期了,执行主动刷新 + log.Debugf("kiro: file reload failed (%v), attempting active refresh for stream", reloadErr) + refreshedAuth, refreshErr := e.Refresh(ctx, auth) + if refreshErr != nil { + log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr) + } else if refreshedAuth != nil { + auth = refreshedAuth + // Persist the refreshed auth to file so subsequent requests use it + if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { + log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) + } + accessToken, profileArn = kiroCredentials(auth) + log.Infof("kiro: token refreshed successfully before stream request") + } + } + } + + // Check for pure web_search request + // Route to MCP endpoint instead of normal Kiro API + if kiroclaude.HasWebSearchTool(req.Payload) { + log.Infof("kiro: detected pure web_search request, routing to MCP endpoint") + streamWebSearch, errWebSearch := e.handleWebSearchStream(ctx, auth, req, opts, accessToken, profileArn) + if errWebSearch != nil { + return nil, errWebSearch + } + return &cliproxyexecutor.StreamResult{Chunks: streamWebSearch}, nil + } + + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + defer reporter.trackFailure(ctx, &err) + + from := opts.SourceFormat + to := sdktranslator.FromString("kiro") + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) + + kiroModelID := e.mapModelToKiro(req.Model) + + // Determine agentic mode and effective profile ARN using helper functions + isAgentic, isChatOnly := determineAgenticMode(req.Model) + effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn) + + // Execute stream with retry on 401/403 and 429 (quota exhausted) + // Note: currentOrigin and kiroPayload are built inside executeStreamWithRetry for each endpoint + streamKiro, errStreamKiro := e.executeStreamWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, body, from, reporter, kiroModelID, isAgentic, isChatOnly, tokenKey) + if errStreamKiro != nil { + return nil, errStreamKiro + } + return &cliproxyexecutor.StreamResult{Chunks: streamKiro}, nil +} + +// executeStreamWithRetry performs the streaming HTTP request with automatic retry on auth errors. +// Supports automatic fallback between endpoints with different quotas: +// - Amazon Q endpoint (CLI origin) uses Amazon Q Developer quota +// - CodeWhisperer endpoint (AI_EDITOR origin) uses Kiro IDE quota +// Also supports multi-endpoint fallback similar to Antigravity implementation. +// tokenKey is used for rate limiting and cooldown tracking. +func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, accessToken, profileArn string, body []byte, from sdktranslator.Format, reporter *usageReporter, kiroModelID string, isAgentic, isChatOnly bool, tokenKey string) (<-chan cliproxyexecutor.StreamChunk, error) { + var currentOrigin string + maxRetries := 2 // Allow retries for token refresh + endpoint fallback + rateLimiter := kiroauth.GetGlobalRateLimiter() + cooldownMgr := kiroauth.GetGlobalCooldownManager() + endpointConfigs := getKiroEndpointConfigs(auth) + var last429Err error + + for endpointIdx := 0; endpointIdx < len(endpointConfigs); endpointIdx++ { + endpointConfig := endpointConfigs[endpointIdx] + url := endpointConfig.URL + // Use this endpoint's compatible Origin (critical for avoiding 403 errors) + currentOrigin = endpointConfig.Origin + + // Rebuild payload with the correct origin for this endpoint + // Each endpoint requires its matching Origin value in the request body + kiroPayload, thinkingEnabled := buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers) + + log.Debugf("kiro: stream trying endpoint %d/%d: %s (Name: %s, Origin: %s)", + endpointIdx+1, len(endpointConfigs), url, endpointConfig.Name, currentOrigin) + + for attempt := 0; attempt <= maxRetries; attempt++ { + // Apply human-like delay before first streaming request (not on retries) + // This mimics natural user behavior patterns + // Note: Delay is NOT applied during streaming response - only before initial request + if attempt == 0 && endpointIdx == 0 { + kiroauth.ApplyHumanLikeDelay() + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(kiroPayload)) + if err != nil { + return nil, err + } + + httpReq.Header.Set("Content-Type", kiroContentType) + httpReq.Header.Set("Accept", kiroAcceptStream) + // Only set X-Amz-Target if specified (Q endpoint doesn't require it) + if endpointConfig.AmzTarget != "" { + httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget) + } + // Kiro-specific headers + httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeVibe) + httpReq.Header.Set("x-amzn-codewhisperer-optout", "true") + + // Apply dynamic fingerprint-based headers + applyDynamicFingerprint(httpReq, auth) + + httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") + httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) + + // Bearer token authentication for all auth types (Builder ID, IDC, social, etc.) + httpReq.Header.Set("Authorization", "Bearer "+accessToken) + + var attrs map[string]string + if auth != nil { + attrs = auth.Attributes + } + util.ApplyCustomHeadersFromAttrs(httpReq, attrs) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: kiroPayload, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 0) + httpResp, err := httpClient.Do(httpReq) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + + // Enhanced socket retry for streaming: Check if error is retryable (network timeout, connection reset, etc.) + retryCfg := defaultRetryConfig() + if isRetryableError(err) && attempt < retryCfg.MaxRetries { + delay := calculateRetryDelay(attempt, retryCfg) + logRetryAttempt(attempt, retryCfg.MaxRetries, fmt.Sprintf("stream socket error: %v", err), delay, endpointConfig.Name) + time.Sleep(delay) + continue + } + + return nil, err + } + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + + // Handle 429 errors (quota exhausted) - try next endpoint + // Each endpoint has its own quota pool, so we can try different endpoints + if httpResp.StatusCode == 429 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + // Record failure and set cooldown for 429 + rateLimiter.MarkTokenFailed(tokenKey) + cooldownDuration := kiroauth.CalculateCooldownFor429(attempt) + cooldownMgr.SetCooldown(tokenKey, cooldownDuration, kiroauth.CooldownReason429) + log.Warnf("kiro: stream rate limit hit (429), token %s set to cooldown for %v", tokenKey, cooldownDuration) + + // Preserve last 429 so callers can correctly backoff when all endpoints are exhausted + last429Err = statusErr{code: httpResp.StatusCode, msg: string(respBody)} + + log.Warnf("kiro: stream %s endpoint quota exhausted (429), will try next endpoint, body: %s", + endpointConfig.Name, summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody)) + + // Break inner retry loop to try next endpoint (which has different quota) + break + } + + // Handle 5xx server errors with exponential backoff retry + // Enhanced: Use retryConfig for consistent retry behavior + if httpResp.StatusCode >= 500 && httpResp.StatusCode < 600 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + retryCfg := defaultRetryConfig() + // Check if this specific 5xx code is retryable (502, 503, 504) + if isRetryableHTTPStatus(httpResp.StatusCode) && attempt < retryCfg.MaxRetries { + delay := calculateRetryDelay(attempt, retryCfg) + logRetryAttempt(attempt, retryCfg.MaxRetries, fmt.Sprintf("stream HTTP %d", httpResp.StatusCode), delay, endpointConfig.Name) + time.Sleep(delay) + continue + } else if attempt < maxRetries { + // Fallback for other 5xx errors (500, 501, etc.) + backoff := time.Duration(1< 30*time.Second { + backoff = 30 * time.Second + } + log.Warnf("kiro: stream server error %d, retrying in %v (attempt %d/%d)", httpResp.StatusCode, backoff, attempt+1, maxRetries) + time.Sleep(backoff) + continue + } + log.Errorf("kiro: stream server error %d after %d retries", httpResp.StatusCode, maxRetries) + return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + // Handle 400 errors - Credential/Validation issues + // Do NOT switch endpoints - return error immediately + if httpResp.StatusCode == 400 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + log.Warnf("kiro: received 400 error (attempt %d/%d), body: %s", attempt+1, maxRetries+1, summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody)) + + // 400 errors indicate request validation issues - return immediately without retry + return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + // Handle 401 errors with token refresh and retry + // 401 = Unauthorized (token expired/invalid) - refresh token + if httpResp.StatusCode == 401 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + log.Warnf("kiro: stream received 401 error, attempting token refresh") + refreshedAuth, refreshErr := e.Refresh(ctx, auth) + if refreshErr != nil { + log.Errorf("kiro: token refresh failed: %v", refreshErr) + return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + if refreshedAuth != nil { + auth = refreshedAuth + // Persist the refreshed auth to file so subsequent requests use it + if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { + log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) + // Continue anyway - the token is valid for this request + } + accessToken, profileArn = kiroCredentials(auth) + // Rebuild payload with new profile ARN if changed + kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers) + if attempt < maxRetries { + log.Infof("kiro: token refreshed successfully, retrying stream request (attempt %d/%d)", attempt+1, maxRetries+1) + continue + } + log.Infof("kiro: token refreshed successfully, no retries remaining") + } + + log.Warnf("kiro stream error, status: 401, body: %s", string(respBody)) + return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + // Handle 402 errors - Monthly Limit Reached + if httpResp.StatusCode == 402 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + log.Warnf("kiro: stream received 402 (monthly limit). Upstream body: %s", string(respBody)) + + // Return upstream error body directly + return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + // Handle 403 errors - Access Denied / Token Expired + // Do NOT switch endpoints for 403 errors + if httpResp.StatusCode == 403 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + // Log the 403 error details for debugging + log.Warnf("kiro: stream received 403 error (attempt %d/%d), body: %s", attempt+1, maxRetries+1, string(respBody)) + + respBodyStr := string(respBody) + + // Check for SUSPENDED status - return immediately without retry + if strings.Contains(respBodyStr, "SUSPENDED") || strings.Contains(respBodyStr, "TEMPORARILY_SUSPENDED") { + // Set long cooldown for suspended accounts + rateLimiter.CheckAndMarkSuspended(tokenKey, respBodyStr) + cooldownMgr.SetCooldown(tokenKey, kiroauth.LongCooldown, kiroauth.CooldownReasonSuspended) + log.Errorf("kiro: stream account is suspended, token %s set to cooldown for %v", tokenKey, kiroauth.LongCooldown) + return nil, statusErr{code: httpResp.StatusCode, msg: "account suspended: " + string(respBody)} + } + + // Check if this looks like a token-related 403 (some APIs return 403 for expired tokens) + isTokenRelated := strings.Contains(respBodyStr, "token") || + strings.Contains(respBodyStr, "expired") || + strings.Contains(respBodyStr, "invalid") || + strings.Contains(respBodyStr, "unauthorized") + + if isTokenRelated && attempt < maxRetries { + log.Warnf("kiro: 403 appears token-related, attempting token refresh") + refreshedAuth, refreshErr := e.Refresh(ctx, auth) + if refreshErr != nil { + log.Errorf("kiro: token refresh failed: %v", refreshErr) + // Token refresh failed - return error immediately + return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + if refreshedAuth != nil { + auth = refreshedAuth + // Persist the refreshed auth to file so subsequent requests use it + if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { + log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) + // Continue anyway - the token is valid for this request + } + accessToken, profileArn = kiroCredentials(auth) + kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers) + log.Infof("kiro: token refreshed for 403, retrying stream request") + continue + } + } + + // For non-token 403 or after max retries, return error immediately + // Do NOT switch endpoints for 403 errors + log.Warnf("kiro: 403 error, returning immediately (no endpoint switch)") + return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + b, _ := io.ReadAll(httpResp.Body) + appendAPIResponseChunk(ctx, e.cfg, b) + log.Debugf("kiro stream error, status: %d, body: %s", httpResp.StatusCode, string(b)) + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("response body close error: %v", errClose) + } + return nil, statusErr{code: httpResp.StatusCode, msg: string(b)} + } + + out := make(chan cliproxyexecutor.StreamChunk) + + // Record success immediately since connection was established successfully + // Streaming errors will be handled separately + rateLimiter.MarkTokenSuccess(tokenKey) + log.Debugf("kiro: stream request successful, token %s marked as success", tokenKey) + + go func(resp *http.Response, thinkingEnabled bool) { + defer close(out) + defer func() { + if r := recover(); r != nil { + log.Errorf("kiro: panic in stream handler: %v", r) + out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("internal error: %v", r)} + } + }() + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("response body close error: %v", errClose) + } + }() + + // Kiro API always returns tags regardless of request parameters + // So we always enable thinking parsing for Kiro responses + log.Debugf("kiro: stream thinkingEnabled = %v (always true for Kiro)", thinkingEnabled) + + e.streamToChannel(ctx, resp.Body, out, from, payloadRequestedModel(opts, req.Model), opts.OriginalRequest, body, reporter, thinkingEnabled) + }(httpResp, thinkingEnabled) + + return out, nil + } + // Inner retry loop exhausted for this endpoint, try next endpoint + // Note: This code is unreachable because all paths in the inner loop + // either return or continue. Kept as comment for documentation. + } + + // All endpoints exhausted + if last429Err != nil { + return nil, last429Err + } + return nil, fmt.Errorf("kiro: stream all endpoints exhausted") +} + +// EventStreamError represents an Event Stream processing error +type EventStreamError struct { + Type string // "fatal", "malformed" + Message string + Cause error +} + +func (e *EventStreamError) Error() string { + if e.Cause != nil { + return fmt.Sprintf("event stream %s: %s: %v", e.Type, e.Message, e.Cause) + } + return fmt.Sprintf("event stream %s: %s", e.Type, e.Message) +} + +// eventStreamMessage represents a parsed AWS Event Stream message +type eventStreamMessage struct { + EventType string // Event type from headers (e.g., "assistantResponseEvent") + Payload []byte // JSON payload of the message +} + +// NOTE: Request building functions moved to pkg/llmproxy/translator/kiro/claude/kiro_claude_request.go +// The executor now uses kiroclaude.BuildKiroPayload() instead + +// parseEventStream parses AWS Event Stream binary format. +// Extracts text content, tool uses, and stop_reason from the response. +// Supports embedded [Called ...] tool calls and input buffering for toolUseEvent. +// Returns: content, toolUses, usageInfo, stopReason, error +func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroclaude.KiroToolUse, usage.Detail, string, error) { + var content strings.Builder + var toolUses []kiroclaude.KiroToolUse + var usageInfo usage.Detail + var stopReason string // Extracted from upstream response + reader := bufio.NewReader(body) + + // Tool use state tracking for input buffering and deduplication + processedIDs := make(map[string]bool) + var currentToolUse *kiroclaude.ToolUseState + + // Upstream usage tracking - Kiro API returns credit usage and context percentage + var upstreamContextPercentage float64 // Context usage percentage from upstream (e.g., 78.56) + + for { + msg, eventErr := e.readEventStreamMessage(reader) + if eventErr != nil { + log.Errorf("kiro: parseEventStream error: %v", eventErr) + return content.String(), toolUses, usageInfo, stopReason, eventErr + } + if msg == nil { + // Normal end of stream (EOF) + break + } + + eventType := msg.EventType + payload := msg.Payload + if len(payload) == 0 { + continue + } + + var event map[string]interface{} + if err := json.Unmarshal(payload, &event); err != nil { + log.Debugf("kiro: skipping malformed event: %v", err) + continue + } + + // Check for error/exception events in the payload (Kiro API may return errors with HTTP 200) + // These can appear as top-level fields or nested within the event + if errType, hasErrType := event["_type"].(string); hasErrType { + // AWS-style error: {"_type": "com.amazon.aws.codewhisperer#ValidationException", "message": "..."} + errMsg := "" + if msg, ok := event["message"].(string); ok { + errMsg = msg + } + log.Errorf("kiro: received AWS error in event stream: type=%s, message=%s", errType, errMsg) + return "", nil, usageInfo, stopReason, fmt.Errorf("kiro API error: %s - %s", errType, errMsg) + } + if errType, hasErrType := event["type"].(string); hasErrType && (errType == "error" || errType == "exception") { + // Generic error event + errMsg := "" + if msg, ok := event["message"].(string); ok { + errMsg = msg + } else if errObj, ok := event["error"].(map[string]interface{}); ok { + if msg, ok := errObj["message"].(string); ok { + errMsg = msg + } + } + log.Errorf("kiro: received error event in stream: type=%s, message=%s", errType, errMsg) + return "", nil, usageInfo, stopReason, fmt.Errorf("kiro API error: %s", errMsg) + } + + // Extract stop_reason from various event formats + // Kiro/Amazon Q API may include stop_reason in different locations + if sr := kirocommon.GetString(event, "stop_reason"); sr != "" { + stopReason = sr + log.Debugf("kiro: parseEventStream found stop_reason (top-level): %s", stopReason) + } + if sr := kirocommon.GetString(event, "stopReason"); sr != "" { + stopReason = sr + log.Debugf("kiro: parseEventStream found stopReason (top-level): %s", stopReason) + } + + // Handle different event types + switch eventType { + case "followupPromptEvent": + // Filter out followupPrompt events - these are UI suggestions, not content + log.Debugf("kiro: parseEventStream ignoring followupPrompt event") + continue + + case "assistantResponseEvent": + if assistantResp, ok := event["assistantResponseEvent"].(map[string]interface{}); ok { + if contentText, ok := assistantResp["content"].(string); ok { + content.WriteString(contentText) + } + // Extract stop_reason from assistantResponseEvent + if sr := kirocommon.GetString(assistantResp, "stop_reason"); sr != "" { + stopReason = sr + log.Debugf("kiro: parseEventStream found stop_reason in assistantResponseEvent: %s", stopReason) + } + if sr := kirocommon.GetString(assistantResp, "stopReason"); sr != "" { + stopReason = sr + log.Debugf("kiro: parseEventStream found stopReason in assistantResponseEvent: %s", stopReason) + } + // Extract tool uses from response + if toolUsesRaw, ok := assistantResp["toolUses"].([]interface{}); ok { + for _, tuRaw := range toolUsesRaw { + if tu, ok := tuRaw.(map[string]interface{}); ok { + toolUseID := kirocommon.GetStringValue(tu, "toolUseId") + // Check for duplicate + if processedIDs[toolUseID] { + log.Debugf("kiro: skipping duplicate tool use from assistantResponse: %s", toolUseID) + continue + } + processedIDs[toolUseID] = true + + toolUse := kiroclaude.KiroToolUse{ + ToolUseID: toolUseID, + Name: kirocommon.GetStringValue(tu, "name"), + } + if input, ok := tu["input"].(map[string]interface{}); ok { + toolUse.Input = input + } + toolUses = append(toolUses, toolUse) + } + } + } + } + // Also try direct format + if contentText, ok := event["content"].(string); ok { + content.WriteString(contentText) + } + // Direct tool uses + if toolUsesRaw, ok := event["toolUses"].([]interface{}); ok { + for _, tuRaw := range toolUsesRaw { + if tu, ok := tuRaw.(map[string]interface{}); ok { + toolUseID := kirocommon.GetStringValue(tu, "toolUseId") + // Check for duplicate + if processedIDs[toolUseID] { + log.Debugf("kiro: skipping duplicate direct tool use: %s", toolUseID) + continue + } + processedIDs[toolUseID] = true + + toolUse := kiroclaude.KiroToolUse{ + ToolUseID: toolUseID, + Name: kirocommon.GetStringValue(tu, "name"), + } + if input, ok := tu["input"].(map[string]interface{}); ok { + toolUse.Input = input + } + toolUses = append(toolUses, toolUse) + } + } + } + + case "toolUseEvent": + // Handle dedicated tool use events with input buffering + completedToolUses, newState := kiroclaude.ProcessToolUseEvent(event, currentToolUse, processedIDs) + currentToolUse = newState + toolUses = append(toolUses, completedToolUses...) + + case "supplementaryWebLinksEvent": + if inputTokens, ok := event["inputTokens"].(float64); ok { + usageInfo.InputTokens = int64(inputTokens) + } + if outputTokens, ok := event["outputTokens"].(float64); ok { + usageInfo.OutputTokens = int64(outputTokens) + } + + case "messageStopEvent", "message_stop": + // Handle message stop events which may contain stop_reason + if sr := kirocommon.GetString(event, "stop_reason"); sr != "" { + stopReason = sr + log.Debugf("kiro: parseEventStream found stop_reason in messageStopEvent: %s", stopReason) + } + if sr := kirocommon.GetString(event, "stopReason"); sr != "" { + stopReason = sr + log.Debugf("kiro: parseEventStream found stopReason in messageStopEvent: %s", stopReason) + } + + case "messageMetadataEvent", "metadataEvent": + // Handle message metadata events which contain token counts + // Official format: { tokenUsage: { outputTokens, totalTokens, uncachedInputTokens, cacheReadInputTokens, cacheWriteInputTokens, contextUsagePercentage } } + var metadata map[string]interface{} + if m, ok := event["messageMetadataEvent"].(map[string]interface{}); ok { + metadata = m + } else if m, ok := event["metadataEvent"].(map[string]interface{}); ok { + metadata = m + } else { + metadata = event // event itself might be the metadata + } + + // Check for nested tokenUsage object (official format) + if tokenUsage, ok := metadata["tokenUsage"].(map[string]interface{}); ok { + // outputTokens - precise output token count + if outputTokens, ok := tokenUsage["outputTokens"].(float64); ok { + usageInfo.OutputTokens = int64(outputTokens) + log.Infof("kiro: parseEventStream found precise outputTokens in tokenUsage: %d", usageInfo.OutputTokens) + } + // totalTokens - precise total token count + if totalTokens, ok := tokenUsage["totalTokens"].(float64); ok { + usageInfo.TotalTokens = int64(totalTokens) + log.Infof("kiro: parseEventStream found precise totalTokens in tokenUsage: %d", usageInfo.TotalTokens) + } + // uncachedInputTokens - input tokens not from cache + if uncachedInputTokens, ok := tokenUsage["uncachedInputTokens"].(float64); ok { + usageInfo.InputTokens = int64(uncachedInputTokens) + log.Infof("kiro: parseEventStream found uncachedInputTokens in tokenUsage: %d", usageInfo.InputTokens) + } + // cacheReadInputTokens - tokens read from cache + if cacheReadTokens, ok := tokenUsage["cacheReadInputTokens"].(float64); ok { + // Add to input tokens if we have uncached tokens, otherwise use as input + if usageInfo.InputTokens > 0 { + usageInfo.InputTokens += int64(cacheReadTokens) + } else { + usageInfo.InputTokens = int64(cacheReadTokens) + } + log.Debugf("kiro: parseEventStream found cacheReadInputTokens in tokenUsage: %d", int64(cacheReadTokens)) + } + // contextUsagePercentage - can be used as fallback for input token estimation + if ctxPct, ok := tokenUsage["contextUsagePercentage"].(float64); ok { + upstreamContextPercentage = ctxPct + log.Debugf("kiro: parseEventStream found contextUsagePercentage in tokenUsage: %.2f%%", ctxPct) + } + } + + // Fallback: check for direct fields in metadata (legacy format) + if usageInfo.InputTokens == 0 { + if inputTokens, ok := metadata["inputTokens"].(float64); ok { + usageInfo.InputTokens = int64(inputTokens) + log.Debugf("kiro: parseEventStream found inputTokens in messageMetadataEvent: %d", usageInfo.InputTokens) + } + } + if usageInfo.OutputTokens == 0 { + if outputTokens, ok := metadata["outputTokens"].(float64); ok { + usageInfo.OutputTokens = int64(outputTokens) + log.Debugf("kiro: parseEventStream found outputTokens in messageMetadataEvent: %d", usageInfo.OutputTokens) + } + } + if usageInfo.TotalTokens == 0 { + if totalTokens, ok := metadata["totalTokens"].(float64); ok { + usageInfo.TotalTokens = int64(totalTokens) + log.Debugf("kiro: parseEventStream found totalTokens in messageMetadataEvent: %d", usageInfo.TotalTokens) + } + } + + case "usageEvent", "usage": + // Handle dedicated usage events + if inputTokens, ok := event["inputTokens"].(float64); ok { + usageInfo.InputTokens = int64(inputTokens) + log.Debugf("kiro: parseEventStream found inputTokens in usageEvent: %d", usageInfo.InputTokens) + } + if outputTokens, ok := event["outputTokens"].(float64); ok { + usageInfo.OutputTokens = int64(outputTokens) + log.Debugf("kiro: parseEventStream found outputTokens in usageEvent: %d", usageInfo.OutputTokens) + } + if totalTokens, ok := event["totalTokens"].(float64); ok { + usageInfo.TotalTokens = int64(totalTokens) + log.Debugf("kiro: parseEventStream found totalTokens in usageEvent: %d", usageInfo.TotalTokens) + } + // Also check nested usage object + if usageObj, ok := event["usage"].(map[string]interface{}); ok { + if inputTokens, ok := usageObj["input_tokens"].(float64); ok { + usageInfo.InputTokens = int64(inputTokens) + } else if inputTokens, ok := usageObj["prompt_tokens"].(float64); ok { + usageInfo.InputTokens = int64(inputTokens) + } + if outputTokens, ok := usageObj["output_tokens"].(float64); ok { + usageInfo.OutputTokens = int64(outputTokens) + } else if outputTokens, ok := usageObj["completion_tokens"].(float64); ok { + usageInfo.OutputTokens = int64(outputTokens) + } + if totalTokens, ok := usageObj["total_tokens"].(float64); ok { + usageInfo.TotalTokens = int64(totalTokens) + } + log.Debugf("kiro: parseEventStream found usage object: input=%d, output=%d, total=%d", + usageInfo.InputTokens, usageInfo.OutputTokens, usageInfo.TotalTokens) + } + + case "metricsEvent": + // Handle metrics events which may contain usage data + if metrics, ok := event["metricsEvent"].(map[string]interface{}); ok { + if inputTokens, ok := metrics["inputTokens"].(float64); ok { + usageInfo.InputTokens = int64(inputTokens) + } + if outputTokens, ok := metrics["outputTokens"].(float64); ok { + usageInfo.OutputTokens = int64(outputTokens) + } + log.Debugf("kiro: parseEventStream found metricsEvent: input=%d, output=%d", + usageInfo.InputTokens, usageInfo.OutputTokens) + } + + case "meteringEvent": + // Handle metering events from Kiro API (usage billing information) + // Official format: { unit: string, unitPlural: string, usage: number } + if metering, ok := event["meteringEvent"].(map[string]interface{}); ok { + unit := "" + if u, ok := metering["unit"].(string); ok { + unit = u + } + usageVal := 0.0 + if u, ok := metering["usage"].(float64); ok { + usageVal = u + } + log.Infof("kiro: parseEventStream received meteringEvent: usage=%.2f %s", usageVal, unit) + // Store metering info for potential billing/statistics purposes + // Note: This is separate from token counts - it's AWS billing units + } else { + // Try direct fields + unit := "" + if u, ok := event["unit"].(string); ok { + unit = u + } + usageVal := 0.0 + if u, ok := event["usage"].(float64); ok { + usageVal = u + } + if unit != "" || usageVal > 0 { + log.Infof("kiro: parseEventStream received meteringEvent (direct): usage=%.2f %s", usageVal, unit) + } + } + + case "contextUsageEvent": + // Handle context usage events from Kiro API + // Format: {"contextUsageEvent": {"contextUsagePercentage": 0.53}} + if ctxUsage, ok := event["contextUsageEvent"].(map[string]interface{}); ok { + if ctxPct, ok := ctxUsage["contextUsagePercentage"].(float64); ok { + upstreamContextPercentage = ctxPct + log.Debugf("kiro: parseEventStream received contextUsageEvent: %.2f%%", ctxPct*100) + } + } else { + // Try direct field (fallback) + if ctxPct, ok := event["contextUsagePercentage"].(float64); ok { + upstreamContextPercentage = ctxPct + log.Debugf("kiro: parseEventStream received contextUsagePercentage (direct): %.2f%%", ctxPct*100) + } + } + + case "error", "exception", "internalServerException", "invalidStateEvent": + // Handle error events from Kiro API stream + errMsg := "" + errType := eventType + + // Try to extract error message from various formats + if msg, ok := event["message"].(string); ok { + errMsg = msg + } else if errObj, ok := event[eventType].(map[string]interface{}); ok { + if msg, ok := errObj["message"].(string); ok { + errMsg = msg + } + if t, ok := errObj["type"].(string); ok { + errType = t + } + } else if errObj, ok := event["error"].(map[string]interface{}); ok { + if msg, ok := errObj["message"].(string); ok { + errMsg = msg + } + if t, ok := errObj["type"].(string); ok { + errType = t + } + } + + // Check for specific error reasons + if reason, ok := event["reason"].(string); ok { + errMsg = fmt.Sprintf("%s (reason: %s)", errMsg, reason) + } + + log.Errorf("kiro: parseEventStream received error event: type=%s, message=%s", errType, errMsg) + + // For invalidStateEvent, we may want to continue processing other events + if eventType == "invalidStateEvent" { + log.Warnf("kiro: invalidStateEvent received, continuing stream processing") + continue + } + + // For other errors, return the error + if errMsg != "" { + return "", nil, usageInfo, stopReason, fmt.Errorf("kiro API error (%s): %s", errType, errMsg) + } + + default: + // Check for contextUsagePercentage in any event + if ctxPct, ok := event["contextUsagePercentage"].(float64); ok { + upstreamContextPercentage = ctxPct + log.Debugf("kiro: parseEventStream received context usage: %.2f%%", upstreamContextPercentage) + } + // Log unknown event types for debugging (to discover new event formats) + log.Debugf("kiro: parseEventStream unknown event type: %s, payload: %s", eventType, string(payload)) + } + + // Check for direct token fields in any event (fallback) + if usageInfo.InputTokens == 0 { + if inputTokens, ok := event["inputTokens"].(float64); ok { + usageInfo.InputTokens = int64(inputTokens) + log.Debugf("kiro: parseEventStream found direct inputTokens: %d", usageInfo.InputTokens) + } + } + if usageInfo.OutputTokens == 0 { + if outputTokens, ok := event["outputTokens"].(float64); ok { + usageInfo.OutputTokens = int64(outputTokens) + log.Debugf("kiro: parseEventStream found direct outputTokens: %d", usageInfo.OutputTokens) + } + } + + // Check for usage object in any event (OpenAI format) + if usageInfo.InputTokens == 0 || usageInfo.OutputTokens == 0 { + if usageObj, ok := event["usage"].(map[string]interface{}); ok { + if usageInfo.InputTokens == 0 { + if inputTokens, ok := usageObj["input_tokens"].(float64); ok { + usageInfo.InputTokens = int64(inputTokens) + } else if inputTokens, ok := usageObj["prompt_tokens"].(float64); ok { + usageInfo.InputTokens = int64(inputTokens) + } + } + if usageInfo.OutputTokens == 0 { + if outputTokens, ok := usageObj["output_tokens"].(float64); ok { + usageInfo.OutputTokens = int64(outputTokens) + } else if outputTokens, ok := usageObj["completion_tokens"].(float64); ok { + usageInfo.OutputTokens = int64(outputTokens) + } + } + if usageInfo.TotalTokens == 0 { + if totalTokens, ok := usageObj["total_tokens"].(float64); ok { + usageInfo.TotalTokens = int64(totalTokens) + } + } + log.Debugf("kiro: parseEventStream found usage object (fallback): input=%d, output=%d, total=%d", + usageInfo.InputTokens, usageInfo.OutputTokens, usageInfo.TotalTokens) + } + } + + // Also check nested supplementaryWebLinksEvent + if usageEvent, ok := event["supplementaryWebLinksEvent"].(map[string]interface{}); ok { + if inputTokens, ok := usageEvent["inputTokens"].(float64); ok { + usageInfo.InputTokens = int64(inputTokens) + } + if outputTokens, ok := usageEvent["outputTokens"].(float64); ok { + usageInfo.OutputTokens = int64(outputTokens) + } + } + } + + // Parse embedded tool calls from content (e.g., [Called tool_name with args: {...}]) + contentStr := content.String() + cleanedContent, embeddedToolUses := kiroclaude.ParseEmbeddedToolCalls(contentStr, processedIDs) + toolUses = append(toolUses, embeddedToolUses...) + + // Deduplicate all tool uses + toolUses = kiroclaude.DeduplicateToolUses(toolUses) + + // Apply fallback logic for stop_reason if not provided by upstream + // Priority: upstream stopReason > tool_use detection > end_turn default + if stopReason == "" { + if len(toolUses) > 0 { + stopReason = "tool_use" + log.Debugf("kiro: parseEventStream using fallback stop_reason: tool_use (detected %d tool uses)", len(toolUses)) + } else { + stopReason = "end_turn" + log.Debugf("kiro: parseEventStream using fallback stop_reason: end_turn") + } + } + + // Log warning if response was truncated due to max_tokens + if stopReason == "max_tokens" { + log.Warnf("kiro: response truncated due to max_tokens limit") + } + + // Use contextUsagePercentage to calculate more accurate input tokens + // Kiro model has 200k max context, contextUsagePercentage represents the percentage used + // Formula: input_tokens = contextUsagePercentage * 200000 / 100 + if upstreamContextPercentage > 0 { + calculatedInputTokens := int64(upstreamContextPercentage * 200000 / 100) + if calculatedInputTokens > 0 { + localEstimate := usageInfo.InputTokens + usageInfo.InputTokens = calculatedInputTokens + usageInfo.TotalTokens = usageInfo.InputTokens + usageInfo.OutputTokens + log.Infof("kiro: parseEventStream using contextUsagePercentage (%.2f%%) to calculate input tokens: %d (local estimate was: %d)", + upstreamContextPercentage, calculatedInputTokens, localEstimate) + } + } + + return cleanedContent, toolUses, usageInfo, stopReason, nil +} + +// readEventStreamMessage reads and validates a single AWS Event Stream message. +// Returns the parsed message or a structured error for different failure modes. +// This function implements boundary protection and detailed error classification. +// +// AWS Event Stream binary format: +// - Prelude (12 bytes): total_length (4) + headers_length (4) + prelude_crc (4) +// - Headers (variable): header entries +// - Payload (variable): JSON data +// - Message CRC (4 bytes): CRC32C of entire message (not validated, just skipped) +func (e *KiroExecutor) readEventStreamMessage(reader *bufio.Reader) (*eventStreamMessage, *EventStreamError) { + // Read prelude (first 12 bytes: total_len + headers_len + prelude_crc) + prelude := make([]byte, 12) + _, err := io.ReadFull(reader, prelude) + if err == io.EOF { + return nil, nil // Normal end of stream + } + if err != nil { + return nil, &EventStreamError{ + Type: ErrStreamFatal, + Message: "failed to read prelude", + Cause: err, + } + } + + totalLength := binary.BigEndian.Uint32(prelude[0:4]) + headersLength := binary.BigEndian.Uint32(prelude[4:8]) + // Note: prelude[8:12] is prelude_crc - we read it but don't validate (no CRC check per requirements) + + // Boundary check: minimum frame size + if totalLength < minEventStreamFrameSize { + return nil, &EventStreamError{ + Type: ErrStreamMalformed, + Message: fmt.Sprintf("invalid message length: %d (minimum is %d)", totalLength, minEventStreamFrameSize), + } + } + + // Boundary check: maximum message size + if totalLength > maxEventStreamMsgSize { + return nil, &EventStreamError{ + Type: ErrStreamMalformed, + Message: fmt.Sprintf("message too large: %d bytes (maximum is %d)", totalLength, maxEventStreamMsgSize), + } + } + + // Boundary check: headers length within message bounds + // Message structure: prelude(12) + headers(headersLength) + payload + message_crc(4) + // So: headersLength must be <= totalLength - 16 (12 for prelude + 4 for message_crc) + if headersLength > totalLength-16 { + return nil, &EventStreamError{ + Type: ErrStreamMalformed, + Message: fmt.Sprintf("headers length %d exceeds message bounds (total: %d)", headersLength, totalLength), + } + } + + // Read the rest of the message (total - 12 bytes already read) + remaining := make([]byte, totalLength-12) + _, err = io.ReadFull(reader, remaining) + if err != nil { + return nil, &EventStreamError{ + Type: ErrStreamFatal, + Message: "failed to read message body", + Cause: err, + } + } + + // Extract event type from headers + // Headers start at beginning of 'remaining', length is headersLength + var eventType string + if headersLength > 0 && headersLength <= uint32(len(remaining)) { + eventType = e.extractEventTypeFromBytes(remaining[:headersLength]) + } + + // Calculate payload boundaries + // Payload starts after headers, ends before message_crc (last 4 bytes) + payloadStart := headersLength + payloadEnd := uint32(len(remaining)) - 4 // Skip message_crc at end + + // Validate payload boundaries + if payloadStart >= payloadEnd { + // No payload, return empty message + return &eventStreamMessage{ + EventType: eventType, + Payload: nil, + }, nil + } + + payload := remaining[payloadStart:payloadEnd] + + return &eventStreamMessage{ + EventType: eventType, + Payload: payload, + }, nil +} + +func skipEventStreamHeaderValue(headers []byte, offset int, valueType byte) (int, bool) { + switch valueType { + case 0, 1: // bool true / bool false + return offset, true + case 2: // byte + if offset+1 > len(headers) { + return offset, false + } + return offset + 1, true + case 3: // short + if offset+2 > len(headers) { + return offset, false + } + return offset + 2, true + case 4: // int + if offset+4 > len(headers) { + return offset, false + } + return offset + 4, true + case 5: // long + if offset+8 > len(headers) { + return offset, false + } + return offset + 8, true + case 6: // byte array (2-byte length + data) + if offset+2 > len(headers) { + return offset, false + } + valueLen := int(binary.BigEndian.Uint16(headers[offset : offset+2])) + offset += 2 + if offset+valueLen > len(headers) { + return offset, false + } + return offset + valueLen, true + case 8: // timestamp + if offset+8 > len(headers) { + return offset, false + } + return offset + 8, true + case 9: // uuid + if offset+16 > len(headers) { + return offset, false + } + return offset + 16, true + default: + return offset, false + } +} + +// extractEventTypeFromBytes extracts the event type from raw header bytes (without prelude CRC prefix) +func (e *KiroExecutor) extractEventTypeFromBytes(headers []byte) string { + offset := 0 + for offset < len(headers) { + nameLen := int(headers[offset]) + offset++ + if offset+nameLen > len(headers) { + break + } + name := string(headers[offset : offset+nameLen]) + offset += nameLen + + if offset >= len(headers) { + break + } + valueType := headers[offset] + offset++ + + if valueType == 7 { // String type + if offset+2 > len(headers) { + break + } + valueLen := int(binary.BigEndian.Uint16(headers[offset : offset+2])) + offset += 2 + if offset+valueLen > len(headers) { + break + } + value := string(headers[offset : offset+valueLen]) + offset += valueLen + + if name == ":event-type" { + return value + } + continue + } + + nextOffset, ok := skipEventStreamHeaderValue(headers, offset, valueType) + if !ok { + break + } + offset = nextOffset + } + return "" +} + +// NOTE: Response building functions moved to pkg/llmproxy/translator/kiro/claude/kiro_claude_response.go +// The executor now uses kiroclaude.BuildClaudeResponse() and kiroclaude.ExtractThinkingFromContent() instead + +// streamToChannel converts AWS Event Stream to channel-based streaming. +// Supports tool calling - emits tool_use content blocks when tools are used. +// Includes embedded [Called ...] tool call parsing and input buffering for toolUseEvent. +// Implements duplicate content filtering using lastContentEvent detection (based on AIClient-2-API). +// Extracts stop_reason from upstream events when available. +// thinkingEnabled controls whether tags are parsed - only parse when request enabled thinking. +func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out chan<- cliproxyexecutor.StreamChunk, targetFormat sdktranslator.Format, model string, originalReq, claudeBody []byte, reporter *usageReporter, thinkingEnabled bool) { + reader := bufio.NewReaderSize(body, 20*1024*1024) // 20MB buffer to match other providers + var totalUsage usage.Detail + var hasToolUses bool // Track if any tool uses were emitted + var hasTruncatedTools bool // Track if any tool uses were truncated + var upstreamStopReason string // Track stop_reason from upstream events + + // Tool use state tracking for input buffering and deduplication + processedIDs := make(map[string]bool) + var currentToolUse *kiroclaude.ToolUseState + + // NOTE: Duplicate content filtering removed - it was causing legitimate repeated + // content (like consecutive newlines) to be incorrectly filtered out. + // The previous implementation compared lastContentEvent == contentDelta which + // is too aggressive for streaming scenarios. + + // Streaming token calculation - accumulate content for real-time token counting + // Based on AIClient-2-API implementation + var accumulatedContent strings.Builder + accumulatedContent.Grow(4096) // Pre-allocate 4KB capacity to reduce reallocations + + // Real-time usage estimation state + // These track when to send periodic usage updates during streaming + var lastUsageUpdateLen int // Last accumulated content length when usage was sent + var lastUsageUpdateTime = time.Now() // Last time usage update was sent + var lastReportedOutputTokens int64 // Last reported output token count + + // Upstream usage tracking - Kiro API returns credit usage and context percentage + var upstreamCreditUsage float64 // Credit usage from upstream (e.g., 1.458) + var upstreamContextPercentage float64 // Context usage percentage from upstream (e.g., 78.56) + var hasUpstreamUsage bool // Whether we received usage from upstream + + // Translator param for maintaining tool call state across streaming events + // IMPORTANT: This must persist across all TranslateStream calls + var translatorParam any + + // Thinking mode state tracking - tag-based parsing for tags in content + inThinkBlock := false // Whether we're currently inside a block + isThinkingBlockOpen := false // Track if thinking content block SSE event is open + thinkingBlockIndex := -1 // Index of the thinking content block + var accumulatedThinkingContent strings.Builder // Accumulate thinking content for token counting + + // Buffer for handling partial tag matches at chunk boundaries + var pendingContent strings.Builder // Buffer content that might be part of a tag + + // Pre-calculate input tokens from request if possible + // Kiro uses Claude format, so try Claude format first, then OpenAI format, then fallback + if enc, err := getTokenizer(model); err == nil { + var inputTokens int64 + var countMethod string + + // Try Claude format first (Kiro uses Claude API format) + if inp, err := countClaudeChatTokens(enc, claudeBody); err == nil && inp > 0 { + inputTokens = inp + countMethod = "claude" + } else if inp, err := countOpenAIChatTokens(enc, originalReq); err == nil && inp > 0 { + // Fallback to OpenAI format (for OpenAI-compatible requests) + inputTokens = inp + countMethod = "openai" + } else { + // Final fallback: estimate from raw request size (roughly 4 chars per token) + inputTokens = int64(len(claudeBody) / 4) + if inputTokens == 0 && len(claudeBody) > 0 { + inputTokens = 1 + } + countMethod = "estimate" + } + + totalUsage.InputTokens = inputTokens + log.Debugf("kiro: streamToChannel pre-calculated input tokens: %d (method: %s, claude body: %d bytes, original req: %d bytes)", + totalUsage.InputTokens, countMethod, len(claudeBody), len(originalReq)) + } + + contentBlockIndex := -1 + messageStartSent := false + isTextBlockOpen := false + var outputLen int + + // Ensure usage is published even on early return + defer func() { + reporter.publish(ctx, totalUsage) + }() + + for { + select { + case <-ctx.Done(): + return + default: + } + + msg, eventErr := e.readEventStreamMessage(reader) + if eventErr != nil { + // Log the error + log.Errorf("kiro: streamToChannel error: %v", eventErr) + + // Send error to channel for client notification + out <- cliproxyexecutor.StreamChunk{Err: eventErr} + return + } + if msg == nil { + // Normal end of stream (EOF) + // Flush any incomplete tool use before ending stream + if currentToolUse != nil && !processedIDs[currentToolUse.ToolUseID] { + log.Warnf("kiro: flushing incomplete tool use at EOF: %s (ID: %s)", currentToolUse.Name, currentToolUse.ToolUseID) + fullInput := currentToolUse.InputBuffer.String() + repairedJSON := kiroclaude.RepairJSON(fullInput) + var finalInput map[string]interface{} + if err := json.Unmarshal([]byte(repairedJSON), &finalInput); err != nil { + log.Warnf("kiro: failed to parse incomplete tool input at EOF: %v", err) + finalInput = make(map[string]interface{}) + } + + processedIDs[currentToolUse.ToolUseID] = true + contentBlockIndex++ + + // Send tool_use content block + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", currentToolUse.ToolUseID, currentToolUse.Name) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + + // Send tool input as delta + inputBytes, _ := json.Marshal(finalInput) + inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputBytes), contentBlockIndex) + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + + // Close block + blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + + hasToolUses = true + currentToolUse = nil + } + + // DISABLED: Tag-based pending character flushing + // This code block was used for tag-based thinking detection which has been + // replaced by reasoningContentEvent handling. No pending tag chars to flush. + // Original code preserved in git history. + break + } + + eventType := msg.EventType + payload := msg.Payload + if len(payload) == 0 { + continue + } + appendAPIResponseChunk(ctx, e.cfg, payload) + + var event map[string]interface{} + if err := json.Unmarshal(payload, &event); err != nil { + log.Warnf("kiro: failed to unmarshal event payload: %v, raw: %s", err, string(payload)) + continue + } + + // Check for error/exception events in the payload (Kiro API may return errors with HTTP 200) + // These can appear as top-level fields or nested within the event + if errType, hasErrType := event["_type"].(string); hasErrType { + // AWS-style error: {"_type": "com.amazon.aws.codewhisperer#ValidationException", "message": "..."} + errMsg := "" + if msg, ok := event["message"].(string); ok { + errMsg = msg + } + log.Errorf("kiro: received AWS error in stream: type=%s, message=%s", errType, errMsg) + out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("kiro API error: %s - %s", errType, errMsg)} + return + } + if errType, hasErrType := event["type"].(string); hasErrType && (errType == "error" || errType == "exception") { + // Generic error event + errMsg := "" + if msg, ok := event["message"].(string); ok { + errMsg = msg + } else if errObj, ok := event["error"].(map[string]interface{}); ok { + if msg, ok := errObj["message"].(string); ok { + errMsg = msg + } + } + log.Errorf("kiro: received error event in stream: type=%s, message=%s", errType, errMsg) + out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("kiro API error: %s", errMsg)} + return + } + + // Extract stop_reason from various event formats (streaming) + // Kiro/Amazon Q API may include stop_reason in different locations + if sr := kirocommon.GetString(event, "stop_reason"); sr != "" { + upstreamStopReason = sr + log.Debugf("kiro: streamToChannel found stop_reason (top-level): %s", upstreamStopReason) + } + if sr := kirocommon.GetString(event, "stopReason"); sr != "" { + upstreamStopReason = sr + log.Debugf("kiro: streamToChannel found stopReason (top-level): %s", upstreamStopReason) + } + + // Send message_start on first event + if !messageStartSent { + msgStart := kiroclaude.BuildClaudeMessageStartEvent(model, totalUsage.InputTokens) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + messageStartSent = true + } + + switch eventType { + case "followupPromptEvent": + // Filter out followupPrompt events - these are UI suggestions, not content + log.Debugf("kiro: streamToChannel ignoring followupPrompt event") + continue + + case "messageStopEvent", "message_stop": + // Handle message stop events which may contain stop_reason + if sr := kirocommon.GetString(event, "stop_reason"); sr != "" { + upstreamStopReason = sr + log.Debugf("kiro: streamToChannel found stop_reason in messageStopEvent: %s", upstreamStopReason) + } + if sr := kirocommon.GetString(event, "stopReason"); sr != "" { + upstreamStopReason = sr + log.Debugf("kiro: streamToChannel found stopReason in messageStopEvent: %s", upstreamStopReason) + } + + case "meteringEvent": + // Handle metering events from Kiro API (usage billing information) + // Official format: { unit: string, unitPlural: string, usage: number } + if metering, ok := event["meteringEvent"].(map[string]interface{}); ok { + unit := "" + if u, ok := metering["unit"].(string); ok { + unit = u + } + usageVal := 0.0 + if u, ok := metering["usage"].(float64); ok { + usageVal = u + } + upstreamCreditUsage = usageVal + hasUpstreamUsage = true + log.Infof("kiro: streamToChannel received meteringEvent: usage=%.4f %s", usageVal, unit) + } else { + // Try direct fields (event is meteringEvent itself) + if unit, ok := event["unit"].(string); ok { + if usage, ok := event["usage"].(float64); ok { + upstreamCreditUsage = usage + hasUpstreamUsage = true + log.Infof("kiro: streamToChannel received meteringEvent (direct): usage=%.4f %s", usage, unit) + } + } + } + + case "contextUsageEvent": + // Handle context usage events from Kiro API + // Format: {"contextUsageEvent": {"contextUsagePercentage": 0.53}} + if ctxUsage, ok := event["contextUsageEvent"].(map[string]interface{}); ok { + if ctxPct, ok := ctxUsage["contextUsagePercentage"].(float64); ok { + upstreamContextPercentage = ctxPct + log.Debugf("kiro: streamToChannel received contextUsageEvent: %.2f%%", ctxPct*100) + } + } else { + // Try direct field (fallback) + if ctxPct, ok := event["contextUsagePercentage"].(float64); ok { + upstreamContextPercentage = ctxPct + log.Debugf("kiro: streamToChannel received contextUsagePercentage (direct): %.2f%%", ctxPct*100) + } + } + + case "error", "exception", "internalServerException": + // Handle error events from Kiro API stream + errMsg := "" + errType := eventType + + // Try to extract error message from various formats + if msg, ok := event["message"].(string); ok { + errMsg = msg + } else if errObj, ok := event[eventType].(map[string]interface{}); ok { + if msg, ok := errObj["message"].(string); ok { + errMsg = msg + } + if t, ok := errObj["type"].(string); ok { + errType = t + } + } else if errObj, ok := event["error"].(map[string]interface{}); ok { + if msg, ok := errObj["message"].(string); ok { + errMsg = msg + } + } + + log.Errorf("kiro: streamToChannel received error event: type=%s, message=%s", errType, errMsg) + + // Send error to the stream and exit + if errMsg != "" { + out <- cliproxyexecutor.StreamChunk{ + Err: fmt.Errorf("kiro API error (%s): %s", errType, errMsg), + } + return + } + + case "invalidStateEvent": + // Handle invalid state events - log and continue (non-fatal) + errMsg := "" + if msg, ok := event["message"].(string); ok { + errMsg = msg + } else if stateEvent, ok := event["invalidStateEvent"].(map[string]interface{}); ok { + if msg, ok := stateEvent["message"].(string); ok { + errMsg = msg + } + } + log.Warnf("kiro: streamToChannel received invalidStateEvent: %s, continuing", errMsg) + continue + + case "assistantResponseEvent": + var contentDelta string + var toolUses []map[string]interface{} + + if assistantResp, ok := event["assistantResponseEvent"].(map[string]interface{}); ok { + if c, ok := assistantResp["content"].(string); ok { + contentDelta = c + } + // Extract stop_reason from assistantResponseEvent + if sr := kirocommon.GetString(assistantResp, "stop_reason"); sr != "" { + upstreamStopReason = sr + log.Debugf("kiro: streamToChannel found stop_reason in assistantResponseEvent: %s", upstreamStopReason) + } + if sr := kirocommon.GetString(assistantResp, "stopReason"); sr != "" { + upstreamStopReason = sr + log.Debugf("kiro: streamToChannel found stopReason in assistantResponseEvent: %s", upstreamStopReason) + } + // Extract tool uses from response + if tus, ok := assistantResp["toolUses"].([]interface{}); ok { + for _, tuRaw := range tus { + if tu, ok := tuRaw.(map[string]interface{}); ok { + toolUses = append(toolUses, tu) + } + } + } + } + if contentDelta == "" { + if c, ok := event["content"].(string); ok { + contentDelta = c + } + } + // Direct tool uses + if tus, ok := event["toolUses"].([]interface{}); ok { + for _, tuRaw := range tus { + if tu, ok := tuRaw.(map[string]interface{}); ok { + toolUses = append(toolUses, tu) + } + } + } + + // Handle text content with thinking mode support + if contentDelta != "" { + // NOTE: Duplicate content filtering was removed because it incorrectly + // filtered out legitimate repeated content (like consecutive newlines "\n\n"). + // Streaming naturally can have identical chunks that are valid content. + + outputLen += len(contentDelta) + // Accumulate content for streaming token calculation + accumulatedContent.WriteString(contentDelta) + + // Real-time usage estimation: Check if we should send a usage update + // This helps clients track context usage during long thinking sessions + shouldSendUsageUpdate := false + if accumulatedContent.Len()-lastUsageUpdateLen >= usageUpdateCharThreshold { + shouldSendUsageUpdate = true + } else if time.Since(lastUsageUpdateTime) >= usageUpdateTimeInterval && accumulatedContent.Len() > lastUsageUpdateLen { + shouldSendUsageUpdate = true + } + + if shouldSendUsageUpdate { + // Calculate current output tokens using tiktoken + var currentOutputTokens int64 + if enc, encErr := getTokenizer(model); encErr == nil { + if tokenCount, countErr := enc.Count(accumulatedContent.String()); countErr == nil { + currentOutputTokens = int64(tokenCount) + } + } + // Fallback to character estimation if tiktoken fails + if currentOutputTokens == 0 { + currentOutputTokens = int64(accumulatedContent.Len() / 4) + if currentOutputTokens == 0 { + currentOutputTokens = 1 + } + } + + // Only send update if token count has changed significantly (at least 10 tokens) + if currentOutputTokens > lastReportedOutputTokens+10 { + // Send ping event with usage information + // This is a non-blocking update that clients can optionally process + pingEvent := kiroclaude.BuildClaudePingEventWithUsage(totalUsage.InputTokens, currentOutputTokens) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, pingEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + + lastReportedOutputTokens = currentOutputTokens + log.Debugf("kiro: sent real-time usage update - input: %d, output: %d (accumulated: %d chars)", + totalUsage.InputTokens, currentOutputTokens, accumulatedContent.Len()) + } + + lastUsageUpdateLen = accumulatedContent.Len() + lastUsageUpdateTime = time.Now() + } + + // TAG-BASED THINKING PARSING: Parse tags from content + // Combine pending content with new content for processing + pendingContent.WriteString(contentDelta) + processContent := pendingContent.String() + pendingContent.Reset() + + // Process content looking for thinking tags + for len(processContent) > 0 { + if inThinkBlock { + // We're inside a thinking block, look for + endIdx := strings.Index(processContent, kirocommon.ThinkingEndTag) + if endIdx >= 0 { + // Found end tag - emit thinking content before the tag + thinkingText := processContent[:endIdx] + if thinkingText != "" { + // Ensure thinking block is open + if !isThinkingBlockOpen { + contentBlockIndex++ + thinkingBlockIndex = contentBlockIndex + isThinkingBlockOpen = true + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + // Send thinking delta + thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(thinkingText, thinkingBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + accumulatedThinkingContent.WriteString(thinkingText) + } + // Close thinking block + if isThinkingBlockOpen { + blockStop := kiroclaude.BuildClaudeThinkingBlockStopEvent(thinkingBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + isThinkingBlockOpen = false + } + inThinkBlock = false + processContent = processContent[endIdx+len(kirocommon.ThinkingEndTag):] + log.Debugf("kiro: closed thinking block, remaining content: %d chars", len(processContent)) + } else { + // No end tag found - check for partial match at end + partialMatch := false + for i := 1; i < len(kirocommon.ThinkingEndTag) && i <= len(processContent); i++ { + if strings.HasSuffix(processContent, kirocommon.ThinkingEndTag[:i]) { + // Possible partial tag at end, buffer it + pendingContent.WriteString(processContent[len(processContent)-i:]) + processContent = processContent[:len(processContent)-i] + partialMatch = true + break + } + } + if !partialMatch || len(processContent) > 0 { + // Emit all as thinking content + if processContent != "" { + if !isThinkingBlockOpen { + contentBlockIndex++ + thinkingBlockIndex = contentBlockIndex + isThinkingBlockOpen = true + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(processContent, thinkingBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + accumulatedThinkingContent.WriteString(processContent) + } + } + processContent = "" + } + } else { + // Not in thinking block, look for + startIdx := strings.Index(processContent, kirocommon.ThinkingStartTag) + if startIdx >= 0 { + // Found start tag - emit text content before the tag + textBefore := processContent[:startIdx] + if textBefore != "" { + // Close thinking block if open + if isThinkingBlockOpen { + blockStop := kiroclaude.BuildClaudeThinkingBlockStopEvent(thinkingBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + isThinkingBlockOpen = false + } + // Ensure text block is open + if !isTextBlockOpen { + contentBlockIndex++ + isTextBlockOpen = true + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + // Send text delta + claudeEvent := kiroclaude.BuildClaudeStreamEvent(textBefore, contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + // Close text block before entering thinking + if isTextBlockOpen { + blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + isTextBlockOpen = false + } + inThinkBlock = true + processContent = processContent[startIdx+len(kirocommon.ThinkingStartTag):] + log.Debugf("kiro: entered thinking block") + } else { + // No start tag found - check for partial match at end + partialMatch := false + for i := 1; i < len(kirocommon.ThinkingStartTag) && i <= len(processContent); i++ { + if strings.HasSuffix(processContent, kirocommon.ThinkingStartTag[:i]) { + // Possible partial tag at end, buffer it + pendingContent.WriteString(processContent[len(processContent)-i:]) + processContent = processContent[:len(processContent)-i] + partialMatch = true + break + } + } + if !partialMatch || len(processContent) > 0 { + // Emit all as text content + if processContent != "" { + if !isTextBlockOpen { + contentBlockIndex++ + isTextBlockOpen = true + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + claudeEvent := kiroclaude.BuildClaudeStreamEvent(processContent, contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + } + processContent = "" + } + } + } + } + + // Handle tool uses in response (with deduplication) + for _, tu := range toolUses { + toolUseID := kirocommon.GetString(tu, "toolUseId") + toolName := kirocommon.GetString(tu, "name") + + // Check for duplicate + if processedIDs[toolUseID] { + log.Debugf("kiro: skipping duplicate tool use in stream: %s", toolUseID) + continue + } + processedIDs[toolUseID] = true + + hasToolUses = true + // Close text block if open before starting tool_use block + if isTextBlockOpen && contentBlockIndex >= 0 { + blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + isTextBlockOpen = false + } + + // Emit tool_use content block + contentBlockIndex++ + + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", toolUseID, toolName) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + + // Send input_json_delta with the tool input + if input, ok := tu["input"].(map[string]interface{}); ok { + inputJSON, err := json.Marshal(input) + if err != nil { + log.Debugf("kiro: failed to marshal tool input: %v", err) + // Don't continue - still need to close the block + } else { + inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex) + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + } + + // Close tool_use block (always close even if input marshal failed) + blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + case "reasoningContentEvent": + // Handle official reasoningContentEvent from Kiro API + // This replaces tag-based thinking detection with the proper event type + // Official format: { text: string, signature?: string, redactedContent?: base64 } + var thinkingText string + var signature string + + if re, ok := event["reasoningContentEvent"].(map[string]interface{}); ok { + if text, ok := re["text"].(string); ok { + thinkingText = text + } + if sig, ok := re["signature"].(string); ok { + signature = sig + if len(sig) > 20 { + log.Debugf("kiro: reasoningContentEvent has signature: %s...", sig[:20]) + } else { + log.Debugf("kiro: reasoningContentEvent has signature: %s", sig) + } + } + } else { + // Try direct fields + if text, ok := event["text"].(string); ok { + thinkingText = text + } + if sig, ok := event["signature"].(string); ok { + signature = sig + } + } + + if thinkingText != "" { + // Close text block if open before starting thinking block + if isTextBlockOpen && contentBlockIndex >= 0 { + blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + isTextBlockOpen = false + } + + // Start thinking block if not already open + if !isThinkingBlockOpen { + contentBlockIndex++ + thinkingBlockIndex = contentBlockIndex + isThinkingBlockOpen = true + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + // Send thinking content + thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(thinkingText, thinkingBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + + // Accumulate for token counting + accumulatedThinkingContent.WriteString(thinkingText) + log.Debugf("kiro: received reasoningContentEvent, text length: %d, has signature: %v", len(thinkingText), signature != "") + } + + // Note: We don't close the thinking block here - it will be closed when we see + // the next assistantResponseEvent or at the end of the stream + _ = signature // Signature can be used for verification if needed + + case "toolUseEvent": + // Handle dedicated tool use events with input buffering + completedToolUses, newState := kiroclaude.ProcessToolUseEvent(event, currentToolUse, processedIDs) + currentToolUse = newState + + // Emit completed tool uses + for _, tu := range completedToolUses { + // Check if this tool was truncated - emit with SOFT_LIMIT_REACHED marker + if tu.IsTruncated { + hasTruncatedTools = true + log.Infof("kiro: streamToChannel emitting truncated tool with SOFT_LIMIT_REACHED: %s (ID: %s)", tu.Name, tu.ToolUseID) + + // Close text block if open + if isTextBlockOpen && contentBlockIndex >= 0 { + blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + isTextBlockOpen = false + } + + contentBlockIndex++ + + // Emit tool_use with SOFT_LIMIT_REACHED marker input + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", tu.ToolUseID, tu.Name) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + + // Build SOFT_LIMIT_REACHED marker input + markerInput := map[string]interface{}{ + "_status": "SOFT_LIMIT_REACHED", + "_message": "Tool output was truncated. Split content into smaller chunks (max 300 lines). Due to potential model hallucination, you MUST re-fetch the current working directory and generate the correct file_path.", + } + + markerJSON, _ := json.Marshal(markerInput) + inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(markerJSON), contentBlockIndex) + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + + // Close tool_use block + blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + + hasToolUses = true // Keep this so stop_reason = tool_use + continue + } + + hasToolUses = true + + // Close text block if open + if isTextBlockOpen && contentBlockIndex >= 0 { + blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + isTextBlockOpen = false + } + + contentBlockIndex++ + + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", tu.ToolUseID, tu.Name) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + + if tu.Input != nil { + inputJSON, err := json.Marshal(tu.Input) + if err != nil { + log.Debugf("kiro: failed to marshal tool input in toolUseEvent: %v", err) + } else { + inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex) + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + } + + blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + case "supplementaryWebLinksEvent": + if inputTokens, ok := event["inputTokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + } + if outputTokens, ok := event["outputTokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + } + + case "messageMetadataEvent", "metadataEvent": + // Handle message metadata events which contain token counts + // Official format: { tokenUsage: { outputTokens, totalTokens, uncachedInputTokens, cacheReadInputTokens, cacheWriteInputTokens, contextUsagePercentage } } + var metadata map[string]interface{} + if m, ok := event["messageMetadataEvent"].(map[string]interface{}); ok { + metadata = m + } else if m, ok := event["metadataEvent"].(map[string]interface{}); ok { + metadata = m + } else { + metadata = event // event itself might be the metadata + } + + // Check for nested tokenUsage object (official format) + if tokenUsage, ok := metadata["tokenUsage"].(map[string]interface{}); ok { + // outputTokens - precise output token count + if outputTokens, ok := tokenUsage["outputTokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + hasUpstreamUsage = true + log.Infof("kiro: streamToChannel found precise outputTokens in tokenUsage: %d", totalUsage.OutputTokens) + } + // totalTokens - precise total token count + if totalTokens, ok := tokenUsage["totalTokens"].(float64); ok { + totalUsage.TotalTokens = int64(totalTokens) + log.Infof("kiro: streamToChannel found precise totalTokens in tokenUsage: %d", totalUsage.TotalTokens) + } + // uncachedInputTokens - input tokens not from cache + if uncachedInputTokens, ok := tokenUsage["uncachedInputTokens"].(float64); ok { + totalUsage.InputTokens = int64(uncachedInputTokens) + hasUpstreamUsage = true + log.Infof("kiro: streamToChannel found uncachedInputTokens in tokenUsage: %d", totalUsage.InputTokens) + } + // cacheReadInputTokens - tokens read from cache + if cacheReadTokens, ok := tokenUsage["cacheReadInputTokens"].(float64); ok { + // Add to input tokens if we have uncached tokens, otherwise use as input + if totalUsage.InputTokens > 0 { + totalUsage.InputTokens += int64(cacheReadTokens) + } else { + totalUsage.InputTokens = int64(cacheReadTokens) + } + hasUpstreamUsage = true + log.Debugf("kiro: streamToChannel found cacheReadInputTokens in tokenUsage: %d", int64(cacheReadTokens)) + } + // contextUsagePercentage - can be used as fallback for input token estimation + if ctxPct, ok := tokenUsage["contextUsagePercentage"].(float64); ok { + upstreamContextPercentage = ctxPct + log.Debugf("kiro: streamToChannel found contextUsagePercentage in tokenUsage: %.2f%%", ctxPct) + } + } + + // Fallback: check for direct fields in metadata (legacy format) + if totalUsage.InputTokens == 0 { + if inputTokens, ok := metadata["inputTokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + hasUpstreamUsage = true + log.Debugf("kiro: streamToChannel found inputTokens in messageMetadataEvent: %d", totalUsage.InputTokens) + } + } + if totalUsage.OutputTokens == 0 { + if outputTokens, ok := metadata["outputTokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + hasUpstreamUsage = true + log.Debugf("kiro: streamToChannel found outputTokens in messageMetadataEvent: %d", totalUsage.OutputTokens) + } + } + if totalUsage.TotalTokens == 0 { + if totalTokens, ok := metadata["totalTokens"].(float64); ok { + totalUsage.TotalTokens = int64(totalTokens) + log.Debugf("kiro: streamToChannel found totalTokens in messageMetadataEvent: %d", totalUsage.TotalTokens) + } + } + + case "usageEvent", "usage": + // Handle dedicated usage events + if inputTokens, ok := event["inputTokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + log.Debugf("kiro: streamToChannel found inputTokens in usageEvent: %d", totalUsage.InputTokens) + } + if outputTokens, ok := event["outputTokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + log.Debugf("kiro: streamToChannel found outputTokens in usageEvent: %d", totalUsage.OutputTokens) + } + if totalTokens, ok := event["totalTokens"].(float64); ok { + totalUsage.TotalTokens = int64(totalTokens) + log.Debugf("kiro: streamToChannel found totalTokens in usageEvent: %d", totalUsage.TotalTokens) + } + // Also check nested usage object + if usageObj, ok := event["usage"].(map[string]interface{}); ok { + if inputTokens, ok := usageObj["input_tokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + } else if inputTokens, ok := usageObj["prompt_tokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + } + if outputTokens, ok := usageObj["output_tokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + } else if outputTokens, ok := usageObj["completion_tokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + } + if totalTokens, ok := usageObj["total_tokens"].(float64); ok { + totalUsage.TotalTokens = int64(totalTokens) + } + log.Debugf("kiro: streamToChannel found usage object: input=%d, output=%d, total=%d", + totalUsage.InputTokens, totalUsage.OutputTokens, totalUsage.TotalTokens) + } + + case "metricsEvent": + // Handle metrics events which may contain usage data + if metrics, ok := event["metricsEvent"].(map[string]interface{}); ok { + if inputTokens, ok := metrics["inputTokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + } + if outputTokens, ok := metrics["outputTokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + } + log.Debugf("kiro: streamToChannel found metricsEvent: input=%d, output=%d", + totalUsage.InputTokens, totalUsage.OutputTokens) + + } + default: + // Check for upstream usage events from Kiro API + // Format: {"unit":"credit","unitPlural":"credits","usage":1.458} + if unit, ok := event["unit"].(string); ok && unit == "credit" { + if usage, ok := event["usage"].(float64); ok { + upstreamCreditUsage = usage + hasUpstreamUsage = true + log.Debugf("kiro: received upstream credit usage: %.4f", upstreamCreditUsage) + } + } + // Format: {"contextUsagePercentage":78.56} + if ctxPct, ok := event["contextUsagePercentage"].(float64); ok { + upstreamContextPercentage = ctxPct + log.Debugf("kiro: received upstream context usage: %.2f%%", upstreamContextPercentage) + } + + // Check for token counts in unknown events + if inputTokens, ok := event["inputTokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + hasUpstreamUsage = true + log.Debugf("kiro: streamToChannel found inputTokens in event %s: %d", eventType, totalUsage.InputTokens) + } + if outputTokens, ok := event["outputTokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + hasUpstreamUsage = true + log.Debugf("kiro: streamToChannel found outputTokens in event %s: %d", eventType, totalUsage.OutputTokens) + } + if totalTokens, ok := event["totalTokens"].(float64); ok { + totalUsage.TotalTokens = int64(totalTokens) + log.Debugf("kiro: streamToChannel found totalTokens in event %s: %d", eventType, totalUsage.TotalTokens) + } + + // Check for usage object in unknown events (OpenAI/Claude format) + if usageObj, ok := event["usage"].(map[string]interface{}); ok { + if inputTokens, ok := usageObj["input_tokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + hasUpstreamUsage = true + } else if inputTokens, ok := usageObj["prompt_tokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + hasUpstreamUsage = true + } + if outputTokens, ok := usageObj["output_tokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + hasUpstreamUsage = true + } else if outputTokens, ok := usageObj["completion_tokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + hasUpstreamUsage = true + } + if totalTokens, ok := usageObj["total_tokens"].(float64); ok { + totalUsage.TotalTokens = int64(totalTokens) + } + log.Debugf("kiro: streamToChannel found usage object in event %s: input=%d, output=%d, total=%d", + eventType, totalUsage.InputTokens, totalUsage.OutputTokens, totalUsage.TotalTokens) + } + + // Log unknown event types for debugging (to discover new event formats) + if eventType != "" { + log.Debugf("kiro: streamToChannel unknown event type: %s, payload: %s", eventType, string(payload)) + } + + } + + // Check nested usage event + if usageEvent, ok := event["supplementaryWebLinksEvent"].(map[string]interface{}); ok { + if inputTokens, ok := usageEvent["inputTokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + } + if outputTokens, ok := usageEvent["outputTokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + } + } + + // Check for direct token fields in any event (fallback) + if totalUsage.InputTokens == 0 { + if inputTokens, ok := event["inputTokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + log.Debugf("kiro: streamToChannel found direct inputTokens: %d", totalUsage.InputTokens) + } + } + if totalUsage.OutputTokens == 0 { + if outputTokens, ok := event["outputTokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + log.Debugf("kiro: streamToChannel found direct outputTokens: %d", totalUsage.OutputTokens) + } + } + + // Check for usage object in any event (OpenAI format) + if totalUsage.InputTokens == 0 || totalUsage.OutputTokens == 0 { + if usageObj, ok := event["usage"].(map[string]interface{}); ok { + if totalUsage.InputTokens == 0 { + if inputTokens, ok := usageObj["input_tokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + } else if inputTokens, ok := usageObj["prompt_tokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + } + } + if totalUsage.OutputTokens == 0 { + if outputTokens, ok := usageObj["output_tokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + } else if outputTokens, ok := usageObj["completion_tokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + } + } + if totalUsage.TotalTokens == 0 { + if totalTokens, ok := usageObj["total_tokens"].(float64); ok { + totalUsage.TotalTokens = int64(totalTokens) + } + } + log.Debugf("kiro: streamToChannel found usage object (fallback): input=%d, output=%d, total=%d", + totalUsage.InputTokens, totalUsage.OutputTokens, totalUsage.TotalTokens) + } + } + } + + // Close content block if open + if isTextBlockOpen && contentBlockIndex >= 0 { + blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + // Streaming token calculation - calculate output tokens from accumulated content + // Only use local estimation if server didn't provide usage (server-side usage takes priority) + if totalUsage.OutputTokens == 0 && accumulatedContent.Len() > 0 { + // Try to use tiktoken for accurate counting + if enc, err := getTokenizer(model); err == nil { + if tokenCount, countErr := enc.Count(accumulatedContent.String()); countErr == nil { + totalUsage.OutputTokens = int64(tokenCount) + log.Debugf("kiro: streamToChannel calculated output tokens using tiktoken: %d", totalUsage.OutputTokens) + } else { + // Fallback on count error: estimate from character count + totalUsage.OutputTokens = int64(accumulatedContent.Len() / 4) + if totalUsage.OutputTokens == 0 { + totalUsage.OutputTokens = 1 + } + log.Debugf("kiro: streamToChannel tiktoken count failed, estimated from chars: %d", totalUsage.OutputTokens) + } + } else { + // Fallback: estimate from character count (roughly 4 chars per token) + totalUsage.OutputTokens = int64(accumulatedContent.Len() / 4) + if totalUsage.OutputTokens == 0 { + totalUsage.OutputTokens = 1 + } + log.Debugf("kiro: streamToChannel estimated output tokens from chars: %d (content len: %d)", totalUsage.OutputTokens, accumulatedContent.Len()) + } + } else if totalUsage.OutputTokens == 0 && outputLen > 0 { + // Legacy fallback using outputLen + totalUsage.OutputTokens = int64(outputLen / 4) + if totalUsage.OutputTokens == 0 { + totalUsage.OutputTokens = 1 + } + } + + // Use contextUsagePercentage to calculate more accurate input tokens + // Kiro model has 200k max context, contextUsagePercentage represents the percentage used + // Formula: input_tokens = contextUsagePercentage * 200000 / 100 + // Note: The effective input context is ~170k (200k - 30k reserved for output) + if upstreamContextPercentage > 0 { + // Calculate input tokens from context percentage + // Using 200k as the base since that's what Kiro reports against + calculatedInputTokens := int64(upstreamContextPercentage * 200000 / 100) + + // Only use calculated value if it's significantly different from local estimate + // This provides more accurate token counts based on upstream data + if calculatedInputTokens > 0 { + localEstimate := totalUsage.InputTokens + totalUsage.InputTokens = calculatedInputTokens + log.Debugf("kiro: using contextUsagePercentage (%.2f%%) to calculate input tokens: %d (local estimate was: %d)", + upstreamContextPercentage, calculatedInputTokens, localEstimate) + } + } + + totalUsage.TotalTokens = totalUsage.InputTokens + totalUsage.OutputTokens + + // Log upstream usage information if received + if hasUpstreamUsage { + log.Debugf("kiro: upstream usage - credits: %.4f, context: %.2f%%, final tokens - input: %d, output: %d, total: %d", + upstreamCreditUsage, upstreamContextPercentage, + totalUsage.InputTokens, totalUsage.OutputTokens, totalUsage.TotalTokens) + } + + // Determine stop reason: prefer upstream, then detect tool_use, default to end_turn + // SOFT_LIMIT_REACHED: Keep stop_reason = "tool_use" so Claude continues the loop + stopReason := upstreamStopReason + if hasTruncatedTools { + // Log that we're using SOFT_LIMIT_REACHED approach + log.Infof("kiro: streamToChannel using SOFT_LIMIT_REACHED - keeping stop_reason=tool_use for truncated tools") + } + if stopReason == "" { + if hasToolUses { + stopReason = "tool_use" + log.Debugf("kiro: streamToChannel using fallback stop_reason: tool_use") + } else { + stopReason = "end_turn" + log.Debugf("kiro: streamToChannel using fallback stop_reason: end_turn") + } + } + + // Log warning if response was truncated due to max_tokens + if stopReason == "max_tokens" { + log.Warnf("kiro: response truncated due to max_tokens limit (streamToChannel)") + } + + // Send message_delta event + msgDelta := kiroclaude.BuildClaudeMessageDeltaEvent(stopReason, totalUsage) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgDelta, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + + // Send message_stop event separately + msgStop := kiroclaude.BuildClaudeMessageStopOnlyEvent() + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + // reporter.publish is called via defer +} + +// ══════════════════════════════════════════════════════════════════════════════ +// Web Search Handler (MCP API) +// ══════════════════════════════════════════════════════════════════════════════ + +// fetchToolDescription caching: +// Uses a mutex + fetched flag to ensure only one goroutine fetches at a time, +// with automatic retry on failure: +// - On failure, fetched stays false so subsequent calls will retry +// - On success, fetched is set to true — subsequent calls skip immediately (mutex-free fast path) +// The cached description is stored in the translator package via kiroclaude.SetWebSearchDescription(), +// enabling the translator's convertClaudeToolsToKiro to read it when building Kiro requests. +var ( + toolDescMu sync.Mutex + toolDescFetched atomic.Bool +) + +// fetchToolDescription calls MCP tools/list to get the web_search tool description +// and caches it. Safe to call concurrently — only one goroutine fetches at a time. +// If the fetch fails, subsequent calls will retry. On success, no further fetches occur. +// The httpClient parameter allows reusing a shared pooled HTTP client. +func fetchToolDescription(ctx context.Context, mcpEndpoint, authToken string, httpClient *http.Client, auth *cliproxyauth.Auth, authAttrs map[string]string) { + // Fast path: already fetched successfully, no lock needed + if toolDescFetched.Load() { + return + } + + toolDescMu.Lock() + defer toolDescMu.Unlock() + + // Double-check after acquiring lock + if toolDescFetched.Load() { + return + } + + handler := newWebSearchHandler(ctx, mcpEndpoint, authToken, httpClient, auth, authAttrs) + reqBody := []byte(`{"id":"tools_list","jsonrpc":"2.0","method":"tools/list"}`) + log.Debugf("kiro/websearch MCP tools/list request: %d bytes", len(reqBody)) + + req, err := http.NewRequestWithContext(ctx, "POST", mcpEndpoint, bytes.NewReader(reqBody)) + if err != nil { + log.Warnf("kiro/websearch: failed to create tools/list request: %v", err) + return + } + + // Reuse same headers as callMcpAPI + handler.setMcpHeaders(req) + + resp, err := handler.httpClient.Do(req) + if err != nil { + log.Warnf("kiro/websearch: tools/list request failed: %v", err) + return + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil || resp.StatusCode != http.StatusOK { + log.Warnf("kiro/websearch: tools/list returned status %d", resp.StatusCode) + return + } + log.Debugf("kiro/websearch MCP tools/list response: [%d] %d bytes", resp.StatusCode, len(body)) + + // Parse: {"result":{"tools":[{"name":"web_search","description":"..."}]}} + var result struct { + Result *struct { + Tools []struct { + Name string `json:"name"` + Description string `json:"description"` + } `json:"tools"` + } `json:"result"` + } + if err := json.Unmarshal(body, &result); err != nil || result.Result == nil { + log.Warnf("kiro/websearch: failed to parse tools/list response") + return + } + + for _, tool := range result.Result.Tools { + if tool.Name == "web_search" && tool.Description != "" { + kiroclaude.SetWebSearchDescription(tool.Description) + toolDescFetched.Store(true) // success — no more fetches + log.Infof("kiro/websearch: cached web_search description from tools/list (%d bytes)", len(tool.Description)) + return + } + } + + // web_search tool not found in response + log.Warnf("kiro/websearch: web_search tool not found in tools/list response") +} + +// webSearchHandler handles web search requests via Kiro MCP API +type webSearchHandler struct { + ctx context.Context + mcpEndpoint string + httpClient *http.Client + authToken string + auth *cliproxyauth.Auth // for applyDynamicFingerprint + authAttrs map[string]string // optional, for custom headers from auth.Attributes +} + +// newWebSearchHandler creates a new webSearchHandler. +// If httpClient is nil, a default client with 30s timeout is used. +// Pass a shared pooled client (e.g. from getKiroPooledHTTPClient) for connection reuse. +func newWebSearchHandler(ctx context.Context, mcpEndpoint, authToken string, httpClient *http.Client, auth *cliproxyauth.Auth, authAttrs map[string]string) *webSearchHandler { + if httpClient == nil { + httpClient = &http.Client{ + Timeout: 30 * time.Second, + } + } + return &webSearchHandler{ + ctx: ctx, + mcpEndpoint: mcpEndpoint, + httpClient: httpClient, + authToken: authToken, + auth: auth, + authAttrs: authAttrs, + } +} + +// setMcpHeaders sets standard MCP API headers on the request, +// aligned with the GAR request pattern. +func (h *webSearchHandler) setMcpHeaders(req *http.Request) { + // 1. Content-Type & Accept (aligned with GAR) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "*/*") + + // 2. Kiro-specific headers (aligned with GAR) + req.Header.Set("x-amzn-kiro-agent-mode", "vibe") + req.Header.Set("x-amzn-codewhisperer-optout", "true") + + // 3. User-Agent: Reuse applyDynamicFingerprint for consistency + applyDynamicFingerprint(req, h.auth) + + // 4. AWS SDK identifiers + req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") + req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) + + // 5. Authentication + req.Header.Set("Authorization", "Bearer "+h.authToken) + + // 6. Custom headers from auth attributes + util.ApplyCustomHeadersFromAttrs(req, h.authAttrs) +} + +// mcpMaxRetries is the maximum number of retries for MCP API calls. +const mcpMaxRetries = 2 + +// callMcpAPI calls the Kiro MCP API with the given request. +// Includes retry logic with exponential backoff for retryable errors. +func (h *webSearchHandler) callMcpAPI(request *kiroclaude.McpRequest) (*kiroclaude.McpResponse, error) { + requestBody, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal MCP request: %w", err) + } + log.Debugf("kiro/websearch MCP request → %s (%d bytes)", h.mcpEndpoint, len(requestBody)) + + var lastErr error + for attempt := 0; attempt <= mcpMaxRetries; attempt++ { + if attempt > 0 { + backoff := time.Duration(1< 10*time.Second { + backoff = 10 * time.Second + } + log.Warnf("kiro/websearch: MCP retry %d/%d after %v (last error: %v)", attempt, mcpMaxRetries, backoff, lastErr) + select { + case <-h.ctx.Done(): + return nil, h.ctx.Err() + case <-time.After(backoff): + } + } + + req, err := http.NewRequestWithContext(h.ctx, "POST", h.mcpEndpoint, bytes.NewReader(requestBody)) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP request: %w", err) + } + + h.setMcpHeaders(req) + + resp, err := h.httpClient.Do(req) + if err != nil { + lastErr = fmt.Errorf("MCP API request failed: %w", err) + continue // network error → retry + } + + body, err := io.ReadAll(resp.Body) + _ = resp.Body.Close() + if err != nil { + lastErr = fmt.Errorf("failed to read MCP response: %w", err) + continue // read error → retry + } + log.Debugf("kiro/websearch MCP response ← [%d] (%d bytes)", resp.StatusCode, len(body)) + + // Retryable HTTP status codes (aligned with GAR: 502, 503, 504) + if resp.StatusCode >= 502 && resp.StatusCode <= 504 { + lastErr = fmt.Errorf("MCP API returned retryable status %d: %s", resp.StatusCode, string(body)) + continue + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("MCP API returned status %d: %s", resp.StatusCode, string(body)) + } + + var mcpResponse kiroclaude.McpResponse + if err := json.Unmarshal(body, &mcpResponse); err != nil { + return nil, fmt.Errorf("failed to parse MCP response: %w", err) + } + + if mcpResponse.Error != nil { + code := -1 + if mcpResponse.Error.Code != nil { + code = *mcpResponse.Error.Code + } + msg := "Unknown error" + if mcpResponse.Error.Message != nil { + msg = *mcpResponse.Error.Message + } + return nil, fmt.Errorf("MCP error %d: %s", code, msg) + } + + return &mcpResponse, nil + } + + return nil, lastErr +} + +// webSearchAuthAttrs extracts auth attributes for MCP calls. +// Used by handleWebSearch and handleWebSearchStream to pass custom headers. +func webSearchAuthAttrs(auth *cliproxyauth.Auth) map[string]string { + if auth != nil { + return auth.Attributes + } + return nil +} + +const maxWebSearchIterations = 5 + +// handleWebSearchStream handles web_search requests: +// Step 1: tools/list (sync) → fetch/cache tool description +// Step 2+: MCP search → InjectToolResultsClaude → callKiroAndBuffer loop +// Note: We skip the "model decides to search" step because Claude Code already +// decided to use web_search. The Kiro tool description restricts non-coding +// topics, so asking the model again would cause it to refuse valid searches. +func (e *KiroExecutor) handleWebSearchStream( + ctx context.Context, + auth *cliproxyauth.Auth, + req cliproxyexecutor.Request, + opts cliproxyexecutor.Options, + accessToken, profileArn string, +) (<-chan cliproxyexecutor.StreamChunk, error) { + // Extract search query from Claude Code's web_search tool_use + query := kiroclaude.ExtractSearchQuery(req.Payload) + if query == "" { + log.Warnf("kiro/websearch: failed to extract search query, falling back to normal flow") + return e.callKiroDirectStream(ctx, auth, req, opts, accessToken, profileArn) + } + + // Build MCP endpoint using shared region resolution (supports api_region + ProfileARN fallback) + region := resolveKiroAPIRegion(auth) + mcpEndpoint := kiroclaude.BuildMcpEndpoint(region) + + // ── Step 1: tools/list (SYNC) — cache tool description ── + { + authAttrs := webSearchAuthAttrs(auth) + fetchToolDescription(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs) + } + + // Create output channel + out := make(chan cliproxyexecutor.StreamChunk) + + // Usage reporting: track web search requests like normal streaming requests + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + + go func() { + var wsErr error + defer reporter.trackFailure(ctx, &wsErr) + defer close(out) + + // Estimate input tokens using tokenizer (matching streamToChannel pattern) + var totalUsage usage.Detail + if enc, tokErr := getTokenizer(req.Model); tokErr == nil { + if inp, e := countClaudeChatTokens(enc, req.Payload); e == nil && inp > 0 { + totalUsage.InputTokens = inp + } else { + totalUsage.InputTokens = int64(len(req.Payload) / 4) + } + } else { + totalUsage.InputTokens = int64(len(req.Payload) / 4) + } + if totalUsage.InputTokens == 0 && len(req.Payload) > 0 { + totalUsage.InputTokens = 1 + } + var accumulatedOutputLen int + defer func() { + if wsErr != nil { + return // let trackFailure handle failure reporting + } + totalUsage.OutputTokens = int64(accumulatedOutputLen / 4) + if accumulatedOutputLen > 0 && totalUsage.OutputTokens == 0 { + totalUsage.OutputTokens = 1 + } + reporter.publish(ctx, totalUsage) + }() + + // Send message_start event to client (aligned with streamToChannel pattern) + // Use payloadRequestedModel to return user's original model alias + msgStart := kiroclaude.BuildClaudeMessageStartEvent( + payloadRequestedModel(opts, req.Model), + totalUsage.InputTokens, + ) + select { + case <-ctx.Done(): + return + case out <- cliproxyexecutor.StreamChunk{Payload: append(msgStart, '\n', '\n')}: + } + + // ── Step 2+: MCP search → InjectToolResultsClaude → callKiroAndBuffer loop ── + contentBlockIndex := 0 + currentQuery := query + + // Replace web_search tool description with a minimal one that allows re-search. + // The original tools/list description from Kiro restricts non-coding topics, + // but we've already decided to search. We keep the tool so the model can + // request additional searches when results are insufficient. + simplifiedPayload, simplifyErr := kiroclaude.ReplaceWebSearchToolDescription(bytes.Clone(req.Payload)) + if simplifyErr != nil { + log.Warnf("kiro/websearch: failed to simplify web_search tool: %v, using original payload", simplifyErr) + simplifiedPayload = bytes.Clone(req.Payload) + } + + currentClaudePayload := simplifiedPayload + totalSearches := 0 + + // Generate toolUseId for the first iteration (Claude Code already decided to search) + currentToolUseId := fmt.Sprintf("srvtoolu_%s", kiroclaude.GenerateToolUseID()) + + for iteration := 0; iteration < maxWebSearchIterations; iteration++ { + log.Infof("kiro/websearch: search iteration %d/%d", + iteration+1, maxWebSearchIterations) + + // MCP search + _, mcpRequest := kiroclaude.CreateMcpRequest(currentQuery) + + authAttrs := webSearchAuthAttrs(auth) + handler := newWebSearchHandler(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs) + mcpResponse, mcpErr := handler.callMcpAPI(mcpRequest) + + var searchResults *kiroclaude.WebSearchResults + if mcpErr != nil { + log.Warnf("kiro/websearch: MCP API call failed: %v, continuing with empty results", mcpErr) + } else { + searchResults = kiroclaude.ParseSearchResults(mcpResponse) + } + + resultCount := 0 + if searchResults != nil { + resultCount = len(searchResults.Results) + } + totalSearches++ + log.Infof("kiro/websearch: iteration %d — got %d search results", iteration+1, resultCount) + + // Send search indicator events to client + searchEvents := kiroclaude.GenerateSearchIndicatorEvents(currentQuery, currentToolUseId, searchResults, contentBlockIndex) + for _, event := range searchEvents { + select { + case <-ctx.Done(): + return + case out <- cliproxyexecutor.StreamChunk{Payload: event}: + } + } + contentBlockIndex += 2 + + // Inject tool_use + tool_result into Claude payload, then call GAR + var err error + currentClaudePayload, err = kiroclaude.InjectToolResultsClaude(currentClaudePayload, currentToolUseId, currentQuery, searchResults) + if err != nil { + log.Warnf("kiro/websearch: failed to inject tool results: %v", err) + wsErr = fmt.Errorf("failed to inject tool results: %w", err) + e.sendFallbackText(ctx, out, contentBlockIndex, currentQuery, searchResults) + return + } + + // Call GAR with modified Claude payload (full translation pipeline) + modifiedReq := req + modifiedReq.Payload = currentClaudePayload + kiroChunks, kiroErr := e.callKiroAndBuffer(ctx, auth, modifiedReq, opts, accessToken, profileArn) + if kiroErr != nil { + log.Warnf("kiro/websearch: Kiro API failed at iteration %d: %v", iteration+1, kiroErr) + wsErr = fmt.Errorf("kiro API failed at iteration %d: %w", iteration+1, kiroErr) + e.sendFallbackText(ctx, out, contentBlockIndex, currentQuery, searchResults) + return + } + + // Analyze response + analysis := kiroclaude.AnalyzeBufferedStream(kiroChunks) + log.Infof("kiro/websearch: iteration %d — stop_reason: %s, has_tool_use: %v", + iteration+1, analysis.StopReason, analysis.HasWebSearchToolUse) + + if analysis.HasWebSearchToolUse && analysis.WebSearchQuery != "" && iteration+1 < maxWebSearchIterations { + // Model wants another search + filteredChunks := kiroclaude.FilterChunksForClient(kiroChunks, analysis.WebSearchToolUseIndex, contentBlockIndex) + for _, chunk := range filteredChunks { + select { + case <-ctx.Done(): + return + case out <- cliproxyexecutor.StreamChunk{Payload: chunk}: + } + } + + currentQuery = analysis.WebSearchQuery + currentToolUseId = analysis.WebSearchToolUseId + continue + } + + // Model returned final response — stream to client + for _, chunk := range kiroChunks { + if contentBlockIndex > 0 && len(chunk) > 0 { + adjusted, shouldForward := kiroclaude.AdjustSSEChunk(chunk, contentBlockIndex) + if !shouldForward { + continue + } + accumulatedOutputLen += len(adjusted) + select { + case <-ctx.Done(): + return + case out <- cliproxyexecutor.StreamChunk{Payload: adjusted}: + } + } else { + accumulatedOutputLen += len(chunk) + select { + case <-ctx.Done(): + return + case out <- cliproxyexecutor.StreamChunk{Payload: chunk}: + } + } + } + log.Infof("kiro/websearch: completed after %d search iteration(s), total searches: %d", iteration+1, totalSearches) + return + } + + log.Warnf("kiro/websearch: reached max iterations (%d), stopping search loop", maxWebSearchIterations) + }() + + return out, nil +} + +// handleWebSearch handles web_search requests for non-streaming Execute path. +// Performs MCP search synchronously, injects results into the request payload, +// then calls the normal non-streaming Kiro API path which returns a proper +// Claude JSON response (not SSE chunks). +func (e *KiroExecutor) handleWebSearch( + ctx context.Context, + auth *cliproxyauth.Auth, + req cliproxyexecutor.Request, + opts cliproxyexecutor.Options, + accessToken, profileArn string, +) (cliproxyexecutor.Response, error) { + // Extract search query from Claude Code's web_search tool_use + query := kiroclaude.ExtractSearchQuery(req.Payload) + if query == "" { + log.Warnf("kiro/websearch: non-stream: failed to extract search query, falling back to normal Execute") + // Fall through to normal non-streaming path + return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn) + } + + // Build MCP endpoint using shared region resolution (supports api_region + ProfileARN fallback) + region := resolveKiroAPIRegion(auth) + mcpEndpoint := kiroclaude.BuildMcpEndpoint(region) + + // Step 1: Fetch/cache tool description (sync) + { + authAttrs := webSearchAuthAttrs(auth) + fetchToolDescription(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs) + } + + // Step 2: Perform MCP search + _, mcpRequest := kiroclaude.CreateMcpRequest(query) + + authAttrs := webSearchAuthAttrs(auth) + handler := newWebSearchHandler(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs) + mcpResponse, mcpErr := handler.callMcpAPI(mcpRequest) + + var searchResults *kiroclaude.WebSearchResults + if mcpErr != nil { + log.Warnf("kiro/websearch: non-stream: MCP API call failed: %v, continuing with empty results", mcpErr) + } else { + searchResults = kiroclaude.ParseSearchResults(mcpResponse) + } + + resultCount := 0 + if searchResults != nil { + resultCount = len(searchResults.Results) + } + log.Infof("kiro/websearch: non-stream: got %d search results", resultCount) + + // Step 3: Replace restrictive web_search tool description (align with streaming path) + simplifiedPayload, simplifyErr := kiroclaude.ReplaceWebSearchToolDescription(bytes.Clone(req.Payload)) + if simplifyErr != nil { + log.Warnf("kiro/websearch: non-stream: failed to simplify web_search tool: %v, using original payload", simplifyErr) + simplifiedPayload = bytes.Clone(req.Payload) + } + + // Step 4: Inject search tool_use + tool_result into Claude payload + currentToolUseId := fmt.Sprintf("srvtoolu_%s", kiroclaude.GenerateToolUseID()) + modifiedPayload, err := kiroclaude.InjectToolResultsClaude(simplifiedPayload, currentToolUseId, query, searchResults) + if err != nil { + log.Warnf("kiro/websearch: non-stream: failed to inject tool results: %v, falling back", err) + return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn) + } + + // Step 5: Call Kiro API via the normal non-streaming path (executeWithRetry) + // This path uses parseEventStream → BuildClaudeResponse → TranslateNonStream + // to produce a proper Claude JSON response + modifiedReq := req + modifiedReq.Payload = modifiedPayload + + resp, err := e.executeNonStreamFallback(ctx, auth, modifiedReq, opts, accessToken, profileArn) + if err != nil { + return resp, err + } + + // Step 6: Inject server_tool_use + web_search_tool_result into response + // so Claude Code can display "Did X searches in Ys" + indicators := []kiroclaude.SearchIndicator{ + { + ToolUseID: currentToolUseId, + Query: query, + Results: searchResults, + }, + } + injectedPayload, injErr := kiroclaude.InjectSearchIndicatorsInResponse(resp.Payload, indicators) + if injErr != nil { + log.Warnf("kiro/websearch: non-stream: failed to inject search indicators: %v", injErr) + } else { + resp.Payload = injectedPayload + } + + return resp, nil +} + +// callKiroAndBuffer calls the Kiro API and buffers all response chunks. +// Returns the buffered chunks for analysis before forwarding to client. +// Usage reporting is NOT done here — the caller (handleWebSearchStream) manages its own reporter. +func (e *KiroExecutor) callKiroAndBuffer( + ctx context.Context, + auth *cliproxyauth.Auth, + req cliproxyexecutor.Request, + opts cliproxyexecutor.Options, + accessToken, profileArn string, +) ([][]byte, error) { + from := opts.SourceFormat + to := sdktranslator.FromString("kiro") + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) + log.Debugf("kiro/websearch GAR request: %d bytes", len(body)) + + kiroModelID := e.mapModelToKiro(req.Model) + isAgentic, isChatOnly := determineAgenticMode(req.Model) + effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn) + + tokenKey := getTokenKey(auth) + + kiroStream, err := e.executeStreamWithRetry( + ctx, auth, req, opts, accessToken, effectiveProfileArn, + body, from, nil, kiroModelID, isAgentic, isChatOnly, tokenKey, + ) + if err != nil { + return nil, err + } + + // Buffer all chunks + var chunks [][]byte + for chunk := range kiroStream { + if chunk.Err != nil { + return chunks, chunk.Err + } + if len(chunk.Payload) > 0 { + chunks = append(chunks, bytes.Clone(chunk.Payload)) + } + } + + log.Debugf("kiro/websearch GAR response: %d chunks buffered", len(chunks)) + + return chunks, nil +} + +// callKiroDirectStream creates a direct streaming channel to Kiro API without search. +func (e *KiroExecutor) callKiroDirectStream( + ctx context.Context, + auth *cliproxyauth.Auth, + req cliproxyexecutor.Request, + opts cliproxyexecutor.Options, + accessToken, profileArn string, +) (<-chan cliproxyexecutor.StreamChunk, error) { + from := opts.SourceFormat + to := sdktranslator.FromString("kiro") + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) + + kiroModelID := e.mapModelToKiro(req.Model) + isAgentic, isChatOnly := determineAgenticMode(req.Model) + effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn) + + tokenKey := getTokenKey(auth) + + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + var streamErr error + defer reporter.trackFailure(ctx, &streamErr) + + stream, streamErr := e.executeStreamWithRetry( + ctx, auth, req, opts, accessToken, effectiveProfileArn, + body, from, reporter, kiroModelID, isAgentic, isChatOnly, tokenKey, + ) + return stream, streamErr +} + +// sendFallbackText sends a simple text response when the Kiro API fails during the search loop. +// Delegates SSE event construction to kiroclaude.BuildFallbackTextEvents() for alignment +// with how streamToChannel() uses BuildClaude*Event() functions. +func (e *KiroExecutor) sendFallbackText( + ctx context.Context, + out chan<- cliproxyexecutor.StreamChunk, + contentBlockIndex int, + query string, + searchResults *kiroclaude.WebSearchResults, +) { + events := kiroclaude.BuildFallbackTextEvents(contentBlockIndex, query, searchResults) + for _, event := range events { + select { + case <-ctx.Done(): + return + case out <- cliproxyexecutor.StreamChunk{Payload: append(event, '\n', '\n')}: + } + } +} + +// executeNonStreamFallback runs the standard non-streaming Execute path for a request. +// Used by handleWebSearch after injecting search results, or as a fallback. +func (e *KiroExecutor) executeNonStreamFallback( + ctx context.Context, + auth *cliproxyauth.Auth, + req cliproxyexecutor.Request, + opts cliproxyexecutor.Options, + accessToken, profileArn string, +) (cliproxyexecutor.Response, error) { + from := opts.SourceFormat + to := sdktranslator.FromString("kiro") + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) + + kiroModelID := e.mapModelToKiro(req.Model) + isAgentic, isChatOnly := determineAgenticMode(req.Model) + effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn) + tokenKey := getTokenKey(auth) + + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + var err error + defer reporter.trackFailure(ctx, &err) + + resp, err := e.executeWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, body, from, to, reporter, kiroModelID, isAgentic, isChatOnly, tokenKey) + return resp, err +} + +func (e *KiroExecutor) CloseExecutionSession(sessionID string) {} diff --git a/pkg/llmproxy/executor/kiro_transform.go b/pkg/llmproxy/executor/kiro_transform.go new file mode 100644 index 0000000000..78c235edfc --- /dev/null +++ b/pkg/llmproxy/executor/kiro_transform.go @@ -0,0 +1,469 @@ +package executor + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "net/http" + "strings" + + kiroclaude "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/translator/kiro/claude" + kiroopenai "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/translator/kiro/openai" + clipproxyauth "github.com/kooshapari/cliproxyapi-plusplus/v6/sdk/cliproxy/auth" + sdktranslator "github.com/kooshapari/cliproxyapi-plusplus/v6/sdk/translator" + log "github.com/sirupsen/logrus" +) + +// kiroEndpointConfig bundles endpoint URL with its compatible Origin and AmzTarget values. +// This solves the "triple mismatch" problem where different endpoints require matching +// Origin and X-Amz-Target header values. +// +// Based on reference implementations: +// - amq2api-main: Uses Amazon Q endpoint with CLI origin and AmazonQDeveloperStreamingService target +// - AIClient-2-API: Uses CodeWhisperer endpoint with AI_EDITOR origin and AmazonCodeWhispererStreamingService target +type kiroEndpointConfig struct { + URL string // Endpoint URL + Origin string // Request Origin: "CLI" for Amazon Q quota, "AI_EDITOR" for Kiro IDE quota + AmzTarget string // X-Amz-Target header value + Name string // Endpoint name for logging +} + +// kiroDefaultRegion is the default AWS region for Kiro API endpoints. +// Used when no region is specified in auth metadata. +const kiroDefaultRegion = "us-east-1" + +// extractRegionFromProfileARN extracts the AWS region from a ProfileARN. +// ARN format: arn:aws:codewhisperer:REGION:ACCOUNT:profile/PROFILE_ID +// Returns empty string if region cannot be extracted. +func extractRegionFromProfileARN(profileArn string) string { + if profileArn == "" { + return "" + } + parts := strings.Split(profileArn, ":") + if len(parts) >= 4 && parts[3] != "" { + return parts[3] + } + return "" +} + +// buildKiroEndpointConfigs creates endpoint configurations for the specified region. +// This enables dynamic region support for Enterprise/IdC users in non-us-east-1 regions. +// +// Uses Q endpoint (q.{region}.amazonaws.com) as primary for ALL auth types: +// - Works universally across all AWS regions (CodeWhisperer endpoint only exists in us-east-1) +// - Uses /generateAssistantResponse path with AI_EDITOR origin +// - Does NOT require X-Amz-Target header +// +// The AmzTarget field is kept for backward compatibility but should be empty +// to indicate that the header should NOT be set. +func buildKiroEndpointConfigs(region string) []kiroEndpointConfig { + if region == "" { + region = kiroDefaultRegion + } + return []kiroEndpointConfig{ + { + // Primary: Q endpoint - works for all regions and auth types + URL: fmt.Sprintf("https://q.%s.amazonaws.com/generateAssistantResponse", region), + Origin: "AI_EDITOR", + AmzTarget: "", // Empty = don't set X-Amz-Target header + Name: "AmazonQ", + }, + { + // Fallback: CodeWhisperer endpoint (legacy, only works in us-east-1) + URL: fmt.Sprintf("https://codewhisperer.%s.amazonaws.com/generateAssistantResponse", region), + Origin: "AI_EDITOR", + AmzTarget: "AmazonCodeWhispererStreamingService.GenerateAssistantResponse", + Name: "CodeWhisperer", + }, + } +} + +// resolveKiroAPIRegion determines the AWS region for Kiro API calls. +// Region priority: +// 1. auth.Metadata["api_region"] - explicit API region override +// 2. ProfileARN region - extracted from arn:aws:service:REGION:account:resource +// 3. kiroDefaultRegion (us-east-1) - fallback +// Note: OIDC "region" is NOT used - it's for token refresh, not API calls +func resolveKiroAPIRegion(auth *cliproxyauth.Auth) string { + if auth == nil || auth.Metadata == nil { + return kiroDefaultRegion + } + // Priority 1: Explicit api_region override + if r, ok := auth.Metadata["api_region"].(string); ok && r != "" { + log.Debugf("kiro: using region %s (source: api_region)", r) + return r + } + // Priority 2: Extract from ProfileARN + if profileArn, ok := auth.Metadata["profile_arn"].(string); ok && profileArn != "" { + if arnRegion := extractRegionFromProfileARN(profileArn); arnRegion != "" { + log.Debugf("kiro: using region %s (source: profile_arn)", arnRegion) + return arnRegion + } + } + // Note: OIDC "region" field is NOT used for API endpoint + // Kiro API only exists in us-east-1, while OIDC region can vary (e.g., ap-northeast-2) + // Using OIDC region for API calls causes DNS failures + log.Debugf("kiro: using region %s (source: default)", kiroDefaultRegion) + return kiroDefaultRegion +} + +// kiroEndpointConfigs is kept for backward compatibility with default us-east-1 region. +// Prefer using buildKiroEndpointConfigs(region) for dynamic region support. +var kiroEndpointConfigs = buildKiroEndpointConfigs(kiroDefaultRegion) + +// getKiroEndpointConfigs returns the list of Kiro API endpoint configurations to try in order. +// Supports dynamic region based on auth metadata "api_region", "profile_arn", or "region" field. +// Supports reordering based on "preferred_endpoint" in auth metadata/attributes. +// +// Region priority: +// 1. auth.Metadata["api_region"] - explicit API region override +// 2. ProfileARN region - extracted from arn:aws:service:REGION:account:resource +// 3. kiroDefaultRegion (us-east-1) - fallback +// Note: OIDC "region" is NOT used - it's for token refresh, not API calls +func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig { + if auth == nil { + return kiroEndpointConfigs + } + + // Determine API region using shared resolution logic + region := resolveKiroAPIRegion(auth) + + // Build endpoint configs for the specified region + endpointConfigs := buildKiroEndpointConfigs(region) + + // For IDC auth, use Q endpoint with AI_EDITOR origin + // IDC tokens work with Q endpoint using Bearer auth + // The difference is only in how tokens are refreshed (OIDC with clientId/clientSecret for IDC) + // NOT in how API calls are made - both Social and IDC use the same endpoint/origin + if auth.Metadata != nil { + authMethod, _ := auth.Metadata["auth_method"].(string) + if strings.ToLower(authMethod) == "idc" { + log.Debugf("kiro: IDC auth, using Q endpoint (region: %s)", region) + return endpointConfigs + } + } + + // Check for preference + var preference string + if auth.Metadata != nil { + if p, ok := auth.Metadata["preferred_endpoint"].(string); ok { + preference = p + } + } + // Check attributes as fallback (e.g. from HTTP headers) + if preference == "" && auth.Attributes != nil { + preference = auth.Attributes["preferred_endpoint"] + } + + if preference == "" { + return endpointConfigs + } + + preference = strings.ToLower(strings.TrimSpace(preference)) + + // Create new slice to avoid modifying global state + var sorted []kiroEndpointConfig + var remaining []kiroEndpointConfig + + for _, cfg := range endpointConfigs { + name := strings.ToLower(cfg.Name) + // Check for matches + // CodeWhisperer aliases: codewhisperer, ide + // AmazonQ aliases: amazonq, q, cli + isMatch := false + if (preference == "codewhisperer" || preference == "ide") && name == "codewhisperer" { + isMatch = true + } else if (preference == "amazonq" || preference == "q" || preference == "cli") && name == "amazonq" { + isMatch = true + } + + if isMatch { + sorted = append(sorted, cfg) + } else { + remaining = append(remaining, cfg) + } + } + + // If preference didn't match anything, return default + if len(sorted) == 0 { + return endpointConfigs + } + + // Combine: preferred first, then others + return append(sorted, remaining...) +} + +// isIDCAuth checks if the auth uses IDC (Identity Center) authentication method. +func isIDCAuth(auth *cliproxyauth.Auth) bool { + if auth == nil || auth.Metadata == nil { + return false + } + authMethod, _ := auth.Metadata["auth_method"].(string) + return strings.ToLower(authMethod) == "idc" +} + +// buildKiroPayloadForFormat builds the Kiro API payload based on the source format. +// This is critical because OpenAI and Claude formats have different tool structures: +// - OpenAI: tools[].function.name, tools[].function.description +// - Claude: tools[].name, tools[].description +// headers parameter allows checking Anthropic-Beta header for thinking mode detection. +// Returns the serialized JSON payload and a boolean indicating whether thinking mode was injected. +func buildKiroPayloadForFormat(body []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool, sourceFormat sdktranslator.Format, headers http.Header) ([]byte, bool) { + switch sourceFormat.String() { + case "openai": + log.Debugf("kiro: using OpenAI payload builder for source format: %s", sourceFormat.String()) + return kiroopenai.BuildKiroPayloadFromOpenAI(body, modelID, profileArn, origin, isAgentic, isChatOnly, headers, nil) + case "kiro": + // Body is already in Kiro format — pass through directly + log.Debugf("kiro: body already in Kiro format, passing through directly") + return sanitizeKiroPayload(body), false + default: + // Default to Claude format + log.Debugf("kiro: using Claude payload builder for source format: %s", sourceFormat.String()) + return kiroclaude.BuildKiroPayload(body, modelID, profileArn, origin, isAgentic, isChatOnly, headers, nil) + } +} + +func sanitizeKiroPayload(body []byte) []byte { + var payload map[string]any + if err := json.Unmarshal(body, &payload); err != nil { + return body + } + if _, exists := payload["user"]; !exists { + return body + } + delete(payload, "user") + sanitized, err := json.Marshal(payload) + if err != nil { + return body + } + return sanitized +} + +func kiroCredentials(auth *cliproxyauth.Auth) (accessToken, profileArn string) { + if auth == nil { + return "", "" + } + + // Try Metadata first (wrapper format) + if auth.Metadata != nil { + if token, ok := auth.Metadata["access_token"].(string); ok { + accessToken = token + } + if arn, ok := auth.Metadata["profile_arn"].(string); ok { + profileArn = arn + } + } + + // Try Attributes + if accessToken == "" && auth.Attributes != nil { + accessToken = auth.Attributes["access_token"] + profileArn = auth.Attributes["profile_arn"] + } + + // Try direct fields from flat JSON format (new AWS Builder ID format) + if accessToken == "" && auth.Metadata != nil { + if token, ok := auth.Metadata["accessToken"].(string); ok { + accessToken = token + } + if arn, ok := auth.Metadata["profileArn"].(string); ok { + profileArn = arn + } + } + + return accessToken, profileArn +} + +// findRealThinkingEndTag finds the real end tag, skipping false positives. +// Returns -1 if no real end tag is found. +// +// Real tags from Kiro API have specific characteristics: +// - Usually preceded by newline (.\n) +// - Usually followed by newline (\n\n) +// - Not inside code blocks or inline code +// +// False positives (discussion text) have characteristics: +// - In the middle of a sentence +// - Preceded by discussion words like "标签", "tag", "returns" +// - Inside code blocks or inline code +// +// Parameters: +// - content: the content to search in +// - alreadyInCodeBlock: whether we're already inside a code block from previous chunks +// - alreadyInInlineCode: whether we're already inside inline code from previous chunks + +// determineAgenticMode determines if the model is an agentic or chat-only variant. +// Returns (isAgentic, isChatOnly) based on model name suffixes. +func determineAgenticMode(model string) (isAgentic, isChatOnly bool) { + isAgentic = strings.HasSuffix(model, "-agentic") + isChatOnly = strings.HasSuffix(model, "-chat") + return isAgentic, isChatOnly +} + +func getMetadataString(metadata map[string]any, keys ...string) string { + if metadata == nil { + return "" + } + for _, key := range keys { + if value, ok := metadata[key].(string); ok { + trimmed := strings.TrimSpace(value) + if trimmed != "" { + return trimmed + } + } + } + return "" +} + +// getEffectiveProfileArn determines if profileArn should be included based on auth method. +// profileArn is only needed for social auth (Google OAuth), not for AWS SSO OIDC (Builder ID/IDC). +// +// Detection logic (matching kiro-openai-gateway): +// 1. Check auth_method field: "builder-id" or "idc" +// 2. Check auth_type field: "aws_sso_oidc" (from kiro-cli tokens) +// 3. Check for client_id + client_secret presence (AWS SSO OIDC signature) + +// getEffectiveProfileArnWithWarning determines if profileArn should be included based on auth method, +// and logs a warning if profileArn is missing for non-builder-id auth. +// This consolidates the auth_method check that was previously done separately. +// +// AWS SSO OIDC (Builder ID/IDC) users don't need profileArn - sending it causes 403 errors. +// Only Kiro Desktop (social auth like Google/GitHub) users need profileArn. +// +// Detection logic (matching kiro-openai-gateway): +// 1. Check auth_method field: "builder-id" or "idc" +// 2. Check auth_type field: "aws_sso_oidc" (from kiro-cli tokens) +// 3. Check for client_id + client_secret presence (AWS SSO OIDC signature) +func getEffectiveProfileArnWithWarning(auth *cliproxyauth.Auth, profileArn string) string { + if auth != nil && auth.Metadata != nil { + // Check 1: auth_method field (from CLIProxyAPI tokens) + authMethod := strings.ToLower(getMetadataString(auth.Metadata, "auth_method", "authMethod")) + if authMethod == "builder-id" || authMethod == "idc" { + return "" // AWS SSO OIDC - don't include profileArn + } + // Check 2: auth_type field (from kiro-cli tokens) + if authType, ok := auth.Metadata["auth_type"].(string); ok && authType == "aws_sso_oidc" { + return "" // AWS SSO OIDC - don't include profileArn + } + // Check 3: client_id + client_secret presence (AWS SSO OIDC signature, like kiro-openai-gateway) + clientID := getMetadataString(auth.Metadata, "client_id", "clientId") + clientSecret := getMetadataString(auth.Metadata, "client_secret", "clientSecret") + if clientID != "" && clientSecret != "" { + return "" // AWS SSO OIDC - don't include profileArn + } + } + // For social auth (Kiro Desktop), profileArn is required + if profileArn == "" { + log.Warnf("kiro: profile ARN not found in auth, API calls may fail") + } + return profileArn +} + +// mapModelToKiro maps external model names to Kiro model IDs. +// Supports both Kiro and Amazon Q prefixes since they use the same API. +// Agentic variants (-agentic suffix) map to the same backend model IDs. +func (e *KiroExecutor) mapModelToKiro(model string) string { + modelMap := map[string]string{ + // Amazon Q format (amazonq- prefix) - same API as Kiro + "amazonq-auto": "auto", + "amazonq-claude-opus-4-6": "claude-opus-4.6", + "amazonq-claude-sonnet-4-6": "claude-sonnet-4.6", + "amazonq-claude-opus-4-5": "claude-opus-4.5", + "amazonq-claude-sonnet-4-5": "claude-sonnet-4.5", + "amazonq-claude-sonnet-4-5-20250929": "claude-sonnet-4.5", + "amazonq-claude-sonnet-4": "claude-sonnet-4", + "amazonq-claude-sonnet-4-20250514": "claude-sonnet-4", + "amazonq-claude-haiku-4-5": "claude-haiku-4.5", + // Kiro format (kiro- prefix) - valid model names that should be preserved + "kiro-claude-opus-4-6": "claude-opus-4.6", + "kiro-claude-sonnet-4-6": "claude-sonnet-4.6", + "kiro-claude-opus-4-5": "claude-opus-4.5", + "kiro-claude-sonnet-4-5": "claude-sonnet-4.5", + "kiro-claude-sonnet-4-5-20250929": "claude-sonnet-4.5", + "kiro-claude-sonnet-4": "claude-sonnet-4", + "kiro-claude-sonnet-4-20250514": "claude-sonnet-4", + "kiro-claude-haiku-4-5": "claude-haiku-4.5", + "kiro-auto": "auto", + // Native format (no prefix) - used by Kiro IDE directly + "claude-opus-4-6": "claude-opus-4.6", + "claude-opus-4.6": "claude-opus-4.6", + "claude-sonnet-4-6": "claude-sonnet-4.6", + "claude-sonnet-4.6": "claude-sonnet-4.6", + "claude-opus-4-5": "claude-opus-4.5", + "claude-opus-4.5": "claude-opus-4.5", + "claude-haiku-4-5": "claude-haiku-4.5", + "claude-haiku-4.5": "claude-haiku-4.5", + "claude-sonnet-4-5": "claude-sonnet-4.5", + "claude-sonnet-4-5-20250929": "claude-sonnet-4.5", + "claude-sonnet-4.5": "claude-sonnet-4.5", + "claude-sonnet-4": "claude-sonnet-4", + "claude-sonnet-4-20250514": "claude-sonnet-4", + "auto": "auto", + // Agentic variants (same backend model IDs, but with special system prompt) + "claude-opus-4.6-agentic": "claude-opus-4.6", + "claude-sonnet-4.6-agentic": "claude-sonnet-4.6", + "claude-opus-4.5-agentic": "claude-opus-4.5", + "claude-sonnet-4.5-agentic": "claude-sonnet-4.5", + "claude-sonnet-4-agentic": "claude-sonnet-4", + "claude-haiku-4.5-agentic": "claude-haiku-4.5", + "kiro-claude-opus-4-6-agentic": "claude-opus-4.6", + "kiro-claude-sonnet-4-6-agentic": "claude-sonnet-4.6", + "kiro-claude-opus-4-5-agentic": "claude-opus-4.5", + "kiro-claude-sonnet-4-5-agentic": "claude-sonnet-4.5", + "kiro-claude-sonnet-4-agentic": "claude-sonnet-4", + "kiro-claude-haiku-4-5-agentic": "claude-haiku-4.5", + } + if kiroID, ok := modelMap[model]; ok { + return kiroID + } + + // Smart fallback: try to infer model type from name patterns + modelLower := strings.ToLower(model) + + // Check for Haiku variants + if strings.Contains(modelLower, "haiku") { + log.Debug("kiro: unknown haiku variant, mapping to claude-haiku-4.5") + return "claude-haiku-4.5" + } + + // Check for Sonnet variants + if strings.Contains(modelLower, "sonnet") { + // Check for specific version patterns + if strings.Contains(modelLower, "3-7") || strings.Contains(modelLower, "3.7") { + log.Debug("kiro: unknown sonnet 3.7 variant, mapping to claude-3-7-sonnet-20250219") + return "claude-3-7-sonnet-20250219" + } + if strings.Contains(modelLower, "4-6") || strings.Contains(modelLower, "4.6") { + log.Debug("kiro: unknown sonnet 4.6 variant, mapping to claude-sonnet-4.6") + return "claude-sonnet-4.6" + } + if strings.Contains(modelLower, "4-5") || strings.Contains(modelLower, "4.5") { + log.Debug("kiro: unknown Sonnet 4.5 model, mapping to claude-sonnet-4.5") + return "claude-sonnet-4.5" + } + } + + // Check for Opus variants + if strings.Contains(modelLower, "opus") { + if strings.Contains(modelLower, "4-6") || strings.Contains(modelLower, "4.6") { + log.Debug("kiro: unknown Opus 4.6 model, mapping to claude-opus-4.6") + return "claude-opus-4.6" + } + log.Debug("kiro: unknown opus variant, mapping to claude-opus-4.5") + return "claude-opus-4.5" + } + + // Final fallback to Sonnet 4.5 (most commonly used model) + log.Warn("kiro: unknown model variant, falling back to claude-sonnet-4.5") + return "claude-sonnet-4.5" +} + +func kiroModelFingerprint(model string) string { + trimmed := strings.TrimSpace(model) + if trimmed == "" { + return "" + } + sum := sha256.Sum256([]byte(trimmed)) + return hex.EncodeToString(sum[:8]) +} diff --git a/pkg/llmproxy/usage/metrics.go b/pkg/llmproxy/usage/metrics.go index f41dc58ad6..f4b157872c 100644 --- a/pkg/llmproxy/usage/metrics.go +++ b/pkg/llmproxy/usage/metrics.go @@ -4,7 +4,7 @@ package usage import ( "strings" - "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/util" + "github.com/router-for-me/CLIProxyAPI/v6/pkg/llmproxy/util" ) func normalizeProvider(apiKey string) string {