From bd906e77daf41e443f7ad0e49cb60796610f0d3e Mon Sep 17 00:00:00 2001
From: i509VCB <git@i509.me>
Date: Sat, 5 Dec 2020 13:03:30 -0600
Subject: [PATCH] Implement entity unload event on server (#1191)

* Start toying with server entity unload event

* Implement testmod stuff

* Implement shutdown implementation of unload entities

* Update fabric-lifecycle-events-v1/src/testmod/java/net/fabricmc/fabric/test/event/lifecycle/ServerEntityLifecycleTests.java

Co-authored-by: Pyrofab <redstoneinfire@gmail.com>

* Comment suggestion

Co-authored-by: Pyrofab <redstoneinfire@gmail.com>
---
 .../lifecycle/v1/ClientEntityEvents.java      |  4 +-
 .../lifecycle/v1/ServerEntityEvents.java      | 31 ++++++++++++-
 .../event/lifecycle/LifecycleEventsImpl.java  |  8 +++-
 .../ServerWorldEntityLoaderMixin.java         |  6 ++-
 .../lifecycle/ServerEntityLifecycleTests.java | 46 +++++++++++++++++++
 5 files changed, 89 insertions(+), 6 deletions(-)

diff --git a/fabric-lifecycle-events-v1/src/main/java/net/fabricmc/fabric/api/client/event/lifecycle/v1/ClientEntityEvents.java b/fabric-lifecycle-events-v1/src/main/java/net/fabricmc/fabric/api/client/event/lifecycle/v1/ClientEntityEvents.java
index f4ee6d837..4713d761b 100644
--- a/fabric-lifecycle-events-v1/src/main/java/net/fabricmc/fabric/api/client/event/lifecycle/v1/ClientEntityEvents.java
+++ b/fabric-lifecycle-events-v1/src/main/java/net/fabricmc/fabric/api/client/event/lifecycle/v1/ClientEntityEvents.java
@@ -57,12 +57,12 @@ public final class ClientEntityEvents {
 	/**
 	 * Called when an Entity is about to be unloaded from a ClientWorld.
 	 *
-	 * <p>When this event is called, the entity is still present in the world.
+	 * <p>This event is called before the entity is unloaded from the world.
 	 */
 	public static final Event<ClientEntityEvents.Unload> ENTITY_UNLOAD = EventFactory.createArrayBacked(ClientEntityEvents.Unload.class, callbacks -> (entity, world) -> {
 		if (EventFactory.isProfilingEnabled()) {
 			final Profiler profiler = world.getProfiler();
-			profiler.push("fabricClientEntityLoad");
+			profiler.push("fabricClientEntityUnload");
 
 			for (ClientEntityEvents.Unload callback : callbacks) {
 				profiler.push(EventFactory.getHandlerName(callback));
diff --git a/fabric-lifecycle-events-v1/src/main/java/net/fabricmc/fabric/api/event/lifecycle/v1/ServerEntityEvents.java b/fabric-lifecycle-events-v1/src/main/java/net/fabricmc/fabric/api/event/lifecycle/v1/ServerEntityEvents.java
index ab0333cb4..42693528a 100644
--- a/fabric-lifecycle-events-v1/src/main/java/net/fabricmc/fabric/api/event/lifecycle/v1/ServerEntityEvents.java
+++ b/fabric-lifecycle-events-v1/src/main/java/net/fabricmc/fabric/api/event/lifecycle/v1/ServerEntityEvents.java
@@ -31,8 +31,6 @@ public final class ServerEntityEvents {
 	 * Called when an Entity is loaded into a ServerWorld.
 	 *
 	 * <p>When this event is called, the entity is already in the world.
-	 *
-	 * <p>Note there is no corresponding unload event because entity unloads cannot be reliably tracked.
 	 */
 	public static final Event<ServerEntityEvents.Load> ENTITY_LOAD = EventFactory.createArrayBacked(ServerEntityEvents.Load.class, callbacks -> (entity, world) -> {
 		if (EventFactory.isProfilingEnabled()) {
@@ -53,8 +51,37 @@ public final class ServerEntityEvents {
 		}
 	});
 
+	/**
+	 * Called when an Entity is unloaded from a ServerWorld.
+	 *
+	 * <p>This event is called before the entity is removed from the world.
+	 */
+	public static final Event<ServerEntityEvents.Unload> ENTITY_UNLOAD = EventFactory.createArrayBacked(ServerEntityEvents.Unload.class, callbacks -> (entity, world) -> {
+		if (EventFactory.isProfilingEnabled()) {
+			final Profiler profiler = world.getProfiler();
+			profiler.push("fabricServerEntityUnload");
+
+			for (ServerEntityEvents.Unload callback : callbacks) {
+				profiler.push(EventFactory.getHandlerName(callback));
+				callback.onUnload(entity, world);
+				profiler.pop();
+			}
+
+			profiler.pop();
+		} else {
+			for (ServerEntityEvents.Unload callback : callbacks) {
+				callback.onUnload(entity, world);
+			}
+		}
+	});
+
 	@FunctionalInterface
 	public interface Load {
 		void onLoad(Entity entity, ServerWorld world);
 	}
+
+	@FunctionalInterface
+	public interface Unload {
+		void onUnload(Entity entity, ServerWorld world);
+	}
 }
diff --git a/fabric-lifecycle-events-v1/src/main/java/net/fabricmc/fabric/impl/event/lifecycle/LifecycleEventsImpl.java b/fabric-lifecycle-events-v1/src/main/java/net/fabricmc/fabric/impl/event/lifecycle/LifecycleEventsImpl.java
index ffa5c7005..d87dcd076 100644
--- a/fabric-lifecycle-events-v1/src/main/java/net/fabricmc/fabric/impl/event/lifecycle/LifecycleEventsImpl.java
+++ b/fabric-lifecycle-events-v1/src/main/java/net/fabricmc/fabric/impl/event/lifecycle/LifecycleEventsImpl.java
@@ -17,11 +17,13 @@
 package net.fabricmc.fabric.impl.event.lifecycle;
 
 import net.minecraft.block.entity.BlockEntity;
+import net.minecraft.entity.Entity;
 import net.minecraft.world.chunk.WorldChunk;
 
 import net.fabricmc.api.ModInitializer;
 import net.fabricmc.fabric.api.event.lifecycle.v1.ServerBlockEntityEvents;
 import net.fabricmc.fabric.api.event.lifecycle.v1.ServerChunkEvents;
+import net.fabricmc.fabric.api.event.lifecycle.v1.ServerEntityEvents;
 import net.fabricmc.fabric.api.event.lifecycle.v1.ServerWorldEvents;
 
 public final class LifecycleEventsImpl implements ModInitializer {
@@ -44,13 +46,17 @@ public final class LifecycleEventsImpl implements ModInitializer {
 			}
 		});
 
-		// We use the world unload event so worlds that are dynamically hot(un)loaded get block entity unload events fired when shut down.
+		// We use the world unload event so worlds that are dynamically hot(un)loaded get (block) entity unload events fired when shut down.
 		ServerWorldEvents.UNLOAD.register((server, world) -> {
 			for (WorldChunk chunk : ((LoadedChunksCache) world).fabric_getLoadedChunks()) {
 				for (BlockEntity blockEntity : chunk.getBlockEntities().values()) {
 					ServerBlockEntityEvents.BLOCK_ENTITY_UNLOAD.invoker().onUnload(blockEntity, world);
 				}
 			}
+
+			for (Entity entity : world.iterateEntities()) {
+				ServerEntityEvents.ENTITY_UNLOAD.invoker().onUnload(entity, world);
+			}
 		});
 	}
 }
diff --git a/fabric-lifecycle-events-v1/src/main/java/net/fabricmc/fabric/mixin/event/lifecycle/ServerWorldEntityLoaderMixin.java b/fabric-lifecycle-events-v1/src/main/java/net/fabricmc/fabric/mixin/event/lifecycle/ServerWorldEntityLoaderMixin.java
index d253b22f0..7e777b3d3 100644
--- a/fabric-lifecycle-events-v1/src/main/java/net/fabricmc/fabric/mixin/event/lifecycle/ServerWorldEntityLoaderMixin.java
+++ b/fabric-lifecycle-events-v1/src/main/java/net/fabricmc/fabric/mixin/event/lifecycle/ServerWorldEntityLoaderMixin.java
@@ -36,9 +36,13 @@ abstract class ServerWorldEntityLoaderMixin {
 	@Final
 	private ServerWorld field_26936;
 
-	// onLoadEntity
 	@Inject(method = "onLoadEntity(Lnet/minecraft/entity/Entity;)V", at = @At("TAIL"))
 	private void invokeEntityLoadEvent(Entity entity, CallbackInfo ci) {
 		ServerEntityEvents.ENTITY_LOAD.invoker().onLoad(entity, this.field_26936);
 	}
+
+	@Inject(method = "onUnloadEntity(Lnet/minecraft/entity/Entity;)V", at = @At("HEAD"))
+	private void invokeEntityUnloadEvent(Entity entity, CallbackInfo info) {
+		ServerEntityEvents.ENTITY_UNLOAD.invoker().onUnload(entity, this.field_26936);
+	}
 }
diff --git a/fabric-lifecycle-events-v1/src/testmod/java/net/fabricmc/fabric/test/event/lifecycle/ServerEntityLifecycleTests.java b/fabric-lifecycle-events-v1/src/testmod/java/net/fabricmc/fabric/test/event/lifecycle/ServerEntityLifecycleTests.java
index a4f6adefd..65050a93a 100644
--- a/fabric-lifecycle-events-v1/src/testmod/java/net/fabricmc/fabric/test/event/lifecycle/ServerEntityLifecycleTests.java
+++ b/fabric-lifecycle-events-v1/src/testmod/java/net/fabricmc/fabric/test/event/lifecycle/ServerEntityLifecycleTests.java
@@ -19,12 +19,16 @@ package net.fabricmc.fabric.test.event.lifecycle;
 import java.util.ArrayList;
 import java.util.List;
 
+import com.google.common.collect.Iterables;
 import org.apache.logging.log4j.Logger;
 
 import net.minecraft.entity.Entity;
+import net.minecraft.server.world.ServerWorld;
 
 import net.fabricmc.api.ModInitializer;
 import net.fabricmc.fabric.api.event.lifecycle.v1.ServerEntityEvents;
+import net.fabricmc.fabric.api.event.lifecycle.v1.ServerLifecycleEvents;
+import net.fabricmc.fabric.api.event.lifecycle.v1.ServerTickEvents;
 
 /**
  * Tests related to the lifecycle of entities.
@@ -32,6 +36,7 @@ import net.fabricmc.fabric.api.event.lifecycle.v1.ServerEntityEvents;
 public final class ServerEntityLifecycleTests implements ModInitializer {
 	private static final boolean PRINT_SERVER_ENTITY_MESSAGES = System.getProperty("fabric-lifecycle-events-testmod.printServerEntityMessages") != null;
 	private final List<Entity> serverEntities = new ArrayList<>();
+	private int serverTicks = 0;
 
 	@Override
 	public void onInitialize() {
@@ -44,5 +49,46 @@ public final class ServerEntityLifecycleTests implements ModInitializer {
 				logger.info("[SERVER] LOADED " + entity.toString() + " - Entities: " + this.serverEntities.size());
 			}
 		});
+
+		ServerEntityEvents.ENTITY_UNLOAD.register((entity, world) -> {
+			this.serverEntities.remove(entity);
+
+			if (PRINT_SERVER_ENTITY_MESSAGES) {
+				logger.info("[SERVER] UNLOADED " + entity.toString() + " - Entities: " + this.serverEntities.size());
+			}
+		});
+
+		ServerTickEvents.END_SERVER_TICK.register(server -> {
+			if (this.serverTicks++ % 200 == 0) {
+				int entities = 0;
+
+				for (ServerWorld world : server.getWorlds()) {
+					final int worldEntities = Iterables.size(world.iterateEntities());
+
+					if (PRINT_SERVER_ENTITY_MESSAGES) {
+						logger.info("[SERVER] Tracked Entities in " + world.getRegistryKey().toString() + " - " + worldEntities);
+					}
+
+					entities += worldEntities;
+				}
+
+				if (PRINT_SERVER_ENTITY_MESSAGES) {
+					logger.info("[SERVER] Actual Total Entities: " + entities);
+				}
+
+				if (entities != this.serverEntities.size()) {
+					// Always print mismatches
+					logger.error("[SERVER] Mismatch in tracked entities and actual entities");
+				}
+			}
+		});
+
+		ServerLifecycleEvents.SERVER_STOPPED.register(server -> {
+			logger.info("[SERVER] Disconnected. Tracking: " + this.serverEntities.size() + " entities");
+
+			if (this.serverEntities.size() != 0) {
+				logger.error("[SERVER] Mismatch in tracked entities, expected 0");
+			}
+		});
 	}
 }