Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/trigger_files/beam_PreCommit_Python_Dill.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run",
"revision": 2
"revision": 3
}
92 changes: 53 additions & 39 deletions sdks/python/apache_beam/coders/coders.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,17 @@ def is_deterministic(self):
"""
return False

def as_deterministic_coder(self, step_label, error_message=None):
def as_deterministic_coder(
self, step_label, error_message=None, options=None):
"""Returns a deterministic version of self, if possible.

Otherwise raises a value error.

Args:
step_label: A label for the step requiring determinism.
error_message: Optional custom error message if coder cannot be made
deterministic.
options: Optional PipelineOptions for version compatibility checks.
"""
if self.is_deterministic():
return self
Expand Down Expand Up @@ -538,10 +545,13 @@ def is_deterministic(self):
# Map ordering is non-deterministic
return False

def as_deterministic_coder(self, step_label, error_message=None):
def as_deterministic_coder(
self, step_label, error_message=None, options=None):
return DeterministicMapCoder(
self._key_coder.as_deterministic_coder(step_label, error_message),
self._value_coder.as_deterministic_coder(step_label, error_message))
self._key_coder.as_deterministic_coder(
step_label, error_message, options),
self._value_coder.as_deterministic_coder(
step_label, error_message, options))

def __eq__(self, other):
return (
Expand Down Expand Up @@ -616,12 +626,13 @@ def is_deterministic(self):
# type: () -> bool
return self._value_coder.is_deterministic()

def as_deterministic_coder(self, step_label, error_message=None):
def as_deterministic_coder(
self, step_label, error_message=None, options=None):
if self.is_deterministic():
return self
else:
deterministic_value_coder = self._value_coder.as_deterministic_coder(
step_label, error_message)
step_label, error_message, options)
return NullableCoder(deterministic_value_coder)

def __eq__(self, other):
Expand Down Expand Up @@ -883,8 +894,10 @@ def _nonhashable_dumps(x):

return coder_impl.CallbackCoderImpl(_nonhashable_dumps, pickler.loads)

def as_deterministic_coder(self, step_label, error_message=None):
return FastPrimitivesCoder(self, requires_deterministic=step_label)
def as_deterministic_coder(
self, step_label, error_message=None, options=None):
return _update_compatible_deterministic_fast_primitives_coder(
FastPrimitivesCoder(self), step_label, options)

def to_type_hint(self):
return Any
Expand All @@ -898,8 +911,10 @@ def _create_impl(self):
return coder_impl.CallbackCoderImpl(
lambda x: dumps(x, protocol), pickle.loads)

def as_deterministic_coder(self, step_label, error_message=None):
return FastPrimitivesCoder(self, requires_deterministic=step_label)
def as_deterministic_coder(
self, step_label, error_message=None, options=None):
return _update_compatible_deterministic_fast_primitives_coder(
FastPrimitivesCoder(self), step_label, options)

def to_type_hint(self):
return Any
Expand Down Expand Up @@ -927,16 +942,14 @@ def _create_impl(self):

class DeterministicFastPrimitivesCoderV2(FastCoder):
"""Throws runtime errors when encoding non-deterministic values."""
def __init__(self, coder, step_label, update_compatibility_version=None):
def __init__(self, coder, step_label, options=None):
self._underlying_coder = coder
self._step_label = step_label
self._use_relative_filepaths = True
self._version_tag = "v2_69"
from apache_beam.transforms.util import is_v1_prior_to_v2

# Versions prior to 2.69.0 did not use relative filepaths.
if update_compatibility_version and is_v1_prior_to_v2(
v1=update_compatibility_version, v2="2.69.0"):
if options and options.is_compat_version_prior_to("2.69.0"):
self._version_tag = ""
self._use_relative_filepaths = False

Expand Down Expand Up @@ -1005,20 +1018,15 @@ def to_type_hint(self):
return Any


def _should_force_use_dill(registry):
def _should_force_use_dill(options=None):
# force_dill_deterministic_coders is for testing purposes. If there is a
# DeterministicFastPrimitivesCoder in the pipeline graph but the dill
# encoding path is not really triggered dill does not have to be installed.
# encoding path is not really triggered dill does not have to be installed
# and this check can be skipped.
if getattr(registry, 'force_dill_deterministic_coders', False):
if getattr(options, 'force_dill_deterministic_coders', False):
return True

from apache_beam.transforms.util import is_v1_prior_to_v2
update_compat_version = registry.update_compatibility_version
if not update_compat_version:
return False

if not is_v1_prior_to_v2(v1=update_compat_version, v2="2.68.0"):
if options is None or not options.is_compat_version_prior_to("2.68.0"):
return False

try:
Expand All @@ -1032,7 +1040,8 @@ def _should_force_use_dill(registry):
return True


def _update_compatible_deterministic_fast_primitives_coder(coder, step_label):
def _update_compatible_deterministic_fast_primitives_coder(
coder, step_label, options=None):
""" Returns the update compatible version of DeterministicFastPrimitivesCoder
The differences are in how "special types" e.g. NamedTuples, Dataclasses are
deterministically encoded.
Expand All @@ -1043,12 +1052,9 @@ def _update_compatible_deterministic_fast_primitives_coder(coder, step_label):
- In SDK version 2.69.0 cloudpickle is used to encode "special types" with
relative filepaths in code objects and dynamic functions.
"""
from apache_beam.coders import typecoders

if _should_force_use_dill(typecoders.registry):
if _should_force_use_dill(options):
return DeterministicFastPrimitivesCoder(coder, step_label)
return DeterministicFastPrimitivesCoderV2(
coder, step_label, typecoders.registry.update_compatibility_version)
return DeterministicFastPrimitivesCoderV2(coder, step_label, options)


class FastPrimitivesCoder(FastCoder):
Expand All @@ -1067,12 +1073,13 @@ def is_deterministic(self):
# type: () -> bool
return self._fallback_coder.is_deterministic()

def as_deterministic_coder(self, step_label, error_message=None):
def as_deterministic_coder(
self, step_label, error_message=None, options=None):
if self.is_deterministic():
return self
else:
return _update_compatible_deterministic_fast_primitives_coder(
self, step_label)
self, step_label, options)

def to_type_hint(self):
return Any
Expand Down Expand Up @@ -1167,7 +1174,8 @@ def is_deterministic(self):
# a Map.
return False

def as_deterministic_coder(self, step_label, error_message=None):
def as_deterministic_coder(
self, step_label, error_message=None, options=None):
return DeterministicProtoCoder(self.proto_message_type)

def __eq__(self, other):
Expand Down Expand Up @@ -1213,7 +1221,8 @@ def is_deterministic(self):
# type: () -> bool
return True

def as_deterministic_coder(self, step_label, error_message=None):
def as_deterministic_coder(
self, step_label, error_message=None, options=None):
return self


Expand Down Expand Up @@ -1300,12 +1309,13 @@ def is_deterministic(self):
# type: () -> bool
return all(c.is_deterministic() for c in self._coders)

def as_deterministic_coder(self, step_label, error_message=None):
def as_deterministic_coder(
self, step_label, error_message=None, options=None):
if self.is_deterministic():
return self
else:
return TupleCoder([
c.as_deterministic_coder(step_label, error_message)
c.as_deterministic_coder(step_label, error_message, options)
for c in self._coders
])

Expand Down Expand Up @@ -1379,12 +1389,14 @@ def is_deterministic(self):
# type: () -> bool
return self._elem_coder.is_deterministic()

def as_deterministic_coder(self, step_label, error_message=None):
def as_deterministic_coder(
self, step_label, error_message=None, options=None):
if self.is_deterministic():
return self
else:
return TupleSequenceCoder(
self._elem_coder.as_deterministic_coder(step_label, error_message))
self._elem_coder.as_deterministic_coder(
step_label, error_message, options))

@classmethod
def from_type_hint(cls, typehint, registry):
Expand Down Expand Up @@ -1419,12 +1431,14 @@ def is_deterministic(self):
# type: () -> bool
return self._elem_coder.is_deterministic()

def as_deterministic_coder(self, step_label, error_message=None):
def as_deterministic_coder(
self, step_label, error_message=None, options=None):
if self.is_deterministic():
return self
else:
return type(self)(
self._elem_coder.as_deterministic_coder(step_label, error_message))
self._elem_coder.as_deterministic_coder(
step_label, error_message, options))

def value_coder(self):
return self._elem_coder
Expand Down
77 changes: 63 additions & 14 deletions sdks/python/apache_beam/coders/coders_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@
from apache_beam.coders import coders
from apache_beam.coders import proto2_coder_test_messages_pb2 as test_message
from apache_beam.coders import typecoders
from apache_beam.coders.row_coder import RowCoder
from apache_beam.typehints.schemas import typing_to_runner_api
from apache_beam.internal import pickler
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.runners import pipeline_context
from apache_beam.transforms import userstate
from apache_beam.transforms import window
Expand Down Expand Up @@ -202,9 +205,6 @@ def tearDownClass(cls):
assert not standard - cls.seen, str(standard - cls.seen)
assert not cls.seen_nested - standard, str(cls.seen_nested - standard)

def tearDown(self):
typecoders.registry.update_compatibility_version = None

@classmethod
def _observe(cls, coder):
cls.seen.add(type(coder))
Expand Down Expand Up @@ -275,13 +275,14 @@ def test_deterministic_coder(self, compat_version):
with relative filepaths in code objects and dynamic functions.
"""

