Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions Control/Concurrent/Async/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
{-# LANGUAGE DeriveDataTypeable #-}
#endif
{-# OPTIONS -Wall #-}
{-# LANGUAGE ScopedTypeVariables #-}

-----------------------------------------------------------------------------
-- |
Expand Down Expand Up @@ -182,10 +183,24 @@ withAsyncUsing doFork = \action inner -> do
let a = Async t (readTMVar var)
r <- restore (inner a) `catchAll` \e -> do
uninterruptibleCancel a
throwIO e
rethrowIO' e
uninterruptibleCancel a
return r


-- | This function attempts at rethrowing while keeping the context
-- This is internal and only working with GHC >=9.12, otherwise it fallsback to
-- standard 'throwIO'
rethrowIO' :: SomeException -> IO a
#if MIN_VERSION_base(4,21,0)
rethrowIO' e =
case fromException e of
Just (e' :: ExceptionWithContext SomeException) -> rethrowIO e'
Nothing -> throwIO e
#else
rethrowIO' = throwIO
#endif

-- | Wait for an asynchronous action to complete, and return its
-- value. If the asynchronous action threw an exception, then the
-- exception is re-thrown by 'wait'.
Expand Down Expand Up @@ -228,7 +243,20 @@ poll = atomically . pollSTM
waitSTM :: Async a -> STM a
waitSTM a = do
r <- waitCatchSTM a
either throwSTM return r
either (rethrowSTM) return r

-- | This function attempts at rethrowing while keeping the context
-- This is internal and only working with GHC >=9.12, otherwise it fallsback to
-- standard 'throwSTM'
rethrowSTM :: SomeException -> STM a
#if MIN_VERSION_base(4,21,0)
rethrowSTM e =
case fromException e of
Just (e' :: ExceptionWithContext SomeException) -> throwSTM (NoBacktrace e')
Nothing -> throwSTM e
#else
rethrowSTM = throwSTM
#endif

-- | A version of 'waitCatch' that can be used inside an STM transaction.
--
Expand Down Expand Up @@ -613,7 +641,7 @@ race left right = concurrently' left right collect
collect m = do
e <- m
case e of
Left ex -> throwIO ex
Left ex -> rethrowIO' ex
Right r -> return r

-- race_ :: IO a -> IO b -> IO ()
Expand All @@ -627,7 +655,7 @@ concurrently left right = concurrently' left right (collect [])
collect xs m = do
e <- m
case e of
Left ex -> throwIO ex
Left ex -> rethrowIO' ex
Right r -> collect (r:xs) m

-- concurrentlyE :: IO (Either e a) -> IO (Either e b) -> IO (Either e (a, b))
Expand All @@ -640,7 +668,7 @@ concurrentlyE left right = concurrently' left right (collect [])
collect xs m = do
e <- m
case e of
Left ex -> throwIO ex
Left ex -> rethrowIO' ex
Right r -> collect (r:xs) m

concurrently' :: IO a -> IO b
Expand Down Expand Up @@ -699,7 +727,7 @@ concurrently_ left right = concurrently' left right (collect 0)
collect i m = do
e <- m
case e of
Left ex -> throwIO ex
Left ex -> rethrowIO' ex
Right _ -> collect (i + 1 :: Int) m


Expand Down
86 changes: 86 additions & 0 deletions test/test-async.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{-# LANGUAGE CPP,ScopedTypeVariables,DeriveDataTypeable #-}
{-# LANGUAGE DeriveAnyClass #-}
module Main where

import Test.Framework (defaultMain, testGroup)
Expand All @@ -19,6 +20,10 @@ import Data.Foldable (foldMap)
import Data.Maybe

import Prelude hiding (catch)
import Control.Exception.Annotation (ExceptionAnnotation(..))
import Control.Exception.Context (displayExceptionContext, getExceptionAnnotations)
import GHC.Stack (HasCallStack)
import Control.Exception.Backtrace (Backtraces, displayBacktraces)

main = defaultMain tests

Expand Down Expand Up @@ -65,7 +70,11 @@ tests = [
, testCase "concurrentlyE_Monoid" concurrentlyE_Monoid
, testCase "concurrentlyE_Monoid_fail" concurrentlyE_Monoid_fail
#endif

#if MIN_VERSION_base(4,9,0)
, testGroup "exception rethrow" exception_rethrow
]
#endif
]

value = 42 :: Int
Expand Down Expand Up @@ -459,3 +468,80 @@ concurrentlyE_Monoid_fail = do
r :: Either Char [Char] <- runConcurrentlyE $ foldMap ConcurrentlyE $ current
assertEqual "The earliest failure" (Left 'u') r
#endif


#if MIN_VERSION_base(4,9,0)
-- The following regroups tests of exception context propagation to ensure that
-- exception rethrown by async keep the initial backtrace.

-- | This is a dummy exception that we can throw
data Exc = Exc
deriving (Show, Exception)

action_wrapper :: HasCallStack => (IO x -> IO y) -> IO y
action_wrapper op = op action

action :: HasCallStack => IO x
action = throwIO Exc


-- | From an exception, extract two lines of context, ignoring the header and
-- the remaining lines.
--
-- For example, when calling the above 'action_wrapper (\x -> x)', in GHC 9.12, the current callstack looks like:
--
--
-- HasCallStack backtrace:
-- throwIO, called at test/test-async.hs:485:11 in async-2.2.5-inplace-test-async:Main
-- action, called at test/test-async.hs:482:10 in async-2.2.5-inplace-test-async:Main
-- action_wrapper, called at <interactive>:2:1 in interactive:Ghci1
--
-- We drop the header (e.g. HasCallStack backtrace:) and only keep the two
-- lines showing the callstack inside "action".
--
-- Note that it does not show where action_wrapper was called, but the idea
-- is that action_wrapper will do the call to the async primitive (e.g.
-- 'concurrently') and will hence keep the trace of where 'concurrently' was
-- called.
extractThrowOrigin :: ExceptionWithContext Exc -> [String]
extractThrowOrigin (ExceptionWithContext ctx e) = do
let backtraces :: [Backtraces] = getExceptionAnnotations ctx
case backtraces of
[backtrace] -> take 2 $ drop 1 (lines (displayBacktraces backtrace))
_ -> error "more than one backtrace"

-- | Run 'action' through a wrapper (using 'action_wrapper') and with a naive
-- wrapper and show that the wrapper returns the same callstack when the
-- exception in 'action' is raised.
compareTwoExceptions op = do
Left direct_exception <- tryWithContext (action_wrapper (\x -> x))
let direct_origin = extractThrowOrigin direct_exception

Left indirect_exception <- tryWithContext (action_wrapper op)
let indirect_origin = extractThrowOrigin indirect_exception

assertEqual "The exception origins" direct_origin indirect_origin

doNothing = pure ()
doForever = doForever

exception_rethrow = [
testCase "concurrentlyL" $ compareTwoExceptions (\action -> concurrently action doNothing),
testCase "concurrentlyR" $ compareTwoExceptions (\action -> concurrently doNothing action),
testCase "concurrently_L" $ compareTwoExceptions (\action -> concurrently_ action doNothing),
testCase "concurrently_R" $ compareTwoExceptions (\action -> concurrently_ doNothing action),
testCase "raceL" $ compareTwoExceptions (\action -> race action doForever),
testCase "raceR" $ compareTwoExceptions (\action -> race doForever action),
testCase "race_L" $ compareTwoExceptions (\action -> race_ action doForever),
testCase "race_R" $ compareTwoExceptions (\action -> race_ doForever action),
testCase "mapConcurrently" $ compareTwoExceptions (\action -> mapConcurrently (\() -> action) [(), (), ()]),
testCase "forConcurrently" $ compareTwoExceptions (\action -> forConcurrently [(), (), ()] (\() -> action)),
testCase "withAsync wait" $ compareTwoExceptions $ \action -> do
withAsync action $ \a -> do
wait a,
testCase "withAsync inside" $ compareTwoExceptions $ \action -> do
withAsync doForever $ \a -> do
action
]
#endif