Skip to content

ModalityDance/LatentTTS

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

13 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

Parallel Test-Time Scaling for Latent Reasoning Models


This is the implementation for the paper Parallel Test-Time Scaling for Latent Reasoning Models, enabling efficient exploration of continuous thought spaces through stochastic sampling and reward model-guided search. It provides implementations of two stochastic sampling methods (Monte Carlo Dropout and Additive Gaussian Noise) and a LatentRM for best-of-N and beam search strategies. This repository includes training scripts, evaluation pipelines, and inference code for multiple backbone models including COCONUT, CODI, and CoLaR, evaluated on benchmarks such as GSM8K Test, GSM8K Hard, and MultiArith.

πŸͺ Key Features

Important

🧩 Full Transformers Integration All models (COCONUT, CODI, and CoLaR) are seamlessly integrated with Transformers, providing native support for:

  • βœ… Batch processing for efficient parallel inference
  • βœ… Standard Transformers APIs (generate(), from_pretrained(), etc.)
  • βœ… Device management with device_map and multi-GPU support
  • βœ… Easy integration into existing Transformers-based workflows

Simply use model.generate() with batch inputs just like any other Transformers model!

🧭 Stochastic Sampling Methods Two complementary approaches for exploring continuous thought spaces: Monte Carlo Dropout and Additive Gaussian Noise, enabling diverse reasoning path generation during inference.

🌌 Latent Reward Model (LatentRM) A trained reward model that guides best-of-N selection and beam search, significantly improving reasoning accuracy by identifying high-quality latent reasoning paths.

πŸ“‘ Table of Contents

πŸš€ Quick Start

1. Installation

Conda (recommended)

conda create -n latenttts python=3.11 -y
conda activate latenttts
pip install -r requirements.txt

Hardware Requirements

  • GPU: Recommended for training and inference (CUDA-compatible)
  • Python: 3.11
  • CUDA: Compatible with PyTorch 2.8.0
  • Frameworks: PyTorch 2.8.0, Transformers 4.52.4, Accelerate 1.7.0

2. Preparation

Dataset

The datasets are located in the /data directory. These datasets are obtained from the coconut project.

Latent Reasoning Models

Download the pre-trained models from HuggingFace to the checkpoints/ directory:

# Download COCONUT model
huggingface-cli download ModalityDance/latent-tts-coconut --local-dir checkpoints/coconut

# Download CODI model
huggingface-cli download ModalityDance/latent-tts-codi --local-dir checkpoints/codi

# Download CoLaR model
huggingface-cli download ModalityDance/latent-tts-colar --local-dir checkpoints/colar

# Optionally download LatentRM (for reward-guided generation)
huggingface-cli download ModalityDance/latent-tts-rm --local-dir checkpoints/latentRM

Simple Generation Example

Here's a minimal example of using .generate() with a latent reasoning model:

from transformers import AutoTokenizer
from src.generation_mixin import LatentGenerationMixin, LatentGenerationConfig
from src.paths import MODELS

# Load tokenizer
model_type = "coconut"  # or "codi", "colar"
model_id = MODELS[model_type]["id"]
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Get latent token IDs
latent_id = tokenizer.convert_tokens_to_ids("<|latent|>")
start_id = tokenizer.convert_tokens_to_ids("<|start-latent|>")
end_id = tokenizer.convert_tokens_to_ids("<|end-latent|>")

# Create model class with generation mixin
class LatentModel(MODELS[model_type]["class"], LatentGenerationMixin):
    def __init__(self, config):
        super().__init__(config)

# Load model
model = LatentModel.from_pretrained(
    model_id,
    latent_id=latent_id,
    latent_start_id=start_id,
    latent_end_id=end_id,
    device_map="auto",
)

# Prepare input
question = "What is 2 + 2?\n<|start-latent|>"
inputs = tokenizer(question, return_tensors="pt").to(model.device)

