Only retain buffer on receiving side (#3446)

(cherry picked from commit e3b6950386)
This commit is contained in:
deirn 2023-11-27 17:21:38 +07:00 committed by modmuss50
parent 00e49842c5
commit 901470e4e3
7 changed files with 103 additions and 25 deletions

View file

@ -48,6 +48,8 @@ public final class NetworkingImpl {
*/ */
public static final Identifier UNREGISTER_CHANNEL = new Identifier("minecraft", "unregister"); public static final Identifier UNREGISTER_CHANNEL = new Identifier("minecraft", "unregister");
public static final ThreadLocal<Boolean> FACTORY_RETAIN = ThreadLocal.withInitial(() -> Boolean.FALSE);
public static boolean isReservedCommonChannel(Identifier channelName) { public static boolean isReservedCommonChannel(Identifier channelName) {
return channelName.equals(REGISTER_CHANNEL) || channelName.equals(UNREGISTER_CHANNEL); return channelName.equals(REGISTER_CHANNEL) || channelName.equals(UNREGISTER_CHANNEL);
} }

View file

@ -17,6 +17,7 @@
package net.fabricmc.fabric.impl.networking.payload; package net.fabricmc.fabric.impl.networking.payload;
import net.minecraft.network.PacketByteBuf; import net.minecraft.network.PacketByteBuf;
import net.minecraft.util.Identifier;
import net.fabricmc.fabric.api.networking.v1.PacketByteBufs; import net.fabricmc.fabric.api.networking.v1.PacketByteBufs;
@ -26,15 +27,31 @@ public class PayloadHelper {
} }
public static PacketByteBuf read(PacketByteBuf byteBuf, int maxSize) { public static PacketByteBuf read(PacketByteBuf byteBuf, int maxSize) {
int size = byteBuf.readableBytes(); assertSize(byteBuf, maxSize);
if (size < 0 || size > maxSize) {
throw new IllegalArgumentException("Payload may not be larger than %d bytes".formatted(maxSize));
}
PacketByteBuf newBuf = PacketByteBufs.create(); PacketByteBuf newBuf = PacketByteBufs.create();
newBuf.writeBytes(byteBuf.copy()); newBuf.writeBytes(byteBuf.copy());
byteBuf.skipBytes(byteBuf.readableBytes()); byteBuf.skipBytes(byteBuf.readableBytes());
return newBuf; return newBuf;
} }
public static ResolvablePayload readCustom(Identifier id, PacketByteBuf buf, int maxSize, boolean retain) {
assertSize(buf, maxSize);
if (retain) {
RetainedPayload payload = new RetainedPayload(id, PacketByteBufs.retainedSlice(buf));
buf.skipBytes(buf.readableBytes());
return payload;
} else {
return new UntypedPayload(id, read(buf, maxSize));
}
}
private static void assertSize(PacketByteBuf buf, int maxSize) {
int size = buf.readableBytes();
if (size < 0 || size > maxSize) {
throw new IllegalArgumentException("Payload may not be larger than " + maxSize + " bytes");
}
}
} }

View file

@ -28,8 +28,8 @@ import net.minecraft.network.packet.CustomPayload;
import net.minecraft.network.packet.c2s.common.CustomPayloadC2SPacket; import net.minecraft.network.packet.c2s.common.CustomPayloadC2SPacket;
import net.minecraft.util.Identifier; import net.minecraft.util.Identifier;
import net.fabricmc.fabric.api.networking.v1.PacketByteBufs; import net.fabricmc.fabric.impl.networking.NetworkingImpl;
import net.fabricmc.fabric.impl.networking.payload.RetainedPayload; import net.fabricmc.fabric.impl.networking.payload.PayloadHelper;
@Mixin(CustomPayloadC2SPacket.class) @Mixin(CustomPayloadC2SPacket.class)
public class CustomPayloadC2SPacketMixin { public class CustomPayloadC2SPacketMixin {
@ -43,13 +43,6 @@ public class CustomPayloadC2SPacketMixin {
cancellable = true cancellable = true
) )
private static void readPayload(Identifier id, PacketByteBuf buf, CallbackInfoReturnable<CustomPayload> cir) { private static void readPayload(Identifier id, PacketByteBuf buf, CallbackInfoReturnable<CustomPayload> cir) {
int size = buf.readableBytes(); cir.setReturnValue(PayloadHelper.readCustom(id, buf, MAX_PAYLOAD_SIZE, NetworkingImpl.FACTORY_RETAIN.get()));
if (size < 0 || size > MAX_PAYLOAD_SIZE) {
throw new IllegalArgumentException("Payload may not be larger than " + MAX_PAYLOAD_SIZE + " bytes");
}
cir.setReturnValue(new RetainedPayload(id, PacketByteBufs.retainedSlice(buf)));
buf.skipBytes(size);
} }
} }

View file

