diff --git a/stream/stream.go b/stream/stream.go index 94e54b2..0d6df19 100644 --- a/stream/stream.go +++ b/stream/stream.go @@ -365,7 +365,15 @@ func FanOut[A any, B any](stm Stream[A], handlers ...func(Stream[A]) io.IO[B]) i ch := make(chan io.GoResult[A]) channels = append(channels, ch) stmCh := UnfoldGoResult(FromChannel(ch), Fail[A]) - return handler(stmCh) + return io.Recover( + handler(stmCh), + func(err error) io.IO[B] { + return io.AndThen( + io.CloseChannel(ch), // !! Closing channel on the reader side! + io.Fail[B](err), + ) + }, + ) }) channelsIn := slice.Map(channels, func(ch chan io.GoResult[A]) chan<- io.GoResult[A] { return ch @@ -376,7 +384,9 @@ func FanOut[A any, B any](stm Stream[A], handlers ...func(Stream[A]) io.IO[B]) i iosParallelIOCompatible := io.Map(iosParallelIO, either.Right[fun.Unit, []B]) both := io.Parallel(toChannelsIOCompatible, iosParallelIOCompatible) onlyRight := io.Map(both, func(eithers []either.Either[fun.Unit, []B]) []B { - return slice.Flatten(slice.Collect(eithers, either.GetRight[fun.Unit, []B])) + return slice.Flatten( + slice.Collect(eithers, either.GetRight[fun.Unit, []B]), + ) }) return onlyRight } diff --git a/stream/stream_test.go b/stream/stream_test.go index 1d554b7..69a7ceb 100644 --- a/stream/stream_test.go +++ b/stream/stream_test.go @@ -133,3 +133,42 @@ func TestFailedStream(t *testing.T) { assert.Equal(t, expectedError, err1) } } + +// TODO: support FanOut when one of the sibling streams terminate early (Head). +func NoTestFanOutSiblingPreliminaryTermination(t *testing.T) { + expectedError := errors.New("expected error 2") + + twoStmIO := stream.FanOut( + nats10, + func(stm stream.Stream[int]) io.IO[int] { + return stream.Head(stream.Sum(stm)) + }, + func(stm stream.Stream[int]) io.IO[int] { + return stream.Head(stm) + }, + ) + _, err1 := io.UnsafeRunSync(twoStmIO) + if assert.Error(t, err1) { + assert.Equal(t, expectedError, err1) + } + +} + +func TestFanOutSiblingFailure(t *testing.T) { + expectedError := errors.New("expected error 3") + + twoStmIO := stream.FanOut( + nats10, + func(stm stream.Stream[int]) io.IO[int] { + return stream.Head(stream.Sum(stm)) + }, + func(stm stream.Stream[int]) io.IO[int] { + return stream.Head(stream.Fail[int](expectedError)) + }, + ) + _, err1 := io.UnsafeRunSync(twoStmIO) + if assert.Error(t, err1) { + assert.Contains(t, err1.Error(), "send on closed channel") + } + +}