Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cmake/common_build_flags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@ endmacro()

# Fixup default compiler settings
if (MSVC)
# Ensure exception handling is enabled; some CMake versions don't add /EHsc by default for all
# MSVC-compatible compilers (e.g. clang-cl). Adding it to CMAKE_CXX_FLAGS allows targets that
# need to disable exceptions (e.g. noexcept) to remove it via replace_cxx_flag.
if (NOT CMAKE_CXX_FLAGS MATCHES "/EHsc")
string(APPEND CMAKE_CXX_FLAGS " /EHsc")
endif()

add_compile_options(
# Be as strict as reasonably possible, since we want to support consumers using strict warning levels
/W4 /WX
Expand Down
17 changes: 17 additions & 0 deletions include/wil/Tracelogging.h
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,23 @@ class ActivityBase : public details::IFailureCallback
}
}

bool IsWatching() WI_NOEXCEPT
{
return m_callbackHolder.IsWatching();
}

// Coroutine watcher interface: pause watching during suspension
bool suspend() WI_NOEXCEPT
{
return m_callbackHolder.suspend();
}

// Coroutine watcher interface: resume watching after suspension
void resume() WI_NOEXCEPT
{
m_callbackHolder.resume();
}

// Call this API to retrieve an RAII object to watch events on the current thread. The returned
// object should only be used on the stack.

Expand Down
109 changes: 109 additions & 0 deletions include/wil/coroutine.h
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,79 @@ struct task_base

static void __stdcall wake_by_address(void* completed);
};

// Generic awaitable wrapper that suspends/resumes a watcher across co_await.
// TPausable must provide:
// bool suspend() - called before suspension, returns true if resume() should be called
// void resume() - called after resumption if suspend() returned true
template <typename TPausable, typename TChildAwaitable>
struct coroutine_withsuspend_awaiter
{
TPausable& pausable;
TChildAwaitable child_awaitable;
bool resume_needed = false;

bool await_ready() noexcept
{
return child_awaitable.await_ready();
}

template <typename T>
auto await_suspend(T&& handle) noexcept(
noexcept(std::declval<TChildAwaitable>().await_suspend(std::forward<T>(handle))) && noexcept(pausable.suspend()))
{
resume_needed = pausable.suspend();
return child_awaitable.await_suspend(std::forward<T>(handle));
}

auto await_resume() noexcept(noexcept(std::declval<TChildAwaitable>().await_resume()))
{
if (resume_needed)
{
pausable.resume();
}
return child_awaitable.await_resume();
}
};

// Priority tags for SFINAE-based overload resolution
struct get_awaiter_priority_fallback
{
};
struct get_awaiter_priority_free_op : get_awaiter_priority_fallback
{
};
struct get_awaiter_priority_member_op : get_awaiter_priority_free_op
{
};
Comment on lines +678 to +687
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a wil::details::priority_tag<N> I added (moved) somewhat recently. May be better to just re-use it? Higher N gets higher priority.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, oops, it's not in the details namespace

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot please fix and use the standard priority_tag<> type instead


// Highest priority: member operator co_await
template <typename T>
auto get_awaiter_impl(T&& awaitable, get_awaiter_priority_member_op) -> decltype(std::forward<T>(awaitable).operator co_await())
{
return std::forward<T>(awaitable).operator co_await();
}

// Second priority: free operator co_await
template <typename T>
auto get_awaiter_impl(T&& awaitable, get_awaiter_priority_free_op) -> decltype(operator co_await(std::forward<T>(awaitable)))
{
return operator co_await(std::forward<T>(awaitable));
}

// Fallback: return the awaitable itself
template <typename T>
T&& get_awaiter_impl(T&& awaitable, get_awaiter_priority_fallback)
{
return std::forward<T>(awaitable);
}

template <typename T>
auto get_awaiter(T&& awaitable) -> decltype(get_awaiter_impl(std::forward<T>(awaitable), get_awaiter_priority_member_op{}))
{
return get_awaiter_impl(std::forward<T>(awaitable), get_awaiter_priority_member_op{});
}

} // namespace wil::details::coro
/// @endcond

