diff --git a/temporal-sdk/src/main/java/io/temporal/internal/activity/CurrentActivityExecutionContext.java b/temporal-sdk/src/main/java/io/temporal/internal/activity/CurrentActivityExecutionContext.java
index 7be9fcee63..219d8aba5b 100644
--- a/temporal-sdk/src/main/java/io/temporal/internal/activity/CurrentActivityExecutionContext.java
+++ b/temporal-sdk/src/main/java/io/temporal/internal/activity/CurrentActivityExecutionContext.java
@@ -1,22 +1,56 @@
package io.temporal.internal.activity;
import io.temporal.activity.ActivityExecutionContext;
+import java.util.ArrayDeque;
+import java.util.Collections;
+import java.util.Deque;
+import java.util.Map;
+import java.util.WeakHashMap;
/**
- * Thread local store of the context object passed to an activity implementation. Avoid using this
- * class directly.
+ * Thread-local / virtual-thread-aware store of the context object passed to an activity
+ * implementation. Avoid using this class directly.
*
- * @author fateev
+ *
Uses a per-thread stack so nested sets/unsets are handled correctly. Platform threads use
+ * ThreadLocal; virtual threads use a WeakHashMap keyed by Thread to avoid leaking memory when
+ * virtual threads die.
+ *
+ * @author fateev (adapted)
*/
-final class CurrentActivityExecutionContext {
+public final class CurrentActivityExecutionContext {
+
+ private static final ThreadLocal> PLATFORM_STACK =
+ ThreadLocal.withInitial(ArrayDeque::new);
- private static final ThreadLocal CURRENT = new ThreadLocal<>();
+ private static final Map> VIRTUAL_STACKS =
+ Collections.synchronizedMap(new WeakHashMap<>());
+
+ private static Deque getStackForCurrentThread() {
+ Thread t = Thread.currentThread();
+ if (isVirtualThread(t)) {
+ Deque d =
+ VIRTUAL_STACKS.computeIfAbsent(t, k -> new ArrayDeque<>());
+ return d;
+ } else {
+ return PLATFORM_STACK.get();
+ }
+ }
+
+ private static boolean isVirtualThread(Thread t) {
+ try {
+ t.getClass().getMethod("isVirtual", boolean.class);
+ return true;
+ } catch (NoSuchMethodException e) {
+ return false;
+ }
+ }
/**
* This is used by activity implementation to get access to the current ActivityExecutionContext
*/
public static ActivityExecutionContext get() {
- ActivityExecutionContext result = CURRENT.get();
+ Deque stack = getStackForCurrentThread();
+ ActivityExecutionContext result = stack.peek();
if (result == null) {
throw new IllegalStateException(
"ActivityExecutionContext can be used only inside of activity "
@@ -26,21 +60,49 @@ public static ActivityExecutionContext get() {
}
public static boolean isSet() {
- return CURRENT.get() != null;
+ Deque stack = getStackForCurrentThread();
+ return stack.peek() != null;
}
+ /**
+ * Pushes the provided context for the current thread. Null context is rejected. We allow nested
+ * sets (push semantics) to support nested interceptors / wrappers.
+ */
public static void set(ActivityExecutionContext context) {
if (context == null) {
throw new IllegalArgumentException("null context");
}
- if (CURRENT.get() != null) {
- throw new IllegalStateException("current already set");
- }
- CURRENT.set(context);
+ Deque stack = getStackForCurrentThread();
+ stack.push(context);
}
+ /**
+ * Pops the current context for the thread. If the stack becomes empty, clear the storage for the
+ * thread to allow GC (remove ThreadLocal or remove map entry for virtual threads).
+ */
public static void unset() {
- CURRENT.set(null);
+ Thread t = Thread.currentThread();
+ if (isVirtualThread(t)) {
+ synchronized (VIRTUAL_STACKS) {
+ Deque stack = VIRTUAL_STACKS.get(t);
+ if (stack == null || stack.isEmpty()) {
+ return;
+ }
+ stack.pop();
+ if (stack.isEmpty()) {
+ VIRTUAL_STACKS.remove(t);
+ }
+ }
+ } else {
+ Deque stack = PLATFORM_STACK.get();
+ if (stack == null || stack.isEmpty()) {
+ return;
+ }
+ stack.pop();
+ if (stack.isEmpty()) {
+ PLATFORM_STACK.remove();
+ }
+ }
}
private CurrentActivityExecutionContext() {}
diff --git a/temporal-sdk/src/test/java/io/temporal/internal/nexus/WorkflowRunTokenTest.java b/temporal-sdk/src/test/java/io/temporal/internal/nexus/WorkflowRunTokenTest.java
index fbf14d217a..776fe790cf 100644
--- a/temporal-sdk/src/test/java/io/temporal/internal/nexus/WorkflowRunTokenTest.java
+++ b/temporal-sdk/src/test/java/io/temporal/internal/nexus/WorkflowRunTokenTest.java
@@ -11,8 +11,6 @@
public class WorkflowRunTokenTest {
private static final ObjectWriter ow =
new ObjectMapper().registerModule(new Jdk8Module()).writer();
- private static final ObjectReader or =
- new ObjectMapper().registerModule(new Jdk8Module()).reader();
private static final Base64.Encoder encoder = Base64.getUrlEncoder().withoutPadding();
@Test
diff --git a/temporal-sdk/src/virtualThreadTests/java/io/temporal/internal/activity/CurrentActivityExecutionContextTest.java b/temporal-sdk/src/virtualThreadTests/java/io/temporal/internal/activity/CurrentActivityExecutionContextTest.java
new file mode 100644
index 0000000000..407b80a8a1
--- /dev/null
+++ b/temporal-sdk/src/virtualThreadTests/java/io/temporal/internal/activity/CurrentActivityExecutionContextTest.java
@@ -0,0 +1,121 @@
+package io.temporal.internal.activity;
+
+import static org.junit.Assert.*;
+
+import io.temporal.activity.ActivityExecutionContext;
+import java.lang.reflect.InvocationHandler;
+import java.lang.reflect.Proxy;
+import java.util.concurrent.atomic.AtomicReference;
+import org.junit.Assume;
+import org.junit.Test;
+
+public class CurrentActivityExecutionContextTest {
+
+ private static ActivityExecutionContext proxyContext() {
+ InvocationHandler handler = (proxy, method, args) -> null;
+ return (ActivityExecutionContext)
+ Proxy.newProxyInstance(
+ ActivityExecutionContext.class.getClassLoader(),
+ new Class[] {ActivityExecutionContext.class},
+ handler);
+ }
+
+ @Test
+ public void platformThreadNestedSetUnsetBehavior() {
+ ActivityExecutionContext ctx1 = proxyContext();
+ ActivityExecutionContext ctx2 = proxyContext();
+
+ assertFalse(CurrentActivityExecutionContext.isSet());
+ assertThrows(IllegalStateException.class, CurrentActivityExecutionContext::get);
+
+ CurrentActivityExecutionContext.set(ctx1);
+ assertTrue(CurrentActivityExecutionContext.isSet());
+ assertSame("should return ctx1", ctx1, CurrentActivityExecutionContext.get());
+
+ CurrentActivityExecutionContext.set(ctx2);
+ assertTrue(CurrentActivityExecutionContext.isSet());
+ assertSame("should return ctx2 (top of stack)", ctx2, CurrentActivityExecutionContext.get());
+
+ CurrentActivityExecutionContext.unset();
+ assertTrue(CurrentActivityExecutionContext.isSet());
+ assertSame("after popping, should return ctx1", ctx1, CurrentActivityExecutionContext.get());
+
+ CurrentActivityExecutionContext.unset();
+ assertFalse(CurrentActivityExecutionContext.isSet());
+ assertThrows(
+ "get() should throw after final unset",
+ IllegalStateException.class,
+ CurrentActivityExecutionContext::get);
+ }
+
+ @Test
+ public void virtualThreadNestedSetUnsetBehavior_ifSupported() throws Exception {
+ boolean supportsVirtual;
+ try {
+ Thread.class.getMethod("startVirtualThread", Runnable.class);
+ supportsVirtual = true;
+ } catch (NoSuchMethodException e) {
+ supportsVirtual = false;
+ }
+
+ Assume.assumeTrue("Virtual threads not supported in this JVM; skipping", supportsVirtual);
+
+ AtomicReference failure = new AtomicReference<>(null);
+ AtomicReference seenAfterFirstSet = new AtomicReference<>(null);
+ AtomicReference seenAfterSecondSet = new AtomicReference<>(null);
+ AtomicReference seenIsSetAfterFinalUnset = new AtomicReference<>(null);
+
+ Thread vt =
+ Thread.startVirtualThread(
+ () -> {
+ try {
+ ActivityExecutionContext vctx1 = proxyContext();
+ ActivityExecutionContext vctx2 = proxyContext();
+
+ assertFalse(CurrentActivityExecutionContext.isSet());
+ try {
+ CurrentActivityExecutionContext.get();
+ fail("get() should have thrown when no context is set");
+ } catch (IllegalStateException expected) {
+ }
+
+ CurrentActivityExecutionContext.set(vctx1);
+ seenAfterFirstSet.set(CurrentActivityExecutionContext.get());
+
+ CurrentActivityExecutionContext.set(vctx2);
+ seenAfterSecondSet.set(CurrentActivityExecutionContext.get());
+
+ CurrentActivityExecutionContext.unset();
+ ActivityExecutionContext afterPop = CurrentActivityExecutionContext.get();
+ if (afterPop != vctx1) {
+ throw new AssertionError("after pop expected vctx1 but got " + afterPop);
+ }
+
+ CurrentActivityExecutionContext.unset();
+ seenIsSetAfterFinalUnset.set(CurrentActivityExecutionContext.isSet());
+ try {
+ CurrentActivityExecutionContext.get();
+ throw new AssertionError("get() should have thrown after final unset");
+ } catch (IllegalStateException expected) {
+ }
+ } catch (Throwable t) {
+ failure.set(t);
+ }
+ });
+
+ vt.join();
+
+ if (failure.get() != null) {
+ Throwable t = failure.get();
+ if (t instanceof AssertionError) {
+ throw (AssertionError) t;
+ } else {
+ throw new RuntimeException(t);
+ }
+ }
+
+ assertNotNull("virtual thread did not record first set", seenAfterFirstSet.get());
+ assertNotNull("virtual thread did not record second (nested) set", seenAfterSecondSet.get());
+ assertFalse("expected context to be unset at the end", seenIsSetAfterFinalUnset.get());
+ }
+}