diff --git a/t3f/__init__.py b/t3f/__init__.py index 10402eec..c9a9b772 100644 --- a/t3f/__init__.py +++ b/t3f/__init__.py @@ -15,6 +15,7 @@ from t3f.ops import multiply from t3f.ops import quadratic_form from t3f.ops import transpose +from t3f.ops import gather_batch_dim from t3f.ops import gather_nd from t3f.batch_ops import concat_along_batch_dim diff --git a/t3f/ops.py b/t3f/ops.py index 9cfd2833..c3e34acb 100644 --- a/t3f/ops.py +++ b/t3f/ops.py @@ -1080,6 +1080,19 @@ def cast(tt_a, dtype): 'TensorTrainBatch.' % tt_a) +def gather_batch_dim(tt_batch, indices): + """out[i] = tt_batch[indices[i]] + TODO: move this to indexing! tt_batch[indices] totally makes sense. + """ + indices = tf.convert_to_tensor(indices) + new_tt_cores = [] + for core in tt_batch.tt_cores: + new_tt_cores.append(tf.gather(core, indices)) + return TensorTrainBatch(new_tt_cores, tt_batch.get_raw_shape(), + tt_batch.get_tt_ranks().as_list(), + indices.get_shape()[0].value) + + def gather_nd(tt, indices): """out[i] = tt[indices[i, 0], indices[i, 1], ...] diff --git a/t3f/ops_test.py b/t3f/ops_test.py index 4fb29032..a159984a 100644 --- a/t3f/ops_test.py +++ b/t3f/ops_test.py @@ -165,6 +165,19 @@ def testCastIntFloat(self): self.assertEqual(dtype, casted.dtype) self.assertTrue(dtype, casted_val.dtype) + def testGatherND(self): + idx = [[0, 0, 0], [0, 1, 2], [0, 1, 0]] + pl_idx = tf.placeholder(tf.int32, [None, 3]) + tt = initializers.random_tensor((3, 4, 5), tt_rank=2) + res_np = ops.gather_nd(tt, idx) + res_pl = ops.gather_nd(tt, pl_idx) + res_desired = tf.gather_nd(ops.full(tt), idx) + to_run = [res_np, res_pl, res_desired] + with self.test_session() as sess: + res_np_v, res_pl_v, des_v = sess.run(to_run, feed_dict={pl_idx: idx}) + self.assertAllClose(res_np_v, des_v) + self.assertAllClose(res_pl_v, res_pl_v) + class TTMatrixTest(tf.test.TestCase): @@ -667,19 +680,20 @@ def testMultiplyUnknownSizeBatchAndBatch(self): with self.assertRaises(tf.errors.InvalidArgumentError): sess.run(to_run, feed_dict=feed_dict_err) - def testGatherND(self): - idx = [[0, 0, 0], [0, 1, 2], [0, 1, 0]] - pl_idx = tf.placeholder(tf.int32, [None, 3]) - tt = initializers.random_tensor((3, 4, 5), tt_rank=2) - res_np = ops.gather_nd(tt, idx) - res_pl = ops.gather_nd(tt, pl_idx) - res_desired = tf.gather_nd(ops.full(tt), idx) + def testGatherBatchDim(self): + idx = [0, 0, 2, 1] + pl_idx = tf.placeholder(tf.int32, [None]) + tt = initializers.random_tensor_batch((3, 4, 5), tt_rank=2, batch_size=3) + res_np = ops.full(ops.gather_batch_dim(tt, idx)) + res_pl = ops.full(ops.gather_batch_dim(tt, pl_idx)) + res_desired = tf.gather(ops.full(tt), idx) to_run = [res_np, res_pl, res_desired] with self.test_session() as sess: res_np_v, res_pl_v, des_v = sess.run(to_run, feed_dict={pl_idx: idx}) self.assertAllClose(res_np_v, des_v) self.assertAllClose(res_pl_v, res_pl_v) + def testGatherNDBatch(self): idx = [[0, 0, 0, 0], [1, 0, 1, 2], [0, 0, 1, 0]] pl_idx = tf.placeholder(tf.int32, [None, 4])