-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathclient.go
More file actions
348 lines (294 loc) · 9.51 KB
/
client.go
File metadata and controls
348 lines (294 loc) · 9.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
package tributary
import (
"context"
"encoding/json"
"fmt"
"sync"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/nilpntr/tributary/tributaryhook"
)
// WorkerCallback is called when new work becomes available for workers.
// Each worker registers a callback that triggers step fetching and execution.
type WorkerCallback func()
// Client is the main Tributary client for executing workflows and steps.
type Client struct {
config *Config
pool *pgxpool.Pool
// State management
started bool
mu sync.Mutex
wg sync.WaitGroup
cancel context.CancelFunc
ctx context.Context
// Worker management - subscription-based notification
workerSubscriptions map[string]WorkerCallback // maps worker IDs to their callback functions
subscriptionMu sync.RWMutex // protects workerSubscriptions
}
// NewClient creates a new Tributary client with the given configuration.
func NewClient(pool *pgxpool.Pool, config *Config) (*Client, error) {
if config == nil {
return nil, fmt.Errorf("config cannot be nil")
}
if pool == nil {
return nil, fmt.Errorf("pool cannot be nil")
}
// Set defaults and validate
config.SetDefaults()
if err := config.Validate(); err != nil {
return nil, fmt.Errorf("invalid config: %w", err)
}
ctx, cancel := context.WithCancel(context.Background())
return &Client{
config: config,
pool: pool,
ctx: ctx,
cancel: cancel,
workerSubscriptions: make(map[string]WorkerCallback),
}, nil
}
// registerWorker registers a worker callback for the given worker ID.
// The callback will be invoked whenever new work becomes available.
func (c *Client) registerWorker(workerID string, callback WorkerCallback) {
c.subscriptionMu.Lock()
defer c.subscriptionMu.Unlock()
c.workerSubscriptions[workerID] = callback
}
// unregisterWorker removes a worker's callback registration.
// This should be called when a worker shuts down to prevent memory leaks.
func (c *Client) unregisterWorker(workerID string) {
c.subscriptionMu.Lock()
defer c.subscriptionMu.Unlock()
delete(c.workerSubscriptions, workerID)
}
// notifyAllWorkers calls all registered worker callbacks to notify them of available work.
// This replaces the old channel-based notification system with direct callback invocation.
func (c *Client) notifyAllWorkers() {
c.subscriptionMu.RLock()
defer c.subscriptionMu.RUnlock()
// Call all registered worker callbacks
for _, callback := range c.workerSubscriptions {
// Call callback in a separate goroutine to avoid blocking
// if one worker's callback takes time
go callback()
}
}
// Start starts the client and begins processing steps.
// It blocks until the context is cancelled or an error occurs.
func (c *Client) Start(ctx context.Context) error {
c.mu.Lock()
if c.started {
c.mu.Unlock()
return ErrClientAlreadyStarted
}
c.started = true
c.mu.Unlock()
// Start notification listener for PostgreSQL LISTEN/NOTIFY
c.startNotificationListener(ctx)
// Start workers for each queue with subscription-based callbacks
for queueName, queueConfig := range c.config.Queues {
for i := 0; i < queueConfig.NumWorkers; i++ {
workerID := fmt.Sprintf("%s-%d", queueName, i)
c.wg.Add(1)
go c.runWorkerWithSubscription(ctx, queueName, workerID)
}
}
// Wait for context cancellation
<-ctx.Done()
// Signal workers to stop
c.cancel()
// Wait for all workers to finish
c.wg.Wait()
c.mu.Lock()
c.started = false
c.mu.Unlock()
return nil
}
// Stop gracefully stops the client.
func (c *Client) Stop(ctx context.Context) error {
c.cancel()
// Wait for all workers to finish with timeout
done := make(chan struct{})
go func() {
c.wg.Wait()
close(done)
}()
select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// runWorkerWithSubscription runs a worker that uses the subscription-based notification system.
// It registers a callback and uses both notifications and polling to ensure work is processed.
func (c *Client) runWorkerWithSubscription(ctx context.Context, queueName string, workerID string) {
defer c.wg.Done()
defer c.unregisterWorker(workerID)
// Create a channel for this worker to receive notifications
workerNotifyCh := make(chan struct{}, 100) // Buffered to prevent blocking
// Register callback that sends to this worker's private channel
callback := func() {
select {
case workerNotifyCh <- struct{}{}:
default:
// Channel full, worker is already notified
}
}
c.registerWorker(workerID, callback)
ticker := time.NewTicker(c.config.FetchPollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-c.ctx.Done():
// Client is shutting down
return
case <-workerNotifyCh:
// New steps available via subscription notification
c.fetchAndExecute(ctx, queueName)
case <-ticker.C:
// Poll periodically as backup
c.fetchAndExecute(ctx, queueName)
}
}
}
// createWorkflowExecution creates a new workflow execution record.
func (c *Client) createWorkflowExecution(ctx context.Context, workflowName string) (int64, error) {
var id int64
err := c.pool.QueryRow(ctx, `
INSERT INTO workflow_executions (workflow_name, state)
VALUES ($1, 'running')
RETURNING id
`, workflowName).Scan(&id)
if err != nil {
return 0, fmt.Errorf("failed to create workflow execution: %w", err)
}
return id, nil
}
// InsertMany inserts multiple steps into the database.
func (c *Client) InsertMany(ctx context.Context, params []InsertManyParams) ([]int64, error) {
return c.InsertManyTx(ctx, nil, params)
}
// InsertManyTx inserts multiple steps into the database within a transaction.
func (c *Client) InsertManyTx(ctx context.Context, tx pgx.Tx, params []InsertManyParams) ([]int64, error) {
if len(params) == 0 {
return []int64{}, nil
}
var err error
if tx == nil {
// Start a new transaction if one wasn't provided
tx, err = c.pool.Begin(ctx)
if err != nil {
return nil, fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.Rollback(ctx)
}
stepIDs := make([]int64, 0, len(params))
taskNameToStepID := make(map[string]int64)
// First pass: insert all steps
for _, param := range params {
// Serialize args to JSON
argsBytes, err := json.Marshal(param.Args)
if err != nil {
return nil, fmt.Errorf("failed to marshal args: %w", err)
}
// Run BeforeInsert hooks
argsBytes, err = runBeforeInsertHooks(ctx, c.config.Hooks, param.Args, argsBytes)
if err != nil {
return nil, fmt.Errorf("BeforeInsert hook failed: %w", err)
}
// Apply default insert options
opts := param.InsertOpts
if opts == nil {
opts = &InsertOpts{}
}
// Check if step args provide their own insert options
if provider, ok := param.Args.(StepInsertOptsProvider); ok {
stepOpts := provider.InsertOpts()
// Step-specific options override parameter options, but not explicit values
if opts.MaxAttempts == 0 && stepOpts.MaxAttempts > 0 {
opts.MaxAttempts = stepOpts.MaxAttempts
}
if opts.Priority == 0 && stepOpts.Priority > 0 {
opts.Priority = stepOpts.Priority
}
if opts.Queue == "" && stepOpts.Queue != "" {
opts.Queue = stepOpts.Queue
}
if opts.Timeout == 0 && stepOpts.Timeout > 0 {
opts.Timeout = stepOpts.Timeout
}
}
// Apply global defaults
if opts.MaxAttempts == 0 {
opts.MaxAttempts = c.config.MaxAttempts
}
if opts.Queue == "" {
opts.Queue = "default"
}
if opts.Priority == 0 {
opts.Priority = 1
}
scheduledAt := opts.ScheduledAt
if scheduledAt.IsZero() {
scheduledAt = time.Now()
}
var timeoutSeconds *int
if opts.Timeout > 0 {
seconds := int(opts.Timeout.Seconds())
timeoutSeconds = &seconds
}
var stepID int64
err = tx.QueryRow(ctx, `
INSERT INTO steps (
workflow_execution_id, task_name, kind, args,
queue, priority, max_attempts, timeout_seconds, scheduled_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
RETURNING id
`, param.WorkflowExecutionID, param.TaskName, param.Args.Kind(), argsBytes,
opts.Queue, opts.Priority, opts.MaxAttempts, timeoutSeconds, scheduledAt).Scan(&stepID)
if err != nil {
return nil, fmt.Errorf("failed to insert step: %w", err)
}
stepIDs = append(stepIDs, stepID)
taskNameToStepID[param.TaskName] = stepID
// Run AfterInsert hooks
tributaryhook.RunAfterInsertHooks(ctx, c.config.Hooks, stepID, param.Args)
}
// Second pass: insert dependencies
for i, param := range params {
stepID := stepIDs[i]
for _, depTaskName := range param.DependsOn {
depStepID, ok := taskNameToStepID[depTaskName]
if !ok {
return nil, fmt.Errorf("dependency %q not found for step %q", depTaskName, param.TaskName)
}
_, err = tx.Exec(ctx, `
INSERT INTO step_dependencies (step_id, depends_on_step_id, workflow_execution_id)
VALUES ($1, $2, $3)
`, stepID, depStepID, param.WorkflowExecutionID)
if err != nil {
return nil, fmt.Errorf("failed to insert dependency: %w", err)
}
}
}
// Commit if we started the transaction
if err := tx.Commit(ctx); err != nil {
return nil, fmt.Errorf("failed to commit transaction: %w", err)
}
// Signal workers that new steps are available
c.notifyWorkers()
return stepIDs, nil
}
// NewWorkflow creates a new workflow with the given options.
func (c *Client) NewWorkflow(opts *WorkflowOpts) *Workflow {
return NewWorkflow(c, opts)
}
// CreateWorkflowExecution creates a new workflow execution record.
// This implements the WorkflowClient interface from tributarytype.
func (c *Client) CreateWorkflowExecution(ctx context.Context, name string) (int64, error) {
return c.createWorkflowExecution(ctx, name)
}