diff --git a/sdks/python/apache_beam/ml/inference/model_manager_test.py b/sdks/python/apache_beam/ml/inference/model_manager_test.py index 1bd8edd34d18..7cfb73cb668f 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_test.py @@ -333,22 +333,21 @@ def test_single_model_convergence_with_fluctuations(self): """ model_name = "fluctuating_model" model_cost = 3000.0 - load_cost = 2500.0 # Fix random seed for reproducibility random.seed(42) def loader(): - self.mock_monitor.allocate(load_cost) + self.mock_monitor.allocate(model_cost) return model_name model = self.manager.acquire_model(model_name, loader) self.manager.release_model(model_name, model) initial_est = self.manager._estimator.get_estimate(model_name) - self.assertEqual(initial_est, load_cost) + self.assertEqual(initial_est, model_cost) def run_inference(): model = self.manager.acquire_model(model_name, loader) - noise = model_cost - load_cost + random.uniform(-300.0, 300.0) + noise = random.uniform(-300.0, 300.0) self.mock_monitor.allocate(noise) time.sleep(0.1) self.mock_monitor.free(noise) diff --git a/sdks/python/apache_beam/utils/multi_process_shared_test.py b/sdks/python/apache_beam/utils/multi_process_shared_test.py index 7b2b11857bfd..3c74903b8d99 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared_test.py +++ b/sdks/python/apache_beam/utils/multi_process_shared_test.py @@ -289,6 +289,7 @@ def setUp(self): for tag in ['basic', 'main', 'to_delete', + 'to_keep', 'mix1', 'mix2', 'test_process_exit', @@ -310,7 +311,7 @@ def tearDown(self): def test_call(self): shared = multi_process_shared.MultiProcessShared( - Counter, tag='basic', always_proxy=True, spawn_process=True).acquire() + Counter, tag='main', always_proxy=True, spawn_process=True).acquire() self.assertEqual(shared.get(), 0) self.assertEqual(shared.increment(), 1) self.assertEqual(shared.increment(10), 11) @@ -323,7 +324,8 @@ def test_unsafe_hard_delete_autoproxywrapper(self): shared2 = multi_process_shared.MultiProcessShared( Counter, tag='to_delete', always_proxy=True, spawn_process=True) counter3 = multi_process_shared.MultiProcessShared( - Counter, tag='basic', always_proxy=True, spawn_process=True).acquire() + Counter, tag='to_keep', always_proxy=True, + spawn_process=True).acquire() counter1 = shared1.acquire() counter2 = shared2.acquire()