-
Notifications
You must be signed in to change notification settings - Fork 4.5k
[Stateful] Implement length-aware keying to minimize padding in BatchElements (Part 2/3) #37565
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
0aa8bcb
47b5a9b
53454f3
4837afb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -178,6 +178,8 @@ def __init__( | |
| max_batch_duration_secs: Optional[int] = None, | ||
| max_batch_weight: Optional[int] = None, | ||
| element_size_fn: Optional[Callable[[Any], int]] = None, | ||
| length_fn: Optional[Callable[[Any], int]] = None, | ||
| bucket_boundaries: Optional[list[int]] = None, | ||
| large_model: bool = False, | ||
| model_copies: Optional[int] = None, | ||
| **kwargs): | ||
|
|
@@ -190,6 +192,11 @@ def __init__( | |
| before emitting; used in streaming contexts. | ||
| max_batch_weight: the maximum weight of a batch. Requires element_size_fn. | ||
| element_size_fn: a function that returns the size (weight) of an element. | ||
| length_fn: a callable mapping an element to its length. When set with | ||
| max_batch_duration_secs, enables length-aware bucketed keying so | ||
| elements of similar length are batched together. | ||
| bucket_boundaries: sorted list of positive boundary values for length | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we add more data to this description, similar to below? |
||
| bucketing. Requires length_fn. | ||
| large_model: set to true if your model is large enough to run into | ||
| memory pressure if you load multiple copies. | ||
| model_copies: The exact number of models that you would like loaded | ||
|
|
@@ -209,6 +216,10 @@ def __init__( | |
| self._batching_kwargs['max_batch_weight'] = max_batch_weight | ||
| if element_size_fn is not None: | ||
| self._batching_kwargs['element_size_fn'] = element_size_fn | ||
| if length_fn is not None: | ||
| self._batching_kwargs['length_fn'] = length_fn | ||
| if bucket_boundaries is not None: | ||
| self._batching_kwargs['bucket_boundaries'] = bucket_boundaries | ||
| self._large_model = large_model | ||
| self._model_copies = model_copies | ||
| self._share_across_processes = large_model or (model_copies is not None) | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -20,6 +20,7 @@ | |||||
|
|
||||||
| # pytype: skip-file | ||||||
|
|
||||||
| import bisect | ||||||
| import collections | ||||||
| import contextlib | ||||||
| import hashlib | ||||||
|
|
@@ -1208,6 +1209,28 @@ def process(self, element): | |||||
| yield (self.key, element) | ||||||
|
|
||||||
|
|
||||||
| class WithLengthBucketKey(DoFn): | ||||||
| """Keys elements with (worker_uuid, length_bucket) for length-aware | ||||||
| stateful batching. Elements of similar length are routed to the same | ||||||
| state partition, reducing padding waste.""" | ||||||
| def __init__(self, length_fn, bucket_boundaries): | ||||||
| self.shared_handle = shared.Shared() | ||||||
| self._length_fn = length_fn | ||||||
| self._bucket_boundaries = bucket_boundaries | ||||||
|
|
||||||
| def setup(self): | ||||||
| self.key = self.shared_handle.acquire( | ||||||
| load_shared_key, "WithLengthBucketKey").key | ||||||
|
|
||||||
| def _get_bucket(self, length): | ||||||
| return bisect.bisect_left(self._bucket_boundaries, length) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using With the current If you make this change, please also update the assertions in
Suggested change
|
||||||
|
|
||||||
| def process(self, element): | ||||||
| length = self._length_fn(element) | ||||||
| bucket = self._get_bucket(length) | ||||||
| yield ((self.key, bucket), element) | ||||||
|
|
||||||
|
|
||||||
| @typehints.with_input_types(T) | ||||||
| @typehints.with_output_types(list[T]) | ||||||
| class BatchElements(PTransform): | ||||||
|
|
@@ -1267,7 +1290,18 @@ class BatchElements(PTransform): | |||||
| donwstream operations (mostly for testing) | ||||||
| record_metrics: (optional) whether or not to record beam metrics on | ||||||
| distributions of the batch size. Defaults to True. | ||||||
| length_fn: (optional) a callable mapping an element to its length (int). | ||||||
| When set together with max_batch_duration_secs, enables length-aware | ||||||
| bucketed keying on the stateful path so that elements of similar length | ||||||
| are routed to the same batch, reducing padding waste. | ||||||
| bucket_boundaries: (optional) a sorted list of positive boundary values | ||||||
| for length bucketing. Elements with length < boundaries[i] go to | ||||||
| bucket i; overflow goes to bucket len(boundaries). Defaults to | ||||||
| [16, 32, 64, 128, 256, 512] when length_fn is set. Requires | ||||||
| length_fn. | ||||||
| """ | ||||||
| _DEFAULT_BUCKET_BOUNDARIES = [16, 32, 64, 128, 256, 512] | ||||||
|
|
||||||
| def __init__( | ||||||
| self, | ||||||
| min_batch_size=1, | ||||||
|
|
@@ -1280,7 +1314,17 @@ def __init__( | |||||
| element_size_fn=lambda x: 1, | ||||||
| variance=0.25, | ||||||
| clock=time.time, | ||||||
| record_metrics=True): | ||||||
| record_metrics=True, | ||||||
| length_fn=None, | ||||||
| bucket_boundaries=None): | ||||||
| if bucket_boundaries is not None and length_fn is None: | ||||||
| raise ValueError('bucket_boundaries requires length_fn to be set.') | ||||||
| if bucket_boundaries is not None: | ||||||
| if (not bucket_boundaries or any(b <= 0 for b in bucket_boundaries) or | ||||||
| bucket_boundaries != sorted(bucket_boundaries)): | ||||||
| raise ValueError( | ||||||
| 'bucket_boundaries must be a non-empty sorted list of ' | ||||||
| 'positive values.') | ||||||
| self._batch_size_estimator = _BatchSizeEstimator( | ||||||
| min_batch_size=min_batch_size, | ||||||
| max_batch_size=max_batch_size, | ||||||
|
|
@@ -1294,13 +1338,23 @@ def __init__( | |||||
| self._element_size_fn = element_size_fn | ||||||
| self._max_batch_dur = max_batch_duration_secs | ||||||
| self._clock = clock | ||||||
| self._length_fn = length_fn | ||||||
| if length_fn is not None and bucket_boundaries is None: | ||||||
| self._bucket_boundaries = self._DEFAULT_BUCKET_BOUNDARIES | ||||||
| else: | ||||||
| self._bucket_boundaries = bucket_boundaries | ||||||
|
|
||||||
| def expand(self, pcoll): | ||||||
| if getattr(pcoll.pipeline.runner, 'is_streaming', False): | ||||||
| raise NotImplementedError("Requires stateful processing (BEAM-2687)") | ||||||
| elif self._max_batch_dur is not None: | ||||||
| coder = coders.registry.get_coder(pcoll) | ||||||
| return pcoll | ParDo(WithSharedKey()) | ParDo( | ||||||
| if self._length_fn is not None: | ||||||
| keying_dofn = WithLengthBucketKey( | ||||||
| self._length_fn, self._bucket_boundaries) | ||||||
| else: | ||||||
| keying_dofn = WithSharedKey() | ||||||
| return pcoll | ParDo(keying_dofn) | ParDo( | ||||||
| _pardo_stateful_batch_elements( | ||||||
| coder, | ||||||
| self._batch_size_estimator, | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make it clear these are batching parameters? e.g.
batch_length_fnandbatch_bucket_boundaries?