From c9bed3fb7e8c668766274fea99bc81d4158daa06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felix=20M=C3=BCller?= Date: Thu, 1 Jan 2026 13:25:52 +0100 Subject: [PATCH 1/2] feat: add `--seed` and `--top-p` options for output reproducibility and variability - Enhanced all LLM commands (`idea`, `brief`, `write`, `meta`) with `--seed` and `--top-p` options to support fine-grained output control. - Introduced `--postedit-seed` and `--postedit-top-p` for the `translate` command's LLM post-edit pass. - Updated documentation and usage examples in README to highlight new options. - Added unit tests to ensure proper propagation and handling of `--seed` and `--top-p` in each command. --- CHANGELOG.md | 5 ++ README.md | 22 ++++++++ src/scribae/brief.py | 19 ++++++- src/scribae/brief_cli.py | 16 ++++++ src/scribae/idea.py | 19 ++++++- src/scribae/idea_cli.py | 14 +++++ src/scribae/meta.py | 16 +++++- src/scribae/meta_cli.py | 14 +++++ src/scribae/translate/postedit.py | 17 +++++- src/scribae/translate_cli.py | 16 ++++++ src/scribae/write.py | 33 ++++++++++-- src/scribae/write_cli.py | 14 +++++ tests/unit/idea_cli_test.py | 37 +++++++++++++ tests/unit/main_test.py | 46 +++++++++++++++++ tests/unit/meta_cli_test.py | 86 +++++++++++++++++++++++++++++++ tests/unit/translate_cli_test.py | 51 ++++++++++++++++++ tests/unit/write_cli_test.py | 58 ++++++++++++++++++++- tests/unit/write_language_test.py | 9 +++- 18 files changed, 477 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d3048d0..edf876d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- `--seed` and `--top-p` CLI options for all LLM commands (`idea`, `brief`, `write`, `meta`) to control output reproducibility +- `--postedit-seed` and `--postedit-top-p` CLI options for the `translate` command's LLM post-edit pass + ## 0.1.0 - 2025-12-29 ### Added diff --git a/README.md b/README.md index 0ce57c5..ddd1b7c 100644 --- a/README.md +++ b/README.md @@ -124,6 +124,28 @@ keywords: - programming ``` +### LLM settings + +All commands that use LLMs support `--temperature`, `--top-p`, and `--seed` for controlling output variability: + +| Option | Default | Description | +|--------|---------|-------------| +| `--temperature` | 0.2–0.4 | Sampling temperature. Lower values produce more deterministic output. | +| `--top-p` | (none) | Nucleus sampling threshold. Set to `1.0` for full distribution. | +| `--seed` | (none) | Random seed for reproducible outputs. | + +For fully deterministic output, set `--temperature 0`. For reproducible creative output, combine `--seed` with `--top-p 1.0`. + +```bash +# Deterministic mode +scribae brief --note notes.md --temperature 0 --out brief.json + +# Reproducible creative mode +scribae write --note notes.md --brief brief.json --seed 42 --top-p 1.0 --out draft.md +``` + +The `translate` command uses `--postedit-temperature`, `--postedit-top-p`, and `--postedit-seed` for the LLM post-edit pass. + ## Usage examples ### Idea discovery diff --git a/src/scribae/brief.py b/src/scribae/brief.py index 6a80daa..87cfd8c 100644 --- a/src/scribae/brief.py +++ b/src/scribae/brief.py @@ -227,6 +227,8 @@ def generate_brief( *, model_name: str, temperature: float, + top_p: float | None = None, + seed: int | None = None, reporter: Reporter = None, settings: OpenAISettings | None = None, agent: Agent[None, SeoBrief] | None = None, @@ -236,7 +238,9 @@ def generate_brief( """Run the LLM call and return a validated SeoBrief.""" resolved_settings = settings or OpenAISettings.from_env() llm_agent: Agent[None, SeoBrief] = ( - _create_agent(model_name, resolved_settings, temperature=temperature) if agent is None else agent + _create_agent(model_name, resolved_settings, temperature=temperature, top_p=top_p, seed=seed) + if agent is None + else agent ) _report( @@ -322,10 +326,21 @@ def save_prompt_artifacts( return prompt_path, note_path -def _create_agent(model_name: str, settings: OpenAISettings, *, temperature: float) -> Agent[None, SeoBrief]: +def _create_agent( + model_name: str, + settings: OpenAISettings, + *, + temperature: float, + top_p: float | None = None, + seed: int | None = None, +) -> Agent[None, SeoBrief]: """Instantiate the Pydantic AI agent for generating briefs.""" settings.configure_environment() model_settings = ModelSettings(temperature=temperature) + if top_p is not None: + model_settings["top_p"] = top_p + if seed is not None: + model_settings["seed"] = seed model = make_model(model_name, model_settings=model_settings, settings=settings) return Agent[None, SeoBrief]( model=model, diff --git a/src/scribae/brief_cli.py b/src/scribae/brief_cli.py index 7f5f766..871e6d2 100644 --- a/src/scribae/brief_cli.py +++ b/src/scribae/brief_cli.py @@ -135,6 +135,18 @@ def brief_command( max=2.0, help="Temperature for the LLM request.", ), + top_p: float | None = typer.Option( # noqa: B008 + None, + "--top-p", + min=0.0, + max=1.0, + help="Nucleus sampling threshold (1.0 = full distribution). For reproducibility, set to 1.0.", + ), + seed: int | None = typer.Option( # noqa: B008 + None, + "--seed", + help="Random seed for reproducible outputs. For full determinism, combine with --temperature 0.", + ), dry_run: bool = typer.Option( # noqa: B008 False, "--dry-run", @@ -223,6 +235,8 @@ def brief_command( context, model_name=model, temperature=temperature, + top_p=top_p, + seed=seed, reporter=reporter, ) except KeyboardInterrupt: @@ -273,6 +287,8 @@ def brief_command( context, model_name=model, temperature=temperature, + top_p=top_p, + seed=seed, reporter=reporter, ) except KeyboardInterrupt: diff --git a/src/scribae/idea.py b/src/scribae/idea.py index dad2d08..065cf62 100644 --- a/src/scribae/idea.py +++ b/src/scribae/idea.py @@ -146,6 +146,8 @@ def generate_ideas( *, model_name: str, temperature: float, + top_p: float | None = None, + seed: int | None = None, reporter: Reporter = None, settings: OpenAISettings | None = None, agent: Agent[None, IdeaList] | None = None, @@ -156,7 +158,9 @@ def generate_ideas( resolved_settings = settings or OpenAISettings.from_env() llm_agent: Agent[None, IdeaList] = ( - _create_agent(model_name, resolved_settings, temperature=temperature) if agent is None else agent + _create_agent(model_name, resolved_settings, temperature=temperature, top_p=top_p, seed=seed) + if agent is None + else agent ) _report(reporter, f"Calling model '{model_name}' via {resolved_settings.base_url}") @@ -223,11 +227,22 @@ def save_prompt_artifacts( return prompt_path, note_path -def _create_agent(model_name: str, settings: OpenAISettings, *, temperature: float) -> Agent[None, IdeaList]: +def _create_agent( + model_name: str, + settings: OpenAISettings, + *, + temperature: float, + top_p: float | None = None, + seed: int | None = None, +) -> Agent[None, IdeaList]: """Instantiate the Pydantic AI agent for generating ideas.""" settings.configure_environment() model_settings = ModelSettings(temperature=temperature) + if top_p is not None: + model_settings["top_p"] = top_p + if seed is not None: + model_settings["seed"] = seed model = make_model(model_name, model_settings=model_settings, settings=settings) return Agent[None, IdeaList]( model=model, diff --git a/src/scribae/idea_cli.py b/src/scribae/idea_cli.py index 7f49d24..5c2ec87 100644 --- a/src/scribae/idea_cli.py +++ b/src/scribae/idea_cli.py @@ -87,6 +87,18 @@ def idea_command( max=2.0, help="Temperature for the LLM request.", ), + top_p: float | None = typer.Option( # noqa: B008 + None, + "--top-p", + min=0.0, + max=1.0, + help="Nucleus sampling threshold (1.0 = full distribution). For reproducibility, set to 1.0.", + ), + seed: int | None = typer.Option( # noqa: B008 + None, + "--seed", + help="Random seed for reproducible outputs. For full determinism, combine with --temperature 0.", + ), dry_run: bool = typer.Option( # noqa: B008 False, "--dry-run", @@ -165,6 +177,8 @@ def idea_command( context, model_name=model, temperature=temperature, + top_p=top_p, + seed=seed, reporter=reporter, ) except KeyboardInterrupt: diff --git a/src/scribae/meta.py b/src/scribae/meta.py index cead274..0c03fe5 100644 --- a/src/scribae/meta.py +++ b/src/scribae/meta.py @@ -228,6 +228,8 @@ def generate_metadata( *, model_name: str, temperature: float, + top_p: float | None = None, + seed: int | None = None, reporter: Reporter = None, agent: Agent[None, ArticleMeta] | None = None, prompts: PromptBundle | None = None, @@ -247,7 +249,7 @@ def generate_metadata( resolved_settings = OpenAISettings.from_env() llm_agent: Agent[None, ArticleMeta] = ( - agent if agent is not None else _create_agent(model_name, temperature) + agent if agent is not None else _create_agent(model_name, temperature, top_p=top_p, seed=seed) ) _report( @@ -516,8 +518,18 @@ def _merge_frontmatter(meta: ArticleMeta, original: dict[str, Any], *, overwrite return merged -def _create_agent(model_name: str, temperature: float) -> Agent[None, ArticleMeta]: +def _create_agent( + model_name: str, + temperature: float, + *, + top_p: float | None = None, + seed: int | None = None, +) -> Agent[None, ArticleMeta]: model_settings = ModelSettings(temperature=temperature) + if top_p is not None: + model_settings["top_p"] = top_p + if seed is not None: + model_settings["seed"] = seed model = make_model(model_name, model_settings=model_settings) return Agent[None, ArticleMeta]( model=model, diff --git a/src/scribae/meta_cli.py b/src/scribae/meta_cli.py index e7e163a..169b4ce 100644 --- a/src/scribae/meta_cli.py +++ b/src/scribae/meta_cli.py @@ -83,6 +83,18 @@ def meta_command( max=2.0, help="Temperature for the LLM request.", ), + top_p: float | None = typer.Option( # noqa: B008 + None, + "--top-p", + min=0.0, + max=1.0, + help="Nucleus sampling threshold (1.0 = full distribution). For reproducibility, set to 1.0.", + ), + seed: int | None = typer.Option( # noqa: B008 + None, + "--seed", + help="Random seed for reproducible outputs. For full determinism, combine with --temperature 0.", + ), dry_run: bool = typer.Option( # noqa: B008 False, "--dry-run", @@ -182,6 +194,8 @@ def meta_command( context, model_name=model, temperature=temperature, + top_p=top_p, + seed=seed, reporter=reporter, prompts=prompts, force_llm_on_missing=force_llm_on_missing, diff --git a/src/scribae/translate/postedit.py b/src/scribae/translate/postedit.py index e7e293d..22cea02 100644 --- a/src/scribae/translate/postedit.py +++ b/src/scribae/translate/postedit.py @@ -45,6 +45,8 @@ def __init__( *, model_name: str = DEFAULT_MODEL_NAME, temperature: float = 0.2, + top_p: float | None = None, + seed: int | None = None, create_agent: bool = True, max_chars: int | None = 4_000, timeout_seconds: float | None = 60.0, @@ -57,7 +59,7 @@ def __init__( if agent is not None: self.agent = agent elif create_agent: - self.agent = self._create_agent(model_name, temperature=temperature) + self.agent = self._create_agent(model_name, temperature=temperature, top_p=top_p, seed=seed) def post_edit( self, @@ -437,10 +439,21 @@ def _validate_output(self, text: str, placeholders: Iterable[str], glossary: dic elif target not in text: raise PostEditValidationError(f"Glossary target not enforced: {target}") - def _create_agent(self, model_name: str, *, temperature: float) -> Agent[None, str] | None: + def _create_agent( + self, + model_name: str, + *, + temperature: float, + top_p: float | None = None, + seed: int | None = None, + ) -> Agent[None, str] | None: settings = OpenAISettings.from_env() settings.configure_environment() model_settings = ModelSettings(temperature=temperature) + if top_p is not None: + model_settings["top_p"] = top_p + if seed is not None: + model_settings["seed"] = seed model = make_model(model_name, model_settings=model_settings, settings=settings) return Agent[None, str]( model=model, diff --git a/src/scribae/translate_cli.py b/src/scribae/translate_cli.py index ae872ee..d1c5c8f 100644 --- a/src/scribae/translate_cli.py +++ b/src/scribae/translate_cli.py @@ -197,6 +197,20 @@ def translate( "--pe-temp", help="Temperature for post-edit LLM pass.", ), + postedit_top_p: float | None = typer.Option( # noqa: B008 + None, + "--postedit-top-p", + "--pe-top-p", + min=0.0, + max=1.0, + help="Nucleus sampling threshold for post-edit (1.0 = full distribution).", + ), + postedit_seed: int | None = typer.Option( # noqa: B008 + None, + "--postedit-seed", + "--pe-seed", + help="Random seed for reproducible post-edit outputs.", + ), device: str = typer.Option( # noqa: B008 "auto", "--device", @@ -276,6 +290,8 @@ def translate( posteditor = LLMPostEditor( model_name=postedit_model, temperature=postedit_temperature, + top_p=postedit_top_p, + seed=postedit_seed, create_agent=postedit, max_chars=postedit_max_chars, ) diff --git a/src/scribae/write.py b/src/scribae/write.py index 4acc72d..7875fb8 100644 --- a/src/scribae/write.py +++ b/src/scribae/write.py @@ -218,6 +218,8 @@ def generate_article( *, model_name: str, temperature: float, + top_p: float | None = None, + seed: int | None = None, evidence_required: bool, section_range: tuple[int, int] | None = None, include_faq: bool = False, @@ -242,6 +244,8 @@ def generate_article( expected_language=context.language, model_name=model_name, temperature=temperature, + top_p=top_p, + seed=seed, language_detector=language_detector, reporter=reporter, ) @@ -257,7 +261,7 @@ def generate_article( prompt=prompt, expected_language=context.language, invoke=lambda prompt: _invoke_model( - prompt, model_name=model_name, temperature=temperature + prompt, model_name=model_name, temperature=temperature, top_p=top_p, seed=seed ), extract_text=lambda text: text, reporter=reporter, @@ -280,6 +284,8 @@ def generate_article( context, model_name=model_name, temperature=temperature, + top_p=top_p, + seed=seed, reporter=reporter, language_detector=language_detector, write_faq=write_faq, @@ -320,9 +326,20 @@ def _load_brief(path: Path) -> SeoBrief: return brief -def _invoke_model(prompt: str, *, model_name: str, temperature: float) -> str: +def _invoke_model( + prompt: str, + *, + model_name: str, + temperature: float, + top_p: float | None = None, + seed: int | None = None, +) -> str: """Call the writer model and return Markdown text.""" model_settings = ModelSettings(temperature=temperature) + if top_p is not None: + model_settings["top_p"] = top_p + if seed is not None: + model_settings["seed"] = seed model = make_model(model_name, model_settings=model_settings) agent = Agent(model=model, instructions=SYSTEM_PROMPT) @@ -364,6 +381,8 @@ def _build_faq_body( *, model_name: str, temperature: float, + top_p: float | None = None, + seed: int | None = None, reporter: Reporter, language_detector: Callable[[str], str] | None, write_faq: bool, @@ -391,7 +410,9 @@ def _build_faq_body( body = ensure_language_output( prompt=prompt, expected_language=context.language, - invoke=lambda prompt: _invoke_model(prompt, model_name=model_name, temperature=temperature), + invoke=lambda prompt: _invoke_model( + prompt, model_name=model_name, temperature=temperature, top_p=top_p, seed=seed + ), extract_text=lambda text: text, reporter=reporter, language_detector=language_detector, @@ -419,6 +440,8 @@ def _ensure_section_title_language( expected_language: str, model_name: str, temperature: float, + top_p: float | None = None, + seed: int | None = None, language_detector: Callable[[str], str] | None, reporter: Reporter, ) -> str: @@ -450,7 +473,9 @@ def _ensure_section_title_language( corrected = ensure_language_output( prompt=prompt, expected_language=expected_language, - invoke=lambda prompt: _invoke_model(prompt, model_name=model_name, temperature=temperature), + invoke=lambda prompt: _invoke_model( + prompt, model_name=model_name, temperature=temperature, top_p=top_p, seed=seed + ), extract_text=lambda text: text, reporter=reporter, language_detector=language_detector, diff --git a/src/scribae/write_cli.py b/src/scribae/write_cli.py index 9053ae6..ec85e62 100644 --- a/src/scribae/write_cli.py +++ b/src/scribae/write_cli.py @@ -72,6 +72,18 @@ def write_command( max=2.0, help="Temperature for the LLM request.", ), + top_p: float | None = typer.Option( # noqa: B008 + None, + "--top-p", + min=0.0, + max=1.0, + help="Nucleus sampling threshold (1.0 = full distribution). For reproducibility, set to 1.0.", + ), + seed: int | None = typer.Option( # noqa: B008 + None, + "--seed", + help="Random seed for reproducible outputs. For full determinism, combine with --temperature 0.", + ), dry_run: bool = typer.Option( # noqa: B008 False, "--dry-run", @@ -178,6 +190,8 @@ def write_command( context, model_name=model, temperature=temperature, + top_p=top_p, + seed=seed, evidence_required=evidence_required, section_range=section_range, include_faq=include_faq or write_faq, diff --git a/tests/unit/idea_cli_test.py b/tests/unit/idea_cli_test.py index c7f6887..d025b7d 100644 --- a/tests/unit/idea_cli_test.py +++ b/tests/unit/idea_cli_test.py @@ -117,3 +117,40 @@ def test_idea_save_prompt_creates_files( assert note_snapshot.exists() assert "SYSTEM PROMPT" in prompt_file.read_text(encoding="utf-8") assert note_body in note_snapshot.read_text(encoding="utf-8") + + +def test_idea_passes_seed_and_top_p(monkeypatch: pytest.MonkeyPatch, note_file: Path, fake: Faker) -> None: + ideas = _fake_ideas(fake) + captured_kwargs: dict[str, object] = {} + + def capture_generate(*args: object, **kwargs: object) -> IdeaList: + captured_kwargs.update(kwargs) + return ideas + + monkeypatch.setattr("scribae.idea_cli.generate_ideas", capture_generate) + + result = runner.invoke( + app, + ["idea", "--note", str(note_file), "--json", "--seed", "42", "--top-p", "0.9"], + ) + + assert result.exit_code == 0 + assert captured_kwargs.get("seed") == 42 + assert captured_kwargs.get("top_p") == 0.9 + + +def test_idea_seed_and_top_p_optional(monkeypatch: pytest.MonkeyPatch, note_file: Path, fake: Faker) -> None: + ideas = _fake_ideas(fake) + captured_kwargs: dict[str, object] = {} + + def capture_generate(*args: object, **kwargs: object) -> IdeaList: + captured_kwargs.update(kwargs) + return ideas + + monkeypatch.setattr("scribae.idea_cli.generate_ideas", capture_generate) + + result = runner.invoke(app, ["idea", "--note", str(note_file), "--json"]) + + assert result.exit_code == 0 + assert captured_kwargs.get("seed") is None + assert captured_kwargs.get("top_p") is None diff --git a/tests/unit/main_test.py b/tests/unit/main_test.py index 6936dfc..510e243 100644 --- a/tests/unit/main_test.py +++ b/tests/unit/main_test.py @@ -186,3 +186,49 @@ def test_help_flag_outputs_help() -> None: assert result.exit_code == 0 assert "Scribae — turn local Markdown notes into ideas, SEO briefs" in result.stdout + + +def test_brief_passes_seed_and_top_p(monkeypatch: pytest.MonkeyPatch, note_file: Path, fake: Faker) -> None: + brief_obj = _fake_brief(fake) + captured_kwargs: dict[str, object] = {} + + def capture_generate(*args: object, **kwargs: object) -> SeoBrief: + captured_kwargs.update(kwargs) + return brief_obj + + monkeypatch.setattr("scribae.brief_cli.brief.generate_brief", capture_generate) + + result = runner.invoke( + app, + [ + "brief", + "--note", + str(note_file), + "--json", + "--seed", + "42", + "--top-p", + "0.9", + ], + ) + + assert result.exit_code == 0 + assert captured_kwargs.get("seed") == 42 + assert captured_kwargs.get("top_p") == 0.9 + + +def test_brief_seed_and_top_p_optional(monkeypatch: pytest.MonkeyPatch, note_file: Path, fake: Faker) -> None: + brief_obj = _fake_brief(fake) + captured_kwargs: dict[str, object] = {} + + def capture_generate(*args: object, **kwargs: object) -> SeoBrief: + captured_kwargs.update(kwargs) + return brief_obj + + monkeypatch.setattr("scribae.brief_cli.brief.generate_brief", capture_generate) + + result = runner.invoke(app, ["brief", "--note", str(note_file), "--json"]) + + assert result.exit_code == 0 + assert captured_kwargs.get("seed") is None + assert captured_kwargs.get("top_p") is None diff --git a/tests/unit/meta_cli_test.py b/tests/unit/meta_cli_test.py index 0755bc3..0116a27 100644 --- a/tests/unit/meta_cli_test.py +++ b/tests/unit/meta_cli_test.py @@ -392,3 +392,89 @@ def test_meta_reports_fabricated_fields_reason( assert result.exit_code == 0, result.stderr assert stub.prompts assert "reason: overwrite=missing with force_llm_on_missing" in result.stderr + + +def test_meta_passes_seed_and_top_p( + monkeypatch: pytest.MonkeyPatch, body_without_frontmatter: Path, brief_path: Path, tmp_path: Path +) -> None: + captured_kwargs: dict[str, object] = {} + + def capture_generate(*args: object, **kwargs: object) -> ArticleMeta: + captured_kwargs.update(kwargs) + return ArticleMeta( + title="LLM Title", + slug="llm-slug", + excerpt="Summary", + tags=["tag"], + reading_time=4, + language="en", + ) + + monkeypatch.setattr("scribae.meta_cli.generate_metadata", capture_generate) + + output_path = tmp_path / "meta.json" + result = runner.invoke( + app, + [ + "meta", + "--body", + str(body_without_frontmatter), + "--brief", + str(brief_path), + "--format", + "json", + "--overwrite", + "all", + "--out", + str(output_path), + "--seed", + "42", + "--top-p", + "0.9", + ], + ) + + assert result.exit_code == 0, result.stderr + assert captured_kwargs.get("seed") == 42 + assert captured_kwargs.get("top_p") == 0.9 + + +def test_meta_seed_and_top_p_optional( + monkeypatch: pytest.MonkeyPatch, body_without_frontmatter: Path, brief_path: Path, tmp_path: Path +) -> None: + captured_kwargs: dict[str, object] = {} + + def capture_generate(*args: object, **kwargs: object) -> ArticleMeta: + captured_kwargs.update(kwargs) + return ArticleMeta( + title="LLM Title", + slug="llm-slug", + excerpt="Summary", + tags=["tag"], + reading_time=4, + language="en", + ) + + monkeypatch.setattr("scribae.meta_cli.generate_metadata", capture_generate) + + output_path = tmp_path / "meta.json" + result = runner.invoke( + app, + [ + "meta", + "--body", + str(body_without_frontmatter), + "--brief", + str(brief_path), + "--format", + "json", + "--overwrite", + "all", + "--out", + str(output_path), + ], + ) + + assert result.exit_code == 0, result.stderr + assert captured_kwargs.get("seed") is None + assert captured_kwargs.get("top_p") is None diff --git a/tests/unit/translate_cli_test.py b/tests/unit/translate_cli_test.py index 287e767..890a63e 100644 --- a/tests/unit/translate_cli_test.py +++ b/tests/unit/translate_cli_test.py @@ -48,6 +48,7 @@ class DummyPostEditor: def __init__(self, *args: object, **kwargs: object) -> None: self.args = args self.kwargs = kwargs + calls["postedit_kwargs"] = dict(kwargs) def prefetch_language_model(self) -> None: calls["postedit_prefetch"] = True @@ -527,3 +528,53 @@ def test_translate_configures_library_logging_respects_verbose(monkeypatch: pyte assert os.environ["TOKENIZERS_PARALLELISM"] == "false" assert os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] == "1" assert os.environ["TRANSFORMERS_VERBOSITY"] == "error" + + +def test_translate_passes_postedit_seed_and_top_p( + stub_translation_components: dict[str, Any], + input_markdown: Path, +) -> None: + result = runner.invoke( + app, + [ + "translate", + "--src", + "en", + "--tgt", + "de", + "--in", + str(input_markdown), + "--postedit-seed", + "42", + "--postedit-top-p", + "0.9", + ], + ) + + assert result.exit_code == 0 + postedit_kwargs = stub_translation_components["postedit_kwargs"] + assert postedit_kwargs.get("seed") == 42 + assert postedit_kwargs.get("top_p") == 0.9 + + +def test_translate_postedit_seed_and_top_p_optional( + stub_translation_components: dict[str, Any], + input_markdown: Path, +) -> None: + result = runner.invoke( + app, + [ + "translate", + "--src", + "en", + "--tgt", + "de", + "--in", + str(input_markdown), + ], + ) + + assert result.exit_code == 0 + postedit_kwargs = stub_translation_components["postedit_kwargs"] + assert postedit_kwargs.get("seed") is None + assert postedit_kwargs.get("top_p") is None diff --git a/tests/unit/write_cli_test.py b/tests/unit/write_cli_test.py index c43482e..b57bc3c 100644 --- a/tests/unit/write_cli_test.py +++ b/tests/unit/write_cli_test.py @@ -29,9 +29,19 @@ def brief_path(fixtures_dir: Path) -> Path: class RecordingLLM: def __init__(self) -> None: self.prompts: list[str] = [] - - def __call__(self, prompt: str, *, model_name: str, temperature: float) -> str: + self.kwargs_log: list[dict[str, object]] = [] + + def __call__( + self, + prompt: str, + *, + model_name: str, + temperature: float, + top_p: float | None = None, + seed: int | None = None, + ) -> str: self.prompts.append(prompt) + self.kwargs_log.append({"model_name": model_name, "temperature": temperature, "top_p": top_p, "seed": seed}) section_title = "" for line in prompt.splitlines(): if line.startswith("Current Section:"): @@ -251,3 +261,47 @@ def test_write_faq_generates_section(recording_llm: RecordingLLM, note_path: Pat assert "## FAQ" in body assert "FAQ body." in body assert len(recording_llm.prompts) == 7 + + +def test_write_passes_seed_and_top_p(recording_llm: RecordingLLM, note_path: Path, brief_path: Path) -> None: + result = runner.invoke( + app, + [ + "write", + "--note", + str(note_path), + "--brief", + str(brief_path), + "--section", + "1..1", + "--seed", + "42", + "--top-p", + "0.9", + ], + ) + + assert result.exit_code == 0, result.stderr + assert len(recording_llm.kwargs_log) == 1 + assert recording_llm.kwargs_log[0]["seed"] == 42 + assert recording_llm.kwargs_log[0]["top_p"] == 0.9 + + +def test_write_seed_and_top_p_optional(recording_llm: RecordingLLM, note_path: Path, brief_path: Path) -> None: + result = runner.invoke( + app, + [ + "write", + "--note", + str(note_path), + "--brief", + str(brief_path), + "--section", + "1..1", + ], + ) + + assert result.exit_code == 0, result.stderr + assert len(recording_llm.kwargs_log) == 1 + assert recording_llm.kwargs_log[0]["seed"] is None + assert recording_llm.kwargs_log[0]["top_p"] is None diff --git a/tests/unit/write_language_test.py b/tests/unit/write_language_test.py index d98ddc2..4ad0563 100644 --- a/tests/unit/write_language_test.py +++ b/tests/unit/write_language_test.py @@ -28,7 +28,14 @@ def test_section_title_language_correction(monkeypatch: pytest.MonkeyPatch, note title_calls = iter(["Introduction to Observability", "Einführung in Observability"]) invoked_prompts: list[str] = [] - def fake_invoke(prompt: str, *, model_name: str, temperature: float) -> str: + def fake_invoke( + prompt: str, + *, + model_name: str, + temperature: float, + top_p: float | None = None, + seed: int | None = None, + ) -> str: invoked_prompts.append(prompt) if "Rewrite the following section title" in prompt: return next(title_calls) From 8618669e8ff56fe9378c08c657815faa515fe66d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felix=20M=C3=BCller?= Date: Thu, 1 Jan 2026 13:36:50 +0100 Subject: [PATCH 2/2] refactor: add keyword-only argument for `language_detector` and reorder arguments in `prepare_context` - Made `language_detector` a keyword-only argument in `_validate_language` for improved clarity and API consistency. - Adjusted argument order in `prepare_context` to group keyword-only arguments together in `brief` and `idea` modules. --- src/scribae/brief.py | 2 +- src/scribae/idea.py | 2 +- src/scribae/language.py | 5 +++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/scribae/brief.py b/src/scribae/brief.py index 87cfd8c..518e780 100644 --- a/src/scribae/brief.py +++ b/src/scribae/brief.py @@ -146,8 +146,8 @@ class BriefingContext: def prepare_context( - note_path: Path, *, + note_path: Path, project: ProjectConfig, max_chars: int, language: str | None = None, diff --git a/src/scribae/idea.py b/src/scribae/idea.py index 065cf62..e67c0de 100644 --- a/src/scribae/idea.py +++ b/src/scribae/idea.py @@ -88,8 +88,8 @@ class IdeaContext: def prepare_context( - note_path: Path, *, + note_path: Path, project: ProjectConfig, max_chars: int, language: str | None = None, diff --git a/src/scribae/language.py b/src/scribae/language.py index 8adb85b..fb5f9e5 100644 --- a/src/scribae/language.py +++ b/src/scribae/language.py @@ -93,14 +93,14 @@ def ensure_language_output( first_result = invoke(prompt) try: - _validate_language(extract_text(first_result), expected_language, language_detector) + _validate_language(extract_text(first_result), expected_language, language_detector=language_detector) return first_result except LanguageMismatchError as first_error: _report(reporter, str(first_error) + " Retrying with language correction.") corrective_prompt = _append_language_correction(prompt, expected_language) second_result = invoke(corrective_prompt) - _validate_language(extract_text(second_result), expected_language, language_detector) + _validate_language(extract_text(second_result), expected_language, language_detector=language_detector) return second_result @@ -115,6 +115,7 @@ def _append_language_correction(prompt: str, expected_language: str) -> str: def _validate_language( text: str, expected_language: str, + *, language_detector: Callable[[str], str] | None = None, ) -> None: detected = _detect_language(text, language_detector)