diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml index cd6eee42..cd861b04 100644 --- a/.github/workflows/documentation.yml +++ b/.github/workflows/documentation.yml @@ -21,7 +21,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Install uv uses: astral-sh/setup-uv@v5 @@ -38,7 +38,7 @@ jobs: just docs - name: Upload artifact - uses: actions/upload-pages-artifact@v3 + uses: actions/upload-pages-artifact@v4 with: path: docs/build diff --git a/.github/workflows/license_headers.yml b/.github/workflows/license_headers.yml index d6c7dc61..f16e801b 100644 --- a/.github/workflows/license_headers.yml +++ b/.github/workflows/license_headers.yml @@ -15,9 +15,9 @@ jobs: steps: - name: Checkout Repository - uses: actions/checkout@v3 + uses: actions/checkout@v6 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: 3.9 - name: Install dependencies diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index e88d1c91..c3475653 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -15,7 +15,7 @@ jobs: id-token: write # This permission is mandatory for trusted publishing steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Install uv uses: astral-sh/setup-uv@v5 @@ -32,4 +32,4 @@ jobs: just build - name: Publish package - uses: pypa/gh-action-pypi-publish@release/v1 + uses: pypa/gh-action-pypi-publish@release/v1.13 # release/v1.13 diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index 97de837b..287f770a 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -15,7 +15,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Install uv uses: astral-sh/setup-uv@v5 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d0e18016..bf5c18bd 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -15,7 +15,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Install uv uses: astral-sh/setup-uv@v5 diff --git a/docs/source/advanced/crude_datasets.md b/docs/source/advanced/crude_datasets.md index dcb091f8..18e55015 100644 --- a/docs/source/advanced/crude_datasets.md +++ b/docs/source/advanced/crude_datasets.md @@ -74,6 +74,34 @@ All it does, is to forward the key, restore key and flavors from the dict to the In a real use-case you will want to do a lot more here and we recommend keeping the cook methods in separate files and importing them where you define your TaskEncoder. +### Using Media Metadata in Monolithic Datasets +If you prepared your dataset with media metadata, you can access it in the cooker using the {py:meth}`get_media_metadata ` method of the {py:class}`FileStore `. This will be faster than decoding the media files to infer the media metadata like video duration or image size. +Here we are looking at the use case for primary datasets, but for [auxiliary datasets](aux-data), this is even more beneficial. + +#### Metadata from the Primary Dataset + +```python +from megatron.energon import FileStore + +# ... + +@stateless +@cooker(need_primary=True) +def cook_media_metadata(sample: dict, primary: FileStore) -> TextSample: + """This cooker loads the media metadata for the current sample + and returns a text sample with the metadata description.""" + + # The `.` prefix is used to access the media metadata of the primary dataset, + # referring to a part of the current sample. + primary_media_metadata = primary.get_media_metadata(".wav") + + return TextSample( + **basic_sample_keys(sample), + text=f"This audio has duration {primary_media_metadata.audio_duration} seconds", + ) +``` + + (aux-data)= ## Auxiliary Data for Polylithic Datasets @@ -165,6 +193,35 @@ def cook_text(sample: dict, primary: FileStore, foo_bar_source: FileStore) -> Te You can then retrieve files by their names from the primary dataset. + +### Using Media Metadata in Polylithic Datasets +If you prepared your auxiliary dataset with media metadata, you can access it in the cooker using the {py:meth}`get_media_metadata ` method of the {py:class}`FileStore `. +This is much faster than reading the media files themselves to infer the media metadata like video duration or image size. +Especially, if you are working with Lazy objects, you can defer loading the media files entirely until you actually need them. +For example in {py:meth}`postencode_sample(self, sample: T_sample) -> T_encoded_sample `, when using packing. + +```python +from megatron.energon import FileStore + +# ... + +def cook_media_metadata(sample: dict, foo_bar_source: FileStore) -> TextSample: + # Use the image filename from the primary sample to get the media metadata from the auxiliary dataset + media_metadata = foo_bar_source.get_media_metadata(sample['image']) + + return TextSample( + **basic_sample_keys(sample), + text=f"This image has size {media_metadata.width}x{media_metadata.height} and format {media_metadata.format}", + ) + +``` + +The path of the media file within the auxiliary dataset, is what needs to be passed to the `get_media_metadata` method, for example `0001.jpg` or `images/0001.jpg` if within a subfoler inside the WebDataset or filesystem dataset. + +The return value will be either {py:class}`AVMetadata ` or {py:class}`ImageMetadata `. +Click on them to see the fields and their types. + + (cache-pools)= ## Cache Pools diff --git a/docs/source/advanced/data_prep_api.md b/docs/source/advanced/data_prep_api.md new file mode 100644 index 00000000..8768f7ea --- /dev/null +++ b/docs/source/advanced/data_prep_api.md @@ -0,0 +1,70 @@ + + +# Programmatic Data Preparation + +Below, we provide an example function that shows how to prepare a dataset programmatically. + +```python +from megatron.energon.flavors import BaseWebdatasetFactory + +def prepare_one_dataset(path: Path, num_workers: int, template_dir: Path): + + if (path / ".nv-meta" / "dataset.yaml").exists(): + print(f"Dataset {path} already prepared. Skipping.") + return + + # Fixed settings + tar_index_only = False + split_parts_ratio = [("train", 1), ("val", 0), ("test", 0)] + split_parts_patterns = None + + # Get all tar files + all_tars = list(path.glob("**/*.tar")) + list(path.glob("**/*.tgz")) + all_tars = [str(p.relative_to(path)) for p in sorted(all_tars)] + + if len(all_tars) == 0: + print("Did not find any tar files. Exiting.") + return + + print(f"Found {len(all_tars)} tar files in total. The first and last ones are:") + print(f"- {all_tars[0]}") + print(f"- {all_tars[-1]}") + + def progress_fn(els, length=None): + with click.progressbar( + els, + label="Indexing shards", + show_pos=True, + length=length, + ) as bar: + for el in bar: + yield el + + found_types, duplicates = BaseWebdatasetFactory.prepare_dataset( + path, + all_tars, + split_parts_ratio=split_parts_ratio, + split_parts_patterns=split_parts_patterns, + progress_fn=progress_fn, + tar_index_only=tar_index_only, + shuffle_seed=None, + workers=num_workers, + ) + + # Copy sample loader and dataset.yaml templates + for file in template_dir.glob("*"): + shutil.copy(file, path / ".nv-meta" / file.name) +``` + +Example usage: + +First, create a template directory with the `dataset.yaml` file, and optionally the `sample_loader.py` file. +Let's call it `template_dir`. + +Then, run the script: + +```python +if __name__ == "__main__": + prepare_one_dataset(Path("/path/to/dataset"), 16, Path("/path/to/template_dir")) +``` diff --git a/docs/source/advanced/error_handling.md b/docs/source/advanced/error_handling.md new file mode 100644 index 00000000..7d208d44 --- /dev/null +++ b/docs/source/advanced/error_handling.md @@ -0,0 +1,159 @@ + + +# Error Handling + +Energon provides configurable error handling to gracefully manage exceptions during data loading and processing. +By default, errors in task encoder code are logged and the sample is skipped, allowing training to continue. +Some system exceptions (clearly programming errors) always cause immediate failure. + +By default, if more than 100 samples in a row raise an error, processing will also fail. This is configurable for the +{py:class}`~megatron.energon.task_encoder.base.TaskEncoder` user-functions by setting the `@stateless(failure_tolerance=100)` parameter. + +## Error Handler Configuration + +Error handlers are configured via {py:class}`WorkerConfig `: + +```python +from megatron.energon import WorkerConfig, log_exception, reraise_exception + +worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=4, + global_error_handler=log_exception, # Default: logs and skips failed samples + restore_error_handler=reraise_exception, # Default: fails fast on restore +) +``` + +**`global_error_handler`:** Handles exceptions during normal iteration (training/validation). +Default is {py:func}`log_exception ` which logs details and skips the sample. + +**`restore_error_handler`:** Handles exceptions during checkpoint restoration. +Default is {py:func}`reraise_exception ` which fails fast to ensure checkpoint integrity. + +## System Exceptions + +Programming errors and critical system issues always cause immediate failure and are never handled by error handlers: +`SystemError`, `SyntaxError`, `ImportError`, `StopIteration`, `StopAsyncIteration`, `MemoryError`, `RecursionError`, `ReferenceError`, `NameError`, `UnboundLocalError`, and {py:exc}`FatalSampleError `. +{py:exc}`FatalSampleError ` is raised automatically when consecutive failure tolerance is exceeded or when a system exception occurs during sample processing. + +## Built-in Error Handlers + +### `log_exception` + +Logs detailed error information and continues: +- Exception traceback +- Source information (dataset path, shard, index) +- Sample details in readable format + +```python +from megatron.energon import log_exception + +worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=4, + global_error_handler=log_exception, +) +``` + +### `reraise_exception` + +Immediately reraises the exception to halt iteration: + +```python +from megatron.energon import reraise_exception + +worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=4, + global_error_handler=reraise_exception, # Fail on any error +) +``` + +### Custom Error Handlers + +Implement custom error handlers with this signature: + +```python +def my_error_handler( + exception: Exception, + sample: Any, + sources: list[SourceInfo] | None +) -> None: + # Log to your monitoring system + log_to_monitoring(exception, sample) + + # Optionally reraise for critical errors + if isinstance(exception, CriticalError): + raise exception +``` + +```python +worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=4, + global_error_handler=my_error_handler, +) +``` + +## Failure Tolerance for Task Encoder Functions + +By default, if more than 100 samples in a row raise an error, processing will fail with a {py:exc}`FatalSampleError `. + +For {py:class}`TaskEncoder ` methods, configure this via the `@stateless` decorator: + +```python +from megatron.energon import DefaultTaskEncoder, stateless + +class MyTaskEncoder(DefaultTaskEncoder): + @stateless(failure_tolerance=50) + def encode_sample(self, sample): + # Process sample - tolerates up to 50 consecutive failures + if sample.is_corrupted(): + raise ValueError("Corrupted sample") + return sample + + @stateless(restore_seeds=True, failure_tolerance=200) + def pack_selected_samples(self, samples): + # Packing with higher tolerance and deterministic randomness + return pack_samples(samples) +``` + +Set `failure_tolerance=0` to disable tolerance checking for a specific function. + +```{admonition} Note +:class: important +Tolerance limits count *consecutive* failures. A single successful sample resets the counter. +``` + +## Skip or Fail Explicitly + +Raise {py:exc}`SkipSample ` to explicitly skip a sample without logging an error: + +```python +from megatron.energon import SkipSample + +def process_sample(sample): + try: + ... + except MySpecificError: + raise SkipSample() + return sample +``` + +Raise {py:exc}`FatalSampleError ` to cause immediate failure, bypassing the error handler: + +```python +from megatron.energon import FatalSampleError + +def process_sample(sample): + try: + ... + except MyFatalError as e: + raise FatalSampleError.from_sample(sample, "Critical corruption detected") from e + return sample +``` diff --git a/docs/source/api/modules_data.md b/docs/source/api/modules_data.md index 873c9516..29510cc7 100644 --- a/docs/source/api/modules_data.md +++ b/docs/source/api/modules_data.md @@ -24,3 +24,13 @@ SPDX-License-Identifier: BSD-3-Clause --> :undoc-members: :show-inheritance: ``` + + +# megatron.energon.media + +```{eval-rst} +.. automodule:: megatron.energon.media + :members: + :undoc-members: + :show-inheritance: +``` diff --git a/docs/source/basic/data_prep.md b/docs/source/basic/data_prep.md index a0fc1c97..73629696 100644 --- a/docs/source/basic/data_prep.md +++ b/docs/source/basic/data_prep.md @@ -16,14 +16,13 @@ Depending on what your data looks like and how you are planning to use it, you w **Monolithic Dataset vs. Polylithic (primary and auxiliary) Datasets** -You can include the media (images/video/audio) inside the same webdataset along with the text and metadata of each sample. +You can include the media (images/video/audio) inside the same webdataset along with the text-based data of each sample (such as labels, captions, etc.). Or you can keep the media separate (either in another indexed webdataset or as individual files on disk). When using JSONL, the media will always be separate, so JSONL datasets are always polylithic unless they are text-only. -If you can, you should go for the monolithic option, because it's faster to load. -However, there are a few reasons why the other option may be needed: +The monolithic option is faster to load. However, there are a few reasons why the other option may be preferable: -* You need to keep the original media and you don't want to duplicate it +* You need to keep the original media files and you don't want to duplicate them in the tar files * Your media data is very large (e.g. long videos) and you need to keep your primary dataset small (containing just the text-based data and meta information) * You want to re-use the same media with different labels or you want to train on different subsets * You want to train with [online packing](../advanced/packing.md) and can't fit all the media of the packing buffer in memory. With polylithic datasets you can use caching to avoid that issue. @@ -63,13 +62,14 @@ These are the typical steps to get your data ready: (polylithic-dataset)= ## Steps to Create a Polylithic Dataset -1. Create the primary [WebDataset](https://github.com/webdataset/webdataset) or JSONL file from your text-based part of the data (meta information, labels, sizes etc.) +1. Create the primary [WebDataset](https://github.com/webdataset/webdataset) or JSONL file from your text-based part of the data (meta information, labels etc.) * Include the file names (don't use absolute paths) of the media that belongs to each sample (e.g. as strings inside a json entry) 2. Create the auxiliary dataset(s). Can be multiple datasets, e.g. one per modality. * Either as a folder on disk with all the media files inside * Or as another WebDataset that contains just the media files (with the exact same names) 3. Run our preparation tool `energon prepare` **on both datasets** (yes also on the JSONL) to convert to an energon-compatible format * Configure both datasets as `CrudeWebdataset` (JSONL always is by default) + * For the auxiliary datasets, we recommend to enable the [media metadata feature](media-metadata) to store additional information about the media (like image size, resolution, video duration etc.) 4. Create a [metadataset](../basic/metadataset) that specifies what auxiliary data to load for each primary dataset * For more details read about [crude data](crude-data) @@ -211,6 +211,53 @@ The command will * Ask you how to decode the data if not using crude data (field map or sample_loader.py) * Store all this information in a subfolder `.nv-meta/`, see details [below](data-on-disk). +(media-metadata)= +### Media Metadata + +If you are preparing a dataset with media files, energon can retrieve and store additional information about the media (like image size, resolution, video duration etc.). +This information will be stored inside an SQLite database file next to the dataset. +Later, inside the [cooker](crude-data), you can access this information using the {py:meth}`get_media_metadata ` method of the {py:class}`FileStore `. + +#### During normal initial preparation of a WebDataset + +```sh +> energon prepare --media-metadata-by-extension /path/to/dataset +``` + +#### Adding media metadata to an existing dataset +```sh +> energon prepare-media --media-metadata-by-extension /path/to/dataset +``` + +```{admonition} Good to know +:class: tip +That also works for filesystem datasets. I.e. you can run `energon prepare-media` on a normal folder with media files and it will create the media metadata database file next to the dataset. +``` + +#### Customizing the selection of media files + +You can customize the selection of media files by using the `--media-metadata-by-glob`, `--media-metadata-by-header` and `--media-metadata-by-extension` options. +You must specify exactly one of the options. + +To select media files by our default extension list (recommended), you can use the `--media-metadata-by-extension` option. +```sh +> energon prepare --media-metadata-by-extension /path/to/dataset +``` + +The list can be found in the [extractor.py](https://github.com/NVIDIA/Megatron-Energon/blob/develop/src/megatron/energon/media/extractor.py) file. + +To select all media files with the extensions `.jpg`, `.png` and `.webp`, you can use the following command: +```sh +> energon prepare --media-metadata-by-glob '*.jpg,*.png,*.webp' /path/to/dataset +``` + +To select media files by reading their contents/header, you can use the `--media-metadata-by-header` option. +```sh +> energon prepare --media-metadata-by-header /path/to/dataset +``` +Note that this option may be slower than the other options, as it needs to read the contents of the files. + + ### Splitting the dataset into train/val/test The first thing that the `energon prepare` assistant will ask you, is how you want to split the data by ratios. @@ -581,7 +628,8 @@ The order of tar files is important, as it's used by the sqlite database below. #### index.sqlite and index.uuid (read-only) The sqlite database was introduced in Energon 7 and allows for fully random access of samples and files by their names. -This is a precondition for polylithic datasets and for the [`energon mount`](energon-mount) command. +This is a precondition for [polylithic datasets](aux-data) and for the [`energon mount`](energon-mount) command. +Later, media metadata was added to the database to allow for fast access to the media metadata (audio length, image resolution etc.) of the media files. Below there is some detailed information for the interested reader. Note that the internal table structure can change in any release without notice. @@ -616,6 +664,28 @@ directly access the content without parsing the tar header. Both tables can be joined over the `tar_file_id` and the `sample_index`. Note that the `tar_file_id` refers to the list of tar files in the `.info.json` file. +Since version 8 of Energon, `media_filters` and `media_metadata` tables are added to the database: + +The filters table is used to store the media filters that were used to select the media files. +The media metadata will be stored *for the union* of all the media files that were selected by the filters. + +| filter_id | strategy | patterns | created_at_utc | +| --- | --- | --- | --- | +| 1 | EXTENSION | | 2025-01-01 12:00:00 | +| 2 | GLOB | \*.jpg,\*.png,\*.webp | 2025-01-01 12:00:00 | +| 3 | HEADER | | 2025-01-01 12:00:00 | +| 4 | ... | ... | ... | + + +The `media_metadata` table is used to store the media metadata for the selected media files: + +| entry_key | metadata_type | metadata_json | +| --- | --- | --- | +| 00000.jpg | image | {"width": 1024, "height": 768, "format": "jpg"} | +| 00001.wav | av | {"audio_duration": 39.0 , "audio_channels": 1, "audio_sample_rate": 16000} | +| 00002.mp4 | ... | ... | + + (data-on-disk-jsonl)= ## Dataset Format on Disk for JSONL Datasets @@ -626,3 +696,12 @@ So if your dataset is named `my_dataset.jsonl`, a new file `my_dataset.jsonl.idx That's all. The dataset type will always be `CrudeWebdataset` and the split part is `train` by default. However, when loading the dataset you can change the split type to `val` or `test`. + +(data-on-disk-filesystem)= +## Dataset Format on Disk for Filesystem Datasets + +Filesystem datasets are datasets that are stored on disk as individual files in a folder. +They are not indexed and cannot be accessed randomly. They are only used as auxiliary datasets. + +They can be used without an `.nv-meta` folder, but if you run `energon prepare-media` on them, an sqlite database file will be created inside an `.nv-meta` folder. +The database will contain just the `media_filters` and `media_metadata` tables as explained above. diff --git a/docs/source/basic/quickstart.md b/docs/source/basic/quickstart.md index 622089ac..c817e4b4 100644 --- a/docs/source/basic/quickstart.md +++ b/docs/source/basic/quickstart.md @@ -147,6 +147,8 @@ Let's also talk about the {py:class}`WorkerConfig ` to see how the worker config is constructed. Also don't be afraid to click the *`[source]`* link and look at the very short source code of it. +The worker config also controls error handling behavior. See [](../advanced/error_handling) for details on customizing error handlers. + ## Tutorial 3: Batch Size Actually, we would like to use a `batch_size` of more than one, let's go with 2 for now. diff --git a/docs/source/basic/save_restore.md b/docs/source/basic/save_restore.md index 2c9792b1..372637aa 100644 --- a/docs/source/basic/save_restore.md +++ b/docs/source/basic/save_restore.md @@ -140,6 +140,14 @@ state = torch.load('dataloader_state.pth') loader.restore_state_global(state, src_rank=None) ``` +## Error Handling During Restore + +By default, energon uses a strict error handler during checkpoint restoration to ensure data integrity. +If sample restoration fails, the exception is immediately raised. + +You can customize this behavior via the `restore_error_handler` parameter in {py:class}`WorkerConfig `. +See [](../advanced/error_handling) for details. + ## Summary In each of these scenarios, ensure that the logic for saving and restoring the state is appropriately synchronized across ranks to maintain consistency. diff --git a/docs/source/index.md b/docs/source/index.md index c1cf490a..e41831fb 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -57,6 +57,8 @@ advanced/subsets advanced/epochized_blending advanced/custom_blending advanced/parallelism +advanced/error_handling +advanced/data_prep_api ``` diff --git a/pyproject.toml b/pyproject.toml index 192fc68e..793fb161 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,9 +23,8 @@ classifiers = [ dependencies = [ "braceexpand", "click", - "dataslots; python_version<'3.10'", "mfusepy", - "multi-storage-client>=0.18.0,<0.26.0", + "multi-storage-client>=0.33.0", "numpy", "pillow>=10.0.1", # WEBP vulnerability fixed starting from 10.0.1 "pyyaml", @@ -49,6 +48,7 @@ dev = [ transforms = [ "torchvision", # Needed for megatron.energon.transforms ] + # Storage services for MSC s3 = [ "multi-storage-client[boto3]", @@ -65,6 +65,9 @@ google-cloud-storage = [ oci = [ "multi-storage-client[oci]", ] +huggingface = [ + "multi-storage-client[huggingface]", +] # Dependencies for video decoding av_decode = [ @@ -79,6 +82,9 @@ av_decode = [ guess_content = [ "filetype>=1.0.0", ] +tar_patcher = [ + "numba", +] [project.urls] Homepage = "https://github.com/NVIDIA/Megatron-Energon" diff --git a/src/megatron/energon/__init__.py b/src/megatron/energon/__init__.py index 1a54c389..0d92c702 100644 --- a/src/megatron/energon/__init__.py +++ b/src/megatron/energon/__init__.py @@ -20,9 +20,11 @@ Lazy, MockLazy, NoCachePool, + PrimaryFileStore, SystemFileStore, ) from megatron.energon.edataclass import edataclass +from megatron.energon.errors import log_exception, reraise_exception from megatron.energon.flavors import ( BaseCoreDatasetFactory, BaseWebdatasetFactory, @@ -168,6 +170,7 @@ "Lazy", "LimitDataset", "load_dataset", + "log_exception", "LogSampleDataset", "MapDataset", "Metadataset", @@ -180,8 +183,10 @@ "OCRSample", "OCRWebdataset", "PackingDataset", + "PrimaryFileStore", "prepare_metadataset", "RepeatDataset", + "reraise_exception", "Sample", "SampleDecoder", "SavableDataLoader", diff --git a/src/megatron/energon/av/__init__.py b/src/megatron/energon/av/__init__.py index 60c333d5..741970d8 100644 --- a/src/megatron/energon/av/__init__.py +++ b/src/megatron/energon/av/__init__.py @@ -2,16 +2,11 @@ # SPDX-License-Identifier: BSD-3-Clause from megatron.energon.av.av_data import AVData -from megatron.energon.av.av_decoder import ( - AVDecoder, - AVMetadata, - AVWebdatasetDecoder, -) +from megatron.energon.av.av_decoder import AVDecoder, AVWebdatasetDecoder from megatron.energon.av.utils import get_clips_uniform, get_single_frames_uniform __all__ = [ "AVDecoder", - "AVMetadata", "AVWebdatasetDecoder", "AVData", "get_clips_uniform", diff --git a/src/megatron/energon/av/av_decoder.py b/src/megatron/energon/av/av_decoder.py index ee6555b9..4170443a 100644 --- a/src/megatron/energon/av/av_decoder.py +++ b/src/megatron/energon/av/av_decoder.py @@ -3,23 +3,24 @@ import io import warnings -from collections.abc import Iterator from fractions import Fraction from typing import BinaryIO, Literal, Optional, Sequence, Union, overload import numpy as np import torch -from megatron.energon.av.av_init import av_open -from megatron.energon.edataclass import edataclass - # Try importing optional dependencies try: import av import av.container import filetype - from megatron.energon.av.fastseek import Fastseek + from megatron.energon.av.fastseek.reader import ( + Fastseek, + FastseekReader, + FastseekReaderByFrames, + FastseekReaderByPts, + ) AV_DECODE_AVAILABLE = True except ImportError as e: @@ -27,6 +28,8 @@ MISSING_DEPENDENCY = str(e) from megatron.energon.av.av_data import AVData +from megatron.energon.av.av_init import av_open +from megatron.energon.media import AVMetadata class AVDecoder: @@ -37,7 +40,7 @@ class AVDecoder: decoding parameters. """ - seeker: "Fastseek" + seeker: "Fastseek | None" = None stream: BinaryIO suppress_warnings: bool @@ -55,13 +58,11 @@ def __init__(self, stream: BinaryIO, suppress_warnings: bool = False) -> None: "Stream must not be opened in text mode" ) - try: + def _ensure_seeker(self) -> None: + if self.seeker is None: + self.stream.seek(0) self.seeker = Fastseek(self.stream) - except ValueError: self.stream.seek(0) - self.seeker = Fastseek(self.stream, probe=True) - - self.stream.seek(0) def get_video(self) -> AVData: """Get the entire video data from the stream (without audio).""" @@ -83,7 +84,8 @@ def get_video_clips( """Get video clips from the video stream. Args: - video_clip_ranges: List of video clip start and end positions in the given unit (see video_unit) + video_clip_ranges: List of video clip start and end positions in the given unit (see video_unit). + The end is inclusive! For each range at least one frame is returned. video_unit: Unit of the video clip positions ("frames" for frame number, "seconds" for timestamp) video_out_frame_size: Output size for video frames (width, height), or None to use the original frame size @@ -93,6 +95,8 @@ def get_video_clips( - video_clips_timestamps: List of timestamps for each video clip start and end in seconds """ + self._ensure_seeker() + assert video_unit in ("frames", "seconds") self.stream.seek(0) # Reset the video stream so that pyav can read the entire container @@ -105,14 +109,19 @@ def get_video_clips( video_stream = input_container.streams.video[0] # Pre-calculate timing info for video - average_rate: Fraction = video_stream.average_rate # Frames per second + # Frames per second + average_rate: Fraction = video_stream.average_rate assert average_rate, "Video stream has no FPS." - time_base: Fraction = video_stream.time_base # Seconds per PTS unit + # Seconds per PTS unit + time_base: Fraction = video_stream.time_base - if video_clip_ranges is not None: - # Convert video_clip_ranges to seeker unit - if video_unit == "frames" and self.seeker.unit == "pts": + reader: FastseekReader + # Convert video_clip_ranges to seeker unit + if video_unit == "frames": + if self.seeker.frame_index_supported: + reader = FastseekReaderByFrames(self.seeker, input_container) + elif self.seeker.pts_supported: # Convert from frames to pts units video_clip_ranges = [ ( @@ -121,14 +130,29 @@ def get_video_clips( ) for clip in video_clip_ranges ] + reader = FastseekReaderByPts(self.seeker, input_container) + video_unit = "seconds" if not self.suppress_warnings: warnings.warn( "Video container unit is frames, but seeking in time units. The resulting frames may be slightly off.", RuntimeWarning, ) - elif video_unit == "seconds" and self.seeker.unit == "frames": + else: + raise ValueError("Video container does not support seeking in frames or PTS") + elif video_unit == "seconds": + if self.seeker.pts_supported: # Convert from seconds to frames + video_clip_ranges = [ + ( + clip[0] / time_base, + clip[1] / time_base, + ) + for clip in video_clip_ranges + ] + reader = FastseekReaderByPts(self.seeker, input_container) + elif self.seeker.frame_index_supported: + # Convert from frames to pts units video_clip_ranges = [ ( clip[0] * average_rate, @@ -136,88 +160,58 @@ def get_video_clips( ) for clip in video_clip_ranges ] + reader = FastseekReaderByFrames(self.seeker, input_container) + video_unit = "frames" + if not self.suppress_warnings: warnings.warn( - "Video container unit is time units, but seeking using frame number. The resulting frames may be slightly off.", + "Video container unit is seconds, but seeking only supports frames. The resulting frames may be slightly off.", RuntimeWarning, ) - elif video_unit == "seconds" and self.seeker.unit == "pts": - # Convert from seconds to pts units - video_clip_ranges = [ - (clip[0] / time_base, clip[1] / time_base) for clip in video_clip_ranges - ] - - frame_iterator: Iterator[av.VideoFrame] = input_container.decode(video=0) - previous_frame_index: int = 0 + else: + raise ValueError("Video container does not support seeking in frames or PTS") + else: + raise ValueError(f"Invalid video unit: {video_unit!r}") video_clips_frames: list[list[torch.Tensor]] = [] video_clips_timestamps: list[tuple[float, float]] = [] for video_clip_range in video_clip_ranges: - start_frame_index, end_frame_index = video_clip_range + range_start, range_end = video_clip_range # Convert to int if possible, set end to None if infinite - start_frame_index = int(start_frame_index) - end_frame_index = int(end_frame_index) if end_frame_index != float("inf") else None + range_start = int(range_start) + range_end = int(range_end) if range_end != float("inf") else None clip_frames: list[torch.Tensor] = [] clip_timestamp_start = None clip_timestamp_end = None - # Find start frame - if ( - iframe_info := self.seeker.should_seek(previous_frame_index, start_frame_index) - ) is not None: - input_container.seek(iframe_info.pts, stream=input_container.streams.video[0]) - previous_frame_index = iframe_info.index - - for frame in frame_iterator: - take_frame = False - last_frame = False - - # Container uses frame counts, we can find the exact target frame by counting from the iframe which is at a known offset - if self.seeker.unit == "frames": - if previous_frame_index >= start_frame_index: - take_frame = True - if end_frame_index is not None and previous_frame_index >= end_frame_index: - last_frame = True - - # Container uses time, the target frame might not correspond exactly to any metadata but the desired timestamp should - # fall within a frames display period - if self.seeker.unit == "pts": - if start_frame_index <= (frame.pts + frame.duration): - take_frame = True - if end_frame_index is not None and end_frame_index <= ( - frame.pts + frame.duration - ): - last_frame = True - - if take_frame: - if video_out_frame_size is not None: - frame = frame.reformat( - width=video_out_frame_size[0], - height=video_out_frame_size[1], - format="rgb24", - interpolation="BILINEAR", - ) - else: - frame = frame.reformat(format="rgb24") - - clip_frames.append(torch.from_numpy(frame.to_ndarray())) - if clip_timestamp_start is None: - clip_timestamp_start = float(frame.pts * frame.time_base) - - clip_timestamp_end = float((frame.pts + frame.duration) * frame.time_base) + frame = None + for frame in reader.seek_read(range_start, range_end): + # print(f"Taking frame {frame.pts}+{frame.duration}") + if video_out_frame_size is not None: + frame = frame.reformat( + width=video_out_frame_size[0], + height=video_out_frame_size[1], + format="rgb24", + interpolation="BILINEAR", + ) + else: + frame = frame.reformat(format="rgb24") - previous_frame_index += 1 + clip_frames.append(torch.from_numpy(frame.to_ndarray())) + if clip_timestamp_start is None: + clip_timestamp_start = float(frame.pts * frame.time_base) - if last_frame: - break + if frame is not None: + clip_timestamp_end = float((frame.pts + frame.duration) * frame.time_base) if clip_timestamp_start is not None and clip_timestamp_end is not None: video_clips_frames.append(clip_frames) video_clips_timestamps.append((clip_timestamp_start, clip_timestamp_end)) + # print(f"Skipped {seeker.skipped} frames") # Stack frames within each clip out_video_clips = [ torch.stack(clip_frames).permute((0, 3, 1, 2)) for clip_frames in video_clips_frames @@ -551,6 +545,7 @@ def get_metadata( get_video_frame_size: bool = True, get_audio: bool = True, get_audio_duration: bool = True, + get_audio_num_samples: bool = False, ) -> "AVMetadata": """Get the metadata of the media object. @@ -561,6 +556,7 @@ def get_metadata( get_video_frame_size: Compute video frame size if not found in header. get_audio: Compute audio metadata. get_audio_duration: Compute audio duration if not found in header. + get_audio_num_samples: Compute audio number of samples. This requires decoding the audio stream. """ self.stream.seek(0) with av_open(self.stream) as input_container: @@ -613,6 +609,21 @@ def get_metadata( audio_stream = input_container.streams.audio[0] metadata.audio_sample_rate = audio_stream.sample_rate metadata.audio_duration = audio_stream.duration + if get_audio_num_samples and metadata.audio_num_samples is None: + num_samples = 0 + last_packet = None + input_container.seek(0) + for p in input_container.decode(audio=0): + if p.pts is not None: + last_packet = p + num_samples += p.samples + + metadata.audio_num_samples = num_samples + + if last_packet is not None and last_packet.duration is not None: + assert last_packet.pts is not None + metadata.audio_duration = last_packet.pts + last_packet.duration + if get_audio_duration and metadata.audio_duration is None: last_packet = None input_container.seek(0) @@ -715,18 +726,3 @@ def __call__( ) else: raise ValueError(f"Invalid av_decode value: {self.av_decode}") - - -@edataclass -class AVMetadata: - """Metadata of the media object.""" - - video_duration: Optional[float] = None - video_num_frames: Optional[int] = None - video_fps: Optional[float] = None - video_width: Optional[int] = None - video_height: Optional[int] = None - - audio_duration: Optional[float] = None - audio_channels: Optional[int] = None - audio_sample_rate: Optional[int] = None diff --git a/src/megatron/energon/av/av_init.py b/src/megatron/energon/av/av_init.py index 8a930cbd..28c8a943 100644 --- a/src/megatron/energon/av/av_init.py +++ b/src/megatron/energon/av/av_init.py @@ -1,6 +1,8 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause +from bitstring.bits import BitsType + try: # Try importing optional dependencies import av @@ -10,7 +12,7 @@ pass -def av_open(file: str) -> "av.container.InputContainer": +def av_open(file: BitsType) -> "av.container.InputContainer": """Open a file with PyAV. This function is a wrapper around av.open that disables additional threads in the container. diff --git a/src/megatron/energon/av/fastseek/__init__.py b/src/megatron/energon/av/fastseek/__init__.py index d9d71f04..357a4ed6 100644 --- a/src/megatron/energon/av/fastseek/__init__.py +++ b/src/megatron/energon/av/fastseek/__init__.py @@ -1,4 +1,13 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -from .fastseek import Fastseek as Fastseek -from .keyframeinfo import KeyframeInfo as KeyframeInfo +from .fastseek import Fastseek +from .keyframeinfo import KeyframeInfo +from .reader import FastseekReader, FastseekReaderByFrames, FastseekReaderByPts + +__all__ = [ + "Fastseek", + "KeyframeInfo", + "FastseekReader", + "FastseekReaderByFrames", + "FastseekReaderByPts", +] diff --git a/src/megatron/energon/av/fastseek/containers/matroska.py b/src/megatron/energon/av/fastseek/containers/matroska.py index 3a323886..bf93362c 100644 --- a/src/megatron/energon/av/fastseek/containers/matroska.py +++ b/src/megatron/energon/av/fastseek/containers/matroska.py @@ -4,12 +4,12 @@ from bitstring.bits import BitsType from ebmlite import MasterElement, loadSchema -from sortedcontainers import SortedList from ..keyframeinfo import KeyframeInfo class CueTrackPositions: + __slots__ = ("track",) track: int def __init__(self, el: MasterElement) -> None: @@ -19,6 +19,7 @@ def __init__(self, el: MasterElement) -> None: class CuePoint: + __slots__ = ("time", "track_positions") time: int track_positions: CueTrackPositions @@ -30,7 +31,7 @@ def __init__(self, el: MasterElement) -> None: self.track_positions = CueTrackPositions(c) -def parse_matroska(file: BitsType) -> SortedList: +def parse_matroska(file: BitsType) -> dict[int, list[KeyframeInfo]]: try: schema = loadSchema("matroska.xml") doc = schema.load(file, headers=True) @@ -39,12 +40,12 @@ def parse_matroska(file: BitsType) -> SortedList: # Get cue times stack = [c for c in doc if c.name == "Segment"] - cues = defaultdict(SortedList) + cues = defaultdict(list) while len(stack) > 0: el = stack.pop() if el.name == "CuePoint": cue = CuePoint(el) - cues[cue.track_positions.track].add(KeyframeInfo(cue.time, cue.time)) + cues[cue.track_positions.track].append(KeyframeInfo(None, cue.time)) elif isinstance(el, MasterElement): stack.extend([c for c in el if c.name in ["Cues", "CuePoint"]]) diff --git a/src/megatron/energon/av/fastseek/containers/mpeg.py b/src/megatron/energon/av/fastseek/containers/mpeg.py deleted file mode 100644 index 566c95c3..00000000 --- a/src/megatron/energon/av/fastseek/containers/mpeg.py +++ /dev/null @@ -1,241 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. -# SPDX-License-Identifier: BSD-3-Clause -from collections import defaultdict -from itertools import accumulate -from typing import Any, Generator - -from bitstring import ConstBitStream, Error -from bitstring.bits import BitsType -from sortedcontainers import SortedList - -from ..keyframeinfo import KeyframeInfo - -box_atoms = {"moov", "trak", "mdia", "minf", "stbl", "edts"} # Non-exhaustive - - -def parse_table(cbs: ConstBitStream, table_size: int, struct: dict[str, str]) -> dict[str, Any]: - return [ - dict(zip(struct.keys(), cbs.readlist(", ".join(struct.values())))) - for _ in range(table_size) - ] - - -class Atom: - skip_version_and_flags: bool = False - - @staticmethod - def make_atom(cbs: ConstBitStream) -> "Atom": - size: int = cbs.read("uint:32") - name: str = cbs.read("bytes:4").decode("ascii") - box: bool = name in box_atoms - - if size == 0: - raise RuntimeError( - "MPEG parser detected a zero byte atom, this likely indicates a corrupt video." - ) - - subclass_list = [c for c in Atom.__subclasses__() if c.__name__ == name.upper()] - atom_class: type = Atom - if len(subclass_list) > 0: - atom_class: type = subclass_list[0] - cbs.bytepos += 4 # Skip version and flags TODO not every atom needs this - - atom = atom_class(size, name, box) - atom._parse(cbs) - - return atom - - def __init__(self, size: int, name: str, box: bool) -> None: - self.size: int = size - self.name: str = name - self.box: bool = box - - def _parse(self, cbs: ConstBitStream) -> None: - if not self.box: - cbs.bytepos += self.size - 8 - - def __str__(self) -> str: - return f"{self.name=}, {self.size=}, {self.box=}" - - -class TKHD(Atom): - """ - Parses the track header atom, see https://developer.apple.com/documentation/quicktime-file-format/track_header_atom - """ - - def _parse(self, cbs: ConstBitStream) -> None: - cbs.bytepos += 8 # skip creation time and modification time - self.track_id: int = cbs.read("uint:32") - cbs.bytepos += 68 # Skip rest of structure - - -class HDLR(Atom): - """ - Parses the media handler atom, see https://developer.apple.com/documentation/quicktime-file-format/handler_reference_atom - - NOTE: currently unused but could speed up parsing by skipping audio tracks - """ - - def _parse(self, cbs: ConstBitStream) -> None: - self.component_type = cbs.read("bytes:4").decode("ascii") - self.component_subtype = cbs.read("bytes:4").decode("ascii") - - # Skip rest of structure, the last field is variable so we need to use the total size - # 24 bytes already read (size (4), type (4), version (1), flags (3), component type (4), component subtype (4)) - cbs.bytepos += self.size - 20 - - -class STSS(Atom): - """ - Parses the sync sample atom https://developer.apple.com/documentation/quicktime-file-format/sample_table_atom/sync_sample_atom - """ - - def _parse(self, cbs: ConstBitStream) -> None: - self.number_of_entries: int = cbs.read("uint:32") - self.sync_sample_table: dict[str, Any] = parse_table( - cbs, self.number_of_entries, {"number": "uint:32"} - ) - - -class STTS(Atom): - """ - Parses the time to sample atom https://developer.apple.com/documentation/quicktime-file-format/time-to-sample_atom - """ - - def _parse(self, cbs: ConstBitStream) -> None: - self.number_of_entries: int = cbs.read("uint:32") - self.time_to_sample_table: dict[str, Any] = parse_table( - cbs, - self.number_of_entries, - {"sample_count": "uint:32", "sample_duration": "uint:32"}, - ) - - -class CTTS(Atom): - """ - Parses the composition offset atom https://developer.apple.com/documentation/quicktime-file-format/composition_offset_atom - """ - - def _parse(self, cbs: ConstBitStream) -> None: - self.number_of_entries: int = cbs.read("uint:32") - self.composition_offset_table: dict[str, Any] = parse_table( - cbs, - self.number_of_entries, - { - "sample_count": "uint:32", - "composition_offset": "int:32", - "media_rate": "", - }, - ) - - -class ELST(Atom): - """ - Parses the edit list atom https://developer.apple.com/documentation/quicktime-file-format/edit_list_atom - """ - - def _parse(self, cbs: ConstBitStream) -> None: - self.number_of_entries: int = cbs.read("uint:32") - self.edit_list_table: dict[str, Any] = parse_table( - cbs, - self.number_of_entries, - { - "track_duration": "uint:32", - "media_time": "int:32", - "media_rate": "int:32", - }, - ) - - -class MDAT(Atom): - """ - Parses the media data atom https: https://developer.apple.com/documentation/quicktime-file-format/movie_data_atom - - This is only here to handle the unusual size handling of mdat, if the normal size field is set to 1 - then the actual size is stored as a 64 bit integer - """ - - def _parse(self, cbs: ConstBitStream) -> None: - if self.size == 1: - cbs.bytepos -= 4 # No version or flags for mdat - self.size = cbs.read("uint:64") - seekto = self.size - 16 - else: - seekto = self.size - 12 - - if cbs.bytepos + seekto >= (cbs.len / 8): - raise StopIteration() - - cbs.bytepos += seekto - - -def parse_atoms(file: BitsType) -> Generator[Atom, None, None]: - try: - cbs = ConstBitStream(file) - while cbs.pos < len(cbs): - try: - yield Atom.make_atom(cbs) - except StopIteration: - return - except Error as e: - raise ValueError(f"MPEG parsing failed with error {e}") - - -def parse_mpeg(file: BitsType) -> dict[int, SortedList]: - sync_samples = {} - decode_timestamps = {} - presentation_time_offsets = {} - start_offsets = defaultdict(int) - current_track = -1 - for a in parse_atoms(file): - if a.name == "tkhd": - a: TKHD - current_track = a.track_id - elif a.name == "stts": - a: STTS - decode_timestamps[current_track] = list( - accumulate( - sum( - [ - [entry["sample_duration"]] * entry["sample_count"] - for entry in a.time_to_sample_table - ], - [0], - ) - ) - ) - elif a.name == "ctts": - a: CTTS - presentation_time_offsets[current_track] = sum( - [ - [entry["composition_offset"]] * entry["sample_count"] - for entry in a.composition_offset_table - ], - [], - ) - elif a.name == "stss": - a: STSS - sync_samples[current_track] = [ss["number"] - 1 for ss in a.sync_sample_table] - elif a.name == "elst": - # NOTE the "media_time" here is a "delay" between decoding and presenting the first sample. - # We follow the ffmpeg convention that the first frame displays at time 0 which means we should - # *subtract* this offset from the decoding time values rather than adding it to presentation time values - # TODO there can be more than one of these, figure out how to handle it - a: ELST - start_offsets[current_track] = -a.edit_list_table[0]["media_time"] - keyframes = defaultdict(SortedList) - try: - for track_id in sync_samples.keys(): - ptos = presentation_time_offsets.get(track_id) - dts = decode_timestamps[track_id] - for keyframe_number in sync_samples[track_id]: - pts = ( - dts[keyframe_number] - + start_offsets[track_id] - + (0 if ptos is None else ptos[keyframe_number]) - ) - keyframes[track_id].add(KeyframeInfo(keyframe_number, pts)) - except (KeyError, IndexError) as e: - raise ValueError(f"MPEG parsing failed with error {e}") - - return keyframes diff --git a/src/megatron/energon/av/fastseek/containers/probe.py b/src/megatron/energon/av/fastseek/containers/probe.py index 6646b7b0..e7cddf03 100644 --- a/src/megatron/energon/av/fastseek/containers/probe.py +++ b/src/megatron/energon/av/fastseek/containers/probe.py @@ -1,23 +1,29 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -from sortedcontainers import SortedList +from bitstring.bits import BitsType from ...av_init import av_open from ..keyframeinfo import KeyframeInfo -def parse_probe(file): +def parse_probe(file: BitsType) -> dict[int, list[KeyframeInfo]]: + """ + Parse the container file using pyav to find keyframes. + + Args: + file: The container file to parse. + + Returns: + A dictionary of keyframes, keyed by stream id. dict> + """ keyframes = {} with av_open(file) as input_container: for stream_idx, stream in enumerate(input_container.streams.video): - packet_pts = [ - (index, p.pts) + keyframes[stream.id] = [ + KeyframeInfo(index, p.pts) for index, p in enumerate(input_container.demux(video=stream_idx)) if p.is_keyframe ] - packet_pts.sort(key=lambda x: x[1]) - - keyframes[stream.id] = SortedList([KeyframeInfo(*p) for p in packet_pts]) - return keyframes + return keyframes diff --git a/src/megatron/energon/av/fastseek/fastseek.py b/src/megatron/energon/av/fastseek/fastseek.py index 84b7f78e..e626183a 100644 --- a/src/megatron/energon/av/fastseek/fastseek.py +++ b/src/megatron/energon/av/fastseek/fastseek.py @@ -1,17 +1,53 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -from typing import Literal, Optional +from bisect import bisect_right +from typing import Optional import filetype from bitstring.bits import BitsType -from sortedcontainers import SortedList from .containers.matroska import parse_matroska -from .containers.mpeg import parse_mpeg from .containers.probe import parse_probe from .keyframeinfo import KeyframeInfo +class StreamInfo: + """Keyframe information about a video stream.""" + + __slots__ = ("id", "keyframe_pts", "keyframe_indexes", "keyframes") + + #: The stream id. + id: int + #: The keyframes, sorted by frame index. + keyframes: list[KeyframeInfo] + #: The PTS times of the keyframes, sorted by PTS. + keyframe_pts: list[int] | None + #: The frame indexes of the keyframes if access by keyframes is allowed. + keyframe_indexes: list[int] | None + + def __init__(self, id: int, keyframes: list[KeyframeInfo]) -> None: + """Initialize the StreamInfo object. Given an unsorted list of KeyframeInfos, it provides attributes to access those sorted by PTS or by frame index. + + Args: + id: The stream id. + keyframes: The list of KeyframeInfos. + has_frame_index: Whether frame indexes are supported for this container. + """ + self.id = id + self.keyframes = sorted(keyframes, key=lambda x: (x.index or -1, x.pts or -1)) + if all(x.pts is None for x in keyframes): + self.keyframe_pts = None + else: + self.keyframe_pts = sorted(x.pts for x in keyframes if x.pts is not None) + if all(x.index is None for x in keyframes): + self.keyframe_indexes = None + else: + self.keyframe_indexes = [x.index for x in self.keyframes if x.index is not None] + + def __repr__(self) -> str: + return f"StreamInfo(id={self.id}, keyframe_pts={self.keyframe_pts}, keyframe_indexes={self.keyframe_indexes}, keyframes={self.keyframes})" + + class Fastseek: """ Gathers information from the video container file (e.g. metadata which requires minimal decoding) @@ -21,32 +57,37 @@ class Fastseek: to make informed decisions about the best seeking behavior Currently supports: - - MP4/MOV: frames are indexed by number and frame counting can be used to get the exact frame - - Matroska/WebM: frames are indexed by time and inter-frame duration must be accounted for to get to the right frame + - Matroska/WebM: frames are indexed by time and inter-frame duration must be accounted for to get to the right frame. Use force_probe=True to use pyav to get frame-accurate keyframes. + - All other formats: Use pyav to find keyframes. - If your container is not listed above, pass "probe=True" to the constructor, this will use ffmpeg to parse the stream - without decoding it. Frames will be indexed by number. This is not as fast as using a supported container but is still + Frames will be indexed by number. This is not as fast as using a supported container but is still significantly faster than sequential decoding. """ - keyframes: dict[int, SortedList[KeyframeInfo]] - unit: Literal["frames", "pts"] + #: Keyframe info by stream id + keyframes: dict[int, StreamInfo] + + #: List of stream ids for indexed access. + streams: list[int] + #: Whether frame indexes are supported for this container. + #: If True, supports frame index access. + frame_index_supported: bool + #: Whether PTS are supported for this container. + #: If True, supports PTS access. + pts_supported: bool + #: MIME type of the container. mime: str - def __init__(self, file: BitsType, probe: bool = False) -> None: + def __init__(self, file: BitsType, force_probe: bool = False) -> None: """Initialize the Fastseek object. Args: file: The video file data as a bitstring BitsType object. This should contain the raw bytes of the video file. - probe: If True, use ffmpeg to probe the stream without decoding. This is slower but works with any container format. - If False (default), attempt to parse the container format directly. Only works with MP4/MOV and Matroska/WebM. - - Raises: - ValueError: If the file type cannot be determined or if the container format is not supported (when probe=False). + force_probe: If True, use ffmpeg to probe the stream without decoding. This may be slower but works with any container format. + If False (default), attempt to parse the container format directly (only optimized for matroska). """ - if probe: - self.keyframes = parse_probe(file) - self.unit = "frames" + if force_probe: + keyframes = parse_probe(file) else: ftype = filetype.guess(file) @@ -57,28 +98,33 @@ def __init__(self, file: BitsType, probe: bool = False) -> None: self.mime = ftype.mime - if ftype.mime in ["video/mp4", "video/quicktime"]: - self.keyframes = parse_mpeg(file) - self.unit = "frames" - elif ftype.mime in ["video/x-matroska", "video/webm"]: - self.keyframes = parse_matroska(file) - self.unit = "pts" + if ftype.mime in ("video/x-matroska", "video/webm"): + keyframes = parse_matroska(file) else: - raise ValueError( - f"Unsupported container: {ftype.mime} (hint: try passing probe=True to the Fastseek constructor)" - ) + keyframes = parse_probe(file) - if len(self.keyframes) == 0: - raise ValueError( - f"The parser for {ftype.mime} was unable to find any streams (hint: try passing probe=True to the Fastseek constructor)" - ) + if len(keyframes) == 0: + raise ValueError( + f"The parser for {ftype.mime} was unable to find any streams (hint: try passing probe=True to the Fastseek constructor)" + ) - if all(len(kf) == 0 for kf in self.keyframes.values()): - raise ValueError( - f"The parser for {ftype.mime} was unable to find any keyframes (hint: try passing probe=True to the Fastseek constructor)" - ) + if all(len(kf) == 0 for kf in keyframes.values()): + raise ValueError( + f"The parser for {ftype.mime} was unable to find any keyframes (hint: try passing probe=True to the Fastseek constructor)" + ) - def should_seek(self, current: int, target: int, stream: int = 0) -> Optional[KeyframeInfo]: + self.keyframes = {k: StreamInfo(k, keyframes) for k, keyframes in keyframes.items()} + self.frame_index_supported = any( + stream.keyframe_indexes is not None for stream in self.keyframes.values() + ) + self.pts_supported = any( + stream.keyframe_pts is not None for stream in self.keyframes.values() + ) + self.streams = list(self.keyframes.keys()) + + def should_seek_by_frame( + self, current_frame_index: int, target_frame_index: int, stream: int = 0 + ) -> Optional[KeyframeInfo]: """Determine if seeking to a keyframe is necessary to reach the target frame. This method helps optimize video seeking by determining whether a seek operation @@ -87,28 +133,23 @@ def should_seek(self, current: int, target: int, stream: int = 0) -> Optional[Ke the current position would be less efficient). Args: - current: The current frame number or timestamp (depending on container format) - target: The desired frame number or timestamp to seek to + current_frame_index: The current frame number + target_frame_index: The desired frame number to seek to stream: The video stream index to use. Defaults to 0. Returns: Information about the nearest keyframe if seeking would be beneficial, or None if sequential decoding from current position is more efficient. - The KeyframeInfo contains the keyframe's position and timing information. - - Note: - The units for current and target depend on the container format: - - For MP4/MOV: frame numbers (count-based) - - For Matroska/WebM: timestamps (time-based) """ - nearest_iframe: KeyframeInfo = self.nearest_keyframe(target, stream) + nearest_iframe: KeyframeInfo = self.nearest_keyframe_by_frame(target_frame_index, stream) return ( nearest_iframe - if (current < nearest_iframe.index <= target) or (target < current) + if (current_frame_index < nearest_iframe.index <= target_frame_index) + or (target_frame_index < current_frame_index) else None ) - def nearest_keyframe(self, target: int, stream: int = 0) -> KeyframeInfo: + def nearest_keyframe_by_frame(self, target_frame_index: int, stream: int = 0) -> KeyframeInfo: """Find the nearest keyframe that comes before the target frame. This method performs a binary search to find the keyframe that is closest to, @@ -116,8 +157,7 @@ def nearest_keyframe(self, target: int, stream: int = 0) -> KeyframeInfo: optimal starting point for decoding to reach a specific frame. Args: - target: The target frame number or timestamp to find the nearest keyframe for. - The unit (frame count or timestamp) depends on the container format. + target_frame_index: The target frame number to find the nearest keyframe for. stream: The video stream index to use. Defaults to 0. Used when the container has multiple video streams. @@ -131,13 +171,57 @@ def nearest_keyframe(self, target: int, stream: int = 0) -> KeyframeInfo: workaround and may be updated in the future. """ - if stream >= len(self.keyframes): + if stream >= len(self.streams): raise ValueError(f"No stream with index {stream}") - stream_id = list(self.keyframes.keys())[stream] + assert self.frame_index_supported, "Frame indexes are not supported for this container" - if len(self.keyframes[stream_id]) == 0: + stream_id = self.streams[stream] + stream_info = self.keyframes[stream_id] + + if len(stream_info.keyframes) == 0: + raise ValueError(f"No keyframes found for stream {stream}") + assert stream_info.keyframe_indexes is not None, ( + "Frame indexes are not supported for this container" + ) + + # bisect_right returns the rightmost insertion point, so subtracting 1 gives + # us the index of the last keyframe with index <= target + nearest_iframe_to_target_index: int = ( + bisect_right(stream_info.keyframe_indexes, target_frame_index) - 1 + ) + return stream_info.keyframes[max(0, nearest_iframe_to_target_index)] + + def should_seek_by_pts(self, current_pts: int, target_pts: int, stream: int = 0) -> int | None: + """ + Find the nearest keyframe that comes before the target frame index and return its PTS. + + Args: + current_pts: The PTS of the current frame to be decoded. + target_pts: The target PTS to seek to. + stream: The video stream index to use. Defaults to 0. + Used when the container has multiple video streams. + + Returns: + The PTS of the nearest keyframe at or before the target PTS. + If no such keyframe is found, returns None. + """ + if stream >= len(self.streams): + raise ValueError(f"No stream with index {stream}") + + assert self.pts_supported, "PTS are not supported for this container" + + stream_id = self.streams[stream] + stream_info = self.keyframes[stream_id] + + if stream_info.keyframe_pts is None or len(stream_info.keyframe_pts) == 0: raise ValueError(f"No keyframes found for stream {stream}") - nearest_iframe_to_target_index: int = self.keyframes[stream_id].bisect_left(target) - 1 - return self.keyframes[stream_id][max(0, nearest_iframe_to_target_index)] + pos = bisect_right(stream_info.keyframe_pts, target_pts) - 1 + keyframe_pts = stream_info.keyframe_pts[max(0, pos)] + + return ( + keyframe_pts + if (current_pts < keyframe_pts <= target_pts) or (target_pts < current_pts) + else None + ) diff --git a/src/megatron/energon/av/fastseek/keyframeinfo.py b/src/megatron/energon/av/fastseek/keyframeinfo.py index 5134997a..7b9bd0b6 100644 --- a/src/megatron/energon/av/fastseek/keyframeinfo.py +++ b/src/megatron/energon/av/fastseek/keyframeinfo.py @@ -3,7 +3,7 @@ from dataclasses import dataclass -@dataclass +@dataclass(slots=True) class KeyframeInfo: """ Information about a detected keyframe @@ -15,17 +15,7 @@ class KeyframeInfo: """ - index: int - pts: int - - def __eq__(self, other) -> bool: - if isinstance(other, KeyframeInfo): - return self.index == other.index - - return self.index == other - - def __lt__(self, other) -> bool: - if isinstance(other, KeyframeInfo): - return self.index < other.index - - return self.index < other + #: The index of the keyframe. If None, the keyframe is not indexed by frame number. + index: int | None + #: The PTS of the keyframe. If None, the keyframe is not indexed by PTS. + pts: int | None diff --git a/src/megatron/energon/av/fastseek/reader.py b/src/megatron/energon/av/fastseek/reader.py new file mode 100644 index 00000000..5cc738d3 --- /dev/null +++ b/src/megatron/energon/av/fastseek/reader.py @@ -0,0 +1,166 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause +from abc import ABC, abstractmethod +from collections.abc import Generator +from typing import Iterator + +import av +import av.container +import av.stream + +from .fastseek import Fastseek + + +class FastseekReader(ABC): + """A class that provides a interface for reading video frames from a video stream for frame range extraction.""" + + #: The Fastseek object to use for seeking. + seeker: Fastseek + #: The input container to read from. + input_container: av.container.InputContainer + #: The iterator over the video frames. + frame_iterator: Iterator[av.VideoFrame] + #: The video stream to read from. + stream: av.stream.Stream + #: Number of frames skipped by the reader. For statistical purposes. + skipped: int + + def __init__( + self, seeker: Fastseek, input_container: av.container.InputContainer, stream_idx: int = 0 + ) -> None: + """Initialize the fastseek reader. + + Args: + seeker: The Fastseek object to use for seeking. + input_container: The pyav input container to read from. + stream_idx: The index of the video stream to read from. + """ + self.seeker = seeker + self.input_container = input_container + self.frame_iterator = input_container.decode(video=stream_idx) + self.stream = input_container.streams.video[stream_idx] + self.skipped = 0 + + @abstractmethod + def seek_read( + self, + range_start: int, + range_end: int | None, + ) -> Generator[av.VideoFrame, None, None]: + """ + Read video frames from the video stream for the given range. + `range_start <= range_end` must hold. If `range_start == range_end` and within the video range, one frame is returned. + + Args: + range_start: The start of the range to read from. The type of the range is defined by the subclass. + range_end: The end of the range to read from. The type of the range is defined by the subclass. If None, the range goes to the end of the video. This is inclusive! + + Returns: + A generator of video frames. + """ + ... + + +class FastseekReaderByFrames(FastseekReader): + """A video frame reader that seeks by frame index.""" + + #: The next frame index that would be returned by the iterator. + _next_frame_index: int = 0 + #: The previous frame that was returned by the iterator. + _previous_frame: av.VideoFrame | None = None + + def seek_read( + self, range_start: int, range_end: int | None + ) -> Generator[av.VideoFrame, None, None]: + if ( + keyframe_info := self.seeker.should_seek_by_frame( + self._next_frame_index - 1, range_start + ) + ) is not None: + seek_pts = keyframe_info.pts + if seek_pts is None: + # input_container.seek should seek to the nearest keyframe, which should be the requested one. + seek_pts = range_start + self.input_container.seek(seek_pts, stream=self.stream) + assert keyframe_info.index is not None, "Frame index is required for this container" + self._next_frame_index = keyframe_info.index + frame = None + else: + frame = self._previous_frame + # Skip the frames between the keyframe / previous frame and the requested range_start. + next_idx = self._next_frame_index + for next_idx, frame in zip( + range(self._next_frame_index + 1, range_start + 1), self.frame_iterator + ): + pass + self.skipped += next_idx - self._next_frame_index + self._next_frame_index = next_idx + + if frame is not None and self._next_frame_index - 1 == range_start: + # Repeat the previous frame. + yield frame + if range_end == range_start: + # Special case: User requested the last frame again, not more. + return + for frame in self.frame_iterator: + self._next_frame_index += 1 + self._previous_frame = frame + yield frame + if range_end is not None and self._next_frame_index > range_end: + break + + +class FastseekReaderByPts(FastseekReader): + """A video frame reader that seeks by PTS.""" + + #: The PTS of the next frame that would be returned by the iterator. + _next_frame_pts: int = 0 + #: The previous frame that was returned by the iterator. + _previous_frame: av.VideoFrame | None = None + + def seek_read( + self, range_start: int, range_end: int | None + ) -> Generator[av.VideoFrame, None, None]: + assert range_end is None or range_start <= range_end, ( + f"Range start {range_start} must be less or equal than range end {range_end}" + ) + if ( + seek_keyframe_pts := self.seeker.should_seek_by_pts(self._next_frame_pts, range_start) + ) is not None: + # Seeking backward or forward beyond the next keyframe + # print(f"Seeking to frame {self.next_keyframe_pts} for {range_start} from {self.next_frame_pts}") + self.input_container.seek(seek_keyframe_pts, stream=self.stream) + self._next_frame_pts = seek_keyframe_pts + frame = self._previous_frame = None + else: + frame = self._previous_frame + # Skip frames before start + if frame is None or range_start >= (frame.pts + frame.duration): + skipped = 0 + for frame in self.frame_iterator: + if range_start < (frame.pts + frame.duration): + break + skipped += 1 + else: + # Out of the end of the video + frame = None + self.skipped += skipped + if frame is not None: + self._next_frame_pts = frame.pts + self._previous_frame = frame + if frame is None: + # No frame available -> at the end of the video + # Just keep the next_frame_pts as it was + return + # Yield at least the current frame. It's after the start. + yield frame + while range_end is None or (frame.pts + frame.duration) <= range_end: + try: + frame = next(self.frame_iterator) + except StopIteration: + self._previous_frame = None + break + # Store the current frame's PTS, because we can still access that frame! + self._next_frame_pts = frame.pts + self._previous_frame = frame + yield frame diff --git a/src/megatron/energon/cache/__init__.py b/src/megatron/energon/cache/__init__.py index 13e3ca2b..5ba07c72 100644 --- a/src/megatron/energon/cache/__init__.py +++ b/src/megatron/energon/cache/__init__.py @@ -1,7 +1,14 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -from megatron.energon.cache.base import CachePool, FileStore, FileStoreDecoder, Lazy, MockLazy +from megatron.energon.cache.base import ( + CachePool, + FileStore, + FileStoreDecoder, + Lazy, + MockLazy, + PrimaryFileStore, +) from megatron.energon.cache.file_cache_pool import FileCacheLazy, FileStoreCachePool from megatron.energon.cache.file_store import DecodeFileStore, SystemFileStore from megatron.energon.cache.no_cache import DirectLazy, NoCachePool @@ -17,5 +24,6 @@ "Lazy", "MockLazy", "NoCachePool", + "PrimaryFileStore", "SystemFileStore", ] diff --git a/src/megatron/energon/cache/base.py b/src/megatron/energon/cache/base.py index 46068711..c4bf4ec4 100644 --- a/src/megatron/energon/cache/base.py +++ b/src/megatron/energon/cache/base.py @@ -5,6 +5,7 @@ from typing import Any, Callable, Generic, TypeVar from megatron.energon.edataclass import edataclass +from megatron.energon.media.metadata import MediaMetadataBase from megatron.energon.source_info import SourceInfo, add_source_info T = TypeVar("T") @@ -18,6 +19,16 @@ def __getitem__(self, key: str) -> tuple[T, SourceInfo]: """Returns the data for the given key.""" ... + def _get_raw(self, key: str) -> tuple[bytes, SourceInfo]: + """Returns the raw data for the given key. Without decoding. + For non-wrapped FileStores, this is the same as __getitem__.""" + return self[key] + + def _decode_raw(self, data: T, **kwargs) -> T: + """Decodes the raw data into the target type. + For non-wrapped FileStores, this is the same as identity""" + return data + def get(self, key: str, sample: Any = None) -> Any: """Returns the data for the given key and adds the source info to the sample.""" data, source_info = self[key] @@ -29,6 +40,39 @@ def get_path(self) -> str: """Returns the path to the dataset.""" ... + def get_media_metadata(self, key: str) -> MediaMetadataBase: + """Return the media metadata for the given key if available.""" + + raise NotImplementedError( + f"{type(self).__name__} does not support media metadata retrieval" + ) + + +class FileStoreWrapper(FileStore[T]): + """ + A wrapper around a FileStore that can be used to add additional functionality. + Typically for decoding the data on access. + """ + + _inner: FileStore[T] + + def __init__(self, inner: FileStore[T]): + self._inner = inner + + def _get_raw(self, key: str) -> bytes: + """Returns the raw data for the given key. Without decoding.""" + return self._inner._get_raw(key) + + def _decode_raw(self, data: T, **kwargs) -> T: + """Decodes the raw data into the target type. + Args: + data: The raw data to decode. + **kwargs: Additional keyword arguments to pass to the decoder. + Returns: + The decoded data. + """ + return self._inner._decode_raw(data, **kwargs) + @edataclass class Lazy(Generic[T]): @@ -159,3 +203,30 @@ def decode(self, fname: str, data: bytes) -> Any: The decoded field's data. """ ... + + +class PrimaryFileStore(FileStoreWrapper[T]): + """Same as the FileStore, but additionally uses the current sample's key as a prefix for the key, + if the key passed to `__getitem__`, `.get` or `.get_media_metadata` starts with a '.'.""" + + def __init__(self, inner: FileStore[T], current_key: str): + super().__init__(inner) + self._current_key = current_key + + def __getitem__(self, key: str) -> tuple[T, SourceInfo]: + if key.startswith("."): + key = f"{self._current_key}{key}" + return self._inner[key] + + def get(self, key: str, sample: Any = None) -> Any: + if key.startswith("."): + key = f"{self._current_key}{key}" + return self._inner.get(key, sample) + + def get_path(self) -> str: + return self._inner.get_path() + + def get_media_metadata(self, key: str) -> MediaMetadataBase: + if key.startswith("."): + key = f"{self._current_key}{key}" + return self._inner.get_media_metadata(key) diff --git a/src/megatron/energon/cache/file_cache_pool.py b/src/megatron/energon/cache/file_cache_pool.py index dc909800..6710a22d 100644 --- a/src/megatron/energon/cache/file_cache_pool.py +++ b/src/megatron/energon/cache/file_cache_pool.py @@ -14,7 +14,6 @@ from typing import Any, Dict, Literal, Optional, Tuple, TypeVar from megatron.energon.cache.base import CachePool, FileStore, Lazy -from megatron.energon.cache.file_store import DecodeFileStore from megatron.energon.edataclass import edataclass from megatron.energon.fork_hook import ForkMixin from megatron.energon.source_info import SourceInfo, add_source_info @@ -152,7 +151,7 @@ class FileStoreCachePool(CachePool, ForkMixin): _pending_tasks: Dict[Tuple[str, str], _PendingTask] # Lock for all shared structures - _lock: threading.Lock + _lock: threading.RLock # Condition variable to signal when cache space is available _cache_space_available: threading.Condition # Whether the pool is shutting down @@ -203,7 +202,7 @@ def __init__( self.current_cache_count = 0 # A lock to protect all shared structures - self._lock = threading.Lock() + self._lock = threading.RLock() # Condition variable to signal when cache space is available self._cache_space_available = threading.Condition(self._lock) @@ -282,10 +281,7 @@ def _cache_out_task(self, ds: FileStore, fname: str, entry: _PendingTask) -> boo # Perform the data read if self.method == "raw": - if isinstance(ds, DecodeFileStore): - data, entry.source_info = ds.inner_reader[fname] - else: - data, entry.source_info = ds[fname] + data, entry.source_info = ds._get_raw(fname) elif self.method == "pickle": data, entry.source_info = ds[fname] data = pickle.dumps(data) @@ -428,10 +424,7 @@ def _read_from_cache(self, entry: _PendingTask) -> tuple[Any, SourceInfo]: with open(entry.cache_path, "rb") as f: if self.method == "raw": raw = f.read() - if isinstance(entry.ds, DecodeFileStore): - return entry.ds.decoder.decode(entry.fname, raw), entry.source_info - else: - return raw, entry.source_info + return entry.ds._decode_raw(raw, fname=entry.fname), entry.source_info else: return pickle.load(f), entry.source_info diff --git a/src/megatron/energon/cache/file_store.py b/src/megatron/energon/cache/file_store.py index 6247d748..556b0bdd 100644 --- a/src/megatron/energon/cache/file_store.py +++ b/src/megatron/energon/cache/file_store.py @@ -1,42 +1,54 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause +import sqlite3 from typing import Any, Optional, Union -from megatron.energon.cache.base import FileStore, FileStoreDecoder +from megatron.energon.cache.base import FileStore, FileStoreDecoder, FileStoreWrapper from megatron.energon.epathlib import EPath from megatron.energon.flavors.jsonl.ijsonl_reader import IJsonlReader +from megatron.energon.flavors.webdataset.config import INDEX_SQLITE_FILENAME, MAIN_FOLDER_NAME from megatron.energon.flavors.webdataset.itar_reader import SqliteITarEntryReader +from megatron.energon.flavors.webdataset.thread_local_sqlite import ThreadLocalSqlite +from megatron.energon.local_copy import ensure_local_copy +from megatron.energon.media.metadata import MediaMetadataBase, deserialize_media_metadata from megatron.energon.source_info import SourceInfo -class DecodeFileStore(FileStore[Any]): +class DecodeFileStore(FileStoreWrapper[Any]): """Used to wrap a FileStore and decode the data on access.""" def __init__( self, - inner_reader: FileStore[bytes], + inner: FileStore[bytes], *, decoder: FileStoreDecoder, ): """ Args: - inner_reader: The FileStore to wrap. + inner: The FileStore to wrap. decoder: The decoder to apply to every item read from the FileStore. """ - self.inner_reader = inner_reader + super().__init__(inner) self.decoder = decoder def __getitem__(self, fname: str) -> tuple[Any, SourceInfo]: - data, source_info = self.inner_reader[fname] + data, source_info = self._inner[fname] return self.decoder.decode(fname, data), source_info + def _decode_raw(self, data: bytes, **kwargs) -> Any: + fname = kwargs["fname"] + return self.decoder.decode(fname, self._inner._decode_raw(data, **kwargs)) + def get_path(self) -> str: - return self.inner_reader.get_path() + return self._inner.get_path() def __str__(self): - return f"DecodeFileStore(inner_reader={self.inner_reader}, decoder={self.decoder})" + return f"DecodeFileStore(inner={self._inner}, decoder={self.decoder})" + + def get_media_metadata(self, key: str) -> MediaMetadataBase: + return self._inner.get_media_metadata(key) class SystemFileStore(FileStore[bytes]): @@ -50,6 +62,8 @@ def __init__(self, base_dir: Optional[Union[EPath, str]] = None): """ self.base_dir = EPath(base_dir) if base_dir is not None else None + self._media_metadata_reader: Optional[ThreadLocalSqlite] = None + self._media_metadata_checked = False def __getitem__(self, key: str) -> tuple[bytes, SourceInfo]: # Construct the full path from the dataset path and the file key @@ -64,8 +78,8 @@ def __getitem__(self, key: str) -> tuple[bytes, SourceInfo]: return data, SourceInfo( dataset_path=self.base_dir, - index=None, - shard_name=None, + index=key, + shard_name=str(self.base_dir), file_names=(key,), ) @@ -76,6 +90,56 @@ def get_path(self) -> str: def __str__(self): return f"SystemFileStore(base_dir={self.base_dir})" + def get_media_metadata(self, key: str) -> MediaMetadataBase: + if self.base_dir is None: + raise RuntimeError("Media metadata requires a base directory for SystemFileStore") + + reader = self._ensure_media_metadata_reader() + row = reader.select_one( + "SELECT metadata_type, metadata_json FROM media_metadata WHERE entry_key = ?", + (key,), + ) + if row is None: + file_path = self.base_dir / key + if file_path.is_file(): + raise KeyError( + f"Media metadata missing for {key}. " + "Run `energon prepare --media-metadata-by-...` to regenerate it." + ) + raise KeyError(f"File {file_path} not found") + metadata_type, metadata_json = row + return deserialize_media_metadata(metadata_type, metadata_json) + + def _ensure_media_metadata_reader(self) -> ThreadLocalSqlite: + assert self.base_dir is not None + if self._media_metadata_reader is None: + sqlite_path = self.base_dir / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME + + if not sqlite_path.is_file(): + raise RuntimeError( + f"Media metadata database missing at {sqlite_path}. " + "Run `energon prepare --media-metadata-by-...` for this dataset." + ) + + local_sqlite_path = ensure_local_copy(sqlite_path) + db_uri = f"file:{str(local_sqlite_path)}?mode=ro&immutable=1" + self._media_metadata_reader = ThreadLocalSqlite(db_uri, is_uri=True) + + if not self._media_metadata_checked: + assert self._media_metadata_reader is not None + exists = self._media_metadata_reader.select_one( + "SELECT name FROM sqlite_master WHERE type='table' AND name='media_metadata'" + ) + if exists is None: + self._media_metadata_reader.thread_close() + self._media_metadata_reader = None + raise RuntimeError( + "Media metadata table missing. Re-run `energon prepare --media-metadata-by-...`." + ) + self._media_metadata_checked = True + + return self._media_metadata_reader + class WebdatasetFileStore(SqliteITarEntryReader, FileStore[bytes]): """This dataset will directly read files from the dataset tar files from a prepared energon dataset.""" @@ -85,10 +149,41 @@ def __init__( dataset_path: EPath, ): super().__init__(base_path=dataset_path, key_is_full_entryname=True) + self._media_metadata_available: Optional[bool] = None def get_path(self) -> str: return str(self.base_path) + def get_media_metadata(self, key: str) -> MediaMetadataBase: + if self._media_metadata_available is None: + try: + has_metadata = self.sqlite_reader.db_has_media_metadata() + except sqlite3.Error as exc: # pragma: no cover - defensive + raise RuntimeError( + "Failed to inspect media metadata table. Re-run `energon prepare --media-metadata-by-...`." + ) from exc + + if not has_metadata: + raise RuntimeError( + "Media metadata is not available for this dataset. " + "Run `energon prepare --media-metadata-by-...` to generate it." + ) + + self._media_metadata_available = True + + try: + row = self.sqlite_reader.get_media_metadata(key) + except sqlite3.Error as exc: # pragma: no cover - defensive + raise RuntimeError( + "Failed to load media metadata. Re-run `energon prepare --media-metadata-by-...`." + ) from exc + + if row is None: + raise KeyError(f"Sample {key!r} not found") + + metadata_type, metadata_json = row + return deserialize_media_metadata(metadata_type, metadata_json) + class JsonlFileStore(IJsonlReader, FileStore[bytes]): """This dataset will directly read entries from a jsonl file.""" diff --git a/src/megatron/energon/cli/main.py b/src/megatron/energon/cli/main.py index c9f72a59..c18ff2d6 100644 --- a/src/megatron/energon/cli/main.py +++ b/src/megatron/energon/cli/main.py @@ -11,6 +11,7 @@ from megatron.energon.tools.lint import command as lint_command from megatron.energon.tools.mount import command as mount_command from megatron.energon.tools.prepare import command as prepare_command +from megatron.energon.tools.prepare_media import command as prepare_media_command from megatron.energon.tools.preview import command as preview_command @@ -56,6 +57,7 @@ def main(ctx): main.add_command(info_command) main.add_command(mount_command) main.add_command(prepare_command) +main.add_command(prepare_media_command) main.add_command(preview_command) if __name__ == "__main__": diff --git a/src/megatron/energon/dataset_config.py b/src/megatron/energon/dataset_config.py index eb5fb20c..68642af0 100644 --- a/src/megatron/energon/dataset_config.py +++ b/src/megatron/energon/dataset_config.py @@ -12,7 +12,7 @@ DefaultCrudeJsonlDatasetFactory, StandardWebdatasetFactory, ) -from megatron.energon.flavors.webdataset import MAIN_FOLDER_NAME +from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME from megatron.energon.flavors.webdataset.metadata import EnergonDatasetType, get_dataset_type from megatron.energon.typed_converter import JsonParser from megatron.energon.worker import WorkerConfig @@ -127,6 +127,8 @@ def get_dataset_from_config( ), default_type=StandardWebdatasetFactory, ) + elif ds_type == EnergonDatasetType.FILESYSTEM: + raise ValueError("Filesystem datasets are only supported as auxiliary datasets. ") else: raise ValueError( f"Path {path} does not contain a {MAIN_FOLDER_NAME}/.info.yaml or .info.json file nor is it a jsonl file. " diff --git a/src/megatron/energon/deprecation.py b/src/megatron/energon/deprecation.py new file mode 100644 index 00000000..f55b5e04 --- /dev/null +++ b/src/megatron/energon/deprecation.py @@ -0,0 +1,21 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause + +import warnings +from functools import wraps + + +def warn_deprecated(reason, stacklevel=2): + warnings.warn(reason, FutureWarning, stacklevel=stacklevel) + + +def deprecated(reason): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + warn_deprecated(f"{func.__name__} is deprecated: {reason}", stacklevel=3) + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/src/megatron/energon/epathlib/epath.py b/src/megatron/energon/epathlib/epath.py index 74ae9729..55c890ca 100644 --- a/src/megatron/energon/epathlib/epath.py +++ b/src/megatron/energon/epathlib/epath.py @@ -9,7 +9,7 @@ import multistorageclient as msc -from megatron.energon.errors import warn_deprecated +from megatron.energon.deprecation import warn_deprecated class EPath: @@ -108,22 +108,35 @@ def _internal_str_path(self) -> str: return str(self.internal_path) @overload - def open(self, mode: Literal["r", "w"] = "r", block_size: Optional[int] = None) -> TextIO: ... + def open( + self, + mode: Literal["r", "w"] = "r", + block_size: Optional[int] = None, + prefetch_file: bool = False, + ) -> TextIO: ... @overload - def open(self, mode: Literal["rb", "wb"], block_size: Optional[int] = None) -> BinaryIO: ... + def open( + self, + mode: Literal["rb", "wb"], + block_size: Optional[int] = None, + prefetch_file: bool = False, + ) -> BinaryIO: ... def open( - self, mode: Literal["r", "rb", "w", "wb"] = "r", block_size: Optional[int] = None + self, + mode: Literal["r", "rb", "w", "wb"] = "r", + block_size: Optional[int] = None, + prefetch_file: bool = False, ) -> Union[TextIO, BinaryIO]: - return self.fs.open(self._internal_str_path, mode) + return self.fs.open(self._internal_str_path, mode, prefetch_file=prefetch_file) def read_text(self) -> str: - with self.open() as f: + with self.open(prefetch_file=True) as f: return f.read() def read_bytes(self) -> bytes: - with self.open("rb") as f: + with self.open("rb", prefetch_file=True) as f: return f.read() def write_text(self, text: str) -> None: diff --git a/src/megatron/energon/errors.py b/src/megatron/energon/errors.py index a6464dd3..44c14eb8 100644 --- a/src/megatron/energon/errors.py +++ b/src/megatron/energon/errors.py @@ -1,83 +1,19 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -import dataclasses -import itertools -import warnings -from functools import wraps -from typing import Any, Type, TypeVar, Union - -import numpy as np -import torch - - -def compact_str( - value: Union[dict, list, str, int, bool, None], - depth: int = 3, - max_items: int = 10, - max_str_len: int = 50, -) -> str: - """ - Compact representation of a value as a string. +from contextlib import contextmanager +from typing import Any, Callable, Generator, Type, TypeVar - Args: - value: The value to compact - depth: The maximum depth to compact - max_items: The maximum number of items to show in a list or dict - max_str_len: The maximum string length to show +from megatron.energon.sample_utils import format_sample_compact, format_sample_detailed +from megatron.energon.source_info import SourceInfo, get_source_info - Returns: The printable string - """ - if isinstance(value, dict): - if depth <= 0: - return "{...}" - return ( - "{" - + ", ".join( - ( - f"{k}: {v!r}" - if isinstance(k, str) and k.startswith("__") - else f"{k}: {compact_str(v, depth - 1, max_items, max_str_len)}" - ) - for k, v in itertools.islice(value.items(), max_items) - ) - + "}" - ) - elif isinstance(value, list): - if depth <= 0: - return "[...]" - return ( - "[" - + ", ".join( - compact_str(v, depth - 1, max_items, max_str_len) for v in value[:max_items] - ) - + "]" - ) - elif isinstance(value, tuple): - if depth <= 0: - return "(...)" - return ( - "(" - + ", ".join( - compact_str(v, depth - 1, max_items, max_str_len) for v in value[:max_items] - ) - + ")" - ) - elif isinstance(value, str): - if len(value) > max_str_len: - return repr(value[:max_str_len] + "...") - return repr(value) - elif isinstance(value, torch.Tensor): - return f"Tensor(shape={value.shape}, dtype={value.dtype}, device={value.device})" - elif isinstance(value, np.ndarray): - return f"np.ndarray(shape={value.shape}, dtype={value.dtype})" - elif dataclasses.is_dataclass(value): - return f"{value.__class__.__name__}({', '.join(f'{field.name}={compact_str(getattr(value, field.name))}' for field in dataclasses.fields(value))})" - else: - return compact_str(repr(value), depth, max_items, max_str_len) +T = TypeVar("T") -T = TypeVar("T") +class SkipSample(Exception): + """Raise this exception in any processing function to skip the current sample.""" + + pass class SampleException(ValueError): @@ -89,7 +25,7 @@ def from_sample_key(cls: Type[T], sample_key: str) -> T: def from_sample(cls: Type[T], sample: Any, message: str = "") -> T: if message: message = f": {message}" - return cls(f"Sample {compact_str(sample)} failed{message}") + return cls(f"Sample {format_sample_compact(sample)} failed{message}") class FatalSampleError(SampleException): @@ -97,22 +33,6 @@ class FatalSampleError(SampleException): pass -def warn_deprecated(reason, stacklevel=2): - warnings.warn(reason, FutureWarning, stacklevel=stacklevel) - - -def deprecated(reason): - def decorator(func): - @wraps(func) - def wrapper(*args, **kwargs): - warn_deprecated(f"{func.__name__} is deprecated: {reason}", stacklevel=3) - return func(*args, **kwargs) - - return wrapper - - return decorator - - SYSTEM_EXCEPTIONS = ( SystemError, SyntaxError, @@ -126,3 +46,183 @@ def wrapper(*args, **kwargs): UnboundLocalError, FatalSampleError, ) + + +class ErrorContext: + """Tracks consecutive errors and enforces error tolerance limits. + + This class helps prevent infinite error loops by tracking consecutive failures + and raising a FatalSampleError when a tolerance threshold is exceeded. + + Example: + error_ctx = ErrorContext( + name="MapDataset.map_fn", + handler=self.worker_config.global_error_handler + tolerance=100, + ) + + with error_ctx.handle_errors(sample): + result = process_sample(sample) + """ + + name: str + tolerance: int + handler: Callable[[Exception, Any, list["SourceInfo"] | None], None] + + _consecutive_failures: int = 0 + + def __init__( + self, + name: str, + handler: Callable[[Exception, Any, list["SourceInfo"] | None], None], + tolerance: int = 100, + ): + """Initialize error context. + + Args: + name: Name of the operation being tracked (for error messages). + handler: Error handler function to call on exceptions. Takes (exception, sample, sources). + If None, exceptions will be raised after incrementing the counter. + tolerance: Maximum number of consecutive failures before raising FatalSampleError. + Set to 0 to disable tolerance checking. + """ + self.name = name + self.tolerance = tolerance + self.handler = handler + + def reset(self) -> None: + """Reset the consecutive failures counter.""" + self._consecutive_failures = 0 + + @contextmanager + def handle_errors( + self, + sample: Any, + ) -> Generator[None, None, None]: + """Context manager for handling exceptions during sample processing. + + Automatically tracks consecutive failures and resets on success. + + Args: + sample: The sample being processed (used in error reporting). + """ + try: + yield + # Success - reset counter + self._consecutive_failures = 0 + except GeneratorExit: + raise + except SkipSample: + pass + except SYSTEM_EXCEPTIONS as e: + raise FatalSampleError.from_sample( + sample, f"{self.name} failed due to system exception: {e}." + ) + except Exception as e: + print(f"Except {e} in {self.name}") + # Call the error handler if provided + if self.handler is not None: + # Call the error handler + self.handler(e, sample, get_source_info(sample)) + + # Increment counter (may raise FatalSampleError if tolerance exceeded) + self._consecutive_failures += 1 + + if self._consecutive_failures > 1: + print( + f"ErrorContext {self.name} failed {self._consecutive_failures}/{self.tolerance} times in a row." + ) + if self.tolerance > 0 and self._consecutive_failures >= self.tolerance: + raise FatalSampleError.from_sample( + sample, + ( + f"{self.name} failed {self._consecutive_failures} times in a row. " + f"Likely your code or dataset are broken." + ), + ) + + def __repr__(self) -> str: + return f"ErrorContext(name={self.name!r}, tolerance={self.tolerance}, count={self._consecutive_failures})" + + +@contextmanager +def handle_restore_errors( + error_handler: Callable[[Exception, Any, list["SourceInfo"] | None], None], + sample: Any, +) -> Generator[None, None, None]: + """Context manager for handling exceptions during sample restoration. + + Args: + error_handler: Function to call when an exception occurs. Takes (exception, sample, sources). + sample: The sample being restored. + """ + try: + yield + except SkipSample as e: + # Unexpected skip sample + try: + raise ValueError(f"Unexpected skip sample {sample} during restoration.") from e + except Exception as e: + error_handler(e, sample, get_source_info(sample)) + except GeneratorExit as e: + # Unexpected skip sample + try: + raise ValueError( + f"Unexpected generator early stopping for sample {sample} during restoration." + ) from e + except Exception as e: + error_handler(e, sample, get_source_info(sample)) + except SYSTEM_EXCEPTIONS as e: + raise FatalSampleError.from_sample(sample) from e + except Exception as e: + error_handler(e, sample, get_source_info(sample)) + + +def log_exception(e: Exception, sample: Any, sources: list["SourceInfo"] | None = None) -> None: + """Error handler that logs exceptions with sample information. + + This function prints the exception traceback, source information if available, + and a smart representation of the failed sample to help with debugging. + + Args: + e: The exception that was raised. + sample: The sample that caused the exception. + sources: Optional list of SourceInfo objects with sample provenance. + """ + import traceback + + traceback.print_exc() + print("-" * 10) + + if sources: + print("Sources:") + for source in sources: + if hasattr(source, "dataset_path"): + print( + f" - {source.dataset_path}[{source.index}] {source.shard_name}{source.file_names!r}" + ) + print("-" * 10) + + sample_str = format_sample_detailed(sample) + print(sample_str) + + print("-" * 10) + + +def reraise_exception( + e: Exception, _sample: Any, _sources: list["SourceInfo"] | None = None +) -> None: + """Error handler that simply reraises the exception. + + This is useful when you want failures to propagate immediately without + any tolerance or logging. + + Args: + e: The exception to reraise. + _sample: The sample (unused). + _sources: Source info (unused). + + Raises: + The original exception. + """ + raise e diff --git a/src/megatron/energon/flavors/captioning.py b/src/megatron/energon/flavors/captioning.py index 449bc290..606ca2c8 100644 --- a/src/megatron/energon/flavors/captioning.py +++ b/src/megatron/energon/flavors/captioning.py @@ -4,9 +4,9 @@ import torch +from megatron.energon.deprecation import warn_deprecated from megatron.energon.edataclass import edataclass from megatron.energon.epathlib.epath import EPath -from megatron.energon.errors import warn_deprecated from megatron.energon.flavors.base_dataset import Sample from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory diff --git a/src/megatron/energon/flavors/image.py b/src/megatron/energon/flavors/image.py index 7d36a70d..e756ffcc 100644 --- a/src/megatron/energon/flavors/image.py +++ b/src/megatron/energon/flavors/image.py @@ -4,9 +4,9 @@ import torch +from megatron.energon.deprecation import warn_deprecated from megatron.energon.edataclass import edataclass from megatron.energon.epathlib.epath import EPath -from megatron.energon.errors import warn_deprecated from megatron.energon.flavors.base_dataset import Sample from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory diff --git a/src/megatron/energon/flavors/image_classification.py b/src/megatron/energon/flavors/image_classification.py index ad8764de..fd6debc2 100644 --- a/src/megatron/energon/flavors/image_classification.py +++ b/src/megatron/energon/flavors/image_classification.py @@ -5,9 +5,9 @@ import torch +from megatron.energon.deprecation import warn_deprecated from megatron.energon.edataclass import edataclass from megatron.energon.epathlib.epath import EPath -from megatron.energon.errors import warn_deprecated from megatron.energon.flavors.base_dataset import Sample from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory diff --git a/src/megatron/energon/flavors/interleaved.py b/src/megatron/energon/flavors/interleaved.py index 10ec8ac6..e9bb921d 100644 --- a/src/megatron/energon/flavors/interleaved.py +++ b/src/megatron/energon/flavors/interleaved.py @@ -5,9 +5,9 @@ import torch +from megatron.energon.deprecation import warn_deprecated from megatron.energon.edataclass import edataclass from megatron.energon.epathlib.epath import EPath -from megatron.energon.errors import warn_deprecated from megatron.energon.flavors.base_dataset import Sample from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory diff --git a/src/megatron/energon/flavors/jsonl/crude_jsonl_dataset.py b/src/megatron/energon/flavors/jsonl/crude_jsonl_dataset.py index 8358ca11..651315cc 100644 --- a/src/megatron/energon/flavors/jsonl/crude_jsonl_dataset.py +++ b/src/megatron/energon/flavors/jsonl/crude_jsonl_dataset.py @@ -10,13 +10,11 @@ from megatron.energon.flavors.base_dataset import ( BaseCoreDatasetFactory, SavableDataset, - legacy_handler, ) from megatron.energon.flavors.crude import CrudeSample from megatron.energon.flavors.jsonl.ijsonl import IJsonlIndexReader from megatron.energon.flavors.jsonl.jsonl_prepare import JsonlPreparator from megatron.energon.flavors.webdataset.base_webdataset import _print_shard_slices -from megatron.energon.flavors.webdataset.error_handler import ErrorHandler from megatron.energon.flavors.webdataset.sample_loader import ( RawSampleData, WebdatasetSampleLoaderDataset, @@ -26,9 +24,7 @@ DatasetSubset, FilteredSample, ShardInfo, - reraise_exception, ) -from megatron.energon.source_info import SourceInfo from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.map_dataset import MapDataset @@ -39,7 +35,6 @@ class CrudeJsonlDatasetFactory( BaseCoreDatasetFactory[CrudeSample], JsonlPreparator, Sharder, - ErrorHandler, ): """ Factory class for creating a crude dataset from JSONL (JSON Lines) files. @@ -66,9 +61,6 @@ def __init__( max_samples_per_sequence: Optional[int] = None, subset: Optional[DatasetSubset] = None, part_filter: Optional[Callable[[str], bool]] = None, - handler: Callable[ - [Exception, Optional[str], Optional[list[SourceInfo]]], None - ] = reraise_exception, ): """ Factory for a jsonl file as a crude dataset. @@ -89,7 +81,6 @@ def __init__( will be sequentially iterated). subset: If specified, the dataset will be subsetted. part_filter: (internal) Function for filtering tar files by dict keys - handler: Exception handler. Args: (exception, key). """ assert self.__sample_type__ is not None, f"Class {type(self)} must define __sample_type__" self.path = path @@ -101,7 +92,6 @@ def __init__( self.max_samples_per_sequence = max_samples_per_sequence self.subset = subset self.part_filter = part_filter - self.handler = legacy_handler(handler) if part_filter is None or part_filter("json"): self._len = IJsonlIndexReader.count_samples(path) else: @@ -157,7 +147,6 @@ def build(self, worker_rotation_offset: int = 0) -> SavableDataset[CrudeSample]: return MapDataset( dataset, self._load_sample_raw, - error_handler=self.error_handler, stateless_map_fn=True, map_fn_config=self.config, worker_config=self.worker_config, diff --git a/src/megatron/energon/flavors/jsonl/ijsonl_reader.py b/src/megatron/energon/flavors/jsonl/ijsonl_reader.py index f50b242c..202f1926 100644 --- a/src/megatron/energon/flavors/jsonl/ijsonl_reader.py +++ b/src/megatron/energon/flavors/jsonl/ijsonl_reader.py @@ -88,7 +88,7 @@ def _get_item_by_sample_pointer( return None return FilteredSample( - __key__=f"{self.jsonl_path.name}/{key}", + __key__=key, __shard__=self.jsonl_path.name, __restore_key__=("Webdataset", sample_pointer.index), __sources__=( diff --git a/src/megatron/energon/flavors/multichoice_vqa.py b/src/megatron/energon/flavors/multichoice_vqa.py index a89c8bba..fe941060 100644 --- a/src/megatron/energon/flavors/multichoice_vqa.py +++ b/src/megatron/energon/flavors/multichoice_vqa.py @@ -5,9 +5,9 @@ import torch +from megatron.energon.deprecation import warn_deprecated from megatron.energon.edataclass import edataclass from megatron.energon.epathlib.epath import EPath -from megatron.energon.errors import warn_deprecated from megatron.energon.flavors.base_dataset import Sample from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory diff --git a/src/megatron/energon/flavors/ocr.py b/src/megatron/energon/flavors/ocr.py index 35c67857..4ed932d4 100644 --- a/src/megatron/energon/flavors/ocr.py +++ b/src/megatron/energon/flavors/ocr.py @@ -5,9 +5,9 @@ import torch +from megatron.energon.deprecation import warn_deprecated from megatron.energon.edataclass import edataclass from megatron.energon.epathlib.epath import EPath -from megatron.energon.errors import warn_deprecated from megatron.energon.flavors.base_dataset import Sample from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory diff --git a/src/megatron/energon/flavors/similarity_interleaved.py b/src/megatron/energon/flavors/similarity_interleaved.py index b80bb155..13415198 100644 --- a/src/megatron/energon/flavors/similarity_interleaved.py +++ b/src/megatron/energon/flavors/similarity_interleaved.py @@ -5,9 +5,9 @@ import torch +from megatron.energon.deprecation import warn_deprecated from megatron.energon.edataclass import edataclass from megatron.energon.epathlib.epath import EPath -from megatron.energon.errors import warn_deprecated from megatron.energon.flavors.base_dataset import Sample from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory diff --git a/src/megatron/energon/flavors/text.py b/src/megatron/energon/flavors/text.py index a3bd498b..e3ad5002 100644 --- a/src/megatron/energon/flavors/text.py +++ b/src/megatron/energon/flavors/text.py @@ -2,9 +2,9 @@ # SPDX-License-Identifier: BSD-3-Clause +from megatron.energon.deprecation import warn_deprecated from megatron.energon.edataclass import edataclass from megatron.energon.epathlib.epath import EPath -from megatron.energon.errors import warn_deprecated from megatron.energon.flavors.base_dataset import Sample from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory diff --git a/src/megatron/energon/flavors/vid_qa.py b/src/megatron/energon/flavors/vid_qa.py index 89642d23..31fe30b4 100644 --- a/src/megatron/energon/flavors/vid_qa.py +++ b/src/megatron/energon/flavors/vid_qa.py @@ -6,9 +6,9 @@ import torch from megatron.energon.av import AVDecoder +from megatron.energon.deprecation import warn_deprecated from megatron.energon.edataclass import edataclass from megatron.energon.epathlib.epath import EPath -from megatron.energon.errors import warn_deprecated from megatron.energon.flavors.base_dataset import Sample from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory diff --git a/src/megatron/energon/flavors/vqa.py b/src/megatron/energon/flavors/vqa.py index 33f3bb18..1c2f989f 100644 --- a/src/megatron/energon/flavors/vqa.py +++ b/src/megatron/energon/flavors/vqa.py @@ -5,9 +5,9 @@ import torch +from megatron.energon.deprecation import warn_deprecated from megatron.energon.edataclass import edataclass from megatron.energon.epathlib.epath import EPath -from megatron.energon.errors import warn_deprecated from megatron.energon.flavors.base_dataset import Sample from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory diff --git a/src/megatron/energon/flavors/vqa_and_ocr.py b/src/megatron/energon/flavors/vqa_and_ocr.py index 00400cd8..63f46486 100644 --- a/src/megatron/energon/flavors/vqa_and_ocr.py +++ b/src/megatron/energon/flavors/vqa_and_ocr.py @@ -5,9 +5,9 @@ import torch +from megatron.energon.deprecation import warn_deprecated from megatron.energon.edataclass import edataclass from megatron.energon.epathlib.epath import EPath -from megatron.energon.errors import warn_deprecated from megatron.energon.flavors.base_dataset import Sample from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory diff --git a/src/megatron/energon/flavors/webdataset/__init__.py b/src/megatron/energon/flavors/webdataset/__init__.py index a0de97f0..464528cc 100644 --- a/src/megatron/energon/flavors/webdataset/__init__.py +++ b/src/megatron/energon/flavors/webdataset/__init__.py @@ -3,7 +3,6 @@ from megatron.energon.av import AVData, AVDecoder from megatron.energon.flavors.webdataset.base_webdataset import BaseWebdatasetFactory -from megatron.energon.flavors.webdataset.config import JOIN_INDEX_FILENAME, MAIN_FOLDER_NAME from megatron.energon.flavors.webdataset.default_decoder_webdataset import ( DefaultDecoderWebdatasetFactory, ) @@ -29,9 +28,7 @@ "DefaultGenericWebdatasetFactory", "EmptyDatasetError", "ImageDecoderType", - "JOIN_INDEX_FILENAME", "JoinedWebdatasetFactory", - "MAIN_FOLDER_NAME", "SampleDecoder", "StandardWebdatasetFactory", "WebdatasetMeta", diff --git a/src/megatron/energon/flavors/webdataset/aggregator_pool.py b/src/megatron/energon/flavors/webdataset/aggregator_pool.py index b6e3bc0b..803e2414 100644 --- a/src/megatron/energon/flavors/webdataset/aggregator_pool.py +++ b/src/megatron/energon/flavors/webdataset/aggregator_pool.py @@ -143,18 +143,32 @@ def process(self) -> T_result: """ workers: List[multiprocessing.Process] = [] - # Start worker processes - for w_id in range(self.num_workers): - p = multiprocessing.Process(target=self._worker, args=(w_id,), daemon=True) - p.start() - workers.append(p) + # Only works with forking for now because we pass self to the worker processes. + orig_start_method = multiprocessing.get_start_method() + multiprocessing.set_start_method("fork", force=True) + try: + # Start worker processes + for w_id in range(self.num_workers): + p = multiprocessing.Process(target=self._worker, args=(w_id,), daemon=True) + p.start() + workers.append(p) + finally: + multiprocessing.set_start_method(orig_start_method, force=True) # Send the sentinel (None) to each worker for _ in range(self.num_workers): self.task_queue.put(None) # Now wait for aggregator to finish reading the queue - result = self._aggregator_run() + try: + result = self._aggregator_run() + except Exception: + for p in workers: + if p.is_alive(): + p.terminate() + for p in workers: + p.join() + raise # Wait for all workers to finish for p in workers: diff --git a/src/megatron/energon/flavors/webdataset/base_webdataset.py b/src/megatron/energon/flavors/webdataset/base_webdataset.py index 1c84a5b6..4128fc16 100644 --- a/src/megatron/energon/flavors/webdataset/base_webdataset.py +++ b/src/megatron/energon/flavors/webdataset/base_webdataset.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause import logging +import os from abc import ABC, abstractmethod from typing import Any, Callable, Dict, Generic, List, Optional, Sequence, TypeVar @@ -12,9 +13,7 @@ from megatron.energon.flavors.base_dataset import ( BaseCoreDatasetFactory, SavableDataset, - legacy_handler, ) -from megatron.energon.flavors.webdataset.error_handler import ErrorHandler from megatron.energon.flavors.webdataset.metadata import WebdatasetMeta from megatron.energon.flavors.webdataset.prepare import WebdatasetPreparator from megatron.energon.flavors.webdataset.sample_loader import ( @@ -26,9 +25,7 @@ DatasetSubset, FilteredSample, ShardInfo, - reraise_exception, ) -from megatron.energon.source_info import SourceInfo from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.map_dataset import MapDataset @@ -37,12 +34,13 @@ logger = logging.getLogger(__name__) +DEBUG_SHARD_PRINT = os.getenv("ENERGON_DEBUG_SHARD_PRINT", "0") == "1" + class BaseWebdatasetFactory( BaseCoreDatasetFactory[T_sample], WebdatasetPreparator, Sharder, - ErrorHandler, Generic[T_sample], ABC, ): @@ -67,9 +65,6 @@ class BaseWebdatasetFactory( subset: Optional[DatasetSubset] part_filter: Optional[Callable[[str], bool]] - handler: Callable[[Exception, Optional[str], Optional[list[SourceInfo]]], None] - - shards: List[ShardInfo] def __init__( self, @@ -84,9 +79,6 @@ def __init__( subset: Optional[DatasetSubset] = None, split_config: Optional[str] = None, part_filter: Optional[Callable[[str], bool]] = None, - handler: Callable[ - [Exception, Optional[str], Optional[list[SourceInfo]]], None - ] = reraise_exception, ): """ Base factory for the webdataset sample loader. @@ -109,7 +101,6 @@ def __init__( subset: If specified, the dataset will be subsetted. split_config: Config file to use for shard split definitions. part_filter: (internal) Function for filtering tar files by dict keys - handler: Exception handler. Args: (exception, key, source_info). """ assert self.__sample_type__ is not None, f"Class {type(self)} must define __sample_type__" wds_meta = WebdatasetMeta.from_config( @@ -127,7 +118,6 @@ def __init__( self.max_samples_per_sequence = max_samples_per_sequence self.subset = subset self.part_filter = part_filter - self.handler = legacy_handler(handler) def __len__(self) -> int: return sum(shard.count for shard in self.shards) @@ -151,7 +141,8 @@ def build(self, worker_rotation_offset: int = 0) -> SavableDataset[T_sample]: rotation_offset=worker_rotation_offset, subset=self.subset, ) - _print_shard_slices(self.worker_config, self.shards, workers_sample_slice_offsets) + if DEBUG_SHARD_PRINT: + _print_shard_slices(self.worker_config, self.shards, workers_sample_slice_offsets) itar_reader = ShardInfosITarReader( self.path, @@ -171,7 +162,6 @@ def build(self, worker_rotation_offset: int = 0) -> SavableDataset[T_sample]: return MapDataset( dataset, self._load_sample_raw, - error_handler=self.error_handler, stateless_map_fn=True, map_fn_config=self.config, worker_config=self.worker_config, diff --git a/src/megatron/energon/flavors/webdataset/config.py b/src/megatron/energon/flavors/webdataset/config.py index 30d56f90..f1057d6e 100644 --- a/src/megatron/energon/flavors/webdataset/config.py +++ b/src/megatron/energon/flavors/webdataset/config.py @@ -7,3 +7,8 @@ skip_meta_re = re.compile(r"__[^/]*__($|/)") MAIN_FOLDER_NAME = ".nv-meta" JOIN_INDEX_FILENAME = "join_index.bin" +INDEX_SQLITE_FILENAME = "index.sqlite" +INDEX_UUID_FILENAME = "index.uuid" +INFO_JSON_FILENAME = ".info.json" +INFO_YAML_FILENAME = ".info.yaml" +SPLIT_YAML_FILENAME = "split.yaml" diff --git a/src/megatron/energon/flavors/webdataset/error_handler.py b/src/megatron/energon/flavors/webdataset/error_handler.py deleted file mode 100644 index f12583a1..00000000 --- a/src/megatron/energon/flavors/webdataset/error_handler.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. -# SPDX-License-Identifier: BSD-3-Clause - -from typing import Callable, Generic, Optional, Tuple, TypeVar, Union - -from megatron.energon.errors import SYSTEM_EXCEPTIONS, FatalSampleError -from megatron.energon.flavors.base_dataset import Sample -from megatron.energon.flavors.webdataset.structs import FilteredSample, reraise_exception -from megatron.energon.source_info import SourceInfo - -T_sample = TypeVar("T_sample", covariant=True) - - -class ErrorHandler(Generic[T_sample]): - handler: Callable[[Exception, Optional[str], Optional[list[SourceInfo]]], None] = ( - reraise_exception - ) - - def sample_error_handler( - self, e: Exception, sample_key: Optional[str], sources: Optional[list[SourceInfo]] = None - ): - if isinstance(e, SYSTEM_EXCEPTIONS): - raise FatalSampleError(f"Error in sample {sample_key!r}: {e}") from e - - self.handler(e, sample_key, sources) - - def error_handler( - self, - e: Exception, - sample: Union[ - T_sample, - dict, - FilteredSample, - None, - Tuple[Union[T_sample, dict, FilteredSample, None], ...], - ], - ): - if isinstance(sample, dict): - key = sample.get("__key__") - sources = sample.get("__sources__") - elif isinstance(sample, list): - if isinstance(sample[0], dict): - key = ",".join("None" if s is None else s.get("__key__") for s in sample) - sources = [src for s in sample for src in s.get("__sources__", ())] - elif isinstance(sample[0], Sample): - key = ",".join("None" if s is None else s.__key__ for s in sample) - sources = [src for s in sample for src in s.__sources__] - else: - key = None - sources = None - elif isinstance(sample, Sample): - key = sample.__key__ - sources = sample.__sources__ - else: - key = None - sources = None - self.sample_error_handler(e, key, sources) diff --git a/src/megatron/energon/flavors/webdataset/field_access.py b/src/megatron/energon/flavors/webdataset/field_access.py index 4a120d39..a661e915 100644 --- a/src/megatron/energon/flavors/webdataset/field_access.py +++ b/src/megatron/energon/flavors/webdataset/field_access.py @@ -4,7 +4,8 @@ import re from typing import Any, List, Union -from megatron.energon.errors import SampleException, compact_str +from megatron.energon.errors import SampleException +from megatron.energon.sample_utils import format_sample_compact _field_split_re = re.compile(r"^(?P[^[\]]+)(?P(?:\[([^[\]]+)])*)$") @@ -60,12 +61,14 @@ def _field_access(value: Union[dict, list, str, int, bool, None], field: List[st return _field_access(value[int(field[0])], field[1:]) else: raise FieldAccessError( - f"Cannot access literal value {compact_str(value)} with {field!r}" + f"Cannot access literal value {format_sample_compact(value)} with {field!r}" ) except FieldAccessError: raise except KeyError: - raise FieldAccessError(f"Cannot access {'.'.join(field)!r} in {compact_str(value)}") + raise FieldAccessError( + f"Cannot access {'.'.join(field)!r} in {format_sample_compact(value)}" + ) def field_access(value: Union[dict, list, str, int, bool, None], field: List[List[str]]) -> Any: diff --git a/src/megatron/energon/flavors/webdataset/indexing.py b/src/megatron/energon/flavors/webdataset/indexing.py index 635fbbf7..b3ef98d1 100644 --- a/src/megatron/energon/flavors/webdataset/indexing.py +++ b/src/megatron/energon/flavors/webdataset/indexing.py @@ -14,12 +14,29 @@ from megatron.energon.local_copy import ensure_local_copy +class DuplicateSampleKeyError(RuntimeError): + """Raised when attempting to insert a duplicate sample key into the index.""" + + def __init__(self, sample_key: str) -> None: + super().__init__(f"Duplicate sample key encountered while indexing: {sample_key!r}") + self.sample_key = sample_key + + class SqliteIndexWriter: sqlite_path: EPath db: Optional[sqlite3.Connection] - duplicates: List[Tuple[str, int]] + enable_sample_tables: bool + enable_media_metadata: bool + reset_tables: bool - def __init__(self, sqlite_path: EPath): + def __init__( + self, + sqlite_path: EPath, + *, + enable_sample_tables: bool = True, + enable_media_metadata: bool = False, + reset_tables: bool = True, + ): """ Initializes an SQLite database and sets up the samples table: - samples(tar_file_id INTEGER, @@ -42,6 +59,9 @@ def __init__(self, sqlite_path: EPath): # Final path and temporary path self.sqlite_path = sqlite_path + self.enable_sample_tables = enable_sample_tables + self.enable_media_metadata = enable_media_metadata + self.reset_tables = reset_tables # Initialize SQLite connection path = str(self.sqlite_path) @@ -53,41 +73,65 @@ def __init__(self, sqlite_path: EPath): Path(path).parent.mkdir(parents=True, exist_ok=True) self.db = sqlite3.connect(path) self.db.execute("PRAGMA busy_timeout = 5000;") # wait up to 5000ms when locked - self.db.execute("PRAGMA journal_mode = WAL;") - # Create the sample table - self.db.execute("DROP INDEX IF EXISTS idx_samples_sample_key") - self.db.execute("DROP INDEX IF EXISTS idx_samples_by_tar_and_idx") - self.db.execute("DROP TABLE IF EXISTS samples") - self.db.execute( - """ - CREATE TABLE samples ( - tar_file_id INTEGER, - sample_key TEXT, - sample_index INTEGER, - byte_offset INTEGER, - byte_size INTEGER - ) - """ - ) + if self.enable_sample_tables: + assert self.reset_tables, "Reset tables is required when enabling sample tables" - # Create the sample parts table - self.db.execute("DROP INDEX IF EXISTS idx_sample_parts_seq") - self.db.execute("DROP INDEX IF EXISTS idx_sample_parts_full") - self.db.execute("DROP TABLE IF EXISTS sample_parts") - self.db.execute( - """ - CREATE TABLE sample_parts ( - tar_file_id INTEGER, - sample_index INTEGER, - part_name TEXT, - content_byte_offset INTEGER, - content_byte_size INTEGER + self.db.execute("DROP INDEX IF EXISTS idx_samples_sample_key") + self.db.execute("DROP INDEX IF EXISTS idx_samples_by_tar_and_idx") + self.db.execute("DROP TABLE IF EXISTS samples") + + self.db.execute("DROP INDEX IF EXISTS idx_sample_parts_seq") + self.db.execute("DROP INDEX IF EXISTS idx_sample_parts_full") + self.db.execute("DROP TABLE IF EXISTS sample_parts") + + self.db.execute( + """ + CREATE TABLE IF NOT EXISTS samples ( + tar_file_id INTEGER NOT NULL, + sample_key TEXT NOT NULL UNIQUE, + sample_index INTEGER NOT NULL, + byte_offset INTEGER, + byte_size INTEGER + ) + """ + ) + self.db.execute( + """ + CREATE TABLE IF NOT EXISTS sample_parts ( + tar_file_id INTEGER, + sample_index INTEGER, + part_name TEXT, + content_byte_offset INTEGER, + content_byte_size INTEGER + ) + """ ) - """ - ) - self.duplicates = [] + if self.enable_media_metadata: + if self.reset_tables: + self.db.execute("DROP TABLE IF EXISTS media_metadata") + self.db.execute("DROP TABLE IF EXISTS media_filters") + self.db.execute( + """ + CREATE TABLE IF NOT EXISTS media_metadata ( + entry_key TEXT PRIMARY KEY, + metadata_type TEXT NOT NULL, + metadata_json TEXT NOT NULL + ) + """ + ) + self.db.execute( + """ + CREATE TABLE IF NOT EXISTS media_filters ( + filter_id INTEGER PRIMARY KEY AUTOINCREMENT, + strategy TEXT NOT NULL, + patterns TEXT, + created_at_utc TEXT DEFAULT CURRENT_TIMESTAMP, + UNIQUE(strategy, patterns) + ) + """ + ) def append_sample( self, @@ -111,13 +155,16 @@ def append_sample( assert self.db is not None, "Database is closed" # Insert a row in the samples table - self.db.execute( - """ - INSERT INTO samples (tar_file_id, sample_key, sample_index, byte_offset, byte_size) - VALUES (?, ?, ?, ?, ?) - """, - (tar_file_id, sample_key, sample_index, byte_offset, byte_size), - ) + try: + self.db.execute( + """ + INSERT INTO samples (tar_file_id, sample_key, sample_index, byte_offset, byte_size) + VALUES (?, ?, ?, ?, ?) + """, + (tar_file_id, sample_key, sample_index, byte_offset, byte_size), + ) + except sqlite3.IntegrityError as exc: # pragma: no cover - defensive programming + raise DuplicateSampleKeyError(sample_key) from exc def append_part( self, @@ -140,6 +187,33 @@ def append_part( (tar_file_id, sample_index, part_name, content_byte_offset, content_byte_size), ) + def append_media_metadata( + self, + entry_key: str, + metadata_type: str, + metadata_json: str, + ) -> None: + """Insert or update a media metadata record.""" + + assert self.enable_media_metadata, "Adding media metadata, although not enabled" + + assert self.db is not None, "Database is closed" + + self.db.execute( + """ + INSERT OR REPLACE INTO media_metadata (entry_key, metadata_type, metadata_json) + VALUES (?, ?, ?) + """, + (entry_key, metadata_type, metadata_json), + ) + + def append_media_filter(self, *, strategy: str, patterns: str | None) -> None: + assert self.db is not None, "Database is closed" + self.db.execute( + "INSERT OR IGNORE INTO media_filters (strategy, patterns) VALUES (?, ?)", + (strategy, patterns), + ) + def close(self): """ Closes the DB connection. If finalize=True, the temporary database is @@ -147,32 +221,27 @@ def close(self): """ assert self.db is not None, "Database is closed" - # Create the index after adding all the samples for better speed - # Index on sample_key for fast lookups - self.db.execute("CREATE INDEX IF NOT EXISTS idx_samples_sample_key ON samples(sample_key)") - - # Create index on the samples table. Help the planner if it chooses `samples` as the probe side of the join - self.db.execute( - "CREATE INDEX IF NOT EXISTS idx_samples_by_tar_and_idx ON samples(tar_file_id, sample_index)" - ) + if self.enable_sample_tables: + # Create the index after adding all the samples for better speed + # Index on sample_key for fast lookups + self.db.execute( + "CREATE UNIQUE INDEX IF NOT EXISTS idx_samples_sample_key ON samples(sample_key)" + ) - # Create index on the sample_parts table for fast sequential access - self.db.execute( - "CREATE INDEX IF NOT EXISTS idx_sample_parts_seq ON sample_parts(tar_file_id, sample_index, content_byte_offset)" - ) + # Create index on the samples table. Help the planner if it chooses `samples` as the probe side of the join + self.db.execute( + "CREATE INDEX IF NOT EXISTS idx_samples_by_tar_and_idx ON samples(tar_file_id, sample_index)" + ) - # Create a full index on the sample_parts table for equality lookups and getting offsets directly from key - self.db.execute( - "CREATE INDEX IF NOT EXISTS idx_sample_parts_full ON sample_parts(tar_file_id, sample_index, part_name, content_byte_offset, content_byte_size)" - ) + # Create index on the sample_parts table for fast sequential access + self.db.execute( + "CREATE INDEX IF NOT EXISTS idx_sample_parts_seq ON sample_parts(tar_file_id, sample_index, content_byte_offset)" + ) - # Check if sample_key are all unique - # self.db.execute("CREATE TEMP TABLE temp AS SELECT sample_key, COUNT(*) AS c FROM samples GROUP BY sample_key HAVING c > 1") - duplicates = self.db.execute( - "SELECT sample_key, COUNT(*) AS c FROM samples GROUP BY sample_key HAVING c > 1 LIMIT 5" - ).fetchall() - if len(duplicates) > 0: - self.duplicates = duplicates + # Create a full index on the sample_parts table for equality lookups and getting offsets directly from key + self.db.execute( + "CREATE INDEX IF NOT EXISTS idx_sample_parts_full ON sample_parts(tar_file_id, sample_index, part_name, content_byte_offset, content_byte_size)" + ) if self.db is not None: self.db.commit() @@ -283,6 +352,17 @@ def db_has_sample_parts(self) -> bool: self.db.thread_close() return db_exists is not None + def db_has_media_metadata(self) -> bool: + """Check if the database has a media_metadata table.""" + + assert self.db is not None, "Database is closed" + + db_exists = self.db.select_one( + "SELECT name FROM sqlite_master WHERE type='table' AND name='media_metadata'" + ) + self.db.thread_close() + return db_exists is not None + def list_all_samples(self) -> Generator[Tuple[str, int, int], None, None]: """List all sample keys in the database. @@ -414,6 +494,17 @@ def get_sample_pointer_by_key(self, key: str) -> ITarSamplePointer: byte_size=sample[4], ) + def get_media_metadata(self, entry_key: str) -> Tuple[str, str] | None: + """Fetch the media metadata record for an entry, if available.""" + + assert self.db is not None, "Database is closed" + + row = self.db.select_one( + "SELECT metadata_type, metadata_json FROM media_metadata WHERE entry_key = ?", + (entry_key,), + ) + return (row[0], row[1]) if row is not None else None + def close(self): """Close the database connection.""" if self.db is not None: diff --git a/src/megatron/energon/flavors/webdataset/itar_reader.py b/src/megatron/energon/flavors/webdataset/itar_reader.py index deb95352..0134f338 100644 --- a/src/megatron/energon/flavors/webdataset/itar_reader.py +++ b/src/megatron/energon/flavors/webdataset/itar_reader.py @@ -17,7 +17,11 @@ ) from megatron.energon.epathlib import EPath -from megatron.energon.flavors.webdataset.config import skip_meta_re, split_name_re +from megatron.energon.flavors.webdataset.config import ( + INDEX_SQLITE_FILENAME, + skip_meta_re, + split_name_re, +) from megatron.energon.flavors.webdataset.indexing import JoinIndexReader, SqliteIndexReader from megatron.energon.flavors.webdataset.itar import ( CachedItarOffsetReader, @@ -225,7 +229,7 @@ def _get_item_by_sample_pointer( raise ValueError(f"No valid files found in sample {sample_pointer}") return FilteredSample( - __key__=f"{shard_name}/{sample_base_name}", + __key__=sample_base_name, __shard__=self.tar_filenames[sample_pointer.tar_file_id], __restore_key__=("Webdataset", restore_index), __sources__=( @@ -465,7 +469,7 @@ def __init__( tar_filepaths = [base_path / fn for fn in tar_filenames] # Initialize the SQLite reader - sqlite_path = base_path / MAIN_FOLDER_NAME / "index.sqlite" + sqlite_path = base_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME self.sqlite_reader = SqliteIndexReader(sqlite_path) self.db_has_sample_parts = self.sqlite_reader.db_has_sample_parts() diff --git a/src/megatron/energon/flavors/webdataset/joined_webdataset.py b/src/megatron/energon/flavors/webdataset/joined_webdataset.py index 45ca153f..ba45439f 100644 --- a/src/megatron/energon/flavors/webdataset/joined_webdataset.py +++ b/src/megatron/energon/flavors/webdataset/joined_webdataset.py @@ -1,6 +1,7 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause +import os from abc import ABC from typing import ( Any, @@ -22,11 +23,9 @@ BaseCoreDatasetFactory, Sample, SavableDataset, - legacy_handler, set_sample_restore_key, ) from megatron.energon.flavors.webdataset.base_webdataset import BaseWebdatasetFactory -from megatron.energon.flavors.webdataset.error_handler import ErrorHandler from megatron.energon.flavors.webdataset.indexing import JoinIndexReader from megatron.energon.flavors.webdataset.itar_reader import JoinIndexFileITarReader from megatron.energon.flavors.webdataset.sample_loader import ( @@ -34,17 +33,16 @@ WebdatasetSampleLoaderDataset, ) from megatron.energon.flavors.webdataset.sharder import Sharder -from megatron.energon.flavors.webdataset.structs import DatasetSubset, ShardInfo, reraise_exception -from megatron.energon.source_info import SourceInfo +from megatron.energon.flavors.webdataset.structs import DatasetSubset, ShardInfo from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.map_dataset import MapDataset T_sample = TypeVar("T_sample", covariant=True) +DEBUG_SHARD_PRINT = os.getenv("ENERGON_DEBUG_SHARD_PRINT", "0") == "1" -class JoinedWebdatasetFactory( - BaseCoreDatasetFactory[T_sample], Sharder, ErrorHandler[T_sample], Generic[T_sample], ABC -): + +class JoinedWebdatasetFactory(BaseCoreDatasetFactory[T_sample], Sharder, Generic[T_sample], ABC): """ Base class for all webdataset loaders. Applies proper sharding across workers. Can join multiple datasets. """ @@ -57,7 +55,6 @@ class JoinedWebdatasetFactory( subset: Optional[DatasetSubset] join_index: EPath - handler: Callable[[Exception, Optional[str], Optional[list[SourceInfo]]], None] shards: List[Sequence[ShardInfo]] part_datasets: SavableDataset[T_sample] @@ -78,9 +75,6 @@ def __init__( subset: Optional[DatasetSubset] = None, join_index: EPath, joiner: Union[Type[T_sample], Callable[..., T_sample]], - handler: Callable[ - [Exception, Optional[str], Optional[list[SourceInfo]]], None - ] = reraise_exception, ): """ Constructs the loader for a joined webdataset. The samples from the inner datasets are joined into a single @@ -105,7 +99,6 @@ def __init__( subset: If specified, the inner dataset(s) will be subsetted. join_index: Path to the join index file. Only required for join_method="left". joiner: Type of the joined samples or a method for joining the samples. - handler: Exception handler. Args: (exception, key). """ self.__sample_type__ = joiner assert all(not hasattr(d, "dataset") for d in inner_datasets), ( @@ -135,7 +128,6 @@ def __init__( self.parallel_shard_iters = parallel_shard_iters self.max_samples_per_sequence = max_samples_per_sequence self.subset = subset - self.handler = legacy_handler(handler) def __len__(self) -> int: return sum(shard.count for shard in self.inner_datasets[0].shards) @@ -169,19 +161,20 @@ def build(self, worker_rotation_offset: int = 0) -> SavableDataset[T_sample]: subset=self.subset, ) - for worker_idx, sample_slice_offsets in enumerate(workers_sample_slice_offsets): - start_idx = sample_slice_offsets[0] - end_idx = sample_slice_offsets[-1] + if DEBUG_SHARD_PRINT: + for worker_idx, sample_slice_offsets in enumerate(workers_sample_slice_offsets): + start_idx = sample_slice_offsets[0] + end_idx = sample_slice_offsets[-1] - if len(sample_slice_offsets) > 6: - offset_str = f"{', '.join(str(o) for o in sample_slice_offsets[:3])} ...<{len(sample_slice_offsets) - 6}> {', '.join(str(o) for o in sample_slice_offsets[-3:])}" - else: - offset_str = ", ".join(str(o) for o in sample_slice_offsets) + if len(sample_slice_offsets) > 6: + offset_str = f"{', '.join(str(o) for o in sample_slice_offsets[:3])} ...<{len(sample_slice_offsets) - 6}> {', '.join(str(o) for o in sample_slice_offsets[-3:])}" + else: + offset_str = ", ".join(str(o) for o in sample_slice_offsets) - print( - f"rank={self.worker_config.rank}, worker={worker_idx}: sample_range=[{start_idx}, {end_idx}) in {len(sample_slice_offsets) - 1} slices, " - f"sum(count)={end_idx - start_idx}: [{offset_str}]" - ) + print( + f"rank={self.worker_config.rank}, worker={worker_idx}: sample_range=[{start_idx}, {end_idx}) in {len(sample_slice_offsets) - 1} slices, " + f"sum(count)={end_idx - start_idx}: [{offset_str}]" + ) itar_readers = [ JoinIndexFileITarReader( @@ -216,7 +209,6 @@ def _process_samples(self, dataset: SavableDataset[RawSampleData]) -> SavableDat return MapDataset( dataset, self.load_sample, - error_handler=self.error_handler, stateless_map_fn=True, map_fn_config=self.config, worker_config=self.worker_config, diff --git a/src/megatron/energon/flavors/webdataset/metadata.py b/src/megatron/energon/flavors/webdataset/metadata.py index d90d04fa..1b2d3694 100644 --- a/src/megatron/energon/flavors/webdataset/metadata.py +++ b/src/megatron/energon/flavors/webdataset/metadata.py @@ -10,7 +10,12 @@ from megatron.energon.edataclass import edataclass from megatron.energon.epathlib import EPath from megatron.energon.eyaml import load_yaml, load_yaml_json -from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME +from megatron.energon.flavors.webdataset.config import ( + INDEX_SQLITE_FILENAME, + INFO_JSON_FILENAME, + INFO_YAML_FILENAME, + MAIN_FOLDER_NAME, +) from megatron.energon.flavors.webdataset.empty_dataset_error import EmptyDatasetError from megatron.energon.flavors.webdataset.structs import ( ShardInfo, @@ -24,6 +29,7 @@ class EnergonDatasetType(Enum): METADATASET = "metadataset" WEBDATASET = "webdataset" JSONL = "jsonl" + FILESYSTEM = "filesystem" INVALID = "invalid" @@ -120,7 +126,7 @@ def get_dataset_info(path: EPath) -> dict: return the dataset info as a dict. """ - info_config = path / MAIN_FOLDER_NAME / ".info.json" + info_config = path / MAIN_FOLDER_NAME / INFO_JSON_FILENAME # YAML for backwards compatibility yaml_info_config = path / MAIN_FOLDER_NAME / ".info.yaml" @@ -137,8 +143,8 @@ def check_dataset_info_present(path: EPath) -> bool: """Given the path to an energon webdataset that contains a .nv-meta folder, return True if the dataset info is present, False otherwise. """ - return (path / MAIN_FOLDER_NAME / ".info.json").is_file() or ( - path / MAIN_FOLDER_NAME / ".info.yaml" + return (path / MAIN_FOLDER_NAME / INFO_JSON_FILENAME).is_file() or ( + path / MAIN_FOLDER_NAME / INFO_YAML_FILENAME ).is_file() @@ -151,6 +157,8 @@ def get_dataset_type(path: EPath) -> EnergonDatasetType: Returns: The type of the dataset. """ + metadata_db = path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME + if path.is_file(): if path.name.endswith(".jsonl"): return EnergonDatasetType.JSONL @@ -160,5 +168,9 @@ def get_dataset_type(path: EPath) -> EnergonDatasetType: return EnergonDatasetType.INVALID elif check_dataset_info_present(path): return EnergonDatasetType.WEBDATASET + elif metadata_db.is_file(): + # There is an sqlite, but no .info.json or .info.yaml, + # so it's a filesystem dataset + return EnergonDatasetType.FILESYSTEM else: return EnergonDatasetType.INVALID diff --git a/src/megatron/energon/flavors/webdataset/prepare.py b/src/megatron/energon/flavors/webdataset/prepare.py index 4fc6b485..c97436a5 100644 --- a/src/megatron/energon/flavors/webdataset/prepare.py +++ b/src/megatron/energon/flavors/webdataset/prepare.py @@ -6,6 +6,7 @@ import logging import random import re +import sys import tarfile import uuid from dataclasses import asdict @@ -34,10 +35,24 @@ from megatron.energon.epathlib import EPath from megatron.energon.eyaml import load_yaml from megatron.energon.flavors.webdataset.aggregator_pool import AggregatorPool, BaseAggregator -from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME, skip_meta_re, split_name_re -from megatron.energon.flavors.webdataset.indexing import SqliteIndexWriter +from megatron.energon.flavors.webdataset.config import ( + INDEX_SQLITE_FILENAME, + INDEX_UUID_FILENAME, + INFO_JSON_FILENAME, + INFO_YAML_FILENAME, + MAIN_FOLDER_NAME, + skip_meta_re, + split_name_re, +) +from megatron.energon.flavors.webdataset.indexing import ( + DuplicateSampleKeyError, + SqliteIndexWriter, +) from megatron.energon.flavors.webdataset.itar import TarIndexWriter +from megatron.energon.flavors.webdataset.metadata import get_dataset_info from megatron.energon.flavors.webdataset.structs import ShardInfo, WebdatasetInfo, WebdatasetSplits +from megatron.energon.media.extractor import MediaFilterConfig +from megatron.energon.media.metadata import serialize_media_metadata from megatron.energon.typed_converter import to_json_object logger = logging.getLogger(__name__) @@ -73,6 +88,13 @@ class IndexSamplePart(IndexAggregatable): content_byte_size: int +@edataclass +class IndexMediaMetadata(IndexAggregatable): + entry_key: str + metadata_type: str + metadata_json: str + + @edataclass class IndexShardInfo(IndexAggregatable): shard_info: ShardInfo @@ -80,9 +102,7 @@ class IndexShardInfo(IndexAggregatable): class SqliteIndexWriterAggregator( - BaseAggregator[ - Tuple[ShardInfo, Set[str]], Tuple[List[ShardInfo], Set[str], bool, List[Tuple[str, int]]] - ] + BaseAggregator[IndexAggregatable, Tuple[List[ShardInfo], Set[str], bool, List[Tuple[str, int]]]] ): sqlite_path: EPath total_tasks: int @@ -92,12 +112,24 @@ class SqliteIndexWriterAggregator( shards: List[ShardInfo] found_parts: Set[str] prog_iter: Iterator + enable_sample_tables: bool + enable_media_metadata: bool + media_filter: Optional[MediaFilterConfig] + reset_tables: bool + media_metadata_written: int + progress_on_media: bool def __init__( self, sqlite_path: EPath, total_tasks: int, progress_fn: Optional[Callable[[Iterator[Any], int], Iterator[T]]] = None, + *, + enable_sample_tables: bool = True, + enable_media_metadata: bool = False, + media_filter: Optional[MediaFilterConfig] = None, + reset_tables: bool = True, + progress_on_media: bool = False, ): self.sqlite_path = sqlite_path self.total_tasks = total_tasks @@ -105,6 +137,12 @@ def __init__( self.had_update = False self.shards = [] self.found_parts = set() + self.enable_sample_tables = enable_sample_tables + self.enable_media_metadata = enable_media_metadata + self.media_filter = media_filter + self.reset_tables = reset_tables + self.media_metadata_written = 0 + self.progress_on_media = progress_on_media if progress_fn is not None: self.prog_iter = progress_fn(iter(range(self.total_tasks)), self.total_tasks) @@ -112,7 +150,12 @@ def __init__( self.prog_iter = iter(range(self.total_tasks)) def on_start(self, aggregator_pool: AggregatorPool) -> None: - self.writer = SqliteIndexWriter(self.sqlite_path) + self.writer = SqliteIndexWriter( + self.sqlite_path, + enable_sample_tables=self.enable_sample_tables, + enable_media_metadata=self.enable_media_metadata, + reset_tables=self.reset_tables, + ) def on_item( self, @@ -125,9 +168,20 @@ def on_item( self.had_update = True elif isinstance(item, IndexSamplePart): self.writer.append_part(**asdict(item)) + elif isinstance(item, IndexMediaMetadata): + self.writer.append_media_metadata( + entry_key=item.entry_key, + metadata_type=item.metadata_type, + metadata_json=item.metadata_json, + ) + self.had_update = True + self.media_metadata_written += 1 + if self.progress_on_media: + self._advance_progress() elif isinstance(item, IndexShardInfo): # This is a (shard_info, parts) tuple - next(self.prog_iter) + if not self.progress_on_media: + self._advance_progress() shard_info, cur_parts = item.shard_info, item.parts assert shard_info.count != 0, f"Shard {shard_info.name} has no samples." @@ -137,22 +191,57 @@ def on_item( def on_finish(self, aggregator_pool: AggregatorPool) -> None: assert self.writer is not None, "Writer is not initialized." + if self.enable_media_metadata and self.media_filter is not None: + self.writer.append_media_filter( + strategy=self.media_filter.strategy.value, + patterns=",".join(self.media_filter.patterns), + ) self.writer.close() def get_final_result_data( self, - ) -> Tuple[List[ShardInfo], Set[str], bool, List[Tuple[str, int]]]: + ) -> Tuple[List[ShardInfo], Set[str], bool]: assert self.writer is not None, "Writer is not initialized." - return self.shards, self.found_parts, self.had_update, self.writer.duplicates + return self.shards, self.found_parts, self.had_update + + def _advance_progress(self) -> None: + try: + next(self.prog_iter) + except StopIteration: + pass class WebdatasetPreparator: + @staticmethod + def _iter_tar_sample_members( + tar: tarfile.TarFile, + ) -> Iterator[Tuple[tarfile.TarInfo, str, str]]: + """Yield (member, base_name, part_name) for relevant tar entries.""" + + member: tarfile.TarInfo + for member in tar: + if not member.isreg(): + continue + if member.name is None: + continue + if skip_meta_re.match(member.name): + continue + + name_match = split_name_re.match(member.name) + if name_match is None: + continue + + base_name = name_match.group(1) + part_name = name_match.group(2) + yield member, base_name, part_name + @staticmethod def _preprocess_tar( path: str, shard_to_idx: Dict[str, int], parent_path: EPath, max_parts: int, + media_filter: Optional[MediaFilterConfig] = None, ) -> Generator[IndexAggregatable, None, None]: """Process a single tar file, i.e. read the tarinfos, generate the tar index and return stats. @@ -189,25 +278,16 @@ def _preprocess_tar( parts = set() last_base_name = None - member: tarfile.TarInfo next_index_sample = None - for member in tar: - if not member.isreg(): - continue - if member.name is None: - continue - if skip_meta_re.match(member.name): - continue - - name_match = split_name_re.match(member.name) - if name_match is None: - continue - - base_name = name_match.group(1) + for ( + member, + base_name, + part_name, + ) in WebdatasetPreparator._iter_tar_sample_members(tar): if len(parts) < max_parts: - parts.add(name_match.group(2)) + parts.add(part_name) if last_base_name != base_name: iw.append(member.offset) @@ -227,15 +307,37 @@ def _preprocess_tar( last_base_name = base_name count += 1 + entry_key = f"{base_name}.{part_name}" + # Yield this part of the sample to the aggregator yield IndexSamplePart( tar_file_id=shard_to_idx[path], sample_index=count - 1, - part_name=name_match.group(2), + part_name=part_name, content_byte_offset=member.offset_data, content_byte_size=member.size, ) + if media_filter is not None: + if not media_filter.should_consider_media(entry_key): + continue + file_member = tar.extractfile(member) + if file_member is not None: + data = file_member.read() + extracted_metadata = media_filter.extract_metadata( + data, + filename=entry_key, + ) + if extracted_metadata is not None: + stored_type, metadata_json = serialize_media_metadata( + extracted_metadata + ) + yield IndexMediaMetadata( + entry_key=entry_key, + metadata_type=stored_type.value, + metadata_json=metadata_json, + ) + shard_info.count = count iw.append(tar.offset) if next_index_sample is not None: @@ -250,6 +352,68 @@ def _preprocess_tar( yield IndexShardInfo(shard_info=shard_info, parts=set()) return + @staticmethod + def _extract_media_from_tar( + path: str, + *, + parent_path: EPath, + media_filter: MediaFilterConfig, + shard_counts: Dict[str, int], + ) -> Generator[IndexAggregatable, None, None]: + """Yield ``IndexMediaMetadata`` entries for media within an existing tar shard.""" + + shard_path = parent_path / path + + try: + with shard_path.open("rb") as handle: + with tarfile.open(fileobj=handle, mode="r:*") as tar: + for ( + member, + base_name, + part_name, + ) in WebdatasetPreparator._iter_tar_sample_members(tar): + entry_key = f"{base_name}.{part_name}" + + if not media_filter.should_consider_media(entry_key): + continue + + file_member = tar.extractfile(member) + if file_member is None: + continue + + extracted_metadata = media_filter.extract_metadata( + file_member, + filename=entry_key, + ) + if extracted_metadata is None: + continue + + stored_type, metadata_json = serialize_media_metadata(extracted_metadata) + + yield IndexMediaMetadata( + entry_key=entry_key, + metadata_type=stored_type.value, + metadata_json=metadata_json, + ) + + except BaseException: # pragma: no cover - dependent on malformed archives + logger.exception( + f"Shard failed to load when extracting media metadata: {path!r}. Skipping it." + ) + + shard_count = shard_counts.get(path) + if shard_count is None: + raise ValueError(f"Shard count for '{path}' not found in dataset metadata") + + yield IndexShardInfo( + shard_info=ShardInfo( + name=path, + path=parent_path / path, + count=shard_count, + ), + parts=set(), + ) + @staticmethod def iter_dataset_content( path: Union[str, EPath], @@ -307,6 +471,8 @@ def prepare_dataset( progress_fn: Callable[[Iterator[Any], int], Iterator[T]] = (lambda x, y: x), workers: int = 32, tar_index_only: bool = False, + media_filter: Optional[MediaFilterConfig] = None, + fix_duplicates: bool = False, ) -> Tuple[Set[str], List[Tuple[str, int]]]: """ Preprocess the shards and write the split config. Preprocessing is done in parallel. @@ -323,6 +489,8 @@ def prepare_dataset( progress_fn: Callback for progress bar workers: Number of parallel workers for reading each shard tar_index_only: Only create tar-index, then exit + media_filter: Media filter configuration + fix_duplicates: If True, fix duplicate keys in the dataset by renaming the files in the shards. Returns: The set of all parts found in the shards. But at most 50. @@ -336,10 +504,39 @@ def prepare_dataset( (parent_path / MAIN_FOLDER_NAME).mkdir(exist_ok=True) + if fix_duplicates: + try: + from megatron.energon.flavors.webdataset.tar_patcher import TarPatcher + except ImportError: + print("Install energon with [tar_patcher] extra to fix duplicate keys.") + raise + tar_patcher = TarPatcher(show_progress=True) + scan_result = tar_patcher.dataset_scan( + paths, parent_path=parent_path, num_workers=workers + ) + + if scan_result.has_duplicates: + print("The dataset contains duplicate keys.") + if not scan_result.compatible: + print( + "But the tar files are not compatible with the in-place rename, aborting." + ) + sys.exit(1) + + print("Fixing the dataset now.") + tar_patcher.dataset_apply_prefix( + paths, parent_path=parent_path, num_workers=workers + ) + print("Duplicate keys fixed successfully.") + else: + print("No duplicate keys found, continuing.") + aggregator = SqliteIndexWriterAggregator( - parent_path / MAIN_FOLDER_NAME / "index.sqlite", + parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME, total_tasks=len(paths), progress_fn=progress_fn, + enable_media_metadata=media_filter is not None, + media_filter=media_filter, ) process_tar = functools.partial( @@ -347,6 +544,7 @@ def prepare_dataset( shard_to_idx=shard_to_idx, parent_path=parent_path, max_parts=50, + media_filter=media_filter, ) pool = AggregatorPool( @@ -358,15 +556,30 @@ def prepare_dataset( for path in paths: pool.submit_task(path) - shards, found_parts, had_update, duplicates = pool.process() + try: + shards, found_parts, had_update = pool.process() + except DuplicateSampleKeyError as error: + print("The data contains duplicate keys (e.g. same filename in different shards).") + print(f'Example duplicate key: "{error.sample_key}"') + print() + print( + "Energon does not support duplicate keys anymore, but we offer a tool to fix your dataset. " + "Run `energon prepare` with `--fix-duplicates` to fix your dataset. Inside each tar, it will " + "put each file in a subfolder with the shard name like `shard_0/filename.ext`." + ) + + if (parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME).is_file(): + (parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME).unlink() + + sys.exit(1) if had_update: logger.info("Regenerating dataset UUID...") - with (parent_path / MAIN_FOLDER_NAME / "index.uuid").open("w") as f: + with (parent_path / MAIN_FOLDER_NAME / INDEX_UUID_FILENAME).open("w") as f: f.write(str(uuid.uuid4())) - json_info_config = parent_path / MAIN_FOLDER_NAME / ".info.json" - yaml_info_config = parent_path / MAIN_FOLDER_NAME / ".info.yaml" + json_info_config = parent_path / MAIN_FOLDER_NAME / INFO_JSON_FILENAME + yaml_info_config = parent_path / MAIN_FOLDER_NAME / INFO_YAML_FILENAME if tar_index_only: if yaml_info_config.is_file() and not json_info_config.is_file(): @@ -374,7 +587,7 @@ def prepare_dataset( with json_info_config.open("w") as f: json.dump(load_yaml(yaml_info_config.read_bytes()), f, indent=2) - return found_parts, duplicates + return found_parts assert len(shards) == len(shard_to_idx), ( f"Lengths of shards and shard_to_idx do not match: {len(shards)} != {len(shard_to_idx)}" @@ -453,4 +666,61 @@ def prepare_dataset( else: raise ValueError(f"Invalid split config extension: {split_config}") - return found_parts, duplicates + return found_parts + + @classmethod + def add_media_metadata( + cls, + parent_path: Union[Path, EPath], + *, + media_filter: MediaFilterConfig, + workers: int = 32, + progress_fn: Callable[[Iterator[Any], int], Iterator[T]] = (lambda x, y: x), + ) -> int: + """Add or refresh media metadata in an existing WebDataset index.""" + + parent_path = EPath(parent_path) + + dataset_info = get_dataset_info(parent_path) + shard_counts: Dict[str, int] = dataset_info.get("shard_counts", {}) + + paths = list(shard_counts.keys()) + + expanded_paths = [path for path in paths for path in braceexpand.braceexpand(path)] + if not expanded_paths: + return 0 + + for path in expanded_paths: + if path not in shard_counts: + raise ValueError(f"Shard '{path}' not present in dataset metadata") + + aggregator = SqliteIndexWriterAggregator( + parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME, + total_tasks=len(expanded_paths), + progress_fn=progress_fn, + enable_sample_tables=False, + enable_media_metadata=True, + media_filter=media_filter, + reset_tables=False, + progress_on_media=False, + ) + + process_tar = functools.partial( + cls._extract_media_from_tar, + parent_path=parent_path, + media_filter=media_filter, + shard_counts=shard_counts, + ) + + pool = AggregatorPool( + num_workers=min(workers, len(expanded_paths)) or 1, + user_produce_data=process_tar, + aggregator=aggregator, + ) + + for path in expanded_paths: + pool.submit_task(path) + + pool.process() + + return aggregator.media_metadata_written diff --git a/src/megatron/energon/flavors/webdataset/standard_webdataset.py b/src/megatron/energon/flavors/webdataset/standard_webdataset.py index 66212a0a..56346e9f 100644 --- a/src/megatron/energon/flavors/webdataset/standard_webdataset.py +++ b/src/megatron/energon/flavors/webdataset/standard_webdataset.py @@ -55,7 +55,6 @@ def __init__( will be sequentially iterated). subset: If specified, the dataset will be subsetted. split_config: Config file to use for shard split definitions. - handler: Exception handler. Args: (exception, key). """ self.__sample_type__ = sample_type assert issubclass(sample_type, Sample), ( diff --git a/src/megatron/energon/flavors/webdataset/structs.py b/src/megatron/energon/flavors/webdataset/structs.py index 1ee955eb..1b8842cd 100644 --- a/src/megatron/energon/flavors/webdataset/structs.py +++ b/src/megatron/energon/flavors/webdataset/structs.py @@ -113,16 +113,3 @@ def config(self) -> dict: "range": self.range, "absolute_range": self.absolute_range, } - - -def reraise_exception( - exc: Exception, key: Optional[str], sources: Optional[list[SourceInfo]] = None -) -> None: - if sources: - raise Exception( - f"For sample {key!r} from {', '.join(f'{source.dataset_path}[{source.index}] {source.shard_name}{source.file_names!r}' for source in sources)}" - ) from exc - elif key: - raise Exception(f"For sample {key!r}") from exc - else: - raise diff --git a/src/megatron/energon/flavors/webdataset/tar_patcher.py b/src/megatron/energon/flavors/webdataset/tar_patcher.py new file mode 100644 index 00000000..9cdbf681 --- /dev/null +++ b/src/megatron/energon/flavors/webdataset/tar_patcher.py @@ -0,0 +1,835 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause + +import concurrent.futures +import multiprocessing as mp +import re +from concurrent.futures import ProcessPoolExecutor +from dataclasses import dataclass, field +from pathlib import Path +from typing import BinaryIO, Dict, List, Sequence, Set + +import click +import numpy as np +from tqdm import tqdm + +from megatron.energon.epathlib import EPath + +# mp.set_start_method("fork") + +# Try to import numba for JIT compilation of hot paths +try: + import numba as nb + from numba import njit # type: ignore[import-not-found] + + HAS_NUMBA = True +except ImportError: + raise ImportError("numba is required for this module") + +U8C = nb.types.Array(nb.types.uint8, 1, "C", readonly=False) +U8C_RO = nb.types.Array(nb.types.uint8, 1, "C", readonly=True) +INT64 = nb.types.int64 +B1 = nb.types.boolean +VOID = nb.types.void + +BLOCK_SIZE = 512 +NUM_WORKERS = 16 + + +class TarPatcherError(Exception): + """Raised when a tar patch operation fails.""" + + +@dataclass(slots=True) +class TarScanResult: + """Result of scanning a tar archive.""" + + sample_keys: Set[bytes] = field(default_factory=set) + compatible: bool = True + + +@dataclass(slots=True) +class DatasetScanResult: + """Aggregated result of scanning a set of tar archives.""" + + compatible: bool + duplicates: Dict[str, List[str]] + scan_results: Dict[str, TarScanResult] + + @property + def has_duplicates(self) -> bool: + return bool(self.duplicates) + + +@njit(B1(U8C_RO), cache=True, fastmath=True, inline="always") +def _nb_is_zero_block(block: bytearray | bytes) -> bool: + """Numba-optimized: Check if a block is all zeros.""" + for i in range(len(block)): + if block[i] != 0: + return False + return True + + +@njit(INT64(U8C_RO), cache=True, inline="always") +def _nb_parse_size(size: np.ndarray) -> int: + # Base-256 (binary) encoding (POSIX) + # # [124:136] + if size[0] & 0x80: + # Numba doesn't support int.from_bytes, so implement base-256 parsing manually + # Big endian encoding + n = nb.int64(size[0] & 0x3F) + for i in range(1, size.size): + n = (n << 8) | nb.int64(size[i]) + # If the sign bit is set, compute the negative value per tar spec + if size[0] & 0x40: + # print("negative binary size") + return 0 + return n + + # Parse ascii integer + n = nb.int64(0) + for i in range(size.size): + byte = size[i] + if byte == 0 or byte == 32: + continue + if byte < 48 or byte > 57: + return 0 + n = (n * 8) + nb.int64(byte - 48) + return n + + +@njit(nb.types.Tuple((U8C_RO, U8C_RO))(U8C_RO), cache=True, inline="always") +def split_ustar_path(path: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """ + Split path into (prefix, name) suitable for ustar: + - name: up to 100 bytes + - prefix: up to 155 bytes + Return (prefix_bytes, name_bytes) or ([]], []]) if it doesn't fit. + """ + if len(path) <= 100: + return np.empty(0, dtype=np.uint8), path + + # Try to split at a '/' so prefix <=155 bytes and name <=100 bytes. + cut = -1 + for i in range(path.size): + if path[i] == 47: # ord(b"/") = 47 + if i <= 155: + cut = i + + if cut == -1: + return np.empty(0, dtype=np.uint8), np.empty(0, dtype=np.uint8) + + prefix_str = path[0:cut] + name_str = path[cut + 1 :] + + prefix_b = prefix_str + name_b = name_str + + if prefix_b.size > 155 or name_b.size > 100: + return np.empty(0, dtype=np.uint8), np.empty(0, dtype=np.uint8) + + return prefix_b, name_b + + +@njit(INT64(U8C_RO), cache=True, fastmath=True, inline="always") +def _nb_compute_checksum(header: np.ndarray) -> int: + """ + Numba-optimized: Compute tar header checksum. + Treat chksum field (148-155) as spaces (0x20) during calculation. + """ + total = nb.int64(0) + for i in range(148): + total += header[i] + for i in range(148, 156): + total += 32 + for i in range(156, len(header)): + total += header[i] + return total + + +@njit(VOID(INT64, U8C), cache=True, inline="always") +def _nb_format_chksum(val: int, dst: np.ndarray) -> None: + dst[0] = (nb.uint8(val >> 15) & 0o7) | 0x30 + dst[1] = (nb.uint8(val >> 12) & 0o7) | 0x30 + dst[2] = (nb.uint8(val >> 9) & 0o7) | 0x30 + dst[3] = (nb.uint8(val >> 6) & 0o7) | 0x30 + dst[4] = (nb.uint8(val >> 3) & 0o7) | 0x30 + dst[5] = (nb.uint8(val) & 0o7) | 0x30 + dst[6] = 0 + dst[7] = 32 + + +@njit(nb.types.Tuple((U8C_RO, B1))(U8C_RO), cache=True, inline="always") +def _nb_extract_full_path(hdr: np.ndarray) -> tuple[np.ndarray, bool]: + for i in range(100): + if hdr[i] == 0: + name = hdr[0:i] + break + else: + name = hdr[0:100] + # if magic == b"ustar\0" or magic == b"ustar ": + if ( + hdr[257] == 117 + and hdr[258] == 115 + and hdr[259] == 116 + and hdr[260] == 97 + and hdr[261] == 114 + and (hdr[262] == 0 or hdr[262] == 32) + ): + for i in range(345, 500): + if hdr[i] == 0: + prefix_field = hdr[345:i] + break + else: + prefix_field = hdr[345:500] + if prefix_field.size > 0 and name.size > 0: + out = np.empty(prefix_field.size + 1 + name.size, dtype=np.uint8) + out[: prefix_field.size] = prefix_field + out[prefix_field.size] = 47 + out[prefix_field.size + 1 :] = name + return out, True + elif prefix_field.size > 0: + return prefix_field, True + else: + return name, True + else: + return name, False + + +@njit(cache=True) +def pax_parse(data: np.ndarray) -> list[tuple[np.ndarray, np.ndarray]]: + """ + Parse PAX extended header into list of (key, value). + Format lines: "%d key=value\n". + """ + out = [] + i = nb.int64(0) + n = len(data) + while i < n: + # parse length + j = i + while j < n and data[j] != 32: + j += 1 + if j == n: + break + rec_len = nb.int64(0) + for k in range(i, j): + if data[k] < 48 or data[k] > 57: + rec_len = 512 + break + rec_len = (rec_len * 10) + (data[k] - 48) + + if i + rec_len > n: + break + + sp = -1 + end = i + rec_len + if data[end - 1] != 10: + break + end -= 1 + for k in range(i, end): + if data[k] == 32: + sp = k + break + else: + break + + for i in range(sp + 1, end): + # ord(b"=") = 61 + if data[i] == 61: + key = data[sp + 1 : i] + val = data[i + 1 : end] + out.append((key, val)) + break + + i += rec_len + return out + + +@njit(nb.types.Tuple((nb.types.boolean, INT64))(U8C, U8C_RO), cache=True) +def update_header(hdr: np.ndarray, prefix: np.ndarray) -> tuple[bool, int]: + """Update the header with a new path, prefixing the path with prefix. + + Args: + hdr: The header to update. + prefix: The prefix to add to the path. + + Returns: + True if the header was updated successfully, False otherwise. + And the number of blocks to skip. + + """ + size_val = _nb_parse_size(hdr[124:136]) + + typeflag = hdr[156] + # ord(b"L") = 76, ord(b"K") = 75 + if typeflag == 76 or typeflag == 75: + raise TarPatcherError("Unexpected GNU longname/longlink encountered during patch.") + + # ord(b"x") = 120, ord(b"g") = 103 + if typeflag == 120 or typeflag == 103: + return False, size_val + + orig_path, is_ustar = _nb_extract_full_path(hdr) + # new_path = prefix + (name_prefix + b'/' if len(name_prefix) > 0 and len(name) > 0 else b'') + name + # ord(b"/") = 47 + new_path = np.empty(prefix.size + orig_path.size, dtype=np.uint8) + new_path[: prefix.size] = prefix + new_path[prefix.size :] = orig_path + + if is_ustar: + new_prefix_b, new_name_b = split_ustar_path(new_path) + if new_name_b.size == 0: + raise TarPatcherError( + "Internal error: ustar fields don't fit for " + repr(new_path) + "." + ) + hdr[: new_name_b.size] = new_name_b + for i in range(new_name_b.size, 100): + hdr[i] = 0 + hdr[345 : 345 + new_prefix_b.size] = new_prefix_b + for i in range(345 + new_prefix_b.size, 500): + hdr[i] = 0 + else: + new_name_b = new_path + if new_name_b.size > 100: + raise TarPatcherError( + "Internal error: legacy name too long for " + repr(new_path) + "." + ) + hdr[0 : new_name_b.size] = new_name_b + for i in range(new_name_b.size, 100): + hdr[i] = 0 + + checksum = _nb_compute_checksum(hdr) + _nb_format_chksum(checksum, hdr[148:156]) + + return True, size_val + + +@njit(nb.types.boolean(U8C_RO, INT64, INT64)) +def _nb_evaluate_pax_header( + raw_data: np.ndarray, + position: nb.int64, + size_val: int, +) -> None: + PAX_PATH_KEYS = ( + np.frombuffer(b"path", dtype=np.uint8), + np.frombuffer(b"linkpath", dtype=np.uint8), + np.frombuffer(b"gnu.path", dtype=np.uint8), + np.frombuffer(b"gnu.linkpath", dtype=np.uint8), + np.frombuffer(b"SCHILY.path", dtype=np.uint8), + np.frombuffer(b"SCHILY.linkpath", dtype=np.uint8), + ) + blocks = (size_val + BLOCK_SIZE - 1) // BLOCK_SIZE * BLOCK_SIZE + data = raw_data[position : position + blocks] + if data.size < blocks: + raise TarPatcherError("Truncated PAX extended header data.") + records = pax_parse(data[:size_val]) + for key, _ in records: + for pax_key in PAX_PATH_KEYS: + if key.size == pax_key.size and np.all(key == pax_key): + return False + return True + + +@njit(nb.types.Tuple((nb.types.boolean, nb.types.ListType(U8C_RO)))(U8C_RO, U8C_RO), cache=True) +def _nb_scan_file(raw_data: np.ndarray, prefix_bytes: np.ndarray) -> tuple[bool, list[np.ndarray]]: + position = nb.int64(0) + total = raw_data.size + compatible = True + sample_key_list = nb.typed.List.empty_list(U8C_RO) + + rd_buf = np.empty(65536, dtype=np.uint8) + + last_sample_key = np.empty(0, dtype=np.uint8) + + while True: + header = raw_data[position : position + BLOCK_SIZE].copy() + if position + BLOCK_SIZE > total: + raise TarPatcherError("Unexpected EOF while reading header.") + + if _nb_is_zero_block(header): + break + + size_val = _nb_parse_size(header[124:136]) + + typeflag = header[156] + # ord(b"L") = 76, ord(b"K") = 75 + if typeflag == 76 or typeflag == 75: + # "Unexpected GNU longname/longlink encountered during patch." + if compatible: + print("Found GNU longname/longlink entry") + compatible = False + # TODO: Still parse the filename to get the sample key + raise TarPatcherError( + "Found GNU longname/longlink entry; in-place rename is unsupported." + ) + + # ord("x") = 120, ord("g") = 103 + if typeflag in (120, 103): + if not _nb_evaluate_pax_header(raw_data, position, size_val): + # "PAX header contains unsupported key; in-place rename is unsafe." + if compatible: + print("PAX header contains unsupported key") + compatible = False + else: + full_path, is_ustar = _nb_extract_full_path(header) + + new_path = np.empty(prefix_bytes.size + full_path.size, dtype=np.uint8) + new_path[: prefix_bytes.size] = prefix_bytes + new_path[prefix_bytes.size :] = full_path + + # Apply the regex for splitting the sample key from the extension. + # Find last slash to find filename + for i in range(full_path.size - 1, -1, -1): + # ord('/') = 47 + if full_path[i] == 47: + last_path_idx = i + break + else: + last_path_idx = 0 + # Find extension (i.e. first dot after last slash) + for i in range(last_path_idx + 1, full_path.size): + # ord('.') = 46 + if full_path[i] == 46: + extension_idx = i + break + else: + extension_idx = 0 + sample_key = full_path[:extension_idx] + if sample_key.size > 0: + # Next group, store the sample key + if last_sample_key.size != sample_key.size or not np.all( + last_sample_key == sample_key + ): + sample_key_list.append(sample_key) + last_sample_key = sample_key + + if is_ustar: + new_prefix_b, new_name_b = split_ustar_path(new_path) + if new_name_b.size == 0: + # "Internal error: ustar fields don't fit." + if compatible: + print("Internal error: ustar fields don't fit") + compatible = False + else: + if new_path.size > 100: + # "New name too long for legacy header." + if compatible: + print("New name too long for legacy header") + compatible = False + + # Dummy read + inc = BLOCK_SIZE + (size_val + BLOCK_SIZE - 1) // BLOCK_SIZE * BLOCK_SIZE + last_block_size = inc % 65536 + for i in range(position, position + inc - last_block_size, 65536): + rd_buf[:] = raw_data[i : i + 65536] + rd_buf[:last_block_size] = raw_data[position + inc - last_block_size : position + inc] + position += inc + + return compatible, sample_key_list + + +@njit(VOID(U8C, U8C_RO), cache=True) +def _nb_process_file(raw_data: np.ndarray, prefix_bytes: np.ndarray) -> None: + position = nb.int64(0) + total = raw_data.size + rd_buf = np.empty(65536, dtype=np.uint8) + + while True: + header = raw_data[position : position + BLOCK_SIZE].copy() + if position + BLOCK_SIZE > total: + raise TarPatcherError("Unexpected EOF while reading header.") + + if _nb_is_zero_block(header): + break + + was_updated, size_val = update_header(header, prefix_bytes) + + if was_updated: + raw_data[position : position + BLOCK_SIZE] = header + + inc = BLOCK_SIZE + (size_val + BLOCK_SIZE - 1) // BLOCK_SIZE * BLOCK_SIZE + last_block_size = inc % 65536 + for i in range(position, position + inc - last_block_size, 65536): + rd_buf[:] = raw_data[i : i + 65536] + rd_buf[:last_block_size] = raw_data[position + inc - last_block_size : position + inc] + position += inc + + +# Keys that, if present in PAX, mean path info is controlled by PAX, +# so in-place rename of classic headers is NOT safe. +PAX_PATH_KEYS = ( + np.frombuffer(b"path", dtype=np.uint8), + np.frombuffer(b"linkpath", dtype=np.uint8), + np.frombuffer(b"gnu.path", dtype=np.uint8), + np.frombuffer(b"gnu.linkpath", dtype=np.uint8), + np.frombuffer(b"SCHILY.path", dtype=np.uint8), + np.frombuffer(b"SCHILY.linkpath", dtype=np.uint8), +) + + +split_name_re_bytes = re.compile(rb"^((?:.*/|)[^.]+)[.]([^/]*)$") + + +class TarPatcher: + """Utility for scanning and renaming tar archive entries in place.""" + + def __init__(self, *, show_progress: bool = True) -> None: + self._show_progress = show_progress + + def dataset_scan( + self, tar_files: Sequence[str], parent_path: EPath, num_workers: int = NUM_WORKERS + ) -> DatasetScanResult: + """Scan multiple tar files, checking compatibility for in-place renaming and for duplicate sample keys. + Each tar_file string must be a relative or absolute path to a tar file. + + Args: + tar_files: List of relative or absolute paths to the tar files to scan. + parent_path: Parent path of the tar files, used if tar_files are relative paths. + + Returns: + DatasetScanResult: Result of the scan. + """ + + scan_results: Dict[str, TarScanResult] = {} + + # Maps from sample key to list of tar files containing it + duplicates: Dict[str, Set[str]] = {} + + compatible = True + have_duplicates = False + + tasks: list[tuple[str, str]] = [] + for rel_tar_file in tar_files: + tar_file_path = parent_path / rel_tar_file + rel_file_path = tar_file_path.relative_to(parent_path) + tar_file = str(tar_file_path) + prefix = f"{rel_file_path}/" + tasks.append((tar_file, prefix)) + + if not tasks: + return DatasetScanResult( + compatible=True, + duplicates={}, + scan_results={}, + ) + + max_workers = min(len(tasks), num_workers) + + with tqdm( + total=len(tasks), + desc="Scanning dataset", + unit="shards", + disable=not self._show_progress, + ) as dataset_pbar: + with ProcessPoolExecutor(max_workers=max_workers) as executor: + jobs = [executor.submit(_scan_tar_worker, *task) for task in tasks] + for future in concurrent.futures.as_completed(jobs): + try: + tar_file, result = future.result() + except: + import traceback + + traceback.print_exc() + raise + scan_results[tar_file] = result + if not result.compatible: + compatible = False + for sample_key in result.sample_keys: + duplicates.setdefault(sample_key, set()).add(tar_file) + if len(duplicates[sample_key]) > 1: + have_duplicates = True + + if have_duplicates and not compatible: + # Let's stop early if we have duplicates and the dataset is not compatible for fixing + break + + dataset_pbar.update() + + for job in jobs: + job.cancel() + + duplicate_map = {key: sorted(paths) for key, paths in duplicates.items() if len(paths) > 1} + + return DatasetScanResult( + compatible=compatible, + duplicates=duplicate_map, + scan_results=scan_results, + ) + + def dataset_apply_prefix( + self, + tar_files: Sequence[str], + parent_path: EPath, + num_workers: int = NUM_WORKERS, + ) -> None: + """Apply shard-specific prefixes to a set of tar files.""" + + tasks: list[tuple[str, str]] = [] + for rel_tar_file in tar_files: + tar_file_path = parent_path / rel_tar_file + rel_file_path = tar_file_path.relative_to(parent_path) + + tar_file = str(tar_file_path) + prefix = f"{rel_file_path}/" + tasks.append((tar_file, prefix)) + + if not tasks: + return + + max_workers = min(len(tasks), num_workers) + + with tqdm( + total=len(tasks), + desc="Applying prefixes", + unit="shards", + disable=not self._show_progress, + ) as dataset_pbar: + with ProcessPoolExecutor(max_workers=max_workers) as executor: + jobs = [executor.submit(_apply_prefix_worker, *task) for task in tasks] + for future in concurrent.futures.as_completed(jobs): + future.result() + dataset_pbar.update() + + def scan(self, tar_path: Path | str, prefix: str) -> TarScanResult: + """Scan *tar_path* and evaluate compatibility for prefixing entries.""" + + prefix_bytes = np.frombuffer(prefix.encode("utf-8"), dtype=np.uint8) + + raw_data = np.memmap(tar_path, dtype=np.uint8, mode="r+") + + compatible, sample_keys_list = _nb_scan_file(raw_data, prefix_bytes) + + # Convert numpy arrays to bytes for the set + sample_keys_set = {sample_key.tobytes() for sample_key in sample_keys_list} + if len(sample_keys_set) != len(sample_keys_list): + from collections import Counter + + print( + f"Duplicate sample keys within a single tar file {tar_path}: {len(sample_keys_list)} keys, {len(sample_keys_set)} unique keys" + ) + print( + f"Most common sample keys: {Counter(key.tobytes() for key in sample_keys_list).most_common(10)}" + ) + compatible = False + + return TarScanResult( + sample_keys=sample_keys_set, + compatible=compatible, + ) + + def apply_prefix( + self, + tar_path: Path | str, + prefix: str, + ) -> None: + """Apply *prefix* to entries in *tar_path* in place.""" + prefix_bytes = np.frombuffer(prefix.encode("utf-8"), dtype=np.uint8) + raw_data = np.memmap(tar_path, dtype=np.uint8, mode="readwrite") + _nb_process_file(raw_data, prefix_bytes) + raw_data.flush() + + def _evaluate_pax_header( + self, + handle: BinaryIO, + size_val: int, + pbar: tqdm, + result: TarScanResult, + ) -> None: + blocks = (size_val + BLOCK_SIZE - 1) // BLOCK_SIZE + data = handle.read(blocks * BLOCK_SIZE) + if len(data) < blocks * BLOCK_SIZE: + raise TarPatcherError("Truncated PAX extended header data.") + pbar.update(blocks * BLOCK_SIZE) + records = pax_parse(data[:size_val]) + for key, _ in records: + if any( + key.size == pax_key.size and np.all(key == pax_key) for pax_key in PAX_PATH_KEYS + ): + result.compatible = False + result.issues.append( + f"PAX header contains unsupported key {key!r}; in-place rename is unsafe." + ) + + def _skip_payload(self, handle: BinaryIO, size_val: int, pbar: tqdm) -> None: + data_blocks = (size_val + BLOCK_SIZE - 1) // BLOCK_SIZE + if data_blocks: + handle.seek(data_blocks * BLOCK_SIZE, 1) + pbar.update(data_blocks * BLOCK_SIZE) + + def _progress(self, total: int, desc: str, disable: bool = False) -> tqdm: + return tqdm( + total=total, + unit="B", + unit_scale=True, + desc=desc, + leave=False, + disable=disable, + ) + + +def _scan_tar_worker(tar_file: str, prefix: str) -> tuple[str, TarScanResult]: + patcher = TarPatcher(show_progress=False) + result = patcher.scan(tar_file, prefix) + return tar_file, result + + +def _apply_prefix_worker(tar_file: str, prefix: str) -> str: + patcher = TarPatcher(show_progress=False) + patcher.apply_prefix(tar_file, prefix) + return tar_file + + +@click.group() +def cli(): + pass + + +@cli.command() +@click.argument( + "dataset_path", + type=click.Path(exists=True, file_okay=False, path_type=Path), +) +@click.option( + "--dry-run", + is_flag=True, + help="Only check if in-place renaming is possible; don't modify the file.", +) +def run_dataset(dataset_path: Path, dry_run: bool): + """ + PREFIX all member names in all tar files in DATASET_PATH in-place, if safely possible. + + Rules: + - Allowed: + * Classic/ustar entries. + * PAX x/g entries that do NOT contain path-related keys. + - Rejected: + * Any PAX entry with path/linkpath-style keys. + * Any GNU longname/longlink (L/K). + * Any resulting name that doesn't fit fixed-size header fields. + """ + import time + + patcher = TarPatcher() + + # files = ["shard-0.tar", "audio_0.tar"] + # for file in files: + # print(f"Reading {file}...") + # start = time.time() + # buf = np.memmap(file, dtype=np.uint8, mode="r") + # for i in range(0, buf.size, 65536): + # buf[i:i+65536].copy() + # # with open(file, "rb") as f: + # # f.read(65536) + # end = time.time() + # print(f"Read time: {end - start} seconds") + + files = tuple(str(p.relative_to(dataset_path)) for p in dataset_path.rglob("*.tar")) + + click.echo(f"Scanning {dataset_path} for in-place rename feasibility...") + start = time.time() + res = patcher.dataset_scan(files, dataset_path) + end = time.time() + print(f"Scan time: {end - start} seconds") + print( + f"compatible: {res.compatible}, {len(res.duplicates)} duplicates: {[key for _, key in zip(range(10), res.duplicates.items())]}" + ) + if not res.compatible: + raise click.ClickException("In-place rename is not possible under current rules.") + if len(res.duplicates) == 0: + raise click.ClickException("No duplicates found.") + + click.echo("OK: in-place modification is possible under current rules.") + + if dry_run: + return + + click.echo("Applying prefix in-place...") + try: + start = time.time() + patcher.dataset_apply_prefix(files, dataset_path) + end = time.time() + print(f"Apply time: {end - start} seconds") + except TarPatcherError as exc: + raise click.ClickException(str(exc)) from exc + click.echo("Done. All eligible member names have been updated.") + + +@cli.command() +@click.argument( + "tar_file", + type=click.Path(exists=True, dir_okay=False, path_type=Path), +) +@click.argument("prefix", type=str) +@click.option( + "--dry-run", + is_flag=True, + help="Only check if in-place renaming is possible; don't modify the file.", +) +def run_file(tar_file: Path, prefix: str, dry_run: bool): + """ + PREFIX all member names in TAR_FILE in-place, if safely possible. + + Rules: + - Allowed: + * Classic/ustar entries. + * PAX x/g entries that do NOT contain path-related keys. + - Rejected: + * Any PAX entry with path/linkpath-style keys. + * Any GNU longname/longlink (L/K). + * Any resulting name that doesn't fit fixed-size header fields. + """ + import time + + tar_path = Path(tar_file) + + patcher = TarPatcher() + + # files = ["shard-0.tar", "audio_0.tar"] + # for file in files: + # print(f"Reading {file}...") + # start = time.time() + # buf = np.memmap(file, dtype=np.uint8, mode="r") + # for i in range(0, buf.size, 65536): + # buf[i:i+65536].copy() + # # with open(file, "rb") as f: + # # f.read(65536) + # end = time.time() + # print(f"Read time: {end - start} seconds") + + click.echo(f"Scanning {tar_path} for in-place rename feasibility...") + orig_start_method = mp.get_start_method() + # mp.set_start_method("fork", force=True) + for i in range(2): + try: + print(f"Scanning {i}...") + start = time.time() + res = patcher.scan(tar_path, prefix) + end = time.time() + print( + f"compatible: {res.compatible}, {len(res.sample_keys)} keys, keys: {[key for _, key in zip(range(10), res.sample_keys)]}" + ) + print(f"Scan time: {end - start} seconds") + except TarPatcherError as exc: + raise click.ClickException(str(exc)) from exc + + mp.set_start_method(orig_start_method, force=True) + + click.echo("OK: in-place modification is possible under current rules.") + + if dry_run: + return + + click.echo("Applying prefix in-place...") + try: + start = time.time() + patcher.apply_prefix(tar_path, prefix) + end = time.time() + print(f"Apply time: {end - start} seconds") + except TarPatcherError as exc: + raise click.ClickException(str(exc)) from exc + click.echo("Done. All eligible member names have been updated.") + + +if __name__ == "__main__": + cli() diff --git a/src/megatron/energon/loader.py b/src/megatron/energon/loader.py index 97b796b9..b973b672 100644 --- a/src/megatron/energon/loader.py +++ b/src/megatron/energon/loader.py @@ -4,7 +4,7 @@ from typing import Optional, TypeVar from megatron.energon.cache import CachePool -from megatron.energon.errors import warn_deprecated +from megatron.energon.deprecation import warn_deprecated from megatron.energon.flavors import SavableDataset from megatron.energon.savable_loader import BasicDataLoader, SavableDataLoader from megatron.energon.worker import WorkerConfig diff --git a/src/megatron/energon/media/__init__.py b/src/megatron/energon/media/__init__.py new file mode 100644 index 00000000..dcb83eb2 --- /dev/null +++ b/src/megatron/energon/media/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause + +from .metadata import ( + AVMetadata, + ImageMetadata, + MediaMetadataBase, + MediaMetadataType, + deserialize_media_metadata, + serialize_media_metadata, +) + +__all__ = [ + "AVMetadata", + "ImageMetadata", + "MediaMetadataBase", + "MediaMetadataType", + "deserialize_media_metadata", + "serialize_media_metadata", +] diff --git a/src/megatron/energon/media/extractor.py b/src/megatron/energon/media/extractor.py new file mode 100644 index 00000000..1126584e --- /dev/null +++ b/src/megatron/energon/media/extractor.py @@ -0,0 +1,241 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import io +import logging +from dataclasses import dataclass, field +from enum import Enum +from fnmatch import fnmatch +from pathlib import Path +from typing import BinaryIO, Union + +import filetype +from PIL import Image, UnidentifiedImageError + +from megatron.energon.av import AVDecoder +from megatron.energon.media.metadata import ImageMetadata, MediaMetadataBase, MediaMetadataType + +logger = logging.getLogger(__name__) + + +SourceData = Union[bytes, Path, BinaryIO] + + +class MediaFilterStrategy(str, Enum): + """Strategy used to decide whether an entry should be treated as media.""" + + EXTENSION = "EXTENSION" # by file extension + HEADER = "HEADER" # by header filetype (detected using filetype package) + GLOB = "GLOB" # by one or more glob patterns (e.g. '*.jpg') + + +@dataclass(frozen=True) +class MediaFilterConfig: + """Configuration for media detection during dataset preparation.""" + + strategy: MediaFilterStrategy + patterns: list[str] = field(default_factory=list) + + @classmethod + def parse(cls, glob: str | None, header: bool, extension: bool) -> "MediaFilterConfig": + # Check that exactly one of the strategies is enabled + strategy_count = sum(bool(s) for s in [glob, header, extension]) + + if strategy_count != 1: + raise ValueError( + "Exactly one of GLOB, HEADER, or EXTENSION media filters must be enabled. " + "You can use multiple glob patterns by separating them by commas." + ) + + if glob: + if "," in glob: + patterns = glob.split(",") + else: + patterns = [glob] + return cls(strategy=MediaFilterStrategy.GLOB, patterns=patterns) + if header: + return cls(strategy=MediaFilterStrategy.HEADER) + if extension: + return cls(strategy=MediaFilterStrategy.EXTENSION) + assert False, "Internal error: Should not be reached" + + def should_consider_all(self) -> bool: + """Check whether all files need to be considered for metadata extraction. + This is the case, if we need to inspect the file content to determine the media type.""" + + return self.strategy == MediaFilterStrategy.HEADER + + def should_consider_media(self, name: str) -> bool: + """Check whether a file name qualifies for metadata extraction under the filter. + This is a first stage check to avoid loading the file content into memory if possible.""" + + lower_name = name.lower() + + if self.strategy == MediaFilterStrategy.HEADER: + # TYPE detection relies on file content, hence it always requires inspection. + return True + + if self.strategy == MediaFilterStrategy.EXTENSION: + return _guess_type_from_extension(lower_name) is not None + + assert self.patterns is not None, "Pattern strategy requires a glob expression" + return any(fnmatch(lower_name, pattern) for pattern in self.patterns) + + def extract_metadata( + self, + source: SourceData, + filename: str | None = None, + ) -> MediaMetadataBase | None: + """Extract media metadata from the source, if the file is a media file according to the filter. + If the file is found not to be a media file, None is returned. + + Args: + source: The source data to extract metadata from. This can be a bytes, Path, or an open file. + filename: The filename of the source data. This is required when extracting metadata from bytes or an open file. + + Returns: + The media metadata, if the file is a media file according to the filter. None otherwise. + """ + + if isinstance(source, (bytes, bytearray, io.IOBase)): + assert filename is not None, ( + "Filename is required when extracting metadata from bytes or IOBase" + ) + else: + assert filename is None, "Filename is not allowed when extracting metadata from path" + filename = source.name + + media_type = _detect_media_type(filename, self, source) + + if media_type is None: + return None + + metadata = _build_metadata(media_type, source) + if metadata is None: + return None + return metadata + + +_IMAGE_EXTENSIONS: set[str] = { + "bmp", + "gif", + "ico", + "j2k", + "jp2", + "jpx", + "jpeg", + "jpg", + "png", + "tif", + "tiff", + "webp", +} + +_AV_EXTENSIONS: set[str] = { + "aac", + "avi", + "flac", + "m4a", + "m4v", + "mkv", + "mov", + "mp3", + "mp4", + "ogg", + "wav", + "webm", +} + +_FILETYPE_PROBE_SIZE = 262 + + +def _detect_media_type( + name: str, + config: MediaFilterConfig, + source: SourceData, +) -> MediaMetadataType | None: + # Case 1: GLOB strategy + if config.strategy == MediaFilterStrategy.GLOB: + if not any(fnmatch(name, pattern) for pattern in config.patterns): + return None + + extension_guess = _guess_type_from_extension(name) + + # Case 2: EXTENSION strategy + if config.strategy == MediaFilterStrategy.EXTENSION: + return extension_guess + + # Case 3: HEADER strategy + assert config.strategy == MediaFilterStrategy.HEADER, ( + "Internal error: Unexpected media filter strategy" + ) + + detected = _guess_type_from_filetype(source) + return detected if detected is not None else extension_guess + + +def _guess_type_from_extension(name: str) -> MediaMetadataType | None: + suffix = Path(name).suffix.lstrip(".").lower() + if suffix in _IMAGE_EXTENSIONS: + return MediaMetadataType.IMAGE + if suffix in _AV_EXTENSIONS: + return MediaMetadataType.AV + return None + + +def _guess_type_from_filetype(source: SourceData) -> MediaMetadataType | None: + kind = filetype.guess(source) + + if kind is None or kind.mime is None: + return None + mime = kind.mime + if mime.startswith("image/"): + return MediaMetadataType.IMAGE + if mime.startswith("video/") or mime.startswith("audio/"): + return MediaMetadataType.AV + return None + + +def _build_metadata( + media_type: MediaMetadataType, + source: SourceData, +) -> MediaMetadataBase | None: + if media_type is MediaMetadataType.IMAGE: + return _build_image_metadata(source) + if media_type is MediaMetadataType.AV: + return _build_av_metadata(source) + return None + + +def _build_image_metadata(source: SourceData) -> ImageMetadata | None: + try: + if isinstance(source, (bytes, bytearray)): + source = io.BytesIO(source) + + with Image.open(source) as image: + image.load() + return ImageMetadata( + width=image.width, + height=image.height, + format=image.format or "UNKNOWN", + mode=image.mode or "UNKNOWN", + ) + except UnidentifiedImageError: + logger.debug("Failed to parse image metadata", exc_info=True) + return None + + +def _build_av_metadata(source: SourceData) -> MediaMetadataBase | None: + try: + if isinstance(source, (bytes, bytearray)): + return AVDecoder(io.BytesIO(source)).get_metadata(get_audio_num_samples=True) + elif isinstance(source, io.IOBase): + return AVDecoder(source).get_metadata(get_audio_num_samples=True) + else: + with source.open("rb") as stream: + return AVDecoder(stream).get_metadata(get_audio_num_samples=True) + except Exception: # pragma: no cover - depends on external libs/media support + logger.debug("Failed to parse AV metadata", exc_info=True) + return None diff --git a/src/megatron/energon/media/filesystem_prepare.py b/src/megatron/energon/media/filesystem_prepare.py new file mode 100644 index 00000000..9d8100f4 --- /dev/null +++ b/src/megatron/energon/media/filesystem_prepare.py @@ -0,0 +1,162 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import os +from functools import partial +from pathlib import Path +from typing import Callable, Iterator + +from megatron.energon.epathlib import EPath +from megatron.energon.flavors.webdataset.aggregator_pool import AggregatorPool +from megatron.energon.flavors.webdataset.config import INDEX_SQLITE_FILENAME, MAIN_FOLDER_NAME +from megatron.energon.flavors.webdataset.prepare import ( + IndexAggregatable, + IndexMediaMetadata, + SqliteIndexWriterAggregator, +) +from megatron.energon.flavors.webdataset.structs import ShardInfo +from megatron.energon.media.extractor import MediaFilterConfig +from megatron.energon.media.metadata import serialize_media_metadata + + +def prepare_filesystem_dataset( + root_path: EPath, + media_filter: MediaFilterConfig, + *, + progress: bool, + workers: int = 16, +) -> int: + """Scan a filesystem dataset and materialize media metadata into SQLite. + + Args: + root_path: Dataset root directory. + media_filter: Media filtering configuration. + progress: Whether to display a tqdm progress bar. + + Returns: + Number of metadata entries written to the database. + """ + + # Only supporting local file system, because sqlite does not support remote file systems. + # TODO: Implement remote file systems. Maybe create locally in tmp then upload? + assert str(root_path).startswith("/"), ( + f"SQLite path must be absolute local file system path: {root_path}" + ) + + root = Path(str(root_path)) + assert root.is_dir(), f"Expected directory for filesystem dataset, got {root}" + assert root.is_absolute(), f"Filesystem dataset path must be absolute: {root}" + + meta_dir = root / MAIN_FOLDER_NAME + meta_dir.mkdir(exist_ok=True, parents=True) + + files = _collect_media_files(root=root, media_filter=media_filter, progress=progress) + + sqlite_path = EPath(meta_dir / INDEX_SQLITE_FILENAME) + + agg_progress_fn: Callable[[Iterator[int], int], Iterator[int]] | None = None + if progress: + from tqdm.auto import tqdm + + def agg_progress_fn(iterator: Iterator[int], total: int) -> Iterator[int]: + with tqdm(iterator, total=total, unit="file", desc="Processing media files") as bar: + yield from bar + + aggregator = SqliteIndexWriterAggregator( + sqlite_path, + total_tasks=len(files), + progress_fn=agg_progress_fn, + enable_media_metadata=True, + media_filter=media_filter, + reset_tables=False, + enable_sample_tables=False, + progress_on_media=progress, + ) + + pool = AggregatorPool[ + Path, + IndexAggregatable, + tuple[list[ShardInfo], set[str], bool, list[tuple[str, int]]], + ]( + num_workers=min(workers, len(files)) or 1, + user_produce_data=partial( + _process_filesystem_entry, + root=root, + media_filter=media_filter, + ), + aggregator=aggregator, + ) + + for file_path in files: + pool.submit_task(file_path) + + pool.process() + + return aggregator.media_metadata_written + + +def _collect_media_files( + *, root: Path, media_filter: MediaFilterConfig, progress: bool = False +) -> list[Path]: + """Return a sorted list of files to process based on the media filter.""" + + consider_all = media_filter.should_consider_all() + files: list[Path] = [] + + progress_bar = None + if progress: + from tqdm.auto import tqdm + + progress_bar = tqdm(total=None, unit="file", desc="Collecting media files") + + for dirpath, dirnames, filenames in os.walk(root, followlinks=False): + current_dir = Path(dirpath) + + if current_dir.name == MAIN_FOLDER_NAME: + dirnames[:] = [] + continue + + dirnames[:] = [d for d in dirnames if d != MAIN_FOLDER_NAME] + + for filename in filenames: + if progress_bar is not None: + progress_bar.update() + + if not consider_all and not media_filter.should_consider_media(filename): + continue + + candidate = current_dir / filename + if MAIN_FOLDER_NAME in candidate.parts: + continue + + if candidate.is_file(): + files.append(candidate) + + if progress_bar is not None: + progress_bar.close() + + files.sort() + return files + + +def _process_filesystem_entry( + file_path: Path | str, + *, + root: Path, + media_filter: MediaFilterConfig, +) -> Iterator[IndexMediaMetadata]: + file_path = Path(file_path) + extracted_metadata = media_filter.extract_metadata(file_path) + if extracted_metadata is None: + return + + stored_type, metadata_json = serialize_media_metadata(extracted_metadata) + entry_key = file_path.relative_to(root).as_posix() + + yield IndexMediaMetadata( + entry_key=entry_key, + metadata_type=stored_type.value, + metadata_json=metadata_json, + ) diff --git a/src/megatron/energon/media/metadata.py b/src/megatron/energon/media/metadata.py new file mode 100644 index 00000000..72534723 --- /dev/null +++ b/src/megatron/energon/media/metadata.py @@ -0,0 +1,107 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import json +from dataclasses import asdict +from enum import Enum +from typing import Any, ClassVar, Dict, Mapping, Type, TypeVar + +from megatron.energon.edataclass import edataclass + + +class MediaMetadataType(str, Enum): + """Enumerates the supported media metadata payload kinds.""" + + AV = "av" + IMAGE = "image" + + @classmethod + def from_string(cls, value: str) -> "MediaMetadataType": + try: + return cls(value) + except ValueError as exc: # pragma: no cover - defensive + raise ValueError(f"Unsupported media metadata type: {value!r}") from exc + + +TMetadata = TypeVar("TMetadata", bound="MediaMetadataBase") + + +@edataclass +class MediaMetadataBase: + """Base class for metadata payloads to support typed JSON storage.""" + + metadata_type: ClassVar[MediaMetadataType] + + def to_dict(self) -> Dict[str, Any]: + """Return a JSON-serialisable mapping representation.""" + + return {key: value for key, value in asdict(self).items() if value is not None} + + @classmethod + def from_dict(cls: Type[TMetadata], payload: Mapping[str, Any]) -> TMetadata: + """Construct the metadata object from its JSON representation.""" + + return cls(**payload) + + +@edataclass +class AVMetadata(MediaMetadataBase): + """Metadata of a video or audio asset.""" + + metadata_type: ClassVar[MediaMetadataType] = MediaMetadataType.AV + + video_duration: float | None = None + video_num_frames: int | None = None + video_fps: float | None = None + video_width: int | None = None + video_height: int | None = None + + audio_duration: float | None = None + audio_channels: int | None = None + audio_sample_rate: int | None = None + audio_num_samples: int | None = None + + +@edataclass +class ImageMetadata(MediaMetadataBase): + """Metadata for an encoded image file.""" + + metadata_type: ClassVar[MediaMetadataType] = MediaMetadataType.IMAGE + + width: int + height: int + format: str + mode: str + + +_MEDIA_METADATA_REGISTRY: Dict[MediaMetadataType, Type[MediaMetadataBase]] = { + MediaMetadataType.AV: AVMetadata, + MediaMetadataType.IMAGE: ImageMetadata, +} + + +def serialize_media_metadata(metadata: MediaMetadataBase) -> tuple[MediaMetadataType, str]: + """Serialise the metadata to a tuple of (type, json.dumps(payload)).""" + + payload_json = json.dumps(metadata.to_dict(), separators=(",", ":")) + return metadata.metadata_type, payload_json + + +def deserialize_media_metadata( + metadata_type: str | MediaMetadataType, + metadata_json: str, +) -> MediaMetadataBase: + """Deserialize a metadata record from stored SQLite values.""" + + if not isinstance(metadata_type, MediaMetadataType): + metadata_type = MediaMetadataType.from_string(metadata_type) + + try: + payload_cls = _MEDIA_METADATA_REGISTRY[metadata_type] + except KeyError as exc: # pragma: no cover - future proofing + raise ValueError(f"Unsupported media metadata type: {metadata_type}") from exc + + payload_dict = json.loads(metadata_json) if metadata_json else {} + return payload_cls.from_dict(payload_dict) diff --git a/src/megatron/energon/metadataset/join_dataset_loader.py b/src/megatron/energon/metadataset/join_dataset_loader.py index 9403326d..4ea1cbd4 100644 --- a/src/megatron/energon/metadataset/join_dataset_loader.py +++ b/src/megatron/energon/metadataset/join_dataset_loader.py @@ -23,7 +23,11 @@ JoinedWebdatasetFactory, Sample, ) -from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME +from megatron.energon.flavors.webdataset.config import ( + INDEX_SQLITE_FILENAME, + INDEX_UUID_FILENAME, + MAIN_FOLDER_NAME, +) from megatron.energon.flavors.webdataset.indexing import JoinIndexWriter from megatron.energon.flavors.webdataset.metadata import WebdatasetMeta from megatron.energon.flavors.webdataset.structs import DatasetSubset @@ -158,7 +162,6 @@ def join_multiple_indices( # For safety, enable a read-only or big timeouts conn.execute("PRAGMA busy_timeout = 5000;") - conn.execute("PRAGMA journal_mode = WAL;") # 2. Attach each secondary DB under a unique alias, e.g. db1, db2, ... secondary_aliases = [] @@ -373,7 +376,7 @@ def _get_joined_meta(self, split_part: str) -> Tuple[EPath, List[JoinedDatasetMe for dataset in datasets: print(f" - {dataset}") - uuid_path = EPath(dataset.dataset.path) / MAIN_FOLDER_NAME / "index.uuid" + uuid_path = EPath(dataset.dataset.path) / MAIN_FOLDER_NAME / INDEX_UUID_FILENAME try: uuid = uuid_path.read_text() except FileNotFoundError: @@ -382,7 +385,7 @@ def _get_joined_meta(self, split_part: str) -> Tuple[EPath, List[JoinedDatasetMe "(with a recent version of energon). If you have already prepared the " "dataset, it should be sufficient to run prepare with --tar-index-only." ) - db_path = EPath(dataset.dataset.path) / MAIN_FOLDER_NAME / "index.sqlite" + db_path = EPath(dataset.dataset.path) / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME # Precedence for split_part is: # 1. Join dataset split part (overrides individual dataset split parts) diff --git a/src/megatron/energon/metadataset/loader.py b/src/megatron/energon/metadataset/loader.py index c35ac4fb..2056826f 100644 --- a/src/megatron/energon/metadataset/loader.py +++ b/src/megatron/energon/metadataset/loader.py @@ -40,6 +40,8 @@ def load_dataset( ds = DatasetLoader(path=path, **kwargs) ds.post_initialize() return ds + elif ds_type == EnergonDatasetType.FILESYSTEM: + raise ValueError("Filesystem datasets can only be used as auxiliary datasets.") else: raise ValueError(f"Invalid dataset at {path}") diff --git a/src/megatron/energon/metadataset/metadataset.py b/src/megatron/energon/metadataset/metadataset.py index 6da4d57e..e5abef64 100644 --- a/src/megatron/energon/metadataset/metadataset.py +++ b/src/megatron/energon/metadataset/metadataset.py @@ -4,9 +4,9 @@ from typing import Any, Dict, List, Literal, Optional, Union from megatron.energon.dataset_config import load_config +from megatron.energon.deprecation import warn_deprecated from megatron.energon.edataclass import edataclass from megatron.energon.epathlib import EPath -from megatron.energon.errors import warn_deprecated from megatron.energon.flavors.webdataset.metadata import check_dataset_info_present from megatron.energon.flavors.webdataset.structs import DatasetSubset from megatron.energon.metadataset.dataset_loader import DatasetLoader diff --git a/src/megatron/energon/metadataset/metadataset_v2.py b/src/megatron/energon/metadataset/metadataset_v2.py index a36e3217..5f67837f 100644 --- a/src/megatron/energon/metadataset/metadataset_v2.py +++ b/src/megatron/energon/metadataset/metadataset_v2.py @@ -14,7 +14,7 @@ from megatron.energon.edataclass import edataclass from megatron.energon.epathlib import EPath from megatron.energon.flavors import Sample -from megatron.energon.flavors.webdataset import MAIN_FOLDER_NAME +from megatron.energon.flavors.webdataset.config import INDEX_SQLITE_FILENAME, MAIN_FOLDER_NAME from megatron.energon.flavors.webdataset.metadata import EnergonDatasetType, get_dataset_type from megatron.energon.flavors.webdataset.structs import DatasetSubset from megatron.energon.metadataset.dataset_loader import DatasetLoader @@ -42,7 +42,7 @@ def post_initialize(self, mds_path: Optional[EPath] = None) -> None: assert not self.path.is_file(), ( "Auxiliary datasets must not be metadataset, but direct dataset references" ) - assert (self.path / MAIN_FOLDER_NAME / "index.sqlite").is_file(), ( + assert (self.path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME).is_file(), ( "Auxiliary datasets must be prepared Energon datasets. This one does not exist or is not prepared: " + str(self.path) ) @@ -225,6 +225,10 @@ def post_initialize(self, mds_path: Optional[EPath] = None) -> None: new_aux[k].post_initialize(mds_path) self.aux = new_aux + elif ds_type == EnergonDatasetType.FILESYSTEM: + raise ValueError( + "Filesystem datasets are not supported within metadatasets except as auxiliary datasets." + ) else: raise FileNotFoundError(self.path) diff --git a/src/megatron/energon/sample_utils.py b/src/megatron/energon/sample_utils.py new file mode 100644 index 00000000..a695817f --- /dev/null +++ b/src/megatron/energon/sample_utils.py @@ -0,0 +1,247 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause + +"""Utilities for inspecting and debugging samples.""" + +import dataclasses +import itertools +from typing import Any + +import numpy as np +import torch + + +def default_get_batch_keys(batch: Any) -> list[str] | None: + """Extract sample keys from a batch using common heuristics. + + This function attempts to extract keys from samples by checking for common + key attributes/fields in the following order: + - __key__ or __keys__ attributes + - __key__ or __keys__ dict keys + - "keys" dict key + + Args: + batch: A sample or batch to extract keys from. If a list, uses the first element. + + Returns: + List of string keys if found, None otherwise. + """ + if isinstance(batch, list): + batch = batch[0] + if ( + hasattr(batch, "__key__") + and isinstance(batch.__key__, list) + and all(isinstance(k, str) for k in batch.__key__) + ): + return batch.__key__ + elif ( + hasattr(batch, "__keys__") + and isinstance(batch.__keys__, list) + and all(isinstance(k, str) for k in batch.__keys__) + ): + return batch.__keys__ + elif ( + isinstance(batch, dict) + and "__key__" in batch + and all(isinstance(k, str) for k in batch["__key__"]) + ): + return batch["__key__"] + elif ( + isinstance(batch, dict) + and "__keys__" in batch + and all(isinstance(k, str) for k in batch["__keys__"]) + ): + return batch["__keys__"] + elif ( + isinstance(batch, dict) + and "keys" in batch + and all(isinstance(k, str) for k in batch["keys"]) + ): + return batch["keys"] + return None + + +def format_sample_compact( + sample: Any, + depth: int = 3, + max_items: int = 10, + max_str_len: int = 50, +) -> str: + """Create a compact single-line string representation of a sample. + + Designed for inline use in error messages and logs. For detailed + multi-line debugging output, use format_sample_detailed(). + + Args: + sample: The sample to represent as a string. + depth: Maximum nesting depth to show before truncation. + max_items: Maximum number of items to show in collections. + max_str_len: Maximum string length before truncation. + + Returns: + A compact single-line string representation. + + Example: + >>> format_sample_compact({"key": "value", "count": 42}) + "{'key': 'value', 'count': 42}" + """ + if isinstance(sample, dict): + if depth <= 0: + return "{...}" + return ( + "{" + + ", ".join( + ( + f"{k}: {v!r}" + if isinstance(k, str) and k.startswith("__") + else f"{k}: {format_sample_compact(v, depth - 1, max_items, max_str_len)}" + ) + for k, v in itertools.islice(sample.items(), max_items) + ) + + "}" + ) + elif isinstance(sample, list): + if depth <= 0: + return "[...]" + return ( + "[" + + ", ".join( + format_sample_compact(v, depth - 1, max_items, max_str_len) + for v in sample[:max_items] + ) + + "]" + ) + elif isinstance(sample, tuple): + if depth <= 0: + return "(...)" + return ( + "(" + + ", ".join( + format_sample_compact(v, depth - 1, max_items, max_str_len) + for v in sample[:max_items] + ) + + ")" + ) + elif isinstance(sample, str): + if len(sample) > max_str_len: + return repr(sample[:max_str_len] + "...") + return repr(sample) + elif isinstance(sample, torch.Tensor): + return f"Tensor(shape={sample.shape}, dtype={sample.dtype}, device={sample.device})" + elif isinstance(sample, np.ndarray): + return f"np.ndarray(shape={sample.shape}, dtype={sample.dtype})" + elif dataclasses.is_dataclass(sample): + return f"{type(sample).__name__}({', '.join(f'{field.name}={format_sample_compact(getattr(sample, field.name), depth, max_items, max_str_len)}' for field in dataclasses.fields(sample))})" + else: + repr_str = repr(sample) + return ( + format_sample_compact(repr_str, depth, max_items, max_str_len) + if not isinstance(sample, str) + else repr_str + ) + + +def format_sample_detailed(sample: Any, indent: str = "") -> str: + """Create a detailed multi-line string representation of a sample for debugging. + + Produces human-readable representations with proper indentation and detailed + information about tensors, arrays, and nested structures. For compact inline + representations in error messages, use format_sample_compact(). + + Args: + sample: The sample to represent as a string. + indent: Current indentation level (used internally for recursion). + + Returns: + A formatted multi-line string representation with detailed information. + + Example: + >>> print(format_sample_detailed({"image": torch.zeros(3, 224, 224), "label": 5})) + - image: Tensor(shape=(3, 224, 224), dtype=torch.float32, ...) + - label: 5 + """ + if isinstance(sample, dict): + result = [] + for _, (key, value) in zip(range(25), sample.items()): + result.append(f"{indent} - {key}: {format_sample_detailed(value, indent + ' ')}") + if len(sample) > 25: + result.append(f"{indent} - ... (and {len(sample) - 25} more items)") + return "\n".join(result) + elif isinstance(sample, str): + if len(sample) > 1000: + sample = f"{sample[:1000]}... (and {len(sample) - 1000} more characters)" + if "\n" in sample: + # represent as """ string if it contains newlines: + return '"""' + sample.replace("\n", "\n " + indent) + '"""' + return repr(sample) + elif isinstance(sample, (int, float, bool, type(None))): + return repr(sample) + elif isinstance(sample, (list, tuple)): + if all(isinstance(value, (str, int, float, bool, type(None))) for value in sample): + return f"[{', '.join(repr(value) for value in sample)}]" + result = [] + for _, value in zip(range(10), sample): + result.append(f"{indent} - {format_sample_detailed(value, indent + ' ')}") + if len(sample) > 10: + result.append(f"{indent} - ... (and {len(sample) - 10} more items)") + return "\n".join(result) + elif isinstance(sample, torch.Tensor): + try: + min_val = sample.min().item() + max_val = sample.max().item() + values_repr = "" + # flatten tensor, get first and last 3 values if possible + numel = sample.numel() + flat = sample.flatten() + n_show = 3 + if numel == 0: + values_repr = "values=[]" + elif numel <= n_show * 2: + shown = ", ".join(repr(v.item()) for v in flat) + values_repr = f"values=[{shown}]" + else: + first_vals = ", ".join(repr(v.item()) for v in flat[:n_show]) + last_vals = ", ".join(repr(v.item()) for v in flat[-n_show:]) + values_repr = f"values=[{first_vals}, ..., {last_vals}]" + return ( + f"Tensor(shape={sample.shape}, dtype={sample.dtype}, device={sample.device}, " + f"min={min_val}, max={max_val}, {values_repr})" + ) + except (RuntimeError, ValueError): + # Handle empty tensors or non-numeric dtypes + return f"Tensor(shape={sample.shape}, dtype={sample.dtype}, device={sample.device})" + elif isinstance(sample, np.ndarray): + try: + min_val = sample.min() + max_val = sample.max() + values_repr = "" + flat = sample.ravel() + n_show = 3 + numel = flat.size + if numel == 0: + values_repr = "values=[]" + elif numel <= n_show * 2: + shown = ", ".join(repr(x) for x in flat) + values_repr = f"values=[{shown}]" + else: + first_vals = ", ".join(repr(x) for x in flat[:n_show]) + last_vals = ", ".join(repr(x) for x in flat[-n_show:]) + values_repr = f"values=[{first_vals}, ..., {last_vals}]" + return ( + f"np.ndarray(shape={sample.shape}, dtype={sample.dtype}, " + f"min={min_val}, max={max_val}, {values_repr})" + ) + except (ValueError, TypeError): + # Handle empty arrays or non-numeric dtypes + return f"np.ndarray(shape={sample.shape}, dtype={sample.dtype})" + elif dataclasses.is_dataclass(sample): + result = [f"{indent}{type(sample).__name__}("] + for field in dataclasses.fields(sample): + result.append( + f"{indent} {field.name}={format_sample_detailed(getattr(sample, field.name), indent + ' ')}" + ) + result.append(f"{indent})") + return "\n".join(result) + else: + repr_str = repr(sample) + return repr_str[:200] + "..." if len(repr_str) > 200 else repr_str diff --git a/src/megatron/energon/savable_loader.py b/src/megatron/energon/savable_loader.py index c750c4ec..7eb55678 100644 --- a/src/megatron/energon/savable_loader.py +++ b/src/megatron/energon/savable_loader.py @@ -30,8 +30,8 @@ from torch.utils.data import DataLoader, IterableDataset from megatron.energon.cache import CachePool +from megatron.energon.deprecation import deprecated from megatron.energon.edataclass import edataclass -from megatron.energon.errors import deprecated from megatron.energon.flavors.base_dataset import ( FlexState, SavableDataset, @@ -43,7 +43,7 @@ from megatron.energon.wrappers.base import BaseWrapperDataset from megatron.energon.wrappers.batch_dataset import BatchDataset from megatron.energon.wrappers.gc_dataset import GC_DEFAULT_EVERY_N_ITER, GcDataset, gc_init_worker -from megatron.energon.wrappers.log_sample_dataset import default_get_keys +from megatron.energon.wrappers.log_sample_dataset import default_get_batch_keys from megatron.energon.wrappers.watchdog_dataset import WatchdogDataset T = TypeVar("T") @@ -772,6 +772,7 @@ def __init__( if self.worker_config.num_workers > 0: kwargs["persistent_workers"] = True kwargs["prefetch_factor"] = prefetch_factor + kwargs["multiprocessing_context"] = "fork" # Assert that prefetch_factor works well with num_checkpoints. # This ensures that the oldest checkpoint is old enough to cover @@ -841,7 +842,7 @@ def _epoch_iter(self): # ) # self._debugf.flush() if self.worker_config.should_log(level=1): - keys = default_get_keys(sample) + keys = default_get_batch_keys(sample) self.worker_config.worker_log( { **{ @@ -1272,6 +1273,7 @@ def __init__( # These must not be specified for num_workers =0 kwargs["persistent_workers"] = True kwargs["prefetch_factor"] = prefetch_factor + kwargs["multiprocessing_context"] = "fork" seed_per_worker = [ self.worker_config.worker_seed(i) for i in range(self.worker_config.num_workers) @@ -1323,7 +1325,7 @@ def _inner_generator(iterator): for worker_id, sample_idx, sample in iterator: # If the next sample will be from the first worker, we can safely resume if self.worker_config.should_log(level=1): - keys = default_get_keys(sample) + keys = default_get_batch_keys(sample) self.worker_config.worker_log( { **{ diff --git a/src/megatron/energon/source_info.py b/src/megatron/energon/source_info.py index 1a90c228..5de0a279 100644 --- a/src/megatron/energon/source_info.py +++ b/src/megatron/energon/source_info.py @@ -36,3 +36,23 @@ def add_source_info(sample: Any, *source_infos: SourceInfo) -> Any: else: sample["__sources__"] = (*sample["__sources__"], *source_infos) return sample + + +def get_source_info(sample: Any) -> list[SourceInfo] | None: + """Extracts source info from a sample.""" + from megatron.energon.flavors.base_dataset import Sample + + if isinstance(sample, dict): + sources = sample.get("__sources__") + elif isinstance(sample, list): + if len(sample) > 0 and isinstance(sample[0], dict): + sources = [src for s in sample for src in s.get("__sources__", ())] + elif len(sample) > 0 and isinstance(sample[0], Sample): + sources = [src for s in sample for src in s.__sources__] + else: + sources = None + elif isinstance(sample, Sample): + sources = sample.__sources__ + else: + sources = None + return list(sources) if sources else None diff --git a/src/megatron/energon/task_encoder/base.py b/src/megatron/energon/task_encoder/base.py index e75d8220..5335d018 100644 --- a/src/megatron/energon/task_encoder/base.py +++ b/src/megatron/energon/task_encoder/base.py @@ -28,6 +28,7 @@ from typing_extensions import ParamSpec from megatron.energon.cache import CachePool, DecodeFileStore, FileStore +from megatron.energon.cache.base import PrimaryFileStore from megatron.energon.edataclass import edataclass from megatron.energon.flavors import ( CrudeSample, @@ -250,11 +251,9 @@ def get_stateless(fn: Callable) -> bool: return getattr(fn, "__stateless__", False) -def get_failure_tolerance( - fn: Callable, default_failure_tolerance: Optional[int] = None -) -> Optional[int]: +def get_failure_tolerance(fn: Callable, default_failure_tolerance: Optional[int] = None) -> int: """Get the failure tolerance of a function.""" - return getattr(fn, "__failure_tolerance__", default_failure_tolerance) + return getattr(fn, "__failure_tolerance__", default_failure_tolerance) or 0 @edataclass @@ -374,47 +373,6 @@ class TaskEncoder(ABC, Generic[T_sample, T_encoded_sample, T_raw_batch, T_batch] #: The decoder to use for decoding samples. Set manually as needed to override options. decoder: Optional[SampleDecoder] = SampleDecoder() - @stateless - def cook_crude_sample( - self, - sample: Union[T_sample, CrudeSample], - get_primary_aux: Callable[[], FileStore], - **aux: FileStore, - ) -> T_sample: - """ - Cooks a crude sample. - - Args: - sample: The sample to cook. - get_primary_aux: A function that returns the (cached) primary auxiliary dataset. - **aux: The auxiliary side dishes to use for cooking. - - Returns: The cooked sample. - """ - if isinstance(sample, CrudeSample): - for cooker in self.cookers: - if cooker.is_match(sample): - assert get_stateless(cooker.cook), "Cooker must be stateless" - if not cooker.need_primary and not cooker.need_cache: - kwargs = aux - else: - kwargs: dict = {} - if cooker.need_primary: - kwargs["primary"] = get_primary_aux() - kwargs.update(aux) - if cooker.need_cache: - kwargs["cache"] = self.cache - return cooker.cook(sample, **kwargs) - - raise NotImplementedError( - "You are using crude samples but not providing a way to cook them: " - f"Sample key={sample['__key__']}, subflavors={sample['__subflavors__']}, " - f"self.cookers={self.cookers}" - ) - else: - assert isinstance(sample, Sample), "Sample must be a complete Sample or a CrudeSample" - return sample - def _is_overridden( self, bound_method: Callable[..., Any], bases: Optional[Sequence[Type[Any]]] = None ) -> bool: @@ -442,6 +400,33 @@ def _is_overridden( bases = (TaskEncoder,) return not any(getattr(base, func.__name__) is func for base in bases) + @stateless + def cook_crude_sample( + self, + sample: CrudeSample, + cooker: Cooker[CrudeSample], + aux: dict[str, FileStore], + ) -> T_sample: + """ + Cooks a crude sample. + + Args: + sample: The sample to cook. + cooker: The cooker to use. + aux: Aux datasets to use. + + Returns: The cooked sample. + """ + + assert get_stateless(cooker.cook), "Cooker must be stateless" + if cooker.need_primary or cooker.need_cache: + aux = {**aux} + if cooker.need_cache: + aux["cache"] = self.cache + if cooker.need_primary: + aux["primary"] = PrimaryFileStore(aux["primary"], current_key=sample["__key__"]) + return cooker.cook(sample, **aux) + @stateless def encode_sample( self, sample: T_sample @@ -464,9 +449,7 @@ def preencode_sample( return sample @stateless - def postencode_sample( - self, sample: T_sample - ) -> Union[T_encoded_sample, Generator[T_encoded_sample, None, None]]: + def postencode_sample(self, sample: T_sample) -> T_encoded_sample: """Post-encode a single sample. May raise :exc:`megatron.energon.SkipSample` to skip a sample. Alternatively, this can be a generator that yields (or ignores) new samples. Use in conjunction with packing and caching. @@ -622,7 +605,7 @@ def build_batch( final_packer_failure_tolerance=get_failure_tolerance( self.pack_selected_samples, self.__default_failure_tolerance__ ), - sample_encoder_failure_tolerance=None + sample_encoder_failure_tolerance=0 if post_encode_fn is None else get_failure_tolerance(post_encode_fn, self.__default_failure_tolerance__), ) @@ -702,49 +685,42 @@ def build_cook_crude_sample( assert self.cookers, "No cookers registered, but got crude dataset." - if aux is not None and self.decoder is not None: - aux = {k: DecodeFileStore(v, decoder=self.decoder) for k, v in aux.items()} + if aux is None: + aux = {} - # Cache the primary auxiliary dataset for this dataset, i.e. construct it once when needed - primary_aux = None + if self.decoder is not None: + aux = {k: DecodeFileStore(v, decoder=self.decoder) for k, v in aux.items()} - def _get_primary_aux(): - nonlocal primary_aux - if primary_aux is None: - try: - if aux is not None: - primary_aux = aux.get("primary") - if primary_aux is None: - primary_aux = get_primary_aux() - assert primary_aux is not None, "Primary auxiliary dataset must always exist" - if self.decoder is not None: - primary_aux = DecodeFileStore(primary_aux, decoder=self.decoder) - except Exception as e: - # Make the exception throw through for the sample being loaded - raise SystemError("Error getting primary auxiliary dataset") from e - return primary_aux - - if aux is not None: - cook_fn = functools.partial( - self.cook_crude_sample, get_primary_aux=_get_primary_aux, **aux - ) + for cooker in self.cookers: + if cooker.is_match(subflavors): + break else: - cook_fn = functools.partial(self.cook_crude_sample, get_primary_aux=_get_primary_aux) + raise ValueError(f"No cooker found for subflavors: {subflavors}") + + if cooker.need_primary and "primary" not in aux: + try: + primary_aux = get_primary_aux() + assert primary_aux is not None, "Primary auxiliary dataset must always exist" + if self.decoder is not None: + primary_aux = DecodeFileStore(primary_aux, decoder=self.decoder) + aux["primary"] = primary_aux + except Exception as e: + # Make the exception throw through for the sample being loaded + raise SystemError("Error getting primary auxiliary dataset") from e + + cook_fn = functools.partial(self.cook_crude_sample, cooker=cooker, aux=aux) return MapDataset( dataset, cook_fn, worker_config=worker_config, - stateless_map_fn=True, + stateless_map_fn=get_stateless(self.cook_crude_sample), map_fn_config=dict( - cookers=[ - dict( - cook=SavableDataset._function_config(cooker.cook), - has_subflavors=cooker.has_subflavors, - ) - for cooker in self.cookers - ], - subflavors=subflavors, + cooker=dict( + cook=SavableDataset._function_config(cooker.cook), + has_subflavors=cooker.has_subflavors, + aux={k: {"_path": str(v.get_path())} for k, v in aux.items()}, + ), ), failure_tolerance=get_failure_tolerance(cook_fn, self.__default_failure_tolerance__), ) diff --git a/src/megatron/energon/task_encoder/cooking.py b/src/megatron/energon/task_encoder/cooking.py index 6ab394ef..23a0c9cc 100644 --- a/src/megatron/energon/task_encoder/cooking.py +++ b/src/megatron/energon/task_encoder/cooking.py @@ -7,7 +7,6 @@ from typing import Callable, Generic, Optional, TypeVar, Union, overload from megatron.energon.flavors.base_dataset import Sample -from megatron.energon.flavors.crude import CrudeSample from megatron.energon.source_info import SourceInfo T_sample = TypeVar("T_sample", bound=Sample, covariant=True) @@ -94,15 +93,12 @@ def need_primary(self) -> bool: def need_cache(self) -> bool: return get_cooker_need_cache(self.cook) - def is_match(self, crude_sample: CrudeSample) -> bool: + def is_match(self, subflavors: dict) -> bool: if self.has_subflavors is not None: # Checks if the dict entries provided as a filter all match # the ones in the sample. The sample may have additional entries. for k, v in self.has_subflavors.items(): - if ( - k not in crude_sample["__subflavors__"] - or crude_sample["__subflavors__"][k] != v - ): + if k not in subflavors or subflavors[k] != v: return False return True diff --git a/src/megatron/energon/task_encoder/loader.py b/src/megatron/energon/task_encoder/loader.py index 07d17da4..5c680bb2 100644 --- a/src/megatron/energon/task_encoder/loader.py +++ b/src/megatron/energon/task_encoder/loader.py @@ -4,8 +4,8 @@ from pathlib import Path from typing import Any, List, Literal, Optional, Tuple, TypeVar, Union +from megatron.energon.deprecation import warn_deprecated from megatron.energon.epathlib import EPath -from megatron.energon.errors import warn_deprecated from megatron.energon.flavors import BaseCoreDatasetFactory, SavableDataset from megatron.energon.flavors.webdataset.sample_decoder import SampleDecoder from megatron.energon.metadataset import load_dataset @@ -31,13 +31,24 @@ def _split_kwargs(kwargs: dict) -> dict: return loader_kwargs -def _split_deprecated_decoder_kwargs(kwargs: dict, task_encoder: TaskEncoder) -> None: +def _split_deprecated_dataset_kwargs( + kwargs: dict, task_encoder: TaskEncoder, worker_config: WorkerConfig +) -> None: """ + This function takes deprecated kwargs to the get_train_dataset and get_val_dataset functions and splits them into the new decoder kwargs and the worker_config error handler. + + Args: + kwargs: The kwargs to consider (will modify this in place). + task_encoder: The task encoder to apply the kwargs to. + worker_config: The worker configuration to apply the kwargs to. + + List of deprecated kwargs: auto_decode: bool = True, image_decode: ImageDecoder = "torchrgb", ignore_decoder_errors: bool = False, av_decode: AVDecoder = "AVDecoder", video_decode_audio: bool = False, + handler: Callable[[Exception, Any, list[SourceInfo] | None], None] = None, """ auto_decode = True @@ -51,6 +62,13 @@ def _split_deprecated_decoder_kwargs(kwargs: dict, task_encoder: TaskEncoder) -> if "video_decode_audio" in kwargs: decoder_kwargs["video_decode_audio"] = kwargs.pop("video_decode_audio") + if "handler" in kwargs: + warn_deprecated( + "The handler kwarg is deprecated and will be removed in a future version. Instead, use the error handler in the worker_config." + "Ignoring the handler argument and using the worker_config handler." + ) + kwargs.pop("handler") + if not auto_decode: task_encoder.decoder = None elif len(decoder_kwargs) > 0: @@ -137,7 +155,7 @@ def get_train_dataset( """ loader = load_dataset(path, **_split_kwargs(kwargs)) - _split_deprecated_decoder_kwargs(kwargs, task_encoder) + _split_deprecated_dataset_kwargs(kwargs, task_encoder, worker_config) datasets = loader.get_datasets( training=True, @@ -197,7 +215,7 @@ def get_val_dataset( Returns: The loaded dataset. """ - _split_deprecated_decoder_kwargs(kwargs, task_encoder) + _split_deprecated_dataset_kwargs(kwargs, task_encoder, worker_config) loader = load_dataset(path, **_split_kwargs(kwargs)) datasets = loader.get_datasets( training=False, @@ -252,7 +270,7 @@ def get_val_datasets( Returns: The loaded val datasets, with the source datasets. """ - _split_deprecated_decoder_kwargs(kwargs, task_encoder) + _split_deprecated_dataset_kwargs(kwargs, task_encoder, worker_config) loader = load_dataset(path, **_split_kwargs(kwargs)) datasets = loader.get_datasets( training=False, diff --git a/src/megatron/energon/tools/info.py b/src/megatron/energon/tools/info.py index eb1eef6f..22e7c8c7 100644 --- a/src/megatron/energon/tools/info.py +++ b/src/megatron/energon/tools/info.py @@ -11,7 +11,7 @@ from megatron.energon.epathlib import EPath from megatron.energon.eyaml import load_yaml, load_yaml_json from megatron.energon.flavors.base_dataset import BaseCoreDatasetFactory -from megatron.energon.flavors.webdataset import MAIN_FOLDER_NAME +from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME from megatron.energon.flavors.webdataset.metadata import get_dataset_info fmt = """ diff --git a/src/megatron/energon/tools/prepare.py b/src/megatron/energon/tools/prepare.py index 4ab5302e..5e6643cf 100644 --- a/src/megatron/energon/tools/prepare.py +++ b/src/megatron/energon/tools/prepare.py @@ -19,13 +19,14 @@ CrudeWebdataset, Sample, ) -from megatron.energon.flavors.webdataset import MAIN_FOLDER_NAME +from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME from megatron.energon.flavors.webdataset.metadata import ( EnergonDatasetType, check_dataset_info_present, get_dataset_info, get_dataset_type, ) +from megatron.energon.media.extractor import MediaFilterConfig from megatron.energon.metadataset.loader import prepare_metadataset @@ -127,6 +128,62 @@ def printify_json(data: Any) -> Any: help="If set, the tar files will be shuffled before splitting.", is_flag=True, ) +@click.option( + "--media-metadata-by-glob", + type=str, + help="Media detection by using one or more glob patterns such as '*.jpg'. Separate multiple patterns by commas.", +) +@click.option( + "--media-metadata-by-header", + is_flag=True, + help="Media detection by binary file header.", +) +@click.option( + "--media-metadata-by-extension", + is_flag=True, + help="Media detection by standard file extensions.", +) +@click.option( + "--fix-duplicates", + help="Fix duplicate keys in the dataset.", + is_flag=True, +) +@click.option( + "--non-interactive", + help="If set, the prepare will not ask for interactive input.", + is_flag=True, +) +@click.option( + "--split-ratio", + help="Train/val/test split ratio in the form '0.8,0.1,0.1' or '8,1,1'. Required for non-interactive mode unless --split-parts or --tar-index-only is used.", + default=None, +) +@click.option( + "--force-overwrite", + help="Overwrite existing dataset preparation without confirmation.", + is_flag=True, +) +@click.option( + "--skip-dataset-yaml", + help="Skip dataset.yaml creation (i.e. no sample loader or dataset.yaml will be created).", + is_flag=True, +) +@click.option( + "--dataset-yaml-name", + help="Name of the dataset.yaml file to create.", + default="dataset.yaml", + type=str, +) +@click.option( + "--sample-type", + help="Sample type class name (e.g., 'CaptioningSample', 'CrudeWebdataset'). Required for non-interactive dataset.yaml creation if creating the dataset.yaml.", + default=None, +) +@click.option( + "--field-map", + help='Field mapping in JSON format (e.g., \'{"image": "jpg", "caption": "txt"}\'). If not set in non-interactive mode, a sample loader from template will be created. Use with --sample-type. Only applies if sample_type is not set to CrudeWebdataset.', + default=None, +) def command( path: EPath, progress: bool, @@ -135,6 +192,17 @@ def command( num_workers: int, tar_index_only: bool, shuffle_tars: bool, + media_metadata_by_glob: str | None, + media_metadata_by_header: bool, + media_metadata_by_extension: bool, + fix_duplicates: bool, + non_interactive: bool, + split_ratio: Optional[str], + force_overwrite: bool, + sample_type: Optional[str], + field_map: Optional[str], + skip_dataset_yaml: bool, + dataset_yaml_name: str, ): """Prepare WebDataset for use with energon. @@ -143,16 +211,45 @@ def command( details. """ + do_media_metadata = bool( + media_metadata_by_glob is not None + or media_metadata_by_header + or media_metadata_by_extension + ) + + if do_media_metadata and tar_index_only: + raise click.UsageError("--media-metadata-by-... cannot be combined with --tar-index-only") + + media_filter_config = ( + MediaFilterConfig.parse( + media_metadata_by_glob, media_metadata_by_header, media_metadata_by_extension + ) + if do_media_metadata + else None + ) + ds_type = get_dataset_type(path) if ds_type == EnergonDatasetType.METADATASET: + if do_media_metadata: + raise click.ClickException( + "Metadatasets cannot store media metadata. Remove --media-metadata-by-... to continue." + ) print("Preparing metadataset...") prepare_metadataset(path) return elif ds_type == EnergonDatasetType.JSONL: + if do_media_metadata: + raise click.ClickException( + "JSONL datasets do not support media metadata. Remove --media-metadata-by-... to continue." + ) print("Preparing jsonl dataset...") count = CrudeJsonlDatasetFactory.prepare_dataset(path) print(f"Done. Found {count} samples.") return + elif ds_type == EnergonDatasetType.FILESYSTEM: + raise click.ClickException( + "Filesystem datasets must be prepared using 'energon prepare-media'." + ) assert path.is_dir(), f"Path {path} is not a known dataset type" @@ -161,10 +258,18 @@ def command( all_tars = list(info["shard_counts"].keys()) else: if check_dataset_info_present(path): - if not click.confirm( - "It seems the dataset had already been prepared. Do you want to continue?" - ): - return + if force_overwrite: + # Silently continue if force_overwrite is set + pass + elif non_interactive: + raise click.ClickException( + "Dataset has already been prepared. Use --force-overwrite to overwrite." + ) + else: + if not click.confirm( + "It seems the dataset had already been prepared. Do you want to continue?" + ): + return all_tars = list(path.glob("**/*.tar")) + list(path.glob("**/*.tgz")) all_tars = [str(p.relative_to(path)) for p in sorted(all_tars)] @@ -190,9 +295,19 @@ def command( split_parts_patterns = [tuple(x.split(":", 1)) for x in split_parts] split_parts_ratio = None elif not tar_index_only: - split_input = click.prompt( - 'Please enter a desired train/val/test split like "0.5, 0.2, 0.3" or "8,1,1"', type=str - ) + if split_ratio is not None: + # Use the provided split_ratio flag + split_input = split_ratio + elif non_interactive: + raise click.ClickException( + "--split-ratio is required in non-interactive mode " + "(unless --split-parts or --tar-index-only is used)." + ) + else: + split_input = click.prompt( + 'Please enter a desired train/val/test split like "0.5, 0.2, 0.3" or "8,1,1"', + type=str, + ) # Extract split floats try: split = [float(x.strip()) for x in split_input.split(",")] @@ -222,7 +337,7 @@ def progress_fn(els, length=None): def progress_fn(els, length=None): return els - found_types, duplicates = BaseWebdatasetFactory.prepare_dataset( + found_types = BaseWebdatasetFactory.prepare_dataset( path, all_tars, split_parts_ratio=split_parts_ratio, @@ -231,25 +346,14 @@ def progress_fn(els, length=None): tar_index_only=tar_index_only, shuffle_seed=42 if shuffle_tars else None, workers=num_workers, + media_filter=media_filter_config, + fix_duplicates=fix_duplicates, ) - if duplicates: - print(f"Examples of duplicates found: {duplicates}") - print() - print( - "The dataset has duplicate keys. Best practice is to use unique keys. " - "You won't be able to use this dataset for joining " - "later on." - ) - found_types = list(found_types) if tar_index_only: return - if duplicates: - if not click.confirm("Do you want to continue?"): - return - # Print json of first two samples for sample_idx, data in enumerate( BaseWebdatasetFactory.iter_dataset_content(path / all_tars[0], ("json",)) @@ -272,7 +376,20 @@ def progress_fn(els, length=None): click.echo(f"Found the following part types in the dataset: {', '.join(found_types)}") allow_interactive_field_map = True - if click.confirm("Do you want to create a dataset.yaml interactively?", default=True): + # Determine if we should create dataset.yaml + if skip_dataset_yaml: + should_create_yaml = False + elif sample_type is not None: + should_create_yaml = True + elif non_interactive: + click.echo("Skipping dataset.yaml creation (use --sample-type to create it).") + should_create_yaml = False + else: + should_create_yaml = click.confirm( + "Do you want to create a dataset.yaml interactively?", default=True + ) + + if should_create_yaml: # Get a list of all classes in megatron.energon that are subclasses of WebdatasetBase import megatron.energon as data_import @@ -283,18 +400,34 @@ def progress_fn(els, length=None): ] display_name_and_class.append(("Crude sample (plain dict for cooking)", CrudeWebdataset)) - # Print all classes and ask user to pick one - click.echo("The following sample types are available:") - for i, (name, cls) in enumerate(display_name_and_class): - click.echo(f"{i}. {name}") - while True: - choice = click.prompt("Please enter a number to choose a class", type=int) - try: - _, cls = display_name_and_class[choice] - break - except IndexError: - click.echo("Invalid choice. Please try again.") - continue + # Find the class by name if sample_type is provided + cls = None + if sample_type is not None: + for name, candidate_cls in display_name_and_class: + if candidate_cls.__name__ == sample_type: + cls = candidate_cls + break + if cls is None: + available = "\n".join(f" - {c.__name__}" for _, c in display_name_and_class) + raise click.ClickException( + f"Sample type '{sample_type}' not found.\nAvailable sample types:\n{available}" + ) + elif non_interactive: + # This should not happen due to earlier checks, but just in case + raise click.ClickException("--sample-type is required in non-interactive mode.") + else: + # Print all classes and ask user to pick one + click.echo("The following sample types are available:") + for i, (name, candidate_cls) in enumerate(display_name_and_class): + click.echo(f"{i}. {name}") + while True: + choice = click.prompt("Please enter a number to choose a class", type=int) + try: + _, cls = display_name_and_class[choice] + break + except IndexError: + click.echo("Invalid choice. Please try again.") + continue if cls == CrudeWebdataset: click.echo( @@ -308,8 +441,9 @@ def progress_fn(els, length=None): "__class__": cls.__name__, } else: - click.echo("The sample type you selected:\n") - click.echo(inspect.getsource(cls)) + if not non_interactive: + click.echo("The sample type you selected:\n") + click.echo(inspect.getsource(cls)) dataset_definition = { "sample_type": { @@ -318,74 +452,111 @@ def progress_fn(els, length=None): }, } - if not allow_interactive_field_map: + if not allow_interactive_field_map and not non_interactive: click.echo( "You cannot set a field_map for this dataset. You will need a sample_loader." ) - if allow_interactive_field_map and click.confirm( + # Determine whether to use field_map or sample_loader + use_field_map = False + if field_map is not None: + use_field_map = True + elif non_interactive: + # In non-interactive mode without field_map, use sample_loader + use_field_map = False + elif allow_interactive_field_map and click.confirm( "Do you want to set a simple field_map[Y] (or write your own sample_loader [n])?", default=True, ): - click.echo( - "\nFor each field, please specify the corresponding name in the WebDataset." - ) - click.echo(f"Available types in WebDataset: {', '.join(found_types)}") - click.echo("Leave empty for skipping optional field") - click.echo( - "You may also access json fields e.g. by setting the field to: json[field][field]" - ) - click.echo("You may also specify alternative fields e.g. by setting to: jpg,png") - - click.echo(f"Please enter the field_map for {cls.__name__}:") - - dataset_definition["field_map"] = field_map = {} - for field in dataclasses.fields(cls): - if field.name in ( - "__key__", - "__restore_key__", - "__subflavors__", - "__sources__", - ): - continue - while True: - if ( - field.default is dataclasses.MISSING - and field.default_factory is dataclasses.MISSING + use_field_map = True + + if use_field_map: + if field_map is not None: + # Parse field_map from JSON string + try: + parsed_field_map = json.loads(field_map) + dataset_definition["field_map"] = parsed_field_map + except json.JSONDecodeError as e: + raise click.ClickException(f"Invalid JSON in --field-map: {e}") + elif non_interactive: + # This shouldn't happen due to earlier checks, but just in case + raise click.ClickException( + "--field-map is required when using field mapping in non-interactive mode." + ) + else: + click.echo( + "\nFor each field, please specify the corresponding name in the WebDataset." + ) + click.echo(f"Available types in WebDataset: {', '.join(found_types)}") + click.echo("Leave empty for skipping optional field") + click.echo( + "You may also access json fields e.g. by setting the field to: json[field][field]" + ) + click.echo( + "You may also specify alternative fields e.g. by setting to: jpg,png" + ) + + click.echo(f"Please enter the field_map for {cls.__name__}:") + + dataset_definition["field_map"] = field_map_dict = {} + for field in dataclasses.fields(cls): + if field.name in ( + "__key__", + "__restore_key__", + "__subflavors__", + "__sources__", ): - default = "" - elif field.default is not dataclasses.MISSING: - default = f", default: {field.default}" - elif field.default_factory is not dataclasses.MISSING: - default = f", default: {field.default_factory!r}" - else: - raise RuntimeError("This should never happen") - field_map[field.name] = input( - f"Please enter a webdataset field name for '{field.name}' " - f"({field.type}{default}): ", - ) - if not field_map[field.name] and default: - del field_map[field.name] - break - type_ok = True - for option in field_map[field.name].split(","): - field_name = option.split("[", 1)[0] - if field_name not in found_types: - click.echo( - "That type doesn't exist in the WebDataset. Please try again." - ) - type_ok = False - if type_ok: - break + continue + while True: + if ( + field.default is dataclasses.MISSING + and field.default_factory is dataclasses.MISSING + ): + default = "" + elif field.default is not dataclasses.MISSING: + default = f", default: {field.default}" + elif field.default_factory is not dataclasses.MISSING: + default = f", default: {field.default_factory!r}" + else: + raise RuntimeError("This should never happen") + field_map_dict[field.name] = input( + f"Please enter a webdataset field name for '{field.name}' " + f"({field.type}{default}): ", + ) + if not field_map_dict[field.name] and default: + del field_map_dict[field.name] + break + type_ok = True + for option in field_map_dict[field.name].split(","): + field_name = option.split("[", 1)[0] + if field_name not in found_types: + click.echo( + "That type doesn't exist in the WebDataset. Please try again." + ) + type_ok = False + if type_ok: + break else: if not allow_interactive_field_map: template_part_types = set(["TODO"]) else: template_part_types = found_types - if not (path / MAIN_FOLDER_NAME / "sample_loader.py").is_file() or click.confirm( - "Do you want to override the existing sample loader?" - ): + sample_loader_exists = (path / MAIN_FOLDER_NAME / "sample_loader.py").is_file() + + if not sample_loader_exists: + # Always create if it doesn't exist + should_create_sample_loader = True + elif non_interactive: + # In non-interactive mode, don't override existing file + should_create_sample_loader = False + else: + # In interactive mode, ask user + should_create_sample_loader = click.confirm( + "Do you want to override the existing sample loader?" + ) + + if should_create_sample_loader: with (path / MAIN_FOLDER_NAME / "sample_loader.py").open("w") as f: f.write( sample_loader_template( @@ -401,7 +572,7 @@ def progress_fn(els, length=None): dataset_definition["part_filter"] = "sample_loader.py:part_filter" # Write the dataset.yaml file - with (path / MAIN_FOLDER_NAME / "dataset.yaml").open("w") as f: + with (path / MAIN_FOLDER_NAME / dataset_yaml_name).open("w") as f: yaml.dump(dataset_definition, f, sort_keys=False) else: click.echo("You will have to add a dataset.yaml manually.") diff --git a/src/megatron/energon/tools/prepare_media.py b/src/megatron/energon/tools/prepare_media.py new file mode 100644 index 00000000..bcc99bf6 --- /dev/null +++ b/src/megatron/energon/tools/prepare_media.py @@ -0,0 +1,106 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause + + +import click + +from megatron.energon.epathlib import EPath +from megatron.energon.flavors.webdataset.metadata import ( + EnergonDatasetType, + get_dataset_type, +) +from megatron.energon.flavors.webdataset.prepare import WebdatasetPreparator +from megatron.energon.media.extractor import MediaFilterConfig +from megatron.energon.media.filesystem_prepare import prepare_filesystem_dataset + + +@click.command(name="prepare-media") +@click.argument( + "path", + type=click.Path(path_type=EPath), +) +@click.option( + "--progress/--no-progress", + default=True, +) +@click.option( + "--num-workers", + type=int, + default=16, + help="Number of workers to use to scan files", +) +@click.option( + "--media-metadata-by-glob", + type=str, + help="Media detection by using one or more glob patterns such as '*.jpg'. Separate multiple patterns by commas.", +) +@click.option( + "--media-metadata-by-header", + is_flag=True, + help="Media detection by binary file header.", +) +@click.option( + "--media-metadata-by-extension", + is_flag=True, + help="Media detection by standard file extensions.", +) +def command( + path: EPath, + progress: bool, + num_workers: int, + media_metadata_by_glob: str | None, + media_metadata_by_header: bool, + media_metadata_by_extension: bool, +): + """Prepare a filesystem dataset by collecting media metadata.""" + + media_filter_config = MediaFilterConfig.parse( + media_metadata_by_glob, media_metadata_by_header, media_metadata_by_extension + ) + + ds_type = get_dataset_type(path) + if ds_type == EnergonDatasetType.WEBDATASET: + click.echo("Preparing webdataset and computing media metadata...") + + if progress: + + def progress_fn(els, length=None): + with click.progressbar( + els, + label="Processing shards", + show_pos=True, + length=length, + ) as bar: + yield from bar + + else: + + def progress_fn(els, length=None): + return els + + count = WebdatasetPreparator.add_media_metadata( + path, + media_filter=media_filter_config, + workers=num_workers, + progress_fn=progress_fn, + ) + + click.echo(f"Done. Stored metadata for {count} files.") + return + elif ds_type not in (EnergonDatasetType.FILESYSTEM, EnergonDatasetType.INVALID): + raise click.ClickException( + f"'prepare-media' only supports WebDatasets or filesystem datasets, but this path is a '{ds_type}' dataset" + ) + + click.echo("Preparing filesystem dataset and computing media metadata...") + stored = prepare_filesystem_dataset( + path, + media_filter_config, + progress=progress, + workers=num_workers, + ) + click.echo(f"Done. Stored metadata for {stored} files.") + + +if __name__ == "__main__": + command() diff --git a/src/megatron/energon/tools/preview.py b/src/megatron/energon/tools/preview.py index f1d449b9..92c0a918 100644 --- a/src/megatron/energon/tools/preview.py +++ b/src/megatron/energon/tools/preview.py @@ -12,7 +12,7 @@ from megatron.energon import Sample, StandardWebdatasetFactory from megatron.energon.dataset_config import load_config from megatron.energon.epathlib import EPath -from megatron.energon.flavors.webdataset import MAIN_FOLDER_NAME +from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME from megatron.energon.flavors.webdataset.empty_dataset_error import EmptyDatasetError from megatron.energon.loader import get_loader from megatron.energon.worker import WorkerConfig diff --git a/src/megatron/energon/worker.py b/src/megatron/energon/worker.py index ec01e0b2..fca26369 100644 --- a/src/megatron/energon/worker.py +++ b/src/megatron/energon/worker.py @@ -6,12 +6,14 @@ import multiprocessing from dataclasses import dataclass from pathlib import Path -from typing import Any, ClassVar, Dict, List, Optional, TextIO, TypeVar +from typing import Any, Callable, ClassVar, Dict, List, Optional, TextIO, TypeVar import torch.distributed import torch.utils.data from megatron.energon.cache import CachePool +from megatron.energon.errors import log_exception, reraise_exception +from megatron.energon.source_info import SourceInfo __all__ = ("WorkerConfig",) @@ -61,6 +63,14 @@ class WorkerConfig: _worker_debug_file: Optional[TextIO] = None #: worker_id of the opened worker debug file _worker_debug_file_worker_id: Optional[int] = None + #: The global error handler to use for the worker during normal iteration. + global_error_handler: Callable[ + [Exception, Any | list[Any], Optional[list[SourceInfo]]], None + ] = log_exception + #: The error handler to use when restoring samples. Defaults to reraising the exception. + restore_error_handler: Callable[ + [Exception, Any | list[Any], Optional[list[SourceInfo]]], None + ] = reraise_exception #: The current sample index within the current iterating worker _sample_index_stack: ClassVar[Optional[List[int]]] = None diff --git a/src/megatron/energon/wrappers/_log_exception.py b/src/megatron/energon/wrappers/_log_exception.py deleted file mode 100644 index 791be674..00000000 --- a/src/megatron/energon/wrappers/_log_exception.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. -# SPDX-License-Identifier: BSD-3-Clause - -import traceback -from typing import Optional, Sequence - -from megatron.energon.source_info import SourceInfo - - -def log_exception(_e: Exception, sample, sources: Optional[Sequence[SourceInfo]] = None): - traceback.print_exc() - print("-" * 10) - - if sources: - print("Sources:") - for source in sources: - print( - f" - {source.dataset_path}[{source.index}] {source.shard_name}{source.file_names!r}" - ) - print("-" * 10) - - sample_str = str(sample) - if len(sample_str) > 400: - sample_str = sample_str[:200] + "..." + sample_str[-200:] - - print(sample_str) - - print("-" * 10) diff --git a/src/megatron/energon/wrappers/batch_dataset.py b/src/megatron/energon/wrappers/batch_dataset.py index 35473390..8cdfa8eb 100644 --- a/src/megatron/energon/wrappers/batch_dataset.py +++ b/src/megatron/energon/wrappers/batch_dataset.py @@ -11,19 +11,15 @@ Iterator, List, Optional, - Sequence, Tuple, TypeVar, Union, ) -from megatron.energon.errors import SYSTEM_EXCEPTIONS, FatalSampleError +from megatron.energon.errors import ErrorContext, handle_restore_errors from megatron.energon.flavors.base_dataset import SavableDataset, set_sample_restore_key -from megatron.energon.source_info import SourceInfo from megatron.energon.worker import WorkerConfig -from megatron.energon.wrappers._log_exception import log_exception from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex, get_sample_restore_key -from megatron.energon.wrappers.skip import SkipSample T_batch = TypeVar("T_batch", covariant=True) T_batch_sample = TypeVar("T_batch_sample", covariant=True) @@ -35,11 +31,10 @@ class BatchDataset(BaseWrapperDataset[T_batch_sample, T_batch], Generic[T_batch_ batch_size: int batcher: Callable[[List[T_batch_sample]], T_batch] drop_last: bool - error_handler: Callable[[Exception, list[T_batch_sample], Sequence[SourceInfo]], None] _sample_index: SampleIndex _generator_sample_keys: Optional[Any] _generator_offset: Optional[int] - _last_batch_failures: int = 0 + _batch_failure_handler: ErrorContext _savable_fields = ("_sample_index", "_generator_sample_keys", "_generator_offset") @@ -52,9 +47,6 @@ def __init__( batcher_stateless: bool = False, batcher_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] = None, drop_last: bool = False, - error_handler: Callable[ - [Exception, list[T_batch_sample], Sequence[SourceInfo]], None - ] = log_exception, failure_tolerance: int = 100, worker_config: WorkerConfig, ): @@ -70,8 +62,6 @@ def __init__( batcher_config: Configuration for the batcher function. If callable, it should return the configuration. Defaults to None. drop_last: If True, the last batch is dropped if it is smaller than the batch size. - error_handler: Function which handles exceptions raised by the batcher. The default - implementation logs the exception. failure_tolerance: The number of consecutive failures after which the dataset is considered broken. Set to 0 to disable. worker_config: Configuration for the workers. """ @@ -81,8 +71,12 @@ def __init__( self.batcher_stateless = batcher_stateless self.batcher_config = batcher_config self.drop_last = drop_last - self.error_handler = error_handler self.failure_tolerance = failure_tolerance + self._batch_failure_handler = ErrorContext( + name=f"BatchDataset.{self.batcher}", + handler=worker_config.global_error_handler, + tolerance=failure_tolerance, + ) self.reset_state_own() @@ -133,7 +127,7 @@ def __iter__(self) -> Iterator[T_batch]: sample_restore_keys = [] def flush() -> Generator[T_batch, None, None]: - try: + with self._batch_failure_handler.handle_errors(batch): with self._sample_index.ctx() as sample_idx: batch_sample = self.batcher(batch) if isinstance(batch_sample, Generator): @@ -145,8 +139,8 @@ def flush() -> Generator[T_batch, None, None]: for batch_sub_idx, (sample_idx, inner_batch_sample) in enumerate( self._sample_index.iter_ctx(batch_sample, sample_idx) ): - self._last_batch_failures = 0 self._generator_offset = batch_sub_idx + 1 + self._batch_failure_handler.reset() yield set_sample_restore_key( inner_batch_sample, sample_idx, @@ -157,28 +151,10 @@ def flush() -> Generator[T_batch, None, None]: self._generator_sample_keys = None self._generator_offset = None else: - self._last_batch_failures = 0 + self._batch_failure_handler.reset() set_sample_restore_key(batch_sample, sample_idx, *sample_restore_keys, src=self) yield batch_sample - except GeneratorExit: - raise - except SkipSample: - pass - except SYSTEM_EXCEPTIONS: - raise FatalSampleError.from_sample(batch) - except Exception as e: - self.error_handler(e, batch) - self._last_batch_failures += 1 - if ( - self.failure_tolerance > 0 - and self._last_batch_failures >= self.failure_tolerance - ): - raise FatalSampleError.from_sample( - batch, - f"BatchDataset {self.batcher} failed {self._last_batch_failures} times in a row. Likely your code or dataset are broken.", - ) - finally: - sample_restore_keys.clear() + sample_restore_keys.clear() for sample in self.dataset: batch.append(sample) @@ -211,7 +187,7 @@ def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_b assert id == type(self).__name__ batch = [self.dataset.restore_sample(inner_idx) for inner_idx in samples_restore_keys] - try: + with handle_restore_errors(self.worker_config.restore_error_handler, batch): with self._sample_index.ctx(sample_idx): batch_sample = self.batcher(batch) if isinstance(batch_sample, Generator): @@ -221,7 +197,6 @@ def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_b for cur_batch_sub_idx, (sample_idx, inner_batch_sample) in enumerate( self._sample_index.iter_ctx(batch_sample, sample_idx) ): - self._last_batch_failures = 0 if cur_batch_sub_idx == batch_sub_idx: return set_sample_restore_key( inner_batch_sample, @@ -232,32 +207,12 @@ def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_b ) assert False, f"Batch sub-index {batch_sub_idx} not found in batch" else: - self._last_batch_failures = 0 return set_sample_restore_key( batch_sample, sample_idx, *samples_restore_keys, src=self, ) - except GeneratorExit: - raise FatalSampleError.from_sample( - batch, - f"BatchDataset {self.batcher} generator exitedwhile trying to restore a batch.", - ) - except SkipSample: - raise FatalSampleError.from_sample( - batch, f"BatchDataset {self.batcher} skipped while trying to restore a batch." - ) - except SYSTEM_EXCEPTIONS: - raise FatalSampleError.from_sample(batch) - except Exception as e: - self.error_handler(e, batch) - self._last_batch_failures += 1 - if self.failure_tolerance > 0 and self._last_batch_failures >= self.failure_tolerance: - raise FatalSampleError.from_sample( - batch, - f"BatchDataset {self.batcher} failed {self._last_batch_failures} times in a row. Likely your code or dataset are broken.", - ) def config(self) -> Dict[str, Any]: return { @@ -277,7 +232,6 @@ def config(self) -> Dict[str, Any]: ), "batcher_stateless": self.batcher_stateless, "drop_last": self.drop_last, - "error_handler": self._function_config(self.error_handler), "worker_config": self.worker_config.config(), "dataset": self.dataset.config(), } diff --git a/src/megatron/energon/wrappers/gc_dataset.py b/src/megatron/energon/wrappers/gc_dataset.py index 70f31688..5ba79e06 100644 --- a/src/megatron/energon/wrappers/gc_dataset.py +++ b/src/megatron/energon/wrappers/gc_dataset.py @@ -10,6 +10,7 @@ from torch.distributed._shard.sharded_tensor import ShardedTensorBase from torch.distributed.distributed_c10d import reduce_op +from megatron.energon.errors import SYSTEM_EXCEPTIONS from megatron.energon.flavors.base_dataset import SavableDataset from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset @@ -52,6 +53,13 @@ def gc_init_worker(worker_id: int): o._shutdown = True except ReferenceError: # Can happen if the object is a weakref proxy, don't care + # (Also in SYSTEM_EXCEPTIONS, thus catch before) + pass + except SYSTEM_EXCEPTIONS: + # Reraise these exceptions + raise + except Exception: + # Can happen if the object is a weakref proxy, or any other exception, don't care pass _frozen_cuda_tensors_initialized = True diff --git a/src/megatron/energon/wrappers/group_batch_dataset.py b/src/megatron/energon/wrappers/group_batch_dataset.py index 72cefc77..41a45b20 100644 --- a/src/megatron/energon/wrappers/group_batch_dataset.py +++ b/src/megatron/energon/wrappers/group_batch_dataset.py @@ -12,26 +12,22 @@ Iterator, List, Optional, - Sequence, Tuple, TypeVar, Union, ) from megatron.energon.edataclass import edataclass -from megatron.energon.errors import SYSTEM_EXCEPTIONS, FatalSampleError +from megatron.energon.errors import ErrorContext, handle_restore_errors from megatron.energon.flavors.base_dataset import ( FlexState, SavableDataset, set_sample_restore_key, ) from megatron.energon.savable import Savable -from megatron.energon.source_info import SourceInfo from megatron.energon.worker import WorkerConfig -from megatron.energon.wrappers._log_exception import log_exception from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex from megatron.energon.wrappers.buffer import SavableSampleBuffer -from megatron.energon.wrappers.skip import SkipSample T_batch = TypeVar("T_batch", covariant=True) T_batch_sample = TypeVar("T_batch_sample", covariant=True) @@ -68,11 +64,11 @@ class GroupBatchDataset( sample_group_key: Callable[[T_batch_sample], Tuple[Hashable, Optional[int]]] batcher: Callable[[List[T_batch_sample]], T_batch] drop_last: bool - error_handler: Callable[[Exception, List[T_batch_sample], list[SourceInfo]], None] _group_key_sample_index: SampleIndex _batch_sample_index: SampleIndex _buckets: Dict[Hashable, Bucket[T_batch_sample]] - _last_batch_failures: int = 0 + _batch_failure_handler: ErrorContext + _group_key_failure_handler: ErrorContext def __init__( self, @@ -84,9 +80,6 @@ def __init__( batcher_stateless: bool = False, batcher_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] = None, drop_last: bool = False, - error_handler: Callable[ - [Exception, List[T_batch_sample], Sequence[SourceInfo]], None - ] = log_exception, failure_tolerance: int = 100, worker_config: WorkerConfig, ): @@ -99,7 +92,6 @@ def __init__( batcher: Function which combines separate samples into a single object. May raise :exc:`megatron.energon.SkipSample` to skip a sample. drop_last: If True, the last batch is dropped if it is smaller than the batch size. - error_handler: Handler for errors. Defaults to logging and ignoring the exception. failure_tolerance: The number of consecutive failures after which the dataset is considered broken. Set to 0 to disable. worker_config: Configuration for the workers. """ @@ -110,8 +102,17 @@ def __init__( self.batcher_stateless = batcher_stateless self.batcher_config = batcher_config self.drop_last = drop_last - self.error_handler = error_handler self.failure_tolerance = failure_tolerance + self._batch_failure_handler = ErrorContext( + name=f"GroupBatchDataset.{self.batcher}", + handler=worker_config.global_error_handler, + tolerance=failure_tolerance, + ) + self._group_key_failure_handler = ErrorContext( + name=f"GroupBatchDataset.{self.sample_group_key}", + handler=worker_config.global_error_handler, + tolerance=failure_tolerance, + ) self.reset_state_own() @@ -152,34 +153,19 @@ def flush(bucket: Bucket[T_batch_sample]) -> Generator[T_batch, None, None]: # dbg_bucket.samples.debug_print(" ") batch_items, sample_restore_keys = bucket.samples.flush() # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] flushed: len(batch)={len(batch_items)} len(samples)={len(bucket.samples)}\n", end="") - try: + with self._batch_failure_handler.handle_errors(batch_items): with self._batch_sample_index.ctx() as sample_idx: batch_sample = self.batcher(batch_items) assert not isinstance(batch_sample, Generator), ( f"Batcher {self.batcher} returned a generator, which is not supported for grouped batching yet." ) - self._last_batch_failures = 0 + self._batch_failure_handler.reset() set_sample_restore_key(batch_sample, sample_idx, *sample_restore_keys, src=self) yield batch_sample - except SkipSample: - pass - except SYSTEM_EXCEPTIONS: - raise FatalSampleError.from_sample(batch_items) - except Exception as e: - self.error_handler(e, batch_items) - self._last_batch_failures += 1 - if ( - self.failure_tolerance > 0 - and self._last_batch_failures >= self.failure_tolerance - ): - raise FatalSampleError.from_sample( - batch_items, - f"GroupBatchDataset {self.batcher} failed {self._last_batch_failures} times in a row. Likely your code or dataset are broken.", - ) # Add samples to the buckets for sample in self.dataset: - try: + with self._group_key_failure_handler.handle_errors(sample): with self._group_key_sample_index.ctx(): bucket_key, batch_size = self.sample_group_key(sample) assert (batch_size is None) != (self.fixed_batch_size is None), ( @@ -188,13 +174,6 @@ def flush(bucket: Bucket[T_batch_sample]) -> Generator[T_batch, None, None]: ) if self.fixed_batch_size is not None: batch_size = self.fixed_batch_size - except SkipSample: - continue - except SYSTEM_EXCEPTIONS: - raise FatalSampleError.from_sample(sample) - except Exception as e: - self.error_handler(e, [sample]) - continue bucket = buckets.get(bucket_key) if bucket is None: assert batch_size is not None @@ -251,23 +230,11 @@ def restore_sample(self, index: Tuple[Union[str, int, tuple], ...]) -> T_batch: id, sample_idx, *sample_restore_keys = index assert id == type(self).__name__ batch = [self.dataset.restore_sample(inner_idx) for inner_idx in sample_restore_keys] - try: + + with handle_restore_errors(self.worker_config.restore_error_handler, batch): with self._batch_sample_index.ctx(sample_idx): batch_sample = self.batcher(batch) set_sample_restore_key(batch_sample, sample_idx, *sample_restore_keys, src=self) - self._last_batch_failures = 0 - except SkipSample: - pass - except SYSTEM_EXCEPTIONS: - raise FatalSampleError.from_sample(batch) - except Exception as e: - self.error_handler(e, batch) - self._last_batch_failures += 1 - if self.failure_tolerance > 0 and self._last_batch_failures >= self.failure_tolerance: - raise FatalSampleError.from_sample( - batch, - f"GroupBatchDataset {self.batcher} failed {self._last_batch_failures} times in a row. Likely your code or dataset are broken.", - ) return batch_sample @@ -289,7 +256,6 @@ def config(self) -> Dict[str, Any]: ), "batcher_stateless": self.batcher_stateless, "drop_last": self.drop_last, - "error_handler": self._function_config(self.error_handler), "worker_config": self.worker_config.config(), "dataset": self.dataset.config(), } diff --git a/src/megatron/energon/wrappers/iter_map_dataset.py b/src/megatron/energon/wrappers/iter_map_dataset.py index 6d4d8956..a595e83b 100644 --- a/src/megatron/energon/wrappers/iter_map_dataset.py +++ b/src/megatron/energon/wrappers/iter_map_dataset.py @@ -16,11 +16,9 @@ from torch.utils.data import IterableDataset -from megatron.energon.errors import SYSTEM_EXCEPTIONS, FatalSampleError +from megatron.energon.errors import ErrorContext, handle_restore_errors from megatron.energon.flavors.base_dataset import SavableDataset, set_sample_restore_key -from megatron.energon.source_info import SourceInfo from megatron.energon.worker import WorkerConfig -from megatron.energon.wrappers._log_exception import log_exception from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex, get_sample_restore_key T_sample = TypeVar("T_sample") @@ -36,10 +34,10 @@ class IterMapDataset(BaseWrapperDataset[T_sample, T_sample_out], Generic[T_sampl iter_map_fn: Callable[[Iterator[T_sample]], Iterator[T_sample_out]] len_map_fn: Callable[[int], int] - error_handler: Callable[[Exception, Optional[T_sample], list[SourceInfo]], None] stateless_iter_fn: bool iter_map_fn_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] _sample_index: SampleIndex + _iter_map_failure_handler: ErrorContext _savable_fields = ("_sample_index",) @@ -49,9 +47,6 @@ def __init__( iter_map_fn: Callable[[Iterator[T_sample]], Iterator[T_sample_out]], *, len_map_fn: Callable[[int], int] = lambda x: x, - error_handler: Callable[ - [Exception, Optional[T_sample], list[SourceInfo]], None - ] = log_exception, stateless_iter_fn: bool = False, iter_map_fn_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] = None, worker_config: WorkerConfig, @@ -67,7 +62,6 @@ def __init__( len_map_fn: The function to apply to the length of the dataset. Returns the new (approximate) length of the resulting stream of samples based on the original length. - error_handler: Handler for errors. Defaults to logging and ignoring the exception. stateless_iter_fn: If true, assume the iter_map_fn is deterministic and stateless (it does not aggregate samples (thus key for random access can propagate to inner dataset), yielding zero or multiple samples per fetched sample is fine). @@ -79,9 +73,12 @@ def __init__( super().__init__(dataset, worker_config=worker_config) self.iter_map_fn = iter_map_fn self.len_map_fn = len_map_fn - self.error_handler = error_handler self.stateless_iter_fn = stateless_iter_fn self.iter_map_fn_config = iter_map_fn_config + self._iter_map_failure_handler = ErrorContext( + name=f"IterMapDataset.{self.iter_map_fn}", + handler=worker_config.global_error_handler, + ) self.reset_state_own() @@ -114,7 +111,7 @@ def reset_idx_iter() -> Generator[T_sample, None, None]: # While True will break when the inner dataset is exhausted, but may continue on exception while True: iter_idx = 0 - try: + with self._iter_map_failure_handler.handle_errors(last_sample_wrapper.last_sample): for sample_idx, sample in self._sample_index.iter_ctx(self.iter_map_fn(ds_iter)): yield set_sample_restore_key( sample, @@ -125,11 +122,6 @@ def reset_idx_iter() -> Generator[T_sample, None, None]: ) sample_restore_keys.clear() iter_idx += 1 - except SYSTEM_EXCEPTIONS: - raise FatalSampleError.from_sample(last_sample_wrapper.last_sample) - except Exception as e: - self.error_handler(e, last_sample_wrapper.last_sample) - else: break def can_restore_sample(self) -> bool: @@ -146,42 +138,34 @@ def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_s id, sample_idx, iter_idx, *sample_restore_keys = restore_key assert id == type(self).__name__ assert isinstance(iter_idx, int) - to_be_mapped = ( + to_be_mapped = tuple( self.dataset.restore_sample(inner_index) for inner_index in sample_restore_keys ) - try: - inner_iter = iter(self.iter_map_fn(to_be_mapped)) - # Skip inner yielded samples to get the correct sample - for skip_idx in range(iter_idx): - with self._sample_index.ctx(sample_idx - iter_idx + skip_idx): - next(inner_iter) - # This is the sample to restore - with self._sample_index.ctx(sample_idx): - sample = next(inner_iter) - return set_sample_restore_key( - sample, - sample_idx, - iter_idx, - *sample_restore_keys, - src=self, - ) - except StopIteration: - raise RuntimeError( - "Generator did not yield enough samples, but is marked stateless/deterministic." - ) - except GeneratorExit: - raise FatalSampleError.from_sample( - to_be_mapped, - f"IterMapDataset {self.iter_map_fn} generator exited while trying to restore a sample.", - ) - except SYSTEM_EXCEPTIONS: - raise FatalSampleError.from_sample(to_be_mapped) - except Exception as e: - self.error_handler(e, to_be_mapped) - finally: - # Properly close if it's a generator - if hasattr(inner_iter, "close"): - inner_iter.close() + with handle_restore_errors(self.worker_config.restore_error_handler, to_be_mapped): + inner_iter = iter(self.iter_map_fn(iter(to_be_mapped))) + try: + # Skip inner yielded samples to get the correct sample + for skip_idx in range(iter_idx): + with self._sample_index.ctx(sample_idx - iter_idx + skip_idx): + next(inner_iter) + # This is the sample to restore + with self._sample_index.ctx(sample_idx): + sample = next(inner_iter) + return set_sample_restore_key( + sample, + sample_idx, + iter_idx, + *sample_restore_keys, + src=self, + ) + except StopIteration: + raise RuntimeError( + "Generator did not yield enough samples, but is marked stateless/deterministic." + ) + finally: + # Properly close if it's a generator + if hasattr(inner_iter, "close"): + inner_iter.close() def config(self) -> Dict[str, Any]: return { @@ -200,7 +184,6 @@ def config(self) -> Dict[str, Any]: else {} ), "len_map_fn": self._function_config(self.len_map_fn), - "error_handler": self._function_config(self.error_handler), } def __str__(self): diff --git a/src/megatron/energon/wrappers/log_sample_dataset.py b/src/megatron/energon/wrappers/log_sample_dataset.py index 415b9f0d..d291fb87 100644 --- a/src/megatron/energon/wrappers/log_sample_dataset.py +++ b/src/megatron/energon/wrappers/log_sample_dataset.py @@ -1,56 +1,20 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Callable, Dict, Generic, Iterator, List, Literal, Optional, TypeVar +from typing import Any, Callable, Dict, Generic, Iterator, Literal, TypeVar from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.sample_utils import default_get_batch_keys from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset T_sample = TypeVar("T_sample") -def default_get_keys(batch: Any) -> Optional[List[str]]: - """Default get_keys, which has some heuristics to find the sample keys.""" - if isinstance(batch, list): - batch = batch[0] - if ( - hasattr(batch, "__key__") - and isinstance(batch.__key__, list) - and all(isinstance(k, str) for k in batch.__key__) - ): - return batch.__key__ - elif ( - hasattr(batch, "__keys__") - and isinstance(batch.__keys__, list) - and all(isinstance(k, str) for k in batch.__keys__) - ): - return batch.__keys__ - elif ( - isinstance(batch, dict) - and "__key__" in batch - and all(isinstance(k, str) for k in batch["__key__"]) - ): - return batch["__key__"] - elif ( - isinstance(batch, dict) - and "__keys__" in batch - and all(isinstance(k, str) for k in batch["__keys__"]) - ): - return batch["__keys__"] - elif ( - isinstance(batch, dict) - and "keys" in batch - and all(isinstance(k, str) for k in batch["keys"]) - ): - return batch["keys"] - return None - - class LogSampleDataset(BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]): """This dataset logs every yielded sample to the debug logs.""" - get_keys_fn: Callable[[T_sample], Optional[List[str]]] + get_keys_fn: Callable[[T_sample], list[str] | None] mode: Literal["train", "val"] _step: int @@ -61,7 +25,7 @@ def __init__( dataset: SavableDataset[T_sample], mode: Literal["train", "val"], worker_config: WorkerConfig, - get_keys_fn: Callable[[T_sample], Optional[List[str]]] = default_get_keys, + get_keys_fn: Callable[[T_sample], list[str] | None] = default_get_batch_keys, ): """Construct the log sample dataset, which logs every yielded sample to the debug logs. diff --git a/src/megatron/energon/wrappers/map_dataset.py b/src/megatron/energon/wrappers/map_dataset.py index 33819ace..a4b4873b 100644 --- a/src/megatron/energon/wrappers/map_dataset.py +++ b/src/megatron/energon/wrappers/map_dataset.py @@ -10,19 +10,18 @@ Generic, Iterator, Optional, - Sequence, Tuple, TypeVar, Union, ) -from megatron.energon.errors import SYSTEM_EXCEPTIONS, FatalSampleError +from megatron.energon.errors import ( + ErrorContext, + handle_restore_errors, +) from megatron.energon.flavors.base_dataset import SavableDataset, add_sample_restore_key -from megatron.energon.source_info import SourceInfo from megatron.energon.worker import WorkerConfig -from megatron.energon.wrappers._log_exception import log_exception from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex, get_sample_restore_key -from megatron.energon.wrappers.skip import SkipSample T_sample = TypeVar("T_sample") T_sample_out = TypeVar("T_sample_out") @@ -32,13 +31,12 @@ class MapDataset(BaseWrapperDataset[T_sample, T_sample_out], Generic[T_sample, T """This dataset wrapper applies a custom function to transform each sample.""" map_fn: Callable[[T_sample], Union[T_sample_out, Generator[T_sample_out, None, None]]] - error_handler: Callable[[Exception, T_sample, Sequence[SourceInfo]], None] stateless_map_fn: bool map_fn_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] _sample_index: SampleIndex _generator_sample_key: Optional[Any] _generator_offset: Optional[int] - _last_map_failures: int = 0 + _map_failure_handler: ErrorContext _savable_fields = ( "_sample_index", @@ -51,7 +49,6 @@ def __init__( dataset: SavableDataset[T_sample], map_fn: Callable[[T_sample], Union[T_sample_out, Generator[T_sample_out, None, None]]], *, - error_handler: Callable[[Exception, T_sample, Sequence[SourceInfo]], None] = log_exception, stateless_map_fn: bool = False, map_fn_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] = None, failure_tolerance: int = 100, @@ -67,7 +64,6 @@ def __init__( map_fn: The function to apply to each sample. May raise :exc:`megatron.energon.SkipSample` to skip a sample. Alternatively, may return a generator to yield multiple or no samples. - error_handler: Handler for errors. Defaults to logging and ignoring the exception. stateless_map_fn: If true, the map_fn is deterministic and stateless (thus key for random access can propagate to inner dataset). Defaults to False. map_fn_config: Configuration for the map_fn function. If callable, it should return the @@ -77,10 +73,14 @@ def __init__( """ super().__init__(dataset, worker_config=worker_config) self.map_fn = map_fn - self.error_handler = error_handler self.stateless_map_fn = stateless_map_fn self.map_fn_config = map_fn_config self.failure_tolerance = failure_tolerance + self._map_failure_handler = ErrorContext( + name=f"MapDataset.{self.map_fn}", + tolerance=failure_tolerance, + handler=worker_config.global_error_handler, + ) self.reset_state_own() @@ -122,7 +122,7 @@ def __iter__(self) -> Iterator[T_sample_out]: for sample in self.dataset: restore_key = get_sample_restore_key(sample) - try: + with self._map_failure_handler.handle_errors(sample): with self._sample_index.ctx() as sample_idx: mapped_sample = self.map_fn(sample) if isinstance(mapped_sample, Generator): @@ -137,7 +137,7 @@ def __iter__(self) -> Iterator[T_sample_out]: self._sample_index.iter_ctx(mapped_sample, sample_idx) ): self._generator_offset = idx + 1 - self._last_map_failures = 0 + self._map_failure_handler.reset() yield add_sample_restore_key( inner_sample, sample_idx, @@ -147,29 +147,12 @@ def __iter__(self) -> Iterator[T_sample_out]: self._generator_sample_key = None self._generator_offset = None else: - self._last_map_failures = 0 + self._map_failure_handler.reset() yield add_sample_restore_key( mapped_sample, sample_idx, src=self, ) - except GeneratorExit: - raise - except SkipSample: - pass - except SYSTEM_EXCEPTIONS: - raise FatalSampleError.from_sample(sample) - except Exception as e: - self.error_handler(e, sample) - self._last_map_failures += 1 - print( - f"MapDataset {self.map_fn} failed {self._last_map_failures}/{self.failure_tolerance} times in a row." - ) - if self.failure_tolerance > 0 and self._last_map_failures >= self.failure_tolerance: - raise FatalSampleError.from_sample( - sample, - f"MapDataset {self.map_fn} failed {self._last_map_failures} times in a row. Likely your code or dataset are broken.", - ) def can_restore_sample(self) -> bool: return super().can_restore_sample() and self.stateless_map_fn @@ -193,7 +176,7 @@ def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_s restore_key = restore_key[2:] inner_sample = self.dataset.restore_sample(restore_key) - try: + with handle_restore_errors(self.worker_config.restore_error_handler, inner_sample): with self._sample_index.ctx(sample_idx): mapped_sample = self.map_fn(inner_sample) if isinstance(mapped_sample, Generator): @@ -203,34 +186,13 @@ def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_s for idx, (sample_idx, res_sample) in enumerate( self._sample_index.iter_ctx(mapped_sample, sample_idx) ): - self._last_map_failures = 0 if idx == local_idx: return add_sample_restore_key(res_sample, sample_idx, local_idx, src=self) assert False, ( "Generator did not yield enough samples, but is marked stateless/deterministic." ) else: - self._last_map_failures = 0 return add_sample_restore_key(mapped_sample, sample_idx, src=self) - except GeneratorExit: - raise FatalSampleError.from_sample( - inner_sample, - f"MapDataset {self.map_fn} generator exited while trying to restore a sample.", - ) - except SkipSample: - raise FatalSampleError.from_sample( - inner_sample, f"MapDataset {self.map_fn} skipped while trying to restore a sample." - ) - except SYSTEM_EXCEPTIONS: - raise FatalSampleError.from_sample(inner_sample) - except Exception as e: - self.error_handler(e, inner_sample) - self._last_map_failures += 1 - if self.failure_tolerance > 0 and self._last_map_failures >= self.failure_tolerance: - raise FatalSampleError.from_sample( - inner_sample, - f"MapDataset {self.map_fn} failed {self._last_map_failures} times in a row. Likely your code or dataset are broken.", - ) def config(self) -> Dict[str, Any]: return { diff --git a/src/megatron/energon/wrappers/packing_dataset.py b/src/megatron/energon/wrappers/packing_dataset.py index 26472fd5..ee10e6c2 100644 --- a/src/megatron/energon/wrappers/packing_dataset.py +++ b/src/megatron/energon/wrappers/packing_dataset.py @@ -12,23 +12,19 @@ Iterator, List, Optional, - Sequence, TypeVar, Union, ) -from megatron.energon.errors import SYSTEM_EXCEPTIONS, FatalSampleError +from megatron.energon.errors import ErrorContext, handle_restore_errors from megatron.energon.flavors.base_dataset import ( SavableDataset, add_sample_restore_key, set_sample_restore_key, ) -from megatron.energon.source_info import SourceInfo from megatron.energon.worker import WorkerConfig -from megatron.energon.wrappers._log_exception import log_exception from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex, get_sample_restore_key from megatron.energon.wrappers.buffer import SavableSampleBuffer -from megatron.energon.wrappers.skip import SkipSample T_sample = TypeVar("T_sample") T_encoded_sample = TypeVar("T_encoded_sample") @@ -49,7 +45,6 @@ class PackingDataset( final_packer: Callable[[List[T_encoded_sample]], T_batch_sample] final_packer_stateless: bool packer_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] - error_handler: Callable[[Exception, List[T_sample], Sequence[SourceInfo]], None] #: The buffer for collecting the samples that shall be packed. _reading_buffer: SavableSampleBuffer @@ -73,10 +68,10 @@ class PackingDataset( #: Sample index for the final_packer _final_packing_sample_index: SampleIndex - # Local state: Tracking last failures for each component, to raise a fatal error after a certain number of failures. - _last_pre_pack_failures: int = 0 - _last_final_pack_failures: int = 0 - _last_sample_encoder_failures: int = 0 + #: Error handlers for tracking failures + _pre_pack_failure_handler: ErrorContext + _final_pack_failure_handler: ErrorContext + _sample_encoder_failure_handler: ErrorContext | None _savable_fields = ( "_reading_buffer", @@ -98,9 +93,6 @@ def __init__( sample_encoder: Optional[Callable[[T_sample], T_encoded_sample]] = None, sample_encoder_stateless: bool = False, packer_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] = None, - error_handler: Callable[ - [Exception, List[T_sample], Sequence[SourceInfo]], None - ] = log_exception, pre_packer_failure_tolerance: int = 100, final_packer_failure_tolerance: int = 100, sample_encoder_failure_tolerance: int = 100, @@ -125,8 +117,6 @@ def __init__( stored/restored. packer_config: Configuration for the (pre|final)_packer functions. If callable, it should return the configuration. Defaults to None. - error_handler: Function which handles exceptions raised by the batcher. The default - implementation logs the exception. pre_packer_failure_tolerance: Maximum number of pre-packer failures before raising an error. Set to 0 to disable. final_packer_failure_tolerance: Maximum number of final-packer failures before raising an error. Set to 0 to disable. sample_encoder_failure_tolerance: Maximum number of sample-encoder failures before raising an error. Set to 0 to disable. @@ -143,12 +133,30 @@ def __init__( self.sample_encoder = sample_encoder self.sample_encoder_stateless = True if sample_encoder is None else sample_encoder_stateless self.packer_config = packer_config - self.error_handler = error_handler self.pre_packer_failure_tolerance = pre_packer_failure_tolerance self.final_packer_failure_tolerance = final_packer_failure_tolerance self.sample_encoder_failure_tolerance = sample_encoder_failure_tolerance + self._pre_pack_failure_handler = ErrorContext( + name=f"PackingDataset.{self.pre_packer}", + handler=worker_config.global_error_handler, + tolerance=pre_packer_failure_tolerance, + ) + self._final_pack_failure_handler = ErrorContext( + name=f"PackingDataset.{self.final_packer}", + handler=worker_config.global_error_handler, + tolerance=final_packer_failure_tolerance, + ) + if self.sample_encoder is not None: + self._sample_encoder_failure_handler = ErrorContext( + name=f"PackingDataset.{self.sample_encoder}", + handler=worker_config.global_error_handler, + tolerance=sample_encoder_failure_tolerance, + ) + else: + self._sample_encoder_failure_handler = None + self.reset_state_own() def reset_state_own(self) -> None: @@ -218,10 +226,11 @@ def encode_pack_samples(pack: List[T_sample]) -> List[T_encoded_sample]: return pack encoded_pack = [] for sample in pack: - try: + with self._sample_encoder_failure_handler.handle_errors(sample): with self._sample_encoder_sample_index.ctx() as encode_idx: encoded_sample = self.sample_encoder(sample) assert not isinstance(encoded_sample, Generator), "Generator not supported" + self._sample_encoder_failure_handler.reset() encoded_pack.append( add_sample_restore_key( encoded_sample, @@ -229,23 +238,6 @@ def encode_pack_samples(pack: List[T_sample]) -> List[T_encoded_sample]: src=self, ) ) - self._last_sample_encoder_failures = 0 - except SkipSample: - pass - except SYSTEM_EXCEPTIONS: - raise FatalSampleError.from_sample(pack) - except Exception as e: - self.error_handler(e, [sample]) - self._last_sample_encoder_failures += 1 - if ( - self.sample_encoder_failure_tolerance > 0 - and self._last_sample_encoder_failures - >= self.sample_encoder_failure_tolerance - ): - raise FatalSampleError.from_sample( - pack, - f"Sample encoder {self.sample_encoder} failed {self._last_sample_encoder_failures} times. Likely your code or dataset are broken.", - ) return encoded_pack def next_pre_pack(): @@ -260,26 +252,10 @@ def next_pre_pack(): self._reading_buffer.clear() pre_packing_lengths.clear() # Now pre pack the samples - try: + pre_packs = [] + with self._pre_pack_failure_handler.handle_errors(samples): with self._pre_packing_sample_index.ctx(): pre_packs = self.pre_packer(samples) - self._last_pre_pack_failures = 0 - except SkipSample: - pre_packs = [] - except SYSTEM_EXCEPTIONS: - raise FatalSampleError.from_sample(samples) - except Exception as e: - self.error_handler(e, samples) - pre_packs = [] - self._last_pre_pack_failures += 1 - if ( - self.pre_packer_failure_tolerance > 0 - and self._last_pre_pack_failures >= self.pre_packer_failure_tolerance - ): - raise FatalSampleError.from_sample( - samples, - f"Pre packer {self.pre_packer} failed {self._last_pre_pack_failures} times. Likely your code or dataset are broken.", - ) # Put the pre-packed samples into the pre_packing_buffer # They will be flattened here to avoid nested buffers @@ -300,7 +276,7 @@ def next_final_pack() -> Generator[T_batch_sample, None, None]: del self._pre_packing_buffer[: pre_packing_lengths[0]] del pre_packing_lengths[0] - try: + with self._final_pack_failure_handler.handle_errors(pack): pack_restore_keys = tuple(get_sample_restore_key(sample) for sample in pack) with self._final_packing_sample_index.ctx() as pack_idx: final_packed_sample = self.final_packer(pack) @@ -311,7 +287,7 @@ def next_final_pack() -> Generator[T_batch_sample, None, None]: for pack_sub_idx, (pack_idx, inner_batch_sample) in enumerate( self._final_packing_sample_index.iter_ctx(final_packed_sample, pack_idx) ): - self._last_final_pack_failures = 0 + self._final_pack_failure_handler.reset() yield set_sample_restore_key( inner_batch_sample, pack_idx, @@ -320,28 +296,13 @@ def next_final_pack() -> Generator[T_batch_sample, None, None]: src=self, ) else: - self._last_final_pack_failures = 0 + self._final_pack_failure_handler.reset() yield set_sample_restore_key( final_packed_sample, pack_idx, *pack_restore_keys, src=self, ) - except SkipSample: - pass - except SYSTEM_EXCEPTIONS: - raise FatalSampleError.from_sample(pack) - except Exception as e: - self.error_handler(e, pack) - self._last_final_pack_failures += 1 - if ( - self.final_packer_failure_tolerance > 0 - and self._last_final_pack_failures >= self.final_packer_failure_tolerance - ): - raise FatalSampleError.from_sample( - pack, - f"Final packer {self.final_packer} failed {self._last_final_pack_failures} times. Likely your code or dataset are broken.", - ) # Main loop: pre_pack_round = 0 @@ -423,34 +384,16 @@ def restore_sample(self, restore_key: Any) -> T_sample: assert id == type(self).__name__ assert isinstance(sample_idx, int) sample = self.dataset.restore_sample(inner_idx) - try: - if self.sample_encoder is not None: + if self.sample_encoder is not None: + with handle_restore_errors(self.worker_config.restore_error_handler, sample): with self._sample_encoder_sample_index.ctx(sample_idx): sample = self.sample_encoder(sample) assert not isinstance(sample, Generator), "Generator not supported" - self._last_sample_encoder_failures = 0 sample = add_sample_restore_key(sample, sample_idx, src=self) - except SkipSample: - raise FatalSampleError.from_sample( - sample, - f"PackingDataset sample encoder {self.sample_encoder} skipped while trying to restore a sample.", - ) - except SYSTEM_EXCEPTIONS: - raise FatalSampleError.from_sample(sample) - except Exception as e: - self.error_handler(e, sample) - self._last_sample_encoder_failures += 1 - if ( - self.sample_encoder_failure_tolerance > 0 - and self._last_sample_encoder_failures >= self.sample_encoder_failure_tolerance - ): - raise FatalSampleError.from_sample( - sample, - f"PackingDataset sample encoder {self.sample_encoder} failed {self._last_sample_encoder_failures} times. Likely your code or dataset are broken.", - ) pack.append(sample) - try: + + with handle_restore_errors(self.worker_config.restore_error_handler, pack): with self._final_packing_sample_index.ctx(pack_idx): final_pack = self.final_packer(pack) if isinstance(final_pack, Generator): @@ -460,7 +403,6 @@ def restore_sample(self, restore_key: Any) -> T_sample: for cur_batch_sub_idx, (pack_idx, inner_batch_sample) in enumerate( self._final_packing_sample_index.iter_ctx(final_pack, pack_idx) ): - self._last_final_pack_failures = 0 if cur_batch_sub_idx == pack_sub_idx: return set_sample_restore_key( inner_batch_sample, @@ -471,30 +413,7 @@ def restore_sample(self, restore_key: Any) -> T_sample: ) assert False, f"Pack sub-index {pack_sub_idx} not found in pack" else: - self._last_final_pack_failures = 0 return set_sample_restore_key(final_pack, pack_idx, *pack_restore_keys, src=self) - except GeneratorExit: - raise FatalSampleError.from_sample( - pack, - f"PackingDataset {self.final_packer} generator exited while trying to restore a pack.", - ) - except SkipSample: - raise FatalSampleError.from_sample( - pack, f"PackingDataset {self.final_packer} skipped while trying to restore a pack." - ) - except SYSTEM_EXCEPTIONS: - raise FatalSampleError.from_sample(pack) - except Exception as e: - self.error_handler(e, pack) - self._last_final_pack_failures += 1 - if ( - self.final_packer_failure_tolerance > 0 - and self._last_final_pack_failures >= self.final_packer_failure_tolerance - ): - raise FatalSampleError.from_sample( - pack, - f"PackingDataset {self.final_packer} failed {self._last_final_pack_failures} times. Likely your code or dataset are broken.", - ) def config(self) -> Dict[str, Any]: return { @@ -512,7 +431,6 @@ def config(self) -> Dict[str, Any]: if self.packer_config else {} ), - "error_handler": self._function_config(self.error_handler), "worker_config": self.worker_config.config(), "dataset": self.dataset.config(), } diff --git a/tests/test_av_decoder.py b/tests/test_av_decoder.py index 6564d866..eaed0c4c 100644 --- a/tests/test_av_decoder.py +++ b/tests/test_av_decoder.py @@ -17,7 +17,14 @@ import torch import torchvision.transforms as transforms -from megatron.energon.av import AVDecoder, AVMetadata, get_clips_uniform, get_single_frames_uniform +from megatron.energon.av import AVDecoder, get_clips_uniform, get_single_frames_uniform +from megatron.energon.av.fastseek.fastseek import Fastseek +from megatron.energon.av.fastseek.keyframeinfo import KeyframeInfo +from megatron.energon.av.fastseek.reader import ( + FastseekReaderByFrames, + FastseekReaderByPts, +) +from megatron.energon.media import AVMetadata # Set multiprocessing start method to 'spawn' on macOS to avoid DataLoader cleanup issues if sys.platform == "darwin": @@ -72,6 +79,319 @@ def tensors_close(tensor1: torch.Tensor, tensor2: torch.Tensor, tolerance: float return mae <= tolerance +class TestFastseek(unittest.TestCase): + """Test fastseek functionality.""" + + def test_fastseek_mp4(self): + """Test fastseek.""" + fastseek = Fastseek(io.BytesIO(Path("tests/data/sync_test.mp4").read_bytes())) + assert fastseek.mime == "video/mp4" + assert fastseek.frame_index_supported + assert fastseek.pts_supported + print(fastseek.keyframes) + assert list(fastseek.keyframes.keys()) == [1] + # There is one stream (id=1). Check that the first stream's keyframes are correct. + assert fastseek.keyframes[1].keyframes == [ + KeyframeInfo(index=0, pts=0), + KeyframeInfo(index=250, pts=128000), + KeyframeInfo(index=500, pts=256000), + KeyframeInfo(index=750, pts=384000), + KeyframeInfo(index=1000, pts=512000), + KeyframeInfo(index=1250, pts=640000), + KeyframeInfo(index=1500, pts=768000), + KeyframeInfo(index=1750, pts=896000), + ] + assert fastseek.keyframes[1].keyframe_pts == [ + 0, + 128000, + 256000, + 384000, + 512000, + 640000, + 768000, + 896000, + ] + + # Going forward + assert fastseek.should_seek_by_frame(0, 0) is None + assert fastseek.should_seek_by_frame(0, 249) is None + assert fastseek.should_seek_by_frame(0, 250) == KeyframeInfo(index=250, pts=128000) + assert fastseek.should_seek_by_frame(0, 251) == KeyframeInfo(index=250, pts=128000) + assert fastseek.should_seek_by_frame(0, 499) == KeyframeInfo(index=250, pts=128000) + assert fastseek.should_seek_by_frame(499, 500) == KeyframeInfo(index=500, pts=256000) + assert fastseek.should_seek_by_frame(250, 499) is None + assert fastseek.should_seek_by_frame(250, 500) == KeyframeInfo(index=500, pts=256000) + + # Going backward + assert fastseek.should_seek_by_frame(1, 0) == KeyframeInfo(index=0, pts=0) + assert fastseek.should_seek_by_frame(249, 0) == KeyframeInfo(index=0, pts=0) + assert fastseek.should_seek_by_frame(250, 0) == KeyframeInfo(index=0, pts=0) + + assert fastseek.should_seek_by_pts(0, 0) is None + assert fastseek.should_seek_by_pts(0, 1) is None + assert fastseek.should_seek_by_pts(0, 127999) is None + assert fastseek.should_seek_by_pts(128000, 128000) is None + assert fastseek.should_seek_by_pts(128000, 255999) is None + assert fastseek.should_seek_by_pts(0, 128000) == 128000 + assert fastseek.should_seek_by_pts(0, 256000) == 256000 + assert fastseek.should_seek_by_pts(128000, 256000) == 256000 + + def test_fastseek_mkv(self): + """Test fastseek.""" + fastseek = Fastseek(io.BytesIO(Path("tests/data/sync_test.mkv").read_bytes())) + assert fastseek.mime == "video/x-matroska" + assert not fastseek.frame_index_supported + assert fastseek.pts_supported + print(fastseek.keyframes) + assert list(fastseek.keyframes.keys()) == [1] + assert fastseek.keyframes[1].keyframe_pts == [ + 0, + 8354, + 16688, + 25021, + 33354, + 41688, + 50021, + 58354, + ] + assert fastseek.keyframes[1].keyframe_indexes is None + assert fastseek.keyframes[1].keyframes == [ + KeyframeInfo(index=None, pts=0), + KeyframeInfo(index=None, pts=8354), + KeyframeInfo(index=None, pts=16688), + KeyframeInfo(index=None, pts=25021), + KeyframeInfo(index=None, pts=33354), + KeyframeInfo(index=None, pts=41688), + KeyframeInfo(index=None, pts=50021), + KeyframeInfo(index=None, pts=58354), + ] + + assert fastseek.should_seek_by_pts(0, 0) is None + assert fastseek.should_seek_by_pts(0, 8353) is None + assert fastseek.should_seek_by_pts(8350, 8353) is None + assert fastseek.should_seek_by_pts(8354, 16687) is None + assert fastseek.should_seek_by_pts(0, 8354) == 8354 + assert fastseek.should_seek_by_pts(0, 8355) == 8354 + assert fastseek.should_seek_by_pts(0, 16688) == 16688 + assert fastseek.should_seek_by_pts(1, 0) == 0 + assert fastseek.should_seek_by_pts(10000, 0) == 0 + assert fastseek.should_seek_by_pts(10000, 8354) == 8354 + assert fastseek.should_seek_by_pts(10000, 8355) == 8354 + + def test_fastseek_reader_mp4(self): + """Test fastseek reader.""" + with open("tests/data/sync_test.mp4", "rb") as f: + fastseek = Fastseek(f) + f.seek(0) + with av.open(f, mode="r") as container: + reader = FastseekReaderByFrames(fastseek, container) + + frames = list(reader.seek_read(0, 1)) + assert len(frames) == 2 + assert frames[0].time == 0 + assert frames[1].time == 1 / 30 + assert reader.skipped == 0 + + # Read more frames, repeating the previous frame + frames = list(reader.seek_read(1, 2)) + assert len(frames) == 2, len(frames) + assert frames[0].time == 1 / 30 + assert frames[1].time == 2 / 30 + assert reader.skipped == 0 + + # Read next frame + frames = list(reader.seek_read(3, 3)) + assert len(frames) == 1, len(frames) + assert frames[0].time == 3 / 30 + assert reader.skipped == 0 + + # Seek to last frame before next keyframe + frames = list(reader.seek_read(249, 249)) + assert len(frames) == 1, len(frames) + assert frames[0].time == 249 / 30 + # Skipped frames 4-248 (inclusive) + assert reader.skipped == 245, reader.skipped + + reader.skipped = 0 + + # Seek through keyframe, repeating the previous frame + frames = list(reader.seek_read(249, 250)) + assert len(frames) == 2, len(frames) + assert frames[0].time == 249 / 30 + assert frames[1].time == 250 / 30 + assert reader.skipped == 0, reader.skipped + + reader.skipped = 0 + + # Seek to next keyframe + frames = list(reader.seek_read(500, 500)) + assert len(frames) == 1, len(frames) + assert frames[0].time == 500 / 30 + assert reader.skipped == 0, reader.skipped + + # Seek backwards 1 frame, but need previous keyframe + frames = list(reader.seek_read(499, 499)) + assert len(frames) == 1, len(frames) + assert frames[0].time == 499 / 30 + # Skipped frames 250-498 (inclusive) + assert reader.skipped == 249, reader.skipped + + reader.skipped = 0 + + # Seek to previous keyframe + frames = list(reader.seek_read(498, 498)) + assert len(frames) == 1, len(frames) + assert frames[0].time == 498 / 30 + # Skipped frames 250-497 (inclusive) + assert reader.skipped == 248, reader.skipped + + f.seek(0) + with av.open(f, mode="r") as container: + reader = FastseekReaderByPts(fastseek, container) + + # 512 PTS per frame + + frames = list(reader.seek_read(0, 1023)) + assert len(frames) == 2, len(frames) + assert frames[0].pts == 0 + assert frames[0].duration == 512 + assert frames[1].pts == 512 + assert frames[1].duration == 512 + assert reader.skipped == 0 + + frames = list(reader.seek_read(512, 3 * 512 - 1)) + assert len(frames) == 2, len(frames) + assert frames[0].time == 1 / 30 + assert frames[1].time == 2 / 30 + assert frames[0].pts == 512 + assert frames[1].pts == 1024 + assert reader.skipped == 0 + + frames = list(reader.seek_read(249 * 512, 249 * 512)) + assert len(frames) == 1, len(frames) + assert frames[0].pts == 249 * 512 + # Skipped frames 2-248 (inclusive) + assert reader.skipped == 246 + + reader.skipped = 0 + + frames = list(reader.seek_read(249 * 512, 250 * 512)) + assert len(frames) == 2, len(frames) + assert frames[0].pts == 249 * 512 + assert frames[1].pts == 250 * 512 + assert reader.skipped == 0 + + def test_fastseek_reader_mkv(self): + """Test fastseek reader.""" + with open("tests/data/sync_test.mkv", "rb") as f: + fastseek = Fastseek(f) + f.seek(0) + with av.open(f, mode="r") as container: + # Note: This video has frames of 33 PTS duration, + # But the time for the next frame increases by 54, 34, 12, 54, 34, 33 + # (afterwards by [33, 33, 34], repeating) + # last_pts = 0 + # for idx, frame in enumerate(container.decode(video=0)): + # print(f"+{frame.pts - last_pts}") + # last_pts = frame.pts + # print(f"{idx}: {frame.pts}+{frame.duration} {'KF' if frame.key_frame else ''}") + # if idx > 500: + # break + reader = FastseekReaderByPts(fastseek, container) + + """ + Frame 1 and 2 actually overlap + 0: 0+33 KF + +54 + 1: 54+33 + +34 + 2: 88+33 + +12 + 3: 100+33 + +54 + 4: 154+33 + +34 + 5: 188+33 + + 247: 8254+33 + +34 + 248: 8288+33 + +33 + 249: 8321+33 + +33 + 250: 8354+33 KF + +34 + 251: 8388+33 + +33 + 252: 8421+33 + +33 + 253: 8454+33 + """ + + # Frame 0, 1 + frames = list(reader.seek_read(0, 85)) + assert len(frames) == 2, len(frames) + assert frames[0].pts == 0 + assert frames[1].pts == 54 + assert reader.skipped == 0 + + # Frame 1, 2 + frames = list(reader.seek_read(85, 103)) + assert len(frames) == 2, len(frames) + assert frames[0].pts == 54 + assert frames[1].pts == 88 + assert reader.skipped == 0 + + # Frame 2 + frames = list(reader.seek_read(100, 100)) + assert len(frames) == 1, len(frames) + assert frames[0].pts == 88 + assert reader.skipped == 0 + + # Frame 3 (overlaps with frame 2) + frames = list(reader.seek_read(132, 132)) + assert len(frames) == 1, len(frames) + assert frames[0].pts == 100 + assert reader.skipped == 0 + + # Frame 249 + frames = list(reader.seek_read(8321, 8321)) + assert len(frames) == 1, len(frames) + assert frames[0].pts == 8321 + # Skipped frames 4-248 (inclusive) + assert reader.skipped == 245 + + reader.skipped = 0 + + # Frame 249 + frames = list(reader.seek_read(8353, 8353)) + assert len(frames) == 1, len(frames) + assert frames[0].pts == 8321 + assert reader.skipped == 0 + + # Frame 249-251 + frames = list(reader.seek_read(8353, 8388)) + assert len(frames) == 3, len(frames) + assert frames[0].pts == 8321 + assert frames[1].pts == 8354 + assert frames[2].pts == 8388 + assert reader.skipped == 0 + + # Frame 0 (go backwards and skip 1 frame) + frames = list(reader.seek_read(85, 85)) + assert len(frames) == 1, len(frames) + assert frames[0].pts == 54 + assert reader.skipped == 1 + + reader.skipped = 0 + + # Frame 251 (seek to keyframe 250 and skip 1 frame) + frames = list(reader.seek_read(8388, 8388)) + assert len(frames) == 1, len(frames) + assert frames[0].pts == 8388 + assert reader.skipped == 1 + + class TestVideoDecode(unittest.TestCase): """Test video decoding functionality.""" @@ -108,6 +428,58 @@ def test_decode_all_frames(self): "Energon decoded video does not match baseline" ) + def test_verify_video_decode(self): + """Verify the video decode matches the baseline.""" + av_decoder = AVDecoder(io.BytesIO(Path("tests/data/sync_test.mp4").read_bytes())) + all_timestamps = [] + for frame in [*range(5), *range(245, 255), *range(1881, 1891)]: + # print(f"Loading frame {frame}") + video_data, timestamps = av_decoder.get_video_clips( + video_clip_ranges=[(frame, frame)], video_unit="frames" + ) + assert len(video_data) == 1 + assert video_data[0].shape == (1, 3, 108, 192), ( + f"Shape of frame {frame} is {video_data[0].shape}" + ) + assert (video_data[0] == self.complete_video_tensor[frame : frame + 1]).all() + # print(f"Timestamp for frame {frame}: {timestamps[0]}") + all_timestamps.append(0.5 * (timestamps[0][0] + timestamps[0][1])) + + for frame, timestamp1, timestamp2 in zip( + [*range(5), *range(245, 255), *range(1881, 1891)], + all_timestamps, + all_timestamps[1:] + [float("inf")], + ): + if frame in (4, 254): + continue + # print(f"Loading frame {frame}") + video_data, timestamps = av_decoder.get_video_clips( + video_clip_ranges=[(timestamp1, timestamp1)], video_unit="seconds" + ) + assert len(video_data) == 1 + assert video_data[0].shape == (1, 3, 108, 192), ( + f"Shape of frame {frame} is {video_data[0].shape}" + ) + assert (video_data[0] == self.complete_video_tensor[frame : frame + 1]).all() + assert 0.5 * (timestamps[0][0] + timestamps[0][1]) == timestamp1, ( + f"Timestamp for frame {frame} is {timestamps[0][0]} + {timestamps[0][1]}" + ) + + video_data, timestamps = av_decoder.get_video_clips( + video_clip_ranges=[(timestamp1, timestamp2)], video_unit="seconds" + ) + assert len(video_data) == 1 + if frame == 1890: + assert video_data[0].shape == (1, 3, 108, 192), ( + f"Shape of frame {frame} is {video_data[0].shape}" + ) + assert (video_data[0] == self.complete_video_tensor[frame : frame + 1]).all() + else: + assert video_data[0].shape == (2, 3, 108, 192), ( + f"Shape of frame {frame} is {video_data[0].shape}" + ) + assert (video_data[0] == self.complete_video_tensor[frame : frame + 2]).all() + def test_decode_metadata(self): """Test decoding metadata.""" expected_metadata = [ @@ -120,6 +492,7 @@ def test_decode_metadata(self): audio_duration=63.103, audio_channels=2, audio_sample_rate=48000, + audio_num_samples=3028992, ), AVMetadata( video_duration=63.03333333333333, @@ -130,13 +503,14 @@ def test_decode_metadata(self): audio_duration=63.068, audio_channels=2, audio_sample_rate=48000, + audio_num_samples=3027968, ), ] for video_file, expected_metadata in zip( ["tests/data/sync_test.mkv", "tests/data/sync_test.mp4"], expected_metadata ): av_decoder = AVDecoder(io.BytesIO(Path(video_file).read_bytes())) - assert av_decoder.get_metadata() == expected_metadata, ( + assert av_decoder.get_metadata(get_audio_num_samples=True) == expected_metadata, ( f"Metadata does not match expected metadata for {video_file}" ) @@ -178,6 +552,51 @@ def test_decode_strided_resized(self): "Energon decoded video does not match baseline" ) + def test_time_precision(self): + """Test decoding video frames with time precision.""" + av_decoder = AVDecoder(io.BytesIO(Path("tests/data/sync_test.mp4").read_bytes())) + video_data, timestamps = av_decoder.get_video_clips( + video_clip_ranges=[ + (4 + 1 / 30, 4 + 1 / 30), + (4 + 1 / 30 + 1 / 60, 4 + 1 / 30 + 1 / 60), + ], + video_unit="seconds", + ) + assert (timestamps[0][0] == 4 + 1 / 30) and (timestamps[0][1] == 4 + 2 / 30), ( + f"Timestamp for frame 0 is {timestamps[0][0]} and {timestamps[0][1]}" + ) + assert (timestamps[1][0] == 4 + 1 / 30) and (timestamps[1][1] == 4 + 2 / 30), ( + f"Timestamp for frame 0 is {timestamps[1][0]} and {timestamps[1][1]}" + ) + # from PIL import Image + + # Image.fromarray(video_data[0][0, :, 18:55, 18:55].numpy().transpose(1, 2, 0)).save( + # "circ.png" + # ) + assert (video_data[0][0, :, 18:55, 18:55] > 250).all(), ( + "First extracted frame is not all white in the area (18, 18, 55, 55)" + ) + + av_decoder = AVDecoder(io.BytesIO(Path("tests/data/sync_test.mp4").read_bytes())) + video_data, timestamps = av_decoder.get_video_clips( + video_clip_ranges=[(4 * 30 + 1, 4 * 30 + 1), (4 * 30 + 1, 4 * 30 + 1)], + video_unit="frames", + ) + assert (timestamps[0][0] == 4 + 1 / 30) and (timestamps[0][1] == 4 + 2 / 30), ( + f"Timestamp for frame 0 is {timestamps[0][0]} and {timestamps[0][1]}" + ) + assert (timestamps[1][0] == 4 + 1 / 30) and (timestamps[1][1] == 4 + 2 / 30), ( + f"Timestamp for frame 0 is {timestamps[1][0]} and {timestamps[1][1]}" + ) + from PIL import Image + + Image.fromarray(video_data[0][0, :, 18:55, 18:55].numpy().transpose(1, 2, 0)).save( + "circ.png" + ) + assert (video_data[0][0, :, 18:55, 18:55] > 250).all(), ( + "First extracted frame is not all white in the area (18, 18, 55, 55)" + ) + def test_video_audio_sync(self): """Test decoding video frames and audio clips together.""" av_decoder = AVDecoder(io.BytesIO(Path("tests/data/sync_test.mp4").read_bytes())) @@ -199,7 +618,11 @@ def test_video_audio_sync(self): video_clips = av_data.video_clips[2:] audio_clips = av_data.audio_clips[2:] # Then we check that the first extracted frame is all white in the area (18, 18, 55, 55) - # Image.fromarray(video_clips[0][0, :, 18:55, 18:55].numpy().transpose(1,2,0)).save('circ.png') + # from PIL import Image + + # Image.fromarray(video_clips[0][0, :, 18:55, 18:55].numpy().transpose(1, 2, 0)).save( + # "circ.png" + # ) assert (video_clips[0][0, :, 18:55, 18:55] > 250).all(), ( "First extracted frame is not all white in the area (18, 18, 55, 55)" ) @@ -222,29 +645,29 @@ def test_video_audio_sync(self): def test_pickle_decoder(self): """Test AVDecoder on a video file can be pickled and unpickled.""" av_decoder = AVDecoder(io.BytesIO(Path("tests/data/sync_test.mp4").read_bytes())) - + # Get metadata from original decoder original_metadata = av_decoder.get_metadata() - + # Pickle the decoder pickled_data = pickle.dumps(av_decoder) - + # Unpickle the decoder unpickled_decoder = pickle.loads(pickled_data) - + # Verify metadata matches unpickled_metadata = unpickled_decoder.get_metadata() assert unpickled_metadata == original_metadata, ( f"Unpickled metadata {unpickled_metadata} does not match original {original_metadata}" ) - + # Verify we can still decode frames from the unpickled decoder video_tensor = get_single_frames_uniform( av_decoder=unpickled_decoder, num_frames=16, video_out_frame_size=(64, 64), ) - + # Check that we got the expected shape assert video_tensor.shape == (16, 3, 64, 64), ( f"Expected shape (16, 3, 64, 64), got {video_tensor.shape}" @@ -464,18 +887,20 @@ def test_decode_metadata(self): audio_duration=10.0, audio_channels=1, audio_sample_rate=32000, + audio_num_samples=320000, ), AVMetadata( audio_duration=12.782585034013605, audio_channels=2, audio_sample_rate=44100, + audio_num_samples=563712, ), ] for audio_file, expected_metadata in zip( ["tests/data/test_audio.flac", "tests/data/test_audio.wav"], expected_metadata ): av_decoder = AVDecoder(io.BytesIO(Path(audio_file).read_bytes())) - assert av_decoder.get_metadata() == expected_metadata, ( + assert av_decoder.get_metadata(get_audio_num_samples=True) == expected_metadata, ( f"Metadata does not match expected metadata for {audio_file}: {av_decoder.get_metadata()}" ) diff --git a/tests/test_crudedataset.py b/tests/test_crudedataset.py index 637ce8fb..15949c85 100644 --- a/tests/test_crudedataset.py +++ b/tests/test_crudedataset.py @@ -4,9 +4,11 @@ """This module defines tests for crude datasets.""" import gc +import io import logging import pickle import re +import shutil import sys import tempfile import unittest @@ -32,14 +34,42 @@ from megatron.energon.cache.base import CachePool from megatron.energon.edataclass import edataclass from megatron.energon.epathlib.epath import EPath +from megatron.energon.errors import reraise_exception from megatron.energon.flavors.base_dataset import Sample -from megatron.energon.flavors.webdataset import MAIN_FOLDER_NAME +from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME from megatron.energon.flavors.webdataset.sample_decoder import SampleDecoder -from megatron.energon.flavors.webdataset.structs import reraise_exception +from megatron.energon.media.extractor import MediaFilterConfig, MediaFilterStrategy +from megatron.energon.media.filesystem_prepare import prepare_filesystem_dataset +from megatron.energon.media.metadata import AVMetadata, ImageMetadata from megatron.energon.source_info import SourceInfo from megatron.energon.task_encoder.cooking import cooker +def _noise_image_bytes(size: tuple[int, int], fmt: str, seed: int) -> bytes: + import numpy as np + from PIL import Image + + rng = np.random.default_rng(seed) + data = rng.integers(0, 256, size=(size[1], size[0], 3), dtype=np.uint8) + image = Image.fromarray(data.astype(np.uint8), "RGB") + buffer = io.BytesIO() + image.save(buffer, format=fmt) + return buffer.getvalue() + + +def _metadata_description(metadata) -> str: + if isinstance(metadata, ImageMetadata): + return f"IMG-{metadata.width}x{metadata.height}-{metadata.format}" + if isinstance(metadata, AVMetadata): + if metadata.video_fps is not None: + return f"VIDEO-{metadata.video_width}x{metadata.video_height}@{metadata.video_fps}fps-{metadata.video_duration:0.1f}s" + elif metadata.audio_sample_rate is not None: + return f"AUDIO-{metadata.audio_duration:0.1f}s@{metadata.audio_sample_rate}Hz" + else: + return "AV-UNKNOWN" + return "UNKNOWN" + + @edataclass class LazyTextSample(Sample): txt: str @@ -86,6 +116,29 @@ def cook_aux(sample: dict, pkl_source: FileStore, fs_source: FileStore) -> TextS ) +@stateless +@cooker(need_primary=True) +def cook_media_metadata(sample: dict, primary: FileStore, media: FileStore) -> TextSample: + """This cooker loads the media from the primary and auxiliary datasets and + returns a text sample with the metadata descriptions of each.""" + + # print(f"Cooking media metadata for {sample}") + filename = sample["__sources__"][0].file_names[0] + + primary_media_metadata = primary.get_media_metadata(filename) + aux_media_metadata = media.get_media_metadata(filename) + + return TextSample( + **basic_sample_keys(sample), + text="|".join( + [ + _metadata_description(primary_media_metadata), + _metadata_description(aux_media_metadata), + ] + ), + ) + + class CookingTaskEncoder(DefaultTaskEncoder[TextSample, TextSample, TextBatch, TextBatch]): """A simple task encoder for captioning.""" @@ -93,6 +146,7 @@ class CookingTaskEncoder(DefaultTaskEncoder[TextSample, TextSample, TextBatch, T Cooker(cook_text, has_subflavors={"crude_type": "txtpkl"}), Cooker(cook_other, has_subflavors={"crude_type": "otherpkl"}), Cooker(cook_aux, has_subflavors={"crude_type": "aux_random_access"}), + Cooker(cook_media_metadata, has_subflavors={"crude_type": "media_metadata"}), ] def batch(self, samples: List[TextSample]) -> TextBatch: @@ -291,6 +345,30 @@ def setUp(self): ) ) + self.multimedia_wds_path = self.dataset_path / "multimedia_wds" + self.create_multimedia_webdataset(self.multimedia_wds_path) + + self.multimedia_fs_path = self.dataset_path / "multimedia_fs" + self.create_multimedia_filesystem_dataset(self.multimedia_fs_path) + + self.media_mds_path = self.dataset_path / "media_metadataset.yaml" + with open(self.media_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " path: multimedia_wds", + " aux:", + " media: filesystem://multimedia_fs", + " subflavors:", + " crude_type: media_metadata", + ] + ) + ) + print(self.dataset_path) def tearDown(self): @@ -326,6 +404,7 @@ def create_crude_text_test_dataset(path: Path, offset: int): split_parts_ratio=[("train", 1.0)], shuffle_seed=None, workers=1, + media_filter=None, ) with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f: @@ -341,12 +420,66 @@ def create_crude_text_test_dataset(path: Path, offset: int): ) ) + @staticmethod + def create_multimedia_webdataset(path: Path): + path.mkdir(exist_ok=True, parents=True) + (path / "parts").mkdir(exist_ok=True, parents=True) + + jpg_bytes = _noise_image_bytes((32, 16), "JPEG", seed=0) + png_bytes = _noise_image_bytes((24, 24), "PNG", seed=1) + video_bytes = Path("tests/data/sync_test.mp4").read_bytes() + audio_bytes = Path("tests/data/test_audio.flac").read_bytes() + + with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=10) as shard_writer: + shard_writer.write({"__key__": "image000", "jpg": jpg_bytes}) + shard_writer.write({"__key__": "image001", "png": png_bytes}) + shard_writer.write({"__key__": "audio001", "flac": audio_bytes}) + shard_writer.write({"__key__": "video001", "mp4": video_bytes}) + total_shards = shard_writer.shard + + from megatron.energon.flavors import BaseWebdatasetFactory + + BaseWebdatasetFactory.prepare_dataset( + path, + [f"parts/data-{{0..{total_shards - 1}}}.tar"], + split_parts_ratio=[("train", 1.0)], + shuffle_seed=None, + workers=1, + media_filter=MediaFilterConfig(strategy=MediaFilterStrategy.EXTENSION), + ) + + with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: CrudeWebdataset", + "subflavors:", + " crude_type: media_metadata", + ] + ) + ) + + @staticmethod + def create_multimedia_filesystem_dataset(path: Path): + path.mkdir(exist_ok=True, parents=True) + + (path / "image000.jpg").write_bytes(_noise_image_bytes((32, 16), "JPEG", seed=0)) + (path / "image001.png").write_bytes(_noise_image_bytes((24, 24), "PNG", seed=1)) + shutil.copyfile("tests/data/sync_test.mp4", path / "video001.mp4") + shutil.copyfile("tests/data/test_audio.flac", path / "audio001.flac") + + prepare_filesystem_dataset( + EPath(path), MediaFilterConfig(strategy=MediaFilterStrategy.EXTENSION), progress=False + ) + def test_metadataset(self): torch.manual_seed(42) worker_config = WorkerConfig( rank=0, world_size=1, num_workers=0, + global_error_handler=reraise_exception, ) # Train mode dataset @@ -358,7 +491,6 @@ def test_metadataset(self): task_encoder=CookingTaskEncoder(), shuffle_buffer_size=None, max_samples_per_sequence=None, - handler=reraise_exception, ) loader = get_savable_loader( train_dataset, @@ -701,6 +833,42 @@ def test_aux_filesystem_reference(self): assert sample.txts[0].endswith("|aux|__module__: megatron.ener>") + def test_media_metadata_webdataset(self): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + ) + + loader = get_savable_loader( + get_train_dataset( + self.media_mds_path, + batch_size=1, + worker_config=worker_config, + task_encoder=CookingTaskEncoder(), + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ) + ) + + descriptions = [] + for _, batch in zip(range(4), loader): + descriptions.extend(batch.txts) + + # from pprint import pprint + # pprint(descriptions, indent=4) + + # The descriptions are like "A|B", where A is the format + # in the WebDataset and B is the format in the auxiliary dataset. + + assert descriptions == [ + "IMG-32x16-JPEG|IMG-32x16-JPEG", + "IMG-24x24-PNG|IMG-24x24-PNG", + "AUDIO-10.0s@32000Hz|AUDIO-10.0s@32000Hz", + "VIDEO-192x108@30.0fps-63.0s|VIDEO-192x108@30.0fps-63.0s", + ] + def test_nomds(self): torch.manual_seed(42) worker_config = WorkerConfig( diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 43567c9d..01aa74ee 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -48,7 +48,7 @@ from megatron.energon.dataset_config import get_dataset_from_config from megatron.energon.edataclass import edataclass from megatron.energon.flavors import BaseWebdatasetFactory -from megatron.energon.flavors.webdataset import MAIN_FOLDER_NAME +from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME from megatron.energon.task_encoder.base import stateless from megatron.energon.tools.analyze_debug import command as analyze_debug_command from megatron.energon.tools.info import command as info_command @@ -93,6 +93,10 @@ class CaptioningBatch(Batch): caption: torch.Tensor +class ShouldRaiseException(Exception): + pass + + class TestDataset(unittest.TestCase): # Set up the test fixture def setUp(self): @@ -431,9 +435,7 @@ def test_sample_loader_key(self): sample_type=CaptioningSample, ) captions = set(sample["caption"] for sample in self.samples) - keys = set( - f"parts/data-{idx // 30:d}.tar/{idx:06d}" for idx in range(len(self.samples)) - ) + keys = set(f"{idx:06d}" for idx in range(len(self.samples))) for sample in get_loader(ds.build()): assert sample.caption[:4] == "" captions.remove(sample.caption[4:]) @@ -452,9 +454,7 @@ def test_exclusion(self): ) keys = [entry.__key__ for entry in get_loader(ds.build())] - assert keys == [ - f"parts/data-1.tar/{i:06d}" for i in list(range(30, 35)) + list(range(40, 50)) - ], keys + assert keys == [f"{i:06d}" for i in list(range(30, 35)) + list(range(40, 50))], keys def test_loader(self): torch.manual_seed(42) @@ -516,8 +516,8 @@ def hist(data): assert len(loader2) == 5 # The order in the split is shuffled this way assert list(key for batch in loader2 for key in batch.__key__) == [ - f"parts/data-1.tar/{i:06d}" for i in range(30, 50) - ] + [f"parts/data-0.tar/{i:06d}" for i in range(30)] + f"{i:06d}" for i in range(30, 50) + ] + [f"{i:06d}" for i in range(30)] def test_default_dataset(self): torch.manual_seed(42) @@ -1567,7 +1567,7 @@ class GroupingTaskEncoder( ): @stateless def encode_sample(self, sample: CaptioningSample) -> CaptioningSample: - sample.caption = sample.__key__.split("/")[-2] + sample.caption = sample.__sources__[0].shard_name.split("/")[-1] return sample def batch_group_criterion(self, sample: CaptioningSample) -> Tuple[Hashable, int]: @@ -1743,6 +1743,66 @@ def test_prepare_dataset(self): assert result.exit_code == 0, "Prepare failed, see output" assert "Done" in result.stdout, "Prepare failed, see output" + def test_prepare_dataset_noninteractive(self): + runner = CliRunner() + result = runner.invoke( + prepare_command, + [ + str(self.dataset_path), + "--non-interactive", + "--force-overwrite", + "--split-ratio=1,0,0", + "--sample-type=CaptioningSample", + '--field-map={"image": "png", "caption": "txt"}', + ], + catch_exceptions=False, + ) + assert result.exit_code == 0, "Prepare failed, see output" + assert "Done" in result.stdout, "Prepare failed, see output" + + # Check failure with non-interactive mode + result = runner.invoke( + prepare_command, + [ + str(self.dataset_path), + "--non-interactive", + ], + catch_exceptions=True, + ) + assert result.exit_code == 1, "Prepare failed, see output" + + def test_prepare_dataset_noninteractive_crude(self): + runner = CliRunner() + result = runner.invoke( + prepare_command, + [ + str(self.dataset_path), + "--non-interactive", + "--force-overwrite", + "--split-ratio=1,0,0", + "--sample-type=CrudeWebdataset", + "--dataset-yaml-name=dataset_crude.yaml", + ], + catch_exceptions=False, + ) + assert result.exit_code == 0, "Prepare failed, see output" + assert "Done" in result.stdout, "Prepare failed, see output" + + # Check failure with non-interactive mode + result = runner.invoke( + prepare_command, + [ + str(self.dataset_path), + "--non-interactive", + ], + catch_exceptions=True, + ) + assert result.exit_code == 1, "Prepare failed, see output" + + with open(self.dataset_path / MAIN_FOLDER_NAME / "dataset_crude.yaml", "r") as f: + content = f.read() + assert "CrudeWebdataset" in content + def test_preview_captioning_dataset(self): runner = CliRunner() result = runner.invoke( @@ -1752,7 +1812,7 @@ def test_preview_captioning_dataset(self): catch_exceptions=False, ) # First sample! - assert "__key__ (): 'parts/data-1.tar/000030'" in result.stdout + assert "__key__ (): '000030'" in result.stdout assert result.exit_code == 0, "Preview failed, see output" def test_info_captioning_dataset(self): @@ -1769,6 +1829,106 @@ def test_info_captioning_dataset(self): assert "train" in result.stdout assert result.exit_code == 0, "Preview failed, see output" + def test_custom_error_handler(self): + """Test that custom error handlers work correctly in TaskEncoder.""" + torch.manual_seed(42) + + # Track error handler calls + error_calls = [] + + class ErrorProneTaskEncoder(DefaultTaskEncoder): + def __init__(self): + super().__init__(raw_batch_type=CaptioningBatch) + + @stateless + def encode_sample(self, sample: CaptioningSample) -> EncodedCaptioningSample: + # Intentionally raise an error for specific samples to test error handling + if "000035" in sample.__key__: + raise ValueError(f"Intentional error for {sample.__key__}") + return EncodedCaptioningSample.derive_from( + sample, + image=sample.image, + caption=torch.frombuffer(bytearray(sample.caption.encode()), dtype=torch.uint8), + ) + + # Test with custom error handler + + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + global_error_handler=lambda e, s, sources: error_calls.append( + { + "exception": e, + "sample_key": getattr(s, "__key__", None), + "exception_type": type(e).__name__, + } + ), + ) + + loader = get_loader( + get_train_dataset( + self.dataset_path, + batch_size=5, + worker_config=worker_config, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + virtual_epoch_length=50, + task_encoder=ErrorProneTaskEncoder(), + ) + ) + + # Iterate through the loader - errors should be handled by custom handler + batches = [] + for i, batch in enumerate(loader): + batches.append(batch) + if i >= 9: # Get 10 batches (50 samples total) + break + + # Verify that the error handler was called + assert len(error_calls) > 0, "Error handler should have been called" + + # Verify that the error was for the right sample + assert any("000035" in call["sample_key"] for call in error_calls), ( + f"Error should have been for sample 000035, got: {error_calls}" + ) + + # Verify the exception type + assert all(call["exception_type"] == "ValueError" for call in error_calls), ( + "All errors should be ValueError" + ) + + print("Step 2: Reraise") + + def reraise(e, s, sources): + raise ShouldRaiseException() from e + + worker_config_r1 = WorkerConfig( + rank=0, + world_size=1, + num_workers=1, + global_error_handler=reraise, + ) + + loader = get_loader( + get_train_dataset( + self.dataset_path, + batch_size=5, + worker_config=worker_config_r1, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + virtual_epoch_length=50, + task_encoder=ErrorProneTaskEncoder(), + ) + ) + + with self.assertRaises(ShouldRaiseException): + batches = [] + for i, batch in enumerate(loader): + batches.append(batch) + if i >= 9: # Get 10 batches (50 samples total) + break + if __name__ == "__main__": unittest.main() diff --git a/tests/test_dataset_det.py b/tests/test_dataset_det.py index 92919b0a..15650307 100644 --- a/tests/test_dataset_det.py +++ b/tests/test_dataset_det.py @@ -26,7 +26,7 @@ get_train_dataset, ) from megatron.energon.dataset_config import get_dataset_from_config -from megatron.energon.flavors.webdataset import MAIN_FOLDER_NAME +from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME from megatron.energon.loader import get_savable_loader from megatron.energon.task_encoder.base import stateless from megatron.energon.tools.checkpoint import command_redist @@ -170,21 +170,21 @@ def test_split_parts(self): all_keys = [sample.__key__ for sample in dl] assert all_keys == [ - "parts/data-4.tar/000011", # Shard 4 first - "parts/data-4.tar/000012", - "parts/data-4.tar/000013", - "parts/data-4.tar/000014", - "parts/data-4.tar/000015", - "parts/data-4.tar/000016", - "parts/data-4.tar/000017", - "parts/data-4.tar/000018", - "parts/data-4.tar/000019", - "parts/data-4.tar/000020", - "parts/data-0.tar/000000", # Shard 0 - "parts/data-0.tar/000001", - "parts/data-2.tar/000004", # Shard 2 - "parts/data-2.tar/000005", - "parts/data-2.tar/000006", + "000011", # Shard 4 first + "000012", + "000013", + "000014", + "000015", + "000016", + "000017", + "000018", + "000019", + "000020", + "000000", # Shard 0 + "000001", + "000004", # Shard 2 + "000005", + "000006", ] def test_text_dataset(self): diff --git a/tests/test_epathlib.py b/tests/test_epathlib.py index 76e45f19..bb00467c 100644 --- a/tests/test_epathlib.py +++ b/tests/test_epathlib.py @@ -194,12 +194,14 @@ def test_multi_storage_client(self): # Test move and delete p4 = EPath("msc://default/tmp/random_file_0001") - p4.unlink() + if p4.is_file(): + p4.unlink() with p4.open("w") as fp: fp.write("*****") assert p4.is_file() p5 = EPath("msc://default/tmp/random_file_0002") - p5.unlink() + if p5.is_file(): + p5.unlink() assert p5.is_file() is False p4.move(p5) assert p5.is_file() diff --git a/tests/test_file_cache_pool.py b/tests/test_file_cache_pool.py index 08beb35b..f4600aca 100644 --- a/tests/test_file_cache_pool.py +++ b/tests/test_file_cache_pool.py @@ -65,8 +65,8 @@ def test_get_method(self): ) mock_decode_file_store = DecodeFileStore( + inner=mock_raw_file_store, decoder=MockDecoder(), - inner_reader=mock_raw_file_store, ) pool = FileStoreCachePool(parent_cache_dir=self.temp_path) try: @@ -530,8 +530,8 @@ def test_raw_method(self): } ) mock_decode_file_store = DecodeFileStore( + inner=mock_raw_file_store, decoder=MockDecoder(), - inner_reader=mock_raw_file_store, ) try: # Request lazy loading @@ -563,8 +563,8 @@ def test_pickle_method(self): } ) mock_decode_file_store = DecodeFileStore( + inner=mock_raw_file_store, decoder=MockDecoder(), - inner_reader=mock_raw_file_store, ) try: # Request lazy loading @@ -616,13 +616,18 @@ def test_concurrent_access(self): results = [] def worker(filename): - lazy_ref = pool.get_lazy(mock_raw_file_store, filename) - result, source_info = lazy_ref.get() - results.append(result) - assert source_info.dataset_path == mock_raw_file_store.get_path() - assert source_info.index is None - assert source_info.shard_name is None - assert source_info.file_names == (filename,) + try: + lazy_ref = pool.get_lazy(mock_raw_file_store, filename) + sample_for_source_info = {"__sources__": ()} + result = lazy_ref.get(sample_for_source_info) + assert sample_for_source_info["__sources__"][0].dataset_path == mock_raw_file_store.get_path() + assert sample_for_source_info["__sources__"][0].index is None + assert sample_for_source_info["__sources__"][0].shard_name is None + assert sample_for_source_info["__sources__"][0].file_names == (filename,) + except Exception as e: + results.append(e) + else: + results.append(result) try: # Start multiple threads accessing the same file @@ -638,6 +643,8 @@ def worker(filename): # All threads should get the correct result for r in results: + if isinstance(r, Exception): + raise r assert r == b"test data 1" finally: pool.close() diff --git a/tests/test_metadataset.py b/tests/test_metadataset.py index a6ec0b14..4126a295 100644 --- a/tests/test_metadataset.py +++ b/tests/test_metadataset.py @@ -26,7 +26,7 @@ get_val_dataset, load_dataset, ) -from megatron.energon.flavors.webdataset import MAIN_FOLDER_NAME +from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME from megatron.energon.metadataset.loader_interface import DatasetBlendMode from megatron.energon.wrappers.blend_dataset import BlendDataset @@ -757,7 +757,6 @@ def new_loader(): "batcher": "megatron.energon.task_encoder.base.DefaultTaskEncoder.batch", "batcher_stateless": True, "drop_last": False, - "error_handler": "megatron.energon.wrappers._log_exception.log_exception", "worker_config": wrk_cfg, "dataset": { "type": "MapDataset", diff --git a/tests/test_metadataset_fewsamp.py b/tests/test_metadataset_fewsamp.py index f1ad30c5..e8918237 100644 --- a/tests/test_metadataset_fewsamp.py +++ b/tests/test_metadataset_fewsamp.py @@ -22,7 +22,7 @@ get_savable_loader, get_train_dataset, ) -from megatron.energon.flavors.webdataset import MAIN_FOLDER_NAME +from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME # Speed up tests significantly by reducing the torch status check interval for broken worker shutdown try: diff --git a/tests/test_metadataset_v2.py b/tests/test_metadataset_v2.py index 70228273..e5ff6444 100644 --- a/tests/test_metadataset_v2.py +++ b/tests/test_metadataset_v2.py @@ -29,7 +29,7 @@ ) from megatron.energon.edataclass import edataclass from megatron.energon.epathlib.epath import EPath -from megatron.energon.flavors.webdataset import MAIN_FOLDER_NAME +from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME from megatron.energon.metadataset.loader import prepare_metadataset from megatron.energon.metadataset.loader_interface import DatasetBlendMode from megatron.energon.task_encoder.base import DefaultTaskEncoder diff --git a/uv.lock b/uv.lock index 0534ec6e..e5b6238f 100644 --- a/uv.lock +++ b/uv.lock @@ -810,7 +810,7 @@ wheels = [ [[package]] name = "google-cloud-storage" -version = "3.4.1" +version = "3.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "google-api-core" }, @@ -820,9 +820,9 @@ dependencies = [ { name = "google-resumable-media" }, { name = "requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/bd/ef/7cefdca67a6c8b3af0ec38612f9e78e5a9f6179dd91352772ae1a9849246/google_cloud_storage-3.4.1.tar.gz", hash = "sha256:6f041a297e23a4b485fad8c305a7a6e6831855c208bcbe74d00332a909f82268", size = 17238203, upload-time = "2025-10-08T18:43:39.665Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4e/a6/6e0a318f70975a3c048c0e1a18aee4f7b6d7dac1e798fdc5353c5248d418/google_cloud_storage-3.4.0.tar.gz", hash = "sha256:4c77ec00c98ccc6428e4c39404926f41e2152f48809b02af29d5116645c3c317", size = 17226847, upload-time = "2025-09-15T10:40:05.045Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/83/6e/b47d83d3a35231c6232566341b0355cce78fd4e6988a7343725408547b2c/google_cloud_storage-3.4.1-py3-none-any.whl", hash = "sha256:972764cc0392aa097be8f49a5354e22eb47c3f62370067fb1571ffff4a1c1189", size = 290142, upload-time = "2025-10-08T18:43:37.524Z" }, + { url = "https://files.pythonhosted.org/packages/16/12/164a90e4692423ed5532274928b0e19c8cae345ae1aa413d78c6b688231b/google_cloud_storage-3.4.0-py3-none-any.whl", hash = "sha256:16eeca305e4747a6871f8f7627eef3b862fdd365b872ca74d4a89e9841d0f8e8", size = 278423, upload-time = "2025-09-15T10:40:03.349Z" }, ] [[package]] @@ -884,6 +884,72 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/86/f1/62a193f0227cf15a920390abe675f386dec35f7ae3ffe6da582d3ade42c7/googleapis_common_protos-1.70.0-py3-none-any.whl", hash = "sha256:b8bfcca8c25a2bb253e0e0b0adaf8c00773e5e6af6fd92397576680b807e0fd8", size = 294530, upload-time = "2025-04-14T10:17:01.271Z" }, ] +[[package]] +name = "hf-transfer" +version = "0.1.9" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1a/eb/8fc64f40388c29ce8ce3b2b180a089d4d6b25b1d0d232d016704cb852104/hf_transfer-0.1.9.tar.gz", hash = "sha256:035572865dab29d17e783fbf1e84cf1cb24f3fcf8f1b17db1cfc7fdf139f02bf", size = 25201, upload-time = "2025-01-07T10:05:12.947Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a4/78/0dce00208f585fae675f40033ef9a30dedfa83665d5ac79f16beb4a0a6c2/hf_transfer-0.1.9-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:6e94e8822da79573c9b6ae4d6b2f847c59a7a06c5327d7db20751b68538dc4f6", size = 1386084, upload-time = "2025-01-07T10:04:47.874Z" }, + { url = "https://files.pythonhosted.org/packages/ea/2e/3d60b1a9e9f29a2152aa66c823bf5e399ae7be3fef310ff0de86779c5d2d/hf_transfer-0.1.9-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:3ebc4ab9023414880c8b1d3c38174d1c9989eb5022d37e814fa91a3060123eb0", size = 1343558, upload-time = "2025-01-07T10:04:42.313Z" }, + { url = "https://files.pythonhosted.org/packages/fb/38/130a5ac3747f104033591bcac1c961cb1faadfdc91704f59b09c0b465ff2/hf_transfer-0.1.9-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8674026f21ed369aa2a0a4b46000aca850fc44cd2b54af33a172ce5325b4fc82", size = 3726676, upload-time = "2025-01-07T10:04:11.539Z" }, + { url = "https://files.pythonhosted.org/packages/15/a1/f4e27c5ad17aac616ae0849e2aede5aae31db8267a948c6b3eeb9fd96446/hf_transfer-0.1.9-cp313-cp313t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3a736dfbb2c84f5a2c975478ad200c0c8bfcb58a25a35db402678fb87ce17fa4", size = 3062920, upload-time = "2025-01-07T10:04:16.297Z" }, + { url = "https://files.pythonhosted.org/packages/8d/0d/727abdfba39bc3f1132cfa4c970588c2c0bb0d82fe2d645cc10f4e2f8e0b/hf_transfer-0.1.9-cp313-cp313t-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:504b8427fd785dd8546d53b9fafe6e436bd7a3adf76b9dce556507650a7b4567", size = 3578681, upload-time = "2025-01-07T10:04:29.702Z" }, + { url = "https://files.pythonhosted.org/packages/50/d0/2b213eb1ea8b1252ccaf1a6c804d0aba03fea38aae4124df6a3acb70511a/hf_transfer-0.1.9-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2c7fc1b85f4d0f76e452765d7648c9f4bfd0aedb9ced2ae1ebfece2d8cfaf8e2", size = 3398837, upload-time = "2025-01-07T10:04:22.778Z" }, + { url = "https://files.pythonhosted.org/packages/8c/8a/79dbce9006e0bd6b74516f97451a7b7c64dbbb426df15d901dd438cfeee3/hf_transfer-0.1.9-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d991376f0eac70a60f0cbc95602aa708a6f7c8617f28b4945c1431d67b8e3c8", size = 3546986, upload-time = "2025-01-07T10:04:36.415Z" }, + { url = "https://files.pythonhosted.org/packages/a9/f7/9ac239b6ee6fe0bad130325d987a93ea58c4118e50479f0786f1733b37e8/hf_transfer-0.1.9-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:e6ac4eddcd99575ed3735ed911ddf9d1697e2bd13aa3f0ad7e3904dd4863842e", size = 4071715, upload-time = "2025-01-07T10:04:53.224Z" }, + { url = "https://files.pythonhosted.org/packages/d8/a3/0ed697279f5eeb7a40f279bd783cf50e6d0b91f24120dcf66ef2cf8822b4/hf_transfer-0.1.9-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:57fd9880da1ee0f47250f735f791fab788f0aa1ee36afc49f761349869c8b4d9", size = 3388081, upload-time = "2025-01-07T10:04:57.818Z" }, + { url = "https://files.pythonhosted.org/packages/dc/eb/47e477bdf1d784f31c7540db6cc8c354b777e51a186897a7abda34517f36/hf_transfer-0.1.9-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:5d561f0520f493c66b016d99ceabe69c23289aa90be38dd802d2aef279f15751", size = 3658654, upload-time = "2025-01-07T10:05:03.168Z" }, + { url = "https://files.pythonhosted.org/packages/45/07/6661e43fbee09594a8a5e9bb778107d95fe38dac4c653982afe03d32bd4d/hf_transfer-0.1.9-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:a5b366d34cd449fe9b20ef25941e6eef0460a2f74e7389f02e673e1f88ebd538", size = 3690551, upload-time = "2025-01-07T10:05:09.238Z" }, + { url = "https://files.pythonhosted.org/packages/81/f5/461d2e5f307e5048289b1168d5c642ae3bb2504e88dff1a38b92ed990a21/hf_transfer-0.1.9-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:e66acf91df4a8b72f60223059df3003062a5ae111757187ed1a06750a30e911b", size = 1393046, upload-time = "2025-01-07T10:04:51.003Z" }, + { url = "https://files.pythonhosted.org/packages/41/ba/8d9fd9f1083525edfcb389c93738c802f3559cb749324090d7109c8bf4c2/hf_transfer-0.1.9-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:8669dbcc7a3e2e8d61d42cd24da9c50d57770bd74b445c65123291ca842a7e7a", size = 1348126, upload-time = "2025-01-07T10:04:45.712Z" }, + { url = "https://files.pythonhosted.org/packages/8e/a2/cd7885bc9959421065a6fae0fe67b6c55becdeda4e69b873e52976f9a9f0/hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8fd0167c4407a3bc4cdd0307e65ada2294ec04f1813d8a69a5243e379b22e9d8", size = 3728604, upload-time = "2025-01-07T10:04:14.173Z" }, + { url = "https://files.pythonhosted.org/packages/f6/2e/a072cf196edfeda3310c9a5ade0a0fdd785e6154b3ce24fc738c818da2a7/hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ee8b10afedcb75f71091bcc197c526a6ebf5c58bbbadb34fdeee6160f55f619f", size = 3064995, upload-time = "2025-01-07T10:04:18.663Z" }, + { url = "https://files.pythonhosted.org/packages/c2/84/aec9ef4c0fab93c1ea2b1badff38c78b4b2f86f0555b26d2051dbc920cde/hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5828057e313de59300dd1abb489444bc452efe3f479d3c55b31a8f680936ba42", size = 3580908, upload-time = "2025-01-07T10:04:32.834Z" }, + { url = "https://files.pythonhosted.org/packages/29/63/b560d39651a56603d64f1a0212d0472a44cbd965db2fa62b99d99cb981bf/hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fc6bd19e1cc177c66bdef15ef8636ad3bde79d5a4f608c158021153b4573509d", size = 3400839, upload-time = "2025-01-07T10:04:26.122Z" }, + { url = "https://files.pythonhosted.org/packages/d6/d8/f87ea6f42456254b48915970ed98e993110521e9263472840174d32c880d/hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cdca9bfb89e6f8f281890cc61a8aff2d3cecaff7e1a4d275574d96ca70098557", size = 3552664, upload-time = "2025-01-07T10:04:40.123Z" }, + { url = "https://files.pythonhosted.org/packages/d6/56/1267c39b65fc8f4e2113b36297320f102718bf5799b544a6cbe22013aa1d/hf_transfer-0.1.9-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:89a23f58b7b7effbc047b8ca286f131b17728c99a9f972723323003ffd1bb916", size = 4073732, upload-time = "2025-01-07T10:04:55.624Z" }, + { url = "https://files.pythonhosted.org/packages/82/1a/9c748befbe3decf7cb415e34f8a0c3789a0a9c55910dea73d581e48c0ce5/hf_transfer-0.1.9-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:dc7fff1345980d6c0ebb92c811d24afa4b98b3e07ed070c8e38cc91fd80478c5", size = 3390096, upload-time = "2025-01-07T10:04:59.98Z" }, + { url = "https://files.pythonhosted.org/packages/72/85/4c03da147b6b4b7cb12e074d3d44eee28604a387ed0eaf7eaaead5069c57/hf_transfer-0.1.9-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:1a6bd16c667ebe89a069ca163060127a794fa3a3525292c900b8c8cc47985b0d", size = 3664743, upload-time = "2025-01-07T10:05:05.416Z" }, + { url = "https://files.pythonhosted.org/packages/e7/6e/e597b04f753f1b09e6893075d53a82a30c13855cbaa791402695b01e369f/hf_transfer-0.1.9-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:d2fde99d502093ade3ab1b53f80da18480e9902aa960dab7f74fb1b9e5bc5746", size = 3695243, upload-time = "2025-01-07T10:05:11.411Z" }, + { url = "https://files.pythonhosted.org/packages/09/89/d4e234727a26b2546c8fb70a276cd924260d60135f2165bf8b9ed67bb9a4/hf_transfer-0.1.9-cp38-abi3-win32.whl", hash = "sha256:435cc3cdc8524ce57b074032b8fd76eed70a4224d2091232fa6a8cef8fd6803e", size = 1086605, upload-time = "2025-01-07T10:05:18.873Z" }, + { url = "https://files.pythonhosted.org/packages/a1/14/f1e15b851d1c2af5b0b1a82bf8eb10bda2da62d98180220ba6fd8879bb5b/hf_transfer-0.1.9-cp38-abi3-win_amd64.whl", hash = "sha256:16f208fc678911c37e11aa7b586bc66a37d02e636208f18b6bc53d29b5df40ad", size = 1160240, upload-time = "2025-01-07T10:05:14.324Z" }, +] + +[[package]] +name = "hf-xet" +version = "1.1.10" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/74/31/feeddfce1748c4a233ec1aa5b7396161c07ae1aa9b7bdbc9a72c3c7dd768/hf_xet-1.1.10.tar.gz", hash = "sha256:408aef343800a2102374a883f283ff29068055c111f003ff840733d3b715bb97", size = 487910, upload-time = "2025-09-12T20:10:27.12Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f7/a2/343e6d05de96908366bdc0081f2d8607d61200be2ac802769c4284cc65bd/hf_xet-1.1.10-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:686083aca1a6669bc85c21c0563551cbcdaa5cf7876a91f3d074a030b577231d", size = 2761466, upload-time = "2025-09-12T20:10:22.836Z" }, + { url = "https://files.pythonhosted.org/packages/31/f9/6215f948ac8f17566ee27af6430ea72045e0418ce757260248b483f4183b/hf_xet-1.1.10-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:71081925383b66b24eedff3013f8e6bbd41215c3338be4b94ba75fd75b21513b", size = 2623807, upload-time = "2025-09-12T20:10:21.118Z" }, + { url = "https://files.pythonhosted.org/packages/15/07/86397573efefff941e100367bbda0b21496ffcdb34db7ab51912994c32a2/hf_xet-1.1.10-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b6bceb6361c80c1cc42b5a7b4e3efd90e64630bcf11224dcac50ef30a47e435", size = 3186960, upload-time = "2025-09-12T20:10:19.336Z" }, + { url = "https://files.pythonhosted.org/packages/01/a7/0b2e242b918cc30e1f91980f3c4b026ff2eedaf1e2ad96933bca164b2869/hf_xet-1.1.10-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:eae7c1fc8a664e54753ffc235e11427ca61f4b0477d757cc4eb9ae374b69f09c", size = 3087167, upload-time = "2025-09-12T20:10:17.255Z" }, + { url = "https://files.pythonhosted.org/packages/4a/25/3e32ab61cc7145b11eee9d745988e2f0f4fafda81b25980eebf97d8cff15/hf_xet-1.1.10-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:0a0005fd08f002180f7a12d4e13b22be277725bc23ed0529f8add5c7a6309c06", size = 3248612, upload-time = "2025-09-12T20:10:24.093Z" }, + { url = "https://files.pythonhosted.org/packages/2c/3d/ab7109e607ed321afaa690f557a9ada6d6d164ec852fd6bf9979665dc3d6/hf_xet-1.1.10-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:f900481cf6e362a6c549c61ff77468bd59d6dd082f3170a36acfef2eb6a6793f", size = 3353360, upload-time = "2025-09-12T20:10:25.563Z" }, + { url = "https://files.pythonhosted.org/packages/ee/0e/471f0a21db36e71a2f1752767ad77e92d8cde24e974e03d662931b1305ec/hf_xet-1.1.10-cp37-abi3-win_amd64.whl", hash = "sha256:5f54b19cc347c13235ae7ee98b330c26dd65ef1df47e5316ffb1e87713ca7045", size = 2804691, upload-time = "2025-09-12T20:10:28.433Z" }, +] + +[[package]] +name = "huggingface-hub" +version = "0.35.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "filelock" }, + { name = "fsspec" }, + { name = "hf-xet", marker = "platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'" }, + { name = "packaging" }, + { name = "pyyaml" }, + { name = "requests" }, + { name = "tqdm" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/10/7e/a0a97de7c73671863ca6b3f61fa12518caf35db37825e43d63a70956738c/huggingface_hub-0.35.3.tar.gz", hash = "sha256:350932eaa5cc6a4747efae85126ee220e4ef1b54e29d31c3b45c5612ddf0b32a", size = 461798, upload-time = "2025-09-29T14:29:58.625Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/31/a0/651f93d154cb72323358bf2bbae3e642bdb5d2f1bfc874d096f7cb159fa0/huggingface_hub-0.35.3-py3-none-any.whl", hash = "sha256:0e3a01829c19d86d03793e4577816fe3bdfc1602ac62c7fb220d593d351224ba", size = 564262, upload-time = "2025-09-29T14:29:55.813Z" }, +] + [[package]] name = "humanize" version = "4.12.3" @@ -982,11 +1048,39 @@ wheels = [ [[package]] name = "lark" -version = "1.2.2" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1d/37/a13baf0135f348af608c667633cbe5d13aa2c5c15a56ae9ad3e6cba45ae3/lark-1.3.0.tar.gz", hash = "sha256:9a3839d0ca5e1faf7cfa3460e420e859b66bcbde05b634e73c369c8244c5fa48", size = 259551, upload-time = "2025-09-22T13:45:05.072Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a8/3e/1c6b43277de64fc3c0333b0e72ab7b52ddaaea205210d60d9b9f83c3d0c7/lark-1.3.0-py3-none-any.whl", hash = "sha256:80661f261fb2584a9828a097a2432efd575af27d20be0fd35d17f0fe37253831", size = 113002, upload-time = "2025-09-22T13:45:03.747Z" }, +] + +[[package]] +name = "llvmlite" +version = "0.45.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/af/60/bc7622aefb2aee1c0b4ba23c1446d3e30225c8770b38d7aedbfb65ca9d5a/lark-1.2.2.tar.gz", hash = "sha256:ca807d0162cd16cef15a8feecb862d7319e7a09bdb13aef927968e45040fed80", size = 252132, upload-time = "2024-08-13T19:49:00.652Z" } +sdist = { url = "https://files.pythonhosted.org/packages/99/8d/5baf1cef7f9c084fb35a8afbde88074f0d6a727bc63ef764fe0e7543ba40/llvmlite-0.45.1.tar.gz", hash = "sha256:09430bb9d0bb58fc45a45a57c7eae912850bedc095cd0810a57de109c69e1c32", size = 185600, upload-time = "2025-10-01T17:59:52.046Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2d/00/d90b10b962b4277f5e64a78b6609968859ff86889f5b898c1a778c06ec00/lark-1.2.2-py3-none-any.whl", hash = "sha256:c2276486b02f0f1b90be155f2c8ba4a8e194d42775786db622faccd652d8e80c", size = 111036, upload-time = "2024-08-13T19:48:58.603Z" }, + { url = "https://files.pythonhosted.org/packages/cf/6d/585c84ddd9d2a539a3c3487792b3cf3f988e28ec4fa281bf8b0e055e1166/llvmlite-0.45.1-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:1b1af0c910af0978aa55fa4f60bbb3e9f39b41e97c2a6d94d199897be62ba07a", size = 43043523, upload-time = "2025-10-01T18:02:58.621Z" }, + { url = "https://files.pythonhosted.org/packages/ae/34/992bd12d3ff245e0801bcf6013961daa8c19c9b9c2e61cb4b8bce94566f9/llvmlite-0.45.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:02a164db2d79088bbd6e0d9633b4fe4021d6379d7e4ac7cc85ed5f44b06a30c5", size = 37253122, upload-time = "2025-10-01T18:03:55.159Z" }, + { url = "https://files.pythonhosted.org/packages/a6/7b/6d7585998a5991fa74dc925aae57913ba8c7c2efff909de9d34cc1cd3c27/llvmlite-0.45.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f2d47f34e4029e6df3395de34cc1c66440a8d72712993a6e6168db228686711b", size = 56288210, upload-time = "2025-10-01T18:00:41.978Z" }, + { url = "https://files.pythonhosted.org/packages/b5/e2/a4abea058633bfc82eb08fd69ce242c118fdb9b0abad1fdcbe0bc6aedab5/llvmlite-0.45.1-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f7319e5f9f90720578a7f56fbc805bdfb4bc071b507c7611f170d631c3c0f1e0", size = 55140958, upload-time = "2025-10-01T18:01:55.694Z" }, + { url = "https://files.pythonhosted.org/packages/74/c0/233468e96ed287b953239c3b24b1d69df47c6ba9262bfdca98eda7e83a04/llvmlite-0.45.1-cp310-cp310-win_amd64.whl", hash = "sha256:4edb62e685867799e336723cb9787ec6598d51d0b1ed9af0f38e692aa757e898", size = 38132232, upload-time = "2025-10-01T18:04:41.538Z" }, + { url = "https://files.pythonhosted.org/packages/04/ad/9bdc87b2eb34642c1cfe6bcb4f5db64c21f91f26b010f263e7467e7536a3/llvmlite-0.45.1-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:60f92868d5d3af30b4239b50e1717cb4e4e54f6ac1c361a27903b318d0f07f42", size = 43043526, upload-time = "2025-10-01T18:03:15.051Z" }, + { url = "https://files.pythonhosted.org/packages/a5/ea/c25c6382f452a943b4082da5e8c1665ce29a62884e2ec80608533e8e82d5/llvmlite-0.45.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:98baab513e19beb210f1ef39066288784839a44cd504e24fff5d17f1b3cf0860", size = 37253118, upload-time = "2025-10-01T18:04:06.783Z" }, + { url = "https://files.pythonhosted.org/packages/fe/af/85fc237de98b181dbbe8647324331238d6c52a3554327ccdc83ced28efba/llvmlite-0.45.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3adc2355694d6a6fbcc024d59bb756677e7de506037c878022d7b877e7613a36", size = 56288209, upload-time = "2025-10-01T18:01:00.168Z" }, + { url = "https://files.pythonhosted.org/packages/0a/df/3daf95302ff49beff4230065e3178cd40e71294968e8d55baf4a9e560814/llvmlite-0.45.1-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2f3377a6db40f563058c9515dedcc8a3e562d8693a106a28f2ddccf2c8fcf6ca", size = 55140958, upload-time = "2025-10-01T18:02:11.199Z" }, + { url = "https://files.pythonhosted.org/packages/a4/56/4c0d503fe03bac820ecdeb14590cf9a248e120f483bcd5c009f2534f23f0/llvmlite-0.45.1-cp311-cp311-win_amd64.whl", hash = "sha256:f9c272682d91e0d57f2a76c6d9ebdfccc603a01828cdbe3d15273bdca0c3363a", size = 38132232, upload-time = "2025-10-01T18:04:52.181Z" }, + { url = "https://files.pythonhosted.org/packages/e2/7c/82cbd5c656e8991bcc110c69d05913be2229302a92acb96109e166ae31fb/llvmlite-0.45.1-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:28e763aba92fe9c72296911e040231d486447c01d4f90027c8e893d89d49b20e", size = 43043524, upload-time = "2025-10-01T18:03:30.666Z" }, + { url = "https://files.pythonhosted.org/packages/9d/bc/5314005bb2c7ee9f33102c6456c18cc81745d7055155d1218f1624463774/llvmlite-0.45.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1a53f4b74ee9fd30cb3d27d904dadece67a7575198bd80e687ee76474620735f", size = 37253123, upload-time = "2025-10-01T18:04:18.177Z" }, + { url = "https://files.pythonhosted.org/packages/96/76/0f7154952f037cb320b83e1c952ec4a19d5d689cf7d27cb8a26887d7bbc1/llvmlite-0.45.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b3796b1b1e1c14dcae34285d2f4ea488402fbd2c400ccf7137603ca3800864f", size = 56288211, upload-time = "2025-10-01T18:01:24.079Z" }, + { url = "https://files.pythonhosted.org/packages/00/b1/0b581942be2683ceb6862d558979e87387e14ad65a1e4db0e7dd671fa315/llvmlite-0.45.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:779e2f2ceefef0f4368548685f0b4adde34e5f4b457e90391f570a10b348d433", size = 55140958, upload-time = "2025-10-01T18:02:30.482Z" }, + { url = "https://files.pythonhosted.org/packages/33/94/9ba4ebcf4d541a325fd8098ddc073b663af75cc8b065b6059848f7d4dce7/llvmlite-0.45.1-cp312-cp312-win_amd64.whl", hash = "sha256:9e6c9949baf25d9aa9cd7cf0f6d011b9ca660dd17f5ba2b23bdbdb77cc86b116", size = 38132231, upload-time = "2025-10-01T18:05:03.664Z" }, + { url = "https://files.pythonhosted.org/packages/1d/e2/c185bb7e88514d5025f93c6c4092f6120c6cea8fe938974ec9860fb03bbb/llvmlite-0.45.1-cp313-cp313-macosx_10_15_x86_64.whl", hash = "sha256:d9ea9e6f17569a4253515cc01dade70aba536476e3d750b2e18d81d7e670eb15", size = 43043524, upload-time = "2025-10-01T18:03:43.249Z" }, + { url = "https://files.pythonhosted.org/packages/09/b8/b5437b9ecb2064e89ccf67dccae0d02cd38911705112dd0dcbfa9cd9a9de/llvmlite-0.45.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:c9f3cadee1630ce4ac18ea38adebf2a4f57a89bd2740ce83746876797f6e0bfb", size = 37253121, upload-time = "2025-10-01T18:04:30.557Z" }, + { url = "https://files.pythonhosted.org/packages/f7/97/ad1a907c0173a90dd4df7228f24a3ec61058bc1a9ff8a0caec20a0cc622e/llvmlite-0.45.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:57c48bf2e1083eedbc9406fb83c4e6483017879714916fe8be8a72a9672c995a", size = 56288210, upload-time = "2025-10-01T18:01:40.26Z" }, + { url = "https://files.pythonhosted.org/packages/32/d8/c99c8ac7a326e9735401ead3116f7685a7ec652691aeb2615aa732b1fc4a/llvmlite-0.45.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3aa3dfceda4219ae39cf18806c60eeb518c1680ff834b8b311bd784160b9ce40", size = 55140957, upload-time = "2025-10-01T18:02:46.244Z" }, + { url = "https://files.pythonhosted.org/packages/09/56/ed35668130e32dbfad2eb37356793b0a95f23494ab5be7d9bf5cb75850ee/llvmlite-0.45.1-cp313-cp313-win_amd64.whl", hash = "sha256:080e6f8d0778a8239cd47686d402cb66eb165e421efa9391366a9b7e5810a38b", size = 38132232, upload-time = "2025-10-01T18:05:14.477Z" }, ] [[package]] @@ -1128,12 +1222,18 @@ google-cloud-storage = [ guess-content = [ { name = "filetype" }, ] +huggingface = [ + { name = "multi-storage-client", extra = ["huggingface"] }, +] oci = [ { name = "multi-storage-client", extra = ["oci"] }, ] s3 = [ { name = "multi-storage-client", extra = ["boto3"] }, ] +tar-patcher = [ + { name = "numba" }, +] transforms = [ { name = "torchvision" }, ] @@ -1144,18 +1244,19 @@ requires-dist = [ { name = "bitstring", marker = "extra == 'av-decode'", specifier = ">=4.2.3" }, { name = "braceexpand" }, { name = "click" }, - { name = "dataslots", marker = "python_full_version < '3.10'" }, { name = "ebmlite", marker = "extra == 'av-decode'", specifier = ">=3.3.1" }, { name = "filetype", marker = "extra == 'av-decode'", specifier = ">=1.2.0" }, { name = "filetype", marker = "extra == 'guess-content'", specifier = ">=1.0.0" }, { name = "mfusepy" }, - { name = "multi-storage-client", specifier = ">=0.18.0,<0.26.0" }, + { name = "multi-storage-client", specifier = ">=0.33.0" }, { name = "multi-storage-client", extras = ["aistore"], marker = "extra == 'aistore'" }, { name = "multi-storage-client", extras = ["azure-storage-blob"], marker = "extra == 'azure-storage-blob'" }, { name = "multi-storage-client", extras = ["boto3"], marker = "extra == 's3'" }, { name = "multi-storage-client", extras = ["google-cloud-storage"], marker = "extra == 'google-cloud-storage'" }, + { name = "multi-storage-client", extras = ["huggingface"], marker = "extra == 'huggingface'" }, { name = "multi-storage-client", extras = ["oci"], marker = "extra == 'oci'" }, { name = "myst-parser", marker = "extra == 'dev'" }, + { name = "numba", marker = "extra == 'tar-patcher'" }, { name = "numpy" }, { name = "pillow", specifier = ">=10.0.1" }, { name = "pyyaml" }, @@ -1173,7 +1274,7 @@ requires-dist = [ { name = "tqdm" }, { name = "webdataset" }, ] -provides-extras = ["aistore", "av-decode", "azure-storage-blob", "dev", "google-cloud-storage", "guess-content", "oci", "s3", "transforms"] +provides-extras = ["aistore", "av-decode", "azure-storage-blob", "dev", "google-cloud-storage", "guess-content", "huggingface", "oci", "s3", "tar-patcher", "transforms"] [[package]] name = "mfusepy" @@ -1231,7 +1332,7 @@ wheels = [ [[package]] name = "multi-storage-client" -version = "0.25.0" +version = "0.33.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, @@ -1239,18 +1340,31 @@ dependencies = [ { name = "jsonschema" }, { name = "lark" }, { name = "opentelemetry-api" }, + { name = "prettytable" }, + { name = "psutil" }, { name = "python-dateutil" }, { name = "pyyaml" }, - { name = "tabulate" }, { name = "tqdm" }, { name = "wcmatch" }, { name = "xattr" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/e3/91/d748ad8faebc79018b5334fcc019426eaa0ce1ab250fc013d87d3f14ce79/multi_storage_client-0.25.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:aefdfeac5558caf00d585688309f344ba637f711709648b8e8e9b99a5d643226", size = 4670834, upload-time = "2025-07-10T23:35:49.784Z" }, - { url = "https://files.pythonhosted.org/packages/1b/47/858af2b7a879bb01d0f0173c888dab6adeba5a9e42182e6505cd86a69f22/multi_storage_client-0.25.0-cp39-abi3-macosx_11_0_x86_64.whl", hash = "sha256:6fb38fa6c6c8ea98c291117ee5c59056d40fe88b5a3fc70c64bd2c1feafe0851", size = 4772658, upload-time = "2025-07-10T23:36:37.412Z" }, - { url = "https://files.pythonhosted.org/packages/8b/eb/f129e2eabcabd5e3b185692c1b0f93c2bbe3fb18107b4bd129a80baea8db/multi_storage_client-0.25.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e77b75dd13a57a0dc2fa1d8ccc4c712522216fa6fd0879eb9c107dfd2f7e56fa", size = 2665745, upload-time = "2025-07-10T23:36:14.168Z" }, - { url = "https://files.pythonhosted.org/packages/a6/8b/c24bd9aea4492bbab7b8f361bf0b4fc8666e903d77ace8224f6705bc3cf1/multi_storage_client-0.25.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:234f4f14033a87af6b29d57a41b193ba371f3dd759730b5b54c37752b30ee094", size = 2812708, upload-time = "2025-07-10T23:35:26.208Z" }, + { url = "https://files.pythonhosted.org/packages/5c/c4/6279fb7d4b8b0a7af060047d592f00f8d49c547adfebe50bcd8d0d2dc8a5/multi_storage_client-0.33.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:df52b3040ef5698c6388fa589bd63812ae0d2f967d358a792abcad5638686590", size = 5282006, upload-time = "2025-10-23T03:45:37.761Z" }, + { url = "https://files.pythonhosted.org/packages/22/3b/23d8beccd73b887c4552bf884275611255b5028388fa3317365cd56c2a93/multi_storage_client-0.33.0-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:370da04b1e56a601ba505a29d42fcabc19b583e10d725a37bc0c11ba3573d211", size = 5403083, upload-time = "2025-10-23T03:53:11.998Z" }, + { url = "https://files.pythonhosted.org/packages/b0/ad/dc355d05fd369da0d800e5f7de24da0393f542c5a6f775f6bcee7edcacb1/multi_storage_client-0.33.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c57749a28ec5d49440f465fd73e4e2feaab18ece9b6e57c73395308b41950f66", size = 3178432, upload-time = "2025-10-23T04:07:00.543Z" }, + { url = "https://files.pythonhosted.org/packages/e0/ad/97b54419d8a58f696b85504568391a627641152f80650d7d2697fc2702ed/multi_storage_client-0.33.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7d95f5fe094aab00a240bf6aa11dfe85bec293b76b3688ec3a9c33d86c751d2", size = 3351102, upload-time = "2025-10-23T03:47:47.622Z" }, + { url = "https://files.pythonhosted.org/packages/52/28/1038a68b9df1b179a61967ce9f7d2e80b9954cdb289801afecde5f7660db/multi_storage_client-0.33.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4b5a0f5a0b7684835be20ae6782070884982a86665e9bab317375a56a20294d1", size = 5281523, upload-time = "2025-10-23T04:06:36.671Z" }, + { url = "https://files.pythonhosted.org/packages/6c/c5/e18de5e2a2671efdc0a12383b8d63f523044ca453525725b3450d0179c0e/multi_storage_client-0.33.0-cp311-cp311-macosx_11_0_x86_64.whl", hash = "sha256:0db694311f90f44ee8f6f7734a14a0857738a467f2ae201649218a3ecf1f6ab2", size = 5403353, upload-time = "2025-10-23T04:07:25.941Z" }, + { url = "https://files.pythonhosted.org/packages/7e/c9/d9f65eb2370151dbbb06925f4216ee017e6cdbf7657263fd98e60944e52b/multi_storage_client-0.33.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2cbe3a0b856f0b968f9fc693670a521b5a995b625351241ca008f866fdfff62a", size = 3180052, upload-time = "2025-10-23T03:57:32.797Z" }, + { url = "https://files.pythonhosted.org/packages/e7/38/08b9d84c93b19ae87caf542ae77f17dfa44a85281ba09de660ffcf3a7718/multi_storage_client-0.33.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:018e7e82255feeff973ff02563f11a30f5e507e4cbc87a2167a9568740144ef2", size = 3351389, upload-time = "2025-10-23T04:02:07.348Z" }, + { url = "https://files.pythonhosted.org/packages/6a/31/c95634a27723b5ba9d2d74158444cc5e40b151b51ae59ca196fc9993f039/multi_storage_client-0.33.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:030b3a592c6352605e9ebdb8d9303dd42daf5d171ffa684f3283d4a5c6e2edfe", size = 5273976, upload-time = "2025-10-23T04:04:35.99Z" }, + { url = "https://files.pythonhosted.org/packages/8c/cf/82d1778d73c3baaec331da4ae8d01fa7934bcd73336aa88a08d86d080347/multi_storage_client-0.33.0-cp312-cp312-macosx_11_0_x86_64.whl", hash = "sha256:14dc0ace16d3830917427d6376d14ef62bd053fb2509f893998555ca1e9c4dcb", size = 5400735, upload-time = "2025-10-23T03:58:37.149Z" }, + { url = "https://files.pythonhosted.org/packages/fc/34/a6194ec725ef80c02de58b5ed3520bb1711807df75a27f7214effd22df34/multi_storage_client-0.33.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a2821765d5c6de365b5b1dcdc7cf2ebba719ff4061fd02975639629f8aa319f6", size = 3182623, upload-time = "2025-10-23T04:03:29.551Z" }, + { url = "https://files.pythonhosted.org/packages/8f/36/7ec85178fd1dd69c278407a82acaccfb806449deda13f3dbd41f653d73bd/multi_storage_client-0.33.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f92f89480c58067fa53c178785b86e7650e16f277a61a732a8a7019173b16129", size = 3352104, upload-time = "2025-10-23T04:08:51.005Z" }, + { url = "https://files.pythonhosted.org/packages/88/ef/f2eb2efefb0e0588b29ed573b8354ecd72c38e6143da7ed5ecf53e859bf8/multi_storage_client-0.33.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:ed9af7e77e3cbac1f614816062b36975dcbc610bd3f8c86741d48aa18c718781", size = 5272154, upload-time = "2025-10-23T04:07:49.572Z" }, + { url = "https://files.pythonhosted.org/packages/1e/49/050aa4fccb2579d2ef5bd0d27169ec98fe85c92bba7a2c31154c491a4f75/multi_storage_client-0.33.0-cp313-cp313-macosx_11_0_x86_64.whl", hash = "sha256:c9d75e95a266ee858cf20c88ed255021552de67a40af9c8884d2fc22037dcd2b", size = 5399474, upload-time = "2025-10-23T04:09:14.545Z" }, + { url = "https://files.pythonhosted.org/packages/f6/4b/70c2df3b60c28360f185188d351e9c3958b702614963a09ffb1dc251c1ca/multi_storage_client-0.33.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48195a2ab9e6e9a2763bde17184cad2bdef82684353e210d0d325f20cea18869", size = 3181788, upload-time = "2025-10-23T04:03:10.404Z" }, + { url = "https://files.pythonhosted.org/packages/9b/96/5008852677fdad10eb9d8dd08a6ea58c6f7e820199a3b2c56607186ac6d5/multi_storage_client-0.33.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bd64403efdcee2a6efcf7bfdb01422dd174c146014563b09f44590346fd835e6", size = 3351269, upload-time = "2025-10-23T04:00:34.714Z" }, ] [package.optional-dependencies] @@ -1266,6 +1380,10 @@ boto3 = [ google-cloud-storage = [ { name = "google-cloud-storage" }, ] +huggingface = [ + { name = "hf-transfer" }, + { name = "huggingface-hub" }, +] oci = [ { name = "oci" }, ] @@ -1394,6 +1512,38 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl", hash = "sha256:df5d4365b724cf81b8c6a7312509d0c22386097011ad1abe274afd5e9d3bbc5f", size = 1723263, upload-time = "2024-10-21T12:39:36.247Z" }, ] +[[package]] +name = "numba" +version = "0.62.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "llvmlite" }, + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a3/20/33dbdbfe60e5fd8e3dbfde299d106279a33d9f8308346022316781368591/numba-0.62.1.tar.gz", hash = "sha256:7b774242aa890e34c21200a1fc62e5b5757d5286267e71103257f4e2af0d5161", size = 2749817, upload-time = "2025-09-29T10:46:31.551Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f5/27/a5a9a58f267ec3b72f609789b2a8eefd6156bd7117e41cc9b7cf5de30490/numba-0.62.1-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:a323df9d36a0da1ca9c592a6baaddd0176d9f417ef49a65bb81951dce69d941a", size = 2684281, upload-time = "2025-09-29T10:43:31.863Z" }, + { url = "https://files.pythonhosted.org/packages/3a/9d/ffc091c0bfd7b80f66df3887a7061b6af80c8c2649902444026ee1454391/numba-0.62.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e1e1f4781d3f9f7c23f16eb04e76ca10b5a3516e959634bd226fc48d5d8e7a0a", size = 2687311, upload-time = "2025-09-29T10:43:54.441Z" }, + { url = "https://files.pythonhosted.org/packages/a1/13/9a27bcd0baeea236116070c7df458414336f25e9dd5a872b066cf36b74bf/numba-0.62.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:14432af305ea68627a084cd702124fd5d0c1f5b8a413b05f4e14757202d1cf6c", size = 3734548, upload-time = "2025-09-29T10:42:38.232Z" }, + { url = "https://files.pythonhosted.org/packages/a7/00/17a1ac4a60253c784ce59549375e047da98330b82de7df6ac7f4ecc90902/numba-0.62.1-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f180922adf159ae36c2fe79fb94ffaa74cf5cb3688cb72dba0a904b91e978507", size = 3441277, upload-time = "2025-09-29T10:43:06.124Z" }, + { url = "https://files.pythonhosted.org/packages/86/94/20ae0ff78612c4697eaf942a639db01dd4e2d90f634ac41fa3e015c961fc/numba-0.62.1-cp310-cp310-win_amd64.whl", hash = "sha256:f41834909d411b4b8d1c68f745144136f21416547009c1e860cc2098754b4ca7", size = 2745647, upload-time = "2025-09-29T10:44:15.282Z" }, + { url = "https://files.pythonhosted.org/packages/dd/5f/8b3491dd849474f55e33c16ef55678ace1455c490555337899c35826836c/numba-0.62.1-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:f43e24b057714e480fe44bc6031de499e7cf8150c63eb461192caa6cc8530bc8", size = 2684279, upload-time = "2025-09-29T10:43:37.213Z" }, + { url = "https://files.pythonhosted.org/packages/bf/18/71969149bfeb65a629e652b752b80167fe8a6a6f6e084f1f2060801f7f31/numba-0.62.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:57cbddc53b9ee02830b828a8428757f5c218831ccc96490a314ef569d8342b7b", size = 2687330, upload-time = "2025-09-29T10:43:59.601Z" }, + { url = "https://files.pythonhosted.org/packages/0e/7d/403be3fecae33088027bc8a95dc80a2fda1e3beff3e0e5fc4374ada3afbe/numba-0.62.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:604059730c637c7885386521bb1b0ddcbc91fd56131a6dcc54163d6f1804c872", size = 3739727, upload-time = "2025-09-29T10:42:45.922Z" }, + { url = "https://files.pythonhosted.org/packages/e0/c3/3d910d08b659a6d4c62ab3cd8cd93c4d8b7709f55afa0d79a87413027ff6/numba-0.62.1-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d6c540880170bee817011757dc9049dba5a29db0c09b4d2349295991fe3ee55f", size = 3445490, upload-time = "2025-09-29T10:43:12.692Z" }, + { url = "https://files.pythonhosted.org/packages/5b/82/9d425c2f20d9f0a37f7cb955945a553a00fa06a2b025856c3550227c5543/numba-0.62.1-cp311-cp311-win_amd64.whl", hash = "sha256:03de6d691d6b6e2b76660ba0f38f37b81ece8b2cc524a62f2a0cfae2bfb6f9da", size = 2745550, upload-time = "2025-09-29T10:44:20.571Z" }, + { url = "https://files.pythonhosted.org/packages/5e/fa/30fa6873e9f821c0ae755915a3ca444e6ff8d6a7b6860b669a3d33377ac7/numba-0.62.1-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:1b743b32f8fa5fff22e19c2e906db2f0a340782caf024477b97801b918cf0494", size = 2685346, upload-time = "2025-09-29T10:43:43.677Z" }, + { url = "https://files.pythonhosted.org/packages/a9/d5/504ce8dc46e0dba2790c77e6b878ee65b60fe3e7d6d0006483ef6fde5a97/numba-0.62.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:90fa21b0142bcf08ad8e32a97d25d0b84b1e921bc9423f8dda07d3652860eef6", size = 2688139, upload-time = "2025-09-29T10:44:04.894Z" }, + { url = "https://files.pythonhosted.org/packages/50/5f/6a802741176c93f2ebe97ad90751894c7b0c922b52ba99a4395e79492205/numba-0.62.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6ef84d0ac19f1bf80431347b6f4ce3c39b7ec13f48f233a48c01e2ec06ecbc59", size = 3796453, upload-time = "2025-09-29T10:42:52.771Z" }, + { url = "https://files.pythonhosted.org/packages/7e/df/efd21527d25150c4544eccc9d0b7260a5dec4b7e98b5a581990e05a133c0/numba-0.62.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9315cc5e441300e0ca07c828a627d92a6802bcbf27c5487f31ae73783c58da53", size = 3496451, upload-time = "2025-09-29T10:43:19.279Z" }, + { url = "https://files.pythonhosted.org/packages/80/44/79bfdab12a02796bf4f1841630355c82b5a69933b1d50eb15c7fa37dabe8/numba-0.62.1-cp312-cp312-win_amd64.whl", hash = "sha256:44e3aa6228039992f058f5ebfcfd372c83798e9464297bdad8cc79febcf7891e", size = 2745552, upload-time = "2025-09-29T10:44:26.399Z" }, + { url = "https://files.pythonhosted.org/packages/22/76/501ea2c07c089ef1386868f33dff2978f43f51b854e34397b20fc55e0a58/numba-0.62.1-cp313-cp313-macosx_10_15_x86_64.whl", hash = "sha256:b72489ba8411cc9fdcaa2458d8f7677751e94f0109eeb53e5becfdc818c64afb", size = 2685766, upload-time = "2025-09-29T10:43:49.161Z" }, + { url = "https://files.pythonhosted.org/packages/80/68/444986ed95350c0611d5c7b46828411c222ce41a0c76707c36425d27ce29/numba-0.62.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:44a1412095534a26fb5da2717bc755b57da5f3053965128fe3dc286652cc6a92", size = 2688741, upload-time = "2025-09-29T10:44:10.07Z" }, + { url = "https://files.pythonhosted.org/packages/78/7e/bf2e3634993d57f95305c7cee4c9c6cb3c9c78404ee7b49569a0dfecfe33/numba-0.62.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:8c9460b9e936c5bd2f0570e20a0a5909ee6e8b694fd958b210e3bde3a6dba2d7", size = 3804576, upload-time = "2025-09-29T10:42:59.53Z" }, + { url = "https://files.pythonhosted.org/packages/e8/b6/8a1723fff71f63bbb1354bdc60a1513a068acc0f5322f58da6f022d20247/numba-0.62.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:728f91a874192df22d74e3fd42c12900b7ce7190b1aad3574c6c61b08313e4c5", size = 3503367, upload-time = "2025-09-29T10:43:26.326Z" }, + { url = "https://files.pythonhosted.org/packages/9c/ec/9d414e7a80d6d1dc4af0e07c6bfe293ce0b04ea4d0ed6c45dad9bd6e72eb/numba-0.62.1-cp313-cp313-win_amd64.whl", hash = "sha256:bbf3f88b461514287df66bc8d0307e949b09f2b6f67da92265094e8fa1282dd8", size = 2745529, upload-time = "2025-09-29T10:44:31.738Z" }, +] + [[package]] name = "numpy" version = "2.2.6" @@ -1591,7 +1741,7 @@ wheels = [ [[package]] name = "oci" -version = "2.150.3" +version = "2.161.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "certifi" }, @@ -1601,9 +1751,9 @@ dependencies = [ { name = "python-dateutil" }, { name = "pytz" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/50/f6/b6a45db992966e15a166e3ebe7db7ef0231fb99a1a0a1e438f312ea957bd/oci-2.150.3.tar.gz", hash = "sha256:f5208670d9dd92e4af2e2c46f9747cc6ac897416d245f8ff9a87b305feb01f69", size = 14689750, upload-time = "2025-04-29T08:32:46.095Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0b/a2/0295ef211f8687b85505fb79ab3833ba8d56bb7aaaf2c0568ab289d2edec/oci-2.161.0.tar.gz", hash = "sha256:1322069822babf472feba130da131bce114e9070f95f7c5bf96d034520470c7e", size = 15836650, upload-time = "2025-10-07T06:01:02.165Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e2/a8/a3503ef76a1a7bc2062aa74dee22c9f8fcd9f00f2d7596bc8578de409859/oci-2.150.3-py3-none-any.whl", hash = "sha256:d8873513cefb76a1d6677f3f92d9525ba3dd4e695431ff7d05e3ae531c8cd50c", size = 29899180, upload-time = "2025-04-29T08:32:35.52Z" }, + { url = "https://files.pythonhosted.org/packages/b5/7d/1a19fb91620d8dc82529860ec5f40730277749c4967c67cd3c91cb23e247/oci-2.161.0-py3-none-any.whl", hash = "sha256:e189272f165d2ae32d2839ce300f50ad8376a861500cf93e8295a10b51172b94", size = 32331958, upload-time = "2025-10-07T06:00:54.045Z" }, ] [[package]] @@ -1726,6 +1876,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e9/2f/a4583c70fbd8cd04910e2884bcc2bdd670e884061f7b4d70bc13e632a993/pockets-0.9.1-py2.py3-none-any.whl", hash = "sha256:68597934193c08a08eb2bf6a1d85593f627c22f9b065cc727a4f03f669d96d86", size = 26263, upload-time = "2019-11-02T14:46:17.814Z" }, ] +[[package]] +name = "prettytable" +version = "3.16.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "wcwidth" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/99/b1/85e18ac92afd08c533603e3393977b6bc1443043115a47bb094f3b98f94f/prettytable-3.16.0.tar.gz", hash = "sha256:3c64b31719d961bf69c9a7e03d0c1e477320906a98da63952bc6698d6164ff57", size = 66276, upload-time = "2025-03-24T19:39:04.008Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/02/c7/5613524e606ea1688b3bdbf48aa64bafb6d0a4ac3750274c43b6158a390f/prettytable-3.16.0-py3-none-any.whl", hash = "sha256:b5eccfabb82222f5aa46b798ff02a8452cf530a352c31bddfa29be41242863aa", size = 33863, upload-time = "2025-03-24T19:39:02.359Z" }, +] + [[package]] name = "propcache" version = "0.3.1" @@ -1841,6 +2003,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ee/01/1ed1d482960a5718fd99c82f6d79120181947cfd4667ec3944d448ed44a3/protobuf-6.31.0-py3-none-any.whl", hash = "sha256:6ac2e82556e822c17a8d23aa1190bbc1d06efb9c261981da95c71c9da09e9e23", size = 168558, upload-time = "2025-05-14T17:58:26.923Z" }, ] +[[package]] +name = "psutil" +version = "7.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b3/31/4723d756b59344b643542936e37a31d1d3204bcdc42a7daa8ee9eb06fb50/psutil-7.1.0.tar.gz", hash = "sha256:655708b3c069387c8b77b072fc429a57d0e214221d01c0a772df7dfedcb3bcd2", size = 497660, upload-time = "2025-09-17T20:14:52.902Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/46/62/ce4051019ee20ce0ed74432dd73a5bb087a6704284a470bb8adff69a0932/psutil-7.1.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:76168cef4397494250e9f4e73eb3752b146de1dd950040b29186d0cce1d5ca13", size = 245242, upload-time = "2025-09-17T20:14:56.126Z" }, + { url = "https://files.pythonhosted.org/packages/38/61/f76959fba841bf5b61123fbf4b650886dc4094c6858008b5bf73d9057216/psutil-7.1.0-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:5d007560c8c372efdff9e4579c2846d71de737e4605f611437255e81efcca2c5", size = 246682, upload-time = "2025-09-17T20:14:58.25Z" }, + { url = "https://files.pythonhosted.org/packages/88/7a/37c99d2e77ec30d63398ffa6a660450b8a62517cabe44b3e9bae97696e8d/psutil-7.1.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:22e4454970b32472ce7deaa45d045b34d3648ce478e26a04c7e858a0a6e75ff3", size = 287994, upload-time = "2025-09-17T20:14:59.901Z" }, + { url = "https://files.pythonhosted.org/packages/9d/de/04c8c61232f7244aa0a4b9a9fbd63a89d5aeaf94b2fc9d1d16e2faa5cbb0/psutil-7.1.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c70e113920d51e89f212dd7be06219a9b88014e63a4cec69b684c327bc474e3", size = 291163, upload-time = "2025-09-17T20:15:01.481Z" }, + { url = "https://files.pythonhosted.org/packages/f4/58/c4f976234bf6d4737bc8c02a81192f045c307b72cf39c9e5c5a2d78927f6/psutil-7.1.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7d4a113425c037300de3ac8b331637293da9be9713855c4fc9d2d97436d7259d", size = 293625, upload-time = "2025-09-17T20:15:04.492Z" }, + { url = "https://files.pythonhosted.org/packages/79/87/157c8e7959ec39ced1b11cc93c730c4fb7f9d408569a6c59dbd92ceb35db/psutil-7.1.0-cp37-abi3-win32.whl", hash = "sha256:09ad740870c8d219ed8daae0ad3b726d3bf9a028a198e7f3080f6a1888b99bca", size = 244812, upload-time = "2025-09-17T20:15:07.462Z" }, + { url = "https://files.pythonhosted.org/packages/bf/e9/b44c4f697276a7a95b8e94d0e320a7bf7f3318521b23de69035540b39838/psutil-7.1.0-cp37-abi3-win_amd64.whl", hash = "sha256:57f5e987c36d3146c0dd2528cd42151cf96cd359b9d67cfff836995cc5df9a3d", size = 247965, upload-time = "2025-09-17T20:15:09.673Z" }, + { url = "https://files.pythonhosted.org/packages/26/65/1070a6e3c036f39142c2820c4b52e9243246fcfc3f96239ac84472ba361e/psutil-7.1.0-cp37-abi3-win_arm64.whl", hash = "sha256:6937cb68133e7c97b6cc9649a570c9a18ba0efebed46d8c5dae4c07fa1b67a07", size = 244971, upload-time = "2025-09-17T20:15:12.262Z" }, +] + [[package]] name = "pyasn1" version = "0.6.1" @@ -2543,15 +2721,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5", size = 6299353, upload-time = "2025-04-27T18:04:59.103Z" }, ] -[[package]] -name = "tabulate" -version = "0.9.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ec/fe/802052aecb21e3797b8f7902564ab6ea0d60ff8ca23952079064155d1ae1/tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c", size = 81090, upload-time = "2022-10-06T17:21:48.54Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/40/44/4a5f08c96eb108af5cb50b41f76142f0afa346dfa99d5296fe7202a11854/tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f", size = 35252, upload-time = "2022-10-06T17:21:44.262Z" }, -] - [[package]] name = "tenacity" version = "9.1.2" @@ -2752,6 +2921,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/eb/d8/0d1d2e9d3fabcf5d6840362adcf05f8cf3cd06a73358140c3a97189238ae/wcmatch-10.1-py3-none-any.whl", hash = "sha256:5848ace7dbb0476e5e55ab63c6bbd529745089343427caa5537f230cc01beb8a", size = 39854, upload-time = "2025-06-22T19:14:00.978Z" }, ] +[[package]] +name = "wcwidth" +version = "0.2.14" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/24/30/6b0809f4510673dc723187aeaf24c7f5459922d01e2f794277a3dfb90345/wcwidth-0.2.14.tar.gz", hash = "sha256:4d478375d31bc5395a3c55c40ccdf3354688364cd61c4f6adacaa9215d0b3605", size = 102293, upload-time = "2025-09-22T16:29:53.023Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/af/b5/123f13c975e9f27ab9c0770f514345bd406d0e8d3b7a0723af9d43f710af/wcwidth-0.2.14-py2.py3-none-any.whl", hash = "sha256:a7bb560c8aee30f9957e5f9895805edd20602f2d7f720186dfd906e82b4982e1", size = 37286, upload-time = "2025-09-22T16:29:51.641Z" }, +] + [[package]] name = "webdataset" version = "0.2.111" @@ -2832,54 +3010,54 @@ wheels = [ [[package]] name = "xattr" -version = "1.1.4" +version = "1.2.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cffi" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/62/bf/8b98081f9f8fd56d67b9478ff1e0f8c337cde08bcb92f0d592f0a7958983/xattr-1.1.4.tar.gz", hash = "sha256:b7b02ecb2270da5b7e7deaeea8f8b528c17368401c2b9d5f63e91f545b45d372", size = 16729, upload-time = "2025-01-06T19:19:32.557Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b2/9d/99cf83aa9e02604e88ad5e843b0f7a003740e24a60de71e7089acf54bee6/xattr-1.1.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:acb85b6249e9f3ea10cbb56df1021d43f4027212f0d004304bc9075dc7f54769", size = 23923, upload-time = "2025-01-06T19:17:26.152Z" }, - { url = "https://files.pythonhosted.org/packages/2e/89/bf59d0b7b718823ae5535cdb367195c50681625e275896eb8eed7cfd4100/xattr-1.1.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1a848ab125c0fafdc501ccd83b4c9018bba576a037a4ca5960a22f39e295552e", size = 18886, upload-time = "2025-01-06T19:17:28.77Z" }, - { url = "https://files.pythonhosted.org/packages/33/e3/b5aeaa2ff5f4ee08024eb6b271f37f59a088849b1338e29836afb318df12/xattr-1.1.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:467ee77471d26ae5187ee7081b82175b5ca56ead4b71467ec2e6119d1b08beed", size = 19220, upload-time = "2025-01-06T19:17:34.285Z" }, - { url = "https://files.pythonhosted.org/packages/bc/dc/719ae036ebe4e4a121c0489e1865dbf2f9547dd75e1af9c299ff79859d98/xattr-1.1.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0fd35f46cb0154f7033f9d5d0960f226857acb0d1e0d71fd7af18ed84663007c", size = 39105, upload-time = "2025-01-06T19:17:35.575Z" }, - { url = "https://files.pythonhosted.org/packages/8a/74/a9764be50b298e7b36d95de4f7fe444ada6845061d37818d6682117b1c3e/xattr-1.1.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7d956478e9bb98a1efd20ebc6e5703497c1d2d690d5a13c4df4abf59881eed50", size = 36984, upload-time = "2025-01-06T19:17:36.669Z" }, - { url = "https://files.pythonhosted.org/packages/aa/76/dc89306ec7a111926d463afa2be547edb737d3c5900a19eff98aa79f7d63/xattr-1.1.4-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f25dfdcd974b700fb04a40e14a664a80227ee58e02ea062ac241f0d7dc54b4e", size = 38895, upload-time = "2025-01-06T19:17:39.099Z" }, - { url = "https://files.pythonhosted.org/packages/f6/58/95a151ccab2c176848adf386ab0da1395312dffd5824e5797bfdb836637b/xattr-1.1.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:33b63365c1fcbc80a79f601575bac0d6921732e0245b776876f3db3fcfefe22d", size = 38350, upload-time = "2025-01-06T19:17:40.851Z" }, - { url = "https://files.pythonhosted.org/packages/14/c0/7e04c29ea105a7f9d94576420f27939ad2bc1ae79813c8caac26c73ecc36/xattr-1.1.4-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:544542be95c9b49e211f0a463758f200de88ba6d5a94d3c4f42855a484341acd", size = 36740, upload-time = "2025-01-06T19:17:43.263Z" }, - { url = "https://files.pythonhosted.org/packages/a2/e0/0342d023a22d41bb9a45bc8747c03736541cdb54c97c7686708bf207b204/xattr-1.1.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ac14c9893f3ea046784b7702be30889b200d31adcd2e6781a8a190b6423f9f2d", size = 37946, upload-time = "2025-01-06T19:17:46.253Z" }, - { url = "https://files.pythonhosted.org/packages/78/5b/f64ba0f93e6447e1997068959f22ff99e08d77dd88d9edcf97ddcb9e9016/xattr-1.1.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:bb4bbe37ba95542081890dd34fa5347bef4651e276647adaa802d5d0d7d86452", size = 23920, upload-time = "2025-01-06T19:17:48.234Z" }, - { url = "https://files.pythonhosted.org/packages/c8/54/ad66655f0b1317b0a297aa2d6ed7d6e5d5343495841fad535bee37a56471/xattr-1.1.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3da489ecef798705f9a39ea8cea4ead0d1eeed55f92c345add89740bd930bab6", size = 18883, upload-time = "2025-01-06T19:17:49.46Z" }, - { url = "https://files.pythonhosted.org/packages/4d/5d/7d5154570bbcb898e6123c292f697c87c33e12258a1a8b9741539f952681/xattr-1.1.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:798dd0cbe696635a6f74b06fc430818bf9c3b24314e1502eadf67027ab60c9b0", size = 19221, upload-time = "2025-01-06T19:17:51.654Z" }, - { url = "https://files.pythonhosted.org/packages/b9/b7/135cf3018278051f57bb5dde944cb1ca4f7ad4ec383465a08c6a5c7f7152/xattr-1.1.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b2b6361626efad5eb5a6bf8172c6c67339e09397ee8140ec41258737bea9681", size = 39098, upload-time = "2025-01-06T19:17:53.099Z" }, - { url = "https://files.pythonhosted.org/packages/a5/62/577e2eb0108158b78cd93ea3782c7a8d464693f1338a5350a1db16f69a89/xattr-1.1.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6e7fa20a0c9ce022d19123b1c5b848d00a68b837251835a7929fe041ee81dcd0", size = 36982, upload-time = "2025-01-06T19:17:54.493Z" }, - { url = "https://files.pythonhosted.org/packages/59/cc/ab3bd7a4bedf445be4b35de4a4627ef2944786724d18eaf28d05c1238c7c/xattr-1.1.4-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e20eeb08e2c57fc7e71f050b1cfae35cbb46105449853a582bf53fd23c5379e", size = 38891, upload-time = "2025-01-06T19:17:55.853Z" }, - { url = "https://files.pythonhosted.org/packages/45/e8/2285651d92f1460159753fe6628af259c943fcc5071e48a0540fa11dc34d/xattr-1.1.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:477370e75821bded901487e5e752cffe554d1bd3bd4839b627d4d1ee8c95a093", size = 38362, upload-time = "2025-01-06T19:17:57.078Z" }, - { url = "https://files.pythonhosted.org/packages/5f/af/7856c0b1970272a53a428bb20dc125f9fd350fb1b40ebca4e54610af1b79/xattr-1.1.4-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:a8682091cd34a9f4a93c8aaea4101aae99f1506e24da00a3cc3dd2eca9566f21", size = 36724, upload-time = "2025-01-06T19:17:58.534Z" }, - { url = "https://files.pythonhosted.org/packages/5d/34/087e02b32d6288a40b7f6573e97a119016e6c3713d4f4b866bbf56cfb803/xattr-1.1.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:2e079b3b1a274ba2121cf0da38bbe5c8d2fb1cc49ecbceb395ce20eb7d69556d", size = 37945, upload-time = "2025-01-06T19:17:59.764Z" }, - { url = "https://files.pythonhosted.org/packages/f0/2a/d0f9e46de4cec5e4aa45fd939549b977c49dd68202fa844d07cb24ce5f17/xattr-1.1.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:ae6579dea05bf9f335a082f711d5924a98da563cac72a2d550f5b940c401c0e9", size = 23917, upload-time = "2025-01-06T19:18:00.868Z" }, - { url = "https://files.pythonhosted.org/packages/83/e0/a5764257cd9c9eb598f4648a3658d915dd3520ec111ecbd251b685de6546/xattr-1.1.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:cd6038ec9df2e67af23c212693751481d5f7e858156924f14340376c48ed9ac7", size = 18891, upload-time = "2025-01-06T19:18:02.029Z" }, - { url = "https://files.pythonhosted.org/packages/8b/83/a81a147987387fd2841a28f767efedb099cf90e23553ead458f2330e47c5/xattr-1.1.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:608b2877526674eb15df4150ef4b70b7b292ae00e65aecaae2f192af224be200", size = 19213, upload-time = "2025-01-06T19:18:03.303Z" }, - { url = "https://files.pythonhosted.org/packages/4b/52/bf093b4eb9873ffc9e9373dcb38ec8a9b5cd4e6a9f681c4c5cf6bf067a42/xattr-1.1.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c54dad1a6a998c6a23edfd25e99f4d38e9b942d54e518570044edf8c767687ea", size = 39302, upload-time = "2025-01-06T19:18:05.846Z" }, - { url = "https://files.pythonhosted.org/packages/2d/d8/9d7315ebae76a7f48bc5e1aecc7e592eb43376a0f6cf470a854d895d2093/xattr-1.1.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c0dab6ff72bb2b508f3850c368f8e53bd706585012676e1f71debba3310acde8", size = 37224, upload-time = "2025-01-06T19:18:07.226Z" }, - { url = "https://files.pythonhosted.org/packages/c8/b2/10eb17bea7e378b2bcd76fc8c2e5158318e2c08e774b13f548f333d7318a/xattr-1.1.4-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a3c54c6af7cf09432b2c461af257d5f4b1cb2d59eee045f91bacef44421a46d", size = 39145, upload-time = "2025-01-06T19:18:08.403Z" }, - { url = "https://files.pythonhosted.org/packages/74/fb/95bbc28116b3c19a21acc34ec0a5973e9cc97fe49d3f47a65775af3760a8/xattr-1.1.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e346e05a158d554639fbf7a0db169dc693c2d2260c7acb3239448f1ff4a9d67f", size = 38469, upload-time = "2025-01-06T19:18:09.602Z" }, - { url = "https://files.pythonhosted.org/packages/af/03/23db582cb271ed47f2d62956e112501d998b5493f892a77104b5795ae2fc/xattr-1.1.4-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:3ff6d9e2103d0d6e5fcd65b85a2005b66ea81c0720a37036445faadc5bbfa424", size = 36797, upload-time = "2025-01-06T19:18:10.709Z" }, - { url = "https://files.pythonhosted.org/packages/90/c4/b631d0174e097cf8c44d4f70c66545d91dc8ba15bbfa5054dd7da8371461/xattr-1.1.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:7a2ee4563c6414dfec0d1ac610f59d39d5220531ae06373eeb1a06ee37cd193f", size = 38128, upload-time = "2025-01-06T19:18:11.884Z" }, - { url = "https://files.pythonhosted.org/packages/41/7c/3b8e82ba6f5d24753314ef9922390d9c8e78f157159621bb01f4741d3240/xattr-1.1.4-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:878df1b38cfdadf3184ad8c7b0f516311128d5597b60ac0b3486948953658a83", size = 23910, upload-time = "2025-01-06T19:18:14.745Z" }, - { url = "https://files.pythonhosted.org/packages/77/8d/30b04121b42537aa969a797b89138bb1abd213d5777e9d4289284ebc7dee/xattr-1.1.4-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:0c9b8350244a1c5454f93a8d572628ff71d7e2fc2f7480dcf4c4f0e8af3150fe", size = 18890, upload-time = "2025-01-06T19:18:17.68Z" }, - { url = "https://files.pythonhosted.org/packages/fe/94/a95c7db010265a449935452db54d614afb1e5e91b1530c61485fc0fea4b5/xattr-1.1.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a46bf48fb662b8bd745b78bef1074a1e08f41a531168de62b5d7bd331dadb11a", size = 19211, upload-time = "2025-01-06T19:18:24.625Z" }, - { url = "https://files.pythonhosted.org/packages/3e/8d/d5c703970d669a3fbaaff244048ed835f4358f89fb783720088d228da50e/xattr-1.1.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:83fc3c07b583777b1dda6355329f75ca6b7179fe0d1002f1afe0ef96f7e3b5de", size = 39213, upload-time = "2025-01-06T19:18:25.756Z" }, - { url = "https://files.pythonhosted.org/packages/e7/8d/7db74dab96b60c70a16c440294d15c8bd8e1521d9251afda114ca7d5e656/xattr-1.1.4-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6308b19cff71441513258699f0538394fad5d66e1d324635207a97cb076fd439", size = 37142, upload-time = "2025-01-06T19:18:28.293Z" }, - { url = "https://files.pythonhosted.org/packages/27/c9/bc4d09c9926d04c94373aa5e386bbffa9767cd3a95558dc884c5152034b4/xattr-1.1.4-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:48c00ddc15ddadc9c729cd9504dabf50adb3d9c28f647d4ac9a3df45a046b1a0", size = 39048, upload-time = "2025-01-06T19:18:29.488Z" }, - { url = "https://files.pythonhosted.org/packages/39/bf/de1f4f94035aaefbecfb655f30eaf2bed2b63b55a7a400c85359524ba362/xattr-1.1.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a06136196f26293758e1b244200b73156a0274af9a7349fa201c71c7af3bb9e8", size = 38463, upload-time = "2025-01-06T19:18:30.667Z" }, - { url = "https://files.pythonhosted.org/packages/a8/06/de7ae6b5ba7e0646878f4a75292dc5b56568ecca5e622d943570185aad3c/xattr-1.1.4-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:8fc2631a3c6cfcdc71f7f0f847461839963754e76a2015de71e7e71e3304abc0", size = 36812, upload-time = "2025-01-06T19:18:32.553Z" }, - { url = "https://files.pythonhosted.org/packages/0e/ee/d840adbee6cb33c0193f168d2b719b24174712859022684d78c7a3f06e91/xattr-1.1.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:d6e1e835f9c938d129dd45e7eb52ebf7d2d6816323dab93ce311bf331f7d2328", size = 38131, upload-time = "2025-01-06T19:18:33.693Z" }, - { url = "https://files.pythonhosted.org/packages/d7/c9/abcc190a7e24de9feead2404f3bd6dbaceda28034277ffc96ad21b2134f8/xattr-1.1.4-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:7c72667f19d3a9acf324aed97f58861d398d87e42314731e7c6ab3ac7850c971", size = 15610, upload-time = "2025-01-06T19:19:03.772Z" }, - { url = "https://files.pythonhosted.org/packages/2f/e8/aa3b2db13f12f9fcbeb79c69a0e8a6dc420845e0a78a37a52bf392bc8471/xattr-1.1.4-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:67ae934d75ea2563fc48a27c5945749575c74a6de19fdd38390917ddcb0e4f24", size = 16100, upload-time = "2025-01-06T19:19:07.866Z" }, - { url = "https://files.pythonhosted.org/packages/da/0c/e2c7468b7624dcd8fc64562bbd5ed76974d1b263a45af302a424314adc06/xattr-1.1.4-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2a1b0c348dd8523554dc535540d2046c0c8a535bb086561d8359f3667967b6ca", size = 17952, upload-time = "2025-01-06T19:19:11.437Z" }, - { url = "https://files.pythonhosted.org/packages/d8/09/6c47e93d2d96b584e56177f15fcb849fdeeb25fc2a2b75ea514a9e92cdf8/xattr-1.1.4-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:22284255d2a8e8f3da195bd8e8d43ce674dbc7c38d38cb6ecfb37fae7755d31f", size = 17664, upload-time = "2025-01-06T19:19:12.661Z" }, - { url = "https://files.pythonhosted.org/packages/c4/02/06f994685af57d74f788ad81dd88231bcffa65e4f5b064dc0748545110cc/xattr-1.1.4-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b38aac5ef4381c26d3ce147ca98fba5a78b1e5bcd6be6755b4908659f2705c6d", size = 17574, upload-time = "2025-01-06T19:19:15.234Z" }, +sdist = { url = "https://files.pythonhosted.org/packages/50/65/14438ae55acf7f8fc396ee8340d740a3e1d6ef382bf25bf24156cfb83563/xattr-1.2.0.tar.gz", hash = "sha256:a64c8e21eff1be143accf80fd3b8fde3e28a478c37da298742af647ac3e5e0a7", size = 17293, upload-time = "2025-07-14T03:15:44.884Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/08/cd/a7db5dc24e03074f02457c76ddcd35f721db2fe9945bafa058b8796056dc/xattr-1.2.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3df4d8d91e2996c3c72a390ec82e8544acdcb6c7df67b954f1736ff37ea4293e", size = 24248, upload-time = "2025-07-14T03:14:23.279Z" }, + { url = "https://files.pythonhosted.org/packages/5a/6c/236b7be6afe3f2fae6a0834f3ddca3d1cd7695d76247312069a7247f8a5a/xattr-1.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f5eec248976bbfa6c23df25d4995413df57dccf4161f6cbae36f643e99dbc397", size = 19213, upload-time = "2025-07-14T03:14:24.472Z" }, + { url = "https://files.pythonhosted.org/packages/4a/db/776dc933799addf692a8e1a2094f87f5615a5b7de3a4ec83a264a1a23783/xattr-1.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fafecfdedf7e8d455443bec2c3edab8a93d64672619cd1a4ee043a806152e19c", size = 19547, upload-time = "2025-07-14T03:14:25.619Z" }, + { url = "https://files.pythonhosted.org/packages/df/51/6e40331e5effd8f592cab3a6001eb91c9f023ab0c2c1f54cc076e90eee36/xattr-1.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c229e245c6c9a85d2fd7d07531498f837dd34670e556b552f73350f11edf000c", size = 39433, upload-time = "2025-07-14T03:14:27.143Z" }, + { url = "https://files.pythonhosted.org/packages/5e/0d/7e072a6d30434e93c0046ef1267229162445f15485a1a133dcc9efde3b60/xattr-1.2.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:376631e2383918fbc3dc9bcaeb9a533e319322d2cff1c119635849edf74e1126", size = 37315, upload-time = "2025-07-14T03:14:28.274Z" }, + { url = "https://files.pythonhosted.org/packages/51/5b/be272ba051442fb308494675a8e49b69c04cb97123d257eac810cfabe0ba/xattr-1.2.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2fbae24ab22afe078d549645501ecacaa17229e0b7769c8418fad69b51ad37c9", size = 39222, upload-time = "2025-07-14T03:14:29.676Z" }, + { url = "https://files.pythonhosted.org/packages/48/50/5e0e900461ada1628d7909da5a21189087fd2ae80d313983d4cd55631d70/xattr-1.2.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a161160211081d765ac41fa056f4f9b1051f027f08188730fbc9782d0dce623e", size = 38679, upload-time = "2025-07-14T03:14:31.061Z" }, + { url = "https://files.pythonhosted.org/packages/1e/6c/e76b0fb90934fbc991efd5f4c0d1f1e41e8ed9d53f2a141f1626eae0f101/xattr-1.2.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:a542acf6c4e8221664b51b35e0160c44bd0ed1f2fd80019476f7698f4911e560", size = 37069, upload-time = "2025-07-14T03:14:32.456Z" }, + { url = "https://files.pythonhosted.org/packages/8f/1a/ea62d888abf8850baba65ebea887f70de486c10a7b854e87091a15c0939f/xattr-1.2.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:034f075fc5a9391a1597a6c9a21cb57b688680f0f18ecf73b2efc22b8d330cff", size = 38276, upload-time = "2025-07-14T03:14:33.852Z" }, + { url = "https://files.pythonhosted.org/packages/5d/e2/bf74df197a415f25e07378bfa301788e3bf2ac029c3a6c7bd56a900934ff/xattr-1.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:00c26c14c90058338993bb2d3e1cebf562e94ec516cafba64a8f34f74b9d18b4", size = 24246, upload-time = "2025-07-14T03:14:34.873Z" }, + { url = "https://files.pythonhosted.org/packages/a5/51/922df424556ff35b20ca043da5e4dcf0f99cbcb674f59046d08ceff3ebc7/xattr-1.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b4f43dc644db87d5eb9484a9518c34a864cb2e588db34cffc42139bf55302a1c", size = 19212, upload-time = "2025-07-14T03:14:35.905Z" }, + { url = "https://files.pythonhosted.org/packages/7c/72/1ed37812e8285c8002b8834395c53cc89a2d83aa088db642b217be439017/xattr-1.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c7602583fc643ca76576498e2319c7cef0b72aef1936701678589da6371b731b", size = 19546, upload-time = "2025-07-14T03:14:37.242Z" }, + { url = "https://files.pythonhosted.org/packages/d4/b8/ec75db23d81beec68e3be20ea176c11f125697d3bbb5e118b9de9ea7a9ab/xattr-1.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:90c3ad4a9205cceb64ec54616aa90aa42d140c8ae3b9710a0aaa2843a6f1aca7", size = 39426, upload-time = "2025-07-14T03:14:38.264Z" }, + { url = "https://files.pythonhosted.org/packages/d4/9f/c24950641b138072eda7f34d86966dd15cfe3af9a111b5e77b85ee55f99c/xattr-1.2.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:83d87cfe19cd606fc0709d45a4d6efc276900797deced99e239566926a5afedf", size = 37311, upload-time = "2025-07-14T03:14:39.347Z" }, + { url = "https://files.pythonhosted.org/packages/d0/d5/3b7e0dab706d09c6cdb2f05384610e6c5693c72e3794d54a4cad8c838373/xattr-1.2.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c67dabd9ddc04ead63fbc85aed459c9afcc24abfc5bb3217fff7ec9a466faacb", size = 39222, upload-time = "2025-07-14T03:14:40.768Z" }, + { url = "https://files.pythonhosted.org/packages/0e/16/80cf8ec7d92d20b2860c96a1eca18d25e27fa4770f32c9e8250ff32e7386/xattr-1.2.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:9a18ee82d8ba2c17f1e8414bfeb421fa763e0fb4acbc1e124988ca1584ad32d5", size = 38694, upload-time = "2025-07-14T03:14:41.93Z" }, + { url = "https://files.pythonhosted.org/packages/38/c0/b154b254e6e4596aed3210dd48b2e82d958b16d9a7f65346b9154968d2d0/xattr-1.2.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:38de598c47b85185e745986a061094d2e706e9c2d9022210d2c738066990fe91", size = 37055, upload-time = "2025-07-14T03:14:43.435Z" }, + { url = "https://files.pythonhosted.org/packages/dc/1d/3a615660849ef9bdf46d04f9c6d40ee082f7427678013ff85452ed9497db/xattr-1.2.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:15e754e854bdaac366ad3f1c8fbf77f6668e8858266b4246e8c5f487eeaf1179", size = 38275, upload-time = "2025-07-14T03:14:45.18Z" }, + { url = "https://files.pythonhosted.org/packages/37/e5/b048a5f6c5a489915026b70b9133242a2a368383ddab24e4e3a5bdba7ebd/xattr-1.2.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:daff0c1f5c5e4eaf758c56259c4f72631fa9619875e7a25554b6077dc73da964", size = 24240, upload-time = "2025-07-14T03:14:46.173Z" }, + { url = "https://files.pythonhosted.org/packages/cc/f5/d795774f719a0be6137041d4833ca00b178f816e538948548dff79530f34/xattr-1.2.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:109b11fb3f73a0d4e199962f11230ab5f462e85a8021874f96c1732aa61148d5", size = 19218, upload-time = "2025-07-14T03:14:47.412Z" }, + { url = "https://files.pythonhosted.org/packages/cb/8b/65f3bed09ca9ced27bbba8d4a3326f14a58b98ac102163d85b545f81d9c2/xattr-1.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7c7c12968ce0bf798d8ba90194cef65de768bee9f51a684e022c74cab4218305", size = 19539, upload-time = "2025-07-14T03:14:48.413Z" }, + { url = "https://files.pythonhosted.org/packages/96/2d/01ecfdf41ce70f7e29c8a21e730de3c157fb1cb84391923581af81a44c45/xattr-1.2.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d37989dabf25ff18773e4aaeebcb65604b9528f8645f43e02bebaa363e3ae958", size = 39631, upload-time = "2025-07-14T03:14:49.665Z" }, + { url = "https://files.pythonhosted.org/packages/c9/e9/15cbf9c59cf1117e3c45dd429c52f9dab25d95e65ac245c5ad9532986bec/xattr-1.2.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:165de92b0f2adafb336f936931d044619b9840e35ba01079f4dd288747b73714", size = 37552, upload-time = "2025-07-14T03:14:50.718Z" }, + { url = "https://files.pythonhosted.org/packages/9d/f5/cb4dad87843fe79d605cf5d10caad80e2c338a06f0363f1443449185f489/xattr-1.2.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:82191c006ae4c609b22b9aea5f38f68fff022dc6884c4c0e1dba329effd4b288", size = 39472, upload-time = "2025-07-14T03:14:51.74Z" }, + { url = "https://files.pythonhosted.org/packages/5a/d9/012df7b814cc4a0ae41afb59ac31d0469227397b29f58c1377e8db0f34ba/xattr-1.2.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2b2e9c87dc643b09d86befad218e921f6e65b59a4668d6262b85308de5dbd1dd", size = 38802, upload-time = "2025-07-14T03:14:52.801Z" }, + { url = "https://files.pythonhosted.org/packages/d8/08/e107a5d294a816586f274c33aea480fe740fd446276efc84c067e6c82de2/xattr-1.2.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:14edd5d47d0bb92b23222c0bb6379abbddab01fb776b2170758e666035ecf3aa", size = 37125, upload-time = "2025-07-14T03:14:54.313Z" }, + { url = "https://files.pythonhosted.org/packages/3e/6c/a6f9152e10543af67ea277caae7c5a6400a581e407c42156ffce71dd8242/xattr-1.2.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:12183d5eb104d4da787638c7dadf63b718472d92fec6dbe12994ea5d094d7863", size = 38456, upload-time = "2025-07-14T03:14:55.383Z" }, + { url = "https://files.pythonhosted.org/packages/b6/f9/6c98102949691f7e9caf9a31118be6e46720a23049f417dcf77cc689d06e/xattr-1.2.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:c385ea93a18aeb6443a719eb6a6b1d7f7b143a4d1f2b08bc4fadfc429209e629", size = 24242, upload-time = "2025-07-14T03:14:56.392Z" }, + { url = "https://files.pythonhosted.org/packages/22/6a/130f6cd5cbb0ea0e470c9b366a21b9474eb607288fd17256d60e50f05d0b/xattr-1.2.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:2d39d7b36842c67ab3040bead7eb6d601e35fa0d6214ed20a43df4ec30b6f9f9", size = 19219, upload-time = "2025-07-14T03:14:57.367Z" }, + { url = "https://files.pythonhosted.org/packages/3d/40/93f2dd033544028e7b9512b8b9fb6872ec74a804fbb686e62b83fdf72e21/xattr-1.2.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:320ef856bb817f4c40213b6de956dc440d0f23cdc62da3ea02239eb5147093f8", size = 19538, upload-time = "2025-07-14T03:14:58.434Z" }, + { url = "https://files.pythonhosted.org/packages/13/d5/7e301840afb7e3d3ad07b95af1815c7b674373d1f7d95cb6f2ecc794fdb1/xattr-1.2.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:26d306bfb3b5641726f2ee0da6f63a2656aa7fdcfd15de61c476e3ca6bc3277e", size = 39544, upload-time = "2025-07-14T03:14:59.66Z" }, + { url = "https://files.pythonhosted.org/packages/50/19/64a1b02d237126c3198257ebd7c643374d928915a86d36db7ad4da0a4f28/xattr-1.2.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c67e70d5d8136d328ad13f85b887ffa97690422f1a11fb29ab2f702cf66e825a", size = 37468, upload-time = "2025-07-14T03:15:01.096Z" }, + { url = "https://files.pythonhosted.org/packages/59/53/f794e3630cf16840e199f086520aca6a59a30f9428b1423a8581bc9cee9d/xattr-1.2.0-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8904d3539afe1a84fc0b7f02fa91da60d2505adf2d5951dc855bf9e75fe322b2", size = 39378, upload-time = "2025-07-14T03:15:02.149Z" }, + { url = "https://files.pythonhosted.org/packages/f0/a2/ee2d1cdba5e5273886b9f157cb7ef5ba6d83b177d0c17a203d3ac11ee7f7/xattr-1.2.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:2520516c1d058895eae00b2b2f10833514caea6dc6802eef1e431c474b5317ad", size = 38797, upload-time = "2025-07-14T03:15:03.206Z" }, + { url = "https://files.pythonhosted.org/packages/73/28/9216ba5a4485561cf628ea8f7a0753f246e7f0df31656a1cf363c1b7bed4/xattr-1.2.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:29d06abbef4024b7469fcd0d4ade6d2290582350a4df95fcc48fa48b2e83246b", size = 37142, upload-time = "2025-07-14T03:15:04.249Z" }, + { url = "https://files.pythonhosted.org/packages/fd/20/dee2ec6153323592e33f2b82c8c0f0946b9d1989e3c521a9f3d6daac47e5/xattr-1.2.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:093c75f7d9190be355b8e86da3f460b9bfe3d6a176f92852d44dcc3289aa10dc", size = 38462, upload-time = "2025-07-14T03:15:05.387Z" }, + { url = "https://files.pythonhosted.org/packages/d2/aa/5ea6dd94d0ea7affdd57a6eeb88a9e62a6b600e76aff03d32e89474b7c2c/xattr-1.2.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:29ae44247d46e63671311bf7e700826a97921278e2c0c04c2d11741888db41b8", size = 15938, upload-time = "2025-07-14T03:15:27.426Z" }, + { url = "https://files.pythonhosted.org/packages/24/a4/5bab900c0b715b96bfdd16f0b9d160ae8f7e2065d3ff74e9497087d21828/xattr-1.2.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:629c42c1dd813442d90f281f69b88ef0c9625f604989bef8411428671f70f43e", size = 16428, upload-time = "2025-07-14T03:15:28.439Z" }, + { url = "https://files.pythonhosted.org/packages/dd/14/70d531b536d6aea9032b1ed4fd241be6a59301a86082564c6bbd7bbdc80c/xattr-1.2.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:549f8fbda5da48cafc81ba6ab7bb8e8e14c4b0748c37963dc504bcae505474b7", size = 18286, upload-time = "2025-07-14T03:15:29.653Z" }, + { url = "https://files.pythonhosted.org/packages/26/a4/1b2e04ea684fc081183eca6faff485da5ab87b25b4dcfcc4164ae87865a1/xattr-1.2.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aa83e677b5f92a3c5c86eaf875e9d3abbc43887ff1767178def865fa9f12a3a0", size = 17997, upload-time = "2025-07-14T03:15:30.997Z" }, + { url = "https://files.pythonhosted.org/packages/1f/03/75a399549e82b6a20ff84d71ee9e777caf6bc687e8004d8b3699565a6aad/xattr-1.2.0-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb669f01627962ce2bc556f19d421162247bc2cad0d4625d6ea5eb32af4cf29b", size = 17908, upload-time = "2025-07-14T03:15:32.335Z" }, ] [[package]]