Skip to content

Commit 9c943b4

Browse files
authored
Merge pull request #69 from timaeus-research/stan/fixes
Refactor num_samples / temperature, rename elasticity to localization, add warnings, init_loss before sampling, OnlineLLC fix
2 parents d901df9 + 2e46321 commit 9c943b4

35 files changed

Lines changed: 3121 additions & 2089 deletions

README.md

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ DevInterp is a python library for conducting research on developmental interpret
2222

2323
from devinterp.slt import sample, LLCEstimator
2424
from devinterp.optim import SGLD
25+
from devinterp.utils import optimal_temperature
2526

2627
# Assuming you have a PyTorch Module and DataLoader
27-
llc_estimator = LLCEstimator(...)
28+
llc_estimator = LLCEstimator(..., temperature=optimal_temperature(trainloader))
2829
sample(model, trainloader, ..., callbacks = [llc_estimator])
2930

3031
llc_mean = llc_estimator.sample()["llc/mean"]
@@ -46,10 +47,6 @@ For papers that either inspired or used the DevInterp package, [click here](http
4647

4748
## Known Issues
4849

49-
- We currently calculate the LLC taking the initial loss to be the loss after one sampling step. This is slightly wrong (it should be the loss before sampling), and there are a bunch of other reasonable and similarly compute-friendly alternative choices that can be made.
50-
51-
- Similarly, we now sample using minibatches that are passed along from the dataloader to sample(). This choice is obscured by the repo, and we should offer alternatives.
52-
5350
- The current implementation does not work with transformers out-of-the-box. This can be fixed by adding a wrapper to your model, for example passing Unpack(model) to sample() where unpack is defined by:
5451
```python
5552
class Unpack(nn.Module):

docs/index.md

Whitespace-only changes.

docs/tutorial.md

Whitespace-only changes.

examples/diagnostics.ipynb

Lines changed: 391 additions & 272 deletions
Large diffs are not rendered by default.

examples/dlns.ipynb

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,6 @@
148148
"from devinterp.optim.sgld import SGLD\n",
149149
"\n",
150150
"\n",
151-
"\n",
152151
"class DLN(nn.Module):\n",
153152
" \"\"\"\n",
154153
" A deep linear network with `L` layers with dimensions `dims`.\n",
@@ -180,7 +179,9 @@
180179
" return f\"DLN({self.dims})\"\n",
181180
"\n",
182181
" @classmethod\n",
183-
" def make_rectangular(cls, input_dim: int, output_dim: int, L: int, w: int, gamma: float):\n",
182+
" def make_rectangular(\n",
183+
" cls, input_dim: int, output_dim: int, L: int, w: int, gamma: float\n",
184+
" ):\n",
184185
" \"\"\"\n",
185186
" Make a rectangular DLN with `L` layers and constant hidden width `w`.\n",
186187
"\n",
@@ -189,7 +190,9 @@
189190
" The weights are initialized from a normal distribution with variance`w ** (-gamma)`.\n",
190191
" \"\"\"\n",
191192
" init_variance = w ** (-gamma)\n",
192-
" return cls([input_dim] + [w] * (L - 1) + [output_dim], init_variance=init_variance)\n",
193+
" return cls(\n",
194+
" [input_dim] + [w] * (L - 1) + [output_dim], init_variance=init_variance\n",
195+
" )\n",
193196
"\n",
194197
" def to_matrix(self):\n",
195198
" \"\"\"Return the collapsed matrix representation of the DLN.\"\"\"\n",
@@ -212,7 +215,10 @@
212215
"\n",
213216
" def ranks(self, **kwargs):\n",
214217
" \"\"\"Return the ranks of the individual layers of the DLN.\"\"\"\n",
215-
" return [torch.linalg.matrix_rank(l.weight.data.to(\"cpu\"), **kwargs) for l in self.linears]\n",
218+
" return [\n",
219+
" torch.linalg.matrix_rank(l.weight.data.to(\"cpu\"), **kwargs)\n",
220+
" for l in self.linears\n",
221+
" ]\n",
216222
"\n",
217223
" def norm(self, p: Union[int, float, str] = 2):\n",
218224
" \"\"\"Return the nuclear norm of the DLN.\"\"\"\n",
@@ -247,6 +253,7 @@
247253
" def device(self):\n",
248254
" return next(self.parameters()).device\n",
249255
"\n",
256+
"\n",
250257
"class DLNDataset(Dataset):\n",
251258
" teacher: DLN\n",
252259
"\n",
@@ -323,11 +330,11 @@
323330
"\n",
324331
"DEVICE = os.environ.get(\n",
325332
" \"DEVICE\",\n",
326-
" \"cuda:0\"\n",
327-
" if torch.cuda.is_available()\n",
328-
" else \"mps\"\n",
329-
" if torch.backends.mps.is_available()\n",
330-
" else \"cpu\",\n",
333+
" (\n",
334+
" \"cuda:0\"\n",
335+
" if torch.cuda.is_available()\n",
336+
" else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n",
337+
" ),\n",
331338
")\n",
332339
"DEVICE = torch.device(DEVICE)\n",
333340
"NUM_CORES = int(os.environ.get(\"NUM_CORES\", 1))\n",
@@ -477,9 +484,7 @@
477484
"\n",
478485
" def eval_rlct(model: DLN):\n",
479486
" model.to(\"cpu\")\n",
480-
" optimizer_kwargs = dict(\n",
481-
" lr=1e-4, temperature=\"adaptive\", num_samples=len(trainset), elasticity=1.0\n",
482-
" )\n",
487+
" optimizer_kwargs = dict(lr=1e-4, localization=1.0)\n",
483488
" optimizer_kwargs.update(kwargs)\n",
484489
" rlct = estimate_learning_coeff(\n",
485490
" model,\n",
@@ -658,7 +663,9 @@
658663
"\n",
659664
" # Train error\n",
660665
" ax.plot(df.step, df[\"mse/test\"], label=\"Test error\", color=PRIMARY)\n",
661-
" ax.plot(df.step, df[\"mse/train\"], label=\"Train error\", color=PRIMARY_LIGHT, alpha=0.5)\n",
666+
" ax.plot(\n",
667+
" df.step, df[\"mse/train\"], label=\"Train error\", color=PRIMARY_LIGHT, alpha=0.5\n",
668+
" )\n",
662669
" ax.set_yscale(\"log\")\n",
663670
" ax.set_ylabel(\"MSE\", color=PRIMARY)\n",
664671
" ax.tick_params(axis=\"y\", labelcolor=PRIMARY)\n",
@@ -952,7 +959,7 @@
952959
" seed=seed,\n",
953960
" )\n",
954961
" learner = config.create_learner(\n",
955-
" num_draws=10, num_chains=100, lr=1e-4, elasticity=1.0, repeats=5\n",
962+
" num_draws=10, num_chains=100, lr=1e-4, localization=1.0, repeats=5\n",
956963
" )\n",
957964
" df = train(learner)\n",
958965
" dfs.append(df)\n",
@@ -1698,7 +1705,9 @@
16981705
" for noise_level in [0.0, 10.0]:\n",
16991706
" name = f\"rk{rk}_L4_w100_noise{noise_level}\"\n",
17001707
" results[name] = run_experiment(rk5_matrix, seed=SEED, **default_settings)\n",
1701-
" plot_all(results[name], xlog=False, title=f\"r={rk}, L=4, w=100, noise={noise_level}\")\n",
1708+
" plot_all(\n",
1709+
" results[name], xlog=False, title=f\"r={rk}, L=4, w=100, noise={noise_level}\"\n",
1710+
" )\n",
17021711
"\n",
17031712
"df = None\n",
17041713
"\n",
@@ -2084,7 +2093,9 @@
20842093
"for gamma in [0.75, 1.0, 1.5]:\n",
20852094
" # for w in [10, 100, 1000]:\n",
20862095
" for w in [10, 100]:\n",
2087-
" results = run_experiment(rk5_matrix, seed=SEED, w=w, gamma=gamma, **fig5_settings)\n",
2096+
" results = run_experiment(\n",
2097+
" rk5_matrix, seed=SEED, w=w, gamma=gamma, **fig5_settings\n",
2098+
" )\n",
20882099
" _df = pd.DataFrame(results)\n",
20892100
" _df[\"w\"] = w\n",
20902101
" _df[\"gamma\"] = gamma\n",
@@ -2498,7 +2509,7 @@
24982509
"name": "python",
24992510
"nbconvert_exporter": "python",
25002511
"pygments_lexer": "ipython3",
2501-
"version": "3.8.10"
2512+
"version": "3.9.18"
25022513
}
25032514
},
25042515
"nbformat": 4,

0 commit comments

Comments
 (0)