Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions t3f/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from t3f.tensor_train import TensorTrain
from t3f.tensor_train_batch import TensorTrainBatch
import t3f.utils as utils
from t3f.tensor_train_base import TensorTrainBase
from t3f import shapes

Expand Down Expand Up @@ -247,8 +248,8 @@ def tensor_with_random_cores(shape, tt_rank=2, mean=0., stddev=1.):
num_dims = shape.size
if tt_rank.size == 1:
tt_rank = tt_rank * np.ones(num_dims - 1)
tt_rank = np.insert(tt_rank, 0, 1)
tt_rank = np.append(tt_rank, 1)
tt_rank = np.concatenate([[1], tt_rank, [1]])
tt_rank = np.minimum(tt_rank, utils.max_tt_ranks(shape))

tt_rank = tt_rank.astype(int)
# TODO: variable (name?) scope.
Expand Down Expand Up @@ -286,8 +287,8 @@ def tensor_batch_with_random_cores(shape, tt_rank=2, batch_size=1,
num_dims = shape.size
if tt_rank.size == 1:
tt_rank = tt_rank * np.ones(num_dims - 1)
tt_rank = np.insert(tt_rank, 0, 1)
tt_rank = np.append(tt_rank, 1)
tt_rank = np.concatenate([[1], tt_rank, [1]])
tt_rank = np.minimum(tt_rank, utils.max_tt_ranks(shape))
tt_rank = tt_rank.astype(int)
# TODO: variable (name?) scope.
tt_cores = [None] * num_dims
Expand Down Expand Up @@ -337,6 +338,8 @@ def matrix_with_random_cores(shape, tt_rank=2, mean=0., stddev=1.):
if tt_rank.size == 1:
tt_rank = tt_rank * np.ones(num_dims - 1)
tt_rank = np.concatenate([[1], tt_rank, [1]])
contracted_shape = np.prod(shape, axis=0)
tt_rank = np.minimum(tt_rank, utils.max_tt_ranks(contracted_shape))

tt_rank = tt_rank.astype(int)
# TODO: variable (name?) scope.
Expand Down Expand Up @@ -391,6 +394,9 @@ def matrix_batch_with_random_cores(shape, tt_rank=2, batch_size=1,
if tt_rank.size == 1:
tt_rank = tt_rank * np.ones(num_dims - 1)
tt_rank = np.concatenate([[1], tt_rank, [1]])
contracted_shape = np.prod(shape, axis=0)
tt_rank = np.minimum(tt_rank, utils.max_tt_ranks(contracted_shape))
# TODO: check that ints?
shape = shape.astype(int)
tt_rank = tt_rank.astype(int)
# TODO: variable (name?) scope.
Expand Down Expand Up @@ -481,8 +487,8 @@ def random_tensor(shape, tt_rank=2, mean=0., stddev=1.):
num_dims = shape.size
if tt_rank.size == 1:
tt_rank = tt_rank * np.ones(num_dims - 1)
tt_rank = np.insert(tt_rank, 0, 1)
tt_rank = np.append(tt_rank, 1)
tt_rank = np.concatenate([[1], tt_rank, [1]])
tt_rank = np.minimum(tt_rank, utils.max_tt_ranks(shape))

tt_rank = tt_rank.astype(int)

Expand Down Expand Up @@ -534,8 +540,8 @@ def random_tensor_batch(shape, tt_rank=2, batch_size=1, mean=0., stddev=1.):
num_dims = shape.size
if tt_rank.size == 1:
tt_rank = tt_rank * np.ones(num_dims - 1)
tt_rank = np.insert(tt_rank, 0, 1)
tt_rank = np.append(tt_rank, 1)
tt_rank = np.concatenate([[1], tt_rank, [1]])
tt_rank = np.minimum(tt_rank, utils.max_tt_ranks(shape))
tt_rank = tt_rank.astype(int)

cr_exponent = -1.0 / (2 * num_dims)
Expand Down Expand Up @@ -598,6 +604,8 @@ def random_matrix(shape, tt_rank=2, mean=0., stddev=1.):
if tt_rank.size == 1:
tt_rank = tt_rank * np.ones(num_dims - 1)
tt_rank = np.concatenate([[1], tt_rank, [1]])
contracted_shape = np.prod(shape, axis=0)
tt_rank = np.minimum(tt_rank, utils.max_tt_ranks(contracted_shape))

tt_rank = tt_rank.astype(int)
var = np.prod(tt_rank)
Expand Down Expand Up @@ -664,6 +672,8 @@ def random_matrix_batch(shape, tt_rank=2, batch_size=1, mean=0., stddev=1.):
if tt_rank.size == 1:
tt_rank = tt_rank * np.ones(num_dims - 1)
tt_rank = np.concatenate([[1], tt_rank, [1]])
contracted_shape = np.prod(shape, axis=0)
tt_rank = np.minimum(tt_rank, utils.max_tt_ranks(contracted_shape))

shape = shape.astype(int)
tt_rank = tt_rank.astype(int)
Expand Down
9 changes: 0 additions & 9 deletions t3f/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,15 +1028,6 @@ def quadratic_form(A, b, c):
Raises:
ValueError if the arguments are not TT-matrices or if the shapes are
not consistent.

Complexity:
O(batch_size r_A r_c r_b n d (r_b + r_A n + r_c))
d is the number of TT-cores (A.ndims());
r_A is the largest TT-rank of A max(A.get_tt_rank())
n is the size of the axis dimensions e.g.
if b and c are tensors of shape (3, 3, 3),
A is a 27 x 27 matrix of tensor shape (3, 3, 3) x (3, 3, 3)
then n is 3
"""
if not isinstance(A, TensorTrainBase) or not A.is_tt_matrix():
raise ValueError('The arguments should be a TT-matrix.')
Expand Down
44 changes: 44 additions & 0 deletions t3f/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,50 @@ def my_svd(tensor, full_matrices=False, compute_uv=True):
tf.svd = my_svd


def robust_cumprod(arr):
"""Cumulative product with large values replaced by the MAX_DTYPE.

robust_cumprod([10] * 100) = [10, 100, 1000, ..., MAX_INT, ..., MAX_INT]
"""

res = np.ones(arr.size, dtype=arr.dtype)
change_large_to = np.iinfo(arr.dtype).max
res[0] = arr[0]
for i in range(1, arr.size):
next_value = np.array(res[i - 1]) * np.array(arr[i])
if next_value / np.array(arr[i]) != np.array(res[i - 1]):
next_value = change_large_to
res[i] = next_value
return res


def max_tt_ranks(raw_shape):
"""Maximal TT-ranks for a TT-object of given shape.

For example, a tensor of shape (2, 3, 5, 7) has maximal TT-ranks
(1, 2, 6, 7, 1)
making the TT-ranks larger will not increase flexibility.

If maximum TT-ranks result in integer overflows, it substitutes
the too-large-values with MAX_INT.

Args:
shape: an integer vector

Returns:
tt_ranks: an integer vector, maximal tt-rank for each dimension
"""
raw_shape = np.array(raw_shape).astype(np.int64)
d = raw_shape.size
tt_ranks = np.zeros(d + 1, dtype='int64')
tt_ranks[0] = 1
tt_ranks[d] = 1
left_to_right = robust_cumprod(raw_shape)
right_to_left = robust_cumprod(raw_shape[::-1])[::-1]
tt_ranks[1:-1] = np.minimum(left_to_right[:-1], right_to_left[1:])
return tt_ranks


def in_eager_mode():
"""Checks whether tensorflow eager mode is avaialable and active."""
try:
Expand Down