From f00682883451a351df5d4c5b3b477baeff8524ce Mon Sep 17 00:00:00 2001 From: Charles Martin Date: Sun, 15 Mar 2026 20:36:26 -0700 Subject: [PATCH] Add periodic WeightWatcher plot export and S3 sync support --- XGBWW_OpenML_1049_W7W8_Alpha_Checkpoint.ipynb | 54 ++++++++++++++++--- 1 file changed, 47 insertions(+), 7 deletions(-) diff --git a/XGBWW_OpenML_1049_W7W8_Alpha_Checkpoint.ipynb b/XGBWW_OpenML_1049_W7W8_Alpha_Checkpoint.ipynb index 2d977b1..2bf1868 100644 --- a/XGBWW_OpenML_1049_W7W8_Alpha_Checkpoint.ipynb +++ b/XGBWW_OpenML_1049_W7W8_Alpha_Checkpoint.ipynb @@ -81,6 +81,7 @@ "DRIVE_ROOT.mkdir(parents=True, exist_ok=True)\n", "\n", "CHECKPOINT_EVERY_ROUNDS = 1 # evaluate/save every N rounds\n", + "WW_PLOT_EVERY_ROUNDS = 50 # run WW plot+save every N rounds\n", "MAX_ROUNDS = 10000 # safety cap\n", "TARGET_ALPHA = 2.0 # stop when alpha(W7) and alpha(W8) <= this value\n", "TEST_SIZE = 0.2\n", @@ -88,7 +89,11 @@ "FORCE_FRESH_START = False # True = ignore prior checkpoints and start over\n", "RESTART_RUNTIME_AFTER_INSTALL = False\n", "\n", + "S3_WW_PLOT_PREFIX = os.environ.get('WW_S3_PLOT_PREFIX', '').strip()\n", + "\n", "print('Checkpoint folder:', DRIVE_ROOT)\n", + "print('WW plot cadence:', WW_PLOT_EVERY_ROUNDS)\n", + "print('S3 plot destination:', S3_WW_PLOT_PREFIX if S3_WW_PLOT_PREFIX else '(disabled)')\n", "print('Started at:', datetime.utcnow().isoformat() + 'Z')\n" ], "id": "ujweiMJQcLj4" @@ -182,6 +187,7 @@ "import time\n", "import json\n", "import warnings\n", + "import subprocess\n", "warnings.filterwarnings('ignore')\n", "\n", "import numpy as np\n", @@ -201,6 +207,8 @@ "MODEL_PATH = DRIVE_ROOT / 'model_latest.json'\n", "SPLIT_PATH = DRIVE_ROOT / 'data_split.npz'\n", "SUMMARY_PATH = DRIVE_ROOT / 'summary.json'\n", + "WW_PLOTS_ROOT = DRIVE_ROOT / 'ww_plots'\n", + "WW_PLOTS_ROOT.mkdir(parents=True, exist_ok=True)\n", "\n", "\n", "def _extract_weight_shape(layer):\n", @@ -234,9 +242,17 @@ " return int(min(shape))\n", "\n", "\n", - "def ww_stats_for_matrix(layer, matrix_name):\n", + "def ww_stats_for_matrix(layer, matrix_name, plot=False, savedir=None):\n", " watcher = ww.WeightWatcher(model=layer)\n", - " details = watcher.analyze(randomize=True, detX=True, ERG=True, plot=False)\n", + " analyze_kwargs = dict(randomize=True, detX=True, ERG=True, plot=False)\n", + " if plot:\n", + " if savedir is None:\n", + " raise ValueError(f\"savedir is required when plot=True for {matrix_name}\")\n", + " savedir = Path(savedir)\n", + " savedir.mkdir(parents=True, exist_ok=True)\n", + " analyze_kwargs.update(plot=True, savefig=True, savedir=str(savedir))\n", + "\n", + " details = watcher.analyze(**analyze_kwargs)\n", " if 'alpha' not in details.columns:\n", " raise RuntimeError(f\"WeightWatcher output missing alpha for {matrix_name}: columns={list(details.columns)}\")\n", "\n", @@ -248,6 +264,19 @@ " }\n", "\n", "\n", + "def sync_round_plots_to_s3(round_plot_dir, round_idx):\n", + " if not S3_WW_PLOT_PREFIX:\n", + " return\n", + "\n", + " s3_round_dir = f\"{S3_WW_PLOT_PREFIX.rstrip('/')}/round_{round_idx:05d}/\"\n", + " cmd = ['aws', 's3', 'cp', '--recursive', str(round_plot_dir), s3_round_dir]\n", + " try:\n", + " subprocess.run(cmd, check=True)\n", + " print(f\"round={round_idx:4d} | uploaded WW plots to {s3_round_dir}\")\n", + " except Exception as err:\n", + " print(f\"round={round_idx:4d} | WARNING: failed to upload WW plots to S3 ({err})\")\n", + "\n", + "\n", "def convert_matrix_layer(model, Xtr, ytr, matrix_name, train_params=None, num_boost_round=None):\n", " return convert(\n", " model,\n", @@ -489,9 +518,17 @@ " )\n", " continue\n", "\n", - " ww_w2 = ww_stats_for_matrix(layer_w2, 'W2')\n", - " ww_w7 = ww_stats_for_matrix(layer_w7, 'W7')\n", - " ww_w8 = ww_stats_for_matrix(layer_w8, 'W8')\n", + " plot_this_round = (r % WW_PLOT_EVERY_ROUNDS == 0)\n", + " round_plot_dir = WW_PLOTS_ROOT / f'round_{r:05d}'\n", + "\n", + " if plot_this_round:\n", + " (round_plot_dir / 'W2').mkdir(parents=True, exist_ok=True)\n", + " (round_plot_dir / 'W7').mkdir(parents=True, exist_ok=True)\n", + " (round_plot_dir / 'W8').mkdir(parents=True, exist_ok=True)\n", + "\n", + " ww_w2 = ww_stats_for_matrix(layer_w2, 'W2', plot=plot_this_round, savedir=round_plot_dir / 'W2' if plot_this_round else None)\n", + " ww_w7 = ww_stats_for_matrix(layer_w7, 'W7', plot=plot_this_round, savedir=round_plot_dir / 'W7' if plot_this_round else None)\n", + " ww_w8 = ww_stats_for_matrix(layer_w8, 'W8', plot=plot_this_round, savedir=round_plot_dir / 'W8' if plot_this_round else None)\n", "\n", " alpha_w2 = ww_w2['alpha']\n", " alpha_w7 = ww_w7['alpha']\n", @@ -533,6 +570,10 @@ " f\"| test_acc={te_acc:.4f} logloss={te_loss:.4f}\"\n", " )\n", "\n", + " if plot_this_round:\n", + " print(f\"round={r:4d} | saved WW plots to {round_plot_dir}\")\n", + " sync_round_plots_to_s3(round_plot_dir, r)\n", + "\n", " if state.get('target_round_all') is not None:\n", " print(f\"Reached alpha <= {TARGET_ALPHA} for W2, W7, and W8 at round {r}.\")\n", " break\n", @@ -549,8 +590,7 @@ "}\n", "SUMMARY_PATH.write_text(json.dumps(summary, indent=2))\n", "\n", - "print('\n", - "Summary:', json.dumps(summary, indent=2))\n" + "print('\\nSummary:', json.dumps(summary, indent=2))\n" ], "id": "6gcD-VQhcLj6" },