Expand Down Expand Up @@ -700,6 +773,42 @@ template <typename T>
task(com_task<T>&&) -> task<T>;
template <typename T>
com_task(task<T>&&) -> com_task<T>;

/**
* @brief Wraps an awaitable with a suspend/resume watcher.
*
* Suspends and resumes the watcher when the coroutine is suspended and resumed.
* Thread-bound objects such as `ThreadFailureCache` or a WIL Activity are good candidates.
*
* @code
* ThreadFailureCache cache;
* auto result = co_await wil::with_watcher(cache, SomethingAsync());
* @endcode
*
* The watcher type must provide `bool suspend()` and `void resume()` methods:
* @code
* struct MyThreadWatcher {
* bool suspend() {
* // pause temporarily; return true if resume() should be called on coroutine resume
* }
* void resume() {
* // resume after the coroutine resumes
* }
* };
* @endcode
*
* @tparam TWatcher Type of watcher invoked on suspend/resume.
* @tparam TAwaitable Type of wrapped awaitable.
* @param watcher The watcher to suspend and resume with the coroutine.
* @param awaitable The awaitable to wrap.
*/
template <typename TWatcher, typename TAwaitable>
auto with_watcher(TWatcher& watcher, TAwaitable&& awaitable)
{
using awaiter_t = std::decay_t<decltype(details::coro::get_awaiter(std::forward<TAwaitable>(awaitable)))>;
return details::coro::coroutine_withsuspend_awaiter<TWatcher, awaiter_t>{watcher, std::forward<TAwaitable>(awaitable)};
}

} // namespace wil

template <typename T, typename... Args>
Expand Down
38 changes: 33 additions & 5 deletions include/wil/result.h
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,24 @@ namespace details
}
}

bool suspend() WI_NOEXCEPT
{
const bool wasWatching = IsWatching();
if (wasWatching)
{
StopWatching();
}
return wasWatching;
}

void resume() WI_NOEXCEPT
{
if (!IsWatching())
{
StartWatching();
}
}

static bool GetThreadContext(
_Inout_ FailureInfo* pFailure,
_In_opt_ ThreadFailureCallbackHolder* pCallback,
Expand Down Expand Up @@ -1093,14 +1111,14 @@ namespace details
{
public:
explicit ThreadFailureCallbackFn(_In_opt_ CallContextInfo* pContext, _Inout_ TLambda&& errorFunction) WI_NOEXCEPT
: m_errorFunction(wistd::move(errorFunction)),
m_callbackHolder(this, pContext)
: m_callbackHolder(this, pContext),
m_errorFunction(wistd::move(errorFunction))
{
}

ThreadFailureCallbackFn(_Inout_ ThreadFailureCallbackFn&& other) WI_NOEXCEPT
: m_errorFunction(wistd::move(other.m_errorFunction)),
m_callbackHolder(this, other.m_callbackHolder.CallContextInfo())
: m_callbackHolder(this, other.m_callbackHolder.CallContextInfo()),
m_errorFunction(wistd::move(other.m_errorFunction))
{
}

Expand All @@ -1109,12 +1127,22 @@ namespace details
return m_errorFunction(failure);
}

bool suspend() WI_NOEXCEPT
{
return m_callbackHolder.suspend();
}

void resume() WI_NOEXCEPT
{
m_callbackHolder.resume();
}

private:
ThreadFailureCallbackFn(_In_ ThreadFailureCallbackFn const&);
ThreadFailureCallbackFn& operator=(_In_ ThreadFailureCallbackFn const&);

TLambda m_errorFunction;
ThreadFailureCallbackHolder m_callbackHolder;
TLambda m_errorFunction;
};

// returns true if telemetry was reported for this error
Expand Down
80 changes: 80 additions & 0 deletions tests/CoroutineTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,25 @@ wil::task<void> void_task(std::shared_ptr<int> value)
++*value;
co_return;
}

struct resume_new_cpp_thread
{
bool await_ready() noexcept
{
return false;
}
template <typename Handle>
void await_suspend(Handle handle) noexcept
{
std::thread([handle] {
handle();
}).detach();
}
void await_resume()
{
}
};

} // namespace

