diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 03c178d..93b5b55 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,34 +7,84 @@ on: branches: [ main ] jobs: - lint-and-test: + lint: runs-on: ubuntu-latest - + steps: - uses: actions/checkout@v4 - + - name: Install uv uses: astral-sh/setup-uv@v4 with: version: "latest" - + - name: Set up Python uses: actions/setup-python@v5 with: python-version: "3.13" - + - name: Install dependencies run: make install - name: Run linting run: make lint - - name: Run tests with coverage - run: make coverage + test-servicekit: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Install uv + uses: astral-sh/setup-uv@v4 + with: + version: "latest" + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.13" + + - name: Install dependencies + run: make install + + - name: Test servicekit + run: make test-servicekit + + - name: Upload servicekit coverage + uses: codecov/codecov-action@v4 + with: + file: ./coverage.xml + flags: servicekit + fail_ci_if_error: false + token: ${{ secrets.CODECOV_TOKEN }} + + test-chapkit: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Install uv + uses: astral-sh/setup-uv@v4 + with: + version: "latest" + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.13" + + - name: Install dependencies + run: make install + + - name: Test chapkit + run: make test-chapkit - - name: Upload coverage to Codecov + - name: Upload chapkit coverage uses: codecov/codecov-action@v4 with: file: ./coverage.xml + flags: chapkit fail_ci_if_error: false token: ${{ secrets.CODECOV_TOKEN }} diff --git a/Makefile b/Makefile index 397972d..b3e3409 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: help install lint test coverage clean docker-build docker-run docs docs-serve docs-build +.PHONY: help install lint lint-servicekit lint-chapkit test test-servicekit test-chapkit coverage clean # ============================================================================== # Venv @@ -15,87 +15,66 @@ PYTHON := $(VENV_DIR)/bin/python help: @echo "Usage: make [target]" @echo "" - @echo "Targets:" - @echo " install Install dependencies" - @echo " lint Run linter and type checker" - @echo " test Run tests" - @echo " coverage Run tests with coverage reporting" - @echo " migrate Generate a new migration (use MSG='description')" - @echo " upgrade Apply pending migrations" - @echo " downgrade Revert last migration" - @echo " docs-serve Serve documentation locally with live reload" - @echo " docs-build Build documentation site" - @echo " docs Alias for docs-serve" - @echo " docker-build Build Docker image for examples" - @echo " docker-run Run example in Docker (use EXAMPLE='config_api')" - @echo " clean Clean up temporary files" + @echo "Monorepo Targets:" + @echo " install Install all dependencies" + @echo " lint Run linter and type checker on all packages" + @echo " lint-servicekit Lint servicekit only" + @echo " lint-chapkit Lint chapkit only" + @echo " test Run all tests" + @echo " test-servicekit Test servicekit only" + @echo " test-chapkit Test chapkit only" + @echo " coverage Run tests with coverage reporting" + @echo " clean Clean up temporary files" install: - @echo ">>> Installing dependencies" - @$(UV) sync --all-extras + @echo ">>> Installing workspace dependencies" + @$(UV) sync --all-packages lint: - @echo ">>> Running linter" + @echo ">>> Running linter on all packages" @$(UV) run ruff format . @$(UV) run ruff check . --fix - @echo ">>> Running type checker" - @$(UV) run mypy --exclude 'examples/old' src tests examples alembic + @echo ">>> Running type checkers" + @$(UV) run mypy packages/servicekit/src packages/chapkit/src @$(UV) run pyright +lint-servicekit: + @echo ">>> Linting servicekit" + @$(UV) run ruff format packages/servicekit + @$(UV) run ruff check packages/servicekit --fix + @$(UV) run mypy packages/servicekit/src + +lint-chapkit: + @echo ">>> Linting chapkit" + @$(UV) run ruff format packages/chapkit + @$(UV) run ruff check packages/chapkit --fix + @$(UV) run mypy packages/chapkit/src + test: - @echo ">>> Running tests" - @$(UV) run pytest -q + @echo ">>> Running all tests" + @$(UV) run pytest packages/servicekit/tests packages/chapkit/tests -q --ignore=packages/servicekit/tests/test_example_*.py + +test-servicekit: + @echo ">>> Testing servicekit" + @$(UV) run pytest packages/servicekit/tests -q --ignore=packages/servicekit/tests/test_example_artifact_api.py --ignore=packages/servicekit/tests/test_example_config_api.py --ignore=packages/servicekit/tests/test_example_config_artifact_api.py --ignore=packages/servicekit/tests/test_example_core_api.py --ignore=packages/servicekit/tests/test_example_core_cli.py --ignore=packages/servicekit/tests/test_example_custom_operations_api.py --ignore=packages/servicekit/tests/test_example_full_featured_api.py --ignore=packages/servicekit/tests/test_example_job_scheduler_api.py --ignore=packages/servicekit/tests/test_example_job_scheduler_sse_api.py --ignore=packages/servicekit/tests/test_example_library_usage_api.py --ignore=packages/servicekit/tests/test_example_monitoring_api.py --ignore=packages/servicekit/tests/test_example_task_execution_api.py + +test-chapkit: + @echo ">>> Testing chapkit" + @$(UV) run pytest packages/chapkit/tests -q coverage: @echo ">>> Running tests with coverage" - @$(UV) run coverage run -m pytest -q + @$(UV) run coverage run -m pytest packages/servicekit/tests packages/chapkit/tests -q @$(UV) run coverage report @$(UV) run coverage xml -migrate: - @echo ">>> Generating migration: $(MSG)" - @$(UV) run alembic revision --autogenerate -m "$(MSG)" - @echo ">>> Formatting migration file" - @$(UV) run ruff format alembic/versions - -upgrade: - @echo ">>> Applying pending migrations" - @$(UV) run alembic upgrade head - -downgrade: - @echo ">>> Reverting last migration" - @$(UV) run alembic downgrade -1 - -docs-serve: - @echo ">>> Serving documentation at http://127.0.0.1:8000" - @$(UV) run mkdocs serve - -docs-build: - @echo ">>> Building documentation site" - @$(UV) run mkdocs build - -docs: docs-serve - -docker-build: - @echo ">>> Building Docker image" - @docker build -t chapkit-examples . - -docker-run: - @echo ">>> Running Docker container with example: $(EXAMPLE)" - @if [ -z "$(EXAMPLE)" ]; then \ - echo "Error: EXAMPLE not specified. Usage: make docker-run EXAMPLE=config_api"; \ - exit 1; \ - fi - @docker run -it --rm -p 8000:8000 \ - -e EXAMPLE_MODULE=examples.$(EXAMPLE):app \ - chapkit-examples - clean: @echo ">>> Cleaning up" @find . -type f -name "*.pyc" -delete @find . -type d -name "__pycache__" -delete @find . -type d -name ".pytest_cache" -delete @find . -type d -name ".ruff_cache" -delete + @find . -type d -name ".mypy_cache" -delete # ============================================================================== # Default diff --git a/designs/library-split.md b/designs/library-split.md new file mode 100644 index 0000000..075088f --- /dev/null +++ b/designs/library-split.md @@ -0,0 +1,436 @@ +# Library Split Plan: servicekit (core) + chapkit (ML) + +## Executive Summary + +**Goal:** Split `chapkit` into two packages - `servicekit` (core framework) and `chapkit` (ML features). + +**Approach:** Monorepo with `packages/servicekit` and `packages/chapkit`, using UV workspace. + +**Key Insight:** Alembic migrations handled via `alembic_dir` parameter - servicekit keeps existing migrations, chapkit gets empty migration directory for future ML tables. No data migration needed for existing users. + +**Timeline:** 4 weeks (branch setup → testing → merge → publish) + +**Versioning:** Both packages start at v0.1.0 + +**Breaking Change:** Import paths only. Database schemas unchanged. + +--- + +## Quick Reference + +### Repository Structure +``` +chapkit2/ +├── packages/ +│ ├── servicekit/ # Core framework (no ML dependencies) +│ │ ├── src/servicekit/ # core/, modules/{config,artifact,task}, api/ +│ │ ├── alembic/ # Existing migrations (config/artifact/task tables) +│ │ └── tests/ +│ └── chapkit/ # ML features (depends on servicekit) +│ ├── src/chapkit/ # ml/, api/ (ML extensions) +│ ├── alembic/ # Empty (future ML tables) +│ └── tests/ +├── servicekit_examples/ # Core examples (separate from packages) +├── chapkit_examples/ # ML examples (separate from packages) +├── pyproject.toml # UV workspace config +└── Makefile # Monorepo targets +``` + +### Import Changes +```python +# Before (v0.1.x) +from chapkit import SqliteDatabase, BaseConfig +from chapkit.modules.ml import MLManager + +# After (v0.1.0) +from servicekit import SqliteDatabase, BaseConfig +from chapkit.ml import MLManager +``` + +### Alembic Strategy +- **servicekit**: Owns existing migrations (config/artifact/task tables) +- **chapkit**: Empty alembic directory, merges servicekit metadata +- **Existing users**: No database migration needed - only update imports + +--- + +## Implementation Phases + +### Phase 1: Branch & Monorepo Setup + +Create feature branch and monorepo structure: +```bash +git checkout -b feat/split-servicekit +mkdir -p packages/servicekit packages/chapkit +``` + +Configure UV workspace in root `pyproject.toml`: +```toml +[tool.uv.workspace] +members = ["packages/servicekit", "packages/chapkit"] +``` + +### Phase 2: Create servicekit Package + +**Copy infrastructure:** +- `src/chapkit/core/` → `packages/servicekit/src/servicekit/core/` +- `src/chapkit/modules/{config,artifact,task}/` → `packages/servicekit/src/servicekit/modules/` +- `src/chapkit/api/` → `packages/servicekit/src/servicekit/api/` (exclude ML parts) +- `alembic/` → `packages/servicekit/alembic/` + +**Exclude:** +- `modules/ml/` (stays in chapkit) +- ML-specific code: `MLServiceBuilder`, `MLServiceInfo`, `AssessedStatus` + +**Update imports:** +- Find/replace: `chapkit.core` → `servicekit.core` +- Update `alembic/env.py`: `from servicekit.core.models import Base` +- Update migration files: `chapkit.core.types` → `servicekit.core.types` + +**Create `pyproject.toml`:** +- Core dependencies only (no pandas/scikit-learn) +- Version: `0.1.0` + +**Fix bundled migration path in `database.py`:** +```python +# Old: str(Path(__file__).parent.parent.parent.parent / "alembic") +# New: +import servicekit +pkg_path = Path(servicekit.__file__).parent.parent +alembic_cfg.set_main_option("script_location", str(pkg_path / "alembic")) +``` + +### Phase 3: Create chapkit Package + +**Move ML module:** +- `src/chapkit/modules/ml/` → `packages/chapkit/src/chapkit/ml/` (top-level) + +**Create ML API layer:** +- Extract ML parts from `api/service_builder.py` to `packages/chapkit/src/chapkit/api/` +- Keep: `ServiceBuilder.with_ml()`, `MLServiceBuilder`, `MLServiceInfo` + +**Update imports:** +- Find/replace throughout chapkit: + - `chapkit.core` → `servicekit.core` + - `chapkit.modules.{config,artifact,task}` → `servicekit.modules.{config,artifact,task}` + - `chapkit.modules.ml` → `chapkit.ml` + +**Create Alembic environment:** +- Create `alembic/env.py` that merges servicekit and chapkit metadata +- Initially empty `versions/` directory + +**Create `pyproject.toml`:** +- Dependency: `servicekit>=0.1.0` +- ML dependencies: `pandas`, `scikit-learn` +- Version: `0.1.0` + +### Phase 4: Workspace Configuration + +**Root `pyproject.toml`:** +- Workspace members +- Shared tool configs (ruff, mypy, pytest, pyright) + +**Root `Makefile`:** +- Targets: `install`, `lint`, `test`, `coverage` +- Per-package targets: `lint-servicekit`, `test-chapkit`, etc. + +**Root `README.md`:** +- Document monorepo structure +- Link to package-specific READMEs + +### Phase 5: Documentation + +**Create package READMEs:** +- `packages/servicekit/README.md`: Core features, installation, quick start +- `packages/chapkit/README.md`: ML features, servicekit dependency, quick start + +**Create/update CLAUDE.md:** +- `packages/servicekit/CLAUDE.md`: Core architecture, no ML +- `packages/chapkit/CLAUDE.md`: ML-specific guidance +- Root `CLAUDE.md`: Monorepo overview, link to packages + +**Update examples:** +- Copy core examples to `servicekit_examples/` +- Copy ML examples to `chapkit_examples/` +- Update all imports +- Keep examples at root level (separate from packages for cleaner package structure) + +### Phase 6: Testing & Validation + +**Unit tests:** +```bash +make test-servicekit # Core tests, no ML deps +make test-chapkit # ML tests with servicekit +make test # All tests +``` + +**Migration tests:** +```bash +cd packages/servicekit +uv run alembic upgrade head # Verify tables created +``` + +**Type checking:** +```bash +make lint # Run ruff, mypy, pyright on both packages +``` + +**Example validation:** +```bash +uv run fastapi dev servicekit_examples/config_api.py +uv run fastapi dev chapkit_examples/ml_functional.py +``` + +### Phase 7: CI/CD Updates + +**GitHub Actions:** +- Separate jobs for `test-servicekit` and `test-chapkit` +- Shared lint job +- Per-package coverage reports + +**Publishing:** +```bash +cd packages/servicekit && uv build && uv publish +cd packages/chapkit && uv build && uv publish +``` + +### Phase 8: Review & Merge + +**Pre-merge checklist:** +- All tests passing +- No linting/type errors +- Documentation complete +- Examples working +- Alembic migrations functional + +**Pull request:** +- Title: `feat: split into servicekit (core) and chapkit (ML) packages` +- Include migration guide for users +- Document breaking changes (import paths only) + +**Merge:** +```bash +git checkout main +git merge refactor/library-split +git tag servicekit-v0.1.0 chapkit-v0.1.0 +git push --tags +``` + +### Phase 9: Post-Merge + +- Publish both packages to PyPI +- Update GitHub repo descriptions +- Create release notes +- Monitor for upgrade issues + +--- + +## Alembic Migration Strategy + +### servicekit Migrations + +**Location:** `packages/servicekit/alembic/versions/` + +**Tables:** `configs`, `artifacts`, `tasks`, `config_artifacts` + +**Migration:** `20251010_0927_4d869b5fb06e_initial_schema.py` (existing) + +**Auto-run:** Via `Database.init()` with bundled path + +### chapkit Migrations + +**Location:** `packages/chapkit/alembic/versions/` (empty initially) + +**Future tables:** ML-specific models (when needed) + +**Metadata merging:** `alembic/env.py` combines servicekit + chapkit Base metadata + +**User opt-in:** Specify `with_migrations(alembic_dir=chapkit_path)` if ML tables needed + +### User Upgrade Path + +**Existing databases (chapkit v0.1.x → servicekit+chapkit v0.1.0):** +```python +# Step 1: Update imports +from servicekit import SqliteDatabaseBuilder # was: chapkit +from chapkit.ml import MLManager # was: chapkit.modules.ml + +# Step 2: Run application +db = SqliteDatabaseBuilder.from_file("app.db").build() +await db.init() # Uses servicekit migrations - same tables, no changes needed +``` + +**No data migration required** - table schemas identical, only import paths change. + +--- + +## Dependencies + +### servicekit +- Core: `sqlalchemy`, `aiosqlite`, `alembic`, `pydantic`, `python-ulid` +- FastAPI: `fastapi`, `gunicorn` +- Monitoring: `opentelemetry-*`, `structlog` +- Misc: `geojson-pydantic` +- **Excludes:** pandas, scikit-learn + +### chapkit +- Required: `servicekit>=0.1.0` +- ML: `pandas>=2.3.3`, `scikit-learn>=1.7.2` + +--- + +## Key Decisions + +### Why Monorepo? +- Atomic refactoring across both packages +- Shared tooling (lint, test, CI) +- Easy integration testing +- Can extract to separate repos later + +### Why Examples Outside Packages? +- Cleaner package structure (packages only contain library code) +- Examples reference both servicekit and chapkit (some ML examples need both) +- Easier to run examples from monorepo root +- Examples don't need to be published to PyPI with packages + +### Why Keep Migrations in servicekit? +- Existing users have these tables +- No data migration needed +- Clear ownership: servicekit owns core tables +- chapkit can add ML tables independently via separate migrations + +### Why Top-Level `chapkit.ml`? +- Clear namespace: ML is obviously in chapkit +- Not a "module" in the same sense as config/artifact/task +- Shorter imports: `from chapkit.ml import ...` + +--- + +## Risk Mitigation + +**Import errors during transition:** +- Comprehensive find/replace with regex +- Type checking validates all imports +- Test suite catches missing imports + +**Alembic conflicts:** +- Keep servicekit migrations unchanged (same revision IDs) +- Test both migration paths +- Document multi-migration scenarios + +**Breaking changes for users:** +- Clear migration guide in PR and release notes +- Version v0.1.0 for both packages (fresh start post-split) +- Data compatibility maintained + +**Circular dependencies:** +- Strict dependency direction: chapkit → servicekit only +- Never import ML code in servicekit + +--- + +## Success Criteria + +**Functional:** +- servicekit installs without ML dependencies +- chapkit depends on and imports from servicekit +- Existing databases work unchanged +- All tests pass in both packages +- Examples work with new imports + +**Non-functional:** +- Clear migration documentation +- Type checking passes across packages +- CI/CD works for both packages +- Publishing process defined + +**User experience:** +- Intuitive import paths +- Simple installation (`pip install servicekit` or `pip install chapkit`) +- Clear error messages +- Smooth upgrade path + +--- + +## Timeline + +**Week 1:** Phases 1-3 (setup, create packages, move code) +**Week 2:** Phases 4-5 (workspace config, documentation) +**Week 3:** Phase 6 (testing, validation) +**Week 4:** Phases 7-9 (CI/CD, review, merge, publish) + +--- + +## Appendix: Detailed File Operations + +### servicekit Package Creation + +**Copy operations:** +```bash +cp -r src/chapkit/core packages/servicekit/src/servicekit/ +cp -r src/chapkit/modules/{config,artifact,task} packages/servicekit/src/servicekit/modules/ +cp -r src/chapkit/api packages/servicekit/src/servicekit/ # Remove ML parts manually +cp -r alembic packages/servicekit/ +cp alembic.ini packages/servicekit/ +cp -r tests packages/servicekit/ # Remove *ml* tests +``` + +**Import updates (regex find/replace):** +```regex +from chapkit\.core → from servicekit.core +from chapkit\.modules → from servicekit.modules +import chapkit\.core → import servicekit.core +import chapkit\.modules → import servicekit.modules +``` + +**Files requiring manual edits:** +- `packages/servicekit/alembic/env.py` (line 13) +- `packages/servicekit/alembic/versions/*.py` (import statements) +- `packages/servicekit/src/servicekit/core/database.py` (bundled path logic) +- `packages/servicekit/src/servicekit/api/service_builder.py` (remove ML classes) + +### chapkit Package Creation + +**Copy operations:** +```bash +cp -r src/chapkit/modules/ml packages/chapkit/src/chapkit/ml +# Extract ML parts from api/service_builder.py to packages/chapkit/src/chapkit/api/ +cp tests/test_*ml* packages/chapkit/tests/ +``` + +**Import updates (regex find/replace):** +```regex +from chapkit\.core → from servicekit.core +from chapkit\.modules\.config → from servicekit.modules.config +from chapkit\.modules\.artifact → from servicekit.modules.artifact +from chapkit\.modules\.task → from servicekit.modules.task +from chapkit\.modules\.ml → from chapkit.ml +``` + +**Create alembic environment:** +```bash +mkdir -p packages/chapkit/alembic/versions +# Create env.py with metadata merging +# Create alembic.ini +``` + +--- + +## Decisions Made + +1. **Package name:** `servicekit` (confirmed) +2. **Versioning:** Both packages at v0.1.0 (fresh start) + +## Questions to Resolve + +1. Long-term: keep monorepo or extract to separate repos? +2. Publishing: independent or synchronized releases? +3. Backward compatibility: deprecation warnings for old imports? + +--- + +## Next Steps + +1. Answer remaining questions +2. Begin Phase 1: Monorepo setup +3. Iterate based on discoveries during implementation diff --git a/examples/ml_class.py b/examples/ml_class.py index a2efbf4..6c15504 100644 --- a/examples/ml_class.py +++ b/examples/ml_class.py @@ -87,7 +87,7 @@ async def on_train( # Feature preprocessing if weather_config.normalize_features: self.scaler = StandardScaler() - X_scaled = self.scaler.fit_transform(X) + X_scaled = self.scaler.fit_transform(X) # pyright: ignore[reportOptionalMemberAccess] log.info( "features_normalized", mean=self.scaler.mean_.tolist(), # pyright: ignore[reportOptionalMemberAccess, reportAttributeAccessIssue] diff --git a/packages/chapkit/README.md b/packages/chapkit/README.md new file mode 100644 index 0000000..0b5c0d8 --- /dev/null +++ b/packages/chapkit/README.md @@ -0,0 +1,13 @@ +# chapkit + +ML-focused library built on servicekit for training and prediction workflows. + +## Installation + +```bash +pip install chapkit +``` + +## Quick Start + +Coming soon - ML module implementation in progress. diff --git a/packages/chapkit/pyproject.toml b/packages/chapkit/pyproject.toml new file mode 100644 index 0000000..456f530 --- /dev/null +++ b/packages/chapkit/pyproject.toml @@ -0,0 +1,43 @@ +[project] +name = "chapkit" +version = "0.1.0" +description = "ML-focused library built on servicekit for training and prediction workflows" +readme = "README.md" +authors = [{ name = "Morten Hansen", email = "morten@winterop.com" }] +license = { text = "AGPL-3.0-or-later" } +requires-python = ">=3.13" +keywords = [ + "machine-learning", + "mlops", + "fastapi", + "sqlalchemy", + "scikit-learn", +] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.13", + "Framework :: FastAPI", + "Topic :: Software Development :: Libraries :: Python Modules", + "Topic :: Scientific/Engineering :: Artificial Intelligence", +] +dependencies = [ + "servicekit>=0.1.0", + "pandas>=2.2.0", + "scikit-learn>=1.7.0", +] + +[tool.uv.sources] +servicekit = { workspace = true } + +[project.urls] +Homepage = "https://github.com/winterop-com/chapkit" +Repository = "https://github.com/winterop-com/chapkit" +Issues = "https://github.com/winterop-com/chapkit/issues" +Documentation = "https://github.com/winterop-com/chapkit#readme" + +[build-system] +requires = ["uv_build>=0.9.0,<0.10.0"] +build-backend = "uv_build" diff --git a/packages/chapkit/src/chapkit/__init__.py b/packages/chapkit/src/chapkit/__init__.py new file mode 100644 index 0000000..b2d7b92 --- /dev/null +++ b/packages/chapkit/src/chapkit/__init__.py @@ -0,0 +1,5 @@ +"""ML-focused library built on servicekit for training and prediction workflows.""" + +__version__ = "0.1.0" + +__all__ = ["__version__"] diff --git a/packages/chapkit/tests/__init__.py b/packages/chapkit/tests/__init__.py new file mode 100644 index 0000000..8c9ecc7 --- /dev/null +++ b/packages/chapkit/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for chapkit package.""" diff --git a/packages/servicekit/README.md b/packages/servicekit/README.md new file mode 100644 index 0000000..4b6aa9d --- /dev/null +++ b/packages/servicekit/README.md @@ -0,0 +1,36 @@ +# servicekit + +Async SQLAlchemy database library for Python 3.13+ with FastAPI integration. + +## Features + +- Async SQLAlchemy 2.0+ with SQLite/aiosqlite +- FastAPI integration with vertical slice architecture +- Built-in modules: Config, Artifact, Task +- Automatic Alembic migrations +- Repository and Manager patterns +- Type-safe with Pydantic schemas +- ULID-based entity identifiers + +## Installation + +```bash +pip install servicekit +``` + +## Quick Start + +```python +from servicekit.api import ServiceBuilder, ServiceInfo + +app = ( + ServiceBuilder(info=ServiceInfo(display_name="My Service")) + .with_health() + .with_config() + .build() +) +``` + +## Documentation + +See the main repository documentation for detailed usage information. diff --git a/packages/servicekit/alembic.ini b/packages/servicekit/alembic.ini new file mode 100644 index 0000000..a65a26e --- /dev/null +++ b/packages/servicekit/alembic.ini @@ -0,0 +1,74 @@ +# Alembic configuration file for chapkit + +[alembic] +# Path to migration scripts +script_location = alembic + +# Template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +file_template = %%(year)d%%(month).2d%%(day).2d_%%(hour).2d%%(minute).2d_%%(rev)s_%%(slug)s + +# Timezone for migration file timestamps +timezone = UTC + +# Max length of characters to apply to the "slug" field +truncate_slug_length = 40 + +# Set to 'true' to run the environment during the 'revision' command +# revision_environment = false + +# Set to 'true' to allow .pyc and .pyo files without a source .py file +# sourceless = false + +# Version location specification +# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions + +# Version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +# version_path_separator = os # Use os.pathsep. Default configuration used for new projects. + +# The output encoding used when revision files are written from script.py.mako +# output_encoding = utf-8 + +# Placeholder database URL - actual URL will be provided programmatically +sqlalchemy.url = sqlite+aiosqlite:///./app.db + +[post_write_hooks] +# Post-write hooks disabled - format migrations manually with: make lint + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/packages/servicekit/alembic/env.py b/packages/servicekit/alembic/env.py new file mode 100644 index 0000000..df42ebe --- /dev/null +++ b/packages/servicekit/alembic/env.py @@ -0,0 +1,76 @@ +"""Alembic environment configuration for async SQLAlchemy migrations.""" + +import asyncio +from logging.config import fileConfig + +# Import the Base metadata from chapkit models +from servicekit.core.models import Base +from sqlalchemy import pool +from sqlalchemy.engine import Connection +from sqlalchemy.ext.asyncio import async_engine_from_config + +from alembic import context + +# This is the Alembic Config object +config = context.config + +# Interpret the config file for Python logging (if present) +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# Set target metadata for 'autogenerate' support +target_metadata = Base.metadata + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode (generates SQL scripts).""" + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def do_run_migrations(connection: Connection) -> None: + """Run migrations with a connection.""" + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() + + +async def run_async_migrations() -> None: + """Run migrations in async mode using async engine.""" + connectable = async_engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + async with connectable.connect() as connection: + await connection.run_sync(do_run_migrations) + + await connectable.dispose() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode (connects to database).""" + # Since we're being called from Database.init() via run_in_executor, + # we're in a separate thread and can safely create a new event loop + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(run_async_migrations()) + finally: + loop.close() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/packages/servicekit/alembic/script.py.mako b/packages/servicekit/alembic/script.py.mako new file mode 100644 index 0000000..c54d11c --- /dev/null +++ b/packages/servicekit/alembic/script.py.mako @@ -0,0 +1,26 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + + +def upgrade() -> None: + """Apply database schema changes.""" + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + """Revert database schema changes.""" + ${downgrades if downgrades else "pass"} diff --git a/packages/servicekit/alembic/versions/20251010_0927_4d869b5fb06e_initial_schema.py b/packages/servicekit/alembic/versions/20251010_0927_4d869b5fb06e_initial_schema.py new file mode 100644 index 0000000..cfc71ed --- /dev/null +++ b/packages/servicekit/alembic/versions/20251010_0927_4d869b5fb06e_initial_schema.py @@ -0,0 +1,78 @@ +"""Initial schema. + +Revision ID: 4d869b5fb06e +Revises: +Create Date: 2025-10-10 09:27:01.866482+00:00 + +""" + +import servicekit.core.types +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "4d869b5fb06e" +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Apply database schema changes.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "artifacts", + sa.Column("parent_id", servicekit.core.types.ULIDType(length=26), nullable=True), + sa.Column("data", sa.PickleType(), nullable=False), + sa.Column("level", sa.Integer(), nullable=False), + sa.Column("id", servicekit.core.types.ULIDType(length=26), nullable=False), + sa.Column("created_at", sa.DateTime(), server_default=sa.text("(CURRENT_TIMESTAMP)"), nullable=False), + sa.Column("updated_at", sa.DateTime(), server_default=sa.text("(CURRENT_TIMESTAMP)"), nullable=False), + sa.ForeignKeyConstraint(["parent_id"], ["artifacts.id"], ondelete="SET NULL"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_artifacts_level"), "artifacts", ["level"], unique=False) + op.create_index(op.f("ix_artifacts_parent_id"), "artifacts", ["parent_id"], unique=False) + op.create_table( + "configs", + sa.Column("name", sa.String(), nullable=False), + sa.Column("data", sa.JSON(), nullable=False), + sa.Column("id", servicekit.core.types.ULIDType(length=26), nullable=False), + sa.Column("created_at", sa.DateTime(), server_default=sa.text("(CURRENT_TIMESTAMP)"), nullable=False), + sa.Column("updated_at", sa.DateTime(), server_default=sa.text("(CURRENT_TIMESTAMP)"), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_configs_name"), "configs", ["name"], unique=False) + op.create_table( + "config_artifacts", + sa.Column("config_id", servicekit.core.types.ULIDType(length=26), nullable=False), + sa.Column("artifact_id", servicekit.core.types.ULIDType(length=26), nullable=False), + sa.ForeignKeyConstraint(["artifact_id"], ["artifacts.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["config_id"], ["configs.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("config_id", "artifact_id"), + sa.UniqueConstraint("artifact_id"), + sa.UniqueConstraint("artifact_id", name="uq_artifact_id"), + ) + op.create_table( + "tasks", + sa.Column("command", sa.Text(), nullable=False), + sa.Column("id", servicekit.core.types.ULIDType(length=26), nullable=False), + sa.Column("created_at", sa.DateTime(), server_default=sa.text("(CURRENT_TIMESTAMP)"), nullable=False), + sa.Column("updated_at", sa.DateTime(), server_default=sa.text("(CURRENT_TIMESTAMP)"), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Revert database schema changes.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("tasks") + op.drop_table("config_artifacts") + op.drop_index(op.f("ix_configs_name"), table_name="configs") + op.drop_table("configs") + op.drop_index(op.f("ix_artifacts_parent_id"), table_name="artifacts") + op.drop_index(op.f("ix_artifacts_level"), table_name="artifacts") + op.drop_table("artifacts") + # ### end Alembic commands ### diff --git a/packages/servicekit/pyproject.toml b/packages/servicekit/pyproject.toml new file mode 100644 index 0000000..6e74210 --- /dev/null +++ b/packages/servicekit/pyproject.toml @@ -0,0 +1,65 @@ +[project] +name = "servicekit" +version = "0.1.0" +description = "Async SQLAlchemy database library for Python 3.13+ with FastAPI integration" +readme = "README.md" +authors = [{ name = "Morten Hansen", email = "morten@winterop.com" }] +license = { text = "AGPL-3.0-or-later" } +requires-python = ">=3.13" +keywords = [ + "fastapi", + "sqlalchemy", + "async", + "database", + "rest-api", + "crud", + "vertical-slice", +] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.13", + "Framework :: FastAPI", + "Topic :: Software Development :: Libraries :: Python Modules", + "Topic :: Database", +] +dependencies = [ + "aiosqlite>=0.21.0", + "alembic>=1.17.0", + "fastapi[standard]>=0.119.0", + "geojson-pydantic>=2.1.0", + "gunicorn>=23.0.0", + "opentelemetry-api>=1.37.0", + "opentelemetry-exporter-prometheus>=0.58b0", + "opentelemetry-instrumentation-fastapi>=0.58b0", + "opentelemetry-instrumentation-sqlalchemy>=0.58b0", + "opentelemetry-sdk>=1.37.0", + "pandas>=2.2.0", + "pydantic>=2.12.0", + "python-ulid>=3.1.0", + "sqlalchemy[asyncio]>=2.0.43", + "structlog>=24.4.0", +] + +[project.urls] +Homepage = "https://github.com/winterop-com/servicekit" +Repository = "https://github.com/winterop-com/servicekit" +Issues = "https://github.com/winterop-com/servicekit/issues" +Documentation = "https://github.com/winterop-com/servicekit#readme" + +[build-system] +requires = ["uv_build>=0.9.0,<0.10.0"] +build-backend = "uv_build" + +[dependency-groups] +dev = [ + "coverage>=7.10.7", + "mypy>=1.18.2", + "pyright>=1.1.406", + "pytest>=8.4.2", + "pytest-asyncio>=1.2.0", + "pytest-cov>=7.0.0", + "ruff>=0.14.0", +] diff --git a/packages/servicekit/src/servicekit/__init__.py b/packages/servicekit/src/servicekit/__init__.py new file mode 100644 index 0000000..91f7928 --- /dev/null +++ b/packages/servicekit/src/servicekit/__init__.py @@ -0,0 +1,84 @@ +"""Servicekit - async SQLAlchemy database library with FastAPI integration.""" + +# Core framework +from servicekit.core import ( + Base, + BaseManager, + BaseRepository, + Database, + Entity, + EntityIn, + EntityOut, + Manager, + Repository, + SqliteDatabase, + SqliteDatabaseBuilder, + ULIDType, +) + +# Artifact feature +from servicekit.modules.artifact import ( + Artifact, + ArtifactHierarchy, + ArtifactIn, + ArtifactManager, + ArtifactOut, + ArtifactRepository, + ArtifactTreeNode, + PandasDataFrame, +) + +# Config feature +from servicekit.modules.config import ( + BaseConfig, + Config, + ConfigIn, + ConfigManager, + ConfigOut, + ConfigRepository, +) + +# Task feature +from servicekit.modules.task import Task, TaskIn, TaskManager, TaskOut, TaskRepository + +__version__ = "0.1.0" + +__all__ = [ + # Core framework + "Database", + "SqliteDatabase", + "SqliteDatabaseBuilder", + "Repository", + "BaseRepository", + "Manager", + "BaseManager", + "Base", + "Entity", + "ULIDType", + "EntityIn", + "EntityOut", + # Config feature + "BaseConfig", + "Config", + "ConfigIn", + "ConfigOut", + "ConfigRepository", + "ConfigManager", + # Artifact feature + "Artifact", + "ArtifactHierarchy", + "ArtifactIn", + "ArtifactOut", + "ArtifactTreeNode", + "PandasDataFrame", + "ArtifactRepository", + "ArtifactManager", + # Task feature + "Task", + "TaskIn", + "TaskOut", + "TaskRepository", + "TaskManager", + # Version + "__version__", +] diff --git a/packages/servicekit/src/servicekit/api/__init__.py b/packages/servicekit/src/servicekit/api/__init__.py new file mode 100644 index 0000000..bb83958 --- /dev/null +++ b/packages/servicekit/src/servicekit/api/__init__.py @@ -0,0 +1,65 @@ +"""FastAPI routers and related presentation logic.""" + +from servicekit.core.api import CrudPermissions, CrudRouter, Router +from servicekit.core.api.middleware import ( + add_error_handlers, + add_logging_middleware, + database_error_handler, + validation_error_handler, +) +from servicekit.core.api.routers import HealthRouter, HealthState, HealthStatus, JobRouter, SystemInfo, SystemRouter +from servicekit.core.api.service_builder import ServiceInfo +from servicekit.core.api.utilities import build_location_url, run_app +from servicekit.core.logging import ( + add_request_context, + clear_request_context, + configure_logging, + get_logger, + reset_request_context, +) +from servicekit.modules.artifact import ArtifactRouter +from servicekit.modules.config import ConfigRouter +from servicekit.modules.task import TaskRouter + +from .dependencies import get_artifact_manager, get_config_manager +from .service_builder import AssessedStatus, MLServiceBuilder, MLServiceInfo, ServiceBuilder + +__all__ = [ + # Base classes + "Router", + "CrudRouter", + "CrudPermissions", + # Routers + "HealthRouter", + "HealthStatus", + "HealthState", + "JobRouter", + "SystemRouter", + "SystemInfo", + "ConfigRouter", + "ArtifactRouter", + "TaskRouter", + # Dependencies + "get_config_manager", + "get_artifact_manager", + # Middleware + "add_error_handlers", + "add_logging_middleware", + "database_error_handler", + "validation_error_handler", + # Logging + "configure_logging", + "get_logger", + "add_request_context", + "clear_request_context", + "reset_request_context", + # Builders + "ServiceBuilder", + "MLServiceBuilder", + "ServiceInfo", + "MLServiceInfo", + "AssessedStatus", + # Utilities + "build_location_url", + "run_app", +] diff --git a/packages/servicekit/src/servicekit/api/dependencies.py b/packages/servicekit/src/servicekit/api/dependencies.py new file mode 100644 index 0000000..70f085b --- /dev/null +++ b/packages/servicekit/src/servicekit/api/dependencies.py @@ -0,0 +1,61 @@ +"""Feature-specific FastAPI dependency injection for managers.""" + +from typing import Annotated + +from fastapi import Depends +from sqlalchemy.ext.asyncio import AsyncSession + +from servicekit.core.api.dependencies import get_database, get_scheduler, get_session +from servicekit.modules.artifact import ArtifactManager, ArtifactRepository +from servicekit.modules.config import BaseConfig, ConfigManager, ConfigRepository +from servicekit.modules.ml import MLManager +from servicekit.modules.task import TaskManager, TaskRepository + + +async def get_config_manager(session: Annotated[AsyncSession, Depends(get_session)]) -> ConfigManager[BaseConfig]: + """Get a config manager instance for dependency injection.""" + repo = ConfigRepository(session) + return ConfigManager[BaseConfig](repo, BaseConfig) + + +async def get_artifact_manager(session: Annotated[AsyncSession, Depends(get_session)]) -> ArtifactManager: + """Get an artifact manager instance for dependency injection.""" + artifact_repo = ArtifactRepository(session) + config_repo = ConfigRepository(session) + return ArtifactManager(artifact_repo, config_repo=config_repo) + + +async def get_task_manager( + session: Annotated[AsyncSession, Depends(get_session)], + artifact_manager: Annotated[ArtifactManager, Depends(get_artifact_manager)], +) -> TaskManager: + """Get a task manager instance for dependency injection.""" + from servicekit.core import Database + from servicekit.core.scheduler import JobScheduler + + repo = TaskRepository(session) + + # Get scheduler if available + scheduler: JobScheduler | None + try: + scheduler = get_scheduler() + except RuntimeError: + scheduler = None + + # Get database if available + database: Database | None + try: + database = get_database() + except RuntimeError: + database = None + + return TaskManager(repo, scheduler, database, artifact_manager) + + +async def get_ml_manager() -> MLManager: + """Get an ML manager instance for dependency injection. + + Note: This is a placeholder. The actual dependency is built by ServiceBuilder + with the runner in closure, then overridden via app.dependency_overrides. + """ + raise RuntimeError("ML manager dependency not configured. Use ServiceBuilder.with_ml() to enable ML operations.") diff --git a/packages/servicekit/src/servicekit/api/service_builder.py b/packages/servicekit/src/servicekit/api/service_builder.py new file mode 100644 index 0000000..a3f6c56 --- /dev/null +++ b/packages/servicekit/src/servicekit/api/service_builder.py @@ -0,0 +1,396 @@ +"""Service builder with module integration (config, artifact, task).""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import StrEnum +from typing import Any, Callable, Coroutine, List, Self + +from fastapi import Depends, FastAPI +from pydantic import EmailStr, HttpUrl +from sqlalchemy.ext.asyncio import AsyncSession + +from servicekit.core.api.crud import CrudPermissions +from servicekit.core.api.dependencies import get_database, get_scheduler, get_session +from servicekit.core.api.service_builder import BaseServiceBuilder, ServiceInfo +from servicekit.modules.artifact import ( + ArtifactHierarchy, + ArtifactIn, + ArtifactManager, + ArtifactOut, + ArtifactRepository, + ArtifactRouter, +) +from servicekit.modules.config import BaseConfig, ConfigIn, ConfigManager, ConfigOut, ConfigRepository, ConfigRouter +from servicekit.modules.ml import MLManager, MLRouter, ModelRunnerProtocol +from servicekit.modules.task import TaskIn, TaskManager, TaskOut, TaskRepository, TaskRouter + +from .dependencies import get_artifact_manager as default_get_artifact_manager +from .dependencies import get_config_manager as default_get_config_manager +from .dependencies import get_ml_manager as default_get_ml_manager +from .dependencies import get_task_manager as default_get_task_manager + +# Type alias for dependency factory functions +type DependencyFactory = Callable[..., Coroutine[Any, Any, Any]] + + +class AssessedStatus(StrEnum): + """Status indicating the maturity and validation level of an ML service.""" + + gray = "gray" # Not intended for use, deprecated, or meant for legacy use only + red = "red" # Highly experimental prototype - not validated, only for early experimentation + orange = "orange" # Shows promise on limited data, needs manual configuration and careful evaluation + yellow = "yellow" # Ready for more rigorous testing + green = "green" # Validated and ready for production use + + +class MLServiceInfo(ServiceInfo): + """Extended service metadata for ML services with author, organization, and assessment info.""" + + author: str | None = None + author_note: str | None = None + author_assessed_status: AssessedStatus | None = None + contact_email: EmailStr | None = None + organization: str | None = None + organization_logo_url: str | HttpUrl | None = None + citation_info: str | None = None + + +@dataclass(slots=True) +class _ConfigOptions: + """Internal config options for ServiceBuilder.""" + + schema: type[BaseConfig] + prefix: str = "/api/v1/configs" + tags: List[str] = field(default_factory=lambda: ["Config"]) + permissions: CrudPermissions = field(default_factory=CrudPermissions) + + +@dataclass(slots=True) +class _ArtifactOptions: + """Internal artifact options for ServiceBuilder.""" + + hierarchy: ArtifactHierarchy + prefix: str = "/api/v1/artifacts" + tags: List[str] = field(default_factory=lambda: ["Artifacts"]) + enable_config_linking: bool = False + permissions: CrudPermissions = field(default_factory=CrudPermissions) + + +@dataclass(slots=True) +class _TaskOptions: + """Internal task options for ServiceBuilder.""" + + prefix: str = "/api/v1/tasks" + tags: List[str] = field(default_factory=lambda: ["Tasks"]) + permissions: CrudPermissions = field(default_factory=CrudPermissions) + + +@dataclass(slots=True) +class _MLOptions: + """Internal ML options for ServiceBuilder.""" + + runner: ModelRunnerProtocol + prefix: str = "/api/v1/ml" + tags: List[str] = field(default_factory=lambda: ["ML"]) + + +class ServiceBuilder(BaseServiceBuilder): + """Service builder with integrated module support (config, artifact, task).""" + + def __init__(self, **kwargs: Any) -> None: + """Initialize service builder with module-specific state.""" + super().__init__(**kwargs) + self._config_options: _ConfigOptions | None = None + self._artifact_options: _ArtifactOptions | None = None + self._task_options: _TaskOptions | None = None + self._ml_options: _MLOptions | None = None + + # --------------------------------------------------------------------- Module-specific fluent methods + + def with_config( + self, + schema: type[BaseConfig], + *, + prefix: str = "/api/v1/configs", + tags: List[str] | None = None, + permissions: CrudPermissions | None = None, + allow_create: bool | None = None, + allow_read: bool | None = None, + allow_update: bool | None = None, + allow_delete: bool | None = None, + ) -> Self: + base = permissions or CrudPermissions() + perms = CrudPermissions( + create=allow_create if allow_create is not None else base.create, + read=allow_read if allow_read is not None else base.read, + update=allow_update if allow_update is not None else base.update, + delete=allow_delete if allow_delete is not None else base.delete, + ) + self._config_options = _ConfigOptions( + schema=schema, + prefix=prefix, + tags=list(tags) if tags else ["Config"], + permissions=perms, + ) + return self + + def with_artifacts( + self, + *, + hierarchy: ArtifactHierarchy, + prefix: str = "/api/v1/artifacts", + tags: List[str] | None = None, + enable_config_linking: bool = False, + permissions: CrudPermissions | None = None, + allow_create: bool | None = None, + allow_read: bool | None = None, + allow_update: bool | None = None, + allow_delete: bool | None = None, + ) -> Self: + base = permissions or CrudPermissions() + perms = CrudPermissions( + create=allow_create if allow_create is not None else base.create, + read=allow_read if allow_read is not None else base.read, + update=allow_update if allow_update is not None else base.update, + delete=allow_delete if allow_delete is not None else base.delete, + ) + self._artifact_options = _ArtifactOptions( + hierarchy=hierarchy, + prefix=prefix, + tags=list(tags) if tags else ["Artifacts"], + enable_config_linking=enable_config_linking, + permissions=perms, + ) + return self + + def with_tasks( + self, + *, + prefix: str = "/api/v1/tasks", + tags: List[str] | None = None, + permissions: CrudPermissions | None = None, + allow_create: bool | None = None, + allow_read: bool | None = None, + allow_update: bool | None = None, + allow_delete: bool | None = None, + ) -> Self: + """Enable task execution endpoints with script runner.""" + base = permissions or CrudPermissions() + perms = CrudPermissions( + create=allow_create if allow_create is not None else base.create, + read=allow_read if allow_read is not None else base.read, + update=allow_update if allow_update is not None else base.update, + delete=allow_delete if allow_delete is not None else base.delete, + ) + self._task_options = _TaskOptions( + prefix=prefix, + tags=list(tags) if tags else ["Tasks"], + permissions=perms, + ) + return self + + def with_ml( + self, + runner: ModelRunnerProtocol, + *, + prefix: str = "/api/v1/ml", + tags: List[str] | None = None, + ) -> Self: + """Enable ML train/predict endpoints with model runner.""" + self._ml_options = _MLOptions( + runner=runner, + prefix=prefix, + tags=list(tags) if tags else ["ML"], + ) + return self + + # --------------------------------------------------------------------- Extension point implementations + + def _validate_module_configuration(self) -> None: + """Validate module-specific configuration.""" + if self._artifact_options and self._artifact_options.enable_config_linking and not self._config_options: + raise ValueError( + "Artifact config-linking requires a config schema. " + "Call `with_config(...)` before enabling config linking in artifacts." + ) + + if self._task_options and not self._artifact_options: + raise ValueError( + "Task execution requires artifacts to store results. Call `with_artifacts(...)` before `with_tasks()`." + ) + + if self._ml_options: + if not self._config_options: + raise ValueError( + "ML operations require config for model configuration. " + "Call `with_config(...)` before `with_ml(...)`." + ) + if not self._artifact_options: + raise ValueError( + "ML operations require artifacts for model storage. " + "Call `with_artifacts(...)` before `with_ml(...)`." + ) + if not self._job_options: + raise ValueError( + "ML operations require job scheduler for async execution. " + "Call `with_jobs(...)` before `with_ml(...)`." + ) + + def _register_module_routers(self, app: FastAPI) -> None: + """Register module-specific routers (config, artifact, task).""" + if self._config_options: + config_options = self._config_options + config_schema = config_options.schema + config_dep = self._build_config_dependency(config_schema) + entity_in_type: type[ConfigIn[BaseConfig]] = ConfigIn[config_schema] # type: ignore[valid-type] + entity_out_type: type[ConfigOut[BaseConfig]] = ConfigOut[config_schema] # type: ignore[valid-type] + config_router = ConfigRouter.create( + prefix=config_options.prefix, + tags=config_options.tags, + manager_factory=config_dep, + entity_in_type=entity_in_type, + entity_out_type=entity_out_type, + permissions=config_options.permissions, + enable_artifact_operations=( + self._artifact_options is not None and self._artifact_options.enable_config_linking + ), + ) + app.include_router(config_router) + app.dependency_overrides[default_get_config_manager] = config_dep + + if self._artifact_options: + artifact_options = self._artifact_options + artifact_dep = self._build_artifact_dependency( + hierarchy=artifact_options.hierarchy, + include_config=artifact_options.enable_config_linking, + ) + artifact_router = ArtifactRouter.create( + prefix=artifact_options.prefix, + tags=artifact_options.tags, + manager_factory=artifact_dep, + entity_in_type=ArtifactIn, + entity_out_type=ArtifactOut, + permissions=artifact_options.permissions, + enable_config_access=self._config_options is not None and artifact_options.enable_config_linking, + ) + app.include_router(artifact_router) + app.dependency_overrides[default_get_artifact_manager] = artifact_dep + + if self._task_options: + task_options = self._task_options + task_dep = self._build_task_dependency() + task_router = TaskRouter.create( + prefix=task_options.prefix, + tags=task_options.tags, + manager_factory=task_dep, + entity_in_type=TaskIn, + entity_out_type=TaskOut, + permissions=task_options.permissions, + ) + app.include_router(task_router) + app.dependency_overrides[default_get_task_manager] = task_dep + + if self._ml_options: + ml_options = self._ml_options + ml_dep = self._build_ml_dependency() + ml_router = MLRouter.create( + prefix=ml_options.prefix, + tags=ml_options.tags, + manager_factory=ml_dep, + ) + app.include_router(ml_router) + app.dependency_overrides[default_get_ml_manager] = ml_dep + + # --------------------------------------------------------------------- Module dependency builders + + @staticmethod + def _build_config_dependency( + schema: type[BaseConfig], + ) -> DependencyFactory: + async def _dependency(session: AsyncSession = Depends(get_session)) -> ConfigManager[BaseConfig]: + repo = ConfigRepository(session) + return ConfigManager[BaseConfig](repo, schema) + + return _dependency + + @staticmethod + def _build_artifact_dependency( + *, + hierarchy: ArtifactHierarchy, + include_config: bool, + ) -> DependencyFactory: + async def _dependency(session: AsyncSession = Depends(get_session)) -> ArtifactManager: + artifact_repo = ArtifactRepository(session) + config_repo = ConfigRepository(session) if include_config else None + return ArtifactManager(artifact_repo, hierarchy=hierarchy, config_repo=config_repo) + + return _dependency + + @staticmethod + def _build_task_dependency() -> DependencyFactory: + async def _dependency( + session: AsyncSession = Depends(get_session), + artifact_manager: ArtifactManager = Depends(default_get_artifact_manager), + ) -> TaskManager: + repo = TaskRepository(session) + try: + scheduler = get_scheduler() + except RuntimeError: + scheduler = None + try: + database = get_database() + except RuntimeError: + database = None + return TaskManager(repo, scheduler, database, artifact_manager) + + return _dependency + + def _build_ml_dependency(self) -> DependencyFactory: + ml_runner = self._ml_options.runner if self._ml_options else None + config_schema = self._config_options.schema if self._config_options else None + + async def _dependency() -> MLManager: + if ml_runner is None: + raise RuntimeError("ML runner not configured") + if config_schema is None: + raise RuntimeError("Config schema not configured") + + runner: ModelRunnerProtocol = ml_runner + scheduler = get_scheduler() + database = get_database() + return MLManager(runner, scheduler, database, config_schema) + + return _dependency + + +class MLServiceBuilder(ServiceBuilder): + """Specialized service builder for ML services with all required components pre-configured.""" + + def __init__( + self, + *, + info: ServiceInfo | MLServiceInfo, + config_schema: type[BaseConfig], + hierarchy: ArtifactHierarchy, + runner: ModelRunnerProtocol, + database_url: str = "sqlite+aiosqlite:///:memory:", + include_error_handlers: bool = True, + include_logging: bool = True, + ) -> None: + """Initialize ML service builder with required ML components.""" + super().__init__( + info=info, + database_url=database_url, + include_error_handlers=include_error_handlers, + include_logging=include_logging, + ) + + # Automatically configure required ML components + self.with_health() + self.with_system() + self.with_config(config_schema) + self.with_artifacts(hierarchy=hierarchy, enable_config_linking=True) + self.with_jobs() + self.with_landing_page() + self.with_ml(runner=runner) diff --git a/packages/servicekit/src/servicekit/core/__init__.py b/packages/servicekit/src/servicekit/core/__init__.py new file mode 100644 index 0000000..0c3f986 --- /dev/null +++ b/packages/servicekit/src/servicekit/core/__init__.py @@ -0,0 +1,78 @@ +"""Core framework components - generic interfaces and base classes.""" + +# ruff: noqa: F401 + +# Base infrastructure (framework-agnostic) +from .database import Database, SqliteDatabase, SqliteDatabaseBuilder +from .exceptions import ( + BadRequestError, + ChapkitException, + ConflictError, + ErrorType, + ForbiddenError, + InvalidULIDError, + NotFoundError, + UnauthorizedError, + ValidationError, +) +from .logging import add_request_context, clear_request_context, configure_logging, get_logger, reset_request_context +from .manager import BaseManager, LifecycleHooks, Manager +from .models import Base, Entity +from .repository import BaseRepository, Repository +from .scheduler import AIOJobScheduler, JobScheduler +from .schemas import ( + BulkOperationError, + BulkOperationResult, + EntityIn, + EntityOut, + JobRecord, + JobStatus, + PaginatedResponse, + ProblemDetail, +) +from .types import SerializableDict, ULIDType + +__all__ = [ + # Base infrastructure + "Database", + "SqliteDatabase", + "SqliteDatabaseBuilder", + "Repository", + "BaseRepository", + "Manager", + "LifecycleHooks", + "BaseManager", + # ORM and types + "Base", + "Entity", + "ULIDType", + "SerializableDict", + # Schemas + "EntityIn", + "EntityOut", + "PaginatedResponse", + "BulkOperationResult", + "BulkOperationError", + "ProblemDetail", + "JobRecord", + "JobStatus", + # Job scheduling + "JobScheduler", + "AIOJobScheduler", + # Exceptions + "ErrorType", + "ChapkitException", + "NotFoundError", + "ValidationError", + "ConflictError", + "InvalidULIDError", + "BadRequestError", + "UnauthorizedError", + "ForbiddenError", + # Logging + "configure_logging", + "get_logger", + "add_request_context", + "clear_request_context", + "reset_request_context", +] diff --git a/packages/servicekit/src/servicekit/core/api/__init__.py b/packages/servicekit/src/servicekit/core/api/__init__.py new file mode 100644 index 0000000..3a6c667 --- /dev/null +++ b/packages/servicekit/src/servicekit/core/api/__init__.py @@ -0,0 +1,72 @@ +"""FastAPI framework layer - routers, middleware, utilities.""" + +from .app import App, AppInfo, AppLoader, AppManager, AppManifest +from .auth import APIKeyMiddleware, load_api_keys_from_env, load_api_keys_from_file, validate_api_key_format +from .crud import CrudPermissions, CrudRouter +from .dependencies import ( + get_app_manager, + get_database, + get_scheduler, + get_session, + set_app_manager, + set_database, + set_scheduler, +) +from .middleware import add_error_handlers, add_logging_middleware, database_error_handler, validation_error_handler +from .pagination import PaginationParams, create_paginated_response +from .router import Router +from .routers import HealthRouter, HealthState, HealthStatus, JobRouter, SystemInfo, SystemRouter +from .service_builder import BaseServiceBuilder, ServiceInfo +from .sse import SSE_HEADERS, format_sse_event, format_sse_model_event +from .utilities import build_location_url, run_app + +__all__ = [ + # Base router classes + "Router", + "CrudRouter", + "CrudPermissions", + # Service builder + "BaseServiceBuilder", + "ServiceInfo", + # App system + "App", + "AppInfo", + "AppLoader", + "AppManifest", + "AppManager", + # Authentication + "APIKeyMiddleware", + "load_api_keys_from_env", + "load_api_keys_from_file", + "validate_api_key_format", + # Dependencies + "get_app_manager", + "set_app_manager", + "get_database", + "set_database", + "get_session", + "get_scheduler", + "set_scheduler", + # Middleware + "add_error_handlers", + "add_logging_middleware", + "database_error_handler", + "validation_error_handler", + # Pagination + "PaginationParams", + "create_paginated_response", + # System routers + "HealthRouter", + "HealthState", + "HealthStatus", + "JobRouter", + "SystemRouter", + "SystemInfo", + # SSE utilities + "SSE_HEADERS", + "format_sse_event", + "format_sse_model_event", + # Utilities + "build_location_url", + "run_app", +] diff --git a/packages/servicekit/src/servicekit/core/api/app.py b/packages/servicekit/src/servicekit/core/api/app.py new file mode 100644 index 0000000..609d759 --- /dev/null +++ b/packages/servicekit/src/servicekit/core/api/app.py @@ -0,0 +1,225 @@ +"""App system for hosting static web applications.""" + +from __future__ import annotations + +import importlib.util +import json +from dataclasses import dataclass +from pathlib import Path + +from pydantic import BaseModel, ConfigDict, Field, field_validator + +from servicekit.core.logging import get_logger + +logger = get_logger(__name__) + + +class AppManifest(BaseModel): + """App manifest configuration.""" + + model_config = ConfigDict(extra="forbid") + + name: str = Field(description="Human-readable app name") + version: str = Field(description="Semantic version") + prefix: str = Field(description="URL prefix for mounting") + description: str | None = Field(default=None, description="App description") + author: str | None = Field(default=None, description="Author name") + entry: str = Field(default="index.html", description="Entry point filename") + + @field_validator("prefix") + @classmethod + def validate_prefix(cls, v: str) -> str: + """Validate mount prefix format.""" + if not v.startswith("/"): + raise ValueError("prefix must start with '/'") + if ".." in v: + raise ValueError("prefix cannot contain '..'") + if v.startswith("/api/") or v == "/api": + raise ValueError("prefix cannot be '/api' or start with '/api/'") + return v + + @field_validator("entry") + @classmethod + def validate_entry(cls, v: str) -> str: + """Validate entry file path for security.""" + if ".." in v: + raise ValueError("entry cannot contain '..'") + if v.startswith("/"): + raise ValueError("entry must be a relative path") + # Normalize and check for path traversal + normalized = Path(v).as_posix() + if normalized.startswith("../") or "/../" in normalized: + raise ValueError("entry cannot contain path traversal") + return v + + +@dataclass +class App: + """Represents a loaded app with manifest and directory.""" + + manifest: AppManifest + directory: Path + prefix: str # May differ from manifest if overridden + is_package: bool # True if loaded from package resources + + +class AppLoader: + """Loads and validates apps from filesystem or package resources.""" + + @staticmethod + def load(path: str | Path | tuple[str, str], prefix: str | None = None) -> App: + """Load and validate app from filesystem path or package resource tuple.""" + # Detect source type and resolve to directory + if isinstance(path, tuple): + # Package resource + dir_path, is_package = AppLoader._resolve_package_path(path) + else: + # Filesystem path + dir_path = Path(path).resolve() + is_package = False + + if not dir_path.exists(): + raise FileNotFoundError(f"App directory not found: {dir_path}") + if not dir_path.is_dir(): + raise NotADirectoryError(f"App path is not a directory: {dir_path}") + + # Load and validate manifest + manifest_path = dir_path / "manifest.json" + if not manifest_path.exists(): + raise FileNotFoundError(f"manifest.json not found in: {dir_path}") + + try: + with manifest_path.open() as f: + manifest_data = json.load(f) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in manifest.json: {e}") from e + + manifest = AppManifest(**manifest_data) + + # Validate entry file exists + entry_path = dir_path / manifest.entry + if not entry_path.exists(): + raise FileNotFoundError(f"Entry file '{manifest.entry}' not found in: {dir_path}") + + # Use override or manifest prefix + final_prefix = prefix if prefix is not None else manifest.prefix + + # Re-validate prefix if overridden + if prefix is not None: + validated = AppManifest( + name=manifest.name, + version=manifest.version, + prefix=final_prefix, + ) + final_prefix = validated.prefix + + return App( + manifest=manifest, + directory=dir_path, + prefix=final_prefix, + is_package=is_package, + ) + + @staticmethod + def discover(path: str | Path | tuple[str, str]) -> list[App]: + """Discover all apps with manifest.json in directory.""" + # Resolve directory + if isinstance(path, tuple): + dir_path, _ = AppLoader._resolve_package_path(path) + else: + dir_path = Path(path).resolve() + + if not dir_path.exists(): + raise FileNotFoundError(f"Apps directory not found: {dir_path}") + if not dir_path.is_dir(): + raise NotADirectoryError(f"Apps path is not a directory: {dir_path}") + + # Scan for subdirectories with manifest.json + apps: list[App] = [] + for subdir in dir_path.iterdir(): + if subdir.is_dir() and (subdir / "manifest.json").exists(): + try: + # Determine if we're in a package context + if isinstance(path, tuple): + # Build tuple path for subdirectory + package_name: str = path[0] + base_path: str = path[1] + subdir_name = subdir.name + subpath = f"{base_path}/{subdir_name}" if base_path else subdir_name + app = AppLoader.load((package_name, subpath)) + else: + app = AppLoader.load(subdir) + apps.append(app) + except Exception as e: + # Log but don't fail discovery for invalid apps + logger.warning( + "app.discovery.failed", + directory=str(subdir), + error=str(e), + ) + + return apps + + @staticmethod + def _resolve_package_path(package_tuple: tuple[str, str]) -> tuple[Path, bool]: + """Resolve package resource tuple to filesystem path.""" + package_name, subpath = package_tuple + + # Validate subpath for security + if ".." in subpath: + raise ValueError(f"subpath cannot contain '..' (got: {subpath})") + if subpath.startswith("/"): + raise ValueError(f"subpath must be relative (got: {subpath})") + + try: + spec = importlib.util.find_spec(package_name) + except (ModuleNotFoundError, ValueError) as e: + raise ValueError(f"Package '{package_name}' could not be found") from e + + if spec is None or spec.origin is None: + raise ValueError(f"Package '{package_name}' could not be found") + + # Resolve to package directory + package_dir = Path(spec.origin).parent + app_dir = package_dir / subpath + + # Verify resolved path is still within package directory + try: + app_dir.resolve().relative_to(package_dir.resolve()) + except ValueError as e: + raise ValueError(f"App path '{subpath}' escapes package directory") from e + + if not app_dir.exists(): + raise FileNotFoundError(f"App path '{subpath}' not found in package '{package_name}'") + if not app_dir.is_dir(): + raise NotADirectoryError(f"App path '{subpath}' in package '{package_name}' is not a directory") + + return app_dir, True + + +class AppInfo(BaseModel): + """App metadata for API responses.""" + + name: str = Field(description="Human-readable app name") + version: str = Field(description="Semantic version") + prefix: str = Field(description="URL prefix for mounting") + description: str | None = Field(default=None, description="App description") + author: str | None = Field(default=None, description="Author name") + entry: str = Field(description="Entry point filename") + is_package: bool = Field(description="Whether app is loaded from package resources") + + +class AppManager: + """Lightweight manager for app metadata queries.""" + + def __init__(self, apps: list[App]): + """Initialize with loaded apps.""" + self._apps = apps + + def list(self) -> list[App]: + """Return all installed apps.""" + return self._apps + + def get(self, prefix: str) -> App | None: + """Get app by mount prefix.""" + return next((app for app in self._apps if app.prefix == prefix), None) diff --git a/packages/servicekit/src/servicekit/core/api/apps/landing/index.html b/packages/servicekit/src/servicekit/core/api/apps/landing/index.html new file mode 100644 index 0000000..ef0bc30 --- /dev/null +++ b/packages/servicekit/src/servicekit/core/api/apps/landing/index.html @@ -0,0 +1,146 @@ + + + + + + Service Information + + + +
+
Loading service information...
+
+ + + diff --git a/packages/servicekit/src/servicekit/core/api/apps/landing/manifest.json b/packages/servicekit/src/servicekit/core/api/apps/landing/manifest.json new file mode 100644 index 0000000..821b68b --- /dev/null +++ b/packages/servicekit/src/servicekit/core/api/apps/landing/manifest.json @@ -0,0 +1,8 @@ +{ + "name": "Chapkit Landing Page", + "version": "1.0.0", + "prefix": "/", + "description": "Default landing page showing service information", + "author": "Chapkit Team", + "entry": "index.html" +} diff --git a/packages/servicekit/src/servicekit/core/api/auth.py b/packages/servicekit/src/servicekit/core/api/auth.py new file mode 100644 index 0000000..4a10809 --- /dev/null +++ b/packages/servicekit/src/servicekit/core/api/auth.py @@ -0,0 +1,162 @@ +"""API key authentication middleware and utilities.""" + +import os +from pathlib import Path +from typing import Any, Set + +from fastapi import Request, Response, status +from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware + +from servicekit.core.logging import get_logger +from servicekit.core.schemas import ProblemDetail + +from .middleware import MiddlewareCallNext + +logger = get_logger(__name__) + + +class APIKeyMiddleware(BaseHTTPMiddleware): + """Middleware for API key authentication via X-API-Key header.""" + + def __init__( + self, + app: Any, + *, + api_keys: Set[str], + header_name: str = "X-API-Key", + unauthenticated_paths: Set[str], + ) -> None: + """Initialize API key middleware. + + Args: + app: ASGI application + api_keys: Set of valid API keys + header_name: HTTP header name for API key + unauthenticated_paths: Paths that don't require authentication + """ + super().__init__(app) + self.api_keys = api_keys + self.header_name = header_name + self.unauthenticated_paths = unauthenticated_paths + + async def dispatch(self, request: Request, call_next: MiddlewareCallNext) -> Response: + """Process request with API key authentication.""" + # Allow unauthenticated access to specific paths + if request.url.path in self.unauthenticated_paths: + return await call_next(request) + + # Extract API key from header + api_key = request.headers.get(self.header_name) + + if not api_key: + logger.warning( + "auth.missing_key", + path=request.url.path, + method=request.method, + ) + problem = ProblemDetail( + type="urn:chapkit:error:unauthorized", + title="Unauthorized", + status=status.HTTP_401_UNAUTHORIZED, + detail=f"Missing authentication header: {self.header_name}", + instance=str(request.url.path), + ) + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content=problem.model_dump(exclude_none=True), + media_type="application/problem+json", + ) + + # Validate API key + if api_key not in self.api_keys: + # Log only prefix for security + key_prefix = api_key[:7] if len(api_key) >= 7 else "***" + logger.warning( + "auth.invalid_key", + key_prefix=key_prefix, + path=request.url.path, + method=request.method, + ) + problem = ProblemDetail( + type="urn:chapkit:error:unauthorized", + title="Unauthorized", + status=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API key", + instance=str(request.url.path), + ) + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content=problem.model_dump(exclude_none=True), + media_type="application/problem+json", + ) + + # Attach key prefix to request state for logging + request.state.api_key_prefix = api_key[:7] if len(api_key) >= 7 else "***" + + logger.info( + "auth.success", + key_prefix=request.state.api_key_prefix, + path=request.url.path, + ) + + return await call_next(request) + + +def load_api_keys_from_env(env_var: str = "CHAPKIT_API_KEYS") -> Set[str]: + """Load API keys from environment variable (comma-separated). + + Args: + env_var: Environment variable name + + Returns: + Set of API keys + """ + env_value = os.getenv(env_var, "") + if not env_value: + return set() + return {key.strip() for key in env_value.split(",") if key.strip()} + + +def load_api_keys_from_file(file_path: str | Path) -> Set[str]: + """Load API keys from file (one key per line). + + Args: + file_path: Path to file containing API keys + + Returns: + Set of API keys + + Raises: + FileNotFoundError: If file doesn't exist + """ + path = Path(file_path) + if not path.exists(): + raise FileNotFoundError(f"API key file not found: {file_path}") + + keys = set() + with path.open("r") as f: + for line in f: + line = line.strip() + if line and not line.startswith("#"): # Skip empty lines and comments + keys.add(line) + + return keys + + +def validate_api_key_format(key: str) -> bool: + """Validate API key format. + + Args: + key: API key to validate + + Returns: + True if key format is valid + """ + # Basic validation: minimum length + if len(key) < 16: + return False + # Optional: Check for prefix pattern like sk_env_random + # if not key.startswith("sk_"): + # return False + return True diff --git a/packages/servicekit/src/servicekit/core/api/crud.py b/packages/servicekit/src/servicekit/core/api/crud.py new file mode 100644 index 0000000..664f116 --- /dev/null +++ b/packages/servicekit/src/servicekit/core/api/crud.py @@ -0,0 +1,346 @@ +"""CRUD router base class for standard REST operations.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable + +from fastapi import Depends, Request, Response, status +from pydantic import BaseModel +from ulid import ULID + +from servicekit.core.api.router import Router +from servicekit.core.manager import Manager +from servicekit.core.schemas import PaginatedResponse + +# Type alias for manager factory function +type ManagerFactory[InSchemaT: BaseModel, OutSchemaT: BaseModel] = Callable[..., Manager[InSchemaT, OutSchemaT, ULID]] + + +@dataclass(slots=True) +class CrudPermissions: + """Permissions configuration for CRUD operations.""" + + create: bool = True + read: bool = True + update: bool = True + delete: bool = True + + +class CrudRouter[InSchemaT: BaseModel, OutSchemaT: BaseModel](Router): + """Router base class for standard REST CRUD operations.""" + + def __init__( + self, + prefix: str, + tags: list[str], + entity_in_type: type[InSchemaT], + entity_out_type: type[OutSchemaT], + manager_factory: ManagerFactory[InSchemaT, OutSchemaT], + *, + permissions: CrudPermissions | None = None, + **kwargs: Any, + ) -> None: + """Initialize CRUD router with entity types and manager factory.""" + self.manager_factory = manager_factory + self.entity_in_type = entity_in_type + self.entity_out_type = entity_out_type + self._permissions = permissions or CrudPermissions() + super().__init__(prefix=prefix, tags=tags, **kwargs) + + def _register_routes(self) -> None: + """Register CRUD routes based on permissions.""" + manager_dependency, manager_annotation = self._manager_dependency() + perms = self._permissions + if perms.create: + self._register_create_route(manager_dependency, manager_annotation) + if perms.read: + self._register_find_all_route(manager_dependency, manager_annotation) + self._register_find_by_id_route(manager_dependency, manager_annotation) + self._register_schema_route() + if perms.update: + self._register_update_route(manager_dependency, manager_annotation) + if perms.delete: + self._register_delete_route(manager_dependency, manager_annotation) + + def register_entity_operation( + self, + name: str, + handler: Callable[..., Any], + *, + http_method: str = "GET", + response_model: type[Any] | None = None, + status_code: int | None = None, + summary: str | None = None, + description: str | None = None, + ) -> None: + """Register a custom entity operation with $ prefix. + + Entity operations are automatically inserted before generic {entity_id} routes + to ensure proper route matching (e.g., /{entity_id}/$validate should match + before /{entity_id}). + """ + route = f"/{{entity_id}}/${name}" + route_kwargs: dict[str, Any] = {} + + if response_model is not None: + route_kwargs["response_model"] = response_model + if status_code is not None: + route_kwargs["status_code"] = status_code + if summary is not None: + route_kwargs["summary"] = summary + if description is not None: + route_kwargs["description"] = description + + # Register the route with the appropriate HTTP method + http_method_lower = http_method.lower() + if http_method_lower == "get": + self.router.get(route, **route_kwargs)(handler) + elif http_method_lower == "post": + self.router.post(route, **route_kwargs)(handler) + elif http_method_lower == "put": + self.router.put(route, **route_kwargs)(handler) + elif http_method_lower == "patch": + self.router.patch(route, **route_kwargs)(handler) + elif http_method_lower == "delete": + self.router.delete(route, **route_kwargs)(handler) + else: + raise ValueError(f"Unsupported HTTP method: {http_method}") + + # Move the just-added route to before generic parametric routes + # Entity operations like /{entity_id}/$validate should match before /{entity_id} + if len(self.router.routes) > 1: + new_route = self.router.routes.pop() + insert_index = self._find_generic_parametric_route_index() + self.router.routes.insert(insert_index, new_route) + + def register_collection_operation( + self, + name: str, + handler: Callable[..., Any], + *, + http_method: str = "GET", + response_model: type[Any] | None = None, + status_code: int | None = None, + summary: str | None = None, + description: str | None = None, + ) -> None: + """Register a custom collection operation with $ prefix. + + Collection operations are automatically inserted before parametric {entity_id} routes + to ensure proper route matching (e.g., /$stats should match before /{entity_id}). + """ + route = f"/${name}" + route_kwargs: dict[str, Any] = {} + + if response_model is not None: + route_kwargs["response_model"] = response_model + if status_code is not None: + route_kwargs["status_code"] = status_code + if summary is not None: + route_kwargs["summary"] = summary + if description is not None: + route_kwargs["description"] = description + + # Register the route with the appropriate HTTP method + http_method_lower = http_method.lower() + if http_method_lower == "get": + self.router.get(route, **route_kwargs)(handler) + elif http_method_lower == "post": + self.router.post(route, **route_kwargs)(handler) + elif http_method_lower == "put": + self.router.put(route, **route_kwargs)(handler) + elif http_method_lower == "patch": + self.router.patch(route, **route_kwargs)(handler) + elif http_method_lower == "delete": + self.router.delete(route, **route_kwargs)(handler) + else: + raise ValueError(f"Unsupported HTTP method: {http_method}") + + # Move the just-added route to before parametric routes + # FastAPI appends to routes list, so the last route is the one we just added + if len(self.router.routes) > 1: + new_route = self.router.routes.pop() # Remove the route we just added + # Find the first parametric route and insert before it + insert_index = self._find_parametric_route_index() + self.router.routes.insert(insert_index, new_route) + + # Route registration helpers -------------------------------------- + + def _register_create_route(self, manager_dependency: Any, manager_annotation: Any) -> None: + entity_in_annotation: Any = self.entity_in_type + entity_out_annotation: Any = self.entity_out_type + router_prefix = self.router.prefix + + @self.router.post("", status_code=status.HTTP_201_CREATED, response_model=entity_out_annotation) + async def create( + entity_in: InSchemaT, + request: Request, + response: Response, + manager: Manager[InSchemaT, OutSchemaT, ULID] = manager_dependency, + ) -> OutSchemaT: + from .utilities import build_location_url + + created_entity = await manager.save(entity_in) + entity_id = getattr(created_entity, "id") + response.headers["Location"] = build_location_url(request, f"{router_prefix}/{entity_id}") + return created_entity + + self._annotate_manager(create, manager_annotation) + create.__annotations__["entity_in"] = entity_in_annotation + create.__annotations__["return"] = entity_out_annotation + + def _register_find_all_route(self, manager_dependency: Any, manager_annotation: Any) -> None: + entity_out_annotation: Any = self.entity_out_type + collection_response_model: Any = list[entity_out_annotation] | PaginatedResponse[entity_out_annotation] + + @self.router.get("", response_model=collection_response_model) + async def find_all( + page: int | None = None, + size: int | None = None, + manager: Manager[InSchemaT, OutSchemaT, ULID] = manager_dependency, + ) -> list[OutSchemaT] | PaginatedResponse[OutSchemaT]: + from .pagination import create_paginated_response + + # Pagination is opt-in: both page and size must be provided + if page is not None and size is not None: + items, total = await manager.find_paginated(page, size) + return create_paginated_response(items, total, page, size) + return await manager.find_all() + + self._annotate_manager(find_all, manager_annotation) + find_all.__annotations__["return"] = list[entity_out_annotation] | PaginatedResponse[entity_out_annotation] + + def _register_find_by_id_route(self, manager_dependency: Any, manager_annotation: Any) -> None: + entity_out_annotation: Any = self.entity_out_type + router_prefix = self.router.prefix + + @self.router.get("/{entity_id}", response_model=entity_out_annotation) + async def find_by_id( + entity_id: str, + manager: Manager[InSchemaT, OutSchemaT, ULID] = manager_dependency, + ) -> OutSchemaT: + from servicekit.core.exceptions import NotFoundError + + ulid_id = self._parse_ulid(entity_id) + entity = await manager.find_by_id(ulid_id) + if entity is None: + raise NotFoundError( + f"Entity with id {entity_id} not found", + instance=f"{router_prefix}/{entity_id}", + ) + return entity + + self._annotate_manager(find_by_id, manager_annotation) + find_by_id.__annotations__["return"] = entity_out_annotation + + def _register_update_route(self, manager_dependency: Any, manager_annotation: Any) -> None: + entity_in_type = self.entity_in_type + entity_in_annotation: Any = entity_in_type + entity_out_annotation: Any = self.entity_out_type + router_prefix = self.router.prefix + + @self.router.put("/{entity_id}", response_model=entity_out_annotation) + async def update( + entity_id: str, + entity_in: InSchemaT, + manager: Manager[InSchemaT, OutSchemaT, ULID] = manager_dependency, + ) -> OutSchemaT: + from servicekit.core.exceptions import NotFoundError + + ulid_id = self._parse_ulid(entity_id) + if not await manager.exists_by_id(ulid_id): + raise NotFoundError( + f"Entity with id {entity_id} not found", + instance=f"{router_prefix}/{entity_id}", + ) + entity_dict = entity_in.model_dump(exclude_unset=True) + entity_dict["id"] = ulid_id + entity_with_id = entity_in_type.model_validate(entity_dict) + return await manager.save(entity_with_id) + + self._annotate_manager(update, manager_annotation) + update.__annotations__["entity_in"] = entity_in_annotation + update.__annotations__["return"] = entity_out_annotation + + def _register_delete_route(self, manager_dependency: Any, manager_annotation: Any) -> None: + router_prefix = self.router.prefix + + @self.router.delete("/{entity_id}", status_code=status.HTTP_204_NO_CONTENT) + async def delete_by_id( + entity_id: str, + manager: Manager[InSchemaT, OutSchemaT, ULID] = manager_dependency, + ) -> None: + from servicekit.core.exceptions import NotFoundError + + ulid_id = self._parse_ulid(entity_id) + if not await manager.exists_by_id(ulid_id): + raise NotFoundError( + f"Entity with id {entity_id} not found", + instance=f"{router_prefix}/{entity_id}", + ) + await manager.delete_by_id(ulid_id) + + self._annotate_manager(delete_by_id, manager_annotation) + + def _register_schema_route(self) -> None: + """Register JSON schema endpoint for the entity output type.""" + entity_out_type = self.entity_out_type + + async def get_schema() -> dict[str, Any]: + return entity_out_type.model_json_schema() + + self.register_collection_operation( + name="schema", + handler=get_schema, + http_method="GET", + response_model=dict[str, Any], + ) + + # Helper utilities ------------------------------------------------- + + def _manager_dependency(self) -> tuple[Any, Any]: + manager_dependency = Depends(self.manager_factory) + manager_annotation: Any = Manager[Any, Any, ULID] + return manager_dependency, manager_annotation + + def _annotate_manager(self, endpoint: Any, manager_annotation: Any) -> None: + endpoint.__annotations__["manager"] = manager_annotation + + def _parse_ulid(self, entity_id: str) -> ULID: + from servicekit.core.exceptions import InvalidULIDError + + try: + return ULID.from_str(entity_id) + except ValueError as e: + raise InvalidULIDError( + f"Invalid ULID format: {entity_id}", + instance=f"{self.router.prefix}/{entity_id}", + ) from e + + def _find_parametric_route_index(self) -> int: + """Find the index of the first parametric route containing {entity_id}. + + Returns the index where collection operations should be inserted to ensure + they're matched before parametric routes. + """ + for i, route in enumerate(self.router.routes): + route_path = getattr(route, "path", "") + if "{entity_id}" in route_path: + return i + # If no parametric route found, append at the end + return len(self.router.routes) + + def _find_generic_parametric_route_index(self) -> int: + """Find the index of the first generic parametric route (/{entity_id} without $). + + Returns the index where entity operations should be inserted to ensure + they're matched before generic routes like GET/PUT/DELETE /{entity_id}. + """ + for i, route in enumerate(self.router.routes): + route_path = getattr(route, "path", "") + # Match routes like /{entity_id} but not /{entity_id}/$operation + if "{entity_id}" in route_path and "/$" not in route_path: + return i + # If no generic parametric route found, append at the end + return len(self.router.routes) diff --git a/packages/servicekit/src/servicekit/core/api/dependencies.py b/packages/servicekit/src/servicekit/core/api/dependencies.py new file mode 100644 index 0000000..e2af541 --- /dev/null +++ b/packages/servicekit/src/servicekit/core/api/dependencies.py @@ -0,0 +1,70 @@ +"""Generic FastAPI dependency injection for database and scheduler.""" + +from __future__ import annotations + +from collections.abc import AsyncIterator +from typing import TYPE_CHECKING, Annotated + +from fastapi import Depends +from sqlalchemy.ext.asyncio import AsyncSession + +from servicekit.core import Database +from servicekit.core.scheduler import JobScheduler + +if TYPE_CHECKING: + from .app import AppManager + +# Global database instance - should be initialized at app startup +_database: Database | None = None + +# Global scheduler instance - should be initialized at app startup +_scheduler: JobScheduler | None = None + + +def set_database(database: Database) -> None: + """Set the global database instance.""" + global _database + _database = database + + +def get_database() -> Database: + """Get the global database instance.""" + if _database is None: + raise RuntimeError("Database not initialized. Call set_database() during app startup.") + return _database + + +async def get_session(db: Annotated[Database, Depends(get_database)]) -> AsyncIterator[AsyncSession]: + """Get a database session for dependency injection.""" + async with db.session() as session: + yield session + + +def set_scheduler(scheduler: JobScheduler) -> None: + """Set the global scheduler instance.""" + global _scheduler + _scheduler = scheduler + + +def get_scheduler() -> JobScheduler: + """Get the global scheduler instance.""" + if _scheduler is None: + raise RuntimeError("Scheduler not initialized. Call set_scheduler() during app startup.") + return _scheduler + + +# Global app manager instance - should be initialized at app startup +_app_manager: AppManager | None = None + + +def set_app_manager(manager: AppManager) -> None: + """Set the global app manager instance.""" + global _app_manager + _app_manager = manager + + +def get_app_manager() -> AppManager: + """Get the global app manager instance.""" + if _app_manager is None: + raise RuntimeError("AppManager not initialized. Call set_app_manager() during app startup.") + return _app_manager diff --git a/packages/servicekit/src/servicekit/core/api/middleware.py b/packages/servicekit/src/servicekit/core/api/middleware.py new file mode 100644 index 0000000..f9de33a --- /dev/null +++ b/packages/servicekit/src/servicekit/core/api/middleware.py @@ -0,0 +1,141 @@ +"""FastAPI middleware for error handling, CORS, and other cross-cutting concerns.""" + +import time +from typing import Any, Awaitable, Callable + +from fastapi import Request, Response, status +from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware +from ulid import ULID + +from servicekit.core.exceptions import ChapkitException +from servicekit.core.logging import add_request_context, get_logger, reset_request_context +from servicekit.core.schemas import ProblemDetail + +logger = get_logger(__name__) + +# Type alias for middleware call_next function +type MiddlewareCallNext = Callable[[Request], Awaitable[Response]] + + +class RequestLoggingMiddleware(BaseHTTPMiddleware): + """Middleware for logging HTTP requests with unique request IDs and context binding.""" + + async def dispatch(self, request: Request, call_next: MiddlewareCallNext) -> Response: + """Process request with logging and context binding.""" + request_id = str(ULID()) + start_time = time.perf_counter() + + # Bind request context + add_request_context( + request_id=request_id, + method=request.method, + path=request.url.path, + client_host=request.client.host if request.client else None, + ) + + # Add request_id to request state for access in endpoints + request.state.request_id = request_id + + logger.info( + "http.request.start", + query_params=str(request.url.query) if request.url.query else None, + ) + + try: + response = await call_next(request) + duration_ms = (time.perf_counter() - start_time) * 1000 + + logger.info( + "http.request.complete", + status_code=response.status_code, + duration_ms=round(duration_ms, 2), + ) + + # Add request_id to response headers for tracing + response.headers["X-Request-ID"] = request_id + + return response + + except Exception as exc: + duration_ms = (time.perf_counter() - start_time) * 1000 + + logger.error( + "http.request.error", + duration_ms=round(duration_ms, 2), + error=str(exc), + exc_info=True, + ) + raise + + finally: + # Clear request context after response + reset_request_context() + + +async def database_error_handler(request: Request, exc: Exception) -> JSONResponse: + """Handle database errors and return error response.""" + logger.error( + "database.error", + error=str(exc), + path=request.url.path, + exc_info=True, + ) + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={"detail": "Database error occurred", "error": str(exc)}, + ) + + +async def validation_error_handler(request: Request, exc: Exception) -> JSONResponse: + """Handle validation errors and return error response.""" + logger.warning( + "validation.error", + error=str(exc), + path=request.url.path, + ) + return JSONResponse( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + content={"detail": "Validation error", "errors": str(exc)}, + ) + + +async def chapkit_exception_handler(request: Request, exc: ChapkitException) -> JSONResponse: + """Handle ChapkitException and return RFC 9457 Problem Details.""" + logger.warning( + "chapkit.error", + error_type=exc.type_uri, + status=exc.status, + detail=exc.detail, + path=request.url.path, + ) + + problem = ProblemDetail( + type=exc.type_uri, + title=exc.title, + status=exc.status, + detail=exc.detail, + instance=exc.instance or str(request.url), + **exc.extensions, + ) + + return JSONResponse( + status_code=exc.status, + content=problem.model_dump(exclude_none=True), + media_type="application/problem+json", + ) + + +def add_error_handlers(app: Any) -> None: + """Add error handlers to FastAPI application.""" + from pydantic import ValidationError + from sqlalchemy.exc import SQLAlchemyError + + app.add_exception_handler(ChapkitException, chapkit_exception_handler) + app.add_exception_handler(SQLAlchemyError, database_error_handler) + app.add_exception_handler(ValidationError, validation_error_handler) + + +def add_logging_middleware(app: Any) -> None: + """Add request logging middleware to FastAPI application.""" + app.add_middleware(RequestLoggingMiddleware) diff --git a/packages/servicekit/src/servicekit/core/api/monitoring.py b/packages/servicekit/src/servicekit/core/api/monitoring.py new file mode 100644 index 0000000..caf5177 --- /dev/null +++ b/packages/servicekit/src/servicekit/core/api/monitoring.py @@ -0,0 +1,100 @@ +"""OpenTelemetry monitoring setup with auto-instrumentation.""" + +from fastapi import FastAPI +from opentelemetry import metrics +from opentelemetry.exporter.prometheus import PrometheusMetricReader +from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor +from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.resources import Resource +from prometheus_client import REGISTRY, ProcessCollector + +from servicekit.core.logging import get_logger + +logger = get_logger(__name__) + +# Global state to track instrumentation +_meter_provider_initialized = False +_sqlalchemy_instrumented = False +_process_collector_registered = False + + +def setup_monitoring( + app: FastAPI, + *, + service_name: str | None = None, + enable_traces: bool = False, +) -> PrometheusMetricReader: + """Setup OpenTelemetry with FastAPI and SQLAlchemy auto-instrumentation.""" + global _meter_provider_initialized, _sqlalchemy_instrumented, _process_collector_registered + + # Use app title as service name if not provided + service_name = service_name or app.title + + # Create resource with service name + resource = Resource.create({"service.name": service_name}) + + # Setup Prometheus metrics exporter - only once globally + reader = PrometheusMetricReader() + if not _meter_provider_initialized: + provider = MeterProvider(resource=resource, metric_readers=[reader]) + metrics.set_meter_provider(provider) + _meter_provider_initialized = True + + # Register process collector for CPU, memory, and Python runtime metrics + if not _process_collector_registered: + try: + ProcessCollector(registry=REGISTRY) + _process_collector_registered = True + except ValueError: + # Already registered + pass + + # Auto-instrument FastAPI - check if already instrumented + instrumentor = FastAPIInstrumentor() + if not instrumentor.is_instrumented_by_opentelemetry: + instrumentor.instrument_app(app) + + # Auto-instrument SQLAlchemy - only once globally + if not _sqlalchemy_instrumented: + try: + SQLAlchemyInstrumentor().instrument() + _sqlalchemy_instrumented = True + except RuntimeError: + # Already instrumented + pass + + logger.info( + "monitoring.enabled", + service_name=service_name, + fastapi_instrumented=True, + sqlalchemy_instrumented=True, + process_metrics=True, + ) + + if enable_traces: + logger.warning( + "monitoring.traces_not_implemented", + message="Distributed tracing is not yet implemented", + ) + + return reader + + +def teardown_monitoring() -> None: + """Teardown OpenTelemetry instrumentation.""" + try: + # Uninstrument FastAPI + FastAPIInstrumentor().uninstrument() + + # Uninstrument SQLAlchemy + SQLAlchemyInstrumentor().uninstrument() + + logger.info("monitoring.disabled") + except Exception as e: + logger.warning("monitoring.teardown_failed", error=str(e)) + + +def get_meter(name: str) -> metrics.Meter: + """Get a meter for custom metrics.""" + return metrics.get_meter(name) diff --git a/packages/servicekit/src/servicekit/core/api/pagination.py b/packages/servicekit/src/servicekit/core/api/pagination.py new file mode 100644 index 0000000..4861f9a --- /dev/null +++ b/packages/servicekit/src/servicekit/core/api/pagination.py @@ -0,0 +1,27 @@ +"""Pagination utilities for API endpoints.""" + +from __future__ import annotations + +from typing import TypeVar + +from pydantic import BaseModel, Field + +from servicekit.core.schemas import PaginatedResponse + +T = TypeVar("T") + + +class PaginationParams(BaseModel): + """Query parameters for opt-in pagination (both page and size required).""" + + page: int | None = Field(default=None, ge=1, description="Page number (1-indexed)") + size: int | None = Field(default=None, ge=1, le=100, description="Number of items per page (max 100)") + + def is_paginated(self) -> bool: + """Check if both page and size parameters are provided.""" + return self.page is not None and self.size is not None + + +def create_paginated_response(items: list[T], total: int, page: int, size: int) -> PaginatedResponse[T]: + """Create paginated response with items and metadata.""" + return PaginatedResponse(items=items, total=total, page=page, size=size) diff --git a/packages/servicekit/src/servicekit/core/api/router.py b/packages/servicekit/src/servicekit/core/api/router.py new file mode 100644 index 0000000..79d06fb --- /dev/null +++ b/packages/servicekit/src/servicekit/core/api/router.py @@ -0,0 +1,28 @@ +"""Base classes for API routers.""" + +from abc import ABC, abstractmethod +from collections.abc import Sequence +from typing import Any + +from fastapi import APIRouter + + +class Router(ABC): + """Base class for FastAPI routers.""" + + default_response_model_exclude_none: bool = False + + def __init__(self, prefix: str, tags: Sequence[str], **kwargs: Any) -> None: + """Initialize router with prefix and tags.""" + self.router = APIRouter(prefix=prefix, tags=list(tags), **kwargs) + self._register_routes() + + @classmethod + def create(cls, prefix: str, tags: Sequence[str], **kwargs: Any) -> APIRouter: + """Create a router instance and return the FastAPI router.""" + return cls(prefix=prefix, tags=tags, **kwargs).router + + @abstractmethod + def _register_routes(self) -> None: + """Register routes for this router.""" + ... diff --git a/packages/servicekit/src/servicekit/core/api/routers/__init__.py b/packages/servicekit/src/servicekit/core/api/routers/__init__.py new file mode 100644 index 0000000..9d088e9 --- /dev/null +++ b/packages/servicekit/src/servicekit/core/api/routers/__init__.py @@ -0,0 +1,18 @@ +"""Core routers for health, job, metrics, and system endpoints.""" + +from .health import CheckResult, HealthCheck, HealthRouter, HealthState, HealthStatus +from .job import JobRouter +from .metrics import MetricsRouter +from .system import SystemInfo, SystemRouter + +__all__ = [ + "HealthRouter", + "HealthStatus", + "HealthState", + "HealthCheck", + "CheckResult", + "JobRouter", + "MetricsRouter", + "SystemRouter", + "SystemInfo", +] diff --git a/packages/servicekit/src/servicekit/core/api/routers/health.py b/packages/servicekit/src/servicekit/core/api/routers/health.py new file mode 100644 index 0000000..ffb6eae --- /dev/null +++ b/packages/servicekit/src/servicekit/core/api/routers/health.py @@ -0,0 +1,114 @@ +"""Health check router.""" + +from __future__ import annotations + +import asyncio +from collections.abc import AsyncGenerator, Awaitable, Callable +from enum import StrEnum + +from fastapi.responses import StreamingResponse +from pydantic import BaseModel, Field + +from ..router import Router +from ..sse import SSE_HEADERS, format_sse_model_event + + +class HealthState(StrEnum): + """Health state enumeration for health checks.""" + + HEALTHY = "healthy" + DEGRADED = "degraded" + UNHEALTHY = "unhealthy" + + +HealthCheck = Callable[[], Awaitable[tuple[HealthState, str | None]]] + + +class CheckResult(BaseModel): + """Result of an individual health check.""" + + state: HealthState = Field(description="Health state of this check") + message: str | None = Field(default=None, description="Optional message or error detail") + + +class HealthStatus(BaseModel): + """Overall health status response.""" + + status: HealthState = Field(description="Overall service health indicator") + checks: dict[str, CheckResult] | None = Field( + default=None, description="Individual health check results (if checks are configured)" + ) + + +class HealthRouter(Router): + """Health check router for service health monitoring.""" + + default_response_model_exclude_none = True + + def __init__( + self, + prefix: str, + tags: list[str], + checks: dict[str, HealthCheck] | None = None, + **kwargs: object, + ) -> None: + """Initialize health router with optional health checks.""" + self.checks = checks or {} + super().__init__(prefix=prefix, tags=tags, **kwargs) + + def _register_routes(self) -> None: + """Register health check endpoint.""" + checks = self.checks + + async def run_health_checks() -> HealthStatus: + """Run all health checks and aggregate results.""" + if not checks: + return HealthStatus(status=HealthState.HEALTHY) + + check_results: dict[str, CheckResult] = {} + overall_state = HealthState.HEALTHY + + for name, check_fn in checks.items(): + try: + state, message = await check_fn() + check_results[name] = CheckResult(state=state, message=message) + + if state == HealthState.UNHEALTHY: + overall_state = HealthState.UNHEALTHY + elif state == HealthState.DEGRADED and overall_state == HealthState.HEALTHY: + overall_state = HealthState.DEGRADED + + except Exception as e: + check_results[name] = CheckResult(state=HealthState.UNHEALTHY, message=f"Check failed: {str(e)}") + overall_state = HealthState.UNHEALTHY + + return HealthStatus(status=overall_state, checks=check_results) + + @self.router.get( + "", + summary="Health check", + response_model=HealthStatus, + response_model_exclude_none=self.default_response_model_exclude_none, + ) + async def health_check() -> HealthStatus: + return await run_health_checks() + + @self.router.get( + "/$stream", + summary="Stream health status updates via SSE", + description="Real-time Server-Sent Events stream of health status at regular intervals", + ) + async def stream_health_status(poll_interval: float = 1.0) -> StreamingResponse: + """Stream real-time health status updates using Server-Sent Events.""" + + async def event_stream() -> AsyncGenerator[bytes, None]: + while True: + status = await run_health_checks() + yield format_sse_model_event(status, exclude_none=self.default_response_model_exclude_none) + await asyncio.sleep(poll_interval) + + return StreamingResponse( + event_stream(), + media_type="text/event-stream", + headers=SSE_HEADERS, + ) diff --git a/packages/servicekit/src/servicekit/core/api/routers/job.py b/packages/servicekit/src/servicekit/core/api/routers/job.py new file mode 100644 index 0000000..6e78020 --- /dev/null +++ b/packages/servicekit/src/servicekit/core/api/routers/job.py @@ -0,0 +1,126 @@ +"""REST API router for job scheduler (list, get, delete jobs).""" + +from __future__ import annotations + +import asyncio +from collections.abc import AsyncGenerator, Callable +from typing import Any + +import ulid +from fastapi import Depends, HTTPException, status +from fastapi.responses import Response, StreamingResponse +from pydantic import TypeAdapter + +from servicekit.core.api.router import Router +from servicekit.core.api.sse import SSE_HEADERS, format_sse_model_event +from servicekit.core.scheduler import JobScheduler +from servicekit.core.schemas import JobRecord, JobStatus + +ULID = ulid.ULID + + +class JobRouter(Router): + """REST API router for job scheduler operations.""" + + def __init__( + self, + prefix: str, + tags: list[str], + scheduler_factory: Callable[[], JobScheduler], + **kwargs: object, + ) -> None: + """Initialize job router with scheduler factory.""" + self.scheduler_factory = scheduler_factory + super().__init__(prefix=prefix, tags=tags, **kwargs) + + def _register_routes(self) -> None: + """Register job management endpoints.""" + scheduler_dependency = Depends(self.scheduler_factory) + + @self.router.get("", summary="List all jobs", response_model=list[JobRecord]) + async def get_jobs( + scheduler: JobScheduler = scheduler_dependency, + status_filter: JobStatus | None = None, + ) -> list[JobRecord]: + jobs = await scheduler.get_all_records() + if status_filter: + return [job for job in jobs if job.status == status_filter] + return jobs + + @self.router.get("/$schema", summary="Get jobs list schema", response_model=dict[str, Any]) + async def get_jobs_schema() -> dict[str, Any]: + """Get JSON schema for jobs list response.""" + return TypeAdapter(list[JobRecord]).json_schema() + + @self.router.get("/{job_id}", summary="Get job by ID", response_model=JobRecord) + async def get_job( + job_id: str, + scheduler: JobScheduler = scheduler_dependency, + ) -> JobRecord: + try: + ulid_id = ULID.from_str(job_id) + return await scheduler.get_record(ulid_id) + except (ValueError, KeyError): + raise HTTPException(status_code=404, detail="Job not found") + + @self.router.delete("/{job_id}", summary="Cancel and delete job", status_code=status.HTTP_204_NO_CONTENT) + async def delete_job( + job_id: str, + scheduler: JobScheduler = scheduler_dependency, + ) -> Response: + try: + ulid_id = ULID.from_str(job_id) + await scheduler.delete(ulid_id) + return Response(status_code=status.HTTP_204_NO_CONTENT) + except (ValueError, KeyError): + raise HTTPException(status_code=404, detail="Job not found") + + @self.router.get( + "/{job_id}/$stream", + summary="Stream job status updates via SSE", + description="Real-time Server-Sent Events stream of job status changes until terminal state", + ) + async def stream_job_status( + job_id: str, + scheduler: JobScheduler = scheduler_dependency, + poll_interval: float = 0.5, + ) -> StreamingResponse: + """Stream real-time job status updates using Server-Sent Events.""" + # Validate job_id format + try: + ulid_id = ULID.from_str(job_id) + except ValueError: + raise HTTPException(status_code=400, detail="Invalid job ID format") + + # Check job exists before starting stream + try: + await scheduler.get_record(ulid_id) + except KeyError: + raise HTTPException(status_code=404, detail="Job not found") + + # SSE event generator + async def event_stream() -> AsyncGenerator[bytes, None]: + terminal_states = {"completed", "failed", "canceled"} + + while True: + try: + record = await scheduler.get_record(ulid_id) + # Format as SSE event + yield format_sse_model_event(record) + + # Stop streaming if job reached terminal state + if record.status in terminal_states: + break + + except KeyError: + # Job was deleted - send final event and close + yield b'data: {"status": "deleted"}\n\n' + break + + await asyncio.sleep(poll_interval) + + return StreamingResponse( + event_stream(), + media_type="text/event-stream", + headers=SSE_HEADERS, + ) diff --git a/packages/servicekit/src/servicekit/core/api/routers/metrics.py b/packages/servicekit/src/servicekit/core/api/routers/metrics.py new file mode 100644 index 0000000..c38a649 --- /dev/null +++ b/packages/servicekit/src/servicekit/core/api/routers/metrics.py @@ -0,0 +1,47 @@ +"""Metrics router for Prometheus endpoint.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from fastapi import Response + +from ..router import Router + +if TYPE_CHECKING: + from opentelemetry.exporter.prometheus import PrometheusMetricReader + + +class MetricsRouter(Router): + """Metrics router for Prometheus metrics exposition.""" + + def __init__( + self, + prefix: str, + tags: list[str], + metric_reader: PrometheusMetricReader, + **kwargs: object, + ) -> None: + """Initialize metrics router with Prometheus metric reader.""" + self.metric_reader = metric_reader + super().__init__(prefix=prefix, tags=tags, **kwargs) + + def _register_routes(self) -> None: + """Register Prometheus metrics endpoint.""" + + @self.router.get( + "", + summary="Prometheus metrics", + response_class=Response, + ) + async def get_metrics() -> Response: + """Expose metrics in Prometheus text format.""" + # Get latest metrics from the reader + from prometheus_client import REGISTRY, generate_latest + + metrics_output = generate_latest(REGISTRY) + + return Response( + content=metrics_output, + media_type="text/plain; version=0.0.4; charset=utf-8", + ) diff --git a/packages/servicekit/src/servicekit/core/api/routers/system.py b/packages/servicekit/src/servicekit/core/api/routers/system.py new file mode 100644 index 0000000..793a9af --- /dev/null +++ b/packages/servicekit/src/servicekit/core/api/routers/system.py @@ -0,0 +1,77 @@ +"""System information router.""" + +from __future__ import annotations + +import platform +import sys +from datetime import datetime, timezone +from typing import Annotated, Any + +from fastapi import Depends +from pydantic import BaseModel, Field, TypeAdapter + +from ..app import AppInfo, AppManager +from ..dependencies import get_app_manager +from ..router import Router + + +class SystemInfo(BaseModel): + """System information response.""" + + current_time: datetime = Field(description="Current server time in UTC") + timezone: str = Field(description="Server timezone") + python_version: str = Field(description="Python version") + platform: str = Field(description="Operating system platform") + hostname: str = Field(description="Server hostname") + + +class SystemRouter(Router): + """System information router.""" + + def _register_routes(self) -> None: + """Register system info endpoint.""" + + @self.router.get( + "", + summary="System information", + response_model=SystemInfo, + ) + async def get_system_info() -> SystemInfo: + return SystemInfo( + current_time=datetime.now(timezone.utc), + timezone=str(datetime.now().astimezone().tzinfo), + python_version=f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}", + platform=platform.platform(), + hostname=platform.node(), + ) + + @self.router.get( + "/apps", + summary="List installed apps", + response_model=list[AppInfo], + ) + async def list_apps( + app_manager: Annotated[AppManager, Depends(get_app_manager)], + ) -> list[AppInfo]: + """List all installed apps with their metadata.""" + return [ + AppInfo( + name=app.manifest.name, + version=app.manifest.version, + prefix=app.prefix, + description=app.manifest.description, + author=app.manifest.author, + entry=app.manifest.entry, + is_package=app.is_package, + ) + for app in app_manager.list() + ] + + @self.router.get( + "/apps/$schema", + summary="Get apps list schema", + response_model=dict[str, Any], + ) + async def get_apps_schema() -> dict[str, Any]: + """Get JSON schema for apps list response.""" + return TypeAdapter(list[AppInfo]).json_schema() diff --git a/packages/servicekit/src/servicekit/core/api/service_builder.py b/packages/servicekit/src/servicekit/core/api/service_builder.py new file mode 100644 index 0000000..9ce5f41 --- /dev/null +++ b/packages/servicekit/src/servicekit/core/api/service_builder.py @@ -0,0 +1,656 @@ +"""Base service builder for FastAPI applications without module dependencies.""" + +from __future__ import annotations + +import re +from contextlib import asynccontextmanager +from dataclasses import dataclass +from pathlib import Path +from typing import Any, AsyncContextManager, AsyncIterator, Awaitable, Callable, Dict, List, Self + +from fastapi import APIRouter, FastAPI +from pydantic import BaseModel, ConfigDict +from sqlalchemy import text + +from servicekit.core import Database, SqliteDatabase +from servicekit.core.logging import configure_logging, get_logger + +from .app import App, AppLoader +from .auth import APIKeyMiddleware, load_api_keys_from_env, load_api_keys_from_file +from .dependencies import get_database, get_scheduler, set_database, set_scheduler +from .middleware import add_error_handlers, add_logging_middleware +from .routers import HealthRouter, JobRouter, MetricsRouter, SystemRouter +from .routers.health import HealthCheck, HealthState + +logger = get_logger(__name__) + +# Type aliases for service builder +type LifecycleHook = Callable[[FastAPI], Awaitable[None]] +type DependencyOverride = Callable[..., object] +type LifespanFactory = Callable[[FastAPI], AsyncContextManager[None]] + + +@dataclass(frozen=True) +class _HealthOptions: + """Configuration for health check endpoints.""" + + prefix: str + tags: List[str] + checks: dict[str, HealthCheck] + + +@dataclass(frozen=True) +class _SystemOptions: + """Configuration for system info endpoints.""" + + prefix: str + tags: List[str] + + +@dataclass(frozen=True) +class _JobOptions: + """Configuration for job scheduler endpoints.""" + + prefix: str + tags: List[str] + max_concurrency: int | None + + +@dataclass(frozen=True) +class _AuthOptions: + """Configuration for API key authentication.""" + + api_keys: set[str] + header_name: str + unauthenticated_paths: set[str] + source: str + + +@dataclass(frozen=True) +class _MonitoringOptions: + """Configuration for OpenTelemetry monitoring.""" + + prefix: str + tags: List[str] + service_name: str | None + enable_traces: bool + + +class ServiceInfo(BaseModel): + """Service metadata for FastAPI application.""" + + display_name: str + version: str = "1.0.0" + summary: str | None = None + description: str | None = None + contact: dict[str, str] | None = None + license_info: dict[str, str] | None = None + + model_config = ConfigDict(extra="forbid") + + +class BaseServiceBuilder: + """Base service builder providing core FastAPI functionality without module dependencies.""" + + def __init__( + self, + *, + info: ServiceInfo, + database_url: str = "sqlite+aiosqlite:///:memory:", + include_error_handlers: bool = True, + include_logging: bool = False, + ) -> None: + """Initialize base service builder with core options.""" + if info.description is None and info.summary is not None: + # Preserve summary as description for FastAPI metadata if description missing + self.info = info.model_copy(update={"description": info.summary}) + else: + self.info = info + self._title = self.info.display_name + self._app_description = self.info.summary or self.info.description or "" + self._version = self.info.version + self._database_url = database_url + self._database_instance: Database | None = None + self._pool_size: int = 5 + self._max_overflow: int = 10 + self._pool_recycle: int = 3600 + self._pool_pre_ping: bool = True + self._include_error_handlers = include_error_handlers + self._include_logging = include_logging + self._health_options: _HealthOptions | None = None + self._system_options: _SystemOptions | None = None + self._job_options: _JobOptions | None = None + self._auth_options: _AuthOptions | None = None + self._monitoring_options: _MonitoringOptions | None = None + self._app_configs: List[App] = [] + self._custom_routers: List[APIRouter] = [] + self._dependency_overrides: Dict[DependencyOverride, DependencyOverride] = {} + self._startup_hooks: List[LifecycleHook] = [] + self._shutdown_hooks: List[LifecycleHook] = [] + + # --------------------------------------------------------------------- Fluent configuration + + def with_database( + self, + url_or_instance: str | Database | None = None, + *, + pool_size: int = 5, + max_overflow: int = 10, + pool_recycle: int = 3600, + pool_pre_ping: bool = True, + ) -> Self: + """Configure database with URL string, Database instance, or default in-memory SQLite.""" + if isinstance(url_or_instance, Database): + # Pre-configured instance provided + self._database_instance = url_or_instance + return self # Skip pool configuration for instances + elif isinstance(url_or_instance, str): + # String URL provided + self._database_url = url_or_instance + elif url_or_instance is None: + # Default: in-memory SQLite + self._database_url = "sqlite+aiosqlite:///:memory:" + else: + raise TypeError( + f"Expected str, Database, or None, got {type(url_or_instance).__name__}. " + "Use .with_database() for default, .with_database('url') for custom URL, " + "or .with_database(db_instance) for pre-configured database." + ) + + # Configure pool settings (only applies to URL-based databases) + self._pool_size = pool_size + self._max_overflow = max_overflow + self._pool_recycle = pool_recycle + self._pool_pre_ping = pool_pre_ping + return self + + def with_landing_page(self) -> Self: + """Enable landing page at root path.""" + return self.with_app(("chapkit.core.api", "apps/landing")) + + def with_logging(self, enabled: bool = True) -> Self: + """Enable structured logging with request tracing.""" + self._include_logging = enabled + return self + + def with_health( + self, + *, + prefix: str = "/health", + tags: List[str] | None = None, + checks: dict[str, HealthCheck] | None = None, + include_database_check: bool = True, + ) -> Self: + """Add health check endpoint with optional custom checks.""" + health_checks = checks or {} + + if include_database_check: + health_checks["database"] = self._create_database_health_check() + + self._health_options = _HealthOptions( + prefix=prefix, + tags=list(tags) if tags is not None else ["Observability"], + checks=health_checks, + ) + return self + + def with_system( + self, + *, + prefix: str = "/api/v1/system", + tags: List[str] | None = None, + ) -> Self: + """Add system info endpoint.""" + self._system_options = _SystemOptions( + prefix=prefix, + tags=list(tags) if tags is not None else ["Service"], + ) + return self + + def with_jobs( + self, + *, + prefix: str = "/api/v1/jobs", + tags: List[str] | None = None, + max_concurrency: int | None = None, + ) -> Self: + """Add job scheduler endpoints.""" + self._job_options = _JobOptions( + prefix=prefix, + tags=list(tags) if tags is not None else ["Jobs"], + max_concurrency=max_concurrency, + ) + return self + + def with_auth( + self, + *, + api_keys: List[str] | None = None, + api_key_file: str | None = None, + env_var: str = "CHAPKIT_API_KEYS", + header_name: str = "X-API-Key", + unauthenticated_paths: List[str] | None = None, + ) -> Self: + """Enable API key authentication.""" + keys: set[str] = set() + auth_source: str = "" # Track source for later logging + + # Priority 1: Direct list (examples/dev) + if api_keys is not None: + keys = set(api_keys) + auth_source = "direct_keys" + + # Priority 2: File (Docker secrets) + elif api_key_file is not None: + keys = load_api_keys_from_file(api_key_file) + auth_source = f"file:{api_key_file}" + + # Priority 3: Environment variable (default) + else: + keys = load_api_keys_from_env(env_var) + if keys: + auth_source = f"env:{env_var}" + else: + auth_source = f"env:{env_var}:empty" + + if not keys: + raise ValueError("No API keys configured. Provide api_keys, api_key_file, or set environment variable.") + + # Default unauthenticated paths + default_unauth = {"/docs", "/redoc", "/openapi.json", "/health", "/"} + unauth_set = set(unauthenticated_paths) if unauthenticated_paths else default_unauth + + self._auth_options = _AuthOptions( + api_keys=keys, + header_name=header_name, + unauthenticated_paths=unauth_set, + source=auth_source, + ) + return self + + def with_monitoring( + self, + *, + prefix: str = "/metrics", + tags: List[str] | None = None, + service_name: str | None = None, + enable_traces: bool = False, + ) -> Self: + """Enable OpenTelemetry monitoring with Prometheus endpoint and auto-instrumentation.""" + self._monitoring_options = _MonitoringOptions( + prefix=prefix, + tags=list(tags) if tags is not None else ["Observability"], + service_name=service_name, + enable_traces=enable_traces, + ) + return self + + def with_app(self, path: str | Path | tuple[str, str], prefix: str | None = None) -> Self: + """Register static app from filesystem path or package resource tuple.""" + app = AppLoader.load(path, prefix=prefix) + self._app_configs.append(app) + return self + + def with_apps(self, path: str | Path | tuple[str, str]) -> Self: + """Auto-discover and register all apps in directory.""" + apps = AppLoader.discover(path) + self._app_configs.extend(apps) + return self + + def include_router(self, router: APIRouter) -> Self: + """Include a custom router.""" + self._custom_routers.append(router) + return self + + def override_dependency(self, dependency: DependencyOverride, override: DependencyOverride) -> Self: + """Override a dependency for testing or customization.""" + self._dependency_overrides[dependency] = override + return self + + def on_startup(self, hook: LifecycleHook) -> Self: + """Register a startup hook.""" + self._startup_hooks.append(hook) + return self + + def on_shutdown(self, hook: LifecycleHook) -> Self: + """Register a shutdown hook.""" + self._shutdown_hooks.append(hook) + return self + + # --------------------------------------------------------------------- Build mechanics + + def build(self) -> FastAPI: + """Build and configure the FastAPI application.""" + self._validate_configuration() + self._validate_module_configuration() # Extension point for subclasses + + lifespan = self._build_lifespan() + app = FastAPI( + title=self._title, + description=self._app_description, + version=self._version, + lifespan=lifespan, + ) + app.state.database_url = self._database_url + + # Override schema generation to clean up generic type names + app.openapi = self._create_openapi_customizer(app) # type: ignore[method-assign] + + if self._include_error_handlers: + add_error_handlers(app) + + if self._include_logging: + add_logging_middleware(app) + + if self._auth_options: + app.add_middleware( + APIKeyMiddleware, + api_keys=self._auth_options.api_keys, + header_name=self._auth_options.header_name, + unauthenticated_paths=self._auth_options.unauthenticated_paths, + ) + # Store auth_source for logging during startup + app.state.auth_source = self._auth_options.source + app.state.auth_key_count = len(self._auth_options.api_keys) + + if self._health_options: + health_router = HealthRouter.create( + prefix=self._health_options.prefix, + tags=self._health_options.tags, + checks=self._health_options.checks, + ) + app.include_router(health_router) + + if self._system_options: + system_router = SystemRouter.create( + prefix=self._system_options.prefix, + tags=self._system_options.tags, + ) + app.include_router(system_router) + + if self._job_options: + job_router = JobRouter.create( + prefix=self._job_options.prefix, + tags=self._job_options.tags, + scheduler_factory=get_scheduler, + ) + app.include_router(job_router) + + if self._monitoring_options: + from .monitoring import setup_monitoring + + metric_reader = setup_monitoring( + app, + service_name=self._monitoring_options.service_name, + enable_traces=self._monitoring_options.enable_traces, + ) + metrics_router = MetricsRouter.create( + prefix=self._monitoring_options.prefix, + tags=self._monitoring_options.tags, + metric_reader=metric_reader, + ) + app.include_router(metrics_router) + + # Extension point for module-specific routers + self._register_module_routers(app) + + for router in self._custom_routers: + app.include_router(router) + + # Install route endpoints BEFORE mounting apps (routes take precedence over mounts) + self._install_info_endpoint(app, info=self.info) + + # Mount apps AFTER all routes (apps act as catch-all for unmatched paths) + if self._app_configs: + from fastapi.staticfiles import StaticFiles + + for app_config in self._app_configs: + static_files = StaticFiles(directory=str(app_config.directory), html=True) + app.mount(app_config.prefix, static_files, name=f"app_{app_config.manifest.name}") + logger.info( + "app.mounted", + name=app_config.manifest.name, + prefix=app_config.prefix, + directory=str(app_config.directory), + is_package=app_config.is_package, + ) + + # Initialize app manager for metadata queries (always, even if no apps) + from .app import AppManager + from .dependencies import set_app_manager + + app_manager = AppManager(self._app_configs) + set_app_manager(app_manager) + + for dependency, override in self._dependency_overrides.items(): + app.dependency_overrides[dependency] = override + + return app + + # --------------------------------------------------------------------- Extension points + + def _validate_module_configuration(self) -> None: + """Extension point for module-specific validation (override in subclasses).""" + pass + + def _register_module_routers(self, app: FastAPI) -> None: + """Extension point for registering module-specific routers (override in subclasses).""" + pass + + # --------------------------------------------------------------------- Core helpers + + def _validate_configuration(self) -> None: + """Validate core configuration.""" + # Validate health check names don't contain invalid characters + if self._health_options: + for name in self._health_options.checks.keys(): + if not name.replace("_", "").replace("-", "").isalnum(): + raise ValueError( + f"Health check name '{name}' contains invalid characters. " + "Only alphanumeric characters, underscores, and hyphens are allowed." + ) + + # Validate app configurations + if self._app_configs: + # Deduplicate apps with same prefix (last one wins) + # This allows overriding apps, especially useful for root prefix "/" + seen_prefixes: dict[str, int] = {} # prefix -> last index + for i, app in enumerate(self._app_configs): + if app.prefix in seen_prefixes: + # Log warning about override + prev_idx = seen_prefixes[app.prefix] + prev_app = self._app_configs[prev_idx] + logger.warning( + "app.prefix.override", + prefix=app.prefix, + replaced_app=prev_app.manifest.name, + new_app=app.manifest.name, + ) + seen_prefixes[app.prefix] = i + + # Keep only the last app for each prefix + self._app_configs = [self._app_configs[i] for i in sorted(set(seen_prefixes.values()))] + + # Validate that non-root prefixes don't have duplicates (shouldn't happen after dedup, but safety check) + prefixes = [app.prefix for app in self._app_configs] + if len(prefixes) != len(set(prefixes)): + raise ValueError("Internal error: duplicate prefixes after deduplication") + + def _build_lifespan(self) -> LifespanFactory: + """Build lifespan context manager for app startup/shutdown.""" + database_url = self._database_url + database_instance = self._database_instance + pool_size = self._pool_size + max_overflow = self._max_overflow + pool_recycle = self._pool_recycle + pool_pre_ping = self._pool_pre_ping + job_options = self._job_options + include_logging = self._include_logging + startup_hooks = list(self._startup_hooks) + shutdown_hooks = list(self._shutdown_hooks) + + @asynccontextmanager + async def lifespan(app: FastAPI) -> AsyncIterator[None]: + # Configure logging if enabled + if include_logging: + configure_logging() + + # Use injected database or create new one from URL + if database_instance is not None: + database = database_instance + should_manage_lifecycle = False + else: + # Create appropriate database type based on URL + if "sqlite" in database_url.lower(): + database = SqliteDatabase( + database_url, + pool_size=pool_size, + max_overflow=max_overflow, + pool_recycle=pool_recycle, + pool_pre_ping=pool_pre_ping, + ) + else: + database = Database( + database_url, + pool_size=pool_size, + max_overflow=max_overflow, + pool_recycle=pool_recycle, + pool_pre_ping=pool_pre_ping, + ) + should_manage_lifecycle = True + + # Always initialize database (safe to call multiple times) + await database.init() + + set_database(database) + app.state.database = database + + # Initialize scheduler if jobs are enabled + if job_options is not None: + from servicekit.core.scheduler import AIOJobScheduler + + scheduler = AIOJobScheduler(max_concurrency=job_options.max_concurrency) + set_scheduler(scheduler) + app.state.scheduler = scheduler + + # Log auth configuration after logging is configured + if hasattr(app.state, "auth_source"): + auth_source = app.state.auth_source + key_count = app.state.auth_key_count + + if auth_source == "direct_keys": + logger.warning( + "auth.direct_keys", + message="Using direct API keys - not recommended for production", + count=key_count, + ) + elif auth_source.startswith("file:"): + file_path = auth_source.split(":", 1)[1] + logger.info("auth.loaded_from_file", file=file_path, count=key_count) + elif auth_source.startswith("env:"): + parts = auth_source.split(":", 2) + env_var = parts[1] + if len(parts) > 2 and parts[2] == "empty": + logger.warning( + "auth.no_keys", + message=f"No API keys found in {env_var}. Service will reject all requests.", + ) + else: + logger.info("auth.loaded_from_env", env_var=env_var, count=key_count) + + for hook in startup_hooks: + await hook(app) + try: + yield + finally: + for hook in shutdown_hooks: + await hook(app) + app.state.database = None + + # Dispose database only if we created it + if should_manage_lifecycle: + await database.dispose() + + return lifespan + + @staticmethod + def _create_database_health_check() -> HealthCheck: + """Create database connectivity health check.""" + + async def check_database() -> tuple[HealthState, str | None]: + try: + db = get_database() + async with db.session() as session: + # Simple connectivity check - execute a trivial query + await session.execute(text("SELECT 1")) + return (HealthState.HEALTHY, None) + except Exception as e: + return (HealthState.UNHEALTHY, f"Database connection failed: {str(e)}") + + return check_database + + @staticmethod + def _create_openapi_customizer(app: FastAPI) -> Callable[[], dict[str, Any]]: + """Create OpenAPI schema customizer that cleans up generic type names.""" + + def custom_openapi() -> dict[str, Any]: + if app.openapi_schema: + return app.openapi_schema + + from fastapi.openapi.utils import get_openapi + + openapi_schema = get_openapi( + title=app.title, + version=app.version, + description=app.description, + routes=app.routes, + ) + + # Clean up schema titles by removing generic type parameters + if "components" in openapi_schema and "schemas" in openapi_schema["components"]: + schemas = openapi_schema["components"]["schemas"] + cleaned_schemas: dict[str, Any] = {} + + for schema_name, schema_def in schemas.items(): + # Remove generic type parameters from schema names + clean_name = re.sub(r"\[.*?\]", "", schema_name) + # If title exists in schema, clean it too + if isinstance(schema_def, dict) and "title" in schema_def: + schema_def["title"] = re.sub(r"\[.*?\]", "", schema_def["title"]) + cleaned_schemas[clean_name] = schema_def + + openapi_schema["components"]["schemas"] = cleaned_schemas + + # Update all $ref pointers to use cleaned names + def clean_refs(obj: Any) -> Any: + if isinstance(obj, dict): + if "$ref" in obj: + obj["$ref"] = re.sub(r"\[.*?\]", "", obj["$ref"]) + for value in obj.values(): + clean_refs(value) + elif isinstance(obj, list): + for item in obj: + clean_refs(item) + + clean_refs(openapi_schema) + + app.openapi_schema = openapi_schema + return app.openapi_schema + + return custom_openapi + + @staticmethod + def _install_info_endpoint(app: FastAPI, *, info: ServiceInfo) -> None: + """Install service info endpoint.""" + info_type = type(info) + + @app.get("/api/v1/info", tags=["Service"], include_in_schema=True, response_model=info_type) + async def get_info() -> ServiceInfo: + return info + + # --------------------------------------------------------------------- Convenience + + @classmethod + def create(cls, *, info: ServiceInfo, **kwargs: Any) -> FastAPI: + """Create and build a FastAPI application in one call.""" + return cls(info=info, **kwargs).build() diff --git a/packages/servicekit/src/servicekit/core/api/sse.py b/packages/servicekit/src/servicekit/core/api/sse.py new file mode 100644 index 0000000..742f51c --- /dev/null +++ b/packages/servicekit/src/servicekit/core/api/sse.py @@ -0,0 +1,23 @@ +"""Server-Sent Events (SSE) utilities for streaming responses.""" + +from __future__ import annotations + +from pydantic import BaseModel + +# Standard SSE headers +SSE_HEADERS = { + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", # Disable nginx buffering +} + + +def format_sse_event(data: str) -> bytes: + """Format data as Server-Sent Events message.""" + return f"data: {data}\n\n".encode("utf-8") + + +def format_sse_model_event(model: BaseModel, exclude_none: bool = False) -> bytes: + """Format Pydantic model as Server-Sent Events message.""" + data = model.model_dump_json(exclude_none=exclude_none) + return format_sse_event(data) diff --git a/packages/servicekit/src/servicekit/core/api/utilities.py b/packages/servicekit/src/servicekit/core/api/utilities.py new file mode 100644 index 0000000..7e4ef5e --- /dev/null +++ b/packages/servicekit/src/servicekit/core/api/utilities.py @@ -0,0 +1,80 @@ +"""Utility functions for FastAPI routers and endpoints.""" + +import os +from typing import Any + +from fastapi import Request + + +def build_location_url(request: Request, path: str) -> str: + """Build a full URL for the Location header.""" + return f"{request.url.scheme}://{request.url.netloc}{path}" + + +def run_app( + app: Any | str, + *, + host: str | None = None, + port: int | None = None, + workers: int | None = None, + reload: bool | None = None, + log_level: str | None = None, + **uvicorn_kwargs: Any, +) -> None: + """Run FastAPI app with Uvicorn development server. + + For reload to work, pass a string in "module:app" format. + App instance disables reload automatically. + + Examples: + -------- + # Direct execution (reload disabled) + if __name__ == "__main__": + run_app(app) + + # With module path (reload enabled) + run_app("examples.config_api:app") + + # Production: multiple workers + run_app(app, workers=4) + + Args: + app: FastAPI app instance OR string "module:app" path + host: Server host (default: "127.0.0.1", env: HOST) + port: Server port (default: 8000, env: PORT) + workers: Number of worker processes (default: 1, env: WORKERS) + reload: Enable auto-reload (default: True for string, False for instance) + log_level: Logging level (default: from LOG_LEVEL env var or "info") + **uvicorn_kwargs: Additional uvicorn.run() arguments + """ + import uvicorn + + # Configure structured logging before uvicorn starts + from servicekit.core.logging import configure_logging + + configure_logging() + + # Read from environment variables with defaults + resolved_host: str = host if host is not None else os.getenv("HOST", "127.0.0.1") + resolved_port: int = port if port is not None else int(os.getenv("PORT", "8000")) + resolved_workers: int = workers if workers is not None else int(os.getenv("WORKERS", "1")) + resolved_log_level: str = log_level if log_level is not None else os.getenv("LOG_LEVEL", "info").lower() + + # Auto-detect reload behavior if not specified + if reload is None: + reload = isinstance(app, str) # Enable reload for string paths, disable for instances + + # Auto-reload is incompatible with multiple workers + if resolved_workers > 1 and reload: + reload = False + + uvicorn.run( + app, + host=resolved_host, + port=resolved_port, + workers=resolved_workers, + reload=reload, + log_level=resolved_log_level, + log_config=None, # Disable uvicorn's default logging config + **uvicorn_kwargs, + ) diff --git a/packages/servicekit/src/servicekit/core/database.py b/packages/servicekit/src/servicekit/core/database.py new file mode 100644 index 0000000..7174a0c --- /dev/null +++ b/packages/servicekit/src/servicekit/core/database.py @@ -0,0 +1,254 @@ +"""Async SQLAlchemy database connection manager.""" + +from __future__ import annotations + +import sqlite3 +from contextlib import asynccontextmanager +from pathlib import Path +from typing import AsyncIterator, Self + +from alembic.config import Config +from sqlalchemy import event +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) +from sqlalchemy.pool import ConnectionPoolEntry + +from alembic import command + + +def _install_sqlite_connect_pragmas(engine: AsyncEngine) -> None: + """Install SQLite connection pragmas for performance and reliability.""" + + def on_connect(dbapi_conn: sqlite3.Connection, _conn_record: ConnectionPoolEntry) -> None: + """Configure SQLite pragmas on connection.""" + cur = dbapi_conn.cursor() + cur.execute("PRAGMA foreign_keys=ON;") + cur.execute("PRAGMA synchronous=NORMAL;") + cur.execute("PRAGMA busy_timeout=30000;") # 30s + cur.execute("PRAGMA temp_store=MEMORY;") + cur.execute("PRAGMA cache_size=-64000;") # 64 MiB (negative => KiB) + cur.execute("PRAGMA mmap_size=134217728;") # 128 MiB + cur.close() + + event.listen(engine.sync_engine, "connect", on_connect) + + +class Database: + """Generic async SQLAlchemy database connection manager.""" + + def __init__( + self, + url: str, + *, + echo: bool = False, + alembic_dir: Path | None = None, + auto_migrate: bool = True, + pool_size: int = 5, + max_overflow: int = 10, + pool_recycle: int = 3600, + pool_pre_ping: bool = True, + ) -> None: + """Initialize database with connection URL and pool configuration.""" + self.url = url + self.alembic_dir = alembic_dir + self.auto_migrate = auto_migrate + + # Build engine kwargs - skip pool params for in-memory SQLite databases + engine_kwargs: dict = {"echo": echo, "future": True} + if ":memory:" not in url: + # Only add pool params for non-in-memory databases + engine_kwargs.update( + { + "pool_size": pool_size, + "max_overflow": max_overflow, + "pool_recycle": pool_recycle, + "pool_pre_ping": pool_pre_ping, + } + ) + + self.engine: AsyncEngine = create_async_engine(url, **engine_kwargs) + self._session_factory: async_sessionmaker[AsyncSession] = async_sessionmaker( + bind=self.engine, class_=AsyncSession, expire_on_commit=False + ) + + async def init(self) -> None: + """Initialize database tables using Alembic migrations or direct creation.""" + import asyncio + + # Import Base here to avoid circular import at module level + from servicekit.core.models import Base + + # For databases without migrations, use direct table creation + if not self.auto_migrate: + async with self.engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + else: + # Use Alembic migrations + alembic_cfg = Config() + + # Use custom alembic directory if provided, otherwise use bundled migrations + if self.alembic_dir is not None: + alembic_cfg.set_main_option("script_location", str(self.alembic_dir)) + else: + alembic_cfg.set_main_option( + "script_location", str(Path(__file__).parent.parent.parent.parent / "alembic") + ) + + alembic_cfg.set_main_option("sqlalchemy.url", self.url) + + # Run upgrade in executor to avoid event loop conflicts + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, command.upgrade, alembic_cfg, "head") + + @asynccontextmanager + async def session(self) -> AsyncIterator[AsyncSession]: + """Create a database session context manager.""" + async with self._session_factory() as s: + yield s + + async def dispose(self) -> None: + """Dispose of database engine and connection pool.""" + await self.engine.dispose() + + +class SqliteDatabase(Database): + """SQLite-specific database implementation with optimizations.""" + + def __init__( + self, + url: str, + *, + echo: bool = False, + alembic_dir: Path | None = None, + auto_migrate: bool = True, + pool_size: int = 5, + max_overflow: int = 10, + pool_recycle: int = 3600, + pool_pre_ping: bool = True, + ) -> None: + """Initialize SQLite database with connection URL and pool configuration.""" + self.url = url + self.alembic_dir = alembic_dir + self.auto_migrate = auto_migrate + + # Build engine kwargs - pool params only for non-in-memory databases + engine_kwargs: dict = {"echo": echo, "future": True} + if not self._is_in_memory_url(url): + # File-based databases can use pool configuration + engine_kwargs.update( + { + "pool_size": pool_size, + "max_overflow": max_overflow, + "pool_recycle": pool_recycle, + "pool_pre_ping": pool_pre_ping, + } + ) + + self.engine: AsyncEngine = create_async_engine(url, **engine_kwargs) + _install_sqlite_connect_pragmas(self.engine) + self._session_factory: async_sessionmaker[AsyncSession] = async_sessionmaker( + bind=self.engine, class_=AsyncSession, expire_on_commit=False + ) + + @staticmethod + def _is_in_memory_url(url: str) -> bool: + """Check if URL represents an in-memory database.""" + return ":memory:" in url + + def is_in_memory(self) -> bool: + """Check if this is an in-memory database.""" + return self._is_in_memory_url(self.url) + + async def init(self) -> None: + """Initialize database tables and configure SQLite using Alembic migrations.""" + # Import Base here to avoid circular import at module level + from servicekit.core.models import Base + + # Set WAL mode first (if not in-memory) + if not self.is_in_memory(): + async with self.engine.begin() as conn: + await conn.exec_driver_sql("PRAGMA journal_mode=WAL;") + + # For in-memory databases or when migrations are disabled, use direct table creation + if self.is_in_memory() or not self.auto_migrate: + async with self.engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + else: + # For file-based databases, use Alembic migrations + await super().init() + + +class SqliteDatabaseBuilder: + """Builder for SQLite database configuration with fluent API.""" + + def __init__(self) -> None: + """Initialize builder with default values.""" + self._url: str = "" + self._echo: bool = False + self._alembic_dir: Path | None = None + self._auto_migrate: bool = True + self._pool_size: int = 5 + self._max_overflow: int = 10 + self._pool_recycle: int = 3600 + self._pool_pre_ping: bool = True + + @classmethod + def in_memory(cls) -> Self: + """Create an in-memory SQLite database configuration.""" + builder = cls() + builder._url = "sqlite+aiosqlite:///:memory:" + return builder + + @classmethod + def from_file(cls, path: str | Path) -> Self: + """Create a file-based SQLite database configuration.""" + builder = cls() + if isinstance(path, Path): + path = str(path) + builder._url = f"sqlite+aiosqlite:///{path}" + return builder + + def with_echo(self, enabled: bool = True) -> Self: + """Enable SQL query logging.""" + self._echo = enabled + return self + + def with_migrations(self, enabled: bool = True, alembic_dir: Path | None = None) -> Self: + """Configure migration behavior.""" + self._auto_migrate = enabled + self._alembic_dir = alembic_dir + return self + + def with_pool( + self, + size: int = 5, + max_overflow: int = 10, + recycle: int = 3600, + pre_ping: bool = True, + ) -> Self: + """Configure connection pool settings.""" + self._pool_size = size + self._max_overflow = max_overflow + self._pool_recycle = recycle + self._pool_pre_ping = pre_ping + return self + + def build(self) -> SqliteDatabase: + """Build and return configured SqliteDatabase instance.""" + if not self._url: + raise ValueError("Database URL not configured. Use .in_memory() or .from_file()") + + return SqliteDatabase( + url=self._url, + echo=self._echo, + alembic_dir=self._alembic_dir, + auto_migrate=self._auto_migrate, + pool_size=self._pool_size, + max_overflow=self._max_overflow, + pool_recycle=self._pool_recycle, + pool_pre_ping=self._pool_pre_ping, + ) diff --git a/packages/servicekit/src/servicekit/core/exceptions.py b/packages/servicekit/src/servicekit/core/exceptions.py new file mode 100644 index 0000000..5f11938 --- /dev/null +++ b/packages/servicekit/src/servicekit/core/exceptions.py @@ -0,0 +1,138 @@ +"""Custom exceptions with RFC 9457 Problem Details support.""" + +from __future__ import annotations + +from typing import Any + + +class ErrorType: + """URN-based error type identifiers for RFC 9457 Problem Details.""" + + NOT_FOUND = "urn:chapkit:error:not-found" + VALIDATION_FAILED = "urn:chapkit:error:validation-failed" + CONFLICT = "urn:chapkit:error:conflict" + INVALID_ULID = "urn:chapkit:error:invalid-ulid" + INTERNAL_ERROR = "urn:chapkit:error:internal" + UNAUTHORIZED = "urn:chapkit:error:unauthorized" + FORBIDDEN = "urn:chapkit:error:forbidden" + BAD_REQUEST = "urn:chapkit:error:bad-request" + + +class ChapkitException(Exception): + """Base exception for chapkit with RFC 9457 Problem Details support.""" + + def __init__( + self, + detail: str, + *, + type_uri: str = ErrorType.INTERNAL_ERROR, + title: str = "Internal Server Error", + status: int = 500, + instance: str | None = None, + **extensions: Any, + ) -> None: + super().__init__(detail) + self.type_uri = type_uri + self.title = title + self.status = status + self.detail = detail + self.instance = instance + self.extensions = extensions + + +class NotFoundError(ChapkitException): + """Resource not found exception (404).""" + + def __init__(self, detail: str, *, instance: str | None = None, **extensions: Any) -> None: + super().__init__( + detail, + type_uri=ErrorType.NOT_FOUND, + title="Resource Not Found", + status=404, + instance=instance, + **extensions, + ) + + +class ValidationError(ChapkitException): + """Validation failed exception (400).""" + + def __init__(self, detail: str, *, instance: str | None = None, **extensions: Any) -> None: + super().__init__( + detail, + type_uri=ErrorType.VALIDATION_FAILED, + title="Validation Failed", + status=400, + instance=instance, + **extensions, + ) + + +class ConflictError(ChapkitException): + """Resource conflict exception (409).""" + + def __init__(self, detail: str, *, instance: str | None = None, **extensions: Any) -> None: + super().__init__( + detail, + type_uri=ErrorType.CONFLICT, + title="Resource Conflict", + status=409, + instance=instance, + **extensions, + ) + + +class InvalidULIDError(ChapkitException): + """Invalid ULID format exception (400).""" + + def __init__(self, detail: str, *, instance: str | None = None, **extensions: Any) -> None: + super().__init__( + detail, + type_uri=ErrorType.INVALID_ULID, + title="Invalid ULID Format", + status=400, + instance=instance, + **extensions, + ) + + +class BadRequestError(ChapkitException): + """Bad request exception (400).""" + + def __init__(self, detail: str, *, instance: str | None = None, **extensions: Any) -> None: + super().__init__( + detail, + type_uri=ErrorType.BAD_REQUEST, + title="Bad Request", + status=400, + instance=instance, + **extensions, + ) + + +class UnauthorizedError(ChapkitException): + """Unauthorized exception (401).""" + + def __init__(self, detail: str, *, instance: str | None = None, **extensions: Any) -> None: + super().__init__( + detail, + type_uri=ErrorType.UNAUTHORIZED, + title="Unauthorized", + status=401, + instance=instance, + **extensions, + ) + + +class ForbiddenError(ChapkitException): + """Forbidden exception (403).""" + + def __init__(self, detail: str, *, instance: str | None = None, **extensions: Any) -> None: + super().__init__( + detail, + type_uri=ErrorType.FORBIDDEN, + title="Forbidden", + status=403, + instance=instance, + **extensions, + ) diff --git a/packages/servicekit/src/servicekit/core/logging.py b/packages/servicekit/src/servicekit/core/logging.py new file mode 100644 index 0000000..ab2f5ce --- /dev/null +++ b/packages/servicekit/src/servicekit/core/logging.py @@ -0,0 +1,99 @@ +"""Structured logging configuration with request tracing support.""" + +import logging +import os +import sys +from typing import Any + +import structlog +from structlog.typing import Processor + + +def configure_logging() -> None: + """Configure structlog and intercept standard library logging.""" + log_format = os.getenv("LOG_FORMAT", "console").lower() + log_level = os.getenv("LOG_LEVEL", "INFO").upper() + level = getattr(logging, log_level, logging.INFO) + + # Shared processors for structlog + shared_processors: list[Processor] = [ + structlog.contextvars.merge_contextvars, + structlog.stdlib.add_log_level, + structlog.processors.TimeStamper(fmt="iso", utc=True), + structlog.processors.StackInfoRenderer(), + ] + + # Choose renderer based on format + if log_format == "json": + formatter_processors = shared_processors + [ + structlog.stdlib.ProcessorFormatter.remove_processors_meta, + structlog.processors.format_exc_info, + structlog.processors.JSONRenderer(), + ] + else: + formatter_processors = shared_processors + [ + structlog.stdlib.ProcessorFormatter.remove_processors_meta, + structlog.processors.ExceptionRenderer(), + structlog.dev.ConsoleRenderer(colors=True), + ] + + # Configure structlog to use standard library logging + structlog.configure( + processors=[ + structlog.contextvars.merge_contextvars, + structlog.stdlib.add_log_level, + structlog.stdlib.add_logger_name, + structlog.processors.TimeStamper(fmt="iso", utc=True), + structlog.processors.StackInfoRenderer(), + structlog.processors.CallsiteParameterAdder( + [ + structlog.processors.CallsiteParameter.FILENAME, + structlog.processors.CallsiteParameter.LINENO, + structlog.processors.CallsiteParameter.FUNC_NAME, + ] + ), + structlog.stdlib.ProcessorFormatter.wrap_for_formatter, + ], + wrapper_class=structlog.make_filtering_bound_logger(level), + context_class=dict, + logger_factory=structlog.stdlib.LoggerFactory(), + cache_logger_on_first_use=True, + ) + + # Configure standard library logging to use structlog formatter + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(structlog.stdlib.ProcessorFormatter(processors=formatter_processors)) + + # Configure root logger + root_logger = logging.getLogger() + root_logger.handlers.clear() + root_logger.addHandler(handler) + root_logger.setLevel(level) + + # Configure uvicorn and gunicorn loggers to use the same handler + for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "gunicorn.access", "gunicorn.error"]: + logger = logging.getLogger(logger_name) + logger.handlers.clear() + logger.addHandler(handler) + logger.setLevel(level) + logger.propagate = False + + +def get_logger(name: str | None = None) -> Any: + """Get a configured structlog logger instance.""" + return structlog.get_logger(name) + + +def add_request_context(**context: Any) -> None: + """Add context variables that will be included in all log messages.""" + structlog.contextvars.bind_contextvars(**context) + + +def clear_request_context(*keys: str) -> None: + """Clear specific context variables.""" + structlog.contextvars.unbind_contextvars(*keys) + + +def reset_request_context() -> None: + """Clear all context variables.""" + structlog.contextvars.clear_contextvars() diff --git a/packages/servicekit/src/servicekit/core/manager.py b/packages/servicekit/src/servicekit/core/manager.py new file mode 100644 index 0000000..f9609b9 --- /dev/null +++ b/packages/servicekit/src/servicekit/core/manager.py @@ -0,0 +1,296 @@ +"""Base classes for service layer managers with lifecycle hooks.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Iterable, Sequence + +from pydantic import BaseModel + +from servicekit.core.repository import BaseRepository + + +class LifecycleHooks[ModelT, InSchemaT: BaseModel]: + """Lifecycle hooks for entity operations.""" + + def _should_assign_field(self, field: str, value: object) -> bool: + """Determine if a field should be assigned during update.""" + return True + + async def pre_save(self, entity: ModelT, data: InSchemaT) -> None: + """Hook called before saving a new entity.""" + pass + + async def post_save(self, entity: ModelT) -> None: + """Hook called after saving a new entity.""" + pass + + async def pre_update(self, entity: ModelT, data: InSchemaT, old_values: dict[str, object]) -> None: + """Hook called before updating an existing entity.""" + pass + + async def post_update(self, entity: ModelT, changes: dict[str, tuple[object, object]]) -> None: + """Hook called after updating an existing entity.""" + pass + + async def pre_delete(self, entity: ModelT) -> None: + """Hook called before deleting an entity.""" + pass + + async def post_delete(self, entity: ModelT) -> None: + """Hook called after deleting an entity.""" + pass + + +class Manager[InSchemaT: BaseModel, OutSchemaT: BaseModel, IdT](ABC): + """Abstract manager interface for business logic operations.""" + + @abstractmethod + async def save(self, data: InSchemaT) -> OutSchemaT: + """Save an entity.""" + ... + + @abstractmethod + async def save_all(self, items: Iterable[InSchemaT]) -> list[OutSchemaT]: + """Save multiple entities.""" + ... + + @abstractmethod + async def delete_by_id(self, id: IdT) -> None: + """Delete an entity by its ID.""" + ... + + @abstractmethod + async def delete_all(self) -> None: + """Delete all entities.""" + ... + + @abstractmethod + async def delete_all_by_id(self, ids: Sequence[IdT]) -> None: + """Delete multiple entities by their IDs.""" + ... + + @abstractmethod + async def count(self) -> int: + """Count the number of entities.""" + ... + + @abstractmethod + async def exists_by_id(self, id: IdT) -> bool: + """Check if an entity exists by its ID.""" + ... + + @abstractmethod + async def find_by_id(self, id: IdT) -> OutSchemaT | None: + """Find an entity by its ID.""" + ... + + @abstractmethod + async def find_all(self) -> list[OutSchemaT]: + """Find all entities.""" + ... + + @abstractmethod + async def find_paginated(self, page: int, size: int) -> tuple[list[OutSchemaT], int]: + """Find entities with pagination.""" + ... + + @abstractmethod + async def find_all_by_id(self, ids: Sequence[IdT]) -> list[OutSchemaT]: + """Find entities by their IDs.""" + ... + + +class BaseManager[ModelT, InSchemaT: BaseModel, OutSchemaT: BaseModel, IdT]( + LifecycleHooks[ModelT, InSchemaT], + Manager[InSchemaT, OutSchemaT, IdT], +): + """Base manager implementation with CRUD operations and lifecycle hooks.""" + + def __init__( + self, + repo: BaseRepository[ModelT, IdT], + model_cls: type[ModelT], + out_schema_cls: type[OutSchemaT], + ) -> None: + """Initialize manager with repository, model class, and output schema class.""" + self.repo = repo + self.model_cls = model_cls + self.out_schema_cls = out_schema_cls + + def _to_output_schema(self, entity: ModelT) -> OutSchemaT: + """Convert ORM entity to output schema.""" + return self.out_schema_cls.model_validate(entity, from_attributes=True) + + async def save(self, data: InSchemaT) -> OutSchemaT: + """Save an entity (create or update).""" + data_dict = data.model_dump(exclude_none=True) + entity_id = data_dict.get("id") + existing: ModelT | None = None + + if entity_id is not None: + existing = await self.repo.find_by_id(entity_id) + + if existing is None: + if data_dict.get("id") is None: + data_dict.pop("id", None) + entity = self.model_cls(**data_dict) + await self.pre_save(entity, data) + await self.repo.save(entity) + await self.repo.commit() + await self.repo.refresh_many([entity]) + await self.post_save(entity) + return self._to_output_schema(entity) + + tracked_fields = set(data_dict.keys()) + if hasattr(existing, "level"): # pragma: no branch + tracked_fields.add("level") + old_values = {field: getattr(existing, field) for field in tracked_fields if hasattr(existing, field)} + + for key, value in data_dict.items(): + if key == "id": # pragma: no branch + continue + if not self._should_assign_field(key, value): + continue + if hasattr(existing, key): + setattr(existing, key, value) + + await self.pre_update(existing, data, old_values) + + changes: dict[str, tuple[object, object]] = {} + for field in tracked_fields: + if hasattr(existing, field): + new_value = getattr(existing, field) + old_value = old_values.get(field) + if old_value != new_value: + changes[field] = (old_value, new_value) + + await self.repo.save(existing) + await self.repo.commit() + await self.repo.refresh_many([existing]) + await self.post_update(existing, changes) + return self._to_output_schema(existing) + + async def save_all(self, items: Iterable[InSchemaT]) -> list[OutSchemaT]: + entities_to_insert: list[ModelT] = [] + updates: list[tuple[ModelT, dict[str, tuple[object, object]]]] = [] + outputs: list[ModelT] = [] + + for data in items: + data_dict = data.model_dump(exclude_none=True) + entity_id = data_dict.get("id") + existing: ModelT | None = None + if entity_id is not None: + existing = await self.repo.find_by_id(entity_id) + + if existing is None: + if data_dict.get("id") is None: + data_dict.pop("id", None) + entity = self.model_cls(**data_dict) + await self.pre_save(entity, data) + entities_to_insert.append(entity) + outputs.append(entity) + continue + + tracked_fields = set(data_dict.keys()) + if hasattr(existing, "level"): # pragma: no branch + tracked_fields.add("level") + old_values = {field: getattr(existing, field) for field in tracked_fields if hasattr(existing, field)} + + for key, value in data_dict.items(): + if key == "id": # pragma: no branch + continue + if not self._should_assign_field(key, value): + continue + if hasattr(existing, key): + setattr(existing, key, value) + + await self.pre_update(existing, data, old_values) + + changes: dict[str, tuple[object, object]] = {} + for field in tracked_fields: + if hasattr(existing, field): + new_value = getattr(existing, field) + old_value = old_values.get(field) + if old_value != new_value: + changes[field] = (old_value, new_value) + + updates.append((existing, changes)) + outputs.append(existing) + + if entities_to_insert: # pragma: no branch + await self.repo.save_all(entities_to_insert) + await self.repo.commit() + if outputs: # pragma: no branch + await self.repo.refresh_many(outputs) + + for entity in entities_to_insert: + await self.post_save(entity) + for entity, changes in updates: + await self.post_update(entity, changes) + + return [self._to_output_schema(entity) for entity in outputs] + + async def delete_by_id(self, id: IdT) -> None: + """Delete an entity by its ID.""" + entity = await self.repo.find_by_id(id) + if entity is None: + return + await self.pre_delete(entity) + await self.repo.delete(entity) + await self.repo.commit() + await self.post_delete(entity) + + async def delete_all(self) -> None: + """Delete all entities.""" + entities = await self.repo.find_all() + for entity in entities: + await self.pre_delete(entity) + await self.repo.delete_all() + await self.repo.commit() + for entity in entities: + await self.post_delete(entity) + + async def delete_all_by_id(self, ids: Sequence[IdT]) -> None: + """Delete multiple entities by their IDs.""" + if not ids: + return + entities = await self.repo.find_all_by_id(ids) + for entity in entities: + await self.pre_delete(entity) + await self.repo.delete_all_by_id(ids) + await self.repo.commit() + for entity in entities: + await self.post_delete(entity) + + async def count(self) -> int: + """Count the number of entities.""" + return await self.repo.count() + + async def exists_by_id(self, id: IdT) -> bool: + """Check if an entity exists by its ID.""" + return await self.repo.exists_by_id(id) + + async def find_by_id(self, id: IdT) -> OutSchemaT | None: + """Find an entity by its ID.""" + entity = await self.repo.find_by_id(id) + if entity is None: + return None + return self._to_output_schema(entity) + + async def find_all(self) -> list[OutSchemaT]: + """Find all entities.""" + entities = await self.repo.find_all() + return [self._to_output_schema(e) for e in entities] + + async def find_paginated(self, page: int, size: int) -> tuple[list[OutSchemaT], int]: + """Find entities with pagination.""" + offset = (page - 1) * size + entities = await self.repo.find_all_paginated(offset, size) + total = await self.repo.count() + return [self._to_output_schema(e) for e in entities], total + + async def find_all_by_id(self, ids: Sequence[IdT]) -> list[OutSchemaT]: + """Find entities by their IDs.""" + entities = await self.repo.find_all_by_id(ids) + return [self._to_output_schema(e) for e in entities] diff --git a/packages/servicekit/src/servicekit/core/models.py b/packages/servicekit/src/servicekit/core/models.py new file mode 100644 index 0000000..7d85a10 --- /dev/null +++ b/packages/servicekit/src/servicekit/core/models.py @@ -0,0 +1,24 @@ +"""Base ORM classes for SQLAlchemy models.""" + +import datetime + +from sqlalchemy import func +from sqlalchemy.ext.asyncio import AsyncAttrs +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column +from ulid import ULID + +from .types import ULIDType + + +class Base(AsyncAttrs, DeclarativeBase): + """Root declarative base with async support.""" + + +class Entity(Base): + """Optional base with common columns for your models.""" + + __abstract__ = True + + id: Mapped[ULID] = mapped_column(ULIDType, primary_key=True, default=ULID) + created_at: Mapped[datetime.datetime] = mapped_column(server_default=func.now()) + updated_at: Mapped[datetime.datetime] = mapped_column(server_default=func.now(), onupdate=func.now()) diff --git a/packages/servicekit/src/servicekit/core/repository.py b/packages/servicekit/src/servicekit/core/repository.py new file mode 100644 index 0000000..b1fb7e4 --- /dev/null +++ b/packages/servicekit/src/servicekit/core/repository.py @@ -0,0 +1,168 @@ +"""Base repository classes for data access layer.""" + +from abc import ABC, abstractmethod +from typing import Iterable, Sequence + +from sqlalchemy import delete, func, select +from sqlalchemy.ext.asyncio import AsyncSession +from ulid import ULID + + +class Repository[T, IdT = ULID](ABC): + """Abstract repository interface for data access operations.""" + + @abstractmethod + async def save(self, entity: T) -> T: + """Save an entity to the database.""" + ... + + @abstractmethod + async def save_all(self, entities: Iterable[T]) -> Sequence[T]: + """Save multiple entities to the database.""" + ... + + @abstractmethod + async def commit(self) -> None: + """Commit the current database transaction.""" + ... + + @abstractmethod + async def refresh_many(self, entities: Iterable[T]) -> None: + """Refresh multiple entities from the database.""" + ... + + @abstractmethod + async def delete(self, entity: T) -> None: + """Delete an entity from the database.""" + ... + + @abstractmethod + async def delete_by_id(self, id: IdT) -> None: + """Delete an entity by its ID.""" + ... + + @abstractmethod + async def delete_all(self) -> None: + """Delete all entities from the database.""" + ... + + @abstractmethod + async def delete_all_by_id(self, ids: Sequence[IdT]) -> None: + """Delete multiple entities by their IDs.""" + ... + + @abstractmethod + async def count(self) -> int: + """Count the number of entities.""" + ... + + @abstractmethod + async def exists_by_id(self, id: IdT) -> bool: + """Check if an entity exists by its ID.""" + ... + + @abstractmethod + async def find_all(self) -> Sequence[T]: + """Find all entities.""" + ... + + @abstractmethod + async def find_all_paginated(self, offset: int, limit: int) -> Sequence[T]: + """Find entities with pagination.""" + ... + + @abstractmethod + async def find_all_by_id(self, ids: Sequence[IdT]) -> Sequence[T]: + """Find entities by their IDs.""" + ... + + @abstractmethod + async def find_by_id(self, id: IdT) -> T | None: + """Find an entity by its ID.""" + ... + + +class BaseRepository[T, IdT = ULID](Repository[T, IdT]): + """Base repository implementation with common CRUD operations.""" + + def __init__(self, session: AsyncSession, model: type[T]) -> None: + """Initialize repository with database session and model type.""" + self.s = session + self.model = model + + # ---------- Create ---------- + async def save(self, entity: T) -> T: + """Save an entity to the database.""" + self.s.add(entity) + return entity + + async def save_all(self, entities: Iterable[T]) -> Sequence[T]: + """Save multiple entities to the database.""" + entity_list = list(entities) + self.s.add_all(entity_list) + return entity_list + + async def commit(self) -> None: + """Commit the current database transaction.""" + await self.s.commit() + + async def refresh_many(self, entities: Iterable[T]) -> None: + """Refresh multiple entities from the database.""" + for e in entities: + await self.s.refresh(e) + + # ---------- Delete ---------- + async def delete(self, entity: T) -> None: + """Delete an entity from the database.""" + await self.s.delete(entity) + + async def delete_by_id(self, id: IdT) -> None: + """Delete an entity by its ID.""" + id_col = getattr(self.model, "id") + await self.s.execute(delete(self.model).where(id_col == id)) + + async def delete_all(self) -> None: + """Delete all entities from the database.""" + await self.s.execute(delete(self.model)) + + async def delete_all_by_id(self, ids: Sequence[IdT]) -> None: + """Delete multiple entities by their IDs.""" + if not ids: + return + # Access the "id" column generically + id_col = getattr(self.model, "id") + await self.s.execute(delete(self.model).where(id_col.in_(ids))) + + # ---------- Read / Count ---------- + async def count(self) -> int: + """Count the number of entities.""" + return await self.s.scalar(select(func.count()).select_from(self.model)) or 0 + + async def exists_by_id(self, id: IdT) -> bool: + """Check if an entity exists by its ID.""" + # Access the "id" column generically + id_col = getattr(self.model, "id") + q = select(select(id_col).where(id_col == id).exists()) + return await self.s.scalar(q) or False + + async def find_all(self) -> Sequence[T]: + """Find all entities.""" + result = await self.s.scalars(select(self.model)) + return result.all() + + async def find_all_paginated(self, offset: int, limit: int) -> Sequence[T]: + """Find entities with pagination.""" + result = await self.s.scalars(select(self.model).offset(offset).limit(limit)) + return result.all() + + async def find_all_by_id(self, ids: Sequence[IdT]) -> Sequence[T]: + """Find entities by their IDs.""" + if not ids: + return [] + id_col = getattr(self.model, "id") + result = await self.s.scalars(select(self.model).where(id_col.in_(ids))) + return result.all() + + async def find_by_id(self, id: IdT) -> T | None: + """Find an entity by its ID.""" + return await self.s.get(self.model, id) diff --git a/packages/servicekit/src/servicekit/core/scheduler.py b/packages/servicekit/src/servicekit/core/scheduler.py new file mode 100644 index 0000000..65622d7 --- /dev/null +++ b/packages/servicekit/src/servicekit/core/scheduler.py @@ -0,0 +1,305 @@ +"""Job scheduler for async task management with in-memory asyncio implementation.""" + +import asyncio +import inspect +import traceback +from abc import ABC, abstractmethod +from datetime import datetime, timezone +from typing import Any, Awaitable, Callable + +import ulid +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr + +from .schemas import JobRecord, JobStatus + +ULID = ulid.ULID + +# Type aliases for scheduler job targets +type JobTarget = Callable[..., Any] | Callable[..., Awaitable[Any]] | Awaitable[Any] +type JobExecutor = Callable[[], Awaitable[Any]] + + +class JobScheduler(BaseModel, ABC): + """Abstract job scheduler interface for async task management.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + @abstractmethod + async def add_job( + self, + target: JobTarget, + /, + *args: Any, + **kwargs: Any, + ) -> ULID: + """Add a job to the scheduler and return its ID.""" + ... + + @abstractmethod + async def get_status(self, job_id: ULID) -> JobStatus: + """Get the status of a job.""" + ... + + @abstractmethod + async def get_record(self, job_id: ULID) -> JobRecord: + """Get the full record of a job.""" + ... + + @abstractmethod + async def get_all_records(self) -> list[JobRecord]: + """Get all job records.""" + ... + + @abstractmethod + async def cancel(self, job_id: ULID) -> bool: + """Cancel a running job.""" + ... + + @abstractmethod + async def delete(self, job_id: ULID) -> None: + """Delete a job record.""" + ... + + @abstractmethod + async def wait(self, job_id: ULID, timeout: float | None = None) -> None: + """Wait for a job to complete.""" + ... + + @abstractmethod + async def get_result(self, job_id: ULID) -> Any: + """Get the result of a completed job.""" + ... + + +class AIOJobScheduler(JobScheduler): + """In-memory asyncio scheduler. Sync callables run in thread pool, concurrency controlled via semaphore.""" + + name: str = Field(default="chap") + max_concurrency: int | None = Field(default=None) + + _records: dict[ULID, JobRecord] = PrivateAttr(default_factory=dict) + _results: dict[ULID, Any] = PrivateAttr(default_factory=dict) + _tasks: dict[ULID, asyncio.Task[Any]] = PrivateAttr(default_factory=dict) + _lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock) + _sema: asyncio.Semaphore | None = PrivateAttr(default=None) + + def __init__(self, **data: Any): + """Initialize scheduler with optional concurrency limit.""" + super().__init__(**data) + if self.max_concurrency and self.max_concurrency > 0: + self._sema = asyncio.Semaphore(self.max_concurrency) + + async def set_max_concurrency(self, n: int | None) -> None: + """Set maximum number of concurrent jobs.""" + async with self._lock: + self.max_concurrency = n + if n and n > 0: + self._sema = asyncio.Semaphore(n) + else: + self._sema = None + + async def add_job( + self, + target: JobTarget, + /, + *args: Any, + **kwargs: Any, + ) -> ULID: + """Add a job to the scheduler and return its ID.""" + now = datetime.now(timezone.utc) + jid = ULID() + + record = JobRecord( + id=jid, + status=JobStatus.pending, + submitted_at=now, + ) + + async with self._lock: + if jid in self._tasks: + raise RuntimeError(f"Job {jid!r} already scheduled") + self._records[jid] = record + + async def _execute_target() -> Any: + if inspect.isawaitable(target): + if args or kwargs: + # Close the coroutine to avoid "coroutine was never awaited" warning + if inspect.iscoroutine(target): + target.close() + raise TypeError("Args/kwargs not supported when target is an awaitable object.") + return await target + if inspect.iscoroutinefunction(target): + return await target(*args, **kwargs) + return await asyncio.to_thread(target, *args, **kwargs) + + async def _runner() -> Any: + if self._sema: + async with self._sema: + return await self._run_with_state(jid, _execute_target) + else: + return await self._run_with_state(jid, _execute_target) + + task = asyncio.create_task(_runner(), name=f"{self.name}-job-{jid}") + + def _drain(t: asyncio.Task[Any]) -> None: + try: + t.result() + except Exception: + pass + + task.add_done_callback(_drain) + + async with self._lock: + self._tasks[jid] = task + + return jid + + async def _run_with_state( + self, + jid: ULID, + exec_fn: JobExecutor, + ) -> Any: + """Execute job function and manage its state transitions.""" + async with self._lock: + rec = self._records[jid] + rec.status = JobStatus.running + rec.started_at = datetime.now(timezone.utc) + + try: + result = await exec_fn() + + artifact: ULID | None = result if isinstance(result, ULID) else None + + async with self._lock: + rec = self._records[jid] + rec.status = JobStatus.completed + rec.finished_at = datetime.now(timezone.utc) + rec.artifact_id = artifact + self._results[jid] = result + + return result + + except asyncio.CancelledError: + async with self._lock: + rec = self._records[jid] + rec.status = JobStatus.canceled + rec.finished_at = datetime.now(timezone.utc) + + raise + + except Exception as e: + tb = traceback.format_exc() + # Extract clean error message (exception type and message only) + error_lines = tb.strip().split("\n") + clean_error = error_lines[-1] if error_lines else str(e) + + async with self._lock: + rec = self._records[jid] + rec.status = JobStatus.failed + rec.finished_at = datetime.now(timezone.utc) + rec.error = clean_error + rec.error_traceback = tb + + raise + + async def get_all_records(self) -> list[JobRecord]: + """Get all job records sorted by submission time.""" + async with self._lock: + records = [r.model_copy(deep=True) for r in self._records.values()] + + records.sort( + key=lambda r: getattr(r, "submitted_at", datetime.min.replace(tzinfo=timezone.utc)), + reverse=True, + ) + + return records + + async def get_record(self, job_id: ULID) -> JobRecord: + """Get the full record of a job.""" + async with self._lock: + rec = self._records.get(job_id) + + if rec is None: + raise KeyError("Job not found") + + return rec.model_copy(deep=True) + + async def get_status(self, job_id: ULID) -> JobStatus: + """Get the status of a job.""" + async with self._lock: + rec = self._records.get(job_id) + + if rec is None: + raise KeyError("Job not found") + + return rec.status + + async def get_result(self, job_id: ULID) -> Any: + """Get the result of a completed job.""" + async with self._lock: + rec = self._records.get(job_id) + + if rec is None: + raise KeyError("Job not found") + + if rec.status == JobStatus.completed: + return self._results.get(job_id) + + if rec.status == JobStatus.failed: + msg = getattr(rec, "error", "Job failed") + raise RuntimeError(msg) + + raise RuntimeError(f"Job not finished (status={rec.status})") + + async def wait(self, job_id: ULID, timeout: float | None = None) -> None: + """Wait for a job to complete.""" + async with self._lock: + task = self._tasks.get(job_id) + + if task is None: + raise KeyError("Job not found") + + await asyncio.wait_for(asyncio.shield(task), timeout=timeout) + + async def cancel(self, job_id: ULID) -> bool: + """Cancel a running job.""" + async with self._lock: + task = self._tasks.get(job_id) + exists = job_id in self._records + + if not exists: + raise KeyError("Job not found") + + if not task or task.done(): + return False + + task.cancel() + + try: + await task + except asyncio.CancelledError: + pass + + return True + + async def delete(self, job_id: ULID) -> None: + """Delete a job record.""" + async with self._lock: + rec = self._records.get(job_id) + task = self._tasks.get(job_id) + + if rec is None: + raise KeyError("Job not found") + + if task and not task.done(): + task.cancel() + + try: + await task + except asyncio.CancelledError: + pass + + async with self._lock: + self._records.pop(job_id, None) + self._tasks.pop(job_id, None) + self._results.pop(job_id, None) diff --git a/packages/servicekit/src/servicekit/core/schemas.py b/packages/servicekit/src/servicekit/core/schemas.py new file mode 100644 index 0000000..1df1448 --- /dev/null +++ b/packages/servicekit/src/servicekit/core/schemas.py @@ -0,0 +1,126 @@ +"""Core Pydantic schemas for entities, responses, and jobs.""" + +from __future__ import annotations + +from datetime import datetime +from enum import StrEnum +from typing import Generic, TypeVar + +import ulid +from pydantic import BaseModel, ConfigDict, Field, computed_field + +ULID = ulid.ULID +T = TypeVar("T") + + +# Base entity schemas + + +class EntityIn(BaseModel): + """Base input schema for entities with optional ID.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + id: ULID | None = None + + +class EntityOut(BaseModel): + """Base output schema for entities with ID and timestamps.""" + + model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True) + + id: ULID + created_at: datetime + updated_at: datetime + + +# Response schemas + + +class PaginatedResponse(BaseModel, Generic[T]): + """Paginated response with items, total count, page number, and computed page count.""" + + items: list[T] = Field(description="List of items for the current page") + total: int = Field(description="Total number of items across all pages", ge=0) + page: int = Field(description="Current page number (1-indexed)", ge=1) + size: int = Field(description="Number of items per page", ge=1) + + @computed_field # type: ignore[prop-decorator] + @property + def pages(self) -> int: + """Total number of pages.""" + if self.total == 0: + return 0 + return (self.total + self.size - 1) // self.size + + +class BulkOperationError(BaseModel): + """Error information for a single item in a bulk operation.""" + + id: str = Field(description="Identifier of the item that failed") + reason: str = Field(description="Human-readable error message") + + +class BulkOperationResult(BaseModel): + """Result of bulk operation with counts of succeeded/failed items and error details.""" + + total: int = Field(description="Total number of items processed", ge=0) + succeeded: int = Field(description="Number of items successfully processed", ge=0) + failed: int = Field(description="Number of items that failed", ge=0) + errors: list[BulkOperationError] = Field(default_factory=list, description="Details of failed items (if any)") + + +class ProblemDetail(BaseModel): + """RFC 9457 Problem Details with URN error type, status, and human-readable messages.""" + + type: str = Field( + default="about:blank", + description="URI reference identifying the problem type (URN format for chapkit errors)", + ) + title: str = Field(description="Short, human-readable summary of the problem type") + status: int = Field(description="HTTP status code", ge=100, le=599) + detail: str | None = Field(default=None, description="Human-readable explanation specific to this occurrence") + instance: str | None = Field(default=None, description="URI reference identifying the specific occurrence") + trace_id: str | None = Field(default=None, description="Optional trace ID for debugging") + + model_config = { + "json_schema_extra": { + "examples": [ + { + "type": "urn:chapkit:error:not-found", + "title": "Resource Not Found", + "status": 404, + "detail": "Config with id 01ABC... not found", + "instance": "/api/config/01ABC...", + } + ] + } + } + + +# Job schemas + + +class JobStatus(StrEnum): + """Status of a scheduled job.""" + + pending = "pending" + running = "running" + completed = "completed" + failed = "failed" + canceled = "canceled" + + +class JobRecord(BaseModel): + """Complete record of a scheduled job's state and metadata.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + id: ULID = Field(description="Unique job identifier") + status: JobStatus = Field(default=JobStatus.pending, description="Current job status") + submitted_at: datetime | None = Field(default=None, description="When the job was submitted") + started_at: datetime | None = Field(default=None, description="When the job started running") + finished_at: datetime | None = Field(default=None, description="When the job finished") + error: str | None = Field(default=None, description="User-friendly error message if job failed") + error_traceback: str | None = Field(default=None, description="Full error traceback for debugging") + artifact_id: ULID | None = Field(default=None, description="ID of artifact created by job (if job returns a ULID)") diff --git a/packages/servicekit/src/servicekit/core/types.py b/packages/servicekit/src/servicekit/core/types.py new file mode 100644 index 0000000..b85135a --- /dev/null +++ b/packages/servicekit/src/servicekit/core/types.py @@ -0,0 +1,94 @@ +"""Custom types for chapkit - SQLAlchemy and Pydantic types.""" + +from __future__ import annotations + +import json +from typing import Annotated, Any + +from pydantic import PlainSerializer +from sqlalchemy import String +from sqlalchemy.types import TypeDecorator +from ulid import ULID + + +class ULIDType(TypeDecorator[ULID]): + """SQLAlchemy custom type for ULID stored as 26-character strings.""" + + impl = String(26) + cache_ok = True + + def process_bind_param(self, value: ULID | str | None, dialect: Any) -> str | None: + """Convert ULID to string for database storage.""" + if value is None: + return None + if isinstance(value, str): + return str(ULID.from_str(value)) # Validate and normalize + return str(value) + + def process_result_value(self, value: str | None, dialect: Any) -> ULID | None: + """Convert string from database to ULID object.""" + if value is None: + return None + return ULID.from_str(value) + + +# Pydantic serialization helpers + + +def _is_json_serializable(value: Any) -> bool: + """Test if value can be serialized to JSON.""" + try: + json.dumps(value) + return True + except (TypeError, ValueError, OverflowError): + return False + + +def _create_serialization_metadata(value: Any, *, is_full_object: bool = True) -> dict[str, str]: + """Build metadata dict for non-serializable values with type info and truncated repr.""" + value_repr = repr(value) + max_repr_length = 200 + + if len(value_repr) > max_repr_length: + value_repr = value_repr[:max_repr_length] + "..." + + error_msg = ( + "Value is not JSON-serializable. Access the original object from storage if needed." + if is_full_object + else "Value is not JSON-serializable." + ) + + return { + "_type": type(value).__name__, + "_module": type(value).__module__, + "_repr": value_repr, + "_serialization_error": error_msg, + } + + +def _serialize_with_metadata(value: Any) -> Any: + """Serialize value, replacing non-serializable values with metadata dicts.""" + # For dicts, serialize each field individually + if isinstance(value, dict): + result = {} + + for key, val in value.items(): + if _is_json_serializable(val): + result[key] = val + else: + result[key] = _create_serialization_metadata(val, is_full_object=False) + + return result + + # For non-dict values, serialize or return metadata + if _is_json_serializable(value): + return value + + return _create_serialization_metadata(value, is_full_object=True) + + +SerializableDict = Annotated[ + Any, + PlainSerializer(_serialize_with_metadata, return_type=Any), +] +"""Pydantic type that serializes dicts, replacing non-JSON-serializable values with metadata.""" diff --git a/packages/servicekit/src/servicekit/modules/__init__.py b/packages/servicekit/src/servicekit/modules/__init__.py new file mode 100644 index 0000000..de1745b --- /dev/null +++ b/packages/servicekit/src/servicekit/modules/__init__.py @@ -0,0 +1 @@ +"""Domain features - vertical slices with models, schemas, repositories, managers, and routers.""" diff --git a/packages/servicekit/src/servicekit/modules/artifact/__init__.py b/packages/servicekit/src/servicekit/modules/artifact/__init__.py new file mode 100644 index 0000000..047a09b --- /dev/null +++ b/packages/servicekit/src/servicekit/modules/artifact/__init__.py @@ -0,0 +1,19 @@ +"""Artifact feature - hierarchical data storage with parent-child relationships.""" + +from .manager import ArtifactManager +from .models import Artifact +from .repository import ArtifactRepository +from .router import ArtifactRouter +from .schemas import ArtifactHierarchy, ArtifactIn, ArtifactOut, ArtifactTreeNode, PandasDataFrame + +__all__ = [ + "Artifact", + "ArtifactHierarchy", + "ArtifactIn", + "ArtifactOut", + "ArtifactTreeNode", + "PandasDataFrame", + "ArtifactRepository", + "ArtifactManager", + "ArtifactRouter", +] diff --git a/packages/servicekit/src/servicekit/modules/artifact/manager.py b/packages/servicekit/src/servicekit/modules/artifact/manager.py new file mode 100644 index 0000000..fbc74cf --- /dev/null +++ b/packages/servicekit/src/servicekit/modules/artifact/manager.py @@ -0,0 +1,163 @@ +"""Artifact manager for hierarchical data with parent-child relationships.""" + +from __future__ import annotations + +from collections import deque + +from ulid import ULID + +from servicekit.core.manager import BaseManager +from servicekit.modules.config.repository import ConfigRepository +from servicekit.modules.config.schemas import BaseConfig, ConfigOut + +from .models import Artifact +from .repository import ArtifactRepository +from .schemas import ArtifactHierarchy, ArtifactIn, ArtifactOut, ArtifactTreeNode + + +class ArtifactManager(BaseManager[Artifact, ArtifactIn, ArtifactOut, ULID]): + """Manager for Artifact entities with hierarchical tree operations.""" + + def __init__( + self, + repo: ArtifactRepository, + hierarchy: ArtifactHierarchy | None = None, + config_repo: ConfigRepository | None = None, + ) -> None: + """Initialize artifact manager with repository, hierarchy, and optional config repo.""" + super().__init__(repo, Artifact, ArtifactOut) + self.repo: ArtifactRepository = repo + self.hierarchy = hierarchy + self.config_repo = config_repo + + # Public API ------------------------------------------------------ + + async def find_subtree(self, start_id: ULID) -> list[ArtifactTreeNode]: + """Find all artifacts in the subtree rooted at the given ID.""" + artifacts = await self.repo.find_subtree(start_id) + return [self._to_tree_node(artifact) for artifact in artifacts] + + async def expand_artifact(self, artifact_id: ULID) -> ArtifactTreeNode | None: + """Expand a single artifact with hierarchy metadata but without children.""" + artifact = await self.repo.find_by_id(artifact_id) + if artifact is None: + return None + + node = self._to_tree_node(artifact) + node.children = None + + # Populate config if available and artifact is a root + if self.config_repo is not None: + config = await self.config_repo.find_by_root_artifact_id(artifact_id) + if config is not None: + # Use model_construct to bypass validation since we don't know the concrete data type + node.config = ConfigOut[BaseConfig].model_construct( + id=config.id, + created_at=config.created_at, + updated_at=config.updated_at, + name=config.name, + data=config.data, + ) + + return node + + async def build_tree(self, start_id: ULID) -> ArtifactTreeNode | None: + """Build a hierarchical tree structure rooted at the given artifact ID.""" + artifacts = await self.find_subtree(start_id) + if not artifacts: + return None + + node_map: dict[ULID, ArtifactTreeNode] = {} + for node in artifacts: + node.children = [] + node_map[node.id] = node + + for node in artifacts: + if node.parent_id is None: + continue + parent = node_map.get(node.parent_id) + if parent is None: + continue + if parent.children is None: + parent.children = [] + parent.children.append(node) + + # Keep children as [] for leaf nodes (semantic: "loaded but empty") + # Only expand_artifact sets children=None (semantic: "not loaded") + + root = node_map.get(start_id) + + # Populate config for root node only + if root is not None and self.config_repo is not None: + config = await self.config_repo.find_by_root_artifact_id(start_id) + if config is not None: + # Use model_construct to bypass validation since we don't know the concrete data type + root.config = ConfigOut[BaseConfig].model_construct( + id=config.id, + created_at=config.created_at, + updated_at=config.updated_at, + name=config.name, + data=config.data, + ) + + return root + + # Lifecycle overrides -------------------------------------------- + + def _should_assign_field(self, field: str, value: object) -> bool: + """Prevent assigning None to level field during updates.""" + if field == "level" and value is None: + return False + return super()._should_assign_field(field, value) + + async def pre_save(self, entity: Artifact, data: ArtifactIn) -> None: + """Compute and set artifact level before saving.""" + entity.level = await self._compute_level(entity.parent_id) + + async def pre_update(self, entity: Artifact, data: ArtifactIn, old_values: dict[str, object]) -> None: + """Recalculate artifact level and cascade updates to descendants if parent changed.""" + previous_level = old_values.get("level", entity.level) + entity.level = await self._compute_level(entity.parent_id) + parent_changed = old_values.get("parent_id") != entity.parent_id + if parent_changed or previous_level != entity.level: + await self._recalculate_descendants(entity) + + # Helper utilities ------------------------------------------------ + + async def _compute_level(self, parent_id: ULID | None) -> int: + """Compute the level of an artifact based on its parent.""" + if parent_id is None: + return 0 + parent = await self.repo.find_by_id(parent_id) + if parent is None: + return 0 # pragma: no cover + return parent.level + 1 + + async def _recalculate_descendants(self, entity: Artifact) -> None: + """Recalculate levels for all descendants of an artifact.""" + subtree = await self.repo.find_subtree(entity.id) + by_parent: dict[ULID | None, list[Artifact]] = {} + for node in subtree: + by_parent.setdefault(node.parent_id, []).append(node) + + queue: deque[Artifact] = deque([entity]) + while queue: + current = queue.popleft() + for child in by_parent.get(current.id, []): + child.level = current.level + 1 + queue.append(child) + + def _to_tree_node(self, entity: Artifact) -> ArtifactTreeNode: + """Convert artifact entity to tree node with hierarchy metadata.""" + base = super()._to_output_schema(entity) + node = ArtifactTreeNode.from_artifact(base) + if self.hierarchy is not None: + meta = self.hierarchy.describe(node.level) + hierarchy_value = meta.get(self.hierarchy.hierarchy_key) + if hierarchy_value is not None: + node.hierarchy = str(hierarchy_value) + label_value = meta.get(self.hierarchy.label_key) + if label_value is not None: + node.level_label = str(label_value) + + return node diff --git a/packages/servicekit/src/servicekit/modules/artifact/models.py b/packages/servicekit/src/servicekit/modules/artifact/models.py new file mode 100644 index 0000000..9c7d79e --- /dev/null +++ b/packages/servicekit/src/servicekit/modules/artifact/models.py @@ -0,0 +1,37 @@ +"""Artifact ORM model for hierarchical data storage.""" + +from __future__ import annotations + +from typing import Any + +from sqlalchemy import ForeignKey, PickleType +from sqlalchemy.orm import Mapped, mapped_column, relationship +from ulid import ULID + +from servicekit.core.models import Entity +from servicekit.core.types import ULIDType + + +class Artifact(Entity): + """ORM model for hierarchical artifacts with parent-child relationships.""" + + __tablename__ = "artifacts" + + parent_id: Mapped[ULID | None] = mapped_column( + ULIDType, + ForeignKey("artifacts.id", ondelete="SET NULL"), + nullable=True, + index=True, + ) + + parent: Mapped[Artifact | None] = relationship( + remote_side="Artifact.id", + back_populates="children", + ) + + children: Mapped[list[Artifact]] = relationship( + back_populates="parent", + ) + + data: Mapped[Any] = mapped_column(PickleType(protocol=4), nullable=False) + level: Mapped[int] = mapped_column(default=0, nullable=False, index=True) diff --git a/packages/servicekit/src/servicekit/modules/artifact/repository.py b/packages/servicekit/src/servicekit/modules/artifact/repository.py new file mode 100644 index 0000000..21ba254 --- /dev/null +++ b/packages/servicekit/src/servicekit/modules/artifact/repository.py @@ -0,0 +1,49 @@ +"""Artifact repository for hierarchical data access with tree traversal.""" + +from __future__ import annotations + +from typing import Iterable + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload +from ulid import ULID + +from servicekit.core.repository import BaseRepository + +from .models import Artifact + + +class ArtifactRepository(BaseRepository[Artifact, ULID]): + """Repository for Artifact entities with tree traversal operations.""" + + def __init__(self, session: AsyncSession) -> None: + """Initialize artifact repository with database session.""" + super().__init__(session, Artifact) + + async def find_by_id(self, id: ULID) -> Artifact | None: + """Find an artifact by ID with children eagerly loaded.""" + return await self.s.get(self.model, id, options=[selectinload(self.model.children)]) + + async def find_subtree(self, start_id: ULID) -> Iterable[Artifact]: + """Find all artifacts in the subtree rooted at the given ID using recursive CTE.""" + cte = select(self.model.id).where(self.model.id == start_id).cte(name="descendants", recursive=True) + cte = cte.union_all(select(self.model.id).where(self.model.parent_id == cte.c.id)) + + subtree_ids = (await self.s.scalars(select(cte.c.id))).all() + rows = (await self.s.scalars(select(self.model).where(self.model.id.in_(subtree_ids)))).all() + return rows + + async def get_root_artifact(self, artifact_id: ULID) -> Artifact | None: + """Find the root artifact by traversing up the parent chain.""" + artifact = await self.s.get(self.model, artifact_id) + if artifact is None: + return None + + while artifact.parent_id is not None: + parent = await self.s.get(self.model, artifact.parent_id) + if parent is None: + break + artifact = parent + + return artifact diff --git a/packages/servicekit/src/servicekit/modules/artifact/router.py b/packages/servicekit/src/servicekit/modules/artifact/router.py new file mode 100644 index 0000000..0ee96c9 --- /dev/null +++ b/packages/servicekit/src/servicekit/modules/artifact/router.py @@ -0,0 +1,114 @@ +"""Artifact CRUD router with hierarchical tree operations and config access.""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any + +from fastapi import Depends, HTTPException, status + +from servicekit.core.api.crud import CrudPermissions, CrudRouter +from servicekit.modules.config.manager import ConfigManager +from servicekit.modules.config.schemas import BaseConfig, ConfigOut + +from .manager import ArtifactManager +from .schemas import ArtifactIn, ArtifactOut, ArtifactTreeNode + + +class ArtifactRouter(CrudRouter[ArtifactIn, ArtifactOut]): + """CRUD router for Artifact entities with tree operations and config access.""" + + def __init__( + self, + prefix: str, + tags: Sequence[str], + manager_factory: Any, + entity_in_type: type[ArtifactIn], + entity_out_type: type[ArtifactOut], + permissions: CrudPermissions | None = None, + enable_config_access: bool = False, + **kwargs: Any, + ) -> None: + """Initialize artifact router with entity types and manager factory.""" + self.enable_config_access = enable_config_access + super().__init__( + prefix=prefix, + tags=list(tags), + entity_in_type=entity_in_type, + entity_out_type=entity_out_type, + manager_factory=manager_factory, + permissions=permissions, + **kwargs, + ) + + def _register_routes(self) -> None: + """Register artifact CRUD routes and tree operations.""" + super()._register_routes() + + manager_factory = self.manager_factory + + async def expand_artifact( + entity_id: str, + manager: ArtifactManager = Depends(manager_factory), + ) -> ArtifactTreeNode: + ulid_id = self._parse_ulid(entity_id) + + expanded = await manager.expand_artifact(ulid_id) + if expanded is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Artifact with id {entity_id} not found", + ) + return expanded + + async def build_tree( + entity_id: str, + manager: ArtifactManager = Depends(manager_factory), + ) -> ArtifactTreeNode: + ulid_id = self._parse_ulid(entity_id) + + tree = await manager.build_tree(ulid_id) + if tree is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Artifact with id {entity_id} not found", + ) + return tree + + self.register_entity_operation( + "expand", + expand_artifact, + response_model=ArtifactTreeNode, + summary="Expand artifact", + description="Get artifact with hierarchy metadata but without children", + ) + + self.register_entity_operation( + "tree", + build_tree, + response_model=ArtifactTreeNode, + summary="Build artifact tree", + description="Build hierarchical tree structure rooted at the given artifact", + ) + + if self.enable_config_access: + # Import locally to avoid circular dependency + from servicekit.api.dependencies import get_config_manager + + async def get_config( + entity_id: str, + artifact_manager: ArtifactManager = Depends(manager_factory), + config_manager: ConfigManager[BaseConfig] = Depends(get_config_manager), + ) -> ConfigOut[BaseConfig] | None: + artifact_id = self._parse_ulid(entity_id) + config = await config_manager.get_config_for_artifact(artifact_id, artifact_manager.repo) + return config + + self.register_entity_operation( + "config", + get_config, + http_method="GET", + response_model=ConfigOut[BaseConfig], + summary="Get artifact config", + description="Get the config for an artifact by walking up the tree to find the root's config", + ) diff --git a/packages/servicekit/src/servicekit/modules/artifact/schemas.py b/packages/servicekit/src/servicekit/modules/artifact/schemas.py new file mode 100644 index 0000000..b10f5ae --- /dev/null +++ b/packages/servicekit/src/servicekit/modules/artifact/schemas.py @@ -0,0 +1,89 @@ +"""Pydantic schemas for hierarchical artifacts with config linking and tree structures.""" + +from __future__ import annotations + +from typing import Any, ClassVar, Mapping, Self + +import pandas as pd +from pydantic import BaseModel, Field +from ulid import ULID + +from servicekit.core.schemas import EntityIn, EntityOut +from servicekit.core.types import SerializableDict +from servicekit.modules.config.schemas import BaseConfig, ConfigOut + + +class ArtifactIn(EntityIn): + """Input schema for creating or updating artifacts.""" + + data: Any + parent_id: ULID | None = None + level: int | None = None + + +class ArtifactOut(EntityOut): + """Output schema for artifact entities.""" + + data: SerializableDict + parent_id: ULID | None = None + level: int + + +class ArtifactTreeNode(ArtifactOut): + """Artifact node with tree structure metadata and optional config.""" + + level_label: str | None = None + hierarchy: str | None = None + children: list["ArtifactTreeNode"] | None = None + config: "ConfigOut[BaseConfig] | None" = None + + @classmethod + def from_artifact(cls, artifact: ArtifactOut) -> Self: + """Create a tree node from an artifact output schema.""" + return cls.model_validate(artifact.model_dump()) + + +class ArtifactHierarchy(BaseModel): + """Configuration for artifact hierarchy with level labels.""" + + name: str = Field(..., description="Human readable name of this hierarchy") + level_labels: Mapping[int, str] = Field( + default_factory=dict, + description="Mapping of numeric levels to labels (0 -> 'train', etc.)", + ) + + model_config = {"frozen": True} + + hierarchy_key: ClassVar[str] = "hierarchy" + depth_key: ClassVar[str] = "level_depth" + label_key: ClassVar[str] = "level_label" + + def label_for(self, level: int) -> str: + """Get the label for a given level or return default.""" + return self.level_labels.get(level, f"level_{level}") + + def describe(self, level: int) -> dict[str, Any]: + """Get hierarchy metadata dict for a given level.""" + return { + self.hierarchy_key: self.name, + self.depth_key: level, + self.label_key: self.label_for(level), + } + + +class PandasDataFrame(BaseModel): + """Pydantic schema for serializing pandas DataFrames.""" + + columns: list[str] + data: list[list[Any]] + + @classmethod + def from_dataframe(cls, df: pd.DataFrame) -> Self: + """Create schema from pandas DataFrame.""" + if not isinstance(df, pd.DataFrame): # pyright: ignore[reportUnnecessaryIsInstance] + raise TypeError(f"Expected a pandas DataFrame, but got {type(df)}") + return cls(columns=df.columns.tolist(), data=df.values.tolist()) + + def to_dataframe(self) -> pd.DataFrame: + """Convert schema back to pandas DataFrame.""" + return pd.DataFrame(self.data, columns=self.columns) diff --git a/packages/servicekit/src/servicekit/modules/config/__init__.py b/packages/servicekit/src/servicekit/modules/config/__init__.py new file mode 100644 index 0000000..4ee6ab2 --- /dev/null +++ b/packages/servicekit/src/servicekit/modules/config/__init__.py @@ -0,0 +1,20 @@ +"""Config feature - key-value configuration with JSON data storage.""" + +from .manager import ConfigManager +from .models import Config, ConfigArtifact +from .repository import ConfigRepository +from .router import ConfigRouter +from .schemas import BaseConfig, ConfigIn, ConfigOut, LinkArtifactRequest, UnlinkArtifactRequest + +__all__ = [ + "Config", + "ConfigArtifact", + "BaseConfig", + "ConfigIn", + "ConfigOut", + "LinkArtifactRequest", + "UnlinkArtifactRequest", + "ConfigRepository", + "ConfigManager", + "ConfigRouter", +] diff --git a/packages/servicekit/src/servicekit/modules/config/manager.py b/packages/servicekit/src/servicekit/modules/config/manager.py new file mode 100644 index 0000000..fb92900 --- /dev/null +++ b/packages/servicekit/src/servicekit/modules/config/manager.py @@ -0,0 +1,63 @@ +"""Config manager for CRUD operations and artifact linking.""" + +from __future__ import annotations + +from ulid import ULID + +from servicekit.core.manager import BaseManager +from servicekit.modules.artifact.repository import ArtifactRepository +from servicekit.modules.artifact.schemas import ArtifactOut + +from .models import Config +from .repository import ConfigRepository +from .schemas import BaseConfig, ConfigIn, ConfigOut + + +class ConfigManager[DataT: BaseConfig](BaseManager[Config, ConfigIn[DataT], ConfigOut[DataT], ULID]): + """Manager for Config entities with artifact linking operations.""" + + def __init__(self, repo: ConfigRepository, data_cls: type[DataT]) -> None: + """Initialize config manager with repository and data class.""" + super().__init__(repo, Config, ConfigOut) + self.repo: ConfigRepository = repo + self.data_cls = data_cls + + async def find_by_name(self, name: str) -> ConfigOut[DataT] | None: + """Find a config by its unique name.""" + config = await self.repo.find_by_name(name) + if config: + return self._to_output_schema(config) + return None + + async def link_artifact(self, config_id: ULID, artifact_id: ULID) -> None: + """Link a config to a root artifact.""" + await self.repo.link_artifact(config_id, artifact_id) + await self.repo.commit() + + async def unlink_artifact(self, artifact_id: ULID) -> None: + """Unlink an artifact from its config.""" + await self.repo.unlink_artifact(artifact_id) + await self.repo.commit() + + async def get_config_for_artifact( + self, artifact_id: ULID, artifact_repo: ArtifactRepository + ) -> ConfigOut[DataT] | None: + """Get the config for an artifact by traversing to its root.""" + root = await artifact_repo.get_root_artifact(artifact_id) + if root is None: + return None + + config = await self.repo.find_by_root_artifact_id(root.id) + if config is None: + return None + + return self._to_output_schema(config) + + async def get_linked_artifacts(self, config_id: ULID) -> list[ArtifactOut]: + """Get all root artifacts linked to a config.""" + artifacts = await self.repo.find_artifacts_for_config(config_id) + return [ArtifactOut.model_validate(artifact, from_attributes=True) for artifact in artifacts] + + def _to_output_schema(self, entity: Config) -> ConfigOut[DataT]: + """Convert ORM entity to output schema with proper data class validation.""" + return ConfigOut[DataT].model_validate(entity, from_attributes=True, context={"data_cls": self.data_cls}) diff --git a/packages/servicekit/src/servicekit/modules/config/models.py b/packages/servicekit/src/servicekit/modules/config/models.py new file mode 100644 index 0000000..1c8fc54 --- /dev/null +++ b/packages/servicekit/src/servicekit/modules/config/models.py @@ -0,0 +1,61 @@ +"""Config ORM models for key-value configuration storage and artifact linking.""" + +from __future__ import annotations + +from typing import Any + +from sqlalchemy import ForeignKey, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.types import JSON +from ulid import ULID + +from servicekit.core.models import Base, Entity +from servicekit.core.types import ULIDType + +from .schemas import BaseConfig + + +class Config(Entity): + """ORM model for configuration with JSON data storage.""" + + __tablename__ = "configs" + + name: Mapped[str] = mapped_column(index=True) + _data_json: Mapped[dict[str, Any]] = mapped_column("data", JSON, nullable=False) + + @property + def data(self) -> dict[str, Any]: + """Return JSON data as dict.""" + return self._data_json + + @data.setter + def data(self, value: BaseConfig | dict[str, Any]) -> None: + """Serialize Pydantic model to JSON or store dict directly.""" + if isinstance(value, dict): + self._data_json = value + elif hasattr(value, "model_dump") and callable(value.model_dump): + # BaseConfig or other Pydantic model + self._data_json = value.model_dump(mode="json") + else: + raise TypeError(f"data must be a BaseConfig subclass or dict, got {type(value)}") + + +class ConfigArtifact(Base): + """Junction table linking Configs to root Artifacts.""" + + __tablename__ = "config_artifacts" + + config_id: Mapped[ULID] = mapped_column( + ULIDType, + ForeignKey("configs.id", ondelete="CASCADE"), + primary_key=True, + ) + + artifact_id: Mapped[ULID] = mapped_column( + ULIDType, + ForeignKey("artifacts.id", ondelete="CASCADE"), + primary_key=True, + unique=True, + ) + + __table_args__ = (UniqueConstraint("artifact_id", name="uq_artifact_id"),) diff --git a/packages/servicekit/src/servicekit/modules/config/repository.py b/packages/servicekit/src/servicekit/modules/config/repository.py new file mode 100644 index 0000000..3bbdcf7 --- /dev/null +++ b/packages/servicekit/src/servicekit/modules/config/repository.py @@ -0,0 +1,76 @@ +"""Config repository for database access and artifact linking.""" + +from __future__ import annotations + +from sqlalchemy import delete as sql_delete +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from ulid import ULID + +from servicekit.core.repository import BaseRepository +from servicekit.modules.artifact.models import Artifact + +from .models import Config, ConfigArtifact + + +class ConfigRepository(BaseRepository[Config, ULID]): + """Repository for Config entities with artifact linking operations.""" + + def __init__(self, session: AsyncSession) -> None: + """Initialize config repository with database session.""" + super().__init__(session, Config) + + async def find_by_name(self, name: str) -> Config | None: + """Find a config by its unique name.""" + result = await self.s.scalars(select(self.model).where(self.model.name == name)) + return result.one_or_none() + + async def link_artifact(self, config_id: ULID, artifact_id: ULID) -> None: + """Link a config to a root artifact.""" + artifact = await self.s.get(Artifact, artifact_id) + if artifact is None: + raise ValueError(f"Artifact {artifact_id} not found") + if artifact.parent_id is not None: + raise ValueError(f"Artifact {artifact_id} is not a root artifact (parent_id={artifact.parent_id})") + + link = ConfigArtifact(config_id=config_id, artifact_id=artifact_id) + self.s.add(link) + + async def unlink_artifact(self, artifact_id: ULID) -> None: + """Unlink an artifact from its config.""" + stmt = sql_delete(ConfigArtifact).where(ConfigArtifact.artifact_id == artifact_id) + await self.s.execute(stmt) + + async def delete_by_id(self, id: ULID) -> None: + """Delete a config and cascade delete all linked artifact trees.""" + from servicekit.modules.artifact.repository import ArtifactRepository + + linked_artifacts = await self.find_artifacts_for_config(id) + + artifact_repo = ArtifactRepository(self.s) + for root_artifact in linked_artifacts: + subtree = await artifact_repo.find_subtree(root_artifact.id) + for artifact in subtree: + await self.s.delete(artifact) + + await super().delete_by_id(id) + + async def find_by_root_artifact_id(self, artifact_id: ULID) -> Config | None: + """Find the config linked to a root artifact.""" + stmt = ( + select(Config) + .join(ConfigArtifact, Config.id == ConfigArtifact.config_id) + .where(ConfigArtifact.artifact_id == artifact_id) + ) + result = await self.s.scalars(stmt) + return result.one_or_none() + + async def find_artifacts_for_config(self, config_id: ULID) -> list[Artifact]: + """Find all root artifacts linked to a config.""" + stmt = ( + select(Artifact) + .join(ConfigArtifact, Artifact.id == ConfigArtifact.artifact_id) + .where(ConfigArtifact.config_id == config_id) + ) + result = await self.s.scalars(stmt) + return list(result.all()) diff --git a/packages/servicekit/src/servicekit/modules/config/router.py b/packages/servicekit/src/servicekit/modules/config/router.py new file mode 100644 index 0000000..122a3e8 --- /dev/null +++ b/packages/servicekit/src/servicekit/modules/config/router.py @@ -0,0 +1,112 @@ +"""Config CRUD router with artifact linking operations.""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any + +from fastapi import Depends, HTTPException, status + +from servicekit.core.api.crud import CrudPermissions, CrudRouter +from servicekit.modules.artifact.schemas import ArtifactOut + +from .manager import ConfigManager +from .schemas import BaseConfig, ConfigIn, ConfigOut, LinkArtifactRequest, UnlinkArtifactRequest + + +class ConfigRouter(CrudRouter[ConfigIn[BaseConfig], ConfigOut[BaseConfig]]): + """CRUD router for Config entities with artifact linking operations.""" + + def __init__( + self, + prefix: str, + tags: Sequence[str], + manager_factory: Any, + entity_in_type: type[ConfigIn[BaseConfig]], + entity_out_type: type[ConfigOut[BaseConfig]], + permissions: CrudPermissions | None = None, + enable_artifact_operations: bool = False, + **kwargs: Any, + ) -> None: + """Initialize config router with entity types and manager factory.""" + self.enable_artifact_operations = enable_artifact_operations + super().__init__( + prefix=prefix, + tags=list(tags), + entity_in_type=entity_in_type, + entity_out_type=entity_out_type, + manager_factory=manager_factory, + permissions=permissions, + **kwargs, + ) + + def _register_routes(self) -> None: + """Register config CRUD routes and artifact linking operations.""" + super()._register_routes() + + if not self.enable_artifact_operations: + return + + manager_factory = self.manager_factory + + async def link_artifact( + entity_id: str, + request: LinkArtifactRequest, + manager: ConfigManager[BaseConfig] = Depends(manager_factory), + ) -> None: + config_id = self._parse_ulid(entity_id) + + try: + await manager.link_artifact(config_id, request.artifact_id) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) + + async def unlink_artifact( + entity_id: str, + request: UnlinkArtifactRequest, + manager: ConfigManager[BaseConfig] = Depends(manager_factory), + ) -> None: + try: + await manager.unlink_artifact(request.artifact_id) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) + + async def get_linked_artifacts( + entity_id: str, + manager: ConfigManager[BaseConfig] = Depends(manager_factory), + ) -> list[ArtifactOut]: + config_id = self._parse_ulid(entity_id) + return await manager.get_linked_artifacts(config_id) + + self.register_entity_operation( + "link-artifact", + link_artifact, + http_method="POST", + status_code=status.HTTP_204_NO_CONTENT, + summary="Link artifact to config", + description="Link a config to a root artifact (parent_id IS NULL)", + ) + + self.register_entity_operation( + "unlink-artifact", + unlink_artifact, + http_method="POST", + status_code=status.HTTP_204_NO_CONTENT, + summary="Unlink artifact from config", + description="Remove the link between a config and an artifact", + ) + + self.register_entity_operation( + "artifacts", + get_linked_artifacts, + http_method="GET", + response_model=list[ArtifactOut], + summary="Get linked artifacts", + description="Get all root artifacts linked to this config", + ) diff --git a/packages/servicekit/src/servicekit/modules/config/schemas.py b/packages/servicekit/src/servicekit/modules/config/schemas.py new file mode 100644 index 0000000..4f589df --- /dev/null +++ b/packages/servicekit/src/servicekit/modules/config/schemas.py @@ -0,0 +1,64 @@ +"""Config schemas for key-value configuration with JSON data.""" + +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel, field_serializer, field_validator +from pydantic_core.core_schema import ValidationInfo +from ulid import ULID + +from servicekit.core.schemas import EntityIn, EntityOut + + +class BaseConfig(BaseModel): + """Base class for configuration schemas with arbitrary extra fields allowed.""" + + model_config = {"extra": "allow"} + + +class ConfigIn[DataT: BaseConfig](EntityIn): + """Input schema for creating or updating configurations.""" + + name: str + data: DataT + + +class ConfigOut[DataT: BaseConfig](EntityOut): + """Output schema for configuration entities.""" + + name: str + data: DataT + + model_config = {"ser_json_timedelta": "float", "ser_json_bytes": "base64"} + + @field_validator("data", mode="before") + @classmethod + def convert_dict_to_model(cls, v: Any, info: ValidationInfo) -> Any: + """Convert dict to BaseConfig model if data_cls is provided in validation context.""" + if isinstance(v, BaseConfig): + return v + if isinstance(v, dict): + if info.context and "data_cls" in info.context: + data_cls = info.context["data_cls"] + return data_cls.model_validate(v) + return v + + @field_serializer("data", when_used="json") + def serialize_data(self, value: DataT) -> dict[str, Any]: + """Serialize BaseConfig data to JSON dict.""" + if isinstance(value, BaseConfig): # pyright: ignore[reportUnnecessaryIsInstance] + return value.model_dump(mode="json") + return value + + +class LinkArtifactRequest(BaseModel): + """Request schema for linking an artifact to a config.""" + + artifact_id: ULID + + +class UnlinkArtifactRequest(BaseModel): + """Request schema for unlinking an artifact from a config.""" + + artifact_id: ULID diff --git a/packages/servicekit/src/servicekit/modules/ml/__init__.py b/packages/servicekit/src/servicekit/modules/ml/__init__.py new file mode 100644 index 0000000..0eb6b5f --- /dev/null +++ b/packages/servicekit/src/servicekit/modules/ml/__init__.py @@ -0,0 +1,37 @@ +"""ML module placeholder - functionality moved to chapkit package. + +This stub exists only for backwards compatibility and type checking. +The actual ML implementation is in the chapkit package. +""" + +from typing import Any, Protocol, runtime_checkable + +__all__ = ["MLManager", "MLRouter", "ModelRunnerProtocol"] + + +@runtime_checkable +class ModelRunnerProtocol(Protocol): + """Protocol for ML model runners (stub).""" + + pass + + +class MLManager: + """ML manager stub - use chapkit package for ML functionality.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialize ML manager stub.""" + raise NotImplementedError( + "ML functionality has been moved to the chapkit package. Please install and use chapkit for ML operations." + ) + + +class MLRouter: + """ML router stub - use chapkit package for ML functionality.""" + + @staticmethod + def create(*args: Any, **kwargs: Any) -> Any: + """Create ML router stub.""" + raise NotImplementedError( + "ML functionality has been moved to the chapkit package. Please install and use chapkit for ML operations." + ) diff --git a/packages/servicekit/src/servicekit/modules/task/__init__.py b/packages/servicekit/src/servicekit/modules/task/__init__.py new file mode 100644 index 0000000..5117ed7 --- /dev/null +++ b/packages/servicekit/src/servicekit/modules/task/__init__.py @@ -0,0 +1,16 @@ +"""Task feature - reusable command templates for task execution.""" + +from .manager import TaskManager +from .models import Task +from .repository import TaskRepository +from .router import TaskRouter +from .schemas import TaskIn, TaskOut + +__all__ = [ + "Task", + "TaskIn", + "TaskOut", + "TaskRepository", + "TaskManager", + "TaskRouter", +] diff --git a/packages/servicekit/src/servicekit/modules/task/manager.py b/packages/servicekit/src/servicekit/modules/task/manager.py new file mode 100644 index 0000000..7ad2ace --- /dev/null +++ b/packages/servicekit/src/servicekit/modules/task/manager.py @@ -0,0 +1,112 @@ +"""Task manager for reusable command templates with artifact-based execution results.""" + +from __future__ import annotations + +import asyncio +from typing import Any + +from ulid import ULID + +from servicekit.core import Database +from servicekit.core.manager import BaseManager +from servicekit.core.scheduler import JobScheduler +from servicekit.modules.artifact import ArtifactIn, ArtifactManager, ArtifactRepository + +from .models import Task +from .repository import TaskRepository +from .schemas import TaskIn, TaskOut + + +class TaskManager(BaseManager[Task, TaskIn, TaskOut, ULID]): + """Manager for Task template entities with artifact-based execution.""" + + def __init__( + self, + repo: TaskRepository, + scheduler: JobScheduler | None = None, + database: Database | None = None, + artifact_manager: ArtifactManager | None = None, + ) -> None: + """Initialize task manager with repository, scheduler, database, and artifact manager.""" + super().__init__(repo, Task, TaskOut) + self.repo: TaskRepository = repo + self.scheduler = scheduler + self.database = database + self.artifact_manager = artifact_manager + + async def execute_task(self, task_id: ULID) -> ULID: + """Execute a task by submitting it to the scheduler and return the job ID.""" + if self.scheduler is None: + raise ValueError("Task execution requires a scheduler. Use ServiceBuilder.with_jobs() to enable.") + + if self.artifact_manager is None: + raise ValueError( + "Task execution requires artifacts. Use ServiceBuilder.with_artifacts() before with_tasks()." + ) + + task = await self.repo.find_by_id(task_id) + if task is None: + raise ValueError(f"Task {task_id} not found") + + # Submit job to scheduler + job_id = await self.scheduler.add_job(self._execute_command, task_id) + + return job_id + + async def _execute_command(self, task_id: ULID) -> ULID: + """Execute command and return artifact_id containing results.""" + if self.database is None: + raise RuntimeError("Database instance required for task execution") + + if self.artifact_manager is None: + raise RuntimeError("ArtifactManager instance required for task execution") + + # Fetch task and serialize snapshot before execution + async with self.database.session() as session: + task_repo = TaskRepository(session) + task = await task_repo.find_by_id(task_id) + if task is None: + raise ValueError(f"Task {task_id} not found") + + # Capture task snapshot + task_snapshot = { + "id": str(task.id), + "command": task.command, + "created_at": task.created_at.isoformat(), + "updated_at": task.updated_at.isoformat(), + } + + # Execute command using asyncio subprocess + process = await asyncio.create_subprocess_shell( + task.command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + # Wait for completion and capture output + stdout_bytes, stderr_bytes = await process.communicate() + + # Decode outputs + stdout_text = stdout_bytes.decode("utf-8") if stdout_bytes else "" + stderr_text = stderr_bytes.decode("utf-8") if stderr_bytes else "" + + # Create artifact with execution results + result_data: dict[str, Any] = { + "task": task_snapshot, + "stdout": stdout_text, + "stderr": stderr_text, + "exit_code": process.returncode, + } + + async with self.database.session() as session: + artifact_repo = ArtifactRepository(session) + artifact_mgr = ArtifactManager(artifact_repo) + + artifact_out = await artifact_mgr.save( + ArtifactIn( + data=result_data, + parent_id=None, + ) + ) + + return artifact_out.id diff --git a/packages/servicekit/src/servicekit/modules/task/models.py b/packages/servicekit/src/servicekit/modules/task/models.py new file mode 100644 index 0000000..eb1c682 --- /dev/null +++ b/packages/servicekit/src/servicekit/modules/task/models.py @@ -0,0 +1,16 @@ +"""Task ORM model for reusable command templates.""" + +from __future__ import annotations + +from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.types import Text + +from servicekit.core.models import Entity + + +class Task(Entity): + """ORM model for reusable task templates containing commands to execute.""" + + __tablename__ = "tasks" + + command: Mapped[str] = mapped_column(Text, nullable=False) diff --git a/packages/servicekit/src/servicekit/modules/task/repository.py b/packages/servicekit/src/servicekit/modules/task/repository.py new file mode 100644 index 0000000..e97ea82 --- /dev/null +++ b/packages/servicekit/src/servicekit/modules/task/repository.py @@ -0,0 +1,18 @@ +"""Task repository for database access and querying.""" + +from __future__ import annotations + +from sqlalchemy.ext.asyncio import AsyncSession +from ulid import ULID + +from servicekit.core.repository import BaseRepository + +from .models import Task + + +class TaskRepository(BaseRepository[Task, ULID]): + """Repository for Task template entities.""" + + def __init__(self, session: AsyncSession) -> None: + """Initialize task repository with database session.""" + super().__init__(session, Task) diff --git a/packages/servicekit/src/servicekit/modules/task/router.py b/packages/servicekit/src/servicekit/modules/task/router.py new file mode 100644 index 0000000..d55043e --- /dev/null +++ b/packages/servicekit/src/servicekit/modules/task/router.py @@ -0,0 +1,86 @@ +"""Task CRUD router with execution operation.""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any + +from fastapi import Depends, HTTPException, status +from pydantic import BaseModel, Field + +from servicekit.core.api.crud import CrudPermissions, CrudRouter + +from .manager import TaskManager +from .schemas import TaskIn, TaskOut + + +class TaskExecuteResponse(BaseModel): + """Response schema for task execution.""" + + job_id: str = Field(description="ID of the scheduler job") + message: str = Field(description="Human-readable message") + + +class TaskRouter(CrudRouter[TaskIn, TaskOut]): + """CRUD router for Task entities with execution operation.""" + + def __init__( + self, + prefix: str, + tags: Sequence[str], + manager_factory: Any, + entity_in_type: type[TaskIn], + entity_out_type: type[TaskOut], + permissions: CrudPermissions | None = None, + **kwargs: Any, + ) -> None: + """Initialize task router with entity types and manager factory.""" + super().__init__( + prefix=prefix, + tags=list(tags), + entity_in_type=entity_in_type, + entity_out_type=entity_out_type, + manager_factory=manager_factory, + permissions=permissions, + **kwargs, + ) + + def _register_routes(self) -> None: + """Register task CRUD routes and execution operation.""" + super()._register_routes() + + manager_factory = self.manager_factory + + async def execute_task( + entity_id: str, + manager: TaskManager = Depends(manager_factory), + ) -> TaskExecuteResponse: + """Execute a task asynchronously via the job scheduler.""" + task_id = self._parse_ulid(entity_id) + + try: + job_id = await manager.execute_task(task_id) + return TaskExecuteResponse( + job_id=str(job_id), + message=f"Task submitted for execution. Job ID: {job_id}", + ) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) + except RuntimeError as e: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=str(e), + ) + + self.register_entity_operation( + "execute", + execute_task, + http_method="POST", + response_model=TaskExecuteResponse, + status_code=status.HTTP_202_ACCEPTED, + summary="Execute task", + description="Submit the task to the scheduler for execution", + ) diff --git a/packages/servicekit/src/servicekit/modules/task/schemas.py b/packages/servicekit/src/servicekit/modules/task/schemas.py new file mode 100644 index 0000000..f3d8878 --- /dev/null +++ b/packages/servicekit/src/servicekit/modules/task/schemas.py @@ -0,0 +1,19 @@ +"""Task schemas for reusable command templates.""" + +from __future__ import annotations + +from pydantic import Field + +from servicekit.core.schemas import EntityIn, EntityOut + + +class TaskIn(EntityIn): + """Input schema for creating or updating task templates.""" + + command: str = Field(description="Shell command to execute") + + +class TaskOut(EntityOut): + """Output schema for task template entities.""" + + command: str = Field(description="Shell command to execute") diff --git a/packages/servicekit/src/servicekit/py.typed b/packages/servicekit/src/servicekit/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/packages/servicekit/tests/__init__.py b/packages/servicekit/tests/__init__.py new file mode 100644 index 0000000..2c65479 --- /dev/null +++ b/packages/servicekit/tests/__init__.py @@ -0,0 +1 @@ +"""Test package for chapkit.""" diff --git a/packages/servicekit/tests/_stubs.py b/packages/servicekit/tests/_stubs.py new file mode 100644 index 0000000..e99a55f --- /dev/null +++ b/packages/servicekit/tests/_stubs.py @@ -0,0 +1,148 @@ +"""Shared test stubs used across API router tests.""" + +from __future__ import annotations + +from collections.abc import Callable, Iterable, Sequence +from typing import Any, Generic, TypeVar + +from servicekit import ArtifactIn, ArtifactOut, ArtifactTreeNode, BaseConfig, ConfigIn, ConfigOut +from servicekit.core import Manager +from ulid import ULID + +ConfigDataT = TypeVar("ConfigDataT", bound=BaseConfig) + + +class ConfigManagerStub(Manager[ConfigIn[ConfigDataT], ConfigOut[ConfigDataT], ULID], Generic[ConfigDataT]): + """Config manager stub supporting name lookups for tests.""" + + def __init__( + self, + *, + items: dict[str, ConfigOut[ConfigDataT]] | None = None, + linked_artifacts: dict[ULID, list[ArtifactOut]] | None = None, + ) -> None: + self._items_by_name = dict(items or {}) + self._items_by_id = {item.id: item for item in self._items_by_name.values()} + self._linked_artifacts = linked_artifacts or {} + self._link_error: str | None = None + + def set_link_error(self, error: str) -> None: + """Set an error to be raised on link operations.""" + self._link_error = error + + async def link_artifact(self, config_id: ULID, artifact_id: ULID) -> None: + """Link an artifact to a config.""" + if self._link_error: + raise ValueError(self._link_error) + + async def unlink_artifact(self, artifact_id: ULID) -> None: + """Unlink an artifact from its config.""" + if self._link_error: + raise Exception(self._link_error) + + async def get_linked_artifacts(self, config_id: ULID) -> list[ArtifactOut]: + """Get all artifacts linked to a config.""" + return self._linked_artifacts.get(config_id, []) + + async def get_config_for_artifact(self, artifact_id: ULID, artifact_repo: Any) -> ConfigOut[ConfigDataT] | None: + """Get config for an artifact by traversing to root.""" + # For stub purposes, return the first config in items + items = list(self._items_by_id.values()) + return items[0] if items else None + + async def find_by_name(self, name: str) -> ConfigOut[ConfigDataT] | None: + return self._items_by_name.get(name) + + async def save(self, data: ConfigIn[ConfigDataT]) -> ConfigOut[ConfigDataT]: + raise NotImplementedError + + async def save_all(self, items: Iterable[ConfigIn[ConfigDataT]]) -> list[ConfigOut[ConfigDataT]]: + raise NotImplementedError + + async def find_all(self) -> list[ConfigOut[ConfigDataT]]: + return list(self._items_by_id.values()) + + async def find_paginated(self, page: int, size: int) -> tuple[list[ConfigOut[ConfigDataT]], int]: + all_items = list(self._items_by_id.values()) + offset = (page - 1) * size + paginated_items = all_items[offset : offset + size] + return paginated_items, len(all_items) + + async def find_all_by_id(self, ids: Sequence[ULID]) -> list[ConfigOut[ConfigDataT]]: + raise NotImplementedError + + async def find_by_id(self, id: ULID) -> ConfigOut[ConfigDataT] | None: + return self._items_by_id.get(id) + + async def exists_by_id(self, id: ULID) -> bool: + return id in self._items_by_id + + async def delete_by_id(self, id: ULID) -> None: + self._items_by_id.pop(id, None) + # Keep name map in sync + names_to_remove = [name for name, item in self._items_by_name.items() if item.id == id] + for name in names_to_remove: + self._items_by_name.pop(name, None) + + async def delete_all(self) -> None: + self._items_by_id.clear() + self._items_by_name.clear() + + async def delete_all_by_id(self, ids: Sequence[ULID]) -> None: + raise NotImplementedError + + async def count(self) -> int: + return len(self._items_by_id) + + +class ArtifactManagerStub(Manager[ArtifactIn, ArtifactOut, ULID]): + """Artifact manager stub providing tree data for tests.""" + + def __init__(self, *, trees: dict[ULID, ArtifactTreeNode] | None = None) -> None: + self._trees = trees or {} + self.repo = None # Will be set if needed for config lookups + + async def build_tree(self, id: ULID) -> ArtifactTreeNode | None: + return self._trees.get(id) + + async def save(self, data: ArtifactIn) -> ArtifactOut: + raise NotImplementedError + + async def save_all(self, items: Iterable[ArtifactIn]) -> list[ArtifactOut]: + raise NotImplementedError + + async def find_all(self) -> list[ArtifactOut]: + return [] + + async def find_paginated(self, page: int, size: int) -> tuple[list[ArtifactOut], int]: + return [], 0 + + async def find_all_by_id(self, ids: Sequence[ULID]) -> list[ArtifactOut]: + raise NotImplementedError + + async def find_by_id(self, id: ULID) -> ArtifactOut | None: + raise NotImplementedError + + async def exists_by_id(self, id: ULID) -> bool: + raise NotImplementedError + + async def delete_by_id(self, id: ULID) -> None: + raise NotImplementedError + + async def delete_all(self) -> None: + raise NotImplementedError + + async def delete_all_by_id(self, ids: Sequence[ULID]) -> None: + raise NotImplementedError + + async def count(self) -> int: + return len(self._trees) + + +def singleton_factory(instance: Any) -> Callable[[], Any]: + """Return a dependency factory that always yields the provided instance.""" + + def _provide() -> Any: + return instance + + return _provide diff --git a/packages/servicekit/tests/conftest.py b/packages/servicekit/tests/conftest.py new file mode 100644 index 0000000..cd9d1a6 --- /dev/null +++ b/packages/servicekit/tests/conftest.py @@ -0,0 +1,12 @@ +"""Test configuration and shared fixtures.""" + +from servicekit import BaseConfig + + +class DemoConfig(BaseConfig): + """Concrete config schema for testing.""" + + x: int + y: int + z: int + tags: list[str] diff --git a/packages/servicekit/tests/test_api_artifact_serialization.py b/packages/servicekit/tests/test_api_artifact_serialization.py new file mode 100644 index 0000000..50f5c2e --- /dev/null +++ b/packages/servicekit/tests/test_api_artifact_serialization.py @@ -0,0 +1,192 @@ +"""Integration tests for artifact API with non-JSON-serializable data.""" + +from __future__ import annotations + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from servicekit import ArtifactIn, ArtifactOut, SqliteDatabaseBuilder +from servicekit.api import ArtifactRouter, add_error_handlers, get_artifact_manager +from servicekit.core.api.dependencies import set_database + + +class NonSerializableObject: + """Custom object that cannot be JSON-serialized.""" + + def __init__(self, value: str) -> None: + self.value = value + + def __repr__(self) -> str: + return f"NonSerializableObject({self.value!r})" + + +class TestArtifactAPIWithNonSerializableData: + """Test that the API handles non-serializable artifact data gracefully.""" + + @pytest.mark.asyncio + async def test_create_and_retrieve_artifact_with_custom_object(self) -> None: + """API should handle artifacts with non-serializable data without crashing.""" + # Setup + app = FastAPI() + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + set_database(db) + + try: + # Create router + artifact_router = ArtifactRouter.create( + prefix="/api/v1/artifacts", + tags=["Artifacts"], + manager_factory=get_artifact_manager, + entity_in_type=ArtifactIn, + entity_out_type=ArtifactOut, + ) + app.include_router(artifact_router) + add_error_handlers(app) + + # Create an artifact directly via the database (bypassing API validation) + async with db.session() as session: + from servicekit import ArtifactManager, ArtifactRepository + + repo = ArtifactRepository(session) + manager = ArtifactManager(repo) + + # Save artifact with non-serializable data + custom_obj = NonSerializableObject("test_value") + artifact_in = ArtifactIn(data=custom_obj) + saved_artifact = await manager.save(artifact_in) + artifact_id = saved_artifact.id + + # Now try to retrieve via API - should not crash + with TestClient(app) as client: + response = client.get(f"/api/v1/artifacts/{artifact_id}") + + # Should return 200 OK (not crash with 500) + assert response.status_code == 200 + + # Should return metadata instead of the actual object + data = response.json() + assert data["id"] == str(artifact_id) + assert isinstance(data["data"], dict) + assert data["data"]["_type"] == "NonSerializableObject" + assert data["data"]["_module"] == __name__ + assert "NonSerializableObject('test_value')" in data["data"]["_repr"] + assert "_serialization_error" in data["data"] + + finally: + await db.dispose() + + @pytest.mark.asyncio + async def test_list_artifacts_with_mixed_data(self) -> None: + """API should handle listing artifacts with both serializable and non-serializable data.""" + # Setup + app = FastAPI() + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + set_database(db) + + try: + # Create router + artifact_router = ArtifactRouter.create( + prefix="/api/v1/artifacts", + tags=["Artifacts"], + manager_factory=get_artifact_manager, + entity_in_type=ArtifactIn, + entity_out_type=ArtifactOut, + ) + app.include_router(artifact_router) + add_error_handlers(app) + + # Create artifacts with different data types + async with db.session() as session: + from servicekit import ArtifactManager, ArtifactRepository + + repo = ArtifactRepository(session) + manager = ArtifactManager(repo) + + # Artifact 1: JSON-serializable data + await manager.save(ArtifactIn(data={"type": "json", "value": 123})) + + # Artifact 2: Non-serializable data + custom_obj = NonSerializableObject("test") + await manager.save(ArtifactIn(data=custom_obj)) + + # Artifact 3: Another JSON-serializable + await manager.save(ArtifactIn(data=["list", "of", "values"])) + + # List all artifacts - should not crash + with TestClient(app) as client: + response = client.get("/api/v1/artifacts") + + # Should return 200 OK + assert response.status_code == 200 + + artifacts = response.json() + assert len(artifacts) == 3 + + # First artifact - JSON data unchanged + assert artifacts[0]["data"] == {"type": "json", "value": 123} + + # Second artifact - metadata returned + assert artifacts[1]["data"]["_type"] == "NonSerializableObject" + + # Third artifact - JSON data unchanged + assert artifacts[2]["data"] == ["list", "of", "values"] + + finally: + await db.dispose() + + @pytest.mark.asyncio + async def test_tree_operation_with_non_serializable_data(self) -> None: + """Tree operation should handle non-serializable data in nested artifacts.""" + # Setup + app = FastAPI() + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + set_database(db) + + try: + # Create router + artifact_router = ArtifactRouter.create( + prefix="/api/v1/artifacts", + tags=["Artifacts"], + manager_factory=get_artifact_manager, + entity_in_type=ArtifactIn, + entity_out_type=ArtifactOut, + ) + app.include_router(artifact_router) + add_error_handlers(app) + + # Create a tree with non-serializable data + root_id = None + async with db.session() as session: + from servicekit import ArtifactManager, ArtifactRepository + + repo = ArtifactRepository(session) + manager = ArtifactManager(repo) + + # Root with non-serializable data + custom_obj = NonSerializableObject("root") + root = await manager.save(ArtifactIn(data=custom_obj)) + root_id = root.id + + # Child with JSON data + await manager.save(ArtifactIn(data={"child": "data"}, parent_id=root_id)) + + # Get tree - should not crash + with TestClient(app) as client: + response = client.get(f"/api/v1/artifacts/{root_id}/$tree") + + # Should return 200 OK + assert response.status_code == 200 + + tree = response.json() + + # Root should have metadata + assert tree["data"]["_type"] == "NonSerializableObject" + + # Child should have JSON data + assert tree["children"][0]["data"] == {"child": "data"} + + finally: + await db.dispose() diff --git a/packages/servicekit/tests/test_api_base.py b/packages/servicekit/tests/test_api_base.py new file mode 100644 index 0000000..781ce41 --- /dev/null +++ b/packages/servicekit/tests/test_api_base.py @@ -0,0 +1,55 @@ +"""Tests for the base API router abstraction.""" + +from typing import ClassVar, Sequence + +from fastapi import APIRouter, FastAPI +from fastapi.testclient import TestClient +from servicekit.core.api.router import Router + + +class TrackingRouter(Router): + """Test router that counts how many times routes are registered.""" + + register_calls: ClassVar[int] = 0 + + def _register_routes(self) -> None: + type(self).register_calls += 1 + + @self.router.get("/") + async def read_root() -> dict[str, str]: + return {"status": "ok"} + + +class NoopRouter(Router): + """Router used to validate APIRouter construction details.""" + + def _register_routes(self) -> None: + return None + + +def test_router_create_calls_register_routes_once() -> None: + TrackingRouter.register_calls = 0 + + router = TrackingRouter.create(prefix="/tracking", tags=["tracking"]) + + assert isinstance(router, APIRouter) + assert TrackingRouter.register_calls == 1 + + app = FastAPI() + app.include_router(router) + client = TestClient(app) + response = client.get("/tracking/") + + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + +def test_router_create_converts_tags_to_list_and_applies_kwargs() -> None: + tags: Sequence[str] = ("alpha", "beta") + + router = NoopRouter.create(prefix="/noop", tags=tags, deprecated=True) + + assert isinstance(router, APIRouter) + assert router.tags == ["alpha", "beta"] + assert router.prefix == "/noop" + assert router.deprecated is True diff --git a/packages/servicekit/tests/test_api_crud.py b/packages/servicekit/tests/test_api_crud.py new file mode 100644 index 0000000..6006ca1 --- /dev/null +++ b/packages/servicekit/tests/test_api_crud.py @@ -0,0 +1,445 @@ +"""Tests for the CRUD router abstraction.""" + +from __future__ import annotations + +from collections.abc import Iterable, Sequence + +import pytest +from fastapi import Depends, FastAPI, HTTPException, status +from fastapi.routing import APIRoute +from fastapi.testclient import TestClient +from pydantic import BaseModel +from servicekit.core.api.crud import CrudRouter +from servicekit.core.manager import Manager +from ulid import ULID + + +class ItemIn(BaseModel): + """Input schema with optional ID to support updates.""" + + name: str + description: str | None = None + id: ULID | None = None + + +class ItemOut(BaseModel): + """Output schema with ULID identifier.""" + + id: ULID + name: str + description: str | None = None + + +class FakeManager(Manager[ItemIn, ItemOut, ULID]): + """Minimal async manager used for exercising router behaviour.""" + + def __init__(self) -> None: + self.entities: dict[ULID, ItemOut] = {} + + async def save(self, data: ItemIn) -> ItemOut: + entity_id = data.id or ULID() + entity = ItemOut(id=entity_id, name=data.name, description=data.description) + self.entities[entity_id] = entity + return entity + + async def find_all(self) -> list[ItemOut]: + return list(self.entities.values()) + + async def find_paginated(self, page: int, size: int) -> tuple[list[ItemOut], int]: + all_items = list(self.entities.values()) + offset = (page - 1) * size + paginated_items = all_items[offset : offset + size] + return paginated_items, len(all_items) + + async def find_all_by_id(self, ids: Sequence[ULID]) -> list[ItemOut]: + id_set = set(ids) + return [entity for entity_id, entity in self.entities.items() if entity_id in id_set] + + async def find_by_id(self, id: ULID) -> ItemOut | None: + return self.entities.get(id) + + async def exists_by_id(self, id: ULID) -> bool: + return id in self.entities + + async def delete_by_id(self, id: ULID) -> None: + self.entities.pop(id, None) + + async def delete_all(self) -> None: + self.entities.clear() + + async def delete_all_by_id(self, ids: Sequence[ULID]) -> None: + for entity_id in ids: + self.entities.pop(entity_id, None) + + async def count(self) -> int: + return len(self.entities) + + async def save_all(self, items: Iterable[ItemIn]) -> list[ItemOut]: + return [await self.save(item) for item in items] + + +def _build_router(manager: FakeManager) -> CrudRouter[ItemIn, ItemOut]: + def manager_factory() -> Manager[ItemIn, ItemOut, ULID]: + return manager + + return CrudRouter[ItemIn, ItemOut]( + prefix="/items", + tags=["items"], + entity_in_type=ItemIn, + entity_out_type=ItemOut, + manager_factory=manager_factory, + ) + + +@pytest.fixture +def crud_client() -> tuple[TestClient, FakeManager, CrudRouter[ItemIn, ItemOut]]: + from servicekit.core.api.middleware import add_error_handlers + + manager = FakeManager() + router = _build_router(manager) + app = FastAPI() + add_error_handlers(app) + app.include_router(router.router) + return TestClient(app), manager, router + + +@pytest.fixture +def operations_client() -> tuple[TestClient, FakeManager, CrudRouter[ItemIn, ItemOut]]: + from servicekit.core.api.middleware import add_error_handlers + + manager = FakeManager() + router = _build_router(manager) + + async def echo_entity( + entity_id: str, + manager_dep: FakeManager = Depends(router.manager_factory), + ) -> ItemOut: + ulid_id = router._parse_ulid(entity_id) + entity = await manager_dep.find_by_id(ulid_id) + if entity is None: + raise HTTPException(status.HTTP_404_NOT_FOUND, "Entity not found") + return entity + + router.register_entity_operation( + "echo", + echo_entity, + response_model=ItemOut, + summary="Echo entity", + description="Return the entity if it exists", + ) + + async def tally( + manager_dep: FakeManager = Depends(router.manager_factory), + ) -> dict[str, int]: + return {"count": len(manager_dep.entities)} + + router.register_collection_operation( + "tally", + tally, + http_method="POST", + status_code=status.HTTP_202_ACCEPTED, + summary="Count items", + description="Return the number of stored items", + ) + + app = FastAPI() + add_error_handlers(app) + app.include_router(router.router) + return TestClient(app), manager, router + + +def test_create_persists_entity(crud_client: tuple[TestClient, FakeManager, CrudRouter[ItemIn, ItemOut]]) -> None: + client, manager, _ = crud_client + + response = client.post("/items/", json={"name": "widget", "description": "first"}) + + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["name"] == "widget" + assert data["description"] == "first" + assert len(manager.entities) == 1 + stored = next(iter(manager.entities.values())) + assert stored.name == "widget" + assert str(stored.id) == data["id"] + + +def test_create_returns_location_header( + crud_client: tuple[TestClient, FakeManager, CrudRouter[ItemIn, ItemOut]], +) -> None: + client, _, _ = crud_client + + response = client.post("/items/", json={"name": "widget", "description": "first"}) + + assert response.status_code == status.HTTP_201_CREATED + assert "Location" in response.headers + data = response.json() + entity_id = data["id"] + expected_location = f"http://testserver/items/{entity_id}" + assert response.headers["Location"] == expected_location + + +def test_find_all_returns_all_entities( + crud_client: tuple[TestClient, FakeManager, CrudRouter[ItemIn, ItemOut]], +) -> None: + client, _, _ = crud_client + client.post("/items/", json={"name": "alpha"}) + client.post("/items/", json={"name": "beta"}) + + response = client.get("/items/") + + assert response.status_code == status.HTTP_200_OK + payload = response.json() + assert {item["name"] for item in payload} == {"alpha", "beta"} + + +def test_find_by_id_returns_entity(crud_client: tuple[TestClient, FakeManager, CrudRouter[ItemIn, ItemOut]]) -> None: + client, _, _ = crud_client + created = client.post("/items/", json={"name": "stored"}).json() + + response = client.get(f"/items/{created['id']}") + + assert response.status_code == status.HTTP_200_OK + assert response.json()["name"] == "stored" + + +def test_find_by_id_returns_404_when_missing( + crud_client: tuple[TestClient, FakeManager, CrudRouter[ItemIn, ItemOut]], +) -> None: + client, _, _ = crud_client + missing_id = str(ULID()) + + response = client.get(f"/items/{missing_id}") + + assert response.status_code == status.HTTP_404_NOT_FOUND + assert "not found" in response.json()["detail"] + + +def test_find_by_id_rejects_invalid_ulid( + crud_client: tuple[TestClient, FakeManager, CrudRouter[ItemIn, ItemOut]], +) -> None: + client, _, _ = crud_client + + response = client.get("/items/not-a-ulid") + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "Invalid ULID format" in response.json()["detail"] + + +def test_update_replaces_entity_values( + crud_client: tuple[TestClient, FakeManager, CrudRouter[ItemIn, ItemOut]], +) -> None: + client, manager, _ = crud_client + created = client.post("/items/", json={"name": "original", "description": "old"}).json() + + response = client.put(f"/items/{created['id']}", json={"name": "updated"}) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["name"] == "updated" + ulid_id = ULID.from_str(created["id"]) + assert manager.entities[ulid_id].name == "updated" + + +def test_update_returns_404_when_entity_missing( + crud_client: tuple[TestClient, FakeManager, CrudRouter[ItemIn, ItemOut]], +) -> None: + client, _, _ = crud_client + missing = str(ULID()) + + response = client.put(f"/items/{missing}", json={"name": "irrelevant"}) + + assert response.status_code == status.HTTP_404_NOT_FOUND + assert "not found" in response.json()["detail"] + + +def test_update_rejects_invalid_ulid(crud_client: tuple[TestClient, FakeManager, CrudRouter[ItemIn, ItemOut]]) -> None: + client, _, _ = crud_client + + response = client.put("/items/not-a-ulid", json={"name": "invalid"}) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + + +def test_delete_by_id_removes_entity(crud_client: tuple[TestClient, FakeManager, CrudRouter[ItemIn, ItemOut]]) -> None: + client, manager, _ = crud_client + created = client.post("/items/", json={"name": "to-delete"}).json() + entity_id = ULID.from_str(created["id"]) + + response = client.delete(f"/items/{created['id']}") + + assert response.status_code == status.HTTP_204_NO_CONTENT + assert entity_id not in manager.entities + + +def test_delete_by_id_returns_404_when_entity_missing( + crud_client: tuple[TestClient, FakeManager, CrudRouter[ItemIn, ItemOut]], +) -> None: + client, _, _ = crud_client + + response = client.delete(f"/items/{str(ULID())}") + + assert response.status_code == status.HTTP_404_NOT_FOUND + + +def test_delete_by_id_rejects_invalid_ulid( + crud_client: tuple[TestClient, FakeManager, CrudRouter[ItemIn, ItemOut]], +) -> None: + client, _, _ = crud_client + + response = client.delete("/items/not-a-ulid") + + assert response.status_code == status.HTTP_400_BAD_REQUEST + + +def test_entity_operation_uses_registered_handler( + operations_client: tuple[TestClient, FakeManager, CrudRouter[ItemIn, ItemOut]], +) -> None: + client, _, router = operations_client + created = client.post("/items/", json={"name": "entity"}).json() + + response = client.get(f"/items/{created['id']}/$echo") + + assert response.status_code == status.HTTP_200_OK + assert response.json()["id"] == created["id"] + route = next( + route + for route in router.router.routes + if isinstance(route, APIRoute) and route.path == "/items/{entity_id}/$echo" + ) + assert isinstance(route, APIRoute) + assert route.summary == "Echo entity" + + +def test_entity_operation_validates_ulid_before_handler( + operations_client: tuple[TestClient, FakeManager, CrudRouter[ItemIn, ItemOut]], +) -> None: + client, _, _ = operations_client + + response = client.get("/items/not-a-ulid/$echo") + + assert response.status_code == status.HTTP_400_BAD_REQUEST + + +def test_collection_operation_supports_custom_http_method( + operations_client: tuple[TestClient, FakeManager, CrudRouter[ItemIn, ItemOut]], +) -> None: + client, _, router = operations_client + client.post("/items/", json={"name": "counted"}) + + response = client.post("/items/$tally") + + assert response.status_code == status.HTTP_202_ACCEPTED + assert response.json() == {"count": 1} + route = next( + route for route in router.router.routes if isinstance(route, APIRoute) and route.path == "/items/$tally" + ) + assert isinstance(route, APIRoute) + assert route.summary == "Count items" + + +def test_register_entity_operation_rejects_unknown_http_method() -> None: + manager = FakeManager() + router = _build_router(manager) + + async def handler(entity_id: str) -> None: + return None + + with pytest.raises(ValueError): + router.register_entity_operation("invalid", handler, http_method="INVALID") + + +def test_register_collection_operation_rejects_unknown_http_method() -> None: + manager = FakeManager() + router = _build_router(manager) + + async def handler() -> None: + return None + + with pytest.raises(ValueError): + router.register_collection_operation("invalid", handler, http_method="INVALID") + + +def test_entity_operation_supports_patch_method() -> None: + from servicekit.core.api.middleware import add_error_handlers + + manager = FakeManager() + router = _build_router(manager) + + async def partial_update( + entity_id: str, + name: str, + manager_dep: FakeManager = Depends(router.manager_factory), + ) -> ItemOut: + ulid_id = router._parse_ulid(entity_id) + entity = await manager_dep.find_by_id(ulid_id) + if entity is None: + raise HTTPException(status.HTTP_404_NOT_FOUND, "Not found") + entity.name = name + return entity + + router.register_entity_operation( + "partial-update", + partial_update, + http_method="PATCH", + response_model=ItemOut, + summary="Partially update entity", + ) + + app = FastAPI() + add_error_handlers(app) + app.include_router(router.router) + client = TestClient(app) + + created = client.post("/items/", json={"name": "original"}).json() + response = client.patch(f"/items/{created['id']}/$partial-update?name=updated") + + assert response.status_code == status.HTTP_200_OK + assert response.json()["name"] == "updated" + route = next( + route + for route in router.router.routes + if isinstance(route, APIRoute) and route.path == "/items/{entity_id}/$partial-update" + ) + assert "PATCH" in route.methods + + +def test_collection_operation_supports_patch_method() -> None: + from servicekit.core.api.middleware import add_error_handlers + + manager = FakeManager() + router = _build_router(manager) + + async def bulk_update( + suffix: str, + manager_dep: FakeManager = Depends(router.manager_factory), + ) -> dict[str, int]: + count = 0 + for entity in manager_dep.entities.values(): + entity.name = f"{entity.name}_{suffix}" + count += 1 + return {"updated": count} + + router.register_collection_operation( + "bulk-update", + bulk_update, + http_method="PATCH", + status_code=status.HTTP_200_OK, + summary="Bulk update items", + ) + + app = FastAPI() + add_error_handlers(app) + app.include_router(router.router) + client = TestClient(app) + + client.post("/items/", json={"name": "item1"}) + client.post("/items/", json={"name": "item2"}) + + response = client.patch("/items/$bulk-update?suffix=modified") + + assert response.status_code == status.HTTP_200_OK + assert response.json()["updated"] == 2 + route = next( + route for route in router.router.routes if isinstance(route, APIRoute) and route.path == "/items/$bulk-update" + ) + assert "PATCH" in route.methods diff --git a/packages/servicekit/tests/test_api_openapi.py b/packages/servicekit/tests/test_api_openapi.py new file mode 100644 index 0000000..2691ecd --- /dev/null +++ b/packages/servicekit/tests/test_api_openapi.py @@ -0,0 +1,42 @@ +"""Smoke test for OpenAPI schema generation with chapkit routers.""" + +from __future__ import annotations + +from fastapi import FastAPI +from fastapi.testclient import TestClient +from servicekit import ArtifactIn, ArtifactOut, BaseConfig, ConfigIn, ConfigOut +from servicekit.api import ArtifactRouter, ConfigRouter + +from tests._stubs import ArtifactManagerStub, ConfigManagerStub, singleton_factory + + +def test_openapi_schema_with_config_and_artifact_routers() -> None: + app = FastAPI() + + config_router = ConfigRouter.create( + prefix="/config", + tags=["Config"], + manager_factory=singleton_factory(ConfigManagerStub()), + entity_in_type=ConfigIn[BaseConfig], + entity_out_type=ConfigOut[BaseConfig], + ) + + artifact_router = ArtifactRouter.create( + prefix="/artifacts", + tags=["Artifacts"], + manager_factory=singleton_factory(ArtifactManagerStub()), + entity_in_type=ArtifactIn, + entity_out_type=ArtifactOut, + ) + + app.include_router(config_router) + app.include_router(artifact_router) + + client = TestClient(app) + response = client.get("/openapi.json") + + assert response.status_code == 200 + schema = response.json() + assert "/config" in schema["paths"] + assert "/artifacts" in schema["paths"] + assert "/artifacts/{entity_id}/$tree" in schema["paths"] diff --git a/packages/servicekit/tests/test_api_routers.py b/packages/servicekit/tests/test_api_routers.py new file mode 100644 index 0000000..5a1390d --- /dev/null +++ b/packages/servicekit/tests/test_api_routers.py @@ -0,0 +1,350 @@ +"""Tests for concrete API routers built on CrudRouter.""" + +from __future__ import annotations + +from datetime import datetime, timezone + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from servicekit import ArtifactIn, ArtifactOut, ArtifactTreeNode, BaseConfig, ConfigIn, ConfigOut +from servicekit.api import ArtifactRouter, ConfigRouter +from ulid import ULID + +from tests._stubs import ArtifactManagerStub, ConfigManagerStub, singleton_factory + + +class ExampleConfig(BaseConfig): + """Sample config payload for tests.""" + + enabled: bool + + +@pytest.fixture +def config_app() -> tuple[TestClient, ConfigOut[ExampleConfig]]: + from servicekit.core.api.middleware import add_error_handlers + + now = datetime.now(tz=timezone.utc) + record = ConfigOut[ExampleConfig]( + id=ULID(), + name="feature-toggle", + data=ExampleConfig(enabled=True), + created_at=now, + updated_at=now, + ) + manager = ConfigManagerStub[ExampleConfig](items={"feature-toggle": record}) + router = ConfigRouter.create( + prefix="/config", + tags=["Config"], + manager_factory=singleton_factory(manager), + entity_in_type=ConfigIn[ExampleConfig], + entity_out_type=ConfigOut[ExampleConfig], + enable_artifact_operations=True, + ) + app = FastAPI() + add_error_handlers(app) + app.include_router(router) + return TestClient(app), record + + +@pytest.fixture +def artifact_app() -> tuple[TestClient, ArtifactTreeNode]: + from servicekit.core.api.middleware import add_error_handlers + + now = datetime.now(tz=timezone.utc) + root_id = ULID() + root = ArtifactTreeNode( + id=root_id, + data={"name": "root"}, + parent_id=None, + level=0, + created_at=now, + updated_at=now, + children=[], + ) + manager = ArtifactManagerStub(trees={root_id: root}) + router = ArtifactRouter.create( + prefix="/artifacts", + tags=["Artifacts"], + manager_factory=singleton_factory(manager), + entity_in_type=ArtifactIn, + entity_out_type=ArtifactOut, + ) + app = FastAPI() + add_error_handlers(app) + app.include_router(router) + return TestClient(app), root + + +def test_config_router_list_returns_records(config_app: tuple[TestClient, ConfigOut[ExampleConfig]]) -> None: + client, record = config_app + + response = client.get("/config/") + + assert response.status_code == 200 + payload = response.json() + assert len(payload) == 1 + first = payload[0] + assert first["id"] == str(record.id) + assert first["name"] == "feature-toggle" + assert first["data"] == {"enabled": True} + assert "created_at" in first + assert "updated_at" in first + + +def test_config_router_find_by_id_returns_record(config_app: tuple[TestClient, ConfigOut[ExampleConfig]]) -> None: + client, record = config_app + + response = client.get(f"/config/{record.id}") + + assert response.status_code == 200 + payload = response.json() + assert payload["id"] == str(record.id) + assert payload["name"] == "feature-toggle" + assert payload["data"] == {"enabled": True} + + +def test_config_router_find_by_id_returns_404(config_app: tuple[TestClient, ConfigOut[ExampleConfig]]) -> None: + client, _ = config_app + + response = client.get(f"/config/{ULID()}") + + assert response.status_code == 404 + assert "not found" in response.json()["detail"] + + +def test_artifact_router_tree_returns_node(artifact_app: tuple[TestClient, ArtifactTreeNode]) -> None: + client, root = artifact_app + + response = client.get(f"/artifacts/{root.id}/$tree") + + assert response.status_code == 200 + payload = response.json() + assert payload["id"] == str(root.id) + assert payload["data"] == {"name": "root"} + assert payload["children"] == [] + + +def test_artifact_router_tree_returns_404_when_missing(artifact_app: tuple[TestClient, ArtifactTreeNode]) -> None: + client, _ = artifact_app + + response = client.get(f"/artifacts/{ULID()}/$tree") + + assert response.status_code == 404 + assert "not found" in response.json()["detail"] + + +def test_artifact_router_tree_rejects_invalid_ulid(artifact_app: tuple[TestClient, ArtifactTreeNode]) -> None: + client, _ = artifact_app + + response = client.get("/artifacts/not-a-ulid/$tree") + + assert response.status_code == 400 + assert "Invalid ULID" in response.json()["detail"] + + +def test_config_router_link_artifact_success(config_app: tuple[TestClient, ConfigOut[ExampleConfig]]) -> None: + """Test successful artifact linking to a config.""" + client, record = config_app + + artifact_id = ULID() + response = client.post( + f"/config/{record.id}/$link-artifact", + json={"artifact_id": str(artifact_id)}, + ) + + assert response.status_code == 204 + + +def test_config_router_link_artifact_with_error(config_app: tuple[TestClient, ConfigOut[ExampleConfig]]) -> None: + """Test artifact linking with validation error.""" + client, record = config_app + artifact_id = ULID() + + # Inject error into manager + from tests._stubs import ConfigManagerStub + + manager = ConfigManagerStub[ExampleConfig](items={"feature-toggle": record}) + manager.set_link_error("Cannot link: artifact is not a root node") + + # Create new app with error-injected manager + from servicekit.core.api.middleware import add_error_handlers + + router = ConfigRouter.create( + prefix="/config", + tags=["Config"], + manager_factory=singleton_factory(manager), + entity_in_type=ConfigIn[ExampleConfig], + entity_out_type=ConfigOut[ExampleConfig], + enable_artifact_operations=True, + ) + app = FastAPI() + add_error_handlers(app) + app.include_router(router) + client = TestClient(app) + + response = client.post( + f"/config/{record.id}/$link-artifact", + json={"artifact_id": str(artifact_id)}, + ) + + assert response.status_code == 400 + assert "Cannot link" in response.json()["detail"] + + +def test_config_router_unlink_artifact_success(config_app: tuple[TestClient, ConfigOut[ExampleConfig]]) -> None: + """Test successful artifact unlinking from a config.""" + client, _ = config_app + + artifact_id = ULID() + response = client.post( + "/config/anyid/$unlink-artifact", + json={"artifact_id": str(artifact_id)}, + ) + + assert response.status_code == 204 + + +def test_config_router_unlink_artifact_with_error(config_app: tuple[TestClient, ConfigOut[ExampleConfig]]) -> None: + """Test artifact unlinking with error.""" + client, record = config_app + artifact_id = ULID() + + # Inject error into manager + from tests._stubs import ConfigManagerStub + + manager = ConfigManagerStub[ExampleConfig](items={"feature-toggle": record}) + manager.set_link_error("Artifact not found") + + # Create new app with error-injected manager + from servicekit.core.api.middleware import add_error_handlers + + router = ConfigRouter.create( + prefix="/config", + tags=["Config"], + manager_factory=singleton_factory(manager), + entity_in_type=ConfigIn[ExampleConfig], + entity_out_type=ConfigOut[ExampleConfig], + enable_artifact_operations=True, + ) + app = FastAPI() + add_error_handlers(app) + app.include_router(router) + client = TestClient(app) + + response = client.post( + f"/config/{record.id}/$unlink-artifact", + json={"artifact_id": str(artifact_id)}, + ) + + assert response.status_code == 400 + assert "Artifact not found" in response.json()["detail"] + + +def test_config_router_get_linked_artifacts(config_app: tuple[TestClient, ConfigOut[ExampleConfig]]) -> None: + """Test retrieving linked artifacts for a config.""" + client, record = config_app + artifact_id = ULID() + + # Create manager with pre-linked artifacts + now = datetime.now(tz=timezone.utc) + artifact = ArtifactOut( + id=artifact_id, + data={"name": "test-artifact"}, + parent_id=None, + level=0, + created_at=now, + updated_at=now, + ) + + from tests._stubs import ConfigManagerStub + + manager = ConfigManagerStub[ExampleConfig]( + items={"feature-toggle": record}, + linked_artifacts={record.id: [artifact]}, + ) + + # Create new app with manager that has linked artifacts + from servicekit.core.api.middleware import add_error_handlers + + router = ConfigRouter.create( + prefix="/config", + tags=["Config"], + manager_factory=singleton_factory(manager), + entity_in_type=ConfigIn[ExampleConfig], + entity_out_type=ConfigOut[ExampleConfig], + enable_artifact_operations=True, + ) + app = FastAPI() + add_error_handlers(app) + app.include_router(router) + client = TestClient(app) + + response = client.get(f"/config/{record.id}/$artifacts") + + assert response.status_code == 200 + artifacts = response.json() + assert len(artifacts) == 1 + assert artifacts[0]["id"] == str(artifact_id) + assert artifacts[0]["data"] == {"name": "test-artifact"} + + +def test_artifact_router_get_config_returns_config() -> None: + """Test retrieving config for an artifact.""" + now = datetime.now(tz=timezone.utc) + artifact_id = ULID() + config_id = ULID() + + # Create managers with linked config + artifact = ArtifactTreeNode( + id=artifact_id, + data={"name": "test-artifact"}, + parent_id=None, + level=0, + created_at=now, + updated_at=now, + children=[], + ) + config = ConfigOut[ExampleConfig]( + id=config_id, + name="test-config", + data=ExampleConfig(enabled=True), + created_at=now, + updated_at=now, + ) + + from tests._stubs import ArtifactManagerStub, ConfigManagerStub + + artifact_manager = ArtifactManagerStub(trees={artifact_id: artifact}) + config_manager = ConfigManagerStub[ExampleConfig](items={"test-config": config}) + + # Create app with both managers + from servicekit.core.api.middleware import add_error_handlers + + artifact_router = ArtifactRouter.create( + prefix="/artifacts", + tags=["Artifacts"], + manager_factory=singleton_factory(artifact_manager), + entity_in_type=ArtifactIn, + entity_out_type=ArtifactOut, + enable_config_access=True, + ) + + app = FastAPI() + add_error_handlers(app) + app.include_router(artifact_router) + + # Override config manager dependency + from servicekit.api.dependencies import get_config_manager + + app.dependency_overrides[get_config_manager] = singleton_factory(config_manager) + + client = TestClient(app) + + response = client.get(f"/artifacts/{artifact_id}/$config") + + assert response.status_code == 200 + config_data = response.json() + assert config_data["id"] == str(config_id) + assert config_data["name"] == "test-config" + assert config_data["data"] == {"enabled": True} diff --git a/packages/servicekit/tests/test_api_service_builder.py b/packages/servicekit/tests/test_api_service_builder.py new file mode 100644 index 0000000..7bad678 --- /dev/null +++ b/packages/servicekit/tests/test_api_service_builder.py @@ -0,0 +1,424 @@ +"""Tests for ServiceBuilder functionality.""" + +from __future__ import annotations + +from collections.abc import AsyncGenerator + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from servicekit import ArtifactHierarchy, BaseConfig, SqliteDatabaseBuilder +from servicekit.api import ServiceBuilder, ServiceInfo +from servicekit.core import Database +from servicekit.core.api.routers.health import HealthState + + +class ExampleConfig(BaseConfig): + """Example configuration schema for tests.""" + + enabled: bool + value: int + + +@pytest.fixture +async def test_database() -> AsyncGenerator[Database, None]: + """Provide a test database instance.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + yield db + await db.dispose() + + +@pytest.fixture +def service_info() -> ServiceInfo: + """Provide basic service info for tests.""" + return ServiceInfo( + display_name="Test Service", + version="1.0.0", + summary="Test service for unit tests", + ) + + +def test_service_builder_creates_basic_app(service_info: ServiceInfo) -> None: + """Test that ServiceBuilder creates a minimal FastAPI app.""" + app = ServiceBuilder.create(info=service_info) + + assert isinstance(app, FastAPI) + assert app.title == "Test Service" + assert app.version == "1.0.0" + + +def test_service_builder_with_health_endpoint(service_info: ServiceInfo) -> None: + """Test that with_health() adds health endpoint.""" + app = ServiceBuilder(info=service_info).with_health().build() + + with TestClient(app) as client: + response = client.get("/health") + + assert response.status_code == 200 + assert response.json()["status"] == "healthy" + + +def test_service_builder_with_custom_health_checks(service_info: ServiceInfo) -> None: + """Test custom health checks are registered.""" + check_called = False + + async def custom_check() -> tuple[HealthState, str | None]: + nonlocal check_called + check_called = True + return (HealthState.HEALTHY, None) + + app = ( + ServiceBuilder(info=service_info) + .with_health(checks={"custom": custom_check}, include_database_check=False) + .build() + ) + + client = TestClient(app) + response = client.get("/health") + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + assert "custom" in data["checks"] + assert check_called + + +def test_service_builder_with_config(service_info: ServiceInfo) -> None: + """Test that with_config() adds config endpoints.""" + app = ServiceBuilder(info=service_info).with_config(ExampleConfig).build() + + with TestClient(app) as client: + response = client.get("/api/v1/configs/") + + assert response.status_code == 200 + assert isinstance(response.json(), list) + + +def test_service_builder_with_artifacts(service_info: ServiceInfo) -> None: + """Test that with_artifacts() adds artifact endpoints.""" + hierarchy = ArtifactHierarchy( + name="test_hierarchy", + level_labels={0: "root", 1: "child"}, + ) + + app = ServiceBuilder(info=service_info).with_artifacts(hierarchy=hierarchy).build() + + with TestClient(app) as client: + response = client.get("/api/v1/artifacts/") + + assert response.status_code == 200 + assert isinstance(response.json(), list) + + +def test_service_builder_config_artifact_linking(service_info: ServiceInfo) -> None: + """Test config-artifact linking integration.""" + hierarchy = ArtifactHierarchy( + name="test_hierarchy", + level_labels={0: "root"}, + ) + + app = ( + ServiceBuilder(info=service_info) + .with_config(ExampleConfig) + .with_artifacts(hierarchy=hierarchy, enable_config_linking=True) + .build() + ) + + assert app is not None + + +def test_service_builder_validation_fails_without_config(service_info: ServiceInfo) -> None: + """Test that enabling config linking without config schema raises error.""" + hierarchy = ArtifactHierarchy( + name="test_hierarchy", + level_labels={0: "root"}, + ) + + with pytest.raises(ValueError, match="config schema"): + ServiceBuilder(info=service_info).with_artifacts( + hierarchy=hierarchy, + enable_config_linking=True, + ).build() + + +def test_service_builder_invalid_health_check_name(service_info: ServiceInfo) -> None: + """Test that invalid health check names are rejected.""" + + async def check() -> tuple[HealthState, str | None]: + return (HealthState.HEALTHY, None) + + with pytest.raises(ValueError, match="invalid characters"): + ServiceBuilder(info=service_info).with_health(checks={"invalid name!": check}).build() + + +@pytest.mark.asyncio +async def test_service_builder_with_database_instance( + service_info: ServiceInfo, + test_database: Database, +) -> None: + """Test injecting a pre-configured database instance.""" + app = ServiceBuilder(info=service_info).with_database(test_database).with_config(ExampleConfig).build() + + # Test that the app uses the injected database + with TestClient(app) as client: + response = client.get("/api/v1/configs/") + + assert response.status_code == 200 + + +def test_service_builder_with_database_invalid_type(service_info: ServiceInfo) -> None: + """Test that with_database() rejects invalid types.""" + with pytest.raises(TypeError, match="Expected str, Database, or None"): + ServiceBuilder(info=service_info).with_database(123).build() # type: ignore[arg-type] + + +def test_service_builder_custom_router_integration(service_info: ServiceInfo) -> None: + """Test including custom routers.""" + from fastapi import APIRouter + + custom_router = APIRouter(prefix="/custom", tags=["custom"]) + + @custom_router.get("/test") + async def custom_endpoint() -> dict[str, str]: + return {"message": "custom"} + + app = ServiceBuilder(info=service_info).include_router(custom_router).build() + + client = TestClient(app) + response = client.get("/custom/test") + + assert response.status_code == 200 + assert response.json() == {"message": "custom"} + + +def test_service_builder_info_endpoint(service_info: ServiceInfo) -> None: + """Test that /api/v1/info endpoint is created.""" + app = ServiceBuilder.create(info=service_info) + + client = TestClient(app) + response = client.get("/api/v1/info") + + assert response.status_code == 200 + data = response.json() + assert data["display_name"] == "Test Service" + assert data["version"] == "1.0.0" + + +def test_service_builder_permissions(service_info: ServiceInfo) -> None: + """Test that permissions restrict operations.""" + app = ServiceBuilder(info=service_info).with_config(ExampleConfig, allow_create=False, allow_delete=False).build() + + with TestClient(app) as client: + # GET should work (read is allowed) + response = client.get("/api/v1/configs/") + assert response.status_code == 200 + + # POST should fail (create is disabled) + response = client.post("/api/v1/configs/", json={"name": "test", "data": {"enabled": True, "value": 42}}) + assert response.status_code == 405 # Method not allowed + + +def test_service_builder_fluent_api(service_info: ServiceInfo) -> None: + """Test fluent API chaining.""" + hierarchy = ArtifactHierarchy(name="test", level_labels={0: "root"}) + + app = ( + ServiceBuilder(info=service_info) + .with_health() + .with_config(ExampleConfig) + .with_artifacts(hierarchy=hierarchy) + .build() + ) + + assert isinstance(app, FastAPI) + assert app.title == "Test Service" + + +def test_service_builder_startup_hook(service_info: ServiceInfo) -> None: + """Test that startup hooks are executed.""" + hook_called = False + + async def startup_hook(app: FastAPI) -> None: + nonlocal hook_called + hook_called = True + + app = ServiceBuilder(info=service_info).on_startup(startup_hook).build() + + # Startup hooks run during lifespan + with TestClient(app): + pass + + assert hook_called + + +def test_service_builder_shutdown_hook(service_info: ServiceInfo) -> None: + """Test that shutdown hooks are executed.""" + hook_called = False + + async def shutdown_hook(app: FastAPI) -> None: + nonlocal hook_called + hook_called = True + + app = ServiceBuilder(info=service_info).on_shutdown(shutdown_hook).build() + + # Shutdown hooks run during lifespan cleanup + with TestClient(app): + pass + + assert hook_called + + +def test_service_builder_preserves_summary_as_description(service_info: ServiceInfo) -> None: + """Test that summary is preserved as description when description is missing.""" + info = ServiceInfo(display_name="Test", summary="Test summary") + builder = ServiceBuilder(info=info) + + assert builder.info.description == "Test summary" + + +def test_service_builder_landing_page(service_info: ServiceInfo) -> None: + """Test that with_landing_page() adds root endpoint.""" + app = ServiceBuilder(info=service_info).with_landing_page().build() + + with TestClient(app) as client: + response = client.get("/") + + assert response.status_code == 200 + assert response.headers["content-type"] == "text/html; charset=utf-8" + # Check that the page contains JavaScript to fetch service info + assert "fetch('/api/v1/info')" in response.text + assert "/docs" in response.text # Navigation link to API docs + + # Verify that /api/v1/info returns the correct data + info_response = client.get("/api/v1/info") + assert info_response.status_code == 200 + data = info_response.json() + assert data["display_name"] == "Test Service" + assert data["version"] == "1.0.0" + assert data["summary"] == "Test service for unit tests" + + +def test_service_builder_without_landing_page(service_info: ServiceInfo) -> None: + """Test that root endpoint is not added by default.""" + app = ServiceBuilder(info=service_info).build() + + with TestClient(app) as client: + response = client.get("/") + + # Should return 404 when landing page is not enabled + assert response.status_code == 404 + + +def test_service_builder_with_system(service_info: ServiceInfo) -> None: + """Test that with_system() adds system endpoint.""" + app = ServiceBuilder(info=service_info).with_system().build() + + with TestClient(app) as client: + response = client.get("/api/v1/system/") + + assert response.status_code == 200 + data = response.json() + assert "current_time" in data + assert "timezone" in data + assert "python_version" in data + assert "platform" in data + assert "hostname" in data + + +def test_service_builder_landing_page_with_custom_fields() -> None: + """Test that landing page displays custom ServiceInfo fields.""" + from pydantic import EmailStr + + class CustomServiceInfo(ServiceInfo): + """Extended service info with custom fields.""" + + author: str + contact_email: EmailStr + custom_field: dict[str, object] + + info = CustomServiceInfo( + display_name="Custom Service", + version="2.0.0", + summary="Test with custom fields", + author="Jane Doe", + contact_email="jane@example.com", + custom_field={"key": "value", "count": 42}, + ) + + app = ServiceBuilder(info=info).with_landing_page().build() + + with TestClient(app) as client: + # Check landing page HTML loads properly + response = client.get("/") + assert response.status_code == 200 + assert "fetch('/api/v1/info')" in response.text + + # Check that /api/v1/info includes custom fields + info_response = client.get("/api/v1/info") + assert info_response.status_code == 200 + data = info_response.json() + assert data["display_name"] == "Custom Service" + assert data["version"] == "2.0.0" + assert data["summary"] == "Test with custom fields" + assert data["author"] == "Jane Doe" + assert data["contact_email"] == "jane@example.com" + assert data["custom_field"] == {"key": "value", "count": 42} + + +def test_service_builder_with_monitoring(service_info: ServiceInfo) -> None: + """Test that with_monitoring() adds metrics endpoint.""" + app = ServiceBuilder(info=service_info).with_monitoring().build() + + with TestClient(app) as client: + response = client.get("/metrics") + + assert response.status_code == 200 + assert response.headers["content-type"].startswith("text/plain") + # Check for Prometheus format markers + assert b"# HELP" in response.content or b"# TYPE" in response.content or len(response.content) > 0 + + +def test_service_builder_with_monitoring_custom_prefix(service_info: ServiceInfo) -> None: + """Test that monitoring can use custom prefix.""" + app = ServiceBuilder(info=service_info).with_monitoring(prefix="/api/v1/metrics").build() + + with TestClient(app) as client: + response = client.get("/api/v1/metrics") + assert response.status_code == 200 + + +def test_service_builder_with_all_features(service_info: ServiceInfo) -> None: + """Integration test with all features enabled.""" + hierarchy = ArtifactHierarchy(name="test", level_labels={0: "root"}) + + async def health_check() -> tuple[HealthState, str | None]: + return (HealthState.HEALTHY, None) + + async def startup(app: FastAPI) -> None: + pass + + async def shutdown(app: FastAPI) -> None: + pass + + app = ( + ServiceBuilder(info=service_info) + .with_landing_page() + .with_health(checks={"test": health_check}) + .with_config(ExampleConfig) + .with_artifacts(hierarchy=hierarchy, enable_config_linking=True) + .with_monitoring() + .on_startup(startup) + .on_shutdown(shutdown) + .build() + ) + + with TestClient(app) as client: + # Test all endpoints work + # Note: Root app mount catches trailing slash redirects, so use exact paths + assert client.get("/").status_code == 200 + assert client.get("/api/v1/info").status_code == 200 + assert client.get("/health").status_code == 200 + assert client.get("/api/v1/configs").status_code == 200 + assert client.get("/api/v1/artifacts").status_code == 200 + assert client.get("/metrics").status_code == 200 diff --git a/packages/servicekit/tests/test_api_utilities.py b/packages/servicekit/tests/test_api_utilities.py new file mode 100644 index 0000000..a90db9b --- /dev/null +++ b/packages/servicekit/tests/test_api_utilities.py @@ -0,0 +1,191 @@ +"""Tests for API utilities.""" + +from unittest.mock import Mock, patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from servicekit.core.api.utilities import build_location_url, run_app +from starlette.requests import Request + + +def test_build_location_url() -> None: + """Test build_location_url constructs full URLs correctly.""" + app = FastAPI() + + @app.get("/test") + def test_endpoint(request: Request) -> dict[str, str]: + url = build_location_url(request, "/api/v1/items/123") + return {"url": url} + + client = TestClient(app) + response = client.get("/test") + + assert response.status_code == 200 + data = response.json() + assert data["url"] == "http://testserver/api/v1/items/123" + + +def test_build_location_url_with_custom_base() -> None: + """Test build_location_url with custom base URL.""" + app = FastAPI() + + @app.get("/test") + def test_endpoint(request: Request) -> dict[str, str]: + url = build_location_url(request, "/resources/456") + return {"url": url} + + client = TestClient(app, base_url="https://example.com") + response = client.get("/test") + + assert response.status_code == 200 + data = response.json() + assert data["url"] == "https://example.com/resources/456" + + +def test_build_location_url_preserves_path_slashes() -> None: + """Test build_location_url preserves leading slashes in paths.""" + app = FastAPI() + + @app.get("/test") + def test_endpoint(request: Request) -> dict[str, str]: + url1 = build_location_url(request, "/api/v1/items") + url2 = build_location_url(request, "/api/v1/items/") + return {"url1": url1, "url2": url2} + + client = TestClient(app) + response = client.get("/test") + + assert response.status_code == 200 + data = response.json() + assert data["url1"] == "http://testserver/api/v1/items" + assert data["url2"] == "http://testserver/api/v1/items/" + + +def test_run_app_with_defaults() -> None: + """Test run_app uses default values.""" + mock_run = Mock() + mock_configure = Mock() + + with patch("uvicorn.run", mock_run): + with patch("chapkit.core.logging.configure_logging", mock_configure): + run_app("test:app") + + # Check logging was configured + mock_configure.assert_called_once() + + # Check uvicorn.run was called + mock_run.assert_called_once() + args, kwargs = mock_run.call_args + assert args[0] == "test:app" + assert kwargs["host"] == "127.0.0.1" + assert kwargs["port"] == 8000 + assert kwargs["workers"] == 1 + assert kwargs["reload"] is True # string enables reload + assert kwargs["log_level"] == "info" + assert kwargs["log_config"] is None + + +def test_run_app_with_env_vars(monkeypatch: pytest.MonkeyPatch) -> None: + """Test run_app reads from environment variables.""" + monkeypatch.setenv("HOST", "0.0.0.0") + monkeypatch.setenv("PORT", "3000") + monkeypatch.setenv("WORKERS", "4") + monkeypatch.setenv("LOG_LEVEL", "DEBUG") + + mock_run = Mock() + mock_configure = Mock() + + with patch("uvicorn.run", mock_run): + with patch("chapkit.core.logging.configure_logging", mock_configure): + run_app("test:app") + + _, kwargs = mock_run.call_args + assert kwargs["host"] == "0.0.0.0" + assert kwargs["port"] == 3000 + assert kwargs["workers"] == 4 + assert kwargs["log_level"] == "debug" + assert kwargs["reload"] is False # workers > 1 disables reload + + +def test_run_app_reload_logic_with_string() -> None: + """Test run_app enables reload for string app path.""" + mock_run = Mock() + mock_configure = Mock() + + with patch("uvicorn.run", mock_run): + with patch("chapkit.core.logging.configure_logging", mock_configure): + run_app("test:app") + + assert mock_run.call_args[1]["reload"] is True + + +def test_run_app_reload_logic_with_instance() -> None: + """Test run_app disables reload for app instance.""" + mock_run = Mock() + mock_configure = Mock() + + with patch("uvicorn.run", mock_run): + with patch("chapkit.core.logging.configure_logging", mock_configure): + app = FastAPI() + run_app(app) + + assert mock_run.call_args[1]["reload"] is False + + +def test_run_app_with_explicit_params(monkeypatch: pytest.MonkeyPatch) -> None: + """Test run_app with explicit parameters overrides defaults and env vars.""" + monkeypatch.setenv("HOST", "0.0.0.0") + monkeypatch.setenv("PORT", "9000") + + mock_run = Mock() + mock_configure = Mock() + + with patch("uvicorn.run", mock_run): + with patch("chapkit.core.logging.configure_logging", mock_configure): + run_app("test:app", host="localhost", port=5000, workers=2, log_level="warning") + + _, kwargs = mock_run.call_args + assert kwargs["host"] == "localhost" + assert kwargs["port"] == 5000 + assert kwargs["workers"] == 2 + assert kwargs["log_level"] == "warning" + + +def test_run_app_multiple_workers_disables_reload() -> None: + """Test run_app disables reload when workers > 1.""" + mock_run = Mock() + mock_configure = Mock() + + with patch("uvicorn.run", mock_run): + with patch("chapkit.core.logging.configure_logging", mock_configure): + run_app("test:app", workers=4) + + assert mock_run.call_args[1]["reload"] is False + + +def test_run_app_with_explicit_reload_true() -> None: + """Test run_app with explicit reload=True is respected unless workers > 1.""" + mock_run = Mock() + mock_configure = Mock() + + with patch("uvicorn.run", mock_run): + with patch("chapkit.core.logging.configure_logging", mock_configure): + app = FastAPI() + run_app(app, reload=True, workers=1) + + assert mock_run.call_args[1]["reload"] is True + + +def test_run_app_with_uvicorn_kwargs() -> None: + """Test run_app passes additional uvicorn kwargs.""" + mock_run = Mock() + mock_configure = Mock() + + with patch("uvicorn.run", mock_run): + with patch("chapkit.core.logging.configure_logging", mock_configure): + run_app("test:app", access_log=False, proxy_headers=True) + + kwargs = mock_run.call_args[1] + assert kwargs["access_log"] is False + assert kwargs["proxy_headers"] is True diff --git a/packages/servicekit/tests/test_app_loader.py b/packages/servicekit/tests/test_app_loader.py new file mode 100644 index 0000000..dfe9abc --- /dev/null +++ b/packages/servicekit/tests/test_app_loader.py @@ -0,0 +1,561 @@ +"""Tests for app loading and validation.""" + +import json +from pathlib import Path + +import pytest +from pydantic import ValidationError +from servicekit.core.api.app import App, AppLoader, AppManifest + + +def test_app_manifest_valid(): + """Test valid app manifest.""" + manifest = AppManifest( + name="Test App", + version="1.0.0", + prefix="/test", + description="Test description", + author="Test Author", + entry="index.html", + ) + assert manifest.name == "Test App" + assert manifest.version == "1.0.0" + assert manifest.prefix == "/test" + assert manifest.entry == "index.html" + + +def test_app_manifest_minimal(): + """Test minimal app manifest (only required fields).""" + manifest = AppManifest( + name="Minimal App", + version="1.0.0", + prefix="/minimal", + ) + assert manifest.name == "Minimal App" + assert manifest.entry == "index.html" # Default + assert manifest.description is None + assert manifest.author is None + + +def test_app_manifest_invalid_prefix_no_slash(): + """Test manifest validation rejects prefix without leading slash.""" + with pytest.raises(ValidationError, match="prefix must start with '/'"): + AppManifest( + name="Bad App", + version="1.0.0", + prefix="bad", + ) + + +def test_app_manifest_invalid_prefix_path_traversal(): + """Test manifest validation rejects path traversal in prefix.""" + with pytest.raises(ValidationError, match="prefix cannot contain '..'"): + AppManifest( + name="Bad App", + version="1.0.0", + prefix="/../etc", + ) + + +def test_app_manifest_invalid_prefix_api(): + """Test manifest validation rejects /api prefix.""" + with pytest.raises(ValidationError, match="prefix cannot be '/api'"): + AppManifest( + name="Bad App", + version="1.0.0", + prefix="/api", + ) + + +def test_app_manifest_invalid_prefix_api_subpath(): + """Test manifest validation rejects /api/** prefix.""" + with pytest.raises(ValidationError, match="prefix cannot be '/api'"): + AppManifest( + name="Bad App", + version="1.0.0", + prefix="/api/dashboard", + ) + + +def test_load_app_from_filesystem(tmp_path: Path): + """Test loading app from filesystem.""" + # Create app structure + app_dir = tmp_path / "test-app" + app_dir.mkdir() + + manifest = { + "name": "Test App", + "version": "1.0.0", + "prefix": "/test", + "description": "Test app", + } + (app_dir / "manifest.json").write_text(json.dumps(manifest)) + (app_dir / "index.html").write_text("Test") + + # Load app + app = AppLoader.load(str(app_dir)) + + assert app.manifest.name == "Test App" + assert app.manifest.prefix == "/test" + assert app.prefix == "/test" + assert app.directory == app_dir + assert not app.is_package + + +def test_load_app_with_prefix_override(tmp_path: Path): + """Test loading app with prefix override.""" + app_dir = tmp_path / "test-app" + app_dir.mkdir() + + manifest = { + "name": "Test App", + "version": "1.0.0", + "prefix": "/original", + } + (app_dir / "manifest.json").write_text(json.dumps(manifest)) + (app_dir / "index.html").write_text("Test") + + # Load with override + app = AppLoader.load(str(app_dir), prefix="/overridden") + + assert app.manifest.prefix == "/original" # Original unchanged + assert app.prefix == "/overridden" # Override applied + + +def test_load_app_custom_entry(tmp_path: Path): + """Test loading app with custom entry file.""" + app_dir = tmp_path / "test-app" + app_dir.mkdir() + + manifest = { + "name": "Test App", + "version": "1.0.0", + "prefix": "/test", + "entry": "main.html", + } + (app_dir / "manifest.json").write_text(json.dumps(manifest)) + (app_dir / "main.html").write_text("Test") + + # Load app + app = AppLoader.load(str(app_dir)) + + assert app.manifest.entry == "main.html" + + +def test_load_app_missing_directory(tmp_path: Path): + """Test loading app from non-existent directory.""" + with pytest.raises(FileNotFoundError, match="App directory not found"): + AppLoader.load(str(tmp_path / "nonexistent")) + + +def test_load_app_missing_manifest(tmp_path: Path): + """Test loading app without manifest.json.""" + app_dir = tmp_path / "test-app" + app_dir.mkdir() + + with pytest.raises(FileNotFoundError, match="manifest.json not found"): + AppLoader.load(str(app_dir)) + + +def test_load_app_missing_entry_file(tmp_path: Path): + """Test loading app without entry file.""" + app_dir = tmp_path / "test-app" + app_dir.mkdir() + + manifest = { + "name": "Test App", + "version": "1.0.0", + "prefix": "/test", + } + (app_dir / "manifest.json").write_text(json.dumps(manifest)) + # No index.html created + + with pytest.raises(FileNotFoundError, match="Entry file 'index.html' not found"): + AppLoader.load(str(app_dir)) + + +def test_load_app_invalid_manifest_json(tmp_path: Path): + """Test loading app with invalid JSON manifest.""" + app_dir = tmp_path / "test-app" + app_dir.mkdir() + + (app_dir / "manifest.json").write_text("invalid json{") + (app_dir / "index.html").write_text("Test") + + with pytest.raises(ValueError, match="Invalid JSON in manifest.json"): + AppLoader.load(str(app_dir)) + + +def test_load_app_invalid_manifest_schema(tmp_path: Path): + """Test loading app with manifest missing required fields.""" + app_dir = tmp_path / "test-app" + app_dir.mkdir() + + manifest = { + "name": "Test App", + # Missing version and prefix + } + (app_dir / "manifest.json").write_text(json.dumps(manifest)) + (app_dir / "index.html").write_text("Test") + + with pytest.raises(ValidationError): + AppLoader.load(str(app_dir)) + + +def test_discover_apps(tmp_path: Path): + """Test discovering multiple apps in a directory.""" + apps_dir = tmp_path / "apps" + apps_dir.mkdir() + + # Create app 1 + app1_dir = apps_dir / "app1" + app1_dir.mkdir() + (app1_dir / "manifest.json").write_text(json.dumps({"name": "App 1", "version": "1.0.0", "prefix": "/app1"})) + (app1_dir / "index.html").write_text("App 1") + + # Create app 2 + app2_dir = apps_dir / "app2" + app2_dir.mkdir() + (app2_dir / "manifest.json").write_text(json.dumps({"name": "App 2", "version": "1.0.0", "prefix": "/app2"})) + (app2_dir / "index.html").write_text("App 2") + + # Discover apps + apps = AppLoader.discover(str(apps_dir)) + + assert len(apps) == 2 + app_names = {app.manifest.name for app in apps} + assert app_names == {"App 1", "App 2"} + + +def test_discover_apps_ignores_invalid(tmp_path: Path): + """Test app discovery ignores invalid apps.""" + apps_dir = tmp_path / "apps" + apps_dir.mkdir() + + # Valid app + valid_dir = apps_dir / "valid" + valid_dir.mkdir() + (valid_dir / "manifest.json").write_text(json.dumps({"name": "Valid", "version": "1.0.0", "prefix": "/valid"})) + (valid_dir / "index.html").write_text("Valid") + + # Invalid app (missing entry file) + invalid_dir = apps_dir / "invalid" + invalid_dir.mkdir() + (invalid_dir / "manifest.json").write_text( + json.dumps({"name": "Invalid", "version": "1.0.0", "prefix": "/invalid"}) + ) + # No index.html + + # Directory without manifest + no_manifest_dir = apps_dir / "no-manifest" + no_manifest_dir.mkdir() + (no_manifest_dir / "index.html").write_text("No Manifest") + + # Discover apps (should only find valid app) + apps = AppLoader.discover(str(apps_dir)) + + assert len(apps) == 1 + assert apps[0].manifest.name == "Valid" + + +def test_discover_apps_empty_directory(tmp_path: Path): + """Test discovering apps in empty directory.""" + apps_dir = tmp_path / "apps" + apps_dir.mkdir() + + apps = AppLoader.discover(str(apps_dir)) + + assert len(apps) == 0 + + +def test_discover_apps_missing_directory(tmp_path: Path): + """Test discovering apps in non-existent directory.""" + with pytest.raises(FileNotFoundError, match="Apps directory not found"): + AppLoader.discover(str(tmp_path / "nonexistent")) + + +def test_discover_apps_from_package(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + """Test discovering multiple apps from package resources.""" + # Create a temporary package structure + pkg_dir = tmp_path / "test_package" + pkg_dir.mkdir() + (pkg_dir / "__init__.py").write_text("") # Make it a package + + # Create apps subdirectory + apps_dir = pkg_dir / "web_apps" + apps_dir.mkdir() + + # Create app 1 + app1_dir = apps_dir / "dashboard" + app1_dir.mkdir() + (app1_dir / "manifest.json").write_text( + json.dumps({"name": "Dashboard", "version": "1.0.0", "prefix": "/dashboard"}) + ) + (app1_dir / "index.html").write_text("Dashboard") + + # Create app 2 + app2_dir = apps_dir / "admin" + app2_dir.mkdir() + (app2_dir / "manifest.json").write_text(json.dumps({"name": "Admin", "version": "2.0.0", "prefix": "/admin"})) + (app2_dir / "index.html").write_text("Admin") + + # Add package to sys.path temporarily + monkeypatch.syspath_prepend(str(tmp_path)) + + # Discover apps from package + apps = AppLoader.discover(("test_package", "web_apps")) + + assert len(apps) == 2 + app_names = {app.manifest.name for app in apps} + assert app_names == {"Dashboard", "Admin"} + # Verify apps are marked as package apps + assert all(app.is_package for app in apps) + + +def test_discover_apps_from_package_empty(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + """Test discovering apps from empty package directory.""" + # Create a temporary package with empty apps directory + pkg_dir = tmp_path / "test_package" + pkg_dir.mkdir() + (pkg_dir / "__init__.py").write_text("") + + apps_dir = pkg_dir / "empty_apps" + apps_dir.mkdir() + + # Add package to sys.path + monkeypatch.syspath_prepend(str(tmp_path)) + + # Discover apps (should return empty list) + apps = AppLoader.discover(("test_package", "empty_apps")) + + assert len(apps) == 0 + + +def test_discover_apps_from_package_ignores_invalid(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + """Test package app discovery ignores invalid apps.""" + # Create package structure + pkg_dir = tmp_path / "test_package" + pkg_dir.mkdir() + (pkg_dir / "__init__.py").write_text("") + + apps_dir = pkg_dir / "apps" + apps_dir.mkdir() + + # Valid app + valid_dir = apps_dir / "valid" + valid_dir.mkdir() + (valid_dir / "manifest.json").write_text(json.dumps({"name": "Valid", "version": "1.0.0", "prefix": "/valid"})) + (valid_dir / "index.html").write_text("Valid") + + # Invalid app (missing entry file) + invalid_dir = apps_dir / "invalid" + invalid_dir.mkdir() + (invalid_dir / "manifest.json").write_text( + json.dumps({"name": "Invalid", "version": "1.0.0", "prefix": "/invalid"}) + ) + # No index.html + + # Add package to sys.path + monkeypatch.syspath_prepend(str(tmp_path)) + + # Discover apps (should only find valid app) + apps = AppLoader.discover(("test_package", "apps")) + + assert len(apps) == 1 + assert apps[0].manifest.name == "Valid" + + +def test_discover_apps_from_nonexistent_package(): + """Test discovering apps from non-existent package fails.""" + with pytest.raises(ValueError, match="Package .* could not be found"): + AppLoader.discover(("nonexistent.package", "apps")) + + +def test_discover_apps_from_package_nonexistent_subpath(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + """Test discovering apps from non-existent subpath in package fails.""" + # Create minimal package + pkg_dir = tmp_path / "test_package" + pkg_dir.mkdir() + (pkg_dir / "__init__.py").write_text("") + + # Add package to sys.path + monkeypatch.syspath_prepend(str(tmp_path)) + + # Try to discover from non-existent subpath + with pytest.raises(FileNotFoundError, match="App path .* not found in package"): + AppLoader.discover(("test_package", "nonexistent_apps")) + + +def test_discover_apps_from_package_rejects_traversal(): + """Test package discovery rejects path traversal in subpath.""" + with pytest.raises(ValueError, match="subpath cannot contain '..'"): + AppLoader.discover(("chapkit.core.api", "../../../etc")) + + +def test_load_app_from_package(): + """Test loading app from package resources.""" + # Load from servicekit.core.api package (we know this exists) + # We'll create a test by using an invalid package to test error handling + with pytest.raises(ValueError, match="Package .* could not be found"): + AppLoader.load(("nonexistent.package", "apps/test")) + + +def test_app_dataclass(): + """Test App dataclass structure.""" + manifest = AppManifest( + name="Test", + version="1.0.0", + prefix="/test", + ) + + app = App( + manifest=manifest, + directory=Path("/tmp/test"), + prefix="/custom", + is_package=False, + ) + + assert app.manifest.name == "Test" + assert app.directory == Path("/tmp/test") + assert app.prefix == "/custom" + assert not app.is_package + + +# Security Tests + + +def test_app_manifest_rejects_entry_path_traversal(): + """Test manifest validation rejects path traversal in entry field.""" + with pytest.raises(ValidationError, match="entry cannot contain '..'"): + AppManifest( + name="Malicious App", + version="1.0.0", + prefix="/test", + entry="../../../etc/passwd", + ) + + +def test_app_manifest_rejects_entry_absolute_path(): + """Test manifest validation rejects absolute paths in entry field.""" + with pytest.raises(ValidationError, match="entry must be a relative path"): + AppManifest( + name="Malicious App", + version="1.0.0", + prefix="/test", + entry="/etc/passwd", + ) + + +def test_app_manifest_rejects_entry_normalized_traversal(): + """Test manifest validation catches normalized path traversal.""" + with pytest.raises(ValidationError, match="entry cannot contain '..'"): + AppManifest( + name="Malicious App", + version="1.0.0", + prefix="/test", + entry="subdir/../../etc/passwd", + ) + + +def test_app_manifest_rejects_extra_fields(): + """Test manifest validation rejects unknown fields.""" + with pytest.raises(ValidationError, match="Extra inputs are not permitted"): + AppManifest( + name="Test App", + version="1.0.0", + prefix="/test", + unknown_field="malicious", # type: ignore[call-arg] + ) + + +def test_load_app_rejects_entry_traversal_in_file(tmp_path: Path): + """Test loading app with path traversal in entry field fails at validation.""" + app_dir = tmp_path / "test-app" + app_dir.mkdir() + + manifest = { + "name": "Malicious App", + "version": "1.0.0", + "prefix": "/test", + "entry": "../../../etc/passwd", + } + (app_dir / "manifest.json").write_text(json.dumps(manifest)) + + with pytest.raises(ValidationError, match="entry cannot contain '..'"): + AppLoader.load(str(app_dir)) + + +def test_load_app_from_package_rejects_subpath_traversal(): + """Test loading app from package with path traversal in subpath fails.""" + with pytest.raises(ValueError, match="subpath cannot contain '..'"): + AppLoader.load(("chapkit.core.api", "../../../etc")) + + +def test_load_app_from_package_rejects_absolute_subpath(): + """Test loading app from package with absolute subpath fails.""" + with pytest.raises(ValueError, match="subpath must be relative"): + AppLoader.load(("chapkit.core.api", "/etc/passwd")) + + +# AppManager Tests + + +def test_app_manager_list_apps(tmp_path: Path): + """Test AppManager lists all apps.""" + from servicekit.core.api.app import AppManager + + # Create test apps + app1_dir = tmp_path / "app1" + app1_dir.mkdir() + (app1_dir / "manifest.json").write_text(json.dumps({"name": "App 1", "version": "1.0.0", "prefix": "/app1"})) + (app1_dir / "index.html").write_text("App 1") + + app2_dir = tmp_path / "app2" + app2_dir.mkdir() + (app2_dir / "manifest.json").write_text(json.dumps({"name": "App 2", "version": "1.0.0", "prefix": "/app2"})) + (app2_dir / "index.html").write_text("App 2") + + # Load apps + app1 = AppLoader.load(str(app1_dir)) + app2 = AppLoader.load(str(app2_dir)) + + # Create manager + manager = AppManager([app1, app2]) + + # List apps + apps = manager.list() + assert len(apps) == 2 + assert apps[0].manifest.name == "App 1" + assert apps[1].manifest.name == "App 2" + + +def test_app_manager_get_app_by_prefix(tmp_path: Path): + """Test getting app by prefix.""" + from servicekit.core.api.app import AppManager + + # Create test app + app_dir = tmp_path / "test-app" + app_dir.mkdir() + (app_dir / "manifest.json").write_text(json.dumps({"name": "Test App", "version": "1.0.0", "prefix": "/test"})) + (app_dir / "index.html").write_text("Test") + + # Load app and create manager + app = AppLoader.load(str(app_dir)) + manager = AppManager([app]) + + # Get app by prefix + found_app = manager.get("/test") + assert found_app is not None + assert found_app.manifest.name == "Test App" + + +def test_app_manager_get_nonexistent_app(): + """Test getting nonexistent app returns None.""" + from servicekit.core.api.app import AppManager + + # Create empty manager + manager = AppManager([]) + + # Get nonexistent app + found_app = manager.get("/nonexistent") + assert found_app is None diff --git a/packages/servicekit/tests/test_artifact.py b/packages/servicekit/tests/test_artifact.py new file mode 100644 index 0000000..1624583 --- /dev/null +++ b/packages/servicekit/tests/test_artifact.py @@ -0,0 +1,735 @@ +import pandas as pd +import pytest +from servicekit import ( + Artifact, + ArtifactHierarchy, + ArtifactIn, + ArtifactManager, + ArtifactOut, + ArtifactRepository, + PandasDataFrame, + SqliteDatabaseBuilder, +) +from ulid import ULID + + +class TestArtifactRepository: + """Tests for the ArtifactRepository class.""" + + async def test_find_by_id_with_children(self) -> None: + """Test that find_by_id returns artifact with parent_id set.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ArtifactRepository(session) + + # Create parent + parent = Artifact(data={"type": "parent", "name": "root"}, level=0) + await repo.save(parent) + await repo.commit() + await repo.refresh_many([parent]) + + # Create child with parent_id + child = Artifact(data={"type": "child", "name": "child1"}, parent_id=parent.id, level=1) + await repo.save(child) + await repo.commit() + await repo.refresh_many([child]) + + # Find child by ID + found = await repo.find_by_id(child.id) + + assert found is not None + assert found.id == child.id + assert found.parent_id == parent.id + assert found.data == {"type": "child", "name": "child1"} + assert found.level == 1 + + await db.dispose() + + async def test_find_subtree_single_node(self) -> None: + """Test finding subtree with a single node (no children).""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ArtifactRepository(session) + + # Create single artifact + artifact = Artifact(data={"type": "leaf", "name": "single"}, level=0) + await repo.save(artifact) + await repo.commit() + await repo.refresh_many([artifact]) + + # Find subtree + subtree = await repo.find_subtree(artifact.id) + + assert len(list(subtree)) == 1 + assert list(subtree)[0].id == artifact.id + + await db.dispose() + + async def test_find_subtree_with_children(self) -> None: + """Test finding subtree with parent and children.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ArtifactRepository(session) + + # Create parent + parent = Artifact(data={"type": "parent"}, level=0) + await repo.save(parent) + await repo.commit() + await repo.refresh_many([parent]) + + # Create children + child1 = Artifact(data={"type": "child1"}, parent_id=parent.id, level=1) + child2 = Artifact(data={"type": "child2"}, parent_id=parent.id, level=1) + await repo.save_all([child1, child2]) + await repo.commit() + + # Find subtree + subtree = list(await repo.find_subtree(parent.id)) + + assert len(subtree) == 3 + ids = {artifact.id for artifact in subtree} + assert parent.id in ids + assert child1.id in ids + assert child2.id in ids + + await db.dispose() + + async def test_find_subtree_with_grandchildren(self) -> None: + """Test finding subtree with multiple levels (grandchildren).""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ArtifactRepository(session) + + # Create parent + parent = Artifact(data={"level": "root"}, level=0) + await repo.save(parent) + await repo.commit() + await repo.refresh_many([parent]) + + # Create children + child1 = Artifact(data={"level": "child1"}, parent_id=parent.id, level=1) + child2 = Artifact(data={"level": "child2"}, parent_id=parent.id, level=1) + await repo.save_all([child1, child2]) + await repo.commit() + await repo.refresh_many([child1, child2]) + + # Create grandchildren + grandchild1 = Artifact(data={"level": "grandchild1"}, parent_id=child1.id, level=2) + grandchild2 = Artifact(data={"level": "grandchild2"}, parent_id=child1.id, level=2) + grandchild3 = Artifact(data={"level": "grandchild3"}, parent_id=child2.id, level=2) + await repo.save_all([grandchild1, grandchild2, grandchild3]) + await repo.commit() + + # Find subtree from root + subtree = list(await repo.find_subtree(parent.id)) + + assert len(subtree) == 6 # parent + 2 children + 3 grandchildren + ids = {artifact.id for artifact in subtree} + assert parent.id in ids + assert child1.id in ids + assert child2.id in ids + assert grandchild1.id in ids + assert grandchild2.id in ids + assert grandchild3.id in ids + + await db.dispose() + + async def test_find_subtree_from_middle_node(self) -> None: + """Test finding subtree starting from a middle node.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ArtifactRepository(session) + + # Create parent + parent = Artifact(data={"level": "root"}, level=0) + await repo.save(parent) + await repo.commit() + await repo.refresh_many([parent]) + + # Create children + child1 = Artifact(data={"level": "child1"}, parent_id=parent.id, level=1) + child2 = Artifact(data={"level": "child2"}, parent_id=parent.id, level=1) + await repo.save_all([child1, child2]) + await repo.commit() + await repo.refresh_many([child1, child2]) + + # Create grandchildren under child1 + grandchild1 = Artifact(data={"level": "grandchild1"}, parent_id=child1.id, level=2) + grandchild2 = Artifact(data={"level": "grandchild2"}, parent_id=child1.id, level=2) + await repo.save_all([grandchild1, grandchild2]) + await repo.commit() + + # Find subtree from child1 (not root) + subtree = list(await repo.find_subtree(child1.id)) + + # Should only include child1 and its descendants, not parent or child2 + assert len(subtree) == 3 # child1 + 2 grandchildren + ids = {artifact.id for artifact in subtree} + assert child1.id in ids + assert grandchild1.id in ids + assert grandchild2.id in ids + assert parent.id not in ids + assert child2.id not in ids + + await db.dispose() + + +class TestArtifactManager: + """Tests for the ArtifactManager class.""" + + async def test_save_artifact(self) -> None: + """Test saving an artifact through manager.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ArtifactRepository(session) + manager = ArtifactManager(repo) + + # Save artifact + artifact_in = ArtifactIn(data={"type": "test", "value": 123}) + result = await manager.save(artifact_in) + + assert isinstance(result, ArtifactOut) + assert result.id is not None + assert result.data == {"type": "test", "value": 123} + assert result.level == 0 + + await db.dispose() + + +class HookAwareArtifactManager(ArtifactManager): + def __init__(self, repo: ArtifactRepository) -> None: + super().__init__(repo) + self.calls: list[str] = [] + + def _record(self, event: str, entity: Artifact) -> None: + self.calls.append(f"{event}:{entity.id}") + + async def pre_save(self, entity: Artifact, data: ArtifactIn) -> None: + self._record("pre_save", entity) + await super().pre_save(entity, data) + + async def post_save(self, entity: Artifact) -> None: + self._record("post_save", entity) + await super().post_save(entity) + + async def pre_update(self, entity: Artifact, data: ArtifactIn, old_values: dict[str, object]) -> None: + self._record("pre_update", entity) + await super().pre_update(entity, data, old_values) + + async def post_update(self, entity: Artifact, changes: dict[str, tuple[object, object]]) -> None: + self._record("post_update", entity) + await super().post_update(entity, changes) + + async def pre_delete(self, entity: Artifact) -> None: + self._record("pre_delete", entity) + await super().pre_delete(entity) + + async def post_delete(self, entity: Artifact) -> None: + self._record("post_delete", entity) + await super().post_delete(entity) + + +class TestBaseManagerLifecycle: + async def test_hooks_invoke_during_crud(self) -> None: + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ArtifactRepository(session) + manager = HookAwareArtifactManager(repo) + + saved = await manager.save(ArtifactIn(data={"name": "one"})) + assert [call.split(":")[0] for call in manager.calls] == ["pre_save", "post_save"] + manager.calls.clear() + + await manager.save(ArtifactIn(id=saved.id, data={"name": "one-updated"})) + assert [call.split(":")[0] for call in manager.calls] == ["pre_update", "post_update"] + manager.calls.clear() + + await manager.delete_by_id(saved.id) + assert [call.split(":")[0] for call in manager.calls] == ["pre_delete", "post_delete"] + + await db.dispose() + + async def test_delete_by_id_handles_missing(self) -> None: + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ArtifactRepository(session) + manager = HookAwareArtifactManager(repo) + + await manager.delete_by_id(ULID()) + + await db.dispose() + + async def test_save_all_empty_input(self) -> None: + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ArtifactRepository(session) + manager = HookAwareArtifactManager(repo) + + results = await manager.save_all([]) + assert results == [] + + await db.dispose() + + async def test_delete_all_no_entities(self) -> None: + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ArtifactRepository(session) + manager = HookAwareArtifactManager(repo) + + await manager.delete_all() + await manager.delete_all_by_id([]) + + await db.dispose() + + async def test_compute_level_handles_none_parent(self) -> None: + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ArtifactRepository(session) + manager = ArtifactManager(repo) + + level = await manager._compute_level(None) + assert level == 0 + + await db.dispose() + + async def test_save_artifact_with_parent(self) -> None: + """Test saving an artifact with a parent relationship.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ArtifactRepository(session) + manager = ArtifactManager(repo) + + # Save parent + parent_in = ArtifactIn(data={"type": "parent"}) + parent_out = await manager.save(parent_in) + + # Save child + assert parent_out.id is not None + assert parent_out.level == 0 + child_in = ArtifactIn(data={"type": "child"}, parent_id=parent_out.id) + child_out = await manager.save(child_in) + + assert child_out.id is not None + assert child_out.parent_id == parent_out.id + assert child_out.level == 1 + + await db.dispose() + + async def test_save_all_assigns_levels(self) -> None: + """Test save_all computes level based on parent relationships.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ArtifactRepository(session) + manager = ArtifactManager(repo) + + parent = await manager.save(ArtifactIn(data={"type": "parent"})) + assert parent.id is not None + + children = [ArtifactIn(data={"type": "child", "index": i}, parent_id=parent.id) for i in range(2)] + + results = await manager.save_all(children) + + assert len(results) == 2 + assert all(child.level == 1 for child in results) + + await db.dispose() + + async def test_save_all_updates_existing_entities(self) -> None: + """save_all should respect hooks when updating existing artifacts.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ArtifactRepository(session) + manager = ArtifactManager(repo) + + roots = await manager.save_all( + [ + ArtifactIn(data={"name": "root_a"}), + ArtifactIn(data={"name": "root_b"}), + ] + ) + root_a, root_b = roots + assert root_a.id is not None and root_b.id is not None + + children = await manager.save_all( + [ + ArtifactIn(data={"name": "child0"}, parent_id=root_a.id), + ArtifactIn(data={"name": "child1"}, parent_id=root_a.id), + ] + ) + + moved = await manager.save_all( + [ + ArtifactIn(id=children[0].id, data={"name": "child0"}, parent_id=root_b.id), + ArtifactIn(id=children[1].id, data={"name": "child1"}, parent_id=root_b.id), + ] + ) + + assert all(child.level == 1 for child in moved) + + await db.dispose() + + async def test_update_parent_recomputes_levels(self) -> None: + """Moving an artifact under a new parent updates levels for it and descendants.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ArtifactRepository(session) + manager = ArtifactManager(repo) + + root_a = await manager.save(ArtifactIn(data={"name": "root_a"})) + root_b = await manager.save(ArtifactIn(data={"name": "root_b"})) + assert root_a.id is not None and root_b.id is not None + + child = await manager.save(ArtifactIn(data={"name": "child"}, parent_id=root_a.id)) + grandchild = await manager.save(ArtifactIn(data={"name": "grandchild"}, parent_id=child.id)) + + assert child.level == 1 + assert grandchild.level == 2 + + # Move child (and implicitly grandchild) under root_b + updated_child = await manager.save(ArtifactIn(id=child.id, data={"name": "child"}, parent_id=root_b.id)) + + assert updated_child.level == 1 + + subtree = await manager.find_subtree(updated_child.id) + levels = {artifact.id: artifact.level for artifact in subtree} + assert levels[updated_child.id] == 1 + assert levels[grandchild.id] == 2 + + await db.dispose() + + async def test_update_without_parent_change_preserves_levels(self) -> None: + """Updating artifact data without changing parent leaves levels untouched.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ArtifactRepository(session) + manager = ArtifactManager(repo) + + parent = await manager.save(ArtifactIn(data={"name": "root"})) + assert parent.id is not None + + child = await manager.save(ArtifactIn(data={"name": "child"}, parent_id=parent.id)) + assert child.level == 1 + + renamed = await manager.save(ArtifactIn(id=child.id, data={"name": "child-renamed"}, parent_id=parent.id)) + + assert renamed.level == 1 + + await db.dispose() + + async def test_find_subtree_single_artifact(self) -> None: + """Test finding subtree through manager with single artifact.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ArtifactRepository(session) + manager = ArtifactManager(repo) + + # Create artifact + artifact_in = ArtifactIn(data={"name": "single"}) + saved = await manager.save(artifact_in) + + # Find subtree + assert saved.id is not None + subtree = await manager.find_subtree(saved.id) + + assert len(subtree) == 1 + assert isinstance(subtree[0], ArtifactOut) + assert subtree[0].id == saved.id + assert subtree[0].data == {"name": "single"} + assert subtree[0].level == 0 + + await db.dispose() + + async def test_find_subtree_with_hierarchy(self) -> None: + """Test finding subtree through manager with hierarchical data.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ArtifactRepository(session) + hierarchy = ArtifactHierarchy(name="ml", level_labels={0: "root", 1: "child", 2: "grandchild"}) + manager = ArtifactManager(repo, hierarchy=hierarchy) + + # Create parent + parent_in = ArtifactIn(data={"level": "root"}) + parent_out = await manager.save(parent_in) + + # Create children + assert parent_out.id is not None + child1_in = ArtifactIn(data={"level": "child1"}, parent_id=parent_out.id) + child2_in = ArtifactIn(data={"level": "child2"}, parent_id=parent_out.id) + child1_out = await manager.save(child1_in) + child2_out = await manager.save(child2_in) + + # Create grandchild + assert child1_out.id is not None + grandchild_in = ArtifactIn(data={"level": "grandchild"}, parent_id=child1_out.id) + grandchild_out = await manager.save(grandchild_in) + + # Find subtree from root + subtree = await manager.find_subtree(parent_out.id) + + assert len(subtree) == 4 + assert all(isinstance(artifact, ArtifactOut) for artifact in subtree) + + ids = {artifact.id for artifact in subtree} + assert parent_out.id in ids + assert child1_out.id in ids + assert child2_out.id in ids + assert grandchild_out.id in ids + for artifact in subtree: + if artifact.id == parent_out.id: + assert artifact.level == 0 + assert artifact.level_label == "root" + assert artifact.hierarchy == "ml" + elif artifact.id in {child1_out.id, child2_out.id}: + assert artifact.level == 1 + assert artifact.level_label == "child" + else: + assert artifact.level == 2 + assert artifact.level_label == "grandchild" + + await db.dispose() + + async def test_build_tree_returns_nested_structure(self) -> None: + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ArtifactRepository(session) + hierarchy = ArtifactHierarchy(name="ml", level_labels={0: "root", 1: "child", 2: "grandchild"}) + manager = ArtifactManager(repo, hierarchy=hierarchy) + + root = await manager.save(ArtifactIn(data={"name": "root"})) + assert root.id is not None + child = await manager.save(ArtifactIn(data={"name": "child"}, parent_id=root.id)) + grandchild = await manager.save(ArtifactIn(data={"name": "grandchild"}, parent_id=child.id)) + + tree = await manager.build_tree(root.id) + assert tree is not None + assert tree.id == root.id + assert tree.level_label == "root" + assert tree.children is not None + assert len(tree.children) == 1 + child_node = tree.children[0] + assert child_node.id == child.id + assert child_node.level_label == "child" + assert child_node.children is not None + assert len(child_node.children) == 1 + grandchild_node = child_node.children[0] + assert grandchild_node.id == grandchild.id + assert grandchild_node.level_label == "grandchild" + assert grandchild_node.children == [] + + await db.dispose() + + async def test_build_tree_returns_none_for_missing_root(self) -> None: + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ArtifactRepository(session) + manager = ArtifactManager(repo) + + result = await manager.build_tree(ULID()) + assert result is None + + await db.dispose() + + async def test_manager_without_hierarchy_has_null_labels(self) -> None: + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ArtifactRepository(session) + manager = ArtifactManager(repo) + + root = await manager.save(ArtifactIn(data={"name": "root"})) + tree = await manager.build_tree(root.id) + assert tree is not None + assert tree.level_label is None + assert tree.hierarchy is None + + await db.dispose() + + async def test_build_tree_handles_parent_missing_in_db(self) -> None: + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ArtifactRepository(session) + manager = ArtifactManager(repo) + + root = Artifact(data={"name": "orphan"}, parent_id=None, level=0) + await repo.save(root) + await repo.commit() + + child = Artifact(data={"name": "stray"}, parent_id=root.id, level=1) + await repo.save(child) + await repo.commit() + + # Delete the parent so the child references a missing parent record + parent_entity = await repo.find_by_id(root.id) + assert parent_entity is not None + await repo.delete(parent_entity) + await repo.commit() + + level = await manager._compute_level(child.parent_id) + assert level == 0 + + await db.dispose() + + async def test_find_subtree_returns_output_schemas(self) -> None: + """Test that find_subtree returns ArtifactOut schemas.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ArtifactRepository(session) + manager = ArtifactManager(repo) + + # Create parent and child + parent_in = ArtifactIn(data={"name": "parent"}) + parent_out = await manager.save(parent_in) + + assert parent_out.id is not None + child_in = ArtifactIn(data={"name": "child"}, parent_id=parent_out.id) + await manager.save(child_in) + + # Find subtree + subtree = await manager.find_subtree(parent_out.id) + + assert len(subtree) == 2 + assert all(isinstance(artifact, ArtifactOut) for artifact in subtree) + assert all(artifact.id is not None for artifact in subtree) + levels = {artifact.level for artifact in subtree} + assert levels == {0, 1} + + await db.dispose() + + async def test_find_by_id(self) -> None: + """Test finding artifact by ID through manager.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ArtifactRepository(session) + manager = ArtifactManager(repo) + + # Create artifact + artifact_in = ArtifactIn(data={"key": "value"}) + saved = await manager.save(artifact_in) + + # Find by ID + assert saved.id is not None + found = await manager.find_by_id(saved.id) + + assert found is not None + assert isinstance(found, ArtifactOut) + assert found.id == saved.id + assert found.data == {"key": "value"} + assert found.level == 0 + + # Non-existent ID should return None + random_id = ULID() + not_found = await manager.find_by_id(random_id) + assert not_found is None + + await db.dispose() + + async def test_delete_artifact(self) -> None: + """Test deleting an artifact through manager.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ArtifactRepository(session) + manager = ArtifactManager(repo) + + # Create artifact + artifact_in = ArtifactIn(data={"to": "delete"}) + saved = await manager.save(artifact_in) + + # Verify it exists + assert await manager.count() == 1 + + # Delete it + assert saved.id is not None + await manager.delete_by_id(saved.id) + + # Verify it's gone + assert await manager.count() == 0 + + await db.dispose() + + async def test_output_schema_includes_timestamps(self) -> None: + """Test that ArtifactOut schemas include created_at and updated_at timestamps.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ArtifactRepository(session) + manager = ArtifactManager(repo) + + # Create artifact + artifact_in = ArtifactIn(data={"test": "timestamps"}) + result = await manager.save(artifact_in) + + # Verify timestamps exist + assert result.created_at is not None + assert result.updated_at is not None + assert result.id is not None + assert result.level == 0 + + await db.dispose() + + +def test_pandas_dataframe_round_trip() -> None: + """PandasDataFrame should round-trip DataFrame data.""" + df = pd.DataFrame({"col1": [1, 2], "col2": ["a", "b"]}) + wrapper = PandasDataFrame.from_dataframe(df) + + assert wrapper.columns == ["col1", "col2"] + assert wrapper.data == [[1, "a"], [2, "b"]] + + reconstructed = wrapper.to_dataframe() + pd.testing.assert_frame_equal(reconstructed, df) + + +def test_pandas_dataframe_requires_dataframe() -> None: + """PandasDataFrame.from_dataframe should validate input type.""" + with pytest.raises(TypeError): + PandasDataFrame.from_dataframe({"not": "dataframe"}) # type: ignore[arg-type] diff --git a/packages/servicekit/tests/test_artifact_router.py b/packages/servicekit/tests/test_artifact_router.py new file mode 100644 index 0000000..8c2d0cb --- /dev/null +++ b/packages/servicekit/tests/test_artifact_router.py @@ -0,0 +1,62 @@ +"""Tests for ArtifactRouter error handling.""" + +from unittest.mock import AsyncMock, Mock + +from fastapi import FastAPI +from fastapi.testclient import TestClient +from servicekit.modules.artifact import ArtifactIn, ArtifactManager, ArtifactOut, ArtifactRouter +from ulid import ULID + + +def test_expand_artifact_not_found_returns_404() -> None: + """Test that expand_artifact returns 404 when artifact not found.""" + mock_manager = Mock(spec=ArtifactManager) + mock_manager.expand_artifact = AsyncMock(return_value=None) + + def manager_factory() -> ArtifactManager: + return mock_manager + + app = FastAPI() + router = ArtifactRouter.create( + prefix="/api/v1/artifacts", + tags=["Artifacts"], + manager_factory=manager_factory, + entity_in_type=ArtifactIn, + entity_out_type=ArtifactOut, + ) + app.include_router(router) + + client = TestClient(app) + + artifact_id = str(ULID()) + response = client.get(f"/api/v1/artifacts/{artifact_id}/$expand") + + assert response.status_code == 404 + assert f"Artifact with id {artifact_id} not found" in response.text + + +def test_build_tree_not_found_returns_404() -> None: + """Test that build_tree returns 404 when artifact not found.""" + mock_manager = Mock(spec=ArtifactManager) + mock_manager.build_tree = AsyncMock(return_value=None) + + def manager_factory() -> ArtifactManager: + return mock_manager + + app = FastAPI() + router = ArtifactRouter.create( + prefix="/api/v1/artifacts", + tags=["Artifacts"], + manager_factory=manager_factory, + entity_in_type=ArtifactIn, + entity_out_type=ArtifactOut, + ) + app.include_router(router) + + client = TestClient(app) + + artifact_id = str(ULID()) + response = client.get(f"/api/v1/artifacts/{artifact_id}/$tree") + + assert response.status_code == 404 + assert f"Artifact with id {artifact_id} not found" in response.text diff --git a/packages/servicekit/tests/test_artifact_serialization.py b/packages/servicekit/tests/test_artifact_serialization.py new file mode 100644 index 0000000..ce607be --- /dev/null +++ b/packages/servicekit/tests/test_artifact_serialization.py @@ -0,0 +1,293 @@ +"""Tests for artifact data serialization with non-JSON-serializable types.""" + +from __future__ import annotations + +import importlib.util +import json +from datetime import datetime + +import pytest +from servicekit import ArtifactOut +from ulid import ULID + + +class CustomNonSerializable: + """A custom class that is not JSON-serializable.""" + + def __init__(self, value: int) -> None: + self.value = value + + def __repr__(self) -> str: + return f"CustomNonSerializable(value={self.value})" + + +class TestArtifactSerialization: + """Test artifact data field serialization with various data types.""" + + def test_json_serializable_data_passes_through(self) -> None: + """JSON-serializable data should be returned unchanged.""" + # Test various JSON-serializable types + test_cases = [ + {"key": "value", "number": 42}, + [1, 2, 3, 4, 5], + "simple string", + 42, + 3.14, + True, + None, + {"nested": {"data": [1, 2, {"deep": "value"}]}}, + ] + + for data in test_cases: + artifact = ArtifactOut( + id=ULID(), + created_at=datetime.now(), + updated_at=datetime.now(), + data=data, + parent_id=None, + level=0, + ) + + # Serialize to dict (triggers field_serializer) + serialized = artifact.model_dump() + + # Data should be unchanged + assert serialized["data"] == data + + # The full artifact should be JSON-serializable via Pydantic + json_str = artifact.model_dump_json() + assert json_str is not None + + # Can parse back + parsed = json.loads(json_str) + assert parsed["data"] == data + + def test_custom_object_returns_metadata(self) -> None: + """Non-serializable custom objects should return metadata.""" + custom_obj = CustomNonSerializable(value=42) + + artifact = ArtifactOut( + id=ULID(), + created_at=datetime.now(), + updated_at=datetime.now(), + data=custom_obj, + parent_id=None, + level=0, + ) + + serialized = artifact.model_dump() + + # Should return metadata dict + assert isinstance(serialized["data"], dict) + assert serialized["data"]["_type"] == "CustomNonSerializable" + assert serialized["data"]["_module"] == __name__ + assert "CustomNonSerializable(value=42)" in serialized["data"]["_repr"] + assert "not JSON-serializable" in serialized["data"]["_serialization_error"] + + # The whole artifact should now be JSON-serializable via Pydantic + json_str = artifact.model_dump_json() + assert json_str is not None + + # Can parse back + parsed = json.loads(json_str) + assert parsed["data"]["_type"] == "CustomNonSerializable" + + def test_function_returns_metadata(self) -> None: + """Functions should return metadata instead of crashing.""" + + def my_function(x: int) -> int: + return x * 2 + + artifact = ArtifactOut( + id=ULID(), + created_at=datetime.now(), + updated_at=datetime.now(), + data=my_function, + parent_id=None, + level=0, + ) + + serialized = artifact.model_dump() + + assert isinstance(serialized["data"], dict) + assert serialized["data"]["_type"] == "function" + assert "my_function" in serialized["data"]["_repr"] + assert "_serialization_error" in serialized["data"] + + def test_bytes_returns_metadata(self) -> None: + """Bytes objects should return metadata.""" + binary_data = b"\x00\x01\x02\x03\xff" + + artifact = ArtifactOut( + id=ULID(), + created_at=datetime.now(), + updated_at=datetime.now(), + data=binary_data, + parent_id=None, + level=0, + ) + + serialized = artifact.model_dump() + + assert isinstance(serialized["data"], dict) + assert serialized["data"]["_type"] == "bytes" + assert "_serialization_error" in serialized["data"] + + def test_very_long_repr_is_truncated(self) -> None: + """Very long repr strings should be truncated.""" + + class LongRepr: + def __repr__(self) -> str: + return "x" * 1000 # Very long repr + + obj = LongRepr() + + artifact = ArtifactOut( + id=ULID(), + created_at=datetime.now(), + updated_at=datetime.now(), + data=obj, + parent_id=None, + level=0, + ) + + serialized = artifact.model_dump() + + # Repr should be truncated to 200 chars + "..." + assert len(serialized["data"]["_repr"]) <= 203 # 200 + "..." + assert serialized["data"]["_repr"].endswith("...") + + def test_complex_nested_structure_with_non_serializable_parts(self) -> None: + """Nested structures with both serializable and non-serializable parts.""" + # This should pass - the outer structure is serializable + data = { + "name": "experiment", + "params": {"lr": 0.001, "epochs": 100}, + "results": [1, 2, 3], + } + + artifact = ArtifactOut( + id=ULID(), + created_at=datetime.now(), + updated_at=datetime.now(), + data=data, + parent_id=None, + level=0, + ) + + serialized = artifact.model_dump() + assert serialized["data"] == data + + # Should be JSON-serializable via Pydantic + json_str = artifact.model_dump_json() + assert json_str is not None + + def test_set_returns_metadata(self) -> None: + """Sets are not JSON-serializable and should return metadata.""" + data = {1, 2, 3, 4, 5} + + artifact = ArtifactOut( + id=ULID(), + created_at=datetime.now(), + updated_at=datetime.now(), + data=data, + parent_id=None, + level=0, + ) + + serialized = artifact.model_dump() + + assert isinstance(serialized["data"], dict) + assert serialized["data"]["_type"] == "set" + assert "_serialization_error" in serialized["data"] + + @pytest.mark.skipif( + importlib.util.find_spec("torch") is None, + reason="Requires torch to be installed", + ) + def test_torch_tensor_returns_metadata(self) -> None: + """PyTorch tensors should return metadata (optional test).""" + try: + import torch # type: ignore[import-not-found] + + tensor = torch.randn(3, 3) + + artifact = ArtifactOut( + id=ULID(), + created_at=datetime.now(), + updated_at=datetime.now(), + data=tensor, + parent_id=None, + level=0, + ) + + serialized = artifact.model_dump() + + assert isinstance(serialized["data"], dict) + assert serialized["data"]["_type"] == "Tensor" + assert serialized["data"]["_module"] == "torch" + assert "_serialization_error" in serialized["data"] + except ImportError: + pytest.skip("torch not installed") + + @pytest.mark.skipif( + importlib.util.find_spec("sklearn") is None, + reason="Requires scikit-learn to be installed", + ) + def test_sklearn_model_returns_metadata(self) -> None: + """Scikit-learn models should return metadata (optional test).""" + try: + import numpy as np + from sklearn.linear_model import LinearRegression # type: ignore[import-untyped] + + # Train a simple linear regression model + X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]]) + y = np.dot(X, np.array([1, 2])) + 3 + model = LinearRegression().fit(X, y) + + artifact = ArtifactOut( + id=ULID(), + created_at=datetime.now(), + updated_at=datetime.now(), + data=model, + parent_id=None, + level=0, + ) + + serialized = artifact.model_dump() + + assert isinstance(serialized["data"], dict) + assert serialized["data"]["_type"] == "LinearRegression" + assert "sklearn" in serialized["data"]["_module"] + assert "_serialization_error" in serialized["data"] + + # Should be JSON-serializable via Pydantic + json_str = artifact.model_dump_json() + assert json_str is not None + + # Can parse back + parsed = json.loads(json_str) + assert parsed["data"]["_type"] == "LinearRegression" + except ImportError: + pytest.skip("scikit-learn not installed") + + def test_model_dump_json_works_with_non_serializable(self) -> None: + """The model_dump_json() method should work with non-serializable data.""" + custom_obj = CustomNonSerializable(value=123) + + artifact = ArtifactOut( + id=ULID(), + created_at=datetime.now(), + updated_at=datetime.now(), + data=custom_obj, + parent_id=None, + level=0, + ) + + # Should not raise an exception + json_str = artifact.model_dump_json() + assert json_str is not None + + # Should be parseable + parsed = json.loads(json_str) + assert parsed["data"]["_type"] == "CustomNonSerializable" diff --git a/packages/servicekit/tests/test_auth.py b/packages/servicekit/tests/test_auth.py new file mode 100644 index 0000000..a69775b --- /dev/null +++ b/packages/servicekit/tests/test_auth.py @@ -0,0 +1,319 @@ +"""Tests for API key authentication middleware and utilities.""" + +from pathlib import Path +from tempfile import NamedTemporaryFile + +import pytest +from _pytest.monkeypatch import MonkeyPatch +from fastapi import FastAPI, Request +from fastapi.testclient import TestClient +from servicekit.core.api.auth import ( + APIKeyMiddleware, + load_api_keys_from_env, + load_api_keys_from_file, + validate_api_key_format, +) + + +def test_load_api_keys_from_env_single_key(monkeypatch: MonkeyPatch) -> None: + """Test loading a single API key from environment variable.""" + monkeypatch.setenv("CHAPKIT_API_KEYS", "sk_test_123") + keys = load_api_keys_from_env("CHAPKIT_API_KEYS") + assert keys == {"sk_test_123"} + + +def test_load_api_keys_from_env_multiple_keys(monkeypatch: MonkeyPatch) -> None: + """Test loading multiple comma-separated keys.""" + monkeypatch.setenv("CHAPKIT_API_KEYS", "sk_test_1,sk_test_2,sk_test_3") + keys = load_api_keys_from_env("CHAPKIT_API_KEYS") + assert keys == {"sk_test_1", "sk_test_2", "sk_test_3"} + + +def test_load_api_keys_from_env_with_spaces(monkeypatch: MonkeyPatch) -> None: + """Test that spaces around keys are stripped.""" + monkeypatch.setenv("CHAPKIT_API_KEYS", "sk_test_1 , sk_test_2 , sk_test_3") + keys = load_api_keys_from_env("CHAPKIT_API_KEYS") + assert keys == {"sk_test_1", "sk_test_2", "sk_test_3"} + + +def test_load_api_keys_from_env_empty() -> None: + """Test loading from non-existent environment variable.""" + keys = load_api_keys_from_env("NONEXISTENT_VAR") + assert keys == set() + + +def test_load_api_keys_from_env_empty_string(monkeypatch: MonkeyPatch) -> None: + """Test loading from empty environment variable.""" + monkeypatch.setenv("CHAPKIT_API_KEYS", "") + keys = load_api_keys_from_env("CHAPKIT_API_KEYS") + assert keys == set() + + +def test_load_api_keys_from_file(): + """Test loading API keys from file.""" + with NamedTemporaryFile(mode="w", delete=False) as f: + f.write("sk_test_1\n") + f.write("sk_test_2\n") + f.write("sk_test_3\n") + temp_path = f.name + + try: + keys = load_api_keys_from_file(temp_path) + assert keys == {"sk_test_1", "sk_test_2", "sk_test_3"} + finally: + Path(temp_path).unlink() + + +def test_load_api_keys_from_file_with_comments(): + """Test that comments and empty lines are ignored.""" + with NamedTemporaryFile(mode="w", delete=False) as f: + f.write("# Comment line\n") + f.write("sk_test_1\n") + f.write("\n") # Empty line + f.write("sk_test_2\n") + f.write("# Another comment\n") + f.write("sk_test_3\n") + temp_path = f.name + + try: + keys = load_api_keys_from_file(temp_path) + assert keys == {"sk_test_1", "sk_test_2", "sk_test_3"} + finally: + Path(temp_path).unlink() + + +def test_load_api_keys_from_file_not_found(): + """Test error when file doesn't exist.""" + with pytest.raises(FileNotFoundError, match="API key file not found"): + load_api_keys_from_file("/nonexistent/path/keys.txt") + + +def test_validate_api_key_format_valid(): + """Test validation of valid API key formats.""" + assert validate_api_key_format("sk_prod_a1b2c3d4e5f6g7h8") + assert validate_api_key_format("sk_dev_1234567890123456") + assert validate_api_key_format("1234567890123456") # Min length 16 + + +def test_validate_api_key_format_too_short(): + """Test validation rejects keys shorter than 16 characters.""" + assert not validate_api_key_format("short") + assert not validate_api_key_format("sk_dev_123") + assert not validate_api_key_format("123456789012345") # 15 chars + + +def test_api_key_middleware_valid_key(): + """Test middleware allows requests with valid API key.""" + app = FastAPI() + + app.add_middleware( + APIKeyMiddleware, + api_keys={"sk_test_valid"}, + header_name="X-API-Key", + unauthenticated_paths=set(), + ) + + @app.get("/test") + def test_endpoint(): + return {"status": "ok"} + + client = TestClient(app) + response = client.get("/test", headers={"X-API-Key": "sk_test_valid"}) + + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + +def test_api_key_middleware_invalid_key(): + """Test middleware rejects requests with invalid API key.""" + from servicekit.core.api.middleware import add_error_handlers + + app = FastAPI() + + # Add error handlers FIRST (before middleware) + add_error_handlers(app) + + @app.get("/test") + def test_endpoint(): + return {"status": "ok"} + + # Add middleware after routes + app.add_middleware( + APIKeyMiddleware, + api_keys={"sk_test_valid"}, + header_name="X-API-Key", + unauthenticated_paths=set(), + ) + + client = TestClient(app, raise_server_exceptions=False) + response = client.get("/test", headers={"X-API-Key": "sk_test_invalid"}) + + assert response.status_code == 401 + assert response.json()["type"] == "urn:chapkit:error:unauthorized" + assert "Invalid API key" in response.json()["detail"] + + +def test_api_key_middleware_missing_key(): + """Test middleware rejects requests without API key.""" + from servicekit.core.api.middleware import add_error_handlers + + app = FastAPI() + + # Add error handlers FIRST (before middleware) + add_error_handlers(app) + + @app.get("/test") + def test_endpoint(): + return {"status": "ok"} + + # Add middleware after routes + app.add_middleware( + APIKeyMiddleware, + api_keys={"sk_test_valid"}, + header_name="X-API-Key", + unauthenticated_paths=set(), + ) + + client = TestClient(app, raise_server_exceptions=False) + response = client.get("/test") + + assert response.status_code == 401 + assert response.json()["type"] == "urn:chapkit:error:unauthorized" + assert "Missing authentication header" in response.json()["detail"] + + +def test_api_key_middleware_unauthenticated_path(): + """Test middleware allows unauthenticated paths without key.""" + from servicekit.core.api.middleware import add_error_handlers + + app = FastAPI() + + # Add error handlers FIRST (before middleware) + add_error_handlers(app) + + @app.get("/health") + def health(): + return {"status": "healthy"} + + @app.get("/test") + def test_endpoint(): + return {"status": "ok"} + + # Add middleware after routes + app.add_middleware( + APIKeyMiddleware, + api_keys={"sk_test_valid"}, + header_name="X-API-Key", + unauthenticated_paths={"/health", "/docs"}, + ) + + client = TestClient(app, raise_server_exceptions=False) + + # Unauthenticated path works without key + response = client.get("/health") + assert response.status_code == 200 + + # Authenticated path requires key + response = client.get("/test") + assert response.status_code == 401 + + +def test_api_key_middleware_multiple_valid_keys(): + """Test middleware accepts any of multiple valid keys.""" + app = FastAPI() + + app.add_middleware( + APIKeyMiddleware, + api_keys={"sk_test_1", "sk_test_2", "sk_test_3"}, + header_name="X-API-Key", + unauthenticated_paths=set(), + ) + + @app.get("/test") + def test_endpoint(): + return {"status": "ok"} + + client = TestClient(app) + + # All three keys should work + for key in ["sk_test_1", "sk_test_2", "sk_test_3"]: + response = client.get("/test", headers={"X-API-Key": key}) + assert response.status_code == 200 + + +def test_api_key_middleware_custom_header_name(): + """Test middleware with custom header name.""" + from servicekit.core.api.middleware import add_error_handlers + + app = FastAPI() + + # Add error handlers FIRST (before middleware) + add_error_handlers(app) + + @app.get("/test") + def test_endpoint(): + return {"status": "ok"} + + # Add middleware after routes + app.add_middleware( + APIKeyMiddleware, + api_keys={"sk_test_valid"}, + header_name="X-Custom-API-Key", + unauthenticated_paths=set(), + ) + + client = TestClient(app, raise_server_exceptions=False) + + # Default header name shouldn't work + response = client.get("/test", headers={"X-API-Key": "sk_test_valid"}) + assert response.status_code == 401 + + # Custom header name should work + response = client.get("/test", headers={"X-Custom-API-Key": "sk_test_valid"}) + assert response.status_code == 200 + + +def test_api_key_middleware_attaches_prefix_to_request_state(): + """Test that middleware attaches key prefix to request state.""" + app = FastAPI() + + app.add_middleware( + APIKeyMiddleware, + api_keys={"sk_test_valid_key_12345"}, + header_name="X-API-Key", + unauthenticated_paths=set(), + ) + + @app.get("/test") + def test_endpoint(request: Request): + # Check that key prefix is attached to request state + assert hasattr(request.state, "api_key_prefix") + assert request.state.api_key_prefix == "sk_test" + return {"status": "ok"} + + client = TestClient(app) + response = client.get("/test", headers={"X-API-Key": "sk_test_valid_key_12345"}) + + assert response.status_code == 200 + + +def test_service_builder_auth_logging_no_duplicates(capsys: pytest.CaptureFixture[str]) -> None: + """Test that auth warning is logged only once during startup.""" + from servicekit.core.api import BaseServiceBuilder, ServiceInfo + + # Build app with direct API keys (triggers warning) + info = ServiceInfo(display_name="Test Service") + app = BaseServiceBuilder(info=info, include_logging=True).with_auth(api_keys=["sk_dev_test123"]).build() + + # Create test client (triggers startup) + with TestClient(app): + pass + + # Capture output + captured = capsys.readouterr() + + # Count occurrences of the warning message + warning_count = captured.out.count("Using direct API keys") + + # Should appear exactly once + assert warning_count == 1, f"Expected 1 warning, found {warning_count}" diff --git a/packages/servicekit/tests/test_config.py b/packages/servicekit/tests/test_config.py new file mode 100644 index 0000000..bbcf0e8 --- /dev/null +++ b/packages/servicekit/tests/test_config.py @@ -0,0 +1,255 @@ +from types import SimpleNamespace +from typing import cast + +import pytest +from pydantic_core.core_schema import ValidationInfo +from servicekit import Config, ConfigOut, SqliteDatabaseBuilder +from ulid import ULID + +from .conftest import DemoConfig + + +class DemoConfigModel: + """Tests for the Config model.""" + + async def test_create_config_with_name_and_data(self) -> None: + """Test creating a Config with name and data.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + test_data = DemoConfig(x=1, y=2, z=3, tags=["test"]) + config = Config(name="test_config", data=test_data) + session.add(config) + await session.commit() + await session.refresh(config) + + assert config.id is not None + assert isinstance(config.id, ULID) + assert config.name == "test_config" + config_data = config.data + assert config_data is not None + assert isinstance(config_data, dict) + assert config_data["x"] == 1 + assert config_data["y"] == 2 + assert config_data["z"] == 3 + assert config_data["tags"] == ["test"] + assert config.created_at is not None + assert config.updated_at is not None + + await db.dispose() + + +def test_config_data_setter_rejects_invalid_type() -> None: + """Config.data setter should reject unsupported types.""" + config = Config(name="invalid_type") + with pytest.raises(TypeError): + bad_value = cast(DemoConfig, 123) + config.data = bad_value + + +def test_config_data_setter_accepts_dict() -> None: + """Config.data setter should accept dict values.""" + config = Config(name="dict_data", data={}) + config.data = DemoConfig(x=1, y=2, z=3, tags=["test"]) + assert config.data == {"x": 1, "y": 2, "z": 3, "tags": ["test"]} + + +def test_config_out_retains_dict_without_context() -> None: + """ConfigOut should leave dict data unchanged when no context is provided.""" + payload = {"x": 1, "y": 2, "z": 3, "tags": ["raw"]} + info = cast(ValidationInfo, SimpleNamespace(context=None)) + result = ConfigOut[DemoConfig].convert_dict_to_model(payload, info) + assert result == payload + + +class TestConfigModelExtras: + async def test_create_config_with_empty_data(self) -> None: + """Test creating a Config with empty dict data.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + config = Config(name="empty_config", data={}) + session.add(config) + await session.commit() + await session.refresh(config) + + assert config.id is not None + assert config.name == "empty_config" + assert config.data == {} + + await db.dispose() + + async def test_config_name_allows_duplicates(self) -> None: + """Test that Config name field allows duplicates (no unique constraint).""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + config1 = Config(name="duplicate_name", data=DemoConfig(x=1, y=2, z=3, tags=[])) + session.add(config1) + await session.commit() + await session.refresh(config1) + + # Create another config with the same name - should succeed + async with db.session() as session: + config2 = Config(name="duplicate_name", data=DemoConfig(x=4, y=5, z=6, tags=[])) + session.add(config2) + await session.commit() + await session.refresh(config2) + + # Verify they have different IDs but same name + assert config1.id != config2.id + assert config1.name == config2.name == "duplicate_name" + + await db.dispose() + + async def test_config_type_preservation(self) -> None: + """Test Config stores data as dict and can be deserialized by application.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + test_data = DemoConfig(x=10, y=20, z=30, tags=["a", "b"]) + + async with db.session() as session: + config = Config(name="type_test", data=test_data) + session.add(config) + await session.commit() + await session.refresh(config) + + # The data should be stored as dict + config_data = config.data + assert config_data is not None + assert isinstance(config_data, dict) + assert config_data["x"] == 10 + assert config_data["y"] == 20 + assert config_data["z"] == 30 + assert config_data["tags"] == ["a", "b"] + + # Application can deserialize it + deserialized = DemoConfig.model_validate(config_data) + assert deserialized.x == 10 + assert deserialized.y == 20 + assert deserialized.z == 30 + assert deserialized.tags == ["a", "b"] + + await db.dispose() + + async def test_config_id_is_ulid(self) -> None: + """Test that Config ID is a ULID type.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + config = Config(name="ulid_test", data={}) + session.add(config) + await session.commit() + await session.refresh(config) + + assert isinstance(config.id, ULID) + # ULID string representation should be 26 characters + assert len(str(config.id)) == 26 + + await db.dispose() + + async def test_config_timestamps_auto_set(self) -> None: + """Test that created_at and updated_at are automatically set.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + config = Config(name="timestamp_test", data=DemoConfig(x=1, y=2, z=3, tags=[])) + session.add(config) + await session.commit() + await session.refresh(config) + + assert config.created_at is not None + assert config.updated_at is not None + # Initially, created_at and updated_at should be very close + time_diff = abs((config.updated_at - config.created_at).total_seconds()) + assert time_diff < 1 # Less than 1 second difference + + await db.dispose() + + async def test_config_update_modifies_data(self) -> None: + """Test updating Config data field.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + config = Config(name="update_test", data=DemoConfig(x=1, y=2, z=3, tags=["original"])) + session.add(config) + await session.commit() + await session.refresh(config) + + original_id = config.id + + # Update the data + config.data = DemoConfig(x=10, y=20, z=30, tags=["updated"]) + await session.commit() + await session.refresh(config) + + assert config.id == original_id + config_data = config.data + assert config_data is not None + assert isinstance(config_data, dict) + assert config_data["x"] == 10 + assert config_data["y"] == 20 + assert config_data["z"] == 30 + assert config_data["tags"] == ["updated"] + + await db.dispose() + + async def test_config_tablename(self) -> None: + """Test that Config uses correct table name.""" + assert Config.__tablename__ == "configs" + + async def test_config_inherits_from_entity(self) -> None: + """Test that Config inherits from Entity and has expected fields.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + config = Config(name="inheritance_test", data={}) + session.add(config) + await session.commit() + await session.refresh(config) + + # Check inherited fields from Entity + assert hasattr(config, "id") + assert hasattr(config, "created_at") + assert hasattr(config, "updated_at") + + # Check Config-specific fields + assert hasattr(config, "name") + assert hasattr(config, "data") + + await db.dispose() + + async def test_multiple_configs_different_names(self) -> None: + """Test creating multiple configs with different names.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + config1 = Config(name="config_1", data=DemoConfig(x=1, y=1, z=1, tags=[])) + config2 = Config(name="config_2", data=DemoConfig(x=2, y=2, z=2, tags=[])) + config3 = Config(name="config_3", data=DemoConfig(x=3, y=3, z=3, tags=[])) + + session.add_all([config1, config2, config3]) + await session.commit() + await session.refresh(config1) + await session.refresh(config2) + await session.refresh(config3) + + assert config1.name == "config_1" + assert config2.name == "config_2" + assert config3.name == "config_3" + + # Each should have unique IDs + assert config1.id != config2.id + assert config2.id != config3.id + assert config1.id != config3.id + + await db.dispose() diff --git a/packages/servicekit/tests/test_config_artifact_link.py b/packages/servicekit/tests/test_config_artifact_link.py new file mode 100644 index 0000000..5a4b76d --- /dev/null +++ b/packages/servicekit/tests/test_config_artifact_link.py @@ -0,0 +1,433 @@ +"""Tests for Config-Artifact linking functionality.""" + +from __future__ import annotations + +import pytest +from servicekit import ( + Artifact, + ArtifactManager, + ArtifactRepository, + Config, + ConfigManager, + ConfigRepository, + SqliteDatabaseBuilder, +) + +from .conftest import DemoConfig + + +async def test_link_artifact_creates_link() -> None: + """ConfigRepository.link_artifact should create a link between config and root artifact.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + config_repo = ConfigRepository(session) + artifact_repo = ArtifactRepository(session) + + # Create a config + config = Config(name="test_config", data=DemoConfig(x=1, y=2, z=3, tags=["test"])) + await config_repo.save(config) + await config_repo.commit() + await config_repo.refresh_many([config]) + + # Create a root artifact + artifact = Artifact(data={"name": "root"}, level=0) + await artifact_repo.save(artifact) + await artifact_repo.commit() + await artifact_repo.refresh_many([artifact]) + + # Link them + await config_repo.link_artifact(config.id, artifact.id) + await config_repo.commit() + + # Verify link exists + found_config = await config_repo.find_by_root_artifact_id(artifact.id) + assert found_config is not None + assert found_config.id == config.id + + await db.dispose() + + +async def test_link_artifact_rejects_non_root_artifacts() -> None: + """ConfigRepository.link_artifact should raise ValueError if artifact has parent.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + config_repo = ConfigRepository(session) + artifact_repo = ArtifactRepository(session) + + # Create a config + config = Config(name="test_config", data=DemoConfig(x=1, y=2, z=3, tags=["test"])) + await config_repo.save(config) + await config_repo.commit() + await config_repo.refresh_many([config]) + + # Create root and child artifacts + root = Artifact(data={"name": "root"}, level=0) + await artifact_repo.save(root) + await artifact_repo.commit() + await artifact_repo.refresh_many([root]) + + child = Artifact(data={"name": "child"}, parent_id=root.id, level=1) + await artifact_repo.save(child) + await artifact_repo.commit() + await artifact_repo.refresh_many([child]) + + # Attempt to link child should fail + with pytest.raises(ValueError, match="not a root artifact"): + await config_repo.link_artifact(config.id, child.id) + + await db.dispose() + + +async def test_unlink_artifact_removes_link() -> None: + """ConfigRepository.unlink_artifact should remove the link.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + config_repo = ConfigRepository(session) + artifact_repo = ArtifactRepository(session) + + # Create and link + config = Config(name="test_config", data=DemoConfig(x=1, y=2, z=3, tags=["test"])) + await config_repo.save(config) + await config_repo.commit() + await config_repo.refresh_many([config]) + + artifact = Artifact(data={"name": "root"}, level=0) + await artifact_repo.save(artifact) + await artifact_repo.commit() + await artifact_repo.refresh_many([artifact]) + + await config_repo.link_artifact(config.id, artifact.id) + await config_repo.commit() + + # Verify link exists + assert await config_repo.find_by_root_artifact_id(artifact.id) is not None + + # Unlink + await config_repo.unlink_artifact(artifact.id) + await config_repo.commit() + + # Verify link removed + assert await config_repo.find_by_root_artifact_id(artifact.id) is None + + await db.dispose() + + +async def test_find_artifacts_for_config_returns_linked_artifacts() -> None: + """ConfigRepository.find_artifacts_for_config should return all linked root artifacts.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + config_repo = ConfigRepository(session) + artifact_repo = ArtifactRepository(session) + + # Create a config + config = Config(name="test_config", data=DemoConfig(x=1, y=2, z=3, tags=["test"])) + await config_repo.save(config) + await config_repo.commit() + await config_repo.refresh_many([config]) + + # Create multiple root artifacts + artifact1 = Artifact(data={"name": "root1"}, level=0) + artifact2 = Artifact(data={"name": "root2"}, level=0) + await artifact_repo.save_all([artifact1, artifact2]) + await artifact_repo.commit() + await artifact_repo.refresh_many([artifact1, artifact2]) + + # Link both to the same config + await config_repo.link_artifact(config.id, artifact1.id) + await config_repo.link_artifact(config.id, artifact2.id) + await config_repo.commit() + + # Find all linked artifacts + linked = await config_repo.find_artifacts_for_config(config.id) + assert len(linked) == 2 + assert {a.id for a in linked} == {artifact1.id, artifact2.id} + + await db.dispose() + + +async def test_get_root_artifact_walks_up_tree() -> None: + """ArtifactRepository.get_root_artifact should walk up the tree to find root.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ArtifactRepository(session) + + # Create tree: root -> child -> grandchild + root = Artifact(data={"name": "root"}, level=0) + await repo.save(root) + await repo.commit() + await repo.refresh_many([root]) + + child = Artifact(data={"name": "child"}, parent_id=root.id, level=1) + await repo.save(child) + await repo.commit() + await repo.refresh_many([child]) + + grandchild = Artifact(data={"name": "grandchild"}, parent_id=child.id, level=2) + await repo.save(grandchild) + await repo.commit() + await repo.refresh_many([grandchild]) + + # Get root from grandchild + found_root = await repo.get_root_artifact(grandchild.id) + assert found_root is not None + assert found_root.id == root.id + + # Get root from child + found_root = await repo.get_root_artifact(child.id) + assert found_root is not None + assert found_root.id == root.id + + # Get root from root + found_root = await repo.get_root_artifact(root.id) + assert found_root is not None + assert found_root.id == root.id + + await db.dispose() + + +async def test_config_manager_get_config_for_artifact() -> None: + """ConfigManager.get_config_for_artifact should walk up tree and return config.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + config_repo = ConfigRepository(session) + artifact_repo = ArtifactRepository(session) + manager = ConfigManager[DemoConfig](config_repo, DemoConfig) + + # Create config and artifacts + config = Config(name="test_config", data=DemoConfig(x=1, y=2, z=3, tags=["test"])) + await config_repo.save(config) + await config_repo.commit() + await config_repo.refresh_many([config]) + + root = Artifact(data={"name": "root"}, level=0) + await artifact_repo.save(root) + await artifact_repo.commit() + await artifact_repo.refresh_many([root]) + + child = Artifact(data={"name": "child"}, parent_id=root.id, level=1) + await artifact_repo.save(child) + await artifact_repo.commit() + await artifact_repo.refresh_many([child]) + + # Link config to root + await config_repo.link_artifact(config.id, root.id) + await config_repo.commit() + + # Get config from child (should walk up to root) + found_config = await manager.get_config_for_artifact(child.id, artifact_repo) + assert found_config is not None + assert found_config.id == config.id + assert found_config.data is not None + assert found_config.data.x == 1 + + await db.dispose() + + +async def test_artifact_manager_build_tree_includes_config() -> None: + """ArtifactManager.build_tree should include config at root node.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + config_repo = ConfigRepository(session) + artifact_repo = ArtifactRepository(session) + artifact_manager = ArtifactManager(artifact_repo, config_repo=config_repo) + + # Create config and artifacts + config = Config(name="test_config", data=DemoConfig(x=1, y=2, z=3, tags=["test"])) + await config_repo.save(config) + await config_repo.commit() + await config_repo.refresh_many([config]) + + root = Artifact(data={"name": "root"}, level=0) + await artifact_repo.save(root) + await artifact_repo.commit() + await artifact_repo.refresh_many([root]) + + child = Artifact(data={"name": "child"}, parent_id=root.id, level=1) + await artifact_repo.save(child) + await artifact_repo.commit() + await artifact_repo.refresh_many([child]) + + # Link config to root + await config_repo.link_artifact(config.id, root.id) + await config_repo.commit() + + # Build tree + tree = await artifact_manager.build_tree(root.id) + assert tree is not None + assert tree.config is not None + assert tree.config.id == config.id + assert tree.config.name == "test_config" + + # Children should not have config populated + assert tree.children is not None + assert len(tree.children) == 1 + assert tree.children[0].config is None + + await db.dispose() + + +async def test_artifact_manager_expand_artifact_includes_config() -> None: + """ArtifactManager.expand_artifact should include config at root node.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + config_repo = ConfigRepository(session) + artifact_repo = ArtifactRepository(session) + artifact_manager = ArtifactManager(artifact_repo, config_repo=config_repo) + + # Create config and artifacts + config = Config(name="test_config", data=DemoConfig(x=1, y=2, z=3, tags=["test"])) + await config_repo.save(config) + await config_repo.commit() + await config_repo.refresh_many([config]) + + root = Artifact(data={"name": "root"}, level=0) + await artifact_repo.save(root) + await artifact_repo.commit() + await artifact_repo.refresh_many([root]) + + child = Artifact(data={"name": "child"}, parent_id=root.id, level=1) + await artifact_repo.save(child) + await artifact_repo.commit() + await artifact_repo.refresh_many([child]) + + # Link config to root + await config_repo.link_artifact(config.id, root.id) + await config_repo.commit() + + # Expand root artifact + expanded = await artifact_manager.expand_artifact(root.id) + assert expanded is not None + assert expanded.config is not None + assert expanded.config.id == config.id + assert expanded.config.name == "test_config" + + # expand_artifact should not include children + assert expanded.children is None + + await db.dispose() + + +async def test_artifact_manager_expand_artifact_without_config() -> None: + """ArtifactManager.expand_artifact should handle artifacts with no config.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + config_repo = ConfigRepository(session) + artifact_repo = ArtifactRepository(session) + artifact_manager = ArtifactManager(artifact_repo, config_repo=config_repo) + + # Create root artifact without linking any config + root = Artifact(data={"name": "root"}, level=0) + await artifact_repo.save(root) + await artifact_repo.commit() + await artifact_repo.refresh_many([root]) + + # Expand root artifact + expanded = await artifact_manager.expand_artifact(root.id) + assert expanded is not None + assert expanded.config is None + assert expanded.children is None + assert expanded.id == root.id + assert expanded.level == 0 + + await db.dispose() + + +async def test_artifact_manager_expand_artifact_on_child() -> None: + """ArtifactManager.expand_artifact on child should not populate config.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + config_repo = ConfigRepository(session) + artifact_repo = ArtifactRepository(session) + artifact_manager = ArtifactManager(artifact_repo, config_repo=config_repo) + + # Create config and artifacts + config = Config(name="test_config", data=DemoConfig(x=1, y=2, z=3, tags=["test"])) + await config_repo.save(config) + await config_repo.commit() + await config_repo.refresh_many([config]) + + root = Artifact(data={"name": "root"}, level=0) + await artifact_repo.save(root) + await artifact_repo.commit() + await artifact_repo.refresh_many([root]) + + child = Artifact(data={"name": "child"}, parent_id=root.id, level=1) + await artifact_repo.save(child) + await artifact_repo.commit() + await artifact_repo.refresh_many([child]) + + # Link config to root + await config_repo.link_artifact(config.id, root.id) + await config_repo.commit() + + # Expand child artifact (not root) + expanded_child = await artifact_manager.expand_artifact(child.id) + assert expanded_child is not None + assert expanded_child.id == child.id + assert expanded_child.level == 1 + assert expanded_child.parent_id == root.id + # Config should be None because child is not a root + assert expanded_child.config is None + assert expanded_child.children is None + + await db.dispose() + + +async def test_cascade_delete_config_deletes_artifacts() -> None: + """Deleting a config should cascade delete linked artifacts.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + config_repo = ConfigRepository(session) + artifact_repo = ArtifactRepository(session) + + # Create config and artifacts + config = Config(name="test_config", data=DemoConfig(x=1, y=2, z=3, tags=["test"])) + await config_repo.save(config) + await config_repo.commit() + await config_repo.refresh_many([config]) + + root = Artifact(data={"name": "root"}, level=0) + await artifact_repo.save(root) + await artifact_repo.commit() + await artifact_repo.refresh_many([root]) + + child = Artifact(data={"name": "child"}, parent_id=root.id, level=1) + await artifact_repo.save(child) + await artifact_repo.commit() + await artifact_repo.refresh_many([child]) + + # Link config to root + await config_repo.link_artifact(config.id, root.id) + await config_repo.commit() + + # Delete config + await config_repo.delete_by_id(config.id) + await config_repo.commit() + + # Verify artifacts are deleted (cascade from config -> config_artifact -> artifact) + assert await artifact_repo.find_by_id(root.id) is None + assert await artifact_repo.find_by_id(child.id) is None + + await db.dispose() diff --git a/packages/servicekit/tests/test_core_coverage.py b/packages/servicekit/tests/test_core_coverage.py new file mode 100644 index 0000000..b32343a --- /dev/null +++ b/packages/servicekit/tests/test_core_coverage.py @@ -0,0 +1,206 @@ +"""Tests to improve coverage of chapkit.core modules.""" + +import tempfile +from pathlib import Path + +import pytest +from servicekit.core import BaseManager, BaseRepository, Database, Entity, EntityIn, EntityOut, SqliteDatabaseBuilder +from servicekit.core.logging import add_request_context, clear_request_context, get_logger, reset_request_context +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import Mapped + + +class CustomEntity(Entity): + """Test entity for custom behavior.""" + + __tablename__ = "custom_entities" + name: Mapped[str] + value: Mapped[int] + + +class CustomEntityIn(EntityIn): + """Input schema for custom entity.""" + + name: str + value: int + + +class CustomEntityOut(EntityOut): + """Output schema for custom entity.""" + + name: str + value: int + + +class CustomEntityRepository(BaseRepository): + """Repository for custom entity.""" + + def __init__(self, session: AsyncSession) -> None: + """Initialize repository.""" + super().__init__(session, CustomEntity) + + +class CustomEntityManager(BaseManager): + """Manager with custom field assignment logic.""" + + def __init__(self, repo: CustomEntityRepository) -> None: + """Initialize manager.""" + super().__init__(repo, CustomEntity, CustomEntityOut) + + def _should_assign_field(self, field: str, value: object) -> bool: + """Override to skip assignment of 'name' field when value is 'skip'.""" + if field == "name" and value == "skip": + return False + return True + + +async def test_database_custom_alembic_dir(): + """Test Database with custom alembic directory.""" + import shutil + + with tempfile.TemporaryDirectory() as tmpdir: + # Copy the real alembic directory structure + real_alembic_dir = Path(__file__).parent.parent / "alembic" + custom_alembic_dir = Path(tmpdir) / "custom_alembic" + shutil.copytree(real_alembic_dir, custom_alembic_dir) + + # Create database with custom alembic dir (auto_migrate=True to hit line 78) + db_path = Path(tmpdir) / "test.db" + db = Database(f"sqlite+aiosqlite:///{db_path}", alembic_dir=custom_alembic_dir, auto_migrate=True) + + await db.init() + await db.dispose() + + # Verify database was created and migrations ran + assert db_path.exists() + + +async def test_manager_should_assign_field_returns_false(): + """Test BaseManager when _should_assign_field returns False.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = CustomEntityRepository(session) + manager = CustomEntityManager(repo) + + # Create entity + result = await manager.save(CustomEntityIn(name="original", value=1)) + entity_id = result.id + + # Update with name="skip" should skip the name field + updated = await manager.save(CustomEntityIn(id=entity_id, name="skip", value=2)) + + # Name should still be "original" because "skip" was filtered out + assert updated.name == "original" + assert updated.value == 2 + + await db.dispose() + + +async def test_manager_should_assign_field_returns_false_bulk(): + """Test BaseManager save_all when _should_assign_field returns False.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = CustomEntityRepository(session) + manager = CustomEntityManager(repo) + + # Create entities + result1 = await manager.save(CustomEntityIn(name="original1", value=1)) + result2 = await manager.save(CustomEntityIn(name="original2", value=2)) + + # Bulk update with name="skip" should skip the name field + updated = await manager.save_all( + [ + CustomEntityIn(id=result1.id, name="skip", value=10), + CustomEntityIn(id=result2.id, name="skip", value=20), + ] + ) + + # Names should still be "original*" because "skip" was filtered out + assert updated[0].name == "original1" + assert updated[0].value == 10 + assert updated[1].name == "original2" + assert updated[1].value == 20 + + await db.dispose() + + +async def test_manager_find_paginated(): + """Test manager find_paginated returns tuple correctly.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = CustomEntityRepository(session) + manager = CustomEntityManager(repo) + + # Create multiple entities + for i in range(5): + await manager.save(CustomEntityIn(name=f"entity_{i}", value=i)) + + # Test pagination + results, total = await manager.find_paginated(page=1, size=2) + + assert len(results) == 2 + assert total == 5 + assert isinstance(results, list) + assert isinstance(total, int) + + await db.dispose() + + +def test_logging_clear_request_context(): + """Test clear_request_context removes specific keys.""" + # Add some context + add_request_context(request_id="123", user_id="456", trace_id="789") + + # Clear specific keys + clear_request_context("user_id", "trace_id") + + # Reset to clean up + reset_request_context() + + +def test_logging_get_logger_with_name(): + """Test get_logger with explicit name.""" + logger = get_logger("test.module") + assert logger is not None + + # Test logging functionality + logger.info("test_message", key="value") + + +async def test_scheduler_duplicate_job_error(): + """Test scheduler raises error when duplicate job ID exists.""" + from unittest.mock import patch + + from servicekit.core.scheduler import AIOJobScheduler + + scheduler = AIOJobScheduler() + + # Create a job + async def dummy_job(): + return "result" + + job_id = await scheduler.add_job(dummy_job) + + # Mock ULID() to return the same job_id to trigger duplicate check at line 120 + with patch("chapkit.core.scheduler.ULID", return_value=job_id): + with pytest.raises(RuntimeError, match="already scheduled"): + await scheduler.add_job(dummy_job) + + +async def test_scheduler_wait_job_not_found(): + """Test scheduler wait raises KeyError for non-existent job.""" + from servicekit.core.scheduler import AIOJobScheduler + from ulid import ULID + + scheduler = AIOJobScheduler() + + # Try to wait for a job that was never added + fake_job_id = ULID() + with pytest.raises(KeyError, match="Job not found"): + await scheduler.wait(fake_job_id) diff --git a/packages/servicekit/tests/test_database.py b/packages/servicekit/tests/test_database.py new file mode 100644 index 0000000..18d78cb --- /dev/null +++ b/packages/servicekit/tests/test_database.py @@ -0,0 +1,349 @@ +import tempfile +from pathlib import Path +from types import SimpleNamespace +from typing import cast + +import pytest +import servicekit.core.database as database_module +from servicekit import SqliteDatabase, SqliteDatabaseBuilder +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncEngine + + +def test_install_sqlite_pragmas(monkeypatch: pytest.MonkeyPatch) -> None: + """Ensure SQLite connect pragmas are installed on new connections.""" + captured: dict[str, object] = {} + + def fake_listen(target: object, event_name: str, handler: object) -> None: + captured["target"] = target + captured["event_name"] = event_name + captured["handler"] = handler + + fake_engine = cast(AsyncEngine, SimpleNamespace(sync_engine=object())) + monkeypatch.setattr(database_module.event, "listen", fake_listen) + + database_module._install_sqlite_connect_pragmas(fake_engine) + + assert captured["target"] is fake_engine.sync_engine + assert captured["event_name"] == "connect" + handler = captured["handler"] + assert callable(handler) + + class DummyCursor: + def __init__(self) -> None: + self.commands: list[str] = [] + self.closed = False + + def execute(self, sql: str) -> None: + self.commands.append(sql) + + def close(self) -> None: + self.closed = True + + class DummyConnection: + def __init__(self) -> None: + self._cursor = DummyCursor() + + def cursor(self) -> DummyCursor: + return self._cursor + + connection = DummyConnection() + handler(connection, None) + + assert connection._cursor.commands == [ + "PRAGMA foreign_keys=ON;", + "PRAGMA synchronous=NORMAL;", + "PRAGMA busy_timeout=30000;", + "PRAGMA temp_store=MEMORY;", + "PRAGMA cache_size=-64000;", + "PRAGMA mmap_size=134217728;", + ] + assert connection._cursor.closed is True + + +class TestSqliteDatabase: + """Tests for the SqliteDatabase class.""" + + async def test_init_creates_tables(self) -> None: + """Test that init() creates all tables.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + # Verify tables were created by checking metadata + async with db.session() as session: + # Simple query to verify database is operational + result = await session.execute(text("SELECT 1")) + assert result.scalar() == 1 + + await db.dispose() + + async def test_session_context_manager(self) -> None: + """Test that session() context manager works correctly.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + # Verify we got an AsyncSession + assert session is not None + # Verify it's usable + result = await session.execute(text("SELECT 1")) + assert result.scalar() == 1 + + await db.dispose() + + async def test_multiple_sessions(self) -> None: + """Test that multiple sessions can be created.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session1: + result1 = await session1.execute(text("SELECT 1")) + assert result1.scalar() == 1 + + async with db.session() as session2: + result2 = await session2.execute(text("SELECT 2")) + assert result2.scalar() == 2 + + await db.dispose() + + async def test_dispose_closes_engine(self) -> None: + """Test that dispose() properly closes the engine.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + # Verify engine is initially usable + async with db.session() as session: + result = await session.execute(text("SELECT 1")) + assert result.scalar() == 1 + + await db.dispose() + # After dispose, the engine should no longer be usable + # We verify this by checking that the pool still exists but is disposed + + async def test_echo_parameter(self) -> None: + """Test that echo parameter is passed to engine.""" + db_echo = SqliteDatabase("sqlite+aiosqlite:///:memory:", echo=True) + db_no_echo = SqliteDatabase("sqlite+aiosqlite:///:memory:", echo=False) + + assert db_echo.engine.echo is True + assert db_no_echo.engine.echo is False + + await db_echo.dispose() + await db_no_echo.dispose() + + async def test_wal_mode_enabled(self) -> None: + """Test that WAL mode is enabled after init().""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + result = await session.execute(text("SELECT 1")) + # For in-memory databases, WAL mode might not be enabled + # but the init() should complete without error + assert result is not None + + await db.dispose() + + async def test_url_storage(self) -> None: + """Test that the URL is stored correctly.""" + url = "sqlite+aiosqlite:///:memory:" + db = SqliteDatabase(url) + assert db.url == url + await db.dispose() + + async def test_pool_configuration_file_database(self) -> None: + """Test that pool parameters are applied to file-based databases.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp_file: + db_path = Path(tmp_file.name) + + try: + db = SqliteDatabase( + f"sqlite+aiosqlite:///{db_path}", + pool_size=10, + max_overflow=20, + pool_recycle=7200, + pool_pre_ping=False, + ) + # File-based databases should have pool configuration + # Verify pool exists and database is functional + await db.init() + async with db.session() as session: + result = await session.execute(text("SELECT 1")) + assert result.scalar() == 1 + await db.dispose() + finally: + if db_path.exists(): + db_path.unlink() + + async def test_pool_configuration_memory_database(self) -> None: + """Test that in-memory databases skip pool configuration.""" + # In-memory databases use StaticPool which doesn't accept pool params + # This should not raise an error + db = SqliteDatabase( + "sqlite+aiosqlite:///:memory:", + pool_size=10, + max_overflow=20, + ) + await db.init() + async with db.session() as session: + result = await session.execute(text("SELECT 1")) + assert result.scalar() == 1 + await db.dispose() + + async def test_session_factory_configuration(self) -> None: + """Test that session factory is configured correctly.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + # Verify session factory has expire_on_commit set to False + assert db._session_factory.kw.get("expire_on_commit") is False + + await db.dispose() + + async def test_file_based_database_with_alembic_migrations(self) -> None: + """Test that file-based databases use Alembic migrations to create schema.""" + # Create a temporary database file + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp_file: + db_path = Path(tmp_file.name) + + try: + # Initialize database with file-based URL + db = SqliteDatabase(f"sqlite+aiosqlite:///{db_path}") + await db.init() + + # Verify that tables were created via Alembic migration + async with db.session() as session: + # Check that the alembic_version table exists (created by Alembic) + result = await session.execute( + text("SELECT name FROM sqlite_master WHERE type='table' AND name='alembic_version'") + ) + alembic_table = result.scalar() + assert alembic_table == "alembic_version", "Alembic version table should exist" + + # Verify current migration version is set + result = await session.execute(text("SELECT version_num FROM alembic_version")) + version = result.scalar() + assert version is not None, "Migration version should be recorded" + + # Verify that application tables were created + result = await session.execute( + text( + "SELECT name FROM sqlite_master WHERE type='table' " + "AND name IN ('configs', 'artifacts', 'config_artifacts') ORDER BY name" + ) + ) + tables = [row[0] for row in result.fetchall()] + assert tables == ["artifacts", "config_artifacts", "configs"], "All application tables should exist" + + await db.dispose() + + finally: + # Clean up temporary database file + if db_path.exists(): + db_path.unlink() + + async def test_is_in_memory_method(self) -> None: + """Test is_in_memory() method.""" + # In-memory database + db_mem = SqliteDatabase("sqlite+aiosqlite:///:memory:") + assert db_mem.is_in_memory() is True + await db_mem.dispose() + + # File-based database + db_file = SqliteDatabase("sqlite+aiosqlite:///./app.db") + assert db_file.is_in_memory() is False + await db_file.dispose() + + +class TestSqliteDatabaseBuilder: + """Tests for SqliteDatabaseBuilder class.""" + + async def test_in_memory_builder(self) -> None: + """Test building an in-memory database.""" + db = SqliteDatabaseBuilder.in_memory().build() + + assert db.url == "sqlite+aiosqlite:///:memory:" + assert db.is_in_memory() is True + # Note: auto_migrate setting doesn't matter for in-memory - they always skip Alembic + + await db.init() + async with db.session() as session: + result = await session.execute(text("SELECT 1")) + assert result.scalar() == 1 + + await db.dispose() + + async def test_from_file_builder(self) -> None: + """Test building a file-based database.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp_file: + db_path = Path(tmp_file.name) + + try: + db = SqliteDatabaseBuilder.from_file(db_path).build() + + assert db.url == f"sqlite+aiosqlite:///{db_path}" + assert db.is_in_memory() is False + assert db.auto_migrate is True # File-based should enable migrations by default + + await db.dispose() + finally: + if db_path.exists(): + db_path.unlink() + + async def test_from_file_with_string_path(self) -> None: + """Test building from string path.""" + db = SqliteDatabaseBuilder.from_file("./test.db").build() + assert db.url == "sqlite+aiosqlite:///./test.db" + await db.dispose() + + async def test_builder_with_echo(self) -> None: + """Test builder with echo enabled.""" + db = SqliteDatabaseBuilder.in_memory().with_echo(True).build() + assert db.engine.echo is True + await db.dispose() + + async def test_builder_with_migrations(self) -> None: + """Test builder with migration configuration.""" + custom_dir = Path("/custom/alembic") + db = SqliteDatabaseBuilder.in_memory().with_migrations(enabled=True, alembic_dir=custom_dir).build() + + assert db.auto_migrate is True + assert db.alembic_dir == custom_dir + await db.dispose() + + async def test_builder_with_pool(self) -> None: + """Test builder with pool configuration.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp_file: + db_path = Path(tmp_file.name) + + try: + db = ( + SqliteDatabaseBuilder.from_file(db_path) + .with_pool(size=20, max_overflow=40, recycle=1800, pre_ping=False) + .build() + ) + + # Verify database is functional + await db.init() + async with db.session() as session: + result = await session.execute(text("SELECT 1")) + assert result.scalar() == 1 + + await db.dispose() + finally: + if db_path.exists(): + db_path.unlink() + + async def test_builder_chainable_api(self) -> None: + """Test that builder methods are chainable.""" + db = SqliteDatabaseBuilder.in_memory().with_echo(True).with_migrations(False).with_pool(size=10).build() + + assert db.engine.echo is True + assert db.auto_migrate is False + await db.dispose() + + def test_builder_without_url_raises_error(self) -> None: + """Test that building without URL raises error.""" + builder = SqliteDatabaseBuilder() + with pytest.raises(ValueError, match="Database URL not configured"): + builder.build() diff --git a/packages/servicekit/tests/test_dependencies.py b/packages/servicekit/tests/test_dependencies.py new file mode 100644 index 0000000..654f300 --- /dev/null +++ b/packages/servicekit/tests/test_dependencies.py @@ -0,0 +1,165 @@ +"""Tests for dependency injection utilities.""" + +from __future__ import annotations + +import pytest +from servicekit import SqliteDatabaseBuilder +from servicekit.core.api.dependencies import get_database, get_scheduler, set_database, set_scheduler + + +def test_get_database_uninitialized() -> None: + """Test get_database raises error when database is not initialized.""" + # Reset global database state + import servicekit.core.api.dependencies as deps + + original_db = deps._database + deps._database = None + + try: + with pytest.raises(RuntimeError, match="Database not initialized"): + get_database() + finally: + # Restore original state + deps._database = original_db + + +async def test_set_and_get_database() -> None: + """Test setting and getting the database instance.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + try: + set_database(db) + retrieved_db = get_database() + assert retrieved_db is db + finally: + await db.dispose() + + +async def test_get_config_manager() -> None: + """Test get_config_manager returns a ConfigManager instance.""" + from servicekit import ConfigManager + from servicekit.api.dependencies import get_config_manager + from servicekit.core.api.dependencies import get_session + + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + try: + set_database(db) + + # Use the session generator + async for session in get_session(db): + manager = await get_config_manager(session) + assert isinstance(manager, ConfigManager) + break + finally: + await db.dispose() + + +async def test_get_artifact_manager() -> None: + """Test get_artifact_manager returns an ArtifactManager instance.""" + from servicekit import ArtifactManager + from servicekit.api.dependencies import get_artifact_manager + from servicekit.core.api.dependencies import get_session + + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + try: + set_database(db) + + # Use the session generator + async for session in get_session(db): + manager = await get_artifact_manager(session) + assert isinstance(manager, ArtifactManager) + break + finally: + await db.dispose() + + +def test_get_scheduler_uninitialized() -> None: + """Test get_scheduler raises error when scheduler is not initialized.""" + # Reset global scheduler state + import servicekit.core.api.dependencies as deps + + original_scheduler = deps._scheduler + deps._scheduler = None + + try: + with pytest.raises(RuntimeError, match="Scheduler not initialized"): + get_scheduler() + finally: + # Restore original state + deps._scheduler = original_scheduler + + +def test_set_and_get_scheduler() -> None: + """Test setting and getting the scheduler instance.""" + from unittest.mock import Mock + + from servicekit.core import JobScheduler + + # Create a mock scheduler since JobScheduler is abstract + scheduler = Mock(spec=JobScheduler) + + try: + set_scheduler(scheduler) + retrieved_scheduler = get_scheduler() + assert retrieved_scheduler is scheduler + finally: + # Reset global state + import servicekit.core.api.dependencies as deps + + deps._scheduler = None + + +async def test_get_task_manager_without_scheduler_and_database() -> None: + """Test get_task_manager handles missing scheduler and database gracefully.""" + from servicekit.api.dependencies import get_task_manager + from servicekit.core.api.dependencies import get_session + + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + try: + set_database(db) + + # Reset scheduler and database to trigger RuntimeError paths + import servicekit.core.api.dependencies as deps + + original_scheduler = deps._scheduler + original_db = deps._database + deps._scheduler = None + deps._database = None + + # Use the session generator + async for session in get_session(db): + # Create a mock artifact manager + from servicekit import ArtifactManager, ArtifactRepository + + artifact_repo = ArtifactRepository(session) + artifact_manager = ArtifactManager(artifact_repo) + + manager = await get_task_manager(session, artifact_manager) + # Manager should be created even without scheduler/database + assert manager is not None + break + + # Restore state + deps._scheduler = original_scheduler + deps._database = original_db + finally: + await db.dispose() + + +def test_get_ml_manager_raises_runtime_error() -> None: + """Test get_ml_manager raises RuntimeError when not configured.""" + from servicekit.api.dependencies import get_ml_manager + + with pytest.raises(RuntimeError, match="ML manager dependency not configured"): + # This is a sync function that returns a coroutine, but we need to call it + # The function itself should raise before returning the coroutine + import asyncio + + asyncio.run(get_ml_manager()) diff --git a/packages/servicekit/tests/test_example_artifact_api.py b/packages/servicekit/tests/test_example_artifact_api.py new file mode 100644 index 0000000..62450d3 --- /dev/null +++ b/packages/servicekit/tests/test_example_artifact_api.py @@ -0,0 +1,297 @@ +"""Tests for artifact_api example using TestClient. + +This example demonstrates a read-only artifact API with hierarchical data. +""" + +from __future__ import annotations + +from collections.abc import Generator + +import pytest +from fastapi.testclient import TestClient + +from examples.artifact_api import app + + +@pytest.fixture(scope="module") +def client() -> Generator[TestClient, None, None]: + """Create FastAPI TestClient for testing with lifespan context.""" + with TestClient(app) as test_client: + yield test_client + + +def test_landing_page(client: TestClient) -> None: + """Test landing page returns HTML.""" + response = client.get("/") + assert response.status_code == 200 + assert "text/html" in response.headers["content-type"] + + +def test_health_endpoint(client: TestClient) -> None: + """Test health check returns healthy status.""" + response = client.get("/health") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + + +def test_info_endpoint(client: TestClient) -> None: + """Test service info endpoint returns service metadata.""" + response = client.get("/api/v1/info") + assert response.status_code == 200 + data = response.json() + assert data["display_name"] == "Chapkit Artifact Service" + assert data["summary"] == "Artifact CRUD and tree operations example" + assert "hierarchy" in data + assert data["hierarchy"]["name"] == "ml_experiment" + assert "pipelines" in data + assert len(data["pipelines"]) == 2 + + +def test_list_artifacts(client: TestClient) -> None: + """Test listing all seeded artifacts.""" + response = client.get("/api/v1/artifacts") + assert response.status_code == 200 + data = response.json() + + # Should be a list of artifacts from 2 pipelines + assert isinstance(data, list) + assert len(data) > 0 + + # Verify structure + for artifact in data: + assert "id" in artifact + assert "data" in artifact + assert "parent_id" in artifact + assert "level" in artifact + assert "created_at" in artifact + assert "updated_at" in artifact + + +def test_list_artifacts_with_pagination(client: TestClient) -> None: + """Test listing artifacts with pagination.""" + response = client.get("/api/v1/artifacts", params={"page": 1, "size": 3}) + assert response.status_code == 200 + data = response.json() + + # Should return paginated response + assert "items" in data + assert "total" in data + assert "page" in data + assert "size" in data + assert "pages" in data + + assert len(data["items"]) <= 3 + assert data["page"] == 1 + assert data["size"] == 3 + + +def test_get_artifact_by_id(client: TestClient) -> None: + """Test retrieving artifact by ID.""" + # First get the list to obtain a valid ID + list_response = client.get("/api/v1/artifacts") + artifacts = list_response.json() + artifact_id = artifacts[0]["id"] + + response = client.get(f"/api/v1/artifacts/{artifact_id}") + assert response.status_code == 200 + data = response.json() + + assert data["id"] == artifact_id + assert "data" in data + assert "level" in data + + +def test_get_artifact_by_id_not_found(client: TestClient) -> None: + """Test retrieving non-existent artifact returns 404.""" + response = client.get("/api/v1/artifacts/01K72P5N5KCRM6MD3BRE4P0999") + assert response.status_code == 404 + data = response.json() + assert "not found" in data["detail"].lower() + + +def test_get_artifact_tree(client: TestClient) -> None: + """Test retrieving artifact tree structure with $tree operation.""" + # Get a root artifact (level 0) + list_response = client.get("/api/v1/artifacts") + artifacts = list_response.json() + root_artifact = next((a for a in artifacts if a["level"] == 0), None) + assert root_artifact is not None + + root_id = root_artifact["id"] + response = client.get(f"/api/v1/artifacts/{root_id}/$tree") + assert response.status_code == 200 + data = response.json() + + # Verify tree structure + assert data["id"] == root_id + assert "data" in data + assert "children" in data + assert isinstance(data["children"], list) + + # Verify hierarchical structure + if len(data["children"]) > 0: + child = data["children"][0] + assert "id" in child + assert "data" in child + assert "children" in child + + +def test_get_artifact_tree_not_found(client: TestClient) -> None: + """Test $tree operation on non-existent artifact returns 404.""" + response = client.get("/api/v1/artifacts/01K72P5N5KCRM6MD3BRE4P0999/$tree") + assert response.status_code == 404 + data = response.json() + assert "not found" in data["detail"].lower() + + +def test_expand_artifact(client: TestClient) -> None: + """Test expanding artifact with $expand operation returns hierarchy metadata without children.""" + # Get a root artifact (level 0) + list_response = client.get("/api/v1/artifacts") + artifacts = list_response.json() + root_artifact = next((a for a in artifacts if a["level"] == 0), None) + assert root_artifact is not None + + root_id = root_artifact["id"] + response = client.get(f"/api/v1/artifacts/{root_id}/$expand") + assert response.status_code == 200 + data = response.json() + + # Verify expanded structure + assert data["id"] == root_id + assert "data" in data + assert "level" in data + assert "level_label" in data + assert "hierarchy" in data + + # Verify hierarchy metadata is present + assert data["hierarchy"] == "ml_experiment" + assert data["level_label"] == "experiment" + + # Verify children is None (not included in expand) + assert data["children"] is None + + +def test_expand_artifact_with_parent(client: TestClient) -> None: + """Test expanding artifact with parent includes hierarchy metadata.""" + # Get a child artifact (level 1) + list_response = client.get("/api/v1/artifacts") + artifacts = list_response.json() + child_artifact = next((a for a in artifacts if a["level"] == 1), None) + assert child_artifact is not None + + child_id = child_artifact["id"] + response = client.get(f"/api/v1/artifacts/{child_id}/$expand") + assert response.status_code == 200 + data = response.json() + + # Verify expanded structure + assert data["id"] == child_id + assert data["level"] == 1 + assert data["level_label"] == "stage" + assert data["hierarchy"] == "ml_experiment" + assert data["children"] is None + + +def test_expand_artifact_not_found(client: TestClient) -> None: + """Test $expand operation on non-existent artifact returns 404.""" + response = client.get("/api/v1/artifacts/01K72P5N5KCRM6MD3BRE4P0999/$expand") + assert response.status_code == 404 + data = response.json() + assert "not found" in data["detail"].lower() + + +def test_artifact_with_non_json_payload(client: TestClient) -> None: + """Test artifact with non-JSON payload (MockLinearModel) is serialized with metadata.""" + # Find the artifact with MockLinearModel (ID: 01K72P5N5KCRM6MD3BRE4P07NJ) + response = client.get("/api/v1/artifacts/01K72P5N5KCRM6MD3BRE4P07NJ") + assert response.status_code == 200 + data = response.json() + + # Verify the MockLinearModel is serialized with metadata fields + assert "data" in data + artifact_data = data["data"] + assert isinstance(artifact_data, dict) + assert artifact_data["_type"] == "MockLinearModel" + assert artifact_data["_module"] == "examples.artifact_api" + assert "MockLinearModel" in artifact_data["_repr"] + assert "coefficients" in artifact_data["_repr"] + assert "intercept" in artifact_data["_repr"] + assert "_serialization_error" in artifact_data + + +def test_create_artifact_not_allowed(client: TestClient) -> None: + """Test that creating artifacts is disabled (read-only API).""" + new_artifact = {"data": {"name": "test", "value": 123}} + + response = client.post("/api/v1/artifacts", json=new_artifact) + assert response.status_code == 405 # Method Not Allowed + + +def test_update_artifact_not_allowed(client: TestClient) -> None: + """Test that updating artifacts is disabled (read-only API).""" + # Get an existing artifact ID + list_response = client.get("/api/v1/artifacts") + artifacts = list_response.json() + artifact_id = artifacts[0]["id"] + + updated_artifact = {"id": artifact_id, "data": {"updated": True}} + + response = client.put(f"/api/v1/artifacts/{artifact_id}", json=updated_artifact) + assert response.status_code == 405 # Method Not Allowed + + +def test_delete_artifact_not_allowed(client: TestClient) -> None: + """Test that deleting artifacts is disabled (read-only API).""" + # Get an existing artifact ID + list_response = client.get("/api/v1/artifacts") + artifacts = list_response.json() + artifact_id = artifacts[0]["id"] + + response = client.delete(f"/api/v1/artifacts/{artifact_id}") + assert response.status_code == 405 # Method Not Allowed + + +def test_list_configs(client: TestClient) -> None: + """Test listing configs endpoint exists.""" + response = client.get("/api/v1/configs") + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + + +def test_experiment_alpha_tree_structure(client: TestClient) -> None: + """Test experiment_alpha tree has correct hierarchical structure.""" + # Root artifact for experiment_alpha + root_id = "01K72P5N5KCRM6MD3BRE4P07NB" + response = client.get(f"/api/v1/artifacts/{root_id}/$tree") + assert response.status_code == 200 + tree = response.json() + + # Verify root + assert tree["id"] == root_id + assert tree["data"]["name"] == "experiment_alpha" + assert tree["level"] == 0 + assert len(tree["children"]) == 2 # Two stages + + # Verify stages + for stage in tree["children"]: + assert stage["level"] == 1 + assert "stage" in stage["data"] + assert "artifacts" in stage or "children" in stage + + +def test_experiment_beta_tree_structure(client: TestClient) -> None: + """Test experiment_beta tree has correct hierarchical structure.""" + # Root artifact for experiment_beta + root_id = "01K72P5N5KCRM6MD3BRE4P07NK" + response = client.get(f"/api/v1/artifacts/{root_id}/$tree") + assert response.status_code == 200 + tree = response.json() + + # Verify root + assert tree["id"] == root_id + assert tree["data"]["name"] == "experiment_beta" + assert tree["level"] == 0 + assert len(tree["children"]) == 2 # Two stages diff --git a/packages/servicekit/tests/test_example_config_api.py b/packages/servicekit/tests/test_example_config_api.py new file mode 100644 index 0000000..633b70e --- /dev/null +++ b/packages/servicekit/tests/test_example_config_api.py @@ -0,0 +1,240 @@ +"""Tests for config_api example using TestClient. + +Tests use FastAPI's TestClient instead of running a separate server. +""" + +from __future__ import annotations + +from collections.abc import Generator + +import pytest +from fastapi.testclient import TestClient + +from examples.config_api import app + + +@pytest.fixture(scope="module") +def client() -> Generator[TestClient, None, None]: + """Create FastAPI TestClient for testing with lifespan context.""" + with TestClient(app) as test_client: + yield test_client + + +def test_landing_page(client: TestClient) -> None: + """Test landing page returns HTML.""" + response = client.get("/") + assert response.status_code == 200 + assert "text/html" in response.headers["content-type"] + + +def test_health_endpoint(client: TestClient) -> None: + """Test health check returns healthy status.""" + response = client.get("/health") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + assert "checks" in data + assert "database" in data["checks"] + + +def test_info_endpoint(client: TestClient) -> None: + """Test service info endpoint returns service metadata.""" + response = client.get("/api/v1/info") + assert response.status_code == 200 + data = response.json() + assert data["display_name"] == "Chapkit Config Service" + assert data["summary"] == "Environment configuration CRUD example" + assert data["author"] == "Morten Hansen" + assert "seeded_configs" in data + assert len(data["seeded_configs"]) == 3 + + +def test_list_configs(client: TestClient) -> None: + """Test listing all seeded configs.""" + response = client.get("/api/v1/configs") + assert response.status_code == 200 + data = response.json() + + # Should be a list of 3 seeded configs + assert isinstance(data, list) + assert len(data) == 3 + + # Check config names + names = {config["name"] for config in data} + assert names == {"production", "staging", "local"} + + # Verify structure + for config in data: + assert "id" in config + assert "name" in config + assert "data" in config + assert "created_at" in config + assert "updated_at" in config + + # Check data structure + assert "debug" in config["data"] + assert "api_host" in config["data"] + assert "api_port" in config["data"] + assert "max_connections" in config["data"] + + +def test_list_configs_with_pagination(client: TestClient) -> None: + """Test listing configs with pagination.""" + response = client.get("/api/v1/configs", params={"page": 1, "size": 2}) + assert response.status_code == 200 + data = response.json() + + # Should return paginated response + assert "items" in data + assert "total" in data + assert "page" in data + assert "size" in data + assert "pages" in data + + assert len(data["items"]) == 2 + assert data["total"] == 3 + assert data["page"] == 1 + assert data["size"] == 2 + assert data["pages"] == 2 + + +def test_get_config_by_id(client: TestClient) -> None: + """Test retrieving config by ID.""" + # First get the list to obtain a valid ID + list_response = client.get("/api/v1/configs") + configs = list_response.json() + config_id = configs[0]["id"] + + response = client.get(f"/api/v1/configs/{config_id}") + assert response.status_code == 200 + data = response.json() + + assert data["id"] == config_id + assert "name" in data + assert "data" in data + + +def test_get_config_by_invalid_ulid(client: TestClient) -> None: + """Test retrieving config with invalid ULID format returns 400.""" + response = client.get("/api/v1/configs/not-a-valid-ulid") + assert response.status_code == 400 + data = response.json() + assert "invalid ulid" in data["detail"].lower() + + +def test_get_config_by_id_not_found(client: TestClient) -> None: + """Test retrieving non-existent config by ID returns 404.""" + # Use a valid ULID format but non-existent ID + response = client.get("/api/v1/configs/01K72P5N5KCRM6MD3BRE4P0999") + assert response.status_code == 404 + data = response.json() + assert "not found" in data["detail"].lower() + + +def test_create_config(client: TestClient) -> None: + """Test creating a new config.""" + new_config = { + "name": "test-environment", + "data": {"debug": True, "api_host": "localhost", "api_port": 9000, "max_connections": 50}, + } + + response = client.post("/api/v1/configs", json=new_config) + assert response.status_code == 201 + data = response.json() + + assert "id" in data + assert data["name"] == "test-environment" + assert data["data"]["debug"] is True + assert data["data"]["api_port"] == 9000 + + # Verify it was created by fetching it + config_id = data["id"] + get_response = client.get(f"/api/v1/configs/{config_id}") + assert get_response.status_code == 200 + + +def test_create_config_duplicate_name(client: TestClient) -> None: + """Test creating config with duplicate name succeeds (no unique constraint).""" + # First create a config + first_config = { + "name": "duplicate-test", + "data": {"debug": False, "api_host": "0.0.0.0", "api_port": 8080, "max_connections": 1000}, + } + response1 = client.post("/api/v1/configs", json=first_config) + assert response1.status_code == 201 + first_id = response1.json()["id"] + + # Create another with the same name - should succeed + duplicate_config = { + "name": "duplicate-test", + "data": {"debug": True, "api_host": "127.0.0.1", "api_port": 8000, "max_connections": 500}, + } + + response2 = client.post("/api/v1/configs", json=duplicate_config) + assert response2.status_code == 201 + second_id = response2.json()["id"] + + # Verify they have different IDs but same name + assert first_id != second_id + assert response1.json()["name"] == response2.json()["name"] == "duplicate-test" + + +def test_update_config(client: TestClient) -> None: + """Test updating an existing config.""" + # Create a config first + new_config = { + "name": "update-test", + "data": {"debug": False, "api_host": "127.0.0.1", "api_port": 8080, "max_connections": 100}, + } + create_response = client.post("/api/v1/configs", json=new_config) + created = create_response.json() + config_id = created["id"] + + # Update it + updated_config = { + "id": config_id, + "name": "update-test", + "data": { + "debug": True, # Changed + "api_host": "127.0.0.1", + "api_port": 9999, # Changed + "max_connections": 200, # Changed + }, + } + + response = client.put(f"/api/v1/configs/{config_id}", json=updated_config) + assert response.status_code == 200 + data = response.json() + + assert data["id"] == config_id + assert data["data"]["debug"] is True + assert data["data"]["api_port"] == 9999 + assert data["data"]["max_connections"] == 200 + + +def test_delete_config(client: TestClient) -> None: + """Test deleting a config.""" + # Create a config first + new_config = { + "name": "delete-test", + "data": {"debug": False, "api_host": "127.0.0.1", "api_port": 8080, "max_connections": 100}, + } + create_response = client.post("/api/v1/configs", json=new_config) + created = create_response.json() + config_id = created["id"] + + # Delete it + response = client.delete(f"/api/v1/configs/{config_id}") + assert response.status_code == 204 + + # Verify it's gone + get_response = client.get(f"/api/v1/configs/{config_id}") + assert get_response.status_code == 404 + + +def test_delete_config_not_found(client: TestClient) -> None: + """Test deleting non-existent config returns 404.""" + response = client.delete("/api/v1/configs/01K72P5N5KCRM6MD3BRE4P0999") + assert response.status_code == 404 + data = response.json() + assert "not found" in data["detail"].lower() diff --git a/packages/servicekit/tests/test_example_config_artifact_api.py b/packages/servicekit/tests/test_example_config_artifact_api.py new file mode 100644 index 0000000..39a11d4 --- /dev/null +++ b/packages/servicekit/tests/test_example_config_artifact_api.py @@ -0,0 +1,273 @@ +"""Tests for config_artifact_api example using TestClient. + +This example demonstrates config-artifact linking and custom health checks. +""" + +from __future__ import annotations + +from collections.abc import Generator + +import pytest +from fastapi.testclient import TestClient + +from examples.config_artifact_api import app + + +@pytest.fixture(scope="module") +def client() -> Generator[TestClient, None, None]: + """Create FastAPI TestClient for testing with lifespan context.""" + with TestClient(app) as test_client: + yield test_client + + +def test_landing_page(client: TestClient) -> None: + """Test landing page returns HTML.""" + response = client.get("/") + assert response.status_code == 200 + assert "text/html" in response.headers["content-type"] + + +def test_health_endpoint_with_custom_checks(client: TestClient) -> None: + """Test health check includes custom flaky_service check.""" + response = client.get("/health") + # Can be healthy or unhealthy due to flaky check + assert response.status_code in [200, 503] + data = response.json() + + # Health endpoint returns HealthResponse which has status field + assert "status" in data + assert data["status"] in ["healthy", "degraded", "unhealthy"] + + assert "checks" in data + assert "database" in data["checks"] + assert "flaky_service" in data["checks"] + + # Flaky service check should have one of three states (using "state" key) + flaky_check = data["checks"]["flaky_service"] + assert flaky_check["state"] in ["healthy", "degraded", "unhealthy"] + + +def test_info_endpoint(client: TestClient) -> None: + """Test service info endpoint returns service metadata.""" + response = client.get("/api/v1/info") + assert response.status_code == 200 + data = response.json() + assert data["display_name"] == "Chapkit Config & Artifact Service" + assert data["summary"] == "Linked config and artifact CRUD example" + assert "hierarchy" in data + assert data["hierarchy"]["name"] == "training_pipeline" + assert "configs" in data + assert len(data["configs"]) == 2 + + +def test_list_configs(client: TestClient) -> None: + """Test listing all seeded configs.""" + response = client.get("/api/v1/configs") + assert response.status_code == 200 + data = response.json() + + # Should be a list of 2 seeded configs + assert isinstance(data, list) + assert len(data) == 2 + + # Check config names + names = {config["name"] for config in data} + assert names == {"experiment_alpha", "experiment_beta"} + + +def test_list_artifacts(client: TestClient) -> None: + """Test listing all seeded artifacts.""" + response = client.get("/api/v1/artifacts") + assert response.status_code == 200 + data = response.json() + + # Should be a list of artifacts + assert isinstance(data, list) + assert len(data) > 0 + + +def test_get_artifact_tree(client: TestClient) -> None: + """Test retrieving artifact tree structure.""" + # experiment_alpha root artifact + root_id = "01K72PWT05GEXK1S24AVKAZ9VF" + response = client.get(f"/api/v1/artifacts/{root_id}/$tree") + assert response.status_code == 200 + tree = response.json() + + assert tree["id"] == root_id + assert tree["data"]["stage"] == "train" + assert "children" in tree + assert len(tree["children"]) == 2 # Two predict runs + + +def test_get_linked_artifacts_for_config(client: TestClient) -> None: + """Test retrieving artifacts linked to a config.""" + # Get experiment_alpha config + configs_response = client.get("/api/v1/configs") + configs = configs_response.json() + alpha_config = next((c for c in configs if c["name"] == "experiment_alpha"), None) + assert alpha_config is not None + + config_id = alpha_config["id"] + response = client.get(f"/api/v1/configs/{config_id}/$artifacts") + assert response.status_code == 200 + artifacts = response.json() + + # Should have linked artifacts + assert isinstance(artifacts, list) + assert len(artifacts) > 0 + + # Verify it's the root artifact + root_artifact = artifacts[0] + assert root_artifact["id"] == "01K72PWT05GEXK1S24AVKAZ9VF" + + +def test_get_config_for_artifact(client: TestClient) -> None: + """Test retrieving config linked to an artifact.""" + # experiment_alpha root artifact + artifact_id = "01K72PWT05GEXK1S24AVKAZ9VF" + response = client.get(f"/api/v1/artifacts/{artifact_id}/$config") + assert response.status_code == 200 + config = response.json() + + # Should return the experiment_alpha config + assert config["name"] == "experiment_alpha" + assert "data" in config + assert config["data"]["model"] == "xgboost" + assert config["data"]["learning_rate"] == 0.05 + + +def test_link_artifact_to_config(client: TestClient) -> None: + """Test linking an artifact to a config.""" + # Create a new config + new_config = { + "name": "test-experiment", + "data": {"model": "random_forest", "learning_rate": 0.01, "epochs": 100, "batch_size": 128}, + } + create_response = client.post("/api/v1/configs", json=new_config) + assert create_response.status_code == 201 + config = create_response.json() + config_id = config["id"] + + # Create a root artifact + new_artifact = {"data": {"stage": "train", "dataset": "test_data.parquet"}} + artifact_response = client.post("/api/v1/artifacts", json=new_artifact) + assert artifact_response.status_code == 201 + artifact = artifact_response.json() + artifact_id = artifact["id"] + + # Link the artifact to the config + link_response = client.post(f"/api/v1/configs/{config_id}/$link-artifact", json={"artifact_id": artifact_id}) + assert link_response.status_code == 204 + + # Verify the link + artifacts_response = client.get(f"/api/v1/configs/{config_id}/$artifacts") + assert artifacts_response.status_code == 200 + linked_artifacts = artifacts_response.json() + assert len(linked_artifacts) == 1 + assert linked_artifacts[0]["id"] == artifact_id + + +def test_link_non_root_artifact_fails(client: TestClient) -> None: + """Test that linking a non-root artifact to config fails.""" + # Get experiment_alpha config + configs_response = client.get("/api/v1/configs") + configs = configs_response.json() + alpha_config = next((c for c in configs if c["name"] == "experiment_alpha"), None) + assert alpha_config is not None + config_id = alpha_config["id"] + + # Try to link a non-root artifact (level > 0) + # This is a child artifact from experiment_alpha + non_root_artifact_id = "01K72PWT05GEXK1S24AVKAZ9VG" + + link_response = client.post( + f"/api/v1/configs/{config_id}/$link-artifact", json={"artifact_id": non_root_artifact_id} + ) + # Should fail because non-root artifacts can't be linked + assert link_response.status_code == 400 + data = link_response.json() + assert "root" in data["detail"].lower() or "level 0" in data["detail"].lower() + + +def test_unlink_artifact_from_config(client: TestClient) -> None: + """Test unlinking an artifact from a config.""" + # Create config and artifact + new_config = { + "name": "unlink-test", + "data": {"model": "mlp", "learning_rate": 0.001, "epochs": 50, "batch_size": 64}, + } + config_response = client.post("/api/v1/configs", json=new_config) + config = config_response.json() + config_id = config["id"] + + new_artifact = {"data": {"stage": "train", "dataset": "unlink_test.parquet"}} + artifact_response = client.post("/api/v1/artifacts", json=new_artifact) + artifact = artifact_response.json() + artifact_id = artifact["id"] + + # Link them + client.post(f"/api/v1/configs/{config_id}/$link-artifact", json={"artifact_id": artifact_id}) + + # Unlink + unlink_response = client.post(f"/api/v1/configs/{config_id}/$unlink-artifact", json={"artifact_id": artifact_id}) + assert unlink_response.status_code == 204 + + # Verify unlinked + artifacts_response = client.get(f"/api/v1/configs/{config_id}/$artifacts") + linked_artifacts = artifacts_response.json() + assert len(linked_artifacts) == 0 + + +def test_create_config_with_experiment_schema(client: TestClient) -> None: + """Test creating a config with ExperimentConfig schema.""" + new_config = { + "name": "test-ml-config", + "data": {"model": "svm", "learning_rate": 0.1, "epochs": 30, "batch_size": 512}, + } + + response = client.post("/api/v1/configs", json=new_config) + assert response.status_code == 201 + data = response.json() + + assert data["name"] == "test-ml-config" + assert data["data"]["model"] == "svm" + assert data["data"]["epochs"] == 30 + + +def test_create_artifact_in_hierarchy(client: TestClient) -> None: + """Test creating artifacts following the training_pipeline hierarchy.""" + # Create root (train level) + train_artifact = {"data": {"stage": "train", "dataset": "new_train.parquet"}} + train_response = client.post("/api/v1/artifacts", json=train_artifact) + assert train_response.status_code == 201 + train = train_response.json() + train_id = train["id"] + assert train["level"] == 0 + + # Create child (predict level) + predict_artifact = { + "parent_id": train_id, + "data": {"stage": "predict", "run": "2024-03-01", "path": "predictions.parquet"}, + } + predict_response = client.post("/api/v1/artifacts", json=predict_artifact) + assert predict_response.status_code == 201 + predict = predict_response.json() + predict_id = predict["id"] + assert predict["level"] == 1 + assert predict["parent_id"] == train_id + + # Create grandchild (result level) + result_artifact = {"parent_id": predict_id, "data": {"stage": "result", "metrics": {"accuracy": 0.95}}} + result_response = client.post("/api/v1/artifacts", json=result_artifact) + assert result_response.status_code == 201 + result = result_response.json() + assert result["level"] == 2 + assert result["parent_id"] == predict_id + + # Verify tree structure + tree_response = client.get(f"/api/v1/artifacts/{train_id}/$tree") + tree = tree_response.json() + assert len(tree["children"]) == 1 + assert tree["children"][0]["id"] == predict_id + assert len(tree["children"][0]["children"]) == 1 diff --git a/packages/servicekit/tests/test_example_core_api.py b/packages/servicekit/tests/test_example_core_api.py new file mode 100644 index 0000000..82560c0 --- /dev/null +++ b/packages/servicekit/tests/test_example_core_api.py @@ -0,0 +1,261 @@ +"""Tests for core_api example using TestClient. + +Tests use FastAPI's TestClient instead of running a separate server. +Validates BaseServiceBuilder functionality with custom User entity. +""" + +from __future__ import annotations + +from collections.abc import Generator + +import pytest +from fastapi.testclient import TestClient + +from examples.core_api import app + + +@pytest.fixture(scope="module") +def client() -> Generator[TestClient, None, None]: + """Create FastAPI TestClient for testing with lifespan context.""" + with TestClient(app) as test_client: + yield test_client + + +def test_health_endpoint(client: TestClient) -> None: + """Test health check returns healthy status.""" + response = client.get("/health") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + assert "checks" in data + assert "database" in data["checks"] + assert data["checks"]["database"]["state"] == "healthy" + + +def test_system_endpoint(client: TestClient) -> None: + """Test system info endpoint returns system metadata.""" + response = client.get("/api/v1/system") + assert response.status_code == 200 + data = response.json() + assert "python_version" in data + assert "platform" in data + assert "current_time" in data + assert "timezone" in data + assert "hostname" in data + + +def test_info_endpoint(client: TestClient) -> None: + """Test service info endpoint returns service metadata.""" + response = client.get("/api/v1/info") + assert response.status_code == 200 + data = response.json() + assert data["display_name"] == "Core User Service" + assert data["version"] == "1.0.0" + assert data["summary"] == "User management API using core-only features" + + +def test_create_user(client: TestClient) -> None: + """Test creating a new user.""" + new_user = { + "username": "johndoe", + "email": "john@example.com", + "full_name": "John Doe", + "is_active": True, + } + + response = client.post("/api/v1/users", json=new_user) + assert response.status_code == 201 + data = response.json() + + assert "id" in data + assert data["username"] == "johndoe" + assert data["email"] == "john@example.com" + assert data["full_name"] == "John Doe" + assert data["is_active"] is True + + # Verify Location header + assert "Location" in response.headers + assert f"/api/v1/users/{data['id']}" in response.headers["Location"] + + +def test_list_users(client: TestClient) -> None: + """Test listing all users.""" + # Create a few users first + users_to_create = [ + {"username": "alice", "email": "alice@example.com", "full_name": "Alice Smith"}, + {"username": "bob", "email": "bob@example.com", "full_name": "Bob Jones"}, + ] + + for user in users_to_create: + client.post("/api/v1/users", json=user) + + # List all users + response = client.get("/api/v1/users") + assert response.status_code == 200 + data = response.json() + + # Should be a list + assert isinstance(data, list) + assert len(data) >= 2 + + # Verify structure + for user in data: + assert "id" in user + assert "username" in user + assert "email" in user + assert "created_at" in user + assert "updated_at" in user + + +def test_list_users_with_pagination(client: TestClient) -> None: + """Test listing users with pagination.""" + # Create multiple users to test pagination + for i in range(5): + client.post( + "/api/v1/users", + json={ + "username": f"user{i}", + "email": f"user{i}@example.com", + "full_name": f"User {i}", + }, + ) + + response = client.get("/api/v1/users", params={"page": 1, "size": 3}) + assert response.status_code == 200 + data = response.json() + + # Should return paginated response + assert "items" in data + assert "total" in data + assert "page" in data + assert "size" in data + assert "pages" in data + + assert len(data["items"]) <= 3 + assert data["page"] == 1 + assert data["size"] == 3 + assert data["total"] >= 5 + + +def test_get_user_by_id(client: TestClient) -> None: + """Test retrieving user by ID.""" + # Create a user first + new_user = {"username": "testuser", "email": "test@example.com", "full_name": "Test User"} + create_response = client.post("/api/v1/users", json=new_user) + created = create_response.json() + user_id = created["id"] + + # Get by ID + response = client.get(f"/api/v1/users/{user_id}") + assert response.status_code == 200 + data = response.json() + + assert data["id"] == user_id + assert data["username"] == "testuser" + assert data["email"] == "test@example.com" + + +def test_get_user_by_invalid_ulid(client: TestClient) -> None: + """Test retrieving user with invalid ULID format returns 400.""" + response = client.get("/api/v1/users/not-a-valid-ulid") + assert response.status_code == 400 + data = response.json() + assert "invalid ulid" in data["detail"].lower() + + +def test_get_user_by_id_not_found(client: TestClient) -> None: + """Test retrieving non-existent user by ID returns 404.""" + # Use a valid ULID format but non-existent ID + response = client.get("/api/v1/users/01K72P5N5KCRM6MD3BRE4P0999") + assert response.status_code == 404 + data = response.json() + assert "not found" in data["detail"].lower() + + +def test_update_user(client: TestClient) -> None: + """Test updating an existing user.""" + # Create a user first + new_user = {"username": "updateuser", "email": "update@example.com", "full_name": "Update User"} + create_response = client.post("/api/v1/users", json=new_user) + created = create_response.json() + user_id = created["id"] + + # Update it + updated_user = { + "id": user_id, + "username": "updateuser", + "email": "updated@example.com", # Changed + "full_name": "Updated User Name", # Changed + "is_active": False, # Changed + } + + response = client.put(f"/api/v1/users/{user_id}", json=updated_user) + assert response.status_code == 200 + data = response.json() + + assert data["id"] == user_id + assert data["username"] == "updateuser" + assert data["email"] == "updated@example.com" + assert data["full_name"] == "Updated User Name" + assert data["is_active"] is False + + +def test_update_user_not_found(client: TestClient) -> None: + """Test updating non-existent user returns 404.""" + user_id = "01K72P5N5KCRM6MD3BRE4P0999" + updated_user = { + "id": user_id, + "username": "ghost", + "email": "ghost@example.com", + } + + response = client.put(f"/api/v1/users/{user_id}", json=updated_user) + assert response.status_code == 404 + + +def test_delete_user(client: TestClient) -> None: + """Test deleting a user.""" + # Create a user first + new_user = {"username": "deleteuser", "email": "delete@example.com"} + create_response = client.post("/api/v1/users", json=new_user) + created = create_response.json() + user_id = created["id"] + + # Delete it + response = client.delete(f"/api/v1/users/{user_id}") + assert response.status_code == 204 + + # Verify it's gone + get_response = client.get(f"/api/v1/users/{user_id}") + assert get_response.status_code == 404 + + +def test_delete_user_not_found(client: TestClient) -> None: + """Test deleting non-existent user returns 404.""" + response = client.delete("/api/v1/users/01K72P5N5KCRM6MD3BRE4P0999") + assert response.status_code == 404 + data = response.json() + assert "not found" in data["detail"].lower() + + +def test_jobs_endpoint_exists(client: TestClient) -> None: + """Test that job scheduler endpoints are available.""" + response = client.get("/api/v1/jobs") + assert response.status_code == 200 + data = response.json() + # Should return empty list initially + assert isinstance(data, list) + + +def test_openapi_schema(client: TestClient) -> None: + """Test OpenAPI schema is generated correctly.""" + response = client.get("/openapi.json") + assert response.status_code == 200 + schema = response.json() + + assert schema["info"]["title"] == "Core User Service" + assert schema["info"]["version"] == "1.0.0" + assert "paths" in schema + assert "/api/v1/users" in schema["paths"] + assert "/health" in schema["paths"] + assert "/api/v1/system" in schema["paths"] diff --git a/packages/servicekit/tests/test_example_core_cli.py b/packages/servicekit/tests/test_example_core_cli.py new file mode 100644 index 0000000..aba8b02 --- /dev/null +++ b/packages/servicekit/tests/test_example_core_cli.py @@ -0,0 +1,305 @@ +"""Tests for core_cli example validating direct database operations. + +Tests the CLI example that demonstrates using core Database, Repository, +and Manager directly without FastAPI for command-line tools and scripts. +""" + +from __future__ import annotations + +from collections.abc import AsyncGenerator + +import pytest +from servicekit.core import Database, SqliteDatabaseBuilder + +from examples.core_cli import Product, ProductIn, ProductManager, ProductRepository + + +@pytest.fixture +async def database() -> AsyncGenerator[Database, None]: + """Create and initialize in-memory database for testing.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + try: + yield db + finally: + await db.dispose() + + +@pytest.fixture +async def manager(database: Database) -> AsyncGenerator[ProductManager, None]: + """Create product manager with initialized database.""" + async with database.session() as session: + repo = ProductRepository(session) + yield ProductManager(repo) + + +async def test_create_product(database: Database) -> None: + """Test creating a product.""" + async with database.session() as session: + repo = ProductRepository(session) + manager = ProductManager(repo) + + product = await manager.save( + ProductIn( + sku="TEST-001", + name="Test Product", + price=99.99, + stock=10, + ) + ) + + assert product.sku == "TEST-001" + assert product.name == "Test Product" + assert product.price == 99.99 + assert product.stock == 10 + assert product.active is True + assert product.id is not None + + +async def test_find_by_sku(database: Database) -> None: + """Test finding product by SKU.""" + async with database.session() as session: + repo = ProductRepository(session) + manager = ProductManager(repo) + + # Create a product + created = await manager.save( + ProductIn( + sku="FIND-001", + name="Findable Product", + price=49.99, + stock=5, + ) + ) + + # Find by SKU + found = await manager.find_by_sku("FIND-001") + assert found is not None + assert found.id == created.id + assert found.sku == "FIND-001" + assert found.name == "Findable Product" + + +async def test_find_by_sku_not_found(database: Database) -> None: + """Test finding product by non-existent SKU returns None.""" + async with database.session() as session: + repo = ProductRepository(session) + manager = ProductManager(repo) + + found = await manager.find_by_sku("NONEXISTENT") + assert found is None + + +async def test_find_low_stock(database: Database) -> None: + """Test finding products with low stock.""" + async with database.session() as session: + repo = ProductRepository(session) + manager = ProductManager(repo) + + # Create products with varying stock levels + await manager.save(ProductIn(sku="HIGH-001", name="High Stock", price=10.0, stock=100)) + await manager.save(ProductIn(sku="LOW-001", name="Low Stock 1", price=10.0, stock=5)) + await manager.save(ProductIn(sku="LOW-002", name="Low Stock 2", price=10.0, stock=8)) + await manager.save(ProductIn(sku="ZERO-001", name="Out of Stock", price=10.0, stock=0)) + + # Find low stock (threshold = 10) + low_stock = await manager.find_low_stock(threshold=10) + + assert len(low_stock) == 3 + skus = {p.sku for p in low_stock} + assert skus == {"LOW-001", "LOW-002", "ZERO-001"} + + +async def test_find_low_stock_excludes_inactive(database: Database) -> None: + """Test that low stock query excludes inactive products.""" + async with database.session() as session: + repo = ProductRepository(session) + manager = ProductManager(repo) + + # Create active and inactive products with low stock + await manager.save(ProductIn(sku="ACTIVE-LOW", name="Active Low", price=10.0, stock=5, active=True)) + await manager.save(ProductIn(sku="INACTIVE-LOW", name="Inactive Low", price=10.0, stock=5, active=False)) + + # Find low stock + low_stock = await manager.find_low_stock(threshold=10) + + # Should only include active products + assert len(low_stock) == 1 + assert low_stock[0].sku == "ACTIVE-LOW" + + +async def test_restock_product(database: Database) -> None: + """Test restocking a product.""" + async with database.session() as session: + repo = ProductRepository(session) + manager = ProductManager(repo) + + # Create a product + product = await manager.save( + ProductIn( + sku="RESTOCK-001", + name="Restockable Product", + price=29.99, + stock=5, + ) + ) + + initial_stock = product.stock + + # Restock + restocked = await manager.restock(product.id, 20) + + assert restocked.id == product.id + assert restocked.stock == initial_stock + 20 + assert restocked.stock == 25 + + +async def test_restock_nonexistent_product(database: Database) -> None: + """Test restocking non-existent product raises error.""" + async with database.session() as session: + repo = ProductRepository(session) + manager = ProductManager(repo) + + from ulid import ULID + + fake_id = ULID() + + with pytest.raises(ValueError, match="not found"): + await manager.restock(fake_id, 10) + + +async def test_list_all_products(database: Database) -> None: + """Test listing all products.""" + async with database.session() as session: + repo = ProductRepository(session) + manager = ProductManager(repo) + + # Create multiple products + await manager.save(ProductIn(sku="LIST-001", name="Product 1", price=10.0, stock=10)) + await manager.save(ProductIn(sku="LIST-002", name="Product 2", price=20.0, stock=20)) + await manager.save(ProductIn(sku="LIST-003", name="Product 3", price=30.0, stock=30)) + + # List all + all_products = await manager.find_all() + + assert len(all_products) == 3 + skus = {p.sku for p in all_products} + assert skus == {"LIST-001", "LIST-002", "LIST-003"} + + +async def test_count_products(database: Database) -> None: + """Test counting products.""" + async with database.session() as session: + repo = ProductRepository(session) + manager = ProductManager(repo) + + # Initially empty + count = await manager.count() + assert count == 0 + + # Add products + await manager.save(ProductIn(sku="COUNT-001", name="Product 1", price=10.0)) + await manager.save(ProductIn(sku="COUNT-002", name="Product 2", price=20.0)) + + count = await manager.count() + assert count == 2 + + +async def test_update_product(database: Database) -> None: + """Test updating a product.""" + async with database.session() as session: + repo = ProductRepository(session) + manager = ProductManager(repo) + + # Create a product + product = await manager.save( + ProductIn( + sku="UPDATE-001", + name="Original Name", + price=50.0, + stock=10, + ) + ) + + # Update it + updated = await manager.save( + ProductIn( + id=product.id, + sku="UPDATE-001", + name="Updated Name", + price=75.0, + stock=15, + ) + ) + + assert updated.id == product.id + assert updated.sku == "UPDATE-001" + assert updated.name == "Updated Name" + assert updated.price == 75.0 + assert updated.stock == 15 + + +async def test_delete_product(database: Database) -> None: + """Test deleting a product.""" + async with database.session() as session: + repo = ProductRepository(session) + manager = ProductManager(repo) + + # Create a product + product = await manager.save( + ProductIn( + sku="DELETE-001", + name="To Be Deleted", + price=10.0, + ) + ) + + # Delete it + await manager.delete_by_id(product.id) + + # Verify it's gone + found = await manager.find_by_id(product.id) + assert found is None + + +async def test_product_entity_defaults(database: Database) -> None: + """Test product entity has correct default values.""" + async with database.session() as session: + repo = ProductRepository(session) + manager = ProductManager(repo) + + # Create with minimal data + product = await manager.save( + ProductIn( + sku="DEFAULTS-001", + name="Product with Defaults", + price=10.0, + ) + ) + + # Check defaults + assert product.stock == 0 + assert product.active is True + + +async def test_repository_find_by_id(database: Database) -> None: + """Test repository find_by_id method.""" + async with database.session() as session: + repo = ProductRepository(session) + + # Create a product directly via ORM + product = Product( + sku="REPO-001", + name="Repository Test", + price=25.0, + stock=5, + active=True, + ) + await repo.save(product) + await repo.commit() + + # Find by ID + found = await repo.find_by_id(product.id) + assert found is not None + assert found.id == product.id + assert found.sku == "REPO-001" diff --git a/packages/servicekit/tests/test_example_custom_operations_api.py b/packages/servicekit/tests/test_example_custom_operations_api.py new file mode 100644 index 0000000..b9b5676 --- /dev/null +++ b/packages/servicekit/tests/test_example_custom_operations_api.py @@ -0,0 +1,366 @@ +"""Tests for custom_operations_api example using TestClient. + +This example demonstrates custom operations with various HTTP methods. +""" + +from __future__ import annotations + +from collections.abc import Generator + +import pytest +from fastapi.testclient import TestClient + +from examples.custom_operations_api import app + + +@pytest.fixture(scope="module") +def client() -> Generator[TestClient, None, None]: + """Create FastAPI TestClient for testing with lifespan context.""" + with TestClient(app) as test_client: + yield test_client + + +def test_health_endpoint(client: TestClient) -> None: + """Test health check returns healthy status.""" + response = client.get("/health") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + + +def test_list_configs(client: TestClient) -> None: + """Test listing all seeded feature configs.""" + response = client.get("/api/v1/configs") + assert response.status_code == 200 + data = response.json() + + # Should be a list of 3 seeded configs + assert isinstance(data, list) + assert len(data) == 3 + + # Check config names + names = {config["name"] for config in data} + assert names == {"api_rate_limiting", "cache_optimization", "experimental_features"} + + # Verify structure + for config in data: + assert "id" in config + assert "name" in config + assert "data" in config + assert "enabled" in config["data"] + assert "max_requests" in config["data"] + assert "timeout_seconds" in config["data"] + assert "tags" in config["data"] + + +def test_get_config_by_id(client: TestClient) -> None: + """Test retrieving config by ID.""" + # Get the list to obtain a valid ID + list_response = client.get("/api/v1/configs") + configs = list_response.json() + config_id = configs[0]["id"] + + response = client.get(f"/api/v1/configs/{config_id}") + assert response.status_code == 200 + data = response.json() + + assert data["id"] == config_id + assert "name" in data + assert "data" in data + + +def test_enable_operation(client: TestClient) -> None: + """Test PATCH operation to toggle enabled flag.""" + # Get experimental_features config (initially disabled) + list_response = client.get("/api/v1/configs") + configs = list_response.json() + experimental = next((c for c in configs if c["name"] == "experimental_features"), None) + assert experimental is not None + config_id = experimental["id"] + initial_enabled = experimental["data"]["enabled"] + + # Toggle enabled flag + response = client.patch(f"/api/v1/configs/{config_id}/$enable", params={"enabled": not initial_enabled}) + assert response.status_code == 200 + updated = response.json() + + assert updated["id"] == config_id + assert updated["data"]["enabled"] is not initial_enabled + + # Toggle back + response2 = client.patch(f"/api/v1/configs/{config_id}/$enable", params={"enabled": initial_enabled}) + assert response2.status_code == 200 + restored = response2.json() + assert restored["data"]["enabled"] is initial_enabled + + +def test_validate_operation(client: TestClient) -> None: + """Test GET operation to validate configuration.""" + # Get a valid config + list_response = client.get("/api/v1/configs") + configs = list_response.json() + config_id = configs[0]["id"] + + response = client.get(f"/api/v1/configs/{config_id}/$validate") + assert response.status_code == 200 + validation = response.json() + + # Verify validation result structure + assert "valid" in validation + assert "errors" in validation + assert "warnings" in validation + assert isinstance(validation["errors"], list) + assert isinstance(validation["warnings"], list) + + +def test_validate_with_errors(client: TestClient) -> None: + """Test validation detects errors in configuration.""" + # Create a config with invalid values + invalid_config = { + "name": "invalid-config", + "data": { + "name": "Invalid Config", + "enabled": True, + "max_requests": 20000, # Exceeds maximum of 10000 + "timeout_seconds": 0.5, # Below minimum of 1 + "tags": [], + }, + } + + create_response = client.post("/api/v1/configs", json=invalid_config) + assert create_response.status_code == 201 + created = create_response.json() + config_id = created["id"] + + # Validate it + response = client.get(f"/api/v1/configs/{config_id}/$validate") + assert response.status_code == 200 + validation = response.json() + + # Should have errors + assert validation["valid"] is False + assert len(validation["errors"]) > 0 + assert any("max_requests" in err for err in validation["errors"]) + assert any("timeout_seconds" in err for err in validation["errors"]) + + +def test_duplicate_operation(client: TestClient) -> None: + """Test POST operation to duplicate a configuration.""" + # Get a config to duplicate + list_response = client.get("/api/v1/configs") + configs = list_response.json() + original_config = configs[0] + config_id = original_config["id"] + + # Duplicate it + response = client.post(f"/api/v1/configs/{config_id}/$duplicate", params={"new_name": "duplicated-config"}) + assert response.status_code == 201 + duplicate = response.json() + + # Verify duplicate + assert duplicate["id"] != config_id + assert duplicate["name"] == "duplicated-config" + assert duplicate["data"] == original_config["data"] + + +def test_duplicate_with_existing_name_fails(client: TestClient) -> None: + """Test duplicating with existing name returns 409.""" + # Get a config to duplicate + list_response = client.get("/api/v1/configs") + configs = list_response.json() + config_id = configs[0]["id"] + + # Try to duplicate with name that already exists + response = client.post( + f"/api/v1/configs/{config_id}/$duplicate", + params={"new_name": "api_rate_limiting"}, # Already exists + ) + assert response.status_code == 409 + data = response.json() + assert "already exists" in data["detail"].lower() + + +def test_bulk_toggle_operation(client: TestClient) -> None: + """Test PATCH collection operation to bulk enable/disable configs.""" + # Disable all configs + response = client.patch("/api/v1/configs/$bulk-toggle", json={"enabled": False, "tag_filter": None}) + assert response.status_code == 200 + result = response.json() + + assert "updated" in result + assert result["updated"] >= 3 # At least 3 seeded configs updated + + # Verify all are disabled + list_response = client.get("/api/v1/configs") + configs = list_response.json() + assert all(not c["data"]["enabled"] for c in configs) + + # Re-enable all + response2 = client.patch("/api/v1/configs/$bulk-toggle", json={"enabled": True, "tag_filter": None}) + assert response2.status_code == 200 + + +def test_bulk_toggle_with_tag_filter(client: TestClient) -> None: + """Test bulk toggle with tag filter.""" + # Toggle only configs with "api" tag (should be api_rate_limiting) + response = client.patch("/api/v1/configs/$bulk-toggle", json={"enabled": False, "tag_filter": "api"}) + assert response.status_code == 200 + result = response.json() + + # Should update at least 1 config with "api" tag + assert result["updated"] >= 1 + + # Verify only api_rate_limiting is disabled + list_response = client.get("/api/v1/configs") + configs = list_response.json() + api_config = next((c for c in configs if c["name"] == "api_rate_limiting"), None) + assert api_config is not None + assert api_config["data"]["enabled"] is False + + +def test_stats_operation(client: TestClient) -> None: + """Test GET collection operation to get statistics.""" + response = client.get("/api/v1/configs/$stats") + assert response.status_code == 200 + stats = response.json() + + # Verify stats structure + assert "total" in stats + assert "enabled" in stats + assert "disabled" in stats + assert "avg_max_requests" in stats + assert "tags" in stats + + assert stats["total"] >= 3 # At least 3 seeded configs + assert stats["enabled"] + stats["disabled"] == stats["total"] + assert isinstance(stats["avg_max_requests"], (int, float)) + assert isinstance(stats["tags"], dict) + + # Verify tag counts + expected_tags = {"api", "security", "performance", "cache", "experimental", "beta", "new", "test", "updated"} + assert set(stats["tags"].keys()).issubset(expected_tags) + + +def test_reset_operation(client: TestClient) -> None: + """Test POST collection operation to reset all configurations.""" + # First, modify some configs + list_response = client.get("/api/v1/configs") + configs = list_response.json() + config_id = configs[0]["id"] + + # Update with different values + client.patch(f"/api/v1/configs/{config_id}/$enable", params={"enabled": False}) + + # Reset all + response = client.post("/api/v1/configs/$reset") + assert response.status_code == 200 + result = response.json() + + assert "reset" in result + assert result["reset"] >= 3 # At least 3 seeded configs reset + + # Verify configs are reset to defaults (check at least the seeded ones) + list_response2 = client.get("/api/v1/configs") + configs2 = list_response2.json() + + # Check that at least 3 configs have default values + default_configs = [ + c + for c in configs2 + if c["data"]["enabled"] is True + and c["data"]["max_requests"] == 1000 + and c["data"]["timeout_seconds"] == 30.0 + and c["data"]["tags"] == [] + ] + assert len(default_configs) >= 3 + + +def test_standard_crud_create(client: TestClient) -> None: + """Test standard POST to create a config.""" + new_config = { + "name": "new-feature", + "data": { + "name": "New Feature", + "enabled": True, + "max_requests": 500, + "timeout_seconds": 45.0, + "tags": ["new", "test"], + }, + } + + response = client.post("/api/v1/configs", json=new_config) + assert response.status_code == 201 + created = response.json() + + assert created["name"] == "new-feature" + assert created["data"]["max_requests"] == 500 + assert created["data"]["tags"] == ["new", "test"] + + +def test_standard_crud_update(client: TestClient) -> None: + """Test standard PUT to update a config.""" + # Create a config + new_config = { + "name": "update-test", + "data": {"name": "Update Test", "enabled": False, "max_requests": 100, "timeout_seconds": 10.0, "tags": []}, + } + create_response = client.post("/api/v1/configs", json=new_config) + created = create_response.json() + config_id = created["id"] + + # Update it + updated_config = { + "id": config_id, + "name": "update-test", + "data": { + "name": "Updated Test", + "enabled": True, + "max_requests": 200, + "timeout_seconds": 20.0, + "tags": ["updated"], + }, + } + + response = client.put(f"/api/v1/configs/{config_id}", json=updated_config) + assert response.status_code == 200 + updated = response.json() + + assert updated["data"]["enabled"] is True + assert updated["data"]["max_requests"] == 200 + assert updated["data"]["tags"] == ["updated"] + + +def test_standard_crud_delete(client: TestClient) -> None: + """Test standard DELETE to remove a config.""" + # Create a config + new_config = { + "name": "delete-test", + "data": {"name": "Delete Test", "enabled": True, "max_requests": 100, "timeout_seconds": 10.0, "tags": []}, + } + create_response = client.post("/api/v1/configs", json=new_config) + created = create_response.json() + config_id = created["id"] + + # Delete it + response = client.delete(f"/api/v1/configs/{config_id}") + assert response.status_code == 204 + + # Verify it's gone + get_response = client.get(f"/api/v1/configs/{config_id}") + assert get_response.status_code == 404 + + +def test_validate_not_found(client: TestClient) -> None: + """Test validate operation on non-existent config returns 404.""" + response = client.get("/api/v1/configs/01K72P5N5KCRM6MD3BRE4P0999/$validate") + assert response.status_code == 404 + data = response.json() + assert "not found" in data["detail"].lower() + + +def test_duplicate_not_found(client: TestClient) -> None: + """Test duplicate operation on non-existent config returns 404.""" + response = client.post("/api/v1/configs/01K72P5N5KCRM6MD3BRE4P0999/$duplicate", params={"new_name": "test"}) + assert response.status_code == 404 + data = response.json() + assert "not found" in data["detail"].lower() diff --git a/packages/servicekit/tests/test_example_full_featured_api.py b/packages/servicekit/tests/test_example_full_featured_api.py new file mode 100644 index 0000000..708e001 --- /dev/null +++ b/packages/servicekit/tests/test_example_full_featured_api.py @@ -0,0 +1,509 @@ +"""Tests for full_featured_api example showcasing all chapkit features. + +This test suite validates the comprehensive example that demonstrates: +- Health checks (custom + database) +- System info +- Config management +- Artifacts with hierarchy +- Config-artifact linking +- Task execution +- Job scheduling +- Custom routers +- Landing page +""" + +from __future__ import annotations + +import time +from collections.abc import Generator +from typing import Any, cast + +import pytest +from fastapi.testclient import TestClient + +from examples.full_featured_api import app + + +@pytest.fixture(scope="module") +def client() -> Generator[TestClient, None, None]: + """Create FastAPI TestClient with lifespan context.""" + with TestClient(app) as test_client: + yield test_client + + +# ==================== Basic Endpoints ==================== + + +def test_landing_page(client: TestClient) -> None: + """Test landing page returns HTML.""" + response = client.get("/") + assert response.status_code == 200 + assert "text/html" in response.headers["content-type"] + + +def test_health_endpoint(client: TestClient) -> None: + """Test health check returns healthy status with custom checks.""" + response = client.get("/health") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + assert "checks" in data + + # Verify both database and custom external_service checks + assert "database" in data["checks"] + assert data["checks"]["database"]["state"] == "healthy" + assert "external_service" in data["checks"] + assert data["checks"]["external_service"]["state"] == "healthy" + + +def test_system_endpoint(client: TestClient) -> None: + """Test system info endpoint returns metadata.""" + response = client.get("/api/v1/system") + assert response.status_code == 200 + data = response.json() + + assert "current_time" in data + assert "timezone" in data + assert "python_version" in data + assert "platform" in data + assert "hostname" in data + + +def test_info_endpoint(client: TestClient) -> None: + """Test service info endpoint returns service metadata.""" + response = client.get("/api/v1/info") + assert response.status_code == 200 + data = response.json() + + assert data["display_name"] == "Complete Feature Showcase" + assert data["version"] == "2.0.0" + assert data["summary"] == "Comprehensive example demonstrating ALL chapkit features" + assert data["contact"]["name"] == "Chapkit Team" + assert data["license_info"]["name"] == "MIT" + + +# ==================== Seeded Data Tests ==================== + + +def test_seeded_configs(client: TestClient) -> None: + """Test that startup hook seeded example config.""" + response = client.get("/api/v1/configs") + assert response.status_code == 200 + configs = response.json() + + # Should have at least the seeded config + assert len(configs) >= 1 + + # Find the seeded config by name + production_config = next((c for c in configs if c["name"] == "production_pipeline"), None) + assert production_config is not None + assert production_config["id"] == "01JCSEED00C0NF1GEXAMP1E001" + + # Verify config data structure + data = production_config["data"] + assert data["model_type"] == "xgboost" + assert data["learning_rate"] == 0.01 + assert data["max_epochs"] == 500 + assert data["batch_size"] == 64 + assert data["early_stopping"] is True + assert data["random_seed"] == 42 + + +def test_seeded_artifacts(client: TestClient) -> None: + """Test that startup hook seeded example artifact.""" + response = client.get("/api/v1/artifacts") + assert response.status_code == 200 + artifacts = response.json() + + # Should have at least the seeded artifact + assert len(artifacts) >= 1 + + # Find the seeded artifact by ID + seeded_artifact_id = "01JCSEED00ART1FACTEXMP1001" + artifact_response = client.get(f"/api/v1/artifacts/{seeded_artifact_id}") + assert artifact_response.status_code == 200 + + artifact = artifact_response.json() + assert artifact["id"] == seeded_artifact_id + assert artifact["data"]["experiment_name"] == "baseline_experiment" + assert artifact["data"]["model_metrics"]["accuracy"] == 0.95 + assert artifact["data"]["model_metrics"]["f1_score"] == 0.93 + assert artifact["data"]["dataset_info"]["train_size"] == 10000 + + +def test_seeded_tasks(client: TestClient) -> None: + """Test that startup hook seeded example tasks.""" + response = client.get("/api/v1/tasks") + assert response.status_code == 200 + tasks = response.json() + + # Should have at least 2 seeded tasks + assert len(tasks) >= 2 + + # Find seeded tasks by ID + task_ids = {task["id"] for task in tasks} + assert "01JCSEED00TASKEXAMP1E00001" in task_ids + assert "01JCSEED00TASKEXAMP1E00002" in task_ids + + # Verify task commands + task1 = next((t for t in tasks if t["id"] == "01JCSEED00TASKEXAMP1E00001"), None) + assert task1 is not None + assert "Training model" in task1["command"] + + task2 = next((t for t in tasks if t["id"] == "01JCSEED00TASKEXAMP1E00002"), None) + assert task2 is not None + assert "Python" in task2["command"] + + +# ==================== Config Management Tests ==================== + + +def test_config_crud(client: TestClient) -> None: + """Test full config CRUD operations.""" + # Create + new_config = { + "name": "test_pipeline", + "data": { + "model_type": "random_forest", + "learning_rate": 0.001, + "max_epochs": 100, + "batch_size": 32, + "early_stopping": True, + "random_seed": 123, + }, + } + + create_response = client.post("/api/v1/configs", json=new_config) + assert create_response.status_code == 201 + created = create_response.json() + config_id = created["id"] + + assert created["name"] == "test_pipeline" + assert created["data"]["model_type"] == "random_forest" + + # Read + get_response = client.get(f"/api/v1/configs/{config_id}") + assert get_response.status_code == 200 + fetched = get_response.json() + assert fetched["id"] == config_id + + # Update + fetched["data"]["max_epochs"] = 200 + update_response = client.put(f"/api/v1/configs/{config_id}", json=fetched) + assert update_response.status_code == 200 + updated = update_response.json() + assert updated["data"]["max_epochs"] == 200 + + # Delete + delete_response = client.delete(f"/api/v1/configs/{config_id}") + assert delete_response.status_code == 204 + + # Verify deletion + get_after_delete = client.get(f"/api/v1/configs/{config_id}") + assert get_after_delete.status_code == 404 + + +def test_config_pagination(client: TestClient) -> None: + """Test config pagination.""" + response = client.get("/api/v1/configs", params={"page": 1, "size": 2}) + assert response.status_code == 200 + data = response.json() + + assert "items" in data + assert "total" in data + assert "page" in data + assert "size" in data + assert "pages" in data + assert len(data["items"]) <= 2 + + +# ==================== Artifact Tests ==================== + + +def test_artifact_crud_with_hierarchy(client: TestClient) -> None: + """Test artifact CRUD with hierarchical relationships.""" + # Create root artifact (level 0: experiment) + root_artifact = { + "data": { + "experiment_name": "test_experiment", + "description": "Testing artifact hierarchy", + }, + } + + root_response = client.post("/api/v1/artifacts", json=root_artifact) + assert root_response.status_code == 201 + root = root_response.json() + root_id = root["id"] + assert root["level"] == 0 + + # Create child artifact (level 1: training) + child_artifact = { + "data": {"training_loss": 0.05, "epoch": 10}, + "parent_id": root_id, + } + + child_response = client.post("/api/v1/artifacts", json=child_artifact) + assert child_response.status_code == 201 + child = child_response.json() + child_id = child["id"] + assert child["level"] == 1 + assert child["parent_id"] == root_id + + # Get tree structure + tree_response = client.get(f"/api/v1/artifacts/{root_id}/$tree") + assert tree_response.status_code == 200 + tree = tree_response.json() + + assert tree["id"] == root_id + assert tree["level"] == 0 + assert len(tree["children"]) >= 1 + assert any(c["id"] == child_id for c in tree["children"]) + + # Cleanup + client.delete(f"/api/v1/artifacts/{child_id}") + client.delete(f"/api/v1/artifacts/{root_id}") + + +def test_artifact_tree_endpoint(client: TestClient) -> None: + """Test artifact tree operation with seeded data.""" + seeded_artifact_id = "01JCSEED00ART1FACTEXMP1001" + + tree_response = client.get(f"/api/v1/artifacts/{seeded_artifact_id}/$tree") + assert tree_response.status_code == 200 + tree = tree_response.json() + + assert tree["id"] == seeded_artifact_id + assert "level" in tree + assert "children" in tree + # Children can be None or empty list if no children exist + assert tree["children"] is None or isinstance(tree["children"], list) + + +# ==================== Config-Artifact Linking Tests ==================== + + +def test_config_artifact_linking(client: TestClient) -> None: + """Test linking configs to root artifacts.""" + # Create a config + config = { + "name": "linking_test", + "data": { + "model_type": "xgboost", + "learning_rate": 0.01, + "max_epochs": 100, + "batch_size": 32, + "early_stopping": True, + "random_seed": 42, + }, + } + config_response = client.post("/api/v1/configs", json=config) + config_id = config_response.json()["id"] + + # Create a root artifact (no parent_id means it's a root) + artifact = {"data": {"experiment": "linking_test"}, "parent_id": None} + artifact_response = client.post("/api/v1/artifacts", json=artifact) + assert artifact_response.status_code == 201 + artifact_id = artifact_response.json()["id"] + + # Verify it's a root artifact (level 0) + artifact_get = client.get(f"/api/v1/artifacts/{artifact_id}") + assert artifact_get.json()["level"] == 0 + + # Link them + link_response = client.post(f"/api/v1/configs/{config_id}/$link-artifact", json={"artifact_id": artifact_id}) + # Accept either 204 or 400 (in case linking not fully supported) + if link_response.status_code == 204: + # Verify link by getting artifacts for config + linked_response = client.get(f"/api/v1/configs/{config_id}/$artifacts") + assert linked_response.status_code == 200 + linked_artifacts = linked_response.json() + assert len(linked_artifacts) >= 1 + assert any(a["id"] == artifact_id for a in linked_artifacts) + + # Unlink + unlink_response = client.post( + f"/api/v1/configs/{config_id}/$unlink-artifact", json={"artifact_id": artifact_id} + ) + assert unlink_response.status_code == 204 + + # Cleanup + client.delete(f"/api/v1/artifacts/{artifact_id}") + client.delete(f"/api/v1/configs/{config_id}") + + +# ==================== Task Execution Tests ==================== + + +def test_task_crud(client: TestClient) -> None: + """Test task CRUD operations.""" + # Create + task = {"command": "echo 'test task'"} + create_response = client.post("/api/v1/tasks", json=task) + assert create_response.status_code == 201 + created = create_response.json() + task_id = created["id"] + assert created["command"] == "echo 'test task'" + + # Read + get_response = client.get(f"/api/v1/tasks/{task_id}") + assert get_response.status_code == 200 + + # Update + updated_task = {"command": "echo 'updated task'"} + update_response = client.put(f"/api/v1/tasks/{task_id}", json=updated_task) + assert update_response.status_code == 200 + assert update_response.json()["command"] == "echo 'updated task'" + + # Delete + delete_response = client.delete(f"/api/v1/tasks/{task_id}") + assert delete_response.status_code == 204 + + +def test_task_execution_creates_job(client: TestClient) -> None: + """Test that executing a task creates a job.""" + # Create a simple task + task = {"command": "echo 'Hello from task'"} + task_response = client.post("/api/v1/tasks", json=task) + task_id = task_response.json()["id"] + + # Execute the task + execute_response = client.post(f"/api/v1/tasks/{task_id}/$execute") + assert execute_response.status_code == 202 # Accepted + job_data = execute_response.json() + assert "job_id" in job_data + assert "message" in job_data + + job_id = job_data["job_id"] + + # Wait for job completion + job = wait_for_job_completion(client, job_id) + + assert job["status"] in ["completed", "failed"] + + # If completed, verify artifact was created + if job["status"] == "completed" and job["artifact_id"]: + artifact_response = client.get(f"/api/v1/artifacts/{job['artifact_id']}") + assert artifact_response.status_code == 200 + artifact = artifact_response.json() + + # Verify artifact contains task snapshot and outputs + assert "task" in artifact["data"] + assert "stdout" in artifact["data"] + assert "stderr" in artifact["data"] + assert "exit_code" in artifact["data"] + assert artifact["data"]["task"]["id"] == task_id + + # Cleanup + client.delete(f"/api/v1/jobs/{job_id}") + client.delete(f"/api/v1/tasks/{task_id}") + + +# ==================== Job Tests ==================== + + +def test_list_jobs(client: TestClient) -> None: + """Test listing jobs.""" + response = client.get("/api/v1/jobs") + assert response.status_code == 200 + jobs = response.json() + assert isinstance(jobs, list) + + +def test_get_job_by_id(client: TestClient) -> None: + """Test getting job by ID.""" + # Create and execute a task to get a job + task = {"command": "echo 'job test'"} + task_response = client.post("/api/v1/tasks", json=task) + task_id = task_response.json()["id"] + + execute_response = client.post(f"/api/v1/tasks/{task_id}/$execute") + job_id = execute_response.json()["job_id"] + + # Get job + job_response = client.get(f"/api/v1/jobs/{job_id}") + assert job_response.status_code == 200 + job = job_response.json() + + assert job["id"] == job_id + assert "status" in job + assert "submitted_at" in job + + # Cleanup + wait_for_job_completion(client, job_id) + client.delete(f"/api/v1/jobs/{job_id}") + client.delete(f"/api/v1/tasks/{task_id}") + + +def test_filter_jobs_by_status(client: TestClient) -> None: + """Test filtering jobs by status.""" + response = client.get("/api/v1/jobs", params={"status_filter": "completed"}) + assert response.status_code == 200 + jobs = cast(list[dict[str, Any]], response.json()) + assert isinstance(jobs, list) + + # All returned jobs should be completed + for job in jobs: + assert job["status"] == "completed" + + +# ==================== Custom Router Tests ==================== + + +def test_custom_stats_endpoint(client: TestClient) -> None: + """Test custom statistics router.""" + response = client.get("/api/v1/stats") + assert response.status_code == 200 + stats = response.json() + + assert "total_configs" in stats + assert "total_artifacts" in stats + assert "total_tasks" in stats + assert "service_version" in stats + + # Should have at least the seeded data + assert stats["total_configs"] >= 1 + assert stats["service_version"] == "2.0.0" + + +# ==================== OpenAPI Documentation Tests ==================== + + +def test_openapi_schema(client: TestClient) -> None: + """Test that OpenAPI schema includes all expected endpoints.""" + response = client.get("/openapi.json") + assert response.status_code == 200 + schema = response.json() + + paths = schema["paths"] + + # Verify all major endpoint groups are present + assert "/health" in paths + assert "/api/v1/system" in paths + assert "/api/v1/configs" in paths + assert "/api/v1/artifacts" in paths + assert "/api/v1/tasks" in paths + assert "/api/v1/jobs" in paths + assert "/api/v1/stats" in paths + + # Verify operation endpoints + assert "/api/v1/artifacts/{entity_id}/$tree" in paths + assert "/api/v1/tasks/{entity_id}/$execute" in paths + assert "/api/v1/configs/{entity_id}/$artifacts" in paths + assert "/api/v1/configs/{entity_id}/$link-artifact" in paths + + +# ==================== Helper Functions ==================== + + +def wait_for_job_completion(client: TestClient, job_id: str, timeout: float = 5.0) -> dict[Any, Any]: + """Poll job status until completion or timeout.""" + start_time = time.time() + while time.time() - start_time < timeout: + job_response = client.get(f"/api/v1/jobs/{job_id}") + assert job_response.status_code == 200 + job = cast(dict[Any, Any], job_response.json()) + + if job["status"] in ["completed", "failed", "canceled"]: + return job + time.sleep(0.1) + raise TimeoutError(f"Job {job_id} did not complete within {timeout}s") diff --git a/packages/servicekit/tests/test_example_job_scheduler_api.py b/packages/servicekit/tests/test_example_job_scheduler_api.py new file mode 100644 index 0000000..1a1bcae --- /dev/null +++ b/packages/servicekit/tests/test_example_job_scheduler_api.py @@ -0,0 +1,287 @@ +"""Tests for job_scheduler_api example using TestClient. + +This example demonstrates the job scheduler for async long-running tasks. +""" + +from __future__ import annotations + +import time +from collections.abc import Generator + +import pytest +from fastapi.testclient import TestClient + +from examples.job_scheduler_api import app + + +@pytest.fixture(scope="module") +def client() -> Generator[TestClient, None, None]: + """Create FastAPI TestClient for testing with lifespan context.""" + with TestClient(app) as test_client: + yield test_client + + +def test_health_endpoint(client: TestClient) -> None: + """Test health check returns healthy status.""" + response = client.get("/health") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + + +def test_submit_computation_job(client: TestClient) -> None: + """Test submitting a computation job returns 202 with job ID and Location header.""" + compute_request = { + "duration": 0.5 # Short duration for testing + } + + response = client.post("/api/v1/compute", json=compute_request) + assert response.status_code == 202 + data = response.json() + + # Verify response structure + assert "job_id" in data + assert "message" in data + assert "Poll GET /api/v1/jobs/" in data["message"] + + # Verify Location header + assert "Location" in response.headers + location = response.headers["Location"] + assert f"/api/v1/jobs/{data['job_id']}" in location + + +def test_list_jobs(client: TestClient) -> None: + """Test listing all jobs.""" + # Submit a job first + compute_request = {"duration": 0.1} + client.post("/api/v1/compute", json=compute_request) + + # List jobs + response = client.get("/api/v1/jobs") + assert response.status_code == 200 + data = response.json() + + assert isinstance(data, list) + assert len(data) > 0 + + # Verify job structure + job = data[0] + assert "id" in job + assert "status" in job + assert "submitted_at" in job + assert job["status"] in ["pending", "running", "completed", "failed", "canceled"] + + +def test_list_jobs_with_status_filter(client: TestClient) -> None: + """Test listing jobs filtered by status.""" + # Submit and wait for completion + compute_request = {"duration": 0.1} + submit_response = client.post("/api/v1/compute", json=compute_request) + job_id = submit_response.json()["job_id"] + + # Wait for completion + time.sleep(0.5) + + # Filter for completed jobs + response = client.get("/api/v1/jobs", params={"status_filter": "completed"}) + assert response.status_code == 200 + jobs = response.json() + + # Should include our completed job + assert any(job["id"] == job_id for job in jobs) + assert all(job["status"] == "completed" for job in jobs) + + +def test_get_job_record(client: TestClient) -> None: + """Test retrieving a specific job record.""" + # Submit a job + compute_request = {"duration": 0.1} + submit_response = client.post("/api/v1/compute", json=compute_request) + job_id = submit_response.json()["job_id"] + + # Get job record + response = client.get(f"/api/v1/jobs/{job_id}") + assert response.status_code == 200 + job = response.json() + + assert job["id"] == job_id + assert "status" in job + assert "submitted_at" in job + assert job["status"] in ["pending", "running", "completed"] + + +def test_get_job_record_not_found(client: TestClient) -> None: + """Test retrieving non-existent job returns 404.""" + response = client.get("/api/v1/jobs/01K72P5N5KCRM6MD3BRE4P0999") + assert response.status_code == 404 + data = response.json() + assert "not found" in data["detail"].lower() + + +def test_get_computation_result_pending(client: TestClient) -> None: + """Test getting result of pending/running job.""" + # Submit a longer-running job + compute_request = {"duration": 5.0} + submit_response = client.post("/api/v1/compute", json=compute_request) + job_id = submit_response.json()["job_id"] + + # Immediately check result (should be pending or running) + response = client.get(f"/api/v1/compute/{job_id}/result") + assert response.status_code == 200 + result = response.json() + + assert result["job_id"] == job_id + assert result["status"] in ["pending", "running"] + assert result["result"] is None + assert result["error"] is None + + +def test_get_computation_result_completed(client: TestClient) -> None: + """Test getting result of completed job.""" + # Submit a short job + compute_request = {"duration": 0.2} + submit_response = client.post("/api/v1/compute", json=compute_request) + job_id = submit_response.json()["job_id"] + + # Wait for completion + time.sleep(0.5) + + # Get result + response = client.get(f"/api/v1/compute/{job_id}/result") + assert response.status_code == 200 + result = response.json() + + assert result["job_id"] == job_id + assert result["status"] == "completed" + assert result["result"] == 42 # Expected result from long_running_computation + assert result["error"] is None + assert result["submitted_at"] is not None + assert result["started_at"] is not None + assert result["finished_at"] is not None + + +def test_job_status_transitions(client: TestClient) -> None: + """Test job status transitions from pending -> running -> completed.""" + # Submit a job + compute_request = {"duration": 0.3} + submit_response = client.post("/api/v1/compute", json=compute_request) + job_id = submit_response.json()["job_id"] + + # Check initial status (should be pending or running) + response1 = client.get(f"/api/v1/jobs/{job_id}") + job1 = response1.json() + assert job1["status"] in ["pending", "running"] + + # Wait a bit + time.sleep(0.1) + + # Check status again (likely running) + response2 = client.get(f"/api/v1/jobs/{job_id}") + job2 = response2.json() + assert job2["status"] in ["running", "completed"] + + # Wait for completion + time.sleep(0.5) + + # Check final status (should be completed) + response3 = client.get(f"/api/v1/jobs/{job_id}") + job3 = response3.json() + assert job3["status"] == "completed" + + +def test_cancel_job(client: TestClient) -> None: + """Test canceling a running job.""" + # Submit a long-running job + compute_request = {"duration": 10.0} + submit_response = client.post("/api/v1/compute", json=compute_request) + job_id = submit_response.json()["job_id"] + + # Wait a bit to ensure it starts + time.sleep(0.2) + + # Cancel it + response = client.delete(f"/api/v1/jobs/{job_id}") + assert response.status_code == 204 + + # Verify it's canceled (or at least not running) + # Note: The job might complete before cancellation in some cases + job_response = client.get(f"/api/v1/jobs/{job_id}") + # Job record should be deleted, so we expect 404 + assert job_response.status_code == 404 + + +def test_delete_completed_job(client: TestClient) -> None: + """Test deleting a completed job record.""" + # Submit and wait for completion + compute_request = {"duration": 0.1} + submit_response = client.post("/api/v1/compute", json=compute_request) + job_id = submit_response.json()["job_id"] + + time.sleep(0.3) + + # Delete the completed job + response = client.delete(f"/api/v1/jobs/{job_id}") + assert response.status_code == 204 + + # Verify it's gone + get_response = client.get(f"/api/v1/jobs/{job_id}") + assert get_response.status_code == 404 + + +def test_delete_job_not_found(client: TestClient) -> None: + """Test deleting non-existent job returns 404.""" + response = client.delete("/api/v1/jobs/01K72P5N5KCRM6MD3BRE4P0999") + assert response.status_code == 404 + data = response.json() + assert "not found" in data["detail"].lower() + + +def test_submit_multiple_concurrent_jobs(client: TestClient) -> None: + """Test submitting multiple jobs respects max_concurrency=5.""" + # Submit 7 jobs (max_concurrency is 5) + job_ids = [] + for _ in range(7): + compute_request = {"duration": 0.2} + response = client.post("/api/v1/compute", json=compute_request) + job_data = response.json() + job_ids.append(job_data["job_id"]) + + # All should be accepted + assert len(job_ids) == 7 + + # Wait for completion + time.sleep(0.5) + + # Check that all jobs completed + for job_id in job_ids: + response = client.get(f"/api/v1/jobs/{job_id}") + if response.status_code == 200: + job = response.json() + assert job["status"] in ["completed", "running"] + + +def test_invalid_duration_too_low(client: TestClient) -> None: + """Test submitting job with duration too low fails validation.""" + compute_request = {"duration": 0.05} # Below minimum of 0.1 + + response = client.post("/api/v1/compute", json=compute_request) + assert response.status_code == 422 # Validation error + data = response.json() + assert "detail" in data + + +def test_invalid_duration_too_high(client: TestClient) -> None: + """Test submitting job with duration too high fails validation.""" + compute_request = {"duration": 100.0} # Above maximum of 60 + + response = client.post("/api/v1/compute", json=compute_request) + assert response.status_code == 422 # Validation error + data = response.json() + assert "detail" in data + + +def test_get_result_for_nonexistent_job(client: TestClient) -> None: + """Test getting result for non-existent job returns error.""" + response = client.get("/api/v1/compute/01K72P5N5KCRM6MD3BRE4P0999/result") + # May return 404 or 500 depending on implementation + assert response.status_code in [404, 500] diff --git a/packages/servicekit/tests/test_example_job_scheduler_sse_api.py b/packages/servicekit/tests/test_example_job_scheduler_sse_api.py new file mode 100644 index 0000000..1f73a49 --- /dev/null +++ b/packages/servicekit/tests/test_example_job_scheduler_sse_api.py @@ -0,0 +1,129 @@ +"""Tests for job_scheduler_sse_api.py example.""" + +import asyncio +import json +from collections.abc import AsyncGenerator + +import pytest +from fastapi import FastAPI +from httpx import ASGITransport, AsyncClient + + +@pytest.fixture +async def app() -> AsyncGenerator[FastAPI, None]: + """Load job_scheduler_sse_api.py app and trigger lifespan.""" + import sys + from pathlib import Path + + examples_dir = Path(__file__).parent.parent / "examples" + sys.path.insert(0, str(examples_dir)) + + from job_scheduler_sse_api import app as example_app # type: ignore[import-not-found] + + async with example_app.router.lifespan_context(example_app): + yield example_app + + sys.path.remove(str(examples_dir)) + + +@pytest.fixture +async def client(app: FastAPI) -> AsyncGenerator[AsyncClient, None]: + """Create async test client.""" + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test", follow_redirects=True) as ac: + yield ac + + +@pytest.mark.asyncio +async def test_health_endpoint(client: AsyncClient): + """Test health endpoint is available.""" + response = await client.get("/health") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + + +@pytest.mark.asyncio +async def test_submit_slow_compute_job(client: AsyncClient): + """Test submitting slow computation job.""" + response = await client.post("/api/v1/slow-compute", json={"steps": 10}) + assert response.status_code == 202 + data = response.json() + assert "job_id" in data + assert "stream_url" in data + assert data["stream_url"].startswith("/api/v1/jobs/") + assert data["stream_url"].endswith("/$stream") + assert "Location" in response.headers + + +@pytest.mark.asyncio +async def test_slow_compute_validation(client: AsyncClient): + """Test request validation for slow compute.""" + # Too few steps + response = await client.post("/api/v1/slow-compute", json={"steps": 5}) + assert response.status_code == 422 + + # Too many steps + response = await client.post("/api/v1/slow-compute", json={"steps": 100}) + assert response.status_code == 422 + + +@pytest.mark.asyncio +async def test_sse_streaming_slow_compute(client: AsyncClient): + """Test SSE streaming for slow computation job.""" + # Submit job + response = await client.post("/api/v1/slow-compute", json={"steps": 10}) + assert response.status_code == 202 + job_id = response.json()["job_id"] + + # Stream status updates + events = [] + async with client.stream("GET", f"/api/v1/jobs/{job_id}/$stream?poll_interval=0.1") as stream_response: + assert stream_response.status_code == 200 + assert stream_response.headers["content-type"] == "text/event-stream; charset=utf-8" + + async for line in stream_response.aiter_lines(): + if line.startswith("data: "): + data = json.loads(line[6:]) + events.append(data) + if data["status"] == "completed": + break + + # Verify we got multiple events showing progress + assert len(events) >= 1 + assert events[-1]["status"] == "completed" + # artifact_id is null because task returns SlowComputeResult, not ULID + assert events[-1]["artifact_id"] is None + + +@pytest.mark.asyncio +async def test_job_completes_successfully(client: AsyncClient): + """Test that job completes successfully.""" + # Submit and wait + response = await client.post("/api/v1/slow-compute", json={"steps": 10}) + job_id = response.json()["job_id"] + + # Wait for completion via polling + job = None + for _ in range(30): + job_response = await client.get(f"/api/v1/jobs/{job_id}") + job = job_response.json() + if job["status"] == "completed": + break + await asyncio.sleep(0.5) + + # Verify job completed + assert job is not None + assert job["status"] == "completed" + assert job["finished_at"] is not None + + +@pytest.mark.asyncio +async def test_openapi_schema_includes_endpoints(client: AsyncClient): + """Test OpenAPI schema includes slow-compute and SSE endpoints.""" + response = await client.get("/openapi.json") + assert response.status_code == 200 + schema = response.json() + + paths = schema["paths"] + assert "/api/v1/slow-compute" in paths + assert "/api/v1/jobs/{job_id}/$stream" in paths diff --git a/packages/servicekit/tests/test_example_library_usage_api.py b/packages/servicekit/tests/test_example_library_usage_api.py new file mode 100644 index 0000000..2a5e870 --- /dev/null +++ b/packages/servicekit/tests/test_example_library_usage_api.py @@ -0,0 +1,324 @@ +"""Tests for library_usage_api example using TestClient. + +This example demonstrates using chapkit as a library with custom models. +Tests use FastAPI's TestClient instead of running a separate server. +""" + +from __future__ import annotations + +from collections.abc import Generator + +import pytest +from fastapi.testclient import TestClient + +from examples.library_usage_api import app + + +@pytest.fixture(scope="module") +def client() -> Generator[TestClient, None, None]: + """Create FastAPI TestClient for testing with lifespan context.""" + with TestClient(app) as test_client: + yield test_client + + +def test_landing_page(client: TestClient) -> None: + """Test landing page returns HTML.""" + response = client.get("/") + assert response.status_code == 200 + assert "text/html" in response.headers["content-type"] + + +def test_health_endpoint(client: TestClient) -> None: + """Test health check returns healthy status.""" + response = client.get("/health") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + + +def test_list_configs(client: TestClient) -> None: + """Test listing configs using chapkit's Config model.""" + response = client.get("/api/v1/configs") + assert response.status_code == 200 + data = response.json() + + # Should have at least the seeded production config + assert isinstance(data, list) + assert len(data) >= 1 + + # Find production config + production = next((c for c in data if c["name"] == "production"), None) + assert production is not None + assert "data" in production + assert "max_users" in production["data"] + assert "registration_enabled" in production["data"] + assert "default_theme" in production["data"] + + +def test_create_config(client: TestClient) -> None: + """Test creating a config with ApiConfig schema.""" + new_config = {"name": "staging", "data": {"max_users": 500, "registration_enabled": True, "default_theme": "light"}} + + response = client.post("/api/v1/configs", json=new_config) + assert response.status_code == 201 + created = response.json() + + assert created["name"] == "staging" + assert created["data"]["max_users"] == 500 + assert created["data"]["default_theme"] == "light" + + +def test_list_users(client: TestClient) -> None: + """Test listing users using custom User model.""" + response = client.get("/api/v1/users") + assert response.status_code == 200 + data = response.json() + + # Should have at least the seeded admin user + assert isinstance(data, list) + assert len(data) >= 1 + + # Find admin user + admin = next((u for u in data if u["username"] == "admin"), None) + assert admin is not None + assert admin["email"] == "admin@example.com" + assert admin["full_name"] == "Administrator" + assert "preferences" in admin + assert admin["preferences"]["theme"] == "dark" + + +def test_get_user_by_id(client: TestClient) -> None: + """Test retrieving user by ID.""" + # Get list to find admin user ID + list_response = client.get("/api/v1/users") + users = list_response.json() + admin = next((u for u in users if u["username"] == "admin"), None) + assert admin is not None + user_id = admin["id"] + + # Get user by ID + response = client.get(f"/api/v1/users/{user_id}") + assert response.status_code == 200 + user = response.json() + + assert user["id"] == user_id + assert user["username"] == "admin" + assert user["email"] == "admin@example.com" + + +def test_get_user_not_found(client: TestClient) -> None: + """Test retrieving non-existent user returns 404.""" + response = client.get("/api/v1/users/01K72P5N5KCRM6MD3BRE4P0999") + assert response.status_code == 404 + data = response.json() + assert "not found" in data["detail"].lower() + + +def test_create_user(client: TestClient) -> None: + """Test creating a user with custom User model.""" + new_user = { + "username": "newuser", + "email": "newuser@example.com", + "full_name": "New User", + "preferences": {"theme": "light", "notifications": False, "language": "en"}, + } + + response = client.post("/api/v1/users", json=new_user) + assert response.status_code == 201 + created = response.json() + + assert "id" in created + assert created["username"] == "newuser" + assert created["email"] == "newuser@example.com" + assert created["full_name"] == "New User" + assert created["preferences"]["theme"] == "light" + assert created["preferences"]["notifications"] is False + + +def test_create_user_duplicate_username(client: TestClient) -> None: + """Test creating user with duplicate username fails.""" + duplicate_user = { + "username": "admin", # Already exists + "email": "different@example.com", + "full_name": "Different User", + } + + response = client.post("/api/v1/users", json=duplicate_user) + # Should fail due to unique constraint on username + assert response.status_code in [400, 409, 500] + + +def test_create_user_duplicate_email(client: TestClient) -> None: + """Test creating user with duplicate email fails.""" + duplicate_user = { + "username": "different", + "email": "admin@example.com", # Already exists + "full_name": "Different User", + } + + response = client.post("/api/v1/users", json=duplicate_user) + # Should fail due to unique constraint on email + assert response.status_code in [400, 409, 500] + + +def test_create_user_with_minimal_fields(client: TestClient) -> None: + """Test creating user with only required fields.""" + minimal_user = { + "username": "minimal", + "email": "minimal@example.com", + # full_name and preferences are optional + } + + response = client.post("/api/v1/users", json=minimal_user) + assert response.status_code == 201 + created = response.json() + + assert created["username"] == "minimal" + assert created["email"] == "minimal@example.com" + assert created["full_name"] is None + assert created["preferences"] == {} + + +def test_create_user_invalid_email(client: TestClient) -> None: + """Test creating user with invalid email format fails.""" + invalid_user = { + "username": "invalidtest", + "email": "not-an-email", + "full_name": "Invalid Test", + } + + response = client.post("/api/v1/users", json=invalid_user) + assert response.status_code == 422 + data = response.json() + assert "email" in str(data).lower() + + +def test_update_user(client: TestClient) -> None: + """Test updating a user.""" + # Create a user first + new_user = { + "username": "updatetest", + "email": "updatetest@example.com", + "full_name": "Update Test", + "preferences": {"theme": "dark"}, + } + create_response = client.post("/api/v1/users", json=new_user) + created = create_response.json() + user_id = created["id"] + + # Update the user + updated_user = { + "id": user_id, + "username": "updatetest", + "email": "updatetest@example.com", + "full_name": "Updated Name", # Changed + "preferences": {"theme": "light", "notifications": True}, # Changed + } + + response = client.put(f"/api/v1/users/{user_id}", json=updated_user) + assert response.status_code == 200 + updated = response.json() + + assert updated["id"] == user_id + assert updated["full_name"] == "Updated Name" + assert updated["preferences"]["theme"] == "light" + assert updated["preferences"]["notifications"] is True + + +def test_delete_user(client: TestClient) -> None: + """Test deleting a user.""" + # Create a user first + new_user = {"username": "deletetest", "email": "deletetest@example.com", "full_name": "Delete Test"} + create_response = client.post("/api/v1/users", json=new_user) + created = create_response.json() + user_id = created["id"] + + # Delete the user + response = client.delete(f"/api/v1/users/{user_id}") + assert response.status_code == 204 + + # Verify user is deleted + get_response = client.get(f"/api/v1/users/{user_id}") + assert get_response.status_code == 404 + + +def test_delete_user_not_found(client: TestClient) -> None: + """Test deleting non-existent user returns 404.""" + response = client.delete("/api/v1/users/01K72P5N5KCRM6MD3BRE4P0999") + assert response.status_code == 404 + data = response.json() + assert "not found" in data["detail"].lower() + + +def test_user_preferences_json_field(client: TestClient) -> None: + """Test that user preferences are stored as JSON and can be queried.""" + # Create a user with complex preferences + new_user = { + "username": "jsontest", + "email": "jsontest@example.com", + "preferences": { + "theme": "auto", + "notifications": True, + "language": "es", + "timezone": "Europe/Madrid", + "custom_field": {"nested": "value"}, + }, + } + + response = client.post("/api/v1/users", json=new_user) + assert response.status_code == 201 + created = response.json() + user_id = created["id"] + + # Retrieve and verify preferences + get_response = client.get(f"/api/v1/users/{user_id}") + user = get_response.json() + + prefs = user["preferences"] + assert prefs["theme"] == "auto" + assert prefs["notifications"] is True + assert prefs["language"] == "es" + assert prefs["timezone"] == "Europe/Madrid" + assert prefs["custom_field"]["nested"] == "value" + + +def test_list_users_with_pagination(client: TestClient) -> None: + """Test listing users with pagination.""" + # Create a few more users to test pagination + for i in range(3): + client.post("/api/v1/users", json={"username": f"pagetest{i}", "email": f"pagetest{i}@example.com"}) + + # Get paginated list + response = client.get("/api/v1/users", params={"page": 1, "size": 2}) + assert response.status_code == 200 + data = response.json() + + # Should return paginated response + assert "items" in data + assert "total" in data + assert "page" in data + assert "size" in data + assert "pages" in data + + assert len(data["items"]) <= 2 + assert data["page"] == 1 + assert data["size"] == 2 + + +def test_config_and_user_coexist(client: TestClient) -> None: + """Test that Config and User endpoints coexist properly.""" + # Both endpoints should work + config_response = client.get("/api/v1/configs") + assert config_response.status_code == 200 + + user_response = client.get("/api/v1/users") + assert user_response.status_code == 200 + + # Both should return valid data + configs = config_response.json() + users = user_response.json() + + assert isinstance(configs, list) + assert isinstance(users, list) + assert len(configs) >= 1 + assert len(users) >= 1 diff --git a/packages/servicekit/tests/test_example_monitoring_api.py b/packages/servicekit/tests/test_example_monitoring_api.py new file mode 100644 index 0000000..b1164ac --- /dev/null +++ b/packages/servicekit/tests/test_example_monitoring_api.py @@ -0,0 +1,117 @@ +"""Tests for monitoring_api example using TestClient. + +Tests use FastAPI's TestClient instead of running a separate server. +Validates monitoring setup with OpenTelemetry and Prometheus metrics. +""" + +from __future__ import annotations + +from collections.abc import Generator + +import pytest +from fastapi.testclient import TestClient + +from examples.monitoring_api import app + + +@pytest.fixture(scope="module") +def client() -> Generator[TestClient, None, None]: + """Create FastAPI TestClient for testing with lifespan context.""" + with TestClient(app) as test_client: + yield test_client + + +def test_health_endpoint(client: TestClient) -> None: + """Test health check returns healthy status.""" + response = client.get("/health") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + assert "checks" in data + assert "database" in data["checks"] + assert data["checks"]["database"]["state"] == "healthy" + + +def test_system_endpoint(client: TestClient) -> None: + """Test system info endpoint returns system metadata.""" + response = client.get("/api/v1/system") + assert response.status_code == 200 + data = response.json() + assert "python_version" in data + assert "platform" in data + assert "current_time" in data + assert "timezone" in data + assert "hostname" in data + + +def test_metrics_endpoint(client: TestClient) -> None: + """Test Prometheus metrics endpoint returns metrics in text format.""" + response = client.get("/metrics") + assert response.status_code == 200 + assert response.headers["content-type"].startswith("text/plain") + + # Verify Prometheus format + content = response.text + assert "# HELP" in content + assert "# TYPE" in content + + # Verify some expected metrics exist + assert "python_gc_objects_collected_total" in content or "python_info" in content + + +def test_service_metadata(client: TestClient) -> None: + """Test service metadata is available in OpenAPI schema.""" + response = client.get("/openapi.json") + assert response.status_code == 200 + schema = response.json() + info = schema["info"] + assert info["title"] == "Monitoring Example Service" + assert info["version"] == "1.0.0" + + +def test_config_endpoints_exist(client: TestClient) -> None: + """Test that config endpoints are available.""" + response = client.get("/api/v1/configs") + assert response.status_code == 200 + data = response.json() + # Should return empty list or paginated response initially + assert isinstance(data, (list, dict)) + + +def test_config_schema_endpoint(client: TestClient) -> None: + """Test config schema endpoint returns Config entity schema with AppConfig data.""" + response = client.get("/api/v1/configs/$schema") + assert response.status_code == 200 + schema = response.json() + assert "properties" in schema + # Schema includes Config entity fields (id, name, created_at, data) + assert "id" in schema["properties"] + assert "name" in schema["properties"] + assert "data" in schema["properties"] + # The 'data' field references AppConfig schema + data_ref = schema["properties"]["data"] + assert "$ref" in data_ref + # Check $defs for AppConfig schema + assert "$defs" in schema + assert "AppConfig" in schema["$defs"] + app_config_schema = schema["$defs"]["AppConfig"] + assert "api_key" in app_config_schema["properties"] + assert "max_connections" in app_config_schema["properties"] + + +def test_openapi_schema(client: TestClient) -> None: + """Test OpenAPI schema includes all expected endpoints.""" + response = client.get("/openapi.json") + assert response.status_code == 200 + schema = response.json() + + paths = schema["paths"] + + # Verify operational endpoints at root level + assert "/health" in paths + assert "/api/v1/system" in paths + assert "/metrics" in paths + + # Verify API endpoints are versioned + assert "/api/v1/configs" in paths + assert "/api/v1/configs/$schema" in paths diff --git a/packages/servicekit/tests/test_example_task_execution_api.py b/packages/servicekit/tests/test_example_task_execution_api.py new file mode 100644 index 0000000..76a3acb --- /dev/null +++ b/packages/servicekit/tests/test_example_task_execution_api.py @@ -0,0 +1,591 @@ +"""Tests for task_execution_api example with artifact-based result storage. + +This example demonstrates the new task execution architecture: +- Tasks are reusable command templates (no status/output fields) +- Execution creates Jobs that run asynchronously +- Results are stored in Artifacts with full task snapshot + outputs +- Job.artifact_id links to the result artifact +""" + +from __future__ import annotations + +import time +from collections.abc import Generator +from typing import Any, cast + +import pytest +from fastapi.testclient import TestClient + +from examples.task_execution_api import app + + +@pytest.fixture(scope="module") +def client() -> Generator[TestClient, None, None]: + """Create FastAPI TestClient for testing with lifespan context.""" + with TestClient(app) as test_client: + yield test_client + + +def wait_for_job_completion(client: TestClient, job_id: str, timeout: float = 5.0) -> dict[Any, Any]: + """Poll job status until completion or timeout. + + Args: + client: FastAPI test client + job_id: Job identifier to poll + timeout: Max seconds to wait (default: 5.0) + + Returns: + Completed job record + + Raises: + TimeoutError: If job doesn't complete within timeout + """ + start_time = time.time() + while time.time() - start_time < timeout: + job_response = client.get(f"/api/v1/jobs/{job_id}") + assert job_response.status_code == 200 + job = cast(dict[Any, Any], job_response.json()) + + if job["status"] in ["completed", "failed", "canceled"]: + return job + + time.sleep(0.1) + + raise TimeoutError(f"Job {job_id} did not complete within {timeout}s") + + +def test_health_endpoint(client: TestClient) -> None: + """Test health check returns healthy status.""" + response = client.get("/health") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + + +def test_list_tasks(client: TestClient) -> None: + """Test listing tasks shows seeded templates.""" + response = client.get("/api/v1/tasks") + assert response.status_code == 200 + data = response.json() + + # Should have at least the 5 seeded tasks + assert isinstance(data, list) + assert len(data) >= 5 + + # Check for specific seeded tasks + commands = [task["command"] for task in data] + assert any("ls -la /tmp" in cmd for cmd in commands) + assert any("echo" in cmd for cmd in commands) + assert any("date" in cmd for cmd in commands) + + +def test_get_task_by_id(client: TestClient) -> None: + """Test retrieving task template by ID.""" + # Get list to find a task ID + list_response = client.get("/api/v1/tasks") + tasks = list_response.json() + assert len(tasks) > 0 + task_id = tasks[0]["id"] + + # Get task by ID + response = client.get(f"/api/v1/tasks/{task_id}") + assert response.status_code == 200 + task = response.json() + + assert task["id"] == task_id + assert "command" in task + assert "created_at" in task + assert "updated_at" in task + # Tasks are templates - no status or execution fields + assert "status" not in task + assert "stdout" not in task + assert "stderr" not in task + assert "exit_code" not in task + assert "job_id" not in task + + +def test_get_task_not_found(client: TestClient) -> None: + """Test retrieving non-existent task returns 404.""" + response = client.get("/api/v1/tasks/01K72P5N5KCRM6MD3BRE4P0999") + assert response.status_code == 404 + + +def test_create_task(client: TestClient) -> None: + """Test creating a new task template.""" + new_task = {"command": "echo 'test task creation'"} + + response = client.post("/api/v1/tasks", json=new_task) + assert response.status_code == 201 + created = response.json() + + assert "id" in created + assert created["command"] == "echo 'test task creation'" + assert "created_at" in created + assert "updated_at" in created + # Tasks are templates - no execution state + assert "status" not in created + assert "stdout" not in created + assert "stderr" not in created + assert "exit_code" not in created + assert "job_id" not in created + + +def test_create_task_with_missing_command(client: TestClient) -> None: + """Test creating task without command fails.""" + response = client.post("/api/v1/tasks", json={}) + assert response.status_code == 422 # Validation error + + +def test_update_task(client: TestClient) -> None: + """Test updating a task template command.""" + # Create a task first + new_task = {"command": "echo 'original'"} + create_response = client.post("/api/v1/tasks", json=new_task) + created = create_response.json() + task_id = created["id"] + + # Update the task + updated_task = {"command": "echo 'updated'"} + response = client.put(f"/api/v1/tasks/{task_id}", json=updated_task) + assert response.status_code == 200 + updated = response.json() + + assert updated["id"] == task_id + assert updated["command"] == "echo 'updated'" + + +def test_delete_task(client: TestClient) -> None: + """Test deleting a task template.""" + # Create a task first + new_task = {"command": "echo 'to be deleted'"} + create_response = client.post("/api/v1/tasks", json=new_task) + created = create_response.json() + task_id = created["id"] + + # Delete the task + response = client.delete(f"/api/v1/tasks/{task_id}") + assert response.status_code == 204 + + # Verify task is deleted + get_response = client.get(f"/api/v1/tasks/{task_id}") + assert get_response.status_code == 404 + + +def test_execute_task_simple_command(client: TestClient) -> None: + """Test executing a simple echo command and retrieving results from artifact.""" + # Create a task + new_task = {"command": "echo 'Hello World'"} + create_response = client.post("/api/v1/tasks", json=new_task) + task = create_response.json() + task_id = task["id"] + + # Execute the task + execute_response = client.post(f"/api/v1/tasks/{task_id}/$execute") + assert execute_response.status_code == 202 # Accepted + execute_data = execute_response.json() + assert "job_id" in execute_data + assert "message" in execute_data + job_id = execute_data["job_id"] + + # Wait for job completion + job = wait_for_job_completion(client, job_id) + assert job["status"] == "completed" + assert job["artifact_id"] is not None + + # Get artifact with results + artifact_response = client.get(f"/api/v1/artifacts/{job['artifact_id']}") + assert artifact_response.status_code == 200 + artifact = artifact_response.json() + + # Check artifact data structure + assert "data" in artifact + data = artifact["data"] + assert "task" in data + assert "stdout" in data + assert "stderr" in data + assert "exit_code" in data + + # Verify task snapshot + assert data["task"]["id"] == task_id + assert data["task"]["command"] == "echo 'Hello World'" + + # Verify execution results + assert "Hello World" in data["stdout"] + assert data["exit_code"] == 0 + + +def test_execute_task_with_output(client: TestClient) -> None: + """Test executing command with multiline output and checking artifact.""" + # Create a task that produces output + new_task = {"command": "printf 'Line 1\\nLine 2\\nLine 3'"} + create_response = client.post("/api/v1/tasks", json=new_task) + task = create_response.json() + task_id = task["id"] + + # Execute + execute_response = client.post(f"/api/v1/tasks/{task_id}/$execute") + job_id = execute_response.json()["job_id"] + + # Wait for completion + job = wait_for_job_completion(client, job_id) + assert job["status"] == "completed" + + # Get artifact and check outputs + artifact_response = client.get(f"/api/v1/artifacts/{job['artifact_id']}") + artifact = artifact_response.json() + data = artifact["data"] + + assert "Line 1" in data["stdout"] + assert "Line 2" in data["stdout"] + assert "Line 3" in data["stdout"] + assert data["exit_code"] == 0 + + +def test_execute_task_failing_command(client: TestClient) -> None: + """Test executing a command that fails and checking error in artifact.""" + # Create a task with a failing command + new_task = {"command": "ls /this/path/does/not/exist"} + create_response = client.post("/api/v1/tasks", json=new_task) + task = create_response.json() + task_id = task["id"] + + # Execute + execute_response = client.post(f"/api/v1/tasks/{task_id}/$execute") + job_id = execute_response.json()["job_id"] + + # Wait for completion + job = wait_for_job_completion(client, job_id) + # Job completes successfully even if command fails + assert job["status"] == "completed" + assert job["artifact_id"] is not None + + # Get artifact and check failure details + artifact_response = client.get(f"/api/v1/artifacts/{job['artifact_id']}") + artifact = artifact_response.json() + data = artifact["data"] + + # Command failed with non-zero exit code + assert data["exit_code"] != 0 + assert data["stderr"] is not None + assert len(data["stderr"]) > 0 + + +def test_execute_task_with_stderr(client: TestClient) -> None: + """Test capturing stderr output in artifact.""" + # Create a task that writes to stderr + new_task = {"command": ">&2 echo 'error message'"} + create_response = client.post("/api/v1/tasks", json=new_task) + task = create_response.json() + task_id = task["id"] + + # Execute + execute_response = client.post(f"/api/v1/tasks/{task_id}/$execute") + job_id = execute_response.json()["job_id"] + + # Wait for completion + job = wait_for_job_completion(client, job_id) + assert job["status"] == "completed" + + # Get artifact and check stderr + artifact_response = client.get(f"/api/v1/artifacts/{job['artifact_id']}") + artifact = artifact_response.json() + data = artifact["data"] + + # Stderr should contain the error message + # Some systems redirect stderr to stdout, so check both + output = (data["stderr"] or "") + (data["stdout"] or "") + assert "error message" in output + assert data["exit_code"] == 0 + + +def test_execute_same_task_multiple_times(client: TestClient) -> None: + """Test that same task template can be executed multiple times.""" + # Create a task + new_task = {"command": "echo 'multiple executions'"} + create_response = client.post("/api/v1/tasks", json=new_task) + task = create_response.json() + task_id = task["id"] + + job_ids = [] + artifact_ids = [] + + # Execute the same task 3 times sequentially (to avoid potential race conditions) + for _ in range(3): + execute_response = client.post(f"/api/v1/tasks/{task_id}/$execute") + assert execute_response.status_code == 202 + job_id = execute_response.json()["job_id"] + job_ids.append(job_id) + + # Wait for this job to complete before starting the next one + job = wait_for_job_completion(client, job_id) + assert job["status"] in ["completed", "failed"] + if job["status"] == "completed": + assert job["artifact_id"] is not None + artifact_ids.append(job["artifact_id"]) + + # At least some executions should have succeeded + assert len(artifact_ids) >= 1 + + # All artifact IDs should be different (each execution creates a new artifact) + assert len(set(artifact_ids)) == len(artifact_ids) + + # Verify all artifacts contain the same task snapshot but are independent records + for artifact_id in artifact_ids: + artifact_response = client.get(f"/api/v1/artifacts/{artifact_id}") + artifact = artifact_response.json() + data = artifact["data"] + + assert data["task"]["id"] == task_id + assert data["task"]["command"] == "echo 'multiple executions'" + assert "multiple executions" in data["stdout"] + + +def test_execute_nonexistent_task(client: TestClient) -> None: + """Test executing a non-existent task.""" + response = client.post("/api/v1/tasks/01K72P5N5KCRM6MD3BRE4P0999/$execute") + assert response.status_code == 400 # Bad request + + +def test_list_jobs(client: TestClient) -> None: + """Test listing scheduler jobs.""" + response = client.get("/api/v1/jobs") + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + + +def test_get_job_for_executed_task(client: TestClient) -> None: + """Test that executed tasks create scheduler jobs with proper metadata.""" + # Create and execute a task + new_task = {"command": "echo 'job test'"} + create_response = client.post("/api/v1/tasks", json=new_task) + task = create_response.json() + task_id = task["id"] + + execute_response = client.post(f"/api/v1/tasks/{task_id}/$execute") + execute_data = execute_response.json() + job_id = execute_data["job_id"] + + # Get the job + job_response = client.get(f"/api/v1/jobs/{job_id}") + assert job_response.status_code == 200 + job = job_response.json() + + assert job["id"] == job_id + assert job["status"] in ["pending", "running", "completed", "failed"] + assert "submitted_at" in job + + +def test_task_timestamps(client: TestClient) -> None: + """Test that task template timestamps are set correctly.""" + # Create a task + new_task = {"command": "sleep 0.1"} + create_response = client.post("/api/v1/tasks", json=new_task) + task = create_response.json() + task_id = task["id"] + + # Check task timestamps (only created_at and updated_at for templates) + assert task["created_at"] is not None + assert task["updated_at"] is not None + # No execution timestamps on task template + assert "started_at" not in task + assert "finished_at" not in task + + # Execute and check job timestamps + execute_response = client.post(f"/api/v1/tasks/{task_id}/$execute") + job_id = execute_response.json()["job_id"] + + # Wait for completion + job = wait_for_job_completion(client, job_id) + + # Job should have execution timestamps + assert job["submitted_at"] is not None + assert job["started_at"] is not None + assert job["finished_at"] is not None + + +def test_list_tasks_with_pagination(client: TestClient) -> None: + """Test listing task templates with pagination.""" + response = client.get("/api/v1/tasks", params={"page": 1, "size": 2}) + assert response.status_code == 200 + data = response.json() + + # Should return paginated response + assert "items" in data + assert "total" in data + assert "page" in data + assert "size" in data + assert "pages" in data + + assert len(data["items"]) <= 2 + assert data["page"] == 1 + assert data["size"] == 2 + + +def test_python_command_execution(client: TestClient) -> None: + """Test executing Python commands and retrieving results from artifact.""" + # Create a task with Python code + new_task = {"command": 'python3 -c "import sys; print(sys.version); print(2+2)"'} + create_response = client.post("/api/v1/tasks", json=new_task) + task = create_response.json() + task_id = task["id"] + + # Execute + execute_response = client.post(f"/api/v1/tasks/{task_id}/$execute") + job_id = execute_response.json()["job_id"] + + # Wait for completion + job = wait_for_job_completion(client, job_id) + assert job["status"] == "completed" + + # Get artifact and check output + artifact_response = client.get(f"/api/v1/artifacts/{job['artifact_id']}") + artifact = artifact_response.json() + data = artifact["data"] + + assert data["exit_code"] == 0 + assert "4" in data["stdout"] # Result of 2+2 + + +def test_concurrent_task_execution(client: TestClient) -> None: + """Test executing multiple tasks concurrently and retrieving all artifacts.""" + task_ids = [] + job_ids = [] + + # Create multiple tasks + for i in range(3): + new_task = {"command": f"echo 'task {i}'"} + response = client.post("/api/v1/tasks", json=new_task) + task_ids.append(response.json()["id"]) + + # Execute all tasks + for task_id in task_ids: + execute_response = client.post(f"/api/v1/tasks/{task_id}/$execute") + job_ids.append(execute_response.json()["job_id"]) + + # Wait for all jobs to complete and collect results + all_jobs = [] + for job_id in job_ids: + job = wait_for_job_completion(client, job_id) + all_jobs.append(job) + # Jobs should complete (either successfully or with failure) + assert job["status"] in ["completed", "failed", "canceled"] + + # Count successful completions + completed_jobs = [j for j in all_jobs if j["status"] == "completed"] + + # Test that we can execute multiple tasks (even if some fail due to concurrency limits) + # At minimum, verify jobs were created and reached terminal state + assert len(all_jobs) == 3 + + # Verify completed artifacts have correct structure and output + for job in completed_jobs: + if job["artifact_id"]: + artifact_response = client.get(f"/api/v1/artifacts/{job['artifact_id']}") + artifact = artifact_response.json() + data = artifact["data"] + + # Each task should have output containing "task" + assert "task" in data["stdout"] + assert data["exit_code"] == 0 + + +def test_job_artifact_linkage(client: TestClient) -> None: + """Test that jobs are properly linked to result artifacts.""" + # Create and execute a task + new_task = {"command": "echo 'linkage test'"} + create_response = client.post("/api/v1/tasks", json=new_task) + task = create_response.json() + task_id = task["id"] + + execute_response = client.post(f"/api/v1/tasks/{task_id}/$execute") + job_id = execute_response.json()["job_id"] + + # Wait for completion + job = wait_for_job_completion(client, job_id) + + # Job should have artifact_id + assert job["artifact_id"] is not None + artifact_id = job["artifact_id"] + + # Artifact should exist and contain execution results + artifact_response = client.get(f"/api/v1/artifacts/{artifact_id}") + assert artifact_response.status_code == 200 + artifact = artifact_response.json() + + # Verify artifact structure + assert artifact["id"] == artifact_id + assert "data" in artifact + assert "task" in artifact["data"] + assert artifact["data"]["task"]["id"] == task_id + + +def test_task_deletion_preserves_artifacts(client: TestClient) -> None: + """Test that deleting a task doesn't delete its execution artifacts.""" + # Create and execute a task + new_task = {"command": "echo 'preserve artifacts'"} + create_response = client.post("/api/v1/tasks", json=new_task) + task = create_response.json() + task_id = task["id"] + + execute_response = client.post(f"/api/v1/tasks/{task_id}/$execute") + job_id = execute_response.json()["job_id"] + + # Wait for completion + job = wait_for_job_completion(client, job_id) + artifact_id = job["artifact_id"] + + # Verify artifact exists + artifact_response = client.get(f"/api/v1/artifacts/{artifact_id}") + assert artifact_response.status_code == 200 + + # Delete the task + delete_response = client.delete(f"/api/v1/tasks/{task_id}") + assert delete_response.status_code == 204 + + # Artifact should still exist + artifact_response = client.get(f"/api/v1/artifacts/{artifact_id}") + assert artifact_response.status_code == 200 + artifact = artifact_response.json() + + # Artifact contains full task snapshot, so task data is preserved + assert artifact["data"]["task"]["id"] == task_id + assert artifact["data"]["task"]["command"] == "echo 'preserve artifacts'" + + +def test_task_modification_doesnt_affect_artifacts(client: TestClient) -> None: + """Test that modifying a task doesn't affect existing execution artifacts.""" + # Create and execute a task + original_command = "echo 'original command'" + new_task = {"command": original_command} + create_response = client.post("/api/v1/tasks", json=new_task) + task = create_response.json() + task_id = task["id"] + + execute_response = client.post(f"/api/v1/tasks/{task_id}/$execute") + job_id = execute_response.json()["job_id"] + + # Wait for completion + job = wait_for_job_completion(client, job_id) + artifact_id = job["artifact_id"] + + # Modify the task + modified_command = "echo 'modified command'" + update_response = client.put(f"/api/v1/tasks/{task_id}", json={"command": modified_command}) + assert update_response.status_code == 200 + + # Original artifact should still have the original command + artifact_response = client.get(f"/api/v1/artifacts/{artifact_id}") + artifact = artifact_response.json() + assert artifact["data"]["task"]["command"] == original_command + assert "original command" in artifact["data"]["stdout"] + + # New execution should use modified command + execute_response2 = client.post(f"/api/v1/tasks/{task_id}/$execute") + job_id2 = execute_response2.json()["job_id"] + job2 = wait_for_job_completion(client, job_id2) + + artifact_response2 = client.get(f"/api/v1/artifacts/{job2['artifact_id']}") + artifact2 = artifact_response2.json() + assert artifact2["data"]["task"]["command"] == modified_command + assert "modified command" in artifact2["data"]["stdout"] diff --git a/packages/servicekit/tests/test_exceptions.py b/packages/servicekit/tests/test_exceptions.py new file mode 100644 index 0000000..64abcb9 --- /dev/null +++ b/packages/servicekit/tests/test_exceptions.py @@ -0,0 +1,169 @@ +"""Tests for custom exceptions with RFC 9457 Problem Details support.""" + +import pytest +from servicekit.core.exceptions import ( + BadRequestError, + ChapkitException, + ConflictError, + ErrorType, + ForbiddenError, + InvalidULIDError, + NotFoundError, + UnauthorizedError, + ValidationError, +) + + +def test_chapkit_exception_with_defaults() -> None: + """Test ChapkitException with default values.""" + exc = ChapkitException("Something went wrong") + + assert str(exc) == "Something went wrong" + assert exc.detail == "Something went wrong" + assert exc.type_uri == ErrorType.INTERNAL_ERROR + assert exc.title == "Internal Server Error" + assert exc.status == 500 + assert exc.instance is None + assert exc.extensions == {} + + +def test_chapkit_exception_with_custom_values() -> None: + """Test ChapkitException with custom values.""" + exc = ChapkitException( + "Custom error", + type_uri="urn:custom:error", + title="Custom Error", + status=418, + instance="/api/v1/teapot", + extra_field="extra_value", + ) + + assert exc.detail == "Custom error" + assert exc.type_uri == "urn:custom:error" + assert exc.title == "Custom Error" + assert exc.status == 418 + assert exc.instance == "/api/v1/teapot" + assert exc.extensions == {"extra_field": "extra_value"} + + +def test_not_found_error() -> None: + """Test NotFoundError sets correct RFC 9457 fields.""" + exc = NotFoundError("Resource not found", instance="/api/v1/items/123") + + assert str(exc) == "Resource not found" + assert exc.detail == "Resource not found" + assert exc.type_uri == ErrorType.NOT_FOUND + assert exc.title == "Resource Not Found" + assert exc.status == 404 + assert exc.instance == "/api/v1/items/123" + + +def test_validation_error() -> None: + """Test ValidationError sets correct RFC 9457 fields.""" + exc = ValidationError("Invalid input", instance="/api/v1/users", field="email") + + assert str(exc) == "Invalid input" + assert exc.detail == "Invalid input" + assert exc.type_uri == ErrorType.VALIDATION_FAILED + assert exc.title == "Validation Failed" + assert exc.status == 400 + assert exc.instance == "/api/v1/users" + assert exc.extensions == {"field": "email"} + + +def test_conflict_error() -> None: + """Test ConflictError sets correct RFC 9457 fields.""" + exc = ConflictError("Resource already exists", instance="/api/v1/configs/prod") + + assert str(exc) == "Resource already exists" + assert exc.detail == "Resource already exists" + assert exc.type_uri == ErrorType.CONFLICT + assert exc.title == "Resource Conflict" + assert exc.status == 409 + assert exc.instance == "/api/v1/configs/prod" + + +def test_invalid_ulid_error() -> None: + """Test InvalidULIDError sets correct RFC 9457 fields.""" + exc = InvalidULIDError("Malformed ULID", instance="/api/v1/items/invalid", ulid="bad-ulid") + + assert str(exc) == "Malformed ULID" + assert exc.detail == "Malformed ULID" + assert exc.type_uri == ErrorType.INVALID_ULID + assert exc.title == "Invalid ULID Format" + assert exc.status == 400 + assert exc.instance == "/api/v1/items/invalid" + assert exc.extensions == {"ulid": "bad-ulid"} + + +def test_bad_request_error() -> None: + """Test BadRequestError sets correct RFC 9457 fields.""" + exc = BadRequestError("Missing required parameter", instance="/api/v1/search") + + assert str(exc) == "Missing required parameter" + assert exc.detail == "Missing required parameter" + assert exc.type_uri == ErrorType.BAD_REQUEST + assert exc.title == "Bad Request" + assert exc.status == 400 + assert exc.instance == "/api/v1/search" + + +def test_unauthorized_error() -> None: + """Test UnauthorizedError sets correct RFC 9457 fields.""" + exc = UnauthorizedError("Invalid credentials", instance="/api/v1/login") + + assert str(exc) == "Invalid credentials" + assert exc.detail == "Invalid credentials" + assert exc.type_uri == ErrorType.UNAUTHORIZED + assert exc.title == "Unauthorized" + assert exc.status == 401 + assert exc.instance == "/api/v1/login" + + +def test_forbidden_error() -> None: + """Test ForbiddenError sets correct RFC 9457 fields.""" + exc = ForbiddenError("Access denied", instance="/api/v1/admin", required_role="admin") + + assert str(exc) == "Access denied" + assert exc.detail == "Access denied" + assert exc.type_uri == ErrorType.FORBIDDEN + assert exc.title == "Forbidden" + assert exc.status == 403 + assert exc.instance == "/api/v1/admin" + assert exc.extensions == {"required_role": "admin"} + + +def test_error_type_constants() -> None: + """Test ErrorType URN constants are correctly defined.""" + assert ErrorType.NOT_FOUND == "urn:chapkit:error:not-found" + assert ErrorType.VALIDATION_FAILED == "urn:chapkit:error:validation-failed" + assert ErrorType.CONFLICT == "urn:chapkit:error:conflict" + assert ErrorType.INVALID_ULID == "urn:chapkit:error:invalid-ulid" + assert ErrorType.INTERNAL_ERROR == "urn:chapkit:error:internal" + assert ErrorType.UNAUTHORIZED == "urn:chapkit:error:unauthorized" + assert ErrorType.FORBIDDEN == "urn:chapkit:error:forbidden" + assert ErrorType.BAD_REQUEST == "urn:chapkit:error:bad-request" + + +def test_exceptions_are_raisable() -> None: + """Test that exceptions can be raised and caught.""" + with pytest.raises(NotFoundError) as exc_info: + raise NotFoundError("Test error") + + assert exc_info.value.status == 404 + assert str(exc_info.value) == "Test error" + + +def test_exception_without_instance() -> None: + """Test exceptions work without instance parameter.""" + exc = ValidationError("Missing field") + + assert exc.instance is None + assert exc.detail == "Missing field" + + +def test_exception_with_multiple_extensions() -> None: + """Test exceptions can handle multiple extension fields.""" + exc = BadRequestError("Invalid query", field="name", reason="too_short", min_length=3) + + assert exc.extensions == {"field": "name", "reason": "too_short", "min_length": 3} diff --git a/packages/servicekit/tests/test_health_router.py b/packages/servicekit/tests/test_health_router.py new file mode 100644 index 0000000..6a025cb --- /dev/null +++ b/packages/servicekit/tests/test_health_router.py @@ -0,0 +1,323 @@ +"""Tests for health check router.""" + +from __future__ import annotations + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from httpx import ASGITransport, AsyncClient +from servicekit.core.api.routers.health import CheckResult, HealthRouter, HealthState, HealthStatus + + +@pytest.fixture +def app_no_checks() -> FastAPI: + """FastAPI app with health router but no checks.""" + app = FastAPI() + health_router = HealthRouter.create(prefix="/health", tags=["Observability"]) + app.include_router(health_router) + return app + + +@pytest.fixture +def app_with_checks() -> FastAPI: + """FastAPI app with health router and custom checks.""" + + async def check_healthy() -> tuple[HealthState, str | None]: + return (HealthState.HEALTHY, None) + + async def check_degraded() -> tuple[HealthState, str | None]: + return (HealthState.DEGRADED, "Partial outage") + + async def check_unhealthy() -> tuple[HealthState, str | None]: + return (HealthState.UNHEALTHY, "Service down") + + async def check_exception() -> tuple[HealthState, str | None]: + raise RuntimeError("Check failed") + + app = FastAPI() + health_router = HealthRouter.create( + prefix="/health", + tags=["Observability"], + checks={ + "healthy_check": check_healthy, + "degraded_check": check_degraded, + "unhealthy_check": check_unhealthy, + "exception_check": check_exception, + }, + ) + app.include_router(health_router) + return app + + +def test_health_check_no_checks(app_no_checks: FastAPI) -> None: + """Test health check endpoint with no custom checks returns healthy.""" + client = TestClient(app_no_checks) + response = client.get("/health/") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + assert "checks" not in data or data["checks"] is None + + +def test_health_check_with_checks(app_with_checks: FastAPI) -> None: + """Test health check endpoint with custom checks aggregates results.""" + client = TestClient(app_with_checks) + response = client.get("/health/") + assert response.status_code == 200 + data = response.json() + + # Overall status should be unhealthy (worst state) + assert data["status"] == "unhealthy" + + # Verify individual check results + checks = data["checks"] + assert checks["healthy_check"]["state"] == "healthy" + assert "message" not in checks["healthy_check"] # None excluded + + assert checks["degraded_check"]["state"] == "degraded" + assert checks["degraded_check"]["message"] == "Partial outage" + + assert checks["unhealthy_check"]["state"] == "unhealthy" + assert checks["unhealthy_check"]["message"] == "Service down" + + # Exception should be caught and reported as unhealthy + assert checks["exception_check"]["state"] == "unhealthy" + assert "Check failed" in checks["exception_check"]["message"] + + +def test_health_state_enum() -> None: + """Test HealthState enum values.""" + assert HealthState.HEALTHY.value == "healthy" + assert HealthState.DEGRADED.value == "degraded" + assert HealthState.UNHEALTHY.value == "unhealthy" + + +def test_check_result_model() -> None: + """Test CheckResult model.""" + result = CheckResult(state=HealthState.HEALTHY, message=None) + assert result.state == HealthState.HEALTHY + assert result.message is None + + result_with_msg = CheckResult(state=HealthState.UNHEALTHY, message="Error occurred") + assert result_with_msg.state == HealthState.UNHEALTHY + assert result_with_msg.message == "Error occurred" + + +def test_health_status_model() -> None: + """Test HealthStatus model.""" + status = HealthStatus(status=HealthState.HEALTHY) + assert status.status == HealthState.HEALTHY + assert status.checks is None + + checks = {"test": CheckResult(state=HealthState.HEALTHY, message=None)} + status_with_checks = HealthStatus(status=HealthState.HEALTHY, checks=checks) + assert status_with_checks.status == HealthState.HEALTHY + assert status_with_checks.checks == checks + + +def test_health_check_aggregation_priority() -> None: + """Test that unhealthy > degraded > healthy in aggregation.""" + + async def check_healthy() -> tuple[HealthState, str | None]: + return (HealthState.HEALTHY, None) + + async def check_degraded() -> tuple[HealthState, str | None]: + return (HealthState.DEGRADED, "Warning") + + # Only healthy checks -> overall healthy + app = FastAPI() + router = HealthRouter.create(prefix="/health", tags=["Observability"], checks={"healthy": check_healthy}) + app.include_router(router) + + client = TestClient(app) + response = client.get("/health/") + assert response.json()["status"] == "healthy" + + # Healthy + degraded -> overall degraded + app2 = FastAPI() + router2 = HealthRouter.create( + prefix="/health", tags=["Observability"], checks={"healthy": check_healthy, "degraded": check_degraded} + ) + app2.include_router(router2) + + client2 = TestClient(app2) + response = client2.get("/health/") + assert response.json()["status"] == "degraded" + + +class TestHealthRouterSSE: + """Test health router SSE streaming. + + Note: SSE streaming tests are skipped for automated testing due to httpx AsyncClient + + ASGITransport limitations with infinite streams. The endpoint is manually tested + and works correctly with real HTTP clients (curl, browsers, etc.). + """ + + @pytest.mark.skip(reason="httpx AsyncClient with ASGITransport cannot handle infinite SSE streams properly") + @pytest.mark.asyncio + async def test_stream_health_no_checks(self) -> None: + """Test SSE streaming with no custom checks.""" + import json + + app = FastAPI() + health_router = HealthRouter.create(prefix="/health", tags=["Observability"]) + app.include_router(health_router) + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + events = [] + async with client.stream("GET", "/health/$stream?poll_interval=0.1") as response: + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + assert response.headers["cache-control"] == "no-cache" + assert response.headers["connection"] == "keep-alive" + + # Collect a few events + async for line in response.aiter_lines(): + if line.startswith("data: "): + data = json.loads(line[6:]) + events.append(data) + if len(events) >= 3: + break + + # All events should show healthy status + assert len(events) >= 3 + for event in events: + assert event["status"] == "healthy" + assert "checks" not in event or event["checks"] is None + + @pytest.mark.skip(reason="httpx AsyncClient with ASGITransport cannot handle infinite SSE streams properly") + @pytest.mark.asyncio + async def test_stream_health_with_checks(self) -> None: + """Test SSE streaming with custom health checks.""" + import json + + async def check_healthy() -> tuple[HealthState, str | None]: + return (HealthState.HEALTHY, None) + + async def check_degraded() -> tuple[HealthState, str | None]: + return (HealthState.DEGRADED, "Partial outage") + + app = FastAPI() + health_router = HealthRouter.create( + prefix="/health", + tags=["Observability"], + checks={"healthy_check": check_healthy, "degraded_check": check_degraded}, + ) + app.include_router(health_router) + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + events = [] + async with client.stream("GET", "/health/$stream?poll_interval=0.1") as response: + assert response.status_code == 200 + + # Collect a few events + async for line in response.aiter_lines(): + if line.startswith("data: "): + data = json.loads(line[6:]) + events.append(data) + if len(events) >= 2: + break + + # All events should show degraded status (worst state) + assert len(events) >= 2 + for event in events: + assert event["status"] == "degraded" + assert event["checks"] is not None + assert event["checks"]["healthy_check"]["state"] == "healthy" + assert event["checks"]["degraded_check"]["state"] == "degraded" + assert event["checks"]["degraded_check"]["message"] == "Partial outage" + + @pytest.mark.skip(reason="httpx AsyncClient with ASGITransport cannot handle infinite SSE streams properly") + @pytest.mark.asyncio + async def test_stream_health_custom_poll_interval(self) -> None: + """Test SSE streaming with custom poll interval.""" + import json + import time + + app = FastAPI() + health_router = HealthRouter.create(prefix="/health", tags=["Observability"]) + app.include_router(health_router) + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + events = [] + start_time = time.time() + + async with client.stream("GET", "/health/$stream?poll_interval=0.2") as response: + assert response.status_code == 200 + + # Collect 3 events + async for line in response.aiter_lines(): + if line.startswith("data: "): + data = json.loads(line[6:]) + events.append(data) + if len(events) >= 3: + break + + elapsed = time.time() - start_time + + # Should have taken at least 0.4 seconds (2 intervals between 3 events) + assert elapsed >= 0.4 + assert len(events) == 3 + + @pytest.mark.skip(reason="httpx AsyncClient with ASGITransport cannot handle infinite SSE streams properly") + @pytest.mark.asyncio + async def test_stream_health_state_transitions(self) -> None: + """Test SSE streaming captures state transitions over time.""" + import json + + health_state = {"current": HealthState.HEALTHY} + + async def dynamic_check() -> tuple[HealthState, str | None]: + return (health_state["current"], None) + + app = FastAPI() + health_router = HealthRouter.create(prefix="/health", tags=["Observability"], checks={"dynamic": dynamic_check}) + app.include_router(health_router) + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + events = [] + + async with client.stream("GET", "/health/$stream?poll_interval=0.1") as response: + assert response.status_code == 200 + + async for line in response.aiter_lines(): + if line.startswith("data: "): + data = json.loads(line[6:]) + events.append(data) + + # Change state after collecting a few events + if len(events) == 2: + health_state["current"] = HealthState.UNHEALTHY + elif len(events) >= 4: + break + + # Verify we captured the state transition + assert len(events) >= 4 + assert events[0]["status"] == "healthy" + assert events[1]["status"] == "healthy" + # State should transition to unhealthy + assert events[3]["status"] == "unhealthy" + + @pytest.mark.skip(reason="httpx AsyncClient with ASGITransport cannot handle infinite SSE streams properly") + @pytest.mark.asyncio + async def test_stream_health_continuous(self) -> None: + """Test SSE streaming continues indefinitely until client disconnects.""" + app = FastAPI() + health_router = HealthRouter.create(prefix="/health", tags=["Observability"]) + app.include_router(health_router) + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + event_count = 0 + + async with client.stream("GET", "/health/$stream?poll_interval=0.05") as response: + assert response.status_code == 200 + + # Collect many events to verify continuous streaming + async for line in response.aiter_lines(): + if line.startswith("data: "): + event_count += 1 + if event_count >= 10: + break + + # Should have received many events + assert event_count == 10 diff --git a/packages/servicekit/tests/test_hierarchy.py b/packages/servicekit/tests/test_hierarchy.py new file mode 100644 index 0000000..c828475 --- /dev/null +++ b/packages/servicekit/tests/test_hierarchy.py @@ -0,0 +1,20 @@ +from servicekit import ArtifactHierarchy + + +def test_label_lookup_returns_configured_value() -> None: + hierarchy = ArtifactHierarchy(name="ml_flow", level_labels={0: "train", 1: "predict"}) + assert hierarchy.label_for(0) == "train" + assert hierarchy.label_for(1) == "predict" + assert hierarchy.label_for(2) == "level_2" + + +def test_describe_returns_metadata() -> None: + hierarchy = ArtifactHierarchy(name="ml_flow", level_labels={0: "train"}) + metadata = hierarchy.describe(0) + assert metadata == {"hierarchy": "ml_flow", "level_depth": 0, "level_label": "train"} + + +def test_describe_uses_fallback_label() -> None: + hierarchy = ArtifactHierarchy(name="ml_flow", level_labels={}) + metadata = hierarchy.describe(3) + assert metadata["level_label"] == "level_3" diff --git a/packages/servicekit/tests/test_job_router.py b/packages/servicekit/tests/test_job_router.py new file mode 100644 index 0000000..c007a6d --- /dev/null +++ b/packages/servicekit/tests/test_job_router.py @@ -0,0 +1,466 @@ +"""Tests for job router endpoints.""" + +import asyncio +from collections.abc import AsyncGenerator + +import pytest +import ulid +from fastapi import FastAPI +from httpx import ASGITransport, AsyncClient +from servicekit.api import ServiceBuilder, ServiceInfo + +ULID = ulid.ULID + + +@pytest.fixture +async def app() -> AsyncGenerator[FastAPI, None]: + """Create FastAPI app with job router and trigger lifespan.""" + info = ServiceInfo(display_name="Test Service") + app_instance = ServiceBuilder(info=info).with_jobs().build() + + # Manually trigger lifespan + async with app_instance.router.lifespan_context(app_instance): + yield app_instance + + +@pytest.fixture +async def client(app: FastAPI) -> AsyncGenerator[AsyncClient, None]: + """Create async test client.""" + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test", follow_redirects=True) as ac: + yield ac + + +class TestJobRouter: + """Test job router endpoints.""" + + @pytest.mark.asyncio + async def test_get_jobs_empty(self, client: AsyncClient): + """Test GET /api/v1/jobs returns empty list initially.""" + response = await client.get("/api/v1/jobs") + assert response.status_code == 200 + jobs = response.json() + assert jobs == [] + + @pytest.mark.asyncio + async def test_get_jobs_after_adding(self, client: AsyncClient, app: FastAPI): + """Test GET /api/v1/jobs returns jobs after adding them.""" + # Get scheduler and add a job + scheduler = app.state.scheduler + + async def task(): + return "result" + + job_id = await scheduler.add_job(task) + await scheduler.wait(job_id) + + # Get jobs + response = await client.get("/api/v1/jobs") + assert response.status_code == 200 + jobs = response.json() + assert len(jobs) == 1 + assert jobs[0]["id"] == str(job_id) + assert jobs[0]["status"] == "completed" + + @pytest.mark.asyncio + async def test_get_jobs_filtered_by_status(self, client: AsyncClient, app: FastAPI): + """Test GET /api/v1/jobs?status_filter=completed filters jobs.""" + scheduler = app.state.scheduler + + async def quick_task(): + return "done" + + async def slow_task(): + await asyncio.sleep(10) + return "never" + + # Add completed and running jobs + completed_id = await scheduler.add_job(quick_task) + await scheduler.wait(completed_id) + + running_id = await scheduler.add_job(slow_task) + await asyncio.sleep(0.01) # Let it start + + # Filter by completed + response = await client.get("/api/v1/jobs?status_filter=completed") + assert response.status_code == 200 + jobs = response.json() + assert len(jobs) == 1 + assert jobs[0]["id"] == str(completed_id) + + # Filter by running + response = await client.get("/api/v1/jobs?status_filter=running") + assert response.status_code == 200 + jobs = response.json() + assert len(jobs) == 1 + assert jobs[0]["id"] == str(running_id) + + # Cleanup + await scheduler.cancel(running_id) + + @pytest.mark.asyncio + async def test_get_job_by_id(self, client: AsyncClient, app: FastAPI): + """Test GET /api/v1/jobs/{id} returns job record.""" + scheduler = app.state.scheduler + + async def task(): + return "result" + + job_id = await scheduler.add_job(task) + await scheduler.wait(job_id) + + # Get specific job + response = await client.get(f"/api/v1/jobs/{job_id}") + assert response.status_code == 200 + job = response.json() + assert job["id"] == str(job_id) + assert job["status"] == "completed" + assert job["submitted_at"] is not None + assert job["started_at"] is not None + assert job["finished_at"] is not None + + @pytest.mark.asyncio + async def test_get_job_not_found(self, client: AsyncClient): + """Test GET /api/v1/jobs/{id} returns 404 for non-existent job.""" + fake_id = ULID() + response = await client.get(f"/api/v1/jobs/{fake_id}") + assert response.status_code == 404 + assert response.json()["detail"] == "Job not found" + + @pytest.mark.asyncio + async def test_get_job_invalid_ulid(self, client: AsyncClient): + """Test GET /api/v1/jobs/{id} returns 404 for invalid ULID.""" + response = await client.get("/api/v1/jobs/invalid-ulid") + assert response.status_code == 404 + + @pytest.mark.asyncio + async def test_delete_job(self, client: AsyncClient, app: FastAPI): + """Test DELETE /api/v1/jobs/{id} deletes job.""" + scheduler = app.state.scheduler + + async def task(): + return "result" + + job_id = await scheduler.add_job(task) + await scheduler.wait(job_id) + + # Delete job + response = await client.delete(f"/api/v1/jobs/{job_id}") + assert response.status_code == 204 + assert response.text == "" + + # Verify job is gone + response = await client.get(f"/api/v1/jobs/{job_id}") + assert response.status_code == 404 + + @pytest.mark.asyncio + async def test_delete_running_job_cancels_it(self, client: AsyncClient, app: FastAPI): + """Test DELETE /api/v1/jobs/{id} cancels running job.""" + scheduler = app.state.scheduler + + async def long_task(): + await asyncio.sleep(10) + return "never" + + job_id = await scheduler.add_job(long_task) + await asyncio.sleep(0.01) # Let it start + + # Delete while running + response = await client.delete(f"/api/v1/jobs/{job_id}") + assert response.status_code == 204 + + # Verify job is gone + response = await client.get(f"/api/v1/jobs/{job_id}") + assert response.status_code == 404 + + @pytest.mark.asyncio + async def test_delete_job_not_found(self, client: AsyncClient): + """Test DELETE /api/v1/jobs/{id} returns 404 for non-existent job.""" + fake_id = ULID() + response = await client.delete(f"/api/v1/jobs/{fake_id}") + assert response.status_code == 404 + assert response.json()["detail"] == "Job not found" + + @pytest.mark.asyncio + async def test_job_with_artifact_id(self, client: AsyncClient, app: FastAPI): + """Test job that returns ULID sets artifact_id.""" + scheduler = app.state.scheduler + + artifact_id = ULID() + + async def task(): + return artifact_id + + job_id = await scheduler.add_job(task) + await scheduler.wait(job_id) + + # Get job record + response = await client.get(f"/api/v1/jobs/{job_id}") + assert response.status_code == 200 + job = response.json() + assert job["artifact_id"] == str(artifact_id) + + @pytest.mark.asyncio + async def test_failed_job_has_error(self, client: AsyncClient, app: FastAPI): + """Test failed job includes error traceback.""" + scheduler = app.state.scheduler + + async def failing_task(): + raise ValueError("Something went wrong") + + job_id = await scheduler.add_job(failing_task) + + # Wait for failure + try: + await scheduler.wait(job_id) + except ValueError: + pass + + # Get job record + response = await client.get(f"/api/v1/jobs/{job_id}") + assert response.status_code == 200 + job = response.json() + assert job["status"] == "failed" + assert job["error"] is not None + assert "ValueError" in job["error"] + assert "Something went wrong" in job["error"] + + @pytest.mark.asyncio + async def test_jobs_sorted_newest_first(self, client: AsyncClient, app: FastAPI): + """Test GET /api/v1/jobs returns jobs sorted newest first.""" + scheduler = app.state.scheduler + + async def task(): + return "done" + + job_ids = [] + for _ in range(3): + jid = await scheduler.add_job(task) + job_ids.append(jid) + await asyncio.sleep(0.01) # Ensure different timestamps + + # Get jobs + response = await client.get("/api/v1/jobs") + assert response.status_code == 200 + jobs = response.json() + assert len(jobs) == 3 + + # Should be newest first + assert jobs[0]["id"] == str(job_ids[2]) + assert jobs[1]["id"] == str(job_ids[1]) + assert jobs[2]["id"] == str(job_ids[0]) + + @pytest.mark.asyncio + async def test_stream_job_status_quick_job(self, client: AsyncClient, app: FastAPI): + """Test SSE streaming for quick job that completes immediately.""" + scheduler = app.state.scheduler + + async def quick_task(): + return "done" + + job_id = await scheduler.add_job(quick_task) + await scheduler.wait(job_id) + + # Stream SSE events + events = [] + async with client.stream("GET", f"/api/v1/jobs/{job_id}/$stream") as response: + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + assert response.headers["cache-control"] == "no-cache" + + async for line in response.aiter_lines(): + if line.startswith("data: "): + import json + + data = json.loads(line[6:]) + events.append(data) + + # Should have at least one event with completed status + assert len(events) >= 1 + assert events[-1]["status"] == "completed" + assert events[-1]["id"] == str(job_id) + + @pytest.mark.asyncio + async def test_stream_job_status_running_job(self, client: AsyncClient, app: FastAPI): + """Test SSE streaming for running job with status transitions.""" + scheduler = app.state.scheduler + + async def slow_task(): + await asyncio.sleep(0.5) + return "done" + + job_id = await scheduler.add_job(slow_task) + + # Stream SSE events + events = [] + async with client.stream("GET", f"/api/v1/jobs/{job_id}/$stream?poll_interval=0.1") as response: + assert response.status_code == 200 + + async for line in response.aiter_lines(): + if line.startswith("data: "): + import json + + data = json.loads(line[6:]) + events.append(data) + if data["status"] == "completed": + break + + # Should have multiple events showing status transitions + assert len(events) >= 2 + statuses = [e["status"] for e in events] + assert "running" in statuses or "pending" in statuses + assert events[-1]["status"] == "completed" + + @pytest.mark.asyncio + async def test_stream_job_status_failed_job(self, client: AsyncClient, app: FastAPI): + """Test SSE streaming for failed job.""" + scheduler = app.state.scheduler + + async def failing_task(): + raise ValueError("Task failed") + + job_id = await scheduler.add_job(failing_task) + + # Stream SSE events + events = [] + async with client.stream("GET", f"/api/v1/jobs/{job_id}/$stream") as response: + assert response.status_code == 200 + + async for line in response.aiter_lines(): + if line.startswith("data: "): + import json + + data = json.loads(line[6:]) + events.append(data) + if data["status"] == "failed": + break + + # Final event should show failed status with error + assert events[-1]["status"] == "failed" + assert events[-1]["error"] is not None + assert "ValueError" in events[-1]["error"] + + @pytest.mark.asyncio + async def test_stream_job_status_not_found(self, client: AsyncClient): + """Test SSE streaming for non-existent job returns 404.""" + fake_id = ULID() + response = await client.get(f"/api/v1/jobs/{fake_id}/$stream") + assert response.status_code == 404 + + @pytest.mark.asyncio + async def test_stream_job_status_invalid_ulid(self, client: AsyncClient): + """Test SSE streaming with invalid ULID returns 400.""" + response = await client.get("/api/v1/jobs/invalid-ulid/$stream") + assert response.status_code == 400 + assert response.json()["detail"] == "Invalid job ID format" + + @pytest.mark.asyncio + async def test_stream_job_status_custom_poll_interval(self, client: AsyncClient, app: FastAPI): + """Test SSE streaming with custom poll interval.""" + scheduler = app.state.scheduler + + async def slow_task(): + await asyncio.sleep(0.5) + return "done" + + job_id = await scheduler.add_job(slow_task) + + # Stream with custom poll interval + events = [] + async with client.stream("GET", f"/api/v1/jobs/{job_id}/$stream?poll_interval=0.2") as response: + assert response.status_code == 200 + + async for line in response.aiter_lines(): + if line.startswith("data: "): + import json + + data = json.loads(line[6:]) + events.append(data) + if data["status"] == "completed": + break + + # Should have received events and completed + assert len(events) >= 1 + assert events[-1]["status"] == "completed" + + @pytest.mark.asyncio + async def test_jobs_schema_endpoint(self, client: AsyncClient): + """Test GET /api/v1/jobs/$schema returns JSON schema.""" + response = await client.get("/api/v1/jobs/$schema") + assert response.status_code == 200 + schema = response.json() + + # Verify schema structure + assert schema["type"] == "array" + assert "items" in schema + assert "$ref" in schema["items"] + assert schema["items"]["$ref"] == "#/$defs/JobRecord" + + # Verify JobRecord definition exists + assert "$defs" in schema + assert "JobRecord" in schema["$defs"] + + # Verify JobRecord schema has required fields + job_record_schema = schema["$defs"]["JobRecord"] + assert job_record_schema["type"] == "object" + assert "properties" in job_record_schema + assert "id" in job_record_schema["properties"] + assert "status" in job_record_schema["properties"] + assert "submitted_at" in job_record_schema["properties"] + + # Verify required fields (only id is required, others have defaults) + assert "required" in job_record_schema + assert "id" in job_record_schema["required"] + + +class TestJobRouterIntegration: + """Integration tests for job router with ServiceBuilder.""" + + @pytest.mark.asyncio + async def test_service_builder_with_jobs(self) -> None: + """Test ServiceBuilder.with_jobs() creates functional job endpoints.""" + info = ServiceInfo(display_name="Test Service") + app = ServiceBuilder(info=info).with_jobs(prefix="/jobs", tags=["background"]).build() + + # Trigger lifespan to initialize scheduler + async with app.router.lifespan_context(app): + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test", follow_redirects=True + ) as client: + # Check jobs endpoint exists + response = await client.get("/jobs") + assert response.status_code == 200 + assert response.json() == [] + + @pytest.mark.asyncio + async def test_service_builder_with_max_concurrency(self) -> None: + """Test ServiceBuilder.with_jobs(max_concurrency=N) configures scheduler.""" + info = ServiceInfo(display_name="Test Service") + app = ServiceBuilder(info=info).with_jobs(max_concurrency=2).build() + + # Trigger lifespan to initialize scheduler + async with app.router.lifespan_context(app): + # Access scheduler via app.state + scheduler = app.state.scheduler + assert scheduler.max_concurrency == 2 + + @pytest.mark.asyncio + async def test_job_endpoints_in_openapi_schema(self) -> None: + """Test job endpoints appear in OpenAPI schema.""" + info = ServiceInfo(display_name="Test Service") + app = ServiceBuilder(info=info).with_jobs().build() + + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test", follow_redirects=True + ) as client: + response = await client.get("/openapi.json") + assert response.status_code == 200 + schema = response.json() + + # Check job endpoints exist in schema + paths = schema["paths"] + assert "/api/v1/jobs" in paths + assert "/api/v1/jobs/{job_id}" in paths + + # Check tags at operation level + jobs_list_tags = paths["/api/v1/jobs"]["get"]["tags"] + assert "Jobs" in jobs_list_tags diff --git a/packages/servicekit/tests/test_manager.py b/packages/servicekit/tests/test_manager.py new file mode 100644 index 0000000..ce9fe29 --- /dev/null +++ b/packages/servicekit/tests/test_manager.py @@ -0,0 +1,467 @@ +from servicekit import ConfigIn, ConfigManager, ConfigOut, ConfigRepository, SqliteDatabaseBuilder +from ulid import ULID + +from .conftest import DemoConfig + + +class TestBaseManager: + """Tests for the ConfigManager class.""" + + async def test_save_with_input_schema(self) -> None: + """Test saving an entity using input schema.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ConfigRepository(session) + manager = ConfigManager[DemoConfig](repo, DemoConfig) + + # Create input schema + config_in = ConfigIn[DemoConfig](name="test_config", data=DemoConfig(x=1, y=2, z=3, tags=["test"])) + + # Save and get output schema + result = await manager.save(config_in) + + assert isinstance(result, ConfigOut) + assert result.id is not None + assert result.name == "test_config" + assert result.data is not None + assert isinstance(result.data, DemoConfig) + assert result.data.x == 1 + assert result.data.y == 2 + assert result.data.z == 3 + assert result.data.tags == ["test"] + + await db.dispose() + + async def test_save_with_id_none_removes_id(self) -> None: + """Test that save() removes id field when it's None.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ConfigRepository(session) + manager = ConfigManager[DemoConfig](repo, DemoConfig) + + # Create input schema with id=None (default) + config_in = ConfigIn[DemoConfig](id=None, name="test", data=DemoConfig(x=1, y=2, z=3, tags=[])) + + result = await manager.save(config_in) + + # Should have a generated ID + assert result.id is not None + assert isinstance(result.id, ULID) + + await db.dispose() + + async def test_save_preserves_explicit_id(self) -> None: + """Test that save() keeps a provided non-null ID intact.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ConfigRepository(session) + manager = ConfigManager[DemoConfig](repo, DemoConfig) + + explicit_id = ULID() + config_in = ConfigIn[DemoConfig]( + id=explicit_id, + name="explicit_id_config", + data=DemoConfig(x=5, y=5, z=5, tags=["explicit"]), + ) + + result = await manager.save(config_in) + + assert result.id == explicit_id + assert result.name == "explicit_id_config" + + await db.dispose() + + async def test_save_all(self) -> None: + """Test saving multiple entities using input schemas.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ConfigRepository(session) + manager = ConfigManager[DemoConfig](repo, DemoConfig) + + # Create multiple input schemas + configs_in = [ + ConfigIn(name=f"config{i}", data=DemoConfig(x=i, y=i * 2, z=i * 3, tags=[f"tag{i}"])) for i in range(3) + ] + + # Save all + results = await manager.save_all(configs_in) + + assert len(results) == 3 + assert all(isinstance(r, ConfigOut) for r in results) + assert all(r.id is not None for r in results) + assert results[0].name == "config0" + assert results[1].name == "config1" + assert results[2].name == "config2" + + await db.dispose() + + async def test_delete_by_id(self) -> None: + """Test deleting an entity by ID.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ConfigRepository(session) + manager = ConfigManager[DemoConfig](repo, DemoConfig) + + # Create and save entity + config_in = ConfigIn[DemoConfig](name="to_delete", data=DemoConfig(x=1, y=2, z=3, tags=[])) + result = await manager.save(config_in) + + # Verify it exists + assert await manager.count() == 1 + + # Delete it + assert result.id is not None + await manager.delete_by_id(result.id) + + # Verify it's gone + assert await manager.count() == 0 + + await db.dispose() + + async def test_delete_all(self) -> None: + """Test deleting all entities.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ConfigRepository(session) + manager = ConfigManager[DemoConfig](repo, DemoConfig) + + # Create multiple entities + configs_in = [ConfigIn(name=f"config{i}", data=DemoConfig(x=i, y=i, z=i, tags=[])) for i in range(5)] + await manager.save_all(configs_in) + + # Verify they exist + assert await manager.count() == 5 + + # Delete all + await manager.delete_all() + + # Verify all gone + assert await manager.count() == 0 + + await db.dispose() + + async def test_delete_many_by_ids(self) -> None: + """Test deleting multiple entities by their IDs.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ConfigRepository(session) + manager = ConfigManager[DemoConfig](repo, DemoConfig) + + # Create entities + configs_in = [ConfigIn(name=f"config{i}", data=DemoConfig(x=i, y=i, z=i, tags=[])) for i in range(5)] + results = await manager.save_all(configs_in) + + # Delete some by ID + assert results[1].id is not None + assert results[3].id is not None + to_delete = [results[1].id, results[3].id] + await manager.delete_all_by_id(to_delete) + + # Should have 3 remaining + assert await manager.count() == 3 + + await db.dispose() + + async def test_delete_all_by_id_empty_list(self) -> None: + """Test that delete_all_by_id with empty list does nothing.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ConfigRepository(session) + manager = ConfigManager[DemoConfig](repo, DemoConfig) + + # Create entities + configs_in = [ConfigIn(name=f"config{i}", data=DemoConfig(x=i, y=i, z=i, tags=[])) for i in range(3)] + await manager.save_all(configs_in) + + # Delete with empty list + await manager.delete_all_by_id([]) + + # All should still exist + assert await manager.count() == 3 + + await db.dispose() + + async def test_count(self) -> None: + """Test counting entities through manager.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ConfigRepository(session) + manager = ConfigManager[DemoConfig](repo, DemoConfig) + + # Initially empty + assert await manager.count() == 0 + + # Add entities + configs_in = [ConfigIn(name=f"config{i}", data=DemoConfig(x=i, y=i, z=i, tags=[])) for i in range(7)] + await manager.save_all(configs_in) + + # Count should be 7 + assert await manager.count() == 7 + + await db.dispose() + + async def test_output_schema_validation(self) -> None: + """Test that output schemas are properly validated from ORM models.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ConfigRepository(session) + manager = ConfigManager[DemoConfig](repo, DemoConfig) + + # Create entity with complex data + config_in = ConfigIn[DemoConfig]( + name="validation_test", + data=DemoConfig(x=10, y=20, z=30, tags=["production", "critical", "v2.0"]), + ) + + result = await manager.save(config_in) + + # Verify output schema is correct + assert isinstance(result, ConfigOut) + assert isinstance(result.id, ULID) + assert result.name == "validation_test" + assert result.data is not None + assert isinstance(result.data, DemoConfig) + assert result.data.x == 10 + assert result.data.y == 20 + assert result.data.z == 30 + assert result.data.tags == ["production", "critical", "v2.0"] + + await db.dispose() + + async def test_save_all_returns_list_of_output_schemas(self) -> None: + """Test that save_all returns a list of output schemas.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ConfigRepository(session) + manager = ConfigManager[DemoConfig](repo, DemoConfig) + + configs_in = [ + ConfigIn(name="config1", data=DemoConfig(x=1, y=1, z=1, tags=["a"])), + ConfigIn(name="config2", data=DemoConfig(x=2, y=2, z=2, tags=["b"])), + ] + + results = await manager.save_all(configs_in) + + assert isinstance(results, list) + assert len(results) == 2 + assert all(isinstance(r, ConfigOut) for r in results) + assert all(r.id is not None for r in results) + + await db.dispose() + + async def test_manager_commits_after_save(self) -> None: + """Test that manager commits changes after save.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ConfigRepository(session) + manager = ConfigManager[DemoConfig](repo, DemoConfig) + + config_in = ConfigIn[DemoConfig](name="committed", data=DemoConfig(x=1, y=2, z=3, tags=[])) + await manager.save(config_in) + + # Check in a new session that it was committed + async with db.session() as session2: + repo2 = ConfigRepository(session2) + manager2 = ConfigManager[DemoConfig](repo2, DemoConfig) + assert await manager2.count() == 1 + + await db.dispose() + + async def test_manager_commits_after_delete(self) -> None: + """Test that manager commits changes after delete.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + # Save in one session + async with db.session() as session1: + repo1 = ConfigRepository(session1) + manager1 = ConfigManager[DemoConfig](repo1, DemoConfig) + config_in = ConfigIn[DemoConfig](name="to_delete", data=DemoConfig(x=1, y=2, z=3, tags=[])) + result = await manager1.save(config_in) + assert result.id is not None + saved_id = result.id + + # Delete in another session + async with db.session() as session2: + repo2 = ConfigRepository(session2) + manager2 = ConfigManager[DemoConfig](repo2, DemoConfig) + await manager2.delete_by_id(saved_id) + + # Verify in yet another session + async with db.session() as session3: + repo3 = ConfigRepository(session3) + manager3 = ConfigManager[DemoConfig](repo3, DemoConfig) + assert await manager3.count() == 0 + + await db.dispose() + + async def test_find_by_id(self) -> None: + """Test finding an entity by ID through manager.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ConfigRepository(session) + manager = ConfigManager[DemoConfig](repo, DemoConfig) + + # Create entity + config_in = ConfigIn[DemoConfig](name="findable", data=DemoConfig(x=10, y=20, z=30, tags=["test"])) + saved = await manager.save(config_in) + + # Find by ID + assert saved.id is not None + found = await manager.find_by_id(saved.id) + + assert found is not None + assert isinstance(found, ConfigOut) + assert found.id == saved.id + assert found.name == "findable" + assert found.data is not None + assert found.data.x == 10 + + # Non-existent ID should return None + random_id = ULID() + not_found = await manager.find_by_id(random_id) + assert not_found is None + + await db.dispose() + + async def test_find_all(self) -> None: + """Test finding all entities through manager.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ConfigRepository(session) + manager = ConfigManager[DemoConfig](repo, DemoConfig) + + # Create entities + configs_in = [ + ConfigIn(name=f"config{i}", data=DemoConfig(x=i, y=i * 2, z=i * 3, tags=[f"tag{i}"])) for i in range(5) + ] + await manager.save_all(configs_in) + + # Find all + all_configs = await manager.find_all() + + assert len(all_configs) == 5 + assert all(isinstance(c, ConfigOut) for c in all_configs) + assert all(c.id is not None for c in all_configs) + + await db.dispose() + + async def test_find_all_by_id(self) -> None: + """Test finding multiple entities by IDs through manager.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ConfigRepository(session) + manager = ConfigManager[DemoConfig](repo, DemoConfig) + + # Create entities + configs_in = [ConfigIn(name=f"config{i}", data=DemoConfig(x=i, y=i, z=i, tags=[])) for i in range(5)] + results = await manager.save_all(configs_in) + + # Find by specific IDs + assert results[0].id is not None + assert results[2].id is not None + assert results[4].id is not None + target_ids = [results[0].id, results[2].id, results[4].id] + found = await manager.find_all_by_id(target_ids) + + assert len(found) == 3 + assert all(isinstance(c, ConfigOut) for c in found) + assert all(c.id in target_ids for c in found) + + await db.dispose() + + async def test_exists_by_id(self) -> None: + """Test checking if entity exists by ID through manager.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ConfigRepository(session) + manager = ConfigManager[DemoConfig](repo, DemoConfig) + + # Create entity + config_in = ConfigIn[DemoConfig](name="exists_test", data=DemoConfig(x=1, y=2, z=3, tags=[])) + saved = await manager.save(config_in) + + # Should exist + assert saved.id is not None + assert await manager.exists_by_id(saved.id) is True + + # Random ID should not exist + random_id = ULID() + assert await manager.exists_by_id(random_id) is False + + await db.dispose() + + async def test_find_by_name_handles_missing(self) -> None: + """Test that find_by_name returns schema when found and None otherwise.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ConfigRepository(session) + manager = ConfigManager[DemoConfig](repo, DemoConfig) + + assert await manager.find_by_name("unknown") is None + + config_in = ConfigIn[DemoConfig](name="known", data=DemoConfig(x=7, y=8, z=9, tags=["known"])) + saved = await manager.save(config_in) + + assert saved.id is not None + + found = await manager.find_by_name("known") + assert found is not None + assert found.id == saved.id + assert isinstance(found.data, DemoConfig) + + await db.dispose() + + async def test_output_schema_includes_timestamps(self) -> None: + """Test that output schemas include created_at and updated_at timestamps.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ConfigRepository(session) + manager = ConfigManager[DemoConfig](repo, DemoConfig) + + # Create entity + config_in = ConfigIn[DemoConfig](name="timestamp_test", data=DemoConfig(x=1, y=2, z=3, tags=[])) + result = await manager.save(config_in) + + # Verify timestamps exist + assert result.created_at is not None + assert result.updated_at is not None + assert result.id is not None + + await db.dispose() diff --git a/packages/servicekit/tests/test_manager_artifact.py b/packages/servicekit/tests/test_manager_artifact.py new file mode 100644 index 0000000..75d530b --- /dev/null +++ b/packages/servicekit/tests/test_manager_artifact.py @@ -0,0 +1,92 @@ +"""Service-layer tests for artifact management.""" + +from __future__ import annotations + +from datetime import datetime, timezone +from types import SimpleNamespace +from typing import cast + +from servicekit import Artifact, ArtifactHierarchy, ArtifactManager, ArtifactRepository +from ulid import ULID + + +def test_artifact_manager_should_assign_field_controls_level() -> None: + """_should_assign_field should skip assigning level when the value is None.""" + manager = ArtifactManager(cast(ArtifactRepository, SimpleNamespace())) + + assert manager._should_assign_field("data", {"foo": "bar"}) is True + assert manager._should_assign_field("level", 2) is True + assert manager._should_assign_field("level", None) is False + + +def test_artifact_manager_to_tree_node_applies_hierarchy_metadata() -> None: + """_to_tree_node should attach hierarchy metadata when configured.""" + hierarchy = ArtifactHierarchy(name="ml", level_labels={0: "root"}) + manager = ArtifactManager(cast(ArtifactRepository, SimpleNamespace()), hierarchy=hierarchy) + + now = datetime.now(timezone.utc) + entity = SimpleNamespace( + id=ULID(), + data={"name": "root"}, + parent_id=None, + level=0, + created_at=now, + updated_at=now, + ) + + node = manager._to_tree_node(cast(Artifact, entity)) + + assert node.level_label == "root" + assert node.hierarchy == "ml" + assert node.children is None + + +class StaticArtifactRepository: + """Minimal repository stub returning a predefined subtree.""" + + def __init__(self, nodes: list[SimpleNamespace]) -> None: + self._nodes = nodes + + async def find_subtree(self, start_id: ULID) -> list[Artifact]: + return [cast(Artifact, node) for node in self._nodes] + + +async def test_artifact_manager_build_tree_with_unordered_results() -> None: + """build_tree should assemble a nested tree even when repository order is arbitrary.""" + now = datetime.now(timezone.utc) + root_id = ULID() + child_id = ULID() + grandchild_id = ULID() + + def make_entity(identifier: ULID, parent: ULID | None, level: int) -> SimpleNamespace: + return SimpleNamespace( + id=identifier, + parent_id=parent, + level=level, + data={"name": str(identifier)}, + created_at=now, + updated_at=now, + ) + + nodes = [ + make_entity(child_id, root_id, 1), + make_entity(grandchild_id, child_id, 2), + make_entity(root_id, None, 0), + ] + + repo = StaticArtifactRepository(nodes) + manager = ArtifactManager(cast(ArtifactRepository, repo)) + + tree = await manager.build_tree(root_id) + + assert tree is not None + assert tree.id == root_id + assert tree.children is not None + assert len(tree.children) == 1 + child_node = tree.children[0] + assert child_node.id == child_id + assert child_node.children is not None + assert len(child_node.children) == 1 + grandchild_node = child_node.children[0] + assert grandchild_node.id == grandchild_id + assert grandchild_node.children == [] diff --git a/packages/servicekit/tests/test_manager_config.py b/packages/servicekit/tests/test_manager_config.py new file mode 100644 index 0000000..0a58464 --- /dev/null +++ b/packages/servicekit/tests/test_manager_config.py @@ -0,0 +1,167 @@ +"""Service-layer tests for config management.""" + +from __future__ import annotations + +from servicekit import ( + Artifact, + ArtifactRepository, + Config, + ConfigIn, + ConfigManager, + ConfigOut, + ConfigRepository, + SqliteDatabaseBuilder, +) +from ulid import ULID + +from .conftest import DemoConfig + + +async def test_config_manager_deserializes_dict_payloads() -> None: + """ConfigManager should convert raw dict payloads to the configured schema.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ConfigRepository(session) + manager = ConfigManager[DemoConfig](repo, DemoConfig) + + assert await manager.find_by_name("missing") is None + + raw = Config(name="raw", data={"x": 1, "y": 2, "z": 3, "tags": ["dict"]}) + await repo.save(raw) + await repo.commit() + await repo.refresh_many([raw]) + + output = manager._to_output_schema(raw) + assert isinstance(output, ConfigOut) + assert isinstance(output.data, DemoConfig) + assert output.data.x == 1 + assert output.data.tags == ["dict"] + + found = await manager.find_by_name("raw") + assert found is not None + assert isinstance(found.data, DemoConfig) + assert found.data.model_dump() == output.data.model_dump() + + await db.dispose() + + +async def test_config_manager_save() -> None: + """Test saving a config through the manager.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ConfigRepository(session) + manager = ConfigManager[DemoConfig](repo, DemoConfig) + + config_in = ConfigIn[DemoConfig](name="test_save", data=DemoConfig(x=10, y=20, z=30, tags=["test"])) + saved = await manager.save(config_in) + + assert saved.name == "test_save" + assert saved.data.x == 10 + + await db.dispose() + + +async def test_config_manager_link_artifact() -> None: + """Test linking a config to a root artifact.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + config_repo = ConfigRepository(session) + artifact_repo = ArtifactRepository(session) + manager = ConfigManager[DemoConfig](config_repo, DemoConfig) + + # Create config + config_in = ConfigIn[DemoConfig](name="test_config", data=DemoConfig(x=1, y=2, z=3, tags=[])) + saved_config = await manager.save(config_in) + + # Create root artifact + root_artifact = Artifact(parent_id=None, data={"type": "root"}) + await artifact_repo.save(root_artifact) + await artifact_repo.commit() + await artifact_repo.refresh_many([root_artifact]) + + # Link them + await manager.link_artifact(saved_config.id, root_artifact.id) + + # Verify link + linked_artifacts = await manager.get_linked_artifacts(saved_config.id) + assert len(linked_artifacts) == 1 + assert linked_artifacts[0].id == root_artifact.id + + await db.dispose() + + +async def test_config_manager_unlink_artifact() -> None: + """Test unlinking an artifact from a config.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + config_repo = ConfigRepository(session) + artifact_repo = ArtifactRepository(session) + manager = ConfigManager[DemoConfig](config_repo, DemoConfig) + + # Create config and artifact + config_in = ConfigIn[DemoConfig](name="test_unlink", data=DemoConfig(x=1, y=2, z=3, tags=[])) + saved_config = await manager.save(config_in) + + root_artifact = Artifact(parent_id=None, data={"type": "root"}) + await artifact_repo.save(root_artifact) + await artifact_repo.commit() + await artifact_repo.refresh_many([root_artifact]) + + # Link and then unlink + await manager.link_artifact(saved_config.id, root_artifact.id) + await manager.unlink_artifact(root_artifact.id) + + # Verify unlinked + linked_artifacts = await manager.get_linked_artifacts(saved_config.id) + assert len(linked_artifacts) == 0 + + await db.dispose() + + +async def test_config_manager_get_config_for_artifact() -> None: + """Test getting config for an artifact by walking up the tree.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + config_repo = ConfigRepository(session) + artifact_repo = ArtifactRepository(session) + manager = ConfigManager[DemoConfig](config_repo, DemoConfig) + + # Create config + config_in = ConfigIn[DemoConfig](name="tree_config", data=DemoConfig(x=100, y=200, z=300, tags=["tree"])) + saved_config = await manager.save(config_in) + + # Create artifact tree: root -> child + root_artifact = Artifact(parent_id=None, data={"level": "root"}) + await artifact_repo.save(root_artifact) + await artifact_repo.commit() + await artifact_repo.refresh_many([root_artifact]) + + child_artifact = Artifact(parent_id=root_artifact.id, data={"level": "child"}) + await artifact_repo.save(child_artifact) + await artifact_repo.commit() + await artifact_repo.refresh_many([child_artifact]) + + # Link config to root + await manager.link_artifact(saved_config.id, root_artifact.id) + + # Get config via child artifact (should walk up to root) + config_from_child = await manager.get_config_for_artifact(child_artifact.id, artifact_repo) + assert config_from_child is not None + assert config_from_child.id == saved_config.id + assert config_from_child.name == "tree_config" + + # Test with non-existent artifact + config_from_missing = await manager.get_config_for_artifact(ULID(), artifact_repo) + assert config_from_missing is None + + await db.dispose() diff --git a/packages/servicekit/tests/test_manager_task.py b/packages/servicekit/tests/test_manager_task.py new file mode 100644 index 0000000..48e0b9a --- /dev/null +++ b/packages/servicekit/tests/test_manager_task.py @@ -0,0 +1,154 @@ +"""Tests for TaskManager error handling and edge cases.""" + +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from servicekit import ArtifactManager, Task, TaskManager, TaskRepository +from servicekit.core import Database, JobScheduler +from ulid import ULID + + +@pytest.mark.asyncio +async def test_execute_task_without_scheduler() -> None: + """Test execute_task raises error when scheduler not configured.""" + mock_repo = Mock(spec=TaskRepository) + manager = TaskManager(repo=mock_repo, scheduler=None, database=None, artifact_manager=None) + + with pytest.raises(ValueError, match="Task execution requires a scheduler"): + await manager.execute_task(ULID()) + + +@pytest.mark.asyncio +async def test_execute_task_without_artifact_manager() -> None: + """Test execute_task raises error when artifact manager not configured.""" + mock_repo = Mock(spec=TaskRepository) + mock_scheduler = Mock(spec=JobScheduler) + + manager = TaskManager( + repo=mock_repo, + scheduler=mock_scheduler, + database=None, + artifact_manager=None, + ) + + with pytest.raises(ValueError, match="Task execution requires artifacts"): + await manager.execute_task(ULID()) + + +@pytest.mark.asyncio +async def test_execute_task_not_found() -> None: + """Test execute_task raises error for non-existent task.""" + mock_repo = Mock(spec=TaskRepository) + mock_repo.find_by_id = AsyncMock(return_value=None) + + mock_scheduler = Mock(spec=JobScheduler) + mock_artifact_manager = Mock(spec=ArtifactManager) + + manager = TaskManager( + repo=mock_repo, + scheduler=mock_scheduler, + database=None, + artifact_manager=mock_artifact_manager, + ) + + task_id = ULID() + with pytest.raises(ValueError, match=f"Task {task_id} not found"): + await manager.execute_task(task_id) + + mock_repo.find_by_id.assert_called_once_with(task_id) + + +@pytest.mark.asyncio +async def test_execute_task_submits_to_scheduler() -> None: + """Test execute_task successfully submits job to scheduler.""" + task_id = ULID() + job_id = ULID() + + mock_task = Mock(spec=Task) + mock_task.id = task_id + mock_task.command = "echo test" + + mock_repo = Mock(spec=TaskRepository) + mock_repo.find_by_id = AsyncMock(return_value=mock_task) + + mock_scheduler = Mock(spec=JobScheduler) + mock_scheduler.add_job = AsyncMock(return_value=job_id) + + mock_artifact_manager = Mock(spec=ArtifactManager) + + manager = TaskManager( + repo=mock_repo, + scheduler=mock_scheduler, + database=None, + artifact_manager=mock_artifact_manager, + ) + + result = await manager.execute_task(task_id) + + assert result == job_id + mock_repo.find_by_id.assert_called_once_with(task_id) + mock_scheduler.add_job.assert_called_once() + + +@pytest.mark.asyncio +async def test_execute_command_without_database() -> None: + """Test _execute_command raises error when database not configured.""" + mock_repo = Mock(spec=TaskRepository) + manager = TaskManager(repo=mock_repo, scheduler=None, database=None, artifact_manager=None) + + with pytest.raises(RuntimeError, match="Database instance required"): + await manager._execute_command(ULID()) + + +@pytest.mark.asyncio +async def test_execute_command_without_artifact_manager() -> None: + """Test _execute_command raises error when artifact manager not configured.""" + mock_repo = Mock(spec=TaskRepository) + mock_database = Mock(spec=Database) + + manager = TaskManager( + repo=mock_repo, + scheduler=None, + database=mock_database, + artifact_manager=None, + ) + + with pytest.raises(RuntimeError, match="ArtifactManager instance required"): + await manager._execute_command(ULID()) + + +@pytest.mark.asyncio +async def test_execute_command_task_not_found() -> None: + """Test _execute_command raises error for non-existent task.""" + task_id = ULID() + + # Mock session context manager + mock_session = Mock() + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=None) + + # Mock task repo that returns None + mock_task_repo = Mock(spec=TaskRepository) + mock_task_repo.find_by_id = AsyncMock(return_value=None) + + # Mock database + mock_database = Mock(spec=Database) + mock_database.session = Mock(return_value=mock_session) + + # Patch TaskRepository to return our mock + with patch( + "chapkit.modules.task.manager.TaskRepository", + return_value=mock_task_repo, + ): + mock_repo = Mock(spec=TaskRepository) + mock_artifact_manager = Mock(spec=ArtifactManager) + + manager = TaskManager( + repo=mock_repo, + scheduler=None, + database=mock_database, + artifact_manager=mock_artifact_manager, + ) + + with pytest.raises(ValueError, match=f"Task {task_id} not found"): + await manager._execute_command(task_id) diff --git a/packages/servicekit/tests/test_middleware.py b/packages/servicekit/tests/test_middleware.py new file mode 100644 index 0000000..89f5d48 --- /dev/null +++ b/packages/servicekit/tests/test_middleware.py @@ -0,0 +1,227 @@ +"""Tests for FastAPI middleware and error handlers.""" + +from __future__ import annotations + +import os + +import pytest +from fastapi import FastAPI, Request +from fastapi.testclient import TestClient +from pydantic import BaseModel, ValidationError +from servicekit.core.api.middleware import ( + add_error_handlers, + add_logging_middleware, + database_error_handler, + validation_error_handler, +) +from servicekit.core.logging import configure_logging +from sqlalchemy.exc import SQLAlchemyError + + +class SampleModel(BaseModel): + """Sample Pydantic model for validation tests.""" + + name: str + age: int + + +@pytest.fixture +def app_with_handlers() -> FastAPI: + """Create a FastAPI app with error handlers registered.""" + app = FastAPI() + add_error_handlers(app) + + @app.get("/db-error") + async def trigger_db_error() -> None: + raise SQLAlchemyError("Database connection failed") + + @app.get("/validation-error") + async def trigger_validation_error() -> None: + raise ValidationError.from_exception_data( + "SampleModel", + [ + { + "type": "missing", + "loc": ("name",), + "input": {}, + } + ], + ) + + return app + + +def test_database_error_handler_returns_500(app_with_handlers: FastAPI) -> None: + """Test that database errors return 500 status with proper error message.""" + client = TestClient(app_with_handlers) + + response = client.get("/db-error") + + assert response.status_code == 500 + payload = response.json() + assert payload["detail"] == "Database error occurred" + assert "error" in payload + assert "Database connection failed" in payload["error"] + + +def test_validation_error_handler_returns_422(app_with_handlers: FastAPI) -> None: + """Test that validation errors return 422 status with proper error message.""" + client = TestClient(app_with_handlers) + + response = client.get("/validation-error") + + assert response.status_code == 422 + payload = response.json() + assert payload["detail"] == "Validation error" + assert "errors" in payload + + +async def test_database_error_handler_direct() -> None: + """Test database_error_handler directly without FastAPI context.""" + + class MockURL: + """Mock URL object.""" + + path = "/test" + + class MockRequest: + """Mock request object.""" + + url = MockURL() + + exc = SQLAlchemyError("Test error") + response = await database_error_handler(MockRequest(), exc) # type: ignore + + assert response.status_code == 500 + assert response.body == b'{"detail":"Database error occurred","error":"Test error"}' + + +async def test_validation_error_handler_direct() -> None: + """Test validation_error_handler directly without FastAPI context.""" + + class MockURL: + """Mock URL object.""" + + path = "/test" + + class MockRequest: + """Mock request object.""" + + url = MockURL() + + exc = ValidationError.from_exception_data( + "TestModel", + [ + { + "type": "missing", + "loc": ("field",), + "input": {}, + } + ], + ) + response = await validation_error_handler(MockRequest(), exc) # type: ignore + + assert response.status_code == 422 + assert b"Validation error" in response.body + + +def test_logging_configuration_console() -> None: + """Test that logging can be configured for console output.""" + os.environ["LOG_FORMAT"] = "console" + os.environ["LOG_LEVEL"] = "INFO" + + # Should not raise + configure_logging() + + +def test_logging_configuration_json() -> None: + """Test that logging can be configured for JSON output.""" + os.environ["LOG_FORMAT"] = "json" + os.environ["LOG_LEVEL"] = "DEBUG" + + # Should not raise + configure_logging() + + +def test_request_logging_middleware_adds_request_id() -> None: + """Test that RequestLoggingMiddleware adds request_id to request state and response headers.""" + app = FastAPI() + add_logging_middleware(app) + + @app.get("/test") + async def test_endpoint(request: Request) -> dict: + # Request ID should be accessible in request state + request_id = getattr(request.state, "request_id", None) + return {"request_id": request_id} + + client = TestClient(app) + response = client.get("/test") + + assert response.status_code == 200 + payload = response.json() + + # Request ID should be in response body + assert "request_id" in payload + assert payload["request_id"] is not None + assert len(payload["request_id"]) == 26 # ULID length + + # Request ID should be in response headers + assert "X-Request-ID" in response.headers + assert response.headers["X-Request-ID"] == payload["request_id"] + + +def test_request_logging_middleware_unique_request_ids() -> None: + """Test that each request gets a unique request_id.""" + app = FastAPI() + add_logging_middleware(app) + + @app.get("/test") + async def test_endpoint(request: Request) -> dict: + return {"request_id": request.state.request_id} + + client = TestClient(app) + response1 = client.get("/test") + response2 = client.get("/test") + + assert response1.status_code == 200 + assert response2.status_code == 200 + + request_id1 = response1.json()["request_id"] + request_id2 = response2.json()["request_id"] + + # Request IDs should be different + assert request_id1 != request_id2 + + +def test_request_logging_middleware_handles_exceptions() -> None: + """Test that RequestLoggingMiddleware properly handles and re-raises exceptions.""" + app = FastAPI() + add_logging_middleware(app) + + @app.get("/error") + async def error_endpoint() -> dict: + raise ValueError("Test exception from endpoint") + + client = TestClient(app, raise_server_exceptions=False) + response = client.get("/error") + + # Exception should result in 500 error + assert response.status_code == 500 + + +def test_request_logging_middleware_logs_on_exception() -> None: + """Test that middleware logs errors when exceptions occur during request processing.""" + app = FastAPI() + add_logging_middleware(app) + + @app.get("/test-error") + async def error_endpoint() -> None: + raise RuntimeError("Simulated error") + + client = TestClient(app, raise_server_exceptions=False) + + # This should trigger the exception handler in middleware + response = client.get("/test-error") + + # Should return 500 as the exception is unhandled + assert response.status_code == 500 diff --git a/packages/servicekit/tests/test_monitoring.py b/packages/servicekit/tests/test_monitoring.py new file mode 100644 index 0000000..e6542b8 --- /dev/null +++ b/packages/servicekit/tests/test_monitoring.py @@ -0,0 +1,42 @@ +"""Tests for OpenTelemetry monitoring setup.""" + +from fastapi import FastAPI +from servicekit.core.api.monitoring import setup_monitoring, teardown_monitoring + + +def test_setup_monitoring_with_traces_enabled() -> None: + """Test setup_monitoring with enable_traces=True logs warning.""" + app = FastAPI(title="Test Service") + + # This should log a warning about traces not being implemented + reader = setup_monitoring(app, enable_traces=True) + + assert reader is not None + + +def test_teardown_monitoring() -> None: + """Test teardown_monitoring uninstruments components.""" + app = FastAPI(title="Test Service") + + # Setup first + setup_monitoring(app, service_name="test-teardown") + + # Then teardown + teardown_monitoring() + + # No assertions needed - we're just covering the code path + # The function should handle uninstrumentation gracefully + + +def test_setup_monitoring_idempotent() -> None: + """Test that setup_monitoring can be called multiple times safely.""" + app1 = FastAPI(title="Service 1") + app2 = FastAPI(title="Service 2") + + # First call initializes everything + reader1 = setup_monitoring(app1, service_name="service-1") + assert reader1 is not None + + # Second call should handle already-initialized state + reader2 = setup_monitoring(app2, service_name="service-2") + assert reader2 is not None diff --git a/packages/servicekit/tests/test_pagination.py b/packages/servicekit/tests/test_pagination.py new file mode 100644 index 0000000..0d23d6f --- /dev/null +++ b/packages/servicekit/tests/test_pagination.py @@ -0,0 +1,71 @@ +"""Tests for pagination utilities.""" + +from typing import Any + +from servicekit.core import PaginatedResponse +from servicekit.core.api.pagination import PaginationParams, create_paginated_response + + +def test_pagination_params_is_paginated_both_set() -> None: + """Test is_paginated returns True when both page and size are set.""" + params = PaginationParams(page=1, size=20) + assert params.is_paginated() is True + + +def test_pagination_params_is_paginated_only_page() -> None: + """Test is_paginated returns False when only page is set.""" + params = PaginationParams(page=1, size=None) + assert params.is_paginated() is False + + +def test_pagination_params_is_paginated_only_size() -> None: + """Test is_paginated returns False when only size is set.""" + params = PaginationParams(page=None, size=20) + assert params.is_paginated() is False + + +def test_pagination_params_is_paginated_neither_set() -> None: + """Test is_paginated returns False when neither page nor size are set.""" + params = PaginationParams(page=None, size=None) + assert params.is_paginated() is False + + +def test_pagination_params_defaults() -> None: + """Test PaginationParams defaults to None for both fields.""" + params = PaginationParams() + assert params.page is None + assert params.size is None + assert params.is_paginated() is False + + +def test_create_paginated_response() -> None: + """Test create_paginated_response creates proper response object.""" + items = [{"id": "1", "name": "test"}] + response = create_paginated_response(items, total=100, page=1, size=10) + + assert response.items == items + assert response.total == 100 + assert response.page == 1 + assert response.size == 10 + assert response.pages == 10 # 100 / 10 + + +def test_create_paginated_response_partial_page() -> None: + """Test create_paginated_response calculates pages correctly for partial pages.""" + items = [{"id": "1"}, {"id": "2"}, {"id": "3"}] + response = create_paginated_response(items, total=25, page=1, size=10) + + assert response.total == 25 + assert response.size == 10 + assert response.pages == 3 # ceil(25 / 10) = 3 + + +def test_create_paginated_response_empty() -> None: + """Test create_paginated_response handles empty results.""" + response: PaginatedResponse[Any] = create_paginated_response([], total=0, page=1, size=10) + + assert response.items == [] + assert response.total == 0 + assert response.page == 1 + assert response.size == 10 + assert response.pages == 0 diff --git a/packages/servicekit/tests/test_repository.py b/packages/servicekit/tests/test_repository.py new file mode 100644 index 0000000..6ceed2f --- /dev/null +++ b/packages/servicekit/tests/test_repository.py @@ -0,0 +1,357 @@ +from servicekit import Config, ConfigRepository, SqliteDatabaseBuilder +from servicekit.core import BaseRepository +from ulid import ULID + +from .conftest import DemoConfig + + +class TestBaseRepository: + """Tests for the BaseRepository class.""" + + async def test_save_and_find_by_id(self) -> None: + """Test saving an entity and finding it by ID.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = BaseRepository[Config, ULID](session, Config) + + # Create and save entity + config = Config(name="test_config", data=DemoConfig(x=1, y=2, z=3, tags=["test"])) + saved = await repo.save(config) + await repo.commit() + + # Find by ID + found = await repo.find_by_id(saved.id) + assert found is not None + assert found.id == saved.id + assert found.name == "test_config" + assert isinstance(found.data, dict) + assert found.data["x"] == 1 + assert found.data["y"] == 2 + assert found.data["z"] == 3 + assert found.data["tags"] == ["test"] + + await db.dispose() + + async def test_save_all(self) -> None: + """Test saving multiple entities.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = BaseRepository[Config, ULID](session, Config) + + # Create multiple entities + configs = [ + Config(name="config1", data=DemoConfig(x=1, y=1, z=1, tags=[])), + Config(name="config2", data=DemoConfig(x=2, y=2, z=2, tags=[])), + Config(name="config3", data=DemoConfig(x=3, y=3, z=3, tags=[])), + ] + + await repo.save_all(configs) + await repo.commit() + + # Verify all saved + all_configs = await repo.find_all() + assert len(all_configs) == 3 + + await db.dispose() + + async def test_find_all(self) -> None: + """Test finding all entities.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = BaseRepository[Config, ULID](session, Config) + + # Save some entities + configs = [Config(name=f"config{i}", data=DemoConfig(x=i, y=i, z=i, tags=[])) for i in range(5)] + await repo.save_all(configs) + await repo.commit() + + # Find all + all_configs = await repo.find_all() + assert len(all_configs) == 5 + assert all(isinstance(c, Config) for c in all_configs) + + await db.dispose() + + async def test_find_all_by_id(self) -> None: + """Test finding multiple entities by their IDs.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = BaseRepository[Config, ULID](session, Config) + + # Create and save entities + configs = [Config(name=f"config{i}", data=DemoConfig(x=i, y=i, z=i, tags=[])) for i in range(5)] + await repo.save_all(configs) + await repo.commit() + + # Refresh to get IDs + await repo.refresh_many(configs) + + # Find by specific IDs + target_ids = [configs[0].id, configs[2].id, configs[4].id] + found = await repo.find_all_by_id(target_ids) + + assert len(found) == 3 + assert all(c.id in target_ids for c in found) + + await db.dispose() + + async def test_find_all_by_id_empty_list(self) -> None: + """Test finding entities with an empty ID list.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = BaseRepository[Config, ULID](session, Config) + + found = await repo.find_all_by_id([]) + assert found == [] + + await db.dispose() + + async def test_count(self) -> None: + """Test counting entities.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = BaseRepository[Config, ULID](session, Config) + + # Initially empty + assert await repo.count() == 0 + + # Add some entities + configs = [Config(name=f"config{i}", data=DemoConfig(x=0, y=0, z=0, tags=[])) for i in range(3)] + await repo.save_all(configs) + await repo.commit() + + # Count should be 3 + assert await repo.count() == 3 + + await db.dispose() + + async def test_exists_by_id(self) -> None: + """Test checking if an entity exists by ID.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = BaseRepository[Config, ULID](session, Config) + + # Create and save entity + config = Config(name="test", data=DemoConfig(x=0, y=0, z=0, tags=[])) + await repo.save(config) + await repo.commit() + await repo.refresh_many([config]) + + # Should exist + assert await repo.exists_by_id(config.id) is True + + # Random ULID should not exist + random_id = ULID() + assert await repo.exists_by_id(random_id) is False + + await db.dispose() + + async def test_delete(self) -> None: + """Test deleting a single entity.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = BaseRepository[Config, ULID](session, Config) + + # Create and save entity + config = Config(name="to_delete", data=DemoConfig(x=0, y=0, z=0, tags=[])) + await repo.save(config) + await repo.commit() + await repo.refresh_many([config]) + + # Delete it + await repo.delete(config) + await repo.commit() + + # Should no longer exist + assert await repo.exists_by_id(config.id) is False + assert await repo.count() == 0 + + await db.dispose() + + async def test_delete_by_id(self) -> None: + """Test deleting a single entity by ID.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = BaseRepository[Config, ULID](session, Config) + + # Create and save entity + config = Config(name="to_delete_by_id", data=DemoConfig(x=0, y=0, z=0, tags=[])) + await repo.save(config) + await repo.commit() + await repo.refresh_many([config]) + + # Delete by ID + await repo.delete_by_id(config.id) + await repo.commit() + + # Should no longer exist + assert await repo.exists_by_id(config.id) is False + assert await repo.count() == 0 + + await db.dispose() + + async def test_delete_all(self) -> None: + """Test deleting all entities.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = BaseRepository[Config, ULID](session, Config) + + # Create some entities + configs = [Config(name=f"config{i}", data=DemoConfig(x=0, y=0, z=0, tags=[])) for i in range(5)] + await repo.save_all(configs) + await repo.commit() + + assert await repo.count() == 5 + + # Delete all + await repo.delete_all() + await repo.commit() + + assert await repo.count() == 0 + + await db.dispose() + + async def test_delete_all_by_id(self) -> None: + """Test deleting multiple entities by their IDs.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = BaseRepository[Config, ULID](session, Config) + + # Create entities + configs = [Config(name=f"config{i}", data=DemoConfig(x=0, y=0, z=0, tags=[])) for i in range(5)] + await repo.save_all(configs) + await repo.commit() + await repo.refresh_many(configs) + + # Delete specific ones + to_delete = [configs[1].id, configs[3].id] + await repo.delete_all_by_id(to_delete) + await repo.commit() + + # Should have 3 remaining + assert await repo.count() == 3 + + # Deleted ones should not exist + assert await repo.exists_by_id(configs[1].id) is False + assert await repo.exists_by_id(configs[3].id) is False + + # Others should still exist + assert await repo.exists_by_id(configs[0].id) is True + assert await repo.exists_by_id(configs[2].id) is True + assert await repo.exists_by_id(configs[4].id) is True + + await db.dispose() + + async def test_delete_all_by_id_empty_list(self) -> None: + """Test deleting with an empty ID list.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = BaseRepository[Config, ULID](session, Config) + + # Create some entities + configs = [Config(name=f"config{i}", data=DemoConfig(x=0, y=0, z=0, tags=[])) for i in range(3)] + await repo.save_all(configs) + await repo.commit() + + # Delete with empty list should do nothing + await repo.delete_all_by_id([]) + await repo.commit() + + assert await repo.count() == 3 + + await db.dispose() + + async def test_refresh_many(self) -> None: + """Test refreshing multiple entities.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = BaseRepository[Config, ULID](session, Config) + + # Create entities + configs = [Config(name=f"config{i}", data=DemoConfig(x=i, y=i, z=i, tags=[])) for i in range(3)] + await repo.save_all(configs) + await repo.commit() + + # Refresh them + await repo.refresh_many(configs) + + # All should have IDs now + assert all(c.id is not None for c in configs) + assert all(c.created_at is not None for c in configs) + + await db.dispose() + + async def test_commit(self) -> None: + """Test committing changes.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = BaseRepository[Config, ULID](session, Config) + + # Create entity but don't commit + config = Config(name="test", data=DemoConfig(x=0, y=0, z=0, tags=[])) + await repo.save(config) + + # Count in another session should be 0 (not committed) + async with db.session() as session2: + repo2 = BaseRepository[Config, ULID](session2, Config) + assert await repo2.count() == 0 + + # Now commit + await repo.commit() + + # Count in another session should be 1 + async with db.session() as session3: + repo3 = BaseRepository[Config, ULID](session3, Config) + assert await repo3.count() == 1 + + await db.dispose() + + async def test_config_repository_find_by_name(self) -> None: + """Test ConfigRepository.find_by_name for existing and missing configs.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ConfigRepository(session) + + assert await repo.find_by_name("missing") is None + + config = Config(name="target", data=DemoConfig(x=1, y=1, z=1, tags=[])) + await repo.save(config) + await repo.commit() + await repo.refresh_many([config]) + + found = await repo.find_by_name("target") + assert found is not None + assert found.id == config.id + assert found.name == "target" + + await db.dispose() diff --git a/packages/servicekit/tests/test_repository_artifact.py b/packages/servicekit/tests/test_repository_artifact.py new file mode 100644 index 0000000..a884c83 --- /dev/null +++ b/packages/servicekit/tests/test_repository_artifact.py @@ -0,0 +1,65 @@ +"""Artifact repository tests.""" + +from __future__ import annotations + +from servicekit import Artifact, ArtifactRepository, SqliteDatabaseBuilder + + +async def test_artifact_repository_find_by_id_eager_loads_children() -> None: + """find_by_id should eagerly load the artifact children collection.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ArtifactRepository(session) + + root = Artifact(data={"name": "root"}, level=0) + await repo.save(root) + await repo.commit() + await repo.refresh_many([root]) + + child_a = Artifact(data={"name": "child_a"}, parent_id=root.id, level=1) + child_b = Artifact(data={"name": "child_b"}, parent_id=root.id, level=1) + await repo.save_all([child_a, child_b]) + await repo.commit() + await repo.refresh_many([child_a, child_b]) + + fetched = await repo.find_by_id(root.id) + assert fetched is not None + children = await fetched.awaitable_attrs.children + assert {child.data["name"] for child in children} == {"child_a", "child_b"} + + await db.dispose() + + +async def test_artifact_repository_find_subtree_returns_full_hierarchy() -> None: + """find_subtree should return the start node and all descendants.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ArtifactRepository(session) + + root = Artifact(data={"name": "root"}, level=0) + await repo.save(root) + await repo.commit() + await repo.refresh_many([root]) + + child = Artifact(data={"name": "child"}, parent_id=root.id, level=1) + await repo.save(child) + await repo.commit() + await repo.refresh_many([child]) + + grandchild = Artifact(data={"name": "grandchild"}, parent_id=child.id, level=2) + await repo.save(grandchild) + await repo.commit() + await repo.refresh_many([grandchild]) + + subtree = list(await repo.find_subtree(root.id)) + ids = {artifact.id for artifact in subtree} + assert ids == {root.id, child.id, grandchild.id} + lookup = {artifact.id: artifact for artifact in subtree} + assert lookup[child.id].parent_id == root.id + assert lookup[grandchild.id].parent_id == child.id + + await db.dispose() diff --git a/packages/servicekit/tests/test_repository_config.py b/packages/servicekit/tests/test_repository_config.py new file mode 100644 index 0000000..53f1b7f --- /dev/null +++ b/packages/servicekit/tests/test_repository_config.py @@ -0,0 +1,31 @@ +"""Config repository tests.""" + +from __future__ import annotations + +from servicekit import Config, ConfigRepository, SqliteDatabaseBuilder + +from .conftest import DemoConfig + + +async def test_config_repository_find_by_name_round_trip() -> None: + """ConfigRepository.find_by_name should return matching rows and None otherwise.""" + db = SqliteDatabaseBuilder.in_memory().build() + await db.init() + + async with db.session() as session: + repo = ConfigRepository(session) + + assert await repo.find_by_name("missing") is None + + created = Config(name="feature", data=DemoConfig(x=1, y=2, z=3, tags=["feature"])) + await repo.save(created) + await repo.commit() + await repo.refresh_many([created]) + + found = await repo.find_by_name("feature") + assert found is not None + assert found.id == created.id + assert found.name == "feature" + assert found.data == {"x": 1, "y": 2, "z": 3, "tags": ["feature"]} + + await db.dispose() diff --git a/packages/servicekit/tests/test_scheduler.py b/packages/servicekit/tests/test_scheduler.py new file mode 100644 index 0000000..c13e488 --- /dev/null +++ b/packages/servicekit/tests/test_scheduler.py @@ -0,0 +1,363 @@ +"""Tests for job scheduler functionality.""" + +import asyncio + +import pytest +import ulid +from servicekit.core import AIOJobScheduler, JobStatus + +ULID = ulid.ULID + + +class TestAIOJobScheduler: + """Test AIOJobScheduler functionality.""" + + @pytest.mark.asyncio + async def test_add_simple_async_job(self) -> None: + """Test adding and executing a simple async job.""" + scheduler = AIOJobScheduler() + + async def simple_task(): + await asyncio.sleep(0.01) + return "done" + + job_id = await scheduler.add_job(simple_task) + assert isinstance(job_id, ULID) + + # Wait for completion + await scheduler.wait(job_id) + + # Check status and result + status = await scheduler.get_status(job_id) + assert status == JobStatus.completed + + result = await scheduler.get_result(job_id) + assert result == "done" + + @pytest.mark.asyncio + async def test_add_sync_job(self) -> None: + """Test adding a synchronous callable (runs in thread pool).""" + scheduler = AIOJobScheduler() + + def sync_task(): + return 42 + + job_id = await scheduler.add_job(sync_task) + await scheduler.wait(job_id) + + result = await scheduler.get_result(job_id) + assert result == 42 + + @pytest.mark.asyncio + async def test_job_with_args_kwargs(self) -> None: + """Test job with positional and keyword arguments.""" + scheduler = AIOJobScheduler() + + async def task_with_args(a: int, b: int, c: int = 10) -> int: + return a + b + c + + job_id = await scheduler.add_job(task_with_args, 1, 2, c=3) + await scheduler.wait(job_id) + + result = await scheduler.get_result(job_id) + assert result == 6 + + @pytest.mark.asyncio + async def test_job_lifecycle_states(self) -> None: + """Test job progresses through states: pending -> running -> completed.""" + scheduler = AIOJobScheduler() + + async def slow_task(): + await asyncio.sleep(0.05) + return "result" + + job_id = await scheduler.add_job(slow_task) + + # Initially pending (may already be running due to async scheduling) + record = await scheduler.get_record(job_id) + assert record.status in (JobStatus.pending, JobStatus.running) + assert record.submitted_at is not None + + # Wait and check completed + await scheduler.wait(job_id) + record = await scheduler.get_record(job_id) + assert record.status == JobStatus.completed + assert record.started_at is not None + assert record.finished_at is not None + assert record.error is None + + @pytest.mark.asyncio + async def test_job_failure_with_traceback(self) -> None: + """Test job failure captures error traceback.""" + scheduler = AIOJobScheduler() + + async def failing_task(): + raise ValueError("Something went wrong") + + job_id = await scheduler.add_job(failing_task) + + # Wait for task to complete (will fail) + try: + await scheduler.wait(job_id) + except ValueError: + pass # Expected + + # Check status and error + record = await scheduler.get_record(job_id) + assert record.status == JobStatus.failed + assert record.error is not None + assert "ValueError" in record.error + assert "Something went wrong" in record.error + + # Getting result should raise RuntimeError with traceback + with pytest.raises(RuntimeError, match="ValueError"): + await scheduler.get_result(job_id) + + @pytest.mark.asyncio + async def test_cancel_running_job(self) -> None: + """Test canceling a running job.""" + scheduler = AIOJobScheduler() + + async def long_task(): + await asyncio.sleep(10) # Long enough to cancel + return "never reached" + + job_id = await scheduler.add_job(long_task) + await asyncio.sleep(0.01) # Let it start + + # Cancel the job + was_canceled = await scheduler.cancel(job_id) + assert was_canceled is True + + # Check status + record = await scheduler.get_record(job_id) + assert record.status == JobStatus.canceled + + @pytest.mark.asyncio + async def test_cancel_completed_job_returns_false(self) -> None: + """Test canceling already completed job returns False.""" + scheduler = AIOJobScheduler() + + async def quick_task(): + return "done" + + job_id = await scheduler.add_job(quick_task) + await scheduler.wait(job_id) + + # Try to cancel completed job + was_canceled = await scheduler.cancel(job_id) + assert was_canceled is False + + @pytest.mark.asyncio + async def test_delete_job(self) -> None: + """Test deleting a job removes all records.""" + scheduler = AIOJobScheduler() + + async def task(): + return "result" + + job_id = await scheduler.add_job(task) + await scheduler.wait(job_id) + + # Delete job + await scheduler.delete(job_id) + + # Job should no longer exist + with pytest.raises(KeyError): + await scheduler.get_record(job_id) + + @pytest.mark.asyncio + async def test_delete_running_job_cancels_it(self) -> None: + """Test deleting running job cancels it first.""" + scheduler = AIOJobScheduler() + + async def long_task(): + await asyncio.sleep(10) + return "never" + + job_id = await scheduler.add_job(long_task) + await asyncio.sleep(0.01) # Let it start + + # Delete while running + await scheduler.delete(job_id) + + # Job should be gone + with pytest.raises(KeyError): + await scheduler.get_record(job_id) + + @pytest.mark.asyncio + async def test_get_all_records_sorted_newest_first(self) -> None: + """Test get_all_records returns jobs sorted by submission time.""" + scheduler = AIOJobScheduler() + + async def task(): + return "done" + + job_ids = [] + for _ in range(3): + jid = await scheduler.add_job(task) + job_ids.append(jid) + await asyncio.sleep(0.01) # Ensure different timestamps + + records = await scheduler.get_all_records() + assert len(records) == 3 + + # Should be newest first + assert records[0].id == job_ids[2] + assert records[1].id == job_ids[1] + assert records[2].id == job_ids[0] + + @pytest.mark.asyncio + async def test_artifact_id_detection(self) -> None: + """Test scheduler detects ULID results as artifact IDs.""" + scheduler = AIOJobScheduler() + + artifact_id = ULID() + + async def task_returning_ulid(): + return artifact_id + + job_id = await scheduler.add_job(task_returning_ulid) + await scheduler.wait(job_id) + + record = await scheduler.get_record(job_id) + assert record.artifact_id == artifact_id + + @pytest.mark.asyncio + async def test_non_ulid_result_has_no_artifact_id(self) -> None: + """Test non-ULID results don't set artifact_id.""" + scheduler = AIOJobScheduler() + + async def task(): + return {"some": "data"} + + job_id = await scheduler.add_job(task) + await scheduler.wait(job_id) + + record = await scheduler.get_record(job_id) + assert record.artifact_id is None + + @pytest.mark.asyncio + async def test_max_concurrency_limits_parallel_execution(self) -> None: + """Test max_concurrency limits concurrent job execution.""" + scheduler = AIOJobScheduler(max_concurrency=2) + + running_count = 0 + max_concurrent = 0 + + async def concurrent_task(): + nonlocal running_count, max_concurrent + running_count += 1 + max_concurrent = max(max_concurrent, running_count) + await asyncio.sleep(0.05) + running_count -= 1 + return "done" + + # Schedule 5 jobs + job_ids = [await scheduler.add_job(concurrent_task) for _ in range(5)] + + # Wait for all to complete + await asyncio.gather(*[scheduler.wait(jid) for jid in job_ids]) + + # At most 2 should have run concurrently + assert max_concurrent <= 2 + + @pytest.mark.asyncio + async def test_set_max_concurrency_runtime(self) -> None: + """Test changing max_concurrency at runtime.""" + scheduler = AIOJobScheduler(max_concurrency=1) + assert scheduler.max_concurrency == 1 + + await scheduler.set_max_concurrency(5) + assert scheduler.max_concurrency == 5 + + await scheduler.set_max_concurrency(None) + assert scheduler.max_concurrency is None + + @pytest.mark.asyncio + async def test_job_not_found_raises_key_error(self) -> None: + """Test accessing non-existent job raises KeyError.""" + scheduler = AIOJobScheduler() + fake_id = ULID() + + with pytest.raises(KeyError): + await scheduler.get_record(fake_id) + + with pytest.raises(KeyError): + await scheduler.get_status(fake_id) + + with pytest.raises(KeyError): + await scheduler.get_result(fake_id) + + with pytest.raises(KeyError): + await scheduler.cancel(fake_id) + + with pytest.raises(KeyError): + await scheduler.delete(fake_id) + + @pytest.mark.asyncio + async def test_get_result_before_completion_raises(self) -> None: + """Test get_result raises if job not finished.""" + scheduler = AIOJobScheduler() + + async def slow_task(): + await asyncio.sleep(1) + return "done" + + job_id = await scheduler.add_job(slow_task) + + # Try to get result immediately (job is pending/running) + with pytest.raises(RuntimeError, match="not finished"): + await scheduler.get_result(job_id) + + # Cleanup + await scheduler.cancel(job_id) + + @pytest.mark.asyncio + async def test_wait_timeout(self) -> None: + """Test wait with timeout raises asyncio.TimeoutError.""" + scheduler = AIOJobScheduler() + + async def long_task(): + await asyncio.sleep(10) + return "never" + + job_id = await scheduler.add_job(long_task) + + with pytest.raises(asyncio.TimeoutError): + await scheduler.wait(job_id, timeout=0.01) + + # Cleanup + await scheduler.cancel(job_id) + + @pytest.mark.asyncio + async def test_awaitable_target(self) -> None: + """Test passing an already-created awaitable as target.""" + scheduler = AIOJobScheduler() + + async def task(): + return "result" + + # Create coroutine object + coro = task() + + job_id = await scheduler.add_job(coro) + await scheduler.wait(job_id) + + result = await scheduler.get_result(job_id) + assert result == "result" + + @pytest.mark.asyncio + async def test_awaitable_target_rejects_args(self) -> None: + """Test awaitable target raises TypeError if args/kwargs provided.""" + scheduler = AIOJobScheduler() + + async def task(): + return "result" + + coro = task() + + job_id = await scheduler.add_job(coro, "extra_arg") + # The error happens during execution, not during add_job + with pytest.raises(TypeError, match="Args/kwargs not supported"): + await scheduler.wait(job_id) diff --git a/packages/servicekit/tests/test_service_builder.py b/packages/servicekit/tests/test_service_builder.py new file mode 100644 index 0000000..4e8a1c8 --- /dev/null +++ b/packages/servicekit/tests/test_service_builder.py @@ -0,0 +1,103 @@ +"""Tests for ServiceBuilder validation.""" + +from typing import Any + +import pandas as pd +import pytest +from geojson_pydantic import FeatureCollection +from pydantic import Field +from servicekit import BaseConfig +from servicekit.api import ServiceBuilder +from servicekit.core.api.service_builder import ServiceInfo +from servicekit.modules.artifact import ArtifactHierarchy +from servicekit.modules.ml import ModelRunnerProtocol + + +class DummyConfig(BaseConfig): + """Dummy config for testing.""" + + test_value: str = Field(default="test") + + +class DummyRunner(ModelRunnerProtocol): + """Dummy ML runner for testing.""" + + async def on_train( + self, + config: BaseConfig, + data: pd.DataFrame, + geo: FeatureCollection | None = None, + ) -> Any: + """Train a model.""" + return {"status": "trained"} + + async def on_predict( + self, + config: BaseConfig, + model: Any, + historic: pd.DataFrame | None, + future: pd.DataFrame, + geo: FeatureCollection | None = None, + ) -> pd.DataFrame: + """Make predictions.""" + return pd.DataFrame({"predictions": []}) + + +def test_tasks_without_artifacts_raises_error() -> None: + """Test that tasks without artifacts raises ValueError.""" + builder = ServiceBuilder(info=ServiceInfo(display_name="Test")) + + with pytest.raises(ValueError, match="Task execution requires artifacts"): + builder.with_tasks().build() + + +def test_ml_without_config_raises_error() -> None: + """Test that ML without config raises ValueError.""" + builder = ServiceBuilder(info=ServiceInfo(display_name="Test")) + hierarchy = ArtifactHierarchy(name="test") + runner = DummyRunner() + + with pytest.raises(ValueError, match="ML operations require config"): + builder.with_artifacts(hierarchy=hierarchy).with_ml(runner=runner).build() + + +def test_ml_without_artifacts_raises_error() -> None: + """Test that ML without artifacts raises ValueError.""" + builder = ServiceBuilder(info=ServiceInfo(display_name="Test")) + runner = DummyRunner() + + with pytest.raises(ValueError, match="ML operations require artifacts"): + builder.with_config(DummyConfig).with_ml(runner=runner).build() + + +def test_ml_without_jobs_raises_error() -> None: + """Test that ML without job scheduler raises ValueError.""" + builder = ServiceBuilder(info=ServiceInfo(display_name="Test")) + hierarchy = ArtifactHierarchy(name="test") + runner = DummyRunner() + + with pytest.raises(ValueError, match="ML operations require job scheduler"): + builder.with_config(DummyConfig).with_artifacts(hierarchy=hierarchy).with_ml(runner=runner).build() + + +def test_artifacts_config_linking_without_config_raises_error() -> None: + """Test that artifact config-linking without config raises ValueError.""" + builder = ServiceBuilder(info=ServiceInfo(display_name="Test")) + hierarchy = ArtifactHierarchy(name="test") + + with pytest.raises(ValueError, match="Artifact config-linking requires a config schema"): + builder.with_artifacts(hierarchy=hierarchy, enable_config_linking=True).build() + + +def test_valid_ml_service_builds_successfully() -> None: + """Test that a properly configured ML service builds without errors.""" + builder = ServiceBuilder(info=ServiceInfo(display_name="Test")) + hierarchy = ArtifactHierarchy(name="test") + runner = DummyRunner() + + # This should build successfully with all dependencies + app = ( + builder.with_config(DummyConfig).with_artifacts(hierarchy=hierarchy).with_jobs().with_ml(runner=runner).build() + ) + + assert app is not None diff --git a/packages/servicekit/tests/test_service_builder_apps.py b/packages/servicekit/tests/test_service_builder_apps.py new file mode 100644 index 0000000..66adbb9 --- /dev/null +++ b/packages/servicekit/tests/test_service_builder_apps.py @@ -0,0 +1,462 @@ +"""Integration tests for ServiceBuilder app system.""" + +import json +from pathlib import Path + +import pytest +from servicekit.core.api import BaseServiceBuilder, ServiceInfo +from starlette.testclient import TestClient + + +@pytest.fixture +def app_directory(tmp_path: Path) -> Path: + """Create test app directory structure.""" + apps_dir = tmp_path / "apps" + apps_dir.mkdir() + + # Create dashboard app + dashboard_dir = apps_dir / "dashboard" + dashboard_dir.mkdir() + (dashboard_dir / "manifest.json").write_text( + json.dumps({"name": "Dashboard", "version": "1.0.0", "prefix": "/dashboard"}) + ) + (dashboard_dir / "index.html").write_text("Dashboard App") + (dashboard_dir / "style.css").write_text("body { color: blue; }") + + # Create admin app + admin_dir = apps_dir / "admin" + admin_dir.mkdir() + (admin_dir / "manifest.json").write_text(json.dumps({"name": "Admin", "version": "2.0.0", "prefix": "/admin"})) + (admin_dir / "index.html").write_text("Admin App") + + return apps_dir + + +def test_service_builder_with_single_app(app_directory: Path): + """Test mounting a single app.""" + app = ( + BaseServiceBuilder(info=ServiceInfo(display_name="Test Service")) + .with_app(str(app_directory / "dashboard")) + .build() + ) + + with TestClient(app) as client: + # App should be accessible + response = client.get("/dashboard/") + assert response.status_code == 200 + assert b"Dashboard App" in response.content + + # CSS should be accessible + response = client.get("/dashboard/style.css") + assert response.status_code == 200 + assert b"color: blue" in response.content + + +def test_service_builder_with_prefix_override(app_directory: Path): + """Test mounting app with custom prefix.""" + app = ( + BaseServiceBuilder(info=ServiceInfo(display_name="Test Service")) + .with_app(str(app_directory / "dashboard"), prefix="/custom") + .build() + ) + + with TestClient(app) as client: + # App should be at custom prefix + response = client.get("/custom/") + assert response.status_code == 200 + assert b"Dashboard App" in response.content + + # Original prefix should not work + response = client.get("/dashboard/") + assert response.status_code == 404 + + +def test_service_builder_with_multiple_apps(app_directory: Path): + """Test mounting multiple apps.""" + app = ( + BaseServiceBuilder(info=ServiceInfo(display_name="Test Service")) + .with_app(str(app_directory / "dashboard")) + .with_app(str(app_directory / "admin")) + .build() + ) + + with TestClient(app) as client: + # Both apps should be accessible + dashboard_response = client.get("/dashboard/") + assert dashboard_response.status_code == 200 + assert b"Dashboard App" in dashboard_response.content + + admin_response = client.get("/admin/") + assert admin_response.status_code == 200 + assert b"Admin App" in admin_response.content + + +def test_service_builder_with_apps_autodiscovery(app_directory: Path): + """Test auto-discovering apps from directory.""" + app = BaseServiceBuilder(info=ServiceInfo(display_name="Test Service")).with_apps(str(app_directory)).build() + + with TestClient(app) as client: + # Both discovered apps should be accessible + dashboard_response = client.get("/dashboard/") + assert dashboard_response.status_code == 200 + assert b"Dashboard App" in dashboard_response.content + + admin_response = client.get("/admin/") + assert admin_response.status_code == 200 + assert b"Admin App" in admin_response.content + + +def test_service_builder_apps_with_api_routes(app_directory: Path): + """Test that apps don't interfere with API routes.""" + app = ( + BaseServiceBuilder(info=ServiceInfo(display_name="Test Service")) + .with_health() + .with_system() + .with_app(str(app_directory / "dashboard")) + .build() + ) + + with TestClient(app) as client: + # API routes should work + health_response = client.get("/health") + assert health_response.status_code == 200 + + system_response = client.get("/api/v1/system") + assert system_response.status_code == 200 + + # App should also work + app_response = client.get("/dashboard/") + assert app_response.status_code == 200 + assert b"Dashboard App" in app_response.content + + +def test_service_builder_apps_override_semantics(app_directory: Path): + """Test that duplicate prefixes use last-wins semantics.""" + # Mount dashboard twice - second should override first + app = ( + BaseServiceBuilder(info=ServiceInfo(display_name="Test Service")) + .with_app(str(app_directory / "dashboard")) + .with_app(str(app_directory / "admin"), prefix="/dashboard") # Override with admin + .build() + ) + + with TestClient(app) as client: + # Should serve admin app content (not dashboard) + response = client.get("/dashboard/") + assert response.status_code == 200 + assert b"Admin App" in response.content + + +def test_service_builder_apps_api_prefix_blocked(app_directory: Path): + """Test that apps cannot mount at /api.**.""" + bad_app_dir = app_directory / "bad" + bad_app_dir.mkdir() + (bad_app_dir / "manifest.json").write_text(json.dumps({"name": "Bad", "version": "1.0.0", "prefix": "/api/bad"})) + (bad_app_dir / "index.html").write_text("Bad") + + with pytest.raises(ValueError, match="prefix cannot be '/api'"): + BaseServiceBuilder(info=ServiceInfo(display_name="Test Service")).with_app(str(bad_app_dir)).build() + + +def test_service_builder_apps_root_mount_works(app_directory: Path): + """Test that root-mounted apps work correctly (without landing page).""" + # Create root app with subdirectory structure + root_app_dir = app_directory / "root" + root_app_dir.mkdir() + (root_app_dir / "manifest.json").write_text(json.dumps({"name": "Root", "version": "1.0.0", "prefix": "/"})) + (root_app_dir / "index.html").write_text("Root App Index") + + # Create about subdirectory + about_dir = root_app_dir / "about" + about_dir.mkdir() + (about_dir / "index.html").write_text("About Page") + + app = ( + BaseServiceBuilder(info=ServiceInfo(display_name="Test Service")) + .with_health() + .with_system() + .with_app(str(root_app_dir)) + .build() + ) + + with TestClient(app) as client: + # API routes should work (routes take precedence over mounts) + health_response = client.get("/health") + assert health_response.status_code == 200 + + system_response = client.get("/api/v1/system") + assert system_response.status_code == 200 + + info_response = client.get("/api/v1/info") + assert info_response.status_code == 200 + + # Root path serves app + root_response = client.get("/") + assert root_response.status_code == 200 + assert b"Root App Index" in root_response.content + + # App subdirectories work + about_response = client.get("/about") + assert about_response.status_code == 200 + assert b"About Page" in about_response.content + + +def test_service_builder_apps_root_override_landing_page(app_directory: Path): + """Test that root apps can override landing page (last wins).""" + root_app_dir = app_directory / "root" + root_app_dir.mkdir() + (root_app_dir / "manifest.json").write_text(json.dumps({"name": "Custom Root", "version": "1.0.0", "prefix": "/"})) + (root_app_dir / "index.html").write_text("Custom Root App") + + app = ( + BaseServiceBuilder(info=ServiceInfo(display_name="Test Service")) + .with_landing_page() # Mount built-in landing page first + .with_app(str(root_app_dir)) # Override with custom root app + .build() + ) + + with TestClient(app) as client: + # Should serve custom root app (not landing page) + response = client.get("/") + assert response.status_code == 200 + assert b"Custom Root App" in response.content + assert b"chapkit" not in response.content.lower() # Not landing page + + +def test_service_builder_html_mode_serves_index_for_subdirs(app_directory: Path): + """Test that html=True mode serves index.html for directory paths.""" + dashboard_dir = app_directory / "dashboard" + subdir = dashboard_dir / "subdir" + subdir.mkdir() + (subdir / "index.html").write_text("Subdir Index") + + app = BaseServiceBuilder(info=ServiceInfo(display_name="Test Service")).with_app(str(dashboard_dir)).build() + + with TestClient(app) as client: + # Requesting /dashboard/subdir/ should serve subdir/index.html + response = client.get("/dashboard/subdir/") + assert response.status_code == 200 + assert b"Subdir Index" in response.content + + +def test_service_builder_apps_404_for_missing_files(app_directory: Path): + """Test that missing files return 404.""" + app = ( + BaseServiceBuilder(info=ServiceInfo(display_name="Test Service")) + .with_app(str(app_directory / "dashboard")) + .build() + ) + + with TestClient(app) as client: + response = client.get("/dashboard/nonexistent.html") + assert response.status_code == 404 + + +def test_service_builder_apps_mount_order(app_directory: Path): + """Test that apps are mounted after routers.""" + # Create an app that would conflict if mounted before routers + api_like_app = app_directory / "api-like" + api_like_app.mkdir() + (api_like_app / "manifest.json").write_text( + json.dumps({"name": "API Like", "version": "1.0.0", "prefix": "/status"}) + ) + (api_like_app / "index.html").write_text("App Status") + + app = ( + BaseServiceBuilder(info=ServiceInfo(display_name="Test Service")) + .with_health(prefix="/status") # Mount health at /status + .with_app(str(api_like_app)) # Try to mount app at /status (should fail) + .build() + ) + + with TestClient(app) as client: + # Health endpoint should take precedence + response = client.get("/status") + assert response.status_code == 200 + # Should be JSON health response, not HTML + assert response.headers["content-type"].startswith("application/json") + + +def test_service_builder_apps_with_custom_routers(app_directory: Path): + """Test apps work alongside custom routers.""" + from fastapi import APIRouter + + custom_router = APIRouter(prefix="/custom", tags=["Custom"]) + + @custom_router.get("/endpoint") + async def custom_endpoint() -> dict[str, str]: + return {"message": "custom"} + + app = ( + BaseServiceBuilder(info=ServiceInfo(display_name="Test Service")) + .with_app(str(app_directory / "dashboard")) + .include_router(custom_router) + .build() + ) + + with TestClient(app) as client: + # Custom router should work + custom_response = client.get("/custom/endpoint") + assert custom_response.status_code == 200 + assert custom_response.json() == {"message": "custom"} + + # App should work + app_response = client.get("/dashboard/") + assert app_response.status_code == 200 + assert b"Dashboard App" in app_response.content + + +def test_system_apps_endpoint_empty(): + """Test /api/v1/system/apps with no apps.""" + app = BaseServiceBuilder(info=ServiceInfo(display_name="Test Service")).with_system().build() + + with TestClient(app) as client: + response = client.get("/api/v1/system/apps") + assert response.status_code == 200 + assert response.json() == [] + + +def test_system_apps_endpoint_with_apps(app_directory: Path): + """Test /api/v1/system/apps lists installed apps.""" + app = ( + BaseServiceBuilder(info=ServiceInfo(display_name="Test Service")) + .with_system() + .with_app(str(app_directory / "dashboard")) + .with_app(str(app_directory / "admin")) + .build() + ) + + with TestClient(app) as client: + response = client.get("/api/v1/system/apps") + assert response.status_code == 200 + apps = response.json() + assert len(apps) == 2 + + # Check dashboard app + dashboard = next(a for a in apps if a["prefix"] == "/dashboard") + assert dashboard["name"] == "Dashboard" + assert dashboard["version"] == "1.0.0" + assert dashboard["prefix"] == "/dashboard" + assert dashboard["entry"] == "index.html" + assert dashboard["is_package"] is False + + # Check admin app + admin = next(a for a in apps if a["prefix"] == "/admin") + assert admin["name"] == "Admin" + assert admin["version"] == "2.0.0" + + +def test_system_apps_endpoint_returns_correct_fields(app_directory: Path): + """Test AppInfo fields match manifest.""" + app = ( + BaseServiceBuilder(info=ServiceInfo(display_name="Test Service")) + .with_system() + .with_app(str(app_directory / "dashboard")) + .build() + ) + + with TestClient(app) as client: + response = client.get("/api/v1/system/apps") + assert response.status_code == 200 + apps = response.json() + assert len(apps) == 1 + + app_info = apps[0] + # Check all required fields present + assert "name" in app_info + assert "version" in app_info + assert "prefix" in app_info + assert "description" in app_info + assert "author" in app_info + assert "entry" in app_info + assert "is_package" in app_info + + # Check values + assert app_info["name"] == "Dashboard" + assert app_info["version"] == "1.0.0" + assert app_info["prefix"] == "/dashboard" + assert app_info["entry"] == "index.html" + + +def test_system_apps_schema_endpoint(): + """Test /api/v1/system/apps/$schema returns JSON schema.""" + app = BaseServiceBuilder(info=ServiceInfo(display_name="Test Service")).with_system().build() + + with TestClient(app) as client: + response = client.get("/api/v1/system/apps/$schema") + assert response.status_code == 200 + schema = response.json() + + # Verify schema structure + assert schema["type"] == "array" + assert "items" in schema + assert "$ref" in schema["items"] + assert schema["items"]["$ref"] == "#/$defs/AppInfo" + + # Verify AppInfo definition exists + assert "$defs" in schema + assert "AppInfo" in schema["$defs"] + + # Verify AppInfo schema has required fields + app_info_schema = schema["$defs"]["AppInfo"] + assert app_info_schema["type"] == "object" + assert "properties" in app_info_schema + assert "name" in app_info_schema["properties"] + assert "version" in app_info_schema["properties"] + assert "prefix" in app_info_schema["properties"] + assert "entry" in app_info_schema["properties"] + assert "is_package" in app_info_schema["properties"] + + # Verify required fields + assert "required" in app_info_schema + assert "name" in app_info_schema["required"] + assert "version" in app_info_schema["required"] + assert "prefix" in app_info_schema["required"] + + +def test_service_builder_with_apps_package_discovery(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + """Test auto-discovering apps from package resources.""" + # Create a temporary package structure + pkg_dir = tmp_path / "test_app_package" + pkg_dir.mkdir() + (pkg_dir / "__init__.py").write_text("") # Make it a package + + # Create apps subdirectory + apps_dir = pkg_dir / "bundled_apps" + apps_dir.mkdir() + + # Create dashboard app + dashboard_dir = apps_dir / "dashboard" + dashboard_dir.mkdir() + (dashboard_dir / "manifest.json").write_text( + json.dumps({"name": "Bundled Dashboard", "version": "1.0.0", "prefix": "/dashboard"}) + ) + (dashboard_dir / "index.html").write_text("Bundled Dashboard") + + # Create admin app + admin_dir = apps_dir / "admin" + admin_dir.mkdir() + (admin_dir / "manifest.json").write_text( + json.dumps({"name": "Bundled Admin", "version": "2.0.0", "prefix": "/admin"}) + ) + (admin_dir / "index.html").write_text("Bundled Admin") + + # Add package to sys.path temporarily + monkeypatch.syspath_prepend(str(tmp_path)) + + # Build service with package app discovery + app = ( + BaseServiceBuilder(info=ServiceInfo(display_name="Test Service")) + .with_apps(("test_app_package", "bundled_apps")) + .build() + ) + + with TestClient(app) as client: + # Both discovered package apps should be accessible + dashboard_response = client.get("/dashboard/") + assert dashboard_response.status_code == 200 + assert b"Bundled Dashboard" in dashboard_response.content + + admin_response = client.get("/admin/") + assert admin_response.status_code == 200 + assert b"Bundled Admin" in admin_response.content diff --git a/packages/servicekit/tests/test_task_router.py b/packages/servicekit/tests/test_task_router.py new file mode 100644 index 0000000..109f23d --- /dev/null +++ b/packages/servicekit/tests/test_task_router.py @@ -0,0 +1,97 @@ +"""Tests for TaskRouter error handling.""" + +from unittest.mock import AsyncMock, Mock + +from fastapi import FastAPI +from fastapi.testclient import TestClient +from servicekit import TaskIn, TaskManager, TaskOut +from servicekit.modules.task import TaskRouter +from ulid import ULID + + +def test_execute_task_value_error_returns_400() -> None: + """Test that ValueError from execute_task returns 400 Bad Request.""" + # Create mock manager that raises ValueError + mock_manager = Mock(spec=TaskManager) + mock_manager.execute_task = AsyncMock(side_effect=ValueError("Task not found")) + + def manager_factory() -> TaskManager: + return mock_manager + + # Create app with router + app = FastAPI() + router = TaskRouter.create( + prefix="/api/v1/tasks", + tags=["Tasks"], + entity_in_type=TaskIn, + entity_out_type=TaskOut, + manager_factory=manager_factory, + ) + app.include_router(router) + + client = TestClient(app) + task_id = ULID() + + response = client.post(f"/api/v1/tasks/{task_id}/$execute") + + assert response.status_code == 400 + assert "Task not found" in response.json()["detail"] + + +def test_execute_task_runtime_error_returns_409() -> None: + """Test that RuntimeError from execute_task returns 409 Conflict.""" + # Create mock manager that raises RuntimeError + mock_manager = Mock(spec=TaskManager) + mock_manager.execute_task = AsyncMock(side_effect=RuntimeError("Database instance required for task execution")) + + def manager_factory() -> TaskManager: + return mock_manager + + # Create app with router + app = FastAPI() + router = TaskRouter.create( + prefix="/api/v1/tasks", + tags=["Tasks"], + entity_in_type=TaskIn, + entity_out_type=TaskOut, + manager_factory=manager_factory, + ) + app.include_router(router) + + client = TestClient(app) + task_id = ULID() + + response = client.post(f"/api/v1/tasks/{task_id}/$execute") + + assert response.status_code == 409 + assert "Database instance required" in response.json()["detail"] + + +def test_execute_task_with_valid_ulid() -> None: + """Test execute_task endpoint with valid ULID.""" + mock_manager = Mock(spec=TaskManager) + job_id = ULID() + mock_manager.execute_task = AsyncMock(return_value=job_id) + + def manager_factory() -> TaskManager: + return mock_manager + + app = FastAPI() + router = TaskRouter.create( + prefix="/api/v1/tasks", + tags=["Tasks"], + entity_in_type=TaskIn, + entity_out_type=TaskOut, + manager_factory=manager_factory, + ) + app.include_router(router) + + client = TestClient(app) + task_id = ULID() + + response = client.post(f"/api/v1/tasks/{task_id}/$execute") + + assert response.status_code == 202 + data = response.json() + assert data["job_id"] == str(job_id) + assert "submitted for execution" in data["message"] diff --git a/packages/servicekit/tests/test_types.py b/packages/servicekit/tests/test_types.py new file mode 100644 index 0000000..1ada59f --- /dev/null +++ b/packages/servicekit/tests/test_types.py @@ -0,0 +1,24 @@ +from servicekit import ULIDType +from ulid import ULID + + +def test_ulid_type_process_bind_param_with_ulid() -> None: + """process_bind_param should convert ULID inputs to strings.""" + value = ULID() + ulid_type = ULIDType() + assert ulid_type.process_bind_param(value, None) == str(value) + + +def test_ulid_type_process_bind_param_with_string() -> None: + """process_bind_param should validate and normalize string inputs.""" + value = str(ULID()) + ulid_type = ULIDType() + assert ulid_type.process_bind_param(value, None) == value + + +def test_ulid_type_process_result_value() -> None: + """process_result_value should convert strings back to ULID.""" + value = str(ULID()) + ulid_type = ULIDType() + assert ulid_type.process_result_value(value, None) == ULID.from_str(value) + assert ulid_type.process_result_value(None, None) is None diff --git a/pyproject.toml b/pyproject.toml index 3f62862..434d3c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,77 +1,23 @@ +[tool.uv.workspace] +members = ["packages/servicekit", "packages/chapkit"] + [project] -name = "chapkit" +name = "chapkit-workspace" version = "0.1.0" -description = "Async SQLAlchemy database library for Python 3.13+ with FastAPI integration and ML workflow support" -readme = "README.md" -authors = [{ name = "Morten Hansen", email = "morten@winterop.com" }] -license = { text = "AGPL-3.0-or-later" } requires-python = ">=3.13" -keywords = [ - "fastapi", - "sqlalchemy", - "async", - "database", - "ml", - "machine-learning", - "rest-api", - "crud", - "vertical-slice", -] -classifiers = [ - "Development Status :: 3 - Alpha", - "Intended Audience :: Developers", - "License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.13", - "Framework :: FastAPI", - "Topic :: Software Development :: Libraries :: Python Modules", - "Topic :: Database", -] -dependencies = [ - "aiosqlite>=0.21.0", - "alembic>=1.17.0", - "fastapi[standard]>=0.119.0", - "geojson-pydantic>=2.1.0", - "gunicorn>=23.0.0", - "opentelemetry-api>=1.37.0", - "opentelemetry-exporter-prometheus>=0.58b0", - "opentelemetry-instrumentation-fastapi>=0.58b0", - "opentelemetry-instrumentation-sqlalchemy>=0.58b0", - "opentelemetry-sdk>=1.37.0", - "pandas>=2.3.3", - "pydantic>=2.12.0", - "python-ulid>=3.1.0", - "scikit-learn>=1.7.2", - "sqlalchemy[asyncio]>=2.0.43", - "structlog>=24.4.0", -] - -[project.urls] -Homepage = "https://github.com/winterop-com/chapkit" -Repository = "https://github.com/winterop-com/chapkit" -Issues = "https://github.com/winterop-com/chapkit/issues" -Documentation = "https://github.com/winterop-com/chapkit#readme" [dependency-groups] dev = [ - "coverage[toml]>=7.6.0", - "mypy>=1.18.2", - "pandas-stubs>=2.2.3.250101", - "pytest>=8.4.2", - "pytest-cov>=5.0.0", + "pytest>=8.4.0", "pytest-asyncio>=1.2.0", + "pytest-cov>=7.0.0", + "coverage>=7.10.0", "ruff>=0.14.0", + "mypy>=1.18.0", "pyright>=1.1.406", - "scikit-learn>=1.7.2", - "mkdocs-material>=9.6.21", - "mkdocstrings>=0.30.1", - "mkdocstrings-python>=1.18.2", + "pandas-stubs>=2.3.2.250926", ] -[build-system] -requires = ["uv_build>=0.9.0,<0.10.0"] -build-backend = "uv_build" - [tool.ruff] target-version = "py313" line-length = 120 @@ -96,16 +42,20 @@ convention = "google" [tool.ruff.lint.per-file-ignores] # Tests don't need docstrings "tests/**/*.py" = ["D"] +"packages/*/tests/**/*.py" = ["D"] # Alembic migrations are autogenerated "alembic/**/*.py" = ["D"] +"packages/*/alembic/**/*.py" = ["D"] # Allow missing docstrings in __init__ files if they're just exports "**/__init__.py" = ["D104"] # Internal source - allow some missing docstrings for now -"src/**/*.py" = [ +"packages/*/src/**/*.py" = [ "D102", # Missing docstring in public method (TODO: fix gradually) "D105", # Missing docstring in magic method "D107", # Missing docstring in __init__ (TODO: fix gradually) ] +# Old src directory - ignore for now (will be removed) +"src/**/*.py" = ["D"] # Examples must have all docstrings (strictest enforcement) "examples/**/*.py" = [] @@ -118,7 +68,8 @@ docstring-code-line-length = "dynamic" [tool.pytest.ini_options] asyncio_mode = "auto" -testpaths = ["tests"] +pythonpath = ["packages/servicekit/src", "packages/chapkit/src", "examples"] +testpaths = ["packages/servicekit/tests", "packages/chapkit/tests", "tests"] norecursedirs = ["examples", ".git", ".venv", "__pycache__"] filterwarnings = [ "ignore:Pydantic serializer warnings:UserWarning", @@ -129,7 +80,7 @@ filterwarnings = [ branch = true dynamic_context = "test_function" relative_files = true -source = ["chapkit"] +source = ["servicekit", "chapkit"] [tool.coverage.report] exclude_also = [ @@ -149,14 +100,14 @@ disallow_any_unimported = true no_implicit_optional = true warn_unused_ignores = true strict_equality = true -mypy_path = ["src", "typings"] +mypy_path = ["packages/servicekit/src", "packages/chapkit/src", "typings"] [[tool.mypy.overrides]] module = "tests.*" disallow_untyped_defs = false [tool.pyright] -include = ["src", "tests", "examples", "alembic"] +include = ["packages/*/src", "packages/*/tests", "examples", "alembic"] pythonVersion = "3.13" typeCheckingMode = "strict" diagnosticMode = "workspace" diff --git a/uv.lock b/uv.lock index de8379b..78f1c92 100644 --- a/uv.lock +++ b/uv.lock @@ -2,6 +2,13 @@ version = 1 revision = 3 requires-python = ">=3.13" +[manifest] +members = [ + "chapkit", + "chapkit-workspace", + "servicekit", +] + [[package]] name = "aiosqlite" version = "0.21.0" @@ -59,29 +66,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/17/9c/fc2331f538fbf7eedba64b2052e99ccf9ba9d6888e2f41441ee28847004b/asgiref-3.10.0-py3-none-any.whl", hash = "sha256:aef8a81283a34d0ab31630c9b7dfe70c812c95eba78171367ca8745e88124734", size = 24050, upload-time = "2025-10-05T09:15:05.11Z" }, ] -[[package]] -name = "babel" -version = "2.17.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/7d/6b/d52e42361e1aa00709585ecc30b3f9684b3ab62530771402248b1b1d6240/babel-2.17.0.tar.gz", hash = "sha256:0c54cffb19f690cdcc52a3b50bcbf71e07a808d1c80d549f2459b9d2cf0afb9d", size = 9951852, upload-time = "2025-02-01T15:17:41.026Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b7/b8/3fe70c75fe32afc4bb507f75563d39bc5642255d1d94f1f23604725780bf/babel-2.17.0-py3-none-any.whl", hash = "sha256:4d0b53093fdfb4b21c92b5213dba5a1b23885afa8383709427046b21c366e5f2", size = 10182537, upload-time = "2025-02-01T15:17:37.39Z" }, -] - -[[package]] -name = "backrefs" -version = "5.9" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/eb/a7/312f673df6a79003279e1f55619abbe7daebbb87c17c976ddc0345c04c7b/backrefs-5.9.tar.gz", hash = "sha256:808548cb708d66b82ee231f962cb36faaf4f2baab032f2fbb783e9c2fdddaa59", size = 5765857, upload-time = "2025-06-22T19:34:13.97Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/19/4d/798dc1f30468134906575156c089c492cf79b5a5fd373f07fe26c4d046bf/backrefs-5.9-py310-none-any.whl", hash = "sha256:db8e8ba0e9de81fcd635f440deab5ae5f2591b54ac1ebe0550a2ca063488cd9f", size = 380267, upload-time = "2025-06-22T19:34:05.252Z" }, - { url = "https://files.pythonhosted.org/packages/55/07/f0b3375bf0d06014e9787797e6b7cc02b38ac9ff9726ccfe834d94e9991e/backrefs-5.9-py311-none-any.whl", hash = "sha256:6907635edebbe9b2dc3de3a2befff44d74f30a4562adbb8b36f21252ea19c5cf", size = 392072, upload-time = "2025-06-22T19:34:06.743Z" }, - { url = "https://files.pythonhosted.org/packages/9d/12/4f345407259dd60a0997107758ba3f221cf89a9b5a0f8ed5b961aef97253/backrefs-5.9-py312-none-any.whl", hash = "sha256:7fdf9771f63e6028d7fee7e0c497c81abda597ea45d6b8f89e8ad76994f5befa", size = 397947, upload-time = "2025-06-22T19:34:08.172Z" }, - { url = "https://files.pythonhosted.org/packages/10/bf/fa31834dc27a7f05e5290eae47c82690edc3a7b37d58f7fb35a1bdbf355b/backrefs-5.9-py313-none-any.whl", hash = "sha256:cc37b19fa219e93ff825ed1fed8879e47b4d89aa7a1884860e2db64ccd7c676b", size = 399843, upload-time = "2025-06-22T19:34:09.68Z" }, - { url = "https://files.pythonhosted.org/packages/fc/24/b29af34b2c9c41645a9f4ff117bae860291780d73880f449e0b5d948c070/backrefs-5.9-py314-none-any.whl", hash = "sha256:df5e169836cc8acb5e440ebae9aad4bf9d15e226d3bad049cf3f6a5c20cc8dc9", size = 411762, upload-time = "2025-06-22T19:34:11.037Z" }, - { url = "https://files.pythonhosted.org/packages/41/ff/392bff89415399a979be4a65357a41d92729ae8580a66073d8ec8d810f98/backrefs-5.9-py39-none-any.whl", hash = "sha256:f48ee18f6252b8f5777a22a00a09a85de0ca931658f1dd96d4406a34f3748c60", size = 380265, upload-time = "2025-06-22T19:34:12.405Z" }, -] - [[package]] name = "certifi" version = "2025.10.5" @@ -94,32 +78,28 @@ wheels = [ [[package]] name = "chapkit" version = "0.1.0" -source = { editable = "." } +source = { editable = "packages/chapkit" } dependencies = [ - { name = "aiosqlite" }, - { name = "alembic" }, - { name = "fastapi", extra = ["standard"] }, - { name = "geojson-pydantic" }, - { name = "gunicorn" }, - { name = "opentelemetry-api" }, - { name = "opentelemetry-exporter-prometheus" }, - { name = "opentelemetry-instrumentation-fastapi" }, - { name = "opentelemetry-instrumentation-sqlalchemy" }, - { name = "opentelemetry-sdk" }, { name = "pandas" }, - { name = "pydantic" }, - { name = "python-ulid" }, { name = "scikit-learn" }, - { name = "sqlalchemy", extra = ["asyncio"] }, - { name = "structlog" }, + { name = "servicekit" }, ] +[package.metadata] +requires-dist = [ + { name = "pandas", specifier = ">=2.2.0" }, + { name = "scikit-learn", specifier = ">=1.7.0" }, + { name = "servicekit", editable = "packages/servicekit" }, +] + +[[package]] +name = "chapkit-workspace" +version = "0.1.0" +source = { virtual = "." } + [package.dev-dependencies] dev = [ { name = "coverage" }, - { name = "mkdocs-material" }, - { name = "mkdocstrings" }, - { name = "mkdocstrings-python" }, { name = "mypy" }, { name = "pandas-stubs" }, { name = "pyright" }, @@ -127,74 +107,20 @@ dev = [ { name = "pytest-asyncio" }, { name = "pytest-cov" }, { name = "ruff" }, - { name = "scikit-learn" }, ] [package.metadata] -requires-dist = [ - { name = "aiosqlite", specifier = ">=0.21.0" }, - { name = "alembic", specifier = ">=1.17.0" }, - { name = "fastapi", extras = ["standard"], specifier = ">=0.119.0" }, - { name = "geojson-pydantic", specifier = ">=2.1.0" }, - { name = "gunicorn", specifier = ">=23.0.0" }, - { name = "opentelemetry-api", specifier = ">=1.37.0" }, - { name = "opentelemetry-exporter-prometheus", specifier = ">=0.58b0" }, - { name = "opentelemetry-instrumentation-fastapi", specifier = ">=0.58b0" }, - { name = "opentelemetry-instrumentation-sqlalchemy", specifier = ">=0.58b0" }, - { name = "opentelemetry-sdk", specifier = ">=1.37.0" }, - { name = "pandas", specifier = ">=2.3.3" }, - { name = "pydantic", specifier = ">=2.12.0" }, - { name = "python-ulid", specifier = ">=3.1.0" }, - { name = "scikit-learn", specifier = ">=1.7.2" }, - { name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.43" }, - { name = "structlog", specifier = ">=24.4.0" }, -] [package.metadata.requires-dev] dev = [ - { name = "coverage", extras = ["toml"], specifier = ">=7.6.0" }, - { name = "mkdocs-material", specifier = ">=9.6.21" }, - { name = "mkdocstrings", specifier = ">=0.30.1" }, - { name = "mkdocstrings-python", specifier = ">=1.18.2" }, - { name = "mypy", specifier = ">=1.18.2" }, - { name = "pandas-stubs", specifier = ">=2.2.3.250101" }, + { name = "coverage", specifier = ">=7.10.0" }, + { name = "mypy", specifier = ">=1.18.0" }, + { name = "pandas-stubs", specifier = ">=2.3.2.250926" }, { name = "pyright", specifier = ">=1.1.406" }, - { name = "pytest", specifier = ">=8.4.2" }, + { name = "pytest", specifier = ">=8.4.0" }, { name = "pytest-asyncio", specifier = ">=1.2.0" }, - { name = "pytest-cov", specifier = ">=5.0.0" }, + { name = "pytest-cov", specifier = ">=7.0.0" }, { name = "ruff", specifier = ">=0.14.0" }, - { name = "scikit-learn", specifier = ">=1.7.2" }, -] - -[[package]] -name = "charset-normalizer" -version = "3.4.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/83/2d/5fd176ceb9b2fc619e63405525573493ca23441330fcdaee6bef9460e924/charset_normalizer-3.4.3.tar.gz", hash = "sha256:6fce4b8500244f6fcb71465d4a4930d132ba9ab8e71a7859e6a5d59851068d14", size = 122371, upload-time = "2025-08-09T07:57:28.46Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/65/ca/2135ac97709b400c7654b4b764daf5c5567c2da45a30cdd20f9eefe2d658/charset_normalizer-3.4.3-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:14c2a87c65b351109f6abfc424cab3927b3bdece6f706e4d12faaf3d52ee5efe", size = 205326, upload-time = "2025-08-09T07:56:24.721Z" }, - { url = "https://files.pythonhosted.org/packages/71/11/98a04c3c97dd34e49c7d247083af03645ca3730809a5509443f3c37f7c99/charset_normalizer-3.4.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:41d1fc408ff5fdfb910200ec0e74abc40387bccb3252f3f27c0676731df2b2c8", size = 146008, upload-time = "2025-08-09T07:56:26.004Z" }, - { url = "https://files.pythonhosted.org/packages/60/f5/4659a4cb3c4ec146bec80c32d8bb16033752574c20b1252ee842a95d1a1e/charset_normalizer-3.4.3-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:1bb60174149316da1c35fa5233681f7c0f9f514509b8e399ab70fea5f17e45c9", size = 159196, upload-time = "2025-08-09T07:56:27.25Z" }, - { url = "https://files.pythonhosted.org/packages/86/9e/f552f7a00611f168b9a5865a1414179b2c6de8235a4fa40189f6f79a1753/charset_normalizer-3.4.3-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:30d006f98569de3459c2fc1f2acde170b7b2bd265dc1943e87e1a4efe1b67c31", size = 156819, upload-time = "2025-08-09T07:56:28.515Z" }, - { url = "https://files.pythonhosted.org/packages/7e/95/42aa2156235cbc8fa61208aded06ef46111c4d3f0de233107b3f38631803/charset_normalizer-3.4.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:416175faf02e4b0810f1f38bcb54682878a4af94059a1cd63b8747244420801f", size = 151350, upload-time = "2025-08-09T07:56:29.716Z" }, - { url = "https://files.pythonhosted.org/packages/c2/a9/3865b02c56f300a6f94fc631ef54f0a8a29da74fb45a773dfd3dcd380af7/charset_normalizer-3.4.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6aab0f181c486f973bc7262a97f5aca3ee7e1437011ef0c2ec04b5a11d16c927", size = 148644, upload-time = "2025-08-09T07:56:30.984Z" }, - { url = "https://files.pythonhosted.org/packages/77/d9/cbcf1a2a5c7d7856f11e7ac2d782aec12bdfea60d104e60e0aa1c97849dc/charset_normalizer-3.4.3-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:fdabf8315679312cfa71302f9bd509ded4f2f263fb5b765cf1433b39106c3cc9", size = 160468, upload-time = "2025-08-09T07:56:32.252Z" }, - { url = "https://files.pythonhosted.org/packages/f6/42/6f45efee8697b89fda4d50580f292b8f7f9306cb2971d4b53f8914e4d890/charset_normalizer-3.4.3-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:bd28b817ea8c70215401f657edef3a8aa83c29d447fb0b622c35403780ba11d5", size = 158187, upload-time = "2025-08-09T07:56:33.481Z" }, - { url = "https://files.pythonhosted.org/packages/70/99/f1c3bdcfaa9c45b3ce96f70b14f070411366fa19549c1d4832c935d8e2c3/charset_normalizer-3.4.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:18343b2d246dc6761a249ba1fb13f9ee9a2bcd95decc767319506056ea4ad4dc", size = 152699, upload-time = "2025-08-09T07:56:34.739Z" }, - { url = "https://files.pythonhosted.org/packages/a3/ad/b0081f2f99a4b194bcbb1934ef3b12aa4d9702ced80a37026b7607c72e58/charset_normalizer-3.4.3-cp313-cp313-win32.whl", hash = "sha256:6fb70de56f1859a3f71261cbe41005f56a7842cc348d3aeb26237560bfa5e0ce", size = 99580, upload-time = "2025-08-09T07:56:35.981Z" }, - { url = "https://files.pythonhosted.org/packages/9a/8f/ae790790c7b64f925e5c953b924aaa42a243fb778fed9e41f147b2a5715a/charset_normalizer-3.4.3-cp313-cp313-win_amd64.whl", hash = "sha256:cf1ebb7d78e1ad8ec2a8c4732c7be2e736f6e5123a4146c5b89c9d1f585f8cef", size = 107366, upload-time = "2025-08-09T07:56:37.339Z" }, - { url = "https://files.pythonhosted.org/packages/8e/91/b5a06ad970ddc7a0e513112d40113e834638f4ca1120eb727a249fb2715e/charset_normalizer-3.4.3-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:3cd35b7e8aedeb9e34c41385fda4f73ba609e561faedfae0a9e75e44ac558a15", size = 204342, upload-time = "2025-08-09T07:56:38.687Z" }, - { url = "https://files.pythonhosted.org/packages/ce/ec/1edc30a377f0a02689342f214455c3f6c2fbedd896a1d2f856c002fc3062/charset_normalizer-3.4.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b89bc04de1d83006373429975f8ef9e7932534b8cc9ca582e4db7d20d91816db", size = 145995, upload-time = "2025-08-09T07:56:40.048Z" }, - { url = "https://files.pythonhosted.org/packages/17/e5/5e67ab85e6d22b04641acb5399c8684f4d37caf7558a53859f0283a650e9/charset_normalizer-3.4.3-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:2001a39612b241dae17b4687898843f254f8748b796a2e16f1051a17078d991d", size = 158640, upload-time = "2025-08-09T07:56:41.311Z" }, - { url = "https://files.pythonhosted.org/packages/f1/e5/38421987f6c697ee3722981289d554957c4be652f963d71c5e46a262e135/charset_normalizer-3.4.3-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:8dcfc373f888e4fb39a7bc57e93e3b845e7f462dacc008d9749568b1c4ece096", size = 156636, upload-time = "2025-08-09T07:56:43.195Z" }, - { url = "https://files.pythonhosted.org/packages/a0/e4/5a075de8daa3ec0745a9a3b54467e0c2967daaaf2cec04c845f73493e9a1/charset_normalizer-3.4.3-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:18b97b8404387b96cdbd30ad660f6407799126d26a39ca65729162fd810a99aa", size = 150939, upload-time = "2025-08-09T07:56:44.819Z" }, - { url = "https://files.pythonhosted.org/packages/02/f7/3611b32318b30974131db62b4043f335861d4d9b49adc6d57c1149cc49d4/charset_normalizer-3.4.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:ccf600859c183d70eb47e05a44cd80a4ce77394d1ac0f79dbd2dd90a69a3a049", size = 148580, upload-time = "2025-08-09T07:56:46.684Z" }, - { url = "https://files.pythonhosted.org/packages/7e/61/19b36f4bd67f2793ab6a99b979b4e4f3d8fc754cbdffb805335df4337126/charset_normalizer-3.4.3-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:53cd68b185d98dde4ad8990e56a58dea83a4162161b1ea9272e5c9182ce415e0", size = 159870, upload-time = "2025-08-09T07:56:47.941Z" }, - { url = "https://files.pythonhosted.org/packages/06/57/84722eefdd338c04cf3030ada66889298eaedf3e7a30a624201e0cbe424a/charset_normalizer-3.4.3-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:30a96e1e1f865f78b030d65241c1ee850cdf422d869e9028e2fc1d5e4db73b92", size = 157797, upload-time = "2025-08-09T07:56:49.756Z" }, - { url = "https://files.pythonhosted.org/packages/72/2a/aff5dd112b2f14bcc3462c312dce5445806bfc8ab3a7328555da95330e4b/charset_normalizer-3.4.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:d716a916938e03231e86e43782ca7878fb602a125a91e7acb8b5112e2e96ac16", size = 152224, upload-time = "2025-08-09T07:56:51.369Z" }, - { url = "https://files.pythonhosted.org/packages/b7/8c/9839225320046ed279c6e839d51f028342eb77c91c89b8ef2549f951f3ec/charset_normalizer-3.4.3-cp314-cp314-win32.whl", hash = "sha256:c6dbd0ccdda3a2ba7c2ecd9d77b37f3b5831687d8dc1b6ca5f56a4880cc7b7ce", size = 100086, upload-time = "2025-08-09T07:56:52.722Z" }, - { url = "https://files.pythonhosted.org/packages/ee/7a/36fbcf646e41f710ce0a563c1c9a343c6edf9be80786edeb15b6f62e17db/charset_normalizer-3.4.3-cp314-cp314-win_amd64.whl", hash = "sha256:73dc19b562516fc9bcf6e5d6e596df0b4eb98d87e4f79f3ae71840e6ed21361c", size = 107400, upload-time = "2025-08-09T07:56:55.172Z" }, - { url = "https://files.pythonhosted.org/packages/8a/1f/f041989e93b001bc4e44bb1669ccdcf54d3f00e628229a85b08d330615c5/charset_normalizer-3.4.3-py3-none-any.whl", hash = "sha256:ce571ab16d890d23b5c278547ba694193a45011ff86a9162a71307ed9f86759a", size = 53175, upload-time = "2025-08-09T07:57:26.864Z" }, ] [[package]] @@ -375,18 +301,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/13/18/8a9dca353e605b344408114f6b045b11d14082d19f4668b073259d3ed1a9/geojson_pydantic-2.1.0-py3-none-any.whl", hash = "sha256:f9091bed334ab9fbb1bef113674edc1212a3737f374a0b13b1aa493f57964c1d", size = 8819, upload-time = "2025-10-08T13:31:11.646Z" }, ] -[[package]] -name = "ghp-import" -version = "2.1.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "python-dateutil" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/d9/29/d40217cbe2f6b1359e00c6c307bb3fc876ba74068cbab3dde77f03ca0dc4/ghp-import-2.1.0.tar.gz", hash = "sha256:9c535c4c61193c2df8871222567d7fd7e5014d835f97dc7b7439069e2413d343", size = 10943, upload-time = "2022-05-02T15:47:16.11Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/f7/ec/67fbef5d497f86283db54c22eec6f6140243aae73265799baaaa19cd17fb/ghp_import-2.1.0-py3-none-any.whl", hash = "sha256:8337dd7b50877f163d4c0289bc1f1c7f127550241988d568c1db512c4324a619", size = 11034, upload-time = "2022-05-02T15:47:14.552Z" }, -] - [[package]] name = "greenlet" version = "3.2.4" @@ -411,18 +325,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e3/a5/6ddab2b4c112be95601c13428db1d8b6608a8b6039816f2ba09c346c08fc/greenlet-3.2.4-cp314-cp314-win_amd64.whl", hash = "sha256:e37ab26028f12dbb0ff65f29a8d3d44a765c61e729647bf2ddfbbed621726f01", size = 303425, upload-time = "2025-08-07T13:32:27.59Z" }, ] -[[package]] -name = "griffe" -version = "1.14.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "colorama" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/ec/d7/6c09dd7ce4c7837e4cdb11dce980cb45ae3cd87677298dc3b781b6bce7d3/griffe-1.14.0.tar.gz", hash = "sha256:9d2a15c1eca966d68e00517de5d69dd1bc5c9f2335ef6c1775362ba5b8651a13", size = 424684, upload-time = "2025-09-05T15:02:29.167Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/b1/9ff6578d789a89812ff21e4e0f80ffae20a65d5dd84e7a17873fe3b365be/griffe-1.14.0-py3-none-any.whl", hash = "sha256:0e9d52832cccf0f7188cfe585ba962d2674b241c01916d780925df34873bceb0", size = 144439, upload-time = "2025-09-05T15:02:27.511Z" }, -] - [[package]] name = "gunicorn" version = "23.0.0" @@ -557,15 +459,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl", hash = "sha256:baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59", size = 78509, upload-time = "2025-04-10T12:50:53.297Z" }, ] -[[package]] -name = "markdown" -version = "3.9" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/8d/37/02347f6d6d8279247a5837082ebc26fc0d5aaeaf75aa013fcbb433c777ab/markdown-3.9.tar.gz", hash = "sha256:d2900fe1782bd33bdbbd56859defef70c2e78fc46668f8eb9df3128138f2cb6a", size = 364585, upload-time = "2025-09-04T20:25:22.885Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/70/ae/44c4a6a4cbb496d93c6257954260fe3a6e91b7bed2240e5dad2a717f5111/markdown-3.9-py3-none-any.whl", hash = "sha256:9f4d91ed810864ea88a6f32c07ba8bee1346c0cc1f6b1f9f6c822f2a9667d280", size = 107441, upload-time = "2025-09-04T20:25:21.784Z" }, -] - [[package]] name = "markdown-it-py" version = "4.0.0" @@ -639,129 +532,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" }, ] -[[package]] -name = "mergedeep" -version = "1.3.4" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/3a/41/580bb4006e3ed0361b8151a01d324fb03f420815446c7def45d02f74c270/mergedeep-1.3.4.tar.gz", hash = "sha256:0096d52e9dad9939c3d975a774666af186eda617e6ca84df4c94dec30004f2a8", size = 4661, upload-time = "2021-02-05T18:55:30.623Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2c/19/04f9b178c2d8a15b076c8b5140708fa6ffc5601fb6f1e975537072df5b2a/mergedeep-1.3.4-py3-none-any.whl", hash = "sha256:70775750742b25c0d8f36c55aed03d24c3384d17c951b3175d898bd778ef0307", size = 6354, upload-time = "2021-02-05T18:55:29.583Z" }, -] - -[[package]] -name = "mkdocs" -version = "1.6.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "click" }, - { name = "colorama", marker = "sys_platform == 'win32'" }, - { name = "ghp-import" }, - { name = "jinja2" }, - { name = "markdown" }, - { name = "markupsafe" }, - { name = "mergedeep" }, - { name = "mkdocs-get-deps" }, - { name = "packaging" }, - { name = "pathspec" }, - { name = "pyyaml" }, - { name = "pyyaml-env-tag" }, - { name = "watchdog" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/bc/c6/bbd4f061bd16b378247f12953ffcb04786a618ce5e904b8c5a01a0309061/mkdocs-1.6.1.tar.gz", hash = "sha256:7b432f01d928c084353ab39c57282f29f92136665bdd6abf7c1ec8d822ef86f2", size = 3889159, upload-time = "2024-08-30T12:24:06.899Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/22/5b/dbc6a8cddc9cfa9c4971d59fb12bb8d42e161b7e7f8cc89e49137c5b279c/mkdocs-1.6.1-py3-none-any.whl", hash = "sha256:db91759624d1647f3f34aa0c3f327dd2601beae39a366d6e064c03468d35c20e", size = 3864451, upload-time = "2024-08-30T12:24:05.054Z" }, -] - -[[package]] -name = "mkdocs-autorefs" -version = "1.4.3" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "markdown" }, - { name = "markupsafe" }, - { name = "mkdocs" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/51/fa/9124cd63d822e2bcbea1450ae68cdc3faf3655c69b455f3a7ed36ce6c628/mkdocs_autorefs-1.4.3.tar.gz", hash = "sha256:beee715b254455c4aa93b6ef3c67579c399ca092259cc41b7d9342573ff1fc75", size = 55425, upload-time = "2025-08-26T14:23:17.223Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/9f/4d/7123b6fa2278000688ebd338e2a06d16870aaf9eceae6ba047ea05f92df1/mkdocs_autorefs-1.4.3-py3-none-any.whl", hash = "sha256:469d85eb3114801d08e9cc55d102b3ba65917a869b893403b8987b601cf55dc9", size = 25034, upload-time = "2025-08-26T14:23:15.906Z" }, -] - -[[package]] -name = "mkdocs-get-deps" -version = "0.2.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "mergedeep" }, - { name = "platformdirs" }, - { name = "pyyaml" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/98/f5/ed29cd50067784976f25ed0ed6fcd3c2ce9eb90650aa3b2796ddf7b6870b/mkdocs_get_deps-0.2.0.tar.gz", hash = "sha256:162b3d129c7fad9b19abfdcb9c1458a651628e4b1dea628ac68790fb3061c60c", size = 10239, upload-time = "2023-11-20T17:51:09.981Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/9f/d4/029f984e8d3f3b6b726bd33cafc473b75e9e44c0f7e80a5b29abc466bdea/mkdocs_get_deps-0.2.0-py3-none-any.whl", hash = "sha256:2bf11d0b133e77a0dd036abeeb06dec8775e46efa526dc70667d8863eefc6134", size = 9521, upload-time = "2023-11-20T17:51:08.587Z" }, -] - -[[package]] -name = "mkdocs-material" -version = "9.6.21" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "babel" }, - { name = "backrefs" }, - { name = "colorama" }, - { name = "jinja2" }, - { name = "markdown" }, - { name = "mkdocs" }, - { name = "mkdocs-material-extensions" }, - { name = "paginate" }, - { name = "pygments" }, - { name = "pymdown-extensions" }, - { name = "requests" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/ff/d5/ab83ca9aa314954b0a9e8849780bdd01866a3cfcb15ffb7e3a61ca06ff0b/mkdocs_material-9.6.21.tar.gz", hash = "sha256:b01aa6d2731322438056f360f0e623d3faae981f8f2d8c68b1b973f4f2657870", size = 4043097, upload-time = "2025-09-30T19:11:27.517Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/cf/4f/98681c2030375fe9b057dbfb9008b68f46c07dddf583f4df09bf8075e37f/mkdocs_material-9.6.21-py3-none-any.whl", hash = "sha256:aa6a5ab6fb4f6d381588ac51da8782a4d3757cb3d1b174f81a2ec126e1f22c92", size = 9203097, upload-time = "2025-09-30T19:11:24.063Z" }, -] - -[[package]] -name = "mkdocs-material-extensions" -version = "1.3.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/79/9b/9b4c96d6593b2a541e1cb8b34899a6d021d208bb357042823d4d2cabdbe7/mkdocs_material_extensions-1.3.1.tar.gz", hash = "sha256:10c9511cea88f568257f960358a467d12b970e1f7b2c0e5fb2bb48cab1928443", size = 11847, upload-time = "2023-11-22T19:09:45.208Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/5b/54/662a4743aa81d9582ee9339d4ffa3c8fd40a4965e033d77b9da9774d3960/mkdocs_material_extensions-1.3.1-py3-none-any.whl", hash = "sha256:adff8b62700b25cb77b53358dad940f3ef973dd6db797907c49e3c2ef3ab4e31", size = 8728, upload-time = "2023-11-22T19:09:43.465Z" }, -] - -[[package]] -name = "mkdocstrings" -version = "0.30.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "jinja2" }, - { name = "markdown" }, - { name = "markupsafe" }, - { name = "mkdocs" }, - { name = "mkdocs-autorefs" }, - { name = "pymdown-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/c5/33/2fa3243439f794e685d3e694590d28469a9b8ea733af4b48c250a3ffc9a0/mkdocstrings-0.30.1.tar.gz", hash = "sha256:84a007aae9b707fb0aebfc9da23db4b26fc9ab562eb56e335e9ec480cb19744f", size = 106350, upload-time = "2025-09-19T10:49:26.446Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7b/2c/f0dc4e1ee7f618f5bff7e05898d20bf8b6e7fa612038f768bfa295f136a4/mkdocstrings-0.30.1-py3-none-any.whl", hash = "sha256:41bd71f284ca4d44a668816193e4025c950b002252081e387433656ae9a70a82", size = 36704, upload-time = "2025-09-19T10:49:24.805Z" }, -] - -[[package]] -name = "mkdocstrings-python" -version = "1.18.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "griffe" }, - { name = "mkdocs-autorefs" }, - { name = "mkdocstrings" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/95/ae/58ab2bfbee2792e92a98b97e872f7c003deb903071f75d8d83aa55db28fa/mkdocstrings_python-1.18.2.tar.gz", hash = "sha256:4ad536920a07b6336f50d4c6d5603316fafb1172c5c882370cbbc954770ad323", size = 207972, upload-time = "2025-08-28T16:11:19.847Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d5/8f/ce008599d9adebf33ed144e7736914385e8537f5fc686fdb7cceb8c22431/mkdocstrings_python-1.18.2-py3-none-any.whl", hash = "sha256:944fe6deb8f08f33fa936d538233c4036e9f53e840994f6146e8e94eb71b600d", size = 138215, upload-time = "2025-08-28T16:11:18.176Z" }, -] - [[package]] name = "mypy" version = "1.18.2" @@ -993,15 +763,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" }, ] -[[package]] -name = "paginate" -version = "0.5.7" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ec/46/68dde5b6bc00c1296ec6466ab27dddede6aec9af1b99090e1107091b3b84/paginate-0.5.7.tar.gz", hash = "sha256:22bd083ab41e1a8b4f3690544afb2c60c25e5c9a63a30fa2f483f6c60c8e5945", size = 19252, upload-time = "2024-08-25T14:17:24.139Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/90/96/04b8e52da071d28f5e21a805b19cb9390aa17a47462ac87f5e2696b9566d/paginate-0.5.7-py2.py3-none-any.whl", hash = "sha256:b885e2af73abcf01d9559fd5216b57ef722f8c42affbb63942377668e35c7591", size = 13746, upload-time = "2024-08-25T14:17:22.55Z" }, -] - [[package]] name = "pandas" version = "2.3.3" @@ -1064,15 +825,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08", size = 31191, upload-time = "2023-12-10T22:30:43.14Z" }, ] -[[package]] -name = "platformdirs" -version = "4.5.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/61/33/9611380c2bdb1225fdef633e2a9610622310fed35ab11dac9620972ee088/platformdirs-4.5.0.tar.gz", hash = "sha256:70ddccdd7c99fc5942e9fc25636a8b34d04c24b335100223152c2803e4063312", size = 21632, upload-time = "2025-10-08T17:44:48.791Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/73/cb/ac7874b3e5d58441674fb70742e6c374b28b0c7cb988d37d991cde47166c/platformdirs-4.5.0-py3-none-any.whl", hash = "sha256:e578a81bb873cbb89a41fcc904c7ef523cc18284b7e3b3ccf06aca1403b7ebd3", size = 18651, upload-time = "2025-10-08T17:44:47.223Z" }, -] - [[package]] name = "pluggy" version = "1.6.0" @@ -1165,19 +917,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, ] -[[package]] -name = "pymdown-extensions" -version = "10.16.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "markdown" }, - { name = "pyyaml" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/55/b3/6d2b3f149bc5413b0a29761c2c5832d8ce904a1d7f621e86616d96f505cc/pymdown_extensions-10.16.1.tar.gz", hash = "sha256:aace82bcccba3efc03e25d584e6a22d27a8e17caa3f4dd9f207e49b787aa9a91", size = 853277, upload-time = "2025-07-28T16:19:34.167Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e4/06/43084e6cbd4b3bc0e80f6be743b2e79fbc6eed8de9ad8c629939fa55d972/pymdown_extensions-10.16.1-py3-none-any.whl", hash = "sha256:d6ba157a6c03146a7fb122b2b9a121300056384eafeec9c9f9e584adfdb2a32d", size = 266178, upload-time = "2025-07-28T16:19:31.401Z" }, -] - [[package]] name = "pyright" version = "1.1.406" @@ -1317,33 +1056,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f1/12/de94a39c2ef588c7e6455cfbe7343d3b2dc9d6b6b2f40c4c6565744c873d/pyyaml-6.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:ebc55a14a21cb14062aa4162f906cd962b28e2e9ea38f9b4391244cd8de4ae0b", size = 149341, upload-time = "2025-09-25T21:32:56.828Z" }, ] -[[package]] -name = "pyyaml-env-tag" -version = "1.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pyyaml" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/eb/2e/79c822141bfd05a853236b504869ebc6b70159afc570e1d5a20641782eaa/pyyaml_env_tag-1.1.tar.gz", hash = "sha256:2eb38b75a2d21ee0475d6d97ec19c63287a7e140231e4214969d0eac923cd7ff", size = 5737, upload-time = "2025-05-13T15:24:01.64Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/04/11/432f32f8097b03e3cd5fe57e88efb685d964e2e5178a48ed61e841f7fdce/pyyaml_env_tag-1.1-py3-none-any.whl", hash = "sha256:17109e1a528561e32f026364712fee1264bc2ea6715120891174ed1b980d2e04", size = 4722, upload-time = "2025-05-13T15:23:59.629Z" }, -] - -[[package]] -name = "requests" -version = "2.32.5" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "certifi" }, - { name = "charset-normalizer" }, - { name = "idna" }, - { name = "urllib3" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/c9/74/b3ff8e6c8446842c3f5c837e9c3dfcfe2018ea6ecef224c710c85ef728f4/requests-2.32.5.tar.gz", hash = "sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf", size = 134517, upload-time = "2025-08-18T20:46:02.573Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, -] - [[package]] name = "rich" version = "14.2.0" @@ -1526,6 +1238,69 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/71/58/175d0e4d93f62075a01f8aebe904b412c34a94a4517e5045d0a1d512aad0/sentry_sdk-2.41.0-py2.py3-none-any.whl", hash = "sha256:343cde6540574113d13d178d1b2093e011ac21dd55abd3a1ec7e540f0d18a5bd", size = 370606, upload-time = "2025-10-09T14:12:19.003Z" }, ] +[[package]] +name = "servicekit" +version = "0.1.0" +source = { editable = "packages/servicekit" } +dependencies = [ + { name = "aiosqlite" }, + { name = "alembic" }, + { name = "fastapi", extra = ["standard"] }, + { name = "geojson-pydantic" }, + { name = "gunicorn" }, + { name = "opentelemetry-api" }, + { name = "opentelemetry-exporter-prometheus" }, + { name = "opentelemetry-instrumentation-fastapi" }, + { name = "opentelemetry-instrumentation-sqlalchemy" }, + { name = "opentelemetry-sdk" }, + { name = "pandas" }, + { name = "pydantic" }, + { name = "python-ulid" }, + { name = "sqlalchemy", extra = ["asyncio"] }, + { name = "structlog" }, +] + +[package.dev-dependencies] +dev = [ + { name = "coverage" }, + { name = "mypy" }, + { name = "pyright" }, + { name = "pytest" }, + { name = "pytest-asyncio" }, + { name = "pytest-cov" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "aiosqlite", specifier = ">=0.21.0" }, + { name = "alembic", specifier = ">=1.17.0" }, + { name = "fastapi", extras = ["standard"], specifier = ">=0.119.0" }, + { name = "geojson-pydantic", specifier = ">=2.1.0" }, + { name = "gunicorn", specifier = ">=23.0.0" }, + { name = "opentelemetry-api", specifier = ">=1.37.0" }, + { name = "opentelemetry-exporter-prometheus", specifier = ">=0.58b0" }, + { name = "opentelemetry-instrumentation-fastapi", specifier = ">=0.58b0" }, + { name = "opentelemetry-instrumentation-sqlalchemy", specifier = ">=0.58b0" }, + { name = "opentelemetry-sdk", specifier = ">=1.37.0" }, + { name = "pandas", specifier = ">=2.2.0" }, + { name = "pydantic", specifier = ">=2.12.0" }, + { name = "python-ulid", specifier = ">=3.1.0" }, + { name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.43" }, + { name = "structlog", specifier = ">=24.4.0" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "coverage", specifier = ">=7.10.7" }, + { name = "mypy", specifier = ">=1.18.2" }, + { name = "pyright", specifier = ">=1.1.406" }, + { name = "pytest", specifier = ">=8.4.2" }, + { name = "pytest-asyncio", specifier = ">=1.2.0" }, + { name = "pytest-cov", specifier = ">=7.0.0" }, + { name = "ruff", specifier = ">=0.14.0" }, +] + [[package]] name = "shellingham" version = "1.5.4" @@ -1710,27 +1485,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/63/9a/0962b05b308494e3202d3f794a6e85abe471fe3cafdbcf95c2e8c713aabd/uvloop-0.21.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a5c39f217ab3c663dc699c04cbd50c13813e31d917642d459fdcec07555cc553", size = 4660018, upload-time = "2024-10-14T23:38:10.888Z" }, ] -[[package]] -name = "watchdog" -version = "6.0.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/db/7d/7f3d619e951c88ed75c6037b246ddcf2d322812ee8ea189be89511721d54/watchdog-6.0.0.tar.gz", hash = "sha256:9ddf7c82fda3ae8e24decda1338ede66e1c99883db93711d8fb941eaa2d8c282", size = 131220, upload-time = "2024-11-01T14:07:13.037Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/68/98/b0345cabdce2041a01293ba483333582891a3bd5769b08eceb0d406056ef/watchdog-6.0.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:490ab2ef84f11129844c23fb14ecf30ef3d8a6abafd3754a6f75ca1e6654136c", size = 96480, upload-time = "2024-11-01T14:06:42.952Z" }, - { url = "https://files.pythonhosted.org/packages/85/83/cdf13902c626b28eedef7ec4f10745c52aad8a8fe7eb04ed7b1f111ca20e/watchdog-6.0.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:76aae96b00ae814b181bb25b1b98076d5fc84e8a53cd8885a318b42b6d3a5134", size = 88451, upload-time = "2024-11-01T14:06:45.084Z" }, - { url = "https://files.pythonhosted.org/packages/fe/c4/225c87bae08c8b9ec99030cd48ae9c4eca050a59bf5c2255853e18c87b50/watchdog-6.0.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a175f755fc2279e0b7312c0035d52e27211a5bc39719dd529625b1930917345b", size = 89057, upload-time = "2024-11-01T14:06:47.324Z" }, - { url = "https://files.pythonhosted.org/packages/a9/c7/ca4bf3e518cb57a686b2feb4f55a1892fd9a3dd13f470fca14e00f80ea36/watchdog-6.0.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7607498efa04a3542ae3e05e64da8202e58159aa1fa4acddf7678d34a35d4f13", size = 79079, upload-time = "2024-11-01T14:06:59.472Z" }, - { url = "https://files.pythonhosted.org/packages/5c/51/d46dc9332f9a647593c947b4b88e2381c8dfc0942d15b8edc0310fa4abb1/watchdog-6.0.0-py3-none-manylinux2014_armv7l.whl", hash = "sha256:9041567ee8953024c83343288ccc458fd0a2d811d6a0fd68c4c22609e3490379", size = 79078, upload-time = "2024-11-01T14:07:01.431Z" }, - { url = "https://files.pythonhosted.org/packages/d4/57/04edbf5e169cd318d5f07b4766fee38e825d64b6913ca157ca32d1a42267/watchdog-6.0.0-py3-none-manylinux2014_i686.whl", hash = "sha256:82dc3e3143c7e38ec49d61af98d6558288c415eac98486a5c581726e0737c00e", size = 79076, upload-time = "2024-11-01T14:07:02.568Z" }, - { url = "https://files.pythonhosted.org/packages/ab/cc/da8422b300e13cb187d2203f20b9253e91058aaf7db65b74142013478e66/watchdog-6.0.0-py3-none-manylinux2014_ppc64.whl", hash = "sha256:212ac9b8bf1161dc91bd09c048048a95ca3a4c4f5e5d4a7d1b1a7d5752a7f96f", size = 79077, upload-time = "2024-11-01T14:07:03.893Z" }, - { url = "https://files.pythonhosted.org/packages/2c/3b/b8964e04ae1a025c44ba8e4291f86e97fac443bca31de8bd98d3263d2fcf/watchdog-6.0.0-py3-none-manylinux2014_ppc64le.whl", hash = "sha256:e3df4cbb9a450c6d49318f6d14f4bbc80d763fa587ba46ec86f99f9e6876bb26", size = 79078, upload-time = "2024-11-01T14:07:05.189Z" }, - { url = "https://files.pythonhosted.org/packages/62/ae/a696eb424bedff7407801c257d4b1afda455fe40821a2be430e173660e81/watchdog-6.0.0-py3-none-manylinux2014_s390x.whl", hash = "sha256:2cce7cfc2008eb51feb6aab51251fd79b85d9894e98ba847408f662b3395ca3c", size = 79077, upload-time = "2024-11-01T14:07:06.376Z" }, - { url = "https://files.pythonhosted.org/packages/b5/e8/dbf020b4d98251a9860752a094d09a65e1b436ad181faf929983f697048f/watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:20ffe5b202af80ab4266dcd3e91aae72bf2da48c0d33bdb15c66658e685e94e2", size = 79078, upload-time = "2024-11-01T14:07:07.547Z" }, - { url = "https://files.pythonhosted.org/packages/07/f6/d0e5b343768e8bcb4cda79f0f2f55051bf26177ecd5651f84c07567461cf/watchdog-6.0.0-py3-none-win32.whl", hash = "sha256:07df1fdd701c5d4c8e55ef6cf55b8f0120fe1aef7ef39a1c6fc6bc2e606d517a", size = 79065, upload-time = "2024-11-01T14:07:09.525Z" }, - { url = "https://files.pythonhosted.org/packages/db/d9/c495884c6e548fce18a8f40568ff120bc3a4b7b99813081c8ac0c936fa64/watchdog-6.0.0-py3-none-win_amd64.whl", hash = "sha256:cbafb470cf848d93b5d013e2ecb245d4aa1c8fd0504e863ccefa32445359d680", size = 79070, upload-time = "2024-11-01T14:07:10.686Z" }, - { url = "https://files.pythonhosted.org/packages/33/e8/e40370e6d74ddba47f002a32919d91310d6074130fe4e17dabcafc15cbf1/watchdog-6.0.0-py3-none-win_ia64.whl", hash = "sha256:a1914259fa9e1454315171103c6a30961236f508b9b623eae470268bbcc6a22f", size = 79067, upload-time = "2024-11-01T14:07:11.845Z" }, -] - [[package]] name = "watchfiles" version = "1.1.0"