From 0bad01d0e4823d95bce608e75dad2983e7c7fb16 Mon Sep 17 00:00:00 2001 From: Oliver Backhouse Date: Thu, 30 Oct 2025 16:44:57 +0000 Subject: [PATCH] Add introduction notebook --- examples/introduction.ipynb | 1155 +++++++++++++++++++++++++++++++++++ 1 file changed, 1155 insertions(+) create mode 100644 examples/introduction.ipynb diff --git a/examples/introduction.ipynb b/examples/introduction.ipynb new file mode 100644 index 0000000..f0aec13 --- /dev/null +++ b/examples/introduction.ipynb @@ -0,0 +1,1155 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f7c484da-ffe6-4b62-be9d-40eeafa867bb", + "metadata": {}, + "source": [ + "# Introduction to `albert`" + ] + }, + { + "cell_type": "markdown", + "id": "2bae6eb3-d8cf-4df8-a03b-67c3df8e7532", + "metadata": {}, + "source": [ + "This notebook is intended to introduce the core functionality of `albert`." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "737fde72-0fdf-4d18-9d0a-7ba2b57740f9", + "metadata": {}, + "outputs": [], + "source": [ + "from albert.tensor import Tensor\n", + "from albert.scalar import Scalar\n", + "from albert.algebra import Add, Mul\n", + "from albert.base import Base\n", + "from albert.index import Index\n", + "from albert.expression import Expression\n", + "from albert.symmetry import Permutation, Symmetry\n", + "from albert.misc import from_string" + ] + }, + { + "cell_type": "markdown", + "id": "cbfb1583-2f61-4a4e-af58-f76506292c9a", + "metadata": {}, + "source": [ + "## Basics" + ] + }, + { + "cell_type": "markdown", + "id": "45fb8a50-2355-4ca6-b836-84b544543f92", + "metadata": {}, + "source": [ + "The primary building blocks of `albert` are computational graphs which can be classified as [arborescences](https://en.wikipedia.org/wiki/Arborescence_(graph_theory)); that is, a directed acyclic graph (DAG) in which every vertex has a single parent and shares the same root node. This structure is sufficient to define any sum of tensor contractions obeying the [Einstein summation convention](https://en.wikipedia.org/wiki/Einstein_notation) using just four types of vertices:\n", + "\n", + "1. a scalar node, consistent of a single scalar factor;\n", + "2. a tensor node, consisting of a tensor and the associated indices;\n", + "3. an addition node, consisting of a collection of nodes that are added;\n", + "4. a multiplication node, consisting of a collection of nodes that are contracted, according to their indices.\n", + "\n", + "These nodes (vertices) are implemented in `albert` as `Scalar`, `Tensor`, `Add`, and `Mul`, respectively. The first two of these nodes represent leaves in the graph, as they do not have child nodes themselves. The two algebraic nodes are internal nodes, each with children that define the arguments to the respective operand. Each class is a subclass of a `Base` object which offers a powerful toolbox for walking the graph and constructing, analysing, and applying tensor contractions." + ] + }, + { + "cell_type": "markdown", + "id": "00791fa1-dc1c-4410-b00d-25d7c541f26d", + "metadata": {}, + "source": [ + "The simplest node is the `Scalar`, which just requires the value:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "20a153e1-6af3-4823-b738-86f1ac8485ca", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "s = Scalar(2.0)\n", + "s" + ] + }, + { + "cell_type": "markdown", + "id": "290f74cc-f504-48b0-b455-94a063fb5e7d", + "metadata": {}, + "source": [ + "The `Tensor` object requires specification of the indices, and the name (label) of the tensor." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "07eedbc2-03f0-4edd-820e-bf238598dcfe", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "x(i,j)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "indices = (Index(\"i\"), Index(\"j\"))\n", + "x = Tensor(*indices, name=\"x\")\n", + "x" + ] + }, + { + "cell_type": "markdown", + "id": "c62d3001-7808-440b-9e39-a68498622447", + "metadata": {}, + "source": [ + "The `Add` object requires specification of the tensors that are being added together." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "227c8eb6-1bea-43b0-a481-a56c0fab9de7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "x(i,j) + y(i,j)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "add = Add(\n", + " Tensor(*indices, name=\"x\"),\n", + " Tensor(*indices, name=\"y\"),\n", + ")\n", + "add" + ] + }, + { + "cell_type": "markdown", + "id": "39efa6f3-4bf2-4c31-8cf6-480d9f38d229", + "metadata": {}, + "source": [ + "The arguments must have the same indices (within transposition)." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "f26740ed-79a8-4594-85e0-f3ea0262401c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ValueError('External indices in additions must be equal.')" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "error = None\n", + "try:\n", + " add = Add(\n", + " Tensor(Index(\"i\"), Index(\"j\"), name=\"x\"),\n", + " Tensor(Index(\"i\"), Index(\"k\"), name=\"y\"),\n", + " )\n", + "except ValueError as e:\n", + " error = e\n", + "error" + ] + }, + { + "cell_type": "markdown", + "id": "a0b56c9a-b81f-4b7b-9004-b2b47a852d49", + "metadata": {}, + "source": [ + "The `Mul` requires specification of the tensors that are being contracted." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "df25d27d-9549-479c-a81a-ca88c4df0002", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "x(i,j) * y(j,k)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mul = Mul(\n", + " Tensor(Index(\"i\"), Index(\"j\"), name=\"x\"),\n", + " Tensor(Index(\"j\"), Index(\"k\"), name=\"y\"),\n", + ")\n", + "mul" + ] + }, + { + "cell_type": "markdown", + "id": "427b411d-cef9-4b5e-9923-a485f0617356", + "metadata": {}, + "source": [ + "The indices must obey the Einstein summation convention." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "82a887c0-9fc8-4d35-b7fb-87cd2f874bef", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ValueError('Input arguments are not a valid Einstein notation. Each index must appear at most twice.')" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "error = None\n", + "try:\n", + " mul = Mul(\n", + " Tensor(Index(\"i\"), Index(\"j\"), name=\"x\"),\n", + " Tensor(Index(\"i\"), Index(\"k\"), name=\"y\"),\n", + " Tensor(Index(\"i\"), Index(\"l\"), name=\"z\"),\n", + " )\n", + "except ValueError as e:\n", + " error = e\n", + "error" + ] + }, + { + "cell_type": "markdown", + "id": "e3eda021-48cc-4e91-bf0f-c4e235cf1d4f", + "metadata": {}, + "source": [ + "Graphs can be constructed by either\n", + "\n", + "- instantiating the node classes as above, along with the children as arguments;\n", + "- using the `Base.factory` method of each node, with the same signature as the class constructor;\n", + "- using the overloaded operators `*` and `+`.\n", + "\n", + "The latter two methods perform basic simplifications in trivial cases, for example, collecting like terms or detecting necessary zeros." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "17830144-d09a-47ae-a22b-f4ca4a770283", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x = Tensor(Index(\"i\"), Index(\"j\"), name=\"x\")\n", + "Mul.factory(Scalar(0.0), x)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "17a4fd3a-3dcc-41e3-8863-46e42b40b7a4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(x(i,j) * y(j,k)) + z(i,k)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x = Tensor(Index(\"i\"), Index(\"j\"), name=\"x\")\n", + "y = Tensor(Index(\"j\"), Index(\"k\"), name=\"y\")\n", + "z = Tensor(Index(\"i\"), Index(\"k\"), name=\"z\")\n", + "rhs = (x * y) + z\n", + "rhs" + ] + }, + { + "cell_type": "markdown", + "id": "329a39cd-d82c-493e-9760-1af9f50bf37b", + "metadata": {}, + "source": [ + "The default `repr` of the graphs displays their algebraic form. For large nested expressions this can become unintuitive, and instead there is a function to view the graph in a tree-like representation." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "3f81cc36-ff02-41af-925f-4b3c5f451d49", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Add\n", + "├── Mul\n", + "│ ├── Add\n", + "│ │ ├── Mul\n", + "│ │ │ ├── x(i,j)\n", + "│ │ │ └── y(j,k)\n", + "│ │ └── z(i,k)\n", + "│ └── d(i,a)\n", + "├── Mul\n", + "│ ├── Add\n", + "│ │ ├── Mul\n", + "│ │ │ ├── x(i,j)\n", + "│ │ │ └── y(j,k)\n", + "│ │ └── z(i,k)\n", + "│ └── d(i,a)\n", + "└── c(a,k)\n" + ] + } + ], + "source": [ + "x = Tensor(Index(\"i\"), Index(\"j\"), name=\"x\")\n", + "y = Tensor(Index(\"j\"), Index(\"k\"), name=\"y\")\n", + "z = Tensor(Index(\"i\"), Index(\"k\"), name=\"z\")\n", + "c = Tensor(Index(\"a\"), Index(\"k\"), name=\"c\")\n", + "d = Tensor(Index(\"i\"), Index(\"a\"), name=\"d\")\n", + "rhs = (x * y) + z\n", + "rhs = (rhs * d) + (rhs * d) + c\n", + "print(rhs.tree_repr())" + ] + }, + { + "cell_type": "markdown", + "id": "8585953b-1274-4f68-904f-f30cbb934651", + "metadata": {}, + "source": [ + "## Simplification" + ] + }, + { + "cell_type": "markdown", + "id": "96af7eb0-9d30-4cde-b746-221d7c1dfadf", + "metadata": {}, + "source": [ + "Graphs have a series of simple methods for basic simplification\n", + "\n", + "- `expand`, to expand brackets;\n", + "- `squeeze`, to remove any redundant algebraic operations;\n", + "- `collect`, to collect like terms and sum them via modified scalar factors." + ] + }, + { + "cell_type": "markdown", + "id": "a2546c2c-b387-4338-bbd6-6b6a58913a24", + "metadata": {}, + "source": [ + "Expansion via `expand` converts arbitrarily nested graphs into the form `Add[Mul[Tensor | Scalar]]` by expanding brackets." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "04e57181-2dbb-4f37-b1b3-fd9cfb259688", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Add\n", + "├── Mul\n", + "│ ├── x(i,j)\n", + "│ ├── y(j,k)\n", + "│ └── d(i,a)\n", + "├── Mul\n", + "│ ├── z(i,k)\n", + "│ └── d(i,a)\n", + "├── Mul\n", + "│ ├── x(i,j)\n", + "│ ├── y(j,k)\n", + "│ └── d(i,a)\n", + "├── Mul\n", + "│ ├── z(i,k)\n", + "│ └── d(i,a)\n", + "└── Mul\n", + " └── c(a,k)\n" + ] + } + ], + "source": [ + "rhs = rhs.expand()\n", + "print(rhs.tree_repr())" + ] + }, + { + "cell_type": "markdown", + "id": "bffed623-6e26-4131-adc5-a79a092ce558", + "metadata": {}, + "source": [ + "Squeezing via `squeeze` will remove any redundant nodes; for example, the final `Mul` in this example has a single argument, and this node can be replaced with the tensor argument itself." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "74856590-96f5-4e18-8402-64cc2983f722", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Add\n", + "├── Mul\n", + "│ ├── x(i,j)\n", + "│ ├── y(j,k)\n", + "│ └── d(i,a)\n", + "├── Mul\n", + "│ ├── z(i,k)\n", + "│ └── d(i,a)\n", + "├── Mul\n", + "│ ├── x(i,j)\n", + "│ ├── y(j,k)\n", + "│ └── d(i,a)\n", + "├── Mul\n", + "│ ├── z(i,k)\n", + "│ └── d(i,a)\n", + "└── c(a,k)\n" + ] + } + ], + "source": [ + "rhs = rhs.squeeze()\n", + "print(rhs.tree_repr())" + ] + }, + { + "cell_type": "markdown", + "id": "903586b7-7a93-4454-a789-8b777ab5fff5", + "metadata": {}, + "source": [ + "Collection via `collect` will collect like terms and insert a multiplication by a scalar factor to account for their frequency. This may not be sufficient to reverse `expand`, as it is not a true factorisation." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "28defe89-279d-46bc-af6b-a95f8911a1cd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Add\n", + "├── Mul\n", + "│ ├── 2\n", + "│ ├── x(i,j)\n", + "│ ├── y(j,k)\n", + "│ └── d(i,a)\n", + "├── Mul\n", + "│ ├── 2\n", + "│ ├── z(i,k)\n", + "│ └── d(i,a)\n", + "└── c(a,k)\n" + ] + } + ], + "source": [ + "rhs = rhs.collect()\n", + "print(rhs.tree_repr())" + ] + }, + { + "cell_type": "markdown", + "id": "33076074-4284-441b-96d1-24c99188f88d", + "metadata": {}, + "source": [ + "## Permutation" + ] + }, + { + "cell_type": "markdown", + "id": "0d75f516-e101-4ff0-a658-d7611de158b3", + "metadata": {}, + "source": [ + "Nodes in graphs have external indices; those being the indices that are not contracted over within that particular node. For a `Tensor` those are simply the specified indices, for an `Add` the equal indices of each argument, and for a `Mul` the external indices within the Einstein summation convention." + ] + }, + { + "cell_type": "markdown", + "id": "570e4301-df89-472b-9221-4a2dded676cb", + "metadata": {}, + "source": [ + "The external indices of a node can be permuted using the `permute_indices` method." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "fcaa8200-cfe5-429d-bedb-6f74ac6fcbb4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "x(i,k,j)" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tensor = Tensor(Index(\"i\"), Index(\"j\"), Index(\"k\"), name=\"x\")\n", + "tensor.permute_indices((0, 2, 1))" + ] + }, + { + "cell_type": "markdown", + "id": "1b98dc26-0b9a-49e5-b0a9-1f7be95c6508", + "metadata": {}, + "source": [ + "More generally, all indices (both internal and external) can be permuted according to an explicit mapping using the `map_indices` method." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "72bced0b-a777-434f-bd0b-71ff867b3b19", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Add\n", + "├── Mul\n", + "│ ├── 2\n", + "│ ├── x(j,i)\n", + "│ ├── y(i,k)\n", + "│ └── d(j,w)\n", + "├── Mul\n", + "│ ├── 2\n", + "│ ├── z(j,k)\n", + "│ └── d(j,w)\n", + "└── c(w,k)\n" + ] + } + ], + "source": [ + "rhs = rhs.map_indices(\n", + " {\n", + " Index(\"i\"): Index(\"j\"),\n", + " Index(\"j\"): Index(\"i\"),\n", + " Index(\"a\"): Index(\"w\"),\n", + " }\n", + ")\n", + "print(rhs.tree_repr())" + ] + }, + { + "cell_type": "markdown", + "id": "38188b9b-8780-4983-bf17-60b26be2c892", + "metadata": {}, + "source": [ + "Nodes can also be permuted using a `Permutation` object, which contains the permutation and sign." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "c225ac7b-ec53-4f8d-9f5d-97423b2e5363", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "x(j,i)" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tensor = Tensor(Index(\"i\"), Index(\"j\"), name=\"x\")\n", + "permutation = Permutation((1, 0), +1)\n", + "permutation(tensor)" + ] + }, + { + "cell_type": "markdown", + "id": "59818df7-0efc-472f-a42e-7ba10bbd57fc", + "metadata": {}, + "source": [ + "A collection of `Permutation` objects can be composed into a `Symmetry` object to represent the permutational symmetry of a tensor. These symmetry groups can be passed during the initialisation of the tensor, and then the resulting graph canonicalised under those symmetries using the `canonicalise` method." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "d13bb608-30b7-49f2-af87-19dc2f23ef23", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2 * x(i,j)" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "symmetry = Symmetry(\n", + " Permutation((0, 1), +1),\n", + " Permutation((1, 0), +1),\n", + ")\n", + "rhs = Tensor(Index(\"i\"), Index(\"j\"), name=\"x\", symmetry=symmetry)\n", + "rhs += Tensor(Index(\"j\"), Index(\"i\"), name=\"x\", symmetry=symmetry)\n", + "rhs.canonicalise().collect()" + ] + }, + { + "cell_type": "markdown", + "id": "e9e0dd67-fac8-4f2c-a5fb-7b04034346b8", + "metadata": {}, + "source": [ + "## Traversal" + ] + }, + { + "cell_type": "markdown", + "id": "35fcac97-51e7-49f3-8032-4a31de320290", + "metadata": {}, + "source": [ + "Since the graphs are essentially n-ary generalisations of a binary tree, we can use familiar tree traversal methods to perform depth-first search over the graph to find specific nodes, or do surgery on the graph itself, using the four functions\n", + "\n", + "- `search`, to search for nodes matching some type filter;\n", + "- `find`, to search for the first node matching some type filter;\n", + "- `delete`, to set the value of a node matching some type filter to zero;\n", + "- `apply`, to call a function on a node matching some type filter, and replace the node with the result.\n", + "\n", + "The type filters can be a type to match the instance of, a tuple of types, or a function that evaluates to `True` or `False` when called with the node as the argument. Since nodes are immutable objects, each method returns a copy of the original node, which may contain references to existing nodes." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "e3a2a4a8-e7c7-465d-a47b-99b8bc4c1a62", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mul\n", + "├── Add\n", + "│ ├── Mul\n", + "│ │ ├── a(i,l)\n", + "│ │ └── b(k,l)\n", + "│ └── c(i,k)\n", + "└── d(j,k)\n" + ] + } + ], + "source": [ + "rhs = from_string(\"(a(i,l) * b(k,l) + c(i,k)) * d(j,k)\")\n", + "print(rhs.tree_repr())" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "f2265862-2512-4475-b41c-23e32f6b4ec8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[a(i,l), b(k,l), c(i,k), d(j,k)]" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "list(rhs.search(Tensor))" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "c00a44f1-3f81-4975-bae1-ca6e2fa96c39", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "c(i,k)" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rhs.find(lambda node: isinstance(node, Tensor) and node.name == \"c\")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "c9d0cf7b-fbdd-4427-a8a0-bd53bc498fd2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "a(i,l) * b(k,l) * d(j,k)" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rhs.delete(lambda node: isinstance(node, Tensor) and node.name == \"c\")" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "aa06b5a4-7edf-446f-b1cb-ad094c140be7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((A(i,l) * B(k,l)) + C(i,k)) * D(j,k)" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rhs.apply(lambda tensor: tensor.copy(name=tensor.name.upper()), Tensor)" + ] + }, + { + "cell_type": "markdown", + "id": "b4682871-4ecf-4ebb-a241-e10223f185d1", + "metadata": {}, + "source": [ + "The `apply` function should be used for the majority of cases where one wishes to algebraically manipulate the expression. Many complex substitutions, morphisms, and other manipulations are possible through this approach." + ] + }, + { + "cell_type": "markdown", + "id": "68563376-beda-4972-bf3c-4262dca60a1c", + "metadata": {}, + "source": [ + "## Evaluation" + ] + }, + { + "cell_type": "markdown", + "id": "d5ca4edb-53d2-4c99-b625-da71e03ac7ce", + "metadata": {}, + "source": [ + "Graph expressions can be numerically evaluated without the need to use the code generation functionality, using the `evaluate` function. This requires the arrays corresponding to each tensor name to be passed in a dictionary, and an `einsum` driver function must be provided. In more advanced cases where the `Index` objects possess spaces, the dictionary of arrays must be a nested dictionary where each entry is itself a dictionary mapping spaces to arrays." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "57ed7b0f-8b06-4624-a96c-bca78bd716cc", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import numpy as np\n", + "\n", + "rhs = from_string(\"(a(i,l) * b(k,l) + c(i,k)) * d(j,k)\")\n", + "\n", + "np.random.seed(123)\n", + "i, j, k, l = 6, 7, 8, 9\n", + "a = np.random.random((i, l))\n", + "b = np.random.random((k, l))\n", + "c = np.random.random((i, k))\n", + "d = np.random.random((j, k))\n", + "\n", + "reference = (a @ b.T + c) @ d.T\n", + "result = rhs.evaluate(dict(a=a, b=b, c=c, d=d), np.einsum)\n", + "np.allclose(reference, result)" + ] + }, + { + "cell_type": "markdown", + "id": "60a6de0f-6141-4416-b983-ffe90ca55faf", + "metadata": {}, + "source": [ + "Alternatively, the `code` module provides template code generators that are intended to be subclassed to fit specific code generation requirements. Using one of the default code generators, we can generate a list of `np.einsum` calls. The `__call__` method of the code generator takes the desired function name, a list of tensors to be returned from the function, and a list of `Expression` objects, which each wrap a LHS tensor and a RHS definition." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "6e932a2e-6ec7-49a8-9118-a61b50ce8a9f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\"\"\"Code generated by `albert` version 0.0.0.\n", + "\n", + " * date: 2025-10-30T16:40:26.509729\n", + " * python version: 3.10.12 (main, Aug 15 2025, 14:32:43) [GCC 11.4.0]\n", + " * albert version: 0.0.0\n", + " * caller: /home/ollie/git/albert/albert/code/einsum.py\n", + " * node: ollie-desktop\n", + " * system: Linux\n", + " * processor: x86_64\n", + " * release: 6.8.0-86-generic\n", + "\"\"\"\n", + "\n", + "from types import SimpleNamespace\n", + "import numpy as np\n", + "\n", + "\n", + "def my_function(a=None, b=None, c=None, d=None, **kwargs):\n", + " \"\"\"Code generated by `albert` 0.0.0.\n", + "\n", + " Args:\n", + " a: \n", + " b: \n", + " c: \n", + " d: \n", + "\n", + " Returns:\n", + " x: \n", + " \"\"\"\n", + "\n", + " x = np.einsum(a, (0, 1), b, (2, 1), d, (3, 2), (0, 3), optimize=True)\n", + " x += np.einsum(c, (0, 1), d, (2, 1), (0, 2), optimize=True)\n", + "\n", + " return x\n", + "\n" + ] + } + ], + "source": [ + "from albert.code.einsum import EinsumCodeGenerator\n", + "import sys\n", + "\n", + "lhs = from_string(\"x(i,j)\")\n", + "rhs = from_string(\"(a(i,l) * b(k,l) + c(i,k)) * d(j,k)\")\n", + "expr = Expression(lhs, rhs)\n", + "\n", + "codegen = EinsumCodeGenerator(stdout=sys.stdout)\n", + "codegen.preamble()\n", + "codegen(\"my_function\", [lhs], [expr])\n", + "codegen.postamble()" + ] + }, + { + "cell_type": "markdown", + "id": "587700ac-a8cd-4f38-8a88-cc3304b1461e", + "metadata": {}, + "source": [ + "## Serialisation" + ] + }, + { + "cell_type": "markdown", + "id": "0e0f8952-aa64-49a1-9b99-3fd27c1683e3", + "metadata": {}, + "source": [ + "Essentially everything in `albert` inherits from the `Serialisable` base class, which means that they are serialisable, deserialisable, hashable, comparable, and sortable." + ] + }, + { + "cell_type": "markdown", + "id": "01304b42-0067-47ba-9f37-266ba4e93ee4", + "metadata": {}, + "source": [ + "There is a global intern table carrying a weak reference from the hash of a node to its value, and constructing nodes via their `factory` methods will make a cache request. This means that instantiation in this way can return nodes identical to existing ones, permitting efficiency and memory optimisations." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "a6d62317-c347-4328-a311-8aa4b0a71921", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(True, False)" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x1 = Tensor(Index(\"i\"), Index(\"j\"), name=\"x\")\n", + "x2 = Tensor(Index(\"i\"), Index(\"j\"), name=\"x\")\n", + "(x1 == x2, x1 is x2)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "935019aa-5766-4106-88a3-eb5a3ac0b475", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(True, True)" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x1 = Tensor.factory(Index(\"i\"), Index(\"j\"), name=\"x\")\n", + "x2 = Tensor.factory(Index(\"i\"), Index(\"j\"), name=\"x\")\n", + "(x1 == x2, x1 is x2)" + ] + }, + { + "cell_type": "markdown", + "id": "6929dc0c-41a0-4b6f-9f70-08d2f6c836aa", + "metadata": {}, + "source": [ + "Comparison and hashing are dispatched via the `_hashable_fields` method, which yields identical fields for any subclass of `Base`. This means that any pair of nodes can be compared or sorted, and equality takes advantage of short-circuited evaluation." + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "644dd968-7241-477a-933f-651e29f84459", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[1, a(i,j), a(i,j) + b(i,j), a(i,j) * c(j,k)]" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nodes = [\n", + " from_string(\"a(i,j) * c(j,k)\"),\n", + " from_string(\"a(i,j)\"),\n", + " from_string(\"1\"),\n", + " from_string(\"a(i,j) + b(i,j)\"),\n", + "]\n", + "sorted(nodes)" + ] + }, + { + "cell_type": "markdown", + "id": "d1d0ed20-c43c-43a6-bd99-aa0dd755d80e", + "metadata": {}, + "source": [ + "Using an internal JSON data format, `as_json` and `from_json` can be used for serialisation and deserialisation. This allows arbitrary tensor expressions to be saved and reused in the native `albert` format, rather than only according to code generated formats. " + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "fc3dd23e-d80a-417f-8da3-c930e4691a3d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'_type': 'Mul',\n", + " '_module': 'albert.algebra',\n", + " 'children': ({'_type': 'Tensor',\n", + " '_module': 'albert.tensor',\n", + " 'indices': ({'_type': 'Index',\n", + " '_module': 'albert.index',\n", + " 'name': 'i',\n", + " 'spin': None,\n", + " 'space': None},\n", + " {'_type': 'Index',\n", + " '_module': 'albert.index',\n", + " 'name': 'j',\n", + " 'spin': None,\n", + " 'space': None}),\n", + " 'name': 'a',\n", + " 'symmetry': None},\n", + " {'_type': 'Tensor',\n", + " '_module': 'albert.tensor',\n", + " 'indices': ({'_type': 'Index',\n", + " '_module': 'albert.index',\n", + " 'name': 'j',\n", + " 'spin': None,\n", + " 'space': None},\n", + " {'_type': 'Index',\n", + " '_module': 'albert.index',\n", + " 'name': 'k',\n", + " 'spin': None,\n", + " 'space': None}),\n", + " 'name': 'b',\n", + " 'symmetry': None})}" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rhs = from_string(\"a(i,j) * b(j,k)\")\n", + "rhs.as_json()" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "2fa3a00f-45a9-429b-9c8c-f2e0e0d93085", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "a(i,j) * b(j,k)" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loaded = Base.from_json(rhs.as_json())\n", + "loaded" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "dd805277-8707-4975-8210-d7ef3a6c72f2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loaded == rhs" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}