Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 24 additions & 37 deletions nemo_deploy/llm/inference/inference_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import atexit
import logging
from pathlib import Path
from typing import Any, List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union

import megatron.core.dist_checkpointing.serialization as dist_ckpt
import torch
Expand All @@ -27,20 +27,20 @@
get_default_load_sharded_strategy,
)
from megatron.core.dist_checkpointing.validation import StrictHandling
from megatron.core.inference.contexts import StaticInferenceContext
from megatron.core.inference.engines.mcore_engine import MCoreEngine
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import (
GPTInferenceWrapper,
)
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
InferenceWrapperConfig,
)
from megatron.core.inference.text_generation_controllers.text_generation_controller import (
TextGenerationController,
)
from megatron.core.transformer.enums import AttnBackend
from megatron.core.transformer.module import MegatronModule
from packaging import version

from nemo_export_deploy_common.import_utils import MISSING_NEMO_MSG, UnavailableError

from .tron_utils import (
DistributedInitConfig,
RNGConfig,
Expand All @@ -62,29 +62,13 @@
except ImportError:
HAVE_TRITON = False

try:
if not HAVE_TRITON:
raise ImportError("Triton is not installed")
from nemo.collections.llm.gpt.model.base import GPTConfig
from nemo.collections.llm.inference.base import MCoreTokenizerWrappper
from nemo.collections.llm.modelopt import set_modelopt_spec_if_exists_in_ckpt
from nemo.collections.llm.t5.model.t5 import T5Config
from nemo.lightning import io
from nemo.lightning.ckpt_utils import ckpt_to_context_subdir
from nemo.lightning.io.pl import ckpt_to_weights_subdir

HAVE_NEMO = True
except (ImportError, ModuleNotFoundError):
HAVE_NEMO = False
from typing import Any

io = None
GPTConfig = Any
T5Config = Any
MCoreTokenizerWrappper = Any
set_modelopt_spec_if_exists_in_ckpt = None
ckpt_to_weights_subdir = None
ckpt_to_context_subdir = None
from .nemo_utils import (
MCoreTokenizerWrappper,
ckpt_to_context_subdir,
ckpt_to_weights_subdir,
io,
set_modelopt_spec_if_exists_in_ckpt,
)

LOGGER = logging.getLogger("NeMo")

Expand Down Expand Up @@ -201,8 +185,6 @@ def load_nemo_checkpoint_to_tron_model(model: List[MegatronModule], path: Path,
path (Path): Path to NeMo checkpoint directory
legacy_ckpt (bool): Whether to use legacy checkpoint format
"""
if not HAVE_NEMO:
raise UnavailableError(MISSING_NEMO_MSG)
weights_dir = ckpt_to_weights_subdir(path, is_saving=False)
LOGGER.info(f"Loading NeMo checkpoint from {weights_dir}")

Expand Down Expand Up @@ -324,9 +306,6 @@ def setup_model_and_tokenizer_for_inference(
Raises:
ValueError: If checkpoint_path is not a valid NeMo-2.0 checkpoint
"""
if not HAVE_NEMO:
raise UnavailableError(MISSING_NEMO_MSG)

checkpoint_path = Path(checkpoint_path)

# Load model context for config and tokenizer
Expand Down Expand Up @@ -465,6 +444,7 @@ def create_mcore_engine(
model_type: str = "gpt",
model_format: str = "nemo",
micro_batch_size: Optional[int] = None,
buffer_size_gb: float = 10.0,
**model_config_kwargs,
) -> Tuple[MCoreEngineWithCleanup, GPTInferenceWrapper, Union[MCoreTokenizerWrappper, MegatronTokenizer]]:
"""Set up the model, tokenizer and MCoreEngine for inference.
Expand Down Expand Up @@ -492,9 +472,6 @@ def create_mcore_engine(
- GPTInferenceWrapper: Inference-wrapped model
- Union[MCoreTokenizerWrappper, MegatronTokenizer]: Tokenizer instance
"""
if not HAVE_NEMO and model_format == "nemo":
raise UnavailableError(MISSING_NEMO_MSG)

# Default to 1 for any parallelism dimension that's None
tensor_model_parallel_size = tensor_model_parallel_size if tensor_model_parallel_size is not None else 1
pipeline_model_parallel_size = pipeline_model_parallel_size if pipeline_model_parallel_size is not None else 1
Expand Down Expand Up @@ -534,15 +511,25 @@ def create_mcore_engine(
else:
raise ValueError(f"Model format {model_format} not supported.")

inference_context = StaticInferenceContext(max_batch_size, inference_max_seq_length)
model_inference_wrapper = GPTInferenceWrapper(model, inference_context)
inner_model = peel(model)
model_config = inner_model.config
inference_wrapper_config = InferenceWrapperConfig(
hidden_size=model_config.hidden_size,
params_dtype=model_config.params_dtype,
inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold,
padded_vocab_size=inner_model.vocab_size,
inference_max_requests=max_batch_size,
inference_max_seq_length=inference_max_seq_length,
)
model_inference_wrapper = GPTInferenceWrapper(model, inference_wrapper_config)
text_generation_controller = TextGenerationController(
inference_wrapped_model=model_inference_wrapper, tokenizer=tokenizer
)
mcore_engine = MCoreEngine(
text_generation_controller=text_generation_controller,
max_batch_size=max_batch_size,
random_seed=random_seed,
buffer_size_gb=buffer_size_gb,
)

# Wrap the engine to ensure cleanup
Expand Down
Loading