diff --git a/oocana/oocana/data.py b/oocana/oocana/data.py index 75834465..0b5560ba 100644 --- a/oocana/oocana/data.py +++ b/oocana/oocana/data.py @@ -1,8 +1,19 @@ -from dataclasses import dataclass +from dataclasses import dataclass, asdict from typing import TypedDict, Literal +from simplejson import JSONEncoder +import simplejson as json EXECUTOR_NAME = "python" +def dumps(obj, **kwargs): + return json.dumps(obj, cls=DataclassJSONEncoder, ignore_nan=True, **kwargs) + +class DataclassJSONEncoder(JSONEncoder): + def default(self, o): # pyright: ignore[reportIncompatibleMethodOverride] + if hasattr(o, '__dataclass_fields__'): + return asdict(o) + return JSONEncoder.default(self, o) + class BinValueDict(TypedDict): value: str __OOMOL_TYPE__: Literal["oomol/bin"] diff --git a/oocana/oocana/mainframe.py b/oocana/oocana/mainframe.py index 2b2537a2..e99ec538 100644 --- a/oocana/oocana/mainframe.py +++ b/oocana/oocana/mainframe.py @@ -1,10 +1,10 @@ -import simplejson as json +from simplejson import loads import paho.mqtt.client as mqtt from paho.mqtt.enums import CallbackAPIVersion import operator from urllib.parse import urlparse import uuid -from .data import BlockDict, JobDict +from .data import BlockDict, JobDict, dumps import logging from typing import Optional @@ -63,19 +63,19 @@ def on_disconnect(self, client, userdata, flags, reason_code, properties): # 不等待 publish 完成,使用 qos 参数来会保证消息到达。 def send(self, job_info: JobDict, msg) -> mqtt.MQTTMessageInfo: return self.client.publish( - f'session/{job_info["session_id"]}', json.dumps({"job_id": job_info["job_id"], "session_id": job_info["session_id"], **msg}, ignore_nan=True), qos=1 + f'session/{job_info["session_id"]}', dumps({"job_id": job_info["job_id"], "session_id": job_info["session_id"], **msg}), qos=1 ) def report(self, block_info: BlockDict, msg: dict) -> mqtt.MQTTMessageInfo: - return self.client.publish("report", json.dumps({**block_info, **msg}, ignore_nan=True), qos=1) + return self.client.publish("report", dumps({**block_info, **msg}), qos=1) def notify_executor_ready(self, session_id: str, executor_name: str, package: str | None) -> None: - self.client.publish(f"session/{session_id}", json.dumps({ + self.client.publish(f"session/{session_id}", dumps({ "type": "ExecutorReady", "session_id": session_id, "executor_name": executor_name, "package": package, - }, ignore_nan=True), qos=1) + }), qos=1) def notify_block_ready(self, session_id: str, job_id: str) -> dict: @@ -85,16 +85,16 @@ def notify_block_ready(self, session_id: str, job_id: str) -> dict: def on_message_once(_client, _userdata, message): nonlocal replay self.client.unsubscribe(topic) - replay = json.loads(message.payload) + replay = loads(message.payload) self.client.subscribe(topic, qos=1) self.client.message_callback_add(topic, on_message_once) - self.client.publish(f"session/{session_id}", json.dumps({ + self.client.publish(f"session/{session_id}", dumps({ "type": "BlockReady", "session_id": session_id, "job_id": job_id, - }, ignore_nan=True), qos=1) + }), qos=1) while True: if replay is not None: @@ -102,12 +102,12 @@ def on_message_once(_client, _userdata, message): return replay def publish(self, topic, payload): - self.client.publish(topic, json.dumps(payload, ignore_nan=True), qos=1) + self.client.publish(topic, dumps(payload), qos=1) def subscribe(self, topic: str, callback): def on_message(_client, _userdata, message): logger.info("receive topic: {} payload: {}".format(topic, message.payload)) - payload = json.loads(message.payload) + payload = loads(message.payload) callback(payload) self.client.message_callback_add(topic, on_message) diff --git a/oocana/tests/test_data.py b/oocana/tests/test_data.py index d4aac0f9..3f4b077a 100644 --- a/oocana/tests/test_data.py +++ b/oocana/tests/test_data.py @@ -50,3 +50,22 @@ def test_block_info_extra(self): self.assertEqual(block_info.job_id, "job_id_one") self.assertEqual(block_info.stacks, ["stack1", "stack2"]) self.assertEqual(block_info.block_path, "block_path_one") + + def test_dataclass_dumps(self): + block_info_dict = { + "session_id": "session_id_one", + "job_id": "job_id_one", + "stacks": ["stack1", "stack2"], + "block_path": "block_path_one", + "extra": "extra" + } + + block_info = data.BlockInfo(**block_info_dict) + serialize_block_info = data.dumps(block_info) + self.assertEqual(serialize_block_info, '{"session_id": "session_id_one", "job_id": "job_id_one", "stacks": ["stack1", "stack2"], "block_path": "block_path_one"}') + + list_serialize_block_info = data.dumps([block_info]) + self.assertEqual(list_serialize_block_info, '[{"session_id": "session_id_one", "job_id": "job_id_one", "stacks": ["stack1", "stack2"], "block_path": "block_path_one"}]') + + key_serialize_block_info = data.dumps({"key": block_info}) + self.assertEqual(key_serialize_block_info, '{"key": {"session_id": "session_id_one", "job_id": "job_id_one", "stacks": ["stack1", "stack2"], "block_path": "block_path_one"}}') \ No newline at end of file