diff --git a/superbench/benchmarks/base.py b/superbench/benchmarks/base.py index 8e6e58bfe..02b072b7e 100644 --- a/superbench/benchmarks/base.py +++ b/superbench/benchmarks/base.py @@ -93,7 +93,7 @@ def get_configurable_settings(self): return message def parse_args(self, ignore_invalid=False): - """Parse the arguments. + """Parse the arguments, accepting unknown args for forwarding. Return: ret (bool): whether parse succeed or not. @@ -104,20 +104,48 @@ def parse_args(self, ignore_invalid=False): args, unknown = self._parser.parse_known_args(self._argv) except BaseException as e: if ignore_invalid: - logger.info('Missing or invliad parameters, will ignore the error and skip the args checking.') + logger.info('Missing or invalid parameters, will ignore the error and skip the args checking.') return True, None, [] else: logger.error('Invalid argument - benchmark: {}, message: {}.'.format(self._name, str(e))) return False, None, [] - ret = True + # Normalize unknown arguments (convert underscores to hyphens) if len(unknown) > 0: - logger.error( - 'Unknown arguments - benchmark: {}, unknown arguments: {}'.format(self._name, ' '.join(unknown)) - ) - ret = False + if not getattr(self, '_ignore_unknown_args', False): + logger.error( + 'Unknown arguments - benchmark: {}, unknown arguments: {}'.format(self._name, ' '.join(unknown)) + ) + return False, None, [] + else: + unknown = self._normalize_unknown_args(unknown) + return True, args, unknown - return ret, args, unknown + def _normalize_unknown_args(self, unknown): + """Normalize unknown args by converting underscores to hyphens in flag names. + + Args: + unknown (list): List of unknown arguments. + + Return: + list: Normalized list of arguments. + """ + normalized = [] + i = 0 + while i < len(unknown): + arg = unknown[i] + # Check if it's a flag (starts with --) + if arg.startswith('--'): + # Convert underscores to hyphens in the flag name + flag = arg.split('=')[0] + value = arg.split('=')[1] if '=' in arg else None + normalized_flag = flag.replace('_', '-') + normalized.append(f'{normalized_flag} {value}' if value is not None else normalized_flag) + else: + # It's a value, keep as-is + normalized.append(arg) + i += 1 + return normalized def _preprocess(self): """Preprocess/preparation operations before the benchmarking. @@ -126,7 +154,7 @@ def _preprocess(self): True if _preprocess() succeed. """ self.add_parser_arguments() - ret, self._args, unknown = self.parse_args() + ret, self._args, self._unknown_args = self.parse_args() if not ret: self._result = BenchmarkResult(self._name, self._benchmark_type, ReturnCode.INVALID_ARGUMENT) diff --git a/superbench/benchmarks/model_benchmarks/megatron_gpt3.py b/superbench/benchmarks/model_benchmarks/megatron_gpt3.py index 37d27bf1a..d0c61e6be 100644 --- a/superbench/benchmarks/model_benchmarks/megatron_gpt3.py +++ b/superbench/benchmarks/model_benchmarks/megatron_gpt3.py @@ -37,6 +37,7 @@ def __init__(self, name, parameters=''): """ super().__init__(name, parameters) self._supported_precision = [Precision.FLOAT32, Precision.FLOAT16, Precision.BFLOAT16] + self._ignore_unknown_args = True def add_parser_arguments(self): """Add the specified arguments.""" @@ -90,7 +91,7 @@ def add_parser_arguments(self): # Parallelism configs self._parser.add_argument('--zero_stage', type=int, default=1, help='Zero stage.') # Misc configs - self._parser.add_argument('--log-interval', type=int, required=False, default=1, help='Log interval.') + self._parser.add_argument('--log_interval', type=int, required=False, default=1, help='Log interval.') self._parser.add_argument('--eval_iters', type=int, default=0, help='Eval iters.') self._parser.add_argument('--eval_interval', type=int, default=10, help='Eval interval.') self._parser.add_argument('--num_save', type=int, default=10000, help='Num save.') @@ -187,7 +188,7 @@ def add_parser_arguments(self): ) self._parser.add_argument('--moe_ffn_hidden_size', type=int, help='MoE FFN hidden size.') self._parser.add_argument('--enable_shared_expert', action='store_true', help='Enable shared expert in MoE.') - self._parser.add_argument('--moe_layer_freq', type=int, help='MoE layer frequency.') + self._parser.add_argument('--moe_layer_freq', type=str, help='MoE layer frequency.') self._parser.add_argument('--num_shared_experts', type=int, help='Number of shared experts.') self._parser.add_argument('--moe_router_topk', type=int, help='Top-k routing for MoE.') self._parser.add_argument('--moe_aux_loss_coeff', type=float, help='Auxiliary loss coefficient.') @@ -232,8 +233,11 @@ def add_parser_arguments(self): ) def _preprocess(self): + """Preprocess with support for unknown args.""" if not super()._preprocess(): return False + + # Original MegatronGPT preprocessing logic if not self._args.code_base: if self._args.deepspeed: self._args.code_base = os.path.join( @@ -531,7 +535,9 @@ def _megatron_command(self, precision): # noqa: C901 command = f'deepspeed {script_path} {megatron_options} {self._data_options} {deepspeed_option}' else: command = f'torchrun {self._distributed_args} {script_path} {megatron_options} {self._data_options}' - + # Transparently append any unknown args captured during parsing for forward compatibility. + if getattr(self, '_unknown_args', None): + command = f"{command} {' '.join(self._unknown_args)}" return command def _train_step(self, precision): # noqa: E501 @@ -786,13 +792,74 @@ def _cal_params_count(self): '--load=deepseek-ai/DeepSeek-V2-Lite ' '--no_load_optim ' '--no_load_rng ' - '--ckpt_format=torch ' '--eod_mask_loss ' '--train_mode=pretrain ' - '--data_cache_path=/root/cache ' '--max_padding_length=4096 ' '--kv_lora_rank=512 ' - '--dataloader_type=cyclic' + '--dataloader_type=cyclic ' ), platform=Platform.ROCM ) +BenchmarkRegistry.register_benchmark( + 'megatron-deepseek-v2-lite', + MegatronGPT, + parameters=( + '--model=gpt ' + '--transformer_impl=transformer_engine ' + '--tokenizer_type=HuggingFaceTokenizer ' + '--tokenizer_model=/opt/superbench/third_party/Megatron/data/DeepSeek-V2-Lite ' + '--num_layers=27 ' + '--hidden_size=1024 ' + '--seq_len=4096 ' + '--num_attn_heads=16 ' + '--moe_ffn_hidden_size=1408 ' + '--ffn_hidden_size=10944 ' + '--dataloader_type=cyclic ' + '--num_experts=64 ' + '--no_async_tensor_model_parallel_allreduce ' + '--use_rotary_position_embeddings ' + '--no_gradient_accumulation_fusion ' + '--mock_data ' + '--use_flash_attn ' + '--no_load_optim ' + '--no_load_rng ' + '--swiglu ' + '--normalization=RMSNorm ' + '--norm_epsilon=1e-06 ' + '--no_bias_swiglu_fusion ' + '--no_rope_fusion ' + '--position_embedding_type=rope ' + '--untie_embeddings_and_output_weights ' + '--disable_bias_linear ' + '--ckpt_format=torch ' + '--rotary_percent=1.0 ' + '--rotary_base=10000 ' + '--rotary_scaling_factor=40 ' + '--eod_mask_loss ' + '--data_cache_path=/tmp/cache ' + '--moe_layer_freq="([0]+[1]*26)" ' + '--moe_router_topk=6 ' + '--moe_router_topk_scaling_factor=1.0 ' + '--moe_aux_loss_coeff=1e-3 ' + '--kv_lora_rank=512 ' + '--v_head_dim=128 ' + '--qk_head_dim=128 ' + '--qk_layernorm ' + '--qk_pos_emb_head_dim=64 ' + '--no_masked_softmax_fusion ' + '--kv_channels=16 ' + '--multi_latent_attention ' + '--moe_router_score_function=softmax ' + '--moe_router_topk=6 ' + '--moe_router_pre_softmax ' + '--moe_shared_expert_intermediate_size=2816 ' + '--moe_token_dispatcher_type=alltoall ' + '--moe_token_drop_policy=probs ' + '--make_vocab_size_divisible_by=3200 ' + '--attention_softmax_in_fp32 ' + '--use_mcore_models ' + '--mscale=0.707 ' + '--mscale_all_dim=0.707 ' + ), + platform=Platform.CUDA +) diff --git a/superbench/benchmarks/registry.py b/superbench/benchmarks/registry.py index 62f32868e..1be6e4138 100644 --- a/superbench/benchmarks/registry.py +++ b/superbench/benchmarks/registry.py @@ -84,7 +84,7 @@ def __parse_and_check_args(cls, name, class_def, parameters): benchmark = class_def(name, parameters) benchmark.add_parser_arguments() ret, args, unknown = benchmark.parse_args(ignore_invalid=True) - if not ret or len(unknown) >= 1: + if not ret or (len(unknown) >= 1 and not getattr(benchmark, '_ignore_unknown_args', False)): logger.log_and_raise( TypeError, 'Registered benchmark has invalid arguments - benchmark: {}, parameters: {}'.format(name, parameters) diff --git a/tests/benchmarks/model_benchmarks/test_megatron_gpt.py b/tests/benchmarks/model_benchmarks/test_megatron_gpt.py index b7c588677..2f99070a6 100644 --- a/tests/benchmarks/model_benchmarks/test_megatron_gpt.py +++ b/tests/benchmarks/model_benchmarks/test_megatron_gpt.py @@ -500,6 +500,55 @@ def test_deepseek_v2_command(self): self.assertEqual(actual_units, expected_units) + @mock.patch('superbench.benchmarks.model_benchmarks.MegatronGPT._generate_dataset') + def test_megatron_gpt_unknown_args(self, mock_generate_dataset): + """Test unknown args forwarding and normalization.""" + (benchmark_cls, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(self.benchmark_name, Platform.CUDA) + assert (benchmark_cls) + os.environ['OMPI_COMM_WORLD_SIZE'] = '1' + os.environ['OMPI_COMM_WORLD_LOCAL_SIZE'] = '1' + os.environ['OMPI_COMM_WORLD_RANK'] = '0' + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12345' + with open(self.hostfile_path, 'w') as f: + f.write('host1\n') + + # Test with unknown args that have underscores (should be converted to hyphens) + benchmark = benchmark_cls( + self.benchmark_name, + parameters=f'--code_base {self._tmp_dir} --hostfile {self.hostfile_path} ' + '--num_warmup 0 --num_steps 10 --batch_size 2048 ' + '--my_custom_flag 128 --another_option --third_param value', + ) + mock_generate_dataset.return_value = True + ret = benchmark._preprocess() + assert (ret is True) + + # Verify unknown args are stored and normalized + assert (hasattr(benchmark, '_unknown_args')) + assert (len(benchmark._unknown_args) > 0) + + # Check that underscores are converted to hyphens + assert ('--my-custom-flag' in benchmark._unknown_args) + assert ('128' in benchmark._unknown_args) + assert ('--another-option' in benchmark._unknown_args) + assert ('--third-param' in benchmark._unknown_args) + assert ('value' in benchmark._unknown_args) + + # Verify unknown args appear in the generated command + benchmark._data_options = '--mock-data' + command = benchmark._megatron_command(Precision.FLOAT32) + + # Check that normalized unknown args are in the command + assert ('--my-custom-flag 128' in command) + assert ('--another-option' in command) + assert ('--third-param value' in command) + + # Ensure original underscore versions are NOT in the command + assert ('--my_custom_flag' not in command) + assert ('--another_option' not in command) + assert ('--third_param' not in command) + @decorator.load_data('tests/data/megatron_deepspeed.log') @mock.patch('superbench.benchmarks.model_benchmarks.MegatronGPT._generate_dataset') def test_megatron_parse_log(self, raw_output, mock_generate_dataset): diff --git a/third_party/Makefile b/third_party/Makefile index 2a09f5990..fee844822 100755 --- a/third_party/Makefile +++ b/third_party/Makefile @@ -226,7 +226,14 @@ directx_amf_encoding_latency: megatron_lm: cd Megatron && \ apt install -y python3-mpi4py && \ - python -m pip install --no-cache-dir -r requirements.txt + python -m pip install --no-cache-dir -r requirements.txt && \ + mkdir -p data/gpt && \ + wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json -O data/gpt2-vocab.json && \ + wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt -O data/gpt2-merges.txt && \ + mkdir -p data/DeepSeek-V2-Lite && \ + wget https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/resolve/main/config.json -O data/DeepSeek-V2-Lite/config.json && \ + wget https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/resolve/main/tokenizer.json -O data/DeepSeek-V2-Lite/tokenizer.json && \ + wget https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/resolve/main/tokenizer_config.json -O data/DeepSeek-V2-Lite/tokenizer_config.json # Install requirements for Megatron-DeepSpeed megatron_deepspeed: diff --git a/third_party/Megatron/Megatron-LM b/third_party/Megatron/Megatron-LM index 52b7a18a0..6cc29a208 160000 --- a/third_party/Megatron/Megatron-LM +++ b/third_party/Megatron/Megatron-LM @@ -1 +1 @@ -Subproject commit 52b7a18a00bced8b3670eededfd58ee0c4bd7d06 +Subproject commit 6cc29a2081ec435c69e6614c9afceb9c9e99b666