# Configure generation
generation_config = LatentGenerationConfig(
    max_new_tokens=512,
    latent_length=6,
    latent_do_sample=True,
    latent_do_sample_by="dropout",  # or "noise"
    dropout_p=0.1,
    pad_token_id=tokenizer.pad_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

# Generate
output = model.generate(
    **inputs,
    generation_config=generation_config,
    num_return_sequences=1,
)

# Decode result
result = tokenizer.decode(output[0], skip_special_tokens=True)
print(result)

Data Annotation

First, run the data annotation process to prepare training data for LatentRM:

./run_annotation.sh

This script will:

  • Process training data and validation data with specified batch size and sampling parameters
  • Generate annotated data for LatentRM training
  • Save results to the specified output directory

3. Running

Training Configuration

Configure your training parameters in the training_args/ directory. The main configuration file is train_coconut.yaml:

run_name: "run1"
metric_for_best_model: "test_n_64_recall_at_1"
output_dir: "/workspace/model-out/"
# ... other parameters

Model Training

Navigate to your project directory and launch training:

cd your/path/to/latent-tts
accelerate launch -m src.train training_args/train_coconut.yaml

The training process will:

  • Load the annotated data from the previous step
  • Train the latentRM with the specified configuration
  • Save checkpoints and evaluation results

Note

Pre-trained checkpoint is available at HuggingFace.

Evaluation and Testing

Majority Voting and Coverage Testing

Run comprehensive evaluation using majority voting and coverage metrics:

# For LLaMA model (CoLaR)
./run_tests_llama.sh

# For GPT-2 models (COCONUT and CODI)
./run_tests.sh

These scripts will:

  • Test different sampling strategies (dropout, noise)
  • Evaluate on multiple datasets (GSM8K Test, MultiArith, GSM8K Hard)
  • Generate detailed performance metrics including Pass@k, Coverage, and Voting Accuracy
Beam Search and Best-of-N Testing

For beam search evaluation:

./run_tts_with_rm.sh

This script will:

  • Test beam search with different beam size (1, 2, 4, 8)
  • Test Best-of-N with different n_return_sequences (1, 4, 16, 64)
  • Generate logs for different configurations

✨ How It Works

πŸͺ LatentTTS is built around a modular research pipeline for parallel test-time scaling of latent reasoning models, where each component corresponds to a well-defined stage in the overall method.
The system separates input processing, stochastic latent reasoning, and reward-guided selection into independent modules, allowing controlled experimentation and analysis.
This design enables flexible replacement of individual components (e.g., switching between dropout and noise sampling, or different backbone models) without affecting the rest of the pipeline.

At a high level, the workflow proceeds as follows:

  1. Input Processing and Tokenization β€” Raw problem inputs (e.g., math word problems) are tokenized and prepared with special latent tokens (<|latent|>, <|start-latent|>, <|end-latent|>). The model processes these inputs through its embedding layer, setting up the context for latent reasoning generation.
  2. Stochastic Latent Reasoning Generation β€” The model generates multiple diverse reasoning paths in the continuous latent space using one of two stochastic sampling methods: Monte Carlo Dropout (randomly dropping activations during forward passes to create variability) or Additive Gaussian Noise (injecting noise directly into latent embeddings). Each sampling method explores different regions of the latent thought space, producing varied reasoning trajectories for the same input.
  3. Reward-Guided Selection and Output Generation β€” The trained Latent Reward Model (LatentRM) evaluates the quality of each generated reasoning path by scoring latent embeddings. Based on these scores, the system applies either best-of-N selection (choosing the top-N highest-scoring paths) or beam search (maintaining multiple high-quality candidates during generation) to identify the most promising reasoning paths. The final answer is extracted from the selected path, significantly improving accuracy through parallel exploration and intelligent selection.

πŸ“ Project Structure

latent-tts/
β”œβ”€β”€ src/                   # Source code
β”‚   β”œβ”€β”€ models/            # Model implementations
β”‚   β”‚   β”œβ”€β”€ coconut.py     # COCONUT model
β”‚   β”‚   β”œβ”€β”€ codi.py        # CODI model
β”‚   β”‚   β”œβ”€β”€ colar.py       # CoLaR model
β”‚   β”‚   β”œβ”€β”€ gpt2.py        # GPT-2 base models
β”‚   β”‚   β”œβ”€β”€ llama.py       # LLaMA base models
β”‚   β”‚   β”œβ”€β”€ loss.py        # Loss functions
β”‚   β”‚   └── perturbation.py # Perturbation methods
β”‚   β”œβ”€β”€ annotate_data.py   # Data annotation script
β”‚   β”œβ”€β”€ train.py           # latentRM training script
β”‚   β”œβ”€β”€ trainer.py         # Training utilities
β”‚   β”œβ”€β”€ infer_gpt2.py      # GPT-2 inference
β”‚   β”œβ”€β”€ infer_llama.py     # LLaMA inference
β”‚   β”œβ”€β”€ infer_gpt2_rm.py   # latentRM-based inference
β”‚   β”œβ”€β”€ dataset.py         # Dataset handling
β”‚   β”œβ”€β”€ generation_mixin.py # Generation utilities
β”‚   β”œβ”€β”€ paths.py           # Path utilities
β”‚   └── utils.py           # Utility functions
β”œβ”€β”€ training_args/         # Training configurations
β”‚   └── train_coconut.yaml # COCONUT training config
β”œβ”€β”€ data/                  # Dataset files
β”œβ”€β”€ checkpoints/           # Model checkpoints
β”‚   └── latentRM/          # latentRM checkpoint
|   └── coconut/
β”œβ”€β”€ run_annotation.sh      # Data annotation script
β”œβ”€β”€ run_tests.sh           # GPT-2 evaluation script
β”œβ”€β”€ run_tests_llama.sh     # LLaMA evaluation script
β”œβ”€β”€ run_tts_with_rm.sh     # Beam search evaluation script
└── requirements.txt       # Python dependencies

🀝 Join the Community

We welcome researchers, developers, and enthusiasts to join the LatentTTS community. You can participate by reporting issues, contributing features, or sharing feedback to help us improve and grow the project.

Tip

πŸ“„ Explore the paper on Hugging Face Papers β€” it includes community discussions, citation tools, and related resources. If you find our work insightful, please consider giving it an upvote to support further research!

🌱 Acknowledgements

We would like to thank the contributors, open-source projects, and research communities whose work made LatentTTS possible. This project builds upon ideas, tools, and datasets developed by the broader machine learning and reasoning research ecosystem. We also acknowledge helpful discussions and support from the members of Modality Dance Group and the open-source community.

This project is licensed under the MIT License. Please refer to the LICENSE file for more details.

πŸ”— Related Projects

πŸ“„ Related Papers

🌟 Awesome Collections

πŸ“š Citation

If you use LatentTTS in your research or applications, please consider citing:

@misc{you2025paralleltesttimescalinglatent,
      title={Parallel Test-Time Scaling for Latent Reasoning Models}, 
      author={Runyang You and Yongqi Li and Meng Liu and Wenjie Wang and Liqiang Nie and Wenjie Li},
      year={2025},
      eprint={2510.07745},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2510.07745}, 
}

About

"Parallel Test-Time Scaling for Latent Reasoning Models"

Resources

License

Stars

Watchers

Forks

Contributors