From 00e60f3c5b4e3a31b84a99247ea4cb38084c166f Mon Sep 17 00:00:00 2001 From: psaegert Date: Thu, 23 Oct 2025 23:28:40 +0200 Subject: [PATCH] Improve docstrings and documentation --- README.md | 4 +- docs/index.md | 105 +++++++++++++++++- pyproject.toml | 2 +- src/simplipy/engine.py | 75 +++++++++---- src/simplipy/utils.py | 238 +++++++++++++++++++++++++---------------- tests/test_utils.py | 14 +++ 6 files changed, 322 insertions(+), 116 deletions(-) diff --git a/README.md b/README.md index f3b7301..cf6cb50 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ engine.simplify(('/', '', '*', '/', '*', 'x3', '', 'x3', 'lo # Simplify infix expressions engine.simplify('x3 * sin( + 1) / (x3 * x3)') -# > '( / x3)' +# > ' / x3' ``` More examples can be found in the [documentation](https://simplipy.readthedocs.io/). @@ -88,7 +88,7 @@ pytest tests --cov src --cov-report html -m "not integration" title = {Efficient Simplification of Mathematical Expressions}, year = 2025, publisher = {GitHub}, - version = {0.2.8}, + version = {0.2.9}, url = {https://github.com/psaegert/simplipy} } ``` diff --git a/docs/index.md b/docs/index.md index 8e904eb..4c8dd99 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,4 +1,103 @@ -# Home +# SimpliPy Documentation -This page is under construction. -Check out the [API Reference](api.md) in the meantime. \ No newline at end of file +SimpliPy is a high-throughput symbolic simplifier built for workloads where +classic tools like SymPy struggle—think millions of expressions in the pre-training of +Flash-ANSR's prefix-based transformer models. Instead of converting tokens into +heavyweight objects and back again, SimpliPy keeps expressions as lightweight +prefix lists, enabling rapid rewriting and direct integration with machine +learning pipelines. + + +## Why SimpliPy Exists + +SymPy excels at exact algebra, but its object graph and string parsing introduce +costs that dominate at scale. SimpliPy was created to remove those bottlenecks: + +- **Prefix-first representation** – Expressions stay as token lists the entire + time, so there's no repeated parsing or AST allocation. +- **Deterministic pipelines** – Rule application, operand sorting, and literal + masking always produce the same layout, which keeps downstream caches warm. +- **GPU-friendly integration** – Outputs map directly into Flash-ANSR's input + space without any conversion step, making it practical to simplify millions of + candidates per minute. + + +## Simplification Pipeline (Pseudo-Algorithm) + +```text +function simplify(expr, max_iter=5): + tokens = parse(expr) # infix→prefix or validate existing prefix + tokens = normalize(tokens) # power folding, unary handling + + for _ in range(max_iter): + tokens = cancel_terms(tokens) # additive/multiplicative multiplicities + tokens = apply_rules(tokens) # compiled rewrite patterns + tokens = sort_operands(tokens) # canonical order for commutative ops + tokens = mask_literals(tokens) # collapse trivial numerics to + + if converged(tokens): + break + + return finalize(tokens) # prefix list or infix string, caller’s choice +``` + +This loop is intentionally lightweight: each pass performs a handful of pure +list transformations, giving you predictable performance even on nested or noisy +expressions. + + +## Key Components + +- **Parsing & normalization** – `SimpliPyEngine.parse` and + `convert_expression` convert infix input, harmonize power operators, and + propagate unary negation without losing prefix fidelity. +- **Term cancellation** – `collect_multiplicities` and `cancel_terms` identify + subtrees that appear with opposite parity or redundant factors, pruning them + before any rules run. +- **Rule execution** – `compile_rules` turns machine-discovered or human-authored + simplifications into tree patterns. `apply_simplifcation_rules` then performs + fast top-down matching in each iteration. +- **Canonical ordering** – `sort_operands` imposes a stable ordering for + commutative operators, ensuring identical expressions share identical token + layouts. +- **Rule discovery workflow** – `find_rules` explores expression space in + parallel worker processes, confirms identities with numeric sampling, and + writes back deduplicated rulesets that future engines can load instantly. + + +## Quickstart + +```bash +pip install simplipy +``` + +```python +import simplipy as sp + +engine = sp.SimpliPyEngine.load("dev_7-3", install=True) + +# Simplify prefix expressions +engine.simplify(['/', '', '*', '/', '*', 'x3', '', 'x3', 'log', 'x3']) +# -> ['/', '', 'log', 'x3'] + +# Simplify infix expressions +engine.simplify('x3 * sin( + 1) / (x3 * x3)') +# -> ' / x3' +``` + +Available engines can be browsed and downloaded from Hugging Face. +The SimpliPy Asset Manager handles listing, installing, and uninstalling assets: + +```python +sp.list_assets("engine") +# --- Available Assets --- +# - dev_7-3 [installed] Development engine 7-3 for mathematical expression simplification. +# - dev_7-2 Development engine 7-2 for mathematical expression simplification. +``` + +## Where to go next + +- Explore the [API reference](api.md) for function-level details. +- Read the [rule authoring guide](rules.md) to build simplification rule sets. + +Happy simplifying! \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 0dc8518..d67a21d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ authors = [ readme = "README.md" requires-python = ">=3.11" dynamic = ["dependencies"] -version = "0.2.8" +version = "0.2.9" license = "MIT" license-files = ["LICEN[CS]E*"] diff --git a/src/simplipy/engine.py b/src/simplipy/engine.py index 94f6349..934f893 100644 --- a/src/simplipy/engine.py +++ b/src/simplipy/engine.py @@ -69,7 +69,7 @@ class SimpliPyEngine: A compiled version of explicit rules without pattern variables. """ def __init__(self, operators: dict[str, dict[str, Any]], rules: list[tuple] | None = None) -> None: - # This part, which sets up all the operator properties, is unchanged. + # Cache operator metadata for quick access during parsing and evaluation. self.operator_tokens = list(operators.keys()) self.operator_aliases = {alias: operator for operator, properties in operators.items() for alias in properties['alias']} self.operator_inverses = {k: v["inverse"] for k, v in operators.items() if v.get("inverse") is not None} @@ -105,16 +105,14 @@ def __init__(self, operators: dict[str, dict[str, Any]], rules: list[tuple] | No self.connection_classes_hyper = {'add': "mult", 'mult': "pow"} self.binary_connectable_operators = {'+', '-', '*', '/'} - # This is the simplified rules handling logic. - # It no longer checks if `rules` is a string or performs any file I/O. - # It only accepts a list of rules or None. + # Normalize the incoming rule list and eliminate duplicate patterns. dummy_variables = [f'x{i}' for i in range(100)] if rules is None: self.simplification_rules = [] else: self.simplification_rules = deduplicate_rules(rules, dummy_variables=dummy_variables) - # This part is also unchanged. + # Build the compiled lookup tables that power rule application. self.compile_rules() self.rule_application_statistics: defaultdict[tuple, int] = defaultdict(int) @@ -268,7 +266,31 @@ def is_valid(self, prefix_expression: list[str], verbose: bool = False) -> bool: return True def prefix_to_infix(self, tokens: list[str], power: Literal['func', '**'] = 'func', realization: bool = False) -> str: - """Converts a prefix expression to a human-readable infix string with minimal parentheses.""" + """Converts a prefix expression to an infix string with minimal parentheses. + + Parameters + ---------- + tokens : list[str] + The prefix expression to render. + power : {'func', '**'}, optional + Controls how power operators are emitted. ``'func'`` keeps canonical + engine names such as ``pow3(x)``, while ``'**'`` renders Python-style + exponentiation. + realization : bool, optional + If True, operator tokens are replaced with their runtime + realizations (for example, ``'sin'`` becomes ``'np.sin'``), so the + output can be compiled directly. + + Returns + ------- + str + The formatted infix expression. + + Raises + ------ + ValueError + If the provided tokens do not form a well-formed prefix expression. + """ if not tokens: return '' @@ -688,7 +710,9 @@ def parse( """Parses an infix string into a standardized prefix expression. This is a high-level parsing utility that combines `infix_to_prefix` - with optional conversion and number masking steps. + with optional canonicalization and number masking. The resulting token + list is additionally cleaned up via `remove_pow1` to drop redundant + ``pow1_1`` occurrences. Parameters ---------- @@ -704,7 +728,8 @@ def parse( Returns ------- list[str] - The final processed prefix expression. + The processed prefix expression after conversion, masking (if + enabled), and `remove_pow1` cleanup. """ parsed_expression = self.infix_to_prefix(infix_expression) @@ -1023,11 +1048,15 @@ def collect_multiplicities(self, expression: list[str] | tuple[str, ...], verbos Returns ------- expression_tree : list - The expression represented as a tree. + A stack-based representation of the expression tree. Each entry is a + nested list of the form ``[operator, operands]`` mirroring the + structure consumed by `cancel_terms`. annotations_tree : list - A parallel tree containing the multiplicity counts for each subtree. + A parallel stack holding multiplicity annotations for each subtree, + organized by connection class. labels_tree : list - A parallel tree containing unique identifiers for each subtree. + A parallel stack containing stable identifiers for every subtree, + used to detect duplicates during cancellation. """ stack: list = [] stack_annotations: list = [] @@ -1133,18 +1162,22 @@ def cancel_terms(self, expression_tree: list, expression_annotations_tree: list, Parameters ---------- expression_tree : list - The nested list representation of the expression. + The stack produced by `collect_multiplicities`, containing the + nested expression structure. expression_annotations_tree : list - The corresponding tree of multiplicity annotations. + The parallel stack of multiplicity annotations returned by + `collect_multiplicities`. stack_labels : list - The corresponding tree of subtree labels. + The parallel stack of subtree labels returned by + `collect_multiplicities`. verbose : bool, optional If True, prints detailed debugging information. Defaults to False. Returns ------- list[str] - A new prefix expression with terms cancelled. + A simplified prefix expression with the detected duplicates merged + or removed. """ stack = expression_tree stack_annotations = expression_annotations_tree @@ -1637,11 +1670,14 @@ def find_rule_worker( constants_fit_retries: int) -> None: """A worker process for discovering simplification rules in parallel. - This function runs in a separate process. It fetches an expression from - the `work_queue`, evaluates it on a set of random numerical data, and + This function runs in a separate process. It fetches work items of the + form ``(expression, simplified_length, allowed_candidate_lengths)`` from + `work_queue`, evaluates the expression on shared random data, and compares the result against a library of simpler candidate expressions. If a numerical equivalence is found, it is considered a potential new - simplification rule and is placed on the `result_queue`. + simplification rule and is placed on the `result_queue`; otherwise ``None`` + is queued to signal that no rule was discovered. A sentinel ``None`` work + item triggers a graceful shutdown. Notes ----- @@ -1803,7 +1839,8 @@ def find_rules( Equivalences are found by evaluating both expressions on random numerical data. - Discovered rules are added to the engine and can be saved to a file. + Discovered rules are deduplicated, compiled into the running engine, and + can optionally be saved to disk. Parameters ---------- diff --git a/src/simplipy/utils.py b/src/simplipy/utils.py index 52943d1..9701b44 100644 --- a/src/simplipy/utils.py +++ b/src/simplipy/utils.py @@ -10,11 +10,12 @@ def apply_on_nested(structure: list | dict, func: Callable) -> list | dict: - """Recursively apply a function to all non-dict/list values in a nested structure. + """Recursively apply a function to all non-structural values in a nested container. - This function traverses a nested dictionary or list and applies the provided - function `func` to every value that is not a dictionary or a list itself. - The modification is done in-place. + This function traverses a nested dictionary or list and applies ``func`` to + every value that is not itself a ``dict`` or ``list``. The original + ``structure`` is mutated; the same instance is returned for convenience. If + ``structure`` is neither a list nor a dictionary, it is returned unchanged. Parameters ---------- @@ -26,13 +27,16 @@ def apply_on_nested(structure: list | dict, func: Callable) -> list | dict: Returns ------- list or dict - The modified nested structure with the function applied to its values. + The input ``structure`` with ``func`` applied to all terminal values. Examples -------- >>> data = {'a': 1, 'b': {'c': 2, 'd': [{'e': 3}, {'f': 4}, 3]}} - >>> sp.utils.apply_on_nested(data, lambda x: x * 10) + >>> result = apply_on_nested(data, lambda x: x * 10) + >>> result {'a': 10, 'b': {'c': 20, 'd': [{'e': 30}, {'f': 40}, 30]}} + >>> data is result + True """ if isinstance(structure, list): for i, value in enumerate(structure): @@ -120,11 +124,12 @@ def codify(code_string: str, variables: list[str] | None = None) -> CodeType: def get_used_modules(infix_expression: str) -> list[str]: - """Extract top-level Python modules used in an infix expression string. + """Return the names of top-level Python modules referenced in an infix expression. - Parses a string to find all occurrences of module-like function calls - (e.g., `numpy.sin(...)`, `math.cos(...)`) and returns a unique list of the - top-level modules. The 'numpy' module is always included by default. + The function scans for dotted attribute accesses that look like module + usages (for example ``numpy.sin(...)`` or ``math.cos(...)``) and collects + their leading module names. The module ``numpy`` is always included so that + downstream evaluation logic can rely on it being available. Parameters ---------- @@ -134,11 +139,12 @@ def get_used_modules(infix_expression: str) -> list[str]: Returns ------- list[str] - A list of unique top-level module names found in the expression. + Unique module names referenced in ``infix_expression``. The order is + derived from the underlying ``set`` and should be treated as arbitrary. Examples -------- - >>> get_used_modules("numpy.sin(x) + math.exp(y)") + >>> sorted(get_used_modules("numpy.sin(x) + math.exp(y)")) ['math', 'numpy'] """ # Match the expression against `module.submodule. ... .function(` @@ -158,41 +164,47 @@ def get_used_modules(infix_expression: str) -> list[str]: def substitude_constants(prefix_expression: list[str], values: list | np.ndarray, constants: list[str] | None = None, inplace: bool = False) -> list[str]: """Substitute placeholders in a prefix expression with numeric values. - This function replaces constant placeholders like `` or `C_i` - in a prefix-notated expression with the provided numerical values in order. + This helper replaces constant placeholders such as ``""`` or the + tokens listed in ``constants`` with the values supplied in ``values``. Values + are consumed from left to right as matching tokens are encountered. Parameters ---------- prefix_expression : list[str] The prefix expression containing constant placeholders. values : list or np.ndarray - The numerical values to substitute into the expression. + The numeric values to substitute into the expression. constants : list[str] or None, optional - An explicit list of constant names to be replaced, by default None. + An explicit list of placeholder names to be replaced. When ``None``, + the function considers ``""`` and ``C_i`` tokens. Defaults to + ``None``. inplace : bool, optional - If True, modifies the list in-place; otherwise, returns a new list. - Defaults to False. + If ``True``, modifies ``prefix_expression`` in-place; otherwise, works on + a shallow copy. Defaults to ``False``. Returns ------- list[str] - The prefix expression with placeholders replaced by values. + The prefix expression with placeholders replaced by strings holding the + given numeric values. + + Raises + ------ + IndexError + If there are more placeholders than supplied ``values``. Examples -------- - With default constant placeholders: >>> expr = ['*', '', '+', 'x', ''] - >>> substitude_constants(expr, [3.14, 2.71], constants=None) + >>> substitude_constants(expr, [3.14, 2.71]) ['*', '3.14', '+', 'x', '2.71'] - With default constant names: >>> expr = ['*', 'C_0', '+', 'x', 'C_1'] - >>> substitute_constants(expr, [3.14, 2.71], constants=['C_0', 'C_1']) + >>> substitude_constants(expr, [3.14, 2.71], constants=['C_0', 'C_1']) ['*', '3.14', '+', 'x', '2.71'] - With custom constant names: >>> expr = ['*', 'k1', '+', 'x', 'k2'] - >>> substitute_constants(expr, [3.14, 2.71], constants=['k1', 'k2']) + >>> substitude_constants(expr, [3.14, 2.71], constants=['k1', 'k2']) ['*', '3.14', '+', 'x', '2.71'] """ if inplace: @@ -285,66 +297,74 @@ def numbers_to_constant(prefix_expression: list[str], inplace: bool = False) -> def explicit_constant_placeholders(prefix_expression: list[str], constants: list[str] | None = None, inplace: bool = False, convert_numbers_to_constant: bool = True) -> tuple[list[str], list[str]]: - """Convert numeric placeholders to indexed constant names (e.g., C_0, C_1). + """Convert placeholder tokens to explicit constant names (for example ``C_0``, ``C_1``). - Replaces `` tokens and optionally numeric strings with unique, - indexed constant names. This prepares the expression for compilation into a - function where constants are passed as named arguments. + ``""`` tokens — and, when ``convert_numbers_to_constant`` is ``True``, + integer-like numeric strings or existing ``C_i`` tokens — are replaced with + explicit constant identifiers. This is useful for generating call signatures + where constants are passed as named arguments. Parameters ---------- prefix_expression : list[str] The prefix expression to process. constants : list[str] or None, optional - An initial list of constants to use for naming, by default None. + Initial constant names to reuse before generating new ones. The returned + list includes these values plus any newly generated identifiers. inplace : bool, optional - If True, modifies the list in-place; otherwise, returns a new list. - Defaults to False. + If ``True``, modifies the input list; otherwise, works on a shallow copy. + Defaults to ``False``. convert_numbers_to_constant : bool, optional - If True, also convert numeric strings to indexed constants. - Defaults to True. + If ``True``, numeric strings consisting only of digits are also replaced. + Defaults to ``True``. Returns ------- tuple[list[str], list[str]] - A tuple containing: - - The modified prefix expression. - - The list of constant names used. + Two items: the modified prefix expression and the list of constant + names used in order of appearance. Examples -------- - >>> expr = ['*', '', '+', 'x', '2.5'] + >>> expr = ['*', '', '+', 'x', '2'] >>> explicit_constant_placeholders(expr) (['*', 'C_0', '+', 'x', 'C_1'], ['C_0', 'C_1']) + + >>> explicit_constant_placeholders(['+', 'C_3', ''], constants=['K']) + (['+', 'K', 'C_0'], ['K', 'C_0', 'C_1']) """ if inplace: modified_prefix_expression = prefix_expression else: modified_prefix_expression = prefix_expression.copy() - constant_index = 0 - if constants is None: - constants = [] - else: - constants = list(constants) + provided_constants = list(constants) if constants is not None else [] + used_constants: list[str] = [] + provided_index = 0 + generated_index = 0 for i, token in enumerate(prefix_expression): if token == "" or (convert_numbers_to_constant and (re.match(r"C_\d+", token) or token.isnumeric())): - if constants is not None and len(constants) > constant_index: - modified_prefix_expression[i] = constants[constant_index] + if provided_index < len(provided_constants): + constant_name = provided_constants[provided_index] + provided_index += 1 else: - modified_prefix_expression[i] = f"C_{constant_index}" - constants.append(f"C_{constant_index}") - constant_index += 1 + constant_name = f"C_{generated_index}" + generated_index += 1 - return modified_prefix_expression, constants + modified_prefix_expression[i] = constant_name + used_constants.append(constant_name) + + return modified_prefix_expression, used_constants def flatten_nested_list(nested_list: list) -> list[str]: - """Flatten an arbitrarily nested list into a single list. + """Flatten an arbitrarily nested list into a single list of leaf values. - This function uses a non-recursive, stack-based approach to efficiently - flatten a nested list structure into a single flat list of elements. + A stack-based traversal is used to avoid recursion limits. Because a LIFO + stack is employed, values appear in reverse depth-first order relative to + the original nesting. ``list(reversed(...))`` can be used to restore a + left-to-right ordering if required. Parameters ---------- @@ -354,7 +374,7 @@ def flatten_nested_list(nested_list: list) -> list[str]: Returns ------- list[str] - The flattened list. + The flattened list of elements encountered during traversal. Examples -------- @@ -404,23 +424,37 @@ def is_prime(n: int) -> bool: def safe_f(f: Callable, X: np.ndarray, constants: np.ndarray | None = None) -> np.ndarray: """Safely evaluate a compiled function on an array of inputs. - This wrapper executes a function `f`, handling optional constants and - ensuring the output is always a NumPy array of the correct shape, even if - the function returns a scalar. + The callable ``f`` is invoked with the columns of ``X`` unpacked as separate + arguments, followed by any optional ``constants``. Scalar results are + broadcast to all samples to guarantee a one-dimensional NumPy array of + length ``X.shape[0]``. Parameters ---------- f : Callable The function to evaluate. X : np.ndarray - The input data array, where rows are samples and columns are features. + Two-dimensional array of input samples. Each column is passed as a + positional argument to ``f``. constants : np.ndarray or None, optional - An array of constant values to pass to the function, by default None. + Extra constant values appended when calling ``f``. Defaults to ``None``. Returns ------- np.ndarray - The result of the function evaluation as a NumPy array. + A one-dimensional array with the evaluation results for each row of + ``X``. + + Examples + -------- + >>> import numpy as np + >>> f = lambda x, y: x + y + >>> safe_f(f, np.array([[1, 2], [3, 4]])) + array([3, 7]) + + >>> g = lambda x, y, c0: c0 + >>> safe_f(g, np.array([[1, 2], [3, 4]]), constants=np.array([5])) + array([5, 5]) """ if constants is None: y = f(*X.T) @@ -569,7 +603,8 @@ def factorize_to_at_most(p: int, max_factor: int, max_iter: int = 1000) -> list[ ------- list[int] The factors of ``p``. Their product is equal to ``p`` and each factor is - less than or equal to ``max_factor``. + less than or equal to ``max_factor``. The factors are yielded in the + order they are discovered and are not sorted. Raises ------ @@ -580,9 +615,9 @@ def factorize_to_at_most(p: int, max_factor: int, max_iter: int = 1000) -> list[ Examples -------- >>> factorize_to_at_most(100, 10) - [10, 10] + [4, 5, 5] >>> factorize_to_at_most(18, 5) - [3, 3, 2] + [2, 3, 3] """ if p < 1: @@ -673,29 +708,37 @@ def mask_elementary_literals(prefix_expression: list[str], inplace: bool = False def construct_expressions(expressions_of_length: dict[int, set[tuple[str, ...]]], non_leaf_nodes: dict[str, int], must_have_sizes: list | set | None = None) -> Generator[tuple[str, ...], None, None]: - """Generate new, larger expressions by combining existing smaller ones. + """Generate new prefix expressions by combining existing building blocks. - This generator function builds complex mathematical expressions by taking a - set of existing expressions (grouped by length) and combining them using - a given set of operators (non-leaf nodes). It systematically creates all - possible new valid expressions. + Expressions are grouped by length in ``expressions_of_length``. For each + operator in ``non_leaf_nodes`` the generator enumerates every compatible + tuple of child expressions and yields the resulting prefix encoding. When + ``must_have_sizes`` is provided, at least one operand must have a length + contained in that collection before the expression is yielded. Parameters ---------- expressions_of_length : dict[int, set[tuple[str, ...]]] - A dictionary mapping expression length to a set of expressions of that - length. These are the building blocks. + Mapping from expression length to the set of expressions with that + length. non_leaf_nodes : dict[str, int] - A dictionary of operators, mapping the operator token to its arity. + Mapping from operator tokens to their arity. must_have_sizes : list or set or None, optional - If provided, only generates combinations where at least one child - expression has a length present in this set. This is an optimization - to avoid redundant constructions. Defaults to None. + If provided, filters generated combinations so that at least one child + expression has a length contained in this collection. Defaults to + ``None``. Yields ------ tuple[str, ...] - A new, valid prefix expression constructed from the inputs. + Newly constructed prefix expressions. + + Examples + -------- + >>> expressions = {1: {('x',), ('y',)}} + >>> operators = {'+': 2} + >>> sorted(construct_expressions(expressions, operators)) + [('+', 'x', 'x'), ('+', 'x', 'y'), ('+', 'y', 'x'), ('+', 'y', 'y')] """ expressions_of_length_with_lists = {k: list(v) for k, v in expressions_of_length.items()} @@ -716,24 +759,32 @@ def construct_expressions(expressions_of_length: dict[int, set[tuple[str, ...]]] def apply_mapping(tree: list, mapping: dict[str, Any]) -> list: - """Apply a variable mapping to a target expression tree. + """Apply a placeholder-to-subtree mapping to a target expression tree. - This function is used after a successful pattern match. It takes a target - expression tree (which may contain placeholders like `_0`, `_1`) and a - mapping from those placeholders to actual subtrees. It returns a new tree - where all placeholders have been replaced by their corresponding subtrees. + Trees are represented as ``[operator, [operands...]]`` where each operand is + itself a tree. Leaves are encoded as one-element lists, for example + ``['x']``. Placeholders such as ``'_0'`` are replaced with the corresponding + subtree provided in ``mapping``. Parameters ---------- tree : list The target expression tree containing placeholders. mapping : dict[str, Any] - The dictionary mapping placeholders to subtrees. + Dictionary mapping placeholder names to the subtrees that should + replace them. Returns ------- list - The new expression tree with placeholders substituted. + A new expression tree with placeholders substituted. + + Examples + -------- + >>> template = ['mul', [['_0'], ['_1']]] + >>> mapping = {'_0': ['x'], '_1': ['add', [['y'], ['z']]]} + >>> apply_mapping(template, mapping) + ['mul', [['x'], ['add', [['y'], ['z']]]]] """ # If the tree is a leaf node, replace the placeholder with the actual subtree defined in the mapping if len(tree) == 1 and isinstance(tree[0], str): @@ -748,11 +799,10 @@ def apply_mapping(tree: list, mapping: dict[str, Any]) -> list: def match_pattern(tree: list, pattern: list, mapping: dict[str, Any] | None = None) -> tuple[bool, dict[str, Any]]: """Recursively match an expression tree against a pattern tree. - This function performs structural pattern matching. It checks if `tree` - conforms to the structure of `pattern`. The pattern can contain - placeholders (e.g., `_0`, `_1`) which match any subtree. If a match is - found, it returns True and a dictionary mapping the placeholders to the - subtrees they matched. + ``tree`` and ``pattern`` use the same representation as described in + :func:`apply_mapping`. Placeholders in ``pattern`` (for example ``'_0'``) + match any subtree. When a match succeeds the mapping is populated with the + subtrees that correspond to each placeholder. Parameters ---------- @@ -761,15 +811,21 @@ def match_pattern(tree: list, pattern: list, mapping: dict[str, Any] | None = No pattern : list The pattern tree to match against. mapping : dict[str, Any] or None, optional - An initial mapping dictionary. If None, an empty one is created. - Defaults to None. + Initial mapping dictionary. If ``None``, an empty one is created. Returns ------- tuple[bool, dict[str, Any]] - A tuple containing: - - A boolean indicating if the match was successful. - - The dictionary mapping placeholders to the matched subtrees. + ``(True, mapping)`` when the structures align; otherwise ``(False, mapping)``. + The returned mapping may contain partial assignments even when the match + fails. + + Examples + -------- + >>> tree = ['mul', [['x'], ['add', [['y'], ['z']]]]] + >>> pattern = ['mul', [['_a'], ['_b']]] + >>> match_pattern(tree, pattern) + (True, {'_a': ['x'], '_b': ['add', [['y'], ['z']]]}) """ if mapping is None: mapping = {} diff --git a/tests/test_utils.py b/tests/test_utils.py index b284833..d3b1952 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -108,6 +108,20 @@ def test_explicit_constant_placeholders(): assert sorted(result_constants) == sorted(expected_constants) +def test_explicit_constant_placeholders_reuses_provided_constants(): + expr = ['+', 'C_3', ''] + result_expr, result_constants = utils.explicit_constant_placeholders(expr, constants=['K']) + assert result_expr == ['+', 'K', 'C_0'] + assert result_constants == ['K', 'C_0'] + + +def test_explicit_constant_placeholders_discards_unused_constants(): + expr = ['+', ''] + result_expr, result_constants = utils.explicit_constant_placeholders(expr, constants=['K', 'L']) + assert result_expr == ['+', 'K'] + assert result_constants == ['K'] + + def test_flatten_nested_list(): """Tests flattening of a nested list.""" # Note: The implementation reverses the list, so the test reflects that.