From 6f04e7d7dba0299937c685f34851403a8236a96a Mon Sep 17 00:00:00 2001 From: ZhiyuLi-goog Date: Sat, 10 Aug 2024 16:46:32 -0700 Subject: [PATCH] update extract_anthropic_prompt --- preference_datasets.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/preference_datasets.py b/preference_datasets.py index 9964ce8..e9c97c6 100644 --- a/preference_datasets.py +++ b/preference_datasets.py @@ -1,4 +1,5 @@ import datasets +import os import torch from torch.utils.data import DataLoader, Dataset from utils import get_local_dir, TemporarilySeededRandom @@ -11,12 +12,14 @@ from typing import Dict, List, Optional, Iterator, Callable, Union, Tuple -def extract_anthropic_prompt(prompt_and_response): +def extract_anthropic_prompt(chosen, rejected): """Extract the anthropic prompt from a prompt and response pair.""" + # Use the longest common prefix between the chosen and rejected as the prompt + common_prefix = os.path.commonprefix([chosen, rejected]) search_term = '\n\nAssistant:' - search_term_idx = prompt_and_response.rfind(search_term) + search_term_idx = common_prefix.rfind(search_term) assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'" - return prompt_and_response[:search_term_idx + len(search_term)] + return chosen[:search_term_idx + len(search_term)] def strip_html_tags(html_string): @@ -143,7 +146,7 @@ def get_hh(split: str, silent: bool = False, cache_dir: str = None) -> Dict[str, print('done') def split_prompt_and_responses(ex): - prompt = extract_anthropic_prompt(ex['chosen']) + prompt = extract_anthropic_prompt(ex['chosen'], ex['rejected']) chosen_response = ex['chosen'][len(prompt):] rejected_response = ex['rejected'][len(prompt):] return prompt, chosen_response, rejected_response