From 868cab44f933a6bb30e25eb2da2f553982a20444 Mon Sep 17 00:00:00 2001 From: HaonanWang98 Date: Tue, 3 Mar 2026 02:50:27 +0000 Subject: [PATCH 1/9] add mark_trace --- atom/config.py | 1 + atom/entrypoints/openai_server.py | 10 + atom/model_engine/arg_utils.py | 6 + atom/model_engine/engine_core.py | 6 +- atom/model_engine/model_runner.py | 45 +- atom/model_ops/linear.py | 10 + atom/models/deepseek_v2.py | 1 + atom/utils/compiler_inferface.py | 46 +- atom/utils/decorators.py | 79 +++ atom/utils/graph_marker.py | 50 ++ atom/utils/graph_marker_instrumentation.py | 290 ++++++++ pyproject.toml | 2 +- tools/parse_trace.py | 769 +++++++++++++++++++++ 13 files changed, 1296 insertions(+), 19 deletions(-) create mode 100644 atom/utils/graph_marker.py create mode 100644 atom/utils/graph_marker_instrumentation.py create mode 100644 tools/parse_trace.py diff --git a/atom/config.py b/atom/config.py index 8ef9be071..120fa2e62 100644 --- a/atom/config.py +++ b/atom/config.py @@ -590,6 +590,7 @@ class Config: default_factory=lambda: QuantizationConfig() ) asyncio_mode: bool = False + mark_trace: bool = False load_dummy: bool = False enable_expert_parallel: bool = False master_addr: str = "127.0.0.1" diff --git a/atom/entrypoints/openai_server.py b/atom/entrypoints/openai_server.py index 9ed9c3b45..affcec098 100644 --- a/atom/entrypoints/openai_server.py +++ b/atom/entrypoints/openai_server.py @@ -910,6 +910,16 @@ def main(): engine_args = EngineArgs.from_cli_args(args) engine = engine_args.create_engine() + if args.mark_trace: + logger.info( + "--mark-trace enabled: reinitializing engine once to apply compile-time instrumentation changes." + ) + try: + engine.core_mgr.close() + except Exception as e: + logger.warning(f"Failed to close engine before reinitialize: {e}") + engine = engine_args.create_engine() + print(f"Starting server on {args.host}:{args.server_port}...") uvicorn.run(app, host=args.host, port=args.server_port) diff --git a/atom/model_engine/arg_utils.py b/atom/model_engine/arg_utils.py index a3c9881e4..f16362705 100644 --- a/atom/model_engine/arg_utils.py +++ b/atom/model_engine/arg_utils.py @@ -45,6 +45,7 @@ class EngineArgs: enable_dp_attention: bool = False method: Optional[str] = None num_speculative_tokens: int = 1 + mark_trace: bool = False @staticmethod def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: @@ -170,6 +171,11 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: help="Apply a delay (of delay factor multiplied by previous" "prompt latency) before scheduling next prompt.", ) + parser.add_argument( + "--mark-trace", + action="store_true", + help="Enable graph_marker nodes for tracing/profile instrumentation.", + ) return parser diff --git a/atom/model_engine/engine_core.py b/atom/model_engine/engine_core.py index 80f85fd90..021507e94 100644 --- a/atom/model_engine/engine_core.py +++ b/atom/model_engine/engine_core.py @@ -63,6 +63,7 @@ def __init__(self, config: Config, input_address: str, output_address: str): self.input_thread.start() self.profile_enbaled = config.torch_profiler_dir is not None + self.mark_trace = getattr(config, "mark_trace", False) init_exit_handler(self) self._init_data_parallel(config) @@ -83,6 +84,9 @@ def __init__(self, config: Config, input_address: str, output_address: str): config.num_kvcache_blocks = num_blocks if not config.enforce_eager: + # Start profiler before cudagraph capture only if mark-trace is enabled. + if self.profile_enbaled and self.mark_trace: + self.runner_mgr.call_func("start_profiler", wait_out=True) cap_cost, bs = self.runner_mgr.call_func( "capture_cudagraph", wait_out=True ) @@ -284,7 +288,7 @@ def process_output_sockets(self, output_address: str): def start_profiler(self): if self.profile_enbaled: - self.runner_mgr.call_func("start_profiler") + self.runner_mgr.call_func("start_profiler", wait_out=True) def stop_profiler(self): if self.profile_enbaled: diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index bad557093..703d3ab92 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -5,6 +5,8 @@ import math import os import time +from contextlib import nullcontext +from itertools import chain, islice from typing import Any, Optional, Union import numpy as np @@ -19,6 +21,7 @@ graph_capture, ) from aiter.dist.utils import get_distributed_init_method +from torch.profiler import record_function from atom.config import Config, KVCacheTensor, set_current_atom_config from atom.model_engine.scheduler import ScheduledBatch, ScheduledBatchOutput from atom.model_engine.sequence import Sequence, SequenceStatus, SequenceType @@ -462,6 +465,9 @@ class ModelRunner: def __init__(self, rank: int, config: Config): self.config = config + self.mark_trace = getattr(config, "mark_trace", False) + from atom.utils.graph_marker import set_graph_marker_enabled + set_graph_marker_enabled(self.mark_trace) set_current_atom_config(config) hf_config = config.hf_config self.block_size = config.kv_cache_block_size @@ -696,6 +702,7 @@ def start_profiler(self): ), ) self.profiler.__enter__() + return True def stop_profiler(self): """Stop profiling for this rank""" @@ -1331,14 +1338,24 @@ def run_model(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor is_prefill = context.is_prefill positions = context.positions if is_prefill or self.enforce_eager or bs > self.graph_bs[-1]: - hidden_states = self.model(input_ids, positions) + with ( + record_function("prefill") + if self.mark_trace + else nullcontext() + ): + hidden_states = self.model(input_ids, positions) else: - graph_bs = context.graph_bs - max_q_len = forward_context.attn_metadata.max_seqlen_q - graph_key = (graph_bs, max_q_len) - self.graphs[graph_key].replay() - num_tokens = context.batch_size * max_q_len - hidden_states = self.forward_vars["outputs"][:num_tokens] + with ( + record_function(f"decode_step_bs_{bs}") + if self.mark_trace + else nullcontext() + ): + graph_bs = context.graph_bs + max_q_len = forward_context.attn_metadata.max_seqlen_q + graph_key = (graph_bs, max_q_len) + self.graphs[graph_key].replay() + num_tokens = context.batch_size * max_q_len + hidden_states = self.forward_vars["outputs"][:num_tokens] logits = self.model.compute_logits(hidden_states) return logits, hidden_states @@ -1536,11 +1553,15 @@ def capture_cudagraph(self): outputs[:num_tokens] = self.model( input_ids[:num_tokens], positions[:num_tokens] ) # warmup - - with torch.cuda.graph(graph, self.graph_pool, stream=gc.stream): - outputs[:num_tokens] = self.model( - input_ids[:num_tokens], positions[:num_tokens] - ) # capture + with ( + record_function(f"capture_graph_bs_{bs}") + if self.mark_trace + else nullcontext() + ): + with torch.cuda.graph(graph, self.graph_pool, stream=gc.stream): + outputs[:num_tokens] = self.model( + input_ids[:num_tokens], positions[:num_tokens] + ) # capture if self.graph_pool is None: self.graph_pool = graph.pool() self.graphs[(bs, max_q_len)] = graph diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index a3d7b4ef7..4046b652c 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -26,6 +26,7 @@ from aiter.jit.utils.torch_guard import torch_compile_guard from aiter.tuned_gemm import tgemm from aiter.utility import fp4_utils +from atom.utils.decorators import mark_trace from atom.model_ops.utils import shuffle_weights from atom.utils import envs @@ -190,6 +191,7 @@ def gemm_a8w8_blockscale_preshuffle_impl( return y +@mark_trace class LinearBase(nn.Module): def __init__( self, @@ -200,6 +202,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, reduce_results: bool = False, source_quant_dtype: torch.dtype | None = None, + prefix: str = "", ): if quant_config is None: quant_config = QuantizationConfig() @@ -207,6 +210,7 @@ def __init__( quant_type = quant_config["quant_type"] params_dtype = quant_config["quant_dtype"] super().__init__() + self.prefix = prefix self.reduce_results = reduce_results self.input_size = input_size self.output_size = ( @@ -460,6 +464,7 @@ def __init__( bias=bias, quant_config=quant_config, source_quant_dtype=source_quant_dtype, + prefix=kwargs.get("prefix", ""), ) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): @@ -485,6 +490,7 @@ def __init__( bias, quant_config=quant_config, source_quant_dtype=source_quant_dtype, + prefix=kwargs.get("prefix", ""), ) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): @@ -516,6 +522,7 @@ def __init__( bias=bias, quant_config=quant_config, source_quant_dtype=source_quant_dtype, + prefix=prefix, ) def weight_loader( @@ -650,6 +657,7 @@ def __init__( bias=bias, quant_config=quant_config, source_quant_dtype=source_quant_dtype, + prefix=kwargs.get("prefix", ""), ) def weight_loader( @@ -711,6 +719,7 @@ def __init__( quant_config=quant_config, reduce_results=reduce_results, source_quant_dtype=source_quant_dtype, + prefix=prefix, ) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): @@ -746,6 +755,7 @@ def __init__( bias=bias, quant_config=quant_config, source_quant_dtype=source_quant_dtype, + prefix=kwargs.get("prefix", ""), ) def weight_loader( diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index f0342dce1..51006bf43 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -1276,6 +1276,7 @@ def __init__( bias=False, quant_config=quant_config, source_quant_dtype=source_quant_dtype, + prefix=f"{prefix}.q_a_proj", ) self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) self.q_b_proj = ColumnParallelLinear( diff --git a/atom/utils/compiler_inferface.py b/atom/utils/compiler_inferface.py index 71d9fd231..ebabf59cd 100644 --- a/atom/utils/compiler_inferface.py +++ b/atom/utils/compiler_inferface.py @@ -168,6 +168,14 @@ def set_inductor_config(config, runtime_shape): config["max_autotune"] = True config["coordinate_descent_tuning"] = True + try: + from atom.utils.graph_marker import is_graph_marker_enabled + if is_graph_marker_enabled(): + config["size_asserts"] = False + config["compile_threads"] = 1 + except Exception: + pass + class InductorAdaptor(CompilerInterface): """ @@ -390,11 +398,23 @@ def _get_shape_env() -> AlwaysHitShapeEnv: "failed, leading to a corrupted compilation artifact. " "We recommend trying to " "remove ~/.cache/vllm/torch_compile_cache and try again " - "to see the real issue. " - ) - assert ( - file_path is not None - ), "failed to get the file path of the compiled graph" + "to see the real issue. ") + assert file_path is not None, ( + "failed to get the file path of the compiled graph") + # Best-effort post-process the generated wrapper file too (PyTorch <2.8 path). + try: + # Only run post-processing when mark-trace is enabled (to avoid any + # overhead / file churn in default runs). + from atom.utils.graph_marker import is_graph_marker_enabled + if is_graph_marker_enabled(): + # Local import to avoid extra package-level side effects. + from .graph_marker_instrumentation import ( + instrument_record_functions_in_file, + ) + instrument_record_functions_in_file(file_path, + strip_markers=False) + except Exception: + pass return compiled_graph, (hash_str, file_path) def load( @@ -560,6 +580,22 @@ def compile( # if not envs.VLLM_DISABLE_COMPILE_CACHE: compiled_graph.save(path=path, format="unpacked") compilation_counter.num_compiled_artifacts_saved += 1 + # Post-process generated wrapper Python files: wrap regions between + # _start / _end graph markers with record_function(""). + try: + # Only run post-processing when mark-trace is enabled (to avoid any + # overhead / file churn in default runs). + from atom.utils.graph_marker import is_graph_marker_enabled + if is_graph_marker_enabled(): + # Local import to avoid extra package-level side effects. + from .graph_marker_instrumentation import ( + instrument_record_functions_in_dir, + ) + instrument_record_functions_in_dir(path, + strip_markers=False) + except Exception: + # Best-effort: never fail compilation due to instrumentation. + pass return compiled_graph, (key, path) def load( diff --git a/atom/utils/decorators.py b/atom/utils/decorators.py index 71344d2b2..14581d906 100644 --- a/atom/utils/decorators.py +++ b/atom/utils/decorators.py @@ -17,6 +17,7 @@ from atom.config import CompilationConfig, Config, CompilationLevel +from atom.utils.graph_marker import graph_marker # from atom.utils import start_monitoring_torch_compile _T = TypeVar("_T", bound=type[nn.Module]) @@ -25,6 +26,84 @@ torch_compile_start_time: float = 0.0 +def _graph_marker_first_tensor(obj, name: str): + if torch.is_tensor(obj): + return graph_marker(obj, name=name), True + if isinstance(obj, tuple): + out = [] + marked_any = False + for v in obj: + if marked_any: + out.append(v) + continue + vv, marked_any = _graph_marker_first_tensor(v, name) + out.append(vv) + out_t = tuple(out) + # namedtuple support + if hasattr(obj, "_fields"): + return obj.__class__(*out_t), marked_any + return out_t, marked_any + if isinstance(obj, list): + out = [] + marked_any = False + for v in obj: + if marked_any: + out.append(v) + continue + vv, marked_any = _graph_marker_first_tensor(v, name) + out.append(vv) + return out, marked_any + if isinstance(obj, dict): + out = {} + marked_any = False + for k, v in obj.items(): + if marked_any: + out[k] = v + continue + vv, marked_any = _graph_marker_first_tensor(v, name) + out[k] = vv + return out, marked_any + return obj, False + + +def mark_trace(cls): + forward = getattr(cls, "forward", None) + if forward is None: + return cls + if getattr(forward, "__mark_trace_wrapped__", False): + return cls + + from atom.utils.graph_marker import is_graph_marker_enabled + + def wrapped_forward(self, *args, **kwargs): + # When mark-trace is disabled, bypass all wrapping logic entirely + if not is_graph_marker_enabled(): + return forward(self, *args, **kwargs) + + prefix = getattr(self, "prefix", cls.__name__) + # Mark only the first tensor across args/kwargs, keeping names stable. + args_l = list(args) + marked = False + for i, a in enumerate(args_l): + if marked: + break + aa, marked = _graph_marker_first_tensor(a, f"{prefix}_start") + args_l[i] = aa + if not marked: + for k, v in list(kwargs.items()): + if marked: + break + vv, marked = _graph_marker_first_tensor(v, f"{prefix}_start") + kwargs[k] = vv + args = tuple(args_l) + y = forward(self, *args, **kwargs) + yy, _ = _graph_marker_first_tensor(y, f"{prefix}_end") + return yy + + wrapped_forward.__mark_trace_wrapped__ = True + cls.forward = wrapped_forward + return cls + # We remove it from utils/__init__.py to avoid circular import def start_monitoring_torch_compile(vllm_config: Config): global torch_compile_start_time diff --git a/atom/utils/graph_marker.py b/atom/utils/graph_marker.py new file mode 100644 index 000000000..1f71a98c9 --- /dev/null +++ b/atom/utils/graph_marker.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. +# +# A tiny, graph-friendly marker op for debugging/graph inspection. +# It is an identity at runtime, but it shows up in FX/graph dumps. + +# from __future__ import annotations + +import torch + +from aiter.jit.utils.torch_guard import torch_compile_guard + +_GRAPH_MARKER_ENABLED: bool = False + + +def set_graph_marker_enabled(enabled: bool) -> None: + """Enable/disable graph markers globally (per-process).""" + global _GRAPH_MARKER_ENABLED + _GRAPH_MARKER_ENABLED = bool(enabled) + + +def is_graph_marker_enabled() -> bool: + return _GRAPH_MARKER_ENABLED + + +def _graph_marker_impl(x: torch.Tensor) -> torch.Tensor: + # Runtime behavior: identity. + # Keep this side-effect free to avoid graph breaks. + return x + + +def _graph_marker_fake(x: torch.Tensor, name: str) -> torch.Tensor: + # FakeTensor / meta behavior: identity with preserved shape/stride/dtype. + return x + + +@torch_compile_guard(gen_fake=_graph_marker_fake) +def graph_marker(x: torch.Tensor, name: str) -> torch.Tensor: + """Insert a no-op marker node into the compiled/traced graph. + + The marker `name` is embedded as a constant in the graph dump so you can + grep it in `computation_graph.py` / generated wrapper files. + """ + # When disabled, return early so the marker does not even appear in the + # traced/compiled graph. + if not _GRAPH_MARKER_ENABLED: + return x + return _graph_marker_impl(x) + + diff --git a/atom/utils/graph_marker_instrumentation.py b/atom/utils/graph_marker_instrumentation.py new file mode 100644 index 000000000..1c893686a --- /dev/null +++ b/atom/utils/graph_marker_instrumentation.py @@ -0,0 +1,290 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. + +from __future__ import annotations + +import os +import re +from dataclasses import dataclass +from typing import Iterable, Optional + + +_GRAPH_MARKER_RE = re.compile( + r"""torch\.ops\.aiter\.graph_marker\.default\( + \s*[^,]+,\s* + (?P['"])(?P.*?)(?P=q) + \s*\)""", + re.VERBOSE, +) + +_SUBGRAPH_ID_RE = re.compile(r"artifact_shape_[^/]+_subgraph_(\d+)") + +_GRAPH_MARKER_LINE_RE = re.compile( + r"""^(?P\s*) + (?P[A-Za-z_]\w*)\s*=\s* + torch\.ops\.aiter\.graph_marker\.default\( + \s*(?P[^,]+?)\s*,\s* + (?P['"])(?P.*?)(?P=q)\s* + \)\s*$""", + re.VERBOSE, +) + + +@dataclass(frozen=True) +class _Marker: + idx: int + indent: str + name: str + + +def _iter_py_files(root: str) -> Iterable[str]: + for dirpath, _, filenames in os.walk(root): + for fn in filenames: + if fn.endswith(".py"): + yield os.path.join(dirpath, fn) + + +def _ensure_record_function_import(lines: list[str]) -> None: + # If already imported or referenced via qualified name, do nothing. + if any( + ("record_function" in l and ("import" in l or "from torch" in l)) + for l in lines + ): + return + + # Insert `from torch.profiler import record_function` after the first + # real `import torch` line outside the initial docstring. + in_doc = False + for i, line in enumerate(lines): + if i == 0 and line.lstrip().startswith('"""'): + in_doc = True + if in_doc and line.rstrip().endswith('"""') and i != 0: + in_doc = False + continue + if in_doc: + continue + if re.match(r"^\s*import\s+torch\b", line): + lines.insert(i + 1, "from torch.profiler import record_function\n") + return + + # Fallback: put it near the top (after any shebang/encoding if present). + insert_at = 0 + if lines and lines[0].startswith("#!"): + insert_at = 1 + lines.insert(insert_at, "from torch.profiler import record_function\n") + + +def _collect_markers(lines: list[str]) -> list[_Marker]: + out: list[_Marker] = [] + for i, line in enumerate(lines): + m = _GRAPH_MARKER_RE.search(line) + if not m: + continue + indent = re.match(r"^(\s*)", line).group(1) # type: ignore[union-attr] + out.append(_Marker(idx=i, indent=indent, name=m.group("name"))) + return out + + +def _prefix_and_kind(name: str) -> Optional[tuple[str, str]]: + if name.endswith("_start"): + return name[: -len("_start")], "start" + if name.endswith("_end"): + return name[: -len("_end")], "end" + return None + + +def _already_wrapped(lines: list[str], indent: str, prefix: str, start_idx: int, end_idx: int) -> bool: + needle = f'{indent}with record_function("{prefix}"):\n' + for i in range(start_idx, min(end_idx, len(lines))): + if lines[i] == needle: + return True + return False + + +def _wrap_region_with_record_function( + lines: list[str], + *, + start_marker_idx: int, + end_marker_idx: int, + prefix: str, + indent: str, + layer_id: Optional[int] = None, +) -> None: + """ + Transform: + ... graph_marker(..., "_start") + LINE_A + LINE_B + ... graph_marker(..., "_end") + Into: + ... graph_marker(..., "_start") + with record_function(""): + LINE_A + LINE_B + ... graph_marker(..., "_end") + """ + if end_marker_idx <= start_marker_idx + 1: + return + + # Only add layer prefix if prefix doesn't already contain layer info + if layer_id is not None and layer_id >= 0 and "model.layers" not in prefix: + tag = f"layer_{layer_id}_{prefix}" + else: + tag = prefix + with_line = f'{indent}with record_function("{tag}"):\n' + insert_at = start_marker_idx + 1 + + # If we already inserted a record_function line in a previous run, upgrade it + # in-place (e.g. "mlp" -> "layer_0_mlp") and exit without touching indentation. + if insert_at < len(lines) and lines[insert_at].startswith(f"{indent}with record_function("): + if lines[insert_at] != with_line: + lines[insert_at] = with_line + return + + # Otherwise, insert a new record_function wrapper and indent the region. + lines.insert(insert_at, with_line) + + # Re-indent the region between start_marker and end_marker (exclusive of end marker). + region_start = insert_at + 1 + region_end = end_marker_idx + 1 # end marker shifted down by 1 due to insertion + indent_prefix = indent + extra = " " * 4 + for i in range(region_start, region_end): + line = lines[i] + if line.strip() == "": + continue + if line.startswith(indent_prefix): + lines[i] = indent_prefix + extra + line[len(indent_prefix):] + + +def _layer_id_from_wrapper_path(path: str) -> Optional[int]: + """ + Derive layer id from wrapper file path: + .../artifact_shape__subgraph_/... -> layer_id = N - 1 + Returns None if the pattern isn't found. + """ + m = _SUBGRAPH_ID_RE.search(path) + if not m: + return None + try: + subgraph_id = int(m.group(1)) + except ValueError: + return None + return subgraph_id - 1 + + +def _strip_runtime_graph_markers(lines: list[str]) -> bool: + """ + Remove runtime overhead of graph markers in generated wrapper code. + + - Replace `x = torch.ops.aiter.graph_marker.default(y, '...')` with `x = y` + - Drop assert_size_stride / assert_alignment lines that specifically refer to + `torch.ops.aiter.graph_marker.default` (they become redundant). + """ + out: list[str] = [] + changed = False + for line in lines: + m = _GRAPH_MARKER_LINE_RE.match(line.rstrip("\n")) + if m: + indent = m.group("indent") + lhs = m.group("lhs") + arg = m.group("arg").strip() + out.append(f"{indent}{lhs} = {arg}\n") + changed = True + continue + + if ( + ("assert_size_stride" in line or "assert_alignment" in line) + and "torch.ops.aiter.graph_marker.default" in line + ): + changed = True + continue + + out.append(line) + + if changed: + lines[:] = out + return changed + + +def instrument_record_functions_in_file(path: str, *, strip_markers: bool = True) -> bool: + """ + Returns True if the file was modified. + """ + try: + with open(path, "r", encoding="utf-8") as f: + lines = f.readlines() + except OSError: + return False + + markers = _collect_markers(lines) + if not markers: + return False + + # Build intervals by matching _start / _end. + stack: dict[str, _Marker] = {} + intervals: list[tuple[_Marker, _Marker, str]] = [] + for mk in markers: + pk = _prefix_and_kind(mk.name) + if pk is None: + continue + prefix, kind = pk + if kind == "start": + stack[prefix] = mk + else: + start_mk = stack.pop(prefix, None) + if start_mk is None: + continue + # Use start indent as the wrapping indent (generated code is consistent). + intervals.append((start_mk, mk, prefix)) + + # Even if we can't form any intervals, we still might want to strip marker + # calls from already-instrumented wrappers (best-effort). + has_intervals = bool(intervals) + + layer_id = _layer_id_from_wrapper_path(path) + + # Apply from bottom to top so indices stay valid. + wrapped_or_upgraded = False + if has_intervals: + for start_mk, end_mk, prefix in sorted(intervals, key=lambda t: t[0].idx, reverse=True): + _wrap_region_with_record_function( + lines, + start_marker_idx=start_mk.idx, + end_marker_idx=end_mk.idx, + prefix=prefix, + indent=start_mk.indent, + layer_id=layer_id, + ) + wrapped_or_upgraded = True + + stripped = False + if strip_markers: + # Only strip marker calls if we either: + # - wrapped/upgraded this run, or + # - the file already contains record_function blocks (previous run) + already_has_record = any("with record_function(" in l for l in lines) + if wrapped_or_upgraded or already_has_record: + stripped = _strip_runtime_graph_markers(lines) + + changed = wrapped_or_upgraded or stripped + if changed: + if wrapped_or_upgraded: + _ensure_record_function_import(lines) + with open(path, "w", encoding="utf-8") as f: + f.writelines(lines) + return changed + + +def instrument_record_functions_in_dir(root: str, *, strip_markers: bool = True) -> int: + """ + Walk `root` and instrument all generated `.py` wrapper files. + Returns the number of modified files. + """ + changed = 0 + for fp in _iter_py_files(root): + if instrument_record_functions_in_file(fp, strip_markers=strip_markers): + changed += 1 + return changed + + diff --git a/pyproject.toml b/pyproject.toml index 49d3d8c0b..08625423e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ readme = "README.md" description = "a lightweight vLLM implementation built from scratch" requires-python = ">=3.10,<3.13" dynamic = ["version"] -dependencies = ["pybind11", "transformers", "zmq", "xxhash", "fastapi", "psutil", "protobuf", "uvicorn", "aiohttp", "datasets"] +dependencies = ["pybind11", "transformers", "zmq", "xxhash", "fastapi", "psutil", "protobuf", "uvicorn", "aiohttp", "datasets", "openpyxl"] [project.urls] Homepage = "https://github.com/ROCm/ATOM" diff --git a/tools/parse_trace.py b/tools/parse_trace.py new file mode 100644 index 000000000..d93131889 --- /dev/null +++ b/tools/parse_trace.py @@ -0,0 +1,769 @@ +#!/usr/bin/env python3 +""" +Parse PyTorch profiler trace JSON to extract kernel information. + +Usage: + python parse_trace.py [--layer N] +""" + +import json +import gzip +import sys +import bisect +import argparse +from typing import List, Dict, Any, Tuple, Optional +from openpyxl import Workbook + +# Modules to filter out (no corresponding GPU kernel in decode) +FILTER_OUT = ['fill_'] + +# Sampling-related modules and low-level ops to filter out in prefill +FILTER_OUT_PREFILL = ['aten::', 'aiter::gemm_a16w16', 'aiter::mixed_sample'] + + +# ============================================================================= +# Utility Functions +# ============================================================================= + +def load_trace(filepath: str) -> Dict[str, Any]: + """Load trace JSON file (supports .gz).""" + opener = gzip.open if filepath.endswith('.gz') else open + with opener(filepath, 'rt', encoding='utf-8') as f: + return json.load(f) + + +def is_within(child_ts: float, child_dur: float, parent_ts: float, parent_dur: float) -> bool: + """Check if child event is within parent's time range.""" + return child_ts >= parent_ts and (child_ts + child_dur) <= (parent_ts + parent_dur) + + +def is_kernel_launch(name: str) -> bool: + """Check if name is a kernel launch (contains 'launch' and 'kernel').""" + n = name.lower() + return 'launch' in n and 'kernel' in n + + +def should_filter(name: str) -> bool: + """Check if module should be filtered out.""" + return any(f in name for f in FILTER_OUT) + + +def should_filter_prefill(name: str) -> bool: + """Check if module should be filtered out in prefill (sampling ops).""" + return any(f in name for f in FILTER_OUT_PREFILL) + + +def write_breakdown_xlsx(output_xlsx: str, rows: List[List[Any]], sheet_name: str) -> None: + """ + Write XLSX breakdown with columns: + cpu_module, gpu_kernel, duration_us, sum per module. + + The 1st and 4th columns are merged for contiguous identical modules. + """ + wb = Workbook() + ws = wb.active + ws.title = sheet_name + ws.append(['cpu_module', 'gpu_kernel', 'duration_us', 'sum per module']) + + if not rows: + wb.save(output_xlsx) + return + + # Build contiguous groups by cpu_module in current row order. + groups: List[Tuple[int, int, str, float]] = [] + i = 0 + while i < len(rows): + mod = rows[i][0] + j = i + 1 + total = float(rows[i][2]) + while j < len(rows) and rows[j][0] == mod: + total += float(rows[j][2]) + j += 1 + groups.append((i, j - 1, mod, total)) + i = j + + # Write rows with per-module sum in 4th column. + for start, end, _, total in groups: + for idx in range(start, end + 1): + mod, kernel, dur = rows[idx] + ws.append([mod, kernel, dur, total]) + + # Merge col A and D for each contiguous module group (row offset +2 for header). + for start, end, _, _ in groups: + if end > start: + r1 = start + 2 + r2 = end + 2 + ws.merge_cells(start_row=r1, start_column=1, end_row=r2, end_column=1) + ws.merge_cells(start_row=r1, start_column=4, end_row=r2, end_column=4) + + wb.save(output_xlsx) + + +# ============================================================================= +# Optimized Event Index for fast time-range queries +# ============================================================================= + +class EventIndex: + """Pre-indexed events for fast time-range queries.""" + + def __init__(self, events: List[Dict]): + # Filter duration events only + self.duration_events = [e for e in events if e.get('ph') == 'X'] + self.duration_events.sort(key=lambda x: x['ts']) + self.ts_list = [e['ts'] for e in self.duration_events] + + # Pre-compute kernel launch flags and prefix sum + self._is_kernel_launch = [is_kernel_launch(e.get('name', '')) for e in self.duration_events] + self._kernel_prefix_sum = [0] + for is_kl in self._is_kernel_launch: + self._kernel_prefix_sum.append(self._kernel_prefix_sum[-1] + (1 if is_kl else 0)) + + def events_in_range(self, start_ts: float, end_ts: float) -> List[Dict]: + """Get all duration events within [start_ts, end_ts].""" + left = bisect.bisect_left(self.ts_list, start_ts) + right = bisect.bisect_right(self.ts_list, end_ts) + return [ + e for e in self.duration_events[left:right] + if e['ts'] + e.get('dur', 0) <= end_ts + ] + + def count_kernel_launches_in_range(self, start_ts: float, end_ts: float) -> int: + """Count kernel launches within time range (fast using prefix sum).""" + left = bisect.bisect_left(self.ts_list, start_ts) + right = bisect.bisect_right(self.ts_list, end_ts) + count = 0 + for i in range(left, right): + e = self.duration_events[i] + if e['ts'] + e.get('dur', 0) <= end_ts and self._is_kernel_launch[i]: + count += 1 + return count + + def get_direct_children(self, parent: Dict) -> List[Dict]: + """Get direct children of parent event (optimized).""" + p_ts = parent['ts'] + p_end = p_ts + parent.get('dur', 0) + + # Get candidates in parent's time range + candidates = [ + e for e in self.events_in_range(p_ts, p_end) + if e is not parent + ] + + if not candidates: + return [] + + # Filter to direct children only (not nested in other candidates) + # Sort by duration descending - larger events are potential parents + candidates_sorted = sorted(candidates, key=lambda x: -x.get('dur', 0)) + + direct = [] + for i, c in enumerate(candidates_sorted): + c_ts, c_dur = c['ts'], c.get('dur', 0) + c_end = c_ts + c_dur + # Check if c is nested inside any larger candidate + is_nested = False + for j in range(i): # Only check larger (earlier in sorted list) + o = candidates_sorted[j] + o_ts = o['ts'] + o_end = o_ts + o.get('dur', 0) + if c_ts >= o_ts and c_end <= o_end: + is_nested = True + break + if not is_nested: + direct.append(c) + + return sorted(direct, key=lambda x: x['ts']) + + def count_kernel_launches(self, event: Dict) -> int: + """Count kernel launches within event's time range.""" + e_ts = event['ts'] + e_end = e_ts + event.get('dur', 0) + return self.count_kernel_launches_in_range(e_ts, e_end) + + def has_kernel_launch(self, event: Dict) -> bool: + """Check if event contains any kernel launch.""" + return self.count_kernel_launches(event) > 0 + + +# ============================================================================= +# Legacy functions (for prefill compatibility) +# ============================================================================= + +def find_events(events: List[Dict], name: str, prefix: bool = False) -> List[Dict]: + """Find all duration events (ph='X') with given name, sorted by time.""" + if prefix: + result = [e for e in events if e.get('name', '').startswith(name) and e.get('ph') == 'X'] + else: + result = [e for e in events if e.get('name') == name and e.get('ph') == 'X'] + return sorted(result, key=lambda x: x['ts']) + + +def get_gpu_kernels(events: List[Dict], start_ts: float) -> List[Dict]: + """Get GPU kernels (cat='kernel') starting from given timestamp.""" + result = [e for e in events if e.get('cat') == 'kernel' and e['ts'] >= start_ts] + return sorted(result, key=lambda x: x['ts']) + + +def get_direct_children(parent: Dict, events: List[Dict]) -> List[Dict]: + """Get direct children of parent event (excluding nested children).""" + p_ts, p_dur = parent['ts'], parent.get('dur', 0) + + candidates = [ + e for e in events + if e.get('ph') == 'X' and e is not parent + and is_within(e.get('ts', 0), e.get('dur', 0), p_ts, p_dur) + ] + + direct = [] + for c in candidates: + c_ts, c_dur = c['ts'], c.get('dur', 0) + is_direct = not any( + is_within(c_ts, c_dur, o['ts'], o.get('dur', 0)) + for o in candidates if o is not c + ) + if is_direct: + direct.append(c) + + return sorted(direct, key=lambda x: x['ts']) + + +def count_kernel_launches(event: Dict, events: List[Dict]) -> int: + """Count kernel launches within event's subtree.""" + e_ts, e_dur = event['ts'], event.get('dur', 0) + return sum( + 1 for e in events + if e.get('ph') == 'X' and is_kernel_launch(e.get('name', '')) + and is_within(e.get('ts', 0), e.get('dur', 0), e_ts, e_dur) + ) + + +def has_kernel_launch(event: Dict, events: List[Dict]) -> bool: + """Check if event's subtree contains any kernel launch.""" + return count_kernel_launches(event, events) > 0 + + +# ============================================================================= +# Parse Functions +# ============================================================================= + +def parse_prefill(events: List[Dict], output_xlsx: str, target_layer: int = 3) -> None: + """ + Parse prefill phase: find the actual prefill event on CPU trace (user_annotation). + + Warmup rule: + - If only one prefill exists, it is the actual prefill (no warmup). + - If >=2 prefills exist: + - If there is a decode_step_bs* event between prefill[0] and prefill[1], prefill[0] + is treated as warmup and prefill[1] is the actual prefill. + - Otherwise, prefill[0] is the actual prefill. + """ + # CPU side prefill/decode annotations + prefills = [ + e for e in events + if e.get('name') == 'prefill' + and e.get('ph') == 'X' + and e.get('cat') == 'user_annotation' + ] + prefills = sorted(prefills, key=lambda x: x['ts']) + + if not prefills: + print("No prefill (user_annotation) events found.") + write_breakdown_xlsx(output_xlsx, [], sheet_name='prefill') + return + + actual_prefill_idx = 0 + warmup_detected = False + + # Only evaluate warmup when there are at least two prefills. + if len(prefills) >= 2: + first = prefills[0] + second = prefills[1] + gap_start = first['ts'] + first.get('dur', 0) + gap_end = second['ts'] + + # If decode_step_bs appears in [gap_start, gap_end], first prefill is warmup. + has_decode_between = any( + e.get('ph') == 'X' + and e.get('cat') == 'user_annotation' + and e.get('name', '').startswith('decode_step_bs') + and gap_start <= e.get('ts', 0) <= gap_end + for e in events + ) + if has_decode_between: + actual_prefill_idx = 1 + warmup_detected = True + + actual_prefill = prefills[actual_prefill_idx] + print(f"Found {len(prefills)} prefill events (user_annotation)") + if warmup_detected: + print("Warmup detected: decode_step_bs found between prefill[0] and prefill[1]") + else: + print("No warmup prefill detected by rule, using prefill[0]") + print( + f"Using prefill[{actual_prefill_idx}] " + f"(ts={actual_prefill.get('ts', 0):.0f}, dur={actual_prefill.get('dur', 0):.0f})" + ) + + # Build prefill hierarchy on the same thread as the selected CPU prefill. + # Using thread affinity is more robust than category-only filtering. + prefill_tid = actual_prefill.get('tid') + prefill_pid = actual_prefill.get('pid') + prefill_hierarchy_events = [ + e for e in events + if e.get('ph') == 'X' + and e.get('tid') == prefill_tid + and e.get('pid') == prefill_pid + ] + # Build index once for fast subtree queries in prefill parsing. + prefill_idx = EventIndex(prefill_hierarchy_events) + level1_children = prefill_idx.get_direct_children(actual_prefill) + print( + f"Prefill level 1 (same thread pid={prefill_pid}, tid={prefill_tid}): " + f"{len(level1_children)} nodes" + ) + print("First 20 level 1 nodes:") + for i, child in enumerate(level1_children[:20]): + print( + f" [{i:02d}] {child.get('name', '')} " + f"(ts={child.get('ts', 0):.0f}, dur={child.get('dur', 0):.0f})" + ) + + # Keep only level2 children that have kernel launch in their subtree. + launch_level2_items = [] + for l1 in level1_children: + l1_name = l1.get('name', '') + level2_children = prefill_idx.get_direct_children(l1) + level2_with_launch = [ + l2 for l2 in level2_children + if prefill_idx.has_kernel_launch(l2) + ] + for l2 in level2_with_launch: + launch_level2_items.append({ + 'level1_name': l1_name, + 'level2_event': l2, + }) + + print(f"Level2 children with kernel launch: {len(launch_level2_items)}") + print("First 20 level2 (with kernel launch):") + for i, item in enumerate(launch_level2_items[:20]): + l2 = item['level2_event'] + print( + f" [{i:02d}] L1={item['level1_name']} | " + f"L2={l2.get('name', '')} | " + f"cat={l2.get('cat', '')} | dur={l2.get('dur', 0):.0f}" + ) + + # Layer extraction by rmsnorm positions: + # each layer has 2 rmsnorm modules, layer N starts at rmsnorm index 2*N (0-based). + TARGET_LAYER = target_layer + norm_indices = [ + i for i, item in enumerate(launch_level2_items) + if 'rmsnorm' in item['level2_event'].get('name', '').lower() + ] + print(f"Found {len(norm_indices)} rmsnorm modules in level2-with-launch rows") + + layer_items = [] + norm_start_idx = TARGET_LAYER * 2 + norm_end_idx = (TARGET_LAYER + 1) * 2 + if norm_start_idx >= len(norm_indices): + print(f"Not enough rmsnorm modules for layer {TARGET_LAYER}, writing empty CSV") + else: + mod_start = norm_indices[norm_start_idx] + mod_end = norm_indices[norm_end_idx] if norm_end_idx < len(norm_indices) else len(launch_level2_items) + layer_items = launch_level2_items[mod_start:mod_end] + print( + f"Layer {TARGET_LAYER} range by rmsnorm: " + f"rows [{mod_start}:{mod_end}) from rmsnorm #{norm_start_idx+1} to #{norm_end_idx+1}" + ) + print(f"Layer {TARGET_LAYER} modules: {len(layer_items)}") + + # Build launch->kernel mapping by correlation id. + # Build launch candidates from current prefill thread/range once. + runtime_launches = [ + e for e in prefill_hierarchy_events + if e.get('cat') == 'cuda_runtime' and is_kernel_launch(e.get('name', '')) + ] + runtime_launches.sort(key=lambda x: x.get('ts', 0)) + runtime_launch_ts = [e.get('ts', 0) for e in runtime_launches] + + output_rows = [] + level2_kernel_map: Dict[str, List[str]] = {} + corr_needed = set() + for item in layer_items: + l1_name = item['level1_name'] + l2 = item['level2_event'] + l2_start = l2.get('ts', 0) + l2_end = l2_start + l2.get('dur', 0) + l2_name = l2.get('name', '') + l2_cat = l2.get('cat', '') + l2_dur = l2.get('dur', 0) + + left = bisect.bisect_left(runtime_launch_ts, l2_start) + right = bisect.bisect_right(runtime_launch_ts, l2_end) + launches_in_l2 = runtime_launches[left:right] + + if not launches_in_l2: + output_rows.append([ + l1_name, l2_name, l2_cat, l2_dur, + 'N/A', 0, '', 'N/A', 0 + ]) + continue + + for launch in launches_in_l2: + launch_name = launch.get('name', 'N/A') + launch_dur = launch.get('dur', 0) + corr = (launch.get('args') or {}).get('correlation') + if corr is not None: + corr_needed.add(corr) + matched_kernels = [] # Fill after kernel index is built. + output_rows.append([ + l1_name, l2_name, l2_cat, l2_dur, + launch_name, launch_dur, corr if corr is not None else '', + 'PENDING', 0 + ]) + + # Build kernel index only for correlations we actually need. + kernel_by_corr: Dict[Any, List[Dict]] = {} + if corr_needed: + for e in events: + if e.get('ph') != 'X' or e.get('cat') != 'kernel': + continue + corr = (e.get('args') or {}).get('correlation') + if corr is None or corr not in corr_needed: + continue + kernel_by_corr.setdefault(corr, []).append(e) + for corr in kernel_by_corr: + kernel_by_corr[corr].sort(key=lambda x: x.get('ts', 0)) + + # Expand pending launch rows into launch->kernel rows. + expanded_rows = [] + for row in output_rows: + l1_name, l2_name, l2_cat, l2_dur, launch_name, launch_dur, corr, _, _ = row + corr_val = corr if corr != '' else None + matched_kernels = kernel_by_corr.get(corr_val, []) if corr_val is not None else [] + if not matched_kernels: + expanded_rows.append([ + l1_name, l2_name, l2_cat, l2_dur, + launch_name, launch_dur, corr, 'N/A', 0 + ]) + level2_kernel_map.setdefault(l2_name, []) + continue + for k in matched_kernels: + kernel_name = k.get('name', 'N/A') + expanded_rows.append([ + l1_name, l2_name, l2_cat, l2_dur, + launch_name, launch_dur, corr, kernel_name, k.get('dur', 0) + ]) + level2_kernel_map.setdefault(l2_name, []).append(kernel_name) + output_rows = expanded_rows + + print(f"Layer {TARGET_LAYER} launch->kernel mapping rows: {len(output_rows)}") + print("First 20 mapping rows:") + for i, row in enumerate(output_rows[:20]): + print( + f" [{i:02d}] L2={row[1]} | launch={row[4]} | " + f"corr={row[6]} | kernel={row[7]}" + ) + print(f"Layer {TARGET_LAYER} level2 -> gpu operators:") + for level2_name, kernels in level2_kernel_map.items(): + # Keep order while removing duplicates. + uniq = list(dict.fromkeys(kernels)) + if not uniq: + print(f" - {level2_name}: N/A") + continue + print(f" - {level2_name}: {len(uniq)} operator(s)") + for kname in uniq: + print(f" * {kname}") + + # Convert to decode-style CSV rows with prefill filters. + # Output columns: cpu_module, gpu_kernel, duration_us + csv_rows = [] + for row in output_rows: + l2_name = row[1] + gpu_kernel = row[7] + gpu_dur = row[8] + + if should_filter_prefill(l2_name): + continue + if gpu_kernel in ('', 'N/A'): + continue + + csv_rows.append([clean_module_name(l2_name), gpu_kernel, gpu_dur]) + + print(f"Prefill decode-style CSV rows (after filters): {len(csv_rows)}") + + # Write XLSX for prefill. + write_breakdown_xlsx(output_xlsx, csv_rows, sheet_name='prefill') + + +def clean_module_name(name: str) -> str: + """Clean and simplify module name.""" + # Remove 'aiter::' prefix if present + if name.startswith('aiter::'): + name = name[7:] # len('aiter::') == 7 + + # Rename based on keywords (rope takes priority) + name_lower = name.lower() + if 'rope' in name_lower: + return 'rope' + if 'cache' in name_lower: + return 'kv_cache' + + return name + + +def process_moe_module( + mod_name: str, + kernel_count: int, + start_gpu_idx: int, + gpu_kernels: List[Dict] +) -> List[List]: + """ + Process moe_forward module: categorize kernels by name. + + - 'moesort' in kernel name -> moe_sort + - 'topk' in kernel name -> moe_topk + - others -> keep original mod_name + + Returns list of [display_name, gpu_kernel_name, gpu_dur] rows. + """ + rows = [] + clean_mod_name = clean_module_name(mod_name) + + for i in range(kernel_count): + gpu_idx = start_gpu_idx + i + gpu_kernel_name = 'N/A' + gpu_dur = 0 + if gpu_idx < len(gpu_kernels): + gpu = gpu_kernels[gpu_idx] + gpu_kernel_name = gpu.get('name', 'N/A') + gpu_dur = gpu.get('dur', 0) + + # Determine category based on kernel name + kernel_lower = gpu_kernel_name.lower() + if 'moesort' in kernel_lower: + category = 'moe_sort' + elif 'topk' in kernel_lower: + category = 'moe_topk' + else: + category = clean_mod_name + + # Always show category/module name on each row. + display_name = category + rows.append([display_name, gpu_kernel_name, gpu_dur]) + + return rows + + +def process_regular_module( + mod_name: str, + kernel_count: int, + start_gpu_idx: int, + gpu_kernels: List[Dict] +) -> List[List]: + """Process regular module and show module name on every row.""" + rows = [] + clean_mod_name = clean_module_name(mod_name) + for i in range(kernel_count): + gpu_idx = start_gpu_idx + i + gpu_kernel_name = 'N/A' + gpu_dur = 0 + if gpu_idx < len(gpu_kernels): + gpu = gpu_kernels[gpu_idx] + gpu_kernel_name = gpu.get('name', 'N/A') + gpu_dur = gpu.get('dur', 0) + display_name = clean_mod_name + rows.append([display_name, gpu_kernel_name, gpu_dur]) + return rows + + +def parse_decode(events: List[Dict], output_xlsx: str, target_layer: int = 3) -> None: + """ + Parse decode phase: map capture_graph modules to GPU kernels. + + Output CSV columns: cpu_module, gpu_kernel, duration_us + """ + print("Building event index...") + + # Find GPU-annotated decode_step events (cat='gpu_user_annotation') + decode_steps = [ + e for e in events + if e.get('name', '').startswith('decode_step') + and e.get('ph') == 'X' + and e.get('cat') == 'gpu_user_annotation' + ] + decode_steps = sorted(decode_steps, key=lambda x: x['ts']) + + if not decode_steps: + print("No decode_step (gpu_user_annotation) events found.") + return + + # Skip warmup: find first gap > 100ms (warmup/run boundary) + # Normal decode gaps are < 5ms, so 100ms is safe threshold + WARMUP_GAP_THRESHOLD = 100000 # 100ms in microseconds + actual_run_idx = 0 + found_warmup_boundary = False + for i in range(1, len(decode_steps)): + gap = decode_steps[i]['ts'] - (decode_steps[i-1]['ts'] + decode_steps[i-1].get('dur', 0)) + if gap > WARMUP_GAP_THRESHOLD: + actual_run_idx = i + found_warmup_boundary = True + print(f"Warmup/run boundary at [{i-1}]->[{i}], gap={gap/1000:.1f}ms") + break + + if not found_warmup_boundary: + print("No warmup detected (no gap > 100ms), using first decode_step") + + first_ds = decode_steps[actual_run_idx] + first_ds_name = first_ds.get('name', '') + if '_bs_' in first_ds_name: + bs = first_ds_name.split('_bs_')[-1] + target_cg_name = f'capture_graph_bs_{bs}' + else: + target_cg_name = 'capture_graph' + + print(f"First decode_step: {first_ds_name}") + print(f"Looking for: {target_cg_name}") + + # Find matching capture_graph + capture_graphs = [ + e for e in events + if e.get('name') == target_cg_name and e.get('ph') == 'X' + ] + if not capture_graphs: + # Fallback: find any capture_graph + capture_graphs = [ + e for e in events + if e.get('name', '').startswith('capture_graph') and e.get('ph') == 'X' + ] + capture_graphs = sorted(capture_graphs, key=lambda x: x['ts']) + print(f"No exact match, using first capture_graph") + + if not capture_graphs: + print("No capture_graph events found.") + return + + cg = capture_graphs[0] + print(f"Using: {cg.get('name')}") + + # Build optimized index only for capture_graph's time range + cg_start = cg['ts'] + cg_end = cg_start + cg.get('dur', 0) + cg_events = [ + e for e in events + if e.get('ph') == 'X' and e.get('ts', 0) >= cg_start + and e.get('ts', 0) + e.get('dur', 0) <= cg_end + ] + print(f"Events in capture_graph: {len(cg_events)}") + idx = EventIndex(cg_events) + + # Get GPU kernels from first decode_step (within its duration) + ds1_start = first_ds['ts'] + ds1_end = ds1_start + first_ds.get('dur', 0) + + gpu_kernels = [ + e for e in events + if e.get('cat') == 'kernel' and ds1_start <= e['ts'] <= ds1_end + ] + gpu_kernels = sorted(gpu_kernels, key=lambda x: x['ts']) + print(f"First decode_step (tid={first_ds.get('tid')}): {first_ds_name}") + print(f" Range: {ds1_start:.0f} ~ {ds1_end:.0f} (dur={first_ds.get('dur', 0):.0f})") + print(f" GPU kernels: {len(gpu_kernels)}") + + # Use optimized index for children lookup + direct_children = idx.get_direct_children(cg) + kernel_children = [c for c in direct_children if idx.has_kernel_launch(c)] + print(f"Direct children with kernels: {len(kernel_children)}") + + # Collect all modules with their kernel info + all_modules = [] # list of (mod_name, kernel_count, start_gpu_idx) + gpu_idx = 0 + + for child in kernel_children: + child_name = child.get('name', '') + if should_filter(child_name): + continue + + # Get sub-children (actual module names) + sub_children = idx.get_direct_children(child) + sub_kernel_children = [sc for sc in sub_children if idx.has_kernel_launch(sc)] + + # Determine modules to process + modules = sub_kernel_children if sub_kernel_children else [child] + + for mod in modules: + mod_name = mod.get('name', '') + kernel_count = idx.count_kernel_launches(mod) + all_modules.append((mod_name, kernel_count, gpu_idx)) + gpu_idx += kernel_count + + # Find norm positions (rmsnorm in name) + norm_indices = [i for i, (name, _, _) in enumerate(all_modules) if 'rmsnorm' in name.lower()] + print(f"Found {len(norm_indices)} norm modules") + + # Extract layer 3 (4th layer, 0-indexed) + # Each layer has 2 norms, so layer N starts at norm index 2*N + TARGET_LAYER = target_layer + norm_start_idx = TARGET_LAYER * 2 # 6 (7th norm, 0-indexed) + norm_end_idx = (TARGET_LAYER + 1) * 2 # 8 (9th norm, 0-indexed) + + if norm_start_idx >= len(norm_indices): + print(f"Not enough norms for layer {TARGET_LAYER}") + return + + # Module range for layer 3: from norm_indices[6] to norm_indices[8] (exclusive) + mod_start = norm_indices[norm_start_idx] + mod_end = norm_indices[norm_end_idx] if norm_end_idx < len(norm_indices) else len(all_modules) + + print(f"Layer {TARGET_LAYER}: modules [{mod_start}:{mod_end}] (norms at indices {norm_start_idx}, {norm_start_idx+1})") + + # Build CSV rows for layer 3 only + rows = [] + for mod_name, kernel_count, start_gpu_idx in all_modules[mod_start:mod_end]: + if 'moe_forward' in mod_name.lower(): + rows.extend(process_moe_module(mod_name, kernel_count, start_gpu_idx, gpu_kernels)) + else: + rows.extend(process_regular_module(mod_name, kernel_count, start_gpu_idx, gpu_kernels)) + + # Write XLSX + write_breakdown_xlsx(output_xlsx, rows, sheet_name='decode') + + print(f"Layer {TARGET_LAYER} modules: {mod_end - mod_start}") + print(f"XLSX written to: {output_xlsx} ({len(rows)} rows)") + + +# ============================================================================= +# Main +# ============================================================================= + +def main(): + parser = argparse.ArgumentParser(description="Parse PyTorch profiler trace JSON to extract kernel information.") + parser.add_argument('filepath', help='Path to trace JSON or JSON.GZ file') + parser.add_argument('--layer', type=int, default=3, help='Target layer index (default: 3)') + args = parser.parse_args() + + if args.layer < 0: + print("--layer must be >= 0") + sys.exit(1) + + filepath = args.filepath + target_layer = args.layer + + print(f"Loading: {filepath}") + trace = load_trace(filepath) + events = trace.get('traceEvents', []) + print(f"Loaded {len(events)} events\n") + + print("=" * 60) + print("PREFILL ANALYSIS") + print("=" * 60) + parse_prefill(events, 'prefill_breakdown.xlsx', target_layer=target_layer) + + print("\n" + "=" * 60) + print("DECODE ANALYSIS") + print("=" * 60) + parse_decode(events, 'decode_breakdown.xlsx', target_layer=target_layer) + + +if __name__ == '__main__': + main() From aa31f6cebe7822b7361a59a5d68eba4786c98560 Mon Sep 17 00:00:00 2001 From: HaonanWang98 Date: Tue, 3 Mar 2026 05:06:21 +0000 Subject: [PATCH 2/9] optimize --- tools/parse_trace.py | 54 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 49 insertions(+), 5 deletions(-) diff --git a/tools/parse_trace.py b/tools/parse_trace.py index d93131889..c2c563657 100644 --- a/tools/parse_trace.py +++ b/tools/parse_trace.py @@ -99,6 +99,25 @@ def write_breakdown_xlsx(output_xlsx: str, rows: List[List[Any]], sheet_name: st wb.save(output_xlsx) +def rename_rmsnorm_modules(rows: List[List[Any]]) -> None: + """ + Rename first two rmsnorm cpu module names in-place: + - first -> input_layernorm + - second -> post_attn_layernorm + """ + seen = 0 + for row in rows: + if not row or len(row) < 1: + continue + mod = row[0] + if isinstance(mod, str) and 'rmsnorm' in mod.lower(): + if seen == 0: + row[0] = 'input_layernorm' + elif seen == 1: + row[0] = 'post_attn_layernorm' + seen += 1 + + # ============================================================================= # Optimized Event Index for fast time-range queries # ============================================================================= @@ -475,11 +494,12 @@ def parse_prefill(events: List[Dict], output_xlsx: str, target_layer: int = 3) - for kname in uniq: print(f" * {kname}") - # Convert to decode-style CSV rows with prefill filters. + # Convert to decode-style rows with prefill filters and MoE categorization. # Output columns: cpu_module, gpu_kernel, duration_us - csv_rows = [] + module_instances: List[Dict[str, Any]] = [] + instance_idx: Dict[Tuple[Any, ...], int] = {} for row in output_rows: - l2_name = row[1] + l1_name, l2_name, l2_cat, l2_dur = row[0], row[1], row[2], row[3] gpu_kernel = row[7] gpu_dur = row[8] @@ -488,9 +508,32 @@ def parse_prefill(events: List[Dict], output_xlsx: str, target_layer: int = 3) - if gpu_kernel in ('', 'N/A'): continue - csv_rows.append([clean_module_name(l2_name), gpu_kernel, gpu_dur]) + # Keep module instances separated by their parent+name+category+duration, + # preserving the original order from the trace. + key = (l1_name, l2_name, l2_cat, l2_dur) + if key not in instance_idx: + instance_idx[key] = len(module_instances) + module_instances.append({'mod_name': l2_name, 'kernels': []}) + module_instances[instance_idx[key]]['kernels'].append({ + 'name': gpu_kernel, + 'dur': gpu_dur, + }) + + csv_rows = [] + for inst in module_instances: + mod_name = inst['mod_name'] + kernels = inst['kernels'] + if not kernels: + continue + + if 'moe_forward' in mod_name.lower(): + csv_rows.extend(process_moe_module(mod_name, len(kernels), 0, kernels)) + else: + for k in kernels: + csv_rows.append([clean_module_name(mod_name), k['name'], k['dur']]) print(f"Prefill decode-style CSV rows (after filters): {len(csv_rows)}") + rename_rmsnorm_modules(csv_rows) # Write XLSX for prefill. write_breakdown_xlsx(output_xlsx, csv_rows, sheet_name='prefill') @@ -724,6 +767,7 @@ def parse_decode(events: List[Dict], output_xlsx: str, target_layer: int = 3) -> rows.extend(process_moe_module(mod_name, kernel_count, start_gpu_idx, gpu_kernels)) else: rows.extend(process_regular_module(mod_name, kernel_count, start_gpu_idx, gpu_kernels)) + rename_rmsnorm_modules(rows) # Write XLSX write_breakdown_xlsx(output_xlsx, rows, sheet_name='decode') @@ -745,7 +789,7 @@ def main(): if args.layer < 0: print("--layer must be >= 0") sys.exit(1) - + filepath = args.filepath target_layer = args.layer From 90245ce33028691964dbecabf82ed84b3d1c8425 Mon Sep 17 00:00:00 2001 From: HaonanWang98 Date: Tue, 3 Mar 2026 07:27:46 +0000 Subject: [PATCH 3/9] refactor marker --- atom/utils/graph_marker_instrumentation.py | 188 ++++++++++++++++++--- tools/parse_trace.py | 17 +- 2 files changed, 176 insertions(+), 29 deletions(-) diff --git a/atom/utils/graph_marker_instrumentation.py b/atom/utils/graph_marker_instrumentation.py index 1c893686a..ba2170cde 100644 --- a/atom/utils/graph_marker_instrumentation.py +++ b/atom/utils/graph_marker_instrumentation.py @@ -3,29 +3,19 @@ from __future__ import annotations +import ast import os import re from dataclasses import dataclass from typing import Iterable, Optional -_GRAPH_MARKER_RE = re.compile( - r"""torch\.ops\.aiter\.graph_marker\.default\( - \s*[^,]+,\s* - (?P['"])(?P.*?)(?P=q) - \s*\)""", - re.VERBOSE, -) - _SUBGRAPH_ID_RE = re.compile(r"artifact_shape_[^/]+_subgraph_(\d+)") -_GRAPH_MARKER_LINE_RE = re.compile( +_ASSIGNMENT_RE = re.compile( r"""^(?P\s*) (?P[A-Za-z_]\w*)\s*=\s* - torch\.ops\.aiter\.graph_marker\.default\( - \s*(?P[^,]+?)\s*,\s* - (?P['"])(?P.*?)(?P=q)\s* - \)\s*$""", + (?P.+?)\s*$""", re.VERBOSE, ) @@ -37,6 +27,163 @@ class _Marker: name: str +@dataclass(frozen=True) +class _ParsedMarkerAssignment: + indent: str + lhs: str + arg: str + name: str + + +_GRAPH_MARKER_PREFIX = "torch.ops.aiter.graph_marker.default(" + + +def _find_matching_paren(s: str, open_idx: int) -> Optional[int]: + depth = 0 + in_str: Optional[str] = None + escaped = False + for i in range(open_idx, len(s)): + ch = s[i] + if in_str is not None: + if escaped: + escaped = False + continue + if ch == "\\": + escaped = True + continue + if ch == in_str: + in_str = None + continue + if ch in ("'", '"'): + in_str = ch + continue + if ch == "(": + depth += 1 + continue + if ch == ")": + depth -= 1 + if depth == 0: + return i + return None + + +def _split_top_level_args(s: str) -> list[str]: + out: list[str] = [] + depth_paren = 0 + depth_bracket = 0 + depth_brace = 0 + in_str: Optional[str] = None + escaped = False + start = 0 + for i, ch in enumerate(s): + if in_str is not None: + if escaped: + escaped = False + continue + if ch == "\\": + escaped = True + continue + if ch == in_str: + in_str = None + continue + if ch in ("'", '"'): + in_str = ch + continue + if ch == "(": + depth_paren += 1 + continue + if ch == ")": + depth_paren -= 1 + continue + if ch == "[": + depth_bracket += 1 + continue + if ch == "]": + depth_bracket -= 1 + continue + if ch == "{": + depth_brace += 1 + continue + if ch == "}": + depth_brace -= 1 + continue + if ( + ch == "," + and depth_paren == 0 + and depth_bracket == 0 + and depth_brace == 0 + ): + out.append(s[start:i].strip()) + start = i + 1 + out.append(s[start:].strip()) + return out + + +def _parse_graph_marker_call_expr(expr: str) -> Optional[tuple[str, str]]: + idx = expr.find(_GRAPH_MARKER_PREFIX) + if idx < 0: + return None + open_idx = idx + len(_GRAPH_MARKER_PREFIX) - 1 + close_idx = _find_matching_paren(expr, open_idx) + if close_idx is None: + return None + call_src = expr[idx : close_idx + 1] + if call_src.count("(") != call_src.count(")"): + return None + # Ensure the matched call is the whole RHS expression (allowing whitespace). + if expr[:idx].strip() or expr[close_idx + 1 :].strip(): + return None + inner = call_src[len(_GRAPH_MARKER_PREFIX) : -1] + args = _split_top_level_args(inner) + if len(args) < 2: + return None + arg_expr = args[0] + try: + name = ast.literal_eval(args[1]) + except Exception: + return None + if not isinstance(name, str): + return None + return arg_expr, name + + +def _parse_graph_marker_assignment_line(line: str) -> Optional[_ParsedMarkerAssignment]: + m = _ASSIGNMENT_RE.match(line.rstrip("\n")) + if not m: + return None + rhs = m.group("rhs") + parsed = _parse_graph_marker_call_expr(rhs) + if parsed is None: + return None + arg, name = parsed + return _ParsedMarkerAssignment( + indent=m.group("indent"), + lhs=m.group("lhs"), + arg=arg, + name=name, + ) + + +def _extract_graph_marker_name(line: str) -> Optional[str]: + idx = line.find(_GRAPH_MARKER_PREFIX) + if idx < 0: + return None + open_idx = idx + len(_GRAPH_MARKER_PREFIX) - 1 + close_idx = _find_matching_paren(line, open_idx) + if close_idx is None: + return None + call_src = line[idx : close_idx + 1] + inner = call_src[len(_GRAPH_MARKER_PREFIX) : -1] + args = _split_top_level_args(inner) + if len(args) < 2: + return None + try: + name = ast.literal_eval(args[1]) + except Exception: + return None + return name if isinstance(name, str) else None + + def _iter_py_files(root: str) -> Iterable[str]: for dirpath, _, filenames in os.walk(root): for fn in filenames: @@ -77,11 +224,11 @@ def _ensure_record_function_import(lines: list[str]) -> None: def _collect_markers(lines: list[str]) -> list[_Marker]: out: list[_Marker] = [] for i, line in enumerate(lines): - m = _GRAPH_MARKER_RE.search(line) - if not m: + name = _extract_graph_marker_name(line) + if name is None: continue indent = re.match(r"^(\s*)", line).group(1) # type: ignore[union-attr] - out.append(_Marker(idx=i, indent=indent, name=m.group("name"))) + out.append(_Marker(idx=i, indent=indent, name=name)) return out @@ -184,12 +331,9 @@ def _strip_runtime_graph_markers(lines: list[str]) -> bool: out: list[str] = [] changed = False for line in lines: - m = _GRAPH_MARKER_LINE_RE.match(line.rstrip("\n")) - if m: - indent = m.group("indent") - lhs = m.group("lhs") - arg = m.group("arg").strip() - out.append(f"{indent}{lhs} = {arg}\n") + parsed = _parse_graph_marker_assignment_line(line) + if parsed is not None: + out.append(f"{parsed.indent}{parsed.lhs} = {parsed.arg}\n") changed = True continue diff --git a/tools/parse_trace.py b/tools/parse_trace.py index c2c563657..5b6616722 100644 --- a/tools/parse_trace.py +++ b/tools/parse_trace.py @@ -530,7 +530,7 @@ def parse_prefill(events: List[Dict], output_xlsx: str, target_layer: int = 3) - csv_rows.extend(process_moe_module(mod_name, len(kernels), 0, kernels)) else: for k in kernels: - csv_rows.append([clean_module_name(mod_name), k['name'], k['dur']]) + csv_rows.append([clean_module_name(mod_name, k['name']), k['name'], k['dur']]) print(f"Prefill decode-style CSV rows (after filters): {len(csv_rows)}") rename_rmsnorm_modules(csv_rows) @@ -539,14 +539,20 @@ def parse_prefill(events: List[Dict], output_xlsx: str, target_layer: int = 3) - write_breakdown_xlsx(output_xlsx, csv_rows, sheet_name='prefill') -def clean_module_name(name: str) -> str: +def clean_module_name(name: str, mapped_kernel_name: str = '') -> str: """Clean and simplify module name.""" + # Runtime launch wrappers should display the actual launched operator name. + if 'hipmodulelaunchkernel' in name.lower() and mapped_kernel_name not in ('', 'N/A'): + name = mapped_kernel_name + # Remove 'aiter::' prefix if present if name.startswith('aiter::'): name = name[7:] # len('aiter::') == 7 # Rename based on keywords (rope takes priority) name_lower = name.lower() + if 'rope' in name_lower and 'cache' in name_lower: + return 'rope & kv_cache' if 'rope' in name_lower: return 'rope' if 'cache' in name_lower: @@ -571,8 +577,6 @@ def process_moe_module( Returns list of [display_name, gpu_kernel_name, gpu_dur] rows. """ rows = [] - clean_mod_name = clean_module_name(mod_name) - for i in range(kernel_count): gpu_idx = start_gpu_idx + i gpu_kernel_name = 'N/A' @@ -589,7 +593,7 @@ def process_moe_module( elif 'topk' in kernel_lower: category = 'moe_topk' else: - category = clean_mod_name + category = clean_module_name(mod_name, gpu_kernel_name) # Always show category/module name on each row. display_name = category @@ -606,7 +610,6 @@ def process_regular_module( ) -> List[List]: """Process regular module and show module name on every row.""" rows = [] - clean_mod_name = clean_module_name(mod_name) for i in range(kernel_count): gpu_idx = start_gpu_idx + i gpu_kernel_name = 'N/A' @@ -615,7 +618,7 @@ def process_regular_module( gpu = gpu_kernels[gpu_idx] gpu_kernel_name = gpu.get('name', 'N/A') gpu_dur = gpu.get('dur', 0) - display_name = clean_mod_name + display_name = clean_module_name(mod_name, gpu_kernel_name) rows.append([display_name, gpu_kernel_name, gpu_dur]) return rows From 90e10acd00577cf7b7fb8919c480a923632ef229 Mon Sep 17 00:00:00 2001 From: HaonanWang98 Date: Wed, 4 Mar 2026 10:35:58 +0000 Subject: [PATCH 4/9] add avg and fix some bugs --- atom/entrypoints/openai_server.py | 10 - atom/model_ops/linear.py | 8 +- atom/utils/backends.py | 28 +++ atom/utils/decorators.py | 58 +++++ tools/parse_trace.py | 405 +++++++++++++++++------------- 5 files changed, 316 insertions(+), 193 deletions(-) diff --git a/atom/entrypoints/openai_server.py b/atom/entrypoints/openai_server.py index affcec098..9ed9c3b45 100644 --- a/atom/entrypoints/openai_server.py +++ b/atom/entrypoints/openai_server.py @@ -910,16 +910,6 @@ def main(): engine_args = EngineArgs.from_cli_args(args) engine = engine_args.create_engine() - if args.mark_trace: - logger.info( - "--mark-trace enabled: reinitializing engine once to apply compile-time instrumentation changes." - ) - try: - engine.core_mgr.close() - except Exception as e: - logger.warning(f"Failed to close engine before reinitialize: {e}") - engine = engine_args.create_engine() - print(f"Starting server on {args.host}:{args.server_port}...") uvicorn.run(app, host=args.host, port=args.server_port) diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index 4046b652c..56be52349 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -26,7 +26,7 @@ from aiter.jit.utils.torch_guard import torch_compile_guard from aiter.tuned_gemm import tgemm from aiter.utility import fp4_utils -from atom.utils.decorators import mark_trace +from atom.utils.decorators import mark_trace, record_function from atom.model_ops.utils import shuffle_weights from atom.utils import envs @@ -169,10 +169,11 @@ def gemm_a8w8_blockscale_preshuffle_fake( x_scale: torch.Tensor, w_scale: torch.Tensor, dtype: torch.dtype = torch.bfloat16, + prefix: str="", ) -> torch.Tensor: return torch.empty((*x.shape[:-1], weight.shape[0]), dtype=dtype, device=x.device) - +@record_function @torch_compile_guard(gen_fake=gemm_a8w8_blockscale_preshuffle_fake, mutates_args=[]) def gemm_a8w8_blockscale_preshuffle_impl( x: torch.Tensor, @@ -180,6 +181,7 @@ def gemm_a8w8_blockscale_preshuffle_impl( x_scale: torch.Tensor, w_scale: torch.Tensor, dtype: torch.dtype = torch.bfloat16, + prefix: str="", ) -> torch.Tensor: if gemm_a8w8_blockscale_bpreshuffle_triton is not None: weight_shuffled = weight.reshape(weight.shape[0] // 16, weight.shape[1] * 16) @@ -425,7 +427,7 @@ def forward( y += self.bias elif self.quant_type.value == QuantType.per_1x128.value: y = gemm_a8w8_blockscale_preshuffle_impl( - x, self.weight, x_scale, self.weight_scale, dtype=otype + x, self.weight, x_scale, self.weight_scale, dtype=otype, prefix=self.prefix ) if self.bias is not None: y += self.bias diff --git a/atom/utils/backends.py b/atom/utils/backends.py index b583ba447..ae9715300 100644 --- a/atom/utils/backends.py +++ b/atom/utils/backends.py @@ -230,6 +230,34 @@ def compile( handle, ) + # When mark-trace is enabled, post-processing may rewrite generated + # artifact sources. Force reloading from cache artifact so the current + # process uses the rewritten artifact immediately. + try: + from atom.utils.graph_marker import is_graph_marker_enabled + + force_reload = is_graph_marker_enabled() + except Exception: + force_reload = False + if force_reload and handle is not None and not self.disable_cache: + try: + reloaded_graph = self.load( + graph, example_inputs, graph_index, runtime_shape + ) + if reloaded_graph is not None: + compiled_graph = reloaded_graph + logger.info( + "Force reloaded compiled graph from cache artifact " + "(graph_index=%s, runtime_shape=%s).", + graph_index, + runtime_shape, + ) + except Exception: + logger.exception( + "Failed to force reload compiled graph from cache artifact; " + "falling back to in-memory compiled callable." + ) + # after compiling the last graph, record the end time if graph_index == num_graphs - 1: now = time.time() diff --git a/atom/utils/decorators.py b/atom/utils/decorators.py index 14581d906..f2be89785 100644 --- a/atom/utils/decorators.py +++ b/atom/utils/decorators.py @@ -5,6 +5,7 @@ import inspect import os import sys +from functools import wraps from types import CodeType from abc import abstractmethod from contextlib import contextmanager @@ -26,6 +27,63 @@ torch_compile_start_time: float = 0.0 +def record_function(prefix: Union[str, Callable, None] = None): + """ + Decorator that wraps a function with torch.profiler.record_function. + + Usage: + - @record_function + - @record_function("my_prefix") + """ + + def _decorate(func: Callable): + # Try to recover the original callable signature even when func is wrapped + # by other decorators. + base_func = inspect.unwrap(func) + try: + base_sig = inspect.signature(base_func) + except (TypeError, ValueError): + base_sig = None + + @wraps(func) + def _wrapped(*args, **kwargs): + # Keep this decorator no-op unless mark-trace is enabled. + from atom.utils.graph_marker import is_graph_marker_enabled + if not is_graph_marker_enabled(): + return func(*args, **kwargs) + + # Priority: + # 1) explicit decorator prefix: @record_function("xxx") + # 2) runtime function argument named "prefix" when non-empty + # 3) function name fallback + if prefix is not None: + span_name = str(prefix) + else: + span_name = func.__name__ + runtime_prefix = kwargs.get("prefix") + if not (isinstance(runtime_prefix, str) and runtime_prefix): + if base_sig is not None: + try: + bound = base_sig.bind_partial(*args, **kwargs) + runtime_prefix = bound.arguments.get("prefix") + except Exception: + runtime_prefix = None + if isinstance(runtime_prefix, str) and runtime_prefix: + span_name = runtime_prefix + + with torch.profiler.record_function(f"{span_name}"): + return func(*args, **kwargs) + + return _wrapped + + # Support @record_function without parentheses. + if callable(prefix): + func = prefix + prefix = None + return _decorate(func) + return _decorate + + def _graph_marker_first_tensor(obj, name: str): if torch.is_tensor(obj): return graph_marker(obj, name=name), True diff --git a/tools/parse_trace.py b/tools/parse_trace.py index 5b6616722..f233abafa 100644 --- a/tools/parse_trace.py +++ b/tools/parse_trace.py @@ -11,6 +11,7 @@ import sys import bisect import argparse +import re from typing import List, Dict, Any, Tuple, Optional from openpyxl import Workbook @@ -53,69 +54,124 @@ def should_filter_prefill(name: str) -> bool: return any(f in name for f in FILTER_OUT_PREFILL) -def write_breakdown_xlsx(output_xlsx: str, rows: List[List[Any]], sheet_name: str) -> None: +def write_breakdown_xlsx( + output_xlsx: str, + rows: List[List[Any]], + sheet_name: str, + avg_rows: Optional[List[List[Any]]] = None, +) -> None: """ Write XLSX breakdown with columns: - cpu_module, gpu_kernel, duration_us, sum per module. + cpu_module, gpu_kernel, duration_us, sum per module, + avg duration_us, avg sum per module. - The 1st and 4th columns are merged for contiguous identical modules. + The 1st/4th columns are merged for contiguous identical modules. + AVG columns are appended to the right in the same table. """ wb = Workbook() ws = wb.active ws.title = sheet_name - ws.append(['cpu_module', 'gpu_kernel', 'duration_us', 'sum per module']) - - if not rows: - wb.save(output_xlsx) - return - - # Build contiguous groups by cpu_module in current row order. - groups: List[Tuple[int, int, str, float]] = [] - i = 0 - while i < len(rows): - mod = rows[i][0] - j = i + 1 - total = float(rows[i][2]) - while j < len(rows) and rows[j][0] == mod: - total += float(rows[j][2]) - j += 1 - groups.append((i, j - 1, mod, total)) - i = j - - # Write rows with per-module sum in 4th column. - for start, end, _, total in groups: + ws.append([ + 'cpu_module', 'gpu_kernel', 'duration_us', 'sum per module', + 'avg duration_us', 'avg sum per module' + ]) + + def build_groups(block_rows: List[List[Any]]) -> List[Tuple[int, int, str, float]]: + groups: List[Tuple[int, int, str, float]] = [] + i = 0 + while i < len(block_rows): + mod = block_rows[i][0] + j = i + 1 + total = float(block_rows[i][2]) + while j < len(block_rows) and block_rows[j][0] == mod: + total += float(block_rows[j][2]) + j += 1 + groups.append((i, j - 1, mod, total)) + i = j + return groups + + main_groups = build_groups(rows) if rows else [] + renamed_group_mods = [g[2] for g in main_groups] + seen_rmsnorm = 0 + for gi, mod in enumerate(renamed_group_mods): + if isinstance(mod, str) and 'rmsnorm' in mod.lower(): + if seen_rmsnorm == 0: + renamed_group_mods[gi] = 'input_layernorm' + elif seen_rmsnorm == 1: + renamed_group_mods[gi] = 'post_attn_layernorm' + seen_rmsnorm += 1 + + avg_sum_by_row: Dict[int, float] = {} + if avg_rows: + avg_groups = build_groups(avg_rows) + for start, end, _, total in avg_groups: + for i in range(start, end + 1): + avg_sum_by_row[i] = total + + data_start_row = ws.max_row + 1 + for gi, (start, end, _, total) in enumerate(main_groups): + renamed_mod = renamed_group_mods[gi] for idx in range(start, end + 1): - mod, kernel, dur = rows[idx] - ws.append([mod, kernel, dur, total]) + _, kernel, dur = rows[idx] + avg_dur = float(avg_rows[idx][2]) if avg_rows and idx < len(avg_rows) else '' + avg_sum = avg_sum_by_row.get(idx, '') + ws.append([renamed_mod, kernel, dur, total, avg_dur, avg_sum]) - # Merge col A and D for each contiguous module group (row offset +2 for header). - for start, end, _, _ in groups: + for start, end, _, _ in main_groups: if end > start: - r1 = start + 2 - r2 = end + 2 + r1 = data_start_row + start + r2 = data_start_row + end ws.merge_cells(start_row=r1, start_column=1, end_row=r2, end_column=1) ws.merge_cells(start_row=r1, start_column=4, end_row=r2, end_column=4) + if avg_rows: + ws.merge_cells(start_row=r1, start_column=6, end_row=r2, end_column=6) + + total_duration = sum(float(r[2]) for r in rows) if rows else 0.0 + total_avg_duration = sum(float(r[2]) for r in avg_rows) if avg_rows else '' + ws.append(['TOTAL', '', total_duration, '', total_avg_duration, '']) wb.save(output_xlsx) -def rename_rmsnorm_modules(rows: List[List[Any]]) -> None: +def _normalize_module_for_avg(name: str) -> str: + if not isinstance(name, str): + return str(name) + return re.sub(r"model\.layers\.\d+\.", "model.layers.*.", name) + + +def build_avg_rows_from_layers( + layer_rows_list: List[List[List[Any]]], + layer_start_idx: int, + section_name: str, +) -> Optional[List[List[Any]]]: """ - Rename first two rmsnorm cpu module names in-place: - - first -> input_layernorm - - second -> post_attn_layernorm + Build AVG rows across layers using layer-3 rows as template. + Returns None if any layer cannot be aligned by (module, kernel) sequence. """ - seen = 0 - for row in rows: - if not row or len(row) < 1: - continue - mod = row[0] - if isinstance(mod, str) and 'rmsnorm' in mod.lower(): - if seen == 0: - row[0] = 'input_layernorm' - elif seen == 1: - row[0] = 'post_attn_layernorm' - seen += 1 + if not layer_rows_list: + return [] + + base = layer_rows_list[0] + base_sig = [(_normalize_module_for_avg(r[0]), r[1]) for r in base] + + for rel_idx, rows in enumerate(layer_rows_list[1:], start=1): + sig = [(_normalize_module_for_avg(r[0]), r[1]) for r in rows] + if sig != base_sig: + bad_layer = layer_start_idx + rel_idx + print( + f"{section_name} avg skipped: layer {bad_layer} does not match layer {layer_start_idx} layout." + ) + return None + + n = len(layer_rows_list) + avg_rows: List[List[Any]] = [] + for i, (mod, kernel) in enumerate(base_sig): + # Keep original module display style from layer_start_idx rows. + display_mod = base[i][0] + avg_dur = sum(float(layer_rows_list[l][i][2]) for l in range(n)) / n + avg_rows.append([display_mod, kernel, avg_dur]) + return avg_rows + # ============================================================================= @@ -340,12 +396,6 @@ def parse_prefill(events: List[Dict], output_xlsx: str, target_layer: int = 3) - f"Prefill level 1 (same thread pid={prefill_pid}, tid={prefill_tid}): " f"{len(level1_children)} nodes" ) - print("First 20 level 1 nodes:") - for i, child in enumerate(level1_children[:20]): - print( - f" [{i:02d}] {child.get('name', '')} " - f"(ts={child.get('ts', 0):.0f}, dur={child.get('dur', 0):.0f})" - ) # Keep only level2 children that have kernel launch in their subtree. launch_level2_items = [] @@ -363,38 +413,36 @@ def parse_prefill(events: List[Dict], output_xlsx: str, target_layer: int = 3) - }) print(f"Level2 children with kernel launch: {len(launch_level2_items)}") - print("First 20 level2 (with kernel launch):") - for i, item in enumerate(launch_level2_items[:20]): - l2 = item['level2_event'] - print( - f" [{i:02d}] L1={item['level1_name']} | " - f"L2={l2.get('name', '')} | " - f"cat={l2.get('cat', '')} | dur={l2.get('dur', 0):.0f}" - ) # Layer extraction by rmsnorm positions: # each layer has 2 rmsnorm modules, layer N starts at rmsnorm index 2*N (0-based). TARGET_LAYER = target_layer - norm_indices = [ + all_norm_indices = [ i for i, item in enumerate(launch_level2_items) if 'rmsnorm' in item['level2_event'].get('name', '').lower() ] - print(f"Found {len(norm_indices)} rmsnorm modules in level2-with-launch rows") + # Last rmsnorm is final layernorm, not part of transformer layers. + norm_indices = all_norm_indices[:-1] if len(all_norm_indices) > 0 else [] + print( + f"Found {len(all_norm_indices)} rmsnorm modules in level2-with-launch rows " + f"({len(norm_indices)} used for layer split, excluding final layernorm)" + ) - layer_items = [] + mod_start = 0 + mod_end = 0 norm_start_idx = TARGET_LAYER * 2 norm_end_idx = (TARGET_LAYER + 1) * 2 + final_norm_idx = all_norm_indices[-1] if len(all_norm_indices) > 0 else len(launch_level2_items) if norm_start_idx >= len(norm_indices): - print(f"Not enough rmsnorm modules for layer {TARGET_LAYER}, writing empty CSV") + print(f"Not enough rmsnorm modules for layer {TARGET_LAYER}, writing empty XLSX") else: mod_start = norm_indices[norm_start_idx] - mod_end = norm_indices[norm_end_idx] if norm_end_idx < len(norm_indices) else len(launch_level2_items) - layer_items = launch_level2_items[mod_start:mod_end] + mod_end = norm_indices[norm_end_idx] if norm_end_idx < len(norm_indices) else final_norm_idx print( f"Layer {TARGET_LAYER} range by rmsnorm: " f"rows [{mod_start}:{mod_end}) from rmsnorm #{norm_start_idx+1} to #{norm_end_idx+1}" ) - print(f"Layer {TARGET_LAYER} modules: {len(layer_items)}") + print(f"Layer {TARGET_LAYER} modules: {mod_end - mod_start}") # Build launch->kernel mapping by correlation id. # Build launch candidates from current prefill thread/range once. @@ -405,41 +453,23 @@ def parse_prefill(events: List[Dict], output_xlsx: str, target_layer: int = 3) - runtime_launches.sort(key=lambda x: x.get('ts', 0)) runtime_launch_ts = [e.get('ts', 0) for e in runtime_launches] - output_rows = [] - level2_kernel_map: Dict[str, List[str]] = {} + item_corrs: List[List[Any]] = [] corr_needed = set() - for item in layer_items: - l1_name = item['level1_name'] + for item in launch_level2_items: l2 = item['level2_event'] l2_start = l2.get('ts', 0) l2_end = l2_start + l2.get('dur', 0) - l2_name = l2.get('name', '') - l2_cat = l2.get('cat', '') - l2_dur = l2.get('dur', 0) left = bisect.bisect_left(runtime_launch_ts, l2_start) right = bisect.bisect_right(runtime_launch_ts, l2_end) launches_in_l2 = runtime_launches[left:right] - - if not launches_in_l2: - output_rows.append([ - l1_name, l2_name, l2_cat, l2_dur, - 'N/A', 0, '', 'N/A', 0 - ]) - continue - + curr_corrs = [] for launch in launches_in_l2: - launch_name = launch.get('name', 'N/A') - launch_dur = launch.get('dur', 0) corr = (launch.get('args') or {}).get('correlation') if corr is not None: corr_needed.add(corr) - matched_kernels = [] # Fill after kernel index is built. - output_rows.append([ - l1_name, l2_name, l2_cat, l2_dur, - launch_name, launch_dur, corr if corr is not None else '', - 'PENDING', 0 - ]) + curr_corrs.append(corr) + item_corrs.append(curr_corrs) # Build kernel index only for correlations we actually need. kernel_by_corr: Dict[Any, List[Dict]] = {} @@ -454,89 +484,55 @@ def parse_prefill(events: List[Dict], output_xlsx: str, target_layer: int = 3) - for corr in kernel_by_corr: kernel_by_corr[corr].sort(key=lambda x: x.get('ts', 0)) - # Expand pending launch rows into launch->kernel rows. - expanded_rows = [] - for row in output_rows: - l1_name, l2_name, l2_cat, l2_dur, launch_name, launch_dur, corr, _, _ = row - corr_val = corr if corr != '' else None - matched_kernels = kernel_by_corr.get(corr_val, []) if corr_val is not None else [] - if not matched_kernels: - expanded_rows.append([ - l1_name, l2_name, l2_cat, l2_dur, - launch_name, launch_dur, corr, 'N/A', 0 - ]) - level2_kernel_map.setdefault(l2_name, []) - continue - for k in matched_kernels: - kernel_name = k.get('name', 'N/A') - expanded_rows.append([ - l1_name, l2_name, l2_cat, l2_dur, - launch_name, launch_dur, corr, kernel_name, k.get('dur', 0) - ]) - level2_kernel_map.setdefault(l2_name, []).append(kernel_name) - output_rows = expanded_rows - - print(f"Layer {TARGET_LAYER} launch->kernel mapping rows: {len(output_rows)}") - print("First 20 mapping rows:") - for i, row in enumerate(output_rows[:20]): - print( - f" [{i:02d}] L2={row[1]} | launch={row[4]} | " - f"corr={row[6]} | kernel={row[7]}" - ) - print(f"Layer {TARGET_LAYER} level2 -> gpu operators:") - for level2_name, kernels in level2_kernel_map.items(): - # Keep order while removing duplicates. - uniq = list(dict.fromkeys(kernels)) - if not uniq: - print(f" - {level2_name}: N/A") - continue - print(f" - {level2_name}: {len(uniq)} operator(s)") - for kname in uniq: - print(f" * {kname}") - - # Convert to decode-style rows with prefill filters and MoE categorization. - # Output columns: cpu_module, gpu_kernel, duration_us - module_instances: List[Dict[str, Any]] = [] - instance_idx: Dict[Tuple[Any, ...], int] = {} - for row in output_rows: - l1_name, l2_name, l2_cat, l2_dur = row[0], row[1], row[2], row[3] - gpu_kernel = row[7] - gpu_dur = row[8] - - if should_filter_prefill(l2_name): - continue - if gpu_kernel in ('', 'N/A'): - continue - - # Keep module instances separated by their parent+name+category+duration, - # preserving the original order from the trace. - key = (l1_name, l2_name, l2_cat, l2_dur) - if key not in instance_idx: - instance_idx[key] = len(module_instances) - module_instances.append({'mod_name': l2_name, 'kernels': []}) - module_instances[instance_idx[key]]['kernels'].append({ - 'name': gpu_kernel, - 'dur': gpu_dur, - }) - - csv_rows = [] - for inst in module_instances: - mod_name = inst['mod_name'] - kernels = inst['kernels'] - if not kernels: - continue + item_kernels: List[List[Dict[str, Any]]] = [] + for corrs in item_corrs: + kernels = [] + for corr in corrs: + for k in kernel_by_corr.get(corr, []): + kernels.append({'name': k.get('name', 'N/A'), 'dur': k.get('dur', 0)}) + item_kernels.append(kernels) + + def build_rows_from_item_range(start: int, end: int) -> List[List[Any]]: + rows = [] + for i in range(start, end): + item = launch_level2_items[i] + mod_name = item['level2_event'].get('name', '') + if should_filter_prefill(mod_name): + continue + kernels = [k for k in item_kernels[i] if k['name'] not in ('', 'N/A')] + if not kernels: + continue + if 'moe_forward' in mod_name.lower(): + rows.extend(process_moe_module(mod_name, len(kernels), 0, kernels)) + else: + for k in kernels: + rows.append([clean_module_name(mod_name, k['name']), k['name'], k['dur']]) + return rows - if 'moe_forward' in mod_name.lower(): - csv_rows.extend(process_moe_module(mod_name, len(kernels), 0, kernels)) - else: - for k in kernels: - csv_rows.append([clean_module_name(mod_name, k['name']), k['name'], k['dur']]) + # Target layer rows. + csv_rows = build_rows_from_item_range(mod_start, mod_end) if norm_start_idx < len(norm_indices) else [] + print(f"Layer {TARGET_LAYER} launch->kernel mapping rows: {len(csv_rows)}") print(f"Prefill decode-style CSV rows (after filters): {len(csv_rows)}") - rename_rmsnorm_modules(csv_rows) + + # AVG rows from layer 3 to last layer. + avg_rows = None + avg_layer_rows: List[List[List[Any]]] = [] + avg_start_layer = 3 + layer = avg_start_layer + while 2 * layer < len(norm_indices): + s = norm_indices[2 * layer] + e_idx = 2 * (layer + 1) + e = norm_indices[e_idx] if e_idx < len(norm_indices) else final_norm_idx + avg_layer_rows.append(build_rows_from_item_range(s, e)) + layer += 1 + if avg_layer_rows: + avg_rows = build_avg_rows_from_layers(avg_layer_rows, avg_start_layer, "Prefill") + if avg_rows is not None: + print(f"Prefill avg rows: {len(avg_rows)}") # Write XLSX for prefill. - write_breakdown_xlsx(output_xlsx, csv_rows, sheet_name='prefill') + write_breakdown_xlsx(output_xlsx, csv_rows, sheet_name='prefill', avg_rows=avg_rows) def clean_module_name(name: str, mapped_kernel_name: str = '') -> str: @@ -555,7 +551,7 @@ def clean_module_name(name: str, mapped_kernel_name: str = '') -> str: return 'rope & kv_cache' if 'rope' in name_lower: return 'rope' - if 'cache' in name_lower: + if 'cache' in name_lower and 'gemm' not in name_lower: return 'kv_cache' return name @@ -662,9 +658,14 @@ def parse_decode(events: List[Dict], output_xlsx: str, target_layer: int = 3) -> first_ds = decode_steps[actual_run_idx] first_ds_name = first_ds.get('name', '') + target_bs: Optional[int] = None if '_bs_' in first_ds_name: bs = first_ds_name.split('_bs_')[-1] target_cg_name = f'capture_graph_bs_{bs}' + try: + target_bs = int(bs) + except ValueError: + target_bs = None else: target_cg_name = 'capture_graph' @@ -676,6 +677,26 @@ def parse_decode(events: List[Dict], output_xlsx: str, target_layer: int = 3) -> e for e in events if e.get('name') == target_cg_name and e.get('ph') == 'X' ] + if not capture_graphs and target_bs is not None: + # Prefer the largest capture_graph_bs_K where K < target_bs. + lower_bs_candidates: List[Tuple[int, Dict[str, Any]]] = [] + for e in events: + if e.get('ph') != 'X': + continue + n = e.get('name', '') + m = re.match(r'^capture_graph_bs_(\d+)$', n) + if not m: + continue + k = int(m.group(1)) + if k < target_bs: + lower_bs_candidates.append((k, e)) + if lower_bs_candidates: + best_bs = max(k for k, _ in lower_bs_candidates) + capture_graphs = sorted( + [e for k, e in lower_bs_candidates if k == best_bs], + key=lambda x: x.get('ts', 0), + ) + print(f"No exact match, using nearest lower capture_graph_bs_{best_bs}") if not capture_graphs: # Fallback: find any capture_graph capture_graphs = [ @@ -744,8 +765,13 @@ def parse_decode(events: List[Dict], output_xlsx: str, target_layer: int = 3) -> gpu_idx += kernel_count # Find norm positions (rmsnorm in name) - norm_indices = [i for i, (name, _, _) in enumerate(all_modules) if 'rmsnorm' in name.lower()] - print(f"Found {len(norm_indices)} norm modules") + all_norm_indices = [i for i, (name, _, _) in enumerate(all_modules) if 'rmsnorm' in name.lower()] + # Last rmsnorm is final layernorm, not part of transformer layers. + norm_indices = all_norm_indices[:-1] if len(all_norm_indices) > 0 else [] + print( + f"Found {len(all_norm_indices)} norm modules " + f"({len(norm_indices)} used for layer split, excluding final layernorm)" + ) # Extract layer 3 (4th layer, 0-indexed) # Each layer has 2 norms, so layer N starts at norm index 2*N @@ -753,27 +779,46 @@ def parse_decode(events: List[Dict], output_xlsx: str, target_layer: int = 3) -> norm_start_idx = TARGET_LAYER * 2 # 6 (7th norm, 0-indexed) norm_end_idx = (TARGET_LAYER + 1) * 2 # 8 (9th norm, 0-indexed) + final_norm_idx = all_norm_indices[-1] if len(all_norm_indices) > 0 else len(all_modules) if norm_start_idx >= len(norm_indices): print(f"Not enough norms for layer {TARGET_LAYER}") return # Module range for layer 3: from norm_indices[6] to norm_indices[8] (exclusive) mod_start = norm_indices[norm_start_idx] - mod_end = norm_indices[norm_end_idx] if norm_end_idx < len(norm_indices) else len(all_modules) + mod_end = norm_indices[norm_end_idx] if norm_end_idx < len(norm_indices) else final_norm_idx print(f"Layer {TARGET_LAYER}: modules [{mod_start}:{mod_end}] (norms at indices {norm_start_idx}, {norm_start_idx+1})") - # Build CSV rows for layer 3 only - rows = [] - for mod_name, kernel_count, start_gpu_idx in all_modules[mod_start:mod_end]: - if 'moe_forward' in mod_name.lower(): - rows.extend(process_moe_module(mod_name, kernel_count, start_gpu_idx, gpu_kernels)) - else: - rows.extend(process_regular_module(mod_name, kernel_count, start_gpu_idx, gpu_kernels)) - rename_rmsnorm_modules(rows) + def build_rows_for_module_range(start: int, end: int) -> List[List[Any]]: + rows = [] + for mod_name, kernel_count, start_gpu_idx in all_modules[start:end]: + if 'moe_forward' in mod_name.lower(): + rows.extend(process_moe_module(mod_name, kernel_count, start_gpu_idx, gpu_kernels)) + else: + rows.extend(process_regular_module(mod_name, kernel_count, start_gpu_idx, gpu_kernels)) + return rows + + # Target layer rows. + rows = build_rows_for_module_range(mod_start, mod_end) + + # AVG rows from layer 3 to last layer. + avg_rows = None + avg_layer_rows: List[List[List[Any]]] = [] + layer = 3 + while 2 * layer < len(norm_indices): + s = norm_indices[2 * layer] + e_idx = 2 * (layer + 1) + e = norm_indices[e_idx] if e_idx < len(norm_indices) else final_norm_idx + avg_layer_rows.append(build_rows_for_module_range(s, e)) + layer += 1 + if avg_layer_rows: + avg_rows = build_avg_rows_from_layers(avg_layer_rows, 3, "Decode") + if avg_rows is not None: + print(f"Decode avg rows: {len(avg_rows)}") # Write XLSX - write_breakdown_xlsx(output_xlsx, rows, sheet_name='decode') + write_breakdown_xlsx(output_xlsx, rows, sheet_name='decode', avg_rows=avg_rows) print(f"Layer {TARGET_LAYER} modules: {mod_end - mod_start}") print(f"XLSX written to: {output_xlsx} ({len(rows)} rows)") From abdf74bfea05efd36021c67235fc027f445a6c11 Mon Sep 17 00:00:00 2001 From: HaonanWang98 Date: Thu, 5 Mar 2026 03:40:12 +0000 Subject: [PATCH 5/9] fix --- atom/model_engine/model_runner.py | 3 +-- tools/parse_trace.py | 11 ++++++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index 703d3ab92..4cb487ca8 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -6,7 +6,6 @@ import os import time from contextlib import nullcontext -from itertools import chain, islice from typing import Any, Optional, Union import numpy as np @@ -1339,7 +1338,7 @@ def run_model(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor positions = context.positions if is_prefill or self.enforce_eager or bs > self.graph_bs[-1]: with ( - record_function("prefill") + record_function(f"prefill_bs_{bs}_ctxlens_{forward_context.attn_metadata.context_lens}") if self.mark_trace else nullcontext() ): diff --git a/tools/parse_trace.py b/tools/parse_trace.py index f233abafa..fdcc8329b 100644 --- a/tools/parse_trace.py +++ b/tools/parse_trace.py @@ -332,10 +332,15 @@ def parse_prefill(events: List[Dict], output_xlsx: str, target_layer: int = 3) - is treated as warmup and prefill[1] is the actual prefill. - Otherwise, prefill[0] is the actual prefill. """ - # CPU side prefill/decode annotations + # CPU side prefill/decode annotations. + # Accept both legacy "prefill" and traced variants like + # "prefill_bs_1_ctxlens_tensor([417], ...)". prefills = [ e for e in events - if e.get('name') == 'prefill' + if ( + e.get('name') == 'prefill' + or e.get('name', '').startswith('prefill_bs_') + ) and e.get('ph') == 'X' and e.get('cat') == 'user_annotation' ] @@ -704,7 +709,7 @@ def parse_decode(events: List[Dict], output_xlsx: str, target_layer: int = 3) -> if e.get('name', '').startswith('capture_graph') and e.get('ph') == 'X' ] capture_graphs = sorted(capture_graphs, key=lambda x: x['ts']) - print(f"No exact match, using first capture_graph") + print("No exact match, using first capture_graph") if not capture_graphs: print("No capture_graph events found.") From 7ff38035d2f9fa498dd780feeb324b5eeab2f3ff Mon Sep 17 00:00:00 2001 From: HaonanWang98 Date: Thu, 5 Mar 2026 05:38:36 +0000 Subject: [PATCH 6/9] fix_black --- atom/model_engine/model_runner.py | 5 +- atom/model_ops/linear.py | 12 +- atom/utils/compiler_inferface.py | 19 +- atom/utils/decorators.py | 5 +- atom/utils/graph_marker.py | 2 - atom/utils/graph_marker_instrumentation.py | 36 +- tools/parse_trace.py | 542 ++++++++++++--------- 7 files changed, 352 insertions(+), 269 deletions(-) diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index 83b5c29fe..58772f993 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -468,6 +468,7 @@ def __init__(self, rank: int, config: Config): self.config = config self.mark_trace = getattr(config, "mark_trace", False) from atom.utils.graph_marker import set_graph_marker_enabled + set_graph_marker_enabled(self.mark_trace) set_current_atom_config(config) hf_config = config.hf_config @@ -1323,7 +1324,9 @@ def run_model(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor positions = context.positions if is_prefill or self.enforce_eager or bs > self.graph_bs[-1]: with ( - record_function(f"prefill_bs_{bs}_ctxlens_{forward_context.attn_metadata.context_lens}") + record_function( + f"prefill_bs_{bs}_ctxlens_{forward_context.attn_metadata.context_lens}" + ) if self.mark_trace else nullcontext() ): diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index 56be52349..ed611edaa 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -169,10 +169,11 @@ def gemm_a8w8_blockscale_preshuffle_fake( x_scale: torch.Tensor, w_scale: torch.Tensor, dtype: torch.dtype = torch.bfloat16, - prefix: str="", + prefix: str = "", ) -> torch.Tensor: return torch.empty((*x.shape[:-1], weight.shape[0]), dtype=dtype, device=x.device) + @record_function @torch_compile_guard(gen_fake=gemm_a8w8_blockscale_preshuffle_fake, mutates_args=[]) def gemm_a8w8_blockscale_preshuffle_impl( @@ -181,7 +182,7 @@ def gemm_a8w8_blockscale_preshuffle_impl( x_scale: torch.Tensor, w_scale: torch.Tensor, dtype: torch.dtype = torch.bfloat16, - prefix: str="", + prefix: str = "", ) -> torch.Tensor: if gemm_a8w8_blockscale_bpreshuffle_triton is not None: weight_shuffled = weight.reshape(weight.shape[0] // 16, weight.shape[1] * 16) @@ -427,7 +428,12 @@ def forward( y += self.bias elif self.quant_type.value == QuantType.per_1x128.value: y = gemm_a8w8_blockscale_preshuffle_impl( - x, self.weight, x_scale, self.weight_scale, dtype=otype, prefix=self.prefix + x, + self.weight, + x_scale, + self.weight_scale, + dtype=otype, + prefix=self.prefix, ) if self.bias is not None: y += self.bias diff --git a/atom/utils/compiler_inferface.py b/atom/utils/compiler_inferface.py index ebabf59cd..70cf95722 100644 --- a/atom/utils/compiler_inferface.py +++ b/atom/utils/compiler_inferface.py @@ -170,6 +170,7 @@ def set_inductor_config(config, runtime_shape): try: from atom.utils.graph_marker import is_graph_marker_enabled + if is_graph_marker_enabled(): config["size_asserts"] = False config["compile_threads"] = 1 @@ -398,21 +399,24 @@ def _get_shape_env() -> AlwaysHitShapeEnv: "failed, leading to a corrupted compilation artifact. " "We recommend trying to " "remove ~/.cache/vllm/torch_compile_cache and try again " - "to see the real issue. ") - assert file_path is not None, ( - "failed to get the file path of the compiled graph") + "to see the real issue. " + ) + assert ( + file_path is not None + ), "failed to get the file path of the compiled graph" # Best-effort post-process the generated wrapper file too (PyTorch <2.8 path). try: # Only run post-processing when mark-trace is enabled (to avoid any # overhead / file churn in default runs). from atom.utils.graph_marker import is_graph_marker_enabled + if is_graph_marker_enabled(): # Local import to avoid extra package-level side effects. from .graph_marker_instrumentation import ( instrument_record_functions_in_file, ) - instrument_record_functions_in_file(file_path, - strip_markers=False) + + instrument_record_functions_in_file(file_path, strip_markers=False) except Exception: pass return compiled_graph, (hash_str, file_path) @@ -586,13 +590,14 @@ def compile( # Only run post-processing when mark-trace is enabled (to avoid any # overhead / file churn in default runs). from atom.utils.graph_marker import is_graph_marker_enabled + if is_graph_marker_enabled(): # Local import to avoid extra package-level side effects. from .graph_marker_instrumentation import ( instrument_record_functions_in_dir, ) - instrument_record_functions_in_dir(path, - strip_markers=False) + + instrument_record_functions_in_dir(path, strip_markers=False) except Exception: # Best-effort: never fail compilation due to instrumentation. pass diff --git a/atom/utils/decorators.py b/atom/utils/decorators.py index f2be89785..8fa916218 100644 --- a/atom/utils/decorators.py +++ b/atom/utils/decorators.py @@ -19,6 +19,7 @@ from atom.config import CompilationConfig, Config, CompilationLevel from atom.utils.graph_marker import graph_marker + # from atom.utils import start_monitoring_torch_compile _T = TypeVar("_T", bound=type[nn.Module]) @@ -49,6 +50,7 @@ def _decorate(func: Callable): def _wrapped(*args, **kwargs): # Keep this decorator no-op unless mark-trace is enabled. from atom.utils.graph_marker import is_graph_marker_enabled + if not is_graph_marker_enabled(): return func(*args, **kwargs) @@ -137,7 +139,7 @@ def wrapped_forward(self, *args, **kwargs): # When mark-trace is disabled, bypass all wrapping logic entirely if not is_graph_marker_enabled(): return forward(self, *args, **kwargs) - + prefix = getattr(self, "prefix", cls.__name__) # Mark only the first tensor across args/kwargs, keeping names stable. args_l = list(args) @@ -162,6 +164,7 @@ def wrapped_forward(self, *args, **kwargs): cls.forward = wrapped_forward return cls + # We remove it from utils/__init__.py to avoid circular import def start_monitoring_torch_compile(vllm_config: Config): global torch_compile_start_time diff --git a/atom/utils/graph_marker.py b/atom/utils/graph_marker.py index 1f71a98c9..d9caf0b9c 100644 --- a/atom/utils/graph_marker.py +++ b/atom/utils/graph_marker.py @@ -46,5 +46,3 @@ def graph_marker(x: torch.Tensor, name: str) -> torch.Tensor: if not _GRAPH_MARKER_ENABLED: return x return _graph_marker_impl(x) - - diff --git a/atom/utils/graph_marker_instrumentation.py b/atom/utils/graph_marker_instrumentation.py index ba2170cde..8e8dd4e51 100644 --- a/atom/utils/graph_marker_instrumentation.py +++ b/atom/utils/graph_marker_instrumentation.py @@ -9,7 +9,6 @@ from dataclasses import dataclass from typing import Iterable, Optional - _SUBGRAPH_ID_RE = re.compile(r"artifact_shape_[^/]+_subgraph_(\d+)") _ASSIGNMENT_RE = re.compile( @@ -107,12 +106,7 @@ def _split_top_level_args(s: str) -> list[str]: if ch == "}": depth_brace -= 1 continue - if ( - ch == "," - and depth_paren == 0 - and depth_bracket == 0 - and depth_brace == 0 - ): + if ch == "," and depth_paren == 0 and depth_bracket == 0 and depth_brace == 0: out.append(s[start:i].strip()) start = i + 1 out.append(s[start:].strip()) @@ -194,8 +188,7 @@ def _iter_py_files(root: str) -> Iterable[str]: def _ensure_record_function_import(lines: list[str]) -> None: # If already imported or referenced via qualified name, do nothing. if any( - ("record_function" in l and ("import" in l or "from torch" in l)) - for l in lines + ("record_function" in l and ("import" in l or "from torch" in l)) for l in lines ): return @@ -240,7 +233,9 @@ def _prefix_and_kind(name: str) -> Optional[tuple[str, str]]: return None -def _already_wrapped(lines: list[str], indent: str, prefix: str, start_idx: int, end_idx: int) -> bool: +def _already_wrapped( + lines: list[str], indent: str, prefix: str, start_idx: int, end_idx: int +) -> bool: needle = f'{indent}with record_function("{prefix}"):\n' for i in range(start_idx, min(end_idx, len(lines))): if lines[i] == needle: @@ -283,7 +278,9 @@ def _wrap_region_with_record_function( # If we already inserted a record_function line in a previous run, upgrade it # in-place (e.g. "mlp" -> "layer_0_mlp") and exit without touching indentation. - if insert_at < len(lines) and lines[insert_at].startswith(f"{indent}with record_function("): + if insert_at < len(lines) and lines[insert_at].startswith( + f"{indent}with record_function(" + ): if lines[insert_at] != with_line: lines[insert_at] = with_line return @@ -301,7 +298,7 @@ def _wrap_region_with_record_function( if line.strip() == "": continue if line.startswith(indent_prefix): - lines[i] = indent_prefix + extra + line[len(indent_prefix):] + lines[i] = indent_prefix + extra + line[len(indent_prefix) :] def _layer_id_from_wrapper_path(path: str) -> Optional[int]: @@ -338,9 +335,8 @@ def _strip_runtime_graph_markers(lines: list[str]) -> bool: continue if ( - ("assert_size_stride" in line or "assert_alignment" in line) - and "torch.ops.aiter.graph_marker.default" in line - ): + "assert_size_stride" in line or "assert_alignment" in line + ) and "torch.ops.aiter.graph_marker.default" in line: changed = True continue @@ -351,7 +347,9 @@ def _strip_runtime_graph_markers(lines: list[str]) -> bool: return changed -def instrument_record_functions_in_file(path: str, *, strip_markers: bool = True) -> bool: +def instrument_record_functions_in_file( + path: str, *, strip_markers: bool = True +) -> bool: """ Returns True if the file was modified. """ @@ -391,7 +389,9 @@ def instrument_record_functions_in_file(path: str, *, strip_markers: bool = True # Apply from bottom to top so indices stay valid. wrapped_or_upgraded = False if has_intervals: - for start_mk, end_mk, prefix in sorted(intervals, key=lambda t: t[0].idx, reverse=True): + for start_mk, end_mk, prefix in sorted( + intervals, key=lambda t: t[0].idx, reverse=True + ): _wrap_region_with_record_function( lines, start_marker_idx=start_mk.idx, @@ -430,5 +430,3 @@ def instrument_record_functions_in_dir(root: str, *, strip_markers: bool = True) if instrument_record_functions_in_file(fp, strip_markers=strip_markers): changed += 1 return changed - - diff --git a/tools/parse_trace.py b/tools/parse_trace.py index fdcc8329b..a9407cb2d 100644 --- a/tools/parse_trace.py +++ b/tools/parse_trace.py @@ -16,24 +16,27 @@ from openpyxl import Workbook # Modules to filter out (no corresponding GPU kernel in decode) -FILTER_OUT = ['fill_'] +FILTER_OUT = ["fill_"] # Sampling-related modules and low-level ops to filter out in prefill -FILTER_OUT_PREFILL = ['aten::', 'aiter::gemm_a16w16', 'aiter::mixed_sample'] +FILTER_OUT_PREFILL = ["aten::", "aiter::gemm_a16w16", "aiter::mixed_sample"] # ============================================================================= # Utility Functions # ============================================================================= + def load_trace(filepath: str) -> Dict[str, Any]: """Load trace JSON file (supports .gz).""" - opener = gzip.open if filepath.endswith('.gz') else open - with opener(filepath, 'rt', encoding='utf-8') as f: + opener = gzip.open if filepath.endswith(".gz") else open + with opener(filepath, "rt", encoding="utf-8") as f: return json.load(f) -def is_within(child_ts: float, child_dur: float, parent_ts: float, parent_dur: float) -> bool: +def is_within( + child_ts: float, child_dur: float, parent_ts: float, parent_dur: float +) -> bool: """Check if child event is within parent's time range.""" return child_ts >= parent_ts and (child_ts + child_dur) <= (parent_ts + parent_dur) @@ -41,7 +44,7 @@ def is_within(child_ts: float, child_dur: float, parent_ts: float, parent_dur: f def is_kernel_launch(name: str) -> bool: """Check if name is a kernel launch (contains 'launch' and 'kernel').""" n = name.lower() - return 'launch' in n and 'kernel' in n + return "launch" in n and "kernel" in n def should_filter(name: str) -> bool: @@ -71,10 +74,16 @@ def write_breakdown_xlsx( wb = Workbook() ws = wb.active ws.title = sheet_name - ws.append([ - 'cpu_module', 'gpu_kernel', 'duration_us', 'sum per module', - 'avg duration_us', 'avg sum per module' - ]) + ws.append( + [ + "cpu_module", + "gpu_kernel", + "duration_us", + "sum per module", + "avg duration_us", + "avg sum per module", + ] + ) def build_groups(block_rows: List[List[Any]]) -> List[Tuple[int, int, str, float]]: groups: List[Tuple[int, int, str, float]] = [] @@ -94,11 +103,11 @@ def build_groups(block_rows: List[List[Any]]) -> List[Tuple[int, int, str, float renamed_group_mods = [g[2] for g in main_groups] seen_rmsnorm = 0 for gi, mod in enumerate(renamed_group_mods): - if isinstance(mod, str) and 'rmsnorm' in mod.lower(): + if isinstance(mod, str) and "rmsnorm" in mod.lower(): if seen_rmsnorm == 0: - renamed_group_mods[gi] = 'input_layernorm' + renamed_group_mods[gi] = "input_layernorm" elif seen_rmsnorm == 1: - renamed_group_mods[gi] = 'post_attn_layernorm' + renamed_group_mods[gi] = "post_attn_layernorm" seen_rmsnorm += 1 avg_sum_by_row: Dict[int, float] = {} @@ -113,8 +122,10 @@ def build_groups(block_rows: List[List[Any]]) -> List[Tuple[int, int, str, float renamed_mod = renamed_group_mods[gi] for idx in range(start, end + 1): _, kernel, dur = rows[idx] - avg_dur = float(avg_rows[idx][2]) if avg_rows and idx < len(avg_rows) else '' - avg_sum = avg_sum_by_row.get(idx, '') + avg_dur = ( + float(avg_rows[idx][2]) if avg_rows and idx < len(avg_rows) else "" + ) + avg_sum = avg_sum_by_row.get(idx, "") ws.append([renamed_mod, kernel, dur, total, avg_dur, avg_sum]) for start, end, _, _ in main_groups: @@ -127,8 +138,8 @@ def build_groups(block_rows: List[List[Any]]) -> List[Tuple[int, int, str, float ws.merge_cells(start_row=r1, start_column=6, end_row=r2, end_column=6) total_duration = sum(float(r[2]) for r in rows) if rows else 0.0 - total_avg_duration = sum(float(r[2]) for r in avg_rows) if avg_rows else '' - ws.append(['TOTAL', '', total_duration, '', total_avg_duration, '']) + total_avg_duration = sum(float(r[2]) for r in avg_rows) if avg_rows else "" + ws.append(["TOTAL", "", total_duration, "", total_avg_duration, ""]) wb.save(output_xlsx) @@ -173,35 +184,40 @@ def build_avg_rows_from_layers( return avg_rows - # ============================================================================= # Optimized Event Index for fast time-range queries # ============================================================================= + class EventIndex: """Pre-indexed events for fast time-range queries.""" - + def __init__(self, events: List[Dict]): # Filter duration events only - self.duration_events = [e for e in events if e.get('ph') == 'X'] - self.duration_events.sort(key=lambda x: x['ts']) - self.ts_list = [e['ts'] for e in self.duration_events] - + self.duration_events = [e for e in events if e.get("ph") == "X"] + self.duration_events.sort(key=lambda x: x["ts"]) + self.ts_list = [e["ts"] for e in self.duration_events] + # Pre-compute kernel launch flags and prefix sum - self._is_kernel_launch = [is_kernel_launch(e.get('name', '')) for e in self.duration_events] + self._is_kernel_launch = [ + is_kernel_launch(e.get("name", "")) for e in self.duration_events + ] self._kernel_prefix_sum = [0] for is_kl in self._is_kernel_launch: - self._kernel_prefix_sum.append(self._kernel_prefix_sum[-1] + (1 if is_kl else 0)) - + self._kernel_prefix_sum.append( + self._kernel_prefix_sum[-1] + (1 if is_kl else 0) + ) + def events_in_range(self, start_ts: float, end_ts: float) -> List[Dict]: """Get all duration events within [start_ts, end_ts].""" left = bisect.bisect_left(self.ts_list, start_ts) right = bisect.bisect_right(self.ts_list, end_ts) return [ - e for e in self.duration_events[left:right] - if e['ts'] + e.get('dur', 0) <= end_ts + e + for e in self.duration_events[left:right] + if e["ts"] + e.get("dur", 0) <= end_ts ] - + def count_kernel_launches_in_range(self, start_ts: float, end_ts: float) -> int: """Count kernel launches within time range (fast using prefix sum).""" left = bisect.bisect_left(self.ts_list, start_ts) @@ -209,52 +225,49 @@ def count_kernel_launches_in_range(self, start_ts: float, end_ts: float) -> int: count = 0 for i in range(left, right): e = self.duration_events[i] - if e['ts'] + e.get('dur', 0) <= end_ts and self._is_kernel_launch[i]: + if e["ts"] + e.get("dur", 0) <= end_ts and self._is_kernel_launch[i]: count += 1 return count - + def get_direct_children(self, parent: Dict) -> List[Dict]: """Get direct children of parent event (optimized).""" - p_ts = parent['ts'] - p_end = p_ts + parent.get('dur', 0) - + p_ts = parent["ts"] + p_end = p_ts + parent.get("dur", 0) + # Get candidates in parent's time range - candidates = [ - e for e in self.events_in_range(p_ts, p_end) - if e is not parent - ] - + candidates = [e for e in self.events_in_range(p_ts, p_end) if e is not parent] + if not candidates: return [] - + # Filter to direct children only (not nested in other candidates) # Sort by duration descending - larger events are potential parents - candidates_sorted = sorted(candidates, key=lambda x: -x.get('dur', 0)) - + candidates_sorted = sorted(candidates, key=lambda x: -x.get("dur", 0)) + direct = [] for i, c in enumerate(candidates_sorted): - c_ts, c_dur = c['ts'], c.get('dur', 0) + c_ts, c_dur = c["ts"], c.get("dur", 0) c_end = c_ts + c_dur # Check if c is nested inside any larger candidate is_nested = False for j in range(i): # Only check larger (earlier in sorted list) o = candidates_sorted[j] - o_ts = o['ts'] - o_end = o_ts + o.get('dur', 0) + o_ts = o["ts"] + o_end = o_ts + o.get("dur", 0) if c_ts >= o_ts and c_end <= o_end: is_nested = True break if not is_nested: direct.append(c) - - return sorted(direct, key=lambda x: x['ts']) - + + return sorted(direct, key=lambda x: x["ts"]) + def count_kernel_launches(self, event: Dict) -> int: """Count kernel launches within event's time range.""" - e_ts = event['ts'] - e_end = e_ts + event.get('dur', 0) + e_ts = event["ts"] + e_end = e_ts + event.get("dur", 0) return self.count_kernel_launches_in_range(e_ts, e_end) - + def has_kernel_launch(self, event: Dict) -> bool: """Check if event contains any kernel launch.""" return self.count_kernel_launches(event) > 0 @@ -264,51 +277,61 @@ def has_kernel_launch(self, event: Dict) -> bool: # Legacy functions (for prefill compatibility) # ============================================================================= + def find_events(events: List[Dict], name: str, prefix: bool = False) -> List[Dict]: """Find all duration events (ph='X') with given name, sorted by time.""" if prefix: - result = [e for e in events if e.get('name', '').startswith(name) and e.get('ph') == 'X'] + result = [ + e + for e in events + if e.get("name", "").startswith(name) and e.get("ph") == "X" + ] else: - result = [e for e in events if e.get('name') == name and e.get('ph') == 'X'] - return sorted(result, key=lambda x: x['ts']) + result = [e for e in events if e.get("name") == name and e.get("ph") == "X"] + return sorted(result, key=lambda x: x["ts"]) def get_gpu_kernels(events: List[Dict], start_ts: float) -> List[Dict]: """Get GPU kernels (cat='kernel') starting from given timestamp.""" - result = [e for e in events if e.get('cat') == 'kernel' and e['ts'] >= start_ts] - return sorted(result, key=lambda x: x['ts']) + result = [e for e in events if e.get("cat") == "kernel" and e["ts"] >= start_ts] + return sorted(result, key=lambda x: x["ts"]) def get_direct_children(parent: Dict, events: List[Dict]) -> List[Dict]: """Get direct children of parent event (excluding nested children).""" - p_ts, p_dur = parent['ts'], parent.get('dur', 0) - + p_ts, p_dur = parent["ts"], parent.get("dur", 0) + candidates = [ - e for e in events - if e.get('ph') == 'X' and e is not parent - and is_within(e.get('ts', 0), e.get('dur', 0), p_ts, p_dur) + e + for e in events + if e.get("ph") == "X" + and e is not parent + and is_within(e.get("ts", 0), e.get("dur", 0), p_ts, p_dur) ] - + direct = [] for c in candidates: - c_ts, c_dur = c['ts'], c.get('dur', 0) + c_ts, c_dur = c["ts"], c.get("dur", 0) is_direct = not any( - is_within(c_ts, c_dur, o['ts'], o.get('dur', 0)) - for o in candidates if o is not c + is_within(c_ts, c_dur, o["ts"], o.get("dur", 0)) + for o in candidates + if o is not c ) if is_direct: direct.append(c) - - return sorted(direct, key=lambda x: x['ts']) + + return sorted(direct, key=lambda x: x["ts"]) def count_kernel_launches(event: Dict, events: List[Dict]) -> int: """Count kernel launches within event's subtree.""" - e_ts, e_dur = event['ts'], event.get('dur', 0) + e_ts, e_dur = event["ts"], event.get("dur", 0) return sum( - 1 for e in events - if e.get('ph') == 'X' and is_kernel_launch(e.get('name', '')) - and is_within(e.get('ts', 0), e.get('dur', 0), e_ts, e_dur) + 1 + for e in events + if e.get("ph") == "X" + and is_kernel_launch(e.get("name", "")) + and is_within(e.get("ts", 0), e.get("dur", 0), e_ts, e_dur) ) @@ -321,6 +344,7 @@ def has_kernel_launch(event: Dict, events: List[Dict]) -> bool: # Parse Functions # ============================================================================= + def parse_prefill(events: List[Dict], output_xlsx: str, target_layer: int = 3) -> None: """ Parse prefill phase: find the actual prefill event on CPU trace (user_annotation). @@ -336,19 +360,17 @@ def parse_prefill(events: List[Dict], output_xlsx: str, target_layer: int = 3) - # Accept both legacy "prefill" and traced variants like # "prefill_bs_1_ctxlens_tensor([417], ...)". prefills = [ - e for e in events - if ( - e.get('name') == 'prefill' - or e.get('name', '').startswith('prefill_bs_') - ) - and e.get('ph') == 'X' - and e.get('cat') == 'user_annotation' + e + for e in events + if (e.get("name") == "prefill" or e.get("name", "").startswith("prefill_bs_")) + and e.get("ph") == "X" + and e.get("cat") == "user_annotation" ] - prefills = sorted(prefills, key=lambda x: x['ts']) + prefills = sorted(prefills, key=lambda x: x["ts"]) if not prefills: print("No prefill (user_annotation) events found.") - write_breakdown_xlsx(output_xlsx, [], sheet_name='prefill') + write_breakdown_xlsx(output_xlsx, [], sheet_name="prefill") return actual_prefill_idx = 0 @@ -358,15 +380,15 @@ def parse_prefill(events: List[Dict], output_xlsx: str, target_layer: int = 3) - if len(prefills) >= 2: first = prefills[0] second = prefills[1] - gap_start = first['ts'] + first.get('dur', 0) - gap_end = second['ts'] + gap_start = first["ts"] + first.get("dur", 0) + gap_end = second["ts"] # If decode_step_bs appears in [gap_start, gap_end], first prefill is warmup. has_decode_between = any( - e.get('ph') == 'X' - and e.get('cat') == 'user_annotation' - and e.get('name', '').startswith('decode_step_bs') - and gap_start <= e.get('ts', 0) <= gap_end + e.get("ph") == "X" + and e.get("cat") == "user_annotation" + and e.get("name", "").startswith("decode_step_bs") + and gap_start <= e.get("ts", 0) <= gap_end for e in events ) if has_decode_between: @@ -386,13 +408,14 @@ def parse_prefill(events: List[Dict], output_xlsx: str, target_layer: int = 3) - # Build prefill hierarchy on the same thread as the selected CPU prefill. # Using thread affinity is more robust than category-only filtering. - prefill_tid = actual_prefill.get('tid') - prefill_pid = actual_prefill.get('pid') + prefill_tid = actual_prefill.get("tid") + prefill_pid = actual_prefill.get("pid") prefill_hierarchy_events = [ - e for e in events - if e.get('ph') == 'X' - and e.get('tid') == prefill_tid - and e.get('pid') == prefill_pid + e + for e in events + if e.get("ph") == "X" + and e.get("tid") == prefill_tid + and e.get("pid") == prefill_pid ] # Build index once for fast subtree queries in prefill parsing. prefill_idx = EventIndex(prefill_hierarchy_events) @@ -405,17 +428,18 @@ def parse_prefill(events: List[Dict], output_xlsx: str, target_layer: int = 3) - # Keep only level2 children that have kernel launch in their subtree. launch_level2_items = [] for l1 in level1_children: - l1_name = l1.get('name', '') + l1_name = l1.get("name", "") level2_children = prefill_idx.get_direct_children(l1) level2_with_launch = [ - l2 for l2 in level2_children - if prefill_idx.has_kernel_launch(l2) + l2 for l2 in level2_children if prefill_idx.has_kernel_launch(l2) ] for l2 in level2_with_launch: - launch_level2_items.append({ - 'level1_name': l1_name, - 'level2_event': l2, - }) + launch_level2_items.append( + { + "level1_name": l1_name, + "level2_event": l2, + } + ) print(f"Level2 children with kernel launch: {len(launch_level2_items)}") @@ -423,8 +447,9 @@ def parse_prefill(events: List[Dict], output_xlsx: str, target_layer: int = 3) - # each layer has 2 rmsnorm modules, layer N starts at rmsnorm index 2*N (0-based). TARGET_LAYER = target_layer all_norm_indices = [ - i for i, item in enumerate(launch_level2_items) - if 'rmsnorm' in item['level2_event'].get('name', '').lower() + i + for i, item in enumerate(launch_level2_items) + if "rmsnorm" in item["level2_event"].get("name", "").lower() ] # Last rmsnorm is final layernorm, not part of transformer layers. norm_indices = all_norm_indices[:-1] if len(all_norm_indices) > 0 else [] @@ -437,12 +462,20 @@ def parse_prefill(events: List[Dict], output_xlsx: str, target_layer: int = 3) - mod_end = 0 norm_start_idx = TARGET_LAYER * 2 norm_end_idx = (TARGET_LAYER + 1) * 2 - final_norm_idx = all_norm_indices[-1] if len(all_norm_indices) > 0 else len(launch_level2_items) + final_norm_idx = ( + all_norm_indices[-1] if len(all_norm_indices) > 0 else len(launch_level2_items) + ) if norm_start_idx >= len(norm_indices): - print(f"Not enough rmsnorm modules for layer {TARGET_LAYER}, writing empty XLSX") + print( + f"Not enough rmsnorm modules for layer {TARGET_LAYER}, writing empty XLSX" + ) else: mod_start = norm_indices[norm_start_idx] - mod_end = norm_indices[norm_end_idx] if norm_end_idx < len(norm_indices) else final_norm_idx + mod_end = ( + norm_indices[norm_end_idx] + if norm_end_idx < len(norm_indices) + else final_norm_idx + ) print( f"Layer {TARGET_LAYER} range by rmsnorm: " f"rows [{mod_start}:{mod_end}) from rmsnorm #{norm_start_idx+1} to #{norm_end_idx+1}" @@ -452,25 +485,26 @@ def parse_prefill(events: List[Dict], output_xlsx: str, target_layer: int = 3) - # Build launch->kernel mapping by correlation id. # Build launch candidates from current prefill thread/range once. runtime_launches = [ - e for e in prefill_hierarchy_events - if e.get('cat') == 'cuda_runtime' and is_kernel_launch(e.get('name', '')) + e + for e in prefill_hierarchy_events + if e.get("cat") == "cuda_runtime" and is_kernel_launch(e.get("name", "")) ] - runtime_launches.sort(key=lambda x: x.get('ts', 0)) - runtime_launch_ts = [e.get('ts', 0) for e in runtime_launches] + runtime_launches.sort(key=lambda x: x.get("ts", 0)) + runtime_launch_ts = [e.get("ts", 0) for e in runtime_launches] item_corrs: List[List[Any]] = [] corr_needed = set() for item in launch_level2_items: - l2 = item['level2_event'] - l2_start = l2.get('ts', 0) - l2_end = l2_start + l2.get('dur', 0) + l2 = item["level2_event"] + l2_start = l2.get("ts", 0) + l2_end = l2_start + l2.get("dur", 0) left = bisect.bisect_left(runtime_launch_ts, l2_start) right = bisect.bisect_right(runtime_launch_ts, l2_end) launches_in_l2 = runtime_launches[left:right] curr_corrs = [] for launch in launches_in_l2: - corr = (launch.get('args') or {}).get('correlation') + corr = (launch.get("args") or {}).get("correlation") if corr is not None: corr_needed.add(corr) curr_corrs.append(corr) @@ -480,42 +514,48 @@ def parse_prefill(events: List[Dict], output_xlsx: str, target_layer: int = 3) - kernel_by_corr: Dict[Any, List[Dict]] = {} if corr_needed: for e in events: - if e.get('ph') != 'X' or e.get('cat') != 'kernel': + if e.get("ph") != "X" or e.get("cat") != "kernel": continue - corr = (e.get('args') or {}).get('correlation') + corr = (e.get("args") or {}).get("correlation") if corr is None or corr not in corr_needed: continue kernel_by_corr.setdefault(corr, []).append(e) for corr in kernel_by_corr: - kernel_by_corr[corr].sort(key=lambda x: x.get('ts', 0)) + kernel_by_corr[corr].sort(key=lambda x: x.get("ts", 0)) item_kernels: List[List[Dict[str, Any]]] = [] for corrs in item_corrs: kernels = [] for corr in corrs: for k in kernel_by_corr.get(corr, []): - kernels.append({'name': k.get('name', 'N/A'), 'dur': k.get('dur', 0)}) + kernels.append({"name": k.get("name", "N/A"), "dur": k.get("dur", 0)}) item_kernels.append(kernels) def build_rows_from_item_range(start: int, end: int) -> List[List[Any]]: rows = [] for i in range(start, end): item = launch_level2_items[i] - mod_name = item['level2_event'].get('name', '') + mod_name = item["level2_event"].get("name", "") if should_filter_prefill(mod_name): continue - kernels = [k for k in item_kernels[i] if k['name'] not in ('', 'N/A')] + kernels = [k for k in item_kernels[i] if k["name"] not in ("", "N/A")] if not kernels: continue - if 'moe_forward' in mod_name.lower(): + if "moe_forward" in mod_name.lower(): rows.extend(process_moe_module(mod_name, len(kernels), 0, kernels)) else: for k in kernels: - rows.append([clean_module_name(mod_name, k['name']), k['name'], k['dur']]) + rows.append( + [clean_module_name(mod_name, k["name"]), k["name"], k["dur"]] + ) return rows # Target layer rows. - csv_rows = build_rows_from_item_range(mod_start, mod_end) if norm_start_idx < len(norm_indices) else [] + csv_rows = ( + build_rows_from_item_range(mod_start, mod_end) + if norm_start_idx < len(norm_indices) + else [] + ) print(f"Layer {TARGET_LAYER} launch->kernel mapping rows: {len(csv_rows)}") print(f"Prefill decode-style CSV rows (after filters): {len(csv_rows)}") @@ -532,93 +572,92 @@ def build_rows_from_item_range(start: int, end: int) -> List[List[Any]]: avg_layer_rows.append(build_rows_from_item_range(s, e)) layer += 1 if avg_layer_rows: - avg_rows = build_avg_rows_from_layers(avg_layer_rows, avg_start_layer, "Prefill") + avg_rows = build_avg_rows_from_layers( + avg_layer_rows, avg_start_layer, "Prefill" + ) if avg_rows is not None: print(f"Prefill avg rows: {len(avg_rows)}") # Write XLSX for prefill. - write_breakdown_xlsx(output_xlsx, csv_rows, sheet_name='prefill', avg_rows=avg_rows) + write_breakdown_xlsx(output_xlsx, csv_rows, sheet_name="prefill", avg_rows=avg_rows) -def clean_module_name(name: str, mapped_kernel_name: str = '') -> str: +def clean_module_name(name: str, mapped_kernel_name: str = "") -> str: """Clean and simplify module name.""" # Runtime launch wrappers should display the actual launched operator name. - if 'hipmodulelaunchkernel' in name.lower() and mapped_kernel_name not in ('', 'N/A'): + if "hipmodulelaunchkernel" in name.lower() and mapped_kernel_name not in ( + "", + "N/A", + ): name = mapped_kernel_name # Remove 'aiter::' prefix if present - if name.startswith('aiter::'): + if name.startswith("aiter::"): name = name[7:] # len('aiter::') == 7 - + # Rename based on keywords (rope takes priority) name_lower = name.lower() - if 'rope' in name_lower and 'cache' in name_lower: - return 'rope & kv_cache' - if 'rope' in name_lower: - return 'rope' - if 'cache' in name_lower and 'gemm' not in name_lower: - return 'kv_cache' - + if "rope" in name_lower and "cache" in name_lower: + return "rope & kv_cache" + if "rope" in name_lower: + return "rope" + if "cache" in name_lower and "gemm" not in name_lower: + return "kv_cache" + return name def process_moe_module( - mod_name: str, - kernel_count: int, - start_gpu_idx: int, - gpu_kernels: List[Dict] + mod_name: str, kernel_count: int, start_gpu_idx: int, gpu_kernels: List[Dict] ) -> List[List]: """ Process moe_forward module: categorize kernels by name. - + - 'moesort' in kernel name -> moe_sort - 'topk' in kernel name -> moe_topk - others -> keep original mod_name - + Returns list of [display_name, gpu_kernel_name, gpu_dur] rows. """ rows = [] for i in range(kernel_count): gpu_idx = start_gpu_idx + i - gpu_kernel_name = 'N/A' + gpu_kernel_name = "N/A" gpu_dur = 0 if gpu_idx < len(gpu_kernels): gpu = gpu_kernels[gpu_idx] - gpu_kernel_name = gpu.get('name', 'N/A') - gpu_dur = gpu.get('dur', 0) - + gpu_kernel_name = gpu.get("name", "N/A") + gpu_dur = gpu.get("dur", 0) + # Determine category based on kernel name kernel_lower = gpu_kernel_name.lower() - if 'moesort' in kernel_lower: - category = 'moe_sort' - elif 'topk' in kernel_lower: - category = 'moe_topk' + if "moesort" in kernel_lower: + category = "moe_sort" + elif "topk" in kernel_lower: + category = "moe_topk" else: category = clean_module_name(mod_name, gpu_kernel_name) - + # Always show category/module name on each row. display_name = category rows.append([display_name, gpu_kernel_name, gpu_dur]) - + return rows def process_regular_module( - mod_name: str, - kernel_count: int, - start_gpu_idx: int, - gpu_kernels: List[Dict] + mod_name: str, kernel_count: int, start_gpu_idx: int, gpu_kernels: List[Dict] ) -> List[List]: """Process regular module and show module name on every row.""" rows = [] for i in range(kernel_count): gpu_idx = start_gpu_idx + i - gpu_kernel_name = 'N/A' + gpu_kernel_name = "N/A" gpu_dur = 0 if gpu_idx < len(gpu_kernels): gpu = gpu_kernels[gpu_idx] - gpu_kernel_name = gpu.get('name', 'N/A') - gpu_dur = gpu.get('dur', 0) + gpu_kernel_name = gpu.get("name", "N/A") + gpu_dur = gpu.get("dur", 0) display_name = clean_module_name(mod_name, gpu_kernel_name) rows.append([display_name, gpu_kernel_name, gpu_dur]) return rows @@ -627,69 +666,71 @@ def process_regular_module( def parse_decode(events: List[Dict], output_xlsx: str, target_layer: int = 3) -> None: """ Parse decode phase: map capture_graph modules to GPU kernels. - + Output CSV columns: cpu_module, gpu_kernel, duration_us """ print("Building event index...") - + # Find GPU-annotated decode_step events (cat='gpu_user_annotation') decode_steps = [ - e for e in events - if e.get('name', '').startswith('decode_step') - and e.get('ph') == 'X' - and e.get('cat') == 'gpu_user_annotation' + e + for e in events + if e.get("name", "").startswith("decode_step") + and e.get("ph") == "X" + and e.get("cat") == "gpu_user_annotation" ] - decode_steps = sorted(decode_steps, key=lambda x: x['ts']) - + decode_steps = sorted(decode_steps, key=lambda x: x["ts"]) + if not decode_steps: print("No decode_step (gpu_user_annotation) events found.") return - + # Skip warmup: find first gap > 100ms (warmup/run boundary) # Normal decode gaps are < 5ms, so 100ms is safe threshold WARMUP_GAP_THRESHOLD = 100000 # 100ms in microseconds actual_run_idx = 0 found_warmup_boundary = False for i in range(1, len(decode_steps)): - gap = decode_steps[i]['ts'] - (decode_steps[i-1]['ts'] + decode_steps[i-1].get('dur', 0)) + gap = decode_steps[i]["ts"] - ( + decode_steps[i - 1]["ts"] + decode_steps[i - 1].get("dur", 0) + ) if gap > WARMUP_GAP_THRESHOLD: actual_run_idx = i found_warmup_boundary = True print(f"Warmup/run boundary at [{i-1}]->[{i}], gap={gap/1000:.1f}ms") break - + if not found_warmup_boundary: print("No warmup detected (no gap > 100ms), using first decode_step") - + first_ds = decode_steps[actual_run_idx] - first_ds_name = first_ds.get('name', '') + first_ds_name = first_ds.get("name", "") target_bs: Optional[int] = None - if '_bs_' in first_ds_name: - bs = first_ds_name.split('_bs_')[-1] - target_cg_name = f'capture_graph_bs_{bs}' + if "_bs_" in first_ds_name: + bs = first_ds_name.split("_bs_")[-1] + target_cg_name = f"capture_graph_bs_{bs}" try: target_bs = int(bs) except ValueError: target_bs = None else: - target_cg_name = 'capture_graph' - + target_cg_name = "capture_graph" + print(f"First decode_step: {first_ds_name}") print(f"Looking for: {target_cg_name}") - + # Find matching capture_graph capture_graphs = [ - e for e in events - if e.get('name') == target_cg_name and e.get('ph') == 'X' + e for e in events if e.get("name") == target_cg_name and e.get("ph") == "X" ] if not capture_graphs and target_bs is not None: # Prefer the largest capture_graph_bs_K where K < target_bs. lower_bs_candidates: List[Tuple[int, Dict[str, Any]]] = [] for e in events: - if e.get('ph') != 'X': + if e.get("ph") != "X": continue - n = e.get('name', '') - m = re.match(r'^capture_graph_bs_(\d+)$', n) + n = e.get("name", "") + m = re.match(r"^capture_graph_bs_(\d+)$", n) if not m: continue k = int(m.group(1)) @@ -699,109 +740,133 @@ def parse_decode(events: List[Dict], output_xlsx: str, target_layer: int = 3) -> best_bs = max(k for k, _ in lower_bs_candidates) capture_graphs = sorted( [e for k, e in lower_bs_candidates if k == best_bs], - key=lambda x: x.get('ts', 0), + key=lambda x: x.get("ts", 0), ) print(f"No exact match, using nearest lower capture_graph_bs_{best_bs}") if not capture_graphs: # Fallback: find any capture_graph capture_graphs = [ - e for e in events - if e.get('name', '').startswith('capture_graph') and e.get('ph') == 'X' + e + for e in events + if e.get("name", "").startswith("capture_graph") and e.get("ph") == "X" ] - capture_graphs = sorted(capture_graphs, key=lambda x: x['ts']) + capture_graphs = sorted(capture_graphs, key=lambda x: x["ts"]) print("No exact match, using first capture_graph") - + if not capture_graphs: print("No capture_graph events found.") return - + cg = capture_graphs[0] print(f"Using: {cg.get('name')}") - + # Build optimized index only for capture_graph's time range - cg_start = cg['ts'] - cg_end = cg_start + cg.get('dur', 0) + cg_start = cg["ts"] + cg_end = cg_start + cg.get("dur", 0) cg_events = [ - e for e in events - if e.get('ph') == 'X' and e.get('ts', 0) >= cg_start - and e.get('ts', 0) + e.get('dur', 0) <= cg_end + e + for e in events + if e.get("ph") == "X" + and e.get("ts", 0) >= cg_start + and e.get("ts", 0) + e.get("dur", 0) <= cg_end ] print(f"Events in capture_graph: {len(cg_events)}") idx = EventIndex(cg_events) - + # Get GPU kernels from first decode_step (within its duration) - ds1_start = first_ds['ts'] - ds1_end = ds1_start + first_ds.get('dur', 0) - + ds1_start = first_ds["ts"] + ds1_end = ds1_start + first_ds.get("dur", 0) + gpu_kernels = [ - e for e in events - if e.get('cat') == 'kernel' and ds1_start <= e['ts'] <= ds1_end + e + for e in events + if e.get("cat") == "kernel" and ds1_start <= e["ts"] <= ds1_end ] - gpu_kernels = sorted(gpu_kernels, key=lambda x: x['ts']) + gpu_kernels = sorted(gpu_kernels, key=lambda x: x["ts"]) print(f"First decode_step (tid={first_ds.get('tid')}): {first_ds_name}") - print(f" Range: {ds1_start:.0f} ~ {ds1_end:.0f} (dur={first_ds.get('dur', 0):.0f})") + print( + f" Range: {ds1_start:.0f} ~ {ds1_end:.0f} (dur={first_ds.get('dur', 0):.0f})" + ) print(f" GPU kernels: {len(gpu_kernels)}") - + # Use optimized index for children lookup direct_children = idx.get_direct_children(cg) kernel_children = [c for c in direct_children if idx.has_kernel_launch(c)] print(f"Direct children with kernels: {len(kernel_children)}") - + # Collect all modules with their kernel info all_modules = [] # list of (mod_name, kernel_count, start_gpu_idx) gpu_idx = 0 - + for child in kernel_children: - child_name = child.get('name', '') + child_name = child.get("name", "") if should_filter(child_name): continue - + # Get sub-children (actual module names) sub_children = idx.get_direct_children(child) sub_kernel_children = [sc for sc in sub_children if idx.has_kernel_launch(sc)] - + # Determine modules to process modules = sub_kernel_children if sub_kernel_children else [child] - + for mod in modules: - mod_name = mod.get('name', '') + mod_name = mod.get("name", "") kernel_count = idx.count_kernel_launches(mod) all_modules.append((mod_name, kernel_count, gpu_idx)) gpu_idx += kernel_count - + # Find norm positions (rmsnorm in name) - all_norm_indices = [i for i, (name, _, _) in enumerate(all_modules) if 'rmsnorm' in name.lower()] + all_norm_indices = [ + i for i, (name, _, _) in enumerate(all_modules) if "rmsnorm" in name.lower() + ] # Last rmsnorm is final layernorm, not part of transformer layers. norm_indices = all_norm_indices[:-1] if len(all_norm_indices) > 0 else [] print( f"Found {len(all_norm_indices)} norm modules " f"({len(norm_indices)} used for layer split, excluding final layernorm)" ) - + # Extract layer 3 (4th layer, 0-indexed) # Each layer has 2 norms, so layer N starts at norm index 2*N TARGET_LAYER = target_layer norm_start_idx = TARGET_LAYER * 2 # 6 (7th norm, 0-indexed) norm_end_idx = (TARGET_LAYER + 1) * 2 # 8 (9th norm, 0-indexed) - - final_norm_idx = all_norm_indices[-1] if len(all_norm_indices) > 0 else len(all_modules) + + final_norm_idx = ( + all_norm_indices[-1] if len(all_norm_indices) > 0 else len(all_modules) + ) if norm_start_idx >= len(norm_indices): print(f"Not enough norms for layer {TARGET_LAYER}") return - + # Module range for layer 3: from norm_indices[6] to norm_indices[8] (exclusive) mod_start = norm_indices[norm_start_idx] - mod_end = norm_indices[norm_end_idx] if norm_end_idx < len(norm_indices) else final_norm_idx - - print(f"Layer {TARGET_LAYER}: modules [{mod_start}:{mod_end}] (norms at indices {norm_start_idx}, {norm_start_idx+1})") - + mod_end = ( + norm_indices[norm_end_idx] + if norm_end_idx < len(norm_indices) + else final_norm_idx + ) + + print( + f"Layer {TARGET_LAYER}: modules [{mod_start}:{mod_end}] (norms at indices {norm_start_idx}, {norm_start_idx+1})" + ) + def build_rows_for_module_range(start: int, end: int) -> List[List[Any]]: rows = [] for mod_name, kernel_count, start_gpu_idx in all_modules[start:end]: - if 'moe_forward' in mod_name.lower(): - rows.extend(process_moe_module(mod_name, kernel_count, start_gpu_idx, gpu_kernels)) + if "moe_forward" in mod_name.lower(): + rows.extend( + process_moe_module( + mod_name, kernel_count, start_gpu_idx, gpu_kernels + ) + ) else: - rows.extend(process_regular_module(mod_name, kernel_count, start_gpu_idx, gpu_kernels)) + rows.extend( + process_regular_module( + mod_name, kernel_count, start_gpu_idx, gpu_kernels + ) + ) return rows # Target layer rows. @@ -821,10 +886,10 @@ def build_rows_for_module_range(start: int, end: int) -> List[List[Any]]: avg_rows = build_avg_rows_from_layers(avg_layer_rows, 3, "Decode") if avg_rows is not None: print(f"Decode avg rows: {len(avg_rows)}") - + # Write XLSX - write_breakdown_xlsx(output_xlsx, rows, sheet_name='decode', avg_rows=avg_rows) - + write_breakdown_xlsx(output_xlsx, rows, sheet_name="decode", avg_rows=avg_rows) + print(f"Layer {TARGET_LAYER} modules: {mod_end - mod_start}") print(f"XLSX written to: {output_xlsx} ({len(rows)} rows)") @@ -833,34 +898,39 @@ def build_rows_for_module_range(start: int, end: int) -> List[List[Any]]: # Main # ============================================================================= + def main(): - parser = argparse.ArgumentParser(description="Parse PyTorch profiler trace JSON to extract kernel information.") - parser.add_argument('filepath', help='Path to trace JSON or JSON.GZ file') - parser.add_argument('--layer', type=int, default=3, help='Target layer index (default: 3)') + parser = argparse.ArgumentParser( + description="Parse PyTorch profiler trace JSON to extract kernel information." + ) + parser.add_argument("filepath", help="Path to trace JSON or JSON.GZ file") + parser.add_argument( + "--layer", type=int, default=3, help="Target layer index (default: 3)" + ) args = parser.parse_args() if args.layer < 0: print("--layer must be >= 0") sys.exit(1) - + filepath = args.filepath target_layer = args.layer - + print(f"Loading: {filepath}") trace = load_trace(filepath) - events = trace.get('traceEvents', []) + events = trace.get("traceEvents", []) print(f"Loaded {len(events)} events\n") - + print("=" * 60) print("PREFILL ANALYSIS") print("=" * 60) - parse_prefill(events, 'prefill_breakdown.xlsx', target_layer=target_layer) - + parse_prefill(events, "prefill_breakdown.xlsx", target_layer=target_layer) + print("\n" + "=" * 60) print("DECODE ANALYSIS") print("=" * 60) - parse_decode(events, 'decode_breakdown.xlsx', target_layer=target_layer) + parse_decode(events, "decode_breakdown.xlsx", target_layer=target_layer) -if __name__ == '__main__': +if __name__ == "__main__": main() From f44e2e7c97bc363f9346f13bbc9ecb66cbf6da07 Mon Sep 17 00:00:00 2001 From: HaonanWang98 Date: Thu, 5 Mar 2026 05:41:46 +0000 Subject: [PATCH 7/9] fix ruff --- atom/utils/graph_marker_instrumentation.py | 5 +++-- tools/parse_trace.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/atom/utils/graph_marker_instrumentation.py b/atom/utils/graph_marker_instrumentation.py index 8e8dd4e51..ba97f8f40 100644 --- a/atom/utils/graph_marker_instrumentation.py +++ b/atom/utils/graph_marker_instrumentation.py @@ -188,7 +188,8 @@ def _iter_py_files(root: str) -> Iterable[str]: def _ensure_record_function_import(lines: list[str]) -> None: # If already imported or referenced via qualified name, do nothing. if any( - ("record_function" in l and ("import" in l or "from torch" in l)) for l in lines + ("record_function" in line and ("import" in line or "from torch" in line)) + for line in lines ): return @@ -407,7 +408,7 @@ def instrument_record_functions_in_file( # Only strip marker calls if we either: # - wrapped/upgraded this run, or # - the file already contains record_function blocks (previous run) - already_has_record = any("with record_function(" in l for l in lines) + already_has_record = any("with record_function(" in line for line in lines) if wrapped_or_upgraded or already_has_record: stripped = _strip_runtime_graph_markers(lines) diff --git a/tools/parse_trace.py b/tools/parse_trace.py index a9407cb2d..153ce112c 100644 --- a/tools/parse_trace.py +++ b/tools/parse_trace.py @@ -179,7 +179,7 @@ def build_avg_rows_from_layers( for i, (mod, kernel) in enumerate(base_sig): # Keep original module display style from layer_start_idx rows. display_mod = base[i][0] - avg_dur = sum(float(layer_rows_list[l][i][2]) for l in range(n)) / n + avg_dur = sum(float(layer_rows_list[layer_idx][i][2]) for layer_idx in range(n)) / n avg_rows.append([display_mod, kernel, avg_dur]) return avg_rows From 6ace63fbf905202215635f724f63007d5ab9b046 Mon Sep 17 00:00:00 2001 From: HaonanWang98 Date: Thu, 5 Mar 2026 05:52:32 +0000 Subject: [PATCH 8/9] fix black --- tools/parse_trace.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tools/parse_trace.py b/tools/parse_trace.py index 153ce112c..ba0ff04ba 100644 --- a/tools/parse_trace.py +++ b/tools/parse_trace.py @@ -179,7 +179,9 @@ def build_avg_rows_from_layers( for i, (mod, kernel) in enumerate(base_sig): # Keep original module display style from layer_start_idx rows. display_mod = base[i][0] - avg_dur = sum(float(layer_rows_list[layer_idx][i][2]) for layer_idx in range(n)) / n + avg_dur = ( + sum(float(layer_rows_list[layer_idx][i][2]) for layer_idx in range(n)) / n + ) avg_rows.append([display_mod, kernel, avg_dur]) return avg_rows From 6c720500c261864a139580ace8c53ac3fb4b6cc7 Mon Sep 17 00:00:00 2001 From: HaonanWang98 Date: Thu, 5 Mar 2026 05:55:20 +0000 Subject: [PATCH 9/9] reformat --- atom/model_engine/model_runner.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index acdbc7b76..5e379246c 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -1566,7 +1566,9 @@ def capture_cudagraph(self): input_ids[:num_tokens], positions[:num_tokens] ) if self.logits_in_graph: - graph_logits = self.model.compute_logits(outputs[:num_tokens]) + graph_logits = self.model.compute_logits( + outputs[:num_tokens] + ) if self.graph_pool is None: self.graph_pool = graph.pool() self.graphs[(bs, max_q_len)] = graph