diff --git a/python/python/psyche/models/ttitan.py b/python/python/psyche/models/ttitan.py index 93dd80518..cd8fa6a15 100644 --- a/python/python/psyche/models/ttitan.py +++ b/python/python/psyche/models/ttitan.py @@ -436,7 +436,7 @@ def forward( ), ) if num_logits_to_keep: - pred = pred[:, -num_logits_to_keep, :] + pred = pred[:, -num_logits_to_keep:, :] loss = None if labels is not None: if labels.shape != pred.shape[:2]: diff --git a/shared/eval/src/harness.rs b/shared/eval/src/harness.rs index 836cdd728..209ec2874 100644 --- a/shared/eval/src/harness.rs +++ b/shared/eval/src/harness.rs @@ -493,29 +493,24 @@ impl PreparedTask { // The request already contains [fewshot_tokens] + [question + choice_without_last_token] let full_request = request; - let input_length = &full_request.len(); let request_tensor = Tensor::from_slice(&full_request) .to(options.model.device()) .unsqueeze(0); let (logits, _) = { let _no_grad = tch::no_grad_guard(); - options - .model - .forward(&request_tensor, None, None, None, None, None) + options.model.forward( + &request_tensor, + None, + None, + None, + Some(choice.len() as i64), + None, + ) }; - let logits = logits.unwrap().squeeze_dim(0).slice(0, 0, None, 1); - - // Get tensor of shape `[choice.len(), vocab_size]` containing the - // model's logits for each token of the `choice` text. - // This should skip the fewshot tokens and get the tokens from the end. - let logits = logits.slice( - 0, - *input_length as i64 - choice.len() as i64, - *input_length as i64, - 1, - ); + // Shape: [choice.len(), vocab_size] + let logits = logits.unwrap().squeeze_dim(0); let greedy_tokens: Vec = logits.argmax(-1, false).try_into().unwrap(); let exact_match = greedy_tokens.eq(&choice);