diff --git a/albert/base.py b/albert/base.py index 8fbd77f..953c852 100644 --- a/albert/base.py +++ b/albert/base.py @@ -123,7 +123,7 @@ def _matches_filter(node: Base, type_filter: TypeOrFilter[Base]) -> bool: def _sign_penalty(base: Base) -> int: - """Return a penalty for the sign in scalars in a base object. + """Return a penalty for the sign in scalars in a `Base` object. Args: base: Base object to check. @@ -131,6 +131,9 @@ def _sign_penalty(base: Base) -> int: Returns: Penalty for the sign. """ + # TODO: Improve check for Scalar + if hasattr(base, "value"): + return 1 if getattr(base, "value") < 0 else -1 if not base.children: return 0 penalty = 1 diff --git a/albert/opt/__init__.py b/albert/opt/__init__.py index 8d08094..5b52b80 100644 --- a/albert/opt/__init__.py +++ b/albert/opt/__init__.py @@ -6,30 +6,41 @@ from albert.opt._gristmill import optimise_gristmill from albert.opt.cse import optimise as optimise_albert +from albert.opt._brute import eliminate_and_factorise_common_subexpressions if TYPE_CHECKING: - from typing import Any + from typing import Any, Literal from albert.expression import Expression def optimise( exprs: list[Expression], - method: str = "auto", + method: Literal["auto", "gristmill", "albert", "legacy"] = "auto", **kwargs: Any, ) -> list[Expression]: """Perform common subexpression elimination on the given expression. Args: exprs: The expressions to be optimised. - method: The optimisation method to use. Options are `"auto"`, `"gristmill"`. + method: The optimisation method to use. **kwargs: Additional keyword arguments to pass to the optimisation method. Returns: The optimised expressions, as tuples of the output tensor and the expression. """ - if method == "gristmill" or method == "auto": + if method == "auto": + try: + return optimise_gristmill(exprs, **kwargs) + except ImportError: + return optimise_albert(exprs, **kwargs) + elif method == "gristmill": return optimise_gristmill(exprs, **kwargs) elif method == "albert": return optimise_albert(exprs, **kwargs) + elif method == "legacy": + return sum( + [eliminate_and_factorise_common_subexpressions(expr, **kwargs) for expr in exprs], + [], + ) else: raise ValueError(f"Unknown optimisation method: {method!r}") diff --git a/albert/opt/_brute.py b/albert/opt/_brute.py new file mode 100644 index 0000000..1e01bdf --- /dev/null +++ b/albert/opt/_brute.py @@ -0,0 +1,743 @@ +"""Brute-force common subexpression elimination. + +This is deprecated and `albert.opt.cse` is recommended instead. +""" + +from __future__ import annotations + +import functools +import itertools +import warnings +from typing import TYPE_CHECKING, cast + +from albert import _default_sizes +from albert.algebra import Add, Algebraic, Mul +from albert.canon import canonicalise_indices +from albert.expression import Expression +from albert.opt.tools import count_flops, sort_expressions +from albert.scalar import Scalar +from albert.tensor import Tensor + +if TYPE_CHECKING: + from typing import Any, Optional + + from albert.base import Base + from albert.index import Index + + +@functools.lru_cache(maxsize=512) +def _count_tensors(expr: Base) -> int: + """Count the number of tensors in an expression.""" + if not expr._children: + return 0 + count = 0 + for tensor in expr._children: + if isinstance(tensor, Tensor): + count += 1 + else: + count += _count_tensors(tensor) + return count + + +def _count_scalars(expr: Base) -> int: + """Count the number of scalars in an expression.""" + if not expr._children: + return 0 + count = 0 + for scalar in expr._children: + if isinstance(scalar, Scalar): + count += 1 + else: + count += _count_scalars(scalar) + return count + + +def _identify_subexpressions( + exprs: list[Expression], indices: Optional[set[Index]] = None +) -> dict[tuple[Base, tuple[Index, ...]], int]: + """Identify candidate common subexpressions and count their occurrences.""" + if indices is None: + indices = set() + for expr in exprs: + for tensor in expr.rhs.search(Tensor): + indices.update(set(tensor.indices)) + + candidates: dict[tuple[Base, tuple[Index, ...]], int] = {} + for expr in exprs: + for mul in expr.rhs.search(Mul): + # Loop over all combinations of >1 children to find subexpressions + children = [child for child in mul._children if not isinstance(child, Scalar)] + for r in range(2, len(children) + 1): + for combo in itertools.combinations(children, r): + # Get the candidate subexpression + candidate: Base = Mul(*combo) + + # Find the external indices of the candidate -- for Einstein summation + # compliant expressions, this is just candidate.external_indices, but we want + # to support more general expressions + other_indices = set() + for child in children: + if child not in combo: + other_indices.update(set(child.external_indices)) + other_indices.update(set(child.internal_indices)) + other_indices.update(set(expr.lhs.indices)) + candidate_indices = set(candidate.external_indices + candidate.internal_indices) + candidate_indices = set.intersection(candidate_indices, other_indices) + + # Canonicalise the candidate + index_map = _get_canonicalise_intermediate_map(candidate, indices) + _, candidate = _canonicalise_intermediate(None, candidate, indices) + canon_indices = tuple(index_map[i] for i in candidate_indices) + + # Increment the count for this candidate + candidates[candidate, canon_indices] = ( + candidates.get((candidate, canon_indices), 0) + 1 + ) + + return candidates + + +def parenthesise_mul( + mul: Mul, + sizes: Optional[dict[str | None, int]] = None, + scaling_limit_cpu: dict[tuple[str, ...], int] | None = None, + scaling_limit_ram: dict[tuple[str, ...], int] | None = None, + intermediate_counter: int = 0, +) -> tuple[Mul, list[Expression]]: + """Parenthesise a product. + + Converts the `Mul` of given children into a nested `Mul` of groups of said children. + + Args: + mul: The contraction to parenthesise. + sizes: The sizes of the spaces in the expression. + scaling_limit_cpu: The scaling limits for CPU. Keys should be tuples of index space names, + and values are the maximum allowed scaling for that combination of spaces. + scaling_limit_ram: The scaling limits for RAM. Keys should be tuples of index space names, + and values are the maximum allowed scaling for that combination of spaces. + intermediate_counter: The starting counter for naming intermediate tensors. + + Returns: + The parenthesised contraction represented by a non-nested product, and a list of tensor + expressions defining the intermediates to resolve the nested product. + """ + import opt_einsum + + if sizes is None: + sizes = _default_sizes + if scaling_limit_cpu is None: + scaling_limit_cpu = {} + if scaling_limit_ram is None: + scaling_limit_ram = {} + + # Get dummy sizes for the cost function + sizes_dummy = {space: ord(space) for space in sizes if isinstance(space, str)} + dummy_map = {value: key for key, value in sizes_dummy.items()} + sizes_map = {sizes_dummy[space]: sizes[space] for space in sizes_dummy} + + def cost( + cost1: int, + cost2: int, + i1_union_i2: set[int], + size_dict: list[int], + cost_cap: int, + s1: int, + s2: int, + xn: dict[int, Any], + g: int, + all_tensors: int, + inputs: list[set[int]], + i1_cut_i2_wo_output: set[int], + memory_limit: Optional[int], + contract1: int | tuple[int], + contract2: int | tuple[int], + ) -> None: + """Cost function for `opt_einsum`.""" + # Get the cost scaling + scaling: dict[str, int] = {} + for i in i1_union_i2: + c = dummy_map[size_dict[i]] + scaling[c] = scaling.get(c, 0) + 1 + + # Check the cost scaling + if scaling_limit_cpu is not None: + for cs, n in scaling_limit_cpu.items(): + if sum(scaling.get(c, 0) for c in cs) > n: + return + + # Get the real cost + size_dict_real = [sizes_map[i] for i in size_dict] + cost = cost1 + cost2 + opt_einsum.paths.compute_size_by_dict(i1_union_i2, size_dict_real) + + # Check the real cost + if cost <= cost_cap: + s = s1 | s2 + if s not in xn or cost < xn[s][1]: + i_mem = opt_einsum.paths._dp_calc_legs( + g, all_tensors, s, inputs, i1_cut_i2_wo_output, i1_union_i2 + ) + + # Get the memory scaling + scaling = {} + for i in i_mem: + c = dummy_map[size_dict[i]] + scaling[c] = scaling.get(c, 0) + 1 + + # Check the memory scaling + if scaling_limit_ram is not None: + for cs, n in scaling_limit_ram.items(): + if sum(scaling.get(c, 0) for c in cs) > n: + return + + # Get the real memory + mem = opt_einsum.paths.compute_size_by_dict(i_mem, size_dict_real) + + # Check the real memory + if memory_limit is None or mem <= memory_limit: + # Accept this contraction + xn[s] = (i_mem, cost, (contract1, contract2)) + + # Separate the children into tensors and scalars + tensors = list(mul.search(Tensor, depth=1)) + scalars = list(mul.search(Scalar, depth=1)) + assert len(mul._children) == len(list(tensors)) + len(list(scalars)) + + # Get the optimal contraction path + optimiser = opt_einsum.DynamicProgramming( + minimize=cost, + cost_cap=True, + ) + + # Map index names to unique characters for opt_einsum + _index_map: dict[Index, str] = {} + + def _assign_index(index: Index) -> str: + if index not in _index_map: + if len(_index_map) >= 26: + raise ValueError("Too many unique indices.") + _index_map[index] = chr(97 + len(_index_map)) + return _index_map[index] + + # Make fake arrays to get the contraction path + arrays = [lambda: None for _ in tensors] + for i, t in enumerate(tensors): + arrays[i].shape = tuple(sizes_dummy[i.space] for i in t.indices) # type: ignore + inputs = ["".join(_assign_index(i) for i in t.indices) for t in tensors] + output = "".join(_assign_index(i) for i in mul.external_indices) + subscript = ",".join(inputs) + "->" + output + path, info = opt_einsum.contract_path(subscript, *arrays, optimize=optimiser) + lines = str(info).splitlines() + start = next(i for i, line in enumerate(lines) if line.startswith("-----")) + 3 + subscripts = [line.split()[2] for line in lines[start:] if line.strip()] + + # Build the contractions + intermediates: list[Expression] = [] + counter = intermediate_counter + _index_map_rev = {v: k for k, v in _index_map.items()} + while subscripts: + inputs_i, output_i = subscripts.pop(0).split("->") + tensors_i = [tensors.pop(i) for i in sorted(path.pop(0), reverse=True)] + assert all( + tuple(_index_map_rev[c] for c in inp) == tuple(t.indices) + for inp, t in zip(inputs_i.split(","), tensors_i) + ) + if len(subscripts) == 0: + expr = Mul(*scalars, *tensors_i) + else: + output_indices = [_index_map_rev[c] for c in output_i] + interm = Tensor(*output_indices, name=f"tmp{counter}") + counter += 1 + intermediates.append(Expression(interm, Mul(*tensors_i))) + tensors.append(interm) + + return expr, intermediates + + +def factorise(exprs: list[Expression]) -> list[Expression]: + """Factorise expressions that differ by at most one tensor and the scalar factor. + + Args: + exprs: The tensor expressions to identify common subexpressions in. + + Returns: + The factorised tensor expressions. + """ + # Check that each expression is either: + # a) a Mul with at most two non-scalar children + # b) a non-scalar + new_exprs: list[Expression] = [] + to_factorise: list[tuple[Tensor, Base]] = [] + for expr in exprs: + if isinstance(expr.rhs, Mul): + children = [child for child in expr.rhs._children if not isinstance(child, Scalar)] + if len(children) > 2: + raise ValueError( + "Each expression must be a Mul with two non-scalar children. Try " + "parenthesising the expressions first.", + ) + if len(children) == 2: + to_factorise.append((expr.lhs, expr.rhs)) + else: + new_exprs.append(expr) + else: + new_exprs.append(expr) + + while to_factorise: + # Get all the possible factors + factors: dict[Base, int] = {} + for lhs, rhs in to_factorise: + assert rhs._children is not None + children = [child for child in rhs._children if not isinstance(child, Scalar)] + assert len(children) == 2 + for child in children: + if child not in factors: + factors[child] = 0 + factors[child] += 1 + + # Find the factor that appears the most + factor = max(factors, key=lambda k: factors[k]) + + # For each expression that contains this factor, remove it and group them + group: list[tuple[Tensor, Base]] = [] + new_to_factorise: list[tuple[Tensor, Base]] = [] + for lhs, rhs in to_factorise: + if factor in rhs.children: + group.append((lhs, Mul(*[child for child in rhs.children if child != factor]))) + else: + new_to_factorise.append((lhs, rhs)) + to_factorise = new_to_factorise + + # Combine the group into sums for each unique output + for output in set(output for output, _ in group): + group_out = [child for out, child in group if out == output] + new_exprs.append(Expression(output, Mul(factor, Add(*group_out)))) + + return new_exprs + + +def eliminate_common_subexpressions( + exprs: list[Expression], sizes: Optional[dict[str | None, int]] = None +) -> list[Expression]: + """Identify common subexpressions in a series of expressions. + + Expression should be parenthesised and split into individual contractions for this to work. + + Args: + exprs: The tensor expressions to identify common subexpressions in. + sizes: The sizes of the spaces in the expressions. + + Returns: + Expressions with common subexpressions eliminated, and a list of intermediate definitions. + """ + if sizes is None: + sizes = _default_sizes + + # Get all indices in the expressions + indices: set[Index] = set() + for expr in exprs: + for tensor in expr.rhs.search(Tensor): + indices.update(set(tensor.indices)) + + # Check if there are any existing intermediates we should avoid clashing with + counter = 0 + for expr in exprs: + for tensor in itertools.chain([expr.lhs], expr.rhs.search(Tensor)): + if tensor.name.startswith("tmp") and tensor.name[3:].isdigit(): + counter = max(counter, int(tensor.name[3:]) + 1) + + while True: + # Find candidate subexpressions and count their occurrences + # TODO: write update function to avoid repeating work + candidates = _identify_subexpressions(exprs, indices=indices) + candidates = {k: v for k, v in candidates.items() if v > 1} + + # If no candidates, we're done + if not candidates: + break + + def _cost(c: tuple[Base, tuple[Index, ...]]) -> float: + """Estimate the cost (benefit) of a candidate intermediate.""" + count = candidates[c] # noqa: B023 + flops = count_flops(c[0], sizes=sizes) + return count * flops + + # Favour the best candidate according to the cost function + candidate, candidate_indices = max(candidates, key=_cost) + + # Initialise the intermediate + interm = Tensor( + *candidate_indices, + name=f"tmp{counter}", + ) + + # Find all instances of the candidate + # TODO: track addresses when searching for candidates to avoid repeating work + new_exprs: list[Expression] = [] + touched = False + for i, expr in enumerate(exprs): + # Find the substitutions + substs: dict[Base, Base] = {} + for mul in expr.rhs.search(Mul): + # Loop over combinations of children to find subexpressions + children = [child for child in mul._children if not isinstance(child, Scalar)] + assert candidate._children is not None + for combo in itertools.combinations(children, len(candidate._children)): + mul_check = Mul(*combo) + _, mul_check_canon = _canonicalise_intermediate(None, mul_check, indices) + if mul_check_canon == candidate: + index_map = _get_canonicalise_intermediate_map(mul_check, indices) + index_map_rev = {v: k for k, v in index_map.items()} + scalars = [child for child in mul._children if child not in combo] + substs[mul] = Mul.factory(*scalars, interm.map_indices(index_map_rev)) + touched = True + + if substs: + # Apply the substitutions + new_expr = expr.rhs.apply(lambda node: substs.get(node, node), Mul) # noqa: B023 + new_exprs.append(Expression(expr.lhs, new_expr)) + else: + new_exprs.append(expr) + + exprs = new_exprs + + if touched: + # Add the definition of the intermediate and increment the counter + exprs.append(Expression(interm, candidate)) + counter += 1 + + # For any remaining nested multiplications, assign intermediates instead of the nesting + new_exprs = [] + + def _separate(mul: Mul) -> Mul: + """Separate a nested multiplication.""" + nonlocal counter + + children: list[Base] = [] + for child in mul._children: + if isinstance(child, Algebraic): + # Create an intermediate for this nested multiplication + intermediate = Tensor( + *child.external_indices, + name=f"tmp{counter}", + ) + counter += 1 + exprs.append(Expression(intermediate, child)) + children.append(intermediate) + else: + children.append(child) + + return Mul(*children) + + for expr in exprs: + new_exprs.append(Expression(expr.lhs, expr.rhs.apply(_separate, Mul))) + exprs = new_exprs + + return exprs + + +def absorb_intermediate_factors(exprs: list[Expression]) -> list[Expression]: + """Absorb factors from intermediates back into the expressions where possible. + + Args: + exprs: The tensor expressions to update. + + Returns: + The updated tensor expressions. + """ + new_exprs: list[Expression] = [] + for i, expr in enumerate(exprs): + if not expr.lhs.name.startswith("tmp"): + new_exprs.append(expr) + continue + scalars = list(filter(lambda child: isinstance(child, Scalar), expr.rhs._children or [])) + others = list(filter(lambda child: not isinstance(child, Scalar), expr.rhs._children or [])) + if len(others) == len(scalars) == 1: + for j, ex in enumerate(new_exprs): + new_exprs[j] = Expression( + ex.lhs, + ex.rhs.apply( + lambda node: ( + Mul(*scalars, node) if node.name == ex.lhs.name else node # noqa: B023 + ), + Tensor, + ), + ) + new_exprs.append(Expression(expr.lhs, others[0])) + else: + new_exprs.append(expr) + return new_exprs + + +def merge_identical_intermediates(exprs: list[Expression]) -> list[Expression]: + """Merge identical intermediates to avoid duplication. + + Args: + exprs: The tensor expressions to update. + + Returns: + The updated tensor expressions. + """ + # TODO: relax the identical indices requirement to allow for transposes + groups: dict[tuple[Base, tuple[Index, ...]], list[Tensor]] = {} + for expr in exprs: + if (expr.rhs, expr.lhs.indices) not in groups: + groups[expr.rhs, expr.lhs.indices] = [] + groups[expr.rhs, expr.lhs.indices].append(expr.lhs) + unique_intermediates: dict[str, Tensor] = {} + for _, outputs in groups.items(): + for output in outputs: + unique_intermediates[output.name] = outputs[0] + + def _apply(node: Tensor) -> Tensor: + if node.name.startswith("tmp"): + return node.__class__(*node.indices, name=unique_intermediates[node.name].name) + return node + + return [ + Expression(expr.lhs, expr.rhs.apply(_apply, Tensor)) + for expr in exprs + if expr.lhs.name == unique_intermediates[expr.lhs.name].name + ] + + +def absorb_trivial_intermediates(exprs: list[Expression]) -> list[Expression]: + """Absorb intermediates that are just a single tensor back into the expressions. + + Args: + exprs: The tensor expressions to update. + + Returns: + The updated expression. + """ + trivial: dict[str, bool] = {} + definitions: dict[str, Expression] = {} + for i, expr in enumerate(exprs): + if expr.lhs.name.startswith("tmp") and isinstance(expr.rhs, Tensor): + # If the output has multiple single tensor expressions, it's not trivial + trivial[expr.lhs.name] = expr.lhs.name not in trivial and True + definitions[expr.lhs.name] = expr + + def _apply(node: Tensor) -> Tensor: + while trivial.get(node.name, False): + expr = definitions[node.name] + index_map = dict(zip(expr.lhs.indices, node.indices)) + node = expr.rhs.map_indices(index_map) # type: ignore[assignment] + return node + + return [ + Expression(expr.lhs, expr.rhs.apply(_apply, Tensor)) + for expr in exprs + if not trivial.get(expr.lhs.name, False) + ] + + +def unused_intermediates(exprs: list[Expression]) -> list[Tensor]: + """Identify intermediates that are defined but not used. + + Args: + exprs: The tensor expressions to check. + + Returns: + The list of unused intermediate tensors. + """ + defined: set[Tensor] = set() + used: set[str] = set() + for expr in exprs: + if expr.lhs.name.startswith("tmp"): + defined.add(expr.lhs) + for tensor in expr.rhs.search(Tensor): + if tensor.name.startswith("tmp"): + used.add(tensor.name) + return [tensor for tensor in defined if tensor.name not in used] + + +def undefined_intermediates(exprs: list[Expression]) -> list[Tensor]: + """Identify intermediates that are used but not defined. + + Args: + exprs: The tensor expressions to check. + + Returns: + The list of undefined intermediate tensors. + """ + defined: set[str] = set() + used: set[Tensor] = set() + for expr in exprs: + if expr.lhs.name.startswith("tmp"): + defined.add(expr.lhs.name) + for tensor in expr.rhs.search(Tensor): + if tensor.name.startswith("tmp"): + used.add(tensor) + return [tensor for tensor in used if tensor.name not in defined] + + +def renumber_intermediates(exprs: list[Expression]) -> list[Expression]: + """Renumber intermediates to ensure a contiguous sequence. + + Args: + exprs: The tensor expressions to renumber. + + Returns: + The renumbered tensor expressions. + """ + # Sort the expressions so the renumbering looks sensible after code generation + exprs = sort_expressions(exprs) + + # Map old intermediate names to new ones + counter = 0 + name_map: dict[str, str] = {} + for expr in exprs: + for tensor in itertools.chain([expr.lhs], expr.rhs.search(Tensor)): + if tensor.name.startswith("tmp"): + if tensor.name not in name_map: + name_map[tensor.name] = f"tmp{counter}" + counter += 1 + + def _apply(node: Tensor) -> Tensor: + if node.name.startswith("tmp"): + return node.__class__(*node.indices, name=name_map[node.name]) + return node + + exprs = [ + Expression( + expr.lhs.__class__(*expr.lhs.indices, name=name_map.get(expr.lhs.name, expr.lhs.name)), + expr.rhs.apply(_apply, Tensor), + ) + for expr in exprs + ] + + return exprs + + +@functools.lru_cache(maxsize=32) +def _get_index_groups( + indices: frozenset[Index], +) -> dict[tuple[str | None, str | None], list[Index]]: + """Group indices by their (space, spin) pairs.""" + index_groups: dict[tuple[str | None, str | None], list[Index]] = {} + for index in indices: + key = (index.space, index.spin) + if key not in index_groups: + index_groups[key] = [] + index_groups[key].append(index) + return {key: sorted(value) for key, value in index_groups.items()} + + +def _get_canonicalise_intermediate_map(expr: Base, indices: set[Index]) -> dict[Index, Index]: + """Get the index mapping to canonicalise the indices of an intermediate.""" + index_groups = _get_index_groups(frozenset(indices)) + indices_i = { + key: [ + index + for index in (expr.external_indices + expr.internal_indices) + if (index.space, index.spin) == key + ] + for key in index_groups + } + index_map = {} + for key in index_groups: + for old, new in zip(index_groups[key], indices_i[key]): + index_map[new] = old + return index_map + + +def _canonicalise_intermediate( + output: Tensor | None, expr: Base, indices: set[Index] +) -> tuple[Tensor, Base]: + """Canonicalise the indices of an intermediate.""" + index_map = _get_canonicalise_intermediate_map(expr, indices) + expr = expr.map_indices(index_map) + output = output.map_indices(index_map) if output is not None else None + return output, expr # type: ignore[return-value] + + +def eliminate_and_factorise_common_subexpressions( + expr: Expression, + sizes: Optional[dict[str | None, int]] = None, + scaling_limit_cpu: dict[tuple[str, ...], int] | None = None, + scaling_limit_ram: dict[tuple[str, ...], int] | None = None, + max_passes: int = 3, +) -> list[Expression]: + """Identify common subexpressions in an expression, with parenthesisation and factorisation. + + Expression should be canonicalised for this to work well. + + Args: + expr: The tensor expression to identify common subexpressions in. + sizes: The sizes of the spaces in the expression. + scaling_limit_cpu: The scaling limits for CPU. Keys should be tuples of index space names, + and values are the maximum allowed scaling for that combination of spaces. + scaling_limit_ram: The scaling limits for RAM. Keys should be tuples of index space names, + and values are the maximum allowed scaling for that combination of spaces. + max_passes: The maximum number of passes to perform. More passes may find more common + subexpressions, but will take longer. + + Returns: + List of tensor expressions, which may correspond to the original output or intermediates. + """ + # Collect all indices in the expression + indices: set[Index] = set() + for tensor in expr.rhs.search(Tensor): + indices.update(set(tensor.indices)) + + def _canonicalise(exprs: list[Expression]) -> list[Expression]: + """Canonicalise the indices.""" + for i, expr in enumerate(exprs): + if expr.lhs.name.startswith("tmp"): + lhs, rhs = _canonicalise_intermediate(expr.lhs, expr.rhs, indices) + expr = Expression(lhs, rhs) + else: + rhs = canonicalise_indices(expr.rhs, extra_indices=list(indices), which="internal") + lhs = expr.lhs.map_indices( + dict(zip(expr.rhs.external_indices, rhs.external_indices)) + ) + exprs[i] = Expression(expr.lhs, expr.rhs.squeeze().canonicalise()) + return exprs + + # Parenthesise each multiplication + exprs: list[Expression] = [] + counter = 0 + for mul in expr.rhs.expand().children: + rhs, ints = parenthesise_mul( + cast(Mul, mul), + sizes=sizes, + scaling_limit_cpu=scaling_limit_cpu, + scaling_limit_ram=scaling_limit_ram, + intermediate_counter=counter, + ) + exprs.extend(ints) + exprs.append(Expression(expr.lhs, rhs)) + counter += len(ints) + + # Eliminate common subexpressions + for i in range(max_passes): + exprs_prev = exprs.copy() + if i != 0: + exprs = factorise(exprs) + exprs = eliminate_common_subexpressions(exprs, sizes=sizes) + exprs = _canonicalise(exprs) + exprs = absorb_trivial_intermediates(exprs) + exprs = merge_identical_intermediates(exprs) + if exprs == exprs_prev: + break + + # Renumber intermediates, also sorts the expressions + exprs = renumber_intermediates(exprs) + + # Sum expressions with the same output + groups: dict[Tensor, Base] = {} + for expr in exprs: + if expr.lhs not in groups: + groups[expr.lhs] = expr.rhs + else: + groups[expr.lhs] = groups[expr.lhs] + expr.rhs + exprs = [Expression(lhs, rhs) for lhs, rhs in groups.items()] + + unused = set(interm.name for interm in unused_intermediates(exprs)) + undefined = set(interm.name for interm in undefined_intermediates(exprs)) + if unused: + warnings.warn(f"Intermediates defined but not used: {unused}.", stacklevel=2) + if undefined: + warnings.warn(f"Intermediates used but not defined: {undefined}.", stacklevel=2) + + return exprs diff --git a/albert/opt/cse.py b/albert/opt/cse.py index 0f3f268..cf776b6 100644 --- a/albert/opt/cse.py +++ b/albert/opt/cse.py @@ -786,6 +786,7 @@ def _optimise( return expressions # Canonicalise the terms in the expressions + expressions = list(expressions) for i, expression in enumerate(expressions): expressions[i] = _canonicalise_expression(expression, indices) diff --git a/albert/opt/tools.py b/albert/opt/tools.py index f1c326c..1393873 100644 --- a/albert/opt/tools.py +++ b/albert/opt/tools.py @@ -14,7 +14,7 @@ from albert.tensor import Tensor if TYPE_CHECKING: - from typing import Any, Optional + from typing import Any, Literal, Optional from albert.base import Base @@ -263,7 +263,7 @@ def count_flops(expr: Base, sizes: Optional[dict[str | None, int]] = None) -> in def optimise_eom( returns: list[Tensor], exprs: list[Expression], - method: str = "auto", + method: Literal["auto", "gristmill", "albert", "legacy"] = "auto", **kwargs: Any, ) -> tuple[tuple[list[Tensor], list[Expression]], tuple[list[Tensor], list[Expression]]]: """Perform common subexpression elimination for EOM expressions. @@ -273,7 +273,7 @@ def optimise_eom( Args: returns: The return tensors. exprs: The tensor expressions to be optimised. - method: The optimisation method to use. Options are `"auto"`, `"gristmill"`. + method: The optimisation method to use. **kwargs: Additional keyword arguments to pass to the optimiser. Returns: diff --git a/albert/qc/__init__.py b/albert/qc/__init__.py index bf5aa55..208d9bd 100644 --- a/albert/qc/__init__.py +++ b/albert/qc/__init__.py @@ -1 +1,96 @@ """Functionality specific to quantum chemistry applications.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from albert.qc._pdaggerq import import_from_pdaggerq +from albert.qc._wick import import_from_wick +from albert.qc.spin import ghf_to_rhf, ghf_to_uhf +from albert.tensor import Tensor +from albert.expression import Expression + +if TYPE_CHECKING: + from typing import Any, Iterable, Literal + + +def import_expression( + external: Any, + package: Literal["pdaggerq", "wick"] = "pdaggerq", + index_order: Iterable[str] | None = None, + name: str | None = None, + **kwargs: Any, +) -> Expression: + """Import an expression from a third-party quantum chemistry package. + + Args: + external: The external expression to import. The exact format is specified by the + individual importers. + package: The code from which to import the expression. + index_order: The desired order of external index labels in the imported expression. The + indices of the left-hand side of the expression will be sorted such that their labels + match this order. + name: The name to assign to the LHS tensor. + **kwargs: Additional keyword arguments to pass to the importer. + + Returns: + The imported expression. + """ + # Import the RHS + if package == "pdaggerq": + rhs = import_from_pdaggerq(external, **kwargs) + elif package == "wick": + rhs = import_from_wick(external, **kwargs) + else: + raise ValueError(f"Unknown package: {package}") + rhs = rhs.canonicalise(indices=True) + + # Get the LHS + if index_order is None: + indices = rhs.external_indices + else: + indices = tuple(sorted(rhs.external_indices, key=lambda i: list(index_order).index(i.name))) + lhs = Tensor(*indices, name=name) + + return Expression(lhs, rhs) + + +def adapt_spin( + expr: Expression | Iterable[Expression], + target_spin: Literal["rhf", "uhf"], +) -> tuple[Expression, ...]: + """Adapt the spin representation of a quantum chemistry expression. + + Args: + expr: The expression(s) to adapt. + target_spin: The target spin representation. + + Returns: + The adapted expressions. For `"rhf"`, this is a tuple with a single expression. For `"uhf"`, + this is a tuple with one expression per spin case. + """ + if isinstance(expr, Expression): + expr = (expr,) + + # Convert the RHS + if target_spin == "rhf": + rhs_list = [ghf_to_rhf(e.rhs) for e in expr] + lhs_list = [e.lhs for e in expr] + elif target_spin == "uhf": + rhs_list = [] + lhs_list = [] + for e in expr: + rhs_parts = ghf_to_uhf(e.rhs) + rhs_list.extend(rhs_parts) + lhs_list.extend([e.lhs for _ in rhs_parts]) + else: + raise ValueError(f"Unknown target spin: {target_spin}") + + # Get the LHS for each case + exprs = [] + for lhs, rhs in zip(lhs_list, rhs_list): + spins = {index.name: index.spin for index in rhs.external_indices} + index_map = {index: index.copy(spin=spins[index.name]) for index in lhs.external_indices} + exprs.append(Expression(lhs.map_indices(index_map), rhs)) + + return tuple(exprs) diff --git a/albert/qc/_pdaggerq.py b/albert/qc/_pdaggerq.py index c440548..845acb7 100644 --- a/albert/qc/_pdaggerq.py +++ b/albert/qc/_pdaggerq.py @@ -151,6 +151,7 @@ def _convert_symbol( index_spins: Optional[dict[str, str]] = None, index_spaces: Optional[dict[str, str]] = None, l_is_lambda: bool = True, + name: str | None = None, ) -> Base: """Convert a symbol to a subclass of `Base`. @@ -160,6 +161,7 @@ def _convert_symbol( index_spaces: The index spaces. l_is_lambda: Whether `l` corresponds to the Lambda operator, rather than the left-hand EOM operator. + name: The name of the tensor. Returns: The converted symbol. @@ -386,7 +388,7 @@ def _convert_symbol( for index in index_strs ) - return tensor_symbol(*indices) + return tensor_symbol.factory(*indices, name=name) def remove_reference_energy(terms: list[list[str]]) -> list[list[str]]: diff --git a/albert/qc/_wick.py b/albert/qc/_wick.py new file mode 100644 index 0000000..19d9a91 --- /dev/null +++ b/albert/qc/_wick.py @@ -0,0 +1,215 @@ +"""Interface to `wick`.""" + +from __future__ import annotations + +import re +from typing import TYPE_CHECKING + +from albert.algebra import Mul +from albert.index import Index +from albert.qc import ghf +from albert.qc._pdaggerq import _guess_space, _is_number +from albert.qc.tensor import QTensor +from albert.scalar import Scalar + +if TYPE_CHECKING: + from typing import Optional + + from albert.base import Base + + +def import_from_wick( + terms: list[str], + index_spins: Optional[dict[str, str]] = None, + index_spaces: Optional[dict[str, str]] = None, + l_is_lambda: bool = True, + symbol_aliases: Optional[dict[str, str]] = None, +) -> Base: + r"""Import an expression from `wick`. + + Tensors in the return expression are `GHF` tensors. + + Args: + terms: The terms of the expression. Should be the lines of the `repr` of the output + `AExpression` in `wick`, i.e. `str(AExpression(Ex=...)).split("\n")`. + index_spins: The index spins. + index_spaces: The index spaces. + l_is_lambda: Whether `l` corresponds to the Lambda operator, rather than the left-hand EOM + operator. + symbol_aliases: Aliases for symbols. + + Returns: + The imported expression. + """ + if index_spins is None: + index_spins = {} + if index_spaces is None: + index_spaces = {} + + # Build the expression + expr: Base = Scalar.factory(0.0) + for term_str in terms: + # Convert the symbols + term = _split_term(term_str) + term, names = zip(*[_format_symbol(symbol, aliases=symbol_aliases) for symbol in term]) + symbols = [ + _convert_symbol( + symbol, + index_spins=index_spins, + index_spaces=index_spaces, + l_is_lambda=l_is_lambda, + name=name, + ) + for symbol, name in zip(term, names) + ] + part = Mul.factory(*symbols) + + # Add the term to the expression + expr += part.canonicalise(indices=True) # wick doesn't guarantee same external indices + + return expr + + +def _split_term(term: str) -> list[str]: + """Split a term into its symbols.""" + term = term.lstrip(" ") + term = term.replace(" ", "") + term = term.replace("}", "} ").rstrip(" ") + if r"\sum_{" in term: + term = re.sub(r"\\sum_\{[^\}]*\}", "", term) + else: + i = 0 + while term[i] in "-+0123456789.": + i += 1 + if i > 0: + term = term[:i] + " " + term[i:] + return term.split(" ") + + +def _format_symbol(symbol: str, aliases: dict[str, str] | None = None) -> tuple[str, str]: + """Rewrite a `wick` symbol to look like a `pdaggerq` symbol.""" + symbol = re.sub( + r"([a-zA-Z0-9]+)_\{([^\}]*)\}", lambda m: f"{m.group(1)}({','.join(m.group(2))})", symbol + ) + symbol_name, indices = symbol.split("(", 1) if "(" in symbol else (symbol, None) + if aliases is not None: + symbol_alias = aliases.get(symbol_name, symbol_name) + symbol = f"{symbol_alias}({indices}" if indices is not None else symbol_alias + return symbol, symbol_name + + +def _convert_symbol( + symbol: str, + index_spins: Optional[dict[str, str]] = None, + index_spaces: Optional[dict[str, str]] = None, + l_is_lambda: bool = True, + name: str | None = None, +) -> Base: + """Convert a symbol to a subclass of `Base`. + + Args: + symbol: The symbol. + index_spins: The index spins. + index_spaces: The index spaces. + l_is_lambda: Whether `l` corresponds to the Lambda operator, rather than the left-hand EOM + operator. + name: The name of the tensor. + + Returns: + The converted symbol. + """ + if index_spins is None: + index_spins = {} + if index_spaces is None: + index_spaces = {} + + if re.match(r".*_[0-9]+$", symbol): + # Symbol has spaces attached, separate them + symbol, spaces = symbol.rsplit("_", 1) + + if _is_number(symbol): + # It's the factor + return Scalar.factory(float(symbol)) + + tensor_symbol: type[QTensor] + index_strs: tuple[str, ...] + if symbol in ("r0", "l0"): + # r0 or l0 + index_strs = () + tensor_symbol = ghf.R0 + elif re.match(r"f\((?i:[a-z]),(?i:[a-z])\)", symbol): + # f(i,j) + index_strs = tuple(symbol[2:-1].split(",")) + tensor_symbol = ghf.Fock + elif re.match(r"v\((?i:[a-z]),(?i:[a-z]),(?i:[a-z]),(?i:[a-z])\)", symbol): + # v(i,j,k,l) + index_strs = tuple(symbol[2:-1].split(",")) + index_strs = (index_strs[2], index_strs[3], index_strs[0], index_strs[1]) + tensor_symbol = ghf.ERI + elif re.match(r"t1\((?i:[a-z]),(?i:[a-z])\)", symbol): + # t1(i,j) + index_strs = tuple(symbol[3:-1].split(",")) + index_strs = (index_strs[1], index_strs[0]) + tensor_symbol = ghf.T1 + elif re.match(r"t2\((?i:[a-z]),(?i:[a-z]),(?i:[a-z]),(?i:[a-z])\)", symbol): + # t2(i,j,k,l) + index_strs = tuple(symbol[3:-1].split(",")) + index_strs = (index_strs[2], index_strs[3], index_strs[0], index_strs[1]) + tensor_symbol = ghf.T2 + elif re.match( + r"t3\((?i:[a-z]),(?i:[a-z]),(?i:[a-z]),(?i:[a-z]),(?i:[a-z]),(?i:[a-z])\)", symbol + ): + # t3(i,j,k,l,m,n) + index_strs = tuple(symbol[3:-1].split(","))[::-1] + index_strs = ( + index_strs[3], + index_strs[4], + index_strs[5], + index_strs[0], + index_strs[1], + index_strs[2], + ) + tensor_symbol = ghf.T3 + elif re.match(r"l1\((?i:[a-z]),(?i:[a-z])\)", symbol) and l_is_lambda: + # l1(i,j) + index_strs = tuple(symbol[3:-1].split(",")) + index_strs = (index_strs[1], index_strs[0]) + tensor_symbol = ghf.L1 + elif re.match(r"l2\((?i:[a-z]),(?i:[a-z]),(?i:[a-z]),(?i:[a-z])\)", symbol) and l_is_lambda: + # l2(i,j,k,l) + index_strs = tuple(symbol[3:-1].split(",")) + index_strs = (index_strs[2], index_strs[3], index_strs[0], index_strs[1]) + tensor_symbol = ghf.L2 + elif ( + re.match(r"l3\((?i:[a-z]),(?i:[a-z]),(?i:[a-z]),(?i:[a-z]),(?i:[a-z]),(?i:[a-z])\)", symbol) + and l_is_lambda + ): + # l3(i,j,k,l,m,n) + index_strs = tuple(symbol[3:-1].split(","))[::-1] + index_strs = ( + index_strs[3], + index_strs[4], + index_strs[5], + index_strs[0], + index_strs[1], + index_strs[2], + ) + tensor_symbol = ghf.L3 + elif re.match(r"delta\((?i:[a-z]),(?i:[a-z])\)", symbol): + # delta(i,j) + index_strs = tuple(symbol[6:-1].split(",")) + tensor_symbol = ghf.Delta + else: + raise ValueError(f"Unknown symbol {symbol}") + + # Convert the indices + indices = tuple( + Index( + index, + spin=index_spins.get(index, None), + space=index_spaces.get(index, _guess_space(index)), + ) + for index in index_strs + ) + + return tensor_symbol.factory(*indices, name=name) diff --git a/examples/codegen_rccsd.py b/examples/codegen_rccsd.py index 17fcbe1..ce88b79 100644 --- a/examples/codegen_rccsd.py +++ b/examples/codegen_rccsd.py @@ -6,11 +6,9 @@ from pdaggerq import pq_helper from albert.code.einsum import EinsumCodeGenerator -from albert.expression import Expression -from albert.opt._gristmill import optimise_gristmill -from albert.qc._pdaggerq import import_from_pdaggerq, remove_reference_energy -from albert.qc.spin import ghf_to_rhf -from albert.tensor import Tensor +from albert.opt import optimise +from albert.qc import adapt_spin, import_expression +from albert.qc._pdaggerq import remove_reference_energy # Suppress warnings since we're outputting the code to stdout warnings.filterwarnings("ignore") @@ -30,15 +28,14 @@ pq.simplify() expr = pq.strings() expr = remove_reference_energy(expr) -expr = import_from_pdaggerq(expr) -expr = ghf_to_rhf(expr).collect() -output = Tensor(name="e_cc") +expr = import_expression(expr, name="e_cc") +exprs = adapt_spin(expr, target_spin="rhf") # Optimise the energy expression -exprs = optimise_gristmill([Expression(output, expr)], strategy="exhaust") +exprs = optimise(exprs, strategy="exhaust") # Generate the code for the energy expression -codegen("energy", [output], exprs) +codegen("energy", [expr.lhs for expr in exprs], exprs) # Find the T1 expression pq.clear() @@ -47,9 +44,8 @@ pq.add_st_operator(1.0, ["v"], ["t1", "t2"]) pq.simplify() expr_t1 = pq.strings() -expr_t1 = import_from_pdaggerq(expr_t1) -expr_t1 = ghf_to_rhf(expr_t1).collect() -output_t1 = Tensor(*expr_t1.external_indices, name="t1new") +expr_t1 = import_expression(expr_t1, name="t1new") +exprs_t1 = adapt_spin(expr_t1, target_spin="rhf") # Find the T2 expression pq.clear() @@ -58,20 +54,19 @@ pq.add_st_operator(1.0, ["v"], ["t1", "t2"]) pq.simplify() expr_t2 = pq.strings() -expr_t2 = import_from_pdaggerq(expr_t2) -expr_t2 = ghf_to_rhf(expr_t2).collect() -output_t2 = Tensor(*expr_t2.external_indices, name="t2new") +expr_t2 = import_expression(expr_t2, name="t2new") +exprs_t2 = adapt_spin(expr_t2, target_spin="rhf") # Optimise the T1 and T2 expressions -exprs = optimise_gristmill( - [Expression(output_t1, expr_t1), Expression(output_t2, expr_t2)], +exprs = optimise( + exprs_t1 + exprs_t2, strategy="trav", ) # Generate the code for the T1 and T2 expressions codegen( "update_amplitudes", - [output_t1, output_t2], + [expr.lhs for expr in (exprs_t1 + exprs_t2)], exprs, as_dict=True, ) diff --git a/tests/test_opt.py b/tests/test_opt.py index febe757..b9f83d3 100644 --- a/tests/test_opt.py +++ b/tests/test_opt.py @@ -1,6 +1,10 @@ +import pytest + from albert.expression import Expression from albert.opt.tools import substitute_expressions +from albert.opt import optimise from albert.tensor import Tensor +from albert.index import from_list def test_substitute_expressions(): @@ -61,3 +65,18 @@ def test_substitute_expressions(): output_expr_sub[1].rhs == Tensor.from_string("(a(i,k,l) * b(k,l,j)) + (a(i,k,l) * c(k,l,j)) + (z(i,j))").expand() ) + + +@pytest.mark.parametrize("method", ["auto", "gristmill", "albert", "legacy"]) +def test_optimise(method: str): + i, j, k, l = from_list(["i", "j", "k", "l"], spaces="o") + lhs = Tensor(i, j, name="x") + rhs = ( + Tensor(i, k, l, name="a") * Tensor(k, l, j, name="b") + + Tensor(i, k, l, name="a") * Tensor(k, l, j, name="c") + ) + expr = Expression(lhs, rhs) + optimised_exprs = optimise([expr], method=method) + expr_recovered = substitute_expressions(optimised_exprs)[0] + assert len(optimised_exprs) == 2 + assert expr_recovered.lhs == expr.lhs diff --git a/tests/test_pdaggerq.py b/tests/test_pdaggerq.py index 26d52d6..8fa73ff 100644 --- a/tests/test_pdaggerq.py +++ b/tests/test_pdaggerq.py @@ -26,7 +26,7 @@ def test_ccsd_energy(): pq.simplify() terms = pq.strings() expr_ghf = import_from_pdaggerq(terms) - expr_ghf = expr_ghf.canonicalise() + expr_ghf = expr_ghf.canonicalise(indices=True) assert ( repr(expr_ghf) == "(0.5 * v(i,j,j,i)) + f(i,i) + (0.25 * t2(i,j,a,b) * v(i,j,a,b)) + (f(i,a) * t1(i,a)) + (0.5 * t1(i,a) * t1(j,b) * v(i,j,a,b))" @@ -36,7 +36,7 @@ def test_ccsd_energy(): terms = remove_reference_energy(terms) expr_ghf = import_from_pdaggerq(terms) - expr_ghf = expr_ghf.canonicalise() + expr_ghf = expr_ghf.canonicalise(indices=True) assert ( repr(expr_ghf) == "(0.25 * t2(i,j,a,b) * v(i,j,a,b)) + (f(i,a) * t1(i,a)) + (0.5 * t1(i,a) * t1(j,b) * v(i,j,a,b))" @@ -55,7 +55,7 @@ def _filter_fock_terms(mul: Mul) -> Mul | Scalar: expr_ghf = expr_ghf.expand() expr_ghf = expr_ghf.apply(_filter_fock_terms, Mul) - expr_ghf = expr_ghf.canonicalise() + expr_ghf = expr_ghf.canonicalise(indices=True) assert ( repr(expr_ghf) == "(0.25 * t2(i,j,a,b) * v(i,j,a,b)) + (0.5 * t1(i,a) * t1(j,b) * v(i,j,a,b))" @@ -87,11 +87,11 @@ def _project(mul: Mul) -> Mul | Scalar: ), Mul, ) - expr_uhf_aaaa = expr_uhf_aaaa.canonicalise() - #assert ( - # repr(expr_uhf_aaaa) - # == "(0.5 * t2(iα,jα,aα,bα) * v(iα,aα,jα,bα)) + (-0.5 * t2(iα,jα,aα,bα) * v(iα,bα,jα,aα)) + (0.5 * t1(iα,aα) * t1(jα,bα) * v(iα,aα,jα,bα)) + (-0.5 * t1(iα,aα) * t1(jα,bα) * v(iα,bα,jα,aα))" - #) + expr_uhf_aaaa = expr_uhf_aaaa.canonicalise(indices=True) + assert ( + repr(expr_uhf_aaaa) + == "(0.5 * t2(iα,jα,aα,bα) * v(iα,aα,jα,bα)) + (-0.5 * t2(iα,jα,aα,bα) * v(iα,bα,jα,aα)) + (0.5 * t1(iα,aα) * t1(jα,bα) * v(iα,aα,jα,bα)) + (-0.5 * t1(iα,aα) * t1(jα,bα) * v(iα,bα,jα,aα))" + ) expr_uhf_abab = expr_uhf.apply( _project_onto_indices( @@ -104,7 +104,7 @@ def _project(mul: Mul) -> Mul | Scalar: ), Mul, ) - expr_uhf_abab = expr_uhf_abab.canonicalise() + expr_uhf_abab = expr_uhf_abab.canonicalise(indices=True) assert ( repr(expr_uhf_abab) == "(0.25 * t2(iα,jβ,aα,bβ) * v(iα,aα,jβ,bβ)) + (0.5 * t1(iα,aα) * t1(jβ,bβ) * v(iα,aα,jβ,bβ))" diff --git a/tests/test_rccsd.py b/tests/test_rccsd.py index c8a6dda..ad1679a 100644 --- a/tests/test_rccsd.py +++ b/tests/test_rccsd.py @@ -10,8 +10,8 @@ from albert.code.einsum import EinsumCodeGenerator from albert.opt import optimise as _optimise -from albert.qc._pdaggerq import import_from_pdaggerq, remove_reference_energy -from albert.qc.spin import ghf_to_rhf +from albert.qc._pdaggerq import remove_reference_energy +from albert.qc import import_expression, adapt_spin from albert.tensor import Tensor from albert.expression import Expression @@ -26,27 +26,26 @@ def _kwargs(strategy, transposes, greedy_cutoff, drop_cutoff): @pytest.mark.parametrize( - "optimise, method, canonicalise, kwargs", + "optimise, method, kwargs", [ - (False, None, False, _kwargs(None, None, None, None)), - (True, "gristmill", True, _kwargs("trav", "natural", -1, -1)), - (True, "gristmill", True, _kwargs("opt", "natural", -1, -1)), - (True, "gristmill", False, _kwargs("greedy", "ignore", -1, 2)), - (True, "gristmill", True, _kwargs("greedy", "ignore", 2, 2)), - (True, "albert", True, {}), + (False, None, _kwargs(None, None, None, None)), + (True, "gristmill", _kwargs("trav", "natural", -1, -1)), + (True, "gristmill", _kwargs("greedy", "ignore", -1, 2)), + (True, "gristmill", _kwargs("greedy", "ignore", 2, 2)), + (True, "albert", {}), ], ) -def test_rccsd_einsum(helper, optimise, method, canonicalise, kwargs): +def test_rccsd_einsum(helper, optimise, method, kwargs): with open(f"{os.path.dirname(__file__)}/_test_rccsd.py", "w") as file: try: - _test_rccsd_einsum(helper, file, optimise, method, canonicalise, kwargs) + _test_rccsd_einsum(helper, file, optimise, method, kwargs) except Exception as e: raise e finally: os.remove(f"{os.path.dirname(__file__)}/_test_rccsd.py") -def _test_rccsd_einsum(helper, file, optimise, method, canonicalise, kwargs): +def _test_rccsd_einsum(helper, file, optimise, method, kwargs): codegen = EinsumCodeGenerator(stdout=file) codegen.preamble() @@ -59,17 +58,12 @@ def _test_rccsd_einsum(helper, file, optimise, method, canonicalise, kwargs): pq.simplify() energy = pq.strings() energy = remove_reference_energy(energy) - energy = import_from_pdaggerq(energy) - energy = ghf_to_rhf(energy) - if canonicalise: - energy = energy.canonicalise(indices=True).collect() - output = Tensor(name="e_cc") - - exprs = [Expression(output, energy)] + energy = import_expression(energy, package="pdaggerq", name="e_cc") + exprs = adapt_spin(energy, target_spin="rhf") if optimise: exprs = _optimise(exprs, method=method, **kwargs) - codegen("energy", [output], exprs) + codegen("energy", [expr.lhs for expr in exprs], exprs) pq.clear() pq.set_left_operators([["e1(i,a)"]]) @@ -77,13 +71,8 @@ def _test_rccsd_einsum(helper, file, optimise, method, canonicalise, kwargs): pq.add_st_operator(1.0, ["v"], ["t1", "t2"]) pq.simplify() t1 = pq.strings() - t1 = import_from_pdaggerq(t1, index_spins=dict(i="a", a="a")) - t1 = ghf_to_rhf(t1) - if canonicalise: - t1 = t1.canonicalise(indices=True).collect() - output_t1 = Tensor( - *sorted(t1.external_indices, key=lambda i: "ijab".index(i.name)), name="t1new" - ) + t1 = import_expression(t1, package="pdaggerq", index_spins=dict(i="a", a="a"), name="t1new") + t1 = adapt_spin(t1, target_spin="rhf") pq.clear() pq.set_left_operators([["e2(i,j,b,a)"]]) @@ -91,19 +80,14 @@ def _test_rccsd_einsum(helper, file, optimise, method, canonicalise, kwargs): pq.add_st_operator(1.0, ["v"], ["t1", "t2"]) pq.simplify() t2 = pq.strings() - t2 = import_from_pdaggerq(t2, index_spins=dict(i="a", j="b", a="a", b="b")) - t2 = ghf_to_rhf(t2) - if canonicalise: - t2 = t2.canonicalise(indices=True).collect() - output_t2 = Tensor( - *sorted(t2.external_indices, key=lambda i: "ijab".index(i.name)), name="t2new" - ) - - exprs = [Expression(output_t1, t1), Expression(output_t2, t2)] + t2 = import_expression(t2, package="pdaggerq", index_spins=dict(i="a", j="b", a="a", b="b"), name="t2new") + t2 = adapt_spin(t2, target_spin="rhf") + + exprs = t1 + t2 if optimise: exprs = _optimise(exprs, method=method, **kwargs) - codegen("update_amplitudes", [output_t1, output_t2], exprs, as_dict=True) + codegen("update_amplitudes", [expr.lhs for expr in exprs], exprs, as_dict=True) module = importlib.import_module(f"_test_rccsd") energy = module.energy diff --git a/tests/test_uccsd.py b/tests/test_uccsd.py index e892cb2..67ad7ef 100644 --- a/tests/test_uccsd.py +++ b/tests/test_uccsd.py @@ -10,8 +10,8 @@ from albert.code.einsum import EinsumCodeGenerator from albert.opt import optimise as _optimise -from albert.qc._pdaggerq import import_from_pdaggerq, remove_reference_energy -from albert.qc.spin import ghf_to_uhf +from albert.qc._pdaggerq import remove_reference_energy +from albert.qc import import_expression, adapt_spin from albert.tensor import Tensor from albert.expression import Expression @@ -26,24 +26,24 @@ def _kwargs(strategy, transposes, greedy_cutoff, drop_cutoff): @pytest.mark.parametrize( - "optimise, canonicalise, kwargs", + "optimise, kwargs", [ - (False, False, _kwargs(None, None, None, None)), - (True, False, _kwargs("greedy", "ignore", -1, 2)), - (True, True, _kwargs("greedy", "ignore", 2, 2)), + (False, _kwargs(None, None, None, None)), + (True, _kwargs("greedy", "ignore", -1, 2)), + (True, _kwargs("greedy", "ignore", 2, 2)), ], ) -def test_uccsd_einsum(helper, optimise, canonicalise, kwargs): +def test_uccsd_einsum(helper, optimise, kwargs): with open(f"{os.path.dirname(__file__)}/_test_uccsd.py", "w") as file: try: - _test_uccsd_einsum(helper, file, optimise, canonicalise, kwargs) + _test_uccsd_einsum(helper, file, optimise, kwargs) except Exception as e: raise e finally: os.remove(f"{os.path.dirname(__file__)}/_test_uccsd.py") -def _test_uccsd_einsum(helper, file, optimise, canonicalise, kwargs): +def _test_uccsd_einsum(helper, file, optimise, kwargs): codegen = EinsumCodeGenerator(stdout=file) codegen.preamble() @@ -56,17 +56,12 @@ def _test_uccsd_einsum(helper, file, optimise, canonicalise, kwargs): pq.simplify() energy = pq.strings() energy = remove_reference_energy(energy) - energy = import_from_pdaggerq(energy) - energy = ghf_to_uhf(energy) - if canonicalise: - energy = tuple(e.canonicalise(indices=True).collect() for e in energy) - output = tuple(Tensor(name="e_cc") for _ in energy) - - exprs = [Expression(o, e) for o, e in zip(output, energy)] + energy = import_expression(energy, package="pdaggerq", name="e_cc") + exprs = adapt_spin(energy, target_spin="uhf") if optimise: exprs = _optimise(exprs, **kwargs) - codegen("energy", output, exprs) + codegen("energy", [expr.lhs for expr in exprs], exprs) pq.clear() pq.set_left_operators([["e1(i,a)"]]) @@ -74,14 +69,8 @@ def _test_uccsd_einsum(helper, file, optimise, canonicalise, kwargs): pq.add_st_operator(1.0, ["v"], ["t1", "t2"]) pq.simplify() t1 = pq.strings() - t1 = import_from_pdaggerq(t1) - t1 = ghf_to_uhf(t1) - if canonicalise: - t1 = tuple(t.canonicalise(indices=True).collect() for t in t1) - output_t1 = tuple( - Tensor(*sorted(t.external_indices, key=lambda i: "ijab".index(i.name)), name=f"t1new") - for i, t in enumerate(t1) - ) + t1 = import_expression(t1, package="pdaggerq", name="t1new") + t1 = adapt_spin(t1, target_spin="uhf") pq.clear() pq.set_left_operators([["e2(i,j,b,a)"]]) @@ -89,23 +78,19 @@ def _test_uccsd_einsum(helper, file, optimise, canonicalise, kwargs): pq.add_st_operator(1.0, ["v"], ["t1", "t2"]) pq.simplify() t2 = pq.strings() - t2_expr = tuple() - for spins in ("aaaa", "abab", "baba", "bbbb"): - index_spins = dict(zip("ijab", spins)) - t2_expr += ghf_to_uhf(import_from_pdaggerq(t2, index_spins=index_spins)) - t2 = t2_expr - if canonicalise: - t2 = tuple(t.canonicalise(indices=True).collect() for t in t2) - output_t2 = tuple( - Tensor(*sorted(t.external_indices, key=lambda i: "ijab".index(i.name)), name=f"t2new") - for i, t in enumerate(t2) - ) - - exprs = [Expression(o, t) for o, t in zip(output_t1 + output_t2, t1 + t2)] + t2 = [ + import_expression( + t2, package="pdaggerq", index_spins=dict(zip("ijab", spins)), name="t2new" + ) + for spins in ("aaaa", "abab", "baba", "bbbb") + ] + t2 = adapt_spin(t2, target_spin="uhf") + + exprs = t1 + t2 if optimise: exprs = _optimise(exprs, **kwargs) - codegen("update_amplitudes", output_t1 + output_t2, exprs, as_dict=True) + codegen("update_amplitudes", [expr.lhs for expr in exprs], exprs, as_dict=True) module = importlib.import_module(f"_test_uccsd") energy = module.energy