Skip to content

Commit ff7464a

Browse files
committed
fix lint
1 parent 1b2ab01 commit ff7464a

File tree

6 files changed

+99
-99
lines changed

6 files changed

+99
-99
lines changed

src/twinkle/model/megatron/_mindspeed_args.py

Lines changed: 82 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,14 @@
33

44
import argparse
55
import json
6-
from typing import Any, Dict
7-
86
import torch
7+
from typing import Any, Dict
98

109
from .utils import convert_hf_config
1110

1211

1312
def sanitize_mindspeed_values(values: Dict[str, Any]) -> Dict[str, Any]:
14-
return {
15-
key: value
16-
for key, value in values.items()
17-
if isinstance(key, str) and key.isidentifier()
18-
}
13+
return {key: value for key, value in values.items() if isinstance(key, str) and key.isidentifier()}
1914

2015

2116
def _resolve_optimization_level(values: Dict[str, Any]) -> int:
@@ -35,9 +30,81 @@ def _resolve_optimization_level(values: Dict[str, Any]) -> int:
3530
return 0
3631

3732

33+
def _update_sanitized(values: Dict[str, Any], section: Dict[str, Any]) -> None:
34+
values.update(sanitize_mindspeed_values(section))
35+
36+
37+
def _build_fixed_runtime_defaults() -> Dict[str, Any]:
38+
# Fixed MindSpeed / TE runtime defaults.
39+
return {
40+
'transformer_impl': 'transformer_engine',
41+
'fp8': None,
42+
'optimizer_selection': 'fused_adamw',
43+
'shape_order': 'SBH',
44+
'use_ascend_mc2': False,
45+
'enable_gloo_process_groups': True,
46+
'disable_gloo_group': False,
47+
}
48+
49+
50+
def _build_topology_and_shape_defaults(args: Any, values: Dict[str, Any], rope_scaling: Dict[str,
51+
Any]) -> Dict[str, Any]:
52+
# Core topology and transformer shape.
53+
return {
54+
'tensor_model_parallel_size': args.tp_size,
55+
'pipeline_model_parallel_size': args.pp_size,
56+
'context_parallel_size': args.cp_size,
57+
'expert_model_parallel_size': args.ep_size,
58+
'expert_tensor_parallel_size': args.etp_size,
59+
'virtual_pipeline_model_parallel_size': args.vpp_size,
60+
'sequence_parallel': bool(args.sequence_parallel),
61+
'num_layers': int(args.num_layers),
62+
'hidden_size': int(args.hidden_size),
63+
'num_attention_heads': int(args.num_attention_heads),
64+
'num_query_groups': int(args.num_query_groups or args.num_attention_heads),
65+
'ffn_hidden_size': int(args.ffn_hidden_size),
66+
'mtp_num_layers': int(args.mtp_num_layers or 0),
67+
'bf16': args.params_dtype == torch.bfloat16,
68+
'fp16': args.params_dtype == torch.float16,
69+
'position_embedding_type': values.get('position_embedding_type', 'rope'),
70+
'rope_scaling_type': rope_scaling.get('rope_type') or rope_scaling.get('type'),
71+
'yarn_scaling_factor': rope_scaling.get('factor'),
72+
'rope_scaling_mscale': rope_scaling.get('mscale'),
73+
'rope_scaling_mscale_all_dim': rope_scaling.get('mscale_all_dim'),
74+
}
75+
76+
77+
def _build_moe_runtime_defaults(values: Dict[str, Any], args: Any, num_experts: int) -> Dict[str, Any]:
78+
# MoE runtime knobs.
79+
return {
80+
'num_experts': num_experts,
81+
'num_moe_experts': num_experts or None,
82+
'moe_grouped_gemm': bool(values.get('moe_grouped_gemm', False) or num_experts > 0),
83+
'moe_token_dispatcher_type': values.get('moe_token_dispatcher_type')
84+
or ('alltoall' if num_experts > 0 else None),
85+
'moe_router_topk': int(values.get('moe_router_topk', args.num_experts_per_tok) or 2),
86+
}
87+
88+
89+
def _build_mla_runtime_defaults(values: Dict[str, Any], q_lora_rank: Any, multi_latent_attention: bool,
90+
qk_layernorm: bool, args: Any) -> Dict[str, Any]:
91+
# MLA / DeepSeek-style attention knobs.
92+
return {
93+
'multi_latent_attention': multi_latent_attention,
94+
'multi_head_latent_attention': multi_latent_attention,
95+
'q_lora_rank': q_lora_rank,
96+
'kv_lora_rank': values.get('kv_lora_rank'),
97+
'qk_layernorm': qk_layernorm,
98+
'use_qk_norm': qk_layernorm,
99+
'qk_nope_head_dim': values.get('qk_head_dim', values.get('qk_nope_head_dim')),
100+
'qk_rope_head_dim': values.get('qk_pos_emb_head_dim', values.get('qk_rope_head_dim')),
101+
'v_head_dim': values.get('v_head_dim', args.kv_channels),
102+
}
103+
104+
38105
def build_mindspeed_namespace(args: Any, defaults: Dict[str, Any]) -> argparse.Namespace:
39106
"""Build MindSpeed runtime args namespace from Twinkle args.
40-
107+
41108
If there are fields with the same name, the one at the lowest level will be overwritten.
42109
43110
Merges three layers in order of precedence (later layers override earlier ones):
@@ -64,64 +131,15 @@ def build_mindspeed_namespace(args: Any, defaults: Dict[str, Any]) -> argparse.N
64131
num_experts = int(getattr(args, 'num_experts', 0) or values.get('num_experts', 0) or 0)
65132
q_lora_rank = values.get('q_lora_rank', getattr(args, 'q_lora_rank', None))
66133
multi_latent_attention = bool(
67-
getattr(args, 'multi_latent_attention', False)
68-
or values.get('multi_latent_attention', False)
69-
or values.get('multi_head_latent_attention', False)
70-
or q_lora_rank is not None
71-
)
134+
getattr(args, 'multi_latent_attention', False) or values.get('multi_latent_attention', False)
135+
or values.get('multi_head_latent_attention', False) or q_lora_rank is not None)
72136
qk_layernorm = bool(getattr(args, 'qk_layernorm', False) or values.get('qk_layernorm', False))
73137

74-
values.update(
75-
sanitize_mindspeed_values({
76-
# Fixed MindSpeed / TE runtime defaults.
77-
'transformer_impl': 'transformer_engine',
78-
'fp8': None,
79-
'optimizer_selection': 'fused_adamw',
80-
'shape_order': 'SBH',
81-
'use_ascend_mc2': False,
82-
'enable_gloo_process_groups': True,
83-
'disable_gloo_group': False,
84-
85-
# Core topology and transformer shape.
86-
'tensor_model_parallel_size': args.tp_size,
87-
'pipeline_model_parallel_size': args.pp_size,
88-
'context_parallel_size': args.cp_size,
89-
'expert_model_parallel_size': args.ep_size,
90-
'expert_tensor_parallel_size': args.etp_size,
91-
'virtual_pipeline_model_parallel_size': args.vpp_size,
92-
'sequence_parallel': bool(args.sequence_parallel),
93-
'num_layers': int(args.num_layers),
94-
'hidden_size': int(args.hidden_size),
95-
'num_attention_heads': int(args.num_attention_heads),
96-
'num_query_groups': int(args.num_query_groups or args.num_attention_heads),
97-
'ffn_hidden_size': int(args.ffn_hidden_size),
98-
'mtp_num_layers': int(args.mtp_num_layers or 0),
99-
'bf16': args.params_dtype == torch.bfloat16,
100-
'fp16': args.params_dtype == torch.float16,
101-
'position_embedding_type': values.get('position_embedding_type', 'rope'),
102-
'rope_scaling_type': rope_scaling.get('rope_type') or rope_scaling.get('type'),
103-
'yarn_scaling_factor': rope_scaling.get('factor'),
104-
'rope_scaling_mscale': rope_scaling.get('mscale'),
105-
'rope_scaling_mscale_all_dim': rope_scaling.get('mscale_all_dim'),
106-
107-
# MoE runtime knobs.
108-
'num_experts': num_experts,
109-
'num_moe_experts': num_experts or None,
110-
'moe_grouped_gemm': bool(values.get('moe_grouped_gemm', False) or num_experts > 0),
111-
'moe_token_dispatcher_type': values.get('moe_token_dispatcher_type') or ('alltoall' if num_experts > 0 else None),
112-
'moe_router_topk': int(values.get('moe_router_topk', args.num_experts_per_tok) or 2),
113-
114-
# MLA / DeepSeek-style attention knobs.
115-
'multi_latent_attention': multi_latent_attention,
116-
'multi_head_latent_attention': multi_latent_attention,
117-
'q_lora_rank': q_lora_rank,
118-
'kv_lora_rank': values.get('kv_lora_rank'),
119-
'qk_layernorm': qk_layernorm,
120-
'use_qk_norm': qk_layernorm,
121-
'qk_nope_head_dim': values.get('qk_head_dim', values.get('qk_nope_head_dim')),
122-
'qk_rope_head_dim': values.get('qk_pos_emb_head_dim', values.get('qk_rope_head_dim')),
123-
'v_head_dim': values.get('v_head_dim', args.kv_channels),
124-
}))
138+
_update_sanitized(values, _build_fixed_runtime_defaults())
139+
_update_sanitized(values, _build_topology_and_shape_defaults(args, values, rope_scaling))
140+
_update_sanitized(values, _build_moe_runtime_defaults(values, args, num_experts))
141+
_update_sanitized(values,
142+
_build_mla_runtime_defaults(values, q_lora_rank, multi_latent_attention, qk_layernorm, args))
125143
values['optimization_level'] = _resolve_optimization_level(values)
126144
return argparse.Namespace(**sanitize_mindspeed_values(values))
127145

src/twinkle/model/megatron/args.py

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,8 @@ def _allreduce_word_embedding_grads_allow_none(*call_args, **call_kwargs):
4242
it can survive Megatron helper signature drift across versions.
4343
"""
4444
from megatron.core import parallel_state
45-
from megatron.core.distributed.finalize_model_grads import (
46-
_get_main_grad_attr,
47-
_reshard_if_dtensor,
48-
_unshard_if_dtensor,
49-
get_attr_wrapped_model,
50-
)
45+
from megatron.core.distributed.finalize_model_grads import (_get_main_grad_attr, _reshard_if_dtensor,
46+
_unshard_if_dtensor, get_attr_wrapped_model)
5147

