diff --git a/docs/finetune_parameters.md b/docs/finetune_parameters.md index 4f113e69f..cebb1449e 100644 --- a/docs/finetune_parameters.md +++ b/docs/finetune_parameters.md @@ -16,6 +16,7 @@ The following are the parameters supported in the finetuning workflow. |lora_config|task_type: CAUSAL_LM
r: 8
lora_alpha: 32
lora_dropout: 0.1|Will be passed to the LoraConfig `__init__()` method, then it'll be used as config to build Peft model object.| |deltatuner_config|"algo": "lora"
"denas": True
"best_model_structure": "/path/to/best_structure_of_deltatuner_model"|Will be passed to the DeltaTunerArguments `__init__()` method, then it'll be used as config to build [Deltatuner model](https://github.com/intel/e2eAIOK/tree/main/e2eAIOK/deltatuner) object.| |enable_gradient_checkpointing|False|enable gradient checkpointing to save GPU memory, but will cost more compute runtime| +|chat_template|None|User-defined chat template.| ## Dataset Parameters diff --git a/llm_on_ray/common/__init__.py b/llm_on_ray/common/__init__.py index 0e8e821ad..e002976b6 100644 --- a/llm_on_ray/common/__init__.py +++ b/llm_on_ray/common/__init__.py @@ -18,4 +18,13 @@ from llm_on_ray.common.torch_config import TorchConfig from llm_on_ray.common.config import Config from llm_on_ray.common.init import init -from llm_on_ray.common import agentenv, dataset, initializer, model, optimizer, tokenizer, trainer +from llm_on_ray.common import ( + agentenv, + dataset, + initializer, + model, + optimizer, + tokenizer, + trainer, + dataprocesser, +) diff --git a/llm_on_ray/common/dataprocesser/__init__.py b/llm_on_ray/common/dataprocesser/__init__.py index 2b5152764..c1bf68ae8 100644 --- a/llm_on_ray/common/dataprocesser/__init__.py +++ b/llm_on_ray/common/dataprocesser/__init__.py @@ -15,7 +15,8 @@ # from llm_on_ray.common.dataprocesser.dataprocesser import DataProcesser -from llm_on_ray.common.dataprocesser.general_processer import GeneralProcesser +from llm_on_ray.common.dataprocesser.general_processer import ChatDataPreprocess +from llm_on_ray.common.dataprocesser.general_processer import SlimOrcaDataPreprocess from llm_on_ray.common.dataprocesser.rm_dataprocesser import RMDataProcesser diff --git a/llm_on_ray/common/dataprocesser/general_processer.py b/llm_on_ray/common/dataprocesser/general_processer.py index 37235b425..c51329718 100644 --- a/llm_on_ray/common/dataprocesser/general_processer.py +++ b/llm_on_ray/common/dataprocesser/general_processer.py @@ -24,10 +24,9 @@ from llm_on_ray.common.dataprocesser import DataProcesser INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request." -INSTRUCTION_KEY = "### Instruction:" -INPUT_KEY = "Input:" -RESPONSE_KEY = "### Response:" -END_KEY = "### End" +INSTRUCTION_KEY = "### Instruction: " +INPUT_KEY = "Input: " +RESPONSE_KEY = "### Response: " RESPONSE_KEY_NL = f"{RESPONSE_KEY}\n" PROMPT_NO_INPUT_FORMAT = """{intro} @@ -36,15 +35,12 @@ {instruction} {response_key} -{response} - -{end_key}""".format( +{response}""".format( intro=INTRO_BLURB, instruction_key=INSTRUCTION_KEY, instruction="{instruction}", response_key=RESPONSE_KEY, response="{response}", - end_key=END_KEY, ) PROMPT_WITH_INPUT_FORMAT = """{intro} @@ -56,9 +52,7 @@ {input} {response_key} -{response} - -{end_key}""".format( +{response}""".format( intro=INTRO_BLURB, instruction_key=INSTRUCTION_KEY, instruction="{instruction}", @@ -66,10 +60,14 @@ input="{input}", response_key=RESPONSE_KEY, response="{response}", - end_key=END_KEY, ) TEXT_COLUMN_NAME = "text" +SLIMORCA_PROMPT_DICT = { + "prompt_with_input": ("### System: {system} \n" "### User: {user} \n### Assistant: {gpt}"), + "prompt_without_input": ("### System: {system} \n" "### Assistant: {gpt}"), +} + class DataCollatorForCompletionOnlyLM(transformers.DataCollatorForLanguageModeling): def torch_call(self, examples): @@ -98,9 +96,74 @@ def torch_call(self, examples): return batch -class GeneralProcesser(DataProcesser): +class ChatDataPreprocess(DataProcesser): + base_template = """Below is an instruction that describes a task. Write a response that appropriately completes the request.\n""" + + def __init__(self, config): + super().__init__(config) + self.prompt_template = self.base_template + self.user = "### Instruction:\n" + self.assistant = "### Response:\n" + self.end = "### End\n" + + def create_data(self, examples): + if self.config.get("gpt_base_model"): + instruction = examples["instruction"] + response = examples["response"] + context = examples.get("context") + if not instruction: + raise ValueError(f"Expected an instruction in: {examples}") + if not response: + raise ValueError(f"Expected a response in: {examples}") + if context: + new_messages = PROMPT_WITH_INPUT_FORMAT.format( + instruction=instruction, response=response, input=context + ) + else: + new_messages = PROMPT_NO_INPUT_FORMAT.format( + instruction=instruction, response=response + ) + else: + new_messages = [ + { + "role": "system", + "content": INTRO_BLURB + "\n", + }, + { + "role": "user", + "content": examples["instruction"] + + "\n" + + INPUT_KEY + + examples["context"] + + "\n", + }, + {"role": "assistant", "content": examples["response"] + "\n"}, + ] + + return new_messages + + def tokenize_func(self, tokenizer, message): + if self.config.get("gpt_base_model"): + return tokenizer( + message, add_special_tokens=False, max_length=self.config.get("max_length") + ) + else: + if self.config.get("chat_template") is not None: + tokenizer.chat_template = self.config.get("chat_template") + elif tokenizer.chat_template is not None: + pass + else: + tokenizer.chat_template = self.config.get("default_chat_template") + + new_tokenizer = tokenizer.apply_chat_template( + message, + tokenize=False, + ) + return tokenizer( + new_tokenizer, add_special_tokens=False, max_length=self.config.get("max_length") + ) + def tokenize_dataset(self, tokenizer, dataset): - max_length = self.config.get("max_length") group = self.config.get("group") block_size = self.config.get("block_size") tokenizer.pad_token = tokenizer.eos_token @@ -111,38 +174,8 @@ def tokenize_dataset(self, tokenizer, dataset): if isinstance(dataset, datasets.DatasetDict): column_names = dataset["train"].column_names - if column_names and TEXT_COLUMN_NAME not in column_names: - - def prompt(rec): - instruction = rec["instruction"] - response = rec["response"] - context = rec.get("context") - if not instruction: - raise ValueError(f"Expected an instruction in: {rec}") - if not response: - raise ValueError(f"Expected a response in: {rec}") - if context: - rec["text"] = PROMPT_WITH_INPUT_FORMAT.format( - instruction=instruction, response=response, input=context - ) - else: - rec["text"] = PROMPT_NO_INPUT_FORMAT.format( - instruction=instruction, response=response - ) - return rec - - dataset = dataset.map( - prompt, - load_from_cache_file=False, - desc="Prompt", - ) - column_names += [TEXT_COLUMN_NAME] - - def tokenize_function(examples): - return tokenizer(examples[TEXT_COLUMN_NAME], max_length=max_length) - tokenized_datasets = dataset.map( - tokenize_function, + lambda examples: self.tokenize_func(tokenizer, self.create_data(examples)), remove_columns=column_names, load_from_cache_file=False, desc="Tokenize dataset", @@ -208,3 +241,94 @@ def prepare_dataloader(self, tokenizer, dataset): } eval_dataloader = torch.utils.data.DataLoader(eval_dataset, **eval_dataloader_params) return train_dataloader, eval_dataloader + + +class SlimOrcaDataPreprocess(ChatDataPreprocess): + chat_template = ( + "{% for message in messages %}" + "{% if message['role'] == 'system' %}" + "{{ '### System: ' + message['content'] }}" + "{% elif message['role'] == 'user' %}" + "{{ '### User: ' + message['content'] }}" + "{% elif message['role'] == 'assistant' %}" + "{{ '### Assistant: ' + message['content'] }}" + "{% endif %}" + "{% endfor %}" + ) + + def __init__(self, config): + super().__init__(config) + self.config["chat_template"] = self.chat_template + self.default_system = "You are a helpful, respectful and honest assistant." + + def create_data(self, data): + examples = {} + conv = data["conversations"] + # system + if conv[0]["from"] != "system": + examples["system"] = self.default_system + start = 0 + elif conv[0]["from"] == "system" and conv[0]["value"] == "": + examples[conv[0]["from"]] = self.default_system + start = 1 + else: + examples[conv[0]["from"]] = conv[0]["value"] + start = 1 + + for j in range(start, len(conv) - 1, 2): + examples[conv[j]["from"]] = conv[j]["value"] + examples[conv[j + 1]["from"]] = conv[j + 1]["value"] + + if self.config.get("gpt_base_model"): + if examples["human"]: + return SLIMORCA_PROMPT_DICT["prompt_with_input"].format( + instruction=examples["system"], + response=examples["gpt"], + input=examples["human"], + ) + else: + return SLIMORCA_PROMPT_DICT["prompt_without_input"].format( + instruction=examples["system"], response=examples["gpt"] + ) + else: + new_messages = [ + {"role": "system", "content": examples["system"] + "\n"}, + { + "role": "user", + "content": examples["system"] + "\n" + INPUT_KEY + examples["human"] + "\n", + }, + {"role": "assistant", "content": examples["gpt"] + "\n"}, + ] + return new_messages + + +class OpenOrcaDataPreprocess(ChatDataPreprocess): + def __init__(self, config): + super().__init__(config) + self.default_system = "You are an AI assistant. You will be given a task. You must generate a detailed and long answer." + + def create_data(self, examples): + if self.config.get("gpt_base_model"): + if not examples["system"]: + examples["system"] = self.default_system + + if examples["question"]: + return PROMPT_WITH_INPUT_FORMAT.format( + instruction=examples["system"], + response=examples["chosen"], + input=examples["question"], + ) + else: + return PROMPT_NO_INPUT_FORMAT.format( + instruction=examples["system"], response=examples["chosen"] + ) + else: + new_messages = [ + {"role": "system", "content": INTRO_BLURB + "\n"}, + { + "role": "user", + "content": examples["system"] + "\n" + INPUT_KEY + examples["question"] + "\n", + }, + {"role": "assistant", "content": examples["chosen"] + "\n"}, + ] + return new_messages diff --git a/llm_on_ray/common/trainer/default_trainer.py b/llm_on_ray/common/trainer/default_trainer.py index 8825f08be..61d9d6015 100644 --- a/llm_on_ray/common/trainer/default_trainer.py +++ b/llm_on_ray/common/trainer/default_trainer.py @@ -33,10 +33,17 @@ class DefaultTrainer(Trainer): def __init__(self, config): self.model = None + self.tokenizer = None self.config = config dataprocesser_config = config.get("dataprocesser") dataprocesser_type = dataprocesser_config.get("type") - Factory = dataprocesser.DataProcesser.registory.get(dataprocesser_type) + if dataprocesser_type == "chat": + Factory = dataprocesser.DataProcesser.registory.get("ChatDataPreprocess") + elif dataprocesser_type == "SlimOrca": + Factory = dataprocesser.DataProcesser.registory.get("SlimOrcaDataPreprocess") + else: + raise ValueError(f"there is no {dataprocesser_type} dataprocesser.") + if Factory is None: raise ValueError(f"there is no {dataprocesser_type} dataprocesser.") self.dataprocesser = Factory(dataprocesser_config) @@ -121,7 +128,7 @@ def _get_lr_scheduler( def prepare(self, model, tokenizer, dataset, optimizer, accelerator): self._coordinate(accelerator) - + self.tokenizer = tokenizer embedding_size = model.get_input_embeddings().weight.shape[0] logger.info(f"model embedding size: {embedding_size}") if len(tokenizer) > embedding_size: @@ -290,6 +297,11 @@ def train(self): is_main_process=self.accelerator.is_main_process, save_function=self.accelerator.save, ) + self.tokenizer.save_pretrained( + output, + is_main_process=self.accelerator.is_main_process, + save_function=self.accelerator.save, + ) logger.info(f"finish save model to {output}") self.accelerator.wait_for_everyone() diff --git a/llm_on_ray/finetune/finetune.py b/llm_on_ray/finetune/finetune.py index 29d955a49..a4ecd9b07 100644 --- a/llm_on_ray/finetune/finetune.py +++ b/llm_on_ray/finetune/finetune.py @@ -14,7 +14,7 @@ # limitations under the License. # -#!/usr/bin/env python +# !/usr/bin/env python import os import argparse @@ -221,7 +221,17 @@ def train_func(config: Dict[str, Any]): } ) - dataprocesser = common.dataprocesser.DataProcesser.registory.get("GeneralProcesser")( + dataprocesser_type = config["Dataset"]["type"] + if dataprocesser_type == "chat": + preprocesser_name = "ChatDataPreprocess" + elif dataprocesser_type == "OpenOrca": + preprocesser_name = "OpenOrcaDataPreprocess" + elif dataprocesser_type == "SlimOrca": + preprocesser_name = "SlimOrcaDataPreprocess" + else: + raise ValueError(f"there is no {dataprocesser_type} dataprocesser.") + + dataprocesser = common.dataprocesser.DataProcesser.registory.get(preprocesser_name)( config={ "per_device_train_batch_size": config["Training"]["batch_size"], "per_device_eval_batch_size": config["Training"]["batch_size"], @@ -230,6 +240,11 @@ def train_func(config: Dict[str, Any]): "group": config["Dataset"].get("group", True), "block_size": config["Dataset"].get("block_size", 512), "shuffle": config["Dataset"].get("shuffle", False), + "name": tokenizer_name, + "config": config["General"]["config"], + "gpt_base_model": config["General"].get("gpt_base_model", False), + "chat_template": config["General"]["chat_template"], + "default_chat_template": config["General"]["default_chat_template"], } ) tokenized_datasets = dataprocesser.tokenize_dataset(tokenizer, datasets) @@ -356,7 +371,15 @@ def main(external_config=None): ) # additional 1 for head worker ray.init(num_cpus=num_cpus, runtime_env=runtime_env) else: - ray.init(runtime_env=runtime_env) + import intel_extension_for_pytorch as ipex + + if "xpu" in ipex.__version__: + num_cpus = ( + resources_per_worker["CPU"] * num_training_workers + 1 + ) # additional 1 for head worker + ray.init(num_cpus=num_cpus, runtime_env=runtime_env) + else: + ray.init(runtime_env=runtime_env) common.logger.info(f"ray available resources = {ray.available_resources()}") use_gpu = True if device == "gpu" else False diff --git a/llm_on_ray/finetune/finetune.yaml b/llm_on_ray/finetune/finetune.yaml index 627a88753..78a9e1c57 100644 --- a/llm_on_ray/finetune/finetune.yaml +++ b/llm_on_ray/finetune/finetune.yaml @@ -13,7 +13,8 @@ General: lora_dropout: 0.1 enable_gradient_checkpointing: false Dataset: - train_file: examples/data/sample_finetune_data_small.jsonl + type: "SlimOrca" + train_file: Open-Orca/SlimOrca group: true max_length: 512 block_size: 512 diff --git a/llm_on_ray/finetune/finetune_config.py b/llm_on_ray/finetune/finetune_config.py index 030fcc5a6..0a25ad777 100644 --- a/llm_on_ray/finetune/finetune_config.py +++ b/llm_on_ray/finetune/finetune_config.py @@ -17,7 +17,6 @@ from pydantic import BaseModel, validator from typing import Optional, List - PRECISION_BF16 = "bf16" PRECISION_FP16 = "fp16" PRECISION_NO = "no" @@ -61,9 +60,36 @@ class General(BaseModel): lora_config: Optional[LoraConfig] = None deltatuner_config: Optional[DeltatunerConfig] = None enable_gradient_checkpointing: bool = False + chat_template: Optional[str] = None + default_chat_template: str = ( + "{% if messages[0]['role'] == 'system' %}" + "{% set loop_messages = messages[1:] %}" + "{% set system_message = messages[0]['content'] %}" + "{% else %}" + "{% set loop_messages = messages %}" + "{% set system_message = false %}" + "{% endif %}" + "{% for message in loop_messages %}" + "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" + "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" + "{% endif %}" + "{% if loop.index0 == 0 and system_message %}" + "{{ system_message }}" + "{ % endif %}" + "{ % if message['role'] == 'user' %}" + "{{ '### Instruction: ' + message['content'].strip() }}" + "{% elif message['role'] == 'assistant' %}" + "{{ '### Response:' + message['content'].strip() }}" + "{% endif %}" + "{% endfor %}" + "{% if add_generation_prompt %}" + "{{ '### Response: '}}" + "{% endif %}" + ) class Dataset(BaseModel): + type: str = "chat" train_file: str validation_file: Optional[str] validation_split_percentage: int diff --git a/pyproject.toml b/pyproject.toml index b319045cc..451d2649d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,8 @@ dependencies = [ "deltatuner==1.1.9", "py-cpuinfo", "pydantic-yaml", - "async_timeout", + "async-timeout", + "jinja2>=3.0.0", "typer" ] diff --git a/tests/finetune/test_chat_template.py b/tests/finetune/test_chat_template.py new file mode 100644 index 000000000..31d0eed12 --- /dev/null +++ b/tests/finetune/test_chat_template.py @@ -0,0 +1,172 @@ +# +# Copyright 2023 The LLM-on-Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +import transformers +from transformers import AutoTokenizer +from llm_on_ray.common.dataprocesser.general_processer import ChatDataPreprocess + + +class TestTokenizeFunction(unittest.TestCase): + def setUp(self): + self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + self.config = { + "gpt_base_model": True, + "max_length": 512, + "trust_remote_code": False, + "chat_template": "{% if messages[0]['role'] == 'system' %}" + "{% set loop_messages = messages[1:] %}" + "{% set system_message = messages[0]['content'] %}" + "{% else %}" + "{% set loop_messages = messages %}" + "{% set system_message = false %}" + "{% endif %}" + "{% for message in loop_messages %}" + "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" + "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" + "{% endif %}" + "{% if loop.index0 == 0 and system_message %}" + "{{ system_message }}" + "{% endif %}" + "{% if message['role'] == 'user' %}" + "{{ '### Instruction: ' + message['content'] + eos_token }}" + "{% elif message['role'] == 'assistant' %}" + "{{ '### Response:' + message['content'] + eos_token }}" + "{% endif %}{% endfor %}" + "{{'### End \n'}}", + } + self.processer = ChatDataPreprocess(self.config) + + def test_tokenize_function_with_gpt_model(self): + self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6b") + + examples = { + "instruction": "Test instruction", + "response": "Test response", + "context": "Test context", + } + + # Verify the format of the result + expected_result = ( + "Below is an instruction that describes a task. Write a response that " + "appropriately completes the request.\n" + "\n" + "### Instruction: \n" + "Test instruction\n" + "\n" + "Input: \n" + "Test context\n" + "\n" + "### Response: \n" + "Test response\n" + "\n" + "### End" + ) + + print(self.processer.create_data(examples)) + result = self.processer.tokenize_func(self.tokenizer, self.processer.create_data(examples)) + print(self.tokenizer.decode(result["input_ids"])) + + self.assertEqual(expected_result, self.tokenizer.decode(result["input_ids"])) + + def test_tokenize_function_with_custom_chat_template(self): + examples = { + "instruction": "Test instruction", + "response": "Test response", + "context": "Test context", + } + + # Verify the format of the result + expected_result = ( + "<|im_start|>user\n" + "Test instruction\n" + "\n" + "Input: Test context\n" + "\n" + "<|im_end|><|im_start|>assistant\n" + "Test response\n" + "\n" + "<|im_end|>" + ) + + print(expected_result) + # Set custom chat template + self.config["chat_template"] = ( + "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n'" + "+ message['content'] + '<|im_end|>'}}{% endfor %}" + ) + + self.config["gpt_base_model"] = False + print(self.processer.create_data(examples)) + result = self.processer.tokenize_func(self.tokenizer, self.processer.create_data(examples)) + print(self.tokenizer.decode(result["input_ids"])) + self.assertEqual(expected_result, self.tokenizer.decode(result["input_ids"])) + + def test_tokenize_function_with_default_chat_template(self): + examples = { + "instruction": "Test instruction", + "response": "Test response", + "context": "Test context", + } + + # Verify the format of the result + expected_result = ( + "Below is an instruction that describes a task. Write a response that " + "appropriately completes the request\n" + "### Instruction: Test instruction\n" + "\n" + "Input: Test context\n" + "\n" + "### Response: Test response\n" + "\n" + "### End \n" + ) + self.config["gpt_base_model"] = False + result = self.processer.tokenize_func(self.tokenizer, self.processer.create_data(examples)) + self.assertEqual(expected_result, self.tokenizer.decode(result["input_ids"])) + + def test_tokenize_function_with_tokenizer_chat_template(self): + self.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it") + examples = { + "instruction": "Test instruction", + "response": "Test response", + "context": "Test context", + } + + chat_example = [ + { + "role": "user", + "content": "Test instruction\n\nInput: Test context\n\n", + }, + { + "role": "assistant", + "content": "Test response\n\n", + }, + ] + + # Verify the format of the result + expected_result = self.tokenizer.apply_chat_template( + chat_example, tokenize=False, max_length=self.config.get("max_length") + ) + + self.config["chat_template"] = None + self.config["gpt_base_model"] = False + result = self.processer.tokenize_func(self.tokenizer, self.processer.create_data(examples)) + self.assertEqual(expected_result, self.tokenizer.decode(result["input_ids"])) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/finetune/test_slimOrca_chat_template.py b/tests/finetune/test_slimOrca_chat_template.py new file mode 100644 index 000000000..059a316d1 --- /dev/null +++ b/tests/finetune/test_slimOrca_chat_template.py @@ -0,0 +1,128 @@ +# +# Copyright 2023 The LLM-on-Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +import transformers +from datasets import Dataset +from transformers import AutoTokenizer +from llm_on_ray.common.dataprocesser.general_processer import ( + ChatDataPreprocess, + SlimOrcaDataPreprocess, +) + + +class TestTokenizeFunction(unittest.TestCase): + def setUp(self): + self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + self.config = { + "gpt_base_model": True, + "max_length": 512, + "trust_remote_code": False, + "chat_template": "Below is an instruction that describes a task. Write a response that appropriately " + "completes the request\n {% if messages[0]['role'] == 'system' %}{{ raise_exception(" + "'System role not supported') }}{% endif %}{% for message in messages %}{% if (message[" + "'role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles " + "must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] " + "== 'user' %}{{ '### Instruction: ' + message['content'] }}{% elif message['role'] == " + "'assistant' %}{{ '### Response: ' + message['content'] }}{% endif %}{% endfor %}{{'### " + "End \n'}}", + } + self.processer = SlimOrcaDataPreprocess(self.config) + examples = { + "conversations": [ + {"from": "system", "value": "Test system", "weight": None}, + {"from": "human", "value": "Test human", "weight": 0}, + {"from": "gpt", "value": "Test gpt.", "weight": 1}, + ] + } + + self.ds = Dataset.from_dict(examples) + + def test_tokenize_function_with_gpt_model(self): + self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6b") + + # Verify the format of the result + expected_result = ( + "### System: Test system \n" "### User: Test human \n" "### Assistant: Test gpt." + ) + + result = self.processer.tokenize_func(self.tokenizer, self.processer.create_data(self.ds)) + + self.assertEqual(expected_result, self.tokenizer.decode(result["input_ids"])) + + def test_tokenize_function_with_custom_chat_template(self): + # Verify the format of the result + expected_result = ( + "<|im_start|>system\n" + "Test system\n" + "<|im_end|><|im_start|>user\n" + "Test human\n" + "<|im_end|><|im_start|>assistant\n" + "Test gpt.\n" + "<|im_end|>" + ) + + print(expected_result) + # Set custom chat template + self.config["chat_template"] = ( + "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n'" + "+ message['content'] + '<|im_end|>'}}{% endfor %}" + ) + + self.config["gpt_base_model"] = False + result = self.processer.tokenize_func(self.tokenizer, self.processer.create_data(self.ds)) + self.assertEqual(expected_result, self.tokenizer.decode(result["input_ids"])) + + def test_tokenize_function_with_default_chat_template(self): + # Verify the format of the result + expected_result = ( + "### System: Test system\n" "### User: Test human\n" "### Assistant: Test gpt.\n" + ) + self.config["gpt_base_model"] = False + result = self.processer.tokenize_func(self.tokenizer, self.processer.create_data(self.ds)) + self.assertEqual(expected_result, self.tokenizer.decode(result["input_ids"])) + + def test_tokenize_function_with_tokenizer_chat_template(self): + self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") + + chat_example = [ + { + "role": "system", + "content": "Test system\n", + }, + { + "role": "user", + "content": "Test human\n", + }, + { + "role": "assistant", + "content": "Test gpt.\n", + }, + ] + + # Verify the format of the result + expected_result = self.tokenizer.apply_chat_template( + chat_example, tokenize=True, max_length=self.config.get("max_length") + ) + + self.config["chat_template"] = None + self.config["gpt_base_model"] = False + result = self.processer.tokenize_func(self.tokenizer, self.processer.create_data(self.ds)) + self.assertEqual(expected_result, result["input_ids"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/inference/test_query_single.py b/tests/inference/test_query_single.py new file mode 100644 index 000000000..d48727a30 --- /dev/null +++ b/tests/inference/test_query_single.py @@ -0,0 +1,123 @@ +# +# Copyright 2023 The LLM-on-Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import subprocess +import pytest +import os + +os.environ["no_proxy"] = "localhost,127.0.0.1" + + +def start_serve(model_name): + current_path = os.path.dirname(os.path.abspath(__file__)) + + config_path = os.path.join( + current_path, "../../.github/workflows/config/" + model_name + "-ci.yaml" + ) + + cmd_serve = ["llm_on_ray-serve", "--config_file", config_path, "--simple"] + + result_serve = subprocess.run(cmd_serve, capture_output=True, text=True) + + # Ensure there are no errors in the serve script execution + assert result_serve.returncode == 0, print( + "\n" + "Serve error stderr message: " + "\n", result_serve.stderr + ) + + # Print the output of subprocess.run for checking if output is expected + print("\n" + "Serve message: " + "\n", result_serve.stdout) + + # Ensure there are no errors in the serve script execution + assert "Error" not in result_serve.stderr + + +def script_with_args( + base_url, model_name, streaming_response, max_new_tokens, temperature, top_p, top_k +): + current_path = os.path.dirname(os.path.abspath(__file__)) + + os.path.join(current_path, "../../.github/workflows/config/" + model_name + "-ci.yaml") + + example_query_single_path = os.path.join( + current_path, "../../examples/inference/api_server_simple/query_single.py" + ) + + cmd_single = [ + "python", + example_query_single_path, + "--model_endpoint", + base_url + model_name, + ] + + if streaming_response: + cmd_single.append("--streaming_response") + + if max_new_tokens is not None: + cmd_single.extend(["--max_new_tokens", str(max_new_tokens)]) + + if temperature is not None: + cmd_single.extend(["--temperature", str(temperature)]) + + if top_p is not None: + cmd_single.extend(["--top_p", str(top_p)]) + + if top_k is not None: + cmd_single.extend(["--top_k", str(top_k)]) + + result_query_single = subprocess.run(cmd_single, capture_output=True, text=True) + + # Print the output of subprocess.run for checking if output is expected + print(result_query_single) + + # Ensure there are no errors in the OpenAI API query script execution + assert "Error" not in result_query_single.stderr + + # Returncode should be 0 when there is no exception + assert result_query_single.returncode == 0 + + +executed_models = {} + + +# Parametrize the test function with different combinations of parameters +# TODO: more models and combinations will be added and tested. +@pytest.mark.parametrize( + "base_url,model_name,streaming_response,max_new_tokens,temperature,top_p, top_k", + [ + (base_url, model_name, streaming_response, max_new_tokens, temperature, top_p, top_k) + for base_url in ["http://localhost:8000/"] + for model_name in ["gpt2"] + for streaming_response in [None] + for max_new_tokens in [None] + for temperature in [None] + for top_p in [None] + for top_k in [None] + ], +) +def test_script( + base_url, model_name, streaming_response, max_new_tokens, temperature, top_p, top_k +): + global executed_models + + # Check if this modelname has already executed start_serve + if model_name not in executed_models: + start_serve(model_name) + # Mark this modelname has already executed start_serve + executed_models[model_name] = True + + script_with_args( + base_url, model_name, streaming_response, max_new_tokens, temperature, top_p, top_k + )