From 96dfb68133df4bddd0b6fca6931463b3c60d5b94 Mon Sep 17 00:00:00 2001 From: "Amy X. Lu" Date: Tue, 8 Jul 2025 11:44:15 +0100 Subject: [PATCH 01/10] Update .gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 51bb5fd8..cdffd5e4 100644 --- a/.gitignore +++ b/.gitignore @@ -29,9 +29,11 @@ scripts/combined_db* *_play.py src/lobster/hydra_config/experiment/* src/lobster/mcp/claude_desktop_config.json +*.ipynb_checkpoints notebooks/nathan/* notebooks/karina/* +notebooks/amyxlu/* models/* From c7cf131a321ad649dfd5cc621dc2f1754de8d530 Mon Sep 17 00:00:00 2001 From: "Amy X. Lu" Date: Tue, 8 Jul 2025 11:45:04 +0100 Subject: [PATCH 02/10] Token selection scripts --- token_selection/scripts/inference.py | 264 ++++++++++++++++++ token_selection/scripts/save_token_losses.py | 224 +++++++++++++++ .../scripts/save_token_losses.slrm | 40 +++ 3 files changed, 528 insertions(+) create mode 100644 token_selection/scripts/inference.py create mode 100644 token_selection/scripts/save_token_losses.py create mode 100644 token_selection/scripts/save_token_losses.slrm diff --git a/token_selection/scripts/inference.py b/token_selection/scripts/inference.py new file mode 100644 index 00000000..6e08f398 --- /dev/null +++ b/token_selection/scripts/inference.py @@ -0,0 +1,264 @@ +import os +import glob +import numpy as np +import pandas as pd +import torch +import torch.distributed as dist +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.distributed import DistributedSampler + +class ShardedParquetDataset(Dataset): + def __init__(self, + parquet_dir, + percentile_threshold=90, + loss_threshold=None, + stats_file=None, + rank=None, + world_size=None): + """ + Distributed dataset for sharded parquet files. + + Args: + parquet_dir: Directory containing parquet shards + percentile_threshold: Only include tokens with loss below this percentile + loss_threshold: Optional explicit loss threshold (if pre-computed) + stats_file: Path to pre-computed statistics file + rank: Process rank in distributed training + world_size: Total number of processes + """ + self.parquet_dir = parquet_dir + self.percentile_threshold = percentile_threshold + + # Get list of all shard files + self.shard_files = sorted(glob.glob(f"{parquet_dir}/partition_id=*/part-*.parquet")) + + # If running distributed, only use shards for this rank + if rank is not None and world_size is not None: + # Distribute shards across workers + self.shard_files = [ + f for i, f in enumerate(self.shard_files) + if i % world_size == rank + ] + + # Set loss threshold either from argument or by loading stats + if loss_threshold is not None: + self.loss_threshold = loss_threshold + elif stats_file and os.path.exists(stats_file): + # Load pre-computed statistics + import json + with open(stats_file, 'r') as f: + stats = json.load(f) + self.loss_threshold = stats['percentiles'][str(percentile_threshold)] + else: + # Calculate threshold (ideally, this is pre-computed) + self.loss_threshold = self._calculate_percentile() + + # Load sequence metadata from all assigned shards + self.sequence_data = self._load_sequence_metadata() + + def _calculate_percentile(self): + """Calculate percentile threshold from samples.""" + # Only calculate on rank 0 and broadcast if distributed + if dist.is_initialized() and dist.get_rank() != 0: + # Non-root processes wait for result + threshold = torch.zeros(1, dtype=torch.float32).cuda() + dist.broadcast(threshold, 0) + return threshold.item() + + # Root process (or non-distributed) calculates + print(f"Calculating {self.percentile_threshold}th percentile threshold...") + samples = [] + + # Sample from each shard + for shard in self.shard_files[:10]: # Limit to 10 shards for efficiency + df = pd.read_parquet(shard, columns=['loss']) + # Take a sample proportional to size + sample_size = min(10000, len(df)) + if sample_size > 0: + samples.append(df.sample(sample_size)['loss'].values) + + # Calculate threshold from samples + if samples: + all_samples = np.concatenate(samples) + threshold = float(np.percentile(all_samples, self.percentile_threshold)) + else: + threshold = float('inf') # No samples available + + # Broadcast result if distributed + if dist.is_initialized(): + threshold_tensor = torch.tensor([threshold], dtype=torch.float32).cuda() + dist.broadcast(threshold_tensor, 0) + threshold = threshold_tensor.item() + + print(f"Using loss threshold: {threshold}") + return threshold + + def _load_sequence_metadata(self): + """Load sequence metadata from assigned shards.""" + sequences = [] + + for shard_file in self.shard_files: + # Read just sequence metadata for efficiency + try: + # Group by sequence_id and get sizes + df = pd.read_parquet( + shard_file, + columns=['sequence_id', 'position'] + ) + seq_info = df.groupby('sequence_id').agg({'position': 'max'}) + + for seq_id, max_pos in seq_info.itertuples(): + sequences.append({ + 'sequence_id': seq_id, + 'length': max_pos + 1, # Convert to length + 'shard_file': shard_file + }) + except Exception as e: + print(f"Error loading metadata from {shard_file}: {e}") + + return sequences + + def __len__(self): + return len(self.sequence_data) + + def __getitem__(self, idx): + """Get a filtered sequence by index.""" + seq_info = self.sequence_data[idx] + seq_id = seq_info['sequence_id'] + shard_file = seq_info['shard_file'] + + # Read this sequence with filtering + try: + # Use PyArrow filter pushdown for efficiency + df = pd.read_parquet( + shard_file, + filters=[ + ('sequence_id', '=', seq_id), + ('loss', '<=', self.loss_threshold) + ] + ) + + # Sort by position to maintain sequence order + if not df.empty: + df = df.sort_values('position') + + return { + 'sequence_id': seq_id, + 'tokens': df['token'].values, + 'positions': df['position'].values, + 'losses': df['loss'].values + } + else: + # No tokens passed the filter + return { + 'sequence_id': seq_id, + 'tokens': np.array([], dtype=np.int64), + 'positions': np.array([], dtype=np.int64), + 'losses': np.array([], dtype=np.float32) + } + + except Exception as e: + print(f"Error loading sequence {seq_id}: {e}") + # Return empty sequence on error + return { + 'sequence_id': seq_id, + 'tokens': np.array([], dtype=np.int64), + 'positions': np.array([], dtype=np.int64), + 'losses': np.array([], dtype=np.float32) + } + + +def collate_variable_length_sequences(batch): + """Custom collate function for variable-length sequences.""" + # Filter out empty sequences + non_empty = [b for b in batch if len(b['tokens']) > 0] + + if not non_empty: + # All sequences were empty after filtering + return { + 'sequence_ids': [], + 'tokens': torch.zeros(0, dtype=torch.int64), + 'positions': torch.zeros(0, dtype=torch.int64), + 'losses': torch.zeros(0, dtype=torch.float32), + 'batch_indices': torch.zeros(0, dtype=torch.int64) + } + + # Gather data + sequence_ids = [b['sequence_id'] for b in non_empty] + tokens_list = [torch.tensor(b['tokens'], dtype=torch.int64) for b in non_empty] + positions_list = [torch.tensor(b['positions'], dtype=torch.int64) for b in non_empty] + losses_list = [torch.tensor(b['losses'], dtype=torch.float32) for b in non_empty] + + # Create batch indices for reconstructing sequences later + batch_sizes = [len(t) for t in tokens_list] + batch_indices = torch.cat([ + torch.full((size,), i, dtype=torch.int64) + for i, size in enumerate(batch_sizes) + ]) + + # Concatenate all tokens + tokens = torch.cat(tokens_list) + positions = torch.cat(positions_list) + losses = torch.cat(losses_list) + + return { + 'sequence_ids': sequence_ids, + 'tokens': tokens, + 'positions': positions, + 'losses': losses, + 'batch_indices': batch_indices + } + + +def setup_distributed(): + """Initialize distributed training environment.""" + # Initialize process group + dist.init_process_group( + backend='nccl', # Use 'gloo' for CPU-only + init_method='env://' + ) + + # Get global rank and world size + rank = dist.get_rank() + world_size = dist.get_world_size() + + # Set device for this process + torch.cuda.set_device(rank % torch.cuda.device_count()) + + return rank, world_size + +def create_distributed_dataloader(parquet_dir, percentile_threshold=90, + batch_size=32, num_workers=4): + """Create a distributed dataloader for sharded parquet files.""" + # Setup distributed environment + rank, world_size = setup_distributed() + + # Create dataset with this rank's shards + dataset = ShardedParquetDataset( + parquet_dir=parquet_dir, + percentile_threshold=percentile_threshold, + stats_file=f"{parquet_dir}/stats.json", + rank=rank, + world_size=world_size + ) + + # Create distributed sampler to handle partitioning + sampler = DistributedSampler( + dataset, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=False + ) + + # Create dataloader with custom collate function + dataloader = DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=num_workers, + collate_fn=collate_variable_length_sequences, + pin_memory=True + ) + + return dataloader, rank, world_size \ No newline at end of file diff --git a/token_selection/scripts/save_token_losses.py b/token_selection/scripts/save_token_losses.py new file mode 100644 index 00000000..c25c4eac --- /dev/null +++ b/token_selection/scripts/save_token_losses.py @@ -0,0 +1,224 @@ +from typing import Dict, List, Any +import os +import argparse +from pathlib import Path + + +import numpy as np +import pandas as pd +from transformers import AutoModelForCausalLM, AutoTokenizer + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel +from torch.nn.functional import cross_entropy + +from lobster.datasets import FASTADataset + + +torch.set_float32_matmul_precision('high') + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--fasta_file", + type=str, + default="/data/lux70/data/uniref90/partial.fasta", + ) + parser.add_argument( + "--offset_array_path", + type=str, + default="/data/lux70/data/uniref90/partial.fasta.offsets.npy", + ) + parser.add_argument( + "--output_dir", + type=str, + default="/data/lux70/data/uniref90/token_losses", + ) + parser.add_argument( + "--model_name", + type=str, + default="lightonai/RITA_l", + ) + parser.add_argument( + "--batch_size", + type=int, + default=512, + ) + parser.add_argument( + "--max_length", + type=int, + default=512, + ) + parser.add_argument( + "--max_num_per_shard", + type=int, + default=100_000, + ) + parser.add_argument( + "--cur_num_in_shard", + type=int, + default=0, + ) + parser.add_argument( + "--cur_shard_num", + type=int, + default=0, + ) + return parser.parse_args() + + +def setup(rank, world_size): + # Initialize process group + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + +def cleanup(): + dist.destroy_process_group() + + +def load_model(model_name: str = "lightonai/RITA_xl", max_length: int = 512) -> torch.nn.Module: + """Load the model and tokenizer.""" + model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.pad_token = "" + tokenizer.pad_token_id = tokenizer.vocab[''] + tokenizer.max_length = max_length + model.eval() + return model, tokenizer + + +def get_model_device(model: torch.nn.Module) -> torch.device: + return next(model.parameters()).device + + +def compute_loss(batch, model, tokenizer, max_length, device=None) -> List[Dict[str, Any]]: + sequences, headers = batch + if device is not None: + device = get_model_device(model) + + sequences = [s[:max_length] for s in sequences] + inputs = tokenizer(sequences, return_tensors="pt", padding=True, truncation=False) + input_ids = inputs['input_ids'].to(device) + attn_mask = inputs['attention_mask'].to(device) + N, L = input_ids.shape[0], input_ids.shape[1] - 1 # remove EOS token + + with torch.no_grad(): + output = model( + input_ids=input_ids, + attention_mask=attn_mask + ) + + targets = input_ids[:, 1:].reshape(-1) + logits = output['logits'] + logits = logits[:, :-1, :].reshape(-1, logits.shape[-1]) + per_token_loss = cross_entropy(logits, targets, reduction="none") + per_token_loss = per_token_loss.reshape(-1, L).half() # store as float16. + + processed = [ + { + "sequence": sequences[i], + "header": headers[i], + "per_token_loss": per_token_loss[i, :min(len(sequences[i]), max_length)].cpu().tolist(), + } + for i in range(len(sequences)) + ] + return processed + + +def main(rank, args, world_size): + if world_size > 1: + setup(rank, world_size) + + output_dir = Path(args.output_dir) / args.model_name.replace("/", "_") + if not output_dir.exists(): + output_dir.mkdir(parents=True) + + # the fasta loader relies on offsets to do file.seek operations + # we can paralellize this by splitting up the offset array into subsections for each GPU + offset_array = np.load(args.offset_array_path) + print("Original offset array shape:", offset_array.shape) + assert len(offset_array.shape) == 2 + assert offset_array.shape[0] == 2 + + # Partition data for this GPU + per_gpu_size = offset_array.shape[1] // world_size + start_idx = rank * per_gpu_size + end_idx = start_idx + per_gpu_size if rank < world_size - 1 else offset_array.shape[1] + + local_offsets = offset_array[:, start_idx:end_idx] + print(f"Rank {rank} processing offsets from {start_idx} to {end_idx}") + print(f"Rank {rank} processing {local_offsets.shape[1]} sequences from {args.fasta_file}") + + # Create dataset and dataloader for this GPU + dataset = FASTADataset( + root=args.fasta_file, + offsets_arr=local_offsets, + use_text_descriptions=True + ) + + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, num_replicas=world_size, rank=rank + ) + + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=args.batch_size, sampler=sampler, shuffle=False + ) + + # Create model + model, tokenizer = load_model(args.model_name, args.max_length) + device = torch.device("cuda", rank) + model.to(device) + + # wrap in DDP and compile + if world_size > 1: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + ddp_model = DistributedDataParallel(model, device_ids=[rank], output_device=rank) + else: + ddp_model = model + + ddp_model = torch.compile(ddp_model) + + # Inference loop + results_tmp_list = [] + cur_shard_num = 0 + cur_num_in_shard = 0 + + for batch in dataloader: + with torch.no_grad(): + outputs = compute_loss(batch, ddp_model, tokenizer, args.max_length, device) + results_tmp_list.extend(outputs) + cur_num_in_shard += len(outputs) + + if cur_num_in_shard >= args.max_num_per_shard: + print(f"Saving shard {cur_shard_num} to {output_file}...") + output_file = output_dir / f"rank_{rank:02}_shard_{cur_shard_num:06}.parquet" + pd.DataFrame(results_tmp_list).to_parquet(output_file, engine='pyarrow', index=False) + + cur_shard_num += 1 + cur_num_in_shard = 0 + results_tmp_list = [] + + else: + print(f"Rank {rank} processed {cur_num_in_shard} sequences in shard {cur_shard_num}") + + +if __name__ == "__main__": + args = get_args() + world_size = torch.cuda.device_count() + + if world_size == 1: + print("Only one GPU available. Running without DDP.") + main(0, args, world_size) + exit() + + else: + print(f"Using {world_size} GPUs for DDP.") + rank = int(os.environ["LOCAL_RANK"]) + mp.spawn(main, (args, world_size), world_size, join=True) + + if world_size > 1: + cleanup() + diff --git a/token_selection/scripts/save_token_losses.slrm b/token_selection/scripts/save_token_losses.slrm new file mode 100644 index 00000000..09b71485 --- /dev/null +++ b/token_selection/scripts/save_token_losses.slrm @@ -0,0 +1,40 @@ +#!/usr/bin/env bash + +#SBATCH --job-name token_loss +#SBATCH --nodes 1 +#SBATCH --gpus-per-node 1 +#SBATCH --partition gpu2 +#SBATCH --cpus-per-gpu 4 +#SBATCH --mem 150G +#SBATCH --time=1-00:00:00 + +source !/.bashrc +eval "$(mamba shell hook --shell bash)" + +echo "SLURM_JOB_NODELIST = ${SLURM_JOB_NODELIST}" +echo "SLURM_JOB_ID = ${SLURM_JOB_ID}" +echo "SLURMD_NODENAME = ${SLURMD_NODENAME}" +echo "SLURM_JOB_NUM_NODES = ${SLURM_JOB_NUM_NODES}" + +# make sure that this is already set! +cd $LOBSTER_PROJECT_DIR + +# use uv, which should already be set up +source .venv/bin/activate + +echo "SLURM_JOB_NODELIST = ${SLURM_JOB_NODELIST}" +echo "SLURM_JOB_ID = ${SLURM_JOB_ID}" +echo "SLURMD_NODENAME = ${SLURMD_NODENAME}" +echo "SLURM_JOB_NUM_NODES = ${SLURM_JOB_NUM_NODES}" + +nvidia-smi +mamba activate plaid +mamba env list +echo $CONDA_PREFIX +which python + +# see save_token_losses.py for the default parser arguments +srun torchrun token_selection/scripts/save_token_losses.py \ + --fasta_file /data/bucket/freyn6/data/uniref50.fasta \ + --output_dir /data2/lux70/data/uniref50/per_token_losses \ + --max_num_per_shard 10000 \ No newline at end of file From abd7e6a532eb3d06f9eeaab1ac18ba721d49f033 Mon Sep 17 00:00:00 2001 From: "Amy X. Lu" Date: Wed, 9 Jul 2025 23:10:25 +0100 Subject: [PATCH 03/10] Load ProtGPT2 as option for model analysis --- src/lobster/model/_clm.py | 72 +++++++++++++++++++++++---------------- 1 file changed, 42 insertions(+), 30 deletions(-) diff --git a/src/lobster/model/_clm.py b/src/lobster/model/_clm.py index 267b3613..fd6a342f 100644 --- a/src/lobster/model/_clm.py +++ b/src/lobster/model/_clm.py @@ -4,7 +4,7 @@ import lightning.pytorch as pl import torch from torch.nn import CrossEntropyLoss -from transformers import LlamaConfig, LlamaForCausalLM, get_scheduler, pipeline +from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM, get_scheduler, pipeline from lobster.constants import SchedulerType from lobster.tokenization import PmlmTokenizer, PmlmTokenizerTransform @@ -13,6 +13,9 @@ from ._clm_configuration import PCLM_CONFIG_ARGS +ALLOWABLE_MODEL_NAMES = list(PCLM_CONFIG_ARGS.keys()) + ["ProtGPT2"] + + class LobsterPCLM(pl.LightningModule): def __init__( self, @@ -68,36 +71,45 @@ def __init__( self.scheduler_kwargs = scheduler_kwargs or {} model_kwargs = model_kwargs or {} - if self._tokenizer_dir is not None: - path = importlib.resources.files("lobster") / "assets" / self._tokenizer_dir - self.tokenizer = PmlmTokenizer.from_pretrained(path, do_lower_case=False) - self._transform_fn = transform_fn or PmlmTokenizerTransform( - path, - padding="max_length", - truncation=True, - max_length=self._max_length, - mlm=False, + assert model_name in ALLOWABLE_MODEL_NAMES, f"model_name must be one of {ALLOWABLE_MODEL_NAMES}" + + if model_name == "ProtGPT2": + self.tokenizer = AutoTokenizer.from_pretrained("nferruz/ProtGPT2") + self.model = AutoModelForCausalLM.from_pretrained("nferruz/ProtGPT2") + self.config = self.model.config + + else: + # Create PCLM model + if self._tokenizer_dir is not None: + path = importlib.resources.files("lobster") / "assets" / self._tokenizer_dir + self.tokenizer = PmlmTokenizer.from_pretrained(path, do_lower_case=False) + self._transform_fn = transform_fn or PmlmTokenizerTransform( + path, + padding="max_length", + truncation=True, + max_length=self._max_length, + mlm=False, + ) + + config_args = PCLM_CONFIG_ARGS[model_name] + if num_key_value_heads is None: + num_key_value_heads = config_args["num_attention_heads"] + self._num_key_value_heads = num_key_value_heads + + config = LlamaConfig( + **config_args, + mask_token_id=self.tokenizer.mask_token_id, + pad_token_id=self.tokenizer.pad_token_id, + cls_token_id=self.tokenizer.cls_token_id, + eos_token_id=self.tokenizer.eos_token_id, + vocab_size=len(self.tokenizer.get_vocab()), + max_position_embeddings=self._max_length, + num_key_value_heads=self._num_key_value_heads, + attention_bias=self._attention_bias, + **model_kwargs, ) - - config_args = PCLM_CONFIG_ARGS[model_name] - if num_key_value_heads is None: - num_key_value_heads = config_args["num_attention_heads"] - self._num_key_value_heads = num_key_value_heads - - config = LlamaConfig( - **config_args, - mask_token_id=self.tokenizer.mask_token_id, - pad_token_id=self.tokenizer.pad_token_id, - cls_token_id=self.tokenizer.cls_token_id, - eos_token_id=self.tokenizer.eos_token_id, - vocab_size=len(self.tokenizer.get_vocab()), - max_position_embeddings=self._max_length, - num_key_value_heads=self._num_key_value_heads, - attention_bias=self._attention_bias, - **model_kwargs, - ) - self.model = LlamaForCausalLM(config) - self.config = self.model.config + self.model = LlamaForCausalLM(config) + self.config = self.model.config self.save_hyperparameters(logger=False) From 33044492324bb300aea6de64fbfac7adc8a42bd0 Mon Sep 17 00:00:00 2001 From: "Amy X. Lu" Date: Wed, 9 Jul 2025 23:11:27 +0100 Subject: [PATCH 04/10] Add option for explicit numpy offsets loading --- src/lobster/data/_fasta_datamodule.py | 17 ++++++++++------- src/lobster/datasets/_fasta_dataset.py | 20 ++++++++++++-------- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/src/lobster/data/_fasta_datamodule.py b/src/lobster/data/_fasta_datamodule.py index 53fbea19..1a3b42f3 100644 --- a/src/lobster/data/_fasta_datamodule.py +++ b/src/lobster/data/_fasta_datamodule.py @@ -1,9 +1,10 @@ import importlib from collections.abc import Callable, Iterable, Sequence from pathlib import Path -from typing import Any, TypeVar +from typing import Any, TypeVar, Optional import pandas as pd +import numpy as np import torch.utils.data # from beignet.datasets import FASTADataset @@ -43,6 +44,7 @@ def __init__( is_relative_model: bool = False, tokenizer_dir: str | None = "pmlm_tokenizer", mlm: bool = True, + offsets_arr: Optional[np.ndarray] = None, ) -> None: """ :param path_to_fasta: path to fasta file @@ -139,6 +141,7 @@ def __init__( self._is_relative_model = is_relative_model self._tokenizer_dir = tokenizer_dir self._mlm = mlm + self._offsets_arr = offsets_arr path = importlib.resources.files("lobster") / "assets" / self._tokenizer_dir self._transform_fn = transform_fn or PmlmTokenizerTransform( @@ -159,16 +162,16 @@ def setup(self, stage: str = "fit") -> None: # noqa: ARG002 if stage == "fit": if any(["train" in self._path_to_fasta]): # pre computed splits self._train_dataset = torch.utils.data.ConcatDataset( - [FASTADataset(root=p, transform=self._transform_fn) for p in self._path_to_fasta if "train" in p] + [FASTADataset(root=p, transform=self._transform_fn, offsets_arr=self._offsets_arr) for p in self._path_to_fasta if "train" in p] ) self._val_dataset = torch.utils.data.ConcatDataset( - [FASTADataset(root=p, transform=self._transform_fn) for p in self._path_to_fasta if "val" in p] + [FASTADataset(root=p, transform=self._transform_fn, offsets_arr=self._offsets_arr) for p in self._path_to_fasta if "val" in p] ) self._test_dataset = torch.utils.data.ConcatDataset( - [FASTADataset(root=p, transform=self._transform_fn) for p in self._path_to_fasta if "test" in p] + [FASTADataset(root=p, transform=self._transform_fn, offsets_arr=self._offsets_arr) for p in self._path_to_fasta if "test" in p] ) else: # iid split - datasets = [FASTADataset(root=p, transform=self._transform_fn) for p in self._path_to_fasta] + datasets = [FASTADataset(root=p, transform=self._transform_fn, offsets_arr=self._offsets_arr) for p in self._path_to_fasta] dataset = torch.utils.data.ConcatDataset(datasets) ( self._train_dataset, @@ -181,7 +184,7 @@ def setup(self, stage: str = "fit") -> None: # noqa: ARG002 ) if stage == "predict": - datasets = [FASTADataset(root=p, transform=self._transform_fn) for p in self._path_to_fasta] + datasets = [FASTADataset(root=p, transform=self._transform_fn, offsets_arr=self._offsets_arr) for p in self._path_to_fasta] dataset = torch.utils.data.ConcatDataset(datasets) self._predict_dataset = dataset @@ -236,4 +239,4 @@ def _clm_data_wrangle(self, dataset) -> Dataset: seq_dict = dict(seqs_for_dl) seq_dict_df = pd.DataFrame(seq_dict.items(), columns=["input_ids", "Labels"]) seq_dict_df = Dataset.from_pandas(seq_dict_df) - return seq_dict_df + return seq_dict_df \ No newline at end of file diff --git a/src/lobster/datasets/_fasta_dataset.py b/src/lobster/datasets/_fasta_dataset.py index e78ad354..7142a690 100644 --- a/src/lobster/datasets/_fasta_dataset.py +++ b/src/lobster/datasets/_fasta_dataset.py @@ -1,7 +1,7 @@ import subprocess from collections.abc import Callable from pathlib import Path -from typing import TypeVar +from typing import TypeVar, Optional import numpy from beignet.datasets._sized_sequence_dataset import SizedSequenceDataset @@ -17,6 +17,7 @@ def __init__( *, transform: Callable | None = None, use_text_descriptions: bool = True, + offsets_arr: Optional[numpy.ndarray] = None, ) -> None: if isinstance(root, str): root = Path(root) @@ -32,14 +33,17 @@ def __init__( self.data = ThreadSafeFile(self.root, open) - offsets = Path(f"{self.root}.offsets.npy") + if offsets_arr is None: + offsets_path = Path(f"{self.root}.offsets.npy") + if offsets_path.exists(): + self.offsets, sizes = numpy.load(f"{offsets_path}") + else: + self.offsets, sizes = self._build_index() + numpy.save(f"{offsets_path}", numpy.stack([self.offsets, sizes])) - if offsets.exists(): - self.offsets, sizes = numpy.load(f"{offsets}") else: - self.offsets, sizes = self._build_index() - - numpy.save(f"{offsets}", numpy.stack([self.offsets, sizes])) + self.offsets = offsets_arr[0, :] + sizes = offsets_arr[1, :] self.transform = transform @@ -93,4 +97,4 @@ def _build_index(self) -> tuple[numpy.ndarray, numpy.ndarray]: dtype=numpy.int64, sep=" ", ), - ) + ) \ No newline at end of file From 9014c613e3e12520a0d6b55ef231184c42b3be22 Mon Sep 17 00:00:00 2001 From: "Amy X. Lu" Date: Sun, 20 Jul 2025 15:48:24 +0100 Subject: [PATCH 05/10] Move save_tokens file --- .../save_token_losses.slrm => slurm/scripts/save_token_losses.sh | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename token_selection/scripts/save_token_losses.slrm => slurm/scripts/save_token_losses.sh (100%) diff --git a/token_selection/scripts/save_token_losses.slrm b/slurm/scripts/save_token_losses.sh similarity index 100% rename from token_selection/scripts/save_token_losses.slrm rename to slurm/scripts/save_token_losses.sh From 4094f9dc983d519b8dd2d7f35640cb5e3f3fc0bf Mon Sep 17 00:00:00 2001 From: "Amy X. Lu" Date: Sun, 20 Jul 2025 15:50:29 +0100 Subject: [PATCH 06/10] Update parser defaults and help prompts --- token_selection/scripts/save_token_losses.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/token_selection/scripts/save_token_losses.py b/token_selection/scripts/save_token_losses.py index c25c4eac..fa229f3c 100644 --- a/token_selection/scripts/save_token_losses.py +++ b/token_selection/scripts/save_token_losses.py @@ -25,47 +25,53 @@ def get_args(): parser.add_argument( "--fasta_file", type=str, - default="/data/lux70/data/uniref90/partial.fasta", + help="Path to the FASTA file containing sequences.", ) parser.add_argument( "--offset_array_path", type=str, - default="/data/lux70/data/uniref90/partial.fasta.offsets.npy", + help="Path to the numpy array containing offsets for the FASTA file.", ) parser.add_argument( "--output_dir", type=str, - default="/data/lux70/data/uniref90/token_losses", + help="Directory to save the output files.", ) parser.add_argument( "--model_name", type=str, default="lightonai/RITA_l", + help="Name of the autoregressive model to use for token loss computation." ) parser.add_argument( "--batch_size", type=int, default=512, + help="Batch size for processing sequences. Adjust based on GPU memory.", ) parser.add_argument( "--max_length", type=int, default=512, + help="Maximum sequence length for the model. Sequences longer than this will be truncated.", ) parser.add_argument( "--max_num_per_shard", type=int, default=100_000, + help="Maximum number of sequences to process in each shard. Adjust based on GPU memory.", ) parser.add_argument( "--cur_num_in_shard", type=int, default=0, + help="Current number of sequences processed in the current shard. Used for resuming processing.", ) parser.add_argument( "--cur_shard_num", type=int, default=0, + help="Current shard number. Used for resuming processing.", ) return parser.parse_args() From 739b33fc143e9515adfa0c88e0a0c91ee11b1171 Mon Sep 17 00:00:00 2001 From: "Amy X. Lu" Date: Sun, 20 Jul 2025 15:52:07 +0100 Subject: [PATCH 07/10] Refactor sharede parquet dataset --- src/lobster/datasets/__init__.py | 3 +++ .../lobster/datasets/_sharded_parquet_dataset.py | 1 + 2 files changed, 4 insertions(+) rename token_selection/scripts/inference.py => src/lobster/datasets/_sharded_parquet_dataset.py (99%) diff --git a/src/lobster/datasets/__init__.py b/src/lobster/datasets/__init__.py index 9fa880d2..202acbbb 100644 --- a/src/lobster/datasets/__init__.py +++ b/src/lobster/datasets/__init__.py @@ -14,6 +14,8 @@ from ._shuffled_iterable_dataset import ShuffledIterableDataset from ._ume_streaming_dataset import UMEStreamingDataset from ._zinc_dataset import ZINCIterableDataset +from ._sharded_parquet_dataset import ShardedParquetDataset + __all__ = [ "CalmDataset", @@ -36,4 +38,5 @@ "ZINCIterableDataset", "OpenGenome2IterableDataset", "UMEStreamingDataset", + "ShardedParquetDataset", ] diff --git a/token_selection/scripts/inference.py b/src/lobster/datasets/_sharded_parquet_dataset.py similarity index 99% rename from token_selection/scripts/inference.py rename to src/lobster/datasets/_sharded_parquet_dataset.py index 6e08f398..8b753b60 100644 --- a/token_selection/scripts/inference.py +++ b/src/lobster/datasets/_sharded_parquet_dataset.py @@ -7,6 +7,7 @@ from torch.utils.data import Dataset, DataLoader from torch.utils.data.distributed import DistributedSampler + class ShardedParquetDataset(Dataset): def __init__(self, parquet_dir, From 8c3ead7e1ffebd6f7a566159dfebba2fe556c1c8 Mon Sep 17 00:00:00 2001 From: "Amy X. Lu" Date: Sun, 20 Jul 2025 16:03:51 +0100 Subject: [PATCH 08/10] Add README documentation --- token_selection/scripts/README.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 token_selection/scripts/README.md diff --git a/token_selection/scripts/README.md b/token_selection/scripts/README.md new file mode 100644 index 00000000..b3e0b682 --- /dev/null +++ b/token_selection/scripts/README.md @@ -0,0 +1,20 @@ +# Selective Token Modeling + +This directory contains experiments related to calculating per-token losses on an existing pretrained model for the purpose of Selective Token Modeling (SLM) (see the Rho-1 [paper](https://arxiv.org/abs/2404.07965) by Lin et al.). +The core idea is that not all tokens are similarly difficult for the model to learn; in the English language, this might be tokens such as `the`. Faster convergence, better performance, and/or reduced model parameter size can be achieved by selectively trains on useful tokens that aligned with the desired distribution. We can make use of previously trained models to determine this notion of "in-distribution". + +From the project root directory, running +``` +LOBSTER_PROJECT_DIR=$(pwd) +sbatch slurm/scripts/save_token_losses.sh +``` + +will launch a multi-GPU inference job that saves per-token losses for a FASTA sequence on a specified model (the autoregressive [RITA-Large](https://arxiv.org/abs/2205.05789) model is used by default) into Parquet format. + +Model training with selective token percentages can be done using the dataloader in `datasets/_sharded_parquet_dataset.py`. + +## Extensions: +- [ ] Perform the same experiment for other modalities for data mixture determination +- [ ] Perform the same experiment for downstream tasks to determine which tasks are more difficult for the model +- [ ] Perform ablation experiments by incorporating data at different loss percentages. +- [ ] Perform on masked language models to see if the pattern is different. Note: this will require O(L) forward passes. From 3022c6cb71f3b80974ab035582f1ae30c37635f5 Mon Sep 17 00:00:00 2001 From: "Amy X. Lu" Date: Sun, 20 Jul 2025 16:07:36 +0100 Subject: [PATCH 09/10] Ruff formatting --- token_selection/{scripts => }/README.md | 0 token_selection/scripts/save_token_losses.py | 61 ++++++++------------ 2 files changed, 25 insertions(+), 36 deletions(-) rename token_selection/{scripts => }/README.md (100%) diff --git a/token_selection/scripts/README.md b/token_selection/README.md similarity index 100% rename from token_selection/scripts/README.md rename to token_selection/README.md diff --git a/token_selection/scripts/save_token_losses.py b/token_selection/scripts/save_token_losses.py index fa229f3c..8d1ad32f 100644 --- a/token_selection/scripts/save_token_losses.py +++ b/token_selection/scripts/save_token_losses.py @@ -14,10 +14,11 @@ from torch.nn.parallel import DistributedDataParallel from torch.nn.functional import cross_entropy -from lobster.datasets import FASTADataset +from lobster.datasets import FASTADataset -torch.set_float32_matmul_precision('high') +torch.set_float32_matmul_precision("high") + def get_args(): parser = argparse.ArgumentParser() @@ -41,7 +42,7 @@ def get_args(): "--model_name", type=str, default="lightonai/RITA_l", - help="Name of the autoregressive model to use for token loss computation." + help="Name of the autoregressive model to use for token loss computation.", ) parser.add_argument( "--batch_size", @@ -90,7 +91,7 @@ def load_model(model_name: str = "lightonai/RITA_xl", max_length: int = 512) -> model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.pad_token = "" - tokenizer.pad_token_id = tokenizer.vocab[''] + tokenizer.pad_token_id = tokenizer.vocab[""] tokenizer.max_length = max_length model.eval() return model, tokenizer @@ -107,18 +108,15 @@ def compute_loss(batch, model, tokenizer, max_length, device=None) -> List[Dict[ sequences = [s[:max_length] for s in sequences] inputs = tokenizer(sequences, return_tensors="pt", padding=True, truncation=False) - input_ids = inputs['input_ids'].to(device) - attn_mask = inputs['attention_mask'].to(device) - N, L = input_ids.shape[0], input_ids.shape[1] - 1 # remove EOS token + input_ids = inputs["input_ids"].to(device) + attn_mask = inputs["attention_mask"].to(device) + N, L = input_ids.shape[0], input_ids.shape[1] - 1 # remove EOS token with torch.no_grad(): - output = model( - input_ids=input_ids, - attention_mask=attn_mask - ) + output = model(input_ids=input_ids, attention_mask=attn_mask) targets = input_ids[:, 1:].reshape(-1) - logits = output['logits'] + logits = output["logits"] logits = logits[:, :-1, :].reshape(-1, logits.shape[-1]) per_token_loss = cross_entropy(logits, targets, reduction="none") per_token_loss = per_token_loss.reshape(-1, L).half() # store as float16. @@ -127,7 +125,7 @@ def compute_loss(batch, model, tokenizer, max_length, device=None) -> List[Dict[ { "sequence": sequences[i], "header": headers[i], - "per_token_loss": per_token_loss[i, :min(len(sequences[i]), max_length)].cpu().tolist(), + "per_token_loss": per_token_loss[i, : min(len(sequences[i]), max_length)].cpu().tolist(), } for i in range(len(sequences)) ] @@ -137,7 +135,7 @@ def compute_loss(batch, model, tokenizer, max_length, device=None) -> List[Dict[ def main(rank, args, world_size): if world_size > 1: setup(rank, world_size) - + output_dir = Path(args.output_dir) / args.model_name.replace("/", "_") if not output_dir.exists(): output_dir.mkdir(parents=True) @@ -148,7 +146,7 @@ def main(rank, args, world_size): print("Original offset array shape:", offset_array.shape) assert len(offset_array.shape) == 2 assert offset_array.shape[0] == 2 - + # Partition data for this GPU per_gpu_size = offset_array.shape[1] // world_size start_idx = rank * per_gpu_size @@ -159,20 +157,12 @@ def main(rank, args, world_size): print(f"Rank {rank} processing {local_offsets.shape[1]} sequences from {args.fasta_file}") # Create dataset and dataloader for this GPU - dataset = FASTADataset( - root=args.fasta_file, - offsets_arr=local_offsets, - use_text_descriptions=True - ) + dataset = FASTADataset(root=args.fasta_file, offsets_arr=local_offsets, use_text_descriptions=True) - sampler = torch.utils.data.distributed.DistributedSampler( - dataset, num_replicas=world_size, rank=rank - ) + sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank) + + dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, sampler=sampler, shuffle=False) - dataloader = torch.utils.data.DataLoader( - dataset, batch_size=args.batch_size, sampler=sampler, shuffle=False - ) - # Create model model, tokenizer = load_model(args.model_name, args.max_length) device = torch.device("cuda", rank) @@ -186,11 +176,11 @@ def main(rank, args, world_size): ddp_model = model ddp_model = torch.compile(ddp_model) - + # Inference loop results_tmp_list = [] - cur_shard_num = 0 - cur_num_in_shard = 0 + cur_shard_num = 0 + cur_num_in_shard = 0 for batch in dataloader: with torch.no_grad(): @@ -201,15 +191,15 @@ def main(rank, args, world_size): if cur_num_in_shard >= args.max_num_per_shard: print(f"Saving shard {cur_shard_num} to {output_file}...") output_file = output_dir / f"rank_{rank:02}_shard_{cur_shard_num:06}.parquet" - pd.DataFrame(results_tmp_list).to_parquet(output_file, engine='pyarrow', index=False) + pd.DataFrame(results_tmp_list).to_parquet(output_file, engine="pyarrow", index=False) cur_shard_num += 1 cur_num_in_shard = 0 results_tmp_list = [] - + else: print(f"Rank {rank} processed {cur_num_in_shard} sequences in shard {cur_shard_num}") - + if __name__ == "__main__": args = get_args() @@ -219,12 +209,11 @@ def main(rank, args, world_size): print("Only one GPU available. Running without DDP.") main(0, args, world_size) exit() - + else: print(f"Using {world_size} GPUs for DDP.") rank = int(os.environ["LOCAL_RANK"]) mp.spawn(main, (args, world_size), world_size, join=True) - + if world_size > 1: cleanup() - From 9c55725cafd12fcd270e08f97fc745fdec52d107 Mon Sep 17 00:00:00 2001 From: "Amy X. Lu" Date: Mon, 28 Jul 2025 18:01:58 +0100 Subject: [PATCH 10/10] Ruff --- src/lobster/data/_fasta_datamodule.py | 34 ++- src/lobster/datasets/_fasta_dataset.py | 6 +- .../datasets/_sharded_parquet_dataset.py | 204 ++++++++---------- 3 files changed, 121 insertions(+), 123 deletions(-) diff --git a/src/lobster/data/_fasta_datamodule.py b/src/lobster/data/_fasta_datamodule.py index 1a3b42f3..081fb3e6 100644 --- a/src/lobster/data/_fasta_datamodule.py +++ b/src/lobster/data/_fasta_datamodule.py @@ -1,7 +1,7 @@ import importlib from collections.abc import Callable, Iterable, Sequence from pathlib import Path -from typing import Any, TypeVar, Optional +from typing import Any, TypeVar import pandas as pd import numpy as np @@ -44,7 +44,7 @@ def __init__( is_relative_model: bool = False, tokenizer_dir: str | None = "pmlm_tokenizer", mlm: bool = True, - offsets_arr: Optional[np.ndarray] = None, + offsets_arr: np.ndarray | None = None, ) -> None: """ :param path_to_fasta: path to fasta file @@ -162,16 +162,31 @@ def setup(self, stage: str = "fit") -> None: # noqa: ARG002 if stage == "fit": if any(["train" in self._path_to_fasta]): # pre computed splits self._train_dataset = torch.utils.data.ConcatDataset( - [FASTADataset(root=p, transform=self._transform_fn, offsets_arr=self._offsets_arr) for p in self._path_to_fasta if "train" in p] + [ + FASTADataset(root=p, transform=self._transform_fn, offsets_arr=self._offsets_arr) + for p in self._path_to_fasta + if "train" in p + ] ) self._val_dataset = torch.utils.data.ConcatDataset( - [FASTADataset(root=p, transform=self._transform_fn, offsets_arr=self._offsets_arr) for p in self._path_to_fasta if "val" in p] + [ + FASTADataset(root=p, transform=self._transform_fn, offsets_arr=self._offsets_arr) + for p in self._path_to_fasta + if "val" in p + ] ) self._test_dataset = torch.utils.data.ConcatDataset( - [FASTADataset(root=p, transform=self._transform_fn, offsets_arr=self._offsets_arr) for p in self._path_to_fasta if "test" in p] + [ + FASTADataset(root=p, transform=self._transform_fn, offsets_arr=self._offsets_arr) + for p in self._path_to_fasta + if "test" in p + ] ) else: # iid split - datasets = [FASTADataset(root=p, transform=self._transform_fn, offsets_arr=self._offsets_arr) for p in self._path_to_fasta] + datasets = [ + FASTADataset(root=p, transform=self._transform_fn, offsets_arr=self._offsets_arr) + for p in self._path_to_fasta + ] dataset = torch.utils.data.ConcatDataset(datasets) ( self._train_dataset, @@ -184,7 +199,10 @@ def setup(self, stage: str = "fit") -> None: # noqa: ARG002 ) if stage == "predict": - datasets = [FASTADataset(root=p, transform=self._transform_fn, offsets_arr=self._offsets_arr) for p in self._path_to_fasta] + datasets = [ + FASTADataset(root=p, transform=self._transform_fn, offsets_arr=self._offsets_arr) + for p in self._path_to_fasta + ] dataset = torch.utils.data.ConcatDataset(datasets) self._predict_dataset = dataset @@ -239,4 +257,4 @@ def _clm_data_wrangle(self, dataset) -> Dataset: seq_dict = dict(seqs_for_dl) seq_dict_df = pd.DataFrame(seq_dict.items(), columns=["input_ids", "Labels"]) seq_dict_df = Dataset.from_pandas(seq_dict_df) - return seq_dict_df \ No newline at end of file + return seq_dict_df diff --git a/src/lobster/datasets/_fasta_dataset.py b/src/lobster/datasets/_fasta_dataset.py index 7142a690..0e36ef53 100644 --- a/src/lobster/datasets/_fasta_dataset.py +++ b/src/lobster/datasets/_fasta_dataset.py @@ -1,7 +1,7 @@ import subprocess from collections.abc import Callable from pathlib import Path -from typing import TypeVar, Optional +from typing import TypeVar import numpy from beignet.datasets._sized_sequence_dataset import SizedSequenceDataset @@ -17,7 +17,7 @@ def __init__( *, transform: Callable | None = None, use_text_descriptions: bool = True, - offsets_arr: Optional[numpy.ndarray] = None, + offsets_arr: numpy.ndarray | None = None, ) -> None: if isinstance(root, str): root = Path(root) @@ -97,4 +97,4 @@ def _build_index(self) -> tuple[numpy.ndarray, numpy.ndarray]: dtype=numpy.int64, sep=" ", ), - ) \ No newline at end of file + ) diff --git a/src/lobster/datasets/_sharded_parquet_dataset.py b/src/lobster/datasets/_sharded_parquet_dataset.py index 8b753b60..4556a72e 100644 --- a/src/lobster/datasets/_sharded_parquet_dataset.py +++ b/src/lobster/datasets/_sharded_parquet_dataset.py @@ -9,16 +9,12 @@ class ShardedParquetDataset(Dataset): - def __init__(self, - parquet_dir, - percentile_threshold=90, - loss_threshold=None, - stats_file=None, - rank=None, - world_size=None): + def __init__( + self, parquet_dir, percentile_threshold=90, loss_threshold=None, stats_file=None, rank=None, world_size=None + ): """ Distributed dataset for sharded parquet files. - + Args: parquet_dir: Directory containing parquet shards percentile_threshold: Only include tokens with loss below this percentile @@ -29,34 +25,32 @@ def __init__(self, """ self.parquet_dir = parquet_dir self.percentile_threshold = percentile_threshold - + # Get list of all shard files self.shard_files = sorted(glob.glob(f"{parquet_dir}/partition_id=*/part-*.parquet")) - + # If running distributed, only use shards for this rank if rank is not None and world_size is not None: # Distribute shards across workers - self.shard_files = [ - f for i, f in enumerate(self.shard_files) - if i % world_size == rank - ] - + self.shard_files = [f for i, f in enumerate(self.shard_files) if i % world_size == rank] + # Set loss threshold either from argument or by loading stats if loss_threshold is not None: self.loss_threshold = loss_threshold elif stats_file and os.path.exists(stats_file): # Load pre-computed statistics import json - with open(stats_file, 'r') as f: + + with open(stats_file) as f: stats = json.load(f) - self.loss_threshold = stats['percentiles'][str(percentile_threshold)] + self.loss_threshold = stats["percentiles"][str(percentile_threshold)] else: # Calculate threshold (ideally, this is pre-computed) self.loss_threshold = self._calculate_percentile() - + # Load sequence metadata from all assigned shards self.sequence_data = self._load_sequence_metadata() - + def _calculate_percentile(self): """Calculate percentile threshold from samples.""" # Only calculate on rank 0 and broadcast if distributed @@ -65,149 +59,141 @@ def _calculate_percentile(self): threshold = torch.zeros(1, dtype=torch.float32).cuda() dist.broadcast(threshold, 0) return threshold.item() - + # Root process (or non-distributed) calculates print(f"Calculating {self.percentile_threshold}th percentile threshold...") samples = [] - + # Sample from each shard for shard in self.shard_files[:10]: # Limit to 10 shards for efficiency - df = pd.read_parquet(shard, columns=['loss']) + df = pd.read_parquet(shard, columns=["loss"]) # Take a sample proportional to size sample_size = min(10000, len(df)) if sample_size > 0: - samples.append(df.sample(sample_size)['loss'].values) - + samples.append(df.sample(sample_size)["loss"].values) + # Calculate threshold from samples if samples: all_samples = np.concatenate(samples) threshold = float(np.percentile(all_samples, self.percentile_threshold)) else: - threshold = float('inf') # No samples available - + threshold = float("inf") # No samples available + # Broadcast result if distributed if dist.is_initialized(): threshold_tensor = torch.tensor([threshold], dtype=torch.float32).cuda() dist.broadcast(threshold_tensor, 0) threshold = threshold_tensor.item() - + print(f"Using loss threshold: {threshold}") return threshold - + def _load_sequence_metadata(self): """Load sequence metadata from assigned shards.""" sequences = [] - + for shard_file in self.shard_files: # Read just sequence metadata for efficiency try: # Group by sequence_id and get sizes - df = pd.read_parquet( - shard_file, - columns=['sequence_id', 'position'] - ) - seq_info = df.groupby('sequence_id').agg({'position': 'max'}) - + df = pd.read_parquet(shard_file, columns=["sequence_id", "position"]) + seq_info = df.groupby("sequence_id").agg({"position": "max"}) + for seq_id, max_pos in seq_info.itertuples(): - sequences.append({ - 'sequence_id': seq_id, - 'length': max_pos + 1, # Convert to length - 'shard_file': shard_file - }) + sequences.append( + { + "sequence_id": seq_id, + "length": max_pos + 1, # Convert to length + "shard_file": shard_file, + } + ) except Exception as e: print(f"Error loading metadata from {shard_file}: {e}") - + return sequences - + def __len__(self): return len(self.sequence_data) - + def __getitem__(self, idx): """Get a filtered sequence by index.""" seq_info = self.sequence_data[idx] - seq_id = seq_info['sequence_id'] - shard_file = seq_info['shard_file'] - + seq_id = seq_info["sequence_id"] + shard_file = seq_info["shard_file"] + # Read this sequence with filtering try: # Use PyArrow filter pushdown for efficiency df = pd.read_parquet( - shard_file, - filters=[ - ('sequence_id', '=', seq_id), - ('loss', '<=', self.loss_threshold) - ] + shard_file, filters=[("sequence_id", "=", seq_id), ("loss", "<=", self.loss_threshold)] ) - + # Sort by position to maintain sequence order if not df.empty: - df = df.sort_values('position') - + df = df.sort_values("position") + return { - 'sequence_id': seq_id, - 'tokens': df['token'].values, - 'positions': df['position'].values, - 'losses': df['loss'].values + "sequence_id": seq_id, + "tokens": df["token"].values, + "positions": df["position"].values, + "losses": df["loss"].values, } else: # No tokens passed the filter return { - 'sequence_id': seq_id, - 'tokens': np.array([], dtype=np.int64), - 'positions': np.array([], dtype=np.int64), - 'losses': np.array([], dtype=np.float32) + "sequence_id": seq_id, + "tokens": np.array([], dtype=np.int64), + "positions": np.array([], dtype=np.int64), + "losses": np.array([], dtype=np.float32), } - + except Exception as e: print(f"Error loading sequence {seq_id}: {e}") # Return empty sequence on error return { - 'sequence_id': seq_id, - 'tokens': np.array([], dtype=np.int64), - 'positions': np.array([], dtype=np.int64), - 'losses': np.array([], dtype=np.float32) + "sequence_id": seq_id, + "tokens": np.array([], dtype=np.int64), + "positions": np.array([], dtype=np.int64), + "losses": np.array([], dtype=np.float32), } def collate_variable_length_sequences(batch): """Custom collate function for variable-length sequences.""" # Filter out empty sequences - non_empty = [b for b in batch if len(b['tokens']) > 0] - + non_empty = [b for b in batch if len(b["tokens"]) > 0] + if not non_empty: # All sequences were empty after filtering return { - 'sequence_ids': [], - 'tokens': torch.zeros(0, dtype=torch.int64), - 'positions': torch.zeros(0, dtype=torch.int64), - 'losses': torch.zeros(0, dtype=torch.float32), - 'batch_indices': torch.zeros(0, dtype=torch.int64) + "sequence_ids": [], + "tokens": torch.zeros(0, dtype=torch.int64), + "positions": torch.zeros(0, dtype=torch.int64), + "losses": torch.zeros(0, dtype=torch.float32), + "batch_indices": torch.zeros(0, dtype=torch.int64), } - + # Gather data - sequence_ids = [b['sequence_id'] for b in non_empty] - tokens_list = [torch.tensor(b['tokens'], dtype=torch.int64) for b in non_empty] - positions_list = [torch.tensor(b['positions'], dtype=torch.int64) for b in non_empty] - losses_list = [torch.tensor(b['losses'], dtype=torch.float32) for b in non_empty] - + sequence_ids = [b["sequence_id"] for b in non_empty] + tokens_list = [torch.tensor(b["tokens"], dtype=torch.int64) for b in non_empty] + positions_list = [torch.tensor(b["positions"], dtype=torch.int64) for b in non_empty] + losses_list = [torch.tensor(b["losses"], dtype=torch.float32) for b in non_empty] + # Create batch indices for reconstructing sequences later batch_sizes = [len(t) for t in tokens_list] - batch_indices = torch.cat([ - torch.full((size,), i, dtype=torch.int64) - for i, size in enumerate(batch_sizes) - ]) - + batch_indices = torch.cat([torch.full((size,), i, dtype=torch.int64) for i, size in enumerate(batch_sizes)]) + # Concatenate all tokens tokens = torch.cat(tokens_list) positions = torch.cat(positions_list) losses = torch.cat(losses_list) - + return { - 'sequence_ids': sequence_ids, - 'tokens': tokens, - 'positions': positions, - 'losses': losses, - 'batch_indices': batch_indices + "sequence_ids": sequence_ids, + "tokens": tokens, + "positions": positions, + "losses": losses, + "batch_indices": batch_indices, } @@ -215,43 +201,37 @@ def setup_distributed(): """Initialize distributed training environment.""" # Initialize process group dist.init_process_group( - backend='nccl', # Use 'gloo' for CPU-only - init_method='env://' + backend="nccl", # Use 'gloo' for CPU-only + init_method="env://", ) - + # Get global rank and world size rank = dist.get_rank() world_size = dist.get_world_size() - + # Set device for this process torch.cuda.set_device(rank % torch.cuda.device_count()) - + return rank, world_size -def create_distributed_dataloader(parquet_dir, percentile_threshold=90, - batch_size=32, num_workers=4): + +def create_distributed_dataloader(parquet_dir, percentile_threshold=90, batch_size=32, num_workers=4): """Create a distributed dataloader for sharded parquet files.""" # Setup distributed environment rank, world_size = setup_distributed() - + # Create dataset with this rank's shards dataset = ShardedParquetDataset( parquet_dir=parquet_dir, percentile_threshold=percentile_threshold, stats_file=f"{parquet_dir}/stats.json", rank=rank, - world_size=world_size + world_size=world_size, ) - + # Create distributed sampler to handle partitioning - sampler = DistributedSampler( - dataset, - num_replicas=world_size, - rank=rank, - shuffle=True, - drop_last=False - ) - + sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False) + # Create dataloader with custom collate function dataloader = DataLoader( dataset, @@ -259,7 +239,7 @@ def create_distributed_dataloader(parquet_dir, percentile_threshold=90, sampler=sampler, num_workers=num_workers, collate_fn=collate_variable_length_sequences, - pin_memory=True + pin_memory=True, ) - - return dataloader, rank, world_size \ No newline at end of file + + return dataloader, rank, world_size