Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 6 additions & 77 deletions corrai/optimize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from abc import ABC
from typing import Callable
from functools import wraps

import numpy as np
import pandas as pd
Expand All @@ -10,12 +9,10 @@
from pymoo.core.problem import ElementwiseProblem
from pymoo.core.variable import Integer, Real, Choice, Binary

import plotly.graph_objects as go

from corrai.base.math import METHODS
from corrai.base.model import Model
from corrai.base.utils import check_indicators_configs
from corrai.sampling import Sample
from corrai.sampling import Sample, SampleMethodsMixin
from corrai.base.parameter import Parameter


Expand Down Expand Up @@ -577,7 +574,7 @@ def _evaluate(self, x, out, *args, **kwargs):
self._post_evaluate(pairs, out)


class SciOptimizer:
class SciOptimizer(SampleMethodsMixin):
"""
Optimization wrapper for models using SciPy.

Expand Down Expand Up @@ -635,6 +632,10 @@ def __init__(
def parameters(self):
return self.model_evaluator.parameters

@property
def sample(self):
return self.model_evaluator.sample

@property
def values(self):
return self.model_evaluator.sample.values
Expand Down Expand Up @@ -787,75 +788,3 @@ def diff_evo_minimize(
rng=rng,
workers=workers,
)

@wraps(Sample.plot_sample)
def plot_sample(
self,
indicator: str | None,
reference_timeseries: pd.Series | None = None,
title: str | None = None,
y_label: str | None = None,
x_label: str | None = None,
alpha: float = 0.5,
show_legends: bool = False,
round_ndigits: int = 2,
quantile_band: float = 0.75,
type_graph: str = "area",
) -> go.Figure:
return self.model_evaluator.sample.plot_sample(
indicator=indicator,
reference_timeseries=reference_timeseries,
title=title,
y_label=y_label,
x_label=x_label,
alpha=alpha,
show_legends=show_legends,
round_ndigits=round_ndigits,
quantile_band=quantile_band,
type_graph=type_graph,
)

@wraps(Sample.plot_pcp)
def plot_pcp(
self,
indicators_configs: list[str]
| list[tuple[str, str | Callable] | tuple[str, str | Callable, pd.Series]],
color_by: str | None = None,
title: str | None = "Parallel Coordinates — Samples",
html_file_path: str | None = None,
) -> go.Figure:
return self.model_evaluator.sample.plot_pcp(
indicators_configs=indicators_configs,
color_by=color_by,
title=title,
html_file_path=html_file_path,
)

@wraps(Sample.plot_hist)
def plot_hist(
self,
indicator: str,
method: str = "mean",
unit: str = "",
agg_method_kwarg: dict = None,
reference_time_series: pd.Series = None,
bins: int = 30,
colors: str = "orange",
reference_value: int | float = None,
reference_label: str = "Reference",
show_rug: bool = False,
title: str = None,
):
return self.model_evaluator.sample.plot_hist(
indicator=indicator,
method=method,
unit=unit,
agg_method_kwarg=agg_method_kwarg,
reference_time_series=reference_time_series,
bins=bins,
colors=colors,
reference_value=reference_value,
reference_label=reference_label,
show_rug=show_rug,
title=title,
)
181 changes: 118 additions & 63 deletions corrai/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,26 @@ class Sample:
values: pd.DataFrame = field(init=False)
results: pd.Series = field(default_factory=lambda: pd.Series(dtype=object))

def __repr__(self):
if not self.results.empty:
if not self.results[0].empty:
indicators_name = (
self.results[0].columns
if self.is_dynamic
else self.results[0].index
)
else:
indicators_name = None
else:
indicators_name = None

return (
f"is dynamic: {self.is_dynamic} \n"
f"n computed sample: {len(self.results)} \n"
f"parameters: {[par.name for par in self.parameters]} \n"
f"indicators: {[None] if indicators_name is None else list(indicators_name)}"
)

def __post_init__(self):
self.values = pd.DataFrame(columns=[par.name for par in self.parameters])

Expand Down Expand Up @@ -426,7 +446,7 @@ def get_score_df(
reference_time_series: pd.Series,
scoring_methods: list[str | Callable] = None,
resample_rule: str | pd.Timedelta | dt.timedelta = None,
agg_method: str = "mean",
resample_agg_method: str = "mean",
) -> pd.DataFrame:
"""
Compute scoring metrics for a given indicator across all sample results.
Expand Down Expand Up @@ -456,7 +476,7 @@ def get_score_df(
Examples: ``"D"`` (daily), ``"h"`` (hourly), ``"ME"`` (month end).
If None, no resampling is performed.
Default is None.
agg_method : str, optional
resample_agg_method : str, optional
Aggregation method to use when resampling. Common values include:
``"mean"``, ``"sum"``, ``"min"``, ``"max"``, ``"median"``.
Default is ``"mean"``.
Expand Down Expand Up @@ -519,7 +539,7 @@ def get_score_df(
... reference_time_series=reference,
... scoring_methods=["r2", "rmse", "mae"],
... resample_rule="D",
... agg_method="sum",
... resample_agg_method="sum",
... )
>>> print(scores)
r2_score rmse mae
Expand Down Expand Up @@ -553,10 +573,10 @@ def get_score_df(
for idx, sample_res in self.results.items():
data = sample_res[indicator]
if resample_rule:
data = data.resample(resample_rule).agg(agg_method)
data = data.resample(resample_rule).agg(resample_agg_method)
reference_time_series = reference_time_series.resample(
resample_rule
).agg(agg_method)
).agg(resample_agg_method)

for method in method_func:
scores.loc[idx, method.__name__] = method(reference_time_series, data)
Expand Down Expand Up @@ -960,7 +980,99 @@ def plot_pcp(
)


class Sampler:
class SampleMethodsMixin:
"""Mixin to expose Sample plotting methods to classes that contain a Sample object."""

sample: Sample

@wraps(Sample.get_aggregated_time_series)
def get_sample_aggregated_time_series(
self,
indicator: str,
method: str = "mean",
agg_method_kwarg: dict = None,
reference_time_series: pd.Series = None,
freq: str | pd.Timedelta | dt.timedelta = None,
prefix: str = "aggregated",
):
return self.sample.get_aggregated_time_series(
indicator, method, agg_method_kwarg, reference_time_series, freq, prefix
)

@wraps(Sample.plot_sample)
def plot_sample(
self,
indicator: str | None,
reference_timeseries: pd.Series | None = None,
title: str | None = None,
y_label: str | None = None,
x_label: str | None = None,
alpha: float = 0.5,
show_legends: bool = False,
round_ndigits: int = 2,
quantile_band: float = 0.75,
type_graph: str = "area",
) -> go.Figure:
return self.sample.plot_sample(
indicator=indicator,
reference_timeseries=reference_timeseries,
title=title,
y_label=y_label,
x_label=x_label,
alpha=alpha,
show_legends=show_legends,
round_ndigits=round_ndigits,
quantile_band=quantile_band,
type_graph=type_graph,
)

@wraps(Sample.plot_pcp)
def plot_pcp(
self,
indicators_configs: list[str]
| list[tuple[str, str | Callable] | tuple[str, str | Callable, pd.Series]],
color_by: str | None = None,
title: str | None = "Parallel Coordinates — Samples",
html_file_path: str | None = None,
) -> go.Figure:
return self.sample.plot_pcp(
indicators_configs=indicators_configs,
color_by=color_by,
title=title,
html_file_path=html_file_path,
)

@wraps(Sample.plot_hist)
def plot_hist(
self,
indicator: str,
method: str = "mean",
unit: str = "",
agg_method_kwarg: dict = None,
reference_time_series: pd.Series = None,
bins: int = 30,
colors: str = "orange",
reference_value: int | float = None,
reference_label: str = "Reference",
show_rug: bool = False,
title: str = None,
):
return self.sample.plot_hist(
indicator,
method,
unit,
agg_method_kwarg,
reference_time_series,
bins,
colors,
reference_value,
reference_label,
show_rug,
title,
)


class Sampler(SampleMethodsMixin):
"""
Abstract base class for parameter samplers.

Expand Down Expand Up @@ -1160,63 +1272,6 @@ def simulate_pending(self, n_cpu: int = 1, simulation_kwargs: dict = None):
unsimulated_idx = self.sample.get_pending_index()
self.simulate_at(unsimulated_idx, n_cpu, simulation_kwargs)

@wraps(Sample.plot_sample)
def plot_sample(
self,
indicator: str | None,
reference_timeseries: pd.Series | None = None,
title: str | None = None,
y_label: str | None = None,
x_label: str | None = None,
alpha: float = 0.5,
show_legends: bool = False,
round_ndigits: int = 2,
quantile_band: float = 0.75,
type_graph: str = "area",
) -> go.Figure:
return self.sample.plot_sample(
indicator=indicator,
reference_timeseries=reference_timeseries,
title=title,
y_label=y_label,
x_label=x_label,
alpha=alpha,
show_legends=show_legends,
round_ndigits=round_ndigits,
quantile_band=quantile_band,
type_graph=type_graph,
)

@wraps(Sample.get_aggregated_time_series)
def get_sample_aggregated_time_series(
self,
indicator: str,
method: str = "mean",
agg_method_kwarg: dict = None,
reference_time_series: pd.Series = None,
freq: str | pd.Timedelta | dt.timedelta = None,
prefix: str = "aggregated",
):
return self.sample.get_aggregated_time_series(
indicator, method, agg_method_kwarg, reference_time_series, freq, prefix
)

@wraps(Sample.plot_pcp)
def plot_pcp(
self,
indicators_configs: list[str]
| list[tuple[str, str | Callable] | tuple[str, str | Callable, pd.Series]],
color_by: str | None = None,
title: str | None = "Parallel Coordinates — Samples",
html_file_path: str | None = None,
) -> go.Figure:
return self.sample.plot_pcp(
indicators_configs=indicators_configs,
color_by=color_by,
title=title,
html_file_path=html_file_path,
)


class RealSampler(Sampler):
"""
Expand Down
Loading