Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 42 additions & 34 deletions shelly/shelly.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import socket
import threading
import json
from concurrent.futures import ThreadPoolExecutor
from typing import List, Tuple
from config import ClientFilter
from powermeter import Powermeter
Expand All @@ -20,6 +21,8 @@ def __init__(
self._udp_thread = None
self._stop = False
self._value_mutex = threading.Lock()
self._executor = ThreadPoolExecutor(max_workers=5)
self._send_lock = threading.Lock()

def _calculate_derived_values(self, power):
decimal_point_enforcer = 0.001
Expand Down Expand Up @@ -76,6 +79,43 @@ def _create_em1_response(self, request_id, powers):
},
}

def _handle_request(self, sock, data, addr):
request_str = data.decode()
logger.debug(f"Received UDP message: {request_str}")
logger.debug(f"From: {addr[0]}:{addr[1]}")

try:
request = json.loads(request_str)
logger.debug(f"Parsed request: {json.dumps(request, indent=2)}")
if isinstance(request.get("params", {}).get("id"), int):
powermeter = None
for pm, client_filter in self._powermeters:
if client_filter.matches(addr[0]):
powermeter = pm
break
if powermeter is None:
logger.warning(f"No powermeter found for client {addr[0]}")
return

powers = powermeter.get_powermeter_watts()

if request.get("method") == "EM.GetStatus":
response = self._create_em_response(request["id"], powers)
elif request.get("method") == "EM1.GetStatus":
response = self._create_em1_response(request["id"], powers)
else:
return

response_json = json.dumps(response, separators=(",", ":"))
logger.debug(f"Sending response: {response_json}")
response_data = response_json.encode()
with self._send_lock:
sock.sendto(response_data, addr)
except json.JSONDecodeError:
logger.error("Error: Invalid JSON")
except Exception as e:
logger.error(f"Error processing message: {e}")

def udp_server(self):
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.bind(("", self._udp_port))
Expand All @@ -84,40 +124,7 @@ def udp_server(self):
try:
while not self._stop:
data, addr = sock.recvfrom(1024)
request_str = data.decode()
logger.debug(f"Received UDP message: {request_str}")
logger.debug(f"From: {addr[0]}:{addr[1]}")

try:
request = json.loads(request_str)
logger.debug(f"Parsed request: {json.dumps(request, indent=2)}")
if isinstance(request.get("params", {}).get("id"), int):
powermeter = None
for pm, client_filter in self._powermeters:
if client_filter.matches(addr[0]):
powermeter = pm
break
if powermeter is None:
logger.warning(f"No powermeter found for client {addr[0]}")
continue

powers = powermeter.get_powermeter_watts()

if request.get("method") == "EM.GetStatus":
response = self._create_em_response(request["id"], powers)
elif request.get("method") == "EM1.GetStatus":
response = self._create_em1_response(request["id"], powers)
else:
continue

response_json = json.dumps(response, separators=(",", ":"))
logger.debug(f"Sending response: {response_json}")
response_data = response_json.encode()
sock.sendto(response_data, addr)
except json.JSONDecodeError:
logger.error(f"Error: Invalid JSON")
except Exception as e:
logger.error(f"Error processing message: {e}")
self._executor.submit(self._handle_request, sock, data, addr)

finally:
sock.close()
Expand All @@ -138,3 +145,4 @@ def stop(self):
if self._udp_thread:
self._udp_thread.join()
self._udp_thread = None
self._executor.shutdown(wait=True)
65 changes: 65 additions & 0 deletions shelly/shelly_udp_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import unittest
import socket
import json
import threading
import time
from ipaddress import IPv4Network

from config import ClientFilter
from powermeter import Powermeter, ThrottledPowermeter
from shelly.shelly import Shelly


class DummyPowermeter(Powermeter):
def get_powermeter_watts(self):
return [1.0]


class TestShellyUDP(unittest.TestCase):
def test_multiple_requests_with_throttling(self):
pm = ThrottledPowermeter(DummyPowermeter(), throttle_interval=0.2)
cf = ClientFilter([IPv4Network("127.0.0.1/32")])

tmp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
tmp.bind(("", 0))
port = tmp.getsockname()[1]
tmp.close()

shelly = Shelly([(pm, cf)], udp_port=port, device_id="test")
shelly.start()
try:
client = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
responses = []

def send_req(i):
req = {
"id": i,
"src": "cli",
"method": "EM.GetStatus",
"params": {"id": 0},
}
client.sendto(json.dumps(req).encode(), ("127.0.0.1", port))
data, _ = client.recvfrom(1024)
responses.append(json.loads(data.decode())["id"])

threads = []
start = time.time()
for i in range(3):
t = threading.Thread(target=send_req, args=(i,))
t.start()
threads.append(t)
for t in threads:
t.join()
duration = time.time() - start
self.assertEqual(sorted(responses), [0, 1, 2])
self.assertLess(duration, 0.6)
finally:
client.close()
# send dummy packet to unblock server if waiting on recv
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.sendto(b"{}", ("127.0.0.1", port))
shelly.stop()


if __name__ == "__main__":
unittest.main()