@ -28,8 +28,8 @@ import net.minecraft.network.packet.CustomPayload;
import net.minecraft.network.packet.s2c.common.CustomPayloadS2CPacket; import net.minecraft.network.packet.s2c.common.CustomPayloadS2CPacket;
import net.minecraft.util.Identifier; import net.minecraft.util.Identifier;
import net.fabricmc.fabric.api.networking.v1.PacketByteBufs; import net.fabricmc.fabric.impl.networking.NetworkingImpl;
import net.fabricmc.fabric.impl.networking.payload.RetainedPayload; import net.fabricmc.fabric.impl.networking.payload.PayloadHelper;
@Mixin(CustomPayloadS2CPacket.class) @Mixin(CustomPayloadS2CPacket.class)
public class CustomPayloadS2CPacketMixin { public class CustomPayloadS2CPacketMixin {
@ -43,13 +43,6 @@ public class CustomPayloadS2CPacketMixin {
cancellable = true cancellable = true
) )
private static void readPayload(Identifier id, PacketByteBuf buf, CallbackInfoReturnable<CustomPayload> cir) { private static void readPayload(Identifier id, PacketByteBuf buf, CallbackInfoReturnable<CustomPayload> cir) {
int size = buf.readableBytes(); cir.setReturnValue(PayloadHelper.readCustom(id, buf, MAX_PAYLOAD_SIZE, NetworkingImpl.FACTORY_RETAIN.get()));
if (size < 0 || size > MAX_PAYLOAD_SIZE) {
throw new IllegalArgumentException("Payload may not be larger than " + MAX_PAYLOAD_SIZE + " bytes");
}
cir.setReturnValue(new RetainedPayload(id, PacketByteBufs.retainedSlice(buf)));
buf.skipBytes(size);
} }
} }

View file

@ -0,0 +1,64 @@
/*
* Copyright (c) 2016, 2017, 2018, 2019 FabricMC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.fabricmc.fabric.mixin.networking;
import java.util.function.Function;
import org.spongepowered.asm.mixin.Mixin;
import org.spongepowered.asm.mixin.injection.At;
import org.spongepowered.asm.mixin.injection.ModifyVariable;
import net.minecraft.network.PacketByteBuf;
import net.minecraft.network.packet.Packet;
import net.minecraft.network.packet.c2s.common.CustomPayloadC2SPacket;
import net.minecraft.network.packet.s2c.common.CustomPayloadS2CPacket;
import net.fabricmc.fabric.impl.networking.NetworkingImpl;
import net.fabricmc.fabric.impl.networking.payload.RetainedPayload;
import net.fabricmc.fabric.impl.networking.payload.UntypedPayload;
@Mixin(targets = "net.minecraft.network.NetworkState$InternalPacketHandler")
public class NetworkStateInternalPacketHandlerMixin {
/**
* Only retain custom packet buffer to {@link RetainedPayload} on the receiving side,
* otherwise resolve to {@link UntypedPayload}.
*/
@ModifyVariable(method = "register", at = @At("HEAD"), argsOnly = true)
private Function<PacketByteBuf, Packet<?>> replaceCustomPayloadFactory(Function<PacketByteBuf, Packet<?>> original, Class<?> type) {
if (type == CustomPayloadC2SPacket.class) {
return buf -> {
try {
NetworkingImpl.FACTORY_RETAIN.set(true);
return new CustomPayloadC2SPacket(buf);
} finally {
NetworkingImpl.FACTORY_RETAIN.set(false);
}
};
} else if (type == CustomPayloadS2CPacket.class) {
return buf -> {
try {
NetworkingImpl.FACTORY_RETAIN.set(true);
return new CustomPayloadS2CPacket(buf);
} finally {
NetworkingImpl.FACTORY_RETAIN.set(false);
}
};
}
return original;
}
}

View file

@ -10,6 +10,7 @@
"EntityTrackerEntryMixin", "EntityTrackerEntryMixin",
"LoginQueryRequestS2CPacketMixin", "LoginQueryRequestS2CPacketMixin",
"LoginQueryResponseC2SPacketMixin", "LoginQueryResponseC2SPacketMixin",
"NetworkStateInternalPacketHandlerMixin",
"PlayerManagerMixin", "PlayerManagerMixin",
"ServerCommonNetworkHandlerMixin", "ServerCommonNetworkHandlerMixin",
"ServerConfigurationNetworkHandlerMixin", "ServerConfigurationNetworkHandlerMixin",

View file

@ -29,6 +29,7 @@ import com.mojang.brigadier.arguments.StringArgumentType;
import net.minecraft.network.PacketByteBuf; import net.minecraft.network.PacketByteBuf;
import net.minecraft.network.listener.ClientPlayPacketListener; import net.minecraft.network.listener.ClientPlayPacketListener;
import net.minecraft.network.packet.Packet; import net.minecraft.network.packet.Packet;
import net.minecraft.network.packet.s2c.common.CustomPayloadS2CPacket;
import net.minecraft.network.packet.s2c.play.BundleS2CPacket; import net.minecraft.network.packet.s2c.play.BundleS2CPacket;
import net.minecraft.server.command.ServerCommandSource; import net.minecraft.server.command.ServerCommandSource;
import net.minecraft.server.network.ServerPlayerEntity; import net.minecraft.server.network.ServerPlayerEntity;
@ -70,6 +71,13 @@ public final class NetworkingPlayPacketTest implements ModInitializer {
sendToUnknownChannel(ctx.getSource().getPlayer()); sendToUnknownChannel(ctx.getSource().getPlayer());
return Command.SINGLE_SUCCESS; return Command.SINGLE_SUCCESS;
})) }))
.then(literal("bufctor").executes(ctx -> {
PacketByteBuf buf = PacketByteBufs.create();
buf.writeIdentifier(TEST_CHANNEL);
buf.writeText(Text.literal("bufctor"));
ctx.getSource().getPlayer().networkHandler.sendPacket(new CustomPayloadS2CPacket(buf));
return Command.SINGLE_SUCCESS;
}))
.then(literal("bundled").executes(ctx -> { .then(literal("bundled").executes(ctx -> {
PacketByteBuf buf1 = PacketByteBufs.create(); PacketByteBuf buf1 = PacketByteBufs.create();
buf1.writeText(Text.literal("bundled #1")); buf1.writeText(Text.literal("bundled #1"));