From 03d0792f913a52cbd1250cfa704090ae26930ab1 Mon Sep 17 00:00:00 2001 From: Mikhail Usvyatsov Date: Mon, 26 Aug 2019 20:52:04 +0200 Subject: [PATCH] fixed full on master pytorch --- t3nsor/decompositions.py | 2 +- t3nsor/tensor_train.py | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/t3nsor/decompositions.py b/t3nsor/decompositions.py index 4bcf4e8..2da644a 100755 --- a/t3nsor/decompositions.py +++ b/t3nsor/decompositions.py @@ -19,7 +19,7 @@ def to_tt_tensor(tens, max_tt_rank=10, epsilon=None): for core_idx in range(d - 1): curr_mode = shape[core_idx] rows = ranks[core_idx] * curr_mode - tens = tens.view(rows, -1) + tens = tens.reshape(rows, -1) columns = tens.shape[1] u, s, v = svd_fix(tens) if max_tt_rank[core_idx + 1] == 1: diff --git a/t3nsor/tensor_train.py b/t3nsor/tensor_train.py index 4327d03..cbc634c 100755 --- a/t3nsor/tensor_train.py +++ b/t3nsor/tensor_train.py @@ -33,7 +33,7 @@ def __init__(self, tt_cores, shape=None, tt_ranks=None, convert_to_tensors=True) self._parameter = None self._dof = np.sum([np.prod(list(tt_core.shape)) for tt_core in self._tt_cores]) self._total = np.prod(self._shape) - + @property def tt_cores(self): @@ -73,15 +73,15 @@ def parameter(self): return self._parameter else: raise ValueError('Not a parameter, run .to_parameter() first') - + @property def dof(self): return self._dof - + @property def total(self): return self._total - + def to(self, device): new_cores = [] @@ -109,7 +109,7 @@ def to_parameter(self): new_cores.append(core) tt_p = TensorTrain(new_cores, convert_to_tensors=False) - tt_p._parameter = nn.ParameterList(tt_p.tt_cores) + tt_p._parameter = nn.ParameterList(tt_p.tt_cores) tt_p._is_parameter = True return tt_p @@ -122,7 +122,7 @@ def full(self): for i in range(1, num_dims): res = res.view(-1, ranks[i]) - curr_core = self.tt_cores[i].view(ranks[i], -1) + curr_core = self.tt_cores[i].reshape(ranks[i], -1) res = torch.matmul(res, curr_core) if self.is_tt_matrix: @@ -138,7 +138,7 @@ def full(self): for i in range(1, 2 * num_dims, 2): transpose.append(i) res = res.permute(*transpose) - + if self.is_tt_matrix: res = res.contiguous().view(*shape) else: @@ -268,8 +268,8 @@ def full(self): for i in range(1, 2 * num_dims, 2): transpose.append(i + 1) res = res.permute(transpose) - - if self.is_tt_matrix: + + if self.is_tt_matrix: res = res.contiguous().view(*shape) else: res = res.view(*shape)