typecoders.registry.update_compatibility_version = compat_version
options = PipelineOptions(update_compatibility_version=compat_version)
coder = coders.FastPrimitivesCoder()
if not dill and compat_version == "2.67.0":
with self.assertRaises(RuntimeError):
coder.as_deterministic_coder(step_label="step")
coder.as_deterministic_coder(step_label="step", options=options)
self.skipTest('Dill not installed')
deterministic_coder = coder.as_deterministic_coder(step_label="step")
deterministic_coder = coder.as_deterministic_coder(
step_label="step", options=options)

self.check_coder(deterministic_coder, *self.test_values_deterministic)
for v in self.test_values_deterministic:
Expand Down Expand Up @@ -364,7 +365,7 @@ def test_deterministic_map_coder_is_update_compatible(self, compat_version):
- In SDK version >=2.69.0 cloudpickle is used to encode "special types"
with relative file.
"""
typecoders.registry.update_compatibility_version = compat_version
options = PipelineOptions(update_compatibility_version=compat_version)
values = [{
MyTypedNamedTuple(i, 'a'): MyTypedNamedTuple('a', i)
for i in range(10)
Expand All @@ -375,10 +376,11 @@ def test_deterministic_map_coder_is_update_compatible(self, compat_version):

if not dill and compat_version == "2.67.0":
with self.assertRaises(RuntimeError):
coder.as_deterministic_coder(step_label="step")
coder.as_deterministic_coder(step_label="step", options=options)
self.skipTest('Dill not installed')

deterministic_coder = coder.as_deterministic_coder(step_label="step")
deterministic_coder = coder.as_deterministic_coder(
step_label="step", options=options)

assert isinstance(
deterministic_coder._key_coder,
Expand All @@ -387,6 +389,53 @@ def test_deterministic_map_coder_is_update_compatible(self, compat_version):

self.check_coder(deterministic_coder, *values)

@parameterized.expand([
param(compat_version=None),
param(compat_version="2.67.0"),
param(compat_version="2.68.0"),
])
def test_deterministic_row_coder_is_update_compatible(self, compat_version):
""" Test that RowCoder.as_deterministic_coder propagates options to
component coders for proper version compatibility.

