From 86625f618c3bdcdec87d65b0061cccff642e057c Mon Sep 17 00:00:00 2001 From: Alexander Novikov Date: Sun, 6 Jan 2019 21:17:23 +0000 Subject: [PATCH 1/2] more detailed complexities --- t3f/riemannian.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/t3f/riemannian.py b/t3f/riemannian.py index 5a92f5a7..fcfb4223 100644 --- a/t3f/riemannian.py +++ b/t3f/riemannian.py @@ -389,12 +389,16 @@ def project_matmul(what, where, matrix): a TensorTrain with the TT-ranks equal 2 * tangent_space_tens.get_tt_ranks() Complexity: - O(d r_where^3 m) for orthogonalizing the TT-cores of where + O(d r_where^3 m) for orthogonalizing the TT-cores +O(batch_size d R r_what r_where (n r_what + n m R + m r_where)) - d is the number of TT-cores (what.ndims()); - r_what is the largest TT-rank of what max(what.get_tt_rank()) - r_where is the largest TT-rank of where - matrix is of TT-rank R and of raw-shape (m, m, ..., m) x (n, n, ..., n). + or equivalently + O(d r_where^3 m) + +O(batch_size d R^3 r_what r_where n m R) + given that m r_where < n m R and n r_what < n m R, where + d is the number of TT-cores (what.ndims()); + r_what is the largest TT-rank of what (max(what.get_tt_rank())) + r_where is the largest TT-rank of where + matrix is of TT-rank R and of raw-shape (m, m, ..., m) x (n, n, ..., n). """ if not isinstance(where, TensorTrain): From 25835c45843ca1806c3e231a210146890702a111 Mon Sep 17 00:00:00 2001 From: Alexander Novikov Date: Sun, 6 Jan 2019 21:21:45 +0000 Subject: [PATCH 2/2] fix the code to actually get the advertaised complexity --- t3f/riemannian.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/t3f/riemannian.py b/t3f/riemannian.py index fcfb4223..6f4ba89b 100644 --- a/t3f/riemannian.py +++ b/t3f/riemannian.py @@ -450,8 +450,8 @@ def project_matmul(what, where, matrix): tens_core = what.tt_cores[core_idx] right_tang_core = right_tangent_space_tens.tt_cores[core_idx] matrix_core = matrix.tt_cores[core_idx] - rhs[core_idx] = tf.einsum('bije,cikf,sdef,sajkd->sabc', matrix_core, - right_tang_core, rhs[core_idx + 1], tens_core) + rhs[core_idx] = tf.einsum('cikf,sdef,bije,sajkd->sabc', right_tang_core, + rhs[core_idx + 1], matrix_core, tens_core) # Prepare lhs vectors. # lhs[core_idx] is of size # batch_size x tangent_tt_ranks[core_idx] x matrix_tt_ranks[core_idx] x tensor_tt_ranks[core_idx] @@ -462,8 +462,8 @@ def project_matmul(what, where, matrix): left_tang_core = left_tangent_space_tens.tt_cores[core_idx] matrix_core = matrix.tt_cores[core_idx] # TODO: brutforce order of indices in lhs?? - lhs[core_idx + 1] = tf.einsum('bije,aikd,sabc,scjkf->sdef', matrix_core, - left_tang_core, lhs[core_idx], tens_core) + lhs[core_idx + 1] = tf.einsum('aikd,sabc,bije,scjkf->sdef', left_tang_core, + lhs[core_idx], matrix_core, tens_core) # Left to right sweep. res_cores_list = []