allow sending progress from Task coroutines by using co_yield

This commit is contained in:
matcool 2024-11-12 12:21:09 -03:00
parent 0ee9aebdee
commit ab196b9adf

View file

@ -9,7 +9,7 @@
namespace geode { namespace geode {
namespace geode_internal { namespace geode_internal {
template <class T> template <class T, class P>
struct TaskPromise; struct TaskPromise;
template <class T, class P> template <class T, class P>
@ -161,7 +161,7 @@ namespace geode {
template <std::move_constructible T2, std::move_constructible P2> template <std::move_constructible T2, std::move_constructible P2>
friend class Task; friend class Task;
template <class> template <class, class>
friend struct geode_internal::TaskPromise; friend struct geode_internal::TaskPromise;
template <class, class> template <class, class>
@ -322,7 +322,7 @@ namespace geode {
template <std::move_constructible T2, std::move_constructible P2> template <std::move_constructible T2, std::move_constructible P2>
friend class Task; friend class Task;
template <class> template <class, class>
friend struct geode_internal::TaskPromise; friend struct geode_internal::TaskPromise;
template <class, class> template <class, class>
@ -924,12 +924,22 @@ namespace geode {
// //
// The body of the coroutine is ran in whatever thread it got called in. // The body of the coroutine is ran in whatever thread it got called in.
// TODO: maybe guarantee main thread? // TODO: maybe guarantee main thread?
//
// The coroutine can also yield progress values using `co_yield`:
// ```
// Task<std::string, int> someTask() {
// for (int i = 0; i < 10; i++) {
// co_yield i;
// }
// co_return "done!";
// }
// ```
namespace geode { namespace geode {
namespace geode_internal { namespace geode_internal {
template <class T> template <class T, class P>
struct TaskPromise { struct TaskPromise {
using MyTask = Task<T>; using MyTask = Task<T, P>;
std::weak_ptr<typename MyTask::Handle> m_handle; std::weak_ptr<typename MyTask::Handle> m_handle;
~TaskPromise() { ~TaskPromise() {
@ -948,10 +958,15 @@ namespace geode {
return handle; return handle;
} }
void return_value(T&& x) { void return_value(T x) {
MyTask::finish(m_handle.lock(), std::move(x)); MyTask::finish(m_handle.lock(), std::move(x));
} }
std::suspend_never yield_value(P value) {
MyTask::progress(m_handle.lock(), std::move(value));
return {};
}
bool isCancelled() { bool isCancelled() {
if (auto p = m_handle.lock()) { if (auto p = m_handle.lock()) {
return p->is(MyTask::Status::Cancelled); return p->is(MyTask::Status::Cancelled);
@ -968,8 +983,8 @@ namespace geode {
return task.isFinished(); return task.isFinished();
} }
template <class U> template <class U, class V>
void await_suspend(std::coroutine_handle<TaskPromise<U>> handle) { void await_suspend(std::coroutine_handle<TaskPromise<U, V>> handle) {
if (handle.promise().isCancelled()) { if (handle.promise().isCancelled()) {
handle.destroy(); handle.destroy();
return; return;
@ -981,7 +996,7 @@ namespace geode {
handle.destroy(); handle.destroy();
return; return;
} }
parentHandle->m_extraData = std::make_unique<typename Task<U>::Handle::ExtraData>( parentHandle->m_extraData = std::make_unique<typename Task<U, V>::Handle::ExtraData>(
static_cast<void*>(new EventListener<Task<T, P>>( static_cast<void*>(new EventListener<Task<T, P>>(
[handle](auto* event) { [handle](auto* event) {
if (event->getValue()) { if (event->getValue()) {
@ -1014,7 +1029,7 @@ auto operator co_await(geode::Task<T, P> task) {
return geode::geode_internal::TaskAwaiter<T, P>{task}; return geode::geode_internal::TaskAwaiter<T, P>{task};
} }
template <class T, class... Args> template <class T, class P, class... Args>
struct std::coroutine_traits<geode::Task<T>, Args...> { struct std::coroutine_traits<geode::Task<T, P>, Args...> {
using promise_type = geode::geode_internal::TaskPromise<T>; using promise_type = geode::geode_internal::TaskPromise<T, P>;
}; };