diff --git a/mighty/mighty_agents/base_agent.py b/mighty/mighty_agents/base_agent.py index 0f554c3..2300d63 100644 --- a/mighty/mighty_agents/base_agent.py +++ b/mighty/mighty_agents/base_agent.py @@ -7,6 +7,7 @@ import random from abc import ABC from pathlib import Path +from warnings import warn from typing import TYPE_CHECKING, Dict import numpy as np @@ -295,11 +296,16 @@ def __init__( # noqa: PLR0915, PLR0912 or isinstance(self.env.unwrapped, CARLENV) ): self.result_buffer["instances"] = [] + + def default_serialization(obj): + warn(f"'{type(obj)}' not serializable", UserWarning) + return str(obj) + with open(Path(self.output_dir) / "instance_set.json", "w+") as f: - json.dump(self.env.instance_set, f) + json.dump(self.env.instance_set, f, default=default_serialization) with open(Path(self.output_dir) / "test_set.json", "w+") as f: - json.dump(self.eval_env.instance_set, f) + json.dump(self.eval_env.instance_set, f, default=default_serialization) self.eval_buffer = { "eval_after_n_steps": [],