From b144be2f7ece9aaa4adb42f3970e9e624c75f87f Mon Sep 17 00:00:00 2001 From: Shubham Ugare Date: Sat, 7 Dec 2024 18:57:06 -0600 Subject: [PATCH] For instruct mode input can be list of user and system message --- README.md | 17 ++++- notebooks/example.ipynb | 4 +- notebooks/example_instruct_model.ipynb | 99 ++++++++++++++++++++++++++ syncode/evaluation/code_eval.py | 2 +- syncode/infer.py | 32 +++------ syncode/language_model.py | 71 +++++++++--------- tests/test_language_model.py | 4 +- 7 files changed, 162 insertions(+), 67 deletions(-) create mode 100644 notebooks/example_instruct_model.ipynb diff --git a/README.md b/README.md index cf49475d..263a3e1a 100644 --- a/README.md +++ b/README.md @@ -148,6 +148,20 @@ print(f"SynCode output:\n{output}") Check more examples of using Python, Go, and other grammars in Notebooks and a quick example at   [](https://colab.research.google.com/drive/1rYm8iehx_qYdtgWmqLkmhIjizhUVTb9E?usp=sharing) +#### Instuct-tuned Models +SynCode can also be used with Instruct-tuned models. For example, the following code snippet shows how to use SynCode with the Instruct-tuned LLM model. + +``` python +messages = [ + {"role": "system", "content": "You are a chatbot who always returns a JSON object."}, + {"role": "user", "content": "can you give me a JSON object describing University of Illinois at Urbana-Champaign?"}, +] + +out = syn_llm.infer(messages) +``` + + See the [notebook](./notebooks/example_instruct_model.ipynb) for example. + ### Environment Variables Optionally, you can set the directories for cache by exporting the following environment variables. Add the following lines to your .bashrc or .zshrc file: ``` @@ -179,8 +193,6 @@ export HF_ACCESS_TOKEN="your_huggingface_api_key" - `dataset` (str, optional): Dataset. Defaults to "input". "input" indicates that the user can provide input via CLI or by passing in a prompt as a string. - `num_few_shot` (int, optional): Number of examples for few-shot prompting. Defaults to 0. - -- `chat_mode` (bool, optional): True if using a Chat/Instruct LLM. False otherwise. Defaults to False. - `dev_mode` (bool, optional): Development mode where we do not fail silently with parser errors. Defaults to False. @@ -217,7 +229,6 @@ python3 syncode/infer.py --dataset [mbxp, humaneval, mathqa-x, input] --few_shot [True, False] --num_fs_examples [num_fs_examples] - --chat_mode [True, False] --dev_mode [True, False] --log_level [0, 1, 2] --new_mask_store [True, False] diff --git a/notebooks/example.ipynb b/notebooks/example.ipynb index bcd2c419..1225c22a 100644 --- a/notebooks/example.ipynb +++ b/notebooks/example.ipynb @@ -79,10 +79,10 @@ "model_name = \"microsoft/phi-2\"\n", "\n", "# Load the unconstrained original model\n", - "llm = Syncode(model = model_name, mode='original', max_new_tokens=20, chat_mode=True)\n", + "llm = Syncode(model = model_name, mode='original', max_new_tokens=20)\n", "\n", "# Load the Syncode augmented model\n", - "syn_llm = Syncode(model = model_name, mode='grammar_mask', grammar=grammar, chat_mode=True)" + "syn_llm = Syncode(model = model_name, mode='grammar_mask', grammar=grammar)" ] }, { diff --git a/notebooks/example_instruct_model.ipynb b/notebooks/example_instruct_model.ipynb new file mode 100644 index 00000000..37ba2ec1 --- /dev/null +++ b/notebooks/example_instruct_model.ipynb @@ -0,0 +1,99 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/shubham/anaconda3/envs/codex/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 12.69it/s]\n" + ] + } + ], + "source": [ + "from syncode import Syncode\n", + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "\n", + "model_name = \"meta-llama/Meta-Llama-3-8B-Instruct\"\n", + "\n", + "# Load the Syncode augmented model\n", + "syn_llm = Syncode(model=model_name, grammar='json', max_new_tokens=400)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", + "Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{ \"name\": \"University of Illinois at Urbana-Champaign\", \n", + "\"location\": \"Champaign, Illinois, USA\", \n", + "\"established\": 1867, \n", + "\"type\": \"Public research university\", \n", + "\"enrollment\": 50000, \n", + "\"campus\": \"Urban\", \n", + "\"mascot\": \"Chief Illiniwek\", \n", + "\"athletic_conference\": \"Big Ten Conference\", \n", + "\"notable_alumni\": [\"Ray Kroc\", \"Lee DeForest\", \"Napoleon Hill\", \"John Bardeen\", \"Roger Ebert\"], \n", + "\"reputation\": \"Highly ranked in the US News & World Report's Best Colleges\", \n", + "\"programs\": [\"Engineering\", \"Business\", \"Agriculture\", \"Computer Science\", \"Education\"], \n", + "\"research_areas\": [\"Materials Science\", \"Biotechnology\", \"Energy\", \"Aerospace Engineering\", \"Computer Science\"], \n", + "\"campus_size\": \"1,783 acres\", \n", + "\"student_to_faculty_ratio\": \"18:1\", \n", + "\"acceptance_rate\": \"64.4%\", \n", + "\"tuition\": \"$15,622 (in-state), $29,444 (out-of-state)\" } \n", + "\n", + " \n" + ] + } + ], + "source": [ + "messages = [\n", + " {\"role\": \"system\", \"content\": \"You are a chatbot who always returns a JSON object.\"},\n", + " {\"role\": \"user\", \"content\": \"can you give me a JSON object describing University of Illinois at Urbana-Champaign?\"},\n", + "]\n", + "\n", + "out = syn_llm.infer(messages)\n", + "print(out[0])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "codex", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/syncode/evaluation/code_eval.py b/syncode/evaluation/code_eval.py index bd2d1fc0..21396c8d 100644 --- a/syncode/evaluation/code_eval.py +++ b/syncode/evaluation/code_eval.py @@ -68,7 +68,7 @@ def run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samp else: prompt = problems[task_id]["prompt"] - batch_completions = syncode.model.generate_batch_completion_grammar(prompt, num_samples_per_task, stop_words=stop_words, return_token_ids=True) + batch_completions = syncode.model.generate_grammar_constrained_completion(prompt, num_samples_per_task, stop_words=stop_words, return_token_ids=True) batch_size = len(batch_completions) all_completions = [] diff --git a/syncode/infer.py b/syncode/infer.py index d2090974..e970e759 100644 --- a/syncode/infer.py +++ b/syncode/infer.py @@ -5,7 +5,7 @@ import torch from syncode.language_model import HuggingFaceModel from syncode.grammar_decoder import SyncodeLogitsProcessor -from typing import Optional, Literal +from typing import Optional, Literal, Union from syncode.parsers.grammars import Grammar from syncode.dataset import Dataset from syncode.evaluation.code_eval import CodeEval @@ -15,9 +15,9 @@ from syncode.evaluation.fol_eval import FOLEval -def compile_and_run(model, mode="grammar_strict", quantize=True, device="cuda", grammar=None, dataset="input", num_few_shot=0, chat_mode=False, dev_mode=False, log_level=1, new_mask_store=False, parser="lalr", num_tasks=None, task_id=None, seed=None, opp=True, debug=False, **kwargs): +def compile_and_run(model, mode="grammar_strict", quantize=True, device="cuda", grammar=None, dataset="input", num_few_shot=0, dev_mode=False, log_level=1, new_mask_store=False, parser="lalr", num_tasks=None, task_id=None, seed=None, opp=True, debug=False, **kwargs): - syncode = Syncode(model, mode=mode, quantize=quantize, device=device, grammar=grammar, chat_mode=chat_mode, dev_mode=dev_mode, log_level=log_level, new_mask_store=new_mask_store, parser=parser, seed=seed, opp=opp, **kwargs) + syncode = Syncode(model, mode=mode, quantize=quantize, device=device, grammar=grammar, dev_mode=dev_mode, log_level=log_level, new_mask_store=new_mask_store, parser=parser, seed=seed, opp=opp, **kwargs) if dataset == "input": syncode.infer(debug=debug) @@ -54,8 +54,6 @@ class Syncode: parse_output_only (bool, optional): Parse only the output. Defaults to True. new_mask_store (bool, optional): Use new DFA mask store. Defaults to False. - - chat_mode (bool, optional): Parse only the (output) and not (prompt+output) in chat mode. Defaults to False. dev_mode (bool, optional): Development mode. Defaults to False. @@ -70,7 +68,6 @@ def __init__( quantize: bool = True, device: str = "cuda", grammar: Optional[str] = None, - chat_mode: bool = False, parse_output_only: bool = True, dev_mode: bool = False, log_level: int = 1, @@ -94,17 +91,13 @@ def __init__( self.num_samples = kwargs.get('num_return_sequences', 1) self.new_mask_store = new_mask_store self.parser = parser - self.chat_mode = chat_mode self.log_level = log_level # Set seed if seed is not None: torch.manual_seed(seed) - if self.chat_mode: - self.parse_output_only = True - else: - self.parse_output_only = parse_output_only + self.parse_output_only = parse_output_only # Set the grammar self.language = grammar @@ -196,7 +189,7 @@ def evaluate( logger.close() return output - def user_input(self, prompt:str, stop_words=None, debug=False): + def user_input(self, prompt:Union[str, list], stop_words=None, debug=False): """ Run user input on the model with grammar mask @@ -205,14 +198,11 @@ def user_input(self, prompt:str, stop_words=None, debug=False): stop_words (list, optional): Stop words to use. Defaults to None. debug (bool, optional): Debug mode. Defaults to False. """ - if prompt: - if self.grammar_decoder is not None: # TODO: Remove this check - self.grammar_decoder.reset(prompt) - - if self.chat_mode: - return self.model.generate_chat_completion_grammar(prompt) - else: - return self.model.generate_batch_completion_grammar(prompt, self.num_samples, stop_words=stop_words, debug=debug) + if prompt: + if isinstance(prompt, list): + assert self.parse_output_only == True, "Prompt must be a string for input+output parsing" + + return self.model.generate_grammar_constrained_completion(prompt, self.num_samples, stop_words=stop_words, debug=debug) else: while True: prompt = input('Enter prompt: ') @@ -220,7 +210,7 @@ def user_input(self, prompt:str, stop_words=None, debug=False): if prompt == "exit": break - batch_completions = self.model.generate_batch_completion_grammar(prompt, self.num_samples) + batch_completions = self.model.generate_grammar_constrained_completion(prompt, self.num_samples) for i, completion in enumerate(batch_completions): print(prompt + completion) diff --git a/syncode/language_model.py b/syncode/language_model.py index ec3be393..fd282806 100644 --- a/syncode/language_model.py +++ b/syncode/language_model.py @@ -1,3 +1,4 @@ +from ast import Tuple import time import torch import syncode.common as common @@ -5,7 +6,7 @@ from transformers import LogitsProcessorList, StoppingCriteriaList, StoppingCriteria from syncode.parsers.grammars import Grammar from syncode.utils.generation import filter_code, fix_indents -from typing import Callable, Iterable +from typing import Callable, Iterable, Union from transformers.generation.utils import GenerationMode from transformers.generation.configuration_utils import GenerationConfig @@ -67,9 +68,9 @@ def get_grammar_decoder(self): return None @torch.inference_mode() - def generate_batch_completion_grammar( + def generate_grammar_constrained_completion( self, - prompt, + prompt: Union[str, list], batch_size, stop_words=None, return_token_ids=False, @@ -84,13 +85,13 @@ def generate_batch_completion_grammar( stop_words (list): A list of stop words. If the completion ends with any of the stop words, the completion is returned. return_token_ids (bool): If True, returns the token ids of the completions. debug (bool): If True, prints debug information. - ''' + ''' + inputs, prompt_str = self.get_tokenized_input(prompt, batch_size) + # Reset the grammar decoder if self.grammar_decoder is not None: - self.grammar_decoder.reset(prompt) + self.grammar_decoder.reset(prompt_str) - input_batch = [prompt for _ in range(batch_size)] - inputs = self.tokenizer(input_batch, return_tensors="pt").to(self.model.device) input_ids_cutoff = inputs.input_ids.size(dim=1) # Get the generation config @@ -134,6 +135,8 @@ def generate_batch_completion_grammar( # Total tokens generated self.total_tokens = 0 + + # Decode the completions for i in range(batch_size): self.total_tokens += len(generated_ids[i])-input_ids_cutoff+1 completion = self.tokenizer.decode(generated_ids[i][input_ids_cutoff:len(generated_ids[i])], skip_special_tokens=True) @@ -145,6 +148,29 @@ def generate_batch_completion_grammar( return batch_completions + + def get_tokenized_input(self, prompt: Union[str, list], batch_size: int): + """ + Tokenizes the input prompt and returns the input dictionary for the model. + """ + if isinstance(prompt, list): + assert 'apply_chat_template' in dir(self.tokenizer), "Tokenizer does not support chat completion" + prompt_str = self.tokenizer.apply_chat_template( + prompt, + add_generation_prompt=True, + return_tensors="pt", + tokenize=False + ) + elif isinstance(prompt, str): + prompt_str = prompt + else: + raise ValueError("Prompt should be either a string or a list! It is currently of type: "+str(type(prompt))) + + input_batch = [prompt_str for _ in range(batch_size)] + inputs = self.tokenizer(input_batch, return_tensors="pt").to(self.model.device) + + return inputs, prompt_str + @torch.inference_mode() def _generate( self, @@ -252,37 +278,6 @@ def _get_generation_mode( else: generation_mode = GenerationMode.BEAM_SEARCH return generation_mode - - @torch.inference_mode() - def generate_chat_completion_grammar(self, prompt) -> str: - ''' - Generates chat completion for the given prompt. - ''' - assert 'apply_chat_template' in dir(self.tokenizer), "Tokenizer does not support chat completion" - - message = [{"role": "user", "content": prompt}] - inputs = self.tokenizer.apply_chat_template( - message, - add_generation_prompt=True, - return_tensors="pt" - ).to(self.model.device) - - input_ids_cutoff = inputs.size(dim=1) - - # input_ids_cutoff = inputs.input_ids.size(dim=1) - start_time = time.time() - - generated_ids = self.model.generate( - inputs, - logits_processor=self.grammar_processor, - **self.gen_args - ) - completion = self.tokenizer.decode( - generated_ids[0][input_ids_cutoff:len(generated_ids[0])], - skip_special_tokens=True) - - return completion - def tokenize(self, s: str) -> 'Iterable[int]': return self.tokenizer.encode(s, add_special_tokens=False) diff --git a/tests/test_language_model.py b/tests/test_language_model.py index c1986401..c77c0be3 100644 --- a/tests/test_language_model.py +++ b/tests/test_language_model.py @@ -62,7 +62,7 @@ def test_generate_batch_completion_grammar(self): logger = common.EmptyLogger() lm = HuggingFaceModel(model, Grammar('calc'), tokenizer, mode='original', max_new_tokens=15, device='cpu') prompt = "113 + 235 + 17" - output = lm.generate_batch_completion_grammar(prompt, 1) + output = lm.generate_grammar_constrained_completion(prompt, 1) self.assertEqual(len(output[0]), 15, "The output length does not match the expected value.") def test_generate_batch_completion_grammar2(self): @@ -72,7 +72,7 @@ def test_generate_batch_completion_grammar2(self): logger = common.EmptyLogger() lm = HuggingFaceModel(model, Grammar('calc'), tokenizer, mode='original', max_new_tokens=15, device='cpu') prompt = "113 + 235 + 17" - output = lm.generate_batch_completion_grammar(prompt, 2) + output = lm.generate_grammar_constrained_completion(prompt, 2) self.assertEqual(len(output[0]), 15, "The output length does not match the expected value.") self.assertEqual(len(output[1]), 15, "The output length does not match the expected value.")