From 9037cb30eb0e960e891e245042cfbf0f4c55b298 Mon Sep 17 00:00:00 2001 From: Alex <40795980+AlexProgrammerDE@users.noreply.github.com> Date: Wed, 27 Nov 2024 14:07:43 +0100 Subject: [PATCH] Move reusable methods to a separate helper class (#863) * Move reusable methods to a separate helper class This way we allow other apps such as Geyser LocalSession to use these currently private methods without needing to copy over the code. * Remove unused field --- .../network/helper/NettyHelper.java | 149 +++++++++++++++++ .../network/helper/TransportHelper.java | 4 +- .../network/tcp/TcpClientSession.java | 151 +----------------- .../mcprotocollib/network/tcp/TcpServer.java | 7 +- 4 files changed, 162 insertions(+), 149 deletions(-) create mode 100644 protocol/src/main/java/org/geysermc/mcprotocollib/network/helper/NettyHelper.java diff --git a/protocol/src/main/java/org/geysermc/mcprotocollib/network/helper/NettyHelper.java b/protocol/src/main/java/org/geysermc/mcprotocollib/network/helper/NettyHelper.java new file mode 100644 index 00000000..fdbee4c7 --- /dev/null +++ b/protocol/src/main/java/org/geysermc/mcprotocollib/network/helper/NettyHelper.java @@ -0,0 +1,149 @@ +package org.geysermc.mcprotocollib.network.helper; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.AddressedEnvelope; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.EventLoop; +import io.netty.handler.codec.dns.DefaultDnsQuestion; +import io.netty.handler.codec.dns.DefaultDnsRawRecord; +import io.netty.handler.codec.dns.DefaultDnsRecordDecoder; +import io.netty.handler.codec.dns.DnsRecordType; +import io.netty.handler.codec.dns.DnsResponse; +import io.netty.handler.codec.dns.DnsSection; +import io.netty.handler.codec.haproxy.HAProxyCommand; +import io.netty.handler.codec.haproxy.HAProxyMessage; +import io.netty.handler.codec.haproxy.HAProxyMessageEncoder; +import io.netty.handler.codec.haproxy.HAProxyProtocolVersion; +import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol; +import io.netty.handler.proxy.HttpProxyHandler; +import io.netty.handler.proxy.Socks4ProxyHandler; +import io.netty.handler.proxy.Socks5ProxyHandler; +import io.netty.resolver.dns.DnsNameResolver; +import io.netty.resolver.dns.DnsNameResolverBuilder; +import org.geysermc.mcprotocollib.network.BuiltinFlags; +import org.geysermc.mcprotocollib.network.ProxyInfo; +import org.geysermc.mcprotocollib.network.Session; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.Inet4Address; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.UnknownHostException; + +public class NettyHelper { + private static final Logger log = LoggerFactory.getLogger(NettyHelper.class); + private static final String IP_REGEX = "\\b\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\b"; + + public static InetSocketAddress resolveAddress(Session session, EventLoop eventLoop, String host, int port) { + String name = session.getPacketProtocol().getSRVRecordPrefix() + "._tcp." + host; + log.debug("Attempting SRV lookup for \"{}\".", name); + + if (session.getFlag(BuiltinFlags.ATTEMPT_SRV_RESOLVE, true) && (!host.matches(IP_REGEX) && !host.equalsIgnoreCase("localhost"))) { + try (DnsNameResolver resolver = new DnsNameResolverBuilder(eventLoop) + .channelFactory(TransportHelper.TRANSPORT_TYPE.datagramChannelFactory()) + .build()) { + AddressedEnvelope envelope = resolver.query(new DefaultDnsQuestion(name, DnsRecordType.SRV)).get(); + try { + DnsResponse response = envelope.content(); + if (response.count(DnsSection.ANSWER) > 0) { + DefaultDnsRawRecord record = response.recordAt(DnsSection.ANSWER, 0); + if (record.type() == DnsRecordType.SRV) { + ByteBuf buf = record.content(); + buf.skipBytes(4); // Skip priority and weight. + + int tempPort = buf.readUnsignedShort(); + String tempHost = DefaultDnsRecordDecoder.decodeName(buf); + if (tempHost.endsWith(".")) { + tempHost = tempHost.substring(0, tempHost.length() - 1); + } + + log.debug("Found SRV record containing \"{}:{}\".", tempHost, tempPort); + + host = tempHost; + port = tempPort; + } else { + log.debug("Received non-SRV record in response."); + } + } else { + log.debug("No SRV record found."); + } + } finally { + envelope.release(); + } + } catch (Exception e) { + log.debug("Failed to resolve SRV record.", e); + } + } else { + log.debug("Not resolving SRV record for {}", host); + } + + // Resolve host here + try { + InetAddress resolved = InetAddress.getByName(host); + log.debug("Resolved {} -> {}", host, resolved.getHostAddress()); + return new InetSocketAddress(resolved, port); + } catch (UnknownHostException e) { + log.debug("Failed to resolve host, letting Netty do it instead.", e); + return InetSocketAddress.createUnresolved(host, port); + } + } + + public static void initializeHAProxySupport(Session session, Channel channel) { + InetSocketAddress clientAddress = session.getFlag(BuiltinFlags.CLIENT_PROXIED_ADDRESS); + if (clientAddress == null) { + return; + } + + channel.pipeline().addLast("proxy-protocol-encoder", HAProxyMessageEncoder.INSTANCE); + channel.pipeline().addLast("proxy-protocol-packet-sender", new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + InetSocketAddress remoteAddress = (InetSocketAddress) ctx.channel().remoteAddress(); + HAProxyProxiedProtocol proxiedProtocol = clientAddress.getAddress() instanceof Inet4Address ? HAProxyProxiedProtocol.TCP4 : HAProxyProxiedProtocol.TCP6; + ctx.channel().writeAndFlush(new HAProxyMessage( + HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, proxiedProtocol, + clientAddress.getAddress().getHostAddress(), remoteAddress.getAddress().getHostAddress(), + clientAddress.getPort(), remoteAddress.getPort() + )).addListener(future -> channel.pipeline().remove("proxy-protocol-encoder")); + ctx.pipeline().remove(this); + + super.channelActive(ctx); + } + }); + } + + public static void addProxy(ProxyInfo proxy, ChannelPipeline pipeline) { + if (proxy == null) { + return; + } + + switch (proxy.type()) { + case HTTP -> { + if (proxy.username() != null && proxy.password() != null) { + pipeline.addLast("proxy", new HttpProxyHandler(proxy.address(), proxy.username(), proxy.password())); + } else { + pipeline.addLast("proxy", new HttpProxyHandler(proxy.address())); + } + } + case SOCKS4 -> { + if (proxy.username() != null) { + pipeline.addLast("proxy", new Socks4ProxyHandler(proxy.address(), proxy.username())); + } else { + pipeline.addLast("proxy", new Socks4ProxyHandler(proxy.address())); + } + } + case SOCKS5 -> { + if (proxy.username() != null && proxy.password() != null) { + pipeline.addLast("proxy", new Socks5ProxyHandler(proxy.address(), proxy.username(), proxy.password())); + } else { + pipeline.addLast("proxy", new Socks5ProxyHandler(proxy.address())); + } + } + default -> throw new UnsupportedOperationException("Unsupported proxy type: " + proxy.type()); + } + } +} diff --git a/protocol/src/main/java/org/geysermc/mcprotocollib/network/helper/TransportHelper.java b/protocol/src/main/java/org/geysermc/mcprotocollib/network/helper/TransportHelper.java index 7e526837..7c7e4a07 100644 --- a/protocol/src/main/java/org/geysermc/mcprotocollib/network/helper/TransportHelper.java +++ b/protocol/src/main/java/org/geysermc/mcprotocollib/network/helper/TransportHelper.java @@ -29,6 +29,8 @@ import java.util.concurrent.ThreadFactory; import java.util.function.Function; public class TransportHelper { + public static final TransportHelper.TransportType TRANSPORT_TYPE = TransportHelper.determineTransportMethod(); + public enum TransportMethod { NIO, EPOLL, KQUEUE, IO_URING } @@ -45,7 +47,7 @@ public class TransportHelper { boolean supportsTcpFastOpenClient) { } - public static TransportType determineTransportMethod() { + private static TransportType determineTransportMethod() { if (isClassAvailable("io.netty.incubator.channel.uring.IOUring") && IOUring.isAvailable()) { return new TransportType( TransportMethod.IO_URING, diff --git a/protocol/src/main/java/org/geysermc/mcprotocollib/network/tcp/TcpClientSession.java b/protocol/src/main/java/org/geysermc/mcprotocollib/network/tcp/TcpClientSession.java index 4c3ac62f..c5daef1d 100644 --- a/protocol/src/main/java/org/geysermc/mcprotocollib/network/tcp/TcpClientSession.java +++ b/protocol/src/main/java/org/geysermc/mcprotocollib/network/tcp/TcpClientSession.java @@ -1,55 +1,27 @@ package org.geysermc.mcprotocollib.network.tcp; import io.netty.bootstrap.Bootstrap; -import io.netty.buffer.ByteBuf; -import io.netty.channel.AddressedEnvelope; import io.netty.channel.Channel; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOption; import io.netty.channel.ChannelPipeline; import io.netty.channel.EventLoopGroup; -import io.netty.handler.codec.dns.DefaultDnsQuestion; -import io.netty.handler.codec.dns.DefaultDnsRawRecord; -import io.netty.handler.codec.dns.DefaultDnsRecordDecoder; -import io.netty.handler.codec.dns.DnsRecordType; -import io.netty.handler.codec.dns.DnsResponse; -import io.netty.handler.codec.dns.DnsSection; -import io.netty.handler.codec.haproxy.HAProxyCommand; -import io.netty.handler.codec.haproxy.HAProxyMessage; -import io.netty.handler.codec.haproxy.HAProxyMessageEncoder; -import io.netty.handler.codec.haproxy.HAProxyProtocolVersion; -import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol; -import io.netty.handler.proxy.HttpProxyHandler; -import io.netty.handler.proxy.Socks4ProxyHandler; -import io.netty.handler.proxy.Socks5ProxyHandler; import io.netty.handler.timeout.ReadTimeoutHandler; import io.netty.handler.timeout.WriteTimeoutHandler; -import io.netty.resolver.dns.DnsNameResolver; -import io.netty.resolver.dns.DnsNameResolverBuilder; import io.netty.util.concurrent.DefaultThreadFactory; import org.checkerframework.checker.nullness.qual.NonNull; import org.geysermc.mcprotocollib.network.BuiltinFlags; import org.geysermc.mcprotocollib.network.ProxyInfo; import org.geysermc.mcprotocollib.network.codec.PacketCodecHelper; +import org.geysermc.mcprotocollib.network.helper.NettyHelper; import org.geysermc.mcprotocollib.network.helper.TransportHelper; import org.geysermc.mcprotocollib.network.packet.PacketProtocol; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import java.net.Inet4Address; -import java.net.InetAddress; -import java.net.InetSocketAddress; -import java.net.UnknownHostException; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; public class TcpClientSession extends TcpSession { - private static final TransportHelper.TransportType TRANSPORT_TYPE = TransportHelper.determineTransportMethod(); - private static final String IP_REGEX = "\\b\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\b"; - private static final Logger log = LoggerFactory.getLogger(TcpClientSession.class); private static EventLoopGroup EVENT_LOOP_GROUP; /** @@ -94,12 +66,12 @@ public class TcpClientSession extends TcpSession { } final Bootstrap bootstrap = new Bootstrap() - .channelFactory(TRANSPORT_TYPE.socketChannelFactory()) + .channelFactory(TransportHelper.TRANSPORT_TYPE.socketChannelFactory()) .option(ChannelOption.TCP_NODELAY, true) .option(ChannelOption.IP_TOS, 0x18) .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, getFlag(BuiltinFlags.CLIENT_CONNECT_TIMEOUT, 30) * 1000) .group(EVENT_LOOP_GROUP) - .remoteAddress(resolveAddress()) + .remoteAddress(NettyHelper.resolveAddress(this, EVENT_LOOP_GROUP.next(), getHost(), getPort())) .localAddress(bindAddress, bindPort) .handler(new ChannelInitializer<>() { @Override @@ -109,9 +81,9 @@ public class TcpClientSession extends TcpSession { ChannelPipeline pipeline = channel.pipeline(); - addProxy(pipeline); + NettyHelper.addProxy(proxy, pipeline); - initializeHAProxySupport(channel); + NettyHelper.initializeHAProxySupport(TcpClientSession.this, channel); pipeline.addLast("read-timeout", new ReadTimeoutHandler(getFlag(BuiltinFlags.READ_TIMEOUT, 30))); pipeline.addLast("write-timeout", new WriteTimeoutHandler(getFlag(BuiltinFlags.WRITE_TIMEOUT, 0))); @@ -127,7 +99,7 @@ public class TcpClientSession extends TcpSession { } }); - if (getFlag(BuiltinFlags.TCP_FAST_OPEN, false) && TRANSPORT_TYPE.supportsTcpFastOpenClient()) { + if (getFlag(BuiltinFlags.TCP_FAST_OPEN, false) && TransportHelper.TRANSPORT_TYPE.supportsTcpFastOpenClient()) { bootstrap.option(ChannelOption.TCP_FASTOPEN_CONNECT, true); } @@ -150,121 +122,12 @@ public class TcpClientSession extends TcpSession { return this.codecHelper; } - private InetSocketAddress resolveAddress() { - String name = this.getPacketProtocol().getSRVRecordPrefix() + "._tcp." + this.getHost(); - log.debug("Attempting SRV lookup for \"{}\".", name); - - if (getFlag(BuiltinFlags.ATTEMPT_SRV_RESOLVE, true) && (!this.host.matches(IP_REGEX) && !this.host.equalsIgnoreCase("localhost"))) { - try (DnsNameResolver resolver = new DnsNameResolverBuilder(EVENT_LOOP_GROUP.next()) - .channelFactory(TRANSPORT_TYPE.datagramChannelFactory()) - .build()) { - AddressedEnvelope envelope = resolver.query(new DefaultDnsQuestion(name, DnsRecordType.SRV)).get(); - try { - DnsResponse response = envelope.content(); - if (response.count(DnsSection.ANSWER) > 0) { - DefaultDnsRawRecord record = response.recordAt(DnsSection.ANSWER, 0); - if (record.type() == DnsRecordType.SRV) { - ByteBuf buf = record.content(); - buf.skipBytes(4); // Skip priority and weight. - - int port = buf.readUnsignedShort(); - String host = DefaultDnsRecordDecoder.decodeName(buf); - if (host.endsWith(".")) { - host = host.substring(0, host.length() - 1); - } - - log.debug("Found SRV record containing \"{}:{}\".", host, port); - - this.host = host; - this.port = port; - } else { - log.debug("Received non-SRV record in response."); - } - } else { - log.debug("No SRV record found."); - } - } finally { - envelope.release(); - } - } catch (Exception e) { - log.debug("Failed to resolve SRV record.", e); - } - } else { - log.debug("Not resolving SRV record for {}", this.host); - } - - // Resolve host here - try { - InetAddress resolved = InetAddress.getByName(getHost()); - log.debug("Resolved {} -> {}", getHost(), resolved.getHostAddress()); - return new InetSocketAddress(resolved, getPort()); - } catch (UnknownHostException e) { - log.debug("Failed to resolve host, letting Netty do it instead.", e); - return InetSocketAddress.createUnresolved(getHost(), getPort()); - } - } - - private void addProxy(ChannelPipeline pipeline) { - if (proxy == null) { - return; - } - - switch (proxy.type()) { - case HTTP -> { - if (proxy.username() != null && proxy.password() != null) { - pipeline.addLast("proxy", new HttpProxyHandler(proxy.address(), proxy.username(), proxy.password())); - } else { - pipeline.addLast("proxy", new HttpProxyHandler(proxy.address())); - } - } - case SOCKS4 -> { - if (proxy.username() != null) { - pipeline.addLast("proxy", new Socks4ProxyHandler(proxy.address(), proxy.username())); - } else { - pipeline.addLast("proxy", new Socks4ProxyHandler(proxy.address())); - } - } - case SOCKS5 -> { - if (proxy.username() != null && proxy.password() != null) { - pipeline.addLast("proxy", new Socks5ProxyHandler(proxy.address(), proxy.username(), proxy.password())); - } else { - pipeline.addLast("proxy", new Socks5ProxyHandler(proxy.address())); - } - } - default -> throw new UnsupportedOperationException("Unsupported proxy type: " + proxy.type()); - } - } - - private void initializeHAProxySupport(Channel channel) { - InetSocketAddress clientAddress = getFlag(BuiltinFlags.CLIENT_PROXIED_ADDRESS); - if (clientAddress == null) { - return; - } - - channel.pipeline().addLast("proxy-protocol-encoder", HAProxyMessageEncoder.INSTANCE); - channel.pipeline().addLast("proxy-protocol-packet-sender", new ChannelInboundHandlerAdapter() { - @Override - public void channelActive(ChannelHandlerContext ctx) throws Exception { - InetSocketAddress remoteAddress = (InetSocketAddress) ctx.channel().remoteAddress(); - HAProxyProxiedProtocol proxiedProtocol = clientAddress.getAddress() instanceof Inet4Address ? HAProxyProxiedProtocol.TCP4 : HAProxyProxiedProtocol.TCP6; - ctx.channel().writeAndFlush(new HAProxyMessage( - HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, proxiedProtocol, - clientAddress.getAddress().getHostAddress(), remoteAddress.getAddress().getHostAddress(), - clientAddress.getPort(), remoteAddress.getPort() - )).addListener(future -> channel.pipeline().remove("proxy-protocol-encoder")); - ctx.pipeline().remove(this); - - super.channelActive(ctx); - } - }); - } - private static void createTcpEventLoopGroup() { if (EVENT_LOOP_GROUP != null) { return; } - EVENT_LOOP_GROUP = TRANSPORT_TYPE.eventLoopGroupFactory().apply(newThreadFactory()); + EVENT_LOOP_GROUP = TransportHelper.TRANSPORT_TYPE.eventLoopGroupFactory().apply(newThreadFactory()); Runtime.getRuntime().addShutdownHook(new Thread( () -> EVENT_LOOP_GROUP.shutdownGracefully(SHUTDOWN_QUIET_PERIOD_MS, SHUTDOWN_TIMEOUT_MS, TimeUnit.MILLISECONDS))); diff --git a/protocol/src/main/java/org/geysermc/mcprotocollib/network/tcp/TcpServer.java b/protocol/src/main/java/org/geysermc/mcprotocollib/network/tcp/TcpServer.java index 809a88c7..fea2fbd2 100644 --- a/protocol/src/main/java/org/geysermc/mcprotocollib/network/tcp/TcpServer.java +++ b/protocol/src/main/java/org/geysermc/mcprotocollib/network/tcp/TcpServer.java @@ -22,7 +22,6 @@ import java.util.concurrent.CompletableFuture; import java.util.function.Supplier; public class TcpServer extends AbstractServer { - private static final TransportHelper.TransportType TRANSPORT_TYPE = TransportHelper.determineTransportMethod(); private static final Logger log = LoggerFactory.getLogger(TcpServer.class); private EventLoopGroup group; @@ -43,10 +42,10 @@ public class TcpServer extends AbstractServer { return; } - this.group = TRANSPORT_TYPE.eventLoopGroupFactory().apply(null); + this.group = TransportHelper.TRANSPORT_TYPE.eventLoopGroupFactory().apply(null); ServerBootstrap bootstrap = new ServerBootstrap() - .channelFactory(TRANSPORT_TYPE.serverSocketChannelFactory()) + .channelFactory(TransportHelper.TRANSPORT_TYPE.serverSocketChannelFactory()) .group(this.group) .childOption(ChannelOption.TCP_NODELAY, true) .childOption(ChannelOption.IP_TOS, 0x18) @@ -76,7 +75,7 @@ public class TcpServer extends AbstractServer { } }); - if (getGlobalFlag(BuiltinFlags.TCP_FAST_OPEN, false) && TRANSPORT_TYPE.supportsTcpFastOpenServer()) { + if (getGlobalFlag(BuiltinFlags.TCP_FAST_OPEN, false) && TransportHelper.TRANSPORT_TYPE.supportsTcpFastOpenServer()) { bootstrap.option(ChannelOption.TCP_FASTOPEN, 3); }