5248
model, config, embd_group, pp_group, _ = _normalize_word_embedding_allreduce_call(*call_args, **call_kwargs)
5349
if embd_group is None:
@@ -65,8 +61,8 @@ def _get_main_grad_attr_compat(weight, ddp_config):
6561
return _get_main_grad_attr(weight)
6662
return _get_main_grad_attr(weight, ddp_config.use_custom_fsdp)
6763

68-
if parallel_state.is_rank_in_embedding_group(ignore_virtual=True) and torch.distributed.get_world_size(
69-
embd_group) > 1:
64+
if parallel_state.is_rank_in_embedding_group(
65+
ignore_virtual=True) and torch.distributed.get_world_size(embd_group) > 1:
7066
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
7167
model_module = model[0]
7268
elif parallel_state.is_pipeline_last_stage(ignore_virtual=True):
@@ -328,9 +324,8 @@ def expert_tensor_parallel_size(self) -> int:
328324
# the parameters were built according to tp_size.
329325
tp = self.device_mesh.tp_world_size or 1
330326
if self.device_mesh.etp_size is not None and self.device_mesh.etp_world_size != tp:
331-
logger.warning(
332-
f'etp_size={self.device_mesh.etp_world_size} is ignored on '
333-
f'megatron_core<0.13; expert TP is tied to tp_size={tp}')
327+
logger.warning(f'etp_size={self.device_mesh.etp_world_size} is ignored on '
328+
f'megatron_core<0.13; expert TP is tied to tp_size={tp}')
334329
return tp
335330
return self.device_mesh.etp_world_size
336331

