diff --git a/kimia_infer/utils/sampler.py b/kimia_infer/utils/sampler.py index 35b8b99..dc031c3 100644 --- a/kimia_infer/utils/sampler.py +++ b/kimia_infer/utils/sampler.py @@ -62,13 +62,14 @@ def sample_audio_logits( logits.scatter_(dim=0, index=recent_window, src=scores) logits = logits.unsqueeze(0) # Add batch dimension back + # Apply temperature scaling if not greedy + if self.audio_temperature > 1e-6: + logits = logits / self.audio_temperature + # Convert to probabilities with softmax logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) - # Apply temperature scaling if not greedy if self.audio_temperature > 1e-6: - logprobs = logprobs / self.audio_temperature - # Apply top-k sampling if self.audio_top_k > 0: # Get probabilities from logprobs @@ -134,13 +135,14 @@ def sample_text_logits( logits.scatter_(dim=0, index=recent_window, src=scores) logits = logits.unsqueeze(0) # Add batch dimension back + # Apply temperature scaling if not greedy + if self.text_temperature > 1e-6: + logits = logits / self.text_temperature + # Convert to probabilities with softmax logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) - # Apply temperature scaling if not greedy if self.text_temperature > 1e-6: - logprobs = logprobs / self.text_temperature - # Apply top-k sampling if self.text_top_k > 0: # Get probabilities from logprobs