diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..3d48fd75 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,6 @@ +- bump: patch + changes: + added: + - Policy and Dynamic classes now support addition operator (__add__) to combine policies and dynamics. + - Parameter values are appended when combining policies/dynamics. + - Simulation modifiers are chained in sequence when combining policies/dynamics. diff --git a/src/policyengine/core/dynamic.py b/src/policyengine/core/dynamic.py index 9b312952..3b6ba553 100644 --- a/src/policyengine/core/dynamic.py +++ b/src/policyengine/core/dynamic.py @@ -15,3 +15,29 @@ class Dynamic(BaseModel): simulation_modifier: Callable | None = None created_at: datetime = Field(default_factory=datetime.now) updated_at: datetime = Field(default_factory=datetime.now) + + def __add__(self, other: "Dynamic") -> "Dynamic": + """Combine two dynamics by appending parameter values and chaining simulation modifiers.""" + if not isinstance(other, Dynamic): + return NotImplemented + + # Combine simulation modifiers + combined_modifier = None + if self.simulation_modifier is not None and other.simulation_modifier is not None: + + def combined_modifier(sim): + sim = self.simulation_modifier(sim) + sim = other.simulation_modifier(sim) + return sim + + elif self.simulation_modifier is not None: + combined_modifier = self.simulation_modifier + elif other.simulation_modifier is not None: + combined_modifier = other.simulation_modifier + + return Dynamic( + name=f"{self.name} + {other.name}", + description=f"Combined dynamic: {self.name} and {other.name}", + parameter_values=self.parameter_values + other.parameter_values, + simulation_modifier=combined_modifier, + ) diff --git a/src/policyengine/core/policy.py b/src/policyengine/core/policy.py index 20587d85..3aeb19b9 100644 --- a/src/policyengine/core/policy.py +++ b/src/policyengine/core/policy.py @@ -15,3 +15,29 @@ class Policy(BaseModel): simulation_modifier: Callable | None = None created_at: datetime = Field(default_factory=datetime.now) updated_at: datetime = Field(default_factory=datetime.now) + + def __add__(self, other: "Policy") -> "Policy": + """Combine two policies by appending parameter values and chaining simulation modifiers.""" + if not isinstance(other, Policy): + return NotImplemented + + # Combine simulation modifiers + combined_modifier = None + if self.simulation_modifier is not None and other.simulation_modifier is not None: + + def combined_modifier(sim): + sim = self.simulation_modifier(sim) + sim = other.simulation_modifier(sim) + return sim + + elif self.simulation_modifier is not None: + combined_modifier = self.simulation_modifier + elif other.simulation_modifier is not None: + combined_modifier = other.simulation_modifier + + return Policy( + name=f"{self.name} + {other.name}", + description=f"Combined policy: {self.name} and {other.name}", + parameter_values=self.parameter_values + other.parameter_values, + simulation_modifier=combined_modifier, + )