diff --git a/models/llama/__init__.py b/models/llama/__init__.py index 085da15..59b4743 100644 --- a/models/llama/__init__.py +++ b/models/llama/__init__.py @@ -71,16 +71,12 @@ def run(self, input_text: str) -> str: return text_out def chat(self): - history = "" while True: - session = "用户输入: " input_text = input("用户输入: ") - session += input_text + "\n" - session += "LLaMA : " with jt.no_grad(): - for output_text in self.generator.generate([input_text], max_gen_len=256, temperature=0.8, top_p=0.95): - print(history + session + output_text, flush=True) - history += session + output_text + "\n" + output_text = self.generator.generate([input_text], max_gen_len=256, temperature=0.8, top_p=0.95) + output_text = list(output_text) + print("LLaMA : "+output_text[-1]) def run_web_demo(self, input_text, history=[]): response = self.run(input_text)