diff --git a/cmake/common_build_flags.cmake b/cmake/common_build_flags.cmake index 1e69612fc..11a2eeae4 100644 --- a/cmake/common_build_flags.cmake +++ b/cmake/common_build_flags.cmake @@ -13,6 +13,13 @@ endmacro() # Fixup default compiler settings if (MSVC) + # Ensure exception handling is enabled; some CMake versions don't add /EHsc by default for all + # MSVC-compatible compilers (e.g. clang-cl). Adding it to CMAKE_CXX_FLAGS allows targets that + # need to disable exceptions (e.g. noexcept) to remove it via replace_cxx_flag. + if (NOT CMAKE_CXX_FLAGS MATCHES "/EHsc") + string(APPEND CMAKE_CXX_FLAGS " /EHsc") + endif() + add_compile_options( # Be as strict as reasonably possible, since we want to support consumers using strict warning levels /W4 /WX diff --git a/include/wil/Tracelogging.h b/include/wil/Tracelogging.h index 9124c4dc6..28e2e7e63 100644 --- a/include/wil/Tracelogging.h +++ b/include/wil/Tracelogging.h @@ -872,6 +872,23 @@ class ActivityBase : public details::IFailureCallback } } + bool IsWatching() WI_NOEXCEPT + { + return m_callbackHolder.IsWatching(); + } + + // Coroutine watcher interface: pause watching during suspension + bool suspend() WI_NOEXCEPT + { + return m_callbackHolder.suspend(); + } + + // Coroutine watcher interface: resume watching after suspension + void resume() WI_NOEXCEPT + { + m_callbackHolder.resume(); + } + // Call this API to retrieve an RAII object to watch events on the current thread. The returned // object should only be used on the stack. diff --git a/include/wil/coroutine.h b/include/wil/coroutine.h index 7381fd636..9fa8c37b1 100644 --- a/include/wil/coroutine.h +++ b/include/wil/coroutine.h @@ -640,6 +640,79 @@ struct task_base static void __stdcall wake_by_address(void* completed); }; + +// Generic awaitable wrapper that suspends/resumes a watcher across co_await. +// TPausable must provide: +// bool suspend() - called before suspension, returns true if resume() should be called +// void resume() - called after resumption if suspend() returned true +template +struct coroutine_withsuspend_awaiter +{ + TPausable& pausable; + TChildAwaitable child_awaitable; + bool resume_needed = false; + + bool await_ready() noexcept + { + return child_awaitable.await_ready(); + } + + template + auto await_suspend(T&& handle) noexcept( + noexcept(std::declval().await_suspend(std::forward(handle))) && noexcept(pausable.suspend())) + { + resume_needed = pausable.suspend(); + return child_awaitable.await_suspend(std::forward(handle)); + } + + auto await_resume() noexcept(noexcept(std::declval().await_resume())) + { + if (resume_needed) + { + pausable.resume(); + } + return child_awaitable.await_resume(); + } +}; + +// Priority tags for SFINAE-based overload resolution +struct get_awaiter_priority_fallback +{ +}; +struct get_awaiter_priority_free_op : get_awaiter_priority_fallback +{ +}; +struct get_awaiter_priority_member_op : get_awaiter_priority_free_op +{ +}; + +// Highest priority: member operator co_await +template +auto get_awaiter_impl(T&& awaitable, get_awaiter_priority_member_op) -> decltype(std::forward(awaitable).operator co_await()) +{ + return std::forward(awaitable).operator co_await(); +} + +// Second priority: free operator co_await +template +auto get_awaiter_impl(T&& awaitable, get_awaiter_priority_free_op) -> decltype(operator co_await(std::forward(awaitable))) +{ + return operator co_await(std::forward(awaitable)); +} + +// Fallback: return the awaitable itself +template +T&& get_awaiter_impl(T&& awaitable, get_awaiter_priority_fallback) +{ + return std::forward(awaitable); +} + +template +auto get_awaiter(T&& awaitable) -> decltype(get_awaiter_impl(std::forward(awaitable), get_awaiter_priority_member_op{})) +{ + return get_awaiter_impl(std::forward(awaitable), get_awaiter_priority_member_op{}); +} + } // namespace wil::details::coro /// @endcond @@ -700,6 +773,42 @@ template task(com_task&&) -> task; template com_task(task&&) -> com_task; + +/** + * @brief Wraps an awaitable with a suspend/resume watcher. + * + * Suspends and resumes the watcher when the coroutine is suspended and resumed. + * Thread-bound objects such as `ThreadFailureCache` or a WIL Activity are good candidates. + * + * @code + * ThreadFailureCache cache; + * auto result = co_await wil::with_watcher(cache, SomethingAsync()); + * @endcode + * + * The watcher type must provide `bool suspend()` and `void resume()` methods: + * @code + * struct MyThreadWatcher { + * bool suspend() { + * // pause temporarily; return true if resume() should be called on coroutine resume + * } + * void resume() { + * // resume after the coroutine resumes + * } + * }; + * @endcode + * + * @tparam TWatcher Type of watcher invoked on suspend/resume. + * @tparam TAwaitable Type of wrapped awaitable. + * @param watcher The watcher to suspend and resume with the coroutine. + * @param awaitable The awaitable to wrap. + */ +template +auto with_watcher(TWatcher& watcher, TAwaitable&& awaitable) +{ + using awaiter_t = std::decay_t(awaitable)))>; + return details::coro::coroutine_withsuspend_awaiter{watcher, std::forward(awaitable)}; +} + } // namespace wil template diff --git a/include/wil/result.h b/include/wil/result.h index b2846c08f..c9d2607b2 100644 --- a/include/wil/result.h +++ b/include/wil/result.h @@ -992,6 +992,24 @@ namespace details } } + bool suspend() WI_NOEXCEPT + { + const bool wasWatching = IsWatching(); + if (wasWatching) + { + StopWatching(); + } + return wasWatching; + } + + void resume() WI_NOEXCEPT + { + if (!IsWatching()) + { + StartWatching(); + } + } + static bool GetThreadContext( _Inout_ FailureInfo* pFailure, _In_opt_ ThreadFailureCallbackHolder* pCallback, @@ -1093,14 +1111,14 @@ namespace details { public: explicit ThreadFailureCallbackFn(_In_opt_ CallContextInfo* pContext, _Inout_ TLambda&& errorFunction) WI_NOEXCEPT - : m_errorFunction(wistd::move(errorFunction)), - m_callbackHolder(this, pContext) + : m_callbackHolder(this, pContext), + m_errorFunction(wistd::move(errorFunction)) { } ThreadFailureCallbackFn(_Inout_ ThreadFailureCallbackFn&& other) WI_NOEXCEPT - : m_errorFunction(wistd::move(other.m_errorFunction)), - m_callbackHolder(this, other.m_callbackHolder.CallContextInfo()) + : m_callbackHolder(this, other.m_callbackHolder.CallContextInfo()), + m_errorFunction(wistd::move(other.m_errorFunction)) { } @@ -1109,12 +1127,22 @@ namespace details return m_errorFunction(failure); } + bool suspend() WI_NOEXCEPT + { + return m_callbackHolder.suspend(); + } + + void resume() WI_NOEXCEPT + { + m_callbackHolder.resume(); + } + private: ThreadFailureCallbackFn(_In_ ThreadFailureCallbackFn const&); ThreadFailureCallbackFn& operator=(_In_ ThreadFailureCallbackFn const&); - TLambda m_errorFunction; ThreadFailureCallbackHolder m_callbackHolder; + TLambda m_errorFunction; }; // returns true if telemetry was reported for this error diff --git a/tests/CoroutineTests.cpp b/tests/CoroutineTests.cpp index dc2b5a013..4879b05ea 100644 --- a/tests/CoroutineTests.cpp +++ b/tests/CoroutineTests.cpp @@ -31,6 +31,25 @@ wil::task void_task(std::shared_ptr value) ++*value; co_return; } + +struct resume_new_cpp_thread +{ + bool await_ready() noexcept + { + return false; + } + template + void await_suspend(Handle handle) noexcept + { + std::thread([handle] { + handle(); + }).detach(); + } + void await_resume() + { + } +}; + } // namespace TEST_CASE("CppWinRTTests::SimpleNoCOMTaskTest", "[cppwinrt]") @@ -42,4 +61,65 @@ TEST_CASE("CppWinRTTests::SimpleNoCOMTaskTest", "[cppwinrt]") }).join(); } +TEST_CASE("CoroutineTests::WithWatcherBasic", "[coroutine]") +{ + // Test that wil::with_watcher wraps an awaitable and calls suspend()/resume() + // on the watcher object across co_await. + struct mock_watcher + { + int suspend_count = 0; + int resume_count = 0; + bool suspend() noexcept + { + ++suspend_count; + return true; + } + void resume() noexcept + { + ++resume_count; + } + }; + + auto test = [](mock_watcher& watcher) -> wil::task { + co_await wil::with_watcher(watcher, resume_new_cpp_thread{}); + }; + + std::thread([&] { + mock_watcher watcher; + std::move(test(watcher)).get(); + REQUIRE(watcher.suspend_count == 1); + REQUIRE(watcher.resume_count == 1); + }).join(); +} + +TEST_CASE("CoroutineTests::WithWatcherSuspendReturnsFalse", "[coroutine]") +{ + // When suspend() returns false, resume() should not be called. + struct mock_watcher + { + int suspend_count = 0; + int resume_count = 0; + bool suspend() noexcept + { + ++suspend_count; + return false; + } + void resume() noexcept + { + ++resume_count; + } + }; + + auto test = [](mock_watcher& watcher) -> wil::task { + co_await wil::with_watcher(watcher, resume_new_cpp_thread{}); + }; + + std::thread([&] { + mock_watcher watcher; + std::move(test(watcher)).get(); + REQUIRE(watcher.suspend_count == 1); + REQUIRE(watcher.resume_count == 0); + }).join(); +} + #endif // coroutines diff --git a/tests/CppWinRTTests.cpp b/tests/CppWinRTTests.cpp index 0db4e4916..fdecdbc4b 100644 --- a/tests/CppWinRTTests.cpp +++ b/tests/CppWinRTTests.cpp @@ -745,6 +745,75 @@ TEST_CASE("CppWinRTTests::ResumeForegroundTests", "[cppwinrt]") }() .get(); } + +namespace +{ +struct resume_new_cpp_thread_for_watcher +{ + bool await_ready() noexcept + { + return false; + } + template + void await_suspend(Handle handle) noexcept + { + std::thread([handle] { + handle(); + }).detach(); + } + void await_resume() + { + } +}; +} // namespace + +TEST_CASE("CppWinRTTests::WithWatcherThreadFailureCallback", "[cppwinrt][coroutine]") +{ + // Test that wil::with_watcher correctly pauses/resumes a ThreadFailureCallback across co_await. + auto test = []() -> wil::task { + auto watcher = wil::ThreadFailureCallback([](wil::FailureInfo const&) { + return false; + }); + co_await wil::with_watcher(watcher, resume_new_cpp_thread_for_watcher{}); + }; + + std::move(test()).get(); +} + +TEST_CASE("CppWinRTTests::WithWatcherWinRTAction", "[cppwinrt][coroutine]") +{ + // Test that wil::with_watcher works with a WinRT IAsyncAction. + auto test = []() -> winrt::Windows::Foundation::IAsyncAction { + auto tid = ::GetCurrentThreadId(); + auto watcher = wil::ThreadFailureCallback([](wil::FailureInfo const&) { + return false; + }); + co_await wil::with_watcher(watcher, winrt::resume_background()); + REQUIRE(tid != ::GetCurrentThreadId()); + }; + + test().get(); +} + +TEST_CASE("CppWinRTTests::WithWatcherWinRTOperation", "[cppwinrt][coroutine]") +{ + // Test that wil::with_watcher works with a WinRT IAsyncOperation. + auto inner = []() -> winrt::Windows::Foundation::IAsyncOperation { + co_await winrt::resume_background(); + co_return winrt::hstring(L"kittens"); + }; + + auto test = [&inner]() -> winrt::Windows::Foundation::IAsyncAction { + auto watcher = wil::ThreadFailureCallback([](wil::FailureInfo const&) { + return false; + }); + auto result = co_await wil::with_watcher(watcher, inner()); + REQUIRE(result == L"kittens"); + }; + + test().get(); +} + #endif // coroutines TEST_CASE("CppWinRTTests::ThrownExceptionWithMessage", "[cppwinrt]") diff --git a/tests/TraceLoggingTests.cpp b/tests/TraceLoggingTests.cpp index 9d51116c7..866c488ce 100644 --- a/tests/TraceLoggingTests.cpp +++ b/tests/TraceLoggingTests.cpp @@ -3,3 +3,36 @@ // Just verify that Tracelogging.h compiles. #define PROVIDER_CLASS_NAME TestProvider #include "TraceLoggingTests.h" + +#include "catch.hpp" +#include "common.h" + +TEST_CASE("TraceLoggingTests::ActivitySuspendResume", "[tracelogging]") +{ + // Test that Activity classes implement the suspend/resume interface for coroutine watchers. + // This interface is used by wil::with_watcher() to pause error watching during co_await. + auto activity = TestProvider::TraceloggingActivity::Start(); + + // Initially watching after Start() + REQUIRE(activity.IsRunning()); + + // suspend() should return true (was watching) and stop watching + bool wasWatching = activity.suspend(); + REQUIRE(wasWatching); + REQUIRE_FALSE(activity.IsWatching()); + + // Calling suspend() again should return false (wasn't watching) + wasWatching = activity.suspend(); + REQUIRE_FALSE(wasWatching); + REQUIRE_FALSE(activity.IsWatching()); + + // resume() should restart watching + activity.resume(); + REQUIRE(activity.IsWatching()); + + // Calling resume() when already watching is safe + activity.resume(); + REQUIRE(activity.IsWatching()); + + activity.Stop(); +}