From 350b6b12084eaf52a3bd2231374823fd64597cda Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Thu, 25 Jan 2024 12:02:56 -0800 Subject: [PATCH 1/3] fixed gradio app --- ferret/serve/gradio_web_server.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/ferret/serve/gradio_web_server.py b/ferret/serve/gradio_web_server.py index 67576c8..a8f4b64 100644 --- a/ferret/serve/gradio_web_server.py +++ b/ferret/serve/gradio_web_server.py @@ -554,9 +554,10 @@ def draw(input_mode, input, refer_input_state, refer_text_show, imagebox_refer): last_mask = refer_input_state['masks'][-1] diff_mask = mask_new - last_mask - if torch.all(diff_mask == 0): - print('Init Uploading Images.') + if mask_new.sum() == 0: return (refer_input_state, refer_text_show, image) + elif torch.all(diff_mask == 0): + return (refer_input_state, refer_input_state['refer_text_show'][-1], refer_input_state['imagebox_refer'][-1]) else: refer_input_state['masks'].append(mask_new) @@ -600,12 +601,15 @@ def draw(input_mode, input, refer_input_state, refer_text_show, imagebox_refer): refer_input_state['region_coordinates'].append(cur_region_coordinates) refer_input_state['region_masks'].append(cur_region_masks) assert len(refer_input_state['region_masks']) == len(refer_input_state['region_coordinates']) == len(refer_input_state['region_placeholder_tokens']) - refer_text_show.append((cur_region_token, '')) + refer_text_show.append((cur_region_token, None)) # Show Parsed Referring. imagebox_refer = draw_box(sampled_coor, cur_region_masks, \ cur_region_token, imagebox_refer, input_mode) + refer_input_state['refer_text_show'].append(refer_text_show) + refer_input_state['imagebox_refer'].append(imagebox_refer) + return (refer_input_state, refer_text_show, imagebox_refer) def build_demo(embed_mode): @@ -635,7 +639,7 @@ def build_demo(embed_mode): visible=False) # Added for any-format input. - sketch_pad = ImageMask(label="Image & Sketch", type="pil", elem_id="img2text") + sketch_pad = gr.ImageMask(label="Image & Sketch", type="pil", elem_id="img2text") refer_input_mode = gr.Radio( ["Point", "Box", "Sketch"], value="Point", @@ -645,6 +649,8 @@ def build_demo(embed_mode): 'region_masks':[], 'region_masks_in_prompts':[], 'masks':[], + 'refer_text_show': [], + 'imagebox_refer': [], }) refer_text_show = gr.HighlightedText(value=[], label="Referring Input Cache") From 1646a57c84a7b976ac8b26824117db40c6ac6c2a Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Sat, 27 Jan 2024 10:25:41 -0800 Subject: [PATCH 2/3] don't initially skip --- ferret/serve/gradio_web_server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ferret/serve/gradio_web_server.py b/ferret/serve/gradio_web_server.py index a8f4b64..3cad9e9 100644 --- a/ferret/serve/gradio_web_server.py +++ b/ferret/serve/gradio_web_server.py @@ -616,6 +616,7 @@ def build_demo(embed_mode): textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", visible=False, container=False) with gr.Blocks(title="FERRET", theme=gr.themes.Base(), css=css) as demo: state = gr.State() + state.skip_next = False if not embed_mode: gr.Markdown(title_markdown) From d5f64a7b4d00493b57870b506db1eed4e3a58201 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Sat, 27 Jan 2024 10:58:13 -0800 Subject: [PATCH 3/3] add refer_text_show to state --- ferret/serve/gradio_web_server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ferret/serve/gradio_web_server.py b/ferret/serve/gradio_web_server.py index 3cad9e9..d6902ba 100644 --- a/ferret/serve/gradio_web_server.py +++ b/ferret/serve/gradio_web_server.py @@ -555,6 +555,7 @@ def draw(input_mode, input, refer_input_state, refer_text_show, imagebox_refer): diff_mask = mask_new - last_mask if mask_new.sum() == 0: + refer_input_state['refer_text_show'].append(refer_text_show) return (refer_input_state, refer_text_show, image) elif torch.all(diff_mask == 0): return (refer_input_state, refer_input_state['refer_text_show'][-1], refer_input_state['imagebox_refer'][-1])