diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a01d175..0ddffa8 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -22,7 +22,7 @@ jobs: with: python-version-file: '.python-version' - name: Install dependencies - run: uv sync --locked --all-extras --dev + run: uv sync --locked --all-extras --dev --index pytorch-cpu - name: Run ruff checks run: uv run ruff check - name: Run mypy for type checking @@ -48,6 +48,6 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Install dependencies - run: uv sync --locked --all-extras --dev + run: uv sync --locked --all-extras --dev --index pytorch-cpu - name: Run unit tests run: uv run pytest -v diff --git a/CLAUDE.md b/CLAUDE.md index b933412..ba2677a 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -11,7 +11,7 @@ Scribae is a CLI tool that transforms local Markdown notes into structured SEO c ## Build & Development Commands ```bash -uv sync --locked --all-extras --dev # Install dependencies +uv sync --locked --all-extras --dev # Install dependencies (includes PyTorch with CUDA) uv run scribae --help # Run CLI uv run ruff check # Lint (auto-fix: --fix) uv run mypy # Type check @@ -20,6 +20,11 @@ uv run pytest tests/unit/foo_test.py # Run single test file uv run pytest -k "test_name" # Run tests matching pattern ``` +For a lighter install (~200MB vs ~2GB), use the CPU-only PyTorch index: +```bash +uv sync --locked --all-extras --dev --index pytorch-cpu +``` + **Important:** Always run tests, mypy, and ruff at the end of your task and fix any issues. ## Architecture diff --git a/README.md b/README.md index 1cac39b..041994e 100644 --- a/README.md +++ b/README.md @@ -38,18 +38,37 @@ keeping placeholders, links, and numbers intact. - **NLLB fallback.** When pivoting fails, the pipeline falls back to NLLB. ISO codes like `en`/`de`/`es` are mapped to NLLB codes (e.g., `eng_Latn`, `deu_Latn`, `spa_Latn`). You can also pass NLLB codes directly via `--src`/`--tgt`. +### Translation dependencies +Translation uses PyTorch and Hugging Face Transformers. Install the translation extra before running +`scribae translate`: +```bash +uv sync --locked --dev --extra translation +``` +To avoid downloading CUDA libraries (~2GB), use the CPU-only PyTorch index instead: +```bash +uv sync --locked --dev --extra translation --index pytorch-cpu +``` + ## Quick start 1. Install [uv](https://github.com/astral-sh/uv) and sync dependencies (Python 3.12 is managed by uv): ```bash - uv sync --locked --all-extras --dev + uv sync --locked --dev + ``` +2. (Optional) Install translation dependencies: + ```bash + uv sync --locked --dev --extra translation + ``` + Use the CPU-only index if you want to avoid CUDA downloads: + ```bash + uv sync --locked --dev --extra translation --index pytorch-cpu ``` -2. (Optional) Point Scribae at your model endpoint: +3. (Optional) Point Scribae at your model endpoint: ```bash export OPENAI_BASE_URL="http://localhost:11434/v1" export OPENAI_API_KEY="no-key" # or use OPENAI_API_BASE if you prefer ``` -3. Run the CLI: +4. Run the CLI: ```bash uv run scribae --help ``` diff --git a/pyproject.toml b/pyproject.toml index 9712709..8ea0b3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,6 @@ dependencies = [ "pyyaml>=6.0.2", "python-frontmatter>=1.1.0", "transformers>=4.46.3", - "torch>=2.5.1", "sentencepiece>=0.2.0", "sacremoses>=0.1.1", "fast-langdetect>=1.0.0", @@ -44,6 +43,11 @@ dependencies = [ "tomli>=2.0.0;python_version<'3.11'", ] +[project.optional-dependencies] +translation = [ + "torch>=2.5.1", +] + [project.scripts] scribae = "scribae.main:app" @@ -56,6 +60,11 @@ Changelog = "https://github.com/fmueller/scribae/blob/main/CHANGELOG.md" [tool.uv] package = true +[[tool.uv.index]] +name = "pytorch-cpu" +url = "https://download.pytorch.org/whl/cpu" +explicit = true + [dependency-groups] dev = [ "build>=1.2.2.post1", diff --git a/src/scribae/translate/mt.py b/src/scribae/translate/mt.py index a740926..23f4c79 100644 --- a/src/scribae/translate/mt.py +++ b/src/scribae/translate/mt.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections.abc import Iterable +from types import ModuleType from typing import TYPE_CHECKING, Any from .model_registry import ModelRegistry, RouteStep @@ -51,17 +52,25 @@ def _pipeline_for(self, model_id: str) -> Pipeline: from transformers import pipeline if model_id not in self._pipelines: + torch = self._require_torch() if self.device is None or self.device == "auto": - import torch - device = 0 if torch.cuda.is_available() else -1 self._pipelines[model_id] = pipeline("translation", model=model_id, device=device) else: - self._pipelines[model_id] = pipeline( - "translation", model=model_id, device=self.device - ) + self._pipelines[model_id] = pipeline("translation", model=model_id, device=self.device) return self._pipelines[model_id] + def _require_torch(self) -> ModuleType: + try: + import torch + except ImportError as exc: + raise RuntimeError( + "Translation requires PyTorch. Install it with " + "`uv sync --extra translation` or " + "`uv sync --extra translation --index pytorch-cpu` (CPU-only)." + ) from exc + return torch + def prefetch(self, steps: Iterable[RouteStep]) -> None: """Warm translation pipelines for the provided route steps.""" for step in steps: diff --git a/uv.lock b/uv.lock index 842a6ab..acaa118 100644 --- a/uv.lock +++ b/uv.lock @@ -2,9 +2,13 @@ version = 1 revision = 3 requires-python = ">=3.10" resolution-markers = [ - "python_full_version >= '3.12'", - "python_full_version == '3.11.*'", - "python_full_version < '3.11'", + "python_full_version >= '3.13' and platform_machine == 'x86_64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine == 'x86_64' and sys_platform == 'linux'", + "(python_full_version >= '3.12' and platform_machine != 'x86_64') or (python_full_version >= '3.12' and sys_platform != 'linux')", + "python_full_version == '3.11.*' and platform_machine == 'x86_64' and sys_platform == 'linux'", + "(python_full_version == '3.11.*' and platform_machine != 'x86_64') or (python_full_version == '3.11.*' and sys_platform != 'linux')", + "python_full_version < '3.11' and platform_machine == 'x86_64' and sys_platform == 'linux'", + "(python_full_version < '3.11' and platform_machine != 'x86_64') or (python_full_version < '3.11' and sys_platform != 'linux')", ] [[package]] @@ -294,7 +298,7 @@ name = "build" version = "1.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "os_name == 'nt'" }, + { name = "colorama", marker = "(os_name == 'nt' and platform_machine != 'x86_64') or (os_name == 'nt' and sys_platform != 'linux')" }, { name = "importlib-metadata", marker = "python_full_version < '3.10.2'" }, { name = "packaging" }, { name = "pyproject-hooks" }, @@ -1918,7 +1922,8 @@ name = "networkx" version = "3.4.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version < '3.11'", + "python_full_version < '3.11' and platform_machine == 'x86_64' and sys_platform == 'linux'", + "(python_full_version < '3.11' and platform_machine != 'x86_64') or (python_full_version < '3.11' and sys_platform != 'linux')", ] sdist = { url = "https://files.pythonhosted.org/packages/fd/1d/06475e1cd5264c0b870ea2cc6fdb3e37177c1e565c43f56ff17a10e3937f/networkx-3.4.2.tar.gz", hash = "sha256:307c3669428c5362aab27c8a1260aa8f47c4e91d3891f48be0141738d8d053e1", size = 2151368, upload-time = "2024-10-21T12:39:38.695Z" } wheels = [ @@ -1930,8 +1935,11 @@ name = "networkx" version = "3.6.1" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.12'", - "python_full_version == '3.11.*'", + "python_full_version >= '3.13' and platform_machine == 'x86_64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine == 'x86_64' and sys_platform == 'linux'", + "(python_full_version >= '3.12' and platform_machine != 'x86_64') or (python_full_version >= '3.12' and sys_platform != 'linux')", + "python_full_version == '3.11.*' and platform_machine == 'x86_64' and sys_platform == 'linux'", + "(python_full_version == '3.11.*' and platform_machine != 'x86_64') or (python_full_version == '3.11.*' and sys_platform != 'linux')", ] sdist = { url = "https://files.pythonhosted.org/packages/6a/51/63fe664f3908c97be9d2e4f1158eb633317598cfa6e1fc14af5383f17512/networkx-3.6.1.tar.gz", hash = "sha256:26b7c357accc0c8cde558ad486283728b65b6a95d85ee1cd66bafab4c8168509", size = 2517025, upload-time = "2025-12-08T17:02:39.908Z" } wheels = [ @@ -1997,7 +2005,8 @@ name = "numpy" version = "2.2.6" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version < '3.11'", + "python_full_version < '3.11' and platform_machine == 'x86_64' and sys_platform == 'linux'", + "(python_full_version < '3.11' and platform_machine != 'x86_64') or (python_full_version < '3.11' and sys_platform != 'linux')", ] sdist = { url = "https://files.pythonhosted.org/packages/76/21/7d2a95e4bba9dc13d043ee156a356c0a8f0c6309dff6b21b4d71a073b8a8/numpy-2.2.6.tar.gz", hash = "sha256:e29554e2bef54a90aa5cc07da6ce955accb83f21ab5de01a62c8478897b264fd", size = 20276440, upload-time = "2025-05-17T22:38:04.611Z" } wheels = [ @@ -2062,8 +2071,11 @@ name = "numpy" version = "2.3.5" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.12'", - "python_full_version == '3.11.*'", + "python_full_version >= '3.13' and platform_machine == 'x86_64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine == 'x86_64' and sys_platform == 'linux'", + "(python_full_version >= '3.12' and platform_machine != 'x86_64') or (python_full_version >= '3.12' and sys_platform != 'linux')", + "python_full_version == '3.11.*' and platform_machine == 'x86_64' and sys_platform == 'linux'", + "(python_full_version == '3.11.*' and platform_machine != 'x86_64') or (python_full_version == '3.11.*' and sys_platform != 'linux')", ] sdist = { url = "https://files.pythonhosted.org/packages/76/65/21b3bc86aac7b8f2862db1e808f1ea22b028e30a225a34a5ede9bf8678f2/numpy-2.3.5.tar.gz", hash = "sha256:784db1dcdab56bf0517743e746dfb0f885fc68d948aba86eeec2cba234bdf1c0", size = 20584950, upload-time = "2025-11-16T22:52:42.067Z" } wheels = [ @@ -2179,7 +2191,7 @@ name = "nvidia-cudnn-cu12" version = "9.10.2.21" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" }, @@ -2190,7 +2202,7 @@ name = "nvidia-cufft-cu12" version = "11.3.3.83" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" }, @@ -2217,9 +2229,9 @@ name = "nvidia-cusolver-cu12" version = "11.7.3.90" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12" }, - { name = "nvidia-cusparse-cu12" }, - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" }, @@ -2230,7 +2242,7 @@ name = "nvidia-cusparse-cu12" version = "12.5.8.93" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" }, @@ -3618,11 +3630,15 @@ dependencies = [ { name = "sacremoses" }, { name = "sentencepiece" }, { name = "tomli", marker = "python_full_version < '3.11'" }, - { name = "torch" }, { name = "transformers" }, { name = "typer" }, ] +[package.optional-dependencies] +translation = [ + { name = "torch" }, +] + [package.dev-dependencies] dev = [ { name = "build" }, @@ -3645,10 +3661,11 @@ requires-dist = [ { name = "sacremoses", specifier = ">=0.1.1" }, { name = "sentencepiece", specifier = ">=0.2.0" }, { name = "tomli", marker = "python_full_version < '3.11'", specifier = ">=2.0.0" }, - { name = "torch", specifier = ">=2.5.1" }, + { name = "torch", marker = "extra == 'translation'", specifier = ">=2.5.1" }, { name = "transformers", specifier = ">=4.46.3" }, { name = "typer", specifier = ">=0.20.0" }, ] +provides-extras = ["translation"] [package.metadata.requires-dev] dev = [