diff --git a/nemo_curator_semantic_dedup/Dockerfile b/nemo_curator_semantic_dedup/Dockerfile new file mode 100644 index 0000000..200ef69 --- /dev/null +++ b/nemo_curator_semantic_dedup/Dockerfile @@ -0,0 +1,64 @@ +# NeMo Curator Image Deduplication Example +# Uses CUDA 12.8 for GPU-accelerated processing +FROM anyscale/ray:2.52.0-slim-py312-cu128 + +# Install system dependencies +RUN sudo apt-get update && \ + sudo apt-get install -y --no-install-recommends \ + build-essential \ + unzip \ + wget \ + curl && \ + sudo apt-get clean && \ + sudo rm -rf /var/lib/apt/lists/* + +# Install uv for fast package management +RUN curl -LsSf https://astral.sh/uv/install.sh | sh + +# Install Python dependencies +# NeMo Curator with CUDA 12 support for image processing +RUN uv pip install --system "nemo-curator[image_cuda12]" + +# Additional dependencies for image downloading and processing +RUN uv pip install --system \ + loguru \ + Pillow \ + aiohttp \ + tqdm \ + pandas \ + pyarrow \ + huggingface_hub \ + transformers + +# Pre-download CLIP model weights to avoid runtime downloads +# This makes job startup faster and more reliable +RUN python -c "\ +from huggingface_hub import snapshot_download; \ +import os; \ +model_dir = '/home/ray/model_weights/openai/clip-vit-large-patch14'; \ +os.makedirs(model_dir, exist_ok=True); \ +snapshot_download('openai/clip-vit-large-patch14', local_dir=model_dir)" + +# Set environment variable for model directory +ENV MODEL_DIR=/home/ray/model_weights + +# Download and prepare the example dataset from HuggingFace +# Downloads MS COCO parquet, deduplicates URLs, and truncates to 100k rows +RUN mkdir -p /home/ray/data && \ + curl -L https://huggingface.co/datasets/ChristophSchuhmann/MS_COCO_2017_URL_TEXT/resolve/main/mscoco.parquet \ + -o /home/ray/data/mscoco.parquet && \ + python -c "\ +import pandas as pd; \ +df = pd.read_parquet('/home/ray/data/mscoco.parquet'); \ +deduped = df[~df['URL'].duplicated()]; \ +truncated = deduped[:100000]; \ +truncated.to_parquet('/home/ray/data/truncated_100k_mscoco.parquet'); \ +print(f'Created truncated dataset with {len(truncated)} rows')" && \ + rm /home/ray/data/mscoco.parquet + +# Create output directories +RUN mkdir -p /home/ray/data/webdataset \ + /home/ray/data/results \ + /home/ray/data/embeddings \ + /home/ray/data/removal_ids + diff --git a/nemo_curator_semantic_dedup/README.md b/nemo_curator_semantic_dedup/README.md new file mode 100644 index 0000000..84b872e --- /dev/null +++ b/nemo_curator_semantic_dedup/README.md @@ -0,0 +1,57 @@ +# Image Semantic Deduplication with NeMo Curator + +This example uses [NVIDIA NeMo Curator](https://github.com/NVIDIA-NeMo/Curator) to perform GPU-accelerated semantic deduplication on image datasets. + +NeMo Curator is a scalable data curation library that leverages NVIDIA RAPIDS™ for GPU acceleration. This example downloads images from a parquet file, generates CLIP embeddings, and removes near-duplicate images based on semantic similarity. + +## Install the Anyscale CLI + +```bash +pip install -U anyscale +anyscale login +``` + +## Run the job + +Clone the example from GitHub. + +```bash +git clone https://github.com/anyscale/examples.git +cd examples/nemo_curator_semantic_dedup +``` + +Submit the job. + +```bash +anyscale job submit -f job.yaml +``` + +## Understanding the example + +- The [Dockerfile](./Dockerfile) builds a custom image with NeMo Curator CUDA dependencies (`nemo-curator[image_cuda12]`), downloads the MS COCO sample dataset from HuggingFace, and pre-downloads the CLIP model weights to speed up job startup. + +- The entrypoint defined in [job.yaml](./job.yaml) runs `image_dedup_example.py` which executes a 3-step pipeline: + 1. **Download WebDataset**: Fetches images from URLs in the parquet file and saves them as WebDataset tar files to `/mnt/cluster_storage/nemo_curator/webdataset` + 2. **Generate CLIP embeddings**: Uses OpenAI's CLIP ViT-L/14 model to create 768-dimensional embeddings for each image + 3. **Semantic deduplication**: Clusters embeddings with k-means and removes near-duplicates based on cosine similarity + +- The `/mnt/cluster_storage/` directory is an ephemeral shared filesystem attached to the cluster for the duration of the job. All outputs (embeddings, duplicate IDs, and deduplicated images) are saved here. + +- To use your own data, prepare a parquet file with `URL` and `TEXT` columns, upload it to cluster storage, and override the `INPUT_PARQUET` environment variable: + ```bash + anyscale job submit -f job.yaml \ + --env INPUT_PARQUET=/mnt/cluster_storage/your_data.parquet \ + --env OUTPUT_DIR=/mnt/cluster_storage/your_results + ``` + +- The [helper.py](./helper.py) module provides utilities for downloading images in parallel and converting them to [WebDataset](https://github.com/webdataset/webdataset) format, which is optimized for streaming large-scale image datasets. + +## View the job + +View the job in the [jobs tab](https://console.anyscale.com/jobs) of the Anyscale console. + +## Learn more + +- [NeMo Curator Documentation](https://docs.nvidia.com/nemo/curator/latest/) +- [NeMo Curator Image Tutorials](https://github.com/NVIDIA-NeMo/Curator/tree/main/tutorials/image/getting-started) +- [Anyscale Jobs Documentation](https://docs.anyscale.com/platform/jobs/) diff --git a/nemo_curator_semantic_dedup/helper.py b/nemo_curator_semantic_dedup/helper.py new file mode 100644 index 0000000..a83a1fa --- /dev/null +++ b/nemo_curator_semantic_dedup/helper.py @@ -0,0 +1,327 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import io +import json +import math +import os +import tarfile +from functools import partial +from multiprocessing import Pool +from typing import TYPE_CHECKING + +import aiohttp +import pandas as pd +from loguru import logger +from PIL import Image +from tqdm import tqdm + +if TYPE_CHECKING: + from nemo_curator.tasks import ImageObject + from nemo_curator.tasks.image import ImageBatch + +# HTTP status codes +HTTP_OK = 200 + + +async def fetch_image_bytes(session: aiohttp.ClientSession, url: str, retries: int = 3) -> bytes | None: + for attempt in range(1, retries + 1): + try: + async with session.get(url, timeout=aiohttp.ClientTimeout(total=15)) as response: + if response.status == HTTP_OK: + return await response.read() + elif attempt > 1: + logger.debug(f"[Attempt {attempt}] Failed to download {url}: HTTP status {response.status}") + except (aiohttp.ClientError, asyncio.TimeoutError) as e: + if attempt > 1: + logger.debug(f"[Attempt {attempt}] Failed to download {url}: {e}") + + if attempt < retries: + await asyncio.sleep(1) + + logger.debug(f"All {retries} attempts failed for {url}") + return None + + +async def process_batch(batch: pd.DataFrame, output_dir: str, batch_num: int) -> None: + tar_filename = os.path.join(output_dir, f"{batch_num:05d}.tar") + + metadatas = [] + # Set timeout and connection limits for the session + timeout = aiohttp.ClientTimeout(total=15) + connector = aiohttp.TCPConnector(limit=256, limit_per_host=16) + + async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session: + tasks = [] + for i, (_, row) in enumerate(batch.iterrows()): + caption = row["TEXT"] + url = row["URL"] + + key = f"{batch_num:05d}{i:04d}" + + meta = {"url": url, "caption": caption, "key": key} + metadatas.append(meta) + + tasks.append(fetch_image_bytes(session, url, retries=3)) + + results = await asyncio.gather(*tasks, return_exceptions=True) + + with tarfile.open(tar_filename, "w") as tar: + for i, result in enumerate(results): + # Only proceed for successful downloads (bytes) + if isinstance(result, bytes) and result: + key = f"{batch_num:05d}{i:04d}" + + # Add image bytes + jpg_info = tarfile.TarInfo(name=f"{key}.jpg") + jpg_info.size = len(result) + tar.addfile(jpg_info, fileobj=io.BytesIO(result)) + + # Add caption text + caption_bytes = str(metadatas[i]["caption"]).encode("utf-8") + txt_info = tarfile.TarInfo(name=f"{key}.txt") + txt_info.size = len(caption_bytes) + tar.addfile(txt_info, fileobj=io.BytesIO(caption_bytes)) + + # Add JSON metadata + json_bytes = json.dumps(metadatas[i]).encode("utf-8") + json_info = tarfile.TarInfo(name=f"{key}.json") + json_info.size = len(json_bytes) + tar.addfile(json_info, fileobj=io.BytesIO(json_bytes)) + + # Write parquet + meta_df = pd.DataFrame(metadatas) + parquet_path = os.path.join(output_dir, f"{batch_num:05d}.parquet") + meta_df.to_parquet(parquet_path) + + +def process_parquet_chunk(chunk: tuple[int, pd.DataFrame], output_dir: str) -> None: + batch_num, batch = chunk + + asyncio.run(process_batch(batch, output_dir, batch_num)) + + +def download_webdataset( + parquet_path: str, + output_dir: str, + entries_per_tar: int = 10000, + num_processes: int = 2, +) -> None: + os.makedirs(output_dir, exist_ok=True) + + # Read the parquet file + df = pd.read_parquet(parquet_path) + print(f"Loaded {len(df)} entries from parquet file") + + # Split the dataframe into chunks for multiprocessing + chunks = [ + (batch_num, df[i : i + entries_per_tar]) for batch_num, i in enumerate(range(0, len(df), entries_per_tar)) + ] + print(f"Split into {len(chunks)} chunks of {entries_per_tar} entries each") + + # Use multiprocessing to process chunks in parallel with progress tracking + with Pool(processes=num_processes) as pool: + func = partial(process_parquet_chunk, output_dir=output_dir) + + # Use tqdm to track progress of chunk processing + list(tqdm( + pool.imap(func, chunks), + total=len(chunks), + desc="Processing chunks", + unit="chunk" + )) + + # Best-effort cleanup of legacy tmp dir from previous versions + tmp_dir = os.path.join(output_dir, "tmp") + try: + if os.path.isdir(tmp_dir) and not os.listdir(tmp_dir): + os.rmdir(tmp_dir) + except OSError as e: + logger.debug(f"Failed to remove tmp dir {tmp_dir}: {e}") + + +def _prepare_metadata_record( + image_obj: ImageObject, + new_id: str, + old_id_col: str | None, +) -> dict: + """Prepare metadata record for an image object.""" + metadata_record = { + "id": new_id, + "original_id": image_obj.image_id, + "original_path": image_obj.image_path, + } + + # Preserve original ID in specified column if requested + if old_id_col: + metadata_record[old_id_col] = image_obj.image_id + + # Add scores and embeddings to metadata + if image_obj.aesthetic_score is not None: + metadata_record["aesthetic_score"] = image_obj.aesthetic_score + if image_obj.nsfw_score is not None: + metadata_record["nsfw_score"] = image_obj.nsfw_score + if image_obj.embedding is not None: + # Convert embedding to list for JSON serialization + metadata_record["embedding"] = image_obj.embedding.tolist() + metadata_record["embedding_dim"] = len(image_obj.embedding) + + # Add original metadata + if image_obj.metadata: + metadata_record.update(image_obj.metadata) + + return metadata_record + + +def _add_caption_to_metadata(image_obj: ImageObject, metadata_record: dict) -> None: + """Add caption/text to metadata record.""" + if "caption" in image_obj.metadata: + metadata_record["caption"] = str(image_obj.metadata["caption"]) + elif "text" in image_obj.metadata: + metadata_record["caption"] = str(image_obj.metadata["text"]) + elif "TEXT" in image_obj.metadata: + metadata_record["caption"] = str(image_obj.metadata["TEXT"]) + + +def _add_image_to_tar(tar: tarfile.TarFile, image_obj: ImageObject, new_id: str) -> None: + """Add image data to tar file if available.""" + if image_obj.image_data is not None: + # Convert numpy array to PIL Image and save as bytes + image_pil = Image.fromarray(image_obj.image_data) + image_bytes = _image_to_bytes(image_pil) + + # Add image to tar + image_info = tarfile.TarInfo(name=f"{new_id}.jpg") + image_info.size = len(image_bytes.getvalue()) + tar.addfile(image_info, fileobj=image_bytes) + + +def _add_json_to_tar(tar: tarfile.TarFile, metadata_record: dict, new_id: str) -> None: + """Add JSON metadata to tar file.""" + json_data = json.dumps(metadata_record, indent=2) + json_bytes = json_data.encode("utf-8") + json_info = tarfile.TarInfo(name=f"{new_id}.json") + json_info.size = len(json_bytes) + tar.addfile(json_info, fileobj=io.BytesIO(json_bytes)) + + +def save_imagebatch_to_webdataset( + image_batches: list[ImageBatch], + output_path: str, + samples_per_shard: int = 10000, + max_shards: int = 5, + old_id_col: str | None = None, +) -> None: + """ + Save ImageBatch objects to WebDataset format with resharding. + + Args: + image_batches: List of ImageBatch objects from pipeline output + output_path: Directory path where the WebDataset should be saved + samples_per_shard: Number of samples to include in each tar file + max_shards: Order of magnitude of max shards (for zero-padding filenames) + old_id_col: If specified, will preserve the original image_id in this column + """ + os.makedirs(output_path, exist_ok=True) + + # Flatten all ImageObjects from all batches + all_image_objects = [] + for batch in image_batches: + all_image_objects.extend(batch.data) + + if not all_image_objects: + print("No images to save") + return + + print(f"Processing {len(all_image_objects)} images into {samples_per_shard} samples per shard") + + max_samples_per_shard = math.ceil(math.log10(samples_per_shard)) + + # Process images in shards + shard_id = 0 + for i in range(0, len(all_image_objects), samples_per_shard): + shard_images = all_image_objects[i:i + samples_per_shard] + + # Create output file paths + parquet_filename = _name_partition(shard_id, max_shards=max_shards) + tar_filename = _name_partition(shard_id, max_shards=max_shards, ext="tar") + parquet_path = os.path.join(output_path, parquet_filename) + tar_path = os.path.join(output_path, tar_filename) + + # Prepare metadata for parquet + metadata_records = [] + + # Create tar file with images and metadata + with tarfile.open(tar_path, "w") as tar: + for sample_idx, image_obj in enumerate(shard_images): + # Generate new ID combining shard and sample indices + new_id = _combine_id( + shard_id, + sample_idx, + max_shards=max_shards, + max_samples_per_shard=max_samples_per_shard + ) + + # Prepare metadata record for parquet + metadata_record = _prepare_metadata_record(image_obj, new_id, old_id_col) + metadata_records.append(metadata_record) + + # Save image data if available and requested + _add_image_to_tar(tar, image_obj, new_id) + + # Store caption/text in metadata (no separate .txt file) + _add_caption_to_metadata(image_obj, metadata_record) + + # Add JSON metadata to tar + _add_json_to_tar(tar, metadata_record, new_id) + + # Save metadata to parquet + metadata_df = pd.DataFrame(metadata_records) + metadata_df.to_parquet(parquet_path, index=False) + + print(f"✓ Saved shard {shard_id:0{max_shards}d} with {len(shard_images)} samples") + print(f" - Tar file: {tar_filename}") + print(f" - Parquet file: {parquet_filename}") + + shard_id += 1 + + print(f"\nSuccessfully saved {len(all_image_objects)} images to {shard_id} shards") + print(f"Output directory: {output_path}") + + +def _name_partition( + partition_index: int, + max_shards: int = 5, + ext: str = "parquet", +) -> str: + """Generate partition filename with proper zero-padding.""" + return f"{partition_index:0{max_shards}d}.{ext}" + + +def _combine_id(shard_id: int, sample_id: int, max_shards: int = 5, max_samples_per_shard: int = 4) -> str: + """Combine shard and sample IDs into a unique identifier.""" + int_id = sample_id + (10**max_samples_per_shard) * shard_id + n_digits = max_samples_per_shard + max_shards + return f"{int_id:0{n_digits}d}" + + +def _image_to_bytes(image_pil: Image.Image, image_format: str = "JPEG") -> io.BytesIO: + """Convert PIL Image to BytesIO object for tarfile.""" + buffer = io.BytesIO() + image_pil.save(buffer, format=image_format) + buffer.seek(0) + return buffer \ No newline at end of file diff --git a/nemo_curator_semantic_dedup/image_dedup_example.py b/nemo_curator_semantic_dedup/image_dedup_example.py new file mode 100644 index 0000000..077b94e --- /dev/null +++ b/nemo_curator_semantic_dedup/image_dedup_example.py @@ -0,0 +1,368 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import time + +import ray +from helper import download_webdataset + +from nemo_curator.core.client import RayClient +from nemo_curator.pipeline import Pipeline +from nemo_curator.stages.deduplication.semantic import SemanticDeduplicationWorkflow +from nemo_curator.stages.file_partitioning import FilePartitioningStage +from nemo_curator.stages.image.deduplication.removal import ImageDuplicatesRemovalStage +from nemo_curator.stages.image.embedders.clip_embedder import ImageEmbeddingStage +from nemo_curator.stages.image.io.convert import ConvertImageBatchToDocumentBatchStage +from nemo_curator.stages.image.io.image_reader import ImageReaderStage +from nemo_curator.stages.image.io.image_writer import ImageWriterStage +from nemo_curator.stages.text.io.writer.parquet import ParquetWriter + + +def create_image_embedding_pipeline(args: argparse.Namespace) -> Pipeline: + """Create image curation pipeline with file partitioning, image reading, embedding, deduplication.""" + + # Define pipeline + pipeline = Pipeline(name="image_curation", description="Curate images with embeddings and quality scoring") + + # Stage 0: Partition tar files for parallel processing + pipeline.add_stage(FilePartitioningStage( + file_paths=args.input_wds_dataset_dir, + files_per_partition=args.tar_files_per_partition, + file_extensions=[".tar"], + )) + + # Stage 1: Read images from webdataset tar files (now runs in parallel) + pipeline.add_stage(ImageReaderStage( + task_batch_size=args.batch_size, + verbose=args.verbose, + num_threads=16, # More threads for I/O + num_gpus_per_worker=0.25, + )) + + # Stage 2: Generate CLIP embeddings for images + pipeline.add_stage(ImageEmbeddingStage( + model_dir=args.model_dir, + num_gpus_per_worker=args.embedding_gpus_per_worker, + model_inference_batch_size=args.embedding_batch_size, + verbose=args.verbose, + )) + + # Stage 3: Convert embeddings to document batch + pipeline.add_stage(ConvertImageBatchToDocumentBatchStage(fields=["image_id", "embedding"])) + + # Stage 4: Save embeddings to parquet file + pipeline.add_stage(ParquetWriter( + path=args.embeddings_dir, + )) + + return pipeline + +def create_embedding_deduplication_workflow(args: argparse.Namespace) -> Pipeline: + """Create image deduplication pipeline with embedding deduplication.""" + return SemanticDeduplicationWorkflow( + input_path=args.embeddings_dir, + output_path=args.removal_parquets_dir, + id_field="image_id", + embedding_field="embedding", + n_clusters=100, + eps=0.01, + read_kwargs={"storage_options": {}}, + write_kwargs={"storage_options": {}}, + verbose=args.verbose, + ) + +def create_image_deduplication_pipeline(args: argparse.Namespace) -> Pipeline: + """Create image deduplication pipeline with image deduplication.""" + # Define pipeline + pipeline = Pipeline(name="image_deduplication", description="Deduplicate images with image deduplication") + + # Stage 0: Partition tar files for parallel processing + pipeline.add_stage(FilePartitioningStage( + file_paths=args.input_wds_dataset_dir, + files_per_partition=args.tar_files_per_partition, + file_extensions=[".tar"], + )) + + # Stage 1: Read images from webdataset tar files (now runs in parallel) + pipeline.add_stage(ImageReaderStage( + task_batch_size=args.batch_size, + verbose=args.verbose, + num_threads=16, # More threads for I/O + num_gpus_per_worker=0.25, + )) + + # Stage 2: Read removal list from parquet file and filter images + pipeline.add_stage(ImageDuplicatesRemovalStage( + removal_parquets_dir=args.removal_parquets_dir + "/duplicates", + duplicate_id_field="id", + verbose=args.verbose, + )) + + # Stage 3: Write filtered images to disk + pipeline.add_stage(ImageWriterStage( + output_dir=args.output_dataset_dir, + remove_image_data=True, + verbose=args.verbose, + )) + + return pipeline + + +def main(args: argparse.Namespace) -> None: + """Main execution function for image curation pipeline.""" + + ray_client = RayClient() + ray_client.start() + + print("Starting image curation pipeline...") + print(f"Input parquet file: {args.input_parquet}") + print(f"Input webdataset directory: {args.input_wds_dataset_dir}") + print(f"Output webdataset directory: {args.output_dataset_dir}") + print(f"Model directory: {args.model_dir}") + print(f"Tar files per partition: {args.tar_files_per_partition}") + print(f"Task batch size: {args.batch_size}") + print("\n" + "=" * 50 + "\n") + + # Step 1: Download and prepare webdataset from parquet file + if not args.skip_download: + print("Step 1: Downloading webdataset from parquet file...") + download_start = time.time() + + # Create output directory if it doesn't exist + os.makedirs(args.input_wds_dataset_dir, exist_ok=True) + + # Download webdataset using helper function + download_webdataset( + parquet_path=args.input_parquet, + output_dir=args.input_wds_dataset_dir, + num_processes=args.download_processes, + entries_per_tar=args.entries_per_tar, + ) + + download_time = time.time() - download_start + print(f"✓ Dataset download completed in {download_time:.2f} seconds") + print(f"✓ Webdataset saved to: {args.input_wds_dataset_dir}") + print("\n" + "=" * 50 + "\n") + else: + print("Step 1: Skipping download (using existing dataset)") + print(f"Using existing dataset at: {args.input_wds_dataset_dir}") + print("\n" + "=" * 50 + "\n") + + # Step 2: Create and run curation pipelines + # Step 2.1: Create image embedding pipeline + print("Step 2.1: Running image embedding pipeline...") + start_time = time.time() + pipeline = create_image_embedding_pipeline(args) + print(pipeline.describe()) + print("\n" + "=" * 50 + "\n") + pipeline.run() + + # Step 2.2: Create image deduplication pipeline (pairwise executor is XennaExecutor by default) + print("Step 2.2: Running image deduplication pipeline...") + start_time = time.time() + pipeline = create_embedding_deduplication_workflow(args) + print("\n" + "=" * 50 + "\n") + pipeline.run() + + # Step 2.3: Create image deduplication pipeline + print("Step 2.3: Running image deduplication pipeline...") + start_time = time.time() + pipeline = create_image_deduplication_pipeline(args) + print(pipeline.describe()) + print("\n" + "=" * 50 + "\n") + pipeline.run() + + end_time = time.time() + + # Calculate and print execution time + execution_time = end_time - start_time + hours, remainder = divmod(execution_time, 3600) + minutes, seconds = divmod(remainder, 60) + + print("\nImage curation pipeline completed!") + print(f"Total execution time: {int(hours):02d}:{int(minutes):02d}:{seconds:.2f}") + print(f"Total execution time: {execution_time:.2f} seconds") + print(f"\nProcessed dataset available at: {args.output_dataset_dir}") + + ray_client.stop() + + +def get_env_or_arg(env_var: str, arg_value, default=None): + """Get value from environment variable or command-line argument.""" + env_value = os.environ.get(env_var) + if env_value is not None: + return env_value + if arg_value is not None: + return arg_value + return default + + +def get_env_bool(env_var: str, arg_value: bool, default: bool = False) -> bool: + """Get boolean value from environment variable or command-line argument.""" + env_value = os.environ.get(env_var) + if env_value is not None: + return env_value.lower() in ("true", "1", "yes") + return arg_value if arg_value is not None else default + + +def get_env_int(env_var: str, arg_value: int, default: int) -> int: + """Get integer value from environment variable or command-line argument.""" + env_value = os.environ.get(env_var) + if env_value is not None: + return int(env_value) + return arg_value if arg_value is not None else default + + +def get_env_float(env_var: str, arg_value: float, default: float) -> float: + """Get float value from environment variable or command-line argument.""" + env_value = os.environ.get(env_var) + if env_value is not None: + return float(env_value) + return arg_value if arg_value is not None else default + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Image curation pipeline with embedding generation and quality scoring. " + "Arguments can also be set via environment variables (see job.yaml)." + ) + + # Dataset arguments + parser.add_argument( + "--input-parquet", + type=str, + required=False, + default=None, + help="Path to input parquet file containing image URLs and metadata (env: INPUT_PARQUET)" + ) + parser.add_argument( + "--input-wds-dataset-dir", + type=str, + required=False, + default=None, + help="Directory to save the downloaded webdataset (env: INPUT_WDS_DIR)" + ) + parser.add_argument( + "--output-dataset-dir", + type=str, + required=False, + default=None, + help="Directory to save the resulting webdataset (env: OUTPUT_DIR)" + ) + parser.add_argument( + "--embeddings-dir", + type=str, + required=False, + default=None, + help="Directory to save the embeddings (env: EMBEDDINGS_DIR)" + ) + parser.add_argument( + "--removal-parquets-dir", + type=str, + required=False, + default=None, + help="Directory to save the remove parquets (env: REMOVAL_DIR)" + ) + parser.add_argument( + "--download-processes", + type=int, + default=None, + help="Number of parallel processes for downloading images (env: DOWNLOAD_PROCESSES)" + ) + parser.add_argument( + "--entries-per-tar", + type=int, + default=None, + help="Number of entries per tar shard during download (env: ENTRIES_PER_TAR)" + ) + parser.add_argument( + "--skip-download", + action="store_true", + default=None, + help="Skip dataset download and use existing webdataset (env: SKIP_DOWNLOAD)" + ) + + # Image reader arguments + parser.add_argument( + "--tar-files-per-partition", + type=int, + default=None, + help="Number of tar files to process per partition (env: TAR_FILES_PER_PARTITION)" + ) + parser.add_argument( + "--batch-size", + type=int, + default=None, + help="Number of images per ImageBatch for the reader stage (env: BATCH_SIZE)" + ) + + # General arguments + parser.add_argument( + "--model-dir", + type=str, + required=False, + default=None, + help="Path to model directory containing all model weights (env: MODEL_DIR)" + ) + parser.add_argument( + "--verbose", + action="store_true", + default=False, + help="Enable verbose logging for all stages" + ) + + # Embedding stage arguments + parser.add_argument( + "--embedding-batch-size", + type=int, + default=None, + help="Batch size for embedding generation (env: EMBEDDING_BATCH_SIZE)" + ) + parser.add_argument( + "--embedding-gpus-per-worker", + type=float, + default=None, + help="GPU allocation per worker for embedding generation" + ) + + cli_args = parser.parse_args() + + # Resolve arguments from environment variables or command-line args + args = argparse.Namespace( + input_parquet=get_env_or_arg("INPUT_PARQUET", cli_args.input_parquet), + input_wds_dataset_dir=get_env_or_arg("INPUT_WDS_DIR", cli_args.input_wds_dataset_dir), + output_dataset_dir=get_env_or_arg("OUTPUT_DIR", cli_args.output_dataset_dir), + embeddings_dir=get_env_or_arg("EMBEDDINGS_DIR", cli_args.embeddings_dir), + removal_parquets_dir=get_env_or_arg("REMOVAL_DIR", cli_args.removal_parquets_dir), + model_dir=get_env_or_arg("MODEL_DIR", cli_args.model_dir, "/home/ray/model_weights"), + download_processes=get_env_int("DOWNLOAD_PROCESSES", cli_args.download_processes, 8), + entries_per_tar=get_env_int("ENTRIES_PER_TAR", cli_args.entries_per_tar, 1000), + skip_download=get_env_bool("SKIP_DOWNLOAD", cli_args.skip_download, False), + tar_files_per_partition=get_env_int("TAR_FILES_PER_PARTITION", cli_args.tar_files_per_partition, 1), + batch_size=get_env_int("BATCH_SIZE", cli_args.batch_size, 100), + embedding_batch_size=get_env_int("EMBEDDING_BATCH_SIZE", cli_args.embedding_batch_size, 32), + embedding_gpus_per_worker=get_env_float("EMBEDDING_GPUS_PER_WORKER", cli_args.embedding_gpus_per_worker, 0.25), + verbose=cli_args.verbose, + ) + + # Validate required arguments + required_args = ["input_wds_dataset_dir", "output_dataset_dir", "embeddings_dir", "removal_parquets_dir"] + missing = [arg for arg in required_args if getattr(args, arg) is None] + if missing: + parser.error(f"Missing required arguments: {', '.join(missing)}. " + "Set them via command-line or environment variables.") + + main(args) \ No newline at end of file diff --git a/nemo_curator_semantic_dedup/job.yaml b/nemo_curator_semantic_dedup/job.yaml new file mode 100644 index 0000000..1cca312 --- /dev/null +++ b/nemo_curator_semantic_dedup/job.yaml @@ -0,0 +1,72 @@ +# NeMo Curator Image Semantic Deduplication Job +# View the docs: https://docs.anyscale.com/reference/job-api#jobconfig + +name: nemo-curator-image-dedup + +# Build custom image with NeMo Curator CUDA dependencies +containerfile: ./Dockerfile + +# Compute configuration with L4 GPU for CUDA-accelerated image processing +# Head + worker nodes for distributed processing +compute_config: + head_node: + instance_type: g6.8xlarge # 1x L4 GPU, 32 vCPUs, 128GB RAM + # Ensure Ray reports CPU resources on the head node for cosmos_xenna + resources: + CPU: 32 + worker_nodes: + - instance_type: g6.8xlarge # 1x L4 GPU per worker + min_nodes: 2 + max_nodes: 2 + +# Working directory - upload only the example code, not data +working_dir: . + +# Environment variables for job configuration +# Override these when submitting to use your own data paths +env_vars: + # Input parquet file with image URLs (TEXT and URL columns) + # This file is copied into the Docker image during build + INPUT_PARQUET: "/home/ray/data/truncated_100k_mscoco.parquet" + + # Directory for WebDataset tar files (created from parquet) + # Use /mnt/cluster_storage for persistence, or /home/ray/data for ephemeral + INPUT_WDS_DIR: "/mnt/cluster_storage/nemo_curator/webdataset" + + # Output directory for deduplicated images + OUTPUT_DIR: "/mnt/cluster_storage/nemo_curator/results" + + # Directory to store CLIP embeddings + EMBEDDINGS_DIR: "/mnt/cluster_storage/nemo_curator/embeddings" + + # Directory for duplicate removal parquets + REMOVAL_DIR: "/mnt/cluster_storage/nemo_curator/removal_ids" + + # Model weights directory (pre-downloaded in Docker image) + MODEL_DIR: "/home/ray/model_weights" + + # Processing settings + BATCH_SIZE: "32" + EMBEDDING_BATCH_SIZE: "32" + TAR_FILES_PER_PARTITION: "10" + DOWNLOAD_PROCESSES: "8" + ENTRIES_PER_TAR: "1000" + + + SKIP_DOWNLOAD: "false" # Always keep false + + # Ray memory settings to avoid OOM + RAY_DEFAULT_OBJECT_STORE_MEMORY_PROPORTION: "0.5" + + # Increase Ray API server limit for cosmos_xenna monitoring + RAY_MAX_LIMIT_FROM_API_SERVER: "100000" + +# The entrypoint script +entrypoint: python image_dedup_example.py + +# Don't retry on failure - easier to debug +max_retries: 0 + +# Kill after 4 hours to control costs (adjust based on dataset size) +timeout_s: 14400 +