diff --git a/.agents/skills/strandsagents/SKILL.md b/.agents/skills/strandsagents/SKILL.md new file mode 100644 index 0000000..2f2f3de --- /dev/null +++ b/.agents/skills/strandsagents/SKILL.md @@ -0,0 +1,502 @@ +--- +name: strandagents +description: Comprehensive guide for Strands Agents SDK development with Python. Use when building AI agents with Strands, including OpenAI integration, custom tools, multi-agent systems, security, Docker deployment, testing, observability, and MCP server integration. Covers installation, configuration, best practices, and production deployment patterns. +--- + +# Strands Agents SDK + +Strands Agents is a lightweight, model-driven Python SDK for building AI agents. It scales from simple conversational assistants to complex autonomous workflows with support for multiple LLM providers, custom tools, multi-agent systems, and production deployment. + +## When to Use This Skill + +Use this skill when: +- Building AI agents with Strands SDK +- Integrating OpenAI or other LLM providers +- Creating custom tools for agents +- Implementing multi-agent systems +- Deploying agents with Docker +- Setting up security and observability +- Integrating MCP servers +- Testing agent applications + +## Quick Start + +See [references/quickstart.md](references/quickstart.md) for installation and basic usage. + +```python +from strands import Agent +from strands.models import OpenAIModel + +# Create agent with OpenAI +model = OpenAIModel( + model_id="gpt-4o", + client_args={"api_key": "your-api-key"} +) +agent = Agent(model=model) +response = agent("Hello, world!") +``` + +## Core Concepts + +### Agent + +The `Agent` class is the core component that orchestrates conversations with LLMs: + +```python +from strands import Agent + +agent = Agent( + model=model, # LLM model provider + system_prompt="...", # System instructions + tools=[...], # Available tools + messages=[...], # Initial conversation history + state={...}, # Persistent state + hooks=[...], # Lifecycle hooks + conversation_manager=..., # Context window management + trace_attributes={...} # OpenTelemetry attributes +) +``` + +### Models + +Strands supports multiple LLM providers: +- **OpenAI**: `OpenAIModel` - See [references/openai.md](references/openai.md) +- **Anthropic**: `AnthropicModel` +- **Amazon Bedrock**: `BedrockModel` (default) +- **Google Gemini**: `GeminiModel` +- **Ollama**: `OllamaModel` (local) +- **LiteLLM**: `LiteLLMModel` (100+ providers) + +### Tools + +Custom tools extend agent capabilities. See [references/tools.md](references/tools.md) for complete guide. + +```python +from strands import Agent, tool + +@tool +def get_weather(city: str) -> dict: + """Get weather for a city.""" + return {"status": "success", "content": [{"text": f"Sunny in {city}"}]} + +agent = Agent(tools=[get_weather]) +``` + +## Key Features + +### 1. OpenAI Integration + +See [references/openai.md](references/openai.md) for: +- API key configuration +- Model selection (gpt-4o, gpt-4o-mini, etc.) +- Parameters (temperature, max_tokens, etc.) +- Streaming responses +- Error handling +- Best practices + +### 2. Custom Tools + +See [references/tools.md](references/tools.md) for: +- Creating tools with `@tool` decorator +- Tool response format +- Context-aware tools +- Async tools +- Dynamic tool loading +- Built-in tools +- Best practices + +### 3. Security + +See [references/security.md](references/security.md) for: +- API key management (environment variables, vaults) +- Input validation +- Guardrails (AWS Bedrock, custom filters) +- Rate limiting +- Secure tool execution +- Logging best practices +- Vulnerability reporting + +### 4. Docker Deployment + +See [references/docker.md](references/docker.md) for: +- Dockerfile examples (basic, multi-stage, production) +- Docker Compose configurations +- Environment variables +- Secrets management +- Health checks +- Monitoring +- Resource limits +- Production checklist + +### 5. Multi-Agent Systems + +See [references/multi-agent.md](references/multi-agent.md) for: +- Agent as tool pattern +- Sequential workflows +- Parallel workflows +- Swarm pattern +- Hierarchical agents +- State sharing +- Communication patterns +- Best practices + +### 6. Observability + +See [references/observability.md](references/observability.md) for: +- OpenTelemetry integration +- Metrics (tokens, latency, etc.) +- Custom tracing +- Structured logging +- Hooks for monitoring +- Prometheus/CloudWatch integration +- Debugging techniques + +### 7. Testing + +See [references/testing.md](references/testing.md) for: +- Development environment setup +- Unit testing patterns +- Integration testing +- Testing hooks +- Mock models +- Async testing +- Code quality tools +- CI/CD integration +- Coverage reporting + +### 8. MCP Integration + +See [references/mcp.md](references/mcp.md) for: +- Connecting to MCP servers +- Tool filtering and prefixes +- Multiple servers +- MCP prompts and resources +- Popular MCP servers +- Error handling +- Best practices +- Custom MCP servers + +## Development Workflow + +### 1. Setup Environment + +```bash +# Create virtual environment +python -m venv .venv +source .venv/bin/activate + +# Install Strands +pip install strands-agents strands-agents-tools + +# For development +hatch shell +pre-commit install -t pre-commit -t commit-msg +``` + +### 2. Create Agent + +```python +from strands import Agent +from strands.models import OpenAIModel +import os + +# Configure model +model = OpenAIModel( + model_id="gpt-4o", + client_args={"api_key": os.environ.get("OPENAI_API_KEY")} +) + +# Create agent +agent = Agent( + model=model, + system_prompt="You are a helpful assistant." +) +``` + +### 3. Add Tools + +```python +from strands import tool + +@tool +def my_tool(param: str) -> dict: + """Tool description for LLM. + + Args: + param: Parameter description. + """ + return {"status": "success", "content": [{"text": f"Result: {param}"}]} + +agent = Agent(model=model, tools=[my_tool]) +``` + +### 4. Test + +```bash +# Run tests +hatch test + +# With coverage +hatch test -c + +# Integration tests +hatch run test-integ +``` + +### 5. Deploy + +See [references/docker.md](references/docker.md) for production deployment. + +## Best Practices + +### Security +- Never hardcode API keys - use environment variables +- Validate all tool inputs +- Implement rate limiting +- Use guardrails for content filtering +- Log securely (mask sensitive data) +- See [references/security.md](references/security.md) + +### Tools +- Use clear docstrings (Google-style) +- Include type hints +- Return structured responses +- Handle errors gracefully +- Design for idempotency +- See [references/tools.md](references/tools.md) + +### Multi-Agent +- Give agents clear responsibilities +- Avoid circular dependencies +- Limit delegation depth +- Handle failures gracefully +- Monitor performance +- See [references/multi-agent.md](references/multi-agent.md) + +### Observability +- Use structured logging +- Set trace attributes +- Monitor token usage +- Implement health checks +- Export to monitoring systems +- See [references/observability.md](references/observability.md) + +### Testing +- Write unit tests for tools +- Test with real models (integration) +- Mock external dependencies +- Test error handling +- Maintain high coverage +- See [references/testing.md](references/testing.md) + +### Deployment +- Use multi-stage Docker builds +- Run as non-root user +- Set resource limits +- Configure health checks +- Use secrets management +- See [references/docker.md](references/docker.md) + +## Common Patterns + +### Conversational Agent + +```python +from strands import Agent +from strands.models import OpenAIModel + +model = OpenAIModel(model_id="gpt-4o") +agent = Agent(model=model, system_prompt="You are a helpful assistant.") + +while True: + user_input = input("You: ") + if user_input.lower() == "exit": + break + response = agent(user_input) + print(f"Agent: {response}") +``` + +### Agent with Tools + +```python +from strands import Agent, tool +from strands.models import OpenAIModel + +@tool +def calculate(expression: str) -> dict: + """Evaluate mathematical expression.""" + result = eval(expression) + return {"status": "success", "content": [{"text": str(result)}]} + +model = OpenAIModel(model_id="gpt-4o") +agent = Agent(model=model, tools=[calculate]) +response = agent("What is 15 * 23?") +``` + +### Streaming Agent + +```python +from strands import Agent +from strands.models import OpenAIModel +import asyncio + +async def stream(): + model = OpenAIModel(model_id="gpt-4o") + agent = Agent(model=model) + + async for event in agent.stream_async("Tell me a story"): + if "data" in event: + print(event["data"], end="", flush=True) + +asyncio.run(stream()) +``` + +### Multi-Agent System + +```python +from strands import Agent, tool +from strands.models import OpenAIModel + +# Create specialist +specialist = Agent( + model=OpenAIModel(model_id="gpt-4o"), + system_prompt="You are a research specialist." +) + +# Wrap as tool +@tool +def research(query: str) -> dict: + """Research a topic.""" + result = specialist(query) + return {"status": "success", "content": [{"text": str(result)}]} + +# Create coordinator +coordinator = Agent( + model=OpenAIModel(model_id="gpt-4o"), + tools=[research] +) + +response = coordinator("Research AI trends") +``` + +## Troubleshooting + +### API Key Issues +- Check environment variable: `echo $OPENAI_API_KEY` +- Verify key is valid and has credits +- Use `.env` file for local development + +### Tool Not Called +- Ensure clear docstring describing tool purpose +- Check tool name doesn't conflict +- Verify tool returns correct format + +### Memory Issues +- Use conversation manager to limit context +- Implement sliding window or summarization +- Clear messages periodically + +### Performance Issues +- Monitor token usage via `result.metrics` +- Use smaller models for simple tasks +- Implement caching for repeated queries +- Use async tools for I/O operations + +### Connection Errors +- Check network connectivity +- Verify API endpoint is accessible +- Implement retry logic with backoff + +## Resources + +### Documentation +- [Strands GitHub](https://github.com/strands-agents/sdk-python) +- [OpenAI API Docs](https://platform.openai.com/docs) +- [MCP Documentation](https://modelcontextprotocol.io/) + +### Reference Files +- [quickstart.md](references/quickstart.md) - Installation and basic usage +- [openai.md](references/openai.md) - OpenAI integration +- [tools.md](references/tools.md) - Custom tools +- [security.md](references/security.md) - Security best practices +- [docker.md](references/docker.md) - Docker deployment +- [multi-agent.md](references/multi-agent.md) - Multi-agent systems +- [observability.md](references/observability.md) - Monitoring and telemetry +- [testing.md](references/testing.md) - Testing strategies +- [mcp.md](references/mcp.md) - MCP server integration + +## Examples + +### Production Agent + +```python +import os +import logging +from strands import Agent, tool +from strands.models import OpenAIModel +from strands.telemetry import StrandsTelemetry +from strands.hooks import HookProvider, HookRegistry, AfterInvocationEvent + +# Setup logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Setup telemetry +StrandsTelemetry().setup_console_exporter().setup_otlp_exporter() + +# Metrics hook +class MetricsHooks(HookProvider): + def register_hooks(self, registry: HookRegistry, **kwargs) -> None: + registry.add_callback(AfterInvocationEvent, self.log_metrics) + + def log_metrics(self, event: AfterInvocationEvent) -> None: + if event.result: + logger.info( + "tokens_in=<%d>, tokens_out=<%d>, latency_ms=<%d> | invocation completed", + event.result.metrics.input_tokens, + event.result.metrics.output_tokens, + event.result.metrics.total_time + ) + +# Custom tool +@tool +def process_data(data: str) -> dict: + """Process data with validation.""" + if not data or len(data) > 1000: + return {"status": "error", "content": [{"text": "Invalid data"}]} + + # Process + result = data.upper() + return {"status": "success", "content": [{"text": result}]} + +# Create production agent +def create_agent(): + api_key = os.environ.get("OPENAI_API_KEY") + if not api_key: + raise ValueError("OPENAI_API_KEY not set") + + model = OpenAIModel( + model_id="gpt-4o", + client_args={"api_key": api_key}, + params={"temperature": 0.7, "max_tokens": 2048} + ) + + return Agent( + model=model, + tools=[process_data], + hooks=[MetricsHooks()], + trace_attributes={ + "environment": "production", + "version": "1.0.0" + } + ) + +if __name__ == "__main__": + agent = create_agent() + response = agent("Process this data: hello world") + print(response) +``` + +## Version Information + +This skill is based on Strands Agents SDK Python documentation (latest as of January 2026). Check the [GitHub repository](https://github.com/strands-agents/sdk-python) for updates. diff --git a/.agents/skills/strandsagents/references/docker.md b/.agents/skills/strandsagents/references/docker.md new file mode 100644 index 0000000..f050cea --- /dev/null +++ b/.agents/skills/strandsagents/references/docker.md @@ -0,0 +1,434 @@ +# Docker Deployment + +## Basic Dockerfile + +```dockerfile +FROM python:3.11-slim + +WORKDIR /app + +# Install dependencies +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# Copy application +COPY . . + +# Run agent +CMD ["python", "agent.py"] +``` + +## Multi-Stage Build + +Optimize image size: + +```dockerfile +# Build stage +FROM python:3.11-slim as builder + +WORKDIR /app + +# Install build dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + gcc \ + && rm -rf /var/lib/apt/lists/* + +# Install Python dependencies +COPY requirements.txt . +RUN pip install --no-cache-dir --user -r requirements.txt + +# Runtime stage +FROM python:3.11-slim + +WORKDIR /app + +# Copy Python dependencies from builder +COPY --from=builder /root/.local /root/.local + +# Copy application +COPY . . + +# Update PATH +ENV PATH=/root/.local/bin:$PATH + +# Run agent +CMD ["python", "agent.py"] +``` + +## Production Dockerfile + +With security and optimization: + +```dockerfile +FROM python:3.11-slim + +# Create non-root user +RUN useradd -m -u 1000 agent && \ + mkdir -p /app && \ + chown -R agent:agent /app + +WORKDIR /app + +# Install dependencies as root +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt && \ + rm -rf /root/.cache + +# Copy application +COPY --chown=agent:agent . . + +# Switch to non-root user +USER agent + +# Health check +HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ + CMD python -c "import sys; sys.exit(0)" + +# Run agent +CMD ["python", "agent.py"] +``` + +## Docker Compose + +### Basic Setup + +```yaml +version: '3.8' + +services: + agent: + build: . + environment: + - OPENAI_API_KEY=${OPENAI_API_KEY} + volumes: + - ./data:/app/data + restart: unless-stopped +``` + +### With Environment File + +```yaml +version: '3.8' + +services: + agent: + build: . + env_file: + - .env + volumes: + - ./data:/app/data + - ./logs:/app/logs + restart: unless-stopped + logging: + driver: "json-file" + options: + max-size: "10m" + max-file: "3" +``` + +### Multi-Agent Setup + +```yaml +version: '3.8' + +services: + coordinator: + build: . + environment: + - AGENT_ROLE=coordinator + - OPENAI_API_KEY=${OPENAI_API_KEY} + depends_on: + - worker1 + - worker2 + networks: + - agent-network + + worker1: + build: . + environment: + - AGENT_ROLE=worker + - WORKER_ID=1 + - OPENAI_API_KEY=${OPENAI_API_KEY} + networks: + - agent-network + + worker2: + build: . + environment: + - AGENT_ROLE=worker + - WORKER_ID=2 + - OPENAI_API_KEY=${OPENAI_API_KEY} + networks: + - agent-network + +networks: + agent-network: + driver: bridge +``` + +## Example Application + +### agent.py + +```python +import os +from strands import Agent +from strands.models import OpenAIModel + +def main(): + # Get configuration from environment + openai_key = os.environ.get("OPENAI_API_KEY") + if not openai_key: + raise ValueError("OPENAI_API_KEY not set") + + # Create agent + model = OpenAIModel( + model_id="gpt-4o", + client_args={"api_key": openai_key} + ) + agent = Agent( + model=model, + system_prompt="You are a helpful assistant." + ) + + # Run agent loop + print("Agent started. Type 'exit' to quit.") + while True: + user_input = input("You: ") + if user_input.lower() == "exit": + break + + response = agent(user_input) + print(f"Agent: {response}") + +if __name__ == "__main__": + main() +``` + +### requirements.txt + +``` +strands-agents==0.1.0 +strands-agents-tools==0.1.0 +python-dotenv==1.0.0 +``` + +## Building and Running + +### Build Image + +```bash +docker build -t my-agent . +``` + +### Run Container + +```bash +docker run -e OPENAI_API_KEY=$OPENAI_API_KEY my-agent +``` + +### With Docker Compose + +```bash +# Build and start +docker-compose up --build + +# Run in background +docker-compose up -d + +# View logs +docker-compose logs -f + +# Stop +docker-compose down +``` + +## Environment Variables + +### .env File + +```bash +# API Keys +OPENAI_API_KEY=sk-proj-abc123... +ANTHROPIC_API_KEY=sk-ant-abc123... + +# Agent Configuration +AGENT_NAME=MyAgent +AGENT_ROLE=coordinator +LOG_LEVEL=INFO + +# Model Configuration +MODEL_ID=gpt-4o +TEMPERATURE=0.7 +MAX_TOKENS=2048 + +# OpenTelemetry +OTEL_SERVICE_NAME=my-agent-service +OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:4318 +``` + +## Secrets Management + +### Docker Secrets + +```yaml +version: '3.8' + +services: + agent: + build: . + secrets: + - openai_api_key + environment: + - OPENAI_API_KEY_FILE=/run/secrets/openai_api_key + +secrets: + openai_api_key: + file: ./secrets/openai_api_key.txt +``` + +Read secret in Python: + +```python +import os + +def get_secret(secret_name): + secret_file = os.environ.get(f"{secret_name.upper()}_FILE") + if secret_file: + with open(secret_file) as f: + return f.read().strip() + return os.environ.get(secret_name.upper()) + +openai_key = get_secret("openai_api_key") +``` + +## Monitoring + +### Health Check + +```python +# health.py +from strands import Agent +from strands.models import OpenAIModel +import os +import sys + +def health_check(): + try: + openai_key = os.environ.get("OPENAI_API_KEY") + if not openai_key: + return False + + model = OpenAIModel(model_id="gpt-4o", client_args={"api_key": openai_key}) + agent = Agent(model=model) + + # Simple test + response = agent("ping") + return bool(response) + except Exception: + return False + +if __name__ == "__main__": + sys.exit(0 if health_check() else 1) +``` + +### Dockerfile with Health Check + +```dockerfile +FROM python:3.11-slim + +WORKDIR /app + +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +COPY . . + +HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ + CMD python health.py + +CMD ["python", "agent.py"] +``` + +## Logging + +### Configure Logging + +```python +import logging +import sys + +# Configure logging for container +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler(sys.stdout) # Log to stdout for Docker + ] +) + +logger = logging.getLogger(__name__) +``` + +### Docker Compose Logging + +```yaml +version: '3.8' + +services: + agent: + build: . + logging: + driver: "json-file" + options: + max-size: "10m" + max-file: "5" + labels: "agent,production" +``` + +## Best Practices + +1. **Use Multi-Stage Builds**: Reduce image size +2. **Non-Root User**: Run as non-root for security +3. **Health Checks**: Implement health check endpoints +4. **Environment Variables**: Use for configuration +5. **Secrets Management**: Never hardcode secrets +6. **Logging**: Log to stdout/stderr for Docker +7. **Resource Limits**: Set memory and CPU limits +8. **Restart Policy**: Use `unless-stopped` or `on-failure` +9. **Volume Mounts**: Persist data outside container +10. **Network Isolation**: Use Docker networks + +## Resource Limits + +```yaml +version: '3.8' + +services: + agent: + build: . + deploy: + resources: + limits: + cpus: '2' + memory: 4G + reservations: + cpus: '1' + memory: 2G + restart: unless-stopped +``` + +## Production Checklist + +- [ ] Multi-stage build for smaller images +- [ ] Non-root user configured +- [ ] Health checks implemented +- [ ] Secrets via environment or Docker secrets +- [ ] Logging to stdout/stderr +- [ ] Resource limits set +- [ ] Restart policy configured +- [ ] Volumes for persistent data +- [ ] Network isolation configured +- [ ] Security scanning (e.g., Trivy) +- [ ] Image tagged with version +- [ ] Documentation updated diff --git a/.agents/skills/strandsagents/references/mcp.md b/.agents/skills/strandsagents/references/mcp.md new file mode 100644 index 0000000..978c89a --- /dev/null +++ b/.agents/skills/strandsagents/references/mcp.md @@ -0,0 +1,430 @@ +# Model Context Protocol (MCP) Integration + +## Overview + +MCP enables integration with thousands of pre-built tools from MCP servers. Strands provides seamless integration via `MCPClient`. + +## Basic Usage + +### Connect to MCP Server + +```python +from strands import Agent +from strands.tools.mcp import MCPClient +from mcp import stdio_client, StdioServerParameters + +# Connect to AWS documentation server +aws_docs_client = MCPClient( + lambda: stdio_client(StdioServerParameters( + command="uvx", + args=["awslabs.aws-documentation-mcp-server@latest"] + )) +) + +with aws_docs_client: + agent = Agent(tools=aws_docs_client.list_tools_sync()) + response = agent("Tell me about Amazon Bedrock and how to use it with Python") +``` + +## MCPClient Configuration + +### Startup Timeout + +```python +aws_docs_client = MCPClient( + lambda: stdio_client(StdioServerParameters( + command="uvx", + args=["awslabs.aws-documentation-mcp-server@latest"] + )), + startup_timeout=30 # Timeout for server initialization +) +``` + +### Tool Filtering + +```python +# Only include specific tools +filesystem_client = MCPClient( + lambda: stdio_client(StdioServerParameters( + command="npx", + args=["@anthropic/mcp-server-filesystem", "/tmp"] + )), + tool_filters={ + "allowed": ["read_file", "list_directory"], # Only these tools + "rejected": ["delete_file"] # Exclude these + }, + prefix="fs" # Prefix tool names: fs_read_file, fs_list_directory +) +``` + +### Tool Prefix + +```python +# Add prefix to avoid name conflicts +client = MCPClient( + lambda: stdio_client(StdioServerParameters( + command="npx", + args=["@anthropic/mcp-server-filesystem", "/tmp"] + )), + prefix="fs" # Tools become: fs_read_file, fs_list_directory, etc. +) +``` + +## Multiple MCP Servers + +Combine tools from multiple servers: + +```python +from strands import Agent +from strands.tools.mcp import MCPClient +from mcp import stdio_client, StdioServerParameters + +# AWS documentation server +aws_docs_client = MCPClient( + lambda: stdio_client(StdioServerParameters( + command="uvx", + args=["awslabs.aws-documentation-mcp-server@latest"] + )) +) + +# Filesystem server +filesystem_client = MCPClient( + lambda: stdio_client(StdioServerParameters( + command="npx", + args=["@anthropic/mcp-server-filesystem", "/tmp"] + )), + prefix="fs" +) + +# Use both servers +with aws_docs_client, filesystem_client: + all_tools = aws_docs_client.list_tools_sync() + filesystem_client.list_tools_sync() + agent = Agent(tools=all_tools) + response = agent("Read the README file and summarize it") +``` + +## MCP Prompts + +Access pre-defined prompts from MCP servers: + +```python +from strands.tools.mcp import MCPClient +from mcp import stdio_client, StdioServerParameters + +client = MCPClient( + lambda: stdio_client(StdioServerParameters( + command="uvx", + args=["awslabs.aws-documentation-mcp-server@latest"] + )) +) + +with client: + # List available prompts + prompts = client.list_prompts_sync() + print(f"Available prompts: {[p.name for p in prompts]}") + + # Get a specific prompt + prompt_result = client.get_prompt_sync("my-prompt", {"arg1": "value"}) + print(f"Prompt: {prompt_result}") +``` + +## MCP Resources + +Access resources from MCP servers: + +```python +from strands.tools.mcp import MCPClient +from mcp import stdio_client, StdioServerParameters + +client = MCPClient( + lambda: stdio_client(StdioServerParameters( + command="npx", + args=["@anthropic/mcp-server-filesystem", "/tmp"] + )) +) + +with client: + # List available resources + resources = client.list_resources_sync() + print(f"Available resources: {[r.uri for r in resources]}") + + # Read a specific resource + content = client.read_resource_sync("file:///path/to/resource") + print(f"Resource content: {content}") +``` + +## Popular MCP Servers + +### AWS Documentation + +```python +aws_docs_client = MCPClient( + lambda: stdio_client(StdioServerParameters( + command="uvx", + args=["awslabs.aws-documentation-mcp-server@latest"] + )) +) +``` + +### Filesystem + +```python +filesystem_client = MCPClient( + lambda: stdio_client(StdioServerParameters( + command="npx", + args=["@anthropic/mcp-server-filesystem", "/path/to/directory"] + )) +) +``` + +### GitHub + +```python +github_client = MCPClient( + lambda: stdio_client(StdioServerParameters( + command="npx", + args=["@modelcontextprotocol/server-github"] + )) +) +``` + +### Brave Search + +```python +brave_client = MCPClient( + lambda: stdio_client(StdioServerParameters( + command="npx", + args=["@modelcontextprotocol/server-brave-search"] + )) +) +``` + +## Error Handling + +### Connection Errors + +```python +from strands.tools.mcp import MCPClient +from mcp import stdio_client, StdioServerParameters + +try: + client = MCPClient( + lambda: stdio_client(StdioServerParameters( + command="uvx", + args=["awslabs.aws-documentation-mcp-server@latest"] + )), + startup_timeout=30 + ) + + with client: + tools = client.list_tools_sync() + print(f"Connected: {len(tools)} tools available") +except TimeoutError: + print("MCP server startup timeout") +except Exception as e: + print(f"Connection error: {e}") +``` + +### Tool Execution Errors + +```python +from strands import Agent +from strands.tools.mcp import MCPClient + +client = MCPClient(...) + +with client: + agent = Agent(tools=client.list_tools_sync()) + + try: + response = agent("Execute risky operation") + except RuntimeError as e: + if "Connection to the MCP server was closed" in str(e): + print("MCP server connection lost") + else: + raise +``` + +## Best Practices + +### 1. Use Context Managers + +Always use `with` statement for automatic cleanup: + +```python +# ✅ Good - Automatic cleanup +with mcp_client: + agent = Agent(tools=mcp_client.list_tools_sync()) + response = agent("Use MCP tools") + +# ❌ Bad - Manual cleanup required +mcp_client = MCPClient(...) +agent = Agent(tools=mcp_client.list_tools_sync()) +# Cleanup not guaranteed +``` + +### 2. Filter Tools + +Only expose necessary tools to reduce context: + +```python +# ✅ Good - Only necessary tools +client = MCPClient( + ..., + tool_filters={"allowed": ["read_file", "list_directory"]} +) + +# ❌ Bad - All tools exposed +client = MCPClient(...) # Exposes all tools including dangerous ones +``` + +### 3. Use Prefixes + +Avoid name conflicts with prefixes: + +```python +# ✅ Good - Prefixed tools +fs_client = MCPClient(..., prefix="fs") +db_client = MCPClient(..., prefix="db") + +# Both have "list" tool, but become: fs_list, db_list +``` + +### 4. Set Timeouts + +Configure appropriate timeouts: + +```python +# ✅ Good - Reasonable timeout +client = MCPClient(..., startup_timeout=30) + +# ❌ Bad - No timeout (may hang) +client = MCPClient(...) +``` + +### 5. Handle Errors + +Implement error handling for robustness: + +```python +try: + with mcp_client: + agent = Agent(tools=mcp_client.list_tools_sync()) + response = agent("Task") +except TimeoutError: + print("MCP server timeout") +except RuntimeError as e: + print(f"MCP error: {e}") +``` + +## Advanced Usage + +### Custom MCP Server + +Create your own MCP server: + +```python +# server.py +from mcp.server import Server +from mcp.types import Tool + +server = Server("my-custom-server") + +@server.list_tools() +async def list_tools(): + return [ + Tool( + name="custom_tool", + description="A custom tool", + inputSchema={ + "type": "object", + "properties": { + "param": {"type": "string"} + } + } + ) + ] + +@server.call_tool() +async def call_tool(name: str, arguments: dict): + if name == "custom_tool": + return {"result": f"Processed: {arguments['param']}"} + +if __name__ == "__main__": + server.run() +``` + +Connect to custom server: + +```python +custom_client = MCPClient( + lambda: stdio_client(StdioServerParameters( + command="python", + args=["server.py"] + )) +) +``` + +### Async MCP Operations + +```python +import asyncio +from strands.tools.mcp import MCPClient + +async def async_mcp(): + client = MCPClient(...) + + async with client: + tools = await client.list_tools() + prompts = await client.list_prompts() + resources = await client.list_resources() + + print(f"Tools: {len(tools)}") + print(f"Prompts: {len(prompts)}") + print(f"Resources: {len(resources)}") + +asyncio.run(async_mcp()) +``` + +## Troubleshooting + +### Server Not Starting + +```bash +# Check if command is available +which uvx +which npx + +# Test server manually +uvx awslabs.aws-documentation-mcp-server@latest + +# Check logs +python -c "import logging; logging.basicConfig(level=logging.DEBUG)" +``` + +### Connection Hanging + +```python +# Increase timeout +client = MCPClient(..., startup_timeout=60) + +# Check for error messages +import logging +logging.basicConfig(level=logging.DEBUG) +``` + +### Tool Not Found + +```python +# List all available tools +with client: + tools = client.list_tools_sync() + print(f"Available tools: {[t.tool_name for t in tools]}") +``` + +## MCP Resources + +- [MCP Documentation](https://modelcontextprotocol.io/) +- [MCP Servers Registry](https://github.com/modelcontextprotocol/servers) +- [Creating MCP Servers](https://modelcontextprotocol.io/docs/creating-servers) +- [MCP Python SDK](https://github.com/modelcontextprotocol/python-sdk) diff --git a/.agents/skills/strandsagents/references/multi-agent.md b/.agents/skills/strandsagents/references/multi-agent.md new file mode 100644 index 0000000..0d354c2 --- /dev/null +++ b/.agents/skills/strandsagents/references/multi-agent.md @@ -0,0 +1,405 @@ +# Multi-Agent Systems + +## Agent as Tool + +Use one agent as a tool for another: + +```python +from strands import Agent, tool +from strands.models import OpenAIModel + +# Create specialist agent +research_agent = Agent( + model=OpenAIModel(model_id="gpt-4o"), + system_prompt="You are a research specialist. Provide detailed, factual information." +) + +# Wrap agent as tool +@tool +def research(query: str) -> dict: + """Research a topic using the research agent. + + Args: + query: The research question or topic. + """ + result = research_agent(query) + return {"status": "success", "content": [{"text": str(result)}]} + +# Create coordinator agent +coordinator = Agent( + model=OpenAIModel(model_id="gpt-4o"), + tools=[research], + system_prompt="You are a coordinator. Delegate research tasks to the research tool." +) + +# Use multi-agent system +response = coordinator("Research the history of AI and summarize it") +``` + +## Multi-Agent Orchestration + +### Sequential Workflow + +```python +from strands import Agent +from strands.models import OpenAIModel + +# Create specialized agents +planner = Agent( + model=OpenAIModel(model_id="gpt-4o"), + system_prompt="You are a planning specialist. Break down tasks into steps." +) + +executor = Agent( + model=OpenAIModel(model_id="gpt-4o"), + system_prompt="You are an execution specialist. Implement detailed plans." +) + +reviewer = Agent( + model=OpenAIModel(model_id="gpt-4o"), + system_prompt="You are a review specialist. Critique and improve outputs." +) + +# Sequential workflow +def sequential_workflow(task: str) -> str: + # Step 1: Plan + plan = planner(f"Create a plan for: {task}") + print(f"Plan: {plan}") + + # Step 2: Execute + execution = executor(f"Execute this plan: {plan}") + print(f"Execution: {execution}") + + # Step 3: Review + review = reviewer(f"Review this execution: {execution}") + print(f"Review: {review}") + + return str(review) + +result = sequential_workflow("Build a todo app") +``` + +### Parallel Workflow + +```python +from strands import Agent +from strands.models import OpenAIModel +import asyncio + +# Create specialized agents +code_agent = Agent( + model=OpenAIModel(model_id="gpt-4o"), + system_prompt="You are a coding specialist." +) + +docs_agent = Agent( + model=OpenAIModel(model_id="gpt-4o"), + system_prompt="You are a documentation specialist." +) + +test_agent = Agent( + model=OpenAIModel(model_id="gpt-4o"), + system_prompt="You are a testing specialist." +) + +# Parallel workflow +async def parallel_workflow(task: str) -> dict: + # Run agents in parallel + code_task = asyncio.create_task(asyncio.to_thread(code_agent, f"Write code for: {task}")) + docs_task = asyncio.create_task(asyncio.to_thread(docs_agent, f"Write docs for: {task}")) + test_task = asyncio.create_task(asyncio.to_thread(test_agent, f"Write tests for: {task}")) + + # Wait for all to complete + code, docs, tests = await asyncio.gather(code_task, docs_task, test_task) + + return { + "code": str(code), + "docs": str(docs), + "tests": str(tests) + } + +result = asyncio.run(parallel_workflow("Build a calculator")) +``` + +## Swarm Pattern + +Multiple agents collaborate: + +```python +from strands import Agent, tool +from strands.models import OpenAIModel +from typing import List + +class AgentSwarm: + def __init__(self, agents: List[Agent]): + self.agents = agents + + def collaborate(self, task: str) -> str: + """All agents contribute to solving the task.""" + results = [] + + for i, agent in enumerate(self.agents): + # Each agent sees previous results + context = "\n".join([f"Agent {j}: {r}" for j, r in enumerate(results)]) + prompt = f"Task: {task}\n\nPrevious contributions:\n{context}\n\nYour contribution:" + + result = agent(prompt) + results.append(str(result)) + + return "\n\n".join([f"Agent {i}: {r}" for i, r in enumerate(results)]) + +# Create swarm +swarm = AgentSwarm([ + Agent(model=OpenAIModel(model_id="gpt-4o"), system_prompt="You focus on architecture."), + Agent(model=OpenAIModel(model_id="gpt-4o"), system_prompt="You focus on implementation."), + Agent(model=OpenAIModel(model_id="gpt-4o"), system_prompt="You focus on optimization."), +]) + +result = swarm.collaborate("Design a scalable web service") +``` + +## Hierarchical Agents + +Manager-worker pattern: + +```python +from strands import Agent, tool +from strands.models import OpenAIModel + +# Create worker agents +workers = { + "data": Agent( + model=OpenAIModel(model_id="gpt-4o"), + system_prompt="You are a data processing specialist." + ), + "analysis": Agent( + model=OpenAIModel(model_id="gpt-4o"), + system_prompt="You are a data analysis specialist." + ), + "visualization": Agent( + model=OpenAIModel(model_id="gpt-4o"), + system_prompt="You are a data visualization specialist." + ) +} + +# Create tools for manager to delegate +@tool +def delegate_to_data(task: str) -> dict: + """Delegate data processing tasks.""" + result = workers["data"](task) + return {"status": "success", "content": [{"text": str(result)}]} + +@tool +def delegate_to_analysis(task: str) -> dict: + """Delegate data analysis tasks.""" + result = workers["analysis"](task) + return {"status": "success", "content": [{"text": str(result)}]} + +@tool +def delegate_to_visualization(task: str) -> dict: + """Delegate data visualization tasks.""" + result = workers["visualization"](task) + return {"status": "success", "content": [{"text": str(result)}]} + +# Create manager agent +manager = Agent( + model=OpenAIModel(model_id="gpt-4o"), + tools=[delegate_to_data, delegate_to_analysis, delegate_to_visualization], + system_prompt="""You are a manager. Delegate tasks to specialists: + - delegate_to_data: For data processing + - delegate_to_analysis: For data analysis + - delegate_to_visualization: For creating visualizations + """ +) + +# Manager delegates automatically +response = manager("Analyze sales data and create a chart") +``` + +## State Sharing + +Share state between agents: + +```python +from strands import Agent +from strands.models import OpenAIModel + +# Shared state +shared_state = { + "context": [], + "decisions": [] +} + +# Agent 1: Gather information +agent1 = Agent( + model=OpenAIModel(model_id="gpt-4o"), + state=shared_state, + system_prompt="You gather information. Save findings to state['context']." +) + +# Agent 2: Make decisions +agent2 = Agent( + model=OpenAIModel(model_id="gpt-4o"), + state=shared_state, + system_prompt="You make decisions based on state['context']. Save to state['decisions']." +) + +# Workflow with shared state +result1 = agent1("Research market trends") +shared_state["context"].append(str(result1)) + +result2 = agent2("Based on the research, what should we do?") +shared_state["decisions"].append(str(result2)) + +print(f"Context: {shared_state['context']}") +print(f"Decisions: {shared_state['decisions']}") +``` + +## Communication Patterns + +### Message Passing + +```python +from strands import Agent +from strands.models import OpenAIModel +from queue import Queue + +class MessageBus: + def __init__(self): + self.queues = {} + + def create_queue(self, agent_id: str): + self.queues[agent_id] = Queue() + + def send(self, to_agent: str, message: str): + if to_agent in self.queues: + self.queues[to_agent].put(message) + + def receive(self, agent_id: str) -> str: + if agent_id in self.queues and not self.queues[agent_id].empty(): + return self.queues[agent_id].get() + return None + +# Create message bus +bus = MessageBus() +bus.create_queue("agent1") +bus.create_queue("agent2") + +# Agents communicate via bus +agent1 = Agent(model=OpenAIModel(model_id="gpt-4o")) +agent2 = Agent(model=OpenAIModel(model_id="gpt-4o")) + +# Agent 1 sends message +result1 = agent1("Analyze this data: [1,2,3,4,5]") +bus.send("agent2", str(result1)) + +# Agent 2 receives and responds +message = bus.receive("agent2") +result2 = agent2(f"Based on this analysis: {message}, what's the next step?") +``` + +## Best Practices + +### 1. Clear Responsibilities + +Give each agent a specific role: + +```python +# ✅ Good - Clear specialization +researcher = Agent(system_prompt="You are a research specialist. Only provide factual information.") +writer = Agent(system_prompt="You are a writing specialist. Only create content.") + +# ❌ Bad - Vague role +agent = Agent(system_prompt="You are a helpful assistant.") +``` + +### 2. Avoid Circular Dependencies + +```python +# ❌ Bad - Circular delegation +@tool +def delegate_to_b(task: str) -> dict: + return {"status": "success", "content": [{"text": str(agent_b(task))}]} + +@tool +def delegate_to_a(task: str) -> dict: + return {"status": "success", "content": [{"text": str(agent_a(task))}]} + +agent_a = Agent(tools=[delegate_to_b]) +agent_b = Agent(tools=[delegate_to_a]) # Circular! + +# ✅ Good - Hierarchical structure +manager = Agent(tools=[delegate_to_worker1, delegate_to_worker2]) +worker1 = Agent() +worker2 = Agent() +``` + +### 3. Limit Delegation Depth + +Prevent infinite loops: + +```python +class DelegationTracker: + def __init__(self, max_depth: int = 3): + self.depth = 0 + self.max_depth = max_depth + + def can_delegate(self) -> bool: + return self.depth < self.max_depth + + def enter(self): + self.depth += 1 + + def exit(self): + self.depth -= 1 + +tracker = DelegationTracker(max_depth=3) + +@tool +def safe_delegate(task: str) -> dict: + """Delegate with depth limit.""" + if not tracker.can_delegate(): + return {"status": "error", "content": [{"text": "Max delegation depth reached"}]} + + tracker.enter() + try: + result = worker_agent(task) + return {"status": "success", "content": [{"text": str(result)}]} + finally: + tracker.exit() +``` + +### 4. Error Handling + +Handle agent failures gracefully: + +```python +def robust_delegation(agents: List[Agent], task: str) -> str: + """Try multiple agents until one succeeds.""" + for agent in agents: + try: + result = agent(task) + return str(result) + except Exception as e: + print(f"Agent failed: {e}") + continue + + raise RuntimeError("All agents failed") +``` + +### 5. Monitor Performance + +Track agent metrics: + +```python +from strands import Agent +from strands.models import OpenAIModel + +agent = Agent(model=OpenAIModel(model_id="gpt-4o")) +result = agent("Process this task") + +# Access metrics +print(f"Input tokens: {result.metrics.input_tokens}") +print(f"Output tokens: {result.metrics.output_tokens}") +print(f"Total time: {result.metrics.total_time}ms") +``` diff --git a/.agents/skills/strandsagents/references/observability.md b/.agents/skills/strandsagents/references/observability.md new file mode 100644 index 0000000..37d274e --- /dev/null +++ b/.agents/skills/strandsagents/references/observability.md @@ -0,0 +1,422 @@ +# Observability and Telemetry + +## OpenTelemetry Integration + +### Basic Setup + +```python +from strands import Agent +from strands.telemetry import StrandsTelemetry + +# Quick setup with console and OTLP exporters +StrandsTelemetry().setup_console_exporter().setup_otlp_exporter() + +# Create agent +agent = Agent() +response = agent("Hello, world!") +``` + +### Environment Variables + +Configure via environment: + +```bash +# Service identification +export OTEL_SERVICE_NAME=my-agent-service + +# OTLP endpoint +export OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:4318 + +# Sampling +export OTEL_TRACES_SAMPLER=always_on +``` + +### Custom Trace Attributes + +```python +from strands import Agent + +agent = Agent( + trace_attributes={ + "environment": "production", + "version": "1.0.0", + "user.id": "user-123", + "team": "platform" + } +) +response = agent("Process this data") +``` + +## Metrics + +### Access Agent Metrics + +```python +from strands import Agent + +agent = Agent() +result = agent("Process this data") + +# Access metrics +metrics = result.metrics +print(f"Input tokens: {metrics.input_tokens}") +print(f"Output tokens: {metrics.output_tokens}") +print(f"Total latency: {metrics.total_time}ms") +print(f"Model calls: {metrics.model_calls}") +``` + +### Setup Metrics Exporter + +```python +from strands.telemetry import StrandsTelemetry + +telemetry = StrandsTelemetry() +telemetry.setup_console_exporter() +telemetry.setup_otlp_exporter() +telemetry.setup_meter(enable_console_exporter=True, enable_otlp_exporter=True) +``` + +## Custom Tracing + +### Custom Spans + +```python +from strands import Agent +from strands.telemetry import get_tracer + +tracer = get_tracer() + +with tracer.tracer.start_as_current_span("custom_operation") as span: + span.set_attribute("custom.key", "value") + span.set_attribute("operation.type", "data_processing") + + agent = Agent() + result = agent("Perform traced operation") + + span.set_attribute("result.tokens", result.metrics.output_tokens) +``` + +### Nested Spans + +```python +from strands.telemetry import get_tracer + +tracer = get_tracer() + +with tracer.tracer.start_as_current_span("workflow") as workflow_span: + workflow_span.set_attribute("workflow.id", "wf-123") + + with tracer.tracer.start_as_current_span("step_1") as step1_span: + step1_span.set_attribute("step.name", "data_collection") + # Step 1 logic + + with tracer.tracer.start_as_current_span("step_2") as step2_span: + step2_span.set_attribute("step.name", "data_processing") + # Step 2 logic +``` + +## Structured Logging + +### Logging Format + +Use structured logging with field-value pairs: + +```python +import logging + +logger = logging.getLogger(__name__) + +# ✅ Good - Structured format +logger.debug("field1=<%s>, field2=<%s> | human readable message", field1, field2) +logger.info("request_id=<%s>, duration_ms=<%d> | request completed", request_id, duration) +logger.warning("attempt=<%d>, max_attempts=<%d> | retry limit approaching", attempt, max_attempts) + +# ❌ Bad - Unstructured +logger.info(f"Request {request_id} completed in {duration}ms") +``` + +### Logging Best Practices + +```python +import logging + +logger = logging.getLogger(__name__) + +# Use %s for interpolation (performance) +logger.debug("user_id=<%s>, action=<%s> | user performed action", user_id, action) + +# Separate fields and message with | +logger.info("request_id=<%s>, duration_ms=<%d> | request completed", request_id, duration) + +# Lowercase messages, no punctuation +logger.warning("attempt=<%d>, max_attempts=<%d> | retry limit approaching", attempt, max_attempts) + +# Multiple statements with | +logger.info("user_id=<%s> | processing request | starting validation", user_id) +``` + +## Hooks for Observability + +### Logging Hooks + +```python +from strands import Agent +from strands.hooks import ( + HookProvider, HookRegistry, + AgentInitializedEvent, BeforeInvocationEvent, AfterInvocationEvent, + BeforeToolCallEvent, AfterToolCallEvent +) +import time +import logging + +logger = logging.getLogger(__name__) + +class LoggingHooks(HookProvider): + """Custom hooks for logging and monitoring.""" + + def __init__(self): + self.call_count = 0 + self.total_duration = 0 + + def register_hooks(self, registry: HookRegistry, **kwargs) -> None: + registry.add_callback(AgentInitializedEvent, self.on_init) + registry.add_callback(BeforeInvocationEvent, self.on_before_invoke) + registry.add_callback(AfterInvocationEvent, self.on_after_invoke) + registry.add_callback(BeforeToolCallEvent, self.on_before_tool) + registry.add_callback(AfterToolCallEvent, self.on_after_tool) + + def on_init(self, event: AgentInitializedEvent) -> None: + logger.info("agent_name=<%s> | agent initialized", event.agent.name) + + def on_before_invoke(self, event: BeforeInvocationEvent) -> None: + self.call_count += 1 + self.start_time = time.time() + logger.info("invocation=<%d> | starting invocation", self.call_count) + + def on_after_invoke(self, event: AfterInvocationEvent) -> None: + duration = time.time() - self.start_time + self.total_duration += duration + logger.info( + "invocation=<%d>, duration_ms=<%d>, stop_reason=<%s> | invocation completed", + self.call_count, int(duration * 1000), event.result.stop_reason if event.result else "unknown" + ) + + def on_before_tool(self, event: BeforeToolCallEvent) -> None: + logger.info("tool=<%s> | calling tool", event.tool.tool_name) + + def on_after_tool(self, event: AfterToolCallEvent) -> None: + logger.info("tool=<%s> | tool completed", event.tool.tool_name) + +# Use hooks +logging_hooks = LoggingHooks() +agent = Agent(hooks=[logging_hooks], name="MyAssistant") +response = agent("Hello!") + +print(f"Total calls: {logging_hooks.call_count}") +print(f"Total duration: {logging_hooks.total_duration:.2f}s") +``` + +### Metrics Hooks + +```python +from strands.hooks import HookProvider, HookRegistry, AfterInvocationEvent +from collections import defaultdict + +class MetricsHooks(HookProvider): + """Track custom metrics.""" + + def __init__(self): + self.metrics = defaultdict(int) + self.token_usage = {"input": 0, "output": 0} + + def register_hooks(self, registry: HookRegistry, **kwargs) -> None: + registry.add_callback(AfterInvocationEvent, self.track_metrics) + + def track_metrics(self, event: AfterInvocationEvent) -> None: + if event.result: + self.metrics["invocations"] += 1 + self.token_usage["input"] += event.result.metrics.input_tokens + self.token_usage["output"] += event.result.metrics.output_tokens + + def get_summary(self) -> dict: + return { + "total_invocations": self.metrics["invocations"], + "total_input_tokens": self.token_usage["input"], + "total_output_tokens": self.token_usage["output"], + "total_tokens": self.token_usage["input"] + self.token_usage["output"] + } + +metrics_hooks = MetricsHooks() +agent = Agent(hooks=[metrics_hooks]) + +# Use agent +agent("Task 1") +agent("Task 2") + +# Get metrics +summary = metrics_hooks.get_summary() +print(f"Summary: {summary}") +``` + +## Integration with Monitoring Systems + +### Prometheus Metrics + +```python +from prometheus_client import Counter, Histogram, start_http_server +from strands import Agent +from strands.hooks import HookProvider, HookRegistry, AfterInvocationEvent + +# Define metrics +invocations_total = Counter('agent_invocations_total', 'Total agent invocations') +invocation_duration = Histogram('agent_invocation_duration_seconds', 'Invocation duration') +tokens_total = Counter('agent_tokens_total', 'Total tokens used', ['type']) + +class PrometheusHooks(HookProvider): + """Export metrics to Prometheus.""" + + def register_hooks(self, registry: HookRegistry, **kwargs) -> None: + registry.add_callback(AfterInvocationEvent, self.export_metrics) + + def export_metrics(self, event: AfterInvocationEvent) -> None: + if event.result: + invocations_total.inc() + invocation_duration.observe(event.result.metrics.total_time / 1000) + tokens_total.labels(type='input').inc(event.result.metrics.input_tokens) + tokens_total.labels(type='output').inc(event.result.metrics.output_tokens) + +# Start Prometheus server +start_http_server(8000) + +# Use agent with Prometheus hooks +agent = Agent(hooks=[PrometheusHooks()]) +``` + +### CloudWatch Metrics + +```python +import boto3 +from strands import Agent +from strands.hooks import HookProvider, HookRegistry, AfterInvocationEvent + +cloudwatch = boto3.client('cloudwatch') + +class CloudWatchHooks(HookProvider): + """Export metrics to CloudWatch.""" + + def __init__(self, namespace: str): + self.namespace = namespace + + def register_hooks(self, registry: HookRegistry, **kwargs) -> None: + registry.add_callback(AfterInvocationEvent, self.export_metrics) + + def export_metrics(self, event: AfterInvocationEvent) -> None: + if event.result: + cloudwatch.put_metric_data( + Namespace=self.namespace, + MetricData=[ + { + 'MetricName': 'Invocations', + 'Value': 1, + 'Unit': 'Count' + }, + { + 'MetricName': 'InputTokens', + 'Value': event.result.metrics.input_tokens, + 'Unit': 'Count' + }, + { + 'MetricName': 'OutputTokens', + 'Value': event.result.metrics.output_tokens, + 'Unit': 'Count' + }, + { + 'MetricName': 'Duration', + 'Value': event.result.metrics.total_time, + 'Unit': 'Milliseconds' + } + ] + ) + +agent = Agent(hooks=[CloudWatchHooks(namespace='MyAgent')]) +``` + +## Debugging + +### Enable Debug Logging + +```python +import logging + +# Enable debug logging +logging.basicConfig( + level=logging.DEBUG, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) + +from strands import Agent + +agent = Agent() +response = agent("Debug this") +``` + +### Inspect Messages + +```python +from strands import Agent + +agent = Agent() +agent("Hello, my name is Alice") +agent("I like pizza") + +# Inspect conversation history +print(f"Message count: {len(agent.messages)}") +for msg in agent.messages: + print(f"{msg['role']}: {msg['content']}") +``` + +### Trace Tool Calls + +```python +from strands import Agent, tool +from strands.hooks import HookProvider, HookRegistry, BeforeToolCallEvent, AfterToolCallEvent +import logging + +logger = logging.getLogger(__name__) + +class ToolTraceHooks(HookProvider): + """Trace all tool calls.""" + + def register_hooks(self, registry: HookRegistry, **kwargs) -> None: + registry.add_callback(BeforeToolCallEvent, self.before_tool) + registry.add_callback(AfterToolCallEvent, self.after_tool) + + def before_tool(self, event: BeforeToolCallEvent) -> None: + logger.debug( + "tool=<%s>, args=<%s> | calling tool", + event.tool.tool_name, + event.tool_input + ) + + def after_tool(self, event: AfterToolCallEvent) -> None: + logger.debug( + "tool=<%s>, result=<%s> | tool completed", + event.tool.tool_name, + str(event.result)[:100] # Truncate for logging + ) + +agent = Agent(hooks=[ToolTraceHooks()]) +``` + +## Best Practices + +1. **Use Structured Logging**: Field-value pairs for easy parsing +2. **Set Trace Attributes**: Add context to traces (user_id, version, etc.) +3. **Monitor Token Usage**: Track costs via metrics +4. **Implement Health Checks**: Verify agent availability +5. **Use Hooks for Observability**: Centralize monitoring logic +6. **Export to Monitoring Systems**: Integrate with Prometheus, CloudWatch, etc. +7. **Debug with Logging**: Enable DEBUG level for troubleshooting +8. **Track Performance**: Monitor latency and throughput +9. **Alert on Anomalies**: Set up alerts for errors and performance issues +10. **Retain Traces**: Keep traces for debugging and analysis diff --git a/.agents/skills/strandsagents/references/openai.md b/.agents/skills/strandsagents/references/openai.md new file mode 100644 index 0000000..506b05d --- /dev/null +++ b/.agents/skills/strandsagents/references/openai.md @@ -0,0 +1,128 @@ +# OpenAI Integration + +## Basic Configuration + +```python +from strands import Agent +from strands.models import OpenAIModel + +# OpenAI with API key +openai_model = OpenAIModel( + model_id="gpt-4o", + client_args={"api_key": "your-openai-api-key"}, + params={"temperature": 0.7, "max_tokens": 2048} +) +agent = Agent(model=openai_model) +response = agent("Hello, how are you?") +``` + +## Environment Variables + +Set your API key via environment variable: + +```bash +export OPENAI_API_KEY="your-openai-api-key" +``` + +Then use without explicit key: + +```python +from strands import Agent +from strands.models import OpenAIModel + +openai_model = OpenAIModel( + model_id="gpt-4o", + params={"temperature": 0.7, "max_tokens": 2048} +) +agent = Agent(model=openai_model) +``` + +## Available Models + +- `gpt-4o` - Latest GPT-4 Optimized +- `gpt-4o-mini` - Faster, cost-effective +- `gpt-4-turbo` - GPT-4 Turbo +- `gpt-3.5-turbo` - Fast and efficient + +## Model Parameters + +```python +openai_model = OpenAIModel( + model_id="gpt-4o", + params={ + "temperature": 0.7, # 0.0-2.0, controls randomness + "max_tokens": 2048, # Maximum response length + "top_p": 0.9, # Nucleus sampling + "frequency_penalty": 0, # -2.0 to 2.0 + "presence_penalty": 0 # -2.0 to 2.0 + } +) +``` + +## Streaming Responses + +```python +from strands import Agent +from strands.models import OpenAIModel +import asyncio + +async def stream_openai(): + openai_model = OpenAIModel( + model_id="gpt-4o", + client_args={"api_key": "your-openai-api-key"} + ) + agent = Agent(model=openai_model) + + async for event in agent.stream_async("Tell me a story"): + if "data" in event: + print(event["data"], end="", flush=True) + +asyncio.run(stream_openai()) +``` + +## With Custom Tools + +```python +from strands import Agent, tool +from strands.models import OpenAIModel + +@tool +def get_weather(city: str) -> dict: + """Get weather for a city.""" + return { + "status": "success", + "content": [{"text": f"Weather in {city}: Sunny, 22°C"}] + } + +openai_model = OpenAIModel( + model_id="gpt-4o", + client_args={"api_key": "your-openai-api-key"} +) +agent = Agent(model=openai_model, tools=[get_weather]) +response = agent("What's the weather in Paris?") +``` + +## Error Handling + +```python +from strands import Agent +from strands.models import OpenAIModel + +try: + openai_model = OpenAIModel( + model_id="gpt-4o", + client_args={"api_key": "your-openai-api-key"} + ) + agent = Agent(model=openai_model) + response = agent("Hello!") +except Exception as e: + print(f"Error: {e}") +``` + +## Best Practices + +1. **API Key Security**: Never hardcode API keys. Use environment variables or secure vaults. +2. **Rate Limits**: OpenAI has rate limits. Implement retry logic for production. +3. **Cost Management**: Monitor token usage via `result.metrics` to control costs. +4. **Model Selection**: Use `gpt-4o-mini` for cost-effective tasks, `gpt-4o` for complex reasoning. +5. **Temperature**: Lower (0.0-0.3) for deterministic outputs, higher (0.7-1.0) for creative tasks. diff --git a/.agents/skills/strandsagents/references/quickstart.md b/.agents/skills/strandsagents/references/quickstart.md new file mode 100644 index 0000000..5909788 --- /dev/null +++ b/.agents/skills/strandsagents/references/quickstart.md @@ -0,0 +1,79 @@ +# Strands Agents - Quick Start + +## Installation + +```bash +# Create and activate virtual environment +python -m venv .venv +source .venv/bin/activate # On Windows: .venv\Scripts\activate + +# Install Strands and tools +pip install strands-agents strands-agents-tools +``` + +## Basic Agent + +```python +from strands import Agent + +# Default Bedrock model +agent = Agent() +response = agent("What is the capital of France?") +print(response) +``` + +## Agent with Custom Model + +```python +from strands import Agent +from strands.models import BedrockModel + +agent = Agent( + model=BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0"), + system_prompt="You are a helpful coding assistant. Be concise and provide examples." +) +result = agent("How do I read a JSON file in Python?") +print(result.message) +print(result.stop_reason) # "end_turn", "max_tokens", etc. +print(result.metrics) # Performance metrics +``` + +## Agent with Built-in Tools + +```python +from strands import Agent +from strands_tools import calculator + +agent = Agent(tools=[calculator]) +response = agent("What is the square root of 1764?") +``` + +## Agent with Initial State + +```python +from strands import Agent + +agent = Agent( + messages=[ + {"role": "user", "content": [{"text": "My name is Alice"}]}, + {"role": "assistant", "content": [{"text": "Nice to meet you, Alice!"}]} + ], + state={"user_preference": "dark_mode"} +) +response = agent("What's my name?") # Agent remembers: "Your name is Alice" +``` + +## Async Streaming + +```python +from strands import Agent +import asyncio + +async def stream_response(): + agent = Agent() + async for event in agent.stream_async("Tell me a story"): + if "data" in event: + print(event["data"], end="", flush=True) + +asyncio.run(stream_response()) +``` diff --git a/.agents/skills/strandsagents/references/security.md b/.agents/skills/strandsagents/references/security.md new file mode 100644 index 0000000..6d0beea --- /dev/null +++ b/.agents/skills/strandsagents/references/security.md @@ -0,0 +1,312 @@ +# Security Best Practices + +## API Key Management + +### Environment Variables + +**NEVER hardcode API keys in source code.** Always use environment variables: + +```python +import os +from strands import Agent +from strands.models import OpenAIModel + +# ✅ Good - Use environment variable +openai_model = OpenAIModel( + model_id="gpt-4o", + client_args={"api_key": os.environ.get("OPENAI_API_KEY")} +) + +# ❌ Bad - Hardcoded API key +openai_model = OpenAIModel( + model_id="gpt-4o", + client_args={"api_key": "sk-proj-abc123..."} # NEVER DO THIS +) +``` + +### .env Files + +Use `.env` files for local development: + +```bash +# .env +OPENAI_API_KEY=sk-proj-abc123... +ANTHROPIC_API_KEY=sk-ant-abc123... +``` + +Load with python-dotenv: + +```python +from dotenv import load_dotenv +import os + +load_dotenv() + +openai_key = os.environ.get("OPENAI_API_KEY") +``` + +**Important**: Add `.env` to `.gitignore`: + +```gitignore +.env +.env.local +*.env +``` + +### AWS Secrets Manager + +For production, use secure vaults: + +```python +import boto3 +import json + +def get_secret(secret_name): + client = boto3.client('secretsmanager') + response = client.get_secret_value(SecretId=secret_name) + return json.loads(response['SecretString']) + +secrets = get_secret("my-app/api-keys") +openai_key = secrets["OPENAI_API_KEY"] +``` + +## Input Validation + +### Validate Tool Inputs + +Always validate and sanitize inputs: + +```python +from strands import tool + +@tool +def execute_query(query: str) -> dict: + """Execute a database query.""" + # Validate input + if not query or len(query) > 1000: + return {"status": "error", "content": [{"text": "Invalid query"}]} + + # Sanitize - prevent SQL injection + if any(keyword in query.lower() for keyword in ["drop", "delete", "truncate"]): + return {"status": "error", "content": [{"text": "Forbidden operation"}]} + + # Execute safely + result = safe_execute(query) + return {"status": "success", "content": [{"text": result}]} +``` + +### Type Checking + +Use type hints and runtime validation: + +```python +from strands import tool +from typing import Union + +@tool +def process_data(data: str, max_length: int = 100) -> dict: + """Process data with length limit.""" + if not isinstance(data, str): + return {"status": "error", "content": [{"text": "Data must be string"}]} + + if not isinstance(max_length, int) or max_length <= 0: + return {"status": "error", "content": [{"text": "Invalid max_length"}]} + + if len(data) > max_length: + return {"status": "error", "content": [{"text": f"Data exceeds {max_length} chars"}]} + + return {"status": "success", "content": [{"text": f"Processed: {data}"}]} +``` + +## Guardrails + +### AWS Bedrock Guardrails + +Use Bedrock guardrails for content filtering: + +```python +from strands import Agent +from strands.models import BedrockModel + +bedrock_model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", + guardrail_id="my-guardrail-id", + guardrail_version="1" +) +agent = Agent(model=bedrock_model) +``` + +### Custom Content Filtering + +Implement custom filters: + +```python +from strands import Agent +from strands.hooks import HookProvider, HookRegistry, BeforeInvocationEvent + +class ContentFilterHooks(HookProvider): + """Filter sensitive content.""" + + BLOCKED_PATTERNS = ["password", "api_key", "secret"] + + def register_hooks(self, registry: HookRegistry, **kwargs) -> None: + registry.add_callback(BeforeInvocationEvent, self.filter_content) + + def filter_content(self, event: BeforeInvocationEvent) -> None: + # Check user input for sensitive patterns + user_input = str(event.invocation_state.get("input", "")).lower() + for pattern in self.BLOCKED_PATTERNS: + if pattern in user_input: + raise ValueError(f"Blocked pattern detected: {pattern}") + +agent = Agent(hooks=[ContentFilterHooks()]) +``` + +## Rate Limiting + +### Implement Rate Limits + +Protect against abuse: + +```python +from strands import Agent, tool +from strands.types.tools import ToolContext +import time +from collections import defaultdict + +class RateLimiter: + def __init__(self, max_calls: int, window_seconds: int): + self.max_calls = max_calls + self.window_seconds = window_seconds + self.calls = defaultdict(list) + + def check(self, user_id: str) -> bool: + now = time.time() + # Remove old calls + self.calls[user_id] = [t for t in self.calls[user_id] if now - t < self.window_seconds] + + if len(self.calls[user_id]) >= self.max_calls: + return False + + self.calls[user_id].append(now) + return True + +rate_limiter = RateLimiter(max_calls=10, window_seconds=60) + +@tool(context=True) +def rate_limited_operation(data: str, tool_context: ToolContext) -> dict: + """Operation with rate limiting.""" + user_id = tool_context.agent.state.get("user_id", "anonymous") + + if not rate_limiter.check(user_id): + return {"status": "error", "content": [{"text": "Rate limit exceeded"}]} + + # Process operation + return {"status": "success", "content": [{"text": f"Processed: {data}"}]} +``` + +## Secure Tool Execution + +### Sandbox Tool Execution + +Isolate tool execution: + +```python +import subprocess +from strands import tool + +@tool +def execute_code(code: str) -> dict: + """Execute code in sandboxed environment.""" + # Validate code + if not code or len(code) > 1000: + return {"status": "error", "content": [{"text": "Invalid code"}]} + + # Block dangerous imports + blocked = ["os", "sys", "subprocess", "eval", "exec"] + if any(module in code for module in blocked): + return {"status": "error", "content": [{"text": "Forbidden imports"}]} + + try: + # Execute in restricted environment + result = subprocess.run( + ["python", "-c", code], + capture_output=True, + text=True, + timeout=5, + env={"PYTHONPATH": ""} # Restricted environment + ) + return {"status": "success", "content": [{"text": result.stdout}]} + except subprocess.TimeoutExpired: + return {"status": "error", "content": [{"text": "Execution timeout"}]} + except Exception as e: + return {"status": "error", "content": [{"text": f"Error: {str(e)}"}]} +``` + +## Logging and Monitoring + +### Secure Logging + +Never log sensitive information: + +```python +import logging + +logger = logging.getLogger(__name__) + +# ✅ Good - No sensitive data +logger.info("user_id=<%s> | user authenticated", user_id) + +# ❌ Bad - Logs API key +logger.info("api_key=<%s> | authentication successful", api_key) # NEVER DO THIS + +# ✅ Good - Mask sensitive data +logger.info("api_key=<%s> | authentication successful", api_key[:8] + "***") +``` + +### Audit Trail + +Track agent operations: + +```python +from strands import Agent +from strands.hooks import HookProvider, HookRegistry, AfterToolCallEvent +import logging + +class AuditHooks(HookProvider): + """Audit trail for tool calls.""" + + def register_hooks(self, registry: HookRegistry, **kwargs) -> None: + registry.add_callback(AfterToolCallEvent, self.log_tool_call) + + def log_tool_call(self, event: AfterToolCallEvent) -> None: + logger.info( + "tool=<%s>, user_id=<%s>, status=<%s> | tool executed", + event.tool.tool_name, + event.tool_context.agent.state.get("user_id", "anonymous"), + "success" if event.result else "error" + ) + +agent = Agent(hooks=[AuditHooks()]) +``` + +## Vulnerability Reporting + +If you discover a security vulnerability: + +1. **DO NOT** create a public GitHub issue +2. Report via [AWS Security Vulnerability Reporting](http://aws.amazon.com/security/vulnerability-reporting/) +3. Provide detailed description and reproduction steps + +## Security Checklist + +- [ ] API keys stored in environment variables or secure vaults +- [ ] `.env` files in `.gitignore` +- [ ] Input validation on all tools +- [ ] Rate limiting implemented +- [ ] Content filtering for sensitive data +- [ ] Secure logging (no API keys/secrets) +- [ ] Audit trail for operations +- [ ] Sandboxed tool execution for code +- [ ] Error messages don't leak sensitive info +- [ ] Dependencies regularly updated diff --git a/.agents/skills/strandsagents/references/testing.md b/.agents/skills/strandsagents/references/testing.md new file mode 100644 index 0000000..f29fc3f --- /dev/null +++ b/.agents/skills/strandsagents/references/testing.md @@ -0,0 +1,487 @@ +# Testing Strands Agents + +## Development Environment + +### Setup + +```bash +# Enter dev environment +hatch shell + +# Install pre-commit hooks +pre-commit install -t pre-commit -t commit-msg +``` + +### Run Tests + +```bash +# Run unit tests +hatch test + +# Run with coverage +hatch test -c + +# Run integration tests +hatch run test-integ + +# Test specific directory +hatch test tests/strands/agent/ + +# Test across all Python versions +hatch test --all +``` + +## Unit Testing + +### Test File Structure + +``` +tests/strands/ +├── agent/ +│ ├── test_agent.py +│ └── test_conversation_manager.py +└── tools/ + └── test_tools.py +``` + +### Basic Agent Test + +```python +# tests/strands/agent/test_agent.py +import pytest +from strands import Agent +from strands.models import BedrockModel + +def test_agent_creation(): + """Test basic agent creation.""" + agent = Agent() + assert agent is not None + assert agent.messages == [] + +def test_agent_with_system_prompt(): + """Test agent with custom system prompt.""" + system_prompt = "You are a helpful assistant." + agent = Agent(system_prompt=system_prompt) + assert agent.system_prompt == system_prompt + +def test_agent_with_state(): + """Test agent with initial state.""" + state = {"user_id": "123"} + agent = Agent(state=state) + assert agent.state["user_id"] == "123" +``` + +### Tool Testing + +```python +# tests/strands/tools/test_tools.py +import pytest +from strands import tool + +@tool +def test_tool(text: str) -> dict: + """A test tool.""" + return {"status": "success", "content": [{"text": f"Processed: {text}"}]} + +def test_tool_execution(): + """Test tool execution.""" + result = test_tool("hello") + assert result["status"] == "success" + assert "Processed: hello" in result["content"][0]["text"] + +def test_tool_with_invalid_input(): + """Test tool with invalid input.""" + with pytest.raises(TypeError): + test_tool(123) # Should be string +``` + +### Async Tool Testing + +```python +import pytest +import asyncio +from strands import tool + +@tool +async def async_test_tool(data: str) -> dict: + """An async test tool.""" + await asyncio.sleep(0.1) + return {"status": "success", "content": [{"text": f"Async: {data}"}]} + +@pytest.mark.asyncio +async def test_async_tool(): + """Test async tool execution.""" + result = await async_test_tool("test") + assert result["status"] == "success" + assert "Async: test" in result["content"][0]["text"] +``` + +### Mock Model Testing + +```python +import pytest +from unittest.mock import Mock, patch +from strands import Agent + +def test_agent_with_mock_model(): + """Test agent with mocked model.""" + mock_model = Mock() + mock_model.invoke.return_value = { + "role": "assistant", + "content": [{"text": "Mocked response"}] + } + + agent = Agent(model=mock_model) + # Test agent behavior without real API calls +``` + +## Integration Testing + +### Test with Real Models + +```python +# tests_integ/test_openai_integration.py +import pytest +import os +from strands import Agent +from strands.models import OpenAIModel + +@pytest.fixture +def openai_agent(): + """Create OpenAI agent for testing.""" + api_key = os.environ.get("OPENAI_API_KEY") + if not api_key: + pytest.skip("OPENAI_API_KEY not set") + + model = OpenAIModel( + model_id="gpt-4o-mini", + client_args={"api_key": api_key} + ) + return Agent(model=model) + +def test_basic_conversation(openai_agent): + """Test basic conversation with OpenAI.""" + response = openai_agent("What is 2+2?") + assert response is not None + assert "4" in str(response).lower() + +def test_tool_usage(openai_agent): + """Test tool usage with OpenAI.""" + from strands import tool + + @tool + def add(a: int, b: int) -> dict: + """Add two numbers.""" + return {"status": "success", "content": [{"text": str(a + b)}]} + + agent = Agent( + model=openai_agent.model, + tools=[add] + ) + response = agent("Use the add tool to calculate 5 + 3") + assert "8" in str(response) +``` + +### Test Conversation Flow + +```python +def test_conversation_memory(openai_agent): + """Test that agent remembers conversation.""" + openai_agent("My name is Alice") + response = openai_agent("What is my name?") + assert "alice" in str(response).lower() +``` + +## Testing Hooks + +```python +import pytest +from strands import Agent +from strands.hooks import ( + HookProvider, HookRegistry, + BeforeInvocationEvent, AfterInvocationEvent +) + +class TestHooks(HookProvider): + """Test hooks for verification.""" + + def __init__(self): + self.before_called = False + self.after_called = False + + def register_hooks(self, registry: HookRegistry, **kwargs) -> None: + registry.add_callback(BeforeInvocationEvent, self.on_before) + registry.add_callback(AfterInvocationEvent, self.on_after) + + def on_before(self, event: BeforeInvocationEvent) -> None: + self.before_called = True + + def on_after(self, event: AfterInvocationEvent) -> None: + self.after_called = True + +def test_hooks_execution(): + """Test that hooks are called.""" + test_hooks = TestHooks() + agent = Agent(hooks=[test_hooks]) + + # Mock invocation + assert test_hooks.before_called is False + assert test_hooks.after_called is False + + # After invocation, hooks should be called + # (requires mock or real model) +``` + +## Testing Best Practices + +### 1. Use Fixtures + +```python +import pytest +from strands import Agent + +@pytest.fixture +def basic_agent(): + """Basic agent fixture.""" + return Agent() + +@pytest.fixture +def agent_with_tools(): + """Agent with tools fixture.""" + from strands_tools import calculator + return Agent(tools=[calculator]) + +def test_with_fixture(basic_agent): + """Test using fixture.""" + assert basic_agent is not None +``` + +### 2. Parametrize Tests + +```python +import pytest +from strands import tool + +@tool +def process(text: str) -> dict: + """Process text.""" + return {"status": "success", "content": [{"text": text.upper()}]} + +@pytest.mark.parametrize("input_text,expected", [ + ("hello", "HELLO"), + ("world", "WORLD"), + ("test", "TEST"), +]) +def test_process_parametrized(input_text, expected): + """Test with multiple inputs.""" + result = process(input_text) + assert expected in result["content"][0]["text"] +``` + +### 3. Test Error Handling + +```python +import pytest +from strands import tool + +@tool +def divide(a: int, b: int) -> dict: + """Divide two numbers.""" + if b == 0: + return {"status": "error", "content": [{"text": "Division by zero"}]} + return {"status": "success", "content": [{"text": str(a / b)}]} + +def test_divide_by_zero(): + """Test error handling.""" + result = divide(10, 0) + assert result["status"] == "error" + assert "Division by zero" in result["content"][0]["text"] +``` + +### 4. Test Async Code + +```python +import pytest +import asyncio +from strands import Agent + +@pytest.mark.asyncio +async def test_async_streaming(): + """Test async streaming.""" + agent = Agent() + + chunks = [] + async for event in agent.stream_async("Tell me a short story"): + if "data" in event: + chunks.append(event["data"]) + + assert len(chunks) > 0 +``` + +### 5. Clean Up Resources + +```python +import pytest +from strands import Agent + +@pytest.fixture +def agent(): + """Agent fixture with cleanup.""" + agent = Agent() + yield agent + # Cleanup + agent.messages.clear() + +def test_with_cleanup(agent): + """Test with automatic cleanup.""" + agent("Hello") + assert len(agent.messages) > 0 + # Cleanup happens automatically after test +``` + +## Code Quality + +### Formatting + +```bash +# Check formatting +hatch fmt --formatter + +# Auto-fix formatting +hatch fmt +``` + +### Linting + +```bash +# Run linter +hatch fmt --linter + +# Type checking with mypy +mypy src/ +``` + +### Pre-commit Hooks + +```bash +# Run all hooks manually +pre-commit run --all-files + +# Install hooks +pre-commit install +``` + +## Continuous Integration + +### GitHub Actions Example + +```yaml +name: Tests + +on: [push, pull_request] + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['3.9', '3.10', '3.11'] + + steps: + - uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + pip install hatch + hatch env create + + - name: Run tests + run: hatch test -c + + - name: Run linting + run: hatch fmt --linter +``` + +## Test Coverage + +### Generate Coverage Report + +```bash +# Run tests with coverage +hatch test -c + +# Generate HTML report +coverage html + +# View report +open htmlcov/index.html +``` + +### Coverage Configuration + +```toml +# pyproject.toml +[tool.coverage.run] +branch = true +source = ["src/strands"] + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "def __repr__", + "raise AssertionError", + "raise NotImplementedError", + "if __name__ == .__main__.:", +] +``` + +## Debugging Tests + +### Run Single Test + +```bash +# Run specific test +pytest tests/strands/agent/test_agent.py::test_agent_creation + +# Run with verbose output +pytest -v tests/strands/agent/test_agent.py + +# Run with print statements +pytest -s tests/strands/agent/test_agent.py +``` + +### Debug with pdb + +```python +import pytest + +def test_with_debugger(): + """Test with debugger.""" + agent = Agent() + + # Set breakpoint + import pdb; pdb.set_trace() + + response = agent("Hello") + assert response is not None +``` + +## Testing Checklist + +- [ ] Unit tests for all core functionality +- [ ] Integration tests with real models +- [ ] Test error handling and edge cases +- [ ] Test async functionality +- [ ] Mock external dependencies +- [ ] Use fixtures for common setup +- [ ] Parametrize tests for multiple inputs +- [ ] Clean up resources after tests +- [ ] Maintain high test coverage (>80%) +- [ ] Run tests in CI/CD pipeline +- [ ] Document test requirements +- [ ] Test hooks and lifecycle events diff --git a/.agents/skills/strandsagents/references/tools.md b/.agents/skills/strandsagents/references/tools.md new file mode 100644 index 0000000..f0f2e7d --- /dev/null +++ b/.agents/skills/strandsagents/references/tools.md @@ -0,0 +1,215 @@ +# Custom Tools + +## Creating Tools with @tool Decorator + +### Basic Tool + +```python +from strands import Agent, tool + +@tool +def word_count(text: str) -> int: + """Count words in text. + + This docstring is used by the LLM to understand the tool's purpose. + """ + return len(text.split()) + +agent = Agent(tools=[word_count]) +response = agent("How many words are in this sentence?") +``` + +### Tool with Multiple Parameters + +```python +from strands import tool + +@tool +def get_weather(city: str, units: str = "celsius") -> dict: + """Get current weather for a city. + + Args: + city: Name of the city to get weather for. + units: Temperature units, either 'celsius' or 'fahrenheit'. + + Returns: + Weather data including temperature and conditions. + """ + # Simulated weather API call + weather_data = {"city": city, "temperature": 22, "units": units, "conditions": "sunny"} + return { + "status": "success", + "content": [{"text": f"Weather in {city}: {weather_data['temperature']}°{units[0].upper()}, {weather_data['conditions']}"}] + } +``` + +### Tool Response Format + +Tools should return a dict with this structure: + +```python +{ + "status": "success", # or "error" + "content": [ + {"text": "Response text here"} + ] +} +``` + +## Context-Aware Tools + +Access agent state and context: + +```python +from strands import Agent, tool +from strands.types.tools import ToolContext + +@tool(context=True) +def save_note(note: str, tool_context: ToolContext) -> dict: + """Save a note to the agent's state. + + Args: + note: The note content to save. + """ + agent = tool_context.agent + if "notes" not in agent.state: + agent.state["notes"] = [] + agent.state["notes"].append(note) + return {"status": "success", "content": [{"text": f"Note saved: {note}"}]} + +agent = Agent(tools=[save_note]) +agent("Save a note: Remember to call mom") +print(agent.state["notes"]) # ['Remember to call mom'] +``` + +## Async Tools + +For non-blocking operations: + +```python +from strands import tool +import aiohttp + +@tool +async def fetch_url(url: str) -> dict: + """Fetch content from a URL asynchronously. + + Args: + url: The URL to fetch. + """ + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + content = await response.text() + return {"status": "success", "content": [{"text": content[:500]}]} +``` + +## Direct Tool Invocation + +Call tools directly without agent: + +```python +from strands import Agent, tool + +@tool +def calculate(expression: str) -> dict: + """Evaluate a mathematical expression.""" + result = eval(expression) + return {"status": "success", "content": [{"text": str(result)}]} + +agent = Agent(tools=[calculate]) + +# Direct invocation +result = agent.tool.calculate(expression="2 + 2") +print(result) +``` + +## Dynamic Tool Loading + +Load tools from directory with hot reloading: + +```python +from strands import Agent + +# Enable hot reloading from ./tools/ directory +agent = Agent(load_tools_from_directory=True) + +# Create tools/my_tool.py: +# from strands import tool +# +# @tool +# def my_dynamic_tool(param: str) -> dict: +# """A dynamically loaded tool.""" +# return {"status": "success", "content": [{"text": f"Processed: {param}"}]} + +# Agent automatically detects and loads tools +response = agent("Use my_dynamic_tool with param='hello'") + +# List available tools +print(f"Available tools: {agent.tool_names}") +``` + +## Built-in Tools + +Use pre-built tools from strands-agents-tools: + +```python +from strands import Agent +from strands_tools import calculator + +agent = Agent(tools=[calculator]) +response = agent("What is the square root of 1764?") +``` + +## Tool Best Practices + +### 1. Clear Docstrings + +Use Google-style docstrings for LLM understanding: + +```python +@tool +def example_function(param1: str, param2: int) -> dict: + """Brief description of function. + + Longer description if needed. This docstring is used by LLMs + to understand the function's purpose when used as a tool. + + Args: + param1: Description of param1 + param2: Description of param2 + + Returns: + Description of return value + + Raises: + ValueError: When invalid input is provided + """ + pass +``` + +### 2. Type Hints + +Always include type hints for parameters and return values. + +### 3. Error Handling + +Return error status for failures: + +```python +@tool +def risky_operation(data: str) -> dict: + """Perform a risky operation.""" + try: + result = process(data) + return {"status": "success", "content": [{"text": result}]} + except Exception as e: + return {"status": "error", "content": [{"text": f"Error: {str(e)}"}]} +``` + +### 4. Idempotency + +Design tools to be safely retried without side effects. + +### 5. Performance + +Use async tools for I/O-bound operations to avoid blocking. diff --git a/.gitignore b/.gitignore index 9a1b727..dc1a41a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,10 +1,43 @@ +# Python __pycache__/ -node_modules/ -test-results/ -playwright-report/ venv/ .venv/ uv.lock -.env -.vscode/ .pytest_cache/ +.ruff_cache/ +*egg-info/ + +# JS / Node +node_modules/ +dist/ +dist-ssr +*.local +logs +*.log +npm-debug.log* +yarn-debug.log* +yarn-error.log* +pnpm-debug.log* +lerna-debug.log* + +# Test output +test-results/ +playwright-report/ + +# Frontend generated +frontend_omni/public/modules.json + +# Editor directories and files +.vscode/* +!.vscode/extensions.json +.idea +.DS_Store +*.suo +*.ntvs* +*.njsproj +*.sln +*.sw? + +# Other +.env +*.db diff --git a/AGENTS.md b/AGENTS.md index bf08d6e..c8f1d3a 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -123,8 +123,8 @@ Before any backend work, read relevant architecture documents: ### Testing - **Framework**: pytest -- **Command**: `uv run pytest` or `uv run pytest tests/test_*.py` -- **Location**: `backend/tests/` +- **Command**: `uv run pytest` +- **Location**: Tests live alongside source code in `__tests__/` directories under `backend/src/modai/` and `backend/src/modai/modules/*/` - **Test Coverage**: Always add unit tests for new features or bug fixes - **Test Isolation**: Use mocking for external dependencies - **Atomic Tests**: Each test function should test one specific behavior diff --git a/backend/.gitignore b/backend/.gitignore deleted file mode 100644 index fec00f2..0000000 --- a/backend/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -*egg-info/ -*.db diff --git a/backend/config.yaml b/backend/config.yaml index 722149d..1d2087f 100644 --- a/backend/config.yaml +++ b/backend/config.yaml @@ -10,7 +10,7 @@ modules: chat_openai: chat_openai session: "session" chat_openai: - class: modai.modules.chat.openai_llm_chat.OpenAILLMChatModule + class: modai.modules.chat.openai_agent_chat.StrandsAgentChatModule module_dependencies: llm_provider_module: openai_model_provider model_provider_store: diff --git a/backend/docs/architecture/auth.md b/backend/docs/architecture/auth.md index 9109289..e800daf 100644 --- a/backend/docs/architecture/auth.md +++ b/backend/docs/architecture/auth.md @@ -38,8 +38,8 @@ flowchart TD - Integration with session management **API Endpoints**: -- `POST /api/v1/auth/login` - User authentication (200 OK / 401 Unauthorized / 422 Unprocessable Entity) -- `POST /api/v1/auth/logout` - Session termination (200 OK / 401 Unauthorized) +- `POST /api/auth/login` - User authentication (200 OK / 401 Unauthorized / 422 Unprocessable Entity) +- `POST /api/auth/logout` - Session termination (200 OK / 401 Unauthorized) **Dependencies**: - Session Module (for session creation/destruction) @@ -166,7 +166,7 @@ Because of the variety of options, the session module interface is rather generi - REST API for permission discovery **API Endpoints**: -- `GET /api/v1/auth/permissions` - List all registered permissions for client discovery (200 OK / 401 Unauthorized if auth required) +- `GET /api/auth/permissions` - List all registered permissions for client discovery (200 OK / 401 Unauthorized if auth required) **Key Functions**: ```python @@ -192,7 +192,7 @@ async def validate_permission(self, user_id: str, resource: str, action: str) -> @dataclass class PermissionDefinition: """Definition of a permission for registration purposes""" - resource: str # e.g., "/api/v1/documents", "/api/v1/user/*" + resource: str # e.g., "/api/documents", "/api/user/*" actions: list[str] # e.g., ["read", "write", "delete"] resource_name: str # Human-readable name, e.g., "Document Library", "User Management" description: str | None = None # Optional detailed description @@ -204,13 +204,13 @@ class PermissionDefinition: { "permissions": [ { - "resource": "/api/v1/provider/*/models", + "resource": "/api/provider/*/models", "actions": ["read"], "resource_name": "Large Language Models", "description": "Models available through the AI provider" }, { - "resource": "/api/v1/file/*", + "resource": "/api/file/*", "actions": ["read", "write", "delete"], "resource_name": "User Files", "description": "Access individual uploated files outside a document library" @@ -233,18 +233,18 @@ The resource identifier is a string, so arbitrary content can be put in, but it advisable to use a **pseudo endpoint notation** which reflects the endpoints of the module exactly or at least to a certain extent. -If a module e.g. has an endpoint `/api/v1/files` then this is also a good candidate +If a module e.g. has an endpoint `/api/files` then this is also a good candidate for the resource identifier. -If a module has several endpoints like `/api/v1/file/{id}/title`, -`/api/v1/file/{id}/name`, ... then it is not advisable to create permissions for -each single endpoint, but instead use a more generic one like `/api/v1/file/*` as +If a module has several endpoints like `/api/file/{id}/title`, +`/api/file/{id}/name`, ... then it is not advisable to create permissions for +each single endpoint, but instead use a more generic one like `/api/file/*` as resource name. In some cases it can even be interesting to share resource identifiers across modules. E.g. if there are several LLM Provider modules which should follow the endpoint pattern -`/api/v1/provider`, then usually we don't want to have permissions for each provider -individually. Here a resource identifier of `/api/v1/provider/` could be shared amongst +`/api/provider`, then usually we don't want to have permissions for each provider +individually. Here a resource identifier of `/api/provider/` could be shared amongst all provider modules. Benefits of Pseudo-Endpoint-based Permissions: @@ -289,7 +289,7 @@ class SomeWebModule(ModaiModule, ABC): """Register all permissions used by this module""" self.authorization_module.register_permission( PermissionDefinition( - resource="/api/v1/some", + resource="/api/some", actions=["read", "write", "delete"], resource_name="Some Resources" ) @@ -314,7 +314,7 @@ class SomeWebModule(ModaiModule): ... # Add routes - self.router.add_api_route("/api/v1/some", self.get_some, methods=["GET"]) + self.router.add_api_route("/api/some", self.get_some, methods=["GET"]) ... async def get_some(self, request: Request): @@ -323,7 +323,7 @@ class SomeWebModule(ModaiModule): # 2. Validate endpoint permissions await self.authorization_module.validate_permission( - session.user_id, "/api/v1/some", "read" + session.user_id, "/api/some", "read" ) # 3. Process request @@ -353,7 +353,7 @@ sequenceDiagram Auth->>Session: start_new_session() Session-->>Auth: Session created Auth-->>Client: 200 OK (session cookie) - Client->>SomeWebModule: GET /api/v1/file/1 + Client->>SomeWebModule: GET /api/file/1 SomeWebModule->>Session: validate_session() Session-->>SomeWebModule: Session data SomeWebModule->>Authorization: validate_permission() diff --git a/backend/docs/architecture/core.md b/backend/docs/architecture/core.md index ff88193..2771f54 100644 --- a/backend/docs/architecture/core.md +++ b/backend/docs/architecture/core.md @@ -84,7 +84,7 @@ class HealthModule(ModaiModule, ABC): def __init__(self, dependencies: ModuleDependencies, config: dict[str, Any]): super().__init__(dependencies, config) self.router = APIRouter() # This makes it a web module - self.router.add_api_route("/api/v1/health", self.get_health, methods=["GET"]) + self.router.add_api_route("/api/health", self.get_health, methods=["GET"]) @abstractmethod def get_health(self) -> dict[str, Any]: diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 40f99cb..1626236 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -12,6 +12,8 @@ dependencies = [ "pyjwt", "python-dotenv", "sqlalchemy", + "strands-agents", + "strands-agents-tools", ] [dependency-groups] @@ -22,5 +24,11 @@ dev = [ "ruff", ] +[tool.pytest.ini_options] +testpaths = ["src"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] + [tool.uv] package = true diff --git a/backend/tests/__init__.py b/backend/src/modai/__tests__/__init__.py similarity index 100% rename from backend/tests/__init__.py rename to backend/src/modai/__tests__/__init__.py diff --git a/backend/tests/test_module_loader.py b/backend/src/modai/__tests__/test_module_loader.py similarity index 85% rename from backend/tests/test_module_loader.py rename to backend/src/modai/__tests__/test_module_loader.py index 09619e2..9a4543e 100644 --- a/backend/tests/test_module_loader.py +++ b/backend/src/modai/__tests__/test_module_loader.py @@ -1,8 +1,5 @@ -import sys -import os from typing import Any -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) from modai.module import ModaiModule, ModuleDependencies from modai.module_loader import ModuleLoader @@ -29,7 +26,7 @@ def test_import_class_success(): startup_config = { "modules": { "foo": { - "class": "tests.test_module_loader.DummyModule", + "class": "modai.__tests__.test_module_loader.DummyModule", "config": {"some": "nicevalue"}, } } @@ -47,7 +44,10 @@ def test_load_module_disabled(caplog): """Test loading a disabled module.""" startup_config = { "modules": { - "foo": {"class": "tests.test_module_loader.DummyModule", "enabled": False} + "foo": { + "class": "modai.__tests__.test_module_loader.DummyModule", + "enabled": False, + } } } loader = ModuleLoader(startup_config) @@ -79,11 +79,11 @@ def test_module_dependencies(): startup_config = { "modules": { "bar": { - "class": "tests.test_module_loader.DummyModule", + "class": "modai.__tests__.test_module_loader.DummyModule", "module_dependencies": {"foo": "foo"}, }, "foo": { - "class": "tests.test_module_loader.DummyModule", + "class": "modai.__tests__.test_module_loader.DummyModule", }, } } @@ -107,15 +107,15 @@ def test_module_dependencies_chain(): startup_config = { "modules": { "baz": { - "class": "tests.test_module_loader.DummyModule", + "class": "modai.__tests__.test_module_loader.DummyModule", "module_dependencies": {"bar": "bar"}, }, "bar": { - "class": "tests.test_module_loader.DummyModule", + "class": "modai.__tests__.test_module_loader.DummyModule", "module_dependencies": {"foo": "foo"}, }, "foo": { - "class": "tests.test_module_loader.DummyModule", + "class": "modai.__tests__.test_module_loader.DummyModule", }, } } @@ -144,7 +144,7 @@ def test_module_dependencies_unresolvable(): startup_config = { "modules": { "bar": { - "class": "tests.test_module_loader.DummyModule", + "class": "modai.__tests__.test_module_loader.DummyModule", "module_dependencies": {"foo": "nonexistent"}, }, } diff --git a/backend/src/modai/default_config.yaml b/backend/src/modai/default_config.yaml index 9e233f2..24aa6c7 100644 --- a/backend/src/modai/default_config.yaml +++ b/backend/src/modai/default_config.yaml @@ -9,7 +9,7 @@ modules: module_dependencies: chat_openai: chat_openai chat_openai: - class: modai.modules.chat.openai_llm_chat.OpenAILLMChatModule + class: modai.modules.chat.openai_agent_chat.StrandsAgentChatModule module_dependencies: llm_provider_module: openai_model_provider model_provider_store: diff --git a/backend/src/modai/modules/authentication/__tests__/__init__.py b/backend/src/modai/modules/authentication/__tests__/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/test_authentication.py b/backend/src/modai/modules/authentication/__tests__/test_authentication.py similarity index 92% rename from backend/tests/test_authentication.py rename to backend/src/modai/modules/authentication/__tests__/test_authentication.py index 2a0fb20..a85a59f 100644 --- a/backend/tests/test_authentication.py +++ b/backend/src/modai/modules/authentication/__tests__/test_authentication.py @@ -1,11 +1,7 @@ -import sys -import os import pytest from unittest.mock import Mock, MagicMock, AsyncMock from fastapi.testclient import TestClient from fastapi import FastAPI - -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) from modai.module import ModuleDependencies from modai.modules.authentication.password_authentication_module import ( PasswordAuthenticationModule, @@ -60,7 +56,7 @@ def test_login_success(client): session_module.start_new_session.return_value = None payload = {"email": "admin@example.com", "password": "admin"} - response = test_client.post("/api/v1/auth/login", json=payload) + response = test_client.post("/api/auth/login", json=payload) assert response.status_code == 200 response_data = response.json() @@ -87,7 +83,7 @@ def test_login_invalid_credentials(client): user_store.get_user_credentials.return_value = test_credentials payload = {"email": "admin@example.com", "password": "wrong-password"} - response = test_client.post("/api/v1/auth/login", json=payload) + response = test_client.post("/api/auth/login", json=payload) assert response.status_code == 401 assert response.json()["detail"] == "Invalid email or password" @@ -102,7 +98,7 @@ def test_login_nonexistent_user(client): user_store.get_user_by_email.return_value = None payload = {"email": "nonexistent@example.com", "password": "password"} - response = test_client.post("/api/v1/auth/login", json=payload) + response = test_client.post("/api/auth/login", json=payload) assert response.status_code == 401 assert response.json()["detail"] == "Invalid email or password" @@ -118,7 +114,7 @@ def test_logout_with_valid_session_cookie(client): session_module.end_session.return_value = None # Test logout - logout_response = test_client.post("/api/v1/auth/logout") + logout_response = test_client.post("/api/auth/logout") assert logout_response.status_code == 200 assert logout_response.json()["message"] == "Successfully logged out" @@ -133,7 +129,7 @@ def test_logout_with_invalid_session_cookie(client): # Mock that session module doesn't raise any exception session_module.end_session.return_value = None - response = test_client.post("/api/v1/auth/logout") + response = test_client.post("/api/auth/logout") assert response.status_code == 200 assert response.json()["message"] == "Successfully logged out" @@ -148,7 +144,7 @@ def test_logout_without_session_cookie(client): # Mock that session module doesn't raise any exception session_module.end_session.return_value = None - response = test_client.post("/api/v1/auth/logout") + response = test_client.post("/api/auth/logout") assert response.status_code == 200 assert response.json()["message"] == "Successfully logged out" @@ -168,7 +164,7 @@ def test_login_user_without_credentials(client): user_store.get_user_credentials.return_value = None # No credentials payload = {"email": "admin@example.com", "password": "admin"} - response = test_client.post("/api/v1/auth/login", json=payload) + response = test_client.post("/api/auth/login", json=payload) assert response.status_code == 401 assert response.json()["detail"] == "Invalid email or password" @@ -191,7 +187,7 @@ def test_signup_success(client): "password": "password123", "full_name": "New User", } - response = test_client.post("/api/v1/auth/signup", json=payload) + response = test_client.post("/api/auth/signup", json=payload) assert response.status_code == 200 response_data = response.json() @@ -220,7 +216,7 @@ def test_signup_existing_user(client): "password": "password123", "full_name": "New User", } - response = test_client.post("/api/v1/auth/signup", json=payload) + response = test_client.post("/api/auth/signup", json=payload) assert response.status_code == 400 assert response.json()["detail"] == "User with this email already exists" @@ -246,7 +242,7 @@ def test_signup_password_creation_failure(client): "password": "password123", "full_name": "Test User", } - response = test_client.post("/api/v1/auth/signup", json=payload) + response = test_client.post("/api/auth/signup", json=payload) assert response.status_code == 500 assert response.json()["detail"] == "Failed to create user account" diff --git a/backend/src/modai/modules/authentication/module.py b/backend/src/modai/modules/authentication/module.py index a7f3506..d828b33 100644 --- a/backend/src/modai/modules/authentication/module.py +++ b/backend/src/modai/modules/authentication/module.py @@ -44,8 +44,8 @@ def __init__(self, dependencies: ModuleDependencies, config: dict[str, Any]): self.security = HTTPBearer() # Add authentication routes - self.router.add_api_route("/api/v1/auth/login", self.login, methods=["POST"]) - self.router.add_api_route("/api/v1/auth/logout", self.logout, methods=["POST"]) + self.router.add_api_route("/api/auth/login", self.login, methods=["POST"]) + self.router.add_api_route("/api/auth/logout", self.logout, methods=["POST"]) @abstractmethod async def login( diff --git a/backend/src/modai/modules/authentication/password_authentication_module.py b/backend/src/modai/modules/authentication/password_authentication_module.py index 6a7f873..a8eeace 100644 --- a/backend/src/modai/modules/authentication/password_authentication_module.py +++ b/backend/src/modai/modules/authentication/password_authentication_module.py @@ -47,7 +47,7 @@ def __init__(self, dependencies: ModuleDependencies, config: dict[str, Any]): self.user_store: UserStore = dependencies.modules.get("user_store") # Add password authentication specific routes - self.router.add_api_route("/api/v1/auth/signup", self.signup, methods=["POST"]) + self.router.add_api_route("/api/auth/signup", self.signup, methods=["POST"]) async def login( self, diff --git a/backend/src/modai/modules/chat/__tests__/__init__.py b/backend/src/modai/modules/chat/__tests__/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/test_chat.py b/backend/src/modai/modules/chat/__tests__/test_openai_raw_chat.py similarity index 95% rename from backend/tests/test_chat.py rename to backend/src/modai/modules/chat/__tests__/test_openai_raw_chat.py index 6f7c13a..88db28d 100644 --- a/backend/tests/test_chat.py +++ b/backend/src/modai/modules/chat/__tests__/test_openai_raw_chat.py @@ -1,5 +1,4 @@ from pathlib import Path -import sys import os from dotenv import find_dotenv, load_dotenv import pytest @@ -7,10 +6,8 @@ from openai import AsyncOpenAI from unittest.mock import Mock, MagicMock, AsyncMock from fastapi.testclient import TestClient - -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) from modai.module import ModuleDependencies -from modai.modules.chat.openai_llm_chat import OpenAILLMChatModule +from modai.modules.chat.openai_raw_chat import OpenAILLMChatModule from modai.modules.chat.web_chat_router import ChatWebModule from modai.modules.chat.module import ChatLLMModule from modai.modules.session.module import SessionModule, Session @@ -60,6 +57,7 @@ async def openai_client(request): if client_type == "direct": openai_client = AsyncOpenAI( api_key=os.environ["OPENAI_API_KEY"], + base_url=os.environ.get("OPENAI_BASE_URL"), ) yield openai_client else: @@ -83,7 +81,7 @@ async def test_llm_generate_response(): id="test_provider", type="openai", name="myopenai", - base_url="https://api.openai.com/v1", + base_url=os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1"), api_key=os.environ["OPENAI_API_KEY"], properties={}, created_at=None, @@ -109,8 +107,9 @@ async def test_llm_generate_response(): request = Mock(spec=Request) # Test non-streaming + model = os.environ.get("OPENAI_MODEL", "gpt-5") body_json = { - "model": "myopenai/gpt-4o", + "model": f"myopenai/{model}", "input": [{"role": "user", "content": "Just echo the word 'Hello'"}], } @@ -152,7 +151,7 @@ async def test_llm_generate_response_streaming(): id="test_provider", type="openai", name="myopenai", - base_url="https://api.openai.com/v1", + base_url=os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1"), api_key=os.environ["OPENAI_API_KEY"], properties={}, created_at=None, @@ -178,8 +177,9 @@ async def test_llm_generate_response_streaming(): request = Mock(spec=Request) # Test streaming + model = os.environ.get("OPENAI_MODEL", "gpt-5") body_json = { - "model": "myopenai/gpt-4o", + "model": f"myopenai/{model}", "input": [{"role": "user", "content": "Just echo the word 'Hello'"}], "stream": True, } @@ -212,7 +212,7 @@ async def test_chat_responses_api(openai_client: AsyncOpenAI, request): """Test chat responses API.""" request.node.callspec.params["openai_client"] - model = "gpt-4o" # No backend_proxy + model = os.environ.get("OPENAI_MODEL", "gpt-5") # Make the request response = await openai_client.responses.create( @@ -244,7 +244,7 @@ async def test_chat_responses_api_streaming(openai_client: AsyncOpenAI, request) """Test streaming chat responses API.""" request.node.callspec.params["openai_client"] - model = "gpt-4o" # No backend_proxy + model = os.environ.get("OPENAI_MODEL", "gpt-5") # Make the streaming request stream = await openai_client.responses.create( @@ -446,7 +446,7 @@ async def test_openai_llm_provider_not_found(): # Test with non-existent provider body_json = { - "model": "nonexistent/gpt-4", + "model": "nonexistent/gpt-5", "input": [{"role": "user", "content": "Hello"}], } @@ -482,7 +482,7 @@ def test_responses_endpoint_rejects_unauthenticated_request(): client = TestClient(app) response = client.post( - "/api/v1/responses", + "/api/responses", json={"model": "dummy/test_model", "input": "hello"}, ) assert response.status_code == 401 diff --git a/backend/src/modai/modules/chat/__tests__/test_strands_agent_chat.py b/backend/src/modai/modules/chat/__tests__/test_strands_agent_chat.py new file mode 100644 index 0000000..fee2cc1 --- /dev/null +++ b/backend/src/modai/modules/chat/__tests__/test_strands_agent_chat.py @@ -0,0 +1,489 @@ +"""Tests for the StrandsAgentChatModule.""" + +import os +from pathlib import Path + +import pytest +from dotenv import find_dotenv, load_dotenv +from unittest.mock import AsyncMock, Mock, patch +from dataclasses import dataclass, field + +from fastapi import Request + +from modai.module import ModuleDependencies +from modai.modules.chat.openai_agent_chat import ( + StrandsAgentChatModule, + _parse_model, + _extract_last_user_message, + _build_conversation_history, + _to_strands_message, + _message_text, + _build_openai_response, +) +from modai.modules.model_provider.module import ( + ModelProviderResponse, + ModelProvidersListResponse, +) +import openai + +working_dir = Path.cwd() +load_dotenv(find_dotenv(str(working_dir / ".env"))) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def _make_provider(name: str = "myprovider") -> ModelProviderResponse: + return ModelProviderResponse( + id="provider_1", + type="openai", + name=name, + base_url="https://api.openai.com/v1", + api_key="sk-test-key", + properties={}, + created_at=None, + updated_at=None, + ) + + +def _make_mock_provider_module(providers: list[ModelProviderResponse] | None = None): + providers = providers or [_make_provider()] + mock = Mock() + mock.get_providers = AsyncMock( + return_value=ModelProvidersListResponse( + providers=providers, + total=len(providers), + limit=None, + offset=None, + ) + ) + return mock + + +def _make_dependencies(provider_module=None): + provider_module = provider_module or _make_mock_provider_module() + deps = ModuleDependencies({"llm_provider_module": provider_module}) + return deps + + +# --------------------------------------------------------------------------- +# _parse_model +# --------------------------------------------------------------------------- + + +class TestParseModel: + def test_valid_model(self): + provider, model = _parse_model("myprovider/gpt-4o") + assert provider == "myprovider" + assert model == "gpt-4o" + + def test_valid_model_with_slash_in_name(self): + provider, model = _parse_model("myprovider/azure/gpt-5") + assert provider == "myprovider" + assert model == "azure/gpt-5" + + def test_invalid_no_slash(self): + with pytest.raises(ValueError, match="Invalid model format"): + _parse_model("gpt-4o") + + +# --------------------------------------------------------------------------- +# _extract_last_user_message +# --------------------------------------------------------------------------- + + +class TestExtractLastUserMessage: + def test_string_input(self): + body = {"input": "Hello there"} + assert _extract_last_user_message(body) == "Hello there" + + def test_list_simple_content(self): + body = {"input": [{"role": "user", "content": "Hi"}]} + assert _extract_last_user_message(body) == "Hi" + + def test_list_structured_content(self): + body = { + "input": [ + {"role": "user", "content": [{"type": "input_text", "text": "Hello"}]} + ] + } + assert _extract_last_user_message(body) == "Hello" + + def test_multiple_messages_returns_last(self): + body = { + "input": [ + {"role": "user", "content": "First"}, + {"role": "assistant", "content": "Response"}, + {"role": "user", "content": "Second"}, + ] + } + assert _extract_last_user_message(body) == "Second" + + def test_empty_input(self): + assert _extract_last_user_message({"input": ""}) == "" + assert _extract_last_user_message({"input": []}) == "" + assert _extract_last_user_message({}) == "" + + +# --------------------------------------------------------------------------- +# _build_conversation_history +# --------------------------------------------------------------------------- + + +class TestBuildConversationHistory: + def test_string_input_returns_empty(self): + assert _build_conversation_history({"input": "Hello"}) == [] + + def test_single_message_returns_empty(self): + body = {"input": [{"role": "user", "content": "Hi"}]} + assert _build_conversation_history(body) == [] + + def test_multi_turn_excludes_last(self): + body = { + "input": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi"}, + {"role": "user", "content": "How are you?"}, + ] + } + history = _build_conversation_history(body) + assert len(history) == 2 + assert history[0]["role"] == "user" + assert history[0]["content"] == [{"text": "Hello"}] + assert history[1]["role"] == "assistant" + assert history[1]["content"] == [{"text": "Hi"}] + + +# --------------------------------------------------------------------------- +# _to_strands_message / _message_text +# --------------------------------------------------------------------------- + + +class TestMessageConversion: + def test_to_strands_message_simple(self): + msg = _to_strands_message({"role": "user", "content": "Hello"}) + assert msg == {"role": "user", "content": [{"text": "Hello"}]} + + def test_to_strands_message_structured(self): + msg = _to_strands_message( + {"role": "assistant", "content": [{"type": "output_text", "text": "Hey"}]} + ) + assert msg == {"role": "assistant", "content": [{"text": "Hey"}]} + + def test_message_text_string(self): + assert _message_text("Hello") == "Hello" + + def test_message_text_dict_string_content(self): + assert _message_text({"content": "Hello"}) == "Hello" + + def test_message_text_dict_list_content(self): + assert ( + _message_text({"content": [{"type": "input_text", "text": "Hi"}]}) == "Hi" + ) + + def test_message_text_none(self): + assert _message_text(None) == "" + + +# --------------------------------------------------------------------------- +# _build_openai_response +# --------------------------------------------------------------------------- + + +class TestBuildOpenAIResponse: + def test_builds_valid_response(self): + resp = _build_openai_response( + text="Hello!", + model="gpt-4o", + response_id="resp_test123", + msg_id="msg_test456", + input_tokens=10, + output_tokens=5, + ) + assert isinstance(resp, openai.types.responses.Response) + assert resp.id == "resp_test123" + assert resp.model == "gpt-4o" + assert resp.status == "completed" + assert resp.output[0].content[0].text == "Hello!" + assert resp.usage.input_tokens == 10 + assert resp.usage.output_tokens == 5 + assert resp.usage.total_tokens == 15 + + +# --------------------------------------------------------------------------- +# StrandsAgentChatModule.__init__ +# --------------------------------------------------------------------------- + + +class TestStrandsAgentChatModuleInit: + def test_raises_without_provider(self): + deps = ModuleDependencies({}) + with pytest.raises(ValueError, match="llm_provider_module"): + StrandsAgentChatModule(dependencies=deps, config={}) + + def test_creates_with_provider(self): + deps = _make_dependencies() + module = StrandsAgentChatModule(dependencies=deps, config={}) + assert module.provider_module is not None + + +# --------------------------------------------------------------------------- +# StrandsAgentChatModule.generate_response (mocked) +# --------------------------------------------------------------------------- + + +@dataclass +class _FakeUsage: + inputTokens: int = 10 + outputTokens: int = 20 + totalTokens: int = 30 + + def get(self, key, default=0): + return getattr(self, key, default) + + +@dataclass +class _FakeMetrics: + accumulated_usage: dict = field( + default_factory=lambda: { + "inputTokens": 10, + "outputTokens": 20, + "totalTokens": 30, + } + ) + + +@dataclass +class _FakeAgentResult: + text: str = "Mocked response" + metrics: _FakeMetrics = field(default_factory=_FakeMetrics) + stop_reason: str = "end_turn" + message: dict = field( + default_factory=lambda: { + "role": "assistant", + "content": [{"text": "Mocked response"}], + } + ) + + def __str__(self) -> str: + return self.text + + +@pytest.mark.asyncio +async def test_generate_response_non_streaming(): + """Non-streaming generate_response returns an OpenAI Response.""" + deps = _make_dependencies() + module = StrandsAgentChatModule(dependencies=deps, config={}) + request = Mock(spec=Request) + + fake_result = _FakeAgentResult() + + with ( + patch( + "modai.modules.chat.openai_agent_chat._create_agent" + ) as mock_create_agent, + patch("asyncio.to_thread", new_callable=AsyncMock, return_value=fake_result), + ): + mock_agent = Mock() + mock_create_agent.return_value = mock_agent + + body = { + "model": "myprovider/gpt-4o", + "input": [{"role": "user", "content": "Hello"}], + } + + result = await module.generate_response(request, body) + + assert isinstance(result, openai.types.responses.Response) + assert result.status == "completed" + assert result.output[0].content[0].text == "Mocked response" + assert result.usage.input_tokens == 10 + assert result.usage.output_tokens == 20 + + +@pytest.mark.asyncio +async def test_generate_response_streaming(): + """Streaming generate_response returns an async generator of events.""" + deps = _make_dependencies() + module = StrandsAgentChatModule(dependencies=deps, config={}) + request = Mock(spec=Request) + + async def fake_stream_async(prompt): + yield {"data": "Hello"} + yield {"data": " world"} + + with patch( + "modai.modules.chat.openai_agent_chat._create_agent" + ) as mock_create_agent: + mock_agent = Mock() + mock_agent.stream_async = fake_stream_async + mock_create_agent.return_value = mock_agent + + body = { + "model": "myprovider/gpt-4o", + "input": [{"role": "user", "content": "Hi"}], + "stream": True, + } + + result = await module.generate_response(request, body) + + # Result should be an async generator + assert hasattr(result, "__aiter__") + + events = [] + async for event in result: + events.append(event) + + # Expected: created, 2 text deltas, text done, completed + assert len(events) == 5 + + # First event is response.created + assert events[0].type == "response.created" + + # Delta events + assert events[1].type == "response.output_text.delta" + assert events[1].delta == "Hello" + assert events[2].type == "response.output_text.delta" + assert events[2].delta == " world" + + # Text done + assert events[3].type == "response.output_text.done" + assert events[3].text == "Hello world" + + # Completed + assert events[4].type == "response.completed" + assert events[4].response.output[0].content[0].text == "Hello world" + + +# --------------------------------------------------------------------------- +# Provider resolution +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_invalid_model_format(): + """Raises ValueError for an invalid model string.""" + deps = _make_dependencies() + module = StrandsAgentChatModule(dependencies=deps, config={}) + request = Mock(spec=Request) + + body = { + "model": "no_slash_model", + "input": "Hello", + } + with pytest.raises(ValueError, match="Invalid model format"): + await module.generate_response(request, body) + + +@pytest.mark.asyncio +async def test_provider_not_found(): + """Raises ValueError when provider name is unknown.""" + deps = _make_dependencies() + module = StrandsAgentChatModule(dependencies=deps, config={}) + request = Mock(spec=Request) + + body = { + "model": "unknown/gpt-4o", + "input": "Hello", + } + with pytest.raises(ValueError, match="Provider 'unknown' not found"): + await module.generate_response(request, body) + + +# --------------------------------------------------------------------------- +# Integration tests (require OPENAI_API_KEY in .env) +# --------------------------------------------------------------------------- + + +def _make_real_provider() -> ModelProviderResponse: + """Create a provider backed by the env-var credentials.""" + return ModelProviderResponse( + id="test_provider", + type="openai", + name="myopenai", + base_url=os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1"), + api_key=os.environ.get("OPENAI_API_KEY", ""), + properties={}, + created_at=None, + updated_at=None, + ) + + +def _make_real_dependencies() -> ModuleDependencies: + """Dependencies wired to the real provider from env vars.""" + provider = _make_real_provider() + provider_module = _make_mock_provider_module([provider]) + return ModuleDependencies({"llm_provider_module": provider_module}) + + +def _real_model() -> str: + """Return 'myopenai/' using OPENAI_MODEL from env.""" + model = os.environ.get("OPENAI_MODEL", "gpt-4o") + return f"myopenai/{model}" + + +@pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set") +@pytest.mark.asyncio +async def test_strands_generate_response_non_streaming_integration(): + """Integration: non-streaming response via Strands Agent + real LLM.""" + deps = _make_real_dependencies() + module = StrandsAgentChatModule(dependencies=deps, config={}) + request = Mock(spec=Request) + + body = { + "model": _real_model(), + "input": [{"role": "user", "content": "Just echo the word 'Hello'"}], + } + + result = await module.generate_response(request, body) + + assert isinstance(result, openai.types.responses.Response) + assert result.status == "completed" + assert result.output + assert len(result.output) > 0 + text = result.output[0].content[0].text + assert "Hello" in text + assert result.usage.input_tokens > 0 + assert result.usage.output_tokens > 0 + + +@pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set") +@pytest.mark.asyncio +async def test_strands_generate_response_streaming_integration(): + """Integration: streaming response via Strands Agent + real LLM.""" + deps = _make_real_dependencies() + module = StrandsAgentChatModule(dependencies=deps, config={}) + request = Mock(spec=Request) + + body = { + "model": _real_model(), + "input": [{"role": "user", "content": "Just echo the word 'Hello'"}], + "stream": True, + } + + result = await module.generate_response(request, body) + assert hasattr(result, "__aiter__") + + events = [] + async for event in result: + events.append(event) + + # Must have at least created + text done + completed + assert len(events) >= 3 + + # First is response.created + assert events[0].type == "response.created" + + # Collect text deltas + full_text = "" + for evt in events: + if hasattr(evt, "type") and evt.type == "response.output_text.delta": + full_text += evt.delta + + assert "Hello" in full_text + + # Last is response.completed + assert events[-1].type == "response.completed" + assert events[-1].response.output[0].content[0].text == full_text diff --git a/backend/src/modai/modules/chat/module.py b/backend/src/modai/modules/chat/module.py index c366239..5db8d2d 100644 --- a/backend/src/modai/modules/chat/module.py +++ b/backend/src/modai/modules/chat/module.py @@ -8,14 +8,20 @@ from fastapi.responses import StreamingResponse from typing import Any, AsyncGenerator from modai.module import ModaiModule, ModuleDependencies -import openai +from openai.types.responses import ( + Response as OpenAIResponse, + ResponseStreamEvent as OpenAIResponseStreamEvent, +) +from openai.types.responses.response_create_params import ( + ResponseCreateParams as OpenAICreateResponse, +) class ChatWebModule(ModaiModule, ABC): """ Module Declaration for: Chat Responses (Web Module) - Provides the /api/v1/responses endpoint for chat completions. + Provides the /api/responses endpoint for chat completions. Fully OpenAI /responses API compatible. """ @@ -23,7 +29,7 @@ def __init__(self, dependencies: ModuleDependencies, config: dict[str, Any]): super().__init__(dependencies, config) self.router = APIRouter() self.router.add_api_route( - "/api/v1/responses", + "/api/responses", self.responses_endpoint, methods=["POST"], response_model=None, # Disable response model since we return either Response or StreamingResponse @@ -33,8 +39,8 @@ def __init__(self, dependencies: ModuleDependencies, config: dict[str, Any]): async def responses_endpoint( self, request: Request, - body_json: openai.types.responses.ResponseCreateParams = Body(...), - ) -> openai.types.responses.Response | StreamingResponse: + body_json: OpenAICreateResponse = Body(...), + ) -> OpenAIResponse | StreamingResponse: """ Handles responses requests. Must be implemented by concrete implementations. Fully OpenAI /responses API compatible. @@ -61,11 +67,8 @@ def __init__(self, dependencies: ModuleDependencies, config: dict[str, Any]): @abstractmethod async def generate_response( - self, request: Request, body_json: openai.types.responses.ResponseCreateParams - ) -> ( - openai.types.responses.Response - | AsyncGenerator[openai.types.responses.ResponseStreamEvent, None] - ): + self, request: Request, body_json: OpenAICreateResponse + ) -> OpenAIResponse | AsyncGenerator[OpenAIResponseStreamEvent, None]: """ Generate a streaming or non-streaming chat response. diff --git a/backend/src/modai/modules/chat/openai_agent_chat.py b/backend/src/modai/modules/chat/openai_agent_chat.py new file mode 100644 index 0000000..428f904 --- /dev/null +++ b/backend/src/modai/modules/chat/openai_agent_chat.py @@ -0,0 +1,341 @@ +""" +Strands Agent Chat Module: ChatLLMModule implementation using Strands Agents SDK. + +Routes OpenAI-compatible requests through the Strands Agent framework with +OpenAI model provider. Tool support is planned for later — this module currently +only serves model requests via the framework. +""" + +import asyncio +import logging +import uuid +from datetime import datetime, timezone +from typing import Any, AsyncGenerator + +from fastapi import Request +from openai.types.responses import ( + Response as OpenAIResponse, + ResponseCompletedEvent, + ResponseCreatedEvent, + ResponseStreamEvent as OpenAIResponseStreamEvent, + ResponseTextDeltaEvent, + ResponseTextDoneEvent, +) +from openai.types.responses.response_create_params import ( + ResponseCreateParams as OpenAICreateResponse, +) +from strands import Agent +from strands.models import OpenAIModel + +from modai.module import ModuleDependencies +from modai.modules.chat.module import ChatLLMModule +from modai.modules.model_provider.module import ( + ModelProviderModule, + ModelProviderResponse, +) + +logger = logging.getLogger(__name__) + +DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant." + + +class StrandsAgentChatModule(ChatLLMModule): + """Strands Agent LLM Provider for Chat Responses. + + Implements the ChatLLMModule interface using the Strands Agents SDK + with OpenAI model provider. No tool support yet. + """ + + def __init__(self, dependencies: ModuleDependencies, config: dict[str, Any]): + super().__init__(dependencies, config) + + self.provider_module: ModelProviderModule = dependencies.get_module( + "llm_provider_module" + ) + if not self.provider_module: + raise ValueError( + "StrandsAgentChatModule requires 'llm_provider_module' module dependency" + ) + + async def generate_response( + self, request: Request, body_json: OpenAICreateResponse + ) -> OpenAIResponse | AsyncGenerator[OpenAIResponseStreamEvent, None]: + provider_name, actual_model = _parse_model(body_json.get("model", "")) + provider = await self._resolve_provider(request, provider_name) + agent = _create_agent(provider, actual_model, body_json) + user_message = _extract_last_user_message(body_json) + + if body_json.get("stream", False): + return _generate_streaming_response(agent, user_message, actual_model) + else: + return await _generate_non_streaming_response( + agent, user_message, actual_model + ) + + async def _resolve_provider( + self, request: Request, provider_name: str + ) -> ModelProviderResponse: + """Look up the provider by name from the provider module.""" + providers_response = await self.provider_module.get_providers( + request=request, limit=None, offset=None + ) + for p in providers_response.providers: + if p.name == provider_name: + return p + raise ValueError(f"Provider '{provider_name}' not found") + + +# --------------------------------------------------------------------------- +# Pure helper functions (module-private) +# --------------------------------------------------------------------------- + + +def _parse_model(model: str) -> tuple[str, str]: + """Parse ``provider_name/model_name`` into its components.""" + parts = model.split("/", maxsplit=1) + if len(parts) != 2 or not parts[0] or not parts[1]: + raise ValueError( + f"Invalid model format: {model}. Expected 'provider_name/model_name'" + ) + return parts[0], parts[1] + + +def _create_agent( + provider: ModelProviderResponse, + model_id: str, + body_json: OpenAICreateResponse, +) -> Agent: + """Build a fresh Strands ``Agent`` for this request.""" + client_args: dict[str, Any] = {"api_key": provider.api_key} + if provider.base_url: + client_args["base_url"] = provider.base_url + + model = OpenAIModel(model_id=model_id, client_args=client_args) + + system_prompt = body_json.get("instructions") or DEFAULT_SYSTEM_PROMPT + prior_messages = _build_conversation_history(body_json) + + return Agent( + model=model, + system_prompt=system_prompt, + messages=prior_messages or None, + callback_handler=None, # suppress default stdout printing + ) + + +def _build_conversation_history( + body_json: OpenAICreateResponse, +) -> list[dict[str, Any]]: + """Convert the ``input`` field into Strands-style messages. + + All messages *except* the last user message are returned as prior + history — the last user message is passed as the ``prompt`` argument + when invoking the agent. + """ + input_data = body_json.get("input", "") + if isinstance(input_data, str) or not isinstance(input_data, list): + return [] + if len(input_data) <= 1: + return [] + return [ + _to_strands_message(msg) for msg in input_data[:-1] if isinstance(msg, dict) + ] + + +def _extract_last_user_message(body_json: OpenAICreateResponse) -> str: + """Return the text of the last user message from ``input``.""" + input_data = body_json.get("input", "") + if isinstance(input_data, str): + return input_data + if isinstance(input_data, list) and input_data: + return _message_text(input_data[-1]) + return "" + + +def _to_strands_message(msg: dict[str, Any]) -> dict[str, Any]: + """Convert a single OpenAI Responses API message to Strands format.""" + role = msg.get("role", "user") + text = _message_text(msg) + return {"role": role, "content": [{"text": text}]} + + +def _message_text(msg: Any) -> str: + """Extract plain text from a message dict.""" + if isinstance(msg, str): + return msg + if isinstance(msg, dict): + content = msg.get("content", "") + if isinstance(content, str): + return content + if isinstance(content, list): + texts = [ + c.get("text", "") + for c in content + if isinstance(c, dict) + and c.get("type") in ("input_text", "text", "output_text") + ] + return " ".join(texts) + return "" + + +# --------------------------------------------------------------------------- +# Response builders +# --------------------------------------------------------------------------- + + +def _response_id() -> str: + return f"resp_{uuid.uuid4().hex[:24]}" + + +def _item_id() -> str: + return f"msg_{uuid.uuid4().hex[:24]}" + + +def _build_openai_response( + text: str, + model: str, + response_id: str, + msg_id: str, + input_tokens: int = 0, + output_tokens: int = 0, +) -> OpenAIResponse: + """Construct a fully-formed ``openai.types.responses.Response``.""" + return OpenAIResponse.model_validate( + { + "id": response_id, + "object": "response", + "created_at": datetime.now(timezone.utc).timestamp(), + "model": model, + "status": "completed", + "parallel_tool_calls": True, + "tool_choice": "auto", + "tools": [], + "output": [ + { + "type": "message", + "id": msg_id, + "role": "assistant", + "content": [ + { + "type": "output_text", + "text": text, + "annotations": [], + } + ], + "status": "completed", + } + ], + "usage": { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": input_tokens + output_tokens, + "input_tokens_details": {"cached_tokens": 0}, + "output_tokens_details": {"reasoning_tokens": 0}, + }, + } + ) + + +# --------------------------------------------------------------------------- +# Non-streaming +# --------------------------------------------------------------------------- + + +async def _generate_non_streaming_response( + agent: Agent, user_message: str, model: str +) -> OpenAIResponse: + """Run the agent synchronously (in a thread) and return an OpenAI Response.""" + result = await asyncio.to_thread(agent, user_message) + + text_output = str(result).strip() + + input_tokens = 0 + output_tokens = 0 + if hasattr(result, "metrics") and hasattr(result.metrics, "accumulated_usage"): + usage = result.metrics.accumulated_usage + input_tokens = usage.get("inputTokens", 0) + output_tokens = usage.get("outputTokens", 0) + + return _build_openai_response( + text=text_output, + model=model, + response_id=_response_id(), + msg_id=_item_id(), + input_tokens=input_tokens, + output_tokens=output_tokens, + ) + + +# --------------------------------------------------------------------------- +# Streaming +# --------------------------------------------------------------------------- + + +async def _generate_streaming_response( + agent: Agent, user_message: str, model: str +) -> AsyncGenerator[OpenAIResponseStreamEvent, None]: + """Stream text-delta events from the agent, book-ended by created/completed.""" + resp_id = _response_id() + msg_id = _item_id() + seq = 0 + + # --- response.created --------------------------------------------------- + stub_response = OpenAIResponse( + id=resp_id, + created_at=datetime.now(timezone.utc).timestamp(), + model=model, + object="response", + output=[], + parallel_tool_calls=True, + tool_choice="auto", + tools=[], + status="in_progress", + ) + yield ResponseCreatedEvent( + response=stub_response, sequence_number=seq, type="response.created" + ) + seq += 1 + + # --- text deltas --------------------------------------------------------- + full_text = "" + + async for event in agent.stream_async(user_message): + chunk = event.get("data", "") if isinstance(event, dict) else "" + if not chunk: + continue + full_text += chunk + yield ResponseTextDeltaEvent( + content_index=0, + delta=chunk, + item_id=msg_id, + logprobs=[], + output_index=0, + sequence_number=seq, + type="response.output_text.delta", + ) + seq += 1 + + # --- response.output_text.done ------------------------------------------ + yield ResponseTextDoneEvent( + content_index=0, + item_id=msg_id, + logprobs=[], + output_index=0, + sequence_number=seq, + text=full_text, + type="response.output_text.done", + ) + seq += 1 + + # --- response.completed -------------------------------------------------- + completed_response = _build_openai_response( + text=full_text, + model=model, + response_id=resp_id, + msg_id=msg_id, + ) + yield ResponseCompletedEvent( + response=completed_response, + sequence_number=seq, + type="response.completed", + ) diff --git a/backend/src/modai/modules/chat/openai_llm_chat.py b/backend/src/modai/modules/chat/openai_raw_chat.py similarity index 69% rename from backend/src/modai/modules/chat/openai_llm_chat.py rename to backend/src/modai/modules/chat/openai_raw_chat.py index 6f02d9e..e3d520d 100644 --- a/backend/src/modai/modules/chat/openai_llm_chat.py +++ b/backend/src/modai/modules/chat/openai_raw_chat.py @@ -8,9 +8,14 @@ Response as OpenAIResponse, ResponseStreamEvent as OpenAIResponseStreamEvent, ) -from modai.modules.model_provider.module import ModelProviderModule +from modai.modules.model_provider.module import ( + ModelProviderModule, + ModelProviderResponse, +) +# The module is at the moment unuses as I'm unsure if we should support a raw LLM Access as the agentic +# chat is more what a chat backend needs class OpenAILLMChatModule(ChatLLMModule): """ OpenAI LLM Provider for Chat Responses. @@ -34,45 +39,45 @@ def __init__(self, dependencies: ModuleDependencies, config: dict[str, Any]): async def generate_response( self, request: Request, body_json: OpenAICreateResponse ) -> OpenAIResponse | AsyncGenerator[OpenAIResponseStreamEvent, None]: - # Parse model: format is "provider_name/model_name" - model = body_json.get("model", "") - model_parts = model.split("/") - if len(model_parts) != 2: + provider_name, actual_model = self._parse_model(body_json.get("model", "")) + provider = await self._resolve_provider(request, provider_name) + client = self._create_client(provider) + + body_json["model"] = actual_model + + if body_json.get("stream", False): + return self._generate_streaming_response(client, body_json) + else: + return await self._generate_non_streaming_response(client, body_json) + + def _parse_model(self, model: str) -> tuple[str, str]: + """Parse 'provider_name/model_name' into its components.""" + model_parts = model.split("/", maxsplit=1) + if len(model_parts) != 2 or not model_parts[0] or not model_parts[1]: raise ValueError( f"Invalid model format: {model}. Expected 'provider_name/model_name'" ) + return model_parts[0], model_parts[1] - provider_name, actual_model = model_parts - - # Get all providers from the provider module + async def _resolve_provider( + self, request: Request, provider_name: str + ) -> ModelProviderResponse: + """Look up the provider by name from the provider module.""" providers_response = await self.provider_module.get_providers( - limit=None, offset=None + request=request, limit=None, offset=None ) - provider = None for p in providers_response.providers: if p.name == provider_name: - provider = p - break + return p + raise ValueError(f"Provider '{provider_name}' not found") - if not provider: - raise ValueError(f"Provider '{provider_name}' not found") - - # Create OpenAI client with the provider's API key and base URL - client = AsyncOpenAI( + def _create_client(self, provider: ModelProviderResponse) -> AsyncOpenAI: + """Create an AsyncOpenAI client from a provider configuration.""" + return AsyncOpenAI( api_key=provider.api_key, base_url=provider.base_url if provider.base_url else None, ) - # Update body_json with the actual model - body_json["model"] = actual_model - - stream = body_json.get("stream", False) - - if stream: - return self._generate_streaming_response(client, body_json) - else: - return await self._generate_non_streaming_response(client, body_json) - async def _generate_streaming_response( self, client: AsyncOpenAI, body_json: OpenAICreateResponse ) -> AsyncGenerator[OpenAIResponseStreamEvent, None]: diff --git a/backend/src/modai/modules/health/__tests__/__init__.py b/backend/src/modai/modules/health/__tests__/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/test_health.py b/backend/src/modai/modules/health/__tests__/test_health.py similarity index 81% rename from backend/tests/test_health.py rename to backend/src/modai/modules/health/__tests__/test_health.py index 5fd4056..fa17fdc 100644 --- a/backend/tests/test_health.py +++ b/backend/src/modai/modules/health/__tests__/test_health.py @@ -1,10 +1,7 @@ -import sys -import os import pytest from fastapi.testclient import TestClient from fastapi import FastAPI -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) from modai.modules.health.simple_health_module import SimpleHealthModule from modai.module import ModuleDependencies @@ -19,13 +16,13 @@ def client(): def test_health_endpoint_returns_healthy_status(client): - response = client.get("/api/v1/health") + response = client.get("/api/health") assert response.status_code == 200 assert response.json() == {"status": "healthy"} def test_health_endpoint_requires_no_authentication(client): """Health endpoint must be accessible without any session or credentials.""" - response = client.get("/api/v1/health") + response = client.get("/api/health") assert response.status_code == 200 assert response.json()["status"] == "healthy" diff --git a/backend/src/modai/modules/health/module.py b/backend/src/modai/modules/health/module.py index 3b8d76e..f8d724f 100644 --- a/backend/src/modai/modules/health/module.py +++ b/backend/src/modai/modules/health/module.py @@ -13,7 +13,7 @@ class HealthModule(ModaiModule, ABC): def __init__(self, dependencies: ModuleDependencies, config: dict[str, Any]): super().__init__(dependencies, config) self.router = APIRouter() # This makes it a web module - self.router.add_api_route("/api/v1/health", self.get_health, methods=["GET"]) + self.router.add_api_route("/api/health", self.get_health, methods=["GET"]) @abstractmethod def get_health(self) -> dict[str, Any]: diff --git a/backend/src/modai/modules/model_provider/__tests__/__init__.py b/backend/src/modai/modules/model_provider/__tests__/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/test_central_model_provider_router.py b/backend/src/modai/modules/model_provider/__tests__/test_central_model_provider_router.py similarity index 94% rename from backend/tests/test_central_model_provider_router.py rename to backend/src/modai/modules/model_provider/__tests__/test_central_model_provider_router.py index ad36d68..c72626e 100644 --- a/backend/tests/test_central_model_provider_router.py +++ b/backend/src/modai/modules/model_provider/__tests__/test_central_model_provider_router.py @@ -2,15 +2,11 @@ Tests for Central Model Provider Router. """ -import sys -import os import pytest from unittest.mock import MagicMock from fastapi.testclient import TestClient from fastapi import FastAPI -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) - from modai.modules.model_provider.central_router import CentralModelProviderRouter from modai.modules.model_provider.module import ( ModelProviderModule, @@ -163,7 +159,7 @@ def test_client(self, central_router): def test_get_all_providers_endpoint(self, test_client): """Test GET /models/providers endpoint""" - response = test_client.get("/api/v1/models/providers") + response = test_client.get("/api/models/providers") assert response.status_code == 200 data = response.json() @@ -181,7 +177,7 @@ def test_get_all_providers_endpoint(self, test_client): def test_get_all_models_endpoint(self, test_client): """Test GET /models endpoint""" - response = test_client.get("/api/v1/models") + response = test_client.get("/api/models") assert response.status_code == 200 data = response.json() @@ -208,8 +204,8 @@ def test_get_all_models_endpoint(self, test_client): assert "owned_by" in model def test_get_all_providers_with_pagination(self, test_client): - """Test GET /api/v1/models/providers with pagination""" - response = test_client.get("/api/v1/models/providers?limit=1&offset=0") + """Test GET /api/models/providers with pagination""" + response = test_client.get("/api/models/providers?limit=1&offset=0") assert response.status_code == 200 data = response.json() @@ -229,7 +225,7 @@ def test_get_all_providers_empty(self, mock_session_module): app.include_router(router.router) client = TestClient(app) - response = client.get("/api/v1/models/providers") + response = client.get("/api/models/providers") assert response.status_code == 200 data = response.json() @@ -247,7 +243,7 @@ def test_get_all_models_empty(self, mock_session_module): app.include_router(router.router) client = TestClient(app) - response = client.get("/api/v1/models") + response = client.get("/api/models") assert response.status_code == 200 data = response.json() @@ -272,8 +268,8 @@ def test_all_endpoints_reject_unauthenticated_requests(self): client = TestClient(app) endpoints = [ - ("GET", "/api/v1/models/providers"), - ("GET", "/api/v1/models"), + ("GET", "/api/models/providers"), + ("GET", "/api/models"), ] for method, path in endpoints: diff --git a/backend/tests/test_model_provider.py b/backend/src/modai/modules/model_provider/__tests__/test_model_provider.py similarity index 87% rename from backend/tests/test_model_provider.py rename to backend/src/modai/modules/model_provider/__tests__/test_model_provider.py index f416d76..87239e5 100644 --- a/backend/tests/test_model_provider.py +++ b/backend/src/modai/modules/model_provider/__tests__/test_model_provider.py @@ -2,7 +2,6 @@ Tests for LLM Provider Module REST API endpoints. """ -import sys import os import pytest from pathlib import Path @@ -11,8 +10,6 @@ from fastapi.testclient import TestClient from fastapi import FastAPI -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) - from modai.modules.model_provider.openai_provider import OpenAIProviderModule from modai.modules.model_provider_store.module import ModelProvider, ModelProviderStore from modai.modules.session.module import SessionModule, Session @@ -23,6 +20,8 @@ working_dir = Path.cwd() load_dotenv(find_dotenv(str(working_dir / ".env"))) +OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1") + # Force anyio to use asyncio backend only anyio_backend = pytest.fixture(scope="session")(lambda: "asyncio") @@ -44,7 +43,7 @@ def mock_provider_store(self) -> ModelProviderStore: sample_provider = ModelProvider( id="test-id-123", name="TestProvider", - url="https://api.openai.com/v1", + url=OPENAI_BASE_URL, properties=properties, created_at=datetime(2024, 1, 1, 12, 0, 0), updated_at=datetime(2024, 1, 1, 12, 0, 0), @@ -101,8 +100,8 @@ def test_web_module_missing_dependency(self) -> None: def test_get_providers_endpoint( self, test_client: TestClient, mock_provider_store: ModelProviderStore ) -> None: - """Test GET /api/v1/models/providers/openai endpoint""" - response = test_client.get("/api/v1/models/providers/openai") + """Test GET /api/models/providers/openai endpoint""" + response = test_client.get("/api/models/providers/openai") assert response.status_code == 200 data = response.json() @@ -116,7 +115,7 @@ def test_get_providers_endpoint( provider = data["providers"][0] assert provider["id"] == "test-id-123" assert provider["name"] == "TestProvider" - assert provider["base_url"] == "https://api.openai.com/v1" + assert provider["base_url"] == OPENAI_BASE_URL assert provider["properties"]["key"] == "value" # Verify that api_key is available as a direct field assert "api_key" in provider @@ -132,7 +131,7 @@ def test_get_providers_with_pagination( self, test_client: TestClient, mock_provider_store: ModelProviderStore ) -> None: """Test GET /models/providers/openai with pagination parameters""" - response = test_client.get("/api/v1/models/providers/openai?limit=10&offset=5") + response = test_client.get("/api/models/providers/openai?limit=10&offset=5") assert response.status_code == 200 data = response.json() @@ -146,29 +145,29 @@ def test_get_providers_with_pagination( def test_get_providers_invalid_pagination(self, test_client): """Test GET /models/providers/openai with invalid pagination parameters""" # Test negative offset - response = test_client.get("/api/v1/models/providers/openai?offset=-1") + response = test_client.get("/api/models/providers/openai?offset=-1") assert response.status_code == 422 # Test limit too large - response = test_client.get("/api/v1/models/providers/openai?limit=2000") + response = test_client.get("/api/models/providers/openai?limit=2000") assert response.status_code == 422 # Test limit too small - response = test_client.get("/api/v1/models/providers/openai?limit=0") + response = test_client.get("/api/models/providers/openai?limit=0") assert response.status_code == 422 def test_get_provider_by_id( self, test_client: TestClient, mock_provider_store: ModelProviderStore ) -> None: """Test GET /models/providers/openai/{id} endpoint""" - response = test_client.get("/api/v1/models/providers/openai/test-id-123") + response = test_client.get("/api/models/providers/openai/test-id-123") assert response.status_code == 200 data = response.json() assert data["id"] == "test-id-123" assert data["name"] == "TestProvider" - assert data["base_url"] == "https://api.openai.com/v1" + assert data["base_url"] == OPENAI_BASE_URL assert data["properties"]["key"] == "value" # Verify that api_key is available as a direct field assert "api_key" in data @@ -185,7 +184,7 @@ def test_get_provider_not_found( """Test GET /models/providers/openai/{id} for non-existent provider""" mock_provider_store.get_provider.return_value = None - response = test_client.get("/api/v1/models/providers/openai/nonexistent") + response = test_client.get("/api/models/providers/openai/nonexistent") assert response.status_code == 404 data = response.json() @@ -202,9 +201,7 @@ def test_create_provider( "properties": {"model": "new-model", "temperature": 0.8}, } - response = test_client.post( - "/api/v1/models/providers/openai", json=request_data - ) + response = test_client.post("/api/models/providers/openai", json=request_data) assert response.status_code == 201 data = response.json() @@ -244,9 +241,7 @@ def test_create_provider_validation_error( "properties": {}, } - response = test_client.post( - "/api/v1/models/providers/openai", json=request_data - ) + response = test_client.post("/api/models/providers/openai", json=request_data) assert response.status_code == 400 data = response.json() @@ -256,21 +251,21 @@ def test_create_provider_missing_fields(self, test_client: TestClient) -> None: """Test POST /models/providers/openai with missing required fields""" # Missing name response = test_client.post( - "/api/v1/models/providers/openai", + "/api/models/providers/openai", json={"base_url": "https://api.test.com", "api_key": "test-key"}, ) assert response.status_code == 422 # Missing base_url response = test_client.post( - "/api/v1/models/providers/openai", + "/api/models/providers/openai", json={"name": "TestProvider", "api_key": "test-key"}, ) assert response.status_code == 422 # Missing api_key response = test_client.post( - "/api/v1/models/providers/openai", + "/api/models/providers/openai", json={"name": "TestProvider", "base_url": "https://api.test.com"}, ) assert response.status_code == 422 @@ -287,7 +282,7 @@ def test_update_provider( } response = test_client.put( - "/api/v1/models/providers/openai/existing-id", json=request_data + "/api/models/providers/openai/existing-id", json=request_data ) assert response.status_code == 200 @@ -322,7 +317,7 @@ def test_update_provider_not_found( "properties": {}, } response = test_client.put( - "/api/v1/models/providers/openai/nonexistent-id", json=request_data + "/api/models/providers/openai/nonexistent-id", json=request_data ) assert response.status_code == 404 @@ -333,7 +328,7 @@ def test_delete_provider( self, test_client: TestClient, mock_provider_store: ModelProviderStore ) -> None: """Test DELETE /models/providers/openai/{id} endpoint""" - response = test_client.delete("/api/v1/models/providers/openai/test-id-123") + response = test_client.delete("/api/models/providers/openai/test-id-123") assert response.status_code == 204 assert response.content == b"" # No content for 204 @@ -346,7 +341,7 @@ def test_delete_provider_idempotent( ) -> None: """Test DELETE /models/providers/openai/{id} is idempotent""" # Even if provider doesn't exist, should return 204 - response = test_client.delete("/api/v1/models/providers/openai/nonexistent") + response = test_client.delete("/api/models/providers/openai/nonexistent") assert response.status_code == 204 mock_provider_store.delete_provider.assert_called_once_with("nonexistent") @@ -361,7 +356,7 @@ def test_endpoint_error_handling( # Since we removed try-catch, exceptions now bubble up and get raised by test client with pytest.raises(Exception, match="Database connection failed"): - test_client.get("/api/v1/models/providers/openai") + test_client.get("/api/models/providers/openai") def test_complex_properties_handling( self, test_client: TestClient, mock_provider_store: ModelProviderStore @@ -380,9 +375,7 @@ def test_complex_properties_handling( "properties": complex_properties, } - response = test_client.post( - "/api/v1/models/providers/openai", json=request_data - ) + response = test_client.post("/api/models/providers/openai", json=request_data) assert response.status_code == 201 @@ -402,7 +395,7 @@ def test_get_models_endpoint( self, test_client: TestClient, mock_provider_store: ModelProviderStore ) -> None: """Test GET /models/providers/openai/{provider_id}/models endpoint""" - response = test_client.get("/api/v1/models/providers/openai/test-id-123/models") + response = test_client.get("/api/models/providers/openai/test-id-123/models") assert response.status_code == 200 data = response.json() @@ -430,7 +423,7 @@ def test_get_models_provider_not_found( """Test GET /models/providers/openai/{provider_id}/models for non-existent provider""" mock_provider_store.get_provider.return_value = None - response = test_client.get("/api/v1/models/providers/openai/nonexistent/models") + response = test_client.get("/api/models/providers/openai/nonexistent/models") assert response.status_code == 404 data = response.json() @@ -466,12 +459,12 @@ def test_all_endpoints_reject_unauthenticated_requests( } endpoints = [ - ("GET", "/api/v1/models/providers/openai"), - ("POST", "/api/v1/models/providers/openai", provider_body), - ("GET", "/api/v1/models/providers/openai/some-id"), - ("PUT", "/api/v1/models/providers/openai/some-id", provider_body), - ("DELETE", "/api/v1/models/providers/openai/some-id"), - ("GET", "/api/v1/models/providers/openai/some-id/models"), + ("GET", "/api/models/providers/openai"), + ("POST", "/api/models/providers/openai", provider_body), + ("GET", "/api/models/providers/openai/some-id"), + ("PUT", "/api/models/providers/openai/some-id", provider_body), + ("DELETE", "/api/models/providers/openai/some-id"), + ("GET", "/api/models/providers/openai/some-id/models"), ] for entry in endpoints: diff --git a/backend/src/modai/modules/model_provider/central_router.py b/backend/src/modai/modules/model_provider/central_router.py index 6e8da9a..101f40d 100644 --- a/backend/src/modai/modules/model_provider/central_router.py +++ b/backend/src/modai/modules/model_provider/central_router.py @@ -49,14 +49,14 @@ def __init__(self, dependencies: ModuleDependencies, config: dict[str, Any]): # Add the central route for getting all providers self.router.add_api_route( - "/api/v1/models/providers", + "/api/models/providers", self.get_all_providers, methods=["GET"], ) # Add the central route for getting all models self.router.add_api_route( - "/api/v1/models", + "/api/models", self.get_all_models, methods=["GET"], ) diff --git a/backend/src/modai/modules/model_provider/module.py b/backend/src/modai/modules/model_provider/module.py index cd5b86e..05ea69b 100644 --- a/backend/src/modai/modules/model_provider/module.py +++ b/backend/src/modai/modules/model_provider/module.py @@ -81,33 +81,33 @@ def __init__( # Add model provider routes self.router.add_api_route( - f"/api/v1/models/providers/{provider_type_name}", + f"/api/models/providers/{provider_type_name}", self.get_providers, methods=["GET"], ) self.router.add_api_route( - f"/api/v1/models/providers/{provider_type_name}", + f"/api/models/providers/{provider_type_name}", self.create_provider, methods=["POST"], status_code=201, ) self.router.add_api_route( - f"/api/v1/models/providers/{provider_type_name}/{{provider_id}}", + f"/api/models/providers/{provider_type_name}/{{provider_id}}", self.get_provider, methods=["GET"], ) self.router.add_api_route( - f"/api/v1/models/providers/{provider_type_name}/{{provider_id}}", + f"/api/models/providers/{provider_type_name}/{{provider_id}}", self.update_provider, methods=["PUT"], ) self.router.add_api_route( - f"/api/v1/models/providers/{provider_type_name}/{{provider_id}}/models", + f"/api/models/providers/{provider_type_name}/{{provider_id}}/models", self.get_models, methods=["GET"], ) self.router.add_api_route( - f"/api/v1/models/providers/{provider_type_name}/{{provider_id}}", + f"/api/models/providers/{provider_type_name}/{{provider_id}}", self.delete_provider, methods=["DELETE"], status_code=204, diff --git a/backend/src/modai/modules/model_provider_store/__tests__/__init__.py b/backend/src/modai/modules/model_provider_store/__tests__/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/abstract_model_provider_store_test.py b/backend/src/modai/modules/model_provider_store/__tests__/abstract_model_provider_store_test.py similarity index 100% rename from backend/tests/abstract_model_provider_store_test.py rename to backend/src/modai/modules/model_provider_store/__tests__/abstract_model_provider_store_test.py diff --git a/backend/tests/test_sql_model_provider_store.py b/backend/src/modai/modules/model_provider_store/__tests__/test_sql_model_provider_store.py similarity index 95% rename from backend/tests/test_sql_model_provider_store.py rename to backend/src/modai/modules/model_provider_store/__tests__/test_sql_model_provider_store.py index ceeb9a8..d62c00e 100644 --- a/backend/tests/test_sql_model_provider_store.py +++ b/backend/src/modai/modules/model_provider_store/__tests__/test_sql_model_provider_store.py @@ -1,13 +1,12 @@ -import sys -import os import pytest -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) from modai.module import ModuleDependencies from modai.modules.model_provider_store.sql_model_provider_store import ( SQLAlchemyModelProviderStore, ) -from tests.abstract_model_provider_store_test import AbstractModelProviderStoreTestBase +from modai.modules.model_provider_store.__tests__.abstract_model_provider_store_test import ( + AbstractModelProviderStoreTestBase, +) # Force anyio to use asyncio backend only anyio_backend = pytest.fixture(scope="session")(lambda: "asyncio") diff --git a/backend/src/modai/modules/session/__tests__/__init__.py b/backend/src/modai/modules/session/__tests__/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/test_session.py b/backend/src/modai/modules/session/__tests__/test_session.py similarity index 98% rename from backend/tests/test_session.py rename to backend/src/modai/modules/session/__tests__/test_session.py index 2dba2c7..b151faa 100644 --- a/backend/tests/test_session.py +++ b/backend/src/modai/modules/session/__tests__/test_session.py @@ -1,11 +1,7 @@ -import sys -import os import pytest from unittest.mock import MagicMock import jwt from datetime import datetime, timedelta, timezone - -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) from modai.module import ModuleDependencies from modai.modules.session.jwt_session_module import JwtSessionModule from modai.modules.session.module import Session diff --git a/backend/src/modai/modules/startup_config/__tests__/__init__.py b/backend/src/modai/modules/startup_config/__tests__/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/test_config_loader.py b/backend/src/modai/modules/startup_config/__tests__/test_config_loader.py similarity index 95% rename from backend/tests/test_config_loader.py rename to backend/src/modai/modules/startup_config/__tests__/test_config_loader.py index f108d0d..8eaad55 100644 --- a/backend/tests/test_config_loader.py +++ b/backend/src/modai/modules/startup_config/__tests__/test_config_loader.py @@ -1,11 +1,7 @@ -import sys -import os import yaml import pytest from pathlib import Path -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) - from modai.module import ModuleDependencies from modai.modules.startup_config.yaml_config_module import YamlConfigModule diff --git a/backend/src/modai/modules/user/__tests__/__init__.py b/backend/src/modai/modules/user/__tests__/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/test_user.py b/backend/src/modai/modules/user/__tests__/test_user.py similarity index 95% rename from backend/tests/test_user.py rename to backend/src/modai/modules/user/__tests__/test_user.py index fd8254c..460cfab 100644 --- a/backend/tests/test_user.py +++ b/backend/src/modai/modules/user/__tests__/test_user.py @@ -1,12 +1,8 @@ -import sys -import os import pytest from unittest.mock import Mock, MagicMock, AsyncMock from fastapi.testclient import TestClient from fastapi import FastAPI from datetime import datetime - -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) from modai.module import ModuleDependencies from modai.modules.user.simple_user_module import SimpleUserModule from modai.modules.session.module import SessionModule, Session @@ -93,7 +89,7 @@ def test_get_current_user_success(client): user_store.get_user_by_id.return_value = test_user # Call the endpoint - response = test_client.get("/api/v1/user") + response = test_client.get("/api/user") assert response.status_code == 200 response_data = response.json() @@ -118,7 +114,7 @@ def test_get_current_user_user_not_found(client): user_store.get_user_by_id.return_value = None # User not found # Call the endpoint - response = test_client.get("/api/v1/user") + response = test_client.get("/api/user") assert response.status_code == 404 assert response.json()["detail"] == "User not found" @@ -140,7 +136,7 @@ def test_get_current_user_invalid_session(client): ) # Call the endpoint - response = test_client.get("/api/v1/user") + response = test_client.get("/api/user") assert response.status_code == 401 assert response.json()["detail"] == "Missing, invalid or expired session" diff --git a/backend/src/modai/modules/user/module.py b/backend/src/modai/modules/user/module.py index 6ab7692..874041f 100644 --- a/backend/src/modai/modules/user/module.py +++ b/backend/src/modai/modules/user/module.py @@ -27,9 +27,7 @@ def __init__(self, dependencies: ModuleDependencies, config: dict[str, Any]): self.router = APIRouter() # This makes it a web module # Add user routes - self.router.add_api_route( - "/api/v1/user", self.get_current_user, methods=["GET"] - ) + self.router.add_api_route("/api/user", self.get_current_user, methods=["GET"]) @abstractmethod async def get_current_user(self, request: Request) -> UserResponse: diff --git a/backend/src/modai/modules/user_settings/README.md b/backend/src/modai/modules/user_settings/README.md index 0bae094..1cf2d83 100644 --- a/backend/src/modai/modules/user_settings/README.md +++ b/backend/src/modai/modules/user_settings/README.md @@ -10,10 +10,10 @@ The User Settings module provides REST API endpoints for managing user-specific The module provides four endpoints for managing user settings: -1. **Bulk operations** (`/api/v1/user/{user_id}/settings`): Get all settings or update multiple setting types at once -2. **Single type operations** (`/api/v1/user/{user_id}/settings/{setting_type}`): Get or update a specific setting type +1. **Bulk operations** (`/api/user/{user_id}/settings`): Get all settings or update multiple setting types at once +2. **Single type operations** (`/api/user/{user_id}/settings/{setting_type}`): Get or update a specific setting type -### GET /api/v1/user/{user_id}/settings +### GET /api/user/{user_id}/settings Retrieves all settings for a specific user. **Authentication**: Required (session-based) @@ -40,7 +40,7 @@ Retrieves all settings for a specific user. - `403 Forbidden`: User doesn't have permission to access these settings - `404 Not Found`: User not found -### PUT /api/v1/user/{user_id}/settings +### PUT /api/user/{user_id}/settings Updates settings for a specific user. Only provided setting types are updated, existing settings for other types are preserved. **Authentication**: Required (session-based) @@ -81,7 +81,7 @@ Updates settings for a specific user. Only provided setting types are updated, e - `404 Not Found`: User not found - `422 Unprocessable Entity`: Invalid settings data format -### GET /api/v1/user/{user_id}/settings/{setting_type} +### GET /api/user/{user_id}/settings/{setting_type} Retrieves a specific setting type for a user. **Authentication**: Required (session-based) @@ -103,7 +103,7 @@ Retrieves a specific setting type for a user. - `403 Forbidden`: User doesn't have permission to access these settings - `404 Not Found`: User not found (returns empty settings if setting type doesn't exist) -### PUT /api/v1/user/{user_id}/settings/{setting_type} +### PUT /api/user/{user_id}/settings/{setting_type} Updates a specific setting type for a user. This completely replaces the setting type data. **Authentication**: Required (session-based) @@ -190,17 +190,17 @@ user_settings: ### Frontend Integration ```typescript // Get all user settings -const allSettings = await fetch('/api/v1/user/123/settings', { +const allSettings = await fetch('/api/user/123/settings', { credentials: 'include' }).then(r => r.json()); // Get specific setting type -const themeSettings = await fetch('/api/v1/user/123/settings/theme', { +const themeSettings = await fetch('/api/user/123/settings/theme', { credentials: 'include' }).then(r => r.json()); // Update all settings (merge) -await fetch('/api/v1/user/123/settings', { +await fetch('/api/user/123/settings', { method: 'PUT', credentials: 'include', headers: { 'Content-Type': 'application/json' }, @@ -212,7 +212,7 @@ await fetch('/api/v1/user/123/settings', { }); // Update specific setting type (replace) -await fetch('/api/v1/user/123/settings/theme', { +await fetch('/api/user/123/settings/theme', { method: 'PUT', credentials: 'include', headers: { 'Content-Type': 'application/json' }, diff --git a/backend/src/modai/modules/user_settings/__tests__/__init__.py b/backend/src/modai/modules/user_settings/__tests__/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/test_user_settings.py b/backend/src/modai/modules/user_settings/__tests__/test_user_settings.py similarity index 100% rename from backend/tests/test_user_settings.py rename to backend/src/modai/modules/user_settings/__tests__/test_user_settings.py diff --git a/backend/src/modai/modules/user_settings/module.py b/backend/src/modai/modules/user_settings/module.py index 002d585..716c7e4 100644 --- a/backend/src/modai/modules/user_settings/module.py +++ b/backend/src/modai/modules/user_settings/module.py @@ -50,20 +50,20 @@ def __init__(self, dependencies: ModuleDependencies, config: dict[str, Any]): # Add user settings routes self.router.add_api_route( - "/api/v1/user/{user_id}/settings", self.get_user_settings, methods=["GET"] + "/api/user/{user_id}/settings", self.get_user_settings, methods=["GET"] ) self.router.add_api_route( - "/api/v1/user/{user_id}/settings", + "/api/user/{user_id}/settings", self.update_user_settings, methods=["PUT"], ) self.router.add_api_route( - "/api/v1/user/{user_id}/settings/{setting_type}", + "/api/user/{user_id}/settings/{setting_type}", self.get_user_setting_type, methods=["GET"], ) self.router.add_api_route( - "/api/v1/user/{user_id}/settings/{setting_type}", + "/api/user/{user_id}/settings/{setting_type}", self.update_user_setting_type, methods=["PUT"], ) diff --git a/backend/src/modai/modules/user_settings_store/__tests__/__init__.py b/backend/src/modai/modules/user_settings_store/__tests__/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/test_user_settings_store.py b/backend/src/modai/modules/user_settings_store/__tests__/test_user_settings_store.py similarity index 100% rename from backend/tests/test_user_settings_store.py rename to backend/src/modai/modules/user_settings_store/__tests__/test_user_settings_store.py diff --git a/backend/src/modai/modules/user_store/__tests__/__init__.py b/backend/src/modai/modules/user_store/__tests__/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/abstract_user_store_test.py b/backend/src/modai/modules/user_store/__tests__/abstract_user_store_test.py similarity index 100% rename from backend/tests/abstract_user_store_test.py rename to backend/src/modai/modules/user_store/__tests__/abstract_user_store_test.py diff --git a/backend/tests/test_inmemory_user_store.py b/backend/src/modai/modules/user_store/__tests__/test_inmemory_user_store.py similarity index 76% rename from backend/tests/test_inmemory_user_store.py rename to backend/src/modai/modules/user_store/__tests__/test_inmemory_user_store.py index baa8bf3..63ca852 100644 --- a/backend/tests/test_inmemory_user_store.py +++ b/backend/src/modai/modules/user_store/__tests__/test_inmemory_user_store.py @@ -1,11 +1,10 @@ -import sys -import os import pytest -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) from modai.module import ModuleDependencies from modai.modules.user_store.inmemory_user_store import InMemoryUserStore -from tests.abstract_user_store_test import AbstractUserStoreTestBase +from modai.modules.user_store.__tests__.abstract_user_store_test import ( + AbstractUserStoreTestBase, +) # Force anyio to use asyncio backend only anyio_backend = pytest.fixture(scope="session")(lambda: "asyncio") diff --git a/backend/tests/test_sql_model_user_store.py b/backend/src/modai/modules/user_store/__tests__/test_sql_model_user_store.py similarity index 93% rename from backend/tests/test_sql_model_user_store.py rename to backend/src/modai/modules/user_store/__tests__/test_sql_model_user_store.py index cbffc18..bfb31b6 100644 --- a/backend/tests/test_sql_model_user_store.py +++ b/backend/src/modai/modules/user_store/__tests__/test_sql_model_user_store.py @@ -1,11 +1,10 @@ -import sys -import os import pytest -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) from modai.module import ModuleDependencies from modai.modules.user_store.sql_model_user_store import SQLAlchemyUserStore -from tests.abstract_user_store_test import AbstractUserStoreTestBase +from modai.modules.user_store.__tests__.abstract_user_store_test import ( + AbstractUserStoreTestBase, +) # Force anyio to use asyncio backend only anyio_backend = pytest.fixture(scope="session")(lambda: "asyncio") diff --git a/e2e_tests/tests_omni_full/.gitignore b/e2e_tests/tests_omni_full/.gitignore deleted file mode 100644 index 039fa90..0000000 --- a/e2e_tests/tests_omni_full/.gitignore +++ /dev/null @@ -1,27 +0,0 @@ -# Logs -logs -*.log -npm-debug.log* -yarn-debug.log* -yarn-error.log* -pnpm-debug.log* -lerna-debug.log* -public/modules.json - -node_modules/ -dist/ -dist-ssr -*.local -playwright-report/ -test-results/ - -# Editor directories and files -.vscode/* -!.vscode/extensions.json -.idea -.DS_Store -*.suo -*.ntvs* -*.njsproj -*.sln -*.sw? diff --git a/e2e_tests/tests_omni_full/.gitignore b/e2e_tests/tests_omni_full/.gitignore new file mode 120000 index 0000000..6ef08f9 --- /dev/null +++ b/e2e_tests/tests_omni_full/.gitignore @@ -0,0 +1 @@ +../../.gitignore \ No newline at end of file diff --git a/e2e_tests/tests_omni_full/playwright.config.ts b/e2e_tests/tests_omni_full/playwright.config.ts index 37518f2..ead6215 100644 --- a/e2e_tests/tests_omni_full/playwright.config.ts +++ b/e2e_tests/tests_omni_full/playwright.config.ts @@ -37,7 +37,7 @@ export default defineConfig({ }, { command: "cd ../../backend && rm -f *.db && uv run uvicorn modai.main:app", - url: "http://localhost:8000/api/v1/health", + url: "http://localhost:8000/api/health", reuseExistingServer: !process.env.CI, }, { diff --git a/e2e_tests/tests_omni_light/.gitignore b/e2e_tests/tests_omni_light/.gitignore deleted file mode 100644 index 039fa90..0000000 --- a/e2e_tests/tests_omni_light/.gitignore +++ /dev/null @@ -1,27 +0,0 @@ -# Logs -logs -*.log -npm-debug.log* -yarn-debug.log* -yarn-error.log* -pnpm-debug.log* -lerna-debug.log* -public/modules.json - -node_modules/ -dist/ -dist-ssr -*.local -playwright-report/ -test-results/ - -# Editor directories and files -.vscode/* -!.vscode/extensions.json -.idea -.DS_Store -*.suo -*.ntvs* -*.njsproj -*.sln -*.sw? diff --git a/e2e_tests/tests_omni_light/.gitignore b/e2e_tests/tests_omni_light/.gitignore new file mode 120000 index 0000000..6ef08f9 --- /dev/null +++ b/e2e_tests/tests_omni_light/.gitignore @@ -0,0 +1 @@ +../../.gitignore \ No newline at end of file diff --git a/e2e_tests/tests_omni_light/src/mock-openai-server.js b/e2e_tests/tests_omni_light/src/mock-openai-server.js deleted file mode 120000 index 2d98f53..0000000 --- a/e2e_tests/tests_omni_light/src/mock-openai-server.js +++ /dev/null @@ -1 +0,0 @@ -../../tests_omni_full/src/mock-openai-server.js \ No newline at end of file diff --git a/frontend_omni/.gitignore b/frontend_omni/.gitignore deleted file mode 100644 index 039fa90..0000000 --- a/frontend_omni/.gitignore +++ /dev/null @@ -1,27 +0,0 @@ -# Logs -logs -*.log -npm-debug.log* -yarn-debug.log* -yarn-error.log* -pnpm-debug.log* -lerna-debug.log* -public/modules.json - -node_modules/ -dist/ -dist-ssr -*.local -playwright-report/ -test-results/ - -# Editor directories and files -.vscode/* -!.vscode/extensions.json -.idea -.DS_Store -*.suo -*.ntvs* -*.njsproj -*.sln -*.sw? diff --git a/frontend_omni/.gitignore b/frontend_omni/.gitignore new file mode 120000 index 0000000..5a19b83 --- /dev/null +++ b/frontend_omni/.gitignore @@ -0,0 +1 @@ +../.gitignore \ No newline at end of file diff --git a/frontend_omni/src/modules/authentication-service/AuthenticationService.ts b/frontend_omni/src/modules/authentication-service/AuthenticationService.ts index 8cffa81..c524d4e 100644 --- a/frontend_omni/src/modules/authentication-service/AuthenticationService.ts +++ b/frontend_omni/src/modules/authentication-service/AuthenticationService.ts @@ -23,7 +23,7 @@ export class AuthenticationService implements AuthService { * @throws Error if login fails */ async login(credentials: LoginRequest): Promise { - const response = await fetch("/api/v1/auth/login", { + const response = await fetch("/api/auth/login", { method: "POST", credentials: "include", // Include cookies for session authentication headers: { @@ -48,7 +48,7 @@ export class AuthenticationService implements AuthService { * @throws Error if signup fails */ async signup(credentials: SignupRequest): Promise { - const response = await fetch("/api/v1/auth/signup", { + const response = await fetch("/api/auth/signup", { method: "POST", credentials: "include", // Include cookies for session authentication headers: { @@ -72,7 +72,7 @@ export class AuthenticationService implements AuthService { * @throws Error if logout fails */ async logout(): Promise { - const response = await fetch("/api/v1/auth/logout", { + const response = await fetch("/api/auth/logout", { method: "POST", credentials: "include", // Include cookies for session authentication headers: { diff --git a/frontend_omni/src/modules/chat-service/OpenAIService.ts b/frontend_omni/src/modules/chat-service/OpenAIService.ts index 698d6a7..d9e22d0 100644 --- a/frontend_omni/src/modules/chat-service/OpenAIService.ts +++ b/frontend_omni/src/modules/chat-service/OpenAIService.ts @@ -100,14 +100,14 @@ export abstract class OpenAIChatService implements ChatService { /** * OpenAI Chat Service for use with a backend. - * Routes requests through the backend /api/v1 endpoint which handles + * Routes requests through the backend /api endpoint which handles * provider routing and authentication. */ export class WithBackendOpenAIChatService extends OpenAIChatService { protected createOpenAI(_modelId: string): OpenAI { return new OpenAI({ apiKey: "not-needed-backend-handles-auth", - baseURL: `${window.location.origin}/api/v1`, + baseURL: `${window.location.origin}/api`, dangerouslyAllowBrowser: true, }); } diff --git a/frontend_omni/src/modules/llm-provider-service/LLMRestProviderService.tsx b/frontend_omni/src/modules/llm-provider-service/LLMRestProviderService.tsx index 4e9ae12..f19c10d 100644 --- a/frontend_omni/src/modules/llm-provider-service/LLMRestProviderService.tsx +++ b/frontend_omni/src/modules/llm-provider-service/LLMRestProviderService.tsx @@ -45,7 +45,7 @@ class LLMRestProviderService implements ProviderService { * Get all models from all providers (via /models endpoint) */ async getAllModels(): Promise { - const response = await fetch(`/api/v1/models`); + const response = await fetch(`/api/models`); if (!response.ok) { throw new Error( @@ -61,7 +61,7 @@ class LLMRestProviderService implements ProviderService { * Get all providers from all types */ async getAllProviders(): Promise { - const response = await fetch(`/api/v1/models/providers`); + const response = await fetch(`/api/models/providers`); if (!response.ok) { throw new Error( @@ -83,7 +83,7 @@ class LLMRestProviderService implements ProviderService { typeof providerType === "string" ? providerType : providerType.value; - const response = await fetch(`/api/v1/models/providers/${typeValue}`); + const response = await fetch(`/api/models/providers/${typeValue}`); if (!response.ok) { throw new Error( @@ -107,7 +107,7 @@ class LLMRestProviderService implements ProviderService { ? providerType : providerType.value; const response = await fetch( - `/api/v1/models/providers/${typeValue}/${providerId}`, + `/api/models/providers/${typeValue}/${providerId}`, ); if (!response.ok) { @@ -131,7 +131,7 @@ class LLMRestProviderService implements ProviderService { ? providerType : providerType.value; const response = await fetch( - `/api/v1/models/providers/${typeValue}/${providerId}/models`, + `/api/models/providers/${typeValue}/${providerId}/models`, ); if (!response.ok) { @@ -155,7 +155,7 @@ class LLMRestProviderService implements ProviderService { typeof providerType === "string" ? providerType : providerType.value; - const response = await fetch(`/api/v1/models/providers/${typeValue}`, { + const response = await fetch(`/api/models/providers/${typeValue}`, { method: "POST", headers: { "Content-Type": "application/json", @@ -179,7 +179,7 @@ class LLMRestProviderService implements ProviderService { ? providerType : providerType.value; const response = await fetch( - `/api/v1/models/providers/${typeValue}/${providerId}`, + `/api/models/providers/${typeValue}/${providerId}`, { method: "PUT", headers: { @@ -204,7 +204,7 @@ class LLMRestProviderService implements ProviderService { ? providerType : providerType.value; const response = await fetch( - `/api/v1/models/providers/${typeValue}/${providerId}`, + `/api/models/providers/${typeValue}/${providerId}`, { method: "DELETE", }, diff --git a/frontend_omni/src/modules/user-service/HttpUserService.ts b/frontend_omni/src/modules/user-service/HttpUserService.ts index 4b221e6..bf88e46 100644 --- a/frontend_omni/src/modules/user-service/HttpUserService.ts +++ b/frontend_omni/src/modules/user-service/HttpUserService.ts @@ -11,7 +11,7 @@ export class HttpUserService implements UserService { * @throws Error if the request fails or user is not authenticated */ async fetchCurrentUser(): Promise { - const response = await fetch("/api/v1/user", { + const response = await fetch("/api/user", { method: "GET", credentials: "include", // Include cookies for session authentication headers: {