Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions superbench/benchmarks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def get_configurable_settings(self):
return message

def parse_args(self, ignore_invalid=False):
"""Parse the arguments.
"""Parse the arguments, accepting unknown args for forwarding.

Return:
ret (bool): whether parse succeed or not.
Expand All @@ -104,20 +104,48 @@ def parse_args(self, ignore_invalid=False):
args, unknown = self._parser.parse_known_args(self._argv)
except BaseException as e:
if ignore_invalid:
logger.info('Missing or invliad parameters, will ignore the error and skip the args checking.')
logger.info('Missing or invalid parameters, will ignore the error and skip the args checking.')
return True, None, []
else:
logger.error('Invalid argument - benchmark: {}, message: {}.'.format(self._name, str(e)))
return False, None, []

ret = True
# Normalize unknown arguments (convert underscores to hyphens)
if len(unknown) > 0:
logger.error(
'Unknown arguments - benchmark: {}, unknown arguments: {}'.format(self._name, ' '.join(unknown))
)
ret = False
if not getattr(self, '_ignore_unknown_args', False):
logger.error(
'Unknown arguments - benchmark: {}, unknown arguments: {}'.format(self._name, ' '.join(unknown))
)
return False, None, []
else:
unknown = self._normalize_unknown_args(unknown)
return True, args, unknown

return ret, args, unknown
def _normalize_unknown_args(self, unknown):
"""Normalize unknown args by converting underscores to hyphens in flag names.

Args:
unknown (list): List of unknown arguments.

Return:
list: Normalized list of arguments.
"""
normalized = []
i = 0
while i < len(unknown):
arg = unknown[i]
# Check if it's a flag (starts with --)
if arg.startswith('--'):
# Convert underscores to hyphens in the flag name
flag = arg.split('=')[0]
value = arg.split('=')[1] if '=' in arg else None
normalized_flag = flag.replace('_', '-')
normalized.append(f'{normalized_flag} {value}' if value is not None else normalized_flag)
else:
# It's a value, keep as-is
normalized.append(arg)
i += 1
return normalized

def _preprocess(self):
"""Preprocess/preparation operations before the benchmarking.
Expand All @@ -126,7 +154,7 @@ def _preprocess(self):
True if _preprocess() succeed.
"""
self.add_parser_arguments()
ret, self._args, unknown = self.parse_args()
ret, self._args, self._unknown_args = self.parse_args()

