Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion albert/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,17 @@ def _matches_filter(node: Base, type_filter: TypeOrFilter[Base]) -> bool:


def _sign_penalty(base: Base) -> int:
"""Return a penalty for the sign in scalars in a base object.
"""Return a penalty for the sign in scalars in a `Base` object.

Args:
base: Base object to check.

Returns:
Penalty for the sign.
"""
# TODO: Improve check for Scalar
if hasattr(base, "value"):
return 1 if getattr(base, "value") < 0 else -1
if not base.children:
return 0
penalty = 1
Expand Down
19 changes: 15 additions & 4 deletions albert/opt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,41 @@

from albert.opt._gristmill import optimise_gristmill
from albert.opt.cse import optimise as optimise_albert
from albert.opt._brute import eliminate_and_factorise_common_subexpressions

if TYPE_CHECKING:
from typing import Any
from typing import Any, Literal
from albert.expression import Expression


def optimise(
exprs: list[Expression],
method: str = "auto",
method: Literal["auto", "gristmill", "albert", "legacy"] = "auto",
**kwargs: Any,
) -> list[Expression]:
"""Perform common subexpression elimination on the given expression.

Args:
exprs: The expressions to be optimised.
method: The optimisation method to use. Options are `"auto"`, `"gristmill"`.
method: The optimisation method to use.
**kwargs: Additional keyword arguments to pass to the optimisation method.

Returns:
The optimised expressions, as tuples of the output tensor and the expression.
"""
if method == "gristmill" or method == "auto":
if method == "auto":
try:
return optimise_gristmill(exprs, **kwargs)
except ImportError:
return optimise_albert(exprs, **kwargs)
elif method == "gristmill":
return optimise_gristmill(exprs, **kwargs)
elif method == "albert":
return optimise_albert(exprs, **kwargs)
elif method == "legacy":
return sum(
[eliminate_and_factorise_common_subexpressions(expr, **kwargs) for expr in exprs],
[],
)
else:
raise ValueError(f"Unknown optimisation method: {method!r}")
Loading