- In SDK version <= 2.67.0 dill is used to encode "special types"
- In SDK version 2.68.0 cloudpickle is used to encode "special types" with
absolute filepaths in code objects and dynamic functions.
- In SDK version >=2.69.0 cloudpickle is used to encode "special types"
with relative filepaths in code objects and dynamic functions.
"""
# Create a NamedTuple with an Any field which uses FastPrimitivesCoder
RowWithAny = NamedTuple('RowWithAny', [('name', str), ('data', Any)])
schema = typing_to_runner_api(RowWithAny).row_type.schema

options = PipelineOptions(update_compatibility_version=compat_version)
coder = RowCoder(schema)

if not dill and compat_version == "2.67.0":
with self.assertRaises(RuntimeError):
coder.as_deterministic_coder(step_label="step", options=options)
self.skipTest('Dill not installed')

deterministic_coder = coder.as_deterministic_coder(
step_label="step", options=options)

# The 'data' field (index 1) should have the appropriate deterministic coder
# based on the compat_version
data_coder = deterministic_coder.components[1]
expected_coder_type = (
coders.DeterministicFastPrimitivesCoderV2 if compat_version
in (None, "2.68.0") else coders.DeterministicFastPrimitivesCoder)
self.assertIsInstance(data_coder, expected_coder_type)

# Verify encoding/decoding works
test_values = [
RowWithAny(name='test', data={'key': 'value'}),
RowWithAny(name='test2', data=[1, 2, 3]),
]
for value in test_values:
self.assertEqual(
value, deterministic_coder.decode(deterministic_coder.encode(value)))

def test_dill_coder(self):
if not dill:
with self.assertRaises(RuntimeError):
Expand Down Expand Up @@ -738,7 +787,7 @@ def test_cross_process_encoding_of_special_types_is_deterministic(

if sys.executable is None:
self.skipTest('No Python interpreter found')
typecoders.registry.update_compatibility_version = compat_version
options = PipelineOptions(update_compatibility_version=compat_version)

# pylint: disable=line-too-long
script = textwrap.dedent(
Expand All @@ -750,7 +799,7 @@ def test_cross_process_encoding_of_special_types_is_deterministic(
import logging

from apache_beam.coders import coders
from apache_beam.coders import typecoders
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.coders.coders_test_common import MyNamedTuple
from apache_beam.coders.coders_test_common import MyTypedNamedTuple
from apache_beam.coders.coders_test_common import MyEnum
Expand Down Expand Up @@ -802,9 +851,9 @@ def test_cross_process_encoding_of_special_types_is_deterministic(
])

compat_version = {'"'+ compat_version +'"' if compat_version else None}
typecoders.registry.update_compatibility_version = compat_version
options = PipelineOptions(update_compatibility_version=compat_version)
coder = coders.FastPrimitivesCoder()
deterministic_coder = coder.as_deterministic_coder("step")
deterministic_coder = coder.as_deterministic_coder("step", options=options)

results = dict()
for test_name, value in test_cases:
Expand Down Expand Up @@ -834,7 +883,7 @@ def run_subprocess():
results2 = run_subprocess()

coder = coders.FastPrimitivesCoder()
deterministic_coder = coder.as_deterministic_coder("step")
deterministic_coder = coder.as_deterministic_coder("step", options=options)

for test_name in results1:

Expand Down
Loading
Loading