diff --git a/docs/benchmark/run_all_autodiff.py b/docs/benchmark/run_all_autodiff.py new file mode 100644 index 00000000..f26ca1e7 --- /dev/null +++ b/docs/benchmark/run_all_autodiff.py @@ -0,0 +1,38 @@ +""" +Running a suite of autodiff benchmarks. + python3 run_all_autodiff.py --logs=autodiff_logs.pkl 2> autodiff_logs.stderr +""" + +import argparse +import subprocess +import utils + + +parser = argparse.ArgumentParser() +parser.add_argument('--logs', type=str) +args = parser.parse_args() + +def run_single(case, n, d, r, R=None): + cmd = ['python3', 'run_single_autodiff.py', '--case=%s' % case, + '--n=%d' % n, '--d=%d' % d, + '--tt_rank_vec=%d' % r, '--logs=%s' % args.logs] + if R is not None: + cmd.append('--tt_rank_mat=%d' % R) + cmd.append('--m=%d' % n) + try: + print(' '.join(cmd)) + print(subprocess.check_output(cmd)) + except: + print('Running subprocess failed.') + pass + + +for n in [5, 10, 20]: + for d in [10, 20, 40]: + for r in [5, 10, 20, 40]: + run_single('completion', n, d, r) + run_single('ExpMachines', n, d, r) + for R in [5, 10, 20, 40]: + run_single('xAx', n, d, r, R) + run_single('xABx', n, d, r, R) + run_single('RayleighQuotient', n, d, r, R) diff --git a/docs/benchmark/run_single_autodiff.py b/docs/benchmark/run_single_autodiff.py new file mode 100644 index 00000000..30b992dc --- /dev/null +++ b/docs/benchmark/run_single_autodiff.py @@ -0,0 +1,34 @@ +import argparse +import utils +import pickle +import os.path + +parser = argparse.ArgumentParser() +parser.add_argument('--logs', type=str) +parser.add_argument('--case', type=str) +parser.add_argument('--m', type=int) +parser.add_argument('--n', type=int) +parser.add_argument('--d', type=int) +parser.add_argument('--tt_rank_mat', type=int) +parser.add_argument('--tt_rank_vec', type=int) +args = parser.parse_args() + +if args.case == 'completion': + assert args.m is None and args.tt_rank_mat is None + case = utils.Completion(args.n, args.d, args.tt_rank_vec) +elif args.case == 'xAx': + case = utils.BilinearXAX(args.m, args.n, args.d, args.tt_rank_mat, args.tt_rank_vec) +elif args.case == 'xABx': + case = utils.BilinearXABX(args.m, args.n, args.d, args.tt_rank_mat, args.tt_rank_vec) +elif args.case == 'ExpMachines': + assert args.m is None and args.tt_rank_mat is None + case = utils.ExpMachines(args.n, args.d, args.tt_rank_vec) +elif args.case == 'RayleighQuotient': + case = utils.RayleighQuotient(args.m, args.n, args.d, args.tt_rank_mat, args.tt_rank_vec) +else: + print('Dont know this case.') + +print(args.case, case.settings) +utils.benchmark(args.case, case, args.logs) + + diff --git a/docs/benchmark/utils.py b/docs/benchmark/utils.py new file mode 100644 index 00000000..a70e87d4 --- /dev/null +++ b/docs/benchmark/utils.py @@ -0,0 +1,418 @@ +import numpy as np +import tensorflow as tf +import numpy as np +import t3f +import json +import pickle +import copy +import os +from shutil import copyfile + + +def robust_cumprod(arr): + """Cumulative product with large values replaced by the MAX_DTYPE. + + robust_cumprod([10] * 100) = [10, 100, 1000, ..., MAX_INT, ..., MAX_INT] + """ + + res = np.ones(arr.size, dtype=arr.dtype) + change_large_to = np.iinfo(arr.dtype).max + res[0] = arr[0] + for i in range(1, arr.size): + next_value = np.array(res[i - 1]) * np.array(arr[i]) + if next_value / np.array(arr[i]) != np.array(res[i - 1]): + next_value = change_large_to + res[i] = next_value + return res + + +def max_tt_ranks(raw_shape): + """Maximal TT-ranks for a TT-object of given shape. + + For example, a tensor of shape (2, 3, 5, 7) has maximal TT-ranks + (1, 2, 6, 7, 1) + making the TT-ranks larger will not increase flexibility. + + If maximum TT-ranks result in integer overflows, it substitutes + the too-large-values with MAX_INT. + + Args: + shape: an integer vector + Returns: + tt_ranks: an integer vector, maximal tt-rank for each dimension + """ + raw_shape = np.array(raw_shape).astype(np.int64) + d = raw_shape.size + tt_ranks = np.zeros(d + 1, dtype='int64') + tt_ranks[0] = 1 + tt_ranks[d] = 1 + left_to_right = robust_cumprod(raw_shape) + right_to_left = robust_cumprod(raw_shape[::-1])[::-1] + tt_ranks[1:-1] = np.minimum(left_to_right[:-1], right_to_left[1:]) + return tt_ranks + +def sparse(idx, shape, dtype=None): + cores = [] + for k in range(len(idx)): + eye = tf.eye(shape[k], dtype=dtype) + cores.append(tf.reshape(eye[idx[k]], (1, shape[k], 1))) + return t3f.TensorTrain(cores) + +def batch_sparse(idx_list, shape, weights=None, dtype=None): + cores = [] + for k in range(len(idx_list[0])): + curr_core = [] + eye = tf.eye(shape[k], dtype=dtype) + cores.append(tf.reshape(tf.gather(eye, idx_list[:, k]), (-1, 1, shape[k], 1))) + if weights is not None: + cores[0] *= weights[:, None, None, None] + return t3f.TensorTrainBatch(cores) + + +def reduce_sum_batch(x): + tt_cores = list(x.tt_cores) + for i, core in enumerate(tt_cores): + bs, r1, n, r2 = core.shape.as_list() + assert r1 == 1 and r2 == 1 + if i == 0: + core = tf.reshape(core, (bs, 1, n)) + core = tf.transpose(core, (1, 2, 0)) + elif i == len(tt_cores) - 1: + core = tf.reshape(core, (bs, n, 1)) + else: + core = tf.tile(core[:, :, :, None, :], (1, 1, 1, bs, 1)) + core = tf.reshape(core, (bs, n, bs)) + core *= tf.tile(tf.eye(bs, dtype=x.dtype)[:, None, :], (1, n, 1)) + tt_cores[i] = core + return t3f.TensorTrain(tt_cores) + + +def prune_ranks(tt_rank, shape): + tt_rank_arr = [1] + [tt_rank] * (len(shape) - 1) + [1] + return np.minimum(tt_rank_arr, max_tt_ranks(shape)) + + +class Task(object): + + def smart_grad(self): + return NotImplementedError() + + def naive_hessian_by_vector(self): + return NotImplementedError() + + def smart_hessian_by_vector(self): + return NotImplementedError() + + +class Completion(Task): + + def __init__(self, n, d, tt_rank): + self.settings = {'n': n, 'd': d, 'tt_rank': tt_rank} + shape = [n] * d + self.num_observed = 10 * d * n * tt_rank**2 ############################################################### + self.observation_idx = np.random.randint(0, n, size=(self.num_observed, len(shape))) + self.observations_np = np.random.randn(self.num_observed) + self.observations = tf.constant(self.observations_np) + tt_rank_x = [1] + [tt_rank] * (d - 1) + [1] + tt_rank_x = np.minimum(tt_rank_x, max_tt_ranks(shape)) + initialization = t3f.random_tensor(shape, tt_rank=tt_rank_x, dtype=tf.float64) + self.x = t3f.get_variable('x', initializer=initialization) + self.x *= 1.0 # Dtype bug + tt_rank_v = [1] + [2 * tt_rank] * (d - 1) + [1] + tt_rank_v = np.minimum(tt_rank_v, max_tt_ranks(shape)) + initialization = t3f.random_tensor(shape, tt_rank=tt_rank_v, dtype=tf.float64) + self.vector = t3f.get_variable('vector', initializer=initialization) + self.sparsity_mask_list_tt = batch_sparse(self.observation_idx, shape, dtype=tf.float64) + self.sparsity_mask_tt = reduce_sum_batch(self.sparsity_mask_list_tt) + self.sparse_observation_tt = reduce_sum_batch(batch_sparse(self.observation_idx, shape, self.observations_np, dtype=tf.float64)) + + def loss(self, x): + estimated_vals = t3f.gather_nd(x, self.observation_idx) + return 0.5 * tf.reduce_sum((estimated_vals - self.observations_np) ** 2) + + def naive_grad(self): + grad = self.sparsity_mask_tt * self.x - self.sparse_observation_tt + return t3f.project(grad, self.x) + + def smart_grad(self): + estimated_vals = t3f.gather_nd(self.x, self.observation_idx) + diff = estimated_vals - self.observations + return t3f.project_sum(self.sparsity_mask_list_tt, self.x, diff) + + def naive_hessian_by_vector(self): + return t3f.project(self.sparsity_mask_tt * t3f.project(self.vector, self.x), self.x) + + def smart_hessian_by_vector(self): + vector_nonzero = t3f.gather_nd(t3f.project(self.vector, self.x), self.observation_idx) + return t3f.project_sum(self.sparsity_mask_list_tt, self.x, vector_nonzero) + + +class BilinearXAX(Task): + + def __init__(self, m, n, d, tt_rank_mat, tt_rank_vec): + self.settings = {'n': n, 'm': m, 'd': d, 'tt_rank_mat': tt_rank_mat, 'tt_rank_vec': tt_rank_vec} + shape = ([m] * d, [n] * d) + ranks = prune_ranks(tt_rank_vec, shape[1]) + initialization = t3f.random_matrix((shape[1], None), tt_rank=ranks, dtype=tf.float64) + self.x = t3f.get_variable('x', initializer=initialization) + ranks = prune_ranks(2 * tt_rank_vec, shape[1]) + initialization = t3f.random_matrix((shape[1], None), tt_rank=ranks, dtype=tf.float64) + self.vector = t3f.get_variable('vector', initializer=initialization) + ranks = prune_ranks(tt_rank_mat, np.prod(shape, axis=0)) + mat = t3f.random_matrix(shape, tt_rank=ranks, dtype=tf.float64) + mat = t3f.transpose(mat) + mat + self.mat = t3f.get_variable('mat', initializer=mat) + + def loss(self, x): + return 0.5 * t3f.quadratic_form(self.mat, x, x) # DO NOT SUBMIT + + def naive_grad(self): + grad = t3f.matmul(self.mat, self.x) # DO NOT SUBMIT + return t3f.project(grad, self.x) + + def smart_grad(self): + return t3f.project_matmul(t3f.expand_batch_dim(self.x), self.x, self.mat)[0] # DO NOT SUBMIT + + def naive_hessian_by_vector(self): + grad = t3f.matmul(self.mat, self.vector) + return t3f.project(grad, self.x) + + def smart_hessian_by_vector(self): + return t3f.project_matmul(t3f.expand_batch_dim(self.vector), self.x, self.mat)[0] + + +class ExpMachines(Task): + + def __init__(self, n, d, tt_rank_vec, batch_size=32): + self.settings = {'n': n, 'd': d, 'tt_rank_vec': tt_rank_vec} + shape = [n] * d + ranks = prune_ranks(tt_rank_vec, shape) + initialization = t3f.random_tensor(shape, tt_rank=ranks, dtype=tf.float64) + self.x = t3f.get_variable('x', initializer=initialization) + initialization = t3f.random_tensor_batch(shape, tt_rank=1, dtype=tf.float64, batch_size=batch_size) + self.w = t3f.get_variable('w', initializer=initialization) + ranks = prune_ranks(2 * tt_rank_vec, shape) + initialization = t3f.random_tensor(shape, tt_rank=ranks, dtype=tf.float64) + self.vector = t3f.get_variable('vector', initializer=initialization) + + def loss(self, x): + l = t3f.flat_inner(x, self.w) + return tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=l, labels=tf.ones(self.w.batch_size, dtype=tf.float64))) + + def naive_grad(self): + e = tf.exp(-1. * t3f.flat_inner(self.x, self.w)) + c = -e / (1 + e) + grad = c[0] * self.w[0] + for i in range(1, self.w.batch_size): + grad += c[i] * self.w[i] + return t3f.project(grad, self.x) + + def smart_grad(self): + e = tf.exp(-1. * t3f.flat_inner(self.x, self.w)) + c = -e / (1 + e) + return t3f.project_sum(self.w, 1. * self.x, c) + + def naive_hessian_by_vector(self): + e = tf.exp(-1. * t3f.flat_inner(self.x, self.w)) + s = 1. / (1 + e) + c = s * (1 - s) + c *= t3f.flat_inner(self.vector, self.w) + res = c[0] * self.w[0] + for i in range(1, self.w.batch_size): + res += c[i] * self.w[i] + return t3f.project(res, self.x) + + def smart_hessian_by_vector(self): + e = tf.exp(-1. * t3f.flat_inner(self.x, self.w)) + s = 1. / (1 + e) + c = s * (1 - s) + c *= t3f.flat_inner(self.vector, self.w) + return t3f.project_sum(self.w, 1. * self.x, c) + + +class BilinearXABX(Task): + + def __init__(self, m, n, d, tt_rank_mat, tt_rank_vec): + self.settings = {'n': n, 'm': m, 'd': d, 'tt_rank_mat': tt_rank_mat, 'tt_rank_vec': tt_rank_vec} + shape = ([m] * d, [n] * d) + ranks = prune_ranks(tt_rank_vec, shape[1]) + initialization = t3f.random_matrix((shape[1], None), tt_rank=ranks, dtype=tf.float64) + self.x = t3f.get_variable('x', initializer=initialization) + ranks = prune_ranks(2 * tt_rank_vec, shape[1]) + initialization = t3f.random_matrix((shape[1], None), tt_rank=ranks, dtype=tf.float64) + self.vector = t3f.get_variable('vector', initializer=initialization) + ranks = prune_ranks(tt_rank_mat, np.prod(shape, axis=0)) + initialization = t3f.random_matrix(shape, tt_rank=ranks, dtype=tf.float64) + self.mat = t3f.get_variable('mat', initializer=initialization) + + def loss(self, x): + return 0.5 * t3f.bilinear_xaby(x, t3f.transpose(self.mat), self.mat, x) + + def naive_grad(self): + grad = t3f.matmul(t3f.transpose(self.mat), t3f.matmul(self.mat, self.x)) + return t3f.project(grad, self.x) + + def smart_grad(self): + raise NotImplementedError() + + def naive_hessian_by_vector(self): + projected_vec = t3f.project(self.vector, self.x) + return t3f.project(t3f.matmul(t3f.transpose(self.mat), t3f.matmul(self.mat, projected_vec)), self.x) + + def smart_hessian_by_vector(self): + raise NotImplementedError() + + +class RayleighQuotient(Task): + + def __init__(self, m, n, d, tt_rank_mat, tt_rank_vec): + self.settings = {'n': n, 'm': m, 'd': d, 'tt_rank_mat': tt_rank_mat, 'tt_rank_vec': tt_rank_vec} + shape = ([m] * d, [n] * d) + ranks = prune_ranks(tt_rank_vec, shape[1]) + initialization = t3f.random_matrix((shape[1], None), tt_rank=ranks, dtype=tf.float64) + self.x = t3f.get_variable('x', initializer=initialization) + ranks = prune_ranks(2 * tt_rank_vec, shape[1]) + initialization = t3f.random_matrix((shape[1], None), tt_rank=ranks, dtype=tf.float64) + self.vector = t3f.get_variable('vector', initializer=initialization) + ranks = prune_ranks(tt_rank_mat, np.prod(shape, axis=0)) + mat = t3f.random_matrix(shape, tt_rank=ranks, dtype=tf.float64) + mat = t3f.transpose(mat) + mat + self.mat = t3f.get_variable('mat', initializer=mat) + + def loss(self, x): + xAx = t3f.quadratic_form(self.mat, x, x) # bilinear_form + xx = t3f.flat_inner(x, x) + return xAx / xx + + def naive_grad(self): + xAx = t3f.quadratic_form(self.mat, self.x, self.x) # bilinear_form + xx = t3f.flat_inner(self.x, self.x) + grad = (1. / xx) * t3f.matmul(self.mat, self.x) + grad -= (xAx / (xx**2)) * self.x + return t3f.project(2 * grad, self.x) + + def smart_grad(self): + xAx = t3f.quadratic_form(self.mat, self.x, self.x) # bilinear_form + xx = t3f.frobenius_norm_squared(self.x, differentiable=True) + grad = (1. / xx) * t3f.project_matmul(t3f.expand_batch_dim(self.x), self.x, self.mat)[0] + grad -= (xAx / xx**2) * self.x + return 2 * grad + + def naive_hessian_by_vector(self): + xAx = t3f.quadratic_form(self.mat, self.x, self.x) # bilinear_form + xx = t3f.frobenius_norm_squared(self.x, differentiable=True) + res = (2 / xx) * t3f.matmul(self.mat, self.vector) + res -= (2 * xAx / xx**2) * self.vector + xv = t3f.flat_inner(self.x, self.vector) + res -= (4 * t3f.quadratic_form(self.mat, self.vector, self.x) / xx**2) * self.x + res -= (4 * xv / xx**2) * t3f.matmul(self.mat, self.x) + res += (8 * xAx * xv / xx**3) * self.x + return t3f.project(res, self.x) + + def smart_hessian_by_vector(self): + xAx = t3f.quadratic_form(self.mat, self.x, self.x) # bilinear_form + xx = t3f.frobenius_norm_squared(self.x, differentiable=True) + projected_vec = t3f.project(self.vector, self.x) + res = (2 / xx) * t3f.project_matmul(t3f.expand_batch_dim(self.vector), self.x, self.mat)[0] + res -= (2 * xAx / xx**2) * projected_vec + xv = t3f.flat_inner(self.x, projected_vec) + res -= (4 * t3f.quadratic_form(self.mat, self.vector, self.x) / xx**2) * self.x + res -= (4 * xv / xx**2) * t3f.project_matmul(t3f.expand_batch_dim(self.x), self.x, self.mat)[0] + res += (8 * xAx * xv / xx**3) * self.x + return res + + +def exist(all_logs, case_name, case): + for l in all_logs[case_name]: + s = l['settings'] + coincide = True + for k in case.settings: + if s[k] != case.settings[k]: + coincide = False + if coincide: + return True + return False + + +def did_smaller_fail(all_logs, name, case_name, case): + for l in all_logs[case_name]: + s = l['settings'] + if name in l and l[name] is None: + # If this attempt failed. + smaller = True + for k in case.settings: + if s[k] > case.settings[k]: + smaller = False + if smaller: + return True + return False + + +def benchmark(case_name, case, logs_path): + naive_grad = case.naive_grad() + auto_grad = t3f.gradients(case.loss, case.x, runtime_check=False) + try: + smart_grad = case.smart_grad() + except NotImplementedError: + smart_grad = None + + naive_hv = case.naive_hessian_by_vector() + auto_hv = t3f.hessian_vector_product(case.loss, case.x, case.vector, runtime_check=False) + try: + smart_hv = case.smart_hessian_by_vector() + except NotImplementedError: + smart_hv = None + try: + with open(logs_path, "rb") as output_file: + # Dict with case_name -> list of configurations. + all_logs = pickle.load(output_file) + except: + all_logs = {} + if case_name not in all_logs: + all_logs[case_name] = [] + # Single configuration. + current_case_logs = {} + with tf.Session(config=tf.test.benchmark_config()) as sess: + sess.run(tf.global_variables_initializer()) + benchmark = tf.test.Benchmark() + + current_case_logs['settings'] = case.settings + if exist(all_logs, case_name, case): + print('skipping') + return None + + + def benchmark_single(op, name, current_case_logs): + # First write None to indicate the attempt. + if os.path.exists(logs_path): + copyfile(logs_path, logs_path + '_back') + with open(logs_path, "wb") as output_file: + all_logs_curr = copy.deepcopy(all_logs) + current_case_logs[name] = None + all_logs_curr[case_name].append(current_case_logs) + pickle.dump(all_logs_curr, output_file) + + try: + if did_smaller_fail(all_logs, name, case_name, case): + # No point in trying again, a smaller example failed already. + raise ValueError() + logs = benchmark.run_op_benchmark(sess, op) + current_case_logs[name] = logs + except: + current_case_logs[name] = None + + copyfile(logs_path, logs_path + '_back') + with open(logs_path, "wb") as output_file: + all_logs_curr = copy.deepcopy(all_logs) + all_logs_curr[case_name].append(current_case_logs) + pickle.dump(all_logs_curr, output_file) + + benchmark_single(auto_grad.op, 'auto_grad', current_case_logs) + benchmark_single(auto_hv.op, 'auto_hv', current_case_logs) + if smart_grad is not None: + benchmark_single(smart_grad.op, 'smart_grad', current_case_logs) + if smart_hv is not None: + benchmark_single(smart_hv.op, 'smart_hv', current_case_logs) + benchmark_single(naive_grad.op, 'naive_grad', current_case_logs) + benchmark_single(naive_hv.op, 'naive_hv', current_case_logs) + return current_case_logs diff --git a/docs/benchmark/utils_test.py b/docs/benchmark/utils_test.py new file mode 100644 index 00000000..2ce7dbb1 --- /dev/null +++ b/docs/benchmark/utils_test.py @@ -0,0 +1,85 @@ +import numpy as np +import tensorflow as tf + +import utils +import t3f + + +class UtilsTest(tf.test.TestCase): + + def _TestCaseGrad(self, case): + with self.session(): + tf.global_variables_initializer().run() + tensors = [] + manual = case.naive_grad() + try: + manual_2 = case.smart_grad() + self.assertAllClose(t3f.full(manual).eval(), t3f.full(manual_2).eval()) + except NotImplementedError: + pass + auto_g = t3f.gradients(case.loss, case.x, runtime_check=True) + self.assertAllClose(t3f.full(manual).eval(), t3f.full(auto_g).eval(), rtol=1e-5) + + def _TestCaseHess(self, case): + with self.session(): + tf.global_variables_initializer().run() + manual = case.naive_hessian_by_vector() + try: + manual_2 = case.smart_hessian_by_vector() + self.assertAllClose(t3f.full(manual).eval(), t3f.full(manual_2).eval()) + except NotImplementedError: + pass + auto_hv = t3f.hessian_vector_product(case.loss, case.x, case.vector, runtime_check=True) + self.assertAllClose(t3f.full(manual).eval(), t3f.full(auto_hv).eval(), rtol=1e-5) + + def testCompletionGrad(self): + test_case = utils.Completion(3, 3, 4) + self._TestCaseGrad(test_case) + + def testCompletionHess(self): + test_case = utils.Completion(3, 3, 4) + self._TestCaseHess(test_case) + + def testXAXGrad(self): + test_case = utils.BilinearXAX(3, 3, 3, 4, 5) + self._TestCaseGrad(test_case) + + def testXAXHess(self): + test_case = utils.BilinearXAX(3, 3, 3, 4, 5) + self._TestCaseHess(test_case) + + def testXABXGrad(self): + test_case = utils.BilinearXABX(3, 3, 3, 4, 5) + self._TestCaseGrad(test_case) + + def testXABXHess(self): + test_case = utils.BilinearXABX(3, 3, 3, 4, 5) + self._TestCaseHess(test_case) + + def testExpMachinesGrad(self): + test_case = utils.ExpMachines(3, 4, 5, batch_size=3) + self._TestCaseGrad(test_case) + + def testExpMachinesHess(self): + test_case = utils.ExpMachines(3, 3, 3, batch_size=2) + self._TestCaseHess(test_case) + + def testRayleighQuotientGrad(self): + test_case = utils.RayleighQuotient(3, 3, 3, 4, 5) + self._TestCaseGrad(test_case) + + def testRayleighQuotientHess(self): + test_case = utils.RayleighQuotient(3, 3, 3, 4, 5) + self._TestCaseHess(test_case) + + +# class AutodiffTestFloat32(tf.test.TestCase, _AutodiffTest): +# dtype = tf.float32 + + +# class AutodiffTestFloat64(tf.test.TestCase, _AutodiffTest): +# dtype = tf.float64 + + +if __name__ == "__main__": + tf.test.main() diff --git a/setup.py b/setup.py index 5b165d43..b02fca67 100644 --- a/setup.py +++ b/setup.py @@ -10,5 +10,6 @@ packages=['t3f'], install_requires=[ 'numpy', + 'opt_einsum', ], zip_safe=False) diff --git a/t3f/__init__.py b/t3f/__init__.py index 1b57c500..68218428 100644 --- a/t3f/__init__.py +++ b/t3f/__init__.py @@ -15,6 +15,7 @@ from t3f.ops import multiply from t3f.ops import quadratic_form from t3f.ops import bilinear_form +from t3f.ops import bilinear_form_two_mat from t3f.ops import transpose from t3f.ops import gather_nd from t3f.ops import renormalize_tt_cores diff --git a/t3f/batch_ops.py b/t3f/batch_ops.py index f062e0e1..2f43f7df 100644 --- a/t3f/batch_ops.py +++ b/t3f/batch_ops.py @@ -4,6 +4,7 @@ from t3f.tensor_train_base import TensorTrainBase from t3f.tensor_train_batch import TensorTrainBatch from t3f import ops +from t3f import utils def concat_along_batch_dim(tt_list, name='t3f_concat_along_batch_dim'): @@ -168,12 +169,12 @@ def pairwise_flat_inner(tt_1, tt_2, matrix=None, curr_core_2 = tt_2.tt_cores[0] mode_string = 'ij' if tt_1.is_tt_matrix() else 'i' einsum_str = 'pa{0}b,qc{0}d->pqbd'.format(mode_string) - res = tf.einsum(einsum_str, curr_core_1, curr_core_2) + res = utils.einsum(einsum_str, curr_core_1, curr_core_2) for core_idx in range(1, ndims): curr_core_1 = tt_1.tt_cores[core_idx] curr_core_2 = tt_2.tt_cores[core_idx] einsum_str = 'pqac,pa{0}b,qc{0}d->pqbd'.format(mode_string) - res = tf.einsum(einsum_str, res, curr_core_1, curr_core_2) + res = utils.einsum(einsum_str, res, curr_core_1, curr_core_2) else: # res[i, j] = tt_1[i] ^ T * matrix * tt_2[j] are_all_matrices = tt_1.is_tt_matrix() and tt_2.is_tt_matrix() @@ -221,13 +222,13 @@ def pairwise_flat_inner(tt_1, tt_2, matrix=None, curr_core_2 = tt_2.tt_cores[0] curr_matrix_core = matrix.tt_cores[0] # We enumerate the dummy dimension (that takes 1 value) with `k`. - res = tf.einsum('pakib,cijd,qekjf->pqbdf', curr_core_1, curr_matrix_core, + res = utils.einsum('pakib,cijd,qekjf->pqbdf', curr_core_1, curr_matrix_core, curr_core_2) for core_idx in range(1, ndims): curr_core_1 = tt_1.tt_cores[core_idx] curr_core_2 = tt_2.tt_cores[core_idx] curr_matrix_core = matrix.tt_cores[core_idx] - res = tf.einsum('pqace,pakib,cijd,qekjf->pqbdf', res, curr_core_1, + res = utils.einsum('pqace,pakib,cijd,qekjf->pqbdf', res, curr_core_1, curr_matrix_core, curr_core_2) # Squeeze to make the result of size batch_size x batch_size instead of diff --git a/t3f/ops.py b/t3f/ops.py index 86be42ec..85f4346a 100644 --- a/t3f/ops.py +++ b/t3f/ops.py @@ -85,7 +85,7 @@ def _full_tt_batch(tt): for i in range(1, num_dims): res = tf.reshape(res, (batch_size, -1, ranks[i])) curr_core = tf.reshape(tt.tt_cores[i], (batch_size, ranks[i], -1)) - res = tf.einsum('oqb,obw->oqw', res, curr_core) + res = utils.einsum('oqb,obw->oqw', res, curr_core) if tt.is_tt_matrix(): intermediate_shape = [batch_size] for i in range(num_dims): @@ -161,7 +161,7 @@ def tt_tt_matmul(tt_matrix_a, tt_matrix_b): for core_idx in range(ndims): a_core = tt_matrix_a.tt_cores[core_idx] b_core = tt_matrix_b.tt_cores[core_idx] - curr_res_core = tf.einsum(einsum_str, a_core, b_core) + curr_res_core = utils.einsum(einsum_str, a_core, b_core) res_left_rank = a_ranks[core_idx] * b_ranks[core_idx] res_right_rank = a_ranks[core_idx + 1] * b_ranks[core_idx + 1] @@ -221,7 +221,7 @@ def tt_dense_matmul(tt_matrix_a, matrix_b): curr_core = tt_matrix_a.tt_cores[core_idx] # On the k = core_idx iteration, after applying einsum the shape of data # becomes ik x (ik-1..., id-1, K, j0, ..., jk-1) x rank_k - data = tf.einsum('aijb,rjb->ira', curr_core, data) + data = utils.einsum('aijb,rjb->ira', curr_core, data) if core_idx > 0: # After reshape the shape of data becomes # (ik, ..., id-1, K, j0, ..., jk-2) x jk-1 x rank_k @@ -384,8 +384,8 @@ def tt_tt_flat_inner(tt_a, tt_b): b_core = tt_b.tt_cores[0] # Simplest example of this operation: # if both arguments are TT-tensors, then it is - # res = tf.einsum('aib,cid->bd', a_core, b_core) - res = tf.einsum(init_einsum_str, a_core, b_core) + # res = utils.einsum('aib,cid->bd', a_core, b_core) + res = utils.einsum(init_einsum_str, a_core, b_core) einsum_str = '{3}ac,{1}a{0}b,{2}c{0}d->{3}bd'.format(axes_str, a_batch_str, b_batch_str, @@ -395,8 +395,8 @@ def tt_tt_flat_inner(tt_a, tt_b): b_core = tt_b.tt_cores[core_idx] # Simplest example of this operation: # if both arguments are TT-tensors, then it is - # res = tf.einsum('ac,aib,cid->bd', res, a_core, b_core) - res = tf.einsum(einsum_str, res, a_core, b_core) + # res = utils.einsum('ac,aib,cid->bd', res, a_core, b_core) + res = utils.einsum(einsum_str, res, a_core, b_core) return tf.squeeze(res) @@ -891,7 +891,7 @@ def multiply(tt_left, right, name='t3f_multiply'): right_rank = a_ranks[core_idx + 1] * b_ranks[core_idx + 1] if is_matrix: with tf.control_dependencies(dependencies): - curr_core = tf.einsum('{0}aijb,{1}cijd->{2}acijbd'.format(bs_str_left, + curr_core = utils.einsum('{0}aijb,{1}cijd->{2}acijbd'.format(bs_str_left, bs_str_right, output_str), a_core, b_core) curr_core = tf.reshape(curr_core, (-1, left_rank, shape[0][core_idx], @@ -901,7 +901,7 @@ def multiply(tt_left, right, name='t3f_multiply'): curr_core = tf.squeeze(curr_core, axis=0) else: with tf.control_dependencies(dependencies): - curr_core = tf.einsum('{0}aib,{1}cid->{2}acibd'.format(bs_str_left, + curr_core = utils.einsum('{0}aib,{1}cid->{2}acibd'.format(bs_str_left, bs_str_right, output_str), a_core, b_core) curr_core = tf.reshape(curr_core, (-1, left_rank, shape[0][core_idx], right_rank)) @@ -944,19 +944,19 @@ def frobenius_norm_squared(tt, differentiable=False, else: bs_str = '' if tt.is_tt_matrix(): - running_prod = tf.einsum('{0}aijb,{0}cijd->{0}bd'.format(bs_str), + running_prod = utils.einsum('{0}aijb,{0}cijd->{0}bd'.format(bs_str), tt.tt_cores[0], tt.tt_cores[0]) else: - running_prod = tf.einsum('{0}aib,{0}cid->{0}bd'.format(bs_str), + running_prod = utils.einsum('{0}aib,{0}cid->{0}bd'.format(bs_str), tt.tt_cores[0], tt.tt_cores[0]) for core_idx in range(1, tt.ndims()): curr_core = tt.tt_cores[core_idx] if tt.is_tt_matrix(): - running_prod = tf.einsum('{0}ac,{0}aijb,{0}cijd->{0}bd'.format(bs_str), + running_prod = utils.einsum('{0}ac,{0}aijb,{0}cijd->{0}bd'.format(bs_str), running_prod, curr_core, curr_core) else: - running_prod = tf.einsum('{0}ac,{0}aib,{0}cid->{0}bd'.format(bs_str), + running_prod = utils.einsum('{0}ac,{0}aib,{0}cid->{0}bd'.format(bs_str), running_prod, curr_core, curr_core) return tf.squeeze(running_prod, [-1, -2]) @@ -1098,7 +1098,7 @@ def bilinear_form(A, b, c, name='t3f_bilinear_form'): # experience it's even a little bit slower (but neglectable in general). einsum_str = '{0}aikb,cijd,{1}ejkf->{2}bdf'.format(b_bs_str, c_bs_str, out_bs_str) - res = tf.einsum(einsum_str, curr_core_1, curr_matrix_core, curr_core_2) + res = utils.einsum(einsum_str, curr_core_1, curr_matrix_core, curr_core_2) for core_idx in range(1, ndims): curr_core_1 = b.tt_cores[core_idx] curr_core_2 = c.tt_cores[core_idx] @@ -1106,8 +1106,8 @@ def bilinear_form(A, b, c, name='t3f_bilinear_form'): einsum_str = '{2}ace,{0}aikb,cijd,{1}ejkf->{2}bdf'.format(b_bs_str, c_bs_str, out_bs_str) - res = tf.einsum(einsum_str, res, curr_core_1, - curr_matrix_core, curr_core_2) + res = utils.einsum(einsum_str, res, curr_core_1, + curr_matrix_core, curr_core_2) # Squeeze to make the result a number instead of 1 x 1 for NON batch case # and to make the result a tensor of size @@ -1120,7 +1120,6 @@ def bilinear_form(A, b, c, name='t3f_bilinear_form'): def bilinear_form_two_mat(x, A, B, y, name='t3f_bilinear_xaby'): """Bilinear form x^t A B y; A are B are TT-matrices, x and y can be batches. - Args: x: `TensorTrain` object containing a TT-matrix of size N x 1 or `TensorTrainBatch` with a batch of TT-matrices of size N x 1. diff --git a/t3f/riemannian.py b/t3f/riemannian.py index 649d3537..779c445b 100644 --- a/t3f/riemannian.py +++ b/t3f/riemannian.py @@ -4,6 +4,7 @@ from t3f.tensor_train_batch import TensorTrainBatch from t3f import shapes from t3f import decompositions +from t3f import utils def project_sum(what, where, weights=None): @@ -97,7 +98,7 @@ def project_sum(what, where, weights=None): tens_core = what.tt_cores[core_idx] right_tang_core = right_tangent_space_tens.tt_cores[core_idx] einsum_str = 'sa{0}b,sbd,c{0}d->sac'.format(mode_str) - rhs[core_idx] = tf.einsum(einsum_str, tens_core, rhs[core_idx + 1], + rhs[core_idx] = utils.einsum(einsum_str, tens_core, rhs[core_idx + 1], right_tang_core) # Prepare lhs vectors. @@ -109,7 +110,7 @@ def project_sum(what, where, weights=None): tens_core = what.tt_cores[core_idx] left_tang_core = left_tangent_space_tens.tt_cores[core_idx] einsum_str = 'sab,a{0}c,sb{0}d->scd'.format(mode_str) - lhs[core_idx + 1] = tf.einsum(einsum_str, lhs[core_idx], left_tang_core, + lhs[core_idx + 1] = utils.einsum(einsum_str, lhs[core_idx], left_tang_core, tens_core) # Left to right sweep. @@ -121,27 +122,27 @@ def project_sum(what, where, weights=None): if core_idx < ndims - 1: einsum_str = 'sab,sb{0}c->sa{0}c'.format(mode_str) - proj_core = tf.einsum(einsum_str, lhs[core_idx], tens_core) + proj_core = utils.einsum(einsum_str, lhs[core_idx], tens_core) einsum_str = 'a{0}b,sbc->sa{0}c'.format(mode_str) - proj_core -= tf.einsum(einsum_str, left_tang_core, lhs[core_idx + 1]) + proj_core -= utils.einsum(einsum_str, left_tang_core, lhs[core_idx + 1]) if weights is None: einsum_str = 'sa{0}b,sbc->a{0}c'.format(mode_str) - proj_core = tf.einsum(einsum_str, proj_core, rhs[core_idx + 1]) + proj_core = utils.einsum(einsum_str, proj_core, rhs[core_idx + 1]) else: einsum_str = 'sa{0}b,sbc->sa{0}c'.format(mode_str, output_batch_str) - proj_core_s = tf.einsum(einsum_str, proj_core, rhs[core_idx + 1]) + proj_core_s = utils.einsum(einsum_str, proj_core, rhs[core_idx + 1]) einsum_str = 's{1},sa{0}c->{1}a{0}c'.format(mode_str, output_batch_str) - proj_core = tf.einsum(einsum_str, weights, proj_core_s) + proj_core = utils.einsum(einsum_str, weights, proj_core_s) if core_idx == ndims - 1: if weights is None: einsum_str = 'sab,sb{0}c->a{0}c'.format(mode_str) - proj_core = tf.einsum(einsum_str, lhs[core_idx], tens_core) + proj_core = utils.einsum(einsum_str, lhs[core_idx], tens_core) else: einsum_str = 'sab,sb{0}c->sa{0}c'.format(mode_str, output_batch_str) - proj_core_s = tf.einsum(einsum_str, lhs[core_idx], tens_core) + proj_core_s = utils.einsum(einsum_str, lhs[core_idx], tens_core) einsum_str = 's{1},sa{0}c->{1}a{0}c'.format(mode_str, output_batch_str) - proj_core = tf.einsum(einsum_str, weights, proj_core_s) + proj_core = utils.einsum(einsum_str, weights, proj_core_s) if output_is_batch: # Add batch dimension of size output_batch_size to left_tang_core and @@ -275,7 +276,7 @@ def project(what, where): tens_core = what.tt_cores[core_idx] right_tang_core = right_tangent_space_tens.tt_cores[core_idx] einsum_str = 'sa{0}b,sbd,c{0}d->sac'.format(mode_str) - rhs[core_idx] = tf.einsum(einsum_str, tens_core, rhs[core_idx + 1], + rhs[core_idx] = utils.einsum(einsum_str, tens_core, rhs[core_idx + 1], right_tang_core) # Prepare lhs vectors. @@ -287,7 +288,7 @@ def project(what, where): tens_core = what.tt_cores[core_idx] left_tang_core = left_tangent_space_tens.tt_cores[core_idx] einsum_str = 'sab,a{0}c,sb{0}d->scd'.format(mode_str) - lhs[core_idx + 1] = tf.einsum(einsum_str, lhs[core_idx], left_tang_core, + lhs[core_idx + 1] = utils.einsum(einsum_str, lhs[core_idx], left_tang_core, tens_core) # Left to right sweep. @@ -299,21 +300,21 @@ def project(what, where): if core_idx < ndims - 1: einsum_str = 'sab,sb{0}c->sa{0}c'.format(mode_str) - proj_core = tf.einsum(einsum_str, lhs[core_idx], tens_core) + proj_core = utils.einsum(einsum_str, lhs[core_idx], tens_core) einsum_str = 'a{0}b,sbc->sa{0}c'.format(mode_str) - proj_core -= tf.einsum(einsum_str, left_tang_core, lhs[core_idx + 1]) + proj_core -= utils.einsum(einsum_str, left_tang_core, lhs[core_idx + 1]) if output_is_batch: einsum_str = 'sa{0}b,sbc->sa{0}c'.format(mode_str) else: einsum_str = 'sa{0}b,sbc->a{0}c'.format(mode_str) - proj_core = tf.einsum(einsum_str, proj_core, rhs[core_idx + 1]) + proj_core = utils.einsum(einsum_str, proj_core, rhs[core_idx + 1]) if core_idx == ndims - 1: if output_is_batch: einsum_str = 'sab,sb{0}c->sa{0}c'.format(mode_str) else: einsum_str = 'sab,sb{0}c->a{0}c'.format(mode_str) - proj_core = tf.einsum(einsum_str, lhs[core_idx], tens_core) + proj_core = utils.einsum(einsum_str, lhs[core_idx], tens_core) if output_is_batch: # Add batch dimension of size output_batch_size to left_tang_core and @@ -446,7 +447,7 @@ def project_matmul(what, where, matrix): tens_core = what.tt_cores[core_idx] right_tang_core = right_tangent_space_tens.tt_cores[core_idx] matrix_core = matrix.tt_cores[core_idx] - rhs[core_idx] = tf.einsum('bije,cikf,sdef,sajkd->sabc', matrix_core, + rhs[core_idx] = utils.einsum('bije,cikf,sdef,sajkd->sabc', matrix_core, right_tang_core, rhs[core_idx + 1], tens_core) # Prepare lhs vectors. # lhs[core_idx] is of size @@ -458,7 +459,7 @@ def project_matmul(what, where, matrix): left_tang_core = left_tangent_space_tens.tt_cores[core_idx] matrix_core = matrix.tt_cores[core_idx] # TODO: brutforce order of indices in lhs?? - lhs[core_idx + 1] = tf.einsum('bije,aikd,sabc,scjkf->sdef', matrix_core, + lhs[core_idx + 1] = utils.einsum('bije,aikd,sabc,scjkf->sdef', matrix_core, left_tang_core, lhs[core_idx], tens_core) # Left to right sweep. @@ -470,17 +471,17 @@ def project_matmul(what, where, matrix): right_tang_core = right_tangent_space_tens.tt_cores[core_idx] if core_idx < ndims - 1: - proj_core = tf.einsum('scjke,sabc,bijd->saikde', tens_core, + proj_core = utils.einsum('scjke,sabc,bijd->saikde', tens_core, lhs[core_idx], matrix_core) - proj_core -= tf.einsum('aikb,sbcd->saikcd', left_tang_core, + proj_core -= utils.einsum('aikb,sbcd->saikcd', left_tang_core, lhs[core_idx + 1]) - proj_core = tf.einsum('saikcb,sbcd->saikd', proj_core, rhs[core_idx + 1]) + proj_core = utils.einsum('saikcb,sbcd->saikd', proj_core, rhs[core_idx + 1]) if core_idx == ndims - 1: # d and e dimensions take 1 value, since its the last rank. # To make the result shape (?, ?, ?, 1), we are summing d and leaving e, # but we could have done the opposite -- sum e and leave d. - proj_core = tf.einsum('sabc,bijd,scjke->saike', lhs[core_idx], matrix_core, + proj_core = utils.einsum('sabc,bijd,scjke->saike', lhs[core_idx], matrix_core, tens_core) if output_is_batch: @@ -586,7 +587,7 @@ def pairwise_flat_inner_projected(projected_tt_vectors_1, curr_core_2 = projected_tt_vectors_2.tt_cores[0] curr_du_1 = curr_core_1[:, :, :, :, :right_size] curr_du_2 = curr_core_2[:, :, :, :, :right_size] - res = tf.einsum('paijb,qaijb->pq', curr_du_1, curr_du_2) + res = utils.einsum('paijb,qaijb->pq', curr_du_1, curr_du_2) for core_idx in range(1, ndims): left_size = tt_ranks[core_idx] // 2 right_size = tt_ranks[core_idx + 1] // 2 @@ -594,14 +595,14 @@ def pairwise_flat_inner_projected(projected_tt_vectors_1, curr_core_2 = projected_tt_vectors_2.tt_cores[core_idx] curr_du_1 = curr_core_1[:, left_size:, :, :, :right_size] curr_du_2 = curr_core_2[:, left_size:, :, :, :right_size] - res += tf.einsum('paijb,qaijb->pq', curr_du_1, curr_du_2) + res += utils.einsum('paijb,qaijb->pq', curr_du_1, curr_du_2) left_size = tt_ranks[-2] // 2 curr_core_1 = projected_tt_vectors_1.tt_cores[-1] curr_core_2 = projected_tt_vectors_2.tt_cores[-1] curr_du_1 = curr_core_1[:, left_size:, :, :, :] curr_du_2 = curr_core_2[:, left_size:, :, :, :] - res += tf.einsum('paijb,qaijb->pq', curr_du_1, curr_du_2) + res += utils.einsum('paijb,qaijb->pq', curr_du_1, curr_du_2) else: # Working with TT-tensor, not TT-matrix. right_size = tt_ranks[1] // 2 @@ -609,7 +610,7 @@ def pairwise_flat_inner_projected(projected_tt_vectors_1, curr_core_2 = projected_tt_vectors_2.tt_cores[0] curr_du_1 = curr_core_1[:, :, :, :right_size] curr_du_2 = curr_core_2[:, :, :, :right_size] - res = tf.einsum('paib,qaib->pq', curr_du_1, curr_du_2) + res = utils.einsum('paib,qaib->pq', curr_du_1, curr_du_2) for core_idx in range(1, ndims): left_size = tt_ranks[core_idx] // 2 right_size = tt_ranks[core_idx + 1] // 2 @@ -617,14 +618,14 @@ def pairwise_flat_inner_projected(projected_tt_vectors_1, curr_core_2 = projected_tt_vectors_2.tt_cores[core_idx] curr_du_1 = curr_core_1[:, left_size:, :, :right_size] curr_du_2 = curr_core_2[:, left_size:, :, :right_size] - res += tf.einsum('paib,qaib->pq', curr_du_1, curr_du_2) + res += utils.einsum('paib,qaib->pq', curr_du_1, curr_du_2) left_size = tt_ranks[-2] // 2 curr_core_1 = projected_tt_vectors_1.tt_cores[-1] curr_core_2 = projected_tt_vectors_2.tt_cores[-1] curr_du_1 = curr_core_1[:, left_size:, :, :] curr_du_2 = curr_core_2[:, left_size:, :, :] - res += tf.einsum('paib,qaib->pq', curr_du_1, curr_du_2) + res += utils.einsum('paib,qaib->pq', curr_du_1, curr_du_2) return res diff --git a/t3f/tensor_train.py b/t3f/tensor_train.py index 79807004..7fb14256 100644 --- a/t3f/tensor_train.py +++ b/t3f/tensor_train.py @@ -2,6 +2,7 @@ from t3f.tensor_train_base import TensorTrainBase from t3f import shapes +from t3f import utils class TensorTrain(TensorTrainBase): @@ -130,13 +131,13 @@ def __getitem__(self, slice_spec): if remainder is not None: # Add reminder from the previous collapsed cores to the current # core. - sliced_core = tf.einsum('ab,bid->aid', remainder, sliced_core) + sliced_core = utils.einsum('ab,bid->aid', remainder, sliced_core) remainder = None new_tt_cores.append(sliced_core) if remainder is not None: # The reminder obtained from collapsing the last cores. - new_tt_cores[-1] = tf.einsum('aib,bd->aid', new_tt_cores[-1], remainder) + new_tt_cores[-1] = utils.einsum('aib,bd->aid', new_tt_cores[-1], remainder) remainder = None # TODO: infer the output ranks and shape. return TensorTrain(new_tt_cores) diff --git a/t3f/tensor_train_batch.py b/t3f/tensor_train_batch.py index 976582ac..0a180317 100644 --- a/t3f/tensor_train_batch.py +++ b/t3f/tensor_train_batch.py @@ -4,6 +4,7 @@ from t3f.tensor_train_base import TensorTrainBase from t3f.tensor_train import TensorTrain from t3f import shapes +from t3f import utils class TensorTrainBatch(TensorTrainBase): @@ -203,17 +204,17 @@ def _full_getitem(self, slice_spec): remainder = sliced_core else: if do_collapse_batch_dim: - remainder = tf.einsum('ab,bd->ad', remainder, sliced_core) + remainder = utils.einsum('ab,bd->ad', remainder, sliced_core) else: - remainder = tf.einsum('oab,obd->oad', remainder, sliced_core) + remainder = utils.einsum('oab,obd->oad', remainder, sliced_core) else: if remainder is not None: # Add reminder from the previous collapsed cores to the current # core. if do_collapse_batch_dim: - sliced_core = tf.einsum('ab,bid->aid', remainder, sliced_core) + sliced_core = utils.einsum('ab,bid->aid', remainder, sliced_core) else: - sliced_core = tf.einsum('oab,obid->oaid', remainder, + sliced_core = utils.einsum('oab,obid->oaid', remainder, sliced_core) remainder = None new_tt_cores.append(sliced_core) @@ -221,11 +222,11 @@ def _full_getitem(self, slice_spec): if remainder is not None: # The reminder obtained from collapsing the last cores. if do_collapse_batch_dim: - new_tt_cores[-1] = tf.einsum('aib,bd->aid', new_tt_cores[-1], + new_tt_cores[-1] = utils.einsum('aib,bd->aid', new_tt_cores[-1], remainder) else: - new_tt_cores[-1] = tf.einsum('oaib,obd->oaid', new_tt_cores[-1], + new_tt_cores[-1] = utils.einsum('oaib,obd->oaid', new_tt_cores[-1], remainder) remainder = None # TODO: infer the output ranks and shape. diff --git a/t3f/utils.py b/t3f/utils.py index 906664a0..7904e2b0 100644 --- a/t3f/utils.py +++ b/t3f/utils.py @@ -1,5 +1,10 @@ import numpy as np import tensorflow as tf +from opt_einsum import contract + + +def einsum(*args, **kargs): + return contract(*args, **kargs, backend='tensorflow', optimize='optimal') # TODO: substitute with native implementation when it's ready.