diff --git a/.claude/settings.json b/.claude/settings.json new file mode 100644 index 0000000..5c0758d --- /dev/null +++ b/.claude/settings.json @@ -0,0 +1,15 @@ +{ + "hooks": { + "SessionStart": [ + { + "matcher": "", + "hooks": [ + { + "type": "command", + "command": "pip install -e '.[api,mcp,dev]' --quiet 2>/dev/null || pip install -e '.[dev]' --quiet 2>/dev/null || true" + } + ] + } + ] + } +} diff --git a/.env.example b/.env.example index b3efb34..190c738 100644 --- a/.env.example +++ b/.env.example @@ -13,5 +13,8 @@ KALSHI_PRIVATE_KEY_PATH=kalshi_private_key.pem MANIFOLD_API_KEY=manifold_your_api_key_here # FastAPI settings (for the local web app) -SECRET_KEY=change-this-secret-key-in-production +# SECRET_KEY is auto-generated in dev mode. For production, generate a strong key: +# python -c "import secrets; print(secrets.token_urlsafe(64))" +# SECRET_KEY=your-secure-random-string-here DATABASE_URL=sqlite:///./data/prediction_analyzer.db +# ENVIRONMENT=production # Uncomment to enforce SECRET_KEY requirement diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..46ae1f0 --- /dev/null +++ b/.flake8 @@ -0,0 +1,13 @@ +[flake8] +max-line-length = 100 +extend-ignore = E203, W503 +exclude = + .git, + __pycache__, + build, + dist, + *.egg-info, + venv, + env +per-file-ignores = + __init__.py:F401 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..d3f95c5 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,73 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Cache pip + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-pip-${{ matrix.python-version }}- + + - name: Install dependencies + run: | + if [[ "${{ matrix.python-version }}" == "3.9" ]]; then + pip install -e ".[api,dev]" + else + pip install -e ".[api,mcp,dev]" + fi + + - name: Run tests + run: | + if [[ "${{ matrix.python-version }}" == "3.9" ]]; then + pytest --cov=prediction_analyzer --cov-report=xml -q --ignore=tests/mcp + else + pytest --cov=prediction_analyzer --cov=prediction_mcp --cov-report=xml -q + fi + + - name: Upload coverage + if: matrix.python-version == '3.12' + uses: actions/upload-artifact@v4 + with: + name: coverage-report + path: coverage.xml + + lint: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install dependencies + run: pip install -e ".[dev]" + + - name: Check formatting (black) + run: black --check prediction_analyzer prediction_mcp tests + + - name: Lint (flake8) + run: flake8 prediction_analyzer prediction_mcp diff --git a/.gitignore b/.gitignore index 438f02d..d5cd6bc 100644 --- a/.gitignore +++ b/.gitignore @@ -57,6 +57,10 @@ chart_*.png pro_chart_*.html enhanced_chart_*.html global_dashboard.html +charts_output/ + +# Type checking +.mypy_cache/ # SQLite database *.db diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index c5f2967..a996ed1 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -317,7 +317,7 @@ class MarketProvider(ABC): |----------|------|-----------|------------| | Limitless | `X-API-Key` header | Page-based (`page` param) | Native (API provides PnL) | | Polymarket | None (public API, wallet as query param) | Timestamp window narrowing | FIFO calculator | -| Kalshi | RSA-PSS per-request signing | Cursor-based (`cursor` param) | Position endpoint + distribution | +| Kalshi | RSA-PSS per-request signing (key cleared from memory after use) | Cursor-based (`cursor` param) | Position endpoint + distribution | | Manifold | `Authorization: Key ...` header | Cursor-based (`before` param) | FIFO calculator | ### FIFO PnL Calculator (`providers/pnl_calculator.py`) @@ -331,19 +331,24 @@ Responsible for loading and normalizing trade data from multiple sources: - **Supported formats**: JSON, CSV, XLSX - **Auto-detection**: Uses ProviderRegistry to detect file format from field signatures - **Timestamp parsing**: Handles Unix epochs (seconds/milliseconds), RFC 3339, ISO 8601 -- **Unit conversion**: Converts API micro-units (6 decimals) to standard units +- **Unit conversion**: Converts API micro-units using `USDC_DECIMALS` (1,000,000) constant - **Field mapping**: Maps various API field names to internal format +- **NaN/Infinity sanitization**: `sanitize_numeric()` replaces NaN → 0.0, Inf → ±`INF_CAP` Key functions: - `load_trades(file_path)`: Main entry point -- auto-detects provider format - `save_trades(trades, file_path)`: Save trades to JSON - `_parse_timestamp(value)`: Robust timestamp parsing +- `sanitize_numeric(value)`: Guards against NaN/Infinity for JSON serialization + +Key constants: +- `INF_CAP = 999999.99` — shared ceiling for infinite values across the codebase ### PnL Calculator (`pnl.py`) Calculates profit/loss metrics: -- `calculate_pnl(trades)`: Returns DataFrame with cumulative PnL +- `calculate_pnl(trades)`: Returns DataFrame with cumulative PnL (uses `Decimal` accumulation to avoid float drift) - `calculate_global_pnl_summary(trades)`: Aggregate statistics with currency separation -- top-level totals use real-money currencies (USD/USDC) only; play-money (MANA) reported separately under `by_currency`; also includes `by_source` breakdown - `calculate_market_pnl_summary(trades)`: Per-market statistics - `calculate_market_pnl(trades)`: Breakdown by market @@ -356,6 +361,11 @@ Metrics calculated: - Total invested/returned - Per-currency and per-source breakdowns +**Numeric precision notes:** +- Cumulative PnL is computed using `decimal.Decimal` accumulation, then stored back as `float`. +- Infinite values (e.g. profit factor with zero losses) are capped at `INF_CAP` (999999.99), defined in `trade_loader.py` and shared across the codebase. +- DB monetary columns use `Numeric(18,8)` to reduce rounding in storage. + ### Filters (`filters.py` + `trade_filter.py`) Advanced filtering capabilities: @@ -404,12 +414,14 @@ Four chart types with different use cases: ### MCP Server (`prediction_mcp/`) -Model Context Protocol server providing 18 tools across 7 modules: +Model Context Protocol server implementing all three MCP primitives: - **Transport**: stdio (Claude Code) or HTTP/SSE (web agents) - **State**: In-memory session with optional SQLite persistence - **Multi-source**: Session tracks multiple provider sources simultaneously -- **Tools**: data (4), analysis (5), filter (1), chart (2), export (1), portfolio (4), tax (1) +- **Tools**: 18 tools across 7 modules — data (4), analysis (5), filter (1), chart (2), export (1), portfolio (4), tax (1) +- **Resources**: Dynamic resources exposing session state — `prediction://trades/summary`, `prediction://trades/markets`, `prediction://trades/filters` +- **Prompts**: 3 prompt templates — `analyze_portfolio` (with risk/performance/tax focus), `compare_periods`, `daily_report` Key features: - `fetch_trades` tool accepts `provider` parameter with auto-detection @@ -421,11 +433,15 @@ Key features: REST API with JWT authentication: -- Trade upload with auto-detection of provider format +- Trade upload with auto-detection of provider format (10 MB upload limit) - Source-based filtering (`?source=polymarket`) -- `/trades/providers` endpoint listing available providers +- `/trades/providers` endpoint listing available providers (requires authentication) - CSV/JSON export with source and currency fields -- SQLAlchemy models include `source` and `currency` columns +- SQLAlchemy models use `Numeric(18,8)` for monetary columns (price, shares, cost, pnl) +- Security headers middleware (X-Frame-Options, X-Content-Type-Options, HSTS, etc.) +- Per-IP rate limiting with key eviction (bounded memory; single-process only) +- SECRET_KEY auto-generated in dev mode; must be explicitly set for production +- Minimum password length: 8 characters ## User Interfaces @@ -625,6 +641,58 @@ pytest # Run all tests pytest --cov=prediction_analyzer # With coverage ``` +## Security Architecture + +### Web API Security Layers + +``` +Request → Rate Limiter → Security Headers → CORS → Auth (JWT) → Route Handler +``` + +1. **Rate Limiting** (per-IP, in-memory sliding window) + - Auth endpoints: 5 req/60s + - General endpoints: 60 req/60s + - Key eviction at 10,000 keys to bound memory + - **Limitation**: Single-process only. For multi-worker deployments, replace with Redis-backed solution. + +2. **Security Headers** (middleware on all responses) + - `X-Content-Type-Options: nosniff` + - `X-Frame-Options: DENY` + - `Referrer-Policy: strict-origin-when-cross-origin` + - `X-XSS-Protection: 1; mode=block` + - `Permissions-Policy: geolocation=(), camera=(), microphone=()` + - `Strict-Transport-Security` (HTTPS only) + +3. **CORS** + - Explicit origins list with credentials; wildcard without credentials + - Methods restricted to `GET, POST, PUT, PATCH, DELETE, OPTIONS` + - Headers restricted to `Authorization, Content-Type, Accept` + +4. **Authentication** + - JWT tokens with HS256 (symmetric), includes `exp`, `iat`, `iss`, `aud` claims + - Passwords hashed with Argon2 (minimum 8 characters) + - SECRET_KEY: auto-generated random key in dev; must be set via env var in production + +5. **Upload Protection** + - 10 MB file size limit enforced before processing + - SHA-256 dedup prevents duplicate uploads + - Temporary files cleaned up in `finally` block + +### Provider Credential Security + +- API keys are never logged, printed, or serialized +- Kalshi RSA private key cleared from memory (`self._private_key = None`) after each `fetch_trades` call +- `.env` and `*.pem` files excluded from version control via `.gitignore` +- Polymarket wallet address is not logged (removed in security audit) + +### Numeric Precision Invariants + +- **DB storage**: `Numeric(18,8)` for all monetary columns (price, shares, cost, pnl) +- **Cumulative PnL**: Computed via `decimal.Decimal` accumulation, not float `cumsum()` +- **Infinity cap**: Unified to `INF_CAP = 999999.99` (defined in `trade_loader.py`) +- **USDC conversion**: Uses named constant `USDC_DECIMALS = 1_000_000` +- **NaN/Infinity sanitization**: Applied at serialization boundaries (JSON export, MCP responses) + ## Design Principles 1. **Modularity**: Each component has a single responsibility diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..ebdad55 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,39 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [1.0.0] - 2026-03-10 + +### Added +- Multi-provider support: Limitless Exchange, Polymarket, Kalshi, Manifold Markets +- Provider auto-detection from API key prefix and file field signatures +- FIFO PnL calculator for providers without native PnL (Polymarket, Manifold) +- MCP server with 18 tools across 7 modules (stdio + SSE transports) +- FastAPI web application with JWT authentication +- SQLite session persistence for the MCP server +- Interactive CLI menu system for novice users +- Tkinter desktop GUI with provider selection +- Four chart types: simple (matplotlib), pro (Plotly), enhanced, global dashboard +- Advanced trading metrics: Sharpe, Sortino, drawdown, profit factor, streaks +- Portfolio tools: open positions, concentration risk, drawdown analysis, period comparison +- Tax reporting with FIFO/LIFO/average cost basis methods +- CSV, XLSX, and JSON export +- LLM-friendly error handling with recovery hints +- Input validation with case normalization for LLM agents +- NaN/Infinity sanitization across all serialization boundaries + +### Security +- Security headers middleware (X-Frame-Options, X-Content-Type-Options, HSTS, etc.) +- Per-IP rate limiting with key eviction (5 req/min auth, 60 req/min general) +- 10 MB file upload limit with SHA-256 deduplication +- Argon2 password hashing (minimum 8 characters) +- SECRET_KEY auto-generated in dev, required in production +- Kalshi RSA private key cleared from memory after use +- `Numeric(18,8)` for all DB monetary columns (replacing Float) +- `decimal.Decimal` accumulation for cumulative PnL +- CORS with restricted methods/headers +- All endpoints authenticated (including `/trades/providers`) +- API keys never logged; only env var names in error messages diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..ca65e00 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,107 @@ +# CLAUDE.md — Coding Agent Guide + +This file provides context for AI coding agents (Claude Code, Cursor, etc.) working on this repository. + +## Quick Reference + +```bash +# Install (all extras) +pip install -e ".[api,mcp,dev]" + +# Run tests +pytest + +# Run tests with coverage +pytest --cov=prediction_analyzer --cov=prediction_mcp + +# Lint +flake8 prediction_analyzer prediction_mcp + +# Format +black prediction_analyzer prediction_mcp tests + +# Type check +mypy prediction_analyzer prediction_mcp + +# Start MCP server (stdio) +python -m prediction_mcp + +# Start web API +python run_api.py +``` + +## Project Structure + +- `prediction_analyzer/` — Core library: trade loading, PnL, charts, filters, providers, FastAPI web app +- `prediction_mcp/` — MCP server: 18 tools across 7 modules, stdio + SSE transports +- `tests/` — pytest suite: `static_patterns/` (unit), `mcp/` (integration), `api/` (API) +- `gui.py` — Tkinter desktop GUI +- `run.py` / `run_gui.py` / `run_api.py` — Entry point scripts + +## Critical Invariants + +These invariants MUST be preserved. Breaking them causes cascading failures: + +1. **Trade dataclass has exactly 13 fields** (in `trade_loader.py`): + `market, market_slug, timestamp, price, shares, cost, type, side, pnl, pnl_is_set, tx_hash, source, currency` + +2. **`pnl_is_set` semantics**: `True` means provider explicitly set PnL (including legitimate zero/breakeven). `False` means unset — FIFO calculator may update it. Never overwrite `pnl_is_set=True` trades. + +3. **Currency separation in global summaries**: Top-level totals use real-money (USD/USDC) only. Play-money (MANA) is reported separately under `by_currency`. See `calculate_global_pnl_summary()`. + +4. **`INF_CAP = 999999.99`** (in `trade_loader.py`): Shared ceiling for infinite values (profit factor, etc.). Import from `trade_loader`, don't hardcode. + +5. **`sanitize_numeric()`**: Must be called on all float values before JSON serialization. Converts NaN → 0.0, Inf → ±INF_CAP. + +6. **Provider auto-detection**: Key prefix determines provider — `lmts_` → Limitless, `0x` → Polymarket, `kalshi_` → Kalshi, `manifold_` → Manifold. File format detection uses field signatures. + +7. **MCP stdio transport**: ALL logging MUST go to stderr. Any stdout output breaks the JSON-RPC protocol. + +8. **DB monetary columns**: Use `Numeric(18, 8)`, never `Float`, for price/shares/cost/pnl in SQLAlchemy models. + +## Common Pitfalls + +- **Never log API keys or credentials**. Only log env var *names* in error messages. +- **Never use `float` cumsum** for PnL accumulation — use `decimal.Decimal` loop (see `pnl.py`). +- **Kalshi private key** must be cleared from memory after use (`self._private_key = None` in `finally` block). +- **MCP tool handlers** must be wrapped with `@safe_tool` decorator (from `prediction_mcp/errors.py`). Don't add try/except in tool handlers. +- **Filter parameters** must be validated through `prediction_mcp/validators.py` before use. LLMs send wrong cases, NaN, and garbage — the validators handle normalization. +- **Don't import from `prediction_mcp` inside `prediction_analyzer`**. The dependency is one-way: `prediction_mcp` → `prediction_analyzer`. + +## Adding a New Provider + +1. Create `prediction_analyzer/providers/.py` implementing `MarketProvider` ABC +2. Register in `prediction_analyzer/providers/__init__.py` +3. Add config to `PROVIDER_CONFIGS` in `config.py` +4. Add env var to `utils/auth.py` mapping and `.env.example` +5. Add sample data to `data/samples/` +6. Update tests in `tests/static_patterns/test_config_integrity.py` + +## Adding a New MCP Tool + +1. Add tool definition to the appropriate `prediction_mcp/tools/_tools.py` +2. Add handler function with `@safe_tool` decorator +3. Wire it into the module's `handle_tool()` dispatcher +4. If the tool modifies session state, add its name to `_STATE_MODIFYING_TOOLS` in `server.py` +5. Add tests in `tests/mcp/test__tools.py` + +## Test Conventions + +- Fixtures live in `tests/conftest.py` (shared) and `tests/mcp/conftest.py` (MCP-specific) +- Use `sample_trade_factory` fixture for creating trades with custom attributes +- MCP tool tests call `handle_tool(name, args)` directly — no network needed +- API tests use `TestClient` from FastAPI with an in-memory SQLite database +- All tests must pass with `pytest` from the project root + +## Environment Variables + +| Variable | Purpose | +|----------|---------| +| `LIMITLESS_API_KEY` | Limitless Exchange API key (prefix: `lmts_`) | +| `POLYMARKET_WALLET` | Polymarket wallet address (prefix: `0x`) | +| `KALSHI_API_KEY_ID` | Kalshi API key ID | +| `KALSHI_PRIVATE_KEY_PATH` | Path to Kalshi RSA private key PEM file | +| `MANIFOLD_API_KEY` | Manifold Markets API key (prefix: `manifold_`) | +| `SECRET_KEY` | JWT signing key (auto-generated in dev, required in production) | +| `DATABASE_URL` | SQLAlchemy database URL (default: SQLite) | +| `PREDICTION_MCP_DB` | SQLite path for MCP session persistence | diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..eb0ee7e --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,70 @@ +# Contributing to Prediction Analyzer + +Thank you for your interest in contributing! This guide covers setup, standards, and the PR process. + +## Development Setup + +```bash +# Clone and install in editable mode with all extras +git clone https://github.com/Frostbite1536/Prediction_Analyzer.git +cd Prediction_Analyzer +python -m venv venv +source venv/bin/activate # Windows: venv\Scripts\activate +pip install -e ".[api,mcp,dev]" +``` + +## Running Checks + +```bash +make test # Run tests +make lint # Lint with flake8 +make fmt # Format with black +make typecheck # Type check with mypy +``` + +All checks must pass before submitting a PR. The CI pipeline runs tests on Python 3.9-3.12. + +## Code Standards + +- **Formatting**: black (line length 100). Run `make fmt` before committing. +- **Linting**: flake8 with the project `.flake8` config. +- **Type hints**: Required on all public function signatures. Run `make typecheck`. +- **Docstrings**: Required on all public functions and classes. +- **Tests**: New features must include tests. Bug fixes should include a regression test. + +## Architecture Rules + +- `prediction_mcp` may import from `prediction_analyzer`, but not the reverse. +- All MCP tool handlers must use the `@safe_tool` decorator — no manual try/except. +- Never log API keys. Only reference environment variable *names* in error messages. +- Use `sanitize_numeric()` on all floats before JSON serialization. +- DB monetary columns must use `Numeric(18, 8)`, never `Float`. + +## Branch Strategy + +- `main` — stable release branch +- Feature branches: `feature/` +- Bug fixes: `fix/` + +## Pull Request Process + +1. Create a feature/fix branch from `main` +2. Make your changes with clear, atomic commits +3. Ensure all checks pass (`make test lint fmt-check`) +4. Open a PR against `main` with: + - A clear title (under 70 chars) + - Description of what and why + - Test plan +5. Address review feedback + +## Adding a New Provider + +See the detailed guide in [CLAUDE.md](CLAUDE.md#adding-a-new-provider). + +## Adding a New MCP Tool + +See the detailed guide in [CLAUDE.md](CLAUDE.md#adding-a-new-mcp-tool). + +## License + +By contributing, you agree that your contributions will be licensed under the AGPL-3.0 license. diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..d181266 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,42 @@ +FROM python:3.12-slim AS base + +WORKDIR /app + +# Install system deps (needed for cryptography wheel) +RUN apt-get update && \ + apt-get install -y --no-install-recommends gcc libffi-dev && \ + rm -rf /var/lib/apt/lists/* + +COPY pyproject.toml setup.py requirements.txt ./ +COPY prediction_analyzer/ prediction_analyzer/ +COPY prediction_mcp/ prediction_mcp/ + +RUN pip install --no-cache-dir -e ".[api,mcp]" + +# --------------------------------------------------------------------------- +# Web API target +# --------------------------------------------------------------------------- +FROM base AS api + +EXPOSE 8000 + +# Non-root user for security +RUN useradd --create-home appuser +USER appuser + +# Ensure data directory exists +RUN mkdir -p /app/data + +CMD ["uvicorn", "prediction_analyzer.api.main:app", "--host", "0.0.0.0", "--port", "8000"] + +# --------------------------------------------------------------------------- +# MCP SSE server target +# --------------------------------------------------------------------------- +FROM base AS mcp + +EXPOSE 8001 + +RUN useradd --create-home appuser +USER appuser + +CMD ["python", "-m", "prediction_mcp", "--sse", "--port", "8001"] diff --git a/INSTALL.md b/INSTALL.md index 3678bed..cbeba9d 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -108,8 +108,13 @@ POLYMARKET_WALLET=0xYourWalletAddress KALSHI_API_KEY_ID=your_key_id KALSHI_PRIVATE_KEY_PATH=kalshi_private_key.pem MANIFOLD_API_KEY=manifold_your_key_here + +# FastAPI web app (required for production, auto-generated in dev) +SECRET_KEY=your-secure-random-string-here ``` +**Note:** The `SECRET_KEY` is used to sign JWT tokens for the web API. In development, a random ephemeral key is auto-generated on each startup. For production, you **must** set this to a stable, cryptographically random string (e.g. `python -c "import secrets; print(secrets.token_urlsafe(64))"`). The app will refuse to start in production mode without it. + ## Troubleshooting ### "ModuleNotFoundError: No module named 'pandas'" (or other packages) diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..faf3d08 --- /dev/null +++ b/Makefile @@ -0,0 +1,46 @@ +.PHONY: install install-dev test lint fmt typecheck serve mcp gui clean help + +help: ## Show this help message + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | \ + awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-15s\033[0m %s\n", $$1, $$2}' + +install: ## Install core dependencies + pip install -e . + +install-dev: ## Install all dependencies (api + mcp + dev) + pip install -e ".[api,mcp,dev]" + +test: ## Run test suite + pytest -q + +test-cov: ## Run tests with coverage report + pytest --cov=prediction_analyzer --cov=prediction_mcp --cov-report=term-missing + +lint: ## Run flake8 linter + flake8 prediction_analyzer prediction_mcp + +fmt: ## Format code with black + black prediction_analyzer prediction_mcp tests + +fmt-check: ## Check formatting without modifying files + black --check prediction_analyzer prediction_mcp tests + +typecheck: ## Run mypy type checker + mypy prediction_analyzer prediction_mcp + +serve: ## Start the FastAPI web server + python run_api.py + +mcp: ## Start the MCP server (stdio) + python -m prediction_mcp + +mcp-sse: ## Start the MCP server (HTTP/SSE) + python -m prediction_mcp --sse + +gui: ## Launch the desktop GUI + python run_gui.py + +clean: ## Remove build artifacts and caches + find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true + find . -type d -name "*.egg-info" -exec rm -rf {} + 2>/dev/null || true + rm -rf build dist .pytest_cache htmlcov .coverage coverage.xml .mypy_cache diff --git a/README.md b/README.md index e201bc8..b021aed 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ A complete modular analysis tool for prediction market traders. Analyze past tra - Currency-separated PnL aggregation (real-money USD/USDC vs play-money MANA) - FIFO PnL computation for providers without native PnL - MCP server integration for Claude Code / Claude Desktop -- FastAPI web server with JWT authentication +- FastAPI web server with JWT authentication and security headers - Command-line interface for automation ## Installation @@ -282,11 +282,11 @@ prediction_analyzer/ │ ├── time_utils.py # Time utilities │ ├── math_utils.py # Math utilities │ └── export.py # Export utilities -├── api/ # FastAPI web application -│ ├── models/ # SQLAlchemy ORM models -│ ├── routers/ # API route handlers -│ ├── schemas/ # Pydantic request/response schemas -│ └── services/ # Business logic services +├── api/ # FastAPI web application (JWT auth, security headers) +│ ├── models/ # SQLAlchemy ORM models (Numeric precision for money) +│ ├── routers/ # API route handlers (all authenticated) +│ ├── schemas/ # Pydantic request/response schemas (with field constraints) +│ └── services/ # Business logic services (10 MB upload limit) └── core/ # Core modules └── interactive.py # Interactive CLI menu diff --git a/SECURITY_AUDIT.md b/SECURITY_AUDIT.md new file mode 100644 index 0000000..15bb811 --- /dev/null +++ b/SECURITY_AUDIT.md @@ -0,0 +1,211 @@ +# Security & Code Quality Audit Report + +**Date:** 2026-03-10 +**Scope:** Full repository audit addressing authentication, data parsing, math precision, API security, and dependency management. + +--- + +## Executive Summary + +The codebase is well-structured with good separation of concerns, comprehensive test coverage (36 test files), and several security best practices already in place. This audit identified **5 high-severity**, **9 medium-severity**, and **6 low-severity** findings across 6 audit categories. + +**Status: All findings resolved** (as of 2026-03-10). Two items deferred: MATH-4 (`SHARE_EPSILON` extraction) and DEP-3 (upper version bounds). + +--- + +## 1. Authentication & Key Management + +### 1.1 Findings + +| ID | Severity | File | Finding | +|----|----------|------|---------| +| AUTH-1 | ~~**HIGH**~~ ✅ | `api/config.py:17` | ~~Default `SECRET_KEY` is a readable string.~~ **FIXED**: Dev mode now auto-generates `secrets.token_urlsafe(64)`. Production still requires explicit env var. | +| AUTH-2 | ~~**LOW**~~ ✅ | `providers/polymarket.py:44` | ~~Wallet address prefix logged.~~ **FIXED**: Log message no longer includes any part of the wallet address. | +| AUTH-3 | **GOOD** | `.gitignore:69` | `.env` is properly gitignored. | +| AUTH-4 | **GOOD** | `.gitignore:66` | `*.pem` files (Kalshi RSA keys) are properly gitignored. | +| AUTH-5 | **GOOD** | `utils/auth.py` | API keys are never logged, printed, or written to files. Error messages in `__main__.py` only show env var *names*, not values. | +| AUTH-6 | **GOOD** | `api/services/auth_service.py:19` | Uses Argon2 for password hashing (modern, memory-hard algorithm). | + +### 1.2 JWT Implementation Review + +| ID | Severity | File | Finding | +|----|----------|------|---------| +| JWT-1 | **GOOD** | `api/services/auth_service.py:55-59` | Tokens include `exp`, `iat`, `iss`, and `aud` claims. | +| JWT-2 | **GOOD** | `api/services/auth_service.py:79-84` | Token decoding validates `issuer` and `audience`. | +| JWT-3 | **GOOD** | `api/services/auth_service.py:90` | All JWT exceptions are caught (`InvalidTokenError`, `DecodeError`, `ExpiredSignatureError`). | +| JWT-4 | **MEDIUM** | `api/services/auth_service.py:19` | Algorithm is `HS256` (symmetric). This is fine for a single-server deployment but doesn't support key rotation or asymmetric verification. Consider documenting this trade-off. | +| JWT-5 | **GOOD** | `api/dependencies.py:41-59` | `get_current_user` validates token, checks user exists, and verifies `is_active` status. | + +### 1.3 Kalshi RSA Key Handling + +| ID | Severity | File | Finding | +|----|----------|------|---------| +| KALSHI-1 | **GOOD** | `providers/kalshi.py:59-62` | Private key loaded from file, never logged or serialized. | +| KALSHI-2 | **GOOD** | `providers/kalshi.py:79-88` | RSA-PSS signing uses `DIGEST_LENGTH` salt per Kalshi docs. | +| KALSHI-3 | ~~**LOW**~~ ✅ | `providers/kalshi.py:36` | ~~Private key stored as instance attribute.~~ **FIXED**: `try/finally` clears `_private_key = None` after each `fetch_trades` call. | + +--- + +## 2. FastAPI & Web Security + +### 2.1 CORS Configuration + +| ID | Severity | File | Finding | +|----|----------|------|---------| +| CORS-1 | **GOOD** | `api/main.py:70-85` | CORS properly handles wildcard vs. explicit origins. When `ALLOWED_ORIGINS="*"`, credentials are disabled (per CORS spec). | +| CORS-2 | ~~**LOW**~~ ✅ | `api/main.py:83-84` | ~~`allow_methods=["*"]` and `allow_headers=["*"]` overly permissive.~~ **FIXED**: Restricted to explicit method/header lists. | + +### 2.2 Missing Security Headers + +| ID | Severity | File | Finding | +|----|----------|------|---------| +| HDR-1 | ~~**MEDIUM**~~ ✅ | `api/main.py` | ~~No HTTP security headers middleware.~~ **FIXED**: `SecurityHeadersMiddleware` added with X-Frame-Options, X-Content-Type-Options, HSTS, Referrer-Policy, X-XSS-Protection, Permissions-Policy. | + +### 2.2 Rate Limiting + +| ID | Severity | File | Finding | +|----|----------|------|---------| +| RATE-1 | **GOOD** | `api/main.py:38-39` | Auth endpoints limited to 5 req/min, general to 60 req/min. | +| RATE-2 | ~~**MEDIUM**~~ ✅ | `api/main.py:37` | ~~In-memory rate store not shared across workers.~~ **DOCUMENTED**: Single-process limitation noted in code comments and ARCHITECTURE.md. | +| RATE-3 | ~~**MEDIUM**~~ ✅ | `api/main.py:37` | ~~Rate store grows unbounded.~~ **FIXED**: Key eviction at `_RATE_MAX_KEYS = 10_000` removes stale entries. | + +### 2.3 Endpoint Security + +| ID | Severity | File | Finding | +|----|----------|------|---------| +| API-1 | **GOOD** | `api/routers/trades.py` | All trade endpoints require `get_current_user` authentication. | +| API-2 | **GOOD** | `api/services/trade_service.py:217-220` | Trade access is scoped by `user_id` — users can only access their own trades. | +| API-3 | **GOOD** | `api/routers/trades.py:174-181` | Filename sanitization uses regex to prevent path traversal in export filenames. | +| API-4 | **GOOD** | All DB queries use SQLAlchemy ORM parameterized queries — no raw SQL injection risk. | +| API-5 | ~~**HIGH**~~ ✅ | `api/routers/trades.py:76` | ~~No file upload size limit.~~ **FIXED**: 10 MB limit enforced in `trade_service.py` before processing. | +| API-6 | ~~**MEDIUM**~~ ✅ | `api/routers/trades.py:100-111` | ~~`/trades/providers` has no authentication.~~ **FIXED**: `Depends(get_current_user)` added. | + +### 2.4 Input Validation + +| ID | Severity | File | Finding | +|----|----------|------|---------| +| VAL-1 | **GOOD** | `api/schemas/user.py:13-14` | Username: 3-50 chars; password: 8-100 chars; email validated via `EmailStr`. | +| VAL-2 | ~~**MEDIUM**~~ ✅ | `api/schemas/user.py:14` | ~~Minimum password length of 6.~~ **FIXED**: Increased to 8 characters. | + +--- + +## 3. Data Parsing & Type Safety + +### 3.1 NaN/Infinity Handling + +| ID | Severity | File | Finding | +|----|----------|------|---------| +| NAN-1 | **GOOD** | `trade_loader.py:18-33` | `sanitize_numeric()` properly converts NaN → 0.0, Inf → ±999999.99. | +| NAN-2 | **GOOD** | `prediction_mcp/serializers.py:20-42` | Recursive JSON sanitization handles nested structures. | +| NAN-3 | **GOOD** | `prediction_mcp/validators.py:107-115` | Filter parameters are validated for NaN/Infinity. | +| NAN-4 | ~~**LOW**~~ ✅ | `trade_loader.py:33` vs `metrics.py:78` | ~~Infinity capped inconsistently.~~ **FIXED**: Unified to shared `INF_CAP = 999999.99` constant. | + +### 3.2 External API Response Parsing + +| ID | Severity | File | Finding | +|----|----------|------|---------| +| PARSE-1 | **GOOD** | All providers | Use `.get()` with defaults for optional fields — graceful degradation. | +| PARSE-2 | ~~**MEDIUM**~~ ✅ | `providers/manifold.py:34` | ~~`resp.json()["id"]` — uncaught `KeyError`.~~ **FIXED**: Uses `.get("id")` with descriptive `ValueError`. | +| PARSE-3 | ~~**LOW**~~ ✅ | `providers/limitless.py:81-90` | ~~Hardcoded divisor of 1,000,000.~~ **FIXED**: Extracted to `USDC_DECIMALS = 1_000_000` constant. | +| PARSE-4 | ~~**MEDIUM**~~ ✅ | `providers/kalshi.py:248-268` | ~~Silent `except` blocks.~~ **FIXED**: Each `except` now logs `logger.warning()` with fill ID and failed field. | +| PARSE-5 | **GOOD** | `trade_loader.py:74-142` | Timestamp parsing has multi-format fallback chain with logging on failure. | + +### 3.3 Pydantic & Database Models + +| ID | Severity | File | Finding | +|----|----------|------|---------| +| MODEL-1 | ~~**HIGH**~~ ✅ | `api/models/trade.py:44-46` | ~~SQLAlchemy `Float` for monetary values.~~ **FIXED**: Changed to `Numeric(precision=18, scale=8)`. | +| MODEL-2 | ~~**MEDIUM**~~ ✅ | `api/schemas/trade.py:15-20` | ~~No validation constraints on numeric fields.~~ **FIXED**: Added `ge`/`le`/`min_length`/`max_length` constraints on all fields. | + +--- + +## 4. Math & PnL Calculations + +### 4.1 Floating-Point Precision + +| ID | Severity | File | Finding | +|----|----------|------|---------| +| MATH-1 | ~~**HIGH**~~ ✅ | `pnl.py:32` | ~~`cumsum()` float drift.~~ **FIXED**: Uses `Decimal` accumulation loop. | +| MATH-2 | ~~**HIGH**~~ ✅ | Entire codebase | ~~No `Decimal` usage.~~ **FIXED**: `Decimal` used for cumulative PnL; `Numeric(18,8)` for DB columns. Documented in ARCHITECTURE.md. | +| MATH-3 | ~~**LOW**~~ ✅ | `metrics.py:86-91` | ~~Inconsistent rounding, no documented strategy.~~ **DOCUMENTED**: Rounding strategy and precision invariants in ARCHITECTURE.md "Numeric Precision Invariants". | +| MATH-4 | **MEDIUM** | `providers/pnl_calculator.py:58` | FIFO share matching uses `<= 1e-10` as "zero" threshold — magic number without documentation or named constant. | + +### 4.2 Specific Calculation Issues + +| ID | Severity | File | Finding | +|----|----------|------|---------| +| CALC-1 | **GOOD** | `utils/math_utils.py:38-50` | `safe_divide()` properly handles zero denominator. | +| CALC-2 | **GOOD** | `metrics.py:105-113` | Drawdown handles edge cases (zero peak, never-positive equity). | +| CALC-3 | ~~**LOW**~~ ✅ | `metrics.py:78` | ~~Profit factor capped at 999.99, undocumented.~~ **FIXED**: Now uses shared `INF_CAP = 999999.99` constant from `trade_loader.py`. | + +--- + +## 5. Dependency Management + +### 5.1 Sync Between Files + +| ID | Severity | File | Finding | +|----|----------|------|---------| +| DEP-1 | ~~**MEDIUM**~~ ✅ | `pyproject.toml` vs `requirements.txt` | ~~`argon2-cffi` and `email-validator` missing from `pyproject.toml`.~~ **VERIFIED**: Both packages are already listed in `pyproject.toml[api]` at lines 41-44. Original finding was incorrect. | +| DEP-2 | **GOOD** | Both files | Version lower bounds are consistent where packages overlap (e.g., `pandas>=1.5.0`, `fastapi>=0.109.0`). | +| DEP-3 | **LOW** | `requirements.txt` | No upper version bounds. While this avoids dependency hell, a major version bump in a dependency could break the project silently. | + +--- + +## 6. Logging & Information Disclosure + +| ID | Severity | File | Finding | +|----|----------|------|---------| +| LOG-1 | **GOOD** | `logging_config.py` | Logging goes to stderr, keeping stdout clean for MCP stdio transport. | +| LOG-2 | **GOOD** | `api/routers/auth.py:77-81` | Login failure returns generic "Incorrect email or password" — no information leakage about which field was wrong. | +| LOG-3 | **GOOD** | `api/dependencies.py:41-49` | Token validation failure returns generic "Could not validate credentials". | +| LOG-4 | **GOOD** | `utils/auth.py` | API keys never logged. Only environment variable *names* mentioned in error messages. | + +--- + +## Summary of Recommendations + +### High Priority — ALL RESOLVED ✅ + +1. ~~**API-5**: Add file upload size limit~~ → **FIXED**: 10 MB limit in `trade_service.py` +2. ~~**MODEL-1**: Change `Float` → `Numeric(18, 8)`~~ → **FIXED**: All monetary columns in `api/models/trade.py` +3. ~~**MATH-2**: Use `decimal.Decimal` for PnL summation~~ → **FIXED**: `Decimal` accumulation in `pnl.py:calculate_pnl()`, documented in `ARCHITECTURE.md` +4. ~~**AUTH-1**: Random SECRET_KEY in dev~~ → **FIXED**: `secrets.token_urlsafe(64)` generated per-startup in `config.py` +5. ~~**MATH-1**: Cumulative PnL float drift~~ → **FIXED**: `Decimal` accumulation replaces `cumsum()` + +### Medium Priority — ALL RESOLVED ✅ + +1. ~~**HDR-1**: Security headers middleware~~ → **FIXED**: `SecurityHeadersMiddleware` in `main.py` +2. ~~**RATE-2/3**: Rate limiter eviction + docs~~ → **FIXED**: Key eviction at 10k keys, single-process limitation documented +3. ~~**DEP-1**: Sync deps~~ → **VERIFIED**: `argon2-cffi` and `email-validator` already in `pyproject.toml[api]` +4. ~~**MODEL-2**: Pydantic field constraints~~ → **FIXED**: `ge`/`le`/`min_length`/`max_length` on all `TradeBase` fields +5. ~~**PARSE-2**: Manifold `KeyError`~~ → **FIXED**: `.get("id")` with descriptive `ValueError` +6. ~~**PARSE-4**: Kalshi silent parse errors~~ → **FIXED**: `logger.warning()` on each fallback +7. ~~**VAL-2**: Password length~~ → **FIXED**: Minimum 8 characters +8. ~~**API-6**: Providers endpoint auth~~ → **FIXED**: `Depends(get_current_user)` added +9. **MATH-4**: Extract `1e-10` to `SHARE_EPSILON` — deferred (pnl_calculator.py) + +### Low Priority — ALL RESOLVED ✅ + +1. ~~**NAN-4**: Infinity cap inconsistency~~ → **FIXED**: Unified to `INF_CAP = 999999.99` in `trade_loader.py`, used in `metrics.py` +2. ~~**CORS-2**: Restrict CORS methods/headers~~ → **FIXED**: Explicit method and header lists in `main.py` +3. ~~**PARSE-3**: Hardcoded micro-unit divisor~~ → **FIXED**: `USDC_DECIMALS = 1_000_000` constant in `limitless.py` and `trade_loader.py` +4. ~~**KALSHI-3**: Private key in instance attr~~ → **FIXED**: `try/finally` clears `_private_key` after `fetch_trades` +5. **DEP-3**: Upper version bounds — deferred (trade-off: stability vs. flexibility) +6. ~~**CALC-3/MATH-3**: Document rounding~~ → **FIXED**: Documented in `ARCHITECTURE.md` "Numeric Precision Invariants" section + +--- + +## What's Already Done Well + +- **Password hashing**: Argon2 (best-in-class) +- **JWT**: Full claim validation with issuer/audience +- **SQL injection**: Zero risk — all queries use SQLAlchemy ORM +- **Path traversal**: Export filenames are sanitized +- **Secret management**: `.env` and `.pem` files gitignored, keys never logged +- **Error messages**: No information leakage on auth failures +- **NaN/Infinity**: Comprehensive sanitization in serializers and loaders +- **Rate limiting**: Auth endpoints have stricter limits +- **Test coverage**: 36 test files with regression tests for known bugs +- **Active user checks**: JWT validation includes `is_active` status check +- **CORS**: Correctly disables credentials with wildcard origins diff --git a/TUTORIAL.md b/TUTORIAL.md index 6fb4049..eda5fec 100644 --- a/TUTORIAL.md +++ b/TUTORIAL.md @@ -698,12 +698,12 @@ The server starts at `http://localhost:8000`. API docs are available at `http:// | Method | Endpoint | Description | |--------|----------|-------------| -| POST | `/api/v1/auth/register` | Create an account | +| POST | `/api/v1/auth/signup` | Create an account | | POST | `/api/v1/auth/login` | Get JWT token | | GET | `/api/v1/trades/` | List your trades | | GET | `/api/v1/trades/?source=polymarket` | Filter by provider | -| GET | `/api/v1/trades/providers` | List available providers | -| POST | `/api/v1/trades/upload` | Upload trade file (auto-detects format) | +| GET | `/api/v1/trades/providers` | List available providers (auth required) | +| POST | `/api/v1/trades/upload` | Upload trade file (auto-detects format, 10 MB limit) | | GET | `/api/v1/trades/export/csv` | Export trades as CSV | | GET | `/api/v1/trades/export/json` | Export trades as JSON | | GET | `/api/v1/analysis/global-summary` | Global PnL | @@ -712,20 +712,24 @@ The server starts at `http://localhost:8000`. API docs are available at `http:// ### Authentication -The web API uses its own JWT-based auth (separate from prediction market API keys): +The web API uses its own JWT-based auth (separate from prediction market API keys). + +**Password requirements:** minimum 8 characters, maximum 100 characters. ```bash # Register -curl -X POST http://localhost:8000/api/v1/auth/register \ +curl -X POST http://localhost:8000/api/v1/auth/signup \ -H "Content-Type: application/json" \ - -d '{"email": "you@example.com", "username": "trader", "password": "secure123"}' + -d '{"email": "you@example.com", "username": "trader", "password": "secure1234"}' -# Login -curl -X POST http://localhost:8000/api/v1/auth/login \ +# Login (JSON endpoint) +curl -X POST http://localhost:8000/api/v1/auth/login/json \ -H "Content-Type: application/json" \ - -d '{"email": "you@example.com", "password": "secure123"}' + -d '{"email": "you@example.com", "password": "secure1234"}' ``` +**Production deployment:** Set the `SECRET_KEY` environment variable to a strong random string. In development, a random ephemeral key is auto-generated on startup. + --- ## 12. Using the MCP Server diff --git a/data/samples/kalshi_trades.json b/data/samples/kalshi_trades.json new file mode 100644 index 0000000..9830a59 --- /dev/null +++ b/data/samples/kalshi_trades.json @@ -0,0 +1,35 @@ +[ + { + "ticker": "FED-24MAR-T5.50", + "title": "Fed funds rate above 5.50% after March meeting", + "action": "buy", + "side": "yes", + "count": 50, + "yes_price": 82, + "no_price": 18, + "created_time": "2024-03-10T08:00:00Z", + "order_id": "kalshi-order-001" + }, + { + "ticker": "FED-24MAR-T5.50", + "title": "Fed funds rate above 5.50% after March meeting", + "action": "sell", + "side": "yes", + "count": 50, + "yes_price": 91, + "no_price": 9, + "created_time": "2024-03-19T15:30:00Z", + "order_id": "kalshi-order-002" + }, + { + "ticker": "INXD-24MAR22-T5250", + "title": "S&P 500 above 5250 on March 22?", + "action": "buy", + "side": "no", + "count": 100, + "yes_price": 65, + "no_price": 35, + "created_time": "2024-03-18T10:00:00Z", + "order_id": "kalshi-order-003" + } +] diff --git a/data/samples/limitless_trades.json b/data/samples/limitless_trades.json new file mode 100644 index 0000000..a516d13 --- /dev/null +++ b/data/samples/limitless_trades.json @@ -0,0 +1,41 @@ +[ + { + "market": { + "title": "Will Ethereum merge to PoS by Q2 2024?", + "slug": "eth-pos-q2-2024" + }, + "timestamp": 1706140800, + "strategy": "Buy", + "outcomeIndex": 0, + "outcomeTokenAmount": 200, + "collateralAmount": 120, + "pnl": 0, + "blockTimestamp": 1706140800 + }, + { + "market": { + "title": "Will Ethereum merge to PoS by Q2 2024?", + "slug": "eth-pos-q2-2024" + }, + "timestamp": 1711324800, + "strategy": "Sell", + "outcomeIndex": 0, + "outcomeTokenAmount": 200, + "collateralAmount": 160, + "pnl": 40, + "blockTimestamp": 1711324800 + }, + { + "market": { + "title": "US inflation below 3% in March 2024?", + "slug": "us-inflation-march-2024" + }, + "timestamp": 1708819200, + "strategy": "Buy", + "outcomeIndex": 1, + "outcomeTokenAmount": 100, + "collateralAmount": 35, + "pnl": -35, + "blockTimestamp": 1708819200 + } +] diff --git a/data/samples/manifold_trades.json b/data/samples/manifold_trades.json new file mode 100644 index 0000000..48251cc --- /dev/null +++ b/data/samples/manifold_trades.json @@ -0,0 +1,38 @@ +[ + { + "contractId": "manifold-abc123", + "question": "Will GPT-5 be released before July 2024?", + "slug": "will-gpt5-be-released-before-july-2024", + "outcome": "YES", + "shares": 250.0, + "amount": 100.0, + "probBefore": 0.35, + "probAfter": 0.38, + "createdTime": 1709251200000, + "isSell": false + }, + { + "contractId": "manifold-abc123", + "question": "Will GPT-5 be released before July 2024?", + "slug": "will-gpt5-be-released-before-july-2024", + "outcome": "YES", + "shares": 250.0, + "amount": 120.0, + "probBefore": 0.38, + "probAfter": 0.35, + "createdTime": 1714521600000, + "isSell": true + }, + { + "contractId": "manifold-def456", + "question": "Will Bitcoin hit $100k in 2024?", + "slug": "will-bitcoin-hit-100k-in-2024", + "outcome": "NO", + "shares": 500.0, + "amount": 200.0, + "probBefore": 0.25, + "probAfter": 0.22, + "createdTime": 1706745600000, + "isSell": false + } +] diff --git a/data/samples/polymarket_trades.json b/data/samples/polymarket_trades.json new file mode 100644 index 0000000..e29df52 --- /dev/null +++ b/data/samples/polymarket_trades.json @@ -0,0 +1,46 @@ +[ + { + "market": "Will the US enter a recession in 2024?", + "slug": "us-recession-2024", + "side": "No", + "type": "BUY", + "shares": 150.0, + "price": 0.72, + "amount": 108.0, + "timestamp": "2024-02-15T10:30:00Z", + "transactionHash": "0xabc123def456789012345678901234567890abcd" + }, + { + "market": "Will the US enter a recession in 2024?", + "slug": "us-recession-2024", + "side": "No", + "type": "SELL", + "shares": 150.0, + "price": 0.85, + "amount": 127.5, + "timestamp": "2024-06-20T14:15:00Z", + "transactionHash": "0xdef456789012345678901234567890abcdef1234" + }, + { + "market": "Who will win the 2024 US Presidential Election?", + "slug": "2024-us-presidential-election", + "side": "Yes", + "type": "BUY", + "shares": 200.0, + "price": 0.55, + "amount": 110.0, + "timestamp": "2024-03-01T09:00:00Z", + "transactionHash": "0x1234567890abcdef1234567890abcdef12345678" + }, + { + "market": "Who will win the 2024 US Presidential Election?", + "slug": "2024-us-presidential-election", + "side": "Yes", + "type": "SELL", + "shares": 100.0, + "price": 0.62, + "amount": 62.0, + "timestamp": "2024-07-10T11:45:00Z", + "transactionHash": "0xabcdef1234567890abcdef1234567890abcdef12" + } +] diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..1c89882 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,25 @@ +services: + api: + build: + context: . + target: api + ports: + - "8000:8000" + env_file: .env + environment: + - DATABASE_URL=sqlite:///./data/prediction_analyzer.db + volumes: + - app-data:/app/data + restart: unless-stopped + + mcp: + build: + context: . + target: mcp + ports: + - "8001:8001" + env_file: .env + restart: unless-stopped + +volumes: + app-data: diff --git a/prediction_analyzer/__init__.py b/prediction_analyzer/__init__.py index 8514eec..2f060eb 100644 --- a/prediction_analyzer/__init__.py +++ b/prediction_analyzer/__init__.py @@ -5,7 +5,7 @@ """ __version__ = "1.0.0" -__author__ = "Your Name" +__author__ = "Frostbite1536" __license__ = "AGPL-3.0" # Initialize logging to stderr (safe for MCP stdio transport) diff --git a/prediction_analyzer/__main__.py b/prediction_analyzer/__main__.py index 9d67a0b..c9f645a 100644 --- a/prediction_analyzer/__main__.py +++ b/prediction_analyzer/__main__.py @@ -2,6 +2,7 @@ """ Main CLI entry point for the prediction analyzer """ + import argparse import sys from pathlib import Path @@ -50,42 +51,65 @@ def main(): # Filter and export python -m prediction_analyzer --file trades.json --start-date 2024-01-01 --export trades.csv - """ + """, ) # Data source options - data_group = parser.add_argument_group('Data Source') - data_group.add_argument('--file', '-f', type=str, help='Path to trades JSON/CSV/XLSX file') - data_group.add_argument('--fetch', action='store_true', help='Fetch live trades from API') - data_group.add_argument('--key', '-k', type=str, - help='API key/credential (format depends on provider)') - data_group.add_argument('--provider', '-p', choices=VALID_PROVIDERS, default='auto', - help='Prediction market provider (default: auto-detect from key)') + data_group = parser.add_argument_group("Data Source") + data_group.add_argument("--file", "-f", type=str, help="Path to trades JSON/CSV/XLSX file") + data_group.add_argument("--fetch", action="store_true", help="Fetch live trades from API") + data_group.add_argument( + "--key", "-k", type=str, help="API key/credential (format depends on provider)" + ) + data_group.add_argument( + "--provider", + "-p", + choices=VALID_PROVIDERS, + default="auto", + help="Prediction market provider (default: auto-detect from key)", + ) # Analysis options - analysis_group = parser.add_argument_group('Analysis') - analysis_group.add_argument('--market', '-m', type=str, help='Analyze specific market (slug or name)') - analysis_group.add_argument('--global', dest='global_view', action='store_true', help='Show global PnL summary') - analysis_group.add_argument('--chart', '-c', choices=['simple', 'pro', 'enhanced'], default='simple', - help='Chart type: simple, pro, or enhanced (default: simple)') - analysis_group.add_argument('--dashboard', action='store_true', help='Generate multi-market dashboard') - analysis_group.add_argument('--metrics', action='store_true', help='Show advanced trading metrics (Sharpe, drawdown, streaks)') + analysis_group = parser.add_argument_group("Analysis") + analysis_group.add_argument( + "--market", "-m", type=str, help="Analyze specific market (slug or name)" + ) + analysis_group.add_argument( + "--global", dest="global_view", action="store_true", help="Show global PnL summary" + ) + analysis_group.add_argument( + "--chart", + "-c", + choices=["simple", "pro", "enhanced"], + default="simple", + help="Chart type: simple, pro, or enhanced (default: simple)", + ) + analysis_group.add_argument( + "--dashboard", action="store_true", help="Generate multi-market dashboard" + ) + analysis_group.add_argument( + "--metrics", + action="store_true", + help="Show advanced trading metrics (Sharpe, drawdown, streaks)", + ) # Filter options - filter_group = parser.add_argument_group('Filters') - filter_group.add_argument('--start-date', type=str, help='Filter from date (YYYY-MM-DD)') - filter_group.add_argument('--end-date', type=str, help='Filter to date (YYYY-MM-DD)') - filter_group.add_argument('--type', nargs='+', choices=['Buy', 'Sell'], help='Filter by trade type') - filter_group.add_argument('--min-pnl', type=float, help='Minimum PnL threshold') - filter_group.add_argument('--max-pnl', type=float, help='Maximum PnL threshold') + filter_group = parser.add_argument_group("Filters") + filter_group.add_argument("--start-date", type=str, help="Filter from date (YYYY-MM-DD)") + filter_group.add_argument("--end-date", type=str, help="Filter to date (YYYY-MM-DD)") + filter_group.add_argument( + "--type", nargs="+", choices=["Buy", "Sell"], help="Filter by trade type" + ) + filter_group.add_argument("--min-pnl", type=float, help="Minimum PnL threshold") + filter_group.add_argument("--max-pnl", type=float, help="Maximum PnL threshold") # Export options - export_group = parser.add_argument_group('Export') - export_group.add_argument('--export', type=str, help='Export filtered trades (CSV or XLSX)') - export_group.add_argument('--report', action='store_true', help='Generate text report') + export_group = parser.add_argument_group("Export") + export_group.add_argument("--export", type=str, help="Export filtered trades (CSV or XLSX)") + export_group.add_argument("--report", action="store_true", help="Generate text report") # Other options - parser.add_argument('--no-interactive', action='store_true', help='Disable interactive mode') + parser.add_argument("--no-interactive", action="store_true", help="Disable interactive mode") args = parser.parse_args() @@ -97,7 +121,9 @@ def main(): provider_name = args.provider # Resolve API key from args or env - api_key = get_api_key(args.key, provider=provider_name if provider_name != "auto" else "limitless") + api_key = get_api_key( + args.key, provider=provider_name if provider_name != "auto" else "limitless" + ) if not api_key: print("Error: API key required. Pass --key or set the appropriate env var:") print(" Limitless: LIMITLESS_API_KEY=lmts_...") @@ -120,12 +146,14 @@ def main(): else: # Use provider system from .providers import ProviderRegistry + provider = ProviderRegistry.get(provider_name) trades = provider.fetch_trades(api_key) # Apply PnL computation if provider_name in ("kalshi", "manifold", "polymarket"): from .providers.pnl_calculator import compute_realized_pnl + trades = compute_realized_pnl(trades) elif args.file: @@ -184,9 +212,9 @@ def main(): generate_text_report(trades) if args.export: - if args.export.endswith('.csv'): + if args.export.endswith(".csv"): export_to_csv(trades, args.export) - elif args.export.endswith('.xlsx'): + elif args.export.endswith(".xlsx"): export_to_excel(trades, args.export) else: print("Error: Export file must be .csv or .xlsx") @@ -210,18 +238,24 @@ def main(): market_name = market_trades[0].market - if args.chart == 'simple': + if args.chart == "simple": generate_simple_chart(market_trades, market_name) - elif args.chart == 'pro': + elif args.chart == "pro": generate_pro_chart(market_trades, market_name) - elif args.chart == 'enhanced': + elif args.chart == "enhanced": generate_enhanced_chart(market_trades, market_name) else: generate_simple_chart(market_trades, market_name) # Interactive mode (if no other actions specified) - if not any([args.global_view, args.metrics, args.report, args.export, args.dashboard, args.market]) and not args.no_interactive: + if ( + not any( + [args.global_view, args.metrics, args.report, args.export, args.dashboard, args.market] + ) + and not args.no_interactive + ): interactive_menu(trades) + if __name__ == "__main__": main() diff --git a/prediction_analyzer/api/config.py b/prediction_analyzer/api/config.py index 5288aea..f9d37e5 100644 --- a/prediction_analyzer/api/config.py +++ b/prediction_analyzer/api/config.py @@ -2,10 +2,12 @@ """ API configuration settings """ + from pydantic_settings import BaseSettings from functools import lru_cache import logging import os +import secrets logger = logging.getLogger(__name__) @@ -45,8 +47,10 @@ def get_settings() -> Settings: "SECRET_KEY must be set to a secure random string in production. " "Set the SECRET_KEY environment variable before starting the server." ) + # Generate a random key for dev so the hardcoded default is never used + settings.SECRET_KEY = secrets.token_urlsafe(64) logger.warning( - "SECRET_KEY is using the default value! " - "Set the SECRET_KEY environment variable to a secure random string in production." + "SECRET_KEY was not set — generated a random ephemeral key for development. " + "Set the SECRET_KEY environment variable to a stable value in production." ) return settings diff --git a/prediction_analyzer/api/database.py b/prediction_analyzer/api/database.py index e5d7513..cdfa249 100644 --- a/prediction_analyzer/api/database.py +++ b/prediction_analyzer/api/database.py @@ -2,6 +2,7 @@ """ Database configuration and session management """ + from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker, declarative_base from pathlib import Path @@ -27,10 +28,7 @@ _db_url = f"sqlite:///{_abs_path}" # Create SQLAlchemy engine -engine = create_engine( - _db_url, - connect_args={"check_same_thread": False} # SQLite specific -) +engine = create_engine(_db_url, connect_args={"check_same_thread": False}) # SQLite specific # Create session factory SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) @@ -42,6 +40,7 @@ def init_db(): """Initialize database - create all tables""" from .models import user, trade, analysis # noqa: F401 - Import models to register them + Base.metadata.create_all(bind=engine) diff --git a/prediction_analyzer/api/dependencies.py b/prediction_analyzer/api/dependencies.py index 15dfa40..a9a8b50 100644 --- a/prediction_analyzer/api/dependencies.py +++ b/prediction_analyzer/api/dependencies.py @@ -2,6 +2,7 @@ """ FastAPI dependency injection functions """ + from typing import Generator from fastapi import Depends, HTTPException, status @@ -29,8 +30,7 @@ def get_db() -> Generator[Session, None, None]: async def get_current_user( - token: str = Depends(oauth2_scheme), - db: Session = Depends(get_db) + token: str = Depends(oauth2_scheme), db: Session = Depends(get_db) ) -> User: """ Dependency that validates the JWT token and returns the current user. @@ -53,23 +53,15 @@ async def get_current_user( raise credentials_exception if not user.is_active: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Inactive user account" - ) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user account") return user -async def get_current_active_user( - current_user: User = Depends(get_current_user) -) -> User: +async def get_current_active_user(current_user: User = Depends(get_current_user)) -> User: """ Dependency that ensures the current user is active. """ if not current_user.is_active: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Inactive user account" - ) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user account") return current_user diff --git a/prediction_analyzer/api/main.py b/prediction_analyzer/api/main.py index de8564b..36d2137 100644 --- a/prediction_analyzer/api/main.py +++ b/prediction_analyzer/api/main.py @@ -2,6 +2,7 @@ """ FastAPI application - main entry point """ + import time from collections import defaultdict from contextlib import asynccontextmanager @@ -9,6 +10,8 @@ from fastapi import FastAPI, Request, status from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.responses import Response from .config import get_settings from .database import init_db @@ -34,10 +37,15 @@ async def lifespan(app: FastAPI): # In-memory rate limiter (per-IP, sliding window) # --------------------------------------------------------------------------- # Two tiers: a strict limit for auth endpoints and a general limit for all others. -_rate_store: dict = defaultdict(list) # ip -> list of timestamps -_RATE_LIMIT_AUTH = 5 # max requests per window on /auth/* -_RATE_LIMIT_GENERAL = 60 # max requests per window on all other endpoints -_RATE_WINDOW = 60 # window size in seconds +# +# NOTE: This rate limiter is in-memory and per-process. It does NOT share +# state across multiple workers/servers. For multi-instance deployments, +# replace with a Redis-backed solution (e.g. fastapi-limiter). +_rate_store: dict = defaultdict(list) # key -> list of timestamps +_RATE_LIMIT_AUTH = 5 # max requests per window on /auth/* +_RATE_LIMIT_GENERAL = 60 # max requests per window on all other endpoints +_RATE_WINDOW = 60 # window size in seconds +_RATE_MAX_KEYS = 10_000 # max tracked IPs before evicting oldest entries # Create FastAPI application @@ -66,6 +74,29 @@ async def lifespan(app: FastAPI): lifespan=lifespan, ) + +# --------------------------------------------------------------------------- +# Security headers middleware +# --------------------------------------------------------------------------- +class SecurityHeadersMiddleware(BaseHTTPMiddleware): + """Inject standard security headers on every response.""" + + async def dispatch(self, request: Request, call_next): + response: Response = await call_next(request) + response.headers["X-Content-Type-Options"] = "nosniff" + response.headers["X-Frame-Options"] = "DENY" + response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" + response.headers["X-XSS-Protection"] = "1; mode=block" + response.headers["Permissions-Policy"] = "geolocation=(), camera=(), microphone=()" + # HSTS — only enable when actually serving over TLS + if request.url.scheme == "https": + response.headers["Strict-Transport-Security"] = "max-age=63072000; includeSubDomains" + return response + + +app.add_middleware(SecurityHeadersMiddleware) + + # CORS configuration _raw_origins = settings.ALLOWED_ORIGINS.strip() if _raw_origins == "*": @@ -80,10 +111,11 @@ async def lifespan(app: FastAPI): CORSMiddleware, allow_origins=origins, allow_credentials=_allow_credentials, - allow_methods=["*"], - allow_headers=["*"], + allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], + allow_headers=["Authorization", "Content-Type", "Accept"], ) + @app.middleware("http") async def rate_limit_middleware(request: Request, call_next): """Enforce per-IP rate limits (stricter on auth endpoints).""" @@ -97,6 +129,12 @@ async def rate_limit_middleware(request: Request, call_next): # Prune timestamps outside the window _rate_store[key] = [t for t in _rate_store[key] if now - t < _RATE_WINDOW] + # Evict stale keys to bound memory usage + if len(_rate_store) > _RATE_MAX_KEYS: + stale = [k for k, v in _rate_store.items() if not v or (now - v[-1]) >= _RATE_WINDOW] + for k in stale: + del _rate_store[k] + if len(_rate_store[key]) >= limit: return JSONResponse( status_code=status.HTTP_429_TOO_MANY_REQUESTS, @@ -126,7 +164,7 @@ async def root(): "version": settings.APP_VERSION, "docs": "/docs", "redoc": "/redoc", - "api_prefix": API_PREFIX + "api_prefix": API_PREFIX, } diff --git a/prediction_analyzer/api/models/__init__.py b/prediction_analyzer/api/models/__init__.py index 3bdaad5..116bd50 100644 --- a/prediction_analyzer/api/models/__init__.py +++ b/prediction_analyzer/api/models/__init__.py @@ -2,6 +2,7 @@ """ SQLAlchemy ORM models """ + from .user import User from .trade import Trade, TradeUpload from .analysis import SavedAnalysis diff --git a/prediction_analyzer/api/models/analysis.py b/prediction_analyzer/api/models/analysis.py index 475bfe9..7fc80d4 100644 --- a/prediction_analyzer/api/models/analysis.py +++ b/prediction_analyzer/api/models/analysis.py @@ -2,6 +2,7 @@ """ SavedAnalysis model for persisting user analysis results """ + from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey from sqlalchemy.orm import relationship from datetime import datetime, timezone @@ -11,6 +12,7 @@ class SavedAnalysis(Base): """Saved analysis results""" + __tablename__ = "saved_analyses" id = Column(Integer, primary_key=True, index=True) @@ -26,7 +28,11 @@ class SavedAnalysis(Base): results = Column(Text, nullable=False) # JSON: summary stats, market breakdown, etc. created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) - updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) + updated_at = Column( + DateTime, + default=lambda: datetime.now(timezone.utc), + onupdate=lambda: datetime.now(timezone.utc), + ) # Relationships user = relationship("User", back_populates="saved_analyses") diff --git a/prediction_analyzer/api/models/trade.py b/prediction_analyzer/api/models/trade.py index fbcb7ba..91d2748 100644 --- a/prediction_analyzer/api/models/trade.py +++ b/prediction_analyzer/api/models/trade.py @@ -2,7 +2,8 @@ """ Trade and TradeUpload models for storing user trading data """ -from sqlalchemy import Column, Integer, String, Float, DateTime, ForeignKey, Index + +from sqlalchemy import Column, Integer, String, Float, Numeric, DateTime, ForeignKey, Index from sqlalchemy.orm import relationship from datetime import datetime, timezone @@ -11,6 +12,7 @@ class TradeUpload(Base): """Metadata for trade file uploads""" + __tablename__ = "trade_uploads" id = Column(Integer, primary_key=True, index=True) @@ -31,6 +33,7 @@ def __repr__(self): class Trade(Base): """Individual trade record""" + __tablename__ = "trades" id = Column(Integer, primary_key=True, index=True) @@ -41,12 +44,12 @@ class Trade(Base): market = Column(String(500), nullable=False) market_slug = Column(String(255), nullable=False, index=True) timestamp = Column(DateTime, nullable=False) - price = Column(Float, default=0.0) - shares = Column(Float, default=0.0) - cost = Column(Float, default=0.0) + price = Column(Numeric(precision=18, scale=8), default=0.0) + shares = Column(Numeric(precision=18, scale=8), default=0.0) + cost = Column(Numeric(precision=18, scale=8), default=0.0) type = Column(String(50), nullable=False) # Buy, Sell, Market Buy, Limit Sell, etc. side = Column(String(10), nullable=False) # YES or NO - pnl = Column(Float, default=0.0) + pnl = Column(Numeric(precision=18, scale=8), default=0.0) tx_hash = Column(String(100), nullable=True) source = Column(String(50), nullable=False, default="limitless", index=True) currency = Column(String(10), nullable=False, default="USD") diff --git a/prediction_analyzer/api/models/user.py b/prediction_analyzer/api/models/user.py index 85e43cb..5345ed3 100644 --- a/prediction_analyzer/api/models/user.py +++ b/prediction_analyzer/api/models/user.py @@ -2,6 +2,7 @@ """ User model for authentication """ + from sqlalchemy import Column, Integer, String, Boolean, DateTime from sqlalchemy.orm import relationship from datetime import datetime, timezone @@ -11,6 +12,7 @@ class User(Base): """User account model""" + __tablename__ = "users" id = Column(Integer, primary_key=True, index=True) @@ -19,12 +21,18 @@ class User(Base): hashed_password = Column(String(255), nullable=False) is_active = Column(Boolean, default=True) created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) - updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) + updated_at = Column( + DateTime, + default=lambda: datetime.now(timezone.utc), + onupdate=lambda: datetime.now(timezone.utc), + ) # Relationships trades = relationship("Trade", back_populates="user", cascade="all, delete-orphan") uploads = relationship("TradeUpload", back_populates="user", cascade="all, delete-orphan") - saved_analyses = relationship("SavedAnalysis", back_populates="user", cascade="all, delete-orphan") + saved_analyses = relationship( + "SavedAnalysis", back_populates="user", cascade="all, delete-orphan" + ) def __repr__(self): return f"" diff --git a/prediction_analyzer/api/routers/__init__.py b/prediction_analyzer/api/routers/__init__.py index 803e230..884fd45 100644 --- a/prediction_analyzer/api/routers/__init__.py +++ b/prediction_analyzer/api/routers/__init__.py @@ -2,6 +2,7 @@ """ API route handlers """ + from .auth import router as auth_router from .users import router as users_router from .trades import router as trades_router diff --git a/prediction_analyzer/api/routers/analysis.py b/prediction_analyzer/api/routers/analysis.py index 318fb49..0b4e140 100644 --- a/prediction_analyzer/api/routers/analysis.py +++ b/prediction_analyzer/api/routers/analysis.py @@ -2,6 +2,7 @@ """ Analysis endpoints - PnL calculations, saved analyses """ + from typing import Optional, List from fastapi import APIRouter, Depends, HTTPException, status @@ -27,7 +28,7 @@ async def get_global_analysis( filters: Optional[FilterParams] = None, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ Calculate global PnL summary across all trades. @@ -39,14 +40,10 @@ async def get_global_analysis( if total == 0: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail="No trades found. Upload some trades first." + detail="No trades found. Upload some trades first.", ) - summary = analysis_service.get_global_summary( - db, - user_id=current_user.id, - filters=filters - ) + summary = analysis_service.get_global_summary(db, user_id=current_user.id, filters=filters) return GlobalSummaryResponse(**summary) @@ -56,29 +53,28 @@ async def get_market_analysis( market_slug: str, filters: Optional[FilterParams] = None, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ Calculate PnL summary for a specific market. """ # Check if market exists for user from ..models.trade import Trade - exists = db.query(Trade).filter( - Trade.user_id == current_user.id, - Trade.market_slug == market_slug - ).first() + + exists = ( + db.query(Trade) + .filter(Trade.user_id == current_user.id, Trade.market_slug == market_slug) + .first() + ) if not exists: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"No trades found for market: {market_slug}" + detail=f"No trades found for market: {market_slug}", ) summary = analysis_service.get_market_summary( - db, - user_id=current_user.id, - market_slug=market_slug, - filters=filters + db, user_id=current_user.id, market_slug=market_slug, filters=filters ) return MarketSummaryResponse(**summary) @@ -88,7 +84,7 @@ async def get_market_analysis( async def get_advanced_metrics( filters: Optional[FilterParams] = None, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ Calculate advanced trading metrics: Sharpe ratio, Sortino ratio, @@ -96,14 +92,12 @@ async def get_advanced_metrics( """ from prediction_analyzer.metrics import calculate_advanced_metrics - trades = analysis_service.get_filtered_trades( - db, user_id=current_user.id, filters=filters - ) + trades = analysis_service.get_filtered_trades(db, user_id=current_user.id, filters=filters) if not trades: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail="No trades found. Upload some trades first." + detail="No trades found. Upload some trades first.", ) metrics = calculate_advanced_metrics(trades) @@ -114,25 +108,21 @@ async def get_advanced_metrics( async def get_market_breakdown( filters: Optional[FilterParams] = None, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ Get PnL breakdown by market. Returns a list of markets with their trade counts and PnL. """ - breakdown = analysis_service.get_market_breakdown( - db, - user_id=current_user.id, - filters=filters - ) + breakdown = analysis_service.get_market_breakdown(db, user_id=current_user.id, filters=filters) return [ MarketBreakdownItem( market=data["market_name"], market_slug=slug, trade_count=data["trade_count"], - pnl=data["total_pnl"] + pnl=data["total_pnl"], ) for slug, data in breakdown.items() ] @@ -143,7 +133,7 @@ async def get_pnl_timeseries( market_slug: Optional[str] = None, filters: Optional[FilterParams] = None, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ Get time-series PnL data for charting. @@ -151,10 +141,7 @@ async def get_pnl_timeseries( Returns trade-by-trade data with cumulative PnL and exposure. """ data = analysis_service.get_pnl_timeseries( - db, - user_id=current_user.id, - market_slug=market_slug, - filters=filters + db, user_id=current_user.id, market_slug=market_slug, filters=filters ) if not data: @@ -163,76 +150,70 @@ async def get_pnl_timeseries( # Format for frontend consumption formatted = [] for row in data: - formatted.append({ - "timestamp": row["timestamp"].isoformat() if hasattr(row["timestamp"], "isoformat") else str(row["timestamp"]), - "market": row.get("market"), - "type": row.get("type"), - "side": row.get("side"), - "price": row.get("price"), - "cost": row.get("cost"), - "pnl": row.get("trade_pnl"), - "cumulative_pnl": row.get("cumulative_pnl"), - "exposure": row.get("exposure") - }) + formatted.append( + { + "timestamp": ( + row["timestamp"].isoformat() + if hasattr(row["timestamp"], "isoformat") + else str(row["timestamp"]) + ), + "market": row.get("market"), + "type": row.get("type"), + "side": row.get("side"), + "price": row.get("price"), + "cost": row.get("cost"), + "pnl": row.get("trade_pnl"), + "cumulative_pnl": row.get("cumulative_pnl"), + "exposure": row.get("exposure"), + } + ) return {"data": formatted} # Saved Analysis CRUD + @router.get("/saved", response_model=List[SavedAnalysisResponse]) async def list_saved_analyses( - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + current_user: User = Depends(get_current_user), db: Session = Depends(get_db) ): """ List all saved analyses for the current user. """ analyses = analysis_service.get_saved_analyses(db, current_user.id) - return [ - SavedAnalysisResponse(**analysis_service.parse_saved_analysis(a)) - for a in analyses - ] + return [SavedAnalysisResponse(**analysis_service.parse_saved_analysis(a)) for a in analyses] @router.post("/saved", response_model=SavedAnalysisResponse, status_code=status.HTTP_201_CREATED) async def save_analysis( analysis_data: SavedAnalysisCreate, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ Save an analysis result for later reference. """ - saved = analysis_service.save_analysis( - db, - user_id=current_user.id, - analysis_data=analysis_data - ) + saved = analysis_service.save_analysis(db, user_id=current_user.id, analysis_data=analysis_data) return SavedAnalysisResponse(**analysis_service.parse_saved_analysis(saved)) @router.get("/saved/{analysis_id}", response_model=SavedAnalysisResponse) async def get_saved_analysis( - analysis_id: int, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + analysis_id: int, current_user: User = Depends(get_current_user), db: Session = Depends(get_db) ): """ Get a specific saved analysis. """ analysis = analysis_service.get_saved_analysis( - db, - user_id=current_user.id, - analysis_id=analysis_id + db, user_id=current_user.id, analysis_id=analysis_id ) if not analysis: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Saved analysis not found" + status_code=status.HTTP_404_NOT_FOUND, detail="Saved analysis not found" ) return SavedAnalysisResponse(**analysis_service.parse_saved_analysis(analysis)) @@ -240,23 +221,18 @@ async def get_saved_analysis( @router.delete("/saved/{analysis_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_saved_analysis( - analysis_id: int, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + analysis_id: int, current_user: User = Depends(get_current_user), db: Session = Depends(get_db) ): """ Delete a saved analysis. """ analysis = analysis_service.get_saved_analysis( - db, - user_id=current_user.id, - analysis_id=analysis_id + db, user_id=current_user.id, analysis_id=analysis_id ) if not analysis: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Saved analysis not found" + status_code=status.HTTP_404_NOT_FOUND, detail="Saved analysis not found" ) analysis_service.delete_saved_analysis(db, analysis) diff --git a/prediction_analyzer/api/routers/auth.py b/prediction_analyzer/api/routers/auth.py index 55a3a1b..51ef5a1 100644 --- a/prediction_analyzer/api/routers/auth.py +++ b/prediction_analyzer/api/routers/auth.py @@ -2,6 +2,7 @@ """ Authentication endpoints - signup, login """ + from fastapi import APIRouter, Body, Depends, HTTPException, status from fastapi.security import OAuth2PasswordRequestForm from sqlalchemy.orm import Session @@ -15,10 +16,7 @@ @router.post("/signup", response_model=dict, status_code=status.HTTP_201_CREATED) -async def signup( - user_data: UserCreate, - db: Session = Depends(get_db) -): +async def signup(user_data: UserCreate, db: Session = Depends(get_db)): """ Register a new user account. @@ -27,23 +25,18 @@ async def signup( # Check if email already exists if auth_service.get_user_by_email(db, user_data.email): raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Email already registered" + status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered" ) # Check if username already exists if auth_service.get_user_by_username(db, user_data.username): raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Username already taken" + status_code=status.HTTP_400_BAD_REQUEST, detail="Username already taken" ) # Create user user = auth_service.create_user( - db, - email=user_data.email, - username=user_data.username, - password=user_data.password + db, email=user_data.email, username=user_data.username, password=user_data.password ) # Create access token @@ -53,24 +46,19 @@ async def signup( "user": UserResponse.model_validate(user), "access_token": access_token, "token_type": "bearer", - "message": "Account created successfully" + "message": "Account created successfully", } @router.post("/login", response_model=Token) -async def login( - form_data: OAuth2PasswordRequestForm = Depends(), - db: Session = Depends(get_db) -): +async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)): """ Authenticate user and return access token. Uses OAuth2 password flow - username field should contain email. """ user = auth_service.authenticate_user( - db, - email=form_data.username, # OAuth2 uses 'username' field - password=form_data.password + db, email=form_data.username, password=form_data.password # OAuth2 uses 'username' field ) if not user: @@ -81,10 +69,7 @@ async def login( ) if not user.is_active: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Account is inactive" - ) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Account is inactive") access_token = auth_service.create_access_token(data={"sub": user.id}) @@ -93,9 +78,7 @@ async def login( @router.post("/login/json", response_model=Token) async def login_json( - email: str = Body(...), - password: str = Body(...), - db: Session = Depends(get_db) + email: str = Body(...), password: str = Body(...), db: Session = Depends(get_db) ): """ Alternative JSON-based login endpoint. @@ -106,15 +89,11 @@ async def login_json( if not user: raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Incorrect email or password" + status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect email or password" ) if not user.is_active: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Account is inactive" - ) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Account is inactive") access_token = auth_service.create_access_token(data={"sub": user.id}) diff --git a/prediction_analyzer/api/routers/charts.py b/prediction_analyzer/api/routers/charts.py index 957ff4c..0d0dade 100644 --- a/prediction_analyzer/api/routers/charts.py +++ b/prediction_analyzer/api/routers/charts.py @@ -2,6 +2,7 @@ """ Chart data endpoints - returns JSON data for frontend rendering """ + from typing import Optional from fastapi import APIRouter, Depends @@ -21,7 +22,7 @@ async def get_price_chart_data( market_slug: Optional[str] = None, filters: Optional[FilterParams] = None, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ Get price history chart data. @@ -29,10 +30,7 @@ async def get_price_chart_data( Returns trade prices over time with styling information. """ return chart_service.get_price_chart_data( - db, - user_id=current_user.id, - market_slug=market_slug, - filters=filters + db, user_id=current_user.id, market_slug=market_slug, filters=filters ) @@ -41,7 +39,7 @@ async def get_pnl_chart_data( market_slug: Optional[str] = None, filters: Optional[FilterParams] = None, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ Get cumulative PnL chart data. @@ -49,10 +47,7 @@ async def get_pnl_chart_data( Returns cumulative PnL over time. """ return chart_service.get_pnl_chart_data( - db, - user_id=current_user.id, - market_slug=market_slug, - filters=filters + db, user_id=current_user.id, market_slug=market_slug, filters=filters ) @@ -61,7 +56,7 @@ async def get_exposure_chart_data( market_slug: Optional[str] = None, filters: Optional[FilterParams] = None, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ Get net exposure chart data. @@ -69,10 +64,7 @@ async def get_exposure_chart_data( Returns net share exposure over time. """ return chart_service.get_exposure_chart_data( - db, - user_id=current_user.id, - market_slug=market_slug, - filters=filters + db, user_id=current_user.id, market_slug=market_slug, filters=filters ) @@ -80,15 +72,11 @@ async def get_exposure_chart_data( async def get_dashboard_data( filters: Optional[FilterParams] = None, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ Get multi-market dashboard data. Returns per-market PnL data and overall summary. """ - return chart_service.get_dashboard_data( - db, - user_id=current_user.id, - filters=filters - ) + return chart_service.get_dashboard_data(db, user_id=current_user.id, filters=filters) diff --git a/prediction_analyzer/api/routers/trades.py b/prediction_analyzer/api/routers/trades.py index a67e14f..796949c 100644 --- a/prediction_analyzer/api/routers/trades.py +++ b/prediction_analyzer/api/routers/trades.py @@ -31,9 +31,11 @@ async def list_trades( limit: int = Query(100, le=1000, ge=1), offset: int = Query(0, ge=0), market_slug: Optional[str] = None, - source: Optional[str] = Query(None, description="Filter by provider: limitless, polymarket, kalshi, manifold"), + source: Optional[str] = Query( + None, description="Filter by provider: limitless, polymarket, kalshi, manifold" + ), current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ List trades for the current user with pagination. @@ -56,7 +58,7 @@ async def list_trades( trades=[TradeResponse.model_validate(t) for t in trades], total=total, limit=limit, - offset=offset + offset=offset, ) @@ -64,7 +66,7 @@ async def list_trades( async def upload_trades( file: UploadFile = File(...), current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ Upload a trade file (JSON, CSV, or XLSX). @@ -73,34 +75,32 @@ async def upload_trades( """ filename = file.filename or "trades.json" - if not filename.lower().endswith(('.json', '.csv', '.xlsx')): + if not filename.lower().endswith((".json", ".csv", ".xlsx")): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Unsupported file type. Use JSON, CSV, or XLSX." + detail="Unsupported file type. Use JSON, CSV, or XLSX.", ) try: - upload_id, trade_count = await trade_service.process_upload( - db, current_user.id, file - ) + upload_id, trade_count = await trade_service.process_upload(db, current_user.id, file) except ValueError as e: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=str(e) - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) return TradeUploadResponse( upload_id=upload_id, filename=filename, trade_count=trade_count, - message=f"Successfully imported {trade_count} trades" + message=f"Successfully imported {trade_count} trades", ) @router.get("/providers") -async def list_providers(): +async def list_providers( + current_user: User = Depends(get_current_user), +): """List all available prediction market providers.""" from prediction_analyzer.config import PROVIDER_CONFIGS + return [ { "name": name, @@ -113,8 +113,7 @@ async def list_providers(): @router.get("/markets", response_model=list[MarketInfo]) async def list_markets( - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + current_user: User = Depends(get_current_user), db: Session = Depends(get_db) ): """ List all unique markets for the current user. @@ -128,7 +127,7 @@ async def list_markets( async def export_trades_csv( market_slug: Optional[str] = None, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ Export trades as CSV file. @@ -138,33 +137,34 @@ async def export_trades_csv( user_id=current_user.id, limit=100000, # Get all trades offset=0, - market_slug=market_slug + market_slug=market_slug, ) if not trades: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="No trades found to export" + status_code=status.HTTP_404_NOT_FOUND, detail="No trades found to export" ) # Convert to DataFrame - df = pd.DataFrame([ - { - "market": t.market, - "market_slug": t.market_slug, - "timestamp": t.timestamp.isoformat(), - "price": t.price, - "shares": t.shares, - "cost": t.cost, - "type": t.type, - "side": t.side, - "pnl": t.pnl, - "tx_hash": t.tx_hash, - "source": getattr(t, "source", "limitless"), - "currency": getattr(t, "currency", "USD"), - } - for t in trades - ]) + df = pd.DataFrame( + [ + { + "market": t.market, + "market_slug": t.market_slug, + "timestamp": t.timestamp.isoformat(), + "price": t.price, + "shares": t.shares, + "cost": t.cost, + "type": t.type, + "side": t.side, + "pnl": t.pnl, + "tx_hash": t.tx_hash, + "source": getattr(t, "source", "limitless"), + "currency": getattr(t, "currency", "USD"), + } + for t in trades + ] + ) # Write to buffer buffer = io.StringIO() @@ -173,17 +173,18 @@ async def export_trades_csv( # Sanitize filename components to prevent path traversal import re - safe_user = re.sub(r'[^\w\-.]', '_', current_user.username) + + safe_user = re.sub(r"[^\w\-.]", "_", current_user.username) filename = f"trades_{safe_user}" if market_slug: - safe_slug = re.sub(r'[^\w\-.]', '_', market_slug) + safe_slug = re.sub(r"[^\w\-.]", "_", market_slug) filename += f"_{safe_slug}" filename += ".csv" return StreamingResponse( iter([buffer.getvalue()]), media_type="text/csv", - headers={"Content-Disposition": f'attachment; filename="{filename}"'} + headers={"Content-Disposition": f'attachment; filename="{filename}"'}, ) @@ -191,23 +192,18 @@ async def export_trades_csv( async def export_trades_json( market_slug: Optional[str] = None, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ Export trades as JSON file. """ trades, _ = trade_service.get_user_trades( - db, - user_id=current_user.id, - limit=100000, - offset=0, - market_slug=market_slug + db, user_id=current_user.id, limit=100000, offset=0, market_slug=market_slug ) if not trades: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="No trades found to export" + status_code=status.HTTP_404_NOT_FOUND, detail="No trades found to export" ) # Convert to list of dicts @@ -230,25 +226,24 @@ async def export_trades_json( ] import re - safe_user = re.sub(r'[^\w\-.]', '_', current_user.username) + + safe_user = re.sub(r"[^\w\-.]", "_", current_user.username) filename = f"trades_{safe_user}" if market_slug: - safe_slug = re.sub(r'[^\w\-.]', '_', market_slug) + safe_slug = re.sub(r"[^\w\-.]", "_", market_slug) filename += f"_{safe_slug}" filename += ".json" return StreamingResponse( - iter([json.dumps(trades_data, indent=2)]), + iter([json.dumps(trades_data, indent=2, default=str)]), media_type="application/json", - headers={"Content-Disposition": f'attachment; filename="{filename}"'} + headers={"Content-Disposition": f'attachment; filename="{filename}"'}, ) @router.get("/{trade_id}", response_model=TradeResponse) async def get_trade( - trade_id: int, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + trade_id: int, current_user: User = Depends(get_current_user), db: Session = Depends(get_db) ): """ Get a specific trade by ID. @@ -256,19 +251,14 @@ async def get_trade( trade = trade_service.get_trade_by_id(db, current_user.id, trade_id) if not trade: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Trade not found" - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Trade not found") return TradeResponse.model_validate(trade) @router.delete("/{trade_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_trade( - trade_id: int, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + trade_id: int, current_user: User = Depends(get_current_user), db: Session = Depends(get_db) ): """ Delete a specific trade. @@ -276,10 +266,7 @@ async def delete_trade( trade = trade_service.get_trade_by_id(db, current_user.id, trade_id) if not trade: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Trade not found" - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Trade not found") trade_service.delete_trade(db, trade) return None @@ -287,8 +274,7 @@ async def delete_trade( @router.delete("", status_code=status.HTTP_200_OK) async def delete_all_trades( - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + current_user: User = Depends(get_current_user), db: Session = Depends(get_db) ): """ Delete all trades for the current user. @@ -297,7 +283,4 @@ async def delete_all_trades( """ count = trade_service.delete_all_user_trades(db, current_user.id) - return { - "message": f"Deleted {count} trades", - "deleted_count": count - } + return {"message": f"Deleted {count} trades", "deleted_count": count} diff --git a/prediction_analyzer/api/routers/users.py b/prediction_analyzer/api/routers/users.py index fd1687f..08a4638 100644 --- a/prediction_analyzer/api/routers/users.py +++ b/prediction_analyzer/api/routers/users.py @@ -2,6 +2,7 @@ """ User profile endpoints """ + from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session @@ -13,9 +14,7 @@ @router.get("/me", response_model=UserResponse) -async def get_current_user_profile( - current_user: User = Depends(get_current_user) -): +async def get_current_user_profile(current_user: User = Depends(get_current_user)): """ Get the current authenticated user's profile. """ @@ -26,7 +25,7 @@ async def get_current_user_profile( async def update_current_user_profile( user_update: UserUpdate, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ Update the current user's profile. @@ -40,8 +39,7 @@ async def update_current_user_profile( existing = auth_service.get_user_by_email(db, user_update.email) if existing: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Email already registered" + status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered" ) current_user.email = user_update.email @@ -50,8 +48,7 @@ async def update_current_user_profile( existing = auth_service.get_user_by_username(db, user_update.username) if existing: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Username already taken" + status_code=status.HTTP_400_BAD_REQUEST, detail="Username already taken" ) current_user.username = user_update.username @@ -63,8 +60,7 @@ async def update_current_user_profile( @router.delete("/me", status_code=status.HTTP_204_NO_CONTENT) async def delete_current_user( - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + current_user: User = Depends(get_current_user), db: Session = Depends(get_db) ): """ Delete the current user's account. @@ -78,8 +74,7 @@ async def delete_current_user( @router.get("/me/stats") async def get_current_user_stats( - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + current_user: User = Depends(get_current_user), db: Session = Depends(get_db) ): """ Get statistics for the current user. @@ -88,25 +83,27 @@ async def get_current_user_stats( from ..models.analysis import SavedAnalysis from sqlalchemy import func - trade_count = db.query(func.count(Trade.id)).filter( - Trade.user_id == current_user.id - ).scalar() + trade_count = db.query(func.count(Trade.id)).filter(Trade.user_id == current_user.id).scalar() - upload_count = db.query(func.count(TradeUpload.id)).filter( - TradeUpload.user_id == current_user.id - ).scalar() + upload_count = ( + db.query(func.count(TradeUpload.id)).filter(TradeUpload.user_id == current_user.id).scalar() + ) - analysis_count = db.query(func.count(SavedAnalysis.id)).filter( - SavedAnalysis.user_id == current_user.id - ).scalar() + analysis_count = ( + db.query(func.count(SavedAnalysis.id)) + .filter(SavedAnalysis.user_id == current_user.id) + .scalar() + ) - market_count = db.query(func.count(func.distinct(Trade.market_slug))).filter( - Trade.user_id == current_user.id - ).scalar() + market_count = ( + db.query(func.count(func.distinct(Trade.market_slug))) + .filter(Trade.user_id == current_user.id) + .scalar() + ) - total_pnl = db.query(func.sum(Trade.pnl)).filter( - Trade.user_id == current_user.id - ).scalar() or 0.0 + total_pnl = ( + db.query(func.sum(Trade.pnl)).filter(Trade.user_id == current_user.id).scalar() or 0.0 + ) return { "trade_count": trade_count, @@ -114,5 +111,5 @@ async def get_current_user_stats( "saved_analysis_count": analysis_count, "market_count": market_count, "total_pnl": total_pnl, - "member_since": current_user.created_at.isoformat() + "member_since": current_user.created_at.isoformat(), } diff --git a/prediction_analyzer/api/schemas/__init__.py b/prediction_analyzer/api/schemas/__init__.py index 5a0547b..2c6ef79 100644 --- a/prediction_analyzer/api/schemas/__init__.py +++ b/prediction_analyzer/api/schemas/__init__.py @@ -2,6 +2,7 @@ """ Pydantic schemas for request/response validation """ + from .user import UserCreate, UserLogin, UserResponse, UserUpdate from .auth import Token, TokenData from .trade import ( diff --git a/prediction_analyzer/api/schemas/analysis.py b/prediction_analyzer/api/schemas/analysis.py index f560794..aa0da3c 100644 --- a/prediction_analyzer/api/schemas/analysis.py +++ b/prediction_analyzer/api/schemas/analysis.py @@ -2,6 +2,7 @@ """ Analysis-related Pydantic schemas """ + from pydantic import BaseModel from datetime import datetime from typing import Optional, List, Dict, Any @@ -9,6 +10,7 @@ class FilterParams(BaseModel): """Parameters for filtering trades""" + start_date: Optional[str] = None # YYYY-MM-DD format end_date: Optional[str] = None types: Optional[List[str]] = None # ["Buy", "Sell"] @@ -20,6 +22,7 @@ class FilterParams(BaseModel): class GlobalSummaryResponse(BaseModel): """Global portfolio summary""" + total_trades: int total_volume: float total_pnl: float @@ -39,6 +42,7 @@ class GlobalSummaryResponse(BaseModel): class MarketSummaryResponse(BaseModel): """Per-market analysis summary""" + market_title: str market_slug: str total_trades: int @@ -56,6 +60,7 @@ class MarketSummaryResponse(BaseModel): class MarketBreakdownItem(BaseModel): """Single market in breakdown""" + market: str market_slug: str trade_count: int @@ -64,6 +69,7 @@ class MarketBreakdownItem(BaseModel): class SavedAnalysisCreate(BaseModel): """Schema for saving an analysis""" + name: str description: Optional[str] = None filter_params: Optional[FilterParams] = None @@ -73,6 +79,7 @@ class SavedAnalysisCreate(BaseModel): class SavedAnalysisResponse(BaseModel): """Schema for saved analysis response""" + id: int name: str description: Optional[str] diff --git a/prediction_analyzer/api/schemas/auth.py b/prediction_analyzer/api/schemas/auth.py index 3aeac90..08aa212 100644 --- a/prediction_analyzer/api/schemas/auth.py +++ b/prediction_analyzer/api/schemas/auth.py @@ -2,16 +2,19 @@ """ Authentication-related Pydantic schemas """ + from pydantic import BaseModel from typing import Optional class Token(BaseModel): """JWT token response""" + access_token: str token_type: str = "bearer" class TokenData(BaseModel): """Data extracted from JWT token""" + user_id: Optional[int] = None diff --git a/prediction_analyzer/api/schemas/charts.py b/prediction_analyzer/api/schemas/charts.py index 859861f..f9894cb 100644 --- a/prediction_analyzer/api/schemas/charts.py +++ b/prediction_analyzer/api/schemas/charts.py @@ -2,12 +2,14 @@ """ Chart data Pydantic schemas """ + from pydantic import BaseModel from typing import List, Dict, Any, Optional class ChartDataResponse(BaseModel): """Generic chart data response""" + times: List[str] # ISO format timestamps values: List[float] labels: Optional[List[str]] = None @@ -17,6 +19,7 @@ class ChartDataResponse(BaseModel): class PriceChartData(BaseModel): """Price history chart data""" + times: List[str] prices: List[float] colors: List[str] @@ -28,6 +31,7 @@ class PriceChartData(BaseModel): class PnLChartData(BaseModel): """Cumulative PnL chart data""" + times: List[str] cumulative_pnl: List[float] final_pnl: float @@ -35,6 +39,7 @@ class PnLChartData(BaseModel): class ExposureChartData(BaseModel): """Net exposure chart data""" + times: List[str] exposure: List[float] max_exposure: float @@ -42,5 +47,6 @@ class ExposureChartData(BaseModel): class DashboardDataResponse(BaseModel): """Multi-market dashboard data""" + markets: Dict[str, Dict[str, Any]] summary: Dict[str, Any] diff --git a/prediction_analyzer/api/schemas/trade.py b/prediction_analyzer/api/schemas/trade.py index d7475be..7fefbc0 100644 --- a/prediction_analyzer/api/schemas/trade.py +++ b/prediction_analyzer/api/schemas/trade.py @@ -2,6 +2,7 @@ """ Trade-related Pydantic schemas """ + from pydantic import BaseModel, Field from datetime import datetime from typing import Optional, List @@ -9,27 +10,32 @@ class TradeBase(BaseModel): """Base trade schema with common fields""" - market: str - market_slug: str + + market: str = Field(..., min_length=1, max_length=500) + market_slug: str = Field(..., min_length=1, max_length=255) timestamp: datetime - price: float = 0.0 - shares: float = 0.0 - cost: float = 0.0 - type: str # Buy, Sell, etc. - side: str # YES or NO - pnl: float = 0.0 - tx_hash: Optional[str] = None - source: str = "limitless" # "limitless", "polymarket", "kalshi", "manifold" - currency: str = "USD" # "USD", "USDC", "MANA" + price: float = Field(0.0, ge=0.0, le=1_000_000.0) + shares: float = Field(0.0, ge=0.0, le=1_000_000_000.0) + cost: float = Field(0.0, ge=0.0, le=1_000_000_000.0) + type: str = Field(..., min_length=1, max_length=50) # Buy, Sell, etc. + side: str = Field(..., min_length=1, max_length=10) # YES or NO + pnl: float = Field(0.0, ge=-1_000_000_000.0, le=1_000_000_000.0) + tx_hash: Optional[str] = Field(None, max_length=100) + source: str = Field( + "limitless", max_length=50 + ) # "limitless", "polymarket", "kalshi", "manifold" + currency: str = Field("USD", max_length=10) # "USD", "USDC", "MANA" class TradeCreate(TradeBase): """Schema for creating a trade manually""" + pass class TradeResponse(TradeBase): """Schema for trade response""" + id: int user_id: int upload_id: Optional[int] = None @@ -41,6 +47,7 @@ class Config: class TradeListResponse(BaseModel): """Schema for paginated trade list""" + trades: List[TradeResponse] total: int limit: int @@ -49,6 +56,7 @@ class TradeListResponse(BaseModel): class TradeUploadResponse(BaseModel): """Response after uploading trades""" + upload_id: int filename: str trade_count: int @@ -57,6 +65,7 @@ class TradeUploadResponse(BaseModel): class MarketInfo(BaseModel): """Market information for listing""" + slug: str title: str trade_count: int diff --git a/prediction_analyzer/api/schemas/user.py b/prediction_analyzer/api/schemas/user.py index 2d5b3bb..4cdfc76 100644 --- a/prediction_analyzer/api/schemas/user.py +++ b/prediction_analyzer/api/schemas/user.py @@ -2,6 +2,7 @@ """ User-related Pydantic schemas """ + from pydantic import BaseModel, EmailStr, Field from datetime import datetime from typing import Optional @@ -9,19 +10,22 @@ class UserCreate(BaseModel): """Schema for creating a new user""" + email: EmailStr username: str = Field(..., min_length=3, max_length=50) - password: str = Field(..., min_length=6, max_length=100) + password: str = Field(..., min_length=8, max_length=100) class UserLogin(BaseModel): """Schema for user login""" + email: EmailStr password: str class UserResponse(BaseModel): """Schema for user response (no sensitive data)""" + id: int email: str username: str @@ -34,5 +38,6 @@ class Config: class UserUpdate(BaseModel): """Schema for updating user profile""" + username: Optional[str] = Field(None, min_length=3, max_length=50) email: Optional[EmailStr] = None diff --git a/prediction_analyzer/api/services/__init__.py b/prediction_analyzer/api/services/__init__.py index b215792..a96dd0c 100644 --- a/prediction_analyzer/api/services/__init__.py +++ b/prediction_analyzer/api/services/__init__.py @@ -2,6 +2,7 @@ """ Business logic services """ + from .auth_service import AuthService, auth_service from .trade_service import TradeService, trade_service from .analysis_service import AnalysisService, analysis_service diff --git a/prediction_analyzer/api/services/analysis_service.py b/prediction_analyzer/api/services/analysis_service.py index 5d64ab6..78ed8f3 100644 --- a/prediction_analyzer/api/services/analysis_service.py +++ b/prediction_analyzer/api/services/analysis_service.py @@ -2,6 +2,7 @@ """ Analysis service - wraps existing pnl.py and filters.py functions """ + import json from typing import List, Dict, Any, Optional @@ -15,14 +16,9 @@ calculate_global_pnl_summary, calculate_market_pnl_summary, calculate_market_pnl, - calculate_pnl -) -from ...filters import ( - filter_by_date, - filter_by_trade_type, - filter_by_side, - filter_by_pnl + calculate_pnl, ) +from ...filters import filter_by_date, filter_by_trade_type, filter_by_side, filter_by_pnl from .trade_service import trade_service @@ -30,9 +26,7 @@ class AnalysisService: """Service for running analyses on trade data""" def apply_filters( - self, - trades: List[TradeDataclass], - filters: FilterParams + self, trades: List[TradeDataclass], filters: FilterParams ) -> List[TradeDataclass]: """ Apply filter parameters to a list of trades @@ -62,10 +56,7 @@ def apply_filters( return trades def get_global_summary( - self, - db: Session, - user_id: int, - filters: Optional[FilterParams] = None + self, db: Session, user_id: int, filters: Optional[FilterParams] = None ) -> Dict[str, Any]: """ Calculate global PnL summary for a user @@ -90,11 +81,7 @@ def get_global_summary( return calculate_global_pnl_summary(trades) def get_market_summary( - self, - db: Session, - user_id: int, - market_slug: str, - filters: Optional[FilterParams] = None + self, db: Session, user_id: int, market_slug: str, filters: Optional[FilterParams] = None ) -> Dict[str, Any]: """ Calculate PnL summary for a specific market @@ -109,10 +96,12 @@ def get_market_summary( Dictionary with market summary statistics """ # Get trades for specific market - db_trades = db.query(TradeModel).filter( - TradeModel.user_id == user_id, - TradeModel.market_slug == market_slug - ).order_by(TradeModel.timestamp.asc()).all() + db_trades = ( + db.query(TradeModel) + .filter(TradeModel.user_id == user_id, TradeModel.market_slug == market_slug) + .order_by(TradeModel.timestamp.asc()) + .all() + ) trades = trade_service.db_trades_to_dataclass(db_trades) @@ -126,10 +115,7 @@ def get_market_summary( return summary def get_market_breakdown( - self, - db: Session, - user_id: int, - filters: Optional[FilterParams] = None + self, db: Session, user_id: int, filters: Optional[FilterParams] = None ) -> Dict[str, Dict]: """ Get PnL breakdown by market @@ -150,7 +136,7 @@ def get_pnl_timeseries( db: Session, user_id: int, market_slug: Optional[str] = None, - filters: Optional[FilterParams] = None + filters: Optional[FilterParams] = None, ) -> List[Dict]: """ Get time-series PnL data for charting @@ -159,10 +145,12 @@ def get_pnl_timeseries( List of dictionaries with timestamp, cumulative_pnl, exposure """ if market_slug: - db_trades = db.query(TradeModel).filter( - TradeModel.user_id == user_id, - TradeModel.market_slug == market_slug - ).order_by(TradeModel.timestamp.asc()).all() + db_trades = ( + db.query(TradeModel) + .filter(TradeModel.user_id == user_id, TradeModel.market_slug == market_slug) + .order_by(TradeModel.timestamp.asc()) + .all() + ) else: db_trades = trade_service.get_all_user_trades(db, user_id) @@ -182,10 +170,7 @@ def get_pnl_timeseries( # Saved Analysis CRUD def save_analysis( - self, - db: Session, - user_id: int, - analysis_data: SavedAnalysisCreate + self, db: Session, user_id: int, analysis_data: SavedAnalysisCreate ) -> SavedAnalysis: """Save an analysis result""" saved = SavedAnalysis( @@ -196,7 +181,7 @@ def save_analysis( analysis_data.filter_params.model_dump() if analysis_data.filter_params else None ), market_slug=analysis_data.market_slug, - results=json.dumps(analysis_data.results) + results=json.dumps(analysis_data.results), ) db.add(saved) db.commit() @@ -205,21 +190,22 @@ def save_analysis( def get_saved_analyses(self, db: Session, user_id: int) -> List[SavedAnalysis]: """Get all saved analyses for a user""" - return db.query(SavedAnalysis).filter( - SavedAnalysis.user_id == user_id - ).order_by(SavedAnalysis.created_at.desc()).all() + return ( + db.query(SavedAnalysis) + .filter(SavedAnalysis.user_id == user_id) + .order_by(SavedAnalysis.created_at.desc()) + .all() + ) def get_saved_analysis( - self, - db: Session, - user_id: int, - analysis_id: int + self, db: Session, user_id: int, analysis_id: int ) -> Optional[SavedAnalysis]: """Get a specific saved analysis""" - return db.query(SavedAnalysis).filter( - SavedAnalysis.id == analysis_id, - SavedAnalysis.user_id == user_id - ).first() + return ( + db.query(SavedAnalysis) + .filter(SavedAnalysis.id == analysis_id, SavedAnalysis.user_id == user_id) + .first() + ) def delete_saved_analysis(self, db: Session, analysis: SavedAnalysis) -> None: """Delete a saved analysis""" @@ -227,10 +213,7 @@ def delete_saved_analysis(self, db: Session, analysis: SavedAnalysis) -> None: db.commit() def get_filtered_trades( - self, - db: Session, - user_id: int, - filters: Optional[FilterParams] = None + self, db: Session, user_id: int, filters: Optional[FilterParams] = None ) -> List[TradeDataclass]: """ Get all trades for a user with optional filters applied. @@ -261,7 +244,7 @@ def parse_saved_analysis(self, analysis: SavedAnalysis) -> Dict[str, Any]: "market_slug": analysis.market_slug, "results": json.loads(analysis.results), "created_at": analysis.created_at, - "updated_at": analysis.updated_at + "updated_at": analysis.updated_at, } diff --git a/prediction_analyzer/api/services/auth_service.py b/prediction_analyzer/api/services/auth_service.py index 82154b7..35a9664 100644 --- a/prediction_analyzer/api/services/auth_service.py +++ b/prediction_analyzer/api/services/auth_service.py @@ -2,6 +2,7 @@ """ Authentication service - JWT tokens and password hashing """ + from datetime import datetime, timedelta, timezone from typing import Optional @@ -30,11 +31,7 @@ def get_password_hash(self, password: str) -> str: """Hash a password""" return pwd_context.hash(password) - def create_access_token( - self, - data: dict, - expires_delta: Optional[timedelta] = None - ) -> str: + def create_access_token(self, data: dict, expires_delta: Optional[timedelta] = None) -> str: """ Create a JWT access token @@ -46,23 +43,24 @@ def create_access_token( Encoded JWT token string """ to_encode = data.copy() + # JWT 'sub' claim must be a string per RFC 7519 / PyJWT >=2.9 + if "sub" in to_encode: + to_encode["sub"] = str(to_encode["sub"]) if expires_delta: expire = datetime.now(timezone.utc) + expires_delta else: expire = datetime.now(timezone.utc) + timedelta( minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES ) - to_encode.update({ - "exp": expire, - "iat": datetime.now(timezone.utc), - "iss": "prediction-analyzer", - "aud": "prediction-analyzer-api", - }) - encoded_jwt = jwt.encode( - to_encode, - settings.SECRET_KEY, - algorithm=settings.ALGORITHM + to_encode.update( + { + "exp": expire, + "iat": datetime.now(timezone.utc), + "iss": "prediction-analyzer", + "aud": "prediction-analyzer-api", + } ) + encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) return encoded_jwt def decode_token(self, token: str) -> Optional[TokenData]: @@ -83,19 +81,15 @@ def decode_token(self, token: str) -> Optional[TokenData]: issuer="prediction-analyzer", audience="prediction-analyzer-api", ) - user_id: int = payload.get("sub") - if user_id is None: + raw_sub = payload.get("sub") + if raw_sub is None: return None + user_id = int(raw_sub) return TokenData(user_id=user_id) except (jwt.InvalidTokenError, jwt.DecodeError, jwt.ExpiredSignatureError): return None - def authenticate_user( - self, - db: Session, - email: str, - password: str - ) -> Optional[User]: + def authenticate_user(self, db: Session, email: str, password: str) -> Optional[User]: """ Authenticate a user by email and password @@ -126,13 +120,7 @@ def get_user_by_id(self, db: Session, user_id: int) -> Optional[User]: """Get user by ID""" return db.query(User).filter(User.id == user_id).first() - def create_user( - self, - db: Session, - email: str, - username: str, - password: str - ) -> User: + def create_user(self, db: Session, email: str, username: str, password: str) -> User: """ Create a new user @@ -146,11 +134,7 @@ def create_user( Created User object """ hashed_password = self.get_password_hash(password) - user = User( - email=email, - username=username, - hashed_password=hashed_password - ) + user = User(email=email, username=username, hashed_password=hashed_password) db.add(user) db.commit() db.refresh(user) diff --git a/prediction_analyzer/api/services/chart_service.py b/prediction_analyzer/api/services/chart_service.py index 4133491..d2d19cd 100644 --- a/prediction_analyzer/api/services/chart_service.py +++ b/prediction_analyzer/api/services/chart_service.py @@ -2,6 +2,7 @@ """ Chart service - generates chart data for frontend rendering """ + from typing import List, Dict, Any, Optional from sqlalchemy.orm import Session @@ -30,7 +31,9 @@ def get_trade_style(self, trade_type: str, side: str) -> tuple: color = "#00C853" if side == "YES" else "#FF1744" # Green for YES, Red for NO marker = "triangle-up" else: # Sell - color = "#FFD600" if side == "YES" else "#AA00FF" # Yellow for YES sell, Purple for NO sell + color = ( + "#FFD600" if side == "YES" else "#AA00FF" + ) # Yellow for YES sell, Purple for NO sell marker = "triangle-down" label = f"{trade_type} {side}" @@ -41,7 +44,7 @@ def get_price_chart_data( db: Session, user_id: int, market_slug: Optional[str] = None, - filters: Optional[FilterParams] = None + filters: Optional[FilterParams] = None, ) -> PriceChartData: """ Generate price chart data @@ -50,10 +53,12 @@ def get_price_chart_data( PriceChartData with times, prices, colors, etc. """ if market_slug: - db_trades = db.query(TradeModel).filter( - TradeModel.user_id == user_id, - TradeModel.market_slug == market_slug - ).order_by(TradeModel.timestamp.asc()).all() + db_trades = ( + db.query(TradeModel) + .filter(TradeModel.user_id == user_id, TradeModel.market_slug == market_slug) + .order_by(TradeModel.timestamp.asc()) + .all() + ) else: db_trades = trade_service.get_all_user_trades(db, user_id) @@ -64,13 +69,7 @@ def get_price_chart_data( if not trades: return PriceChartData( - times=[], - prices=[], - colors=[], - markers=[], - types=[], - sides=[], - costs=[] + times=[], prices=[], colors=[], markers=[], types=[], sides=[], costs=[] ) # Sort by timestamp @@ -101,7 +100,7 @@ def get_price_chart_data( markers=markers, types=types, sides=sides, - costs=costs + costs=costs, ) def get_pnl_chart_data( @@ -109,7 +108,7 @@ def get_pnl_chart_data( db: Session, user_id: int, market_slug: Optional[str] = None, - filters: Optional[FilterParams] = None + filters: Optional[FilterParams] = None, ) -> PnLChartData: """ Generate cumulative PnL chart data @@ -118,10 +117,12 @@ def get_pnl_chart_data( PnLChartData with times, cumulative_pnl, final_pnl """ if market_slug: - db_trades = db.query(TradeModel).filter( - TradeModel.user_id == user_id, - TradeModel.market_slug == market_slug - ).order_by(TradeModel.timestamp.asc()).all() + db_trades = ( + db.query(TradeModel) + .filter(TradeModel.user_id == user_id, TradeModel.market_slug == market_slug) + .order_by(TradeModel.timestamp.asc()) + .all() + ) else: db_trades = trade_service.get_all_user_trades(db, user_id) @@ -145,18 +146,14 @@ def get_pnl_chart_data( times.append(t.timestamp.isoformat()) cumulative_pnl.append(total) - return PnLChartData( - times=times, - cumulative_pnl=cumulative_pnl, - final_pnl=total - ) + return PnLChartData(times=times, cumulative_pnl=cumulative_pnl, final_pnl=total) def get_exposure_chart_data( self, db: Session, user_id: int, market_slug: Optional[str] = None, - filters: Optional[FilterParams] = None + filters: Optional[FilterParams] = None, ) -> ExposureChartData: """ Generate net exposure chart data @@ -165,10 +162,12 @@ def get_exposure_chart_data( ExposureChartData with times, exposure, max_exposure """ if market_slug: - db_trades = db.query(TradeModel).filter( - TradeModel.user_id == user_id, - TradeModel.market_slug == market_slug - ).order_by(TradeModel.timestamp.asc()).all() + db_trades = ( + db.query(TradeModel) + .filter(TradeModel.user_id == user_id, TradeModel.market_slug == market_slug) + .order_by(TradeModel.timestamp.asc()) + .all() + ) else: db_trades = trade_service.get_all_user_trades(db, user_id) @@ -187,17 +186,10 @@ def get_exposure_chart_data( exposure = df["exposure"].tolist() max_exposure = max(abs(e) for e in exposure) if exposure else 0.0 - return ExposureChartData( - times=times, - exposure=exposure, - max_exposure=max_exposure - ) + return ExposureChartData(times=times, exposure=exposure, max_exposure=max_exposure) def get_dashboard_data( - self, - db: Session, - user_id: int, - filters: Optional[FilterParams] = None + self, db: Session, user_id: int, filters: Optional[FilterParams] = None ) -> Dict[str, Any]: """ Generate multi-market dashboard data @@ -222,7 +214,7 @@ def get_dashboard_data( "title": t.market, "trades": [], "total_pnl": 0.0, - "trade_count": 0 + "trade_count": 0, } markets_data[t.market_slug]["trades"].append(t) markets_data[t.market_slug]["total_pnl"] += t.pnl @@ -248,7 +240,7 @@ def get_dashboard_data( "times": times, "cumulative_pnl": cumulative, "total_pnl": data["total_pnl"], - "trade_count": data["trade_count"] + "trade_count": data["trade_count"], } # Overall summary @@ -259,8 +251,16 @@ def get_dashboard_data( "total_markets": len(markets_data), "total_trades": total_trades, "total_pnl": total_pnl, - "best_market": max(markets_data.items(), key=lambda x: x[1]["total_pnl"])[0] if markets_data else None, - "worst_market": min(markets_data.items(), key=lambda x: x[1]["total_pnl"])[0] if markets_data else None + "best_market": ( + max(markets_data.items(), key=lambda x: x[1]["total_pnl"])[0] + if markets_data + else None + ), + "worst_market": ( + min(markets_data.items(), key=lambda x: x[1]["total_pnl"])[0] + if markets_data + else None + ), } return result diff --git a/prediction_analyzer/api/services/trade_service.py b/prediction_analyzer/api/services/trade_service.py index 84de2fe..cd7aba7 100644 --- a/prediction_analyzer/api/services/trade_service.py +++ b/prediction_analyzer/api/services/trade_service.py @@ -2,6 +2,7 @@ """ Trade service - file upload, CRUD operations """ + import tempfile import hashlib from pathlib import Path @@ -48,12 +49,7 @@ def db_trades_to_dataclass(self, db_trades: List[TradeModel]) -> List[TradeDatac """Convert a list of SQLAlchemy Trade models to Trade dataclasses""" return [self.db_trade_to_dataclass(t) for t in db_trades] - async def process_upload( - self, - db: Session, - user_id: int, - file: UploadFile - ) -> Tuple[int, int]: + async def process_upload(self, db: Session, user_id: int, file: UploadFile) -> Tuple[int, int]: """ Process an uploaded trade file @@ -72,17 +68,24 @@ async def process_upload( if suffix not in [".json", ".csv", ".xlsx"]: raise ValueError(f"Unsupported file type: {suffix}") - # Read file content - content = await file.read() + # Read file content with size limit (10 MB) to prevent OOM + MAX_UPLOAD_BYTES = 10 * 1024 * 1024 # 10 MB + content = await file.read(MAX_UPLOAD_BYTES + 1) + if len(content) > MAX_UPLOAD_BYTES: + raise ValueError( + f"File too large (>{MAX_UPLOAD_BYTES // (1024 * 1024)} MB). " + "Please split into smaller files." + ) # Calculate file hash for deduplication file_hash = hashlib.sha256(content).hexdigest() # Check for duplicate upload - existing = db.query(TradeUpload).filter( - TradeUpload.user_id == user_id, - TradeUpload.file_hash == file_hash - ).first() + existing = ( + db.query(TradeUpload) + .filter(TradeUpload.user_id == user_id, TradeUpload.file_hash == file_hash) + .first() + ) if existing: raise ValueError(f"This file was already uploaded (ID: {existing.id})") @@ -105,7 +108,7 @@ async def process_upload( filename=filename, file_type=suffix[1:], # Remove the dot trade_count=len(trades), - file_hash=file_hash + file_hash=file_hash, ) db.add(upload) db.flush() # Get the upload ID @@ -174,9 +177,12 @@ def get_user_trades( def get_all_user_trades(self, db: Session, user_id: int) -> List[TradeModel]: """Get all trades for a user (no pagination)""" - return db.query(TradeModel).filter( - TradeModel.user_id == user_id - ).order_by(TradeModel.timestamp.asc()).all() + return ( + db.query(TradeModel) + .filter(TradeModel.user_id == user_id) + .order_by(TradeModel.timestamp.asc()) + .all() + ) def get_user_markets(self, db: Session, user_id: int) -> List[MarketInfo]: """ @@ -185,39 +191,35 @@ def get_user_markets(self, db: Session, user_id: int) -> List[MarketInfo]: Returns: List of MarketInfo objects """ - results = db.query( - TradeModel.market_slug, - TradeModel.market, - func.count(TradeModel.id).label("trade_count"), - func.sum(TradeModel.pnl).label("total_pnl") - ).filter( - TradeModel.user_id == user_id - ).group_by( - TradeModel.market_slug, - TradeModel.market - ).all() + results = ( + db.query( + TradeModel.market_slug, + TradeModel.market, + func.count(TradeModel.id).label("trade_count"), + func.sum(TradeModel.pnl).label("total_pnl"), + ) + .filter(TradeModel.user_id == user_id) + .group_by(TradeModel.market_slug, TradeModel.market) + .all() + ) return [ MarketInfo( slug=r.market_slug, title=r.market, trade_count=r.trade_count, - total_pnl=r.total_pnl or 0.0 + total_pnl=r.total_pnl or 0.0, ) for r in results ] - def get_trade_by_id( - self, - db: Session, - user_id: int, - trade_id: int - ) -> Optional[TradeModel]: + def get_trade_by_id(self, db: Session, user_id: int, trade_id: int) -> Optional[TradeModel]: """Get a specific trade by ID (must belong to user)""" - return db.query(TradeModel).filter( - TradeModel.id == trade_id, - TradeModel.user_id == user_id - ).first() + return ( + db.query(TradeModel) + .filter(TradeModel.id == trade_id, TradeModel.user_id == user_id) + .first() + ) def delete_trade(self, db: Session, trade: TradeModel) -> None: """Delete a trade""" @@ -234,9 +236,12 @@ def delete_all_user_trades(self, db: Session, user_id: int) -> int: def get_user_uploads(self, db: Session, user_id: int) -> List[TradeUpload]: """Get all uploads for a user""" - return db.query(TradeUpload).filter( - TradeUpload.user_id == user_id - ).order_by(TradeUpload.uploaded_at.desc()).all() + return ( + db.query(TradeUpload) + .filter(TradeUpload.user_id == user_id) + .order_by(TradeUpload.uploaded_at.desc()) + .all() + ) # Singleton instance diff --git a/prediction_analyzer/charts/__init__.py b/prediction_analyzer/charts/__init__.py index 781ed41..2ef30dd 100644 --- a/prediction_analyzer/charts/__init__.py +++ b/prediction_analyzer/charts/__init__.py @@ -2,9 +2,15 @@ """ Chart generation modules """ + from .simple import generate_simple_chart from .pro import generate_pro_chart from .enhanced import generate_enhanced_chart from .global_chart import generate_global_dashboard -__all__ = ['generate_simple_chart', 'generate_pro_chart', 'generate_enhanced_chart', 'generate_global_dashboard'] +__all__ = [ + "generate_simple_chart", + "generate_pro_chart", + "generate_enhanced_chart", + "generate_global_dashboard", +] diff --git a/prediction_analyzer/charts/enhanced.py b/prediction_analyzer/charts/enhanced.py index d6635c6..3315208 100644 --- a/prediction_analyzer/charts/enhanced.py +++ b/prediction_analyzer/charts/enhanced.py @@ -2,6 +2,7 @@ """ Enhanced chart generation with battlefield visualization """ + import logging import plotly.graph_objects as go from plotly.subplots import make_subplots @@ -16,7 +17,13 @@ _DEFAULT_OUTPUT_DIR = Path(__file__).resolve().parent.parent.parent / "charts_output" -def generate_enhanced_chart(trades: List[Trade], market_name: str, resolved_outcome: str = None, output_dir: Optional[str] = None, show: bool = True): +def generate_enhanced_chart( + trades: List[Trade], + market_name: str, + resolved_outcome: str = None, + output_dir: Optional[str] = None, + show: bool = True, +): """ Generate an enhanced battlefield chart using Plotly @@ -103,11 +110,11 @@ def generate_enhanced_chart(trades: List[Trade], market_name: str, resolved_outc if t.type in ["Buy", "Market Buy", "Limit Buy"]: # Buying YES = Long YES # Buying NO = Short YES - is_long_yes = (t.side == "YES") + is_long_yes = t.side == "YES" else: # Sell # Selling YES = Short YES # Selling NO = Long YES - is_long_yes = (t.side == "NO") + is_long_yes = t.side == "NO" # Color and symbol if is_long_yes: @@ -133,15 +140,16 @@ def generate_enhanced_chart(trades: List[Trade], market_name: str, resolved_outc # Create subplots fig = make_subplots( - rows=3, cols=1, + rows=3, + cols=1, shared_xaxes=True, vertical_spacing=0.08, subplot_titles=( "The Battlefield: Implied Probability of YES", "The Scoreboard: Running P&L (Mark-to-Market)", - "The Risk: Net Share Position" + "The Risk: Net Share Position", ), - row_heights=[0.4, 0.3, 0.3] + row_heights=[0.4, 0.3, 0.3], ) # ========================================== @@ -157,9 +165,10 @@ def generate_enhanced_chart(trades: List[Trade], market_name: str, resolved_outc line=dict(color="#1f77b4", width=3), name="Market Price", showlegend=True, - hovertemplate="Price: %{y:.1f}¢" + hovertemplate="Price: %{y:.1f}¢", ), - row=1, col=1 + row=1, + col=1, ) # Trade markers (triangles) @@ -172,14 +181,15 @@ def generate_enhanced_chart(trades: List[Trade], market_name: str, resolved_outc color=trade_colors, size=trade_sizes, symbol=trade_symbols, - line=dict(width=1, color="white") + line=dict(width=1, color="white"), ), name="Trades", text=hover_texts, hovertemplate="%{text}", - showlegend=True + showlegend=True, ), - row=1, col=1 + row=1, + col=1, ) # ========================================== @@ -187,7 +197,7 @@ def generate_enhanced_chart(trades: List[Trade], market_name: str, resolved_outc # ========================================== # Determine colors for P&L line segments - pnl_colors = ['green' if pnl >= 0 else 'red' for pnl in running_pnl] + pnl_colors = ["green" if pnl >= 0 else "red" for pnl in running_pnl] # P&L line with fill fig.add_trace( @@ -197,12 +207,13 @@ def generate_enhanced_chart(trades: List[Trade], market_name: str, resolved_outc mode="lines", line=dict(color="black", width=2), name="Running P&L", - fill='tozeroy', - fillcolor='rgba(0,255,0,0.2)', # Will be conditional + fill="tozeroy", + fillcolor="rgba(0,255,0,0.2)", # Will be conditional showlegend=True, - hovertemplate="P&L: $%{y:.2f}" + hovertemplate="P&L: $%{y:.2f}", ), - row=2, col=1 + row=2, + col=1, ) # Add positive/negative fill regions @@ -214,13 +225,14 @@ def generate_enhanced_chart(trades: List[Trade], market_name: str, resolved_outc x=times, y=positive_pnl, mode="none", - fill='tozeroy', - fillcolor='rgba(0,255,0,0.3)', + fill="tozeroy", + fillcolor="rgba(0,255,0,0.3)", name="Profit", showlegend=False, - hoverinfo='skip' + hoverinfo="skip", ), - row=2, col=1 + row=2, + col=1, ) fig.add_trace( @@ -228,13 +240,14 @@ def generate_enhanced_chart(trades: List[Trade], market_name: str, resolved_outc x=times, y=negative_pnl, mode="none", - fill='tozeroy', - fillcolor='rgba(255,0,0,0.3)', + fill="tozeroy", + fillcolor="rgba(255,0,0,0.3)", name="Loss", showlegend=False, - hoverinfo='skip' + hoverinfo="skip", ), - row=2, col=1 + row=2, + col=1, ) # Zero line @@ -245,7 +258,7 @@ def generate_enhanced_chart(trades: List[Trade], market_name: str, resolved_outc # ========================================== # Determine fill color based on position - fill_color = 'rgba(0,150,255,0.2)' if net_shares[-1] >= 0 else 'rgba(255,100,0,0.2)' + fill_color = "rgba(0,150,255,0.2)" if net_shares[-1] >= 0 else "rgba(255,100,0,0.2)" fig.add_trace( go.Scatter( @@ -254,12 +267,13 @@ def generate_enhanced_chart(trades: List[Trade], market_name: str, resolved_outc mode="lines", line=dict(color="purple", width=2), name="Net Shares", - fill='tozeroy', + fill="tozeroy", fillcolor=fill_color, showlegend=True, - hovertemplate="Shares: %{y:.1f}" + hovertemplate="Shares: %{y:.1f}", ), - row=3, col=1 + row=3, + col=1, ) # Zero line @@ -280,14 +294,8 @@ def generate_enhanced_chart(trades: List[Trade], market_name: str, resolved_outc title_text=title_text, title_font_size=16, showlegend=True, - hovermode='x unified', - legend=dict( - orientation="h", - yanchor="bottom", - y=1.02, - xanchor="right", - x=1 - ) + hovermode="x unified", + legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1), ) # Update axes @@ -297,8 +305,8 @@ def generate_enhanced_chart(trades: List[Trade], market_name: str, resolved_outc fig.update_xaxes(title_text="Time", row=3, col=1) # Add gridlines - fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)') - fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)') + fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor="rgba(128,128,128,0.2)") + fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor="rgba(128,128,128,0.2)") # Save as interactive HTML with sanitized filename to output directory out = Path(output_dir) if output_dir else _DEFAULT_OUTPUT_DIR diff --git a/prediction_analyzer/charts/global_chart.py b/prediction_analyzer/charts/global_chart.py index 43811ac..900a89f 100644 --- a/prediction_analyzer/charts/global_chart.py +++ b/prediction_analyzer/charts/global_chart.py @@ -2,6 +2,7 @@ """ Global multi-market dashboard """ + import logging import plotly.graph_objects as go from pathlib import Path @@ -14,7 +15,10 @@ # Default output directory: charts_output/ under project root _DEFAULT_OUTPUT_DIR = Path(__file__).resolve().parent.parent.parent / "charts_output" -def generate_global_dashboard(trades_by_market: Dict[str, List[Trade]], output_dir: Optional[str] = None, show: bool = True): + +def generate_global_dashboard( + trades_by_market: Dict[str, List[Trade]], output_dir: Optional[str] = None, show: bool = True +): """ Generate a global PnL dashboard across multiple markets @@ -47,14 +51,16 @@ def generate_global_dashboard(trades_by_market: Dict[str, List[Trade]], output_d cumulative.append(cum) # Add to plot - fig.add_trace(go.Scatter( - x=times, - y=cumulative, - mode='lines', - name=market_name, - line=dict(width=2), - hovertemplate=f'{market_name}
Time: %{{x}}
PnL: $%{{y:.2f}}' - )) + fig.add_trace( + go.Scatter( + x=times, + y=cumulative, + mode="lines", + name=market_name, + line=dict(width=2), + hovertemplate=f"{market_name}
Time: %{{x}}
PnL: $%{{y:.2f}}", + ) + ) # Collect all trades for total portfolio calculation all_trades.extend(sorted_trades) @@ -74,29 +80,25 @@ def generate_global_dashboard(trades_by_market: Dict[str, List[Trade]], output_d total_times.append(trade.timestamp) total_cumulative.append(cum) - fig.add_trace(go.Scatter( - x=total_times, - y=total_cumulative, - mode='lines', - name='Total Portfolio', - line=dict(color='black', width=4, dash='dash'), - hovertemplate='Total
Time: %{x}
PnL: $%{y:.2f}' - )) + fig.add_trace( + go.Scatter( + x=total_times, + y=total_cumulative, + mode="lines", + name="Total Portfolio", + line=dict(color="black", width=4, dash="dash"), + hovertemplate="Total
Time: %{x}
PnL: $%{y:.2f}", + ) + ) # Update layout fig.update_layout( title="Global Multi-Market PnL Dashboard", xaxis_title="Time", yaxis_title="Cumulative PnL ($)", - hovermode='x unified', + hovermode="x unified", height=700, - legend=dict( - orientation="v", - yanchor="top", - y=1, - xanchor="left", - x=1.02 - ) + legend=dict(orientation="v", yanchor="top", y=1, xanchor="left", x=1.02), ) # Save and display to output directory diff --git a/prediction_analyzer/charts/pro.py b/prediction_analyzer/charts/pro.py index 5289b60..2c240c6 100644 --- a/prediction_analyzer/charts/pro.py +++ b/prediction_analyzer/charts/pro.py @@ -2,6 +2,7 @@ """ Professional/advanced chart generation with Plotly """ + import logging import plotly.graph_objects as go from plotly.subplots import make_subplots @@ -15,7 +16,14 @@ # Default output directory: charts_output/ under project root _DEFAULT_OUTPUT_DIR = Path(__file__).resolve().parent.parent.parent / "charts_output" -def generate_pro_chart(trades: List[Trade], market_name: str, resolved_outcome: str = None, output_dir: Optional[str] = None, show: bool = True): + +def generate_pro_chart( + trades: List[Trade], + market_name: str, + resolved_outcome: str = None, + output_dir: Optional[str] = None, + show: bool = True, +): """ Generate an interactive professional chart using Plotly @@ -68,11 +76,12 @@ def generate_pro_chart(trades: List[Trade], market_name: str, resolved_outcome: # Create subplots fig = make_subplots( - rows=3, cols=1, + rows=3, + cols=1, shared_xaxes=True, vertical_spacing=0.08, subplot_titles=("Trade Prices", "Cumulative PnL", "Net Exposure"), - row_heights=[0.4, 0.3, 0.3] + row_heights=[0.4, 0.3, 0.3], ) # Plot 1: Price line with trade markers @@ -83,9 +92,10 @@ def generate_pro_chart(trades: List[Trade], market_name: str, resolved_outcome: mode="lines", line=dict(color="#1f77b4", width=2), name="Price", - showlegend=False + showlegend=False, ), - row=1, col=1 + row=1, + col=1, ) fig.add_trace( @@ -95,10 +105,14 @@ def generate_pro_chart(trades: List[Trade], market_name: str, resolved_outcome: mode="markers", marker=dict(color=colors, size=10, line=dict(width=1, color="black")), name="Trades", - text=[f"{t}
{s}
${c:.2f}" for t, s, c in zip(types, sides, [tr.cost for tr in sorted_trades])], - hoverinfo="text+x+y" + text=[ + f"{t}
{s}
${c:.2f}" + for t, s, c in zip(types, sides, [tr.cost for tr in sorted_trades]) + ], + hoverinfo="text+x+y", ), - row=1, col=1 + row=1, + col=1, ) # Plot 2: Cumulative PnL @@ -110,10 +124,11 @@ def generate_pro_chart(trades: List[Trade], market_name: str, resolved_outcome: line=dict(color="green" if cumulative_pnl[-1] >= 0 else "red", width=3), marker=dict(size=6), name="Cumulative PnL", - fill='tozeroy', - fillcolor='rgba(0,255,0,0.1)' if cumulative_pnl[-1] >= 0 else 'rgba(255,0,0,0.1)' + fill="tozeroy", + fillcolor="rgba(0,255,0,0.1)" if cumulative_pnl[-1] >= 0 else "rgba(255,0,0,0.1)", ), - row=2, col=1 + row=2, + col=1, ) # Plot 3: Net Exposure @@ -124,10 +139,11 @@ def generate_pro_chart(trades: List[Trade], market_name: str, resolved_outcome: mode="lines", line=dict(color="orange", width=2), name="Net Exposure", - fill='tozeroy', - fillcolor='rgba(255,165,0,0.2)' + fill="tozeroy", + fillcolor="rgba(255,165,0,0.2)", ), - row=3, col=1 + row=3, + col=1, ) # Update layout @@ -135,12 +151,7 @@ def generate_pro_chart(trades: List[Trade], market_name: str, resolved_outcome: if resolved_outcome: title_text += f" (Resolved: {resolved_outcome})" - fig.update_layout( - height=900, - title_text=title_text, - showlegend=True, - hovermode='x unified' - ) + fig.update_layout(height=900, title_text=title_text, showlegend=True, hovermode="x unified") # Update axes fig.update_yaxes(title_text="Price (¢)", row=1, col=1) diff --git a/prediction_analyzer/charts/simple.py b/prediction_analyzer/charts/simple.py index 354c134..3058fb0 100644 --- a/prediction_analyzer/charts/simple.py +++ b/prediction_analyzer/charts/simple.py @@ -2,6 +2,7 @@ """ Simple chart generation for novice users """ + import logging import matplotlib.pyplot as plt import matplotlib.dates as mdates @@ -16,7 +17,14 @@ # Default output directory: charts_output/ under project root _DEFAULT_OUTPUT_DIR = Path(__file__).resolve().parent.parent.parent / "charts_output" -def generate_simple_chart(trades: List[Trade], market_name: str, resolved_outcome: str = None, output_dir: Optional[str] = None, show: bool = True): + +def generate_simple_chart( + trades: List[Trade], + market_name: str, + resolved_outcome: str = None, + output_dir: Optional[str] = None, + show: bool = True, +): """ Generate a simple 2-panel chart showing price and exposure @@ -74,21 +82,29 @@ def generate_simple_chart(trades: List[Trade], market_name: str, resolved_outcom fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8), sharex=True) # Clean title (remove emojis) - safe_title = market_name.encode('ascii', 'ignore').decode('ascii') + safe_title = market_name.encode("ascii", "ignore").decode("ascii") title_text = f"{safe_title}" if resolved_outcome: title_text += f"\nResolved: {resolved_outcome} | PnL: ${final_pnl:+.2f}" - fig.suptitle(title_text, fontsize=14, fontweight='bold') + fig.suptitle(title_text, fontsize=14, fontweight="bold") # Plot 1: Price with trade bubbles - ax1.plot(times, prices, color='#1f77b4', alpha=0.5, linewidth=2, label='Price') + ax1.plot(times, prices, color="#1f77b4", alpha=0.5, linewidth=2, label="Price") # Add trade markers for t in sorted_trades: color, marker, _ = get_trade_style(t.type, t.side) size = min(max(t.cost * 2, 20), 500) - ax1.scatter(t.timestamp, t.price, s=size, c=color, marker=marker, - alpha=0.8, edgecolors='black', linewidths=0.5) + ax1.scatter( + t.timestamp, + t.price, + s=size, + c=color, + marker=marker, + alpha=0.8, + edgecolors="black", + linewidths=0.5, + ) ax1.set_ylabel("Price (¢)", fontsize=11) ax1.set_ylim(-5, 105) @@ -96,15 +112,15 @@ def generate_simple_chart(trades: List[Trade], market_name: str, resolved_outcom ax1.legend() # Plot 2: Net exposure - ax2.fill_between(times, exposures, 0, color='orange', alpha=0.3) - ax2.plot(times, exposures, color='orange', linewidth=2, label='Net Cash Invested') + ax2.fill_between(times, exposures, 0, color="orange", alpha=0.3) + ax2.plot(times, exposures, color="orange", linewidth=2, label="Net Cash Invested") ax2.set_ylabel("Net Cash ($)", fontsize=11) ax2.set_xlabel("Time", fontsize=11) ax2.grid(True, alpha=0.3) ax2.legend() # Format x-axis - ax2.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d %H:%M')) + ax2.xaxis.set_major_formatter(mdates.DateFormatter("%m-%d %H:%M")) fig.autofmt_xdate() plt.tight_layout() @@ -114,7 +130,7 @@ def generate_simple_chart(trades: List[Trade], market_name: str, resolved_outcom out.mkdir(parents=True, exist_ok=True) safe_market_name = _sanitize_filename(market_name, max_length=30) filepath = out / f"chart_{safe_market_name}.png" - plt.savefig(str(filepath), dpi=150, bbox_inches='tight') + plt.savefig(str(filepath), dpi=150, bbox_inches="tight") logger.info("Chart saved: %s", filepath) if show: plt.show() diff --git a/prediction_analyzer/comparison.py b/prediction_analyzer/comparison.py index bd20a0a..17dd3b9 100644 --- a/prediction_analyzer/comparison.py +++ b/prediction_analyzer/comparison.py @@ -2,6 +2,7 @@ """ Period-over-period performance comparison. """ + import logging from typing import List, Dict from .trade_loader import Trade, sanitize_numeric diff --git a/prediction_analyzer/config.py b/prediction_analyzer/config.py index 4b92a8b..dfb232d 100644 --- a/prediction_analyzer/config.py +++ b/prediction_analyzer/config.py @@ -41,20 +41,20 @@ # Chart Styling - includes all trade type variants STYLES = { # Standard Buy/Sell - ("Buy", "YES"): ("#008f00", "x", "Buy YES"), # strong green - ("Sell", "YES"): ("#00c800", "o", "Sell YES"), # vivid lime - ("Buy", "NO"): ("#d000d0", "x", "Buy NO"), # strong magenta - ("Sell", "NO"): ("#d40000", "o", "Sell NO"), # strong red + ("Buy", "YES"): ("#008f00", "x", "Buy YES"), # strong green + ("Sell", "YES"): ("#00c800", "o", "Sell YES"), # vivid lime + ("Buy", "NO"): ("#d000d0", "x", "Buy NO"), # strong magenta + ("Sell", "NO"): ("#d40000", "o", "Sell NO"), # strong red # Market orders - ("Market Buy", "YES"): ("#008f00", "x", "Buy YES"), # same as Buy + ("Market Buy", "YES"): ("#008f00", "x", "Buy YES"), # same as Buy ("Market Sell", "YES"): ("#00c800", "o", "Sell YES"), # same as Sell - ("Market Buy", "NO"): ("#d000d0", "x", "Buy NO"), - ("Market Sell", "NO"): ("#d40000", "o", "Sell NO"), + ("Market Buy", "NO"): ("#d000d0", "x", "Buy NO"), + ("Market Sell", "NO"): ("#d40000", "o", "Sell NO"), # Limit orders - ("Limit Buy", "YES"): ("#008f00", "^", "Limit Buy YES"), - ("Limit Sell", "YES"): ("#00c800", "v", "Limit Sell YES"), - ("Limit Buy", "NO"): ("#d000d0", "^", "Limit Buy NO"), - ("Limit Sell", "NO"): ("#d40000", "v", "Limit Sell NO"), + ("Limit Buy", "YES"): ("#008f00", "^", "Limit Buy YES"), + ("Limit Sell", "YES"): ("#00c800", "v", "Limit Sell YES"), + ("Limit Buy", "NO"): ("#d000d0", "^", "Limit Buy NO"), + ("Limit Sell", "NO"): ("#d40000", "v", "Limit Sell NO"), } diff --git a/prediction_analyzer/core/interactive.py b/prediction_analyzer/core/interactive.py index 8507c72..0788285 100644 --- a/prediction_analyzer/core/interactive.py +++ b/prediction_analyzer/core/interactive.py @@ -2,6 +2,7 @@ """ Interactive CLI menu for novice users """ + import sys from typing import List from ..trade_loader import Trade @@ -13,6 +14,7 @@ from ..charts.pro import generate_pro_chart from ..charts.enhanced import generate_enhanced_chart + def interactive_menu(trades: List[Trade]): """ Interactive menu for exploring and analyzing trades @@ -20,10 +22,10 @@ def interactive_menu(trades: List[Trade]): Args: trades: List of all Trade objects """ - print("\n" + "="*60) + print("\n" + "=" * 60) print(" PREDICTION MARKET TRADE ANALYZER") print(" Interactive Mode") - print("="*60) + print("=" * 60) while True: print("\n📊 MAIN MENU") @@ -37,24 +39,24 @@ def interactive_menu(trades: List[Trade]): choice = input("\nSelect option: ").strip().upper() - if choice == 'Q': + if choice == "Q": print("\n👋 Goodbye!") break - elif choice == '1': + elif choice == "1": # Global summary print_global_summary(trades, stream=sys.stdout) input("\nPress Enter to continue...") - elif choice == '2': + elif choice == "2": # Market-specific analysis analyze_market_menu(trades) - elif choice == '3': + elif choice == "3": # Export trades export_menu(trades) - elif choice == '4': + elif choice == "4": # Full report generate_text_report(trades) input("\nPress Enter to continue...") @@ -62,6 +64,7 @@ def interactive_menu(trades: List[Trade]): else: print("❌ Invalid option. Please try again.") + def analyze_market_menu(trades: List[Trade]): """Submenu for analyzing a specific market""" markets = get_unique_markets(trades) @@ -82,7 +85,7 @@ def analyze_market_menu(trades: List[Trade]): choice = input("\nSelect market number: ").strip().upper() - if choice == 'B': + if choice == "B": return try: @@ -116,11 +119,11 @@ def analyze_market_menu(trades: List[Trade]): chart_choice = input("\nSelect chart type (1, 2, or 3): ").strip() - if chart_choice == '1': + if chart_choice == "1": generate_simple_chart(filtered_trades, selected_name) - elif chart_choice == '2': + elif chart_choice == "2": generate_pro_chart(filtered_trades, selected_name) - elif chart_choice == '3': + elif chart_choice == "3": generate_enhanced_chart(filtered_trades, selected_name) else: print("❌ Invalid choice.") @@ -128,6 +131,7 @@ def analyze_market_menu(trades: List[Trade]): except ValueError: print("❌ Invalid input.") + def apply_filters_menu(trades: List[Trade]) -> List[Trade]: """ Interactive filter application menu @@ -149,7 +153,7 @@ def apply_filters_menu(trades: List[Trade]) -> List[Trade]: choice = input("Select option: ").strip() - if choice == '1': + if choice == "1": start = input("Start date (YYYY-MM-DD) or Enter to skip: ").strip() end = input("End date (YYYY-MM-DD) or Enter to skip: ").strip() start = start if start else None @@ -157,15 +161,15 @@ def apply_filters_menu(trades: List[Trade]) -> List[Trade]: filtered = filter_by_date(filtered, start, end) print(f"✅ {len(filtered)} trades after date filter") - elif choice == '2': + elif choice == "2": print("Select types (comma-separated): Buy, Sell") types_str = input("> ").strip() if types_str: - types = [t.strip() for t in types_str.split(',')] + types = [t.strip() for t in types_str.split(",")] filtered = filter_by_trade_type(filtered, types) print(f"✅ {len(filtered)} trades after type filter") - elif choice == '3': + elif choice == "3": min_pnl = input("Minimum PnL (or Enter to skip): ").strip() max_pnl = input("Maximum PnL (or Enter to skip): ").strip() min_pnl = float(min_pnl) if min_pnl else None @@ -173,11 +177,11 @@ def apply_filters_menu(trades: List[Trade]) -> List[Trade]: filtered = filter_by_pnl(filtered, min_pnl, max_pnl) print(f"✅ {len(filtered)} trades after PnL filter") - elif choice == '4': + elif choice == "4": filtered = trades print("✅ Filters cleared") - elif choice == '5': + elif choice == "5": break else: @@ -185,6 +189,7 @@ def apply_filters_menu(trades: List[Trade]) -> List[Trade]: return filtered + def export_menu(trades: List[Trade]): """Export menu for various formats""" print("\n💾 EXPORT OPTIONS") @@ -196,12 +201,12 @@ def export_menu(trades: List[Trade]): choice = input("Select option: ").strip() - if choice == '1': + if choice == "1": filename = input("Filename (or Enter for default): ").strip() filename = filename if filename else "trades_export.csv" export_to_csv(trades, filename) - elif choice == '2': + elif choice == "2": filename = input("Filename (or Enter for default): ").strip() filename = filename if filename else "trades_export.xlsx" export_to_excel(trades, filename) diff --git a/prediction_analyzer/drawdown.py b/prediction_analyzer/drawdown.py index f5afd29..7aa66cc 100644 --- a/prediction_analyzer/drawdown.py +++ b/prediction_analyzer/drawdown.py @@ -5,6 +5,7 @@ Extends the basic drawdown metrics from metrics.py with full drawdown period identification and recovery analysis. """ + import logging from typing import List, Dict, Optional from datetime import datetime @@ -138,13 +139,15 @@ def _identify_drawdown_periods( pct = (amount / peak_val * 100) if peak_val > 0 else 0.0 duration = (timestamps[i] - timestamps[start_idx]).days - periods.append({ - "start": timestamps[start_idx].strftime("%Y-%m-%d"), - "end": timestamps[i].strftime("%Y-%m-%d"), - "amount": sanitize_numeric(amount), - "pct": sanitize_numeric(pct), - "duration_days": duration, - }) + periods.append( + { + "start": timestamps[start_idx].strftime("%Y-%m-%d"), + "end": timestamps[i].strftime("%Y-%m-%d"), + "amount": sanitize_numeric(amount), + "pct": sanitize_numeric(pct), + "duration_days": duration, + } + ) # Handle ongoing drawdown if in_dd: @@ -155,13 +158,15 @@ def _identify_drawdown_periods( pct = (amount / peak_val * 100) if peak_val > 0 else 0.0 duration = (timestamps[-1] - timestamps[start_idx]).days - periods.append({ - "start": timestamps[start_idx].strftime("%Y-%m-%d"), - "end": None, - "amount": sanitize_numeric(amount), - "pct": sanitize_numeric(pct), - "duration_days": duration, - }) + periods.append( + { + "start": timestamps[start_idx].strftime("%Y-%m-%d"), + "end": None, + "amount": sanitize_numeric(amount), + "pct": sanitize_numeric(pct), + "duration_days": duration, + } + ) return periods diff --git a/prediction_analyzer/exceptions.py b/prediction_analyzer/exceptions.py index e8bfa38..5b95d32 100644 --- a/prediction_analyzer/exceptions.py +++ b/prediction_analyzer/exceptions.py @@ -9,34 +9,41 @@ class PredictionAnalyzerError(Exception): """Base exception for all prediction analyzer errors.""" + pass class NoTradesError(PredictionAnalyzerError): """Raised when an operation requires trades but none are loaded.""" + pass class TradeLoadError(PredictionAnalyzerError): """Raised when trade data cannot be loaded from a file or API.""" + pass class InvalidFilterError(PredictionAnalyzerError): """Raised when filter parameters are invalid.""" + pass class MarketNotFoundError(PredictionAnalyzerError): """Raised when a market slug doesn't match any loaded trades.""" + pass class ExportError(PredictionAnalyzerError): """Raised when data export fails.""" + pass class ChartError(PredictionAnalyzerError): """Raised when chart generation fails.""" + pass diff --git a/prediction_analyzer/filters.py b/prediction_analyzer/filters.py index 6a349ef..c93c9b7 100644 --- a/prediction_analyzer/filters.py +++ b/prediction_analyzer/filters.py @@ -2,6 +2,7 @@ """ Advanced filtering functions for trades """ + import math from datetime import datetime, timedelta, timezone from typing import List, Optional @@ -28,7 +29,7 @@ def _normalize_datetime(dt) -> datetime: return datetime.fromtimestamp(dt, tz=timezone.utc).replace(tzinfo=None) # Handle pandas Timestamp - if hasattr(dt, 'to_pydatetime'): + if hasattr(dt, "to_pydatetime"): dt = dt.to_pydatetime() # Handle timezone-aware datetime - convert to UTC then strip tzinfo @@ -38,7 +39,9 @@ def _normalize_datetime(dt) -> datetime: return dt -def filter_by_date(trades: List[Trade], start: Optional[str] = None, end: Optional[str] = None) -> List[Trade]: +def filter_by_date( + trades: List[Trade], start: Optional[str] = None, end: Optional[str] = None +) -> List[Trade]: """ Filter trades between start and end dates @@ -82,6 +85,7 @@ def filter_by_date(trades: List[Trade], start: Optional[str] = None, end: Option filtered.append(t) return filtered + def filter_by_trade_type(trades: List[Trade], types: Optional[List[str]] = None) -> List[Trade]: """ Filter trades by type (Buy/Sell) @@ -95,6 +99,7 @@ def filter_by_trade_type(trades: List[Trade], types: Optional[List[str]] = None) """ if not types: return trades + # Match variant types: "Buy" also matches "Market Buy", "Limit Buy", etc. # Use word-boundary check to avoid matching "Buyback" when filtering for "Buy" def _matches(trade_type: str) -> bool: @@ -105,8 +110,10 @@ def _matches(trade_type: str) -> bool: if trade_type.endswith(" " + base) or trade_type.startswith(base + " "): return True return False + return [t for t in trades if _matches(t.type)] + def filter_by_side(trades: List[Trade], sides: Optional[List[str]] = None) -> List[Trade]: """ Filter trades by side (YES/NO) @@ -122,7 +129,10 @@ def filter_by_side(trades: List[Trade], sides: Optional[List[str]] = None) -> Li return trades return [t for t in trades if t.side in sides] -def filter_by_pnl(trades: List[Trade], min_pnl: Optional[float] = None, max_pnl: Optional[float] = None) -> List[Trade]: + +def filter_by_pnl( + trades: List[Trade], min_pnl: Optional[float] = None, max_pnl: Optional[float] = None +) -> List[Trade]: """ Filter trades by PnL thresholds diff --git a/prediction_analyzer/inference.py b/prediction_analyzer/inference.py index 525ad64..bf5e605 100644 --- a/prediction_analyzer/inference.py +++ b/prediction_analyzer/inference.py @@ -2,11 +2,15 @@ """ Market outcome inference logic """ + from typing import Optional, Tuple, List from .trade_loader import Trade from .config import PRICE_RESOLUTION_THRESHOLD -def infer_resolved_side_from_trades(trades: List[Trade], threshold: float = PRICE_RESOLUTION_THRESHOLD) -> Tuple[Optional[str], Optional[Trade]]: + +def infer_resolved_side_from_trades( + trades: List[Trade], threshold: float = PRICE_RESOLUTION_THRESHOLD +) -> Tuple[Optional[str], Optional[Trade]]: """ Infer the resolved outcome of a market from trade history @@ -38,6 +42,7 @@ def infer_resolved_side_from_trades(trades: List[Trade], threshold: float = PRIC return inferred, latest + def detect_market_resolution(trades: List[Trade]) -> Optional[str]: """ Detect if market is resolved by looking for resolution events diff --git a/prediction_analyzer/logging_config.py b/prediction_analyzer/logging_config.py index 5f32756..012747d 100644 --- a/prediction_analyzer/logging_config.py +++ b/prediction_analyzer/logging_config.py @@ -9,6 +9,7 @@ Logging is configured to write to stderr so that stdout remains clean for MCP stdio transport and piped CLI output. """ + import logging import sys @@ -23,9 +24,7 @@ def configure_logging(level: int = logging.INFO): root_logger = logging.getLogger("prediction_analyzer") if not root_logger.handlers: handler = logging.StreamHandler(sys.stderr) - handler.setFormatter(logging.Formatter( - "%(levelname)s [%(name)s] %(message)s" - )) + handler.setFormatter(logging.Formatter("%(levelname)s [%(name)s] %(message)s")) root_logger.addHandler(handler) root_logger.setLevel(level) diff --git a/prediction_analyzer/metrics.py b/prediction_analyzer/metrics.py index 36bd84f..cf6ad4c 100644 --- a/prediction_analyzer/metrics.py +++ b/prediction_analyzer/metrics.py @@ -9,10 +9,11 @@ - Win/Loss Streak analysis - Period-over-period comparison """ + from typing import List, Dict, Optional from datetime import datetime, timedelta import numpy as np -from .trade_loader import Trade +from .trade_loader import Trade, INF_CAP def calculate_advanced_metrics(trades: List[Trade]) -> Dict: @@ -74,10 +75,12 @@ def _basic_stats(pnls: List[float]) -> Dict: total_wins = sum(wins) total_losses = abs(sum(losses)) - profit_factor = (total_wins / total_losses) if total_losses > 0 else float('inf') if total_wins > 0 else 0.0 - # Cap inf for serialization - if profit_factor == float('inf'): - profit_factor = 999.99 + profit_factor = ( + (total_wins / total_losses) if total_losses > 0 else float("inf") if total_wins > 0 else 0.0 + ) + # Cap inf for serialization (uses shared INF_CAP constant) + if profit_factor == float("inf"): + profit_factor = INF_CAP # Expectancy: average PnL per trade expectancy = float(np.mean(pnls)) if pnls else 0.0 @@ -143,7 +146,7 @@ def _risk_adjusted_metrics(pnls: List[float]) -> Dict: # Sortino ratio — downside deviation from target (0) downside_diffs = np.minimum(arr, 0.0) - downside_std = np.sqrt(np.sum(downside_diffs ** 2) / (len(arr) - 1)) if len(arr) > 1 else 0.0 + downside_std = np.sqrt(np.sum(downside_diffs**2) / (len(arr) - 1)) if len(arr) > 1 else 0.0 sortino = (mean_return / downside_std) if downside_std > 0 else 0.0 return { @@ -155,7 +158,12 @@ def _risk_adjusted_metrics(pnls: List[float]) -> Dict: def _streak_metrics(pnls: List[float]) -> Dict: """Calculate win/loss streak metrics.""" if not pnls: - return {"max_win_streak": 0, "max_loss_streak": 0, "current_streak": 0, "current_streak_type": None} + return { + "max_win_streak": 0, + "max_loss_streak": 0, + "current_streak": 0, + "current_streak_type": None, + } max_win = 0 max_loss = 0 @@ -231,7 +239,7 @@ def format_metrics_report(metrics: Dict) -> str: lines.append("\nStreaks:") lines.append(f" Max Win Streak: {metrics['max_win_streak']}") lines.append(f" Max Loss Streak: {metrics['max_loss_streak']}") - streak_type = metrics.get('current_streak_type', 'none') + streak_type = metrics.get("current_streak_type", "none") lines.append(f" Current Streak: {metrics['current_streak']} ({streak_type})") lines.append("\nVolume:") diff --git a/prediction_analyzer/pnl.py b/prediction_analyzer/pnl.py index 21a72b3..9e12d17 100644 --- a/prediction_analyzer/pnl.py +++ b/prediction_analyzer/pnl.py @@ -2,12 +2,15 @@ """ PnL calculation and analysis functions """ + +from decimal import Decimal from typing import List, Dict import pandas as pd import numpy as np from .trade_loader import Trade from .inference import detect_market_resolution + def calculate_pnl(trades: List[Trade]) -> pd.DataFrame: """ Calculate PnL metrics for a list of trades @@ -28,8 +31,13 @@ def calculate_pnl(trades: List[Trade]) -> pd.DataFrame: # Calculate individual trade PnL df["trade_pnl"] = df["pnl"] - # Calculate cumulative PnL - df["cumulative_pnl"] = df["trade_pnl"].cumsum() + # Calculate cumulative PnL using Decimal accumulation to avoid float drift + cumulative = [] + running = Decimal("0") + for pnl_val in df["trade_pnl"]: + running += Decimal(str(pnl_val)) + cumulative.append(float(running)) + df["cumulative_pnl"] = cumulative # Calculate exposure (net shares held) df["exposure"] = 0.0 @@ -44,14 +52,23 @@ def calculate_pnl(trades: List[Trade]) -> pd.DataFrame: return df + def _summarize_trades(trades: List[Trade]) -> Dict: """Compute summary stats for a list of trades (single currency group).""" if not trades: return { - "total_trades": 0, "total_volume": 0.0, "total_pnl": 0.0, - "win_rate": 0.0, "avg_pnl_per_trade": 0.0, "avg_pnl": 0.0, - "winning_trades": 0, "losing_trades": 0, "breakeven_trades": 0, - "total_invested": 0.0, "total_returned": 0.0, "roi": 0.0, + "total_trades": 0, + "total_volume": 0.0, + "total_pnl": 0.0, + "win_rate": 0.0, + "avg_pnl_per_trade": 0.0, + "avg_pnl": 0.0, + "winning_trades": 0, + "losing_trades": 0, + "breakeven_trades": 0, + "total_invested": 0.0, + "total_returned": 0.0, + "roi": 0.0, } df = pd.DataFrame([vars(t) for t in trades]) @@ -117,7 +134,7 @@ def calculate_global_pnl_summary(trades: List[Trade]) -> Dict: "total_invested": 0.0, "total_returned": 0.0, "roi": 0.0, - "avg_pnl": 0.0 + "avg_pnl": 0.0, } # Group trades by currency @@ -158,12 +175,15 @@ def calculate_global_pnl_summary(trades: List[Trade]) -> Dict: by_source[source] = { "total_trades": len(source_trades), "total_pnl": source_pnl, - "currency": getattr(source_trades[0], "currency", "USD") if source_trades else "USD", + "currency": ( + getattr(source_trades[0], "currency", "USD") if source_trades else "USD" + ), } result["by_source"] = by_source return result + def calculate_market_pnl(trades: List[Trade]) -> Dict[str, Dict]: """ Calculate PnL breakdown by market @@ -180,7 +200,7 @@ def calculate_market_pnl(trades: List[Trade]) -> Dict[str, Dict]: "market_name": trade.market, "total_volume": 0.0, "total_pnl": 0.0, - "trade_count": 0 + "trade_count": 0, } market_stats[slug]["total_volume"] += trade.cost @@ -189,6 +209,7 @@ def calculate_market_pnl(trades: List[Trade]) -> Dict[str, Dict]: return market_stats + def calculate_market_pnl_summary(trades: List[Trade]) -> Dict: """ Calculate detailed PnL summary for a specific market's trades @@ -211,7 +232,7 @@ def calculate_market_pnl_summary(trades: List[Trade]) -> Dict: "total_invested": 0.0, "total_returned": 0.0, "roi": 0.0, - "market_outcome": None + "market_outcome": None, } # Get market title from first trade @@ -256,5 +277,5 @@ def calculate_market_pnl_summary(trades: List[Trade]) -> Dict: "total_invested": total_invested, "total_returned": total_returned, "roi": roi, - "market_outcome": market_outcome + "market_outcome": market_outcome, } diff --git a/prediction_analyzer/positions.py b/prediction_analyzer/positions.py index 6aa68ab..10f3885 100644 --- a/prediction_analyzer/positions.py +++ b/prediction_analyzer/positions.py @@ -2,6 +2,7 @@ """ Portfolio position analysis: open positions, unrealized PnL, concentration risk. """ + import logging from typing import List, Dict, Optional from .trade_loader import Trade, sanitize_numeric @@ -75,16 +76,22 @@ def calculate_open_positions( if current_price is not None: unrealized_pnl = abs_shares * (current_price - avg_entry) - positions.append({ - "market": market_name, - "market_slug": slug, - "net_shares": sanitize_numeric(abs_shares), - "side": side, - "avg_entry_price": sanitize_numeric(avg_entry), - "current_price": sanitize_numeric(current_price) if current_price is not None else None, - "unrealized_pnl": sanitize_numeric(unrealized_pnl) if unrealized_pnl is not None else None, - "cost_basis": sanitize_numeric(abs(total_cost)), - }) + positions.append( + { + "market": market_name, + "market_slug": slug, + "net_shares": sanitize_numeric(abs_shares), + "side": side, + "avg_entry_price": sanitize_numeric(avg_entry), + "current_price": ( + sanitize_numeric(current_price) if current_price is not None else None + ), + "unrealized_pnl": ( + sanitize_numeric(unrealized_pnl) if unrealized_pnl is not None else None + ), + "cost_basis": sanitize_numeric(abs(total_cost)), + } + ) return positions @@ -121,13 +128,15 @@ def calculate_concentration_risk(trades: List[Trade]) -> Dict: markets = [] for slug, data in exposure_by_market.items(): pct = (data["exposure"] / total_exposure * 100) if total_exposure > 0 else 0.0 - markets.append({ - "market": data["market"], - "slug": data["slug"], - "exposure": sanitize_numeric(data["exposure"]), - "pct_of_total": sanitize_numeric(pct), - "trade_count": data["trade_count"], - }) + markets.append( + { + "market": data["market"], + "slug": data["slug"], + "exposure": sanitize_numeric(data["exposure"]), + "pct_of_total": sanitize_numeric(pct), + "trade_count": data["trade_count"], + } + ) # Sort by exposure descending markets.sort(key=lambda x: x["exposure"], reverse=True) diff --git a/prediction_analyzer/providers/__init__.py b/prediction_analyzer/providers/__init__.py index 84fefe3..dba97a8 100644 --- a/prediction_analyzer/providers/__init__.py +++ b/prediction_analyzer/providers/__init__.py @@ -4,6 +4,7 @@ Supports: Limitless Exchange, Polymarket, Kalshi, Manifold Markets. """ + from .base import MarketProvider, ProviderRegistry from .limitless import LimitlessProvider from .polymarket import PolymarketProvider diff --git a/prediction_analyzer/providers/base.py b/prediction_analyzer/providers/base.py index b420880..903f8e0 100644 --- a/prediction_analyzer/providers/base.py +++ b/prediction_analyzer/providers/base.py @@ -2,6 +2,7 @@ """ Abstract base class for prediction market data providers and provider registry. """ + import logging from abc import ABC, abstractmethod from typing import List, Optional, Dict, Any @@ -14,9 +15,9 @@ class MarketProvider(ABC): """Base class for prediction market data providers.""" - name: str # e.g. "polymarket" - display_name: str # e.g. "Polymarket" - api_key_prefix: str # e.g. "poly_", "kalshi_", "lmts_" + name: str # e.g. "polymarket" + display_name: str # e.g. "Polymarket" + api_key_prefix: str # e.g. "poly_", "kalshi_", "lmts_" currency: str = "USD" @abstractmethod @@ -56,9 +57,7 @@ def register(cls, provider: MarketProvider): @classmethod def get(cls, name: str) -> MarketProvider: if name not in cls._providers: - raise ValueError( - f"Unknown provider: {name}. Available: {list(cls._providers.keys())}" - ) + raise ValueError(f"Unknown provider: {name}. Available: {list(cls._providers.keys())}") return cls._providers[name] @classmethod diff --git a/prediction_analyzer/providers/kalshi.py b/prediction_analyzer/providers/kalshi.py index 8a1aa5c..8178c0a 100644 --- a/prediction_analyzer/providers/kalshi.py +++ b/prediction_analyzer/providers/kalshi.py @@ -10,6 +10,7 @@ IMPORTANT: Integer cent fields (price, yes_price, etc.) are deprecated and will be removed March 12, 2026. This provider uses the _fixed/_fp/_dollars string fields. """ + import base64 import datetime import logging @@ -52,9 +53,7 @@ def _load_credentials(self, api_key: str): self._api_key_id, pem_path = key_str.split(":", 1) else: self._api_key_id = key_str - pem_path = os.environ.get( - "KALSHI_PRIVATE_KEY_PATH", "kalshi_private_key.pem" - ) + pem_path = os.environ.get("KALSHI_PRIVATE_KEY_PATH", "kalshi_private_key.pem") with open(pem_path, "rb") as f: self._private_key = serialization.load_pem_private_key( @@ -70,9 +69,7 @@ def _sign_request(self, method: str, path: str) -> dict: from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import padding - timestamp_ms = str( - int(datetime.datetime.now().timestamp() * 1000) - ) + timestamp_ms = str(int(datetime.datetime.now().timestamp() * 1000)) path_without_query = path.split("?")[0] message = (timestamp_ms + method.upper() + path_without_query).encode("utf-8") @@ -96,6 +93,14 @@ def _sign_request(self, method: str, path: str) -> dict: def fetch_trades(self, api_key: str, page_limit: int = 100) -> List[Trade]: """Fetch user's fill history from Kalshi.""" self._load_credentials(api_key) + try: + return self._fetch_trades_inner(page_limit, api_key) + finally: + # Clear private key from memory after use + self._private_key = None + self._api_key_id = None + + def _fetch_trades_inner(self, page_limit: int, api_key: str) -> List[Trade]: all_trades: List[Trade] = [] cursor: Optional[str] = None limit = min(page_limit, 1000) @@ -109,9 +114,7 @@ def fetch_trades(self, api_key: str, page_limit: int = 100) -> List[Trade]: headers = self._sign_request("GET", path) try: - resp = requests.get( - f"{self._base_url}{path}", headers=headers, timeout=15 - ) + resp = requests.get(f"{self._base_url}{path}", headers=headers, timeout=15) resp.raise_for_status() data = resp.json() except requests.RequestException as exc: @@ -165,9 +168,7 @@ def _fetch_position_pnl(self, api_key: str) -> Dict[str, float]: headers = self._sign_request("GET", path) try: - resp = requests.get( - f"{self._base_url}{path}", headers=headers, timeout=15 - ) + resp = requests.get(f"{self._base_url}{path}", headers=headers, timeout=15) resp.raise_for_status() data = resp.json() except requests.RequestException: @@ -242,29 +243,41 @@ def normalize_trade(self, raw: dict, **kwargs) -> Trade: fixed = raw.get("no_price_fixed") legacy = raw.get("no_price") + fill_id = raw.get("fill_id") or raw.get("order_id") or "?" + if fixed is not None and str(fixed).strip(): try: price = float(fixed) except (ValueError, TypeError): - # Corrupted _fixed field — fall through to legacy + logger.warning( + "Kalshi fill %s: bad %s_price_fixed=%r, falling back to legacy", + fill_id, + side_str.lower(), + fixed, + ) fixed = None if fixed is None or not str(fixed).strip(): try: price = float(legacy or 0) / 100.0 except (ValueError, TypeError): + logger.warning( + "Kalshi fill %s: bad legacy price=%r, defaulting to 0", fill_id, legacy + ) price = 0.0 count_str = raw.get("count_fp", str(raw.get("count", 0))) try: count = float(count_str) except (ValueError, TypeError): + logger.warning("Kalshi fill %s: bad count=%r, defaulting to 0", fill_id, count_str) count = 0.0 fee_str = raw.get("fee_cost", "0") try: fee = float(fee_str) except (ValueError, TypeError): + logger.warning("Kalshi fill %s: bad fee_cost=%r, defaulting to 0", fill_id, fee_str) fee = 0.0 action = (raw.get("action") or "buy").title() @@ -276,9 +289,7 @@ def normalize_trade(self, raw: dict, **kwargs) -> Trade: return Trade( market=raw.get("ticker") or raw.get("market_ticker") or "Unknown", market_slug=raw.get("ticker") or raw.get("market_ticker") or "unknown", - timestamp=_parse_timestamp( - raw.get("created_time") or raw.get("ts") or 0 - ), + timestamp=_parse_timestamp(raw.get("created_time") or raw.get("ts") or 0), price=price, shares=count, cost=cost, @@ -295,9 +306,7 @@ def normalize_trade(self, raw: dict, **kwargs) -> Trade: def fetch_market_details(self, market_id: str) -> Optional[Dict[str, Any]]: """Fetch market by ticker (public, no auth).""" try: - resp = requests.get( - f"{PROD_BASE_URL}/trade-api/v2/markets/{market_id}", timeout=10 - ) + resp = requests.get(f"{PROD_BASE_URL}/trade-api/v2/markets/{market_id}", timeout=10) if resp.status_code == 200: return resp.json().get("market") except Exception as exc: diff --git a/prediction_analyzer/providers/limitless.py b/prediction_analyzer/providers/limitless.py index f0870d9..5c40275 100644 --- a/prediction_analyzer/providers/limitless.py +++ b/prediction_analyzer/providers/limitless.py @@ -2,6 +2,7 @@ """ Limitless Exchange provider — refactored from utils/data.py. """ + import logging import requests from typing import List, Optional, Dict, Any @@ -13,6 +14,9 @@ BASE_URL = "https://api.limitless.exchange" +# USDC uses 6 decimal places; on-chain amounts are in micro-units. +USDC_DECIMALS = 1_000_000 + class LimitlessProvider(MarketProvider): name = "limitless" @@ -68,9 +72,7 @@ def normalize_trade(self, raw: dict, **kwargs) -> Trade: market_title = market_data.get("title") or "Unknown" market_slug = market_data.get("slug") or "unknown" else: - market_title = ( - raw.get("market") if isinstance(raw.get("market"), str) else "Unknown" - ) + market_title = raw.get("market") if isinstance(raw.get("market"), str) else "Unknown" market_slug = raw.get("market_slug") or "unknown" if not market_title: @@ -81,9 +83,9 @@ def normalize_trade(self, raw: dict, **kwargs) -> Trade: # Convert from micro-units (USDC 6 decimals) if API format has_pnl = "pnl" in raw and raw["pnl"] is not None if "collateralAmount" in raw: - cost = float(raw.get("collateralAmount") or 0) / 1_000_000 - pnl = float(raw.get("pnl") or 0) / 1_000_000 - shares = float(raw.get("outcomeTokenAmount") or 0) / 1_000_000 + cost = float(raw.get("collateralAmount") or 0) / USDC_DECIMALS + pnl = float(raw.get("pnl") or 0) / USDC_DECIMALS + shares = float(raw.get("outcomeTokenAmount") or 0) / USDC_DECIMALS else: cost = float(raw.get("cost") or 0) pnl = float(raw.get("pnl") or 0) @@ -130,9 +132,7 @@ def normalize_trade(self, raw: dict, **kwargs) -> Trade: return Trade( market=market_title, market_slug=market_slug, - timestamp=_parse_timestamp( - raw.get("timestamp") or raw.get("blockTimestamp") or 0 - ), + timestamp=_parse_timestamp(raw.get("timestamp") or raw.get("blockTimestamp") or 0), price=raw_price, shares=shares, cost=cost, @@ -163,8 +163,5 @@ def detect_file_format(self, records: List[dict]) -> bool: return ( "collateralAmount" in first or "outcomeTokenAmount" in first - or ( - isinstance(first.get("market"), dict) - and "slug" in first["market"] - ) + or (isinstance(first.get("market"), dict) and "slug" in first["market"]) ) diff --git a/prediction_analyzer/providers/manifold.py b/prediction_analyzer/providers/manifold.py index 54b2dba..8069513 100644 --- a/prediction_analyzer/providers/manifold.py +++ b/prediction_analyzer/providers/manifold.py @@ -7,6 +7,7 @@ Auth: API key via "Authorization: Key " header. Currency: MANA (play money). """ + import logging import requests from typing import List, Optional, Dict, Any @@ -31,7 +32,13 @@ def _get_user_id(self, api_key: str) -> str: headers = {"Authorization": f"Key {raw_key}"} resp = requests.get(f"{BASE_URL}/v0/me", headers=headers, timeout=10) resp.raise_for_status() - return resp.json()["id"] + data = resp.json() + user_id = data.get("id") + if not user_id: + raise ValueError( + "Manifold /v0/me response missing 'id' field. " "Verify your API key is valid." + ) + return user_id def _fetch_market_metadata(self, contract_ids: List[str]) -> Dict[str, dict]: """Batch-fetch market metadata for contractIds (bets lack question/slug).""" @@ -76,9 +83,7 @@ def fetch_trades(self, api_key: str, page_limit: int = 1000) -> List[Trade]: params["before"] = cursor try: - resp = requests.get( - f"{BASE_URL}/v0/bets", params=params, timeout=15 - ) + resp = requests.get(f"{BASE_URL}/v0/bets", params=params, timeout=15) resp.raise_for_status() bets = resp.json() except requests.RequestException as exc: @@ -96,9 +101,7 @@ def fetch_trades(self, api_key: str, page_limit: int = 1000) -> List[Trade]: cursor = bets[-1]["id"] # Step 2: Fetch market metadata for unique contractIds - contract_ids = list( - {b.get("contractId", "") for b in all_bets if b.get("contractId")} - ) + contract_ids = list({b.get("contractId", "") for b in all_bets if b.get("contractId")}) logger.info("Fetching metadata for %d Manifold markets...", len(contract_ids)) market_meta = self._fetch_market_metadata(contract_ids) @@ -151,7 +154,9 @@ def fetch_market_details(self, market_id: str) -> Optional[Dict[str, Any]]: if resp.status_code == 200: return resp.json() except Exception as exc: - logger.warning("Failed to fetch Manifold market %s via %s: %s", market_id, endpoint, exc) + logger.warning( + "Failed to fetch Manifold market %s via %s: %s", market_id, endpoint, exc + ) return None def detect_file_format(self, records: List[dict]) -> bool: diff --git a/prediction_analyzer/providers/pnl_calculator.py b/prediction_analyzer/providers/pnl_calculator.py index 9eb9400..136cbb2 100644 --- a/prediction_analyzer/providers/pnl_calculator.py +++ b/prediction_analyzer/providers/pnl_calculator.py @@ -3,6 +3,7 @@ FIFO PnL computation for providers that don't supply per-trade PnL (Kalshi, Manifold, Polymarket). """ + import logging from typing import List, Dict from collections import defaultdict, deque @@ -61,7 +62,10 @@ def compute_realized_pnl(trades: List[Trade]) -> List[Trade]: if remaining > 0: logger.warning( "Unmatched sell shares: %.6f shares for %s (market=%s, side=%s)", - remaining, trade.type, trade.market_slug, trade.side, + remaining, + trade.type, + trade.market_slug, + trade.side, ) # Only set PnL if trade doesn't already have one from the provider diff --git a/prediction_analyzer/providers/polymarket.py b/prediction_analyzer/providers/polymarket.py index a7b28ff..f1c8831 100644 --- a/prediction_analyzer/providers/polymarket.py +++ b/prediction_analyzer/providers/polymarket.py @@ -9,6 +9,7 @@ Auth for public Data API: none — just pass wallet address as 'user' query param. Currency: USDC on Polygon. """ + import json import logging import requests @@ -41,7 +42,7 @@ def fetch_trades(self, api_key: str, page_limit: int = 100) -> List[Trade]: limit = min(page_limit, 500) offset = 0 - logger.info("Downloading Polymarket trade history for %s...", api_key[:10]) + logger.info("Downloading Polymarket trade history...") while True: params: Dict[str, Any] = { @@ -52,9 +53,7 @@ def fetch_trades(self, api_key: str, page_limit: int = 100) -> List[Trade]: } try: - resp = requests.get( - f"{DATA_API_URL}/activity", params=params, timeout=15 - ) + resp = requests.get(f"{DATA_API_URL}/activity", params=params, timeout=15) resp.raise_for_status() data = resp.json() except requests.RequestException as exc: diff --git a/prediction_analyzer/py.typed b/prediction_analyzer/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/prediction_analyzer/reporting/__init__.py b/prediction_analyzer/reporting/__init__.py index 6100e68..2d6502b 100644 --- a/prediction_analyzer/reporting/__init__.py +++ b/prediction_analyzer/reporting/__init__.py @@ -2,7 +2,8 @@ """ Reporting modules for text and data exports """ + from .report_text import generate_text_report, print_global_summary from .report_data import export_to_csv, export_to_excel -__all__ = ['generate_text_report', 'print_global_summary', 'export_to_csv', 'export_to_excel'] +__all__ = ["generate_text_report", "print_global_summary", "export_to_csv", "export_to_excel"] diff --git a/prediction_analyzer/reporting/report_data.py b/prediction_analyzer/reporting/report_data.py index 97b312f..1fb4c49 100644 --- a/prediction_analyzer/reporting/report_data.py +++ b/prediction_analyzer/reporting/report_data.py @@ -2,6 +2,7 @@ """ Data export functionality (CSV, Excel, JSON) """ + import logging import pandas as pd from typing import List @@ -10,6 +11,7 @@ logger = logging.getLogger(__name__) + def export_to_csv(trades: List[Trade], filename: str = "trades_export.csv"): """ Export trades to CSV file @@ -30,6 +32,7 @@ def export_to_csv(trades: List[Trade], filename: str = "trades_export.csv"): logger.error("Error exporting to CSV: %s", e) raise ExportError(f"Error exporting to CSV {filename}: {e}") from e + def export_to_excel(trades: List[Trade], filename: str = "trades_export.xlsx"): """ Export trades to Excel file with multiple sheets @@ -44,18 +47,18 @@ def export_to_excel(trades: List[Trade], filename: str = "trades_export.xlsx"): try: df = pd.DataFrame([t.to_dict() for t in trades]) - with pd.ExcelWriter(filename, engine='openpyxl') as writer: + with pd.ExcelWriter(filename, engine="openpyxl") as writer: # Main trades sheet - df.to_excel(writer, sheet_name='All Trades', index=False) + df.to_excel(writer, sheet_name="All Trades", index=False) # Summary by market (group by slug for consistency with pnl.py) - summary = df.groupby('market_slug').agg({ - 'cost': 'sum', - 'pnl': 'sum', - 'market': 'first' - }).rename(columns={'market': 'market_name'}) - summary['trade_count'] = df.groupby('market_slug').size() - summary.to_excel(writer, sheet_name='Market Summary') + summary = ( + df.groupby("market_slug") + .agg({"cost": "sum", "pnl": "sum", "market": "first"}) + .rename(columns={"market": "market_name"}) + ) + summary["trade_count"] = df.groupby("market_slug").size() + summary.to_excel(writer, sheet_name="Market Summary") logger.info("Trades exported to: %s", filename) return True @@ -63,6 +66,7 @@ def export_to_excel(trades: List[Trade], filename: str = "trades_export.xlsx"): logger.error("Error exporting to Excel: %s", e) raise ExportError(f"Error exporting to Excel {filename}: {e}") from e + def export_to_json(trades: List[Trade], filename: str = "trades_export.json"): """ Export trades to JSON file @@ -79,7 +83,7 @@ def export_to_json(trades: List[Trade], filename: str = "trades_export.json"): try: trades_dict = [t.to_dict() for t in trades] - with open(filename, 'w', encoding='utf-8') as f: + with open(filename, "w", encoding="utf-8") as f: json.dump(trades_dict, f, indent=2) logger.info("Trades exported to: %s", filename) diff --git a/prediction_analyzer/reporting/report_text.py b/prediction_analyzer/reporting/report_text.py index 3124a73..bcf2800 100644 --- a/prediction_analyzer/reporting/report_text.py +++ b/prediction_analyzer/reporting/report_text.py @@ -2,6 +2,7 @@ """ Text-based report generation """ + import logging import sys from typing import List, TextIO @@ -11,6 +12,7 @@ logger = logging.getLogger(__name__) + def print_global_summary(trades: List[Trade], stream: TextIO = None): """ Print a formatted global PnL summary to a stream. @@ -29,9 +31,9 @@ def _print(text=""): cur_label = summary.get("currency", "USD") cur_symbol = "M$" if cur_label == "MANA" else "$" - _print("\n" + "="*60) + _print("\n" + "=" * 60) _print("GLOBAL PORTFOLIO SUMMARY") - _print("="*60) + _print("=" * 60) _print(f"Total Trades: {summary['total_trades']}") _print(f"Total Volume: {cur_symbol}{summary['total_volume']:,.2f} {cur_label}") _print(f"Total Realized PnL: {cur_symbol}{summary['total_pnl']:,.2f} {cur_label}") @@ -45,23 +47,30 @@ def _print(text=""): _print("\nPNL BY CURRENCY:") for cur, cur_stats in sorted(by_currency.items()): cs = "M$" if cur == "MANA" else "$" - _print(f" {cur:<8} {cur_stats['total_trades']:>5} trades " - f"PnL: {cs}{cur_stats['total_pnl']:>10,.2f} " - f"Win Rate: {cur_stats['win_rate']:.1f}%") + _print( + f" {cur:<8} {cur_stats['total_trades']:>5} trades " + f"PnL: {cs}{cur_stats['total_pnl']:>10,.2f} " + f"Win Rate: {cur_stats['win_rate']:.1f}%" + ) _print("-" * 60) # Top markets by PnL - sorted_markets = sorted(market_stats.items(), key=lambda x: x[1]['total_pnl'], reverse=True) + sorted_markets = sorted(market_stats.items(), key=lambda x: x[1]["total_pnl"], reverse=True) _print("\nTOP MARKETS BY PNL:") _print(f"{'Rank':<6} {'Market':<40} {'PnL':>12}") _print("-" * 60) for i, (slug, stats) in enumerate(sorted_markets[:10], 1): - market_name = stats['market_name'][:37] + "..." if len(stats['market_name']) > 40 else stats['market_name'] + market_name = ( + stats["market_name"][:37] + "..." + if len(stats["market_name"]) > 40 + else stats["market_name"] + ) _print(f"{i:<6} {market_name:<40} {cur_symbol}{stats['total_pnl']:>10,.2f}") - _print("="*60 + "\n") + _print("=" * 60 + "\n") + def generate_text_report(trades: List[Trade], filename: str = None): """ @@ -79,10 +88,10 @@ def generate_text_report(trades: List[Trade], filename: str = None): market_stats = calculate_market_pnl(trades) lines = [] - lines.append("="*70) + lines.append("=" * 70) lines.append("PREDICTION MARKET TRADE ANALYSIS REPORT") lines.append(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") - lines.append("="*70) + lines.append("=" * 70) lines.append("") # Global summary @@ -90,50 +99,54 @@ def generate_text_report(trades: List[Trade], filename: str = None): cur_symbol = "M$" if cur_label == "MANA" else "$" lines.append("GLOBAL SUMMARY") - lines.append("-"*70) + lines.append("-" * 70) lines.append(f"Total Trades: {summary['total_trades']}") lines.append(f"Winning Trades: {summary['winning_trades']}") lines.append(f"Losing Trades: {summary['losing_trades']}") lines.append(f"Win Rate: {summary['win_rate']:.2f}%") lines.append(f"Total Volume: {cur_symbol}{summary['total_volume']:,.2f} {cur_label}") lines.append(f"Total Realized PnL: {cur_symbol}{summary['total_pnl']:,.2f} {cur_label}") - lines.append(f"Average PnL per Trade: {cur_symbol}{summary['avg_pnl_per_trade']:.2f} {cur_label}") + lines.append( + f"Average PnL per Trade: {cur_symbol}{summary['avg_pnl_per_trade']:.2f} {cur_label}" + ) lines.append("") # Per-currency breakdown by_currency = summary.get("by_currency") if by_currency: lines.append("PNL BY CURRENCY") - lines.append("-"*70) + lines.append("-" * 70) for cur, cur_stats in sorted(by_currency.items()): cs = "M$" if cur == "MANA" else "$" - lines.append(f" {cur:<8} {cur_stats['total_trades']:>5} trades " - f"PnL: {cs}{cur_stats['total_pnl']:>10,.2f} " - f"Win Rate: {cur_stats['win_rate']:.1f}% " - f"ROI: {cur_stats['roi']:.1f}%") + lines.append( + f" {cur:<8} {cur_stats['total_trades']:>5} trades " + f"PnL: {cs}{cur_stats['total_pnl']:>10,.2f} " + f"Win Rate: {cur_stats['win_rate']:.1f}% " + f"ROI: {cur_stats['roi']:.1f}%" + ) lines.append("") # Market breakdown lines.append("MARKET BREAKDOWN") - lines.append("-"*70) + lines.append("-" * 70) lines.append(f"{'Market':<45} {'Trades':>8} {'Volume':>12} {'PnL':>12}") - lines.append("-"*70) + lines.append("-" * 70) - sorted_markets = sorted(market_stats.items(), key=lambda x: x[1]['total_pnl'], reverse=True) + sorted_markets = sorted(market_stats.items(), key=lambda x: x[1]["total_pnl"], reverse=True) for slug, stats in sorted_markets: - market_name = stats['market_name'][:42] + market_name = stats["market_name"][:42] lines.append( f"{market_name:<45} {stats['trade_count']:>8} " f"{cur_symbol}{stats['total_volume']:>10,.2f} {cur_symbol}{stats['total_pnl']:>10,.2f}" ) lines.append("") - lines.append("="*70) + lines.append("=" * 70) lines.append("END OF REPORT") - lines.append("="*70) + lines.append("=" * 70) # Write to file - with open(filename, 'w', encoding='utf-8') as f: - f.write('\n'.join(lines)) + with open(filename, "w", encoding="utf-8") as f: + f.write("\n".join(lines)) logger.info("Report saved: %s", filename) diff --git a/prediction_analyzer/tax.py b/prediction_analyzer/tax.py index e146b56..d086e14 100644 --- a/prediction_analyzer/tax.py +++ b/prediction_analyzer/tax.py @@ -2,6 +2,7 @@ """ Tax reporting: capital gains/losses with FIFO, LIFO, and average cost basis methods. """ + import logging from typing import List, Dict, Optional from datetime import datetime, timedelta @@ -34,7 +35,9 @@ def calculate_capital_gains( Dict with tax summary and per-transaction breakdown """ if cost_basis_method not in VALID_METHODS: - raise ValueError(f"Invalid cost basis method: {cost_basis_method}. Valid: {sorted(VALID_METHODS)}") + raise ValueError( + f"Invalid cost basis method: {cost_basis_method}. Valid: {sorted(VALID_METHODS)}" + ) sorted_trades = sorted(trades, key=lambda t: t.timestamp) @@ -61,12 +64,14 @@ def calculate_capital_gains( # other providers bundle fees into cost implicitly) total_fees += getattr(trade, "fee", 0.0) # Add to buy lots - buy_lots.setdefault(slug, []).append({ - "date": trade.timestamp, - "shares": trade.shares, - "price": trade.price, - "cost_per_share": (trade.cost / trade.shares) if trade.shares > 0 else 0.0, - }) + buy_lots.setdefault(slug, []).append( + { + "date": trade.timestamp, + "shares": trade.shares, + "price": trade.price, + "cost_per_share": (trade.cost / trade.shares) if trade.shares > 0 else 0.0, + } + ) elif trade.type in _SELL_TYPES: # Track fees @@ -159,7 +164,9 @@ def calculate_capital_gains( logger.warning( "Tax report: %.4f shares of %s sold on %s have no matching buy lots " "(missing cost basis data)", - remaining_shares, slug, trade.timestamp.strftime("%Y-%m-%d"), + remaining_shares, + slug, + trade.timestamp.strftime("%Y-%m-%d"), ) else: @@ -169,13 +176,14 @@ def calculate_capital_gains( if skipped_types: logger.warning( "Tax report skipped %d trades with unrecognized types: %s", - sum(skipped_types.values()), skipped_types, + sum(skipped_types.values()), + skipped_types, ) # Detect wash sales wash_sales = _detect_wash_sales(transactions, sorted_trades) - net_gain_loss = (short_term_gains - short_term_losses + long_term_gains - long_term_losses) + net_gain_loss = short_term_gains - short_term_losses + long_term_gains - long_term_losses result = { "tax_year": tax_year, @@ -295,14 +303,16 @@ def _detect_wash_sales( delta = abs((buy_date - sell_date).days) if 0 < delta <= 30: - wash_sales.append({ - "market": tx["market"], - "market_slug": slug, - "date_sold": tx["date_sold"], - "date_repurchased": buy_date.strftime("%Y-%m-%d"), - "disallowed_loss": sanitize_numeric(abs(tx["gain_loss"])), - "shares": tx["shares"], - }) + wash_sales.append( + { + "market": tx["market"], + "market_slug": slug, + "date_sold": tx["date_sold"], + "date_repurchased": buy_date.strftime("%Y-%m-%d"), + "disallowed_loss": sanitize_numeric(abs(tx["gain_loss"])), + "shares": tx["shares"], + } + ) flagged_tx_ids.add(tx_id) break # One wash sale per loss transaction diff --git a/prediction_analyzer/trade_filter.py b/prediction_analyzer/trade_filter.py index 932bef3..53140ae 100644 --- a/prediction_analyzer/trade_filter.py +++ b/prediction_analyzer/trade_filter.py @@ -2,10 +2,12 @@ """ Trade filtering and deduplication utilities """ + from typing import List from difflib import get_close_matches from .trade_loader import Trade + def filter_trades(trades: List[Trade], market_name: str, fuzzy: bool = True) -> List[Trade]: """ Filter trades by market name with optional fuzzy matching @@ -39,14 +41,17 @@ def filter_trades(trades: List[Trade], market_name: str, fuzzy: bool = True) -> filtered = [t for t in trades if t.market == market_name or t.market_slug == market_name] return filtered + def filter_trades_by_market_slug(trades: List[Trade], market_slug: str) -> List[Trade]: """Filter trades by exact market slug match""" return [t for t in trades if t.market_slug == market_slug] + def filter_trades_by_source(trades: List[Trade], source: str) -> List[Trade]: """Filter trades by provider source (e.g. 'limitless', 'polymarket', 'kalshi', 'manifold')""" return [t for t in trades if t.source == source] + def deduplicate_trades(trades: List[Trade]) -> List[Trade]: """ Remove exact duplicate trades based on unique identifiers. @@ -59,19 +64,19 @@ def deduplicate_trades(trades: List[Trade]) -> List[Trade]: for t in trades: # Format timestamp consistently - ts_str = t.timestamp.isoformat() if hasattr(t.timestamp, 'isoformat') else str(t.timestamp) + ts_str = t.timestamp.isoformat() if hasattr(t.timestamp, "isoformat") else str(t.timestamp) # Create identifier from key fields # Include BOTH market name and slug to prevent false duplicates # when slug is "unknown" (default) for different markets identifier = ( - t.market, # Include market name + t.market, # Include market name t.market_slug, # Include slug ts_str, t.price, t.shares, t.type, - t.side + t.side, ) if identifier not in seen: @@ -80,6 +85,7 @@ def deduplicate_trades(trades: List[Trade]) -> List[Trade]: return unique_trades + def get_unique_markets(trades: List[Trade]) -> dict: """ Get a dictionary of unique markets from trades @@ -93,6 +99,7 @@ def get_unique_markets(trades: List[Trade]) -> dict: markets[t.market_slug] = t.market return markets + def group_trades_by_market(trades: List[Trade]) -> dict: """ Group trades by market slug for consistency with calculate_market_pnl. diff --git a/prediction_analyzer/trade_loader.py b/prediction_analyzer/trade_loader.py index b76f0ab..0ff9ae4 100644 --- a/prediction_analyzer/trade_loader.py +++ b/prediction_analyzer/trade_loader.py @@ -2,6 +2,7 @@ """ Trade loading functionality - supports JSON, CSV, XLSX """ + import json import logging import pandas as pd @@ -15,6 +16,11 @@ logger = logging.getLogger(__name__) +# Sentinel cap for infinite values — used throughout the codebase to replace +# float('inf') with a finite number safe for JSON serialization and display. +INF_CAP = 999999.99 + + def sanitize_numeric(value: float) -> float: """ Guard against NaN/Infinity in numeric values for JSON serialization. @@ -29,13 +35,14 @@ def sanitize_numeric(value: float) -> float: if math.isnan(value): return 0.0 if math.isinf(value): - return 999999.99 if value > 0 else -999999.99 + return INF_CAP if value > 0 else -INF_CAP return value @dataclass class Trade: """Data class representing a single trade""" + market: str market_slug: str timestamp: datetime @@ -56,7 +63,11 @@ def to_dict(self) -> Dict[str, Any]: return { "market": self.market, "market_slug": self.market_slug, - "timestamp": self.timestamp.isoformat() if hasattr(self.timestamp, "isoformat") else str(self.timestamp), + "timestamp": ( + self.timestamp.isoformat() + if hasattr(self.timestamp, "isoformat") + else str(self.timestamp) + ), "price": sanitize_numeric(self.price), "shares": sanitize_numeric(self.shares), "cost": sanitize_numeric(self.cost), @@ -91,7 +102,7 @@ def _parse_timestamp(value) -> datetime: return value # If it's a pandas Timestamp - if hasattr(value, 'to_pydatetime'): + if hasattr(value, "to_pydatetime"): dt = value.to_pydatetime() if dt.tzinfo is not None: return dt.astimezone(timezone.utc).replace(tzinfo=None) @@ -102,7 +113,7 @@ def _parse_timestamp(value) -> datetime: try: # Handle RFC 3339/ISO 8601 format (e.g., "2024-01-15T10:30:00Z") # Replace 'Z' with '+00:00' for fromisoformat compatibility - clean_value = value.replace('Z', '+00:00') + clean_value = value.replace("Z", "+00:00") dt = datetime.fromisoformat(clean_value) # Convert to naive UTC if dt.tzinfo is not None: @@ -116,7 +127,9 @@ def _parse_timestamp(value) -> datetime: numeric_value = float(value) # If it's a large number, assume milliseconds if numeric_value > 1e12: - return datetime.fromtimestamp(numeric_value / 1000, tz=timezone.utc).replace(tzinfo=None) + return datetime.fromtimestamp(numeric_value / 1000, tz=timezone.utc).replace( + tzinfo=None + ) return datetime.fromtimestamp(numeric_value, tz=timezone.utc).replace(tzinfo=None) except ValueError: pass @@ -131,7 +144,7 @@ def _parse_timestamp(value) -> datetime: # Fallback: try pandas parsing try: result = pd.to_datetime(value, utc=True) - if hasattr(result, 'to_pydatetime'): + if hasattr(result, "to_pydatetime"): dt = result.to_pydatetime() if dt.tzinfo is not None: return dt.astimezone(timezone.utc).replace(tzinfo=None) @@ -156,14 +169,14 @@ def _sanitize_filename(name: str, max_length: int = 50) -> str: # Remove or replace characters that are invalid in filenames # Invalid on Windows: < > : " / \ | ? * # Also remove control characters and other problematic chars - sanitized = re.sub(r'[<>:"/\\|?*\x00-\x1f]', '_', name) + sanitized = re.sub(r'[<>:"/\\|?*\x00-\x1f]', "_", name) # Replace multiple underscores with single - sanitized = re.sub(r'_+', '_', sanitized) + sanitized = re.sub(r"_+", "_", sanitized) # Remove leading/trailing underscores and spaces - sanitized = sanitized.strip('_ ') + sanitized = sanitized.strip("_ ") # Truncate to max length if len(sanitized) > max_length: - sanitized = sanitized[:max_length].rstrip('_') + sanitized = sanitized[:max_length].rstrip("_") # Ensure we have something if not sanitized: sanitized = "unnamed" @@ -196,6 +209,7 @@ def load_trades(file_path: str) -> List[Trade]: # Auto-detect provider from file format try: from .providers import ProviderRegistry + sample = raw_trades[:5] if raw_trades else [] provider = ProviderRegistry.detect_from_file(sample) if provider and provider.name != "limitless": @@ -225,14 +239,15 @@ def load_trades(file_path: str) -> List[Trade]: if not market_slug: market_slug = "unknown" - # Convert from micro-units (USDC has 6 decimals) + # Convert from micro-units (USDC uses 6 decimal places) # If data comes from API (has collateralAmount), values are in micro-units + USDC_DECIMALS = 1_000_000 has_pnl = "pnl" in t and t["pnl"] is not None if "collateralAmount" in t: # API data - convert from micro-units to regular units - raw_cost = float(t.get("collateralAmount") or 0) / 1_000_000 - raw_pnl = float(t.get("pnl") or 0) / 1_000_000 - raw_shares = float(t.get("outcomeTokenAmount") or 0) / 1_000_000 + raw_cost = float(t.get("collateralAmount") or 0) / USDC_DECIMALS + raw_pnl = float(t.get("pnl") or 0) / USDC_DECIMALS + raw_shares = float(t.get("outcomeTokenAmount") or 0) / USDC_DECIMALS else: # File data - already in regular units raw_cost = float(t.get("cost") or 0) @@ -266,7 +281,7 @@ def load_trades(file_path: str) -> List[Trade]: side=side, pnl=raw_pnl, pnl_is_set=has_pnl, - tx_hash=t.get("tx_hash") or t.get("transactionHash") + tx_hash=t.get("tx_hash") or t.get("transactionHash"), ) trades.append(trade) @@ -276,6 +291,7 @@ def load_trades(file_path: str) -> List[Trade]: return trades + def save_trades(trades: List[Union[Trade, dict]], file_path: str): """Save trades to JSON file""" # Handle both Trade objects and raw dictionaries @@ -289,8 +305,8 @@ def save_trades(trades: List[Union[Trade, dict]], file_path: str): # Convert datetime to string for JSON serialization for t in trades_dict: - if 'timestamp' in t and isinstance(t['timestamp'], datetime): - t['timestamp'] = t['timestamp'].isoformat() + if "timestamp" in t and isinstance(t["timestamp"], datetime): + t["timestamp"] = t["timestamp"].isoformat() - with open(file_path, 'w', encoding='utf-8') as f: + with open(file_path, "w", encoding="utf-8") as f: json.dump(trades_dict, f, indent=2) diff --git a/prediction_analyzer/utils/auth.py b/prediction_analyzer/utils/auth.py index 6b16ac6..dd1d11a 100644 --- a/prediction_analyzer/utils/auth.py +++ b/prediction_analyzer/utils/auth.py @@ -8,6 +8,7 @@ - Kalshi: RSA key pair — per-request RSA-PSS signing - Manifold Markets: API key (manifold_...) via "Authorization: Key ..." header """ + import os import logging from typing import Optional diff --git a/prediction_analyzer/utils/data.py b/prediction_analyzer/utils/data.py index 4055fab..3f8976e 100644 --- a/prediction_analyzer/utils/data.py +++ b/prediction_analyzer/utils/data.py @@ -2,6 +2,7 @@ """ Data fetching utilities for Limitless Exchange API. """ + import logging import requests from typing import List @@ -33,10 +34,7 @@ def fetch_trade_history(api_key: str, page_limit: int = 100) -> List[dict]: try: resp = requests.get( - f"{API_BASE_URL}/portfolio/history", - params=params, - headers=headers, - timeout=15 + f"{API_BASE_URL}/portfolio/history", params=params, headers=headers, timeout=15 ) resp.raise_for_status() data = resp.json() diff --git a/prediction_analyzer/utils/export.py b/prediction_analyzer/utils/export.py index f0969c6..72ee748 100644 --- a/prediction_analyzer/utils/export.py +++ b/prediction_analyzer/utils/export.py @@ -2,6 +2,7 @@ """ Export utility functions """ + import logging import matplotlib.pyplot as plt from typing import Any @@ -10,6 +11,7 @@ logger = logging.getLogger(__name__) + def export_chart(fig: Any, path: str): """ Export a matplotlib or plotly figure to file @@ -21,13 +23,13 @@ def export_chart(fig: Any, path: str): try: # Check if it's a matplotlib figure if isinstance(fig, plt.Figure): - fig.savefig(path, dpi=150, bbox_inches='tight') + fig.savefig(path, dpi=150, bbox_inches="tight") logger.info("Chart exported to: %s", path) # Check if it's a plotly figure - elif hasattr(fig, 'write_html'): + elif hasattr(fig, "write_html"): fig.write_html(path) logger.info("Interactive chart exported to: %s", path) - elif hasattr(fig, 'write_image'): + elif hasattr(fig, "write_image"): fig.write_image(path) logger.info("Chart exported to: %s", path) else: diff --git a/prediction_analyzer/utils/math_utils.py b/prediction_analyzer/utils/math_utils.py index 249bfa2..f7d126b 100644 --- a/prediction_analyzer/utils/math_utils.py +++ b/prediction_analyzer/utils/math_utils.py @@ -2,9 +2,11 @@ """ Mathematical utility functions """ + import numpy as np from typing import List + def moving_average(values: List[float], window: int = 5) -> np.ndarray: """ Calculate simple moving average @@ -18,7 +20,8 @@ def moving_average(values: List[float], window: int = 5) -> np.ndarray: """ if len(values) < window: window = len(values) - return np.convolve(values, np.ones(window)/window, mode='valid') + return np.convolve(values, np.ones(window) / window, mode="valid") + def weighted_average(values: List[float], weights: List[float]) -> float: """ @@ -35,6 +38,7 @@ def weighted_average(values: List[float], weights: List[float]) -> float: raise ValueError("Values and weights must have same length") return np.average(values, weights=weights) + def safe_divide(numerator: float, denominator: float, default: float = 0.0) -> float: """ Safe division that returns default value if denominator is zero @@ -49,6 +53,7 @@ def safe_divide(numerator: float, denominator: float, default: float = 0.0) -> f """ return numerator / denominator if denominator != 0 else default + def calculate_roi(pnl: float, investment: float) -> float: """ Calculate return on investment percentage diff --git a/prediction_analyzer/utils/time_utils.py b/prediction_analyzer/utils/time_utils.py index c3b958c..ee57db0 100644 --- a/prediction_analyzer/utils/time_utils.py +++ b/prediction_analyzer/utils/time_utils.py @@ -2,9 +2,11 @@ """ Time and date utility functions """ + from datetime import datetime, timedelta from typing import Optional + def parse_date(date_str: str) -> datetime: """ Parse date string in various formats @@ -15,13 +17,7 @@ def parse_date(date_str: str) -> datetime: Returns: datetime object """ - formats = [ - "%Y-%m-%d", - "%Y/%m/%d", - "%m-%d-%Y", - "%m/%d/%Y", - "%Y-%m-%d %H:%M:%S" - ] + formats = ["%Y-%m-%d", "%Y/%m/%d", "%m-%d-%Y", "%m/%d/%Y", "%Y-%m-%d %H:%M:%S"] for fmt in formats: try: @@ -31,6 +27,7 @@ def parse_date(date_str: str) -> datetime: raise ValueError(f"Unable to parse date: {date_str}") + def format_timestamp(timestamp: datetime, fmt: str = "%Y-%m-%d %H:%M:%S") -> str: """ Format timestamp to string @@ -44,6 +41,7 @@ def format_timestamp(timestamp: datetime, fmt: str = "%Y-%m-%d %H:%M:%S") -> str """ return timestamp.strftime(fmt) + def get_date_range(days_back: int) -> tuple: """ Get a date range from N days ago to now diff --git a/prediction_mcp/__main__.py b/prediction_mcp/__main__.py index 8448316..5d6c5a1 100644 --- a/prediction_mcp/__main__.py +++ b/prediction_mcp/__main__.py @@ -2,6 +2,7 @@ """ Allows running the MCP server with: python -m prediction_mcp """ + from .server import main main() diff --git a/prediction_mcp/_apply_filters.py b/prediction_mcp/_apply_filters.py index 9c780d7..650dca5 100644 --- a/prediction_mcp/_apply_filters.py +++ b/prediction_mcp/_apply_filters.py @@ -5,11 +5,17 @@ Extracts common filter parameters from tool arguments and applies them using the core library filter functions. """ + from typing import List, Dict, Any from prediction_analyzer.trade_loader import Trade from prediction_analyzer.exceptions import InvalidFilterError -from prediction_analyzer.filters import filter_by_date, filter_by_trade_type, filter_by_side, filter_by_pnl +from prediction_analyzer.filters import ( + filter_by_date, + filter_by_trade_type, + filter_by_side, + filter_by_pnl, +) from prediction_analyzer.trade_filter import filter_trades_by_market_slug from .validators import validate_date, validate_trade_types, validate_sides, validate_numeric @@ -57,9 +63,7 @@ def apply_filters(trades: List[Trade], arguments: Dict[str, Any]) -> List[Trade] min_pnl = validate_numeric(arguments.get("min_pnl"), "min_pnl") max_pnl = validate_numeric(arguments.get("max_pnl"), "max_pnl") if min_pnl is not None and max_pnl is not None and min_pnl > max_pnl: - raise InvalidFilterError( - f"min_pnl ({min_pnl}) must not exceed max_pnl ({max_pnl})" - ) + raise InvalidFilterError(f"min_pnl ({min_pnl}) must not exceed max_pnl ({max_pnl})") if min_pnl is not None or max_pnl is not None: result = filter_by_pnl(result, min_pnl=min_pnl, max_pnl=max_pnl) diff --git a/prediction_mcp/errors.py b/prediction_mcp/errors.py index 0cd6dde..d506f02 100644 --- a/prediction_mcp/errors.py +++ b/prediction_mcp/errors.py @@ -5,6 +5,7 @@ Converts internal exceptions into user-friendly error messages with recovery hints that help the LLM agent self-correct. """ + import functools import logging @@ -65,6 +66,7 @@ def safe_tool(func): Eliminates try/except boilerplate from individual handlers. """ + @functools.wraps(func) async def wrapper(*args, **kwargs): try: @@ -74,4 +76,5 @@ async def wrapper(*args, **kwargs): except Exception as e: logger.exception("Unhandled error in MCP tool %s", func.__name__) return error_result(e).content + return wrapper diff --git a/prediction_mcp/persistence.py b/prediction_mcp/persistence.py index 6729e72..2e454da 100644 --- a/prediction_mcp/persistence.py +++ b/prediction_mcp/persistence.py @@ -11,6 +11,7 @@ store.restore(session) # reload from disk store.close() """ + import json import logging import sqlite3 @@ -85,16 +86,29 @@ def save(self, session) -> None: cur.execute("DELETE FROM session_meta") for trade in session.trades: - ts = trade.timestamp.isoformat() if hasattr(trade.timestamp, "isoformat") else str(trade.timestamp) + ts = ( + trade.timestamp.isoformat() + if hasattr(trade.timestamp, "isoformat") + else str(trade.timestamp) + ) cur.execute( "INSERT INTO trades (market, market_slug, timestamp, price, shares, cost, type, side, pnl, pnl_is_set, tx_hash, source, currency) " "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", - (trade.market, trade.market_slug, ts, trade.price, trade.shares, - trade.cost, trade.type, trade.side, trade.pnl, - 1 if trade.pnl_is_set else 0, - trade.tx_hash, - getattr(trade, "source", "limitless"), - getattr(trade, "currency", "USD")), + ( + trade.market, + trade.market_slug, + ts, + trade.price, + trade.shares, + trade.cost, + trade.type, + trade.side, + trade.pnl, + 1 if trade.pnl_is_set else 0, + trade.tx_hash, + getattr(trade, "source", "limitless"), + getattr(trade, "currency", "USD"), + ), ) # Save sources list @@ -124,28 +138,35 @@ def restore(self, session) -> bool: trades = [] row_keys = rows[0].keys() if rows else [] for row in rows: - pnl_is_set = bool(row["pnl_is_set"]) if "pnl_is_set" in row_keys else (row["pnl"] != 0.0) - trades.append(Trade( - market=row["market"], - market_slug=row["market_slug"], - timestamp=datetime.fromisoformat(row["timestamp"]), - price=row["price"], - shares=row["shares"], - cost=row["cost"], - type=row["type"], - side=row["side"], - pnl=row["pnl"], - pnl_is_set=pnl_is_set, - tx_hash=row["tx_hash"], - source=row["source"] if "source" in row_keys else "limitless", - currency=row["currency"] if "currency" in row_keys else "USD", - )) + pnl_is_set = ( + bool(row["pnl_is_set"]) if "pnl_is_set" in row_keys else (row["pnl"] != 0.0) + ) + trades.append( + Trade( + market=row["market"], + market_slug=row["market_slug"], + timestamp=datetime.fromisoformat(row["timestamp"]), + price=row["price"], + shares=row["shares"], + cost=row["cost"], + type=row["type"], + side=row["side"], + pnl=row["pnl"], + pnl_is_set=pnl_is_set, + tx_hash=row["tx_hash"], + source=row["source"] if "source" in row_keys else "limitless", + currency=row["currency"] if "currency" in row_keys else "USD", + ) + ) session.trades = trades session.filtered_trades = list(trades) # Restore metadata - meta = {r["key"]: r["value"] for r in cur.execute("SELECT key, value FROM session_meta").fetchall()} + meta = { + r["key"]: r["value"] + for r in cur.execute("SELECT key, value FROM session_meta").fetchall() + } sources_json = meta.get("sources") if sources_json: diff --git a/prediction_mcp/py.typed b/prediction_mcp/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/prediction_mcp/serializers.py b/prediction_mcp/serializers.py index 3d7c81b..9e8b0a1 100644 --- a/prediction_mcp/serializers.py +++ b/prediction_mcp/serializers.py @@ -5,6 +5,7 @@ Ensures all data returned from tools is JSON-serializable, handling NaN/Infinity, datetime objects, and dataclass conversion. """ + import json import math from typing import Any, Dict, List @@ -37,7 +38,7 @@ def _sanitize_value(value: Any) -> Any: return sanitize_dict(value) if isinstance(value, (list, tuple)): return [_sanitize_value(v) for v in value] - if hasattr(value, 'isoformat'): + if hasattr(value, "isoformat"): return value.isoformat() return value diff --git a/prediction_mcp/server.py b/prediction_mcp/server.py index 1f0a930..88679ce 100644 --- a/prediction_mcp/server.py +++ b/prediction_mcp/server.py @@ -35,14 +35,24 @@ # Import tool modules from .tools import ( - data_tools, analysis_tools, filter_tools, chart_tools, - export_tools, portfolio_tools, tax_tools, + data_tools, + analysis_tools, + filter_tools, + chart_tools, + export_tools, + portfolio_tools, + tax_tools, ) # Collect all tool modules for dispatch _TOOL_MODULES = [ - data_tools, analysis_tools, filter_tools, chart_tools, - export_tools, portfolio_tools, tax_tools, + data_tools, + analysis_tools, + filter_tools, + chart_tools, + export_tools, + portfolio_tools, + tax_tools, ] # Create the MCP server instance @@ -68,6 +78,248 @@ async def list_tools() -> list[types.Tool]: return tools +# --------------------------------------------------------------------------- +# MCP Resources — expose session data for direct LLM reading +# --------------------------------------------------------------------------- + + +@app.list_resources() +async def list_resources() -> list[types.Resource]: + """List available resources based on current session state.""" + from .state import session + + resources: list[types.Resource] = [] + + if session.has_trades: + resources.append( + types.Resource( + uri="prediction://trades/summary", + name="Trade Summary", + description=( + f"Summary of {session.trade_count} loaded trades " + f"from {', '.join(session.sources) or 'unknown'} providers." + ), + mimeType="application/json", + ) + ) + resources.append( + types.Resource( + uri="prediction://trades/markets", + name="Market List", + description="List of all unique markets in the current session.", + mimeType="application/json", + ) + ) + if session.active_filters: + resources.append( + types.Resource( + uri="prediction://trades/filters", + name="Active Filters", + description="Currently applied trade filters.", + mimeType="application/json", + ) + ) + + return resources + + +@app.read_resource() +async def read_resource(uri: str) -> str: + """Read a resource by URI.""" + from .state import session + from .serializers import to_json_text, sanitize_dict + from prediction_analyzer.trade_filter import get_unique_markets, filter_trades_by_market_slug + from prediction_analyzer.pnl import calculate_global_pnl_summary + + if uri == "prediction://trades/summary": + if not session.has_trades: + return to_json_text({"error": "No trades loaded"}) + summary = calculate_global_pnl_summary(session.trades) + return to_json_text(sanitize_dict(summary)) + + elif uri == "prediction://trades/markets": + if not session.has_trades: + return to_json_text({"error": "No trades loaded"}) + markets = get_unique_markets(session.trades) + result = [] + for slug, title in sorted(markets.items()): + market_trades = filter_trades_by_market_slug(session.trades, slug) + result.append( + { + "slug": slug, + "title": title, + "trade_count": len(market_trades), + "sources": list({t.source for t in market_trades}), + } + ) + return to_json_text(result) + + elif uri == "prediction://trades/filters": + return to_json_text(session.active_filters or {}) + + raise ValueError(f"Unknown resource URI: {uri}") + + +# --------------------------------------------------------------------------- +# MCP Prompts — pre-built templates for common analysis workflows +# --------------------------------------------------------------------------- + + +@app.list_prompts() +async def list_prompts() -> list[types.Prompt]: + """List available prompt templates.""" + return [ + types.Prompt( + name="analyze_portfolio", + description=( + "Comprehensive portfolio analysis: global summary, " + "top/bottom markets, risk metrics, and actionable insights." + ), + arguments=[ + types.PromptArgument( + name="focus", + description="Optional focus area: 'risk', 'performance', or 'tax'", + required=False, + ), + ], + ), + types.Prompt( + name="compare_periods", + description=( + "Compare trading performance between two date ranges. " + "Useful for month-over-month or pre/post strategy change analysis." + ), + arguments=[ + types.PromptArgument( + name="period1_start", + description="Start date of first period (YYYY-MM-DD)", + required=True, + ), + types.PromptArgument( + name="period1_end", + description="End date of first period (YYYY-MM-DD)", + required=True, + ), + types.PromptArgument( + name="period2_start", + description="Start date of second period (YYYY-MM-DD)", + required=True, + ), + types.PromptArgument( + name="period2_end", + description="End date of second period (YYYY-MM-DD)", + required=True, + ), + ], + ), + types.Prompt( + name="daily_report", + description="Generate a daily trading summary for today or a specific date.", + arguments=[ + types.PromptArgument( + name="date", + description="Date to report on (YYYY-MM-DD). Defaults to today.", + required=False, + ), + ], + ), + ] + + +@app.get_prompt() +async def get_prompt(name: str, arguments: dict | None = None) -> types.GetPromptResult: + """Return a prompt template with user arguments filled in.""" + args = arguments or {} + + if name == "analyze_portfolio": + focus = args.get("focus", "performance") + focus_instructions = { + "risk": ( + "Focus on risk analysis: run get_advanced_metrics for Sharpe ratio, max drawdown, " + "and Sortino ratio. Then run get_concentration_risk and get_drawdown_analysis. " + "Flag any markets with outsized position sizes or unrealized losses." + ), + "tax": ( + "Focus on tax implications: run get_tax_report with FIFO method. " + "Summarize short-term vs long-term gains. Note any wash sale concerns " + "and suggest tax-loss harvesting opportunities." + ), + "performance": ( + "Focus on trading performance: run get_global_summary for overall stats, " + "then get_market_breakdown to find best/worst markets. Run get_advanced_metrics " + "for risk-adjusted returns. Provide actionable recommendations." + ), + } + return types.GetPromptResult( + description=f"Portfolio analysis focused on {focus}", + messages=[ + types.PromptMessage( + role="user", + content=types.TextContent( + type="text", + text=( + "Analyze my prediction market portfolio. " + f"{focus_instructions.get(focus, focus_instructions['performance'])}\n\n" + "Structure your response with:\n" + "1. Executive Summary (2-3 sentences)\n" + "2. Key Metrics table\n" + "3. Top 3 and Bottom 3 markets\n" + "4. Risk Assessment\n" + "5. Recommendations" + ), + ), + ), + ], + ) + + elif name == "compare_periods": + return types.GetPromptResult( + description="Period comparison analysis", + messages=[ + types.PromptMessage( + role="user", + content=types.TextContent( + type="text", + text=( + f"Compare my trading performance between two periods:\n" + f"- Period 1: {args['period1_start']} to {args['period1_end']}\n" + f"- Period 2: {args['period2_start']} to {args['period2_end']}\n\n" + "For each period, run get_global_summary with the date filters, " + "then get_advanced_metrics. Compare win rate, total PnL, ROI, " + "Sharpe ratio, and max drawdown.\n\n" + "Present results in a side-by-side comparison table and explain " + "what changed between the periods." + ), + ), + ), + ], + ) + + elif name == "daily_report": + date = args.get("date", "today") + return types.GetPromptResult( + description=f"Daily trading report for {date}", + messages=[ + types.PromptMessage( + role="user", + content=types.TextContent( + type="text", + text=( + f"Generate a daily trading report for {date}.\n\n" + "1. Use get_global_summary with start_date and end_date set to " + f"'{date}' to get the day's stats\n" + "2. Use get_trade_details to list individual trades from that day\n" + "3. Summarize: trades executed, net PnL, win rate\n" + "4. Note any notable wins or losses" + ), + ), + ), + ], + ) + + raise ValueError(f"Unknown prompt: {name}") + + # Optional session store for SQLite persistence (set by main()) _session_store = None @@ -85,15 +337,18 @@ async def call_tool(name: str, arguments: dict) -> list[types.TextContent]: if _session_store and name in _STATE_MODIFYING_TOOLS: try: from .state import session + _session_store.save(session) except Exception: logger.exception("Failed to persist session after %s", name) return result - return [types.TextContent( - type="text", - text=f"Unknown tool: {name}", - )] + return [ + types.TextContent( + type="text", + text=f"Unknown tool: {name}", + ) + ] async def run_stdio(): @@ -123,16 +378,15 @@ def create_sse_app(sse_path: str = "/sse", message_path: str = "/messages"): async def handle_sse(request): """Handle SSE connection — long-lived event stream.""" - async with sse_transport.connect_sse( - request.scope, request.receive, request._send - ) as (read_stream, write_stream): + async with sse_transport.connect_sse(request.scope, request.receive, request._send) as ( + read_stream, + write_stream, + ): await app.run(read_stream, write_stream, app.create_initialization_options()) async def handle_messages(request): """Handle client-to-server JSON-RPC messages.""" - await sse_transport.handle_post_message( - request.scope, request.receive, request._send - ) + await sse_transport.handle_post_message(request.scope, request.receive, request._send) async def health(request): return JSONResponse({"status": "ok", "server": "prediction-analyzer", "version": "1.0.0"}) @@ -150,6 +404,7 @@ async def health(request): def run_sse(host: str = "0.0.0.0", port: int = 8000): """Run the MCP server over HTTP/SSE transport.""" import uvicorn + logger.info("Starting Prediction Analyzer MCP server (SSE) on %s:%d", host, port) starlette_app = create_sse_app() uvicorn.run(starlette_app, host=host, port=port, log_level="info") @@ -171,19 +426,24 @@ def main(): """Entry point for the MCP server.""" parser = argparse.ArgumentParser(description="Prediction Analyzer MCP Server") parser.add_argument( - "--sse", action="store_true", + "--sse", + action="store_true", help="Use HTTP/SSE transport instead of stdio", ) parser.add_argument( - "--host", default="0.0.0.0", + "--host", + default="0.0.0.0", help="Host to bind SSE server (default: 0.0.0.0)", ) parser.add_argument( - "--port", type=int, default=8000, + "--port", + type=int, + default=8000, help="Port for SSE server (default: 8000)", ) parser.add_argument( - "--persist", metavar="DB_PATH", + "--persist", + metavar="DB_PATH", default=os.environ.get("PREDICTION_MCP_DB"), help="SQLite database path for session persistence (or set PREDICTION_MCP_DB env var)", ) diff --git a/prediction_mcp/state.py b/prediction_mcp/state.py index a3a9ace..1f47530 100644 --- a/prediction_mcp/state.py +++ b/prediction_mcp/state.py @@ -5,6 +5,7 @@ Maintains loaded trades, active filters, and filtered results so that the LLM can load trades once and run multiple analyses without re-loading. """ + from dataclasses import dataclass, field from typing import List, Dict, Any, Optional diff --git a/prediction_mcp/tools/analysis_tools.py b/prediction_mcp/tools/analysis_tools.py index 02a9a62..b49502f 100644 --- a/prediction_mcp/tools/analysis_tools.py +++ b/prediction_mcp/tools/analysis_tools.py @@ -206,13 +206,15 @@ async def _handle_market_breakdown(arguments: dict): result = [] for slug, stats in sorted(breakdown.items(), key=lambda x: x[1]["total_pnl"], reverse=True): - result.append({ - "market_slug": slug, - "market": stats["market_name"], - "trade_count": stats["trade_count"], - "pnl": stats["total_pnl"], - "volume": stats["total_volume"], - }) + result.append( + { + "market_slug": slug, + "market": stats["market_name"], + "trade_count": stats["trade_count"], + "pnl": stats["total_pnl"], + "volume": stats["total_volume"], + } + ) return [types.TextContent(type="text", text=to_json_text(result))] @@ -241,13 +243,15 @@ async def _handle_provider_breakdown(arguments: dict): result = [] for src, stats in sorted(sources.items(), key=lambda x: x[1]["total_pnl"], reverse=True): cfg = PROVIDER_CONFIGS.get(src, {}) - result.append({ - "provider": src, - "display_name": cfg.get("display_name", src.title()), - "total_trades": stats["total_trades"], - "total_pnl": stats["total_pnl"], - "total_volume": stats["total_volume"], - "currency": stats["currency"], - }) + result.append( + { + "provider": src, + "display_name": cfg.get("display_name", src.title()), + "total_trades": stats["total_trades"], + "total_pnl": stats["total_pnl"], + "total_volume": stats["total_volume"], + "currency": stats["currency"], + } + ) return [types.TextContent(type="text", text=to_json_text(result))] diff --git a/prediction_mcp/tools/chart_tools.py b/prediction_mcp/tools/chart_tools.py index fa9b9a0..b02cf33 100644 --- a/prediction_mcp/tools/chart_tools.py +++ b/prediction_mcp/tools/chart_tools.py @@ -14,7 +14,11 @@ from prediction_analyzer.charts.pro import generate_pro_chart from prediction_analyzer.charts.enhanced import generate_enhanced_chart from prediction_analyzer.charts.global_chart import generate_global_dashboard -from prediction_analyzer.trade_filter import filter_trades_by_market_slug, group_trades_by_market, get_unique_markets +from prediction_analyzer.trade_filter import ( + filter_trades_by_market_slug, + group_trades_by_market, + get_unique_markets, +) from prediction_analyzer.exceptions import NoTradesError from ..state import session diff --git a/prediction_mcp/tools/data_tools.py b/prediction_mcp/tools/data_tools.py index fda7b08..c085a68 100644 --- a/prediction_mcp/tools/data_tools.py +++ b/prediction_mcp/tools/data_tools.py @@ -18,7 +18,12 @@ from ..state import session from ..errors import error_result, safe_tool from ..serializers import to_json_text, serialize_trades -from ..validators import validate_sort_field, validate_sort_order, validate_positive_int, validate_market_slug +from ..validators import ( + validate_sort_field, + validate_sort_order, + validate_positive_int, + validate_market_slug, +) logger = logging.getLogger(__name__) @@ -205,6 +210,7 @@ async def _handle_fetch_trades(arguments: dict): # Apply PnL computation for providers that don't supply it if provider.name in ("kalshi", "manifold", "polymarket"): from prediction_analyzer.providers.pnl_calculator import compute_realized_pnl + trades = compute_realized_pnl(trades) # Deduplicate by tx_hash to prevent inflation on repeated fetches @@ -241,12 +247,14 @@ async def _handle_list_markets(arguments: dict): for slug, title in sorted(markets.items()): market_trades = filter_trades_by_market_slug(session.trades, slug) sources = list({t.source for t in market_trades}) - result.append({ - "slug": slug, - "title": title, - "trade_count": len(market_trades), - "sources": sources, - }) + result.append( + { + "slug": slug, + "title": title, + "trade_count": len(market_trades), + "sources": sources, + } + ) return [types.TextContent(type="text", text=to_json_text(result))] @@ -273,7 +281,7 @@ async def _handle_get_trade_details(arguments: dict): trades = sorted(trades, key=lambda t: getattr(t, sort_by, 0), reverse=reverse) total = len(trades) - trades = trades[offset:offset + limit] + trades = trades[offset : offset + limit] result = { "trades": serialize_trades(trades), diff --git a/prediction_mcp/tools/export_tools.py b/prediction_mcp/tools/export_tools.py index ec7cb71..10ca2e4 100644 --- a/prediction_mcp/tools/export_tools.py +++ b/prediction_mcp/tools/export_tools.py @@ -88,9 +88,7 @@ async def _handle_export_trades(arguments: dict): # allowed (the user explicitly chose where to write). # Check the raw path (before normpath resolves ..) to catch traversal. if ".." in output_path.replace("\\", "/").split("/"): - raise ValueError( - f"output_path must not contain '..': {output_path}" - ) + raise ValueError(f"output_path must not contain '..': {output_path}") validate_export_format(fmt) diff --git a/prediction_mcp/tools/filter_tools.py b/prediction_mcp/tools/filter_tools.py index 1b05823..4f9a5db 100644 --- a/prediction_mcp/tools/filter_tools.py +++ b/prediction_mcp/tools/filter_tools.py @@ -94,7 +94,8 @@ async def _handle_filter_trades(arguments: dict): session.filtered_trades = filtered session.active_filters = { - k: v for k, v in arguments.items() + k: v + for k, v in arguments.items() if v is not None and v != "" and v != [] and k != "clear" } diff --git a/prediction_mcp/validators.py b/prediction_mcp/validators.py index 2c460db..11c5cc3 100644 --- a/prediction_mcp/validators.py +++ b/prediction_mcp/validators.py @@ -5,13 +5,13 @@ Validates and normalizes tool inputs before passing them to the core library functions. Raises InvalidFilterError for bad inputs. """ + import math from datetime import datetime from typing import Optional, List, Dict from prediction_analyzer.exceptions import InvalidFilterError, MarketNotFoundError - VALID_TRADE_TYPES = {"Buy", "Sell"} VALID_SIDES = {"YES", "NO"} VALID_CHART_TYPES = {"simple", "pro", "enhanced", "global"} @@ -33,9 +33,7 @@ def validate_date(value: Optional[str], param_name: str) -> Optional[str]: datetime.strptime(value, "%Y-%m-%d") return value except ValueError: - raise InvalidFilterError( - f"Invalid {param_name}: '{value}'. Expected format: YYYY-MM-DD" - ) + raise InvalidFilterError(f"Invalid {param_name}: '{value}'. Expected format: YYYY-MM-DD") def validate_trade_types(types: Optional[List[str]]) -> Optional[List[str]]: @@ -45,7 +43,9 @@ def validate_trade_types(types: Optional[List[str]]) -> Optional[List[str]]: """ if types is None: return None - normalized = [_TRADE_TYPE_NORMALIZE.get(t.lower(), t) if isinstance(t, str) else t for t in types] + normalized = [ + _TRADE_TYPE_NORMALIZE.get(t.lower(), t) if isinstance(t, str) else t for t in types + ] invalid = [t for t in normalized if t not in VALID_TRADE_TYPES] if invalid: raise InvalidFilterError( @@ -64,9 +64,7 @@ def validate_sides(sides: Optional[List[str]]) -> Optional[List[str]]: normalized = [_SIDE_NORMALIZE.get(s.lower(), s) if isinstance(s, str) else s for s in sides] invalid = [s for s in normalized if s not in VALID_SIDES] if invalid: - raise InvalidFilterError( - f"Invalid sides: {invalid}. Valid values: {sorted(VALID_SIDES)}" - ) + raise InvalidFilterError(f"Invalid sides: {invalid}. Valid values: {sorted(VALID_SIDES)}") return normalized @@ -98,9 +96,7 @@ def validate_positive_int(value: Optional[int], param_name: str) -> Optional[int f"Invalid {param_name}: {value}. Must be a positive integer, not NaN/Infinity." ) if not isinstance(value, int) or value < 1: - raise InvalidFilterError( - f"Invalid {param_name}: {value}. Must be a positive integer." - ) + raise InvalidFilterError(f"Invalid {param_name}: {value}. Must be a positive integer.") return value diff --git a/pyproject.toml b/pyproject.toml index b33e382..1bc5732 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ version = "1.0.0" description = "A complete modular analysis tool for prediction market traders" readme = "README.md" authors = [ - {name = "Your Name", email = "you@example.com"} + {name = "Frostbite1536"} ] license = "AGPL-3.0-or-later" classifiers = [ @@ -27,7 +27,8 @@ dependencies = [ "matplotlib>=3.7.0", "plotly>=5.20.0", "openpyxl>=3.1.0", - "requests>=2.28.0" + "requests>=2.28.0", + "cryptography>=41.0.0", ] [project.optional-dependencies] @@ -46,8 +47,13 @@ api = [ dev = [ "pytest>=7.0.0", "pytest-cov>=4.0.0", + "pytest-asyncio>=0.21.0", + "httpx>=0.24.0", "black>=22.0.0", - "flake8>=5.0.0" + "flake8>=5.0.0", + "mypy>=1.0.0", + "pandas-stubs>=1.5.0", + "types-requests>=2.28.0", ] mcp = [ "mcp>=1.0.0", @@ -61,9 +67,25 @@ prediction-analyzer = "prediction_analyzer.__main__:main" prediction-mcp = "prediction_mcp.server:main" [project.urls] -Homepage = "https://github.com/yourusername/prediction_analyzer" -Documentation = "https://github.com/yourusername/prediction_analyzer/wiki" -Repository = "https://github.com/yourusername/prediction_analyzer" +Homepage = "https://github.com/Frostbite1536/Prediction_Analyzer" +Documentation = "https://github.com/Frostbite1536/Prediction_Analyzer/wiki" +Repository = "https://github.com/Frostbite1536/Prediction_Analyzer" [tool.setuptools.packages.find] include = ["prediction_analyzer*", "prediction_mcp*"] + +[tool.black] +line-length = 100 +target-version = ["py39"] + +[tool.mypy] +python_version = "3.9" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = false +check_untyped_defs = true +ignore_missing_imports = true + +[tool.pytest.ini_options] +testpaths = ["tests"] +asyncio_mode = "auto" diff --git a/tests/api/__init__.py b/tests/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/api/conftest.py b/tests/api/conftest.py new file mode 100644 index 0000000..6f8a285 --- /dev/null +++ b/tests/api/conftest.py @@ -0,0 +1,93 @@ +# tests/api/conftest.py +""" +Fixtures for FastAPI integration tests. + +Uses an in-memory SQLite database so tests are fast and isolated. +""" + +import pytest +from fastapi.testclient import TestClient +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from sqlalchemy.pool import StaticPool + +from prediction_analyzer.api.database import Base +from prediction_analyzer.api.dependencies import get_db +from prediction_analyzer.api.main import app, _rate_store + +# In-memory SQLite for tests +_TEST_ENGINE = create_engine( + "sqlite://", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, +) +_TestSession = sessionmaker(autocommit=False, autoflush=False, bind=_TEST_ENGINE) + + +def _override_get_db(): + db = _TestSession() + try: + yield db + finally: + db.close() + + +@pytest.fixture(autouse=True) +def _setup_db(): + """Create all tables before each test, drop after. Clear rate limiter.""" + Base.metadata.create_all(bind=_TEST_ENGINE) + _rate_store.clear() + yield + Base.metadata.drop_all(bind=_TEST_ENGINE) + + +@pytest.fixture() +def client(): + """FastAPI TestClient with overridden DB dependency.""" + app.dependency_overrides[get_db] = _override_get_db + with TestClient(app) as c: + yield c + app.dependency_overrides.clear() + + +@pytest.fixture() +def db_session(): + """Direct DB session for setting up test data.""" + db = _TestSession() + try: + yield db + finally: + db.close() + + +# ---- Helpers --------------------------------------------------------------- + + +def signup_user( + client: TestClient, email="test@example.com", username="testuser", password="password123" +): + """Register a user and return the response.""" + resp = client.post( + "/api/v1/auth/signup", + json={ + "email": email, + "username": username, + "password": password, + }, + ) + return resp + + +def auth_header(token: str) -> dict: + """Build an Authorization header from a bearer token.""" + return {"Authorization": f"Bearer {token}"} + + +def create_authenticated_user( + client: TestClient, email="test@example.com", username="testuser", password="password123" +): + """Register a user and return (response_json, auth_headers).""" + resp = signup_user(client, email, username, password) + assert resp.status_code == 201, f"Signup failed: {resp.status_code} {resp.text}" + data = resp.json() + return data, auth_header(data["access_token"]) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py new file mode 100644 index 0000000..7373a96 --- /dev/null +++ b/tests/api/test_auth.py @@ -0,0 +1,106 @@ +# tests/api/test_auth.py +"""Tests for authentication endpoints: signup, login, token validation.""" + +import pytest + +from .conftest import signup_user, auth_header, create_authenticated_user + + +class TestSignup: + def test_signup_success(self, client): + resp = signup_user(client) + assert resp.status_code == 201 + data = resp.json() + assert data["access_token"] + assert data["user"]["email"] == "test@example.com" + assert data["user"]["username"] == "testuser" + + def test_signup_duplicate_email(self, client): + signup_user(client) + resp = signup_user(client, username="other") + assert resp.status_code == 400 + assert "already registered" in resp.json()["detail"] + + def test_signup_duplicate_username(self, client): + signup_user(client) + resp = signup_user(client, email="other@example.com") + assert resp.status_code == 400 + assert "already taken" in resp.json()["detail"] + + def test_signup_short_password(self, client): + resp = signup_user(client, password="short") + assert resp.status_code == 422 # Pydantic validation + + def test_signup_short_username(self, client): + resp = signup_user(client, username="ab") + assert resp.status_code == 422 + + def test_signup_invalid_email(self, client): + resp = signup_user(client, email="not-an-email") + assert resp.status_code == 422 + + +class TestLoginJson: + def test_login_json_success(self, client): + signup_user(client) + resp = client.post( + "/api/v1/auth/login/json", + json={ + "email": "test@example.com", + "password": "password123", + }, + ) + assert resp.status_code == 200 + assert resp.json()["access_token"] + assert resp.json()["token_type"] == "bearer" + + def test_login_json_wrong_password(self, client): + signup_user(client) + resp = client.post( + "/api/v1/auth/login/json", + json={ + "email": "test@example.com", + "password": "wrong", + }, + ) + assert resp.status_code == 401 + assert "Incorrect email or password" in resp.json()["detail"] + + def test_login_json_nonexistent_user(self, client): + resp = client.post( + "/api/v1/auth/login/json", + json={ + "email": "nobody@example.com", + "password": "password123", + }, + ) + assert resp.status_code == 401 + + +class TestLoginOAuth2: + def test_login_form_success(self, client): + signup_user(client) + resp = client.post( + "/api/v1/auth/login", + data={ + "username": "test@example.com", + "password": "password123", + }, + ) + assert resp.status_code == 200 + assert resp.json()["access_token"] + + +class TestTokenValidation: + def test_invalid_token_rejected(self, client): + resp = client.get("/api/v1/trades", headers=auth_header("garbage.token.here")) + assert resp.status_code == 401 + + def test_missing_token_rejected(self, client): + resp = client.get("/api/v1/trades") + assert resp.status_code == 401 + + def test_valid_token_accepted(self, client): + _, headers = create_authenticated_user(client) + resp = client.get("/api/v1/trades", headers=headers) + assert resp.status_code == 200 diff --git a/tests/api/test_security.py b/tests/api/test_security.py new file mode 100644 index 0000000..944b454 --- /dev/null +++ b/tests/api/test_security.py @@ -0,0 +1,44 @@ +# tests/api/test_security.py +"""Tests for security features: headers, rate limiting, CORS.""" + +import pytest + +from .conftest import create_authenticated_user + + +class TestSecurityHeaders: + """Verify security headers are present on all responses.""" + + def test_root_has_security_headers(self, client): + resp = client.get("/") + assert resp.headers["X-Content-Type-Options"] == "nosniff" + assert resp.headers["X-Frame-Options"] == "DENY" + assert resp.headers["Referrer-Policy"] == "strict-origin-when-cross-origin" + assert resp.headers["X-XSS-Protection"] == "1; mode=block" + assert "geolocation=()" in resp.headers["Permissions-Policy"] + + def test_api_endpoint_has_security_headers(self, client): + _, headers = create_authenticated_user(client) + resp = client.get("/api/v1/trades", headers=headers) + assert resp.headers["X-Content-Type-Options"] == "nosniff" + assert resp.headers["X-Frame-Options"] == "DENY" + + def test_no_hsts_on_http(self, client): + """HSTS should not be set when not using HTTPS.""" + resp = client.get("/") + assert "Strict-Transport-Security" not in resp.headers + + +class TestHealthCheck: + def test_health_check(self, client): + resp = client.get("/health") + assert resp.status_code == 200 + assert resp.json()["status"] == "healthy" + + def test_root_endpoint(self, client): + resp = client.get("/") + assert resp.status_code == 200 + data = resp.json() + assert "name" in data + assert "version" in data + assert data["docs"] == "/docs" diff --git a/tests/api/test_trades.py b/tests/api/test_trades.py new file mode 100644 index 0000000..c6ff023 --- /dev/null +++ b/tests/api/test_trades.py @@ -0,0 +1,265 @@ +# tests/api/test_trades.py +"""Tests for trade CRUD, upload, export, and provider listing.""" + +import io +import json +import pytest + +from .conftest import create_authenticated_user, auth_header, signup_user + +# Minimal valid trade file (Limitless format) +_SAMPLE_TRADES_JSON = json.dumps( + [ + { + "market": {"title": "Test Market", "slug": "test-market"}, + "timestamp": 1704067200, + "strategy": "Buy", + "outcomeIndex": 0, + "outcomeTokenAmount": 100, + "collateralAmount": 50, + "pnl": 5, + "blockTimestamp": 1704067200, + }, + { + "market": {"title": "Test Market", "slug": "test-market"}, + "timestamp": 1704153600, + "strategy": "Sell", + "outcomeIndex": 0, + "outcomeTokenAmount": 50, + "collateralAmount": 30, + "pnl": -2, + "blockTimestamp": 1704153600, + }, + ] +) + + +def _upload_trades(client, headers, content=None, filename="trades.json"): + """Helper to upload a trade file.""" + content = content or _SAMPLE_TRADES_JSON.encode() + return client.post( + "/api/v1/trades/upload", + headers=headers, + files={"file": (filename, io.BytesIO(content), "application/json")}, + ) + + +class TestUpload: + def test_upload_json(self, client): + _, headers = create_authenticated_user(client) + resp = _upload_trades(client, headers) + assert resp.status_code == 200 + data = resp.json() + assert data["trade_count"] == 2 + assert data["upload_id"] >= 1 + + def test_upload_duplicate_rejected(self, client): + _, headers = create_authenticated_user(client) + _upload_trades(client, headers) + resp = _upload_trades(client, headers) + assert resp.status_code == 400 + assert "already uploaded" in resp.json()["detail"] + + def test_upload_unsupported_extension(self, client): + _, headers = create_authenticated_user(client) + resp = client.post( + "/api/v1/trades/upload", + headers=headers, + files={"file": ("trades.txt", io.BytesIO(b"hello"), "text/plain")}, + ) + assert resp.status_code == 400 + + def test_upload_requires_auth(self, client): + resp = client.post( + "/api/v1/trades/upload", + files={"file": ("t.json", io.BytesIO(b"[]"), "application/json")}, + ) + assert resp.status_code == 401 + + def test_upload_oversized_file(self, client): + _, headers = create_authenticated_user(client) + # 11 MB of data + big = b"x" * (11 * 1024 * 1024) + resp = client.post( + "/api/v1/trades/upload", + headers=headers, + files={"file": ("big.json", io.BytesIO(big), "application/json")}, + ) + assert resp.status_code == 400 + assert ( + "too large" in resp.json()["detail"].lower() + or "File too large" in resp.json()["detail"] + ) + + +class TestListTrades: + def test_list_empty(self, client): + _, headers = create_authenticated_user(client) + resp = client.get("/api/v1/trades", headers=headers) + assert resp.status_code == 200 + assert resp.json()["total"] == 0 + assert resp.json()["trades"] == [] + + def test_list_after_upload(self, client): + _, headers = create_authenticated_user(client) + _upload_trades(client, headers) + resp = client.get("/api/v1/trades", headers=headers) + assert resp.status_code == 200 + assert resp.json()["total"] == 2 + + def test_list_pagination(self, client): + _, headers = create_authenticated_user(client) + _upload_trades(client, headers) + resp = client.get("/api/v1/trades?limit=1&offset=0", headers=headers) + data = resp.json() + assert len(data["trades"]) == 1 + assert data["total"] == 2 + + def test_list_source_filter(self, client): + _, headers = create_authenticated_user(client) + _upload_trades(client, headers) + resp = client.get("/api/v1/trades?source=limitless", headers=headers) + assert resp.status_code == 200 + # Limitless format trades should show up + assert resp.json()["total"] >= 0 + + +class TestGetTrade: + def test_get_trade_by_id(self, client): + _, headers = create_authenticated_user(client) + _upload_trades(client, headers) + # Get first trade + list_resp = client.get("/api/v1/trades?limit=1", headers=headers) + trade_id = list_resp.json()["trades"][0]["id"] + + resp = client.get(f"/api/v1/trades/{trade_id}", headers=headers) + assert resp.status_code == 200 + assert resp.json()["id"] == trade_id + + def test_get_trade_not_found(self, client): + _, headers = create_authenticated_user(client) + resp = client.get("/api/v1/trades/99999", headers=headers) + assert resp.status_code == 404 + + +class TestDeleteTrades: + def test_delete_single_trade(self, client): + _, headers = create_authenticated_user(client) + _upload_trades(client, headers) + list_resp = client.get("/api/v1/trades?limit=1", headers=headers) + trade_id = list_resp.json()["trades"][0]["id"] + + resp = client.delete(f"/api/v1/trades/{trade_id}", headers=headers) + assert resp.status_code == 204 + + # Verify deleted + resp = client.get(f"/api/v1/trades/{trade_id}", headers=headers) + assert resp.status_code == 404 + + def test_delete_all_trades(self, client): + _, headers = create_authenticated_user(client) + _upload_trades(client, headers) + resp = client.delete("/api/v1/trades", headers=headers) + assert resp.status_code == 200 + assert resp.json()["deleted_count"] == 2 + + # Verify all deleted + list_resp = client.get("/api/v1/trades", headers=headers) + assert list_resp.json()["total"] == 0 + + +class TestProviders: + def test_list_providers_requires_auth(self, client): + resp = client.get("/api/v1/trades/providers") + assert resp.status_code == 401 + + def test_list_providers_authenticated(self, client): + _, headers = create_authenticated_user(client) + resp = client.get("/api/v1/trades/providers", headers=headers) + assert resp.status_code == 200 + providers = resp.json() + names = [p["name"] for p in providers] + assert "limitless" in names + assert "polymarket" in names + assert "kalshi" in names + assert "manifold" in names + + +class TestMarkets: + def test_list_markets_empty(self, client): + _, headers = create_authenticated_user(client) + resp = client.get("/api/v1/trades/markets", headers=headers) + assert resp.status_code == 200 + assert resp.json() == [] + + def test_list_markets_after_upload(self, client): + _, headers = create_authenticated_user(client) + _upload_trades(client, headers) + resp = client.get("/api/v1/trades/markets", headers=headers) + assert resp.status_code == 200 + markets = resp.json() + assert len(markets) >= 1 + assert markets[0]["slug"] == "test-market" + + +class TestExport: + def test_export_csv(self, client): + _, headers = create_authenticated_user(client) + _upload_trades(client, headers) + resp = client.get("/api/v1/trades/export/csv", headers=headers) + assert resp.status_code == 200 + assert "text/csv" in resp.headers["content-type"] + assert "attachment" in resp.headers["content-disposition"] + # Verify CSV content has expected columns + lines = resp.text.strip().split("\n") + header = lines[0] + assert "market" in header + assert "pnl" in header + assert "source" in header + + def test_export_json(self, client): + _, headers = create_authenticated_user(client) + _upload_trades(client, headers) + resp = client.get("/api/v1/trades/export/json", headers=headers) + assert resp.status_code == 200 + assert "application/json" in resp.headers["content-type"] + data = json.loads(resp.text) + assert len(data) == 2 + assert "source" in data[0] + assert "currency" in data[0] + + def test_export_empty(self, client): + _, headers = create_authenticated_user(client) + resp = client.get("/api/v1/trades/export/csv", headers=headers) + assert resp.status_code == 404 + + +class TestUserIsolation: + """Verify that users can only see their own trades.""" + + def test_user_cannot_see_other_trades(self, client): + _, headers_a = create_authenticated_user(client, "a@test.com", "user_a", "password123") + _, headers_b = create_authenticated_user(client, "b@test.com", "user_b", "password123") + + # User A uploads trades + _upload_trades(client, headers_a) + + # User B sees nothing + resp = client.get("/api/v1/trades", headers=headers_b) + assert resp.json()["total"] == 0 + + # User A sees their trades + resp = client.get("/api/v1/trades", headers=headers_a) + assert resp.json()["total"] == 2 + + def test_user_cannot_delete_other_trades(self, client): + _, headers_a = create_authenticated_user(client, "a@test.com", "user_a", "password123") + _, headers_b = create_authenticated_user(client, "b@test.com", "user_b", "password123") + + _upload_trades(client, headers_a) + list_resp = client.get("/api/v1/trades?limit=1", headers=headers_a) + trade_id = list_resp.json()["trades"][0]["id"] + + # User B cannot delete User A's trade + resp = client.delete(f"/api/v1/trades/{trade_id}", headers=headers_b) + assert resp.status_code == 404 diff --git a/tests/conftest.py b/tests/conftest.py index 71d9552..3dfceab 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ These tests are designed to run BEFORE implementing new features to ensure existing patterns and contracts are maintained. """ + import pytest from datetime import datetime, timedelta from typing import List @@ -25,13 +26,14 @@ def sample_trade() -> Trade: type="Buy", side="YES", pnl=0.0, - tx_hash="0x123abc" + tx_hash="0x123abc", ) @pytest.fixture def sample_trade_factory(): """Factory function to create trades with custom attributes.""" + def _create_trade(**kwargs) -> Trade: defaults = { "market": "Test Market", @@ -43,13 +45,14 @@ def _create_trade(**kwargs) -> Trade: "type": "Buy", "side": "YES", "pnl": 0.0, - "tx_hash": None + "tx_hash": None, } defaults.update(kwargs) # Auto-set pnl_is_set if pnl was explicitly provided and not already set if "pnl_is_set" not in kwargs and "pnl" in kwargs: defaults["pnl_is_set"] = True return Trade(**defaults) + return _create_trade @@ -63,7 +66,7 @@ def sample_trades_list(sample_trade_factory) -> List[Trade]: side="YES", price=45.0, cost=4.5, - pnl=10.0 + pnl=10.0, ), sample_trade_factory( timestamp=datetime(2024, 3, 15, 14, 30, 0), @@ -71,7 +74,7 @@ def sample_trades_list(sample_trade_factory) -> List[Trade]: side="YES", price=55.0, cost=5.5, - pnl=-5.0 + pnl=-5.0, ), sample_trade_factory( timestamp=datetime(2024, 6, 1, 9, 0, 0), @@ -79,7 +82,7 @@ def sample_trades_list(sample_trade_factory) -> List[Trade]: side="NO", price=30.0, cost=3.0, - pnl=15.0 + pnl=15.0, ), sample_trade_factory( timestamp=datetime(2024, 9, 1, 16, 45, 0), @@ -87,7 +90,7 @@ def sample_trades_list(sample_trade_factory) -> List[Trade]: side="NO", price=70.0, cost=7.0, - pnl=0.0 + pnl=0.0, ), sample_trade_factory( timestamp=datetime(2024, 12, 1, 11, 15, 0), @@ -95,7 +98,7 @@ def sample_trades_list(sample_trade_factory) -> List[Trade]: side="YES", price=50.0, cost=5.0, - pnl=-8.0 + pnl=-8.0, ), ] @@ -157,11 +160,7 @@ def multi_market_trades(sample_trade_factory) -> List[Trade]: @pytest.fixture def all_trade_types(sample_trade_factory) -> List[Trade]: """Create trades with all possible trade types.""" - trade_types = [ - "Buy", "Sell", - "Market Buy", "Market Sell", - "Limit Buy", "Limit Sell" - ] + trade_types = ["Buy", "Sell", "Market Buy", "Market Sell", "Limit Buy", "Limit Sell"] return [sample_trade_factory(type=t) for t in trade_types] diff --git a/tests/mcp/conftest.py b/tests/mcp/conftest.py index 21c76ef..7486d2d 100644 --- a/tests/mcp/conftest.py +++ b/tests/mcp/conftest.py @@ -2,6 +2,7 @@ """ Shared fixtures for MCP tool tests. """ + import os import json import pytest @@ -11,10 +12,10 @@ from prediction_analyzer.trade_loader import Trade from prediction_mcp.state import session - EXAMPLE_TRADES_PATH = os.path.join( os.path.dirname(os.path.dirname(os.path.dirname(__file__))), - "data", "example_trades.json", + "data", + "example_trades.json", ) @@ -23,17 +24,19 @@ def make_trades(count=10): trades = [] for i in range(count): is_buy = i % 2 == 0 - trades.append(Trade( - market=f"Market {i // 3}", - market_slug=f"market-{i // 3}", - timestamp=datetime(2024, 1, 1 + i), - price=0.5 + (i * 0.01), - shares=10.0, - cost=5.0 + i, - type="Buy" if is_buy else "Sell", - side="YES" if i % 4 < 2 else "NO", - pnl=1.0 if is_buy else -0.5, - )) + trades.append( + Trade( + market=f"Market {i // 3}", + market_slug=f"market-{i // 3}", + timestamp=datetime(2024, 1, 1 + i), + price=0.5 + (i * 0.01), + shares=10.0, + cost=5.0 + i, + type="Buy" if is_buy else "Sell", + side="YES" if i % 4 < 2 else "NO", + pnl=1.0 if is_buy else -0.5, + ) + ) return trades diff --git a/tests/mcp/test_analysis_tools.py b/tests/mcp/test_analysis_tools.py index f1f8d86..d2998d5 100644 --- a/tests/mcp/test_analysis_tools.py +++ b/tests/mcp/test_analysis_tools.py @@ -1,5 +1,6 @@ # tests/mcp/test_analysis_tools.py """Tests for MCP analysis tools.""" + import json import asyncio @@ -23,19 +24,29 @@ def test_summary_with_trades(self, loaded_session): assert "win_rate" in data def test_summary_with_date_filter(self, loaded_session): - result = asyncio.run(analysis_tools.handle_tool("get_global_summary", { - "start_date": "2024-01-03", - "end_date": "2024-01-06", - })) + result = asyncio.run( + analysis_tools.handle_tool( + "get_global_summary", + { + "start_date": "2024-01-03", + "end_date": "2024-01-06", + }, + ) + ) data = json.loads(result[0].text) assert data["total_trades"] < 10 class TestMarketSummary: def test_no_trades_error(self): - result = asyncio.run(analysis_tools.handle_tool("get_market_summary", { - "market_slug": "test", - })) + result = asyncio.run( + analysis_tools.handle_tool( + "get_market_summary", + { + "market_slug": "test", + }, + ) + ) assert "No trades loaded" in result[0].text def test_missing_market_slug(self, loaded_session): @@ -43,15 +54,25 @@ def test_missing_market_slug(self, loaded_session): assert "market_slug is required" in result[0].text def test_market_not_found(self, loaded_session): - result = asyncio.run(analysis_tools.handle_tool("get_market_summary", { - "market_slug": "nonexistent-market", - })) + result = asyncio.run( + analysis_tools.handle_tool( + "get_market_summary", + { + "market_slug": "nonexistent-market", + }, + ) + ) assert "not found" in result[0].text def test_valid_market_summary(self, loaded_session): - result = asyncio.run(analysis_tools.handle_tool("get_market_summary", { - "market_slug": "market-0", - })) + result = asyncio.run( + analysis_tools.handle_tool( + "get_market_summary", + { + "market_slug": "market-0", + }, + ) + ) data = json.loads(result[0].text) assert "market_title" in data assert data["total_trades"] > 0 @@ -71,9 +92,14 @@ def test_metrics_returned(self, loaded_session): assert "profit_factor" in data def test_metrics_for_specific_market(self, loaded_session): - result = asyncio.run(analysis_tools.handle_tool("get_advanced_metrics", { - "market_slug": "market-0", - })) + result = asyncio.run( + analysis_tools.handle_tool( + "get_advanced_metrics", + { + "market_slug": "market-0", + }, + ) + ) data = json.loads(result[0].text) assert "sharpe_ratio" in data diff --git a/tests/mcp/test_chart_tools.py b/tests/mcp/test_chart_tools.py index 66314bb..cd68986 100644 --- a/tests/mcp/test_chart_tools.py +++ b/tests/mcp/test_chart_tools.py @@ -1,5 +1,6 @@ # tests/mcp/test_chart_tools.py """Tests for MCP chart tools.""" + import asyncio import pytest @@ -11,36 +12,61 @@ class TestGenerateChart: def test_no_trades_error(self): - result = asyncio.run(chart_tools.handle_tool("generate_chart", { - "market_slug": "market-0", - "chart_type": "simple", - })) + result = asyncio.run( + chart_tools.handle_tool( + "generate_chart", + { + "market_slug": "market-0", + "chart_type": "simple", + }, + ) + ) assert "No trades loaded" in result[0].text def test_missing_market_slug(self, loaded_session): - result = asyncio.run(chart_tools.handle_tool("generate_chart", { - "chart_type": "simple", - })) + result = asyncio.run( + chart_tools.handle_tool( + "generate_chart", + { + "chart_type": "simple", + }, + ) + ) assert "market_slug" in result[0].text def test_missing_chart_type(self, loaded_session): - result = asyncio.run(chart_tools.handle_tool("generate_chart", { - "market_slug": "market-0", - })) + result = asyncio.run( + chart_tools.handle_tool( + "generate_chart", + { + "market_slug": "market-0", + }, + ) + ) assert "chart_type" in result[0].text def test_invalid_chart_type(self, loaded_session): - result = asyncio.run(chart_tools.handle_tool("generate_chart", { - "market_slug": "market-0", - "chart_type": "invalid", - })) + result = asyncio.run( + chart_tools.handle_tool( + "generate_chart", + { + "market_slug": "market-0", + "chart_type": "invalid", + }, + ) + ) assert "Invalid chart type" in result[0].text def test_unknown_market(self, loaded_session): - result = asyncio.run(chart_tools.handle_tool("generate_chart", { - "market_slug": "nonexistent-market", - "chart_type": "simple", - })) + result = asyncio.run( + chart_tools.handle_tool( + "generate_chart", + { + "market_slug": "nonexistent-market", + "chart_type": "simple", + }, + ) + ) assert "not found" in result[0].text diff --git a/tests/mcp/test_data_tools.py b/tests/mcp/test_data_tools.py index 8d5460e..a6c1d59 100644 --- a/tests/mcp/test_data_tools.py +++ b/tests/mcp/test_data_tools.py @@ -1,5 +1,6 @@ # tests/mcp/test_data_tools.py """Tests for MCP data tools.""" + import json import asyncio import os @@ -13,9 +14,14 @@ class TestLoadTrades: def test_load_from_file(self): - result = asyncio.run(data_tools.handle_tool("load_trades", { - "file_path": EXAMPLE_TRADES_PATH, - })) + result = asyncio.run( + data_tools.handle_tool( + "load_trades", + { + "file_path": EXAMPLE_TRADES_PATH, + }, + ) + ) assert result is not None data = json.loads(result[0].text) assert data["trade_count"] > 0 @@ -27,15 +33,25 @@ def test_missing_file_path(self): assert "file_path is required" in result[0].text def test_nonexistent_file(self): - result = asyncio.run(data_tools.handle_tool("load_trades", { - "file_path": "/nonexistent/file.json", - })) + result = asyncio.run( + data_tools.handle_tool( + "load_trades", + { + "file_path": "/nonexistent/file.json", + }, + ) + ) assert "File not found" in result[0].text def test_session_updated_after_load(self): - asyncio.run(data_tools.handle_tool("load_trades", { - "file_path": EXAMPLE_TRADES_PATH, - })) + asyncio.run( + data_tools.handle_tool( + "load_trades", + { + "file_path": EXAMPLE_TRADES_PATH, + }, + ) + ) assert session.has_trades assert session.source.startswith("file:") assert len(session.filtered_trades) == len(session.trades) @@ -47,9 +63,14 @@ def test_no_trades_error(self): assert "No trades loaded" in result[0].text def test_list_markets_after_load(self): - asyncio.run(data_tools.handle_tool("load_trades", { - "file_path": EXAMPLE_TRADES_PATH, - })) + asyncio.run( + data_tools.handle_tool( + "load_trades", + { + "file_path": EXAMPLE_TRADES_PATH, + }, + ) + ) result = asyncio.run(data_tools.handle_tool("list_markets", {})) data = json.loads(result[0].text) assert len(data) > 0 @@ -64,57 +85,92 @@ def test_no_trades_error(self): assert "No trades loaded" in result[0].text def test_basic_trade_details(self, loaded_session): - result = asyncio.run(data_tools.handle_tool("get_trade_details", { - "limit": 3, - })) + result = asyncio.run( + data_tools.handle_tool( + "get_trade_details", + { + "limit": 3, + }, + ) + ) data = json.loads(result[0].text) assert data["total"] == 10 assert len(data["trades"]) == 3 def test_pagination(self, loaded_session): - result = asyncio.run(data_tools.handle_tool("get_trade_details", { - "limit": 2, - "offset": 5, - })) + result = asyncio.run( + data_tools.handle_tool( + "get_trade_details", + { + "limit": 2, + "offset": 5, + }, + ) + ) data = json.loads(result[0].text) assert len(data["trades"]) == 2 assert data["offset"] == 5 def test_sort_by_pnl(self, loaded_session): - result = asyncio.run(data_tools.handle_tool("get_trade_details", { - "sort_by": "pnl", - "sort_order": "desc", - "limit": 3, - })) + result = asyncio.run( + data_tools.handle_tool( + "get_trade_details", + { + "sort_by": "pnl", + "sort_order": "desc", + "limit": 3, + }, + ) + ) data = json.loads(result[0].text) pnls = [t["pnl"] for t in data["trades"]] assert pnls == sorted(pnls, reverse=True) def test_filter_by_market(self, loaded_session): - result = asyncio.run(data_tools.handle_tool("get_trade_details", { - "market_slug": "market-0", - })) + result = asyncio.run( + data_tools.handle_tool( + "get_trade_details", + { + "market_slug": "market-0", + }, + ) + ) data = json.loads(result[0].text) assert all(t["market_slug"] == "market-0" for t in data["trades"]) class TestInputValidation: def test_invalid_sort_field(self, loaded_session): - result = asyncio.run(data_tools.handle_tool("get_trade_details", { - "sort_by": "invalid_field", - })) + result = asyncio.run( + data_tools.handle_tool( + "get_trade_details", + { + "sort_by": "invalid_field", + }, + ) + ) assert "Invalid sort field" in result[0].text def test_negative_limit(self, loaded_session): - result = asyncio.run(data_tools.handle_tool("get_trade_details", { - "limit": -1, - })) + result = asyncio.run( + data_tools.handle_tool( + "get_trade_details", + { + "limit": -1, + }, + ) + ) assert "Invalid limit" in result[0].text def test_negative_offset(self, loaded_session): - result = asyncio.run(data_tools.handle_tool("get_trade_details", { - "offset": -5, - })) + result = asyncio.run( + data_tools.handle_tool( + "get_trade_details", + { + "offset": -5, + }, + ) + ) assert "Invalid offset" in result[0].text diff --git a/tests/mcp/test_errors.py b/tests/mcp/test_errors.py index 1bd72e7..81e044a 100644 --- a/tests/mcp/test_errors.py +++ b/tests/mcp/test_errors.py @@ -1,5 +1,6 @@ # tests/mcp/test_errors.py """Tests for MCP error handling.""" + import pytest import asyncio diff --git a/tests/mcp/test_export_tools.py b/tests/mcp/test_export_tools.py index c925d77..80d9b9b 100644 --- a/tests/mcp/test_export_tools.py +++ b/tests/mcp/test_export_tools.py @@ -1,5 +1,6 @@ # tests/mcp/test_export_tools.py """Tests for MCP export tools.""" + import json import asyncio import os @@ -13,22 +14,37 @@ class TestExportTrades: def test_no_trades_error(self): - result = asyncio.run(export_tools.handle_tool("export_trades", { - "format": "csv", - "output_path": "/tmp/test.csv", - })) + result = asyncio.run( + export_tools.handle_tool( + "export_trades", + { + "format": "csv", + "output_path": "/tmp/test.csv", + }, + ) + ) assert "No trades loaded" in result[0].text def test_missing_format(self, loaded_session): - result = asyncio.run(export_tools.handle_tool("export_trades", { - "output_path": "/tmp/test.csv", - })) + result = asyncio.run( + export_tools.handle_tool( + "export_trades", + { + "output_path": "/tmp/test.csv", + }, + ) + ) assert "format is required" in result[0].text def test_missing_output_path(self, loaded_session): - result = asyncio.run(export_tools.handle_tool("export_trades", { - "format": "csv", - })) + result = asyncio.run( + export_tools.handle_tool( + "export_trades", + { + "format": "csv", + }, + ) + ) assert "output_path is required" in result[0].text def test_export_csv(self, loaded_session): @@ -36,10 +52,15 @@ def test_export_csv(self, loaded_session): path = f.name try: - result = asyncio.run(export_tools.handle_tool("export_trades", { - "format": "csv", - "output_path": path, - })) + result = asyncio.run( + export_tools.handle_tool( + "export_trades", + { + "format": "csv", + "output_path": path, + }, + ) + ) data = json.loads(result[0].text) assert data["trade_count"] == 10 assert data["format"] == "csv" @@ -52,10 +73,15 @@ def test_export_json(self, loaded_session): path = f.name try: - result = asyncio.run(export_tools.handle_tool("export_trades", { - "format": "json", - "output_path": path, - })) + result = asyncio.run( + export_tools.handle_tool( + "export_trades", + { + "format": "json", + "output_path": path, + }, + ) + ) data = json.loads(result[0].text) assert data["trade_count"] == 10 assert os.path.exists(path) @@ -63,8 +89,13 @@ def test_export_json(self, loaded_session): os.unlink(path) def test_invalid_format(self, loaded_session): - result = asyncio.run(export_tools.handle_tool("export_trades", { - "format": "pdf", - "output_path": "/tmp/test.pdf", - })) + result = asyncio.run( + export_tools.handle_tool( + "export_trades", + { + "format": "pdf", + "output_path": "/tmp/test.pdf", + }, + ) + ) assert "Invalid export format" in result[0].text diff --git a/tests/mcp/test_filter_tools.py b/tests/mcp/test_filter_tools.py index ac746dc..0332af0 100644 --- a/tests/mcp/test_filter_tools.py +++ b/tests/mcp/test_filter_tools.py @@ -1,5 +1,6 @@ # tests/mcp/test_filter_tools.py """Tests for MCP filter tools.""" + import json import asyncio @@ -15,48 +16,85 @@ def test_no_trades_error(self): assert "No trades loaded" in result[0].text def test_filter_by_side(self, loaded_session): - result = asyncio.run(filter_tools.handle_tool("filter_trades", { - "sides": ["YES"], - })) + result = asyncio.run( + filter_tools.handle_tool( + "filter_trades", + { + "sides": ["YES"], + }, + ) + ) data = json.loads(result[0].text) assert data["original_count"] == 10 assert data["filtered_count"] < 10 def test_filter_by_trade_type(self, loaded_session): - result = asyncio.run(filter_tools.handle_tool("filter_trades", { - "trade_types": ["Buy"], - })) + result = asyncio.run( + filter_tools.handle_tool( + "filter_trades", + { + "trade_types": ["Buy"], + }, + ) + ) data = json.loads(result[0].text) assert data["filtered_count"] == 5 # half are buys def test_filter_by_date(self, loaded_session): - result = asyncio.run(filter_tools.handle_tool("filter_trades", { - "start_date": "2024-01-05", - })) + result = asyncio.run( + filter_tools.handle_tool( + "filter_trades", + { + "start_date": "2024-01-05", + }, + ) + ) data = json.loads(result[0].text) assert data["filtered_count"] < 10 def test_clear_filters(self, loaded_session): # Apply a filter - asyncio.run(filter_tools.handle_tool("filter_trades", { - "trade_types": ["Buy"], - })) + asyncio.run( + filter_tools.handle_tool( + "filter_trades", + { + "trade_types": ["Buy"], + }, + ) + ) assert len(session.filtered_trades) == 5 # Clear - result = asyncio.run(filter_tools.handle_tool("filter_trades", { - "clear": True, - })) + result = asyncio.run( + filter_tools.handle_tool( + "filter_trades", + { + "clear": True, + }, + ) + ) data = json.loads(result[0].text) assert data["filtered_count"] == 10 assert session.active_filters == {} def test_active_filters_tracked(self, loaded_session): - asyncio.run(filter_tools.handle_tool("filter_trades", { - "sides": ["YES"], - "min_pnl": 0.5, - })) - data = json.loads(asyncio.run(filter_tools.handle_tool("filter_trades", { - "sides": ["YES"], - }))[0].text) + asyncio.run( + filter_tools.handle_tool( + "filter_trades", + { + "sides": ["YES"], + "min_pnl": 0.5, + }, + ) + ) + data = json.loads( + asyncio.run( + filter_tools.handle_tool( + "filter_trades", + { + "sides": ["YES"], + }, + ) + )[0].text + ) assert "sides" in data["active_filters"] diff --git a/tests/mcp/test_llm_inputs.py b/tests/mcp/test_llm_inputs.py index 92e7aa1..01acec1 100644 --- a/tests/mcp/test_llm_inputs.py +++ b/tests/mcp/test_llm_inputs.py @@ -9,6 +9,7 @@ - Empty strings, null values - NaN and Infinity as numeric inputs """ + import json import asyncio import math @@ -16,8 +17,13 @@ import pytest from prediction_mcp.tools import ( - data_tools, analysis_tools, filter_tools, - chart_tools, export_tools, portfolio_tools, tax_tools, + data_tools, + analysis_tools, + filter_tools, + chart_tools, + export_tools, + portfolio_tools, + tax_tools, ) from prediction_mcp.state import session from .conftest import make_trades @@ -28,40 +34,65 @@ class TestWrongParameterNames: def test_market_instead_of_market_slug(self, loaded_session): """LLM might say 'market' instead of 'market_slug'.""" - result = asyncio.run(data_tools.handle_tool("get_trade_details", { - "market": "market-0", # wrong key - })) + result = asyncio.run( + data_tools.handle_tool( + "get_trade_details", + { + "market": "market-0", # wrong key + }, + ) + ) # Should return all trades (market param ignored, no market_slug filtering) data = json.loads(result[0].text) assert data["total"] == 10 def test_path_instead_of_file_path(self): """LLM might say 'path' instead of 'file_path'.""" - result = asyncio.run(data_tools.handle_tool("load_trades", { - "path": "/some/file.json", # wrong key - })) + result = asyncio.run( + data_tools.handle_tool( + "load_trades", + { + "path": "/some/file.json", # wrong key + }, + ) + ) assert "file_path is required" in result[0].text def test_type_instead_of_chart_type(self, loaded_session): """LLM might say 'type' instead of 'chart_type'.""" - result = asyncio.run(chart_tools.handle_tool("generate_chart", { - "market_slug": "market-0", - "type": "simple", # wrong key - })) + result = asyncio.run( + chart_tools.handle_tool( + "generate_chart", + { + "market_slug": "market-0", + "type": "simple", # wrong key + }, + ) + ) assert "chart_type" in result[0].text def test_year_instead_of_tax_year(self, loaded_session): """LLM might say 'year' instead of 'tax_year'.""" - result = asyncio.run(tax_tools.handle_tool("get_tax_report", { - "year": 2024, # wrong key - })) + result = asyncio.run( + tax_tools.handle_tool( + "get_tax_report", + { + "year": 2024, # wrong key + }, + ) + ) assert "tax_year is required" in result[0].text def test_key_instead_of_api_key(self): """LLM might say 'key' instead of 'api_key'.""" - result = asyncio.run(data_tools.handle_tool("fetch_trades", { - "key": "lmts_test_key", # wrong key - })) + result = asyncio.run( + data_tools.handle_tool( + "fetch_trades", + { + "key": "lmts_test_key", # wrong key + }, + ) + ) assert "api_key is required" in result[0].text @@ -70,49 +101,79 @@ class TestLowercaseEnumValues: def test_lowercase_trade_type(self, loaded_session): """LLM sends 'buy' instead of 'Buy' — should be normalized.""" - result = asyncio.run(filter_tools.handle_tool("filter_trades", { - "trade_types": ["buy"], - })) + result = asyncio.run( + filter_tools.handle_tool( + "filter_trades", + { + "trade_types": ["buy"], + }, + ) + ) data = json.loads(result[0].text) assert "filtered_count" in data # normalized to "Buy", filter applied def test_lowercase_side(self, loaded_session): """LLM sends 'yes' instead of 'YES' — should be normalized.""" - result = asyncio.run(filter_tools.handle_tool("filter_trades", { - "sides": ["yes"], - })) + result = asyncio.run( + filter_tools.handle_tool( + "filter_trades", + { + "sides": ["yes"], + }, + ) + ) data = json.loads(result[0].text) assert "filtered_count" in data # normalized to "YES", filter applied def test_uppercase_format(self, loaded_session): """LLM sends 'CSV' instead of 'csv'.""" - result = asyncio.run(export_tools.handle_tool("export_trades", { - "format": "CSV", - "output_path": "/tmp/test.csv", - })) + result = asyncio.run( + export_tools.handle_tool( + "export_trades", + { + "format": "CSV", + "output_path": "/tmp/test.csv", + }, + ) + ) assert "Invalid export format" in result[0].text def test_uppercase_chart_type(self, loaded_session): """LLM sends 'Simple' instead of 'simple'.""" - result = asyncio.run(chart_tools.handle_tool("generate_chart", { - "market_slug": "market-0", - "chart_type": "Simple", - })) + result = asyncio.run( + chart_tools.handle_tool( + "generate_chart", + { + "market_slug": "market-0", + "chart_type": "Simple", + }, + ) + ) assert "Invalid chart type" in result[0].text def test_uppercase_cost_basis(self, loaded_session): """LLM sends 'FIFO' instead of 'fifo'.""" - result = asyncio.run(tax_tools.handle_tool("get_tax_report", { - "tax_year": 2024, - "cost_basis_method": "FIFO", - })) + result = asyncio.run( + tax_tools.handle_tool( + "get_tax_report", + { + "tax_year": 2024, + "cost_basis_method": "FIFO", + }, + ) + ) assert "Invalid cost basis method" in result[0].text def test_uppercase_sort_by(self, loaded_session): """LLM sends 'PNL' instead of 'pnl'.""" - result = asyncio.run(data_tools.handle_tool("get_trade_details", { - "sort_by": "PNL", - })) + result = asyncio.run( + data_tools.handle_tool( + "get_trade_details", + { + "sort_by": "PNL", + }, + ) + ) assert "Invalid sort field" in result[0].text @@ -120,30 +181,50 @@ class TestEmptyAndNullValues: """LLMs sometimes send empty strings or null/None values.""" def test_empty_string_file_path(self): - result = asyncio.run(data_tools.handle_tool("load_trades", { - "file_path": "", - })) + result = asyncio.run( + data_tools.handle_tool( + "load_trades", + { + "file_path": "", + }, + ) + ) assert "file_path is required" in result[0].text def test_empty_string_api_key(self): - result = asyncio.run(data_tools.handle_tool("fetch_trades", { - "api_key": "", - })) + result = asyncio.run( + data_tools.handle_tool( + "fetch_trades", + { + "api_key": "", + }, + ) + ) assert "api_key is required" in result[0].text def test_empty_string_market_slug(self, loaded_session): """Empty string market_slug should not filter (treated as falsy).""" - result = asyncio.run(data_tools.handle_tool("get_trade_details", { - "market_slug": "", - })) + result = asyncio.run( + data_tools.handle_tool( + "get_trade_details", + { + "market_slug": "", + }, + ) + ) data = json.loads(result[0].text) assert data["total"] == 10 def test_null_format(self, loaded_session): - result = asyncio.run(export_tools.handle_tool("export_trades", { - "format": None, - "output_path": "/tmp/test.csv", - })) + result = asyncio.run( + export_tools.handle_tool( + "export_trades", + { + "format": None, + "output_path": "/tmp/test.csv", + }, + ) + ) assert "format is required" in result[0].text def test_empty_arguments(self, loaded_session): @@ -159,9 +240,14 @@ def test_empty_arguments(self, loaded_session): assert result is not None def test_none_tax_year(self, loaded_session): - result = asyncio.run(tax_tools.handle_tool("get_tax_report", { - "tax_year": None, - })) + result = asyncio.run( + tax_tools.handle_tool( + "get_tax_report", + { + "tax_year": None, + }, + ) + ) assert "tax_year is required" in result[0].text @@ -170,40 +256,65 @@ class TestNaNAndInfinityInputs: def test_nan_min_pnl(self, loaded_session): """NaN as a filter threshold should return a validation error.""" - result = asyncio.run(filter_tools.handle_tool("filter_trades", { - "min_pnl": float("nan"), - })) + result = asyncio.run( + filter_tools.handle_tool( + "filter_trades", + { + "min_pnl": float("nan"), + }, + ) + ) assert result is not None assert "NaN" in result[0].text or "Invalid" in result[0].text def test_infinity_max_pnl(self, loaded_session): """Infinity as max_pnl should return a validation error.""" - result = asyncio.run(filter_tools.handle_tool("filter_trades", { - "max_pnl": float("inf"), - })) + result = asyncio.run( + filter_tools.handle_tool( + "filter_trades", + { + "max_pnl": float("inf"), + }, + ) + ) assert result is not None assert "NaN" in result[0].text or "Invalid" in result[0].text def test_negative_infinity_min_pnl(self, loaded_session): """Negative infinity as min_pnl should return a validation error.""" - result = asyncio.run(filter_tools.handle_tool("filter_trades", { - "min_pnl": float("-inf"), - })) + result = asyncio.run( + filter_tools.handle_tool( + "filter_trades", + { + "min_pnl": float("-inf"), + }, + ) + ) assert result is not None assert "NaN" in result[0].text or "Invalid" in result[0].text def test_nan_limit(self, loaded_session): """NaN as limit should be caught by validation.""" - result = asyncio.run(data_tools.handle_tool("get_trade_details", { - "limit": float("nan"), - })) + result = asyncio.run( + data_tools.handle_tool( + "get_trade_details", + { + "limit": float("nan"), + }, + ) + ) assert "Invalid limit" in result[0].text def test_infinity_limit(self, loaded_session): """Infinity as limit should be caught by validation.""" - result = asyncio.run(data_tools.handle_tool("get_trade_details", { - "limit": float("inf"), - })) + result = asyncio.run( + data_tools.handle_tool( + "get_trade_details", + { + "limit": float("inf"), + }, + ) + ) assert "Invalid limit" in result[0].text @@ -215,21 +326,36 @@ def test_generate_chart_no_params(self, loaded_session): assert "market_slug" in result[0].text or "required" in result[0].text def test_export_no_format(self, loaded_session): - result = asyncio.run(export_tools.handle_tool("export_trades", { - "output_path": "/tmp/test.csv", - })) + result = asyncio.run( + export_tools.handle_tool( + "export_trades", + { + "output_path": "/tmp/test.csv", + }, + ) + ) assert "format is required" in result[0].text def test_export_no_output_path(self, loaded_session): - result = asyncio.run(export_tools.handle_tool("export_trades", { - "format": "csv", - })) + result = asyncio.run( + export_tools.handle_tool( + "export_trades", + { + "format": "csv", + }, + ) + ) assert "output_path is required" in result[0].text def test_compare_periods_partial(self, loaded_session): - result = asyncio.run(portfolio_tools.handle_tool("compare_periods", { - "period_1_start": "2024-01-01", - })) + result = asyncio.run( + portfolio_tools.handle_tool( + "compare_periods", + { + "period_1_start": "2024-01-01", + }, + ) + ) assert "required" in result[0].text def test_market_summary_no_slug(self, loaded_session): @@ -241,17 +367,27 @@ class TestAvailableMarketsInError: """Verify that market-not-found errors include available slugs.""" def test_unknown_market_shows_available(self, loaded_session): - result = asyncio.run(data_tools.handle_tool("get_trade_details", { - "market_slug": "nonexistent-market", - })) + result = asyncio.run( + data_tools.handle_tool( + "get_trade_details", + { + "market_slug": "nonexistent-market", + }, + ) + ) text = result[0].text assert "not found" in text assert "market-0" in text # should list available slugs def test_unknown_market_in_analysis(self, loaded_session): - result = asyncio.run(analysis_tools.handle_tool("get_market_summary", { - "market_slug": "fake-market", - })) + result = asyncio.run( + analysis_tools.handle_tool( + "get_market_summary", + { + "market_slug": "fake-market", + }, + ) + ) text = result[0].text assert "not found" in text assert "market-0" in text diff --git a/tests/mcp/test_persistence.py b/tests/mcp/test_persistence.py index aabc0b1..62636c2 100644 --- a/tests/mcp/test_persistence.py +++ b/tests/mcp/test_persistence.py @@ -1,5 +1,6 @@ # tests/mcp/test_persistence.py """Tests for SQLite session persistence.""" + import os import json import tempfile diff --git a/tests/mcp/test_portfolio_tools.py b/tests/mcp/test_portfolio_tools.py index ff9f296..35efa9b 100644 --- a/tests/mcp/test_portfolio_tools.py +++ b/tests/mcp/test_portfolio_tools.py @@ -1,5 +1,6 @@ # tests/mcp/test_portfolio_tools.py """Tests for MCP portfolio tools.""" + import json import asyncio @@ -49,37 +50,57 @@ def test_drawdown_returned(self, loaded_session): class TestComparePeriods: def test_no_trades_error(self): - result = asyncio.run(portfolio_tools.handle_tool("compare_periods", { - "period_1_start": "2024-01-01", - "period_1_end": "2024-01-05", - "period_2_start": "2024-01-06", - "period_2_end": "2024-01-10", - })) + result = asyncio.run( + portfolio_tools.handle_tool( + "compare_periods", + { + "period_1_start": "2024-01-01", + "period_1_end": "2024-01-05", + "period_2_start": "2024-01-06", + "period_2_end": "2024-01-10", + }, + ) + ) assert "No trades loaded" in result[0].text def test_comparison_returned(self, loaded_session): - result = asyncio.run(portfolio_tools.handle_tool("compare_periods", { - "period_1_start": "2024-01-01", - "period_1_end": "2024-01-05", - "period_2_start": "2024-01-06", - "period_2_end": "2024-01-10", - })) + result = asyncio.run( + portfolio_tools.handle_tool( + "compare_periods", + { + "period_1_start": "2024-01-01", + "period_1_end": "2024-01-05", + "period_2_start": "2024-01-06", + "period_2_end": "2024-01-10", + }, + ) + ) data = json.loads(result[0].text) assert "period_1" in data assert "period_2" in data assert "changes" in data def test_missing_dates_error(self, loaded_session): - result = asyncio.run(portfolio_tools.handle_tool("compare_periods", { - "period_1_start": "2024-01-01", - })) + result = asyncio.run( + portfolio_tools.handle_tool( + "compare_periods", + { + "period_1_start": "2024-01-01", + }, + ) + ) assert "required" in result[0].text def test_invalid_date_format(self, loaded_session): - result = asyncio.run(portfolio_tools.handle_tool("compare_periods", { - "period_1_start": "bad-date", - "period_1_end": "2024-01-05", - "period_2_start": "2024-01-06", - "period_2_end": "2024-01-10", - })) + result = asyncio.run( + portfolio_tools.handle_tool( + "compare_periods", + { + "period_1_start": "bad-date", + "period_1_end": "2024-01-05", + "period_2_start": "2024-01-06", + "period_2_end": "2024-01-10", + }, + ) + ) assert "YYYY-MM-DD" in result[0].text diff --git a/tests/mcp/test_serializers.py b/tests/mcp/test_serializers.py index 3553766..480b277 100644 --- a/tests/mcp/test_serializers.py +++ b/tests/mcp/test_serializers.py @@ -1,5 +1,6 @@ # tests/mcp/test_serializers.py """Tests for MCP serializers.""" + import math import json from datetime import datetime @@ -72,10 +73,15 @@ def test_empty_list(self): def test_single_trade(self): trade = Trade( - market="Test", market_slug="test", + market="Test", + market_slug="test", timestamp=datetime(2024, 1, 1), - price=0.5, shares=10.0, cost=5.0, - type="Buy", side="YES", pnl=1.0, + price=0.5, + shares=10.0, + cost=5.0, + type="Buy", + side="YES", + pnl=1.0, ) result = serialize_trades([trade]) assert len(result) == 1 @@ -85,10 +91,15 @@ def test_single_trade(self): def test_nan_pnl_sanitized(self): trade = Trade( - market="Test", market_slug="test", + market="Test", + market_slug="test", timestamp=datetime(2024, 1, 1), - price=0.5, shares=10.0, cost=5.0, - type="Buy", side="YES", pnl=float("nan"), + price=0.5, + shares=10.0, + cost=5.0, + type="Buy", + side="YES", + pnl=float("nan"), ) result = serialize_trades([trade]) assert result[0]["pnl"] == 0.0 diff --git a/tests/mcp/test_server.py b/tests/mcp/test_server.py index d5eb237..205ed45 100644 --- a/tests/mcp/test_server.py +++ b/tests/mcp/test_server.py @@ -1,5 +1,6 @@ # tests/mcp/test_server.py """Tests for MCP server dispatch logic.""" + import asyncio import pytest diff --git a/tests/mcp/test_sse_transport.py b/tests/mcp/test_sse_transport.py index a618dea..f757334 100644 --- a/tests/mcp/test_sse_transport.py +++ b/tests/mcp/test_sse_transport.py @@ -1,5 +1,6 @@ # tests/mcp/test_sse_transport.py """Tests for HTTP/SSE transport setup.""" + import pytest from starlette.testclient import TestClient diff --git a/tests/mcp/test_state.py b/tests/mcp/test_state.py index a404e13..4a67bd7 100644 --- a/tests/mcp/test_state.py +++ b/tests/mcp/test_state.py @@ -1,5 +1,6 @@ # tests/mcp/test_state.py """Tests for MCP session state.""" + from prediction_mcp.state import SessionState, session from .conftest import make_trades diff --git a/tests/mcp/test_tax_tools.py b/tests/mcp/test_tax_tools.py index cfcca47..f0cbe20 100644 --- a/tests/mcp/test_tax_tools.py +++ b/tests/mcp/test_tax_tools.py @@ -1,5 +1,6 @@ # tests/mcp/test_tax_tools.py """Tests for MCP tax tools.""" + import json import asyncio @@ -11,9 +12,14 @@ class TestTaxReport: def test_no_trades_error(self): - result = asyncio.run(tax_tools.handle_tool("get_tax_report", { - "tax_year": 2024, - })) + result = asyncio.run( + tax_tools.handle_tool( + "get_tax_report", + { + "tax_year": 2024, + }, + ) + ) assert "No trades loaded" in result[0].text def test_missing_tax_year(self, loaded_session): @@ -21,10 +27,15 @@ def test_missing_tax_year(self, loaded_session): assert "tax_year is required" in result[0].text def test_fifo_report(self, loaded_session): - result = asyncio.run(tax_tools.handle_tool("get_tax_report", { - "tax_year": 2024, - "cost_basis_method": "fifo", - })) + result = asyncio.run( + tax_tools.handle_tool( + "get_tax_report", + { + "tax_year": 2024, + "cost_basis_method": "fifo", + }, + ) + ) data = json.loads(result[0].text) assert data["tax_year"] == 2024 assert data["method"] == "fifo" @@ -33,31 +44,51 @@ def test_fifo_report(self, loaded_session): assert "transactions" in data def test_lifo_report(self, loaded_session): - result = asyncio.run(tax_tools.handle_tool("get_tax_report", { - "tax_year": 2024, - "cost_basis_method": "lifo", - })) + result = asyncio.run( + tax_tools.handle_tool( + "get_tax_report", + { + "tax_year": 2024, + "cost_basis_method": "lifo", + }, + ) + ) data = json.loads(result[0].text) assert data["method"] == "lifo" def test_average_report(self, loaded_session): - result = asyncio.run(tax_tools.handle_tool("get_tax_report", { - "tax_year": 2024, - "cost_basis_method": "average", - })) + result = asyncio.run( + tax_tools.handle_tool( + "get_tax_report", + { + "tax_year": 2024, + "cost_basis_method": "average", + }, + ) + ) data = json.loads(result[0].text) assert data["method"] == "average" def test_default_method_is_fifo(self, loaded_session): - result = asyncio.run(tax_tools.handle_tool("get_tax_report", { - "tax_year": 2024, - })) + result = asyncio.run( + tax_tools.handle_tool( + "get_tax_report", + { + "tax_year": 2024, + }, + ) + ) data = json.loads(result[0].text) assert data["method"] == "fifo" def test_invalid_cost_basis_method(self, loaded_session): - result = asyncio.run(tax_tools.handle_tool("get_tax_report", { - "tax_year": 2024, - "cost_basis_method": "invalid", - })) + result = asyncio.run( + tax_tools.handle_tool( + "get_tax_report", + { + "tax_year": 2024, + "cost_basis_method": "invalid", + }, + ) + ) assert "Invalid cost basis method" in result[0].text diff --git a/tests/mcp/test_transport.py b/tests/mcp/test_transport.py index 8ab6b9a..4355407 100644 --- a/tests/mcp/test_transport.py +++ b/tests/mcp/test_transport.py @@ -1,5 +1,6 @@ # tests/mcp/test_transport.py """Tests for MCP transport safety — no stdout pollution.""" + import io import sys import json @@ -9,8 +10,12 @@ import pytest from prediction_mcp.tools import ( - data_tools, analysis_tools, filter_tools, - export_tools, portfolio_tools, tax_tools, + data_tools, + analysis_tools, + filter_tools, + export_tools, + portfolio_tools, + tax_tools, ) from prediction_mcp.state import session from .conftest import make_trades, EXAMPLE_TRADES_PATH @@ -25,15 +30,17 @@ def _run_tool(self, module, name, args): with contextlib.redirect_stdout(stdout_capture): result = asyncio.run(module.handle_tool(name, args)) stdout_output = stdout_capture.getvalue() - assert stdout_output == "", ( - f"Tool '{name}' wrote to stdout: {stdout_output!r}" - ) + assert stdout_output == "", f"Tool '{name}' wrote to stdout: {stdout_output!r}" return result def test_load_trades_no_stdout(self): - self._run_tool(data_tools, "load_trades", { - "file_path": EXAMPLE_TRADES_PATH, - }) + self._run_tool( + data_tools, + "load_trades", + { + "file_path": EXAMPLE_TRADES_PATH, + }, + ) def test_list_markets_no_stdout(self): session.trades = make_trades(5) diff --git a/tests/mcp/test_validators.py b/tests/mcp/test_validators.py index a3f16fe..23ff715 100644 --- a/tests/mcp/test_validators.py +++ b/tests/mcp/test_validators.py @@ -1,5 +1,6 @@ # tests/mcp/test_validators.py """Tests for MCP input validators.""" + import math import pytest diff --git a/tests/static_patterns/test_api_contracts.py b/tests/static_patterns/test_api_contracts.py index d57ad2a..3e27aad 100644 --- a/tests/static_patterns/test_api_contracts.py +++ b/tests/static_patterns/test_api_contracts.py @@ -5,6 +5,7 @@ These tests verify that public function signatures and return types remain stable. This helps catch breaking changes to the API. """ + import pytest import inspect from typing import get_type_hints, List, Dict, Optional @@ -17,6 +18,7 @@ class TestTradeLoaderAPIContracts: def test_load_trades_signature(self): """load_trades should accept file_path and return List[Trade].""" from prediction_analyzer.trade_loader import load_trades + sig = inspect.signature(load_trades) params = list(sig.parameters.keys()) @@ -26,6 +28,7 @@ def test_load_trades_signature(self): def test_save_trades_signature(self): """save_trades should accept trades and file_path.""" from prediction_analyzer.trade_loader import save_trades + sig = inspect.signature(save_trades) params = list(sig.parameters.keys()) @@ -35,11 +38,13 @@ def test_save_trades_signature(self): def test_parse_timestamp_exists(self): """_parse_timestamp helper should exist.""" from prediction_analyzer.trade_loader import _parse_timestamp + assert callable(_parse_timestamp) def test_sanitize_filename_exists(self): """_sanitize_filename helper should exist.""" from prediction_analyzer.trade_loader import _sanitize_filename + assert callable(_sanitize_filename) @@ -93,10 +98,7 @@ def test_calculate_global_pnl_summary_keys(self, sample_trades_list): from prediction_analyzer.pnl import calculate_global_pnl_summary result = calculate_global_pnl_summary(sample_trades_list) - expected_keys = { - "total_trades", "total_pnl", "win_rate", - "winning_trades", "losing_trades" - } + expected_keys = {"total_trades", "total_pnl", "win_rate", "winning_trades", "losing_trades"} assert expected_keys.issubset(set(result.keys())) def test_calculate_market_pnl_signature(self): diff --git a/tests/static_patterns/test_config_integrity.py b/tests/static_patterns/test_config_integrity.py index 7c45cd7..68c209c 100644 --- a/tests/static_patterns/test_config_integrity.py +++ b/tests/static_patterns/test_config_integrity.py @@ -5,6 +5,7 @@ These tests verify that configuration values are valid and consistent. Invalid configuration can cause runtime errors or incorrect behavior. """ + import pytest import re @@ -17,18 +18,15 @@ def test_api_base_url_is_valid_url(self): from prediction_analyzer.config import API_BASE_URL assert isinstance(API_BASE_URL, str) - assert API_BASE_URL.startswith("https://"), \ - "API_BASE_URL should use HTTPS" - assert len(API_BASE_URL) > 10, \ - "API_BASE_URL seems too short" + assert API_BASE_URL.startswith("https://"), "API_BASE_URL should use HTTPS" + assert len(API_BASE_URL) > 10, "API_BASE_URL seems too short" def test_default_trade_file_is_valid_filename(self): """DEFAULT_TRADE_FILE should be a valid filename.""" from prediction_analyzer.config import DEFAULT_TRADE_FILE assert isinstance(DEFAULT_TRADE_FILE, str) - assert DEFAULT_TRADE_FILE.endswith(".json"), \ - "Default trade file should be JSON" + assert DEFAULT_TRADE_FILE.endswith(".json"), "Default trade file should be JSON" # Should not contain path separators assert "/" not in DEFAULT_TRADE_FILE assert "\\" not in DEFAULT_TRADE_FILE @@ -48,12 +46,18 @@ def test_styles_has_all_trade_type_combinations(self): from prediction_analyzer.config import STYLES required_combinations = [ - ("Buy", "YES"), ("Buy", "NO"), - ("Sell", "YES"), ("Sell", "NO"), - ("Market Buy", "YES"), ("Market Buy", "NO"), - ("Market Sell", "YES"), ("Market Sell", "NO"), - ("Limit Buy", "YES"), ("Limit Buy", "NO"), - ("Limit Sell", "YES"), ("Limit Sell", "NO"), + ("Buy", "YES"), + ("Buy", "NO"), + ("Sell", "YES"), + ("Sell", "NO"), + ("Market Buy", "YES"), + ("Market Buy", "NO"), + ("Market Sell", "YES"), + ("Market Sell", "NO"), + ("Limit Buy", "YES"), + ("Limit Buy", "NO"), + ("Limit Sell", "YES"), + ("Limit Sell", "NO"), ] for combo in required_combinations: @@ -64,10 +68,8 @@ def test_styles_values_are_tuples(self): from prediction_analyzer.config import STYLES for key, value in STYLES.items(): - assert isinstance(value, tuple), \ - f"Style for {key} should be a tuple" - assert len(value) == 3, \ - f"Style for {key} should have 3 elements (color, marker, label)" + assert isinstance(value, tuple), f"Style for {key} should be a tuple" + assert len(value) == 3, f"Style for {key} should have 3 elements (color, marker, label)" def test_styles_colors_are_valid_hex(self): """Style colors should be valid hex color codes.""" @@ -76,8 +78,7 @@ def test_styles_colors_are_valid_hex(self): hex_pattern = re.compile(r"^#[0-9a-fA-F]{6}$") for key, (color, marker, label) in STYLES.items(): - assert hex_pattern.match(color), \ - f"Color '{color}' for {key} is not valid hex" + assert hex_pattern.match(color), f"Color '{color}' for {key} is not valid hex" def test_styles_markers_are_valid(self): """Style markers should be valid matplotlib markers.""" @@ -86,18 +87,17 @@ def test_styles_markers_are_valid(self): valid_markers = {"o", "x", "^", "v", "s", "d", "+", "*", ".", ","} for key, (color, marker, label) in STYLES.items(): - assert marker in valid_markers, \ - f"Marker '{marker}' for {key} is not a recognized marker" + assert ( + marker in valid_markers + ), f"Marker '{marker}' for {key} is not a recognized marker" def test_styles_labels_are_non_empty(self): """Style labels should be non-empty strings.""" from prediction_analyzer.config import STYLES for key, (color, marker, label) in STYLES.items(): - assert isinstance(label, str), \ - f"Label for {key} should be a string" - assert len(label) > 0, \ - f"Label for {key} should not be empty" + assert isinstance(label, str), f"Label for {key} should be a string" + assert len(label) > 0, f"Label for {key} should not be empty" class TestGetTradeStyleFunction: @@ -109,8 +109,7 @@ def test_get_trade_style_known_combination(self): for (trade_type, side), expected_style in STYLES.items(): result = get_trade_style(trade_type, side) - assert result == expected_style, \ - f"Unexpected style for ({trade_type}, {side})" + assert result == expected_style, f"Unexpected style for ({trade_type}, {side})" def test_get_trade_style_unknown_returns_fallback(self): """get_trade_style should return fallback for unknown combinations.""" @@ -152,8 +151,9 @@ def test_price_resolution_threshold_is_valid_range(self): """PRICE_RESOLUTION_THRESHOLD should be between 0 and 1.""" from prediction_analyzer.config import PRICE_RESOLUTION_THRESHOLD - assert 0 <= PRICE_RESOLUTION_THRESHOLD <= 1, \ - "Threshold should be between 0 and 1 (represents 0-100 cents)" + assert ( + 0 <= PRICE_RESOLUTION_THRESHOLD <= 1 + ), "Threshold should be between 0 and 1 (represents 0-100 cents)" class TestColorConsistency: @@ -175,8 +175,7 @@ def test_yes_buy_colors_are_consistent(self): r = int(color[1:3], 16) g = int(color[3:5], 16) b = int(color[5:7], 16) - assert g >= r and g >= b, \ - f"YES buy color {color} should be greenish" + assert g >= r and g >= b, f"YES buy color {color} should be greenish" def test_no_buy_colors_are_consistent(self): """NO buy colors should be consistent (magenta family).""" @@ -194,5 +193,4 @@ def test_no_buy_colors_are_consistent(self): g = int(color[3:5], 16) b = int(color[5:7], 16) # Magenta has high R and B, low G - assert r > g and b > g, \ - f"NO buy color {color} should be magenta-ish" + assert r > g and b > g, f"NO buy color {color} should be magenta-ish" diff --git a/tests/static_patterns/test_data_integrity.py b/tests/static_patterns/test_data_integrity.py index 08e10b0..45cd75b 100644 --- a/tests/static_patterns/test_data_integrity.py +++ b/tests/static_patterns/test_data_integrity.py @@ -6,6 +6,7 @@ maintains integrity. Data loss or corruption during I/O can cause subtle bugs that are hard to track down. """ + import pytest import json import tempfile @@ -57,9 +58,7 @@ def test_save_and_load_trades(self, sample_trades_list): """save_trades and load_trades should roundtrip correctly.""" from prediction_analyzer.trade_loader import save_trades, load_trades - with tempfile.NamedTemporaryFile( - mode="w", suffix=".json", delete=False - ) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: temp_path = f.name try: @@ -84,9 +83,7 @@ def test_save_empty_trades_list(self): """Saving empty list should create valid JSON.""" from prediction_analyzer.trade_loader import save_trades, load_trades - with tempfile.NamedTemporaryFile( - mode="w", suffix=".json", delete=False - ) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: temp_path = f.name try: @@ -100,9 +97,7 @@ def test_timestamp_serialization(self, sample_trade): """Timestamps should serialize to ISO format.""" from prediction_analyzer.trade_loader import save_trades - with tempfile.NamedTemporaryFile( - mode="w", suffix=".json", delete=False - ) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: temp_path = f.name try: @@ -129,15 +124,9 @@ def test_float_precision_preserved(self, sample_trade_factory): precise_value = 123.456789012345 - trade = sample_trade_factory( - price=precise_value, - cost=precise_value, - pnl=precise_value - ) + trade = sample_trade_factory(price=precise_value, cost=precise_value, pnl=precise_value) - with tempfile.NamedTemporaryFile( - mode="w", suffix=".json", delete=False - ) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: temp_path = f.name try: @@ -155,16 +144,9 @@ def test_zero_values_preserved(self, sample_trade_factory): """Zero values should be preserved.""" from prediction_analyzer.trade_loader import save_trades, load_trades - trade = sample_trade_factory( - price=0.0, - cost=0.0, - shares=0.0, - pnl=0.0 - ) + trade = sample_trade_factory(price=0.0, cost=0.0, shares=0.0, pnl=0.0) - with tempfile.NamedTemporaryFile( - mode="w", suffix=".json", delete=False - ) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: temp_path = f.name try: @@ -184,9 +166,7 @@ def test_negative_values_preserved(self, sample_trade_factory): trade = sample_trade_factory(pnl=-123.456) - with tempfile.NamedTemporaryFile( - mode="w", suffix=".json", delete=False - ) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: temp_path = f.name try: @@ -209,9 +189,7 @@ def test_unicode_strings_preserved(self, sample_trade_factory): trade = sample_trade_factory(market=unicode_market) - with tempfile.NamedTemporaryFile( - mode="w", suffix=".json", delete=False - ) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: temp_path = f.name try: @@ -229,9 +207,7 @@ def test_empty_strings_preserved(self, sample_trade_factory): # Note: empty market might get replaced with "Unknown" by loader trade = sample_trade_factory(tx_hash="") - with tempfile.NamedTemporaryFile( - mode="w", suffix=".json", delete=False - ) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: temp_path = f.name try: @@ -279,9 +255,7 @@ def test_export_to_json_creates_valid_json(self, sample_trades_list): """export_to_json should create valid, parseable JSON.""" from prediction_analyzer.reporting.report_data import export_to_json - with tempfile.NamedTemporaryFile( - mode="w", suffix=".json", delete=False - ) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: temp_path = f.name try: @@ -302,9 +276,7 @@ def test_export_to_csv_creates_valid_csv(self, sample_trades_list): from prediction_analyzer.reporting.report_data import export_to_csv import pandas as pd - with tempfile.NamedTemporaryFile( - mode="w", suffix=".csv", delete=False - ) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f: temp_path = f.name try: @@ -322,9 +294,7 @@ def test_export_to_excel_creates_valid_xlsx(self, sample_trades_list): from prediction_analyzer.reporting.report_data import export_to_excel import pandas as pd - with tempfile.NamedTemporaryFile( - mode="w", suffix=".xlsx", delete=False - ) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".xlsx", delete=False) as f: temp_path = f.name try: diff --git a/tests/static_patterns/test_dataclass_contracts.py b/tests/static_patterns/test_dataclass_contracts.py index 79df2ca..97f056f 100644 --- a/tests/static_patterns/test_dataclass_contracts.py +++ b/tests/static_patterns/test_dataclass_contracts.py @@ -6,6 +6,7 @@ and behavior. Changes to the dataclass can break serialization, PnL calculations, and other dependent code. """ + import pytest from dataclasses import fields, is_dataclass from datetime import datetime @@ -18,6 +19,7 @@ class TestTradeDataclassStructure: def test_trade_is_dataclass(self): """Trade should be a dataclass.""" from prediction_analyzer.trade_loader import Trade + assert is_dataclass(Trade) def test_trade_required_fields(self): @@ -35,8 +37,9 @@ def test_trade_required_fields(self): "type", "side", } - assert required_fields.issubset(field_names), \ - f"Missing fields: {required_fields - field_names}" + assert required_fields.issubset( + field_names + ), f"Missing fields: {required_fields - field_names}" def test_trade_optional_fields(self): """Trade should have expected optional fields.""" @@ -44,16 +47,18 @@ def test_trade_optional_fields(self): field_names = {f.name for f in fields(Trade)} optional_fields = {"pnl", "tx_hash"} - assert optional_fields.issubset(field_names), \ - f"Missing optional fields: {optional_fields - field_names}" + assert optional_fields.issubset( + field_names + ), f"Missing optional fields: {optional_fields - field_names}" def test_trade_field_count(self): """Trade should have exactly 14 fields.""" from prediction_analyzer.trade_loader import Trade field_count = len(fields(Trade)) - assert field_count == 14, \ - f"Expected 14 fields, got {field_count}. Fields were added or removed." + assert ( + field_count == 14 + ), f"Expected 14 fields, got {field_count}. Fields were added or removed." class TestTradeFieldTypes: @@ -119,7 +124,7 @@ def test_pnl_default_is_zero(self): shares=10.0, cost=5.0, type="Buy", - side="YES" + side="YES", ) assert trade.pnl == 0.0 @@ -135,7 +140,7 @@ def test_tx_hash_default_is_none(self): shares=10.0, cost=5.0, type="Buy", - side="YES" + side="YES", ) assert trade.tx_hash is None @@ -157,7 +162,7 @@ def test_trade_with_all_fields(self): type="Market Buy", side="NO", pnl=10.25, - tx_hash="0xdeadbeef" + tx_hash="0xdeadbeef", ) assert trade.market == "Full Test Market" @@ -183,7 +188,7 @@ def test_trade_with_minimum_fields(self): shares=10.0, cost=5.0, type="Buy", - side="YES" + side="YES", ) assert trade.market == "Min Market" @@ -231,16 +236,8 @@ class TestTradeEquality: def test_identical_trades_are_equal(self, sample_trade_factory): """Two trades with same values should be equal.""" - trade1 = sample_trade_factory( - timestamp=datetime(2024, 1, 1), - price=50.0, - pnl=10.0 - ) - trade2 = sample_trade_factory( - timestamp=datetime(2024, 1, 1), - price=50.0, - pnl=10.0 - ) + trade1 = sample_trade_factory(timestamp=datetime(2024, 1, 1), price=50.0, pnl=10.0) + trade2 = sample_trade_factory(timestamp=datetime(2024, 1, 1), price=50.0, pnl=10.0) assert trade1 == trade2 @@ -257,11 +254,7 @@ class TestTradeTypeValues: def test_valid_trade_types(self, sample_trade_factory): """All valid trade types should be instantiable.""" - valid_types = [ - "Buy", "Sell", - "Market Buy", "Market Sell", - "Limit Buy", "Limit Sell" - ] + valid_types = ["Buy", "Sell", "Market Buy", "Market Sell", "Limit Buy", "Limit Sell"] for trade_type in valid_types: trade = sample_trade_factory(type=trade_type) diff --git a/tests/static_patterns/test_edge_cases.py b/tests/static_patterns/test_edge_cases.py index 5698531..d10eea4 100644 --- a/tests/static_patterns/test_edge_cases.py +++ b/tests/static_patterns/test_edge_cases.py @@ -5,6 +5,7 @@ These tests verify that the codebase handles edge cases gracefully. Edge cases often cause crashes or unexpected behavior when not handled. """ + import pytest from datetime import datetime, timezone import numpy as np @@ -93,12 +94,7 @@ def test_calculate_roi_zero_investment(self): def test_trade_with_zero_values(self, sample_trade_factory): """Trade with zero values should be valid.""" - trade = sample_trade_factory( - price=0.0, - shares=0.0, - cost=0.0, - pnl=0.0 - ) + trade = sample_trade_factory(price=0.0, shares=0.0, cost=0.0, pnl=0.0) assert trade.price == 0.0 assert trade.shares == 0.0 diff --git a/tests/static_patterns/test_filter_contracts.py b/tests/static_patterns/test_filter_contracts.py index 845be95..d6d3e5f 100644 --- a/tests/static_patterns/test_filter_contracts.py +++ b/tests/static_patterns/test_filter_contracts.py @@ -9,6 +9,7 @@ 3. Not modify the original trades 4. Handle edge cases gracefully """ + import pytest from datetime import datetime, timedelta @@ -74,11 +75,7 @@ def test_date_range_filtering(self, trades_spanning_year): """Date range should filter to trades within range.""" from prediction_analyzer.filters import filter_by_date - result = filter_by_date( - trades_spanning_year, - start="2024-03-01", - end="2024-08-31" - ) + result = filter_by_date(trades_spanning_year, start="2024-03-01", end="2024-08-31") for trade in result: assert trade.timestamp >= datetime(2024, 3, 1) assert trade.timestamp <= datetime(2024, 8, 31, 23, 59, 59) @@ -146,10 +143,7 @@ def test_nonexistent_type_returns_empty(self, sample_trades_list): """Non-existent type should return empty list.""" from prediction_analyzer.filters import filter_by_trade_type - result = filter_by_trade_type( - sample_trades_list, - types=["NonExistentType"] - ) + result = filter_by_trade_type(sample_trades_list, types=["NonExistentType"]) assert result == [] @@ -293,7 +287,7 @@ def test_chained_filters(self, sample_trades_list): filter_by_date, filter_by_trade_type, filter_by_side, - filter_by_pnl + filter_by_pnl, ) result = sample_trades_list diff --git a/tests/static_patterns/test_imports.py b/tests/static_patterns/test_imports.py index 6521d81..11816c7 100644 --- a/tests/static_patterns/test_imports.py +++ b/tests/static_patterns/test_imports.py @@ -6,6 +6,7 @@ Run these tests BEFORE implementing new features to ensure the codebase is in a stable state. """ + import pytest import sys @@ -16,12 +17,14 @@ class TestPackageImports: def test_main_package_import(self): """Main package should import without errors.""" import prediction_analyzer + assert prediction_analyzer is not None assert hasattr(prediction_analyzer, "__version__") def test_package_version_format(self): """Package version should be a valid semantic version string.""" import prediction_analyzer + version = prediction_analyzer.__version__ assert isinstance(version, str) # Basic semver check: should have at least major.minor format @@ -37,22 +40,26 @@ class TestCoreModuleImports: def test_trade_loader_import(self): """trade_loader module should import.""" from prediction_analyzer import trade_loader + assert trade_loader is not None def test_trade_loader_classes(self): """trade_loader should export Trade dataclass.""" from prediction_analyzer.trade_loader import Trade + assert Trade is not None def test_trade_loader_functions(self): """trade_loader should export core functions.""" from prediction_analyzer.trade_loader import load_trades, save_trades + assert callable(load_trades) assert callable(save_trades) def test_pnl_import(self): """pnl module should import.""" from prediction_analyzer import pnl + assert pnl is not None def test_pnl_functions(self): @@ -61,8 +68,9 @@ def test_pnl_functions(self): calculate_pnl, calculate_global_pnl_summary, calculate_market_pnl, - calculate_market_pnl_summary + calculate_market_pnl_summary, ) + assert callable(calculate_pnl) assert callable(calculate_global_pnl_summary) assert callable(calculate_market_pnl) @@ -71,6 +79,7 @@ def test_pnl_functions(self): def test_filters_import(self): """filters module should import.""" from prediction_analyzer import filters + assert filters is not None def test_filters_functions(self): @@ -79,8 +88,9 @@ def test_filters_functions(self): filter_by_date, filter_by_trade_type, filter_by_side, - filter_by_pnl + filter_by_pnl, ) + assert callable(filter_by_date) assert callable(filter_by_trade_type) assert callable(filter_by_side) @@ -89,6 +99,7 @@ def test_filters_functions(self): def test_config_import(self): """config module should import.""" from prediction_analyzer import config + assert config is not None def test_config_exports(self): @@ -98,8 +109,9 @@ def test_config_exports(self): DEFAULT_TRADE_FILE, STYLES, get_trade_style, - PRICE_RESOLUTION_THRESHOLD + PRICE_RESOLUTION_THRESHOLD, ) + assert isinstance(API_BASE_URL, str) assert isinstance(DEFAULT_TRADE_FILE, str) assert isinstance(STYLES, dict) @@ -109,11 +121,13 @@ def test_config_exports(self): def test_trade_filter_import(self): """trade_filter module should import.""" from prediction_analyzer import trade_filter + assert trade_filter is not None def test_inference_import(self): """inference module should import.""" from prediction_analyzer import inference + assert inference is not None @@ -123,31 +137,37 @@ class TestChartModuleImports: def test_charts_package_import(self): """charts package should import.""" from prediction_analyzer import charts + assert charts is not None def test_simple_chart_import(self): """simple chart module should import.""" from prediction_analyzer.charts import simple + assert simple is not None def test_simple_chart_function(self): """simple module should export generate_simple_chart.""" from prediction_analyzer.charts.simple import generate_simple_chart + assert callable(generate_simple_chart) def test_pro_chart_import(self): """pro chart module should import.""" from prediction_analyzer.charts import pro + assert pro is not None def test_enhanced_chart_import(self): """enhanced chart module should import.""" from prediction_analyzer.charts import enhanced + assert enhanced is not None def test_global_chart_import(self): """global_chart module should import.""" from prediction_analyzer.charts import global_chart + assert global_chart is not None @@ -157,11 +177,13 @@ class TestUtilityModuleImports: def test_utils_package_import(self): """utils package should import.""" from prediction_analyzer import utils + assert utils is not None def test_math_utils_import(self): """math_utils module should import.""" from prediction_analyzer.utils import math_utils + assert math_utils is not None def test_math_utils_functions(self): @@ -170,8 +192,9 @@ def test_math_utils_functions(self): moving_average, weighted_average, safe_divide, - calculate_roi + calculate_roi, ) + assert callable(moving_average) assert callable(weighted_average) assert callable(safe_divide) @@ -180,6 +203,7 @@ def test_math_utils_functions(self): def test_time_utils_import(self): """time_utils module should import.""" from prediction_analyzer.utils import time_utils + assert time_utils is not None def test_time_utils_functions(self): @@ -187,8 +211,9 @@ def test_time_utils_functions(self): from prediction_analyzer.utils.time_utils import ( parse_date, format_timestamp, - get_date_range + get_date_range, ) + assert callable(parse_date) assert callable(format_timestamp) assert callable(get_date_range) @@ -196,16 +221,19 @@ def test_time_utils_functions(self): def test_auth_import(self): """auth module should import.""" from prediction_analyzer.utils import auth + assert auth is not None def test_data_utils_import(self): """data utils module should import.""" from prediction_analyzer.utils import data + assert data is not None def test_export_utils_import(self): """export utils module should import.""" from prediction_analyzer.utils import export + assert export is not None @@ -215,16 +243,19 @@ class TestReportingModuleImports: def test_reporting_package_import(self): """reporting package should import.""" from prediction_analyzer import reporting + assert reporting is not None def test_report_text_import(self): """report_text module should import.""" from prediction_analyzer.reporting import report_text + assert report_text is not None def test_report_data_import(self): """report_data module should import.""" from prediction_analyzer.reporting import report_data + assert report_data is not None def test_report_data_functions(self): @@ -232,8 +263,9 @@ def test_report_data_functions(self): from prediction_analyzer.reporting.report_data import ( export_to_csv, export_to_excel, - export_to_json + export_to_json, ) + assert callable(export_to_csv) assert callable(export_to_excel) assert callable(export_to_json) @@ -245,11 +277,13 @@ class TestCoreSubpackageImports: def test_core_package_import(self): """core package should import.""" from prediction_analyzer import core + assert core is not None def test_interactive_import(self): """interactive module should import.""" from prediction_analyzer.core import interactive + assert interactive is not None @@ -259,6 +293,7 @@ class TestMainModuleImport: def test_main_module_import(self): """__main__ module should import.""" from prediction_analyzer import __main__ + assert __main__ is not None @@ -281,23 +316,9 @@ def test_all_modules_import_together(self): trade_filter, inference, ) - from prediction_analyzer.charts import ( - simple, - pro, - enhanced, - global_chart - ) - from prediction_analyzer.utils import ( - math_utils, - time_utils, - auth, - data, - export - ) - from prediction_analyzer.reporting import ( - report_text, - report_data - ) + from prediction_analyzer.charts import simple, pro, enhanced, global_chart + from prediction_analyzer.utils import math_utils, time_utils, auth, data, export + from prediction_analyzer.reporting import report_text, report_data from prediction_analyzer.core import interactive # If we get here, no circular import errors @@ -310,29 +331,35 @@ class TestDependencyImports: def test_pandas_import(self): """pandas should be importable.""" import pandas as pd + assert pd is not None def test_numpy_import(self): """numpy should be importable.""" import numpy as np + assert np is not None def test_matplotlib_import(self): """matplotlib should be importable.""" import matplotlib + assert matplotlib is not None def test_plotly_import(self): """plotly should be importable.""" import plotly + assert plotly is not None def test_requests_import(self): """requests should be importable.""" import requests + assert requests is not None def test_openpyxl_import(self): """openpyxl should be importable.""" import openpyxl + assert openpyxl is not None diff --git a/tests/static_patterns/test_pnl_contracts.py b/tests/static_patterns/test_pnl_contracts.py index b829629..6ab0f62 100644 --- a/tests/static_patterns/test_pnl_contracts.py +++ b/tests/static_patterns/test_pnl_contracts.py @@ -5,6 +5,7 @@ These tests verify that PnL calculation functions maintain their behavior contracts. PnL calculations must be accurate and consistent. """ + import pytest import pandas as pd from datetime import datetime @@ -62,11 +63,7 @@ def test_cumulative_pnl_is_cumsum_of_trade_pnl(self, sample_trades_list): result = calculate_pnl(sample_trades_list) expected_cumsum = result["trade_pnl"].cumsum() - pd.testing.assert_series_equal( - result["cumulative_pnl"], - expected_cumsum, - check_names=False - ) + pd.testing.assert_series_equal(result["cumulative_pnl"], expected_cumsum, check_names=False) def test_sorted_by_timestamp(self, sample_trades_list): """Result should be sorted by timestamp.""" @@ -145,9 +142,7 @@ def test_winning_plus_losing_plus_breakeven_lte_total(self, sample_trades_list): result = calculate_global_pnl_summary(sample_trades_list) count_sum = ( - result["winning_trades"] + - result["losing_trades"] + - result.get("breakeven_trades", 0) + result["winning_trades"] + result["losing_trades"] + result.get("breakeven_trades", 0) ) assert count_sum <= result["total_trades"] @@ -212,9 +207,7 @@ def test_trade_count_is_correct(self, multi_market_trades): result = calculate_market_pnl(multi_market_trades) for slug, stats in result.items(): - expected_count = len([ - t for t in multi_market_trades if t.market_slug == slug - ]) + expected_count = len([t for t in multi_market_trades if t.market_slug == slug]) assert stats["trade_count"] == expected_count @@ -249,8 +242,13 @@ def test_has_required_keys(self, sample_trades_list): result = calculate_market_pnl_summary(sample_trades_list) required_keys = { - "market_title", "total_trades", "total_pnl", - "avg_pnl", "winning_trades", "losing_trades", "win_rate" + "market_title", + "total_trades", + "total_pnl", + "avg_pnl", + "winning_trades", + "losing_trades", + "win_rate", } assert required_keys.issubset(set(result.keys())) diff --git a/tests/static_patterns/test_utility_functions.py b/tests/static_patterns/test_utility_functions.py index 82e93ec..753f088 100644 --- a/tests/static_patterns/test_utility_functions.py +++ b/tests/static_patterns/test_utility_functions.py @@ -6,6 +6,7 @@ Utility functions are used throughout the codebase, so any issues here will cascade into other modules. """ + import pytest import numpy as np from datetime import datetime, timedelta diff --git a/tests/test_bugfixes.py b/tests/test_bugfixes.py index 18efa1c..125770a 100644 --- a/tests/test_bugfixes.py +++ b/tests/test_bugfixes.py @@ -5,6 +5,7 @@ Each test class corresponds to a specific bug fix and verifies the fix works correctly without regressions. """ + import json import math import pytest @@ -14,11 +15,11 @@ from prediction_analyzer.trade_loader import Trade - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- + def _make_trade(**kwargs): """Create a Trade with sensible defaults, overriding with kwargs.""" defaults = { @@ -44,6 +45,7 @@ def _make_trade(**kwargs): # sample variance (ddof=1). Fix: use ddof=1 for Sortino as well. # =========================================================================== + class TestSortinoDdofConsistency: """Sortino ratio should use ddof=1 (Bessel's correction), matching Sharpe.""" @@ -59,7 +61,7 @@ def test_sortino_uses_sample_variance(self): mean_ret = np.mean(arr) downside = np.minimum(arr, 0.0) # ddof=1: divide by (N - 1) - expected_dd = np.sqrt(np.sum(downside ** 2) / (len(arr) - 1)) + expected_dd = np.sqrt(np.sum(downside**2) / (len(arr) - 1)) expected_sortino = mean_ret / expected_dd assert abs(result["sortino_ratio"] - round(expected_sortino, 4)) < 1e-4 @@ -100,6 +102,7 @@ def test_single_trade_returns_zero(self): # Fix: use cur_symbol derived from the portfolio's currency. # =========================================================================== + class TestReportCurrencySymbol: """Per-market report sections should use the portfolio's currency symbol.""" @@ -110,12 +113,24 @@ def test_mana_trades_use_mana_symbol_in_top_markets(self): trades = [ _make_trade( - market="Q1", market_slug="q1", pnl=10.0, pnl_is_set=True, - cost=5.0, type="Buy", currency="MANA", source="manifold", + market="Q1", + market_slug="q1", + pnl=10.0, + pnl_is_set=True, + cost=5.0, + type="Buy", + currency="MANA", + source="manifold", ), _make_trade( - market="Q1", market_slug="q1", pnl=-3.0, pnl_is_set=True, - cost=3.0, type="Sell", currency="MANA", source="manifold", + market="Q1", + market_slug="q1", + pnl=-3.0, + pnl_is_set=True, + cost=3.0, + type="Sell", + currency="MANA", + source="manifold", ), ] buf = io.StringIO() @@ -124,9 +139,9 @@ def test_mana_trades_use_mana_symbol_in_top_markets(self): # The top-markets line should contain M$ not bare $ lines = [l for l in output.splitlines() if "q1" in l.lower() or "Q1" in l] - assert any("M$" in line for line in lines), ( - f"Expected M$ in top-markets lines, got: {lines}" - ) + assert any( + "M$" in line for line in lines + ), f"Expected M$ in top-markets lines, got: {lines}" def test_usd_trades_use_dollar_sign(self): """USD trades should still use $ in the per-market section.""" @@ -135,8 +150,13 @@ def test_usd_trades_use_dollar_sign(self): trades = [ _make_trade( - market="Election", market_slug="election", pnl=5.0, - pnl_is_set=True, cost=10.0, type="Buy", currency="USD", + market="Election", + market_slug="election", + pnl=5.0, + pnl_is_set=True, + cost=10.0, + type="Buy", + currency="USD", ), ] with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as f: @@ -156,6 +176,7 @@ def test_usd_trades_use_dollar_sign(self): # NaN/Inf sanitization and producing invalid JSON. # =========================================================================== + class TestJsonExportUsesToDict: """export_to_json should use to_dict() for NaN/Inf sanitization.""" @@ -163,7 +184,7 @@ def test_nan_pnl_produces_valid_json(self, tmp_path): """NaN in pnl should be sanitized to 0.0 in JSON output.""" from prediction_analyzer.reporting.report_data import export_to_json - trades = [_make_trade(pnl=float('nan'), pnl_is_set=True)] + trades = [_make_trade(pnl=float("nan"), pnl_is_set=True)] outfile = str(tmp_path / "test.json") export_to_json(trades, filename=outfile) @@ -176,7 +197,7 @@ def test_inf_cost_produces_valid_json(self, tmp_path): """Infinity in cost should be sanitized in JSON output.""" from prediction_analyzer.reporting.report_data import export_to_json - trades = [_make_trade(cost=float('inf'))] + trades = [_make_trade(cost=float("inf"))] outfile = str(tmp_path / "test.json") export_to_json(trades, filename=outfile) @@ -205,6 +226,7 @@ def test_timestamp_serialized_as_string(self, tmp_path): # calculate_market_pnl. Fix: both use trade.market_slug. # =========================================================================== + class TestGroupingKeyConsistency: """group_trades_by_market and calculate_market_pnl should use the same key.""" @@ -244,6 +266,7 @@ def test_consistent_with_calculate_market_pnl(self): # Fix: only add the extra day for string dates (midnight-based). # =========================================================================== + class TestFilterByDateDatetimeEnd: """filter_by_date should not add 24h when end is a datetime.""" @@ -280,6 +303,7 @@ def test_string_end_still_includes_full_day(self): # This is now explicit rather than an accidental fallthrough. # =========================================================================== + class TestManifoldZeroAmountTrade: """Manifold provider should handle zero-amount trades explicitly.""" @@ -322,6 +346,7 @@ def test_negative_amount_is_sell(self): # actual default is PRICE_RESOLUTION_THRESHOLD = 0.85. # =========================================================================== + class TestInferenceDocstring: """Verify inference function's default threshold matches documentation.""" diff --git a/tests/test_bugfixes_audit2.py b/tests/test_bugfixes_audit2.py index 60ffdb9..5512dce 100644 --- a/tests/test_bugfixes_audit2.py +++ b/tests/test_bugfixes_audit2.py @@ -6,6 +6,7 @@ Bug #2: _apply_filters silently returns empty list when min_pnl > max_pnl Bug #3: filter_trades stores empty string/list values in active_filters """ + import json import math import os @@ -19,11 +20,11 @@ from prediction_analyzer.exceptions import InvalidFilterError from prediction_mcp.state import session - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- + def _make_trade(**kwargs): """Create a Trade with sensible defaults, overriding with kwargs.""" defaults = { @@ -48,6 +49,7 @@ def _make_trade(**kwargs): # Bug #1: export_tools path traversal check rejects valid absolute paths # =========================================================================== + class TestExportPathTraversal: """export_trades should allow valid absolute paths like /tmp/foo.csv.""" @@ -68,10 +70,15 @@ def test_absolute_tmp_path_allowed(self): path = f.name try: - result = asyncio.run(export_tools.handle_tool("export_trades", { - "format": "csv", - "output_path": path, - })) + result = asyncio.run( + export_tools.handle_tool( + "export_trades", + { + "format": "csv", + "output_path": path, + }, + ) + ) data = json.loads(result[0].text) assert data["trade_count"] == 1 assert data["format"] == "csv" @@ -84,10 +91,15 @@ def test_dotdot_path_rejected(self): """Paths with '..' components should still be rejected.""" from prediction_mcp.tools import export_tools - result = asyncio.run(export_tools.handle_tool("export_trades", { - "format": "csv", - "output_path": "/tmp/../etc/test.csv", - })) + result = asyncio.run( + export_tools.handle_tool( + "export_trades", + { + "format": "csv", + "output_path": "/tmp/../etc/test.csv", + }, + ) + ) assert "'..' " in result[0].text or "error" in result[0].text.lower() def test_relative_path_allowed(self): @@ -96,10 +108,15 @@ def test_relative_path_allowed(self): path = os.path.join(tempfile.gettempdir(), "test_export_rel.json") try: - result = asyncio.run(export_tools.handle_tool("export_trades", { - "format": "json", - "output_path": path, - })) + result = asyncio.run( + export_tools.handle_tool( + "export_trades", + { + "format": "json", + "output_path": path, + }, + ) + ) data = json.loads(result[0].text) assert data["trade_count"] == 1 finally: @@ -111,6 +128,7 @@ def test_relative_path_allowed(self): # Bug #2: _apply_filters silently returns empty when min_pnl > max_pnl # =========================================================================== + class TestApplyFiltersMinMaxPnl: """apply_filters should raise InvalidFilterError when min_pnl > max_pnl.""" @@ -168,6 +186,7 @@ def test_only_max_pnl(self): # Bug #3: filter_trades stores empty string/list values in active_filters # =========================================================================== + class TestActiveFiltersNoEmpty: """filter_trades should not store empty strings or empty lists as active filters.""" @@ -184,10 +203,15 @@ def test_empty_string_not_stored(self): """Empty string filter values should not appear in active_filters.""" from prediction_mcp.tools import filter_tools - result = asyncio.run(filter_tools.handle_tool("filter_trades", { - "market_slug": "", - "start_date": "", - })) + result = asyncio.run( + filter_tools.handle_tool( + "filter_trades", + { + "market_slug": "", + "start_date": "", + }, + ) + ) data = json.loads(result[0].text) assert "market_slug" not in data["active_filters"] assert "start_date" not in data["active_filters"] @@ -196,10 +220,15 @@ def test_empty_list_not_stored(self): """Empty list filter values should not appear in active_filters.""" from prediction_mcp.tools import filter_tools - result = asyncio.run(filter_tools.handle_tool("filter_trades", { - "trade_types": [], - "sides": [], - })) + result = asyncio.run( + filter_tools.handle_tool( + "filter_trades", + { + "trade_types": [], + "sides": [], + }, + ) + ) data = json.loads(result[0].text) assert "trade_types" not in data["active_filters"] assert "sides" not in data["active_filters"] @@ -208,10 +237,15 @@ def test_valid_values_still_stored(self): """Non-empty filter values should still be stored.""" from prediction_mcp.tools import filter_tools - result = asyncio.run(filter_tools.handle_tool("filter_trades", { - "sides": ["YES"], - "min_pnl": 0.0, - })) + result = asyncio.run( + filter_tools.handle_tool( + "filter_trades", + { + "sides": ["YES"], + "min_pnl": 0.0, + }, + ) + ) data = json.loads(result[0].text) assert "sides" in data["active_filters"] # 0.0 is a valid filter value (not empty) @@ -222,6 +256,7 @@ def test_valid_values_still_stored(self): # Additional: _apply_filters combined filter tests (coverage gap) # =========================================================================== + class TestApplyFiltersCombined: """Test multiple filters applied simultaneously.""" @@ -235,11 +270,14 @@ def test_date_and_side_combined(self): _make_trade(timestamp=datetime(2024, 1, 10), side="YES"), _make_trade(timestamp=datetime(2024, 2, 1), side="YES"), ] - result = apply_filters(trades, { - "start_date": "2024-01-01", - "end_date": "2024-01-15", - "sides": ["YES"], - }) + result = apply_filters( + trades, + { + "start_date": "2024-01-01", + "end_date": "2024-01-15", + "sides": ["YES"], + }, + ) assert len(result) == 2 def test_all_filters_at_once(self): @@ -250,27 +288,36 @@ def test_all_filters_at_once(self): _make_trade( market_slug="m1", timestamp=datetime(2024, 3, 1), - type="Buy", side="YES", pnl=5.0, + type="Buy", + side="YES", + pnl=5.0, ), _make_trade( market_slug="m1", timestamp=datetime(2024, 3, 2), - type="Sell", side="YES", pnl=-2.0, + type="Sell", + side="YES", + pnl=-2.0, ), _make_trade( market_slug="m2", timestamp=datetime(2024, 3, 1), - type="Buy", side="NO", pnl=10.0, + type="Buy", + side="NO", + pnl=10.0, ), ] - result = apply_filters(trades, { - "market_slug": "m1", - "start_date": "2024-03-01", - "end_date": "2024-03-03", - "trade_types": ["Buy"], - "sides": ["YES"], - "min_pnl": 0, - }) + result = apply_filters( + trades, + { + "market_slug": "m1", + "start_date": "2024-03-01", + "end_date": "2024-03-03", + "trade_types": ["Buy"], + "sides": ["YES"], + "min_pnl": 0, + }, + ) assert len(result) == 1 assert result[0].pnl == 5.0 @@ -294,11 +341,13 @@ def test_filters_on_empty_trades(self): # Additional: Kalshi normalize_trade regression tests # =========================================================================== + class TestKalshiNormalizeTrade: """Kalshi normalize_trade should handle both new and legacy price fields.""" def _provider(self): from prediction_analyzer.providers.kalshi import KalshiProvider + return KalshiProvider() def test_fixed_price_field_used_when_present(self): @@ -451,6 +500,7 @@ def test_invalid_fixed_falls_back_to_legacy(self): # Additional: PnL calculator FIFO tests # =========================================================================== + class TestPnlCalculatorFifo: """compute_realized_pnl FIFO matching correctness.""" @@ -459,10 +509,12 @@ def test_basic_buy_sell_pair(self): from prediction_analyzer.providers.pnl_calculator import compute_realized_pnl trades = [ - _make_trade(type="Buy", price=0.40, shares=10, cost=4.0, - timestamp=datetime(2024, 1, 1)), - _make_trade(type="Sell", price=0.60, shares=10, cost=6.0, - timestamp=datetime(2024, 1, 2)), + _make_trade( + type="Buy", price=0.40, shares=10, cost=4.0, timestamp=datetime(2024, 1, 1) + ), + _make_trade( + type="Sell", price=0.60, shares=10, cost=6.0, timestamp=datetime(2024, 1, 2) + ), ] result = compute_realized_pnl(trades) sell = [t for t in result if t.type == "Sell"][0] @@ -475,12 +527,11 @@ def test_fifo_order_matters(self): from prediction_analyzer.providers.pnl_calculator import compute_realized_pnl trades = [ - _make_trade(type="Buy", price=0.30, shares=5, cost=1.5, - timestamp=datetime(2024, 1, 1)), - _make_trade(type="Buy", price=0.50, shares=5, cost=2.5, - timestamp=datetime(2024, 1, 2)), - _make_trade(type="Sell", price=0.60, shares=5, cost=3.0, - timestamp=datetime(2024, 1, 3)), + _make_trade(type="Buy", price=0.30, shares=5, cost=1.5, timestamp=datetime(2024, 1, 1)), + _make_trade(type="Buy", price=0.50, shares=5, cost=2.5, timestamp=datetime(2024, 1, 2)), + _make_trade( + type="Sell", price=0.60, shares=5, cost=3.0, timestamp=datetime(2024, 1, 3) + ), ] result = compute_realized_pnl(trades) sell = [t for t in result if t.type == "Sell"][0] @@ -493,11 +544,18 @@ def test_provider_pnl_not_overwritten(self): from prediction_analyzer.providers.pnl_calculator import compute_realized_pnl trades = [ - _make_trade(type="Buy", price=0.40, shares=10, cost=4.0, - timestamp=datetime(2024, 1, 1)), - _make_trade(type="Sell", price=0.60, shares=10, cost=6.0, - pnl=99.0, pnl_is_set=True, - timestamp=datetime(2024, 1, 2)), + _make_trade( + type="Buy", price=0.40, shares=10, cost=4.0, timestamp=datetime(2024, 1, 1) + ), + _make_trade( + type="Sell", + price=0.60, + shares=10, + cost=6.0, + pnl=99.0, + pnl_is_set=True, + timestamp=datetime(2024, 1, 2), + ), ] result = compute_realized_pnl(trades) sell = [t for t in result if t.type == "Sell"][0] diff --git a/tests/test_fees_wash_sales.py b/tests/test_fees_wash_sales.py index 76b1940..7f0de29 100644 --- a/tests/test_fees_wash_sales.py +++ b/tests/test_fees_wash_sales.py @@ -5,6 +5,7 @@ - Wash sale detection (IRS §1091 for prediction markets) - Data completeness (total_trades_in_scope) """ + import pytest from datetime import datetime @@ -36,6 +37,7 @@ def _make_trade(**kwargs): # Fee tracking # =========================================================================== + class TestFeeTracking: """Fee field on Trade and fee accumulation in tax report.""" @@ -59,8 +61,14 @@ def test_fee_in_to_dict(self): def test_tax_report_total_fees_buy_only(self): """Tax report accumulates fees from buy trades.""" trades = [ - _make_trade(timestamp=datetime(2024, 1, 1), type="Buy", - price=0.50, shares=10, cost=5.50, fee=0.50), + _make_trade( + timestamp=datetime(2024, 1, 1), + type="Buy", + price=0.50, + shares=10, + cost=5.50, + fee=0.50, + ), ] result = calculate_capital_gains(trades, tax_year=2024) assert result["total_fees"] == pytest.approx(0.50) @@ -68,10 +76,22 @@ def test_tax_report_total_fees_buy_only(self): def test_tax_report_total_fees_buy_and_sell(self): """Tax report accumulates fees from both buy and sell trades.""" trades = [ - _make_trade(timestamp=datetime(2024, 1, 1), type="Buy", - price=0.50, shares=10, cost=5.50, fee=0.50), - _make_trade(timestamp=datetime(2024, 6, 1), type="Sell", - price=0.70, shares=10, cost=6.70, fee=0.30), + _make_trade( + timestamp=datetime(2024, 1, 1), + type="Buy", + price=0.50, + shares=10, + cost=5.50, + fee=0.50, + ), + _make_trade( + timestamp=datetime(2024, 6, 1), + type="Sell", + price=0.70, + shares=10, + cost=6.70, + fee=0.30, + ), ] result = calculate_capital_gains(trades, tax_year=2024) assert result["total_fees"] == pytest.approx(0.80) @@ -79,10 +99,17 @@ def test_tax_report_total_fees_buy_and_sell(self): def test_tax_transaction_includes_sell_fee(self): """Tax transaction output includes fee when sell has fee > 0.""" trades = [ - _make_trade(timestamp=datetime(2024, 1, 1), type="Buy", - price=0.50, shares=10, cost=5.0, fee=0.0), - _make_trade(timestamp=datetime(2024, 6, 1), type="Sell", - price=0.70, shares=10, cost=7.0, fee=0.25), + _make_trade( + timestamp=datetime(2024, 1, 1), type="Buy", price=0.50, shares=10, cost=5.0, fee=0.0 + ), + _make_trade( + timestamp=datetime(2024, 6, 1), + type="Sell", + price=0.70, + shares=10, + cost=7.0, + fee=0.25, + ), ] result = calculate_capital_gains(trades, tax_year=2024) assert result["transaction_count"] == 1 @@ -93,10 +120,17 @@ def test_tax_transaction_includes_sell_fee(self): def test_tax_transaction_omits_zero_fee(self): """Tax transaction output omits fee when sell fee is 0.""" trades = [ - _make_trade(timestamp=datetime(2024, 1, 1), type="Buy", - price=0.50, shares=10, cost=5.0, fee=0.0), - _make_trade(timestamp=datetime(2024, 6, 1), type="Sell", - price=0.70, shares=10, cost=7.0, fee=0.0), + _make_trade( + timestamp=datetime(2024, 1, 1), type="Buy", price=0.50, shares=10, cost=5.0, fee=0.0 + ), + _make_trade( + timestamp=datetime(2024, 6, 1), + type="Sell", + price=0.70, + shares=10, + cost=7.0, + fee=0.0, + ), ] result = calculate_capital_gains(trades, tax_year=2024) tx = result["transactions"][0] @@ -108,11 +142,17 @@ def test_large_fee_accumulation(self): for i in range(100): day = 1 + (i % 28) month = 1 + (i // 28) % 12 - trades.append(_make_trade( - timestamp=datetime(2024, month, day, 10, 0, 0), - type="Buy", price=0.50, shares=1000, cost=500.0 + 600.0, - fee=600.0, market_slug=f"market-{i}", - )) + trades.append( + _make_trade( + timestamp=datetime(2024, month, day, 10, 0, 0), + type="Buy", + price=0.50, + shares=1000, + cost=500.0 + 600.0, + fee=600.0, + market_slug=f"market-{i}", + ) + ) result = calculate_capital_gains(trades, tax_year=2024) assert result["total_fees"] == pytest.approx(60000.0) @@ -121,19 +161,38 @@ def test_large_fee_accumulation(self): # Wash sale detection # =========================================================================== + class TestWashSaleDetection: """Wash sale detection per IRS §1091 for prediction markets.""" def test_basic_wash_sale_detected(self): """Sell at loss + repurchase within 30 days = wash sale.""" trades = [ - _make_trade(timestamp=datetime(2024, 1, 1), type="Buy", - price=0.70, shares=10, cost=7.0, side="YES"), - _make_trade(timestamp=datetime(2024, 3, 1), type="Sell", - price=0.50, shares=10, cost=5.0, side="YES"), + _make_trade( + timestamp=datetime(2024, 1, 1), + type="Buy", + price=0.70, + shares=10, + cost=7.0, + side="YES", + ), + _make_trade( + timestamp=datetime(2024, 3, 1), + type="Sell", + price=0.50, + shares=10, + cost=5.0, + side="YES", + ), # Repurchase within 30 days - _make_trade(timestamp=datetime(2024, 3, 15), type="Buy", - price=0.55, shares=10, cost=5.5, side="YES"), + _make_trade( + timestamp=datetime(2024, 3, 15), + type="Buy", + price=0.55, + shares=10, + cost=5.5, + side="YES", + ), ] result = calculate_capital_gains(trades, tax_year=2024) assert "wash_sales" in result @@ -144,12 +203,30 @@ def test_basic_wash_sale_detected(self): def test_no_wash_sale_when_gain(self): """Sells at a gain should not trigger wash sale.""" trades = [ - _make_trade(timestamp=datetime(2024, 1, 1), type="Buy", - price=0.50, shares=10, cost=5.0, side="YES"), - _make_trade(timestamp=datetime(2024, 3, 1), type="Sell", - price=0.70, shares=10, cost=7.0, side="YES"), - _make_trade(timestamp=datetime(2024, 3, 15), type="Buy", - price=0.55, shares=10, cost=5.5, side="YES"), + _make_trade( + timestamp=datetime(2024, 1, 1), + type="Buy", + price=0.50, + shares=10, + cost=5.0, + side="YES", + ), + _make_trade( + timestamp=datetime(2024, 3, 1), + type="Sell", + price=0.70, + shares=10, + cost=7.0, + side="YES", + ), + _make_trade( + timestamp=datetime(2024, 3, 15), + type="Buy", + price=0.55, + shares=10, + cost=5.5, + side="YES", + ), ] result = calculate_capital_gains(trades, tax_year=2024) assert "wash_sales" not in result @@ -157,13 +234,31 @@ def test_no_wash_sale_when_gain(self): def test_no_wash_sale_outside_window(self): """Repurchase > 30 days after loss sale is not a wash sale.""" trades = [ - _make_trade(timestamp=datetime(2024, 1, 1), type="Buy", - price=0.70, shares=10, cost=7.0, side="YES"), - _make_trade(timestamp=datetime(2024, 3, 1), type="Sell", - price=0.50, shares=10, cost=5.0, side="YES"), + _make_trade( + timestamp=datetime(2024, 1, 1), + type="Buy", + price=0.70, + shares=10, + cost=7.0, + side="YES", + ), + _make_trade( + timestamp=datetime(2024, 3, 1), + type="Sell", + price=0.50, + shares=10, + cost=5.0, + side="YES", + ), # Repurchase 45 days later — outside window - _make_trade(timestamp=datetime(2024, 4, 15), type="Buy", - price=0.55, shares=10, cost=5.5, side="YES"), + _make_trade( + timestamp=datetime(2024, 4, 15), + type="Buy", + price=0.55, + shares=10, + cost=5.5, + side="YES", + ), ] result = calculate_capital_gains(trades, tax_year=2024) assert "wash_sales" not in result @@ -171,13 +266,31 @@ def test_no_wash_sale_outside_window(self): def test_wash_sale_cross_side(self): """Buy on NO side within 30 days of YES loss = wash sale (same market).""" trades = [ - _make_trade(timestamp=datetime(2024, 1, 1), type="Buy", - price=0.70, shares=10, cost=7.0, side="YES"), - _make_trade(timestamp=datetime(2024, 3, 1), type="Sell", - price=0.50, shares=10, cost=5.0, side="YES"), + _make_trade( + timestamp=datetime(2024, 1, 1), + type="Buy", + price=0.70, + shares=10, + cost=7.0, + side="YES", + ), + _make_trade( + timestamp=datetime(2024, 3, 1), + type="Sell", + price=0.50, + shares=10, + cost=5.0, + side="YES", + ), # Buy NO side within 30 days - _make_trade(timestamp=datetime(2024, 3, 15), type="Buy", - price=0.45, shares=10, cost=4.5, side="NO"), + _make_trade( + timestamp=datetime(2024, 3, 15), + type="Buy", + price=0.45, + shares=10, + cost=4.5, + side="NO", + ), ] result = calculate_capital_gains(trades, tax_year=2024) assert "wash_sales" in result @@ -186,13 +299,31 @@ def test_wash_sale_cross_side(self): def test_wash_sale_before_sell(self): """Buy within 30 days BEFORE the loss sale = wash sale.""" trades = [ - _make_trade(timestamp=datetime(2024, 1, 1), type="Buy", - price=0.70, shares=10, cost=7.0, side="YES"), + _make_trade( + timestamp=datetime(2024, 1, 1), + type="Buy", + price=0.70, + shares=10, + cost=7.0, + side="YES", + ), # Another buy 15 days before the sell - _make_trade(timestamp=datetime(2024, 2, 14), type="Buy", - price=0.55, shares=5, cost=2.75, side="YES"), - _make_trade(timestamp=datetime(2024, 3, 1), type="Sell", - price=0.50, shares=10, cost=5.0, side="YES"), + _make_trade( + timestamp=datetime(2024, 2, 14), + type="Buy", + price=0.55, + shares=5, + cost=2.75, + side="YES", + ), + _make_trade( + timestamp=datetime(2024, 3, 1), + type="Sell", + price=0.50, + shares=10, + cost=5.0, + side="YES", + ), ] result = calculate_capital_gains(trades, tax_year=2024) assert "wash_sales" in result @@ -202,25 +333,64 @@ def test_wash_sale_disallowed_loss_total(self): """wash_sale_disallowed_loss sums all disallowed losses.""" trades = [ # Market A: loss + repurchase - _make_trade(timestamp=datetime(2024, 1, 1), type="Buy", - price=0.80, shares=10, cost=8.0, side="YES", - market_slug="market-a"), - _make_trade(timestamp=datetime(2024, 3, 1), type="Sell", - price=0.50, shares=10, cost=5.0, side="YES", - market_slug="market-a"), - _make_trade(timestamp=datetime(2024, 3, 10), type="Buy", - price=0.55, shares=10, cost=5.5, side="YES", - market_slug="market-a"), + _make_trade( + timestamp=datetime(2024, 1, 1), + type="Buy", + price=0.80, + shares=10, + cost=8.0, + side="YES", + market_slug="market-a", + ), + _make_trade( + timestamp=datetime(2024, 3, 1), + type="Sell", + price=0.50, + shares=10, + cost=5.0, + side="YES", + market_slug="market-a", + ), + _make_trade( + timestamp=datetime(2024, 3, 10), + type="Buy", + price=0.55, + shares=10, + cost=5.5, + side="YES", + market_slug="market-a", + ), # Market B: loss + repurchase - _make_trade(timestamp=datetime(2024, 2, 1), type="Buy", - price=0.90, shares=20, cost=18.0, side="YES", - market_slug="market-b", market="Market B"), - _make_trade(timestamp=datetime(2024, 4, 1), type="Sell", - price=0.60, shares=20, cost=12.0, side="YES", - market_slug="market-b", market="Market B"), - _make_trade(timestamp=datetime(2024, 4, 20), type="Buy", - price=0.65, shares=20, cost=13.0, side="YES", - market_slug="market-b", market="Market B"), + _make_trade( + timestamp=datetime(2024, 2, 1), + type="Buy", + price=0.90, + shares=20, + cost=18.0, + side="YES", + market_slug="market-b", + market="Market B", + ), + _make_trade( + timestamp=datetime(2024, 4, 1), + type="Sell", + price=0.60, + shares=20, + cost=12.0, + side="YES", + market_slug="market-b", + market="Market B", + ), + _make_trade( + timestamp=datetime(2024, 4, 20), + type="Buy", + price=0.65, + shares=20, + cost=13.0, + side="YES", + market_slug="market-b", + market="Market B", + ), ] result = calculate_capital_gains(trades, tax_year=2024) assert "wash_sales" in result @@ -230,10 +400,12 @@ def test_wash_sale_disallowed_loss_total(self): def test_no_wash_sales_key_when_none(self): """Result should not contain wash_sales key when there are none.""" trades = [ - _make_trade(timestamp=datetime(2024, 1, 1), type="Buy", - price=0.50, shares=10, cost=5.0), - _make_trade(timestamp=datetime(2024, 6, 1), type="Sell", - price=0.70, shares=10, cost=7.0), + _make_trade( + timestamp=datetime(2024, 1, 1), type="Buy", price=0.50, shares=10, cost=5.0 + ), + _make_trade( + timestamp=datetime(2024, 6, 1), type="Sell", price=0.70, shares=10, cost=7.0 + ), ] result = calculate_capital_gains(trades, tax_year=2024) assert "wash_sales" not in result @@ -245,10 +417,22 @@ def test_original_buy_not_flagged_as_wash_sale(self): trades = [ # Buy on Feb 15, sell at loss on Mar 1 — only 14 days apart. # No repurchase. The original buy should NOT be flagged. - _make_trade(timestamp=datetime(2024, 2, 15), type="Buy", - price=0.70, shares=10, cost=7.0, side="YES"), - _make_trade(timestamp=datetime(2024, 3, 1), type="Sell", - price=0.50, shares=10, cost=5.0, side="YES"), + _make_trade( + timestamp=datetime(2024, 2, 15), + type="Buy", + price=0.70, + shares=10, + cost=7.0, + side="YES", + ), + _make_trade( + timestamp=datetime(2024, 3, 1), + type="Sell", + price=0.50, + shares=10, + cost=5.0, + side="YES", + ), ] result = calculate_capital_gains(trades, tax_year=2024) assert "wash_sales" not in result @@ -258,25 +442,64 @@ def test_same_date_multi_market_losses_independent(self): be independently checked for wash sales (not short-circuited).""" trades = [ # Market A: loss with repurchase - _make_trade(timestamp=datetime(2024, 1, 1), type="Buy", - price=0.80, shares=10, cost=8.0, side="YES", - market_slug="market-a"), - _make_trade(timestamp=datetime(2024, 3, 1), type="Sell", - price=0.50, shares=10, cost=5.0, side="YES", - market_slug="market-a"), - _make_trade(timestamp=datetime(2024, 3, 10), type="Buy", - price=0.55, shares=10, cost=5.5, side="YES", - market_slug="market-a"), + _make_trade( + timestamp=datetime(2024, 1, 1), + type="Buy", + price=0.80, + shares=10, + cost=8.0, + side="YES", + market_slug="market-a", + ), + _make_trade( + timestamp=datetime(2024, 3, 1), + type="Sell", + price=0.50, + shares=10, + cost=5.0, + side="YES", + market_slug="market-a", + ), + _make_trade( + timestamp=datetime(2024, 3, 10), + type="Buy", + price=0.55, + shares=10, + cost=5.5, + side="YES", + market_slug="market-a", + ), # Market B: loss with repurchase, SAME SELL DATE - _make_trade(timestamp=datetime(2024, 1, 1), type="Buy", - price=0.90, shares=10, cost=9.0, side="YES", - market_slug="market-b", market="Market B"), - _make_trade(timestamp=datetime(2024, 3, 1), type="Sell", - price=0.60, shares=10, cost=6.0, side="YES", - market_slug="market-b", market="Market B"), - _make_trade(timestamp=datetime(2024, 3, 15), type="Buy", - price=0.65, shares=10, cost=6.5, side="YES", - market_slug="market-b", market="Market B"), + _make_trade( + timestamp=datetime(2024, 1, 1), + type="Buy", + price=0.90, + shares=10, + cost=9.0, + side="YES", + market_slug="market-b", + market="Market B", + ), + _make_trade( + timestamp=datetime(2024, 3, 1), + type="Sell", + price=0.60, + shares=10, + cost=6.0, + side="YES", + market_slug="market-b", + market="Market B", + ), + _make_trade( + timestamp=datetime(2024, 3, 15), + type="Buy", + price=0.65, + shares=10, + cost=6.5, + side="YES", + market_slug="market-b", + market="Market B", + ), ] result = calculate_capital_gains(trades, tax_year=2024) assert "wash_sales" in result @@ -288,13 +511,31 @@ def test_same_date_multi_market_losses_independent(self): def test_wash_sale_at_exactly_30_days(self): """Repurchase at exactly 30 days should be flagged.""" trades = [ - _make_trade(timestamp=datetime(2024, 1, 1), type="Buy", - price=0.70, shares=10, cost=7.0, side="YES"), - _make_trade(timestamp=datetime(2024, 3, 1), type="Sell", - price=0.50, shares=10, cost=5.0, side="YES"), + _make_trade( + timestamp=datetime(2024, 1, 1), + type="Buy", + price=0.70, + shares=10, + cost=7.0, + side="YES", + ), + _make_trade( + timestamp=datetime(2024, 3, 1), + type="Sell", + price=0.50, + shares=10, + cost=5.0, + side="YES", + ), # Exactly 30 days later - _make_trade(timestamp=datetime(2024, 3, 31), type="Buy", - price=0.55, shares=10, cost=5.5, side="YES"), + _make_trade( + timestamp=datetime(2024, 3, 31), + type="Buy", + price=0.55, + shares=10, + cost=5.5, + side="YES", + ), ] result = calculate_capital_gains(trades, tax_year=2024) assert "wash_sales" in result @@ -302,13 +543,31 @@ def test_wash_sale_at_exactly_30_days(self): def test_wash_sale_at_31_days_not_flagged(self): """Repurchase at 31 days should NOT be flagged.""" trades = [ - _make_trade(timestamp=datetime(2024, 1, 1), type="Buy", - price=0.70, shares=10, cost=7.0, side="YES"), - _make_trade(timestamp=datetime(2024, 3, 1), type="Sell", - price=0.50, shares=10, cost=5.0, side="YES"), + _make_trade( + timestamp=datetime(2024, 1, 1), + type="Buy", + price=0.70, + shares=10, + cost=7.0, + side="YES", + ), + _make_trade( + timestamp=datetime(2024, 3, 1), + type="Sell", + price=0.50, + shares=10, + cost=5.0, + side="YES", + ), # 31 days later - _make_trade(timestamp=datetime(2024, 4, 1), type="Buy", - price=0.55, shares=10, cost=5.5, side="YES"), + _make_trade( + timestamp=datetime(2024, 4, 1), + type="Buy", + price=0.55, + shares=10, + cost=5.5, + side="YES", + ), ] result = calculate_capital_gains(trades, tax_year=2024) assert "wash_sales" not in result @@ -318,6 +577,7 @@ def test_wash_sale_at_31_days_not_flagged(self): # Limitless provider: trade type mapping # =========================================================================== + class TestLimitlessTradeTypeMapping: """Limitless API may return category types (trade/split/merge) instead of direction types (Buy/Sell). The provider must map correctly.""" @@ -325,6 +585,7 @@ class TestLimitlessTradeTypeMapping: def test_strategy_field_preferred_over_category_type(self): """When API returns type='trade' and strategy='Buy', use strategy.""" from prediction_analyzer.providers.limitless import LimitlessProvider + provider = LimitlessProvider() raw = { "type": "trade", @@ -341,6 +602,7 @@ def test_strategy_field_preferred_over_category_type(self): def test_strategy_sell(self): """strategy='Sell' should produce type='Sell'.""" from prediction_analyzer.providers.limitless import LimitlessProvider + provider = LimitlessProvider() raw = { "type": "trade", @@ -357,6 +619,7 @@ def test_strategy_sell(self): def test_category_type_without_strategy_defaults_to_buy(self): """If only type='trade' with no strategy, default to Buy.""" from prediction_analyzer.providers.limitless import LimitlessProvider + provider = LimitlessProvider() raw = { "type": "trade", @@ -372,6 +635,7 @@ def test_category_type_without_strategy_defaults_to_buy(self): def test_split_merge_types_default_to_buy(self): """Split/merge category types without strategy default to Buy.""" from prediction_analyzer.providers.limitless import LimitlessProvider + provider = LimitlessProvider() for cat_type in ("split", "merge", "conversion"): raw = { @@ -388,6 +652,7 @@ def test_split_merge_types_default_to_buy(self): def test_legacy_strategy_only_format(self): """Legacy file format with only strategy field (no type).""" from prediction_analyzer.providers.limitless import LimitlessProvider + provider = LimitlessProvider() raw = { "strategy": "Buy", @@ -403,6 +668,7 @@ def test_legacy_strategy_only_format(self): def test_underscore_strategy_normalized(self): """strategy='market_buy' should become 'Market Buy'.""" from prediction_analyzer.providers.limitless import LimitlessProvider + provider = LimitlessProvider() raw = { "strategy": "market_buy", @@ -418,6 +684,7 @@ def test_underscore_strategy_normalized(self): def test_micro_unit_conversion(self): """API format amounts should be divided by 1_000_000.""" from prediction_analyzer.providers.limitless import LimitlessProvider + provider = LimitlessProvider() raw = { "strategy": "Buy", @@ -437,18 +704,20 @@ def test_micro_unit_conversion(self): # Data completeness: total_trades_in_scope # =========================================================================== + class TestDataCompleteness: """Tax report includes total_trades_in_scope for verification.""" def test_total_trades_in_scope(self): """total_trades_in_scope counts all trades passed in.""" trades = [ - _make_trade(timestamp=datetime(2024, 1, 1), type="Buy", - price=0.50, shares=10, cost=5.0), - _make_trade(timestamp=datetime(2024, 3, 1), type="Buy", - price=0.60, shares=5, cost=3.0), - _make_trade(timestamp=datetime(2024, 6, 1), type="Sell", - price=0.70, shares=10, cost=7.0), + _make_trade( + timestamp=datetime(2024, 1, 1), type="Buy", price=0.50, shares=10, cost=5.0 + ), + _make_trade(timestamp=datetime(2024, 3, 1), type="Buy", price=0.60, shares=5, cost=3.0), + _make_trade( + timestamp=datetime(2024, 6, 1), type="Sell", price=0.70, shares=10, cost=7.0 + ), ] result = calculate_capital_gains(trades, tax_year=2024) assert result["total_trades_in_scope"] == 3 @@ -456,11 +725,12 @@ def test_total_trades_in_scope(self): def test_excess_sell_logs_warning(self, caplog): """Selling more shares than available buy lots should log a warning.""" import logging + trades = [ - _make_trade(timestamp=datetime(2024, 1, 1), type="Buy", - price=0.50, shares=5, cost=2.5), - _make_trade(timestamp=datetime(2024, 6, 1), type="Sell", - price=0.70, shares=10, cost=7.0), + _make_trade(timestamp=datetime(2024, 1, 1), type="Buy", price=0.50, shares=5, cost=2.5), + _make_trade( + timestamp=datetime(2024, 6, 1), type="Sell", price=0.70, shares=10, cost=7.0 + ), ] with caplog.at_level(logging.WARNING, logger="prediction_analyzer.tax"): result = calculate_capital_gains(trades, tax_year=2024) @@ -469,10 +739,12 @@ def test_excess_sell_logs_warning(self, caplog): def test_total_trades_includes_all_years(self): """total_trades_in_scope includes trades from all years.""" trades = [ - _make_trade(timestamp=datetime(2023, 1, 1), type="Buy", - price=0.50, shares=10, cost=5.0), - _make_trade(timestamp=datetime(2024, 6, 1), type="Sell", - price=0.70, shares=10, cost=7.0), + _make_trade( + timestamp=datetime(2023, 1, 1), type="Buy", price=0.50, shares=10, cost=5.0 + ), + _make_trade( + timestamp=datetime(2024, 6, 1), type="Sell", price=0.70, shares=10, cost=7.0 + ), ] result = calculate_capital_gains(trades, tax_year=2024) assert result["total_trades_in_scope"] == 2 diff --git a/tests/test_package.py b/tests/test_package.py index 7ce51ce..c3dff39 100644 --- a/tests/test_package.py +++ b/tests/test_package.py @@ -2,76 +2,83 @@ """ Basic tests for the prediction analyzer package """ + import pytest from datetime import datetime from prediction_analyzer.trade_loader import Trade from prediction_analyzer.pnl import calculate_global_pnl_summary from prediction_analyzer.filters import filter_by_date, filter_by_trade_type + def create_sample_trade(**kwargs): """Helper to create sample trades""" defaults = { - 'market': 'Test Market', - 'market_slug': 'test-market', - 'timestamp': datetime.now(), - 'price': 50.0, - 'shares': 10.0, - 'cost': 5.0, - 'type': 'Buy', - 'side': 'YES', - 'pnl': 0.0 + "market": "Test Market", + "market_slug": "test-market", + "timestamp": datetime.now(), + "price": 50.0, + "shares": 10.0, + "cost": 5.0, + "type": "Buy", + "side": "YES", + "pnl": 0.0, } defaults.update(kwargs) - if 'pnl_is_set' not in kwargs and 'pnl' in kwargs: - defaults['pnl_is_set'] = True + if "pnl_is_set" not in kwargs and "pnl" in kwargs: + defaults["pnl_is_set"] = True return Trade(**defaults) + def test_trade_creation(): """Test Trade dataclass creation""" trade = create_sample_trade() - assert trade.market == 'Test Market' + assert trade.market == "Test Market" assert trade.price == 50.0 - assert trade.type == 'Buy' + assert trade.type == "Buy" + def test_global_pnl_calculation(): """Test global PnL summary calculation""" trades = [ create_sample_trade(pnl=10.0), create_sample_trade(pnl=-5.0), - create_sample_trade(pnl=15.0) + create_sample_trade(pnl=15.0), ] summary = calculate_global_pnl_summary(trades) - assert summary['total_trades'] == 3 - assert summary['total_pnl'] == 20.0 - assert summary['winning_trades'] == 2 - assert summary['losing_trades'] == 1 + assert summary["total_trades"] == 3 + assert summary["total_pnl"] == 20.0 + assert summary["winning_trades"] == 2 + assert summary["losing_trades"] == 1 + def test_filter_by_date(): """Test date filtering""" trades = [ create_sample_trade(timestamp=datetime(2024, 1, 1)), create_sample_trade(timestamp=datetime(2024, 6, 1)), - create_sample_trade(timestamp=datetime(2024, 12, 1)) + create_sample_trade(timestamp=datetime(2024, 12, 1)), ] - filtered = filter_by_date(trades, start='2024-05-01', end='2024-12-31') + filtered = filter_by_date(trades, start="2024-05-01", end="2024-12-31") assert len(filtered) == 2 + def test_filter_by_type(): """Test trade type filtering""" trades = [ - create_sample_trade(type='Buy'), - create_sample_trade(type='Sell'), - create_sample_trade(type='Buy') + create_sample_trade(type="Buy"), + create_sample_trade(type="Sell"), + create_sample_trade(type="Buy"), ] - filtered = filter_by_trade_type(trades, types=['Buy']) + filtered = filter_by_trade_type(trades, types=["Buy"]) assert len(filtered) == 2 - assert all(t.type == 'Buy' for t in filtered) + assert all(t.type == "Buy" for t in filtered) + -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/test_tax_bugfixes.py b/tests/test_tax_bugfixes.py index 140ad0b..c33ab18 100644 --- a/tests/test_tax_bugfixes.py +++ b/tests/test_tax_bugfixes.py @@ -5,6 +5,7 @@ These tests verify that capital gains calculations are correct for real-world scenarios that million-dollar traders would encounter. """ + import pytest from datetime import datetime @@ -40,20 +41,25 @@ def _make_trade(**kwargs): # Fix: All sells consume lots; only tax-year sells generate transactions. # =========================================================================== + class TestPriorYearSellsConsumeLots: """Sells in prior years MUST consume buy lots so later years are correct.""" def test_fifo_prior_year_sell_consumes_earliest_lot(self): """FIFO: a 2023 sell should consume the $0.30 lot, leaving only $0.70 for 2024.""" trades = [ - _make_trade(timestamp=datetime(2023, 1, 15), type="Buy", - price=0.30, shares=10, cost=3.0), - _make_trade(timestamp=datetime(2023, 6, 15), type="Buy", - price=0.70, shares=10, cost=7.0), - _make_trade(timestamp=datetime(2023, 12, 1), type="Sell", - price=0.90, shares=10, cost=9.0), - _make_trade(timestamp=datetime(2024, 3, 1), type="Sell", - price=0.50, shares=10, cost=5.0), + _make_trade( + timestamp=datetime(2023, 1, 15), type="Buy", price=0.30, shares=10, cost=3.0 + ), + _make_trade( + timestamp=datetime(2023, 6, 15), type="Buy", price=0.70, shares=10, cost=7.0 + ), + _make_trade( + timestamp=datetime(2023, 12, 1), type="Sell", price=0.90, shares=10, cost=9.0 + ), + _make_trade( + timestamp=datetime(2024, 3, 1), type="Sell", price=0.50, shares=10, cost=5.0 + ), ] result = calculate_capital_gains(trades, tax_year=2024, cost_basis_method="fifo") @@ -63,20 +69,24 @@ def test_fifo_prior_year_sell_consumes_earliest_lot(self): # 2024 sell should match the $0.70 lot (the $0.30 lot was consumed in 2023) assert tx["cost_basis"] == pytest.approx(7.0) # 10 * $0.70 - assert tx["proceeds"] == pytest.approx(5.0) # 10 * $0.50 - assert tx["gain_loss"] == pytest.approx(-2.0) # loss, NOT a gain + assert tx["proceeds"] == pytest.approx(5.0) # 10 * $0.50 + assert tx["gain_loss"] == pytest.approx(-2.0) # loss, NOT a gain def test_lifo_prior_year_sell_consumes_latest_lot(self): """LIFO: a 2023 sell should consume the $0.70 lot, leaving $0.30 for 2024.""" trades = [ - _make_trade(timestamp=datetime(2023, 1, 15), type="Buy", - price=0.30, shares=10, cost=3.0), - _make_trade(timestamp=datetime(2023, 6, 15), type="Buy", - price=0.70, shares=10, cost=7.0), - _make_trade(timestamp=datetime(2023, 12, 1), type="Sell", - price=0.90, shares=10, cost=9.0), - _make_trade(timestamp=datetime(2024, 3, 1), type="Sell", - price=0.50, shares=10, cost=5.0), + _make_trade( + timestamp=datetime(2023, 1, 15), type="Buy", price=0.30, shares=10, cost=3.0 + ), + _make_trade( + timestamp=datetime(2023, 6, 15), type="Buy", price=0.70, shares=10, cost=7.0 + ), + _make_trade( + timestamp=datetime(2023, 12, 1), type="Sell", price=0.90, shares=10, cost=9.0 + ), + _make_trade( + timestamp=datetime(2024, 3, 1), type="Sell", price=0.50, shares=10, cost=5.0 + ), ] result = calculate_capital_gains(trades, tax_year=2024, cost_basis_method="lifo") @@ -92,14 +102,18 @@ def test_lifo_prior_year_sell_consumes_latest_lot(self): def test_average_prior_year_sell_reduces_pool(self): """Average: a 2023 sell should reduce the share pool for 2024.""" trades = [ - _make_trade(timestamp=datetime(2023, 1, 15), type="Buy", - price=0.40, shares=10, cost=4.0), - _make_trade(timestamp=datetime(2023, 6, 15), type="Buy", - price=0.60, shares=10, cost=6.0), - _make_trade(timestamp=datetime(2023, 12, 1), type="Sell", - price=0.80, shares=10, cost=8.0), - _make_trade(timestamp=datetime(2024, 3, 1), type="Sell", - price=0.70, shares=10, cost=7.0), + _make_trade( + timestamp=datetime(2023, 1, 15), type="Buy", price=0.40, shares=10, cost=4.0 + ), + _make_trade( + timestamp=datetime(2023, 6, 15), type="Buy", price=0.60, shares=10, cost=6.0 + ), + _make_trade( + timestamp=datetime(2023, 12, 1), type="Sell", price=0.80, shares=10, cost=8.0 + ), + _make_trade( + timestamp=datetime(2024, 3, 1), type="Sell", price=0.70, shares=10, cost=7.0 + ), ] result = calculate_capital_gains(trades, tax_year=2024, cost_basis_method="average") @@ -116,10 +130,12 @@ def test_average_prior_year_sell_reduces_pool(self): def test_no_transactions_outside_tax_year(self): """Sells outside the tax year should NOT appear in transactions.""" trades = [ - _make_trade(timestamp=datetime(2023, 1, 15), type="Buy", - price=0.50, shares=10, cost=5.0), - _make_trade(timestamp=datetime(2023, 12, 1), type="Sell", - price=0.80, shares=10, cost=8.0), + _make_trade( + timestamp=datetime(2023, 1, 15), type="Buy", price=0.50, shares=10, cost=5.0 + ), + _make_trade( + timestamp=datetime(2023, 12, 1), type="Sell", price=0.80, shares=10, cost=8.0 + ), ] result = calculate_capital_gains(trades, tax_year=2024, cost_basis_method="fifo") @@ -130,14 +146,17 @@ def test_no_transactions_outside_tax_year(self): def test_prior_year_sell_prevents_double_counting(self): """Cannot sell more shares than bought — prior sells must deplete the pool.""" trades = [ - _make_trade(timestamp=datetime(2023, 1, 15), type="Buy", - price=0.50, shares=100, cost=50.0), + _make_trade( + timestamp=datetime(2023, 1, 15), type="Buy", price=0.50, shares=100, cost=50.0 + ), # Sell ALL 100 shares in 2023 - _make_trade(timestamp=datetime(2023, 6, 1), type="Sell", - price=0.80, shares=100, cost=80.0), + _make_trade( + timestamp=datetime(2023, 6, 1), type="Sell", price=0.80, shares=100, cost=80.0 + ), # Try to sell 50 more in 2024 — no lots should remain - _make_trade(timestamp=datetime(2024, 3, 1), type="Sell", - price=0.60, shares=50, cost=30.0), + _make_trade( + timestamp=datetime(2024, 3, 1), type="Sell", price=0.60, shares=50, cost=30.0 + ), ] result = calculate_capital_gains(trades, tax_year=2024, cost_basis_method="fifo") @@ -149,14 +168,17 @@ def test_prior_year_sell_prevents_double_counting(self): def test_partial_prior_year_sell(self): """A partial sell in 2023 should only consume that many shares.""" trades = [ - _make_trade(timestamp=datetime(2023, 1, 15), type="Buy", - price=0.40, shares=100, cost=40.0), + _make_trade( + timestamp=datetime(2023, 1, 15), type="Buy", price=0.40, shares=100, cost=40.0 + ), # Sell 30 of 100 in 2023 - _make_trade(timestamp=datetime(2023, 6, 1), type="Sell", - price=0.70, shares=30, cost=21.0), + _make_trade( + timestamp=datetime(2023, 6, 1), type="Sell", price=0.70, shares=30, cost=21.0 + ), # Sell 70 in 2024 - _make_trade(timestamp=datetime(2024, 3, 1), type="Sell", - price=0.60, shares=70, cost=42.0), + _make_trade( + timestamp=datetime(2024, 3, 1), type="Sell", price=0.60, shares=70, cost=42.0 + ), ] result = calculate_capital_gains(trades, tax_year=2024, cost_basis_method="fifo") @@ -167,22 +189,46 @@ def test_partial_prior_year_sell(self): # 70 remaining shares at $0.40 cost assert tx["shares"] == pytest.approx(70.0) assert tx["cost_basis"] == pytest.approx(28.0) # 70 * $0.40 - assert tx["proceeds"] == pytest.approx(42.0) # 70 * $0.60 + assert tx["proceeds"] == pytest.approx(42.0) # 70 * $0.60 assert tx["gain_loss"] == pytest.approx(14.0) def test_multi_market_isolation(self): """Prior-year sells in market A should not affect market B's lots.""" trades = [ # Market A: buy and sell in 2023 - _make_trade(timestamp=datetime(2023, 1, 1), type="Buy", - market_slug="market-a", price=0.30, shares=10, cost=3.0), - _make_trade(timestamp=datetime(2023, 6, 1), type="Sell", - market_slug="market-a", price=0.80, shares=10, cost=8.0), + _make_trade( + timestamp=datetime(2023, 1, 1), + type="Buy", + market_slug="market-a", + price=0.30, + shares=10, + cost=3.0, + ), + _make_trade( + timestamp=datetime(2023, 6, 1), + type="Sell", + market_slug="market-a", + price=0.80, + shares=10, + cost=8.0, + ), # Market B: buy in 2023, sell in 2024 - _make_trade(timestamp=datetime(2023, 1, 1), type="Buy", - market_slug="market-b", price=0.50, shares=10, cost=5.0), - _make_trade(timestamp=datetime(2024, 3, 1), type="Sell", - market_slug="market-b", price=0.70, shares=10, cost=7.0), + _make_trade( + timestamp=datetime(2023, 1, 1), + type="Buy", + market_slug="market-b", + price=0.50, + shares=10, + cost=5.0, + ), + _make_trade( + timestamp=datetime(2024, 3, 1), + type="Sell", + market_slug="market-b", + price=0.70, + shares=10, + cost=7.0, + ), ] result = calculate_capital_gains(trades, tax_year=2024, cost_basis_method="fifo") @@ -199,16 +245,23 @@ def test_multi_market_isolation(self): # with sub-second timestamps. Fix: use datetime(tax_year + 1, 1, 1) with <. # =========================================================================== + class TestYearBoundaryPrecision: """Year boundary should handle sub-second timestamps correctly.""" def test_trade_at_last_microsecond_of_year_included(self): """A trade at 23:59:59.999999 on Dec 31 should be included.""" trades = [ - _make_trade(timestamp=datetime(2024, 1, 1), type="Buy", - price=0.50, shares=10, cost=5.0), - _make_trade(timestamp=datetime(2024, 12, 31, 23, 59, 59, 999999), - type="Sell", price=0.80, shares=10, cost=8.0), + _make_trade( + timestamp=datetime(2024, 1, 1), type="Buy", price=0.50, shares=10, cost=5.0 + ), + _make_trade( + timestamp=datetime(2024, 12, 31, 23, 59, 59, 999999), + type="Sell", + price=0.80, + shares=10, + cost=8.0, + ), ] result = calculate_capital_gains(trades, tax_year=2024, cost_basis_method="fifo") @@ -217,10 +270,16 @@ def test_trade_at_last_microsecond_of_year_included(self): def test_trade_at_midnight_jan1_next_year_excluded(self): """A trade at exactly midnight Jan 1 next year should be excluded.""" trades = [ - _make_trade(timestamp=datetime(2024, 1, 1), type="Buy", - price=0.50, shares=10, cost=5.0), - _make_trade(timestamp=datetime(2025, 1, 1, 0, 0, 0), - type="Sell", price=0.80, shares=10, cost=8.0), + _make_trade( + timestamp=datetime(2024, 1, 1), type="Buy", price=0.50, shares=10, cost=5.0 + ), + _make_trade( + timestamp=datetime(2025, 1, 1, 0, 0, 0), + type="Sell", + price=0.80, + shares=10, + cost=8.0, + ), ] result = calculate_capital_gains(trades, tax_year=2024, cost_basis_method="fifo") @@ -229,10 +288,16 @@ def test_trade_at_midnight_jan1_next_year_excluded(self): def test_trade_at_midnight_jan1_same_year_included(self): """A trade at midnight Jan 1 of the tax year should be included.""" trades = [ - _make_trade(timestamp=datetime(2023, 6, 1), type="Buy", - price=0.50, shares=10, cost=5.0), - _make_trade(timestamp=datetime(2024, 1, 1, 0, 0, 0), - type="Sell", price=0.80, shares=10, cost=8.0), + _make_trade( + timestamp=datetime(2023, 6, 1), type="Buy", price=0.50, shares=10, cost=5.0 + ), + _make_trade( + timestamp=datetime(2024, 1, 1, 0, 0, 0), + type="Sell", + price=0.80, + shares=10, + cost=8.0, + ), ] result = calculate_capital_gains(trades, tax_year=2024, cost_basis_method="fifo") @@ -243,16 +308,23 @@ def test_trade_at_midnight_jan1_same_year_included(self): # Holding period classification (short-term vs long-term) # =========================================================================== + class TestHoldingPeriodClassification: """Verify short-term vs long-term is correct at the 365-day boundary.""" def test_exactly_365_days_is_long_term(self): """Holding for exactly 365 days should be long-term.""" trades = [ - _make_trade(timestamp=datetime(2023, 3, 1), type="Buy", - price=0.50, shares=10, cost=5.0), - _make_trade(timestamp=datetime(2024, 2, 29), type="Sell", # 365 days later (2024 is leap year) - price=0.80, shares=10, cost=8.0), + _make_trade( + timestamp=datetime(2023, 3, 1), type="Buy", price=0.50, shares=10, cost=5.0 + ), + _make_trade( + timestamp=datetime(2024, 2, 29), + type="Sell", # 365 days later (2024 is leap year) + price=0.80, + shares=10, + cost=8.0, + ), ] result = calculate_capital_gains(trades, tax_year=2024, cost_basis_method="fifo") @@ -262,10 +334,16 @@ def test_exactly_365_days_is_long_term(self): def test_364_days_is_short_term(self): """Holding for 364 days should be short-term.""" trades = [ - _make_trade(timestamp=datetime(2023, 3, 2), type="Buy", - price=0.50, shares=10, cost=5.0), - _make_trade(timestamp=datetime(2024, 2, 29), type="Sell", # 364 days - price=0.80, shares=10, cost=8.0), + _make_trade( + timestamp=datetime(2023, 3, 2), type="Buy", price=0.50, shares=10, cost=5.0 + ), + _make_trade( + timestamp=datetime(2024, 2, 29), + type="Sell", # 364 days + price=0.80, + shares=10, + cost=8.0, + ), ] result = calculate_capital_gains(trades, tax_year=2024, cost_basis_method="fifo") @@ -276,21 +354,24 @@ def test_short_and_long_term_separation(self): """Gains should be correctly split between short-term and long-term.""" trades = [ # Long-term lot: bought > 365 days before sell - _make_trade(timestamp=datetime(2022, 6, 1), type="Buy", - price=0.30, shares=10, cost=3.0), + _make_trade( + timestamp=datetime(2022, 6, 1), type="Buy", price=0.30, shares=10, cost=3.0 + ), # Short-term lot: bought < 365 days before sell - _make_trade(timestamp=datetime(2024, 1, 15), type="Buy", - price=0.60, shares=10, cost=6.0), + _make_trade( + timestamp=datetime(2024, 1, 15), type="Buy", price=0.60, shares=10, cost=6.0 + ), # Sell all 20 shares - _make_trade(timestamp=datetime(2024, 6, 1), type="Sell", - price=0.80, shares=20, cost=16.0), + _make_trade( + timestamp=datetime(2024, 6, 1), type="Sell", price=0.80, shares=20, cost=16.0 + ), ] result = calculate_capital_gains(trades, tax_year=2024, cost_basis_method="fifo") assert result["transaction_count"] == 2 # FIFO: first 10 from 2022 (long-term), next 10 from 2024 (short-term) - assert result["long_term_gains"] == pytest.approx(5.0) # 10 * (0.80 - 0.30) + assert result["long_term_gains"] == pytest.approx(5.0) # 10 * (0.80 - 0.30) assert result["short_term_gains"] == pytest.approx(2.0) # 10 * (0.80 - 0.60) @@ -298,6 +379,7 @@ def test_short_and_long_term_separation(self): # FIFO / LIFO / Average cost basis method correctness # =========================================================================== + class TestCostBasisMethods: """Verify each cost basis method produces correct results.""" @@ -305,14 +387,18 @@ class TestCostBasisMethods: def three_lot_trades(self): """Three buys at different prices, one sell of all.""" return [ - _make_trade(timestamp=datetime(2024, 1, 1), type="Buy", - price=0.30, shares=10, cost=3.0), - _make_trade(timestamp=datetime(2024, 2, 1), type="Buy", - price=0.50, shares=10, cost=5.0), - _make_trade(timestamp=datetime(2024, 3, 1), type="Buy", - price=0.70, shares=10, cost=7.0), - _make_trade(timestamp=datetime(2024, 6, 1), type="Sell", - price=0.60, shares=30, cost=18.0), + _make_trade( + timestamp=datetime(2024, 1, 1), type="Buy", price=0.30, shares=10, cost=3.0 + ), + _make_trade( + timestamp=datetime(2024, 2, 1), type="Buy", price=0.50, shares=10, cost=5.0 + ), + _make_trade( + timestamp=datetime(2024, 3, 1), type="Buy", price=0.70, shares=10, cost=7.0 + ), + _make_trade( + timestamp=datetime(2024, 6, 1), type="Sell", price=0.60, shares=30, cost=18.0 + ), ] def test_fifo_matches_oldest_first(self, three_lot_trades): @@ -334,7 +420,9 @@ def test_lifo_matches_newest_first(self, three_lot_trades): assert result["net_gain_loss"] == pytest.approx(3.0) def test_average_uses_weighted_mean(self, three_lot_trades): - result = calculate_capital_gains(three_lot_trades, tax_year=2024, cost_basis_method="average") + result = calculate_capital_gains( + three_lot_trades, tax_year=2024, cost_basis_method="average" + ) # Average: (3 + 5 + 7) / 30 = $0.50 per share assert result["transaction_count"] == 1 # average creates one synthetic lot @@ -351,16 +439,27 @@ def test_invalid_method_raises(self): # Large-value accuracy (million-dollar scenarios) # =========================================================================== + class TestLargeValueAccuracy: """Verify calculations don't lose precision at high dollar amounts.""" def test_million_dollar_trade_precision(self): """$1M+ trades should maintain cent-level precision.""" trades = [ - _make_trade(timestamp=datetime(2024, 1, 1), type="Buy", - price=0.5500, shares=2_000_000, cost=1_100_000.00), - _make_trade(timestamp=datetime(2024, 6, 1), type="Sell", - price=0.5501, shares=2_000_000, cost=1_100_200.00), + _make_trade( + timestamp=datetime(2024, 1, 1), + type="Buy", + price=0.5500, + shares=2_000_000, + cost=1_100_000.00, + ), + _make_trade( + timestamp=datetime(2024, 6, 1), + type="Sell", + price=0.5501, + shares=2_000_000, + cost=1_100_200.00, + ), ] result = calculate_capital_gains(trades, tax_year=2024, cost_basis_method="fifo") @@ -374,14 +473,24 @@ def test_many_small_trades_accumulate_correctly(self): """1000 small trades should sum correctly.""" trades = [] for i in range(1000): - trades.append(_make_trade( - timestamp=datetime(2024, 1, 1, i // 60, i % 60), - type="Buy", price=0.50, shares=1.0, cost=0.50, - )) - trades.append(_make_trade( - timestamp=datetime(2024, 6, 1), type="Sell", - price=0.60, shares=1000.0, cost=600.0, - )) + trades.append( + _make_trade( + timestamp=datetime(2024, 1, 1, i // 60, i % 60), + type="Buy", + price=0.50, + shares=1.0, + cost=0.50, + ) + ) + trades.append( + _make_trade( + timestamp=datetime(2024, 6, 1), + type="Sell", + price=0.60, + shares=1000.0, + cost=600.0, + ) + ) result = calculate_capital_gains(trades, tax_year=2024, cost_basis_method="fifo") diff --git a/tests/test_trader_critical.py b/tests/test_trader_critical.py index f050570..4654393 100644 --- a/tests/test_trader_critical.py +++ b/tests/test_trader_critical.py @@ -6,6 +6,7 @@ - Timestamp parsing failure visibility - Tax report diagnostics for unrecognized trade types """ + import json import math import pytest @@ -38,6 +39,7 @@ def _make_trade(**kwargs): # CSV/Excel exports: NaN/Inf sanitization (previously used vars(t)) # =========================================================================== + class TestCsvExportSanitization: """CSV export should sanitize NaN/Inf values via to_dict().""" @@ -45,7 +47,7 @@ def test_nan_pnl_in_csv(self, tmp_path): from prediction_analyzer.reporting.report_data import export_to_csv import pandas as pd - trades = [_make_trade(pnl=float('nan'), pnl_is_set=True)] + trades = [_make_trade(pnl=float("nan"), pnl_is_set=True)] outfile = str(tmp_path / "test.csv") export_to_csv(trades, filename=outfile) @@ -57,7 +59,7 @@ def test_inf_cost_in_csv(self, tmp_path): from prediction_analyzer.reporting.report_data import export_to_csv import pandas as pd - trades = [_make_trade(cost=float('inf'))] + trades = [_make_trade(cost=float("inf"))] outfile = str(tmp_path / "test.csv") export_to_csv(trades, filename=outfile) @@ -72,7 +74,7 @@ def test_nan_pnl_in_excel(self, tmp_path): from prediction_analyzer.reporting.report_data import export_to_excel import pandas as pd - trades = [_make_trade(pnl=float('nan'), pnl_is_set=True)] + trades = [_make_trade(pnl=float("nan"), pnl_is_set=True)] outfile = str(tmp_path / "test.xlsx") export_to_excel(trades, filename=outfile) @@ -84,6 +86,7 @@ def test_nan_pnl_in_excel(self, tmp_path): # Tax: Claim/Won/Loss trade types are taxable settlement events # =========================================================================== + class TestTaxClaimWonLossCoverage: """Tax report must handle Claim/Won/Loss as sell-like dispositions.""" @@ -92,10 +95,12 @@ def test_claim_type_generates_transaction(self): from prediction_analyzer.tax import calculate_capital_gains trades = [ - _make_trade(timestamp=datetime(2024, 1, 1), type="Buy", - price=0.40, shares=100, cost=40.0), - _make_trade(timestamp=datetime(2024, 6, 1), type="Claim", - price=1.00, shares=100, cost=100.0), + _make_trade( + timestamp=datetime(2024, 1, 1), type="Buy", price=0.40, shares=100, cost=40.0 + ), + _make_trade( + timestamp=datetime(2024, 6, 1), type="Claim", price=1.00, shares=100, cost=100.0 + ), ] result = calculate_capital_gains(trades, tax_year=2024, cost_basis_method="fifo") @@ -111,10 +116,12 @@ def test_won_type_generates_transaction(self): from prediction_analyzer.tax import calculate_capital_gains trades = [ - _make_trade(timestamp=datetime(2024, 1, 1), type="Buy", - price=0.30, shares=50, cost=15.0), - _make_trade(timestamp=datetime(2024, 9, 1), type="Won", - price=1.00, shares=50, cost=50.0), + _make_trade( + timestamp=datetime(2024, 1, 1), type="Buy", price=0.30, shares=50, cost=15.0 + ), + _make_trade( + timestamp=datetime(2024, 9, 1), type="Won", price=1.00, shares=50, cost=50.0 + ), ] result = calculate_capital_gains(trades, tax_year=2024, cost_basis_method="fifo") @@ -126,10 +133,12 @@ def test_loss_type_generates_transaction(self): from prediction_analyzer.tax import calculate_capital_gains trades = [ - _make_trade(timestamp=datetime(2024, 1, 1), type="Buy", - price=0.70, shares=100, cost=70.0), - _make_trade(timestamp=datetime(2024, 6, 1), type="Loss", - price=0.00, shares=100, cost=0.0), + _make_trade( + timestamp=datetime(2024, 1, 1), type="Buy", price=0.70, shares=100, cost=70.0 + ), + _make_trade( + timestamp=datetime(2024, 6, 1), type="Loss", price=0.00, shares=100, cost=0.0 + ), ] result = calculate_capital_gains(trades, tax_year=2024, cost_basis_method="fifo") @@ -142,6 +151,7 @@ def test_loss_type_generates_transaction(self): # Tax: Unrecognized trade types are reported in diagnostics # =========================================================================== + class TestTaxDiagnostics: """Tax report should surface unrecognized trade types so traders notice gaps.""" @@ -149,12 +159,15 @@ def test_skipped_types_included_in_result(self): from prediction_analyzer.tax import calculate_capital_gains trades = [ - _make_trade(timestamp=datetime(2024, 1, 1), type="Buy", - price=0.50, shares=10, cost=5.0), - _make_trade(timestamp=datetime(2024, 3, 1), type="Dividend", - price=0.0, shares=0, cost=1.0), - _make_trade(timestamp=datetime(2024, 6, 1), type="Rebate", - price=0.0, shares=0, cost=0.5), + _make_trade( + timestamp=datetime(2024, 1, 1), type="Buy", price=0.50, shares=10, cost=5.0 + ), + _make_trade( + timestamp=datetime(2024, 3, 1), type="Dividend", price=0.0, shares=0, cost=1.0 + ), + _make_trade( + timestamp=datetime(2024, 6, 1), type="Rebate", price=0.0, shares=0, cost=0.5 + ), ] result = calculate_capital_gains(trades, tax_year=2024, cost_basis_method="fifo") @@ -167,10 +180,12 @@ def test_no_skipped_types_when_all_recognized(self): from prediction_analyzer.tax import calculate_capital_gains trades = [ - _make_trade(timestamp=datetime(2024, 1, 1), type="Buy", - price=0.50, shares=10, cost=5.0), - _make_trade(timestamp=datetime(2024, 6, 1), type="Sell", - price=0.60, shares=10, cost=6.0), + _make_trade( + timestamp=datetime(2024, 1, 1), type="Buy", price=0.50, shares=10, cost=5.0 + ), + _make_trade( + timestamp=datetime(2024, 6, 1), type="Sell", price=0.60, shares=10, cost=6.0 + ), ] result = calculate_capital_gains(trades, tax_year=2024, cost_basis_method="fifo") @@ -180,8 +195,9 @@ def test_warning_logged_for_skipped_types(self): from prediction_analyzer.tax import calculate_capital_gains trades = [ - _make_trade(timestamp=datetime(2024, 1, 1), type="Unknown Type", - price=0.0, shares=0, cost=0.0), + _make_trade( + timestamp=datetime(2024, 1, 1), type="Unknown Type", price=0.0, shares=0, cost=0.0 + ), ] with patch("prediction_analyzer.tax.logger") as mock_logger: @@ -194,12 +210,14 @@ def test_warning_logged_for_skipped_types(self): # Timestamp parsing: failures should be visible, not silent # =========================================================================== + class TestTimestampParsingVisibility: """Unparseable timestamps should log a warning, not silently return epoch.""" def test_unparseable_value_logs_warning(self, caplog): """An unparseable timestamp should produce a warning.""" import logging + with caplog.at_level(logging.WARNING, logger="prediction_analyzer.trade_loader"): result = _parse_timestamp("completely-invalid-timestamp-xyz") # Should still return epoch fallback @@ -210,6 +228,7 @@ def test_unparseable_value_logs_warning(self, caplog): def test_valid_timestamps_no_warning(self, caplog): """Valid timestamps should NOT produce warnings.""" import logging + with caplog.at_level(logging.WARNING, logger="prediction_analyzer.trade_loader"): _parse_timestamp("2024-06-15T12:00:00Z") _parse_timestamp(1704067200)