-
Notifications
You must be signed in to change notification settings - Fork 0
update flepimop2 provider pacakge #12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4d3c081
c30bdb5
1d38f47
bd49d88
d1096cc
b0e8ec9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,21 +1,183 @@ | ||
| """flepimop2 engine integration for op_engine. | ||
| """flepimop2 Engine integration for op_engine (thin, single-file).""" | ||
|
|
||
| This package intentionally defines the public Engine class in this module so that | ||
| flepimop2's dynamic loader can auto-inject a default `build()` function. | ||
| from __future__ import annotations | ||
|
|
||
| Why: | ||
| - flepimop2 resolves `module: op_engine` to `flepimop2.engine.op_engine` | ||
| - if that module has no `build`, it looks for a pydantic BaseModel subclass | ||
| defined *in this module* and generates `build()` automatically. | ||
| """ | ||
| from dataclasses import replace | ||
| from typing import TYPE_CHECKING, Literal | ||
|
|
||
| from __future__ import annotations | ||
| import numpy as np | ||
| from flepimop2.configuration import IdentifierString, ModuleModel | ||
| from flepimop2.engine.abc import EngineABC | ||
| from flepimop2.exceptions import ValidationIssue | ||
| from flepimop2.typing import StateChangeEnum # noqa: TC002 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of |
||
| from pydantic import Field | ||
|
|
||
| from op_engine.core_solver import ( | ||
| CoreSolver, | ||
| ) | ||
| from op_engine.model_core import ModelCore, ModelCoreOptions | ||
|
|
||
| from .config import OpEngineEngineConfig, _coerce_operator_specs, _has_operator_specs | ||
|
|
||
| if TYPE_CHECKING: | ||
| from collections.abc import Callable | ||
|
|
||
| from flepimop2.system.abc import SystemABC, SystemProtocol | ||
|
|
||
|
|
||
| def _as_float64_1d(x: object, *, name: str) -> np.ndarray: | ||
| arr = np.asarray(x, dtype=np.float64) | ||
| if arr.ndim != 1: | ||
| msg = f"{name} must be a 1D array" | ||
| raise ValueError(msg) | ||
| return np.ascontiguousarray(arr) | ||
MacdonaldJoshuaCaleb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def _ensure_strictly_increasing(times: np.ndarray, *, name: str) -> None: | ||
| if times.size <= 1: | ||
| return | ||
| if np.any(np.diff(times) <= 0.0): | ||
| msg = f"{name} must be strictly increasing" | ||
| raise ValueError(msg) | ||
MacdonaldJoshuaCaleb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def _rhs_from_stepper( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. work for another day: if adapter is necessary, we should probably be working the pipeline definition a bit. i imagine generally engines may need to adapt SystemABC somewhat, but all this re-shaping seems problematic.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This actually seems fine to me, and within the scope of engines. This is taking a stepper function and then vectorizing it because that's the type of "stepper" that
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm mostly bothered by the shaping related checks than the reshaping itself. I should have said if this amount of adaptor required
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, I think this is a consequence of the outputs being unstructured numpy arrays. To address this would require adding more structure on that front. Perhaps ACCIDDA/flepimop2#147 is a first step to that, but definitely more work required beyond that. |
||
| stepper: SystemProtocol, | ||
| *, | ||
| params: dict[IdentifierString, object], | ||
| n_state: int, | ||
| ) -> Callable[[float, np.ndarray], np.ndarray]: | ||
| def rhs(time: float, state: np.ndarray) -> np.ndarray: | ||
| state_arr = np.asarray(state, dtype=np.float64) | ||
| if state_arr.shape == (n_state,): | ||
| state_arr = state_arr.reshape((n_state, 1)) | ||
| expected_shape = (n_state, 1) | ||
| if state_arr.shape != expected_shape: | ||
| msg = ( | ||
| f"RHS received unexpected state shape {state_arr.shape}; " | ||
| f"expected {expected_shape}." | ||
| ) | ||
| raise ValueError(msg) | ||
| out = np.asarray( | ||
| stepper(np.float64(time), state_arr[:, 0], **params), dtype=np.float64 | ||
| ) | ||
| if out.shape != (n_state,): | ||
| msg = f"Stepper returned shape {out.shape}; expected {(n_state,)}." | ||
| raise ValueError(msg) | ||
| return out.reshape(expected_shape) | ||
|
|
||
| return rhs | ||
|
|
||
|
|
||
| def _extract_states_2d(core: ModelCore, *, n_state: int) -> np.ndarray: | ||
| state_array = getattr(core, "state_array", None) | ||
| if state_array is None: | ||
| msg = "ModelCore does not expose state_array; store_history must be enabled." | ||
| raise RuntimeError(msg) | ||
| arr = np.asarray(state_array, dtype=np.float64) | ||
| if arr.ndim == 3 and arr.shape[1] == n_state and arr.shape[2] == 1: | ||
| return arr[:, :, 0] | ||
| if arr.ndim == 2 and arr.shape[1] == n_state: | ||
| return arr | ||
| msg = ( | ||
| f"Unexpected state shape {arr.shape}; " | ||
| f"expected (T, {n_state}, 1) or (T, {n_state})." | ||
| ) | ||
| raise RuntimeError(msg) | ||
|
|
||
|
|
||
| def _make_core(times: np.ndarray, y0: np.ndarray) -> ModelCore: | ||
| n_states = int(y0.size) | ||
| core = ModelCore( | ||
| n_states, | ||
| 1, | ||
| np.asarray(times, dtype=np.float64), | ||
| options=ModelCoreOptions(other_axes=(), store_history=True, dtype=np.float64), | ||
| ) | ||
| core.set_initial_state(y0.reshape(n_states, 1)) | ||
| return core | ||
|
|
||
|
|
||
| class OpEngineFlepimop2Engine(ModuleModel, EngineABC): | ||
| """flepimop2 engine adapter backed by op_engine.CoreSolver.""" | ||
|
|
||
| module: Literal["flepimop2.engine.op_engine"] = "flepimop2.engine.op_engine" | ||
| state_change: StateChangeEnum | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, I'm a bit confused about this one. It seems like this seems to indicate that |
||
| config: OpEngineEngineConfig = Field(default_factory=OpEngineEngineConfig) | ||
|
|
||
| def validate_system(self, system: SystemABC) -> list[ValidationIssue] | None: | ||
| """Validate system compatibility against the engine state-change mode.""" | ||
| if system.state_change != self.state_change: | ||
| return [ | ||
| ValidationIssue( | ||
| msg=( | ||
| f"Engine state change type, '{self.state_change}', is not " | ||
| "compatible with system state change type " | ||
| f"'{system.state_change}'." | ||
| ), | ||
| kind="incompatible_system", | ||
| ) | ||
| ] | ||
| return None | ||
|
|
||
| def run( | ||
| self, | ||
| system: SystemABC, | ||
| eval_times: np.ndarray, | ||
| initial_state: np.ndarray, | ||
| params: dict[IdentifierString, object], | ||
| **kwargs: object, | ||
| ) -> np.ndarray: | ||
| """Execute simulation using op_engine and return `(time, state...)` output.""" | ||
| del kwargs | ||
|
|
||
| times = _as_float64_1d(eval_times, name="eval_times") | ||
| _ensure_strictly_increasing(times, name="eval_times") | ||
| y0 = _as_float64_1d(initial_state, name="initial_state") | ||
| n_state = int(y0.size) | ||
|
|
||
| run_cfg = self.config.to_run_config() | ||
| is_imex = run_cfg.method.startswith("imex-") | ||
| operators = run_cfg.operators | ||
|
|
||
| if is_imex and not _has_operator_specs(operators): | ||
| operators = ( | ||
| _coerce_operator_specs(system.option("operators", None)) or operators | ||
| ) | ||
| run_cfg = replace(run_cfg, operators=operators) | ||
|
|
||
| if is_imex and not _has_operator_specs(operators): | ||
| msg = ( | ||
| f"IMEX method '{run_cfg.method}' requires operators from engine config " | ||
| "or system option 'operators'." | ||
| ) | ||
| raise ValueError(msg) | ||
|
|
||
| operator_axis = self.config.operator_axis | ||
| if operator_axis == "state": | ||
| system_axis = system.option("operator_axis", None) | ||
| if isinstance(system_axis, str | int): | ||
| operator_axis = system_axis | ||
|
|
||
| stepper: SystemProtocol = system._stepper # noqa: SLF001 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another argument for pushing ACCIDDA/flepimop2#130 forward in the workplan |
||
|
|
||
| from .engine import _OpEngineFlepimop2EngineImpl | ||
| mixing_kernels = system.option("mixing_kernels", None) | ||
| merged_params = { | ||
| **(mixing_kernels if isinstance(mixing_kernels, dict) else {}), | ||
| **params, | ||
| } | ||
| rhs = _rhs_from_stepper(stepper, params=merged_params, n_state=n_state) | ||
| core = _make_core(times, y0) | ||
|
|
||
| solver = CoreSolver( | ||
| core, | ||
| operators=operators.default if is_imex else None, | ||
| operator_axis=operator_axis, | ||
| ) | ||
| solver.run(rhs, config=run_cfg) | ||
|
|
||
| class OpEngineFlepimop2Engine(_OpEngineFlepimop2EngineImpl): # noqa: RUF067 | ||
| """Public op_engine-backed flepimop2 Engine (default-build enabled).""" | ||
| states = _extract_states_2d(core, n_state=n_state) | ||
| return np.asarray(np.column_stack((times, states)), dtype=np.float64) | ||
|
|
||
|
|
||
| __all__ = ["OpEngineFlepimop2Engine"] | ||
| __all__ = ["OpEngineEngineConfig", "OpEngineFlepimop2Engine"] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| """Connector between op_engine and flepimop2 config structure.""" | ||
| """Configuration model for op_engine provider integration.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
|
|
@@ -13,118 +13,82 @@ | |
| RunConfig, | ||
| ) | ||
|
|
||
| MethodName = Literal[ | ||
| "euler", | ||
| "heun", | ||
| "imex-euler", | ||
| "imex-heun-tr", | ||
| "imex-trbdf2", | ||
| ] | ||
|
|
||
| def _has_operator_specs(specs: OperatorSpecs | None) -> bool: | ||
| if specs is None: | ||
| return False | ||
| return any(getattr(specs, key) is not None for key in ("default", "tr", "bdf2")) | ||
|
|
||
|
|
||
| def _coerce_operator_specs(specs: object) -> OperatorSpecs | None: | ||
| if isinstance(specs, OperatorSpecs): | ||
| return specs | ||
| if isinstance(specs, dict): | ||
| return OperatorSpecs( | ||
| default=specs.get("default"), | ||
| tr=specs.get("tr"), | ||
| bdf2=specs.get("bdf2"), | ||
| ) | ||
| return None | ||
|
|
||
|
|
||
| class OpEngineEngineConfig(BaseModel): | ||
| """Configuration schema for op_engine when used as a flepimop2 engine.""" | ||
|
|
||
| model_config = ConfigDict(extra="allow") | ||
|
|
||
| method: MethodName = Field(default="heun", description="Time integration method") | ||
| adaptive: bool = Field( | ||
| default=False, | ||
| description="Enable adaptive substepping between output times", | ||
| ) | ||
| strict: bool = Field( | ||
| default=True, description="Fail fast on invalid configurations" | ||
| method: Literal["euler", "heun", "imex-euler", "imex-heun-tr", "imex-trbdf2"] = ( | ||
| "heun" | ||
| ) | ||
|
|
||
| # tolerances | ||
| adaptive: bool = False | ||
| strict: bool = True | ||
| rtol: float = Field(default=1e-6, ge=0.0) | ||
| atol: float = Field(default=1e-9, ge=0.0) | ||
|
|
||
| # controller | ||
| dt_min: float = Field(default=0.0, ge=0.0) | ||
| dt_max: float = Field(default=float("inf"), gt=0.0) | ||
| safety: float = Field(default=0.9, gt=0.0) | ||
| fac_min: float = Field(default=0.2, gt=0.0) | ||
| fac_max: float = Field(default=5.0, gt=0.0) | ||
|
|
||
| gamma: float | None = Field(default=None, gt=0.0, lt=1.0) | ||
|
|
||
| # Operator specs (default/tr/bdf2) for IMEX methods. | ||
| operators: dict[str, Any] | None = Field( | ||
| default=None, | ||
| description=( | ||
| "Operator specifications for IMEX methods. " | ||
| "Required when method is an IMEX variant." | ||
| ), | ||
| ) | ||
|
|
||
| operator_axis: str | int = Field( | ||
| default="state", | ||
| description="Axis along which implicit operators act (name or index).", | ||
| ) | ||
| operator_axis: str | int = "state" | ||
| operators: dict[str, Any] = Field(default_factory=dict) | ||
|
|
||
| @model_validator(mode="after") | ||
| def _validate_imex_requirements(self) -> OpEngineEngineConfig: | ||
| method = str(self.method) | ||
| if method.startswith("imex-") and not self._has_any_operator_specs( | ||
| self.operators | ||
| def _validate_explicit_empty_operators(self) -> OpEngineEngineConfig: | ||
| if ( | ||
| self.method.startswith("imex-") | ||
| and "operators" in self.model_fields_set | ||
| and not _has_operator_specs(_coerce_operator_specs(self.operators)) | ||
| ): | ||
| msg = ( | ||
| f"IMEX method '{method}' requires operator specifications, " | ||
| "but no operators were provided in the engine config." | ||
| f"IMEX method '{self.method}' received operators, " | ||
| "but none were populated. Provide at least one stage " | ||
| "or omit operators to use system options." | ||
| ) | ||
| raise ValueError(msg) | ||
| return self | ||
|
|
||
| @staticmethod | ||
| def _has_any_operator_specs(operators: dict[str, Any] | None) -> bool: | ||
| """Return True if any operator spec is provided.""" | ||
| if operators is None: | ||
| return False | ||
| return any( | ||
| operators.get(name) is not None for name in ("default", "tr", "bdf2") | ||
| ) | ||
|
|
||
| def to_run_config(self) -> RunConfig: | ||
| """ | ||
| Convert to op_engine RunConfig. | ||
| """Convert this provider config to an op_engine `RunConfig`. | ||
|
|
||
| Returns: | ||
| RunConfig instance reflecting this configuration. | ||
| `RunConfig` derived from this provider configuration. | ||
| """ | ||
| adaptive_cfg = AdaptiveConfig(rtol=self.rtol, atol=self.atol) | ||
| dt_controller = DtControllerConfig( | ||
| dt_min=self.dt_min, | ||
| dt_max=self.dt_max, | ||
| safety=self.safety, | ||
| fac_min=self.fac_min, | ||
| fac_max=self.fac_max, | ||
| ) | ||
|
|
||
| op_specs = self._coerce_operator_specs(self.operators) | ||
|
|
||
| return RunConfig( | ||
| method=self.method, | ||
| adaptive=self.adaptive, | ||
| strict=self.strict, | ||
| adaptive_cfg=adaptive_cfg, | ||
| dt_controller=dt_controller, | ||
| operators=op_specs, | ||
| adaptive_cfg=AdaptiveConfig(rtol=self.rtol, atol=self.atol), | ||
| dt_controller=DtControllerConfig( | ||
| dt_min=self.dt_min, | ||
| dt_max=self.dt_max, | ||
| safety=self.safety, | ||
| fac_min=self.fac_min, | ||
| fac_max=self.fac_max, | ||
| ), | ||
| operators=_coerce_operator_specs(self.operators) or OperatorSpecs(), | ||
| gamma=self.gamma, | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def _coerce_operator_specs(operators: dict[str, Any] | None) -> OperatorSpecs: | ||
| """ | ||
| Normalize operator inputs into OperatorSpecs. | ||
|
|
||
| Returns: | ||
| OperatorSpecs with default/tr/bdf2 fields populated when provided. | ||
| """ | ||
| if operators is None: | ||
| return OperatorSpecs() | ||
| return OperatorSpecs( | ||
| default=operators.get("default"), | ||
| tr=operators.get("tr"), | ||
| bdf2=operators.get("bdf2"), | ||
| ) | ||
| __all__ = ["OpEngineEngineConfig", "_coerce_operator_specs", "_has_operator_specs"] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why export |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes me a bit weary to ignore so many linting/formatting suggestions. Perhaps outside the scope of this PR, but I think these rules exist to indicate where there might be refactoring/documentation opportunities. Although, for
RUF067we also need to think about that on theflepimop2front, the way we've structured the package as being an implicit namespace package with dynamic module loading makes it a little annoying to work around this one.