Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions prich/cli/dynamic_command_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def create_dynamic_command(config, template: TemplateModel) -> click.Command:
click.Option(["-p", "--provider"], type=click.Choice(config.providers.keys()), show_default=True,
help="Override LLM provider"),
click.Option(["-v", "--verbose"], is_flag=True, default=False, help="Verbose mode"),
click.Option(["-d", "--debug"], is_flag=True, default=False, help="Debug mode"),
click.Option(["-q", "--quiet"], is_flag=True, default=False, help="Suppress all output"),
click.Option(["-f", "--only-final-output"], is_flag=True, default=False,
help="Suppress output and show only the last step output")
Expand Down
2 changes: 1 addition & 1 deletion prich/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# and are not allowed to be user/defined in the template variables cli_option
RESERVED_RUN_TEMPLATE_CLI_OPTIONS = [
"-g", "--global", "-q", "--quiet", "-o", "--output", "-p", "--provider",
"-f", "--only-final-output", "-v", "--verbose"
"-f", "--only-final-output", "-v", "--verbose", "-d", "--debug"
]

# .prich folder name
Expand Down
381 changes: 260 additions & 121 deletions prich/core/engine.py

Large diffs are not rendered by default.

27 changes: 24 additions & 3 deletions prich/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,18 @@ def is_print_enabled() -> bool:
def is_piped() -> bool:
""" Check if prich executed with a piped command (should work only when not executed from pytest) """
# TODO: revisit, we need to allow executions from templates for example
return not console.is_terminal and not os.getenv("PYTEST_CURRENT_TEST")
# return False
# return not console.is_terminal and not os.getenv("PYTEST_CURRENT_TEST")
return False

def console_print(message: str = "", end: str = "\n", markup = None, flush: bool = None):
def console_print(message: str = "", end: str = "\n", markup = None):
""" Print to console wrapper """
if is_print_enabled():
console.print(message, end=end, markup=markup, crop=False)

def console_print_debug(message: str = "", end: str = "\n", markup = None):
""" Print to console wrapper with debug prefix """
console_print(f"[[yellow]debug[/yellow]] {message}", end=end, markup=markup)

def is_valid_template_id(template_id) -> bool:
""" Validate Name Pattern: lowercase letters, numbers, hyphen, optional underscores, and no other characters"""
pattern = r'^[a-z0-9-]+(_[a-z0-9-]+)*$'
Expand Down Expand Up @@ -133,3 +137,20 @@ def is_just_filename(filename: Path | str):
if s in ("", ".", "..") or ("/" in s) or ("\\" in s):
return False
return True

def models_equal(a, b, *,
exclude_unset=False,
exclude_none=False,
by_alias=False,
strict_type=True):
if strict_type and type(a) is not type(b):
return False
return (
a.model_dump(exclude_unset=exclude_unset,
exclude_none=exclude_none,
by_alias=by_alias)
==
b.model_dump(exclude_unset=exclude_unset,
exclude_none=exclude_none,
by_alias=by_alias)
)
14 changes: 14 additions & 0 deletions prich/models/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,20 @@ def str_presenter(dumper, data):
with open(template_file, "w") as f:
f.write(yaml.safe_dump(model_dict, sort_keys=False, width=float("inf")))

def prepare_variables(self, cli_options, env_vars) -> dict:
from prich.core.variable_utils import replace_env_vars
variables = {}
for var in self.variables:
cli_option = var.cli_option
if cli_option:
option_name = cli_option.lstrip("-").replace("-", "_")
variables[var.name] = replace_env_vars(cli_options.get(option_name, var.default), env_vars)
else:
variables[var.name] = replace_env_vars(cli_options.get(var.name, var.default), env_vars)
if var.required and variables.get(var.name) is None:
raise click.ClickException(f"Missing required variable {var.name}")
return variables

def describe(self):
return f"""
Template: {self.id}
Expand Down
2 changes: 1 addition & 1 deletion tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def test_validate_step_output(mock_paths, basic_config, case):
},
]
@pytest.mark.parametrize("case", get_run_command_step_CASES, ids=[c["id"] for c in get_run_command_step_CASES])
def test_run_command_step(case, monkeypatch):
def test_run_command_step(case, mock_paths, monkeypatch):
from prich.core.steps.step_run_command import run_command_step

if case.get("mock_output"):
Expand Down
34 changes: 34 additions & 0 deletions tests/test_run_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,9 +928,40 @@ def test_run_template(case, monkeypatch, basic_config):
"expected_regex_output": ["^### System(?:.|\\n)+### Assistant:\\n$"]},
{"id": "run_local_template_id_quiet", "add_template": True, "args": ["template-local", "--quiet"],
"expected_regex_output": ["^$"]},
{"id": "run_local_template_id_debug_continue", "add_template": True, "args": ["template-local", "--debug"],
"expected_regex_output": [
"\\[debug\\] Initial variables:",
"name = Assistant",
"test_output = This is my text",
"Stash variables before step 1",
"--- Step llm llm step #1",
"\\[debug\\] \\(c\\)ontinue, \\(r\\)epeat step",
], "key_press": ["c"]},
{"id": "run_local_template_id_debug_repeat_continue", "add_template": True, "args": ["template-local", "--debug"],
"expected_regex_output": [
": r\n\\[debug\\] Template change detected - reloading",
"stas\\(h\\): c"
], "key_press": ["r", "c"]},
{"id": "run_local_template_id_debug_list_vars", "add_template": True, "args": ["template-local", "--debug"],
"expected_regex_output": [
"stas\\(h\\): l\n\\[debug\\] Variables:\n\\[debug\\]"
], "key_press": ["l"]},
{"id": "run_local_template_id_debug_list_var_hashes", "add_template": True, "args": ["template-local", "--debug"],
"expected_regex_output": [
"stas\\(h\\): h\n\\[debug\\] Stashed variables\n\\[debug\\]"
], "key_press": ["h"]},
{"id": "run_local_template_id_debug_show_step", "add_template": True, "args": ["template-local", "--debug"],
"expected_regex_output": [
"stas\\(h\\): s\n\\[debug\\] Step #1:\nname: llm step\n"
], "key_press": ["s"]},
]
@pytest.mark.parametrize("case", get_run_template_cli_CASES, ids=[c["id"] for c in get_run_template_cli_CASES])
def test_run_template_cli(mock_paths, monkeypatch, case, template, basic_config):
def get_key_press(count):
res = case.get("key_press")[count["count"]]
count["count"] += 1
return res

global_dir = mock_paths.home_dir
local_dir = mock_paths.cwd_dir

Expand All @@ -939,6 +970,9 @@ def test_run_template_cli(mock_paths, monkeypatch, case, template, basic_config)

monkeypatch.setattr("prich.core.loaders.get_cwd_dir", lambda: local_dir)
monkeypatch.setattr("prich.core.loaders.get_home_dir", lambda: global_dir)
if case.get("key_press"):
count = {"count": 0}
monkeypatch.setattr("click.getchar", lambda: get_key_press(count))

local_config = basic_config.model_copy(deep=True)
global_config = basic_config.model_copy(deep=True)
Expand Down