Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 47 additions & 6 deletions src/ptychi/api/options/base.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
# Copyright © 2025 UChicago Argonne, LLC All right reserved
# Full license accessible at https://github.com//AdvancedPhotonSource/pty-chi/blob/main/LICENSE

from typing import Optional, Union, TYPE_CHECKING, Sequence
from typing import Optional, Union, TYPE_CHECKING, Sequence, get_origin, get_args
import dataclasses
from dataclasses import field
from dataclasses import field, fields
import logging
from math import ceil
import enum

from numpy import ndarray
from torch import Tensor
import torch
import numpy as np

import ptychi.api.enums as enums
import ptychi.utils as utils

if TYPE_CHECKING:
import ptychi.api.options.task as task_options
Expand All @@ -36,6 +38,49 @@ def check(self, *args, **kwargs) -> None:
"""Check if options values are valid.
"""
return

def resolve_type(self, ann_type) -> type:
"""Resolve annotation to underlying type (handles Optional, etc.)."""
origin = get_origin(ann_type)
if origin is Union:
args = get_args(ann_type)
# Drop NoneType from Optional[...]
return next((arg for arg in args if arg is not type(None)), None)
return ann_type

def get_non_data_fields(self) -> dict:
"""Get fields that do not contain large arrays or tensors."""
d = self.__dict__.copy()
return d

def get_dict(self) -> dict:
"""Get a dictionary representation of the options."""
d = self.get_non_data_fields()
for k, v in d.items():
if isinstance(v, Options):
d[k] = v.get_dict()
else:
d[k] = utils.jsonize(v)
return d

def load_from_dict(self, d: dict) -> "Options":
"""Load options from a dictionary."""
for k, v in d.items():
field_type = self.resolve_type(self.get_field_type(k))
if isinstance(field_type, type) and issubclass(field_type, Options):
self.__setattr__(k, self.resolve_type(self.get_field_type(k))().load_from_dict(v))
elif isinstance(field_type, type) and issubclass(field_type, enum.StrEnum) and isinstance(v, str):
self.__setattr__(k, field_type(v))
else:
self.__setattr__(k, v)
return self

def get_field_type(self, name: str) -> type:
"""Get the type of a field."""
for f in fields(self):
if f.name == name:
return f.type
raise ValueError(f"Field {name} not found in {self.__class__.__name__}.")


@dataclasses.dataclass
Expand Down Expand Up @@ -110,10 +155,6 @@ class ParameterOptions(Options):
def check(self, options: "task_options.PtychographyTaskOptions"):
return super().check(options)

def get_non_data_fields(self) -> dict:
d = self.__dict__.copy()
return d


@dataclasses.dataclass
class FeatureOptions(Options):
Expand Down
6 changes: 6 additions & 0 deletions src/ptychi/api/options/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,9 @@ class PtychographyDataOptions(base.Options):

save_data_on_device: bool = False
"""Whether to save the diffraction data on acceleration devices like GPU."""

def get_non_data_fields(self) -> dict:
d = self.__dict__.copy()
del d['data']
del d['valid_pixel_mask']
return d
12 changes: 12 additions & 0 deletions src/ptychi/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,18 @@ def copy_data_from_task(
)
else:
raise ValueError(f"Invalid parameter name: {param}")

def get_options_as_dict(self) -> dict:
return self.options.get_dict()

def load_options_from_dict(self, d: dict) -> None:
self.options.load_from_dict(d)
self.data_options = self.options.data_options
self.object_options = self.options.object_options
self.probe_options = self.options.probe_options
self.position_options = self.options.probe_position_options
self.opr_mode_weight_options = self.options.opr_mode_weight_options
self.reconstructor_options = self.options.reconstructor_options

def __exit__(self, exc_type, exc_value, exc_tb):
del self.object
Expand Down
3 changes: 0 additions & 3 deletions src/ptychi/data_structures/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,6 @@ def optimization_enabled(self, epoch: int):
enabled = False
logger.debug(f"{self.name} optimization enabled at epoch {epoch}: {enabled}")
return enabled

def get_config_dict(self):
return self.options.get_non_data_fields()

def step_optimizer(self, limit: float = None):
"""Step the optimizer with gradient filled in. This function
Expand Down
3 changes: 0 additions & 3 deletions src/ptychi/data_structures/parameter_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ def get_optimizable_parameters(self) -> list["dsbase.ReconstructParameter"]:
ovs.append(var)
return ovs

def get_config_dict(self):
return {var.name: var.get_config_dict() for var in self.get_all_parameters()}


@dataclasses.dataclass
class PtychographyParameterGroup(ParameterGroup):
Expand Down
9 changes: 0 additions & 9 deletions src/ptychi/reconstructors/ad_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,3 @@ def get_forward_model(self) -> "fm.ForwardModel":
else:
return self.forward_model

def get_config_dict(self) -> dict:
d = super().get_config_dict()
d.update(
{
"forward_model_class": str(self.forward_model_class),
"loss_function": str(self.loss_function),
}
)
return d
12 changes: 0 additions & 12 deletions src/ptychi/reconstructors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,6 @@ def get_option_class(self):
except KeyError:
return api.options.base.ReconstructorOptions

def get_config_dict(self) -> dict:
d = self.parameter_group.get_config_dict()
reconstructor_options = {"name": self.__class__.__name__}
reconstructor_options.update(self.options.__dict__)
d["reconstructor_options"] = reconstructor_options
return d


class PtychographyReconstructor(Reconstructor):
parameter_group: "pg.PtychographyParameterGroup"
Expand Down Expand Up @@ -264,11 +257,6 @@ def build_counter(self):
leave=False
)
self.current_epoch = 0

def get_config_dict(self) -> dict:
d = super().get_config_dict()
d.update({"batch_size": self.batch_size, "n_epochs": self.n_epochs})
return d

def prepare_batch_data(self, batch_data: Sequence[Tensor]) -> Tuple[Sequence[Tensor], Tensor]:
input_data = batch_data[:-1]
Expand Down
5 changes: 0 additions & 5 deletions src/ptychi/reconstructors/lsqml.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,8 +1201,3 @@ def run_minibatch(self, input_data, y_true, *args, **kwargs) -> None:
self.run_real_space_step(psi_opt, indices)

self.loss_tracker.update_batch_loss_with_metric_function(y_pred, y_true)

def get_config_dict(self) -> dict:
d = super().get_config_dict()
d.update({"noise_model": self.noise_model.noise_statistics})
return d
14 changes: 14 additions & 0 deletions src/ptychi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,3 +511,17 @@ def clear_memory(task: Optional["PtychographyTask"] = None):
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()


def jsonize(val):
"""Convert a value to a JSON-serializable object."""
if isinstance(val, np.generic):
return val.item()
elif isinstance(val, np.ndarray):
return val.tolist()
elif isinstance(val, torch.Tensor):
return val.tolist()
elif isinstance(val, (list, tuple, dict, str, int, float, bool, type(None))):
return val
else:
raise TypeError(f"Object of type {type(val).__name__} is not JSON serializable")