diff --git a/README.md b/README.md
index cf49475d..263a3e1a 100644
--- a/README.md
+++ b/README.md
@@ -148,6 +148,20 @@ print(f"SynCode output:\n{output}")
Check more examples of using Python, Go, and other grammars in Notebooks and a quick example at
[
](https://colab.research.google.com/drive/1rYm8iehx_qYdtgWmqLkmhIjizhUVTb9E?usp=sharing)
+#### Instuct-tuned Models
+SynCode can also be used with Instruct-tuned models. For example, the following code snippet shows how to use SynCode with the Instruct-tuned LLM model.
+
+``` python
+messages = [
+ {"role": "system", "content": "You are a chatbot who always returns a JSON object."},
+ {"role": "user", "content": "can you give me a JSON object describing University of Illinois at Urbana-Champaign?"},
+]
+
+out = syn_llm.infer(messages)
+```
+
+ See the [notebook](./notebooks/example_instruct_model.ipynb) for example.
+
### Environment Variables
Optionally, you can set the directories for cache by exporting the following environment variables. Add the following lines to your .bashrc or .zshrc file:
```
@@ -179,8 +193,6 @@ export HF_ACCESS_TOKEN="your_huggingface_api_key"
- `dataset` (str, optional): Dataset. Defaults to "input". "input" indicates that the user can provide input via CLI or by passing in a prompt as a string.
- `num_few_shot` (int, optional): Number of examples for few-shot prompting. Defaults to 0.
-
-- `chat_mode` (bool, optional): True if using a Chat/Instruct LLM. False otherwise. Defaults to False.
- `dev_mode` (bool, optional): Development mode where we do not fail silently with parser errors. Defaults to False.
@@ -217,7 +229,6 @@ python3 syncode/infer.py
--dataset [mbxp, humaneval, mathqa-x, input]
--few_shot [True, False]
--num_fs_examples [num_fs_examples]
- --chat_mode [True, False]
--dev_mode [True, False]
--log_level [0, 1, 2]
--new_mask_store [True, False]
diff --git a/notebooks/example.ipynb b/notebooks/example.ipynb
index bcd2c419..1225c22a 100644
--- a/notebooks/example.ipynb
+++ b/notebooks/example.ipynb
@@ -79,10 +79,10 @@
"model_name = \"microsoft/phi-2\"\n",
"\n",
"# Load the unconstrained original model\n",
- "llm = Syncode(model = model_name, mode='original', max_new_tokens=20, chat_mode=True)\n",
+ "llm = Syncode(model = model_name, mode='original', max_new_tokens=20)\n",
"\n",
"# Load the Syncode augmented model\n",
- "syn_llm = Syncode(model = model_name, mode='grammar_mask', grammar=grammar, chat_mode=True)"
+ "syn_llm = Syncode(model = model_name, mode='grammar_mask', grammar=grammar)"
]
},
{
diff --git a/notebooks/example_instruct_model.ipynb b/notebooks/example_instruct_model.ipynb
new file mode 100644
index 00000000..37ba2ec1
--- /dev/null
+++ b/notebooks/example_instruct_model.ipynb
@@ -0,0 +1,99 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/shubham/anaconda3/envs/codex/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+ " from .autonotebook import tqdm as notebook_tqdm\n",
+ "Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 12.69it/s]\n"
+ ]
+ }
+ ],
+ "source": [
+ "from syncode import Syncode\n",
+ "import warnings\n",
+ "warnings.filterwarnings('ignore')\n",
+ "\n",
+ "model_name = \"meta-llama/Meta-Llama-3-8B-Instruct\"\n",
+ "\n",
+ "# Load the Syncode augmented model\n",
+ "syn_llm = Syncode(model=model_name, grammar='json', max_new_tokens=400)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
+ "Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{ \"name\": \"University of Illinois at Urbana-Champaign\", \n",
+ "\"location\": \"Champaign, Illinois, USA\", \n",
+ "\"established\": 1867, \n",
+ "\"type\": \"Public research university\", \n",
+ "\"enrollment\": 50000, \n",
+ "\"campus\": \"Urban\", \n",
+ "\"mascot\": \"Chief Illiniwek\", \n",
+ "\"athletic_conference\": \"Big Ten Conference\", \n",
+ "\"notable_alumni\": [\"Ray Kroc\", \"Lee DeForest\", \"Napoleon Hill\", \"John Bardeen\", \"Roger Ebert\"], \n",
+ "\"reputation\": \"Highly ranked in the US News & World Report's Best Colleges\", \n",
+ "\"programs\": [\"Engineering\", \"Business\", \"Agriculture\", \"Computer Science\", \"Education\"], \n",
+ "\"research_areas\": [\"Materials Science\", \"Biotechnology\", \"Energy\", \"Aerospace Engineering\", \"Computer Science\"], \n",
+ "\"campus_size\": \"1,783 acres\", \n",
+ "\"student_to_faculty_ratio\": \"18:1\", \n",
+ "\"acceptance_rate\": \"64.4%\", \n",
+ "\"tuition\": \"$15,622 (in-state), $29,444 (out-of-state)\" } \n",
+ "\n",
+ " \n"
+ ]
+ }
+ ],
+ "source": [
+ "messages = [\n",
+ " {\"role\": \"system\", \"content\": \"You are a chatbot who always returns a JSON object.\"},\n",
+ " {\"role\": \"user\", \"content\": \"can you give me a JSON object describing University of Illinois at Urbana-Champaign?\"},\n",
+ "]\n",
+ "\n",
+ "out = syn_llm.infer(messages)\n",
+ "print(out[0])"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "codex",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.4"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/syncode/evaluation/code_eval.py b/syncode/evaluation/code_eval.py
index bd2d1fc0..21396c8d 100644
--- a/syncode/evaluation/code_eval.py
+++ b/syncode/evaluation/code_eval.py
@@ -68,7 +68,7 @@ def run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samp
else:
prompt = problems[task_id]["prompt"]
- batch_completions = syncode.model.generate_batch_completion_grammar(prompt, num_samples_per_task, stop_words=stop_words, return_token_ids=True)
+ batch_completions = syncode.model.generate_grammar_constrained_completion(prompt, num_samples_per_task, stop_words=stop_words, return_token_ids=True)
batch_size = len(batch_completions)
all_completions = []
diff --git a/syncode/infer.py b/syncode/infer.py
index d2090974..e970e759 100644
--- a/syncode/infer.py
+++ b/syncode/infer.py
@@ -5,7 +5,7 @@
import torch
from syncode.language_model import HuggingFaceModel
from syncode.grammar_decoder import SyncodeLogitsProcessor
-from typing import Optional, Literal
+from typing import Optional, Literal, Union
from syncode.parsers.grammars import Grammar
from syncode.dataset import Dataset
from syncode.evaluation.code_eval import CodeEval
@@ -15,9 +15,9 @@
from syncode.evaluation.fol_eval import FOLEval
-def compile_and_run(model, mode="grammar_strict", quantize=True, device="cuda", grammar=None, dataset="input", num_few_shot=0, chat_mode=False, dev_mode=False, log_level=1, new_mask_store=False, parser="lalr", num_tasks=None, task_id=None, seed=None, opp=True, debug=False, **kwargs):
+def compile_and_run(model, mode="grammar_strict", quantize=True, device="cuda", grammar=None, dataset="input", num_few_shot=0, dev_mode=False, log_level=1, new_mask_store=False, parser="lalr", num_tasks=None, task_id=None, seed=None, opp=True, debug=False, **kwargs):
- syncode = Syncode(model, mode=mode, quantize=quantize, device=device, grammar=grammar, chat_mode=chat_mode, dev_mode=dev_mode, log_level=log_level, new_mask_store=new_mask_store, parser=parser, seed=seed, opp=opp, **kwargs)
+ syncode = Syncode(model, mode=mode, quantize=quantize, device=device, grammar=grammar, dev_mode=dev_mode, log_level=log_level, new_mask_store=new_mask_store, parser=parser, seed=seed, opp=opp, **kwargs)
if dataset == "input":
syncode.infer(debug=debug)
@@ -54,8 +54,6 @@ class Syncode:
parse_output_only (bool, optional): Parse only the output. Defaults to True.
new_mask_store (bool, optional): Use new DFA mask store. Defaults to False.
-
- chat_mode (bool, optional): Parse only the (output) and not (prompt+output) in chat mode. Defaults to False.
dev_mode (bool, optional): Development mode. Defaults to False.
@@ -70,7 +68,6 @@ def __init__(
quantize: bool = True,
device: str = "cuda",
grammar: Optional[str] = None,
- chat_mode: bool = False,
parse_output_only: bool = True,
dev_mode: bool = False,
log_level: int = 1,
@@ -94,17 +91,13 @@ def __init__(
self.num_samples = kwargs.get('num_return_sequences', 1)
self.new_mask_store = new_mask_store
self.parser = parser
- self.chat_mode = chat_mode
self.log_level = log_level
# Set seed
if seed is not None:
torch.manual_seed(seed)
- if self.chat_mode:
- self.parse_output_only = True
- else:
- self.parse_output_only = parse_output_only
+ self.parse_output_only = parse_output_only
# Set the grammar
self.language = grammar
@@ -196,7 +189,7 @@ def evaluate(
logger.close()
return output
- def user_input(self, prompt:str, stop_words=None, debug=False):
+ def user_input(self, prompt:Union[str, list], stop_words=None, debug=False):
"""
Run user input on the model with grammar mask
@@ -205,14 +198,11 @@ def user_input(self, prompt:str, stop_words=None, debug=False):
stop_words (list, optional): Stop words to use. Defaults to None.
debug (bool, optional): Debug mode. Defaults to False.
"""
- if prompt:
- if self.grammar_decoder is not None: # TODO: Remove this check
- self.grammar_decoder.reset(prompt)
-
- if self.chat_mode:
- return self.model.generate_chat_completion_grammar(prompt)
- else:
- return self.model.generate_batch_completion_grammar(prompt, self.num_samples, stop_words=stop_words, debug=debug)
+ if prompt:
+ if isinstance(prompt, list):
+ assert self.parse_output_only == True, "Prompt must be a string for input+output parsing"
+
+ return self.model.generate_grammar_constrained_completion(prompt, self.num_samples, stop_words=stop_words, debug=debug)
else:
while True:
prompt = input('Enter prompt: ')
@@ -220,7 +210,7 @@ def user_input(self, prompt:str, stop_words=None, debug=False):
if prompt == "exit":
break
- batch_completions = self.model.generate_batch_completion_grammar(prompt, self.num_samples)
+ batch_completions = self.model.generate_grammar_constrained_completion(prompt, self.num_samples)
for i, completion in enumerate(batch_completions):
print(prompt + completion)
diff --git a/syncode/language_model.py b/syncode/language_model.py
index ec3be393..fd282806 100644
--- a/syncode/language_model.py
+++ b/syncode/language_model.py
@@ -1,3 +1,4 @@
+from ast import Tuple
import time
import torch
import syncode.common as common
@@ -5,7 +6,7 @@
from transformers import LogitsProcessorList, StoppingCriteriaList, StoppingCriteria
from syncode.parsers.grammars import Grammar
from syncode.utils.generation import filter_code, fix_indents
-from typing import Callable, Iterable
+from typing import Callable, Iterable, Union
from transformers.generation.utils import GenerationMode
from transformers.generation.configuration_utils import GenerationConfig
@@ -67,9 +68,9 @@ def get_grammar_decoder(self):
return None
@torch.inference_mode()
- def generate_batch_completion_grammar(
+ def generate_grammar_constrained_completion(
self,
- prompt,
+ prompt: Union[str, list],
batch_size,
stop_words=None,
return_token_ids=False,
@@ -84,13 +85,13 @@ def generate_batch_completion_grammar(
stop_words (list): A list of stop words. If the completion ends with any of the stop words, the completion is returned.
return_token_ids (bool): If True, returns the token ids of the completions.
debug (bool): If True, prints debug information.
- '''
+ '''
+ inputs, prompt_str = self.get_tokenized_input(prompt, batch_size)
+
# Reset the grammar decoder
if self.grammar_decoder is not None:
- self.grammar_decoder.reset(prompt)
+ self.grammar_decoder.reset(prompt_str)
- input_batch = [prompt for _ in range(batch_size)]
- inputs = self.tokenizer(input_batch, return_tensors="pt").to(self.model.device)
input_ids_cutoff = inputs.input_ids.size(dim=1)
# Get the generation config
@@ -134,6 +135,8 @@ def generate_batch_completion_grammar(
# Total tokens generated
self.total_tokens = 0
+
+ # Decode the completions
for i in range(batch_size):
self.total_tokens += len(generated_ids[i])-input_ids_cutoff+1
completion = self.tokenizer.decode(generated_ids[i][input_ids_cutoff:len(generated_ids[i])], skip_special_tokens=True)
@@ -145,6 +148,29 @@ def generate_batch_completion_grammar(
return batch_completions
+
+ def get_tokenized_input(self, prompt: Union[str, list], batch_size: int):
+ """
+ Tokenizes the input prompt and returns the input dictionary for the model.
+ """
+ if isinstance(prompt, list):
+ assert 'apply_chat_template' in dir(self.tokenizer), "Tokenizer does not support chat completion"
+ prompt_str = self.tokenizer.apply_chat_template(
+ prompt,
+ add_generation_prompt=True,
+ return_tensors="pt",
+ tokenize=False
+ )
+ elif isinstance(prompt, str):
+ prompt_str = prompt
+ else:
+ raise ValueError("Prompt should be either a string or a list! It is currently of type: "+str(type(prompt)))
+
+ input_batch = [prompt_str for _ in range(batch_size)]
+ inputs = self.tokenizer(input_batch, return_tensors="pt").to(self.model.device)
+
+ return inputs, prompt_str
+
@torch.inference_mode()
def _generate(
self,
@@ -252,37 +278,6 @@ def _get_generation_mode(
else:
generation_mode = GenerationMode.BEAM_SEARCH
return generation_mode
-
- @torch.inference_mode()
- def generate_chat_completion_grammar(self, prompt) -> str:
- '''
- Generates chat completion for the given prompt.
- '''
- assert 'apply_chat_template' in dir(self.tokenizer), "Tokenizer does not support chat completion"
-
- message = [{"role": "user", "content": prompt}]
- inputs = self.tokenizer.apply_chat_template(
- message,
- add_generation_prompt=True,
- return_tensors="pt"
- ).to(self.model.device)
-
- input_ids_cutoff = inputs.size(dim=1)
-
- # input_ids_cutoff = inputs.input_ids.size(dim=1)
- start_time = time.time()
-
- generated_ids = self.model.generate(
- inputs,
- logits_processor=self.grammar_processor,
- **self.gen_args
- )
- completion = self.tokenizer.decode(
- generated_ids[0][input_ids_cutoff:len(generated_ids[0])],
- skip_special_tokens=True)
-
- return completion
-
def tokenize(self, s: str) -> 'Iterable[int]':
return self.tokenizer.encode(s, add_special_tokens=False)
diff --git a/tests/test_language_model.py b/tests/test_language_model.py
index c1986401..c77c0be3 100644
--- a/tests/test_language_model.py
+++ b/tests/test_language_model.py
@@ -62,7 +62,7 @@ def test_generate_batch_completion_grammar(self):
logger = common.EmptyLogger()
lm = HuggingFaceModel(model, Grammar('calc'), tokenizer, mode='original', max_new_tokens=15, device='cpu')
prompt = "113 + 235 + 17"
- output = lm.generate_batch_completion_grammar(prompt, 1)
+ output = lm.generate_grammar_constrained_completion(prompt, 1)
self.assertEqual(len(output[0]), 15, "The output length does not match the expected value.")
def test_generate_batch_completion_grammar2(self):
@@ -72,7 +72,7 @@ def test_generate_batch_completion_grammar2(self):
logger = common.EmptyLogger()
lm = HuggingFaceModel(model, Grammar('calc'), tokenizer, mode='original', max_new_tokens=15, device='cpu')
prompt = "113 + 235 + 17"
- output = lm.generate_batch_completion_grammar(prompt, 2)
+ output = lm.generate_grammar_constrained_completion(prompt, 2)
self.assertEqual(len(output[0]), 15, "The output length does not match the expected value.")
self.assertEqual(len(output[1]), 15, "The output length does not match the expected value.")