diff --git a/syncode/evaluation/fol_eval.py b/syncode/evaluation/fol_eval.py index a747a12b..5a937f0c 100644 --- a/syncode/evaluation/fol_eval.py +++ b/syncode/evaluation/fol_eval.py @@ -98,7 +98,7 @@ def run_eval(syncode, out_path: Optional[str]=None, debug_task_id=None): for task_id, problem in enumerate(problems): results[task_id] = [] full_prompt = FOLEval._prompt_folio(problem) - completion = syncode.model.generate_batch_completion_grammar( + completion = syncode.model.generate_grammar_constrained_completion( full_prompt, syncode.num_samples, stop_words=['\n\n', '------'] diff --git a/syncode/evaluation/json_eval.py b/syncode/evaluation/json_eval.py index 909603d1..89520e80 100644 --- a/syncode/evaluation/json_eval.py +++ b/syncode/evaluation/json_eval.py @@ -70,7 +70,7 @@ def run_eval_for_task(syncode, num_samples_per_task, problem, samples, pbar, tas prompt = syncode.model.tokenizer.apply_chat_template(problem["prompt"], tokenize = False) - batch_completions = syncode.model.generate_batch_completion_grammar(prompt, num_samples_per_task) + batch_completions = syncode.model.generate_grammar_constrained_completion(prompt, num_samples_per_task) for completion_id, completion in enumerate(batch_completions): result = dict( task_id = task_id, diff --git a/syncode/evaluation/math_eval.py b/syncode/evaluation/math_eval.py index 96a088cc..1fa79ea2 100644 --- a/syncode/evaluation/math_eval.py +++ b/syncode/evaluation/math_eval.py @@ -22,7 +22,7 @@ def run_math_eval(syncode, out_path: Optional[str], debug_task_id=None, logger=c for task_id, problem in enumerate(problems): results[task_id] = [] - batch_completions = syncode.model.generate_batch_completion_grammar( + batch_completions = syncode.model.generate_grammar_constrained_completion( problem['question'], syncode.num_samples ) diff --git a/tests/test_language_model.py b/tests/test_language_model.py index c77c0be3..9882ec39 100644 --- a/tests/test_language_model.py +++ b/tests/test_language_model.py @@ -55,7 +55,7 @@ def get_vocab(self) -> Dict[str, int]: return {v: i for i, v in enumerate(self.vocab)} class TestHuggingFaceModel(unittest.TestCase): - def test_generate_batch_completion_grammar(self): + def test_generate_grammar_constrained_completion(self): torch.manual_seed(0) model = TestModel() tokenizer = TestTokenizer() @@ -65,7 +65,7 @@ def test_generate_batch_completion_grammar(self): output = lm.generate_grammar_constrained_completion(prompt, 1) self.assertEqual(len(output[0]), 15, "The output length does not match the expected value.") - def test_generate_batch_completion_grammar2(self): + def test_generate_grammar_constrained_completion2(self): torch.manual_seed(0) model = TestModel() tokenizer = TestTokenizer()