@@ -438,9 +433,7 @@ def from_hf_config(
438433
# The registry import chain can pull in megatron.core, which must stay
439434
# behind the MindSpeed bootstrap on NPU.
440435
from .model.constant import MLLMModelType
441-
is_multimodal = model_type in {
442-
value for key, value in vars(MLLMModelType).items() if not key.startswith('_')
443-
}
436+
is_multimodal = model_type in {value for key, value in vars(MLLMModelType).items() if not key.startswith('_')}
444437

445438
# Determine QKV bias
446439
if hasattr(text_config, 'attention_bias'):
@@ -589,15 +582,10 @@ def create_model(self, ) -> List[nn.Module]:
589582

590583
def finalize_model_grads_for_lora(model, *args, **kwargs):
591584
import importlib
592-
593-
from megatron.core.distributed import DistributedDataParallel as MegatronDDP
594-
from megatron.core.distributed.finalize_model_grads import (
595-
_get_main_grad_attr,
596-
_reshard_if_dtensor,
597-
_unshard_if_dtensor,
598-
get_attr_wrapped_model,
599-
)
600585
from megatron.core import parallel_state
586+
from megatron.core.distributed import DistributedDataParallel as MegatronDDP
587+
from megatron.core.distributed.finalize_model_grads import (_get_main_grad_attr, _reshard_if_dtensor,
588+
_unshard_if_dtensor, get_attr_wrapped_model)
601589
from peft import PeftModel as _PeftModel
602590

603591
# Unwrap PeftModel -> LoraModel -> real model to check DDP capability.
@@ -610,9 +598,7 @@ def _get_base_model(m):
610598
base_model = _get_base_model(model[0])
611599
if isinstance(base_model, MegatronDDP) or hasattr(base_model, 'finish_grad_sync'):
612600
# Fix 2: temporarily swap in the None-safe embedding allreduce.
613-
finalize_model_grads_mod = importlib.import_module(
614-
'megatron.core.distributed.finalize_model_grads'
615-
)
601+
finalize_model_grads_mod = importlib.import_module('megatron.core.distributed.finalize_model_grads')
616602
orig_allreduce_word_embedding_grads = finalize_model_grads_mod._allreduce_word_embedding_grads
617603
finalize_model_grads_mod._allreduce_word_embedding_grads = _allreduce_word_embedding_grads_allow_none
618604
try:

src/twinkle/model/megatron/megatron.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ def forward_step_func(data_iterator, model):
502502
masked_labels = torch.where(loss_mask, labels, torch.zeros_like(labels))
503503

504504
output_tensor.div_(temperature)
505-
505+
506506
logps = selective_log_softmax(output_tensor, masked_labels)
507507
if cp_size > 1:
508508
logps = self._postprocess_tensor_cp(logps)

src/twinkle/model/megatron/mindspeed_bootstrap.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from typing import Any, Dict, Optional
88

99
from twinkle import Platform
10-
1110
from ._mindspeed_args import build_mindspeed_namespace, get_mindspeed_signature, sanitize_mindspeed_values
1211

1312
_DEFAULT_MINDSPEED_VALUES: Optional[Dict[str, Any]] = None
@@ -60,10 +59,8 @@ def bootstrap_mindspeed_for_npu(args: Any) -> Optional[Dict[str, Any]]:
6059
try:
6160
args_utils = importlib.import_module('mindspeed.args_utils')
6261
except ModuleNotFoundError as exc:
63-
raise RuntimeError(
64-
'MindSpeed is required for Twinkle NPU Megatron runs. '
65-
'Please install MindSpeed in the current environment.'
66-
) from exc
62+
raise RuntimeError('MindSpeed is required for Twinkle NPU Megatron runs. '
63+
'Please install MindSpeed in the current environment.') from exc
6764
# Fetch MindSpeed defaults here, then merge them with Twinkle args to
6865
# build the final MindSpeed runtime args.
6966
runtime_args = build_mindspeed_namespace(args, _get_mindspeed_defaults(args_utils))
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .constant import MegatronModelType
22
from .gpt_bridge import GPTBridge
3-
from .register import (MegatronModelLoader, MegatronModelMeta, ensure_megatron_model_registry,
4-
get_megatron_model_meta, register_megatron_model)
3+
from .register import (MegatronModelLoader, MegatronModelMeta, ensure_megatron_model_registry, get_megatron_model_meta,
4+
register_megatron_model)

src/twinkle/utils/framework.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,7 @@ def gather_object(object: Any, device_mesh: DeviceMesh, process_group=None):
5858
from megatron.core import parallel_state as mpu
5959

6060
process_group = mpu.get_data_parallel_group_gloo(
61-
with_context_parallel=getattr(device_mesh, 'cp_world_size', 1) > 1
62-
)
61+
with_context_parallel=getattr(device_mesh, 'cp_world_size', 1) > 1)
6362
except Exception:
6463
pass
6564
group_size = dist.get_world_size(group=process_group)

0 commit comments

Comments
 (0)