Skip to content

Add parallelism_auto flag to automatically set dp, tp and micro batch size#516

Open
pefontana wants to merge 22 commits intomainfrom
hardcode-parallelism-data
Open

Add parallelism_auto flag to automatically set dp, tp and micro batch size#516
pefontana wants to merge 22 commits intomainfrom
hardcode-parallelism-data

Conversation

@pefontana
Copy link
Contributor

@pefontana pefontana commented Jan 23, 2026

Adds --parallelism-auto flag to automatically detect optimal dp, tp, and micro_batch_size based on the model and GPU hardware.

  1. Detects GPU type via nvidia-smi (with fallback to /proc/driver/nvidia/gpus/*/information for containers)
  2. Fetches parallelism_data.json from the model's HuggingFace repo
  3. Looks up config: GPU type → num GPUs → {dp, tp, micro_batch_size}
  {
    "H100": {
      "1": { "dp": 1, "tp": 1, "micro_batch_size": 4 },
      "8": { "dp": 4, "tp": 2, "micro_batch_size": 4 }
    },
    "H200": {
      "8": { "dp": 8, "tp": 1, "micro_batch_size": 8 }
    }
  }

We considered sharing the parallelism config via P2P but it added unnecessary complexity—dp/tp/micro_batch_size are needed before model download begins, requiring a separate request/wait flow. Since the file is ~200 bytes, fetching from HF is fast and keeps the code simple.

Note:

I think we can merge this PR now, we should only take into account that if the run is private, is posible that some trainers dont have access to HuggingFace, so I dont recommend using this flag for private runs now
Once #455 is merged, all trainers must have access to HF/GCP, so it wont be a problem

@pefontana pefontana changed the title Hardcode parallelism data Add parallelism_auto flag to automatically set dp, tp and micro batch size Jan 26, 2026
@pefontana pefontana marked this pull request as ready for review January 26, 2026 20:40

fn get_gpu_type() -> String {
// Try nvidia-smi first
let raw = Command::new("nvidia-smi")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small nit: maybe we can rename this variable to something better?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

Comment on lines 21 to 57
fn get_gpu_type() -> String {
// Try nvidia-smi first
let raw_gpu_name = Command::new("nvidia-smi")
.args(["--query-gpu=name", "--format=csv,noheader"])
.output()
.ok()
.and_then(|o| String::from_utf8(o.stdout).ok())
.and_then(|s| s.lines().next().map(|l| l.trim().to_string()))
.filter(|s| !s.is_empty())
// Fallback: read from /proc/driver/nvidia (works in containers without nvidia-smi)
.or_else(|| {
std::fs::read_dir("/proc/driver/nvidia/gpus")
.ok()?
.filter_map(|e| e.ok())
.next()
.and_then(|entry| {
let info_path = entry.path().join("information");
std::fs::read_to_string(info_path).ok()
})
.and_then(|content| {
content
.lines()
.find(|line| line.starts_with("Model:"))
.map(|line| line.trim_start_matches("Model:").trim().to_string())
})
})
.unwrap_or_default();

// Normalize GPU name to match table keys
if raw_gpu_name.to_uppercase().contains("H200") {
"H200".to_string()
} else if raw_gpu_name.to_uppercase().contains("H100") {
"H100".to_string()
} else {
raw_gpu_name
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have you considered using the nvml_wrapper crate instead of shelling out / reading /proc/fs stuff? we can grab gpu count from there too 🤷 and assert that they're all the same GPU for sanity checking :D

we use this in some of the metrics stuff already -
it would be something like:

use nvml_wrapper::Nvml;

#[derive(Debug)]
struct GpuInfo {
    name: String,
    device_count: u32,
}

fn get_gpu_info() -> anyhow::Result<GpuInfo> {
    let nvml = Nvml::init()?;
    let device_count = nvml.device_count()?;

    if device_count == 0 {
        anyhow::bail!("No GPUs found!");
    }

    let mut gpu_names = Vec::new();

    for i in 0..device_count {
        let device = nvml.device_by_index(i)?;
        gpu_names.push(device.name()?);
    }

    let first_name = &gpu_names[0];
    if !gpu_names.iter().all(|name| name == first_name) {
        anyhow::bail!(
            "All GPUs must be of the same type, but we have mismatching names: {:?}",
            gpu_names
        );
    }

    Ok(GpuInfo {
        name: gpu_names.pop().unwrap(),
        device_count,
    })
}

Copy link
Contributor Author

@pefontana pefontana Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Ari!
I add it and works

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just changed the let device_count = nvml.device_count()?; because was not taking into account CUDA_VISIBLE_DEVICES env var

@pefontana pefontana force-pushed the hardcode-parallelism-data branch from 0beffdf to ec28547 Compare January 28, 2026 21:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants