33
44import argparse
55import json
6- from typing import Any , Dict
7-
86import torch
7+ from typing import Any , Dict
98
109from .utils import convert_hf_config
1110
1211
1312def sanitize_mindspeed_values (values : Dict [str , Any ]) -> Dict [str , Any ]:
14- return {
15- key : value
16- for key , value in values .items ()
17- if isinstance (key , str ) and key .isidentifier ()
18- }
13+ return {key : value for key , value in values .items () if isinstance (key , str ) and key .isidentifier ()}
1914
2015
2116def _resolve_optimization_level (values : Dict [str , Any ]) -> int :
@@ -35,9 +30,81 @@ def _resolve_optimization_level(values: Dict[str, Any]) -> int:
3530 return 0
3631
3732
33+ def _update_sanitized (values : Dict [str , Any ], section : Dict [str , Any ]) -> None :
34+ values .update (sanitize_mindspeed_values (section ))
35+
36+
37+ def _build_fixed_runtime_defaults () -> Dict [str , Any ]:
38+ # Fixed MindSpeed / TE runtime defaults.
39+ return {
40+ 'transformer_impl' : 'transformer_engine' ,
41+ 'fp8' : None ,
42+ 'optimizer_selection' : 'fused_adamw' ,
43+ 'shape_order' : 'SBH' ,
44+ 'use_ascend_mc2' : False ,
45+ 'enable_gloo_process_groups' : True ,
46+ 'disable_gloo_group' : False ,
47+ }
48+
49+
50+ def _build_topology_and_shape_defaults (args : Any , values : Dict [str , Any ], rope_scaling : Dict [str ,
51+ Any ]) -> Dict [str , Any ]:
52+ # Core topology and transformer shape.
53+ return {
54+ 'tensor_model_parallel_size' : args .tp_size ,
55+ 'pipeline_model_parallel_size' : args .pp_size ,
56+ 'context_parallel_size' : args .cp_size ,
57+ 'expert_model_parallel_size' : args .ep_size ,
58+ 'expert_tensor_parallel_size' : args .etp_size ,
59+ 'virtual_pipeline_model_parallel_size' : args .vpp_size ,
60+ 'sequence_parallel' : bool (args .sequence_parallel ),
61+ 'num_layers' : int (args .num_layers ),
62+ 'hidden_size' : int (args .hidden_size ),
63+ 'num_attention_heads' : int (args .num_attention_heads ),
64+ 'num_query_groups' : int (args .num_query_groups or args .num_attention_heads ),
65+ 'ffn_hidden_size' : int (args .ffn_hidden_size ),
66+ 'mtp_num_layers' : int (args .mtp_num_layers or 0 ),
67+ 'bf16' : args .params_dtype == torch .bfloat16 ,
68+ 'fp16' : args .params_dtype == torch .float16 ,
69+ 'position_embedding_type' : values .get ('position_embedding_type' , 'rope' ),
70+ 'rope_scaling_type' : rope_scaling .get ('rope_type' ) or rope_scaling .get ('type' ),
71+ 'yarn_scaling_factor' : rope_scaling .get ('factor' ),
72+ 'rope_scaling_mscale' : rope_scaling .get ('mscale' ),
73+ 'rope_scaling_mscale_all_dim' : rope_scaling .get ('mscale_all_dim' ),
74+ }
75+
76+
77+ def _build_moe_runtime_defaults (values : Dict [str , Any ], args : Any , num_experts : int ) -> Dict [str , Any ]:
78+ # MoE runtime knobs.
79+ return {
80+ 'num_experts' : num_experts ,
81+ 'num_moe_experts' : num_experts or None ,
82+ 'moe_grouped_gemm' : bool (values .get ('moe_grouped_gemm' , False ) or num_experts > 0 ),
83+ 'moe_token_dispatcher_type' : values .get ('moe_token_dispatcher_type' )
84+ or ('alltoall' if num_experts > 0 else None ),
85+ 'moe_router_topk' : int (values .get ('moe_router_topk' , args .num_experts_per_tok ) or 2 ),
86+ }
87+
88+
89+ def _build_mla_runtime_defaults (values : Dict [str , Any ], q_lora_rank : Any , multi_latent_attention : bool ,
90+ qk_layernorm : bool , args : Any ) -> Dict [str , Any ]:
91+ # MLA / DeepSeek-style attention knobs.
92+ return {
93+ 'multi_latent_attention' : multi_latent_attention ,
94+ 'multi_head_latent_attention' : multi_latent_attention ,
95+ 'q_lora_rank' : q_lora_rank ,
96+ 'kv_lora_rank' : values .get ('kv_lora_rank' ),
97+ 'qk_layernorm' : qk_layernorm ,
98+ 'use_qk_norm' : qk_layernorm ,
99+ 'qk_nope_head_dim' : values .get ('qk_head_dim' , values .get ('qk_nope_head_dim' )),
100+ 'qk_rope_head_dim' : values .get ('qk_pos_emb_head_dim' , values .get ('qk_rope_head_dim' )),
101+ 'v_head_dim' : values .get ('v_head_dim' , args .kv_channels ),
102+ }
103+
104+
38105def build_mindspeed_namespace (args : Any , defaults : Dict [str , Any ]) -> argparse .Namespace :
39106 """Build MindSpeed runtime args namespace from Twinkle args.
40-
107+
41108 If there are fields with the same name, the one at the lowest level will be overwritten.
42109
43110 Merges three layers in order of precedence (later layers override earlier ones):
@@ -64,64 +131,15 @@ def build_mindspeed_namespace(args: Any, defaults: Dict[str, Any]) -> argparse.N
64131 num_experts = int (getattr (args , 'num_experts' , 0 ) or values .get ('num_experts' , 0 ) or 0 )
65132 q_lora_rank = values .get ('q_lora_rank' , getattr (args , 'q_lora_rank' , None ))
66133 multi_latent_attention = bool (
67- getattr (args , 'multi_latent_attention' , False )
68- or values .get ('multi_latent_attention' , False )
69- or values .get ('multi_head_latent_attention' , False )
70- or q_lora_rank is not None
71- )
134+ getattr (args , 'multi_latent_attention' , False ) or values .get ('multi_latent_attention' , False )
135+ or values .get ('multi_head_latent_attention' , False ) or q_lora_rank is not None )
72136 qk_layernorm = bool (getattr (args , 'qk_layernorm' , False ) or values .get ('qk_layernorm' , False ))
73137
74- values .update (
75- sanitize_mindspeed_values ({
76- # Fixed MindSpeed / TE runtime defaults.
77- 'transformer_impl' : 'transformer_engine' ,
78- 'fp8' : None ,
79- 'optimizer_selection' : 'fused_adamw' ,
80- 'shape_order' : 'SBH' ,
81- 'use_ascend_mc2' : False ,
82- 'enable_gloo_process_groups' : True ,
83- 'disable_gloo_group' : False ,
84-
85- # Core topology and transformer shape.
86- 'tensor_model_parallel_size' : args .tp_size ,
87- 'pipeline_model_parallel_size' : args .pp_size ,
88- 'context_parallel_size' : args .cp_size ,
89- 'expert_model_parallel_size' : args .ep_size ,
90- 'expert_tensor_parallel_size' : args .etp_size ,
91- 'virtual_pipeline_model_parallel_size' : args .vpp_size ,
92- 'sequence_parallel' : bool (args .sequence_parallel ),
93- 'num_layers' : int (args .num_layers ),
94- 'hidden_size' : int (args .hidden_size ),
95- 'num_attention_heads' : int (args .num_attention_heads ),
96- 'num_query_groups' : int (args .num_query_groups or args .num_attention_heads ),
97- 'ffn_hidden_size' : int (args .ffn_hidden_size ),
98- 'mtp_num_layers' : int (args .mtp_num_layers or 0 ),
99- 'bf16' : args .params_dtype == torch .bfloat16 ,
100- 'fp16' : args .params_dtype == torch .float16 ,
101- 'position_embedding_type' : values .get ('position_embedding_type' , 'rope' ),
102- 'rope_scaling_type' : rope_scaling .get ('rope_type' ) or rope_scaling .get ('type' ),
103- 'yarn_scaling_factor' : rope_scaling .get ('factor' ),
104- 'rope_scaling_mscale' : rope_scaling .get ('mscale' ),
105- 'rope_scaling_mscale_all_dim' : rope_scaling .get ('mscale_all_dim' ),
106-
107- # MoE runtime knobs.
108- 'num_experts' : num_experts ,
109- 'num_moe_experts' : num_experts or None ,
110- 'moe_grouped_gemm' : bool (values .get ('moe_grouped_gemm' , False ) or num_experts > 0 ),
111- 'moe_token_dispatcher_type' : values .get ('moe_token_dispatcher_type' ) or ('alltoall' if num_experts > 0 else None ),
112- 'moe_router_topk' : int (values .get ('moe_router_topk' , args .num_experts_per_tok ) or 2 ),
113-
114- # MLA / DeepSeek-style attention knobs.
115- 'multi_latent_attention' : multi_latent_attention ,
116- 'multi_head_latent_attention' : multi_latent_attention ,
117- 'q_lora_rank' : q_lora_rank ,
118- 'kv_lora_rank' : values .get ('kv_lora_rank' ),
119- 'qk_layernorm' : qk_layernorm ,
120- 'use_qk_norm' : qk_layernorm ,
121- 'qk_nope_head_dim' : values .get ('qk_head_dim' , values .get ('qk_nope_head_dim' )),
122- 'qk_rope_head_dim' : values .get ('qk_pos_emb_head_dim' , values .get ('qk_rope_head_dim' )),
123- 'v_head_dim' : values .get ('v_head_dim' , args .kv_channels ),
124- }))
138+ _update_sanitized (values , _build_fixed_runtime_defaults ())
139+ _update_sanitized (values , _build_topology_and_shape_defaults (args , values , rope_scaling ))
140+ _update_sanitized (values , _build_moe_runtime_defaults (values , args , num_experts ))
141+ _update_sanitized (values ,
142+ _build_mla_runtime_defaults (values , q_lora_rank , multi_latent_attention , qk_layernorm , args ))
125143 values ['optimization_level' ] = _resolve_optimization_level (values )
126144 return argparse .Namespace (** sanitize_mindspeed_values (values ))
127145
0 commit comments