From 8d1730adc0f309769c225fa30cb58aeb166cd155 Mon Sep 17 00:00:00 2001 From: Muhammad Awad Date: Sun, 22 Mar 2026 19:23:14 -0500 Subject: [PATCH 1/8] Add distributed launcher support for linex and metrix - Add DistributedContext dataclass and env detection for torchrun, mpirun, srun, horovodrun - Linex: rank-scoped output dirs, RankProfile objects, MCP per-rank hotspots - Metrix: rank metadata in ProfileResult/KernelResults/ProfilingResults, rank-suffixed output files - CLI: argparse.REMAINDER for `-- launcher ...` syntax - Both: normalize_command_argv with shlex, accept str | Sequence[str] - Tests for distributed helpers, shlex parsing, rank field propagation Note: command construction is still rocprofv3-wraps-launcher (wrong order). Next step: fix to launcher-wraps-rocprofv3 for correct distributed profiling. Co-Authored-By: Claude Opus 4.6 --- linex/README.md | 22 +++ linex/src/linex/__init__.py | 4 +- linex/src/linex/api.py | 132 +++++++++++++---- linex/src/linex/distributed.py | 115 +++++++++++++++ linex/src/linex/mcp/server.py | 84 ++++++++++- linex/tests/test_distributed_api.py | 73 +++++++++ metrix/README.md | 21 +++ metrix/src/metrix/api.py | 46 +++++- metrix/src/metrix/backends/base.py | 29 +++- metrix/src/metrix/backends/gfx1201.py | 4 +- metrix/src/metrix/backends/gfx90a.py | 4 +- metrix/src/metrix/backends/gfx942.py | 4 +- metrix/src/metrix/cli/main.py | 6 +- metrix/src/metrix/cli/profile_cmd.py | 112 +++++++++++--- metrix/src/metrix/mcp/server.py | 15 +- metrix/src/metrix/profiler/rocprof_wrapper.py | 29 +++- metrix/src/metrix/utils/distributed.py | 138 ++++++++++++++++++ metrix/tests/unit/test_distributed.py | 49 +++++++ metrix/tests/unit/test_rocprof_wrapper.py | 68 +++++++++ 19 files changed, 881 insertions(+), 74 deletions(-) create mode 100644 linex/src/linex/distributed.py create mode 100644 linex/tests/test_distributed_api.py create mode 100644 metrix/src/metrix/utils/distributed.py create mode 100644 metrix/tests/unit/test_distributed.py diff --git a/linex/README.md b/linex/README.md index a9fb24c..44057a9 100644 --- a/linex/README.md +++ b/linex/README.md @@ -24,6 +24,26 @@ for line in profiler.source_lines[:5]: print(f" {line.total_cycles:,} cycles ({line.stall_percent:.1f}% stalled)") ``` +## Distributed Launchers + +Linex can wrap distributed launcher commands (`torchrun`, `mpirun/mpiexec`, `srun`, +`horovodrun`) and automatically records rank metadata from common environment variables. + +```python +profiler = Linex() +profiler.profile( + "torchrun --nproc_per_node=8 train.py", + output_dir="linex_sqtt", +) + +print(profiler.distributed_context.global_rank) +for rank_key, rank_profile in profiler.rank_profiles.items(): + print(rank_key, len(rank_profile.source_lines)) +``` + +In distributed mode, Linex writes traces into rank-specific subdirectories +(`.../rank0000`, `.../rank0001`, ...) to avoid collisions. + ## What You Get **Instruction-level metrics mapped to source lines:** @@ -66,6 +86,8 @@ profiler = Linex( **Properties:** - `source_lines` - List[SourceLine] sorted by total_cycles - `instructions` - List[InstructionData] +- `rank_profiles` - Per-rank profiling data for distributed runs +- `distributed_context` - Detected launcher/rank metadata ### SourceLine diff --git a/linex/src/linex/__init__.py b/linex/src/linex/__init__.py index d832c18..be65762 100644 --- a/linex/src/linex/__init__.py +++ b/linex/src/linex/__init__.py @@ -8,7 +8,7 @@ providing cycle counts and performance metrics per source line. """ -from .api import Linex, SourceLine, InstructionData +from .api import InstructionData, Linex, RankProfile, SourceLine __version__ = "0.1.0" -__all__ = ["Linex", "SourceLine", "InstructionData"] +__all__ = ["Linex", "SourceLine", "InstructionData", "RankProfile"] diff --git a/linex/src/linex/api.py b/linex/src/linex/api.py index 2b9533f..da6c139 100644 --- a/linex/src/linex/api.py +++ b/linex/src/linex/api.py @@ -6,11 +6,14 @@ """ import json +import os import subprocess import urllib.request from dataclasses import dataclass from pathlib import Path -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Sequence + +from .distributed import DistributedContext, detect_distributed_context, normalize_command_argv @dataclass @@ -99,6 +102,21 @@ def stall_percent(self) -> float: return 100.0 * self.stall_cycles / self.total_cycles if self.total_cycles > 0 else 0.0 +@dataclass +class RankProfile: + """Per-rank trace data produced by a Linex profiling run.""" + + rank_key: str + global_rank: int + local_rank: int + world_size: int + hostname: str + launcher: str + ui_output_dir: str + source_lines: List[SourceLine] + instructions: List[InstructionData] + + class Linex: """ Linex - Source-Level GPU Performance Profiler @@ -143,6 +161,8 @@ def __init__( # Data storage self._instructions: List[InstructionData] = [] self._source_lines: Dict[str, SourceLine] = {} + self._rank_profiles: Dict[str, RankProfile] = {} + self._distributed_context: DistributedContext = DistributedContext() def _ensure_decoder(self) -> Path: """Ensure decoder library is available, download if needed.""" @@ -158,10 +178,11 @@ def _ensure_decoder(self) -> Path: def profile( self, - command: str, + command: str | Sequence[str], output_dir: Optional[str] = None, kernel_filter: Optional[str] = None, force_cu_mask: bool = True, + env: Optional[Dict[str, str]] = None, ) -> "Linex": """ Profile an application and collect source-level performance data. @@ -180,12 +201,27 @@ def profile( Returns: self for chaining """ - import os import tempfile # Use temp directory if not specified if output_dir is None: - output_dir = tempfile.mkdtemp(prefix="linex_") + base_output_dir = Path(tempfile.mkdtemp(prefix="linex_")) + else: + base_output_dir = Path(output_dir) + + run_env = os.environ.copy() + if env: + run_env.update(env) + + dist_context = detect_distributed_context(run_env) + self._distributed_context = dist_context + + output_path = base_output_dir + if dist_context.is_distributed: + output_path = base_output_dir / dist_context.rank_tag + output_path.mkdir(parents=True, exist_ok=True) + + command_argv = normalize_command_argv(command) cmd = [ "rocprofv3", @@ -199,37 +235,60 @@ def profile( "--att-shader-engine-mask", self.shader_engine_mask, "-d", - output_dir, + str(output_path), ] if kernel_filter: cmd.extend(["--kernel-include-regex", kernel_filter]) - cmd.extend(["--", *command.split()]) + cmd.extend(["--", *command_argv]) - env = os.environ.copy() - if force_cu_mask: - env["HSA_CU_MASK"] = "0x1" # Force to CU 0 + if force_cu_mask and "HSA_CU_MASK" not in run_env: + run_env["HSA_CU_MASK"] = "0x1" # Force to CU 0 unless caller already set a mask - result = subprocess.run(cmd, env=env, capture_output=True, text=True) + result = subprocess.run(cmd, env=run_env, capture_output=True, text=True) if result.returncode != 0: # Include stderr in error message for debugging raise RuntimeError(f"rocprofv3 failed with code {result.returncode}\n{result.stderr}") # Find generated ui_output directory - output_path = Path(output_dir) - ui_dirs = list(output_path.glob("ui_output_*")) + ui_dirs = sorted(output_path.glob("ui_output_*"), key=lambda p: p.name) if not ui_dirs: - raise RuntimeError(f"No ui_output directories found in {output_dir}") + raise RuntimeError(f"No ui_output directories found in {output_path}") + + self._rank_profiles = {} + for idx, ui_dir in enumerate(ui_dirs): + instructions, source_lines = self._load_ui_output_data(ui_dir) + if idx == 0: + rank_key = dist_context.rank_tag + else: + rank_key = f"{dist_context.rank_tag}_dispatch{idx + 1:03d}" + self._rank_profiles[rank_key] = RankProfile( + rank_key=rank_key, + global_rank=dist_context.global_rank, + local_rank=dist_context.local_rank, + world_size=dist_context.world_size, + hostname=dist_context.hostname, + launcher=dist_context.launcher, + ui_output_dir=str(ui_dir), + source_lines=sorted(source_lines.values(), key=lambda x: x.total_cycles, reverse=True), + instructions=instructions, + ) - # Load the first dispatch - self._load_ui_output(ui_dirs[0]) + # Preserve existing API behavior by exposing the first rank profile as top-level fields. + primary_rank = next(iter(self._rank_profiles.values())) + self._instructions = primary_rank.instructions + self._source_lines = { + line.source_location: line for line in primary_rank.source_lines if line.source_location + } return self - def _load_ui_output(self, ui_output_dir: Path) -> None: - """Internal: Load performance trace from ui_output directory.""" + def _load_ui_output_data( + self, ui_output_dir: Path + ) -> tuple[List[InstructionData], Dict[str, SourceLine]]: + """Internal: Load performance trace data from ui_output directory.""" code_file = ui_output_dir / "code.json" if not code_file.exists(): @@ -244,7 +303,7 @@ def _load_ui_output(self, ui_output_dir: Path) -> None: ) # Parse instructions - self._instructions = [] + instructions: List[InstructionData] = [] for entry in data["code"]: inst = InstructionData( isa=entry[0], @@ -257,21 +316,27 @@ def _load_ui_output(self, ui_output_dir: Path) -> None: stall_cycles=entry[8], idle_cycles=entry[9], ) - self._instructions.append(inst) + instructions.append(inst) - # Aggregate by source line - self._aggregate_source_lines() + return instructions, self._aggregate_source_lines(instructions) + + def _load_ui_output(self, ui_output_dir: Path) -> None: + """Backward-compatible loader for a single ui_output directory.""" + instructions, source_lines = self._load_ui_output_data(ui_output_dir) + self._instructions = instructions + self._source_lines = source_lines - def _aggregate_source_lines(self): + def _aggregate_source_lines(self, instructions: Optional[List[InstructionData]] = None): """Aggregate instruction data by source line.""" - self._source_lines = {} + source_lines: Dict[str, SourceLine] = {} + instructions_to_aggregate = self._instructions if instructions is None else instructions - for inst in self._instructions: + for inst in instructions_to_aggregate: source = inst.source_location if not source or source.startswith(";"): continue - if source not in self._source_lines: + if source not in source_lines: # Parse file:line from source if ":" in source: parts = source.rsplit(":", 1) @@ -285,7 +350,7 @@ def _aggregate_source_lines(self): file = source line = 0 - self._source_lines[source] = SourceLine( + source_lines[source] = SourceLine( file=file, line_number=line, source_location=source, @@ -296,12 +361,15 @@ def _aggregate_source_lines(self): instructions=[], ) - sl = self._source_lines[source] + sl = source_lines[source] sl.execution_count += inst.execution_count sl.total_cycles += inst.latency_cycles sl.stall_cycles += inst.stall_cycles sl.idle_cycles += inst.idle_cycles sl.instructions.append(inst) + if instructions is None: + self._source_lines = source_lines + return source_lines @property def source_lines(self) -> List[SourceLine]: @@ -312,3 +380,13 @@ def source_lines(self) -> List[SourceLine]: def instructions(self) -> List[InstructionData]: """Get all instructions.""" return self._instructions + + @property + def rank_profiles(self) -> Dict[str, RankProfile]: + """Get per-rank profiles for distributed runs.""" + return self._rank_profiles + + @property + def distributed_context(self) -> DistributedContext: + """Get distributed runtime metadata detected for this profile run.""" + return self._distributed_context diff --git a/linex/src/linex/distributed.py b/linex/src/linex/distributed.py new file mode 100644 index 0000000..30857a3 --- /dev/null +++ b/linex/src/linex/distributed.py @@ -0,0 +1,115 @@ +""" +Distributed launcher helpers for Linex. +""" + +from __future__ import annotations + +import os +import shlex +import socket +from dataclasses import dataclass +from typing import Mapping, Sequence + + +@dataclass(frozen=True) +class DistributedContext: + global_rank: int = 0 + local_rank: int = 0 + world_size: int = 1 + node_rank: int = 0 + hostname: str = "" + launcher: str = "single" + + @property + def is_distributed(self) -> bool: + return self.world_size > 1 + + @property + def rank_tag(self) -> str: + return f"rank{self.global_rank:04d}" + + +def _first_int(env: Mapping[str, str], keys: Sequence[str], default: int) -> int: + for key in keys: + value = env.get(key) + if value is None: + continue + try: + return int(value) + except ValueError: + continue + return default + + +def detect_distributed_context(env: Mapping[str, str] | None = None) -> DistributedContext: + env_map = os.environ if env is None else env + global_rank = _first_int( + env_map, + [ + "RANK", + "OMPI_COMM_WORLD_RANK", + "PMI_RANK", + "PMIX_RANK", + "SLURM_PROCID", + "HOROVOD_RANK", + ], + 0, + ) + local_rank = _first_int( + env_map, + [ + "LOCAL_RANK", + "OMPI_COMM_WORLD_LOCAL_RANK", + "MPI_LOCALRANKID", + "SLURM_LOCALID", + "HOROVOD_LOCAL_RANK", + ], + 0, + ) + world_size = _first_int( + env_map, + [ + "WORLD_SIZE", + "OMPI_COMM_WORLD_SIZE", + "PMI_SIZE", + "PMIX_SIZE", + "SLURM_NTASKS", + "HOROVOD_SIZE", + ], + 1, + ) + node_rank = _first_int( + env_map, + ["GROUP_RANK", "NODE_RANK", "OMPI_COMM_WORLD_NODE_RANK", "SLURM_NODEID"], + 0, + ) + hostname = env_map.get("HOSTNAME", "") or socket.gethostname() + + launcher = "single" + if "TORCHELASTIC_RUN_ID" in env_map or "LOCAL_RANK" in env_map: + launcher = "torchrun" + elif "OMPI_COMM_WORLD_RANK" in env_map: + launcher = "mpirun" + elif "SLURM_PROCID" in env_map: + launcher = "srun" + elif "HOROVOD_RANK" in env_map: + launcher = "horovodrun" + + return DistributedContext( + global_rank=global_rank, + local_rank=local_rank, + world_size=world_size, + node_rank=node_rank, + hostname=hostname, + launcher=launcher, + ) + + +def normalize_command_argv(command: str | Sequence[str]) -> list[str]: + if isinstance(command, (list, tuple)): + argv = [str(arg) for arg in command] + else: + argv = shlex.split(command) + if not argv: + raise ValueError("Command is empty") + return argv diff --git a/linex/src/linex/mcp/server.py b/linex/src/linex/mcp/server.py index cacc885..2d2d0b8 100644 --- a/linex/src/linex/mcp/server.py +++ b/linex/src/linex/mcp/server.py @@ -31,9 +31,17 @@ def profile_application(command: str, kernel_filter: str = None, top_n: int = 10 profiler.profile(command, kernel_filter=kernel_filter) results = { + "distributed_context": { + "global_rank": profiler.distributed_context.global_rank, + "local_rank": profiler.distributed_context.local_rank, + "world_size": profiler.distributed_context.world_size, + "hostname": profiler.distributed_context.hostname, + "launcher": profiler.distributed_context.launcher, + }, "total_source_lines": len(profiler.source_lines), "total_instructions": len(profiler.instructions), "hotspots": [], + "per_rank_hotspots": [], } for i, line in enumerate(profiler.source_lines[:top_n], 1): @@ -52,6 +60,36 @@ def profile_application(command: str, kernel_filter: str = None, top_n: int = 10 } ) + for rank_key, rank_profile in profiler.rank_profiles.items(): + rank_entry = { + "rank_key": rank_key, + "global_rank": rank_profile.global_rank, + "local_rank": rank_profile.local_rank, + "world_size": rank_profile.world_size, + "hostname": rank_profile.hostname, + "launcher": rank_profile.launcher, + "ui_output_dir": rank_profile.ui_output_dir, + "total_source_lines": len(rank_profile.source_lines), + "total_instructions": len(rank_profile.instructions), + "hotspots": [], + } + for i, line in enumerate(rank_profile.source_lines[:top_n], 1): + rank_entry["hotspots"].append( + { + "rank": i, + "file": line.file, + "line_number": line.line_number, + "source_location": line.source_location, + "total_cycles": line.total_cycles, + "stall_cycles": line.stall_cycles, + "stall_percent": round(line.stall_percent, 2), + "idle_cycles": line.idle_cycles, + "execution_count": line.execution_count, + "num_instructions": len(line.instructions), + } + ) + results["per_rank_hotspots"].append(rank_entry) + return results @@ -77,7 +115,17 @@ def analyze_instruction_hotspots( profiler = Linex() profiler.profile(command, kernel_filter=kernel_filter) - results = {"hotspot_analysis": []} + results = { + "distributed_context": { + "global_rank": profiler.distributed_context.global_rank, + "local_rank": profiler.distributed_context.local_rank, + "world_size": profiler.distributed_context.world_size, + "hostname": profiler.distributed_context.hostname, + "launcher": profiler.distributed_context.launcher, + }, + "hotspot_analysis": [], + "per_rank_hotspot_analysis": [], + } for line in profiler.source_lines[:top_lines]: # Sort instructions by latency @@ -105,6 +153,40 @@ def analyze_instruction_hotspots( results["hotspot_analysis"].append(line_data) + for rank_key, rank_profile in profiler.rank_profiles.items(): + rank_entry = { + "rank_key": rank_key, + "global_rank": rank_profile.global_rank, + "local_rank": rank_profile.local_rank, + "world_size": rank_profile.world_size, + "hostname": rank_profile.hostname, + "launcher": rank_profile.launcher, + "ui_output_dir": rank_profile.ui_output_dir, + "hotspot_analysis": [], + } + for line in rank_profile.source_lines[:top_lines]: + sorted_insts = sorted(line.instructions, key=lambda x: x.latency_cycles, reverse=True) + line_data = { + "source_location": line.source_location, + "total_cycles": line.total_cycles, + "stall_percent": round(line.stall_percent, 2), + "instructions": [], + } + for inst in sorted_insts[:top_instructions_per_line]: + line_data["instructions"].append( + { + "isa": inst.isa, + "latency_cycles": inst.latency_cycles, + "stall_cycles": inst.stall_cycles, + "stall_percent": round(inst.stall_percent, 2), + "idle_cycles": inst.idle_cycles, + "execution_count": inst.execution_count, + "instruction_address": f"0x{inst.instruction_address:08x}", + } + ) + rank_entry["hotspot_analysis"].append(line_data) + results["per_rank_hotspot_analysis"].append(rank_entry) + return results diff --git a/linex/tests/test_distributed_api.py b/linex/tests/test_distributed_api.py new file mode 100644 index 0000000..65c96a8 --- /dev/null +++ b/linex/tests/test_distributed_api.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: MIT + +import json +from pathlib import Path +from unittest.mock import MagicMock, patch + +from linex import Linex +from linex.distributed import detect_distributed_context, normalize_command_argv + + +def _write_code_json(ui_dir: Path, source_location: str) -> None: + ui_dir.mkdir(parents=True, exist_ok=True) + code = { + "code": [ + [ + "s_add_u32 s0, s0, s1", + 0, + 0, + source_location, + 1, + 0x1000, + 4, + 100, + 20, + 0, + ] + ] + } + (ui_dir / "code.json").write_text(json.dumps(code)) + + +def test_distributed_helpers_parse_common_env(): + ctx = detect_distributed_context({"SLURM_PROCID": "3", "SLURM_LOCALID": "1", "SLURM_NTASKS": "8"}) + assert ctx.global_rank == 3 + assert ctx.local_rank == 1 + assert ctx.world_size == 8 + assert ctx.launcher == "srun" + assert normalize_command_argv('torchrun --nproc_per_node=2 train.py --arg "two words"')[-1] == ( + "two words" + ) + + +def test_profile_uses_rank_scoped_output_and_loads_deterministic_ui_dir(tmp_path): + dummy_decoder = tmp_path / "decoder" / "librocprof-trace-decoder.so" + dummy_decoder.parent.mkdir(parents=True, exist_ok=True) + dummy_decoder.write_text("placeholder") + + def fake_run(cmd, **kwargs): + output_dir = Path(cmd[cmd.index("-d") + 1]) + _write_code_json(output_dir / "ui_output_200", "kernel2.hip:20") + _write_code_json(output_dir / "ui_output_100", "kernel1.hip:10") + m = MagicMock() + m.returncode = 0 + m.stdout = "" + m.stderr = "" + return m + + with ( + patch.object(Linex, "_ensure_decoder", return_value=dummy_decoder), + patch("subprocess.run", side_effect=fake_run), + ): + profiler = Linex() + profiler.profile( + command='python -c "print(1)"', + output_dir=str(tmp_path / "linex_out"), + env={"RANK": "1", "LOCAL_RANK": "1", "WORLD_SIZE": "2"}, + ) + + # Primary profile should come from lexicographically first ui_output directory. + assert profiler.source_lines[0].source_location == "kernel1.hip:10" + assert profiler.distributed_context.is_distributed + assert profiler.distributed_context.global_rank == 1 + assert len(profiler.rank_profiles) == 2 diff --git a/metrix/README.md b/metrix/README.md index f6e8351..705e1fd 100644 --- a/metrix/README.md +++ b/metrix/README.md @@ -47,6 +47,27 @@ metrix --metrics memory.l2_hit_rate,memory.coalescing_efficiency ./my_app metrix -o results.json ./my_app ``` +## Distributed Launchers + +Metrix supports launcher commands such as `torchrun`, `mpirun/mpiexec`, `srun`, and +`horovodrun`. Pass launcher commands directly after `--` so arguments are preserved. + +```bash +# Torch distributed +metrix profile -- torchrun --nproc_per_node=8 train.py + +# MPI +metrix profile -- mpirun -np 8 ./my_app --problem-size 4096 + +# Slurm +metrix profile -- srun -N 2 -n 16 ./my_app +``` + +When distributed rank variables are present (`RANK`, `OMPI_COMM_WORLD_RANK`, +`SLURM_PROCID`, etc.), Metrix emits per-rank metadata and automatically suffixes +`--output/-o` files with a rank tag (for example, `results.rank0003.json`) to avoid +cross-rank clobbering. + ## Python API ```python diff --git a/metrix/src/metrix/api.py b/metrix/src/metrix/api.py index 02738f2..b6828bd 100644 --- a/metrix/src/metrix/api.py +++ b/metrix/src/metrix/api.py @@ -8,7 +8,7 @@ """ import re -from typing import List, Optional, Dict +from typing import Dict, List, Optional, Sequence from dataclasses import dataclass from pathlib import Path @@ -16,6 +16,7 @@ from .backends.base import CounterBackend from .metrics import METRIC_PROFILES, METRIC_CATALOG from .logger import logger +from .utils.distributed import detect_distributed_context, normalize_command_argv @dataclass @@ -28,6 +29,10 @@ class KernelResults: duration_us: Statistics metrics: Dict[str, Statistics] dispatch_count: int = 1 + global_rank: int = 0 + local_rank: int = 0 + world_size: int = 1 + hostname: str = "" @property def avg_time_us(self) -> float: @@ -46,6 +51,11 @@ class ProfilingResults: command: str kernels: List[KernelResults] total_kernels: int + global_rank: int = 0 + local_rank: int = 0 + world_size: int = 1 + hostname: str = "" + launcher: str = "single" class Metrix: @@ -78,7 +88,7 @@ def __init__(self, arch: Optional[str] = None): def profile( self, - command: str, + command: str | Sequence[str], metrics: Optional[List[str]] = None, profile: Optional[str] = None, kernel_filter: Optional[str] = None, @@ -154,7 +164,9 @@ def profile( # Use simple kernel filter (no regex) rocprof_filter = kernel_filter - logger.info(f"Profiling: {command}") + command_argv = normalize_command_argv(command) + command_string = " ".join(command_argv) + logger.info(f"Profiling: {command_string}") logger.info(f"Collecting {len(metrics_to_compute)} metrics across {num_replays} replay(s)") if rocprof_filter: logger.info(f"Kernel filter: {rocprof_filter}") @@ -162,7 +174,7 @@ def profile( # Profile using backend (filtering at rocprofv3 level) logger.debug(f"Calling backend.profile with {len(metrics_to_compute)} metrics") self.backend.profile( - command=command, + command=command_argv, metrics=metrics_to_compute, num_replays=num_replays, aggregate_by_kernel=aggregate_by_kernel, @@ -177,10 +189,21 @@ def profile( if not dispatch_keys: logger.warning("No kernels profiled") - return ProfilingResults(command=command, kernels=[], total_kernels=0) + dist_context = detect_distributed_context() + return ProfilingResults( + command=command_string, + kernels=[], + total_kernels=0, + global_rank=dist_context.global_rank, + local_rank=dist_context.local_rank, + world_size=dist_context.world_size, + hostname=dist_context.hostname, + launcher=dist_context.launcher, + ) # Build result objects kernel_results = [] + dist_context = detect_distributed_context() for dispatch_key in dispatch_keys: # Get duration duration = self.backend._aggregated[dispatch_key].get("duration_us") @@ -206,11 +229,22 @@ def profile( duration_us=duration, metrics=computed_metrics, dispatch_count=int(dispatch_count), + global_rank=dist_context.global_rank, + local_rank=dist_context.local_rank, + world_size=dist_context.world_size, + hostname=dist_context.hostname, ) kernel_results.append(kernel_result) return ProfilingResults( - command=command, kernels=kernel_results, total_kernels=len(kernel_results) + command=command_string, + kernels=kernel_results, + total_kernels=len(kernel_results), + global_rank=dist_context.global_rank, + local_rank=dist_context.local_rank, + world_size=dist_context.world_size, + hostname=dist_context.hostname, + launcher=dist_context.launcher, ) def list_metrics(self, category: Optional[str] = None) -> List[str]: diff --git a/metrix/src/metrix/backends/base.py b/metrix/src/metrix/backends/base.py index 65c8cc5..1dcff5f 100644 --- a/metrix/src/metrix/backends/base.py +++ b/metrix/src/metrix/backends/base.py @@ -6,7 +6,7 @@ """ from abc import ABC, abstractmethod -from typing import List, Dict, Optional +from typing import List, Dict, Optional, Sequence from dataclasses import dataclass from collections import defaultdict @@ -66,6 +66,12 @@ class ProfileResult: arch_vgpr: int = 0 accum_vgpr: int = 0 sgpr: int = 0 + global_rank: int = 0 + local_rank: int = 0 + world_size: int = 1 + node_rank: int = 0 + hostname: str = "" + launcher: str = "single" class CounterBackend(ABC): @@ -465,7 +471,7 @@ def _split_counters_into_passes(self, counters: List[str]) -> List[List[str]]: def profile( self, - command: str, + command: str | Sequence[str], metrics: List[str], num_replays: int = 5, aggregate_by_kernel: bool = False, @@ -774,7 +780,7 @@ def _compute_with_stat_type( @abstractmethod def _run_rocprof( self, - command: str, + command: str | Sequence[str], counters: List[str], kernel_filter: Optional[str] = None, cwd: Optional[str] = None, @@ -807,7 +813,10 @@ def _aggregate_by_dispatch_across_runs( # Group by dispatch_id:kernel_name groups = defaultdict(list) for result in results: - key = f"dispatch_{result.dispatch_id}:{result.kernel_name}" + if result.world_size > 1: + key = f"rank_{result.global_rank}:dispatch_{result.dispatch_id}:{result.kernel_name}" + else: + key = f"dispatch_{result.dispatch_id}:{result.kernel_name}" groups[key].append(result) # Compute stats for each group @@ -841,7 +850,11 @@ def _aggregate_by_kernel_then_runs( # Now aggregate merged results across replays groups = defaultdict(list) for merged in merged_replays: - groups[merged.kernel_name].append(merged) + if merged.world_size > 1: + key = f"rank_{merged.global_rank}:{merged.kernel_name}" + else: + key = merged.kernel_name + groups[key].append(merged) aggregated = {} for kernel_name, dispatches in groups.items(): @@ -945,6 +958,12 @@ def _should_average(name: str) -> bool: arch_vgpr=first.arch_vgpr, accum_vgpr=first.accum_vgpr, sgpr=first.sgpr, + global_rank=first.global_rank, + local_rank=first.local_rank, + world_size=first.world_size, + node_rank=first.node_rank, + hostname=first.hostname, + launcher=first.launcher, ) merged._num_dispatches = len(dispatches) return merged diff --git a/metrix/src/metrix/backends/gfx1201.py b/metrix/src/metrix/backends/gfx1201.py index ee477af..5586c75 100644 --- a/metrix/src/metrix/backends/gfx1201.py +++ b/metrix/src/metrix/backends/gfx1201.py @@ -7,7 +7,7 @@ from .base import CounterBackend, DeviceSpecs, ProfileResult, Statistics from ..profiler.rocprof_wrapper import ROCProfV3Wrapper from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Sequence class GFX1201Backend(CounterBackend): @@ -52,7 +52,7 @@ def _get_device_specs(self) -> DeviceSpecs: def _run_rocprof( self, - command: str, + command: str | Sequence[str], counters: List[str], kernel_filter: Optional[str] = None, cwd: Optional[str] = None, diff --git a/metrix/src/metrix/backends/gfx90a.py b/metrix/src/metrix/backends/gfx90a.py index bc38992..5a811e3 100644 --- a/metrix/src/metrix/backends/gfx90a.py +++ b/metrix/src/metrix/backends/gfx90a.py @@ -9,7 +9,7 @@ from ..utils.common import split_counters_into_passes from .decorator import metric from ..profiler.rocprof_wrapper import ROCProfV3Wrapper -from typing import List, Optional, Dict +from typing import Dict, List, Optional, Sequence class GFX90aBackend(CounterBackend): @@ -94,7 +94,7 @@ def _get_counter_block_limits(self) -> Dict[str, int]: def _run_rocprof( self, - command: str, + command: str | Sequence[str], counters: List[str], kernel_filter: Optional[str] = None, cwd: Optional[str] = None, diff --git a/metrix/src/metrix/backends/gfx942.py b/metrix/src/metrix/backends/gfx942.py index 3391012..4ddada5 100644 --- a/metrix/src/metrix/backends/gfx942.py +++ b/metrix/src/metrix/backends/gfx942.py @@ -9,7 +9,7 @@ from ..utils.common import split_counters_into_passes from .decorator import metric from ..profiler.rocprof_wrapper import ROCProfV3Wrapper -from typing import List, Optional, Dict +from typing import Dict, List, Optional, Sequence class GFX942Backend(CounterBackend): @@ -94,7 +94,7 @@ def _get_counter_block_limits(self) -> Dict[str, int]: def _run_rocprof( self, - command: str, + command: str | Sequence[str], counters: List[str], kernel_filter: Optional[str] = None, cwd: Optional[str] = None, diff --git a/metrix/src/metrix/cli/main.py b/metrix/src/metrix/cli/main.py index ae99015..4396112 100644 --- a/metrix/src/metrix/cli/main.py +++ b/metrix/src/metrix/cli/main.py @@ -53,7 +53,11 @@ def create_parser(): profile_parser.add_argument( "target", - help="Target application command (e.g., ./my_app or './my_app arg1 arg2')", + nargs=argparse.REMAINDER, + help=( + "Target application command. Use '--' before the command for complex launch lines " + "(e.g., metrix profile -- torchrun --nproc_per_node=8 train.py)" + ), ) profile_group = profile_parser.add_mutually_exclusive_group() diff --git a/metrix/src/metrix/cli/profile_cmd.py b/metrix/src/metrix/cli/profile_cmd.py index 9ddd17a..c7e06e4 100644 --- a/metrix/src/metrix/cli/profile_cmd.py +++ b/metrix/src/metrix/cli/profile_cmd.py @@ -13,9 +13,14 @@ from ..backends import get_backend, Statistics, detect_or_default from ..metrics import METRIC_PROFILES, METRIC_CATALOG from ..logger import logger +from ..utils.distributed import apply_rank_suffix, detect_distributed_context, normalize_command_argv def profile_command(args): + command_argv = normalize_command_argv(_normalize_cli_target(args.target)) + command_display = " ".join(command_argv) + dist_context = detect_distributed_context() + """Execute profile command using clean backend API""" # Auto-detect architecture @@ -79,7 +84,16 @@ def profile_command(args): # Log configuration logger.info(f"{'=' * 80}") logger.info(f"Metrix: {mode}") - logger.info(f"Target: {args.target}") + logger.info(f"Target: {command_display}") + if dist_context.is_distributed: + logger.info( + "Distributed context: launcher=%s rank=%s/%s local_rank=%s host=%s", + dist_context.launcher, + dist_context.global_rank, + dist_context.world_size, + dist_context.local_rank, + dist_context.hostname, + ) if args.num_replays > 1: logger.info(f"Replays: {args.num_replays}") if args.kernel: @@ -95,7 +109,7 @@ def profile_command(args): logger.info(f"Running {args.num_replays} replays...") backend.profile( - command=args.target, + command=command_argv, metrics=metrics_to_compute, num_replays=args.num_replays, aggregate_by_kernel=args.aggregate, @@ -146,28 +160,46 @@ def profile_command(args): logger.warning(f"Failed to compute {metric} for {dispatch_key}: {e}") # Output results - if args.output: + output_path = args.output + if output_path and dist_context.is_distributed: + output_path = apply_rank_suffix(output_path, dist_context) + logger.info("Distributed output path for rank %s: %s", dist_context.global_rank, output_path) + + if output_path: # Detect format from file extension - output_path = Path(args.output) - ext = output_path.suffix.lower() + output_file = Path(output_path) + ext = output_file.suffix.lower() if ext == ".json": - _write_json_output(output_path, results, metrics_to_compute) + _write_json_output(output_file, results, metrics_to_compute, dist_context) elif ext == ".csv": - _write_csv_output(output_path, results, metrics_to_compute, args.aggregate) + _write_csv_output(output_file, results, metrics_to_compute, args.aggregate, dist_context) else: # Default to text - _write_text_output(output_path, results, metrics_to_compute, args.aggregate) + _write_text_output(output_file, results, metrics_to_compute, args.aggregate, dist_context) else: # Print to stdout - _print_text_results(results, metrics_to_compute, args.aggregate, args.no_counters) + _print_text_results( + results, metrics_to_compute, args.aggregate, args.no_counters, dist_context + ) logger.info(f"Profiled {len(results)} dispatch(es)/kernel(s)") return 0 -def _print_text_results(results: Dict, metrics: List[str], aggregated: bool, no_counters: bool): +def _normalize_cli_target(target) -> str | list[str]: + """Normalize argparse target (string or remainder list) into command input.""" + if isinstance(target, list): + if target and target[0] == "--": + return target[1:] + return target + return target + + +def _print_text_results( + results: Dict, metrics: List[str], aggregated: bool, no_counters: bool, dist_context +): """Print results to stdout in human-readable format""" # Group metrics by category @@ -194,14 +226,25 @@ def _print_text_results(results: Dict, metrics: List[str], aggregated: bool, no_ if aggregated: print(f"Kernel: {dispatch_key}") else: - # dispatch_key is like "dispatch_1:kernel_name" - parts = dispatch_key.split(":", 1) - if len(parts) == 2: + # dispatch_key may be: + # - "dispatch_1:kernel_name" + # - "rank_1:dispatch_1:kernel_name" (distributed) + parts = dispatch_key.split(":") + if len(parts) >= 2 and parts[0].startswith("rank_") and parts[1].startswith("dispatch_"): + dispatch_id = parts[1].replace("dispatch_", "") + kernel_name = ":".join(parts[2:]) if len(parts) > 2 else "" + print(f"Dispatch #{dispatch_id}: {kernel_name}") + elif len(parts) >= 2 and parts[0].startswith("dispatch_"): dispatch_id = parts[0].replace("dispatch_", "") - kernel_name = parts[1] + kernel_name = ":".join(parts[1:]) print(f"Dispatch #{dispatch_id}: {kernel_name}") else: print(f"Kernel: {dispatch_key}") + if dist_context.is_distributed: + print( + f"Rank: {dist_context.global_rank}/{dist_context.world_size} " + f"(local={dist_context.local_rank}, host={dist_context.hostname})" + ) print(f"{'─' * 80}") # Duration @@ -228,9 +271,17 @@ def _print_text_results(results: Dict, metrics: List[str], aggregated: bool, no_ print(f" {name:45s} {stats.avg:10.2f} {unit}") -def _write_json_output(output_path: Path, results: Dict, metrics: List[str]): +def _write_json_output(output_path: Path, results: Dict, metrics: List[str], dist_context): """Write results to JSON file""" - json_data = {} + json_data = { + "_rank": { + "global_rank": dist_context.global_rank, + "local_rank": dist_context.local_rank, + "world_size": dist_context.world_size, + "hostname": dist_context.hostname, + "launcher": dist_context.launcher, + } + } for dispatch_key, data in results.items(): json_data[dispatch_key] = { @@ -258,13 +309,24 @@ def _write_json_output(output_path: Path, results: Dict, metrics: List[str]): logger.info(f"Results written to {output_path}") -def _write_csv_output(output_path: Path, results: Dict, metrics: List[str], aggregated: bool): +def _write_csv_output( + output_path: Path, results: Dict, metrics: List[str], aggregated: bool, dist_context +): """Write results to CSV file""" with open(output_path, "w", newline="") as f: writer = csv.writer(f) # Header - header = ["dispatch_key", "duration_min_us", "duration_max_us", "duration_avg_us"] + header = [ + "global_rank", + "local_rank", + "world_size", + "hostname", + "dispatch_key", + "duration_min_us", + "duration_max_us", + "duration_avg_us", + ] for metric in metrics: header.extend( [ @@ -277,7 +339,13 @@ def _write_csv_output(output_path: Path, results: Dict, metrics: List[str], aggr # Data rows for dispatch_key, data in results.items(): - row = [dispatch_key] + row = [ + dist_context.global_rank, + dist_context.local_rank, + dist_context.world_size, + dist_context.hostname, + dispatch_key, + ] duration = data.get("duration_us") if duration: @@ -297,7 +365,9 @@ def _write_csv_output(output_path: Path, results: Dict, metrics: List[str], aggr logger.info(f"Results written to {output_path}") -def _write_text_output(output_path: Path, results: Dict, metrics: List[str], aggregated: bool): +def _write_text_output( + output_path: Path, results: Dict, metrics: List[str], aggregated: bool, dist_context +): """Write results to text file""" buffer = StringIO() @@ -305,7 +375,7 @@ def _write_text_output(output_path: Path, results: Dict, metrics: List[str], agg old_stdout = sys.stdout sys.stdout = buffer - _print_text_results(results, metrics, aggregated, no_counters=False) + _print_text_results(results, metrics, aggregated, no_counters=False, dist_context=dist_context) sys.stdout = old_stdout diff --git a/metrix/src/metrix/mcp/server.py b/metrix/src/metrix/mcp/server.py index 2d62c89..f012631 100644 --- a/metrix/src/metrix/mcp/server.py +++ b/metrix/src/metrix/mcp/server.py @@ -35,11 +35,24 @@ def profile_metrics(command: str, metrics: list[str] = None) -> dict: results_obj = profiler.profile(command, metrics=metrics) - results = {"kernels": []} + results = { + "rank": { + "global_rank": results_obj.global_rank, + "local_rank": results_obj.local_rank, + "world_size": results_obj.world_size, + "hostname": results_obj.hostname, + "launcher": results_obj.launcher, + }, + "kernels": [], + } for kernel in results_obj.kernels: kernel_data = { "name": kernel.name, + "global_rank": kernel.global_rank, + "local_rank": kernel.local_rank, + "world_size": kernel.world_size, + "hostname": kernel.hostname, "duration_us_avg": float(kernel.duration_us.avg) if hasattr(kernel.duration_us, "avg") else 0.0, diff --git a/metrix/src/metrix/profiler/rocprof_wrapper.py b/metrix/src/metrix/profiler/rocprof_wrapper.py index 13f5315..cd69f1a 100644 --- a/metrix/src/metrix/profiler/rocprof_wrapper.py +++ b/metrix/src/metrix/profiler/rocprof_wrapper.py @@ -10,11 +10,12 @@ import os import yaml from pathlib import Path -from typing import List, Dict, Optional +from typing import List, Dict, Optional, Sequence # Import ProfileResult from backends to avoid duplication from ..backends.base import ProfileResult from ..logger import logger +from ..utils.distributed import detect_distributed_context, normalize_command_argv class ROCProfV3Wrapper: @@ -70,7 +71,7 @@ def _needs_extra_counters(counter_defs_file: Path) -> bool: def profile( self, - command: str, + command: str | Sequence[str], counters: List[str], output_dir: Optional[Path] = None, kernel_filter: Optional[str] = None, @@ -78,6 +79,7 @@ def profile( kernel_iteration_range: Optional[str] = None, extra_counters_path: Optional[Path] = None, arch: Optional[str] = None, + env: Optional[Dict[str, str]] = None, ) -> List[ProfileResult]: """ Profile a command with specified counters (single pass). @@ -166,8 +168,9 @@ def profile( prof_cmd.extend(["--kernel-include-regex", kernel_filter]) # Add target command + command_argv = normalize_command_argv(command) prof_cmd.append("--") - prof_cmd.extend(command.split()) + prof_cmd.extend(command_argv) logger.debug(f"rocprofv3 command: {' '.join(prof_cmd)}") logger.info(f"Starting rocprofv3 with {len(counters)} counters") @@ -181,8 +184,17 @@ def profile( logger.debug(f"Timeout: {self.timeout}") logger.debug(f"CWD: {cwd}") + run_env = os.environ.copy() + if env: + run_env.update(env) + result = subprocess.run( - prof_cmd, capture_output=True, timeout=self.timeout, text=True, cwd=cwd + prof_cmd, + capture_output=True, + timeout=self.timeout, + text=True, + cwd=cwd, + env=run_env, ) logger.info("subprocess.run returned!") @@ -240,6 +252,15 @@ def profile( results = [r for r in results if pattern.search(r.kernel_name)] logger.info(f"After kernel filter '{kernel_filter}': {len(results)} dispatch(es)") + dist_context = detect_distributed_context(run_env) + for profile_result in results: + profile_result.global_rank = dist_context.global_rank + profile_result.local_rank = dist_context.local_rank + profile_result.world_size = dist_context.world_size + profile_result.node_rank = dist_context.node_rank + profile_result.hostname = dist_context.hostname + profile_result.launcher = dist_context.launcher + return results except subprocess.TimeoutExpired: diff --git a/metrix/src/metrix/utils/distributed.py b/metrix/src/metrix/utils/distributed.py new file mode 100644 index 0000000..bd5cea2 --- /dev/null +++ b/metrix/src/metrix/utils/distributed.py @@ -0,0 +1,138 @@ +""" +Helpers for distributed launch environments. +""" + +from __future__ import annotations + +import os +import shlex +import socket +from dataclasses import dataclass +from pathlib import Path +from typing import Mapping, Sequence + + +@dataclass(frozen=True) +class DistributedContext: + """Runtime rank information derived from environment variables.""" + + global_rank: int = 0 + local_rank: int = 0 + world_size: int = 1 + node_rank: int = 0 + hostname: str = "" + launcher: str = "single" + + @property + def is_distributed(self) -> bool: + return self.world_size > 1 + + @property + def rank_tag(self) -> str: + return f"rank{self.global_rank:04d}" + + +def _first_int(env: Mapping[str, str], keys: Sequence[str], default: int) -> int: + for key in keys: + value = env.get(key) + if value is None: + continue + try: + return int(value) + except ValueError: + continue + return default + + +def detect_distributed_context(env: Mapping[str, str] | None = None) -> DistributedContext: + """Detect launcher/rank metadata from a process environment.""" + env_map = os.environ if env is None else env + + global_rank = _first_int( + env_map, + [ + "RANK", + "OMPI_COMM_WORLD_RANK", + "PMI_RANK", + "PMIX_RANK", + "SLURM_PROCID", + "HOROVOD_RANK", + ], + 0, + ) + local_rank = _first_int( + env_map, + [ + "LOCAL_RANK", + "OMPI_COMM_WORLD_LOCAL_RANK", + "MPI_LOCALRANKID", + "SLURM_LOCALID", + "HOROVOD_LOCAL_RANK", + ], + 0, + ) + world_size = _first_int( + env_map, + [ + "WORLD_SIZE", + "OMPI_COMM_WORLD_SIZE", + "PMI_SIZE", + "PMIX_SIZE", + "SLURM_NTASKS", + "HOROVOD_SIZE", + ], + 1, + ) + node_rank = _first_int( + env_map, + ["GROUP_RANK", "NODE_RANK", "OMPI_COMM_WORLD_NODE_RANK", "SLURM_NODEID"], + 0, + ) + hostname = env_map.get("HOSTNAME", "") or socket.gethostname() + + launcher = "single" + if "TORCHELASTIC_RUN_ID" in env_map or "LOCAL_RANK" in env_map: + launcher = "torchrun" + elif "OMPI_COMM_WORLD_RANK" in env_map: + launcher = "mpirun" + elif "SLURM_PROCID" in env_map: + launcher = "srun" + elif "HOROVOD_RANK" in env_map: + launcher = "horovodrun" + + return DistributedContext( + global_rank=global_rank, + local_rank=local_rank, + world_size=world_size, + node_rank=node_rank, + hostname=hostname, + launcher=launcher, + ) + + +def normalize_command_argv(command: str | Sequence[str]) -> list[str]: + """ + Normalize command input into argv list. + + Accepts either a shell-like string or an explicit argv sequence. + """ + if isinstance(command, (list, tuple)): + argv = [str(arg) for arg in command] + else: + argv = shlex.split(command) + + if not argv: + raise ValueError("Command is empty") + return argv + + +def apply_rank_suffix(path: str, context: DistributedContext) -> str: + """Append rank suffix to output paths for distributed runs.""" + if not context.is_distributed: + return path + + p = Path(path) + suffix = p.suffix + if suffix: + return str(p.with_name(f"{p.stem}.{context.rank_tag}{suffix}")) + return str(p.with_name(f"{p.name}.{context.rank_tag}")) diff --git a/metrix/tests/unit/test_distributed.py b/metrix/tests/unit/test_distributed.py new file mode 100644 index 0000000..e05a4b3 --- /dev/null +++ b/metrix/tests/unit/test_distributed.py @@ -0,0 +1,49 @@ +""" +Unit tests for distributed launcher helpers. +""" + +from metrix.utils.distributed import ( + DistributedContext, + apply_rank_suffix, + detect_distributed_context, + normalize_command_argv, +) + + +def test_detect_distributed_context_torchrun_env(): + env = {"RANK": "2", "LOCAL_RANK": "0", "WORLD_SIZE": "4", "TORCHELASTIC_RUN_ID": "run-1"} + ctx = detect_distributed_context(env) + assert ctx.global_rank == 2 + assert ctx.local_rank == 0 + assert ctx.world_size == 4 + assert ctx.launcher == "torchrun" + + +def test_detect_distributed_context_mpi_env(): + env = {"OMPI_COMM_WORLD_RANK": "5", "OMPI_COMM_WORLD_LOCAL_RANK": "1", "OMPI_COMM_WORLD_SIZE": "8"} + ctx = detect_distributed_context(env) + assert ctx.global_rank == 5 + assert ctx.local_rank == 1 + assert ctx.world_size == 8 + assert ctx.launcher == "mpirun" + + +def test_apply_rank_suffix_distributed_file_path(): + ctx = DistributedContext(global_rank=3, world_size=8) + assert apply_rank_suffix("results.json", ctx) == "results.rank0003.json" + + +def test_normalize_command_argv_accepts_string_and_sequence(): + assert normalize_command_argv('torchrun --nproc_per_node=2 train.py --arg "two words"') == [ + "torchrun", + "--nproc_per_node=2", + "train.py", + "--arg", + "two words", + ] + assert normalize_command_argv(["mpirun", "-np", "4", "./app"]) == [ + "mpirun", + "-np", + "4", + "./app", + ] diff --git a/metrix/tests/unit/test_rocprof_wrapper.py b/metrix/tests/unit/test_rocprof_wrapper.py index 205c7c4..057c068 100644 --- a/metrix/tests/unit/test_rocprof_wrapper.py +++ b/metrix/tests/unit/test_rocprof_wrapper.py @@ -290,6 +290,74 @@ def fake_run(cmd, **kwargs): assert len(results) == 1 assert results[0].kernel_name == "my_kernel(float*, int)" + def test_command_string_uses_shlex_parsing(self, wrapper_no_rocm_check): + """Quoted command arguments are preserved via shlex parsing.""" + wrapper = wrapper_no_rocm_check + captured_cmd = [] + + def fake_run(cmd, **kwargs): + captured_cmd.extend(cmd) + mock_result = MagicMock() + mock_result.returncode = 0 + mock_result.stdout = "" + mock_result.stderr = "" + return mock_result + + with ( + patch("subprocess.run", side_effect=fake_run), + patch.object(wrapper, "_parse_output", return_value=[]), + tempfile.TemporaryDirectory() as tmpdir, + ): + wrapper.profile( + command='python -c "print(1 + 2)"', + counters=[], + output_dir=Path(tmpdir), + ) + + assert "--" in captured_cmd + idx = captured_cmd.index("--") + assert captured_cmd[idx + 1 : idx + 4] == ["python", "-c", "print(1 + 2)"] + + def test_profile_sets_distributed_rank_fields(self, wrapper_no_rocm_check): + """ProfileResult receives distributed rank metadata from env.""" + wrapper = wrapper_no_rocm_check + + parsed = [ + ProfileResult( + dispatch_id=1, + kernel_name="my_kernel", + gpu_id=0, + duration_ns=1000, + grid_size=(256, 1, 1), + workgroup_size=(64, 1, 1), + counters={}, + ) + ] + + def fake_run(cmd, **kwargs): + m = MagicMock() + m.returncode = 0 + m.stdout = "" + m.stderr = "" + return m + + with ( + patch("subprocess.run", side_effect=fake_run), + patch.object(wrapper, "_parse_output", return_value=parsed), + tempfile.TemporaryDirectory() as tmpdir, + ): + results = wrapper.profile( + command="true", + counters=[], + output_dir=Path(tmpdir), + env={"RANK": "3", "LOCAL_RANK": "1", "WORLD_SIZE": "8"}, + ) + + assert len(results) == 1 + assert results[0].global_rank == 3 + assert results[0].local_rank == 1 + assert results[0].world_size == 8 + def test_parse_missing_optional_fields(self, wrapper): """Handle missing optional fields gracefully""" row = { From f5ac2f0061e8105c96726ff7285168ffb4b319d4 Mon Sep 17 00:00:00 2001 From: muhaawad Date: Mon, 23 Mar 2026 00:28:21 +0000 Subject: [PATCH 2/8] Fix command construction: launcher wraps rocprofv3, not the reverse Previously, distributed commands like `torchrun --nproc_per_node=8 train.py` produced `rocprofv3 ... -- torchrun --nproc_per_node=8 train.py` which is wrong. rocprofv3 would profile the launcher process, not the per-rank GPU work. Now we split the command into launcher args and app args, producing: `torchrun --nproc_per_node=8 rocprofv3 ... -- train.py` The launcher spawns N processes, each running rocprofv3 around the app. Changes: - Add split_launcher_command() to both distributed.py modules - Handles torchrun, python -m torch.distributed.*, mpirun/mpiexec, srun, horovodrun - Update linex/api.py and metrix/rocprof_wrapper.py to use launcher wrapping - Add tests verifying correct command ordering for all launcher types Co-Authored-By: Claude Opus 4.6 --- linex/src/linex/api.py | 9 +- linex/src/linex/distributed.py | 218 +++++++++++++++- linex/tests/test_distributed_api.py | 105 +++++++- metrix/src/metrix/profiler/rocprof_wrapper.py | 9 +- metrix/src/metrix/utils/distributed.py | 234 ++++++++++++++++-- metrix/tests/unit/test_distributed.py | 27 ++ metrix/tests/unit/test_rocprof_wrapper.py | 33 +++ 7 files changed, 615 insertions(+), 20 deletions(-) diff --git a/linex/src/linex/api.py b/linex/src/linex/api.py index da6c139..790a5cb 100644 --- a/linex/src/linex/api.py +++ b/linex/src/linex/api.py @@ -13,7 +13,7 @@ from pathlib import Path from typing import Dict, List, Optional, Sequence -from .distributed import DistributedContext, detect_distributed_context, normalize_command_argv +from .distributed import DistributedContext, detect_distributed_context, normalize_command_argv, split_launcher_command @dataclass @@ -222,6 +222,7 @@ def profile( output_path.mkdir(parents=True, exist_ok=True) command_argv = normalize_command_argv(command) + launcher_split = split_launcher_command(command_argv) cmd = [ "rocprofv3", @@ -241,7 +242,11 @@ def profile( if kernel_filter: cmd.extend(["--kernel-include-regex", kernel_filter]) - cmd.extend(["--", *command_argv]) + cmd.extend(["--", *launcher_split.app_argv]) + + # If a distributed launcher was detected, wrap: launcher rocprofv3 ... -- app + if launcher_split.is_distributed: + cmd = launcher_split.launcher_argv + cmd if force_cu_mask and "HSA_CU_MASK" not in run_env: run_env["HSA_CU_MASK"] = "0x1" # Force to CU 0 unless caller already set a mask diff --git a/linex/src/linex/distributed.py b/linex/src/linex/distributed.py index 30857a3..487f434 100644 --- a/linex/src/linex/distributed.py +++ b/linex/src/linex/distributed.py @@ -7,10 +7,20 @@ import os import shlex import socket -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Mapping, Sequence +KNOWN_LAUNCHERS = { + "torchrun": "torchrun", + "python": None, # only when followed by -m torch.distributed + "mpirun": "mpirun", + "mpiexec": "mpirun", + "srun": "srun", + "horovodrun": "horovodrun", +} + + @dataclass(frozen=True) class DistributedContext: global_rank: int = 0 @@ -29,6 +39,18 @@ def rank_tag(self) -> str: return f"rank{self.global_rank:04d}" +@dataclass +class LauncherSplit: + """Result of splitting a command into launcher prefix and application suffix.""" + launcher_argv: list[str] = field(default_factory=list) + app_argv: list[str] = field(default_factory=list) + launcher_name: str = "single" + + @property + def is_distributed(self) -> bool: + return len(self.launcher_argv) > 0 + + def _first_int(env: Mapping[str, str], keys: Sequence[str], default: int) -> int: for key in keys: value = env.get(key) @@ -113,3 +135,197 @@ def normalize_command_argv(command: str | Sequence[str]) -> list[str]: if not argv: raise ValueError("Command is empty") return argv + + +def split_launcher_command(argv: list[str]) -> LauncherSplit: + """Split a command argv into launcher prefix and application suffix. + + Recognizes torchrun, mpirun/mpiexec, srun, horovodrun. + For torchrun/python -m torch.distributed.*, all flags (--nproc_per_node etc.) + before the script name are launcher args; the script and everything after are app args. + For mpirun/mpiexec/srun/horovodrun, we split at the first positional arg that + looks like an executable (not a flag). + + Returns a LauncherSplit with launcher_argv (empty if no launcher detected) + and app_argv. + """ + if not argv: + return LauncherSplit(app_argv=argv) + + binary = os.path.basename(argv[0]) + + # --- torchrun --- + if binary == "torchrun": + return _split_torchrun(argv) + + # --- python -m torch.distributed.launch / python -m torch.distributed.run --- + if binary in ("python", "python3") and len(argv) >= 3: + if argv[1] == "-m" and argv[2].startswith("torch.distributed"): + return _split_torchrun(argv) + + # --- mpirun / mpiexec --- + if binary in ("mpirun", "mpiexec"): + return _split_mpi(argv) + + # --- srun --- + if binary == "srun": + return _split_srun(argv) + + # --- horovodrun --- + if binary == "horovodrun": + return _split_horovodrun(argv) + + # No launcher detected + return LauncherSplit(app_argv=argv) + + +def _split_torchrun(argv: list[str]) -> LauncherSplit: + """Split torchrun command. Flags before the script are launcher args.""" + # torchrun [flags] script.py [script args] + # Flags all start with -- and some take a value argument. + # We find the first arg that doesn't start with - and isn't a value of a flag. + TORCHRUN_VALUE_FLAGS = { + "--nproc_per_node", "--nproc-per-node", "--nnodes", + "--node_rank", "--node-rank", "--master_addr", "--master-addr", + "--master_port", "--master-port", "--rdzv_id", "--rdzv-id", + "--rdzv_backend", "--rdzv-backend", "--rdzv_endpoint", "--rdzv-endpoint", + "--rdzv_conf", "--rdzv-conf", "--max_restarts", "--max-restarts", + "--monitor_interval", "--monitor-interval", "--log_dir", "--log-dir", + "--redirects", "--tee", "-r", "-t", "--role", "--local_addr", + "--local-addr", "--logs_specs", "--logs-specs", + "--start_method", "--start-method", "--run_path", "--run-path", + "--omp_num_threads", "--omp-num-threads", + } + i = 1 # skip argv[0] (torchrun / python) + # skip python -m torch.distributed.run if present + if os.path.basename(argv[0]) in ("python", "python3") and len(argv) > 2 and argv[1] == "-m": + i = 3 # skip python -m torch.distributed.run + + while i < len(argv): + arg = argv[i] + if arg == "--": + # Explicit separator + return LauncherSplit( + launcher_argv=argv[:i], + app_argv=argv[i + 1:], + launcher_name="torchrun", + ) + if arg.startswith("-"): + # Check if this flag takes a value + flag_name = arg.split("=")[0] + if "=" not in arg and flag_name in TORCHRUN_VALUE_FLAGS: + i += 2 # skip flag and its value + else: + i += 1 + else: + # First positional = script name, everything from here is app + return LauncherSplit( + launcher_argv=argv[:i], + app_argv=argv[i:], + launcher_name="torchrun", + ) + + # All flags, no script found — treat entire thing as app + return LauncherSplit(app_argv=argv) + + +def _split_mpi(argv: list[str]) -> LauncherSplit: + """Split mpirun/mpiexec command.""" + MPI_VALUE_FLAGS = { + "-np", "-n", "--np", "-N", "--map-by", "--bind-to", "--rank-by", + "-H", "--host", "--hostfile", "-x", "--mca", "-wdir", "--wdir", + "-oversubscribe", "--oversubscribe", "--prefix", "-output-filename", + "--output-filename", "--report-bindings", + } + # Flags that take TWO values + MPI_DOUBLE_VALUE_FLAGS = {"--mca"} + + i = 1 + while i < len(argv): + arg = argv[i] + if arg == ":": + # MPMD separator — everything before is one command spec + # For simplicity, treat everything up to : as launcher + return LauncherSplit( + launcher_argv=argv[:i], + app_argv=argv[i:], + launcher_name="mpirun", + ) + if arg.startswith("-"): + flag_name = arg.split("=")[0] + if "=" not in arg and flag_name in MPI_DOUBLE_VALUE_FLAGS: + i += 3 # --mca key value + elif "=" not in arg and flag_name in MPI_VALUE_FLAGS: + i += 2 + else: + i += 1 + else: + # First positional = executable + return LauncherSplit( + launcher_argv=argv[:i], + app_argv=argv[i:], + launcher_name="mpirun", + ) + + return LauncherSplit(app_argv=argv) + + +def _split_srun(argv: list[str]) -> LauncherSplit: + """Split srun command.""" + SRUN_VALUE_FLAGS = { + "-N", "--nodes", "-n", "--ntasks", "-c", "--cpus-per-task", + "-G", "--gpus", "--gpus-per-node", "--gpus-per-task", + "-p", "--partition", "-w", "--nodelist", "-x", "--exclude", + "-t", "--time", "-J", "--job-name", "-o", "--output", "-e", "--error", + "--mem", "--mem-per-cpu", "--mem-per-gpu", "-D", "--chdir", + "--export", "--mpi", "--distribution", + } + i = 1 + while i < len(argv): + arg = argv[i] + if arg == "--": + return LauncherSplit( + launcher_argv=argv[:i], + app_argv=argv[i + 1:], + launcher_name="srun", + ) + if arg.startswith("-"): + flag_name = arg.split("=")[0] + if "=" not in arg and flag_name in SRUN_VALUE_FLAGS: + i += 2 + else: + i += 1 + else: + return LauncherSplit( + launcher_argv=argv[:i], + app_argv=argv[i:], + launcher_name="srun", + ) + + return LauncherSplit(app_argv=argv) + + +def _split_horovodrun(argv: list[str]) -> LauncherSplit: + """Split horovodrun command.""" + HOROVOD_VALUE_FLAGS = { + "-np", "-p", "--num-proc", "-H", "--hosts", "--hostfile", + "--start-timeout", "--network-interface", "--output-filename", + "--gloo-timeout-seconds", + } + i = 1 + while i < len(argv): + arg = argv[i] + if arg.startswith("-"): + flag_name = arg.split("=")[0] + if "=" not in arg and flag_name in HOROVOD_VALUE_FLAGS: + i += 2 + else: + i += 1 + else: + return LauncherSplit( + launcher_argv=argv[:i], + app_argv=argv[i:], + launcher_name="horovodrun", + ) + + return LauncherSplit(app_argv=argv) diff --git a/linex/tests/test_distributed_api.py b/linex/tests/test_distributed_api.py index 65c96a8..9c0db98 100644 --- a/linex/tests/test_distributed_api.py +++ b/linex/tests/test_distributed_api.py @@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch from linex import Linex -from linex.distributed import detect_distributed_context, normalize_command_argv +from linex.distributed import detect_distributed_context, normalize_command_argv, split_launcher_command def _write_code_json(ui_dir: Path, source_location: str) -> None: @@ -71,3 +71,106 @@ def fake_run(cmd, **kwargs): assert profiler.distributed_context.is_distributed assert profiler.distributed_context.global_rank == 1 assert len(profiler.rank_profiles) == 2 + + +def test_split_launcher_command_torchrun(): + from linex.distributed import split_launcher_command + split = split_launcher_command([ + "torchrun", "--nproc_per_node=8", "--nnodes", "2", "train.py", "--lr", "0.01" + ]) + assert split.is_distributed + assert split.launcher_name == "torchrun" + assert split.launcher_argv == ["torchrun", "--nproc_per_node=8", "--nnodes", "2"] + assert split.app_argv == ["train.py", "--lr", "0.01"] + + +def test_split_launcher_command_mpirun(): + from linex.distributed import split_launcher_command + split = split_launcher_command([ + "mpirun", "-np", "4", "--bind-to", "core", "./my_app", "--size", "1024" + ]) + assert split.is_distributed + assert split.launcher_name == "mpirun" + assert split.launcher_argv == ["mpirun", "-np", "4", "--bind-to", "core"] + assert split.app_argv == ["./my_app", "--size", "1024"] + + +def test_split_launcher_command_srun(): + from linex.distributed import split_launcher_command + split = split_launcher_command([ + "srun", "-N", "2", "-n", "16", "--gpus-per-node", "8", "./app" + ]) + assert split.is_distributed + assert split.launcher_name == "srun" + assert split.launcher_argv == ["srun", "-N", "2", "-n", "16", "--gpus-per-node", "8"] + assert split.app_argv == ["./app"] + + +def test_split_launcher_command_no_launcher(): + from linex.distributed import split_launcher_command + split = split_launcher_command(["python3", "train.py", "--epochs", "10"]) + assert not split.is_distributed + assert split.launcher_argv == [] + assert split.app_argv == ["python3", "train.py", "--epochs", "10"] + + +def test_split_launcher_command_python_m_torch_distributed(): + from linex.distributed import split_launcher_command + split = split_launcher_command([ + "python3", "-m", "torch.distributed.run", "--nproc_per_node=4", "train.py" + ]) + assert split.is_distributed + assert split.launcher_name == "torchrun" + assert split.launcher_argv == ["python3", "-m", "torch.distributed.run", "--nproc_per_node=4"] + assert split.app_argv == ["train.py"] + + +def test_profile_builds_correct_command_order_with_launcher(tmp_path): + """Verify that when a launcher is detected, the subprocess command is + launcher_args + rocprofv3 ... -- app_args (not rocprofv3 -- launcher app).""" + import json + from pathlib import Path + from unittest.mock import MagicMock, patch + + from linex import Linex + + dummy_decoder = tmp_path / "decoder" / "librocprof-trace-decoder.so" + dummy_decoder.parent.mkdir(parents=True, exist_ok=True) + dummy_decoder.write_text("placeholder") + + captured_cmd = [] + + def fake_run(cmd, **kwargs): + captured_cmd.extend(cmd) + output_dir = Path(cmd[cmd.index("-d") + 1]) + # Write dummy trace data + ui_dir = output_dir / "ui_output_000" + ui_dir.mkdir(parents=True, exist_ok=True) + code = {"code": [["s_nop 0", 0, 0, "test.hip:1", 1, 0x1000, 4, 10, 2, 0]]} + (ui_dir / "code.json").write_text(json.dumps(code)) + m = MagicMock() + m.returncode = 0 + m.stdout = "" + m.stderr = "" + return m + + with ( + patch.object(Linex, "_ensure_decoder", return_value=dummy_decoder), + patch("subprocess.run", side_effect=fake_run), + ): + profiler = Linex() + profiler.profile( + command="torchrun --nproc_per_node=4 train.py --lr 0.01", + output_dir=str(tmp_path / "out"), + ) + + # The command should start with torchrun, then rocprofv3 + assert captured_cmd[0] == "torchrun" + assert captured_cmd[1] == "--nproc_per_node=4" + rocprofv3_idx = captured_cmd.index("rocprofv3") + assert rocprofv3_idx > 1 # rocprofv3 comes after launcher args + # After --, the app args should be train.py, not torchrun + separator_idx = captured_cmd.index("--") + assert captured_cmd[separator_idx + 1] == "train.py" + assert captured_cmd[separator_idx + 2] == "--lr" + assert captured_cmd[separator_idx + 3] == "0.01" diff --git a/metrix/src/metrix/profiler/rocprof_wrapper.py b/metrix/src/metrix/profiler/rocprof_wrapper.py index cd69f1a..8f4c3e1 100644 --- a/metrix/src/metrix/profiler/rocprof_wrapper.py +++ b/metrix/src/metrix/profiler/rocprof_wrapper.py @@ -15,7 +15,7 @@ # Import ProfileResult from backends to avoid duplication from ..backends.base import ProfileResult from ..logger import logger -from ..utils.distributed import detect_distributed_context, normalize_command_argv +from ..utils.distributed import detect_distributed_context, normalize_command_argv, split_launcher_command class ROCProfV3Wrapper: @@ -169,8 +169,13 @@ def profile( # Add target command command_argv = normalize_command_argv(command) + launcher_split = split_launcher_command(command_argv) prof_cmd.append("--") - prof_cmd.extend(command_argv) + prof_cmd.extend(launcher_split.app_argv) + + # If a distributed launcher was detected, wrap: launcher rocprofv3 ... -- app + if launcher_split.is_distributed: + prof_cmd = launcher_split.launcher_argv + prof_cmd logger.debug(f"rocprofv3 command: {' '.join(prof_cmd)}") logger.info(f"Starting rocprofv3 with {len(counters)} counters") diff --git a/metrix/src/metrix/utils/distributed.py b/metrix/src/metrix/utils/distributed.py index bd5cea2..0821233 100644 --- a/metrix/src/metrix/utils/distributed.py +++ b/metrix/src/metrix/utils/distributed.py @@ -1,5 +1,5 @@ """ -Helpers for distributed launch environments. +Distributed launcher helpers for Linex. """ from __future__ import annotations @@ -7,15 +7,22 @@ import os import shlex import socket -from dataclasses import dataclass -from pathlib import Path +from dataclasses import dataclass, field from typing import Mapping, Sequence +KNOWN_LAUNCHERS = { + "torchrun": "torchrun", + "python": None, # only when followed by -m torch.distributed + "mpirun": "mpirun", + "mpiexec": "mpirun", + "srun": "srun", + "horovodrun": "horovodrun", +} + + @dataclass(frozen=True) class DistributedContext: - """Runtime rank information derived from environment variables.""" - global_rank: int = 0 local_rank: int = 0 world_size: int = 1 @@ -32,6 +39,18 @@ def rank_tag(self) -> str: return f"rank{self.global_rank:04d}" +@dataclass +class LauncherSplit: + """Result of splitting a command into launcher prefix and application suffix.""" + launcher_argv: list[str] = field(default_factory=list) + app_argv: list[str] = field(default_factory=list) + launcher_name: str = "single" + + @property + def is_distributed(self) -> bool: + return len(self.launcher_argv) > 0 + + def _first_int(env: Mapping[str, str], keys: Sequence[str], default: int) -> int: for key in keys: value = env.get(key) @@ -45,9 +64,7 @@ def _first_int(env: Mapping[str, str], keys: Sequence[str], default: int) -> int def detect_distributed_context(env: Mapping[str, str] | None = None) -> DistributedContext: - """Detect launcher/rank metadata from a process environment.""" env_map = os.environ if env is None else env - global_rank = _first_int( env_map, [ @@ -111,27 +128,216 @@ def detect_distributed_context(env: Mapping[str, str] | None = None) -> Distribu def normalize_command_argv(command: str | Sequence[str]) -> list[str]: - """ - Normalize command input into argv list. - - Accepts either a shell-like string or an explicit argv sequence. - """ if isinstance(command, (list, tuple)): argv = [str(arg) for arg in command] else: argv = shlex.split(command) - if not argv: raise ValueError("Command is empty") return argv +def split_launcher_command(argv: list[str]) -> LauncherSplit: + """Split a command argv into launcher prefix and application suffix. + + Recognizes torchrun, mpirun/mpiexec, srun, horovodrun. + For torchrun/python -m torch.distributed.*, all flags (--nproc_per_node etc.) + before the script name are launcher args; the script and everything after are app args. + For mpirun/mpiexec/srun/horovodrun, we split at the first positional arg that + looks like an executable (not a flag). + + Returns a LauncherSplit with launcher_argv (empty if no launcher detected) + and app_argv. + """ + if not argv: + return LauncherSplit(app_argv=argv) + + binary = os.path.basename(argv[0]) + + # --- torchrun --- + if binary == "torchrun": + return _split_torchrun(argv) + + # --- python -m torch.distributed.launch / python -m torch.distributed.run --- + if binary in ("python", "python3") and len(argv) >= 3: + if argv[1] == "-m" and argv[2].startswith("torch.distributed"): + return _split_torchrun(argv) + + # --- mpirun / mpiexec --- + if binary in ("mpirun", "mpiexec"): + return _split_mpi(argv) + + # --- srun --- + if binary == "srun": + return _split_srun(argv) + + # --- horovodrun --- + if binary == "horovodrun": + return _split_horovodrun(argv) + + # No launcher detected + return LauncherSplit(app_argv=argv) + + +def _split_torchrun(argv: list[str]) -> LauncherSplit: + """Split torchrun command. Flags before the script are launcher args.""" + # torchrun [flags] script.py [script args] + # Flags all start with -- and some take a value argument. + # We find the first arg that doesn't start with - and isn't a value of a flag. + TORCHRUN_VALUE_FLAGS = { + "--nproc_per_node", "--nproc-per-node", "--nnodes", + "--node_rank", "--node-rank", "--master_addr", "--master-addr", + "--master_port", "--master-port", "--rdzv_id", "--rdzv-id", + "--rdzv_backend", "--rdzv-backend", "--rdzv_endpoint", "--rdzv-endpoint", + "--rdzv_conf", "--rdzv-conf", "--max_restarts", "--max-restarts", + "--monitor_interval", "--monitor-interval", "--log_dir", "--log-dir", + "--redirects", "--tee", "-r", "-t", "--role", "--local_addr", + "--local-addr", "--logs_specs", "--logs-specs", + "--start_method", "--start-method", "--run_path", "--run-path", + "--omp_num_threads", "--omp-num-threads", + } + i = 1 # skip argv[0] (torchrun / python) + # skip python -m torch.distributed.run if present + if os.path.basename(argv[0]) in ("python", "python3") and len(argv) > 2 and argv[1] == "-m": + i = 3 # skip python -m torch.distributed.run + + while i < len(argv): + arg = argv[i] + if arg == "--": + # Explicit separator + return LauncherSplit( + launcher_argv=argv[:i], + app_argv=argv[i + 1:], + launcher_name="torchrun", + ) + if arg.startswith("-"): + # Check if this flag takes a value + flag_name = arg.split("=")[0] + if "=" not in arg and flag_name in TORCHRUN_VALUE_FLAGS: + i += 2 # skip flag and its value + else: + i += 1 + else: + # First positional = script name, everything from here is app + return LauncherSplit( + launcher_argv=argv[:i], + app_argv=argv[i:], + launcher_name="torchrun", + ) + + # All flags, no script found — treat entire thing as app + return LauncherSplit(app_argv=argv) + + +def _split_mpi(argv: list[str]) -> LauncherSplit: + """Split mpirun/mpiexec command.""" + MPI_VALUE_FLAGS = { + "-np", "-n", "--np", "-N", "--map-by", "--bind-to", "--rank-by", + "-H", "--host", "--hostfile", "-x", "--mca", "-wdir", "--wdir", + "-oversubscribe", "--oversubscribe", "--prefix", "-output-filename", + "--output-filename", "--report-bindings", + } + # Flags that take TWO values + MPI_DOUBLE_VALUE_FLAGS = {"--mca"} + + i = 1 + while i < len(argv): + arg = argv[i] + if arg == ":": + # MPMD separator — everything before is one command spec + # For simplicity, treat everything up to : as launcher + return LauncherSplit( + launcher_argv=argv[:i], + app_argv=argv[i:], + launcher_name="mpirun", + ) + if arg.startswith("-"): + flag_name = arg.split("=")[0] + if "=" not in arg and flag_name in MPI_DOUBLE_VALUE_FLAGS: + i += 3 # --mca key value + elif "=" not in arg and flag_name in MPI_VALUE_FLAGS: + i += 2 + else: + i += 1 + else: + # First positional = executable + return LauncherSplit( + launcher_argv=argv[:i], + app_argv=argv[i:], + launcher_name="mpirun", + ) + + return LauncherSplit(app_argv=argv) + + +def _split_srun(argv: list[str]) -> LauncherSplit: + """Split srun command.""" + SRUN_VALUE_FLAGS = { + "-N", "--nodes", "-n", "--ntasks", "-c", "--cpus-per-task", + "-G", "--gpus", "--gpus-per-node", "--gpus-per-task", + "-p", "--partition", "-w", "--nodelist", "-x", "--exclude", + "-t", "--time", "-J", "--job-name", "-o", "--output", "-e", "--error", + "--mem", "--mem-per-cpu", "--mem-per-gpu", "-D", "--chdir", + "--export", "--mpi", "--distribution", + } + i = 1 + while i < len(argv): + arg = argv[i] + if arg == "--": + return LauncherSplit( + launcher_argv=argv[:i], + app_argv=argv[i + 1:], + launcher_name="srun", + ) + if arg.startswith("-"): + flag_name = arg.split("=")[0] + if "=" not in arg and flag_name in SRUN_VALUE_FLAGS: + i += 2 + else: + i += 1 + else: + return LauncherSplit( + launcher_argv=argv[:i], + app_argv=argv[i:], + launcher_name="srun", + ) + + return LauncherSplit(app_argv=argv) + + +def _split_horovodrun(argv: list[str]) -> LauncherSplit: + """Split horovodrun command.""" + HOROVOD_VALUE_FLAGS = { + "-np", "-p", "--num-proc", "-H", "--hosts", "--hostfile", + "--start-timeout", "--network-interface", "--output-filename", + "--gloo-timeout-seconds", + } + i = 1 + while i < len(argv): + arg = argv[i] + if arg.startswith("-"): + flag_name = arg.split("=")[0] + if "=" not in arg and flag_name in HOROVOD_VALUE_FLAGS: + i += 2 + else: + i += 1 + else: + return LauncherSplit( + launcher_argv=argv[:i], + app_argv=argv[i:], + launcher_name="horovodrun", + ) + + return LauncherSplit(app_argv=argv) + + def apply_rank_suffix(path: str, context: DistributedContext) -> str: """Append rank suffix to output paths for distributed runs.""" if not context.is_distributed: return path - p = Path(path) + from pathlib import Path as _Path + p = _Path(path) suffix = p.suffix if suffix: return str(p.with_name(f"{p.stem}.{context.rank_tag}{suffix}")) diff --git a/metrix/tests/unit/test_distributed.py b/metrix/tests/unit/test_distributed.py index e05a4b3..b83ea22 100644 --- a/metrix/tests/unit/test_distributed.py +++ b/metrix/tests/unit/test_distributed.py @@ -7,6 +7,7 @@ apply_rank_suffix, detect_distributed_context, normalize_command_argv, + split_launcher_command, ) @@ -47,3 +48,29 @@ def test_normalize_command_argv_accepts_string_and_sequence(): "4", "./app", ] + + +def test_split_launcher_command_torchrun(): + split = split_launcher_command([ + "torchrun", "--nproc_per_node=8", "train.py", "--lr", "0.01" + ]) + assert split.is_distributed + assert split.launcher_name == "torchrun" + assert split.launcher_argv == ["torchrun", "--nproc_per_node=8"] + assert split.app_argv == ["train.py", "--lr", "0.01"] + + +def test_split_launcher_command_mpirun(): + split = split_launcher_command([ + "mpirun", "-np", "4", "./my_app", "--size", "1024" + ]) + assert split.is_distributed + assert split.launcher_name == "mpirun" + assert split.launcher_argv == ["mpirun", "-np", "4"] + assert split.app_argv == ["./my_app", "--size", "1024"] + + +def test_split_launcher_command_no_launcher(): + split = split_launcher_command(["./my_app", "--size", "1024"]) + assert not split.is_distributed + assert split.app_argv == ["./my_app", "--size", "1024"] diff --git a/metrix/tests/unit/test_rocprof_wrapper.py b/metrix/tests/unit/test_rocprof_wrapper.py index 057c068..b0f8c68 100644 --- a/metrix/tests/unit/test_rocprof_wrapper.py +++ b/metrix/tests/unit/test_rocprof_wrapper.py @@ -358,6 +358,39 @@ def fake_run(cmd, **kwargs): assert results[0].local_rank == 1 assert results[0].world_size == 8 + + def test_launcher_command_builds_correct_order(self, wrapper_no_rocm_check): + """torchrun command should produce: torchrun args rocprofv3 ... -- app args.""" + wrapper = wrapper_no_rocm_check + captured_cmd = [] + + def fake_run(cmd, **kwargs): + captured_cmd.extend(cmd) + mock_result = MagicMock() + mock_result.returncode = 0 + mock_result.stdout = "" + mock_result.stderr = "" + return mock_result + + with ( + patch("subprocess.run", side_effect=fake_run), + patch.object(wrapper, "_parse_output", return_value=[]), + tempfile.TemporaryDirectory() as tmpdir, + ): + wrapper.profile( + command="torchrun --nproc_per_node=4 train.py --lr 0.01", + counters=[], + output_dir=Path(tmpdir), + ) + + assert captured_cmd[0] == "torchrun" + assert captured_cmd[1] == "--nproc_per_node=4" + rocprofv3_idx = captured_cmd.index("rocprofv3") + assert rocprofv3_idx > 1 + separator_idx = captured_cmd.index("--") + assert captured_cmd[separator_idx + 1] == "train.py" + assert captured_cmd[separator_idx + 2] == "--lr" + def test_parse_missing_optional_fields(self, wrapper): """Handle missing optional fields gracefully""" row = { From a6406dab53e1aa3890b0a3c99b2bd55d9bc8899d Mon Sep 17 00:00:00 2001 From: muhaawad Date: Mon, 23 Mar 2026 00:34:01 +0000 Subject: [PATCH 3/8] Replace auto-detection with explicit launcher parameter Instead of trying to parse launcher flags from a combined command string (fragile, requires hardcoded flag sets per launcher), let the user provide the launcher and app commands separately: # Python API profiler.profile(command="train.py", launcher="torchrun --nproc_per_node=8") # Metrix CLI metrix profile --launcher "torchrun --nproc_per_node=8" -- train.py This is unambiguous, works with any launcher (including custom ones), and requires no flag-parsing maintenance. - Remove split_launcher_command() and all _split_* helpers - Add launcher parameter to Linex.profile(), Metrix.profile(), ROCProfV3Wrapper.profile(), CounterBackend.profile(), all backend implementations, CLI (--launcher flag), and MCP tools - Update tests and READMEs Co-Authored-By: Claude Opus 4.6 --- linex/README.md | 11 +- linex/src/linex/api.py | 13 +- linex/src/linex/distributed.py | 218 +---------------- linex/src/linex/mcp/server.py | 6 +- linex/tests/test_distributed_api.py | 111 ++++----- metrix/README.md | 12 +- metrix/src/metrix/api.py | 2 + metrix/src/metrix/backends/base.py | 2 + metrix/src/metrix/backends/gfx1201.py | 1 + metrix/src/metrix/backends/gfx90a.py | 1 + metrix/src/metrix/backends/gfx942.py | 1 + metrix/src/metrix/cli/main.py | 9 + metrix/src/metrix/cli/profile_cmd.py | 1 + metrix/src/metrix/mcp/server.py | 4 +- metrix/src/metrix/profiler/rocprof_wrapper.py | 13 +- metrix/src/metrix/utils/distributed.py | 220 +----------------- metrix/tests/unit/test_distributed.py | 37 +-- metrix/tests/unit/test_rocprof_wrapper.py | 37 ++- 18 files changed, 140 insertions(+), 559 deletions(-) diff --git a/linex/README.md b/linex/README.md index 44057a9..418535e 100644 --- a/linex/README.md +++ b/linex/README.md @@ -26,13 +26,15 @@ for line in profiler.source_lines[:5]: ## Distributed Launchers -Linex can wrap distributed launcher commands (`torchrun`, `mpirun/mpiexec`, `srun`, -`horovodrun`) and automatically records rank metadata from common environment variables. +Linex supports distributed profiling with launchers like `torchrun`, `mpirun`, +`srun`, and `horovodrun`. Pass the launcher separately so Linex builds the +correct command order (`launcher rocprofv3 ... -- app`). ```python profiler = Linex() profiler.profile( - "torchrun --nproc_per_node=8 train.py", + command="train.py", + launcher="torchrun --nproc_per_node=8", output_dir="linex_sqtt", ) @@ -42,7 +44,8 @@ for rank_key, rank_profile in profiler.rank_profiles.items(): ``` In distributed mode, Linex writes traces into rank-specific subdirectories -(`.../rank0000`, `.../rank0001`, ...) to avoid collisions. +(`.../rank0000`, `.../rank0001`, ...) to avoid collisions. Rank metadata is +automatically detected from environment variables set by the launcher. ## What You Get diff --git a/linex/src/linex/api.py b/linex/src/linex/api.py index 790a5cb..1b9fd39 100644 --- a/linex/src/linex/api.py +++ b/linex/src/linex/api.py @@ -13,7 +13,7 @@ from pathlib import Path from typing import Dict, List, Optional, Sequence -from .distributed import DistributedContext, detect_distributed_context, normalize_command_argv, split_launcher_command +from .distributed import DistributedContext, detect_distributed_context, normalize_command_argv @dataclass @@ -183,6 +183,7 @@ def profile( kernel_filter: Optional[str] = None, force_cu_mask: bool = True, env: Optional[Dict[str, str]] = None, + launcher: Optional[str | Sequence[str]] = None, ) -> "Linex": """ Profile an application and collect source-level performance data. @@ -222,7 +223,6 @@ def profile( output_path.mkdir(parents=True, exist_ok=True) command_argv = normalize_command_argv(command) - launcher_split = split_launcher_command(command_argv) cmd = [ "rocprofv3", @@ -242,11 +242,12 @@ def profile( if kernel_filter: cmd.extend(["--kernel-include-regex", kernel_filter]) - cmd.extend(["--", *launcher_split.app_argv]) + cmd.extend(["--", *command_argv]) - # If a distributed launcher was detected, wrap: launcher rocprofv3 ... -- app - if launcher_split.is_distributed: - cmd = launcher_split.launcher_argv + cmd + # If a launcher is specified, prepend it: launcher rocprofv3 ... -- app + if launcher is not None: + launcher_argv = normalize_command_argv(launcher) + cmd = launcher_argv + cmd if force_cu_mask and "HSA_CU_MASK" not in run_env: run_env["HSA_CU_MASK"] = "0x1" # Force to CU 0 unless caller already set a mask diff --git a/linex/src/linex/distributed.py b/linex/src/linex/distributed.py index 487f434..30857a3 100644 --- a/linex/src/linex/distributed.py +++ b/linex/src/linex/distributed.py @@ -7,20 +7,10 @@ import os import shlex import socket -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Mapping, Sequence -KNOWN_LAUNCHERS = { - "torchrun": "torchrun", - "python": None, # only when followed by -m torch.distributed - "mpirun": "mpirun", - "mpiexec": "mpirun", - "srun": "srun", - "horovodrun": "horovodrun", -} - - @dataclass(frozen=True) class DistributedContext: global_rank: int = 0 @@ -39,18 +29,6 @@ def rank_tag(self) -> str: return f"rank{self.global_rank:04d}" -@dataclass -class LauncherSplit: - """Result of splitting a command into launcher prefix and application suffix.""" - launcher_argv: list[str] = field(default_factory=list) - app_argv: list[str] = field(default_factory=list) - launcher_name: str = "single" - - @property - def is_distributed(self) -> bool: - return len(self.launcher_argv) > 0 - - def _first_int(env: Mapping[str, str], keys: Sequence[str], default: int) -> int: for key in keys: value = env.get(key) @@ -135,197 +113,3 @@ def normalize_command_argv(command: str | Sequence[str]) -> list[str]: if not argv: raise ValueError("Command is empty") return argv - - -def split_launcher_command(argv: list[str]) -> LauncherSplit: - """Split a command argv into launcher prefix and application suffix. - - Recognizes torchrun, mpirun/mpiexec, srun, horovodrun. - For torchrun/python -m torch.distributed.*, all flags (--nproc_per_node etc.) - before the script name are launcher args; the script and everything after are app args. - For mpirun/mpiexec/srun/horovodrun, we split at the first positional arg that - looks like an executable (not a flag). - - Returns a LauncherSplit with launcher_argv (empty if no launcher detected) - and app_argv. - """ - if not argv: - return LauncherSplit(app_argv=argv) - - binary = os.path.basename(argv[0]) - - # --- torchrun --- - if binary == "torchrun": - return _split_torchrun(argv) - - # --- python -m torch.distributed.launch / python -m torch.distributed.run --- - if binary in ("python", "python3") and len(argv) >= 3: - if argv[1] == "-m" and argv[2].startswith("torch.distributed"): - return _split_torchrun(argv) - - # --- mpirun / mpiexec --- - if binary in ("mpirun", "mpiexec"): - return _split_mpi(argv) - - # --- srun --- - if binary == "srun": - return _split_srun(argv) - - # --- horovodrun --- - if binary == "horovodrun": - return _split_horovodrun(argv) - - # No launcher detected - return LauncherSplit(app_argv=argv) - - -def _split_torchrun(argv: list[str]) -> LauncherSplit: - """Split torchrun command. Flags before the script are launcher args.""" - # torchrun [flags] script.py [script args] - # Flags all start with -- and some take a value argument. - # We find the first arg that doesn't start with - and isn't a value of a flag. - TORCHRUN_VALUE_FLAGS = { - "--nproc_per_node", "--nproc-per-node", "--nnodes", - "--node_rank", "--node-rank", "--master_addr", "--master-addr", - "--master_port", "--master-port", "--rdzv_id", "--rdzv-id", - "--rdzv_backend", "--rdzv-backend", "--rdzv_endpoint", "--rdzv-endpoint", - "--rdzv_conf", "--rdzv-conf", "--max_restarts", "--max-restarts", - "--monitor_interval", "--monitor-interval", "--log_dir", "--log-dir", - "--redirects", "--tee", "-r", "-t", "--role", "--local_addr", - "--local-addr", "--logs_specs", "--logs-specs", - "--start_method", "--start-method", "--run_path", "--run-path", - "--omp_num_threads", "--omp-num-threads", - } - i = 1 # skip argv[0] (torchrun / python) - # skip python -m torch.distributed.run if present - if os.path.basename(argv[0]) in ("python", "python3") and len(argv) > 2 and argv[1] == "-m": - i = 3 # skip python -m torch.distributed.run - - while i < len(argv): - arg = argv[i] - if arg == "--": - # Explicit separator - return LauncherSplit( - launcher_argv=argv[:i], - app_argv=argv[i + 1:], - launcher_name="torchrun", - ) - if arg.startswith("-"): - # Check if this flag takes a value - flag_name = arg.split("=")[0] - if "=" not in arg and flag_name in TORCHRUN_VALUE_FLAGS: - i += 2 # skip flag and its value - else: - i += 1 - else: - # First positional = script name, everything from here is app - return LauncherSplit( - launcher_argv=argv[:i], - app_argv=argv[i:], - launcher_name="torchrun", - ) - - # All flags, no script found — treat entire thing as app - return LauncherSplit(app_argv=argv) - - -def _split_mpi(argv: list[str]) -> LauncherSplit: - """Split mpirun/mpiexec command.""" - MPI_VALUE_FLAGS = { - "-np", "-n", "--np", "-N", "--map-by", "--bind-to", "--rank-by", - "-H", "--host", "--hostfile", "-x", "--mca", "-wdir", "--wdir", - "-oversubscribe", "--oversubscribe", "--prefix", "-output-filename", - "--output-filename", "--report-bindings", - } - # Flags that take TWO values - MPI_DOUBLE_VALUE_FLAGS = {"--mca"} - - i = 1 - while i < len(argv): - arg = argv[i] - if arg == ":": - # MPMD separator — everything before is one command spec - # For simplicity, treat everything up to : as launcher - return LauncherSplit( - launcher_argv=argv[:i], - app_argv=argv[i:], - launcher_name="mpirun", - ) - if arg.startswith("-"): - flag_name = arg.split("=")[0] - if "=" not in arg and flag_name in MPI_DOUBLE_VALUE_FLAGS: - i += 3 # --mca key value - elif "=" not in arg and flag_name in MPI_VALUE_FLAGS: - i += 2 - else: - i += 1 - else: - # First positional = executable - return LauncherSplit( - launcher_argv=argv[:i], - app_argv=argv[i:], - launcher_name="mpirun", - ) - - return LauncherSplit(app_argv=argv) - - -def _split_srun(argv: list[str]) -> LauncherSplit: - """Split srun command.""" - SRUN_VALUE_FLAGS = { - "-N", "--nodes", "-n", "--ntasks", "-c", "--cpus-per-task", - "-G", "--gpus", "--gpus-per-node", "--gpus-per-task", - "-p", "--partition", "-w", "--nodelist", "-x", "--exclude", - "-t", "--time", "-J", "--job-name", "-o", "--output", "-e", "--error", - "--mem", "--mem-per-cpu", "--mem-per-gpu", "-D", "--chdir", - "--export", "--mpi", "--distribution", - } - i = 1 - while i < len(argv): - arg = argv[i] - if arg == "--": - return LauncherSplit( - launcher_argv=argv[:i], - app_argv=argv[i + 1:], - launcher_name="srun", - ) - if arg.startswith("-"): - flag_name = arg.split("=")[0] - if "=" not in arg and flag_name in SRUN_VALUE_FLAGS: - i += 2 - else: - i += 1 - else: - return LauncherSplit( - launcher_argv=argv[:i], - app_argv=argv[i:], - launcher_name="srun", - ) - - return LauncherSplit(app_argv=argv) - - -def _split_horovodrun(argv: list[str]) -> LauncherSplit: - """Split horovodrun command.""" - HOROVOD_VALUE_FLAGS = { - "-np", "-p", "--num-proc", "-H", "--hosts", "--hostfile", - "--start-timeout", "--network-interface", "--output-filename", - "--gloo-timeout-seconds", - } - i = 1 - while i < len(argv): - arg = argv[i] - if arg.startswith("-"): - flag_name = arg.split("=")[0] - if "=" not in arg and flag_name in HOROVOD_VALUE_FLAGS: - i += 2 - else: - i += 1 - else: - return LauncherSplit( - launcher_argv=argv[:i], - app_argv=argv[i:], - launcher_name="horovodrun", - ) - - return LauncherSplit(app_argv=argv) diff --git a/linex/src/linex/mcp/server.py b/linex/src/linex/mcp/server.py index 2d2d0b8..992b7a6 100644 --- a/linex/src/linex/mcp/server.py +++ b/linex/src/linex/mcp/server.py @@ -12,7 +12,7 @@ @mcp.tool() -def profile_application(command: str, kernel_filter: str = None, top_n: int = 10) -> dict: +def profile_application(command: str, kernel_filter: str = None, top_n: int = 10, launcher: str = None) -> dict: """ Profile a GPU application and get source-level performance metrics. @@ -28,7 +28,7 @@ def profile_application(command: str, kernel_filter: str = None, top_n: int = 10 Dictionary with total_source_lines, total_instructions, and hotspots list """ profiler = Linex() - profiler.profile(command, kernel_filter=kernel_filter) + profiler.profile(command, kernel_filter=kernel_filter, launcher=launcher) results = { "distributed_context": { @@ -113,7 +113,7 @@ def analyze_instruction_hotspots( Dictionary with hotspot_analysis list containing ISA-level details """ profiler = Linex() - profiler.profile(command, kernel_filter=kernel_filter) + profiler.profile(command, kernel_filter=kernel_filter, launcher=launcher, launcher=launcher) results = { "distributed_context": { diff --git a/linex/tests/test_distributed_api.py b/linex/tests/test_distributed_api.py index 9c0db98..4fa500b 100644 --- a/linex/tests/test_distributed_api.py +++ b/linex/tests/test_distributed_api.py @@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch from linex import Linex -from linex.distributed import detect_distributed_context, normalize_command_argv, split_launcher_command +from linex.distributed import detect_distributed_context, normalize_command_argv def _write_code_json(ui_dir: Path, source_location: str) -> None: @@ -73,67 +73,9 @@ def fake_run(cmd, **kwargs): assert len(profiler.rank_profiles) == 2 -def test_split_launcher_command_torchrun(): - from linex.distributed import split_launcher_command - split = split_launcher_command([ - "torchrun", "--nproc_per_node=8", "--nnodes", "2", "train.py", "--lr", "0.01" - ]) - assert split.is_distributed - assert split.launcher_name == "torchrun" - assert split.launcher_argv == ["torchrun", "--nproc_per_node=8", "--nnodes", "2"] - assert split.app_argv == ["train.py", "--lr", "0.01"] - - -def test_split_launcher_command_mpirun(): - from linex.distributed import split_launcher_command - split = split_launcher_command([ - "mpirun", "-np", "4", "--bind-to", "core", "./my_app", "--size", "1024" - ]) - assert split.is_distributed - assert split.launcher_name == "mpirun" - assert split.launcher_argv == ["mpirun", "-np", "4", "--bind-to", "core"] - assert split.app_argv == ["./my_app", "--size", "1024"] - - -def test_split_launcher_command_srun(): - from linex.distributed import split_launcher_command - split = split_launcher_command([ - "srun", "-N", "2", "-n", "16", "--gpus-per-node", "8", "./app" - ]) - assert split.is_distributed - assert split.launcher_name == "srun" - assert split.launcher_argv == ["srun", "-N", "2", "-n", "16", "--gpus-per-node", "8"] - assert split.app_argv == ["./app"] - - -def test_split_launcher_command_no_launcher(): - from linex.distributed import split_launcher_command - split = split_launcher_command(["python3", "train.py", "--epochs", "10"]) - assert not split.is_distributed - assert split.launcher_argv == [] - assert split.app_argv == ["python3", "train.py", "--epochs", "10"] - - -def test_split_launcher_command_python_m_torch_distributed(): - from linex.distributed import split_launcher_command - split = split_launcher_command([ - "python3", "-m", "torch.distributed.run", "--nproc_per_node=4", "train.py" - ]) - assert split.is_distributed - assert split.launcher_name == "torchrun" - assert split.launcher_argv == ["python3", "-m", "torch.distributed.run", "--nproc_per_node=4"] - assert split.app_argv == ["train.py"] - - -def test_profile_builds_correct_command_order_with_launcher(tmp_path): - """Verify that when a launcher is detected, the subprocess command is - launcher_args + rocprofv3 ... -- app_args (not rocprofv3 -- launcher app).""" - import json - from pathlib import Path - from unittest.mock import MagicMock, patch - - from linex import Linex - +def test_profile_with_launcher_builds_correct_command_order(tmp_path): + """When launcher is provided, the subprocess command should be + launcher_argv + rocprofv3 ... -- app_argv.""" dummy_decoder = tmp_path / "decoder" / "librocprof-trace-decoder.so" dummy_decoder.parent.mkdir(parents=True, exist_ok=True) dummy_decoder.write_text("placeholder") @@ -143,7 +85,6 @@ def test_profile_builds_correct_command_order_with_launcher(tmp_path): def fake_run(cmd, **kwargs): captured_cmd.extend(cmd) output_dir = Path(cmd[cmd.index("-d") + 1]) - # Write dummy trace data ui_dir = output_dir / "ui_output_000" ui_dir.mkdir(parents=True, exist_ok=True) code = {"code": [["s_nop 0", 0, 0, "test.hip:1", 1, 0x1000, 4, 10, 2, 0]]} @@ -160,17 +101,53 @@ def fake_run(cmd, **kwargs): ): profiler = Linex() profiler.profile( - command="torchrun --nproc_per_node=4 train.py --lr 0.01", + command="train.py --lr 0.01", + launcher="torchrun --nproc_per_node=4", output_dir=str(tmp_path / "out"), ) - # The command should start with torchrun, then rocprofv3 + # Command should be: torchrun --nproc_per_node=4 rocprofv3 ... -- train.py --lr 0.01 assert captured_cmd[0] == "torchrun" assert captured_cmd[1] == "--nproc_per_node=4" rocprofv3_idx = captured_cmd.index("rocprofv3") - assert rocprofv3_idx > 1 # rocprofv3 comes after launcher args - # After --, the app args should be train.py, not torchrun + assert rocprofv3_idx == 2 separator_idx = captured_cmd.index("--") assert captured_cmd[separator_idx + 1] == "train.py" assert captured_cmd[separator_idx + 2] == "--lr" assert captured_cmd[separator_idx + 3] == "0.01" + + +def test_profile_without_launcher_uses_plain_rocprofv3(tmp_path): + """Without launcher, command should be: rocprofv3 ... -- app_argv.""" + dummy_decoder = tmp_path / "decoder" / "librocprof-trace-decoder.so" + dummy_decoder.parent.mkdir(parents=True, exist_ok=True) + dummy_decoder.write_text("placeholder") + + captured_cmd = [] + + def fake_run(cmd, **kwargs): + captured_cmd.extend(cmd) + output_dir = Path(cmd[cmd.index("-d") + 1]) + ui_dir = output_dir / "ui_output_000" + ui_dir.mkdir(parents=True, exist_ok=True) + code = {"code": [["s_nop 0", 0, 0, "test.hip:1", 1, 0x1000, 4, 10, 2, 0]]} + (ui_dir / "code.json").write_text(json.dumps(code)) + m = MagicMock() + m.returncode = 0 + m.stdout = "" + m.stderr = "" + return m + + with ( + patch.object(Linex, "_ensure_decoder", return_value=dummy_decoder), + patch("subprocess.run", side_effect=fake_run), + ): + profiler = Linex() + profiler.profile( + command="./my_app --size 1024", + output_dir=str(tmp_path / "out"), + ) + + assert captured_cmd[0] == "rocprofv3" + separator_idx = captured_cmd.index("--") + assert captured_cmd[separator_idx + 1] == "./my_app" diff --git a/metrix/README.md b/metrix/README.md index 705e1fd..c98cdd6 100644 --- a/metrix/README.md +++ b/metrix/README.md @@ -49,18 +49,20 @@ metrix -o results.json ./my_app ## Distributed Launchers -Metrix supports launcher commands such as `torchrun`, `mpirun/mpiexec`, `srun`, and -`horovodrun`. Pass launcher commands directly after `--` so arguments are preserved. +Metrix supports distributed profiling with launchers like `torchrun`, `mpirun`, +`srun`, and `horovodrun`. Use `--launcher` to specify the launcher command +separately from the application, ensuring the correct invocation order +(`launcher rocprofv3 ... -- app`). ```bash # Torch distributed -metrix profile -- torchrun --nproc_per_node=8 train.py +metrix profile --launcher "torchrun --nproc_per_node=8" -- train.py # MPI -metrix profile -- mpirun -np 8 ./my_app --problem-size 4096 +metrix profile --launcher "mpirun -np 8" -- ./my_app --problem-size 4096 # Slurm -metrix profile -- srun -N 2 -n 16 ./my_app +metrix profile --launcher "srun -N 2 -n 16" -- ./my_app ``` When distributed rank variables are present (`RANK`, `OMPI_COMM_WORLD_RANK`, diff --git a/metrix/src/metrix/api.py b/metrix/src/metrix/api.py index b6828bd..282348e 100644 --- a/metrix/src/metrix/api.py +++ b/metrix/src/metrix/api.py @@ -92,6 +92,7 @@ def profile( metrics: Optional[List[str]] = None, profile: Optional[str] = None, kernel_filter: Optional[str] = None, + launcher: Optional[str | Sequence[str]] = None, time_only: bool = False, num_replays: int = 1, aggregate_by_kernel: bool = True, @@ -178,6 +179,7 @@ def profile( metrics=metrics_to_compute, num_replays=num_replays, aggregate_by_kernel=aggregate_by_kernel, + launcher=launcher, kernel_filter=rocprof_filter, cwd=cwd, timeout_seconds=timeout_seconds, diff --git a/metrix/src/metrix/backends/base.py b/metrix/src/metrix/backends/base.py index 1dcff5f..77f2114 100644 --- a/metrix/src/metrix/backends/base.py +++ b/metrix/src/metrix/backends/base.py @@ -475,6 +475,7 @@ def profile( metrics: List[str], num_replays: int = 5, aggregate_by_kernel: bool = False, + launcher: Optional[str | Sequence[str]] = None, kernel_filter: Optional[str] = None, cwd: Optional[str] = None, timeout_seconds: Optional[int] = 0, @@ -784,6 +785,7 @@ def _run_rocprof( counters: List[str], kernel_filter: Optional[str] = None, cwd: Optional[str] = None, + launcher: Optional[str | Sequence[str]] = None, timeout_seconds: Optional[int] = 0, kernel_iteration_range: Optional[str] = None, ) -> List[ProfileResult]: diff --git a/metrix/src/metrix/backends/gfx1201.py b/metrix/src/metrix/backends/gfx1201.py index 5586c75..eca553c 100644 --- a/metrix/src/metrix/backends/gfx1201.py +++ b/metrix/src/metrix/backends/gfx1201.py @@ -56,6 +56,7 @@ def _run_rocprof( counters: List[str], kernel_filter: Optional[str] = None, cwd: Optional[str] = None, + launcher: Optional[str | Sequence[str]] = None, timeout_seconds: Optional[int] = 0, kernel_iteration_range: Optional[str] = None, ) -> List[ProfileResult]: diff --git a/metrix/src/metrix/backends/gfx90a.py b/metrix/src/metrix/backends/gfx90a.py index 5a811e3..0d63edc 100644 --- a/metrix/src/metrix/backends/gfx90a.py +++ b/metrix/src/metrix/backends/gfx90a.py @@ -98,6 +98,7 @@ def _run_rocprof( counters: List[str], kernel_filter: Optional[str] = None, cwd: Optional[str] = None, + launcher: Optional[str | Sequence[str]] = None, timeout_seconds: Optional[int] = 0, ) -> List[ProfileResult]: """Run rocprofv3 and return results (single pass only - base class handles multi-pass)""" diff --git a/metrix/src/metrix/backends/gfx942.py b/metrix/src/metrix/backends/gfx942.py index 4ddada5..3c6e08a 100644 --- a/metrix/src/metrix/backends/gfx942.py +++ b/metrix/src/metrix/backends/gfx942.py @@ -98,6 +98,7 @@ def _run_rocprof( counters: List[str], kernel_filter: Optional[str] = None, cwd: Optional[str] = None, + launcher: Optional[str | Sequence[str]] = None, timeout_seconds: Optional[int] = 0, ) -> List[ProfileResult]: """Run rocprofv3 and return results (single pass only - base class handles multi-pass)""" diff --git a/metrix/src/metrix/cli/main.py b/metrix/src/metrix/cli/main.py index 4396112..a63d202 100644 --- a/metrix/src/metrix/cli/main.py +++ b/metrix/src/metrix/cli/main.py @@ -51,6 +51,15 @@ def create_parser(): description="Collect performance metrics from GPU kernels", ) + profile_parser.add_argument( + "--launcher", + default=None, + help=( + "Distributed launcher command to wrap rocprofv3 " + "(e.g., 'torchrun --nproc_per_node=8' or 'mpirun -np 4')" + ), + ) + profile_parser.add_argument( "target", nargs=argparse.REMAINDER, diff --git a/metrix/src/metrix/cli/profile_cmd.py b/metrix/src/metrix/cli/profile_cmd.py index c7e06e4..ecf6691 100644 --- a/metrix/src/metrix/cli/profile_cmd.py +++ b/metrix/src/metrix/cli/profile_cmd.py @@ -113,6 +113,7 @@ def profile_command(args): metrics=metrics_to_compute, num_replays=args.num_replays, aggregate_by_kernel=args.aggregate, + launcher=args.launcher, kernel_filter=kernel_filter, ) except Exception as e: diff --git a/metrix/src/metrix/mcp/server.py b/metrix/src/metrix/mcp/server.py index f012631..43ec28b 100644 --- a/metrix/src/metrix/mcp/server.py +++ b/metrix/src/metrix/mcp/server.py @@ -12,7 +12,7 @@ @mcp.tool() -def profile_metrics(command: str, metrics: list[str] = None) -> dict: +def profile_metrics(command: str, metrics: list[str] = None, launcher: str = None) -> dict: """ Profile GPU application and collect hardware performance metrics. @@ -33,7 +33,7 @@ def profile_metrics(command: str, metrics: list[str] = None) -> dict: if metrics is None: metrics = ["memory.hbm_bandwidth_utilization"] - results_obj = profiler.profile(command, metrics=metrics) + results_obj = profiler.profile(command, metrics=metrics, launcher=launcher) results = { "rank": { diff --git a/metrix/src/metrix/profiler/rocprof_wrapper.py b/metrix/src/metrix/profiler/rocprof_wrapper.py index 8f4c3e1..c204942 100644 --- a/metrix/src/metrix/profiler/rocprof_wrapper.py +++ b/metrix/src/metrix/profiler/rocprof_wrapper.py @@ -15,7 +15,7 @@ # Import ProfileResult from backends to avoid duplication from ..backends.base import ProfileResult from ..logger import logger -from ..utils.distributed import detect_distributed_context, normalize_command_argv, split_launcher_command +from ..utils.distributed import detect_distributed_context, normalize_command_argv class ROCProfV3Wrapper: @@ -80,6 +80,7 @@ def profile( extra_counters_path: Optional[Path] = None, arch: Optional[str] = None, env: Optional[Dict[str, str]] = None, + launcher: Optional[str | Sequence[str]] = None, ) -> List[ProfileResult]: """ Profile a command with specified counters (single pass). @@ -169,13 +170,13 @@ def profile( # Add target command command_argv = normalize_command_argv(command) - launcher_split = split_launcher_command(command_argv) prof_cmd.append("--") - prof_cmd.extend(launcher_split.app_argv) + prof_cmd.extend(command_argv) - # If a distributed launcher was detected, wrap: launcher rocprofv3 ... -- app - if launcher_split.is_distributed: - prof_cmd = launcher_split.launcher_argv + prof_cmd + # If a launcher is specified, prepend it: launcher rocprofv3 ... -- app + if launcher is not None: + launcher_argv = normalize_command_argv(launcher) + prof_cmd = launcher_argv + prof_cmd logger.debug(f"rocprofv3 command: {' '.join(prof_cmd)}") logger.info(f"Starting rocprofv3 with {len(counters)} counters") diff --git a/metrix/src/metrix/utils/distributed.py b/metrix/src/metrix/utils/distributed.py index 0821233..9e47cc3 100644 --- a/metrix/src/metrix/utils/distributed.py +++ b/metrix/src/metrix/utils/distributed.py @@ -1,5 +1,5 @@ """ -Distributed launcher helpers for Linex. +Helpers for distributed launch environments. """ from __future__ import annotations @@ -7,20 +7,10 @@ import os import shlex import socket -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Mapping, Sequence -KNOWN_LAUNCHERS = { - "torchrun": "torchrun", - "python": None, # only when followed by -m torch.distributed - "mpirun": "mpirun", - "mpiexec": "mpirun", - "srun": "srun", - "horovodrun": "horovodrun", -} - - @dataclass(frozen=True) class DistributedContext: global_rank: int = 0 @@ -39,18 +29,6 @@ def rank_tag(self) -> str: return f"rank{self.global_rank:04d}" -@dataclass -class LauncherSplit: - """Result of splitting a command into launcher prefix and application suffix.""" - launcher_argv: list[str] = field(default_factory=list) - app_argv: list[str] = field(default_factory=list) - launcher_name: str = "single" - - @property - def is_distributed(self) -> bool: - return len(self.launcher_argv) > 0 - - def _first_int(env: Mapping[str, str], keys: Sequence[str], default: int) -> int: for key in keys: value = env.get(key) @@ -137,200 +115,6 @@ def normalize_command_argv(command: str | Sequence[str]) -> list[str]: return argv -def split_launcher_command(argv: list[str]) -> LauncherSplit: - """Split a command argv into launcher prefix and application suffix. - - Recognizes torchrun, mpirun/mpiexec, srun, horovodrun. - For torchrun/python -m torch.distributed.*, all flags (--nproc_per_node etc.) - before the script name are launcher args; the script and everything after are app args. - For mpirun/mpiexec/srun/horovodrun, we split at the first positional arg that - looks like an executable (not a flag). - - Returns a LauncherSplit with launcher_argv (empty if no launcher detected) - and app_argv. - """ - if not argv: - return LauncherSplit(app_argv=argv) - - binary = os.path.basename(argv[0]) - - # --- torchrun --- - if binary == "torchrun": - return _split_torchrun(argv) - - # --- python -m torch.distributed.launch / python -m torch.distributed.run --- - if binary in ("python", "python3") and len(argv) >= 3: - if argv[1] == "-m" and argv[2].startswith("torch.distributed"): - return _split_torchrun(argv) - - # --- mpirun / mpiexec --- - if binary in ("mpirun", "mpiexec"): - return _split_mpi(argv) - - # --- srun --- - if binary == "srun": - return _split_srun(argv) - - # --- horovodrun --- - if binary == "horovodrun": - return _split_horovodrun(argv) - - # No launcher detected - return LauncherSplit(app_argv=argv) - - -def _split_torchrun(argv: list[str]) -> LauncherSplit: - """Split torchrun command. Flags before the script are launcher args.""" - # torchrun [flags] script.py [script args] - # Flags all start with -- and some take a value argument. - # We find the first arg that doesn't start with - and isn't a value of a flag. - TORCHRUN_VALUE_FLAGS = { - "--nproc_per_node", "--nproc-per-node", "--nnodes", - "--node_rank", "--node-rank", "--master_addr", "--master-addr", - "--master_port", "--master-port", "--rdzv_id", "--rdzv-id", - "--rdzv_backend", "--rdzv-backend", "--rdzv_endpoint", "--rdzv-endpoint", - "--rdzv_conf", "--rdzv-conf", "--max_restarts", "--max-restarts", - "--monitor_interval", "--monitor-interval", "--log_dir", "--log-dir", - "--redirects", "--tee", "-r", "-t", "--role", "--local_addr", - "--local-addr", "--logs_specs", "--logs-specs", - "--start_method", "--start-method", "--run_path", "--run-path", - "--omp_num_threads", "--omp-num-threads", - } - i = 1 # skip argv[0] (torchrun / python) - # skip python -m torch.distributed.run if present - if os.path.basename(argv[0]) in ("python", "python3") and len(argv) > 2 and argv[1] == "-m": - i = 3 # skip python -m torch.distributed.run - - while i < len(argv): - arg = argv[i] - if arg == "--": - # Explicit separator - return LauncherSplit( - launcher_argv=argv[:i], - app_argv=argv[i + 1:], - launcher_name="torchrun", - ) - if arg.startswith("-"): - # Check if this flag takes a value - flag_name = arg.split("=")[0] - if "=" not in arg and flag_name in TORCHRUN_VALUE_FLAGS: - i += 2 # skip flag and its value - else: - i += 1 - else: - # First positional = script name, everything from here is app - return LauncherSplit( - launcher_argv=argv[:i], - app_argv=argv[i:], - launcher_name="torchrun", - ) - - # All flags, no script found — treat entire thing as app - return LauncherSplit(app_argv=argv) - - -def _split_mpi(argv: list[str]) -> LauncherSplit: - """Split mpirun/mpiexec command.""" - MPI_VALUE_FLAGS = { - "-np", "-n", "--np", "-N", "--map-by", "--bind-to", "--rank-by", - "-H", "--host", "--hostfile", "-x", "--mca", "-wdir", "--wdir", - "-oversubscribe", "--oversubscribe", "--prefix", "-output-filename", - "--output-filename", "--report-bindings", - } - # Flags that take TWO values - MPI_DOUBLE_VALUE_FLAGS = {"--mca"} - - i = 1 - while i < len(argv): - arg = argv[i] - if arg == ":": - # MPMD separator — everything before is one command spec - # For simplicity, treat everything up to : as launcher - return LauncherSplit( - launcher_argv=argv[:i], - app_argv=argv[i:], - launcher_name="mpirun", - ) - if arg.startswith("-"): - flag_name = arg.split("=")[0] - if "=" not in arg and flag_name in MPI_DOUBLE_VALUE_FLAGS: - i += 3 # --mca key value - elif "=" not in arg and flag_name in MPI_VALUE_FLAGS: - i += 2 - else: - i += 1 - else: - # First positional = executable - return LauncherSplit( - launcher_argv=argv[:i], - app_argv=argv[i:], - launcher_name="mpirun", - ) - - return LauncherSplit(app_argv=argv) - - -def _split_srun(argv: list[str]) -> LauncherSplit: - """Split srun command.""" - SRUN_VALUE_FLAGS = { - "-N", "--nodes", "-n", "--ntasks", "-c", "--cpus-per-task", - "-G", "--gpus", "--gpus-per-node", "--gpus-per-task", - "-p", "--partition", "-w", "--nodelist", "-x", "--exclude", - "-t", "--time", "-J", "--job-name", "-o", "--output", "-e", "--error", - "--mem", "--mem-per-cpu", "--mem-per-gpu", "-D", "--chdir", - "--export", "--mpi", "--distribution", - } - i = 1 - while i < len(argv): - arg = argv[i] - if arg == "--": - return LauncherSplit( - launcher_argv=argv[:i], - app_argv=argv[i + 1:], - launcher_name="srun", - ) - if arg.startswith("-"): - flag_name = arg.split("=")[0] - if "=" not in arg and flag_name in SRUN_VALUE_FLAGS: - i += 2 - else: - i += 1 - else: - return LauncherSplit( - launcher_argv=argv[:i], - app_argv=argv[i:], - launcher_name="srun", - ) - - return LauncherSplit(app_argv=argv) - - -def _split_horovodrun(argv: list[str]) -> LauncherSplit: - """Split horovodrun command.""" - HOROVOD_VALUE_FLAGS = { - "-np", "-p", "--num-proc", "-H", "--hosts", "--hostfile", - "--start-timeout", "--network-interface", "--output-filename", - "--gloo-timeout-seconds", - } - i = 1 - while i < len(argv): - arg = argv[i] - if arg.startswith("-"): - flag_name = arg.split("=")[0] - if "=" not in arg and flag_name in HOROVOD_VALUE_FLAGS: - i += 2 - else: - i += 1 - else: - return LauncherSplit( - launcher_argv=argv[:i], - app_argv=argv[i:], - launcher_name="horovodrun", - ) - - return LauncherSplit(app_argv=argv) - - def apply_rank_suffix(path: str, context: DistributedContext) -> str: """Append rank suffix to output paths for distributed runs.""" if not context.is_distributed: diff --git a/metrix/tests/unit/test_distributed.py b/metrix/tests/unit/test_distributed.py index b83ea22..9519866 100644 --- a/metrix/tests/unit/test_distributed.py +++ b/metrix/tests/unit/test_distributed.py @@ -7,7 +7,6 @@ apply_rank_suffix, detect_distributed_context, normalize_command_argv, - split_launcher_command, ) @@ -34,6 +33,16 @@ def test_apply_rank_suffix_distributed_file_path(): assert apply_rank_suffix("results.json", ctx) == "results.rank0003.json" +def test_apply_rank_suffix_no_extension(): + ctx = DistributedContext(global_rank=0, world_size=4) + assert apply_rank_suffix("results", ctx) == "results.rank0000" + + +def test_apply_rank_suffix_single_process(): + ctx = DistributedContext(global_rank=0, world_size=1) + assert apply_rank_suffix("results.json", ctx) == "results.json" + + def test_normalize_command_argv_accepts_string_and_sequence(): assert normalize_command_argv('torchrun --nproc_per_node=2 train.py --arg "two words"') == [ "torchrun", @@ -48,29 +57,3 @@ def test_normalize_command_argv_accepts_string_and_sequence(): "4", "./app", ] - - -def test_split_launcher_command_torchrun(): - split = split_launcher_command([ - "torchrun", "--nproc_per_node=8", "train.py", "--lr", "0.01" - ]) - assert split.is_distributed - assert split.launcher_name == "torchrun" - assert split.launcher_argv == ["torchrun", "--nproc_per_node=8"] - assert split.app_argv == ["train.py", "--lr", "0.01"] - - -def test_split_launcher_command_mpirun(): - split = split_launcher_command([ - "mpirun", "-np", "4", "./my_app", "--size", "1024" - ]) - assert split.is_distributed - assert split.launcher_name == "mpirun" - assert split.launcher_argv == ["mpirun", "-np", "4"] - assert split.app_argv == ["./my_app", "--size", "1024"] - - -def test_split_launcher_command_no_launcher(): - split = split_launcher_command(["./my_app", "--size", "1024"]) - assert not split.is_distributed - assert split.app_argv == ["./my_app", "--size", "1024"] diff --git a/metrix/tests/unit/test_rocprof_wrapper.py b/metrix/tests/unit/test_rocprof_wrapper.py index b0f8c68..c51a671 100644 --- a/metrix/tests/unit/test_rocprof_wrapper.py +++ b/metrix/tests/unit/test_rocprof_wrapper.py @@ -359,8 +359,8 @@ def fake_run(cmd, **kwargs): assert results[0].world_size == 8 - def test_launcher_command_builds_correct_order(self, wrapper_no_rocm_check): - """torchrun command should produce: torchrun args rocprofv3 ... -- app args.""" + def test_launcher_param_builds_correct_order(self, wrapper_no_rocm_check): + """Explicit launcher param should produce: launcher rocprofv3 ... -- app.""" wrapper = wrapper_no_rocm_check captured_cmd = [] @@ -378,19 +378,48 @@ def fake_run(cmd, **kwargs): tempfile.TemporaryDirectory() as tmpdir, ): wrapper.profile( - command="torchrun --nproc_per_node=4 train.py --lr 0.01", + command="train.py --lr 0.01", counters=[], output_dir=Path(tmpdir), + launcher="torchrun --nproc_per_node=4", ) assert captured_cmd[0] == "torchrun" assert captured_cmd[1] == "--nproc_per_node=4" rocprofv3_idx = captured_cmd.index("rocprofv3") - assert rocprofv3_idx > 1 + assert rocprofv3_idx == 2 separator_idx = captured_cmd.index("--") assert captured_cmd[separator_idx + 1] == "train.py" assert captured_cmd[separator_idx + 2] == "--lr" + def test_no_launcher_uses_plain_rocprofv3(self, wrapper_no_rocm_check): + """Without launcher, command should start with rocprofv3.""" + wrapper = wrapper_no_rocm_check + captured_cmd = [] + + def fake_run(cmd, **kwargs): + captured_cmd.extend(cmd) + mock_result = MagicMock() + mock_result.returncode = 0 + mock_result.stdout = "" + mock_result.stderr = "" + return mock_result + + with ( + patch("subprocess.run", side_effect=fake_run), + patch.object(wrapper, "_parse_output", return_value=[]), + tempfile.TemporaryDirectory() as tmpdir, + ): + wrapper.profile( + command="./my_app --size 1024", + counters=[], + output_dir=Path(tmpdir), + ) + + assert captured_cmd[0] == "rocprofv3" + separator_idx = captured_cmd.index("--") + assert captured_cmd[separator_idx + 1] == "./my_app" + def test_parse_missing_optional_fields(self, wrapper): """Handle missing optional fields gracefully""" row = { From cb1ee059a5579730923085cea835708a0464328d Mon Sep 17 00:00:00 2001 From: muhaawad Date: Mon, 23 Mar 2026 01:03:47 +0000 Subject: [PATCH 4/8] Fix lint: remove duplicate launcher kwarg, fix parenthesized with for Python 3.8 Co-Authored-By: Claude Opus 4.6 --- linex/src/linex/mcp/server.py | 2 +- linex/tests/test_distributed_api.py | 18 ++++++------------ 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/linex/src/linex/mcp/server.py b/linex/src/linex/mcp/server.py index 992b7a6..00106bc 100644 --- a/linex/src/linex/mcp/server.py +++ b/linex/src/linex/mcp/server.py @@ -113,7 +113,7 @@ def analyze_instruction_hotspots( Dictionary with hotspot_analysis list containing ISA-level details """ profiler = Linex() - profiler.profile(command, kernel_filter=kernel_filter, launcher=launcher, launcher=launcher) + profiler.profile(command, kernel_filter=kernel_filter, launcher=launcher) results = { "distributed_context": { diff --git a/linex/tests/test_distributed_api.py b/linex/tests/test_distributed_api.py index 4fa500b..2aba185 100644 --- a/linex/tests/test_distributed_api.py +++ b/linex/tests/test_distributed_api.py @@ -55,10 +55,8 @@ def fake_run(cmd, **kwargs): m.stderr = "" return m - with ( - patch.object(Linex, "_ensure_decoder", return_value=dummy_decoder), - patch("subprocess.run", side_effect=fake_run), - ): + with patch.object(Linex, "_ensure_decoder", return_value=dummy_decoder), \ + patch("subprocess.run", side_effect=fake_run): profiler = Linex() profiler.profile( command='python -c "print(1)"', @@ -95,10 +93,8 @@ def fake_run(cmd, **kwargs): m.stderr = "" return m - with ( - patch.object(Linex, "_ensure_decoder", return_value=dummy_decoder), - patch("subprocess.run", side_effect=fake_run), - ): + with patch.object(Linex, "_ensure_decoder", return_value=dummy_decoder), \ + patch("subprocess.run", side_effect=fake_run): profiler = Linex() profiler.profile( command="train.py --lr 0.01", @@ -138,10 +134,8 @@ def fake_run(cmd, **kwargs): m.stderr = "" return m - with ( - patch.object(Linex, "_ensure_decoder", return_value=dummy_decoder), - patch("subprocess.run", side_effect=fake_run), - ): + with patch.object(Linex, "_ensure_decoder", return_value=dummy_decoder), \ + patch("subprocess.run", side_effect=fake_run): profiler = Linex() profiler.profile( command="./my_app --size 1024", From 9b07645100c4f401dd82f4f75ac2beb4b14ccc81 Mon Sep 17 00:00:00 2001 From: muhaawad Date: Mon, 23 Mar 2026 01:04:47 +0000 Subject: [PATCH 5/8] Fix lint: add missing launcher param to analyze_instruction_hotspots Co-Authored-By: Claude Opus 4.6 --- linex/src/linex/mcp/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/linex/src/linex/mcp/server.py b/linex/src/linex/mcp/server.py index 00106bc..077e8c8 100644 --- a/linex/src/linex/mcp/server.py +++ b/linex/src/linex/mcp/server.py @@ -95,7 +95,7 @@ def profile_application(command: str, kernel_filter: str = None, top_n: int = 10 @mcp.tool() def analyze_instruction_hotspots( - command: str, kernel_filter: str = None, top_lines: int = 5, top_instructions_per_line: int = 10 + command: str, kernel_filter: str = None, top_lines: int = 5, top_instructions_per_line: int = 10, launcher: str = None ) -> dict: """ Get detailed instruction-level analysis for the hottest source lines. From 22f8fada6bfc8195575db03f80082afd40c41f30 Mon Sep 17 00:00:00 2001 From: muhaawad Date: Mon, 23 Mar 2026 01:06:12 +0000 Subject: [PATCH 6/8] Fix ruff formatting Co-Authored-By: Claude Opus 4.6 --- linex/src/linex/api.py | 4 +++- linex/src/linex/mcp/server.py | 10 ++++++++-- linex/tests/test_distributed_api.py | 19 +++++++++++------- metrix/src/metrix/backends/base.py | 4 +++- metrix/src/metrix/cli/profile_cmd.py | 24 ++++++++++++++++++----- metrix/src/metrix/utils/distributed.py | 1 + metrix/tests/unit/test_distributed.py | 6 +++++- metrix/tests/unit/test_rocprof_wrapper.py | 1 - 8 files changed, 51 insertions(+), 18 deletions(-) diff --git a/linex/src/linex/api.py b/linex/src/linex/api.py index 1b9fd39..aec03ef 100644 --- a/linex/src/linex/api.py +++ b/linex/src/linex/api.py @@ -279,7 +279,9 @@ def profile( hostname=dist_context.hostname, launcher=dist_context.launcher, ui_output_dir=str(ui_dir), - source_lines=sorted(source_lines.values(), key=lambda x: x.total_cycles, reverse=True), + source_lines=sorted( + source_lines.values(), key=lambda x: x.total_cycles, reverse=True + ), instructions=instructions, ) diff --git a/linex/src/linex/mcp/server.py b/linex/src/linex/mcp/server.py index 077e8c8..84c7a2d 100644 --- a/linex/src/linex/mcp/server.py +++ b/linex/src/linex/mcp/server.py @@ -12,7 +12,9 @@ @mcp.tool() -def profile_application(command: str, kernel_filter: str = None, top_n: int = 10, launcher: str = None) -> dict: +def profile_application( + command: str, kernel_filter: str = None, top_n: int = 10, launcher: str = None +) -> dict: """ Profile a GPU application and get source-level performance metrics. @@ -95,7 +97,11 @@ def profile_application(command: str, kernel_filter: str = None, top_n: int = 10 @mcp.tool() def analyze_instruction_hotspots( - command: str, kernel_filter: str = None, top_lines: int = 5, top_instructions_per_line: int = 10, launcher: str = None + command: str, + kernel_filter: str = None, + top_lines: int = 5, + top_instructions_per_line: int = 10, + launcher: str = None, ) -> dict: """ Get detailed instruction-level analysis for the hottest source lines. diff --git a/linex/tests/test_distributed_api.py b/linex/tests/test_distributed_api.py index 2aba185..c6feaeb 100644 --- a/linex/tests/test_distributed_api.py +++ b/linex/tests/test_distributed_api.py @@ -30,7 +30,9 @@ def _write_code_json(ui_dir: Path, source_location: str) -> None: def test_distributed_helpers_parse_common_env(): - ctx = detect_distributed_context({"SLURM_PROCID": "3", "SLURM_LOCALID": "1", "SLURM_NTASKS": "8"}) + ctx = detect_distributed_context( + {"SLURM_PROCID": "3", "SLURM_LOCALID": "1", "SLURM_NTASKS": "8"} + ) assert ctx.global_rank == 3 assert ctx.local_rank == 1 assert ctx.world_size == 8 @@ -55,8 +57,9 @@ def fake_run(cmd, **kwargs): m.stderr = "" return m - with patch.object(Linex, "_ensure_decoder", return_value=dummy_decoder), \ - patch("subprocess.run", side_effect=fake_run): + with patch.object(Linex, "_ensure_decoder", return_value=dummy_decoder), patch( + "subprocess.run", side_effect=fake_run + ): profiler = Linex() profiler.profile( command='python -c "print(1)"', @@ -93,8 +96,9 @@ def fake_run(cmd, **kwargs): m.stderr = "" return m - with patch.object(Linex, "_ensure_decoder", return_value=dummy_decoder), \ - patch("subprocess.run", side_effect=fake_run): + with patch.object(Linex, "_ensure_decoder", return_value=dummy_decoder), patch( + "subprocess.run", side_effect=fake_run + ): profiler = Linex() profiler.profile( command="train.py --lr 0.01", @@ -134,8 +138,9 @@ def fake_run(cmd, **kwargs): m.stderr = "" return m - with patch.object(Linex, "_ensure_decoder", return_value=dummy_decoder), \ - patch("subprocess.run", side_effect=fake_run): + with patch.object(Linex, "_ensure_decoder", return_value=dummy_decoder), patch( + "subprocess.run", side_effect=fake_run + ): profiler = Linex() profiler.profile( command="./my_app --size 1024", diff --git a/metrix/src/metrix/backends/base.py b/metrix/src/metrix/backends/base.py index 77f2114..f3547c7 100644 --- a/metrix/src/metrix/backends/base.py +++ b/metrix/src/metrix/backends/base.py @@ -816,7 +816,9 @@ def _aggregate_by_dispatch_across_runs( groups = defaultdict(list) for result in results: if result.world_size > 1: - key = f"rank_{result.global_rank}:dispatch_{result.dispatch_id}:{result.kernel_name}" + key = ( + f"rank_{result.global_rank}:dispatch_{result.dispatch_id}:{result.kernel_name}" + ) else: key = f"dispatch_{result.dispatch_id}:{result.kernel_name}" groups[key].append(result) diff --git a/metrix/src/metrix/cli/profile_cmd.py b/metrix/src/metrix/cli/profile_cmd.py index ecf6691..c3b99f9 100644 --- a/metrix/src/metrix/cli/profile_cmd.py +++ b/metrix/src/metrix/cli/profile_cmd.py @@ -13,7 +13,11 @@ from ..backends import get_backend, Statistics, detect_or_default from ..metrics import METRIC_PROFILES, METRIC_CATALOG from ..logger import logger -from ..utils.distributed import apply_rank_suffix, detect_distributed_context, normalize_command_argv +from ..utils.distributed import ( + apply_rank_suffix, + detect_distributed_context, + normalize_command_argv, +) def profile_command(args): @@ -164,7 +168,9 @@ def profile_command(args): output_path = args.output if output_path and dist_context.is_distributed: output_path = apply_rank_suffix(output_path, dist_context) - logger.info("Distributed output path for rank %s: %s", dist_context.global_rank, output_path) + logger.info( + "Distributed output path for rank %s: %s", dist_context.global_rank, output_path + ) if output_path: # Detect format from file extension @@ -174,10 +180,14 @@ def profile_command(args): if ext == ".json": _write_json_output(output_file, results, metrics_to_compute, dist_context) elif ext == ".csv": - _write_csv_output(output_file, results, metrics_to_compute, args.aggregate, dist_context) + _write_csv_output( + output_file, results, metrics_to_compute, args.aggregate, dist_context + ) else: # Default to text - _write_text_output(output_file, results, metrics_to_compute, args.aggregate, dist_context) + _write_text_output( + output_file, results, metrics_to_compute, args.aggregate, dist_context + ) else: # Print to stdout _print_text_results( @@ -231,7 +241,11 @@ def _print_text_results( # - "dispatch_1:kernel_name" # - "rank_1:dispatch_1:kernel_name" (distributed) parts = dispatch_key.split(":") - if len(parts) >= 2 and parts[0].startswith("rank_") and parts[1].startswith("dispatch_"): + if ( + len(parts) >= 2 + and parts[0].startswith("rank_") + and parts[1].startswith("dispatch_") + ): dispatch_id = parts[1].replace("dispatch_", "") kernel_name = ":".join(parts[2:]) if len(parts) > 2 else "" print(f"Dispatch #{dispatch_id}: {kernel_name}") diff --git a/metrix/src/metrix/utils/distributed.py b/metrix/src/metrix/utils/distributed.py index 9e47cc3..d799e95 100644 --- a/metrix/src/metrix/utils/distributed.py +++ b/metrix/src/metrix/utils/distributed.py @@ -121,6 +121,7 @@ def apply_rank_suffix(path: str, context: DistributedContext) -> str: return path from pathlib import Path as _Path + p = _Path(path) suffix = p.suffix if suffix: diff --git a/metrix/tests/unit/test_distributed.py b/metrix/tests/unit/test_distributed.py index 9519866..1afc0e3 100644 --- a/metrix/tests/unit/test_distributed.py +++ b/metrix/tests/unit/test_distributed.py @@ -20,7 +20,11 @@ def test_detect_distributed_context_torchrun_env(): def test_detect_distributed_context_mpi_env(): - env = {"OMPI_COMM_WORLD_RANK": "5", "OMPI_COMM_WORLD_LOCAL_RANK": "1", "OMPI_COMM_WORLD_SIZE": "8"} + env = { + "OMPI_COMM_WORLD_RANK": "5", + "OMPI_COMM_WORLD_LOCAL_RANK": "1", + "OMPI_COMM_WORLD_SIZE": "8", + } ctx = detect_distributed_context(env) assert ctx.global_rank == 5 assert ctx.local_rank == 1 diff --git a/metrix/tests/unit/test_rocprof_wrapper.py b/metrix/tests/unit/test_rocprof_wrapper.py index c51a671..13c490f 100644 --- a/metrix/tests/unit/test_rocprof_wrapper.py +++ b/metrix/tests/unit/test_rocprof_wrapper.py @@ -358,7 +358,6 @@ def fake_run(cmd, **kwargs): assert results[0].local_rank == 1 assert results[0].world_size == 8 - def test_launcher_param_builds_correct_order(self, wrapper_no_rocm_check): """Explicit launcher param should produce: launcher rocprofv3 ... -- app.""" wrapper = wrapper_no_rocm_check From f308ae1299977cf5696d799dece6bbb950d57837 Mon Sep 17 00:00:00 2001 From: muhaawad Date: Mon, 23 Mar 2026 03:17:05 +0000 Subject: [PATCH 7/8] Address review feedback: fix docstring placement, forward launcher in batch path, clarify launcher semantics - Move profile_command docstring above code (was displaced) - Forward launcher param in CounterBackend recursive batch call - Add Note sections to Linex.profile() and Metrix.profile() explaining that launcher is for mpirun-style use, and for torchrun the correct pattern is running metrix/linex under torchrun (not the reverse) Co-Authored-By: Claude Opus 4.6 --- linex/src/linex/api.py | 9 +++++++++ metrix/src/metrix/api.py | 6 ++++++ metrix/src/metrix/cli/profile_cmd.py | 3 +-- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/linex/src/linex/api.py b/linex/src/linex/api.py index aec03ef..ae3f258 100644 --- a/linex/src/linex/api.py +++ b/linex/src/linex/api.py @@ -199,6 +199,15 @@ def profile( kernel_filter: Regex filter for kernel names (default: None = all kernels) force_cu_mask: Force waves to target CU using HSA_CU_MASK (default: True) + Note: + When using ``launcher``, the launcher spawns child processes that + each run rocprofv3. Rank metadata is detected from environment + variables (RANK, LOCAL_RANK, WORLD_SIZE, etc.) set by the launcher + in each child process. This means ``launcher`` is most useful when + Linex itself runs *inside* a launched process (e.g., + ``torchrun --no-python ... linex-cli ...``), not when calling + ``Linex().profile(launcher=...)`` from a non-distributed parent. + Returns: self for chaining """ diff --git a/metrix/src/metrix/api.py b/metrix/src/metrix/api.py index 282348e..dd7cae9 100644 --- a/metrix/src/metrix/api.py +++ b/metrix/src/metrix/api.py @@ -121,6 +121,12 @@ def profile( cwd: Working directory for command execution (default: None) timeout_seconds: Timeout in seconds for profiling (default: 0, zero or None for no timeout) + Note: + When using ``launcher``, rank metadata is detected from environment + variables (RANK, LOCAL_RANK, etc.) set by the launcher in child + processes. For torchrun, run ``torchrun --no-python metrix ...`` + so each rank gets its own metrix process with correct env vars. + Returns: ProfilingResults object with all collected data """ diff --git a/metrix/src/metrix/cli/profile_cmd.py b/metrix/src/metrix/cli/profile_cmd.py index c3b99f9..bcd9981 100644 --- a/metrix/src/metrix/cli/profile_cmd.py +++ b/metrix/src/metrix/cli/profile_cmd.py @@ -21,12 +21,11 @@ def profile_command(args): + """Execute profile command using clean backend API.""" command_argv = normalize_command_argv(_normalize_cli_target(args.target)) command_display = " ".join(command_argv) dist_context = detect_distributed_context() - """Execute profile command using clean backend API""" - # Auto-detect architecture arch = detect_or_default(None) logger.info(f"Detected architecture: {arch}") From 06ba4dea590f0a27741312ff81179320acb0ce05 Mon Sep 17 00:00:00 2001 From: muhaawad Date: Mon, 23 Mar 2026 03:17:42 +0000 Subject: [PATCH 8/8] Fix launcher forwarding in CounterBackend batch path Co-Authored-By: Claude Opus 4.6 --- metrix/src/metrix/backends/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/metrix/src/metrix/backends/base.py b/metrix/src/metrix/backends/base.py index f3547c7..42f88b8 100644 --- a/metrix/src/metrix/backends/base.py +++ b/metrix/src/metrix/backends/base.py @@ -554,6 +554,7 @@ def profile( metrics=batch_metrics, num_replays=num_replays, aggregate_by_kernel=aggregate_by_kernel, + launcher=launcher, kernel_filter=kernel_filter, cwd=cwd, timeout_seconds=timeout_seconds,