Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion albert/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,17 @@ def _matches_filter(node: Base, type_filter: TypeOrFilter[Base]) -> bool:


def _sign_penalty(base: Base) -> int:
"""Return a penalty for the sign in scalars in a base object.
"""Return a penalty for the sign in scalars in a `Base` object.

Args:
base: Base object to check.

Returns:
Penalty for the sign.
"""
# TODO: Improve check for Scalar
if hasattr(base, "value"):
return 1 if getattr(base, "value") < 0 else -1
if not base.children:
return 0
penalty = 1
Expand Down
7 changes: 6 additions & 1 deletion albert/opt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@ def optimise(
Returns:
The optimised expressions, as tuples of the output tensor and the expression.
"""
if method == "gristmill" or method == "auto":
if method == "auto":
try:
return optimise_gristmill(exprs, **kwargs)
except ImportError:
return optimise_albert(exprs, **kwargs)
elif method == "gristmill":
return optimise_gristmill(exprs, **kwargs)
elif method == "albert":
return optimise_albert(exprs, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions albert/opt/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,7 @@ def _optimise(
return expressions

# Canonicalise the terms in the expressions
expressions = list(expressions)
for i, expression in enumerate(expressions):
expressions[i] = _canonicalise_expression(expression, indices)

Expand Down
95 changes: 95 additions & 0 deletions albert/qc/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,96 @@
"""Functionality specific to quantum chemistry applications."""

from __future__ import annotations

from typing import TYPE_CHECKING

from albert.qc._pdaggerq import import_from_pdaggerq
from albert.qc._wick import import_from_wick
from albert.qc.spin import ghf_to_rhf, ghf_to_uhf
from albert.tensor import Tensor
from albert.expression import Expression

if TYPE_CHECKING:
from typing import Any, Iterable, Literal


def import_expression(
external: Any,
package: Literal["pdaggerq", "wick"] = "pdaggerq",
index_order: Iterable[str] | None = None,
name: str | None = None,
**kwargs: Any,
) -> Expression:
"""Import an expression from a third-party quantum chemistry package.

Args:
external: The external expression to import. The exact format is specified by the
individual importers.
package: The code from which to import the expression.
index_order: The desired order of external index labels in the imported expression. The
indices of the left-hand side of the expression will be sorted such that their labels
match this order.
name: The name to assign to the LHS tensor.
**kwargs: Additional keyword arguments to pass to the importer.

Returns:
The imported expression.
"""
# Import the RHS
if package == "pdaggerq":
rhs = import_from_pdaggerq(external, **kwargs)
elif package == "wick":
rhs = import_from_wick(external, **kwargs)
else:
raise ValueError(f"Unknown package: {package}")
rhs = rhs.canonicalise(indices=True)

# Get the LHS
if index_order is None:
indices = rhs.external_indices
else:
indices = tuple(sorted(rhs.external_indices, key=lambda i: list(index_order).index(i.name)))
lhs = Tensor(*indices, name=name)

return Expression(lhs, rhs)


def adapt_spin(
expr: Expression | Iterable[Expression],
target_spin: Literal["rhf", "uhf"],
) -> tuple[Expression, ...]:
"""Adapt the spin representation of a quantum chemistry expression.

Args:
expr: The expression(s) to adapt.
target_spin: The target spin representation.

Returns:
The adapted expressions. For `"rhf"`, this is a tuple with a single expression. For `"uhf"`,
this is a tuple with one expression per spin case.
"""
if isinstance(expr, Expression):
expr = (expr,)

# Convert the RHS
if target_spin == "rhf":
rhs_list = [ghf_to_rhf(e.rhs) for e in expr]
lhs_list = [e.lhs for e in expr]
elif target_spin == "uhf":
rhs_list = []
lhs_list = []
for e in expr:
rhs_parts = ghf_to_uhf(e.rhs)
rhs_list.extend(rhs_parts)
lhs_list.extend([e.lhs for _ in rhs_parts])
else:
raise ValueError(f"Unknown target spin: {target_spin}")

# Get the LHS for each case
exprs = []
for lhs, rhs in zip(lhs_list, rhs_list):
spins = {index.name: index.spin for index in rhs.external_indices}
index_map = {index: index.copy(spin=spins[index.name]) for index in lhs.external_indices}
exprs.append(Expression(lhs.map_indices(index_map), rhs))

return tuple(exprs)
4 changes: 3 additions & 1 deletion albert/qc/_pdaggerq.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def _convert_symbol(
index_spins: Optional[dict[str, str]] = None,
index_spaces: Optional[dict[str, str]] = None,
l_is_lambda: bool = True,
name: str | None = None,
) -> Base:
"""Convert a symbol to a subclass of `Base`.

Expand All @@ -160,6 +161,7 @@ def _convert_symbol(
index_spaces: The index spaces.
l_is_lambda: Whether `l` corresponds to the Lambda operator, rather than the left-hand EOM
operator.
name: The name of the tensor.

Returns:
The converted symbol.
Expand Down Expand Up @@ -386,7 +388,7 @@ def _convert_symbol(
for index in index_strs
)

return tensor_symbol(*indices)
return tensor_symbol.factory(*indices, name=name)


def remove_reference_energy(terms: list[list[str]]) -> list[list[str]]:
Expand Down
215 changes: 215 additions & 0 deletions albert/qc/_wick.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
"""Interface to `wick`."""

from __future__ import annotations

import re
from typing import TYPE_CHECKING

from albert.algebra import Mul
from albert.index import Index
from albert.qc import ghf
from albert.qc._pdaggerq import _guess_space, _is_number
from albert.qc.tensor import QTensor
from albert.scalar import Scalar

if TYPE_CHECKING:
from typing import Optional

from albert.base import Base


def import_from_wick(
terms: list[str],
index_spins: Optional[dict[str, str]] = None,
index_spaces: Optional[dict[str, str]] = None,
l_is_lambda: bool = True,
symbol_aliases: Optional[dict[str, str]] = None,
) -> Base:
r"""Import an expression from `wick`.

Tensors in the return expression are `GHF` tensors.

Args:
terms: The terms of the expression. Should be the lines of the `repr` of the output
`AExpression` in `wick`, i.e. `str(AExpression(Ex=...)).split("\n")`.
index_spins: The index spins.
index_spaces: The index spaces.
l_is_lambda: Whether `l` corresponds to the Lambda operator, rather than the left-hand EOM
operator.
symbol_aliases: Aliases for symbols.

Returns:
The imported expression.
"""
if index_spins is None:
index_spins = {}
if index_spaces is None:
index_spaces = {}

# Build the expression
expr: Base = Scalar.factory(0.0)
for term_str in terms:
# Convert the symbols
term = _split_term(term_str)
term, names = zip(*[_format_symbol(symbol, aliases=symbol_aliases) for symbol in term])
symbols = [
_convert_symbol(
symbol,
index_spins=index_spins,
index_spaces=index_spaces,
l_is_lambda=l_is_lambda,
name=name,
)
for symbol, name in zip(term, names)
]
part = Mul.factory(*symbols)

# Add the term to the expression
expr += part.canonicalise(indices=True) # wick doesn't guarantee same external indices

return expr


def _split_term(term: str) -> list[str]:
"""Split a term into its symbols."""
term = term.lstrip(" ")
term = term.replace(" ", "")
term = term.replace("}", "} ").rstrip(" ")
if r"\sum_{" in term:
term = re.sub(r"\\sum_\{[^\}]*\}", "", term)
else:
i = 0
while term[i] in "-+0123456789.":
i += 1
if i > 0:
term = term[:i] + " " + term[i:]
return term.split(" ")


def _format_symbol(symbol: str, aliases: dict[str, str] | None = None) -> tuple[str, str]:
"""Rewrite a `wick` symbol to look like a `pdaggerq` symbol."""
symbol = re.sub(
r"([a-zA-Z0-9]+)_\{([^\}]*)\}", lambda m: f"{m.group(1)}({','.join(m.group(2))})", symbol
)
symbol_name, indices = symbol.split("(", 1) if "(" in symbol else (symbol, None)
if aliases is not None:
symbol_alias = aliases.get(symbol_name, symbol_name)
symbol = f"{symbol_alias}({indices}" if indices is not None else symbol_alias
return symbol, symbol_name


def _convert_symbol(
symbol: str,
index_spins: Optional[dict[str, str]] = None,
index_spaces: Optional[dict[str, str]] = None,
l_is_lambda: bool = True,
name: str | None = None,
) -> Base:
"""Convert a symbol to a subclass of `Base`.

Args:
symbol: The symbol.
index_spins: The index spins.
index_spaces: The index spaces.
l_is_lambda: Whether `l` corresponds to the Lambda operator, rather than the left-hand EOM
operator.
name: The name of the tensor.

Returns:
The converted symbol.
"""
if index_spins is None:
index_spins = {}
if index_spaces is None:
index_spaces = {}

if re.match(r".*_[0-9]+$", symbol):
# Symbol has spaces attached, separate them
symbol, spaces = symbol.rsplit("_", 1)

if _is_number(symbol):
# It's the factor
return Scalar.factory(float(symbol))

tensor_symbol: type[QTensor]
index_strs: tuple[str, ...]
if symbol in ("r0", "l0"):
# r0 or l0
index_strs = ()
tensor_symbol = ghf.R0
elif re.match(r"f\((?i:[a-z]),(?i:[a-z])\)", symbol):
# f(i,j)
index_strs = tuple(symbol[2:-1].split(","))
tensor_symbol = ghf.Fock
elif re.match(r"v\((?i:[a-z]),(?i:[a-z]),(?i:[a-z]),(?i:[a-z])\)", symbol):
# v(i,j,k,l)
index_strs = tuple(symbol[2:-1].split(","))
index_strs = (index_strs[2], index_strs[3], index_strs[0], index_strs[1])
tensor_symbol = ghf.ERI
elif re.match(r"t1\((?i:[a-z]),(?i:[a-z])\)", symbol):
# t1(i,j)
index_strs = tuple(symbol[3:-1].split(","))
index_strs = (index_strs[1], index_strs[0])
tensor_symbol = ghf.T1
elif re.match(r"t2\((?i:[a-z]),(?i:[a-z]),(?i:[a-z]),(?i:[a-z])\)", symbol):
# t2(i,j,k,l)
index_strs = tuple(symbol[3:-1].split(","))
index_strs = (index_strs[2], index_strs[3], index_strs[0], index_strs[1])
tensor_symbol = ghf.T2
elif re.match(
r"t3\((?i:[a-z]),(?i:[a-z]),(?i:[a-z]),(?i:[a-z]),(?i:[a-z]),(?i:[a-z])\)", symbol
):
# t3(i,j,k,l,m,n)
index_strs = tuple(symbol[3:-1].split(","))[::-1]
index_strs = (
index_strs[3],
index_strs[4],
index_strs[5],
index_strs[0],
index_strs[1],
index_strs[2],
)
tensor_symbol = ghf.T3
elif re.match(r"l1\((?i:[a-z]),(?i:[a-z])\)", symbol) and l_is_lambda:
# l1(i,j)
index_strs = tuple(symbol[3:-1].split(","))
index_strs = (index_strs[1], index_strs[0])
tensor_symbol = ghf.L1
elif re.match(r"l2\((?i:[a-z]),(?i:[a-z]),(?i:[a-z]),(?i:[a-z])\)", symbol) and l_is_lambda:
# l2(i,j,k,l)
index_strs = tuple(symbol[3:-1].split(","))
index_strs = (index_strs[2], index_strs[3], index_strs[0], index_strs[1])
tensor_symbol = ghf.L2
elif (
re.match(r"l3\((?i:[a-z]),(?i:[a-z]),(?i:[a-z]),(?i:[a-z]),(?i:[a-z]),(?i:[a-z])\)", symbol)
and l_is_lambda
):
# l3(i,j,k,l,m,n)
index_strs = tuple(symbol[3:-1].split(","))[::-1]
index_strs = (
index_strs[3],
index_strs[4],
index_strs[5],
index_strs[0],
index_strs[1],
index_strs[2],
)
tensor_symbol = ghf.L3
elif re.match(r"delta\((?i:[a-z]),(?i:[a-z])\)", symbol):
# delta(i,j)
index_strs = tuple(symbol[6:-1].split(","))
tensor_symbol = ghf.Delta
else:
raise ValueError(f"Unknown symbol {symbol}")

# Convert the indices
indices = tuple(
Index(
index,
spin=index_spins.get(index, None),
space=index_spaces.get(index, _guess_space(index)),
)
for index in index_strs
)

return tensor_symbol.factory(*indices, name=name)
Loading