Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3,836 changes: 3,836 additions & 0 deletions analysis_ICLR/Analysis_WD_SGD_NORMS.ipynb

Large diffs are not rendered by default.

1,643 changes: 1,643 additions & 0 deletions analysis_ICLR/Main_analysis_notebook.ipynb

Large diffs are not rendered by default.

Binary file modified cnn/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file modified cnn/__pycache__/model.cpython-39.pyc
Binary file not shown.
162 changes: 126 additions & 36 deletions cnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,65 +11,61 @@ def __init__(self,
mlp_layers,
pool_every=1,
dropout=0.5,
input_size=(3, 32, 32)):
"""
Constructs a CNN classifier.

Parameters:
conv_channels (list of int): A list specifying the channel dimensions for the convolutional layers.
For example, [3, 32, 64] means the input has 3 channels, then a conv layer maps 3 → 32 channels,
and the next conv layer maps 32 → 64.
kernel_size (int): The kernel (window) size for all convolutional layers (assumed odd to allow padding).
mlp_layers (list of int): A list defining the fully connected (MLP) layers after flattening.
For example, [512, 128, num_classes].
pool_every (int): Insert a MaxPool2d layer (with kernel size 2, stride 2) after every 'pool_every'
convolutional layers.
dropout (float): Dropout probability applied after each fully-connected layer (except the last).
input_size (tuple): Input dimensions as (channels, height, width).
"""
super(CNNClassifier, self).__init__()
input_size=(3, 32, 32),
initialization_factor: float = 1.0,
):
super().__init__()
self._init_factor = initialization_factor

# --- Build the convolutional block ---
conv_layers = []
in_channels = conv_channels[0]
for i, out_channels in enumerate(conv_channels[1:]):
conv_layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
conv_layers.append(nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size,
padding=kernel_size // 2))
conv_layers.append(nn.ReLU(inplace=True))
# Insert pooling layer every 'pool_every' conv layer(s)
if (i + 1) % pool_every == 0:
conv_layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
in_channels = out_channels
self.features = nn.Sequential(*conv_layers)

# Determine the number of features after the conv block using a dummy forward pass.
# figure out flatten size
with torch.no_grad():
dummy_input = torch.zeros(1, *input_size)
feat = self.features(dummy_input)
dummy = torch.zeros(1, *input_size)
feat = self.features(dummy)
self.flatten_dim = feat.view(1, -1).size(1)

# --- Build the MLP (classifier) block ---
mlp_layers_list = []
prev_dim = self.flatten_dim
# Add hidden layers with Linear, ReLU and dropout.
# --- Build the MLP ---
mlp = []
prev = self.flatten_dim
for h in mlp_layers[:-1]:
mlp_layers_list.append(nn.Linear(prev_dim, h))
mlp_layers_list.append(nn.ReLU(inplace=True))
mlp.append(nn.Linear(prev, h))
mlp.append(nn.ReLU(inplace=True))
if dropout > 0:
mlp_layers_list.append(nn.Dropout(dropout))
prev_dim = h
# Final layer (without activation) for classification.
mlp_layers_list.append(nn.Linear(prev_dim, mlp_layers[-1]))
self.classifier = nn.Sequential(*mlp_layers_list)
mlp.append(nn.Dropout(dropout))
prev = h
mlp.append(nn.Linear(prev, mlp_layers[-1]))
self.classifier = nn.Sequential(*mlp)

# --- NOW scale *all* conv/linear weights ---
self._scale_initial_weights()

def _scale_initial_weights(self):
# 1.0 → no change; anything else scales
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
# multiply the *existing* initialization
m.weight.data.mul_(self._init_factor)
# leave bias as whatever it was (PyTorch default = 0)

def forward(self, x):
# Pass through convolutional features.
x = self.features(x)
x = x.view(x.size(0), -1) # flatten
# Pass through classifier (MLP).
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x


def compute_model_norm(self):
"""
Computes the model's spectral complexity, which is defined as:
Expand Down Expand Up @@ -144,6 +140,100 @@ def two_one_norm(weight):
norm_value = prod_spec * (correction_sum ** (3.0/2.0))
return norm_value

def compute_l1_norm(self):
"""
Computes the entry-wise L1 norm: sum of absolute values of all weights.
"""
total = torch.tensor(0.0, device=next(self.parameters()).device)
for param in self.parameters():
total += torch.sum(torch.abs(param))
return total

def compute_frobenius_norm(self):
"""
Computes the Frobenius norm: sqrt(sum of squares of all weights).
"""
total_sq = torch.tensor(0.0, device=next(self.parameters()).device)
for param in self.parameters():
total_sq += torch.sum(param ** 2)
return torch.sqrt(total_sq)

def compute_group_2_1_norm(self):
"""
Computes the (2,1) group norm: sum of L2 norms of columns per layer.
"""
total = torch.tensor(0.0, device=next(self.parameters()).device)
for module in list(self.features) + list(self.classifier):
if isinstance(module, (nn.Conv2d, nn.Linear)):
weight = module.weight
if weight.ndim > 2:
w = weight.view(weight.size(0), -1)
else:
w = weight
col_norms = torch.norm(w, p=2, dim=0)
total += torch.sum(col_norms)
return total

def compute_spectral_norm(self):
"""
Computes the product of spectral norms (largest singular value) across layers.
"""
prod_spec = torch.tensor(1.0, device=next(self.parameters()).device)
for module in list(self.features) + list(self.classifier):
if isinstance(module, (nn.Conv2d, nn.Linear)):
weight = module.weight
if weight.ndim > 2:
w = weight.view(weight.size(0), -1)
else:
w = weight
s = torch.linalg.svdvals(w)[0]
prod_spec *= s
return prod_spec

def compute_path_norm(self):
"""
Approximates the path norm via dynamic programming over channels:
- For Conv2d: sum absolute weights over spatial dims to get channel connectivity.
- For Linear: use absolute weight matrix directly.
"""
device = next(self.parameters()).device
prev = None
for module in list(self.features) + list(self.classifier):
if isinstance(module, (nn.Conv2d, nn.Linear)):
weight = module.weight
if weight.ndim > 2:
# weight shape: [out_channels, in_channels, kH, kW]
abs_w = torch.sum(torch.abs(weight), dim=(2, 3))
else:
abs_w = torch.abs(weight)
# abs_w: [out_units, in_units]
if prev is None:
prev = torch.ones(abs_w.size(1), device=device)
curr = abs_w.matmul(prev)
prev = curr
return torch.sum(prev)

def compute_fisher_rao_norm(self, inputs, labels):
"""
Approximates the Fisher-Rao norm via diagonal Fisher information diag(F) ≈ E[grad^2].
Requires one batch of (inputs, labels).
"""
self.eval()
# compute gradients of log-likelihood
self.zero_grad()
outputs = F.log_softmax(self(inputs), dim=1)
batch_size = outputs.size(0)
log_probs = outputs[torch.arange(batch_size), labels]
loss = -log_probs.mean()
loss.backward()
total = torch.tensor(0.0, device=next(self.parameters()).device)
for param in self.parameters():
if param.grad is not None:
# diag(F)_i ≈ (grad_i)^2 ; FR norm ≈ sum theta_i^2 * diag(F)_i
total += torch.sum((param.detach() ** 2) * (param.grad.detach() ** 2))
return torch.sqrt(total)


def compute_margin_distribution(self, inputs, labels):
"""
Computes the normalized margin distribution on a given batch.
Expand Down
Empty file added resnet/__init__.py
Empty file.
Loading