Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 139 additions & 14 deletions python/sglang/srt/configs/nemotron_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

"""NemotronH model configuration"""

import regex as re
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging

Expand Down Expand Up @@ -151,14 +150,44 @@ class NemotronHConfig(PretrainedConfig):
model_type = "nemotron_h"
keys_to_ignore_at_inference = ["past_key_values"]

@staticmethod
def _validate_layers_block_type(
layers_block_type, expected_length=None, param_name="layers_block_type"
):
"""
Validate layers_block_type list.
Args:
layers_block_type: List of layer types to validate
expected_length: If provided, validate the list has this length
param_name: Parameter name for error messages
Raises:
ValueError: If validation fails
"""
if not isinstance(layers_block_type, list):
raise ValueError(
f"{param_name} must be a list of strings. Got type: {type(layers_block_type)}"
)

if expected_length is not None and len(layers_block_type) != expected_length:
raise ValueError(
f"{param_name} must have length {expected_length}. Got length {len(layers_block_type)}."
)

valid_types = {"mamba", "attention", "moe"}
if not all(block_type in valid_types for block_type in layers_block_type):
invalid = set(layers_block_type) - valid_types
raise ValueError(
f"{param_name} contains invalid types: {invalid}. Must be one of: {valid_types}"
)

def __init__(
self,
vocab_size=131072,
tie_word_embeddings=False,
hidden_size=4096,
intermediate_size=21504,
num_hidden_layers=52,
hybrid_override_pattern="M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-",
num_hidden_layers=None, # Deprecated, only for backward compatibility
layers_block_type=None,
num_attention_heads=32,
head_dim=128,
num_key_value_heads=8, # nemo: num_query_groups
Expand Down Expand Up @@ -204,30 +233,57 @@ def __init__(
n_group=1,
topk_group=1,
norm_topk_prob=True,
num_nextn_predict_layers=0,
mtp_layers_block_type=["attention", "moe"],
**kwargs,
):

# Backward compatibility: convert hybrid_override_pattern to layers_block_type
# Always pop hybrid_override_pattern from kwargs to prevent it from being set as an attribute
if "hybrid_override_pattern" in kwargs:
pattern = kwargs.pop("hybrid_override_pattern")
if layers_block_type is None:
layers_block_type = self._pattern_to_list(pattern)
elif layers_block_type is None:
# Default layers_block_type if not provided
layers_block_type = ["mamba", "moe", "attention", "moe"]

# Note: num_hidden_layers is deprecated and ignored if layers_block_type is explicitly provided
# It's only kept for backward compatibility when loading old configs
if num_hidden_layers is not None:
# Warn if num_hidden_layers is provided but doesn't match layers_block_type
if len(layers_block_type) != num_hidden_layers:
logger.warning(
f"num_hidden_layers ({num_hidden_layers}) is deprecated and doesn't match "
f"layers_block_type length ({len(layers_block_type)}). Using layers_block_type length."
)

# Backward compatibility: convert mtp_hybrid_override_pattern to mtp_layers_block_type
# Always pop mtp_hybrid_override_pattern from kwargs to prevent it from being set as an attribute
if "mtp_hybrid_override_pattern" in kwargs:
pattern = kwargs.pop("mtp_hybrid_override_pattern")
if mtp_layers_block_type is None or mtp_layers_block_type == [
"attention",
"moe",
]:
mtp_layers_block_type = self._pattern_to_list(pattern)

self.vocab_size = vocab_size
self.tie_word_embeddings = tie_word_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.hybrid_override_pattern = hybrid_override_pattern
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
self.sliding_window = sliding_window
self.max_position_embeddings = max_position_embeddings
self.attention_dropout = attention_dropout
self.hidden_dropout = hidden_dropout

# Validate hybrid_override_pattern
# M: Mamba2, *: Attention, -: MLP
assert (
len(self.hybrid_override_pattern) == self.num_hidden_layers
), "hybrid_override_pattern must have same length as num_hidden_layers"
assert re.match(
r"^[*\-ME]+$", self.hybrid_override_pattern
), "hybrid_override_pattern must only contain characters 'M', '*', '-' or 'E'"

# Validate layers_block_type (no longer checking length against num_hidden_layers)
self._validate_layers_block_type(
layers_block_type, expected_length=None, param_name="layers_block_type"
)
self.layers_block_type = layers_block_type
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
Expand Down Expand Up @@ -271,6 +327,22 @@ def __init__(
self.topk_group = topk_group
self.norm_topk_prob = norm_topk_prob

# MTP config
self.num_nextn_predict_layers = num_nextn_predict_layers

# Validate mtp_layers_block_type is provided when MTP is enabled
if self.num_nextn_predict_layers > 0:
if mtp_layers_block_type is None:
raise ValueError(
"mtp_layers_block_type is required when num_nextn_predict_layers > 0. "
"Please provide an explicit list of layer types for MTP layers. "
"Example: mtp_layers_block_type=['attention', 'moe']"
)
self._validate_layers_block_type(
mtp_layers_block_type, None, "mtp_layers_block_type"
)
self.mtp_layers_block_type = mtp_layers_block_type

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
Expand Down Expand Up @@ -312,3 +384,56 @@ def mamba2_cache_params(self) -> Mamba2CacheParams:
return Mamba2CacheParams(
shape=shape, layers=self.mamba_layer_ids, dtype=mamba2_state_dtype(self)
)

@property
def num_hidden_layers(self) -> int:
"""
Number of hidden layers derived from the length of layers_block_type.
This property replaces the deprecated num_hidden_layers parameter.
"""
return len(self.layers_block_type)

@num_hidden_layers.setter
def num_hidden_layers(self, value):
"""
Setter for backward compatibility when loading configs.
The value is ignored since num_hidden_layers is computed from layers_block_type.
"""
# Ignore the value - num_hidden_layers is always derived from layers_block_type
pass

@property
def hybrid_override_pattern(self) -> str:
"""
Backward compatibility property.
Returns the pattern string representation of layers_block_type.
"""
return self._list_to_pattern(self.layers_block_type)

@hybrid_override_pattern.setter
def hybrid_override_pattern(self, value):
"""
Setter for backward compatibility when loading configs.
The value is ignored since hybrid_override_pattern is computed from layers_block_type.
"""
self.layers_block_type = self._pattern_to_list(value)

@property
def mtp_hybrid_override_pattern(self) -> str:
"""
Backward compatibility property.
Returns the pattern string representation of mtp_layers_block_type.
"""
return self._list_to_pattern(self.mtp_layers_block_type)

@staticmethod
def _list_to_pattern(layers_list: list) -> str:
"""Convert list of layer types back to pattern string (for backward compatibility)."""
reverse_mapping = {"mamba": "M", "moe": "E", "attention": "*"}
return "".join(reverse_mapping[layer_type] for layer_type in layers_list)

@staticmethod
def _pattern_to_list(pattern: str) -> list:
"""Convert pattern string to list of layer types (for backward compatibility)."""
pattern_mapping = {"M": "mamba", "E": "moe", "*": "attention"}
return [pattern_mapping[char] for char in pattern]