Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 47 additions & 7 deletions XGBWW_OpenML_1049_W7W8_Alpha_Checkpoint.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,19 @@
"DRIVE_ROOT.mkdir(parents=True, exist_ok=True)\n",
"\n",
"CHECKPOINT_EVERY_ROUNDS = 1 # evaluate/save every N rounds\n",
"WW_PLOT_EVERY_ROUNDS = 50 # run WW plot+save every N rounds\n",
"MAX_ROUNDS = 10000 # safety cap\n",
"TARGET_ALPHA = 2.0 # stop when alpha(W7) and alpha(W8) <= this value\n",
"TEST_SIZE = 0.2\n",
"RANDOM_STATE = 42\n",
"FORCE_FRESH_START = False # True = ignore prior checkpoints and start over\n",
"RESTART_RUNTIME_AFTER_INSTALL = False\n",
"\n",
"S3_WW_PLOT_PREFIX = os.environ.get('WW_S3_PLOT_PREFIX', '').strip()\n",
"\n",
"print('Checkpoint folder:', DRIVE_ROOT)\n",
"print('WW plot cadence:', WW_PLOT_EVERY_ROUNDS)\n",
"print('S3 plot destination:', S3_WW_PLOT_PREFIX if S3_WW_PLOT_PREFIX else '(disabled)')\n",
"print('Started at:', datetime.utcnow().isoformat() + 'Z')\n"
],
"id": "ujweiMJQcLj4"
Expand Down Expand Up @@ -182,6 +187,7 @@
"import time\n",
"import json\n",
"import warnings\n",
"import subprocess\n",
"warnings.filterwarnings('ignore')\n",
"\n",
"import numpy as np\n",
Expand All @@ -201,6 +207,8 @@
"MODEL_PATH = DRIVE_ROOT / 'model_latest.json'\n",
"SPLIT_PATH = DRIVE_ROOT / 'data_split.npz'\n",
"SUMMARY_PATH = DRIVE_ROOT / 'summary.json'\n",
"WW_PLOTS_ROOT = DRIVE_ROOT / 'ww_plots'\n",
"WW_PLOTS_ROOT.mkdir(parents=True, exist_ok=True)\n",
"\n",
"\n",
"def _extract_weight_shape(layer):\n",
Expand Down Expand Up @@ -234,9 +242,17 @@
" return int(min(shape))\n",
"\n",
"\n",
"def ww_stats_for_matrix(layer, matrix_name):\n",
"def ww_stats_for_matrix(layer, matrix_name, plot=False, savedir=None):\n",
" watcher = ww.WeightWatcher(model=layer)\n",
" details = watcher.analyze(randomize=True, detX=True, ERG=True, plot=False)\n",
" analyze_kwargs = dict(randomize=True, detX=True, ERG=True, plot=False)\n",
" if plot:\n",
" if savedir is None:\n",
" raise ValueError(f\"savedir is required when plot=True for {matrix_name}\")\n",
" savedir = Path(savedir)\n",
" savedir.mkdir(parents=True, exist_ok=True)\n",
" analyze_kwargs.update(plot=True, savefig=True, savedir=str(savedir))\n",
"\n",
" details = watcher.analyze(**analyze_kwargs)\n",
" if 'alpha' not in details.columns:\n",
" raise RuntimeError(f\"WeightWatcher output missing alpha for {matrix_name}: columns={list(details.columns)}\")\n",
"\n",
Expand All @@ -248,6 +264,19 @@
" }\n",
"\n",
"\n",
"def sync_round_plots_to_s3(round_plot_dir, round_idx):\n",
" if not S3_WW_PLOT_PREFIX:\n",
" return\n",
"\n",
" s3_round_dir = f\"{S3_WW_PLOT_PREFIX.rstrip('/')}/round_{round_idx:05d}/\"\n",
" cmd = ['aws', 's3', 'cp', '--recursive', str(round_plot_dir), s3_round_dir]\n",
" try:\n",
" subprocess.run(cmd, check=True)\n",
" print(f\"round={round_idx:4d} | uploaded WW plots to {s3_round_dir}\")\n",
" except Exception as err:\n",
" print(f\"round={round_idx:4d} | WARNING: failed to upload WW plots to S3 ({err})\")\n",
"\n",
"\n",
"def convert_matrix_layer(model, Xtr, ytr, matrix_name, train_params=None, num_boost_round=None):\n",
" return convert(\n",
" model,\n",
Expand Down Expand Up @@ -489,9 +518,17 @@
" )\n",
" continue\n",
"\n",
" ww_w2 = ww_stats_for_matrix(layer_w2, 'W2')\n",
" ww_w7 = ww_stats_for_matrix(layer_w7, 'W7')\n",
" ww_w8 = ww_stats_for_matrix(layer_w8, 'W8')\n",
" plot_this_round = (r % WW_PLOT_EVERY_ROUNDS == 0)\n",
" round_plot_dir = WW_PLOTS_ROOT / f'round_{r:05d}'\n",
"\n",
" if plot_this_round:\n",
" (round_plot_dir / 'W2').mkdir(parents=True, exist_ok=True)\n",
" (round_plot_dir / 'W7').mkdir(parents=True, exist_ok=True)\n",
" (round_plot_dir / 'W8').mkdir(parents=True, exist_ok=True)\n",
"\n",
" ww_w2 = ww_stats_for_matrix(layer_w2, 'W2', plot=plot_this_round, savedir=round_plot_dir / 'W2' if plot_this_round else None)\n",
" ww_w7 = ww_stats_for_matrix(layer_w7, 'W7', plot=plot_this_round, savedir=round_plot_dir / 'W7' if plot_this_round else None)\n",
" ww_w8 = ww_stats_for_matrix(layer_w8, 'W8', plot=plot_this_round, savedir=round_plot_dir / 'W8' if plot_this_round else None)\n",
"\n",
" alpha_w2 = ww_w2['alpha']\n",
" alpha_w7 = ww_w7['alpha']\n",
Expand Down Expand Up @@ -533,6 +570,10 @@
" f\"| test_acc={te_acc:.4f} logloss={te_loss:.4f}\"\n",
" )\n",
"\n",
" if plot_this_round:\n",
" print(f\"round={r:4d} | saved WW plots to {round_plot_dir}\")\n",
" sync_round_plots_to_s3(round_plot_dir, r)\n",
"\n",
" if state.get('target_round_all') is not None:\n",
" print(f\"Reached alpha <= {TARGET_ALPHA} for W2, W7, and W8 at round {r}.\")\n",
" break\n",
Expand All @@ -549,8 +590,7 @@
"}\n",
"SUMMARY_PATH.write_text(json.dumps(summary, indent=2))\n",
"\n",
"print('\n",
"Summary:', json.dumps(summary, indent=2))\n"
"print('\\nSummary:', json.dumps(summary, indent=2))\n"
],
"id": "6gcD-VQhcLj6"
},
Expand Down