Skip to content

Experiment: Reconsideration steering vector increases accuracy of R1-Llama-8B on MATH500 #28

@aranguri

Description

@aranguri

TL;DR

We can increase the performance of R1-Llama-8B on MATH500 by steering in the 'Reconsideration' direction!

Steering Coef. Accuracy CoT Length
0.0 90.00% 3078
0.1 93.75% 4415

More generally, the 'Reconsideration' direction trades off accuracy with CoT length.

Description

In this experiment, @wendlerc found a steering vector that roughly represents reconsideration behavior in the chain of thought of R1-Llama-8B using the GSM8K dataset.

We show here that this steering vector trades performance and CoT length in the MATH500 dataset. More precisely, adding this vector increases accuracy while producing longer CoT, and subtracting this vector gives shorter CoTs with a decrease in accuracy.

Subtracting vector. For layers 12 to 15 of R1-Llama-8B, we take a question from MATH500 (see prompt template below), add the corresponding steering vectors found by @wendlerc at the residual stream for the last token and sample 64 new tokens. We then add the steering vectors at the new last token, and sample 64 tokens again, and so on until the model outputs <EOS> with an upper bound of 16384 tokens. We get that the accuracy and CoT lengths drop as we have a smaller coefficient.

Image

This is computed using 110 MATH500 questions. The CoT length plot is the average CoT length among the questions that were answered correctly.

Adding vector. Using the 110 MATH500 questions, I looked at the questions that the vanilla (no steering coefficient) R1-Llama-8B got wrong, and optimized over the steering coefficient and number of tokens that we sample before steering again (in the previous experiment it was 64) and found that using a steering coefficient of 0.1 and steering after every 4096 tokens can solve around half of the questions that the vanilla model got wrong.

I then took a fresh sample of 90 MATH500 questions as a test set, and compared the vanilla R1-Llama-8B and the modified version with steering coefficient 0.1 and steering after every 4096 tokens. We see that, at the cost of longer CoT, we can increase the accuracy of R1-Llama-8B!

Steering Coef. Accuracy CoT Length
0.0 90.00% 3078
0.1 93.75% 4415

Additional points

  • Code is here
  • The template we use for feeding the question is
    <|begin▁of▁sentence|><|User|>[MATH500 question] Please reason step by step, and put your final answer within \\boxed{{}}.<|Assistant|><think>\n
  • I believe this is a more principled version of the s1: simple test-time scaling paper who add the Wait token to replace <EOS> at the end of a CoT and continue sampling, since we are working with the representation of the reconsideration behavior instead of working at token space. A next step is to compare these two methods, and check whether we get better accuracy with the method proposed here.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions