diff --git a/main.py b/main.py index 8d41c5f..9891232 100644 --- a/main.py +++ b/main.py @@ -19,7 +19,7 @@ 'block_size': 256, 'world_size': 1, 'model_name_or_path': 'Qwen/Qwen3-0.6B', - 'enforce_eager': True, + 'enforce_eager': False, 'vocab_size': 151936, # Fixed: was 151643, HF model uses 151936 'hidden_size': 1024, 'num_heads': 16, diff --git a/src/myvllm/engine/model_runner.py b/src/myvllm/engine/model_runner.py index 5dd66ce..457ab30 100644 --- a/src/myvllm/engine/model_runner.py +++ b/src/myvllm/engine/model_runner.py @@ -285,12 +285,14 @@ def prepare_prefill(self, seqs: list[Sequence]) -> torch.Tensor: seqlens_k.append(len(token_ids)) cu_seqlens_q.append(cu_seqlens_q[-1] + seqlens_q[-1]) cu_seqlens_k.append(cu_seqlens_k[-1] + seqlens_k[-1]) + # by token generate slot_mapping if seq.block_table: - for i, block_id in enumerate(seq.block_table[seq.num_cached_blocks:]): - if seq.num_cached_blocks + i != seq.num_blocks - 1: - slot_mappings.extend(list(range(block_id * self.block_size, (block_id+1) * self.block_size))) - else: - slot_mappings.extend(list(range(block_id * self.block_size, block_id * self.block_size + seq.last_block_num_tokens))) + for pos in range(num_cached_tokens, len(token_ids)): + block_idx = pos // self.block_size + block_offset = pos % self.block_size + block_id = seq.block_table[block_idx] + slot_mappings.append(block_id * self.block_size + block_offset) + if cu_seqlens_q[-1] < cu_seqlens_k[-1]: # pad block_tables all_block_tables = [seq.block_table for seq in seqs] @@ -404,7 +406,7 @@ def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]: # (later use graph.replay() to run the captured graph) @torch.inference_mode() def capture_cudagraph(self) -> None: - max_bs = self.config['max_num_seqs'] + max_bs = self.config['max_num_sequences'] max_len = self.config['max_model_length'] max_num_blocks = math.ceil(max_len / self.block_size) # for decoding, input is always of shape (batch_size, 1) @@ -417,7 +419,7 @@ def capture_cudagraph(self) -> None: # where to read KV values in the cache block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32, device=f'cuda:{self.rank}') # output logits - outputs = torch.zeros(max_bs, self.config['vocab_size'], device=f'cuda:{self.rank}') + outputs = torch.zeros(max_bs, self.config['hidden_size'], device=f'cuda:{self.rank}') # graphs to be captured for different batch sizes batch_sizes = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16)) @@ -440,8 +442,9 @@ def capture_cudagraph(self) -> None: with torch.cuda.graph(graph, graph_pool): outputs[:batch_size] = self.model(input_ids[:batch_size]) - if graph_pool is None: - graph_pool = graph.pool() + + if graph_pool is None: + graph_pool = graph.pool() # store the captured graph self.graphs[batch_size] = graph