Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ require (
github.com/google/uuid v1.6.0
github.com/gorilla/mux v1.8.1
github.com/jackc/pgx/v5 v5.6.0
github.com/klauspost/compress v1.18.0
github.com/prometheus/client_golang v1.23.2
github.com/redis/go-redis/v9 v9.17.1
github.com/spf13/viper v1.21.0
Expand All @@ -29,6 +30,7 @@ require (
go.opentelemetry.io/otel/sdk v1.38.0
go.opentelemetry.io/otel/trace v1.38.0
go.uber.org/zap v1.27.0
golang.org/x/crypto v0.43.0
golang.org/x/time v0.14.0
gorm.io/driver/postgres v1.6.0
gorm.io/gorm v1.25.12
Expand Down Expand Up @@ -68,7 +70,6 @@ require (
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/magiconair/properties v1.8.10 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
Expand Down Expand Up @@ -117,7 +118,6 @@ require (
go.uber.org/multierr v1.10.0 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/crypto v0.43.0 // indirect
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect
golang.org/x/net v0.45.0 // indirect
golang.org/x/sync v0.18.0 // indirect
Expand Down
4 changes: 2 additions & 2 deletions pkg/errortracking/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,12 @@ Panics are automatically captured when using the logger's panic handlers:

```go
// Using CatchPanic
defer logger.CatchPanic("MyFunction")
defer logger.CatchPanic("MyFunction")()

// Using CatchPanicCallback
defer logger.CatchPanicCallback("MyFunction", func(err any) {
// Custom cleanup
})
})()

// Using HandlePanic
defer func() {
Expand Down
92 changes: 62 additions & 30 deletions pkg/logger/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,28 @@ func CloseErrorTracking() error {
return nil
}

// extractContext attempts to find a context.Context in the given arguments.
// It returns the found context (or context.Background() if not found) and
// the remaining arguments without the context.
func extractContext(args ...interface{}) (context.Context, []interface{}) {
ctx := context.Background()
var newArgs []interface{}
found := false

for _, arg := range args {
if c, ok := arg.(context.Context); ok {
if !found {
ctx = c
found = true
}
// Ignore any additional context.Context arguments after the first one.
continue
}
newArgs = append(newArgs, arg)
}
return ctx, newArgs
}

func Info(template string, args ...interface{}) {
if Logger == nil {
log.Printf(template, args...)
Expand All @@ -84,7 +106,8 @@ func Info(template string, args ...interface{}) {
}

func Warn(template string, args ...interface{}) {
message := fmt.Sprintf(template, args...)
ctx, remainingArgs := extractContext(args...)
message := fmt.Sprintf(template, remainingArgs...)
if Logger == nil {
log.Printf("%s", message)
} else {
Expand All @@ -93,14 +116,15 @@ func Warn(template string, args ...interface{}) {

// Send to error tracker
if errorTracker != nil {
errorTracker.CaptureMessage(context.Background(), message, errortracking.SeverityWarning, map[string]interface{}{
errorTracker.CaptureMessage(ctx, message, errortracking.SeverityWarning, map[string]interface{}{
"process_id": os.Getpid(),
})
}
}

func Error(template string, args ...interface{}) {
message := fmt.Sprintf(template, args...)
ctx, remainingArgs := extractContext(args...)
message := fmt.Sprintf(template, remainingArgs...)
if Logger == nil {
log.Printf("%s", message)
} else {
Expand All @@ -109,7 +133,7 @@ func Error(template string, args ...interface{}) {

// Send to error tracker
if errorTracker != nil {
errorTracker.CaptureMessage(context.Background(), message, errortracking.SeverityError, map[string]interface{}{
errorTracker.CaptureMessage(ctx, message, errortracking.SeverityError, map[string]interface{}{
"process_id": os.Getpid(),
})
}
Expand All @@ -124,34 +148,41 @@ func Debug(template string, args ...interface{}) {
}

// CatchPanic - Handle panic
func CatchPanicCallback(location string, cb func(err any)) {
if err := recover(); err != nil {
callstack := debug.Stack()

if Logger != nil {
Error("Panic in %s : %v", location, err)
} else {
fmt.Printf("%s:PANIC->%+v", location, err)
debug.PrintStack()
}

// Send to error tracker
if errorTracker != nil {
errorTracker.CapturePanic(context.Background(), err, callstack, map[string]interface{}{
"location": location,
"process_id": os.Getpid(),
})
}

if cb != nil {
cb(err)
// Returns a function that should be deferred to catch panics
// Example usage: defer CatchPanicCallback("MyFunction", func(err any) { /* cleanup */ })()
func CatchPanicCallback(location string, cb func(err any), args ...interface{}) func() {
ctx, _ := extractContext(args...)
return func() {
if err := recover(); err != nil {
callstack := debug.Stack()

if Logger != nil {
Error("Panic in %s : %v", location, err, ctx) // Pass context implicitly
} else {
fmt.Printf("%s:PANIC->%+v", location, err)
debug.PrintStack()
}

// Send to error tracker
if errorTracker != nil {
errorTracker.CapturePanic(ctx, err, callstack, map[string]interface{}{
"location": location,
"process_id": os.Getpid(),
})
}

if cb != nil {
cb(err)
}
}
}
}

// CatchPanic - Handle panic
func CatchPanic(location string) {
CatchPanicCallback(location, nil)
// Returns a function that should be deferred to catch panics
// Example usage: defer CatchPanic("MyFunction")()
func CatchPanic(location string, args ...interface{}) func() {
return CatchPanicCallback(location, nil, args...)
}

// HandlePanic logs a panic and returns it as an error
Expand All @@ -163,13 +194,14 @@ func CatchPanic(location string) {
// err = logger.HandlePanic("MethodName", r)
// }
// }()
func HandlePanic(methodName string, r any) error {
func HandlePanic(methodName string, r any, args ...interface{}) error {
ctx, _ := extractContext(args...)
stack := debug.Stack()
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack))
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack), ctx) // Pass context implicitly

// Send to error tracker
if errorTracker != nil {
errorTracker.CapturePanic(context.Background(), r, stack, map[string]interface{}{
errorTracker.CapturePanic(ctx, r, stack, map[string]interface{}{
"method": methodName,
"process_id": os.Getpid(),
})
Expand Down
4 changes: 4 additions & 0 deletions pkg/metrics/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ type Provider interface {
// UpdateEventQueueSize updates the event queue size metric
UpdateEventQueueSize(size int64)

// RecordPanic records a panic event
RecordPanic(methodName string)

// Handler returns an HTTP handler for exposing metrics (e.g., /metrics endpoint)
Handler() http.Handler
}
Expand Down Expand Up @@ -75,6 +78,7 @@ func (n *NoOpProvider) RecordEventPublished(source, eventType string) {}
func (n *NoOpProvider) RecordEventProcessed(source, eventType, status string, duration time.Duration) {
}
func (n *NoOpProvider) UpdateEventQueueSize(size int64) {}
func (n *NoOpProvider) RecordPanic(methodName string) {}
func (n *NoOpProvider) Handler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
Expand Down
13 changes: 13 additions & 0 deletions pkg/metrics/prometheus.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ type PrometheusProvider struct {
cacheHits *prometheus.CounterVec
cacheMisses *prometheus.CounterVec
cacheSize *prometheus.GaugeVec
panicsTotal *prometheus.CounterVec
}

// NewPrometheusProvider creates a new Prometheus metrics provider
Expand Down Expand Up @@ -83,6 +84,13 @@ func NewPrometheusProvider() *PrometheusProvider {
},
[]string{"provider"},
),
panicsTotal: promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "panics_total",
Help: "Total number of panics",
},
[]string{"method"},
),
}
}

Expand Down Expand Up @@ -145,6 +153,11 @@ func (p *PrometheusProvider) UpdateCacheSize(provider string, size int64) {
p.cacheSize.WithLabelValues(provider).Set(float64(size))
}

// RecordPanic implements the Provider interface
func (p *PrometheusProvider) RecordPanic(methodName string) {
p.panicsTotal.WithLabelValues(methodName).Inc()
}

// Handler implements Provider interface
func (p *PrometheusProvider) Handler() http.Handler {
return promhttp.Handler()
Expand Down
33 changes: 33 additions & 0 deletions pkg/middleware/panic.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package middleware

import (
"net/http"

"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/metrics"
)

const panicMiddlewareMethodName = "PanicMiddleware"

// PanicRecovery is a middleware that recovers from panics, logs the error,
// sends it to an error tracker, records a metric, and returns a 500 Internal Server Error.
func PanicRecovery(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if rcv := recover(); rcv != nil {
// Record the panic metric
metrics.GetProvider().RecordPanic(panicMiddlewareMethodName)

// Log the panic and send to error tracker
// We pass the request context so the error tracker can potentially
// link the panic to the request trace.
ctx := r.Context()
err := logger.HandlePanic(panicMiddlewareMethodName, rcv, ctx)

// Respond with a 500 error
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}()
next.ServeHTTP(w, r)
})
}
86 changes: 86 additions & 0 deletions pkg/middleware/panic_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package middleware

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/metrics"
"github.com/stretchr/testify/assert"
)

// mockMetricsProvider is a mock for the metrics provider to check if methods are called.
type mockMetricsProvider struct {
metrics.NoOpProvider // Embed NoOpProvider to avoid implementing all methods
panicRecorded bool
methodName string
}

func (m *mockMetricsProvider) RecordPanic(methodName string) {
m.panicRecorded = true
m.methodName = methodName
}

func TestPanicRecovery(t *testing.T) {
// Initialize a mock logger to avoid actual logging output during tests
logger.Init(true)

// Setup mock metrics provider
mockProvider := &mockMetricsProvider{}
originalProvider := metrics.GetProvider()
metrics.SetProvider(mockProvider)
defer metrics.SetProvider(originalProvider) // Restore original provider after test

// 1. Test case: A handler that panics
t.Run("recovers from panic and returns 500", func(t *testing.T) {
// Reset mock state for this sub-test
mockProvider.panicRecorded = false
mockProvider.methodName = ""

panicHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
panic("something went terribly wrong")
})

// Create the middleware wrapping the panicking handler
testHandler := PanicRecovery(panicHandler)

// Create a test request and response recorder
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
rr := httptest.NewRecorder()

// Serve the request
testHandler.ServeHTTP(rr, req)

// Assertions
assert.Equal(t, http.StatusInternalServerError, rr.Code, "expected status code to be 500")
assert.Contains(t, rr.Body.String(), "panic in PanicMiddleware: something went terribly wrong", "expected error message in response body")

// Assert that the metric was recorded
assert.True(t, mockProvider.panicRecorded, "expected RecordPanic to be called on metrics provider")
assert.Equal(t, panicMiddlewareMethodName, mockProvider.methodName, "expected panic to be recorded with the correct method name")
})

// 2. Test case: A handler that does NOT panic
t.Run("does not interfere with a non-panicking handler", func(t *testing.T) {
// Reset mock state for this sub-test
mockProvider.panicRecorded = false

successHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
})

testHandler := PanicRecovery(successHandler)

req := httptest.NewRequest("GET", "http://example.com/foo", nil)
rr := httptest.NewRecorder()

testHandler.ServeHTTP(rr, req)

// Assertions
assert.Equal(t, http.StatusOK, rr.Code, "expected status code to be 200")
assert.Equal(t, "OK", rr.Body.String(), "expected 'OK' response body")
assert.False(t, mockProvider.panicRecorded, "expected RecordPanic to not be called when there is no panic")
})
}
4 changes: 2 additions & 2 deletions pkg/security/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ func setColSecValue(fieldsrc reflect.Value, colsec ColumnSecurity, fieldTypeName
}

func (m *SecurityList) ApplyColumnSecurity(records reflect.Value, modelType reflect.Type, pUserID int, pSchema, pTablename string) (reflect.Value, error) {
defer logger.CatchPanic("ApplyColumnSecurity")
defer logger.CatchPanic("ApplyColumnSecurity")()

if m.ColumnSecurity == nil {
return records, fmt.Errorf("security not initialized")
Expand Down Expand Up @@ -437,7 +437,7 @@ func (m *SecurityList) LoadRowSecurity(ctx context.Context, pUserID int, pSchema
}

func (m *SecurityList) GetRowSecurityTemplate(pUserID int, pSchema, pTablename string) (RowSecurity, error) {
defer logger.CatchPanic("GetRowSecurityTemplate")
defer logger.CatchPanic("GetRowSecurityTemplate")()

if m.RowSecurity == nil {
return RowSecurity{}, fmt.Errorf("security not initialized")
Expand Down
Loading
Loading