diff --git a/include/wil/cppwinrt_helpers.h b/include/wil/cppwinrt_helpers.h index c46a73e0..f4d5c4e3 100644 --- a/include/wil/cppwinrt_helpers.h +++ b/include/wil/cppwinrt_helpers.h @@ -41,12 +41,16 @@ namespace wil::details namespace wil::details { template using coroutine_handle = std::experimental::coroutine_handle; + using suspend_always = std::experimental::suspend_always; + using suspend_never = std::experimental::suspend_never; } #elif defined(__cpp_lib_coroutine) && (__cpp_lib_coroutine >= 201902L) #include namespace wil::details { template using coroutine_handle = std::coroutine_handle; + using suspend_always = std::suspend_always; + using suspend_never = std::suspend_never; } #endif /// @endcond @@ -312,6 +316,229 @@ namespace wil } } } + +#if defined(_RESUMABLE_FUNCTIONS_SUPPORTED) || (defined(__cpp_lib_coroutine) && (__cpp_lib_coroutine >= 201902L)) +/// @cond +namespace wil::details +{ + template + struct iterator_promise : winrt::implements< + iterator_promise, + winrt::Windows::Foundation::Collections::IIterator + > + { + private: + enum class IterationStatus + { + Producing, + Value, + Done + }; + + public: + unsigned long __stdcall Release() noexcept + { + uint32_t const remaining = this->subtract_reference(); + + if (remaining == 0) + { + std::atomic_thread_fence(std::memory_order_acquire); + coroutine_handle::from_promise(*this).destroy(); + } + + return remaining; + } + + winrt::Windows::Foundation::Collections::IIterator get_return_object() noexcept + { + return { winrt::get_abi(static_cast const&>(*this)), winrt::take_ownership_from_abi }; + } + + suspend_never initial_suspend() const noexcept + { + return {}; + } + + suspend_always final_suspend() const noexcept + { + return {}; + } + + void unhandled_exception() const + { + throw; + } + + constexpr void await_transform() = delete; + + constexpr void return_void() noexcept + { + m_status = IterationStatus::Done; + } + + template + auto yield_value(U &&value) + { + struct YieldAwaiter + { + bool m_ready; + + constexpr bool await_ready() const noexcept + { + return m_ready; + } + + constexpr void await_suspend(coroutine_handle<>) const noexcept {} + constexpr void await_resume() const noexcept {} + }; + + *m_current = std::forward(value); + + if (m_current == m_values.end() - 1) + { + if (m_current != &m_last_value) + { + m_last_value = *m_current; + } + + m_status = IterationStatus::Value; + ++m_current; + return YieldAwaiter{ false }; + } + else + { + ++m_current; + return YieldAwaiter{ true }; + } + } + +#if defined(_DEBUG) && !defined(WINRT_NO_MAKE_DETECTION) + void use_make_function_to_create_this_object() final + { + } +#endif + + uint32_t produce_values(winrt::array_view const& view) + { + if (m_status == IterationStatus::Producing) + { + return 0; + } + else if (m_status == IterationStatus::Done) + { + throw winrt::hresult_out_of_bounds(); + } + + m_values = view; + m_current = m_values.begin(); + m_status = IterationStatus::Producing; + + coroutine_handle::from_promise(*this).resume(); + + return static_cast(m_current - m_values.begin()); + } + + bool HasCurrent() const noexcept + { + return m_status == IterationStatus::Value; + } + + TResult Current() const noexcept + { + return m_last_value; + } + + uint32_t GetMany(winrt::array_view values) + { + if (!HasCurrent() || values.empty()) + { + return 0; + } + + values.front() = Current(); + + uint32_t result; + + if (values.size() == 1) + { + result = 1; + } + else + { + result = produce_values({ values.data() + 1, values.size() - 1 }) + 1; + } + + MoveNext(); + return result; + } + + bool MoveNext() + { + return produce_values({ &m_last_value, 1 }); + } + + private: + IterationStatus m_status{ IterationStatus::Producing }; + winrt::array_view m_values{ &m_last_value, 1 }; + TResult* m_current{ &m_last_value }; + TResult m_last_value{ empty() }; + }; + + template + struct iterable_iterator_helper : winrt::implements< + iterable_iterator_helper, + winrt::Windows::Foundation::Collections::IIterable + > + { + iterable_iterator_helper(Func&& func, Args&&... args) : + m_func{ std::forward(func) }, + m_args{ std::forward(args)... } + { + } + + auto First() + { + return std::apply(m_func, m_args); + } + + private: + Func m_func; + std::tuple m_args; + }; + + template + struct iterator_result; + + template + struct iterator_result> + { + using type = TResult; + }; +} +/// @endcond + +namespace wil +{ + template>::type> + winrt::Windows::Foundation::Collections::IIterable make_iterable_from_generator(Func&& func, Args&&... args) + { + return winrt::make>(std::forward(func), std::forward(args)...); + } +} + +#ifdef __cpp_lib_coroutine +namespace std +#else +namespace std::experimental +#endif +{ + template + struct coroutine_traits, Args...> + { + using promise_type = wil::details::iterator_promise; + }; +} +#endif #endif #if defined(WINRT_Windows_UI_H) && defined(_WINDOWS_UI_INTEROP_H_) && !defined(__WIL_CPPWINRT_WINDOWS_UI_INTEROP_HELPERS) diff --git a/tests/CppWinRTTests.cpp b/tests/CppWinRTTests.cpp index 210ad858..80075361 100644 --- a/tests/CppWinRTTests.cpp +++ b/tests/CppWinRTTests.cpp @@ -647,6 +647,199 @@ TEST_CASE("CppWinRTTests::ResumeForegroundTests", "[cppwinrt]") co_await wil::resume_foreground(dispatcher, test::TestDispatcherPriority::Weird); }().get(); } + +namespace test +{ + winrt::Windows::Foundation::Collections::IIterator hello_world_generator() + { + co_yield L"Hello"; + co_yield L"World!"; + } + + class set_true_on_destruction + { + public: + set_true_on_destruction(bool& value) noexcept : m_value{ &value } + { + } + + set_true_on_destruction(set_true_on_destruction const&) = delete; + set_true_on_destruction& operator=(set_true_on_destruction const&) = delete; + + set_true_on_destruction(set_true_on_destruction&& other) noexcept : m_value{ std::exchange(other.m_value, nullptr) } + { + } + + set_true_on_destruction& operator=(set_true_on_destruction&& other) noexcept + { + m_value = std::exchange(other.m_value, nullptr); + return *this; + } + + ~set_true_on_destruction() + { + if (m_value) + { + *m_value = true; + } + } + + private: + bool* m_value{nullptr}; + }; +} + + +TEST_CASE("CppWinRTTests::Generator", "[cppwinrt]") +{ + SECTION("Hello World") + { + auto iterator = test::hello_world_generator(); + REQUIRE(iterator.HasCurrent()); + REQUIRE(iterator.Current() == L"Hello"); + + REQUIRE(iterator.MoveNext()); + REQUIRE(iterator.HasCurrent()); + REQUIRE(iterator.Current() == L"World!"); + + REQUIRE(!iterator.MoveNext()); + REQUIRE(!iterator.HasCurrent()); + } + + SECTION("Value types") + { + auto iterator = []() -> winrt::Windows::Foundation::Collections::IIterator + { + for (int i = 0; i < 10; ++i) + { + co_yield i; + } + }(); + + REQUIRE(iterator.HasCurrent()); + REQUIRE(iterator.MoveNext()); + REQUIRE(iterator.Current() == 1); + + for (int i = 2; i < 10; ++i) + { + REQUIRE(iterator.MoveNext()); + REQUIRE(iterator.Current() == i); + } + + REQUIRE(!iterator.MoveNext()); + REQUIRE(!iterator.HasCurrent()); + } + + SECTION("GetMany") + { + { + auto iterator = test::hello_world_generator(); + + std::array values; + REQUIRE(iterator.GetMany(values) == 2); + REQUIRE(values[0] == L"Hello"); + REQUIRE(values[1] == L"World!"); + REQUIRE(iterator.GetMany(values) == 0); + } + + { + auto iterator = test::hello_world_generator(); + std::array values; + REQUIRE(iterator.GetMany(values) == 1); + REQUIRE(values[0] == L"Hello"); + REQUIRE(iterator.HasCurrent()); + REQUIRE(iterator.Current() == L"World!"); + + REQUIRE(iterator.GetMany(values) == 1); + REQUIRE(values[0] == L"World!"); + REQUIRE(iterator.GetMany(values) == 0); + } + } + + SECTION("MoveNext") + { + auto iterator = test::hello_world_generator(); + REQUIRE(iterator.HasCurrent()); + REQUIRE(iterator.MoveNext()); + REQUIRE(iterator.HasCurrent()); + REQUIRE(!iterator.MoveNext()); + REQUIRE_THROWS_AS(iterator.MoveNext(), winrt::hresult_out_of_bounds); + } + + SECTION("Coroutine destruction") + { + auto set_true_on_destruction_generator = [](test::set_true_on_destruction) -> winrt::Windows::Foundation::Collections::IIterator + { + co_yield L"Hello"; + co_yield L"World!"; + }; + + bool destroyed = false; + { + auto _generator = set_true_on_destruction_generator(destroyed); + } + + REQUIRE(destroyed); + } + + SECTION("Coroutine destruction with exception") + { + auto set_true_on_destruction_generator = [](test::set_true_on_destruction) -> winrt::Windows::Foundation::Collections::IIterator + { + co_yield L"Hello"; + + throw winrt::hresult_invalid_argument(); + + co_yield L"World!"; + }; + + bool destroyed = false; + auto iterator = set_true_on_destruction_generator(destroyed); + REQUIRE_THROWS_AS(iterator.MoveNext(), winrt::hresult_invalid_argument); + } + + SECTION("make_iterable_from_generator") + { + auto generator = wil::make_iterable_from_generator(&test::hello_world_generator); + auto iterator = generator.First(); + + REQUIRE(iterator.Current() == L"Hello"); + REQUIRE(iterator.MoveNext()); + REQUIRE(iterator.Current() == L"World!"); + REQUIRE(!iterator.MoveNext()); + + auto iterator2 = generator.First(); + REQUIRE(iterator2.Current() == L"Hello"); + REQUIRE(iterator2.MoveNext()); + REQUIRE(iterator2.Current() == L"World!"); + REQUIRE(!iterator2.MoveNext()); + } + + SECTION("make_iterable_from_generator with arguments") + { + auto ptr = std::make_unique(3); + auto const_ref_generator = wil::make_iterable_from_generator([](const std::unique_ptr &ptr) -> winrt::Windows::Foundation::Collections::IIterator + { + co_yield *ptr; + }, ptr); + + REQUIRE(const_ref_generator.First().Current() == 3); + *ptr = 4; + REQUIRE(const_ref_generator.First().Current() == 4); + } + + SECTION("Range-based for loop") + { + std::wstring result; + for (const auto &i : wil::make_iterable_from_generator(&test::hello_world_generator)) + { + result += i; + } + + REQUIRE(result == L"HelloWorld!"); + } +} + #endif // coroutines TEST_CASE("CppWinRTTests::ThrownExceptionWithMessage", "[cppwinrt]")