Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions statemachine.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@ func (sm *StateMachine) State(ctx context.Context) (State, error) {
return state, err
}

// StateWithArgs returns the current state along with any arguments that were passed to the state mutator.
// This is useful when using NewStateMachineWithExternalStorageAndArgs to retain additional state information.
func (sm *StateMachine) StateWithArgs(ctx context.Context) (State, []any, error) {
return sm.stateAccessor(ctx)
}

// MustState returns the current state without the error.
// It is safe to use this method when used together with NewStateMachine
// or when using NewStateMachineWithExternalStorage / NewStateMachineWithExternalStorageAndArgs with a state accessor that
Expand Down
98 changes: 98 additions & 0 deletions statemachine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,104 @@ func TestStateMachine_NewStateMachineWithExternalStorageAndArgs(t *testing.T) {
}
}

func TestStateMachine_StateWithArgs(t *testing.T) {
sm := NewStateMachine(stateA)
sm.Configure(stateA).Permit(triggerX, stateB)

state, args, err := sm.StateWithArgs(context.Background())
if err != nil {
t.Errorf("StateWithArgs() error = %v, want nil", err)
}
if state != stateA {
t.Errorf("StateWithArgs() state = %v, want %v", state, stateA)
}
if args != nil {
t.Errorf("StateWithArgs() args = %v, want nil", args)
}
}

func TestStateMachine_StateWithArgs_ExternalStorage(t *testing.T) {
var state State = stateB
sm := NewStateMachineWithExternalStorage(func(_ context.Context) (State, error) {
return state, nil
}, func(_ context.Context, s State) error {
state = s
return nil
}, FiringImmediate)

gotState, gotArgs, err := sm.StateWithArgs(context.Background())
if err != nil {
t.Errorf("StateWithArgs() error = %v, want nil", err)
}
if gotState != stateB {
t.Errorf("StateWithArgs() state = %v, want %v", gotState, stateB)
}
if gotArgs != nil {
t.Errorf("StateWithArgs() args = %v, want nil", gotArgs)
}
}

func TestStateMachine_StateWithArgs_ExternalStorageAndArgs(t *testing.T) {
var state State = stateB
var args = []any{"arg1", 42, errors.New("test error")}
sm := NewStateMachineWithExternalStorageAndArgs(func(_ context.Context) (State, []any, error) {
return state, args, nil
}, func(_ context.Context, s State, a ...any) error {
state = s
args = a
return nil
}, FiringImmediate)
sm.Configure(stateB).Permit(triggerX, stateC)

gotState, gotArgs, err := sm.StateWithArgs(context.Background())
if err != nil {
t.Errorf("StateWithArgs() error = %v, want nil", err)
}
if gotState != stateB {
t.Errorf("StateWithArgs() state = %v, want %v", gotState, stateB)
}
if !reflect.DeepEqual(gotArgs, args) {
t.Errorf("StateWithArgs() args = %v, want %v", gotArgs, args)
}
if got := gotArgs[0].(string); got != "arg1" {
t.Errorf("expected arg 0 to be %v, got %v", "arg1", got)
}
if got := gotArgs[1].(int); got != 42 {
t.Errorf("expected arg 1 to be %v, got %v", 42, got)
}
if got := gotArgs[2].(error).Error(); got != "test error" {
t.Errorf("expected arg 2 to be %v, got %v", "test error", got)
}

// Fire a transition with new arguments
sm.Fire(triggerX, "arg2", 99)
gotState, gotArgs, err = sm.StateWithArgs(context.Background())
if err != nil {
t.Errorf("StateWithArgs() error = %v, want nil", err)
}
if gotState != stateC {
t.Errorf("StateWithArgs() state = %v, want %v", gotState, stateC)
}
if got := gotArgs[0].(string); got != "arg2" {
t.Errorf("expected arg 0 to be %v, got %v", "arg2", got)
}
if got := gotArgs[1].(int); got != 99 {
t.Errorf("expected arg 1 to be %v, got %v", 99, got)
}
}

func TestStateMachine_StateWithArgs_Error(t *testing.T) {
sm := NewStateMachineWithExternalStorage(func(_ context.Context) (State, error) {
return nil, errors.New("state accessor error")
}, func(_ context.Context, s State) error { return nil }, FiringImmediate)

_, _, err := sm.StateWithArgs(context.Background())
want := "state accessor error"
if err == nil || err.Error() != want {
t.Errorf("StateWithArgs() error = %v, want %v", err, want)
}
}

func TestStateMachine_Configure_SubstateIsIncludedInCurrentState(t *testing.T) {
sm := NewStateMachine(stateB)
sm.Configure(stateB).SubstateOf(stateC)
Expand Down