[refactor] Convert BatchMeta to columnar layout; enable zero-copy serialization by default#39
Conversation
CLA Signature Passmpb159753, thanks for your pull request. All authors of the commits have signed the CLA. 👍 |
…ialization by default - Restructure BatchMeta from row-based (FieldMeta/SampleMeta) to columnar storage - Add _SampleView for lazy read-only row access into columnar BatchMeta - Enable zero-copy (msgpack) serialization by default; auto-fallback to pickle - Remove TQ_ZERO_COPY_SERIALIZATION env var toggle - Update all related tests and managers to new columnar BatchMeta API - Add serial_utils.py with shared serialization helpers Signed-off-by: 看我72遍 <m.pb@msn.com>
…st layout - Change BatchMeta.custom_meta and _custom_backend_meta from dict[int, dict] to list[dict], aligned positionally with global_indexes - Update __post_init__ to validate length and initialize empty lists - Adapt get/update/clear_custom_meta, select_samples, select_fields, concat, reorder, empty, to_dict, from_dict methods - Add dict→list bridge in controller._build_batch_metadata - Update simple_backend_manager and base.py to access by position index - Adapt test_metadata.py and test_kv_storage_manager.py accordingly Signed-off-by: 看我72遍 <m.pb@msn.com>
Add BatchMeta.with_data_fields() method that returns a new BatchMeta with the given field list, allowing field names not yet present in the current field_schema. This lets callers request newly-added fields on a known sample range without triggering poll_for_meta, which may return samples outside the intended region. Fix test_cross_shard_complex_update which incorrectly used poll_for_meta(force_fetch) to verify new fields — this could fetch samples (indices 0-9 or 30-39) that do not have the new_extra_* fields, causing a RuntimeError. Use meta_update.with_data_fields() instead to target exactly the 10-29 update region. Signed-off-by: 看我72遍 <m.pb@msn.com>
efdcc43 to
0a17163
Compare
CLA Signature Passmpb159753, thanks for your pull request. All authors of the commits have signed the CLA. 👍 |
1 similar comment
CLA Signature Passmpb159753, thanks for your pull request. All authors of the commits have signed the CLA. 👍 |
- Add CUSTOM_TYPE_NUMPY = 6 extension type - _encode_numpy(): direct memoryview extraction, no torch intermediary - _decode_numpy(): reconstruct np.ndarray from raw bytes - Update enc_hook to use _encode_numpy with exclusion-based dtype check - Register CUSTOM_TYPE_NUMPY in ext_hook - Update test_numpy_numeric_arrays_zero_copy to assert np.ndarray (not Tensor) - Update test_zmq_msg_serialization with correct TensorDict numpy comment - Add TestNumpyNativeSerialization with 22 new parametrized/edge-case tests Fixes: numpy arrays decoded as torch.Tensor (type information lost), and the torch.from_numpy + flatten() path could trigger extra copies. Signed-off-by: 看我72遍 <m.pb@msn.com>
89b43bb to
ed4a66f
Compare
CLA Signature Passmpb159753, thanks for your pull request. All authors of the commits have signed the CLA. 👍 |
…eview - metadata.py: fix BatchMeta.empty().to_dict() crash when production_status is None - metadata.py: explicitly serialize dtype as string in to_dict(); add _parse_dtype() helper for from_dict() to reconstruct torch/numpy dtypes without implicit pickle - controller.py: add field_schema_cache to DataPartitionStatus for O(F) get_field_schema() (replaces O(G*F) double-scan of field_dtypes/field_shapes per-sample maps) - zmq_utils.py: downgrade pickle-fallback log from WARNING to INFO; it is a normal degradation path (e.g. body contains torch.dtype objects), not an error - simple_backend_manager.py: move defaultdict import to top-level; remove two redundant local imports in get_data() and clear_data() - tutorial/03_metadata_concepts.py: update tutorial for columnar BatchMeta API - docs/plans: add serial_utils warning background doc for future discussion Signed-off-by: 看我72遍 <m.pb@msn.com>
ed4a66f to
83a3ee7
Compare
CLA Signature Passmpb159753, thanks for your pull request. All authors of the commits have signed the CLA. 👍 |
| """ | ||
| Demonstrate FieldMeta - specific data fields of each training sample. | ||
| Demonstrate BatchMeta field_schema - field-level metadata for the batch. | ||
| After the columnar refactoring, field metadata is stored once at the batch level |
There was a problem hiding this comment.
We can directly introduce the BatchMeta as it is now, don't need to compare with previous design
|
|
||
| print("FieldMeta represents a single field in ONE sample:") | ||
| print("- name: Field identifier ('Prompt', 'Response', etc.)") | ||
| print("field_schema stores metadata for each field ONCE per batch (not per sample):") |
| print(f" For global_index: use batch.global_indexes[0] = {ready_batch.global_indexes[0]}") | ||
|
|
||
|
|
||
| def demonstrate_batch_meta_construction(): |
There was a problem hiding this comment.
The core logics here is not demonstrating the construction, but usage.
There was a problem hiding this comment.
What's the relationship between this demo and demonstrate_batch_meta()
| """ | ||
| Demonstrate SampleMeta - describes a single data sample. | ||
| Demonstrate how to construct BatchMeta directly and operate on it | ||
| (analogous to old SampleMeta operations: add_fields, select_fields, union). |
|
|
||
|
|
||
| def demonstrate_field_meta(): | ||
| def demonstrate_batch_meta_schema(): |
There was a problem hiding this comment.
I believe we don't need to deliberately demonstrate the field_schema here. We can simplify it by showing how to construct a BatchMeta manually
| dtype=torch.int64, | ||
| shape=(512,), # Sequence length for ONE sample | ||
| production_status=ProductionStatus.NOT_PRODUCED, | ||
| # Example 2: Create a field schema entry for attention_mask |
There was a problem hiding this comment.
Functionally repeated
| return dtype_str | ||
|
|
||
|
|
||
| class _SampleView: |
| return self.fields.get(name) | ||
| def fields(self) -> dict: | ||
| """Read-only access to field_schema: batch.samples[i].fields['a'] -> field meta dict.""" | ||
| return self._batch.field_schema |
There was a problem hiding this comment.
We should only return the field_schema of THE selected sample?
| ) | ||
|
|
||
| partition_id = "demo_partition" | ||
| batch_meta = tq_client.put(data=data_batch, partition_id=partition_id) |
There was a problem hiding this comment.
I have tried this code but the returned batch_meta only has one field_schema. Is that because the field has uniform shape? What if it's nested?
There was a problem hiding this comment.
BatchMeta(global_indexes=[8, 9, 10, 11, 12, 13, 14, 15], partition_ids=['demo_partition', 'demo_partition', 'demo_partition', 'demo_partition', 'demo_partition', 'demo_partition', 'demo_partition', 'demo_partition'], field_schema={'input_ids': {'dtype': torch.int64, 'shape': torch.Size([512]), 'is_nested': False, 'is_non_tensor': False}, 'attention_mask': {'dtype': torch.float32, 'shape': torch.Size([512]), 'is_nested': False, 'is_non_tensor': False}, 'nested': {'dtype': torch.float32, 'shape': None, 'is_nested': True, 'is_non_tensor': False, 'per_sample_shapes': [(4, 3), (2, 4), (4, 4), (2, 3), (4, 5), (2, 2), (4, 6), (2, 1)]}}, extra_info={}, custom_meta=[{}, {}, {}, {}, {}, {}, {}, {}], _custom_backend_meta=[{}, {}, {}, {}, {}, {}, {}, {}])
There was a problem hiding this comment.
I have found a bug:
input_ids = torch.randint(0, 1000, (8, 512))
attention_mask = torch.ones(8, 512)
nested = torch.nested.as_nested_tensor(
[torch.randn(4, 3).int(), torch.randn(2, 4),torch.randn(4, 4), torch.randn(2, 3),torch.randn(4, 5), torch.randn(2, 2),torch.randn(4, 6), torch.randn(2, 1)], layout=torch.strided
)
data_batch = TensorDict(
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"nested": nested,
},
batch_size=8,
)
partition_id = "demo_partition"
batch_meta = tq_client.put(data=data_batch, partition_id=partition_id)The dtype is not supporting per_sample_dtypes and directly use the first element's dtype.
There was a problem hiding this comment.
Please check whether the current design fits the backend's requirements @dpj135 @tianyi-ge
| Extract the expected shape, dtype, and custom_backend_meta for each field-sample pair in metadata. | ||
| The order matches the key/value order: sorted by field name, then by global index. | ||
|
|
||
| O(F) optimized version that uses field_schema instead of per-sample metadata. |
There was a problem hiding this comment.
Don't need to compare with old version. It will confusing users
| Serializes the input tensors, stores them using the storage client, | ||
| extracts per-sample dtype and shape information, and sends a notification | ||
| to the controller that new data is available. | ||
| O(F) optimized version that extracts field-level schema instead of per-sample metadata. |
| for field_name, field_data in data.items(): | ||
| first_item = field_data[0] if len(field_data) > 0 else None | ||
| is_nested = isinstance(field_data, torch.Tensor) and field_data.is_nested | ||
| field_dtype = getattr(first_item, "dtype", type(first_item) if first_item is not None else None) |
There was a problem hiding this comment.
For nested tensor, we cannot assume the first item's dtype equals to all other tensors
| per_field_dtypes[global_idx][field_name] = field_dtype | ||
| if is_nested: | ||
| assert unbound is not None # is_nested=True implies unbind() was called | ||
| per_field_shapes[global_idx][field_name] = tuple(unbound[i].shape) |
There was a problem hiding this comment.
Lack per field dtype logic
| # Pre-compute unbind once to avoid O(B²) repeated calls inside the loop | ||
| unbound = field_data.unbind() if is_nested else None | ||
|
|
||
| for i, global_idx in enumerate(metadata.global_indexes): |
There was a problem hiding this comment.
We should only goes in this loop when the field is nested to speed up?
| def __post_init__(self): | ||
| """Initialize all computed properties during initialization""" | ||
| self.samples = copy.deepcopy(self.samples) | ||
| self.global_indexes = copy.deepcopy(self.global_indexes) |
There was a problem hiding this comment.
We'd better remove these deepcopy from __post_init__ since it will bring much higher pressure for controller -- now the controller is called for each single PUT/GET operation.
I suggest to provide a dedicate method that deepcopy these variables, and actively call this method for operations like chunk/union/reorder...
| Args: | ||
| fields: Field names used for getting data. | ||
| local_indexes: Local indexes used for getting data. | ||
| local_keys: Global indexes used as dict keys. |
There was a problem hiding this comment.
Do we still need local keys now?
There was a problem hiding this comment.
It can directly be the global_indexes
| # global_index is used directly as dict key in storage, no local_index conversion needed | ||
| global_indexes_slice = metadata.global_indexes[start:end] | ||
| local_indexes = list(global_indexes_slice) |
There was a problem hiding this comment.
unnecessary codes and comments here
| self._prepare_and_send_to_unit( | ||
| unit_idx=unit_idx, | ||
| storage_id=storage_id, | ||
| chunk_size=chunk_size, | ||
| batch_size=batch_size, | ||
| start_offset=0, # fixed; no cross-batch rotation | ||
| num_units=num_units, | ||
| data=data, | ||
| metadata=metadata, |
There was a problem hiding this comment.
There are too many input params. Many of them can be simplified. E.g. chunk_size, num_units
There was a problem hiding this comment.
And I suggest to provide a mapping function as earlier version, as it helps to make the dispatching logic more clear.
| """ | ||
| Retrieve data from remote StorageUnit based on metadata. | ||
|
|
||
| Routes to each SU using the _su_id recorded in metadata._custom_backend_meta |
There was a problem hiding this comment.
Why? We don't need _custom_backend_meta to do routing
| groups: dict[str, list[int]] = defaultdict(list) | ||
| for i, gi in enumerate(metadata.global_indexes): | ||
| backend_meta = metadata._custom_backend_meta[i] | ||
| if not backend_meta or "_su_id" not in backend_meta: | ||
| raise RuntimeError( | ||
| f"get_data: missing _su_id for global_index {gi} in _custom_backend_meta. " | ||
| f"Make sure put_data was called before get_data." | ||
| ) | ||
| groups[backend_meta["_su_id"]].append(gi) |
There was a problem hiding this comment.
Can be simplified by providing a mapping function
| async def _get_from_single_storage_unit( | ||
| self, storage_meta_group: StorageMetaGroup, target_storage_unit: str, socket: zmq.Socket = None | ||
| self, | ||
| gi_list: list[int], |
There was a problem hiding this comment.
fully spell as global_indexes. Use gi only for cases like [gi for gi in global_indexes]
| groups: dict[str, list[int]] = defaultdict(list) | ||
| for i, gi in enumerate(metadata.global_indexes): | ||
| backend_meta = metadata._custom_backend_meta[i] | ||
| if not backend_meta or "_su_id" not in backend_meta: | ||
| raise RuntimeError( | ||
| f"clear_data: missing _su_id for global_index {gi} in _custom_backend_meta. " | ||
| f"Make sure put_data was called before clear_data." | ||
| ) | ||
| groups[backend_meta["_su_id"]].append(gi) |
| CUSTOM_TYPE_CLOUDPICKLE = 2 | ||
| CUSTOM_TYPE_TENSOR = 3 # For tensor with buffer reference | ||
| CUSTOM_TYPE_NESTED_TENSOR = 4 # For nested tensor (strided or jagged) | ||
| CUSTOM_TYPE_BATCHMETA = 5 # For BatchMeta serialization |
There was a problem hiding this comment.
Is there any serial performance comparison for BatchMeta? I'm wondering if pickle could be faster for BatchMeta obj
| else: | ||
| return pickle.loads(frames[0]) | ||
| # pickle fallback path: serialize() sets frame[0] to _PICKLE_FALLBACK_SENTINEL on failure. | ||
| if len(frames) >= 2 and frames[0] == _PICKLE_FALLBACK_SENTINEL: |
There was a problem hiding this comment.
Why we have to tackle this outside serial_utils.py?
| if global_idx not in self.field_dtypes: | ||
| self.field_dtypes[global_idx] = {} | ||
| self.field_dtypes[global_idx].update(dtype_value[i]) | ||
| # Update field_schema_cache with new dtype info |
There was a problem hiding this comment.
What will happen if:
- We input a uniform tensor (same dtype & shape) in field A -> only has one dtype and shape info
- We append other format tensor with different dtype & shape in the same field A? Will we update the dtype and shape info to reflect the new changes?
A folloing question:
If user retrieve the data put in step 1, what's the BatchMeta looks like? It will say this is a nested or ordinary tensor?
jianjunzhong
left a comment
There was a problem hiding this comment.
Scripts in recipe/simple_use_case cannot run properly.
| self.production_status = np.zeros(batch_size, dtype=np.int8) | ||
|
|
||
| for field_name, meta in self.field_schema.items(): | ||
| if meta.get("per_sample_shapes") is not None: |
There was a problem hiding this comment.
per_sample_shapes is only meaningful when is_nested=True, but the validation logic doesn't check this correlation.
| return np.dtype(dtype_str) | ||
| except TypeError: | ||
| pass | ||
| # Fallback: return as-is (e.g. plain Python type repr like "<class 'int'>") |
There was a problem hiding this comment.
Should record a warning:
logger.warning(f"Unknown dtype string '{dtype_str}', returning as-is")
|
|
||
|
|
||
| def test_storage_unit_data_dict_key(): | ||
| """StorageUnitData dict-key: gi 直接作为 key,clear 真正释放内存.""" |
There was a problem hiding this comment.
It's better to use English instead of Chinese

Columnar BatchMeta + Zero-Copy Default
1. Context & Motivation
Closes: [refactor] Convert BatchMeta from row-oriented to column-oriented layout
The current
BatchMetauses a row-oriented design (BatchMeta→List[SampleMeta]→Dict[str, FieldMeta]), which introduces three scaling issues in high-throughput scenarios:build_storage_meta_groups,add_fields,_filter_storage_data) involve nested loops over every sample × every field, incurred multiple times per PUT.This PR refactors
BatchMetato a column-oriented (structure-of-arrays) design, reducing metadata complexity from O(B×F) to O(B) + O(F), and enables zero-copy serialization by default with automatic pickle fallback.2. Key Changes
2.1 Columnar BatchMeta (
metadata.py)BatchMeta.samples: List[SampleMeta]global_indexes,partition_ids,production_statusFieldMetaobjects (B×F instances)field_schemadict (F entries)np.all()on ndarray O(1)BatchMeta,SampleMeta,FieldMetaBatchMetaonlySampleMetaandFieldMetaclasses entirelyfield_schemadict with three field types: Regular Tensor, Nested Tensor (is_nested), Non-Tensor (is_non_tensor)production_statusasnp.ndarray(int8)— enables O(1) readiness checks vianp.all()2.2 Zero-Copy Serialization Default (
serial_utils.py)ZERO_COPY_SERIALIZATIONenvironment variable switch2.3 Storage & Transport Adaptation
simple_backend.py/simple_backend_manager.py/controller.py: Adapted to columnar API;clear()usesdelinstead ofNoneassignment to reduce memory fragmentationzmq_utils.py: ZMQ transport uses new serialization utilities; frame count reduced from O(B) to F+1 (one metadata header + one per field)2.4 Test Suite
test_metadata.py: Fully rewritten for columnar API (net -799 lines)BatchMetaconstructor3. Benchmark Results
Tests conducted in Docker (single-node Ray) across 7 payload sizes. Three configurations compared:
Throughput Comparison (Gbps)
Speedup vs Baseline (main-no-zerocopy)
Visualization
Resource Usage
Columnar layout reduces CPU time by eliminating per-sample object creation and pickle overhead:
4. API Breaking Changes
BatchMeta.samplesList[SampleMeta]SampleMetaclassFieldMetaclasssample.fields['x'].dtypebatch.field_schema['x']['dtype']BatchMeta(samples=[...])BatchMeta(global_indexes=..., partition_ids=..., field_schema=..., production_status=...)5. Files Changed
metadata.pyserial_utils.py,zmq_utils.pysimple_backend.py,simple_backend_manager.py,base.pycontroller.pytest_metadata.py+ 7 test filesput_benchmark.py6. Conclusion
The columnar
BatchMetarefactoring combined with default zero-copy serialization delivers: