co_await support for geode Task

see comment at the bottom of the header for more information
This commit is contained in:
matcool 2024-11-12 02:31:36 -03:00
parent acad3d2a8d
commit e61b2c0595

View file

@ -5,8 +5,17 @@
#include "../loader/Loader.hpp"
#include <mutex>
#include <string_view>
#include <coroutine>
namespace geode {
namespace geode_internal {
template <class T>
struct TaskPromise;
template <class T, class P>
struct TaskAwaiter;
}
/**
* Tasks represent an asynchronous operation that will be finished at some
* unknown point in the future. Tasks can report their progress, and will
@ -152,6 +161,12 @@ namespace geode {
template <std::move_constructible T2, std::move_constructible P2>
friend class Task;
template <class>
friend struct geode_internal::TaskPromise;
template <class, class>
friend struct geode_internal::TaskAwaiter;
public:
Handle(PrivateMarker, std::string_view name) : m_name(name) {}
~Handle() {
@ -307,6 +322,12 @@ namespace geode {
template <std::move_constructible T2, std::move_constructible P2>
friend class Task;
template <class>
friend struct geode_internal::TaskPromise;
template <class, class>
friend struct geode_internal::TaskAwaiter;
public:
// Allow default-construction
Task() : m_handle(nullptr) {}
@ -883,3 +904,117 @@ namespace geode {
static_assert(is_filter<Task<int>>, "The Task class must be a valid event filter!");
}
// - C++20 coroutine support for Task - //
// Example usage (function must return a Task):
// ```
// Task<int> someTask() {
// auto response = co_await web::WebRequest().get("https://example.com");
// co_return response.code();
// }
// ```
// This will create a Task that will finish with the response code of the
// web request.
//
// Note: If the Task the coroutine is waiting on is cancelled, the coroutine
// will be destroyed and the Task will be cancelled as well. If the Task returned
// by the coroutine is cancelled, the coroutine will be destroyed as well and execution
// stops as soon as possible.
//
// The body of the coroutine is ran in whatever thread it got called in.
// TODO: maybe guarantee main thread?
namespace geode {
namespace geode_internal {
template <class T>
struct TaskPromise {
using MyTask = Task<T>;
std::weak_ptr<typename MyTask::Handle> m_handle;
~TaskPromise() {
// does nothing if its not pending
MyTask::cancel(m_handle.lock());
}
std::suspend_never initial_suspend() noexcept { return {}; }
std::suspend_never final_suspend() noexcept { return {}; }
// TODO: do something here?
void unhandled_exception() {}
MyTask get_return_object() {
auto handle = MyTask::Handle::create("<Coroutine Task>");
m_handle = handle;
return handle;
}
void return_value(T&& x) {
MyTask::finish(m_handle.lock(), std::move(x));
}
bool isCancelled() {
if (auto p = m_handle.lock()) {
return p->is(MyTask::Status::Cancelled);
}
return true;
}
};
template <class T, class P>
struct TaskAwaiter {
Task<T, P> task;
bool await_ready() {
return task.isFinished();
}
template <class U>
void await_suspend(std::coroutine_handle<TaskPromise<U>> handle) {
if (handle.promise().isCancelled()) {
handle.destroy();
return;
}
// this should be fine because the parent task can only have
// one pending task at a time
std::shared_ptr<Task<U>::Handle> parentHandle = handle.promise().m_handle.lock();
if (!parentHandle) {
handle.destroy();
return;
}
parentHandle->m_extraData = std::make_unique<typename Task<U>::Handle::ExtraData>(
static_cast<void*>(new EventListener<Task<T, P>>(
[handle](auto* event) {
if (event->getValue()) {
handle.resume();
}
if (event->isCancelled()) {
handle.destroy();
}
},
task
)),
+[](void* ptr) {
delete static_cast<EventListener<Task<T, P>>*>(ptr);
},
+[](void* ptr) {
static_cast<EventListener<Task<T, P>>*>(ptr)->getFilter().cancel();
}
);
}
T await_resume() {
return std::move(*task.getFinishedValue());
}
};
}
}
template <class T, class P>
auto operator co_await(geode::Task<T, P> task) {
return geode::geode_internal::TaskAwaiter<T, P>{task};
}
template <class T, class... Args>
struct std::coroutine_traits<geode::Task<T>, Args...> {
using promise_type = geode::geode_internal::TaskPromise<T>;
};