Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions atom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,7 @@ class Config:
default_factory=lambda: QuantizationConfig()
)
asyncio_mode: bool = False
mark_trace: bool = False
load_dummy: bool = False
enable_expert_parallel: bool = False
master_addr: str = "127.0.0.1"
Expand Down
6 changes: 6 additions & 0 deletions atom/model_engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class EngineArgs:
enable_dp_attention: bool = False
method: Optional[str] = None
num_speculative_tokens: int = 1
mark_trace: bool = False

@staticmethod
def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
Expand Down Expand Up @@ -170,6 +171,11 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
help="Apply a delay (of delay factor multiplied by previous"
"prompt latency) before scheduling next prompt.",
)
parser.add_argument(
"--mark-trace",
action="store_true",
help="Enable graph_marker nodes for tracing/profile instrumentation.",
)

return parser

Expand Down
6 changes: 5 additions & 1 deletion atom/model_engine/engine_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(self, config: Config, input_address: str, output_address: str):
self.input_thread.start()

self.profile_enbaled = config.torch_profiler_dir is not None
self.mark_trace = getattr(config, "mark_trace", False)
init_exit_handler(self)
self._init_data_parallel(config)

Expand All @@ -83,6 +84,9 @@ def __init__(self, config: Config, input_address: str, output_address: str):

config.num_kvcache_blocks = num_blocks
if not config.enforce_eager:
# Start profiler before cudagraph capture only if mark-trace is enabled.
if self.profile_enbaled and self.mark_trace:
self.runner_mgr.call_func("start_profiler", wait_out=True)
cap_cost, bs = self.runner_mgr.call_func(
"capture_cudagraph", wait_out=True
)
Expand Down Expand Up @@ -284,7 +288,7 @@ def process_output_sockets(self, output_address: str):

def start_profiler(self):
if self.profile_enbaled:
self.runner_mgr.call_func("start_profiler")
self.runner_mgr.call_func("start_profiler", wait_out=True)

def stop_profiler(self):
if self.profile_enbaled:
Expand Down
62 changes: 44 additions & 18 deletions atom/model_engine/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import math
import os
import time
from contextlib import nullcontext
from typing import Any, Optional, Union

import numpy as np
Expand All @@ -19,6 +20,7 @@
graph_capture,
)
from aiter.dist.utils import get_distributed_init_method
from torch.profiler import record_function
from atom.config import Config, KVCacheTensor, set_current_atom_config
from atom.model_engine.scheduler import ScheduledBatch, ScheduledBatchOutput
from atom.model_engine.sequence import Sequence, SequenceStatus, SequenceType
Expand Down Expand Up @@ -464,6 +466,10 @@ class ModelRunner:

def __init__(self, rank: int, config: Config):
self.config = config
self.mark_trace = getattr(config, "mark_trace", False)
from atom.utils.graph_marker import set_graph_marker_enabled

set_graph_marker_enabled(self.mark_trace)
set_current_atom_config(config)
hf_config = config.hf_config
self.block_size = config.kv_cache_block_size
Expand Down Expand Up @@ -681,6 +687,7 @@ def start_profiler(self):
),
)
self.profiler.__enter__()
return True

