diff --git a/temporalio/lib/temporalio/activity/context.rb b/temporalio/lib/temporalio/activity/context.rb index 9d51769a..edaef5e3 100644 --- a/temporalio/lib/temporalio/activity/context.rb +++ b/temporalio/lib/temporalio/activity/context.rb @@ -48,6 +48,12 @@ def info raise NotImplementedError end + # @return [Object, nil] Activity class instance. This should always be present except for advanced cases where the + # definition was manually created without any instance getter/creator. + def instance + raise NotImplementedError + end + # Record a heartbeat on the activity. # # Heartbeats should be used for all non-immediately-returning, non-local activities and they are required to diff --git a/temporalio/lib/temporalio/activity/definition.rb b/temporalio/lib/temporalio/activity/definition.rb index 58d8641a..7e06a585 100644 --- a/temporalio/lib/temporalio/activity/definition.rb +++ b/temporalio/lib/temporalio/activity/definition.rb @@ -105,7 +105,10 @@ class Info # @return [String, Symbol, nil] Name of the activity, or nil if the activity is dynamic. attr_reader :name - # @return [Proc] Proc for the activity. + # @return [Object, Proc, nil] The pre-created instance or the proc to create/return it. + attr_reader :instance + + # @return [Proc] Proc for the activity. Should use {Context#instance} to access the instance. attr_reader :proc # @return [Symbol] Name of the executor. Default is `:default`. @@ -134,18 +137,20 @@ def self.from_activity(activity) details = activity._activity_definition_details new( name: details[:activity_name], + instance: proc { activity.new }, executor: details[:activity_executor], cancel_raise: details[:activity_cancel_raise], raw_args: details[:activity_raw_args] - ) { |*args| activity.new.execute(*args) } # Instantiate and call + ) { |*args| Context.current.instance&.execute(*args) } when Definition details = activity.class._activity_definition_details new( name: details[:activity_name], + instance: activity, executor: details[:activity_executor], cancel_raise: details[:activity_cancel_raise], raw_args: details[:activity_raw_args] - ) { |*args| activity.execute(*args) } # Just and call + ) { |*args| Context.current.instance&.execute(*args) } when Info activity else @@ -156,12 +161,21 @@ def self.from_activity(activity) # Manually create activity definition info. Most users will use an instance/class of {Definition}. # # @param name [String, Symbol, nil] Name of the activity or nil for dynamic activity. + # @param instance [Object, Proc, nil] The pre-created instance or the proc to create/return it. # @param executor [Symbol] Name of the executor. # @param cancel_raise [Boolean] Whether to raise in thread/fiber on cancellation. # @param raw_args [Boolean] Whether to use {Converters::RawValue}s as arguments. # @yield Use this block as the activity. - def initialize(name:, executor: :default, cancel_raise: true, raw_args: false, &block) + def initialize( + name:, + instance: nil, + executor: :default, + cancel_raise: true, + raw_args: false, + &block + ) @name = name + @instance = instance raise ArgumentError, 'Must give block' unless block_given? @proc = block diff --git a/temporalio/lib/temporalio/internal/worker/activity_worker.rb b/temporalio/lib/temporalio/internal/worker/activity_worker.rb index 46cd0f1f..595a63e6 100644 --- a/temporalio/lib/temporalio/internal/worker/activity_worker.rb +++ b/temporalio/lib/temporalio/internal/worker/activity_worker.rb @@ -210,7 +210,7 @@ def execute_activity(task_token, defn, start) ) Activity::Context._current_executor&.set_activity_context(defn, activity) set_running_activity(task_token, activity) - run_activity(activity, input) + run_activity(defn, activity, input) rescue Exception => e # rubocop:disable Lint/RescueException We are intending to catch everything here @scoped_logger.warn("Failed starting or sending completion for activity #{start.activity_type}") @scoped_logger.warn(e) @@ -236,8 +236,11 @@ def execute_activity(task_token, defn, start) remove_running_activity(task_token) end - def run_activity(activity, input) + def run_activity(defn, activity, input) result = begin + # Create the instance. We choose to do this before interceptors so that it is available in the interceptor. + activity.instance = defn.instance.is_a?(Proc) ? defn.instance.call : defn.instance # steep:ignore + # Build impl with interceptors # @type var impl: Temporalio::Worker::Interceptor::Activity::Inbound impl = InboundImplementation.new(self) @@ -293,7 +296,7 @@ def run_activity(activity, input) class RunningActivity < Activity::Context attr_reader :info, :cancellation, :worker_shutdown_cancellation, :payload_converter, :logger - attr_accessor :_outbound_impl, :_server_requested_cancel + attr_accessor :instance, :_outbound_impl, :_server_requested_cancel def initialize( # rubocop:disable Lint/MissingSuper info:, diff --git a/temporalio/lib/temporalio/internal/worker/workflow_instance/context.rb b/temporalio/lib/temporalio/internal/worker/workflow_instance/context.rb index 6736bd4f..eb094970 100644 --- a/temporalio/lib/temporalio/internal/worker/workflow_instance/context.rb +++ b/temporalio/lib/temporalio/internal/worker/workflow_instance/context.rb @@ -122,6 +122,10 @@ def info @instance.info end + def instance + @instance.instance + end + def initialize_continue_as_new_error(error) @outbound.initialize_continue_as_new_error( Temporalio::Worker::Interceptor::Workflow::InitializeContinueAsNewErrorInput.new(error:) diff --git a/temporalio/lib/temporalio/testing/activity_environment.rb b/temporalio/lib/temporalio/testing/activity_environment.rb index 68fc2c72..6d504ac3 100644 --- a/temporalio/lib/temporalio/testing/activity_environment.rb +++ b/temporalio/lib/temporalio/testing/activity_environment.rb @@ -80,6 +80,8 @@ def run(activity, *args) Activity::Context._current_executor = executor executor.set_activity_context(defn, Context.new( info: @info.dup, + instance: + defn.instance.is_a?(Proc) ? defn.instance.call : defn.instance, on_heartbeat: @on_heartbeat, cancellation: @cancellation, worker_shutdown_cancellation: @worker_shutdown_cancellation, @@ -102,17 +104,19 @@ def run(activity, *args) # @!visibility private class Context < Activity::Context - attr_reader :info, :cancellation, :worker_shutdown_cancellation, :payload_converter, :logger + attr_reader :info, :instance, :cancellation, :worker_shutdown_cancellation, :payload_converter, :logger def initialize( # rubocop:disable Lint/MissingSuper - info: ActivityEnvironment.default_info, - on_heartbeat: nil, - cancellation: Cancellation.new, - worker_shutdown_cancellation: Cancellation.new, - payload_converter: Converters::PayloadConverter.default, - logger: Logger.new(nil) + info:, + instance:, + on_heartbeat:, + cancellation:, + worker_shutdown_cancellation:, + payload_converter:, + logger: ) @info = info + @instance = instance @on_heartbeat = on_heartbeat @cancellation = cancellation @worker_shutdown_cancellation = worker_shutdown_cancellation diff --git a/temporalio/lib/temporalio/workflow.rb b/temporalio/lib/temporalio/workflow.rb index 2b7e9d6f..c13da19c 100644 --- a/temporalio/lib/temporalio/workflow.rb +++ b/temporalio/lib/temporalio/workflow.rb @@ -220,6 +220,12 @@ def self.info _current.info end + # @return [Definition, nil] Workflow class instance. This should always be present except in + # {Worker::Interceptor::Workflow::Inbound.init} where it will be nil. + def self.instance + _current.instance + end + # @return [Logger] Logger for the workflow. This is a scoped logger that automatically appends workflow details to # every log and takes care not to log during replay. def self.logger diff --git a/temporalio/sig/temporalio/activity/context.rbs b/temporalio/sig/temporalio/activity/context.rbs index 5aec9089..930522b4 100644 --- a/temporalio/sig/temporalio/activity/context.rbs +++ b/temporalio/sig/temporalio/activity/context.rbs @@ -9,6 +9,7 @@ module Temporalio def self._current_executor=: (Worker::ActivityExecutor? executor) -> void def info: -> Info + def instance: -> Definition? def heartbeat: (*Object? details) -> void def cancellation: -> Cancellation def worker_shutdown_cancellation: -> Cancellation diff --git a/temporalio/sig/temporalio/activity/definition.rbs b/temporalio/sig/temporalio/activity/definition.rbs index b838fb7b..9b46d29e 100644 --- a/temporalio/sig/temporalio/activity/definition.rbs +++ b/temporalio/sig/temporalio/activity/definition.rbs @@ -18,6 +18,7 @@ module Temporalio class Info attr_reader name: String | Symbol | nil + attr_reader instance: Object | Proc | nil attr_reader proc: Proc attr_reader executor: Symbol attr_reader cancel_raise: bool @@ -27,6 +28,7 @@ module Temporalio def initialize: ( name: String | Symbol | nil, + ?instance: Object | Proc | nil, ?executor: Symbol, ?cancel_raise: bool, ?raw_args: bool diff --git a/temporalio/sig/temporalio/internal/worker/activity_worker.rbs b/temporalio/sig/temporalio/internal/worker/activity_worker.rbs index af6364e2..651f3caf 100644 --- a/temporalio/sig/temporalio/internal/worker/activity_worker.rbs +++ b/temporalio/sig/temporalio/internal/worker/activity_worker.rbs @@ -21,11 +21,13 @@ module Temporalio def execute_activity: (String task_token, Activity::Definition::Info defn, untyped start) -> void def run_activity: ( + Activity::Definition::Info defn, RunningActivity activity, Temporalio::Worker::Interceptor::Activity::ExecuteInput input ) -> void class RunningActivity < Activity::Context + attr_accessor instance: Activity::Definition? attr_accessor _outbound_impl: Temporalio::Worker::Interceptor::Activity::Outbound? attr_accessor _server_requested_cancel: bool diff --git a/temporalio/sig/temporalio/internal/worker/workflow_instance.rbs b/temporalio/sig/temporalio/internal/worker/workflow_instance.rbs index f4504ad0..19716d40 100644 --- a/temporalio/sig/temporalio/internal/worker/workflow_instance.rbs +++ b/temporalio/sig/temporalio/internal/worker/workflow_instance.rbs @@ -40,7 +40,7 @@ module Temporalio def add_command: (untyped command) -> void - def instance: -> Object + def instance: -> Temporalio::Workflow::Definition def search_attributes: -> SearchAttributes @@ -58,7 +58,7 @@ module Temporalio def activate_internal: (untyped activation) -> untyped - def create_instance: -> Object + def create_instance: -> Temporalio::Workflow::Definition def apply: (untyped job) -> void diff --git a/temporalio/sig/temporalio/internal/worker/workflow_instance/context.rbs b/temporalio/sig/temporalio/internal/worker/workflow_instance/context.rbs index 0dbf5eae..552fab8d 100644 --- a/temporalio/sig/temporalio/internal/worker/workflow_instance/context.rbs +++ b/temporalio/sig/temporalio/internal/worker/workflow_instance/context.rbs @@ -53,6 +53,8 @@ module Temporalio def info: -> Workflow::Info + def instance: -> Temporalio::Workflow::Definition? + def initialize_continue_as_new_error: (Workflow::ContinueAsNewError error) -> void def logger: -> ReplaySafeLogger diff --git a/temporalio/sig/temporalio/workflow.rbs b/temporalio/sig/temporalio/workflow.rbs index 9a55ff72..2fded9bd 100644 --- a/temporalio/sig/temporalio/workflow.rbs +++ b/temporalio/sig/temporalio/workflow.rbs @@ -64,6 +64,8 @@ module Temporalio def self.info: -> Info + def self.instance: -> Definition? + def self.logger: -> ScopedLogger def self.memo: -> Hash[String, Object?] diff --git a/temporalio/test/worker_activity_test.rb b/temporalio/test/worker_activity_test.rb index 5ed29788..8c870066 100644 --- a/temporalio/test/worker_activity_test.rb +++ b/temporalio/test/worker_activity_test.rb @@ -857,6 +857,50 @@ def test_dynamic_activity_raw_args execute_activity(DynamicActivityRawArgs, 'arg1', nil, 123, override_name: 'does-not-exist') end + class ContextInstanceInterceptor + include Temporalio::Worker::Interceptor::Activity + + def intercept_activity(next_interceptor) + Inbound.new(next_interceptor) + end + + class Inbound < Temporalio::Worker::Interceptor::Activity::Inbound + def init(outbound) + Temporalio::Activity::Context.current.instance.events&.<< 'interceptor-init' # steep:ignore + super + end + + def execute(input) + Temporalio::Activity::Context.current.instance.events&.<< 'interceptor-execute' # steep:ignore + super + end + end + end + + class ContextInstanceActivity < Temporalio::Activity::Definition + def events + @events ||= [] + end + + def execute + events << 'execute' # steep:ignore + end + end + + def test_context_instance + # Instance-per-attempt (twice) + assert_equal %w[interceptor-init interceptor-execute execute], + execute_activity(ContextInstanceActivity, interceptors: [ContextInstanceInterceptor.new]) + assert_equal %w[interceptor-init interceptor-execute execute], + execute_activity(ContextInstanceActivity, interceptors: [ContextInstanceInterceptor.new]) + # Shared instance + shared_instance = ContextInstanceActivity.new + assert_equal %w[interceptor-init interceptor-execute execute], + execute_activity(shared_instance, interceptors: [ContextInstanceInterceptor.new]) + assert_equal %w[interceptor-init interceptor-execute execute interceptor-init interceptor-execute execute], + execute_activity(shared_instance, interceptors: [ContextInstanceInterceptor.new]) + end + # steep:ignore def execute_activity( activity, diff --git a/temporalio/test/worker_workflow_test.rb b/temporalio/test/worker_workflow_test.rb index ff6491ca..50a5100f 100644 --- a/temporalio/test/worker_workflow_test.rb +++ b/temporalio/test/worker_workflow_test.rb @@ -1738,6 +1738,37 @@ def test_confirm_garbage_collect end end + class ContextInstanceInterceptor + include Temporalio::Worker::Interceptor::Workflow + + def intercept_workflow(next_interceptor) + Inbound.new(next_interceptor) + end + + class Inbound < Temporalio::Worker::Interceptor::Workflow::Inbound + def execute(input) + Temporalio::Workflow.instance.events << 'interceptor-execute' + super + end + end + end + + class ContextInstanceWorkflow < Temporalio::Workflow::Definition + def execute + events << 'execute' + end + + workflow_query + def events + @events ||= [] + end + end + + def test_context_instance + assert_equal %w[interceptor-execute execute], + execute_workflow(ContextInstanceWorkflow, interceptors: [ContextInstanceInterceptor.new]) + end + # TODO(cretz): To test # * Common # * Eager workflow start diff --git a/temporalio/test/workflow_utils.rb b/temporalio/test/workflow_utils.rb index 92d896d4..61afc0b2 100644 --- a/temporalio/test/workflow_utils.rb +++ b/temporalio/test/workflow_utils.rb @@ -27,7 +27,8 @@ def execute_workflow( workflow_payload_codec_thread_pool: nil, id_conflict_policy: Temporalio::WorkflowIDConflictPolicy::UNSPECIFIED, max_heartbeat_throttle_interval: 60.0, - task_timeout: nil + task_timeout: nil, + interceptors: [] ) worker = Temporalio::Worker.new( client:, @@ -38,7 +39,8 @@ def execute_workflow( max_cached_workflows:, logger: logger || client.options.logger, workflow_payload_codec_thread_pool:, - max_heartbeat_throttle_interval: + max_heartbeat_throttle_interval:, + interceptors: ) worker.run do handle = client.start_workflow(