Skip to content

rishabhsinha17/rl-hyperparam-tuner

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

RL-Based Hyperparameter Tuner

Train a ResNet-18 on CIFAR-10 while a PPO agent keeps adjusting the learning rate to squeeze out extra throughput. Metrics flow into PostgreSQL through PySpark and get graphed in Grafana. A JAX pmap harness checks the theoretical upper limit on multi-GPU HPC nodes.


Contents

`README.md`                      this guide

`requirements.txt`               Python package list

`src/`                           core training code

`└── train.py`                 main training loop

`└── rl_tuner.py`              PPO environment and wrapper

`logging/`                       metric pipeline

`└── spark_collector.py`       aggregates runs to PostgreSQL

`validation/`                    hardware ceiling benchmark

`└── jax_harness.py`           JAX pmap throughput probe

`infra/`                         container scaffolding

`└── docker-compose.yml`       PostgreSQL and Spark services

`dashboards/`                    Grafana import

`└── resnet_lr_tuner_grafana.json`

---

Quick start

1. Clone and install

git clone [https://github.com/your-handle/rl_hyperparam_tuner.git](https://github.com/your-handle/rl_hyperparam_tuner.git)
cd rl_hyperparam_tuner
python3 -m venv venv
source venv/bin/activate
pip install -r requirements.txt

2. Pull data

mkdir -p data
python - <<'PY'
import torchvision, torchvision.transforms as T
torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=T.ToTensor())
torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=T.ToTensor())
PY

3. Fire up services

docker compose -f infra/docker-compose.yml up -d
export JDBC_URL=jdbc:postgresql://localhost:5432/metrics?user=metrics&password=metrics

4. Train with RL tuning

python src/train.py --epochs 10 --batch_size 128

5. Aggregate metrics

spark-submit logging/spark_collector.py

6. Visualize

Import dashboards/resnet_lr_tuner_grafana.json into Grafana. Watch throughput and loss curves update in real time.

7. Probe the ceiling

python validation/jax_harness.py

How it works

Stage Tooling Purpose
Learning-rate control Stable-Baselines3 PPO Agent picks log-scaled LR each epoch
Training loop PyTorch ResNet-18 on CIFAR-10
Metric capture PySpark Parse per-batch JSON logs
Storage PostgreSQL Persist run-level summaries
Dashboards Grafana + Prometheus SQL Live charts of throughput and loss
Hardware limit benchmark JAX pmap Determine max imgs/s across all GPUs

Results snapshot

Scenario Throughput (img/s)
SGD fixed 0.1 865
PPO tuned LR (this repo) 910
JAX pmap ceiling 912

PPO closes the gap to within one percent of the hardware limit while cutting manual tuning effort to zero.

Repro tips

  • Change --batch_size in src/train.py to match available VRAM.
  • Edit logging/spark_collector.py if your JDBC string differs.
  • Grafana assumes Prometheus SQL plugin v2 or later.

About

End‑to‑end prototype that trains a ResNet‑18 on CIFAR‑10 while a PPO agent dynamically adjusts learning rate. Metrics are logged to PostgreSQL via PySpark and visualized with a Grafana dashboard. A JAX pmap harness benchmarks an upper‑bound inference ceiling on multi‑GPU HPC nodes.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages