diff --git a/atom/config.py b/atom/config.py index e273fb73..55b955dc 100644 --- a/atom/config.py +++ b/atom/config.py @@ -594,6 +594,7 @@ class Config: default_factory=lambda: QuantizationConfig() ) asyncio_mode: bool = False + mark_trace: bool = False load_dummy: bool = False enable_expert_parallel: bool = False master_addr: str = "127.0.0.1" diff --git a/atom/model_engine/arg_utils.py b/atom/model_engine/arg_utils.py index a3c9881e..f1636270 100644 --- a/atom/model_engine/arg_utils.py +++ b/atom/model_engine/arg_utils.py @@ -45,6 +45,7 @@ class EngineArgs: enable_dp_attention: bool = False method: Optional[str] = None num_speculative_tokens: int = 1 + mark_trace: bool = False @staticmethod def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: @@ -170,6 +171,11 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: help="Apply a delay (of delay factor multiplied by previous" "prompt latency) before scheduling next prompt.", ) + parser.add_argument( + "--mark-trace", + action="store_true", + help="Enable graph_marker nodes for tracing/profile instrumentation.", + ) return parser diff --git a/atom/model_engine/engine_core.py b/atom/model_engine/engine_core.py index f85baa0b..f5df16fa 100644 --- a/atom/model_engine/engine_core.py +++ b/atom/model_engine/engine_core.py @@ -63,6 +63,7 @@ def __init__(self, config: Config, input_address: str, output_address: str): self.input_thread.start() self.profile_enbaled = config.torch_profiler_dir is not None + self.mark_trace = getattr(config, "mark_trace", False) init_exit_handler(self) self._init_data_parallel(config) @@ -83,6 +84,9 @@ def __init__(self, config: Config, input_address: str, output_address: str): config.num_kvcache_blocks = num_blocks if not config.enforce_eager: + # Start profiler before cudagraph capture only if mark-trace is enabled. + if self.profile_enbaled and self.mark_trace: + self.runner_mgr.call_func("start_profiler", wait_out=True) cap_cost, bs = self.runner_mgr.call_func( "capture_cudagraph", wait_out=True ) @@ -284,7 +288,7 @@ def process_output_sockets(self, output_address: str): def start_profiler(self): if self.profile_enbaled: - self.runner_mgr.call_func("start_profiler") + self.runner_mgr.call_func("start_profiler", wait_out=True) def stop_profiler(self): if self.profile_enbaled: diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index 508c6b4f..5e379246 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -5,6 +5,7 @@ import math import os import time +from contextlib import nullcontext from typing import Any, Optional, Union import numpy as np @@ -19,6 +20,7 @@ graph_capture, ) from aiter.dist.utils import get_distributed_init_method +from torch.profiler import record_function from atom.config import Config, KVCacheTensor, set_current_atom_config from atom.model_engine.scheduler import ScheduledBatch, ScheduledBatchOutput from atom.model_engine.sequence import Sequence, SequenceStatus, SequenceType @@ -464,6 +466,10 @@ class ModelRunner: def __init__(self, rank: int, config: Config): self.config = config + self.mark_trace = getattr(config, "mark_trace", False) + from atom.utils.graph_marker import set_graph_marker_enabled + + set_graph_marker_enabled(self.mark_trace) set_current_atom_config(config) hf_config = config.hf_config self.block_size = config.kv_cache_block_size @@ -681,6 +687,7 @@ def start_profiler(self): ), ) self.profiler.__enter__() + return True def stop_profiler(self): """Stop profiling for this rank""" @@ -1316,19 +1323,31 @@ def run_model(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor is_prefill = context.is_prefill positions = context.positions if is_prefill or self.enforce_eager or bs > self.graph_bs[-1]: - hidden_states = self.model(input_ids, positions) - logits = self.model.compute_logits(hidden_states) - else: - graph_bs = context.graph_bs - max_q_len = forward_context.attn_metadata.max_seqlen_q - graph_key = (graph_bs, max_q_len) - self.graphs[graph_key].replay() - num_tokens = context.batch_size * max_q_len - hidden_states = self.forward_vars["outputs"][:num_tokens] - if self.logits_in_graph: - logits = self.graph_logits[graph_key][:num_tokens] - else: + with ( + record_function( + f"prefill_bs_{bs}_ctxlens_{forward_context.attn_metadata.context_lens}" + ) + if self.mark_trace + else nullcontext() + ): + hidden_states = self.model(input_ids, positions) logits = self.model.compute_logits(hidden_states) + else: + with ( + record_function(f"decode_step_bs_{bs}") + if self.mark_trace + else nullcontext() + ): + graph_bs = context.graph_bs + max_q_len = forward_context.attn_metadata.max_seqlen_q + graph_key = (graph_bs, max_q_len) + self.graphs[graph_key].replay() + num_tokens = context.batch_size * max_q_len + hidden_states = self.forward_vars["outputs"][:num_tokens] + if self.logits_in_graph: + logits = self.graph_logits[graph_key][:num_tokens] + else: + logits = self.model.compute_logits(hidden_states) return logits, hidden_states @@ -1537,12 +1556,19 @@ def capture_cudagraph(self): # Capture: include compute_logits only when TP=1 since # ParallelLMHead uses NCCL all_gather which is not # graph-capturable on HIP when TP > 1. - with torch.cuda.graph(graph, self.graph_pool, stream=gc.stream): - outputs[:num_tokens] = self.model( - input_ids[:num_tokens], positions[:num_tokens] - ) - if self.logits_in_graph: - graph_logits = self.model.compute_logits(outputs[:num_tokens]) + with ( + record_function(f"capture_graph_bs_{bs}") + if self.mark_trace + else nullcontext() + ): + with torch.cuda.graph(graph, self.graph_pool, stream=gc.stream): + outputs[:num_tokens] = self.model( + input_ids[:num_tokens], positions[:num_tokens] + ) + if self.logits_in_graph: + graph_logits = self.model.compute_logits( + outputs[:num_tokens] + ) if self.graph_pool is None: self.graph_pool = graph.pool() self.graphs[(bs, max_q_len)] = graph diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index a3d7b4ef..ed611eda 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -26,6 +26,7 @@ from aiter.jit.utils.torch_guard import torch_compile_guard from aiter.tuned_gemm import tgemm from aiter.utility import fp4_utils +from atom.utils.decorators import mark_trace, record_function from atom.model_ops.utils import shuffle_weights from atom.utils import envs @@ -168,10 +169,12 @@ def gemm_a8w8_blockscale_preshuffle_fake( x_scale: torch.Tensor, w_scale: torch.Tensor, dtype: torch.dtype = torch.bfloat16, + prefix: str = "", ) -> torch.Tensor: return torch.empty((*x.shape[:-1], weight.shape[0]), dtype=dtype, device=x.device) +@record_function @torch_compile_guard(gen_fake=gemm_a8w8_blockscale_preshuffle_fake, mutates_args=[]) def gemm_a8w8_blockscale_preshuffle_impl( x: torch.Tensor, @@ -179,6 +182,7 @@ def gemm_a8w8_blockscale_preshuffle_impl( x_scale: torch.Tensor, w_scale: torch.Tensor, dtype: torch.dtype = torch.bfloat16, + prefix: str = "", ) -> torch.Tensor: if gemm_a8w8_blockscale_bpreshuffle_triton is not None: weight_shuffled = weight.reshape(weight.shape[0] // 16, weight.shape[1] * 16) @@ -190,6 +194,7 @@ def gemm_a8w8_blockscale_preshuffle_impl( return y +@mark_trace class LinearBase(nn.Module): def __init__( self, @@ -200,6 +205,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, reduce_results: bool = False, source_quant_dtype: torch.dtype | None = None, + prefix: str = "", ): if quant_config is None: quant_config = QuantizationConfig() @@ -207,6 +213,7 @@ def __init__( quant_type = quant_config["quant_type"] params_dtype = quant_config["quant_dtype"] super().__init__() + self.prefix = prefix self.reduce_results = reduce_results self.input_size = input_size self.output_size = ( @@ -421,7 +428,12 @@ def forward( y += self.bias elif self.quant_type.value == QuantType.per_1x128.value: y = gemm_a8w8_blockscale_preshuffle_impl( - x, self.weight, x_scale, self.weight_scale, dtype=otype + x, + self.weight, + x_scale, + self.weight_scale, + dtype=otype, + prefix=self.prefix, ) if self.bias is not None: y += self.bias @@ -460,6 +472,7 @@ def __init__( bias=bias, quant_config=quant_config, source_quant_dtype=source_quant_dtype, + prefix=kwargs.get("prefix", ""), ) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): @@ -485,6 +498,7 @@ def __init__( bias, quant_config=quant_config, source_quant_dtype=source_quant_dtype, + prefix=kwargs.get("prefix", ""), ) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): @@ -516,6 +530,7 @@ def __init__( bias=bias, quant_config=quant_config, source_quant_dtype=source_quant_dtype, + prefix=prefix, ) def weight_loader( @@ -650,6 +665,7 @@ def __init__( bias=bias, quant_config=quant_config, source_quant_dtype=source_quant_dtype, + prefix=kwargs.get("prefix", ""), ) def weight_loader( @@ -711,6 +727,7 @@ def __init__( quant_config=quant_config, reduce_results=reduce_results, source_quant_dtype=source_quant_dtype, + prefix=prefix, ) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): @@ -746,6 +763,7 @@ def __init__( bias=bias, quant_config=quant_config, source_quant_dtype=source_quant_dtype, + prefix=kwargs.get("prefix", ""), ) def weight_loader( diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index f0342dce..51006bf4 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -1276,6 +1276,7 @@ def __init__( bias=False, quant_config=quant_config, source_quant_dtype=source_quant_dtype, + prefix=f"{prefix}.q_a_proj", ) self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) self.q_b_proj = ColumnParallelLinear( diff --git a/atom/utils/backends.py b/atom/utils/backends.py index 57eec887..3ef45496 100644 --- a/atom/utils/backends.py +++ b/atom/utils/backends.py @@ -230,6 +230,34 @@ def compile( handle, ) + # When mark-trace is enabled, post-processing may rewrite generated + # artifact sources. Force reloading from cache artifact so the current + # process uses the rewritten artifact immediately. + try: + from atom.utils.graph_marker import is_graph_marker_enabled + + force_reload = is_graph_marker_enabled() + except Exception: + force_reload = False + if force_reload and handle is not None and not self.disable_cache: + try: + reloaded_graph = self.load( + graph, example_inputs, graph_index, runtime_shape + ) + if reloaded_graph is not None: + compiled_graph = reloaded_graph + logger.info( + "Force reloaded compiled graph from cache artifact " + "(graph_index=%s, runtime_shape=%s).", + graph_index, + runtime_shape, + ) + except Exception: + logger.exception( + "Failed to force reload compiled graph from cache artifact; " + "falling back to in-memory compiled callable." + ) + # after compiling the last graph, record the end time if graph_index == num_graphs - 1: now = time.time() diff --git a/atom/utils/compiler_inferface.py b/atom/utils/compiler_inferface.py index 71d9fd23..70cf9572 100644 --- a/atom/utils/compiler_inferface.py +++ b/atom/utils/compiler_inferface.py @@ -168,6 +168,15 @@ def set_inductor_config(config, runtime_shape): config["max_autotune"] = True config["coordinate_descent_tuning"] = True + try: + from atom.utils.graph_marker import is_graph_marker_enabled + + if is_graph_marker_enabled(): + config["size_asserts"] = False + config["compile_threads"] = 1 + except Exception: + pass + class InductorAdaptor(CompilerInterface): """ @@ -395,6 +404,21 @@ def _get_shape_env() -> AlwaysHitShapeEnv: assert ( file_path is not None ), "failed to get the file path of the compiled graph" + # Best-effort post-process the generated wrapper file too (PyTorch <2.8 path). + try: + # Only run post-processing when mark-trace is enabled (to avoid any + # overhead / file churn in default runs). + from atom.utils.graph_marker import is_graph_marker_enabled + + if is_graph_marker_enabled(): + # Local import to avoid extra package-level side effects. + from .graph_marker_instrumentation import ( + instrument_record_functions_in_file, + ) + + instrument_record_functions_in_file(file_path, strip_markers=False) + except Exception: + pass return compiled_graph, (hash_str, file_path) def load( @@ -560,6 +584,23 @@ def compile( # if not envs.VLLM_DISABLE_COMPILE_CACHE: compiled_graph.save(path=path, format="unpacked") compilation_counter.num_compiled_artifacts_saved += 1 + # Post-process generated wrapper Python files: wrap regions between + # _start / _end graph markers with record_function(""). + try: + # Only run post-processing when mark-trace is enabled (to avoid any + # overhead / file churn in default runs). + from atom.utils.graph_marker import is_graph_marker_enabled + + if is_graph_marker_enabled(): + # Local import to avoid extra package-level side effects. + from .graph_marker_instrumentation import ( + instrument_record_functions_in_dir, + ) + + instrument_record_functions_in_dir(path, strip_markers=False) + except Exception: + # Best-effort: never fail compilation due to instrumentation. + pass return compiled_graph, (key, path) def load( diff --git a/atom/utils/decorators.py b/atom/utils/decorators.py index 71344d2b..8fa91621 100644 --- a/atom/utils/decorators.py +++ b/atom/utils/decorators.py @@ -5,6 +5,7 @@ import inspect import os import sys +from functools import wraps from types import CodeType from abc import abstractmethod from contextlib import contextmanager @@ -17,6 +18,8 @@ from atom.config import CompilationConfig, Config, CompilationLevel +from atom.utils.graph_marker import graph_marker + # from atom.utils import start_monitoring_torch_compile _T = TypeVar("_T", bound=type[nn.Module]) @@ -25,6 +28,143 @@ torch_compile_start_time: float = 0.0 +def record_function(prefix: Union[str, Callable, None] = None): + """ + Decorator that wraps a function with torch.profiler.record_function. + + Usage: + - @record_function + - @record_function("my_prefix") + """ + + def _decorate(func: Callable): + # Try to recover the original callable signature even when func is wrapped + # by other decorators. + base_func = inspect.unwrap(func) + try: + base_sig = inspect.signature(base_func) + except (TypeError, ValueError): + base_sig = None + + @wraps(func) + def _wrapped(*args, **kwargs): + # Keep this decorator no-op unless mark-trace is enabled. + from atom.utils.graph_marker import is_graph_marker_enabled + + if not is_graph_marker_enabled(): + return func(*args, **kwargs) + + # Priority: + # 1) explicit decorator prefix: @record_function("xxx") + # 2) runtime function argument named "prefix" when non-empty + # 3) function name fallback + if prefix is not None: + span_name = str(prefix) + else: + span_name = func.__name__ + runtime_prefix = kwargs.get("prefix") + if not (isinstance(runtime_prefix, str) and runtime_prefix): + if base_sig is not None: + try: + bound = base_sig.bind_partial(*args, **kwargs) + runtime_prefix = bound.arguments.get("prefix") + except Exception: + runtime_prefix = None + if isinstance(runtime_prefix, str) and runtime_prefix: + span_name = runtime_prefix + + with torch.profiler.record_function(f"{span_name}"): + return func(*args, **kwargs) + + return _wrapped + + # Support @record_function without parentheses. + if callable(prefix): + func = prefix + prefix = None + return _decorate(func) + return _decorate + + +def _graph_marker_first_tensor(obj, name: str): + if torch.is_tensor(obj): + return graph_marker(obj, name=name), True + if isinstance(obj, tuple): + out = [] + marked_any = False + for v in obj: + if marked_any: + out.append(v) + continue + vv, marked_any = _graph_marker_first_tensor(v, name) + out.append(vv) + out_t = tuple(out) + # namedtuple support + if hasattr(obj, "_fields"): + return obj.__class__(*out_t), marked_any + return out_t, marked_any + if isinstance(obj, list): + out = [] + marked_any = False + for v in obj: + if marked_any: + out.append(v) + continue + vv, marked_any = _graph_marker_first_tensor(v, name) + out.append(vv) + return out, marked_any + if isinstance(obj, dict): + out = {} + marked_any = False + for k, v in obj.items(): + if marked_any: + out[k] = v + continue + vv, marked_any = _graph_marker_first_tensor(v, name) + out[k] = vv + return out, marked_any + return obj, False + + +def mark_trace(cls): + forward = getattr(cls, "forward", None) + if forward is None: + return cls + if getattr(forward, "__mark_trace_wrapped__", False): + return cls + + from atom.utils.graph_marker import is_graph_marker_enabled + + def wrapped_forward(self, *args, **kwargs): + # When mark-trace is disabled, bypass all wrapping logic entirely + if not is_graph_marker_enabled(): + return forward(self, *args, **kwargs) + + prefix = getattr(self, "prefix", cls.__name__) + # Mark only the first tensor across args/kwargs, keeping names stable. + args_l = list(args) + marked = False + for i, a in enumerate(args_l): + if marked: + break + aa, marked = _graph_marker_first_tensor(a, f"{prefix}_start") + args_l[i] = aa + if not marked: + for k, v in list(kwargs.items()): + if marked: + break + vv, marked = _graph_marker_first_tensor(v, f"{prefix}_start") + kwargs[k] = vv + args = tuple(args_l) + y = forward(self, *args, **kwargs) + yy, _ = _graph_marker_first_tensor(y, f"{prefix}_end") + return yy + + wrapped_forward.__mark_trace_wrapped__ = True + cls.forward = wrapped_forward + return cls + + # We remove it from utils/__init__.py to avoid circular import def start_monitoring_torch_compile(vllm_config: Config): global torch_compile_start_time diff --git a/atom/utils/graph_marker.py b/atom/utils/graph_marker.py new file mode 100644 index 00000000..d9caf0b9 --- /dev/null +++ b/atom/utils/graph_marker.py @@ -0,0 +1,48 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. +# +# A tiny, graph-friendly marker op for debugging/graph inspection. +# It is an identity at runtime, but it shows up in FX/graph dumps. + +# from __future__ import annotations + +import torch + +from aiter.jit.utils.torch_guard import torch_compile_guard + +_GRAPH_MARKER_ENABLED: bool = False + + +def set_graph_marker_enabled(enabled: bool) -> None: + """Enable/disable graph markers globally (per-process).""" + global _GRAPH_MARKER_ENABLED + _GRAPH_MARKER_ENABLED = bool(enabled) + + +def is_graph_marker_enabled() -> bool: + return _GRAPH_MARKER_ENABLED + + +def _graph_marker_impl(x: torch.Tensor) -> torch.Tensor: + # Runtime behavior: identity. + # Keep this side-effect free to avoid graph breaks. + return x + + +def _graph_marker_fake(x: torch.Tensor, name: str) -> torch.Tensor: + # FakeTensor / meta behavior: identity with preserved shape/stride/dtype. + return x + + +@torch_compile_guard(gen_fake=_graph_marker_fake) +def graph_marker(x: torch.Tensor, name: str) -> torch.Tensor: + """Insert a no-op marker node into the compiled/traced graph. + + The marker `name` is embedded as a constant in the graph dump so you can + grep it in `computation_graph.py` / generated wrapper files. + """ + # When disabled, return early so the marker does not even appear in the + # traced/compiled graph. + if not _GRAPH_MARKER_ENABLED: + return x + return _graph_marker_impl(x) diff --git a/atom/utils/graph_marker_instrumentation.py b/atom/utils/graph_marker_instrumentation.py new file mode 100644 index 00000000..ba97f8f4 --- /dev/null +++ b/atom/utils/graph_marker_instrumentation.py @@ -0,0 +1,433 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. + +from __future__ import annotations + +import ast +import os +import re +from dataclasses import dataclass +from typing import Iterable, Optional + +_SUBGRAPH_ID_RE = re.compile(r"artifact_shape_[^/]+_subgraph_(\d+)") + +_ASSIGNMENT_RE = re.compile( + r"""^(?P\s*) + (?P[A-Za-z_]\w*)\s*=\s* + (?P.+?)\s*$""", + re.VERBOSE, +) + + +@dataclass(frozen=True) +class _Marker: + idx: int + indent: str + name: str + + +@dataclass(frozen=True) +class _ParsedMarkerAssignment: + indent: str + lhs: str + arg: str + name: str + + +_GRAPH_MARKER_PREFIX = "torch.ops.aiter.graph_marker.default(" + + +def _find_matching_paren(s: str, open_idx: int) -> Optional[int]: + depth = 0 + in_str: Optional[str] = None + escaped = False + for i in range(open_idx, len(s)): + ch = s[i] + if in_str is not None: + if escaped: + escaped = False + continue + if ch == "\\": + escaped = True + continue + if ch == in_str: + in_str = None + continue + if ch in ("'", '"'): + in_str = ch + continue + if ch == "(": + depth += 1 + continue + if ch == ")": + depth -= 1 + if depth == 0: + return i + return None + + +def _split_top_level_args(s: str) -> list[str]: + out: list[str] = [] + depth_paren = 0 + depth_bracket = 0 + depth_brace = 0 + in_str: Optional[str] = None + escaped = False + start = 0 + for i, ch in enumerate(s): + if in_str is not None: + if escaped: + escaped = False + continue + if ch == "\\": + escaped = True + continue + if ch == in_str: + in_str = None + continue + if ch in ("'", '"'): + in_str = ch + continue + if ch == "(": + depth_paren += 1 + continue + if ch == ")": + depth_paren -= 1 + continue + if ch == "[": + depth_bracket += 1 + continue + if ch == "]": + depth_bracket -= 1 + continue + if ch == "{": + depth_brace += 1 + continue + if ch == "}": + depth_brace -= 1 + continue + if ch == "," and depth_paren == 0 and depth_bracket == 0 and depth_brace == 0: + out.append(s[start:i].strip()) + start = i + 1 + out.append(s[start:].strip()) + return out + + +def _parse_graph_marker_call_expr(expr: str) -> Optional[tuple[str, str]]: + idx = expr.find(_GRAPH_MARKER_PREFIX) + if idx < 0: + return None + open_idx = idx + len(_GRAPH_MARKER_PREFIX) - 1 + close_idx = _find_matching_paren(expr, open_idx) + if close_idx is None: + return None + call_src = expr[idx : close_idx + 1] + if call_src.count("(") != call_src.count(")"): + return None + # Ensure the matched call is the whole RHS expression (allowing whitespace). + if expr[:idx].strip() or expr[close_idx + 1 :].strip(): + return None + inner = call_src[len(_GRAPH_MARKER_PREFIX) : -1] + args = _split_top_level_args(inner) + if len(args) < 2: + return None + arg_expr = args[0] + try: + name = ast.literal_eval(args[1]) + except Exception: + return None + if not isinstance(name, str): + return None + return arg_expr, name + + +def _parse_graph_marker_assignment_line(line: str) -> Optional[_ParsedMarkerAssignment]: + m = _ASSIGNMENT_RE.match(line.rstrip("\n")) + if not m: + return None + rhs = m.group("rhs") + parsed = _parse_graph_marker_call_expr(rhs) + if parsed is None: + return None + arg, name = parsed + return _ParsedMarkerAssignment( + indent=m.group("indent"), + lhs=m.group("lhs"), + arg=arg, + name=name, + ) + + +def _extract_graph_marker_name(line: str) -> Optional[str]: + idx = line.find(_GRAPH_MARKER_PREFIX) + if idx < 0: + return None + open_idx = idx + len(_GRAPH_MARKER_PREFIX) - 1 + close_idx = _find_matching_paren(line, open_idx) + if close_idx is None: + return None + call_src = line[idx : close_idx + 1] + inner = call_src[len(_GRAPH_MARKER_PREFIX) : -1] + args = _split_top_level_args(inner) + if len(args) < 2: + return None + try: + name = ast.literal_eval(args[1]) + except Exception: + return None + return name if isinstance(name, str) else None + + +def _iter_py_files(root: str) -> Iterable[str]: + for dirpath, _, filenames in os.walk(root): + for fn in filenames: + if fn.endswith(".py"): + yield os.path.join(dirpath, fn) + + +def _ensure_record_function_import(lines: list[str]) -> None: + # If already imported or referenced via qualified name, do nothing. + if any( + ("record_function" in line and ("import" in line or "from torch" in line)) + for line in lines + ): + return + + # Insert `from torch.profiler import record_function` after the first + # real `import torch` line outside the initial docstring. + in_doc = False + for i, line in enumerate(lines): + if i == 0 and line.lstrip().startswith('"""'): + in_doc = True + if in_doc and line.rstrip().endswith('"""') and i != 0: + in_doc = False + continue + if in_doc: + continue + if re.match(r"^\s*import\s+torch\b", line): + lines.insert(i + 1, "from torch.profiler import record_function\n") + return + + # Fallback: put it near the top (after any shebang/encoding if present). + insert_at = 0 + if lines and lines[0].startswith("#!"): + insert_at = 1 + lines.insert(insert_at, "from torch.profiler import record_function\n") + + +def _collect_markers(lines: list[str]) -> list[_Marker]: + out: list[_Marker] = [] + for i, line in enumerate(lines): + name = _extract_graph_marker_name(line) + if name is None: + continue + indent = re.match(r"^(\s*)", line).group(1) # type: ignore[union-attr] + out.append(_Marker(idx=i, indent=indent, name=name)) + return out + + +def _prefix_and_kind(name: str) -> Optional[tuple[str, str]]: + if name.endswith("_start"): + return name[: -len("_start")], "start" + if name.endswith("_end"): + return name[: -len("_end")], "end" + return None + + +def _already_wrapped( + lines: list[str], indent: str, prefix: str, start_idx: int, end_idx: int +) -> bool: + needle = f'{indent}with record_function("{prefix}"):\n' + for i in range(start_idx, min(end_idx, len(lines))): + if lines[i] == needle: + return True + return False + + +def _wrap_region_with_record_function( + lines: list[str], + *, + start_marker_idx: int, + end_marker_idx: int, + prefix: str, + indent: str, + layer_id: Optional[int] = None, +) -> None: + """ + Transform: + ... graph_marker(..., "_start") + LINE_A + LINE_B + ... graph_marker(..., "_end") + Into: + ... graph_marker(..., "_start") + with record_function(""): + LINE_A + LINE_B + ... graph_marker(..., "_end") + """ + if end_marker_idx <= start_marker_idx + 1: + return + + # Only add layer prefix if prefix doesn't already contain layer info + if layer_id is not None and layer_id >= 0 and "model.layers" not in prefix: + tag = f"layer_{layer_id}_{prefix}" + else: + tag = prefix + with_line = f'{indent}with record_function("{tag}"):\n' + insert_at = start_marker_idx + 1 + + # If we already inserted a record_function line in a previous run, upgrade it + # in-place (e.g. "mlp" -> "layer_0_mlp") and exit without touching indentation. + if insert_at < len(lines) and lines[insert_at].startswith( + f"{indent}with record_function(" + ): + if lines[insert_at] != with_line: + lines[insert_at] = with_line + return + + # Otherwise, insert a new record_function wrapper and indent the region. + lines.insert(insert_at, with_line) + + # Re-indent the region between start_marker and end_marker (exclusive of end marker). + region_start = insert_at + 1 + region_end = end_marker_idx + 1 # end marker shifted down by 1 due to insertion + indent_prefix = indent + extra = " " * 4 + for i in range(region_start, region_end): + line = lines[i] + if line.strip() == "": + continue + if line.startswith(indent_prefix): + lines[i] = indent_prefix + extra + line[len(indent_prefix) :] + + +def _layer_id_from_wrapper_path(path: str) -> Optional[int]: + """ + Derive layer id from wrapper file path: + .../artifact_shape__subgraph_/... -> layer_id = N - 1 + Returns None if the pattern isn't found. + """ + m = _SUBGRAPH_ID_RE.search(path) + if not m: + return None + try: + subgraph_id = int(m.group(1)) + except ValueError: + return None + return subgraph_id - 1 + + +def _strip_runtime_graph_markers(lines: list[str]) -> bool: + """ + Remove runtime overhead of graph markers in generated wrapper code. + + - Replace `x = torch.ops.aiter.graph_marker.default(y, '...')` with `x = y` + - Drop assert_size_stride / assert_alignment lines that specifically refer to + `torch.ops.aiter.graph_marker.default` (they become redundant). + """ + out: list[str] = [] + changed = False + for line in lines: + parsed = _parse_graph_marker_assignment_line(line) + if parsed is not None: + out.append(f"{parsed.indent}{parsed.lhs} = {parsed.arg}\n") + changed = True + continue + + if ( + "assert_size_stride" in line or "assert_alignment" in line + ) and "torch.ops.aiter.graph_marker.default" in line: + changed = True + continue + + out.append(line) + + if changed: + lines[:] = out + return changed + + +def instrument_record_functions_in_file( + path: str, *, strip_markers: bool = True +) -> bool: + """ + Returns True if the file was modified. + """ + try: + with open(path, "r", encoding="utf-8") as f: + lines = f.readlines() + except OSError: + return False + + markers = _collect_markers(lines) + if not markers: + return False + + # Build intervals by matching _start / _end. + stack: dict[str, _Marker] = {} + intervals: list[tuple[_Marker, _Marker, str]] = [] + for mk in markers: + pk = _prefix_and_kind(mk.name) + if pk is None: + continue + prefix, kind = pk + if kind == "start": + stack[prefix] = mk + else: + start_mk = stack.pop(prefix, None) + if start_mk is None: + continue + # Use start indent as the wrapping indent (generated code is consistent). + intervals.append((start_mk, mk, prefix)) + + # Even if we can't form any intervals, we still might want to strip marker + # calls from already-instrumented wrappers (best-effort). + has_intervals = bool(intervals) + + layer_id = _layer_id_from_wrapper_path(path) + + # Apply from bottom to top so indices stay valid. + wrapped_or_upgraded = False + if has_intervals: + for start_mk, end_mk, prefix in sorted( + intervals, key=lambda t: t[0].idx, reverse=True + ): + _wrap_region_with_record_function( + lines, + start_marker_idx=start_mk.idx, + end_marker_idx=end_mk.idx, + prefix=prefix, + indent=start_mk.indent, + layer_id=layer_id, + ) + wrapped_or_upgraded = True + + stripped = False + if strip_markers: + # Only strip marker calls if we either: + # - wrapped/upgraded this run, or + # - the file already contains record_function blocks (previous run) + already_has_record = any("with record_function(" in line for line in lines) + if wrapped_or_upgraded or already_has_record: + stripped = _strip_runtime_graph_markers(lines) + + changed = wrapped_or_upgraded or stripped + if changed: + if wrapped_or_upgraded: + _ensure_record_function_import(lines) + with open(path, "w", encoding="utf-8") as f: + f.writelines(lines) + return changed + + +def instrument_record_functions_in_dir(root: str, *, strip_markers: bool = True) -> int: + """ + Walk `root` and instrument all generated `.py` wrapper files. + Returns the number of modified files. + """ + changed = 0 + for fp in _iter_py_files(root): + if instrument_record_functions_in_file(fp, strip_markers=strip_markers): + changed += 1 + return changed diff --git a/pyproject.toml b/pyproject.toml index 7e6c623b..9b3b1a80 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ readme = "README.md" description = "a lightweight vLLM implementation built from scratch" requires-python = ">=3.10,<3.13" dynamic = ["version"] -dependencies = ["pybind11", "transformers", "zmq", "xxhash", "fastapi", "psutil", "protobuf", "uvicorn", "aiohttp", "datasets"] +dependencies = ["pybind11", "transformers", "zmq", "xxhash", "fastapi", "psutil", "protobuf", "uvicorn", "aiohttp", "datasets", "openpyxl"] [project.urls] Homepage = "https://github.com/ROCm/ATOM" diff --git a/tools/parse_trace.py b/tools/parse_trace.py new file mode 100644 index 00000000..ba0ff04b --- /dev/null +++ b/tools/parse_trace.py @@ -0,0 +1,938 @@ +#!/usr/bin/env python3 +""" +Parse PyTorch profiler trace JSON to extract kernel information. + +Usage: + python parse_trace.py [--layer N] +""" + +import json +import gzip +import sys +import bisect +import argparse +import re +from typing import List, Dict, Any, Tuple, Optional +from openpyxl import Workbook + +# Modules to filter out (no corresponding GPU kernel in decode) +FILTER_OUT = ["fill_"] + +# Sampling-related modules and low-level ops to filter out in prefill +FILTER_OUT_PREFILL = ["aten::", "aiter::gemm_a16w16", "aiter::mixed_sample"] + + +# ============================================================================= +# Utility Functions +# ============================================================================= + + +def load_trace(filepath: str) -> Dict[str, Any]: + """Load trace JSON file (supports .gz).""" + opener = gzip.open if filepath.endswith(".gz") else open + with opener(filepath, "rt", encoding="utf-8") as f: + return json.load(f) + + +def is_within( + child_ts: float, child_dur: float, parent_ts: float, parent_dur: float +) -> bool: + """Check if child event is within parent's time range.""" + return child_ts >= parent_ts and (child_ts + child_dur) <= (parent_ts + parent_dur) + + +def is_kernel_launch(name: str) -> bool: + """Check if name is a kernel launch (contains 'launch' and 'kernel').""" + n = name.lower() + return "launch" in n and "kernel" in n + + +def should_filter(name: str) -> bool: + """Check if module should be filtered out.""" + return any(f in name for f in FILTER_OUT) + + +def should_filter_prefill(name: str) -> bool: + """Check if module should be filtered out in prefill (sampling ops).""" + return any(f in name for f in FILTER_OUT_PREFILL) + + +def write_breakdown_xlsx( + output_xlsx: str, + rows: List[List[Any]], + sheet_name: str, + avg_rows: Optional[List[List[Any]]] = None, +) -> None: + """ + Write XLSX breakdown with columns: + cpu_module, gpu_kernel, duration_us, sum per module, + avg duration_us, avg sum per module. + + The 1st/4th columns are merged for contiguous identical modules. + AVG columns are appended to the right in the same table. + """ + wb = Workbook() + ws = wb.active + ws.title = sheet_name + ws.append( + [ + "cpu_module", + "gpu_kernel", + "duration_us", + "sum per module", + "avg duration_us", + "avg sum per module", + ] + ) + + def build_groups(block_rows: List[List[Any]]) -> List[Tuple[int, int, str, float]]: + groups: List[Tuple[int, int, str, float]] = [] + i = 0 + while i < len(block_rows): + mod = block_rows[i][0] + j = i + 1 + total = float(block_rows[i][2]) + while j < len(block_rows) and block_rows[j][0] == mod: + total += float(block_rows[j][2]) + j += 1 + groups.append((i, j - 1, mod, total)) + i = j + return groups + + main_groups = build_groups(rows) if rows else [] + renamed_group_mods = [g[2] for g in main_groups] + seen_rmsnorm = 0 + for gi, mod in enumerate(renamed_group_mods): + if isinstance(mod, str) and "rmsnorm" in mod.lower(): + if seen_rmsnorm == 0: + renamed_group_mods[gi] = "input_layernorm" + elif seen_rmsnorm == 1: + renamed_group_mods[gi] = "post_attn_layernorm" + seen_rmsnorm += 1 + + avg_sum_by_row: Dict[int, float] = {} + if avg_rows: + avg_groups = build_groups(avg_rows) + for start, end, _, total in avg_groups: + for i in range(start, end + 1): + avg_sum_by_row[i] = total + + data_start_row = ws.max_row + 1 + for gi, (start, end, _, total) in enumerate(main_groups): + renamed_mod = renamed_group_mods[gi] + for idx in range(start, end + 1): + _, kernel, dur = rows[idx] + avg_dur = ( + float(avg_rows[idx][2]) if avg_rows and idx < len(avg_rows) else "" + ) + avg_sum = avg_sum_by_row.get(idx, "") + ws.append([renamed_mod, kernel, dur, total, avg_dur, avg_sum]) + + for start, end, _, _ in main_groups: + if end > start: + r1 = data_start_row + start + r2 = data_start_row + end + ws.merge_cells(start_row=r1, start_column=1, end_row=r2, end_column=1) + ws.merge_cells(start_row=r1, start_column=4, end_row=r2, end_column=4) + if avg_rows: + ws.merge_cells(start_row=r1, start_column=6, end_row=r2, end_column=6) + + total_duration = sum(float(r[2]) for r in rows) if rows else 0.0 + total_avg_duration = sum(float(r[2]) for r in avg_rows) if avg_rows else "" + ws.append(["TOTAL", "", total_duration, "", total_avg_duration, ""]) + + wb.save(output_xlsx) + + +def _normalize_module_for_avg(name: str) -> str: + if not isinstance(name, str): + return str(name) + return re.sub(r"model\.layers\.\d+\.", "model.layers.*.", name) + + +def build_avg_rows_from_layers( + layer_rows_list: List[List[List[Any]]], + layer_start_idx: int, + section_name: str, +) -> Optional[List[List[Any]]]: + """ + Build AVG rows across layers using layer-3 rows as template. + Returns None if any layer cannot be aligned by (module, kernel) sequence. + """ + if not layer_rows_list: + return [] + + base = layer_rows_list[0] + base_sig = [(_normalize_module_for_avg(r[0]), r[1]) for r in base] + + for rel_idx, rows in enumerate(layer_rows_list[1:], start=1): + sig = [(_normalize_module_for_avg(r[0]), r[1]) for r in rows] + if sig != base_sig: + bad_layer = layer_start_idx + rel_idx + print( + f"{section_name} avg skipped: layer {bad_layer} does not match layer {layer_start_idx} layout." + ) + return None + + n = len(layer_rows_list) + avg_rows: List[List[Any]] = [] + for i, (mod, kernel) in enumerate(base_sig): + # Keep original module display style from layer_start_idx rows. + display_mod = base[i][0] + avg_dur = ( + sum(float(layer_rows_list[layer_idx][i][2]) for layer_idx in range(n)) / n + ) + avg_rows.append([display_mod, kernel, avg_dur]) + return avg_rows + + +# ============================================================================= +# Optimized Event Index for fast time-range queries +# ============================================================================= + + +class EventIndex: + """Pre-indexed events for fast time-range queries.""" + + def __init__(self, events: List[Dict]): + # Filter duration events only + self.duration_events = [e for e in events if e.get("ph") == "X"] + self.duration_events.sort(key=lambda x: x["ts"]) + self.ts_list = [e["ts"] for e in self.duration_events] + + # Pre-compute kernel launch flags and prefix sum + self._is_kernel_launch = [ + is_kernel_launch(e.get("name", "")) for e in self.duration_events + ] + self._kernel_prefix_sum = [0] + for is_kl in self._is_kernel_launch: + self._kernel_prefix_sum.append( + self._kernel_prefix_sum[-1] + (1 if is_kl else 0) + ) + + def events_in_range(self, start_ts: float, end_ts: float) -> List[Dict]: + """Get all duration events within [start_ts, end_ts].""" + left = bisect.bisect_left(self.ts_list, start_ts) + right = bisect.bisect_right(self.ts_list, end_ts) + return [ + e + for e in self.duration_events[left:right] + if e["ts"] + e.get("dur", 0) <= end_ts + ] + + def count_kernel_launches_in_range(self, start_ts: float, end_ts: float) -> int: + """Count kernel launches within time range (fast using prefix sum).""" + left = bisect.bisect_left(self.ts_list, start_ts) + right = bisect.bisect_right(self.ts_list, end_ts) + count = 0 + for i in range(left, right): + e = self.duration_events[i] + if e["ts"] + e.get("dur", 0) <= end_ts and self._is_kernel_launch[i]: + count += 1 + return count + + def get_direct_children(self, parent: Dict) -> List[Dict]: + """Get direct children of parent event (optimized).""" + p_ts = parent["ts"] + p_end = p_ts + parent.get("dur", 0) + + # Get candidates in parent's time range + candidates = [e for e in self.events_in_range(p_ts, p_end) if e is not parent] + + if not candidates: + return [] + + # Filter to direct children only (not nested in other candidates) + # Sort by duration descending - larger events are potential parents + candidates_sorted = sorted(candidates, key=lambda x: -x.get("dur", 0)) + + direct = [] + for i, c in enumerate(candidates_sorted): + c_ts, c_dur = c["ts"], c.get("dur", 0) + c_end = c_ts + c_dur + # Check if c is nested inside any larger candidate + is_nested = False + for j in range(i): # Only check larger (earlier in sorted list) + o = candidates_sorted[j] + o_ts = o["ts"] + o_end = o_ts + o.get("dur", 0) + if c_ts >= o_ts and c_end <= o_end: + is_nested = True + break + if not is_nested: + direct.append(c) + + return sorted(direct, key=lambda x: x["ts"]) + + def count_kernel_launches(self, event: Dict) -> int: + """Count kernel launches within event's time range.""" + e_ts = event["ts"] + e_end = e_ts + event.get("dur", 0) + return self.count_kernel_launches_in_range(e_ts, e_end) + + def has_kernel_launch(self, event: Dict) -> bool: + """Check if event contains any kernel launch.""" + return self.count_kernel_launches(event) > 0 + + +# ============================================================================= +# Legacy functions (for prefill compatibility) +# ============================================================================= + + +def find_events(events: List[Dict], name: str, prefix: bool = False) -> List[Dict]: + """Find all duration events (ph='X') with given name, sorted by time.""" + if prefix: + result = [ + e + for e in events + if e.get("name", "").startswith(name) and e.get("ph") == "X" + ] + else: + result = [e for e in events if e.get("name") == name and e.get("ph") == "X"] + return sorted(result, key=lambda x: x["ts"]) + + +def get_gpu_kernels(events: List[Dict], start_ts: float) -> List[Dict]: + """Get GPU kernels (cat='kernel') starting from given timestamp.""" + result = [e for e in events if e.get("cat") == "kernel" and e["ts"] >= start_ts] + return sorted(result, key=lambda x: x["ts"]) + + +def get_direct_children(parent: Dict, events: List[Dict]) -> List[Dict]: + """Get direct children of parent event (excluding nested children).""" + p_ts, p_dur = parent["ts"], parent.get("dur", 0) + + candidates = [ + e + for e in events + if e.get("ph") == "X" + and e is not parent + and is_within(e.get("ts", 0), e.get("dur", 0), p_ts, p_dur) + ] + + direct = [] + for c in candidates: + c_ts, c_dur = c["ts"], c.get("dur", 0) + is_direct = not any( + is_within(c_ts, c_dur, o["ts"], o.get("dur", 0)) + for o in candidates + if o is not c + ) + if is_direct: + direct.append(c) + + return sorted(direct, key=lambda x: x["ts"]) + + +def count_kernel_launches(event: Dict, events: List[Dict]) -> int: + """Count kernel launches within event's subtree.""" + e_ts, e_dur = event["ts"], event.get("dur", 0) + return sum( + 1 + for e in events + if e.get("ph") == "X" + and is_kernel_launch(e.get("name", "")) + and is_within(e.get("ts", 0), e.get("dur", 0), e_ts, e_dur) + ) + + +def has_kernel_launch(event: Dict, events: List[Dict]) -> bool: + """Check if event's subtree contains any kernel launch.""" + return count_kernel_launches(event, events) > 0 + + +# ============================================================================= +# Parse Functions +# ============================================================================= + + +def parse_prefill(events: List[Dict], output_xlsx: str, target_layer: int = 3) -> None: + """ + Parse prefill phase: find the actual prefill event on CPU trace (user_annotation). + + Warmup rule: + - If only one prefill exists, it is the actual prefill (no warmup). + - If >=2 prefills exist: + - If there is a decode_step_bs* event between prefill[0] and prefill[1], prefill[0] + is treated as warmup and prefill[1] is the actual prefill. + - Otherwise, prefill[0] is the actual prefill. + """ + # CPU side prefill/decode annotations. + # Accept both legacy "prefill" and traced variants like + # "prefill_bs_1_ctxlens_tensor([417], ...)". + prefills = [ + e + for e in events + if (e.get("name") == "prefill" or e.get("name", "").startswith("prefill_bs_")) + and e.get("ph") == "X" + and e.get("cat") == "user_annotation" + ] + prefills = sorted(prefills, key=lambda x: x["ts"]) + + if not prefills: + print("No prefill (user_annotation) events found.") + write_breakdown_xlsx(output_xlsx, [], sheet_name="prefill") + return + + actual_prefill_idx = 0 + warmup_detected = False + + # Only evaluate warmup when there are at least two prefills. + if len(prefills) >= 2: + first = prefills[0] + second = prefills[1] + gap_start = first["ts"] + first.get("dur", 0) + gap_end = second["ts"] + + # If decode_step_bs appears in [gap_start, gap_end], first prefill is warmup. + has_decode_between = any( + e.get("ph") == "X" + and e.get("cat") == "user_annotation" + and e.get("name", "").startswith("decode_step_bs") + and gap_start <= e.get("ts", 0) <= gap_end + for e in events + ) + if has_decode_between: + actual_prefill_idx = 1 + warmup_detected = True + + actual_prefill = prefills[actual_prefill_idx] + print(f"Found {len(prefills)} prefill events (user_annotation)") + if warmup_detected: + print("Warmup detected: decode_step_bs found between prefill[0] and prefill[1]") + else: + print("No warmup prefill detected by rule, using prefill[0]") + print( + f"Using prefill[{actual_prefill_idx}] " + f"(ts={actual_prefill.get('ts', 0):.0f}, dur={actual_prefill.get('dur', 0):.0f})" + ) + + # Build prefill hierarchy on the same thread as the selected CPU prefill. + # Using thread affinity is more robust than category-only filtering. + prefill_tid = actual_prefill.get("tid") + prefill_pid = actual_prefill.get("pid") + prefill_hierarchy_events = [ + e + for e in events + if e.get("ph") == "X" + and e.get("tid") == prefill_tid + and e.get("pid") == prefill_pid + ] + # Build index once for fast subtree queries in prefill parsing. + prefill_idx = EventIndex(prefill_hierarchy_events) + level1_children = prefill_idx.get_direct_children(actual_prefill) + print( + f"Prefill level 1 (same thread pid={prefill_pid}, tid={prefill_tid}): " + f"{len(level1_children)} nodes" + ) + + # Keep only level2 children that have kernel launch in their subtree. + launch_level2_items = [] + for l1 in level1_children: + l1_name = l1.get("name", "") + level2_children = prefill_idx.get_direct_children(l1) + level2_with_launch = [ + l2 for l2 in level2_children if prefill_idx.has_kernel_launch(l2) + ] + for l2 in level2_with_launch: + launch_level2_items.append( + { + "level1_name": l1_name, + "level2_event": l2, + } + ) + + print(f"Level2 children with kernel launch: {len(launch_level2_items)}") + + # Layer extraction by rmsnorm positions: + # each layer has 2 rmsnorm modules, layer N starts at rmsnorm index 2*N (0-based). + TARGET_LAYER = target_layer + all_norm_indices = [ + i + for i, item in enumerate(launch_level2_items) + if "rmsnorm" in item["level2_event"].get("name", "").lower() + ] + # Last rmsnorm is final layernorm, not part of transformer layers. + norm_indices = all_norm_indices[:-1] if len(all_norm_indices) > 0 else [] + print( + f"Found {len(all_norm_indices)} rmsnorm modules in level2-with-launch rows " + f"({len(norm_indices)} used for layer split, excluding final layernorm)" + ) + + mod_start = 0 + mod_end = 0 + norm_start_idx = TARGET_LAYER * 2 + norm_end_idx = (TARGET_LAYER + 1) * 2 + final_norm_idx = ( + all_norm_indices[-1] if len(all_norm_indices) > 0 else len(launch_level2_items) + ) + if norm_start_idx >= len(norm_indices): + print( + f"Not enough rmsnorm modules for layer {TARGET_LAYER}, writing empty XLSX" + ) + else: + mod_start = norm_indices[norm_start_idx] + mod_end = ( + norm_indices[norm_end_idx] + if norm_end_idx < len(norm_indices) + else final_norm_idx + ) + print( + f"Layer {TARGET_LAYER} range by rmsnorm: " + f"rows [{mod_start}:{mod_end}) from rmsnorm #{norm_start_idx+1} to #{norm_end_idx+1}" + ) + print(f"Layer {TARGET_LAYER} modules: {mod_end - mod_start}") + + # Build launch->kernel mapping by correlation id. + # Build launch candidates from current prefill thread/range once. + runtime_launches = [ + e + for e in prefill_hierarchy_events + if e.get("cat") == "cuda_runtime" and is_kernel_launch(e.get("name", "")) + ] + runtime_launches.sort(key=lambda x: x.get("ts", 0)) + runtime_launch_ts = [e.get("ts", 0) for e in runtime_launches] + + item_corrs: List[List[Any]] = [] + corr_needed = set() + for item in launch_level2_items: + l2 = item["level2_event"] + l2_start = l2.get("ts", 0) + l2_end = l2_start + l2.get("dur", 0) + + left = bisect.bisect_left(runtime_launch_ts, l2_start) + right = bisect.bisect_right(runtime_launch_ts, l2_end) + launches_in_l2 = runtime_launches[left:right] + curr_corrs = [] + for launch in launches_in_l2: + corr = (launch.get("args") or {}).get("correlation") + if corr is not None: + corr_needed.add(corr) + curr_corrs.append(corr) + item_corrs.append(curr_corrs) + + # Build kernel index only for correlations we actually need. + kernel_by_corr: Dict[Any, List[Dict]] = {} + if corr_needed: + for e in events: + if e.get("ph") != "X" or e.get("cat") != "kernel": + continue + corr = (e.get("args") or {}).get("correlation") + if corr is None or corr not in corr_needed: + continue + kernel_by_corr.setdefault(corr, []).append(e) + for corr in kernel_by_corr: + kernel_by_corr[corr].sort(key=lambda x: x.get("ts", 0)) + + item_kernels: List[List[Dict[str, Any]]] = [] + for corrs in item_corrs: + kernels = [] + for corr in corrs: + for k in kernel_by_corr.get(corr, []): + kernels.append({"name": k.get("name", "N/A"), "dur": k.get("dur", 0)}) + item_kernels.append(kernels) + + def build_rows_from_item_range(start: int, end: int) -> List[List[Any]]: + rows = [] + for i in range(start, end): + item = launch_level2_items[i] + mod_name = item["level2_event"].get("name", "") + if should_filter_prefill(mod_name): + continue + kernels = [k for k in item_kernels[i] if k["name"] not in ("", "N/A")] + if not kernels: + continue + if "moe_forward" in mod_name.lower(): + rows.extend(process_moe_module(mod_name, len(kernels), 0, kernels)) + else: + for k in kernels: + rows.append( + [clean_module_name(mod_name, k["name"]), k["name"], k["dur"]] + ) + return rows + + # Target layer rows. + csv_rows = ( + build_rows_from_item_range(mod_start, mod_end) + if norm_start_idx < len(norm_indices) + else [] + ) + print(f"Layer {TARGET_LAYER} launch->kernel mapping rows: {len(csv_rows)}") + + print(f"Prefill decode-style CSV rows (after filters): {len(csv_rows)}") + + # AVG rows from layer 3 to last layer. + avg_rows = None + avg_layer_rows: List[List[List[Any]]] = [] + avg_start_layer = 3 + layer = avg_start_layer + while 2 * layer < len(norm_indices): + s = norm_indices[2 * layer] + e_idx = 2 * (layer + 1) + e = norm_indices[e_idx] if e_idx < len(norm_indices) else final_norm_idx + avg_layer_rows.append(build_rows_from_item_range(s, e)) + layer += 1 + if avg_layer_rows: + avg_rows = build_avg_rows_from_layers( + avg_layer_rows, avg_start_layer, "Prefill" + ) + if avg_rows is not None: + print(f"Prefill avg rows: {len(avg_rows)}") + + # Write XLSX for prefill. + write_breakdown_xlsx(output_xlsx, csv_rows, sheet_name="prefill", avg_rows=avg_rows) + + +def clean_module_name(name: str, mapped_kernel_name: str = "") -> str: + """Clean and simplify module name.""" + # Runtime launch wrappers should display the actual launched operator name. + if "hipmodulelaunchkernel" in name.lower() and mapped_kernel_name not in ( + "", + "N/A", + ): + name = mapped_kernel_name + + # Remove 'aiter::' prefix if present + if name.startswith("aiter::"): + name = name[7:] # len('aiter::') == 7 + + # Rename based on keywords (rope takes priority) + name_lower = name.lower() + if "rope" in name_lower and "cache" in name_lower: + return "rope & kv_cache" + if "rope" in name_lower: + return "rope" + if "cache" in name_lower and "gemm" not in name_lower: + return "kv_cache" + + return name + + +def process_moe_module( + mod_name: str, kernel_count: int, start_gpu_idx: int, gpu_kernels: List[Dict] +) -> List[List]: + """ + Process moe_forward module: categorize kernels by name. + + - 'moesort' in kernel name -> moe_sort + - 'topk' in kernel name -> moe_topk + - others -> keep original mod_name + + Returns list of [display_name, gpu_kernel_name, gpu_dur] rows. + """ + rows = [] + for i in range(kernel_count): + gpu_idx = start_gpu_idx + i + gpu_kernel_name = "N/A" + gpu_dur = 0 + if gpu_idx < len(gpu_kernels): + gpu = gpu_kernels[gpu_idx] + gpu_kernel_name = gpu.get("name", "N/A") + gpu_dur = gpu.get("dur", 0) + + # Determine category based on kernel name + kernel_lower = gpu_kernel_name.lower() + if "moesort" in kernel_lower: + category = "moe_sort" + elif "topk" in kernel_lower: + category = "moe_topk" + else: + category = clean_module_name(mod_name, gpu_kernel_name) + + # Always show category/module name on each row. + display_name = category + rows.append([display_name, gpu_kernel_name, gpu_dur]) + + return rows + + +def process_regular_module( + mod_name: str, kernel_count: int, start_gpu_idx: int, gpu_kernels: List[Dict] +) -> List[List]: + """Process regular module and show module name on every row.""" + rows = [] + for i in range(kernel_count): + gpu_idx = start_gpu_idx + i + gpu_kernel_name = "N/A" + gpu_dur = 0 + if gpu_idx < len(gpu_kernels): + gpu = gpu_kernels[gpu_idx] + gpu_kernel_name = gpu.get("name", "N/A") + gpu_dur = gpu.get("dur", 0) + display_name = clean_module_name(mod_name, gpu_kernel_name) + rows.append([display_name, gpu_kernel_name, gpu_dur]) + return rows + + +def parse_decode(events: List[Dict], output_xlsx: str, target_layer: int = 3) -> None: + """ + Parse decode phase: map capture_graph modules to GPU kernels. + + Output CSV columns: cpu_module, gpu_kernel, duration_us + """ + print("Building event index...") + + # Find GPU-annotated decode_step events (cat='gpu_user_annotation') + decode_steps = [ + e + for e in events + if e.get("name", "").startswith("decode_step") + and e.get("ph") == "X" + and e.get("cat") == "gpu_user_annotation" + ] + decode_steps = sorted(decode_steps, key=lambda x: x["ts"]) + + if not decode_steps: + print("No decode_step (gpu_user_annotation) events found.") + return + + # Skip warmup: find first gap > 100ms (warmup/run boundary) + # Normal decode gaps are < 5ms, so 100ms is safe threshold + WARMUP_GAP_THRESHOLD = 100000 # 100ms in microseconds + actual_run_idx = 0 + found_warmup_boundary = False + for i in range(1, len(decode_steps)): + gap = decode_steps[i]["ts"] - ( + decode_steps[i - 1]["ts"] + decode_steps[i - 1].get("dur", 0) + ) + if gap > WARMUP_GAP_THRESHOLD: + actual_run_idx = i + found_warmup_boundary = True + print(f"Warmup/run boundary at [{i-1}]->[{i}], gap={gap/1000:.1f}ms") + break + + if not found_warmup_boundary: + print("No warmup detected (no gap > 100ms), using first decode_step") + + first_ds = decode_steps[actual_run_idx] + first_ds_name = first_ds.get("name", "") + target_bs: Optional[int] = None + if "_bs_" in first_ds_name: + bs = first_ds_name.split("_bs_")[-1] + target_cg_name = f"capture_graph_bs_{bs}" + try: + target_bs = int(bs) + except ValueError: + target_bs = None + else: + target_cg_name = "capture_graph" + + print(f"First decode_step: {first_ds_name}") + print(f"Looking for: {target_cg_name}") + + # Find matching capture_graph + capture_graphs = [ + e for e in events if e.get("name") == target_cg_name and e.get("ph") == "X" + ] + if not capture_graphs and target_bs is not None: + # Prefer the largest capture_graph_bs_K where K < target_bs. + lower_bs_candidates: List[Tuple[int, Dict[str, Any]]] = [] + for e in events: + if e.get("ph") != "X": + continue + n = e.get("name", "") + m = re.match(r"^capture_graph_bs_(\d+)$", n) + if not m: + continue + k = int(m.group(1)) + if k < target_bs: + lower_bs_candidates.append((k, e)) + if lower_bs_candidates: + best_bs = max(k for k, _ in lower_bs_candidates) + capture_graphs = sorted( + [e for k, e in lower_bs_candidates if k == best_bs], + key=lambda x: x.get("ts", 0), + ) + print(f"No exact match, using nearest lower capture_graph_bs_{best_bs}") + if not capture_graphs: + # Fallback: find any capture_graph + capture_graphs = [ + e + for e in events + if e.get("name", "").startswith("capture_graph") and e.get("ph") == "X" + ] + capture_graphs = sorted(capture_graphs, key=lambda x: x["ts"]) + print("No exact match, using first capture_graph") + + if not capture_graphs: + print("No capture_graph events found.") + return + + cg = capture_graphs[0] + print(f"Using: {cg.get('name')}") + + # Build optimized index only for capture_graph's time range + cg_start = cg["ts"] + cg_end = cg_start + cg.get("dur", 0) + cg_events = [ + e + for e in events + if e.get("ph") == "X" + and e.get("ts", 0) >= cg_start + and e.get("ts", 0) + e.get("dur", 0) <= cg_end + ] + print(f"Events in capture_graph: {len(cg_events)}") + idx = EventIndex(cg_events) + + # Get GPU kernels from first decode_step (within its duration) + ds1_start = first_ds["ts"] + ds1_end = ds1_start + first_ds.get("dur", 0) + + gpu_kernels = [ + e + for e in events + if e.get("cat") == "kernel" and ds1_start <= e["ts"] <= ds1_end + ] + gpu_kernels = sorted(gpu_kernels, key=lambda x: x["ts"]) + print(f"First decode_step (tid={first_ds.get('tid')}): {first_ds_name}") + print( + f" Range: {ds1_start:.0f} ~ {ds1_end:.0f} (dur={first_ds.get('dur', 0):.0f})" + ) + print(f" GPU kernels: {len(gpu_kernels)}") + + # Use optimized index for children lookup + direct_children = idx.get_direct_children(cg) + kernel_children = [c for c in direct_children if idx.has_kernel_launch(c)] + print(f"Direct children with kernels: {len(kernel_children)}") + + # Collect all modules with their kernel info + all_modules = [] # list of (mod_name, kernel_count, start_gpu_idx) + gpu_idx = 0 + + for child in kernel_children: + child_name = child.get("name", "") + if should_filter(child_name): + continue + + # Get sub-children (actual module names) + sub_children = idx.get_direct_children(child) + sub_kernel_children = [sc for sc in sub_children if idx.has_kernel_launch(sc)] + + # Determine modules to process + modules = sub_kernel_children if sub_kernel_children else [child] + + for mod in modules: + mod_name = mod.get("name", "") + kernel_count = idx.count_kernel_launches(mod) + all_modules.append((mod_name, kernel_count, gpu_idx)) + gpu_idx += kernel_count + + # Find norm positions (rmsnorm in name) + all_norm_indices = [ + i for i, (name, _, _) in enumerate(all_modules) if "rmsnorm" in name.lower() + ] + # Last rmsnorm is final layernorm, not part of transformer layers. + norm_indices = all_norm_indices[:-1] if len(all_norm_indices) > 0 else [] + print( + f"Found {len(all_norm_indices)} norm modules " + f"({len(norm_indices)} used for layer split, excluding final layernorm)" + ) + + # Extract layer 3 (4th layer, 0-indexed) + # Each layer has 2 norms, so layer N starts at norm index 2*N + TARGET_LAYER = target_layer + norm_start_idx = TARGET_LAYER * 2 # 6 (7th norm, 0-indexed) + norm_end_idx = (TARGET_LAYER + 1) * 2 # 8 (9th norm, 0-indexed) + + final_norm_idx = ( + all_norm_indices[-1] if len(all_norm_indices) > 0 else len(all_modules) + ) + if norm_start_idx >= len(norm_indices): + print(f"Not enough norms for layer {TARGET_LAYER}") + return + + # Module range for layer 3: from norm_indices[6] to norm_indices[8] (exclusive) + mod_start = norm_indices[norm_start_idx] + mod_end = ( + norm_indices[norm_end_idx] + if norm_end_idx < len(norm_indices) + else final_norm_idx + ) + + print( + f"Layer {TARGET_LAYER}: modules [{mod_start}:{mod_end}] (norms at indices {norm_start_idx}, {norm_start_idx+1})" + ) + + def build_rows_for_module_range(start: int, end: int) -> List[List[Any]]: + rows = [] + for mod_name, kernel_count, start_gpu_idx in all_modules[start:end]: + if "moe_forward" in mod_name.lower(): + rows.extend( + process_moe_module( + mod_name, kernel_count, start_gpu_idx, gpu_kernels + ) + ) + else: + rows.extend( + process_regular_module( + mod_name, kernel_count, start_gpu_idx, gpu_kernels + ) + ) + return rows + + # Target layer rows. + rows = build_rows_for_module_range(mod_start, mod_end) + + # AVG rows from layer 3 to last layer. + avg_rows = None + avg_layer_rows: List[List[List[Any]]] = [] + layer = 3 + while 2 * layer < len(norm_indices): + s = norm_indices[2 * layer] + e_idx = 2 * (layer + 1) + e = norm_indices[e_idx] if e_idx < len(norm_indices) else final_norm_idx + avg_layer_rows.append(build_rows_for_module_range(s, e)) + layer += 1 + if avg_layer_rows: + avg_rows = build_avg_rows_from_layers(avg_layer_rows, 3, "Decode") + if avg_rows is not None: + print(f"Decode avg rows: {len(avg_rows)}") + + # Write XLSX + write_breakdown_xlsx(output_xlsx, rows, sheet_name="decode", avg_rows=avg_rows) + + print(f"Layer {TARGET_LAYER} modules: {mod_end - mod_start}") + print(f"XLSX written to: {output_xlsx} ({len(rows)} rows)") + + +# ============================================================================= +# Main +# ============================================================================= + + +def main(): + parser = argparse.ArgumentParser( + description="Parse PyTorch profiler trace JSON to extract kernel information." + ) + parser.add_argument("filepath", help="Path to trace JSON or JSON.GZ file") + parser.add_argument( + "--layer", type=int, default=3, help="Target layer index (default: 3)" + ) + args = parser.parse_args() + + if args.layer < 0: + print("--layer must be >= 0") + sys.exit(1) + + filepath = args.filepath + target_layer = args.layer + + print(f"Loading: {filepath}") + trace = load_trace(filepath) + events = trace.get("traceEvents", []) + print(f"Loaded {len(events)} events\n") + + print("=" * 60) + print("PREFILL ANALYSIS") + print("=" * 60) + parse_prefill(events, "prefill_breakdown.xlsx", target_layer=target_layer) + + print("\n" + "=" * 60) + print("DECODE ANALYSIS") + print("=" * 60) + parse_decode(events, "decode_breakdown.xlsx", target_layer=target_layer) + + +if __name__ == "__main__": + main()