diff --git a/models/chatglm/__init__.py b/models/chatglm/__init__.py index 0440751..f44d5a8 100644 --- a/models/chatglm/__init__.py +++ b/models/chatglm/__init__.py @@ -38,8 +38,7 @@ def chat(self) -> str: print(flush=True) def run_web_demo(self, input_text, history=[]): - while True: - yield self.run(input_text, history=history) + return self.model.stream_chat(self.tokenizer, input_text, history=history) def run(self, text, history=[]): return self.model.chat(self.tokenizer, text, history=history) diff --git a/web_demo.py b/web_demo.py index d07b908..bfa962d 100644 --- a/web_demo.py +++ b/web_demo.py @@ -16,7 +16,7 @@ def predict(input, history=None): updates.append(gr.update(visible=True, value=f"{args.model}:" + response)) if len(updates) < MAX_BOXES: updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates)) - yield [history] + updates + yield [history] + updates + [""] if __name__ == "__main__": @@ -36,12 +36,13 @@ def predict(input, history=None): with gr.Row(): with gr.Column(scale=4): - txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", lines=11).style( + txt = gr.Textbox(show_label=False, placeholder="Enter text and press shift + enter", lines=11).style( container=False) with gr.Column(scale=1): max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) button = gr.Button("Generate") - button.click(predict, [txt, state], [state] + text_boxes) + button.click(predict, [txt, state], [state] + text_boxes + [txt]) + txt.submit(predict, [txt, state], [state] + text_boxes + [txt]) demo.queue().launch(share=False, inbrowser=False, server_port=51234, server_name="0.0.0.0")