Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion oocana/oocana/data.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
from dataclasses import dataclass
from dataclasses import dataclass, asdict
from typing import TypedDict, Literal
from simplejson import JSONEncoder
import simplejson as json

EXECUTOR_NAME = "python"

def dumps(obj, **kwargs):
return json.dumps(obj, cls=DataclassJSONEncoder, ignore_nan=True, **kwargs)

class DataclassJSONEncoder(JSONEncoder):
def default(self, o): # pyright: ignore[reportIncompatibleMethodOverride]
if hasattr(o, '__dataclass_fields__'):
return asdict(o)
return JSONEncoder.default(self, o)

class BinValueDict(TypedDict):
value: str
__OOMOL_TYPE__: Literal["oomol/bin"]
Expand Down
22 changes: 11 additions & 11 deletions oocana/oocana/mainframe.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import simplejson as json
from simplejson import loads
import paho.mqtt.client as mqtt
from paho.mqtt.enums import CallbackAPIVersion
import operator
from urllib.parse import urlparse
import uuid
from .data import BlockDict, JobDict
from .data import BlockDict, JobDict, dumps
import logging
from typing import Optional

Expand Down Expand Up @@ -63,19 +63,19 @@ def on_disconnect(self, client, userdata, flags, reason_code, properties):
# 不等待 publish 完成,使用 qos 参数来会保证消息到达。
def send(self, job_info: JobDict, msg) -> mqtt.MQTTMessageInfo:
return self.client.publish(
f'session/{job_info["session_id"]}', json.dumps({"job_id": job_info["job_id"], "session_id": job_info["session_id"], **msg}, ignore_nan=True), qos=1
f'session/{job_info["session_id"]}', dumps({"job_id": job_info["job_id"], "session_id": job_info["session_id"], **msg}), qos=1
)

def report(self, block_info: BlockDict, msg: dict) -> mqtt.MQTTMessageInfo:
return self.client.publish("report", json.dumps({**block_info, **msg}, ignore_nan=True), qos=1)
return self.client.publish("report", dumps({**block_info, **msg}), qos=1)

def notify_executor_ready(self, session_id: str, executor_name: str, package: str | None) -> None:
self.client.publish(f"session/{session_id}", json.dumps({
self.client.publish(f"session/{session_id}", dumps({
"type": "ExecutorReady",
"session_id": session_id,
"executor_name": executor_name,
"package": package,
}, ignore_nan=True), qos=1)
}), qos=1)

def notify_block_ready(self, session_id: str, job_id: str) -> dict:

Expand All @@ -85,29 +85,29 @@ def notify_block_ready(self, session_id: str, job_id: str) -> dict:
def on_message_once(_client, _userdata, message):
nonlocal replay
self.client.unsubscribe(topic)
replay = json.loads(message.payload)
replay = loads(message.payload)

self.client.subscribe(topic, qos=1)
self.client.message_callback_add(topic, on_message_once)

self.client.publish(f"session/{session_id}", json.dumps({
self.client.publish(f"session/{session_id}", dumps({
"type": "BlockReady",
"session_id": session_id,
"job_id": job_id,
}, ignore_nan=True), qos=1)
}), qos=1)

while True:
if replay is not None:
logger.info("notify ready success in {} {}".format(session_id, job_id))
return replay

def publish(self, topic, payload):
self.client.publish(topic, json.dumps(payload, ignore_nan=True), qos=1)
self.client.publish(topic, dumps(payload), qos=1)

def subscribe(self, topic: str, callback):
def on_message(_client, _userdata, message):
logger.info("receive topic: {} payload: {}".format(topic, message.payload))
payload = json.loads(message.payload)
payload = loads(message.payload)
callback(payload)

self.client.message_callback_add(topic, on_message)
Expand Down
19 changes: 19 additions & 0 deletions oocana/tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,22 @@ def test_block_info_extra(self):
self.assertEqual(block_info.job_id, "job_id_one")
self.assertEqual(block_info.stacks, ["stack1", "stack2"])
self.assertEqual(block_info.block_path, "block_path_one")

def test_dataclass_dumps(self):
block_info_dict = {
"session_id": "session_id_one",
"job_id": "job_id_one",
"stacks": ["stack1", "stack2"],
"block_path": "block_path_one",
"extra": "extra"
}

block_info = data.BlockInfo(**block_info_dict)
serialize_block_info = data.dumps(block_info)
self.assertEqual(serialize_block_info, '{"session_id": "session_id_one", "job_id": "job_id_one", "stacks": ["stack1", "stack2"], "block_path": "block_path_one"}')

list_serialize_block_info = data.dumps([block_info])
self.assertEqual(list_serialize_block_info, '[{"session_id": "session_id_one", "job_id": "job_id_one", "stacks": ["stack1", "stack2"], "block_path": "block_path_one"}]')

key_serialize_block_info = data.dumps({"key": block_info})
self.assertEqual(key_serialize_block_info, '{"key": {"session_id": "session_id_one", "job_id": "job_id_one", "stacks": ["stack1", "stack2"], "block_path": "block_path_one"}}')
Loading