-
Notifications
You must be signed in to change notification settings - Fork 220
[tx] Paged Attention in SkyRL-tx #786
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces paged attention for the Qwen3 model, a valuable feature for optimizing memory usage during inference. The changes are well-structured, with a new paged_attention.py module for the core logic, integrations into the model, and new tests. However, I've identified a few critical issues in the core implementation that affect both correctness and performance. Specifically, the paged attention logic is missing causal masking, which is a correctness bug, and the cache update mechanism has a significant performance bottleneck. Additionally, the padding mask is not being passed correctly. Addressing these points is crucial before merging.
|
hey @pcmoritz, can you take a look and review the changes |
|
Thanks a lot for getting this started, I already took a first look -- I'm currently busy with getting the FSDP PR merged and will have a much closer look once that's done. Here are some high level comments of things I'm thinking about:
Anyways, these don't need to happen all at once and it is good to start with something small. If you want to do some more stuff in the mean time, I would encourage you to think a little more about the continuous batching problem which is a very real performance optimization, see also #769 which remediates the need a little bit (but doesn't solve the high variance in length distribution problem). We also need to implement early stopping etc. And your approach of copying together the chunks in the kv cache and then using a vanilla attention implementation makes a lot of sense as the first step for that (have you measured the performance impact of that btw?). Let's discuss more in this thread, I would love to hear your thoughts about these things :) |
The compresslevel parameter is not supported with streaming mode 'w|gz' in Python's tarfile module, causing BadRequestError in tests. Fixes test_api.py::test_training_workflow and test_api.py::test_sample[lora_model]
56ee5f6 to
cf0ae0a
Compare
|
I looked in to the repository, SGLang-JAX, and they are using memory cache and kernels to implement KV caching, their kernels are specific to TPUs. I have created a document for this I tried to integrate those with tinker and the inference results were 20% slower. when I did the integration I was getting about 480 token generated per/s. When I was going through skyrl-tx engine I realized that we can just use vllm/sglang directly for inference and that should give us better performance than trying to implement KV caching. And there is a placeholder for this in Tinker, so I implemented the external inference, there results are better with 920 tokens generated/sec. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces an external inference client, likely to leverage engines with paged attention like vLLM. The code is well-structured, separating external from local inference logic. However, there are several significant issues in the new ExternalInferenceClient. The error handling is problematic, as exceptions are swallowed, which can lead to silent failures. Additionally, the concurrent request handling has a potential resource exhaustion issue. I've provided specific comments and suggestions to address these points and improve the robustness and efficiency of the implementation.
| except Exception as e: | ||
| logger.error(f"Error generating with external engine: {e}") | ||
| # Return empty sequence on error | ||
| results.append( | ||
| types.GeneratedSequence( | ||
| stop_reason="stop", | ||
| tokens=[], | ||
| logprobs=[], | ||
| ) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Swallowing the exception and returning an empty GeneratedSequence hides the error from the caller. The caller will treat a failed request as a successful generation that produced no tokens, which is misleading and can lead to silent failures. The exception should be propagated to the caller so it can be handled appropriately. Re-raising the exception is a simple way to achieve this, which would cause the entire batch to fail as per the logic in engine.py. For more granular error handling, you could consider changing generate_batch to return both results and exceptions, allowing the caller to report partial failures.
except Exception as e:
logger.error(f"Error generating with external engine: {e}")
# Re-raise the exception to allow the caller to handle the error.
# Returning an empty sequence hides the failure.
raise| model=model_name, | ||
| lora_path=lora_path, | ||
| ) | ||
| except Exception as e: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Catching a broad Exception can mask underlying issues and make debugging difficult, as it can catch unexpected errors like KeyboardInterrupt or SystemExit. It's better to catch more specific exceptions. The ExternalInferenceClient could be modified to raise a custom, more specific exception (e.g., ExternalInferenceError) that can be caught here.
| import concurrent.futures | ||
|
|
||
| results = [] | ||
| with concurrent.futures.ThreadPoolExecutor(max_workers=len(prompts)) as executor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using len(prompts) as max_workers can lead to resource exhaustion if a large number of prompts are processed in a batch. This could create an excessive number of threads, potentially causing performance issues or even crashing the service. It's safer to cap the number of workers to a reasonable limit. This limit could also be made configurable.
| with concurrent.futures.ThreadPoolExecutor(max_workers=len(prompts)) as executor: | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=min(32, len(prompts))) as executor: |
|
|
||
| # Expand for num_samples | ||
| for _ in range(request_data.num_samples): | ||
| prompt_tokens = [token for chunk in request_data.prompt.chunks for token in chunk.tokens] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
|
||
| # Parse response | ||
| result = response.json() | ||
| choice = result["choices"][0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accessing result["choices"][0] directly can cause an IndexError if the choices list from the external API is empty. While this might be rare for successful calls, it's safer to validate that the list is not empty before accessing its first element to prevent unexpected crashes. For example: if not (choices := result.get("choices")): ...
Adding paged attention implementation in paged attention in skyrl-tx