This is the implementation for the paper Parallel Test-Time Scaling for Latent Reasoning Models, enabling efficient exploration of continuous thought spaces through stochastic sampling and reward model-guided search. It provides implementations of two stochastic sampling methods (Monte Carlo Dropout and Additive Gaussian Noise) and a LatentRM for best-of-N and beam search strategies. This repository includes training scripts, evaluation pipelines, and inference code for multiple backbone models including COCONUT, CODI, and CoLaR, evaluated on benchmarks such as GSM8K Test, GSM8K Hard, and MultiArith.
Important
π§© Full Transformers Integration All models (COCONUT, CODI, and CoLaR) are seamlessly integrated with Transformers, providing native support for:
- β Batch processing for efficient parallel inference
- β
Standard Transformers APIs (
generate(),from_pretrained(), etc.) - β
Device management with
device_mapand multi-GPU support - β Easy integration into existing Transformers-based workflows
Simply use model.generate() with batch inputs just like any other Transformers model!
π§ Stochastic Sampling Methods Two complementary approaches for exploring continuous thought spaces: Monte Carlo Dropout and Additive Gaussian Noise, enabling diverse reasoning path generation during inference.
π Latent Reward Model (LatentRM) A trained reward model that guides best-of-N selection and beam search, significantly improving reasoning accuracy by identifying high-quality latent reasoning paths.
- π Quick Start
- β¨ How It Works
- π Project Structure
- π€ Community
- π± Acknowledgements
- π Related Projects
- π Citation
conda create -n latenttts python=3.11 -y
conda activate latenttts
pip install -r requirements.txt- GPU: Recommended for training and inference (CUDA-compatible)
- Python: 3.11
- CUDA: Compatible with PyTorch 2.8.0
- Frameworks: PyTorch 2.8.0, Transformers 4.52.4, Accelerate 1.7.0
The datasets are located in the /data directory. These datasets are obtained from the coconut project.
Download the pre-trained models from HuggingFace to the checkpoints/ directory:
# Download COCONUT model
huggingface-cli download ModalityDance/latent-tts-coconut --local-dir checkpoints/coconut
# Download CODI model
huggingface-cli download ModalityDance/latent-tts-codi --local-dir checkpoints/codi
# Download CoLaR model
huggingface-cli download ModalityDance/latent-tts-colar --local-dir checkpoints/colar
# Optionally download LatentRM (for reward-guided generation)
huggingface-cli download ModalityDance/latent-tts-rm --local-dir checkpoints/latentRMSimple Generation Example
Here's a minimal example of using .generate() with a latent reasoning model:
from transformers import AutoTokenizer
from src.generation_mixin import LatentGenerationMixin, LatentGenerationConfig
from src.paths import MODELS
# Load tokenizer
model_type = "coconut" # or "codi", "colar"
model_id = MODELS[model_type]["id"]
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Get latent token IDs
latent_id = tokenizer.convert_tokens_to_ids("<|latent|>")
start_id = tokenizer.convert_tokens_to_ids("<|start-latent|>")
end_id = tokenizer.convert_tokens_to_ids("<|end-latent|>")
# Create model class with generation mixin
class LatentModel(MODELS[model_type]["class"], LatentGenerationMixin):
def __init__(self, config):
super().__init__(config)
# Load model
model = LatentModel.from_pretrained(
model_id,
latent_id=latent_id,
latent_start_id=start_id,
latent_end_id=end_id,
device_map="auto",
)
# Prepare input
question = "What is 2 + 2?\n<|start-latent|>"
inputs = tokenizer(question, return_tensors="pt").to(model.device)
# Configure generation
generation_config = LatentGenerationConfig(
max_new_tokens=512,
latent_length=6,
latent_do_sample=True,
latent_do_sample_by="dropout", # or "noise"
dropout_p=0.1,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
# Generate
output = model.generate(
**inputs,
generation_config=generation_config,
num_return_sequences=1,
)
# Decode result
result = tokenizer.decode(output[0], skip_special_tokens=True)
print(result)First, run the data annotation process to prepare training data for LatentRM:
./run_annotation.shThis script will:
- Process training data and validation data with specified batch size and sampling parameters
- Generate annotated data for LatentRM training
- Save results to the specified output directory
Configure your training parameters in the training_args/ directory. The main configuration file is train_coconut.yaml:
run_name: "run1"
metric_for_best_model: "test_n_64_recall_at_1"
output_dir: "/workspace/model-out/"
# ... other parametersNavigate to your project directory and launch training:
cd your/path/to/latent-tts
accelerate launch -m src.train training_args/train_coconut.yamlThe training process will:
- Load the annotated data from the previous step
- Train the latentRM with the specified configuration
- Save checkpoints and evaluation results
Note
Pre-trained checkpoint is available at HuggingFace.
Run comprehensive evaluation using majority voting and coverage metrics:
# For LLaMA model (CoLaR)
./run_tests_llama.sh
# For GPT-2 models (COCONUT and CODI)
./run_tests.shThese scripts will:
- Test different sampling strategies (dropout, noise)
- Evaluate on multiple datasets (GSM8K Test, MultiArith, GSM8K Hard)
- Generate detailed performance metrics including Pass@k, Coverage, and Voting Accuracy
For beam search evaluation:
./run_tts_with_rm.shThis script will:
- Test beam search with different
beam size(1, 2, 4, 8) - Test Best-of-N with different
n_return_sequences(1, 4, 16, 64) - Generate logs for different configurations
πͺ LatentTTS is built around a modular research pipeline for parallel test-time scaling of latent reasoning models, where each component corresponds to a well-defined stage in the overall method.
The system separates input processing, stochastic latent reasoning, and reward-guided selection into independent modules, allowing controlled experimentation and analysis.
This design enables flexible replacement of individual components (e.g., switching between dropout and noise sampling, or different backbone models) without affecting the rest of the pipeline.
At a high level, the workflow proceeds as follows:
- Input Processing and Tokenization β Raw problem inputs (e.g., math word problems) are tokenized and prepared with special latent tokens (
<|latent|>,<|start-latent|>,<|end-latent|>). The model processes these inputs through its embedding layer, setting up the context for latent reasoning generation. - Stochastic Latent Reasoning Generation β The model generates multiple diverse reasoning paths in the continuous latent space using one of two stochastic sampling methods: Monte Carlo Dropout (randomly dropping activations during forward passes to create variability) or Additive Gaussian Noise (injecting noise directly into latent embeddings). Each sampling method explores different regions of the latent thought space, producing varied reasoning trajectories for the same input.
- Reward-Guided Selection and Output Generation β The trained Latent Reward Model (LatentRM) evaluates the quality of each generated reasoning path by scoring latent embeddings. Based on these scores, the system applies either best-of-N selection (choosing the top-N highest-scoring paths) or beam search (maintaining multiple high-quality candidates during generation) to identify the most promising reasoning paths. The final answer is extracted from the selected path, significantly improving accuracy through parallel exploration and intelligent selection.
latent-tts/
βββ src/ # Source code
β βββ models/ # Model implementations
β β βββ coconut.py # COCONUT model
β β βββ codi.py # CODI model
β β βββ colar.py # CoLaR model
β β βββ gpt2.py # GPT-2 base models
β β βββ llama.py # LLaMA base models
β β βββ loss.py # Loss functions
β β βββ perturbation.py # Perturbation methods
β βββ annotate_data.py # Data annotation script
β βββ train.py # latentRM training script
β βββ trainer.py # Training utilities
β βββ infer_gpt2.py # GPT-2 inference
β βββ infer_llama.py # LLaMA inference
β βββ infer_gpt2_rm.py # latentRM-based inference
β βββ dataset.py # Dataset handling
β βββ generation_mixin.py # Generation utilities
β βββ paths.py # Path utilities
β βββ utils.py # Utility functions
βββ training_args/ # Training configurations
β βββ train_coconut.yaml # COCONUT training config
βββ data/ # Dataset files
βββ checkpoints/ # Model checkpoints
β βββ latentRM/ # latentRM checkpoint
| βββ coconut/
βββ run_annotation.sh # Data annotation script
βββ run_tests.sh # GPT-2 evaluation script
βββ run_tests_llama.sh # LLaMA evaluation script
βββ run_tts_with_rm.sh # Beam search evaluation script
βββ requirements.txt # Python dependencies
We welcome researchers, developers, and enthusiasts to join the LatentTTS community. You can participate by reporting issues, contributing features, or sharing feedback to help us improve and grow the project.
Tip
π Explore the paper on Hugging Face Papers β it includes community discussions, citation tools, and related resources. If you find our work insightful, please consider giving it an upvote to support further research!
We would like to thank the contributors, open-source projects, and research communities whose work made LatentTTS possible. This project builds upon ideas, tools, and datasets developed by the broader machine learning and reasoning research ecosystem. We also acknowledge helpful discussions and support from the members of Modality Dance Group and the open-source community.
This project is licensed under the MIT License. Please refer to the LICENSE file for more details.
- LLMs are Single-threaded Reasoners: Demystifying the Working Mechanism of Soft Thinking
Check out stochastic soft thinking!
-
Awesome Latent Space
A curated collection of resources on latent space methods and applications. -
Awesome Latent CoT
A comprehensive list of latent chain-of-thought reasoning resources. -
Awesome Efficient Reasoning
A collection of efficient reasoning methods and techniques.
If you use LatentTTS in your research or applications, please consider citing:
@misc{you2025paralleltesttimescalinglatent,
title={Parallel Test-Time Scaling for Latent Reasoning Models},
author={Runyang You and Yongqi Li and Meng Liu and Wenjie Wang and Liqiang Nie and Wenjie Li},
year={2025},
eprint={2510.07745},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2510.07745},
}