diff --git a/src/viser/_gui_api.py b/src/viser/_gui_api.py index f55c99013..5ba529622 100644 --- a/src/viser/_gui_api.py +++ b/src/viser/_gui_api.py @@ -58,8 +58,8 @@ GuiVector3Handle, SupportsRemoveProtocol, UploadedFile, + _GuiHandle, _GuiHandleState, - _GuiInputHandle, _make_uuid, ) from ._icons import svg_from_icon @@ -203,7 +203,7 @@ def __init__( ) """Interface for sending and listening to messages.""" - self._gui_input_handle_from_uuid: dict[str, _GuiInputHandle[Any]] = {} + self._gui_handle_from_uuid: dict[str, _GuiHandle[Any]] = {} self._container_handle_from_uuid: dict[str, GuiContainerProtocol] = { "root": _RootGuiContainer({}) } @@ -228,7 +228,7 @@ async def _handle_gui_updates( self, client_id: ClientId, message: _messages.GuiUpdateMessage ) -> None: """Callback for handling GUI messages.""" - handle = self._gui_input_handle_from_uuid.get(message.uuid, None) + handle = self._gui_handle_from_uuid.get(message.uuid, None) if handle is None: return handle_state = handle._impl @@ -293,7 +293,7 @@ async def _handle_gui_updates( def _handle_file_transfer_start( self, client_id: ClientId, message: _messages.FileTransferStart ) -> None: - if message.source_component_uuid not in self._gui_input_handle_from_uuid: + if message.source_component_uuid not in self._gui_handle_from_uuid: return self._current_file_upload_states[message.transfer_uuid] = { "filename": message.filename, @@ -310,7 +310,7 @@ def _handle_file_transfer_part( ) -> None: if message.transfer_uuid not in self._current_file_upload_states: return - assert message.source_component_uuid in self._gui_input_handle_from_uuid + assert message.source_component_uuid in self._gui_handle_from_uuid state = self._current_file_upload_states[message.transfer_uuid] state["parts"][message.part] = message.content @@ -336,9 +336,7 @@ def _handle_file_transfer_part( assert state["transferred_bytes"] == total_bytes state = self._current_file_upload_states.pop(message.transfer_uuid) - handle = self._gui_input_handle_from_uuid.get( - message.source_component_uuid, None - ) + handle = self._gui_handle_from_uuid.get(message.source_component_uuid, None) if handle is None: return @@ -649,6 +647,7 @@ def add_image( media_type="image/png" if format == "png" else "image/jpeg", order=_apply_default_order(order), visible=visible, + _clickable=False, ), ) self._websock_interface.queue_message(message) @@ -657,9 +656,10 @@ def add_image( _GuiHandleState( message.uuid, self, - None, + [0.0, 0.0], props=message.props, parent_container_id=message.container_uuid, + is_button=True, ), _image=image, _jpeg_quality=jpeg_quality, diff --git a/src/viser/_gui_handles.py b/src/viser/_gui_handles.py index b245ebdc5..0bb6f0653 100644 --- a/src/viser/_gui_handles.py +++ b/src/viser/_gui_handles.py @@ -33,6 +33,7 @@ from ._messages import ( GuiBaseProps, GuiButtonGroupProps, + GuiButtonProps, GuiCheckboxProps, GuiCloseModalMessage, GuiDropdownProps, @@ -64,7 +65,7 @@ T = TypeVar("T") -TGuiHandle = TypeVar("TGuiHandle", bound="_GuiInputHandle") +TGuiHandle = TypeVar("TGuiHandle", bound="_GuiHandle") NoneOrCoroutine = TypeVar("NoneOrCoroutine", None, Coroutine) @@ -165,8 +166,7 @@ def __init__(self, _impl: _GuiHandleState[T]) -> None: ] parent._children[self._impl.uuid] = self - if isinstance(self, _GuiInputHandle): - self._impl.gui_api._gui_input_handle_from_uuid[self._impl.uuid] = self + self._impl.gui_api._gui_handle_from_uuid[self._impl.uuid] = self def remove(self) -> None: """Permanently remove this GUI element from the visualizer.""" @@ -191,8 +191,7 @@ def remove(self) -> None: parent = gui_api._container_handle_from_uuid[self._impl.parent_container_id] parent._children.pop(self._impl.uuid) - if isinstance(self, _GuiInputHandle): - gui_api._gui_input_handle_from_uuid.pop(self._impl.uuid) + gui_api._gui_handle_from_uuid.pop(self._impl.uuid) class _GuiInputHandle( @@ -397,7 +396,7 @@ class GuiEvent(Generic[TGuiHandle]): """GUI element that was affected.""" -class GuiButtonHandle(_GuiInputHandle[bool]): +class GuiButtonHandle(_GuiInputHandle[bool], GuiButtonProps): """Handle for a button input in our visualizer. .. attribute:: value @@ -818,7 +817,7 @@ def figure(self, figure: go.Figure) -> None: self._plotly_json_str = json_str -class GuiImageHandle(_GuiHandle[None], GuiImageProps): +class GuiImageHandle(_GuiHandle[Tuple[float, float]], GuiImageProps): """Handle for updating and removing images.""" def __init__( @@ -831,6 +830,36 @@ def __init__( self._image = _image self._jpeg_quality = _jpeg_quality + @property + def clicked_xy(self) -> Tuple[float, float]: + """Last-clicked XY coordinate of the image. Normalized [0, 1].""" + return self._impl.value + + def on_click( + self, func: Callable[[GuiEvent[TGuiHandle]], NoneOrCoroutine] + ) -> Callable[[GuiEvent[TGuiHandle]], NoneOrCoroutine]: + """Attach a function to call when an image is clicked.""" + self._impl.update_cb.append(func) + self._clickable = True + return func + + def remove_click_callback( + self, callback: Literal["all"] | Callable = "all" + ) -> None: + """Remove click callbacks from the GUI input. + + Args: + callback: Either "all" to remove all callbacks, or a specific callback function to remove. + """ + if callback == "all": + self._impl.update_cb.clear() + else: + self._impl.update_cb = [cb for cb in self._impl.update_cb if cb != callback] + + # Set clickable to False if not more callbacks. + if len(self._impl.update_cb) == 0: + self._clickable = False + @property def image(self) -> np.ndarray: """Current content of this image element. Synchronized automatically when assigned.""" diff --git a/src/viser/_messages.py b/src/viser/_messages.py index 30b0652d5..50d31f2ff 100644 --- a/src/viser/_messages.py +++ b/src/viser/_messages.py @@ -915,6 +915,8 @@ class GuiImageProps: """Format of the provided image ('image/jpeg' or 'image/png'). Synchronized automatically when assigned.""" visible: bool """Visibility state of the image. Synchronized automatically when assigned.""" + _clickable: bool + """(Private) Whether the image is clickable. Synchronized automatically when assigned.""" @dataclasses.dataclass diff --git a/src/viser/client/src/WebsocketMessages.ts b/src/viser/client/src/WebsocketMessages.ts index cea1b2b01..eb416ef3c 100644 --- a/src/viser/client/src/WebsocketMessages.ts +++ b/src/viser/client/src/WebsocketMessages.ts @@ -414,6 +414,7 @@ export interface GuiImageMessage { _data: Uint8Array | null; media_type: "image/jpeg" | "image/png"; visible: boolean; + _clickable: boolean; }; } /** GuiTabGroupMessage(uuid: 'str', container_uuid: 'str', props: 'GuiTabGroupProps') @@ -522,7 +523,7 @@ export interface GuiSliderMessage { _marks: { value: number; label: string | null }[] | null; }; } -/** GuiMultiSliderMessage(uuid: 'str', value: 'tuple[float, ...]', container_uuid: 'str', props: 'GuiMultiSliderProps') +/** GuiMultiSliderMessage(uuid: 'str', value: 'Tuple[float, ...]', container_uuid: 'str', props: 'GuiMultiSliderProps') * * (automatically generated) */ diff --git a/src/viser/client/src/components/Image.tsx b/src/viser/client/src/components/Image.tsx index 46c95e0f9..077acd74e 100644 --- a/src/viser/client/src/components/Image.tsx +++ b/src/viser/client/src/components/Image.tsx @@ -1,11 +1,14 @@ +import React from "react"; import { useEffect, useState } from "react"; import { GuiImageMessage } from "../WebsocketMessages"; import { Box, Text } from "@mantine/core"; +import { ViewerContext } from "../ViewerContext"; -function ImageComponent({ props }: GuiImageMessage) { +function ImageComponent({ uuid, props }: GuiImageMessage) { if (!props.visible) return <>; const [imageUrl, setImageUrl] = useState(null); + const viewer = React.useContext(ViewerContext)!; useEffect(() => { if (props._data === null) { @@ -28,11 +31,24 @@ function ImageComponent({ props }: GuiImageMessage) { {props.label} )} + { + if (props._clickable === false) return; + const rect = e.currentTarget.getBoundingClientRect(); + const x = (e.clientX - rect.left) / rect.width; + const y = (e.clientY - rect.top) / rect.height; + viewer.sendMessageRef.current({ + type: "GuiUpdateMessage", + uuid: uuid, + updates: { value: [x, y] }, + }); }} />