diff --git a/.github/workflows/docker/dynamic-config-custom.yaml b/.github/workflows/docker/dynamic-config-custom.yaml index 49dddbb77..acaddf3c5 100644 --- a/.github/workflows/docker/dynamic-config-custom.yaml +++ b/.github/workflows/docker/dynamic-config-custom.yaml @@ -49,4 +49,8 @@ history.enableChasm: history.enableTransitionHistory: - value: true component.nexusoperations.useSystemCallbackURL: - - value: false \ No newline at end of file + - value: false +frontend.WorkerHeartbeatsEnabled: + - value: true +frontend.ListWorkersEnabled: + - value: true diff --git a/contrib/resourcetuner/cgroups.go b/contrib/sysinfo/cgroups.go similarity index 93% rename from contrib/resourcetuner/cgroups.go rename to contrib/sysinfo/cgroups.go index ab5012250..acfb1f3dd 100644 --- a/contrib/resourcetuner/cgroups.go +++ b/contrib/sysinfo/cgroups.go @@ -1,6 +1,6 @@ //go:build linux -package resourcetuner +package sysinfo import ( "errors" @@ -43,11 +43,11 @@ func (p *cGroupInfoImpl) GetLastCPUUsage() float64 { func (p *cGroupInfoImpl) updateCGroupStats() error { control, err := cgroup2.Load("/") if err != nil { - return fmt.Errorf("failed to get cgroup mem stats %w", err) + return fmt.Errorf("failed to load cgroup: %w", err) } metrics, err := control.Stat() if err != nil { - return fmt.Errorf("failed to get cgroup mem stats %w", err) + return fmt.Errorf("failed to get cgroup stats: %w", err) } // Only update if a limit has been set if metrics.Memory.UsageLimit != 0 { @@ -56,7 +56,7 @@ func (p *cGroupInfoImpl) updateCGroupStats() error { err = p.cgroupCpuCalc.updateCpuUsage(metrics) if err != nil { - return fmt.Errorf("failed to get cgroup cpu usage %w", err) + return fmt.Errorf("failed to get cgroup cpu usage: %w", err) } return nil } diff --git a/contrib/resourcetuner/cgroups_common.go b/contrib/sysinfo/cgroups_common.go similarity index 95% rename from contrib/resourcetuner/cgroups_common.go rename to contrib/sysinfo/cgroups_common.go index f4fd3d244..b314a4599 100644 --- a/contrib/resourcetuner/cgroups_common.go +++ b/contrib/sysinfo/cgroups_common.go @@ -1,4 +1,4 @@ -package resourcetuner +package sysinfo import ( "errors" diff --git a/contrib/resourcetuner/cgroups_notlinux.go b/contrib/sysinfo/cgroups_notlinux.go similarity index 94% rename from contrib/resourcetuner/cgroups_notlinux.go rename to contrib/sysinfo/cgroups_notlinux.go index 068e4220f..d89de073b 100644 --- a/contrib/resourcetuner/cgroups_notlinux.go +++ b/contrib/sysinfo/cgroups_notlinux.go @@ -1,6 +1,6 @@ //go:build !linux -package resourcetuner +package sysinfo import "errors" diff --git a/contrib/resourcetuner/cgroups_test.go b/contrib/sysinfo/cgroups_test.go similarity index 98% rename from contrib/resourcetuner/cgroups_test.go rename to contrib/sysinfo/cgroups_test.go index 57c69f005..a0102b32e 100644 --- a/contrib/resourcetuner/cgroups_test.go +++ b/contrib/sysinfo/cgroups_test.go @@ -1,4 +1,4 @@ -package resourcetuner +package sysinfo import ( "errors" diff --git a/contrib/resourcetuner/go.mod b/contrib/sysinfo/go.mod similarity index 94% rename from contrib/resourcetuner/go.mod rename to contrib/sysinfo/go.mod index ac80a5858..ea15d3d0d 100644 --- a/contrib/resourcetuner/go.mod +++ b/contrib/sysinfo/go.mod @@ -1,4 +1,4 @@ -module go.temporal.io/sdk/contrib/resourcetuner +module go.temporal.io/sdk/contrib/sysinfo go 1.23.0 @@ -8,7 +8,6 @@ require ( github.com/containerd/cgroups/v3 v3.0.3 github.com/shirou/gopsutil/v4 v4.24.8 github.com/stretchr/testify v1.10.0 - go.einride.tech/pid v0.1.3 go.temporal.io/sdk v1.29.1 ) @@ -31,7 +30,7 @@ require ( github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect github.com/robfig/cron v1.2.0 // indirect github.com/shoenig/go-m1cpu v0.1.6 // indirect - github.com/sirupsen/logrus v1.9.3 // indirect + github.com/sirupsen/logrus v1.9.0 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect diff --git a/contrib/resourcetuner/go.sum b/contrib/sysinfo/go.sum similarity index 97% rename from contrib/resourcetuner/go.sum rename to contrib/sysinfo/go.sum index 9e4126df1..6478bd8bf 100644 --- a/contrib/resourcetuner/go.sum +++ b/contrib/sysinfo/go.sum @@ -54,8 +54,8 @@ github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFt github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ= github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU= github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k= -github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= -github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= +github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= @@ -141,5 +141,3 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EV gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU= -gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= diff --git a/contrib/sysinfo/sysinfo.go b/contrib/sysinfo/sysinfo.go new file mode 100644 index 000000000..204b6ea63 --- /dev/null +++ b/contrib/sysinfo/sysinfo.go @@ -0,0 +1,111 @@ +package sysinfo + +import ( + "context" + "runtime" + "sync" + "sync/atomic" + "time" + + "github.com/shirou/gopsutil/v4/cpu" + "github.com/shirou/gopsutil/v4/mem" + "go.temporal.io/sdk/worker" +) + +var sysInfoProvider = sync.OnceValue(func() *psUtilSystemInfoSupplier { + return &psUtilSystemInfoSupplier{ + cGroupInfo: newCGroupInfo(), + } +}) + +// SysInfoProvider returns a shared SysInfoProvider using gopsutil. +// Supports cgroup metrics in containerized Linux environments. +func SysInfoProvider() worker.SysInfoProvider { + return sysInfoProvider() +} + +type psUtilSystemInfoSupplier struct { + mu sync.Mutex + lastRefresh atomic.Int64 // UnixNano, atomic for lock-free reads in maybeRefresh + + lastMemStat *mem.VirtualMemoryStat + lastCpuUsage float64 + + stopTryingToGetCGroupInfo bool + cGroupInfo cGroupInfo +} + +type cGroupInfo interface { + // Update requests an update of the cgroup stats. This is a no-op if not in a cgroup. Returns + // true if cgroup stats should continue to be updated, false if not in a cgroup or the returned + // error is considered unrecoverable. + Update() (bool, error) + // GetLastMemUsage returns last known memory usage as a fraction of the cgroup limit. 0 if not + // in a cgroup or limit is not set. + GetLastMemUsage() float64 + // GetLastCPUUsage returns last known CPU usage as a fraction of the cgroup limit. 0 if not in a + // cgroup or limit is not set. + GetLastCPUUsage() float64 +} + +func (p *psUtilSystemInfoSupplier) MemoryUsage(infoContext *worker.SysInfoContext) (float64, error) { + if err := p.maybeRefresh(infoContext); err != nil { + return 0, err + } + p.mu.Lock() + defer p.mu.Unlock() + lastCGroupMem := p.cGroupInfo.GetLastMemUsage() + if lastCGroupMem != 0 { + return lastCGroupMem, nil + } + return p.lastMemStat.UsedPercent / 100, nil +} + +func (p *psUtilSystemInfoSupplier) CpuUsage(infoContext *worker.SysInfoContext) (float64, error) { + if err := p.maybeRefresh(infoContext); err != nil { + return 0, err + } + p.mu.Lock() + defer p.mu.Unlock() + lastCGroupCPU := p.cGroupInfo.GetLastCPUUsage() + if lastCGroupCPU != 0 { + return lastCGroupCPU, nil + } + return p.lastCpuUsage / 100, nil +} + +func (p *psUtilSystemInfoSupplier) maybeRefresh(infoContext *worker.SysInfoContext) error { + if time.Since(time.Unix(0, p.lastRefresh.Load())) < 100*time.Millisecond { + return nil + } + p.mu.Lock() + defer p.mu.Unlock() + // Double check refresh is still needed + if time.Since(time.Unix(0, p.lastRefresh.Load())) < 100*time.Millisecond { + return nil + } + ctx, cancelFn := context.WithTimeout(context.Background(), 1*time.Second) + defer cancelFn() + memStat, err := mem.VirtualMemoryWithContext(ctx) + if err != nil { + return err + } + cpuUsage, err := cpu.PercentWithContext(ctx, 0, false) + if err != nil { + return err + } + + p.lastMemStat = memStat + p.lastCpuUsage = cpuUsage[0] + + if runtime.GOOS == "linux" && !p.stopTryingToGetCGroupInfo { + continueUpdates, err := p.cGroupInfo.Update() + if err != nil { + infoContext.Logger.Warn("Failed to get cgroup stats", "error", err) + } + p.stopTryingToGetCGroupInfo = !continueUpdates + } + + p.lastRefresh.Store(time.Now().UnixNano()) + return nil +} diff --git a/contrib/sysinfo/sysinfo_test.go b/contrib/sysinfo/sysinfo_test.go new file mode 100644 index 000000000..8e8bfb6be --- /dev/null +++ b/contrib/sysinfo/sysinfo_test.go @@ -0,0 +1,42 @@ +package sysinfo + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.temporal.io/sdk/internal/log" + "go.temporal.io/sdk/worker" +) + +func TestGetMemoryCpuUsage(t *testing.T) { + supplier := SysInfoProvider() + ctx := &worker.SysInfoContext{Logger: log.NewNopLogger()} + + usage, err := supplier.MemoryUsage(ctx) + require.NoError(t, err) + assert.GreaterOrEqual(t, usage, 0.0) + assert.LessOrEqual(t, usage, 1.0) + + usage, err = supplier.CpuUsage(ctx) + require.NoError(t, err) + assert.GreaterOrEqual(t, usage, 0.0) + assert.LessOrEqual(t, usage, 1.0) +} + +func TestMaybeRefreshRateLimiting(t *testing.T) { + supplier := SysInfoProvider().(*psUtilSystemInfoSupplier) + ctx := &worker.SysInfoContext{Logger: log.NewNopLogger()} + + // First call should refresh + firstUsage, err := supplier.MemoryUsage(ctx) + require.NoError(t, err) + firstRefresh := supplier.lastRefresh.Load() + + // Immediate second call should not refresh (rate limited) + secondUsage, err := supplier.MemoryUsage(ctx) + require.NoError(t, err) + assert.Equal(t, firstRefresh, supplier.lastRefresh.Load()) + + assert.Equal(t, firstUsage, secondUsage) +} diff --git a/internal/client.go b/internal/client.go index 31c5850f3..430c77859 100644 --- a/internal/client.go +++ b/internal/client.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "fmt" + "github.com/google/uuid" "sync/atomic" "time" @@ -601,6 +602,14 @@ type ( // // NOTE: Experimental Plugins []ClientPlugin + + // WorkerHeartbeatInterval is the interval at which the worker will send heartbeats to the server. + // Interval must be between 1s and 60s, inclusive, or a negative value to disable. + // + // default: 0 defaults to 60s interval. + // + // NOTE: Experimental + WorkerHeartbeatInterval time.Duration } // HeadersProvider returns a map of gRPC headers that should be used on every request. @@ -1198,7 +1207,9 @@ func NewServiceClient(workflowServiceClient workflowservice.WorkflowServiceClien // Collect set of applicable worker plugins and interceptors var workerPlugins []WorkerPlugin + var clientPluginNames []string for _, plugin := range options.Plugins { + clientPluginNames = append(clientPluginNames, plugin.Name()) if workerPlugin, _ := plugin.(WorkerPlugin); workerPlugin != nil { workerPlugins = append(workerPlugins, workerPlugin) } @@ -1210,6 +1221,18 @@ func NewServiceClient(workflowServiceClient workflowservice.WorkflowServiceClien } } + var heartbeatInterval time.Duration + if options.WorkerHeartbeatInterval < 0 { + heartbeatInterval = 0 + } else if options.WorkerHeartbeatInterval == 0 { + heartbeatInterval = 60 * time.Second + } else { + if options.WorkerHeartbeatInterval < time.Second || options.WorkerHeartbeatInterval > 60*time.Second { + panic("WorkerHeartbeatInterval must be between 1 second and 60 seconds") + } + heartbeatInterval = options.WorkerHeartbeatInterval + } + client := &WorkflowClient{ workflowService: workflowServiceClient, conn: conn, @@ -1223,11 +1246,18 @@ func NewServiceClient(workflowServiceClient workflowservice.WorkflowServiceClien contextPropagators: options.ContextPropagators, workerPlugins: workerPlugins, workerInterceptors: workerInterceptors, + clientPluginNames: clientPluginNames, excludeInternalFromRetry: options.ConnectionOptions.excludeInternalFromRetry, eagerDispatcher: &eagerWorkflowDispatcher{ workersByTaskQueue: make(map[string]map[eagerWorker]struct{}), }, - getSystemInfoTimeout: options.ConnectionOptions.GetSystemInfoTimeout, + getSystemInfoTimeout: options.ConnectionOptions.GetSystemInfoTimeout, + workerHeartbeatInterval: heartbeatInterval, + workerGroupingKey: uuid.NewString(), + } + + if heartbeatInterval > 0 { + client.heartbeatManager = newHeartbeatManager(client, heartbeatInterval, client.logger) } // Create outbound interceptor by wrapping backwards through chain diff --git a/internal/cmd/build/main.go b/internal/cmd/build/main.go index c949c056b..80376e1eb 100644 --- a/internal/cmd/build/main.go +++ b/internal/cmd/build/main.go @@ -159,7 +159,9 @@ func (b *builder) integrationTest() error { "--dynamic-config-value", "history.enableChasm=true", "--dynamic-config-value", "history.enableTransitionHistory=true", "--dynamic-config-value", `component.nexusoperations.useSystemCallbackURL=false`, - "--dynamic-config-value", `component.nexusoperations.callback.endpoint.template="http://localhost:7243/namespaces/{{.NamespaceName}}/nexus/callback"`}, + "--dynamic-config-value", `component.nexusoperations.callback.endpoint.template="http://localhost:7243/namespaces/{{.NamespaceName}}/nexus/callback"`, + "--dynamic-config-value", "frontend.ListWorkersEnabled=true", + }, }) if err != nil { return fmt.Errorf("failed starting dev server: %w", err) diff --git a/internal/internal_nexus_task_poller.go b/internal/internal_nexus_task_poller.go index 6537f3454..857fd56e7 100644 --- a/internal/internal_nexus_task_poller.go +++ b/internal/internal_nexus_task_poller.go @@ -42,6 +42,7 @@ func newNexusTaskPoller( useBuildIDVersioning: params.UseBuildIDForVersioning, workerDeploymentVersion: params.DeploymentOptions.Version, capabilities: params.capabilities, + pollTimeTracker: params.pollTimeTracker, }, taskHandler: taskHandler, service: service, @@ -90,11 +91,9 @@ func (ntp *nexusTaskPoller) poll(ctx context.Context) (taskForWorker, error) { return nil, nil } - return &nexusTask{task: response}, nil -} + ntp.pollTimeTracker.recordPollSuccess(metrics.PollerTypeNexusTask) -func (ntp *nexusTaskPoller) Cleanup() error { - return nil + return &nexusTask{task: response}, nil } // PollTask polls a new task diff --git a/internal/internal_nexus_worker.go b/internal/internal_nexus_worker.go index ba38bd6ef..d5f844ed5 100644 --- a/internal/internal_nexus_worker.go +++ b/internal/internal_nexus_worker.go @@ -78,10 +78,6 @@ func newNexusWorker(opts nexusWorkerOptions) (*nexusWorker, error) { // Start the worker. func (w *nexusWorker) Start() error { - err := verifyNamespaceExist(w.workflowService, w.executionParameters.MetricsHandler, w.executionParameters.Namespace, w.worker.logger) - if err != nil { - return err - } w.worker.Start() return nil } diff --git a/internal/internal_task_pollers.go b/internal/internal_task_pollers.go index dba946879..c5e6dff2b 100644 --- a/internal/internal_task_pollers.go +++ b/internal/internal_task_pollers.go @@ -13,13 +13,10 @@ import ( "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/wrapperspb" - "github.com/google/uuid" - commonpb "go.temporal.io/api/common/v1" deploymentpb "go.temporal.io/api/deployment/v1" enumspb "go.temporal.io/api/enums/v1" historypb "go.temporal.io/api/history/v1" - "go.temporal.io/api/serviceerror" taskqueuepb "go.temporal.io/api/taskqueue/v1" "go.temporal.io/api/workflowservice/v1" @@ -53,9 +50,6 @@ type ( taskPoller interface { // PollTask polls for one new task PollTask() (taskForWorker, error) - // Called when the poller will no longer be polled. Presently only useful for - // workflow workers. - Cleanup() error } // taskProcessor interface to process tasks @@ -85,6 +79,10 @@ type ( workerDeploymentVersion WorkerDeploymentVersion // Server's capabilities capabilities *workflowservice.GetSystemInfoResponse_Capabilities + // tracks timestamp for last poll request, for worker heartbeating + pollTimeTracker *pollTimeTracker + // Unique identifier for worker + workerInstanceKey string } // numPollerMetric tracks the number of active pollers and publishes a metric on it. @@ -208,6 +206,9 @@ type ( ) func newNumPollerMetric(metricsHandler metrics.Handler, pollerType string) *numPollerMetric { + if heartbeatHandler, isHeartbeat := metricsHandler.(*heartbeatMetricsHandler); isHeartbeat { + metricsHandler = heartbeatHandler.forPoller(pollerType) + } return &numPollerMetric{ gauge: metricsHandler.WithTags(metrics.PollerTags(pollerType)).Gauge(metrics.NumPoller), } @@ -315,6 +316,7 @@ func newWorkflowTaskProcessor( contextManager WorkflowContextManager, service workflowservice.WorkflowServiceClient, params workerExecutionParameters, + stickyUUID string, ) *workflowTaskProcessor { return &workflowTaskProcessor{ basePoller: basePoller{ @@ -324,6 +326,8 @@ func newWorkflowTaskProcessor( useBuildIDVersioning: params.UseBuildIDForVersioning, workerDeploymentVersion: params.DeploymentOptions.Version, capabilities: params.capabilities, + pollTimeTracker: params.pollTimeTracker, + workerInstanceKey: params.workerInstanceKey, }, service: service, namespace: params.Namespace, @@ -334,7 +338,7 @@ func newWorkflowTaskProcessor( logger: params.Logger, dataConverter: params.DataConverter, failureConverter: params.FailureConverter, - stickyUUID: uuid.NewString(), + stickyUUID: stickyUUID, StickyScheduleToStartTimeout: params.StickyScheduleToStartTimeout, stickyCacheSize: params.cache.MaxWorkflowCacheSize(), eagerActivityExecutor: params.eagerActivityExecutor, @@ -343,36 +347,6 @@ func newWorkflowTaskProcessor( } } -// Best-effort attempt to indicate to Matching service that this workflow task -// poller's sticky queue will no longer be polled. Should be called when the -// poller is stopping. Failure to call ShutdownWorker is logged, but otherwise -// ignored. -func (wtp *workflowTaskPoller) Cleanup() error { - ctx := context.Background() - grpcCtx, cancel := newGRPCContext(ctx, grpcMetricsHandler(wtp.metricsHandler)) - defer cancel() - - _, err := wtp.service.ShutdownWorker(grpcCtx, &workflowservice.ShutdownWorkerRequest{ - Namespace: wtp.namespace, - StickyTaskQueue: getWorkerTaskQueue(wtp.stickyUUID), - Identity: wtp.identity, - Reason: "graceful shutdown", - }) - - // we ignore unimplemented - if _, isUnimplemented := err.(*serviceerror.Unimplemented); isUnimplemented { - return nil - } - - if err != nil { - traceLog(func() { - wtp.logger.Debug("ShutdownWorker failed.", tagError, err) - }) - } - - return err -} - // PollTask polls a new task func (wtp *workflowTaskPoller) PollTask() (taskForWorker, error) { // Get the task. @@ -737,10 +711,6 @@ func newLocalActivityPoller( } } -func (latp *localActivityTaskPoller) Cleanup() error { - return nil -} - func (latp *localActivityTaskPoller) PollTask() (taskForWorker, error) { return latp.laTunnel.getTask(), nil } @@ -965,6 +935,7 @@ func (wtp *workflowTaskPoller) getNextPollRequest() (request *workflowservice.Po wtp.useBuildIDVersioning, wtp.workerDeploymentVersion, ), + WorkerInstanceKey: wtp.workerInstanceKey, } if wtp.getCapabilities().BuildIdBasedVersioning { //lint:ignore SA1019 ignore deprecated versioning APIs @@ -1008,6 +979,12 @@ func (wtp *workflowTaskPoller) poll(ctx context.Context) (taskForWorker, error) return &workflowTask{}, nil } + if request.TaskQueue.GetKind() == enumspb.TASK_QUEUE_KIND_STICKY { + wtp.pollTimeTracker.recordPollSuccess(metrics.PollerTypeWorkflowStickyTask) + } else { + wtp.pollTimeTracker.recordPollSuccess(metrics.PollerTypeWorkflowTask) + } + wtp.updateBacklog(request.TaskQueue.GetKind(), response.GetBacklogCountHint()) task := wtp.toWorkflowTask(response) @@ -1155,6 +1132,8 @@ func newActivityTaskPoller(taskHandler ActivityTaskHandler, service workflowserv useBuildIDVersioning: params.UseBuildIDForVersioning, workerDeploymentVersion: params.DeploymentOptions.Version, capabilities: params.capabilities, + pollTimeTracker: params.pollTimeTracker, + workerInstanceKey: params.workerInstanceKey, }, taskHandler: taskHandler, service: service, @@ -1194,6 +1173,7 @@ func (atp *activityTaskPoller) poll(ctx context.Context) (taskForWorker, error) atp.useBuildIDVersioning, atp.workerDeploymentVersion, ), + WorkerInstanceKey: atp.workerInstanceKey, } response, err := atp.pollActivityTaskQueue(ctx, request) @@ -1206,6 +1186,8 @@ func (atp *activityTaskPoller) poll(ctx context.Context) (taskForWorker, error) return &activityTask{}, nil } + atp.pollTimeTracker.recordPollSuccess(metrics.PollerTypeActivityTask) + workflowType := response.WorkflowType.GetName() activityType := response.ActivityType.GetName() metricsHandler := atp.metricsHandler.WithTags(metrics.ActivityTags(workflowType, activityType, atp.taskQueueName)) @@ -1216,10 +1198,6 @@ func (atp *activityTaskPoller) poll(ctx context.Context) (taskForWorker, error) return &activityTask{task: response}, nil } -func (atp *activityTaskPoller) Cleanup() error { - return nil -} - // PollTask polls a new task func (atp *activityTaskPoller) PollTask() (taskForWorker, error) { // Get the task. diff --git a/internal/internal_task_pollers_test.go b/internal/internal_task_pollers_test.go index af70a53f6..8fde8791e 100644 --- a/internal/internal_task_pollers_test.go +++ b/internal/internal_task_pollers_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/binary" "errors" + "github.com/google/uuid" "sync/atomic" "testing" "time" @@ -97,7 +98,7 @@ func TestWFTRacePrevention(t *testing.T) { return &workflowservice.RespondWorkflowTaskFailedResponse{}, nil }) - poller := newWorkflowTaskProcessor(taskHandler, contextManager, client, params) + poller := newWorkflowTaskProcessor(taskHandler, contextManager, client, params, uuid.NewString()) t.Log("Issue task0") go func() { resultsChan <- poller.processWorkflowTask(&task0) }() @@ -188,7 +189,7 @@ func TestWFTCorruption(t *testing.T) { return nil, errors.New("Failure responding to workflow task") }) - poller := newWorkflowTaskProcessor(taskHandler, contextManager, client, params) + poller := newWorkflowTaskProcessor(taskHandler, contextManager, client, params, uuid.NewString()) processTaskDone := make(chan struct{}) go func() { require.Error(t, poller.processWorkflowTask(&task0)) @@ -329,7 +330,7 @@ func TestWFTReset(t *testing.T) { client.EXPECT().RespondWorkflowTaskCompleted(gomock.Any(), gomock.Any()). Return(&workflowservice.RespondWorkflowTaskCompletedResponse{}, nil) - poller := newWorkflowTaskProcessor(taskHandler, contextManager, client, params) + poller := newWorkflowTaskProcessor(taskHandler, contextManager, client, params, uuid.NewString()) // Send a full history as part of the speculative WFT require.NoError(t, poller.processWorkflowTask(&task0)) originalCachedExecution := cache.getWorkflowContext(runID) @@ -403,7 +404,7 @@ func TestWFTPanicInTaskHandler(t *testing.T) { task0 = workflowTask{task: &pollResp0} ) - poller := newWorkflowTaskProcessor(taskHandler, contextManager, client, params) + poller := newWorkflowTaskProcessor(taskHandler, contextManager, client, params, uuid.NewString()) require.Error(t, poller.processWorkflowTask(&task0)) // Workflow should not be in cache require.Nil(t, cache.getWorkflowContext(runID)) diff --git a/internal/internal_worker.go b/internal/internal_worker.go index 8bae608fc..f21c63a66 100644 --- a/internal/internal_worker.go +++ b/internal/internal_worker.go @@ -6,11 +6,15 @@ import ( "context" "errors" "fmt" + workerpb "go.temporal.io/api/worker/v1" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/timestamppb" "io" "math" "os" "reflect" "runtime" + "sort" "strconv" "strings" "sync" @@ -24,6 +28,7 @@ import ( deploymentpb "go.temporal.io/api/deployment/v1" enumspb "go.temporal.io/api/enums/v1" historypb "go.temporal.io/api/history/v1" + "go.temporal.io/api/serviceerror" "go.temporal.io/api/temporalproto" "go.temporal.io/api/workflowservice/v1" "go.temporal.io/api/workflowservicemock/v1" @@ -81,6 +86,7 @@ type ( identity string stopC chan struct{} localActivityStopC chan struct{} + stickyUUID string // Used for ShutdownWorker call } // ActivityWorker wraps the code for hosting activity types. @@ -208,6 +214,10 @@ type ( eagerActivityExecutor *eagerActivityExecutor capabilities *workflowservice.GetSystemInfoResponse_Capabilities + + pollTimeTracker *pollTimeTracker + + workerInstanceKey string } // HistoryJSONOptions are options for HistoryFromJSON. @@ -264,6 +274,9 @@ func ensureRequiredParams(params *workerExecutionParameters) { NumNexusSlots: defaultMaxConcurrentTaskExecutionSize, }) } + if params.pollTimeTracker == nil { + params.pollTimeTracker = &pollTimeTracker{} + } } // getBuildID returns either the user-defined build ID if it was provided, or an autogenerated one @@ -280,23 +293,6 @@ func (params *workerExecutionParameters) isInternalWorker() bool { return params.Namespace == "temporal-system" || params.TaskQueue == "temporal-sys-per-ns-tq" } -// verifyNamespaceExist does a DescribeNamespace operation on the specified namespace with backoff/retry -func verifyNamespaceExist( - client workflowservice.WorkflowServiceClient, - metricsHandler metrics.Handler, - namespace string, - logger log.Logger, -) error { - ctx := context.Background() - if namespace == "" { - return errors.New("namespace cannot be empty") - } - grpcCtx, cancel := newGRPCContext(ctx, grpcMetricsHandler(metricsHandler), defaultGrpcRetryParameters(ctx)) - defer cancel() - _, err := client.DescribeNamespace(grpcCtx, &workflowservice.DescribeNamespaceRequest{Namespace: namespace}) - return err -} - func newWorkflowWorkerInternal(client *WorkflowClient, params workerExecutionParameters, ppMgr pressurePointMgr, overrides *workerOverrides, registry *registry) *workflowWorker { workerStopChannel := make(chan struct{}) params.WorkerStopChannel = getReadOnlyChannel(workerStopChannel) @@ -324,7 +320,9 @@ func newWorkflowTaskWorkerInternal( if client != nil { service = client.workflowService } - taskProcessor := newWorkflowTaskProcessor(taskHandler, contextManager, service, params) + // Generate stickyUUID here so it can be stored in workflowWorker for ShutdownWorker call + stickyUUID := uuid.NewString() + taskProcessor := newWorkflowTaskProcessor(taskHandler, contextManager, service, params, stickyUUID) var scalableTaskPollers []scalableTaskPoller switch params.WorkflowTaskPollerBehavior.(type) { @@ -414,15 +412,12 @@ func newWorkflowTaskWorkerInternal( identity: params.Identity, stopC: stopC, localActivityStopC: laStopChannel, + stickyUUID: stickyUUID, } } // Start the worker. func (ww *workflowWorker) Start() error { - err := verifyNamespaceExist(ww.workflowService, ww.executionParameters.MetricsHandler, ww.executionParameters.Namespace, ww.worker.logger) - if err != nil { - return err - } ww.localActivityWorker.Start() ww.worker.Start() return nil // TODO: propagate error @@ -566,10 +561,6 @@ func newActivityWorker( // Start the worker. func (aw *activityWorker) Start() error { - err := verifyNamespaceExist(aw.workflowService, aw.executionParameters.MetricsHandler, aw.executionParameters.Namespace, aw.worker.logger) - if err != nil { - return err - } aw.worker.Start() return nil // TODO: propagate errors } @@ -1169,6 +1160,7 @@ type AggregatedWorker struct { registry *registry // Stores a boolean indicating whether the worker has already been started. started atomic.Bool + shuttingDown atomic.Bool stopC chan struct{} fatalErr error fatalErrLock sync.Mutex @@ -1177,6 +1169,9 @@ type AggregatedWorker struct { workerInstanceKey string plugins []WorkerPlugin pluginRegistryOptions *WorkerPluginConfigureWorkerRegistryOptions // Never nil + + heartbeatMetrics *heartbeatMetricsHandler + heartbeatCallback func() *workerpb.WorkerHeartbeat } // RegisterWorkflow registers workflow implementation with the AggregatedWorker @@ -1286,6 +1281,10 @@ func (aw *AggregatedWorker) start() error { } proto.Merge(aw.capabilities, capabilities) + if _, err := aw.client.loadNamespaceCapabilities(aw.executionParams.MetricsHandler); err != nil { + return err + } + if !util.IsInterfaceNil(aw.workflowWorker) { if err := aw.workflowWorker.Start(); err != nil { return err @@ -1353,6 +1352,12 @@ func (aw *AggregatedWorker) start() error { return fmt.Errorf("failed to start a nexus worker: %w", err) } } + + if aw.client.workerHeartbeatInterval > 0 { + if err := aw.registerHeartbeatWorker(); err != nil { + return fmt.Errorf("failed to register heartbeat worker: %w", err) + } + } aw.logger.Info("Started Worker") return nil } @@ -1441,6 +1446,8 @@ func (aw *AggregatedWorker) Stop() { close(aw.stopC) } + aw.shutdownWorker() + // Issue stop through plugins stop := func(context.Context, WorkerPluginStopWorkerOptions) { if !util.IsInterfaceNil(aw.workflowWorker) { @@ -1470,9 +1477,66 @@ func (aw *AggregatedWorker) Stop() { WorkerInstanceKey: aw.workerInstanceKey, }) + aw.unregisterHeartbeatWorker() + aw.logger.Info("Stopped Worker") } +func (aw *AggregatedWorker) registerHeartbeatWorker() error { + if aw.client.heartbeatManager == nil { + return nil + } + return aw.client.heartbeatManager.registerWorker(aw) +} + +func (aw *AggregatedWorker) unregisterHeartbeatWorker() { + if aw.client.heartbeatManager == nil { + return + } + aw.client.heartbeatManager.unregisterWorker(aw) +} + +// shutdownWorker sends a ShutdownWorker RPC to notify the server that this worker is shutting down. +// When StickyTaskQueue is non-empty, this is a best-effort attempt to indicate to Matching service +// that this workflow task poller's sticky queue will no longer be polled. +// +// NOTE: errors are logged but don't fail the shutdown. +func (aw *AggregatedWorker) shutdownWorker() { + aw.shuttingDown.Store(true) + + ctx := context.Background() + grpcCtx, cancel := newGRPCContext(ctx, grpcMetricsHandler(aw.executionParams.MetricsHandler)) + defer cancel() + + var heartbeat *workerpb.WorkerHeartbeat + if aw.heartbeatCallback != nil { + heartbeat = aw.heartbeatCallback() + } + + var stickyTaskQueue string + if aw.workflowWorker != nil && aw.workflowWorker.stickyUUID != "" { + stickyTaskQueue = getWorkerTaskQueue(aw.workflowWorker.stickyUUID) + } + + _, err := aw.client.workflowService.ShutdownWorker(grpcCtx, &workflowservice.ShutdownWorkerRequest{ + Namespace: aw.executionParams.Namespace, + StickyTaskQueue: stickyTaskQueue, + Identity: aw.executionParams.Identity, + Reason: "graceful shutdown", + WorkerHeartbeat: heartbeat, + WorkerInstanceKey: aw.workerInstanceKey, + }) + + // Ignore unimplemented (server doesn't support it) + if _, isUnimplemented := err.(*serviceerror.Unimplemented); isUnimplemented { + return + } + + if err != nil { + aw.logger.Debug("ShutdownWorker rpc errored during worker shutdown.", tagError, err) + } +} + // WorkflowReplayer is used to replay workflow code from an event history type WorkflowReplayer struct { registry *registry @@ -2026,6 +2090,17 @@ func NewAggregatedWorker(client *WorkflowClient, taskQueue string, options Worke // should take a pointer to this struct and wait for it to be populated when the worker is run. var capabilities workflowservice.GetSystemInfoResponse_Capabilities + baseMetricsHandler := client.metricsHandler.WithTags(metrics.TaskQueueTags(taskQueue)) + var metricsHandler metrics.Handler + var heartbeatMetrics *heartbeatMetricsHandler + + if client.workerHeartbeatInterval != 0 { + heartbeatMetrics = newHeartbeatMetricsHandler(baseMetricsHandler) + metricsHandler = heartbeatMetrics + } else { + metricsHandler = baseMetricsHandler + } + cache := NewWorkerCache() workerParams := workerExecutionParameters{ Namespace: client.namespace, @@ -2037,7 +2112,7 @@ func NewAggregatedWorker(client *WorkflowClient, taskQueue string, options Worke WorkerBuildID: options.BuildID, UseBuildIDForVersioning: options.UseBuildIDForVersioning || options.DeploymentOptions.UseVersioning, DeploymentOptions: options.DeploymentOptions, - MetricsHandler: client.metricsHandler.WithTags(metrics.TaskQueueTags(taskQueue)), + MetricsHandler: metricsHandler, Logger: client.logger, EnableLoggingInReplay: options.EnableLoggingInReplay, BackgroundContext: backgroundActivityContext, @@ -2059,7 +2134,9 @@ func NewAggregatedWorker(client *WorkflowClient, taskQueue string, options Worke taskQueue: taskQueue, maxConcurrent: options.MaxConcurrentEagerActivityExecutionSize, }), - capabilities: &capabilities, + capabilities: &capabilities, + pollTimeTracker: &pollTimeTracker{}, + workerInstanceKey: workerInstanceKey, } if options.MaxConcurrentWorkflowTaskPollers != 0 { @@ -2147,6 +2224,102 @@ func NewAggregatedWorker(client *WorkflowClient, taskQueue string, options Worke }) } + // Get SysInfoProvider from tuner's slot supplier if it implements HasSysInfoProvider. + // If not available, heartbeats will report 0 for CPU/memory usage. + var sysInfoProvider SysInfoProvider + if sis, ok := options.Tuner.GetWorkflowTaskSlotSupplier().(HasSysInfoProvider); ok { + sysInfoProvider = sis.SysInfoProvider() + } + + var heartbeatCallback func() *workerpb.WorkerHeartbeat + if client.workerHeartbeatInterval != 0 { + startTime := timestamppb.New(time.Now()) + hostname, _ := os.Hostname() + pid := strconv.Itoa(os.Getpid()) + previousHeartbeatTime := time.Now() + pluginInfos := collectPluginInfos(client.clientPluginNames, plugins) + + var prevWorkflowProcessed, prevWorkflowFailed int64 + var prevActivityProcessed, prevActivityFailed int64 + var prevLocalActivityProcessed, prevLocalActivityFailed int64 + var prevNexusProcessed, prevNexusFailed int64 + + populateOpts := &populateHeartbeatOptions{ + workflowPollerBehavior: workerParams.WorkflowTaskPollerBehavior, + activityPollerBehavior: workerParams.ActivityTaskPollerBehavior, + nexusPollerBehavior: workerParams.NexusTaskPollerBehavior, + prevWorkflowProcessed: &prevWorkflowProcessed, + prevWorkflowFailed: &prevWorkflowFailed, + prevActivityProcessed: &prevActivityProcessed, + prevActivityFailed: &prevActivityFailed, + prevLocalActivityProcessed: &prevLocalActivityProcessed, + prevLocalActivityFailed: &prevLocalActivityFailed, + prevNexusProcessed: &prevNexusProcessed, + prevNexusFailed: &prevNexusFailed, + pollTimeTracker: workerParams.pollTimeTracker, + } + + var deploymentVersion *deploymentpb.WorkerDeploymentVersion + if options.DeploymentOptions.UseVersioning { + deploymentVersion = &deploymentpb.WorkerDeploymentVersion{ + DeploymentName: options.DeploymentOptions.Version.DeploymentName, + BuildId: options.DeploymentOptions.Version.BuildID, + } + } + + // The callback can be invoked concurrently from the heartbeat worker goroutine and the shutdown path + var mu sync.Mutex + heartbeatCallback = func() *workerpb.WorkerHeartbeat { + cpuUsage := getCpuUsage(sysInfoProvider, workerParams.Logger) + memUsage := getMemUsage(sysInfoProvider, workerParams.Logger) + + mu.Lock() + defer mu.Unlock() + if aw.workflowWorker != nil { + populateOpts.workflowSlotSupplierKind = aw.workflowWorker.worker.slotSupplier.GetSlotSupplierKind() + populateOpts.localActivitySlotSupplierKind = aw.workflowWorker.localActivityWorker.slotSupplier.GetSlotSupplierKind() + } + if aw.activityWorker != nil { + populateOpts.activitySlotSupplierKind = aw.activityWorker.worker.slotSupplier.GetSlotSupplierKind() + } + if aw.nexusWorker != nil { + populateOpts.nexusSlotSupplierKind = aw.nexusWorker.worker.slotSupplier.GetSlotSupplierKind() + } + heartbeatTime := time.Now() + elapsedSinceLastHeartbeat := heartbeatTime.Sub(previousHeartbeatTime) + previousHeartbeatTime = heartbeatTime + + status := enumspb.WORKER_STATUS_RUNNING + if aw.shuttingDown.Load() { + status = enumspb.WORKER_STATUS_SHUTTING_DOWN + } + + hb := &workerpb.WorkerHeartbeat{ + WorkerInstanceKey: aw.workerInstanceKey, + WorkerIdentity: aw.executionParams.Identity, + HostInfo: &workerpb.WorkerHostInfo{ + HostName: hostname, + WorkerGroupingKey: aw.client.workerGroupingKey, + ProcessId: pid, + CurrentHostCpuUsage: cpuUsage, + CurrentHostMemUsage: memUsage, + }, + TaskQueue: aw.executionParams.TaskQueue, + DeploymentVersion: deploymentVersion, + SdkName: SDKName, + SdkVersion: SDKVersion, + Status: status, + StartTime: startTime, + HeartbeatTime: timestamppb.New(heartbeatTime), + ElapsedSinceLastHeartbeat: durationpb.New(elapsedSinceLastHeartbeat), + Plugins: pluginInfos, + } + aw.heartbeatMetrics.PopulateHeartbeat(hb, populateOpts) + + return hb + } + } + aw = &AggregatedWorker{ client: client, workflowWorker: workflowWorker, @@ -2160,6 +2333,8 @@ func NewAggregatedWorker(client *WorkflowClient, taskQueue string, options Worke workerInstanceKey: workerInstanceKey, plugins: plugins, pluginRegistryOptions: &pluginRegistryOptions, + heartbeatMetrics: heartbeatMetrics, + heartbeatCallback: heartbeatCallback, } // Set memoized start as a once-value that invokes plugins first @@ -2478,3 +2653,52 @@ func workerDeploymentVersionFromProtoOrString(wd *deploymentpb.WorkerDeploymentV BuildID: wd.BuildId, } } + +func getCpuUsage(supplier SysInfoProvider, logger log.Logger) float32 { + if supplier == nil { + return 0 + } + cpu, err := supplier.CpuUsage(&SysInfoContext{Logger: logger}) + if err != nil { + logger.Warn("Failed to get CPU usage for heartbeat", "error", err) + return 0 + } + return float32(cpu) +} + +func getMemUsage(supplier SysInfoProvider, logger log.Logger) float32 { + if supplier == nil { + return 0 + } + mem, err := supplier.MemoryUsage(&SysInfoContext{Logger: logger}) + if err != nil { + logger.Warn("Failed to get memory usage for heartbeat", "error", err) + return 0 + } + return float32(mem) +} + +// collectPluginInfos collects plugin names from client and worker plugins, +// deduplicates them, and returns a slice of PluginInfo for heartbeat reporting. +func collectPluginInfos(clientPluginNames []string, workerPlugins []WorkerPlugin) []*workerpb.PluginInfo { + set := make(map[string]struct{}, len(clientPluginNames)+len(workerPlugins)) + result := make([]*workerpb.PluginInfo, 0, len(clientPluginNames)+len(workerPlugins)) + for _, name := range clientPluginNames { + if _, found := set[name]; !found { + set[name] = struct{}{} + result = append(result, &workerpb.PluginInfo{Name: name}) + } + } + for _, plugin := range workerPlugins { + if _, found := set[plugin.Name()]; !found { + set[plugin.Name()] = struct{}{} + result = append(result, &workerpb.PluginInfo{Name: plugin.Name()}) + } + } + + sort.Slice(result, func(i, j int) bool { + return result[i].Name < result[j].Name + }) + + return result +} diff --git a/internal/internal_worker_base.go b/internal/internal_worker_base.go index 4f2ae1493..20b9133ec 100644 --- a/internal/internal_worker_base.go +++ b/internal/internal_worker_base.go @@ -324,6 +324,9 @@ func newBaseWorker( ) *baseWorker { ctx, cancel := context.WithCancel(context.Background()) logger := log.With(options.logger, tagWorkerType, options.workerType) + if heartbeatHandler, isHeartbeat := options.metricsHandler.(*heartbeatMetricsHandler); isHeartbeat { + options.metricsHandler = heartbeatHandler.forWorker(options.workerType) + } metricsHandler := options.metricsHandler.WithTags(metrics.WorkerTags(options.workerType)) tss := newTrackingSlotSupplier(options.slotSupplier, trackingSlotSupplierOptions{ logger: logger, @@ -692,13 +695,6 @@ func (bw *baseWorker) Stop() { close(bw.stopCh) bw.limiterContextCancel() - for _, taskWorker := range bw.options.taskPollers { - err := taskWorker.taskPoller.Cleanup() - if err != nil { - bw.logger.Error("Couldn't cleanup task worker", tagError, err) - } - } - if success := awaitWaitGroup(&bw.stopWG, bw.options.stopTimeout); !success { traceLog(func() { bw.logger.Info("Worker graceful stop timed out.", "Stop timeout", bw.options.stopTimeout) diff --git a/internal/internal_worker_base_test.go b/internal/internal_worker_base_test.go index e900dbba7..a9de195d1 100644 --- a/internal/internal_worker_base_test.go +++ b/internal/internal_worker_base_test.go @@ -242,12 +242,6 @@ func (p *semaphoreProbeTaskPoller) PollTask() (taskForWorker, error) { return nil, nil } -// Cleanup implements taskPoller. -func (p *semaphoreProbeTaskPoller) Cleanup() error { - p.Close() - return nil -} - func (p *semaphoreProbeTaskPoller) Allow(n int) { for range n { for { diff --git a/internal/internal_worker_heartbeat.go b/internal/internal_worker_heartbeat.go new file mode 100644 index 000000000..74ddb9f73 --- /dev/null +++ b/internal/internal_worker_heartbeat.go @@ -0,0 +1,201 @@ +package internal + +import ( + "context" + "fmt" + ilog "go.temporal.io/sdk/internal/log" + "sync" + "sync/atomic" + "time" + + workerpb "go.temporal.io/api/worker/v1" + "go.temporal.io/api/workflowservice/v1" + "go.temporal.io/sdk/log" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// heartbeatManager manages heartbeat workers across namespaces for a client. +type heartbeatManager struct { + client *WorkflowClient + interval time.Duration + logger log.Logger + + workersMutex sync.Mutex + workers map[string]*sharedNamespaceWorker // namespace -> worker +} + +// newHeartbeatManager creates a new heartbeatManager. +func newHeartbeatManager(client *WorkflowClient, interval time.Duration, logger log.Logger) *heartbeatManager { + if logger == nil { + logger = ilog.NewDefaultLogger() + } + return &heartbeatManager{ + client: client, + interval: interval, + logger: logger, + workers: make(map[string]*sharedNamespaceWorker), + } +} + +// registerWorker registers a worker's heartbeat callback with the shared heartbeat worker for the namespace. +func (m *heartbeatManager) registerWorker( + worker *AggregatedWorker, +) error { + capabilities, err := m.client.loadNamespaceCapabilities(worker.heartbeatMetrics) + if err != nil { + return fmt.Errorf("failed to get namespace capabilities: %w", err) + } + if !capabilities.GetWorkerHeartbeats() { + if m.logger != nil { + m.logger.Debug("Worker heartbeating configured, but server version does not support it.") + } + return nil + } + + namespace := worker.executionParams.Namespace + m.workersMutex.Lock() + defer m.workersMutex.Unlock() + + hw, ok := m.workers[namespace] + // If this is the first worker on the namespace, start a new shared namespace worker. + if !ok { + hw = &sharedNamespaceWorker{ + client: m.client, + namespace: namespace, + interval: m.interval, + callbacks: make(map[string]func() *workerpb.WorkerHeartbeat), + stopC: make(chan struct{}), + stoppedC: make(chan struct{}), + logger: m.logger, + } + m.workers[namespace] = hw + if hw.started.Swap(true) { + panic("heartbeat worker already started") + } + go hw.run() + } + + hw.callbacksMutex.Lock() + hw.callbacks[worker.workerInstanceKey] = worker.heartbeatCallback + hw.callbacksMutex.Unlock() + + return nil +} + +// unregisterWorker removes a worker's heartbeat callback. If no callbacks remain for the namespace, +// the shared heartbeat worker is stopped. +func (m *heartbeatManager) unregisterWorker(worker *AggregatedWorker) { + m.workersMutex.Lock() + defer m.workersMutex.Unlock() + + namespace := worker.executionParams.Namespace + hw, ok := m.workers[namespace] + if !ok { + return + } + + hw.callbacksMutex.Lock() + delete(hw.callbacks, worker.workerInstanceKey) + remaining := len(hw.callbacks) + hw.callbacksMutex.Unlock() + + if remaining == 0 { + hw.stop() + delete(m.workers, namespace) + } +} + +// sharedNamespaceWorker handles heartbeating for all workers in a specific namespace for a specific client. +type sharedNamespaceWorker struct { + client *WorkflowClient + namespace string + interval time.Duration + logger log.Logger + + // callbacksMutex should only be unlocked under + callbacksMutex sync.RWMutex + callbacks map[string]func() *workerpb.WorkerHeartbeat // workerInstanceKey -> callback + + stopC chan struct{} + stoppedC chan struct{} + started atomic.Bool +} + +func (hw *sharedNamespaceWorker) run() { + defer close(hw.stoppedC) + + ticker := time.NewTicker(hw.interval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if err := hw.sendHeartbeats(); err != nil { + hw.logger.Warn("Stopping heartbeat worker", "error", err) + return + } + case <-hw.stopC: + return + } + } +} + +func (hw *sharedNamespaceWorker) sendHeartbeats() error { + hw.callbacksMutex.RLock() + callbacks := make([]func() *workerpb.WorkerHeartbeat, 0, len(hw.callbacks)) + for _, cb := range hw.callbacks { + callbacks = append(callbacks, cb) + } + hw.callbacksMutex.RUnlock() + + if len(callbacks) == 0 { + return nil + } + + heartbeats := make([]*workerpb.WorkerHeartbeat, 0, len(callbacks)) + for _, cb := range callbacks { + hb := cb() + heartbeats = append(heartbeats, hb) + } + + _, err := hw.client.recordWorkerHeartbeat(context.Background(), &workflowservice.RecordWorkerHeartbeatRequest{ + Namespace: hw.namespace, + WorkerHeartbeat: heartbeats, + }) + + if err != nil { + if status.Code(err) == codes.Unimplemented { + // Server doesn't support heartbeats; return error to stop the worker. + return fmt.Errorf("server does not support worker heartbeats: %w", err) + } + // For other errors, log and continue heartbeating + hw.logger.Warn("Failed to send heartbeat", "Error", err) + } + return nil +} + +func (hw *sharedNamespaceWorker) stop() { + if !hw.started.CompareAndSwap(true, false) { + return + } + + close(hw.stopC) + <-hw.stoppedC +} + +// pollTimeTracker tracks the last successful poll time for each poller type. +type pollTimeTracker struct { + times sync.Map // pollerType (string) -> time.Time (stored as int64 nanos) +} + +func (p *pollTimeTracker) recordPollSuccess(pollerType string) { + p.times.Store(pollerType, time.Now().UnixNano()) +} + +func (p *pollTimeTracker) getLastPollTime(pollerType string) time.Time { + if v, ok := p.times.Load(pollerType); ok { + return time.Unix(0, v.(int64)) + } + return time.Time{} +} diff --git a/internal/internal_worker_heartbeat_metrics.go b/internal/internal_worker_heartbeat_metrics.go new file mode 100644 index 000000000..f829d5503 --- /dev/null +++ b/internal/internal_worker_heartbeat_metrics.go @@ -0,0 +1,323 @@ +package internal + +import ( + "sync" + "sync/atomic" + "time" + + workerpb "go.temporal.io/api/worker/v1" + "google.golang.org/protobuf/types/known/timestamppb" + + "go.temporal.io/sdk/internal/common/metrics" +) + +// Metrics we capture for heartbeat reporting. +var ( + capturedCounters = map[string]struct{}{ + metrics.StickyCacheHit: {}, + metrics.StickyCacheMiss: {}, + metrics.WorkflowTaskExecutionFailureCounter: {}, + metrics.ActivityExecutionFailedCounter: {}, + metrics.LocalActivityExecutionFailedCounter: {}, + metrics.NexusTaskExecutionFailedCounter: {}, + } + + // Timer recordings are counted (not their latencies) to track tasks processed. + capturedTimers = map[string]struct{}{ + metrics.WorkflowTaskExecutionLatency: {}, + metrics.ActivityExecutionLatency: {}, + metrics.LocalActivityExecutionLatency: {}, + metrics.NexusTaskExecutionLatency: {}, + } +) + +// heartbeatMetricsHandler wraps a metrics handler and captures specific metrics +// in memory for worker heartbeats. +type heartbeatMetricsHandler struct { + underlying metrics.Handler + workerType string + pollerType string + + // Keys are metric names, or "metricName:workerType" / "metricName:pollerType" for typed metrics. + metrics *sync.Map +} + +// newHeartbeatMetricsHandler creates a new handler that captures specific metrics +// for worker heartbeats while passing all metrics to the underlying handler. +func newHeartbeatMetricsHandler(underlying metrics.Handler) *heartbeatMetricsHandler { + return &heartbeatMetricsHandler{ + underlying: underlying, + metrics: &sync.Map{}, + } +} + +// forWorker creates a new handler that captures metrics specific to a worker type, for worker heartbeating. +// This should be called explicitly before calling WithTags on the returned handler. +func (h *heartbeatMetricsHandler) forWorker(workerType string) metrics.Handler { + cpy := *h + cpy.workerType = workerType + return &cpy +} + +// forPoller creates a new handler that captures metrics specific to a poller type, for worker heartbeating. +// This should be called explicitly before calling WithTags on the returned handler. +func (h *heartbeatMetricsHandler) forPoller(pollerType string) metrics.Handler { + cpy := *h + cpy.pollerType = pollerType + return &cpy +} + +func (h *heartbeatMetricsHandler) WithTags(tags map[string]string) metrics.Handler { + cpy := *h + cpy.underlying = h.underlying.WithTags(tags) + return &cpy +} + +func (h *heartbeatMetricsHandler) Counter(name string) metrics.Counter { + underlying := h.underlying.Counter(name) + if _, ok := capturedCounters[name]; ok { + return &capturingCounter{ + underlying: underlying, + value: h.getOrCreate(name), + } + } + return underlying +} + +func (h *heartbeatMetricsHandler) Gauge(name string) metrics.Gauge { + underlying := h.underlying.Gauge(name) + + switch name { + case metrics.StickyCacheSize: + return &capturingGauge{ + underlying: underlying, + value: h.getOrCreate(name), + } + case metrics.WorkerTaskSlotsAvailable, metrics.WorkerTaskSlotsUsed: + if h.workerType != "" { + return &capturingGauge{ + underlying: underlying, + value: h.getOrCreate(name + ":" + h.workerType), + } + } + case metrics.NumPoller: + if h.pollerType != "" { + return &capturingGauge{ + underlying: underlying, + value: h.getOrCreate(name + ":" + h.pollerType), + } + } + } + + return underlying +} + +func (h *heartbeatMetricsHandler) Timer(name string) metrics.Timer { + underlying := h.underlying.Timer(name) + if _, ok := capturedTimers[name]; ok { + return &capturingTimer{ + underlying: underlying, + counter: h.getOrCreate(name), + } + } + return underlying +} + +func (h *heartbeatMetricsHandler) getOrCreate(key string) *atomic.Int64 { + if v, ok := h.metrics.Load(key); ok { + return v.(*atomic.Int64) + } + v := new(atomic.Int64) + actual, _ := h.metrics.LoadOrStore(key, v) + return actual.(*atomic.Int64) +} + +func (h *heartbeatMetricsHandler) get(key string) int64 { + if v, ok := h.metrics.Load(key); ok { + return v.(*atomic.Int64).Load() + } + return 0 +} + +// populateHeartbeatOptions contains extra information needed to populate heartbeats. +type populateHeartbeatOptions struct { + workflowSlotSupplierKind string + activitySlotSupplierKind string + localActivitySlotSupplierKind string + nexusSlotSupplierKind string + + workflowPollerBehavior PollerBehavior + activityPollerBehavior PollerBehavior + nexusPollerBehavior PollerBehavior + + // For delta calculations between heartbeats (mutated by PopulateHeartbeat). + prevWorkflowProcessed *int64 + prevWorkflowFailed *int64 + prevActivityProcessed *int64 + prevActivityFailed *int64 + prevLocalActivityProcessed *int64 + prevLocalActivityFailed *int64 + prevNexusProcessed *int64 + prevNexusFailed *int64 + + pollTimeTracker *pollTimeTracker +} + +// PopulateHeartbeat fills in the metrics-related fields of the WorkerHeartbeat proto. +func (h *heartbeatMetricsHandler) PopulateHeartbeat(hb *workerpb.WorkerHeartbeat, opts *populateHeartbeatOptions) { + hb.TotalStickyCacheHit = int32(h.get(metrics.StickyCacheHit)) + hb.TotalStickyCacheMiss = int32(h.get(metrics.StickyCacheMiss)) + hb.CurrentStickyCacheSize = int32(h.get(metrics.StickyCacheSize)) + + if opts.workflowSlotSupplierKind != "" { + hb.WorkflowTaskSlotsInfo = buildSlotsInfo( + opts.workflowSlotSupplierKind, + int32(h.get(metrics.WorkerTaskSlotsAvailable+":"+"WorkflowWorker")), + int32(h.get(metrics.WorkerTaskSlotsUsed+":"+"WorkflowWorker")), + h.get(metrics.WorkflowTaskExecutionLatency), + h.get(metrics.WorkflowTaskExecutionFailureCounter), + opts.prevWorkflowProcessed, + opts.prevWorkflowFailed, + ) + } + + if opts.activitySlotSupplierKind != "" { + hb.ActivityTaskSlotsInfo = buildSlotsInfo( + opts.activitySlotSupplierKind, + int32(h.get(metrics.WorkerTaskSlotsAvailable+":"+"ActivityWorker")), + int32(h.get(metrics.WorkerTaskSlotsUsed+":"+"ActivityWorker")), + h.get(metrics.ActivityExecutionLatency), + h.get(metrics.ActivityExecutionFailedCounter), + opts.prevActivityProcessed, + opts.prevActivityFailed, + ) + } + + if opts.localActivitySlotSupplierKind != "" { + hb.LocalActivitySlotsInfo = buildSlotsInfo( + opts.localActivitySlotSupplierKind, + int32(h.get(metrics.WorkerTaskSlotsAvailable+":"+"LocalActivityWorker")), + int32(h.get(metrics.WorkerTaskSlotsUsed+":"+"LocalActivityWorker")), + h.get(metrics.LocalActivityExecutionLatency), + h.get(metrics.LocalActivityExecutionFailedCounter), + opts.prevLocalActivityProcessed, + opts.prevLocalActivityFailed, + ) + } + + if opts.nexusSlotSupplierKind != "" { + hb.NexusTaskSlotsInfo = buildSlotsInfo( + opts.nexusSlotSupplierKind, + int32(h.get(metrics.WorkerTaskSlotsAvailable+":"+"NexusWorker")), + int32(h.get(metrics.WorkerTaskSlotsUsed+":"+"NexusWorker")), + h.get(metrics.NexusTaskExecutionLatency), + h.get(metrics.NexusTaskExecutionFailedCounter), + opts.prevNexusProcessed, + opts.prevNexusFailed, + ) + } + + hb.WorkflowPollerInfo = buildPollerInfo( + int32(h.get(metrics.NumPoller+":"+metrics.PollerTypeWorkflowTask)), + opts.pollTimeTracker.getLastPollTime(metrics.PollerTypeWorkflowTask), + opts.workflowPollerBehavior, + ) + hb.WorkflowStickyPollerInfo = buildPollerInfo( + int32(h.get(metrics.NumPoller+":"+metrics.PollerTypeWorkflowStickyTask)), + opts.pollTimeTracker.getLastPollTime(metrics.PollerTypeWorkflowStickyTask), + opts.workflowPollerBehavior, + ) + hb.ActivityPollerInfo = buildPollerInfo( + int32(h.get(metrics.NumPoller+":"+metrics.PollerTypeActivityTask)), + opts.pollTimeTracker.getLastPollTime(metrics.PollerTypeActivityTask), + opts.activityPollerBehavior, + ) + hb.NexusPollerInfo = buildPollerInfo( + int32(h.get(metrics.NumPoller+":"+metrics.PollerTypeNexusTask)), + opts.pollTimeTracker.getLastPollTime(metrics.PollerTypeNexusTask), + opts.nexusPollerBehavior, + ) +} + +func (h *heartbeatMetricsHandler) Unwrap() metrics.Handler { + return h.underlying +} + +func buildSlotsInfo( + supplierKind string, + slotsAvailable int32, + slotsUsed int32, + totalProcessed int64, + totalFailed int64, + prevProcessed *int64, + prevFailed *int64, +) *workerpb.WorkerSlotsInfo { + intervalProcessed := totalProcessed - *prevProcessed + intervalFailed := totalFailed - *prevFailed + + *prevProcessed = totalProcessed + *prevFailed = totalFailed + + return &workerpb.WorkerSlotsInfo{ + CurrentAvailableSlots: slotsAvailable, + CurrentUsedSlots: slotsUsed, + SlotSupplierKind: supplierKind, + TotalProcessedTasks: int32(totalProcessed), + TotalFailedTasks: int32(totalFailed), + LastIntervalProcessedTasks: int32(intervalProcessed), + LastIntervalFailureTasks: int32(intervalFailed), + } +} + +func buildPollerInfo(currentPollers int32, lastSuccessfulPollTime time.Time, pollerBehavior PollerBehavior) *workerpb.WorkerPollerInfo { + var isAutoscaling bool + switch pollerBehavior.(type) { + case *pollerBehaviorAutoscaling: + isAutoscaling = true + } + var pollTime *timestamppb.Timestamp + if !lastSuccessfulPollTime.IsZero() { + pollTime = timestamppb.New(lastSuccessfulPollTime) + } + + return &workerpb.WorkerPollerInfo{ + CurrentPollers: currentPollers, + LastSuccessfulPollTime: pollTime, + IsAutoscaling: isAutoscaling, + } +} + +// capturingCounter wraps a counter and captures its value in memory. +type capturingCounter struct { + underlying metrics.Counter + value *atomic.Int64 +} + +func (c *capturingCounter) Inc(delta int64) { + c.underlying.Inc(delta) + if delta > 0 { + c.value.Add(delta) + } +} + +// capturingGauge wraps a gauge and captures its value in memory. +type capturingGauge struct { + underlying metrics.Gauge + value *atomic.Int64 +} + +func (g *capturingGauge) Update(f float64) { + g.underlying.Update(f) + g.value.Store(int64(f)) +} + +// capturingTimer wraps a timer and increments a counter each time Record is called. +type capturingTimer struct { + underlying metrics.Timer + counter *atomic.Int64 +} + +func (t *capturingTimer) Record(d time.Duration) { + t.underlying.Record(d) + t.counter.Add(1) +} diff --git a/internal/internal_worker_interfaces_test.go b/internal/internal_worker_interfaces_test.go index 87361c7b6..58a491382 100644 --- a/internal/internal_worker_interfaces_test.go +++ b/internal/internal_worker_interfaces_test.go @@ -212,7 +212,6 @@ func (s *InterfacesTestSuite) TestInterface() { s.service.EXPECT().PollWorkflowTaskQueue(gomock.Any(), gomock.Any(), gomock.Any()).Return(&workflowservice.PollWorkflowTaskQueueResponse{}, nil).AnyTimes() s.service.EXPECT().RespondWorkflowTaskCompleted(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() s.service.EXPECT().StartWorkflowExecution(gomock.Any(), gomock.Any(), gomock.Any()).Return(&workflowservice.StartWorkflowExecutionResponse{}, nil).AnyTimes() - s.service.EXPECT().ShutdownWorker(gomock.Any(), gomock.Any(), gomock.Any()).Return(&workflowservice.ShutdownWorkerResponse{}, nil).Times(1) registry := newRegistry() // Launch worker. diff --git a/internal/internal_workers_test.go b/internal/internal_workers_test.go index acb799a52..8c6509d62 100644 --- a/internal/internal_workers_test.go +++ b/internal/internal_workers_test.go @@ -73,10 +73,8 @@ func TestWorkersTestSuite(t *testing.T) { } func (s *WorkersTestSuite) TestWorkflowWorker() { - s.service.EXPECT().DescribeNamespace(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) s.service.EXPECT().PollWorkflowTaskQueue(gomock.Any(), gomock.Any(), gomock.Any()).Return(&workflowservice.PollWorkflowTaskQueueResponse{}, nil).AnyTimes() s.service.EXPECT().RespondWorkflowTaskCompleted(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() - s.service.EXPECT().ShutdownWorker(gomock.Any(), gomock.Any(), gomock.Any()).Return(&workflowservice.ShutdownWorkerResponse{}, nil).Times(1) ctx, cancel := context.WithCancelCause(context.Background()) executionParameters := workerExecutionParameters{ @@ -157,7 +155,6 @@ func (s *WorkersTestSuite) TestWorkflowWorkerSlotSupplier() { unblockPollCh := make(chan struct{}) pollRespondedCh := make(chan struct{}) - s.service.EXPECT().DescribeNamespace(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) s.service.EXPECT().PollWorkflowTaskQueue(gomock.Any(), gomock.Any(), gomock.Any()). Do(func(ctx, in interface{}, opts ...interface{}) { <-unblockPollCh @@ -168,7 +165,6 @@ func (s *WorkersTestSuite) TestWorkflowWorkerSlotSupplier() { pollRespondedCh <- struct{}{} }). Return(nil, nil).AnyTimes() - s.service.EXPECT().ShutdownWorker(gomock.Any(), gomock.Any(), gomock.Any()).Return(&workflowservice.ShutdownWorkerResponse{}, nil).Times(1) ctx, cancel := context.WithCancelCause(context.Background()) wfCss := &CountingSlotSupplier{} @@ -223,7 +219,6 @@ func (s *WorkersTestSuite) TestActivityWorkerSlotSupplier() { unblockPollCh := make(chan struct{}) pollRespondedCh := make(chan struct{}) - s.service.EXPECT().DescribeNamespace(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) s.service.EXPECT().PollActivityTaskQueue(gomock.Any(), gomock.Any(), gomock.Any()). Do(func(ctx, in interface{}, opts ...interface{}) { <-unblockPollCh @@ -303,7 +298,6 @@ func (s *WorkersTestSuite) TestErrorProneSlotSupplier() { unblockPollCh := make(chan struct{}) pollRespondedCh := make(chan struct{}) - s.service.EXPECT().DescribeNamespace(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) s.service.EXPECT().PollActivityTaskQueue(gomock.Any(), gomock.Any(), gomock.Any()). Do(func(ctx, in interface{}, opts ...interface{}) { <-unblockPollCh @@ -348,7 +342,6 @@ func (s *WorkersTestSuite) TestErrorProneSlotSupplier() { } func (s *WorkersTestSuite) TestActivityWorker() { - s.service.EXPECT().DescribeNamespace(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) s.service.EXPECT().PollActivityTaskQueue(gomock.Any(), gomock.Any(), gomock.Any()).Return(&workflowservice.PollActivityTaskQueueResponse{}, nil).AnyTimes() s.service.EXPECT().RespondActivityTaskCompleted(gomock.Any(), gomock.Any(), gomock.Any()).Return(&workflowservice.RespondActivityTaskCompletedResponse{}, nil).AnyTimes() @@ -394,7 +387,6 @@ func (s *WorkersTestSuite) TestActivityWorkerStop() { WorkflowNamespace: "namespace", } - s.service.EXPECT().DescribeNamespace(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) s.service.EXPECT().PollActivityTaskQueue(gomock.Any(), gomock.Any(), gomock.Any()).Return(pats, nil).AnyTimes() s.service.EXPECT().RespondActivityTaskCompleted(gomock.Any(), gomock.Any(), gomock.Any()).Return(&workflowservice.RespondActivityTaskCompletedResponse{}, nil).AnyTimes() @@ -442,9 +434,7 @@ func (s *WorkersTestSuite) TestActivityWorkerStop() { } func (s *WorkersTestSuite) TestPollWorkflowTaskQueue_InternalServiceError() { - s.service.EXPECT().DescribeNamespace(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) s.service.EXPECT().PollWorkflowTaskQueue(gomock.Any(), gomock.Any(), gomock.Any()).Return(&workflowservice.PollWorkflowTaskQueueResponse{}, serviceerror.NewInternal("")).AnyTimes() - s.service.EXPECT().ShutdownWorker(gomock.Any(), gomock.Any(), gomock.Any()).Return(&workflowservice.ShutdownWorkerResponse{}, nil).Times(1) executionParameters := workerExecutionParameters{ Namespace: DefaultNamespace, diff --git a/internal/internal_workflow_client.go b/internal/internal_workflow_client.go index b7bf35126..32646224e 100644 --- a/internal/internal_workflow_client.go +++ b/internal/internal_workflow_client.go @@ -21,6 +21,7 @@ import ( commonpb "go.temporal.io/api/common/v1" enumspb "go.temporal.io/api/enums/v1" historypb "go.temporal.io/api/history/v1" + namespacepb "go.temporal.io/api/namespace/v1" "go.temporal.io/api/operatorservice/v1" querypb "go.temporal.io/api/query/v1" "go.temporal.io/api/sdk/v1" @@ -59,24 +60,30 @@ const ( type ( // WorkflowClient is the client for starting a workflow execution. WorkflowClient struct { - workflowService workflowservice.WorkflowServiceClient - conn *grpc.ClientConn - namespace string - registry *registry - logger log.Logger - metricsHandler metrics.Handler - identity string - dataConverter converter.DataConverter - failureConverter converter.FailureConverter - contextPropagators []ContextPropagator - workerPlugins []WorkerPlugin - workerInterceptors []WorkerInterceptor - interceptor ClientOutboundInterceptor - excludeInternalFromRetry *atomic.Bool - capabilities *workflowservice.GetSystemInfoResponse_Capabilities - capabilitiesLock sync.RWMutex - eagerDispatcher *eagerWorkflowDispatcher - getSystemInfoTimeout time.Duration + workflowService workflowservice.WorkflowServiceClient + conn *grpc.ClientConn + namespace string + registry *registry + logger log.Logger + metricsHandler metrics.Handler + identity string + dataConverter converter.DataConverter + failureConverter converter.FailureConverter + contextPropagators []ContextPropagator + workerPlugins []WorkerPlugin + workerInterceptors []WorkerInterceptor + clientPluginNames []string + interceptor ClientOutboundInterceptor + excludeInternalFromRetry *atomic.Bool + capabilities *workflowservice.GetSystemInfoResponse_Capabilities + capabilitiesLock sync.RWMutex + namespaceCapabilities *namespacepb.NamespaceInfo_Capabilities + namespaceCapabilitiesLock sync.RWMutex + eagerDispatcher *eagerWorkflowDispatcher + getSystemInfoTimeout time.Duration + workerHeartbeatInterval time.Duration + workerGroupingKey string + heartbeatManager *heartbeatManager // The pointer value is shared across multiple clients. If non-nil, only // access/mutate atomically. @@ -1391,6 +1398,36 @@ func (wc *WorkflowClient) loadCapabilities(ctx context.Context) (*workflowservic return capabilities, nil } +// Get namespace capabilities, lazily fetching from server if not already obtained. +func (wc *WorkflowClient) loadNamespaceCapabilities(metricsHandler metrics.Handler) (*namespacepb.NamespaceInfo_Capabilities, error) { + ctx := contextWithNewHeader(context.Background()) + wc.namespaceCapabilitiesLock.RLock() + capabilities := wc.namespaceCapabilities + wc.namespaceCapabilitiesLock.RUnlock() + if capabilities != nil { + return capabilities, nil + } + + grpcCtx, cancel := newGRPCContext(ctx, grpcMetricsHandler(metricsHandler), defaultGrpcRetryParameters(ctx)) + defer cancel() + resp, err := wc.workflowService.DescribeNamespace(grpcCtx, &workflowservice.DescribeNamespaceRequest{Namespace: wc.namespace}) + var unimplemented *serviceerror.Unimplemented + if err != nil && !errors.As(err, &unimplemented) { + return nil, fmt.Errorf("failed reaching server: %w", err) + } + if resp != nil { + capabilities = resp.GetNamespaceInfo().GetCapabilities() + } + if capabilities == nil { + capabilities = &namespacepb.NamespaceInfo_Capabilities{} + } + + wc.namespaceCapabilitiesLock.Lock() + wc.namespaceCapabilities = capabilities + wc.namespaceCapabilitiesLock.Unlock() + return capabilities, nil +} + func (wc *WorkflowClient) ensureInitialized(ctx context.Context) error { // Just loading the capabilities is enough _, err := wc.loadCapabilities(ctx) @@ -1418,6 +1455,21 @@ func (wc *WorkflowClient) WorkerDeploymentClient() WorkerDeploymentClient { } } +func (wc *WorkflowClient) recordWorkerHeartbeat(ctx context.Context, request *workflowservice.RecordWorkerHeartbeatRequest) (*workflowservice.RecordWorkerHeartbeatResponse, error) { + if err := wc.ensureInitialized(ctx); err != nil { + return nil, err + } + + grpcCtx, cancel := newGRPCContext(ctx, defaultGrpcRetryParameters(ctx)) + defer cancel() + resp, err := wc.workflowService.RecordWorkerHeartbeat(grpcCtx, request) + if err != nil { + return nil, err + } + + return resp, nil +} + // Close client and clean up underlying resources. func (wc *WorkflowClient) Close() { // If there's a set of unclosed clients, we have to decrement it and then diff --git a/contrib/resourcetuner/resourcetuner.go b/internal/resource_tuner.go similarity index 57% rename from contrib/resourcetuner/resourcetuner.go rename to internal/resource_tuner.go index 1145d3cd7..34f199b38 100644 --- a/contrib/resourcetuner/resourcetuner.go +++ b/internal/resource_tuner.go @@ -1,18 +1,13 @@ -package resourcetuner +package internal import ( "context" "errors" - "runtime" "sync" "time" - "github.com/shirou/gopsutil/v4/cpu" - "github.com/shirou/gopsutil/v4/mem" - "go.einride.tech/pid" - "go.temporal.io/sdk/client" + "go.temporal.io/sdk/internal/common/metrics" "go.temporal.io/sdk/log" - "go.temporal.io/sdk/worker" ) // Metric names emitted by the resource-based tuner @@ -21,6 +16,37 @@ const ( resourceSlotsMemUsage = "temporal_resource_slots_mem_usage" ) +// SysInfoProvider implementations provide information about system resources. +// +// Exposed as: [go.temporal.io/sdk/worker.SysInfoProvider] +type SysInfoProvider interface { + // MemoryUsage returns the current system memory usage as a fraction of total memory between + // 0 and 1. + MemoryUsage(infoContext *SysInfoContext) (float64, error) + // CpuUsage returns the current system CPU usage as a fraction of total CPU usage between 0 + // and 1. + CpuUsage(infoContext *SysInfoContext) (float64, error) +} + +// SysInfoContext provides context for SysInfoProvider calls. +// +// Exposed as: [go.temporal.io/sdk/worker.SysInfoContext] +type SysInfoContext struct { + Logger log.Logger +} + +// HasSysInfoProvider is an optional interface that SlotSupplier implementations can implement +// to expose their SysInfoProvider. This allows the SDK to access system metrics (CPU/memory) +// for features like worker heartbeats without coupling to specific SlotSupplier implementations. +// +// Exposed as: [go.temporal.io/sdk/worker.HasSysInfoProvider] +type HasSysInfoProvider interface { + SysInfoProvider() SysInfoProvider +} + +// ResourceBasedTunerOptions configures a resource-based tuner. +// +// Exposed as: [go.temporal.io/sdk/worker.ResourceBasedTunerOptions] type ResourceBasedTunerOptions struct { // TargetMem is the target overall system memory usage as value 0 and 1 that the controller will // attempt to maintain. Must be set nonzero. @@ -28,6 +54,9 @@ type ResourceBasedTunerOptions struct { // TargetCpu is the target overall system CPU usage as value 0 and 1 that the controller will // attempt to maintain. Must be set nonzero. TargetCpu float64 + // InfoSupplier provides CPU and memory usage information. This is required. + // Use contrib/sysinfo.SysInfoProvider() for a gopsutil-based implementation. + InfoSupplier SysInfoProvider // Passed to ResourceBasedSlotSupplierOptions.RampThrottle for activities. // If not set, the default value is 50ms. ActivityRampThrottle time.Duration @@ -38,11 +67,22 @@ type ResourceBasedTunerOptions struct { // NewResourceBasedTuner creates a WorkerTuner that dynamically adjusts the number of slots based // on system resources. Specify the target CPU and memory usage as a value between 0 and 1. -func NewResourceBasedTuner(opts ResourceBasedTunerOptions) (worker.WorkerTuner, error) { - options := DefaultResourceControllerOptions() - options.MemTargetPercent = opts.TargetMem - options.CpuTargetPercent = opts.TargetCpu - controller := NewResourceController(options) +// +// InfoSupplier is required - use contrib/sysinfo.SysInfoProvider() for a gopsutil-based +// implementation, or provide your own. +// +// Exposed as: [go.temporal.io/sdk/worker.NewResourceBasedTuner] +func NewResourceBasedTuner(opts ResourceBasedTunerOptions) (WorkerTuner, error) { + if opts.InfoSupplier == nil { + return nil, errors.New("InfoSupplier is required for resource-based tuning") + } + + controllerOpts := DefaultResourceControllerOptions() + controllerOpts.MemTargetPercent = opts.TargetMem + controllerOpts.CpuTargetPercent = opts.TargetCpu + controllerOpts.InfoSupplier = opts.InfoSupplier + controller := NewResourceController(controllerOpts) + wfSS := &ResourceBasedSlotSupplier{controller: controller, options: DefaultWorkflowResourceBasedSlotSupplierOptions()} if opts.WorkflowRampThrottle != 0 { @@ -62,20 +102,19 @@ func NewResourceBasedTuner(opts ResourceBasedTunerOptions) (worker.WorkerTuner, options: DefaultWorkflowResourceBasedSlotSupplierOptions()} sessSS := &ResourceBasedSlotSupplier{controller: controller, options: DefaultActivityResourceBasedSlotSupplierOptions()} - compositeTuner, err := worker.NewCompositeTuner(worker.CompositeTunerOptions{ + + return NewCompositeTuner(CompositeTunerOptions{ WorkflowSlotSupplier: wfSS, ActivitySlotSupplier: actSS, LocalActivitySlotSupplier: laSS, NexusSlotSupplier: nexusSS, SessionActivitySlotSupplier: sessSS, }) - if err != nil { - return nil, err - } - return compositeTuner, nil } // ResourceBasedSlotSupplierOptions configures a particular ResourceBasedSlotSupplier. +// +// Exposed as: [go.temporal.io/sdk/worker.ResourceBasedSlotSupplierOptions] type ResourceBasedSlotSupplierOptions struct { // MinSlots is minimum number of slots that will be issued without any resource checks. MinSlots int @@ -87,6 +126,9 @@ type ResourceBasedSlotSupplierOptions struct { RampThrottle time.Duration } +// DefaultWorkflowResourceBasedSlotSupplierOptions returns default options for workflow slot suppliers. +// +// Exposed as: [go.temporal.io/sdk/worker.DefaultWorkflowResourceBasedSlotSupplierOptions] func DefaultWorkflowResourceBasedSlotSupplierOptions() ResourceBasedSlotSupplierOptions { return ResourceBasedSlotSupplierOptions{ MinSlots: 5, @@ -94,6 +136,10 @@ func DefaultWorkflowResourceBasedSlotSupplierOptions() ResourceBasedSlotSupplier RampThrottle: 0 * time.Second, } } + +// DefaultActivityResourceBasedSlotSupplierOptions returns default options for activity slot suppliers. +// +// Exposed as: [go.temporal.io/sdk/worker.DefaultActivityResourceBasedSlotSupplierOptions] func DefaultActivityResourceBasedSlotSupplierOptions() ResourceBasedSlotSupplierOptions { return ResourceBasedSlotSupplierOptions{ MinSlots: 1, @@ -102,8 +148,9 @@ func DefaultActivityResourceBasedSlotSupplierOptions() ResourceBasedSlotSupplier } } -// ResourceBasedSlotSupplier is a worker.SlotSupplier that issues slots based on system resource -// usage. +// ResourceBasedSlotSupplier is a SlotSupplier that issues slots based on system resource usage. +// +// Exposed as: [go.temporal.io/sdk/worker.ResourceBasedSlotSupplier] type ResourceBasedSlotSupplier struct { controller *ResourceController options ResourceBasedSlotSupplierOptions @@ -115,12 +162,14 @@ type ResourceBasedSlotSupplier struct { // NewResourceBasedSlotSupplier creates a ResourceBasedSlotSupplier given the provided // ResourceController and ResourceBasedSlotSupplierOptions. All ResourceBasedSlotSupplier instances // must use the same ResourceController. +// +// Exposed as: [go.temporal.io/sdk/worker.NewResourceBasedSlotSupplier] func NewResourceBasedSlotSupplier( controller *ResourceController, options ResourceBasedSlotSupplierOptions, ) (*ResourceBasedSlotSupplier, error) { if options.MinSlots < 0 || options.MaxSlots < 0 || options.MinSlots > options.MaxSlots { - return nil, errors.New("MinSlots and Max slots must be non-negative and MinSlots must be less than or equal to MaxSlots") + return nil, errors.New("MinSlots and MaxSlots must be non-negative and MinSlots must be less than or equal to MaxSlots") } if options.RampThrottle < 0 { return nil, errors.New("RampThrottle must be non-negative") @@ -128,16 +177,14 @@ func NewResourceBasedSlotSupplier( return &ResourceBasedSlotSupplier{controller: controller, options: options}, nil } -func (r *ResourceBasedSlotSupplier) ReserveSlot(ctx context.Context, info worker.SlotReservationInfo) (*worker.SlotPermit, error) { +func (r *ResourceBasedSlotSupplier) ReserveSlot(ctx context.Context, info SlotReservationInfo) (*SlotPermit, error) { for { if info.NumIssuedSlots() < r.options.MinSlots { - return &worker.SlotPermit{}, nil + return &SlotPermit{}, nil } if r.options.RampThrottle > 0 { r.lastIssuedMu.Lock() mustWaitFor := r.options.RampThrottle - time.Since(r.lastSlotIssuedAt) - // Deal with last issued possibly being unset, or, on windows seemingly sometimes can - // have zero values if called rapidly enough. if mustWaitFor > 0 { select { case <-time.After(mustWaitFor): @@ -157,7 +204,7 @@ func (r *ResourceBasedSlotSupplier) ReserveSlot(ctx context.Context, info worker } } -func (r *ResourceBasedSlotSupplier) TryReserveSlot(info worker.SlotReservationInfo) *worker.SlotPermit { +func (r *ResourceBasedSlotSupplier) TryReserveSlot(info SlotReservationInfo) *SlotPermit { r.lastIssuedMu.Lock() defer r.lastIssuedMu.Unlock() @@ -171,35 +218,28 @@ func (r *ResourceBasedSlotSupplier) TryReserveSlot(info worker.SlotReservationIn } if decision { r.lastSlotIssuedAt = time.Now() - return &worker.SlotPermit{} + return &SlotPermit{} } } return nil } -func (r *ResourceBasedSlotSupplier) MarkSlotUsed(worker.SlotMarkUsedInfo) {} -func (r *ResourceBasedSlotSupplier) ReleaseSlot(worker.SlotReleaseInfo) {} +func (r *ResourceBasedSlotSupplier) MarkSlotUsed(SlotMarkUsedInfo) {} +func (r *ResourceBasedSlotSupplier) ReleaseSlot(SlotReleaseInfo) {} func (r *ResourceBasedSlotSupplier) MaxSlots() int { return 0 } -// SystemInfoSupplier implementations provide information about system resources. -type SystemInfoSupplier interface { - // GetMemoryUsage returns the current system memory usage as a fraction of total memory between - // 0 and 1. - GetMemoryUsage(infoContext *SystemInfoContext) (float64, error) - // GetCpuUsage returns the current system CPU usage as a fraction of total CPU usage between 0 - // and 1. - GetCpuUsage(infoContext *SystemInfoContext) (float64, error) -} - -type SystemInfoContext struct { - Logger log.Logger +// GetSysInfoProvider returns the SysInfoProvider used by this slot supplier's controller. +func (r *ResourceBasedSlotSupplier) SysInfoProvider() SysInfoProvider { + return r.controller.infoSupplier } // ResourceControllerOptions contains configurable parameters for a ResourceController. // It is recommended to use DefaultResourceControllerOptions to create a ResourceControllerOptions // and only modify the mem/cpu target percent fields. +// +// Exposed as: [go.temporal.io/sdk/worker.ResourceControllerOptions] type ResourceControllerOptions struct { // MemTargetPercent is the target overall system memory usage as value 0 and 1 that the // controller will attempt to maintain. @@ -207,9 +247,8 @@ type ResourceControllerOptions struct { // CpuTargetPercent is the target overall system CPU usage as value 0 and 1 that the controller // will attempt to maintain. CpuTargetPercent float64 - // SystemInfoSupplier is the supplier that the controller will use to get system resources. - // Leave this nil to use the default implementation. - InfoSupplier SystemInfoSupplier + // InfoSupplier is the supplier that the controller will use to get system resources. + InfoSupplier SysInfoProvider MemOutputThreshold float64 CpuOutputThreshold float64 @@ -223,6 +262,8 @@ type ResourceControllerOptions struct { } // DefaultResourceControllerOptions returns a ResourceControllerOptions with default values. +// +// Exposed as: [go.temporal.io/sdk/worker.DefaultResourceControllerOptions] func DefaultResourceControllerOptions() ResourceControllerOptions { return ResourceControllerOptions{ MemTargetPercent: 0.8, @@ -238,60 +279,73 @@ func DefaultResourceControllerOptions() ResourceControllerOptions { } } -// A ResourceController is used by ResourceBasedSlotSupplier to make decisions about whether slots +// pidController implements a simple PID controller for resource-based tuning. +// This is the standard PID formula: output = Kp*error + Ki*integral + Kd*derivative +type pidController struct { + pGain, iGain, dGain float64 + + prevError float64 + integral float64 + controlSignal float64 +} + +func (c *pidController) update(reference, actual float64, dt time.Duration) { + err := reference - actual + c.integral += err * dt.Seconds() + derivative := (err - c.prevError) / dt.Seconds() + c.controlSignal = c.pGain*err + c.iGain*c.integral + c.dGain*derivative + c.prevError = err +} + +// ResourceController is used by ResourceBasedSlotSupplier to make decisions about whether slots // should be issued based on system resource usage. +// +// Exposed as: [go.temporal.io/sdk/worker.ResourceController] type ResourceController struct { options ResourceControllerOptions mu sync.Mutex - infoSupplier SystemInfoSupplier + infoSupplier SysInfoProvider lastRefresh time.Time - memPid *pid.Controller - cpuPid *pid.Controller + memPid *pidController + cpuPid *pidController } // NewResourceController creates a new ResourceController with the provided options. // WARNING: It is important that you do not create multiple ResourceController instances. Since // the controller looks at overall system resources, multiple instances with different configs can // only conflict with one another. +// +// Exposed as: [go.temporal.io/sdk/worker.NewResourceController] func NewResourceController(options ResourceControllerOptions) *ResourceController { - var infoSupplier SystemInfoSupplier if options.InfoSupplier == nil { - infoSupplier = &psUtilSystemInfoSupplier{ - cGroupInfo: newCGroupInfo(), - } - } else { - infoSupplier = options.InfoSupplier + panic("InfoSupplier is required - use contrib/sysinfo.SysInfoProvider() or provide your own") } return &ResourceController{ options: options, - infoSupplier: infoSupplier, - memPid: &pid.Controller{ - Config: pid.ControllerConfig{ - ProportionalGain: options.MemPGain, - IntegralGain: options.MemIGain, - DerivativeGain: options.MemDGain, - }, + infoSupplier: options.InfoSupplier, + memPid: &pidController{ + pGain: options.MemPGain, + iGain: options.MemIGain, + dGain: options.MemDGain, }, - cpuPid: &pid.Controller{ - Config: pid.ControllerConfig{ - ProportionalGain: options.CpuPGain, - IntegralGain: options.CpuIGain, - DerivativeGain: options.CpuDGain, - }, + cpuPid: &pidController{ + pGain: options.CpuPGain, + iGain: options.CpuIGain, + dGain: options.CpuDGain, }, } } -func (rc *ResourceController) pidDecision(logger log.Logger, metricsHandler client.MetricsHandler) (bool, error) { +func (rc *ResourceController) pidDecision(logger log.Logger, metricsHandler metrics.Handler) (bool, error) { rc.mu.Lock() defer rc.mu.Unlock() - memUsage, err := rc.infoSupplier.GetMemoryUsage(&SystemInfoContext{Logger: logger}) + memUsage, err := rc.infoSupplier.MemoryUsage(&SysInfoContext{Logger: logger}) if err != nil { return false, err } - cpuUsage, err := rc.infoSupplier.GetCpuUsage(&SystemInfoContext{Logger: logger}) + cpuUsage, err := rc.infoSupplier.CpuUsage(&SysInfoContext{Logger: logger}) if err != nil { return false, err } @@ -306,110 +360,18 @@ func (rc *ResourceController) pidDecision(logger log.Logger, metricsHandler clie if elapsedTime <= 0 { elapsedTime = 1 * time.Millisecond } - rc.memPid.Update(pid.ControllerInput{ - ReferenceSignal: rc.options.MemTargetPercent, - ActualSignal: memUsage, - SamplingInterval: elapsedTime, - }) - rc.cpuPid.Update(pid.ControllerInput{ - ReferenceSignal: rc.options.CpuTargetPercent, - ActualSignal: cpuUsage, - SamplingInterval: elapsedTime, - }) + rc.memPid.update(rc.options.MemTargetPercent, memUsage, elapsedTime) + rc.cpuPid.update(rc.options.CpuTargetPercent, cpuUsage, elapsedTime) rc.lastRefresh = time.Now() - return rc.memPid.State.ControlSignal > rc.options.MemOutputThreshold && - rc.cpuPid.State.ControlSignal > rc.options.CpuOutputThreshold, nil + return rc.memPid.controlSignal > rc.options.MemOutputThreshold && + rc.cpuPid.controlSignal > rc.options.CpuOutputThreshold, nil } -func (rc *ResourceController) publishResourceMetrics(metricsHandler client.MetricsHandler, memUsage, cpuUsage float64) { +func (rc *ResourceController) publishResourceMetrics(metricsHandler metrics.Handler, memUsage, cpuUsage float64) { if metricsHandler == nil { return } metricsHandler.Gauge(resourceSlotsMemUsage).Update(memUsage * 100) metricsHandler.Gauge(resourceSlotsCPUUsage).Update(cpuUsage * 100) } - -type psUtilSystemInfoSupplier struct { - logger log.Logger - mu sync.Mutex - lastRefresh time.Time - - lastMemStat *mem.VirtualMemoryStat - lastCpuUsage float64 - - stopTryingToGetCGroupInfo bool - cGroupInfo cGroupInfo -} - -type cGroupInfo interface { - // Update requests an update of the cgroup stats. This is a no-op if not in a cgroup. Returns - // true if cgroup stats should continue to be updated, false if not in a cgroup or the returned - // error is considered unrecoverable. - Update() (bool, error) - // GetLastMemUsage returns last known memory usage as a fraction of the cgroup limit. 0 if not - // in a cgroup or limit is not set. - GetLastMemUsage() float64 - // GetLastCPUUsage returns last known CPU usage as a fraction of the cgroup limit. 0 if not in a - // cgroup or limit is not set. - GetLastCPUUsage() float64 -} - -func (p *psUtilSystemInfoSupplier) GetMemoryUsage(infoContext *SystemInfoContext) (float64, error) { - if err := p.maybeRefresh(infoContext); err != nil { - return 0, err - } - lastCGroupMem := p.cGroupInfo.GetLastMemUsage() - if lastCGroupMem != 0 { - return lastCGroupMem, nil - } - return p.lastMemStat.UsedPercent / 100, nil -} - -func (p *psUtilSystemInfoSupplier) GetCpuUsage(infoContext *SystemInfoContext) (float64, error) { - if err := p.maybeRefresh(infoContext); err != nil { - return 0, err - } - - lastCGroupCPU := p.cGroupInfo.GetLastCPUUsage() - if lastCGroupCPU != 0 { - return lastCGroupCPU, nil - } - return p.lastCpuUsage / 100, nil -} - -func (p *psUtilSystemInfoSupplier) maybeRefresh(infoContext *SystemInfoContext) error { - if time.Since(p.lastRefresh) < 100*time.Millisecond { - return nil - } - p.mu.Lock() - defer p.mu.Unlock() - // Double check refresh is still needed - if time.Since(p.lastRefresh) < 100*time.Millisecond { - return nil - } - ctx, cancelFn := context.WithTimeout(context.Background(), 1*time.Second) - defer cancelFn() - memStat, err := mem.VirtualMemoryWithContext(ctx) - if err != nil { - return err - } - cpuUsage, err := cpu.PercentWithContext(ctx, 0, false) - if err != nil { - return err - } - - p.lastMemStat = memStat - p.lastCpuUsage = cpuUsage[0] - - if runtime.GOOS == "linux" && !p.stopTryingToGetCGroupInfo { - continueUpdates, err := p.cGroupInfo.Update() - if err != nil { - infoContext.Logger.Warn("Failed to get cgroup stats", "error", err) - } - p.stopTryingToGetCGroupInfo = !continueUpdates - } - - p.lastRefresh = time.Now() - return nil -} diff --git a/contrib/resourcetuner/resourcetuner_test.go b/internal/resource_tuner_test.go similarity index 86% rename from contrib/resourcetuner/resourcetuner_test.go rename to internal/resource_tuner_test.go index 603d53da8..d97554401 100644 --- a/contrib/resourcetuner/resourcetuner_test.go +++ b/internal/resource_tuner_test.go @@ -1,12 +1,10 @@ -package resourcetuner +package internal import ( - "testing" - "github.com/stretchr/testify/assert" - "go.temporal.io/sdk/client" "go.temporal.io/sdk/internal/common/metrics" "go.temporal.io/sdk/internal/log" + "testing" ) type FakeSystemInfoSupplier struct { @@ -14,17 +12,17 @@ type FakeSystemInfoSupplier struct { cpuUse float64 } -func (f FakeSystemInfoSupplier) GetMemoryUsage(_ *SystemInfoContext) (float64, error) { +func (f FakeSystemInfoSupplier) MemoryUsage(_ *SysInfoContext) (float64, error) { return f.memUse, nil } -func (f FakeSystemInfoSupplier) GetCpuUsage(_ *SystemInfoContext) (float64, error) { +func (f FakeSystemInfoSupplier) CpuUsage(_ *SysInfoContext) (float64, error) { return f.cpuUse, nil } func TestPidDecisions(t *testing.T) { logger := &log.NoopLogger{} - metricsHandler := client.MetricsNopHandler + metricsHandler := metrics.NopHandler fakeSupplier := &FakeSystemInfoSupplier{memUse: 0.5, cpuUse: 0.5} rcOpts := DefaultResourceControllerOptions() rcOpts.MemTargetPercent = 0.8 @@ -37,8 +35,8 @@ func TestPidDecisions(t *testing.T) { assert.NoError(t, err) assert.True(t, decision) - assert.InDelta(t, 1.5, rc.memPid.State.ControlSignal, 0.001) - assert.InDelta(t, 2.0, rc.cpuPid.State.ControlSignal, 0.001) + assert.InDelta(t, 1.5, rc.memPid.controlSignal, 0.001) + assert.InDelta(t, 2.0, rc.cpuPid.controlSignal, 0.001) } fakeSupplier.memUse = 0.8 diff --git a/internal/tuning.go b/internal/tuning.go index 8d146800f..f6597680f 100644 --- a/internal/tuning.go +++ b/internal/tuning.go @@ -130,6 +130,17 @@ type SlotSupplier interface { MaxSlots() int } +func getSlotSupplierKind(s SlotSupplier) string { + switch s.(type) { + case *FixedSizeSlotSupplier: + return "Fixed" + case *ResourceBasedSlotSupplier: + return "ResourceBased" + default: + return "Custom" + } +} + // CompositeTuner allows you to build a tuner from multiple slot suppliers. type CompositeTuner struct { workflowSlotSupplier SlotSupplier @@ -478,6 +489,7 @@ func (t *trackingSlotSupplier) ReleaseSlot(permit *SlotPermit, reason SlotReleas if permit.extraReleaseCallback != nil { permit.extraReleaseCallback() } + t.publishMetrics(usedSlots) } @@ -487,3 +499,7 @@ func (t *trackingSlotSupplier) publishMetrics(usedSlots int) { } t.taskSlotsUsedGauge.Update(float64(usedSlots)) } + +func (t *trackingSlotSupplier) GetSlotSupplierKind() string { + return getSlotSupplierKind(t.inner) +} diff --git a/test/go.mod b/test/go.mod index 9e158cd68..c3454ecf4 100644 --- a/test/go.mod +++ b/test/go.mod @@ -17,9 +17,9 @@ require ( go.opentelemetry.io/otel/trace v1.28.0 go.temporal.io/api v1.62.1 go.temporal.io/sdk v1.29.1 + go.temporal.io/sdk/contrib/sysinfo v0.0.0-00010101000000-000000000000 go.temporal.io/sdk/contrib/opentelemetry v0.0.0-00010101000000-000000000000 go.temporal.io/sdk/contrib/opentracing v0.0.0-00010101000000-000000000000 - go.temporal.io/sdk/contrib/resourcetuner v0.0.0-00010101000000-000000000000 go.temporal.io/sdk/contrib/tally v0.0.0-00010101000000-000000000000 go.uber.org/goleak v1.1.12 google.golang.org/grpc v1.67.1 @@ -46,13 +46,12 @@ require ( github.com/robfig/cron v1.2.0 // indirect github.com/shirou/gopsutil/v4 v4.24.8 // indirect github.com/shoenig/go-m1cpu v0.1.6 // indirect - github.com/sirupsen/logrus v1.9.3 // indirect + github.com/sirupsen/logrus v1.9.0 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect github.com/twmb/murmur3 v1.1.5 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect - go.einride.tech/pid v0.1.3 // indirect go.opentelemetry.io/otel/metric v1.28.0 // indirect go.uber.org/atomic v1.9.0 // indirect golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 // indirect @@ -68,8 +67,8 @@ require ( replace ( go.temporal.io/sdk => ../ + go.temporal.io/sdk/contrib/sysinfo => ../contrib/sysinfo go.temporal.io/sdk/contrib/opentelemetry => ../contrib/opentelemetry go.temporal.io/sdk/contrib/opentracing => ../contrib/opentracing - go.temporal.io/sdk/contrib/resourcetuner => ../contrib/resourcetuner go.temporal.io/sdk/contrib/tally => ../contrib/tally ) diff --git a/test/go.sum b/test/go.sum index 6962bedd8..fa7890939 100644 --- a/test/go.sum +++ b/test/go.sum @@ -135,8 +135,8 @@ github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnj github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= -github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= -github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= +github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= @@ -160,8 +160,6 @@ github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9dec github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= -go.einride.tech/pid v0.1.3 h1:yWAKSmD2Z10jxd4gYFhOjbBNqXeIQwAtnCO/XKCT7sQ= -go.einride.tech/pid v0.1.3/go.mod h1:33JSUbKrH/4v8DZf/0K8IC8Enjd92wB2birp+bCYQso= go.opentelemetry.io/otel v1.28.0 h1:/SqNcYk+idO0CxKEUOtKQClMK/MimZihKYMruSMViUo= go.opentelemetry.io/otel v1.28.0/go.mod h1:q68ijF8Fc8CnMHKyzqL6akLO46ePnjkgfIMIjUIX9z4= go.opentelemetry.io/otel/metric v1.28.0 h1:f0HGvSl1KRAU1DLgLGFjrwVyismPlnuU6JD6bOeuA5Q= @@ -288,5 +286,3 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU= -gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= diff --git a/test/integration_test.go b/test/integration_test.go index 9ce4821fc..763e41b0c 100644 --- a/test/integration_test.go +++ b/test/integration_test.go @@ -5,6 +5,7 @@ import ( "errors" "flag" "fmt" + "go.temporal.io/sdk/contrib/sysinfo" "math" "math/rand" "os" @@ -37,7 +38,6 @@ import ( "go.temporal.io/sdk/contrib/opentelemetry" sdkopentracing "go.temporal.io/sdk/contrib/opentracing" - "go.temporal.io/sdk/contrib/resourcetuner" "go.temporal.io/sdk/converter" "go.temporal.io/sdk/test" @@ -177,10 +177,11 @@ func (ts *IntegrationTestSuite) SetupTest() { NewKeysPropagator([]string{testContextKey1}), NewKeysPropagator([]string{testContextKey2}), }, - MetricsHandler: metricsHandler, - TrafficController: trafficController, - Interceptors: clientInterceptors, - ConnectionOptions: client.ConnectionOptions{TLS: ts.config.TLS}, + MetricsHandler: metricsHandler, + TrafficController: trafficController, + Interceptors: clientInterceptors, + ConnectionOptions: client.ConnectionOptions{TLS: ts.config.TLS}, + WorkerHeartbeatInterval: -1, }) ts.NoError(err) @@ -242,9 +243,10 @@ func (ts *IntegrationTestSuite) SetupTest() { options.MaxConcurrentLocalActivityExecutionSize = 2 } if strings.Contains(ts.T().Name(), "ResourceBasedSlotSupplier") { - tuner, err := resourcetuner.NewResourceBasedTuner(resourcetuner.ResourceBasedTunerOptions{ - TargetMem: 0.9, - TargetCpu: 0.9, + tuner, err := worker.NewResourceBasedTuner(worker.ResourceBasedTunerOptions{ + TargetMem: 0.9, + TargetCpu: 0.9, + InfoSupplier: sysinfo.SysInfoProvider(), }) ts.NoError(err) options.Tuner = tuner @@ -3572,7 +3574,6 @@ func (ts *IntegrationTestSuite) TestSlotSupplierWFTFailMetrics() { run, err := ts.client.ExecuteWorkflow(ctx, wfOptions, waitsToProceedWorkflow) ts.NoError(err) ts.NotNil(run) - ts.NoError(err) <-actStarted // The workflow task will fail once and then pass diff --git a/test/test_utils_test.go b/test/test_utils_test.go index 644b5514b..ece9100e1 100644 --- a/test/test_utils_test.go +++ b/test/test_utils_test.go @@ -238,6 +238,7 @@ func (ts *ConfigAndClientSuiteBase) newClient() (client.Client, error) { TLS: ts.config.TLS, GetSystemInfoTimeout: ctxTimeout, }, + WorkerHeartbeatInterval: -1, }) } diff --git a/test/worker_heartbeat_test.go b/test/worker_heartbeat_test.go new file mode 100644 index 000000000..d920bd596 --- /dev/null +++ b/test/worker_heartbeat_test.go @@ -0,0 +1,915 @@ +package test_test + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/google/uuid" + "github.com/nexus-rpc/sdk-go/nexus" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "go.temporal.io/api/enums/v1" + workerpb "go.temporal.io/api/worker/v1" + "go.temporal.io/api/workflowservice/v1" + "go.temporal.io/sdk/activity" + "go.temporal.io/sdk/client" + "go.temporal.io/sdk/contrib/sysinfo" + "go.temporal.io/sdk/internal" + ilog "go.temporal.io/sdk/internal/log" + "go.temporal.io/sdk/temporal" + "go.temporal.io/sdk/worker" + "go.temporal.io/sdk/workflow" + "google.golang.org/protobuf/types/known/timestamppb" +) + +type WorkerHeartbeatTestSuite struct { + *require.Assertions + suite.Suite + ConfigAndClientSuiteBase + worker worker.Worker +} + +func TestWorkerHeartbeatSuite(t *testing.T) { + suite.Run(t, new(WorkerHeartbeatTestSuite)) +} + +func (ts *WorkerHeartbeatTestSuite) SetupSuite() { + ts.Assertions = require.New(ts.T()) + ts.NoError(ts.InitConfigAndNamespace()) +} + +func (ts *WorkerHeartbeatTestSuite) TearDownSuite() { + ts.Assertions = require.New(ts.T()) +} + +func (ts *WorkerHeartbeatTestSuite) SetupTest() { + var err error + // Create a client with heartbeating enabled + ts.client, err = client.Dial(client.Options{ + HostPort: ts.config.ServiceAddr, + Namespace: ts.config.Namespace, + Logger: ilog.NewDefaultLogger(), + WorkerHeartbeatInterval: 1 * time.Second, + ConnectionOptions: client.ConnectionOptions{TLS: ts.config.TLS}, + Identity: "WorkerHeartbeatTest", + }) + ts.NoError(err) + + ts.taskQueueName = taskQueuePrefix + "-" + ts.T().Name() +} + +func (ts *WorkerHeartbeatTestSuite) TearDownTest() { + if ts.worker != nil { + ts.worker.Stop() + ts.worker = nil + } + if ts.client != nil { + ts.client.Close() + ts.client = nil + } +} + +// assertRecentTimestamp asserts the timestamp is within maxAge of now +func (ts *WorkerHeartbeatTestSuite) assertRecentTimestamp(timestamp *timestamppb.Timestamp, maxAge time.Duration, name string) { + ts.NotNil(timestamp, "%s should not be nil", name) + ts.False(timestamp.AsTime().IsZero(), "%s should not be zero", name) + ts.WithinDuration(time.Now(), timestamp.AsTime(), maxAge, "%s should be recent", name) +} + +// TestWorkerHeartbeat verifies that worker heartbeats are sent to the server +// and can be queried via ListWorkers and DescribeWorker APIs +func (ts *WorkerHeartbeatTestSuite) TestWorkerHeartbeatBasic() { + workerStartTime := time.Now() + + worker.SetStickyWorkflowCacheSize(5) + ts.worker = worker.New(ts.client, ts.taskQueueName, worker.Options{ + MaxConcurrentWorkflowTaskExecutionSize: 5, + MaxConcurrentActivityExecutionSize: 5, + DisableEagerActivities: true, + }) + ts.worker.RegisterWorkflow(workflowWithBlockingActivity) + ts.worker.RegisterActivity(blockingActivity) + // Register a nexus service so the nexus worker is created and slot info is populated + nexusService := nexus.NewService("test-heartbeat") + ts.NoError(nexusService.Register(noopNexusOp)) + ts.worker.RegisterNexusService(nexusService) + ts.Nil(ts.worker.Start()) + + ctx := context.Background() + wfOptions := ts.startWorkflowOptions("test-worker-heartbeat") + + run, err := ts.client.ExecuteWorkflow(ctx, wfOptions, workflowWithBlockingActivity) + ts.NoError(err) + ts.NotNil(run) + // Wait for activity to start + select { + case <-blockingActivityStarted: + ts.T().Log("Activity started") + case <-time.After(5 * time.Second): + ts.Fail("Timeout waiting for activity to start") + } + + var workerInfo *workerpb.WorkerHeartbeat + // Wait for heartbeat to capture the in-flight activity + ts.Eventually(func() bool { + workerInfo = ts.getWorkerInfo(ctx, ts.taskQueueName) + return workerInfo != nil && workerInfo.ActivityTaskSlotsInfo != nil && + workerInfo.ActivityTaskSlotsInfo.CurrentUsedSlots >= 1 + }, 5*time.Second, 200*time.Millisecond, "Should find worker with activity slot used") + + ts.Equal(enums.WORKER_STATUS_RUNNING, workerInfo.Status) + + workflowTaskSlots := workerInfo.WorkflowTaskSlotsInfo + ts.Equal(int32(1), workflowTaskSlots.TotalProcessedTasks) + ts.Equal(int32(5), workflowTaskSlots.CurrentAvailableSlots) + ts.Equal(int32(0), workflowTaskSlots.CurrentUsedSlots) + ts.Equal("Fixed", workflowTaskSlots.SlotSupplierKind) + activityTaskSlots := workerInfo.ActivityTaskSlotsInfo + ts.Equal(int32(0), activityTaskSlots.TotalProcessedTasks) + ts.Equal(int32(4), activityTaskSlots.CurrentAvailableSlots) + ts.Equal(int32(1), activityTaskSlots.CurrentUsedSlots) + ts.Equal("Fixed", activityTaskSlots.SlotSupplierKind) + nexusTaskSlots := workerInfo.NexusTaskSlotsInfo + ts.NotNil(nexusTaskSlots) + ts.Equal(int32(0), nexusTaskSlots.TotalProcessedTasks) + ts.Equal(int32(1000), nexusTaskSlots.CurrentAvailableSlots) + ts.Equal(int32(0), nexusTaskSlots.CurrentUsedSlots) + ts.Equal("Fixed", nexusTaskSlots.SlotSupplierKind) + localActivityTaskSlots := workerInfo.LocalActivitySlotsInfo + ts.Equal(int32(0), localActivityTaskSlots.TotalProcessedTasks) + ts.Equal(int32(1000), localActivityTaskSlots.CurrentAvailableSlots) + ts.Equal(int32(0), localActivityTaskSlots.CurrentUsedSlots) + ts.Equal("Fixed", localActivityTaskSlots.SlotSupplierKind) + + workflowPollerInfo := workerInfo.WorkflowPollerInfo + ts.NotEqual(int32(0), workflowPollerInfo.CurrentPollers) + nexusPollerInfo := workerInfo.NexusPollerInfo + ts.NotEqual(int32(0), nexusPollerInfo.CurrentPollers) + activityPollerInfo := workerInfo.ActivityPollerInfo + ts.NotEqual(int32(0), activityPollerInfo.CurrentPollers) + + if ts.config.maxWorkflowCacheSize > 0 { + stickyPollerInfo := workerInfo.WorkflowStickyPollerInfo + ts.NotEqual(int32(0), stickyPollerInfo.CurrentPollers) + ts.GreaterOrEqual(workerInfo.CurrentStickyCacheSize, int32(1)) + } + + ts.assertRecentTimestamp(workerInfo.StartTime, 10*time.Second, "StartTime") + ts.assertRecentTimestamp(workerInfo.HeartbeatTime, 5*time.Second, "HeartbeatTime") + + ts.WithinDuration(workerStartTime, workerInfo.StartTime.AsTime(), 5*time.Second, + "StartTime should match worker creation time") + + ts.True(workerInfo.HeartbeatTime.AsTime().After(workerInfo.StartTime.AsTime()) || + workerInfo.HeartbeatTime.AsTime().Equal(workerInfo.StartTime.AsTime()), + "HeartbeatTime should be >= StartTime") + + ts.NotNil(workerInfo.ElapsedSinceLastHeartbeat) + elapsed := workerInfo.ElapsedSinceLastHeartbeat.AsDuration() + ts.True(elapsed <= 5*time.Second, + "ElapsedSinceLastHeartbeat should be <= 5s (got %v)", elapsed) + + ts.assertRecentTimestamp(workerInfo.WorkflowPollerInfo.LastSuccessfulPollTime, 5*time.Second, + "WorkflowPollerInfo.LastSuccessfulPollTime") + ts.assertRecentTimestamp(workerInfo.ActivityPollerInfo.LastSuccessfulPollTime, 5*time.Second, + "ActivityPollerInfo.LastSuccessfulPollTime") + + // Store values to compare after shutdown + firstStartTime := workerInfo.StartTime.AsTime() + firstHeartbeatTime := workerInfo.HeartbeatTime.AsTime() + + // Signal activity to complete + blockingActivityComplete <- struct{}{} + + ts.NoError(run.Get(ctx, nil)) + ts.worker.Stop() + + workerInfo = ts.getWorkerInfo(ctx, ts.taskQueueName) + ts.NotNil(workerInfo, "Should find worker in ListWorkers/DescribeWorker") + + // After shutdown checks + ts.Equal("WorkerHeartbeatTest", workerInfo.WorkerIdentity) + hostInfo := workerInfo.HostInfo + ts.NotEqual("", hostInfo.HostName) + ts.NotEqual("", hostInfo.ProcessId) + ts.NotEqual("", hostInfo.WorkerGroupingKey) + + ts.GreaterOrEqual(hostInfo.CurrentHostCpuUsage, float32(0.0)) + ts.GreaterOrEqual(hostInfo.CurrentHostMemUsage, float32(0.0)) + + ts.Equal(ts.taskQueueName, workerInfo.TaskQueue) + ts.Equal(internal.SDKName, workerInfo.SdkName) + ts.Equal(internal.SDKVersion, workerInfo.SdkVersion) + ts.Equal(enums.WORKER_STATUS_SHUTTING_DOWN, workerInfo.Status) + + // Timestamp validations - second heartbeat check (after shutdown) + // StartTime should be unchanged + ts.Equal(firstStartTime, workerInfo.StartTime.AsTime()) + + // HeartbeatTime should have advanced + ts.True(workerInfo.HeartbeatTime.AsTime().After(firstHeartbeatTime)) + + workflowTaskSlots = workerInfo.WorkflowTaskSlotsInfo + ts.Equal(int32(2), workflowTaskSlots.TotalProcessedTasks) + ts.Equal("Fixed", workflowTaskSlots.SlotSupplierKind) + activityTaskSlots = workerInfo.ActivityTaskSlotsInfo + ts.Equal(int32(1), activityTaskSlots.TotalProcessedTasks) + ts.Equal(int32(5), activityTaskSlots.CurrentAvailableSlots) + ts.Equal(int32(0), activityTaskSlots.CurrentUsedSlots) + ts.Equal(int32(1), activityTaskSlots.LastIntervalProcessedTasks) + ts.Equal("Fixed", activityTaskSlots.SlotSupplierKind) + nexusTaskSlots = workerInfo.NexusTaskSlotsInfo + ts.NotNil(nexusTaskSlots) + ts.Equal(int32(0), nexusTaskSlots.TotalProcessedTasks) + ts.Equal(int32(1000), nexusTaskSlots.CurrentAvailableSlots) + ts.Equal(int32(0), nexusTaskSlots.CurrentUsedSlots) + ts.Equal("Fixed", nexusTaskSlots.SlotSupplierKind) + localActivityTaskSlots = workerInfo.LocalActivitySlotsInfo + ts.Equal(int32(0), localActivityTaskSlots.TotalProcessedTasks) + ts.Equal(int32(1000), localActivityTaskSlots.CurrentAvailableSlots) + ts.Equal(int32(0), localActivityTaskSlots.CurrentUsedSlots) + ts.Equal("Fixed", localActivityTaskSlots.SlotSupplierKind) + + workflowPollerInfo = workerInfo.WorkflowPollerInfo + ts.NotEqual(int32(0), workflowPollerInfo.CurrentPollers) + ts.False(workflowPollerInfo.IsAutoscaling) + ts.assertRecentTimestamp(workflowPollerInfo.LastSuccessfulPollTime, 10*time.Second, + "WorkflowPollerInfo.LastSuccessfulPollTime after shutdown") + + if ts.config.maxWorkflowCacheSize > 0 { + stickyPollerInfo := workerInfo.WorkflowStickyPollerInfo + ts.NotEqual(int32(0), stickyPollerInfo.CurrentPollers) + ts.False(stickyPollerInfo.IsAutoscaling) + ts.assertRecentTimestamp(stickyPollerInfo.LastSuccessfulPollTime, 10*time.Second, + "WorkflowStickyPollerInfo.LastSuccessfulPollTime after shutdown") + } + + nexusPollerInfo = workerInfo.NexusPollerInfo + ts.NotEqual(int32(0), nexusPollerInfo.CurrentPollers) + ts.False(nexusPollerInfo.IsAutoscaling) + // Nexus poller has no successful polls since we didn't execute any nexus operations + + activityPollerInfo = workerInfo.ActivityPollerInfo + ts.NotEqual(int32(0), activityPollerInfo.CurrentPollers) + ts.False(activityPollerInfo.IsAutoscaling) + ts.assertRecentTimestamp(activityPollerInfo.LastSuccessfulPollTime, 10*time.Second, + "ActivityPollerInfo.LastSuccessfulPollTime after shutdown") + + if ts.config.maxWorkflowCacheSize > 0 { + ts.GreaterOrEqual(workerInfo.TotalStickyCacheHit, int32(1)) + } +} + +// TestWorkerHeartbeatDeploymentVersion verifies that deployment version info is +// included in heartbeats when versioning is enabled. This test doesn't run workflows +// since versioned workers require additional server-side setup for task routing. +func (ts *WorkerHeartbeatTestSuite) TestWorkerHeartbeatDeploymentVersion() { + ctx := context.Background() + + taskQueue := ts.taskQueueName + "-deployment-version" + + w := worker.New(ts.client, taskQueue, worker.Options{ + DeploymentOptions: worker.DeploymentOptions{ + UseVersioning: true, + Version: worker.WorkerDeploymentVersion{ + DeploymentName: "test-deployment", + BuildID: "test_build_id", + }, + DefaultVersioningBehavior: internal.VersioningBehaviorAutoUpgrade, + }, + }) + w.RegisterWorkflow(simpleWorkflow) + ts.NoError(w.Start()) + defer w.Stop() + + // Wait for heartbeat to be sent + var workerInfo *workerpb.WorkerHeartbeat + ts.Eventually(func() bool { + workerInfo = ts.getWorkerInfo(ctx, taskQueue) + return workerInfo != nil && workerInfo.DeploymentVersion != nil + }, 5*time.Second, 200*time.Millisecond, "Should find worker with deployment version") + + ts.NotNil(workerInfo.DeploymentVersion) + ts.Equal("test_build_id", workerInfo.DeploymentVersion.BuildId) + ts.Equal("test-deployment", workerInfo.DeploymentVersion.DeploymentName) +} + +// TestWorkerHeartbeatDisabled verifies that when heartbeating is disabled, +// workers should not appear in ListWorkers +func (ts *WorkerHeartbeatTestSuite) TestWorkerHeartbeatDisabled() { + ctx := context.Background() + + // Create a separate client with heartbeating disabled + clientNoHeartbeat, err := client.Dial(client.Options{ + HostPort: ts.config.ServiceAddr, + Namespace: ts.config.Namespace, + Logger: ilog.NewDefaultLogger(), + WorkerHeartbeatInterval: -1, + ConnectionOptions: client.ConnectionOptions{TLS: ts.config.TLS}, + }) + ts.NoError(err) + defer clientNoHeartbeat.Close() + + taskQueueNoHeartbeat := taskQueuePrefix + "-no-heartbeat-" + ts.T().Name() + + // Create and start worker with no heartbeating + workerNoHeartbeat := worker.New(clientNoHeartbeat, taskQueueNoHeartbeat, worker.Options{}) + workerNoHeartbeat.RegisterWorkflow(simpleWorkflow) + ts.NoError(workerNoHeartbeat.Start()) + defer workerNoHeartbeat.Stop() + + // Wait a bit + time.Sleep(2 * time.Second) + + // Get the internal client + internalClient := clientNoHeartbeat.(internal.Client) + workflowClient := internalClient.(*internal.WorkflowClient) + + // List workers - should not find the worker without heartbeating + listResp, err := workflowClient.WorkflowService().ListWorkers(ctx, &workflowservice.ListWorkersRequest{ + Namespace: ts.config.Namespace, + Query: fmt.Sprintf(`TaskQueue="%s"`, taskQueueNoHeartbeat), + PageSize: 10, + }) + + ts.NoError(err, "ListWorkers failed") + foundWorker := false + for _, workerInfo := range listResp.WorkersInfo { + if workerInfo.WorkerHeartbeat.TaskQueue == taskQueueNoHeartbeat { + foundWorker = true + break + } + } + ts.False(foundWorker, "Should not find worker without heartbeating enabled") +} + +// Get worker info from the server +func (ts *WorkerHeartbeatTestSuite) getWorkerInfo(ctx context.Context, taskQueue string) *workerpb.WorkerHeartbeat { + // Get the internal client to access the workflow service directly + internalClient := ts.client.(internal.Client) + workflowClient := internalClient.(*internal.WorkflowClient) + + // List workers in this namespace + listResp, err := workflowClient.WorkflowService().ListWorkers(ctx, &workflowservice.ListWorkersRequest{ + Namespace: ts.config.Namespace, + Query: fmt.Sprintf(`TaskQueue="%s"`, taskQueue), + PageSize: 10, + }) + if err != nil { + ts.T().Logf("ListWorkers failed: %v (may not be implemented on this server)", err) + return nil + } + + if len(listResp.WorkersInfo) == 0 { + return nil + } + + // Find our worker in the list + var workerInstanceKey string + for _, workerInfo := range listResp.WorkersInfo { + if workerInfo.WorkerHeartbeat.TaskQueue == taskQueue { + workerInstanceKey = workerInfo.WorkerHeartbeat.WorkerInstanceKey + break + } + } + + if workerInstanceKey == "" { + ts.T().Logf("Could not find worker with task queue %s in list", taskQueue) + return nil + } + + // Describe the specific worker + describeResp, err := workflowClient.WorkflowService().DescribeWorker(ctx, &workflowservice.DescribeWorkerRequest{ + Namespace: ts.config.Namespace, + WorkerInstanceKey: workerInstanceKey, + }) + if err != nil { + ts.T().Logf("DescribeWorker failed: %v", err) + return nil + } + + return describeResp.WorkerInfo.WorkerHeartbeat +} + +// Simple workflow for testing +func simpleWorkflow(ctx workflow.Context) (string, error) { + return "hello", nil +} + +// Simple nexus operation for testing - just returns immediately +var noopNexusOp = nexus.NewSyncOperation("noop", func(ctx context.Context, input nexus.NoValue, opts nexus.StartOperationOptions) (nexus.NoValue, error) { + return nil, nil +}) + +var ( + blockingActivityStarted = make(chan struct{}, 10) + blockingActivityComplete = make(chan struct{}, 10) +) + +func blockingActivity(ctx context.Context) (string, error) { + // Signal that activity has started + select { + case blockingActivityStarted <- struct{}{}: + default: + } + + // Wait for signal to complete + select { + case <-blockingActivityComplete: + return "done", nil + case <-ctx.Done(): + return "", ctx.Err() + } +} + +func workflowWithBlockingActivity(ctx workflow.Context) (string, error) { + ao := workflow.ActivityOptions{ + StartToCloseTimeout: 30 * time.Second, + } + ctx = workflow.WithActivityOptions(ctx, ao) + + var result string + err := workflow.ExecuteActivity(ctx, blockingActivity).Get(ctx, &result) + return result, err +} + +var failingActivityCallCount atomic.Int32 + +func failingActivity(ctx context.Context) error { + failingActivityCallCount.Add(1) + return temporal.NewApplicationError("intentional failure", "TEST_ERROR") +} + +// Workflow that executes a failing activity with limited retries +func workflowWithFailingActivity(ctx workflow.Context) error { + ao := workflow.ActivityOptions{ + StartToCloseTimeout: 10 * time.Second, + RetryPolicy: &temporal.RetryPolicy{ + MaximumAttempts: 1, + }, + } + ctx = workflow.WithActivityOptions(ctx, ao) + + return workflow.ExecuteActivity(ctx, failingActivity).Get(ctx, nil) +} + +// Workflow that panics to simulate a workflow task failure. The flag controls +// whether it panics, allowing tests to toggle it off so the workflow can +// eventually complete after the server retries the task. +var failingWorkflowShouldFail atomic.Bool + +func failingWorkflow(ctx workflow.Context) (string, error) { + if failingWorkflowShouldFail.Load() { + panic("intentional workflow task failure") + } + return "success", nil +} + +// TestWorkerHeartbeatWithActivityInFlight verifies that activity slots are tracked +// correctly when activities are in flight +func (ts *WorkerHeartbeatTestSuite) TestWorkerHeartbeatWithActivityInFlight() { + ctx := context.Background() + + blockingActivityStarted = make(chan struct{}, 10) + blockingActivityComplete = make(chan struct{}, 10) + + ts.worker = worker.New(ts.client, ts.taskQueueName, worker.Options{ + MaxConcurrentActivityExecutionSize: 5, + }) + ts.worker.RegisterWorkflow(workflowWithBlockingActivity) + ts.worker.RegisterActivity(blockingActivity) + ts.NoError(ts.worker.Start()) + + workflowOptions := client.StartWorkflowOptions{ + ID: "test-activity-in-flight-" + uuid.NewString(), + TaskQueue: ts.taskQueueName, + } + + run, err := ts.client.ExecuteWorkflow(ctx, workflowOptions, workflowWithBlockingActivity) + ts.NoError(err) + + // Wait for activity to start + select { + case <-blockingActivityStarted: + ts.T().Log("Activity started") + case <-time.After(10 * time.Second): + ts.Fail("Timeout waiting for activity to start") + } + + var workerInfo *workerpb.WorkerHeartbeat + ts.Eventually(func() bool { + workerInfo = ts.getWorkerInfo(ctx, ts.taskQueueName) + return workerInfo != nil && workerInfo.ActivityTaskSlotsInfo != nil && + workerInfo.ActivityTaskSlotsInfo.CurrentUsedSlots >= 1 + }, 5*time.Second, 200*time.Millisecond, "Should have at least 1 activity slot used") + + ts.T().Logf("Activity slots used: %d, available: %d", + workerInfo.ActivityTaskSlotsInfo.CurrentUsedSlots, + workerInfo.ActivityTaskSlotsInfo.CurrentAvailableSlots) + ts.GreaterOrEqual(workerInfo.ActivityTaskSlotsInfo.CurrentAvailableSlots, int32(0)) + + blockingActivityComplete <- struct{}{} + + var result string + err = run.Get(ctx, &result) + ts.NoError(err) + ts.Equal("done", result) + + ts.Eventually(func() bool { + workerInfo = ts.getWorkerInfo(ctx, ts.taskQueueName) + return workerInfo != nil && workerInfo.ActivityTaskSlotsInfo != nil && + workerInfo.ActivityTaskSlotsInfo.CurrentUsedSlots == 0 + }, 5*time.Second, 200*time.Millisecond, "Activity slot should be released after completion") + + ts.T().Logf("After completion - Activity slots used: %d, available: %d", + workerInfo.ActivityTaskSlotsInfo.CurrentUsedSlots, + workerInfo.ActivityTaskSlotsInfo.CurrentAvailableSlots) + ts.GreaterOrEqual(workerInfo.ActivityTaskSlotsInfo.TotalProcessedTasks, int32(1)) +} + +func (ts *WorkerHeartbeatTestSuite) TestWorkerHeartbeatStickyCacheMiss() { + if ts.config.maxWorkflowCacheSize == 0 { + ts.T().Skip("Sticky cache disabled") + } + ctx := context.Background() + + activityStarted := make(chan struct{}, 1) + activityComplete := make(chan struct{}, 1) + + cacheMissActivity := func(ctx context.Context) (string, error) { + select { + case activityStarted <- struct{}{}: + default: + } + select { + case <-activityComplete: + return "done", nil + case <-ctx.Done(): + return "", ctx.Err() + } + } + + cacheMissWorkflow := func(ctx workflow.Context) (string, error) { + ao := workflow.ActivityOptions{ + StartToCloseTimeout: 30 * time.Second, + } + ctx = workflow.WithActivityOptions(ctx, ao) + var result string + err := workflow.ExecuteActivity(ctx, cacheMissActivity).Get(ctx, &result) + return result, err + } + + ts.worker = worker.New(ts.client, ts.taskQueueName, worker.Options{ + DisableEagerActivities: true, + }) + ts.worker.RegisterWorkflow(cacheMissWorkflow) + ts.worker.RegisterActivity(cacheMissActivity) + ts.NoError(ts.worker.Start()) + + wfOptions := client.StartWorkflowOptions{ + ID: "test-sticky-miss-" + uuid.NewString(), + TaskQueue: ts.taskQueueName, + } + run, err := ts.client.ExecuteWorkflow(ctx, wfOptions, cacheMissWorkflow) + ts.NoError(err) + + select { + case <-activityStarted: + ts.T().Log("Activity started") + case <-time.After(10 * time.Second): + ts.Fail("Timeout waiting for activity to start") + } + + // Purge the cache so the workflow's sticky task triggers a cache miss on resume + worker.PurgeStickyWorkflowCache() + + activityComplete <- struct{}{} + var result string + ts.NoError(run.Get(ctx, &result)) + ts.Equal("done", result) + + // Wait for heartbeat to capture sticky cache miss + var workerInfo *workerpb.WorkerHeartbeat + ts.Eventually(func() bool { + workerInfo = ts.getWorkerInfo(ctx, ts.taskQueueName) + return workerInfo != nil && workerInfo.TotalStickyCacheMiss >= 1 + }, 5*time.Second, 200*time.Millisecond, "Should have at least 1 sticky cache miss") +} + +// TestWorkerHeartbeatMultipleWorkers verifies that multiple workers can heartbeat +// simultaneously and be tracked separately +func (ts *WorkerHeartbeatTestSuite) TestWorkerHeartbeatMultipleWorkers() { + ctx := context.Background() + + taskQueue1 := ts.taskQueueName + "-worker1" + taskQueue2 := ts.taskQueueName + "-worker2" + + worker1 := worker.New(ts.client, taskQueue1, worker.Options{}) + worker1.RegisterWorkflow(simpleWorkflow) + ts.NoError(worker1.Start()) + defer worker1.Stop() + + worker2 := worker.New(ts.client, taskQueue2, worker.Options{}) + worker2.RegisterWorkflow(simpleWorkflow) + ts.NoError(worker2.Start()) + defer worker2.Stop() + + // Run workflow on each worker + var wg sync.WaitGroup + for i, tq := range []string{taskQueue1, taskQueue2} { + wg.Add(1) + go func(idx int, taskQueue string) { + defer wg.Done() + workflowOptions := client.StartWorkflowOptions{ + ID: fmt.Sprintf("test-multi-worker-%d-%s", idx, uuid.NewString()), + TaskQueue: taskQueue, + } + run, err := ts.client.ExecuteWorkflow(ctx, workflowOptions, simpleWorkflow) + ts.NoError(err) + err = run.Get(ctx, nil) + ts.NoError(err) + }(i, tq) + } + wg.Wait() + + // Verify both workers are tracked + var workerInfo1, workerInfo2 *workerpb.WorkerHeartbeat + ts.Eventually(func() bool { + workerInfo1 = ts.getWorkerInfo(ctx, taskQueue1) + workerInfo2 = ts.getWorkerInfo(ctx, taskQueue2) + return workerInfo1 != nil && workerInfo2 != nil + }, 5*time.Second, 200*time.Millisecond, "Should find both workers") + + ts.NotEqual(workerInfo1.WorkerInstanceKey, workerInfo2.WorkerInstanceKey, + "Different workers should have different instance keys") + + ts.Equal(taskQueue1, workerInfo1.TaskQueue) + ts.Equal(taskQueue2, workerInfo2.TaskQueue) + + ts.Equal(workerInfo1.HostInfo.WorkerGroupingKey, workerInfo2.HostInfo.WorkerGroupingKey, + "Workers should share the same client and worker grouping key") + +} + +func (ts *WorkerHeartbeatTestSuite) TestWorkerHeartbeatFailureMetrics() { + ctx := context.Background() + + // Reset call counter + failingActivityCallCount.Store(0) + + ts.worker = worker.New(ts.client, ts.taskQueueName, worker.Options{}) + ts.worker.RegisterWorkflow(workflowWithFailingActivity) + ts.worker.RegisterActivity(failingActivity) + ts.NoError(ts.worker.Start()) + + // Run workflow that will have a failing activity + workflowOptions := client.StartWorkflowOptions{ + ID: "test-failure-metrics-" + uuid.NewString(), + TaskQueue: ts.taskQueueName, + } + + run, err := ts.client.ExecuteWorkflow(ctx, workflowOptions, workflowWithFailingActivity) + ts.NoError(err) + + // Wait for workflow to complete (will fail due to activity failure) + err = run.Get(ctx, nil) + ts.Error(err) + + // Wait for heartbeat to capture failure metrics + var workerInfo *workerpb.WorkerHeartbeat + ts.Eventually(func() bool { + workerInfo = ts.getWorkerInfo(ctx, ts.taskQueueName) + return workerInfo != nil && workerInfo.ActivityTaskSlotsInfo != nil && + workerInfo.ActivityTaskSlotsInfo.TotalFailedTasks >= 1 + }, 5*time.Second, 200*time.Millisecond, "Should have tracked at least 1 activity task failure") + + ts.GreaterOrEqual(workerInfo.ActivityTaskSlotsInfo.LastIntervalFailureTasks, int32(1)) + + // Last interval should go back to 0 on next heartbeat + ts.Eventually(func() bool { + workerInfo = ts.getWorkerInfo(ctx, ts.taskQueueName) + return workerInfo != nil && workerInfo.ActivityTaskSlotsInfo != nil && + workerInfo.ActivityTaskSlotsInfo.LastIntervalFailureTasks == 0 + }, 5*time.Second, 200*time.Millisecond, "Last interval failure count should reset to 0") +} + +func (ts *WorkerHeartbeatTestSuite) TestWorkerHeartbeatWorkflowTaskFailureMetrics() { + ctx := context.Background() + + failingWorkflowShouldFail.Store(true) + defer failingWorkflowShouldFail.Store(false) + + ts.worker = worker.New(ts.client, ts.taskQueueName, worker.Options{}) + ts.worker.RegisterWorkflow(failingWorkflow) + ts.NoError(ts.worker.Start()) + + workflowOptions := client.StartWorkflowOptions{ + ID: "test-wf-task-failure-" + uuid.NewString(), + TaskQueue: ts.taskQueueName, + } + + _, err := ts.client.ExecuteWorkflow(ctx, workflowOptions, failingWorkflow) + ts.NoError(err) + + var workerInfo *workerpb.WorkerHeartbeat + ts.Eventually(func() bool { + workerInfo = ts.getWorkerInfo(ctx, ts.taskQueueName) + return workerInfo != nil && workerInfo.WorkflowTaskSlotsInfo != nil && + workerInfo.WorkflowTaskSlotsInfo.TotalFailedTasks >= 1 + }, 5*time.Second, 200*time.Millisecond, "Should have tracked at least 1 workflow task failure") + + ts.GreaterOrEqual(workerInfo.WorkflowTaskSlotsInfo.TotalFailedTasks, int32(1)) + ts.GreaterOrEqual(workerInfo.WorkflowTaskSlotsInfo.LastIntervalFailureTasks, int32(1)) + + // Stop panicking so the workflow can complete on the next retry + failingWorkflowShouldFail.Store(false) + + ts.Eventually(func() bool { + workerInfo = ts.getWorkerInfo(ctx, ts.taskQueueName) + return workerInfo != nil && workerInfo.WorkflowTaskSlotsInfo != nil && + workerInfo.WorkflowTaskSlotsInfo.TotalProcessedTasks >= 1 + }, 5*time.Second, 200*time.Millisecond, "Should have processed at least 1 workflow task after recovery") + + // Last interval failure count should reset to 0 on a subsequent heartbeat + ts.Eventually(func() bool { + workerInfo = ts.getWorkerInfo(ctx, ts.taskQueueName) + return workerInfo != nil && workerInfo.WorkflowTaskSlotsInfo != nil && + workerInfo.WorkflowTaskSlotsInfo.LastIntervalFailureTasks == 0 + }, 5*time.Second, 200*time.Millisecond, "Last interval failure count should reset to 0") +} + +func (ts *WorkerHeartbeatTestSuite) TestWorkerHeartbeatWorkflowTaskProcessed() { + ctx := context.Background() + + ts.worker = worker.New(ts.client, ts.taskQueueName, worker.Options{}) + ts.worker.RegisterWorkflow(simpleWorkflow) + ts.NoError(ts.worker.Start()) + + numWorkflows := 3 + for i := 0; i < numWorkflows; i++ { + workflowOptions := client.StartWorkflowOptions{ + ID: fmt.Sprintf("test-wf-processed-%d-%s", i, uuid.NewString()), + TaskQueue: ts.taskQueueName, + } + run, err := ts.client.ExecuteWorkflow(ctx, workflowOptions, simpleWorkflow) + ts.NoError(err) + err = run.Get(ctx, nil) + ts.NoError(err) + } + + // Wait for heartbeat to capture processed tasks + var workerInfo *workerpb.WorkerHeartbeat + ts.Eventually(func() bool { + workerInfo = ts.getWorkerInfo(ctx, ts.taskQueueName) + return workerInfo != nil && workerInfo.WorkflowTaskSlotsInfo != nil && + workerInfo.WorkflowTaskSlotsInfo.TotalProcessedTasks == int32(numWorkflows) + }, 5*time.Second, 200*time.Millisecond, "Should have processed all workflow tasks") + + ts.GreaterOrEqual(workerInfo.WorkflowTaskSlotsInfo.LastIntervalProcessedTasks, int32(1)) + + // Last interval should go back to 0 on next heartbeat + ts.Eventually(func() bool { + workerInfo = ts.getWorkerInfo(ctx, ts.taskQueueName) + return workerInfo != nil && workerInfo.WorkflowTaskSlotsInfo != nil && + workerInfo.WorkflowTaskSlotsInfo.LastIntervalProcessedTasks == 0 + }, 5*time.Second, 200*time.Millisecond, "Last interval processed count should reset to 0") +} + +func (ts *WorkerHeartbeatTestSuite) TestWorkerHeartbeatResourceBasedTuner() { + ctx := context.Background() + + tuner, err := worker.NewResourceBasedTuner(worker.ResourceBasedTunerOptions{ + TargetMem: 0.8, + TargetCpu: 0.9, + InfoSupplier: sysinfo.SysInfoProvider(), + }) + ts.NoError(err) + + tunerWorkflow := func(ctx workflow.Context) error { + ao := workflow.ActivityOptions{ + StartToCloseTimeout: 10 * time.Second, + } + ctx = workflow.WithActivityOptions(ctx, ao) + return workflow.ExecuteActivity(ctx, "tunerActivity").Get(ctx, nil) + } + + tunerActivity := func(ctx context.Context) error { + activity.GetLogger(ctx).Info("tunerActivity executed") + return nil + } + + autoscalingBehavior := worker.NewPollerBehaviorAutoscaling(worker.PollerBehaviorAutoscalingOptions{ + InitialNumberOfPollers: 5, + MinimumNumberOfPollers: 1, + MaximumNumberOfPollers: 200, + }) + + ts.worker = worker.New(ts.client, ts.taskQueueName, worker.Options{ + Tuner: tuner, + WorkflowTaskPollerBehavior: autoscalingBehavior, + ActivityTaskPollerBehavior: autoscalingBehavior, + NexusTaskPollerBehavior: autoscalingBehavior, + }) + ts.worker.RegisterWorkflowWithOptions(tunerWorkflow, workflow.RegisterOptions{Name: "tunerWorkflow"}) + ts.worker.RegisterActivityWithOptions(tunerActivity, activity.RegisterOptions{Name: "tunerActivity"}) + ts.NoError(ts.worker.Start()) + + // Run a workflow + workflowOptions := client.StartWorkflowOptions{ + ID: "test-resource-tuner-" + uuid.NewString(), + TaskQueue: ts.taskQueueName, + } + run, err := ts.client.ExecuteWorkflow(ctx, workflowOptions, "tunerWorkflow") + ts.NoError(err) + ts.NoError(run.Get(ctx, nil)) + + // Wait for heartbeat with resource-based tuner info + var workerInfo *workerpb.WorkerHeartbeat + ts.Eventually(func() bool { + workerInfo = ts.getWorkerInfo(ctx, ts.taskQueueName) + return workerInfo != nil && workerInfo.WorkflowTaskSlotsInfo != nil && + workerInfo.WorkflowTaskSlotsInfo.SlotSupplierKind == "ResourceBased" + }, 5*time.Second, 200*time.Millisecond, "Should find worker with ResourceBased slot supplier") + + ts.NotNil(workerInfo.ActivityTaskSlotsInfo) + ts.Equal("ResourceBased", workerInfo.ActivityTaskSlotsInfo.SlotSupplierKind) + + ts.NotNil(workerInfo.LocalActivitySlotsInfo) + ts.Equal("ResourceBased", workerInfo.LocalActivitySlotsInfo.SlotSupplierKind) + + ts.NotNil(workerInfo.WorkflowPollerInfo) + ts.True(workerInfo.WorkflowPollerInfo.IsAutoscaling) + + if ts.config.maxWorkflowCacheSize > 0 { + ts.NotNil(workerInfo.WorkflowStickyPollerInfo) + ts.True(workerInfo.WorkflowStickyPollerInfo.IsAutoscaling) + } + + ts.NotNil(workerInfo.ActivityPollerInfo) + ts.True(workerInfo.ActivityPollerInfo.IsAutoscaling) +} + +func (ts *WorkerHeartbeatTestSuite) TestWorkerHeartbeatPlugins() { + ctx := context.Background() + + clientPlugin, err := temporal.NewSimplePlugin(temporal.SimplePluginOptions{ + Name: "test-client-plugin", + }) + ts.NoError(err) + + workerPlugin, err := temporal.NewSimplePlugin(temporal.SimplePluginOptions{ + Name: "test-worker-plugin", + }) + ts.NoError(err) + + duplicatePlugin, err := temporal.NewSimplePlugin(temporal.SimplePluginOptions{ + Name: "test-client-plugin", + }) + ts.NoError(err) + + // Create a new client with the plugin + pluginClient, err := client.Dial(client.Options{ + HostPort: ts.config.ServiceAddr, + Namespace: ts.config.Namespace, + Logger: ilog.NewDefaultLogger(), + WorkerHeartbeatInterval: 1 * time.Second, + ConnectionOptions: client.ConnectionOptions{TLS: ts.config.TLS}, + Identity: "PluginTest", + Plugins: []client.Plugin{clientPlugin}, + }) + ts.NoError(err) + defer pluginClient.Close() + + // Create worker with additional plugins (including duplicate) + ts.worker = worker.New(pluginClient, ts.taskQueueName, worker.Options{ + Plugins: []worker.Plugin{workerPlugin, duplicatePlugin}, + }) + ts.worker.RegisterWorkflow(simpleWorkflow) + ts.NoError(ts.worker.Start()) + + workflowOptions := client.StartWorkflowOptions{ + ID: "test-plugins-" + uuid.NewString(), + TaskQueue: ts.taskQueueName, + } + run, err := pluginClient.ExecuteWorkflow(ctx, workflowOptions, simpleWorkflow) + ts.NoError(err) + ts.NoError(run.Get(ctx, nil)) + + // Wait for heartbeat with plugin info + var workerInfo *workerpb.WorkerHeartbeat + ts.Eventually(func() bool { + workerInfo = ts.getWorkerInfo(ctx, ts.taskQueueName) + return workerInfo != nil && len(workerInfo.Plugins) == 2 + }, 5*time.Second, 200*time.Millisecond, "Should have 2 unique plugins (duplicates deduped)") + + pluginNames := make(map[string]bool) + for _, plugin := range workerInfo.Plugins { + pluginNames[plugin.Name] = true + } + ts.True(pluginNames["test-client-plugin"]) + ts.True(pluginNames["test-worker-plugin"]) +} diff --git a/test/worker_tuner_test.go b/test/worker_tuner_test.go index fd1baf387..227e4effc 100644 --- a/test/worker_tuner_test.go +++ b/test/worker_tuner_test.go @@ -4,11 +4,10 @@ import ( "context" "testing" - "go.temporal.io/sdk/worker" - "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" - "go.temporal.io/sdk/contrib/resourcetuner" + "go.temporal.io/sdk/contrib/sysinfo" + "go.temporal.io/sdk/worker" ) type WorkerTunerTestSuite struct { @@ -58,12 +57,13 @@ func (ts *WorkerTunerTestSuite) TestCompositeWorkerTuner() { wfSS, err := worker.NewFixedSizeSlotSupplier(10) ts.NoError(err) - controllerOpts := resourcetuner.DefaultResourceControllerOptions() + controllerOpts := worker.DefaultResourceControllerOptions() controllerOpts.MemTargetPercent = 0.8 controllerOpts.CpuTargetPercent = 0.9 - controller := resourcetuner.NewResourceController(controllerOpts) - actSS, err := resourcetuner.NewResourceBasedSlotSupplier(controller, - resourcetuner.ResourceBasedSlotSupplierOptions{ + controllerOpts.InfoSupplier = sysinfo.SysInfoProvider() + controller := worker.NewResourceController(controllerOpts) + actSS, err := worker.NewResourceBasedSlotSupplier(controller, + worker.ResourceBasedSlotSupplierOptions{ MinSlots: 10, MaxSlots: 20, RampThrottle: 0, @@ -112,12 +112,13 @@ func (ts *WorkerTunerTestSuite) TestResourceBasedSmallSlots() { wfSS, err := worker.NewFixedSizeSlotSupplier(10) ts.NoError(err) - controllerOpts := resourcetuner.DefaultResourceControllerOptions() + controllerOpts := worker.DefaultResourceControllerOptions() controllerOpts.MemTargetPercent = 0.8 controllerOpts.CpuTargetPercent = 0.9 - controller := resourcetuner.NewResourceController(controllerOpts) - actSS, err := resourcetuner.NewResourceBasedSlotSupplier(controller, - resourcetuner.ResourceBasedSlotSupplierOptions{ + controllerOpts.InfoSupplier = sysinfo.SysInfoProvider() + controller := worker.NewResourceController(controllerOpts) + actSS, err := worker.NewResourceBasedSlotSupplier(controller, + worker.ResourceBasedSlotSupplierOptions{ MinSlots: 1, MaxSlots: 4, RampThrottle: 0, diff --git a/worker/tuning.go b/worker/tuning.go index a04800c5c..7b7362dd1 100644 --- a/worker/tuning.go +++ b/worker/tuning.go @@ -46,3 +46,79 @@ func NewCompositeTuner(options CompositeTunerOptions) (WorkerTuner, error) { func NewFixedSizeSlotSupplier(numSlots int) (SlotSupplier, error) { return internal.NewFixedSizeSlotSupplier(numSlots) } + +// SysInfoProvider implementations provide information about system resources. +// Use contrib/sysinfo.SysInfoProvider() for a gopsutil-based implementation, +// or provide your own. +type SysInfoProvider = internal.SysInfoProvider + +// SysInfoContext provides context for SysInfoProvider calls. +type SysInfoContext = internal.SysInfoContext + +// HasSysInfoProvider is an optional interface that SlotSupplier implementations can implement +// to expose their SysInfoProvider. +type HasSysInfoProvider = internal.HasSysInfoProvider + +// ResourceBasedTunerOptions configures a resource-based tuner. +type ResourceBasedTunerOptions = internal.ResourceBasedTunerOptions + +// NewResourceBasedTuner creates a WorkerTuner that dynamically adjusts the number of slots based +// on system resources. Specify the target CPU and memory usage as a value between 0 and 1. +// +// InfoSupplier is required - use contrib/sysinfo.SysInfoProvider() for a gopsutil-based +// implementation, or provide your own. +func NewResourceBasedTuner(opts ResourceBasedTunerOptions) (WorkerTuner, error) { + return internal.NewResourceBasedTuner(opts) +} + +// ResourceBasedSlotSupplierOptions configures a particular ResourceBasedSlotSupplier. +type ResourceBasedSlotSupplierOptions = internal.ResourceBasedSlotSupplierOptions + +// ResourceBasedSlotSupplier is a SlotSupplier that issues slots based on system resource usage. +type ResourceBasedSlotSupplier = internal.ResourceBasedSlotSupplier + +// NewResourceBasedSlotSupplier creates a ResourceBasedSlotSupplier given the provided +// ResourceController and ResourceBasedSlotSupplierOptions. All ResourceBasedSlotSupplier instances +// must use the same ResourceController. +func NewResourceBasedSlotSupplier( + controller *ResourceController, + options ResourceBasedSlotSupplierOptions, +) (*ResourceBasedSlotSupplier, error) { + return internal.NewResourceBasedSlotSupplier(controller, options) +} + +// ResourceControllerOptions contains configurable parameters for a ResourceController. +// It is recommended to use DefaultResourceControllerOptions to create a ResourceControllerOptions +// and only modify the mem/cpu target percent fields. +type ResourceControllerOptions = internal.ResourceControllerOptions + +// ResourceController is used by ResourceBasedSlotSupplier to make decisions about whether slots +// should be issued based on system resource usage. +type ResourceController = internal.ResourceController + +// NewResourceController creates a new ResourceController with the provided options. +// +// InfoSupplier is required - use contrib/sysinfo.SysInfoProvider() for a gopsutil-based +// implementation, or provide your own. +// +// WARNING: It is important that you do not create multiple InfoSupplier instances. Since +// InfoSupplier looks at overall system resources, multiple instances with different configs can +// only conflict with one another. +func NewResourceController(options ResourceControllerOptions) *ResourceController { + return internal.NewResourceController(options) +} + +// DefaultResourceControllerOptions returns a ResourceControllerOptions with default values. +func DefaultResourceControllerOptions() ResourceControllerOptions { + return internal.DefaultResourceControllerOptions() +} + +// DefaultWorkflowResourceBasedSlotSupplierOptions returns default options for workflow slot suppliers. +func DefaultWorkflowResourceBasedSlotSupplierOptions() ResourceBasedSlotSupplierOptions { + return internal.DefaultWorkflowResourceBasedSlotSupplierOptions() +} + +// DefaultActivityResourceBasedSlotSupplierOptions returns default options for activity slot suppliers. +func DefaultActivityResourceBasedSlotSupplierOptions() ResourceBasedSlotSupplierOptions { + return internal.DefaultActivityResourceBasedSlotSupplierOptions() +}