From 700e5f16abab87740ddd62c880d0011267e33c4a Mon Sep 17 00:00:00 2001 From: Ross Nelson Date: Sat, 22 Feb 2025 22:02:26 -0500 Subject: [PATCH 1/2] Add CircuitBreaker --- circuitbreaker/circuitbreaker.go | 255 ++++++++++++++++++++++++++ circuitbreaker/circuitbreaker_test.go | 164 +++++++++++++++++ workflow/command.go | 4 +- 3 files changed, 421 insertions(+), 2 deletions(-) create mode 100644 circuitbreaker/circuitbreaker.go create mode 100644 circuitbreaker/circuitbreaker_test.go diff --git a/circuitbreaker/circuitbreaker.go b/circuitbreaker/circuitbreaker.go new file mode 100644 index 0000000..d927603 --- /dev/null +++ b/circuitbreaker/circuitbreaker.go @@ -0,0 +1,255 @@ +package circuitbreaker + +import ( + "fmt" + "sync" + "time" + + "github.com/simiancreative/simiango/logger" +) + +type State int + +const ( + StateClosed State = iota + StateOpen + StateHalfOpen +) + +func (s State) String() string { + switch s { + case StateClosed: + return "CLOSED" + case StateOpen: + return "OPEN" + case StateHalfOpen: + return "HALF_OPEN" + default: + return "UNKNOWN" + } +} + +type Config struct { + FailureThreshold int + OpenTimeout time.Duration + HalfOpenMaxCalls int + OnStateChange func(from, to State) +} + +type CircuitBreaker struct { + config Config + state State + failures int + attempts int + successes int + mutex sync.RWMutex + timer *time.Timer +} + +func New(config Config) (*CircuitBreaker, error) { + if err := validateConfig(config); err != nil { + return nil, err + } + + logger.Debug("creating new circuit breaker", logger.Fields{ + "failure_threshold": config.FailureThreshold, + "open_timeout": config.OpenTimeout.String(), + "half_open_max_calls": config.HalfOpenMaxCalls, + }) + + return &CircuitBreaker{ + config: config, + state: StateClosed, + }, nil +} + +func (cb *CircuitBreaker) Allow() bool { + cb.mutex.RLock() + defer cb.mutex.RUnlock() + + allowed := false + switch cb.state { + case StateOpen: + allowed = false + case StateHalfOpen: + allowed = cb.attempts < cb.config.HalfOpenMaxCalls + default: + allowed = true + } + + logger.Debug("circuit breaker allow check", logger.Fields{ + "state": cb.state.String(), + "allowed": allowed, + "attempts": cb.attempts, + "max_calls": cb.config.HalfOpenMaxCalls, + }) + + return allowed +} + +func (cb *CircuitBreaker) GetState() State { + cb.mutex.RLock() + defer cb.mutex.RUnlock() + return cb.state +} + +// RecordStart marks the beginning of an attempt +func (cb *CircuitBreaker) RecordStart() bool { + cb.mutex.Lock() + defer cb.mutex.Unlock() + + switch cb.state { + case StateOpen: + logger.Debug("attempt rejected - circuit open", logger.Fields{ + "state": cb.state.String(), + }) + return false + case StateHalfOpen: + if cb.attempts >= cb.config.HalfOpenMaxCalls { + logger.Debug("attempt rejected - max half-open calls reached", logger.Fields{ + "attempts": cb.attempts, + "max_calls": cb.config.HalfOpenMaxCalls, + }) + return false + } + } + + cb.attempts++ + logger.Debug("attempt started", logger.Fields{ + "state": cb.state.String(), + "attempts": cb.attempts, + }) + return true +} + +// RecordResult records the result of an attempt +func (cb *CircuitBreaker) RecordResult(success bool) { + cb.mutex.Lock() + defer cb.mutex.Unlock() + + logger.Debug("recording attempt result", logger.Fields{ + "success": success, + "state": cb.state.String(), + "attempts": cb.attempts, + "successes": cb.successes, + "failures": cb.failures, + }) + + if !success { + cb.recordFailure() + return + } + + switch cb.state { + case StateHalfOpen: + cb.successes++ + logger.Debug("recorded success in half-open state", logger.Fields{ + "attempts": cb.attempts, + "successes": cb.successes, + "max_calls": cb.config.HalfOpenMaxCalls, + }) + if cb.successes >= cb.config.HalfOpenMaxCalls { + cb.transitionTo(StateClosed) + } + case StateClosed: + cb.failures = 0 + logger.Debug("recorded success in closed state", logger.Fields{ + "failures": cb.failures, + }) + } +} + +func (cb *CircuitBreaker) Reset() { + cb.mutex.Lock() + defer cb.mutex.Unlock() + + logger.Debug("resetting circuit breaker", logger.Fields{ + "from_state": cb.state.String(), + }) + + if cb.timer != nil { + cb.timer.Stop() + } + + cb.transitionTo(StateClosed) + cb.failures = 0 + cb.attempts = 0 + cb.successes = 0 +} + +func (cb *CircuitBreaker) recordFailure() { + cb.failures++ + + logger.Debug("recorded failure", logger.Fields{ + "state": cb.state.String(), + "failures": cb.failures, + "threshold": cb.config.FailureThreshold, + }) + + if cb.state == StateClosed && cb.failures >= cb.config.FailureThreshold { + cb.openCircuit() + } else if cb.state == StateHalfOpen { + cb.openCircuit() + } +} + +func (cb *CircuitBreaker) openCircuit() { + logger.Debug("opening circuit", logger.Fields{ + "from_state": cb.state.String(), + "open_timeout": cb.config.OpenTimeout.String(), + }) + + if cb.timer != nil { + cb.timer.Stop() + } + + cb.transitionTo(StateOpen) + + cb.timer = time.AfterFunc(cb.config.OpenTimeout, func() { + cb.mutex.Lock() + defer cb.mutex.Unlock() + + logger.Debug("open timeout elapsed", logger.Fields{ + "current_state": cb.state.String(), + }) + + if cb.state == StateOpen { + cb.transitionTo(StateHalfOpen) + } + }) +} + +func (cb *CircuitBreaker) transitionTo(newState State) { + if cb.state == newState { + return + } + + oldState := cb.state + cb.state = newState + cb.attempts = 0 + cb.successes = 0 + + logger.Debug("state transition", logger.Fields{ + "from_state": oldState.String(), + "to_state": newState.String(), + "attempts": cb.attempts, + "successes": cb.successes, + }) + + if cb.config.OnStateChange != nil { + go cb.config.OnStateChange(oldState, newState) + } +} + +func validateConfig(config Config) error { + if config.FailureThreshold <= 0 { + return fmt.Errorf("failure threshold must be greater than 0") + } + if config.OpenTimeout <= 0 { + return fmt.Errorf("open timeout must be greater than 0") + } + if config.HalfOpenMaxCalls <= 0 { + return fmt.Errorf("half-open max calls must be greater than 0") + } + return nil +} diff --git a/circuitbreaker/circuitbreaker_test.go b/circuitbreaker/circuitbreaker_test.go new file mode 100644 index 0000000..200f1a3 --- /dev/null +++ b/circuitbreaker/circuitbreaker_test.go @@ -0,0 +1,164 @@ +package circuitbreaker_test + +import ( + "os" + "testing" + "time" + + "github.com/simiancreative/simiango/circuitbreaker" + "github.com/simiancreative/simiango/logger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type actionType int + +const ( + allow actionType = iota + start + recordSuccess + recordFailure + checkState + wait +) + +type step struct { + action actionType + want interface{} // bool for allow/start, State for checkState + duration time.Duration // for wait action +} + +func TestCircuitBreaker(t *testing.T) { + os.Setenv("LOG_TYPE", "line") + logger.Enable() + + tests := []struct { + name string + config circuitbreaker.Config + steps []step + wantFinalState circuitbreaker.State + }{ + { + name: "remains closed when under threshold", + config: circuitbreaker.Config{ + FailureThreshold: 3, + OpenTimeout: time.Second, + HalfOpenMaxCalls: 2, + }, + steps: []step{ + {action: start, want: true}, + {action: recordFailure}, + {action: start, want: true}, + {action: recordFailure}, + {action: allow, want: true}, + {action: checkState, want: circuitbreaker.StateClosed}, + }, + wantFinalState: circuitbreaker.StateClosed, + }, + { + name: "opens after threshold failures", + config: circuitbreaker.Config{ + FailureThreshold: 2, + OpenTimeout: time.Second, + HalfOpenMaxCalls: 2, + }, + steps: []step{ + {action: start, want: true}, + {action: recordFailure}, + {action: start, want: true}, + {action: recordFailure}, + {action: start, want: false}, + {action: checkState, want: circuitbreaker.StateOpen}, + }, + wantFinalState: circuitbreaker.StateOpen, + }, + { + name: "transitions to half-open after timeout", + config: circuitbreaker.Config{ + FailureThreshold: 2, + OpenTimeout: 50 * time.Millisecond, + HalfOpenMaxCalls: 2, + }, + steps: []step{ + {action: start, want: true}, + {action: recordFailure}, + {action: start, want: true}, + {action: recordFailure}, + {action: wait, duration: 60 * time.Millisecond}, + {action: start, want: true}, + {action: checkState, want: circuitbreaker.StateHalfOpen}, + }, + wantFinalState: circuitbreaker.StateHalfOpen, + }, + { + name: "closes after successful half-open calls", + config: circuitbreaker.Config{ + FailureThreshold: 2, + OpenTimeout: 50 * time.Millisecond, + HalfOpenMaxCalls: 2, + }, + steps: []step{ + {action: start, want: true}, + {action: recordFailure}, + {action: start, want: true}, + {action: recordFailure}, + {action: wait, duration: 60 * time.Millisecond}, + {action: start, want: true}, + {action: recordSuccess}, + {action: start, want: true}, + {action: recordSuccess}, + {action: checkState, want: circuitbreaker.StateClosed}, + }, + wantFinalState: circuitbreaker.StateClosed, + }, + { + name: "limits calls in half-open state", + config: circuitbreaker.Config{ + FailureThreshold: 2, + OpenTimeout: 50 * time.Millisecond, + HalfOpenMaxCalls: 2, + }, + steps: []step{ + {action: start, want: true}, + {action: recordFailure}, + {action: start, want: true}, + {action: recordFailure}, + {action: wait, duration: 60 * time.Millisecond}, + {action: start, want: true}, // First call allowed + {action: start, want: true}, // Second call allowed + {action: start, want: false}, // Third call rejected + {action: checkState, want: circuitbreaker.StateHalfOpen}, + {action: recordSuccess}, + {action: recordSuccess}, + {action: checkState, want: circuitbreaker.StateClosed}, + }, + wantFinalState: circuitbreaker.StateClosed, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cb, err := circuitbreaker.New(tt.config) + require.NoError(t, err) + + for i, step := range tt.steps { + switch step.action { + case allow: + assert.Equal(t, step.want, cb.Allow(), "step %d: Allow()", i) + case start: + assert.Equal(t, step.want, cb.RecordStart(), "step %d: RecordStart()", i) + case recordSuccess: + cb.RecordResult(true) + case recordFailure: + cb.RecordResult(false) + case checkState: + assert.Equal(t, step.want, cb.GetState(), "step %d: GetState()", i) + case wait: + time.Sleep(step.duration) + } + } + + assert.Equal(t, tt.wantFinalState, cb.GetState(), "final state") + }) + } +} diff --git a/workflow/command.go b/workflow/command.go index a399764..d776218 100644 --- a/workflow/command.go +++ b/workflow/command.go @@ -88,8 +88,8 @@ func checkArgs( ) } -func buildArgs(flags *pflag.FlagSet, args []string, actionArgs ArgsList) map[string]string { - mapped := map[string]string{} +func buildArgs(flags *pflag.FlagSet, args []string, actionArgs ArgsList) Args { + mapped := Args{} for i, val := range actionArgs { if flags.Changed(val[0]) { From daaf1f1729b30e20964f472885123d843f2226d2 Mon Sep 17 00:00:00 2001 From: Ross Nelson Date: Sat, 22 Feb 2025 22:29:51 -0500 Subject: [PATCH 2/2] guard --- circuitbreaker/circuitbreaker.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/circuitbreaker/circuitbreaker.go b/circuitbreaker/circuitbreaker.go index d927603..036ff87 100644 --- a/circuitbreaker/circuitbreaker.go +++ b/circuitbreaker/circuitbreaker.go @@ -186,11 +186,14 @@ func (cb *CircuitBreaker) recordFailure() { "threshold": cb.config.FailureThreshold, }) - if cb.state == StateClosed && cb.failures >= cb.config.FailureThreshold { - cb.openCircuit() - } else if cb.state == StateHalfOpen { - cb.openCircuit() + shouldOpen := cb.state == StateHalfOpen || + (cb.state == StateClosed && cb.failures >= cb.config.FailureThreshold) + + if !shouldOpen { + return } + + cb.openCircuit() } func (cb *CircuitBreaker) openCircuit() {