diff --git a/flytekit/exceptions/user.py b/flytekit/exceptions/user.py index e2df9fc73d..8a01c5bbf5 100644 --- a/flytekit/exceptions/user.py +++ b/flytekit/exceptions/user.py @@ -31,7 +31,12 @@ def value(self): @property def error_code(self): - return self.value.error_code if hasattr(self.value, "error_code") else self._ERROR_CODE + if hasattr(self.value, "error_code"): + return self.value.error_code + elif hasattr(type(self.value), "error_code"): + return type(self.value).error_code + else: + return self._ERROR_CODE class FlyteTypeException(FlyteUserException, TypeError): diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index 70e8ea5b4e..ecb6f450c8 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -33,7 +33,7 @@ from flytekit.exceptions import user as user_exceptions from flytekit.exceptions.base import FlyteException from flytekit.exceptions.scopes import system_entry_point -from flytekit.exceptions.user import FlyteRecoverableException, FlyteUserRuntimeException +from flytekit.exceptions.user import FlyteRecoverableException, FlyteUserRuntimeException, FlyteUserException from flytekit.models import literals as _literal_models from flytekit.models.core import errors as error_models, execution from flytekit.models.core import execution as execution_models @@ -128,19 +128,24 @@ def verify_output(*args, **kwargs): _dispatch_execute(ctx, lambda: python_task, "inputs path", "outputs prefix") assert mock_write_to_file.call_count == 1 +class CustomException(FlyteUserException): + _ERROR_CODE = "USER:CustomError" + @pytest.mark.parametrize( "exception_value", [ - FlyteException("exception", timestamp=1), - FlyteException("exception"), - Exception("exception"), + [FlyteException("exception", timestamp=1), FlyteException.error_code], + [FlyteException("exception"), FlyteException.error_code], + [Exception("exception"), FlyteUserRuntimeException.error_code], + [CustomException("exception"), CustomException.error_code], ] ) @mock.patch("flytekit.core.utils.load_proto_from_file") @mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data") @mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") @mock.patch("flytekit.core.utils.write_proto_to_file") -def test_dispatch_execute_exception_with_multi_error_files(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto, exception_value: Exception, monkeypatch): +def test_dispatch_execute_exception_with_multi_error_files(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto, exception_value: typing.Tuple[Exception, str], monkeypatch): + exception_value, error_code = exception_value monkeypatch.setenv("_F_DES", "1") monkeypatch.setenv("_F_WN", "worker") @@ -170,7 +175,7 @@ def verify_output(*args, **kwargs): assert error_filename_base.startswith("error-") uuid.UUID(hex=error_filename_base[6:], version=4) assert error_filename_ext == ".pb" - assert container_error.code == "USER:RuntimeError" + assert container_error.code == error_code mock_write_to_file.side_effect = verify_output _dispatch_execute(ctx, lambda: python_task, "inputs path", "outputs prefix")