-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprompting.py
More file actions
182 lines (153 loc) · 7.85 KB
/
prompting.py
File metadata and controls
182 lines (153 loc) · 7.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
# TODO: rename to prompting.py
import time
import pandas as pd
from together import Together
from openai import OpenAI
from huggingface_hub import InferenceClient
from constants import *
def parse_example_for_system_prompt(example, output_prefix=""):
if type(example) is str:
return "\n" + example
parsed = ""
if output_prefix and not output_prefix.endswith(" "):
output_prefix += " "
if INPUT in example:
parsed += f"\n{INPUT}: {example[INPUT]}"
if OUTPUT in example:
parsed += f"\n{output_prefix}{OUTPUT}: {example[OUTPUT]}"
if EXPLANATION in example:
parsed += f"\n{EXPLANATION}: {example[EXPLANATION]}"
return parsed
def get_private_examples(sheet):
if not sheet:
return []
return [] # TODO: needs implementation!
def get_system_prompt(prompt_type_task=DIRECT_CODING_TASK):
coding_task = validate_model_config()[CODING_TASK]
parameters = PARAMETERS_BY_CODING_TASK[coding_task]
task_instruction = parameters[TASK_DEFINITION]
input_format, output_format = parameters[INPUT_FORMAT_INSTRUCTION], parameters[OUTPUT_FORMAT_INSTRUCTION]
system_prompt = f"{SYSTEM_INTRO}\n\nCODING SCHEME:\n{task_instruction}"
if prompt_type_task == CHAT_TASK:
system_prompt += f"\n\n{CHAT_INSTRUCTION_FORMAT.format(input_format, output_format)}"
else: # prompt_type_task == DIRECT_CODING_TASK
system_prompt += f"\n\nINPUT FORMAT:\n{input_format}\n\nOUTPUT FORMAT:\n{output_format}"
system_prompt += f"\n{STRICT_OUTPUT_FORMAT_REMINDER}"
public_correct_examples, public_incorrect_examples = parameters[PUBLIC_CORRECT_EXAMPLES], parameters[PUBLIC_INCORRECT_EXAMPLES]
private_correct_examples = get_private_examples(parameters[PRIVATE_CORRECT_EXAMPLES_SHEET])
private_incorrect_examples = get_private_examples(parameters[PRIVATE_INCORRECT_EXAMPLES_SHEET])
for examples, title, output_prefix in [(public_correct_examples + private_correct_examples, "EXAMPLES FOR CORRECT CODINGS", "CORRECT"),
(public_incorrect_examples + private_incorrect_examples, "EXAMPLES FOR INCORRECT CODINGS", "INCORRECT")]:
if examples:
system_prompt += f"\n\n{title}:"
for example in examples:
system_prompt += parse_example_for_system_prompt(example, output_prefix)
return system_prompt
def get_model_config_parameters():
model_config = validate_model_config()
service = model_config[MODEL_SERVICE]
if service not in st.session_state:
if service == PRIVATE_SERVICE:
client = InferenceClient(provider="hf-inference", api_key=st.secrets["HF_API_KEY"])
else: # service == FREE_SERVICE
# client = Together(api_key=st.secrets["TOGETHER_API_KEY"])
client = OpenAI(base_url="https://router.huggingface.co/v1", api_key=st.secrets["HF_API_KEY"])
st.session_state[service] = client
client = st.session_state[service]
base_llm = model_config[BASE_LLM]
coding_task = model_config[CODING_TASK]
return client, service, base_llm, coding_task
def get_generation_kwargs(**kwargs):
generation_kwargs = dict(**kwargs)
for parameter, value in DEFAULT_GENERATION_PARAMETERS.items():
if parameter not in generation_kwargs:
generation_kwargs[parameter] = value
return generation_kwargs
def get_generation_log(service, base_llm, coding_task, messages,
generation_kwargs, output, task=DIRECT_CODING_TASK):
return {
TIMESTAMP_COLUMN: time.strftime("%x %X"),
USERNAME_COLUMN: st.session_state.get("user", "error"),
SERVICE_COLUMN: service,
BASE_LLM_COLUMN: base_llm,
CODING_TASK_COLUMN: coding_task,
INPUT_COLUMN: str(messages),
GEN_KWARGS_COLUMN: str(generation_kwargs),
OUTPUT_COLUMN: output,
TASK_COLUMN: task
}
def save_generation_log(single_generation_log: dict[str, str] = None,
multiple_generation_logs: list[dict[str]] = None):
conn = get_gsheets_connection()
df = conn.read(ttl=0)
if single_generation_log:
df.loc[len(df)] = single_generation_log
if multiple_generation_logs:
df = pd.concat([df, pd.DataFrame(multiple_generation_logs)], ignore_index=True)
conn.update(data=df)
def code_text(new_message: str, message_history: list[dict[str, str]] = None, **kwargs):
if message_history is None:
messages = [{"role": "system", "content": get_system_prompt()}]
else:
messages = message_history.copy()
messages.append({"role": "user", "content": new_message})
output, log = generate_with_retries(messages, raw_generation, **kwargs)
# messages.append({"role": "assistant", "content": output})
return output, messages, log
def generate_with_retries(messages, generation_func, task=DIRECT_CODING_TASK, **kwargs):
client, service, base_llm, coding_task = get_model_config_parameters()
generation_kwargs = get_generation_kwargs(**kwargs)
allowed_tries = MAX_ALLOWED_RETRIES
output = ""
while allowed_tries > 0:
if allowed_tries < MAX_ALLOWED_RETRIES:
st.warning(f"Failed to generate with {generation_kwargs[MAX_TOKENS_PARAM]} max tokens,"
f"probably due to the model's thinking tokens. Re-trying with more...")
generation_kwargs[MAX_TOKENS_PARAM] += MAX_TOKENS_INC_STEP
output = generation_func(client, base_llm, messages, generation_kwargs)
if output:
break
allowed_tries -= 1
if not output:
st.error(f"Failed to generate a response, probably due to the model's thinking tokens."
f"(re-tried {MAX_ALLOWED_RETRIES} times)")
log = get_generation_log(service, base_llm, coding_task, messages, generation_kwargs, output, task)
return output, log
def raw_generation(client, base_llm, messages, generation_kwargs):
response = client.chat.completions.create(model=base_llm, messages=messages,
**generation_kwargs)
choice = response.choices[0]
if getattr(choice, "finish_reason", None) == "length":
return ""
return choice.message.content
def raw_stream_generation(client, base_llm, messages, generation_kwargs):
return client.chat.completions.create(model=base_llm, messages=messages,
stream=True, **generation_kwargs)
def join_write_stream(stream):
content_parts = []
def generator():
for chunk in stream:
try:
content = chunk.choices[0].delta.content # TODO: consider using chunk.choices[0].finish_reason == "length"
if content:
content_parts.append(content)
yield content
except (AttributeError, IndexError, KeyError):
# Skip malformed chunks
continue
st.write_stream(generator())
return "".join(content_parts)
def generate_for_chat_with_write_stream(messages: list[dict[str, str]], **kwargs):
# used for streaming the answer directly
generation_func = lambda *args: join_write_stream(raw_stream_generation(*args))
output, log = generate_with_retries(messages, generation_func, CHAT_TASK, **kwargs)
# appending the message to all messages should be outside of assistant scope
return output, log
def generate_for_chat(messages: list[dict[str, str]], temperature: float = 0):
# original deprecated function
client, service, base_llm, coding_task = get_model_config_parameters()
generation_kwargs = dict(temperature=temperature)
output = raw_generation(client, base_llm, messages, generation_kwargs)
log = get_generation_log(service, base_llm, coding_task, messages, generation_kwargs, output, CHAT_TASK)
messages.append({"role": "assistant", "content": output}) # first add to show user
save_generation_log(single_generation_log=log) # then save log