diff --git a/clt/activation_generation/__init__.py b/clt/activation_generation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/clt/activation_generation/generator.py b/clt/activation_generation/generator.py index a873ca0..b0575a0 100644 --- a/clt/activation_generation/generator.py +++ b/clt/activation_generation/generator.py @@ -40,7 +40,8 @@ from collections import defaultdict import psutil -from clt.training.utils import torch_bfloat16_to_numpy_uint16 +# Local application imports +# from clt.training.utils import torch_bfloat16_to_numpy_uint16 # Removed unused import try: import GPUtil @@ -141,13 +142,22 @@ def log_system_metrics(self, interval_name: str = "interval"): self.system_metrics_log.append(metrics) return metrics - def report(self): + def report(self, top_n_ops: Optional[int] = 20): logger.info("\n=== Performance Report ===") # Sort by total time descending for timings - sorted_timings = sorted(self.timings.items(), key=lambda item: sum(item[1]), reverse=True) + # Filter out operations with zero total time before sorting and slicing + valid_timings = {name: times for name, times in self.timings.items() if sum(times) > 0} + sorted_timings = sorted(valid_timings.items(), key=lambda item: sum(item[1]), reverse=True) - for name, times in sorted_timings: - if not times: + if top_n_ops is not None and top_n_ops > 0: + logger.info(f"--- Showing Top {top_n_ops} Timed Operations (by total time) ---") + timings_to_show = sorted_timings[:top_n_ops] + else: + logger.info("--- Showing All Timed Operations (by total time) ---") + timings_to_show = sorted_timings + + for name, times in timings_to_show: + if not times: # Should be redundant due to pre-filtering but safe continue avg_time = sum(times) / len(times) total_time = sum(times) @@ -386,10 +396,22 @@ def _async_uploader(upload_q: "queue.Queue[Optional[Path]]", cfg: ActivationConf # Welford online stats helper # --------------------------------------------------------------------------- class _RunningStat: - def __init__(self, dim: int): + def __init__(self, dim: int, device: Optional[torch.device | str] = None): self.n = 0 - self.mean = torch.zeros(dim, dtype=torch.float64) - self.M2 = torch.zeros(dim, dtype=torch.float64) + self.device = ( + torch.device(device) if isinstance(device, str) else device + ) # Resolve device string to torch.device + + if self.device and self.device.type == "mps": + self.stats_dtype = torch.float32 + logger.info("Using float32 for running stats on MPS device.") + else: + self.stats_dtype = torch.float64 + + # Initialize to CPU if device is None, then move on first update, or initialize directly if device is known. + initial_device_for_zeros = self.device if self.device else "cpu" + self.mean = torch.zeros(dim, dtype=self.stats_dtype, device=initial_device_for_zeros) + self.M2 = torch.zeros(dim, dtype=self.stats_dtype, device=initial_device_for_zeros) def update(self, x: torch.Tensor): """Update running mean & M2 using a mini-batch (Welford, parallel form). @@ -397,9 +419,19 @@ def update(self, x: torch.Tensor): This corrects the previous implementation which **under-estimated** the variance by failing to include the between-batch mean shift term. """ - - # Promote to float64 for numerical stability - x = x.to(torch.float64) + if self.device is None: + self.device = x.device + # Update self.stats_dtype if it was default and first tensor is MPS + if self.device.type == "mps" and self.stats_dtype == torch.float64: + self.stats_dtype = torch.float32 + logger.info("Switched running stats to float32 due to MPS device tensor.") + self.mean = self.mean.to(device=self.device, dtype=self.stats_dtype) + self.M2 = self.M2.to(device=self.device, dtype=self.stats_dtype) + elif x.device != self.device: + x = x.to(self.device) + + # Ensure x is on the correct device and has the stats_dtype for calculations + x = x.to(device=self.device, dtype=self.stats_dtype) cnt = x.shape[0] if cnt == 0: @@ -424,7 +456,10 @@ def update(self, x: torch.Tensor): def finalize(self) -> Tuple[np.ndarray, np.ndarray]: var = self.M2 / max(self.n - 1, 1) - return self.mean.cpu().numpy().astype("float32"), np.sqrt(var).cpu().numpy().astype("float32") + # Ensure tensors are moved to CPU before NumPy conversion and operations like np.sqrt + mean_cpu = self.mean.cpu() + var_cpu = var.cpu() + return mean_cpu.numpy().astype("float32"), np.sqrt(var_cpu.numpy()).astype("float32") # --------------------------------------------------------------------------- @@ -544,8 +579,8 @@ def generate_and_save(self): buf_tgt[lid] = [] if cfg.compute_norm_stats: stats[lid] = { - "inputs": _RunningStat(d_model), - "targets": _RunningStat(d_model), + "inputs": _RunningStat(d_model, device=self.device), + "targets": _RunningStat(d_model, device=self.device), } logger.info( "Layers=%d d_model=%d dtype=%s", len(layer_ids) if layer_ids else 0, d_model, dtype_str @@ -555,12 +590,12 @@ def generate_and_save(self): if layer_ids and batch_inp.get(layer_ids[0]) is not None: n_tok_in_batch = batch_inp[layer_ids[0]].shape[0] - with self._conditional_measure("batch_cpu_transfer_and_accumulate"): + with self._conditional_measure("batch_gpu_tensor_accumulate"): if layer_ids: for lid in layer_ids: if lid in batch_inp and lid in batch_tgt: - inp = batch_inp[lid].detach().cpu() - tgt = batch_tgt[lid].detach().cpu() + inp = batch_inp[lid].detach() + tgt = batch_tgt[lid].detach() buf_inp[lid].append(inp) buf_tgt[lid].append(tgt) if cfg.compute_norm_stats and lid in stats: @@ -721,8 +756,8 @@ def _conditional_measure(self, name: str): def _write_chunk( self, chunk_idx: int, - buf_inp: Dict[int, List[torch.Tensor]], - buf_tgt: Dict[int, List[torch.Tensor]], + buf_inp_gpu: Dict[int, List[torch.Tensor]], + buf_tgt_gpu: Dict[int, List[torch.Tensor]], layer_ids: List[int], d_model: int, rows: int, @@ -730,10 +765,7 @@ def _write_chunk( offset: int, ): with self._conditional_measure(f"chunk_write_total_idx_{chunk_idx}"): - with self._conditional_measure(f"chunk_{chunk_idx}_permutation_generation"): - perm = torch.randperm(rows) - - p = self.out_dir / f"chunk_{chunk_idx}.h5" + p = self.out_dir / f"chunk_{chunk_idx}.{self.cfg.output_format}" if self.torch_dtype == torch.float32: h5py_dtype_str = "float32" @@ -745,95 +777,141 @@ def _write_chunk( else: raise ValueError(f"Unsupported torch_dtype for HDF5: {self.torch_dtype}") - try: + if self.cfg.output_format == "hdf5": + num_write_workers = min(4, len(layer_ids) if layer_ids else 1) + with self._conditional_measure(f"chunk_{chunk_idx}_hdf5_file_open_and_create_datasets"): with h5py.File(p, "w", libver="latest") as hf: - _create_datasets(hf, layer_ids, rows, d_model, h5py_dtype=h5py_dtype_str) + for layer_id in layer_ids: + hf.create_dataset( + f"layer_{layer_id}/inputs", + shape=(rows, d_model), + dtype=h5py_dtype_str, + compression=self.cfg.compression if self.cfg.compression else None, + chunks=(min(rows, 16384), d_model), + ) + hf.create_dataset( + f"layer_{layer_id}/targets", + shape=(rows, d_model), + dtype=h5py_dtype_str, + compression=self.cfg.compression if self.cfg.compression else None, + chunks=(min(rows, 16384), d_model), + ) - # --- Phase 2a: Parallel Layer Writing --- - # First, prepare all the data outside the parallel execution - layer_data = {} - for lid in layer_ids: - with self._conditional_measure(f"chunk_{chunk_idx}_layer_{lid}_data_prep"): - # Concatenate tensors - with self._conditional_measure(f"chunk_{chunk_idx}_layer_{lid}_concat"): - inp_concat = torch.cat(buf_inp[lid], 0) - tgt_concat = torch.cat(buf_tgt[lid], 0) - - # Apply permutation - with self._conditional_measure(f"chunk_{chunk_idx}_layer_{lid}_permute"): - inp_perm = inp_concat[perm] - tgt_perm = tgt_concat[perm] - - # Convert to numpy - with self._conditional_measure(f"chunk_{chunk_idx}_layer_{lid}_convert_numpy"): - # Handle bfloat16 conversion - if h5py_dtype_str == "uint16": - inp_np = torch_bfloat16_to_numpy_uint16(inp_perm) - tgt_np = torch_bfloat16_to_numpy_uint16(tgt_perm) - else: - inp_np = inp_perm.to(self.torch_dtype).numpy() - tgt_np = tgt_perm.to(self.torch_dtype).numpy() - - # Store prepared data - layer_data[lid] = (inp_np, tgt_np) - - # Helper function for writing a single layer's data - def write_layer_data(layer_id: int, inputs_data: np.ndarray, targets_data: np.ndarray): - """Write a single layer's data to HDF5""" + layer_data_to_write = [] + for layer_id in layer_ids: + with self._conditional_measure(f"chunk_{chunk_idx}_layer_{layer_id}_data_prep"): + with self._conditional_measure(f"chunk_{chunk_idx}_layer_{layer_id}_concat"): + layer_inp_gpu = torch.cat(buf_inp_gpu[layer_id], dim=0) + layer_tgt_gpu = torch.cat(buf_tgt_gpu[layer_id], dim=0) + + with self._conditional_measure(f"chunk_{chunk_idx}_layer_{layer_id}_permute"): + perm = torch.randperm(rows, device=layer_inp_gpu.device) + layer_inp_gpu_perm = layer_inp_gpu[perm] + layer_tgt_gpu_perm = layer_tgt_gpu[perm] + + with self._conditional_measure(f"chunk_{chunk_idx}_layer_{layer_id}_cpu_transfer"): + layer_inp_cpu = layer_inp_gpu_perm.cpu() + layer_tgt_cpu = layer_tgt_gpu_perm.cpu() + + with self._conditional_measure(f"chunk_{chunk_idx}_layer_{layer_id}_convert_numpy"): + inputs_np = ( + layer_inp_cpu.numpy().view(np.uint16) + if self.torch_dtype == torch.bfloat16 + else layer_inp_cpu.numpy() + ) + targets_np = ( + layer_tgt_cpu.numpy().view(np.uint16) + if self.torch_dtype == torch.bfloat16 + else layer_tgt_cpu.numpy() + ) + layer_data_to_write.append((layer_id, inputs_np, targets_np)) + + def write_layer_data(layer_id_arg: int, inputs_data: np.ndarray, targets_data: np.ndarray): try: - hf[f"layer_{layer_id}/inputs"][:] = inputs_data - hf[f"layer_{layer_id}/targets"][:] = targets_data - return layer_id, None # Success + with h5py.File(p, "a", libver="latest") as hf_thread: + hf_thread[f"layer_{layer_id_arg}/inputs"][:] = inputs_data + hf_thread[f"layer_{layer_id_arg}/targets"][:] = targets_data + return layer_id_arg, None except Exception as e: - logger.error(f"Error writing layer {layer_id}: {e}") - return layer_id, e + logger.error(f"Error writing layer {layer_id_arg} to HDF5 chunk {chunk_idx}: {e}") + return layer_id_arg, e - # Use ThreadPoolExecutor for parallel writes + futures = [] with self._conditional_measure(f"chunk_{chunk_idx}_parallel_hdf5_writes"): - # Limit workers to avoid overwhelming I/O - max_workers = min(4, len(layer_ids)) - - with ThreadPoolExecutor(max_workers=max_workers) as executor: - # Submit all write tasks - futures = {} - for lid, (inp_data, tgt_data) in layer_data.items(): - future = executor.submit(write_layer_data, lid, inp_data, tgt_data) - futures[future] = lid - - # Wait for all writes to complete and check for errors - write_errors = [] - for future in as_completed(futures): - layer_id = futures[future] - try: - _, error = future.result() - if error: - write_errors.append((layer_id, error)) - except Exception as e: - write_errors.append((layer_id, e)) - - # If any writes failed, raise an error - if write_errors: - error_msg = "; ".join([f"Layer {lid}: {e}" for lid, e in write_errors]) - raise RuntimeError(f"Failed to write some layers: {error_msg}") - - except (IOError, OSError) as e: - logger.error(f"Failed to write HDF5 chunk {p}: {e}", exc_info=True) - try: - p.unlink(missing_ok=True) - except OSError: - logger.warning(f"Failed to remove partial chunk file {p} after write error.") - raise RuntimeError(f"Fatal error writing HDF5 chunk {chunk_idx}") from e - - m = np.empty((rows, 2), dtype="