From d0ed805d7d8c4adec78dea9af352d4aba131743c Mon Sep 17 00:00:00 2001 From: Tom Quist Date: Sat, 12 Jul 2025 18:58:13 +0200 Subject: [PATCH] Add threaded UDP handling for Shelly emulator --- shelly/shelly.py | 76 +++++++++++++++++++++------------------ shelly/shelly_udp_test.py | 65 +++++++++++++++++++++++++++++++++ 2 files changed, 107 insertions(+), 34 deletions(-) create mode 100644 shelly/shelly_udp_test.py diff --git a/shelly/shelly.py b/shelly/shelly.py index 6bcc6eb..5fded9a 100644 --- a/shelly/shelly.py +++ b/shelly/shelly.py @@ -1,6 +1,7 @@ import socket import threading import json +from concurrent.futures import ThreadPoolExecutor from typing import List, Tuple from config import ClientFilter from powermeter import Powermeter @@ -20,6 +21,8 @@ def __init__( self._udp_thread = None self._stop = False self._value_mutex = threading.Lock() + self._executor = ThreadPoolExecutor(max_workers=5) + self._send_lock = threading.Lock() def _calculate_derived_values(self, power): decimal_point_enforcer = 0.001 @@ -76,6 +79,43 @@ def _create_em1_response(self, request_id, powers): }, } + def _handle_request(self, sock, data, addr): + request_str = data.decode() + logger.debug(f"Received UDP message: {request_str}") + logger.debug(f"From: {addr[0]}:{addr[1]}") + + try: + request = json.loads(request_str) + logger.debug(f"Parsed request: {json.dumps(request, indent=2)}") + if isinstance(request.get("params", {}).get("id"), int): + powermeter = None + for pm, client_filter in self._powermeters: + if client_filter.matches(addr[0]): + powermeter = pm + break + if powermeter is None: + logger.warning(f"No powermeter found for client {addr[0]}") + return + + powers = powermeter.get_powermeter_watts() + + if request.get("method") == "EM.GetStatus": + response = self._create_em_response(request["id"], powers) + elif request.get("method") == "EM1.GetStatus": + response = self._create_em1_response(request["id"], powers) + else: + return + + response_json = json.dumps(response, separators=(",", ":")) + logger.debug(f"Sending response: {response_json}") + response_data = response_json.encode() + with self._send_lock: + sock.sendto(response_data, addr) + except json.JSONDecodeError: + logger.error("Error: Invalid JSON") + except Exception as e: + logger.error(f"Error processing message: {e}") + def udp_server(self): sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.bind(("", self._udp_port)) @@ -84,40 +124,7 @@ def udp_server(self): try: while not self._stop: data, addr = sock.recvfrom(1024) - request_str = data.decode() - logger.debug(f"Received UDP message: {request_str}") - logger.debug(f"From: {addr[0]}:{addr[1]}") - - try: - request = json.loads(request_str) - logger.debug(f"Parsed request: {json.dumps(request, indent=2)}") - if isinstance(request.get("params", {}).get("id"), int): - powermeter = None - for pm, client_filter in self._powermeters: - if client_filter.matches(addr[0]): - powermeter = pm - break - if powermeter is None: - logger.warning(f"No powermeter found for client {addr[0]}") - continue - - powers = powermeter.get_powermeter_watts() - - if request.get("method") == "EM.GetStatus": - response = self._create_em_response(request["id"], powers) - elif request.get("method") == "EM1.GetStatus": - response = self._create_em1_response(request["id"], powers) - else: - continue - - response_json = json.dumps(response, separators=(",", ":")) - logger.debug(f"Sending response: {response_json}") - response_data = response_json.encode() - sock.sendto(response_data, addr) - except json.JSONDecodeError: - logger.error(f"Error: Invalid JSON") - except Exception as e: - logger.error(f"Error processing message: {e}") + self._executor.submit(self._handle_request, sock, data, addr) finally: sock.close() @@ -138,3 +145,4 @@ def stop(self): if self._udp_thread: self._udp_thread.join() self._udp_thread = None + self._executor.shutdown(wait=True) diff --git a/shelly/shelly_udp_test.py b/shelly/shelly_udp_test.py new file mode 100644 index 0000000..ca1b26b --- /dev/null +++ b/shelly/shelly_udp_test.py @@ -0,0 +1,65 @@ +import unittest +import socket +import json +import threading +import time +from ipaddress import IPv4Network + +from config import ClientFilter +from powermeter import Powermeter, ThrottledPowermeter +from shelly.shelly import Shelly + + +class DummyPowermeter(Powermeter): + def get_powermeter_watts(self): + return [1.0] + + +class TestShellyUDP(unittest.TestCase): + def test_multiple_requests_with_throttling(self): + pm = ThrottledPowermeter(DummyPowermeter(), throttle_interval=0.2) + cf = ClientFilter([IPv4Network("127.0.0.1/32")]) + + tmp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + tmp.bind(("", 0)) + port = tmp.getsockname()[1] + tmp.close() + + shelly = Shelly([(pm, cf)], udp_port=port, device_id="test") + shelly.start() + try: + client = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + responses = [] + + def send_req(i): + req = { + "id": i, + "src": "cli", + "method": "EM.GetStatus", + "params": {"id": 0}, + } + client.sendto(json.dumps(req).encode(), ("127.0.0.1", port)) + data, _ = client.recvfrom(1024) + responses.append(json.loads(data.decode())["id"]) + + threads = [] + start = time.time() + for i in range(3): + t = threading.Thread(target=send_req, args=(i,)) + t.start() + threads.append(t) + for t in threads: + t.join() + duration = time.time() - start + self.assertEqual(sorted(responses), [0, 1, 2]) + self.assertLess(duration, 0.6) + finally: + client.close() + # send dummy packet to unblock server if waiting on recv + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + s.sendto(b"{}", ("127.0.0.1", port)) + shelly.stop() + + +if __name__ == "__main__": + unittest.main()