From 636be07435cedd42e5607ed99ddb0d18bf4d26aa Mon Sep 17 00:00:00 2001
From: altalk23 <45172705+altalk23@users.noreply.github.com>
Date: Thu, 20 Jun 2024 16:17:02 +0300
Subject: [PATCH] fix event pool race condition in a hacky way

---
 loader/include/Geode/loader/Dispatch.hpp |  4 +-
 loader/include/Geode/loader/Event.hpp    | 31 +++++++++++--
 loader/src/loader/Event.cpp              | 57 ++++++++++++++++--------
 3 files changed, 69 insertions(+), 23 deletions(-)

diff --git a/loader/include/Geode/loader/Dispatch.hpp b/loader/include/Geode/loader/Dispatch.hpp
index ee66264e..1ce7bb9c 100644
--- a/loader/include/Geode/loader/Dispatch.hpp
+++ b/loader/include/Geode/loader/Dispatch.hpp
@@ -31,7 +31,7 @@ namespace geode {
 
         EventListenerPool* getPool() const override {
             if (dispatchPools().count(m_id) == 0) {
-                dispatchPools()[m_id] = new DefaultEventListenerPool();
+                dispatchPools()[m_id] = DefaultEventListenerPool::create();
             }
             return dispatchPools()[m_id];
         }
@@ -48,7 +48,7 @@ namespace geode {
 
         EventListenerPool* getPool() const {
             if (dispatchPools().count(m_id) == 0) {
-                dispatchPools()[m_id] = new DefaultEventListenerPool();
+                dispatchPools()[m_id] = DefaultEventListenerPool::create();
             }
             return dispatchPools()[m_id];
         }
diff --git a/loader/include/Geode/loader/Event.hpp b/loader/include/Geode/loader/Event.hpp
index 8b321bb3..c10f9a0a 100644
--- a/loader/include/Geode/loader/Event.hpp
+++ b/loader/include/Geode/loader/Event.hpp
@@ -5,6 +5,8 @@
 
 #include <Geode/DefaultInclude.hpp>
 #include <type_traits>
+#include <mutex>
+#include <deque>
 #include <unordered_set>
 
 namespace geode {
@@ -29,12 +31,29 @@ namespace geode {
         EventListenerPool(EventListenerPool const&) = delete;
         EventListenerPool(EventListenerPool&&) = delete;
     };
+
+    template <class... Args>
+    class DispatchEvent;
+
+    template <class... Args>
+    class DispatchFilter;
     
     class GEODE_DLL DefaultEventListenerPool : public EventListenerPool {
     protected:
-        std::atomic_size_t m_locked = 0;
-        std::vector<EventListenerProtocol*> m_listeners;
-        std::vector<EventListenerProtocol*> m_toAdd;
+        // fix this in Geode 4.0.0
+        struct Data {
+            std::atomic_size_t m_locked = 0;
+            std::mutex m_mutex;
+            std::deque<EventListenerProtocol*> m_listeners;
+            std::vector<EventListenerProtocol*> m_toAdd;
+            std::vector<EventListenerProtocol*> m_toRemove;
+        };
+        std::unique_ptr<Data> m_data;
+
+        DefaultEventListenerPool();
+
+    private:
+        static DefaultEventListenerPool* create();
 
     public:
         bool add(EventListenerProtocol* listener) override;
@@ -42,6 +61,12 @@ namespace geode {
         ListenerResult handle(Event* event) override;
 
         static DefaultEventListenerPool* get();
+
+        template <class... Args>
+        friend class DispatchEvent;
+
+        template <class... Args>
+        friend class DispatchFilter;
     };
 
     class GEODE_DLL EventListenerProtocol {
diff --git a/loader/src/loader/Event.cpp b/loader/src/loader/Event.cpp
index 1e6a8946..f2e42c8e 100644
--- a/loader/src/loader/Event.cpp
+++ b/loader/src/loader/Event.cpp
@@ -4,50 +4,71 @@
 
 using namespace geode::prelude;
 
+DefaultEventListenerPool::DefaultEventListenerPool() : m_data(new Data) {}
+
 bool DefaultEventListenerPool::add(EventListenerProtocol* listener) {
-    if (m_locked) {
-        m_toAdd.push_back(listener);
+    if (!m_data) m_data = std::make_unique<Data>();
+
+    std::unique_lock lock(m_data->m_mutex);
+    if (m_data->m_locked) {
+        m_data->m_toAdd.push_back(listener);
+        ranges::remove(m_data->m_toRemove, listener);
     }
     else {
         // insert listeners at the start so new listeners get priority
-        m_listeners.insert(m_listeners.begin(), listener);
+        m_data->m_listeners.push_front(listener);
     }
     return true;
 }
 
 void DefaultEventListenerPool::remove(EventListenerProtocol* listener) {
-    for (size_t i = 0; i < m_listeners.size(); i++) {
-        if (m_listeners[i] == listener) {
-            m_listeners[i] = nullptr;
-        }
+    if (!m_data) m_data = std::make_unique<Data>();
+
+    std::unique_lock lock(m_data->m_mutex);
+    if (m_data->m_locked) {
+        m_data->m_toRemove.push_back(listener);
+        ranges::remove(m_data->m_toAdd, listener);
+    }
+    else {
+        ranges::remove(m_data->m_listeners, listener);
     }
-    ranges::remove(m_toAdd, listener);
 }
 
 ListenerResult DefaultEventListenerPool::handle(Event* event) {
+    if (!m_data) m_data = std::make_unique<Data>();
+
     auto res = ListenerResult::Propagate;
-    m_locked += 1;
-    for (auto h : m_listeners) {
-        // if an event listener gets destroyed in the middle of this loop, it 
-        // gets set to null
+    m_data->m_locked += 1;
+    std::unique_lock lock(m_data->m_mutex);
+    for (auto h : m_data->m_listeners) {
+        lock.unlock();
         if (h && h->handle(event) == ListenerResult::Stop) {
             res = ListenerResult::Stop;
+            lock.lock();
             break;
         }
+        lock.lock();
     }
-    m_locked -= 1;
+    m_data->m_locked -= 1;
     // only mutate listeners once nothing is iterating 
     // (if there are recursive handle calls)
-    if (m_locked == 0) {
-        ranges::remove(m_listeners, nullptr);
-        for (auto listener : m_toAdd) {
-            m_listeners.insert(m_listeners.begin(), listener);
+    if (m_data->m_locked == 0) {
+        for (auto listener : m_data->m_toRemove) {
+            ranges::remove(m_data->m_listeners, listener);
         }
-        m_toAdd.clear();
+        for (auto listener : m_data->m_toAdd) {
+            m_data->m_listeners.push_front(listener);
+        }
+        m_data->m_toAdd.clear();
+        m_data->m_toRemove.clear();
     }
     return res;
 }
 
+DefaultEventListenerPool* DefaultEventListenerPool::create() {
+    return new DefaultEventListenerPool();
+}
+
 DefaultEventListenerPool* DefaultEventListenerPool::get() {
     static auto inst = new DefaultEventListenerPool();
     return inst;