TFO support (#793)

This commit is contained in:
Alex 2024-05-01 00:40:13 +02:00 committed by GitHub
parent bc8526b267
commit 114ebbdcf2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 138 additions and 120 deletions

View file

@ -20,6 +20,11 @@ public class BuiltinFlags {
*/
public static final Flag<Boolean> ATTEMPT_SRV_RESOLVE = new Flag<>("attempt-srv-resolve", Boolean.class);
/**
* When set to true, the client or server will attempt to use TCP Fast Open if supported.
*/
public static final Flag<Boolean> TCP_FAST_OPEN = new Flag<>("tcp-fast-open", Boolean.class);
private BuiltinFlags() {
}
}

View file

@ -1,19 +1,93 @@
package org.geysermc.mcprotocollib.network.helper;
import io.netty.channel.ChannelFactory;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.epoll.Epoll;
import io.netty.channel.epoll.EpollDatagramChannel;
import io.netty.channel.epoll.EpollEventLoopGroup;
import io.netty.channel.epoll.EpollServerSocketChannel;
import io.netty.channel.epoll.EpollSocketChannel;
import io.netty.channel.kqueue.KQueue;
import io.netty.channel.kqueue.KQueueDatagramChannel;
import io.netty.channel.kqueue.KQueueEventLoopGroup;
import io.netty.channel.kqueue.KQueueServerSocketChannel;
import io.netty.channel.kqueue.KQueueSocketChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.DatagramChannel;
import io.netty.channel.socket.ServerSocketChannel;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioDatagramChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.incubator.channel.uring.IOUring;
import io.netty.incubator.channel.uring.IOUringDatagramChannel;
import io.netty.incubator.channel.uring.IOUringEventLoopGroup;
import io.netty.incubator.channel.uring.IOUringServerSocketChannel;
import io.netty.incubator.channel.uring.IOUringSocketChannel;
import java.util.concurrent.ThreadFactory;
import java.util.function.Function;
public class TransportHelper {
public enum TransportMethod {
NIO, EPOLL, KQUEUE, IO_URING
}
public static TransportMethod determineTransportMethod() {
if (isClassAvailable("io.netty.incubator.channel.uring.IOUring") && IOUring.isAvailable()) return TransportMethod.IO_URING;
if (isClassAvailable("io.netty.channel.epoll.Epoll") && Epoll.isAvailable()) return TransportMethod.EPOLL;
if (isClassAvailable("io.netty.channel.kqueue.KQueue") && KQueue.isAvailable()) return TransportMethod.KQUEUE;
return TransportMethod.NIO;
public record TransportType(TransportMethod method,
ChannelFactory<? extends ServerSocketChannel> serverSocketChannelFactory,
ChannelFactory<? extends SocketChannel> socketChannelFactory,
ChannelFactory<? extends DatagramChannel> datagramChannelFactory,
Function<ThreadFactory, EventLoopGroup> eventLoopGroupFactory,
boolean supportsTcpFastOpenServer,
boolean supportsTcpFastOpenClient) {
}
public static TransportType determineTransportMethod() {
if (isClassAvailable("io.netty.incubator.channel.uring.IOUring") && IOUring.isAvailable()) {
return new TransportType(
TransportMethod.IO_URING,
IOUringServerSocketChannel::new,
IOUringSocketChannel::new,
IOUringDatagramChannel::new,
factory -> new IOUringEventLoopGroup(0, factory),
IOUring.isTcpFastOpenServerSideAvailable(),
IOUring.isTcpFastOpenClientSideAvailable()
);
}
if (isClassAvailable("io.netty.channel.epoll.Epoll") && Epoll.isAvailable()) {
return new TransportType(
TransportMethod.EPOLL,
EpollServerSocketChannel::new,
EpollSocketChannel::new,
EpollDatagramChannel::new,
factory -> new EpollEventLoopGroup(0, factory),
Epoll.isTcpFastOpenServerSideAvailable(),
Epoll.isTcpFastOpenClientSideAvailable()
);
}
if (isClassAvailable("io.netty.channel.kqueue.KQueue") && KQueue.isAvailable()) {
return new TransportType(
TransportMethod.KQUEUE,
KQueueServerSocketChannel::new,
KQueueSocketChannel::new,
KQueueDatagramChannel::new,
factory -> new KQueueEventLoopGroup(0, factory),
KQueue.isTcpFastOpenServerSideAvailable(),
KQueue.isTcpFastOpenClientSideAvailable()
);
}
return new TransportType(
TransportMethod.NIO,
NioServerSocketChannel::new,
NioSocketChannel::new,
NioDatagramChannel::new,
factory -> new NioEventLoopGroup(0, factory),
false,
false
);
}
/**

View file

@ -4,7 +4,6 @@ import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.channel.AddressedEnvelope;
import io.netty.channel.Channel;
import io.netty.channel.ChannelException;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
@ -12,16 +11,6 @@ import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.epoll.EpollDatagramChannel;
import io.netty.channel.epoll.EpollEventLoopGroup;
import io.netty.channel.epoll.EpollSocketChannel;
import io.netty.channel.kqueue.KQueueDatagramChannel;
import io.netty.channel.kqueue.KQueueEventLoopGroup;
import io.netty.channel.kqueue.KQueueSocketChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.DatagramChannel;
import io.netty.channel.socket.nio.NioDatagramChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.dns.DefaultDnsQuestion;
import io.netty.handler.codec.dns.DefaultDnsRawRecord;
import io.netty.handler.codec.dns.DefaultDnsRecordDecoder;
@ -36,9 +25,6 @@ import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol;
import io.netty.handler.proxy.HttpProxyHandler;
import io.netty.handler.proxy.Socks4ProxyHandler;
import io.netty.handler.proxy.Socks5ProxyHandler;
import io.netty.incubator.channel.uring.IOUringDatagramChannel;
import io.netty.incubator.channel.uring.IOUringEventLoopGroup;
import io.netty.incubator.channel.uring.IOUringSocketChannel;
import io.netty.resolver.dns.DnsNameResolver;
import io.netty.resolver.dns.DnsNameResolverBuilder;
import io.netty.util.concurrent.DefaultThreadFactory;
@ -56,9 +42,8 @@ import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
public class TcpClientSession extends TcpSession {
private static final TransportHelper.TransportType TRANSPORT_TYPE = TransportHelper.determineTransportMethod();
private static final String IP_REGEX = "\\b\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\b";
private static Class<? extends Channel> CHANNEL_CLASS;
private static Class<? extends DatagramChannel> DATAGRAM_CHANNEL_CLASS;
private static EventLoopGroup EVENT_LOOP_GROUP;
/**
@ -100,51 +85,47 @@ public class TcpClientSession extends TcpSession {
boolean debug = getFlag(BuiltinFlags.PRINT_DEBUG, false);
if (CHANNEL_CLASS == null) {
if (EVENT_LOOP_GROUP == null) {
createTcpEventLoopGroup();
}
try {
final Bootstrap bootstrap = new Bootstrap();
bootstrap.channel(CHANNEL_CLASS);
bootstrap.handler(new ChannelInitializer<>() {
@Override
public void initChannel(Channel channel) {
PacketProtocol protocol = getPacketProtocol();
protocol.newClientSession(TcpClientSession.this, transferring);
final Bootstrap bootstrap = new Bootstrap()
.channelFactory(TRANSPORT_TYPE.socketChannelFactory())
.option(ChannelOption.TCP_NODELAY, true)
.option(ChannelOption.IP_TOS, 0x18)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, getConnectTimeout() * 1000)
.group(EVENT_LOOP_GROUP)
.remoteAddress(resolveAddress())
.localAddress(bindAddress, bindPort)
.handler(new ChannelInitializer<>() {
@Override
public void initChannel(Channel channel) {
PacketProtocol protocol = getPacketProtocol();
protocol.newClientSession(TcpClientSession.this, transferring);
channel.config().setOption(ChannelOption.IP_TOS, 0x18);
try {
channel.config().setOption(ChannelOption.TCP_NODELAY, true);
} catch (ChannelException e) {
if (debug) {
System.out.println("Exception while trying to set TCP_NODELAY");
e.printStackTrace();
ChannelPipeline pipeline = channel.pipeline();
refreshReadTimeoutHandler(channel);
refreshWriteTimeoutHandler(channel);
addProxy(pipeline);
int size = protocol.getPacketHeader().getLengthSize();
if (size > 0) {
pipeline.addLast("sizer", new TcpPacketSizer(TcpClientSession.this, size));
}
pipeline.addLast("codec", new TcpPacketCodec(TcpClientSession.this, true));
pipeline.addLast("manager", TcpClientSession.this);
addHAProxySupport(pipeline);
}
}
});
ChannelPipeline pipeline = channel.pipeline();
refreshReadTimeoutHandler(channel);
refreshWriteTimeoutHandler(channel);
addProxy(pipeline);
int size = protocol.getPacketHeader().getLengthSize();
if (size > 0) {
pipeline.addLast("sizer", new TcpPacketSizer(TcpClientSession.this, size));
}
pipeline.addLast("codec", new TcpPacketCodec(TcpClientSession.this, true));
pipeline.addLast("manager", TcpClientSession.this);
addHAProxySupport(pipeline);
}
}).group(EVENT_LOOP_GROUP).option(ChannelOption.CONNECT_TIMEOUT_MILLIS, getConnectTimeout() * 1000);
InetSocketAddress remoteAddress = resolveAddress();
bootstrap.remoteAddress(remoteAddress);
bootstrap.localAddress(bindAddress, bindPort);
if (getFlag(BuiltinFlags.TCP_FAST_OPEN, false) && TRANSPORT_TYPE.supportsTcpFastOpenClient()) {
bootstrap.option(ChannelOption.TCP_FASTOPEN_CONNECT, true);
}
ChannelFuture future = bootstrap.connect();
if (wait) {
@ -177,7 +158,7 @@ public class TcpClientSession extends TcpSession {
if (getFlag(BuiltinFlags.ATTEMPT_SRV_RESOLVE, true) && (!this.host.matches(IP_REGEX) && !this.host.equalsIgnoreCase("localhost"))) {
AddressedEnvelope<DnsResponse, InetSocketAddress> envelope = null;
try (DnsNameResolver resolver = new DnsNameResolverBuilder(EVENT_LOOP_GROUP.next())
.channelType(DATAGRAM_CHANNEL_CLASS)
.channelFactory(TRANSPORT_TYPE.datagramChannelFactory())
.build()) {
envelope = resolver.query(new DefaultDnsQuestion(name, DnsRecordType.SRV)).get();
@ -294,32 +275,11 @@ public class TcpClientSession extends TcpSession {
}
private static void createTcpEventLoopGroup() {
if (CHANNEL_CLASS != null) {
if (EVENT_LOOP_GROUP != null) {
return;
}
switch (TransportHelper.determineTransportMethod()) {
case IO_URING -> {
EVENT_LOOP_GROUP = new IOUringEventLoopGroup(newThreadFactory());
CHANNEL_CLASS = IOUringSocketChannel.class;
DATAGRAM_CHANNEL_CLASS = IOUringDatagramChannel.class;
}
case EPOLL -> {
EVENT_LOOP_GROUP = new EpollEventLoopGroup(newThreadFactory());
CHANNEL_CLASS = EpollSocketChannel.class;
DATAGRAM_CHANNEL_CLASS = EpollDatagramChannel.class;
}
case KQUEUE -> {
EVENT_LOOP_GROUP = new KQueueEventLoopGroup(newThreadFactory());
CHANNEL_CLASS = KQueueSocketChannel.class;
DATAGRAM_CHANNEL_CLASS = KQueueDatagramChannel.class;
}
case NIO -> {
EVENT_LOOP_GROUP = new NioEventLoopGroup(newThreadFactory());
CHANNEL_CLASS = NioSocketChannel.class;
DATAGRAM_CHANNEL_CLASS = NioDatagramChannel.class;
}
}
EVENT_LOOP_GROUP = TRANSPORT_TYPE.eventLoopGroupFactory().apply(newThreadFactory());
Runtime.getRuntime().addShutdownHook(new Thread(
() -> EVENT_LOOP_GROUP.shutdownGracefully(SHUTDOWN_QUIET_PERIOD_MS, SHUTDOWN_TIMEOUT_MS, TimeUnit.MILLISECONDS)));

View file

@ -2,22 +2,12 @@ package org.geysermc.mcprotocollib.network.tcp;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelException;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.epoll.EpollEventLoopGroup;
import io.netty.channel.epoll.EpollServerSocketChannel;
import io.netty.channel.kqueue.KQueueEventLoopGroup;
import io.netty.channel.kqueue.KQueueServerSocketChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.ServerSocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.incubator.channel.uring.IOUringEventLoopGroup;
import io.netty.incubator.channel.uring.IOUringServerSocketChannel;
import io.netty.util.concurrent.Future;
import org.geysermc.mcprotocollib.network.AbstractServer;
import org.geysermc.mcprotocollib.network.BuiltinFlags;
@ -28,8 +18,8 @@ import java.net.InetSocketAddress;
import java.util.function.Supplier;
public class TcpServer extends AbstractServer {
private static final TransportHelper.TransportType TRANSPORT_TYPE = TransportHelper.determineTransportMethod();
private EventLoopGroup group;
private Class<? extends ServerSocketChannel> serverSocketChannel;
private Channel channel;
public TcpServer(String host, int port, Supplier<? extends PacketProtocol> protocol) {
@ -47,26 +37,15 @@ public class TcpServer extends AbstractServer {
return;
}
switch (TransportHelper.determineTransportMethod()) {
case IO_URING -> {
this.group = new IOUringEventLoopGroup();
this.serverSocketChannel = IOUringServerSocketChannel.class;
}
case EPOLL -> {
this.group = new EpollEventLoopGroup();
this.serverSocketChannel = EpollServerSocketChannel.class;
}
case KQUEUE -> {
this.group = new KQueueEventLoopGroup();
this.serverSocketChannel = KQueueServerSocketChannel.class;
}
case NIO -> {
this.group = new NioEventLoopGroup();
this.serverSocketChannel = NioServerSocketChannel.class;
}
}
this.group = TRANSPORT_TYPE.eventLoopGroupFactory().apply(null);
ChannelFuture future = new ServerBootstrap().channel(this.serverSocketChannel).childHandler(new ChannelInitializer<>() {
ServerBootstrap bootstrap = new ServerBootstrap()
.channelFactory(TRANSPORT_TYPE.serverSocketChannelFactory())
.group(this.group)
.childOption(ChannelOption.TCP_NODELAY, true)
.childOption(ChannelOption.IP_TOS, 0x18)
.localAddress(this.getHost(), this.getPort())
.childHandler(new ChannelInitializer<>() {
@Override
public void initChannel(Channel channel) {
InetSocketAddress address = (InetSocketAddress) channel.remoteAddress();
@ -75,12 +54,6 @@ public class TcpServer extends AbstractServer {
TcpSession session = new TcpServerSession(address.getHostName(), address.getPort(), protocol, TcpServer.this);
session.getPacketProtocol().newServerSession(TcpServer.this, session);
channel.config().setOption(ChannelOption.IP_TOS, 0x18);
try {
channel.config().setOption(ChannelOption.TCP_NODELAY, true);
} catch (ChannelException ignored) {
}
ChannelPipeline pipeline = channel.pipeline();
session.refreshReadTimeoutHandler(channel);
@ -94,7 +67,13 @@ public class TcpServer extends AbstractServer {
pipeline.addLast("codec", new TcpPacketCodec(session, false));
pipeline.addLast("manager", session);
}
}).group(this.group).localAddress(this.getHost(), this.getPort()).bind();
});
if (getGlobalFlag(BuiltinFlags.TCP_FAST_OPEN, false) && TRANSPORT_TYPE.supportsTcpFastOpenServer()) {
bootstrap.option(ChannelOption.TCP_FASTOPEN, 3);
}
ChannelFuture future = bootstrap.bind();
if (wait) {
try {