diff --git a/atom/config.py b/atom/config.py index 8ef9be071..c905daf48 100644 --- a/atom/config.py +++ b/atom/config.py @@ -130,7 +130,7 @@ class CompilationConfig: - FULL_AND_PIECEWISE. PIECEWISE mode build piecewise cudagraph only, keeping the cudagraph - incompatiable ops (i.e. some attention ops) outside the cudagraph + incompatible ops (i.e. some attention ops) outside the cudagraph for general flexibility. This is the default mode. @@ -337,9 +337,8 @@ def get_quant_config(config: PretrainedConfig) -> QuantizationConfig: f"{dtype_prefix}{bit}" if bit != 4 else f"{dtype_prefix}{bit}x2" ) quant_dtype = d_dtypes.get(quant_dtype_str, None) - assert ( - quant_dtype is not None - ), f"Cannot parse quant dtype from {orig_quant_config_str}" + if quant_dtype is None: + raise ValueError(f"Cannot parse quant dtype from {orig_quant_config_str}") if quant_dtype == d_dtypes["fp4x2"]: quant_type = QuantType.per_1x32 @@ -613,7 +612,10 @@ def _set_cudagraph_sizes(self): def __post_init__(self): # assert os.path.isdir(self.model) - assert 1 <= self.tensor_parallel_size <= 8 + if not (1 <= self.tensor_parallel_size <= 8): + raise ValueError( + f"tensor_parallel_size must be between 1 and 8, got {self.tensor_parallel_size}." + ) self.hf_config = get_hf_config(self.model) if not hasattr(self.hf_config, "rope_parameters"): # Compatible with both transformers < 5 @@ -659,7 +661,7 @@ def __post_init__(self): if self.speculative_config is not None: if self.speculative_config.num_speculative_tokens > 4: raise ValueError( - f"num_speculative_tokens must be between 1 and 4,, got {self.speculative_config.num_speculative_tokens}. " + f"num_speculative_tokens must be between 1 and 4, got {self.speculative_config.num_speculative_tokens}." ) def compute_hash(self) -> str: diff --git a/atom/model_engine/engine_core.py b/atom/model_engine/engine_core.py index 80f85fd90..77008627f 100644 --- a/atom/model_engine/engine_core.py +++ b/atom/model_engine/engine_core.py @@ -62,7 +62,7 @@ def __init__(self, config: Config, input_address: str, output_address: str): ) self.input_thread.start() - self.profile_enbaled = config.torch_profiler_dir is not None + self.profile_enabled = config.torch_profiler_dir is not None init_exit_handler(self) self._init_data_parallel(config) @@ -283,11 +283,11 @@ def process_output_sockets(self, output_address: str): break def start_profiler(self): - if self.profile_enbaled: + if self.profile_enabled: self.runner_mgr.call_func("start_profiler") def stop_profiler(self): - if self.profile_enbaled: + if self.profile_enabled: logger.info("Profiler stopping...") self.runner_mgr.call_func("stop_profiler", wait_out=True) logger.info("Profiler stopped.") diff --git a/atom/model_engine/llm_engine.py b/atom/model_engine/llm_engine.py index ef42263d5..6cfa28894 100644 --- a/atom/model_engine/llm_engine.py +++ b/atom/model_engine/llm_engine.py @@ -35,7 +35,6 @@ def __init__(self, model, **kwargs): # Set data parallel size in config config.parallel_config.data_parallel_size = data_parallel_size self.data_parallel_size = data_parallel_size - self.rquest_ids = set() self.io_processor = InputOutputProcessor( config, self.tokenizer, config.kv_cache_block_size ) diff --git a/atom/model_engine/scheduler.py b/atom/model_engine/scheduler.py index 46405682f..daff94c89 100644 --- a/atom/model_engine/scheduler.py +++ b/atom/model_engine/scheduler.py @@ -382,6 +382,7 @@ def postprocess( if self.spec_stats: self.spec_stats.update(num_new_token) idx = fwd_output.req_ids.index(seq.id) + num_rejected = 0 if is_deferred_out or self.use_spec: num_rejected = fwd_output.num_rejected[idx] num_bonus = fwd_output.num_bonus[idx] diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 324c10a9c..55129ca24 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -121,7 +121,10 @@ def _prefill(self, scheduler, seq): def _output(self, seq_id, tokens): return ScheduledBatchOutput( - token_ids={seq_id: tuple(tokens)}, draft_token_ids=None + token_ids={seq_id: tuple(tokens)}, + num_rejected=None, + num_bonus=None, + draft_token_ids=None, ) def test_appends_token(self, scheduler, seq_factory): @@ -166,7 +169,12 @@ def test_stop_token_ids(self, seq_factory): sched.schedule() finished = sched.postprocess( list(sched.running), - ScheduledBatchOutput(token_ids={seq.id: (99,)}, draft_token_ids=None), + ScheduledBatchOutput( + token_ids={seq.id: (99,)}, + num_rejected=None, + num_bonus=None, + draft_token_ids=None, + ), ) assert len(finished) == 1 assert "stop_99" in finished[0].leave_reason