diff --git a/nemo_deploy/llm/inference/inference_base.py b/nemo_deploy/llm/inference/inference_base.py index a39025400..54f43a812 100644 --- a/nemo_deploy/llm/inference/inference_base.py +++ b/nemo_deploy/llm/inference/inference_base.py @@ -16,7 +16,7 @@ import atexit import logging from pathlib import Path -from typing import Any, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import megatron.core.dist_checkpointing.serialization as dist_ckpt import torch @@ -27,11 +27,13 @@ get_default_load_sharded_strategy, ) from megatron.core.dist_checkpointing.validation import StrictHandling -from megatron.core.inference.contexts import StaticInferenceContext from megatron.core.inference.engines.mcore_engine import MCoreEngine from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import ( GPTInferenceWrapper, ) +from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( + InferenceWrapperConfig, +) from megatron.core.inference.text_generation_controllers.text_generation_controller import ( TextGenerationController, ) @@ -39,8 +41,6 @@ from megatron.core.transformer.module import MegatronModule from packaging import version -from nemo_export_deploy_common.import_utils import MISSING_NEMO_MSG, UnavailableError - from .tron_utils import ( DistributedInitConfig, RNGConfig, @@ -62,29 +62,13 @@ except ImportError: HAVE_TRITON = False -try: - if not HAVE_TRITON: - raise ImportError("Triton is not installed") - from nemo.collections.llm.gpt.model.base import GPTConfig - from nemo.collections.llm.inference.base import MCoreTokenizerWrappper - from nemo.collections.llm.modelopt import set_modelopt_spec_if_exists_in_ckpt - from nemo.collections.llm.t5.model.t5 import T5Config - from nemo.lightning import io - from nemo.lightning.ckpt_utils import ckpt_to_context_subdir - from nemo.lightning.io.pl import ckpt_to_weights_subdir - - HAVE_NEMO = True -except (ImportError, ModuleNotFoundError): - HAVE_NEMO = False - from typing import Any - - io = None - GPTConfig = Any - T5Config = Any - MCoreTokenizerWrappper = Any - set_modelopt_spec_if_exists_in_ckpt = None - ckpt_to_weights_subdir = None - ckpt_to_context_subdir = None +from .nemo_utils import ( + MCoreTokenizerWrappper, + ckpt_to_context_subdir, + ckpt_to_weights_subdir, + io, + set_modelopt_spec_if_exists_in_ckpt, +) LOGGER = logging.getLogger("NeMo") @@ -201,8 +185,6 @@ def load_nemo_checkpoint_to_tron_model(model: List[MegatronModule], path: Path, path (Path): Path to NeMo checkpoint directory legacy_ckpt (bool): Whether to use legacy checkpoint format """ - if not HAVE_NEMO: - raise UnavailableError(MISSING_NEMO_MSG) weights_dir = ckpt_to_weights_subdir(path, is_saving=False) LOGGER.info(f"Loading NeMo checkpoint from {weights_dir}") @@ -324,9 +306,6 @@ def setup_model_and_tokenizer_for_inference( Raises: ValueError: If checkpoint_path is not a valid NeMo-2.0 checkpoint """ - if not HAVE_NEMO: - raise UnavailableError(MISSING_NEMO_MSG) - checkpoint_path = Path(checkpoint_path) # Load model context for config and tokenizer @@ -465,6 +444,7 @@ def create_mcore_engine( model_type: str = "gpt", model_format: str = "nemo", micro_batch_size: Optional[int] = None, + buffer_size_gb: float = 10.0, **model_config_kwargs, ) -> Tuple[MCoreEngineWithCleanup, GPTInferenceWrapper, Union[MCoreTokenizerWrappper, MegatronTokenizer]]: """Set up the model, tokenizer and MCoreEngine for inference. @@ -492,9 +472,6 @@ def create_mcore_engine( - GPTInferenceWrapper: Inference-wrapped model - Union[MCoreTokenizerWrappper, MegatronTokenizer]: Tokenizer instance """ - if not HAVE_NEMO and model_format == "nemo": - raise UnavailableError(MISSING_NEMO_MSG) - # Default to 1 for any parallelism dimension that's None tensor_model_parallel_size = tensor_model_parallel_size if tensor_model_parallel_size is not None else 1 pipeline_model_parallel_size = pipeline_model_parallel_size if pipeline_model_parallel_size is not None else 1 @@ -534,8 +511,17 @@ def create_mcore_engine( else: raise ValueError(f"Model format {model_format} not supported.") - inference_context = StaticInferenceContext(max_batch_size, inference_max_seq_length) - model_inference_wrapper = GPTInferenceWrapper(model, inference_context) + inner_model = peel(model) + model_config = inner_model.config + inference_wrapper_config = InferenceWrapperConfig( + hidden_size=model_config.hidden_size, + params_dtype=model_config.params_dtype, + inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold, + padded_vocab_size=inner_model.vocab_size, + inference_max_requests=max_batch_size, + inference_max_seq_length=inference_max_seq_length, + ) + model_inference_wrapper = GPTInferenceWrapper(model, inference_wrapper_config) text_generation_controller = TextGenerationController( inference_wrapped_model=model_inference_wrapper, tokenizer=tokenizer ) @@ -543,6 +529,7 @@ def create_mcore_engine( text_generation_controller=text_generation_controller, max_batch_size=max_batch_size, random_seed=random_seed, + buffer_size_gb=buffer_size_gb, ) # Wrap the engine to ensure cleanup diff --git a/nemo_deploy/llm/inference/nemo_io.py b/nemo_deploy/llm/inference/nemo_io.py new file mode 100644 index 000000000..67f453d15 --- /dev/null +++ b/nemo_deploy/llm/inference/nemo_io.py @@ -0,0 +1,417 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""IO utilities for loading NeMo 2.0 checkpoints without a direct nemo import. + +Copied from the NeMo project (https://github.com/NVIDIA/NeMo). Static +``from nemo import …`` statements are removed; the logic is otherwise +identical to the upstream sources. When a NeMo checkpoint is actually +loaded at runtime, NeMo classes are imported transitively through +``pydoc.locate`` — NeMo must therefore still be installed to read NeMo +checkpoints. + +Sources +------- + - IOProtocol : nemo/lightning/io/capture.py + - IO helpers, load : nemo/lightning/io/mixin.py + - load_context : nemo/lightning/io/api.py + - Torch-dtype fiddle + registration : nemo/lightning/io/fdl_torch.py +""" + +from __future__ import annotations + +import dataclasses +import functools +import inspect +import json +import logging +import threading +import uuid +from pathlib import Path +from pydoc import locate +from typing import Any, Dict, Generic, List, Optional, Protocol, TypeVar, runtime_checkable + +import fiddle as fdl +import fiddle._src.experimental.dataclasses as fdl_dc +import torch +from cloudpickle import dump +from cloudpickle import load as pickle_load +from fiddle._src.experimental import serialization + +_logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Thread-local storage (mirrors nemo.lightning.io.mixin._thread_local) +# --------------------------------------------------------------------------- + +_thread_local = threading.local() + + +def _set_thread_local_output_dir(path: Path) -> None: + """Set output_dir in our thread-local and in NeMo's (if already imported). + + NeMo classes registered before our first load call will use NeMo's + _io_unflatten_object, which reads from NeMo's own _thread_local. We + mirror the value there so that pickle-based artifacts still resolve + correctly even in that edge case. + """ + _thread_local.output_dir = path + try: + import sys + + nemo_mixin = sys.modules.get("nemo.lightning.io.mixin") + if nemo_mixin is not None and hasattr(nemo_mixin, "_thread_local"): + nemo_mixin._thread_local.output_dir = path + except Exception: + pass + + +# --------------------------------------------------------------------------- +# Register torch dtypes as fiddle constants +# (from nemo.lightning.io.fdl_torch.enable — only register_constant calls +# are needed for deserialization; libcst / codegen parts are omitted) +# --------------------------------------------------------------------------- + +_TORCH_DTYPE_NAMES = [ + "bool", + "uint8", + "int8", + "int16", + "int32", + "int64", + "float16", + "bfloat16", + "float32", + "float64", + "complex64", + "complex128", +] + + +def _register_torch_dtypes() -> None: + for name in _TORCH_DTYPE_NAMES: + if hasattr(torch, name): + serialization.register_constant("torch", name, compare_by_identity=True) + + +_register_torch_dtypes() + +# --------------------------------------------------------------------------- +# IOProtocol (from nemo.lightning.io.capture) +# --------------------------------------------------------------------------- + +SelfT = TypeVar("SelfT", covariant=True) + + +@runtime_checkable +class IOProtocol(Protocol, Generic[SelfT]): + @property + def __io__(self) -> fdl.Config[SelfT]: ... # noqa: E704 + + +# --------------------------------------------------------------------------- +# IO helper functions (from nemo.lightning.io.mixin) +# --------------------------------------------------------------------------- + + +def _io_transform_args(self, init_fn, *args, **kwargs) -> Dict[str, Any]: + """Capture __init__ arguments as a plain dict for fdl.Config creation.""" + sig = inspect.signature(init_fn) + bound_args = sig.bind_partial(self, *args, **kwargs) + config_kwargs = {k: v for k, v in bound_args.arguments.items() if k != "self"} + + to_del: List[str] = [] + for key in config_kwargs: + if isinstance(config_kwargs[key], IOProtocol): + config_kwargs[key] = config_kwargs[key].__io__ + if dataclasses.is_dataclass(config_kwargs[key]): + config_kwargs[key] = fdl_dc.convert_dataclasses_to_configs(config_kwargs[key], allow_post_init=True) + if config_kwargs[key].__class__.__name__ == "_HAS_DEFAULT_FACTORY_CLASS": + to_del.append(key) + + for key in to_del: + del config_kwargs[key] + + return config_kwargs + + +def _io_init(self, **kwargs) -> fdl.Config: + """Create an fdl.Config for *self* from captured init kwargs.""" + try: + return fdl.Config(type(self), **kwargs) + except Exception as e: + raise RuntimeError( + f"Error creating fdl.Config for {type(self).__name__}: {e}\n" + f"Arguments that caused the error: {kwargs}" + ) from e + + +def _io_wrap_init(cls): + """Wrap cls.__init__ to populate __io__ on every instance.""" + original_init = cls.__init__ + + if getattr(cls, "__wrapped_init__", False): + return cls + + @functools.wraps(original_init) + def wrapped_init(self, *args, **kwargs): + if hasattr(self, "io_transform_args"): + cfg_kwargs = self.io_transform_args(original_init, *args, **kwargs) + else: + cfg_kwargs = _io_transform_args(self, original_init, *args, **kwargs) + if hasattr(self, "io_init"): + self.__io__ = self.io_init(**cfg_kwargs) + else: + self.__io__ = _io_init(self, **cfg_kwargs) + original_init(self, *args, **kwargs) + + cls.__init__ = wrapped_init + cls.__wrapped_init__ = True + return cls + + +def _io_flatten_object(instance): + """Flatten an IOMixin object to a form fiddle can serialize.""" + try: + serialization.dump_json(instance.__io__) + except (serialization.UnserializableValueError, AttributeError) as exc: + if not hasattr(_thread_local, "local_artifacts_dir") or not hasattr(_thread_local, "output_path"): + raise exc + + local_artifact_path = Path(_thread_local.local_artifacts_dir) / f"{uuid.uuid4()}" + output_path = _thread_local.output_path + artifact_path = output_path / local_artifact_path + with open(artifact_path, "wb") as f: + dump(getattr(instance, "__io__", instance), f) + return (str(local_artifact_path),), None + + return instance.__io__.__flatten__() + + +def _io_unflatten_object(values, metadata): + """Unflatten an IOMixin object; load from pickle if it was saved that way.""" + if not hasattr(_thread_local, "output_dir"): + return fdl.Config.__unflatten__(values, metadata) + + output_dir = _thread_local.output_dir + if len(values) == 1: + pickle_path = values[0] + with open(Path(output_dir) / pickle_path, "rb") as f: + return pickle_load(f) + + return fdl.Config.__unflatten__(values, metadata) + + +def _io_path_elements_fn(x): + """Return the path elements for fiddle graph traversal.""" + try: + serialization.dump_json(x.__io__) + except (serialization.UnserializableValueError, AttributeError): + return (serialization.IdentityElement(),) + + return x.__io__.__path_elements__() + + +def _io_register_serialization(cls) -> None: + """Register fiddle traversal functions for *cls* using our _thread_local.""" + serialization.register_node_traverser( + cls, + flatten_fn=_io_flatten_object, + unflatten_fn=_io_unflatten_object, + path_elements_fn=_io_path_elements_fn, + ) + + +def track_io(target, artifacts=None): + """Add fiddle IO functionality to a class or all eligible classes in a module. + + Copied from ``nemo.lightning.io.mixin.track_io``. + """ + import types as _types + + def _add_io_to_class(cls): + if inspect.isclass(cls) and hasattr(cls, "__init__") and not hasattr(cls, "__io__"): + if cls in [str, int, float, tuple, list, dict, bool, type(None)]: + return cls + cls = _io_wrap_init(cls) + _io_register_serialization(cls) + cls.__io_artifacts__ = artifacts or [] + return cls + + def _is_in_module(obj, module): + return obj.__module__ == module.__name__ or obj.__module__.startswith(f"{module.__name__}.") + + def _process_module(module): + for name, obj in inspect.getmembers(module): + if inspect.isclass(obj) and _is_in_module(obj, module): + setattr(module, name, _add_io_to_class(obj)) + return module + + if isinstance(target, _types.ModuleType): + return _process_module(target) + elif inspect.isclass(target): + return _add_io_to_class(target) + else: + raise TypeError("Target must be a module or a class") + + +def drop_unexpected_params(config: fdl.Config) -> bool: + """Remove deprecated / unexpected parameters from a fiddle Config tree. + + Copied from ``nemo.lightning.io.mixin.drop_unexpected_params``. + """ + updated = False + + def analyze(cfg, prefix: str): + nonlocal updated + if not isinstance(cfg, fdl.Config): + return + signature = inspect.signature(cfg.__fn_or_cls__) + accept_kwargs = any(p.kind is inspect.Parameter.VAR_KEYWORD for p in signature.parameters.values()) + if not accept_kwargs: + to_drop = [p for p in cfg.__arguments__ if p not in signature.parameters] + if to_drop: + updated = True + _logger.warning("Deprecated parameters to drop from %s: %s", prefix, to_drop) + for p in to_drop: + del cfg.__arguments__[p] + for key, value in cfg.__arguments__.items(): + analyze(value, f"{prefix}.{key}") + + analyze(config, "") + return updated + + +def _artifact_transform_load(cfg: fdl.Config, path: Path) -> None: + """Rewrite artifact paths stored in a config to absolute paths. + + Copied from ``nemo.lightning.io.mixin._artifact_transform_load``. + """ + for artifact in getattr(cfg.__fn_or_cls__, "__io_artifacts__", []): + current_val = getattr(cfg, artifact.attr) + if isinstance(current_val, fdl.Config): + setattr(cfg, artifact.attr, fdl.build(current_val).attr) + continue + if artifact.skip: + continue + current_val = getattr(cfg, artifact.attr) + if current_val is None: + continue + new_val = str(Path(path) / current_val) + setattr(cfg, artifact.attr, new_val) + + for attr in dir(cfg): + try: + child = getattr(cfg, attr) + if isinstance(child, fdl.Config): + _artifact_transform_load(child, path=path) + except (ValueError, AttributeError): + pass + + +# --------------------------------------------------------------------------- +# load (from nemo.lightning.io.mixin) +# --------------------------------------------------------------------------- + + +def load( + path: Path, + output_type: Any = None, + subpath: Optional[str] = None, + build: bool = True, +) -> Any: + """Load a fiddle-serialised NeMo checkpoint context from an ``io.json`` file. + + Copied from ``nemo.lightning.io.mixin.load``. + """ + _path = Path(path) + _set_thread_local_output_dir(_path) + + if _path.is_dir(): + _path = _path / "io.json" + + if not _path.is_file(): + raise FileNotFoundError(f"No such file: '{_path}'") + + if subpath: + subpath = "." + subpath + + # Register / re-register fiddle traversal for every class in the JSON. + # We always re-register (not just when missing) so that our _thread_local + # is used by _io_unflatten_object for classes that NeMo may have already + # registered before this call. + with open(_path) as f: + j = json.load(f) + + for obj, val in j.get("objects", {}).items(): + clss = ".".join([val["type"]["module"], val["type"]["name"]]) + if subpath and "paths" in val: + if all(subpath not in p for p in val["paths"]): + continue + cls_obj = locate(clss) + if cls_obj is None: + continue + if not serialization.find_node_traverser(cls_obj): + track_io(cls_obj) + else: + # Re-register with our traversal so our _thread_local is active. + _io_register_serialization(cls_obj) + + with open(_path, "rb") as f: + json_config = json.loads(f.read()) + + root_key = None + for obj, val in json_config.get("objects", {}).items(): + if "paths" in val and subpath in val["paths"]: + root_key = obj + break + + if subpath and not root_key: + _logger.warning("Could not find %s for %s in %s", subpath, output_type, _path) + + if root_key: + json_config["root"]["key"] = root_key + + config = serialization.Deserialization(json_config).result + _artifact_transform_load(config, path) + drop_unexpected_params(config) + + if not build: + return config + + return fdl.build(config) + + +# --------------------------------------------------------------------------- +# load_context (from nemo.lightning.io.api) +# --------------------------------------------------------------------------- + + +def load_context(path: Path, subpath: Optional[str] = None, build: bool = True) -> Any: + """Load a NeMo TrainerContext (or a subpath of it) from a checkpoint directory. + + Copied from ``nemo.lightning.io.api.load_context``. + """ + if not isinstance(path, Path): + path = Path(path) + + try: + return load(path, subpath=subpath, build=build) + except FileNotFoundError: + # Backwards compatibility: checkpoints without a ``/context`` sub-dir. + if path.parts[-1] == "context": + path = path.parent + else: + path = path / "context" + return load(path, subpath=subpath, build=build) diff --git a/nemo_deploy/llm/inference/nemo_utils.py b/nemo_deploy/llm/inference/nemo_utils.py new file mode 100644 index 000000000..de1463554 --- /dev/null +++ b/nemo_deploy/llm/inference/nemo_utils.py @@ -0,0 +1,247 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""NeMo utility code copied from the NeMo project. + +All utilities here are copied directly from NeMo and have no static +dependency on the nemo package. When a NeMo checkpoint is loaded at +runtime, NeMo classes are imported transitively through pydoc.locate +inside nemo_io.load_context — NeMo must therefore still be installed +to read NeMo checkpoints. + +Sources: + - MCoreTokenizerWrappper : nemo/collections/llm/inference/base.py + - ckpt_to_dir, + idempotent_path_append, + ckpt_to_context_subdir : nemo/lightning/ckpt_utils.py + - ckpt_to_weights_subdir : nemo/lightning/io/pl.py + - constants : nemo/lightning/ckpt_utils.py + - set_modelopt_spec_* : nemo/collections/llm/modelopt/model_utils.py + - load_context, io : nemo_io.py (copied from nemo/lightning/io/) +""" + +import inspect +import logging +import types +from functools import partial +from pathlib import Path +from typing import Any, Union + +from .nemo_io import load_context as _load_context + +_logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# io namespace — exposes load_context under the same attribute name that +# inference_base.py uses (io.load_context(...)). +# --------------------------------------------------------------------------- + +io = types.SimpleNamespace(load_context=_load_context) + +# --------------------------------------------------------------------------- +# GPTConfig / T5Config — type stubs used only for annotations. +# The actual runtime objects are NeMo classes deserialized from the +# checkpoint; isinstance() checks use class-name strings instead. +# --------------------------------------------------------------------------- + +GPTConfig = Any +T5Config = Any + +# --------------------------------------------------------------------------- +# Constants (from nemo.lightning.ckpt_utils) +# --------------------------------------------------------------------------- + +# NeMo-2 checkpoint structure: +# /weights/ – model weights +# /context/ – hyper-parameters / IO context +WEIGHTS_PATH: str = "weights" +CONTEXT_PATH: str = "context" +ADAPTER_META_FILENAME: str = "adapter_metadata.json" + +# --------------------------------------------------------------------------- +# Checkpoint path utilities (simplified from nemo.lightning.ckpt_utils and +# nemo.lightning.io.pl – AdapterPath and MultiStorageClient branches removed +# because they are not required for basic NeMo-2 inference). +# --------------------------------------------------------------------------- + + +def ckpt_to_dir(filepath: Union[str, Path]) -> Path: + """Return the checkpoint directory path for a given filepath. + + PTL treats checkpoints as ``.ckpt`` files. This helper strips the + extension (appending it first when absent) and returns a :class:`Path` + suitable for use as a distributed-checkpoint directory. + + Copied from ``nemo.lightning.ckpt_utils.ckpt_to_dir`` with the + ``AdapterPath`` and ``MultiStorageClient`` branches removed. + """ + filepath = Path(filepath) + + if filepath.suffix != ".ckpt": + filepath = filepath.with_suffix(filepath.suffix + ".ckpt") + + assert filepath.suffix == ".ckpt", f"filepath: {filepath} must have .ckpt extension" + + # Return path whose name is the original filepath without the .ckpt extension. + return filepath.with_name(filepath.stem) + + +def idempotent_path_append(base_dir: Union[str, Path], suffix: str) -> Path: + """Append *suffix* to *base_dir* only when it is not already the last component. + + Copied from ``nemo.lightning.ckpt_utils.idempotent_path_append`` with the + ``AdapterPath`` and ``MultiStorageClient`` branches removed. + """ + base_dir = Path(base_dir) + if base_dir.parts[-1] != suffix: + base_dir = base_dir / suffix + return base_dir + + +def ckpt_to_context_subdir(filepath: Union[str, Path]) -> Path: + """Return the ``context`` sub-directory of a NeMo-2 checkpoint. + + Copied from ``nemo.lightning.ckpt_utils.ckpt_to_context_subdir``. + """ + base_dir = ckpt_to_dir(filepath=filepath) + return idempotent_path_append(base_dir, CONTEXT_PATH) + + +def ckpt_to_weights_subdir(filepath: Union[str, Path], is_saving: bool) -> Path: + """Return the ``weights`` sub-directory of a NeMo-2 checkpoint. + + Copied from ``nemo.lightning.io.pl.ckpt_to_weights_subdir`` with the + ``AdapterPath`` branch removed. + """ + filepath = ckpt_to_dir(filepath=filepath) + base_dir = filepath + + if base_dir.parts[-1] != WEIGHTS_PATH: + maybe_base_dir = base_dir / WEIGHTS_PATH + if maybe_base_dir.is_dir() or is_saving: + base_dir = maybe_base_dir + + if is_saving: + assert base_dir.parts[-1] == WEIGHTS_PATH + assert base_dir.parent == filepath + + return base_dir + + +# --------------------------------------------------------------------------- +# MCoreTokenizerWrappper (from nemo.collections.llm.inference.base) +# --------------------------------------------------------------------------- + + +class MCoreTokenizerWrappper: + """Thin wrapper that adapts a NeMo tokenizer to the MCore generate API. + + MCore's generate pipeline expects ``tokenizer.detokenize``, + ``tokenizer.tokenize``, ``tokenizer.bos``, and ``tokenizer.pad`` – + this wrapper maps those calls to the corresponding NeMo tokenizer + methods/properties. + + Copied verbatim from ``nemo.collections.llm.inference.base.MCoreTokenizerWrappper``. + """ + + def __init__(self, tokenizer, vocab_size=None): + self.tokenizer = tokenizer + self.eod = tokenizer.eod + self.vocab_size = vocab_size or tokenizer.vocab_size + + def detokenize(self, tokens, remove_special_tokens=False): + """Detokenize *tokens* into a string.""" + if "remove_special_tokens" in inspect.signature(self.tokenizer.ids_to_text).parameters: + return self.tokenizer.ids_to_text(tokens, remove_special_tokens) + return self.tokenizer.ids_to_text(tokens) + + def tokenize(self, prompt): + """Tokenize *prompt* into a list of token IDs.""" + return self.tokenizer.text_to_ids(prompt) + + @property + def additional_special_tokens_ids(self): + """IDs of additional special tokens.""" + return self.tokenizer.additional_special_tokens_ids + + @property + def bos(self): + """Beginning-of-sequence token ID.""" + return self.tokenizer.bos_id + + @property + def pad(self): + """Padding token ID.""" + return self.tokenizer.pad_id + + +# --------------------------------------------------------------------------- +# set_modelopt_spec_if_exists_in_ckpt +# +# Copied from nemo/collections/llm/modelopt/model_utils.py. +# NeMo model-type isinstance checks are replaced by class-name checks to +# avoid importing nemo at module level. +# --------------------------------------------------------------------------- + + +def set_modelopt_spec_if_exists_in_ckpt(model, path: str) -> None: + """Set model.config.transformer_layer_spec to a modelopt spec if the + checkpoint contains a ``modelopt_state`` directory. + + Copied from ``nemo.collections.llm.modelopt.model_utils.set_modelopt_spec_if_exists_in_ckpt`` + with NeMo isinstance checks replaced by class-name comparisons. + """ + path = str(path).removeprefix("nemo://") + modelopt_state_path = ckpt_to_weights_subdir(path, is_saving=False) / "modelopt_state" + if not modelopt_state_path.exists() or hasattr(model, "module"): + return + + model_type_name = type(model).__name__ + if model_type_name not in ("GPTModel", "MambaModel"): + _logger.warning( + "%s is neither a GPTModel nor MambaModel. Modelopt state will not be loaded.", + type(model), + ) + return + + config = model.config + config_type_name = type(config).__name__ + + try: + from megatron.core.post_training.modelopt.gpt.model_specs import get_gpt_modelopt_spec + + _HAVE_GPT_MODELOPT_SPEC = True + except ImportError: + _HAVE_GPT_MODELOPT_SPEC = False + + if config_type_name == "GPTConfig": + if _HAVE_GPT_MODELOPT_SPEC: + config.transformer_layer_spec = partial( + get_gpt_modelopt_spec, + remap_te_layernorm=True, + local_core_attention=getattr(config, "softmax_type", "vanilla") != "vanilla", + ) + else: + _logger.warning("get_gpt_modelopt_spec not available; skipping modelopt layer spec.") + elif config_type_name == "SSMConfig": + try: + from megatron.core.post_training.modelopt.mamba.model_specs import get_mamba_stack_modelopt_spec + + config.mamba_stack_spec = partial(get_mamba_stack_modelopt_spec, remap_te_layernorm=True) + except ImportError: + _logger.warning("get_mamba_stack_modelopt_spec not available; skipping modelopt layer spec.") + else: + _logger.warning("No modelopt layer spec supported for config type %s.", type(config)) + return + + config.gradient_accumulation_fusion = False diff --git a/nemo_deploy/llm/inference/tron_utils.py b/nemo_deploy/llm/inference/tron_utils.py index ecf229209..4aa359a17 100644 --- a/nemo_deploy/llm/inference/tron_utils.py +++ b/nemo_deploy/llm/inference/tron_utils.py @@ -39,20 +39,7 @@ except ImportError: HAVE_TRITON = False -try: - if not HAVE_TRITON: - raise ImportError("Triton is not installed") - - from nemo.collections.llm.gpt.model.base import GPTConfig - from nemo.collections.llm.t5.model.t5 import T5Config - - HAVE_NEMO = True -except (ImportError, ModuleNotFoundError): - from typing import Any - - GPTConfig = Any - T5Config = Any - HAVE_NEMO = False +from .nemo_utils import GPTConfig, T5Config LOGGER = logging.getLogger("NeMo") @@ -369,7 +356,7 @@ def _get_model_type(model_config: Union[GPTConfig, T5Config]) -> ModelType: Returns: ModelType: The model type enum value (encoder_and_decoder or encoder_or_decoder) """ - return ModelType.encoder_and_decoder if isinstance(model_config, T5Config) else ModelType.encoder_or_decoder + return ModelType.encoder_and_decoder if type(model_config).__name__ == "T5Config" else ModelType.encoder_or_decoder def get_model_from_config( @@ -422,7 +409,7 @@ def get_model_from_config( pre_process = parallel_state.is_pipeline_first_stage() post_process = parallel_state.is_pipeline_last_stage() if model_type == ModelType.encoder_and_decoder: - assert isinstance(model_config, T5Config) + assert type(model_config).__name__ == "T5Config" if parallel_state.get_pipeline_model_parallel_world_size() > 1: rank = parallel_state.get_pipeline_model_parallel_rank() first_decoder_rank = parallel_state.get_pipeline_model_parallel_decoder_start() diff --git a/nemo_deploy/llm/megatronllm_deployable.py b/nemo_deploy/llm/megatronllm_deployable.py index 98f637276..ecc173847 100755 --- a/nemo_deploy/llm/megatronllm_deployable.py +++ b/nemo_deploy/llm/megatronllm_deployable.py @@ -79,6 +79,7 @@ class MegatronLLMDeployable(ITritonDeployable): legacy_ckpt (bool): use legacy checkpoint format. Defaults to False. model_type (str): type of model to load. Defaults to "gpt". micro_batch_size (Optional[int]): micro batch size for model execution. Defaults to None. + buffer_size_gb (float): KV cache buffer size in GiB for DynamicInferenceContext. Defaults to 10.0. """ def __init__( @@ -102,6 +103,7 @@ def __init__( legacy_ckpt: bool = False, model_type: str = "gpt", micro_batch_size: Optional[int] = None, + buffer_size_gb: float = 10.0, **model_config_kwargs, ): if not HAVE_TRITON: @@ -131,6 +133,7 @@ def __init__( model_type=model_type, model_format="megatron", micro_batch_size=micro_batch_size, + buffer_size_gb=buffer_size_gb, **model_config_kwargs, ) self.enable_cuda_graphs = enable_cuda_graphs diff --git a/nemo_export/__init__.py b/nemo_export/__init__.py index 6c51dfc86..f69aa4d1c 100644 --- a/nemo_export/__init__.py +++ b/nemo_export/__init__.py @@ -12,21 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -# WAR for trtllm and lightning conflict from nemo_export_deploy_common.package_info import __package_name__, __version__ -try: - from nemo.lightning import io - - HAVE_IO = True -except (ImportError, ModuleNotFoundError): - HAVE_IO = False - __all__ = ["__version__", "__package_name__"] -if HAVE_IO: - __all__ += ["io"] - # Optional convenience imports for TensorRT-LLM classes try: from nemo_export.tensorrt_llm import TensorRTLLM diff --git a/tests/unit_tests/deploy/test_etp_sequence_parallel.py b/tests/unit_tests/deploy/test_etp_sequence_parallel.py index 0024a68a0..86f168ad5 100644 --- a/tests/unit_tests/deploy/test_etp_sequence_parallel.py +++ b/tests/unit_tests/deploy/test_etp_sequence_parallel.py @@ -283,7 +283,6 @@ class TestSetupModelETPSequenceParallel(unittest.TestCase): def _common_patches(self): return [ - patch("nemo_deploy.llm.inference.inference_base.HAVE_NEMO", True), patch("nemo_deploy.llm.inference.inference_base.set_modelopt_spec_if_exists_in_ckpt"), patch("nemo_deploy.llm.inference.inference_base.torch_distributed_init"), patch("nemo_deploy.llm.inference.inference_base.io", new_callable=MagicMock), @@ -391,7 +390,6 @@ def test_sequence_parallel_applied(self): class TestCreateMcoreEngineETPSequenceParallel(unittest.TestCase): """Tests that create_mcore_engine handles ETP/SP defaults and passes them down.""" - @patch("nemo_deploy.llm.inference.inference_base.HAVE_NEMO", True) @patch("nemo_deploy.llm.inference.inference_base.setup_model_and_tokenizer_for_inference") @patch("nemo_deploy.llm.inference.inference_base.MCoreEngine") @patch("nemo_deploy.llm.inference.inference_base.MCoreEngineWithCleanup") @@ -413,7 +411,6 @@ def test_etp_defaults_to_1_when_none( _, kwargs = mock_setup.call_args assert kwargs["expert_tensor_parallel_size"] == 1 - @patch("nemo_deploy.llm.inference.inference_base.HAVE_NEMO", True) @patch("nemo_deploy.llm.inference.inference_base.setup_model_and_tokenizer_for_inference") @patch("nemo_deploy.llm.inference.inference_base.MCoreEngine") @patch("nemo_deploy.llm.inference.inference_base.MCoreEngineWithCleanup") @@ -435,7 +432,6 @@ def test_sp_defaults_to_1_when_none( _, kwargs = mock_setup.call_args assert kwargs["sequence_parallel"] == 1 - @patch("nemo_deploy.llm.inference.inference_base.HAVE_NEMO", True) @patch("nemo_deploy.llm.inference.inference_base.setup_model_and_tokenizer_for_inference") @patch("nemo_deploy.llm.inference.inference_base.MCoreEngine") @patch("nemo_deploy.llm.inference.inference_base.MCoreEngineWithCleanup") @@ -457,7 +453,6 @@ def test_explicit_etp_passed_through( _, kwargs = mock_setup.call_args assert kwargs["expert_tensor_parallel_size"] == 4 - @patch("nemo_deploy.llm.inference.inference_base.HAVE_NEMO", True) @patch("nemo_deploy.llm.inference.inference_base.setup_model_and_tokenizer_for_inference") @patch("nemo_deploy.llm.inference.inference_base.MCoreEngine") @patch("nemo_deploy.llm.inference.inference_base.MCoreEngineWithCleanup") diff --git a/tests/unit_tests/deploy/test_inference_base.py b/tests/unit_tests/deploy/test_inference_base.py index be74c9d44..d352be7d6 100644 --- a/tests/unit_tests/deploy/test_inference_base.py +++ b/tests/unit_tests/deploy/test_inference_base.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import types import unittest from pathlib import Path from unittest.mock import MagicMock, patch @@ -24,14 +25,6 @@ ) from megatron.core.transformer.module import MegatronModule -try: - from nemo.collections.llm.gpt.model.base import GPTConfig - from nemo.collections.llm.inference.base import MCoreTokenizerWrappper - - HAVE_NEMO = True -except (ImportError, ModuleNotFoundError): - HAVE_NEMO = False - from nemo_deploy.llm.inference.inference_base import ( MCoreEngineWithCleanup, _load_dist_shards_into_model, @@ -43,11 +36,10 @@ setup_megatron_model_and_tokenizer_for_inference, setup_model_and_tokenizer_for_inference, ) +from nemo_deploy.llm.inference.nemo_utils import MCoreTokenizerWrappper from nemo_deploy.llm.inference.tron_utils import DistributedInitConfig, RNGConfig -from nemo_export_deploy_common.import_utils import UnavailableError -@pytest.mark.skipif(not HAVE_NEMO, reason="NeMo is not installed") @pytest.mark.run_only_on("GPU") class TestInferenceBase(unittest.TestCase): def setUp(self): @@ -64,7 +56,7 @@ def setUp(self): self.mock_tokenizer.pad = 50257 # Padding token ID # Setup model config - self.model_config = GPTConfig( + self.model_config = types.SimpleNamespace( tensor_model_parallel_size=1, pipeline_model_parallel_size=1, context_parallel_size=1, @@ -199,7 +191,6 @@ def test_load_nemo_checkpoint_to_tron_model(self, mock_ckpt_to_weights, mock_loa mock_ckpt_to_weights.assert_called_once_with(self.mock_path, is_saving=False) mock_load_shards.assert_called_once_with(self.mock_model_list, self.mock_weights_dir, False) - @patch("nemo_deploy.llm.inference.inference_base.HAVE_NEMO", True) @patch("nemo_deploy.llm.inference.inference_base.set_modelopt_spec_if_exists_in_ckpt") @patch("nemo_deploy.llm.inference.inference_base.torch_distributed_init") @patch("nemo_deploy.llm.inference.inference_base.io.load_context") @@ -256,7 +247,6 @@ def test_setup_model_and_tokenizer_for_inference( mock_torch_dist_init.assert_called_once() mock_set_modelopt.assert_called_once() - @patch("nemo_deploy.llm.inference.inference_base.HAVE_NEMO", True) @patch("nemo_deploy.llm.inference.inference_base.set_modelopt_spec_if_exists_in_ckpt") @patch("nemo_deploy.llm.inference.inference_base.torch_distributed_init") @patch("nemo_deploy.llm.inference.inference_base.io.load_context") @@ -366,16 +356,10 @@ def test_mcore_engine_with_cleanup_del(self, mock_cleanup): # Verify cleanup was called mock_cleanup.assert_called_once() - @patch("nemo_deploy.llm.inference.inference_base.HAVE_NEMO", True) def test_create_mcore_engine_unknown_format_raises(self): with self.assertRaises(ValueError): create_mcore_engine(path=self.mock_path, model_format="unknown") - @patch("nemo_deploy.llm.inference.inference_base.HAVE_NEMO", False) - def test_create_mcore_engine_unavailable_nemo_raises(self): - with self.assertRaises(UnavailableError): - create_mcore_engine(path=self.mock_path) - @patch("nemo_deploy.llm.inference.inference_base.torch_distributed_init") @patch("nemo_deploy.llm.inference.inference_base.load_model_config") @patch("nemo_deploy.llm.inference.inference_base.initialize_megatron_for_inference")