diff --git a/vmoe/evaluate/evaluator_test.py b/vmoe/evaluate/evaluator_test.py index cc5f1f3..543c1b1 100644 --- a/vmoe/evaluate/evaluator_test.py +++ b/vmoe/evaluate/evaluator_test.py @@ -75,7 +75,7 @@ def _create_dataset_and_expected_state(cls): # 0 or loss[i]. sum_loss=tf.reduce_sum(loss * valid).numpy(), rngs={}) - return TfDatasetIterator(dataset), expected_eval_state + return TfDatasetIterator(dataset, checkpoint=False), expected_eval_state def test_evaluate_dataset(self): # Create random test dataset. diff --git a/vmoe/evaluate/fewshot_test.py b/vmoe/evaluate/fewshot_test.py index 96f152b..f621d92 100644 --- a/vmoe/evaluate/fewshot_test.py +++ b/vmoe/evaluate/fewshot_test.py @@ -112,9 +112,12 @@ def setUp(self): 'label': labels, fewshot.VALID_KEY: valid, }) - self.mock_get_dataset = self.enter_context(mock.patch.object( - fewshot.vmoe.data.input_pipeline, 'get_dataset', - side_effect=lambda *a, **kw: clu.data.TfDatasetIterator(dataset))) + self.mock_get_dataset = self.enter_context( + mock.patch.object( + fewshot.vmoe.data.input_pipeline, + 'get_dataset', + side_effect=lambda *a, **kw: clu.data.TfDatasetIterator( # pylint: disable=g-long-lambda + dataset, checkpoint=False))) @classmethod def _apply_fn(cls, variables, images, rngs=None):