Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions src/assets/__tests__/__snapshots__/assets.snapshot.test.ts.snap
Original file line number Diff line number Diff line change
Expand Up @@ -1726,6 +1726,41 @@ logger = logging.getLogger(__name__)
import httpx
from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session
{{/if}}
{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}}
{{#unless (includes gatewayAuthTypes "AWS_IAM")}}import httpx
{{/unless}}import time as _time
{{/if}}

{{#each gatewayProviders}}
{{#if (eq authType "CUSTOM_JWT")}}
_token_cache_{{snakeCase name}} = {"token": None, "expires_at": 0}

def _get_bearer_token_{{snakeCase name}}():
"""Obtain OAuth access token via client_credentials grant for {{name}}."""
cache = _token_cache_{{snakeCase name}}
if cache["token"] and _time.time() < cache["expires_at"]:
return cache["token"]
client_id = os.environ.get("{{credentialEnvVarBase}}_CLIENT_ID")
client_secret = os.environ.get("{{credentialEnvVarBase}}_CLIENT_SECRET")
if not client_id or not client_secret:
logger.warning("Agent OAuth credentials not set — {{name}} CUSTOM_JWT auth unavailable")
return None
with httpx.Client() as c:
disc = c.get("{{discoveryUrl}}")
token_ep = disc.json()["token_endpoint"]
resp = c.post(token_ep, data={
"grant_type": "client_credentials",
"client_id": client_id,
"client_secret": client_secret,
{{#if scopes}}"scope": "{{scopes}}",{{/if}}
})
data = resp.json()
cache["token"] = data["access_token"]
cache["expires_at"] = _time.time() + data.get("expires_in", 3600) - 60
return cache["token"]

{{/if}}
{{/each}}

def get_all_gateway_mcp_toolsets() -> list[MCPToolset]:
"""Returns MCP Toolsets for all configured gateways."""
Expand All @@ -1740,6 +1775,10 @@ def get_all_gateway_mcp_toolsets() -> list[MCPToolset]:
url=url,
httpx_client_factory=lambda **kwargs: httpx.AsyncClient(auth=auth, **kwargs)
)))
{{else if (eq authType "CUSTOM_JWT")}}
token = _get_bearer_token_{{snakeCase name}}()
headers = {"Authorization": f"Bearer {token}"} if token else None
toolsets.append(MCPToolset(connection_params=StreamableHTTPConnectionParams(url=url, headers=headers)))
{{else}}
toolsets.append(MCPToolset(connection_params=StreamableHTTPConnectionParams(url=url)))
{{/if}}
Expand Down Expand Up @@ -2012,6 +2051,41 @@ logger = logging.getLogger(__name__)
{{#if (includes gatewayAuthTypes "AWS_IAM")}}
from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session
{{/if}}
{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}}
{{#unless (includes gatewayAuthTypes "AWS_IAM")}}import httpx
{{/unless}}import time as _time
{{/if}}

{{#each gatewayProviders}}
{{#if (eq authType "CUSTOM_JWT")}}
_token_cache_{{snakeCase name}} = {"token": None, "expires_at": 0}

def _get_bearer_token_{{snakeCase name}}():
"""Obtain OAuth access token via client_credentials grant for {{name}}."""
cache = _token_cache_{{snakeCase name}}
if cache["token"] and _time.time() < cache["expires_at"]:
return cache["token"]
client_id = os.environ.get("{{credentialEnvVarBase}}_CLIENT_ID")
client_secret = os.environ.get("{{credentialEnvVarBase}}_CLIENT_SECRET")
if not client_id or not client_secret:
logger.warning("Agent OAuth credentials not set — {{name}} CUSTOM_JWT auth unavailable")
return None
with httpx.Client() as c:
disc = c.get("{{discoveryUrl}}")
token_ep = disc.json()["token_endpoint"]
resp = c.post(token_ep, data={
"grant_type": "client_credentials",
"client_id": client_id,
"client_secret": client_secret,
{{#if scopes}}"scope": "{{scopes}}",{{/if}}
})
data = resp.json()
cache["token"] = data["access_token"]
cache["expires_at"] = _time.time() + data.get("expires_in", 3600) - 60
return cache["token"]

{{/if}}
{{/each}}

def get_all_gateway_mcp_client() -> MultiServerMCPClient | None:
"""Returns an MCP Client connected to all configured gateways."""
Expand All @@ -2023,6 +2097,10 @@ def get_all_gateway_mcp_client() -> MultiServerMCPClient | None:
session = create_aws_session()
auth = SigV4HTTPXAuth(session.get_credentials(), "bedrock-agentcore", session.region_name)
servers["{{name}}"] = {"transport": "streamable_http", "url": url, "auth": auth}
{{else if (eq authType "CUSTOM_JWT")}}
token = _get_bearer_token_{{snakeCase name}}()
headers = {"Authorization": f"Bearer {token}"} if token else None
servers["{{name}}"] = {"transport": "streamable_http", "url": url, "headers": headers}
{{else}}
servers["{{name}}"] = {"transport": "streamable_http", "url": url}
{{/if}}
Expand Down Expand Up @@ -2438,6 +2516,41 @@ logger = logging.getLogger(__name__)
import httpx
from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session
{{/if}}
{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}}
{{#unless (includes gatewayAuthTypes "AWS_IAM")}}import httpx
{{/unless}}import time as _time
{{/if}}

{{#each gatewayProviders}}
{{#if (eq authType "CUSTOM_JWT")}}
_token_cache_{{snakeCase name}} = {"token": None, "expires_at": 0}

def _get_bearer_token_{{snakeCase name}}():
"""Obtain OAuth access token via client_credentials grant for {{name}}."""
cache = _token_cache_{{snakeCase name}}
if cache["token"] and _time.time() < cache["expires_at"]:
return cache["token"]
client_id = os.environ.get("{{credentialEnvVarBase}}_CLIENT_ID")
client_secret = os.environ.get("{{credentialEnvVarBase}}_CLIENT_SECRET")
if not client_id or not client_secret:
logger.warning("Agent OAuth credentials not set — {{name}} CUSTOM_JWT auth unavailable")
return None
with httpx.Client() as c:
disc = c.get("{{discoveryUrl}}")
token_ep = disc.json()["token_endpoint"]
resp = c.post(token_ep, data={
"grant_type": "client_credentials",
"client_id": client_id,
"client_secret": client_secret,
{{#if scopes}}"scope": "{{scopes}}",{{/if}}
})
data = resp.json()
cache["token"] = data["access_token"]
cache["expires_at"] = _time.time() + data.get("expires_in", 3600) - 60
return cache["token"]

{{/if}}
{{/each}}

def get_all_gateway_mcp_servers() -> list[MCPServerStreamableHttp]:
"""Returns MCP servers for all configured gateways."""
Expand All @@ -2452,6 +2565,10 @@ def get_all_gateway_mcp_servers() -> list[MCPServerStreamableHttp]:
name="{{name}}",
params={"url": url, "httpx_client_factory": lambda **kwargs: httpx.AsyncClient(auth=auth, **kwargs)}
))
{{else if (eq authType "CUSTOM_JWT")}}
token = _get_bearer_token_{{snakeCase name}}()
headers = {"Authorization": f"Bearer {token}"} if token else {}
servers.append(MCPServerStreamableHttp(name="{{name}}", params={"url": url, "headers": headers}))
{{else}}
servers.append(MCPServerStreamableHttp(name="{{name}}", params={"url": url}))
{{/if}}
Expand Down Expand Up @@ -2749,7 +2866,41 @@ logger = logging.getLogger(__name__)
{{#if (includes gatewayAuthTypes "AWS_IAM")}}
from mcp_proxy_for_aws.client import aws_iam_streamablehttp_client
{{/if}}
{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}}
{{#unless (includes gatewayAuthTypes "AWS_IAM")}}import httpx
{{/unless}}import time as _time
{{/if}}

{{#each gatewayProviders}}
{{#if (eq authType "CUSTOM_JWT")}}
_token_cache_{{snakeCase name}} = {"token": None, "expires_at": 0}

def _get_bearer_token_{{snakeCase name}}():
"""Obtain OAuth access token via client_credentials grant for {{name}}."""
cache = _token_cache_{{snakeCase name}}
if cache["token"] and _time.time() < cache["expires_at"]:
return cache["token"]
client_id = os.environ.get("{{credentialEnvVarBase}}_CLIENT_ID")
client_secret = os.environ.get("{{credentialEnvVarBase}}_CLIENT_SECRET")
if not client_id or not client_secret:
logger.warning("Agent OAuth credentials not set — {{name}} CUSTOM_JWT auth unavailable")
return None
with httpx.Client() as c:
disc = c.get("{{discoveryUrl}}")
token_ep = disc.json()["token_endpoint"]
resp = c.post(token_ep, data={
"grant_type": "client_credentials",
"client_id": client_id,
"client_secret": client_secret,
{{#if scopes}}"scope": "{{scopes}}",{{/if}}
})
data = resp.json()
cache["token"] = data["access_token"]
cache["expires_at"] = _time.time() + data.get("expires_in", 3600) - 60
return cache["token"]

{{/if}}
{{/each}}
{{#each gatewayProviders}}
def get_{{snakeCase name}}_mcp_client() -> MCPClient | None:
"""Returns an MCP Client connected to the {{name}} gateway."""
Expand All @@ -2759,6 +2910,10 @@ def get_{{snakeCase name}}_mcp_client() -> MCPClient | None:
return None
{{#if (eq authType "AWS_IAM")}}
return MCPClient(lambda: aws_iam_streamablehttp_client(url, aws_service="bedrock-agentcore", aws_region=os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION"))))
{{else if (eq authType "CUSTOM_JWT")}}
token = _get_bearer_token_{{snakeCase name}}()
headers = {"Authorization": f"Bearer {token}"} if token else {}
return MCPClient(lambda: streamablehttp_client(url, headers=headers))
{{else}}
return MCPClient(lambda: streamablehttp_client(url))
{{/if}}
Expand Down
39 changes: 39 additions & 0 deletions src/assets/python/googleadk/base/mcp_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,41 @@
import httpx
from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session
{{/if}}
{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}}
{{#unless (includes gatewayAuthTypes "AWS_IAM")}}import httpx
{{/unless}}import time as _time
{{/if}}

{{#each gatewayProviders}}
{{#if (eq authType "CUSTOM_JWT")}}
_token_cache_{{snakeCase name}} = {"token": None, "expires_at": 0}

def _get_bearer_token_{{snakeCase name}}():
"""Obtain OAuth access token via client_credentials grant for {{name}}."""
cache = _token_cache_{{snakeCase name}}
if cache["token"] and _time.time() < cache["expires_at"]:
return cache["token"]
client_id = os.environ.get("{{credentialEnvVarBase}}_CLIENT_ID")
client_secret = os.environ.get("{{credentialEnvVarBase}}_CLIENT_SECRET")
if not client_id or not client_secret:
logger.warning("Agent OAuth credentials not set — {{name}} CUSTOM_JWT auth unavailable")
return None
with httpx.Client() as c:
disc = c.get("{{discoveryUrl}}")
token_ep = disc.json()["token_endpoint"]
resp = c.post(token_ep, data={
"grant_type": "client_credentials",
"client_id": client_id,
"client_secret": client_secret,
{{#if scopes}}"scope": "{{scopes}}",{{/if}}
})
data = resp.json()
cache["token"] = data["access_token"]
cache["expires_at"] = _time.time() + data.get("expires_in", 3600) - 60
return cache["token"]

{{/if}}
{{/each}}

def get_all_gateway_mcp_toolsets() -> list[MCPToolset]:
"""Returns MCP Toolsets for all configured gateways."""
Expand All @@ -24,6 +59,10 @@ def get_all_gateway_mcp_toolsets() -> list[MCPToolset]:
url=url,
httpx_client_factory=lambda **kwargs: httpx.AsyncClient(auth=auth, **kwargs)
)))
{{else if (eq authType "CUSTOM_JWT")}}
token = _get_bearer_token_{{snakeCase name}}()
headers = {"Authorization": f"Bearer {token}"} if token else None
toolsets.append(MCPToolset(connection_params=StreamableHTTPConnectionParams(url=url, headers=headers)))
{{else}}
toolsets.append(MCPToolset(connection_params=StreamableHTTPConnectionParams(url=url)))
{{/if}}
Expand Down
39 changes: 39 additions & 0 deletions src/assets/python/langchain_langgraph/base/mcp_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,41 @@
{{#if (includes gatewayAuthTypes "AWS_IAM")}}
from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session
{{/if}}
{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}}
{{#unless (includes gatewayAuthTypes "AWS_IAM")}}import httpx
{{/unless}}import time as _time
{{/if}}

{{#each gatewayProviders}}
{{#if (eq authType "CUSTOM_JWT")}}
_token_cache_{{snakeCase name}} = {"token": None, "expires_at": 0}

def _get_bearer_token_{{snakeCase name}}():
"""Obtain OAuth access token via client_credentials grant for {{name}}."""
cache = _token_cache_{{snakeCase name}}
if cache["token"] and _time.time() < cache["expires_at"]:
return cache["token"]
client_id = os.environ.get("{{credentialEnvVarBase}}_CLIENT_ID")
client_secret = os.environ.get("{{credentialEnvVarBase}}_CLIENT_SECRET")
if not client_id or not client_secret:
logger.warning("Agent OAuth credentials not set — {{name}} CUSTOM_JWT auth unavailable")
return None
with httpx.Client() as c:
disc = c.get("{{discoveryUrl}}")
token_ep = disc.json()["token_endpoint"]
resp = c.post(token_ep, data={
"grant_type": "client_credentials",
"client_id": client_id,
"client_secret": client_secret,
{{#if scopes}}"scope": "{{scopes}}",{{/if}}
})
data = resp.json()
cache["token"] = data["access_token"]
cache["expires_at"] = _time.time() + data.get("expires_in", 3600) - 60
return cache["token"]

{{/if}}
{{/each}}

def get_all_gateway_mcp_client() -> MultiServerMCPClient | None:
"""Returns an MCP Client connected to all configured gateways."""
Expand All @@ -19,6 +54,10 @@ def get_all_gateway_mcp_client() -> MultiServerMCPClient | None:
session = create_aws_session()
auth = SigV4HTTPXAuth(session.get_credentials(), "bedrock-agentcore", session.region_name)
servers["{{name}}"] = {"transport": "streamable_http", "url": url, "auth": auth}
{{else if (eq authType "CUSTOM_JWT")}}
token = _get_bearer_token_{{snakeCase name}}()
headers = {"Authorization": f"Bearer {token}"} if token else None
servers["{{name}}"] = {"transport": "streamable_http", "url": url, "headers": headers}
{{else}}
servers["{{name}}"] = {"transport": "streamable_http", "url": url}
{{/if}}
Expand Down
39 changes: 39 additions & 0 deletions src/assets/python/openaiagents/base/mcp_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,41 @@
import httpx
from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session
{{/if}}
{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}}
{{#unless (includes gatewayAuthTypes "AWS_IAM")}}import httpx
{{/unless}}import time as _time
{{/if}}

{{#each gatewayProviders}}
{{#if (eq authType "CUSTOM_JWT")}}
_token_cache_{{snakeCase name}} = {"token": None, "expires_at": 0}

def _get_bearer_token_{{snakeCase name}}():
"""Obtain OAuth access token via client_credentials grant for {{name}}."""
cache = _token_cache_{{snakeCase name}}
if cache["token"] and _time.time() < cache["expires_at"]:
return cache["token"]
client_id = os.environ.get("{{credentialEnvVarBase}}_CLIENT_ID")
client_secret = os.environ.get("{{credentialEnvVarBase}}_CLIENT_SECRET")
if not client_id or not client_secret:
logger.warning("Agent OAuth credentials not set — {{name}} CUSTOM_JWT auth unavailable")
return None
with httpx.Client() as c:
disc = c.get("{{discoveryUrl}}")
token_ep = disc.json()["token_endpoint"]
resp = c.post(token_ep, data={
"grant_type": "client_credentials",
"client_id": client_id,
"client_secret": client_secret,
{{#if scopes}}"scope": "{{scopes}}",{{/if}}
})
data = resp.json()
cache["token"] = data["access_token"]
cache["expires_at"] = _time.time() + data.get("expires_in", 3600) - 60
return cache["token"]

{{/if}}
{{/each}}

def get_all_gateway_mcp_servers() -> list[MCPServerStreamableHttp]:
"""Returns MCP servers for all configured gateways."""
Expand All @@ -23,6 +58,10 @@ def get_all_gateway_mcp_servers() -> list[MCPServerStreamableHttp]:
name="{{name}}",
params={"url": url, "httpx_client_factory": lambda **kwargs: httpx.AsyncClient(auth=auth, **kwargs)}
))
{{else if (eq authType "CUSTOM_JWT")}}
token = _get_bearer_token_{{snakeCase name}}()
headers = {"Authorization": f"Bearer {token}"} if token else {}
servers.append(MCPServerStreamableHttp(name="{{name}}", params={"url": url, "headers": headers}))
{{else}}
servers.append(MCPServerStreamableHttp(name="{{name}}", params={"url": url}))
{{/if}}
Expand Down
Loading
Loading