def stop_profiler(self):
"""Stop profiling for this rank"""
Expand Down Expand Up @@ -1316,19 +1323,31 @@ def run_model(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor
is_prefill = context.is_prefill
positions = context.positions
if is_prefill or self.enforce_eager or bs > self.graph_bs[-1]:
hidden_states = self.model(input_ids, positions)
logits = self.model.compute_logits(hidden_states)
else:
graph_bs = context.graph_bs
max_q_len = forward_context.attn_metadata.max_seqlen_q
graph_key = (graph_bs, max_q_len)
self.graphs[graph_key].replay()
num_tokens = context.batch_size * max_q_len
hidden_states = self.forward_vars["outputs"][:num_tokens]
if self.logits_in_graph:
logits = self.graph_logits[graph_key][:num_tokens]
else:
with (
record_function(
f"prefill_bs_{bs}_ctxlens_{forward_context.attn_metadata.context_lens}"
)
if self.mark_trace
else nullcontext()
):
hidden_states = self.model(input_ids, positions)
logits = self.model.compute_logits(hidden_states)
else:
with (
record_function(f"decode_step_bs_{bs}")
if self.mark_trace
else nullcontext()
):
graph_bs = context.graph_bs
max_q_len = forward_context.attn_metadata.max_seqlen_q
graph_key = (graph_bs, max_q_len)
self.graphs[graph_key].replay()
num_tokens = context.batch_size * max_q_len
hidden_states = self.forward_vars["outputs"][:num_tokens]
if self.logits_in_graph:
logits = self.graph_logits[graph_key][:num_tokens]
else:
logits = self.model.compute_logits(hidden_states)

return logits, hidden_states

Expand Down Expand Up @@ -1537,12 +1556,19 @@ def capture_cudagraph(self):
# Capture: include compute_logits only when TP=1 since
# ParallelLMHead uses NCCL all_gather which is not
# graph-capturable on HIP when TP > 1.
with torch.cuda.graph(graph, self.graph_pool, stream=gc.stream):
outputs[:num_tokens] = self.model(
input_ids[:num_tokens], positions[:num_tokens]
)
if self.logits_in_graph:
graph_logits = self.model.compute_logits(outputs[:num_tokens])
with (
record_function(f"capture_graph_bs_{bs}")
if self.mark_trace
else nullcontext()
):
with torch.cuda.graph(graph, self.graph_pool, stream=gc.stream):
outputs[:num_tokens] = self.model(
input_ids[:num_tokens], positions[:num_tokens]
)
if self.logits_in_graph:
graph_logits = self.model.compute_logits(
outputs[:num_tokens]
)
if self.graph_pool is None:
self.graph_pool = graph.pool()
self.graphs[(bs, max_q_len)] = graph
Expand Down
20 changes: 19 additions & 1 deletion atom/model_ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from aiter.jit.utils.torch_guard import torch_compile_guard
from aiter.tuned_gemm import tgemm
from aiter.utility import fp4_utils
from atom.utils.decorators import mark_trace, record_function

from atom.model_ops.utils import shuffle_weights
from atom.utils import envs
Expand Down Expand Up @@ -168,17 +169,20 @@ def gemm_a8w8_blockscale_preshuffle_fake(
x_scale: torch.Tensor,
w_scale: torch.Tensor,
dtype: torch.dtype = torch.bfloat16,
prefix: str = "",
) -> torch.Tensor:
return torch.empty((*x.shape[:-1], weight.shape[0]), dtype=dtype, device=x.device)


@record_function
@torch_compile_guard(gen_fake=gemm_a8w8_blockscale_preshuffle_fake, mutates_args=[])
def gemm_a8w8_blockscale_preshuffle_impl(
x: torch.Tensor,
weight: torch.Tensor,
x_scale: torch.Tensor,
w_scale: torch.Tensor,
dtype: torch.dtype = torch.bfloat16,
prefix: str = "",
) -> torch.Tensor:
if gemm_a8w8_blockscale_bpreshuffle_triton is not None:
weight_shuffled = weight.reshape(weight.shape[0] // 16, weight.shape[1] * 16)
Expand All @@ -190,6 +194,7 @@ def gemm_a8w8_blockscale_preshuffle_impl(
return y


@mark_trace
class LinearBase(nn.Module):
def __init__(
self,
Expand All @@ -200,13 +205,15 @@ def __init__(
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = False,
source_quant_dtype: torch.dtype | None = None,
prefix: str = "",
):
if quant_config is None:
quant_config = QuantizationConfig()
self.source_quant_dtype = source_quant_dtype
quant_type = quant_config["quant_type"]
params_dtype = quant_config["quant_dtype"]
super().__init__()
self.prefix = prefix
self.reduce_results = reduce_results
self.input_size = input_size
self.output_size = (
Expand Down Expand Up @@ -421,7 +428,12 @@ def forward(
y += self.bias
elif self.quant_type.value == QuantType.per_1x128.value:
y = gemm_a8w8_blockscale_preshuffle_impl(
x, self.weight, x_scale, self.weight_scale, dtype=otype
x,
self.weight,
x_scale,
self.weight_scale,
dtype=otype,
prefix=self.prefix,
)
if self.bias is not None:
y += self.bias
Expand Down Expand Up @@ -460,6 +472,7 @@ def __init__(
bias=bias,
quant_config=quant_config,
source_quant_dtype=source_quant_dtype,
prefix=kwargs.get("prefix", ""),
)

def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
Expand All @@ -485,6 +498,7 @@ def __init__(
bias,
quant_config=quant_config,
source_quant_dtype=source_quant_dtype,
prefix=kwargs.get("prefix", ""),
)

def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
Expand Down Expand Up @@ -516,6 +530,7 @@ def __init__(
bias=bias,
quant_config=quant_config,
source_quant_dtype=source_quant_dtype,
prefix=prefix,
)

def weight_loader(
Expand Down Expand Up @@ -650,6 +665,7 @@ def __init__(
bias=bias,
quant_config=quant_config,
source_quant_dtype=source_quant_dtype,
prefix=kwargs.get("prefix", ""),
)

def weight_loader(
Expand Down Expand Up @@ -711,6 +727,7 @@ def __init__(
quant_config=quant_config,
reduce_results=reduce_results,
source_quant_dtype=source_quant_dtype,
prefix=prefix,
)

def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
Expand Down Expand Up @@ -746,6 +763,7 @@ def __init__(
bias=bias,
quant_config=quant_config,
source_quant_dtype=source_quant_dtype,
prefix=kwargs.get("prefix", ""),
)

def weight_loader(
Expand Down
1 change: 1 addition & 0 deletions atom/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,6 +1276,7 @@ def __init__(
bias=False,
quant_config=quant_config,
source_quant_dtype=source_quant_dtype,
prefix=f"{prefix}.q_a_proj",
)
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(
Expand Down
28 changes: 28 additions & 0 deletions atom/utils/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,34 @@ def compile(
handle,
)

# When mark-trace is enabled, post-processing may rewrite generated
# artifact sources. Force reloading from cache artifact so the current
# process uses the rewritten artifact immediately.
try:
from atom.utils.graph_marker import is_graph_marker_enabled

force_reload = is_graph_marker_enabled()
except Exception:
force_reload = False
if force_reload and handle is not None and not self.disable_cache:
try:
reloaded_graph = self.load(
graph, example_inputs, graph_index, runtime_shape
)
if reloaded_graph is not None:
compiled_graph = reloaded_graph
logger.info(
"Force reloaded compiled graph from cache artifact "
"(graph_index=%s, runtime_shape=%s).",
graph_index,
runtime_shape,
)
except Exception:
logger.exception(
"Failed to force reload compiled graph from cache artifact; "
"falling back to in-memory compiled callable."
)

# after compiling the last graph, record the end time
if graph_index == num_graphs - 1:
now = time.time()
Expand Down
41 changes: 41 additions & 0 deletions atom/utils/compiler_inferface.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,15 @@ def set_inductor_config(config, runtime_shape):
config["max_autotune"] = True
config["coordinate_descent_tuning"] = True

try:
from atom.utils.graph_marker import is_graph_marker_enabled

if is_graph_marker_enabled():
config["size_asserts"] = False
config["compile_threads"] = 1
except Exception:
pass


class InductorAdaptor(CompilerInterface):
"""
Expand Down Expand Up @@ -395,6 +404,21 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
assert (
file_path is not None
), "failed to get the file path of the compiled graph"
# Best-effort post-process the generated wrapper file too (PyTorch <2.8 path).
try:
# Only run post-processing when mark-trace is enabled (to avoid any
# overhead / file churn in default runs).
from atom.utils.graph_marker import is_graph_marker_enabled

if is_graph_marker_enabled():
# Local import to avoid extra package-level side effects.
from .graph_marker_instrumentation import (
instrument_record_functions_in_file,
)

instrument_record_functions_in_file(file_path, strip_markers=False)
except Exception:
pass
return compiled_graph, (hash_str, file_path)

def load(
Expand Down Expand Up @@ -560,6 +584,23 @@ def compile(
# if not envs.VLLM_DISABLE_COMPILE_CACHE:
compiled_graph.save(path=path, format="unpacked")
compilation_counter.num_compiled_artifacts_saved += 1
# Post-process generated wrapper Python files: wrap regions between
# <prefix>_start / <prefix>_end graph markers with record_function("<prefix>").
try:
# Only run post-processing when mark-trace is enabled (to avoid any
# overhead / file churn in default runs).
from atom.utils.graph_marker import is_graph_marker_enabled

if is_graph_marker_enabled():
# Local import to avoid extra package-level side effects.
from .graph_marker_instrumentation import (
instrument_record_functions_in_dir,
)

instrument_record_functions_in_dir(path, strip_markers=False)
except Exception:
# Best-effort: never fail compilation due to instrumentation.
pass
return compiled_graph, (key, path)

def load(
Expand Down
Loading
Loading