diff --git a/README.md b/README.md index fd6ae84b..cf49475d 100644 --- a/README.md +++ b/README.md @@ -266,7 +266,7 @@ from syncode import Syncode model_name = "WizardLM/WizardCoder-1B-V1.0" # Load the Syncode augmented model -syn_llm = Syncode(model=model_name, mode='grammar_strict', grammar='python') +syn_llm = Syncode(model=model_name, mode='grammar_strict', grammar='python', parse_output_only=False) partial_code = "def is_prime(n):\n '''Return if prime'''\n " #generate a completion to the input partial code diff --git a/notebooks/example_python.ipynb b/notebooks/example_python.ipynb index edfbd82d..197a26d4 100644 --- a/notebooks/example_python.ipynb +++ b/notebooks/example_python.ipynb @@ -25,7 +25,7 @@ "llm = Syncode(model = model_name, mode='original', max_new_tokens=200)\n", "\n", "# Load the Syncode augmented model\n", - "syn_llm = Syncode(model = model_name, mode='grammar_mask', grammar='python')" + "syn_llm = Syncode(model = model_name, mode='grammar_mask', grammar='python', parse_output_only=False)" ] }, { diff --git a/syncode/dfa_mask_store.py b/syncode/dfa_mask_store.py index 31036b5c..d1063738 100644 --- a/syncode/dfa_mask_store.py +++ b/syncode/dfa_mask_store.py @@ -155,7 +155,6 @@ def incomplete_case_lookup(self, dfa_state: DFAState) -> Any: if dfa_state in self._exact_lookup: return self._exact_lookup[dfa_state] else: - print(f"Warning: Exact lookup not found for {dfa_state} in the DFA mask store. This could be an error.", flush=True) return self._overapprox_lookup[dfa_state] raise ValueError(f"Invalid mode: {self._mode}") diff --git a/syncode/grammar_decoder.py b/syncode/grammar_decoder.py index 18a195af..a0993f0d 100644 --- a/syncode/grammar_decoder.py +++ b/syncode/grammar_decoder.py @@ -27,7 +27,7 @@ def __init__(self, tokenizer: PreTrainedTokenizer, logger: common.Logger=common.EmptyLogger(), use_cache=True, - parse_output_only=False, + parse_output_only=True, num_samples=1, dev_mode=False, parser='lalr', @@ -122,8 +122,9 @@ def is_valid(self, input_ids: torch.LongTensor, next_token: torch.LongTensor) -> return False if input_ids[0, -1] == self.tokenizer.eos_token_id: - is_valid = AcceptSequence(['$END']) in r.accept_sequences - + # Do not allow the model to generate EOS token until $END in the grammar is reached + return AcceptSequence(['$END']) in r.accept_sequences + if r.remainder_state == RemainderState.COMPLETE or r.remainder_state == RemainderState.MAYBE_COMPLETE: is_valid = True diff --git a/syncode/infer.py b/syncode/infer.py index b4354fae..88245231 100644 --- a/syncode/infer.py +++ b/syncode/infer.py @@ -47,8 +47,12 @@ class Syncode: num_samples (int, optional): Number of samples. Defaults to 1. grammar (str, optional): Language. Defaults to "input". "input" is used for user input. - other options currently supported are "python", "go", "calc" - + other options currently supported are "python", "go", "calc", "sql", "json", "fol". + + parser (str, optional): Parser to use. Defaults to "lalr". + + parse_output_only (bool, optional): Parse only the output. Defaults to True. + new_mask_store (bool, optional): Use new DFA mask store. Defaults to False. chat_mode (bool, optional): Parse only the (output) and not (prompt+output) in chat mode. Defaults to False. @@ -57,7 +61,6 @@ class Syncode: log_level (int, optional): Log level. Defaults to 2. 0 for no logs, 1 for minimal logs, 2 for all logs including time. - parser (str, optional): Parser to use. Defaults to "lalr". opp (bool, optional): Whether to use opportunistic generation. Defaults to True. """ def __init__( @@ -68,7 +71,7 @@ def __init__( device: str = "cuda", grammar: Optional[str] = None, chat_mode: bool = False, - parse_output_only: bool = False, + parse_output_only: bool = True, dev_mode: bool = False, log_level: int = 1, new_mask_store: bool = False, diff --git a/tests/test_misc.py b/tests/test_misc.py index 1ded4383..7dc83bbd 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -2,7 +2,10 @@ import os import sys +import torch + from syncode.dfa_mask_store import DFAMaskStore +from syncode.grammar_decoder import SyncodeLogitsProcessor # Adjusting the path so the modules can be imported correctly sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../') @@ -69,6 +72,30 @@ def test_mask_store_misc3(self): dfa_mask = DFAMaskStore.load_dfa_mask_store(grammar=grammar, tokenizer=tokenizer, use_cache=False, logger=common.EmptyLogger()) mask = dfa_mask.get_accept_mask(r, get_list=True) self.assertIn(' I', mask) + + def test_grammar_decoder_empty(self): + """ + Make sure the grammar decoder does not accept an empty input + """ + grammar = Grammar('json') + model = 'microsoft/Phi-3-mini-128k-instruct' + tokenizer = common.load_tokenizer(model) + grammar_decoder = SyncodeLogitsProcessor( + grammar, + tokenizer=tokenizer, + use_cache=True, + parse_output_only=True, + ) + + prompt = "Generate a JSON object" + input_ids = tokenizer(prompt, return_tensors='pt')['input_ids'] + grammar_decoder.reset(prompt) + + # Create an empty 2D tensor + next_token = torch.tensor([tokenizer.eos_token_id], dtype=torch.long) + + is_valid = grammar_decoder.is_valid(input_ids=input_ids, next_token=next_token) + self.assertFalse(is_valid) def test_parser_calc(self): inc_parser = create_parser(Grammar('calc')) diff --git a/tests/test_syncode.py b/tests/test_syncode.py index ed497d47..3ce0d755 100644 --- a/tests/test_syncode.py +++ b/tests/test_syncode.py @@ -9,7 +9,7 @@ class TestSyncode(unittest.TestCase): @classmethod def setUpClass(cls): # Assuming Syncode and get_data are set up correctly in your environment - cls.sg_mask = Syncode(model="test", mode='grammar_mask', device='cpu', do_sample=False, max_new_tokens=200, grammar='python') + cls.sg_mask = Syncode(model="test", mode='grammar_mask', device='cpu', do_sample=False, max_new_tokens=200, grammar='python', parse_output_only=False) cls.problems = list(get_data('multi-humaneval', 'python').values()) cls.sg_mask_instruct = Syncode(model="test-instruct", mode='grammar_mask', device='cpu', do_sample=False, max_new_tokens=20, grammar=cls.custom_grammar())