diff --git a/dictionary_learning/buffer.py b/dictionary_learning/buffer.py index f0e02e1..e60a6dd 100644 --- a/dictionary_learning/buffer.py +++ b/dictionary_learning/buffer.py @@ -56,8 +56,14 @@ def __init__(self, self.refresh_batch_size = refresh_batch_size self.out_batch_size = out_batch_size self.device = device - self.remove_bos = remove_bos and (self.model.tokenizer.bos_token_id is not None) self.add_special_tokens = add_special_tokens + self.remove_bos = remove_bos + + if remove_bos and self.model.tokenizer.bos_token_id is None: + print( + "\n\n\nWARNING: remove_bos is True but tokenizer does not have a bos token. We are removing the first non-pad token instead. Don't use sequence packing.\n\n\n" + ) + def __iter__(self): return self @@ -138,9 +144,17 @@ def refresh(self): hidden_states = hidden_states.value if isinstance(hidden_states, tuple): hidden_states = hidden_states[0] + if self.remove_bos: - bos_mask = (input.value[1]["input_ids"] == self.model.tokenizer.bos_token_id) - mask = mask & ~bos_mask + if self.model.tokenizer.bos_token_id is not None: + bos_mask = input.value[1]["input_ids"] == self.model.tokenizer.bos_token_id + mask = mask & ~bos_mask + else: + # some models (like Qwen) don't have a bos token, so we need to remove the first non-pad token + assert mask.dim() == 2, "expected shape (batch_size, seq_len)" + first_one = (mask.to(t.int64).cumsum(dim=1) == 1) & mask + mask = mask & ~first_one + hidden_states = hidden_states[mask] remaining_space = self.activation_buffer_size - current_idx diff --git a/dictionary_learning/pytorch_buffer.py b/dictionary_learning/pytorch_buffer.py index 9d943e4..a31d4c6 100644 --- a/dictionary_learning/pytorch_buffer.py +++ b/dictionary_learning/pytorch_buffer.py @@ -119,7 +119,12 @@ def __init__( self.device = device self.add_special_tokens = add_special_tokens self.tokenizer = AutoTokenizer.from_pretrained(model.name_or_path) - self.remove_bos = remove_bos and (self.tokenizer.bos_token_id is not None) + self.remove_bos = remove_bos + + if remove_bos and self.tokenizer.bos_token_id is None: + print( + "\n\n\nWARNING: remove_bos is True but tokenizer does not have a bos token. We are removing the first non-pad token instead. Don't use sequence packing.\n\n\n" + ) if not self.tokenizer.pad_token: self.tokenizer.pad_token = self.tokenizer.eos_token @@ -192,10 +197,17 @@ def refresh(self): with t.no_grad(): input = self.tokenized_batch() hidden_states = collect_activations(self.model, self.submodule, input) - mask = (input["attention_mask"] != 0) + mask = input["attention_mask"] != 0 if self.remove_bos: - bos_mask = (input["input_ids"] == self.tokenizer.bos_token_id) - mask = mask & ~bos_mask + if self.tokenizer.bos_token_id is not None: + bos_mask = input["input_ids"] == self.tokenizer.bos_token_id + mask = mask & ~bos_mask + else: + # some models (like Qwen) don't have a bos token, so we need to remove the first non-pad token + assert mask.dim() == 2, "expected shape (batch_size, seq_len)" + first_one = (mask.to(t.int64).cumsum(dim=1) == 1) & mask + mask = mask & ~first_one + hidden_states = hidden_states[mask] remaining_space = self.activation_buffer_size - current_idx diff --git a/dictionary_learning/trainers/batch_top_k.py b/dictionary_learning/trainers/batch_top_k.py index 686dc0a..8cb2ecf 100644 --- a/dictionary_learning/trainers/batch_top_k.py +++ b/dictionary_learning/trainers/batch_top_k.py @@ -34,11 +34,15 @@ def __init__(self, activation_dim: int, dict_size: int, k: int): self.encoder.bias.data.zero_() self.b_dec = nn.Parameter(t.zeros(activation_dim)) - def encode(self, x: t.Tensor, return_active: bool = False, use_threshold: bool = True): + def encode( + self, x: t.Tensor, return_active: bool = False, use_threshold: bool = True + ): post_relu_feat_acts_BF = nn.functional.relu(self.encoder(x - self.b_dec)) if use_threshold: - encoded_acts_BF = post_relu_feat_acts_BF * (post_relu_feat_acts_BF > self.threshold) + encoded_acts_BF = post_relu_feat_acts_BF * ( + post_relu_feat_acts_BF > self.threshold + ) else: # Flatten and perform batch top-k flattened_acts = post_relu_feat_acts_BF.flatten() @@ -105,6 +109,7 @@ def __init__( decay_start: Optional[int] = None, # when does the lr decay start threshold_beta: float = 0.999, threshold_start_step: int = 1000, + k_anneal_steps: Optional[int] = None, seed: Optional[int] = None, device: Optional[str] = None, wandb_name: str = "BatchTopKSAE", @@ -122,6 +127,7 @@ def __init__( self.k = k self.threshold_beta = threshold_beta self.threshold_start_step = threshold_start_step + self.k_anneal_steps = k_anneal_steps if seed is not None: t.manual_seed(seed) @@ -146,17 +152,43 @@ def __init__( self.dead_feature_threshold = 10_000_000 self.top_k_aux = activation_dim // 2 # Heuristic from B.1 of the paper self.num_tokens_since_fired = t.zeros(dict_size, dtype=t.long, device=device) - self.logging_parameters = ["effective_l0", "dead_features", "pre_norm_auxk_loss"] + self.logging_parameters = [ + "effective_l0", + "dead_features", + "pre_norm_auxk_loss", + ] self.effective_l0 = -1 self.dead_features = -1 self.pre_norm_auxk_loss = -1 - self.optimizer = t.optim.Adam(self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999)) + self.optimizer = t.optim.Adam( + self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999) + ) lr_fn = get_lr_schedule(steps, warmup_steps, decay_start=decay_start) self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) + def update_annealed_k( + self, step: int, activation_dim: int, k_anneal_steps: Optional[int] = None + ) -> None: + """Update k buffer in-place with annealed value""" + if k_anneal_steps is None: + return + + assert 0 <= k_anneal_steps < self.steps, ( + "k_anneal_steps must be >= 0 and < steps." + ) + # self.k is the target k set for the trainer, not the dictionary's current k + assert activation_dim > self.k, "activation_dim must be greater than k" + + step = min(step, k_anneal_steps) + ratio = step / k_anneal_steps + annealed_value = activation_dim * (1 - ratio) + self.k * ratio + + # Update in-place + self.ae.k.fill_(int(annealed_value)) + def get_auxiliary_loss(self, residual_BD: t.Tensor, post_relu_acts_BF: t.Tensor): dead_features = self.num_tokens_since_fired >= self.dead_feature_threshold self.dead_features = int(dead_features.sum()) @@ -170,19 +202,28 @@ def get_auxiliary_loss(self, residual_BD: t.Tensor, post_relu_acts_BF: t.Tensor) auxk_acts, auxk_indices = auxk_latents.topk(k_aux, sorted=False) auxk_buffer_BF = t.zeros_like(post_relu_acts_BF) - auxk_acts_BF = auxk_buffer_BF.scatter_(dim=-1, index=auxk_indices, src=auxk_acts) + auxk_acts_BF = auxk_buffer_BF.scatter_( + dim=-1, index=auxk_indices, src=auxk_acts + ) # Note: decoder(), not decode(), as we don't want to apply the bias x_reconstruct_aux = self.ae.decoder(auxk_acts_BF) l2_loss_aux = ( - (residual_BD.float() - x_reconstruct_aux.float()).pow(2).sum(dim=-1).mean() + (residual_BD.float() - x_reconstruct_aux.float()) + .pow(2) + .sum(dim=-1) + .mean() ) self.pre_norm_auxk_loss = l2_loss_aux # normalization from OpenAI implementation: https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/kernels.py#L614 - residual_mu = residual_BD.mean(dim=0)[None, :].broadcast_to(residual_BD.shape) - loss_denom = (residual_BD.float() - residual_mu.float()).pow(2).sum(dim=-1).mean() + residual_mu = residual_BD.mean(dim=0)[None, :].broadcast_to( + residual_BD.shape + ) + loss_denom = ( + (residual_BD.float() - residual_mu.float()).pow(2).sum(dim=-1).mean() + ) normalized_auxk_loss = l2_loss_aux / loss_denom return normalized_auxk_loss.nan_to_num(0.0) @@ -220,7 +261,7 @@ def loss(self, x, step=None, logging=False): e = x - x_hat - self.effective_l0 = self.k + self.effective_l0 = self.ae.k.item() num_tokens_in_step = x.size(0) did_fire = t.zeros_like(self.num_tokens_since_fired, dtype=t.bool) @@ -239,7 +280,11 @@ def loss(self, x, step=None, logging=False): x, x_hat, f, - {"l2_loss": l2_loss.item(), "auxk_loss": auxk_loss.item(), "loss": loss.item()}, + { + "l2_loss": l2_loss.item(), + "auxk_loss": auxk_loss.item(), + "loss": loss.item(), + }, ) def update(self, step, x): @@ -263,6 +308,7 @@ def update(self, step, x): self.optimizer.step() self.optimizer.zero_grad() self.scheduler.step() + self.update_annealed_k(step, self.ae.activation_dim, self.k_anneal_steps) # Make sure the decoder is still unit-norm self.ae.decoder.weight.data = set_decoder_norm_to_unit_norm( diff --git a/dictionary_learning/trainers/matryoshka_batch_top_k.py b/dictionary_learning/trainers/matryoshka_batch_top_k.py index 67647fb..03c195b 100644 --- a/dictionary_learning/trainers/matryoshka_batch_top_k.py +++ b/dictionary_learning/trainers/matryoshka_batch_top_k.py @@ -35,7 +35,9 @@ def apply_temperature(probabilities: list[float], temperature: float) -> list[fl class MatryoshkaBatchTopKSAE(Dictionary, nn.Module): - def __init__(self, activation_dim: int, dict_size: int, k: int, group_sizes: list[int]): + def __init__( + self, activation_dim: int, dict_size: int, k: int, group_sizes: list[int] + ): super().__init__() self.activation_dim = activation_dim self.dict_size = dict_size @@ -55,7 +57,9 @@ def __init__(self, activation_dim: int, dict_size: int, k: int, group_sizes: lis self.W_enc = nn.Parameter(t.empty(activation_dim, dict_size)) self.b_enc = nn.Parameter(t.zeros(dict_size)) - self.W_dec = nn.Parameter(t.nn.init.kaiming_uniform_(t.empty(dict_size, activation_dim))) + self.W_dec = nn.Parameter( + t.nn.init.kaiming_uniform_(t.empty(dict_size, activation_dim)) + ) self.b_dec = nn.Parameter(t.zeros(activation_dim)) # We must transpose because we are using nn.Parameter, not nn.Linear @@ -64,11 +68,17 @@ def __init__(self, activation_dim: int, dict_size: int, k: int, group_sizes: lis ).T self.W_enc.data = self.W_dec.data.clone().T - def encode(self, x: t.Tensor, return_active: bool = False, use_threshold: bool = True): - post_relu_feat_acts_BF = nn.functional.relu((x - self.b_dec) @ self.W_enc + self.b_enc) + def encode( + self, x: t.Tensor, return_active: bool = False, use_threshold: bool = True + ): + post_relu_feat_acts_BF = nn.functional.relu( + (x - self.b_dec) @ self.W_enc + self.b_enc + ) if use_threshold: - encoded_acts_BF = post_relu_feat_acts_BF * (post_relu_feat_acts_BF > self.threshold) + encoded_acts_BF = post_relu_feat_acts_BF * ( + post_relu_feat_acts_BF > self.threshold + ) else: # Flatten and perform batch top-k flattened_acts = post_relu_feat_acts_BF.flatten() @@ -108,7 +118,9 @@ def scale_biases(self, scale: float): self.threshold *= scale @classmethod - def from_pretrained(cls, path, k=None, device=None, **kwargs) -> "MatryoshkaBatchTopKSAE": + def from_pretrained( + cls, path, k=None, device=None, **kwargs + ) -> "MatryoshkaBatchTopKSAE": state_dict = t.load(path) activation_dim, dict_size = state_dict["W_enc"].shape if k is None: @@ -143,6 +155,7 @@ def __init__( decay_start: Optional[int] = None, # when does the lr decay start threshold_beta: float = 0.999, threshold_start_step: int = 1000, + k_anneal_steps: Optional[int] = None, seed: Optional[int] = None, device: Optional[str] = None, wandb_name: str = "BatchTopKSAE", @@ -160,6 +173,7 @@ def __init__( self.k = k self.threshold_beta = threshold_beta self.threshold_start_step = threshold_start_step + self.k_anneal_steps = k_anneal_steps if seed is not None: t.manual_seed(seed) @@ -200,17 +214,43 @@ def __init__( self.dead_feature_threshold = 10_000_000 self.top_k_aux = activation_dim // 2 # Heuristic from B.1 of the paper - self.optimizer = t.optim.Adam(self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999)) + self.optimizer = t.optim.Adam( + self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999) + ) lr_fn = get_lr_schedule(steps, warmup_steps, decay_start, resample_steps=None) self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) self.num_tokens_since_fired = t.zeros(dict_size, dtype=t.long, device=device) - self.logging_parameters = ["effective_l0", "dead_features", "pre_norm_auxk_loss"] + self.logging_parameters = [ + "effective_l0", + "dead_features", + "pre_norm_auxk_loss", + ] self.effective_l0 = -1 self.dead_features = -1 self.pre_norm_auxk_loss = -1 + def update_annealed_k( + self, step: int, activation_dim: int, k_anneal_steps: Optional[int] = None + ) -> None: + """Update k buffer in-place with annealed value""" + if k_anneal_steps is None: + return + + assert 0 <= k_anneal_steps < self.steps, ( + "k_anneal_steps must be >= 0 and < steps." + ) + # self.k is the target k set for the trainer, not the dictionary's current k + assert activation_dim > self.k, "activation_dim must be greater than k" + + step = min(step, k_anneal_steps) + ratio = step / k_anneal_steps + annealed_value = activation_dim * (1 - ratio) + self.k * ratio + + # Update in-place + self.ae.k.fill_(int(annealed_value)) + def get_auxiliary_loss(self, residual_BD: t.Tensor, post_relu_acts_BF: t.Tensor): dead_features = self.num_tokens_since_fired >= self.dead_feature_threshold self.dead_features = int(dead_features.sum()) @@ -224,19 +264,28 @@ def get_auxiliary_loss(self, residual_BD: t.Tensor, post_relu_acts_BF: t.Tensor) auxk_acts, auxk_indices = auxk_latents.topk(k_aux, sorted=False) auxk_buffer_BF = t.zeros_like(post_relu_acts_BF) - auxk_acts_BF = auxk_buffer_BF.scatter_(dim=-1, index=auxk_indices, src=auxk_acts) + auxk_acts_BF = auxk_buffer_BF.scatter_( + dim=-1, index=auxk_indices, src=auxk_acts + ) # We don't want to apply the bias x_reconstruct_aux = auxk_acts_BF @ self.ae.W_dec l2_loss_aux = ( - (residual_BD.float() - x_reconstruct_aux.float()).pow(2).sum(dim=-1).mean() + (residual_BD.float() - x_reconstruct_aux.float()) + .pow(2) + .sum(dim=-1) + .mean() ) self.pre_norm_auxk_loss = l2_loss_aux # normalization from OpenAI implementation: https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/kernels.py#L614 - residual_mu = residual_BD.mean(dim=0)[None, :].broadcast_to(residual_BD.shape) - loss_denom = (residual_BD.float() - residual_mu.float()).pow(2).sum(dim=-1).mean() + residual_mu = residual_BD.mean(dim=0)[None, :].broadcast_to( + residual_BD.shape + ) + loss_denom = ( + (residual_BD.float() - residual_mu.float()).pow(2).sum(dim=-1).mean() + ) normalized_auxk_loss = l2_loss_aux / loss_denom return normalized_auxk_loss.nan_to_num(0.0) @@ -283,7 +332,9 @@ def loss(self, x, step=None, logging=False): acts_slice = f_chunks[i] x_reconstruct = x_reconstruct + acts_slice @ W_dec_slice - l2_loss = (x - x_reconstruct).pow(2).sum(dim=-1).mean() * self.group_weights[i] + l2_loss = (x - x_reconstruct).pow(2).sum( + dim=-1 + ).mean() * self.group_weights[i] total_l2_loss += l2_loss l2_losses = t.cat([l2_losses, l2_loss.unsqueeze(0)]) @@ -299,7 +350,9 @@ def loss(self, x, step=None, logging=False): self.num_tokens_since_fired += num_tokens_in_step self.num_tokens_since_fired[did_fire] = 0 - auxk_loss = self.get_auxiliary_loss((x - x_reconstruct).detach(), post_relu_acts_BF) + auxk_loss = self.get_auxiliary_loss( + (x - x_reconstruct).detach(), post_relu_acts_BF + ) loss = mean_l2_loss + self.auxk_alpha * auxk_loss if not logging: @@ -329,13 +382,17 @@ def update(self, step, x): # We must transpose because we are using nn.Parameter, not nn.Linear self.ae.W_dec.grad = remove_gradient_parallel_to_decoder_directions( - self.ae.W_dec.T, self.ae.W_dec.grad.T, self.ae.activation_dim, self.ae.dict_size + self.ae.W_dec.T, + self.ae.W_dec.grad.T, + self.ae.activation_dim, + self.ae.dict_size, ).T t.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0) self.optimizer.step() self.optimizer.zero_grad() self.scheduler.step() + self.update_annealed_k(step, self.ae.activation_dim, self.k_anneal_steps) # We must transpose because we are using nn.Parameter, not nn.Linear self.ae.W_dec.data = set_decoder_norm_to_unit_norm( diff --git a/dictionary_learning/trainers/top_k.py b/dictionary_learning/trainers/top_k.py index f6f5692..e81259f 100644 --- a/dictionary_learning/trainers/top_k.py +++ b/dictionary_learning/trainers/top_k.py @@ -80,14 +80,23 @@ def __init__(self, activation_dim: int, dict_size: int, k: int): self.b_dec = nn.Parameter(t.zeros(activation_dim)) - def encode(self, x: t.Tensor, return_topk: bool = False, use_threshold: bool = False): + def encode( + self, x: t.Tensor, return_topk: bool = False, use_threshold: bool = False + ): post_relu_feat_acts_BF = nn.functional.relu(self.encoder(x - self.b_dec)) if use_threshold: - encoded_acts_BF = post_relu_feat_acts_BF * (post_relu_feat_acts_BF > self.threshold) + encoded_acts_BF = post_relu_feat_acts_BF * ( + post_relu_feat_acts_BF > self.threshold + ) if return_topk: post_topk = post_relu_feat_acts_BF.topk(self.k, sorted=False, dim=-1) - return encoded_acts_BF, post_topk.values, post_topk.indices, post_relu_feat_acts_BF + return ( + encoded_acts_BF, + post_topk.values, + post_topk.indices, + post_relu_feat_acts_BF, + ) else: return encoded_acts_BF @@ -98,7 +107,9 @@ def encode(self, x: t.Tensor, return_topk: bool = False, use_threshold: bool = F top_indices_BK = post_topk.indices buffer_BF = t.zeros_like(post_relu_feat_acts_BF) - encoded_acts_BF = buffer_BF.scatter_(dim=-1, index=top_indices_BK, src=tops_acts_BK) + encoded_acts_BF = buffer_BF.scatter_( + dim=-1, index=top_indices_BK, src=tops_acts_BK + ) if return_topk: return encoded_acts_BF, tops_acts_BK, top_indices_BK, post_relu_feat_acts_BF @@ -161,6 +172,7 @@ def __init__( decay_start: Optional[int] = None, # when does the lr decay start threshold_beta: float = 0.999, threshold_start_step: int = 1000, + k_anneal_steps: Optional[int] = None, seed: Optional[int] = None, device: Optional[str] = None, wandb_name: str = "AutoEncoderTopK", @@ -180,6 +192,7 @@ def __init__( self.k = k self.threshold_beta = threshold_beta self.threshold_start_step = threshold_start_step + self.k_anneal_steps = k_anneal_steps if seed is not None: t.manual_seed(seed) @@ -204,18 +217,44 @@ def __init__( self.dead_feature_threshold = 10_000_000 self.top_k_aux = activation_dim // 2 # Heuristic from B.1 of the paper self.num_tokens_since_fired = t.zeros(dict_size, dtype=t.long, device=device) - self.logging_parameters = ["effective_l0", "dead_features", "pre_norm_auxk_loss"] + self.logging_parameters = [ + "effective_l0", + "dead_features", + "pre_norm_auxk_loss", + ] self.effective_l0 = -1 self.dead_features = -1 self.pre_norm_auxk_loss = -1 # Optimizer and scheduler - self.optimizer = t.optim.Adam(self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999)) + self.optimizer = t.optim.Adam( + self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999) + ) lr_fn = get_lr_schedule(steps, warmup_steps, decay_start=decay_start) self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) + def update_annealed_k( + self, step: int, activation_dim: int, k_anneal_steps: Optional[int] = None + ) -> None: + """Update k buffer in-place with annealed value""" + if k_anneal_steps is None: + return + + assert 0 <= k_anneal_steps < self.steps, ( + "k_anneal_steps must be >= 0 and < steps." + ) + # self.k is the target k set for the trainer, not the dictionary's current k + assert activation_dim > self.k, "activation_dim must be greater than k" + + step = min(step, k_anneal_steps) + ratio = step / k_anneal_steps + annealed_value = activation_dim * (1 - ratio) + self.k * ratio + + # Update in-place + self.ae.k.fill_(int(annealed_value)) + def get_auxiliary_loss(self, residual_BD: t.Tensor, post_relu_acts_BF: t.Tensor): dead_features = self.num_tokens_since_fired >= self.dead_feature_threshold self.dead_features = int(dead_features.sum()) @@ -229,19 +268,28 @@ def get_auxiliary_loss(self, residual_BD: t.Tensor, post_relu_acts_BF: t.Tensor) auxk_acts, auxk_indices = auxk_latents.topk(k_aux, sorted=False) auxk_buffer_BF = t.zeros_like(post_relu_acts_BF) - auxk_acts_BF = auxk_buffer_BF.scatter_(dim=-1, index=auxk_indices, src=auxk_acts) + auxk_acts_BF = auxk_buffer_BF.scatter_( + dim=-1, index=auxk_indices, src=auxk_acts + ) # Note: decoder(), not decode(), as we don't want to apply the bias x_reconstruct_aux = self.ae.decoder(auxk_acts_BF) l2_loss_aux = ( - (residual_BD.float() - x_reconstruct_aux.float()).pow(2).sum(dim=-1).mean() + (residual_BD.float() - x_reconstruct_aux.float()) + .pow(2) + .sum(dim=-1) + .mean() ) self.pre_norm_auxk_loss = l2_loss_aux # normalization from OpenAI implementation: https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/kernels.py#L614 - residual_mu = residual_BD.mean(dim=0)[None, :].broadcast_to(residual_BD.shape) - loss_denom = (residual_BD.float() - residual_mu.float()).pow(2).sum(dim=-1).mean() + residual_mu = residual_BD.mean(dim=0)[None, :].broadcast_to( + residual_BD.shape + ) + loss_denom = ( + (residual_BD.float() - residual_mu.float()).pow(2).sum(dim=-1).mean() + ) normalized_auxk_loss = l2_loss_aux / loss_denom return normalized_auxk_loss.nan_to_num(0.0) @@ -294,7 +342,9 @@ def loss(self, x, step=None, logging=False): l2_loss = e.pow(2).sum(dim=-1).mean() auxk_loss = ( - self.get_auxiliary_loss(e.detach(), post_relu_acts_BF) if self.auxk_alpha > 0 else 0 + self.get_auxiliary_loss(e.detach(), post_relu_acts_BF) + if self.auxk_alpha > 0 + else 0 ) loss = l2_loss + self.auxk_alpha * auxk_loss @@ -306,7 +356,11 @@ def loss(self, x, step=None, logging=False): x, x_hat, f, - {"l2_loss": l2_loss.item(), "auxk_loss": auxk_loss.item(), "loss": loss.item()}, + { + "l2_loss": l2_loss.item(), + "auxk_loss": auxk_loss.item(), + "loss": loss.item(), + }, ) def update(self, step, x): @@ -334,6 +388,7 @@ def update(self, step, x): self.optimizer.step() self.optimizer.zero_grad() self.scheduler.step() + self.update_annealed_k(step, self.ae.activation_dim, self.k_anneal_steps) # Make sure the decoder is still unit-norm self.ae.decoder.weight.data = set_decoder_norm_to_unit_norm( diff --git a/dictionary_learning/utils.py b/dictionary_learning/utils.py index 6f2d2c0..a5d0559 100644 --- a/dictionary_learning/utils.py +++ b/dictionary_learning/utils.py @@ -290,6 +290,7 @@ def get_submodule(model: AutoModelForCausalLM, layer: int): elif ( model.config.architectures[0] == "Qwen2ForCausalLM" or model.config.architectures[0] == "Gemma2ForCausalLM" + or model.config.architectures[0] == "Qwen3ForCausalLM" ): return model.model.layers[layer] else: @@ -309,6 +310,7 @@ def truncate_model(model: AutoModelForCausalLM, layer: int): if ( model.config.architectures[0] == "Qwen2ForCausalLM" or model.config.architectures[0] == "Gemma2ForCausalLM" + or model.config.architectures[0] == "Qwen3ForCausalLM" ): removed_layers = model.model.layers[layer + 1 :]