Skip to content

Issue setting huggingface.prompt_builder = 'llama2' when using sagemaker as client #31

@bcarsley

Description

@bcarsley

So I'm building a class that can alternate between both the huggingface and sagemaker clients and I declare all my os.environs at the top of the class like so:

os.environ["AWS_ACCESS_KEY_ID"] = "<key_id>"
os.environ["AWS_SECRET_ACCESS_KEY"] = '<access_key>'
os.environ["AWS_DEFAULT_REGION"] = "us-east-1"
os.environ["HF_TOKEN"] = "<hf_token>"
os.environ["HUGGINGFACE_PROMPT"] = "llama2"

and even later on in the class, just to be sure, I declare huggingface.prompt_builder = 'llama2'
tried importing build_llama2_prompt directly and passing it as a callable, that also didn't work
tried setting sagemaker.prompt_builder = 'llama2' just for fun to see if that would do anything...nope

Still get the warning telling me I haven't set a prompt builder, which is kinda weird, plus it's clear that occasionally the prompt is being formatted a bit weirdly (because the same prompt formatted as in the example below when passed directly to the sagemaker endpoint yields a somewhat better response from the same endpoint)

it's nbd that this doesn't work super well for me, I might just be being stupid about it, below is how I've just worked around it by manually implementing w/ sagemaker's HuggingFacePredictor cls:

llm = sagemaker.huggingface.model.HuggingFacePredictor('llama-party', sess)
def build_llama2_prompt(messages):
    startPrompt = "<s>[INST] "
    endPrompt = " [/INST]"
    conversation = []
    for index, message in enumerate(messages):
        if message["role"] == "system" and index == 0:
            conversation.append(f"<<SYS>>\n{message['content']}\n<</SYS>>\n\n")
        elif message["role"] == "user":
            conversation.append(message["content"].strip())
        else:
            conversation.append(f" [/INST] {message.content}</s><s>[INST] ")

    return startPrompt + "".join(conversation) + endPrompt

prompt = build_llama2_prompt(messages)

payload = {
  "inputs":  prompt,
  "parameters": {
    "do_sample": True,
    "top_p": 0.6,
    "temperature": 0.9,
    "top_k": 50,
    "max_new_tokens": 512,
    "repetition_penalty": 1.03,
    "stop": ["</s>"]
  }
}

chat = llm.predict(payload)

print(chat[0]["generated_text"][len(prompt):])

this code was pretty much fully taken from the sagemaker llama deployment blog post here: https://www.philschmid.de/sagemaker-llama-llm

works fine, just don't know why the same code doesn't work right inside of the lib (easyllm)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions