diff --git a/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/Call.java b/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/Call.java index ab73946534cd..0bff72a3814d 100644 --- a/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/Call.java +++ b/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/Call.java @@ -613,13 +613,10 @@ private void executeAsync(Callable callable) throws UserCodeExecutionExcep private static void parseAndThrow(Future future, ExecutionException e) throws UserCodeExecutionException { future.cancel(true); - if (e.getCause() == null) { - throw new UserCodeExecutionException(e); + Throwable cause = e.getCause(); + if (cause instanceof UserCodeExecutionException) { + throw (UserCodeExecutionException) cause; } - Throwable cause = checkStateNotNull(e.getCause()); - if (cause instanceof UserCodeQuotaException) { - throw new UserCodeQuotaException(cause); - } - throw new UserCodeExecutionException(cause); + throw new UserCodeExecutionException(cause == null ? e : cause); } } diff --git a/sdks/java/io/rrio/src/test/java/org/apache/beam/io/requestresponse/CallTest.java b/sdks/java/io/rrio/src/test/java/org/apache/beam/io/requestresponse/CallTest.java index 0e572bdd2d64..9fd6babf6604 100644 --- a/sdks/java/io/rrio/src/test/java/org/apache/beam/io/requestresponse/CallTest.java +++ b/sdks/java/io/rrio/src/test/java/org/apache/beam/io/requestresponse/CallTest.java @@ -104,6 +104,20 @@ public void givenCallerThrowsUserCodeExecutionException_emitsIntoFailurePCollect pipeline.run(); } + @Test + public void givenCallerThrowsNonUserCodeException_emitsWrappedUserCodeExecutionException() { + Result result = + pipeline + .apply(Create.of(new Request("a"))) + .apply(Call.of(new CallerThrowsRuntimeException(), NON_DETERMINISTIC_RESPONSE_CODER)); + + PCollection failures = result.getFailures(); + PAssert.thatSingleton(countStackTracesOf(failures, UserCodeExecutionException.class)) + .isEqualTo(1L); + + pipeline.run(); + } + @Test public void givenCallerThrowsQuotaException_emitsIntoFailurePCollection() { Result result = @@ -142,7 +156,7 @@ public void givenCallerTimeout_emitsFailurePCollection() { } @Test - public void givenCallerThrowsTimeoutException_emitsFailurePCollection() { + public void givenCallerThrowsTimeoutException_thenPreservesExceptionType() { Result result = pipeline .apply(Create.of(new Request("a"))) @@ -150,7 +164,7 @@ public void givenCallerThrowsTimeoutException_emitsFailurePCollection() { PCollection failures = result.getFailures(); PAssert.thatSingleton(countStackTracesOf(failures, UserCodeExecutionException.class)) - .isEqualTo(1L); + .isEqualTo(0L); PAssert.thatSingleton(countStackTracesOf(failures, UserCodeQuotaException.class)).isEqualTo(0L); PAssert.thatSingleton(countStackTracesOf(failures, UserCodeTimeoutException.class)) .isEqualTo(1L); @@ -158,6 +172,23 @@ public void givenCallerThrowsTimeoutException_emitsFailurePCollection() { pipeline.run(); } + @Test + public void givenCallerThrowsRemoteSystemException_thenPreservesExceptionType() { + Result result = + pipeline + .apply(Create.of(new Request("a"))) + .apply( + Call.of(new CallerThrowsRemoteSystemException(), NON_DETERMINISTIC_RESPONSE_CODER)); + + PCollection failures = result.getFailures(); + PAssert.thatSingleton(countStackTracesOf(failures, UserCodeRemoteSystemException.class)) + .isEqualTo(1L); + PAssert.thatSingleton(countStackTracesOf(failures, UserCodeExecutionException.class)) + .isEqualTo(0L); + + pipeline.run(); + } + @Test public void givenSetupThrowsUserCodeExecutionException_throwsError() { pipeline @@ -376,6 +407,14 @@ public Response call(Request request) throws UserCodeExecutionException { } } + private static class CallerThrowsRuntimeException implements Caller { + + @Override + public Response call(Request request) { + throw new RuntimeException("unexpected error"); + } + } + private static class CallerThrowsTimeout implements Caller { @Override @@ -384,6 +423,14 @@ public Response call(Request request) throws UserCodeExecutionException { } } + private static class CallerThrowsRemoteSystemException implements Caller { + + @Override + public Response call(Request request) throws UserCodeExecutionException { + throw new UserCodeRemoteSystemException(""); + } + } + private static class CallerInvokesQuotaException implements Caller { @Override