diff --git a/python/sglang/srt/configs/nemotron_h.py b/python/sglang/srt/configs/nemotron_h.py index 833e97d87928..a6383d140e6a 100644 --- a/python/sglang/srt/configs/nemotron_h.py +++ b/python/sglang/srt/configs/nemotron_h.py @@ -15,7 +15,6 @@ """NemotronH model configuration""" -import regex as re from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging @@ -151,14 +150,44 @@ class NemotronHConfig(PretrainedConfig): model_type = "nemotron_h" keys_to_ignore_at_inference = ["past_key_values"] + @staticmethod + def _validate_layers_block_type( + layers_block_type, expected_length=None, param_name="layers_block_type" + ): + """ + Validate layers_block_type list. + Args: + layers_block_type: List of layer types to validate + expected_length: If provided, validate the list has this length + param_name: Parameter name for error messages + Raises: + ValueError: If validation fails + """ + if not isinstance(layers_block_type, list): + raise ValueError( + f"{param_name} must be a list of strings. Got type: {type(layers_block_type)}" + ) + + if expected_length is not None and len(layers_block_type) != expected_length: + raise ValueError( + f"{param_name} must have length {expected_length}. Got length {len(layers_block_type)}." + ) + + valid_types = {"mamba", "attention", "moe"} + if not all(block_type in valid_types for block_type in layers_block_type): + invalid = set(layers_block_type) - valid_types + raise ValueError( + f"{param_name} contains invalid types: {invalid}. Must be one of: {valid_types}" + ) + def __init__( self, vocab_size=131072, tie_word_embeddings=False, hidden_size=4096, intermediate_size=21504, - num_hidden_layers=52, - hybrid_override_pattern="M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-", + num_hidden_layers=None, # Deprecated, only for backward compatibility + layers_block_type=None, num_attention_heads=32, head_dim=128, num_key_value_heads=8, # nemo: num_query_groups @@ -204,14 +233,45 @@ def __init__( n_group=1, topk_group=1, norm_topk_prob=True, + num_nextn_predict_layers=0, + mtp_layers_block_type=["attention", "moe"], **kwargs, ): + + # Backward compatibility: convert hybrid_override_pattern to layers_block_type + # Always pop hybrid_override_pattern from kwargs to prevent it from being set as an attribute + if "hybrid_override_pattern" in kwargs: + pattern = kwargs.pop("hybrid_override_pattern") + if layers_block_type is None: + layers_block_type = self._pattern_to_list(pattern) + elif layers_block_type is None: + # Default layers_block_type if not provided + layers_block_type = ["mamba", "moe", "attention", "moe"] + + # Note: num_hidden_layers is deprecated and ignored if layers_block_type is explicitly provided + # It's only kept for backward compatibility when loading old configs + if num_hidden_layers is not None: + # Warn if num_hidden_layers is provided but doesn't match layers_block_type + if len(layers_block_type) != num_hidden_layers: + logger.warning( + f"num_hidden_layers ({num_hidden_layers}) is deprecated and doesn't match " + f"layers_block_type length ({len(layers_block_type)}). Using layers_block_type length." + ) + + # Backward compatibility: convert mtp_hybrid_override_pattern to mtp_layers_block_type + # Always pop mtp_hybrid_override_pattern from kwargs to prevent it from being set as an attribute + if "mtp_hybrid_override_pattern" in kwargs: + pattern = kwargs.pop("mtp_hybrid_override_pattern") + if mtp_layers_block_type is None or mtp_layers_block_type == [ + "attention", + "moe", + ]: + mtp_layers_block_type = self._pattern_to_list(pattern) + self.vocab_size = vocab_size self.tie_word_embeddings = tie_word_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.hybrid_override_pattern = hybrid_override_pattern self.num_attention_heads = num_attention_heads self.head_dim = head_dim self.sliding_window = sliding_window @@ -219,15 +279,11 @@ def __init__( self.attention_dropout = attention_dropout self.hidden_dropout = hidden_dropout - # Validate hybrid_override_pattern - # M: Mamba2, *: Attention, -: MLP - assert ( - len(self.hybrid_override_pattern) == self.num_hidden_layers - ), "hybrid_override_pattern must have same length as num_hidden_layers" - assert re.match( - r"^[*\-ME]+$", self.hybrid_override_pattern - ), "hybrid_override_pattern must only contain characters 'M', '*', '-' or 'E'" - + # Validate layers_block_type (no longer checking length against num_hidden_layers) + self._validate_layers_block_type( + layers_block_type, expected_length=None, param_name="layers_block_type" + ) + self.layers_block_type = layers_block_type # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads @@ -271,6 +327,22 @@ def __init__( self.topk_group = topk_group self.norm_topk_prob = norm_topk_prob + # MTP config + self.num_nextn_predict_layers = num_nextn_predict_layers + + # Validate mtp_layers_block_type is provided when MTP is enabled + if self.num_nextn_predict_layers > 0: + if mtp_layers_block_type is None: + raise ValueError( + "mtp_layers_block_type is required when num_nextn_predict_layers > 0. " + "Please provide an explicit list of layer types for MTP layers. " + "Example: mtp_layers_block_type=['attention', 'moe']" + ) + self._validate_layers_block_type( + mtp_layers_block_type, None, "mtp_layers_block_type" + ) + self.mtp_layers_block_type = mtp_layers_block_type + super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, @@ -312,3 +384,56 @@ def mamba2_cache_params(self) -> Mamba2CacheParams: return Mamba2CacheParams( shape=shape, layers=self.mamba_layer_ids, dtype=mamba2_state_dtype(self) ) + + @property + def num_hidden_layers(self) -> int: + """ + Number of hidden layers derived from the length of layers_block_type. + This property replaces the deprecated num_hidden_layers parameter. + """ + return len(self.layers_block_type) + + @num_hidden_layers.setter + def num_hidden_layers(self, value): + """ + Setter for backward compatibility when loading configs. + The value is ignored since num_hidden_layers is computed from layers_block_type. + """ + # Ignore the value - num_hidden_layers is always derived from layers_block_type + pass + + @property + def hybrid_override_pattern(self) -> str: + """ + Backward compatibility property. + Returns the pattern string representation of layers_block_type. + """ + return self._list_to_pattern(self.layers_block_type) + + @hybrid_override_pattern.setter + def hybrid_override_pattern(self, value): + """ + Setter for backward compatibility when loading configs. + The value is ignored since hybrid_override_pattern is computed from layers_block_type. + """ + self.layers_block_type = self._pattern_to_list(value) + + @property + def mtp_hybrid_override_pattern(self) -> str: + """ + Backward compatibility property. + Returns the pattern string representation of mtp_layers_block_type. + """ + return self._list_to_pattern(self.mtp_layers_block_type) + + @staticmethod + def _list_to_pattern(layers_list: list) -> str: + """Convert list of layer types back to pattern string (for backward compatibility).""" + reverse_mapping = {"mamba": "M", "moe": "E", "attention": "*"} + return "".join(reverse_mapping[layer_type] for layer_type in layers_list) + + @staticmethod + def _pattern_to_list(pattern: str) -> list: + """Convert pattern string to list of layer types (for backward compatibility).""" + pattern_mapping = {"M": "mamba", "E": "moe", "*": "attention"} + return [pattern_mapping[char] for char in pattern]