From bd7ad71282cfb8a22ac0f3c027079313b065af83 Mon Sep 17 00:00:00 2001 From: Oliver Backhouse Date: Thu, 6 Nov 2025 16:35:02 +0000 Subject: [PATCH 1/4] Add wick interface --- albert/base.py | 27 +++++- albert/qc/_pdaggerq.py | 3 +- albert/qc/_wick.py | 214 +++++++++++++++++++++++++++++++++++++++++ tests/test_pdaggerq.py | 18 ++-- 4 files changed, 251 insertions(+), 11 deletions(-) create mode 100644 albert/qc/_wick.py diff --git a/albert/base.py b/albert/base.py index 8fbd77f..89ff567 100644 --- a/albert/base.py +++ b/albert/base.py @@ -123,7 +123,7 @@ def _matches_filter(node: Base, type_filter: TypeOrFilter[Base]) -> bool: def _sign_penalty(base: Base) -> int: - """Return a penalty for the sign in scalars in a base object. + """Return a penalty for the sign in scalars in a `Base` object. Args: base: Base object to check. @@ -131,6 +131,9 @@ def _sign_penalty(base: Base) -> int: Returns: Penalty for the sign. """ + # TODO: Improve check for Scalar + if hasattr(base, "value"): + return 1 if getattr(base, "value") < 0 else -1 if not base.children: return 0 penalty = 1 @@ -141,6 +144,28 @@ def _sign_penalty(base: Base) -> int: return penalty +def _factor_penalty(base: Base) -> int: + """Return a penalty for the absolute factor in scalars in a `Base` object. + + Args: + base: Base object to check. + + Returns: + Penalty for the absolute factor. + """ + # TODO: Improve check for Scalar + if hasattr(base, "value"): + return abs(getattr(base, "value")) + if not base.children: + return 1 + penalty = 1 + if base.children: + for child in base.children: + if hasattr(child, "value"): + penalty *= abs(int(getattr(child, "value"))) + return penalty + + class Base(Serialisable): """Base class for algebraic types.""" diff --git a/albert/qc/_pdaggerq.py b/albert/qc/_pdaggerq.py index c440548..df7e80a 100644 --- a/albert/qc/_pdaggerq.py +++ b/albert/qc/_pdaggerq.py @@ -151,6 +151,7 @@ def _convert_symbol( index_spins: Optional[dict[str, str]] = None, index_spaces: Optional[dict[str, str]] = None, l_is_lambda: bool = True, + name: str | None = None, ) -> Base: """Convert a symbol to a subclass of `Base`. @@ -386,7 +387,7 @@ def _convert_symbol( for index in index_strs ) - return tensor_symbol(*indices) + return tensor_symbol.factory(*indices, name=name) def remove_reference_energy(terms: list[list[str]]) -> list[list[str]]: diff --git a/albert/qc/_wick.py b/albert/qc/_wick.py new file mode 100644 index 0000000..ef00243 --- /dev/null +++ b/albert/qc/_wick.py @@ -0,0 +1,214 @@ +"""Interface to `wick`.""" + +from __future__ import annotations + +import re +from numbers import Number +from typing import TYPE_CHECKING + +from albert.algebra import Mul +from albert.index import Index +from albert.qc import ghf +from albert.qc.tensor import QTensor +from albert.scalar import Scalar +from albert.symmetry import symmetric_group +from albert.qc._pdaggerq import _is_number, _guess_space + +if TYPE_CHECKING: + from typing import Any, Literal, Optional + + from albert.base import Base + from albert.symmetry import Symmetry + + +def import_from_wick( + terms: list[str], + index_spins: Optional[dict[str, str]] = None, + index_spaces: Optional[dict[str, str]] = None, + l_is_lambda: bool = True, + symbol_aliases: Optional[dict[str, str]] = None, +) -> Base: + r"""Import an expression from `wick`. + + Tensors in the return expression are `GHF` tensors. + + Args: + terms: The terms of the expression. Should be the lines of the `repr` of the output + `AExpression` in `wick`, i.e. `str(AExpression(Ex=...)).split("\n")`. + index_spins: The index spins. + index_spaces: The index spaces. + l_is_lambda: Whether `l` corresponds to the Lambda operator, rather than the left-hand EOM + operator. + symbol_aliases: Aliases for symbols. + + Returns: + The imported expression. + """ + if index_spins is None: + index_spins = {} + if index_spaces is None: + index_spaces = {} + + # Build the expression + expr: Base = Scalar.factory(0.0) + for term in terms: + # Convert the symbols + term = _split_term(term) + term, names = zip(*[_format_symbol(symbol, aliases=symbol_aliases) for symbol in term]) + symbols = [ + _convert_symbol( + symbol, + index_spins=index_spins, + index_spaces=index_spaces, + l_is_lambda=l_is_lambda, + name=name, + ) + for symbol, name in zip(term, names) + ] + part = Mul.factory(*symbols) + + # Add the term to the expression + expr += part.canonicalise(indices=True) # wick doesn't guarantee same external indices + + return expr + + +def _split_term(term: str) -> list[str]: + """Split a term into its symbols.""" + term = term.lstrip(" ") + term = term.replace(" ", "") + term = term.replace("}", "} ").rstrip(" ") + if r"\sum_{" in term: + term = re.sub(r"\\sum_\{[^\}]*\}", "", term) + else: + i = 0 + while term[i] in "-+0123456789.": + i += 1 + if i > 0: + term = term[:i] + " " + term[i:] + return term.split(" ") + + +def _format_symbol(symbol: str, aliases: dict[str, str] | None = None) -> tuple[str, str]: + """Rewrite a `wick` symbol to look like a `pdaggerq` symbol.""" + symbol = re.sub(r"([a-zA-Z0-9]+)_\{([^\}]*)\}", lambda m: f"{m.group(1)}({','.join(m.group(2))})", symbol) + symbol_name, indices = symbol.split("(", 1) if "(" in symbol else (symbol, None) + if aliases is not None: + symbol_alias = aliases.get(symbol_name, symbol_name) + symbol = f"{symbol_alias}({indices}" if indices is not None else symbol_alias + return symbol, symbol_name + + +def _convert_symbol( + symbol: str, + index_spins: Optional[dict[str, str]] = None, + index_spaces: Optional[dict[str, str]] = None, + l_is_lambda: bool = True, + name: str | None = None, +) -> Base: + """Convert a symbol to a subclass of `Base`. + + Args: + symbol: The symbol. + index_spins: The index spins. + index_spaces: The index spaces. + l_is_lambda: Whether `l` corresponds to the Lambda operator, rather than the left-hand EOM + operator. + + Returns: + The converted symbol. + """ + if index_spins is None: + index_spins = {} + if index_spaces is None: + index_spaces = {} + + if re.match(r".*_[0-9]+$", symbol): + # Symbol has spaces attached, separate them + symbol, spaces = symbol.rsplit("_", 1) + + if _is_number(symbol): + # It's the factor + return Scalar.factory(float(symbol)) + + tensor_symbol: type[QTensor] + index_strs: tuple[str, ...] + if symbol in ("r0", "l0"): + # r0 or l0 + index_strs = () + tensor_symbol = ghf.R0 + elif re.match(r"f\((?i:[a-z]),(?i:[a-z])\)", symbol): + # f(i,j) + index_strs = tuple(symbol[2:-1].split(",")) + tensor_symbol = ghf.Fock + elif re.match(r"v\((?i:[a-z]),(?i:[a-z]),(?i:[a-z]),(?i:[a-z])\)", symbol): + # v(i,j,k,l) + index_strs = tuple(symbol[2:-1].split(",")) + index_strs = (index_strs[2], index_strs[3], index_strs[0], index_strs[1]) + tensor_symbol = ghf.ERI + elif re.match(r"t1\((?i:[a-z]),(?i:[a-z])\)", symbol): + # t1(i,j) + index_strs = tuple(symbol[3:-1].split(",")) + index_strs = (index_strs[1], index_strs[0]) + tensor_symbol = ghf.T1 + elif re.match(r"t2\((?i:[a-z]),(?i:[a-z]),(?i:[a-z]),(?i:[a-z])\)", symbol): + # t2(i,j,k,l) + index_strs = tuple(symbol[3:-1].split(",")) + index_strs = (index_strs[2], index_strs[3], index_strs[0], index_strs[1]) + tensor_symbol = ghf.T2 + elif re.match( + r"t3\((?i:[a-z]),(?i:[a-z]),(?i:[a-z]),(?i:[a-z]),(?i:[a-z]),(?i:[a-z])\)", symbol + ): + # t3(i,j,k,l,m,n) + index_strs = tuple(symbol[3:-1].split(","))[::-1] + index_strs = ( + index_strs[3], + index_strs[4], + index_strs[5], + index_strs[0], + index_strs[1], + index_strs[2], + ) + tensor_symbol = ghf.T3 + elif re.match(r"l1\((?i:[a-z]),(?i:[a-z])\)", symbol) and l_is_lambda: + # l1(i,j) + index_strs = tuple(symbol[3:-1].split(",")) + index_strs = (index_strs[1], index_strs[0]) + tensor_symbol = ghf.L1 + elif re.match(r"l2\((?i:[a-z]),(?i:[a-z]),(?i:[a-z]),(?i:[a-z])\)", symbol) and l_is_lambda: + # l2(i,j,k,l) + index_strs = tuple(symbol[3:-1].split(",")) + index_strs = (index_strs[2], index_strs[3], index_strs[0], index_strs[1]) + tensor_symbol = ghf.L2 + elif re.match( + r"l3\((?i:[a-z]),(?i:[a-z]),(?i:[a-z]),(?i:[a-z]),(?i:[a-z]),(?i:[a-z])\)", symbol + ) and l_is_lambda: + # l3(i,j,k,l,m,n) + index_strs = tuple(symbol[3:-1].split(","))[::-1] + index_strs = ( + index_strs[3], + index_strs[4], + index_strs[5], + index_strs[0], + index_strs[1], + index_strs[2], + ) + tensor_symbol = ghf.L3 + elif re.match(r"delta\((?i:[a-z]),(?i:[a-z])\)", symbol): + # delta(i,j) + index_strs = tuple(symbol[6:-1].split(",")) + tensor_symbol = ghf.Delta + else: + raise ValueError(f"Unknown symbol {symbol}") + + # Convert the indices + indices = tuple( + Index( + index, + spin=index_spins.get(index, None), + space=index_spaces.get(index, _guess_space(index)), + ) + for index in index_strs + ) + + return tensor_symbol.factory(*indices, name=name) diff --git a/tests/test_pdaggerq.py b/tests/test_pdaggerq.py index 26d52d6..8fa73ff 100644 --- a/tests/test_pdaggerq.py +++ b/tests/test_pdaggerq.py @@ -26,7 +26,7 @@ def test_ccsd_energy(): pq.simplify() terms = pq.strings() expr_ghf = import_from_pdaggerq(terms) - expr_ghf = expr_ghf.canonicalise() + expr_ghf = expr_ghf.canonicalise(indices=True) assert ( repr(expr_ghf) == "(0.5 * v(i,j,j,i)) + f(i,i) + (0.25 * t2(i,j,a,b) * v(i,j,a,b)) + (f(i,a) * t1(i,a)) + (0.5 * t1(i,a) * t1(j,b) * v(i,j,a,b))" @@ -36,7 +36,7 @@ def test_ccsd_energy(): terms = remove_reference_energy(terms) expr_ghf = import_from_pdaggerq(terms) - expr_ghf = expr_ghf.canonicalise() + expr_ghf = expr_ghf.canonicalise(indices=True) assert ( repr(expr_ghf) == "(0.25 * t2(i,j,a,b) * v(i,j,a,b)) + (f(i,a) * t1(i,a)) + (0.5 * t1(i,a) * t1(j,b) * v(i,j,a,b))" @@ -55,7 +55,7 @@ def _filter_fock_terms(mul: Mul) -> Mul | Scalar: expr_ghf = expr_ghf.expand() expr_ghf = expr_ghf.apply(_filter_fock_terms, Mul) - expr_ghf = expr_ghf.canonicalise() + expr_ghf = expr_ghf.canonicalise(indices=True) assert ( repr(expr_ghf) == "(0.25 * t2(i,j,a,b) * v(i,j,a,b)) + (0.5 * t1(i,a) * t1(j,b) * v(i,j,a,b))" @@ -87,11 +87,11 @@ def _project(mul: Mul) -> Mul | Scalar: ), Mul, ) - expr_uhf_aaaa = expr_uhf_aaaa.canonicalise() - #assert ( - # repr(expr_uhf_aaaa) - # == "(0.5 * t2(iα,jα,aα,bα) * v(iα,aα,jα,bα)) + (-0.5 * t2(iα,jα,aα,bα) * v(iα,bα,jα,aα)) + (0.5 * t1(iα,aα) * t1(jα,bα) * v(iα,aα,jα,bα)) + (-0.5 * t1(iα,aα) * t1(jα,bα) * v(iα,bα,jα,aα))" - #) + expr_uhf_aaaa = expr_uhf_aaaa.canonicalise(indices=True) + assert ( + repr(expr_uhf_aaaa) + == "(0.5 * t2(iα,jα,aα,bα) * v(iα,aα,jα,bα)) + (-0.5 * t2(iα,jα,aα,bα) * v(iα,bα,jα,aα)) + (0.5 * t1(iα,aα) * t1(jα,bα) * v(iα,aα,jα,bα)) + (-0.5 * t1(iα,aα) * t1(jα,bα) * v(iα,bα,jα,aα))" + ) expr_uhf_abab = expr_uhf.apply( _project_onto_indices( @@ -104,7 +104,7 @@ def _project(mul: Mul) -> Mul | Scalar: ), Mul, ) - expr_uhf_abab = expr_uhf_abab.canonicalise() + expr_uhf_abab = expr_uhf_abab.canonicalise(indices=True) assert ( repr(expr_uhf_abab) == "(0.25 * t2(iα,jβ,aα,bβ) * v(iα,aα,jβ,bβ)) + (0.5 * t1(iα,aα) * t1(jβ,bβ) * v(iα,aα,jβ,bβ))" From e96b0ba58f823030a7005928d30293a8396d458a Mon Sep 17 00:00:00 2001 From: Oliver Backhouse Date: Thu, 6 Nov 2025 16:37:46 +0000 Subject: [PATCH 2/4] linting --- albert/base.py | 22 ---------------------- albert/qc/_pdaggerq.py | 1 + albert/qc/_wick.py | 23 ++++++++++++----------- 3 files changed, 13 insertions(+), 33 deletions(-) diff --git a/albert/base.py b/albert/base.py index 89ff567..953c852 100644 --- a/albert/base.py +++ b/albert/base.py @@ -144,28 +144,6 @@ def _sign_penalty(base: Base) -> int: return penalty -def _factor_penalty(base: Base) -> int: - """Return a penalty for the absolute factor in scalars in a `Base` object. - - Args: - base: Base object to check. - - Returns: - Penalty for the absolute factor. - """ - # TODO: Improve check for Scalar - if hasattr(base, "value"): - return abs(getattr(base, "value")) - if not base.children: - return 1 - penalty = 1 - if base.children: - for child in base.children: - if hasattr(child, "value"): - penalty *= abs(int(getattr(child, "value"))) - return penalty - - class Base(Serialisable): """Base class for algebraic types.""" diff --git a/albert/qc/_pdaggerq.py b/albert/qc/_pdaggerq.py index df7e80a..845acb7 100644 --- a/albert/qc/_pdaggerq.py +++ b/albert/qc/_pdaggerq.py @@ -161,6 +161,7 @@ def _convert_symbol( index_spaces: The index spaces. l_is_lambda: Whether `l` corresponds to the Lambda operator, rather than the left-hand EOM operator. + name: The name of the tensor. Returns: The converted symbol. diff --git a/albert/qc/_wick.py b/albert/qc/_wick.py index ef00243..19d9a91 100644 --- a/albert/qc/_wick.py +++ b/albert/qc/_wick.py @@ -3,22 +3,19 @@ from __future__ import annotations import re -from numbers import Number from typing import TYPE_CHECKING from albert.algebra import Mul from albert.index import Index from albert.qc import ghf +from albert.qc._pdaggerq import _guess_space, _is_number from albert.qc.tensor import QTensor from albert.scalar import Scalar -from albert.symmetry import symmetric_group -from albert.qc._pdaggerq import _is_number, _guess_space if TYPE_CHECKING: - from typing import Any, Literal, Optional + from typing import Optional from albert.base import Base - from albert.symmetry import Symmetry def import_from_wick( @@ -51,9 +48,9 @@ def import_from_wick( # Build the expression expr: Base = Scalar.factory(0.0) - for term in terms: + for term_str in terms: # Convert the symbols - term = _split_term(term) + term = _split_term(term_str) term, names = zip(*[_format_symbol(symbol, aliases=symbol_aliases) for symbol in term]) symbols = [ _convert_symbol( @@ -91,7 +88,9 @@ def _split_term(term: str) -> list[str]: def _format_symbol(symbol: str, aliases: dict[str, str] | None = None) -> tuple[str, str]: """Rewrite a `wick` symbol to look like a `pdaggerq` symbol.""" - symbol = re.sub(r"([a-zA-Z0-9]+)_\{([^\}]*)\}", lambda m: f"{m.group(1)}({','.join(m.group(2))})", symbol) + symbol = re.sub( + r"([a-zA-Z0-9]+)_\{([^\}]*)\}", lambda m: f"{m.group(1)}({','.join(m.group(2))})", symbol + ) symbol_name, indices = symbol.split("(", 1) if "(" in symbol else (symbol, None) if aliases is not None: symbol_alias = aliases.get(symbol_name, symbol_name) @@ -114,6 +113,7 @@ def _convert_symbol( index_spaces: The index spaces. l_is_lambda: Whether `l` corresponds to the Lambda operator, rather than the left-hand EOM operator. + name: The name of the tensor. Returns: The converted symbol. @@ -180,9 +180,10 @@ def _convert_symbol( index_strs = tuple(symbol[3:-1].split(",")) index_strs = (index_strs[2], index_strs[3], index_strs[0], index_strs[1]) tensor_symbol = ghf.L2 - elif re.match( - r"l3\((?i:[a-z]),(?i:[a-z]),(?i:[a-z]),(?i:[a-z]),(?i:[a-z]),(?i:[a-z])\)", symbol - ) and l_is_lambda: + elif ( + re.match(r"l3\((?i:[a-z]),(?i:[a-z]),(?i:[a-z]),(?i:[a-z]),(?i:[a-z]),(?i:[a-z])\)", symbol) + and l_is_lambda + ): # l3(i,j,k,l,m,n) index_strs = tuple(symbol[3:-1].split(","))[::-1] index_strs = ( From 8eba92fdf6e6d805e1b06f127a82804bc47c9146 Mon Sep 17 00:00:00 2001 From: Oliver Backhouse Date: Tue, 25 Nov 2025 23:16:44 +0000 Subject: [PATCH 3/4] Simplify some interfaces --- albert/opt/__init__.py | 7 ++- albert/opt/cse.py | 1 + albert/qc/__init__.py | 95 +++++++++++++++++++++++++++++++++++++++ examples/codegen_rccsd.py | 31 ++++++------- tests/test_rccsd.py | 58 +++++++++--------------- tests/test_uccsd.py | 63 ++++++++++---------------- 6 files changed, 161 insertions(+), 94 deletions(-) diff --git a/albert/opt/__init__.py b/albert/opt/__init__.py index 8d08094..adc2739 100644 --- a/albert/opt/__init__.py +++ b/albert/opt/__init__.py @@ -27,7 +27,12 @@ def optimise( Returns: The optimised expressions, as tuples of the output tensor and the expression. """ - if method == "gristmill" or method == "auto": + if method == "auto": + try: + return optimise_gristmill(exprs, **kwargs) + except ImportError: + return optimise_albert(exprs, **kwargs) + elif method == "gristmill": return optimise_gristmill(exprs, **kwargs) elif method == "albert": return optimise_albert(exprs, **kwargs) diff --git a/albert/opt/cse.py b/albert/opt/cse.py index 0f3f268..cf776b6 100644 --- a/albert/opt/cse.py +++ b/albert/opt/cse.py @@ -786,6 +786,7 @@ def _optimise( return expressions # Canonicalise the terms in the expressions + expressions = list(expressions) for i, expression in enumerate(expressions): expressions[i] = _canonicalise_expression(expression, indices) diff --git a/albert/qc/__init__.py b/albert/qc/__init__.py index bf5aa55..208d9bd 100644 --- a/albert/qc/__init__.py +++ b/albert/qc/__init__.py @@ -1 +1,96 @@ """Functionality specific to quantum chemistry applications.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from albert.qc._pdaggerq import import_from_pdaggerq +from albert.qc._wick import import_from_wick +from albert.qc.spin import ghf_to_rhf, ghf_to_uhf +from albert.tensor import Tensor +from albert.expression import Expression + +if TYPE_CHECKING: + from typing import Any, Iterable, Literal + + +def import_expression( + external: Any, + package: Literal["pdaggerq", "wick"] = "pdaggerq", + index_order: Iterable[str] | None = None, + name: str | None = None, + **kwargs: Any, +) -> Expression: + """Import an expression from a third-party quantum chemistry package. + + Args: + external: The external expression to import. The exact format is specified by the + individual importers. + package: The code from which to import the expression. + index_order: The desired order of external index labels in the imported expression. The + indices of the left-hand side of the expression will be sorted such that their labels + match this order. + name: The name to assign to the LHS tensor. + **kwargs: Additional keyword arguments to pass to the importer. + + Returns: + The imported expression. + """ + # Import the RHS + if package == "pdaggerq": + rhs = import_from_pdaggerq(external, **kwargs) + elif package == "wick": + rhs = import_from_wick(external, **kwargs) + else: + raise ValueError(f"Unknown package: {package}") + rhs = rhs.canonicalise(indices=True) + + # Get the LHS + if index_order is None: + indices = rhs.external_indices + else: + indices = tuple(sorted(rhs.external_indices, key=lambda i: list(index_order).index(i.name))) + lhs = Tensor(*indices, name=name) + + return Expression(lhs, rhs) + + +def adapt_spin( + expr: Expression | Iterable[Expression], + target_spin: Literal["rhf", "uhf"], +) -> tuple[Expression, ...]: + """Adapt the spin representation of a quantum chemistry expression. + + Args: + expr: The expression(s) to adapt. + target_spin: The target spin representation. + + Returns: + The adapted expressions. For `"rhf"`, this is a tuple with a single expression. For `"uhf"`, + this is a tuple with one expression per spin case. + """ + if isinstance(expr, Expression): + expr = (expr,) + + # Convert the RHS + if target_spin == "rhf": + rhs_list = [ghf_to_rhf(e.rhs) for e in expr] + lhs_list = [e.lhs for e in expr] + elif target_spin == "uhf": + rhs_list = [] + lhs_list = [] + for e in expr: + rhs_parts = ghf_to_uhf(e.rhs) + rhs_list.extend(rhs_parts) + lhs_list.extend([e.lhs for _ in rhs_parts]) + else: + raise ValueError(f"Unknown target spin: {target_spin}") + + # Get the LHS for each case + exprs = [] + for lhs, rhs in zip(lhs_list, rhs_list): + spins = {index.name: index.spin for index in rhs.external_indices} + index_map = {index: index.copy(spin=spins[index.name]) for index in lhs.external_indices} + exprs.append(Expression(lhs.map_indices(index_map), rhs)) + + return tuple(exprs) diff --git a/examples/codegen_rccsd.py b/examples/codegen_rccsd.py index 17fcbe1..c0328d4 100644 --- a/examples/codegen_rccsd.py +++ b/examples/codegen_rccsd.py @@ -7,9 +7,9 @@ from albert.code.einsum import EinsumCodeGenerator from albert.expression import Expression -from albert.opt._gristmill import optimise_gristmill -from albert.qc._pdaggerq import import_from_pdaggerq, remove_reference_energy -from albert.qc.spin import ghf_to_rhf +from albert.opt import optimise +from albert.qc._pdaggerq import remove_reference_energy +from albert.qc import import_expression, adapt_spin from albert.tensor import Tensor # Suppress warnings since we're outputting the code to stdout @@ -30,15 +30,14 @@ pq.simplify() expr = pq.strings() expr = remove_reference_energy(expr) -expr = import_from_pdaggerq(expr) -expr = ghf_to_rhf(expr).collect() -output = Tensor(name="e_cc") +expr = import_expression(expr, name="e_cc") +exprs = adapt_spin(expr, target_spin="rhf") # Optimise the energy expression -exprs = optimise_gristmill([Expression(output, expr)], strategy="exhaust") +exprs = optimise(exprs, strategy="exhaust") # Generate the code for the energy expression -codegen("energy", [output], exprs) +codegen("energy", [expr.lhs for expr in exprs], exprs) # Find the T1 expression pq.clear() @@ -47,9 +46,8 @@ pq.add_st_operator(1.0, ["v"], ["t1", "t2"]) pq.simplify() expr_t1 = pq.strings() -expr_t1 = import_from_pdaggerq(expr_t1) -expr_t1 = ghf_to_rhf(expr_t1).collect() -output_t1 = Tensor(*expr_t1.external_indices, name="t1new") +expr_t1 = import_expression(expr_t1, name="t1new") +exprs_t1 = adapt_spin(expr_t1, target_spin="rhf") # Find the T2 expression pq.clear() @@ -58,20 +56,19 @@ pq.add_st_operator(1.0, ["v"], ["t1", "t2"]) pq.simplify() expr_t2 = pq.strings() -expr_t2 = import_from_pdaggerq(expr_t2) -expr_t2 = ghf_to_rhf(expr_t2).collect() -output_t2 = Tensor(*expr_t2.external_indices, name="t2new") +expr_t2 = import_expression(expr_t2, name="t2new") +exprs_t2 = adapt_spin(expr_t2, target_spin="rhf") # Optimise the T1 and T2 expressions -exprs = optimise_gristmill( - [Expression(output_t1, expr_t1), Expression(output_t2, expr_t2)], +exprs = optimise( + exprs_t1 + exprs_t2, strategy="trav", ) # Generate the code for the T1 and T2 expressions codegen( "update_amplitudes", - [output_t1, output_t2], + [expr.lhs for expr in (exprs_t1 + exprs_t2)], exprs, as_dict=True, ) diff --git a/tests/test_rccsd.py b/tests/test_rccsd.py index c8a6dda..ad1679a 100644 --- a/tests/test_rccsd.py +++ b/tests/test_rccsd.py @@ -10,8 +10,8 @@ from albert.code.einsum import EinsumCodeGenerator from albert.opt import optimise as _optimise -from albert.qc._pdaggerq import import_from_pdaggerq, remove_reference_energy -from albert.qc.spin import ghf_to_rhf +from albert.qc._pdaggerq import remove_reference_energy +from albert.qc import import_expression, adapt_spin from albert.tensor import Tensor from albert.expression import Expression @@ -26,27 +26,26 @@ def _kwargs(strategy, transposes, greedy_cutoff, drop_cutoff): @pytest.mark.parametrize( - "optimise, method, canonicalise, kwargs", + "optimise, method, kwargs", [ - (False, None, False, _kwargs(None, None, None, None)), - (True, "gristmill", True, _kwargs("trav", "natural", -1, -1)), - (True, "gristmill", True, _kwargs("opt", "natural", -1, -1)), - (True, "gristmill", False, _kwargs("greedy", "ignore", -1, 2)), - (True, "gristmill", True, _kwargs("greedy", "ignore", 2, 2)), - (True, "albert", True, {}), + (False, None, _kwargs(None, None, None, None)), + (True, "gristmill", _kwargs("trav", "natural", -1, -1)), + (True, "gristmill", _kwargs("greedy", "ignore", -1, 2)), + (True, "gristmill", _kwargs("greedy", "ignore", 2, 2)), + (True, "albert", {}), ], ) -def test_rccsd_einsum(helper, optimise, method, canonicalise, kwargs): +def test_rccsd_einsum(helper, optimise, method, kwargs): with open(f"{os.path.dirname(__file__)}/_test_rccsd.py", "w") as file: try: - _test_rccsd_einsum(helper, file, optimise, method, canonicalise, kwargs) + _test_rccsd_einsum(helper, file, optimise, method, kwargs) except Exception as e: raise e finally: os.remove(f"{os.path.dirname(__file__)}/_test_rccsd.py") -def _test_rccsd_einsum(helper, file, optimise, method, canonicalise, kwargs): +def _test_rccsd_einsum(helper, file, optimise, method, kwargs): codegen = EinsumCodeGenerator(stdout=file) codegen.preamble() @@ -59,17 +58,12 @@ def _test_rccsd_einsum(helper, file, optimise, method, canonicalise, kwargs): pq.simplify() energy = pq.strings() energy = remove_reference_energy(energy) - energy = import_from_pdaggerq(energy) - energy = ghf_to_rhf(energy) - if canonicalise: - energy = energy.canonicalise(indices=True).collect() - output = Tensor(name="e_cc") - - exprs = [Expression(output, energy)] + energy = import_expression(energy, package="pdaggerq", name="e_cc") + exprs = adapt_spin(energy, target_spin="rhf") if optimise: exprs = _optimise(exprs, method=method, **kwargs) - codegen("energy", [output], exprs) + codegen("energy", [expr.lhs for expr in exprs], exprs) pq.clear() pq.set_left_operators([["e1(i,a)"]]) @@ -77,13 +71,8 @@ def _test_rccsd_einsum(helper, file, optimise, method, canonicalise, kwargs): pq.add_st_operator(1.0, ["v"], ["t1", "t2"]) pq.simplify() t1 = pq.strings() - t1 = import_from_pdaggerq(t1, index_spins=dict(i="a", a="a")) - t1 = ghf_to_rhf(t1) - if canonicalise: - t1 = t1.canonicalise(indices=True).collect() - output_t1 = Tensor( - *sorted(t1.external_indices, key=lambda i: "ijab".index(i.name)), name="t1new" - ) + t1 = import_expression(t1, package="pdaggerq", index_spins=dict(i="a", a="a"), name="t1new") + t1 = adapt_spin(t1, target_spin="rhf") pq.clear() pq.set_left_operators([["e2(i,j,b,a)"]]) @@ -91,19 +80,14 @@ def _test_rccsd_einsum(helper, file, optimise, method, canonicalise, kwargs): pq.add_st_operator(1.0, ["v"], ["t1", "t2"]) pq.simplify() t2 = pq.strings() - t2 = import_from_pdaggerq(t2, index_spins=dict(i="a", j="b", a="a", b="b")) - t2 = ghf_to_rhf(t2) - if canonicalise: - t2 = t2.canonicalise(indices=True).collect() - output_t2 = Tensor( - *sorted(t2.external_indices, key=lambda i: "ijab".index(i.name)), name="t2new" - ) - - exprs = [Expression(output_t1, t1), Expression(output_t2, t2)] + t2 = import_expression(t2, package="pdaggerq", index_spins=dict(i="a", j="b", a="a", b="b"), name="t2new") + t2 = adapt_spin(t2, target_spin="rhf") + + exprs = t1 + t2 if optimise: exprs = _optimise(exprs, method=method, **kwargs) - codegen("update_amplitudes", [output_t1, output_t2], exprs, as_dict=True) + codegen("update_amplitudes", [expr.lhs for expr in exprs], exprs, as_dict=True) module = importlib.import_module(f"_test_rccsd") energy = module.energy diff --git a/tests/test_uccsd.py b/tests/test_uccsd.py index e892cb2..67ad7ef 100644 --- a/tests/test_uccsd.py +++ b/tests/test_uccsd.py @@ -10,8 +10,8 @@ from albert.code.einsum import EinsumCodeGenerator from albert.opt import optimise as _optimise -from albert.qc._pdaggerq import import_from_pdaggerq, remove_reference_energy -from albert.qc.spin import ghf_to_uhf +from albert.qc._pdaggerq import remove_reference_energy +from albert.qc import import_expression, adapt_spin from albert.tensor import Tensor from albert.expression import Expression @@ -26,24 +26,24 @@ def _kwargs(strategy, transposes, greedy_cutoff, drop_cutoff): @pytest.mark.parametrize( - "optimise, canonicalise, kwargs", + "optimise, kwargs", [ - (False, False, _kwargs(None, None, None, None)), - (True, False, _kwargs("greedy", "ignore", -1, 2)), - (True, True, _kwargs("greedy", "ignore", 2, 2)), + (False, _kwargs(None, None, None, None)), + (True, _kwargs("greedy", "ignore", -1, 2)), + (True, _kwargs("greedy", "ignore", 2, 2)), ], ) -def test_uccsd_einsum(helper, optimise, canonicalise, kwargs): +def test_uccsd_einsum(helper, optimise, kwargs): with open(f"{os.path.dirname(__file__)}/_test_uccsd.py", "w") as file: try: - _test_uccsd_einsum(helper, file, optimise, canonicalise, kwargs) + _test_uccsd_einsum(helper, file, optimise, kwargs) except Exception as e: raise e finally: os.remove(f"{os.path.dirname(__file__)}/_test_uccsd.py") -def _test_uccsd_einsum(helper, file, optimise, canonicalise, kwargs): +def _test_uccsd_einsum(helper, file, optimise, kwargs): codegen = EinsumCodeGenerator(stdout=file) codegen.preamble() @@ -56,17 +56,12 @@ def _test_uccsd_einsum(helper, file, optimise, canonicalise, kwargs): pq.simplify() energy = pq.strings() energy = remove_reference_energy(energy) - energy = import_from_pdaggerq(energy) - energy = ghf_to_uhf(energy) - if canonicalise: - energy = tuple(e.canonicalise(indices=True).collect() for e in energy) - output = tuple(Tensor(name="e_cc") for _ in energy) - - exprs = [Expression(o, e) for o, e in zip(output, energy)] + energy = import_expression(energy, package="pdaggerq", name="e_cc") + exprs = adapt_spin(energy, target_spin="uhf") if optimise: exprs = _optimise(exprs, **kwargs) - codegen("energy", output, exprs) + codegen("energy", [expr.lhs for expr in exprs], exprs) pq.clear() pq.set_left_operators([["e1(i,a)"]]) @@ -74,14 +69,8 @@ def _test_uccsd_einsum(helper, file, optimise, canonicalise, kwargs): pq.add_st_operator(1.0, ["v"], ["t1", "t2"]) pq.simplify() t1 = pq.strings() - t1 = import_from_pdaggerq(t1) - t1 = ghf_to_uhf(t1) - if canonicalise: - t1 = tuple(t.canonicalise(indices=True).collect() for t in t1) - output_t1 = tuple( - Tensor(*sorted(t.external_indices, key=lambda i: "ijab".index(i.name)), name=f"t1new") - for i, t in enumerate(t1) - ) + t1 = import_expression(t1, package="pdaggerq", name="t1new") + t1 = adapt_spin(t1, target_spin="uhf") pq.clear() pq.set_left_operators([["e2(i,j,b,a)"]]) @@ -89,23 +78,19 @@ def _test_uccsd_einsum(helper, file, optimise, canonicalise, kwargs): pq.add_st_operator(1.0, ["v"], ["t1", "t2"]) pq.simplify() t2 = pq.strings() - t2_expr = tuple() - for spins in ("aaaa", "abab", "baba", "bbbb"): - index_spins = dict(zip("ijab", spins)) - t2_expr += ghf_to_uhf(import_from_pdaggerq(t2, index_spins=index_spins)) - t2 = t2_expr - if canonicalise: - t2 = tuple(t.canonicalise(indices=True).collect() for t in t2) - output_t2 = tuple( - Tensor(*sorted(t.external_indices, key=lambda i: "ijab".index(i.name)), name=f"t2new") - for i, t in enumerate(t2) - ) - - exprs = [Expression(o, t) for o, t in zip(output_t1 + output_t2, t1 + t2)] + t2 = [ + import_expression( + t2, package="pdaggerq", index_spins=dict(zip("ijab", spins)), name="t2new" + ) + for spins in ("aaaa", "abab", "baba", "bbbb") + ] + t2 = adapt_spin(t2, target_spin="uhf") + + exprs = t1 + t2 if optimise: exprs = _optimise(exprs, **kwargs) - codegen("update_amplitudes", output_t1 + output_t2, exprs, as_dict=True) + codegen("update_amplitudes", [expr.lhs for expr in exprs], exprs, as_dict=True) module = importlib.import_module(f"_test_uccsd") energy = module.energy From 64a3e32f52166448fcb9b8ca18f13af0f3bd95b8 Mon Sep 17 00:00:00 2001 From: Oliver Backhouse Date: Wed, 26 Nov 2025 09:09:37 +0000 Subject: [PATCH 4/4] ruff --- examples/codegen_rccsd.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/codegen_rccsd.py b/examples/codegen_rccsd.py index c0328d4..ce88b79 100644 --- a/examples/codegen_rccsd.py +++ b/examples/codegen_rccsd.py @@ -6,11 +6,9 @@ from pdaggerq import pq_helper from albert.code.einsum import EinsumCodeGenerator -from albert.expression import Expression from albert.opt import optimise +from albert.qc import adapt_spin, import_expression from albert.qc._pdaggerq import remove_reference_energy -from albert.qc import import_expression, adapt_spin -from albert.tensor import Tensor # Suppress warnings since we're outputting the code to stdout warnings.filterwarnings("ignore")