Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 9 additions & 84 deletions egomimic/rldb/zarr/zarr_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,6 @@ def write(
self,
numeric_data: dict[str, np.ndarray] | None = None,
image_data: dict[str, np.ndarray] | None = None,
pre_encoded_image_data: dict[str, tuple[np.ndarray, list[int]]] | None = None,
metadata_override: dict[str, Any] | None = None,
) -> None:
"""
Expand All @@ -339,9 +338,6 @@ def write(
All arrays must have same length along axis 0.
image_data: Dictionary of image arrays with shape (T, H, W, 3).
Images will be JPEG-compressed.
pre_encoded_image_data: Dictionary mapping key to (encoded_array, image_shape).
encoded_array is np.ndarray(dtype=object) of JPEG bytes.
image_shape is [H, W, 3]. Skips internal JPEG encoding.
metadata_override: Optional metadata overrides to apply after building metadata.

Raises:
Expand All @@ -350,26 +346,18 @@ def write(
"""
numeric_data = numeric_data or {}
image_data = image_data or {}
pre_encoded_image_data = pre_encoded_image_data or {}

if not numeric_data and not image_data and not pre_encoded_image_data:
raise ValueError(
"Must provide at least one of numeric_data, image_data, "
"or pre_encoded_image_data"
)
if not numeric_data and not image_data:
raise ValueError("Must provide at least one of numeric_data or image_data")

# Infer total_frames from data
all_lengths: list[int] = []
for arr in numeric_data.values():
all_lengths.append(len(arr))
for arr in image_data.values():
all_lengths = []
for key, arr in {**numeric_data, **image_data}.items():
all_lengths.append(len(arr))
for enc_arr, _shape in pre_encoded_image_data.values():
all_lengths.append(len(enc_arr))

if len(set(all_lengths)) > 1:
raise ValueError(
f"Inconsistent frame counts across arrays: {all_lengths}"
f"Inconsistent frame counts across arrays: {dict(zip(numeric_data.keys() | image_data.keys(), all_lengths))}"
)

self.total_frames = all_lengths[0]
Expand All @@ -392,16 +380,10 @@ def write(
for key, arr in numeric_data.items():
self._write_numeric_array(store, key, arr, padded_frames)

# Write image arrays (with internal JPEG encoding)
# Write image arrays
for key, arr in image_data.items():
self._write_image_array(store, key, arr, padded_frames)

# Write pre-encoded image arrays (skip JPEG encoding)
for key, (enc_arr, img_shape) in pre_encoded_image_data.items():
self._write_pre_encoded_image_array(
store, key, enc_arr, img_shape, padded_frames
)

# Write language annotations if provided
if self.annotations is not None:
self._write_annotations(store, self.annotations)
Expand Down Expand Up @@ -562,61 +544,6 @@ def _write_image_array(
"names": ["height", "width", "channel"],
}

def _write_pre_encoded_image_array(
self,
store: zarr.Group,
key: str,
encoded_arr: np.ndarray,
image_shape: list[int],
padded_frames: int,
) -> None:
"""
Write already-JPEG-encoded image data to the Zarr store.

Args:
store: Zarr group to write to.
key: Array key name.
encoded_arr: Object array of JPEG bytes with shape (T,).
image_shape: Original image dimensions [H, W, 3].
padded_frames: Target frame count after padding.
"""
num_frames = len(encoded_arr)

if padded_frames > num_frames:
padded = np.empty((padded_frames,), dtype=object)
padded[:num_frames] = encoded_arr
last_jpeg = encoded_arr[-1]
for i in range(num_frames, padded_frames):
padded[i] = last_jpeg
encoded_arr = padded

chunk_shape = (1,)

if self.enable_sharding:
shard_shape = encoded_arr.shape
store.create_array(
key,
shape=encoded_arr.shape,
chunks=chunk_shape,
shards=shard_shape,
dtype=VariableLengthBytes(),
)
else:
store.create_array(
key,
shape=encoded_arr.shape,
chunks=chunk_shape,
dtype=VariableLengthBytes(),
)

store[key][:] = encoded_arr

self._features[key] = {
"dtype": "jpeg",
"shape": image_shape,
"names": ["height", "width", "channel"],
}

def _write_annotations(
self, store: zarr.Group, annotations: list[tuple[str, int, int]]
) -> None:
Expand Down Expand Up @@ -703,7 +630,6 @@ def create_and_write(
episode_path: str | Path,
numeric_data: dict[str, np.ndarray] | None = None,
image_data: dict[str, np.ndarray] | None = None,
pre_encoded_image_data: dict[str, tuple[np.ndarray, list[int]]] | None = None,
embodiment: str = "",
fps: int = 30,
task: str = "",
Expand All @@ -719,8 +645,6 @@ def create_and_write(
episode_path: Path to episode .zarr directory.
numeric_data: Dictionary of numeric arrays (state, actions, etc.).
image_data: Dictionary of image arrays with shape (T, H, W, 3).
pre_encoded_image_data: Dict mapping key to (encoded_array, image_shape).
Skips internal JPEG encoding for these keys.
embodiment: Robot type identifier.
fps: Frames per second (default: 30).
task: Task description.
Expand All @@ -733,8 +657,9 @@ def create_and_write(
Path to created episode.

Raises:
ValueError: If no data is provided.
ValueError: If neither numeric_data nor image_data are provided.
"""
# Create writer
writer = ZarrWriter(
episode_path=episode_path,
embodiment=embodiment,
Expand All @@ -745,10 +670,10 @@ def create_and_write(
enable_sharding=enable_sharding,
)

# Write data
writer.write(
numeric_data=numeric_data,
image_data=image_data,
pre_encoded_image_data=pre_encoded_image_data,
metadata_override=metadata_override,
)

Expand Down
170 changes: 0 additions & 170 deletions external/scale/scripts/scale_api.py

This file was deleted.

Loading