diff --git a/docs/finetune_parameters.md b/docs/finetune_parameters.md
index 4f113e69f..cebb1449e 100644
--- a/docs/finetune_parameters.md
+++ b/docs/finetune_parameters.md
@@ -16,6 +16,7 @@ The following are the parameters supported in the finetuning workflow.
|lora_config|task_type: CAUSAL_LM
r: 8
lora_alpha: 32
lora_dropout: 0.1|Will be passed to the LoraConfig `__init__()` method, then it'll be used as config to build Peft model object.|
|deltatuner_config|"algo": "lora"
"denas": True
"best_model_structure": "/path/to/best_structure_of_deltatuner_model"|Will be passed to the DeltaTunerArguments `__init__()` method, then it'll be used as config to build [Deltatuner model](https://github.com/intel/e2eAIOK/tree/main/e2eAIOK/deltatuner) object.|
|enable_gradient_checkpointing|False|enable gradient checkpointing to save GPU memory, but will cost more compute runtime|
+|chat_template|None|User-defined chat template.|
## Dataset Parameters
diff --git a/llm_on_ray/common/__init__.py b/llm_on_ray/common/__init__.py
index 0e8e821ad..e002976b6 100644
--- a/llm_on_ray/common/__init__.py
+++ b/llm_on_ray/common/__init__.py
@@ -18,4 +18,13 @@
from llm_on_ray.common.torch_config import TorchConfig
from llm_on_ray.common.config import Config
from llm_on_ray.common.init import init
-from llm_on_ray.common import agentenv, dataset, initializer, model, optimizer, tokenizer, trainer
+from llm_on_ray.common import (
+ agentenv,
+ dataset,
+ initializer,
+ model,
+ optimizer,
+ tokenizer,
+ trainer,
+ dataprocesser,
+)
diff --git a/llm_on_ray/common/dataprocesser/__init__.py b/llm_on_ray/common/dataprocesser/__init__.py
index 2b5152764..c1bf68ae8 100644
--- a/llm_on_ray/common/dataprocesser/__init__.py
+++ b/llm_on_ray/common/dataprocesser/__init__.py
@@ -15,7 +15,8 @@
#
from llm_on_ray.common.dataprocesser.dataprocesser import DataProcesser
-from llm_on_ray.common.dataprocesser.general_processer import GeneralProcesser
+from llm_on_ray.common.dataprocesser.general_processer import ChatDataPreprocess
+from llm_on_ray.common.dataprocesser.general_processer import SlimOrcaDataPreprocess
from llm_on_ray.common.dataprocesser.rm_dataprocesser import RMDataProcesser
diff --git a/llm_on_ray/common/dataprocesser/general_processer.py b/llm_on_ray/common/dataprocesser/general_processer.py
index 37235b425..c51329718 100644
--- a/llm_on_ray/common/dataprocesser/general_processer.py
+++ b/llm_on_ray/common/dataprocesser/general_processer.py
@@ -24,10 +24,9 @@
from llm_on_ray.common.dataprocesser import DataProcesser
INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
-INSTRUCTION_KEY = "### Instruction:"
-INPUT_KEY = "Input:"
-RESPONSE_KEY = "### Response:"
-END_KEY = "### End"
+INSTRUCTION_KEY = "### Instruction: "
+INPUT_KEY = "Input: "
+RESPONSE_KEY = "### Response: "
RESPONSE_KEY_NL = f"{RESPONSE_KEY}\n"
PROMPT_NO_INPUT_FORMAT = """{intro}
@@ -36,15 +35,12 @@
{instruction}
{response_key}
-{response}
-
-{end_key}""".format(
+{response}""".format(
intro=INTRO_BLURB,
instruction_key=INSTRUCTION_KEY,
instruction="{instruction}",
response_key=RESPONSE_KEY,
response="{response}",
- end_key=END_KEY,
)
PROMPT_WITH_INPUT_FORMAT = """{intro}
@@ -56,9 +52,7 @@
{input}
{response_key}
-{response}
-
-{end_key}""".format(
+{response}""".format(
intro=INTRO_BLURB,
instruction_key=INSTRUCTION_KEY,
instruction="{instruction}",
@@ -66,10 +60,14 @@
input="{input}",
response_key=RESPONSE_KEY,
response="{response}",
- end_key=END_KEY,
)
TEXT_COLUMN_NAME = "text"
+SLIMORCA_PROMPT_DICT = {
+ "prompt_with_input": ("### System: {system} \n" "### User: {user} \n### Assistant: {gpt}"),
+ "prompt_without_input": ("### System: {system} \n" "### Assistant: {gpt}"),
+}
+
class DataCollatorForCompletionOnlyLM(transformers.DataCollatorForLanguageModeling):
def torch_call(self, examples):
@@ -98,9 +96,74 @@ def torch_call(self, examples):
return batch
-class GeneralProcesser(DataProcesser):
+class ChatDataPreprocess(DataProcesser):
+ base_template = """Below is an instruction that describes a task. Write a response that appropriately completes the request.\n"""
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.prompt_template = self.base_template
+ self.user = "### Instruction:\n"
+ self.assistant = "### Response:\n"
+ self.end = "### End\n"
+
+ def create_data(self, examples):
+ if self.config.get("gpt_base_model"):
+ instruction = examples["instruction"]
+ response = examples["response"]
+ context = examples.get("context")
+ if not instruction:
+ raise ValueError(f"Expected an instruction in: {examples}")
+ if not response:
+ raise ValueError(f"Expected a response in: {examples}")
+ if context:
+ new_messages = PROMPT_WITH_INPUT_FORMAT.format(
+ instruction=instruction, response=response, input=context
+ )
+ else:
+ new_messages = PROMPT_NO_INPUT_FORMAT.format(
+ instruction=instruction, response=response
+ )
+ else:
+ new_messages = [
+ {
+ "role": "system",
+ "content": INTRO_BLURB + "\n",
+ },
+ {
+ "role": "user",
+ "content": examples["instruction"]
+ + "\n"
+ + INPUT_KEY
+ + examples["context"]
+ + "\n",
+ },
+ {"role": "assistant", "content": examples["response"] + "\n"},
+ ]
+
+ return new_messages
+
+ def tokenize_func(self, tokenizer, message):
+ if self.config.get("gpt_base_model"):
+ return tokenizer(
+ message, add_special_tokens=False, max_length=self.config.get("max_length")
+ )
+ else:
+ if self.config.get("chat_template") is not None:
+ tokenizer.chat_template = self.config.get("chat_template")
+ elif tokenizer.chat_template is not None:
+ pass
+ else:
+ tokenizer.chat_template = self.config.get("default_chat_template")
+
+ new_tokenizer = tokenizer.apply_chat_template(
+ message,
+ tokenize=False,
+ )
+ return tokenizer(
+ new_tokenizer, add_special_tokens=False, max_length=self.config.get("max_length")
+ )
+
def tokenize_dataset(self, tokenizer, dataset):
- max_length = self.config.get("max_length")
group = self.config.get("group")
block_size = self.config.get("block_size")
tokenizer.pad_token = tokenizer.eos_token
@@ -111,38 +174,8 @@ def tokenize_dataset(self, tokenizer, dataset):
if isinstance(dataset, datasets.DatasetDict):
column_names = dataset["train"].column_names
- if column_names and TEXT_COLUMN_NAME not in column_names:
-
- def prompt(rec):
- instruction = rec["instruction"]
- response = rec["response"]
- context = rec.get("context")
- if not instruction:
- raise ValueError(f"Expected an instruction in: {rec}")
- if not response:
- raise ValueError(f"Expected a response in: {rec}")
- if context:
- rec["text"] = PROMPT_WITH_INPUT_FORMAT.format(
- instruction=instruction, response=response, input=context
- )
- else:
- rec["text"] = PROMPT_NO_INPUT_FORMAT.format(
- instruction=instruction, response=response
- )
- return rec
-
- dataset = dataset.map(
- prompt,
- load_from_cache_file=False,
- desc="Prompt",
- )
- column_names += [TEXT_COLUMN_NAME]
-
- def tokenize_function(examples):
- return tokenizer(examples[TEXT_COLUMN_NAME], max_length=max_length)
-
tokenized_datasets = dataset.map(
- tokenize_function,
+ lambda examples: self.tokenize_func(tokenizer, self.create_data(examples)),
remove_columns=column_names,
load_from_cache_file=False,
desc="Tokenize dataset",
@@ -208,3 +241,94 @@ def prepare_dataloader(self, tokenizer, dataset):
}
eval_dataloader = torch.utils.data.DataLoader(eval_dataset, **eval_dataloader_params)
return train_dataloader, eval_dataloader
+
+
+class SlimOrcaDataPreprocess(ChatDataPreprocess):
+ chat_template = (
+ "{% for message in messages %}"
+ "{% if message['role'] == 'system' %}"
+ "{{ '### System: ' + message['content'] }}"
+ "{% elif message['role'] == 'user' %}"
+ "{{ '### User: ' + message['content'] }}"
+ "{% elif message['role'] == 'assistant' %}"
+ "{{ '### Assistant: ' + message['content'] }}"
+ "{% endif %}"
+ "{% endfor %}"
+ )
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.config["chat_template"] = self.chat_template
+ self.default_system = "You are a helpful, respectful and honest assistant."
+
+ def create_data(self, data):
+ examples = {}
+ conv = data["conversations"]
+ # system
+ if conv[0]["from"] != "system":
+ examples["system"] = self.default_system
+ start = 0
+ elif conv[0]["from"] == "system" and conv[0]["value"] == "":
+ examples[conv[0]["from"]] = self.default_system
+ start = 1
+ else:
+ examples[conv[0]["from"]] = conv[0]["value"]
+ start = 1
+
+ for j in range(start, len(conv) - 1, 2):
+ examples[conv[j]["from"]] = conv[j]["value"]
+ examples[conv[j + 1]["from"]] = conv[j + 1]["value"]
+
+ if self.config.get("gpt_base_model"):
+ if examples["human"]:
+ return SLIMORCA_PROMPT_DICT["prompt_with_input"].format(
+ instruction=examples["system"],
+ response=examples["gpt"],
+ input=examples["human"],
+ )
+ else:
+ return SLIMORCA_PROMPT_DICT["prompt_without_input"].format(
+ instruction=examples["system"], response=examples["gpt"]
+ )
+ else:
+ new_messages = [
+ {"role": "system", "content": examples["system"] + "\n"},
+ {
+ "role": "user",
+ "content": examples["system"] + "\n" + INPUT_KEY + examples["human"] + "\n",
+ },
+ {"role": "assistant", "content": examples["gpt"] + "\n"},
+ ]
+ return new_messages
+
+
+class OpenOrcaDataPreprocess(ChatDataPreprocess):
+ def __init__(self, config):
+ super().__init__(config)
+ self.default_system = "You are an AI assistant. You will be given a task. You must generate a detailed and long answer."
+
+ def create_data(self, examples):
+ if self.config.get("gpt_base_model"):
+ if not examples["system"]:
+ examples["system"] = self.default_system
+
+ if examples["question"]:
+ return PROMPT_WITH_INPUT_FORMAT.format(
+ instruction=examples["system"],
+ response=examples["chosen"],
+ input=examples["question"],
+ )
+ else:
+ return PROMPT_NO_INPUT_FORMAT.format(
+ instruction=examples["system"], response=examples["chosen"]
+ )
+ else:
+ new_messages = [
+ {"role": "system", "content": INTRO_BLURB + "\n"},
+ {
+ "role": "user",
+ "content": examples["system"] + "\n" + INPUT_KEY + examples["question"] + "\n",
+ },
+ {"role": "assistant", "content": examples["chosen"] + "\n"},
+ ]
+ return new_messages
diff --git a/llm_on_ray/common/trainer/default_trainer.py b/llm_on_ray/common/trainer/default_trainer.py
index 8825f08be..61d9d6015 100644
--- a/llm_on_ray/common/trainer/default_trainer.py
+++ b/llm_on_ray/common/trainer/default_trainer.py
@@ -33,10 +33,17 @@
class DefaultTrainer(Trainer):
def __init__(self, config):
self.model = None
+ self.tokenizer = None
self.config = config
dataprocesser_config = config.get("dataprocesser")
dataprocesser_type = dataprocesser_config.get("type")
- Factory = dataprocesser.DataProcesser.registory.get(dataprocesser_type)
+ if dataprocesser_type == "chat":
+ Factory = dataprocesser.DataProcesser.registory.get("ChatDataPreprocess")
+ elif dataprocesser_type == "SlimOrca":
+ Factory = dataprocesser.DataProcesser.registory.get("SlimOrcaDataPreprocess")
+ else:
+ raise ValueError(f"there is no {dataprocesser_type} dataprocesser.")
+
if Factory is None:
raise ValueError(f"there is no {dataprocesser_type} dataprocesser.")
self.dataprocesser = Factory(dataprocesser_config)
@@ -121,7 +128,7 @@ def _get_lr_scheduler(
def prepare(self, model, tokenizer, dataset, optimizer, accelerator):
self._coordinate(accelerator)
-
+ self.tokenizer = tokenizer
embedding_size = model.get_input_embeddings().weight.shape[0]
logger.info(f"model embedding size: {embedding_size}")
if len(tokenizer) > embedding_size:
@@ -290,6 +297,11 @@ def train(self):
is_main_process=self.accelerator.is_main_process,
save_function=self.accelerator.save,
)
+ self.tokenizer.save_pretrained(
+ output,
+ is_main_process=self.accelerator.is_main_process,
+ save_function=self.accelerator.save,
+ )
logger.info(f"finish save model to {output}")
self.accelerator.wait_for_everyone()
diff --git a/llm_on_ray/finetune/finetune.py b/llm_on_ray/finetune/finetune.py
index 29d955a49..a4ecd9b07 100644
--- a/llm_on_ray/finetune/finetune.py
+++ b/llm_on_ray/finetune/finetune.py
@@ -14,7 +14,7 @@
# limitations under the License.
#
-#!/usr/bin/env python
+# !/usr/bin/env python
import os
import argparse
@@ -221,7 +221,17 @@ def train_func(config: Dict[str, Any]):
}
)
- dataprocesser = common.dataprocesser.DataProcesser.registory.get("GeneralProcesser")(
+ dataprocesser_type = config["Dataset"]["type"]
+ if dataprocesser_type == "chat":
+ preprocesser_name = "ChatDataPreprocess"
+ elif dataprocesser_type == "OpenOrca":
+ preprocesser_name = "OpenOrcaDataPreprocess"
+ elif dataprocesser_type == "SlimOrca":
+ preprocesser_name = "SlimOrcaDataPreprocess"
+ else:
+ raise ValueError(f"there is no {dataprocesser_type} dataprocesser.")
+
+ dataprocesser = common.dataprocesser.DataProcesser.registory.get(preprocesser_name)(
config={
"per_device_train_batch_size": config["Training"]["batch_size"],
"per_device_eval_batch_size": config["Training"]["batch_size"],
@@ -230,6 +240,11 @@ def train_func(config: Dict[str, Any]):
"group": config["Dataset"].get("group", True),
"block_size": config["Dataset"].get("block_size", 512),
"shuffle": config["Dataset"].get("shuffle", False),
+ "name": tokenizer_name,
+ "config": config["General"]["config"],
+ "gpt_base_model": config["General"].get("gpt_base_model", False),
+ "chat_template": config["General"]["chat_template"],
+ "default_chat_template": config["General"]["default_chat_template"],
}
)
tokenized_datasets = dataprocesser.tokenize_dataset(tokenizer, datasets)
@@ -356,7 +371,15 @@ def main(external_config=None):
) # additional 1 for head worker
ray.init(num_cpus=num_cpus, runtime_env=runtime_env)
else:
- ray.init(runtime_env=runtime_env)
+ import intel_extension_for_pytorch as ipex
+
+ if "xpu" in ipex.__version__:
+ num_cpus = (
+ resources_per_worker["CPU"] * num_training_workers + 1
+ ) # additional 1 for head worker
+ ray.init(num_cpus=num_cpus, runtime_env=runtime_env)
+ else:
+ ray.init(runtime_env=runtime_env)
common.logger.info(f"ray available resources = {ray.available_resources()}")
use_gpu = True if device == "gpu" else False
diff --git a/llm_on_ray/finetune/finetune.yaml b/llm_on_ray/finetune/finetune.yaml
index 627a88753..78a9e1c57 100644
--- a/llm_on_ray/finetune/finetune.yaml
+++ b/llm_on_ray/finetune/finetune.yaml
@@ -13,7 +13,8 @@ General:
lora_dropout: 0.1
enable_gradient_checkpointing: false
Dataset:
- train_file: examples/data/sample_finetune_data_small.jsonl
+ type: "SlimOrca"
+ train_file: Open-Orca/SlimOrca
group: true
max_length: 512
block_size: 512
diff --git a/llm_on_ray/finetune/finetune_config.py b/llm_on_ray/finetune/finetune_config.py
index 030fcc5a6..0a25ad777 100644
--- a/llm_on_ray/finetune/finetune_config.py
+++ b/llm_on_ray/finetune/finetune_config.py
@@ -17,7 +17,6 @@
from pydantic import BaseModel, validator
from typing import Optional, List
-
PRECISION_BF16 = "bf16"
PRECISION_FP16 = "fp16"
PRECISION_NO = "no"
@@ -61,9 +60,36 @@ class General(BaseModel):
lora_config: Optional[LoraConfig] = None
deltatuner_config: Optional[DeltatunerConfig] = None
enable_gradient_checkpointing: bool = False
+ chat_template: Optional[str] = None
+ default_chat_template: str = (
+ "{% if messages[0]['role'] == 'system' %}"
+ "{% set loop_messages = messages[1:] %}"
+ "{% set system_message = messages[0]['content'] %}"
+ "{% else %}"
+ "{% set loop_messages = messages %}"
+ "{% set system_message = false %}"
+ "{% endif %}"
+ "{% for message in loop_messages %}"
+ "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
+ "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
+ "{% endif %}"
+ "{% if loop.index0 == 0 and system_message %}"
+ "{{ system_message }}"
+ "{ % endif %}"
+ "{ % if message['role'] == 'user' %}"
+ "{{ '### Instruction: ' + message['content'].strip() }}"
+ "{% elif message['role'] == 'assistant' %}"
+ "{{ '### Response:' + message['content'].strip() }}"
+ "{% endif %}"
+ "{% endfor %}"
+ "{% if add_generation_prompt %}"
+ "{{ '### Response: '}}"
+ "{% endif %}"
+ )
class Dataset(BaseModel):
+ type: str = "chat"
train_file: str
validation_file: Optional[str]
validation_split_percentage: int
diff --git a/pyproject.toml b/pyproject.toml
index b319045cc..451d2649d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -34,7 +34,8 @@ dependencies = [
"deltatuner==1.1.9",
"py-cpuinfo",
"pydantic-yaml",
- "async_timeout",
+ "async-timeout",
+ "jinja2>=3.0.0",
"typer"
]
diff --git a/tests/finetune/test_chat_template.py b/tests/finetune/test_chat_template.py
new file mode 100644
index 000000000..31d0eed12
--- /dev/null
+++ b/tests/finetune/test_chat_template.py
@@ -0,0 +1,172 @@
+#
+# Copyright 2023 The LLM-on-Ray Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+import transformers
+from transformers import AutoTokenizer
+from llm_on_ray.common.dataprocesser.general_processer import ChatDataPreprocess
+
+
+class TestTokenizeFunction(unittest.TestCase):
+ def setUp(self):
+ self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
+ self.config = {
+ "gpt_base_model": True,
+ "max_length": 512,
+ "trust_remote_code": False,
+ "chat_template": "{% if messages[0]['role'] == 'system' %}"
+ "{% set loop_messages = messages[1:] %}"
+ "{% set system_message = messages[0]['content'] %}"
+ "{% else %}"
+ "{% set loop_messages = messages %}"
+ "{% set system_message = false %}"
+ "{% endif %}"
+ "{% for message in loop_messages %}"
+ "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
+ "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
+ "{% endif %}"
+ "{% if loop.index0 == 0 and system_message %}"
+ "{{ system_message }}"
+ "{% endif %}"
+ "{% if message['role'] == 'user' %}"
+ "{{ '### Instruction: ' + message['content'] + eos_token }}"
+ "{% elif message['role'] == 'assistant' %}"
+ "{{ '### Response:' + message['content'] + eos_token }}"
+ "{% endif %}{% endfor %}"
+ "{{'### End \n'}}",
+ }
+ self.processer = ChatDataPreprocess(self.config)
+
+ def test_tokenize_function_with_gpt_model(self):
+ self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6b")
+
+ examples = {
+ "instruction": "Test instruction",
+ "response": "Test response",
+ "context": "Test context",
+ }
+
+ # Verify the format of the result
+ expected_result = (
+ "Below is an instruction that describes a task. Write a response that "
+ "appropriately completes the request.\n"
+ "\n"
+ "### Instruction: \n"
+ "Test instruction\n"
+ "\n"
+ "Input: \n"
+ "Test context\n"
+ "\n"
+ "### Response: \n"
+ "Test response\n"
+ "\n"
+ "### End"
+ )
+
+ print(self.processer.create_data(examples))
+ result = self.processer.tokenize_func(self.tokenizer, self.processer.create_data(examples))
+ print(self.tokenizer.decode(result["input_ids"]))
+
+ self.assertEqual(expected_result, self.tokenizer.decode(result["input_ids"]))
+
+ def test_tokenize_function_with_custom_chat_template(self):
+ examples = {
+ "instruction": "Test instruction",
+ "response": "Test response",
+ "context": "Test context",
+ }
+
+ # Verify the format of the result
+ expected_result = (
+ "<|im_start|>user\n"
+ "Test instruction\n"
+ "\n"
+ "Input: Test context\n"
+ "\n"
+ "<|im_end|><|im_start|>assistant\n"
+ "Test response\n"
+ "\n"
+ "<|im_end|>"
+ )
+
+ print(expected_result)
+ # Set custom chat template
+ self.config["chat_template"] = (
+ "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n'"
+ "+ message['content'] + '<|im_end|>'}}{% endfor %}"
+ )
+
+ self.config["gpt_base_model"] = False
+ print(self.processer.create_data(examples))
+ result = self.processer.tokenize_func(self.tokenizer, self.processer.create_data(examples))
+ print(self.tokenizer.decode(result["input_ids"]))
+ self.assertEqual(expected_result, self.tokenizer.decode(result["input_ids"]))
+
+ def test_tokenize_function_with_default_chat_template(self):
+ examples = {
+ "instruction": "Test instruction",
+ "response": "Test response",
+ "context": "Test context",
+ }
+
+ # Verify the format of the result
+ expected_result = (
+ "Below is an instruction that describes a task. Write a response that "
+ "appropriately completes the request\n"
+ "### Instruction: Test instruction\n"
+ "\n"
+ "Input: Test context\n"
+ "\n"
+ "### Response: Test response\n"
+ "\n"
+ "### End \n"
+ )
+ self.config["gpt_base_model"] = False
+ result = self.processer.tokenize_func(self.tokenizer, self.processer.create_data(examples))
+ self.assertEqual(expected_result, self.tokenizer.decode(result["input_ids"]))
+
+ def test_tokenize_function_with_tokenizer_chat_template(self):
+ self.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
+ examples = {
+ "instruction": "Test instruction",
+ "response": "Test response",
+ "context": "Test context",
+ }
+
+ chat_example = [
+ {
+ "role": "user",
+ "content": "Test instruction\n\nInput: Test context\n\n",
+ },
+ {
+ "role": "assistant",
+ "content": "Test response\n\n",
+ },
+ ]
+
+ # Verify the format of the result
+ expected_result = self.tokenizer.apply_chat_template(
+ chat_example, tokenize=False, max_length=self.config.get("max_length")
+ )
+
+ self.config["chat_template"] = None
+ self.config["gpt_base_model"] = False
+ result = self.processer.tokenize_func(self.tokenizer, self.processer.create_data(examples))
+ self.assertEqual(expected_result, self.tokenizer.decode(result["input_ids"]))
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/finetune/test_slimOrca_chat_template.py b/tests/finetune/test_slimOrca_chat_template.py
new file mode 100644
index 000000000..059a316d1
--- /dev/null
+++ b/tests/finetune/test_slimOrca_chat_template.py
@@ -0,0 +1,128 @@
+#
+# Copyright 2023 The LLM-on-Ray Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+import transformers
+from datasets import Dataset
+from transformers import AutoTokenizer
+from llm_on_ray.common.dataprocesser.general_processer import (
+ ChatDataPreprocess,
+ SlimOrcaDataPreprocess,
+)
+
+
+class TestTokenizeFunction(unittest.TestCase):
+ def setUp(self):
+ self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
+ self.config = {
+ "gpt_base_model": True,
+ "max_length": 512,
+ "trust_remote_code": False,
+ "chat_template": "Below is an instruction that describes a task. Write a response that appropriately "
+ "completes the request\n {% if messages[0]['role'] == 'system' %}{{ raise_exception("
+ "'System role not supported') }}{% endif %}{% for message in messages %}{% if (message["
+ "'role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles "
+ "must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] "
+ "== 'user' %}{{ '### Instruction: ' + message['content'] }}{% elif message['role'] == "
+ "'assistant' %}{{ '### Response: ' + message['content'] }}{% endif %}{% endfor %}{{'### "
+ "End \n'}}",
+ }
+ self.processer = SlimOrcaDataPreprocess(self.config)
+ examples = {
+ "conversations": [
+ {"from": "system", "value": "Test system", "weight": None},
+ {"from": "human", "value": "Test human", "weight": 0},
+ {"from": "gpt", "value": "Test gpt.", "weight": 1},
+ ]
+ }
+
+ self.ds = Dataset.from_dict(examples)
+
+ def test_tokenize_function_with_gpt_model(self):
+ self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6b")
+
+ # Verify the format of the result
+ expected_result = (
+ "### System: Test system \n" "### User: Test human \n" "### Assistant: Test gpt."
+ )
+
+ result = self.processer.tokenize_func(self.tokenizer, self.processer.create_data(self.ds))
+
+ self.assertEqual(expected_result, self.tokenizer.decode(result["input_ids"]))
+
+ def test_tokenize_function_with_custom_chat_template(self):
+ # Verify the format of the result
+ expected_result = (
+ "<|im_start|>system\n"
+ "Test system\n"
+ "<|im_end|><|im_start|>user\n"
+ "Test human\n"
+ "<|im_end|><|im_start|>assistant\n"
+ "Test gpt.\n"
+ "<|im_end|>"
+ )
+
+ print(expected_result)
+ # Set custom chat template
+ self.config["chat_template"] = (
+ "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n'"
+ "+ message['content'] + '<|im_end|>'}}{% endfor %}"
+ )
+
+ self.config["gpt_base_model"] = False
+ result = self.processer.tokenize_func(self.tokenizer, self.processer.create_data(self.ds))
+ self.assertEqual(expected_result, self.tokenizer.decode(result["input_ids"]))
+
+ def test_tokenize_function_with_default_chat_template(self):
+ # Verify the format of the result
+ expected_result = (
+ "### System: Test system\n" "### User: Test human\n" "### Assistant: Test gpt.\n"
+ )
+ self.config["gpt_base_model"] = False
+ result = self.processer.tokenize_func(self.tokenizer, self.processer.create_data(self.ds))
+ self.assertEqual(expected_result, self.tokenizer.decode(result["input_ids"]))
+
+ def test_tokenize_function_with_tokenizer_chat_template(self):
+ self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
+
+ chat_example = [
+ {
+ "role": "system",
+ "content": "Test system\n",
+ },
+ {
+ "role": "user",
+ "content": "Test human\n",
+ },
+ {
+ "role": "assistant",
+ "content": "Test gpt.\n",
+ },
+ ]
+
+ # Verify the format of the result
+ expected_result = self.tokenizer.apply_chat_template(
+ chat_example, tokenize=True, max_length=self.config.get("max_length")
+ )
+
+ self.config["chat_template"] = None
+ self.config["gpt_base_model"] = False
+ result = self.processer.tokenize_func(self.tokenizer, self.processer.create_data(self.ds))
+ self.assertEqual(expected_result, result["input_ids"])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/inference/test_query_single.py b/tests/inference/test_query_single.py
new file mode 100644
index 000000000..d48727a30
--- /dev/null
+++ b/tests/inference/test_query_single.py
@@ -0,0 +1,123 @@
+#
+# Copyright 2023 The LLM-on-Ray Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import subprocess
+import pytest
+import os
+
+os.environ["no_proxy"] = "localhost,127.0.0.1"
+
+
+def start_serve(model_name):
+ current_path = os.path.dirname(os.path.abspath(__file__))
+
+ config_path = os.path.join(
+ current_path, "../../.github/workflows/config/" + model_name + "-ci.yaml"
+ )
+
+ cmd_serve = ["llm_on_ray-serve", "--config_file", config_path, "--simple"]
+
+ result_serve = subprocess.run(cmd_serve, capture_output=True, text=True)
+
+ # Ensure there are no errors in the serve script execution
+ assert result_serve.returncode == 0, print(
+ "\n" + "Serve error stderr message: " + "\n", result_serve.stderr
+ )
+
+ # Print the output of subprocess.run for checking if output is expected
+ print("\n" + "Serve message: " + "\n", result_serve.stdout)
+
+ # Ensure there are no errors in the serve script execution
+ assert "Error" not in result_serve.stderr
+
+
+def script_with_args(
+ base_url, model_name, streaming_response, max_new_tokens, temperature, top_p, top_k
+):
+ current_path = os.path.dirname(os.path.abspath(__file__))
+
+ os.path.join(current_path, "../../.github/workflows/config/" + model_name + "-ci.yaml")
+
+ example_query_single_path = os.path.join(
+ current_path, "../../examples/inference/api_server_simple/query_single.py"
+ )
+
+ cmd_single = [
+ "python",
+ example_query_single_path,
+ "--model_endpoint",
+ base_url + model_name,
+ ]
+
+ if streaming_response:
+ cmd_single.append("--streaming_response")
+
+ if max_new_tokens is not None:
+ cmd_single.extend(["--max_new_tokens", str(max_new_tokens)])
+
+ if temperature is not None:
+ cmd_single.extend(["--temperature", str(temperature)])
+
+ if top_p is not None:
+ cmd_single.extend(["--top_p", str(top_p)])
+
+ if top_k is not None:
+ cmd_single.extend(["--top_k", str(top_k)])
+
+ result_query_single = subprocess.run(cmd_single, capture_output=True, text=True)
+
+ # Print the output of subprocess.run for checking if output is expected
+ print(result_query_single)
+
+ # Ensure there are no errors in the OpenAI API query script execution
+ assert "Error" not in result_query_single.stderr
+
+ # Returncode should be 0 when there is no exception
+ assert result_query_single.returncode == 0
+
+
+executed_models = {}
+
+
+# Parametrize the test function with different combinations of parameters
+# TODO: more models and combinations will be added and tested.
+@pytest.mark.parametrize(
+ "base_url,model_name,streaming_response,max_new_tokens,temperature,top_p, top_k",
+ [
+ (base_url, model_name, streaming_response, max_new_tokens, temperature, top_p, top_k)
+ for base_url in ["http://localhost:8000/"]
+ for model_name in ["gpt2"]
+ for streaming_response in [None]
+ for max_new_tokens in [None]
+ for temperature in [None]
+ for top_p in [None]
+ for top_k in [None]
+ ],
+)
+def test_script(
+ base_url, model_name, streaming_response, max_new_tokens, temperature, top_p, top_k
+):
+ global executed_models
+
+ # Check if this modelname has already executed start_serve
+ if model_name not in executed_models:
+ start_serve(model_name)
+ # Mark this modelname has already executed start_serve
+ executed_models[model_name] = True
+
+ script_with_args(
+ base_url, model_name, streaming_response, max_new_tokens, temperature, top_p, top_k
+ )