Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"Topic :: Scientific/Engineering",
],
keywords=["tensor networks"],
install_requires=["numpy>=1.11.0"],
install_requires=["numpy>=1.11.0", "opt-einsum>=3.3.0"],
extras_require={"tests": ["pytest", "coverage"]},
python_requires=">=3.6",
)
15 changes: 13 additions & 2 deletions src/ncon/ncon.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
"""
import numpy as np
from collections.abc import Iterable
import opt_einsum


def ncon(L, v, order=None, forder=None, check_indices=True):
def ncon(L, v, order=None, forder=None, check_indices=True, backend = 'numpy',
opt_einsum_strategy = 'auto'):
"""L = [A1, A2, ..., Ap] list of tensors.

v = (v1, v2, ..., vp) tuple of lists of indices e.g. v1 = [3, 4, -1] labels
Expand All @@ -25,6 +26,10 @@ def ncon(L, v, order=None, forder=None, check_indices=True):
will be considered a tensor).
"""

# Check that contraction backend is valid
if backend not in ['numpy', 'opt_einsum']:
raise RuntimeError(f"Unknown contraction backend '{backend}'")

# We want to handle the tensors as a list, regardless of what kind
# of iterable we are given. In addition, if only a single element is
# given, we make list out of it. Inputs are assumed to be non-empty.
Expand All @@ -48,6 +53,12 @@ def ncon(L, v, order=None, forder=None, check_indices=True):
# Raise a RuntimeError if the indices are wrong.
do_check_indices(L, v, order, forder)

# If we use opt_einsum as a contraction backend we are almost done as
# permutations are handled by the backend
if backend == 'opt_einsum':
opt_einsum_args = [arg for pair in zip(L, v) for arg in pair]
return opt_einsum.contract(*opt_einsum_args, forder)

# If the graph is dinconnected, connect it with trivial indices that
# will be contracted at the very end.
connect_graph(L, v, order)
Expand Down
48 changes: 43 additions & 5 deletions tests/test_ncon.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,61 @@
# the arguments lists or tuples.


def test_matrixproduct():
def test_matrixproduct_numpy():
a = np.random.randn(3, 4)
b = np.random.randn(4, 5)
ab_ncon = ncon([a, b], ((-1, 1), (1, -2)))
ab_np = np.dot(a, b)
assert np.allclose(ab_ncon, ab_np)

def test_matrixproduct_opt_einsum():
a = np.random.randn(3, 4)
b = np.random.randn(4, 5)
ab_ncon = ncon([a, b], ((-1, 1), (1, -2)), backend = 'opt_einsum')
ab_np = np.dot(a, b)
assert np.allclose(ab_ncon, ab_np)

def test_disconnected():
def test_disconnected_numpy():
a = np.random.randn(2, 3)
b = np.random.randn(4)
ab_ncon = ncon((a, b), ([-3, -2], [-1]))
ab_np = np.einsum("ij, k -> kji", a, b)
assert np.allclose(ab_ncon, ab_np)

def test_disconnected_opt_einsum():
a = np.random.randn(2, 3)
b = np.random.randn(4)
ab_ncon = ncon((a, b), ([-3, -2], [-1]), backend = 'opt_einsum')
ab_np = np.einsum("ij, k -> kji", a, b)
assert np.allclose(ab_ncon, ab_np)

def test_permutation():
def test_permutation_numpy():
a = np.random.randn(2, 3, 4, 5)
aperm_ncon = ncon(a, [-4, -2, -1, -3])
aperm_np = np.transpose(a, [2, 1, 3, 0])
assert np.allclose(aperm_ncon, aperm_np)

def test_permutation_opt_einsum():
a = np.random.randn(2, 3, 4, 5)
aperm_ncon = ncon(a, [-4, -2, -1, -3], backend = 'opt_einsum')
aperm_np = np.transpose(a, [2, 1, 3, 0])
assert np.allclose(aperm_ncon, aperm_np)

def test_trace():

def test_trace_numpy():
a = np.random.randn(3, 2, 3)
atr_ncon = ncon((a,), ([1, -1, 1],))
atr_np = np.einsum("iji -> j", a)
assert np.allclose(atr_ncon, atr_np)

def test_trace_opt_einsum():
a = np.random.randn(3, 2, 3)
atr_ncon = ncon((a,), ([1, -1, 1],), backend = 'opt_einsum')
atr_np = np.einsum("iji -> j", a)
assert np.allclose(atr_ncon, atr_np)


def test_large_contraction():
def test_large_contraction_numpy():
a = np.random.randn(3, 4, 5)
b = np.random.randn(5, 3, 6, 7, 6)
c = np.random.randn(7, 2)
Expand All @@ -46,3 +70,17 @@ def test_large_contraction():
)
result_np = np.einsum("ijk, kilml, mh, q, qp -> hjp", a, b, c, d, e)
assert np.allclose(result_ncon, result_np)

def test_large_contraction_opt_einsum():
a = np.random.randn(3, 4, 5)
b = np.random.randn(5, 3, 6, 7, 6)
c = np.random.randn(7, 2)
d = np.random.randn(8)
e = np.random.randn(8, 9)
result_ncon = ncon(
(a, b, c, d, e), ([3, -2, 2], [2, 3, 1, 4, 1], [4, -1], [5], [5, -3]),
backend = 'opt_einsum'
)
result_np = np.einsum("ijk, kilml, mh, q, qp -> hjp", a, b, c, d, e)
assert np.allclose(result_ncon, result_np)