diff --git a/oncetask/once_task_firestore.go b/oncetask/once_task_firestore.go index e7fa513..23f0b55 100644 --- a/oncetask/once_task_firestore.go +++ b/oncetask/once_task_firestore.go @@ -274,7 +274,7 @@ func (m *firestoreOnceTaskManager[TaskKind]) runLoop( cancellationHandler := getCancellationHandler[TaskKind](config) for _, task := range cancelledTasks { ctx := withTaskContext(m.ctx, task.Id, task.ResourceKey) - result, execErr := cancellationHandler(ctx, &task) + result, execErr := SafeExecute(ctx, cancellationHandler, &task) if err := m.completeBatch(ctx, []OnceTask[TaskKind]{task}, execErr, result, config); err != nil { slog.ErrorContext(ctx, "Failed to complete cancelled task", "error", err, "taskId", task.Id) } @@ -293,10 +293,10 @@ func (m *firestoreOnceTaskManager[TaskKind]) runLoop( slog.ErrorContext(m.ctx, "Single task handler claimed multiple tasks", "taskType", taskType, "count", len(normalTasks)) execErr = fmt.Errorf("expected 1 task, got %d", len(normalTasks)) } else { - result, execErr = taskHandler(withSingleTaskContext(m.ctx, normalTasks), &normalTasks[0]) + result, execErr = SafeExecute(withSingleTaskContext(m.ctx, normalTasks), taskHandler, &normalTasks[0]) } } else if hasResource { - result, execErr = resourceHandler(withResourceKeyTaskContext(m.ctx, normalTasks), normalTasks) + result, execErr = SafeExecute(withResourceKeyTaskContext(m.ctx, normalTasks), resourceHandler, normalTasks) } if err := m.completeBatch(m.ctx, normalTasks, execErr, result, config); err != nil { diff --git a/oncetask/panic_recovery.go b/oncetask/panic_recovery.go new file mode 100644 index 0000000..a014411 --- /dev/null +++ b/oncetask/panic_recovery.go @@ -0,0 +1,32 @@ +package oncetask + +import ( + "context" + "fmt" + "log/slog" + "runtime/debug" +) + +// SafeExecute wraps a function execution with panic recovery. +// If the function panics, the panic is recovered and converted to an error. +// The stack trace is logged via slog.ErrorContext for debugging. +// +// Example usage: +// +// result, err := SafeExecute(ctx, handler, task) +// +// Returns: +// - (result, nil) if fn completes successfully +// - (nil, error) if fn returns an error +// - (nil, error) if fn panics (panic converted to error) +func SafeExecute[P any, R any](ctx context.Context, fn func(context.Context, P) (R, error), p P) (result R, err error) { + defer func() { + if r := recover(); r != nil { + stack := string(debug.Stack()) + slog.ErrorContext(ctx, "handler panicked", "panic", r, "stack", stack) + err = fmt.Errorf("panic: %v", r) + } + }() + + return fn(ctx, p) +} diff --git a/oncetask/panic_recovery_test.go b/oncetask/panic_recovery_test.go new file mode 100644 index 0000000..c7f9879 --- /dev/null +++ b/oncetask/panic_recovery_test.go @@ -0,0 +1,222 @@ +package oncetask + +import ( + "context" + "errors" + "strings" + "testing" +) + +func TestSafeExecute(t *testing.T) { + tests := []struct { + handler func(context.Context, string) (string, error) + name string + input string + wantResult string + wantErrContain string + wantErr bool + }{ + { + name: "success returns result", + handler: func(ctx context.Context, input string) (string, error) { + return "got: " + input, nil + }, + input: "test", + wantResult: "got: test", + wantErr: false, + }, + { + name: "error is passed through", + handler: func(ctx context.Context, input string) (string, error) { + return "", errors.New("handler error") + }, + input: "test", + wantResult: "", + wantErr: true, + wantErrContain: "handler error", + }, + { + name: "panic with string is recovered", + handler: func(ctx context.Context, input string) (string, error) { + panic("something went wrong") + }, + input: "test", + wantResult: "", + wantErr: true, + wantErrContain: "something went wrong", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := SafeExecute(context.Background(), tt.handler, tt.input) + + if (err != nil) != tt.wantErr { + t.Errorf("SafeExecute() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr && tt.wantErrContain != "" && !strings.Contains(err.Error(), tt.wantErrContain) { + t.Errorf("SafeExecute() error = %v, want containing %q", err, tt.wantErrContain) + } + if result != tt.wantResult { + t.Errorf("SafeExecute() result = %v, want %v", result, tt.wantResult) + } + }) + } +} + +func TestSafeExecute_PanicRecovery(t *testing.T) { + tests := []struct { + name string + panicValue any + wantErrContain string + }{ + { + name: "string panic", + panicValue: "panic message", + wantErrContain: "panic message", + }, + { + name: "error panic", + panicValue: errors.New("error as panic"), + wantErrContain: "error as panic", + }, + { + name: "int panic", + panicValue: 42, + wantErrContain: "42", + }, + { + name: "nil panic", + panicValue: nil, + wantErrContain: "panic:", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + handler := func(ctx context.Context, input any) (any, error) { + if tt.panicValue == nil { + panic(nil) //nolint:govet // Intentionally testing panic(nil) recovery + } + panic(tt.panicValue) + } + + _, err := SafeExecute(context.Background(), handler, nil) + + if err == nil { + t.Fatal("expected error from panic, got nil") + } + if !strings.Contains(err.Error(), tt.wantErrContain) { + t.Errorf("error = %q, want containing %q", err.Error(), tt.wantErrContain) + } + }) + } +} + +func TestSafeExecute_PreservesContext(t *testing.T) { + type ctxKey string + key := ctxKey("test-key") + + var capturedValue any + handler := func(ctx context.Context, input string) (string, error) { + capturedValue = ctx.Value(key) + return "", nil + } + + ctx := context.WithValue(context.Background(), key, "test-value") + _, _ = SafeExecute(ctx, handler, "input") + + if capturedValue != "test-value" { + t.Errorf("context value = %v, want %v", capturedValue, "test-value") + } +} + +func TestSafeExecute_PassesParameter(t *testing.T) { + var capturedParam int + handler := func(ctx context.Context, input int) (int, error) { + capturedParam = input + return input * 2, nil + } + + result, err := SafeExecute(context.Background(), handler, 21) + + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if capturedParam != 21 { + t.Errorf("captured param = %d, want 21", capturedParam) + } + if result != 42 { + t.Errorf("result = %d, want 42", result) + } +} + +// testTaskKind for testing with actual handler types +type testTaskKind string + +func TestSafeExecute_WithHandlerTypes(t *testing.T) { + t.Run("task handler success", func(t *testing.T) { + handler := func(ctx context.Context, task *OnceTask[testTaskKind]) (any, error) { + return "processed: " + task.Id, nil + } + + task := &OnceTask[testTaskKind]{Id: "task-123"} + result, err := SafeExecute(context.Background(), handler, task) + + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if result != "processed: task-123" { + t.Errorf("result = %v, want %v", result, "processed: task-123") + } + }) + + t.Run("task handler panic", func(t *testing.T) { + handler := func(ctx context.Context, task *OnceTask[testTaskKind]) (any, error) { + panic("handler panic") + } + + task := &OnceTask[testTaskKind]{Id: "task-123"} + _, err := SafeExecute(context.Background(), handler, task) + + if err == nil { + t.Fatal("expected error from panic") + } + if !strings.Contains(err.Error(), "handler panic") { + t.Errorf("error = %v, want containing %q", err, "handler panic") + } + }) + + t.Run("resource key handler success", func(t *testing.T) { + handler := func(ctx context.Context, tasks []OnceTask[testTaskKind]) (any, error) { + return len(tasks), nil + } + + tasks := []OnceTask[testTaskKind]{{Id: "1"}, {Id: "2"}, {Id: "3"}} + result, err := SafeExecute(context.Background(), handler, tasks) + + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if result != 3 { + t.Errorf("result = %v, want 3", result) + } + }) + + t.Run("resource key handler panic", func(t *testing.T) { + handler := func(ctx context.Context, tasks []OnceTask[testTaskKind]) (any, error) { + panic("resource handler panic") + } + + tasks := []OnceTask[testTaskKind]{{Id: "1"}} + _, err := SafeExecute(context.Background(), handler, tasks) + + if err == nil { + t.Fatal("expected error from panic") + } + if !strings.Contains(err.Error(), "resource handler panic") { + t.Errorf("error = %v, want containing %q", err, "resource handler panic") + } + }) +}