Merge remote-tracking branch 'origin/master' into feature/1.21.2
Some checks failed
Java CI with Gradle / build (push) Has been cancelled

This commit is contained in:
Camotoy 2024-10-17 14:08:01 -04:00
commit f4c07f23b5
No known key found for this signature in database
GPG key ID: 7EEFB66FE798081F
34 changed files with 808 additions and 559 deletions

View file

@ -39,7 +39,7 @@ public class ClientSessionListener extends SessionAdapter {
public void connected(ConnectedEvent event) {
log.info("CLIENT Connected");
event.getSession().enableEncryption(((TestProtocol) event.getSession().getPacketProtocol()).getEncryption());
event.getSession().setEncryption(((TestProtocol) event.getSession().getPacketProtocol()).getEncryption());
event.getSession().send(new PingPacket("hello"));
}

View file

@ -38,7 +38,7 @@ public class ServerListener extends ServerAdapter {
public void sessionAdded(SessionAddedEvent event) {
log.info("SERVER Session Added: {}:{}", event.getSession().getHost(), event.getSession().getPort());
((TestProtocol) event.getSession().getPacketProtocol()).setSecretKey(this.key);
event.getSession().enableEncryption(((TestProtocol) event.getSession().getPacketProtocol()).getEncryption());
event.getSession().setEncryption(((TestProtocol) event.getSession().getPacketProtocol()).getEncryption());
}
@Override

View file

@ -8,7 +8,7 @@ import org.geysermc.mcprotocollib.network.codec.PacketCodecHelper;
import org.geysermc.mcprotocollib.network.codec.PacketDefinition;
import org.geysermc.mcprotocollib.network.codec.PacketSerializer;
import org.geysermc.mcprotocollib.network.crypt.AESEncryption;
import org.geysermc.mcprotocollib.network.crypt.PacketEncryption;
import org.geysermc.mcprotocollib.network.crypt.EncryptionConfig;
import org.geysermc.mcprotocollib.network.packet.DefaultPacketHeader;
import org.geysermc.mcprotocollib.network.packet.PacketHeader;
import org.geysermc.mcprotocollib.network.packet.PacketProtocol;
@ -23,7 +23,7 @@ public class TestProtocol extends PacketProtocol {
private static final Logger log = LoggerFactory.getLogger(TestProtocol.class);
private final PacketHeader header = new DefaultPacketHeader();
private final PacketRegistry registry = new PacketRegistry();
private AESEncryption encrypt;
private EncryptionConfig encrypt;
@SuppressWarnings("unused")
public TestProtocol() {
@ -51,7 +51,7 @@ public class TestProtocol extends PacketProtocol {
});
try {
this.encrypt = new AESEncryption(key);
this.encrypt = new EncryptionConfig(new AESEncryption(key));
} catch (GeneralSecurityException e) {
log.error("Failed to create encryption", e);
}
@ -67,7 +67,7 @@ public class TestProtocol extends PacketProtocol {
return this.header;
}
public PacketEncryption getEncryption() {
public EncryptionConfig getEncryption() {
return this.encrypt;
}
@ -82,7 +82,12 @@ public class TestProtocol extends PacketProtocol {
}
@Override
public PacketRegistry getPacketRegistry() {
public PacketRegistry getInboundPacketRegistry() {
return registry;
}
@Override
public PacketRegistry getOutboundPacketRegistry() {
return registry;
}
}

View file

@ -48,7 +48,8 @@ import java.util.BitSet;
public class MinecraftProtocolTest {
private static final Logger log = LoggerFactory.getLogger(MinecraftProtocolTest.class);
private static final boolean SPAWN_SERVER = true;
private static final boolean VERIFY_USERS = false;
private static final boolean ENCRYPT_CONNECTION = true;
private static final boolean SHOULD_AUTHENTICATE = false;
private static final String HOST = "127.0.0.1";
private static final int PORT = 25565;
private static final ProxyInfo PROXY = null;
@ -63,7 +64,8 @@ public class MinecraftProtocolTest {
Server server = new TcpServer(HOST, PORT, MinecraftProtocol::new);
server.setGlobalFlag(MinecraftConstants.SESSION_SERVICE_KEY, sessionService);
server.setGlobalFlag(MinecraftConstants.VERIFY_USERS_KEY, VERIFY_USERS);
server.setGlobalFlag(MinecraftConstants.ENCRYPT_CONNECTION, ENCRYPT_CONNECTION);
server.setGlobalFlag(MinecraftConstants.SHOULD_AUTHENTICATE, SHOULD_AUTHENTICATE);
server.setGlobalFlag(MinecraftConstants.SERVER_INFO_BUILDER_KEY, session ->
new ServerStatusInfo(
Component.text("Hello world!"),
@ -101,7 +103,7 @@ public class MinecraftProtocolTest {
))
);
server.setGlobalFlag(MinecraftConstants.SERVER_COMPRESSION_THRESHOLD, 100);
server.setGlobalFlag(MinecraftConstants.SERVER_COMPRESSION_THRESHOLD, 256);
server.addListener(new ServerAdapter() {
@Override
public void serverClosed(ServerClosedEvent event) {
@ -134,7 +136,7 @@ public class MinecraftProtocolTest {
@Override
public void sessionRemoved(SessionRemovedEvent event) {
MinecraftProtocol protocol = (MinecraftProtocol) event.getSession().getPacketProtocol();
if (protocol.getState() == ProtocolState.GAME) {
if (protocol.getOutboundState() == ProtocolState.GAME) {
log.info("Closing server.");
event.getServer().close(false);
}
@ -178,7 +180,7 @@ public class MinecraftProtocolTest {
private static void login() {
MinecraftProtocol protocol;
if (VERIFY_USERS) {
if (SHOULD_AUTHENTICATE) {
StepFullJavaSession.FullJavaSession fullJavaSession;
try {
fullJavaSession = MinecraftAuth.JAVA_CREDENTIALS_LOGIN.getFromInput(

View file

@ -6,8 +6,11 @@ import java.net.InetSocketAddress;
* Built-in PacketLib session flags.
*/
public class BuiltinFlags {
public static final Flag<Boolean> ENABLE_CLIENT_PROXY_PROTOCOL = new Flag<>("enable-client-proxy-protocol", Boolean.class);
/**
* Enables HAProxy protocol support.
* When this value is not null it represents the ip and port the client claims the connection is from.
*/
public static final Flag<InetSocketAddress> CLIENT_PROXIED_ADDRESS = new Flag<>("client-proxied-address", InetSocketAddress.class);
/**
@ -20,6 +23,24 @@ public class BuiltinFlags {
*/
public static final Flag<Boolean> TCP_FAST_OPEN = new Flag<>("tcp-fast-open", Boolean.class);
/**
* Connection timeout in seconds.
* Only used by the client.
*/
public static final Flag<Integer> CLIENT_CONNECT_TIMEOUT = new Flag<>("client-connect-timeout", Integer.class);
/**
* Read timeout in seconds.
* Used by both the server and client.
*/
public static final Flag<Integer> READ_TIMEOUT = new Flag<>("read-timeout", Integer.class);
/**
* Write timeout in seconds.
* Used by both the server and client.
*/
public static final Flag<Integer> WRITE_TIMEOUT = new Flag<>("write-timeout", Integer.class);
private BuiltinFlags() {
}
}

View file

@ -0,0 +1,10 @@
package org.geysermc.mcprotocollib.network;
import io.netty.util.AttributeKey;
import org.geysermc.mcprotocollib.network.compression.CompressionConfig;
import org.geysermc.mcprotocollib.network.crypt.EncryptionConfig;
public class NetworkConstants {
public static final AttributeKey<CompressionConfig> COMPRESSION_ATTRIBUTE_KEY = AttributeKey.valueOf("compression");
public static final AttributeKey<EncryptionConfig> ENCRYPTION_ATTRIBUTE_KEY = AttributeKey.valueOf("encryption");
}

View file

@ -1,14 +1,17 @@
package org.geysermc.mcprotocollib.network;
import io.netty.channel.Channel;
import net.kyori.adventure.text.Component;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.geysermc.mcprotocollib.network.codec.PacketCodecHelper;
import org.geysermc.mcprotocollib.network.crypt.PacketEncryption;
import org.geysermc.mcprotocollib.network.compression.CompressionConfig;
import org.geysermc.mcprotocollib.network.crypt.EncryptionConfig;
import org.geysermc.mcprotocollib.network.event.session.SessionEvent;
import org.geysermc.mcprotocollib.network.event.session.SessionListener;
import org.geysermc.mcprotocollib.network.packet.Packet;
import org.geysermc.mcprotocollib.network.packet.PacketProtocol;
import org.geysermc.mcprotocollib.network.tcp.FlushHandler;
import java.net.SocketAddress;
import java.util.List;
@ -37,7 +40,7 @@ public interface Session {
* @param wait Whether to wait for the connection to be established before returning.
* @param transferring Whether the session is a client being transferred.
*/
public void connect(boolean wait, boolean transferring);
void connect(boolean wait, boolean transferring);
/**
* Gets the host the session is connected to.
@ -138,7 +141,7 @@ public interface Session {
*
* @param flags Collection of flags
*/
public void setFlags(Map<String, Object> flags);
void setFlags(Map<String, Object> flags);
/**
* Gets the listeners listening on this session.
@ -183,68 +186,21 @@ public interface Session {
void callPacketSent(Packet packet);
/**
* Gets the compression packet length threshold for this session (-1 = disabled).
* Sets the compression config for this session.
*
* @return This session's compression threshold.
* @param compressionConfig the compression to compress with,
* or null to disable compression
*/
int getCompressionThreshold();
void setCompression(@Nullable CompressionConfig compressionConfig);
/**
* Sets the compression packet length threshold for this session (-1 = disabled).
* Sets encryption for this session.
*
* @param threshold The new compression threshold.
* @param validateDecompression whether to validate that the decompression fits within size checks.
*/
void setCompressionThreshold(int threshold, boolean validateDecompression);
/**
* Enables encryption for this session.
* @param encryptionConfig the encryption to encrypt with,
* or null to disable encryption
*
* @param encryption the encryption to encrypt with
*/
void enableEncryption(PacketEncryption encryption);
/**
* Gets the connect timeout for this session in seconds.
*
* @return The session's connect timeout.
*/
int getConnectTimeout();
/**
* Sets the connect timeout for this session in seconds.
*
* @param timeout Connect timeout to set.
*/
void setConnectTimeout(int timeout);
/**
* Gets the read timeout for this session in seconds.
*
* @return The session's read timeout.
*/
int getReadTimeout();
/**
* Sets the read timeout for this session in seconds.
*
* @param timeout Read timeout to set.
*/
void setReadTimeout(int timeout);
/**
* Gets the write timeout for this session in seconds.
*
* @return The session's write timeout.
*/
int getWriteTimeout();
/**
* Sets the write timeout for this session in seconds.
*
* @param timeout Write timeout to set.
*/
void setWriteTimeout(int timeout);
void setEncryption(@Nullable EncryptionConfig encryptionConfig);
/**
* Returns true if the session is connected.
@ -258,7 +214,17 @@ public interface Session {
*
* @param packet Packet to send.
*/
void send(Packet packet);
default void send(@NonNull Packet packet) {
this.send(packet, null);
}
/**
* Sends a packet and runs the specified callback when the packet has been sent.
*
* @param packet Packet to send.
* @param onSent Callback to run when the packet has been sent.
*/
void send(@NonNull Packet packet, @Nullable Runnable onSent);
/**
* Disconnects the session.
@ -301,4 +267,48 @@ public interface Session {
* @param cause Throwable responsible for disconnecting.
*/
void disconnect(@NonNull Component reason, @Nullable Throwable cause);
/**
* Auto read in netty means that the server is automatically reading from the channel.
* Turning it off means that we won't get more packets being decoded until we turn it back on.
* We use this to hold off on reading packets until we are ready to process them.
* For example this is used for switching inbound states with {@link #switchInboundState(Runnable)}.
*
* @param autoRead Whether to enable auto read.
* Default is true.
*/
void setAutoRead(boolean autoRead);
/**
* Returns the underlying netty channel of this session.
*
* @return The netty channel
*/
Channel getChannel();
/**
* Changes the inbound state of the session and then re-enables auto read.
* This is used after a terminal packet was handled and the session is ready to receive more packets in the new state.
*
* @param switcher The runnable that switches the inbound state.
*/
default void switchInboundState(Runnable switcher) {
switcher.run();
// We switched to the new inbound state
// we can start reading again
setAutoRead(true);
}
/**
* Flushes all packets that are due to be sent and changes the outbound state of the session.
* This makes sure no other threads have scheduled packets to be sent.
*
* @param switcher The runnable that switches the outbound state.
*/
default void switchOutboundState(Runnable switcher) {
getChannel().writeAndFlush(FlushHandler.FLUSH_PACKET).syncUninterruptibly();
switcher.run();
}
}

View file

@ -0,0 +1,4 @@
package org.geysermc.mcprotocollib.network.compression;
public record CompressionConfig(int threshold, PacketCompression compression, boolean validateDecompression) {
}

View file

@ -0,0 +1,4 @@
package org.geysermc.mcprotocollib.network.crypt;
public record EncryptionConfig(PacketEncryption encryption) {
}

View file

@ -34,8 +34,11 @@ public class TransportHelper {
}
public record TransportType(TransportMethod method,
Class<? extends ServerSocketChannel> serverSocketChannelClass,
ChannelFactory<? extends ServerSocketChannel> serverSocketChannelFactory,
Class<? extends SocketChannel> socketChannelClass,
ChannelFactory<? extends SocketChannel> socketChannelFactory,
Class<? extends DatagramChannel> datagramChannelClass,
ChannelFactory<? extends DatagramChannel> datagramChannelFactory,
Function<ThreadFactory, EventLoopGroup> eventLoopGroupFactory,
boolean supportsTcpFastOpenServer,
@ -46,8 +49,11 @@ public class TransportHelper {
if (isClassAvailable("io.netty.incubator.channel.uring.IOUring") && IOUring.isAvailable()) {
return new TransportType(
TransportMethod.IO_URING,
IOUringServerSocketChannel.class,
IOUringServerSocketChannel::new,
IOUringSocketChannel.class,
IOUringSocketChannel::new,
IOUringDatagramChannel.class,
IOUringDatagramChannel::new,
factory -> new IOUringEventLoopGroup(0, factory),
IOUring.isTcpFastOpenServerSideAvailable(),
@ -58,8 +64,11 @@ public class TransportHelper {
if (isClassAvailable("io.netty.channel.epoll.Epoll") && Epoll.isAvailable()) {
return new TransportType(
TransportMethod.EPOLL,
EpollServerSocketChannel.class,
EpollServerSocketChannel::new,
EpollSocketChannel.class,
EpollSocketChannel::new,
EpollDatagramChannel.class,
EpollDatagramChannel::new,
factory -> new EpollEventLoopGroup(0, factory),
Epoll.isTcpFastOpenServerSideAvailable(),
@ -70,8 +79,11 @@ public class TransportHelper {
if (isClassAvailable("io.netty.channel.kqueue.KQueue") && KQueue.isAvailable()) {
return new TransportType(
TransportMethod.KQUEUE,
KQueueServerSocketChannel.class,
KQueueServerSocketChannel::new,
KQueueSocketChannel.class,
KQueueSocketChannel::new,
KQueueDatagramChannel.class,
KQueueDatagramChannel::new,
factory -> new KQueueEventLoopGroup(0, factory),
KQueue.isTcpFastOpenServerSideAvailable(),
@ -81,8 +93,11 @@ public class TransportHelper {
return new TransportType(
TransportMethod.NIO,
NioServerSocketChannel.class,
NioServerSocketChannel::new,
NioSocketChannel.class,
NioSocketChannel::new,
NioDatagramChannel.class,
NioDatagramChannel::new,
factory -> new NioEventLoopGroup(0, factory),
false,

View file

@ -1,6 +1,7 @@
package org.geysermc.mcprotocollib.network.packet;
import io.netty.buffer.ByteBuf;
import org.geysermc.mcprotocollib.network.Session;
/**
* A network packet. Any given packet must have a constructor that takes in a {@link ByteBuf}.
@ -17,4 +18,14 @@ public interface Packet {
default boolean isPriority() {
return false;
}
/**
* Returns whether the packet is terminal. If true, this should be the last packet sent inside a protocol state.
* Subsequently, {@link Session#setAutoRead(boolean)} should be disabled when a terminal packet is received, until the session has switched into a new state and is ready to receive more packets.
*
* @return Whether the packet is terminal.
*/
default boolean isTerminal() {
return false;
}
}

View file

@ -49,9 +49,16 @@ public abstract class PacketProtocol {
public abstract void newServerSession(Server server, Session session);
/**
* Gets the packet registry for this protocol.
* Gets the inbound packet registry for this protocol.
*
* @return The protocol's packet registry.
* @return The protocol's inbound packet registry.
*/
public abstract PacketRegistry getPacketRegistry();
public abstract PacketRegistry getInboundPacketRegistry();
/**
* Gets the outbound packet registry for this protocol.
*
* @return The protocol's outbound packet registry.
*/
public abstract PacketRegistry getOutboundPacketRegistry();
}

View file

@ -0,0 +1,28 @@
package org.geysermc.mcprotocollib.network.tcp;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
/**
* Sending a {@link FlushPacket} will ensure all before were sent.
* This handler makes sure it's dropped before it reaches the encoder.
* This logic is similar to the Minecraft UnconfiguredPipelineHandler.OutboundConfigurationTask.
*/
public class FlushHandler extends ChannelOutboundHandlerAdapter {
public static final FlushPacket FLUSH_PACKET = new FlushPacket();
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
if (msg == FLUSH_PACKET) {
promise.setSuccess();
} else {
super.write(ctx, msg, promise);
}
}
public static class FlushPacket {
private FlushPacket() {
}
}
}

View file

@ -4,7 +4,6 @@ import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.channel.AddressedEnvelope;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
@ -25,9 +24,12 @@ import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol;
import io.netty.handler.proxy.HttpProxyHandler;
import io.netty.handler.proxy.Socks4ProxyHandler;
import io.netty.handler.proxy.Socks5ProxyHandler;
import io.netty.handler.timeout.ReadTimeoutHandler;
import io.netty.handler.timeout.WriteTimeoutHandler;
import io.netty.resolver.dns.DnsNameResolver;
import io.netty.resolver.dns.DnsNameResolverBuilder;
import io.netty.util.concurrent.DefaultThreadFactory;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.geysermc.mcprotocollib.network.BuiltinFlags;
import org.geysermc.mcprotocollib.network.ProxyInfo;
import org.geysermc.mcprotocollib.network.codec.PacketCodecHelper;
@ -40,6 +42,7 @@ import java.net.Inet4Address;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
@ -90,56 +93,55 @@ public class TcpClientSession extends TcpSession {
createTcpEventLoopGroup();
}
try {
final Bootstrap bootstrap = new Bootstrap()
.channelFactory(TRANSPORT_TYPE.socketChannelFactory())
.option(ChannelOption.TCP_NODELAY, true)
.option(ChannelOption.IP_TOS, 0x18)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, getConnectTimeout() * 1000)
.group(EVENT_LOOP_GROUP)
.remoteAddress(resolveAddress())
.localAddress(bindAddress, bindPort)
.handler(new ChannelInitializer<>() {
@Override
public void initChannel(Channel channel) {
PacketProtocol protocol = getPacketProtocol();
protocol.newClientSession(TcpClientSession.this, transferring);
final Bootstrap bootstrap = new Bootstrap()
.channelFactory(TRANSPORT_TYPE.socketChannelFactory())
.option(ChannelOption.TCP_NODELAY, true)
.option(ChannelOption.IP_TOS, 0x18)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, getFlag(BuiltinFlags.CLIENT_CONNECT_TIMEOUT, 30) * 1000)
.group(EVENT_LOOP_GROUP)
.remoteAddress(resolveAddress())
.localAddress(bindAddress, bindPort)
.handler(new ChannelInitializer<>() {
@Override
public void initChannel(@NonNull Channel channel) {
PacketProtocol protocol = getPacketProtocol();
protocol.newClientSession(TcpClientSession.this, transferring);
ChannelPipeline pipeline = channel.pipeline();
ChannelPipeline pipeline = channel.pipeline();
refreshReadTimeoutHandler(channel);
refreshWriteTimeoutHandler(channel);
addProxy(pipeline);
addProxy(pipeline);
initializeHAProxySupport(channel);
int size = protocol.getPacketHeader().getLengthSize();
if (size > 0) {
pipeline.addLast("sizer", new TcpPacketSizer(TcpClientSession.this, size));
}
pipeline.addLast("read-timeout", new ReadTimeoutHandler(getFlag(BuiltinFlags.READ_TIMEOUT, 30)));
pipeline.addLast("write-timeout", new WriteTimeoutHandler(getFlag(BuiltinFlags.WRITE_TIMEOUT, 0)));
pipeline.addLast("codec", new TcpPacketCodec(TcpClientSession.this, true));
pipeline.addLast("manager", TcpClientSession.this);
pipeline.addLast("encryption", new TcpPacketEncryptor());
pipeline.addLast("sizer", new TcpPacketSizer(protocol.getPacketHeader(), getCodecHelper()));
pipeline.addLast("compression", new TcpPacketCompression(getCodecHelper()));
addHAProxySupport(pipeline);
}
});
if (getFlag(BuiltinFlags.TCP_FAST_OPEN, false) && TRANSPORT_TYPE.supportsTcpFastOpenClient()) {
bootstrap.option(ChannelOption.TCP_FASTOPEN_CONNECT, true);
}
ChannelFuture future = bootstrap.connect();
if (wait) {
future.sync();
}
future.addListener((futureListener) -> {
if (!futureListener.isSuccess()) {
exceptionCaught(null, futureListener.cause());
pipeline.addLast("flow-control", new TcpFlowControlHandler());
pipeline.addLast("codec", new TcpPacketCodec(TcpClientSession.this, true));
pipeline.addLast("flush-handler", new FlushHandler());
pipeline.addLast("manager", TcpClientSession.this);
}
});
} catch (Throwable t) {
exceptionCaught(null, t);
if (getFlag(BuiltinFlags.TCP_FAST_OPEN, false) && TRANSPORT_TYPE.supportsTcpFastOpenClient()) {
bootstrap.option(ChannelOption.TCP_FASTOPEN_CONNECT, true);
}
CompletableFuture<Void> handleFuture = new CompletableFuture<>();
bootstrap.connect().addListener((futureListener) -> {
if (!futureListener.isSuccess()) {
exceptionCaught(null, futureListener.cause());
}
handleFuture.complete(null);
});
if (wait) {
handleFuture.join();
}
}
@ -153,42 +155,39 @@ public class TcpClientSession extends TcpSession {
log.debug("Attempting SRV lookup for \"{}\".", name);
if (getFlag(BuiltinFlags.ATTEMPT_SRV_RESOLVE, true) && (!this.host.matches(IP_REGEX) && !this.host.equalsIgnoreCase("localhost"))) {
AddressedEnvelope<DnsResponse, InetSocketAddress> envelope = null;
try (DnsNameResolver resolver = new DnsNameResolverBuilder(EVENT_LOOP_GROUP.next())
.channelFactory(TRANSPORT_TYPE.datagramChannelFactory())
.build()) {
envelope = resolver.query(new DefaultDnsQuestion(name, DnsRecordType.SRV)).get();
.channelFactory(TRANSPORT_TYPE.datagramChannelFactory())
.build()) {
AddressedEnvelope<DnsResponse, InetSocketAddress> envelope = resolver.query(new DefaultDnsQuestion(name, DnsRecordType.SRV)).get();
try {
DnsResponse response = envelope.content();
if (response.count(DnsSection.ANSWER) > 0) {
DefaultDnsRawRecord record = response.recordAt(DnsSection.ANSWER, 0);
if (record.type() == DnsRecordType.SRV) {
ByteBuf buf = record.content();
buf.skipBytes(4); // Skip priority and weight.
DnsResponse response = envelope.content();
if (response.count(DnsSection.ANSWER) > 0) {
DefaultDnsRawRecord record = response.recordAt(DnsSection.ANSWER, 0);
if (record.type() == DnsRecordType.SRV) {
ByteBuf buf = record.content();
buf.skipBytes(4); // Skip priority and weight.
int port = buf.readUnsignedShort();
String host = DefaultDnsRecordDecoder.decodeName(buf);
if (host.endsWith(".")) {
host = host.substring(0, host.length() - 1);
}
int port = buf.readUnsignedShort();
String host = DefaultDnsRecordDecoder.decodeName(buf);
if (host.endsWith(".")) {
host = host.substring(0, host.length() - 1);
log.debug("Found SRV record containing \"{}:{}\".", host, port);
this.host = host;
this.port = port;
} else {
log.debug("Received non-SRV record in response.");
}
log.debug("Found SRV record containing \"{}:{}\".", host, port);
this.host = host;
this.port = port;
} else {
log.debug("Received non-SRV record in response.");
log.debug("No SRV record found.");
}
} else {
log.debug("No SRV record found.");
} finally {
envelope.release();
}
} catch (Exception e) {
log.debug("Failed to resolve SRV record.", e);
} finally {
if (envelope != null) {
envelope.release();
}
}
} else {
log.debug("Not resolving SRV record for {}", this.host);
@ -206,54 +205,58 @@ public class TcpClientSession extends TcpSession {
}
private void addProxy(ChannelPipeline pipeline) {
if (proxy != null) {
switch (proxy.type()) {
case HTTP -> {
if (proxy.username() != null && proxy.password() != null) {
pipeline.addFirst("proxy", new HttpProxyHandler(proxy.address(), proxy.username(), proxy.password()));
} else {
pipeline.addFirst("proxy", new HttpProxyHandler(proxy.address()));
}
if (proxy == null) {
return;
}
switch (proxy.type()) {
case HTTP -> {
if (proxy.username() != null && proxy.password() != null) {
pipeline.addLast("proxy", new HttpProxyHandler(proxy.address(), proxy.username(), proxy.password()));
} else {
pipeline.addLast("proxy", new HttpProxyHandler(proxy.address()));
}
case SOCKS4 -> {
if (proxy.username() != null) {
pipeline.addFirst("proxy", new Socks4ProxyHandler(proxy.address(), proxy.username()));
} else {
pipeline.addFirst("proxy", new Socks4ProxyHandler(proxy.address()));
}
}
case SOCKS5 -> {
if (proxy.username() != null && proxy.password() != null) {
pipeline.addFirst("proxy", new Socks5ProxyHandler(proxy.address(), proxy.username(), proxy.password()));
} else {
pipeline.addFirst("proxy", new Socks5ProxyHandler(proxy.address()));
}
}
default -> throw new UnsupportedOperationException("Unsupported proxy type: " + proxy.type());
}
case SOCKS4 -> {
if (proxy.username() != null) {
pipeline.addLast("proxy", new Socks4ProxyHandler(proxy.address(), proxy.username()));
} else {
pipeline.addLast("proxy", new Socks4ProxyHandler(proxy.address()));
}
}
case SOCKS5 -> {
if (proxy.username() != null && proxy.password() != null) {
pipeline.addLast("proxy", new Socks5ProxyHandler(proxy.address(), proxy.username(), proxy.password()));
} else {
pipeline.addLast("proxy", new Socks5ProxyHandler(proxy.address()));
}
}
default -> throw new UnsupportedOperationException("Unsupported proxy type: " + proxy.type());
}
}
private void addHAProxySupport(ChannelPipeline pipeline) {
private void initializeHAProxySupport(Channel channel) {
InetSocketAddress clientAddress = getFlag(BuiltinFlags.CLIENT_PROXIED_ADDRESS);
if (getFlag(BuiltinFlags.ENABLE_CLIENT_PROXY_PROTOCOL, false) && clientAddress != null) {
pipeline.addFirst("proxy-protocol-packet-sender", new ChannelInboundHandlerAdapter() {
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
HAProxyProxiedProtocol proxiedProtocol = clientAddress.getAddress() instanceof Inet4Address ? HAProxyProxiedProtocol.TCP4 : HAProxyProxiedProtocol.TCP6;
InetSocketAddress remoteAddress = (InetSocketAddress) ctx.channel().remoteAddress();
ctx.channel().writeAndFlush(new HAProxyMessage(
HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, proxiedProtocol,
clientAddress.getAddress().getHostAddress(), remoteAddress.getAddress().getHostAddress(),
clientAddress.getPort(), remoteAddress.getPort()
));
ctx.pipeline().remove(this);
ctx.pipeline().remove("proxy-protocol-encoder");
super.channelActive(ctx);
}
});
pipeline.addFirst("proxy-protocol-encoder", HAProxyMessageEncoder.INSTANCE);
if (clientAddress == null) {
return;
}
channel.pipeline().addLast("proxy-protocol-encoder", HAProxyMessageEncoder.INSTANCE);
channel.pipeline().addLast("proxy-protocol-packet-sender", new ChannelInboundHandlerAdapter() {
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
InetSocketAddress remoteAddress = (InetSocketAddress) ctx.channel().remoteAddress();
HAProxyProxiedProtocol proxiedProtocol = clientAddress.getAddress() instanceof Inet4Address ? HAProxyProxiedProtocol.TCP4 : HAProxyProxiedProtocol.TCP6;
ctx.channel().writeAndFlush(new HAProxyMessage(
HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, proxiedProtocol,
clientAddress.getAddress().getHostAddress(), remoteAddress.getAddress().getHostAddress(),
clientAddress.getPort(), remoteAddress.getPort()
)).addListener(future -> channel.pipeline().remove("proxy-protocol-encoder"));
ctx.pipeline().remove(this);
super.channelActive(ctx);
}
});
}
private static void createTcpEventLoopGroup() {
@ -264,7 +267,7 @@ public class TcpClientSession extends TcpSession {
EVENT_LOOP_GROUP = TRANSPORT_TYPE.eventLoopGroupFactory().apply(newThreadFactory());
Runtime.getRuntime().addShutdownHook(new Thread(
() -> EVENT_LOOP_GROUP.shutdownGracefully(SHUTDOWN_QUIET_PERIOD_MS, SHUTDOWN_TIMEOUT_MS, TimeUnit.MILLISECONDS)));
() -> EVENT_LOOP_GROUP.shutdownGracefully(SHUTDOWN_QUIET_PERIOD_MS, SHUTDOWN_TIMEOUT_MS, TimeUnit.MILLISECONDS)));
}
protected static ThreadFactory newThreadFactory() {

View file

@ -0,0 +1,20 @@
package org.geysermc.mcprotocollib.network.tcp;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.flow.FlowControlHandler;
/**
* A flow control handler for TCP connections.
* When auto-read is disabled, this will halt decoding of packets until auto-read is re-enabled.
* This is needed because auto-read still allows packets to be decoded, even if the channel is not reading anymore from the network.
* This can happen when the channel already read a packet, but the packet is not yet decoded.
* This will halt all decoding until the channel is ready to process more packets.
*/
public class TcpFlowControlHandler extends FlowControlHandler {
@Override
public void read(ChannelHandlerContext ctx) throws Exception {
if (ctx.channel().config().isAutoRead()) {
super.read(ctx);
}
}
}

View file

@ -2,17 +2,27 @@ package org.geysermc.mcprotocollib.network.tcp;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageCodec;
import io.netty.handler.codec.DecoderException;
import io.netty.handler.codec.EncoderException;
import io.netty.handler.codec.MessageToMessageCodec;
import org.geysermc.mcprotocollib.network.Session;
import org.geysermc.mcprotocollib.network.codec.PacketCodecHelper;
import org.geysermc.mcprotocollib.network.codec.PacketDefinition;
import org.geysermc.mcprotocollib.network.event.session.PacketErrorEvent;
import org.geysermc.mcprotocollib.network.packet.Packet;
import org.geysermc.mcprotocollib.network.packet.PacketProtocol;
import org.geysermc.mcprotocollib.network.packet.PacketRegistry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.Marker;
import org.slf4j.MarkerFactory;
import java.util.List;
public class TcpPacketCodec extends ByteToMessageCodec<Packet> {
public class TcpPacketCodec extends MessageToMessageCodec<ByteBuf, Packet> {
private static final Marker marker = MarkerFactory.getMarker("packet_logging");
private static final Logger log = LoggerFactory.getLogger(TcpPacketCodec.class);
private final Session session;
private final boolean client;
@ -23,35 +33,51 @@ public class TcpPacketCodec extends ByteToMessageCodec<Packet> {
@SuppressWarnings({"rawtypes", "unchecked"})
@Override
public void encode(ChannelHandlerContext ctx, Packet packet, ByteBuf buf) {
int initial = buf.writerIndex();
public void encode(ChannelHandlerContext ctx, Packet packet, List<Object> out) {
if (log.isTraceEnabled()) {
log.trace(marker, "Encoding packet: {}", packet.getClass().getSimpleName());
}
PacketProtocol packetProtocol = this.session.getPacketProtocol();
PacketRegistry packetRegistry = packetProtocol.getOutboundPacketRegistry();
PacketCodecHelper codecHelper = this.session.getCodecHelper();
try {
int packetId = this.client ? packetProtocol.getPacketRegistry().getServerboundId(packet) : packetProtocol.getPacketRegistry().getClientboundId(packet);
PacketDefinition definition = this.client ? packetProtocol.getPacketRegistry().getServerboundDefinition(packetId) : packetProtocol.getPacketRegistry().getClientboundDefinition(packetId);
int packetId = this.client ? packetRegistry.getServerboundId(packet) : packetRegistry.getClientboundId(packet);
PacketDefinition definition = this.client ? packetRegistry.getServerboundDefinition(packetId) : packetRegistry.getClientboundDefinition(packetId);
ByteBuf buf = ctx.alloc().buffer();
packetProtocol.getPacketHeader().writePacketId(buf, codecHelper, packetId);
definition.getSerializer().serialize(buf, codecHelper, packet);
out.add(buf);
if (log.isDebugEnabled()) {
log.debug(marker, "Encoded packet {} ({})", packet.getClass().getSimpleName(), packetId);
}
} catch (Throwable t) {
// Reset writer index to make sure incomplete data is not written out.
buf.writerIndex(initial);
log.debug(marker, "Error encoding packet", t);
PacketErrorEvent e = new PacketErrorEvent(this.session, t);
this.session.callEvent(e);
if (!e.shouldSuppress()) {
throw t;
throw new EncoderException(t);
}
}
}
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf buf, List<Object> out) {
// Vanilla also checks for 0 length
if (buf.readableBytes() == 0) {
return;
}
int initial = buf.readerIndex();
PacketProtocol packetProtocol = this.session.getPacketProtocol();
PacketRegistry packetRegistry = packetProtocol.getInboundPacketRegistry();
PacketCodecHelper codecHelper = this.session.getCodecHelper();
Packet packet = null;
try {
int id = packetProtocol.getPacketHeader().readPacketId(buf, codecHelper);
if (id == -1) {
@ -59,21 +85,35 @@ public class TcpPacketCodec extends ByteToMessageCodec<Packet> {
return;
}
Packet packet = this.client ? packetProtocol.getPacketRegistry().createClientboundPacket(id, buf, codecHelper) : packetProtocol.getPacketRegistry().createServerboundPacket(id, buf, codecHelper);
log.trace(marker, "Decoding packet with id: {}", id);
packet = this.client ? packetRegistry.createClientboundPacket(id, buf, codecHelper) : packetRegistry.createServerboundPacket(id, buf, codecHelper);
if (buf.readableBytes() > 0) {
throw new IllegalStateException("Packet \"" + packet.getClass().getSimpleName() + "\" not fully read.");
}
out.add(packet);
if (log.isDebugEnabled()) {
log.debug(marker, "Decoded packet {} ({})", packet.getClass().getSimpleName(), id);
}
} catch (Throwable t) {
log.debug(marker, "Error decoding packet", t);
// Advance buffer to end to make sure remaining data in this packet is skipped.
buf.readerIndex(buf.readerIndex() + buf.readableBytes());
PacketErrorEvent e = new PacketErrorEvent(this.session, t);
this.session.callEvent(e);
if (!e.shouldSuppress()) {
throw t;
throw new DecoderException(t);
}
} finally {
if (packet != null && packet.isTerminal()) {
// Next packets are in a different protocol state, so we must
// disable auto-read to prevent reading wrong packets.
session.setAutoRead(false);
}
}
}

View file

@ -4,44 +4,49 @@ import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.DecoderException;
import io.netty.handler.codec.MessageToMessageCodec;
import org.geysermc.mcprotocollib.network.Session;
import org.geysermc.mcprotocollib.network.compression.PacketCompression;
import lombok.RequiredArgsConstructor;
import org.geysermc.mcprotocollib.network.NetworkConstants;
import org.geysermc.mcprotocollib.network.codec.PacketCodecHelper;
import org.geysermc.mcprotocollib.network.compression.CompressionConfig;
import java.util.List;
@RequiredArgsConstructor
public class TcpPacketCompression extends MessageToMessageCodec<ByteBuf, ByteBuf> {
private static final int MAX_UNCOMPRESSED_SIZE = 8 * 1024 * 1024; // 8MiB
private final Session session;
private final PacketCompression compression;
private final boolean validateDecompression;
public TcpPacketCompression(Session session, PacketCompression compression, boolean validateDecompression) {
this.session = session;
this.compression = compression;
this.validateDecompression = validateDecompression;
}
private final PacketCodecHelper helper;
@Override
public void handlerRemoved(ChannelHandlerContext ctx) {
this.compression.close();
CompressionConfig config = ctx.channel().attr(NetworkConstants.COMPRESSION_ATTRIBUTE_KEY).get();
if (config == null) {
return;
}
config.compression().close();
}
@Override
public void encode(ChannelHandlerContext ctx, ByteBuf msg, List<Object> out) {
CompressionConfig config = ctx.channel().attr(NetworkConstants.COMPRESSION_ATTRIBUTE_KEY).get();
if (config == null) {
out.add(msg.retain());
return;
}
int uncompressed = msg.readableBytes();
if (uncompressed > MAX_UNCOMPRESSED_SIZE) {
throw new IllegalArgumentException("Packet too big (is " + uncompressed + ", should be less than " + MAX_UNCOMPRESSED_SIZE + ")");
}
ByteBuf outBuf = ctx.alloc().directBuffer(uncompressed);
if (uncompressed < this.session.getCompressionThreshold()) {
if (uncompressed < config.threshold()) {
// Under the threshold, there is nothing to do.
this.session.getCodecHelper().writeVarInt(outBuf, 0);
this.helper.writeVarInt(outBuf, 0);
outBuf.writeBytes(msg);
} else {
this.session.getCodecHelper().writeVarInt(outBuf, uncompressed);
compression.deflate(msg, outBuf);
this.helper.writeVarInt(outBuf, uncompressed);
config.compression().deflate(msg, outBuf);
}
out.add(outBuf);
@ -49,15 +54,21 @@ public class TcpPacketCompression extends MessageToMessageCodec<ByteBuf, ByteBuf
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {
int claimedUncompressedSize = this.session.getCodecHelper().readVarInt(in);
CompressionConfig config = ctx.channel().attr(NetworkConstants.COMPRESSION_ATTRIBUTE_KEY).get();
if (config == null) {
out.add(in.retain());
return;
}
int claimedUncompressedSize = this.helper.readVarInt(in);
if (claimedUncompressedSize == 0) {
out.add(in.retain());
return;
}
if (validateDecompression) {
if (claimedUncompressedSize < this.session.getCompressionThreshold()) {
throw new DecoderException("Badly compressed packet - size of " + claimedUncompressedSize + " is below server threshold of " + this.session.getCompressionThreshold());
if (config.validateDecompression()) {
if (claimedUncompressedSize < config.threshold()) {
throw new DecoderException("Badly compressed packet - size of " + claimedUncompressedSize + " is below server threshold of " + config.threshold());
}
if (claimedUncompressedSize > MAX_UNCOMPRESSED_SIZE) {
@ -67,7 +78,7 @@ public class TcpPacketCompression extends MessageToMessageCodec<ByteBuf, ByteBuf
ByteBuf uncompressed = ctx.alloc().directBuffer(claimedUncompressedSize);
try {
compression.inflate(in, uncompressed, claimedUncompressedSize);
config.compression().inflate(in, uncompressed, claimedUncompressedSize);
out.add(uncompressed);
} catch (Exception e) {
uncompressed.release();

View file

@ -6,26 +6,27 @@ import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.DecoderException;
import io.netty.handler.codec.EncoderException;
import io.netty.handler.codec.MessageToMessageCodec;
import org.geysermc.mcprotocollib.network.crypt.PacketEncryption;
import org.geysermc.mcprotocollib.network.NetworkConstants;
import org.geysermc.mcprotocollib.network.crypt.EncryptionConfig;
import java.util.List;
public class TcpPacketEncryptor extends MessageToMessageCodec<ByteBuf, ByteBuf> {
private final PacketEncryption encryption;
public TcpPacketEncryptor(PacketEncryption encryption) {
this.encryption = encryption;
}
@Override
public void encode(ChannelHandlerContext ctx, ByteBuf msg, List<Object> out) {
EncryptionConfig config = ctx.channel().attr(NetworkConstants.ENCRYPTION_ATTRIBUTE_KEY).get();
if (config == null) {
out.add(msg.retain());
return;
}
ByteBuf heapBuf = this.ensureHeapBuffer(ctx.alloc(), msg);
int inBytes = heapBuf.readableBytes();
int baseOffset = heapBuf.arrayOffset() + heapBuf.readerIndex();
try {
encryption.encrypt(heapBuf.array(), baseOffset, inBytes, heapBuf.array(), baseOffset);
config.encryption().encrypt(heapBuf.array(), baseOffset, inBytes, heapBuf.array(), baseOffset);
out.add(heapBuf);
} catch (Exception e) {
heapBuf.release();
@ -35,13 +36,19 @@ public class TcpPacketEncryptor extends MessageToMessageCodec<ByteBuf, ByteBuf>
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {
EncryptionConfig config = ctx.channel().attr(NetworkConstants.ENCRYPTION_ATTRIBUTE_KEY).get();
if (config == null) {
out.add(in.retain());
return;
}
ByteBuf heapBuf = this.ensureHeapBuffer(ctx.alloc(), in).slice();
int inBytes = heapBuf.readableBytes();
int baseOffset = heapBuf.arrayOffset() + heapBuf.readerIndex();
try {
encryption.decrypt(heapBuf.array(), baseOffset, inBytes, heapBuf.array(), baseOffset);
config.encryption().decrypt(heapBuf.array(), baseOffset, inBytes, heapBuf.array(), baseOffset);
out.add(heapBuf);
} catch (Exception e) {
heapBuf.release();

View file

@ -5,29 +5,39 @@ import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageCodec;
import io.netty.handler.codec.CorruptedFrameException;
import org.geysermc.mcprotocollib.network.Session;
import lombok.RequiredArgsConstructor;
import org.geysermc.mcprotocollib.network.codec.PacketCodecHelper;
import org.geysermc.mcprotocollib.network.packet.PacketHeader;
import java.util.List;
@RequiredArgsConstructor
public class TcpPacketSizer extends ByteToMessageCodec<ByteBuf> {
private final Session session;
private final int size;
public TcpPacketSizer(Session session, int size) {
this.session = session;
this.size = size;
}
private final PacketHeader header;
private final PacketCodecHelper codecHelper;
@Override
public void encode(ChannelHandlerContext ctx, ByteBuf in, ByteBuf out) {
int size = header.getLengthSize();
if (size == 0) {
out.writeBytes(in);
return;
}
int length = in.readableBytes();
out.ensureWritable(this.session.getPacketProtocol().getPacketHeader().getLengthSize(length) + length);
this.session.getPacketProtocol().getPacketHeader().writeLength(out, this.session.getCodecHelper(), length);
out.ensureWritable(header.getLengthSize(length) + length);
header.writeLength(out, codecHelper, length);
out.writeBytes(in);
}
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf buf, List<Object> out) {
int size = header.getLengthSize();
if (size == 0) {
out.add(buf.retain());
return;
}
buf.markReaderIndex();
byte[] lengthBytes = new byte[size];
for (int index = 0; index < lengthBytes.length; index++) {
@ -37,8 +47,8 @@ public class TcpPacketSizer extends ByteToMessageCodec<ByteBuf> {
}
lengthBytes[index] = buf.readByte();
if ((this.session.getPacketProtocol().getPacketHeader().isLengthVariable() && lengthBytes[index] >= 0) || index == size - 1) {
int length = this.session.getPacketProtocol().getPacketHeader().readLength(Unpooled.wrappedBuffer(lengthBytes), this.session.getCodecHelper(), buf.readableBytes());
if ((header.isLengthVariable() && lengthBytes[index] >= 0) || index == size - 1) {
int length = header.readLength(Unpooled.wrappedBuffer(lengthBytes), codecHelper, buf.readableBytes());
if (buf.readableBytes() < length) {
buf.resetReaderIndex();
return;

View file

@ -2,13 +2,14 @@ package org.geysermc.mcprotocollib.network.tcp;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.util.concurrent.Future;
import io.netty.handler.timeout.ReadTimeoutHandler;
import io.netty.handler.timeout.WriteTimeoutHandler;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.geysermc.mcprotocollib.network.AbstractServer;
import org.geysermc.mcprotocollib.network.BuiltinFlags;
import org.geysermc.mcprotocollib.network.helper.TransportHelper;
@ -17,6 +18,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.net.InetSocketAddress;
import java.util.concurrent.CompletableFuture;
import java.util.function.Supplier;
public class TcpServer extends AbstractServer {
@ -51,7 +53,7 @@ public class TcpServer extends AbstractServer {
.localAddress(this.getHost(), this.getPort())
.childHandler(new ChannelInitializer<>() {
@Override
public void initChannel(Channel channel) {
public void initChannel(@NonNull Channel channel) {
InetSocketAddress address = (InetSocketAddress) channel.remoteAddress();
PacketProtocol protocol = createPacketProtocol();
@ -60,15 +62,16 @@ public class TcpServer extends AbstractServer {
ChannelPipeline pipeline = channel.pipeline();
session.refreshReadTimeoutHandler(channel);
session.refreshWriteTimeoutHandler(channel);
pipeline.addLast("read-timeout", new ReadTimeoutHandler(session.getFlag(BuiltinFlags.READ_TIMEOUT, 30)));
pipeline.addLast("write-timeout", new WriteTimeoutHandler(session.getFlag(BuiltinFlags.WRITE_TIMEOUT, 0)));
int size = protocol.getPacketHeader().getLengthSize();
if (size > 0) {
pipeline.addLast("sizer", new TcpPacketSizer(session, size));
}
pipeline.addLast("encryption", new TcpPacketEncryptor());
pipeline.addLast("sizer", new TcpPacketSizer(protocol.getPacketHeader(), session.getCodecHelper()));
pipeline.addLast("compression", new TcpPacketCompression(session.getCodecHelper()));
pipeline.addLast("flow-control", new TcpFlowControlHandler());
pipeline.addLast("codec", new TcpPacketCodec(session, false));
pipeline.addLast("flush-handler", new FlushHandler());
pipeline.addLast("manager", session);
}
});
@ -77,29 +80,22 @@ public class TcpServer extends AbstractServer {
bootstrap.option(ChannelOption.TCP_FASTOPEN, 3);
}
ChannelFuture future = bootstrap.bind();
CompletableFuture<Void> handleFuture = new CompletableFuture<>();
bootstrap.bind().addListener((ChannelFutureListener) future -> {
if (future.isSuccess()) {
channel = future.channel();
if (callback != null) {
callback.run();
}
} else {
log.error("Failed to bind connection listener.", future.cause());
}
handleFuture.complete(null);
});
if (wait) {
try {
future.sync();
} catch (InterruptedException e) {
}
channel = future.channel();
if (callback != null) {
callback.run();
}
} else {
future.addListener((ChannelFutureListener) future1 -> {
if (future1.isSuccess()) {
channel = future1.channel();
if (callback != null) {
callback.run();
}
} else {
log.error("Failed to asynchronously bind connection listener.", future1.cause());
}
});
handleFuture.join();
}
}
@ -107,26 +103,21 @@ public class TcpServer extends AbstractServer {
public void closeImpl(boolean wait, final Runnable callback) {
if (this.channel != null) {
if (this.channel.isOpen()) {
ChannelFuture future = this.channel.close();
if (wait) {
try {
future.sync();
} catch (InterruptedException e) {
CompletableFuture<Void> handleFuture = new CompletableFuture<>();
this.channel.close().addListener((ChannelFutureListener) future -> {
if (future.isSuccess()) {
if (callback != null) {
callback.run();
}
} else {
log.error("Failed to close connection listener.", future.cause());
}
if (callback != null) {
callback.run();
}
} else {
future.addListener((ChannelFutureListener) future1 -> {
if (future1.isSuccess()) {
if (callback != null) {
callback.run();
}
} else {
log.error("Failed to asynchronously close connection listener.", future1.cause());
}
});
handleFuture.complete(null);
});
if (wait) {
handleFuture.join();
}
}
@ -134,18 +125,17 @@ public class TcpServer extends AbstractServer {
}
if (this.group != null) {
Future<?> future = this.group.shutdownGracefully();
if (wait) {
try {
future.sync();
} catch (InterruptedException e) {
CompletableFuture<Void> handleFuture = new CompletableFuture<>();
this.group.shutdownGracefully().addListener(future -> {
if (!future.isSuccess()) {
log.debug("Failed to close connection listener.", future.cause());
}
} else {
future.addListener(future1 -> {
if (!future1.isSuccess()) {
log.debug("Failed to asynchronously close connection listener.", future1.cause());
}
});
handleFuture.complete(null);
});
if (wait) {
handleFuture.join();
}
this.group = null;

View file

@ -7,16 +7,15 @@ import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.EventLoop;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.timeout.ReadTimeoutHandler;
import io.netty.handler.timeout.WriteTimeoutHandler;
import io.netty.util.concurrent.DefaultThreadFactory;
import net.kyori.adventure.text.Component;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.geysermc.mcprotocollib.network.Flag;
import org.geysermc.mcprotocollib.network.NetworkConstants;
import org.geysermc.mcprotocollib.network.Session;
import org.geysermc.mcprotocollib.network.compression.ZlibCompression;
import org.geysermc.mcprotocollib.network.crypt.PacketEncryption;
import org.geysermc.mcprotocollib.network.compression.CompressionConfig;
import org.geysermc.mcprotocollib.network.crypt.EncryptionConfig;
import org.geysermc.mcprotocollib.network.event.session.ConnectedEvent;
import org.geysermc.mcprotocollib.network.event.session.DisconnectedEvent;
import org.geysermc.mcprotocollib.network.event.session.DisconnectingEvent;
@ -25,6 +24,8 @@ import org.geysermc.mcprotocollib.network.event.session.SessionEvent;
import org.geysermc.mcprotocollib.network.event.session.SessionListener;
import org.geysermc.mcprotocollib.network.packet.Packet;
import org.geysermc.mcprotocollib.network.packet.PacketProtocol;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.net.SocketAddress;
import java.util.Collections;
@ -36,6 +37,8 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
public abstract class TcpSession extends SimpleChannelInboundHandler<Packet> implements Session {
private static final Logger log = LoggerFactory.getLogger(TcpSession.class);
/**
* Controls whether non-priority packets are handled in a separate event loop
*/
@ -49,11 +52,6 @@ public abstract class TcpSession extends SimpleChannelInboundHandler<Packet> imp
private final PacketProtocol protocol;
private final EventLoop eventLoop = createEventLoop();
private int compressionThreshold = -1;
private int connectTimeout = 30;
private int readTimeout = 30;
private int writeTimeout = 0;
private final Map<String, Object> flags = new HashMap<>();
private final List<SessionListener> listeners = new CopyOnWriteArrayList<>();
@ -193,63 +191,23 @@ public abstract class TcpSession extends SimpleChannelInboundHandler<Packet> imp
}
@Override
public int getCompressionThreshold() {
return this.compressionThreshold;
}
@Override
public void setCompressionThreshold(int threshold, boolean validateDecompression) {
this.compressionThreshold = threshold;
if (this.channel != null) {
if (this.compressionThreshold >= 0) {
if (this.channel.pipeline().get("compression") == null) {
this.channel.pipeline().addBefore("codec", "compression",
new TcpPacketCompression(this, new ZlibCompression(), validateDecompression));
}
} else if (this.channel.pipeline().get("compression") != null) {
this.channel.pipeline().remove("compression");
}
public void setCompression(@Nullable CompressionConfig compressionConfig) {
if (this.channel == null) {
throw new IllegalStateException("You need to connect to set the compression!");
}
log.debug("Setting compression for session {}", this);
channel.attr(NetworkConstants.COMPRESSION_ATTRIBUTE_KEY).set(compressionConfig);
}
@Override
public void enableEncryption(PacketEncryption encryption) {
public void setEncryption(@Nullable EncryptionConfig encryptionConfig) {
if (channel == null) {
throw new IllegalStateException("Connect the client before initializing encryption!");
throw new IllegalStateException("You need to connect to set the encryption!");
}
channel.pipeline().addBefore("sizer", "encryption", new TcpPacketEncryptor(encryption));
}
@Override
public int getConnectTimeout() {
return this.connectTimeout;
}
@Override
public void setConnectTimeout(int timeout) {
this.connectTimeout = timeout;
}
@Override
public int getReadTimeout() {
return this.readTimeout;
}
@Override
public void setReadTimeout(int timeout) {
this.readTimeout = timeout;
this.refreshReadTimeoutHandler();
}
@Override
public int getWriteTimeout() {
return this.writeTimeout;
}
@Override
public void setWriteTimeout(int timeout) {
this.writeTimeout = timeout;
this.refreshWriteTimeoutHandler();
log.debug("Setting encryption for session {}", this);
channel.attr(NetworkConstants.ENCRYPTION_ATTRIBUTE_KEY).set(encryptionConfig);
}
@Override
@ -258,11 +216,17 @@ public abstract class TcpSession extends SimpleChannelInboundHandler<Packet> imp
}
@Override
public void send(Packet packet) {
public void send(@NonNull Packet packet, @Nullable Runnable onSent) {
if (this.channel == null) {
return;
}
// Same behaviour as vanilla, always offload packet sending to the event loop
if (!this.channel.eventLoop().inEventLoop()) {
this.channel.eventLoop().execute(() -> this.send(packet, onSent));
return;
}
PacketSendingEvent sendingEvent = new PacketSendingEvent(this, packet);
this.callEvent(sendingEvent);
@ -270,6 +234,10 @@ public abstract class TcpSession extends SimpleChannelInboundHandler<Packet> imp
final Packet toSend = sendingEvent.getPacket();
this.channel.writeAndFlush(toSend).addListener((ChannelFutureListener) future -> {
if (future.isSuccess()) {
if (onSent != null) {
onSent.run();
}
callPacketSent(toSend);
} else {
exceptionCaught(null, future.cause());
@ -294,6 +262,13 @@ public abstract class TcpSession extends SimpleChannelInboundHandler<Packet> imp
}
}
@Override
public void setAutoRead(boolean autoRead) {
if (this.channel != null) {
this.channel.config().setAutoRead(autoRead);
}
}
private @Nullable EventLoop createEventLoop() {
if (!USE_EVENT_LOOP_FOR_PACKETS) {
return null;
@ -304,55 +279,16 @@ public abstract class TcpSession extends SimpleChannelInboundHandler<Packet> imp
// daemon threads and their interaction with the runtime.
PACKET_EVENT_LOOP = new DefaultEventLoopGroup(new DefaultThreadFactory(this.getClass(), true));
Runtime.getRuntime().addShutdownHook(new Thread(
() -> PACKET_EVENT_LOOP.shutdownGracefully(SHUTDOWN_QUIET_PERIOD_MS, SHUTDOWN_TIMEOUT_MS, TimeUnit.MILLISECONDS)));
() -> PACKET_EVENT_LOOP.shutdownGracefully(SHUTDOWN_QUIET_PERIOD_MS, SHUTDOWN_TIMEOUT_MS, TimeUnit.MILLISECONDS)));
}
return PACKET_EVENT_LOOP.next();
}
@Override
public Channel getChannel() {
return this.channel;
}
protected void refreshReadTimeoutHandler() {
this.refreshReadTimeoutHandler(this.channel);
}
protected void refreshReadTimeoutHandler(Channel channel) {
if (channel != null) {
if (this.readTimeout <= 0) {
if (channel.pipeline().get("readTimeout") != null) {
channel.pipeline().remove("readTimeout");
}
} else {
if (channel.pipeline().get("readTimeout") == null) {
channel.pipeline().addFirst("readTimeout", new ReadTimeoutHandler(this.readTimeout));
} else {
channel.pipeline().replace("readTimeout", "readTimeout", new ReadTimeoutHandler(this.readTimeout));
}
}
}
}
protected void refreshWriteTimeoutHandler() {
this.refreshWriteTimeoutHandler(this.channel);
}
protected void refreshWriteTimeoutHandler(Channel channel) {
if (channel != null) {
if (this.writeTimeout <= 0) {
if (channel.pipeline().get("writeTimeout") != null) {
channel.pipeline().remove("writeTimeout");
}
} else {
if (channel.pipeline().get("writeTimeout") == null) {
channel.pipeline().addFirst("writeTimeout", new WriteTimeoutHandler(this.writeTimeout));
} else {
channel.pipeline().replace("writeTimeout", "writeTimeout", new WriteTimeoutHandler(this.writeTimeout));
}
}
}
}
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
if (this.disconnected || this.channel != null) {

View file

@ -7,6 +7,8 @@ import net.kyori.adventure.text.Component;
import org.geysermc.mcprotocollib.auth.GameProfile;
import org.geysermc.mcprotocollib.auth.SessionService;
import org.geysermc.mcprotocollib.network.Session;
import org.geysermc.mcprotocollib.network.compression.CompressionConfig;
import org.geysermc.mcprotocollib.network.compression.ZlibCompression;
import org.geysermc.mcprotocollib.network.event.session.ConnectedEvent;
import org.geysermc.mcprotocollib.network.event.session.SessionAdapter;
import org.geysermc.mcprotocollib.network.packet.Packet;
@ -45,6 +47,7 @@ import javax.crypto.SecretKey;
import java.io.IOException;
import java.security.NoSuchAlgorithmException;
import java.util.Collections;
import java.util.Objects;
/**
* Handles making initial login and status requests for clients.
@ -58,12 +61,12 @@ public class ClientListener extends SessionAdapter {
@Override
public void packetReceived(Session session, Packet packet) {
MinecraftProtocol protocol = (MinecraftProtocol) session.getPacketProtocol();
if (protocol.getState() == ProtocolState.LOGIN) {
if (protocol.getInboundState() == ProtocolState.LOGIN) {
if (packet instanceof ClientboundHelloPacket helloPacket) {
GameProfile profile = session.getFlag(MinecraftConstants.PROFILE_KEY);
String accessToken = session.getFlag(MinecraftConstants.ACCESS_TOKEN_KEY);
if (profile == null || accessToken == null) {
if ((profile == null || accessToken == null) && helloPacket.isShouldAuthenticate()) {
throw new UnexpectedEncryptionException();
}
@ -81,22 +84,29 @@ public class ClientListener extends SessionAdapter {
// TODO: Add generic error, disabled multiplayer and banned from playing online errors
try {
sessionService.joinServer(profile, accessToken, serverId);
if (helloPacket.isShouldAuthenticate()) {
sessionService.joinServer(Objects.requireNonNull(profile, "final shouldAuthenticate changed value?"), accessToken, serverId);
}
} catch (IOException e) {
session.disconnect(Component.translatable("disconnect.loginFailedInfo", Component.text(e.getMessage())), e);
return;
}
session.send(new ServerboundKeyPacket(helloPacket.getPublicKey(), key, helloPacket.getChallenge()));
session.enableEncryption(protocol.enableEncryption(key));
session.send(new ServerboundKeyPacket(helloPacket.getPublicKey(), key, helloPacket.getChallenge()),
() -> session.setEncryption(protocol.createEncryption(key)));
} else if (packet instanceof ClientboundLoginFinishedPacket) {
session.switchInboundState(() -> protocol.setInboundState(ProtocolState.CONFIGURATION));
session.send(new ServerboundLoginAcknowledgedPacket());
session.switchOutboundState(() -> protocol.setOutboundState(ProtocolState.CONFIGURATION));
} else if (packet instanceof ClientboundLoginDisconnectPacket loginDisconnectPacket) {
session.disconnect(loginDisconnectPacket.getReason());
} else if (packet instanceof ClientboundLoginCompressionPacket loginCompressionPacket) {
session.setCompressionThreshold(loginCompressionPacket.getThreshold(), false);
int threshold = loginCompressionPacket.getThreshold();
if (threshold >= 0) {
session.setCompression(new CompressionConfig(threshold, new ZlibCompression(), false));
}
}
} else if (protocol.getState() == ProtocolState.STATUS) {
} else if (protocol.getInboundState() == ProtocolState.STATUS) {
if (packet instanceof ClientboundStatusResponsePacket statusResponsePacket) {
ServerStatusInfo info = statusResponsePacket.parseInfo();
ServerInfoHandler handler = session.getFlag(MinecraftConstants.SERVER_INFO_HANDLER_KEY);
@ -114,13 +124,15 @@ public class ClientListener extends SessionAdapter {
session.disconnect(Component.translatable("multiplayer.status.finished"));
}
} else if (protocol.getState() == ProtocolState.GAME) {
} else if (protocol.getInboundState() == ProtocolState.GAME) {
if (packet instanceof ClientboundKeepAlivePacket keepAlivePacket && session.getFlag(MinecraftConstants.AUTOMATIC_KEEP_ALIVE_MANAGEMENT, true)) {
session.send(new ServerboundKeepAlivePacket(keepAlivePacket.getPingId()));
} else if (packet instanceof ClientboundDisconnectPacket disconnectPacket) {
session.disconnect(disconnectPacket.getReason());
} else if (packet instanceof ClientboundStartConfigurationPacket) {
session.switchInboundState(() -> protocol.setInboundState(ProtocolState.CONFIGURATION));
session.send(new ServerboundConfigurationAcknowledgedPacket());
session.switchOutboundState(() -> protocol.setOutboundState(ProtocolState.CONFIGURATION));
} else if (packet instanceof ClientboundTransferPacket transferPacket) {
if (session.getFlag(MinecraftConstants.FOLLOW_TRANSFERS, true)) {
TcpClientSession newSession = new TcpClientSession(transferPacket.getHost(), transferPacket.getPort(), session.getPacketProtocol());
@ -129,9 +141,13 @@ public class ClientListener extends SessionAdapter {
newSession.connect(true, true);
}
}
} else if (protocol.getState() == ProtocolState.CONFIGURATION) {
if (packet instanceof ClientboundFinishConfigurationPacket) {
} else if (protocol.getInboundState() == ProtocolState.CONFIGURATION) {
if (packet instanceof ClientboundKeepAlivePacket keepAlivePacket && session.getFlag(MinecraftConstants.AUTOMATIC_KEEP_ALIVE_MANAGEMENT, true)) {
session.send(new ServerboundKeepAlivePacket(keepAlivePacket.getPingId()));
} else if (packet instanceof ClientboundFinishConfigurationPacket) {
session.switchInboundState(() -> protocol.setInboundState(ProtocolState.GAME));
session.send(new ServerboundFinishConfigurationPacket());
session.switchOutboundState(() -> protocol.setOutboundState(ProtocolState.GAME));
} else if (packet instanceof ClientboundSelectKnownPacks) {
if (session.getFlag(MinecraftConstants.SEND_BLANK_KNOWN_PACKS_RESPONSE, true)) {
session.send(new ServerboundSelectKnownPacks(Collections.emptyList()));
@ -148,38 +164,25 @@ public class ClientListener extends SessionAdapter {
}
@Override
public void packetSent(Session session, Packet packet) {
public void connected(ConnectedEvent event) {
Session session = event.getSession();
MinecraftProtocol protocol = (MinecraftProtocol) session.getPacketProtocol();
if (packet instanceof ClientIntentionPacket) {
// Once the HandshakePacket has been sent, switch to the next protocol mode.
protocol.setState(this.targetState);
ClientIntentionPacket intention = new ClientIntentionPacket(protocol.getCodec().getProtocolVersion(), session.getHost(), session.getPort(), switch (targetState) {
case LOGIN -> transferring ? HandshakeIntent.TRANSFER : HandshakeIntent.LOGIN;
case STATUS -> HandshakeIntent.STATUS;
default -> throw new IllegalStateException("Unexpected value: " + targetState);
});
if (this.targetState == ProtocolState.LOGIN) {
session.switchInboundState(() -> protocol.setInboundState(this.targetState));
session.send(intention);
session.switchOutboundState(() -> protocol.setOutboundState(this.targetState));
switch (this.targetState) {
case LOGIN -> {
GameProfile profile = session.getFlag(MinecraftConstants.PROFILE_KEY);
session.send(new ServerboundHelloPacket(profile.getName(), profile.getId()));
} else {
session.send(new ServerboundStatusRequestPacket());
}
} else if (packet instanceof ServerboundLoginAcknowledgedPacket) {
protocol.setState(ProtocolState.CONFIGURATION); // LOGIN -> CONFIGURATION
} else if (packet instanceof ServerboundFinishConfigurationPacket) {
protocol.setState(ProtocolState.GAME); // CONFIGURATION -> GAME
} else if (packet instanceof ServerboundConfigurationAcknowledgedPacket) {
protocol.setState(ProtocolState.CONFIGURATION); // GAME -> CONFIGURATION
}
}
@Override
public void connected(ConnectedEvent event) {
MinecraftProtocol protocol = (MinecraftProtocol) event.getSession().getPacketProtocol();
if (this.targetState == ProtocolState.LOGIN) {
if (this.transferring) {
event.getSession().send(new ClientIntentionPacket(protocol.getCodec().getProtocolVersion(), event.getSession().getHost(), event.getSession().getPort(), HandshakeIntent.TRANSFER));
} else {
event.getSession().send(new ClientIntentionPacket(protocol.getCodec().getProtocolVersion(), event.getSession().getHost(), event.getSession().getPort(), HandshakeIntent.LOGIN));
}
} else if (this.targetState == ProtocolState.STATUS) {
event.getSession().send(new ClientIntentionPacket(protocol.getCodec().getProtocolVersion(), event.getSession().getHost(), event.getSession().getPort(), HandshakeIntent.STATUS));
case STATUS -> session.send(new ServerboundStatusRequestPacket());
default -> throw new IllegalStateException("Unexpected value: " + targetState);
}
}
}

View file

@ -63,9 +63,14 @@ public final class MinecraftConstants {
// Server Key Constants
/**
* Session flag for determining whether to verify users. Server only.
* Session flag for determining whether to encrypt the connection. Server only.
*/
public static final Flag<Boolean> VERIFY_USERS_KEY = new Flag<>("verify-users", Boolean.class);
public static final Flag<Boolean> ENCRYPT_CONNECTION = new Flag<>("encrypt-connection", Boolean.class);
/**
* Session flag for determining whether to authenticate users with the session service. Server only.
*/
public static final Flag<Boolean> SHOULD_AUTHENTICATE = new Flag<>("should-authenticate", Boolean.class);
/**
* Session flag for determining whether to accept transferred connections. Server only.

View file

@ -10,7 +10,7 @@ import org.geysermc.mcprotocollib.auth.GameProfile;
import org.geysermc.mcprotocollib.network.Server;
import org.geysermc.mcprotocollib.network.Session;
import org.geysermc.mcprotocollib.network.crypt.AESEncryption;
import org.geysermc.mcprotocollib.network.crypt.PacketEncryption;
import org.geysermc.mcprotocollib.network.crypt.EncryptionConfig;
import org.geysermc.mcprotocollib.network.packet.PacketHeader;
import org.geysermc.mcprotocollib.network.packet.PacketProtocol;
import org.geysermc.mcprotocollib.network.packet.PacketRegistry;
@ -18,6 +18,8 @@ import org.geysermc.mcprotocollib.protocol.codec.MinecraftCodec;
import org.geysermc.mcprotocollib.protocol.codec.MinecraftCodecHelper;
import org.geysermc.mcprotocollib.protocol.codec.PacketCodec;
import org.geysermc.mcprotocollib.protocol.data.ProtocolState;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.InputStream;
import java.security.GeneralSecurityException;
@ -29,6 +31,7 @@ import java.util.UUID;
* Implements the Minecraft protocol.
*/
public class MinecraftProtocol extends PacketProtocol {
private static final Logger log = LoggerFactory.getLogger(MinecraftProtocol.class);
/**
* The network codec sent from the server to the client during {@link ProtocolState#CONFIGURATION}.
@ -44,8 +47,11 @@ public class MinecraftProtocol extends PacketProtocol {
@Getter
private final PacketCodec codec;
private ProtocolState state;
private PacketRegistry stateRegistry;
private ProtocolState inboundState;
private PacketRegistry inboundStateRegistry;
private ProtocolState outboundState;
private PacketRegistry outboundStateRegistry;
private final ProtocolState targetState;
@ -84,7 +90,7 @@ public class MinecraftProtocol extends PacketProtocol {
this.codec = codec;
this.targetState = ProtocolState.STATUS;
this.setState(ProtocolState.HANDSHAKE);
resetStates();
}
/**
@ -129,7 +135,7 @@ public class MinecraftProtocol extends PacketProtocol {
this.profile = profile;
this.accessToken = accessToken;
this.setState(ProtocolState.HANDSHAKE);
resetStates();
}
@Override
@ -152,7 +158,7 @@ public class MinecraftProtocol extends PacketProtocol {
session.setFlag(MinecraftConstants.PROFILE_KEY, this.profile);
session.setFlag(MinecraftConstants.ACCESS_TOKEN_KEY, this.accessToken);
this.setState(ProtocolState.HANDSHAKE);
resetStates();
if (this.useDefaultListeners) {
session.addListener(new ClientListener(this.targetState, transferring));
@ -161,7 +167,7 @@ public class MinecraftProtocol extends PacketProtocol {
@Override
public void newServerSession(Server server, Session session) {
this.setState(ProtocolState.HANDSHAKE);
resetStates();
if (this.useDefaultListeners) {
if (DEFAULT_NETWORK_CODEC == null) {
@ -173,30 +179,61 @@ public class MinecraftProtocol extends PacketProtocol {
}
@Override
public PacketRegistry getPacketRegistry() {
return this.stateRegistry;
public PacketRegistry getInboundPacketRegistry() {
return this.inboundStateRegistry;
}
protected PacketEncryption enableEncryption(Key key) {
@Override
public PacketRegistry getOutboundPacketRegistry() {
return this.outboundStateRegistry;
}
protected EncryptionConfig createEncryption(Key key) {
try {
return new AESEncryption(key);
return new EncryptionConfig(new AESEncryption(key));
} catch (GeneralSecurityException e) {
throw new Error("Failed to enable protocol encryption.", e);
throw new IllegalStateException("Failed to create protocol encryption.", e);
}
}
/**
* Gets the current {@link ProtocolState} the client is in.
*
* @return The current {@link ProtocolState}.
* Resets the protocol states to {@link ProtocolState#HANDSHAKE}.
*/
public ProtocolState getState() {
return this.state;
public void resetStates() {
this.setInboundState(ProtocolState.HANDSHAKE);
this.setOutboundState(ProtocolState.HANDSHAKE);
}
public void setState(ProtocolState state) {
this.state = state;
this.stateRegistry = this.codec.getCodec(state);
/**
* Gets the current inbound {@link ProtocolState} we're in.
*
* @return The current inbound {@link ProtocolState}.
*/
public ProtocolState getInboundState() {
return this.inboundState;
}
/**
* Gets the current outbound {@link ProtocolState} we're in.
*
* @return The current outbound {@link ProtocolState}.
*/
public ProtocolState getOutboundState() {
return this.outboundState;
}
public void setInboundState(ProtocolState state) {
log.debug("Setting inbound protocol state to: {}", state);
this.inboundState = state;
this.inboundStateRegistry = this.codec.getCodec(state);
}
public void setOutboundState(ProtocolState state) {
log.debug("Setting outbound protocol state to: {}", state);
this.outboundState = state;
this.outboundStateRegistry = this.codec.getCodec(state);
}
public static NbtMap loadNetworkCodec() {

View file

@ -1,6 +1,6 @@
package org.geysermc.mcprotocollib.protocol;
import lombok.RequiredArgsConstructor;
import lombok.Getter;
import net.kyori.adventure.key.Key;
import net.kyori.adventure.text.Component;
import org.cloudburstmc.nbt.NbtMap;
@ -8,6 +8,8 @@ import org.cloudburstmc.nbt.NbtType;
import org.geysermc.mcprotocollib.auth.GameProfile;
import org.geysermc.mcprotocollib.auth.SessionService;
import org.geysermc.mcprotocollib.network.Session;
import org.geysermc.mcprotocollib.network.compression.CompressionConfig;
import org.geysermc.mcprotocollib.network.compression.ZlibCompression;
import org.geysermc.mcprotocollib.network.event.session.ConnectedEvent;
import org.geysermc.mcprotocollib.network.event.session.DisconnectingEvent;
import org.geysermc.mcprotocollib.network.event.session.SessionAdapter;
@ -75,9 +77,10 @@ public class ServerListener extends SessionAdapter {
private final byte[] challenge = new byte[4];
private String username = "";
private KeepAliveState keepAliveState;
private long lastPingTime = 0;
private int lastPingId = 0;
@Getter
private boolean isTransfer = false;
public ServerListener(NbtMap networkCodec) {
this.networkCodec = networkCodec;
@ -86,40 +89,33 @@ public class ServerListener extends SessionAdapter {
@Override
public void connected(ConnectedEvent event) {
event.getSession().setFlag(MinecraftConstants.PING_KEY, 0L);
Session session = event.getSession();
session.setFlag(MinecraftConstants.PING_KEY, 0L);
}
@Override
public void packetReceived(Session session, Packet packet) {
MinecraftProtocol protocol = (MinecraftProtocol) session.getPacketProtocol();
if (protocol.getState() == ProtocolState.HANDSHAKE) {
if (protocol.getInboundState() == ProtocolState.HANDSHAKE) {
if (packet instanceof ClientIntentionPacket intentionPacket) {
switch (intentionPacket.getIntent()) {
case STATUS -> protocol.setState(ProtocolState.STATUS);
case TRANSFER -> {
if (!session.getFlag(MinecraftConstants.ACCEPT_TRANSFERS_KEY, false)) {
session.disconnect(Component.translatable("multiplayer.disconnect.transfers_disabled"));
}
}
case LOGIN -> {
protocol.setState(ProtocolState.LOGIN);
if (intentionPacket.getProtocolVersion() > protocol.getCodec().getProtocolVersion()) {
session.disconnect(Component.translatable("multiplayer.disconnect.incompatible", Component.text(protocol.getCodec().getMinecraftVersion())));
} else if (intentionPacket.getProtocolVersion() < protocol.getCodec().getProtocolVersion()) {
session.disconnect(Component.translatable("multiplayer.disconnect.outdated_client", Component.text(protocol.getCodec().getMinecraftVersion())));
}
case STATUS -> {
protocol.setOutboundState(ProtocolState.STATUS);
session.switchInboundState(() -> protocol.setInboundState(ProtocolState.STATUS));
}
case TRANSFER -> beginLogin(session, protocol, intentionPacket, true);
case LOGIN -> beginLogin(session, protocol, intentionPacket, false);
default -> throw new UnsupportedOperationException("Invalid client intent: " + intentionPacket.getIntent());
}
}
} else if (protocol.getState() == ProtocolState.LOGIN) {
} else if (protocol.getInboundState() == ProtocolState.LOGIN) {
if (packet instanceof ServerboundHelloPacket helloPacket) {
this.username = helloPacket.getUsername();
if (session.getFlag(MinecraftConstants.VERIFY_USERS_KEY, true)) {
session.send(new ClientboundHelloPacket(SERVER_ID, KEY_PAIR.getPublic(), this.challenge, true));
if (session.getFlag(MinecraftConstants.ENCRYPT_CONNECTION, true)) {
session.send(new ClientboundHelloPacket(SERVER_ID, KEY_PAIR.getPublic(), this.challenge, session.getFlag(MinecraftConstants.SHOULD_AUTHENTICATE, true)));
} else {
new Thread(new UserAuthTask(session, null)).start();
new Thread(() -> authenticate(session, false, null)).start();
}
} else if (packet instanceof ServerboundKeyPacket keyPacket) {
PrivateKey privateKey = KEY_PAIR.getPrivate();
@ -129,10 +125,16 @@ public class ServerListener extends SessionAdapter {
}
SecretKey key = keyPacket.getSecretKey(privateKey);
session.enableEncryption(protocol.enableEncryption(key));
new Thread(new UserAuthTask(session, key)).start();
session.setEncryption(protocol.createEncryption(key));
new Thread(() -> authenticate(session, session.getFlag(MinecraftConstants.SHOULD_AUTHENTICATE, true), key)).start();
} else if (packet instanceof ServerboundLoginAcknowledgedPacket) {
protocol.setState(ProtocolState.CONFIGURATION);
protocol.setOutboundState(ProtocolState.CONFIGURATION);
session.switchInboundState(() -> protocol.setInboundState(ProtocolState.CONFIGURATION));
keepAliveState = new KeepAliveState();
if (session.getFlag(MinecraftConstants.AUTOMATIC_KEEP_ALIVE_MANAGEMENT, true)) {
// If keepalive state is null, lets assume there is no keepalive thread yet
new Thread(() -> keepAlive(session)).start();
}
// Credit ViaVersion: https://github.com/ViaVersion/ViaVersion/blob/dev/common/src/main/java/com/viaversion/viaversion/protocols/protocol1_20_5to1_20_3/rewriter/EntityPacketRewriter1_20_5.java
for (Map.Entry<String, Object> entry : networkCodec.entrySet()) {
@ -154,7 +156,7 @@ public class ServerListener extends SessionAdapter {
session.send(new ClientboundFinishConfigurationPacket());
}
} else if (protocol.getState() == ProtocolState.STATUS) {
} else if (protocol.getInboundState() == ProtocolState.STATUS) {
if (packet instanceof ServerboundStatusRequestPacket) {
ServerInfoBuilder builder = session.getFlag(MinecraftConstants.SERVER_INFO_BUILDER_KEY);
if (builder == null) {
@ -172,100 +174,132 @@ public class ServerListener extends SessionAdapter {
} else if (packet instanceof ServerboundPingRequestPacket pingRequestPacket) {
session.send(new ClientboundPongResponsePacket(pingRequestPacket.getPingTime()));
}
} else if (protocol.getState() == ProtocolState.GAME) {
} else if (protocol.getInboundState() == ProtocolState.GAME) {
if (packet instanceof ServerboundKeepAlivePacket keepAlivePacket) {
if (keepAlivePacket.getPingId() == this.lastPingId) {
long time = System.currentTimeMillis() - this.lastPingTime;
session.setFlag(MinecraftConstants.PING_KEY, time);
}
handleKeepAlive(session, keepAlivePacket);
} else if (packet instanceof ServerboundConfigurationAcknowledgedPacket) {
protocol.setState(ProtocolState.CONFIGURATION);
// The developer who sends ClientboundStartConfigurationPacket needs to setOutboundState to CONFIGURATION
// after sending the packet. We can't do it in this class because it needs to be a method call right after it was sent.
// Using nettys event loop to change outgoing state may cause differences to vanilla.
session.switchInboundState(() -> protocol.setInboundState(ProtocolState.CONFIGURATION));
keepAliveState = new KeepAliveState();
} else if (packet instanceof ServerboundPingRequestPacket pingRequestPacket) {
session.send(new ClientboundPongResponsePacket(pingRequestPacket.getPingTime()));
session.disconnect(Component.translatable("multiplayer.status.request_handled"));
}
} else if (protocol.getState() == ProtocolState.CONFIGURATION) {
if (packet instanceof ServerboundFinishConfigurationPacket) {
protocol.setState(ProtocolState.GAME);
} else if (protocol.getInboundState() == ProtocolState.CONFIGURATION) {
if (packet instanceof ServerboundKeepAlivePacket keepAlivePacket) {
handleKeepAlive(session, keepAlivePacket);
} else if (packet instanceof ServerboundFinishConfigurationPacket) {
protocol.setOutboundState(ProtocolState.GAME);
session.switchInboundState(() -> protocol.setInboundState(ProtocolState.GAME));
keepAliveState = new KeepAliveState();
ServerLoginHandler handler = session.getFlag(MinecraftConstants.SERVER_LOGIN_HANDLER_KEY);
if (handler != null) {
handler.loggedIn(session);
}
if (session.getFlag(MinecraftConstants.AUTOMATIC_KEEP_ALIVE_MANAGEMENT, true)) {
new Thread(new KeepAliveTask(session)).start();
}
}
}
}
@Override
public void packetSent(Session session, Packet packet) {
if (packet instanceof ClientboundLoginCompressionPacket loginCompressionPacket) {
session.setCompressionThreshold(loginCompressionPacket.getThreshold(), true);
session.send(new ClientboundLoginFinishedPacket(session.getFlag(MinecraftConstants.PROFILE_KEY)));
private void handleKeepAlive(Session session, ServerboundKeepAlivePacket keepAlivePacket) {
KeepAliveState currentKeepAliveState = this.keepAliveState;
if (currentKeepAliveState != null) {
if (currentKeepAliveState.keepAlivePending && keepAlivePacket.getPingId() == currentKeepAliveState.keepAliveChallenge) {
currentKeepAliveState.keepAlivePending = false;
session.setFlag(MinecraftConstants.PING_KEY, System.currentTimeMillis() - currentKeepAliveState.keepAliveTime);
} else {
session.disconnect(Component.translatable("disconnect.timeout"));
}
}
}
private void beginLogin(Session session, MinecraftProtocol protocol, ClientIntentionPacket packet, boolean transferred) {
isTransfer = transferred;
protocol.setOutboundState(ProtocolState.LOGIN);
if (transferred && !session.getFlag(MinecraftConstants.ACCEPT_TRANSFERS_KEY)) {
session.disconnect(Component.translatable("multiplayer.disconnect.transfers_disabled"));
} else if (packet.getProtocolVersion() > protocol.getCodec().getProtocolVersion()) {
session.disconnect(Component.translatable("multiplayer.disconnect.incompatible", Component.text(protocol.getCodec().getMinecraftVersion())));
} else if (packet.getProtocolVersion() < protocol.getCodec().getProtocolVersion()) {
session.disconnect(Component.translatable("multiplayer.disconnect.outdated_client", Component.text(protocol.getCodec().getMinecraftVersion())));
} else {
session.switchInboundState(() -> protocol.setInboundState(ProtocolState.LOGIN));
}
}
@Override
public void disconnecting(DisconnectingEvent event) {
MinecraftProtocol protocol = (MinecraftProtocol) event.getSession().getPacketProtocol();
if (protocol.getState() == ProtocolState.LOGIN) {
event.getSession().send(new ClientboundLoginDisconnectPacket(event.getReason()));
} else if (protocol.getState() == ProtocolState.GAME) {
event.getSession().send(new ClientboundDisconnectPacket(event.getReason()));
Session session = event.getSession();
MinecraftProtocol protocol = (MinecraftProtocol) session.getPacketProtocol();
if (protocol.getOutboundState() == ProtocolState.LOGIN) {
session.send(new ClientboundLoginDisconnectPacket(event.getReason()));
} else if (protocol.getOutboundState() == ProtocolState.GAME) {
session.send(new ClientboundDisconnectPacket(event.getReason()));
}
}
@RequiredArgsConstructor
private class UserAuthTask implements Runnable {
private final Session session;
private final SecretKey key;
@Override
public void run() {
GameProfile profile;
if (this.key != null) {
SessionService sessionService = this.session.getFlag(MinecraftConstants.SESSION_SERVICE_KEY, new SessionService());
try {
profile = sessionService.getProfileByServer(username, SessionService.getServerId(SERVER_ID, KEY_PAIR.getPublic(), this.key));
} catch (IOException e) {
session.disconnect(Component.translatable("multiplayer.disconnect.authservers_down"), e);
return;
}
if (profile == null) {
session.disconnect(Component.translatable("multiplayer.disconnect.unverified_username"));
return;
}
} else {
profile = new GameProfile(UUID.nameUUIDFromBytes(("OfflinePlayer:" + username).getBytes()), username);
private void authenticate(Session session, boolean shouldAuthenticate, SecretKey key) {
GameProfile profile;
if (shouldAuthenticate && key != null) {
SessionService sessionService = session.getFlag(MinecraftConstants.SESSION_SERVICE_KEY, new SessionService());
try {
profile = sessionService.getProfileByServer(username, SessionService.getServerId(SERVER_ID, KEY_PAIR.getPublic(), key));
} catch (IOException e) {
session.disconnect(Component.translatable("multiplayer.disconnect.authservers_down"), e);
return;
}
this.session.setFlag(MinecraftConstants.PROFILE_KEY, profile);
if (profile == null) {
session.disconnect(Component.translatable("multiplayer.disconnect.unverified_username"));
return;
}
} else {
profile = new GameProfile(UUID.nameUUIDFromBytes(("OfflinePlayer:" + username).getBytes()), username);
}
int threshold = session.getFlag(MinecraftConstants.SERVER_COMPRESSION_THRESHOLD, DEFAULT_COMPRESSION_THRESHOLD);
this.session.send(new ClientboundLoginCompressionPacket(threshold));
session.setFlag(MinecraftConstants.PROFILE_KEY, profile);
int threshold = session.getFlag(MinecraftConstants.SERVER_COMPRESSION_THRESHOLD, DEFAULT_COMPRESSION_THRESHOLD);
if (threshold >= 0) {
session.send(new ClientboundLoginCompressionPacket(threshold), () ->
session.setCompression(new CompressionConfig(threshold, new ZlibCompression(), true)));
}
session.send(new ClientboundLoginFinishedPacket(profile));
}
private void keepAlive(Session session) {
while (session.isConnected()) {
KeepAliveState currentKeepAliveState = this.keepAliveState;
if (currentKeepAliveState != null) {
if (System.currentTimeMillis() - currentKeepAliveState.keepAliveTime >= 15000L) {
if (currentKeepAliveState.keepAlivePending) {
session.disconnect(Component.translatable("disconnect.timeout"));
break;
}
long time = System.currentTimeMillis();
currentKeepAliveState.keepAlivePending = true;
currentKeepAliveState.keepAliveChallenge = time;
currentKeepAliveState.keepAliveTime = time;
session.send(new ClientboundKeepAlivePacket(currentKeepAliveState.keepAliveChallenge));
}
}
// TODO: Implement proper tick loop rather than sleeping
try {
Thread.sleep(50);
} catch (InterruptedException e) {
break;
}
}
}
@RequiredArgsConstructor
private class KeepAliveTask implements Runnable {
private final Session session;
@Override
public void run() {
while (this.session.isConnected()) {
lastPingTime = System.currentTimeMillis();
lastPingId = (int) lastPingTime;
this.session.send(new ClientboundKeepAlivePacket(lastPingId));
try {
Thread.sleep(2000);
} catch (InterruptedException e) {
break;
}
}
}
private static class KeepAliveState {
private boolean keepAlivePending;
private long keepAliveChallenge;
private long keepAliveTime = System.currentTimeMillis();
}
}

View file

@ -15,4 +15,9 @@ public class ClientboundFinishConfigurationPacket implements MinecraftPacket {
@Override
public void serialize(ByteBuf out, MinecraftCodecHelper helper) {
}
@Override
public boolean isTerminal() {
return true;
}
}

View file

@ -15,4 +15,9 @@ public class ServerboundFinishConfigurationPacket implements MinecraftPacket {
public void serialize(ByteBuf buf, MinecraftCodecHelper helper) {
}
@Override
public boolean isTerminal() {
return true;
}
}

View file

@ -37,4 +37,9 @@ public class ClientIntentionPacket implements MinecraftPacket {
public boolean isPriority() {
return true;
}
@Override
public boolean isTerminal() {
return true;
}
}

View file

@ -15,4 +15,9 @@ public class ClientboundStartConfigurationPacket implements MinecraftPacket {
public void serialize(ByteBuf out, MinecraftCodecHelper helper) {
}
@Override
public boolean isTerminal() {
return true;
}
}

View file

@ -15,4 +15,9 @@ public class ServerboundConfigurationAcknowledgedPacket implements MinecraftPack
public void serialize(ByteBuf out, MinecraftCodecHelper helper) {
}
@Override
public boolean isTerminal() {
return true;
}
}

View file

@ -32,4 +32,9 @@ public class ClientboundLoginFinishedPacket implements MinecraftPacket {
public boolean isPriority() {
return true;
}
@Override
public boolean isTerminal() {
return true;
}
}

View file

@ -15,4 +15,9 @@ public class ServerboundLoginAcknowledgedPacket implements MinecraftPacket {
@Override
public void serialize(ByteBuf out, MinecraftCodecHelper helper) {
}
@Override
public boolean isTerminal() {
return true;
}
}

View file

@ -27,7 +27,6 @@ import java.util.ArrayList;
import java.util.concurrent.CountDownLatch;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.geysermc.mcprotocollib.protocol.MinecraftConstants.*;
import static org.junit.jupiter.api.Assertions.*;
public class MinecraftProtocolTest {
@ -49,10 +48,11 @@ public class MinecraftProtocolTest {
@BeforeAll
public static void setupServer() {
server = new TcpServer(HOST, PORT, MinecraftProtocol::new);
server.setGlobalFlag(VERIFY_USERS_KEY, false);
server.setGlobalFlag(SERVER_COMPRESSION_THRESHOLD, 100);
server.setGlobalFlag(SERVER_INFO_BUILDER_KEY, session -> SERVER_INFO);
server.setGlobalFlag(SERVER_LOGIN_HANDLER_KEY, session -> {
server.setGlobalFlag(MinecraftConstants.ENCRYPT_CONNECTION, true);
server.setGlobalFlag(MinecraftConstants.SHOULD_AUTHENTICATE, false);
server.setGlobalFlag(MinecraftConstants.SERVER_COMPRESSION_THRESHOLD, 256);
server.setGlobalFlag(MinecraftConstants.SERVER_INFO_BUILDER_KEY, session -> SERVER_INFO);
server.setGlobalFlag(MinecraftConstants.SERVER_LOGIN_HANDLER_KEY, session -> {
// Seems like in this setup the server can reply too quickly to ServerboundFinishConfigurationPacket
// before the client can transition CONFIGURATION -> GAME. There is probably something wrong here and this is just a band-aid.
try {
@ -79,7 +79,7 @@ public class MinecraftProtocolTest {
Session session = new TcpClientSession(HOST, PORT, new MinecraftProtocol());
try {
ServerInfoHandlerTest handler = new ServerInfoHandlerTest();
session.setFlag(SERVER_INFO_HANDLER_KEY, handler);
session.setFlag(MinecraftConstants.SERVER_INFO_HANDLER_KEY, handler);
session.addListener(new DisconnectListener());
session.connect();

View file

@ -0,0 +1 @@
org.slf4j.simpleLogger.defaultLogLevel=debug