This repository is built at the top of MEDUSA. MEDUSA is a method to accelerate the decoding of large language models (LLMs) by predicting multiple next tokens. It constructs candidate sequences by combining multiple next tokens with fixed patterns. Then it evaluates the candidate sequences in parallel by a customized attention mask which is called tree attention. This repository improves MEDUSA by introducing a dynamic tree attention mechanism. The decoding efficiency is improved in terms of tokens per inference.
First install the MEDUSA package. The core code is in the medusa_dynamic.py file.
import medusa_dynamicIn examples.ipynb, we provide a complete example to demonstrate how to run the model. We also visualize the dynamic tree structure.
If you want to generate data for benchmark. You should put gen_model_answer_medusa_dynamic.py into the llm_judge folder of MEDUSA repository. For more details, please refer to MEDUSA.
- Optimize the code to reduce the overhead.
- Explore better strategies to approximate the joint distribution of sequences. Currently, it is approximated by the Cartesian product of the marginal distributions.
For more details, please refer to our report.
@misc{zhang2025accelerationmultipleheadsdecoding,
title={Acceleration Multiple Heads Decoding for LLM via Dynamic Tree Attention},
author={Zhendong Zhang},
year={2025},
eprint={2502.05947},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2502.05947},
}