Skip to content

Conversation

@aditchawdhary
Copy link
Contributor

Adding paged attention implementation in paged attention in skyrl-tx

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces paged attention for the Qwen3 model, a valuable feature for optimizing memory usage during inference. The changes are well-structured, with a new paged_attention.py module for the core logic, integrations into the model, and new tests. However, I've identified a few critical issues in the core implementation that affect both correctness and performance. Specifically, the paged attention logic is missing causal masking, which is a correctness bug, and the cache update mechanism has a significant performance bottleneck. Additionally, the padding mask is not being passed correctly. Addressing these points is crucial before merging.

@pcmoritz pcmoritz added the tx label Dec 17, 2025
@aditchawdhary
Copy link
Contributor Author

hey @pcmoritz, can you take a look and review the changes

@pcmoritz
Copy link
Collaborator

Thanks a lot for getting this started, I already took a first look -- I'm currently busy with getting the FSDP PR merged and will have a much closer look once that's done. Here are some high level comments of things I'm thinking about:

  • In my mind, the primary motivation to implement paged attention is to implement continuous batching (see e.g. https://www.anyscale.com/blog/continuous-batching-llm-inference), and the main motivation for that is if you have requests of very different lengths, you can return the requests that are already finished and return them, so their slots can be taken up by new requests. This can both lead to computational savings (a lot of the request length distribution has high variance and currently in that setting we would do a decent amount of padding) and also faster responses. There is some secondary motivation of making sure other features like prefix caching etc. are building on the right foundations (and there might be more). So we should think a little more about what the best way to implement continuous batching will be. In the multi-tenancy setting there could be a lot of interesting questions on how it interacts with training (e.g. sampling prefill and forward / backward passes have very similar computational characteristics and could be scheduled together), but we can ignore all of that to start off with and just think about the single tenant sampling -> training -> sampling -> training -> ... setting and see how that can be made more efficient in the context of large variance in request length.
  • Going forward, I think we should have only one format of KV cache for the tx native engine that works well. Most likely a paged one and we should study which data structures make the most sense (in my experience, deciding on the data representation is the most important and the code follows). In particular, we should study both sglang and vllm.
  • It is also good to study which GPU / TPU kernels are available (e.g. @atemaguer pointed out https://github.com/sgl-project/sglang-jax/blob/main/python/sgl_jax/srt/kernels/ragged_paged_attention/ragged_paged_attention.py). Most likely we will need some sort of abstraction for the attention layer going forward, since there are many different kernels and techniques. Probably not the highest priority right now.

Anyways, these don't need to happen all at once and it is good to start with something small. If you want to do some more stuff in the mean time, I would encourage you to think a little more about the continuous batching problem which is a very real performance optimization, see also #769 which remediates the need a little bit (but doesn't solve the high variance in length distribution problem). We also need to implement early stopping etc. And your approach of copying together the chunks in the kv cache and then using a vanilla attention implementation makes a lot of sense as the first step for that (have you measured the performance impact of that btw?).

Let's discuss more in this thread, I would love to hear your thoughts about these things :)

The compresslevel parameter is not supported with streaming mode 'w|gz'
in Python's tarfile module, causing BadRequestError in tests.

Fixes test_api.py::test_training_workflow and test_api.py::test_sample[lora_model]
@aditchawdhary
Copy link
Contributor Author

aditchawdhary commented Dec 28, 2025

I looked in to the repository, SGLang-JAX, and they are using memory cache and kernels to implement KV caching, their kernels are specific to TPUs. I have created a document for this I tried to integrate those with tinker and the inference results were 20% slower.

Current Benchmark:
python benchmarks/benchmark_engine.py   --base-model Qwen/Qwen3-0.6B   --benchmark sample   --num-steps 5   --num-requests 4   --seq-len 128   --sample-max-tokens 64
=== Sampling Benchmark ===
Warming up (2 steps)...
INFO:     JIT compiling for sample seq_len=128 in progress...
INFO:     JIT compilation for sample seq_len=128 took 34.50s
Running benchmark (5 steps)...

Results:
  steps:                5
  elapsed:              1.954 s
  tokens generated/sec: 655
  sec/step:              0.39

when I did the integration I was getting about 480 token generated per/s.

When I was going through skyrl-tx engine I realized that we can just use vllm/sglang directly for inference and that should give us better performance than trying to implement KV caching.

And there is a placeholder for this in Tinker, so I implemented the external inference, there results are better with 920 tokens generated/sec.

(main) root@C.29297665:/workspace$ CUDA_VISIBLE_DEVICES=0 vllm serve Qwen/Qwen3-0.6B --port 8000

(main) root@C.29297665:/workspace/skyrl-tx$ CUDA_VISIBLE_DEVICES=1 python benchmarks/benchmark_engine.py   --base-model Qwen/Qwen3-0.6B   --benchmark sample   --external-inference-url http://127.0.0.1:8000   --num-steps 5   --num-requests 4   --seq-len 128   --sample-max-tokens 64

=== Sampling Benchmark ===
Warming up (2 steps)...
Running benchmark (5 steps)...

Results:
  steps:                5
  elapsed:              1.392 s
  tokens generated/sec: 920
  sec/step:              0.28

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an external inference client, likely to leverage engines with paged attention like vLLM. The code is well-structured, separating external from local inference logic. However, there are several significant issues in the new ExternalInferenceClient. The error handling is problematic, as exceptions are swallowed, which can lead to silent failures. Additionally, the concurrent request handling has a potential resource exhaustion issue. I've provided specific comments and suggestions to address these points and improve the robustness and efficiency of the implementation.

Comment on lines +75 to +84
except Exception as e:
logger.error(f"Error generating with external engine: {e}")
# Return empty sequence on error
results.append(
types.GeneratedSequence(
stop_reason="stop",
tokens=[],
logprobs=[],
)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Swallowing the exception and returning an empty GeneratedSequence hides the error from the caller. The caller will treat a failed request as a successful generation that produced no tokens, which is misleading and can lead to silent failures. The exception should be propagated to the caller so it can be handled appropriately. Re-raising the exception is a simple way to achieve this, which would cause the entire batch to fail as per the logic in engine.py. For more granular error handling, you could consider changing generate_batch to return both results and exceptions, allowing the caller to report partial failures.

                except Exception as e:
                    logger.error(f"Error generating with external engine: {e}")
                    # Re-raise the exception to allow the caller to handle the error.
                    # Returning an empty sequence hides the failure.
                    raise

model=model_name,
lora_path=lora_path,
)
except Exception as e:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Catching a broad Exception can mask underlying issues and make debugging difficult, as it can catch unexpected errors like KeyboardInterrupt or SystemExit. It's better to catch more specific exceptions. The ExternalInferenceClient could be modified to raise a custom, more specific exception (e.g., ExternalInferenceError) that can be caught here.

import concurrent.futures

results = []
with concurrent.futures.ThreadPoolExecutor(max_workers=len(prompts)) as executor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using len(prompts) as max_workers can lead to resource exhaustion if a large number of prompts are processed in a batch. This could create an excessive number of threads, potentially causing performance issues or even crashing the service. It's safer to cap the number of workers to a reasonable limit. This limit could also be made configurable.

Suggested change
with concurrent.futures.ThreadPoolExecutor(max_workers=len(prompts)) as executor:
with concurrent.futures.ThreadPoolExecutor(max_workers=min(32, len(prompts))) as executor:


# Expand for num_samples
for _ in range(request_data.num_samples):
prompt_tokens = [token for chunk in request_data.prompt.chunks for token in chunk.tokens]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The list comprehension to create prompt_tokens is inside the for _ in range(request_data.num_samples): loop. This is inefficient as it reconstructs the same list of tokens multiple times if num_samples > 1. For better performance, this list should be created once before the loop.


# Parse response
result = response.json()
choice = result["choices"][0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Accessing result["choices"][0] directly can cause an IndexError if the choices list from the external API is empty. While this might be rare for successful calls, it's safer to validate that the list is not empty before accessing its first element to prevent unexpected crashes. For example: if not (choices := result.get("choices")): ...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants