From 1085438cd92fdeceb21af8f1faf1700ef8a0cb7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E7=83=81?= Date: Tue, 27 Jun 2023 09:50:33 +0800 Subject: [PATCH 1/2] fix_web_demo_aways_repeat_first_question_bug --- models/chatglm/__init__.py | 3 +-- web_demo.py | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/models/chatglm/__init__.py b/models/chatglm/__init__.py index 0440751..f44d5a8 100644 --- a/models/chatglm/__init__.py +++ b/models/chatglm/__init__.py @@ -38,8 +38,7 @@ def chat(self) -> str: print(flush=True) def run_web_demo(self, input_text, history=[]): - while True: - yield self.run(input_text, history=history) + return self.model.stream_chat(self.tokenizer, input_text, history=history) def run(self, text, history=[]): return self.model.chat(self.tokenizer, text, history=history) diff --git a/web_demo.py b/web_demo.py index d07b908..1c2e1cb 100644 --- a/web_demo.py +++ b/web_demo.py @@ -36,7 +36,7 @@ def predict(input, history=None): with gr.Row(): with gr.Column(scale=4): - txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", lines=11).style( + txt = gr.Textbox(show_label=False, placeholder="Enter text and press shift + enter", lines=11).style( container=False) with gr.Column(scale=1): max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) @@ -44,4 +44,5 @@ def predict(input, history=None): temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) button = gr.Button("Generate") button.click(predict, [txt, state], [state] + text_boxes) + txt.submit(predict, [txt, state], [state] + text_boxes) demo.queue().launch(share=False, inbrowser=False, server_port=51234, server_name="0.0.0.0") From c1d0c91bf824a3c254e2e81578f7b4141404e8fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E7=83=81?= Date: Tue, 27 Jun 2023 13:02:44 +0800 Subject: [PATCH 2/2] clear input function after output gennerates --- web_demo.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/web_demo.py b/web_demo.py index 1c2e1cb..bfa962d 100644 --- a/web_demo.py +++ b/web_demo.py @@ -16,7 +16,7 @@ def predict(input, history=None): updates.append(gr.update(visible=True, value=f"{args.model}:" + response)) if len(updates) < MAX_BOXES: updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates)) - yield [history] + updates + yield [history] + updates + [""] if __name__ == "__main__": @@ -43,6 +43,6 @@ def predict(input, history=None): top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) button = gr.Button("Generate") - button.click(predict, [txt, state], [state] + text_boxes) - txt.submit(predict, [txt, state], [state] + text_boxes) + button.click(predict, [txt, state], [state] + text_boxes + [txt]) + txt.submit(predict, [txt, state], [state] + text_boxes + [txt]) demo.queue().launch(share=False, inbrowser=False, server_port=51234, server_name="0.0.0.0")