diff --git a/clt/config/clt_config.py b/clt/config/clt_config.py index 8632724..aee5728 100644 --- a/clt/config/clt_config.py +++ b/clt/config/clt_config.py @@ -40,7 +40,6 @@ def __post_init__(self): assert self.num_features > 0, "Number of features must be positive" assert self.num_layers > 0, "Number of layers must be positive" assert self.d_model > 0, "Model dimension must be positive" - assert self.jumprelu_threshold > 0, "JumpReLU threshold must be positive" valid_norm_methods = ["auto", "estimated_mean_std", "none"] assert ( self.normalization_method in valid_norm_methods diff --git a/clt/models/theta.py b/clt/models/theta.py index a42cff1..e843f4a 100644 --- a/clt/models/theta.py +++ b/clt/models/theta.py @@ -48,6 +48,10 @@ def __init__( self.rank = dist.get_rank(process_group) if self.config.activation_fn == "jumprelu": + if self.config.jumprelu_threshold == 0: + logger.warning( + f"Rank {self.rank}: jumprelu_threshold is 0, expecting to load log_threshold from checkpoint." + ) initial_threshold_val = torch.ones( config.num_layers, config.num_features, device=self.device, dtype=self.dtype ) * torch.log(torch.tensor(config.jumprelu_threshold, device=self.device, dtype=self.dtype))