You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
R1-distilled models tend to get the correct answer multiple times in its thought process but continues to backtrack and check if the solution is correct.
How confident is the model at each backtracking instance? (here each backtrack means the model continues thinking after it gets the answer)
Experiment setting: At each instance when the model gets the answer. Early exit and ask for the final answer. Check the logit and accuracy of the answer. I picked samples with backtracking counts between 4 and 10. (Some of it has very large backtracking instances which may be due to how the design of function used to check that the model has arrive at the answer - string checking which could be flawed)
Confidence of answer increases as the model backtrack more.
Suprisingly, when the model got the answer with a lower count, if we were to prompted it again at the 1st instance, the accuracy is 0.9 and does not increase after further backtracking and similarly arriving at the same solution.
However, we do see an increase in confidence via the logits (which are normalized across the backtracking within each sample). The model gets more confident everytime it backtracks and arrives at the same decision, thus the logits increases with increasing occurences.
Most of the time, we can see that model can get the answer correct once it arrives at the solution, could we use this as a way to save computation and early exit the thought process? But need to do so without having access to the right solution. Is there any backtracking tokens that we can use to see that once the model starts backtracking -> early exit.
Importance of thought process
Judging from the x-axis at 0 where the model immediately generates the answer without thinking, it fails to get the answer correct most of the time. However its not clear if the thought process is faithful or if the model uses the thought process as a sketchpad (continuous space) to derive the answer.
When the model backtrack, which tokens are the model looking at? Does it look at the most recent tokens and think that "oh the steps are not concrete, i should think somemore"? Can we use some attribution methods such as finding important attention heads and looking at the attention pattern? -> can bin context into intervals and measure attention scores assigned to each bin (Below)
Experiment 1) Steering the model to backtrack and wrap up
I found backtracking hard to implement in practice by trying two ways to steer:
Inject high coefficient steering once at the last input token: this only affects the 1st output token and do not actually affect the remaining tokens, which presumably is due to insufficient attention given to it by the attention modules.
Continuous injection across all output tokens: Hard to control and very easily result in degenerate response. Moreover, the model will be trapped in an endless over-thinking and not conclude it's thoughts.
I found that steering it to wrap up its thoughts might be easier, since it has an ending though it sometimes can still cause degenerative respones.
Following the logit lens work done by @ajyl in https://ajyl.github.io/reasoning/2025/02/27/mlp-value-vecs.html, we can do something similar here. I first filter out the samples that have such an occurence: when the model backtracks for the first time, it predicts the next token "Wait". Apparently, such occurences are pretty high (around 40% of the dataset). The model probably predicts this token as a starting sign to start re-thinking its previous calculations.
Logit lens plot (resid_pre is the raw state after embeds, and mlp/attn is adding the computations from the respective modules at each layer)
This plot also shows that MLP layers have high direct effects as compared to attention heads, which may be less surprising, as MLP are responsible for transforming local states and encoding information while attention writes previous informations.
Following the probe-analysis at finding relevant value vectors in the MLP (most layers can predict backtrack/wrap almost perfectly from 6-30):
Most of the relevant vectos appear to be in the middle layers (based on user interpretation), i see lesser relevant ones for wrapping up - since its not sure what kind of words to look for.
Next, i want to see if we can look at tokens that are highly causal for the prediction of the backtrack token, "wait". Does looking at these tokens show why does the model decide to backtrack?
I used attribution patching on the residual states, and assign the score to each token as the max over all layers. Since we dont have a corrupt input, i use mean-ablation as the corrupted state by averaging over the dataset and sequence for each layer.
Some examples (answer is the final answer predicted and rank is the importance rank over all words):
I cherry-picked these examples, where we can actually find previous tokens that relates to re-checking.
Another thing to ponder is that if the model backtracks its solution, then does the backtracking response look at those important words that are causal for backtracking?
Doing a simple analysis, computing the % of the top 10 important words (after removing stopwords) appearing in the next 100 words: approx 41% overlaps.
I also looked at the importance spread out over the question and CoT (binned into 20% intervals):
It seems that the question context and immediate CoT is responsible for the model deciding to backtrack.
Next Experiment
The above seems to provide some form of plausible evidence on why the model decides to backtrack. The 1st experiment shows that even if we were to truncate the model's thought immediately on the 1st instance of backtracking, we can still get almost perfect performance.
An interesting next experiment is to see if we can use a faithfulness test such as the one in Lanham, by inserting mistakes in the immediate CoT right before it backtracks and see if the backtracking corrects it.
How does early exiting at each instance of backtracking affect the model's prediction
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
dataset: GSM8K
R1-distilled models tend to get the correct answer multiple times in its thought process but continues to backtrack and check if the solution is correct.
How confident is the model at each backtracking instance? (here each backtrack means the model continues thinking after it gets the answer)
Experiment setting: At each instance when the model gets the answer. Early exit and ask for the final answer. Check the logit and accuracy of the answer. I picked samples with backtracking counts between 4 and 10. (Some of it has very large backtracking instances which may be due to how the design of function used to check that the model has arrive at the answer - string checking which could be flawed)
Logits: [14.75, 23.98, 24.64, 25.11, 25.98]
Logits: [14.77, 24.92, 25.47, 26.08, 26.43, 26.96]
Logits: [15.19, 24.88, 25.49, 26.01, 26.39, 26.8, 27.49]
Logits: [14.96, 24.4, 25.12, 25.68, 26.06, 26.39, 26.59, 27.13]
Logits: [15.03, 24.37, 24.94, 25.65, 26.13, 26.49, 26.84, 26.97, 27.42]
Logits: [14.79, 24.51, 25.26, 25.83, 26.19, 26.63, 26.86, 27.11, 27.19, 27.61]
Confidence of answer increases as the model backtrack more.
Suprisingly, when the model got the answer with a lower count, if we were to prompted it again at the 1st instance, the accuracy is 0.9 and does not increase after further backtracking and similarly arriving at the same solution.
However, we do see an increase in confidence via the logits (which are normalized across the backtracking within each sample). The model gets more confident everytime it backtracks and arrives at the same decision, thus the logits increases with increasing occurences.
Most of the time, we can see that model can get the answer correct once it arrives at the solution, could we use this as a way to save computation and early exit the thought process? But need to do so without having access to the right solution. Is there any backtracking tokens that we can use to see that once the model starts backtracking -> early exit.
Importance of thought process
Judging from the x-axis at 0 where the model immediately generates the answer without thinking, it fails to get the answer correct most of the time. However its not clear if the thought process is faithful or if the model uses the thought process as a sketchpad (continuous space) to derive the answer.
Experiment (14 Mar)
In Experiment: The surprising effectiveness of GSM8K steering vectors for reconsideration/completion #8 , @wendlerc found a backtracking vector which if steered can cause the model to either backtrack or wrap up its thoughts. On samples where the model still gets the answer wrong, could we steer it to backtrack and see if it can get the answer correct? (Difficult to work)
When the model backtrack, which tokens are the model looking at? Does it look at the most recent tokens and think that "oh the steps are not concrete, i should think somemore"? Can we use some attribution methods such as finding important attention heads and looking at the attention pattern? -> can bin context into intervals and measure attention scores assigned to each bin (Below)
Experiment 1) Steering the model to backtrack and wrap up
I found backtracking hard to implement in practice by trying two ways to steer:
Inject high coefficient steering once at the last input token: this only affects the 1st output token and do not actually affect the remaining tokens, which presumably is due to insufficient attention given to it by the attention modules.
Continuous injection across all output tokens: Hard to control and very easily result in degenerate response. Moreover, the model will be trapped in an endless over-thinking and not conclude it's thoughts.
I found that steering it to wrap up its thoughts might be easier, since it has an ending though it sometimes can still cause degenerative respones.
Experiment 2) Understanding backtracking - logit lens
Following the logit lens work done by @ajyl in https://ajyl.github.io/reasoning/2025/02/27/mlp-value-vecs.html, we can do something similar here. I first filter out the samples that have such an occurence: when the model backtracks for the first time, it predicts the next token "Wait". Apparently, such occurences are pretty high (around 40% of the dataset). The model probably predicts this token as a starting sign to start re-thinking its previous calculations.
Logit lens plot (resid_pre is the raw state after embeds, and mlp/attn is adding the computations from the respective modules at each layer)
We can see that common words include [wait, therefore], and in some cases 'alternatively' and also most of the spike in logits are due to MLP layers -> a similar finding in https://ajyl.github.io/reasoning/2025/02/27/mlp-value-vecs.html.
This plot also shows that MLP layers have high direct effects as compared to attention heads, which may be less surprising, as MLP are responsible for transforming local states and encoding information while attention writes previous informations.
Following the probe-analysis at finding relevant value vectors in the MLP (most layers can predict backtrack/wrap almost perfectly from 6-30):
Layer 16 vec 10967: ['ンピ', ' thoughts', 'solution', ' Solution', ' Thoughts', 'Solution', '.solution', ' verdict', ' solution', 'öst', ' Analysis', 'Thought', 'Analysis', '.double', 'thought', 'ンディ', 'ifetime', '_notes', 'myModal', ' Skinny']
Layer 17 vec 13623: [' therefore', ' Therefore', 'Therefore', ' donc', ' daher', '因此', ',因此', ' hence', ' erg', ' поэтому', ' więc', ' quindi', ' nên', ' لذا', ' Erg', ' Поэтому', ',所以', ' इसल', ' Hence', ' بنابراین']
Layer 17 vec 12955: [' thus', ' therefore', 'thus', ' Thus', ' Therefore', 'Thus', 'Therefore', ' hence', ' donc', '因此', ',因此', ' więc', ' thereby', '所以', ',所以', ' Hence', ' so', ' поэтому', ' quindi', ' così']
Layer 18 vec 10122: [' therefore', ' thus', ' Therefore', 'Therefore', 'thus', '所以', ',所以', ' hence', ' donc', ' so', ',因此', ' Thus', '因此', 'Thus', ' поэтому', ' Hence', ' więc', '于是', ' So', ' nên']
Layer 19 vec 1389: [' therefore', ' thus', ' Therefore', 'Therefore', 'thus', ' donc', ' hence', '因此', ' Thus', ',因此', 'Thus', ' consequently', ' więc', ' quindi', ' daher', ' accordingly', ' поэтому', ' Hence', ' thereby', ' इसल']
Layer 19 vec 4762: [' slow', 'slow', ' Slow', ' slower', 'Slow', ' �', ' slowly', ' slowing', ' hold', ' slows', ' calm', '慢', ' slowed', ' ease', ' calming', ' cal', 'hold', '_slow', ' Hold', ' relax']
Layer 19 vec 12911: [' pause', ' Pause', ' paused', ' pauses', 'Pause', ' pa', '停', 'pause', 'pa', 'STOP', 'PA', 'paused', '_pause', '.pause', ' STOP', ' Pa', ' PA', ' stopping', '_PAUSE', '_pa']
Layer 23 vec 11803: [' hold', ' held', ' Hold', ' holds', 'hold', 'Hold', ' holding', 'held', ' Held', '_hold', ' HOLD', ' Holds', ' Holding', 'holding', '-held', 'holds', ' delay', '_HOLD', ' Holden', ' holder']
Layer 24 vec 11117: [' wait', ' waiting', ' waited', ' Wait', ' waits', 'wait', 'Wait', ' Waiting', ' WAIT', '.wait', '_wait', 'waiting', 'Waiting', 'WAIT', '/wait', ' await', '.Wait', '(wait', '_WAIT', '\twait']
Layer 30 vec 12381: [' so', ' So', 'so', 'So', '如此', '所以', '-so', '_so', '(so', '.so', '那么', '\tSo', ' så', ' так', '"So', 'ので', ',所以', '“So', '这么', ' 所']
Most of the relevant vectos appear to be in the middle layers (based on user interpretation), i see lesser relevant ones for wrapping up - since its not sure what kind of words to look for.
Experiment 2) attribution patching - token-level scores
Next, i want to see if we can look at tokens that are highly causal for the prediction of the backtrack token, "wait". Does looking at these tokens show why does the model decide to backtrack?
I used attribution patching on the residual states, and assign the score to each token as the max over all layers. Since we dont have a corrupt input, i use mean-ablation as the corrupted state by averaging over the dataset and sequence for each layer.
Some examples (answer is the final answer predicted and rank is the importance rank over all words):
I cherry-picked these examples, where we can actually find previous tokens that relates to re-checking.
Another thing to ponder is that if the model backtracks its solution, then does the backtracking response look at those important words that are causal for backtracking?
Doing a simple analysis, computing the % of the top 10 important words (after removing stopwords) appearing in the next 100 words: approx 41% overlaps.
I also looked at the importance spread out over the question and CoT (binned into 20% intervals):
Ques: 0.63, CoT (binned into 20%): [0.16, 0.14, 0.21, 0.38, 0.77]
It seems that the question context and immediate CoT is responsible for the model deciding to backtrack.
Next Experiment
The above seems to provide some form of plausible evidence on why the model decides to backtrack. The 1st experiment shows that even if we were to truncate the model's thought immediately on the 1st instance of backtracking, we can still get almost perfect performance.
An interesting next experiment is to see if we can use a faithfulness test such as the one in Lanham, by inserting mistakes in the immediate CoT right before it backtracks and see if the backtracking corrects it.