Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ from syncode import Syncode
model_name = "WizardLM/WizardCoder-1B-V1.0"

# Load the Syncode augmented model
syn_llm = Syncode(model=model_name, mode='grammar_strict', grammar='python')
syn_llm = Syncode(model=model_name, mode='grammar_strict', grammar='python', parse_output_only=False)
partial_code = "def is_prime(n):\n '''Return if prime'''\n "

#generate a completion to the input partial code
Expand Down
2 changes: 1 addition & 1 deletion notebooks/example_python.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"llm = Syncode(model = model_name, mode='original', max_new_tokens=200)\n",
"\n",
"# Load the Syncode augmented model\n",
"syn_llm = Syncode(model = model_name, mode='grammar_mask', grammar='python')"
"syn_llm = Syncode(model = model_name, mode='grammar_mask', grammar='python', parse_output_only=False)"
]
},
{
Expand Down
1 change: 0 additions & 1 deletion syncode/dfa_mask_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ def incomplete_case_lookup(self, dfa_state: DFAState) -> Any:
if dfa_state in self._exact_lookup:
return self._exact_lookup[dfa_state]
else:
print(f"Warning: Exact lookup not found for {dfa_state} in the DFA mask store. This could be an error.", flush=True)
return self._overapprox_lookup[dfa_state]
raise ValueError(f"Invalid mode: {self._mode}")

Expand Down
7 changes: 4 additions & 3 deletions syncode/grammar_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self,
tokenizer: PreTrainedTokenizer,
logger: common.Logger=common.EmptyLogger(),
use_cache=True,
parse_output_only=False,
parse_output_only=True,
num_samples=1,
dev_mode=False,
parser='lalr',
Expand Down Expand Up @@ -122,8 +122,9 @@ def is_valid(self, input_ids: torch.LongTensor, next_token: torch.LongTensor) ->
return False

if input_ids[0, -1] == self.tokenizer.eos_token_id:
is_valid = AcceptSequence(['$END']) in r.accept_sequences

# Do not allow the model to generate EOS token until $END in the grammar is reached
return AcceptSequence(['$END']) in r.accept_sequences

if r.remainder_state == RemainderState.COMPLETE or r.remainder_state == RemainderState.MAYBE_COMPLETE:
is_valid = True

Expand Down
11 changes: 7 additions & 4 deletions syncode/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,12 @@ class Syncode:
num_samples (int, optional): Number of samples. Defaults to 1.

grammar (str, optional): Language. Defaults to "input". "input" is used for user input.
other options currently supported are "python", "go", "calc"

other options currently supported are "python", "go", "calc", "sql", "json", "fol".

parser (str, optional): Parser to use. Defaults to "lalr".

parse_output_only (bool, optional): Parse only the output. Defaults to True.

new_mask_store (bool, optional): Use new DFA mask store. Defaults to False.

chat_mode (bool, optional): Parse only the (output) and not (prompt+output) in chat mode. Defaults to False.
Expand All @@ -57,7 +61,6 @@ class Syncode:

log_level (int, optional): Log level. Defaults to 2. 0 for no logs, 1 for minimal logs, 2 for all logs including time.

parser (str, optional): Parser to use. Defaults to "lalr".
opp (bool, optional): Whether to use opportunistic generation. Defaults to True.
"""
def __init__(
Expand All @@ -68,7 +71,7 @@ def __init__(
device: str = "cuda",
grammar: Optional[str] = None,
chat_mode: bool = False,
parse_output_only: bool = False,
parse_output_only: bool = True,
dev_mode: bool = False,
log_level: int = 1,
new_mask_store: bool = False,
Expand Down
27 changes: 27 additions & 0 deletions tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
import os
import sys

import torch

from syncode.dfa_mask_store import DFAMaskStore
from syncode.grammar_decoder import SyncodeLogitsProcessor

# Adjusting the path so the modules can be imported correctly
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../')
Expand Down Expand Up @@ -69,6 +72,30 @@ def test_mask_store_misc3(self):
dfa_mask = DFAMaskStore.load_dfa_mask_store(grammar=grammar, tokenizer=tokenizer, use_cache=False, logger=common.EmptyLogger())
mask = dfa_mask.get_accept_mask(r, get_list=True)
self.assertIn(' I', mask)

def test_grammar_decoder_empty(self):
"""
Make sure the grammar decoder does not accept an empty input
"""
grammar = Grammar('json')
model = 'microsoft/Phi-3-mini-128k-instruct'
tokenizer = common.load_tokenizer(model)
grammar_decoder = SyncodeLogitsProcessor(
grammar,
tokenizer=tokenizer,
use_cache=True,
parse_output_only=True,
)

prompt = "Generate a JSON object"
input_ids = tokenizer(prompt, return_tensors='pt')['input_ids']
grammar_decoder.reset(prompt)

# Create an empty 2D tensor
next_token = torch.tensor([tokenizer.eos_token_id], dtype=torch.long)

is_valid = grammar_decoder.is_valid(input_ids=input_ids, next_token=next_token)
self.assertFalse(is_valid)

def test_parser_calc(self):
inc_parser = create_parser(Grammar('calc'))
Expand Down
2 changes: 1 addition & 1 deletion tests/test_syncode.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class TestSyncode(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Assuming Syncode and get_data are set up correctly in your environment
cls.sg_mask = Syncode(model="test", mode='grammar_mask', device='cpu', do_sample=False, max_new_tokens=200, grammar='python')
cls.sg_mask = Syncode(model="test", mode='grammar_mask', device='cpu', do_sample=False, max_new_tokens=200, grammar='python', parse_output_only=False)
cls.problems = list(get_data('multi-humaneval', 'python').values())

cls.sg_mask_instruct = Syncode(model="test-instruct", mode='grammar_mask', device='cpu', do_sample=False, max_new_tokens=20, grammar=cls.custom_grammar())
Expand Down