diff --git a/nemo_curator_semantic_dedup/Dockerfile b/nemo_curator_semantic_dedup/Dockerfile new file mode 100644 index 0000000..997dc07 --- /dev/null +++ b/nemo_curator_semantic_dedup/Dockerfile @@ -0,0 +1,66 @@ +# NeMo Curator Image Deduplication Example +# Uses CUDA 12.8 for GPU-accelerated processing +FROM anyscale/ray:2.52.0-slim-py312-cu128 + +# Note: Cache busting for git clone is done via CURATOR_CACHE_BUST arg below + +# Install system dependencies +RUN sudo apt-get update && \ + sudo apt-get install -y --no-install-recommends \ + build-essential \ + unzip \ + wget \ + curl \ + git && \ + sudo apt-get clean && \ + sudo rm -rf /var/lib/apt/lists/* + +# Install uv for fast package management +RUN curl -LsSf https://astral.sh/uv/install.sh | sh + +# Install Python dependencies +# Use uv pip install --system to install into the base anaconda environment +# so all Ray workers (not just the driver) have these packages +RUN python -m pip install --upgrade pip setuptools wheel + +# IMPORTANT: Uninstall any pre-existing RAPIDS/cuML packages from the base image +# The base image may have incompatible versions that conflict with scikit-learn +RUN python -m pip uninstall -y cuml-cu12 cudf-cu12 cugraph-cu12 pylibraft-cu12 raft-dask-cu12 rmm-cu12 || true && \ + echo "Cleaned up pre-existing RAPIDS packages" + +# Clone NeMo-Curator from fork and install in editable mode +# This ensures all Ray workers have the same code with your local edits +ARG CURATOR_REPO=https://github.com/avigyabb/Curator.git +ARG CURATOR_REF=avi-test +# ARG CURATOR_REF=main +# Cache bust for git clone - change this value to force re-clone after pushing to branch +ARG CURATOR_CACHE_BUST=2025-12-29-v3 +RUN echo "Cache bust: ${CURATOR_CACHE_BUST}" && \ + git clone --depth 1 -b ${CURATOR_REF} ${CURATOR_REPO} /home/ray/NeMo-Curator && \ + uv pip install --system -e /home/ray/NeMo-Curator[image_cuda12] + +# Re-upgrade scikit-learn AFTER nemo-curator in case it was downgraded +# cuML 25.6.* needs sklearn >= 1.5 (has _get_default_requests) +RUN uv pip install --system "scikit-learn>=1.5,<1.6" && \ + python -c "import sklearn; print(f'Final scikit-learn version: {sklearn.__version__}')" + +# Additional dependencies for image downloading and processing +RUN uv pip install --system \ + loguru \ + Pillow \ + aiohttp \ + tqdm \ + pandas \ + pyarrow \ + huggingface_hub \ + transformers + +# Set environment variable for model directory +ENV MODEL_DIR=/home/ray/model_weights + +# Create output directories +RUN mkdir -p /home/ray/data/webdataset \ + /home/ray/data/results \ + /home/ray/data/embeddings \ + /home/ray/data/removal_ids + diff --git a/nemo_curator_semantic_dedup/README.md b/nemo_curator_semantic_dedup/README.md new file mode 100644 index 0000000..84b872e --- /dev/null +++ b/nemo_curator_semantic_dedup/README.md @@ -0,0 +1,57 @@ +# Image Semantic Deduplication with NeMo Curator + +This example uses [NVIDIA NeMo Curator](https://github.com/NVIDIA-NeMo/Curator) to perform GPU-accelerated semantic deduplication on image datasets. + +NeMo Curator is a scalable data curation library that leverages NVIDIA RAPIDS™ for GPU acceleration. This example downloads images from a parquet file, generates CLIP embeddings, and removes near-duplicate images based on semantic similarity. + +## Install the Anyscale CLI + +```bash +pip install -U anyscale +anyscale login +``` + +## Run the job + +Clone the example from GitHub. + +```bash +git clone https://github.com/anyscale/examples.git +cd examples/nemo_curator_semantic_dedup +``` + +Submit the job. + +```bash +anyscale job submit -f job.yaml +``` + +## Understanding the example + +- The [Dockerfile](./Dockerfile) builds a custom image with NeMo Curator CUDA dependencies (`nemo-curator[image_cuda12]`), downloads the MS COCO sample dataset from HuggingFace, and pre-downloads the CLIP model weights to speed up job startup. + +- The entrypoint defined in [job.yaml](./job.yaml) runs `image_dedup_example.py` which executes a 3-step pipeline: + 1. **Download WebDataset**: Fetches images from URLs in the parquet file and saves them as WebDataset tar files to `/mnt/cluster_storage/nemo_curator/webdataset` + 2. **Generate CLIP embeddings**: Uses OpenAI's CLIP ViT-L/14 model to create 768-dimensional embeddings for each image + 3. **Semantic deduplication**: Clusters embeddings with k-means and removes near-duplicates based on cosine similarity + +- The `/mnt/cluster_storage/` directory is an ephemeral shared filesystem attached to the cluster for the duration of the job. All outputs (embeddings, duplicate IDs, and deduplicated images) are saved here. + +- To use your own data, prepare a parquet file with `URL` and `TEXT` columns, upload it to cluster storage, and override the `INPUT_PARQUET` environment variable: + ```bash + anyscale job submit -f job.yaml \ + --env INPUT_PARQUET=/mnt/cluster_storage/your_data.parquet \ + --env OUTPUT_DIR=/mnt/cluster_storage/your_results + ``` + +- The [helper.py](./helper.py) module provides utilities for downloading images in parallel and converting them to [WebDataset](https://github.com/webdataset/webdataset) format, which is optimized for streaming large-scale image datasets. + +## View the job + +View the job in the [jobs tab](https://console.anyscale.com/jobs) of the Anyscale console. + +## Learn more + +- [NeMo Curator Documentation](https://docs.nvidia.com/nemo/curator/latest/) +- [NeMo Curator Image Tutorials](https://github.com/NVIDIA-NeMo/Curator/tree/main/tutorials/image/getting-started) +- [Anyscale Jobs Documentation](https://docs.anyscale.com/platform/jobs/) diff --git a/nemo_curator_semantic_dedup/helper.py b/nemo_curator_semantic_dedup/helper.py new file mode 100644 index 0000000..78b0cbd --- /dev/null +++ b/nemo_curator_semantic_dedup/helper.py @@ -0,0 +1,524 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Helper functions for downloading and preparing image datasets. + +This module provides two approaches for converting parquet files (with URLs) to WebDataset format: + +1. `parquet_to_webdataset_ray()` - Distributed approach using Ray Data (recommended) + - Scales across all nodes in the cluster + - Uses Ray Data for parallel reading and processing + - Best for large datasets (millions of images) + +2. `download_webdataset()` - Single-node multiprocessing approach (legacy) + - Runs on a single machine + - Uses Python multiprocessing for parallelism + - Simpler but doesn't scale beyond one node +""" + +from __future__ import annotations + +import asyncio +import io +import json +import os +import tarfile +import uuid +from typing import TYPE_CHECKING, Any + +import aiohttp +import pandas as pd +from loguru import logger +from PIL import Image + +if TYPE_CHECKING: + pass + +# HTTP status codes +HTTP_OK = 200 + + +# ============================================================================= +# Image Download and Validation Utilities +# ============================================================================= + +async def fetch_image_bytes(session: aiohttp.ClientSession, url: str, retries: int = 3) -> bytes | None: + """Fetch image bytes from URL with retries.""" + last_error = None + for attempt in range(1, retries + 1): + try: + async with session.get(url, timeout=aiohttp.ClientTimeout(total=15)) as response: + if response.status == HTTP_OK: + return await response.read() + last_error = f"HTTP {response.status}" + except (aiohttp.ClientError, asyncio.TimeoutError) as e: + last_error = str(e) + + if attempt < retries: + await asyncio.sleep(1) + + return None + + +def validate_and_convert_to_jpeg(image_bytes: bytes) -> bytes | None: + """ + Validate image and convert to JPEG format for DALI compatibility. + + Args: + image_bytes: Raw image bytes + + Returns: + JPEG bytes if valid, None if image is invalid/corrupted + """ + try: + img = Image.open(io.BytesIO(image_bytes)) + img.verify() # Verify it's a valid image + # Re-open after verify (verify consumes the file) + img = Image.open(io.BytesIO(image_bytes)) + + # Robust RGB conversion for ALL image modes (L, LA, P, PA, RGBA, CMYK, etc.) + # This ensures CLIP gets 3-channel images + if img.mode != "RGB": + # For palette images, convert to RGBA first to preserve transparency info + if img.mode == "P": + img = img.convert("RGBA") + # For any mode with alpha, composite onto white background + if img.mode in ("RGBA", "LA", "PA"): + background = Image.new("RGB", img.size, (255, 255, 255)) + # Use alpha channel as mask + if img.mode == "LA": + img = img.convert("RGBA") + background.paste(img, mask=img.split()[-1]) + img = background + else: + # Simple conversion for grayscale (L), CMYK, etc. + img = img.convert("RGB") + + # Final safety check - ensure we have exactly 3 channels + if img.mode != "RGB": + return None + + # Skip images that are too small (CLIP needs at least 3x3 to avoid channel ambiguity) + if img.size[0] < 3 or img.size[1] < 3: + return None + + # Re-encode as JPEG to ensure DALI compatibility + jpeg_buffer = io.BytesIO() + img.save(jpeg_buffer, format="JPEG", quality=95) + return jpeg_buffer.getvalue() + except Exception: + return None + + +async def download_batch_images( + batch: pd.DataFrame, + url_col: str = "URL", + text_col: str = "TEXT", +) -> list[dict[str, Any]]: + """ + Download images for a batch of URLs asynchronously. + + Args: + batch: DataFrame with URL and TEXT columns + url_col: Name of URL column + text_col: Name of text/caption column + + Returns: + List of dicts with 'url', 'caption', 'jpeg_bytes' (None if failed) + """ + timeout = aiohttp.ClientTimeout(total=15) + connector = aiohttp.TCPConnector(limit=256, limit_per_host=16) + + results = [] + async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session: + tasks = [] + metadata = [] + + for _, row in batch.iterrows(): + url = row[url_col] + caption = row[text_col] + metadata.append({"url": url, "caption": caption}) + tasks.append(fetch_image_bytes(session, url, retries=3)) + + raw_results = await asyncio.gather(*tasks, return_exceptions=True) + + for meta, raw_bytes in zip(metadata, raw_results): + jpeg_bytes = None + if isinstance(raw_bytes, bytes) and raw_bytes: + jpeg_bytes = validate_and_convert_to_jpeg(raw_bytes) + + results.append({ + "url": meta["url"], + "caption": meta["caption"], + "jpeg_bytes": jpeg_bytes, + }) + + return results + + +def write_tar_shard( + images: list[dict[str, Any]], + output_path: str, + shard_id: str, +) -> dict[str, int]: + """ + Write a tar shard with downloaded images. + + Args: + images: List of dicts with 'url', 'caption', 'jpeg_bytes' + output_path: Path to write tar file + shard_id: Unique identifier for this shard + + Returns: + Dict with 'success_count' and 'total_count' + """ + success_count = 0 + metadatas = [] + + with tarfile.open(output_path, "w") as tar: + for i, img_data in enumerate(images): + if img_data["jpeg_bytes"] is None: + continue + + key = f"{shard_id}_{i:06d}" + jpeg_bytes = img_data["jpeg_bytes"] + + # Add image bytes + jpg_info = tarfile.TarInfo(name=f"{key}.jpg") + jpg_info.size = len(jpeg_bytes) + tar.addfile(jpg_info, fileobj=io.BytesIO(jpeg_bytes)) + + # Add caption text + caption_bytes = str(img_data["caption"]).encode("utf-8") + txt_info = tarfile.TarInfo(name=f"{key}.txt") + txt_info.size = len(caption_bytes) + tar.addfile(txt_info, fileobj=io.BytesIO(caption_bytes)) + + # Add JSON metadata + meta = {"url": img_data["url"], "caption": img_data["caption"], "key": key} + json_bytes = json.dumps(meta).encode("utf-8") + json_info = tarfile.TarInfo(name=f"{key}.json") + json_info.size = len(json_bytes) + tar.addfile(json_info, fileobj=io.BytesIO(json_bytes)) + + metadatas.append(meta) + success_count += 1 + + # Write parquet sidecar + if metadatas: + parquet_path = output_path.replace(".tar", ".parquet") + pd.DataFrame(metadatas).to_parquet(parquet_path) + + return {"success_count": success_count, "total_count": len(images)} + + +# ============================================================================= +# Ray Data Approach (Distributed) +# ============================================================================= + +def process_batch_ray(batch: dict[str, Any], output_dir: str) -> dict[str, Any]: + """ + Ray Data map function to process a batch of URLs. + + This function is called by Ray Data's map_batches() and runs distributed + across all nodes in the cluster. + + Args: + batch: Dict with 'URL' and 'TEXT' arrays (Ray Data batch format) + output_dir: Directory to write tar files + + Returns: + Dict with statistics about processing + """ + import ray + + # Convert Ray Data batch format to DataFrame + df = pd.DataFrame({ + "URL": batch["URL"], + "TEXT": batch["TEXT"], + }) + + # Generate unique shard ID using node ID + UUID to avoid collisions + node_id = ray.get_runtime_context().get_node_id()[:8] + shard_id = f"{node_id}_{uuid.uuid4().hex[:8]}" + tar_path = os.path.join(output_dir, f"{shard_id}.tar") + + # Download images asynchronously + images = asyncio.run(download_batch_images(df)) + + # Write tar shard + stats = write_tar_shard(images, tar_path, shard_id) + + # Return statistics as a single-row batch + return { + "shard_id": [shard_id], + "success_count": [stats["success_count"]], + "total_count": [stats["total_count"]], + } + + +def parquet_to_webdataset_ray( + parquet_path: str, + output_dir: str, + entries_per_tar: int = 1000, + max_entries: int | None = None, + concurrency: int | None = None, +) -> dict[str, int]: + """ + Convert parquet file with URLs to WebDataset tar files using Ray Data. + + This distributes the download work across all nodes in the Ray cluster, + providing much better scalability than single-node processing. + + Args: + parquet_path: Path to parquet file with URL and TEXT columns + output_dir: Directory to save tar files + entries_per_tar: Number of entries per tar shard + max_entries: Maximum entries to process (for testing) + concurrency: Number of concurrent download tasks (defaults to num CPUs) + + Returns: + Dict with 'total_success' and 'total_attempted' counts + """ + import ray.data + + os.makedirs(output_dir, exist_ok=True) + + print(f"Reading parquet from: {parquet_path}") + + # Read parquet with Ray Data - this distributes reading across the cluster + ds = ray.data.read_parquet(parquet_path) + + # Get schema and normalize column names + schema = ds.schema() + col_names = schema.names if hasattr(schema, 'names') else [f.name for f in schema] + col_map = {} + + # Handle case-insensitive column matching + for col in col_names: + if col.lower() == "url": + col_map[col] = "URL" + elif col.lower() in ("text", "caption"): + col_map[col] = "TEXT" + + if col_map: + # Rename columns to standard names + def rename_cols(batch): + result = {} + for old_name, new_name in col_map.items(): + if old_name in batch: + result[new_name] = batch[old_name] + # Keep any columns that weren't renamed + for col in batch: + if col not in col_map and col not in result: + result[col] = batch[col] + return result + + ds = ds.map_batches(rename_cols, batch_format="pandas") + + # Select only the columns we need + ds = ds.select_columns(["URL", "TEXT"]) + + # Apply max_entries limit + if max_entries is not None: + print(f"Limiting to {max_entries} entries for testing") + ds = ds.limit(max_entries) + + # Count total rows for progress reporting + total_rows = ds.count() + print(f"Total entries to process: {total_rows}") + + # Process batches in parallel across the cluster + # Each batch becomes one tar shard + from functools import partial + + process_fn = partial(process_batch_ray, output_dir=output_dir) + + # Determine concurrency based on cluster resources + if concurrency is None: + import ray + cluster_resources = ray.cluster_resources() + concurrency = max(1, int(cluster_resources.get("CPU", 4) // 2)) + + print(f"Processing with concurrency={concurrency}, entries_per_tar={entries_per_tar}") + + # map_batches distributes work across all nodes + results_ds = ds.map_batches( + process_fn, + batch_size=entries_per_tar, + batch_format="numpy", + concurrency=concurrency, + ) + + # Materialize results and aggregate statistics + results = results_ds.take_all() + + total_success = sum(r["success_count"] for r in results) + total_attempted = sum(r["total_count"] for r in results) + num_shards = len(results) + + # Report results + success_rate = (total_success / total_attempted * 100) if total_attempted > 0 else 0 + print(f"\n✓ Download complete: {total_success} images in {num_shards} shards ({success_rate:.1f}% success rate)") + print(f" Note: LAION datasets have high link rot - many URLs no longer work.") + + if total_success == 0: + print("\n⚠️ WARNING: No images were downloaded successfully!") + print(" This is likely due to LAION link rot. Try increasing MAX_ENTRIES.") + + return { + "total_success": total_success, + "total_attempted": total_attempted, + "num_shards": num_shards, + } + + +# ============================================================================= +# Single-Node Multiprocessing Approach (Legacy) +# ============================================================================= + +async def process_batch_single_node(batch: pd.DataFrame, output_dir: str, batch_num: int) -> int: + """Process a batch of URLs and return the number of successfully downloaded images.""" + tar_filename = os.path.join(output_dir, f"{batch_num:05d}.tar") + shard_id = f"{batch_num:05d}" + + # Download images + images = await download_batch_images(batch) + + # Write tar shard + stats = write_tar_shard(images, tar_filename, shard_id) + return stats["success_count"] + + +def process_parquet_chunk(chunk: tuple[int, pd.DataFrame], output_dir: str) -> int: + """Process a chunk and return the number of successfully downloaded images.""" + batch_num, batch = chunk + return asyncio.run(process_batch_single_node(batch, output_dir, batch_num)) + + +def download_webdataset( + parquet_path: str, + output_dir: str, + entries_per_tar: int = 10000, + num_processes: int = 2, + max_entries: int | None = None, +) -> None: + """ + Single-node approach: Stream parquet into WebDataset tar shards using multiprocessing. + + This is the legacy approach that runs on a single machine. For distributed + processing across a Ray cluster, use `parquet_to_webdataset_ray()` instead. + + Args: + parquet_path: Path to the parquet file containing URLs and text + output_dir: Directory to save the webdataset tar files + entries_per_tar: Number of entries per tar file + num_processes: Number of parallel download processes + max_entries: Maximum number of entries to process (for testing). None = no limit. + """ + import math + from functools import partial + from multiprocessing import Pool + + import pyarrow.dataset as pa_ds + from tqdm import tqdm + + os.makedirs(output_dir, exist_ok=True) + + # Stream the Parquet in batches + dataset = pa_ds.dataset(parquet_path, format="parquet") + schema = dataset.schema + available = set(schema.names) + + def resolve_cols() -> list[str]: + resolved = [] + for col in ["URL", "TEXT"]: + if col in available: + resolved.append(col) + continue + lower = col.lower() + if lower in available: + resolved.append(lower) + continue + if col.upper() == "TEXT" and "caption" in available: + resolved.append("caption") + if not resolved: + raise ValueError(f"No URL/TEXT-like columns found in {parquet_path}; available: {sorted(available)}") + return resolved + + resolved_cols = resolve_cols() + total_rows = dataset.count_rows() + + # Apply max_entries limit for testing + if max_entries is not None and total_rows is not None: + total_rows = min(total_rows, max_entries) + print(f"Limiting to {max_entries} entries for testing") + + total_chunks = math.ceil(total_rows / entries_per_tar) if total_rows is not None else None + + def batch_iter(): + batch_num = 0 + rows_yielded = 0 + for batch in dataset.to_batches(columns=resolved_cols, batch_size=entries_per_tar): + df = batch.to_pandas() + + # Apply max_entries limit + if max_entries is not None: + remaining = max_entries - rows_yielded + if remaining <= 0: + break + if len(df) > remaining: + df = df.head(remaining) + + # normalize column names to URL/TEXT expected downstream + col_map: dict[str, str] = {} + if "url" in df.columns and "URL" not in df.columns: + col_map["url"] = "URL" + if "caption" in df.columns and "TEXT" not in df.columns: + col_map["caption"] = "TEXT" + df = df.rename(columns=col_map) + yield (batch_num, df) + rows_yielded += len(df) + batch_num += 1 + + total_success = 0 + total_attempted = 0 + with Pool(processes=num_processes) as pool: + func = partial(process_parquet_chunk, output_dir=output_dir) + for success_count in tqdm( + pool.imap_unordered(func, batch_iter()), + total=total_chunks, + desc="Processing chunks", + unit="chunk", + ): + total_success += success_count + total_attempted += entries_per_tar # approximate + + # Report download success rate + success_rate = (total_success / total_attempted * 100) if total_attempted > 0 else 0 + print(f"\n✓ Download complete: {total_success} images saved ({success_rate:.1f}% success rate)") + print(f" Note: LAION datasets have high link rot - many URLs no longer work.") + + if total_success == 0: + print("\n⚠️ WARNING: No images were downloaded successfully!") + print(" This is likely due to LAION link rot. Try increasing MAX_ENTRIES.") + + # Best-effort cleanup of legacy tmp dir from previous versions + tmp_dir = os.path.join(output_dir, "tmp") + try: + if os.path.isdir(tmp_dir) and not os.listdir(tmp_dir): + os.rmdir(tmp_dir) + except OSError as e: + logger.debug(f"Failed to remove tmp dir {tmp_dir}: {e}") diff --git a/nemo_curator_semantic_dedup/image_dedup_example.py b/nemo_curator_semantic_dedup/image_dedup_example.py new file mode 100644 index 0000000..fdcd198 --- /dev/null +++ b/nemo_curator_semantic_dedup/image_dedup_example.py @@ -0,0 +1,423 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import time + +import ray +from helper import download_webdataset, parquet_to_webdataset_ray + +from nemo_curator.backends.experimental.ray_actor_pool import RayActorPoolExecutor +from nemo_curator.backends.experimental.ray_data import RayDataExecutor +from nemo_curator.core.client import RayClient +from nemo_curator.pipeline import Pipeline +from nemo_curator.stages.deduplication.semantic import SemanticDeduplicationWorkflow +from nemo_curator.stages.file_partitioning import FilePartitioningStage +from nemo_curator.stages.image.deduplication.removal import ImageDuplicatesRemovalStage +from nemo_curator.stages.image.embedders.clip_embedder import ImageEmbeddingStage +from nemo_curator.stages.image.io.convert import ConvertImageBatchToDocumentBatchStage +from nemo_curator.stages.image.io.image_reader import ImageReaderStage +from nemo_curator.stages.image.io.image_writer import ImageWriterStage +from nemo_curator.stages.text.io.writer.parquet import ParquetWriter + + +def create_image_embedding_pipeline(args: argparse.Namespace) -> Pipeline: + """Create image curation pipeline with file partitioning, image reading, embedding, deduplication.""" + + # Define pipeline + pipeline = Pipeline(name="image_curation", description="Curate images with embeddings and quality scoring") + + # Stage 0: Partition tar files for parallel processing + pipeline.add_stage(FilePartitioningStage( + file_paths=args.input_wds_dataset_dir, + files_per_partition=args.tar_files_per_partition, + file_extensions=[".tar"], + )) + + # Stage 1: Read images from webdataset tar files (now runs in parallel) + pipeline.add_stage(ImageReaderStage( + batch_size=args.batch_size, + verbose=args.verbose, + )) + + # Stage 2: Generate CLIP embeddings for images + pipeline.add_stage(ImageEmbeddingStage( + model_dir=args.model_dir, + num_gpus_per_worker=args.embedding_gpus_per_worker, + model_inference_batch_size=args.embedding_batch_size, + verbose=args.verbose, + )) + + # Stage 3: Convert embeddings to document batch + pipeline.add_stage(ConvertImageBatchToDocumentBatchStage(fields=["image_id", "embedding"])) + + # Stage 4: Save embeddings to parquet file + pipeline.add_stage(ParquetWriter( + path=args.embeddings_dir, + )) + + return pipeline + +def create_embedding_deduplication_workflow(args: argparse.Namespace) -> Pipeline: + """Create image deduplication pipeline with embedding deduplication.""" + return SemanticDeduplicationWorkflow( + input_path=args.embeddings_dir, + output_path=args.removal_parquets_dir, + id_field="image_id", + embedding_field="embedding", + n_clusters=100, + eps=0.01, + read_kwargs={"storage_options": {}}, + write_kwargs={"storage_options": {}}, + verbose=args.verbose, + ) + +def create_image_deduplication_pipeline(args: argparse.Namespace) -> Pipeline: + """Create image deduplication pipeline with image deduplication.""" + # Define pipeline + pipeline = Pipeline(name="image_deduplication", description="Deduplicate images with image deduplication") + + # Stage 0: Partition tar files for parallel processing + pipeline.add_stage(FilePartitioningStage( + file_paths=args.input_wds_dataset_dir, + files_per_partition=args.tar_files_per_partition, + file_extensions=[".tar"], + )) + + # Stage 1: Read images from webdataset tar files (now runs in parallel) + pipeline.add_stage(ImageReaderStage( + batch_size=args.batch_size, + verbose=args.verbose, + )) + + # Stage 2: Read removal list from parquet file and filter images + pipeline.add_stage(ImageDuplicatesRemovalStage( + removal_parquets_dir=args.removal_parquets_dir + "/duplicates", + duplicate_id_field="id", + verbose=args.verbose, + )) + + # Stage 3: Write filtered images to disk + pipeline.add_stage(ImageWriterStage( + output_dir=args.output_dataset_dir, + remove_image_data=True, + verbose=args.verbose, + )) + + return pipeline + + +def main(args: argparse.Namespace) -> None: + """Main execution function for image curation pipeline.""" + + ray_client = RayClient() + ray_client.start() + + print("Starting image curation pipeline...") + print(f"Input parquet file: {args.input_parquet}") + print(f"Input webdataset directory: {args.input_wds_dataset_dir}") + print(f"Output webdataset directory: {args.output_dataset_dir}") + print(f"Model directory: {args.model_dir}") + print(f"Tar files per partition: {args.tar_files_per_partition}") + print(f"Task batch size: {args.batch_size}") + print(f"Use Ray Data for parquet->tar: {args.use_ray_data}") + if args.max_entries: + print(f"Max entries (testing): {args.max_entries}") + print("\n" + "=" * 50 + "\n") + + # Step 1: Download and prepare webdataset from parquet file + if not args.skip_download: + print("Step 1: Converting parquet to WebDataset tar files...") + print(f" Approach: {'Ray Data (distributed)' if args.use_ray_data else 'Single-node multiprocessing'}") + download_start = time.time() + + # Create output directory if it doesn't exist + os.makedirs(args.input_wds_dataset_dir, exist_ok=True) + + if args.use_ray_data: + # Use Ray Data for distributed processing across the cluster + stats = parquet_to_webdataset_ray( + parquet_path=args.input_parquet, + output_dir=args.input_wds_dataset_dir, + entries_per_tar=args.entries_per_tar, + max_entries=args.max_entries, + concurrency=args.download_concurrency, + ) + print(f" Created {stats['num_shards']} tar shards with {stats['total_success']} images") + else: + # Legacy single-node approach + download_webdataset( + parquet_path=args.input_parquet, + output_dir=args.input_wds_dataset_dir, + num_processes=args.download_processes, + entries_per_tar=args.entries_per_tar, + max_entries=args.max_entries, + ) + + download_time = time.time() - download_start + print(f"✓ Dataset conversion completed in {download_time:.2f} seconds") + print(f"✓ Webdataset saved to: {args.input_wds_dataset_dir}") + print("\n" + "=" * 50 + "\n") + else: + print("Step 1: Skipping download (using existing dataset)") + print(f"Using existing dataset at: {args.input_wds_dataset_dir}") + print("\n" + "=" * 50 + "\n") + + # Step 2: Create and run curation pipelines + # Use experimental executors with ignore_head_node=True to avoid scheduling on head node + # This allows using a CPU-only head node while GPU tasks run on workers + streaming_executor = RayDataExecutor(ignore_head_node=True) + batch_executor = RayActorPoolExecutor(ignore_head_node=True) + + # Step 2.1: Create image embedding pipeline + print("Step 2.1: Running image embedding pipeline...") + start_time = time.time() + pipeline = create_image_embedding_pipeline(args) + print(pipeline.describe()) + print("\n" + "=" * 50 + "\n") + pipeline.run(executor=streaming_executor) + + # Step 2.2: Create image deduplication pipeline (semantic dedup workflow) + print("Step 2.2: Running image deduplication pipeline...") + start_time = time.time() + workflow = create_embedding_deduplication_workflow(args) + print("\n" + "=" * 50 + "\n") + workflow.run( + kmeans_executor=RayActorPoolExecutor(ignore_head_node=True), + pairwise_executor=RayActorPoolExecutor(ignore_head_node=True), + ) + + # Step 2.3: Create image deduplication pipeline + print("Step 2.3: Running image deduplication pipeline...") + start_time = time.time() + pipeline = create_image_deduplication_pipeline(args) + print(pipeline.describe()) + print("\n" + "=" * 50 + "\n") + pipeline.run(executor=streaming_executor) + + end_time = time.time() + + # Calculate and print execution time + execution_time = end_time - start_time + hours, remainder = divmod(execution_time, 3600) + minutes, seconds = divmod(remainder, 60) + + print("\nImage curation pipeline completed!") + print(f"Total execution time: {int(hours):02d}:{int(minutes):02d}:{seconds:.2f}") + print(f"Total execution time: {execution_time:.2f} seconds") + print(f"\nProcessed dataset available at: {args.output_dataset_dir}") + + ray_client.stop() + + +def get_env_or_arg(env_var: str, arg_value, default=None): + """Get value from environment variable or command-line argument.""" + env_value = os.environ.get(env_var) + if env_value is not None: + return env_value + if arg_value is not None: + return arg_value + return default + + +def get_env_bool(env_var: str, arg_value: bool, default: bool = False) -> bool: + """Get boolean value from environment variable or command-line argument.""" + env_value = os.environ.get(env_var) + if env_value is not None: + return env_value.lower() in ("true", "1", "yes") + return arg_value if arg_value is not None else default + + +def get_env_int(env_var: str, arg_value: int, default: int) -> int: + """Get integer value from environment variable or command-line argument.""" + env_value = os.environ.get(env_var) + if env_value is not None: + return int(env_value) + return arg_value if arg_value is not None else default + + +def get_env_float(env_var: str, arg_value: float, default: float) -> float: + """Get float value from environment variable or command-line argument.""" + env_value = os.environ.get(env_var) + if env_value is not None: + return float(env_value) + return arg_value if arg_value is not None else default + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Image curation pipeline with embedding generation and quality scoring. " + "Arguments can also be set via environment variables (see job.yaml)." + ) + + # Dataset arguments + parser.add_argument( + "--input-parquet", + type=str, + required=False, + default=None, + help="Path to input parquet file containing image URLs and metadata (env: INPUT_PARQUET)" + ) + parser.add_argument( + "--input-wds-dataset-dir", + type=str, + required=False, + default=None, + help="Directory to save the downloaded webdataset (env: INPUT_WDS_DIR)" + ) + parser.add_argument( + "--output-dataset-dir", + type=str, + required=False, + default=None, + help="Directory to save the resulting webdataset (env: OUTPUT_DIR)" + ) + parser.add_argument( + "--embeddings-dir", + type=str, + required=False, + default=None, + help="Directory to save the embeddings (env: EMBEDDINGS_DIR)" + ) + parser.add_argument( + "--removal-parquets-dir", + type=str, + required=False, + default=None, + help="Directory to save the remove parquets (env: REMOVAL_DIR)" + ) + parser.add_argument( + "--download-processes", + type=int, + default=None, + help="Number of parallel processes for downloading images (env: DOWNLOAD_PROCESSES)" + ) + parser.add_argument( + "--entries-per-tar", + type=int, + default=None, + help="Number of entries per tar shard during download (env: ENTRIES_PER_TAR)" + ) + parser.add_argument( + "--skip-download", + action="store_true", + default=None, + help="Skip dataset download and use existing webdataset (env: SKIP_DOWNLOAD)" + ) + parser.add_argument( + "--max-entries", + type=int, + default=None, + help="Maximum entries to download for testing (env: MAX_ENTRIES). None = no limit." + ) + parser.add_argument( + "--use-ray-data", + action="store_true", + default=None, + help="Use Ray Data for distributed parquet->tar conversion (env: USE_RAY_DATA, default: true)" + ) + parser.add_argument( + "--no-ray-data", + action="store_true", + default=False, + help="Disable Ray Data, use single-node multiprocessing instead" + ) + parser.add_argument( + "--download-concurrency", + type=int, + default=None, + help="Number of concurrent download tasks for Ray Data (env: DOWNLOAD_CONCURRENCY)" + ) + + # Image reader arguments + parser.add_argument( + "--tar-files-per-partition", + type=int, + default=None, + help="Number of tar files to process per partition (env: TAR_FILES_PER_PARTITION)" + ) + parser.add_argument( + "--batch-size", + type=int, + default=None, + help="Number of images per ImageBatch for the reader stage (env: BATCH_SIZE)" + ) + + # General arguments + parser.add_argument( + "--model-dir", + type=str, + required=False, + default=None, + help="Path to model directory containing all model weights (env: MODEL_DIR)" + ) + parser.add_argument( + "--verbose", + action="store_true", + default=False, + help="Enable verbose logging for all stages" + ) + + # Embedding stage arguments + parser.add_argument( + "--embedding-batch-size", + type=int, + default=None, + help="Batch size for embedding generation (env: EMBEDDING_BATCH_SIZE)" + ) + parser.add_argument( + "--embedding-gpus-per-worker", + type=float, + default=None, + help="GPU allocation per worker for embedding generation" + ) + + cli_args = parser.parse_args() + + # Resolve arguments from environment variables or command-line args + # Determine if Ray Data should be used (default: True unless --no-ray-data is set) + use_ray_data_env = os.environ.get("USE_RAY_DATA", "true").lower() in ("true", "1", "yes") + use_ray_data = use_ray_data_env if not cli_args.no_ray_data else False + if cli_args.use_ray_data: + use_ray_data = True + + args = argparse.Namespace( + input_parquet=get_env_or_arg("INPUT_PARQUET", cli_args.input_parquet), + input_wds_dataset_dir=get_env_or_arg("INPUT_WDS_DIR", cli_args.input_wds_dataset_dir), + output_dataset_dir=get_env_or_arg("OUTPUT_DIR", cli_args.output_dataset_dir), + embeddings_dir=get_env_or_arg("EMBEDDINGS_DIR", cli_args.embeddings_dir), + removal_parquets_dir=get_env_or_arg("REMOVAL_DIR", cli_args.removal_parquets_dir), + model_dir=get_env_or_arg("MODEL_DIR", cli_args.model_dir, "/home/ray/model_weights"), + download_processes=get_env_int("DOWNLOAD_PROCESSES", cli_args.download_processes, 8), + entries_per_tar=get_env_int("ENTRIES_PER_TAR", cli_args.entries_per_tar, 1000), + max_entries=int(get_env_or_arg("MAX_ENTRIES", cli_args.max_entries)) if get_env_or_arg("MAX_ENTRIES", cli_args.max_entries) else None, + skip_download=get_env_bool("SKIP_DOWNLOAD", cli_args.skip_download, False), + use_ray_data=use_ray_data, + download_concurrency=get_env_int("DOWNLOAD_CONCURRENCY", cli_args.download_concurrency, None) if get_env_or_arg("DOWNLOAD_CONCURRENCY", cli_args.download_concurrency) else None, + tar_files_per_partition=get_env_int("TAR_FILES_PER_PARTITION", cli_args.tar_files_per_partition, 1), + batch_size=get_env_int("BATCH_SIZE", cli_args.batch_size, 100), + embedding_batch_size=get_env_int("EMBEDDING_BATCH_SIZE", cli_args.embedding_batch_size, 32), + embedding_gpus_per_worker=get_env_float("EMBEDDING_GPUS_PER_WORKER", cli_args.embedding_gpus_per_worker, 0.25), + verbose=cli_args.verbose, + ) + + # Validate required arguments + required_args = ["input_wds_dataset_dir", "output_dataset_dir", "embeddings_dir", "removal_parquets_dir"] + missing = [arg for arg in required_args if getattr(args, arg) is None] + if missing: + parser.error(f"Missing required arguments: {', '.join(missing)}. " + "Set them via command-line or environment variables.") + + main(args) \ No newline at end of file diff --git a/nemo_curator_semantic_dedup/job.yaml b/nemo_curator_semantic_dedup/job.yaml new file mode 100644 index 0000000..af08da5 --- /dev/null +++ b/nemo_curator_semantic_dedup/job.yaml @@ -0,0 +1,84 @@ +# NeMo Curator Image Semantic Deduplication Job +# View the docs: https://docs.anyscale.com/reference/job-api#jobconfig +# +# This job runs a two-phase pipeline: +# Phase 1: Convert parquet (URLs) → WebDataset tar files (using Ray Data, distributed) +# Phase 2: Run NeMo Curator image deduplication (CLIP embeddings → semantic dedup) +# +# The parquet → tar conversion uses Ray Data to distribute image downloads +# across all nodes in the cluster, providing much better scalability than +# single-node processing. + +name: nemo-curator-image-dedup + +# Build custom image with NeMo Curator CUDA dependencies +containerfile: ./Dockerfile + +# Compute configuration with L4 GPU for CUDA-accelerated image processing +# CPU-only head node + GPU worker nodes (using ignore_head_node=True in executors) +compute_config: + head_node: + instance_type: m6i.2xlarge # CPU-only, 8 vCPUs, 32GB RAM + # No tasks scheduled here - using RayDataExecutor/RayActorPoolExecutor with ignore_head_node=True + resources: + CPU: 0 # Prevent any task scheduling on head node + worker_nodes: + - instance_type: g5.12xlarge # 4x A10G GPUs per worker, 48 vCPUs, 192GB RAM + min_nodes: 2 + max_nodes: 2 + +# Working directory - use the repo root (absolute) so Curator/ is included +working_dir: /home/ray/default + +# Environment variables for job configuration +# Override these when submitting to use your own data paths +env_vars: + # Input parquet file with image URLs (TEXT and URL columns) + # LAION dataset (relative to working_dir) + INPUT_PARQUET: "examples/nemo_curator_semantic_dedup/laion_meta/laion_subset_10m.parquet" + MAX_ENTRIES: "10000" # Limit for testing + + # Directory for WebDataset tar files (created from parquet) + # Use /mnt/cluster_storage for persistence, or /home/ray/data for ephemeral + INPUT_WDS_DIR: "/mnt/cluster_storage/nemo_curator/webdataset" + + # Output directory for deduplicated images + OUTPUT_DIR: "/mnt/cluster_storage/nemo_curator/results" + + # Directory to store CLIP embeddings + EMBEDDINGS_DIR: "/mnt/cluster_storage/nemo_curator/embeddings" + + # Directory for duplicate removal parquets + REMOVAL_DIR: "/mnt/cluster_storage/nemo_curator/removal_ids" + + # Model weights directory (pre-downloaded in Docker image) + MODEL_DIR: "/home/ray/model_weights" + + # Processing settings (reduced to prevent OOM) + BATCH_SIZE: "4" + EMBEDDING_BATCH_SIZE: "8" + TAR_FILES_PER_PARTITION: "1" + ENTRIES_PER_TAR: "500" + + # Ray Data settings for parquet -> tar conversion + # Uses distributed processing across all nodes in the cluster + USE_RAY_DATA: "true" # Set to "false" for single-node multiprocessing + # DOWNLOAD_CONCURRENCY: "" # Auto-detected from cluster resources if not set + + SKIP_DOWNLOAD: "false" # Set to "true" to skip parquet->tar and use existing tars + + # Don't hide GPUs from tasks that request num_gpus=0 (needed for DALI) + RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO: "0" + + # Disable Python output buffering for real-time logs + PYTHONUNBUFFERED: "1" + +# The entrypoint script (-u for unbuffered output) +entrypoint: python -u examples/nemo_curator_semantic_dedup/image_dedup_example.py + +# Don't retry on failure - easier to debug +max_retries: 0 + +# Kill after 4 hours to control costs (adjust based on dataset size) +timeout_s: 14400 +