diff --git a/.github/trigger_files/beam_PreCommit_Python_Dill.json b/.github/trigger_files/beam_PreCommit_Python_Dill.json index 8c604b0a135c..840d064bdbb0 100644 --- a/.github/trigger_files/beam_PreCommit_Python_Dill.json +++ b/.github/trigger_files/beam_PreCommit_Python_Dill.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "revision": 2 + "revision": 3 } diff --git a/sdks/python/apache_beam/coders/coders.py b/sdks/python/apache_beam/coders/coders.py index b1c607457f9f..b9bee4585688 100644 --- a/sdks/python/apache_beam/coders/coders.py +++ b/sdks/python/apache_beam/coders/coders.py @@ -927,16 +927,16 @@ def _create_impl(self): class DeterministicFastPrimitivesCoderV2(FastCoder): """Throws runtime errors when encoding non-deterministic values.""" - def __init__(self, coder, step_label, update_compatibility_version=None): + def __init__(self, coder, step_label): self._underlying_coder = coder self._step_label = step_label self._use_relative_filepaths = True self._version_tag = "v2_69" - from apache_beam.transforms.util import is_v1_prior_to_v2 # Versions prior to 2.69.0 did not use relative filepaths. - if update_compatibility_version and is_v1_prior_to_v2( - v1=update_compatibility_version, v2="2.69.0"): + from apache_beam.options.pipeline_options_context import get_pipeline_options + opts = get_pipeline_options() + if opts and opts.is_compat_version_prior_to("2.69.0"): self._version_tag = "" self._use_relative_filepaths = False @@ -1005,20 +1005,11 @@ def to_type_hint(self): return Any -def _should_force_use_dill(registry): - # force_dill_deterministic_coders is for testing purposes. If there is a - # DeterministicFastPrimitivesCoder in the pipeline graph but the dill - # encoding path is not really triggered dill does not have to be installed. - # and this check can be skipped. - if getattr(registry, 'force_dill_deterministic_coders', False): - return True +def _should_force_use_dill(): + from apache_beam.options.pipeline_options_context import get_pipeline_options - from apache_beam.transforms.util import is_v1_prior_to_v2 - update_compat_version = registry.update_compatibility_version - if not update_compat_version: - return False - - if not is_v1_prior_to_v2(v1=update_compat_version, v2="2.68.0"): + opts = get_pipeline_options() + if opts is None or not opts.is_compat_version_prior_to("2.68.0"): return False try: @@ -1043,12 +1034,9 @@ def _update_compatible_deterministic_fast_primitives_coder(coder, step_label): - In SDK version 2.69.0 cloudpickle is used to encode "special types" with relative filepaths in code objects and dynamic functions. """ - from apache_beam.coders import typecoders - - if _should_force_use_dill(typecoders.registry): + if _should_force_use_dill(): return DeterministicFastPrimitivesCoder(coder, step_label) - return DeterministicFastPrimitivesCoderV2( - coder, step_label, typecoders.registry.update_compatibility_version) + return DeterministicFastPrimitivesCoderV2(coder, step_label) class FastPrimitivesCoder(FastCoder): diff --git a/sdks/python/apache_beam/coders/coders_test_common.py b/sdks/python/apache_beam/coders/coders_test_common.py index ad742665fb8a..5b7f5f65a560 100644 --- a/sdks/python/apache_beam/coders/coders_test_common.py +++ b/sdks/python/apache_beam/coders/coders_test_common.py @@ -42,6 +42,8 @@ from apache_beam.coders import proto2_coder_test_messages_pb2 as test_message from apache_beam.coders import typecoders from apache_beam.internal import pickler +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options_context import scoped_pipeline_options from apache_beam.runners import pipeline_context from apache_beam.transforms import userstate from apache_beam.transforms import window @@ -202,9 +204,6 @@ def tearDownClass(cls): assert not standard - cls.seen, str(standard - cls.seen) assert not cls.seen_nested - standard, str(cls.seen_nested - standard) - def tearDown(self): - typecoders.registry.update_compatibility_version = None - @classmethod def _observe(cls, coder): cls.seen.add(type(coder)) @@ -274,80 +273,82 @@ def test_deterministic_coder(self, compat_version): - In SDK version >=2.69.0 cloudpickle is used to encode "special types" with relative filepaths in code objects and dynamic functions. """ + with scoped_pipeline_options( + PipelineOptions(update_compatibility_version=compat_version)): + coder = coders.FastPrimitivesCoder() + if not dill and compat_version == "2.67.0": + with self.assertRaises(RuntimeError): + coder.as_deterministic_coder(step_label="step") + self.skipTest('Dill not installed') + deterministic_coder = coder.as_deterministic_coder(step_label="step") + + self.check_coder(deterministic_coder, *self.test_values_deterministic) + for v in self.test_values_deterministic: + self.check_coder(coders.TupleCoder((deterministic_coder, )), (v, )) + self.check_coder( + coders.TupleCoder( + (deterministic_coder, ) * len(self.test_values_deterministic)), + tuple(self.test_values_deterministic)) - typecoders.registry.update_compatibility_version = compat_version - coder = coders.FastPrimitivesCoder() - if not dill and compat_version == "2.67.0": - with self.assertRaises(RuntimeError): - coder.as_deterministic_coder(step_label="step") - self.skipTest('Dill not installed') - deterministic_coder = coder.as_deterministic_coder(step_label="step") - - self.check_coder(deterministic_coder, *self.test_values_deterministic) - for v in self.test_values_deterministic: - self.check_coder(coders.TupleCoder((deterministic_coder, )), (v, )) - self.check_coder( - coders.TupleCoder( - (deterministic_coder, ) * len(self.test_values_deterministic)), - tuple(self.test_values_deterministic)) - - self.check_coder(deterministic_coder, {}) - self.check_coder(deterministic_coder, {2: 'x', 1: 'y'}) - with self.assertRaises(TypeError): - self.check_coder(deterministic_coder, {1: 'x', 'y': 2}) - self.check_coder(deterministic_coder, [1, {}]) - with self.assertRaises(TypeError): - self.check_coder(deterministic_coder, [1, {1: 'x', 'y': 2}]) - - self.check_coder( - coders.TupleCoder((deterministic_coder, coder)), (1, {}), ('a', [{}])) + self.check_coder(deterministic_coder, {}) + self.check_coder(deterministic_coder, {2: 'x', 1: 'y'}) + with self.assertRaises(TypeError): + self.check_coder(deterministic_coder, {1: 'x', 'y': 2}) + self.check_coder(deterministic_coder, [1, {}]) + with self.assertRaises(TypeError): + self.check_coder(deterministic_coder, [1, {1: 'x', 'y': 2}]) - self.check_coder(deterministic_coder, test_message.MessageA(field1='value')) + self.check_coder( + coders.TupleCoder((deterministic_coder, coder)), (1, {}), ('a', [{}])) - # Skip this test during cloudpickle. Dill monkey patches the __reduce__ - # method for anonymous named tuples (MyNamedTuple) which is not pickleable. - # Since the test is parameterized the type gets colbbered. - if compat_version == "2.67.0": self.check_coder( - deterministic_coder, [MyNamedTuple(1, 2), MyTypedNamedTuple(1, 'a')]) + deterministic_coder, test_message.MessageA(field1='value')) - self.check_coder( - deterministic_coder, - [AnotherNamedTuple(1, 2), MyTypedNamedTuple(1, 'a')]) + # Skip this test during cloudpickle. Dill monkey patches the __reduce__ + # method for anonymous named tuples (MyNamedTuple) which is not + # pickleable. Since the test is parameterized the type gets colbbered. + if compat_version == "2.67.0": + self.check_coder( + deterministic_coder, + [MyNamedTuple(1, 2), MyTypedNamedTuple(1, 'a')]) - if dataclasses is not None: - self.check_coder(deterministic_coder, FrozenDataClass(1, 2)) - self.check_coder(deterministic_coder, FrozenKwOnlyDataClass(c=1, d=2)) self.check_coder( - deterministic_coder, FrozenUnInitKwOnlyDataClass(side=11)) + deterministic_coder, + [AnotherNamedTuple(1, 2), MyTypedNamedTuple(1, 'a')]) - with self.assertRaises(TypeError): - self.check_coder(deterministic_coder, UnFrozenDataClass(1, 2)) - - with self.assertRaises(TypeError): + if dataclasses is not None: + self.check_coder(deterministic_coder, FrozenDataClass(1, 2)) + self.check_coder(deterministic_coder, FrozenKwOnlyDataClass(c=1, d=2)) self.check_coder( - deterministic_coder, FrozenDataClass(UnFrozenDataClass(1, 2), 3)) + deterministic_coder, FrozenUnInitKwOnlyDataClass(side=11)) + with self.assertRaises(TypeError): - self.check_coder( - deterministic_coder, - AnotherNamedTuple(UnFrozenDataClass(1, 2), 3)) + self.check_coder(deterministic_coder, UnFrozenDataClass(1, 2)) - self.check_coder(deterministic_coder, list(MyEnum)) - self.check_coder(deterministic_coder, list(MyIntEnum)) - self.check_coder(deterministic_coder, list(MyIntFlag)) - self.check_coder(deterministic_coder, list(MyFlag)) + with self.assertRaises(TypeError): + self.check_coder( + deterministic_coder, FrozenDataClass(UnFrozenDataClass(1, 2), 3)) + with self.assertRaises(TypeError): + self.check_coder( + deterministic_coder, + AnotherNamedTuple(UnFrozenDataClass(1, 2), 3)) - self.check_coder( - deterministic_coder, - [DefinesGetAndSetState(1), DefinesGetAndSetState((1, 2, 3))]) + self.check_coder(deterministic_coder, list(MyEnum)) + self.check_coder(deterministic_coder, list(MyIntEnum)) + self.check_coder(deterministic_coder, list(MyIntFlag)) + self.check_coder(deterministic_coder, list(MyFlag)) - with self.assertRaises(TypeError): - self.check_coder(deterministic_coder, DefinesGetState(1)) - with self.assertRaises(TypeError): self.check_coder( - deterministic_coder, DefinesGetAndSetState({ - 1: 'x', 'y': 2 - })) + deterministic_coder, + [DefinesGetAndSetState(1), DefinesGetAndSetState((1, 2, 3))]) + + with self.assertRaises(TypeError): + self.check_coder(deterministic_coder, DefinesGetState(1)) + with self.assertRaises(TypeError): + self.check_coder( + deterministic_coder, DefinesGetAndSetState({ + 1: 'x', 'y': 2 + })) @parameterized.expand([ param(compat_version=None), @@ -364,28 +365,29 @@ def test_deterministic_map_coder_is_update_compatible(self, compat_version): - In SDK version >=2.69.0 cloudpickle is used to encode "special types" with relative file. """ - typecoders.registry.update_compatibility_version = compat_version - values = [{ - MyTypedNamedTuple(i, 'a'): MyTypedNamedTuple('a', i) - for i in range(10) - }] + with scoped_pipeline_options( + PipelineOptions(update_compatibility_version=compat_version)): + values = [{ + MyTypedNamedTuple(i, 'a'): MyTypedNamedTuple('a', i) + for i in range(10) + }] - coder = coders.MapCoder( - coders.FastPrimitivesCoder(), coders.FastPrimitivesCoder()) + coder = coders.MapCoder( + coders.FastPrimitivesCoder(), coders.FastPrimitivesCoder()) - if not dill and compat_version == "2.67.0": - with self.assertRaises(RuntimeError): - coder.as_deterministic_coder(step_label="step") - self.skipTest('Dill not installed') + if not dill and compat_version == "2.67.0": + with self.assertRaises(RuntimeError): + coder.as_deterministic_coder(step_label="step") + self.skipTest('Dill not installed') - deterministic_coder = coder.as_deterministic_coder(step_label="step") + deterministic_coder = coder.as_deterministic_coder(step_label="step") - assert isinstance( - deterministic_coder._key_coder, - coders.DeterministicFastPrimitivesCoderV2 if compat_version - in (None, "2.68.0") else coders.DeterministicFastPrimitivesCoder) + assert isinstance( + deterministic_coder._key_coder, + coders.DeterministicFastPrimitivesCoderV2 if compat_version + in (None, "2.68.0") else coders.DeterministicFastPrimitivesCoder) - self.check_coder(deterministic_coder, *values) + self.check_coder(deterministic_coder, *values) def test_dill_coder(self): if not dill: @@ -738,7 +740,6 @@ def test_cross_process_encoding_of_special_types_is_deterministic( if sys.executable is None: self.skipTest('No Python interpreter found') - typecoders.registry.update_compatibility_version = compat_version # pylint: disable=line-too-long script = textwrap.dedent( @@ -750,7 +751,8 @@ def test_cross_process_encoding_of_special_types_is_deterministic( import logging from apache_beam.coders import coders - from apache_beam.coders import typecoders + from apache_beam.options.pipeline_options_context import scoped_pipeline_options + from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.coders.coders_test_common import MyNamedTuple from apache_beam.coders.coders_test_common import MyTypedNamedTuple from apache_beam.coders.coders_test_common import MyEnum @@ -802,20 +804,20 @@ def test_cross_process_encoding_of_special_types_is_deterministic( ]) compat_version = {'"'+ compat_version +'"' if compat_version else None} - typecoders.registry.update_compatibility_version = compat_version - coder = coders.FastPrimitivesCoder() - deterministic_coder = coder.as_deterministic_coder("step") - - results = dict() - for test_name, value in test_cases: - try: - encoded = deterministic_coder.encode(value) - results[test_name] = encoded - except Exception as e: - logging.warning("Encoding failed with %s", e) - sys.exit(1) - - sys.stdout.buffer.write(pickle.dumps(results)) + with scoped_pipeline_options(PipelineOptions(update_compatibility_version=compat_version)): + coder = coders.FastPrimitivesCoder() + deterministic_coder = coder.as_deterministic_coder("step") + + results = dict() + for test_name, value in test_cases: + try: + encoded = deterministic_coder.encode(value) + results[test_name] = encoded + except Exception as e: + logging.warning("Encoding failed with %s", e) + sys.exit(1) + + sys.stdout.buffer.write(pickle.dumps(results)) ''') diff --git a/sdks/python/apache_beam/coders/typecoders.py b/sdks/python/apache_beam/coders/typecoders.py index ef75a21ce9ef..9683e00f0c2a 100644 --- a/sdks/python/apache_beam/coders/typecoders.py +++ b/sdks/python/apache_beam/coders/typecoders.py @@ -84,7 +84,6 @@ def __init__(self, fallback_coder=None): self._coders: Dict[Any, Type[coders.Coder]] = {} self.custom_types: List[Any] = [] self.register_standard_coders(fallback_coder) - self.update_compatibility_version = None def register_standard_coders(self, fallback_coder): """Register coders for all basic and composite types.""" diff --git a/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py b/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py index 76f465ddebbf..738ace67a5f7 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py @@ -1120,8 +1120,8 @@ def _load_data( of the load jobs would fail but not other. If any of them fails, then copy jobs are not triggered. """ - self.reshuffle_before_load = not util.is_compat_version_prior_to( - p.options, "2.65.0") + self.reshuffle_before_load = not p.options.is_compat_version_prior_to( + "2.65.0") if self.reshuffle_before_load: # Ensure that TriggerLoadJob retry inputs are deterministic by breaking # fusion for inputs. diff --git a/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py b/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py index 30f09ff4f56a..191719e6a208 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py @@ -485,9 +485,9 @@ def test_records_traverse_transform_with_mocks(self): param(compat_version=None), param(compat_version="2.64.0"), ]) - def test_reshuffle_before_load(self, compat_version): - from apache_beam.coders import typecoders - typecoders.registry.force_dill_deterministic_coders = True + @mock.patch( + 'apache_beam.coders.coders._should_force_use_dill', return_value=True) + def test_reshuffle_before_load(self, mock_force_dill, compat_version): destination = 'project1:dataset1.table1' job_reference = bigquery_api.JobReference() @@ -523,7 +523,6 @@ def test_reshuffle_before_load(self, compat_version): reshuffle_before_load = compat_version is None assert transform.reshuffle_before_load == reshuffle_before_load - typecoders.registry.force_dill_deterministic_coders = False def test_load_job_id_used(self): job_reference = bigquery_api.JobReference() @@ -998,10 +997,10 @@ def dynamic_destination_resolver(element, *side_inputs): param( is_streaming=True, with_auto_sharding=True, compat_version="2.64.0"), ]) + @mock.patch( + 'apache_beam.coders.coders._should_force_use_dill', return_value=True) def test_triggering_frequency( - self, is_streaming, with_auto_sharding, compat_version): - from apache_beam.coders import typecoders - typecoders.registry.force_dill_deterministic_coders = True + self, mock_force_dill, is_streaming, with_auto_sharding, compat_version): destination = 'project1:dataset1.table1' @@ -1108,8 +1107,6 @@ def __call__(self): label='CheckDestinations') assert_that(jobs, equal_to(expected_jobs), label='CheckJobs') - typecoders.registry.force_dill_deterministic_coders = False - class BigQueryFileLoadsIT(unittest.TestCase): diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py index 0e1012b2de65..d60d75283eab 100644 --- a/sdks/python/apache_beam/options/pipeline_options.py +++ b/sdks/python/apache_beam/options/pipeline_options.py @@ -668,6 +668,25 @@ def view_as(self, cls: Type[PipelineOptionsT]) -> PipelineOptionsT: view._all_options = self._all_options return view + def is_compat_version_prior_to(self, breaking_change_version): + """Check if update_compatibility_version is prior to a breaking change. + + Returns True if the pipeline should use old behavior (i.e., the + update_compatibility_version is set and is earlier than the given version). + Returns False if update_compatibility_version is not set or is >= the + breaking change version. + + Args: + breaking_change_version: Version string (e.g., "2.72.0") at which + the breaking change was introduced. + """ + v1 = self.view_as(StreamingOptions).update_compatibility_version + if v1 is None: + return False + v1_parts = (v1.split('.') + ['0', '0', '0'])[:3] + v2_parts = (breaking_change_version.split('.') + ['0', '0', '0'])[:3] + return tuple(map(int, v1_parts)) < tuple(map(int, v2_parts)) + def _visible_option_list(self) -> List[str]: return sorted( option for option in dir(self._visible_options) if option[0] != '_') diff --git a/sdks/python/apache_beam/options/pipeline_options_test.py b/sdks/python/apache_beam/options/pipeline_options_test.py index c683c9625272..215c44156ea6 100644 --- a/sdks/python/apache_beam/options/pipeline_options_test.py +++ b/sdks/python/apache_beam/options/pipeline_options_test.py @@ -987,6 +987,72 @@ def test_comma_separated_dataflow_service_options(self): options.get_all_options()['dataflow_service_options']) +class CompatVersionTest(unittest.TestCase): + def test_is_compat_version_prior_to(self): + test_cases = [ + # Basic comparison cases + ("1.0.0", "2.0.0", True), # v1 < v2 in major + ("2.0.0", "1.0.0", False), # v1 > v2 in major + ("1.1.0", "1.2.0", True), # v1 < v2 in minor + ("1.2.0", "1.1.0", False), # v1 > v2 in minor + ("1.0.1", "1.0.2", True), # v1 < v2 in patch + ("1.0.2", "1.0.1", False), # v1 > v2 in patch + + # Equal versions + ("1.0.0", "1.0.0", False), # Identical + ("0.0.0", "0.0.0", False), # Both zero + + # Different lengths - shorter vs longer + ("1.0", "1.0.0", False), # Should be equal (1.0 = 1.0.0) + ("1.0", "1.0.1", True), # 1.0.0 < 1.0.1 + ("1.2", "1.2.0", False), # Should be equal (1.2 = 1.2.0) + ("1.2", "1.2.3", True), # 1.2.0 < 1.2.3 + ("2", "2.0.0", False), # Should be equal (2 = 2.0.0) + ("2", "2.0.1", True), # 2.0.0 < 2.0.1 + ("1", "2.0", True), # 1.0.0 < 2.0.0 + + # Different lengths - longer vs shorter + ("1.0.0", "1.0", False), # Should be equal + ("1.0.1", "1.0", False), # 1.0.1 > 1.0.0 + ("1.2.0", "1.2", False), # Should be equal + ("1.2.3", "1.2", False), # 1.2.3 > 1.2.0 + ("2.0.0", "2", False), # Should be equal + ("2.0.1", "2", False), # 2.0.1 > 2.0.0 + ("2.0", "1", False), # 2.0.0 > 1.0.0 + + # Mixed length comparisons + ("1.0", "2.0.0", True), # 1.0.0 < 2.0.0 + ("2.0", "1.0.0", False), # 2.0.0 > 1.0.0 + ("1", "1.0.1", True), # 1.0.0 < 1.0.1 + ("1.1", "1.0.9", False), # 1.1.0 > 1.0.9 + + # Large numbers + ("1.9.9", "2.0.0", True), # 1.9.9 < 2.0.0 + ("10.0.0", "9.9.9", False), # 10.0.0 > 9.9.9 + ("1.10.0", "1.9.0", False), # 1.10.0 > 1.9.0 + ("1.2.10", "1.2.9", False), # 1.2.10 > 1.2.9 + + # Sequential versions + ("1.0.0", "1.0.1", True), + ("1.0.1", "1.0.2", True), + ("1.0.9", "1.1.0", True), + ("1.9.9", "2.0.0", True), + ] + + for v1, v2, expected in test_cases: + options = PipelineOptions(update_compatibility_version=v1) + self.assertEqual( + options.is_compat_version_prior_to(v2), + expected, + msg=f"Failed {v1} < {v2} == {expected}") + + # None case: no update_compatibility_version set + options_no_compat = PipelineOptions() + self.assertFalse( + options_no_compat.is_compat_version_prior_to("1.0.0"), + msg="Should return False when update_compatibility_version is not set") + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/pipeline_test.py b/sdks/python/apache_beam/pipeline_test.py index 3e7d083cb2fb..b28fe3c3d14e 100644 --- a/sdks/python/apache_beam/pipeline_test.py +++ b/sdks/python/apache_beam/pipeline_test.py @@ -1023,6 +1023,7 @@ def test_dir(self): self.assertEqual({ 'from_dictionary', 'get_all_options', + 'is_compat_version_prior_to', 'slices', 'style', 'view_as', @@ -1038,6 +1039,7 @@ def test_dir(self): self.assertEqual({ 'from_dictionary', 'get_all_options', + 'is_compat_version_prior_to', 'style', 'view_as', 'display_data', diff --git a/sdks/python/apache_beam/transforms/external.py b/sdks/python/apache_beam/transforms/external.py index 21a863069e63..9469ac717dfc 100644 --- a/sdks/python/apache_beam/transforms/external.py +++ b/sdks/python/apache_beam/transforms/external.py @@ -48,7 +48,6 @@ from apache_beam.runners.portability import artifact_service from apache_beam.transforms import environments from apache_beam.transforms import ptransform -from apache_beam.transforms.util import is_compat_version_prior_to from apache_beam.typehints import WithTypeHints from apache_beam.typehints import native_type_compatibility from apache_beam.typehints import row_type @@ -499,9 +498,9 @@ def expand(self, pcolls): expansion_service = self._expansion_service if self._managed_replacement: - compat_version_prior_to_current = is_compat_version_prior_to( - pcolls.pipeline._options, - self._managed_replacement.update_compatibility_version) + compat_version_prior_to_current = ( + pcolls.pipeline._options.is_compat_version_prior_to( + self._managed_replacement.update_compatibility_version)) if not compat_version_prior_to_current: payload_builder = self._managed_payload_builder expansion_service = self._managed_expansion_service diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py index d5985b6212df..06ea822aa0be 100644 --- a/sdks/python/apache_beam/transforms/ptransform.py +++ b/sdks/python/apache_beam/transforms/ptransform.py @@ -1062,10 +1062,8 @@ def expand(self, pcoll): return self._fn(pcoll, *args, **kwargs) def set_options(self, options): - # Avoid circular import. - from apache_beam.transforms.util import is_compat_version_prior_to - self._use_backwards_compatible_label = is_compat_version_prior_to( - options, '2.68.0') + self._use_backwards_compatible_label = options.is_compat_version_prior_to( + '2.68.0') def default_label(self) -> str: # Attempt to give a reasonable name to this transform. diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index fbaab6b4ebbb..770a5baec366 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -47,7 +47,6 @@ from apache_beam import pvalue from apache_beam import typehints from apache_beam.metrics import Metrics -from apache_beam.options import pipeline_options from apache_beam.portability import common_urns from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.pvalue import AsSideInput @@ -1350,27 +1349,6 @@ def get_window_coder(self): return self._window_coder -def is_v1_prior_to_v2(*, v1, v2): - if v1 is None: - return False - - v1_parts = (v1.split('.') + ['0', '0', '0'])[:3] - v2_parts = (v2.split('.') + ['0', '0', '0'])[:3] - return tuple(map(int, v1_parts)) < tuple(map(int, v2_parts)) - - -def is_compat_version_prior_to(options, breaking_change_version): - # This function is used in a branch statement to determine whether we should - # keep the old behavior prior to a breaking change or use the new behavior. - # - If update_compatibility_version < breaking_change_version, we will return - # True and keep the old behavior. - update_compatibility_version = options.view_as( - pipeline_options.StreamingOptions).update_compatibility_version - - return is_v1_prior_to_v2( - v1=update_compatibility_version, v2=breaking_change_version) - - def reify_metadata_default_window( element, timestamp=DoFn.TimestampParam, pane_info=DoFn.PaneInfoParam): key, value = element @@ -1448,8 +1426,8 @@ def restore_timestamps(element): for (value, timestamp) in values ] - if is_compat_version_prior_to(pcoll.pipeline.options, - RESHUFFLE_TYPEHINT_BREAKING_CHANGE_VERSION): + if pcoll.pipeline.options.is_compat_version_prior_to( + RESHUFFLE_TYPEHINT_BREAKING_CHANGE_VERSION): pre_gbk_map = Map(reify_timestamps).with_output_types(Any) else: pre_gbk_map = Map(reify_timestamps).with_input_types( @@ -1468,8 +1446,8 @@ def restore_timestamps(element): key, windowed_values = element return [wv.with_value((key, wv.value)) for wv in windowed_values] - if is_compat_version_prior_to(pcoll.pipeline.options, - RESHUFFLE_TYPEHINT_BREAKING_CHANGE_VERSION): + if pcoll.pipeline.options.is_compat_version_prior_to( + RESHUFFLE_TYPEHINT_BREAKING_CHANGE_VERSION): pre_gbk_map = Map(reify_timestamps).with_output_types(Any) else: pre_gbk_map = Map(reify_timestamps).with_input_types( @@ -1493,7 +1471,7 @@ def restore_timestamps(element): return result def expand(self, pcoll): - if is_compat_version_prior_to(pcoll.pipeline.options, "2.65.0"): + if pcoll.pipeline.options.is_compat_version_prior_to("2.65.0"): return self.expand_2_64_0(pcoll) windowing_saved = pcoll.windowing @@ -1550,8 +1528,8 @@ def __init__(self, num_buckets=None): def expand(self, pcoll): # type: (pvalue.PValue) -> pvalue.PCollection - if is_compat_version_prior_to(pcoll.pipeline.options, - RESHUFFLE_TYPEHINT_BREAKING_CHANGE_VERSION): + if pcoll.pipeline.options.is_compat_version_prior_to( + RESHUFFLE_TYPEHINT_BREAKING_CHANGE_VERSION): reshuffle_step = ReshufflePerKey() else: reshuffle_step = ReshufflePerKey().with_input_types( diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index 7389568691cd..98edb4cc2bd0 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -1321,10 +1321,11 @@ def test_reshuffle_streaming_global_window_with_buckets(self): param(compat_version=None), param(compat_version="2.64.0"), ]) - def test_reshuffle_custom_window_preserves_metadata(self, compat_version): + @mock.patch( + 'apache_beam.coders.coders._should_force_use_dill', return_value=True) + def test_reshuffle_custom_window_preserves_metadata( + self, mock_force_dill, compat_version): """Tests that Reshuffle preserves pane info.""" - from apache_beam.coders import typecoders - typecoders.registry.force_dill_deterministic_coders = True element_count = 12 timestamp_value = timestamp.Timestamp(0) l = [ @@ -1418,17 +1419,17 @@ def test_reshuffle_custom_window_preserves_metadata(self, compat_version): equal_to(expected), label='CheckMetadataPreserved', reify_windows=True) - typecoders.registry.force_dill_deterministic_coders = False @parameterized.expand([ param(compat_version=None), param(compat_version="2.64.0"), ]) - def test_reshuffle_default_window_preserves_metadata(self, compat_version): + @mock.patch( + 'apache_beam.coders.coders._should_force_use_dill', return_value=True) + def test_reshuffle_default_window_preserves_metadata( + self, mock_force_dill, compat_version): """Tests that Reshuffle preserves timestamp, window, and pane info metadata.""" - from apache_beam.coders import typecoders - typecoders.registry.force_dill_deterministic_coders = True no_firing = PaneInfo( is_first=True, is_last=True, @@ -1502,7 +1503,6 @@ def test_reshuffle_default_window_preserves_metadata(self, compat_version): equal_to(expected), label='CheckMetadataPreserved', reify_windows=True) - typecoders.registry.force_dill_deterministic_coders = False @pytest.mark.it_validatesrunner def test_reshuffle_preserves_timestamps(self): @@ -2521,68 +2521,6 @@ def record(tag): label='result') -class CompatCheckTest(unittest.TestCase): - def test_is_v1_prior_to_v2(self): - test_cases = [ - # Basic comparison cases - ("1.0.0", "2.0.0", True), # v1 < v2 in major - ("2.0.0", "1.0.0", False), # v1 > v2 in major - ("1.1.0", "1.2.0", True), # v1 < v2 in minor - ("1.2.0", "1.1.0", False), # v1 > v2 in minor - ("1.0.1", "1.0.2", True), # v1 < v2 in patch - ("1.0.2", "1.0.1", False), # v1 > v2 in patch - - # Equal versions - ("1.0.0", "1.0.0", False), # Identical - ("0.0.0", "0.0.0", False), # Both zero - - # Different lengths - shorter vs longer - ("1.0", "1.0.0", False), # Should be equal (1.0 = 1.0.0) - ("1.0", "1.0.1", True), # 1.0.0 < 1.0.1 - ("1.2", "1.2.0", False), # Should be equal (1.2 = 1.2.0) - ("1.2", "1.2.3", True), # 1.2.0 < 1.2.3 - ("2", "2.0.0", False), # Should be equal (2 = 2.0.0) - ("2", "2.0.1", True), # 2.0.0 < 2.0.1 - ("1", "2.0", True), # 1.0.0 < 2.0.0 - - # Different lengths - longer vs shorter - ("1.0.0", "1.0", False), # Should be equal - ("1.0.1", "1.0", False), # 1.0.1 > 1.0.0 - ("1.2.0", "1.2", False), # Should be equal - ("1.2.3", "1.2", False), # 1.2.3 > 1.2.0 - ("2.0.0", "2", False), # Should be equal - ("2.0.1", "2", False), # 2.0.1 > 2.0.0 - ("2.0", "1", False), # 2.0.0 > 1.0.0 - - # Mixed length comparisons - ("1.0", "2.0.0", True), # 1.0.0 < 2.0.0 - ("2.0", "1.0.0", False), # 2.0.0 > 1.0.0 - ("1", "1.0.1", True), # 1.0.0 < 1.0.1 - ("1.1", "1.0.9", False), # 1.1.0 > 1.0.9 - - # Large numbers - ("1.9.9", "2.0.0", True), # 1.9.9 < 2.0.0 - ("10.0.0", "9.9.9", False), # 10.0.0 > 9.9.9 - ("1.10.0", "1.9.0", False), # 1.10.0 > 1.9.0 - ("1.2.10", "1.2.9", False), # 1.2.10 > 1.2.9 - - # Sequential versions - ("1.0.0", "1.0.1", True), - ("1.0.1", "1.0.2", True), - ("1.0.9", "1.1.0", True), - ("1.9.9", "2.0.0", True), - - # Null/None cases - (None, "1.0.0", False), # v1 is None - ] - - for v1, v2, expected in test_cases: - self.assertEqual( - util.is_v1_prior_to_v2(v1=v1, v2=v2), - expected, - msg=f"Failed {v1} < {v2} == {expected}") - - if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main()