From 491beac5fac31854580391f2f1b385e78bc7fce9 Mon Sep 17 00:00:00 2001 From: zhuxiaoxu Date: Fri, 6 Mar 2026 17:41:15 +0800 Subject: [PATCH] fix: apply temperature scaling to logits before softmax in sampler MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Temperature sampling requires dividing raw logits by T before softmax, i.e. softmax(logits / T). The previous code applied temperature after log_softmax (logprobs / T), which computes softmax(logits)^(1/T) — a mathematically different distribution whose denominator is (Σ exp(logit_j))^(1/T) instead of Σ exp(logit_j / T). This produces incorrect sampling distributions whenever temperature > 0 is used. Fix applied to both sample_audio_logits and sample_text_logits. --- kimia_infer/utils/sampler.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/kimia_infer/utils/sampler.py b/kimia_infer/utils/sampler.py index 35b8b99..dc031c3 100644 --- a/kimia_infer/utils/sampler.py +++ b/kimia_infer/utils/sampler.py @@ -62,13 +62,14 @@ def sample_audio_logits( logits.scatter_(dim=0, index=recent_window, src=scores) logits = logits.unsqueeze(0) # Add batch dimension back + # Apply temperature scaling if not greedy + if self.audio_temperature > 1e-6: + logits = logits / self.audio_temperature + # Convert to probabilities with softmax logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) - # Apply temperature scaling if not greedy if self.audio_temperature > 1e-6: - logprobs = logprobs / self.audio_temperature - # Apply top-k sampling if self.audio_top_k > 0: # Get probabilities from logprobs @@ -134,13 +135,14 @@ def sample_text_logits( logits.scatter_(dim=0, index=recent_window, src=scores) logits = logits.unsqueeze(0) # Add batch dimension back + # Apply temperature scaling if not greedy + if self.text_temperature > 1e-6: + logits = logits / self.text_temperature + # Convert to probabilities with softmax logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) - # Apply temperature scaling if not greedy if self.text_temperature > 1e-6: - logprobs = logprobs / self.text_temperature - # Apply top-k sampling if self.text_top_k > 0: # Get probabilities from logprobs