diff --git a/src/assets/__tests__/__snapshots__/assets.snapshot.test.ts.snap b/src/assets/__tests__/__snapshots__/assets.snapshot.test.ts.snap index d58f06b0..c4273fa2 100644 --- a/src/assets/__tests__/__snapshots__/assets.snapshot.test.ts.snap +++ b/src/assets/__tests__/__snapshots__/assets.snapshot.test.ts.snap @@ -1726,6 +1726,41 @@ logger = logging.getLogger(__name__) import httpx from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session {{/if}} +{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}} +{{#unless (includes gatewayAuthTypes "AWS_IAM")}}import httpx +{{/unless}}import time as _time +{{/if}} + +{{#each gatewayProviders}} +{{#if (eq authType "CUSTOM_JWT")}} +_token_cache_{{snakeCase name}} = {"token": None, "expires_at": 0} + +def _get_bearer_token_{{snakeCase name}}(): + """Obtain OAuth access token via client_credentials grant for {{name}}.""" + cache = _token_cache_{{snakeCase name}} + if cache["token"] and _time.time() < cache["expires_at"]: + return cache["token"] + client_id = os.environ.get("{{credentialEnvVarBase}}_CLIENT_ID") + client_secret = os.environ.get("{{credentialEnvVarBase}}_CLIENT_SECRET") + if not client_id or not client_secret: + logger.warning("Agent OAuth credentials not set — {{name}} CUSTOM_JWT auth unavailable") + return None + with httpx.Client() as c: + disc = c.get("{{discoveryUrl}}") + token_ep = disc.json()["token_endpoint"] + resp = c.post(token_ep, data={ + "grant_type": "client_credentials", + "client_id": client_id, + "client_secret": client_secret, + {{#if scopes}}"scope": "{{scopes}}",{{/if}} + }) + data = resp.json() + cache["token"] = data["access_token"] + cache["expires_at"] = _time.time() + data.get("expires_in", 3600) - 60 + return cache["token"] + +{{/if}} +{{/each}} def get_all_gateway_mcp_toolsets() -> list[MCPToolset]: """Returns MCP Toolsets for all configured gateways.""" @@ -1740,6 +1775,10 @@ def get_all_gateway_mcp_toolsets() -> list[MCPToolset]: url=url, httpx_client_factory=lambda **kwargs: httpx.AsyncClient(auth=auth, **kwargs) ))) + {{else if (eq authType "CUSTOM_JWT")}} + token = _get_bearer_token_{{snakeCase name}}() + headers = {"Authorization": f"Bearer {token}"} if token else None + toolsets.append(MCPToolset(connection_params=StreamableHTTPConnectionParams(url=url, headers=headers))) {{else}} toolsets.append(MCPToolset(connection_params=StreamableHTTPConnectionParams(url=url))) {{/if}} @@ -2012,6 +2051,41 @@ logger = logging.getLogger(__name__) {{#if (includes gatewayAuthTypes "AWS_IAM")}} from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session {{/if}} +{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}} +{{#unless (includes gatewayAuthTypes "AWS_IAM")}}import httpx +{{/unless}}import time as _time +{{/if}} + +{{#each gatewayProviders}} +{{#if (eq authType "CUSTOM_JWT")}} +_token_cache_{{snakeCase name}} = {"token": None, "expires_at": 0} + +def _get_bearer_token_{{snakeCase name}}(): + """Obtain OAuth access token via client_credentials grant for {{name}}.""" + cache = _token_cache_{{snakeCase name}} + if cache["token"] and _time.time() < cache["expires_at"]: + return cache["token"] + client_id = os.environ.get("{{credentialEnvVarBase}}_CLIENT_ID") + client_secret = os.environ.get("{{credentialEnvVarBase}}_CLIENT_SECRET") + if not client_id or not client_secret: + logger.warning("Agent OAuth credentials not set — {{name}} CUSTOM_JWT auth unavailable") + return None + with httpx.Client() as c: + disc = c.get("{{discoveryUrl}}") + token_ep = disc.json()["token_endpoint"] + resp = c.post(token_ep, data={ + "grant_type": "client_credentials", + "client_id": client_id, + "client_secret": client_secret, + {{#if scopes}}"scope": "{{scopes}}",{{/if}} + }) + data = resp.json() + cache["token"] = data["access_token"] + cache["expires_at"] = _time.time() + data.get("expires_in", 3600) - 60 + return cache["token"] + +{{/if}} +{{/each}} def get_all_gateway_mcp_client() -> MultiServerMCPClient | None: """Returns an MCP Client connected to all configured gateways.""" @@ -2023,6 +2097,10 @@ def get_all_gateway_mcp_client() -> MultiServerMCPClient | None: session = create_aws_session() auth = SigV4HTTPXAuth(session.get_credentials(), "bedrock-agentcore", session.region_name) servers["{{name}}"] = {"transport": "streamable_http", "url": url, "auth": auth} + {{else if (eq authType "CUSTOM_JWT")}} + token = _get_bearer_token_{{snakeCase name}}() + headers = {"Authorization": f"Bearer {token}"} if token else None + servers["{{name}}"] = {"transport": "streamable_http", "url": url, "headers": headers} {{else}} servers["{{name}}"] = {"transport": "streamable_http", "url": url} {{/if}} @@ -2438,6 +2516,41 @@ logger = logging.getLogger(__name__) import httpx from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session {{/if}} +{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}} +{{#unless (includes gatewayAuthTypes "AWS_IAM")}}import httpx +{{/unless}}import time as _time +{{/if}} + +{{#each gatewayProviders}} +{{#if (eq authType "CUSTOM_JWT")}} +_token_cache_{{snakeCase name}} = {"token": None, "expires_at": 0} + +def _get_bearer_token_{{snakeCase name}}(): + """Obtain OAuth access token via client_credentials grant for {{name}}.""" + cache = _token_cache_{{snakeCase name}} + if cache["token"] and _time.time() < cache["expires_at"]: + return cache["token"] + client_id = os.environ.get("{{credentialEnvVarBase}}_CLIENT_ID") + client_secret = os.environ.get("{{credentialEnvVarBase}}_CLIENT_SECRET") + if not client_id or not client_secret: + logger.warning("Agent OAuth credentials not set — {{name}} CUSTOM_JWT auth unavailable") + return None + with httpx.Client() as c: + disc = c.get("{{discoveryUrl}}") + token_ep = disc.json()["token_endpoint"] + resp = c.post(token_ep, data={ + "grant_type": "client_credentials", + "client_id": client_id, + "client_secret": client_secret, + {{#if scopes}}"scope": "{{scopes}}",{{/if}} + }) + data = resp.json() + cache["token"] = data["access_token"] + cache["expires_at"] = _time.time() + data.get("expires_in", 3600) - 60 + return cache["token"] + +{{/if}} +{{/each}} def get_all_gateway_mcp_servers() -> list[MCPServerStreamableHttp]: """Returns MCP servers for all configured gateways.""" @@ -2452,6 +2565,10 @@ def get_all_gateway_mcp_servers() -> list[MCPServerStreamableHttp]: name="{{name}}", params={"url": url, "httpx_client_factory": lambda **kwargs: httpx.AsyncClient(auth=auth, **kwargs)} )) + {{else if (eq authType "CUSTOM_JWT")}} + token = _get_bearer_token_{{snakeCase name}}() + headers = {"Authorization": f"Bearer {token}"} if token else {} + servers.append(MCPServerStreamableHttp(name="{{name}}", params={"url": url, "headers": headers})) {{else}} servers.append(MCPServerStreamableHttp(name="{{name}}", params={"url": url})) {{/if}} @@ -2749,7 +2866,41 @@ logger = logging.getLogger(__name__) {{#if (includes gatewayAuthTypes "AWS_IAM")}} from mcp_proxy_for_aws.client import aws_iam_streamablehttp_client {{/if}} +{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}} +{{#unless (includes gatewayAuthTypes "AWS_IAM")}}import httpx +{{/unless}}import time as _time +{{/if}} + +{{#each gatewayProviders}} +{{#if (eq authType "CUSTOM_JWT")}} +_token_cache_{{snakeCase name}} = {"token": None, "expires_at": 0} + +def _get_bearer_token_{{snakeCase name}}(): + """Obtain OAuth access token via client_credentials grant for {{name}}.""" + cache = _token_cache_{{snakeCase name}} + if cache["token"] and _time.time() < cache["expires_at"]: + return cache["token"] + client_id = os.environ.get("{{credentialEnvVarBase}}_CLIENT_ID") + client_secret = os.environ.get("{{credentialEnvVarBase}}_CLIENT_SECRET") + if not client_id or not client_secret: + logger.warning("Agent OAuth credentials not set — {{name}} CUSTOM_JWT auth unavailable") + return None + with httpx.Client() as c: + disc = c.get("{{discoveryUrl}}") + token_ep = disc.json()["token_endpoint"] + resp = c.post(token_ep, data={ + "grant_type": "client_credentials", + "client_id": client_id, + "client_secret": client_secret, + {{#if scopes}}"scope": "{{scopes}}",{{/if}} + }) + data = resp.json() + cache["token"] = data["access_token"] + cache["expires_at"] = _time.time() + data.get("expires_in", 3600) - 60 + return cache["token"] +{{/if}} +{{/each}} {{#each gatewayProviders}} def get_{{snakeCase name}}_mcp_client() -> MCPClient | None: """Returns an MCP Client connected to the {{name}} gateway.""" @@ -2759,6 +2910,10 @@ def get_{{snakeCase name}}_mcp_client() -> MCPClient | None: return None {{#if (eq authType "AWS_IAM")}} return MCPClient(lambda: aws_iam_streamablehttp_client(url, aws_service="bedrock-agentcore", aws_region=os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION")))) + {{else if (eq authType "CUSTOM_JWT")}} + token = _get_bearer_token_{{snakeCase name}}() + headers = {"Authorization": f"Bearer {token}"} if token else {} + return MCPClient(lambda: streamablehttp_client(url, headers=headers)) {{else}} return MCPClient(lambda: streamablehttp_client(url)) {{/if}} diff --git a/src/assets/python/googleadk/base/mcp_client/client.py b/src/assets/python/googleadk/base/mcp_client/client.py index f2c1a39c..d3cc6cab 100644 --- a/src/assets/python/googleadk/base/mcp_client/client.py +++ b/src/assets/python/googleadk/base/mcp_client/client.py @@ -10,6 +10,41 @@ import httpx from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session {{/if}} +{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}} +{{#unless (includes gatewayAuthTypes "AWS_IAM")}}import httpx +{{/unless}}import time as _time +{{/if}} + +{{#each gatewayProviders}} +{{#if (eq authType "CUSTOM_JWT")}} +_token_cache_{{snakeCase name}} = {"token": None, "expires_at": 0} + +def _get_bearer_token_{{snakeCase name}}(): + """Obtain OAuth access token via client_credentials grant for {{name}}.""" + cache = _token_cache_{{snakeCase name}} + if cache["token"] and _time.time() < cache["expires_at"]: + return cache["token"] + client_id = os.environ.get("{{credentialEnvVarBase}}_CLIENT_ID") + client_secret = os.environ.get("{{credentialEnvVarBase}}_CLIENT_SECRET") + if not client_id or not client_secret: + logger.warning("Agent OAuth credentials not set — {{name}} CUSTOM_JWT auth unavailable") + return None + with httpx.Client() as c: + disc = c.get("{{discoveryUrl}}") + token_ep = disc.json()["token_endpoint"] + resp = c.post(token_ep, data={ + "grant_type": "client_credentials", + "client_id": client_id, + "client_secret": client_secret, + {{#if scopes}}"scope": "{{scopes}}",{{/if}} + }) + data = resp.json() + cache["token"] = data["access_token"] + cache["expires_at"] = _time.time() + data.get("expires_in", 3600) - 60 + return cache["token"] + +{{/if}} +{{/each}} def get_all_gateway_mcp_toolsets() -> list[MCPToolset]: """Returns MCP Toolsets for all configured gateways.""" @@ -24,6 +59,10 @@ def get_all_gateway_mcp_toolsets() -> list[MCPToolset]: url=url, httpx_client_factory=lambda **kwargs: httpx.AsyncClient(auth=auth, **kwargs) ))) + {{else if (eq authType "CUSTOM_JWT")}} + token = _get_bearer_token_{{snakeCase name}}() + headers = {"Authorization": f"Bearer {token}"} if token else None + toolsets.append(MCPToolset(connection_params=StreamableHTTPConnectionParams(url=url, headers=headers))) {{else}} toolsets.append(MCPToolset(connection_params=StreamableHTTPConnectionParams(url=url))) {{/if}} diff --git a/src/assets/python/langchain_langgraph/base/mcp_client/client.py b/src/assets/python/langchain_langgraph/base/mcp_client/client.py index adcb478a..bb78656d 100644 --- a/src/assets/python/langchain_langgraph/base/mcp_client/client.py +++ b/src/assets/python/langchain_langgraph/base/mcp_client/client.py @@ -8,6 +8,41 @@ {{#if (includes gatewayAuthTypes "AWS_IAM")}} from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session {{/if}} +{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}} +{{#unless (includes gatewayAuthTypes "AWS_IAM")}}import httpx +{{/unless}}import time as _time +{{/if}} + +{{#each gatewayProviders}} +{{#if (eq authType "CUSTOM_JWT")}} +_token_cache_{{snakeCase name}} = {"token": None, "expires_at": 0} + +def _get_bearer_token_{{snakeCase name}}(): + """Obtain OAuth access token via client_credentials grant for {{name}}.""" + cache = _token_cache_{{snakeCase name}} + if cache["token"] and _time.time() < cache["expires_at"]: + return cache["token"] + client_id = os.environ.get("{{credentialEnvVarBase}}_CLIENT_ID") + client_secret = os.environ.get("{{credentialEnvVarBase}}_CLIENT_SECRET") + if not client_id or not client_secret: + logger.warning("Agent OAuth credentials not set — {{name}} CUSTOM_JWT auth unavailable") + return None + with httpx.Client() as c: + disc = c.get("{{discoveryUrl}}") + token_ep = disc.json()["token_endpoint"] + resp = c.post(token_ep, data={ + "grant_type": "client_credentials", + "client_id": client_id, + "client_secret": client_secret, + {{#if scopes}}"scope": "{{scopes}}",{{/if}} + }) + data = resp.json() + cache["token"] = data["access_token"] + cache["expires_at"] = _time.time() + data.get("expires_in", 3600) - 60 + return cache["token"] + +{{/if}} +{{/each}} def get_all_gateway_mcp_client() -> MultiServerMCPClient | None: """Returns an MCP Client connected to all configured gateways.""" @@ -19,6 +54,10 @@ def get_all_gateway_mcp_client() -> MultiServerMCPClient | None: session = create_aws_session() auth = SigV4HTTPXAuth(session.get_credentials(), "bedrock-agentcore", session.region_name) servers["{{name}}"] = {"transport": "streamable_http", "url": url, "auth": auth} + {{else if (eq authType "CUSTOM_JWT")}} + token = _get_bearer_token_{{snakeCase name}}() + headers = {"Authorization": f"Bearer {token}"} if token else None + servers["{{name}}"] = {"transport": "streamable_http", "url": url, "headers": headers} {{else}} servers["{{name}}"] = {"transport": "streamable_http", "url": url} {{/if}} diff --git a/src/assets/python/openaiagents/base/mcp_client/client.py b/src/assets/python/openaiagents/base/mcp_client/client.py index 39612c38..762bd11a 100644 --- a/src/assets/python/openaiagents/base/mcp_client/client.py +++ b/src/assets/python/openaiagents/base/mcp_client/client.py @@ -9,6 +9,41 @@ import httpx from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session {{/if}} +{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}} +{{#unless (includes gatewayAuthTypes "AWS_IAM")}}import httpx +{{/unless}}import time as _time +{{/if}} + +{{#each gatewayProviders}} +{{#if (eq authType "CUSTOM_JWT")}} +_token_cache_{{snakeCase name}} = {"token": None, "expires_at": 0} + +def _get_bearer_token_{{snakeCase name}}(): + """Obtain OAuth access token via client_credentials grant for {{name}}.""" + cache = _token_cache_{{snakeCase name}} + if cache["token"] and _time.time() < cache["expires_at"]: + return cache["token"] + client_id = os.environ.get("{{credentialEnvVarBase}}_CLIENT_ID") + client_secret = os.environ.get("{{credentialEnvVarBase}}_CLIENT_SECRET") + if not client_id or not client_secret: + logger.warning("Agent OAuth credentials not set — {{name}} CUSTOM_JWT auth unavailable") + return None + with httpx.Client() as c: + disc = c.get("{{discoveryUrl}}") + token_ep = disc.json()["token_endpoint"] + resp = c.post(token_ep, data={ + "grant_type": "client_credentials", + "client_id": client_id, + "client_secret": client_secret, + {{#if scopes}}"scope": "{{scopes}}",{{/if}} + }) + data = resp.json() + cache["token"] = data["access_token"] + cache["expires_at"] = _time.time() + data.get("expires_in", 3600) - 60 + return cache["token"] + +{{/if}} +{{/each}} def get_all_gateway_mcp_servers() -> list[MCPServerStreamableHttp]: """Returns MCP servers for all configured gateways.""" @@ -23,6 +58,10 @@ def get_all_gateway_mcp_servers() -> list[MCPServerStreamableHttp]: name="{{name}}", params={"url": url, "httpx_client_factory": lambda **kwargs: httpx.AsyncClient(auth=auth, **kwargs)} )) + {{else if (eq authType "CUSTOM_JWT")}} + token = _get_bearer_token_{{snakeCase name}}() + headers = {"Authorization": f"Bearer {token}"} if token else {} + servers.append(MCPServerStreamableHttp(name="{{name}}", params={"url": url, "headers": headers})) {{else}} servers.append(MCPServerStreamableHttp(name="{{name}}", params={"url": url})) {{/if}} diff --git a/src/assets/python/strands/base/mcp_client/client.py b/src/assets/python/strands/base/mcp_client/client.py index 3b77cdac..8e71d5a0 100644 --- a/src/assets/python/strands/base/mcp_client/client.py +++ b/src/assets/python/strands/base/mcp_client/client.py @@ -9,7 +9,41 @@ {{#if (includes gatewayAuthTypes "AWS_IAM")}} from mcp_proxy_for_aws.client import aws_iam_streamablehttp_client {{/if}} +{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}} +{{#unless (includes gatewayAuthTypes "AWS_IAM")}}import httpx +{{/unless}}import time as _time +{{/if}} + +{{#each gatewayProviders}} +{{#if (eq authType "CUSTOM_JWT")}} +_token_cache_{{snakeCase name}} = {"token": None, "expires_at": 0} +def _get_bearer_token_{{snakeCase name}}(): + """Obtain OAuth access token via client_credentials grant for {{name}}.""" + cache = _token_cache_{{snakeCase name}} + if cache["token"] and _time.time() < cache["expires_at"]: + return cache["token"] + client_id = os.environ.get("{{credentialEnvVarBase}}_CLIENT_ID") + client_secret = os.environ.get("{{credentialEnvVarBase}}_CLIENT_SECRET") + if not client_id or not client_secret: + logger.warning("Agent OAuth credentials not set — {{name}} CUSTOM_JWT auth unavailable") + return None + with httpx.Client() as c: + disc = c.get("{{discoveryUrl}}") + token_ep = disc.json()["token_endpoint"] + resp = c.post(token_ep, data={ + "grant_type": "client_credentials", + "client_id": client_id, + "client_secret": client_secret, + {{#if scopes}}"scope": "{{scopes}}",{{/if}} + }) + data = resp.json() + cache["token"] = data["access_token"] + cache["expires_at"] = _time.time() + data.get("expires_in", 3600) - 60 + return cache["token"] + +{{/if}} +{{/each}} {{#each gatewayProviders}} def get_{{snakeCase name}}_mcp_client() -> MCPClient | None: """Returns an MCP Client connected to the {{name}} gateway.""" @@ -19,6 +53,10 @@ def get_{{snakeCase name}}_mcp_client() -> MCPClient | None: return None {{#if (eq authType "AWS_IAM")}} return MCPClient(lambda: aws_iam_streamablehttp_client(url, aws_service="bedrock-agentcore", aws_region=os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION")))) + {{else if (eq authType "CUSTOM_JWT")}} + token = _get_bearer_token_{{snakeCase name}}() + headers = {"Authorization": f"Bearer {token}"} if token else {} + return MCPClient(lambda: streamablehttp_client(url, headers=headers)) {{else}} return MCPClient(lambda: streamablehttp_client(url)) {{/if}} diff --git a/src/cli/commands/add/__tests__/validate.test.ts b/src/cli/commands/add/__tests__/validate.test.ts index 40f5dfec..9b495a09 100644 --- a/src/cli/commands/add/__tests__/validate.test.ts +++ b/src/cli/commands/add/__tests__/validate.test.ts @@ -235,6 +235,47 @@ describe('validate', () => { expect(validateAddGatewayOptions(validGatewayOptionsNone)).toEqual({ valid: true }); expect(validateAddGatewayOptions(validGatewayOptionsJwt)).toEqual({ valid: true }); }); + + // AC15: agentClientId and agentClientSecret must be provided together + it('returns error when agentClientId provided without agentClientSecret', () => { + const result = validateAddGatewayOptions({ + ...validGatewayOptionsJwt, + agentClientId: 'my-client-id', + }); + expect(result.valid).toBe(false); + expect(result.error).toBe('Both --agent-client-id and --agent-client-secret must be provided together'); + }); + + it('returns error when agentClientSecret provided without agentClientId', () => { + const result = validateAddGatewayOptions({ + ...validGatewayOptionsJwt, + agentClientSecret: 'my-secret', + }); + expect(result.valid).toBe(false); + expect(result.error).toBe('Both --agent-client-id and --agent-client-secret must be provided together'); + }); + + // AC16: agent credentials only valid with CUSTOM_JWT + it('returns error when agent credentials used with non-CUSTOM_JWT authorizer', () => { + const result = validateAddGatewayOptions({ + ...validGatewayOptionsNone, + agentClientId: 'my-client-id', + agentClientSecret: 'my-secret', + }); + expect(result.valid).toBe(false); + expect(result.error).toBe('Agent OAuth credentials are only valid with CUSTOM_JWT authorizer'); + }); + + // AC17: valid CUSTOM_JWT with agent credentials passes + it('passes for CUSTOM_JWT with agent credentials', () => { + const result = validateAddGatewayOptions({ + ...validGatewayOptionsJwt, + agentClientId: 'my-client-id', + agentClientSecret: 'my-secret', + allowedScopes: 'scope1,scope2', + }); + expect(result.valid).toBe(true); + }); }); describe('validateAddGatewayTargetOptions', () => { diff --git a/src/cli/commands/add/actions.ts b/src/cli/commands/add/actions.ts index 675c52e8..7232f7c7 100644 --- a/src/cli/commands/add/actions.ts +++ b/src/cli/commands/add/actions.ts @@ -64,6 +64,9 @@ export interface ValidatedAddGatewayOptions { discoveryUrl?: string; allowedAudience?: string; allowedClients?: string; + allowedScopes?: string; + agentClientId?: string; + agentClientSecret?: string; agents?: string; } @@ -267,6 +270,14 @@ function buildGatewayConfig(options: ValidatedAddGatewayOptions): AddGatewayConf .allowedClients!.split(',') .map(s => s.trim()) .filter(Boolean), + allowedScopes: options.allowedScopes + ? options.allowedScopes + .split(',') + .map(s => s.trim()) + .filter(Boolean) + : undefined, + agentClientId: options.agentClientId, + agentClientSecret: options.agentClientSecret, }; } diff --git a/src/cli/commands/add/command.tsx b/src/cli/commands/add/command.tsx index ebc0ff44..22e89dc5 100644 --- a/src/cli/commands/add/command.tsx +++ b/src/cli/commands/add/command.tsx @@ -82,6 +82,9 @@ async function handleAddGatewayCLI(options: AddGatewayOptions): Promise { discoveryUrl: options.discoveryUrl, allowedAudience: options.allowedAudience, allowedClients: options.allowedClients, + allowedScopes: options.allowedScopes, + agentClientId: options.agentClientId, + agentClientSecret: options.agentClientSecret, agents: options.agents, }); @@ -265,13 +268,16 @@ export function registerAdd(program: Command) { // Subcommand: add gateway addCmd .command('gateway') - .description('Add an MCP gateway to the project') + .description('Add a gateway to the project') .option('--name ', 'Gateway name') .option('--description ', 'Gateway description') .option('--authorizer-type ', 'Authorizer type: NONE or CUSTOM_JWT', 'NONE') .option('--discovery-url ', 'OIDC discovery URL (required for CUSTOM_JWT)') .option('--allowed-audience ', 'Comma-separated allowed audience values (required for CUSTOM_JWT)') .option('--allowed-clients ', 'Comma-separated allowed client IDs (required for CUSTOM_JWT)') + .option('--allowed-scopes ', 'Comma-separated allowed scopes (optional for CUSTOM_JWT)') + .option('--agent-client-id ', 'Agent OAuth client ID for Bearer token auth (CUSTOM_JWT)') + .option('--agent-client-secret ', 'Agent OAuth client secret (CUSTOM_JWT)') .option('--json', 'Output as JSON') .action(async options => { requireProject(); diff --git a/src/cli/commands/add/types.ts b/src/cli/commands/add/types.ts index c83db76d..46757121 100644 --- a/src/cli/commands/add/types.ts +++ b/src/cli/commands/add/types.ts @@ -31,6 +31,9 @@ export interface AddGatewayOptions { discoveryUrl?: string; allowedAudience?: string; allowedClients?: string; + allowedScopes?: string; + agentClientId?: string; + agentClientSecret?: string; agents?: string; json?: boolean; } diff --git a/src/cli/commands/add/validate.ts b/src/cli/commands/add/validate.ts index b11e2aa7..aa804c36 100644 --- a/src/cli/commands/add/validate.ts +++ b/src/cli/commands/add/validate.ts @@ -180,6 +180,17 @@ export function validateAddGatewayOptions(options: AddGatewayOptions): Validatio } } + // Validate agent OAuth credentials + if (options.agentClientId && !options.agentClientSecret) { + return { valid: false, error: 'Both --agent-client-id and --agent-client-secret must be provided together' }; + } + if (options.agentClientSecret && !options.agentClientId) { + return { valid: false, error: 'Both --agent-client-id and --agent-client-secret must be provided together' }; + } + if (options.agentClientId && options.authorizerType !== 'CUSTOM_JWT') { + return { valid: false, error: 'Agent OAuth credentials are only valid with CUSTOM_JWT authorizer' }; + } + return { valid: true }; } diff --git a/src/cli/commands/remove/actions.ts b/src/cli/commands/remove/actions.ts index 35681c69..74604ea2 100644 --- a/src/cli/commands/remove/actions.ts +++ b/src/cli/commands/remove/actions.ts @@ -72,7 +72,7 @@ export async function handleRemove(options: ValidatedRemoveOptions): Promise ({ - name: gateway.name, - envVarName: computeDefaultGatewayEnvVarName(gateway.name), - authType: gateway.authorizerType, - })); + const project = await configIO.readProjectSpec(); + + return mcpSpec.agentCoreGateways.map(gateway => { + const config: GatewayProviderRenderConfig = { + name: gateway.name, + envVarName: computeDefaultGatewayEnvVarName(gateway.name), + authType: gateway.authorizerType, + }; + + if (gateway.authorizerType === 'CUSTOM_JWT' && gateway.authorizerConfiguration?.customJwtAuthorizer) { + const jwtConfig = gateway.authorizerConfiguration.customJwtAuthorizer; + const credName = `${gateway.name}-agent-oauth`; + const credential = project.credentials.find(c => c.name === credName); + + if (credential) { + config.credentialEnvVarBase = computeDefaultCredentialEnvVarName(credName); + config.discoveryUrl = jwtConfig.discoveryUrl; + const scopes = 'allowedScopes' in jwtConfig ? (jwtConfig as { allowedScopes?: string[] }).allowedScopes : undefined; + if (scopes?.length) { + config.scopes = scopes.join(' '); + } + } + } + + return config; + }); } catch { return []; } diff --git a/src/cli/operations/identity/create-identity.ts b/src/cli/operations/identity/create-identity.ts index 6c6705bb..f42bee61 100644 --- a/src/cli/operations/identity/create-identity.ts +++ b/src/cli/operations/identity/create-identity.ts @@ -108,6 +108,19 @@ export async function getAllCredentialNames(): Promise { } } +/** + * Get list of existing credentials with full type information from the project. + */ +export async function getAllCredentials(): Promise { + try { + const configIO = new ConfigIO(); + const project = await configIO.readProjectSpec(); + return project.credentials; + } catch { + return []; + } +} + /** * Create a credential resource and add it to the project. * Writes the credential config to agentcore.json and secrets to .env.local. diff --git a/src/cli/operations/mcp/create-mcp.ts b/src/cli/operations/mcp/create-mcp.ts index aefd731c..46ed9c5e 100644 --- a/src/cli/operations/mcp/create-mcp.ts +++ b/src/cli/operations/mcp/create-mcp.ts @@ -1,4 +1,4 @@ -import { ConfigIO, requireConfigRoot } from '../../../lib'; +import { ConfigIO, requireConfigRoot, setEnvVar } from '../../../lib'; import type { AgentCoreCliMcpDefs, AgentCoreGateway, @@ -10,6 +10,7 @@ import type { import { AgentCoreCliMcpDefsSchema, ToolDefinitionSchema } from '../../../schema'; import { getTemplateToolDefinitions, renderGatewayTargetTemplate } from '../../templates/GatewayTargetRenderer'; import type { AddGatewayConfig, AddGatewayTargetConfig } from '../../tui/screens/mcp/types'; +import { computeDefaultCredentialEnvVarName } from '../identity/create-identity'; import { DEFAULT_HANDLER, DEFAULT_NODE_VERSION, @@ -76,6 +77,7 @@ function buildAuthorizerConfiguration(config: AddGatewayConfig): AgentCoreGatewa discoveryUrl: config.jwtConfig.discoveryUrl, allowedAudience: config.jwtConfig.allowedAudience, allowedClients: config.jwtConfig.allowedClients, + ...(config.jwtConfig.allowedScopes?.length && { allowedScopes: config.jwtConfig.allowedScopes }), }, }; } @@ -206,6 +208,28 @@ export async function createGatewayFromWizard(config: AddGatewayConfig): Promise await configIO.writeMcpSpec(mcpSpec); + // Auto-create managed credential if agent OAuth credentials provided + if (config.jwtConfig?.agentClientId && config.jwtConfig?.agentClientSecret) { + const credName = `${config.name}-agent-oauth`; + const project = await configIO.readProjectSpec(); + + const credential = { + type: 'OAuthCredentialProvider' as const, + name: credName, + discoveryUrl: config.jwtConfig.discoveryUrl, + vendor: 'CustomOauth2', + managed: true, + usage: 'inbound' as const, + }; + + project.credentials.push(credential); + await configIO.writeProjectSpec(project); + + const envBase = computeDefaultCredentialEnvVarName(credName); + await setEnvVar(`${envBase}_CLIENT_ID`, config.jwtConfig.agentClientId); + await setEnvVar(`${envBase}_CLIENT_SECRET`, config.jwtConfig.agentClientSecret); + } + return { name: config.name }; } diff --git a/src/cli/operations/remove/__tests__/remove-identity.test.ts b/src/cli/operations/remove/__tests__/remove-identity.test.ts index b6172a33..d5f97e90 100644 --- a/src/cli/operations/remove/__tests__/remove-identity.test.ts +++ b/src/cli/operations/remove/__tests__/remove-identity.test.ts @@ -1,8 +1,9 @@ -import { previewRemoveCredential } from '../remove-identity.js'; +import { previewRemoveCredential, removeCredential } from '../remove-identity.js'; import { describe, expect, it, vi } from 'vitest'; -const { mockReadProjectSpec, mockConfigExists, mockReadMcpSpec } = vi.hoisted(() => ({ +const { mockReadProjectSpec, mockWriteProjectSpec, mockConfigExists, mockReadMcpSpec } = vi.hoisted(() => ({ mockReadProjectSpec: vi.fn(), + mockWriteProjectSpec: vi.fn(), mockConfigExists: vi.fn(), mockReadMcpSpec: vi.fn(), })); @@ -10,6 +11,7 @@ const { mockReadProjectSpec, mockConfigExists, mockReadMcpSpec } = vi.hoisted(() vi.mock('../../../../lib/index.js', () => ({ ConfigIO: class { readProjectSpec = mockReadProjectSpec; + writeProjectSpec = mockWriteProjectSpec; configExists = mockConfigExists; readMcpSpec = mockReadMcpSpec; }, @@ -118,4 +120,55 @@ describe('previewRemoveCredential', () => { 'Warning: Credential "test-cred" is referenced by gateway targets: gateway2/target2. Removing it may break these targets.' ); }); + + it('shows managed credential warning in preview', async () => { + mockReadProjectSpec.mockResolvedValue({ + credentials: [{ name: 'gw-agent-oauth', type: 'OAuthCredentialProvider', managed: true, usage: 'inbound' }], + }); + mockConfigExists.mockReturnValue(false); + + const result = await previewRemoveCredential('gw-agent-oauth'); + + const warning = result.summary.find(s => s.includes('auto-created')); + expect(warning).toBeTruthy(); + }); +}); + +describe('removeCredential', () => { + it('blocks removal of managed credential without force', async () => { + mockReadProjectSpec.mockResolvedValue({ + credentials: [{ name: 'gw-agent-oauth', type: 'OAuthCredentialProvider', managed: true, usage: 'inbound' }], + }); + mockConfigExists.mockReturnValue(false); + + const result = await removeCredential('gw-agent-oauth'); + + expect(result.ok).toBe(false); + expect(result.error).toContain('auto-created'); + expect(result.error).toContain('--force'); + }); + + it('allows removal of managed credential with force', async () => { + mockReadProjectSpec.mockResolvedValue({ + credentials: [{ name: 'gw-agent-oauth', type: 'OAuthCredentialProvider', managed: true, usage: 'inbound' }], + }); + mockConfigExists.mockReturnValue(false); + mockWriteProjectSpec.mockResolvedValue(undefined); + + const result = await removeCredential('gw-agent-oauth', { force: true }); + + expect(result.ok).toBe(true); + }); + + it('allows removal of non-managed credential without force', async () => { + mockReadProjectSpec.mockResolvedValue({ + credentials: [{ name: 'regular-cred', type: 'OAuthCredentialProvider' }], + }); + mockConfigExists.mockReturnValue(false); + mockWriteProjectSpec.mockResolvedValue(undefined); + + const result = await removeCredential('regular-cred'); + + expect(result.ok).toBe(true); + }); }); diff --git a/src/cli/operations/remove/remove-identity.ts b/src/cli/operations/remove/remove-identity.ts index 68c9e417..6c560c64 100644 --- a/src/cli/operations/remove/remove-identity.ts +++ b/src/cli/operations/remove/remove-identity.ts @@ -43,6 +43,12 @@ export async function previewRemoveCredential(credentialName: string): Promise { +export async function removeCredential(credentialName: string, options?: { force?: boolean }): Promise { try { const configIO = new ConfigIO(); const project = await configIO.readProjectSpec(); @@ -95,6 +101,16 @@ export async function removeCredential(credentialName: string): Promise Identity Provider Setup - {missingCredentials.length} identity provider{missingCredentials.length > 1 ? 's' : ''} configured: + {new Set(missingCredentials.map(c => c.providerName)).size} identity provider + {new Set(missingCredentials.map(c => c.providerName)).size > 1 ? 's' : ''} configured: - {missingCredentials.map(cred => ( - - • {cred.providerName} + {[...new Set(missingCredentials.map(c => c.providerName))].map(name => ( + + • {name} ))} - How would you like to provide the API keys? + How would you like to provide the credentials? diff --git a/src/cli/tui/hooks/useCdkPreflight.ts b/src/cli/tui/hooks/useCdkPreflight.ts index c669adaa..89ed181e 100644 --- a/src/cli/tui/hooks/useCdkPreflight.ts +++ b/src/cli/tui/hooks/useCdkPreflight.ts @@ -91,9 +91,8 @@ export interface PreflightResult { const STEP_VALIDATE = 0; const STEP_DEPS = 1; const STEP_BUILD = 2; -const STEP_SYNTH = 3; -const STEP_STACK_STATUS = 4; -// Note: Identity and Bootstrap steps are dynamically appended, use steps.length - 1 to find them +// Note: Identity steps are inserted at index 3+ when needed, shifting synth and stack status down. +// Use findStepIndex() to locate synth and stack status dynamically. const BASE_PREFLIGHT_STEPS: Step[] = [ { label: 'Validate project', status: 'pending' }, @@ -103,7 +102,12 @@ const BASE_PREFLIGHT_STEPS: Step[] = [ { label: 'Check stack status', status: 'pending' }, ]; -const IDENTITY_STEP: Step = { label: 'Set up API key providers', status: 'pending' }; +const LABEL_SYNTH = 'Synthesize CloudFormation'; +const LABEL_STACK_STATUS = 'Check stack status'; +const LABEL_API_KEY = 'Set up API key providers'; +const LABEL_OAUTH = 'Set up OAuth providers'; + +const IDENTITY_STEP: Step = { label: LABEL_API_KEY, status: 'pending' }; const BOOTSTRAP_STEP: Step = { label: 'Bootstrap AWS environment', status: 'pending' }; export function useCdkPreflight(options: PreflightOptions): PreflightResult { @@ -138,6 +142,10 @@ export function useCdkPreflight(options: PreflightOptions): PreflightResult { setSteps(prev => prev.map((s, i) => (i === index ? { ...s, ...update } : s))); }; + const updateStepByLabel = (label: string, update: Partial) => { + setSteps(prev => prev.map(s => (s.label === label ? { ...s, ...update } : s))); + }; + const resetSteps = () => { setSteps(BASE_PREFLIGHT_STEPS.map(s => ({ ...s, status: 'pending' as const }))); }; @@ -380,7 +388,7 @@ export function useCdkPreflight(options: PreflightOptions): PreflightResult { } // Step: Synthesize CloudFormation - updateStep(STEP_SYNTH, { status: 'running' }); + updateStepByLabel(LABEL_SYNTH, { status: 'running' }); logger.startStep('Synthesize CloudFormation'); let synthStackNames: string[]; try { @@ -394,14 +402,17 @@ export function useCdkPreflight(options: PreflightOptions): PreflightResult { synthStackNames = synthResult.stackNames; logger.log(`Stacks: ${synthResult.stackNames.join(', ')}`); logger.endStep('success'); - updateStep(STEP_SYNTH, { status: 'success' }); + updateStepByLabel(LABEL_SYNTH, { status: 'success' }); } catch (err) { const errorMsg = formatError(err); logger.endStep('error', errorMsg); if (isExpiredTokenError(err)) { setHasTokenExpiredError(true); } - updateStep(STEP_SYNTH, { status: 'error', error: logger.getFailureMessage('Synthesize CloudFormation') }); + updateStepByLabel(LABEL_SYNTH, { + status: 'error', + error: logger.getFailureMessage('Synthesize CloudFormation'), + }); setPhase('error'); isRunningRef.current = false; return; @@ -410,34 +421,37 @@ export function useCdkPreflight(options: PreflightOptions): PreflightResult { // Step: Check stack status (ensure stacks are not in UPDATE_IN_PROGRESS etc.) const target = preflightContext.awsTargets[0]; if (target && synthStackNames.length > 0) { - updateStep(STEP_STACK_STATUS, { status: 'running' }); + updateStepByLabel(LABEL_STACK_STATUS, { status: 'running' }); logger.startStep('Check stack status'); try { const stackStatus = await checkStackDeployability(target.region, synthStackNames); if (!stackStatus.canDeploy) { const errorMsg = stackStatus.message ?? `Stack ${stackStatus.blockingStack} is not in a deployable state`; logger.endStep('error', errorMsg); - updateStep(STEP_STACK_STATUS, { status: 'error', error: errorMsg }); + updateStepByLabel(LABEL_STACK_STATUS, { status: 'error', error: errorMsg }); setPhase('error'); isRunningRef.current = false; return; } logger.endStep('success'); - updateStep(STEP_STACK_STATUS, { status: 'success' }); + updateStepByLabel(LABEL_STACK_STATUS, { status: 'success' }); } catch (err) { const errorMsg = formatError(err); logger.endStep('error', errorMsg); if (isExpiredTokenError(err)) { setHasTokenExpiredError(true); } - updateStep(STEP_STACK_STATUS, { status: 'error', error: logger.getFailureMessage('Check stack status') }); + updateStepByLabel(LABEL_STACK_STATUS, { + status: 'error', + error: logger.getFailureMessage('Check stack status'), + }); setPhase('error'); isRunningRef.current = false; return; } } else { // Skip stack status check if no target or no stacks - updateStep(STEP_STACK_STATUS, { status: 'success' }); + updateStepByLabel(LABEL_STACK_STATUS, { status: 'success' }); } // Check if bootstrap is needed @@ -488,16 +502,78 @@ export function useCdkPreflight(options: PreflightOptions): PreflightResult { isRunningRef.current = true; const runIdentitySetup = async () => { - // If user chose to skip, go directly to bootstrap check + // If user chose to skip, go directly to synth if (skipIdentitySetup) { - logger.log('Skipping API key provider setup (user choice)'); + logger.log('Skipping identity provider setup (user choice)'); setSkipIdentitySetup(false); // Reset for next run + // Synthesize CloudFormation + updateStepByLabel(LABEL_SYNTH, { status: 'running' }); + logger.startStep('Synthesize CloudFormation'); + let synthStackNames: string[]; + try { + const synthResult = await synthesizeCdk(context.cdkProject, { + ioHost: switchableIoHost.ioHost, + previousWrapper: wrapperRef.current, + }); + wrapperRef.current = synthResult.toolkitWrapper; + setCdkToolkitWrapper(synthResult.toolkitWrapper); + setStackNames(synthResult.stackNames); + synthStackNames = synthResult.stackNames; + logger.endStep('success'); + updateStepByLabel(LABEL_SYNTH, { status: 'success' }); + } catch (err) { + const errorMsg = formatError(err); + logger.endStep('error', errorMsg); + updateStepByLabel(LABEL_SYNTH, { + status: 'error', + error: logger.getFailureMessage('Synthesize CloudFormation'), + }); + setPhase('error'); + isRunningRef.current = false; + return; + } + + // Check stack status + const target = context.awsTargets[0]; + if (target && synthStackNames.length > 0) { + updateStepByLabel(LABEL_STACK_STATUS, { status: 'running' }); + logger.startStep('Check stack status'); + try { + const stackStatus = await checkStackDeployability(target.region, synthStackNames); + if (!stackStatus.canDeploy) { + const errorMsg = stackStatus.message ?? `Stack ${stackStatus.blockingStack} is not in a deployable state`; + logger.endStep('error', errorMsg); + updateStepByLabel(LABEL_STACK_STATUS, { status: 'error', error: errorMsg }); + setPhase('error'); + isRunningRef.current = false; + return; + } + logger.endStep('success'); + updateStepByLabel(LABEL_STACK_STATUS, { status: 'success' }); + } catch (err) { + const errorMsg = formatError(err); + logger.endStep('error', errorMsg); + if (isExpiredTokenError(err)) { + setHasTokenExpiredError(true); + } + updateStepByLabel(LABEL_STACK_STATUS, { + status: 'error', + error: logger.getFailureMessage('Check stack status'), + }); + setPhase('error'); + isRunningRef.current = false; + return; + } + } else { + updateStepByLabel(LABEL_STACK_STATUS, { status: 'success' }); + } + // Check if bootstrap is needed const bootstrapCheck = await checkBootstrapNeeded(context.awsTargets); if (bootstrapCheck.needsBootstrap && bootstrapCheck.target) { setBootstrapContext({ - toolkitWrapper: wrapperRef.current!, + toolkitWrapper: wrapperRef.current, target: bootstrapCheck.target, }); setPhase('bootstrap-confirm'); @@ -510,15 +586,30 @@ export function useCdkPreflight(options: PreflightOptions): PreflightResult { } // Run identity setup with runtime credentials - setSteps(prev => [...prev, { ...IDENTITY_STEP, status: 'running' }]); - logger.startStep('Set up API key providers'); + // Insert identity steps before synthesize in the step list + const hasApiKeys = hasIdentityApiProviders(context.projectSpec); + const hasOAuth = hasIdentityOAuthProviders(context.projectSpec); + setSteps(prev => { + const synthIndex = prev.findIndex(s => s.label === LABEL_SYNTH); + const identitySteps: Step[] = []; + if (hasApiKeys) identitySteps.push({ ...IDENTITY_STEP, status: 'running' }); + if (hasOAuth) identitySteps.push({ label: LABEL_OAUTH, status: hasApiKeys ? 'pending' : 'running' }); + return [...prev.slice(0, synthIndex), ...identitySteps, ...prev.slice(synthIndex)]; + }); + + if (hasApiKeys) { + logger.startStep('Set up API key providers'); + } const target = context.awsTargets[0]; if (!target) { - logger.endStep('error', 'No AWS target configured'); - setSteps(prev => - prev.map((s, i) => (i === prev.length - 1 ? { ...s, status: 'error', error: 'No AWS target configured' } : s)) - ); + const errorMsg = 'No AWS target configured'; + if (hasApiKeys) { + logger.endStep('error', errorMsg); + updateStepByLabel(LABEL_API_KEY, { status: 'error', error: errorMsg }); + } else if (hasOAuth) { + updateStepByLabel(LABEL_OAUTH, { status: 'error', error: errorMsg }); + } setPhase('error'); isRunningRef.current = false; return; @@ -526,66 +617,69 @@ export function useCdkPreflight(options: PreflightOptions): PreflightResult { try { const configBaseDir = path.dirname(context.cdkProject.projectDir); - const identityResult = await setupApiKeyProviders({ - projectSpec: context.projectSpec, - configBaseDir, - region: target.region, - runtimeCredentials: runtimeCredentials ?? undefined, - enableKmsEncryption: true, - }); - // Log KMS setup - if (identityResult.kmsKeyArn) { - logger.log(`Token vault encrypted with KMS key: ${identityResult.kmsKeyArn}`); - setIdentityKmsKeyArn(identityResult.kmsKeyArn); - } + // Collect credential ARNs for deployed state + const deployedCredentials: Record< + string, + { credentialProviderArn: string; clientSecretArn?: string; callbackUrl?: string } + > = {}; + let kmsKeyArn: string | undefined; - // Log results - for (const result of identityResult.results) { - if (result.status === 'created') { - logger.log(`Created API key provider: ${result.providerName}`); - } else if (result.status === 'updated') { - logger.log(`Updated API key provider: ${result.providerName}`); - } else if (result.status === 'exists') { - logger.log(`API key provider exists: ${result.providerName}`); - } else if (result.status === 'skipped') { - logger.log(`Skipped ${result.providerName}: ${result.error}`); - } else if (result.status === 'error') { - logger.log(`Error for ${result.providerName}: ${result.error}`); + // Set up API key providers if needed + if (hasApiKeys) { + const identityResult = await setupApiKeyProviders({ + projectSpec: context.projectSpec, + configBaseDir, + region: target.region, + runtimeCredentials: runtimeCredentials ?? undefined, + enableKmsEncryption: true, + }); + + // Log KMS setup + if (identityResult.kmsKeyArn) { + logger.log(`Token vault encrypted with KMS key: ${identityResult.kmsKeyArn}`); + kmsKeyArn = identityResult.kmsKeyArn; + setIdentityKmsKeyArn(identityResult.kmsKeyArn); } - } - if (identityResult.hasErrors) { - logger.endStep('error', 'Some API key providers failed to set up'); - setSteps(prev => - prev.map((s, i) => - i === prev.length - 1 ? { ...s, status: 'error', error: 'Some API key providers failed' } : s - ) - ); - setPhase('error'); - isRunningRef.current = false; - return; - } + // Log results + for (const result of identityResult.results) { + if (result.status === 'created') { + logger.log(`Created API key provider: ${result.providerName}`); + } else if (result.status === 'updated') { + logger.log(`Updated API key provider: ${result.providerName}`); + } else if (result.status === 'exists') { + logger.log(`API key provider exists: ${result.providerName}`); + } else if (result.status === 'skipped') { + logger.log(`Skipped ${result.providerName}: ${result.error}`); + } else if (result.status === 'error') { + logger.log(`Error for ${result.providerName}: ${result.error}`); + } + } - logger.endStep('success'); - setSteps(prev => prev.map((s, i) => (i === prev.length - 1 ? { ...s, status: 'success' } : s))); + if (identityResult.hasErrors) { + logger.endStep('error', 'Some API key providers failed to set up'); + updateStepByLabel(LABEL_API_KEY, { status: 'error', error: 'Some API key providers failed' }); + setPhase('error'); + isRunningRef.current = false; + return; + } - // Collect API Key credential ARNs for deployed state - const deployedCredentials: Record< - string, - { credentialProviderArn: string; clientSecretArn?: string; callbackUrl?: string } - > = {}; - for (const result of identityResult.results) { - if (result.credentialProviderArn) { - deployedCredentials[result.providerName] = { - credentialProviderArn: result.credentialProviderArn, - }; + logger.endStep('success'); + updateStepByLabel(LABEL_API_KEY, { status: 'success' }); + + for (const result of identityResult.results) { + if (result.credentialProviderArn) { + deployedCredentials[result.providerName] = { + credentialProviderArn: result.credentialProviderArn, + }; + } } } // Set up OAuth credential providers if needed - if (hasIdentityOAuthProviders(context.projectSpec)) { - setSteps(prev => [...prev, { label: 'Set up OAuth providers', status: 'running' }]); + if (hasOAuth) { + updateStepByLabel(LABEL_OAUTH, { status: 'running' }); logger.startStep('Set up OAuth providers'); const oauthResult = await setupOAuth2Providers({ @@ -609,11 +703,7 @@ export function useCdkPreflight(options: PreflightOptions): PreflightResult { if (oauthResult.hasErrors) { logger.endStep('error', 'Some OAuth providers failed to set up'); - setSteps(prev => - prev.map((s, i) => - i === prev.length - 1 ? { ...s, status: 'error', error: 'Some OAuth providers failed' } : s - ) - ); + updateStepByLabel(LABEL_OAUTH, { status: 'error', error: 'Some OAuth providers failed' }); setPhase('error'); isRunningRef.current = false; return; @@ -637,7 +727,7 @@ export function useCdkPreflight(options: PreflightOptions): PreflightResult { Object.assign(deployedCredentials, creds); logger.endStep('success'); - setSteps(prev => prev.map((s, i) => (i === prev.length - 1 ? { ...s, status: 'success' } : s))); + updateStepByLabel(LABEL_OAUTH, { status: 'success' }); } // Write partial deployed state with credential ARNs before CDK synth @@ -648,7 +738,7 @@ export function useCdkPreflight(options: PreflightOptions): PreflightResult { const targetState = existingState.targets?.[target!.name] ?? { resources: {} }; targetState.resources ??= {}; targetState.resources.credentials = deployedCredentials; - if (identityResult.kmsKeyArn) targetState.resources.identityKmsKeyArn = identityResult.kmsKeyArn; + if (kmsKeyArn) targetState.resources.identityKmsKeyArn = kmsKeyArn; await configIO.writeDeployedState({ ...existingState, targets: { ...existingState.targets, [target!.name]: targetState }, @@ -658,9 +748,10 @@ export function useCdkPreflight(options: PreflightOptions): PreflightResult { // Clear runtime credentials setRuntimeCredentials(null); - // Re-synth now that credentials are in deployed state - updateStep(STEP_SYNTH, { status: 'running' }); + // Synthesize CloudFormation now that credentials are in deployed state + updateStepByLabel(LABEL_SYNTH, { status: 'running' }); logger.startStep('Synthesize CloudFormation'); + let synthStackNames: string[]; try { const synthResult = await synthesizeCdk(context.cdkProject, { ioHost: switchableIoHost.ioHost, @@ -669,17 +760,55 @@ export function useCdkPreflight(options: PreflightOptions): PreflightResult { wrapperRef.current = synthResult.toolkitWrapper; setCdkToolkitWrapper(synthResult.toolkitWrapper); setStackNames(synthResult.stackNames); + synthStackNames = synthResult.stackNames; logger.endStep('success'); - updateStep(STEP_SYNTH, { status: 'success' }); + updateStepByLabel(LABEL_SYNTH, { status: 'success' }); } catch (err) { const errorMsg = formatError(err); logger.endStep('error', errorMsg); - updateStep(STEP_SYNTH, { status: 'error', error: logger.getFailureMessage('Synthesize CloudFormation') }); + updateStepByLabel(LABEL_SYNTH, { + status: 'error', + error: logger.getFailureMessage('Synthesize CloudFormation'), + }); setPhase('error'); isRunningRef.current = false; return; } + // Check stack status + if (target && synthStackNames.length > 0) { + updateStepByLabel(LABEL_STACK_STATUS, { status: 'running' }); + logger.startStep('Check stack status'); + try { + const stackStatus = await checkStackDeployability(target.region, synthStackNames); + if (!stackStatus.canDeploy) { + const errorMsg = stackStatus.message ?? `Stack ${stackStatus.blockingStack} is not in a deployable state`; + logger.endStep('error', errorMsg); + updateStepByLabel(LABEL_STACK_STATUS, { status: 'error', error: errorMsg }); + setPhase('error'); + isRunningRef.current = false; + return; + } + logger.endStep('success'); + updateStepByLabel(LABEL_STACK_STATUS, { status: 'success' }); + } catch (err) { + const errorMsg = formatError(err); + logger.endStep('error', errorMsg); + if (isExpiredTokenError(err)) { + setHasTokenExpiredError(true); + } + updateStepByLabel(LABEL_STACK_STATUS, { + status: 'error', + error: logger.getFailureMessage('Check stack status'), + }); + setPhase('error'); + isRunningRef.current = false; + return; + } + } else { + updateStepByLabel(LABEL_STACK_STATUS, { status: 'success' }); + } + // Check if bootstrap is needed const bootstrapCheck = await checkBootstrapNeeded(context.awsTargets); if (bootstrapCheck.needsBootstrap && bootstrapCheck.target) { @@ -712,7 +841,7 @@ export function useCdkPreflight(options: PreflightOptions): PreflightResult { }; void runIdentitySetup(); - }, [phase, context, skipIdentitySetup, runtimeCredentials, logger]); + }, [phase, context, skipIdentitySetup, runtimeCredentials, logger, switchableIoHost.ioHost]); // Handle bootstrapping phase useEffect(() => { diff --git a/src/cli/tui/hooks/useRemove.ts b/src/cli/tui/hooks/useRemove.ts index dccdb79c..15d047e0 100644 --- a/src/cli/tui/hooks/useRemove.ts +++ b/src/cli/tui/hooks/useRemove.ts @@ -351,7 +351,7 @@ export function useRemoveIdentity() { const remove = useCallback(async (identityName: string, preview?: RemovalPreview): Promise => { setState({ isLoading: true, result: null }); - const result = await removeIdentity(identityName); + const result = await removeIdentity(identityName, { force: true }); setState({ isLoading: false, result }); let logPath: string | undefined; diff --git a/src/cli/tui/screens/add/AddFlow.tsx b/src/cli/tui/screens/add/AddFlow.tsx index fbead7de..313b439f 100644 --- a/src/cli/tui/screens/add/AddFlow.tsx +++ b/src/cli/tui/screens/add/AddFlow.tsx @@ -361,10 +361,6 @@ export function AddFlow(props: AddFlowProps) { // Identity wizard - now uses AddIdentityFlow with mode selection if (flow.name === 'identity-wizard') { - // Wait for agents to load before rendering wizard - if (agents.length === 0) { - return null; - } return ( ADD_RESOURCES.map(r => ({ ...r, - disabled: Boolean('disabled' in r && r.disabled) || ((r.id === 'memory' || r.id === 'identity') && !hasAgents), - description: (r.id === 'memory' || r.id === 'identity') && !hasAgents ? 'Add an agent first' : r.description, + disabled: Boolean('disabled' in r && r.disabled) || (r.id === 'memory' && !hasAgents), + description: r.id === 'memory' && !hasAgents ? 'Add an agent first' : r.description, })), [hasAgents] ); diff --git a/src/cli/tui/screens/identity/AddIdentityFlow.tsx b/src/cli/tui/screens/identity/AddIdentityFlow.tsx index 093331dc..5240dfb7 100644 --- a/src/cli/tui/screens/identity/AddIdentityFlow.tsx +++ b/src/cli/tui/screens/identity/AddIdentityFlow.tsx @@ -35,11 +35,26 @@ export function AddIdentityFlow({ isInteractive = true, onExit, onBack, onDev, o const handleCreateComplete = useCallback( (config: AddIdentityConfig) => { - void createIdentity({ - type: 'ApiKeyCredentialProvider', - name: config.name, - apiKey: config.apiKey, - }).then(result => { + const createConfig = + config.identityType === 'OAuthCredentialProvider' + ? { + type: 'OAuthCredentialProvider' as const, + name: config.name, + discoveryUrl: config.discoveryUrl!, + clientId: config.clientId!, + clientSecret: config.clientSecret!, + scopes: config.scopes + ?.split(',') + .map(s => s.trim()) + .filter(Boolean), + } + : { + type: 'ApiKeyCredentialProvider' as const, + name: config.name, + apiKey: config.apiKey, + }; + + void createIdentity(createConfig).then(result => { if (result.ok) { setFlow({ name: 'create-success', identityName: result.result.name }); return; @@ -63,7 +78,7 @@ export function AddIdentityFlow({ isInteractive = true, onExit, onBack, onDev, o void; onExit: () => void; existingIdentityNames: string[]; + initialType?: CredentialType; } -export function AddIdentityScreen({ onComplete, onExit, existingIdentityNames }: AddIdentityScreenProps) { - const wizard = useAddIdentityWizard(); +export function AddIdentityScreen({ onComplete, onExit, existingIdentityNames, initialType }: AddIdentityScreenProps) { + const wizard = useAddIdentityWizard(initialType); const typeItems: SelectableItem[] = useMemo( () => IDENTITY_TYPE_OPTIONS.map(opt => ({ id: opt.id, title: opt.title, description: opt.description })), @@ -27,7 +28,12 @@ export function AddIdentityScreen({ onComplete, onExit, existingIdentityNames }: const isTypeStep = wizard.step === 'type'; const isNameStep = wizard.step === 'name'; const isApiKeyStep = wizard.step === 'apiKey'; + const isDiscoveryUrlStep = wizard.step === 'discoveryUrl'; + const isClientIdStep = wizard.step === 'clientId'; + const isClientSecretStep = wizard.step === 'clientSecret'; + const isScopesStep = wizard.step === 'scopes'; const isConfirmStep = wizard.step === 'confirm'; + const isOAuth = wizard.config.identityType === 'OAuthCredentialProvider'; const typeNav = useListNavigation({ items: typeItems, @@ -51,6 +57,10 @@ export function AddIdentityScreen({ onComplete, onExit, existingIdentityNames }: const headerContent = ; + const defaultName = isOAuth + ? generateUniqueName('MyOAuth', existingIdentityNames) + : generateUniqueName('MyApiKey', existingIdentityNames); + return ( @@ -67,10 +77,11 @@ export function AddIdentityScreen({ onComplete, onExit, existingIdentityNames }: wizard.goBack()} schema={CredentialNameSchema} + customValidation={value => !existingIdentityNames.includes(value) || 'Credential name already exists'} /> )} @@ -85,13 +96,81 @@ export function AddIdentityScreen({ onComplete, onExit, existingIdentityNames }: /> )} + {isDiscoveryUrlStep && ( + wizard.goBack()} + customValidation={value => { + try { + new URL(value); + } catch { + return 'Must be a valid URL'; + } + if (!value.endsWith('/.well-known/openid-configuration')) { + return "URL must end with '/.well-known/openid-configuration'"; + } + return true; + }} + /> + )} + + {isClientIdStep && ( + wizard.goBack()} + customValidation={value => value.trim().length > 0 || 'Client ID is required'} + revealChars={4} + /> + )} + + {isClientSecretStep && ( + wizard.goBack()} + customValidation={value => value.trim().length > 0 || 'Client secret is required'} + revealChars={4} + /> + )} + + {isScopesStep && ( + wizard.goBack()} + allowEmpty + /> + )} + {isConfirmStep && ( )} diff --git a/src/cli/tui/screens/identity/types.ts b/src/cli/tui/screens/identity/types.ts index b936a1e2..49bdf6bf 100644 --- a/src/cli/tui/screens/identity/types.ts +++ b/src/cli/tui/screens/identity/types.ts @@ -4,18 +4,36 @@ import type { CredentialType } from '../../../../schema'; // Identity Flow Types // ───────────────────────────────────────────────────────────────────────────── -export type AddIdentityStep = 'type' | 'name' | 'apiKey' | 'confirm'; +export type AddIdentityStep = + | 'type' + | 'name' + | 'apiKey' + | 'discoveryUrl' + | 'clientId' + | 'clientSecret' + | 'scopes' + | 'confirm'; export interface AddIdentityConfig { identityType: CredentialType; name: string; + /** API Key (when type is ApiKeyCredentialProvider) */ apiKey: string; + /** OAuth fields (when type is OAuthCredentialProvider) */ + discoveryUrl?: string; + clientId?: string; + clientSecret?: string; + scopes?: string; } export const IDENTITY_STEP_LABELS: Record = { type: 'Type', name: 'Name', apiKey: 'API Key', + discoveryUrl: 'Discovery URL', + clientId: 'Client ID', + clientSecret: 'Client Secret', + scopes: 'Scopes', confirm: 'Confirm', }; @@ -25,4 +43,5 @@ export const IDENTITY_STEP_LABELS: Record = { export const IDENTITY_TYPE_OPTIONS = [ { id: 'ApiKeyCredentialProvider' as const, title: 'API Key', description: 'Store and manage API key credentials' }, + { id: 'OAuthCredentialProvider' as const, title: 'OAuth', description: 'OAuth 2.0 client credentials' }, ] as const; diff --git a/src/cli/tui/screens/identity/useAddIdentityWizard.ts b/src/cli/tui/screens/identity/useAddIdentityWizard.ts index b870091c..ea1271f1 100644 --- a/src/cli/tui/screens/identity/useAddIdentityWizard.ts +++ b/src/cli/tui/screens/identity/useAddIdentityWizard.ts @@ -1,74 +1,126 @@ import type { CredentialType } from '../../../../schema'; import type { AddIdentityConfig, AddIdentityStep } from './types'; -import { useCallback, useState } from 'react'; +import { useCallback, useMemo, useState } from 'react'; -const ALL_STEPS: AddIdentityStep[] = ['type', 'name', 'apiKey', 'confirm']; +function getSteps(identityType: CredentialType, skipTypeStep: boolean): AddIdentityStep[] { + const steps: AddIdentityStep[] = + identityType === 'OAuthCredentialProvider' + ? ['type', 'name', 'discoveryUrl', 'clientId', 'clientSecret', 'scopes', 'confirm'] + : ['type', 'name', 'apiKey', 'confirm']; -function getDefaultConfig(): AddIdentityConfig { + return skipTypeStep ? steps.filter(s => s !== 'type') : steps; +} + +function getDefaultConfig(initialType?: CredentialType): AddIdentityConfig { return { - identityType: 'ApiKeyCredentialProvider', + identityType: initialType ?? 'ApiKeyCredentialProvider', name: '', apiKey: '', }; } -export function useAddIdentityWizard() { - const [config, setConfig] = useState(getDefaultConfig); - const [step, setStep] = useState('type'); +export function useAddIdentityWizard(initialType?: CredentialType) { + const hasInitialType = initialType !== undefined; + const [config, setConfig] = useState(() => getDefaultConfig(initialType)); + const [step, setStep] = useState(hasInitialType ? 'name' : 'type'); - const currentIndex = ALL_STEPS.indexOf(step); + const steps = useMemo(() => getSteps(config.identityType, hasInitialType), [config.identityType, hasInitialType]); + const currentIndex = steps.indexOf(step); const goBack = useCallback(() => { - const prevStep = ALL_STEPS[currentIndex - 1]; + const prevStep = steps[currentIndex - 1]; if (prevStep) setStep(prevStep); - }, [currentIndex]); - - const nextStep = useCallback((currentStep: AddIdentityStep): AddIdentityStep | undefined => { - const idx = ALL_STEPS.indexOf(currentStep); - return ALL_STEPS[idx + 1]; - }, []); + }, [currentIndex, steps]); - const setIdentityType = useCallback( - (identityType: CredentialType) => { - setConfig(c => ({ ...c, identityType })); - const next = nextStep('type'); + const advanceFrom = useCallback( + (currentStep: AddIdentityStep) => { + const currentSteps = getSteps(config.identityType, hasInitialType); + const idx = currentSteps.indexOf(currentStep); + const next = currentSteps[idx + 1]; if (next) setStep(next); }, - [nextStep] + [config.identityType, hasInitialType] ); + const setIdentityType = useCallback((identityType: CredentialType) => { + setConfig(c => ({ + ...c, + identityType, + apiKey: '', + discoveryUrl: undefined, + clientId: undefined, + clientSecret: undefined, + scopes: undefined, + })); + setStep('name'); + }, []); + const setName = useCallback( (name: string) => { setConfig(c => ({ ...c, name })); - const next = nextStep('name'); - if (next) setStep(next); + advanceFrom('name'); }, - [nextStep] + [advanceFrom] ); const setApiKey = useCallback( (apiKey: string) => { setConfig(c => ({ ...c, apiKey })); - const next = nextStep('apiKey'); - if (next) setStep(next); + advanceFrom('apiKey'); + }, + [advanceFrom] + ); + + const setDiscoveryUrl = useCallback( + (discoveryUrl: string) => { + setConfig(c => ({ ...c, discoveryUrl })); + advanceFrom('discoveryUrl'); }, - [nextStep] + [advanceFrom] + ); + + const setClientId = useCallback( + (clientId: string) => { + setConfig(c => ({ ...c, clientId })); + advanceFrom('clientId'); + }, + [advanceFrom] + ); + + const setClientSecret = useCallback( + (clientSecret: string) => { + setConfig(c => ({ ...c, clientSecret })); + advanceFrom('clientSecret'); + }, + [advanceFrom] + ); + + const setScopes = useCallback( + (scopes: string) => { + setConfig(c => ({ ...c, scopes: scopes || undefined })); + advanceFrom('scopes'); + }, + [advanceFrom] ); const reset = useCallback(() => { - setConfig(getDefaultConfig()); - setStep('type'); - }, []); + setConfig(getDefaultConfig(initialType)); + setStep(hasInitialType ? 'name' : 'type'); + }, [initialType, hasInitialType]); return { config, step, - steps: ALL_STEPS, + steps, currentIndex, goBack, setIdentityType, setName, setApiKey, + setDiscoveryUrl, + setClientId, + setClientSecret, + setScopes, reset, }; } diff --git a/src/cli/tui/screens/identity/useCreateIdentity.ts b/src/cli/tui/screens/identity/useCreateIdentity.ts index f53f73db..1dee9e37 100644 --- a/src/cli/tui/screens/identity/useCreateIdentity.ts +++ b/src/cli/tui/screens/identity/useCreateIdentity.ts @@ -3,6 +3,7 @@ import { type CreateCredentialConfig, createCredential, getAllCredentialNames, + getAllCredentials, } from '../../../operations/identity/create-identity'; import { useCallback, useEffect, useState } from 'react'; @@ -50,5 +51,20 @@ export function useExistingCredentialNames() { return { names, refresh }; } +export function useExistingCredentials() { + const [credentials, setCredentials] = useState([]); + + useEffect(() => { + void getAllCredentials().then(setCredentials); + }, []); + + const refresh = useCallback(async () => { + const result = await getAllCredentials(); + setCredentials(result); + }, []); + + return { credentials, refresh }; +} + // Alias for old name export const useExistingIdentityNames = useExistingCredentialNames; diff --git a/src/cli/tui/screens/mcp/AddGatewayScreen.tsx b/src/cli/tui/screens/mcp/AddGatewayScreen.tsx index 13269eef..dca25086 100644 --- a/src/cli/tui/screens/mcp/AddGatewayScreen.tsx +++ b/src/cli/tui/screens/mcp/AddGatewayScreen.tsx @@ -4,6 +4,7 @@ import { ConfirmReview, Panel, Screen, + SecretInput, StepIndicator, TextInput, WizardMultiSelect, @@ -29,10 +30,13 @@ interface AddGatewayScreenProps { export function AddGatewayScreen({ onComplete, onExit, existingGateways, unassignedTargets }: AddGatewayScreenProps) { const wizard = useAddGatewayWizard(unassignedTargets.length); - // JWT config sub-step tracking (0 = discoveryUrl, 1 = audience, 2 = clients) + // JWT config sub-step tracking (0=discoveryUrl, 1=audience, 2=clients, 3=scopes, 4=agentClientId, 5=agentClientSecret) const [jwtSubStep, setJwtSubStep] = useState(0); const [jwtDiscoveryUrl, setJwtDiscoveryUrl] = useState(''); const [jwtAudience, setJwtAudience] = useState(''); + const [jwtClients, setJwtClients] = useState(''); + const [jwtScopes, setJwtScopes] = useState(''); + const [jwtAgentClientId, setJwtAgentClientId] = useState(''); const unassignedTargetItems: SelectableItem[] = useMemo( () => unassignedTargets.map(name => ({ id: name, title: name })), @@ -85,12 +89,30 @@ export function AddGatewayScreen({ onComplete, onExit, existingGateways, unassig }; const handleJwtClients = (clients: string) => { - // Parse comma-separated values + setJwtClients(clients); + setJwtSubStep(3); + }; + + const handleJwtScopes = (scopes: string) => { + setJwtScopes(scopes); + setJwtSubStep(4); + }; + + const handleJwtAgentClientId = (clientId: string) => { + setJwtAgentClientId(clientId); + setJwtSubStep(5); + }; + + const handleJwtAgentClientSecret = (clientSecret: string) => { const audienceList = jwtAudience .split(',') .map(s => s.trim()) .filter(Boolean); - const clientsList = clients + const clientsList = jwtClients + .split(',') + .map(s => s.trim()) + .filter(Boolean); + const scopesList = jwtScopes .split(',') .map(s => s.trim()) .filter(Boolean); @@ -99,9 +121,10 @@ export function AddGatewayScreen({ onComplete, onExit, existingGateways, unassig discoveryUrl: jwtDiscoveryUrl, allowedAudience: audienceList, allowedClients: clientsList, + ...(scopesList.length > 0 ? { allowedScopes: scopesList } : {}), + ...(jwtAgentClientId ? { agentClientId: jwtAgentClientId, agentClientSecret: clientSecret } : {}), }); - // Reset sub-step counter only - preserve values for potential back navigation setJwtSubStep(0); }; @@ -160,6 +183,9 @@ export function AddGatewayScreen({ onComplete, onExit, existingGateways, unassig onDiscoveryUrl={handleJwtDiscoveryUrl} onAudience={handleJwtAudience} onClients={handleJwtClients} + onScopes={handleJwtScopes} + onAgentClientId={handleJwtAgentClientId} + onAgentClientSecret={handleJwtAgentClientSecret} onCancel={handleJwtCancel} /> )} @@ -187,6 +213,12 @@ export function AddGatewayScreen({ onComplete, onExit, existingGateways, unassig { label: 'Discovery URL', value: wizard.config.jwtConfig.discoveryUrl }, { label: 'Allowed Audience', value: wizard.config.jwtConfig.allowedAudience.join(', ') }, { label: 'Allowed Clients', value: wizard.config.jwtConfig.allowedClients.join(', ') }, + ...(wizard.config.jwtConfig.allowedScopes?.length + ? [{ label: 'Allowed Scopes', value: wizard.config.jwtConfig.allowedScopes.join(', ') }] + : []), + ...(wizard.config.jwtConfig.agentClientId + ? [{ label: 'Agent Credential', value: `${wizard.config.name}-agent-oauth` }] + : []), ] : []), { @@ -209,6 +241,9 @@ interface JwtConfigInputProps { onDiscoveryUrl: (url: string) => void; onAudience: (audience: string) => void; onClients: (clients: string) => void; + onScopes: (scopes: string) => void; + onAgentClientId: (clientId: string) => void; + onAgentClientSecret: (clientSecret: string) => void; onCancel: () => void; } @@ -227,16 +262,28 @@ function validateCommaSeparatedList(value: string, fieldName: string): true | st return true; } -function JwtConfigInput({ subStep, onDiscoveryUrl, onAudience, onClients, onCancel }: JwtConfigInputProps) { +function JwtConfigInput({ + subStep, + onDiscoveryUrl, + onAudience, + onClients, + onScopes, + onAgentClientId, + onAgentClientSecret, + onCancel, +}: JwtConfigInputProps) { + const totalSteps = 6; return ( Configure Custom JWT Authorizer - Step {subStep + 1} of 3 + + Step {subStep + 1} of {totalSteps} + {subStep === 0 && ( { @@ -271,6 +318,33 @@ function JwtConfigInput({ subStep, onDiscoveryUrl, onAudience, onClients, onCanc customValidation={value => validateCommaSeparatedList(value, 'client')} /> )} + {subStep === 3 && ( + + )} + {subStep === 4 && ( + + )} + {subStep === 5 && ( + value.trim().length > 0 || 'Client secret is required'} + revealChars={4} + /> + )} ); diff --git a/src/cli/tui/screens/mcp/AddGatewayTargetFlow.tsx b/src/cli/tui/screens/mcp/AddGatewayTargetFlow.tsx index c6cce11d..a840d68e 100644 --- a/src/cli/tui/screens/mcp/AddGatewayTargetFlow.tsx +++ b/src/cli/tui/screens/mcp/AddGatewayTargetFlow.tsx @@ -2,12 +2,16 @@ import { createExternalGatewayTarget } from '../../../operations/mcp/create-mcp' import { ErrorPrompt } from '../../components'; import { useCreateGatewayTarget, useExistingGateways, useExistingToolNames } from '../../hooks/useCreateMcp'; import { AddSuccessScreen } from '../add/AddSuccessScreen'; +import { AddIdentityScreen } from '../identity/AddIdentityScreen'; +import type { AddIdentityConfig } from '../identity/types'; +import { useCreateIdentity, useExistingCredentials, useExistingIdentityNames } from '../identity/useCreateIdentity'; import { AddGatewayTargetScreen } from './AddGatewayTargetScreen'; import type { AddGatewayTargetConfig } from './types'; -import React, { useCallback, useEffect, useState } from 'react'; +import React, { useCallback, useEffect, useMemo, useState } from 'react'; type FlowState = | { name: 'create-wizard' } + | { name: 'creating-credential'; pendingConfig: AddGatewayTargetConfig } | { name: 'create-success'; toolName: string; projectPath: string; loading?: boolean; loadingMessage?: string } | { name: 'error'; message: string }; @@ -32,8 +36,16 @@ export function AddGatewayTargetFlow({ const { createTool, reset: resetCreate } = useCreateGatewayTarget(); const { gateways: existingGateways } = useExistingGateways(); const { toolNames: existingToolNames } = useExistingToolNames(); + const { credentials } = useExistingCredentials(); + const { names: existingIdentityNames } = useExistingIdentityNames(); + const { createIdentity } = useCreateIdentity(); const [flow, setFlow] = useState({ name: 'create-wizard' }); + const oauthCredentialNames = useMemo( + () => credentials.filter(c => c.type === 'OAuthCredentialProvider').map(c => c.name), + [credentials] + ); + // In non-interactive mode, exit after success (but not while loading) useEffect(() => { if (!isInteractive && flow.name === 'create-success' && !flow.loading) { @@ -73,18 +85,72 @@ export function AddGatewayTargetFlow({ [createTool] ); + const handleCreateCredential = useCallback((pendingConfig: AddGatewayTargetConfig) => { + setFlow({ name: 'creating-credential', pendingConfig }); + }, []); + + const handleIdentityComplete = useCallback( + (identityConfig: AddIdentityConfig) => { + const createConfig = + identityConfig.identityType === 'OAuthCredentialProvider' + ? { + type: 'OAuthCredentialProvider' as const, + name: identityConfig.name, + discoveryUrl: identityConfig.discoveryUrl!, + clientId: identityConfig.clientId!, + clientSecret: identityConfig.clientSecret!, + scopes: identityConfig.scopes + ?.split(',') + .map(s => s.trim()) + .filter(Boolean), + } + : { + type: 'ApiKeyCredentialProvider' as const, + name: identityConfig.name, + apiKey: identityConfig.apiKey, + }; + + void createIdentity(createConfig).then(result => { + if (result.ok && flow.name === 'creating-credential') { + const finalConfig: AddGatewayTargetConfig = { + ...flow.pendingConfig, + outboundAuth: { type: 'OAUTH', credentialName: result.result.name }, + }; + handleCreateComplete(finalConfig); + } else if (!result.ok) { + setFlow({ name: 'error', message: result.error }); + } + }); + }, + [flow, createIdentity, handleCreateComplete] + ); + // Create wizard if (flow.name === 'create-wizard') { return ( ); } + // Creating credential via identity screen + if (flow.name === 'creating-credential') { + return ( + setFlow({ name: 'create-wizard' })} + initialType="OAuthCredentialProvider" + /> + ); + } + // Create success if (flow.name === 'create-success') { return ( diff --git a/src/cli/tui/screens/mcp/AddGatewayTargetScreen.tsx b/src/cli/tui/screens/mcp/AddGatewayTargetScreen.tsx index 30ece187..d62f1088 100644 --- a/src/cli/tui/screens/mcp/AddGatewayTargetScreen.tsx +++ b/src/cli/tui/screens/mcp/AddGatewayTargetScreen.tsx @@ -1,10 +1,9 @@ import { ToolNameSchema } from '../../../../schema'; -import { ConfirmReview, Panel, Screen, SecretInput, StepIndicator, TextInput, WizardSelect } from '../../components'; +import { ConfirmReview, Panel, Screen, StepIndicator, TextInput, WizardSelect } from '../../components'; import type { SelectableItem } from '../../components'; import { HELP_TEXT } from '../../constants'; import { useListNavigation } from '../../hooks'; import { generateUniqueName } from '../../utils'; -import { useCreateIdentity, useExistingCredentialNames } from '../identity/useCreateIdentity.js'; import type { AddGatewayTargetConfig } from './types'; import { MCP_TOOL_STEP_LABELS, OUTBOUND_AUTH_OPTIONS, SKIP_FOR_NOW } from './types'; import { useAddGatewayTargetWizard } from './useAddGatewayTargetWizard'; @@ -14,28 +13,23 @@ import React, { useMemo, useState } from 'react'; interface AddGatewayTargetScreenProps { existingGateways: string[]; existingToolNames: string[]; + existingOAuthCredentialNames: string[]; onComplete: (config: AddGatewayTargetConfig) => void; + onCreateCredential: (pendingConfig: AddGatewayTargetConfig) => void; onExit: () => void; } export function AddGatewayTargetScreen({ existingGateways, existingToolNames, + existingOAuthCredentialNames, onComplete, + onCreateCredential, onExit, }: AddGatewayTargetScreenProps) { const wizard = useAddGatewayTargetWizard(existingGateways); - const { names: existingCredentialNames } = useExistingCredentialNames(); - const { createIdentity } = useCreateIdentity(); - // Outbound auth sub-step state const [outboundAuthType, setOutboundAuthTypeLocal] = useState(null); - const [credentialName, setCredentialNameLocal] = useState(null); - const [isCreatingCredential, setIsCreatingCredential] = useState(false); - const [oauthSubStep, setOauthSubStep] = useState<'name' | 'client-id' | 'client-secret' | 'discovery-url'>('name'); - const [oauthFields, setOauthFields] = useState({ name: '', clientId: '', clientSecret: '', discoveryUrl: '' }); - const [apiKeySubStep, setApiKeySubStep] = useState<'name' | 'api-key'>('name'); - const [apiKeyFields, setApiKeyFields] = useState({ name: '', apiKey: '' }); const gatewayItems: SelectableItem[] = useMemo( () => [ @@ -51,14 +45,14 @@ export function AddGatewayTargetScreen({ ); const credentialItems: SelectableItem[] = useMemo(() => { - const items: SelectableItem[] = [ - { id: 'create-new', title: 'Create new credential', description: 'Create a new credential inline' }, - ]; - existingCredentialNames.forEach(name => { - items.push({ id: name, title: name, description: 'Use existing credential' }); - }); + const items: SelectableItem[] = existingOAuthCredentialNames.map(name => ({ + id: name, + title: name, + description: 'Use existing OAuth credential', + })); + items.push({ id: 'create-new', title: 'Create new credential', description: 'Create a new OAuth credential' }); return items; - }, [existingCredentialNames]); + }, [existingOAuthCredentialNames]); const isGatewayStep = wizard.step === 'gateway'; const isOutboundAuthStep = wizard.step === 'outbound-auth'; @@ -76,10 +70,14 @@ export function AddGatewayTargetScreen({ const outboundAuthNav = useListNavigation({ items: outboundAuthItems, onSelect: item => { - const authType = item.id as 'OAUTH' | 'API_KEY' | 'NONE'; - setOutboundAuthTypeLocal(authType); + const authType = item.id as 'OAUTH' | 'NONE'; if (authType === 'NONE') { wizard.setOutboundAuth({ type: 'NONE' }); + } else if (existingOAuthCredentialNames.length === 0) { + // No existing OAuth credentials — go straight to creation + onCreateCredential(wizard.config); + } else { + setOutboundAuthTypeLocal(authType); } }, onExit: () => wizard.goBack(), @@ -90,28 +88,15 @@ export function AddGatewayTargetScreen({ items: credentialItems, onSelect: item => { if (item.id === 'create-new') { - setIsCreatingCredential(true); - if (outboundAuthType === 'OAUTH') { - setOauthSubStep('name'); - } else { - setApiKeySubStep('name'); - } + onCreateCredential(wizard.config); } else { - setCredentialNameLocal(item.id); - wizard.setOutboundAuth({ type: outboundAuthType as 'OAUTH' | 'API_KEY', credentialName: item.id }); + wizard.setOutboundAuth({ type: 'OAUTH', credentialName: item.id }); } }, onExit: () => { setOutboundAuthTypeLocal(null); - setCredentialNameLocal(null); - setIsCreatingCredential(false); }, - isActive: - isOutboundAuthStep && - !!outboundAuthType && - outboundAuthType !== 'NONE' && - !credentialName && - !isCreatingCredential, + isActive: isOutboundAuthStep && outboundAuthType === 'OAUTH', }); useListNavigation({ @@ -119,121 +104,14 @@ export function AddGatewayTargetScreen({ onSelect: () => onComplete(wizard.config), onExit: () => { setOutboundAuthTypeLocal(null); - setCredentialNameLocal(null); - setIsCreatingCredential(false); - setOauthSubStep('name'); - setOauthFields({ name: '', clientId: '', clientSecret: '', discoveryUrl: '' }); - setApiKeySubStep('name'); - setApiKeyFields({ name: '', apiKey: '' }); wizard.goBack(); }, isActive: isConfirmStep, }); - // OAuth creation handlers - const handleOauthFieldSubmit = (value: string) => { - const newFields = { ...oauthFields }; - - if (oauthSubStep === 'name') { - newFields.name = value; - setOauthFields(newFields); - setOauthSubStep('client-id'); - } else if (oauthSubStep === 'client-id') { - newFields.clientId = value; - setOauthFields(newFields); - setOauthSubStep('client-secret'); - } else if (oauthSubStep === 'client-secret') { - newFields.clientSecret = value; - setOauthFields(newFields); - setOauthSubStep('discovery-url'); - } else if (oauthSubStep === 'discovery-url') { - newFields.discoveryUrl = value; - setOauthFields(newFields); - - // Create the credential - void createIdentity({ - type: 'OAuthCredentialProvider', - name: newFields.name, - clientId: newFields.clientId, - clientSecret: newFields.clientSecret, - discoveryUrl: newFields.discoveryUrl, - }) - .then(result => { - if (result.ok) { - wizard.setOutboundAuth({ type: 'OAUTH', credentialName: newFields.name }); - } else { - setIsCreatingCredential(false); - setOauthSubStep('name'); - setOauthFields({ name: '', clientId: '', clientSecret: '', discoveryUrl: '' }); - } - }) - .catch(() => { - setIsCreatingCredential(false); - setOauthSubStep('name'); - setOauthFields({ name: '', clientId: '', clientSecret: '', discoveryUrl: '' }); - }); - } - }; - - const handleOauthFieldCancel = () => { - if (oauthSubStep === 'name') { - setIsCreatingCredential(false); - setOauthFields({ name: '', clientId: '', clientSecret: '', discoveryUrl: '' }); - } else if (oauthSubStep === 'client-id') { - setOauthSubStep('name'); - } else if (oauthSubStep === 'client-secret') { - setOauthSubStep('client-id'); - } else if (oauthSubStep === 'discovery-url') { - setOauthSubStep('client-secret'); - } - }; - - // API Key creation handlers - const handleApiKeyFieldSubmit = (value: string) => { - const newFields = { ...apiKeyFields }; - - if (apiKeySubStep === 'name') { - newFields.name = value; - setApiKeyFields(newFields); - setApiKeySubStep('api-key'); - } else if (apiKeySubStep === 'api-key') { - newFields.apiKey = value; - setApiKeyFields(newFields); - - void createIdentity({ - type: 'ApiKeyCredentialProvider', - name: newFields.name, - apiKey: newFields.apiKey, - }) - .then(result => { - if (result.ok) { - wizard.setOutboundAuth({ type: 'API_KEY', credentialName: newFields.name }); - } else { - setIsCreatingCredential(false); - setApiKeySubStep('name'); - setApiKeyFields({ name: '', apiKey: '' }); - } - }) - .catch(() => { - setIsCreatingCredential(false); - setApiKeySubStep('name'); - setApiKeyFields({ name: '', apiKey: '' }); - }); - } - }; - - const handleApiKeyFieldCancel = () => { - if (apiKeySubStep === 'name') { - setIsCreatingCredential(false); - setApiKeyFields({ name: '', apiKey: '' }); - } else if (apiKeySubStep === 'api-key') { - setApiKeySubStep('name'); - } - }; - const helpText = isConfirmStep ? HELP_TEXT.CONFIRM_CANCEL - : isTextStep || isCreatingCredential + : isTextStep ? HELP_TEXT.TEXT_INPUT : HELP_TEXT.NAVIGATE_SELECT; @@ -262,96 +140,13 @@ export function AddGatewayTargetScreen({ /> )} - {isOutboundAuthStep && - outboundAuthType && - outboundAuthType !== 'NONE' && - !credentialName && - !isCreatingCredential && ( - - )} - - {isOutboundAuthStep && isCreatingCredential && outboundAuthType === 'OAUTH' && ( - <> - {oauthSubStep === 'name' && ( - !existingCredentialNames.includes(value) || 'Credential name already exists'} - /> - )} - {oauthSubStep === 'client-id' && ( - value.trim().length > 0 || 'Client ID is required'} - /> - )} - {oauthSubStep === 'client-secret' && ( - value.trim().length > 0 || 'Client secret is required'} - revealChars={4} - /> - )} - {oauthSubStep === 'discovery-url' && ( - { - try { - const url = new URL(value); - if (url.protocol !== 'http:' && url.protocol !== 'https:') { - return 'Discovery URL must use http:// or https:// protocol'; - } - return true; - } catch { - return 'Must be a valid URL'; - } - }} - /> - )} - - )} - - {isOutboundAuthStep && isCreatingCredential && outboundAuthType === 'API_KEY' && ( - <> - {apiKeySubStep === 'name' && ( - !existingCredentialNames.includes(value) || 'Credential name already exists'} - /> - )} - {apiKeySubStep === 'api-key' && ( - value.trim().length > 0 || 'API key is required'} - revealChars={4} - /> - )} - + {isOutboundAuthStep && outboundAuthType === 'OAUTH' && ( + )} {isTextStep && ( diff --git a/src/cli/tui/screens/mcp/types.ts b/src/cli/tui/screens/mcp/types.ts index fcf7d593..f24aeed5 100644 --- a/src/cli/tui/screens/mcp/types.ts +++ b/src/cli/tui/screens/mcp/types.ts @@ -16,6 +16,9 @@ export interface AddGatewayConfig { discoveryUrl: string; allowedAudience: string[]; allowedClients: string[]; + allowedScopes?: string[]; + agentClientId?: string; + agentClientSecret?: string; }; /** Selected unassigned targets to include in this gateway */ selectedTargets?: string[]; diff --git a/src/cli/tui/screens/mcp/useAddGatewayWizard.ts b/src/cli/tui/screens/mcp/useAddGatewayWizard.ts index 2bd24b75..90265bca 100644 --- a/src/cli/tui/screens/mcp/useAddGatewayWizard.ts +++ b/src/cli/tui/screens/mcp/useAddGatewayWizard.ts @@ -68,7 +68,14 @@ export function useAddGatewayWizard(unassignedTargetsCount = 0) { }, []); const setJwtConfig = useCallback( - (jwtConfig: { discoveryUrl: string; allowedAudience: string[]; allowedClients: string[] }) => { + (jwtConfig: { + discoveryUrl: string; + allowedAudience: string[]; + allowedClients: string[]; + allowedScopes?: string[]; + agentClientId?: string; + agentClientSecret?: string; + }) => { setConfig(c => ({ ...c, jwtConfig, diff --git a/src/cli/tui/screens/remove/RemoveScreen.tsx b/src/cli/tui/screens/remove/RemoveScreen.tsx index f64ddc8b..bcb7307c 100644 --- a/src/cli/tui/screens/remove/RemoveScreen.tsx +++ b/src/cli/tui/screens/remove/RemoveScreen.tsx @@ -6,7 +6,7 @@ const REMOVE_RESOURCES = [ { id: 'agent', title: 'Agent', description: 'Remove an agent from the project' }, { id: 'memory', title: 'Memory', description: 'Remove a memory provider' }, { id: 'identity', title: 'Identity', description: 'Remove an identity provider' }, - { id: 'gateway', title: 'Gateway', description: 'Remove an MCP gateway' }, + { id: 'gateway', title: 'Gateway', description: 'Remove a gateway' }, { id: 'gateway-target', title: 'Gateway Target', description: 'Remove a gateway target' }, { id: 'all', title: 'All', description: 'Reset entire agentcore project' }, ] as const; diff --git a/src/cli/tui/screens/schema/EditSchemaScreen.tsx b/src/cli/tui/screens/schema/EditSchemaScreen.tsx index 42846e09..1721f5f8 100644 --- a/src/cli/tui/screens/schema/EditSchemaScreen.tsx +++ b/src/cli/tui/screens/schema/EditSchemaScreen.tsx @@ -44,7 +44,7 @@ export function EditSchemaScreen(props: EditSchemaScreenProps) { { id: 'mcp', title: 'mcp.json', - description: `MCP gateways and tools${mcpMissing}`, + description: `Gateways and tools${mcpMissing}`, filePath: mcpPath, schema: AgentCoreMcpSpecSchema, }, diff --git a/src/cli/tui/screens/schema/McpGuidedEditor.tsx b/src/cli/tui/screens/schema/McpGuidedEditor.tsx index 30760bd0..28534403 100644 --- a/src/cli/tui/screens/schema/McpGuidedEditor.tsx +++ b/src/cli/tui/screens/schema/McpGuidedEditor.tsx @@ -631,7 +631,7 @@ function McpEditorBody(props: { - + {gateways.length === 0 ? ( No gateways configured. Press A to add one. ) : ( diff --git a/src/schema/schemas/agentcore-project.ts b/src/schema/schemas/agentcore-project.ts index 13f8241f..fda34160 100644 --- a/src/schema/schemas/agentcore-project.ts +++ b/src/schema/schemas/agentcore-project.ts @@ -101,6 +101,8 @@ export const OAuthCredentialSchema = z.object({ vendor: z.string().default('CustomOauth2'), /** Whether this credential was auto-created by the CLI (e.g., for CUSTOM_JWT inbound auth) */ managed: z.boolean().optional(), + /** Whether this credential is used for inbound or outbound auth */ + usage: z.enum(['inbound', 'outbound']).optional(), }); export type OAuthCredential = z.infer;