-
Notifications
You must be signed in to change notification settings - Fork 39
Description
So I'm building a class that can alternate between both the huggingface and sagemaker clients and I declare all my os.environs at the top of the class like so:
os.environ["AWS_ACCESS_KEY_ID"] = "<key_id>"
os.environ["AWS_SECRET_ACCESS_KEY"] = '<access_key>'
os.environ["AWS_DEFAULT_REGION"] = "us-east-1"
os.environ["HF_TOKEN"] = "<hf_token>"
os.environ["HUGGINGFACE_PROMPT"] = "llama2"
and even later on in the class, just to be sure, I declare huggingface.prompt_builder = 'llama2'
tried importing build_llama2_prompt directly and passing it as a callable, that also didn't work
tried setting sagemaker.prompt_builder = 'llama2' just for fun to see if that would do anything...nope
Still get the warning telling me I haven't set a prompt builder, which is kinda weird, plus it's clear that occasionally the prompt is being formatted a bit weirdly (because the same prompt formatted as in the example below when passed directly to the sagemaker endpoint yields a somewhat better response from the same endpoint)
it's nbd that this doesn't work super well for me, I might just be being stupid about it, below is how I've just worked around it by manually implementing w/ sagemaker's HuggingFacePredictor cls:
llm = sagemaker.huggingface.model.HuggingFacePredictor('llama-party', sess)
def build_llama2_prompt(messages):
startPrompt = "<s>[INST] "
endPrompt = " [/INST]"
conversation = []
for index, message in enumerate(messages):
if message["role"] == "system" and index == 0:
conversation.append(f"<<SYS>>\n{message['content']}\n<</SYS>>\n\n")
elif message["role"] == "user":
conversation.append(message["content"].strip())
else:
conversation.append(f" [/INST] {message.content}</s><s>[INST] ")
return startPrompt + "".join(conversation) + endPrompt
prompt = build_llama2_prompt(messages)
payload = {
"inputs": prompt,
"parameters": {
"do_sample": True,
"top_p": 0.6,
"temperature": 0.9,
"top_k": 50,
"max_new_tokens": 512,
"repetition_penalty": 1.03,
"stop": ["</s>"]
}
}
chat = llm.predict(payload)
print(chat[0]["generated_text"][len(prompt):])
this code was pretty much fully taken from the sagemaker llama deployment blog post here: https://www.philschmid.de/sagemaker-llama-llm
works fine, just don't know why the same code doesn't work right inside of the lib (easyllm)