diff --git a/t3f/batch_ops.py b/t3f/batch_ops.py index bd89bb36..32f76121 100644 --- a/t3f/batch_ops.py +++ b/t3f/batch_ops.py @@ -215,4 +215,4 @@ def pairwise_flat_inner(tt_1, tt_2, matrix=None): # Squeeze to make the result of size batch_size x batch_size instead of # batch_size x batch_size x 1 x 1. - return tf.squeeze(res) + return tf.squeeze(res, axis=(2, 3))