From f26099e809deeaa61d8609494008ba62ee907a60 Mon Sep 17 00:00:00 2001 From: Alexander Novikov Date: Wed, 18 Oct 2017 14:02:07 +0300 Subject: [PATCH 1/4] Init batch dimension gather --- t3f/__init__.py | 1 + t3f/ops.py | 11 +++++++++++ 2 files changed, 12 insertions(+) diff --git a/t3f/__init__.py b/t3f/__init__.py index 10402eec..c9a9b772 100644 --- a/t3f/__init__.py +++ b/t3f/__init__.py @@ -15,6 +15,7 @@ from t3f.ops import multiply from t3f.ops import quadratic_form from t3f.ops import transpose +from t3f.ops import gather_batch_dim from t3f.ops import gather_nd from t3f.batch_ops import concat_along_batch_dim diff --git a/t3f/ops.py b/t3f/ops.py index 9cfd2833..a9a2993c 100644 --- a/t3f/ops.py +++ b/t3f/ops.py @@ -1080,6 +1080,17 @@ def cast(tt_a, dtype): 'TensorTrainBatch.' % tt_a) +def gather_batch_dim(tt_batch, indices): + """out[i] = tt_batch[indices[i]] + + """ + new_tt_cores = [] + for core in tt_batch.tt_cores: + new_tt_cores.append(tf.gather(core, indices)) + return TensorTrainBatch(new_tt_cores, tt_batch.get_raw_shape(), + tt_batch.get_tt_ranks(), tt_batch.batch_size()) + + def gather_nd(tt, indices): """out[i] = tt[indices[i, 0], indices[i, 1], ...] From 6485b41995142c40501533fd97ed6fb8baf298e7 Mon Sep 17 00:00:00 2001 From: Alexander Novikov Date: Wed, 18 Oct 2017 23:16:53 +0300 Subject: [PATCH 2/4] Move test to appropriate class --- t3f/ops_test.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/t3f/ops_test.py b/t3f/ops_test.py index 4fb29032..27af0cf0 100644 --- a/t3f/ops_test.py +++ b/t3f/ops_test.py @@ -165,6 +165,19 @@ def testCastIntFloat(self): self.assertEqual(dtype, casted.dtype) self.assertTrue(dtype, casted_val.dtype) + def testGatherND(self): + idx = [[0, 0, 0], [0, 1, 2], [0, 1, 0]] + pl_idx = tf.placeholder(tf.int32, [None, 3]) + tt = initializers.random_tensor((3, 4, 5), tt_rank=2) + res_np = ops.gather_nd(tt, idx) + res_pl = ops.gather_nd(tt, pl_idx) + res_desired = tf.gather_nd(ops.full(tt), idx) + to_run = [res_np, res_pl, res_desired] + with self.test_session() as sess: + res_np_v, res_pl_v, des_v = sess.run(to_run, feed_dict={pl_idx: idx}) + self.assertAllClose(res_np_v, des_v) + self.assertAllClose(res_pl_v, res_pl_v) + class TTMatrixTest(tf.test.TestCase): @@ -667,19 +680,6 @@ def testMultiplyUnknownSizeBatchAndBatch(self): with self.assertRaises(tf.errors.InvalidArgumentError): sess.run(to_run, feed_dict=feed_dict_err) - def testGatherND(self): - idx = [[0, 0, 0], [0, 1, 2], [0, 1, 0]] - pl_idx = tf.placeholder(tf.int32, [None, 3]) - tt = initializers.random_tensor((3, 4, 5), tt_rank=2) - res_np = ops.gather_nd(tt, idx) - res_pl = ops.gather_nd(tt, pl_idx) - res_desired = tf.gather_nd(ops.full(tt), idx) - to_run = [res_np, res_pl, res_desired] - with self.test_session() as sess: - res_np_v, res_pl_v, des_v = sess.run(to_run, feed_dict={pl_idx: idx}) - self.assertAllClose(res_np_v, des_v) - self.assertAllClose(res_pl_v, res_pl_v) - def testGatherNDBatch(self): idx = [[0, 0, 0, 0], [1, 0, 1, 2], [0, 0, 1, 0]] pl_idx = tf.placeholder(tf.int32, [None, 4]) From 7b1e5a494ae7166e92cd6b65187f1de6aa322219 Mon Sep 17 00:00:00 2001 From: Alexander Novikov Date: Wed, 18 Oct 2017 23:26:49 +0300 Subject: [PATCH 3/4] Test gather_batch_dim --- t3f/ops_test.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/t3f/ops_test.py b/t3f/ops_test.py index 27af0cf0..a159984a 100644 --- a/t3f/ops_test.py +++ b/t3f/ops_test.py @@ -680,6 +680,20 @@ def testMultiplyUnknownSizeBatchAndBatch(self): with self.assertRaises(tf.errors.InvalidArgumentError): sess.run(to_run, feed_dict=feed_dict_err) + def testGatherBatchDim(self): + idx = [0, 0, 2, 1] + pl_idx = tf.placeholder(tf.int32, [None]) + tt = initializers.random_tensor_batch((3, 4, 5), tt_rank=2, batch_size=3) + res_np = ops.full(ops.gather_batch_dim(tt, idx)) + res_pl = ops.full(ops.gather_batch_dim(tt, pl_idx)) + res_desired = tf.gather(ops.full(tt), idx) + to_run = [res_np, res_pl, res_desired] + with self.test_session() as sess: + res_np_v, res_pl_v, des_v = sess.run(to_run, feed_dict={pl_idx: idx}) + self.assertAllClose(res_np_v, des_v) + self.assertAllClose(res_pl_v, res_pl_v) + + def testGatherNDBatch(self): idx = [[0, 0, 0, 0], [1, 0, 1, 2], [0, 0, 1, 0]] pl_idx = tf.placeholder(tf.int32, [None, 4]) From 3896b05af19e6816cf9971d4758b864ce21eca0d Mon Sep 17 00:00:00 2001 From: Alexander Novikov Date: Wed, 18 Oct 2017 23:27:03 +0300 Subject: [PATCH 4/4] Fix gather_batch_dim --- t3f/ops.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/t3f/ops.py b/t3f/ops.py index a9a2993c..c3e34acb 100644 --- a/t3f/ops.py +++ b/t3f/ops.py @@ -1082,13 +1082,15 @@ def cast(tt_a, dtype): def gather_batch_dim(tt_batch, indices): """out[i] = tt_batch[indices[i]] - + TODO: move this to indexing! tt_batch[indices] totally makes sense. """ + indices = tf.convert_to_tensor(indices) new_tt_cores = [] for core in tt_batch.tt_cores: new_tt_cores.append(tf.gather(core, indices)) return TensorTrainBatch(new_tt_cores, tt_batch.get_raw_shape(), - tt_batch.get_tt_ranks(), tt_batch.batch_size()) + tt_batch.get_tt_ranks().as_list(), + indices.get_shape()[0].value) def gather_nd(tt, indices):