From 7690e2b635035e6d60a061fa2551b5988a206d3b Mon Sep 17 00:00:00 2001 From: Curt Tigges Date: Tue, 3 Jun 2025 17:37:37 -0700 Subject: [PATCH 1/5] fixed issues with activation store --- clt/training/data/local_activation_store.py | 32 ++++++++++++++++----- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/clt/training/data/local_activation_store.py b/clt/training/data/local_activation_store.py index d532e07..04357dd 100644 --- a/clt/training/data/local_activation_store.py +++ b/clt/training/data/local_activation_store.py @@ -89,12 +89,28 @@ def _load_manifest(self) -> Optional[np.ndarray]: logger.error(f"Manifest file not found: {path}") return None try: - with open(path, "rb") as f: - data = np.frombuffer(f.read(), dtype=np.uint32).reshape(-1, 2) - logger.info(f"Manifest loaded from {path} ({len(data)} rows).") + file_size_bytes = path.stat().st_size + # Heuristic: older 2-field format is 8 bytes per entry (two uint32), + # newer 3-field format is 16 bytes per entry (int32, int32, int64). + if file_size_bytes % 16 == 0: + # New format with 3 fields (chunk_id, num_tokens, offset) + manifest_dtype = np.dtype([("chunk_id", np.int32), ("num_tokens", np.int32), ("offset", np.int64)]) + data_structured = np.fromfile(path, dtype=manifest_dtype) + logger.info(f"Manifest loaded (3-field format) from {path} ({data_structured.shape[0]} rows).") + # Convert to Nx2 uint32 array expected by downstream code (drop offset) + data = np.stack((data_structured["chunk_id"], data_structured["num_tokens"]), axis=1).astype(np.uint32) + elif file_size_bytes % 8 == 0: + # Legacy 2-field format already matches expected shape + data = np.fromfile(path, dtype=np.uint32).reshape(-1, 2) + logger.info(f"Manifest loaded (legacy 2-field format) from {path} ({data.shape[0]} rows).") + else: + logger.error( + f"Manifest file size ({file_size_bytes} bytes) is not compatible with known formats (8 or 16 bytes per row)." + ) + return None return data except ValueError as e: - logger.error(f"Error reshaping manifest data from {path} (expected Nx2): {e}") + logger.error(f"Error parsing manifest data from {path}: {e}") return None except OSError as e: logger.error(f"Error reading manifest file {path}: {e}") @@ -129,7 +145,9 @@ def _load_chunk(self, chunk_path: str, layer_key: str, data_type: str): logger.error(f"Chunk file not found for fetch: {chunk_path}") raise except KeyError as e: - raise RuntimeError(f"Missing 'inputs' or 'targets' dataset in layer group '{layer_key}' of chunk {chunk_path}") from e + raise RuntimeError( + f"Missing 'inputs' or 'targets' dataset in layer group '{layer_key}' of chunk {chunk_path}" + ) from e except Exception as e: logger.error(f"Failed to open chunk at {chunk_path}: {e}") raise RuntimeError(f"Failed to access chunk HDF5 file: {chunk_path}") from e @@ -162,8 +180,8 @@ def _layer_sort_key(name: str) -> int: row_indices_h5 = row_indices for i, lk in enumerate(layer_keys): - input_data = self._load_chunk(chunk_path, lk, 'inputs')[row_indices_h5, :] - target_data = self._load_chunk(chunk_path, lk, 'targets')[row_indices_h5, :] + input_data = self._load_chunk(chunk_path, lk, "inputs")[row_indices_h5, :] + target_data = self._load_chunk(chunk_path, lk, "targets")[row_indices_h5, :] bufs.append(input_data.tobytes()) bufs.append(target_data.tobytes()) return b"".join(bufs) From 2f7a0fe6a085edf6fa514c538ecf6ea3b59d1e95 Mon Sep 17 00:00:00 2001 From: Curt Tigges Date: Tue, 3 Jun 2025 17:43:03 -0700 Subject: [PATCH 2/5] manifest now shows per chunk --- clt/training/data/local_activation_store.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/clt/training/data/local_activation_store.py b/clt/training/data/local_activation_store.py index 04357dd..93d822e 100644 --- a/clt/training/data/local_activation_store.py +++ b/clt/training/data/local_activation_store.py @@ -96,9 +96,22 @@ def _load_manifest(self) -> Optional[np.ndarray]: # New format with 3 fields (chunk_id, num_tokens, offset) manifest_dtype = np.dtype([("chunk_id", np.int32), ("num_tokens", np.int32), ("offset", np.int64)]) data_structured = np.fromfile(path, dtype=manifest_dtype) - logger.info(f"Manifest loaded (3-field format) from {path} ({data_structured.shape[0]} rows).") - # Convert to Nx2 uint32 array expected by downstream code (drop offset) - data = np.stack((data_structured["chunk_id"], data_structured["num_tokens"]), axis=1).astype(np.uint32) + logger.info( + f"Manifest loaded (3-field format) from {path} ({data_structured.shape[0]} chunks). Expanding to per-row entries." + ) + # Expand into per-row entries expected by downstream (chunk_id, row_in_chunk) + chunk_ids = data_structured["chunk_id"].astype(np.uint32) + num_tokens_arr = data_structured["num_tokens"].astype(np.uint32) + # Compute total rows + total_rows = int(num_tokens_arr.sum()) + logger.info(f"Expanding manifest: total rows = {total_rows}") + # Pre-allocate array + data = np.empty((total_rows, 2), dtype=np.uint32) + row_ptr = 0 + for cid, ntok in zip(chunk_ids, num_tokens_arr): + data[row_ptr : row_ptr + ntok, 0] = cid # chunk_id column + data[row_ptr : row_ptr + ntok, 1] = np.arange(ntok, dtype=np.uint32) # row index within chunk + row_ptr += ntok elif file_size_bytes % 8 == 0: # Legacy 2-field format already matches expected shape data = np.fromfile(path, dtype=np.uint32).reshape(-1, 2) From f6ef708fe8c05c0afd51d739d7d83ca0e2f6de0d Mon Sep 17 00:00:00 2001 From: Curt Tigges Date: Tue, 3 Jun 2025 17:48:33 -0700 Subject: [PATCH 3/5] manifest now looks for correct file extension --- clt/training/data/local_activation_store.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/clt/training/data/local_activation_store.py b/clt/training/data/local_activation_store.py index 93d822e..f3e35dd 100644 --- a/clt/training/data/local_activation_store.py +++ b/clt/training/data/local_activation_store.py @@ -168,6 +168,19 @@ def _load_chunk(self, chunk_path: str, layer_key: str, data_type: str): def _fetch_slice(self, chunk_id: int, row_indices: np.ndarray) -> bytes: chunk_path = self.dataset_path / f"chunk_{chunk_id}.h5" + if not chunk_path.exists(): + # Fall back to .hdf5 extension (newer generator default) + alt_path = self.dataset_path / f"chunk_{chunk_id}.hdf5" + if alt_path.exists(): + chunk_path = alt_path + else: + # Provide clearer error message before _open_h5 raises + logger.error( + "Chunk file for chunk_id %d not found with either .h5 or .hdf5 extension in %s", + chunk_id, + self.dataset_path, + ) + hf = _open_h5(chunk_path) try: From 41e1262de05256b13c2c1b41785277d9f41884fc Mon Sep 17 00:00:00 2001 From: Curt Tigges Date: Tue, 3 Jun 2025 19:04:39 -0700 Subject: [PATCH 4/5] reduced LRU cache size --- clt/training/data/local_activation_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clt/training/data/local_activation_store.py b/clt/training/data/local_activation_store.py index f3e35dd..11ed807 100644 --- a/clt/training/data/local_activation_store.py +++ b/clt/training/data/local_activation_store.py @@ -146,7 +146,7 @@ def _load_norm_stats(self) -> Optional[Dict[str, Any]]: logger.error(f"Error reading norm_stats file {path}: {e}") return None - @lru_cache(maxsize=256) + @lru_cache(maxsize=64) def _load_chunk(self, chunk_path: str, layer_key: str, data_type: str): """Loads entire HDF5 chunk from disk and caches""" From d2ede36cec0278011799b72272c7f294dbc1bc0c Mon Sep 17 00:00:00 2001 From: Curt Tigges Date: Tue, 3 Jun 2025 21:45:28 -0700 Subject: [PATCH 5/5] corrected total tokens calculation --- clt/training/metric_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clt/training/metric_utils.py b/clt/training/metric_utils.py index 8448fe3..785c806 100644 --- a/clt/training/metric_utils.py +++ b/clt/training/metric_utils.py @@ -49,7 +49,7 @@ def log_training_step( self.metrics["train_losses"].append({"step": step, **loss_dict}) if not self.distributed or self.rank == 0: - total_tokens_processed = self.training_config.train_batch_size_tokens * self.world_size * (step + 1) + total_tokens_processed = self.training_config.train_batch_size_tokens * (step + 1) self.wandb_logger.log_step( step,