diff --git a/statemachine/factory.py b/statemachine/factory.py index e5428e4..4d2789d 100644 --- a/statemachine/factory.py +++ b/statemachine/factory.py @@ -2,8 +2,10 @@ from typing import TYPE_CHECKING from typing import Any from typing import Dict +from typing import Generic from typing import List from typing import Tuple +from typing import TypeVar from . import registry from .event import Event @@ -17,6 +19,10 @@ from .transition_list import TransitionList +TModel = TypeVar("TModel") +"""TypeVar for the model type in StateMachine.""" + + class StateMachineMetaclass(type): "Metaclass for constructing StateMachine classes" @@ -36,7 +42,9 @@ def __init__( cls._abstract = True cls._strict_states = strict_states - cls._events: Dict[Event, None] = {} # used Dict to preserve order and avoid duplicates + cls._events: Dict[Event, None] = ( + {} + ) # used Dict to preserve order and avoid duplicates cls._protected_attrs: set = set() cls._events_to_update: Dict[Event, Event | None] = {} @@ -98,9 +106,9 @@ def _check_final_states(cls): if final_state_with_invalid_transitions: raise InvalidDefinition( - _("Cannot declare transitions from final state. Invalid state(s): {}").format( - [s.id for s in final_state_with_invalid_transitions] - ) + _( + "Cannot declare transitions from final state. Invalid state(s): {}" + ).format([s.id for s in final_state_with_invalid_transitions]) ) def _check_trap_states(cls): @@ -133,7 +141,8 @@ def _states_without_path_to_final_states(cls): return [ state for state in cls.states - if not state.final and not any(s.final for s in visit_connected_states(state)) + if not state.final + and not any(s.final for s in visit_connected_states(state)) ] def _disconnected_states(cls, starting_state): @@ -259,3 +268,17 @@ def _update_event_references(cls): @property def events(self): return list(self._events) + + +class GenericStateMachineMetaclass(StateMachineMetaclass, type(Generic)): # type: ignore[misc] + """ + Metaclass that combines StateMachineMetaclass with Generic. + + This allows StateMachine to be parameterized with a model type using Generic[TModel], + enabling type checkers to infer the correct type of the `model` attribute. + + The type: ignore[misc] is necessary because mypy has limitations with generic metaclasses, + but this pattern works correctly at runtime and with type checkers. + """ + + pass diff --git a/statemachine/statemachine.py b/statemachine/statemachine.py index e5fe262..e0f8f0a 100644 --- a/statemachine/statemachine.py +++ b/statemachine/statemachine.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING from typing import Any from typing import Dict +from typing import Generic from typing import List from .callbacks import SPECS_ALL @@ -18,7 +19,8 @@ from .exceptions import InvalidDefinition from .exceptions import InvalidStateValue from .exceptions import TransitionNotAllowed -from .factory import StateMachineMetaclass +from .factory import GenericStateMachineMetaclass +from .factory import TModel from .graph import iterate_states_and_transitions from .i18n import _ from .model import Model @@ -29,7 +31,7 @@ from .state import State -class StateMachine(metaclass=StateMachineMetaclass): +class StateMachine(Generic[TModel], metaclass=GenericStateMachineMetaclass): """ Args: @@ -68,14 +70,14 @@ class StateMachine(metaclass=StateMachineMetaclass): def __init__( self, - model: Any = None, + model: "TModel | None" = None, state_field: str = "state", start_value: Any = None, rtc: bool = True, allow_event_without_transition: bool = False, listeners: "List[object] | None" = None, ): - self.model = model if model is not None else Model() + self.model: TModel = model if model is not None else Model() # type: ignore[assignment] self.state_field = state_field self.start_value = start_value self.allow_event_without_transition = allow_event_without_transition @@ -149,7 +151,9 @@ def __setstate__(self, state): self._engine = self._get_engine(rtc) def _get_initial_state(self): - initial_state_value = self.start_value if self.start_value else self.initial_state.value + initial_state_value = ( + self.start_value if self.start_value else self.initial_state.value + ) try: return self.states_map[initial_state_value] except KeyError as err: @@ -170,7 +174,9 @@ def bind_events_to(self, *targets): continue setattr(target, event, trigger) - def _add_listener(self, listeners: "Listeners", allowed_references: SpecReference = SPECS_ALL): + def _add_listener( + self, listeners: "Listeners", allowed_references: SpecReference = SPECS_ALL + ): registry = self._callbacks for visited in iterate_states_and_transitions(self.states): listeners.resolve( @@ -292,7 +298,10 @@ def events(self) -> "List[Event]": @property def allowed_events(self) -> "List[Event]": """List of the current allowed events.""" - return [getattr(self, event) for event in self.current_state.transitions.unique_events] + return [ + getattr(self, event) + for event in self.current_state.transitions.unique_events + ] def _put_nonblocking(self, trigger_data: TriggerData): """Put the trigger on the queue without blocking the caller.""" diff --git a/tests/test_generic_support.py b/tests/test_generic_support.py new file mode 100644 index 0000000..91877aa --- /dev/null +++ b/tests/test_generic_support.py @@ -0,0 +1,128 @@ +"""Tests for Generic[TModel] support in StateMachine. + +Test that type checkers can infer the correct model type when using Generic[TModel]. +""" + +import pytest + +from statemachine import State +from statemachine import StateMachine + + +class CustomModel: + """Custom model for testing""" + + def __init__(self): + self.state = None + self.custom_attr = "test_value" + self.counter = 0 + + +class GenericStateMachine(StateMachine[CustomModel]): + """State machine using Generic[CustomModel] for type safety""" + + initial = State("Initial", initial=True) + processing = State("Processing") + final = State("Final", final=True) + + start = initial.to(processing) + finish = processing.to(final) + + +class TestGenericSupport: + """Test suite for Generic[TModel] support""" + + def test_generic_statemachine_with_custom_model(self): + """Test that StateMachine[CustomModel] works with a custom model instance""" + model = CustomModel() + sm = GenericStateMachine(model=model) + + assert sm.model is model + assert sm.model.custom_attr == "test_value" + assert sm.model.counter == 0 + + def test_generic_statemachine_with_default_model(self): + """Test that StateMachine[CustomModel] works with default Model()""" + sm = GenericStateMachine() + + # Default model should be Model(), not CustomModel + assert sm.model is not None + assert sm.current_state == sm.initial + + def test_generic_statemachine_transitions_work(self): + """Test that transitions work correctly with generic state machine""" + model = CustomModel() + sm = GenericStateMachine(model=model) + + assert sm.current_state == sm.initial + + sm.start() + assert sm.current_state == sm.processing + + sm.finish() + assert sm.current_state == sm.final + + def test_generic_statemachine_model_persists_across_transitions(self): + """Test that model state persists across transitions""" + model = CustomModel() + sm = GenericStateMachine(model=model) + + # Modify model + sm.model.counter = 42 + sm.model.custom_attr = "modified" + + # Transition + sm.start() + + # Model state should persist + assert sm.model.counter == 42 + assert sm.model.custom_attr == "modified" + + def test_backward_compatibility_without_generic(self): + """Test that traditional usage without Generic still works""" + + class TraditionalMachine(StateMachine): + """Non-generic state machine for backward compatibility""" + + idle = State("Idle", initial=True) + running = State("Running") + + run = idle.to(running) + + sm = TraditionalMachine() + assert sm.current_state == sm.idle + + sm.run() + assert sm.current_state == sm.running + + def test_multiple_generic_machines_with_different_models(self): + """Test that different generic machines can use different model types""" + + class ModelA: + def __init__(self): + self.state = None + self.value_a = "A" + + class ModelB: + def __init__(self): + self.state = None + self.value_b = "B" + + class MachineA(StateMachine[ModelA]): + initial = State("Initial", initial=True) + final = State("Final", final=True) + go = initial.to(final) + + class MachineB(StateMachine[ModelB]): + start = State("Start", initial=True) + end = State("End", final=True) + advance = start.to(end) + + model_a = ModelA() + model_b = ModelB() + + sm_a = MachineA(model=model_a) + sm_b = MachineB(model=model_b) + + assert sm_a.model.value_a == "A" + assert sm_b.model.value_b == "B"