diff --git a/ferret/serve/gradio_web_server.py b/ferret/serve/gradio_web_server.py index 67576c8..d6902ba 100644 --- a/ferret/serve/gradio_web_server.py +++ b/ferret/serve/gradio_web_server.py @@ -554,9 +554,11 @@ def draw(input_mode, input, refer_input_state, refer_text_show, imagebox_refer): last_mask = refer_input_state['masks'][-1] diff_mask = mask_new - last_mask - if torch.all(diff_mask == 0): - print('Init Uploading Images.') + if mask_new.sum() == 0: + refer_input_state['refer_text_show'].append(refer_text_show) return (refer_input_state, refer_text_show, image) + elif torch.all(diff_mask == 0): + return (refer_input_state, refer_input_state['refer_text_show'][-1], refer_input_state['imagebox_refer'][-1]) else: refer_input_state['masks'].append(mask_new) @@ -600,18 +602,22 @@ def draw(input_mode, input, refer_input_state, refer_text_show, imagebox_refer): refer_input_state['region_coordinates'].append(cur_region_coordinates) refer_input_state['region_masks'].append(cur_region_masks) assert len(refer_input_state['region_masks']) == len(refer_input_state['region_coordinates']) == len(refer_input_state['region_placeholder_tokens']) - refer_text_show.append((cur_region_token, '')) + refer_text_show.append((cur_region_token, None)) # Show Parsed Referring. imagebox_refer = draw_box(sampled_coor, cur_region_masks, \ cur_region_token, imagebox_refer, input_mode) + refer_input_state['refer_text_show'].append(refer_text_show) + refer_input_state['imagebox_refer'].append(imagebox_refer) + return (refer_input_state, refer_text_show, imagebox_refer) def build_demo(embed_mode): textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", visible=False, container=False) with gr.Blocks(title="FERRET", theme=gr.themes.Base(), css=css) as demo: state = gr.State() + state.skip_next = False if not embed_mode: gr.Markdown(title_markdown) @@ -635,7 +641,7 @@ def build_demo(embed_mode): visible=False) # Added for any-format input. - sketch_pad = ImageMask(label="Image & Sketch", type="pil", elem_id="img2text") + sketch_pad = gr.ImageMask(label="Image & Sketch", type="pil", elem_id="img2text") refer_input_mode = gr.Radio( ["Point", "Box", "Sketch"], value="Point", @@ -645,6 +651,8 @@ def build_demo(embed_mode): 'region_masks':[], 'region_masks_in_prompts':[], 'masks':[], + 'refer_text_show': [], + 'imagebox_refer': [], }) refer_text_show = gr.HighlightedText(value=[], label="Referring Input Cache")