diff --git a/t3f/riemannian.py b/t3f/riemannian.py index 5a92f5a7..6f4ba89b 100644 --- a/t3f/riemannian.py +++ b/t3f/riemannian.py @@ -389,12 +389,16 @@ def project_matmul(what, where, matrix): a TensorTrain with the TT-ranks equal 2 * tangent_space_tens.get_tt_ranks() Complexity: - O(d r_where^3 m) for orthogonalizing the TT-cores of where + O(d r_where^3 m) for orthogonalizing the TT-cores +O(batch_size d R r_what r_where (n r_what + n m R + m r_where)) - d is the number of TT-cores (what.ndims()); - r_what is the largest TT-rank of what max(what.get_tt_rank()) - r_where is the largest TT-rank of where - matrix is of TT-rank R and of raw-shape (m, m, ..., m) x (n, n, ..., n). + or equivalently + O(d r_where^3 m) + +O(batch_size d R^3 r_what r_where n m R) + given that m r_where < n m R and n r_what < n m R, where + d is the number of TT-cores (what.ndims()); + r_what is the largest TT-rank of what (max(what.get_tt_rank())) + r_where is the largest TT-rank of where + matrix is of TT-rank R and of raw-shape (m, m, ..., m) x (n, n, ..., n). """ if not isinstance(where, TensorTrain): @@ -446,8 +450,8 @@ def project_matmul(what, where, matrix): tens_core = what.tt_cores[core_idx] right_tang_core = right_tangent_space_tens.tt_cores[core_idx] matrix_core = matrix.tt_cores[core_idx] - rhs[core_idx] = tf.einsum('bije,cikf,sdef,sajkd->sabc', matrix_core, - right_tang_core, rhs[core_idx + 1], tens_core) + rhs[core_idx] = tf.einsum('cikf,sdef,bije,sajkd->sabc', right_tang_core, + rhs[core_idx + 1], matrix_core, tens_core) # Prepare lhs vectors. # lhs[core_idx] is of size # batch_size x tangent_tt_ranks[core_idx] x matrix_tt_ranks[core_idx] x tensor_tt_ranks[core_idx] @@ -458,8 +462,8 @@ def project_matmul(what, where, matrix): left_tang_core = left_tangent_space_tens.tt_cores[core_idx] matrix_core = matrix.tt_cores[core_idx] # TODO: brutforce order of indices in lhs?? - lhs[core_idx + 1] = tf.einsum('bije,aikd,sabc,scjkf->sdef', matrix_core, - left_tang_core, lhs[core_idx], tens_core) + lhs[core_idx + 1] = tf.einsum('aikd,sabc,bije,scjkf->sdef', left_tang_core, + lhs[core_idx], matrix_core, tens_core) # Left to right sweep. res_cores_list = []