diff --git a/linex/README.md b/linex/README.md index a9fb24c..418535e 100644 --- a/linex/README.md +++ b/linex/README.md @@ -24,6 +24,29 @@ for line in profiler.source_lines[:5]: print(f" {line.total_cycles:,} cycles ({line.stall_percent:.1f}% stalled)") ``` +## Distributed Launchers + +Linex supports distributed profiling with launchers like `torchrun`, `mpirun`, +`srun`, and `horovodrun`. Pass the launcher separately so Linex builds the +correct command order (`launcher rocprofv3 ... -- app`). + +```python +profiler = Linex() +profiler.profile( + command="train.py", + launcher="torchrun --nproc_per_node=8", + output_dir="linex_sqtt", +) + +print(profiler.distributed_context.global_rank) +for rank_key, rank_profile in profiler.rank_profiles.items(): + print(rank_key, len(rank_profile.source_lines)) +``` + +In distributed mode, Linex writes traces into rank-specific subdirectories +(`.../rank0000`, `.../rank0001`, ...) to avoid collisions. Rank metadata is +automatically detected from environment variables set by the launcher. + ## What You Get **Instruction-level metrics mapped to source lines:** @@ -66,6 +89,8 @@ profiler = Linex( **Properties:** - `source_lines` - List[SourceLine] sorted by total_cycles - `instructions` - List[InstructionData] +- `rank_profiles` - Per-rank profiling data for distributed runs +- `distributed_context` - Detected launcher/rank metadata ### SourceLine diff --git a/linex/src/linex/__init__.py b/linex/src/linex/__init__.py index d832c18..be65762 100644 --- a/linex/src/linex/__init__.py +++ b/linex/src/linex/__init__.py @@ -8,7 +8,7 @@ providing cycle counts and performance metrics per source line. """ -from .api import Linex, SourceLine, InstructionData +from .api import InstructionData, Linex, RankProfile, SourceLine __version__ = "0.1.0" -__all__ = ["Linex", "SourceLine", "InstructionData"] +__all__ = ["Linex", "SourceLine", "InstructionData", "RankProfile"] diff --git a/linex/src/linex/api.py b/linex/src/linex/api.py index 2b9533f..ae3f258 100644 --- a/linex/src/linex/api.py +++ b/linex/src/linex/api.py @@ -6,11 +6,14 @@ """ import json +import os import subprocess import urllib.request from dataclasses import dataclass from pathlib import Path -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Sequence + +from .distributed import DistributedContext, detect_distributed_context, normalize_command_argv @dataclass @@ -99,6 +102,21 @@ def stall_percent(self) -> float: return 100.0 * self.stall_cycles / self.total_cycles if self.total_cycles > 0 else 0.0 +@dataclass +class RankProfile: + """Per-rank trace data produced by a Linex profiling run.""" + + rank_key: str + global_rank: int + local_rank: int + world_size: int + hostname: str + launcher: str + ui_output_dir: str + source_lines: List[SourceLine] + instructions: List[InstructionData] + + class Linex: """ Linex - Source-Level GPU Performance Profiler @@ -143,6 +161,8 @@ def __init__( # Data storage self._instructions: List[InstructionData] = [] self._source_lines: Dict[str, SourceLine] = {} + self._rank_profiles: Dict[str, RankProfile] = {} + self._distributed_context: DistributedContext = DistributedContext() def _ensure_decoder(self) -> Path: """Ensure decoder library is available, download if needed.""" @@ -158,10 +178,12 @@ def _ensure_decoder(self) -> Path: def profile( self, - command: str, + command: str | Sequence[str], output_dir: Optional[str] = None, kernel_filter: Optional[str] = None, force_cu_mask: bool = True, + env: Optional[Dict[str, str]] = None, + launcher: Optional[str | Sequence[str]] = None, ) -> "Linex": """ Profile an application and collect source-level performance data. @@ -177,15 +199,39 @@ def profile( kernel_filter: Regex filter for kernel names (default: None = all kernels) force_cu_mask: Force waves to target CU using HSA_CU_MASK (default: True) + Note: + When using ``launcher``, the launcher spawns child processes that + each run rocprofv3. Rank metadata is detected from environment + variables (RANK, LOCAL_RANK, WORLD_SIZE, etc.) set by the launcher + in each child process. This means ``launcher`` is most useful when + Linex itself runs *inside* a launched process (e.g., + ``torchrun --no-python ... linex-cli ...``), not when calling + ``Linex().profile(launcher=...)`` from a non-distributed parent. + Returns: self for chaining """ - import os import tempfile # Use temp directory if not specified if output_dir is None: - output_dir = tempfile.mkdtemp(prefix="linex_") + base_output_dir = Path(tempfile.mkdtemp(prefix="linex_")) + else: + base_output_dir = Path(output_dir) + + run_env = os.environ.copy() + if env: + run_env.update(env) + + dist_context = detect_distributed_context(run_env) + self._distributed_context = dist_context + + output_path = base_output_dir + if dist_context.is_distributed: + output_path = base_output_dir / dist_context.rank_tag + output_path.mkdir(parents=True, exist_ok=True) + + command_argv = normalize_command_argv(command) cmd = [ "rocprofv3", @@ -199,37 +245,67 @@ def profile( "--att-shader-engine-mask", self.shader_engine_mask, "-d", - output_dir, + str(output_path), ] if kernel_filter: cmd.extend(["--kernel-include-regex", kernel_filter]) - cmd.extend(["--", *command.split()]) + cmd.extend(["--", *command_argv]) - env = os.environ.copy() - if force_cu_mask: - env["HSA_CU_MASK"] = "0x1" # Force to CU 0 + # If a launcher is specified, prepend it: launcher rocprofv3 ... -- app + if launcher is not None: + launcher_argv = normalize_command_argv(launcher) + cmd = launcher_argv + cmd - result = subprocess.run(cmd, env=env, capture_output=True, text=True) + if force_cu_mask and "HSA_CU_MASK" not in run_env: + run_env["HSA_CU_MASK"] = "0x1" # Force to CU 0 unless caller already set a mask + + result = subprocess.run(cmd, env=run_env, capture_output=True, text=True) if result.returncode != 0: # Include stderr in error message for debugging raise RuntimeError(f"rocprofv3 failed with code {result.returncode}\n{result.stderr}") # Find generated ui_output directory - output_path = Path(output_dir) - ui_dirs = list(output_path.glob("ui_output_*")) + ui_dirs = sorted(output_path.glob("ui_output_*"), key=lambda p: p.name) if not ui_dirs: - raise RuntimeError(f"No ui_output directories found in {output_dir}") + raise RuntimeError(f"No ui_output directories found in {output_path}") + + self._rank_profiles = {} + for idx, ui_dir in enumerate(ui_dirs): + instructions, source_lines = self._load_ui_output_data(ui_dir) + if idx == 0: + rank_key = dist_context.rank_tag + else: + rank_key = f"{dist_context.rank_tag}_dispatch{idx + 1:03d}" + self._rank_profiles[rank_key] = RankProfile( + rank_key=rank_key, + global_rank=dist_context.global_rank, + local_rank=dist_context.local_rank, + world_size=dist_context.world_size, + hostname=dist_context.hostname, + launcher=dist_context.launcher, + ui_output_dir=str(ui_dir), + source_lines=sorted( + source_lines.values(), key=lambda x: x.total_cycles, reverse=True + ), + instructions=instructions, + ) - # Load the first dispatch - self._load_ui_output(ui_dirs[0]) + # Preserve existing API behavior by exposing the first rank profile as top-level fields. + primary_rank = next(iter(self._rank_profiles.values())) + self._instructions = primary_rank.instructions + self._source_lines = { + line.source_location: line for line in primary_rank.source_lines if line.source_location + } return self - def _load_ui_output(self, ui_output_dir: Path) -> None: - """Internal: Load performance trace from ui_output directory.""" + def _load_ui_output_data( + self, ui_output_dir: Path + ) -> tuple[List[InstructionData], Dict[str, SourceLine]]: + """Internal: Load performance trace data from ui_output directory.""" code_file = ui_output_dir / "code.json" if not code_file.exists(): @@ -244,7 +320,7 @@ def _load_ui_output(self, ui_output_dir: Path) -> None: ) # Parse instructions - self._instructions = [] + instructions: List[InstructionData] = [] for entry in data["code"]: inst = InstructionData( isa=entry[0], @@ -257,21 +333,27 @@ def _load_ui_output(self, ui_output_dir: Path) -> None: stall_cycles=entry[8], idle_cycles=entry[9], ) - self._instructions.append(inst) + instructions.append(inst) - # Aggregate by source line - self._aggregate_source_lines() + return instructions, self._aggregate_source_lines(instructions) + + def _load_ui_output(self, ui_output_dir: Path) -> None: + """Backward-compatible loader for a single ui_output directory.""" + instructions, source_lines = self._load_ui_output_data(ui_output_dir) + self._instructions = instructions + self._source_lines = source_lines - def _aggregate_source_lines(self): + def _aggregate_source_lines(self, instructions: Optional[List[InstructionData]] = None): """Aggregate instruction data by source line.""" - self._source_lines = {} + source_lines: Dict[str, SourceLine] = {} + instructions_to_aggregate = self._instructions if instructions is None else instructions - for inst in self._instructions: + for inst in instructions_to_aggregate: source = inst.source_location if not source or source.startswith(";"): continue - if source not in self._source_lines: + if source not in source_lines: # Parse file:line from source if ":" in source: parts = source.rsplit(":", 1) @@ -285,7 +367,7 @@ def _aggregate_source_lines(self): file = source line = 0 - self._source_lines[source] = SourceLine( + source_lines[source] = SourceLine( file=file, line_number=line, source_location=source, @@ -296,12 +378,15 @@ def _aggregate_source_lines(self): instructions=[], ) - sl = self._source_lines[source] + sl = source_lines[source] sl.execution_count += inst.execution_count sl.total_cycles += inst.latency_cycles sl.stall_cycles += inst.stall_cycles sl.idle_cycles += inst.idle_cycles sl.instructions.append(inst) + if instructions is None: + self._source_lines = source_lines + return source_lines @property def source_lines(self) -> List[SourceLine]: @@ -312,3 +397,13 @@ def source_lines(self) -> List[SourceLine]: def instructions(self) -> List[InstructionData]: """Get all instructions.""" return self._instructions + + @property + def rank_profiles(self) -> Dict[str, RankProfile]: + """Get per-rank profiles for distributed runs.""" + return self._rank_profiles + + @property + def distributed_context(self) -> DistributedContext: + """Get distributed runtime metadata detected for this profile run.""" + return self._distributed_context diff --git a/linex/src/linex/distributed.py b/linex/src/linex/distributed.py new file mode 100644 index 0000000..30857a3 --- /dev/null +++ b/linex/src/linex/distributed.py @@ -0,0 +1,115 @@ +""" +Distributed launcher helpers for Linex. +""" + +from __future__ import annotations + +import os +import shlex +import socket +from dataclasses import dataclass +from typing import Mapping, Sequence + + +@dataclass(frozen=True) +class DistributedContext: + global_rank: int = 0 + local_rank: int = 0 + world_size: int = 1 + node_rank: int = 0 + hostname: str = "" + launcher: str = "single" + + @property + def is_distributed(self) -> bool: + return self.world_size > 1 + + @property + def rank_tag(self) -> str: + return f"rank{self.global_rank:04d}" + + +def _first_int(env: Mapping[str, str], keys: Sequence[str], default: int) -> int: + for key in keys: + value = env.get(key) + if value is None: + continue + try: + return int(value) + except ValueError: + continue + return default + + +def detect_distributed_context(env: Mapping[str, str] | None = None) -> DistributedContext: + env_map = os.environ if env is None else env + global_rank = _first_int( + env_map, + [ + "RANK", + "OMPI_COMM_WORLD_RANK", + "PMI_RANK", + "PMIX_RANK", + "SLURM_PROCID", + "HOROVOD_RANK", + ], + 0, + ) + local_rank = _first_int( + env_map, + [ + "LOCAL_RANK", + "OMPI_COMM_WORLD_LOCAL_RANK", + "MPI_LOCALRANKID", + "SLURM_LOCALID", + "HOROVOD_LOCAL_RANK", + ], + 0, + ) + world_size = _first_int( + env_map, + [ + "WORLD_SIZE", + "OMPI_COMM_WORLD_SIZE", + "PMI_SIZE", + "PMIX_SIZE", + "SLURM_NTASKS", + "HOROVOD_SIZE", + ], + 1, + ) + node_rank = _first_int( + env_map, + ["GROUP_RANK", "NODE_RANK", "OMPI_COMM_WORLD_NODE_RANK", "SLURM_NODEID"], + 0, + ) + hostname = env_map.get("HOSTNAME", "") or socket.gethostname() + + launcher = "single" + if "TORCHELASTIC_RUN_ID" in env_map or "LOCAL_RANK" in env_map: + launcher = "torchrun" + elif "OMPI_COMM_WORLD_RANK" in env_map: + launcher = "mpirun" + elif "SLURM_PROCID" in env_map: + launcher = "srun" + elif "HOROVOD_RANK" in env_map: + launcher = "horovodrun" + + return DistributedContext( + global_rank=global_rank, + local_rank=local_rank, + world_size=world_size, + node_rank=node_rank, + hostname=hostname, + launcher=launcher, + ) + + +def normalize_command_argv(command: str | Sequence[str]) -> list[str]: + if isinstance(command, (list, tuple)): + argv = [str(arg) for arg in command] + else: + argv = shlex.split(command) + if not argv: + raise ValueError("Command is empty") + return argv diff --git a/linex/src/linex/mcp/server.py b/linex/src/linex/mcp/server.py index cacc885..84c7a2d 100644 --- a/linex/src/linex/mcp/server.py +++ b/linex/src/linex/mcp/server.py @@ -12,7 +12,9 @@ @mcp.tool() -def profile_application(command: str, kernel_filter: str = None, top_n: int = 10) -> dict: +def profile_application( + command: str, kernel_filter: str = None, top_n: int = 10, launcher: str = None +) -> dict: """ Profile a GPU application and get source-level performance metrics. @@ -28,12 +30,20 @@ def profile_application(command: str, kernel_filter: str = None, top_n: int = 10 Dictionary with total_source_lines, total_instructions, and hotspots list """ profiler = Linex() - profiler.profile(command, kernel_filter=kernel_filter) + profiler.profile(command, kernel_filter=kernel_filter, launcher=launcher) results = { + "distributed_context": { + "global_rank": profiler.distributed_context.global_rank, + "local_rank": profiler.distributed_context.local_rank, + "world_size": profiler.distributed_context.world_size, + "hostname": profiler.distributed_context.hostname, + "launcher": profiler.distributed_context.launcher, + }, "total_source_lines": len(profiler.source_lines), "total_instructions": len(profiler.instructions), "hotspots": [], + "per_rank_hotspots": [], } for i, line in enumerate(profiler.source_lines[:top_n], 1): @@ -52,12 +62,46 @@ def profile_application(command: str, kernel_filter: str = None, top_n: int = 10 } ) + for rank_key, rank_profile in profiler.rank_profiles.items(): + rank_entry = { + "rank_key": rank_key, + "global_rank": rank_profile.global_rank, + "local_rank": rank_profile.local_rank, + "world_size": rank_profile.world_size, + "hostname": rank_profile.hostname, + "launcher": rank_profile.launcher, + "ui_output_dir": rank_profile.ui_output_dir, + "total_source_lines": len(rank_profile.source_lines), + "total_instructions": len(rank_profile.instructions), + "hotspots": [], + } + for i, line in enumerate(rank_profile.source_lines[:top_n], 1): + rank_entry["hotspots"].append( + { + "rank": i, + "file": line.file, + "line_number": line.line_number, + "source_location": line.source_location, + "total_cycles": line.total_cycles, + "stall_cycles": line.stall_cycles, + "stall_percent": round(line.stall_percent, 2), + "idle_cycles": line.idle_cycles, + "execution_count": line.execution_count, + "num_instructions": len(line.instructions), + } + ) + results["per_rank_hotspots"].append(rank_entry) + return results @mcp.tool() def analyze_instruction_hotspots( - command: str, kernel_filter: str = None, top_lines: int = 5, top_instructions_per_line: int = 10 + command: str, + kernel_filter: str = None, + top_lines: int = 5, + top_instructions_per_line: int = 10, + launcher: str = None, ) -> dict: """ Get detailed instruction-level analysis for the hottest source lines. @@ -75,9 +119,19 @@ def analyze_instruction_hotspots( Dictionary with hotspot_analysis list containing ISA-level details """ profiler = Linex() - profiler.profile(command, kernel_filter=kernel_filter) + profiler.profile(command, kernel_filter=kernel_filter, launcher=launcher) - results = {"hotspot_analysis": []} + results = { + "distributed_context": { + "global_rank": profiler.distributed_context.global_rank, + "local_rank": profiler.distributed_context.local_rank, + "world_size": profiler.distributed_context.world_size, + "hostname": profiler.distributed_context.hostname, + "launcher": profiler.distributed_context.launcher, + }, + "hotspot_analysis": [], + "per_rank_hotspot_analysis": [], + } for line in profiler.source_lines[:top_lines]: # Sort instructions by latency @@ -105,6 +159,40 @@ def analyze_instruction_hotspots( results["hotspot_analysis"].append(line_data) + for rank_key, rank_profile in profiler.rank_profiles.items(): + rank_entry = { + "rank_key": rank_key, + "global_rank": rank_profile.global_rank, + "local_rank": rank_profile.local_rank, + "world_size": rank_profile.world_size, + "hostname": rank_profile.hostname, + "launcher": rank_profile.launcher, + "ui_output_dir": rank_profile.ui_output_dir, + "hotspot_analysis": [], + } + for line in rank_profile.source_lines[:top_lines]: + sorted_insts = sorted(line.instructions, key=lambda x: x.latency_cycles, reverse=True) + line_data = { + "source_location": line.source_location, + "total_cycles": line.total_cycles, + "stall_percent": round(line.stall_percent, 2), + "instructions": [], + } + for inst in sorted_insts[:top_instructions_per_line]: + line_data["instructions"].append( + { + "isa": inst.isa, + "latency_cycles": inst.latency_cycles, + "stall_cycles": inst.stall_cycles, + "stall_percent": round(inst.stall_percent, 2), + "idle_cycles": inst.idle_cycles, + "execution_count": inst.execution_count, + "instruction_address": f"0x{inst.instruction_address:08x}", + } + ) + rank_entry["hotspot_analysis"].append(line_data) + results["per_rank_hotspot_analysis"].append(rank_entry) + return results diff --git a/linex/tests/test_distributed_api.py b/linex/tests/test_distributed_api.py new file mode 100644 index 0000000..c6feaeb --- /dev/null +++ b/linex/tests/test_distributed_api.py @@ -0,0 +1,152 @@ +# SPDX-License-Identifier: MIT + +import json +from pathlib import Path +from unittest.mock import MagicMock, patch + +from linex import Linex +from linex.distributed import detect_distributed_context, normalize_command_argv + + +def _write_code_json(ui_dir: Path, source_location: str) -> None: + ui_dir.mkdir(parents=True, exist_ok=True) + code = { + "code": [ + [ + "s_add_u32 s0, s0, s1", + 0, + 0, + source_location, + 1, + 0x1000, + 4, + 100, + 20, + 0, + ] + ] + } + (ui_dir / "code.json").write_text(json.dumps(code)) + + +def test_distributed_helpers_parse_common_env(): + ctx = detect_distributed_context( + {"SLURM_PROCID": "3", "SLURM_LOCALID": "1", "SLURM_NTASKS": "8"} + ) + assert ctx.global_rank == 3 + assert ctx.local_rank == 1 + assert ctx.world_size == 8 + assert ctx.launcher == "srun" + assert normalize_command_argv('torchrun --nproc_per_node=2 train.py --arg "two words"')[-1] == ( + "two words" + ) + + +def test_profile_uses_rank_scoped_output_and_loads_deterministic_ui_dir(tmp_path): + dummy_decoder = tmp_path / "decoder" / "librocprof-trace-decoder.so" + dummy_decoder.parent.mkdir(parents=True, exist_ok=True) + dummy_decoder.write_text("placeholder") + + def fake_run(cmd, **kwargs): + output_dir = Path(cmd[cmd.index("-d") + 1]) + _write_code_json(output_dir / "ui_output_200", "kernel2.hip:20") + _write_code_json(output_dir / "ui_output_100", "kernel1.hip:10") + m = MagicMock() + m.returncode = 0 + m.stdout = "" + m.stderr = "" + return m + + with patch.object(Linex, "_ensure_decoder", return_value=dummy_decoder), patch( + "subprocess.run", side_effect=fake_run + ): + profiler = Linex() + profiler.profile( + command='python -c "print(1)"', + output_dir=str(tmp_path / "linex_out"), + env={"RANK": "1", "LOCAL_RANK": "1", "WORLD_SIZE": "2"}, + ) + + # Primary profile should come from lexicographically first ui_output directory. + assert profiler.source_lines[0].source_location == "kernel1.hip:10" + assert profiler.distributed_context.is_distributed + assert profiler.distributed_context.global_rank == 1 + assert len(profiler.rank_profiles) == 2 + + +def test_profile_with_launcher_builds_correct_command_order(tmp_path): + """When launcher is provided, the subprocess command should be + launcher_argv + rocprofv3 ... -- app_argv.""" + dummy_decoder = tmp_path / "decoder" / "librocprof-trace-decoder.so" + dummy_decoder.parent.mkdir(parents=True, exist_ok=True) + dummy_decoder.write_text("placeholder") + + captured_cmd = [] + + def fake_run(cmd, **kwargs): + captured_cmd.extend(cmd) + output_dir = Path(cmd[cmd.index("-d") + 1]) + ui_dir = output_dir / "ui_output_000" + ui_dir.mkdir(parents=True, exist_ok=True) + code = {"code": [["s_nop 0", 0, 0, "test.hip:1", 1, 0x1000, 4, 10, 2, 0]]} + (ui_dir / "code.json").write_text(json.dumps(code)) + m = MagicMock() + m.returncode = 0 + m.stdout = "" + m.stderr = "" + return m + + with patch.object(Linex, "_ensure_decoder", return_value=dummy_decoder), patch( + "subprocess.run", side_effect=fake_run + ): + profiler = Linex() + profiler.profile( + command="train.py --lr 0.01", + launcher="torchrun --nproc_per_node=4", + output_dir=str(tmp_path / "out"), + ) + + # Command should be: torchrun --nproc_per_node=4 rocprofv3 ... -- train.py --lr 0.01 + assert captured_cmd[0] == "torchrun" + assert captured_cmd[1] == "--nproc_per_node=4" + rocprofv3_idx = captured_cmd.index("rocprofv3") + assert rocprofv3_idx == 2 + separator_idx = captured_cmd.index("--") + assert captured_cmd[separator_idx + 1] == "train.py" + assert captured_cmd[separator_idx + 2] == "--lr" + assert captured_cmd[separator_idx + 3] == "0.01" + + +def test_profile_without_launcher_uses_plain_rocprofv3(tmp_path): + """Without launcher, command should be: rocprofv3 ... -- app_argv.""" + dummy_decoder = tmp_path / "decoder" / "librocprof-trace-decoder.so" + dummy_decoder.parent.mkdir(parents=True, exist_ok=True) + dummy_decoder.write_text("placeholder") + + captured_cmd = [] + + def fake_run(cmd, **kwargs): + captured_cmd.extend(cmd) + output_dir = Path(cmd[cmd.index("-d") + 1]) + ui_dir = output_dir / "ui_output_000" + ui_dir.mkdir(parents=True, exist_ok=True) + code = {"code": [["s_nop 0", 0, 0, "test.hip:1", 1, 0x1000, 4, 10, 2, 0]]} + (ui_dir / "code.json").write_text(json.dumps(code)) + m = MagicMock() + m.returncode = 0 + m.stdout = "" + m.stderr = "" + return m + + with patch.object(Linex, "_ensure_decoder", return_value=dummy_decoder), patch( + "subprocess.run", side_effect=fake_run + ): + profiler = Linex() + profiler.profile( + command="./my_app --size 1024", + output_dir=str(tmp_path / "out"), + ) + + assert captured_cmd[0] == "rocprofv3" + separator_idx = captured_cmd.index("--") + assert captured_cmd[separator_idx + 1] == "./my_app" diff --git a/metrix/README.md b/metrix/README.md index f6e8351..c98cdd6 100644 --- a/metrix/README.md +++ b/metrix/README.md @@ -47,6 +47,29 @@ metrix --metrics memory.l2_hit_rate,memory.coalescing_efficiency ./my_app metrix -o results.json ./my_app ``` +## Distributed Launchers + +Metrix supports distributed profiling with launchers like `torchrun`, `mpirun`, +`srun`, and `horovodrun`. Use `--launcher` to specify the launcher command +separately from the application, ensuring the correct invocation order +(`launcher rocprofv3 ... -- app`). + +```bash +# Torch distributed +metrix profile --launcher "torchrun --nproc_per_node=8" -- train.py + +# MPI +metrix profile --launcher "mpirun -np 8" -- ./my_app --problem-size 4096 + +# Slurm +metrix profile --launcher "srun -N 2 -n 16" -- ./my_app +``` + +When distributed rank variables are present (`RANK`, `OMPI_COMM_WORLD_RANK`, +`SLURM_PROCID`, etc.), Metrix emits per-rank metadata and automatically suffixes +`--output/-o` files with a rank tag (for example, `results.rank0003.json`) to avoid +cross-rank clobbering. + ## Python API ```python diff --git a/metrix/src/metrix/api.py b/metrix/src/metrix/api.py index 02738f2..dd7cae9 100644 --- a/metrix/src/metrix/api.py +++ b/metrix/src/metrix/api.py @@ -8,7 +8,7 @@ """ import re -from typing import List, Optional, Dict +from typing import Dict, List, Optional, Sequence from dataclasses import dataclass from pathlib import Path @@ -16,6 +16,7 @@ from .backends.base import CounterBackend from .metrics import METRIC_PROFILES, METRIC_CATALOG from .logger import logger +from .utils.distributed import detect_distributed_context, normalize_command_argv @dataclass @@ -28,6 +29,10 @@ class KernelResults: duration_us: Statistics metrics: Dict[str, Statistics] dispatch_count: int = 1 + global_rank: int = 0 + local_rank: int = 0 + world_size: int = 1 + hostname: str = "" @property def avg_time_us(self) -> float: @@ -46,6 +51,11 @@ class ProfilingResults: command: str kernels: List[KernelResults] total_kernels: int + global_rank: int = 0 + local_rank: int = 0 + world_size: int = 1 + hostname: str = "" + launcher: str = "single" class Metrix: @@ -78,10 +88,11 @@ def __init__(self, arch: Optional[str] = None): def profile( self, - command: str, + command: str | Sequence[str], metrics: Optional[List[str]] = None, profile: Optional[str] = None, kernel_filter: Optional[str] = None, + launcher: Optional[str | Sequence[str]] = None, time_only: bool = False, num_replays: int = 1, aggregate_by_kernel: bool = True, @@ -110,6 +121,12 @@ def profile( cwd: Working directory for command execution (default: None) timeout_seconds: Timeout in seconds for profiling (default: 0, zero or None for no timeout) + Note: + When using ``launcher``, rank metadata is detected from environment + variables (RANK, LOCAL_RANK, etc.) set by the launcher in child + processes. For torchrun, run ``torchrun --no-python metrix ...`` + so each rank gets its own metrix process with correct env vars. + Returns: ProfilingResults object with all collected data """ @@ -154,7 +171,9 @@ def profile( # Use simple kernel filter (no regex) rocprof_filter = kernel_filter - logger.info(f"Profiling: {command}") + command_argv = normalize_command_argv(command) + command_string = " ".join(command_argv) + logger.info(f"Profiling: {command_string}") logger.info(f"Collecting {len(metrics_to_compute)} metrics across {num_replays} replay(s)") if rocprof_filter: logger.info(f"Kernel filter: {rocprof_filter}") @@ -162,10 +181,11 @@ def profile( # Profile using backend (filtering at rocprofv3 level) logger.debug(f"Calling backend.profile with {len(metrics_to_compute)} metrics") self.backend.profile( - command=command, + command=command_argv, metrics=metrics_to_compute, num_replays=num_replays, aggregate_by_kernel=aggregate_by_kernel, + launcher=launcher, kernel_filter=rocprof_filter, cwd=cwd, timeout_seconds=timeout_seconds, @@ -177,10 +197,21 @@ def profile( if not dispatch_keys: logger.warning("No kernels profiled") - return ProfilingResults(command=command, kernels=[], total_kernels=0) + dist_context = detect_distributed_context() + return ProfilingResults( + command=command_string, + kernels=[], + total_kernels=0, + global_rank=dist_context.global_rank, + local_rank=dist_context.local_rank, + world_size=dist_context.world_size, + hostname=dist_context.hostname, + launcher=dist_context.launcher, + ) # Build result objects kernel_results = [] + dist_context = detect_distributed_context() for dispatch_key in dispatch_keys: # Get duration duration = self.backend._aggregated[dispatch_key].get("duration_us") @@ -206,11 +237,22 @@ def profile( duration_us=duration, metrics=computed_metrics, dispatch_count=int(dispatch_count), + global_rank=dist_context.global_rank, + local_rank=dist_context.local_rank, + world_size=dist_context.world_size, + hostname=dist_context.hostname, ) kernel_results.append(kernel_result) return ProfilingResults( - command=command, kernels=kernel_results, total_kernels=len(kernel_results) + command=command_string, + kernels=kernel_results, + total_kernels=len(kernel_results), + global_rank=dist_context.global_rank, + local_rank=dist_context.local_rank, + world_size=dist_context.world_size, + hostname=dist_context.hostname, + launcher=dist_context.launcher, ) def list_metrics(self, category: Optional[str] = None) -> List[str]: diff --git a/metrix/src/metrix/backends/base.py b/metrix/src/metrix/backends/base.py index 65c8cc5..42f88b8 100644 --- a/metrix/src/metrix/backends/base.py +++ b/metrix/src/metrix/backends/base.py @@ -6,7 +6,7 @@ """ from abc import ABC, abstractmethod -from typing import List, Dict, Optional +from typing import List, Dict, Optional, Sequence from dataclasses import dataclass from collections import defaultdict @@ -66,6 +66,12 @@ class ProfileResult: arch_vgpr: int = 0 accum_vgpr: int = 0 sgpr: int = 0 + global_rank: int = 0 + local_rank: int = 0 + world_size: int = 1 + node_rank: int = 0 + hostname: str = "" + launcher: str = "single" class CounterBackend(ABC): @@ -465,10 +471,11 @@ def _split_counters_into_passes(self, counters: List[str]) -> List[List[str]]: def profile( self, - command: str, + command: str | Sequence[str], metrics: List[str], num_replays: int = 5, aggregate_by_kernel: bool = False, + launcher: Optional[str | Sequence[str]] = None, kernel_filter: Optional[str] = None, cwd: Optional[str] = None, timeout_seconds: Optional[int] = 0, @@ -547,6 +554,7 @@ def profile( metrics=batch_metrics, num_replays=num_replays, aggregate_by_kernel=aggregate_by_kernel, + launcher=launcher, kernel_filter=kernel_filter, cwd=cwd, timeout_seconds=timeout_seconds, @@ -774,10 +782,11 @@ def _compute_with_stat_type( @abstractmethod def _run_rocprof( self, - command: str, + command: str | Sequence[str], counters: List[str], kernel_filter: Optional[str] = None, cwd: Optional[str] = None, + launcher: Optional[str | Sequence[str]] = None, timeout_seconds: Optional[int] = 0, kernel_iteration_range: Optional[str] = None, ) -> List[ProfileResult]: @@ -807,7 +816,12 @@ def _aggregate_by_dispatch_across_runs( # Group by dispatch_id:kernel_name groups = defaultdict(list) for result in results: - key = f"dispatch_{result.dispatch_id}:{result.kernel_name}" + if result.world_size > 1: + key = ( + f"rank_{result.global_rank}:dispatch_{result.dispatch_id}:{result.kernel_name}" + ) + else: + key = f"dispatch_{result.dispatch_id}:{result.kernel_name}" groups[key].append(result) # Compute stats for each group @@ -841,7 +855,11 @@ def _aggregate_by_kernel_then_runs( # Now aggregate merged results across replays groups = defaultdict(list) for merged in merged_replays: - groups[merged.kernel_name].append(merged) + if merged.world_size > 1: + key = f"rank_{merged.global_rank}:{merged.kernel_name}" + else: + key = merged.kernel_name + groups[key].append(merged) aggregated = {} for kernel_name, dispatches in groups.items(): @@ -945,6 +963,12 @@ def _should_average(name: str) -> bool: arch_vgpr=first.arch_vgpr, accum_vgpr=first.accum_vgpr, sgpr=first.sgpr, + global_rank=first.global_rank, + local_rank=first.local_rank, + world_size=first.world_size, + node_rank=first.node_rank, + hostname=first.hostname, + launcher=first.launcher, ) merged._num_dispatches = len(dispatches) return merged diff --git a/metrix/src/metrix/backends/gfx1201.py b/metrix/src/metrix/backends/gfx1201.py index ee477af..eca553c 100644 --- a/metrix/src/metrix/backends/gfx1201.py +++ b/metrix/src/metrix/backends/gfx1201.py @@ -7,7 +7,7 @@ from .base import CounterBackend, DeviceSpecs, ProfileResult, Statistics from ..profiler.rocprof_wrapper import ROCProfV3Wrapper from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Sequence class GFX1201Backend(CounterBackend): @@ -52,10 +52,11 @@ def _get_device_specs(self) -> DeviceSpecs: def _run_rocprof( self, - command: str, + command: str | Sequence[str], counters: List[str], kernel_filter: Optional[str] = None, cwd: Optional[str] = None, + launcher: Optional[str | Sequence[str]] = None, timeout_seconds: Optional[int] = 0, kernel_iteration_range: Optional[str] = None, ) -> List[ProfileResult]: diff --git a/metrix/src/metrix/backends/gfx90a.py b/metrix/src/metrix/backends/gfx90a.py index bc38992..0d63edc 100644 --- a/metrix/src/metrix/backends/gfx90a.py +++ b/metrix/src/metrix/backends/gfx90a.py @@ -9,7 +9,7 @@ from ..utils.common import split_counters_into_passes from .decorator import metric from ..profiler.rocprof_wrapper import ROCProfV3Wrapper -from typing import List, Optional, Dict +from typing import Dict, List, Optional, Sequence class GFX90aBackend(CounterBackend): @@ -94,10 +94,11 @@ def _get_counter_block_limits(self) -> Dict[str, int]: def _run_rocprof( self, - command: str, + command: str | Sequence[str], counters: List[str], kernel_filter: Optional[str] = None, cwd: Optional[str] = None, + launcher: Optional[str | Sequence[str]] = None, timeout_seconds: Optional[int] = 0, ) -> List[ProfileResult]: """Run rocprofv3 and return results (single pass only - base class handles multi-pass)""" diff --git a/metrix/src/metrix/backends/gfx942.py b/metrix/src/metrix/backends/gfx942.py index 3391012..3c6e08a 100644 --- a/metrix/src/metrix/backends/gfx942.py +++ b/metrix/src/metrix/backends/gfx942.py @@ -9,7 +9,7 @@ from ..utils.common import split_counters_into_passes from .decorator import metric from ..profiler.rocprof_wrapper import ROCProfV3Wrapper -from typing import List, Optional, Dict +from typing import Dict, List, Optional, Sequence class GFX942Backend(CounterBackend): @@ -94,10 +94,11 @@ def _get_counter_block_limits(self) -> Dict[str, int]: def _run_rocprof( self, - command: str, + command: str | Sequence[str], counters: List[str], kernel_filter: Optional[str] = None, cwd: Optional[str] = None, + launcher: Optional[str | Sequence[str]] = None, timeout_seconds: Optional[int] = 0, ) -> List[ProfileResult]: """Run rocprofv3 and return results (single pass only - base class handles multi-pass)""" diff --git a/metrix/src/metrix/cli/main.py b/metrix/src/metrix/cli/main.py index ae99015..a63d202 100644 --- a/metrix/src/metrix/cli/main.py +++ b/metrix/src/metrix/cli/main.py @@ -51,9 +51,22 @@ def create_parser(): description="Collect performance metrics from GPU kernels", ) + profile_parser.add_argument( + "--launcher", + default=None, + help=( + "Distributed launcher command to wrap rocprofv3 " + "(e.g., 'torchrun --nproc_per_node=8' or 'mpirun -np 4')" + ), + ) + profile_parser.add_argument( "target", - help="Target application command (e.g., ./my_app or './my_app arg1 arg2')", + nargs=argparse.REMAINDER, + help=( + "Target application command. Use '--' before the command for complex launch lines " + "(e.g., metrix profile -- torchrun --nproc_per_node=8 train.py)" + ), ) profile_group = profile_parser.add_mutually_exclusive_group() diff --git a/metrix/src/metrix/cli/profile_cmd.py b/metrix/src/metrix/cli/profile_cmd.py index 9ddd17a..bcd9981 100644 --- a/metrix/src/metrix/cli/profile_cmd.py +++ b/metrix/src/metrix/cli/profile_cmd.py @@ -13,10 +13,18 @@ from ..backends import get_backend, Statistics, detect_or_default from ..metrics import METRIC_PROFILES, METRIC_CATALOG from ..logger import logger +from ..utils.distributed import ( + apply_rank_suffix, + detect_distributed_context, + normalize_command_argv, +) def profile_command(args): - """Execute profile command using clean backend API""" + """Execute profile command using clean backend API.""" + command_argv = normalize_command_argv(_normalize_cli_target(args.target)) + command_display = " ".join(command_argv) + dist_context = detect_distributed_context() # Auto-detect architecture arch = detect_or_default(None) @@ -79,7 +87,16 @@ def profile_command(args): # Log configuration logger.info(f"{'=' * 80}") logger.info(f"Metrix: {mode}") - logger.info(f"Target: {args.target}") + logger.info(f"Target: {command_display}") + if dist_context.is_distributed: + logger.info( + "Distributed context: launcher=%s rank=%s/%s local_rank=%s host=%s", + dist_context.launcher, + dist_context.global_rank, + dist_context.world_size, + dist_context.local_rank, + dist_context.hostname, + ) if args.num_replays > 1: logger.info(f"Replays: {args.num_replays}") if args.kernel: @@ -95,10 +112,11 @@ def profile_command(args): logger.info(f"Running {args.num_replays} replays...") backend.profile( - command=args.target, + command=command_argv, metrics=metrics_to_compute, num_replays=args.num_replays, aggregate_by_kernel=args.aggregate, + launcher=args.launcher, kernel_filter=kernel_filter, ) except Exception as e: @@ -146,28 +164,52 @@ def profile_command(args): logger.warning(f"Failed to compute {metric} for {dispatch_key}: {e}") # Output results - if args.output: + output_path = args.output + if output_path and dist_context.is_distributed: + output_path = apply_rank_suffix(output_path, dist_context) + logger.info( + "Distributed output path for rank %s: %s", dist_context.global_rank, output_path + ) + + if output_path: # Detect format from file extension - output_path = Path(args.output) - ext = output_path.suffix.lower() + output_file = Path(output_path) + ext = output_file.suffix.lower() if ext == ".json": - _write_json_output(output_path, results, metrics_to_compute) + _write_json_output(output_file, results, metrics_to_compute, dist_context) elif ext == ".csv": - _write_csv_output(output_path, results, metrics_to_compute, args.aggregate) + _write_csv_output( + output_file, results, metrics_to_compute, args.aggregate, dist_context + ) else: # Default to text - _write_text_output(output_path, results, metrics_to_compute, args.aggregate) + _write_text_output( + output_file, results, metrics_to_compute, args.aggregate, dist_context + ) else: # Print to stdout - _print_text_results(results, metrics_to_compute, args.aggregate, args.no_counters) + _print_text_results( + results, metrics_to_compute, args.aggregate, args.no_counters, dist_context + ) logger.info(f"Profiled {len(results)} dispatch(es)/kernel(s)") return 0 -def _print_text_results(results: Dict, metrics: List[str], aggregated: bool, no_counters: bool): +def _normalize_cli_target(target) -> str | list[str]: + """Normalize argparse target (string or remainder list) into command input.""" + if isinstance(target, list): + if target and target[0] == "--": + return target[1:] + return target + return target + + +def _print_text_results( + results: Dict, metrics: List[str], aggregated: bool, no_counters: bool, dist_context +): """Print results to stdout in human-readable format""" # Group metrics by category @@ -194,14 +236,29 @@ def _print_text_results(results: Dict, metrics: List[str], aggregated: bool, no_ if aggregated: print(f"Kernel: {dispatch_key}") else: - # dispatch_key is like "dispatch_1:kernel_name" - parts = dispatch_key.split(":", 1) - if len(parts) == 2: + # dispatch_key may be: + # - "dispatch_1:kernel_name" + # - "rank_1:dispatch_1:kernel_name" (distributed) + parts = dispatch_key.split(":") + if ( + len(parts) >= 2 + and parts[0].startswith("rank_") + and parts[1].startswith("dispatch_") + ): + dispatch_id = parts[1].replace("dispatch_", "") + kernel_name = ":".join(parts[2:]) if len(parts) > 2 else "" + print(f"Dispatch #{dispatch_id}: {kernel_name}") + elif len(parts) >= 2 and parts[0].startswith("dispatch_"): dispatch_id = parts[0].replace("dispatch_", "") - kernel_name = parts[1] + kernel_name = ":".join(parts[1:]) print(f"Dispatch #{dispatch_id}: {kernel_name}") else: print(f"Kernel: {dispatch_key}") + if dist_context.is_distributed: + print( + f"Rank: {dist_context.global_rank}/{dist_context.world_size} " + f"(local={dist_context.local_rank}, host={dist_context.hostname})" + ) print(f"{'─' * 80}") # Duration @@ -228,9 +285,17 @@ def _print_text_results(results: Dict, metrics: List[str], aggregated: bool, no_ print(f" {name:45s} {stats.avg:10.2f} {unit}") -def _write_json_output(output_path: Path, results: Dict, metrics: List[str]): +def _write_json_output(output_path: Path, results: Dict, metrics: List[str], dist_context): """Write results to JSON file""" - json_data = {} + json_data = { + "_rank": { + "global_rank": dist_context.global_rank, + "local_rank": dist_context.local_rank, + "world_size": dist_context.world_size, + "hostname": dist_context.hostname, + "launcher": dist_context.launcher, + } + } for dispatch_key, data in results.items(): json_data[dispatch_key] = { @@ -258,13 +323,24 @@ def _write_json_output(output_path: Path, results: Dict, metrics: List[str]): logger.info(f"Results written to {output_path}") -def _write_csv_output(output_path: Path, results: Dict, metrics: List[str], aggregated: bool): +def _write_csv_output( + output_path: Path, results: Dict, metrics: List[str], aggregated: bool, dist_context +): """Write results to CSV file""" with open(output_path, "w", newline="") as f: writer = csv.writer(f) # Header - header = ["dispatch_key", "duration_min_us", "duration_max_us", "duration_avg_us"] + header = [ + "global_rank", + "local_rank", + "world_size", + "hostname", + "dispatch_key", + "duration_min_us", + "duration_max_us", + "duration_avg_us", + ] for metric in metrics: header.extend( [ @@ -277,7 +353,13 @@ def _write_csv_output(output_path: Path, results: Dict, metrics: List[str], aggr # Data rows for dispatch_key, data in results.items(): - row = [dispatch_key] + row = [ + dist_context.global_rank, + dist_context.local_rank, + dist_context.world_size, + dist_context.hostname, + dispatch_key, + ] duration = data.get("duration_us") if duration: @@ -297,7 +379,9 @@ def _write_csv_output(output_path: Path, results: Dict, metrics: List[str], aggr logger.info(f"Results written to {output_path}") -def _write_text_output(output_path: Path, results: Dict, metrics: List[str], aggregated: bool): +def _write_text_output( + output_path: Path, results: Dict, metrics: List[str], aggregated: bool, dist_context +): """Write results to text file""" buffer = StringIO() @@ -305,7 +389,7 @@ def _write_text_output(output_path: Path, results: Dict, metrics: List[str], agg old_stdout = sys.stdout sys.stdout = buffer - _print_text_results(results, metrics, aggregated, no_counters=False) + _print_text_results(results, metrics, aggregated, no_counters=False, dist_context=dist_context) sys.stdout = old_stdout diff --git a/metrix/src/metrix/mcp/server.py b/metrix/src/metrix/mcp/server.py index 2d62c89..43ec28b 100644 --- a/metrix/src/metrix/mcp/server.py +++ b/metrix/src/metrix/mcp/server.py @@ -12,7 +12,7 @@ @mcp.tool() -def profile_metrics(command: str, metrics: list[str] = None) -> dict: +def profile_metrics(command: str, metrics: list[str] = None, launcher: str = None) -> dict: """ Profile GPU application and collect hardware performance metrics. @@ -33,13 +33,26 @@ def profile_metrics(command: str, metrics: list[str] = None) -> dict: if metrics is None: metrics = ["memory.hbm_bandwidth_utilization"] - results_obj = profiler.profile(command, metrics=metrics) + results_obj = profiler.profile(command, metrics=metrics, launcher=launcher) - results = {"kernels": []} + results = { + "rank": { + "global_rank": results_obj.global_rank, + "local_rank": results_obj.local_rank, + "world_size": results_obj.world_size, + "hostname": results_obj.hostname, + "launcher": results_obj.launcher, + }, + "kernels": [], + } for kernel in results_obj.kernels: kernel_data = { "name": kernel.name, + "global_rank": kernel.global_rank, + "local_rank": kernel.local_rank, + "world_size": kernel.world_size, + "hostname": kernel.hostname, "duration_us_avg": float(kernel.duration_us.avg) if hasattr(kernel.duration_us, "avg") else 0.0, diff --git a/metrix/src/metrix/profiler/rocprof_wrapper.py b/metrix/src/metrix/profiler/rocprof_wrapper.py index 13f5315..c204942 100644 --- a/metrix/src/metrix/profiler/rocprof_wrapper.py +++ b/metrix/src/metrix/profiler/rocprof_wrapper.py @@ -10,11 +10,12 @@ import os import yaml from pathlib import Path -from typing import List, Dict, Optional +from typing import List, Dict, Optional, Sequence # Import ProfileResult from backends to avoid duplication from ..backends.base import ProfileResult from ..logger import logger +from ..utils.distributed import detect_distributed_context, normalize_command_argv class ROCProfV3Wrapper: @@ -70,7 +71,7 @@ def _needs_extra_counters(counter_defs_file: Path) -> bool: def profile( self, - command: str, + command: str | Sequence[str], counters: List[str], output_dir: Optional[Path] = None, kernel_filter: Optional[str] = None, @@ -78,6 +79,8 @@ def profile( kernel_iteration_range: Optional[str] = None, extra_counters_path: Optional[Path] = None, arch: Optional[str] = None, + env: Optional[Dict[str, str]] = None, + launcher: Optional[str | Sequence[str]] = None, ) -> List[ProfileResult]: """ Profile a command with specified counters (single pass). @@ -166,8 +169,14 @@ def profile( prof_cmd.extend(["--kernel-include-regex", kernel_filter]) # Add target command + command_argv = normalize_command_argv(command) prof_cmd.append("--") - prof_cmd.extend(command.split()) + prof_cmd.extend(command_argv) + + # If a launcher is specified, prepend it: launcher rocprofv3 ... -- app + if launcher is not None: + launcher_argv = normalize_command_argv(launcher) + prof_cmd = launcher_argv + prof_cmd logger.debug(f"rocprofv3 command: {' '.join(prof_cmd)}") logger.info(f"Starting rocprofv3 with {len(counters)} counters") @@ -181,8 +190,17 @@ def profile( logger.debug(f"Timeout: {self.timeout}") logger.debug(f"CWD: {cwd}") + run_env = os.environ.copy() + if env: + run_env.update(env) + result = subprocess.run( - prof_cmd, capture_output=True, timeout=self.timeout, text=True, cwd=cwd + prof_cmd, + capture_output=True, + timeout=self.timeout, + text=True, + cwd=cwd, + env=run_env, ) logger.info("subprocess.run returned!") @@ -240,6 +258,15 @@ def profile( results = [r for r in results if pattern.search(r.kernel_name)] logger.info(f"After kernel filter '{kernel_filter}': {len(results)} dispatch(es)") + dist_context = detect_distributed_context(run_env) + for profile_result in results: + profile_result.global_rank = dist_context.global_rank + profile_result.local_rank = dist_context.local_rank + profile_result.world_size = dist_context.world_size + profile_result.node_rank = dist_context.node_rank + profile_result.hostname = dist_context.hostname + profile_result.launcher = dist_context.launcher + return results except subprocess.TimeoutExpired: diff --git a/metrix/src/metrix/utils/distributed.py b/metrix/src/metrix/utils/distributed.py new file mode 100644 index 0000000..d799e95 --- /dev/null +++ b/metrix/src/metrix/utils/distributed.py @@ -0,0 +1,129 @@ +""" +Helpers for distributed launch environments. +""" + +from __future__ import annotations + +import os +import shlex +import socket +from dataclasses import dataclass +from typing import Mapping, Sequence + + +@dataclass(frozen=True) +class DistributedContext: + global_rank: int = 0 + local_rank: int = 0 + world_size: int = 1 + node_rank: int = 0 + hostname: str = "" + launcher: str = "single" + + @property + def is_distributed(self) -> bool: + return self.world_size > 1 + + @property + def rank_tag(self) -> str: + return f"rank{self.global_rank:04d}" + + +def _first_int(env: Mapping[str, str], keys: Sequence[str], default: int) -> int: + for key in keys: + value = env.get(key) + if value is None: + continue + try: + return int(value) + except ValueError: + continue + return default + + +def detect_distributed_context(env: Mapping[str, str] | None = None) -> DistributedContext: + env_map = os.environ if env is None else env + global_rank = _first_int( + env_map, + [ + "RANK", + "OMPI_COMM_WORLD_RANK", + "PMI_RANK", + "PMIX_RANK", + "SLURM_PROCID", + "HOROVOD_RANK", + ], + 0, + ) + local_rank = _first_int( + env_map, + [ + "LOCAL_RANK", + "OMPI_COMM_WORLD_LOCAL_RANK", + "MPI_LOCALRANKID", + "SLURM_LOCALID", + "HOROVOD_LOCAL_RANK", + ], + 0, + ) + world_size = _first_int( + env_map, + [ + "WORLD_SIZE", + "OMPI_COMM_WORLD_SIZE", + "PMI_SIZE", + "PMIX_SIZE", + "SLURM_NTASKS", + "HOROVOD_SIZE", + ], + 1, + ) + node_rank = _first_int( + env_map, + ["GROUP_RANK", "NODE_RANK", "OMPI_COMM_WORLD_NODE_RANK", "SLURM_NODEID"], + 0, + ) + hostname = env_map.get("HOSTNAME", "") or socket.gethostname() + + launcher = "single" + if "TORCHELASTIC_RUN_ID" in env_map or "LOCAL_RANK" in env_map: + launcher = "torchrun" + elif "OMPI_COMM_WORLD_RANK" in env_map: + launcher = "mpirun" + elif "SLURM_PROCID" in env_map: + launcher = "srun" + elif "HOROVOD_RANK" in env_map: + launcher = "horovodrun" + + return DistributedContext( + global_rank=global_rank, + local_rank=local_rank, + world_size=world_size, + node_rank=node_rank, + hostname=hostname, + launcher=launcher, + ) + + +def normalize_command_argv(command: str | Sequence[str]) -> list[str]: + if isinstance(command, (list, tuple)): + argv = [str(arg) for arg in command] + else: + argv = shlex.split(command) + if not argv: + raise ValueError("Command is empty") + return argv + + +def apply_rank_suffix(path: str, context: DistributedContext) -> str: + """Append rank suffix to output paths for distributed runs.""" + if not context.is_distributed: + return path + + from pathlib import Path as _Path + + p = _Path(path) + suffix = p.suffix + if suffix: + return str(p.with_name(f"{p.stem}.{context.rank_tag}{suffix}")) + return str(p.with_name(f"{p.name}.{context.rank_tag}")) diff --git a/metrix/tests/unit/test_distributed.py b/metrix/tests/unit/test_distributed.py new file mode 100644 index 0000000..1afc0e3 --- /dev/null +++ b/metrix/tests/unit/test_distributed.py @@ -0,0 +1,63 @@ +""" +Unit tests for distributed launcher helpers. +""" + +from metrix.utils.distributed import ( + DistributedContext, + apply_rank_suffix, + detect_distributed_context, + normalize_command_argv, +) + + +def test_detect_distributed_context_torchrun_env(): + env = {"RANK": "2", "LOCAL_RANK": "0", "WORLD_SIZE": "4", "TORCHELASTIC_RUN_ID": "run-1"} + ctx = detect_distributed_context(env) + assert ctx.global_rank == 2 + assert ctx.local_rank == 0 + assert ctx.world_size == 4 + assert ctx.launcher == "torchrun" + + +def test_detect_distributed_context_mpi_env(): + env = { + "OMPI_COMM_WORLD_RANK": "5", + "OMPI_COMM_WORLD_LOCAL_RANK": "1", + "OMPI_COMM_WORLD_SIZE": "8", + } + ctx = detect_distributed_context(env) + assert ctx.global_rank == 5 + assert ctx.local_rank == 1 + assert ctx.world_size == 8 + assert ctx.launcher == "mpirun" + + +def test_apply_rank_suffix_distributed_file_path(): + ctx = DistributedContext(global_rank=3, world_size=8) + assert apply_rank_suffix("results.json", ctx) == "results.rank0003.json" + + +def test_apply_rank_suffix_no_extension(): + ctx = DistributedContext(global_rank=0, world_size=4) + assert apply_rank_suffix("results", ctx) == "results.rank0000" + + +def test_apply_rank_suffix_single_process(): + ctx = DistributedContext(global_rank=0, world_size=1) + assert apply_rank_suffix("results.json", ctx) == "results.json" + + +def test_normalize_command_argv_accepts_string_and_sequence(): + assert normalize_command_argv('torchrun --nproc_per_node=2 train.py --arg "two words"') == [ + "torchrun", + "--nproc_per_node=2", + "train.py", + "--arg", + "two words", + ] + assert normalize_command_argv(["mpirun", "-np", "4", "./app"]) == [ + "mpirun", + "-np", + "4", + "./app", + ] diff --git a/metrix/tests/unit/test_rocprof_wrapper.py b/metrix/tests/unit/test_rocprof_wrapper.py index 205c7c4..13c490f 100644 --- a/metrix/tests/unit/test_rocprof_wrapper.py +++ b/metrix/tests/unit/test_rocprof_wrapper.py @@ -290,6 +290,135 @@ def fake_run(cmd, **kwargs): assert len(results) == 1 assert results[0].kernel_name == "my_kernel(float*, int)" + def test_command_string_uses_shlex_parsing(self, wrapper_no_rocm_check): + """Quoted command arguments are preserved via shlex parsing.""" + wrapper = wrapper_no_rocm_check + captured_cmd = [] + + def fake_run(cmd, **kwargs): + captured_cmd.extend(cmd) + mock_result = MagicMock() + mock_result.returncode = 0 + mock_result.stdout = "" + mock_result.stderr = "" + return mock_result + + with ( + patch("subprocess.run", side_effect=fake_run), + patch.object(wrapper, "_parse_output", return_value=[]), + tempfile.TemporaryDirectory() as tmpdir, + ): + wrapper.profile( + command='python -c "print(1 + 2)"', + counters=[], + output_dir=Path(tmpdir), + ) + + assert "--" in captured_cmd + idx = captured_cmd.index("--") + assert captured_cmd[idx + 1 : idx + 4] == ["python", "-c", "print(1 + 2)"] + + def test_profile_sets_distributed_rank_fields(self, wrapper_no_rocm_check): + """ProfileResult receives distributed rank metadata from env.""" + wrapper = wrapper_no_rocm_check + + parsed = [ + ProfileResult( + dispatch_id=1, + kernel_name="my_kernel", + gpu_id=0, + duration_ns=1000, + grid_size=(256, 1, 1), + workgroup_size=(64, 1, 1), + counters={}, + ) + ] + + def fake_run(cmd, **kwargs): + m = MagicMock() + m.returncode = 0 + m.stdout = "" + m.stderr = "" + return m + + with ( + patch("subprocess.run", side_effect=fake_run), + patch.object(wrapper, "_parse_output", return_value=parsed), + tempfile.TemporaryDirectory() as tmpdir, + ): + results = wrapper.profile( + command="true", + counters=[], + output_dir=Path(tmpdir), + env={"RANK": "3", "LOCAL_RANK": "1", "WORLD_SIZE": "8"}, + ) + + assert len(results) == 1 + assert results[0].global_rank == 3 + assert results[0].local_rank == 1 + assert results[0].world_size == 8 + + def test_launcher_param_builds_correct_order(self, wrapper_no_rocm_check): + """Explicit launcher param should produce: launcher rocprofv3 ... -- app.""" + wrapper = wrapper_no_rocm_check + captured_cmd = [] + + def fake_run(cmd, **kwargs): + captured_cmd.extend(cmd) + mock_result = MagicMock() + mock_result.returncode = 0 + mock_result.stdout = "" + mock_result.stderr = "" + return mock_result + + with ( + patch("subprocess.run", side_effect=fake_run), + patch.object(wrapper, "_parse_output", return_value=[]), + tempfile.TemporaryDirectory() as tmpdir, + ): + wrapper.profile( + command="train.py --lr 0.01", + counters=[], + output_dir=Path(tmpdir), + launcher="torchrun --nproc_per_node=4", + ) + + assert captured_cmd[0] == "torchrun" + assert captured_cmd[1] == "--nproc_per_node=4" + rocprofv3_idx = captured_cmd.index("rocprofv3") + assert rocprofv3_idx == 2 + separator_idx = captured_cmd.index("--") + assert captured_cmd[separator_idx + 1] == "train.py" + assert captured_cmd[separator_idx + 2] == "--lr" + + def test_no_launcher_uses_plain_rocprofv3(self, wrapper_no_rocm_check): + """Without launcher, command should start with rocprofv3.""" + wrapper = wrapper_no_rocm_check + captured_cmd = [] + + def fake_run(cmd, **kwargs): + captured_cmd.extend(cmd) + mock_result = MagicMock() + mock_result.returncode = 0 + mock_result.stdout = "" + mock_result.stderr = "" + return mock_result + + with ( + patch("subprocess.run", side_effect=fake_run), + patch.object(wrapper, "_parse_output", return_value=[]), + tempfile.TemporaryDirectory() as tmpdir, + ): + wrapper.profile( + command="./my_app --size 1024", + counters=[], + output_dir=Path(tmpdir), + ) + + assert captured_cmd[0] == "rocprofv3" + separator_idx = captured_cmd.index("--") + assert captured_cmd[separator_idx + 1] == "./my_app" + def test_parse_missing_optional_fields(self, wrapper): """Handle missing optional fields gracefully""" row = {