|
148 | 148 | "from devinterp.optim.sgld import SGLD\n", |
149 | 149 | "\n", |
150 | 150 | "\n", |
151 | | - "\n", |
152 | 151 | "class DLN(nn.Module):\n", |
153 | 152 | " \"\"\"\n", |
154 | 153 | " A deep linear network with `L` layers with dimensions `dims`.\n", |
|
180 | 179 | " return f\"DLN({self.dims})\"\n", |
181 | 180 | "\n", |
182 | 181 | " @classmethod\n", |
183 | | - " def make_rectangular(cls, input_dim: int, output_dim: int, L: int, w: int, gamma: float):\n", |
| 182 | + " def make_rectangular(\n", |
| 183 | + " cls, input_dim: int, output_dim: int, L: int, w: int, gamma: float\n", |
| 184 | + " ):\n", |
184 | 185 | " \"\"\"\n", |
185 | 186 | " Make a rectangular DLN with `L` layers and constant hidden width `w`.\n", |
186 | 187 | "\n", |
|
189 | 190 | " The weights are initialized from a normal distribution with variance`w ** (-gamma)`.\n", |
190 | 191 | " \"\"\"\n", |
191 | 192 | " init_variance = w ** (-gamma)\n", |
192 | | - " return cls([input_dim] + [w] * (L - 1) + [output_dim], init_variance=init_variance)\n", |
| 193 | + " return cls(\n", |
| 194 | + " [input_dim] + [w] * (L - 1) + [output_dim], init_variance=init_variance\n", |
| 195 | + " )\n", |
193 | 196 | "\n", |
194 | 197 | " def to_matrix(self):\n", |
195 | 198 | " \"\"\"Return the collapsed matrix representation of the DLN.\"\"\"\n", |
|
212 | 215 | "\n", |
213 | 216 | " def ranks(self, **kwargs):\n", |
214 | 217 | " \"\"\"Return the ranks of the individual layers of the DLN.\"\"\"\n", |
215 | | - " return [torch.linalg.matrix_rank(l.weight.data.to(\"cpu\"), **kwargs) for l in self.linears]\n", |
| 218 | + " return [\n", |
| 219 | + " torch.linalg.matrix_rank(l.weight.data.to(\"cpu\"), **kwargs)\n", |
| 220 | + " for l in self.linears\n", |
| 221 | + " ]\n", |
216 | 222 | "\n", |
217 | 223 | " def norm(self, p: Union[int, float, str] = 2):\n", |
218 | 224 | " \"\"\"Return the nuclear norm of the DLN.\"\"\"\n", |
|
247 | 253 | " def device(self):\n", |
248 | 254 | " return next(self.parameters()).device\n", |
249 | 255 | "\n", |
| 256 | + "\n", |
250 | 257 | "class DLNDataset(Dataset):\n", |
251 | 258 | " teacher: DLN\n", |
252 | 259 | "\n", |
|
323 | 330 | "\n", |
324 | 331 | "DEVICE = os.environ.get(\n", |
325 | 332 | " \"DEVICE\",\n", |
326 | | - " \"cuda:0\"\n", |
327 | | - " if torch.cuda.is_available()\n", |
328 | | - " else \"mps\"\n", |
329 | | - " if torch.backends.mps.is_available()\n", |
330 | | - " else \"cpu\",\n", |
| 333 | + " (\n", |
| 334 | + " \"cuda:0\"\n", |
| 335 | + " if torch.cuda.is_available()\n", |
| 336 | + " else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n", |
| 337 | + " ),\n", |
331 | 338 | ")\n", |
332 | 339 | "DEVICE = torch.device(DEVICE)\n", |
333 | 340 | "NUM_CORES = int(os.environ.get(\"NUM_CORES\", 1))\n", |
|
477 | 484 | "\n", |
478 | 485 | " def eval_rlct(model: DLN):\n", |
479 | 486 | " model.to(\"cpu\")\n", |
480 | | - " optimizer_kwargs = dict(\n", |
481 | | - " lr=1e-4, temperature=\"adaptive\", num_samples=len(trainset), elasticity=1.0\n", |
482 | | - " )\n", |
| 487 | + " optimizer_kwargs = dict(lr=1e-4, localization=1.0)\n", |
483 | 488 | " optimizer_kwargs.update(kwargs)\n", |
484 | 489 | " rlct = estimate_learning_coeff(\n", |
485 | 490 | " model,\n", |
|
658 | 663 | "\n", |
659 | 664 | " # Train error\n", |
660 | 665 | " ax.plot(df.step, df[\"mse/test\"], label=\"Test error\", color=PRIMARY)\n", |
661 | | - " ax.plot(df.step, df[\"mse/train\"], label=\"Train error\", color=PRIMARY_LIGHT, alpha=0.5)\n", |
| 666 | + " ax.plot(\n", |
| 667 | + " df.step, df[\"mse/train\"], label=\"Train error\", color=PRIMARY_LIGHT, alpha=0.5\n", |
| 668 | + " )\n", |
662 | 669 | " ax.set_yscale(\"log\")\n", |
663 | 670 | " ax.set_ylabel(\"MSE\", color=PRIMARY)\n", |
664 | 671 | " ax.tick_params(axis=\"y\", labelcolor=PRIMARY)\n", |
|
952 | 959 | " seed=seed,\n", |
953 | 960 | " )\n", |
954 | 961 | " learner = config.create_learner(\n", |
955 | | - " num_draws=10, num_chains=100, lr=1e-4, elasticity=1.0, repeats=5\n", |
| 962 | + " num_draws=10, num_chains=100, lr=1e-4, localization=1.0, repeats=5\n", |
956 | 963 | " )\n", |
957 | 964 | " df = train(learner)\n", |
958 | 965 | " dfs.append(df)\n", |
|
1698 | 1705 | " for noise_level in [0.0, 10.0]:\n", |
1699 | 1706 | " name = f\"rk{rk}_L4_w100_noise{noise_level}\"\n", |
1700 | 1707 | " results[name] = run_experiment(rk5_matrix, seed=SEED, **default_settings)\n", |
1701 | | - " plot_all(results[name], xlog=False, title=f\"r={rk}, L=4, w=100, noise={noise_level}\")\n", |
| 1708 | + " plot_all(\n", |
| 1709 | + " results[name], xlog=False, title=f\"r={rk}, L=4, w=100, noise={noise_level}\"\n", |
| 1710 | + " )\n", |
1702 | 1711 | "\n", |
1703 | 1712 | "df = None\n", |
1704 | 1713 | "\n", |
|
2084 | 2093 | "for gamma in [0.75, 1.0, 1.5]:\n", |
2085 | 2094 | " # for w in [10, 100, 1000]:\n", |
2086 | 2095 | " for w in [10, 100]:\n", |
2087 | | - " results = run_experiment(rk5_matrix, seed=SEED, w=w, gamma=gamma, **fig5_settings)\n", |
| 2096 | + " results = run_experiment(\n", |
| 2097 | + " rk5_matrix, seed=SEED, w=w, gamma=gamma, **fig5_settings\n", |
| 2098 | + " )\n", |
2088 | 2099 | " _df = pd.DataFrame(results)\n", |
2089 | 2100 | " _df[\"w\"] = w\n", |
2090 | 2101 | " _df[\"gamma\"] = gamma\n", |
|
2498 | 2509 | "name": "python", |
2499 | 2510 | "nbconvert_exporter": "python", |
2500 | 2511 | "pygments_lexer": "ipython3", |
2501 | | - "version": "3.8.10" |
| 2512 | + "version": "3.9.18" |
2502 | 2513 | } |
2503 | 2514 | }, |
2504 | 2515 | "nbformat": 4, |
|
0 commit comments