TEST_CASE("CppWinRTTests::SimpleNoCOMTaskTest", "[cppwinrt]")
Expand All @@ -42,4 +61,65 @@ TEST_CASE("CppWinRTTests::SimpleNoCOMTaskTest", "[cppwinrt]")
}).join();
}

TEST_CASE("CoroutineTests::WithWatcherBasic", "[coroutine]")
{
// Test that wil::with_watcher wraps an awaitable and calls suspend()/resume()
// on the watcher object across co_await.
struct mock_watcher
{
int suspend_count = 0;
int resume_count = 0;
bool suspend() noexcept
{
++suspend_count;
return true;
}
void resume() noexcept
{
++resume_count;
}
};

auto test = [](mock_watcher& watcher) -> wil::task<void> {
co_await wil::with_watcher(watcher, resume_new_cpp_thread{});
};

std::thread([&] {
mock_watcher watcher;
std::move(test(watcher)).get();
REQUIRE(watcher.suspend_count == 1);
REQUIRE(watcher.resume_count == 1);
}).join();
}

TEST_CASE("CoroutineTests::WithWatcherSuspendReturnsFalse", "[coroutine]")
{
// When suspend() returns false, resume() should not be called.
struct mock_watcher
{
int suspend_count = 0;
int resume_count = 0;
bool suspend() noexcept
{
++suspend_count;
return false;
}
void resume() noexcept
{
++resume_count;
}
};

auto test = [](mock_watcher& watcher) -> wil::task<void> {
co_await wil::with_watcher(watcher, resume_new_cpp_thread{});
};

std::thread([&] {
mock_watcher watcher;
std::move(test(watcher)).get();
REQUIRE(watcher.suspend_count == 1);
REQUIRE(watcher.resume_count == 0);
}).join();
}

#endif // coroutines
69 changes: 69 additions & 0 deletions tests/CppWinRTTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,75 @@ TEST_CASE("CppWinRTTests::ResumeForegroundTests", "[cppwinrt]")
}()
.get();
}

namespace
{
struct resume_new_cpp_thread_for_watcher
{
bool await_ready() noexcept
{
return false;
}
template <typename Handle>
void await_suspend(Handle handle) noexcept
{
std::thread([handle] {
handle();
}).detach();
}
void await_resume()
{
}
};
} // namespace

TEST_CASE("CppWinRTTests::WithWatcherThreadFailureCallback", "[cppwinrt][coroutine]")
{
// Test that wil::with_watcher correctly pauses/resumes a ThreadFailureCallback across co_await.
auto test = []() -> wil::task<void> {
auto watcher = wil::ThreadFailureCallback([](wil::FailureInfo const&) {
return false;
});
co_await wil::with_watcher(watcher, resume_new_cpp_thread_for_watcher{});
};

std::move(test()).get();
}

TEST_CASE("CppWinRTTests::WithWatcherWinRTAction", "[cppwinrt][coroutine]")
{
// Test that wil::with_watcher works with a WinRT IAsyncAction.
auto test = []() -> winrt::Windows::Foundation::IAsyncAction {
auto tid = ::GetCurrentThreadId();
auto watcher = wil::ThreadFailureCallback([](wil::FailureInfo const&) {
return false;
});
co_await wil::with_watcher(watcher, winrt::resume_background());
REQUIRE(tid != ::GetCurrentThreadId());
};

test().get();
}

TEST_CASE("CppWinRTTests::WithWatcherWinRTOperation", "[cppwinrt][coroutine]")
{
// Test that wil::with_watcher works with a WinRT IAsyncOperation.
auto inner = []() -> winrt::Windows::Foundation::IAsyncOperation<winrt::hstring> {
co_await winrt::resume_background();
co_return winrt::hstring(L"kittens");
};

auto test = [&inner]() -> winrt::Windows::Foundation::IAsyncAction {
auto watcher = wil::ThreadFailureCallback([](wil::FailureInfo const&) {
return false;
});
auto result = co_await wil::with_watcher(watcher, inner());
REQUIRE(result == L"kittens");
};

test().get();
}

#endif // coroutines

TEST_CASE("CppWinRTTests::ThrownExceptionWithMessage", "[cppwinrt]")
Expand Down
Loading