diff --git a/statemachine.go b/statemachine.go index dcaab09..7775ca3 100644 --- a/statemachine.go +++ b/statemachine.go @@ -148,6 +148,12 @@ func (sm *StateMachine) State(ctx context.Context) (State, error) { return state, err } +// StateWithArgs returns the current state along with any arguments that were passed to the state mutator. +// This is useful when using NewStateMachineWithExternalStorageAndArgs to retain additional state information. +func (sm *StateMachine) StateWithArgs(ctx context.Context) (State, []any, error) { + return sm.stateAccessor(ctx) +} + // MustState returns the current state without the error. // It is safe to use this method when used together with NewStateMachine // or when using NewStateMachineWithExternalStorage / NewStateMachineWithExternalStorageAndArgs with a state accessor that diff --git a/statemachine_test.go b/statemachine_test.go index ac3124a..f3ae9b7 100644 --- a/statemachine_test.go +++ b/statemachine_test.go @@ -106,6 +106,104 @@ func TestStateMachine_NewStateMachineWithExternalStorageAndArgs(t *testing.T) { } } +func TestStateMachine_StateWithArgs(t *testing.T) { + sm := NewStateMachine(stateA) + sm.Configure(stateA).Permit(triggerX, stateB) + + state, args, err := sm.StateWithArgs(context.Background()) + if err != nil { + t.Errorf("StateWithArgs() error = %v, want nil", err) + } + if state != stateA { + t.Errorf("StateWithArgs() state = %v, want %v", state, stateA) + } + if args != nil { + t.Errorf("StateWithArgs() args = %v, want nil", args) + } +} + +func TestStateMachine_StateWithArgs_ExternalStorage(t *testing.T) { + var state State = stateB + sm := NewStateMachineWithExternalStorage(func(_ context.Context) (State, error) { + return state, nil + }, func(_ context.Context, s State) error { + state = s + return nil + }, FiringImmediate) + + gotState, gotArgs, err := sm.StateWithArgs(context.Background()) + if err != nil { + t.Errorf("StateWithArgs() error = %v, want nil", err) + } + if gotState != stateB { + t.Errorf("StateWithArgs() state = %v, want %v", gotState, stateB) + } + if gotArgs != nil { + t.Errorf("StateWithArgs() args = %v, want nil", gotArgs) + } +} + +func TestStateMachine_StateWithArgs_ExternalStorageAndArgs(t *testing.T) { + var state State = stateB + var args = []any{"arg1", 42, errors.New("test error")} + sm := NewStateMachineWithExternalStorageAndArgs(func(_ context.Context) (State, []any, error) { + return state, args, nil + }, func(_ context.Context, s State, a ...any) error { + state = s + args = a + return nil + }, FiringImmediate) + sm.Configure(stateB).Permit(triggerX, stateC) + + gotState, gotArgs, err := sm.StateWithArgs(context.Background()) + if err != nil { + t.Errorf("StateWithArgs() error = %v, want nil", err) + } + if gotState != stateB { + t.Errorf("StateWithArgs() state = %v, want %v", gotState, stateB) + } + if !reflect.DeepEqual(gotArgs, args) { + t.Errorf("StateWithArgs() args = %v, want %v", gotArgs, args) + } + if got := gotArgs[0].(string); got != "arg1" { + t.Errorf("expected arg 0 to be %v, got %v", "arg1", got) + } + if got := gotArgs[1].(int); got != 42 { + t.Errorf("expected arg 1 to be %v, got %v", 42, got) + } + if got := gotArgs[2].(error).Error(); got != "test error" { + t.Errorf("expected arg 2 to be %v, got %v", "test error", got) + } + + // Fire a transition with new arguments + sm.Fire(triggerX, "arg2", 99) + gotState, gotArgs, err = sm.StateWithArgs(context.Background()) + if err != nil { + t.Errorf("StateWithArgs() error = %v, want nil", err) + } + if gotState != stateC { + t.Errorf("StateWithArgs() state = %v, want %v", gotState, stateC) + } + if got := gotArgs[0].(string); got != "arg2" { + t.Errorf("expected arg 0 to be %v, got %v", "arg2", got) + } + if got := gotArgs[1].(int); got != 99 { + t.Errorf("expected arg 1 to be %v, got %v", 99, got) + } +} + +func TestStateMachine_StateWithArgs_Error(t *testing.T) { + sm := NewStateMachineWithExternalStorage(func(_ context.Context) (State, error) { + return nil, errors.New("state accessor error") + }, func(_ context.Context, s State) error { return nil }, FiringImmediate) + + _, _, err := sm.StateWithArgs(context.Background()) + want := "state accessor error" + if err == nil || err.Error() != want { + t.Errorf("StateWithArgs() error = %v, want %v", err, want) + } +} + func TestStateMachine_Configure_SubstateIsIncludedInCurrentState(t *testing.T) { sm := NewStateMachine(stateB) sm.Configure(stateB).SubstateOf(stateC)