new Task class to replace Promises, test seems to indicate it's working

This commit is contained in:
HJfod 2024-04-22 00:08:10 +03:00
parent 4be910bda4
commit a09ba5c67c
9 changed files with 1142 additions and 77 deletions

View file

@ -77,6 +77,7 @@ file(GLOB SOURCES CONFIGURE_DEPENDS
src/ui/mods/list/*.cpp
src/ui/mods/popups/*.cpp
src/ui/mods/sources/*.cpp
src/ui/mods/test/*.cpp
src/ui/*.cpp
src/c++stl/*.cpp
hash/hash.cpp

View file

@ -219,4 +219,130 @@ namespace geode {
virtual ~Event();
};
// template <is_filter F, std::move_constructible T>
// class [[nodiscard]] EventMapper final {
// public:
// using Value = T;
// class Handle final {
// std::optional<EventListener<F>> m_listener;
// class PrivateMarker final {};
// static std::shared_ptr<Handle> create() {
// return std::make_shared<Handle>(PrivateMarker());
// }
// friend class EventMapper;
// public:
// Handle(PrivateMarker) {}
// };
// class Event final : public geode::Event {
// private:
// std::shared_ptr<Handle> m_handle;
// T m_value;
// Event(std::shared_ptr<Handle> handle, T&& value)
// : m_handle(handle), m_value(std::move(value)) {}
// friend class EventMapper;
// public:
// T& getValue() & {
// return m_value;
// }
// T const& getValue() const& {
// return m_value;
// }
// T&& getValue() && {
// return std::move(m_value);
// }
// operator T*() const {
// return m_value;
// }
// T* operator*() const {
// return m_value;
// }
// T* operator->() const {
// return m_value;
// }
// };
// using Mapper = utils::MiniFunction<T(typename F::Event*)>;
// using Callback = void(Event*);
// private:
// EventListenerProtocol* m_listener = nullptr;
// std::shared_ptr<Handle> m_handle;
// EventMapper(std::shared_ptr<Handle> handle) : m_handle(handle) {}
// public:
// EventMapper() : m_handle(nullptr) {}
// static EventMapper immediate(T&& value) {
// auto emapper = EventMapper(Handle::create());
// Loader::get()->queueInMainThread([handle = emapper.m_handle, value = std::move(value)]() mutable {
// EventMapper::Event(handle, std::move(value)).post();
// });
// return emapper;
// }
// static EventMapper create(F&& filter, Mapper&& mapper) {
// auto emapper = EventMapper(Handle::create());
// emapper.m_handle->m_listener.emplace(EventListener(
// // The event listener should not own itself (circular ref = memory leak!!)
// [handle = std::weak_ptr(emapper.m_handle), mapper = std::move(mapper)](F::Event* event) {
// if (auto lock = handle.lock()) {
// EventMapper::Event(lock, mapper(event)).post();
// }
// },
// std::move(filter)
// ));
// return emapper;
// }
// template <class NewMapper>
// auto map(NewMapper&& mapper) {
// using T2 = decltype(mapper(std::declval<T*>()));
// return mapEvent(*this, [mapper = std::move(mapper)](Event* event) -> T2 {
// return mapper(&event->getValue());
// });
// }
// ListenerResult handle(utils::MiniFunction<Callback> fn, Event* e) {
// if (e->m_handle == m_handle) {
// fn(e);
// }
// return ListenerResult::Propagate;
// }
// // todo: i believe alk wanted these to be in their own pool
// EventListenerPool* getPool() const {
// return DefaultEventListenerPool::get();
// }
// void setListener(EventListenerProtocol* listener) {
// m_listener = listener;
// }
// EventListenerProtocol* getListener() const {
// return m_listener;
// }
// };
// template <is_filter F, class Mapper>
// static auto mapEvent(F&& filter, Mapper&& mapper) {
// using T = decltype(mapper(std::declval<typename F::Event*>()));
// return EventMapper<F, T>::create(std::move(filter), std::move(mapper));
// }
// template <is_filter F, class Mapper>
// requires std::copy_constructible<F>
// static auto mapEvent(F const& filter, Mapper&& mapper) {
// using T = decltype(mapper(std::declval<typename F::Event*>()));
// return EventMapper<F, T>::create(F(filter), std::move(mapper));
// }
}

View file

@ -6,23 +6,52 @@
#include "ranges.hpp"
namespace geode {
namespace impl {
struct DefaultProgress {
std::string message;
std::optional<uint8_t> percentage;
struct DefaultProgress {
std::string message;
std::optional<uint8_t> percentage;
DefaultProgress() = default;
DefaultProgress(std::string const& msg) : message(msg) {}
DefaultProgress(auto msg, uint8_t percentage) : message(msg), percentage(percentage) {}
DefaultProgress() = default;
DefaultProgress(std::string const& msg) : message(msg) {}
DefaultProgress(auto msg, uint8_t percentage) : message(msg), percentage(percentage) {}
};
namespace impl {
template <size_t Ty>
struct LogBug {
static inline size_t COUNT = 0;
static const char* ty() {
return Ty ? "Promise" : "Data";
}
LogBug() {
// log::info("created {} that holds {}, {}", ty(), fmt::ptr(this), ++COUNT);
}
LogBug& operator=(LogBug&&) {
// log::info("moved {} that holds {}, {}", ty(), fmt::ptr(this), ++COUNT);
return *this;
}
LogBug& operator=(LogBug const&) {
// log::info("copied {} that holds {}, {}", ty(), fmt::ptr(this), ++COUNT);
return *this;
}
LogBug(LogBug&&) {
// log::info("moved {} that holds {}, {}", ty(), fmt::ptr(this), ++COUNT);
}
LogBug(LogBug const&) {
// log::info("copied {} that holds {}, {}", ty(), fmt::ptr(this), ++COUNT);
}
~LogBug() {
// log::info("destroyed {} that holds {}, {}", ty(), fmt::ptr(this), --COUNT);
}
};
}
struct CancelledState final {};
template <class T = impl::DefaultValue, class E = impl::DefaultError, class P = impl::DefaultProgress>
template <class T = impl::DefaultValue, class E = impl::DefaultError, class P = DefaultProgress>
class PromiseEventFilter;
template <class T = impl::DefaultValue, class E = impl::DefaultError, class P = impl::DefaultProgress>
template <class T = impl::DefaultValue, class E = impl::DefaultError, class P = DefaultProgress>
class Promise final {
public:
using Value = T;
@ -62,21 +91,25 @@ namespace geode {
if constexpr (std::is_same_v<T, T2>) {
return Promise<T2, E2, P2>::State::make_value(std::move(std::move(*this).take_value()));
}
log::error("THIS CODE PATH SHOULD BE UNREACHABLE!!!!");
}
if (this->has_error()) {
if constexpr (std::is_same_v<E, E2>) {
return Promise<T2, E2, P2>::State::make_error(std::move(std::move(*this).take_error()));
}
log::error("THIS CODE PATH SHOULD BE UNREACHABLE!!!!");
}
if (this->has_progress()) {
if constexpr (std::is_same_v<P, P2>) {
return Promise<T2, E2, P2>::State::make_progress(std::move(std::move(*this).take_progress()));
}
log::error("THIS CODE PATH SHOULD BE UNREACHABLE!!!!");
}
return Promise<T2, E2, P2>::State::make_cancelled();
if (this->is_cancelled()) {
return Promise<T2, E2, P2>::State::make_cancelled();
}
geode::utils::unreachable(
"Promise::State::convert called on a State that isn't in a convertible state! "
"All code should verify before calling convert() that the State holds a value "
"which is trivially convertible (holds the same type)"
);
}
bool has_value() { return m_value.index() == 0; }
@ -96,14 +129,14 @@ namespace geode {
using OnStateChange = utils::MiniFunction<void(State)>;
Promise() : m_data(std::make_shared<Data>()) {}
Promise() : m_data(nullptr) {}
Promise(utils::MiniFunction<void(OnResolved, OnRejected)> source, bool threaded = true)
Promise(utils::MiniFunction<void(OnResolved, OnRejected)> source)
: Promise([source](auto resolve, auto reject, auto, auto const&) {
source(resolve, reject);
}, threaded) {}
}) {}
Promise(utils::MiniFunction<void(OnResolved, OnRejected, OnProgress, std::atomic_bool const&)> source, bool threaded = true)
Promise(utils::MiniFunction<void(OnResolved, OnRejected, OnProgress, std::atomic_bool const&)> source)
: Promise([source](auto onStateChanged, auto const& cancelled) {
source(
[onStateChanged](auto&& value) {
@ -117,18 +150,27 @@ namespace geode {
},
cancelled
);
}, threaded, std::monostate()) {}
}, std::monostate()) {}
Promise(utils::MiniFunction<void(OnStateChange, std::atomic_bool const&)> source, bool threaded, std::monostate tag) : m_data(std::make_shared<Data>()) {
m_data->shouldStartThreaded = threaded;
if (threaded) {
std::thread([source = std::move(source), data = m_data]() mutable {
Promise::invoke_source(std::move(source), data);
}).detach();
}
else {
Promise::invoke_source(std::move(source), m_data);
}
Promise(utils::MiniFunction<void(OnStateChange, std::atomic_bool const&)> source, std::monostate tag)
: m_data(std::make_shared<Data>())
{
std::thread([source = std::move(source), data = m_data]() mutable {
log::info("start invoke_source");
source(
[data = std::weak_ptr(data)](auto&& state) {
if (auto d = data.lock()) {
log::info("callback from invoke_source");
invoke_callback(std::move(state), d.get());
}
else {
log::info("tried to callback from invoke_source but deleted");
}
},
data->cancelled
);
log::info("end invoke_source");
}).detach();
}
Promise then(utils::MiniFunction<void(Value)>&& callback) {
@ -260,14 +302,17 @@ namespace geode {
}
void resolve(Value&& value) {
invoke_callback(State::make_value(std::move(value)), m_data);
if (!m_data) return;
invoke_callback(State::make_value(std::move(value)), m_data.get());
}
void reject(Error&& error) {
invoke_callback(State::make_error(std::move(error)), m_data);
if (!m_data) return;
invoke_callback(State::make_error(std::move(error)), m_data.get());
}
void cancel() {
if (!m_data) return;
m_data->cancelled = true;
invoke_callback(State::make_cancelled(), m_data);
invoke_callback(State::make_cancelled(), m_data.get());
}
/**
@ -356,16 +401,15 @@ namespace geode {
std::vector<OnStateChange> callbacks;
std::optional<std::variant<Value, Error>> result;
std::atomic_bool cancelled;
std::atomic_bool shouldStartThreaded;
impl::LogBug<0> log;
};
std::shared_ptr<Data> m_data;
impl::LogBug<1> log;
template <class T2, class E2, class P2>
static Promise<T2, E2, P2> make_fwd(
auto&& transformState,
std::shared_ptr<Data> data
) {
static Promise<T2, E2, P2> make_fwd(auto&& transformState, std::shared_ptr<Data> data) {
return Promise<T2, E2, P2>([data, transformState](auto fwdStateToNextPromise, auto const&) {
if (!data) return;
Promise::set_callback(
[fwdStateToNextPromise, transformState](auto&& state) {
// Map the state
@ -373,12 +417,12 @@ namespace geode {
// Forward the value to the next Promise
fwdStateToNextPromise(std::move(mapped));
},
data
data.get()
);
}, data->shouldStartThreaded, std::monostate());
}, std::monostate());
}
static void set_callback(OnStateChange&& callback, std::shared_ptr<Data> data) {
static void set_callback(OnStateChange&& callback, Data* data) {
std::unique_lock lock(data->mutex);
data->callbacks.emplace_back(std::move(callback));
@ -397,12 +441,12 @@ namespace geode {
}
}
static void invoke_callback(State&& state, std::shared_ptr<Data> data) {
static void invoke_callback(State&& state, Data* data) {
std::unique_lock lock(data->mutex);
invoke_callback_no_lock(std::move(state), data);
}
static void invoke_callback_no_lock(State&& state, std::shared_ptr<Data> data) {
static void invoke_callback_no_lock(State&& state, Data* data) {
// Run callbacks in the main thread
Loader::get()->queueInMainThread([callbacks = data->callbacks, state = State(state)]() {
for (auto&& callback : std::move(callbacks)) {
@ -421,15 +465,6 @@ namespace geode {
data->cancelled = true;
}
}
static void invoke_source(utils::MiniFunction<void(OnStateChange, std::atomic_bool const&)>&& source, std::shared_ptr<Data> data) {
source(
[data](auto&& state) {
invoke_callback(std::move(state), data);
},
data->cancelled
);
}
};
/**
@ -439,13 +474,14 @@ namespace geode {
* whereas with event listeners being RAII, they are automatically
* removed from layers, avoiding use-after-free errors
*/
template <class T = impl::DefaultValue, class E = impl::DefaultError, class P = impl::DefaultProgress>
template <class T = impl::DefaultValue, class E = impl::DefaultError, class P = DefaultProgress>
class PromiseEvent : public Event {
protected:
size_t m_id;
std::shared_ptr<void> m_handle;
std::variant<T, E, P> m_value;
PromiseEvent(size_t id, std::variant<T, E, P>&& value) : m_id(id), m_value(value) {}
PromiseEvent(std::shared_ptr<void> handle, std::variant<T, E, P>&& value)
: m_handle(handle), m_value(value) {}
friend class Promise<T, E, P>;
friend class PromiseEventFilter<T, E, P>;
@ -463,18 +499,18 @@ namespace geode {
using Callback = void(PromiseEvent<T, E, P>*);
protected:
size_t m_id;
std::shared_ptr<void> m_handle;
friend class Promise<T, E, P>;
PromiseEventFilter(size_t id) : m_id(id) {}
PromiseEventFilter(std::shared_ptr<void> handle) : m_handle(handle) {}
public:
PromiseEventFilter() : m_id(0) {}
PromiseEventFilter() : m_handle(nullptr) {}
ListenerResult handle(utils::MiniFunction<Callback> fn, PromiseEvent<T, E, P>* event) {
// log::debug("Event mod filter: {}, {}, {}, {}", m_mod, static_cast<int>(m_type), event->getMod(), static_cast<int>(event->getType()));
if (m_id == event->m_id) {
if (m_handle == event->m_handle) {
fn(event);
}
return ListenerResult::Propagate;
@ -483,26 +519,22 @@ namespace geode {
template <class T, class E, class P>
PromiseEventFilter<T, E, P> Promise<T, E, P>::listen() {
// After 4 billion promises this will overflow and start producing
// the same IDs again, so technically if some promise takes
// literally forever then this could cause issues later on
static size_t ID_COUNTER = 0;
ID_COUNTER += 1;
// Reserve 0 for PromiseEventFilter not listening to anything
if (ID_COUNTER == 0) {
ID_COUNTER += 1;
}
size_t id = ID_COUNTER;
this
->then([id](auto&& value) {
PromiseEvent<T, E, P>(id, std::variant<T, E, P> { std::in_place_index<0>, std::forward<T>(value) }).post();
->then([data = std::weak_ptr(m_data)](auto&& value) {
if (auto d = std::static_pointer_cast<void>(data.lock())) {
PromiseEvent<T, E, P>(d, std::variant<T, E, P> { std::in_place_index<0>, std::forward<T>(value) }).post();
}
})
.expect([id](auto&& error) {
PromiseEvent<T, E, P>(id, std::variant<T, E, P> { std::in_place_index<1>, std::forward<E>(error) }).post();
.expect([data = std::weak_ptr(m_data)](auto&& error) {
if (auto d = std::static_pointer_cast<void>(data.lock())) {
PromiseEvent<T, E, P>(d, std::variant<T, E, P> { std::in_place_index<1>, std::forward<E>(error) }).post();
}
})
.progress([id](auto&& prog) {
PromiseEvent<T, E, P>(id, std::variant<T, E, P> { std::in_place_index<2>, std::forward<P>(prog) }).post();
.progress([data = std::weak_ptr(m_data)](auto&& prog) {
if (auto d = std::static_pointer_cast<void>(data.lock())) {
PromiseEvent<T, E, P>(d, std::variant<T, E, P> { std::in_place_index<2>, std::forward<P>(prog) }).post();
}
});
return PromiseEventFilter<T, E, P>(id);
return PromiseEventFilter<T, E, P>(m_data);
}
}

View file

@ -0,0 +1,317 @@
#pragma once
#include "general.hpp"
#include "MiniFunction.hpp"
#include "../loader/Event.hpp"
#include "../loader/Loader.hpp"
namespace geode {
template <std::move_constructible T, std::move_constructible P = std::monostate>
class [[nodiscard]] Task final {
public:
struct [[nodiscard]] Cancel final {};
class Result final {
private:
std::variant<T, Cancel> m_value;
public:
Result(Result const&) = delete;
Result(T&& value) : m_value(std::in_place_index<0>, std::forward<T>(value)) {}
Result(Cancel const&) : m_value(std::in_place_index<1>, Cancel()) {}
template <class V>
Result(V&& value) requires std::is_constructible_v<T, V&&>
: m_value(std::in_place_index<0>, std::forward<V>(value))
{}
std::optional<T> getValue() && {
if (m_value.index() == 0) {
return std::optional(std::move(std::get<0>(std::move(m_value))));
}
return std::nullopt;
}
bool isCancelled() const {
return m_value.index() == 1;
}
};
public:
enum class Status {
Pending,
Finished,
Cancelled,
};
class Handle final {
private:
std::recursive_mutex m_mutex;
Status m_status = Status::Pending;
std::optional<T> m_resultValue;
bool m_finalEventPosted = false;
std::unique_ptr<void, void(*)(void*)> m_mapListener = { nullptr, +[](void*) {} };
class PrivateMarker final {};
static std::shared_ptr<Handle> create() {
return std::make_shared<Handle>(PrivateMarker());
}
bool is(Status status) {
std::unique_lock<std::recursive_mutex> lock(m_mutex);
return m_status == status;
}
friend class Task;
public:
Handle(PrivateMarker) {}
};
class Event final : public geode::Event {
private:
std::shared_ptr<Handle> m_handle;
std::variant<T*, P*, Cancel> m_value;
EventListenerProtocol* m_for = nullptr;
Event(std::shared_ptr<Handle> handle, std::variant<T*, P*, Cancel>&& value)
: m_handle(handle), m_value(std::move(value)) {}
static Event createFinished(std::shared_ptr<Handle> handle, T* value) {
return Event(handle, std::variant<T*, P*, Cancel>(std::in_place_index<0>, value));
}
static Event createProgressed(std::shared_ptr<Handle> handle, P* value) {
return Event(handle, std::variant<T*, P*, Cancel>(std::in_place_index<1>, value));
}
static Event createCancelled(std::shared_ptr<Handle> handle) {
return Event(handle, std::variant<T*, P*, Cancel>(std::in_place_index<2>, Cancel()));
}
friend class Task;
public:
T* getValue() {
return m_value.index() == 0 ? std::get<0>(m_value) : nullptr;
}
T const* getValue() const {
return m_value.index() == 0 ? std::get<0>(m_value) : nullptr;
}
P* getProgress() {
return m_value.index() == 1 ? std::get<1>(m_value) : nullptr;
}
P const* getProgress() const {
return m_value.index() == 1 ? std::get<1>(m_value) : nullptr;
}
bool isCancelled() const {
return m_value.index() == 2;
}
void cancel() {
Task::cancel(m_handle);
}
};
using PostProgress = utils::MiniFunction<void(P)>;
using HasBeenCancelled = utils::MiniFunction<bool()>;
using Run = utils::MiniFunction<Result(PostProgress, HasBeenCancelled)>;
using Callback = void(Event*);
private:
EventListenerProtocol* m_listener = nullptr;
std::shared_ptr<Handle> m_handle;
Task(std::shared_ptr<Handle> handle) : m_handle(handle) {}
static void finish(std::shared_ptr<Handle> handle, T&& value) {
if (!handle) return;
std::unique_lock<std::recursive_mutex> lock(handle->m_mutex);
if (handle->m_status == Status::Pending) {
handle->m_status = Status::Finished;
handle->m_resultValue = std::move(value);
Loader::get()->queueInMainThread([handle, value = &*handle->m_resultValue]() mutable {
Event::createFinished(handle, value).post();
std::unique_lock<std::recursive_mutex> lock(handle->m_mutex);
handle->m_finalEventPosted = true;
});
}
}
static void progress(std::shared_ptr<Handle> handle, P&& value) {
if (!handle) return;
std::unique_lock<std::recursive_mutex> lock(handle->m_mutex);
if (handle->m_status == Status::Pending) {
Loader::get()->queueInMainThread([handle, value = std::move(value)]() mutable {
Event::createProgressed(handle, &value).post();
});
}
}
static void cancel(std::shared_ptr<Handle> handle) {
if (!handle) return;
std::unique_lock<std::recursive_mutex> lock(handle->m_mutex);
if (handle->m_status == Status::Pending) {
handle->m_status = Status::Cancelled;
Loader::get()->queueInMainThread([handle]() mutable {
Event::createCancelled(handle).post();
std::unique_lock<std::recursive_mutex> lock(handle->m_mutex);
handle->m_finalEventPosted = true;
});
}
}
template <std::move_constructible T2, std::move_constructible P2>
friend class Task;
public:
Task() : m_handle(nullptr) {}
Task(Task const& other) : m_handle(other.m_handle) {}
Task(Task&& other) : m_handle(std::move(other.m_handle)) {}
Task& operator=(Task const& other) {
m_handle = other.m_handle;
return *this;
}
Task& operator=(Task&& other) {
m_handle = std::move(other.m_handle);
return *this;
}
T* getFinishedValue() {
if (m_handle && m_handle->m_resultValue) {
return &*m_handle->m_resultValue;
}
return nullptr;
}
void cancel() {
Task::cancel(m_handle);
}
bool isPending() {
return m_handle && m_handle->is(Status::Pending);
}
bool isFinished() {
return m_handle && m_handle->is(Status::Finished);
}
bool isCancelled() {
return m_handle && m_handle->is(Status::Cancelled);
}
static Task immediate(T&& value) {
auto task = Task(Handle::create());
Task::finish(task.m_handle, std::move(value));
return task;
}
static Task run(Run&& body) {
auto task = Task(Handle::create());
std::thread([handle = std::weak_ptr(task.m_handle), body = std::move(body)] {
utils::thread::setName(fmt::format("Task @{}", fmt::ptr(handle.lock())));
auto result = body(
[handle](P progress) {
Task::progress(handle.lock(), std::move(progress));
},
[handle]() -> bool {
// The task has been cancelled if the user has explicitly cancelled it,
// or if there is no one listening anymore
auto lock = handle.lock();
return !(lock && lock->is(Status::Pending));
}
);
if (result.isCancelled()) {
Task::cancel(handle.lock());
}
else {
Task::finish(handle.lock(), std::move(*std::move(result).getValue()));
}
}).detach();
return task;
}
template <class ResultMapper, class ProgressMapper>
auto map(ResultMapper&& resultMapper, ProgressMapper&& progressMapper) {
using T2 = decltype(resultMapper(std::declval<T*>()));
using P2 = decltype(progressMapper(std::declval<P*>()));
auto task = Task<T2, P2>(Task<T2, P2>::Handle::create());
// Lock the current task until we have managed to create our new one
std::unique_lock<std::recursive_mutex> lock(m_handle->m_mutex);
// If the current task is cancelled, cancel the new one immediately
if (m_handle->m_status == Status::Cancelled) {
Task<T2, P2>::cancel(task.m_handle);
}
// If the current task is finished, immediately map the value and post that
else if (m_handle->m_status == Status::Finished) {
Task<T2, P2>::finish(task.m_handle, resultMapper(&*m_handle->m_resultValue));
}
// Otherwise start listening and waiting for the current task to finish
else {
task.m_handle->m_mapListener = std::unique_ptr<void, void(*)(void*)>(
static_cast<void*>(new EventListener<Task>(
[
handle = std::weak_ptr(task.m_handle),
resultMapper = std::move(resultMapper),
progressMapper = std::move(progressMapper)
](Event* event) {
if (auto v = event->getValue()) {
Task<T2, P2>::finish(handle.lock(), resultMapper(v));
}
else if (auto p = event->getProgress()) {
Task<T2, P2>::progress(handle.lock(), progressMapper(p));
}
else if (event->isCancelled()) {
Task<T2, P2>::cancel(handle.lock());
}
},
*this
)),
+[](void* ptr) {
delete static_cast<EventListener<Task>*>(ptr);
}
);
}
return task;
}
ListenerResult handle(utils::MiniFunction<Callback> fn, Event* e) {
if (e->m_handle == m_handle && (!e->m_for || e->m_for == m_listener)) {
fn(e);
}
return ListenerResult::Propagate;
}
// todo: i believe alk wanted tasks to be in their own pool
EventListenerPool* getPool() const {
return DefaultEventListenerPool::get();
}
void setListener(EventListenerProtocol* listener) {
m_listener = listener;
if (!m_handle) return;
// If this task has already been finished and the finish event
// isn't pending in the event queue, immediately queue up a
// finish event for this listener
std::unique_lock<std::recursive_mutex> lock(m_handle->m_mutex);
if (m_handle->m_finalEventPosted) {
if (m_handle->m_status == Status::Finished) {
Loader::get()->queueInMainThread([handle = m_handle, listener = m_listener, value = &*m_handle->m_resultValue]() {
auto ev = Event::createFinished(handle, value);
ev.m_for = listener;
ev.post();
});
}
else {
Loader::get()->queueInMainThread([handle = m_handle, listener = m_listener]() {
auto ev = Event::createCancelled(handle);
ev.m_for = listener;
ev.post();
});
}
}
}
EventListenerProtocol* getListener() const {
return m_listener;
}
};
static_assert(is_filter<Task<int>>, "The Task class must be a valid event filter!");
}

View file

@ -3,6 +3,7 @@
#include <matjson.hpp>
#include "Result.hpp"
#include "Promise.hpp"
#include "Task.hpp"
#include <chrono>
namespace geode::utils::web {
@ -20,6 +21,7 @@ namespace geode::utils::web {
// Must be default-constructible for use in Promise
WebResponse();
bool ok() const;
int code() const;
Result<std::string> string() const;
@ -53,6 +55,7 @@ namespace geode::utils::web {
std::optional<float> uploadProgress() const;
};
using WebTask = Task<WebResponse, WebProgress>;
using WebPromise = Promise<WebResponse, WebError, WebProgress>;
class GEODE_DLL WebRequest final {
@ -65,6 +68,8 @@ namespace geode::utils::web {
WebRequest();
~WebRequest();
WebTask send2(std::string_view method, std::string_view url);
WebPromise send(std::string_view method, std::string_view url);
WebPromise post(std::string_view url);
WebPromise get(std::string_view url);

View file

@ -7,6 +7,135 @@ using namespace server;
#define GEODE_GD_VERSION_STR GEODE_STR(GEODE_GD_VERSION)
template <class K, class V>
class CacheMap final {
private:
// I know this looks like a goofy choice over just
// `std::unordered_map`, but hear me out:
//
// This needs preserved insertion order (so shrinking the cache
// to match size limits doesn't just have to erase random
// elements)
//
// If this used a map for values and another vector for storing
// insertion order, it would have a pretty big memory footprint
// (two copies of Query, one for order, one for map + two heap
// allocations on top of that)
//
// In addition, it would be a bad idea to have a cache of 1000s
// of items in any case (since that would likely take up a ton
// of memory, which we want to avoid since it's likely many
// crashes with the old index were due to too much memory
// usage)
//
// Linear searching a vector of at most a couple dozen items is
// lightning-fast (🚀), and besides the main performance benefit
// comes from the lack of a web request - not how many extra
// milliseconds we can squeeze out of a map access
std::vector<std::pair<K, V>> m_values;
size_t m_sizeLimit = 20;
public:
std::optional<V> get(K const& key) {
auto it = std::find_if(m_values.begin(), m_values.end(), [key](auto const& q) {
return q.first == key;
});
if (it != m_values.end()) {
return it->second;
}
return std::nullopt;
}
void add(K&& key, V&& value) {
auto pair = std::make_pair(std::move(key), std::move(value));
// Shift and replace last element if we're at cache size limit
if (m_values.size() >= m_sizeLimit) {
std::shift_left(m_values.begin(), m_values.end(), 1);
m_values.back() = std::move(pair);
}
// Otherwise append at end
else {
m_values.emplace_back(std::move(pair));
}
}
void remove(K const& key) {
ranges::remove(m_values, [&key](auto const& q) { return q.first == key; });
}
void clear() {
m_values.clear();
}
void limit(size_t size) {
m_sizeLimit = size;
m_values.clear();
}
size_t size() const {
return m_values.size();
}
size_t limit() const {
return m_sizeLimit;
}
};
template <class Q, class V>
using ServerFuncQ = V(*)(Q const&, bool);
template <class V>
using ServerFuncNQ = V(*)(bool);
template <class F>
struct ExtractFun;
template <class Q, class V>
struct ExtractFun<ServerRequest<V>(*)(Q const&, bool)> {
using Query = Q;
using Value = V;
static ServerRequest<V> invoke(auto&& func, Query const& query) {
return func(query, false);
}
};
template <class V>
struct ExtractFun<ServerRequest<V>(*)(bool)> {
using Query = std::monostate;
using Value = V;
static ServerRequest<V> invoke(auto&& func, Query const&) {
return func(false);
}
};
template <auto F>
class FunCache final {
public:
using Extract = ExtractFun<decltype(F)>;
using Query = typename Extract::Query;
using Value = typename Extract::Value;
private:
CacheMap<Query, ServerRequest<Value>> m_cache;
public:
FunCache() = default;
FunCache(FunCache const&) = delete;
FunCache(FunCache&&) = delete;
ServerRequest<Value> get(Query const& query = Query()) {
if (auto v = m_cache.get(query)) {
return *v;
}
auto f = Extract::invoke(F, query);
m_cache.add(Query(query), ServerRequest<Value>(f));
return f;
}
void clear() {
m_cache.clear();
}
};
template <auto F>
FunCache<F>& getCache() {
static auto inst = FunCache<F>();
return inst;
}
static const char* jsonTypeToString(matjson::Type const& type) {
switch (type) {
case matjson::Type::Object: return "object";
@ -19,7 +148,7 @@ static const char* jsonTypeToString(matjson::Type const& type) {
}
}
static Result<matjson::Value, ServerError> parseServerPayload(web::WebResponse&& response) {
static Result<matjson::Value, ServerError> parseServerPayload(web::WebResponse const& response) {
auto asJson = response.json();
if (!asJson) {
return Err(ServerError(response.code(), "Response was not valid JSON: {}", asJson.unwrapErr()));
@ -35,7 +164,7 @@ static Result<matjson::Value, ServerError> parseServerPayload(web::WebResponse&&
return Ok(obj["payload"]);
}
static ServerError parseServerError(auto error) {
static ServerError parseServerError(web::WebResponse const& error) {
// The server should return errors as `{ "error": "...", "payload": "" }`
if (auto asJson = error.json()) {
auto json = asJson.unwrap();
@ -58,7 +187,7 @@ static ServerError parseServerError(auto error) {
}
}
static ServerProgress parseServerProgress(auto prog, auto msg) {
static ServerProgress parseServerProgress(web::WebProgress const& prog, auto msg) {
if (auto per = prog.downloadProgress()) {
return ServerProgress(msg, static_cast<uint8_t>(*per));
}
@ -348,6 +477,76 @@ std::string server::getServerUserAgent() {
);
}
ServerRequest<ServerModsList> server::getMods2(ModsQuery const& query, bool useCache) {
if (useCache) {
return getCache<getMods2>().get(query);
}
auto req = web::WebRequest();
req.userAgent(getServerUserAgent());
// Always target current GD version and Loader version
req.param("gd", GEODE_GD_VERSION_STR);
req.param("geode", Loader::get()->getVersion().toString());
// Add search params
if (query.query) {
req.param("query", *query.query);
}
if (query.platforms.size()) {
std::string plats = "";
bool first = true;
for (auto plat : query.platforms) {
if (!first) plats += ",";
plats += PlatformID::toShortString(plat.m_value);
first = false;
}
req.param("platforms", plats);
}
if (query.tags.size()) {
req.param("tags", ranges::join(query.tags, ","));
}
if (query.featured) {
req.param("featured", query.featured.value() ? "true" : "false");
}
req.param("sort", sortToString(query.sorting));
if (query.developer) {
req.param("developer", *query.developer);
}
// Paging (1-based on server, 0-based locally)
req.param("page", std::to_string(query.page + 1));
req.param("per_page", std::to_string(query.pageSize));
return req.send2("GET", getServerAPIBaseURL() + "/mods").map(
[](web::WebResponse* response) -> Result<ServerModsList, ServerError> {
if (response->ok()) {
// Parse payload
auto payload = parseServerPayload(*response);
if (!payload) {
return Err(payload.unwrapErr());
}
// Parse response
auto list = ServerModsList::parse(payload.unwrap());
if (!list) {
return Err(ServerError(response->code(), "Unable to parse response: {}", list.unwrapErr()));
}
return Ok(list.unwrap());
}
else {
// Treat a 404 as empty mods list
if (response->code() == 404) {
return Ok(ServerModsList());
}
return Err(parseServerError(*response));
}
},
[](web::WebProgress* progress) {
return parseServerProgress(*progress, "Downloading mods");
}
);
}
ServerPromise<ServerModsList> server::getMods(ModsQuery const& query) {
auto req = web::WebRequest();
req.userAgent(getServerUserAgent());
@ -540,3 +739,15 @@ ServerPromise<std::vector<ServerModUpdate>> server::checkUpdates(std::vector<std
return parseServerProgress(prog, "Checking updates for mods");
});
}
void server::clearServerCaches2(bool clearGlobalCaches) {
getCache<&getMods2>().clear();
// getCache<&getMod>().clear();
// getCache<&getModLogo>().clear();
// Only clear global caches if explicitly requested
if (clearGlobalCaches) {
// getCache<&getTags>().clear();
// getCache<&checkUpdates>().clear();
}
}

View file

@ -107,8 +107,12 @@ namespace server {
template <class T>
using ServerPromise = Promise<T, ServerError, ServerProgress>;
template <class T>
using ServerRequest = Task<Result<T, ServerError>, ServerProgress>;
std::string getServerAPIBaseURL();
std::string getServerUserAgent();
ServerRequest<ServerModsList> getMods2(ModsQuery const& query, bool useCache = true);
ServerPromise<ServerModsList> getMods(ModsQuery const& query);
ServerPromise<ServerModMetadata> getMod(std::string const& id);
ServerPromise<ByteVector> getModLogo(std::string const& id);
@ -401,4 +405,6 @@ namespace server {
ServerResultCache<&checkUpdates>::shared().invalidateAll();
}
}
void clearServerCaches2(bool clearGlobalCaches = false);
}

View file

@ -0,0 +1,194 @@
#define GEODE_UI_TEST
#ifdef GEODE_UI_TEST
#include <Geode/modify/MenuLayer.hpp>
#include <Geode/ui/Popup.hpp>
#include <Geode/ui/BasedButtonSprite.hpp>
#include <Geode/utils/web2.hpp>
#include <server/Server.hpp>
using namespace geode::prelude;
using StrTask = Task<std::string>;
class GUITestPopup : public Popup<> {
protected:
CCLabelBMFont* m_rawTaskState;
CCMenuItemSpriteExtra* m_cancelTaskBtn;
CCMenuItemSpriteExtra* m_cancelServerTaskBtn;
EventListener<web::WebTask> m_rawListener;
EventListener<StrTask> m_strListener;
EventListener<server::ServerRequest<server::ServerModsList>> m_serListener;
EventListener<server::ServerRequest<server::ServerModsList>> m_serListener2;
bool setup() override {
m_noElasticity = true;
this->setTitle("GUI Test Popup");
auto startPromiseSpr = ButtonSprite::create(
"Promise Test", "bigFont.fnt", "GJ_button_05.png", .8f
);
startPromiseSpr->setScale(.5f);
auto startPromiseBtn = CCMenuItemSpriteExtra::create(
startPromiseSpr, this, menu_selector(GUITestPopup::onPromiseTest)
);
m_buttonMenu->addChildAtPosition(startPromiseBtn, Anchor::Center, ccp(-40, 40));
auto cancelPromiseSpr = ButtonSprite::create(
"Cancel", "bigFont.fnt", "GJ_button_06.png", .8f
);
cancelPromiseSpr->setScale(.5f);
m_cancelTaskBtn = CCMenuItemSpriteExtra::create(
cancelPromiseSpr, this, menu_selector(GUITestPopup::onPromiseCancel)
);
m_cancelTaskBtn->setVisible(false);
m_buttonMenu->addChildAtPosition(m_cancelTaskBtn, Anchor::Center, ccp(50, 40));
m_rawTaskState = CCLabelBMFont::create("Task not started", "bigFont.fnt");
m_rawTaskState->setScale(.5f);
m_mainLayer->addChildAtPosition(m_rawTaskState, Anchor::Center, ccp(0, 10));
auto serverPromiseSpr = ButtonSprite::create(
"Server Request", "bigFont.fnt", "GJ_button_05.png", .8f
);
serverPromiseSpr->setScale(.5f);
auto serverPromiseBtn = CCMenuItemSpriteExtra::create(
serverPromiseSpr, this, menu_selector(GUITestPopup::onServerReq)
);
m_buttonMenu->addChildAtPosition(serverPromiseBtn, Anchor::Center, ccp(-40, -40));
auto cancelServerPromiseSpr = ButtonSprite::create(
"Cancel", "bigFont.fnt", "GJ_button_06.png", .8f
);
cancelServerPromiseSpr->setScale(.5f);
m_cancelServerTaskBtn = CCMenuItemSpriteExtra::create(
cancelServerPromiseSpr, this, menu_selector(GUITestPopup::onServerCancel)
);
m_cancelServerTaskBtn->setVisible(false);
m_buttonMenu->addChildAtPosition(m_cancelServerTaskBtn, Anchor::Center, ccp(50, -40));
auto clearServerCacheSpr = ButtonSprite::create(
"Clear Caches", "bigFont.fnt", "GJ_button_01.png", .8f
);
clearServerCacheSpr->setScale(.5f);
auto clearServerCacheBtn = CCMenuItemSpriteExtra::create(
clearServerCacheSpr, this, menu_selector(GUITestPopup::onServerCacheClear)
);
m_buttonMenu->addChildAtPosition(clearServerCacheBtn, Anchor::Center, ccp(0, -70));
m_rawListener.bind(this, &GUITestPopup::onRawTask);
m_strListener.bind(this, &GUITestPopup::onStrTask);
m_serListener.bind(this, &GUITestPopup::onServerTask);
m_serListener2.bind(this, &GUITestPopup::onServerTask);
return true;
}
void onRawTask(web::WebTask::Event* event) {
m_cancelTaskBtn->setVisible(event->getProgress());
if (event->isCancelled()) {
m_rawTaskState->setString("Cancelled!");
}
if (auto value = event->getValue()) {
m_rawTaskState->setString(fmt::format("Finished with code {}", value->code()).c_str());
}
if (auto progress = event->getProgress()) {
m_rawTaskState->setString(fmt::format(
"Progress: {}/{}", progress->downloaded(), progress->downloadTotal()
).c_str());
}
}
void onStrTask(StrTask::Event* event) {
if (event->isCancelled()) {
log::info("str task cancelled :(");
}
else if (auto value = event->getValue()) {
log::info("str task done: {}", *value);
}
}
void onServerTask(server::ServerRequest<server::ServerModsList>::Event* event) {
m_cancelServerTaskBtn->setVisible(event->getProgress());
if (auto value = event->getValue()) {
if (value->isOk()) {
auto mods = value->unwrap();
log::info("got a mods list with {}/{} mods!!!!", mods.mods.size(), mods.totalModCount);
}
else {
log::info("epic mods list fail L: {}", value->unwrapErr().details);
}
}
else if (auto prog = event->getProgress()) {
log::info("mods list progress: {}", prog->message);
}
else if (event->isCancelled()) {
log::info("mods list cancelled L");
}
}
void makeRequest() {
auto task = web::WebRequest().send2("GET", "https://api.geode-sdk.org/");
m_rawListener.setFilter(task);
m_strListener.setFilter(task.map(
[](auto* result) {
return fmt::format("finish with code {} :3", result->code());
},
[](auto* progress) {
return std::monostate();
}
));
}
void onPromiseCancel(CCObject*) {
m_rawListener.getFilter().cancel();
}
void onPromiseTest(CCObject*) {
m_cancelTaskBtn->setVisible(true);
this->makeRequest();
}
void onServerReq(CCObject*) {
m_cancelServerTaskBtn->setVisible(true);
m_serListener.setFilter(server::getMods2(server::ModsQuery()));
m_serListener2.setFilter(server::getMods2(server::ModsQuery()));
}
void onServerCancel(CCObject*) {
m_serListener.getFilter().cancel();
}
void onServerCacheClear(CCObject*) {
server::clearServerCaches2();
}
public:
static GUITestPopup* create() {
auto ret = new GUITestPopup();
if (ret && ret->initAnchored(320, 280)) {
ret->autorelease();
return ret;
}
CC_SAFE_DELETE(ret);
return nullptr;
}
};
class $modify(GUILayer, MenuLayer) {
bool init() {
if (!MenuLayer::init())
return false;
auto btn = CCMenuItemSpriteExtra::create(
CrossButtonSprite::create(CCLabelBMFont::create("test", "goldFont.fnt")),
this, menu_selector(GUILayer::onTestPopup)
);
this->getChildByID("main-menu")->addChild(btn);
this->getChildByID("main-menu")->updateLayout();
return true;
}
void onTestPopup(CCObject*) {
GUITestPopup::create()->show();
}
};
#endif

View file

@ -15,6 +15,9 @@ public:
WebResponse::WebResponse() : m_impl(std::make_shared<Impl>()) {}
bool WebResponse::ok() const {
return 200 <= m_impl->m_code && m_impl->m_code < 300;
}
int WebResponse::code() const {
return m_impl->m_code;
}
@ -112,6 +115,176 @@ std::string urlParamEncode(std::string_view const input) {
return ss.str();
}
WebTask WebRequest::send2(std::string_view method, std::string_view url) {
m_impl->m_method = method;
m_impl->m_url = url;
return WebTask::run([impl = m_impl](auto progress, auto hasBeenCancelled) -> WebTask::Result {
// Init Curl
auto curl = curl_easy_init();
if (!curl) {
return impl->makeError(-1, "Curl not initialized");
}
// todo: in the future, we might want to support downloading directly into
// files / in-memory streams like the old AsyncWebRequest class
// Struct that holds values for the curl callbacks
struct ResponseData {
WebResponse response;
Impl* impl;
WebTask::PostProgress progress;
WebTask::HasBeenCancelled hasBeenCancelled;
} responseData = {
.response = WebResponse(),
.impl = impl.get(),
.progress = progress,
.hasBeenCancelled = hasBeenCancelled,
};
// Store downloaded response data into a byte vector
curl_easy_setopt(curl, CURLOPT_WRITEDATA, &responseData);
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, +[](char* data, size_t size, size_t nmemb, void* ptr) {
auto& target = static_cast<ResponseData*>(ptr)->response.m_impl->m_data;
target.insert(target.end(), data, data + size * nmemb);
return size * nmemb;
});
// Set headers
curl_slist* headers = nullptr;
for (auto& [name, value] : impl->m_headers) {
// Sanitize header name
auto header = name;
header.erase(std::remove_if(header.begin(), header.end(), [](char c) {
return c == '\r' || c == '\n';
}), header.end());
// Append value
header += ": " + value;
headers = curl_slist_append(headers, header.c_str());
}
curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers);
// Add parameters to the URL and pass it to curl
auto url = impl->m_url;
bool first = true;
for (auto param : impl->m_urlParameters) {
url += (first ? "?" : "&") + urlParamEncode(param.first) + "=" + urlParamEncode(param.second);
first = false;
}
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
// Set request method
if (impl->m_method != "GET") {
if (impl->m_method == "POST") {
curl_easy_setopt(curl, CURLOPT_POST, 1L);
}
else {
curl_easy_setopt(curl, CURLOPT_CUSTOMREQUEST, impl->m_method.c_str());
}
}
// Set body if provided
if (impl->m_body) {
curl_easy_setopt(curl, CURLOPT_POSTFIELDS, impl->m_body->data());
curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, impl->m_body->size());
}
// No need to verify SSL, we trust our domains :-)
curl_easy_setopt(curl, CURLOPT_SSL_VERIFYPEER, 0);
curl_easy_setopt(curl, CURLOPT_SSL_VERIFYHOST, 0);
// Set user agent if provided
if (impl->m_userAgent) {
curl_easy_setopt(curl, CURLOPT_USERAGENT, impl->m_userAgent->c_str());
}
// Set timeout
if (impl->m_timeout) {
curl_easy_setopt(curl, CURLOPT_TIMEOUT, impl->m_timeout->count());
}
// Track progress
curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0);
// Follow redirects
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1);
// Do not fail if response code is 4XX or 5XX
curl_easy_setopt(curl, CURLOPT_FAILONERROR, 0L);
// Get headers from the response
curl_easy_setopt(curl, CURLOPT_HEADERDATA, &responseData);
curl_easy_setopt(curl, CURLOPT_HEADERFUNCTION, (+[](char* buffer, size_t size, size_t nitems, void* ptr) {
auto& headers = static_cast<ResponseData*>(ptr)->response.m_impl->m_headers;
std::string line;
std::stringstream ss(std::string(buffer, size * nitems));
while (std::getline(ss, line)) {
auto colon = line.find(':');
if (colon == std::string::npos) continue;
auto key = line.substr(0, colon);
auto value = line.substr(colon + 2);
if (value.ends_with('\r')) {
value = value.substr(0, value.size() - 1);
}
headers.insert_or_assign(key, value);
}
return size * nitems;
}));
// Track & post progress on the Promise
curl_easy_setopt(curl, CURLOPT_PROGRESSDATA, &responseData);
curl_easy_setopt(curl, CURLOPT_PROGRESSFUNCTION, +[](void* ptr, double dtotal, double dnow, double utotal, double unow) -> int {
auto data = static_cast<ResponseData*>(ptr);
// Check for cancellation and abort if so
if (data->hasBeenCancelled()) {
return 1;
}
// Post progress to Promise listener
auto progress = WebProgress();
progress.m_impl->m_downloadTotal = dtotal;
progress.m_impl->m_downloadCurrent = dnow;
progress.m_impl->m_uploadTotal = utotal;
progress.m_impl->m_uploadCurrent = unow;
data->progress(std::move(progress));
// Continue as normal
return 0;
});
// Make the actual web request
auto curlResponse = curl_easy_perform(curl);
// Get the response code; note that this will be invalid if the
// curlResponse is not CURLE_OK
long code = 0;
curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &code);
responseData.response.m_impl->m_code = static_cast<int>(code);
// Free up curl memory
curl_slist_free_all(headers);
curl_easy_cleanup(curl);
// Check if the request failed on curl's side or because of cancellation
if (curlResponse != CURLE_OK) {
if (hasBeenCancelled()) {
return WebTask::Cancel();
}
else {
return impl->makeError(-1, "Curl failed: " + std::string(curl_easy_strerror(curlResponse)));
}
}
// Check if the response was an error code
if (code >= 400 && code <= 600) {
return std::move(responseData.response);
}
// Otherwise resolve with success :-)
return std::move(responseData.response);
});
}
WebPromise WebRequest::send(std::string_view method, std::string_view url) {
m_impl->m_method = method;
m_impl->m_url = url;