diff --git a/.github/trigger_files/beam_PreCommit_Python_Dill.json b/.github/trigger_files/beam_PreCommit_Python_Dill.json index 8c604b0a135c..840d064bdbb0 100644 --- a/.github/trigger_files/beam_PreCommit_Python_Dill.json +++ b/.github/trigger_files/beam_PreCommit_Python_Dill.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "revision": 2 + "revision": 3 } diff --git a/sdks/python/apache_beam/coders/coders.py b/sdks/python/apache_beam/coders/coders.py index ddcb80630487..2421447a5bc5 100644 --- a/sdks/python/apache_beam/coders/coders.py +++ b/sdks/python/apache_beam/coders/coders.py @@ -178,10 +178,17 @@ def is_deterministic(self): """ return False - def as_deterministic_coder(self, step_label, error_message=None): + def as_deterministic_coder( + self, step_label, error_message=None, options=None): """Returns a deterministic version of self, if possible. Otherwise raises a value error. + + Args: + step_label: A label for the step requiring determinism. + error_message: Optional custom error message if coder cannot be made + deterministic. + options: Optional PipelineOptions for version compatibility checks. """ if self.is_deterministic(): return self @@ -538,10 +545,13 @@ def is_deterministic(self): # Map ordering is non-deterministic return False - def as_deterministic_coder(self, step_label, error_message=None): + def as_deterministic_coder( + self, step_label, error_message=None, options=None): return DeterministicMapCoder( - self._key_coder.as_deterministic_coder(step_label, error_message), - self._value_coder.as_deterministic_coder(step_label, error_message)) + self._key_coder.as_deterministic_coder( + step_label, error_message, options), + self._value_coder.as_deterministic_coder( + step_label, error_message, options)) def __eq__(self, other): return ( @@ -616,12 +626,13 @@ def is_deterministic(self): # type: () -> bool return self._value_coder.is_deterministic() - def as_deterministic_coder(self, step_label, error_message=None): + def as_deterministic_coder( + self, step_label, error_message=None, options=None): if self.is_deterministic(): return self else: deterministic_value_coder = self._value_coder.as_deterministic_coder( - step_label, error_message) + step_label, error_message, options) return NullableCoder(deterministic_value_coder) def __eq__(self, other): @@ -883,8 +894,10 @@ def _nonhashable_dumps(x): return coder_impl.CallbackCoderImpl(_nonhashable_dumps, pickler.loads) - def as_deterministic_coder(self, step_label, error_message=None): - return FastPrimitivesCoder(self, requires_deterministic=step_label) + def as_deterministic_coder( + self, step_label, error_message=None, options=None): + return _update_compatible_deterministic_fast_primitives_coder( + FastPrimitivesCoder(self), step_label, options) def to_type_hint(self): return Any @@ -898,8 +911,10 @@ def _create_impl(self): return coder_impl.CallbackCoderImpl( lambda x: dumps(x, protocol), pickle.loads) - def as_deterministic_coder(self, step_label, error_message=None): - return FastPrimitivesCoder(self, requires_deterministic=step_label) + def as_deterministic_coder( + self, step_label, error_message=None, options=None): + return _update_compatible_deterministic_fast_primitives_coder( + FastPrimitivesCoder(self), step_label, options) def to_type_hint(self): return Any @@ -927,16 +942,14 @@ def _create_impl(self): class DeterministicFastPrimitivesCoderV2(FastCoder): """Throws runtime errors when encoding non-deterministic values.""" - def __init__(self, coder, step_label, update_compatibility_version=None): + def __init__(self, coder, step_label, options=None): self._underlying_coder = coder self._step_label = step_label self._use_relative_filepaths = True self._version_tag = "v2_69" - from apache_beam.transforms.util import is_v1_prior_to_v2 # Versions prior to 2.69.0 did not use relative filepaths. - if update_compatibility_version and is_v1_prior_to_v2( - v1=update_compatibility_version, v2="2.69.0"): + if options and options.is_compat_version_prior_to("2.69.0"): self._version_tag = "" self._use_relative_filepaths = False @@ -1005,20 +1018,15 @@ def to_type_hint(self): return Any -def _should_force_use_dill(registry): +def _should_force_use_dill(options=None): # force_dill_deterministic_coders is for testing purposes. If there is a # DeterministicFastPrimitivesCoder in the pipeline graph but the dill - # encoding path is not really triggered dill does not have to be installed. + # encoding path is not really triggered dill does not have to be installed # and this check can be skipped. - if getattr(registry, 'force_dill_deterministic_coders', False): + if getattr(options, 'force_dill_deterministic_coders', False): return True - from apache_beam.transforms.util import is_v1_prior_to_v2 - update_compat_version = registry.update_compatibility_version - if not update_compat_version: - return False - - if not is_v1_prior_to_v2(v1=update_compat_version, v2="2.68.0"): + if options is None or not options.is_compat_version_prior_to("2.68.0"): return False try: @@ -1032,7 +1040,8 @@ def _should_force_use_dill(registry): return True -def _update_compatible_deterministic_fast_primitives_coder(coder, step_label): +def _update_compatible_deterministic_fast_primitives_coder( + coder, step_label, options=None): """ Returns the update compatible version of DeterministicFastPrimitivesCoder The differences are in how "special types" e.g. NamedTuples, Dataclasses are deterministically encoded. @@ -1043,12 +1052,9 @@ def _update_compatible_deterministic_fast_primitives_coder(coder, step_label): - In SDK version 2.69.0 cloudpickle is used to encode "special types" with relative filepaths in code objects and dynamic functions. """ - from apache_beam.coders import typecoders - - if _should_force_use_dill(typecoders.registry): + if _should_force_use_dill(options): return DeterministicFastPrimitivesCoder(coder, step_label) - return DeterministicFastPrimitivesCoderV2( - coder, step_label, typecoders.registry.update_compatibility_version) + return DeterministicFastPrimitivesCoderV2(coder, step_label, options) class FastPrimitivesCoder(FastCoder): @@ -1067,12 +1073,13 @@ def is_deterministic(self): # type: () -> bool return self._fallback_coder.is_deterministic() - def as_deterministic_coder(self, step_label, error_message=None): + def as_deterministic_coder( + self, step_label, error_message=None, options=None): if self.is_deterministic(): return self else: return _update_compatible_deterministic_fast_primitives_coder( - self, step_label) + self, step_label, options) def to_type_hint(self): return Any @@ -1167,7 +1174,8 @@ def is_deterministic(self): # a Map. return False - def as_deterministic_coder(self, step_label, error_message=None): + def as_deterministic_coder( + self, step_label, error_message=None, options=None): return DeterministicProtoCoder(self.proto_message_type) def __eq__(self, other): @@ -1213,7 +1221,8 @@ def is_deterministic(self): # type: () -> bool return True - def as_deterministic_coder(self, step_label, error_message=None): + def as_deterministic_coder( + self, step_label, error_message=None, options=None): return self @@ -1300,12 +1309,13 @@ def is_deterministic(self): # type: () -> bool return all(c.is_deterministic() for c in self._coders) - def as_deterministic_coder(self, step_label, error_message=None): + def as_deterministic_coder( + self, step_label, error_message=None, options=None): if self.is_deterministic(): return self else: return TupleCoder([ - c.as_deterministic_coder(step_label, error_message) + c.as_deterministic_coder(step_label, error_message, options) for c in self._coders ]) @@ -1379,12 +1389,14 @@ def is_deterministic(self): # type: () -> bool return self._elem_coder.is_deterministic() - def as_deterministic_coder(self, step_label, error_message=None): + def as_deterministic_coder( + self, step_label, error_message=None, options=None): if self.is_deterministic(): return self else: return TupleSequenceCoder( - self._elem_coder.as_deterministic_coder(step_label, error_message)) + self._elem_coder.as_deterministic_coder( + step_label, error_message, options)) @classmethod def from_type_hint(cls, typehint, registry): @@ -1419,12 +1431,14 @@ def is_deterministic(self): # type: () -> bool return self._elem_coder.is_deterministic() - def as_deterministic_coder(self, step_label, error_message=None): + def as_deterministic_coder( + self, step_label, error_message=None, options=None): if self.is_deterministic(): return self else: return type(self)( - self._elem_coder.as_deterministic_coder(step_label, error_message)) + self._elem_coder.as_deterministic_coder( + step_label, error_message, options)) def value_coder(self): return self._elem_coder diff --git a/sdks/python/apache_beam/coders/coders_test_common.py b/sdks/python/apache_beam/coders/coders_test_common.py index fcc5e6ac58bf..b6f3001027fd 100644 --- a/sdks/python/apache_beam/coders/coders_test_common.py +++ b/sdks/python/apache_beam/coders/coders_test_common.py @@ -41,7 +41,10 @@ from apache_beam.coders import coders from apache_beam.coders import proto2_coder_test_messages_pb2 as test_message from apache_beam.coders import typecoders +from apache_beam.coders.row_coder import RowCoder +from apache_beam.typehints.schemas import typing_to_runner_api from apache_beam.internal import pickler +from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.runners import pipeline_context from apache_beam.transforms import userstate from apache_beam.transforms import window @@ -202,9 +205,6 @@ def tearDownClass(cls): assert not standard - cls.seen, str(standard - cls.seen) assert not cls.seen_nested - standard, str(cls.seen_nested - standard) - def tearDown(self): - typecoders.registry.update_compatibility_version = None - @classmethod def _observe(cls, coder): cls.seen.add(type(coder)) @@ -275,13 +275,14 @@ def test_deterministic_coder(self, compat_version): with relative filepaths in code objects and dynamic functions. """ - typecoders.registry.update_compatibility_version = compat_version + options = PipelineOptions(update_compatibility_version=compat_version) coder = coders.FastPrimitivesCoder() if not dill and compat_version == "2.67.0": with self.assertRaises(RuntimeError): - coder.as_deterministic_coder(step_label="step") + coder.as_deterministic_coder(step_label="step", options=options) self.skipTest('Dill not installed') - deterministic_coder = coder.as_deterministic_coder(step_label="step") + deterministic_coder = coder.as_deterministic_coder( + step_label="step", options=options) self.check_coder(deterministic_coder, *self.test_values_deterministic) for v in self.test_values_deterministic: @@ -364,7 +365,7 @@ def test_deterministic_map_coder_is_update_compatible(self, compat_version): - In SDK version >=2.69.0 cloudpickle is used to encode "special types" with relative file. """ - typecoders.registry.update_compatibility_version = compat_version + options = PipelineOptions(update_compatibility_version=compat_version) values = [{ MyTypedNamedTuple(i, 'a'): MyTypedNamedTuple('a', i) for i in range(10) @@ -375,10 +376,11 @@ def test_deterministic_map_coder_is_update_compatible(self, compat_version): if not dill and compat_version == "2.67.0": with self.assertRaises(RuntimeError): - coder.as_deterministic_coder(step_label="step") + coder.as_deterministic_coder(step_label="step", options=options) self.skipTest('Dill not installed') - deterministic_coder = coder.as_deterministic_coder(step_label="step") + deterministic_coder = coder.as_deterministic_coder( + step_label="step", options=options) assert isinstance( deterministic_coder._key_coder, @@ -387,6 +389,53 @@ def test_deterministic_map_coder_is_update_compatible(self, compat_version): self.check_coder(deterministic_coder, *values) + @parameterized.expand([ + param(compat_version=None), + param(compat_version="2.67.0"), + param(compat_version="2.68.0"), + ]) + def test_deterministic_row_coder_is_update_compatible(self, compat_version): + """ Test that RowCoder.as_deterministic_coder propagates options to + component coders for proper version compatibility. + + - In SDK version <= 2.67.0 dill is used to encode "special types" + - In SDK version 2.68.0 cloudpickle is used to encode "special types" with + absolute filepaths in code objects and dynamic functions. + - In SDK version >=2.69.0 cloudpickle is used to encode "special types" + with relative filepaths in code objects and dynamic functions. + """ + # Create a NamedTuple with an Any field which uses FastPrimitivesCoder + RowWithAny = NamedTuple('RowWithAny', [('name', str), ('data', Any)]) + schema = typing_to_runner_api(RowWithAny).row_type.schema + + options = PipelineOptions(update_compatibility_version=compat_version) + coder = RowCoder(schema) + + if not dill and compat_version == "2.67.0": + with self.assertRaises(RuntimeError): + coder.as_deterministic_coder(step_label="step", options=options) + self.skipTest('Dill not installed') + + deterministic_coder = coder.as_deterministic_coder( + step_label="step", options=options) + + # The 'data' field (index 1) should have the appropriate deterministic coder + # based on the compat_version + data_coder = deterministic_coder.components[1] + expected_coder_type = ( + coders.DeterministicFastPrimitivesCoderV2 if compat_version + in (None, "2.68.0") else coders.DeterministicFastPrimitivesCoder) + self.assertIsInstance(data_coder, expected_coder_type) + + # Verify encoding/decoding works + test_values = [ + RowWithAny(name='test', data={'key': 'value'}), + RowWithAny(name='test2', data=[1, 2, 3]), + ] + for value in test_values: + self.assertEqual( + value, deterministic_coder.decode(deterministic_coder.encode(value))) + def test_dill_coder(self): if not dill: with self.assertRaises(RuntimeError): @@ -738,7 +787,7 @@ def test_cross_process_encoding_of_special_types_is_deterministic( if sys.executable is None: self.skipTest('No Python interpreter found') - typecoders.registry.update_compatibility_version = compat_version + options = PipelineOptions(update_compatibility_version=compat_version) # pylint: disable=line-too-long script = textwrap.dedent( @@ -750,7 +799,7 @@ def test_cross_process_encoding_of_special_types_is_deterministic( import logging from apache_beam.coders import coders - from apache_beam.coders import typecoders + from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.coders.coders_test_common import MyNamedTuple from apache_beam.coders.coders_test_common import MyTypedNamedTuple from apache_beam.coders.coders_test_common import MyEnum @@ -802,9 +851,9 @@ def test_cross_process_encoding_of_special_types_is_deterministic( ]) compat_version = {'"'+ compat_version +'"' if compat_version else None} - typecoders.registry.update_compatibility_version = compat_version + options = PipelineOptions(update_compatibility_version=compat_version) coder = coders.FastPrimitivesCoder() - deterministic_coder = coder.as_deterministic_coder("step") + deterministic_coder = coder.as_deterministic_coder("step", options=options) results = dict() for test_name, value in test_cases: @@ -834,7 +883,7 @@ def run_subprocess(): results2 = run_subprocess() coder = coders.FastPrimitivesCoder() - deterministic_coder = coder.as_deterministic_coder("step") + deterministic_coder = coder.as_deterministic_coder("step", options=options) for test_name in results1: diff --git a/sdks/python/apache_beam/coders/row_coder.py b/sdks/python/apache_beam/coders/row_coder.py index 1becf408cfbf..3a3b1a56cfac 100644 --- a/sdks/python/apache_beam/coders/row_coder.py +++ b/sdks/python/apache_beam/coders/row_coder.py @@ -52,13 +52,16 @@ class RowCoder(FastCoder): Implements the beam:coder:row:v1 standard coder spec. """ - def __init__(self, schema, force_deterministic=False): + def __init__(self, schema, force_deterministic=False, options=None): """Initializes a :class:`RowCoder`. Args: schema (apache_beam.portability.api.schema_pb2.Schema): The protobuf representation of the schema of the data that the RowCoder will be used to encode/decode. + force_deterministic: If truthy, used as the step_label for making + component coders deterministic. + options: Optional PipelineOptions for version compatibility. """ self.schema = schema @@ -71,7 +74,8 @@ def __init__(self, schema, force_deterministic=False): ] if force_deterministic: self.components = [ - c.as_deterministic_coder(force_deterministic) for c in self.components + c.as_deterministic_coder(force_deterministic, options=options) + for c in self.components ] self.forced_deterministic = bool(force_deterministic) @@ -81,11 +85,12 @@ def _create_impl(self): def is_deterministic(self): return all(c.is_deterministic() for c in self.components) - def as_deterministic_coder(self, step_label, error_message=None): + def as_deterministic_coder( + self, step_label, error_message=None, options=None): if self.is_deterministic(): return self else: - return RowCoder(self.schema, error_message or step_label) + return RowCoder(self.schema, error_message or step_label, options) def to_type_hint(self): return self._type_hint diff --git a/sdks/python/apache_beam/coders/typecoders.py b/sdks/python/apache_beam/coders/typecoders.py index ef75a21ce9ef..0a46f6f316b1 100644 --- a/sdks/python/apache_beam/coders/typecoders.py +++ b/sdks/python/apache_beam/coders/typecoders.py @@ -84,7 +84,6 @@ def __init__(self, fallback_coder=None): self._coders: Dict[Any, Type[coders.Coder]] = {} self.custom_types: List[Any] = [] self.register_standard_coders(fallback_coder) - self.update_compatibility_version = None def register_standard_coders(self, fallback_coder): """Register coders for all basic and composite types.""" @@ -185,7 +184,7 @@ def load_custom_type_coder_tuples(self, type_coder): for t, c in type_coder: self.register_coder(t, c) - def verify_deterministic(self, key_coder, op_name, silent=True): + def verify_deterministic(self, key_coder, op_name, silent=True, options=None): if not key_coder.is_deterministic(): error_msg = ( 'The key coder "%s" for %s ' @@ -195,7 +194,7 @@ def verify_deterministic(self, key_coder, op_name, silent=True): 'and for custom key classes, by writing a ' 'deterministic custom Coder. Please see the ' 'documentation for more details.' % (key_coder, op_name)) - return key_coder.as_deterministic_coder(op_name, error_msg) + return key_coder.as_deterministic_coder(op_name, error_msg, options) else: return key_coder diff --git a/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py b/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py index 76f465ddebbf..738ace67a5f7 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py @@ -1120,8 +1120,8 @@ def _load_data( of the load jobs would fail but not other. If any of them fails, then copy jobs are not triggered. """ - self.reshuffle_before_load = not util.is_compat_version_prior_to( - p.options, "2.65.0") + self.reshuffle_before_load = not p.options.is_compat_version_prior_to( + "2.65.0") if self.reshuffle_before_load: # Ensure that TriggerLoadJob retry inputs are deterministic by breaking # fusion for inputs. diff --git a/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py b/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py index 30f09ff4f56a..f851602f47f2 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py @@ -486,8 +486,6 @@ def test_records_traverse_transform_with_mocks(self): param(compat_version="2.64.0"), ]) def test_reshuffle_before_load(self, compat_version): - from apache_beam.coders import typecoders - typecoders.registry.force_dill_deterministic_coders = True destination = 'project1:dataset1.table1' job_reference = bigquery_api.JobReference() @@ -513,17 +511,14 @@ def test_reshuffle_before_load(self, compat_version): validate=False, temp_file_format=bigquery_tools.FileFormat.JSON) - options = PipelineOptions( - update_compatibility_version=compat_version, - # Disable unrelated compatibility change. - force_cloudpickle_deterministic_coders=True) + options = PipelineOptions(update_compatibility_version=compat_version) + object.__setattr__(options, 'force_dill_deterministic_coders', True) # Need to test this with the DirectRunner to avoid serializing mocks with TestPipeline('DirectRunner', options=options) as p: _ = p | beam.Create(_ELEMENTS) | transform reshuffle_before_load = compat_version is None assert transform.reshuffle_before_load == reshuffle_before_load - typecoders.registry.force_dill_deterministic_coders = False def test_load_job_id_used(self): job_reference = bigquery_api.JobReference() @@ -1000,9 +995,6 @@ def dynamic_destination_resolver(element, *side_inputs): ]) def test_triggering_frequency( self, is_streaming, with_auto_sharding, compat_version): - from apache_beam.coders import typecoders - typecoders.registry.force_dill_deterministic_coders = True - destination = 'project1:dataset1.table1' job_reference = bigquery_api.JobReference() @@ -1048,6 +1040,7 @@ def __call__(self): flags=['--allow_unsafe_triggers'], update_compatibility_version=compat_version) test_options.view_as(StandardOptions).streaming = is_streaming + object.__setattr__(test_options, 'force_dill_deterministic_coders', True) with TestPipeline(runner='BundleBasedDirectRunner', options=test_options) as p: if is_streaming: @@ -1108,8 +1101,6 @@ def __call__(self): label='CheckDestinations') assert_that(jobs, equal_to(expected_jobs), label='CheckJobs') - typecoders.registry.force_dill_deterministic_coders = False - class BigQueryFileLoadsIT(unittest.TestCase): diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py index 0e1012b2de65..4a3a9f240bf8 100644 --- a/sdks/python/apache_beam/options/pipeline_options.py +++ b/sdks/python/apache_beam/options/pipeline_options.py @@ -668,6 +668,25 @@ def view_as(self, cls: Type[PipelineOptionsT]) -> PipelineOptionsT: view._all_options = self._all_options return view + def is_compat_version_prior_to(self, breaking_change_version): + """Check if update_compatibility_version is prior to a breaking change. + + Returns True if the pipeline should use old behavior (i.e., the + update_compatibility_version is set and is earlier than the given version). + Returns False if update_compatibility_version is not set or is >= the + breaking change version. + + Args: + breaking_change_version: Version string (e.g., "2.72.0") at which + the breaking change was introduced. + """ + v1 = self.view_as(StreamingOptions).update_compatibility_version + if v1 is None: + return False + v1_parts = (v1.split('.') + ['0', '0', '0'])[:3] + v2_parts = (breaking_change_version.split('.') + ['0', '0', '0'])[:3] + return tuple(map(int, v1_parts)) < tuple(map(int, v2_parts)) + def _visible_option_list(self) -> List[str]: return sorted( option for option in dir(self._visible_options) if option[0] != '_') @@ -898,18 +917,6 @@ def _add_argparse_args(cls, parser): 'their condition met. Some operations, such as GroupByKey, disallow ' 'this. This exists for cases where such loss is acceptable and for ' 'backwards compatibility. See BEAM-9487.') - parser.add_argument( - '--force_cloudpickle_deterministic_coders', - default=False, - action='store_true', - help=( - 'Force the use of cloudpickle-based deterministic coders ' - 'instead of dill-based coders, even when ' - 'update_compatibility_version would normally trigger dill usage ' - 'for backward compatibility. This flag overrides automatic coder ' - 'selection to always use the modern cloudpickle serialization ' - ' path. Warning: May break pipeline update compatibility with ' - ' SDK versions prior to 2.68.0.')) def validate(self, unused_validator): errors = [] diff --git a/sdks/python/apache_beam/options/pipeline_options_test.py b/sdks/python/apache_beam/options/pipeline_options_test.py index c683c9625272..215c44156ea6 100644 --- a/sdks/python/apache_beam/options/pipeline_options_test.py +++ b/sdks/python/apache_beam/options/pipeline_options_test.py @@ -987,6 +987,72 @@ def test_comma_separated_dataflow_service_options(self): options.get_all_options()['dataflow_service_options']) +class CompatVersionTest(unittest.TestCase): + def test_is_compat_version_prior_to(self): + test_cases = [ + # Basic comparison cases + ("1.0.0", "2.0.0", True), # v1 < v2 in major + ("2.0.0", "1.0.0", False), # v1 > v2 in major + ("1.1.0", "1.2.0", True), # v1 < v2 in minor + ("1.2.0", "1.1.0", False), # v1 > v2 in minor + ("1.0.1", "1.0.2", True), # v1 < v2 in patch + ("1.0.2", "1.0.1", False), # v1 > v2 in patch + + # Equal versions + ("1.0.0", "1.0.0", False), # Identical + ("0.0.0", "0.0.0", False), # Both zero + + # Different lengths - shorter vs longer + ("1.0", "1.0.0", False), # Should be equal (1.0 = 1.0.0) + ("1.0", "1.0.1", True), # 1.0.0 < 1.0.1 + ("1.2", "1.2.0", False), # Should be equal (1.2 = 1.2.0) + ("1.2", "1.2.3", True), # 1.2.0 < 1.2.3 + ("2", "2.0.0", False), # Should be equal (2 = 2.0.0) + ("2", "2.0.1", True), # 2.0.0 < 2.0.1 + ("1", "2.0", True), # 1.0.0 < 2.0.0 + + # Different lengths - longer vs shorter + ("1.0.0", "1.0", False), # Should be equal + ("1.0.1", "1.0", False), # 1.0.1 > 1.0.0 + ("1.2.0", "1.2", False), # Should be equal + ("1.2.3", "1.2", False), # 1.2.3 > 1.2.0 + ("2.0.0", "2", False), # Should be equal + ("2.0.1", "2", False), # 2.0.1 > 2.0.0 + ("2.0", "1", False), # 2.0.0 > 1.0.0 + + # Mixed length comparisons + ("1.0", "2.0.0", True), # 1.0.0 < 2.0.0 + ("2.0", "1.0.0", False), # 2.0.0 > 1.0.0 + ("1", "1.0.1", True), # 1.0.0 < 1.0.1 + ("1.1", "1.0.9", False), # 1.1.0 > 1.0.9 + + # Large numbers + ("1.9.9", "2.0.0", True), # 1.9.9 < 2.0.0 + ("10.0.0", "9.9.9", False), # 10.0.0 > 9.9.9 + ("1.10.0", "1.9.0", False), # 1.10.0 > 1.9.0 + ("1.2.10", "1.2.9", False), # 1.2.10 > 1.2.9 + + # Sequential versions + ("1.0.0", "1.0.1", True), + ("1.0.1", "1.0.2", True), + ("1.0.9", "1.1.0", True), + ("1.9.9", "2.0.0", True), + ] + + for v1, v2, expected in test_cases: + options = PipelineOptions(update_compatibility_version=v1) + self.assertEqual( + options.is_compat_version_prior_to(v2), + expected, + msg=f"Failed {v1} < {v2} == {expected}") + + # None case: no update_compatibility_version set + options_no_compat = PipelineOptions() + self.assertFalse( + options_no_compat.is_compat_version_prior_to("1.0.0"), + msg="Should return False when update_compatibility_version is not set") + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py index 6ef06abb7436..1ca071bdfbc4 100644 --- a/sdks/python/apache_beam/pipeline.py +++ b/sdks/python/apache_beam/pipeline.py @@ -79,7 +79,6 @@ from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.options.pipeline_options import SetupOptions from apache_beam.options.pipeline_options import StandardOptions -from apache_beam.options.pipeline_options import StreamingOptions from apache_beam.options.pipeline_options import TypeOptions from apache_beam.options.pipeline_options_validator import PipelineOptionsValidator from apache_beam.portability import common_urns @@ -226,9 +225,6 @@ def __init__( raise ValueError( 'Pipeline has validations errors: \n' + '\n'.join(errors)) - typecoders.registry.update_compatibility_version = self._options.view_as( - StreamingOptions).update_compatibility_version - # set default experiments for portable runners # (needs to occur prior to pipeline construction) if runner.is_fnapi_compatible(): @@ -1051,7 +1047,8 @@ def to_runner_api( context = pipeline_context.PipelineContext( use_fake_coders=use_fake_coders, component_id_map=self.component_id_map, - default_environment=default_environment) + default_environment=default_environment, + options=self.options) elif default_environment is not None: raise ValueError( 'Only one of context or default_environment may be specified.') diff --git a/sdks/python/apache_beam/pipeline_test.py b/sdks/python/apache_beam/pipeline_test.py index 3e7d083cb2fb..8e1ef1fe9def 100644 --- a/sdks/python/apache_beam/pipeline_test.py +++ b/sdks/python/apache_beam/pipeline_test.py @@ -23,14 +23,20 @@ import platform import unittest import uuid +from typing import Any +from typing import NamedTuple import mock import pytest +from parameterized import param +from parameterized import parameterized import apache_beam as beam from apache_beam import coders from apache_beam import typehints from apache_beam.coders import BytesCoder +from apache_beam.coders.coders import DeterministicFastPrimitivesCoder +from apache_beam.coders.coders import DeterministicFastPrimitivesCoderV2 from apache_beam.io import Read from apache_beam.io.iobase import SourceBase from apache_beam.options.pipeline_options import PortableOptions @@ -67,6 +73,11 @@ from apache_beam.utils.timestamp import MIN_TIMESTAMP +class KeyWithAnyField(NamedTuple): + id: int + data: Any + + class FakeUnboundedSource(SourceBase): """Fake unbounded source. Does not work at runtime""" def is_bounded(self): @@ -1023,6 +1034,7 @@ def test_dir(self): self.assertEqual({ 'from_dictionary', 'get_all_options', + 'is_compat_version_prior_to', 'slices', 'style', 'view_as', @@ -1038,6 +1050,7 @@ def test_dir(self): self.assertEqual({ 'from_dictionary', 'get_all_options', + 'is_compat_version_prior_to', 'style', 'view_as', 'display_data', @@ -1650,6 +1663,53 @@ def expand(self, pcoll): # ParDo.with_outputs in ParentSalesSplitter. assert len(xform.outputs) == 2 + @parameterized.expand([ + param(compat_version=None), + param(compat_version="2.67.0"), + param(compat_version="2.68.0"), + ]) + def test_pipeline_options_propagate_to_deterministic_coders( + self, compat_version): + """End-to-end test verifying pipeline options propagate to coder selection. + + When a pipeline with update_compatibility_version is serialized via + to_runner_api(), the PipelineContext should use the options to select + the appropriate deterministic coder version: + - SDK version <= 2.67.0: DeterministicFastPrimitivesCoder (dill-based) + - SDK version >= 2.68.0: DeterministicFastPrimitivesCoderV2 (cloudpickle) + """ + options = PipelineOptions(update_compatibility_version=compat_version) + p = beam.Pipeline(options=options) + + # Create a pipeline with GroupByKey that requires deterministic key coders + # KeyWithAnyField is a NamedTuple with an Any field, requiring deterministic + # encoding + _ = ( + p + | beam.Create([(KeyWithAnyField(1, {'nested': 'data'}), 'value1')]) + | beam.GroupByKey()) + + # Serialize to runner API and get the context back + _, context = p.to_runner_api(return_context=True) + + # Get the coder for our key type and find its deterministic version + key_coder = coders.registry.get_coder(KeyWithAnyField) + self.assertIn( + key_coder, + context.deterministic_coder_map, + "Expected coder for KeyWithAnyField to be in deterministic_coder_map") + + deterministic_key_coder = context.deterministic_coder_map[key_coder] + expected_coder_type = ( + DeterministicFastPrimitivesCoderV2 if compat_version + in (None, "2.68.0") else DeterministicFastPrimitivesCoder) + + self.assertIsInstance( + deterministic_key_coder, + expected_coder_type, + f"Expected {expected_coder_type.__name__} for compat_version=" + f"{compat_version}, but got {type(deterministic_key_coder).__name__}") + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py index d33c33f84fee..347362232855 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py @@ -544,7 +544,9 @@ def _verify_gbk_coders(self, transform, pcoll): 'key-value coder: %s.') % (transform.label, coder)) # TODO(robertwb): Update the coder itself if it changed. coders.registry.verify_deterministic( - coder.key_coder(), 'GroupByKey operation "%s"' % transform.label) + coder.key_coder(), + 'GroupByKey operation "%s"' % transform.label, + options=pcoll.pipeline.options) def get_default_gcp_region(self): """Get a default value for Google Cloud region according to diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py index 73b0321b5de4..bc5b77378300 100644 --- a/sdks/python/apache_beam/runners/direct/direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/direct_runner.py @@ -357,7 +357,8 @@ def expand(self, pcoll): pcoll.element_type) typecoders.registry.verify_deterministic( typecoders.registry.get_coder(key_type), - 'GroupByKey operation "%s"' % self.label) + 'GroupByKey operation "%s"' % self.label, + options=pcoll.pipeline.options) reify_output_type = typehints.KV[ key_type, typehints.WindowedValue[value_type]] # type: ignore[misc] diff --git a/sdks/python/apache_beam/runners/pipeline_context.py b/sdks/python/apache_beam/runners/pipeline_context.py index f367598f9293..3aef38b92ebb 100644 --- a/sdks/python/apache_beam/runners/pipeline_context.py +++ b/sdks/python/apache_beam/runners/pipeline_context.py @@ -43,6 +43,7 @@ from apache_beam.coders.coder_impl import IterableStateReader from apache_beam.coders.coder_impl import IterableStateWriter from apache_beam.internal import pickler +from apache_beam.options import pipeline_options from apache_beam.pipeline import ComponentIdMap from apache_beam.portability.api import beam_fn_api_pb2 from apache_beam.portability.api import beam_runner_api_pb2 @@ -177,6 +178,7 @@ def __init__( iterable_state_write: Optional[IterableStateWriter] = None, namespace: str = 'ref', requirements: Iterable[str] = (), + options: Optional[pipeline_options.PipelineOptions] = None, ) -> None: if isinstance(proto, beam_fn_api_pb2.ProcessBundleDescriptor): proto = beam_runner_api_pb2.Components( @@ -226,6 +228,7 @@ def __init__( self.iterable_state_read = iterable_state_read self.iterable_state_write = iterable_state_write self._requirements = set(requirements) + self.options = options self.enable_best_effort_deterministic_pickling = False self.enable_stable_code_identifier_pickling = False @@ -258,7 +261,8 @@ def coder_id_from_element_type( def deterministic_coder(self, coder: coders.Coder, msg: str) -> coders.Coder: if coder not in self.deterministic_coder_map: - self.deterministic_coder_map[coder] = coder.as_deterministic_coder(msg) # type: ignore + self.deterministic_coder_map[coder] = coder.as_deterministic_coder( # type: ignore + msg, options=self.options) return self.deterministic_coder_map[coder] def element_type_from_coder_id(self, coder_id: str) -> Any: diff --git a/sdks/python/apache_beam/transforms/external.py b/sdks/python/apache_beam/transforms/external.py index 21a863069e63..9469ac717dfc 100644 --- a/sdks/python/apache_beam/transforms/external.py +++ b/sdks/python/apache_beam/transforms/external.py @@ -48,7 +48,6 @@ from apache_beam.runners.portability import artifact_service from apache_beam.transforms import environments from apache_beam.transforms import ptransform -from apache_beam.transforms.util import is_compat_version_prior_to from apache_beam.typehints import WithTypeHints from apache_beam.typehints import native_type_compatibility from apache_beam.typehints import row_type @@ -499,9 +498,9 @@ def expand(self, pcolls): expansion_service = self._expansion_service if self._managed_replacement: - compat_version_prior_to_current = is_compat_version_prior_to( - pcolls.pipeline._options, - self._managed_replacement.update_compatibility_version) + compat_version_prior_to_current = ( + pcolls.pipeline._options.is_compat_version_prior_to( + self._managed_replacement.update_compatibility_version)) if not compat_version_prior_to_current: payload_builder = self._managed_payload_builder expansion_service = self._managed_expansion_service diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py index d5985b6212df..06ea822aa0be 100644 --- a/sdks/python/apache_beam/transforms/ptransform.py +++ b/sdks/python/apache_beam/transforms/ptransform.py @@ -1062,10 +1062,8 @@ def expand(self, pcoll): return self._fn(pcoll, *args, **kwargs) def set_options(self, options): - # Avoid circular import. - from apache_beam.transforms.util import is_compat_version_prior_to - self._use_backwards_compatible_label = is_compat_version_prior_to( - options, '2.68.0') + self._use_backwards_compatible_label = options.is_compat_version_prior_to( + '2.68.0') def default_label(self) -> str: # Attempt to give a reasonable name to this transform. diff --git a/sdks/python/apache_beam/transforms/stats.py b/sdks/python/apache_beam/transforms/stats.py index fb38a883dd39..a028d4df268c 100644 --- a/sdks/python/apache_beam/transforms/stats.py +++ b/sdks/python/apache_beam/transforms/stats.py @@ -152,8 +152,8 @@ def expand(self, pcoll): coder = coders.registry.get_coder(pcoll) return pcoll \ | 'CountGlobalUniqueValues' \ - >> (CombineGlobally(ApproximateUniqueCombineFn(self._sample_size, - coder))) + >> (CombineGlobally(ApproximateUniqueCombineFn( + self._sample_size, coder, pcoll.pipeline.options))) @typehints.with_input_types(typing.Tuple[K, V]) @typehints.with_output_types(typing.Tuple[K, int]) @@ -166,8 +166,8 @@ def expand(self, pcoll): coder = coders.registry.get_coder(pcoll) return pcoll \ | 'CountPerKeyUniqueValues' \ - >> (CombinePerKey(ApproximateUniqueCombineFn(self._sample_size, - coder))) + >> (CombinePerKey(ApproximateUniqueCombineFn( + self._sample_size, coder, pcoll.pipeline.options))) class _LargestUnique(object): @@ -242,10 +242,10 @@ class ApproximateUniqueCombineFn(CombineFn): ApproximateUniqueCombineFn computes an estimate of the number of unique values that were combined. """ - def __init__(self, sample_size, coder): + def __init__(self, sample_size, coder, options=None): self._sample_size = sample_size coder = coders.typecoders.registry.verify_deterministic( - coder, 'ApproximateUniqueCombineFn') + coder, 'ApproximateUniqueCombineFn', options=options) self._coder = coder self._hash_fn = _get_default_hash_fn() diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index dd14bd8f57bd..f3ffb39750f7 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -47,7 +47,6 @@ from apache_beam import pvalue from apache_beam import typehints from apache_beam.metrics import Metrics -from apache_beam.options import pipeline_options from apache_beam.portability import common_urns from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.pvalue import AsSideInput @@ -705,7 +704,8 @@ def expand(self, pcoll): if kv_type_hint and kv_type_hint != typehints.Any: coder = coders.registry.get_coder(kv_type_hint) try: - coder = coder.as_deterministic_coder(self.label) + coder = coder.as_deterministic_coder( + self.label, options=pcoll.pipeline.options) except ValueError: logging.warning( 'GroupByEncryptedKey %s: ' @@ -1353,27 +1353,6 @@ def get_window_coder(self): return self._window_coder -def is_v1_prior_to_v2(*, v1, v2): - if v1 is None: - return False - - v1_parts = (v1.split('.') + ['0', '0', '0'])[:3] - v2_parts = (v2.split('.') + ['0', '0', '0'])[:3] - return tuple(map(int, v1_parts)) < tuple(map(int, v2_parts)) - - -def is_compat_version_prior_to(options, breaking_change_version): - # This function is used in a branch statement to determine whether we should - # keep the old behavior prior to a breaking change or use the new behavior. - # - If update_compatibility_version < breaking_change_version, we will return - # True and keep the old behavior. - update_compatibility_version = options.view_as( - pipeline_options.StreamingOptions).update_compatibility_version - - return is_v1_prior_to_v2( - v1=update_compatibility_version, v2=breaking_change_version) - - def reify_metadata_default_window( element, timestamp=DoFn.TimestampParam, pane_info=DoFn.PaneInfoParam): key, value = element @@ -1451,8 +1430,8 @@ def restore_timestamps(element): for (value, timestamp) in values ] - if is_compat_version_prior_to(pcoll.pipeline.options, - RESHUFFLE_TYPEHINT_BREAKING_CHANGE_VERSION): + if pcoll.pipeline.options.is_compat_version_prior_to( + RESHUFFLE_TYPEHINT_BREAKING_CHANGE_VERSION): pre_gbk_map = Map(reify_timestamps).with_output_types(Any) else: pre_gbk_map = Map(reify_timestamps).with_input_types( @@ -1471,8 +1450,8 @@ def restore_timestamps(element): key, windowed_values = element return [wv.with_value((key, wv.value)) for wv in windowed_values] - if is_compat_version_prior_to(pcoll.pipeline.options, - RESHUFFLE_TYPEHINT_BREAKING_CHANGE_VERSION): + if pcoll.pipeline.options.is_compat_version_prior_to( + RESHUFFLE_TYPEHINT_BREAKING_CHANGE_VERSION): pre_gbk_map = Map(reify_timestamps).with_output_types(Any) else: pre_gbk_map = Map(reify_timestamps).with_input_types( @@ -1496,7 +1475,7 @@ def restore_timestamps(element): return result def expand(self, pcoll): - if is_compat_version_prior_to(pcoll.pipeline.options, "2.65.0"): + if pcoll.pipeline.options.is_compat_version_prior_to("2.65.0"): return self.expand_2_64_0(pcoll) windowing_saved = pcoll.windowing @@ -1553,8 +1532,8 @@ def __init__(self, num_buckets=None): def expand(self, pcoll): # type: (pvalue.PValue) -> pvalue.PCollection - if is_compat_version_prior_to(pcoll.pipeline.options, - RESHUFFLE_TYPEHINT_BREAKING_CHANGE_VERSION): + if pcoll.pipeline.options.is_compat_version_prior_to( + RESHUFFLE_TYPEHINT_BREAKING_CHANGE_VERSION): reshuffle_step = ReshufflePerKey() else: reshuffle_step = ReshufflePerKey().with_input_types( diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index 448ba8a7ad9d..6c6168ad3e12 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -1322,9 +1322,6 @@ def test_reshuffle_streaming_global_window_with_buckets(self): param(compat_version="2.64.0"), ]) def test_reshuffle_custom_window_preserves_metadata(self, compat_version): - """Tests that Reshuffle preserves pane info.""" - from apache_beam.coders import typecoders - typecoders.registry.force_dill_deterministic_coders = True element_count = 12 timestamp_value = timestamp.Timestamp(0) l = [ @@ -1390,6 +1387,7 @@ def test_reshuffle_custom_window_preserves_metadata(self, compat_version): ]) options = PipelineOptions(update_compatibility_version=compat_version) options.view_as(StandardOptions).streaming = True + object.__setattr__(options, 'force_dill_deterministic_coders', True) with beam.Pipeline(options=options) as p: stream_source = ( @@ -1418,7 +1416,6 @@ def test_reshuffle_custom_window_preserves_metadata(self, compat_version): equal_to(expected), label='CheckMetadataPreserved', reify_windows=True) - typecoders.registry.force_dill_deterministic_coders = False @parameterized.expand([ param(compat_version=None), @@ -1427,8 +1424,6 @@ def test_reshuffle_custom_window_preserves_metadata(self, compat_version): def test_reshuffle_default_window_preserves_metadata(self, compat_version): """Tests that Reshuffle preserves timestamp, window, and pane info metadata.""" - from apache_beam.coders import typecoders - typecoders.registry.force_dill_deterministic_coders = True no_firing = PaneInfo( is_first=True, is_last=True, @@ -1481,6 +1476,7 @@ def test_reshuffle_default_window_preserves_metadata(self, compat_version): if compat_version is None else expected_not_preserved) options = PipelineOptions(update_compatibility_version=compat_version) + object.__setattr__(options, 'force_dill_deterministic_coders', True) with TestPipeline(options=options) as pipeline: # Create windowed values with specific metadata elements = [ @@ -1502,7 +1498,6 @@ def test_reshuffle_default_window_preserves_metadata(self, compat_version): equal_to(expected), label='CheckMetadataPreserved', reify_windows=True) - typecoders.registry.force_dill_deterministic_coders = False @pytest.mark.it_validatesrunner def test_reshuffle_preserves_timestamps(self): @@ -2560,68 +2555,6 @@ def record(tag): label='result') -class CompatCheckTest(unittest.TestCase): - def test_is_v1_prior_to_v2(self): - test_cases = [ - # Basic comparison cases - ("1.0.0", "2.0.0", True), # v1 < v2 in major - ("2.0.0", "1.0.0", False), # v1 > v2 in major - ("1.1.0", "1.2.0", True), # v1 < v2 in minor - ("1.2.0", "1.1.0", False), # v1 > v2 in minor - ("1.0.1", "1.0.2", True), # v1 < v2 in patch - ("1.0.2", "1.0.1", False), # v1 > v2 in patch - - # Equal versions - ("1.0.0", "1.0.0", False), # Identical - ("0.0.0", "0.0.0", False), # Both zero - - # Different lengths - shorter vs longer - ("1.0", "1.0.0", False), # Should be equal (1.0 = 1.0.0) - ("1.0", "1.0.1", True), # 1.0.0 < 1.0.1 - ("1.2", "1.2.0", False), # Should be equal (1.2 = 1.2.0) - ("1.2", "1.2.3", True), # 1.2.0 < 1.2.3 - ("2", "2.0.0", False), # Should be equal (2 = 2.0.0) - ("2", "2.0.1", True), # 2.0.0 < 2.0.1 - ("1", "2.0", True), # 1.0.0 < 2.0.0 - - # Different lengths - longer vs shorter - ("1.0.0", "1.0", False), # Should be equal - ("1.0.1", "1.0", False), # 1.0.1 > 1.0.0 - ("1.2.0", "1.2", False), # Should be equal - ("1.2.3", "1.2", False), # 1.2.3 > 1.2.0 - ("2.0.0", "2", False), # Should be equal - ("2.0.1", "2", False), # 2.0.1 > 2.0.0 - ("2.0", "1", False), # 2.0.0 > 1.0.0 - - # Mixed length comparisons - ("1.0", "2.0.0", True), # 1.0.0 < 2.0.0 - ("2.0", "1.0.0", False), # 2.0.0 > 1.0.0 - ("1", "1.0.1", True), # 1.0.0 < 1.0.1 - ("1.1", "1.0.9", False), # 1.1.0 > 1.0.9 - - # Large numbers - ("1.9.9", "2.0.0", True), # 1.9.9 < 2.0.0 - ("10.0.0", "9.9.9", False), # 10.0.0 > 9.9.9 - ("1.10.0", "1.9.0", False), # 1.10.0 > 1.9.0 - ("1.2.10", "1.2.9", False), # 1.2.10 > 1.2.9 - - # Sequential versions - ("1.0.0", "1.0.1", True), - ("1.0.1", "1.0.2", True), - ("1.0.9", "1.1.0", True), - ("1.9.9", "2.0.0", True), - - # Null/None cases - (None, "1.0.0", False), # v1 is None - ] - - for v1, v2, expected in test_cases: - self.assertEqual( - util.is_v1_prior_to_v2(v1=v1, v2=v2), - expected, - msg=f"Failed {v1} < {v2} == {expected}") - - if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main()