diff --git a/config.go b/config.go index 2dd76d0..a3ad659 100644 --- a/config.go +++ b/config.go @@ -3,52 +3,66 @@ package stateless import ( "context" "fmt" + "reflect" ) type transitionKey struct{} -func withTransition(ctx context.Context, transition Transition) context.Context { +func withTransition[S State, T Trigger](ctx context.Context, transition Transition[S, T]) context.Context { return context.WithValue(ctx, transitionKey{}, transition) } // GetTransition returns the transition from the context. // If there is no transition the returned value is empty. -func GetTransition(ctx context.Context) Transition { - tr, _ := ctx.Value(transitionKey{}).(Transition) +func GetTransition[S State, T Trigger](ctx context.Context) Transition[S, T] { + tr, _ := ctx.Value(transitionKey{}).(Transition[S, T]) return tr } +// Args is a generic list of arguments. +type Args []any + +func (a Args) Len() int { + return len(a) +} + +func (a Args) TypeOf(i int) reflect.Type { + return reflect.TypeOf(a[i]) +} + +var _ Validatable = Args{} // Ensure Args implements Validatable + // ActionFunc describes a generic action function. // The context will always contain Transition information. -type ActionFunc = func(ctx context.Context, args ...any) error +type ActionFunc[A any] func(ctx context.Context, arg A) error // GuardFunc defines a generic guard function. -type GuardFunc = func(ctx context.Context, args ...any) bool +type GuardFunc[A any] func(ctx context.Context, arg A) bool // DestinationSelectorFunc defines a functions that is called to select a dynamic destination. -type DestinationSelectorFunc = func(ctx context.Context, args ...any) (State, error) +type DestinationSelectorFunc[S State, A any] func(ctx context.Context, arg A) (S, error) // StateConfiguration is the configuration for a single state value. -type StateConfiguration struct { - sm *StateMachine - sr *stateRepresentation - lookup func(State) *stateRepresentation +type StateConfiguration[S State, T Trigger, A any] struct { + sm *StateMachine[S, T, A] + sr *stateRepresentation[S, T, A] + lookup func(S) *stateRepresentation[S, T, A] } // State is configured with this configuration. -func (sc *StateConfiguration) State() State { +func (sc *StateConfiguration[S, T, _]) State() S { return sc.sr.State } // Machine that is configured with this configuration. -func (sc *StateConfiguration) Machine() *StateMachine { +func (sc *StateConfiguration[S, T, A]) Machine() *StateMachine[S, T, A] { return sc.sm } // InitialTransition adds an initial transition to this state. // When entering the current state the state machine will look for an initial transition, // and enter the target state. -func (sc *StateConfiguration) InitialTransition(targetState State) *StateConfiguration { +func (sc *StateConfiguration[S, T, A]) InitialTransition(targetState S) *StateConfiguration[S, T, A] { if sc.sr.HasInitialState { panic(fmt.Sprintf("stateless: This state has already been configured with an initial transition (%v).", sc.sr.InitialTransitionTarget)) } @@ -60,12 +74,12 @@ func (sc *StateConfiguration) InitialTransition(targetState State) *StateConfigu } // Permit accept the specified trigger and transition to the destination state if the guard conditions are met (if any). -func (sc *StateConfiguration) Permit(trigger Trigger, destinationState State, guards ...GuardFunc) *StateConfiguration { +func (sc *StateConfiguration[S, T, A]) Permit(trigger T, destinationState S, guards ...GuardFunc[A]) *StateConfiguration[S, T, A] { if destinationState == sc.sr.State { panic("stateless: Permit() require that the destination state is not equal to the source state. To accept a trigger without changing state, use either Ignore() or PermitReentry().") } - sc.sr.AddTriggerBehaviour(&transitioningTriggerBehaviour{ - baseTriggerBehaviour: baseTriggerBehaviour{Trigger: trigger, Guard: newtransitionGuard(guards...)}, + sc.sr.AddTriggerBehaviour(&transitioningTriggerBehaviour[S, T, A]{ + baseTriggerBehaviour: baseTriggerBehaviour[T, A]{Trigger: trigger, Guard: newtransitionGuard[A](guards...)}, Destination: destinationState, }) return sc @@ -73,9 +87,9 @@ func (sc *StateConfiguration) Permit(trigger Trigger, destinationState State, gu // InternalTransition add an internal transition to the state machine. // An internal action does not cause the Exit and Entry actions to be triggered, and does not change the state of the state machine. -func (sc *StateConfiguration) InternalTransition(trigger Trigger, action ActionFunc, guards ...GuardFunc) *StateConfiguration { - sc.sr.AddTriggerBehaviour(&internalTriggerBehaviour{ - baseTriggerBehaviour: baseTriggerBehaviour{Trigger: trigger, Guard: newtransitionGuard(guards...)}, +func (sc *StateConfiguration[S, T, A]) InternalTransition(trigger T, action ActionFunc[A], guards ...GuardFunc[A]) *StateConfiguration[S, T, A] { + sc.sr.AddTriggerBehaviour(&internalTriggerBehaviour[S, T, A]{ + baseTriggerBehaviour: baseTriggerBehaviour[T, A]{Trigger: trigger, Guard: newtransitionGuard[A](guards...)}, Action: action, }) return sc @@ -85,37 +99,37 @@ func (sc *StateConfiguration) InternalTransition(trigger Trigger, action ActionF // Reentry behaves as though the configured state transitions to an identical sibling state. // Applies to the current state only. Will not re-execute superstate actions, or // cause actions to execute transitioning between super- and sub-states. -func (sc *StateConfiguration) PermitReentry(trigger Trigger, guards ...GuardFunc) *StateConfiguration { - sc.sr.AddTriggerBehaviour(&reentryTriggerBehaviour{ - baseTriggerBehaviour: baseTriggerBehaviour{Trigger: trigger, Guard: newtransitionGuard(guards...)}, +func (sc *StateConfiguration[S, T, A]) PermitReentry(trigger T, guards ...GuardFunc[A]) *StateConfiguration[S, T, A] { + sc.sr.AddTriggerBehaviour(&reentryTriggerBehaviour[S, T, A]{ + baseTriggerBehaviour: baseTriggerBehaviour[T, A]{Trigger: trigger, Guard: newtransitionGuard[A](guards...)}, Destination: sc.sr.State, }) return sc } // Ignore the specified trigger when in the configured state, if the guards return true. -func (sc *StateConfiguration) Ignore(trigger Trigger, guards ...GuardFunc) *StateConfiguration { - sc.sr.AddTriggerBehaviour(&ignoredTriggerBehaviour{ - baseTriggerBehaviour: baseTriggerBehaviour{Trigger: trigger, Guard: newtransitionGuard(guards...)}, +func (sc *StateConfiguration[S, T, A]) Ignore(trigger T, guards ...GuardFunc[A]) *StateConfiguration[S, T, A] { + sc.sr.AddTriggerBehaviour(&ignoredTriggerBehaviour[T, A]{ + baseTriggerBehaviour: baseTriggerBehaviour[T, A]{Trigger: trigger, Guard: newtransitionGuard(guards...)}, }) return sc } // PermitDynamic accept the specified trigger and transition to the destination state, calculated dynamically by the supplied function. -func (sc *StateConfiguration) PermitDynamic(trigger Trigger, selector DestinationSelectorFunc, guards ...GuardFunc) *StateConfiguration { +func (sc *StateConfiguration[S, T, A]) PermitDynamic(trigger T, selector DestinationSelectorFunc[S, A], guards ...GuardFunc[A]) *StateConfiguration[S, T, A] { guardDescriptors := make([]invocationInfo, len(guards)) for i, guard := range guards { guardDescriptors[i] = newinvocationInfo(guard) } - sc.sr.AddTriggerBehaviour(&dynamicTriggerBehaviour{ - baseTriggerBehaviour: baseTriggerBehaviour{Trigger: trigger, Guard: newtransitionGuard(guards...)}, + sc.sr.AddTriggerBehaviour(&dynamicTriggerBehaviour[S, T, A]{ + baseTriggerBehaviour: baseTriggerBehaviour[T, A]{Trigger: trigger, Guard: newtransitionGuard[A](guards...)}, Destination: selector, }) return sc } // OnActive specify an action that will execute when activating the configured state. -func (sc *StateConfiguration) OnActive(action func(context.Context) error) *StateConfiguration { +func (sc *StateConfiguration[S, T, A]) OnActive(action func(context.Context) error) *StateConfiguration[S, T, A] { sc.sr.ActivateActions = append(sc.sr.ActivateActions, actionBehaviourSteady{ Action: action, Description: newinvocationInfo(action), @@ -124,7 +138,7 @@ func (sc *StateConfiguration) OnActive(action func(context.Context) error) *Stat } // OnDeactivate specify an action that will execute when deactivating the configured state. -func (sc *StateConfiguration) OnDeactivate(action func(context.Context) error) *StateConfiguration { +func (sc *StateConfiguration[S, T, A]) OnDeactivate(action func(context.Context) error) *StateConfiguration[S, T, A] { sc.sr.DeactivateActions = append(sc.sr.DeactivateActions, actionBehaviourSteady{ Action: action, Description: newinvocationInfo(action), @@ -133,8 +147,8 @@ func (sc *StateConfiguration) OnDeactivate(action func(context.Context) error) * } // OnEntry specify an action that will execute when transitioning into the configured state. -func (sc *StateConfiguration) OnEntry(action ActionFunc) *StateConfiguration { - sc.sr.EntryActions = append(sc.sr.EntryActions, actionBehaviour{ +func (sc *StateConfiguration[S, T, A]) OnEntry(action ActionFunc[A]) *StateConfiguration[S, T, A] { + sc.sr.EntryActions = append(sc.sr.EntryActions, actionBehaviour[S, T, A]{ Action: action, Description: newinvocationInfo(action), }) @@ -142,8 +156,8 @@ func (sc *StateConfiguration) OnEntry(action ActionFunc) *StateConfiguration { } // OnEntryFrom Specify an action that will execute when transitioning into the configured state from a specific trigger. -func (sc *StateConfiguration) OnEntryFrom(trigger Trigger, action ActionFunc) *StateConfiguration { - sc.sr.EntryActions = append(sc.sr.EntryActions, actionBehaviour{ +func (sc *StateConfiguration[S, T, A]) OnEntryFrom(trigger T, action ActionFunc[A]) *StateConfiguration[S, T, A] { + sc.sr.EntryActions = append(sc.sr.EntryActions, actionBehaviour[S, T, A]{ Action: action, Description: newinvocationInfo(action), Trigger: &trigger, @@ -152,8 +166,8 @@ func (sc *StateConfiguration) OnEntryFrom(trigger Trigger, action ActionFunc) *S } // OnExit specify an action that will execute when transitioning from the configured state. -func (sc *StateConfiguration) OnExit(action ActionFunc) *StateConfiguration { - sc.sr.ExitActions = append(sc.sr.ExitActions, actionBehaviour{ +func (sc *StateConfiguration[S, T, A]) OnExit(action ActionFunc[A]) *StateConfiguration[S, T, A] { + sc.sr.ExitActions = append(sc.sr.ExitActions, actionBehaviour[S, T, A]{ Action: action, Description: newinvocationInfo(action), }) @@ -161,8 +175,8 @@ func (sc *StateConfiguration) OnExit(action ActionFunc) *StateConfiguration { } // OnExitWith specifies an action that will execute when transitioning from the configured state with a specific trigger. -func (sc *StateConfiguration) OnExitWith(trigger Trigger, action ActionFunc) *StateConfiguration { - sc.sr.ExitActions = append(sc.sr.ExitActions, actionBehaviour{ +func (sc *StateConfiguration[S, T, A]) OnExitWith(trigger T, action ActionFunc[A]) *StateConfiguration[S, T, A] { + sc.sr.ExitActions = append(sc.sr.ExitActions, actionBehaviour[S, T, A]{ Action: action, Description: newinvocationInfo(action), Trigger: &trigger, @@ -176,7 +190,7 @@ func (sc *StateConfiguration) OnExitWith(trigger Trigger, action ActionFunc) *St // entry actions for the superstate are executed. // Likewise when leaving from the substate to outside the supserstate, // exit actions for the superstate will execute. -func (sc *StateConfiguration) SubstateOf(superstate State) *StateConfiguration { +func (sc *StateConfiguration[S, T, A]) SubstateOf(superstate S) *StateConfiguration[S, T, A] { state := sc.sr.State // Check for accidental identical cyclic configuration if state == superstate { @@ -185,7 +199,7 @@ func (sc *StateConfiguration) SubstateOf(superstate State) *StateConfiguration { // Check for accidental identical nested cyclic configuration var empty struct{} - supersets := map[State]struct{}{state: empty} + supersets := map[S]struct{}{state: empty} // Build list of super states and check for activeSc := sc.lookup(superstate) diff --git a/example_test.go b/example_test.go index bd53555..76bda87 100644 --- a/example_test.go +++ b/example_test.go @@ -3,9 +3,8 @@ package stateless_test import ( "context" "fmt" - "reflect" - "github.com/qmuntal/stateless" + "reflect" ) const ( @@ -29,7 +28,7 @@ const ( ) func Example() { - phoneCall := stateless.NewStateMachine(stateOffHook) + phoneCall := stateless.NewStateMachine[string, string, stateless.Args](stateOffHook) phoneCall.SetTriggerParameters(triggerSetVolume, reflect.TypeOf(0)) phoneCall.SetTriggerParameters(triggerCallDialed, reflect.TypeOf("")) @@ -37,7 +36,7 @@ func Example() { Permit(triggerCallDialed, stateRinging) phoneCall.Configure(stateRinging). - OnEntryFrom(triggerCallDialed, func(_ context.Context, args ...any) error { + OnEntryFrom(triggerCallDialed, func(_ context.Context, args stateless.Args) error { onDialed(args[0].(string)) return nil }). @@ -45,19 +44,19 @@ func Example() { phoneCall.Configure(stateConnected). OnEntry(startCallTimer). - OnExit(func(_ context.Context, _ ...any) error { + OnExit(func(_ context.Context, args stateless.Args) error { stopCallTimer() return nil }). - InternalTransition(triggerMuteMicrophone, func(_ context.Context, _ ...any) error { + InternalTransition(triggerMuteMicrophone, func(_ context.Context, _ stateless.Args) error { onMute() return nil }). - InternalTransition(triggerUnmuteMicrophone, func(_ context.Context, _ ...any) error { + InternalTransition(triggerUnmuteMicrophone, func(_ context.Context, _ stateless.Args) error { onUnmute() return nil }). - InternalTransition(triggerSetVolume, func(_ context.Context, args ...any) error { + InternalTransition(triggerSetVolume, func(_ context.Context, args stateless.Args) error { onSetVolume(args[0].(int)) return nil }). @@ -66,7 +65,7 @@ func Example() { phoneCall.Configure(stateOnHold). SubstateOf(stateConnected). - OnExitWith(triggerPhoneHurledAgainstWall, func(ctx context.Context, args ...any) error { + OnExitWith(triggerPhoneHurledAgainstWall, func(ctx context.Context, _ stateless.Args) error { onWasted() return nil }). @@ -75,16 +74,16 @@ func Example() { phoneCall.ToGraph() - phoneCall.Fire(triggerCallDialed, "qmuntal") - phoneCall.Fire(triggerCallConnected) - phoneCall.Fire(triggerSetVolume, 2) - phoneCall.Fire(triggerPlacedOnHold) - phoneCall.Fire(triggerMuteMicrophone) - phoneCall.Fire(triggerUnmuteMicrophone) - phoneCall.Fire(triggerTakenOffHold) - phoneCall.Fire(triggerSetVolume, 11) - phoneCall.Fire(triggerPlacedOnHold) - phoneCall.Fire(triggerPhoneHurledAgainstWall) + phoneCall.Fire(triggerCallDialed, stateless.Args{"qmuntal"}) + phoneCall.Fire(triggerCallConnected, nil) + phoneCall.Fire(triggerSetVolume, stateless.Args{2}) + phoneCall.Fire(triggerPlacedOnHold, nil) + phoneCall.Fire(triggerMuteMicrophone, nil) + phoneCall.Fire(triggerUnmuteMicrophone, nil) + phoneCall.Fire(triggerTakenOffHold, nil) + phoneCall.Fire(triggerSetVolume, stateless.Args{11}) + phoneCall.Fire(triggerPlacedOnHold, nil) + phoneCall.Fire(triggerPhoneHurledAgainstWall, nil) fmt.Printf("State is %v\n", phoneCall.MustState()) // Output: @@ -120,7 +119,7 @@ func onWasted() { fmt.Println("Wasted!") } -func startCallTimer(_ context.Context, _ ...any) error { +func startCallTimer(_ context.Context, _ stateless.Args) error { fmt.Println("[Timer:] Call started at 11:00am") return nil } diff --git a/graph.go b/graph.go index 2327771..7f73875 100644 --- a/graph.go +++ b/graph.go @@ -9,14 +9,14 @@ import ( "unicode" ) -type graph struct { +type graph[S State, T Trigger, A any] struct { } -func (g *graph) formatStateMachine(sm *StateMachine) string { +func (g *graph[S, T, A]) formatStateMachine(sm *StateMachine[S, T, A]) string { var sb strings.Builder sb.WriteString("digraph {\n\tcompound=true;\n\tnode [shape=Mrecord];\n\trankdir=\"LR\";\n\n") - stateList := make([]*stateRepresentation, 0, len(sm.stateConfig)) + stateList := make([]*stateRepresentation[S, T, A], 0, len(sm.stateConfig)) for _, st := range sm.stateConfig { stateList = append(stateList, st) } @@ -50,7 +50,7 @@ func (g *graph) formatStateMachine(sm *StateMachine) string { return sb.String() } -func (g *graph) formatActions(sr *stateRepresentation) string { +func (g *graph[S, T, A]) formatActions(sr *stateRepresentation[S, T, A]) string { es := make([]string, 0, len(sr.EntryActions)+len(sr.ExitActions)+len(sr.ActivateActions)+len(sr.DeactivateActions)) for _, act := range sr.ActivateActions { es = append(es, fmt.Sprintf("activated / %s", esc(act.Description.String(), false))) @@ -69,7 +69,7 @@ func (g *graph) formatActions(sr *stateRepresentation) string { return strings.Join(es, "\\n") } -func (g *graph) formatOneState(sb *strings.Builder, sr *stateRepresentation, level int) { +func (g *graph[S, T, A]) formatOneState(sb *strings.Builder, sr *stateRepresentation[S, T, A], level int) { var indent string for i := 0; i < level; i++ { indent += "\t" @@ -98,7 +98,7 @@ func (g *graph) formatOneState(sb *strings.Builder, sr *stateRepresentation, lev } } -func (g *graph) getEntryActions(ab []actionBehaviour, t Trigger) []string { +func (g *graph[S, T, A]) getEntryActions(ab []actionBehaviour[S, T, A], t T) []string { var actions []string for _, ea := range ab { if ea.Trigger != nil && *ea.Trigger == t { @@ -108,8 +108,8 @@ func (g *graph) getEntryActions(ab []actionBehaviour, t Trigger) []string { return actions } -func (g *graph) formatAllStateTransitions(sb *strings.Builder, sm *StateMachine, sr *stateRepresentation) { - triggerList := make([]triggerBehaviour, 0, len(sr.TriggerBehaviours)) +func (g *graph[S, T, A]) formatAllStateTransitions(sb *strings.Builder, sm *StateMachine[S, T, A], sr *stateRepresentation[S, T, A]) { + triggerList := make([]triggerBehaviour[T, A], 0, len(sr.TriggerBehaviours)) for _, triggers := range sr.TriggerBehaviours { triggerList = append(triggerList, triggers...) } @@ -120,35 +120,35 @@ func (g *graph) formatAllStateTransitions(sb *strings.Builder, sm *StateMachine, }) type line struct { - source State - destination State + source S // State + destination S // State } lines := make(map[line][]string, len(triggerList)) order := make([]line, 0, len(triggerList)) for _, trigger := range triggerList { switch t := trigger.(type) { - case *ignoredTriggerBehaviour: + case *ignoredTriggerBehaviour[T, A]: ln := line{sr.State, sr.State} if _, ok := lines[ln]; !ok { order = append(order, ln) } lines[ln] = append(lines[ln], formatOneTransition(t.Trigger, nil, t.Guard)) - case *reentryTriggerBehaviour: + case *reentryTriggerBehaviour[S, T, A]: actions := g.getEntryActions(sr.EntryActions, t.Trigger) ln := line{sr.State, t.Destination} if _, ok := lines[ln]; !ok { order = append(order, ln) } lines[ln] = append(lines[ln], formatOneTransition(t.Trigger, actions, t.Guard)) - case *internalTriggerBehaviour: + case *internalTriggerBehaviour[S, T, A]: actions := g.getEntryActions(sr.EntryActions, t.Trigger) ln := line{sr.State, sr.State} if _, ok := lines[ln]; !ok { order = append(order, ln) } lines[ln] = append(lines[ln], formatOneTransition(t.Trigger, actions, t.Guard)) - case *transitioningTriggerBehaviour: + case *transitioningTriggerBehaviour[S, T, A]: src := sm.stateConfig[sr.State] if src == nil { continue @@ -158,7 +158,7 @@ func (g *graph) formatAllStateTransitions(sb *strings.Builder, sm *StateMachine, if dest != nil { actions = g.getEntryActions(dest.EntryActions, t.Trigger) } - var destState State + var destState S if dest == nil { destState = t.Destination } else { @@ -169,7 +169,7 @@ func (g *graph) formatAllStateTransitions(sb *strings.Builder, sm *StateMachine, order = append(order, ln) } lines[ln] = append(lines[ln], formatOneTransition(t.Trigger, actions, t.Guard)) - case *dynamicTriggerBehaviour: + case *dynamicTriggerBehaviour[S, T, A]: // TODO: not supported yet } } @@ -180,7 +180,7 @@ func (g *graph) formatAllStateTransitions(sb *strings.Builder, sm *StateMachine, } } -func formatOneTransition(trigger Trigger, actions []string, guards transitionGuard) string { +func formatOneTransition[T any, A any](trigger T, actions []string, guards transitionGuard[A]) string { var sb strings.Builder sb.WriteString(str(trigger, false)) if len(actions) > 0 { diff --git a/graph_test.go b/graph_test.go index 4edd6af..f2b4507 100644 --- a/graph_test.go +++ b/graph_test.go @@ -15,20 +15,20 @@ import ( var update = flag.Bool("update", false, "update golden files on failure") -func emptyWithInitial() *stateless.StateMachine { - return stateless.NewStateMachine("A") +func emptyWithInitial() *stateless.StateMachine[string, string, stateless.Args] { + return stateless.NewStateMachine[string, string, stateless.Args]("A") } -func withSubstate() *stateless.StateMachine { - sm := stateless.NewStateMachine("B") +func withSubstate() *stateless.StateMachine[string, string, stateless.Args] { + sm := stateless.NewStateMachine[string, string, stateless.Args]("B") sm.Configure("A").Permit("Z", "B") sm.Configure("B").SubstateOf("C").Permit("X", "A") sm.Configure("C").Permit("Y", "A").Ignore("X") return sm } -func withInitialState() *stateless.StateMachine { - sm := stateless.NewStateMachine("A") +func withInitialState() *stateless.StateMachine[string, string, stateless.Args] { + sm := stateless.NewStateMachine[string, string, stateless.Args]("A") sm.Configure("A"). Permit("X", "B") sm.Configure("B"). @@ -41,28 +41,28 @@ func withInitialState() *stateless.StateMachine { return sm } -func withGuards() *stateless.StateMachine { - sm := stateless.NewStateMachine("B") - sm.SetTriggerParameters("X", reflect.TypeOf(0)) +func withGuards() *stateless.StateMachine[string, string, stateless.Args] { + sm := stateless.NewStateMachine[string, string, stateless.Args]("B") + //sm.SetTriggerParameters("X", reflect.TypeOf(0)) sm.Configure("A"). - Permit("X", "D", func(_ context.Context, args ...any) bool { + Permit("X", "D", func(_ context.Context, args stateless.Args) bool { return args[0].(int) == 3 }) sm.Configure("B"). SubstateOf("A"). - Permit("X", "C", func(_ context.Context, args ...any) bool { + Permit("X", "C", func(_ context.Context, args stateless.Args) bool { return args[0].(int) == 2 }) return sm } -func œ(_ context.Context, args ...any) bool { +func œ(_ context.Context, args stateless.Args) bool { return args[0].(int) == 2 } -func withUnicodeNames() *stateless.StateMachine { - sm := stateless.NewStateMachine("Ĕ") +func withUnicodeNames() *stateless.StateMachine[string, string, stateless.Args] { + sm := stateless.NewStateMachine[string, string, stateless.Args]("Ĕ") sm.Configure("Ĕ"). Permit("◵", "ų", œ) sm.Configure("ų"). @@ -79,32 +79,32 @@ func withUnicodeNames() *stateless.StateMachine { return sm } -func phoneCall() *stateless.StateMachine { - phoneCall := stateless.NewStateMachine(stateOffHook) - phoneCall.SetTriggerParameters(triggerSetVolume, reflect.TypeOf(0)) - phoneCall.SetTriggerParameters(triggerCallDialed, reflect.TypeOf("")) +func phoneCall() *stateless.StateMachine[string, string, stateless.Args] { + phoneCall := stateless.NewStateMachine[string, string, stateless.Args](stateOffHook) + //phoneCall.SetTriggerParameters(triggerSetVolume, reflect.TypeOf(0)) + //phoneCall.SetTriggerParameters(triggerCallDialed, reflect.TypeOf("")) phoneCall.Configure(stateOffHook). Permit(triggerCallDialed, stateRinging) phoneCall.Configure(stateRinging). - OnEntryFrom(triggerCallDialed, func(_ context.Context, args ...any) error { + OnEntryFrom(triggerCallDialed, func(_ context.Context, _ stateless.Args) error { return nil }). Permit(triggerCallConnected, stateConnected) phoneCall.Configure(stateConnected). OnEntry(startCallTimer). - OnExit(func(_ context.Context, _ ...any) error { + OnExit(func(_ context.Context, _ stateless.Args) error { return nil }). - InternalTransition(triggerMuteMicrophone, func(_ context.Context, _ ...any) error { + InternalTransition(triggerMuteMicrophone, func(_ context.Context, _ stateless.Args) error { return nil }). - InternalTransition(triggerUnmuteMicrophone, func(_ context.Context, _ ...any) error { + InternalTransition(triggerUnmuteMicrophone, func(_ context.Context, _ stateless.Args) error { return nil }). - InternalTransition(triggerSetVolume, func(_ context.Context, args ...any) error { + InternalTransition(triggerSetVolume, func(_ context.Context, args stateless.Args) error { return nil }). Permit(triggerLeftMessage, stateOffHook). @@ -112,7 +112,7 @@ func phoneCall() *stateless.StateMachine { phoneCall.Configure(stateOnHold). SubstateOf(stateConnected). - OnExitWith(triggerPhoneHurledAgainstWall, func(ctx context.Context, args ...any) error { + OnExitWith(triggerPhoneHurledAgainstWall, func(ctx context.Context, _ stateless.Args) error { onWasted() return nil }). @@ -123,7 +123,7 @@ func phoneCall() *stateless.StateMachine { } func TestStateMachine_ToGraph(t *testing.T) { - tests := []func() *stateless.StateMachine{ + tests := []func() *stateless.StateMachine[string, string, stateless.Args]{ emptyWithInitial, withSubstate, withInitialState, diff --git a/modes.go b/modes.go index 3aec556..376d925 100644 --- a/modes.go +++ b/modes.go @@ -6,46 +6,46 @@ import ( "sync/atomic" ) -type fireMode interface { - Fire(ctx context.Context, trigger Trigger, args ...any) error +type fireMode[T Trigger, A any] interface { + Fire(ctx context.Context, trigger T, arg A) error Firing() bool } -type fireModeImmediate struct { +type fireModeImmediate[S State, T Trigger, A any] struct { ops atomic.Uint64 - sm *StateMachine + sm *StateMachine[S, T, A] } -func (f *fireModeImmediate) Firing() bool { +func (f *fireModeImmediate[_, _, _]) Firing() bool { return f.ops.Load() > 0 } -func (f *fireModeImmediate) Fire(ctx context.Context, trigger Trigger, args ...any) error { +func (f *fireModeImmediate[_, T, A]) Fire(ctx context.Context, trigger T, arg A) error { f.ops.Add(1) defer f.ops.Add(^uint64(0)) - return f.sm.internalFireOne(ctx, trigger, args...) + return f.sm.internalFireOne(ctx, trigger, arg) } -type queuedTrigger struct { +type queuedTrigger[T Trigger, A any] struct { Context context.Context - Trigger Trigger - Args []any + Trigger T + Arg A } -type fireModeQueued struct { +type fireModeQueued[S State, T Trigger, A any] struct { firing atomic.Bool - sm *StateMachine + sm *StateMachine[S, T, A] - triggers []queuedTrigger + triggers []queuedTrigger[T, A] mu sync.Mutex // guards triggers } -func (f *fireModeQueued) Firing() bool { +func (f *fireModeQueued[_, _, _]) Firing() bool { return f.firing.Load() } -func (f *fireModeQueued) Fire(ctx context.Context, trigger Trigger, args ...any) error { - f.enqueue(ctx, trigger, args...) +func (f *fireModeQueued[_, T, A]) Fire(ctx context.Context, trigger T, arg A) error { + f.enqueue(ctx, trigger, arg) for { et, ok := f.fetch() if !ok { @@ -59,30 +59,30 @@ func (f *fireModeQueued) Fire(ctx context.Context, trigger Trigger, args ...any) return nil } -func (f *fireModeQueued) enqueue(ctx context.Context, trigger Trigger, args ...any) { +func (f *fireModeQueued[_, T, A]) enqueue(ctx context.Context, trigger T, arg A) { f.mu.Lock() defer f.mu.Unlock() - f.triggers = append(f.triggers, queuedTrigger{Context: ctx, Trigger: trigger, Args: args}) + f.triggers = append(f.triggers, queuedTrigger[T, A]{Context: ctx, Trigger: trigger, Arg: arg}) } -func (f *fireModeQueued) fetch() (et queuedTrigger, ok bool) { +func (f *fireModeQueued[S, T, A]) fetch() (et queuedTrigger[T, A], ok bool) { f.mu.Lock() defer f.mu.Unlock() if len(f.triggers) == 0 { - return queuedTrigger{}, false + return queuedTrigger[T, A]{}, false } if !f.firing.CompareAndSwap(false, true) { - return queuedTrigger{}, false + return queuedTrigger[T, A]{}, false } et, f.triggers = f.triggers[0], f.triggers[1:] return et, true } -func (f *fireModeQueued) execute(et queuedTrigger) error { +func (f *fireModeQueued[S, T, A]) execute(et queuedTrigger[T, A]) error { defer f.firing.Swap(false) - return f.sm.internalFireOne(et.Context, et.Trigger, et.Args...) + return f.sm.internalFireOne(et.Context, et.Trigger, et.Arg) } diff --git a/statemachine.go b/statemachine.go index 99e4c2a..c94663e 100644 --- a/statemachine.go +++ b/statemachine.go @@ -8,10 +8,14 @@ import ( ) // State is used to to represent the possible machine states. -type State = any +type State interface { + comparable +} // Trigger is used to represent the triggers that cause state transitions. -type Trigger = any +type Trigger interface { + comparable +} // FiringMode enumerate the different modes used when Fire-ing a trigger. type FiringMode uint8 @@ -25,34 +29,34 @@ const ( ) // Transition describes a state transition. -type Transition struct { - Source State - Destination State - Trigger Trigger +type Transition[S State, T Trigger] struct { + Source S + Destination S + Trigger T isInitial bool } // IsReentry returns true if the transition is a re-entry, // i.e. the identity transition. -func (t *Transition) IsReentry() bool { +func (t *Transition[_, _]) IsReentry() bool { return t.Source == t.Destination } -type TransitionFunc = func(context.Context, Transition) +type TransitionFunc[S State, T Trigger] func(context.Context, Transition[S, T]) // UnhandledTriggerActionFunc defines a function that will be called when a trigger is not handled. -type UnhandledTriggerActionFunc = func(ctx context.Context, state State, trigger Trigger, unmetGuards []string) error +type UnhandledTriggerActionFunc[S State, T Trigger] func(ctx context.Context, state S, trigger T, unmetGuards []string) error // DefaultUnhandledTriggerAction is the default unhandled trigger action. -func DefaultUnhandledTriggerAction(_ context.Context, state State, trigger Trigger, unmetGuards []string) error { +func DefaultUnhandledTriggerAction[S State, T Trigger](_ context.Context, state S, trigger T, unmetGuards []string) error { if len(unmetGuards) != 0 { return fmt.Errorf("stateless: Trigger '%v' is valid for transition from state '%v' but a guard conditions are not met. Guard descriptions: '%v", trigger, state, unmetGuards) } return fmt.Errorf("stateless: No valid leaving transitions are permitted from state '%v' for trigger '%v', consider ignoring the trigger", state, trigger) } -func callEvents(events []TransitionFunc, ctx context.Context, transition Transition) { +func callEvents[S State, T Trigger](events []TransitionFunc[S, T], ctx context.Context, transition Transition[S, T]) { for _, e := range events { e(ctx, transition) } @@ -61,50 +65,50 @@ func callEvents(events []TransitionFunc, ctx context.Context, transition Transit // A StateMachine is an abstract machine that can be in exactly one of a finite number of states at any given time. // It is safe to use the StateMachine concurrently, but non of the callbacks (state manipulation, actions, events, ...) are guarded, // so it is up to the client to protect them against race conditions. -type StateMachine struct { - stateConfig map[State]*stateRepresentation - triggerConfig map[Trigger]triggerWithParameters - stateAccessor func(context.Context) (State, error) - stateMutator func(context.Context, State) error - unhandledTriggerAction UnhandledTriggerActionFunc - onTransitioningEvents []TransitionFunc - onTransitionedEvents []TransitionFunc +type StateMachine[S State, T Trigger, A any] struct { + stateConfig map[S]*stateRepresentation[S, T, A] + triggerConfig map[T]triggerWithParameters[T] + stateAccessor func(context.Context) (S, error) + stateMutator func(context.Context, S) error + unhandledTriggerAction UnhandledTriggerActionFunc[S, T] + onTransitioningEvents []TransitionFunc[S, T] + onTransitionedEvents []TransitionFunc[S, T] stateMutex sync.RWMutex - mode fireMode + mode fireMode[T, A] } -func newStateMachine(firingMode FiringMode) *StateMachine { - sm := &StateMachine{ - stateConfig: make(map[State]*stateRepresentation), - triggerConfig: make(map[Trigger]triggerWithParameters), - unhandledTriggerAction: UnhandledTriggerActionFunc(DefaultUnhandledTriggerAction), +func newStateMachine[S State, T Trigger, A any](firingMode FiringMode) *StateMachine[S, T, A] { + sm := &StateMachine[S, T, A]{ + stateConfig: make(map[S]*stateRepresentation[S, T, A]), + triggerConfig: make(map[T]triggerWithParameters[T]), + unhandledTriggerAction: UnhandledTriggerActionFunc[S, T](DefaultUnhandledTriggerAction[S, T]), } if firingMode == FiringImmediate { - sm.mode = &fireModeImmediate{sm: sm} + sm.mode = &fireModeImmediate[S, T, A]{sm: sm} } else { - sm.mode = &fireModeQueued{sm: sm} + sm.mode = &fireModeQueued[S, T, A]{sm: sm} } return sm } // NewStateMachine returns a queued state machine. -func NewStateMachine(initialState State) *StateMachine { - return NewStateMachineWithMode(initialState, FiringQueued) +func NewStateMachine[S State, T Trigger, A any](initialState S) *StateMachine[S, T, A] { + return NewStateMachineWithMode[S, T, A](initialState, FiringQueued) } // NewStateMachineWithMode returns a state machine with the desired firing mode -func NewStateMachineWithMode(initialState State, firingMode FiringMode) *StateMachine { +func NewStateMachineWithMode[S State, T Trigger, A any](initialState S, firingMode FiringMode) *StateMachine[S, T, A] { var stateMutex sync.Mutex - sm := newStateMachine(firingMode) + sm := newStateMachine[S, T, A](firingMode) reference := &struct { - State State + State S }{State: initialState} - sm.stateAccessor = func(_ context.Context) (State, error) { + sm.stateAccessor = func(_ context.Context) (S, error) { stateMutex.Lock() defer stateMutex.Unlock() return reference.State, nil } - sm.stateMutator = func(_ context.Context, state State) error { + sm.stateMutator = func(_ context.Context, state S) error { stateMutex.Lock() defer stateMutex.Unlock() reference.State = state @@ -114,8 +118,8 @@ func NewStateMachineWithMode(initialState State, firingMode FiringMode) *StateMa } // NewStateMachineWithExternalStorage returns a state machine with external state storage. -func NewStateMachineWithExternalStorage(stateAccessor func(context.Context) (State, error), stateMutator func(context.Context, State) error, firingMode FiringMode) *StateMachine { - sm := newStateMachine(firingMode) +func NewStateMachineWithExternalStorage[S State, T Trigger, A any](stateAccessor func(context.Context) (S, error), stateMutator func(context.Context, S) error, firingMode FiringMode) *StateMachine[S, T, A] { + sm := newStateMachine[S, T, A](firingMode) sm.stateAccessor = stateAccessor sm.stateMutator = stateMutator return sm @@ -123,12 +127,12 @@ func NewStateMachineWithExternalStorage(stateAccessor func(context.Context) (Sta // ToGraph returns the DOT representation of the state machine. // It is not guaranteed that the returned string will be the same in different executions. -func (sm *StateMachine) ToGraph() string { - return new(graph).formatStateMachine(sm) +func (sm *StateMachine[S, T, A]) ToGraph() string { + return new(graph[S, T, A]).formatStateMachine(sm) } // State returns the current state. -func (sm *StateMachine) State(ctx context.Context) (State, error) { +func (sm *StateMachine[S, _, _]) State(ctx context.Context) (S, error) { return sm.stateAccessor(ctx) } @@ -136,7 +140,7 @@ func (sm *StateMachine) State(ctx context.Context) (State, error) { // It is safe to use this method when used together with NewStateMachine // or when using NewStateMachineWithExternalStorage with an state accessor that // does not return an error. -func (sm *StateMachine) MustState() State { +func (sm *StateMachine[S, _, _]) MustState() S { st, err := sm.State(context.Background()) if err != nil { panic(err) @@ -145,28 +149,28 @@ func (sm *StateMachine) MustState() State { } // PermittedTriggers see PermittedTriggersCtx. -func (sm *StateMachine) PermittedTriggers(args ...any) ([]Trigger, error) { - return sm.PermittedTriggersCtx(context.Background(), args...) +func (sm *StateMachine[_, T, A]) PermittedTriggers(arg A) ([]T, error) { + return sm.PermittedTriggersCtx(context.Background(), arg) } // PermittedTriggersCtx returns the currently-permissible trigger values. -func (sm *StateMachine) PermittedTriggersCtx(ctx context.Context, args ...any) ([]Trigger, error) { +func (sm *StateMachine[_, T, A]) PermittedTriggersCtx(ctx context.Context, arg A) ([]T, error) { sr, err := sm.currentState(ctx) if err != nil { return nil, err } - return sr.PermittedTriggers(ctx, args...), nil + return sr.PermittedTriggers(ctx, arg), nil } // Activate see ActivateCtx. -func (sm *StateMachine) Activate() error { +func (sm *StateMachine[S, T, _]) Activate() error { return sm.ActivateCtx(context.Background()) } // ActivateCtx activates current state. Actions associated with activating the current state will be invoked. // The activation is idempotent and subsequent activation of the same current state // will not lead to re-execution of activation callbacks. -func (sm *StateMachine) ActivateCtx(ctx context.Context) error { +func (sm *StateMachine[S, T, _]) ActivateCtx(ctx context.Context) error { sr, err := sm.currentState(ctx) if err != nil { return err @@ -175,14 +179,14 @@ func (sm *StateMachine) ActivateCtx(ctx context.Context) error { } // Deactivate see DeactivateCtx. -func (sm *StateMachine) Deactivate() error { +func (sm *StateMachine[S, T, _]) Deactivate() error { return sm.DeactivateCtx(context.Background()) } // DeactivateCtx deactivates current state. Actions associated with deactivating the current state will be invoked. // The deactivation is idempotent and subsequent deactivation of the same current state // will not lead to re-execution of deactivation callbacks. -func (sm *StateMachine) DeactivateCtx(ctx context.Context) error { +func (sm *StateMachine[S, T, _]) DeactivateCtx(ctx context.Context) error { sr, err := sm.currentState(ctx) if err != nil { return err @@ -191,13 +195,13 @@ func (sm *StateMachine) DeactivateCtx(ctx context.Context) error { } // IsInState see IsInStateCtx. -func (sm *StateMachine) IsInState(state State) (bool, error) { +func (sm *StateMachine[S, T, _]) IsInState(state S) (bool, error) { return sm.IsInStateCtx(context.Background(), state) } // IsInStateCtx determine if the state machine is in the supplied state. // Returns true if the current state is equal to, or a substate of, the supplied state. -func (sm *StateMachine) IsInStateCtx(ctx context.Context, state State) (bool, error) { +func (sm *StateMachine[S, T, _]) IsInStateCtx(ctx context.Context, state S) (bool, error) { sr, err := sm.currentState(ctx) if err != nil { return false, err @@ -206,22 +210,22 @@ func (sm *StateMachine) IsInStateCtx(ctx context.Context, state State) (bool, er } // CanFire see CanFireCtx. -func (sm *StateMachine) CanFire(trigger Trigger, args ...any) (bool, error) { - return sm.CanFireCtx(context.Background(), trigger, args...) +func (sm *StateMachine[S, T, A]) CanFire(trigger T, arg A) (bool, error) { + return sm.CanFireCtx(context.Background(), trigger, arg) } // CanFireCtx returns true if the trigger can be fired in the current state. -func (sm *StateMachine) CanFireCtx(ctx context.Context, trigger Trigger, args ...any) (bool, error) { +func (sm *StateMachine[S, T, A]) CanFireCtx(ctx context.Context, trigger T, arg A) (bool, error) { sr, err := sm.currentState(ctx) if err != nil { return false, err } - return sr.CanHandle(ctx, trigger, args...), nil + return sr.CanHandle(ctx, trigger, arg), nil } // SetTriggerParameters specify the arguments that must be supplied when a specific trigger is fired. -func (sm *StateMachine) SetTriggerParameters(trigger Trigger, argumentTypes ...reflect.Type) { - config := triggerWithParameters{Trigger: trigger, ArgumentTypes: argumentTypes} +func (sm *StateMachine[S, T, A]) SetTriggerParameters(trigger T, argumentTypes ...reflect.Type) { + config := triggerWithParameters[T]{Trigger: trigger, ArgumentTypes: argumentTypes} if _, ok := sm.triggerConfig[config.Trigger]; ok { panic(fmt.Sprintf("stateless: Parameters for the trigger '%v' have already been configured.", trigger)) } @@ -229,8 +233,8 @@ func (sm *StateMachine) SetTriggerParameters(trigger Trigger, argumentTypes ...r } // Fire see FireCtx -func (sm *StateMachine) Fire(trigger Trigger, args ...any) error { - return sm.FireCtx(context.Background(), trigger, args...) +func (sm *StateMachine[S, T, A]) Fire(trigger T, arg A) error { + return sm.FireCtx(context.Background(), trigger, arg) } // FireCtx transition from the current state via the specified trigger. @@ -246,56 +250,56 @@ func (sm *StateMachine) Fire(trigger Trigger, args ...any) error { // // The context is passed down to all actions and callbacks called within the scope of this method. // There is no context error checking, although it may be implemented in future releases. -func (sm *StateMachine) FireCtx(ctx context.Context, trigger Trigger, args ...any) error { - return sm.internalFire(ctx, trigger, args...) +func (sm *StateMachine[S, T, A]) FireCtx(ctx context.Context, trigger T, arg A) error { + return sm.internalFire(ctx, trigger, arg) } // OnTransitioned registers a callback that will be invoked every time the state machine // successfully finishes a transitions from one state into another. -func (sm *StateMachine) OnTransitioned(fn ...TransitionFunc) { +func (sm *StateMachine[S, T, _]) OnTransitioned(fn ...TransitionFunc[S, T]) { sm.onTransitionedEvents = append(sm.onTransitionedEvents, fn...) } // OnTransitioning registers a callback that will be invoked every time the state machine // starts a transitions from one state into another. -func (sm *StateMachine) OnTransitioning(fn ...TransitionFunc) { +func (sm *StateMachine[S, T, _]) OnTransitioning(fn ...TransitionFunc[S, T]) { sm.onTransitioningEvents = append(sm.onTransitioningEvents, fn...) } // OnUnhandledTrigger override the default behaviour of returning an error when an unhandled trigger. -func (sm *StateMachine) OnUnhandledTrigger(fn UnhandledTriggerActionFunc) { +func (sm *StateMachine[S, T, _]) OnUnhandledTrigger(fn UnhandledTriggerActionFunc[S, T]) { sm.unhandledTriggerAction = fn } // Configure begin configuration of the entry/exit actions and allowed transitions // when the state machine is in a particular state. -func (sm *StateMachine) Configure(state State) *StateConfiguration { - return &StateConfiguration{sm: sm, sr: sm.stateRepresentation(state), lookup: sm.stateRepresentation} +func (sm *StateMachine[S, T, A]) Configure(state S) *StateConfiguration[S, T, A] { + return &StateConfiguration[S, T, A]{sm: sm, sr: sm.stateRepresentation(state), lookup: sm.stateRepresentation} } // Firing returns true when the state machine is processing a trigger. -func (sm *StateMachine) Firing() bool { +func (sm *StateMachine[_, _, _]) Firing() bool { return sm.mode.Firing() } // String returns a human-readable representation of the state machine. // It is not guaranteed that the order of the PermittedTriggers is the same in consecutive executions. -func (sm *StateMachine) String() string { +func (sm *StateMachine[_, T, A]) String(arg A) string { state, err := sm.State(context.Background()) if err != nil { return "" } // PermittedTriggers only returns an error if state accessor returns one, and it has already been checked. - triggers, _ := sm.PermittedTriggers() + triggers, _ := sm.PermittedTriggers(arg) return fmt.Sprintf("StateMachine {{ State = %v, PermittedTriggers = %v }}", state, triggers) } -func (sm *StateMachine) setState(ctx context.Context, state State) error { +func (sm *StateMachine[S, T, _]) setState(ctx context.Context, state S) error { return sm.stateMutator(ctx, state) } -func (sm *StateMachine) currentState(ctx context.Context) (*stateRepresentation, error) { +func (sm *StateMachine[S, T, A]) currentState(ctx context.Context) (*stateRepresentation[S, T, A], error) { state, err := sm.State(ctx) if err != nil { return nil, err @@ -303,7 +307,7 @@ func (sm *StateMachine) currentState(ctx context.Context) (*stateRepresentation, return sm.stateRepresentation(state), nil } -func (sm *StateMachine) stateRepresentation(state State) *stateRepresentation { +func (sm *StateMachine[S, T, A]) stateRepresentation(state S) *stateRepresentation[S, T, A] { sm.stateMutex.RLock() sr, ok := sm.stateConfig[state] sm.stateMutex.RUnlock() @@ -312,78 +316,81 @@ func (sm *StateMachine) stateRepresentation(state State) *stateRepresentation { defer sm.stateMutex.Unlock() // Check again, since another goroutine may have added it while we were waiting for the lock. if sr, ok = sm.stateConfig[state]; !ok { - sr = newstateRepresentation(state) + sr = newstateRepresentation[S, T, A](state) sm.stateConfig[state] = sr } } return sr } -func (sm *StateMachine) internalFire(ctx context.Context, trigger Trigger, args ...any) error { - return sm.mode.Fire(ctx, trigger, args...) +func (sm *StateMachine[S, T, A]) internalFire(ctx context.Context, trigger T, arg A) error { + return sm.mode.Fire(ctx, trigger, arg) } -func (sm *StateMachine) internalFireOne(ctx context.Context, trigger Trigger, args ...any) error { +func (sm *StateMachine[S, T, A]) internalFireOne(ctx context.Context, trigger T, arg A) error { var ( - config triggerWithParameters - ok bool + val Validatable + ok bool ) - if config, ok = sm.triggerConfig[trigger]; ok { - config.validateParameters(args...) + if val, ok = any(arg).(Validatable); ok { + var config triggerWithParameters[T] + if config, ok = sm.triggerConfig[trigger]; ok { + config.validateParameters(val) + } } source, err := sm.State(ctx) if err != nil { return err } representativeState := sm.stateRepresentation(source) - var result triggerBehaviourResult - if result, ok = representativeState.FindHandler(ctx, trigger, args...); !ok { + var result triggerBehaviourResult[T, A] + if result, ok = representativeState.FindHandler(ctx, trigger, arg); !ok { return sm.unhandledTriggerAction(ctx, representativeState.State, trigger, result.UnmetGuardConditions) } switch t := result.Handler.(type) { - case *ignoredTriggerBehaviour: + case *ignoredTriggerBehaviour[T, A]: // ignored - case *reentryTriggerBehaviour: - transition := Transition{Source: source, Destination: t.Destination, Trigger: trigger} - err = sm.handleReentryTrigger(ctx, representativeState, transition, args...) - case *dynamicTriggerBehaviour: - var destination any - destination, err = t.Destination(ctx, args...) + case *reentryTriggerBehaviour[S, T, A]: + transition := Transition[S, T]{Source: source, Destination: t.Destination, Trigger: trigger} + err = sm.handleReentryTrigger(ctx, representativeState, transition, arg) + case *dynamicTriggerBehaviour[S, T, A]: + var destination S + destination, err = t.Destination(ctx, arg) if err == nil { - transition := Transition{Source: source, Destination: destination, Trigger: trigger} - err = sm.handleTransitioningTrigger(ctx, representativeState, transition, args...) + transition := Transition[S, T]{Source: source, Destination: destination, Trigger: trigger} + err = sm.handleTransitioningTrigger(ctx, representativeState, transition, arg) } - case *transitioningTriggerBehaviour: + case *transitioningTriggerBehaviour[S, T, A]: if source == t.Destination { // If a trigger was found on a superstate that would cause unintended reentry, don't trigger. break } - transition := Transition{Source: source, Destination: t.Destination, Trigger: trigger} - err = sm.handleTransitioningTrigger(ctx, representativeState, transition, args...) - case *internalTriggerBehaviour: - var sr *stateRepresentation + transition := Transition[S, T]{Source: source, Destination: t.Destination, Trigger: trigger} + err = sm.handleTransitioningTrigger(ctx, representativeState, transition, arg) + case *internalTriggerBehaviour[S, T, A]: + var sr *stateRepresentation[S, T, A] sr, err = sm.currentState(ctx) if err == nil { - transition := Transition{Source: source, Destination: source, Trigger: trigger} - err = sr.InternalAction(ctx, transition, args...) + transition := Transition[S, T]{Source: source, Destination: source, Trigger: trigger} + err = sr.InternalAction(ctx, transition, arg) } } return err } -func (sm *StateMachine) handleReentryTrigger(ctx context.Context, sr *stateRepresentation, transition Transition, args ...any) error { - if err := sr.Exit(ctx, transition, args...); err != nil { +func (sm *StateMachine[S, T, A]) handleReentryTrigger(ctx context.Context, sr *stateRepresentation[S, T, A], transition Transition[S, T], arg A) error { + if err := sr.Exit(ctx, transition, arg); err != nil { return err } newSr := sm.stateRepresentation(transition.Destination) if !transition.IsReentry() { - transition = Transition{Source: transition.Destination, Destination: transition.Destination, Trigger: transition.Trigger} - if err := newSr.Exit(ctx, transition, args...); err != nil { + transition = Transition[S, T]{Source: transition.Destination, Destination: transition.Destination, Trigger: transition.Trigger} + if err := newSr.Exit(ctx, transition, arg); err != nil { return err } } callEvents(sm.onTransitioningEvents, ctx, transition) - rep, err := sm.enterState(ctx, newSr, transition, args...) + rep, err := sm.enterState(ctx, newSr, transition, arg) if err != nil { return err } @@ -394,8 +401,8 @@ func (sm *StateMachine) handleReentryTrigger(ctx context.Context, sr *stateRepre return nil } -func (sm *StateMachine) handleTransitioningTrigger(ctx context.Context, sr *stateRepresentation, transition Transition, args ...any) error { - if err := sr.Exit(ctx, transition, args...); err != nil { +func (sm *StateMachine[S, T, A]) handleTransitioningTrigger(ctx context.Context, sr *stateRepresentation[S, T, A], transition Transition[S, T], arg A) error { + if err := sr.Exit(ctx, transition, arg); err != nil { return err } callEvents(sm.onTransitioningEvents, ctx, transition) @@ -403,7 +410,7 @@ func (sm *StateMachine) handleTransitioningTrigger(ctx context.Context, sr *stat return err } newSr := sm.stateRepresentation(transition.Destination) - rep, err := sm.enterState(ctx, newSr, transition, args...) + rep, err := sm.enterState(ctx, newSr, transition, arg) if err != nil { return err } @@ -413,13 +420,13 @@ func (sm *StateMachine) handleTransitioningTrigger(ctx context.Context, sr *stat return err } } - callEvents(sm.onTransitionedEvents, ctx, Transition{transition.Source, rep.State, transition.Trigger, false}) + callEvents(sm.onTransitionedEvents, ctx, Transition[S, T]{transition.Source, rep.State, transition.Trigger, false}) return nil } -func (sm *StateMachine) enterState(ctx context.Context, sr *stateRepresentation, transition Transition, args ...any) (*stateRepresentation, error) { +func (sm *StateMachine[S, T, A]) enterState(ctx context.Context, sr *stateRepresentation[S, T, A], transition Transition[S, T], arg A) (*stateRepresentation[S, T, A], error) { // Enter the new state - err := sr.Enter(ctx, transition, args...) + err := sr.Enter(ctx, transition, arg) if err != nil { return nil, err } @@ -437,10 +444,10 @@ func (sm *StateMachine) enterState(ctx context.Context, sr *stateRepresentation, if !isValidForInitialState { panic(fmt.Sprintf("stateless: The target (%v) for the initial transition is not a substate.", sr.InitialTransitionTarget)) } - initialTranslation := Transition{Source: transition.Source, Destination: sr.InitialTransitionTarget, Trigger: transition.Trigger, isInitial: true} + initialTranslation := Transition[S, T]{Source: transition.Source, Destination: sr.InitialTransitionTarget, Trigger: transition.Trigger, isInitial: true} sr = sm.stateRepresentation(sr.InitialTransitionTarget) - callEvents(sm.onTransitioningEvents, ctx, Transition{transition.Destination, initialTranslation.Destination, transition.Trigger, false}) - sr, err = sm.enterState(ctx, sr, initialTranslation, args...) + callEvents(sm.onTransitioningEvents, ctx, Transition[S, T]{transition.Destination, initialTranslation.Destination, transition.Trigger, false}) + sr, err = sm.enterState(ctx, sr, initialTranslation, arg) } return sr, err } diff --git a/statemachine_test.go b/statemachine_test.go index 393e81b..86e7a0e 100644 --- a/statemachine_test.go +++ b/statemachine_test.go @@ -23,11 +23,11 @@ const ( func TestTransition_IsReentry(t *testing.T) { tests := []struct { name string - t *Transition + t *Transition[string, string] want bool }{ - {"TransitionIsNotChange", &Transition{"1", "1", "0", false}, true}, - {"TransitionIsChange", &Transition{"1", "2", "0", false}, false}, + {"TransitionIsNotChange", &Transition[string, string]{"1", "1", "0", false}, true}, + {"TransitionIsChange", &Transition[string, string]{"1", "2", "0", false}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -39,17 +39,17 @@ func TestTransition_IsReentry(t *testing.T) { } func TestStateMachine_NewStateMachine(t *testing.T) { - sm := NewStateMachine(stateA) + sm := NewStateMachine[string, string, any](stateA) if got := sm.MustState(); got != stateA { t.Errorf("MustState() = %v, want %v", got, stateA) } } func TestStateMachine_NewStateMachineWithExternalStorage(t *testing.T) { - var state State = stateB - sm := NewStateMachineWithExternalStorage(func(_ context.Context) (State, error) { + state := stateB + sm := NewStateMachineWithExternalStorage[string, string, any](func(_ context.Context) (string, error) { return state, nil - }, func(_ context.Context, s State) error { + }, func(_ context.Context, s string) error { state = s return nil }, FiringImmediate) @@ -60,7 +60,7 @@ func TestStateMachine_NewStateMachineWithExternalStorage(t *testing.T) { if state != stateB { t.Errorf("expected state to be %v, got %v", stateB, state) } - sm.Fire(triggerX) + sm.Fire(triggerX, nil) if got := sm.MustState(); got != stateC { t.Errorf("MustState() = %v, want %v", got, stateC) } @@ -70,7 +70,7 @@ func TestStateMachine_NewStateMachineWithExternalStorage(t *testing.T) { } func TestStateMachine_Configure_SubstateIsIncludedInCurrentState(t *testing.T) { - sm := NewStateMachine(stateB) + sm := NewStateMachine[string, string, any](stateB) sm.Configure(stateB).SubstateOf(stateC) if ok, _ := sm.IsInState(stateC); !ok { t.Errorf("IsInState() = %v, want %v", ok, true) @@ -82,10 +82,10 @@ func TestStateMachine_Configure_SubstateIsIncludedInCurrentState(t *testing.T) { } func TestStateMachine_Configure_InSubstate_TriggerIgnoredInSuperstate_RemainsInSubstate(t *testing.T) { - sm := NewStateMachine(stateB) + sm := NewStateMachine[string, string, any](stateB) sm.Configure(stateB).SubstateOf(stateC) sm.Configure(stateC).Ignore(triggerX) - sm.Fire(triggerX) + sm.Fire(triggerX, nil) if got := sm.MustState(); got != stateB { t.Errorf("MustState() = %v, want %v", got, stateB) @@ -93,24 +93,24 @@ func TestStateMachine_Configure_InSubstate_TriggerIgnoredInSuperstate_RemainsInS } func TestStateMachine_CanFire(t *testing.T) { - sm := NewStateMachine(stateB) + sm := NewStateMachine[string, string, any](stateB) sm.Configure(stateB).Permit(triggerX, stateA) - if ok, _ := sm.CanFire(triggerX); !ok { + if ok, _ := sm.CanFire(triggerX, nil); !ok { t.Errorf("CanFire() = %v, want %v", ok, true) } - if ok, _ := sm.CanFire(triggerY); ok { + if ok, _ := sm.CanFire(triggerY, nil); ok { t.Errorf("CanFire() = %v, want %v", ok, false) } } func TestStateMachine_CanFire_StatusError(t *testing.T) { - sm := NewStateMachineWithExternalStorage(func(_ context.Context) (State, error) { - return nil, errors.New("status error") - }, func(_ context.Context, s State) error { return nil }, FiringImmediate) + sm := NewStateMachineWithExternalStorage[string, string, any](func(_ context.Context) (string, error) { + return "", errors.New("status error") + }, func(_ context.Context, s string) error { return nil }, FiringImmediate) sm.Configure(stateB).Permit(triggerX, stateA) - ok, err := sm.CanFire(triggerX) + ok, err := sm.CanFire(triggerX, nil) if ok { t.Fail() } @@ -121,9 +121,9 @@ func TestStateMachine_CanFire_StatusError(t *testing.T) { } func TestStateMachine_IsInState_StatusError(t *testing.T) { - sm := NewStateMachineWithExternalStorage(func(_ context.Context) (State, error) { - return nil, errors.New("status error") - }, func(_ context.Context, s State) error { return nil }, FiringImmediate) + sm := NewStateMachineWithExternalStorage[string, string, any](func(_ context.Context) (string, error) { + return "", errors.New("status error") + }, func(_ context.Context, s string) error { return nil }, FiringImmediate) ok, err := sm.IsInState(stateA) if ok { @@ -136,9 +136,9 @@ func TestStateMachine_IsInState_StatusError(t *testing.T) { } func TestStateMachine_Activate_StatusError(t *testing.T) { - sm := NewStateMachineWithExternalStorage(func(_ context.Context) (State, error) { - return nil, errors.New("status error") - }, func(_ context.Context, s State) error { return nil }, FiringImmediate) + sm := NewStateMachineWithExternalStorage[string, string, any](func(_ context.Context) (string, error) { + return "", errors.New("status error") + }, func(_ context.Context, s string) error { return nil }, FiringImmediate) want := "status error" if err := sm.Activate(); err == nil || err.Error() != want { @@ -150,37 +150,37 @@ func TestStateMachine_Activate_StatusError(t *testing.T) { } func TestStateMachine_PermittedTriggers_StatusError(t *testing.T) { - sm := NewStateMachineWithExternalStorage(func(_ context.Context) (State, error) { - return nil, errors.New("status error") - }, func(_ context.Context, s State) error { return nil }, FiringImmediate) + sm := NewStateMachineWithExternalStorage[string, string, any](func(_ context.Context) (string, error) { + return "", errors.New("status error") + }, func(_ context.Context, s string) error { return nil }, FiringImmediate) want := "status error" - if _, err := sm.PermittedTriggers(); err == nil || err.Error() != want { + if _, err := sm.PermittedTriggers(nil); err == nil || err.Error() != want { t.Errorf("PermittedTriggers() = %v, want %v", err, want) } } func TestStateMachine_MustState_StatusError(t *testing.T) { - sm := NewStateMachineWithExternalStorage(func(_ context.Context) (State, error) { - return nil, errors.New("") - }, func(_ context.Context, s State) error { return nil }, FiringImmediate) + sm := NewStateMachineWithExternalStorage[string, string, any](func(_ context.Context) (string, error) { + return "", errors.New("") + }, func(_ context.Context, s string) error { return nil }, FiringImmediate) assertPanic(t, func() { sm.MustState() }) } func TestStateMachine_Fire_StatusError(t *testing.T) { - sm := NewStateMachineWithExternalStorage(func(_ context.Context) (State, error) { - return nil, errors.New("status error") - }, func(_ context.Context, s State) error { return nil }, FiringImmediate) + sm := NewStateMachineWithExternalStorage[string, string, any](func(_ context.Context) (string, error) { + return "", errors.New("status error") + }, func(_ context.Context, s string) error { return nil }, FiringImmediate) want := "status error" - if err := sm.Fire(triggerX); err == nil || err.Error() != want { + if err := sm.Fire(triggerX, nil); err == nil || err.Error() != want { t.Errorf("Fire() = %v, want %v", err, want) } } func TestStateMachine_Configure_PermittedTriggersIncludeSuperstatePermittedTriggers(t *testing.T) { - sm := NewStateMachine(stateB) + sm := NewStateMachine[string, string, any](stateB) sm.Configure(stateA).Permit(triggerZ, stateB) sm.Configure(stateB).SubstateOf(stateC).Permit(triggerX, stateA) sm.Configure(stateC).Permit(triggerY, stateA) @@ -211,21 +211,21 @@ func TestStateMachine_Configure_PermittedTriggersIncludeSuperstatePermittedTrigg } func TestStateMachine_PermittedTriggers_PermittedTriggersAreDistinctValues(t *testing.T) { - sm := NewStateMachine(stateB) + sm := NewStateMachine[string, string, any](stateB) sm.Configure(stateB).SubstateOf(stateC).Permit(triggerX, stateA) sm.Configure(stateC).Permit(triggerX, stateB) permitted, _ := sm.PermittedTriggers(context.Background()) - want := []any{triggerX} + want := []string{triggerX} if !reflect.DeepEqual(permitted, want) { t.Errorf("PermittedTriggers() = %v, want %v", permitted, want) } } func TestStateMachine_PermittedTriggers_AcceptedTriggersRespectGuards(t *testing.T) { - sm := NewStateMachine(stateB) - sm.Configure(stateB).Permit(triggerX, stateA, func(_ context.Context, _ ...any) bool { + sm := NewStateMachine[string, string, any](stateB) + sm.Configure(stateB).Permit(triggerX, stateA, func(_ context.Context, _ any) bool { return false }) @@ -237,10 +237,10 @@ func TestStateMachine_PermittedTriggers_AcceptedTriggersRespectGuards(t *testing } func TestStateMachine_PermittedTriggers_AcceptedTriggersRespectMultipleGuards(t *testing.T) { - sm := NewStateMachine(stateB) - sm.Configure(stateB).Permit(triggerX, stateA, func(_ context.Context, _ ...any) bool { + sm := NewStateMachine[string, string, any](stateB) + sm.Configure(stateB).Permit(triggerX, stateA, func(_ context.Context, _ any) bool { return true - }, func(_ context.Context, _ ...any) bool { + }, func(_ context.Context, _ any) bool { return false }) @@ -252,16 +252,16 @@ func TestStateMachine_PermittedTriggers_AcceptedTriggersRespectMultipleGuards(t } func TestStateMachine_Fire_DiscriminatedByGuard_ChoosesPermitedTransition(t *testing.T) { - sm := NewStateMachine(stateB) + sm := NewStateMachine[string, string, any](stateB) sm.Configure(stateB). - Permit(triggerX, stateA, func(_ context.Context, _ ...any) bool { + Permit(triggerX, stateA, func(_ context.Context, _ any) bool { return false }). - Permit(triggerX, stateC, func(_ context.Context, _ ...any) bool { + Permit(triggerX, stateC, func(_ context.Context, _ any) bool { return true }) - sm.Fire(triggerX) + sm.Fire(triggerX, nil) if got := sm.MustState(); got != stateC { t.Errorf("MustState() = %v, want %v", got, stateC) @@ -269,15 +269,15 @@ func TestStateMachine_Fire_DiscriminatedByGuard_ChoosesPermitedTransition(t *tes } func TestStateMachine_Fire_SaveError(t *testing.T) { - sm := NewStateMachineWithExternalStorage(func(_ context.Context) (State, error) { + sm := NewStateMachineWithExternalStorage[string, string, any](func(_ context.Context) (string, error) { return stateB, nil - }, func(_ context.Context, s State) error { return errors.New("status error") }, FiringImmediate) + }, func(_ context.Context, s string) error { return errors.New("status error") }, FiringImmediate) sm.Configure(stateB). Permit(triggerX, stateA) want := "status error" - if err := sm.Fire(triggerX); err == nil || err.Error() != want { + if err := sm.Fire(triggerX, nil); err == nil || err.Error() != want { t.Errorf("Fire() = %v, want %v", err, want) } if sm.MustState() != stateB { @@ -287,15 +287,15 @@ func TestStateMachine_Fire_SaveError(t *testing.T) { func TestStateMachine_Fire_TriggerIsIgnored_ActionsNotExecuted(t *testing.T) { fired := false - sm := NewStateMachine(stateB) + sm := NewStateMachine[string, string, any](stateB) sm.Configure(stateB). - OnEntry(func(_ context.Context, _ ...any) error { + OnEntry(func(_ context.Context, _ any) error { fired = true return nil }). Ignore(triggerX) - sm.Fire(triggerX) + sm.Fire(triggerX, nil) if fired { t.Error("actions were executed") @@ -304,22 +304,22 @@ func TestStateMachine_Fire_TriggerIsIgnored_ActionsNotExecuted(t *testing.T) { func TestStateMachine_Fire_SelfTransitionPermited_ActionsFire(t *testing.T) { fired := false - sm := NewStateMachine(stateB) + sm := NewStateMachine[string, string, any](stateB) sm.Configure(stateB). - OnEntry(func(_ context.Context, _ ...any) error { + OnEntry(func(_ context.Context, _ any) error { fired = true return nil }). PermitReentry(triggerX) - sm.Fire(triggerX) + sm.Fire(triggerX, nil) if !fired { t.Error("actions did not fire") } } func TestStateMachine_Fire_ImplicitReentryIsDisallowed(t *testing.T) { - sm := NewStateMachine(stateB) + sm := NewStateMachine[string, string, any](stateB) assertPanic(t, func() { sm.Configure(stateB). Permit(triggerX, stateB) @@ -327,24 +327,24 @@ func TestStateMachine_Fire_ImplicitReentryIsDisallowed(t *testing.T) { } func TestStateMachine_Fire_ErrorForInvalidTransition(t *testing.T) { - sm := NewStateMachine(stateA) - if err := sm.Fire(triggerX); err == nil { + sm := NewStateMachine[string, string, any](stateA) + if err := sm.Fire(triggerX, nil); err == nil { t.Error("error expected") } } func TestStateMachine_Fire_ErrorForInvalidTransitionMentionsGuardDescriptionIfPresent(t *testing.T) { - sm := NewStateMachine(stateA) - sm.Configure(stateA).Permit(triggerX, stateB, func(_ context.Context, _ ...any) bool { + sm := NewStateMachine[string, string, any](stateA) + sm.Configure(stateA).Permit(triggerX, stateB, func(_ context.Context, _ any) bool { return false }) - if err := sm.Fire(triggerX); err == nil { + if err := sm.Fire(triggerX, nil); err == nil { t.Error("error expected") } } func TestStateMachine_Fire_ParametersSuppliedToFireArePassedToEntryAction(t *testing.T) { - sm := NewStateMachine(stateB) + sm := NewStateMachine[string, string, Args](stateB) sm.SetTriggerParameters(triggerX, reflect.TypeOf(""), reflect.TypeOf(0)) sm.Configure(stateB).Permit(triggerX, stateC) @@ -352,13 +352,13 @@ func TestStateMachine_Fire_ParametersSuppliedToFireArePassedToEntryAction(t *tes entryArg1 string entryArg2 int ) - sm.Configure(stateC).OnEntryFrom(triggerX, func(_ context.Context, args ...any) error { + sm.Configure(stateC).OnEntryFrom(triggerX, func(_ context.Context, args Args) error { entryArg1 = args[0].(string) entryArg2 = args[1].(int) return nil }) suppliedArg1, suppliedArg2 := "something", 2 - sm.Fire(triggerX, suppliedArg1, suppliedArg2) + sm.Fire(triggerX, Args{suppliedArg1, suppliedArg2}) if entryArg1 != suppliedArg1 { t.Errorf("entryArg1 = %v, want %v", entryArg1, suppliedArg1) @@ -369,7 +369,7 @@ func TestStateMachine_Fire_ParametersSuppliedToFireArePassedToEntryAction(t *tes } func TestStateMachine_Fire_ParametersSuppliedToFireArePassedToExitAction(t *testing.T) { - sm := NewStateMachine(stateB) + sm := NewStateMachine[string, string, Args](stateB) sm.SetTriggerParameters(triggerX, reflect.TypeOf(""), reflect.TypeOf(0)) sm.Configure(stateB).Permit(triggerX, stateC) @@ -377,13 +377,13 @@ func TestStateMachine_Fire_ParametersSuppliedToFireArePassedToExitAction(t *test entryArg1 string entryArg2 int ) - sm.Configure(stateB).OnExitWith(triggerX, func(_ context.Context, args ...any) error { + sm.Configure(stateB).OnExitWith(triggerX, func(_ context.Context, args Args) error { entryArg1 = args[0].(string) entryArg2 = args[1].(int) return nil }) suppliedArg1, suppliedArg2 := "something", 2 - sm.Fire(triggerX, suppliedArg1, suppliedArg2) + sm.Fire(triggerX, Args{suppliedArg1, suppliedArg2}) if entryArg1 != suppliedArg1 { t.Errorf("entryArg1 = %v, want %v", entryArg1, suppliedArg1) @@ -394,18 +394,18 @@ func TestStateMachine_Fire_ParametersSuppliedToFireArePassedToExitAction(t *test } func TestStateMachine_OnUnhandledTrigger_TheProvidedHandlerIsCalledWithStateAndTrigger(t *testing.T) { - sm := NewStateMachine(stateB) + sm := NewStateMachine[string, string, any](stateB) var ( - unhandledState State - unhandledTrigger Trigger + unhandledState string + unhandledTrigger string ) - sm.OnUnhandledTrigger(func(_ context.Context, state State, trigger Trigger, unmetGuards []string) error { + sm.OnUnhandledTrigger(func(_ context.Context, state string, trigger string, unmetGuards []string) error { unhandledState = state unhandledTrigger = trigger return nil }) - sm.Fire(triggerZ) + sm.Fire(triggerZ, nil) if stateB != unhandledState { t.Errorf("unhandledState = %v, want %v", unhandledState, stateB) @@ -416,7 +416,7 @@ func TestStateMachine_OnUnhandledTrigger_TheProvidedHandlerIsCalledWithStateAndT } func TestStateMachine_SetTriggerParameters_TriggerParametersAreImmutableOnceSet(t *testing.T) { - sm := NewStateMachine(stateB) + sm := NewStateMachine[string, string, any](stateB) sm.SetTriggerParameters(triggerX, reflect.TypeOf(""), reflect.TypeOf(0)) @@ -424,7 +424,7 @@ func TestStateMachine_SetTriggerParameters_TriggerParametersAreImmutableOnceSet( } func TestStateMachine_SetTriggerParameters_Interfaces(t *testing.T) { - sm := NewStateMachine(stateB) + sm := NewStateMachine[string, string, any](stateB) sm.SetTriggerParameters(triggerX, reflect.TypeOf((*error)(nil)).Elem()) sm.Configure(stateB).Permit(triggerX, stateA) @@ -437,27 +437,27 @@ func TestStateMachine_SetTriggerParameters_Interfaces(t *testing.T) { } func TestStateMachine_SetTriggerParameters_Invalid(t *testing.T) { - sm := NewStateMachine(stateB) + sm := NewStateMachine[string, string, Args](stateB) sm.SetTriggerParameters(triggerX, reflect.TypeOf(""), reflect.TypeOf(0)) sm.Configure(stateB).Permit(triggerX, stateA) - assertPanic(t, func() { sm.Fire(triggerX) }) - assertPanic(t, func() { sm.Fire(triggerX, "1", "2", "3") }) - assertPanic(t, func() { sm.Fire(triggerX, "1", "2") }) + assertPanic(t, func() { sm.Fire(triggerX, nil) }) + assertPanic(t, func() { sm.Fire(triggerX, Args{"1", "2", "3"}) }) + assertPanic(t, func() { sm.Fire(triggerX, Args{"1", "2"}) }) } func TestStateMachine_OnTransitioning_EventFires(t *testing.T) { - sm := NewStateMachine(stateB) + sm := NewStateMachine[string, string, any](stateB) sm.Configure(stateB).Permit(triggerX, stateA) - var transition Transition - sm.OnTransitioning(func(_ context.Context, tr Transition) { + var transition Transition[string, string] + sm.OnTransitioning(func(_ context.Context, tr Transition[string, string]) { transition = tr }) - sm.Fire(triggerX) + sm.Fire(triggerX, nil) - want := Transition{ + want := Transition[string, string]{ Source: stateB, Destination: stateA, Trigger: triggerX, @@ -468,16 +468,16 @@ func TestStateMachine_OnTransitioning_EventFires(t *testing.T) { } func TestStateMachine_OnTransitioned_EventFires(t *testing.T) { - sm := NewStateMachine(stateB) + sm := NewStateMachine[string, string, any](stateB) sm.Configure(stateB).Permit(triggerX, stateA) - var transition Transition - sm.OnTransitioned(func(_ context.Context, tr Transition) { + var transition Transition[string, string] + sm.OnTransitioned(func(_ context.Context, tr Transition[string, string]) { transition = tr }) - sm.Fire(triggerX) + sm.Fire(triggerX, nil) - want := Transition{ + want := Transition[string, string]{ Source: stateB, Trigger: triggerX, Destination: stateA, @@ -488,36 +488,36 @@ func TestStateMachine_OnTransitioned_EventFires(t *testing.T) { } func TestStateMachine_OnTransitioned_EventFiresBeforeTheOnEntryEvent(t *testing.T) { - sm := NewStateMachine(stateB) + sm := NewStateMachine[string, string, any](stateB) expectedOrdering := []string{"OnExit", "OnTransitioning", "OnEntry", "OnTransitioned"} var actualOrdering []string - sm.Configure(stateB).Permit(triggerX, stateA).OnExit(func(_ context.Context, args ...any) error { + sm.Configure(stateB).Permit(triggerX, stateA).OnExit(func(_ context.Context, _ any) error { actualOrdering = append(actualOrdering, "OnExit") return nil }).Machine() - var transition Transition - sm.Configure(stateA).OnEntry(func(ctx context.Context, args ...any) error { + var transition Transition[string, string] + sm.Configure(stateA).OnEntry(func(ctx context.Context, _ any) error { actualOrdering = append(actualOrdering, "OnEntry") - transition = GetTransition(ctx) + transition = GetTransition[string, string](ctx) return nil }) - sm.OnTransitioning(func(_ context.Context, tr Transition) { + sm.OnTransitioning(func(_ context.Context, tr Transition[string, string]) { actualOrdering = append(actualOrdering, "OnTransitioning") }) - sm.OnTransitioned(func(_ context.Context, tr Transition) { + sm.OnTransitioned(func(_ context.Context, tr Transition[string, string]) { actualOrdering = append(actualOrdering, "OnTransitioned") }) - sm.Fire(triggerX) + sm.Fire(triggerX, nil) if !reflect.DeepEqual(actualOrdering, expectedOrdering) { t.Errorf("actualOrdering = %v, want %v", actualOrdering, expectedOrdering) } - want := Transition{ + want := Transition[string, string]{ Source: stateB, Destination: stateA, Trigger: triggerX, @@ -528,25 +528,25 @@ func TestStateMachine_OnTransitioned_EventFiresBeforeTheOnEntryEvent(t *testing. } func TestStateMachine_SubstateOf_DirectCyclicConfigurationDetected(t *testing.T) { - sm := NewStateMachine(stateA) + sm := NewStateMachine[string, string, any](stateA) assertPanic(t, func() { sm.Configure(stateA).SubstateOf(stateA) }) } func TestStateMachine_SubstateOf_NestedCyclicConfigurationDetected(t *testing.T) { - sm := NewStateMachine(stateA) + sm := NewStateMachine[string, string, any](stateA) sm.Configure(stateB).SubstateOf(stateA) assertPanic(t, func() { sm.Configure(stateA).SubstateOf(stateB) }) } func TestStateMachine_SubstateOf_NestedTwoLevelsCyclicConfigurationDetected(t *testing.T) { - sm := NewStateMachine(stateA) + sm := NewStateMachine[string, string, any](stateA) sm.Configure(stateB).SubstateOf(stateA) sm.Configure(stateC).SubstateOf(stateB) assertPanic(t, func() { sm.Configure(stateA).SubstateOf(stateC) }) } func TestStateMachine_SubstateOf_DelayedNestedCyclicConfigurationDetected(t *testing.T) { - sm := NewStateMachine(stateA) + sm := NewStateMachine[string, string, any](stateA) sm.Configure(stateB).SubstateOf(stateA) sm.Configure(stateC) sm.Configure(stateA).SubstateOf(stateC) @@ -554,18 +554,18 @@ func TestStateMachine_SubstateOf_DelayedNestedCyclicConfigurationDetected(t *tes } func TestStateMachine_Fire_IgnoreVsPermitReentry(t *testing.T) { - sm := NewStateMachine(stateA) + sm := NewStateMachine[string, string, any](stateA) var calls int sm.Configure(stateA). - OnEntry(func(_ context.Context, _ ...any) error { + OnEntry(func(_ context.Context, _ any) error { calls += 1 return nil }). PermitReentry(triggerX). Ignore(triggerY) - sm.Fire(triggerX) - sm.Fire(triggerY) + sm.Fire(triggerX, nil) + sm.Fire(triggerY, nil) if calls != 1 { t.Errorf("calls = %d, want %d", calls, 1) @@ -573,22 +573,22 @@ func TestStateMachine_Fire_IgnoreVsPermitReentry(t *testing.T) { } func TestStateMachine_Fire_IgnoreVsPermitReentryFrom(t *testing.T) { - sm := NewStateMachine(stateA) + sm := NewStateMachine[string, string, any](stateA) var calls int sm.Configure(stateA). - OnEntryFrom(triggerX, func(_ context.Context, _ ...any) error { + OnEntryFrom(triggerX, func(_ context.Context, _ any) error { calls += 1 return nil }). - OnEntryFrom(triggerY, func(_ context.Context, _ ...any) error { + OnEntryFrom(triggerY, func(_ context.Context, _ any) error { calls += 1 return nil }). PermitReentry(triggerX). Ignore(triggerY) - sm.Fire(triggerX) - sm.Fire(triggerY) + sm.Fire(triggerX, nil) + sm.Fire(triggerY, nil) if calls != 1 { t.Errorf("calls = %d, want %d", calls, 1) @@ -596,22 +596,22 @@ func TestStateMachine_Fire_IgnoreVsPermitReentryFrom(t *testing.T) { } func TestStateMachine_Fire_IgnoreVsPermitReentryExitWith(t *testing.T) { - sm := NewStateMachine(stateA) + sm := NewStateMachine[string, string, any](stateA) var calls int sm.Configure(stateA). - OnExitWith(triggerX, func(_ context.Context, _ ...any) error { + OnExitWith(triggerX, func(_ context.Context, _ any) error { calls += 1 return nil }). - OnExitWith(triggerY, func(_ context.Context, _ ...any) error { + OnExitWith(triggerY, func(_ context.Context, _ any) error { calls += 1 return nil }). PermitReentry(triggerX). Ignore(triggerY) - sm.Fire(triggerX) - sm.Fire(triggerY) + sm.Fire(triggerX, nil) + sm.Fire(triggerY, nil) if calls != 1 { t.Errorf("calls = %d, want %d", calls, 1) @@ -619,27 +619,27 @@ func TestStateMachine_Fire_IgnoreVsPermitReentryExitWith(t *testing.T) { } func TestStateMachine_Fire_IfSelfTransitionPermited_ActionsFire_InSubstate(t *testing.T) { - sm := NewStateMachine(stateA) + sm := NewStateMachine[string, string, any](stateA) var onEntryStateBfired, onExitStateBfired, onExitStateAfired bool sm.Configure(stateB). - OnEntry(func(_ context.Context, _ ...any) error { + OnEntry(func(_ context.Context, _ any) error { onEntryStateBfired = true return nil }). PermitReentry(triggerX). - OnExit(func(_ context.Context, _ ...any) error { + OnExit(func(_ context.Context, _ any) error { onExitStateBfired = true return nil }) sm.Configure(stateA). SubstateOf(stateB). - OnExit(func(_ context.Context, _ ...any) error { + OnExit(func(_ context.Context, _ any) error { onExitStateAfired = true return nil }) - sm.Fire(triggerX) + sm.Fire(triggerX, nil) if got := sm.MustState(); got != stateB { t.Errorf("sm.MustState() = %v, want %v", got, stateB) @@ -656,11 +656,10 @@ func TestStateMachine_Fire_IfSelfTransitionPermited_ActionsFire_InSubstate(t *te } func TestStateMachine_Fire_TransitionWhenParameterizedGuardTrue(t *testing.T) { - sm := NewStateMachine(stateA) - sm.SetTriggerParameters(triggerX, reflect.TypeOf(0)) + sm := NewStateMachine[string, string, int](stateA) sm.Configure(stateA). - Permit(triggerX, stateB, func(_ context.Context, args ...any) bool { - return args[0].(int) == 2 + Permit(triggerX, stateB, func(_ context.Context, arg int) bool { + return arg == 2 }) sm.Fire(triggerX, 2) @@ -671,11 +670,10 @@ func TestStateMachine_Fire_TransitionWhenParameterizedGuardTrue(t *testing.T) { } func TestStateMachine_Fire_ErrorWhenParameterizedGuardFalse(t *testing.T) { - sm := NewStateMachine(stateA) - sm.SetTriggerParameters(triggerX, reflect.TypeOf(0)) + sm := NewStateMachine[string, string, int](stateA) sm.Configure(stateA). - Permit(triggerX, stateB, func(_ context.Context, args ...any) bool { - return args[0].(int) == 3 + Permit(triggerX, stateB, func(_ context.Context, arg int) bool { + return arg == 3 }) sm.Fire(triggerX, 2) @@ -685,13 +683,12 @@ func TestStateMachine_Fire_ErrorWhenParameterizedGuardFalse(t *testing.T) { } func TestStateMachine_Fire_TransitionWhenBothParameterizedGuardClausesTrue(t *testing.T) { - sm := NewStateMachine(stateA) - sm.SetTriggerParameters(triggerX, reflect.TypeOf(0)) + sm := NewStateMachine[string, string, int](stateA) sm.Configure(stateA). - Permit(triggerX, stateB, func(_ context.Context, args ...any) bool { - return args[0].(int) == 2 - }, func(_ context.Context, args ...any) bool { - return args[0].(int) != 3 + Permit(triggerX, stateB, func(_ context.Context, arg int) bool { + return arg == 2 + }, func(_ context.Context, arg int) bool { + return arg != 3 }) sm.Fire(triggerX, 2) @@ -702,14 +699,22 @@ func TestStateMachine_Fire_TransitionWhenBothParameterizedGuardClausesTrue(t *te } func TestStateMachine_Fire_TransitionWhenGuardReturnsTrueOnTriggerWithMultipleParameters(t *testing.T) { - sm := NewStateMachine(stateA) - sm.SetTriggerParameters(triggerX, reflect.TypeOf(""), reflect.TypeOf(0)) + sm := NewStateMachine[string, string, struct { + s string + i int + }](stateA) sm.Configure(stateA). - Permit(triggerX, stateB, func(_ context.Context, args ...any) bool { - return args[0].(string) == "3" && args[1].(int) == 2 + Permit(triggerX, stateB, func(_ context.Context, arg struct { + s string + i int + }) bool { + return arg.s == "3" && arg.i == 2 }) - sm.Fire(triggerX, "3", 2) + sm.Fire(triggerX, struct { + s string + i int + }{"3", 2}) if got := sm.MustState(); got != stateB { t.Errorf("sm.MustState() = %v, want %v", got, stateB) @@ -717,21 +722,20 @@ func TestStateMachine_Fire_TransitionWhenGuardReturnsTrueOnTriggerWithMultiplePa } func TestStateMachine_Fire_TransitionWhenPermitDyanmicIfHasMultipleExclusiveGuards(t *testing.T) { - sm := NewStateMachine(stateA) - sm.SetTriggerParameters(triggerX, reflect.TypeOf(0)) + sm := NewStateMachine[string, string, int](stateA) sm.Configure(stateA). - PermitDynamic(triggerX, func(_ context.Context, args ...any) (State, error) { - if args[0].(int) == 3 { + PermitDynamic(triggerX, func(_ context.Context, arg int) (string, error) { + if arg == 3 { return stateB, nil } return stateC, nil - }, func(_ context.Context, args ...any) bool { return args[0].(int) == 3 || args[0].(int) == 5 }). - PermitDynamic(triggerX, func(_ context.Context, args ...any) (State, error) { - if args[0].(int) == 2 { + }, func(_ context.Context, arg int) bool { return arg == 3 || arg == 5 }). + PermitDynamic(triggerX, func(_ context.Context, arg int) (string, error) { + if arg == 2 { return stateC, nil } return stateD, nil - }, func(_ context.Context, args ...any) bool { return args[0].(int) == 2 || args[0].(int) == 4 }) + }, func(_ context.Context, arg int) bool { return arg == 2 || arg == 4 }) sm.Fire(triggerX, 3) @@ -741,10 +745,10 @@ func TestStateMachine_Fire_TransitionWhenPermitDyanmicIfHasMultipleExclusiveGuar } func TestStateMachine_Fire_PermitDyanmic_Error(t *testing.T) { - sm := NewStateMachine(stateA) + sm := NewStateMachine[string, string, any](stateA) sm.Configure(stateA). - PermitDynamic(triggerX, func(_ context.Context, _ ...any) (State, error) { - return nil, errors.New("") + PermitDynamic(triggerX, func(_ context.Context, _ any) (string, error) { + return "", errors.New("") }) if err := sm.Fire(triggerX, ""); err == nil { @@ -756,37 +760,35 @@ func TestStateMachine_Fire_PermitDyanmic_Error(t *testing.T) { } func TestStateMachine_Fire_PanicsWhenPermitDyanmicIfHasMultipleNonExclusiveGuards(t *testing.T) { - sm := NewStateMachine(stateA) - sm.SetTriggerParameters(triggerX, reflect.TypeOf(0)) + sm := NewStateMachine[string, string, int](stateA) sm.Configure(stateA). - PermitDynamic(triggerX, func(_ context.Context, args ...any) (State, error) { - if args[0].(int) == 4 { + PermitDynamic(triggerX, func(_ context.Context, arg int) (string, error) { + if arg == 4 { return stateB, nil } return stateC, nil - }, func(_ context.Context, args ...any) bool { return args[0].(int)%2 == 0 }). - PermitDynamic(triggerX, func(_ context.Context, args ...any) (State, error) { - if args[0].(int) == 2 { + }, func(_ context.Context, arg int) bool { return arg%2 == 0 }). + PermitDynamic(triggerX, func(_ context.Context, arg int) (string, error) { + if arg == 2 { return stateC, nil } return stateD, nil - }, func(_ context.Context, args ...any) bool { return args[0].(int) == 2 }) + }, func(_ context.Context, arg int) bool { return arg == 2 }) assertPanic(t, func() { sm.Fire(triggerX, 2) }) } func TestStateMachine_Fire_TransitionWhenPermitIfHasMultipleExclusiveGuardsWithSuperStateTrue(t *testing.T) { - sm := NewStateMachine(stateB) - sm.SetTriggerParameters(triggerX, reflect.TypeOf(0)) + sm := NewStateMachine[string, string, int](stateB) sm.Configure(stateA). - Permit(triggerX, stateD, func(_ context.Context, args ...any) bool { - return args[0].(int) == 3 + Permit(triggerX, stateD, func(_ context.Context, arg int) bool { + return arg == 3 }) sm.Configure(stateB). SubstateOf(stateA). - Permit(triggerX, stateC, func(_ context.Context, args ...any) bool { - return args[0].(int) == 2 + Permit(triggerX, stateC, func(_ context.Context, arg int) bool { + return arg == 2 }) sm.Fire(triggerX, 3) @@ -797,17 +799,16 @@ func TestStateMachine_Fire_TransitionWhenPermitIfHasMultipleExclusiveGuardsWithS } func TestStateMachine_Fire_TransitionWhenPermitIfHasMultipleExclusiveGuardsWithSuperStateFalse(t *testing.T) { - sm := NewStateMachine(stateB) - sm.SetTriggerParameters(triggerX, reflect.TypeOf(0)) + sm := NewStateMachine[string, string, int](stateB) sm.Configure(stateA). - Permit(triggerX, stateD, func(_ context.Context, args ...any) bool { - return args[0].(int) == 3 + Permit(triggerX, stateD, func(_ context.Context, arg int) bool { + return arg == 3 }) sm.Configure(stateB). SubstateOf(stateA). - Permit(triggerX, stateC, func(_ context.Context, args ...any) bool { - return args[0].(int) == 2 + Permit(triggerX, stateC, func(_ context.Context, arg int) bool { + return arg == 2 }) sm.Fire(triggerX, 2) @@ -818,14 +819,14 @@ func TestStateMachine_Fire_TransitionWhenPermitIfHasMultipleExclusiveGuardsWithS } func TestStateMachine_Fire_TransitionToSuperstateDoesNotExitSuperstate(t *testing.T) { - sm := NewStateMachine(stateB) + sm := NewStateMachine[string, string, any](stateB) var superExit, superEntry, subExit bool sm.Configure(stateA). - OnEntry(func(_ context.Context, _ ...any) error { + OnEntry(func(_ context.Context, _ any) error { superEntry = true return nil }). - OnExit(func(_ context.Context, _ ...any) error { + OnExit(func(_ context.Context, _ any) error { superExit = true return nil }) @@ -833,12 +834,12 @@ func TestStateMachine_Fire_TransitionToSuperstateDoesNotExitSuperstate(t *testin sm.Configure(stateB). SubstateOf(stateA). Permit(triggerY, stateA). - OnExit(func(_ context.Context, _ ...any) error { + OnExit(func(_ context.Context, _ any) error { subExit = true return nil }) - sm.Fire(triggerY) + sm.Fire(triggerY, nil) if !subExit { t.Error("substate should exit") @@ -852,31 +853,31 @@ func TestStateMachine_Fire_TransitionToSuperstateDoesNotExitSuperstate(t *testin } func TestStateMachine_Fire_OnExitFiresOnlyOnceReentrySubstate(t *testing.T) { - sm := NewStateMachine(stateA) + sm := NewStateMachine[string, string, any](stateA) var exitB, exitA, entryB, entryA int sm.Configure(stateA). SubstateOf(stateB). - OnEntry(func(_ context.Context, _ ...any) error { + OnEntry(func(_ context.Context, _ any) error { entryA += 1 return nil }). PermitReentry(triggerX). - OnExit(func(_ context.Context, _ ...any) error { + OnExit(func(_ context.Context, _ any) error { exitA += 1 return nil }) sm.Configure(stateB). - OnEntry(func(_ context.Context, _ ...any) error { + OnEntry(func(_ context.Context, _ any) error { entryB += 1 return nil }). - OnExit(func(_ context.Context, _ ...any) error { + OnExit(func(_ context.Context, _ any) error { exitB += 1 return nil }) - sm.Fire(triggerX) + sm.Fire(triggerX, nil) if entryB != 0 { t.Error("entryB should be 0") @@ -893,7 +894,7 @@ func TestStateMachine_Fire_OnExitFiresOnlyOnceReentrySubstate(t *testing.T) { } func TestStateMachine_Activate(t *testing.T) { - sm := NewStateMachine(stateA) + sm := NewStateMachine[string, string, any](stateA) expectedOrdering := []string{"ActivatedC", "ActivatedA"} var actualOrdering []string @@ -912,10 +913,10 @@ func TestStateMachine_Activate(t *testing.T) { }) // should not be called for activation - sm.OnTransitioning(func(_ context.Context, _ Transition) { + sm.OnTransitioning(func(_ context.Context, _ Transition[string, string]) { actualOrdering = append(actualOrdering, "OnTransitioning") }) - sm.OnTransitioned(func(_ context.Context, _ Transition) { + sm.OnTransitioned(func(_ context.Context, _ Transition[string, string]) { actualOrdering = append(actualOrdering, "OnTransitioned") }) @@ -927,7 +928,7 @@ func TestStateMachine_Activate(t *testing.T) { } func TestStateMachine_Activate_Error(t *testing.T) { - sm := NewStateMachine(stateA) + sm := NewStateMachine[string, string, any](stateA) var actualOrdering []string @@ -950,7 +951,7 @@ func TestStateMachine_Activate_Error(t *testing.T) { } func TestStateMachine_Activate_Idempotent(t *testing.T) { - sm := NewStateMachine(stateA) + sm := NewStateMachine[string, string, any](stateA) var actualOrdering []string @@ -975,7 +976,7 @@ func TestStateMachine_Activate_Idempotent(t *testing.T) { } func TestStateMachine_Deactivate(t *testing.T) { - sm := NewStateMachine(stateA) + sm := NewStateMachine[string, string, any](stateA) expectedOrdering := []string{"DeactivatedA", "DeactivatedC"} var actualOrdering []string @@ -994,10 +995,10 @@ func TestStateMachine_Deactivate(t *testing.T) { }) // should not be called for activation - sm.OnTransitioning(func(_ context.Context, _ Transition) { + sm.OnTransitioning(func(_ context.Context, _ Transition[string, string]) { actualOrdering = append(actualOrdering, "OnTransitioning") }) - sm.OnTransitioned(func(_ context.Context, _ Transition) { + sm.OnTransitioned(func(_ context.Context, _ Transition[string, string]) { actualOrdering = append(actualOrdering, "OnTransitioned") }) @@ -1010,7 +1011,7 @@ func TestStateMachine_Deactivate(t *testing.T) { } func TestStateMachine_Deactivate_NoActivated(t *testing.T) { - sm := NewStateMachine(stateA) + sm := NewStateMachine[string, string, any](stateA) var actualOrdering []string @@ -1036,7 +1037,7 @@ func TestStateMachine_Deactivate_NoActivated(t *testing.T) { } func TestStateMachine_Deactivate_Error(t *testing.T) { - sm := NewStateMachine(stateA) + sm := NewStateMachine[string, string, any](stateA) var actualOrdering []string @@ -1060,7 +1061,7 @@ func TestStateMachine_Deactivate_Error(t *testing.T) { } func TestStateMachine_Deactivate_Idempotent(t *testing.T) { - sm := NewStateMachine(stateA) + sm := NewStateMachine[string, string, any](stateA) var actualOrdering []string @@ -1088,7 +1089,7 @@ func TestStateMachine_Deactivate_Idempotent(t *testing.T) { } func TestStateMachine_Activate_Transitioning(t *testing.T) { - sm := NewStateMachine(stateA) + sm := NewStateMachine[string, string, any](stateA) var actualOrdering []string expectedOrdering := []string{"ActivatedA", "ExitedA", "OnTransitioning", "EnteredB", "OnTransitioned", @@ -1103,11 +1104,11 @@ func TestStateMachine_Activate_Transitioning(t *testing.T) { actualOrdering = append(actualOrdering, "DeactivatedA") return nil }). - OnEntry(func(_ context.Context, _ ...any) error { + OnEntry(func(_ context.Context, _ any) error { actualOrdering = append(actualOrdering, "EnteredA") return nil }). - OnExit(func(_ context.Context, _ ...any) error { + OnExit(func(_ context.Context, _ any) error { actualOrdering = append(actualOrdering, "ExitedA") return nil }). @@ -1122,26 +1123,26 @@ func TestStateMachine_Activate_Transitioning(t *testing.T) { actualOrdering = append(actualOrdering, "DeactivatedB") return nil }). - OnEntry(func(_ context.Context, _ ...any) error { + OnEntry(func(_ context.Context, _ any) error { actualOrdering = append(actualOrdering, "EnteredB") return nil }). - OnExit(func(_ context.Context, _ ...any) error { + OnExit(func(_ context.Context, _ any) error { actualOrdering = append(actualOrdering, "ExitedB") return nil }). Permit(triggerY, stateA) - sm.OnTransitioning(func(_ context.Context, _ Transition) { + sm.OnTransitioning(func(_ context.Context, _ Transition[string, string]) { actualOrdering = append(actualOrdering, "OnTransitioning") }) - sm.OnTransitioned(func(_ context.Context, _ Transition) { + sm.OnTransitioned(func(_ context.Context, _ Transition[string, string]) { actualOrdering = append(actualOrdering, "OnTransitioned") }) sm.Activate() - sm.Fire(triggerX) - sm.Fire(triggerY) + sm.Fire(triggerX, nil) + sm.Fire(triggerY, nil) if !reflect.DeepEqual(expectedOrdering, actualOrdering) { t.Errorf("expectedOrdering = %v, actualOrdering = %v", expectedOrdering, actualOrdering) @@ -1149,35 +1150,35 @@ func TestStateMachine_Activate_Transitioning(t *testing.T) { } func TestStateMachine_Fire_ImmediateEntryAProcessedBeforeEnterB(t *testing.T) { - sm := NewStateMachineWithMode(stateA, FiringImmediate) + sm := NewStateMachineWithMode[string, string, any](stateA, FiringImmediate) var actualOrdering []string expectedOrdering := []string{"ExitA", "ExitB", "EnterA", "EnterB"} sm.Configure(stateA). - OnEntry(func(_ context.Context, _ ...any) error { + OnEntry(func(_ context.Context, _ any) error { actualOrdering = append(actualOrdering, "EnterA") return nil }). - OnExit(func(_ context.Context, _ ...any) error { + OnExit(func(_ context.Context, _ any) error { actualOrdering = append(actualOrdering, "ExitA") return nil }). Permit(triggerX, stateB) sm.Configure(stateB). - OnEntry(func(_ context.Context, _ ...any) error { - sm.Fire(triggerY) + OnEntry(func(_ context.Context, _ any) error { + sm.Fire(triggerY, nil) actualOrdering = append(actualOrdering, "EnterB") return nil }). - OnExit(func(_ context.Context, _ ...any) error { + OnExit(func(_ context.Context, _ any) error { actualOrdering = append(actualOrdering, "ExitB") return nil }). Permit(triggerY, stateA) - sm.Fire(triggerX) + sm.Fire(triggerX, nil) if !reflect.DeepEqual(expectedOrdering, actualOrdering) { t.Errorf("expectedOrdering = %v, actualOrdering = %v", expectedOrdering, actualOrdering) @@ -1185,35 +1186,35 @@ func TestStateMachine_Fire_ImmediateEntryAProcessedBeforeEnterB(t *testing.T) { } func TestStateMachine_Fire_QueuedEntryAProcessedBeforeEnterB(t *testing.T) { - sm := NewStateMachineWithMode(stateA, FiringQueued) + sm := NewStateMachineWithMode[string, string, any](stateA, FiringQueued) var actualOrdering []string expectedOrdering := []string{"ExitA", "EnterB", "ExitB", "EnterA"} sm.Configure(stateA). - OnEntry(func(_ context.Context, _ ...any) error { + OnEntry(func(_ context.Context, _ any) error { actualOrdering = append(actualOrdering, "EnterA") return nil }). - OnExit(func(_ context.Context, _ ...any) error { + OnExit(func(_ context.Context, _ any) error { actualOrdering = append(actualOrdering, "ExitA") return nil }). Permit(triggerX, stateB) sm.Configure(stateB). - OnEntry(func(_ context.Context, _ ...any) error { - sm.Fire(triggerY) + OnEntry(func(_ context.Context, _ any) error { + sm.Fire(triggerY, nil) actualOrdering = append(actualOrdering, "EnterB") return nil }). - OnExit(func(_ context.Context, _ ...any) error { + OnExit(func(_ context.Context, _ any) error { actualOrdering = append(actualOrdering, "ExitB") return nil }). Permit(triggerY, stateA) - sm.Fire(triggerX) + sm.Fire(triggerX, nil) if !reflect.DeepEqual(expectedOrdering, actualOrdering) { t.Errorf("expectedOrdering = %v, actualOrdering = %v", expectedOrdering, actualOrdering) @@ -1221,35 +1222,35 @@ func TestStateMachine_Fire_QueuedEntryAProcessedBeforeEnterB(t *testing.T) { } func TestStateMachine_Fire_QueuedEntryAsyncFire(t *testing.T) { - sm := NewStateMachineWithMode(stateA, FiringQueued) + sm := NewStateMachineWithMode[string, string, any](stateA, FiringQueued) sm.Configure(stateA). Permit(triggerX, stateB) sm.Configure(stateB). - OnEntry(func(_ context.Context, _ ...any) error { - go sm.Fire(triggerY) - go sm.Fire(triggerY) + OnEntry(func(_ context.Context, _ any) error { + go sm.Fire(triggerY, nil) + go sm.Fire(triggerY, nil) return nil }). Permit(triggerY, stateA) - sm.Fire(triggerX) + sm.Fire(triggerX, nil) } func TestStateMachine_Fire_Race(t *testing.T) { - sm := NewStateMachineWithMode(stateA, FiringImmediate) + sm := NewStateMachineWithMode[string, string, any](stateA, FiringImmediate) var actualOrdering []string var mu sync.Mutex sm.Configure(stateA). - OnEntry(func(_ context.Context, _ ...any) error { + OnEntry(func(_ context.Context, _ any) error { mu.Lock() actualOrdering = append(actualOrdering, "EnterA") mu.Unlock() return nil }). - OnExit(func(_ context.Context, _ ...any) error { + OnExit(func(_ context.Context, _ any) error { mu.Lock() actualOrdering = append(actualOrdering, "ExitA") mu.Unlock() @@ -1258,14 +1259,14 @@ func TestStateMachine_Fire_Race(t *testing.T) { Permit(triggerX, stateB) sm.Configure(stateB). - OnEntry(func(_ context.Context, _ ...any) error { - sm.Fire(triggerY) + OnEntry(func(_ context.Context, _ any) error { + sm.Fire(triggerY, nil) mu.Lock() actualOrdering = append(actualOrdering, "EnterB") mu.Unlock() return nil }). - OnExit(func(_ context.Context, _ ...any) error { + OnExit(func(_ context.Context, _ any) error { mu.Lock() actualOrdering = append(actualOrdering, "ExitB") mu.Unlock() @@ -1276,11 +1277,11 @@ func TestStateMachine_Fire_Race(t *testing.T) { var wg sync.WaitGroup wg.Add(2) go func() { - sm.Fire(triggerX) + sm.Fire(triggerX, nil) wg.Done() }() go func() { - sm.Fire(triggerZ) + sm.Fire(triggerZ, nil) wg.Done() }() wg.Wait() @@ -1290,110 +1291,110 @@ func TestStateMachine_Fire_Race(t *testing.T) { } func TestStateMachine_Fire_Queued_ErrorExit(t *testing.T) { - sm := NewStateMachineWithMode(stateA, FiringQueued) + sm := NewStateMachineWithMode[string, string, any](stateA, FiringQueued) sm.Configure(stateA). Permit(triggerX, stateB) sm.Configure(stateB). - OnEntry(func(_ context.Context, _ ...any) error { - sm.Fire(triggerY) + OnEntry(func(_ context.Context, _ any) error { + sm.Fire(triggerY, nil) return nil }). - OnExit(func(_ context.Context, _ ...any) error { + OnExit(func(_ context.Context, _ any) error { return errors.New("") }). Permit(triggerY, stateA) - sm.Fire(triggerX) + sm.Fire(triggerX, nil) - if err := sm.Fire(triggerX); err == nil { + if err := sm.Fire(triggerX, nil); err == nil { t.Error("expected error") } } func TestStateMachine_Fire_Queued_ErrorEnter(t *testing.T) { - sm := NewStateMachineWithMode(stateA, FiringQueued) + sm := NewStateMachineWithMode[string, string, any](stateA, FiringQueued) sm.Configure(stateA). - OnEntry(func(_ context.Context, _ ...any) error { + OnEntry(func(_ context.Context, _ any) error { return errors.New("") }). Permit(triggerX, stateB) sm.Configure(stateB). - OnEntry(func(_ context.Context, _ ...any) error { - sm.Fire(triggerY) + OnEntry(func(_ context.Context, _ any) error { + sm.Fire(triggerY, nil) return nil }). Permit(triggerY, stateA) - sm.Fire(triggerX) + sm.Fire(triggerX, nil) - if err := sm.Fire(triggerX); err == nil { + if err := sm.Fire(triggerX, nil); err == nil { t.Error("expected error") } } func TestStateMachine_InternalTransition_StayInSameStateOneState(t *testing.T) { - sm := NewStateMachine(stateA) + sm := NewStateMachine[string, string, any](stateA) sm.Configure(stateB). - InternalTransition(triggerX, func(_ context.Context, _ ...any) error { + InternalTransition(triggerX, func(_ context.Context, _ any) error { return nil }) - sm.Fire(triggerX) + sm.Fire(triggerX, nil) if got := sm.MustState(); got != stateA { t.Errorf("expected %v, got %v", stateA, got) } } func TestStateMachine_InternalTransition_HandledOnlyOnceInSuper(t *testing.T) { - sm := NewStateMachine(stateA) + sm := NewStateMachine[string, string, any](stateA) handledIn := stateC sm.Configure(stateA). - InternalTransition(triggerX, func(_ context.Context, _ ...any) error { + InternalTransition(triggerX, func(_ context.Context, _ any) error { handledIn = stateA return nil }) sm.Configure(stateB). SubstateOf(stateA). - InternalTransition(triggerX, func(_ context.Context, _ ...any) error { + InternalTransition(triggerX, func(_ context.Context, _ any) error { handledIn = stateB return nil }) - sm.Fire(triggerX) + sm.Fire(triggerX, nil) if stateA != handledIn { t.Errorf("expected %v, got %v", stateA, handledIn) } } func TestStateMachine_InternalTransition_HandledOnlyOnceInSub(t *testing.T) { - sm := NewStateMachine(stateB) + sm := NewStateMachine[string, string, any](stateB) handledIn := stateC sm.Configure(stateA). - InternalTransition(triggerX, func(_ context.Context, _ ...any) error { + InternalTransition(triggerX, func(_ context.Context, _ any) error { handledIn = stateA return nil }) sm.Configure(stateB). SubstateOf(stateA). - InternalTransition(triggerX, func(_ context.Context, _ ...any) error { + InternalTransition(triggerX, func(_ context.Context, _ any) error { handledIn = stateB return nil }) - sm.Fire(triggerX) + sm.Fire(triggerX, nil) if stateB != handledIn { t.Errorf("expected %v, got %v", stateB, handledIn) } } func TestStateMachine_InitialTransition_EntersSubState(t *testing.T) { - sm := NewStateMachine(stateA) + sm := NewStateMachine[string, string, any](stateA) sm.Configure(stateA). Permit(triggerX, stateB) @@ -1404,14 +1405,14 @@ func TestStateMachine_InitialTransition_EntersSubState(t *testing.T) { sm.Configure(stateC). SubstateOf(stateB) - sm.Fire(triggerX) + sm.Fire(triggerX, nil) if got := sm.MustState(); got != stateC { t.Errorf("MustState() = %v, want %v", got, stateC) } } func TestStateMachine_InitialTransition_EntersSubStateofSubstate(t *testing.T) { - sm := NewStateMachine(stateA) + sm := NewStateMachine[string, string, any](stateA) sm.Configure(stateA). Permit(triggerX, stateB) @@ -1426,7 +1427,7 @@ func TestStateMachine_InitialTransition_EntersSubStateofSubstate(t *testing.T) { sm.Configure(stateD). SubstateOf(stateC) - sm.Fire(triggerX) + sm.Fire(triggerX, nil) if got := sm.MustState(); got != stateD { t.Errorf("MustState() = %v, want %v", got, stateD) } @@ -1436,37 +1437,37 @@ func TestStateMachine_InitialTransition_Ordering(t *testing.T) { var actualOrdering []string expectedOrdering := []string{"ExitA", "OnTransitioningAB", "EnterB", "OnTransitioningBC", "EnterC", "OnTransitionedAC"} - sm := NewStateMachine(stateA) + sm := NewStateMachine[string, string, any](stateA) sm.Configure(stateA). Permit(triggerX, stateB). - OnExit(func(c context.Context, i ...any) error { + OnExit(func(c context.Context, _ any) error { actualOrdering = append(actualOrdering, "ExitA") return nil }) sm.Configure(stateB). InitialTransition(stateC). - OnEntry(func(c context.Context, i ...any) error { + OnEntry(func(c context.Context, _ any) error { actualOrdering = append(actualOrdering, "EnterB") return nil }) sm.Configure(stateC). SubstateOf(stateB). - OnEntry(func(c context.Context, i ...any) error { + OnEntry(func(c context.Context, _ any) error { actualOrdering = append(actualOrdering, "EnterC") return nil }) - sm.OnTransitioning(func(_ context.Context, tr Transition) { + sm.OnTransitioning(func(_ context.Context, tr Transition[string, string]) { actualOrdering = append(actualOrdering, fmt.Sprintf("OnTransitioning%v%v", tr.Source, tr.Destination)) }) - sm.OnTransitioned(func(_ context.Context, tr Transition) { + sm.OnTransitioned(func(_ context.Context, tr Transition[string, string]) { actualOrdering = append(actualOrdering, fmt.Sprintf("OnTransitioned%v%v", tr.Source, tr.Destination)) }) - sm.Fire(triggerX) + sm.Fire(triggerX, nil) if got := sm.MustState(); got != stateC { t.Errorf("MustState() = %v, want %v", got, stateC) } @@ -1477,7 +1478,7 @@ func TestStateMachine_InitialTransition_Ordering(t *testing.T) { } func TestStateMachine_InitialTransition_DoesNotEnterSubStateofSubstate(t *testing.T) { - sm := NewStateMachine(stateA) + sm := NewStateMachine[string, string, any](stateA) sm.Configure(stateA). Permit(triggerX, stateB) @@ -1490,14 +1491,14 @@ func TestStateMachine_InitialTransition_DoesNotEnterSubStateofSubstate(t *testin sm.Configure(stateD). SubstateOf(stateC) - sm.Fire(triggerX) + sm.Fire(triggerX, nil) if got := sm.MustState(); got != stateB { t.Errorf("MustState() = %v, want %v", got, stateB) } } func TestStateMachine_InitialTransition_DoNotAllowTransitionToSelf(t *testing.T) { - sm := NewStateMachine(stateA) + sm := NewStateMachine[string, string, any](stateA) assertPanic(t, func() { sm.Configure(stateA). InitialTransition(stateA) @@ -1505,18 +1506,18 @@ func TestStateMachine_InitialTransition_DoNotAllowTransitionToSelf(t *testing.T) } func TestStateMachine_InitialTransition_WithMultipleSubStates(t *testing.T) { - sm := NewStateMachine(stateA) + sm := NewStateMachine[string, string, any](stateA) sm.Configure(stateA).Permit(triggerX, stateB) sm.Configure(stateB).InitialTransition(stateC) sm.Configure(stateC).SubstateOf(stateB) sm.Configure(stateD).SubstateOf(stateB) - if err := sm.Fire(triggerX); err != nil { + if err := sm.Fire(triggerX, nil); err != nil { t.Error(err) } } func TestStateMachine_InitialTransition_DoNotAllowTransitionToAnotherSuperstate(t *testing.T) { - sm := NewStateMachine(stateA) + sm := NewStateMachine[string, string, any](stateA) sm.Configure(stateA). Permit(triggerX, stateB) @@ -1524,11 +1525,11 @@ func TestStateMachine_InitialTransition_DoNotAllowTransitionToAnotherSuperstate( sm.Configure(stateB). InitialTransition(stateA) - assertPanic(t, func() { sm.Fire(triggerX) }) + assertPanic(t, func() { sm.Fire(triggerX, nil) }) } func TestStateMachine_InitialTransition_DoNotAllowMoreThanOneInitialTransition(t *testing.T) { - sm := NewStateMachine(stateA) + sm := NewStateMachine[string, string, any](stateA) sm.Configure(stateA). Permit(triggerX, stateB) @@ -1542,19 +1543,19 @@ func TestStateMachine_InitialTransition_DoNotAllowMoreThanOneInitialTransition(t func TestStateMachine_String(t *testing.T) { tests := []struct { name string - sm *StateMachine + sm *StateMachine[string, string, any] want string }{ - {"noTriggers", NewStateMachine(stateA), "StateMachine {{ State = A, PermittedTriggers = [] }}"}, - {"error state", NewStateMachineWithExternalStorage(func(_ context.Context) (State, error) { - return nil, errors.New("status error") - }, func(_ context.Context, s State) error { return nil }, FiringImmediate), ""}, - {"triggers", NewStateMachine(stateB).Configure(stateB).Permit(triggerX, stateA).Machine(), + {"noTriggers", NewStateMachine[string, string, any](stateA), "StateMachine {{ State = A, PermittedTriggers = [] }}"}, + {"error state", NewStateMachineWithExternalStorage[string, string, any](func(_ context.Context) (string, error) { + return "", errors.New("status error") + }, func(_ context.Context, s string) error { return nil }, FiringImmediate), ""}, + {"triggers", NewStateMachine[string, string, any](stateB).Configure(stateB).Permit(triggerX, stateA).Machine(), "StateMachine {{ State = B, PermittedTriggers = [X] }}"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := tt.sm.String(); got != tt.want { + if got := tt.sm.String(nil); got != tt.want { t.Errorf("StateMachine.String() = %v, want %v", got, tt.want) } }) @@ -1563,33 +1564,33 @@ func TestStateMachine_String(t *testing.T) { func TestStateMachine_String_Concurrent(t *testing.T) { // Test that race mode doesn't complain about concurrent access to the state machine. - sm := NewStateMachine(stateA) + sm := NewStateMachine[string, string, any](stateA) const n = 10 var wg sync.WaitGroup wg.Add(n) for i := 0; i < n; i++ { go func() { defer wg.Done() - _ = sm.String() + _ = sm.String(nil) }() } wg.Wait() } func TestStateMachine_Firing_Queued(t *testing.T) { - sm := NewStateMachine(stateA) + sm := NewStateMachine[string, string, any](stateA) sm.Configure(stateA). Permit(triggerX, stateB) sm.Configure(stateB). - OnEntry(func(ctx context.Context, i ...any) error { + OnEntry(func(ctx context.Context, _ any) error { if !sm.Firing() { t.Error("expected firing to be true") } return nil }) - if err := sm.Fire(triggerX); err != nil { + if err := sm.Fire(triggerX, nil); err != nil { t.Error(err) } if sm.Firing() { @@ -1598,19 +1599,19 @@ func TestStateMachine_Firing_Queued(t *testing.T) { } func TestStateMachine_Firing_Immediate(t *testing.T) { - sm := NewStateMachineWithMode(stateA, FiringImmediate) + sm := NewStateMachineWithMode[string, string, any](stateA, FiringImmediate) sm.Configure(stateA). Permit(triggerX, stateB) sm.Configure(stateB). - OnEntry(func(ctx context.Context, i ...any) error { + OnEntry(func(ctx context.Context, _ any) error { if !sm.Firing() { t.Error("expected firing to be true") } return nil }) - if err := sm.Fire(triggerX); err != nil { + if err := sm.Fire(triggerX, nil); err != nil { t.Error(err) } if sm.Firing() { @@ -1619,11 +1620,11 @@ func TestStateMachine_Firing_Immediate(t *testing.T) { } func TestStateMachine_Firing_Concurrent(t *testing.T) { - sm := NewStateMachine(stateA) + sm := NewStateMachine[string, string, any](stateA) sm.Configure(stateA). PermitReentry(triggerX). - OnEntry(func(ctx context.Context, i ...any) error { + OnEntry(func(ctx context.Context, _ any) error { if !sm.Firing() { t.Error("expected firing to be true") } @@ -1634,7 +1635,7 @@ func TestStateMachine_Firing_Concurrent(t *testing.T) { wg.Add(1000) for i := 0; i < 1000; i++ { go func() { - if err := sm.Fire(triggerX); err != nil { + if err := sm.Fire(triggerX, nil); err != nil { t.Error(err) } wg.Done() @@ -1648,7 +1649,7 @@ func TestStateMachine_Firing_Concurrent(t *testing.T) { func TestGetTransition_ContextEmpty(t *testing.T) { // It should not panic - GetTransition(context.Background()) + GetTransition[string, string](context.Background()) } func assertPanic(t *testing.T, f func()) { @@ -1662,11 +1663,11 @@ func assertPanic(t *testing.T, f func()) { } func TestStateMachineWhenInSubstate_TriggerSuperStateTwiceToSameSubstate_DoesNotReenterSubstate(t *testing.T) { - sm := NewStateMachine(stateA) + sm := NewStateMachine[string, string, any](stateA) var eCount = 0 sm.Configure(stateB). - OnEntry(func(_ context.Context, _ ...any) error { + OnEntry(func(_ context.Context, _ any) error { eCount++ return nil }). @@ -1678,8 +1679,8 @@ func TestStateMachineWhenInSubstate_TriggerSuperStateTwiceToSameSubstate_DoesNot sm.Configure(stateC). Permit(triggerX, stateB) - sm.Fire(triggerX) - sm.Fire(triggerX) + sm.Fire(triggerX, nil) + sm.Fire(triggerX, nil) if eCount != 1 { t.Errorf("expected 1, got %d", eCount) diff --git a/states.go b/states.go index f11a148..7d05bdc 100644 --- a/states.go +++ b/states.go @@ -5,16 +5,16 @@ import ( "fmt" ) -type actionBehaviour struct { - Action ActionFunc +type actionBehaviour[S State, T Trigger, A any] struct { + Action ActionFunc[A] Description invocationInfo - Trigger *Trigger + Trigger *T } -func (a actionBehaviour) Execute(ctx context.Context, transition Transition, args ...any) (err error) { +func (a actionBehaviour[S, T, A]) Execute(ctx context.Context, transition Transition[S, T], arg A) (err error) { if a.Trigger == nil || *a.Trigger == transition.Trigger { ctx = withTransition(ctx, transition) - err = a.Action(ctx, args...) + err = a.Action(ctx, arg) } return } @@ -28,57 +28,57 @@ func (a actionBehaviourSteady) Execute(ctx context.Context) error { return a.Action(ctx) } -type stateRepresentation struct { - State State - InitialTransitionTarget State - Superstate *stateRepresentation - EntryActions []actionBehaviour - ExitActions []actionBehaviour +type stateRepresentation[S State, T Trigger, A any] struct { + State S + InitialTransitionTarget S + Superstate *stateRepresentation[S, T, A] + EntryActions []actionBehaviour[S, T, A] + ExitActions []actionBehaviour[S, T, A] ActivateActions []actionBehaviourSteady DeactivateActions []actionBehaviourSteady - Substates []*stateRepresentation - TriggerBehaviours map[Trigger][]triggerBehaviour + Substates []*stateRepresentation[S, T, A] + TriggerBehaviours map[T][]triggerBehaviour[T, A] HasInitialState bool } -func newstateRepresentation(state State) *stateRepresentation { - return &stateRepresentation{ +func newstateRepresentation[S State, T Trigger, A any](state S) *stateRepresentation[S, T, A] { + return &stateRepresentation[S, T, A]{ State: state, - TriggerBehaviours: make(map[Trigger][]triggerBehaviour), + TriggerBehaviours: make(map[T][]triggerBehaviour[T, A]), } } -func (sr *stateRepresentation) SetInitialTransition(state State) { +func (sr *stateRepresentation[S, _, _]) SetInitialTransition(state S) { sr.InitialTransitionTarget = state sr.HasInitialState = true } -func (sr *stateRepresentation) state() State { +func (sr *stateRepresentation[S, _, _]) state() S { return sr.State } -func (sr *stateRepresentation) CanHandle(ctx context.Context, trigger Trigger, args ...any) (ok bool) { - _, ok = sr.FindHandler(ctx, trigger, args...) +func (sr *stateRepresentation[_, T, A]) CanHandle(ctx context.Context, trigger T, arg A) (ok bool) { + _, ok = sr.FindHandler(ctx, trigger, arg) return } -func (sr *stateRepresentation) FindHandler(ctx context.Context, trigger Trigger, args ...any) (handler triggerBehaviourResult, ok bool) { - handler, ok = sr.findHandler(ctx, trigger, args...) +func (sr *stateRepresentation[_, T, A]) FindHandler(ctx context.Context, trigger T, arg A) (handler triggerBehaviourResult[T, A], ok bool) { + handler, ok = sr.findHandler(ctx, trigger, arg) if ok || sr.Superstate == nil { return } - handler, ok = sr.Superstate.FindHandler(ctx, trigger, args...) + handler, ok = sr.Superstate.FindHandler(ctx, trigger, arg) return } -func (sr *stateRepresentation) findHandler(ctx context.Context, trigger Trigger, args ...any) (result triggerBehaviourResult, ok bool) { +func (sr *stateRepresentation[_, T, A]) findHandler(ctx context.Context, trigger T, arg A) (result triggerBehaviourResult[T, A], ok bool) { possibleBehaviours, ok := sr.TriggerBehaviours[trigger] if !ok { return } var unmet []string for _, behaviour := range possibleBehaviours { - unmet = behaviour.UnmetGuardConditions(ctx, unmet[:0], args...) + unmet = behaviour.UnmetGuardConditions(ctx, unmet[:0], arg) // , arg) if len(unmet) == 0 { if result.Handler != nil && len(result.UnmetGuardConditions) == 0 { panic(fmt.Sprintf("stateless: Multiple permitted exit transitions are configured from state '%v' for trigger '%v'. Guard clauses must be mutually exclusive.", sr.State, trigger)) @@ -94,7 +94,7 @@ func (sr *stateRepresentation) findHandler(ctx context.Context, trigger Trigger, return result, result.Handler != nil && len(result.UnmetGuardConditions) == 0 } -func (sr *stateRepresentation) Activate(ctx context.Context) error { +func (sr *stateRepresentation[S, _, _]) Activate(ctx context.Context) error { if sr.Superstate != nil { if err := sr.Superstate.Activate(ctx); err != nil { return err @@ -103,7 +103,7 @@ func (sr *stateRepresentation) Activate(ctx context.Context) error { return sr.executeActivationActions(ctx) } -func (sr *stateRepresentation) Deactivate(ctx context.Context) error { +func (sr *stateRepresentation[S, _, _]) Deactivate(ctx context.Context) error { if err := sr.executeDeactivationActions(ctx); err != nil { return err } @@ -113,51 +113,51 @@ func (sr *stateRepresentation) Deactivate(ctx context.Context) error { return nil } -func (sr *stateRepresentation) Enter(ctx context.Context, transition Transition, args ...any) error { +func (sr *stateRepresentation[S, T, A]) Enter(ctx context.Context, transition Transition[S, T], arg A) error { if transition.IsReentry() { - return sr.executeEntryActions(ctx, transition, args...) + return sr.executeEntryActions(ctx, transition, arg) } if sr.IncludeState(transition.Source) { return nil } if sr.Superstate != nil && !transition.isInitial { - if err := sr.Superstate.Enter(ctx, transition, args...); err != nil { + if err := sr.Superstate.Enter(ctx, transition, arg); err != nil { return err } } - return sr.executeEntryActions(ctx, transition, args...) + return sr.executeEntryActions(ctx, transition, arg) } -func (sr *stateRepresentation) Exit(ctx context.Context, transition Transition, args ...any) (err error) { +func (sr *stateRepresentation[S, T, A]) Exit(ctx context.Context, transition Transition[S, T], arg A) (err error) { isReentry := transition.IsReentry() if !isReentry && sr.IncludeState(transition.Destination) { return } - err = sr.executeExitActions(ctx, transition, args...) + err = sr.executeExitActions(ctx, transition, arg) // Must check if there is a superstate, and if we are leaving that superstate if err == nil && !isReentry && sr.Superstate != nil { // Check if destination is within the state list if sr.IsIncludedInState(transition.Destination) { // Destination state is within the list, exit first superstate only if it is NOT the the first if sr.Superstate.state() != transition.Destination { - err = sr.Superstate.Exit(ctx, transition, args...) + err = sr.Superstate.Exit(ctx, transition, arg) } } else { // Exit the superstate as well - err = sr.Superstate.Exit(ctx, transition, args...) + err = sr.Superstate.Exit(ctx, transition, arg) } } return } -func (sr *stateRepresentation) InternalAction(ctx context.Context, transition Transition, args ...any) error { - var internalTransition *internalTriggerBehaviour - var stateRep *stateRepresentation = sr +func (sr *stateRepresentation[S, T, A]) InternalAction(ctx context.Context, transition Transition[S, T], arg A) error { + var internalTransition *internalTriggerBehaviour[S, T, A] + var stateRep = sr for stateRep != nil { - if result, ok := stateRep.findHandler(ctx, transition.Trigger, args...); ok { + if result, ok := stateRep.findHandler(ctx, transition.Trigger, arg); ok { switch t := result.Handler.(type) { - case *internalTriggerBehaviour: + case *internalTriggerBehaviour[S, T, A]: internalTransition = t } break @@ -167,10 +167,10 @@ func (sr *stateRepresentation) InternalAction(ctx context.Context, transition Tr if internalTransition == nil { panic("stateless: The configuration is incorrect, no action assigned to this internal transition.") } - return internalTransition.Execute(ctx, transition, args...) + return internalTransition.Execute(ctx, transition, arg) } -func (sr *stateRepresentation) IncludeState(state State) bool { +func (sr *stateRepresentation[S, _, _]) IncludeState(state S) bool { if state == sr.State { return true } @@ -182,7 +182,7 @@ func (sr *stateRepresentation) IncludeState(state State) bool { return false } -func (sr *stateRepresentation) IsIncludedInState(state State) bool { +func (sr *stateRepresentation[S, _, _]) IsIncludedInState(state S) bool { if state == sr.State { return true } @@ -192,26 +192,26 @@ func (sr *stateRepresentation) IsIncludedInState(state State) bool { return false } -func (sr *stateRepresentation) AddTriggerBehaviour(tb triggerBehaviour) { +func (sr *stateRepresentation[_, T, A]) AddTriggerBehaviour(tb triggerBehaviour[T, A]) { trigger := tb.GetTrigger() sr.TriggerBehaviours[trigger] = append(sr.TriggerBehaviours[trigger], tb) } -func (sr *stateRepresentation) PermittedTriggers(ctx context.Context, args ...any) (triggers []Trigger) { +func (sr *stateRepresentation[_, T, A]) PermittedTriggers(ctx context.Context, arg A) (triggers []T) { var unmet []string for key, value := range sr.TriggerBehaviours { for _, tb := range value { - if len(tb.UnmetGuardConditions(ctx, unmet[:0], args...)) == 0 { + if len(tb.UnmetGuardConditions(ctx, unmet[:0], arg)) == 0 { triggers = append(triggers, key) break } } } if sr.Superstate != nil { - triggers = append(triggers, sr.Superstate.PermittedTriggers(ctx, args...)...) + triggers = append(triggers, sr.Superstate.PermittedTriggers(ctx, arg)...) // remove duplicated - seen := make(map[Trigger]struct{}, len(triggers)) + seen := make(map[T]struct{}, len(triggers)) j := 0 for _, v := range triggers { if _, ok := seen[v]; ok { @@ -226,7 +226,7 @@ func (sr *stateRepresentation) PermittedTriggers(ctx context.Context, args ...an return } -func (sr *stateRepresentation) executeActivationActions(ctx context.Context) error { +func (sr *stateRepresentation[S, _, _]) executeActivationActions(ctx context.Context) error { for _, a := range sr.ActivateActions { if err := a.Execute(ctx); err != nil { return err @@ -235,7 +235,7 @@ func (sr *stateRepresentation) executeActivationActions(ctx context.Context) err return nil } -func (sr *stateRepresentation) executeDeactivationActions(ctx context.Context) error { +func (sr *stateRepresentation[S, _, _]) executeDeactivationActions(ctx context.Context) error { for _, a := range sr.DeactivateActions { if err := a.Execute(ctx); err != nil { return err @@ -244,18 +244,18 @@ func (sr *stateRepresentation) executeDeactivationActions(ctx context.Context) e return nil } -func (sr *stateRepresentation) executeEntryActions(ctx context.Context, transition Transition, args ...any) error { +func (sr *stateRepresentation[S, T, A]) executeEntryActions(ctx context.Context, transition Transition[S, T], arg A) error { for _, a := range sr.EntryActions { - if err := a.Execute(ctx, transition, args...); err != nil { + if err := a.Execute(ctx, transition, arg); err != nil { return err } } return nil } -func (sr *stateRepresentation) executeExitActions(ctx context.Context, transition Transition, args ...any) error { +func (sr *stateRepresentation[S, T, A]) executeExitActions(ctx context.Context, transition Transition[S, T], arg A) error { for _, a := range sr.ExitActions { - if err := a.Execute(ctx, transition, args...); err != nil { + if err := a.Execute(ctx, transition, arg); err != nil { return err } } diff --git a/states_test.go b/states_test.go index 6cd993d..6cd5113 100644 --- a/states_test.go +++ b/states_test.go @@ -7,172 +7,172 @@ import ( "testing" ) -func createSuperSubstatePair() (*stateRepresentation, *stateRepresentation) { - super := newstateRepresentation(stateA) - sub := newstateRepresentation(stateB) +func createSuperSubstatePair() (*stateRepresentation[string, string, any], *stateRepresentation[string, string, any]) { + super := newstateRepresentation[string, string, any](stateA) + sub := newstateRepresentation[string, string, any](stateB) super.Substates = append(super.Substates, sub) sub.Superstate = super return super, sub } func Test_stateRepresentation_Includes_SameState(t *testing.T) { - sr := newstateRepresentation(stateB) + sr := newstateRepresentation[string, string, any](stateB) if !sr.IncludeState(stateB) { t.Fail() } } func Test_stateRepresentation_Includes_Substate(t *testing.T) { - sr := newstateRepresentation(stateB) - sr.Substates = append(sr.Substates, newstateRepresentation(stateC)) + sr := newstateRepresentation[string, string, any](stateB) + sr.Substates = append(sr.Substates, newstateRepresentation[string, string, any](stateC)) if !sr.IncludeState(stateC) { t.Fail() } } func Test_stateRepresentation_Includes_UnrelatedState(t *testing.T) { - sr := newstateRepresentation(stateB) + sr := newstateRepresentation[string, string, any](stateB) if sr.IncludeState(stateC) { t.Fail() } } func Test_stateRepresentation_Includes_Superstate(t *testing.T) { - sr := newstateRepresentation(stateB) - sr.Superstate = newstateRepresentation(stateC) + sr := newstateRepresentation[string, string, any](stateB) + sr.Superstate = newstateRepresentation[string, string, any](stateC) if sr.IncludeState(stateC) { t.Fail() } } func Test_stateRepresentation_IsIncludedInState_SameState(t *testing.T) { - sr := newstateRepresentation(stateB) + sr := newstateRepresentation[string, string, any](stateB) if !sr.IsIncludedInState(stateB) { t.Fail() } } func Test_stateRepresentation_IsIncludedInState_Substate(t *testing.T) { - sr := newstateRepresentation(stateB) - sr.Substates = append(sr.Substates, newstateRepresentation(stateC)) + sr := newstateRepresentation[string, string, any](stateB) + sr.Substates = append(sr.Substates, newstateRepresentation[string, string, any](stateC)) if sr.IsIncludedInState(stateC) { t.Fail() } } func Test_stateRepresentation_IsIncludedInState_UnrelatedState(t *testing.T) { - sr := newstateRepresentation(stateB) + sr := newstateRepresentation[string, string, any](stateB) if sr.IsIncludedInState(stateC) { t.Fail() } } func Test_stateRepresentation_IsIncludedInState_Superstate(t *testing.T) { - sr := newstateRepresentation(stateB) + sr := newstateRepresentation[string, string, any](stateB) if sr.IsIncludedInState(stateC) { t.Fail() } } func Test_stateRepresentation_CanHandle_TransitionExists_TriggerCannotBeFired(t *testing.T) { - sr := newstateRepresentation(stateB) - if sr.CanHandle(context.Background(), triggerX) { + sr := newstateRepresentation[string, string, any](stateB) + if sr.CanHandle(context.Background(), triggerX, nil) { t.Fail() } } func Test_stateRepresentation_CanHandle_TransitionDoesNotExist_TriggerCanBeFired(t *testing.T) { - sr := newstateRepresentation(stateB) - sr.AddTriggerBehaviour(&ignoredTriggerBehaviour{baseTriggerBehaviour: baseTriggerBehaviour{Trigger: triggerX}}) - if !sr.CanHandle(context.Background(), triggerX) { + sr := newstateRepresentation[string, string, any](stateB) + sr.AddTriggerBehaviour(&ignoredTriggerBehaviour[string, any]{baseTriggerBehaviour: baseTriggerBehaviour[string, any]{Trigger: triggerX}}) + if !sr.CanHandle(context.Background(), triggerX, nil) { t.Fail() } } func Test_stateRepresentation_CanHandle_TransitionExistsInSupersate_TriggerCanBeFired(t *testing.T) { super, sub := createSuperSubstatePair() - super.AddTriggerBehaviour(&ignoredTriggerBehaviour{baseTriggerBehaviour: baseTriggerBehaviour{Trigger: triggerX}}) - if !sub.CanHandle(context.Background(), triggerX) { + super.AddTriggerBehaviour(&ignoredTriggerBehaviour[string, any]{baseTriggerBehaviour: baseTriggerBehaviour[string, any]{Trigger: triggerX}}) + if !sub.CanHandle(context.Background(), triggerX, nil) { t.Fail() } } func Test_stateRepresentation_CanHandle_TransitionUnmetGuardConditions_TriggerCannotBeFired(t *testing.T) { - sr := newstateRepresentation(stateB) - sr.AddTriggerBehaviour(&transitioningTriggerBehaviour{baseTriggerBehaviour: baseTriggerBehaviour{ + sr := newstateRepresentation[string, string, any](stateB) + sr.AddTriggerBehaviour(&transitioningTriggerBehaviour[string, string, any]{baseTriggerBehaviour: baseTriggerBehaviour[string, any]{ Trigger: triggerX, - Guard: newtransitionGuard(func(_ context.Context, _ ...any) bool { + Guard: newtransitionGuard(func(_ context.Context, _ any) bool { return true - }, func(_ context.Context, _ ...any) bool { + }, func(_ context.Context, _ any) bool { return false }), }, Destination: stateC}) - if sr.CanHandle(context.Background(), triggerX) { + if sr.CanHandle(context.Background(), triggerX, nil) { t.Fail() } } func Test_stateRepresentation_CanHandle_TransitionGuardConditionsMet_TriggerCanBeFired(t *testing.T) { - sr := newstateRepresentation(stateB) - sr.AddTriggerBehaviour(&transitioningTriggerBehaviour{baseTriggerBehaviour: baseTriggerBehaviour{ + sr := newstateRepresentation[string, string, any](stateB) + sr.AddTriggerBehaviour(&transitioningTriggerBehaviour[string, string, any]{baseTriggerBehaviour: baseTriggerBehaviour[string, any]{ Trigger: triggerX, - Guard: newtransitionGuard(func(_ context.Context, _ ...any) bool { + Guard: newtransitionGuard(func(_ context.Context, _ any) bool { return true - }, func(_ context.Context, _ ...any) bool { + }, func(_ context.Context, _ any) bool { return true }), }, Destination: stateC}) - if !sr.CanHandle(context.Background(), triggerX) { + if !sr.CanHandle(context.Background(), triggerX, nil) { t.Fail() } } func Test_stateRepresentation_FindHandler_TransitionExistAndSuperstateUnmetGuardConditions_FireNotPossible(t *testing.T) { super, sub := createSuperSubstatePair() - super.AddTriggerBehaviour(&transitioningTriggerBehaviour{baseTriggerBehaviour: baseTriggerBehaviour{ + super.AddTriggerBehaviour(&transitioningTriggerBehaviour[string, string, any]{baseTriggerBehaviour: baseTriggerBehaviour[string, any]{ Trigger: triggerX, - Guard: newtransitionGuard(func(_ context.Context, _ ...any) bool { + Guard: newtransitionGuard(func(_ context.Context, _ any) bool { return true - }, func(_ context.Context, _ ...any) bool { + }, func(_ context.Context, _ any) bool { return false }), }, Destination: stateC}) - handler, ok := sub.FindHandler(context.Background(), triggerX) + handler, ok := sub.FindHandler(context.Background(), triggerX, nil) if ok { t.Fail() } - if sub.CanHandle(context.Background(), triggerX) { + if sub.CanHandle(context.Background(), triggerX, nil) { t.Fail() } - if super.CanHandle(context.Background(), triggerX) { + if super.CanHandle(context.Background(), triggerX, nil) { t.Fail() } - if handler.Handler.GuardConditionMet(context.Background()) { + if handler.Handler.GuardConditionMet(context.Background(), nil) { t.Fail() } } func Test_stateRepresentation_FindHandler_TransitionExistSuperstateMetGuardConditions_CanBeFired(t *testing.T) { super, sub := createSuperSubstatePair() - super.AddTriggerBehaviour(&transitioningTriggerBehaviour{baseTriggerBehaviour: baseTriggerBehaviour{ + super.AddTriggerBehaviour(&transitioningTriggerBehaviour[string, string, any]{baseTriggerBehaviour: baseTriggerBehaviour[string, any]{ Trigger: triggerX, - Guard: newtransitionGuard(func(_ context.Context, _ ...any) bool { + Guard: newtransitionGuard(func(_ context.Context, _ any) bool { return true - }, func(_ context.Context, _ ...any) bool { + }, func(_ context.Context, _ any) bool { return true }), }, Destination: stateC}) - handler, ok := sub.FindHandler(context.Background(), triggerX) + handler, ok := sub.FindHandler(context.Background(), triggerX, nil) if !ok { t.Fail() } - if !sub.CanHandle(context.Background(), triggerX) { + if !sub.CanHandle(context.Background(), triggerX, nil) { t.Fail() } - if !super.CanHandle(context.Background(), triggerX) { + if !super.CanHandle(context.Background(), triggerX, nil) { t.Fail() } - if !handler.Handler.GuardConditionMet(context.Background()) { + if !handler.Handler.GuardConditionMet(context.Background(), nil) { t.Error("expected guard condition to be met") } if len(handler.UnmetGuardConditions) != 0 { @@ -181,16 +181,16 @@ func Test_stateRepresentation_FindHandler_TransitionExistSuperstateMetGuardCondi } func Test_stateRepresentation_Enter_EnteringActionsExecuted(t *testing.T) { - sr := newstateRepresentation(stateB) - transition := Transition{Source: stateA, Destination: stateB, Trigger: triggerX} - var actualTransition Transition - sr.EntryActions = append(sr.EntryActions, actionBehaviour{ - Action: func(_ context.Context, _ ...any) error { + sr := newstateRepresentation[string, string, any](stateB) + transition := Transition[string, string]{Source: stateA, Destination: stateB, Trigger: triggerX} + var actualTransition Transition[string, string] + sr.EntryActions = append(sr.EntryActions, actionBehaviour[string, string, any]{ + Action: func(_ context.Context, _ any) error { actualTransition = transition return nil }, }) - if err := sr.Enter(context.Background(), transition); err != nil { + if err := sr.Enter(context.Background(), transition, nil); err != nil { t.Error(err) } if !reflect.DeepEqual(transition, actualTransition) { @@ -199,15 +199,15 @@ func Test_stateRepresentation_Enter_EnteringActionsExecuted(t *testing.T) { } func Test_stateRepresentation_Enter_EnteringActionsExecuted_Error(t *testing.T) { - sr := newstateRepresentation(stateB) - transition := Transition{Source: stateA, Destination: stateB, Trigger: triggerX} - var actualTransition Transition - sr.EntryActions = append(sr.EntryActions, actionBehaviour{ - Action: func(_ context.Context, _ ...any) error { + sr := newstateRepresentation[string, string, any](stateB) + transition := Transition[string, string]{Source: stateA, Destination: stateB, Trigger: triggerX} + var actualTransition Transition[string, string] + sr.EntryActions = append(sr.EntryActions, actionBehaviour[string, string, any]{ + Action: func(_ context.Context, _ any) error { return errors.New("") }, }) - if err := sr.Enter(context.Background(), transition); err == nil { + if err := sr.Enter(context.Background(), transition, nil); err == nil { t.Error("error expected") } if reflect.DeepEqual(transition, actualTransition) { @@ -216,17 +216,17 @@ func Test_stateRepresentation_Enter_EnteringActionsExecuted_Error(t *testing.T) } func Test_stateRepresentation_Enter_LeavingActionsNotExecuted(t *testing.T) { - sr := newstateRepresentation(stateA) - transition := Transition{Source: stateA, Destination: stateB, Trigger: triggerX} - var actualTransition Transition - sr.ExitActions = append(sr.ExitActions, actionBehaviour{ - Action: func(_ context.Context, _ ...any) error { + sr := newstateRepresentation[string, string, any](stateA) + transition := Transition[string, string]{Source: stateA, Destination: stateB, Trigger: triggerX} + var actualTransition Transition[string, string] + sr.ExitActions = append(sr.ExitActions, actionBehaviour[string, string, any]{ + Action: func(_ context.Context, _ any) error { actualTransition = transition return nil }, }) - sr.Enter(context.Background(), transition) - if actualTransition != (Transition{}) { + sr.Enter(context.Background(), transition, nil) + if actualTransition != (Transition[string, string]{}) { t.Error("expected transition to not be passed to action") } } @@ -234,14 +234,14 @@ func Test_stateRepresentation_Enter_LeavingActionsNotExecuted(t *testing.T) { func Test_stateRepresentation_Enter_FromSubToSuperstate_SubstateEntryActionsExecuted(t *testing.T) { super, sub := createSuperSubstatePair() executed := false - sub.EntryActions = append(sub.EntryActions, actionBehaviour{ - Action: func(_ context.Context, _ ...any) error { + sub.EntryActions = append(sub.EntryActions, actionBehaviour[string, string, any]{ + Action: func(_ context.Context, _ any) error { executed = true return nil }, }) - transition := Transition{Source: super.State, Destination: sub.State, Trigger: triggerX} - sub.Enter(context.Background(), transition) + transition := Transition[string, string]{Source: super.State, Destination: sub.State, Trigger: triggerX} + sub.Enter(context.Background(), transition, nil) if !executed { t.Error("expected substate entry actions to be executed") } @@ -250,14 +250,14 @@ func Test_stateRepresentation_Enter_FromSubToSuperstate_SubstateEntryActionsExec func Test_stateRepresentation_Enter_SuperFromSubstate_SuperEntryActionsNotExecuted(t *testing.T) { super, sub := createSuperSubstatePair() executed := false - super.EntryActions = append(super.EntryActions, actionBehaviour{ - Action: func(_ context.Context, _ ...any) error { + super.EntryActions = append(super.EntryActions, actionBehaviour[string, string, any]{ + Action: func(_ context.Context, _ any) error { executed = true return nil }, }) - transition := Transition{Source: super.State, Destination: sub.State, Trigger: triggerX} - sub.Enter(context.Background(), transition) + transition := Transition[string, string]{Source: super.State, Destination: sub.State, Trigger: triggerX} + sub.Enter(context.Background(), transition, nil) if executed { t.Error("expected superstate entry actions not to be executed") } @@ -266,14 +266,14 @@ func Test_stateRepresentation_Enter_SuperFromSubstate_SuperEntryActionsNotExecut func Test_stateRepresentation_Enter_Substate_SuperEntryActionsExecuted(t *testing.T) { super, sub := createSuperSubstatePair() executed := false - super.EntryActions = append(super.EntryActions, actionBehaviour{ - Action: func(_ context.Context, _ ...any) error { + super.EntryActions = append(super.EntryActions, actionBehaviour[string, string, any]{ + Action: func(_ context.Context, _ any) error { executed = true return nil }, }) - transition := Transition{Source: stateC, Destination: sub.State, Trigger: triggerX} - sub.Enter(context.Background(), transition) + transition := Transition[string, string]{Source: stateC, Destination: sub.State, Trigger: triggerX} + sub.Enter(context.Background(), transition, nil) if !executed { t.Error("expected superstate entry actions to be executed") } @@ -281,21 +281,21 @@ func Test_stateRepresentation_Enter_Substate_SuperEntryActionsExecuted(t *testin func Test_stateRepresentation_Enter_ActionsExecuteInOrder(t *testing.T) { var actual []int - sr := newstateRepresentation(stateB) - sr.EntryActions = append(sr.EntryActions, actionBehaviour{ - Action: func(_ context.Context, _ ...any) error { + sr := newstateRepresentation[string, string, any](stateB) + sr.EntryActions = append(sr.EntryActions, actionBehaviour[string, string, any]{ + Action: func(_ context.Context, _ any) error { actual = append(actual, 0) return nil }, }) - sr.EntryActions = append(sr.EntryActions, actionBehaviour{ - Action: func(_ context.Context, _ ...any) error { + sr.EntryActions = append(sr.EntryActions, actionBehaviour[string, string, any]{ + Action: func(_ context.Context, _ any) error { actual = append(actual, 1) return nil }, }) - transition := Transition{Source: stateA, Destination: stateB, Trigger: triggerX} - sr.Enter(context.Background(), transition) + transition := Transition[string, string]{Source: stateA, Destination: stateB, Trigger: triggerX} + sr.Enter(context.Background(), transition, nil) want := []int{0, 1} if !reflect.DeepEqual(actual, want) { t.Errorf("expected %v, got %v", want, actual) @@ -305,54 +305,54 @@ func Test_stateRepresentation_Enter_ActionsExecuteInOrder(t *testing.T) { func Test_stateRepresentation_Enter_Substate_SuperstateEntryActionsExecuteBeforeSubstate(t *testing.T) { super, sub := createSuperSubstatePair() var order, subOrder, superOrder int - super.EntryActions = append(super.EntryActions, actionBehaviour{ - Action: func(_ context.Context, _ ...any) error { + super.EntryActions = append(super.EntryActions, actionBehaviour[string, string, any]{ + Action: func(_ context.Context, _ any) error { order += 1 superOrder = order return nil }, }) - sub.EntryActions = append(sub.EntryActions, actionBehaviour{ - Action: func(_ context.Context, _ ...any) error { + sub.EntryActions = append(sub.EntryActions, actionBehaviour[string, string, any]{ + Action: func(_ context.Context, _ any) error { order += 1 subOrder = order return nil }, }) - transition := Transition{Source: stateC, Destination: sub.State, Trigger: triggerX} - sub.Enter(context.Background(), transition) + transition := Transition[string, string]{Source: stateC, Destination: sub.State, Trigger: triggerX} + sub.Enter(context.Background(), transition, nil) if superOrder >= subOrder { t.Error("expected superstate entry actions to execute before substate entry actions") } } func Test_stateRepresentation_Exit_EnteringActionsNotExecuted(t *testing.T) { - sr := newstateRepresentation(stateB) - transition := Transition{Source: stateA, Destination: stateB, Trigger: triggerX} - var actualTransition Transition - sr.EntryActions = append(sr.EntryActions, actionBehaviour{ - Action: func(_ context.Context, _ ...any) error { + sr := newstateRepresentation[string, string, any](stateB) + transition := Transition[string, string]{Source: stateA, Destination: stateB, Trigger: triggerX} + var actualTransition Transition[string, string] + sr.EntryActions = append(sr.EntryActions, actionBehaviour[string, string, any]{ + Action: func(_ context.Context, _ any) error { actualTransition = transition return nil }, }) - sr.Exit(context.Background(), transition) - if actualTransition != (Transition{}) { + sr.Exit(context.Background(), transition, nil) + if actualTransition != (Transition[string, string]{}) { t.Error("expected transition to not be passed to action") } } func Test_stateRepresentation_Exit_LeavingActionsExecuted(t *testing.T) { - sr := newstateRepresentation(stateA) - transition := Transition{Source: stateA, Destination: stateB, Trigger: triggerX} - var actualTransition Transition - sr.ExitActions = append(sr.ExitActions, actionBehaviour{ - Action: func(_ context.Context, _ ...any) error { + sr := newstateRepresentation[string, string, any](stateA) + transition := Transition[string, string]{Source: stateA, Destination: stateB, Trigger: triggerX} + var actualTransition Transition[string, string] + sr.ExitActions = append(sr.ExitActions, actionBehaviour[string, string, any]{ + Action: func(_ context.Context, _ any) error { actualTransition = transition return nil }, }) - if err := sr.Exit(context.Background(), transition); err != nil { + if err := sr.Exit(context.Background(), transition, nil); err != nil { t.Error(err) } if actualTransition != transition { @@ -361,15 +361,15 @@ func Test_stateRepresentation_Exit_LeavingActionsExecuted(t *testing.T) { } func Test_stateRepresentation_Exit_LeavingActionsExecuted_Error(t *testing.T) { - sr := newstateRepresentation(stateA) - transition := Transition{Source: stateA, Destination: stateB, Trigger: triggerX} - var actualTransition Transition - sr.ExitActions = append(sr.ExitActions, actionBehaviour{ - Action: func(_ context.Context, _ ...any) error { + sr := newstateRepresentation[string, string, any](stateA) + transition := Transition[string, string]{Source: stateA, Destination: stateB, Trigger: triggerX} + var actualTransition Transition[string, string] + sr.ExitActions = append(sr.ExitActions, actionBehaviour[string, string, any]{ + Action: func(_ context.Context, _ any) error { return errors.New("") }, }) - if err := sr.Exit(context.Background(), transition); err == nil { + if err := sr.Exit(context.Background(), transition, nil); err == nil { t.Error("expected error") } if actualTransition == transition { @@ -380,14 +380,14 @@ func Test_stateRepresentation_Exit_LeavingActionsExecuted_Error(t *testing.T) { func Test_stateRepresentation_Exit_FromSubToSuperstate_SubstateExitActionsExecuted(t *testing.T) { super, sub := createSuperSubstatePair() executed := false - sub.ExitActions = append(sub.ExitActions, actionBehaviour{ - Action: func(_ context.Context, _ ...any) error { + sub.ExitActions = append(sub.ExitActions, actionBehaviour[string, string, any]{ + Action: func(_ context.Context, _ any) error { executed = true return nil }, }) - transition := Transition{Source: sub.State, Destination: super.State, Trigger: triggerX} - sub.Exit(context.Background(), transition) + transition := Transition[string, string]{Source: sub.State, Destination: super.State, Trigger: triggerX} + sub.Exit(context.Background(), transition, nil) if !executed { t.Error("expected substate exit actions to be executed") } @@ -395,18 +395,18 @@ func Test_stateRepresentation_Exit_FromSubToSuperstate_SubstateExitActionsExecut func Test_stateRepresentation_Exit_FromSubToOther_SuperstateExitActionsExecuted(t *testing.T) { super, sub := createSuperSubstatePair() - supersuper := newstateRepresentation(stateC) + supersuper := newstateRepresentation[string, string, any](stateC) super.Superstate = supersuper - supersuper.Superstate = newstateRepresentation(stateD) + supersuper.Superstate = newstateRepresentation[string, string, any](stateD) executed := false - super.ExitActions = append(super.ExitActions, actionBehaviour{ - Action: func(_ context.Context, _ ...any) error { + super.ExitActions = append(super.ExitActions, actionBehaviour[string, string, any]{ + Action: func(_ context.Context, _ any) error { executed = true return nil }, }) - transition := Transition{Source: sub.State, Destination: stateD, Trigger: triggerX} - sub.Exit(context.Background(), transition) + transition := Transition[string, string]{Source: sub.State, Destination: stateD, Trigger: triggerX} + sub.Exit(context.Background(), transition, nil) if !executed { t.Error("expected superstate exit actions to be executed") } @@ -415,14 +415,14 @@ func Test_stateRepresentation_Exit_FromSubToOther_SuperstateExitActionsExecuted( func Test_stateRepresentation_Exit_FromSuperToSubstate_SuperExitActionsNotExecuted(t *testing.T) { super, sub := createSuperSubstatePair() executed := false - super.ExitActions = append(super.ExitActions, actionBehaviour{ - Action: func(_ context.Context, _ ...any) error { + super.ExitActions = append(super.ExitActions, actionBehaviour[string, string, any]{ + Action: func(_ context.Context, _ any) error { executed = true return nil }, }) - transition := Transition{Source: super.State, Destination: sub.State, Trigger: triggerX} - sub.Exit(context.Background(), transition) + transition := Transition[string, string]{Source: super.State, Destination: sub.State, Trigger: triggerX} + sub.Exit(context.Background(), transition, nil) if executed { t.Error("expected superstate exit actions to not be executed") } @@ -431,14 +431,14 @@ func Test_stateRepresentation_Exit_FromSuperToSubstate_SuperExitActionsNotExecut func Test_stateRepresentation_Exit_Substate_SuperExitActionsExecuted(t *testing.T) { super, sub := createSuperSubstatePair() executed := false - super.ExitActions = append(super.ExitActions, actionBehaviour{ - Action: func(_ context.Context, _ ...any) error { + super.ExitActions = append(super.ExitActions, actionBehaviour[string, string, any]{ + Action: func(_ context.Context, _ any) error { executed = true return nil }, }) - transition := Transition{Source: sub.State, Destination: stateC, Trigger: triggerX} - sub.Exit(context.Background(), transition) + transition := Transition[string, string]{Source: sub.State, Destination: stateC, Trigger: triggerX} + sub.Exit(context.Background(), transition, nil) if !executed { t.Error("expected superstate exit actions to be executed") } @@ -446,21 +446,21 @@ func Test_stateRepresentation_Exit_Substate_SuperExitActionsExecuted(t *testing. func Test_stateRepresentation_Exit_ActionsExecuteInOrder(t *testing.T) { var actual []int - sr := newstateRepresentation(stateB) - sr.ExitActions = append(sr.ExitActions, actionBehaviour{ - Action: func(_ context.Context, _ ...any) error { + sr := newstateRepresentation[string, string, any](stateB) + sr.ExitActions = append(sr.ExitActions, actionBehaviour[string, string, any]{ + Action: func(_ context.Context, _ any) error { actual = append(actual, 0) return nil }, }) - sr.ExitActions = append(sr.ExitActions, actionBehaviour{ - Action: func(_ context.Context, _ ...any) error { + sr.ExitActions = append(sr.ExitActions, actionBehaviour[string, string, any]{ + Action: func(_ context.Context, _ any) error { actual = append(actual, 1) return nil }, }) - transition := Transition{Source: stateB, Destination: stateC, Trigger: triggerX} - sr.Exit(context.Background(), transition) + transition := Transition[string, string]{Source: stateB, Destination: stateC, Trigger: triggerX} + sr.Exit(context.Background(), transition, nil) want := []int{0, 1} if !reflect.DeepEqual(actual, want) { t.Errorf("expected %v, got %v", want, actual) @@ -470,22 +470,22 @@ func Test_stateRepresentation_Exit_ActionsExecuteInOrder(t *testing.T) { func Test_stateRepresentation_Exit_Substate_SubstateEntryActionsExecuteBeforeSuperstate(t *testing.T) { super, sub := createSuperSubstatePair() var order, subOrder, superOrder int - super.ExitActions = append(super.ExitActions, actionBehaviour{ - Action: func(_ context.Context, _ ...any) error { + super.ExitActions = append(super.ExitActions, actionBehaviour[string, string, any]{ + Action: func(_ context.Context, _ any) error { order += 1 superOrder = order return nil }, }) - sub.ExitActions = append(sub.ExitActions, actionBehaviour{ - Action: func(_ context.Context, _ ...any) error { + sub.ExitActions = append(sub.ExitActions, actionBehaviour[string, string, any]{ + Action: func(_ context.Context, _ any) error { order += 1 subOrder = order return nil }, }) - transition := Transition{Source: sub.State, Destination: stateC, Trigger: triggerX} - sub.Exit(context.Background(), transition) + transition := Transition[string, string]{Source: sub.State, Destination: stateC, Trigger: triggerX} + sub.Exit(context.Background(), transition, nil) if subOrder >= superOrder { t.Error("expected substate exit actions to execute before superstate") } diff --git a/triggers.go b/triggers.go index 6c58f87..3502c87 100644 --- a/triggers.go +++ b/triggers.go @@ -31,19 +31,19 @@ func (inv invocationInfo) String() string { return "" } -type guardCondition struct { - Guard GuardFunc +type guardCondition[A any] struct { + Guard GuardFunc[A] Description invocationInfo } -type transitionGuard struct { - Guards []guardCondition +type transitionGuard[A any] struct { + Guards []guardCondition[A] } -func newtransitionGuard(guards ...GuardFunc) transitionGuard { - tg := transitionGuard{Guards: make([]guardCondition, len(guards))} +func newtransitionGuard[A any](guards ...GuardFunc[A]) transitionGuard[A] { + tg := transitionGuard[A]{Guards: make([]guardCondition[A], len(guards))} for i, guard := range guards { - tg.Guards[i] = guardCondition{ + tg.Guards[i] = guardCondition[A]{ Guard: guard, Description: newinvocationInfo(guard), } @@ -52,97 +52,102 @@ func newtransitionGuard(guards ...GuardFunc) transitionGuard { } // GuardConditionsMet is true if all of the guard functions return true. -func (t transitionGuard) GuardConditionMet(ctx context.Context, args ...any) bool { +func (t transitionGuard[A]) GuardConditionMet(ctx context.Context, arg A) bool { for _, guard := range t.Guards { - if !guard.Guard(ctx, args...) { + if !guard.Guard(ctx, arg) { return false } } return true } -func (t transitionGuard) UnmetGuardConditions(ctx context.Context, buf []string, args ...any) []string { +func (t transitionGuard[A]) UnmetGuardConditions(ctx context.Context, buf []string, arg A) []string { if cap(buf) < len(t.Guards) { buf = make([]string, 0, len(t.Guards)) } buf = buf[:0] for _, guard := range t.Guards { - if !guard.Guard(ctx, args...) { + if !guard.Guard(ctx, arg) { buf = append(buf, guard.Description.String()) } } return buf } -type triggerBehaviour interface { - GuardConditionMet(context.Context, ...any) bool - UnmetGuardConditions(context.Context, []string, ...any) []string - GetTrigger() Trigger +type triggerBehaviour[T Trigger, A any] interface { + GuardConditionMet(context.Context, A) bool + UnmetGuardConditions(context.Context, []string, A) []string + GetTrigger() T } -type baseTriggerBehaviour struct { - Guard transitionGuard - Trigger Trigger +type baseTriggerBehaviour[T Trigger, A any] struct { + Guard transitionGuard[A] + Trigger T } -func (t *baseTriggerBehaviour) GetTrigger() Trigger { +func (t *baseTriggerBehaviour[T, A]) GetTrigger() T { return t.Trigger } -func (t *baseTriggerBehaviour) GuardConditionMet(ctx context.Context, args ...any) bool { - return t.Guard.GuardConditionMet(ctx, args...) +func (t *baseTriggerBehaviour[T, A]) GuardConditionMet(ctx context.Context, arg A) bool { + return t.Guard.GuardConditionMet(ctx, arg) } -func (t *baseTriggerBehaviour) UnmetGuardConditions(ctx context.Context, buf []string, args ...any) []string { - return t.Guard.UnmetGuardConditions(ctx, buf, args...) +func (t *baseTriggerBehaviour[T, A]) UnmetGuardConditions(ctx context.Context, buf []string, arg A) []string { + return t.Guard.UnmetGuardConditions(ctx, buf, arg) } -type ignoredTriggerBehaviour struct { - baseTriggerBehaviour +type ignoredTriggerBehaviour[T Trigger, A any] struct { + baseTriggerBehaviour[T, A] } -type reentryTriggerBehaviour struct { - baseTriggerBehaviour - Destination State +type reentryTriggerBehaviour[S State, T Trigger, A any] struct { + baseTriggerBehaviour[T, A] + Destination S } -type transitioningTriggerBehaviour struct { - baseTriggerBehaviour - Destination State +type transitioningTriggerBehaviour[S State, T Trigger, A any] struct { + baseTriggerBehaviour[T, A] + Destination S } -type dynamicTriggerBehaviour struct { - baseTriggerBehaviour - Destination func(context.Context, ...any) (State, error) +type dynamicTriggerBehaviour[S State, T Trigger, A any] struct { + baseTriggerBehaviour[T, A] + Destination func(context.Context, A) (S, error) } -type internalTriggerBehaviour struct { - baseTriggerBehaviour - Action ActionFunc +type internalTriggerBehaviour[S State, T Trigger, A any] struct { + baseTriggerBehaviour[T, A] + Action ActionFunc[A] } -func (t *internalTriggerBehaviour) Execute(ctx context.Context, transition Transition, args ...any) error { +func (t *internalTriggerBehaviour[S, T, A]) Execute(ctx context.Context, transition Transition[S, T], arg A) error { ctx = withTransition(ctx, transition) - return t.Action(ctx, args...) + return t.Action(ctx, arg) } -type triggerBehaviourResult struct { - Handler triggerBehaviour +type triggerBehaviourResult[T Trigger, A any] struct { + Handler triggerBehaviour[T, A] UnmetGuardConditions []string } +type Validatable interface { + TypeOf(int) reflect.Type + Len() int +} + // triggerWithParameters associates configured parameters with an underlying trigger value. -type triggerWithParameters struct { - Trigger Trigger +type triggerWithParameters[T Trigger] struct { + Trigger T ArgumentTypes []reflect.Type } -func (t triggerWithParameters) validateParameters(args ...any) { - if len(args) != len(t.ArgumentTypes) { - panic(fmt.Sprintf("stateless: An unexpected amount of parameters have been supplied. Expecting '%d' but got '%d'.", len(t.ArgumentTypes), len(args))) +func (t triggerWithParameters[T]) validateParameters(args Validatable) { + if args.Len() != len(t.ArgumentTypes) { + panic(fmt.Sprintf("stateless: An unexpected amount of parameters have been supplied. Expecting '%d' but got '%d'.", len(t.ArgumentTypes), args.Len())) } for i := range t.ArgumentTypes { - tp := reflect.TypeOf(args[i]) + tp := args.TypeOf(i) want := t.ArgumentTypes[i] if !tp.ConvertibleTo(want) { panic(fmt.Sprintf("stateless: The argument in position '%d' is of type '%v' but must be convertible to '%v'.", i, tp, want))