diff --git a/megatron/core/models/common/language_module/language_module.py b/megatron/core/models/common/language_module/language_module.py index 8d6c1118a94..82e4418ab63 100644 --- a/megatron/core/models/common/language_module/language_module.py +++ b/megatron/core/models/common/language_module/language_module.py @@ -139,7 +139,7 @@ def compute_language_model_loss(self, labels: Tensor, logits: Tensor) -> Tensor: is_cg_capturable = ( hasattr(self.config, 'cuda_graph_scope') and self.config.cuda_graph_scope - and 'full_iteration' in self.config.cuda_graph_scope + and (self.config.cuda_graph_scope is not None and 'full_iteration' in self.config.cuda_graph_scope) ) if is_cg_capturable and not is_te_min_version("2.7.0"): from megatron.core.utils import get_te_version