if not ret:
self._result = BenchmarkResult(self._name, self._benchmark_type, ReturnCode.INVALID_ARGUMENT)
Expand Down
79 changes: 73 additions & 6 deletions superbench/benchmarks/model_benchmarks/megatron_gpt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self, name, parameters=''):
"""
super().__init__(name, parameters)
self._supported_precision = [Precision.FLOAT32, Precision.FLOAT16, Precision.BFLOAT16]
self._ignore_unknown_args = True

def add_parser_arguments(self):
"""Add the specified arguments."""
Expand Down Expand Up @@ -90,7 +91,7 @@ def add_parser_arguments(self):
# Parallelism configs
self._parser.add_argument('--zero_stage', type=int, default=1, help='Zero stage.')
# Misc configs
self._parser.add_argument('--log-interval', type=int, required=False, default=1, help='Log interval.')
self._parser.add_argument('--log_interval', type=int, required=False, default=1, help='Log interval.')
self._parser.add_argument('--eval_iters', type=int, default=0, help='Eval iters.')
self._parser.add_argument('--eval_interval', type=int, default=10, help='Eval interval.')
self._parser.add_argument('--num_save', type=int, default=10000, help='Num save.')
Expand Down Expand Up @@ -187,7 +188,7 @@ def add_parser_arguments(self):
)
self._parser.add_argument('--moe_ffn_hidden_size', type=int, help='MoE FFN hidden size.')
self._parser.add_argument('--enable_shared_expert', action='store_true', help='Enable shared expert in MoE.')
self._parser.add_argument('--moe_layer_freq', type=int, help='MoE layer frequency.')
self._parser.add_argument('--moe_layer_freq', type=str, help='MoE layer frequency.')
self._parser.add_argument('--num_shared_experts', type=int, help='Number of shared experts.')
self._parser.add_argument('--moe_router_topk', type=int, help='Top-k routing for MoE.')
self._parser.add_argument('--moe_aux_loss_coeff', type=float, help='Auxiliary loss coefficient.')
Expand Down Expand Up @@ -232,8 +233,11 @@ def add_parser_arguments(self):
)

def _preprocess(self):
"""Preprocess with support for unknown args."""
if not super()._preprocess():
return False

# Original MegatronGPT preprocessing logic
if not self._args.code_base:
if self._args.deepspeed:
self._args.code_base = os.path.join(
Expand Down Expand Up @@ -531,7 +535,9 @@ def _megatron_command(self, precision): # noqa: C901
command = f'deepspeed {script_path} {megatron_options} {self._data_options} {deepspeed_option}'
else:
command = f'torchrun {self._distributed_args} {script_path} {megatron_options} {self._data_options}'

# Transparently append any unknown args captured during parsing for forward compatibility.
if getattr(self, '_unknown_args', None):
command = f"{command} {' '.join(self._unknown_args)}"
return command

def _train_step(self, precision): # noqa: E501
Expand Down Expand Up @@ -786,13 +792,74 @@ def _cal_params_count(self):
'--load=deepseek-ai/DeepSeek-V2-Lite '
'--no_load_optim '
'--no_load_rng '
'--ckpt_format=torch '
'--eod_mask_loss '
'--train_mode=pretrain '
'--data_cache_path=/root/cache '
'--max_padding_length=4096 '
'--kv_lora_rank=512 '
'--dataloader_type=cyclic'
'--dataloader_type=cyclic '
),
platform=Platform.ROCM
)
BenchmarkRegistry.register_benchmark(
'megatron-deepseek-v2-lite',
MegatronGPT,
parameters=(
'--model=gpt '
'--transformer_impl=transformer_engine '
'--tokenizer_type=HuggingFaceTokenizer '
'--tokenizer_model=/opt/superbench/third_party/Megatron/data/DeepSeek-V2-Lite '
'--num_layers=27 '
'--hidden_size=1024 '
'--seq_len=4096 '
'--num_attn_heads=16 '
'--moe_ffn_hidden_size=1408 '
'--ffn_hidden_size=10944 '
'--dataloader_type=cyclic '
'--num_experts=64 '
'--no_async_tensor_model_parallel_allreduce '
'--use_rotary_position_embeddings '
'--no_gradient_accumulation_fusion '
'--mock_data '
'--use_flash_attn '
'--no_load_optim '
'--no_load_rng '
'--swiglu '
'--normalization=RMSNorm '
'--norm_epsilon=1e-06 '
'--no_bias_swiglu_fusion '
'--no_rope_fusion '
'--position_embedding_type=rope '
'--untie_embeddings_and_output_weights '
'--disable_bias_linear '
'--ckpt_format=torch '
'--rotary_percent=1.0 '
'--rotary_base=10000 '
'--rotary_scaling_factor=40 '
'--eod_mask_loss '
'--data_cache_path=/tmp/cache '
'--moe_layer_freq="([0]+[1]*26)" '
'--moe_router_topk=6 '
'--moe_router_topk_scaling_factor=1.0 '
'--moe_aux_loss_coeff=1e-3 '
'--kv_lora_rank=512 '
'--v_head_dim=128 '
'--qk_head_dim=128 '
'--qk_layernorm '
'--qk_pos_emb_head_dim=64 '
'--no_masked_softmax_fusion '
'--kv_channels=16 '
'--multi_latent_attention '
'--moe_router_score_function=softmax '
'--moe_router_topk=6 '
'--moe_router_pre_softmax '
'--moe_shared_expert_intermediate_size=2816 '
'--moe_token_dispatcher_type=alltoall '
'--moe_token_drop_policy=probs '
'--make_vocab_size_divisible_by=3200 '
'--attention_softmax_in_fp32 '
'--use_mcore_models '
'--mscale=0.707 '
'--mscale_all_dim=0.707 '
),
platform=Platform.CUDA
)
2 changes: 1 addition & 1 deletion superbench/benchmarks/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __parse_and_check_args(cls, name, class_def, parameters):
benchmark = class_def(name, parameters)
benchmark.add_parser_arguments()
ret, args, unknown = benchmark.parse_args(ignore_invalid=True)
if not ret or len(unknown) >= 1:
if not ret or (len(unknown) >= 1 and not getattr(benchmark, '_ignore_unknown_args', False)):
logger.log_and_raise(
TypeError,
'Registered benchmark has invalid arguments - benchmark: {}, parameters: {}'.format(name, parameters)
Expand Down
49 changes: 49 additions & 0 deletions tests/benchmarks/model_benchmarks/test_megatron_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,55 @@ def test_deepseek_v2_command(self):

self.assertEqual(actual_units, expected_units)

@mock.patch('superbench.benchmarks.model_benchmarks.MegatronGPT._generate_dataset')
def test_megatron_gpt_unknown_args(self, mock_generate_dataset):
"""Test unknown args forwarding and normalization."""
(benchmark_cls, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(self.benchmark_name, Platform.CUDA)
assert (benchmark_cls)
os.environ['OMPI_COMM_WORLD_SIZE'] = '1'
os.environ['OMPI_COMM_WORLD_LOCAL_SIZE'] = '1'
os.environ['OMPI_COMM_WORLD_RANK'] = '0'
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12345'
with open(self.hostfile_path, 'w') as f:
f.write('host1\n')

# Test with unknown args that have underscores (should be converted to hyphens)
benchmark = benchmark_cls(
self.benchmark_name,
parameters=f'--code_base {self._tmp_dir} --hostfile {self.hostfile_path} '
'--num_warmup 0 --num_steps 10 --batch_size 2048 '
'--my_custom_flag 128 --another_option --third_param value',
)
mock_generate_dataset.return_value = True
ret = benchmark._preprocess()
assert (ret is True)

# Verify unknown args are stored and normalized
assert (hasattr(benchmark, '_unknown_args'))
assert (len(benchmark._unknown_args) > 0)

# Check that underscores are converted to hyphens
assert ('--my-custom-flag' in benchmark._unknown_args)
assert ('128' in benchmark._unknown_args)
assert ('--another-option' in benchmark._unknown_args)
assert ('--third-param' in benchmark._unknown_args)
assert ('value' in benchmark._unknown_args)

# Verify unknown args appear in the generated command
benchmark._data_options = '--mock-data'
command = benchmark._megatron_command(Precision.FLOAT32)

# Check that normalized unknown args are in the command
assert ('--my-custom-flag 128' in command)
assert ('--another-option' in command)
assert ('--third-param value' in command)

# Ensure original underscore versions are NOT in the command
assert ('--my_custom_flag' not in command)
assert ('--another_option' not in command)
assert ('--third_param' not in command)

@decorator.load_data('tests/data/megatron_deepspeed.log')
@mock.patch('superbench.benchmarks.model_benchmarks.MegatronGPT._generate_dataset')
def test_megatron_parse_log(self, raw_output, mock_generate_dataset):
Expand Down
9 changes: 8 additions & 1 deletion third_party/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,14 @@ directx_amf_encoding_latency:
megatron_lm:
cd Megatron && \
apt install -y python3-mpi4py && \
python -m pip install --no-cache-dir -r requirements.txt
python -m pip install --no-cache-dir -r requirements.txt && \
mkdir -p data/gpt && \
wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json -O data/gpt2-vocab.json && \
wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt -O data/gpt2-merges.txt && \
mkdir -p data/DeepSeek-V2-Lite && \
wget https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/resolve/main/config.json -O data/DeepSeek-V2-Lite/config.json && \
wget https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/resolve/main/tokenizer.json -O data/DeepSeek-V2-Lite/tokenizer.json && \
wget https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/resolve/main/tokenizer_config.json -O data/DeepSeek-V2-Lite/tokenizer_config.json

# Install requirements for Megatron-DeepSpeed
megatron_deepspeed:
Expand Down
2 changes: 1 addition & 1 deletion third_party/Megatron/Megatron-LM
Submodule Megatron-LM updated 2142 files
Loading