diff --git a/bitcoin_usb/device.py b/bitcoin_usb/device.py index c873961..7f48c74 100644 --- a/bitcoin_usb/device.py +++ b/bitcoin_usb/device.py @@ -16,7 +16,7 @@ from hwilib.devices.trezor import TrezorClient from hwilib.hwwclient import HardwareWalletClient from hwilib.psbt import PSBT -from PyQt6.QtCore import QEventLoop, QObject, Qt, QThread, pyqtSignal +from PyQt6.QtCore import QCoreApplication, QEventLoop, QObject, Qt, QThread, pyqtSignal from PyQt6.QtWidgets import ( QDialog, QDialogButtonBox, @@ -29,7 +29,7 @@ from bitcoin_usb.dialogs import Worker from bitcoin_usb.i18n import translate from bitcoin_usb.jade_ble_client import JadeBleClient -from bitcoin_usb.util import run_device_task, run_script +from bitcoin_usb.util import run_script from .address_types import ( AddressType, @@ -350,8 +350,7 @@ def _init_client(self): def __enter__(self): self.lock.acquire() try: - # _init_client is a synchronous function; we just use the common runner - run_device_task(loop_in_thread=self.loop_in_thread, task=self._init_client) + self._init_client() return self except Exception: self.lock.release() @@ -374,11 +373,18 @@ def get_fingerprint(self) -> str: return self.client.get_master_fingerprint().hex() def get_xpubs(self) -> dict[AddressType, str]: - xpubs = {} + xpubs: dict[AddressType, str] = {} for address_type in get_all_address_types(): xpubs[address_type] = self.get_xpub(address_type.key_origin(self.network)) + self._process_pending_gui_events() return xpubs + @staticmethod + def _process_pending_gui_events() -> None: + if QCoreApplication.instance() is None: + return + QCoreApplication.processEvents() + def get_xpub(self, key_origin: str) -> str: assert self.client return self.client.get_pubkey_at_path(key_origin).to_string() diff --git a/bitcoin_usb/dialogs.py b/bitcoin_usb/dialogs.py index 62e446d..2832084 100644 --- a/bitcoin_usb/dialogs.py +++ b/bitcoin_usb/dialogs.py @@ -106,35 +106,34 @@ def __init__( self._layout.addWidget(self.devices_group, stretch=1) self.actions_bar = QDialogButtonBox(self) - self.usb_scan_button = self.actions_bar.addButton( - self.tr("Scan USB devices"), QDialogButtonBox.ButtonRole.ActionRole - ) + self.usb_scan_button = QPushButton(self.tr("Scan USB devices"), self) + self.actions_bar.addButton(self.usb_scan_button, QDialogButtonBox.ButtonRole.ActionRole) self.usb_scan_button.setIcon(self.usb_icon) self.usb_scan_button.clicked.connect(self.scan_usb_devices) self.usb_scan_button.setAutoDefault(True) self.usb_scan_button.setDefault(True) - self.bluetooth_scan_button: QPushButton | None = None - if self.bluetooth_scan_callback: - self.bluetooth_scan_button = self.actions_bar.addButton( - self.tr("Scan Bluetooth devices"), QDialogButtonBox.ButtonRole.ActionRole - ) - self.bluetooth_scan_button.setIcon(self.bluetooth_icon) - self.bluetooth_scan_button.clicked.connect(self.scan_for_bluetooth_devices) - self.bluetooth_scan_button.setAutoDefault(False) - - self.install_udev_button: QPushButton | None = None - if self.install_udev_callback and sys.platform.startswith("linux"): - self.install_udev_button = self.actions_bar.addButton( - self.tr("Install udev rules"), QDialogButtonBox.ButtonRole.ActionRole - ) + self.bluetooth_scan_button = QPushButton(self.tr("Scan Bluetooth devices"), self) + if not self.bluetooth_scan_callback: + self.bluetooth_scan_button.setHidden(True) + self.actions_bar.addButton(self.bluetooth_scan_button, QDialogButtonBox.ButtonRole.ActionRole) + self.bluetooth_scan_button.setIcon(self.bluetooth_icon) + self.bluetooth_scan_button.clicked.connect(self.scan_for_bluetooth_devices) + self.bluetooth_scan_button.setAutoDefault(False) + + self.install_udev_button = QPushButton(self.tr("Install udev rules"), self) + if not self.install_udev_callback and sys.platform.startswith("linux"): + self.install_udev_button.setHidden(True) + elif self.install_udev_callback: self.install_udev_button.clicked.connect(self.install_udev_callback) - self.install_udev_button.setAutoDefault(False) - self.install_udev_button.setVisible(False) + self.actions_bar.addButton(self.install_udev_button, QDialogButtonBox.ButtonRole.ActionRole) + self.install_udev_button.setAutoDefault(False) + self.install_udev_button.setVisible(False) self.cancel_button = self.actions_bar.addButton(QDialogButtonBox.StandardButton.Cancel) self.actions_bar.rejected.connect(self.reject) - self.cancel_button.setAutoDefault(False) + if self.cancel_button: + self.cancel_button.setAutoDefault(False) self._layout.addWidget(self.actions_bar) # ensure the dialog has its “natural” size diff --git a/bitcoin_usb/jade_ble_client.py b/bitcoin_usb/jade_ble_client.py index d2f71b7..b6d7c21 100644 --- a/bitcoin_usb/jade_ble_client.py +++ b/bitcoin_usb/jade_ble_client.py @@ -10,6 +10,7 @@ from contextvars import ContextVar, Token from typing import Any +import aioitertools import semver from bleak import BleakScanner from hwilib.common import Chain @@ -26,6 +27,7 @@ DEFAULT_DISCOVERY_SCAN_TIMEOUT_SECONDS = 6.0 DEFAULT_BLE_CONNECT_TIMEOUT_SECONDS = 15.0 DEFAULT_BLE_GATT_OPERATION_TIMEOUT_SECONDS = 10.0 +DEFAULT_BLE_IO_TIMEOUT_SECONDS = 60.0 _IS_BT_DEVICE_PATCHED = False _ORIGINAL_JADEPY_SUBPROCESS_RUN = jade_ble_module.subprocess.run # Temporary per-call channel to pass a preferred BLE MAC address into the custom @@ -144,6 +146,7 @@ def __init__( self.write_task: asyncio.Task[Any] | None = None self.connect_timeout_seconds = DEFAULT_BLE_CONNECT_TIMEOUT_SECONDS self.gatt_operation_timeout_seconds = DEFAULT_BLE_GATT_OPERATION_TIMEOUT_SECONDS + self.io_timeout_seconds = DEFAULT_BLE_IO_TIMEOUT_SECONDS async def _await_ble_operation( self, @@ -225,6 +228,7 @@ def _disconnection_handler(client: Any) -> None: attempts_remaining = 5 client = None needs_set_disconnection_callback = False + attempted_windows_unpair = False # Bleak connect can fail transiently; retry a few times before giving up. while not connected: try: @@ -247,6 +251,15 @@ def _disconnection_handler(client: Any) -> None: jade_ble_module.logger.info(f"Connected: {connected}") except Exception as e: jade_ble_module.logger.warning(f"BLE connection exception: {e}") + if platform.system() == "Windows" and not attempted_windows_unpair: + attempted_windows_unpair = True + unpaired = await self._try_unpair_windows_device(device_mac=device_mac) + if unpaired: + jade_ble_module.logger.info( + "Removed Windows pairing state for %s; retrying BLE connect", + device_mac, + ) + if not attempts_remaining: jade_ble_module.logger.warning("Exhausted retries - BLE connection failed") raise JadeError( @@ -311,6 +324,26 @@ def _notification_handler(sender: Any, data: Any) -> None: self.client = connected_client + async def _try_unpair_windows_device(self, device_mac: str) -> bool: + if platform.system() != "Windows": + return False + try: + unpair_client = jade_ble_module.bleak.BleakClient(device_mac) + except Exception as e: + jade_ble_module.logger.warning("Unable to prepare Windows BLE unpair client: %s", e) + return False + + try: + result = await self._await_ble_operation( + unpair_client.unpair(), + timeout_seconds=self.gatt_operation_timeout_seconds, + operation_name="unpair", + ) + return bool(result) + except Exception as e: + jade_ble_module.logger.warning("Windows BLE unpair failed for %s: %s", device_mac, e) + return False + async def _disconnect_impl(self) -> None: try: if self.client is not None and self.client.is_connected: @@ -334,6 +367,57 @@ async def _disconnect_impl(self) -> None: self.write_task.cancel() self.write_task = None + async def _write_impl(self, bytes_: bytes) -> int: # type: ignore + assert self.client is not None + assert self.write_task is None + + towrite = len(bytes_) + written = 0 + + async def _write() -> None: + if self.client is None: + return + nonlocal written + + while written < towrite: + remaining = towrite - written + length = min(remaining, BlockstreamJadeBleImpl.BLE_MAX_WRITE_SIZE) + upper_limit = written + length + await self.client.write_gatt_char( + BlockstreamJadeBleImpl.IO_TX_CHAR_UUID, + bytearray(bytes_[written:upper_limit]), + response=True, + ) + written = upper_limit + + self.write_task = asyncio.create_task(_write()) + try: + await self._await_ble_operation( + self.write_task, + timeout_seconds=self.io_timeout_seconds, + operation_name="write", + ) + except asyncio.CancelledError: + jade_ble_module.logger.warning( + "write() task cancelled having written %d of %d bytes", written, towrite + ) + finally: + self.write_task = None + + return written + + async def _read_impl(self, n: int) -> bytes: + assert self.inputstream is not None + return await self._await_ble_operation( + self._read_bytes_from_stream(n=n), + timeout_seconds=self.io_timeout_seconds, + operation_name=f"read({n})", + ) + + async def _read_bytes_from_stream(self, n: int) -> bytes: + assert self.inputstream is not None + return bytes([b async for b in aioitertools.islice(self.inputstream, n)]) # type: ignore + class JadeBleClient(JadeClient): """ diff --git a/bitcoin_usb/usb_gui.py b/bitcoin_usb/usb_gui.py index 745f148..442ebd8 100644 --- a/bitcoin_usb/usb_gui.py +++ b/bitcoin_usb/usb_gui.py @@ -4,9 +4,10 @@ import re import tempfile from collections.abc import Callable +from concurrent.futures import ThreadPoolExecutor from functools import partial from pathlib import Path -from typing import Any, cast +from typing import Any, TypeVar, cast import bdkpython as bdk import hwilib.commands as hwi_commands @@ -22,10 +23,10 @@ from bitcoin_usb.address_types import AddressType from bitcoin_usb.dialogs import DeviceDialog, get_message_box from bitcoin_usb.jade_ble_client import discover_jade_ble_devices -from bitcoin_usb.util import run_device_task from .device import USBDevice, bdknetwork_to_chain from .i18n import translate +from .util import run_device_task logger = logging.getLogger(__name__) @@ -34,11 +35,23 @@ def is_ble_available() -> bool: return BleakClient is not None and BleakScanner is not None +T = TypeVar("T") + + +def _run_ble_operation(operation: Callable[[], T]) -> T: + with ThreadPoolExecutor(max_workers=1, thread_name_prefix="ble") as executor: + return executor.submit(operation).result() + + def can_scan_bluetooth_devices(probe_timeout: float = 0.2) -> bool: if not is_ble_available(): return False + + def _probe_scan() -> list[Any]: + return asyncio.run(BleakScanner.discover(timeout=max(0.1, probe_timeout))) + try: - asyncio.run(BleakScanner.discover(timeout=max(0.1, probe_timeout))) + _run_ble_operation(_probe_scan) except Exception as e: logger.info("Bluetooth scanning unavailable in this environment: %s", e) return False @@ -132,7 +145,7 @@ def get_bluetooth_devices(self) -> list[dict[str, Any]]: raise RuntimeError(self.tr("Bluetooth support is disabled by configuration.")) if not self._is_bluetooth_scan_supported(): raise RuntimeError(self.tr("Bluetooth scanning is not available in this environment.")) - return discover_jade_ble_devices(scan_timeout=6.0) + return _run_ble_operation(self._discover_bluetooth_devices) def _is_bluetooth_scan_supported(self) -> bool: if not self.enable_bluetooth: @@ -141,19 +154,58 @@ def _is_bluetooth_scan_supported(self) -> bool: self._bluetooth_scan_supported = can_scan_bluetooth_devices() return self._bluetooth_scan_supported - def sign(self, psbt: bdk.Psbt, slow_hwi_listing=False) -> bdk.Psbt | None: - selected_device = self.get_device(slow_hwi_listing=slow_hwi_listing) - if not selected_device: - return None + def _discover_bluetooth_devices(self) -> list[dict[str, Any]]: + return discover_jade_ble_devices(scan_timeout=6.0) - try: + @staticmethod + def _is_jade_ble_device(selected_device: dict[str, Any]) -> bool: + return ( + str(selected_device.get("type", "")).lower() == "jade" + and str(selected_device.get("transport", "")).lower() == "bluetooth" + ) + + @staticmethod + def _is_usb_device(selected_device: dict[str, Any]) -> bool: + transport = str(selected_device.get("transport", "usb")).lower() + return transport != "bluetooth" + + @staticmethod + def _should_run_in_worker(selected_device: dict[str, Any]) -> bool: + if USBGui._is_jade_ble_device(selected_device): + return True + if platform.system() == "Darwin": + return False + return USBGui._is_usb_device(selected_device) + + def _with_device(self, selected_device: dict[str, Any], operation: Callable[[USBDevice], T]) -> T | None: + """ + Run one hardware-wallet operation with the required threading model. + + - macOS USB: keep calls on the caller/main thread to avoid crashes. + - Linux/Windows USB: run calls in a worker thread. + - Jade BLE: run calls in a worker thread for stable bleak behavior. + """ + + def _run_operation() -> T: with USBDevice( selected_device=selected_device, network=self.network, loop_in_thread=self.loop_in_thread, initalization_label=self.initalization_label, - ) as dev: - return run_device_task(loop_in_thread=self.loop_in_thread, task=partial(dev.sign_psbt, psbt)) + ) as device: + return operation(device) + + if self._should_run_in_worker(selected_device): + return run_device_task(self.loop_in_thread, _run_operation) + return _run_operation() + + def sign(self, psbt: bdk.Psbt, slow_hwi_listing=False) -> bdk.Psbt | None: + selected_device = self.get_device(slow_hwi_listing=slow_hwi_listing) + if not selected_device: + return None + + try: + return self._with_device(selected_device, partial(USBDevice.sign_psbt, psbt=psbt)) except Exception as e: if not self.handle_exception_sign(e): raise @@ -170,17 +222,11 @@ def get_fingerprint_and_xpubs( return None try: - with USBDevice( - selected_device=selected_device, - network=self.network, - loop_in_thread=self.loop_in_thread, - initalization_label=self.initalization_label, - ) as dev: - def f(): - return (selected_device, dev.get_fingerprint(), dev.get_xpubs()) + def _collect_xpubs(device: USBDevice) -> tuple[dict[str, Any], str, dict[AddressType, str]]: + return (selected_device, device.get_fingerprint(), device.get_xpubs()) - return run_device_task(loop_in_thread=self.loop_in_thread, task=f) + return self._with_device(selected_device, _collect_xpubs) except Exception as e: if not self.handle_exception_get_fingerprint_and_xpubs(e): raise @@ -196,17 +242,11 @@ def get_fingerprint_and_xpub( return None try: - with USBDevice( - selected_device=selected_device, - network=self.network, - loop_in_thread=self.loop_in_thread, - initalization_label=self.initalization_label, - ) as dev: - def f(): - return (selected_device, dev.get_fingerprint(), dev.get_xpub(key_origin)) + def _collect_xpub(device: USBDevice) -> tuple[dict[str, Any], str, str]: + return (selected_device, device.get_fingerprint(), device.get_xpub(key_origin)) - return run_device_task(loop_in_thread=self.loop_in_thread, task=f) + return self._with_device(selected_device, _collect_xpub) except Exception as e: if not self.handle_exception_get_fingerprint_and_xpubs(e): raise @@ -220,15 +260,10 @@ def sign_message(self, message: str, bip32_path: str, slow_hwi_listing=False) -> return None try: - with USBDevice( - selected_device=selected_device, - network=self.network, - loop_in_thread=self.loop_in_thread, - initalization_label=self.initalization_label, - ) as dev: - return run_device_task( - loop_in_thread=self.loop_in_thread, task=partial(dev.sign_message, message, bip32_path) - ) + return self._with_device( + selected_device, + partial(USBDevice.sign_message, message=message, bip32_path=bip32_path), + ) except Exception as e: if not self.handle_exception_sign_message(e): raise @@ -242,15 +277,10 @@ def display_address(self, address_descriptor: str, slow_hwi_listing=False) -> st return None try: - with USBDevice( - selected_device=selected_device, - network=self.network, - loop_in_thread=self.loop_in_thread, - initalization_label=self.initalization_label, - ) as dev: - return run_device_task( - loop_in_thread=self.loop_in_thread, task=partial(dev.display_address, address_descriptor) - ) + return self._with_device( + selected_device, + partial(USBDevice.display_address, address_descriptor=address_descriptor), + ) except Exception as e: if not self.handle_exception_display_address(e): raise @@ -264,13 +294,7 @@ def wipe_device(self, slow_hwi_listing=False) -> bool | None: return None try: - with USBDevice( - selected_device=selected_device, - network=self.network, - loop_in_thread=self.loop_in_thread, - initalization_label=self.initalization_label, - ) as dev: - return run_device_task(loop_in_thread=self.loop_in_thread, task=dev.wipe_device) + return self._with_device(selected_device, USBDevice.wipe_device) except Exception as e: if not self.handle_exception_wipe(e): raise @@ -282,24 +306,23 @@ def write_down_seed(self, slow_hwi_listing=False) -> bool | None: selected_device = self.get_device(slow_hwi_listing=slow_hwi_listing) if not selected_device: return None + if str(selected_device.get("type", "")).lower() != "bitbox02": + QMessageBox.information( + None, + "Not supported", + "This is currently only supported for Bitbox02", + ) + self.signal_end_hwi_blocker.emit() + return None try: - with USBDevice( - selected_device=selected_device, - network=self.network, - loop_in_thread=self.loop_in_thread, - initalization_label=self.initalization_label, - ) as dev: - if isinstance(dev.client, Bitbox02Client): - return run_device_task( - loop_in_thread=self.loop_in_thread, task=partial(dev.write_down_seed, dev.client) - ) - - QMessageBox.information( - None, - "Not supported", - "This is currently only supported for Bitbox02", - ) + + def _backup_seed(device: USBDevice) -> bool | None: + if not isinstance(device.client, Bitbox02Client): + return None + return device.write_down_seed(device.client) + + return self._with_device(selected_device, _backup_seed) except Exception as e: if not self.handle_exception_write_down_seed(e): raise @@ -320,15 +343,10 @@ def register_multisig(self, address_descriptor: str, slow_hwi_listing=False) -> ) try: - with USBDevice( - selected_device=selected_device, - network=self.network, - loop_in_thread=self.loop_in_thread, - initalization_label=self.initalization_label, - ) as dev: - return run_device_task( - loop_in_thread=self.loop_in_thread, task=partial(dev.display_address, address_descriptor) - ) + return self._with_device( + selected_device, + partial(USBDevice.display_address, address_descriptor=address_descriptor), + ) except Exception as e: if not self.handle_exception_display_address(e): raise diff --git a/pyproject.toml b/pyproject.toml index 1028dd2..a1cbb17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ disable_error_code = "assignment" [tool.poetry] name = "bitcoin-usb" -version = "4.0.0" +version = "4.0.1" authors = ["andreasgriffin "] license = "GPL-3.0" readme = "README.md" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..34e981b --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,45 @@ +import os +from collections.abc import Iterator +from typing import Any + +import pytest + + +@pytest.fixture(autouse=True) +def disable_real_hardware_access(monkeypatch: pytest.MonkeyPatch) -> Iterator[None]: + """ + Keep the default test suite hardware-independent. + + Set BITCOIN_USB_TEST_REAL_HWI=1 to run tests against real connected devices. + """ + if os.environ.get("BITCOIN_USB_TEST_REAL_HWI") == "1": + yield + return + + def fake_enumerate(allow_emulators: bool = False, chain: Any = None) -> list[dict[str, Any]]: + _ = allow_emulators + _ = chain + return [] + + def fake_get_client( + device_type: str, + device_path: str, + password: str = "", + expert: bool = False, + chain: Any = None, + **kwargs: Any, + ) -> Any: + _ = device_type + _ = device_path + _ = password + _ = expert + _ = chain + _ = kwargs + raise RuntimeError( + "Real hardware access is disabled during tests. " + "Set BITCOIN_USB_TEST_REAL_HWI=1 to enable integration behavior." + ) + + monkeypatch.setattr("hwilib.commands.enumerate", fake_enumerate) + monkeypatch.setattr("hwilib.commands.get_client", fake_get_client) + yield diff --git a/tests/test_bluetooth_scan_support.py b/tests/test_bluetooth_scan_support.py new file mode 100644 index 0000000..539a3a8 --- /dev/null +++ b/tests/test_bluetooth_scan_support.py @@ -0,0 +1,33 @@ +from bitcoin_usb import usb_gui + + +def test_can_scan_bluetooth_devices_uses_ble_operation_wrapper(monkeypatch) -> None: + calls: dict[str, bool] = {"wrapped": False, "run_called": False} + + monkeypatch.setattr(usb_gui, "is_ble_available", lambda: True) + monkeypatch.setattr(usb_gui.BleakScanner, "discover", lambda timeout: "probe") + + def fake_run_ble_operation(operation): + calls["wrapped"] = True + return operation() + + def fake_asyncio_run(_probe): + calls["run_called"] = True + return [] + + monkeypatch.setattr(usb_gui, "_run_ble_operation", fake_run_ble_operation) + monkeypatch.setattr(usb_gui.asyncio, "run", fake_asyncio_run) + + assert usb_gui.can_scan_bluetooth_devices() is True + assert calls == {"wrapped": True, "run_called": True} + + +def test_can_scan_bluetooth_devices_returns_false_on_probe_exception(monkeypatch) -> None: + monkeypatch.setattr(usb_gui, "is_ble_available", lambda: True) + + def raise_runtime_error(_operation): + raise RuntimeError("probe failed") + + monkeypatch.setattr(usb_gui, "_run_ble_operation", raise_runtime_error) + + assert usb_gui.can_scan_bluetooth_devices() is False diff --git a/tests/test_jade_ble_client.py b/tests/test_jade_ble_client.py new file mode 100644 index 0000000..4034895 --- /dev/null +++ b/tests/test_jade_ble_client.py @@ -0,0 +1,47 @@ +import asyncio + +import pytest +from hwilib.devices.jadepy.jade_error import JadeError + +from bitcoin_usb.jade_ble_client import CompatibleJadeBleImpl + + +async def _never_stream(): + while True: + await asyncio.sleep(1) + if False: + yield 0 # type: ignore + + +class _SlowGattClient: + async def write_gatt_char(self, uuid, payload, response): + _ = uuid + _ = payload + _ = response + await asyncio.sleep(1) + + +def test_read_impl_times_out() -> None: + loop = asyncio.new_event_loop() + client = CompatibleJadeBleImpl(device_name="Jade", serial_number=None, scan_timeout=1, loop=loop) + client.inputstream = _never_stream() + client.io_timeout_seconds = 0.01 + + with pytest.raises(JadeError, match="BLE operation timed out: read\\(1\\)"): + loop.run_until_complete(client._read_impl(1)) + + loop.run_until_complete(client.inputstream.aclose()) + loop.close() + + +def test_write_impl_times_out() -> None: + loop = asyncio.new_event_loop() + client = CompatibleJadeBleImpl(device_name="Jade", serial_number=None, scan_timeout=1, loop=loop) + client.client = _SlowGattClient() + client.io_timeout_seconds = 0.01 + + with pytest.raises(JadeError, match="BLE operation timed out: write"): + loop.run_until_complete(client._write_impl(b"abc")) + + assert client.write_task is None + loop.close() diff --git a/tests/test_usb_gui_jade_threading.py b/tests/test_usb_gui_jade_threading.py new file mode 100644 index 0000000..d2e21ed --- /dev/null +++ b/tests/test_usb_gui_jade_threading.py @@ -0,0 +1,133 @@ +import bdkpython as bdk + +from bitcoin_usb import usb_gui +from bitcoin_usb.usb_gui import USBGui + + +def test_get_fingerprint_and_xpubs_uses_worker_task_for_jade_ble(monkeypatch) -> None: + gui = USBGui(network=bdk.Network.REGTEST, loop_in_thread=object()) + selected_device = {"type": "jade", "transport": "bluetooth", "path": "ble:aa"} + calls: dict[str, object] = {"run_task": False} + + monkeypatch.setattr(gui, "get_device", lambda slow_hwi_listing=False: selected_device) + + class _FakeDevice: + def get_fingerprint(self) -> str: + return "f00dbabe" + + def get_xpubs(self) -> dict[str, str]: + return {} + + class _FakeContextDevice: + def __init__(self, *args, **kwargs): + _ = args + _ = kwargs + + def __enter__(self): + return _FakeDevice() + + def __exit__(self, exc_type, exc_value, traceback): + _ = exc_type + _ = exc_value + _ = traceback + + def fake_run_device_task(loop_in_thread, task): + calls["run_task"] = True + calls["loop"] = loop_in_thread + return task() + + monkeypatch.setattr(usb_gui, "run_device_task", fake_run_device_task) + monkeypatch.setattr(usb_gui, "USBDevice", _FakeContextDevice) + + result = gui.get_fingerprint_and_xpubs() + + assert result == (selected_device, "f00dbabe", {}) + assert calls == {"run_task": True, "loop": gui.loop_in_thread} + + +def test_get_fingerprint_and_xpubs_uses_worker_task_for_usb_on_linux(monkeypatch) -> None: + gui = USBGui(network=bdk.Network.REGTEST, loop_in_thread=object()) + selected_device = {"type": "trezor", "path": "usb:1"} + calls: dict[str, object] = {"run_task": False} + + def fake_get_device(slow_hwi_listing=False): + _ = slow_hwi_listing + return selected_device + + monkeypatch.setattr(gui, "get_device", fake_get_device) + monkeypatch.setattr(usb_gui.platform, "system", lambda: "Linux") + + class _FakeDevice: + def get_fingerprint(self) -> str: + return "f00dbabe" + + def get_xpubs(self) -> dict[str, str]: + return {} + + class _FakeContextDevice: + def __init__(self, *args, **kwargs): + _ = args + _ = kwargs + + def __enter__(self): + return _FakeDevice() + + def __exit__(self, exc_type, exc_value, traceback): + _ = exc_type + _ = exc_value + _ = traceback + + def fake_run_device_task(loop_in_thread, task): + calls["run_task"] = True + calls["loop"] = loop_in_thread + return task() + + monkeypatch.setattr(usb_gui, "run_device_task", fake_run_device_task) + monkeypatch.setattr(usb_gui, "USBDevice", _FakeContextDevice) + + result = gui.get_fingerprint_and_xpubs() + + assert result == (selected_device, "f00dbabe", {}) + assert calls == {"run_task": True, "loop": gui.loop_in_thread} + + +def test_get_fingerprint_and_xpubs_runs_inline_for_usb_on_macos(monkeypatch) -> None: + gui = USBGui(network=bdk.Network.REGTEST, loop_in_thread=object()) + selected_device = {"type": "trezor", "path": "usb:1"} + + def fake_get_device(slow_hwi_listing=False): + _ = slow_hwi_listing + return selected_device + + monkeypatch.setattr(gui, "get_device", fake_get_device) + monkeypatch.setattr(usb_gui.platform, "system", lambda: "Darwin") + + class _FakeDevice: + def get_fingerprint(self) -> str: + return "f00dbabe" + + def get_xpubs(self) -> dict[str, str]: + return {} + + class _FakeContextDevice: + def __init__(self, *args, **kwargs): + _ = args + _ = kwargs + + def __enter__(self): + return _FakeDevice() + + def __exit__(self, exc_type, exc_value, traceback): + _ = exc_type + _ = exc_value + _ = traceback + + def fail_run_device_task(_loop_in_thread, _task): + raise AssertionError("run_device_task must not be called for macOS USB") + + monkeypatch.setattr(usb_gui, "run_device_task", fail_run_device_task) + monkeypatch.setattr(usb_gui, "USBDevice", _FakeContextDevice) + + result = gui.get_fingerprint_and_xpubs() + + assert result == (selected_device, "f00dbabe", {}) diff --git a/tests/test_usb_gui_threading.py b/tests/test_usb_gui_threading.py new file mode 100644 index 0000000..d94922c --- /dev/null +++ b/tests/test_usb_gui_threading.py @@ -0,0 +1,40 @@ +import bdkpython as bdk + +from bitcoin_usb.device import USBDevice +from bitcoin_usb.usb_gui import USBGui + + +def test_usbdevice_enter_is_synchronous(monkeypatch) -> None: + init_calls: list[str] = [] + + def fake_init_client(self: USBDevice) -> None: + init_calls.append("called") + + monkeypatch.setattr(USBDevice, "_init_client", fake_init_client) + + device = USBDevice( + selected_device={"type": "jade", "path": "/dev/mock"}, + network=bdk.Network.REGTEST, + loop_in_thread=object(), + initalization_label="", + ) + with device as entered_device: + assert entered_device is device + + assert init_calls == ["called"] + + +def test_usbdevice_run_wrapper_removed() -> None: + assert not hasattr(USBDevice, "run") + + +def test_usbdevice_execute_with_client_removed() -> None: + assert not hasattr(USBDevice, "execute_with_client") + + +def test_usbgui_with_device_exists() -> None: + assert hasattr(USBGui, "_with_device") + + +def test_usbgui_run_with_device_wrapper_removed() -> None: + assert not hasattr(USBGui, "_run_with_device")