From a093b4aea74273aa478b842c5283ef0a44f326b0 Mon Sep 17 00:00:00 2001 From: Marvin Ritter Date: Fri, 23 Sep 2022 04:05:16 -0700 Subject: [PATCH] Remove default for checkpoint argument. PiperOrigin-RevId: 476332142 --- vmoe/evaluate/evaluator_test.py | 2 +- vmoe/evaluate/fewshot_test.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/vmoe/evaluate/evaluator_test.py b/vmoe/evaluate/evaluator_test.py index cc5f1f3..543c1b1 100644 --- a/vmoe/evaluate/evaluator_test.py +++ b/vmoe/evaluate/evaluator_test.py @@ -75,7 +75,7 @@ def _create_dataset_and_expected_state(cls): # 0 or loss[i]. sum_loss=tf.reduce_sum(loss * valid).numpy(), rngs={}) - return TfDatasetIterator(dataset), expected_eval_state + return TfDatasetIterator(dataset, checkpoint=False), expected_eval_state def test_evaluate_dataset(self): # Create random test dataset. diff --git a/vmoe/evaluate/fewshot_test.py b/vmoe/evaluate/fewshot_test.py index 96f152b..f621d92 100644 --- a/vmoe/evaluate/fewshot_test.py +++ b/vmoe/evaluate/fewshot_test.py @@ -112,9 +112,12 @@ def setUp(self): 'label': labels, fewshot.VALID_KEY: valid, }) - self.mock_get_dataset = self.enter_context(mock.patch.object( - fewshot.vmoe.data.input_pipeline, 'get_dataset', - side_effect=lambda *a, **kw: clu.data.TfDatasetIterator(dataset))) + self.mock_get_dataset = self.enter_context( + mock.patch.object( + fewshot.vmoe.data.input_pipeline, + 'get_dataset', + side_effect=lambda *a, **kw: clu.data.TfDatasetIterator( # pylint: disable=g-long-lambda + dataset, checkpoint=False))) @classmethod def _apply_fn(cls, variables, images, rngs=None):