Train a ResNet-18 on CIFAR-10 while a PPO agent keeps adjusting the learning rate to squeeze out extra throughput. Metrics flow into PostgreSQL through PySpark and get graphed in Grafana. A JAX pmap harness checks the theoretical upper limit on multi-GPU HPC nodes.
`README.md` this guide
`requirements.txt` Python package list
`src/` core training code
`└── train.py` main training loop
`└── rl_tuner.py` PPO environment and wrapper
`logging/` metric pipeline
`└── spark_collector.py` aggregates runs to PostgreSQL
`validation/` hardware ceiling benchmark
`└── jax_harness.py` JAX pmap throughput probe
`infra/` container scaffolding
`└── docker-compose.yml` PostgreSQL and Spark services
`dashboards/` Grafana import
`└── resnet_lr_tuner_grafana.json`
---
git clone [https://github.com/your-handle/rl_hyperparam_tuner.git](https://github.com/your-handle/rl_hyperparam_tuner.git)
cd rl_hyperparam_tuner
python3 -m venv venv
source venv/bin/activate
pip install -r requirements.txtmkdir -p data
python - <<'PY'
import torchvision, torchvision.transforms as T
torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=T.ToTensor())
torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=T.ToTensor())
PYdocker compose -f infra/docker-compose.yml up -d
export JDBC_URL=jdbc:postgresql://localhost:5432/metrics?user=metrics&password=metricspython src/train.py --epochs 10 --batch_size 128spark-submit logging/spark_collector.pyImport dashboards/resnet_lr_tuner_grafana.json into Grafana.
Watch throughput and loss curves update in real time.
python validation/jax_harness.py| Stage | Tooling | Purpose |
|---|---|---|
| Learning-rate control | Stable-Baselines3 PPO | Agent picks log-scaled LR each epoch |
| Training loop | PyTorch | ResNet-18 on CIFAR-10 |
| Metric capture | PySpark | Parse per-batch JSON logs |
| Storage | PostgreSQL | Persist run-level summaries |
| Dashboards | Grafana + Prometheus SQL | Live charts of throughput and loss |
| Hardware limit benchmark | JAX pmap |
Determine max imgs/s across all GPUs |
| Scenario | Throughput (img/s) |
|---|---|
| SGD fixed 0.1 | 865 |
| PPO tuned LR (this repo) | 910 |
| JAX pmap ceiling | 912 |
PPO closes the gap to within one percent of the hardware limit while cutting manual tuning effort to zero.
- Change
--batch_sizeinsrc/train.pyto match available VRAM. - Edit
logging/spark_collector.pyif your JDBC string differs. - Grafana assumes Prometheus SQL plugin v2 or later.