diff --git a/sdks/python/apache_beam/internal/cloudpickle/cloudpickle.py b/sdks/python/apache_beam/internal/cloudpickle/cloudpickle.py index 495e888a5167..4f2c5321edf1 100644 --- a/sdks/python/apache_beam/internal/cloudpickle/cloudpickle.py +++ b/sdks/python/apache_beam/internal/cloudpickle/cloudpickle.py @@ -161,11 +161,15 @@ class CloudPickleConfig: code changes: when a particular lambda function is slightly modified but the location of the function in the codebase has not changed, the pickled representation might stay the same. + + pickle_main_by_ref: An optional boolean. If provided, cloudpickle will + pickle main by reference instead of by value. """ id_generator: typing.Optional[callable] = uuid_generator skip_reset_dynamic_type_state: bool = False filepath_interceptor: typing.Optional[callable] = None get_code_object_params: typing.Optional[GetCodeObjectParams] = None + pickle_main_by_ref: bool = False DEFAULT_CONFIG = CloudPickleConfig() @@ -316,7 +320,7 @@ def _whichmodule(obj, name): return None -def _should_pickle_by_reference(obj, name=None): +def _should_pickle_by_reference(obj, name=None, config=DEFAULT_CONFIG): """Test whether an function or a class should be pickled by reference Pickling by reference means by that the object (typically a function or a @@ -331,7 +335,7 @@ def _should_pickle_by_reference(obj, name=None): explicitly registered to be pickled by value. """ if isinstance(obj, types.FunctionType) or issubclass(type(obj), type): - module_and_name = _lookup_module_and_qualname(obj, name=name) + module_and_name = _lookup_module_and_qualname(obj, name=name, config=config) if module_and_name is None: return False module, name = module_and_name @@ -351,7 +355,7 @@ def _should_pickle_by_reference(obj, name=None): "cannot check importability of {} instances".format(type(obj).__name__)) -def _lookup_module_and_qualname(obj, name=None): +def _lookup_module_and_qualname(obj, name=None, config=DEFAULT_CONFIG): if name is None: name = getattr(obj, "__qualname__", None) if name is None: # pragma: no cover @@ -367,7 +371,7 @@ def _lookup_module_and_qualname(obj, name=None): # imported module. obj is thus treated as dynamic. return None - if module_name == "__main__": + if module_name == "__main__" and not config.pickle_main_by_ref: return None # Note: if module_name is in sys.modules, the corresponding module is @@ -718,7 +722,8 @@ def _decompose_typevar(obj, config: CloudPickleConfig): def _typevar_reduce(obj, config: CloudPickleConfig): # TypeVar instances require the module information hence why we # are not using the _should_pickle_by_reference directly - module_and_name = _lookup_module_and_qualname(obj, name=obj.__name__) + module_and_name = _lookup_module_and_qualname( + obj, name=obj.__name__, config=config) if module_and_name is None: return (_make_typevar, _decompose_typevar(obj, config)) @@ -1185,7 +1190,7 @@ def _class_reduce(obj, config: CloudPickleConfig): return type, (NotImplemented, ) elif obj in _BUILTIN_TYPE_NAMES: return _builtin_type, (_BUILTIN_TYPE_NAMES[obj], ) - elif not _should_pickle_by_reference(obj): + elif not _should_pickle_by_reference(obj, config=config): return _dynamic_class_reduce(obj, config) return NotImplemented @@ -1410,7 +1415,7 @@ def _function_reduce(self, obj): obj using a custom cloudpickle reducer designed specifically to handle dynamic functions. """ - if _should_pickle_by_reference(obj): + if _should_pickle_by_reference(obj, config=self.config): return NotImplemented elif self.config.get_code_object_params is not None: return self._stable_identifier_function_reduce(obj) @@ -1617,7 +1622,7 @@ def save_global(self, obj, name=None, pack=struct.pack): if name is not None: super().save_global(obj, name=name) - elif not _should_pickle_by_reference(obj, name=name): + elif not _should_pickle_by_reference(obj, name=name, config=self.config): self._save_reduce_pickle5( *_dynamic_class_reduce(obj, self.config), obj=obj) else: @@ -1642,7 +1647,7 @@ def save_function(self, obj, name=None): Determines what kind of function obj is (e.g. lambda, defined at interactive prompt, etc) and handles the pickling appropriately. """ - if _should_pickle_by_reference(obj, name=name): + if _should_pickle_by_reference(obj, name=name, config=self.config): return super().save_global(obj, name=name) elif PYPY and isinstance(obj.__code__, builtin_code_type): return self.save_pypy_builtin_func(obj)