From 5060e7824220b7032e8ccebff8f7526e6690ba82 Mon Sep 17 00:00:00 2001 From: David Tellenbach Date: Tue, 15 Feb 2022 17:56:34 +0100 Subject: [PATCH] Add possibility to use opt_einsum instead of pure Numpy This patch enables to use the highly optimized tensor network contractor opt_einsum instead of pure Numpy. opt_einsum particularly optimizes contraction orders, therefore the `order` argument is ignored if opt_einsum is used. --- setup.py | 2 +- src/ncon/ncon.py | 15 +++++++++++++-- tests/test_ncon.py | 48 +++++++++++++++++++++++++++++++++++++++++----- 3 files changed, 57 insertions(+), 8 deletions(-) diff --git a/setup.py b/setup.py index ab08160..f2630c7 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ "Topic :: Scientific/Engineering", ], keywords=["tensor networks"], - install_requires=["numpy>=1.11.0"], + install_requires=["numpy>=1.11.0", "opt-einsum>=3.3.0"], extras_require={"tests": ["pytest", "coverage"]}, python_requires=">=3.6", ) diff --git a/src/ncon/ncon.py b/src/ncon/ncon.py index ae5d67d..cca2349 100644 --- a/src/ncon/ncon.py +++ b/src/ncon/ncon.py @@ -2,9 +2,10 @@ """ import numpy as np from collections.abc import Iterable +import opt_einsum - -def ncon(L, v, order=None, forder=None, check_indices=True): +def ncon(L, v, order=None, forder=None, check_indices=True, backend = 'numpy', + opt_einsum_strategy = 'auto'): """L = [A1, A2, ..., Ap] list of tensors. v = (v1, v2, ..., vp) tuple of lists of indices e.g. v1 = [3, 4, -1] labels @@ -25,6 +26,10 @@ def ncon(L, v, order=None, forder=None, check_indices=True): will be considered a tensor). """ + # Check that contraction backend is valid + if backend not in ['numpy', 'opt_einsum']: + raise RuntimeError(f"Unknown contraction backend '{backend}'") + # We want to handle the tensors as a list, regardless of what kind # of iterable we are given. In addition, if only a single element is # given, we make list out of it. Inputs are assumed to be non-empty. @@ -48,6 +53,12 @@ def ncon(L, v, order=None, forder=None, check_indices=True): # Raise a RuntimeError if the indices are wrong. do_check_indices(L, v, order, forder) + # If we use opt_einsum as a contraction backend we are almost done as + # permutations are handled by the backend + if backend == 'opt_einsum': + opt_einsum_args = [arg for pair in zip(L, v) for arg in pair] + return opt_einsum.contract(*opt_einsum_args, forder) + # If the graph is dinconnected, connect it with trivial indices that # will be contracted at the very end. connect_graph(L, v, order) diff --git a/tests/test_ncon.py b/tests/test_ncon.py index 6e2a4de..396e915 100644 --- a/tests/test_ncon.py +++ b/tests/test_ncon.py @@ -5,37 +5,61 @@ # the arguments lists or tuples. -def test_matrixproduct(): +def test_matrixproduct_numpy(): a = np.random.randn(3, 4) b = np.random.randn(4, 5) ab_ncon = ncon([a, b], ((-1, 1), (1, -2))) ab_np = np.dot(a, b) assert np.allclose(ab_ncon, ab_np) +def test_matrixproduct_opt_einsum(): + a = np.random.randn(3, 4) + b = np.random.randn(4, 5) + ab_ncon = ncon([a, b], ((-1, 1), (1, -2)), backend = 'opt_einsum') + ab_np = np.dot(a, b) + assert np.allclose(ab_ncon, ab_np) -def test_disconnected(): +def test_disconnected_numpy(): a = np.random.randn(2, 3) b = np.random.randn(4) ab_ncon = ncon((a, b), ([-3, -2], [-1])) ab_np = np.einsum("ij, k -> kji", a, b) assert np.allclose(ab_ncon, ab_np) +def test_disconnected_opt_einsum(): + a = np.random.randn(2, 3) + b = np.random.randn(4) + ab_ncon = ncon((a, b), ([-3, -2], [-1]), backend = 'opt_einsum') + ab_np = np.einsum("ij, k -> kji", a, b) + assert np.allclose(ab_ncon, ab_np) -def test_permutation(): +def test_permutation_numpy(): a = np.random.randn(2, 3, 4, 5) aperm_ncon = ncon(a, [-4, -2, -1, -3]) aperm_np = np.transpose(a, [2, 1, 3, 0]) assert np.allclose(aperm_ncon, aperm_np) +def test_permutation_opt_einsum(): + a = np.random.randn(2, 3, 4, 5) + aperm_ncon = ncon(a, [-4, -2, -1, -3], backend = 'opt_einsum') + aperm_np = np.transpose(a, [2, 1, 3, 0]) + assert np.allclose(aperm_ncon, aperm_np) -def test_trace(): + +def test_trace_numpy(): a = np.random.randn(3, 2, 3) atr_ncon = ncon((a,), ([1, -1, 1],)) atr_np = np.einsum("iji -> j", a) assert np.allclose(atr_ncon, atr_np) +def test_trace_opt_einsum(): + a = np.random.randn(3, 2, 3) + atr_ncon = ncon((a,), ([1, -1, 1],), backend = 'opt_einsum') + atr_np = np.einsum("iji -> j", a) + assert np.allclose(atr_ncon, atr_np) + -def test_large_contraction(): +def test_large_contraction_numpy(): a = np.random.randn(3, 4, 5) b = np.random.randn(5, 3, 6, 7, 6) c = np.random.randn(7, 2) @@ -46,3 +70,17 @@ def test_large_contraction(): ) result_np = np.einsum("ijk, kilml, mh, q, qp -> hjp", a, b, c, d, e) assert np.allclose(result_ncon, result_np) + +def test_large_contraction_opt_einsum(): + a = np.random.randn(3, 4, 5) + b = np.random.randn(5, 3, 6, 7, 6) + c = np.random.randn(7, 2) + d = np.random.randn(8) + e = np.random.randn(8, 9) + result_ncon = ncon( + (a, b, c, d, e), ([3, -2, 2], [2, 3, 1, 4, 1], [4, -1], [5], [5, -3]), + backend = 'opt_einsum' + ) + result_np = np.einsum("ijk, kilml, mh, q, qp -> hjp", a, b, c, d, e) + assert np.allclose(result_ncon, result_np) +