From d95f3a277a28702ef3aa5027762b5b92aab9c303 Mon Sep 17 00:00:00 2001 From: "Josiah (Gaming32) Glosson" Date: Fri, 12 May 2023 20:59:08 -0500 Subject: [PATCH] Desugar record equals, hashCode, and toString --- .../viaproxy/injection/Java17ToJava8.java | 237 ++++++++++++++++-- 1 file changed, 218 insertions(+), 19 deletions(-) diff --git a/src/main/java/net/raphimc/viaproxy/injection/Java17ToJava8.java b/src/main/java/net/raphimc/viaproxy/injection/Java17ToJava8.java index ca0a756..8fbb997 100644 --- a/src/main/java/net/raphimc/viaproxy/injection/Java17ToJava8.java +++ b/src/main/java/net/raphimc/viaproxy/injection/Java17ToJava8.java @@ -20,20 +20,42 @@ package net.raphimc.viaproxy.injection; import net.lenni0451.classtransform.TransformerManager; import net.lenni0451.classtransform.transformer.IBytecodeTransformer; import net.lenni0451.classtransform.utils.ASMUtils; +import net.raphimc.viaproxy.util.logging.Logger; +import org.objectweb.asm.Label; +import org.objectweb.asm.MethodVisitor; import org.objectweb.asm.Opcodes; import org.objectweb.asm.Type; import org.objectweb.asm.tree.*; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; +import java.io.File; +import java.nio.file.Files; +import java.util.*; public class Java17ToJava8 implements IBytecodeTransformer { + private static final boolean DEBUG_DUMP = Boolean.getBoolean("viaproxy.debug.dump17to8"); + private static final char STACK_ARG_CONSTANT = '\u0001'; private static final char BSM_ARG_CONSTANT = '\u0002'; - final TransformerManager transformerManager; + private static final String EQUALS_DESC = "(Ljava/lang/Object;)Z"; + private static final String HASHCODE_DESC = "()I"; + private static final String TOSTRING_DESC = "()Ljava/lang/String;"; + private static final Map PRIMITIVE_WRAPPERS = new HashMap<>(); + + static { + PRIMITIVE_WRAPPERS.put("V", Type.getInternalName(Void.class)); + PRIMITIVE_WRAPPERS.put("Z", Type.getInternalName(Boolean.class)); + PRIMITIVE_WRAPPERS.put("B", Type.getInternalName(Byte.class)); + PRIMITIVE_WRAPPERS.put("S", Type.getInternalName(Short.class)); + PRIMITIVE_WRAPPERS.put("C", Type.getInternalName(Character.class)); + PRIMITIVE_WRAPPERS.put("I", Type.getInternalName(Integer.class)); + PRIMITIVE_WRAPPERS.put("F", Type.getInternalName(Float.class)); + PRIMITIVE_WRAPPERS.put("J", Type.getInternalName(Long.class)); + PRIMITIVE_WRAPPERS.put("D", Type.getInternalName(Double.class)); + } + + private final TransformerManager transformerManager; private final int nativeClassVersion; private final List whitelistedPackages = new ArrayList<>(); @@ -69,10 +91,20 @@ public class Java17ToJava8 implements IBytecodeTransformer { this.convertMapMethods(classNode); this.convertStreamMethods(classNode); this.convertMiscMethods(classNode); - this.removeRecords(classNode); + this.convertRecords(classNode); if (calculateStackMapFrames) { - return ASMUtils.toBytes(classNode, this.transformerManager.getClassTree(), this.transformerManager.getClassProvider()); + final byte[] result = ASMUtils.toBytes(classNode, this.transformerManager.getClassTree(), this.transformerManager.getClassProvider()); + if (DEBUG_DUMP) { + try { + final File file = new File("vp_17to8_dump", classNode.name + ".class"); + file.getParentFile().mkdirs(); + Files.write(file.toPath(), result); + } catch (Throwable e) { + Logger.LOGGER.error("Failed to dump class {}", className, e); + } + } + return result; } else { return ASMUtils.toStacklessBytes(classNode); } @@ -404,24 +436,191 @@ public class Java17ToJava8 implements IBytecodeTransformer { } } - private void removeRecords(final ClassNode node) { - if (node.superName.equals("java/lang/Record")) { - node.access &= ~Opcodes.ACC_RECORD; - node.superName = "java/lang/Object"; + private void convertRecords(final ClassNode node) { + if (!node.superName.equals("java/lang/Record")) return; - List constructors = ASMUtils.getMethodsFromCombi(node, ""); - for (MethodNode method : constructors) { - for (AbstractInsnNode insn : method.instructions.toArray()) { - if (insn.getOpcode() == Opcodes.INVOKESPECIAL) { - MethodInsnNode min = (MethodInsnNode) insn; - if (min.owner.equals("java/lang/Record")) { - min.owner = "java/lang/Object"; - break; - } + node.access &= ~Opcodes.ACC_RECORD; + node.superName = "java/lang/Object"; + + final List constructors = ASMUtils.getMethodsFromCombi(node, ""); + for (MethodNode method : constructors) { + for (AbstractInsnNode insn : method.instructions.toArray()) { + if (insn.getOpcode() == Opcodes.INVOKESPECIAL) { + MethodInsnNode min = (MethodInsnNode) insn; + if (min.owner.equals("java/lang/Record")) { + min.owner = "java/lang/Object"; + break; } } } } + + node.methods.remove(ASMUtils.getMethod(node, "equals", EQUALS_DESC)); + final MethodVisitor equals = node.visitMethod(Opcodes.ACC_PUBLIC, "equals", EQUALS_DESC, null, null); + { + equals.visitCode(); + + equals.visitVarInsn(Opcodes.ALOAD, 0); + equals.visitVarInsn(Opcodes.ALOAD, 1); + final Label notSameLabel = new Label(); + equals.visitJumpInsn(Opcodes.IF_ACMPNE, notSameLabel); + equals.visitInsn(Opcodes.ICONST_1); + equals.visitInsn(Opcodes.IRETURN); + equals.visitLabel(notSameLabel); + + // Original uses Class.isInstance, but I think instanceof is more fitting here + equals.visitVarInsn(Opcodes.ALOAD, 1); + equals.visitTypeInsn(Opcodes.INSTANCEOF, node.name); + final Label notIsInstanceLabel = new Label(); + equals.visitJumpInsn(Opcodes.IFNE, notIsInstanceLabel); + equals.visitInsn(Opcodes.ICONST_0); + equals.visitInsn(Opcodes.IRETURN); + equals.visitLabel(notIsInstanceLabel); + + equals.visitVarInsn(Opcodes.ALOAD, 1); + equals.visitTypeInsn(Opcodes.CHECKCAST, node.name); + equals.visitVarInsn(Opcodes.ASTORE, 2); + + final Label notEqualLabel = new Label(); + for (final RecordComponentNode component : node.recordComponents) { + equals.visitVarInsn(Opcodes.ALOAD, 0); + equals.visitFieldInsn(Opcodes.GETFIELD, node.name, component.name, component.descriptor); + equals.visitVarInsn(Opcodes.ALOAD, 2); + equals.visitFieldInsn(Opcodes.GETFIELD, node.name, component.name, component.descriptor); + if (Type.getType(component.descriptor).getSort() >= Type.ARRAY) { // ARRAY or OBJECT + equals.visitMethodInsn( + Opcodes.INVOKESTATIC, + Type.getInternalName(Objects.class), + "equals", + "(Ljava/lang/Object;Ljava/lang/Object;)Z", + false + ); + equals.visitJumpInsn(Opcodes.IFEQ, notEqualLabel); + continue; + } else if ("BSCIZ".contains(component.descriptor)) { + equals.visitJumpInsn(Opcodes.IF_ICMPNE, notEqualLabel); + continue; + } else if (component.descriptor.equals("F")) { + equals.visitMethodInsn( + Opcodes.INVOKESTATIC, + Type.getInternalName(Float.class), + "equals", + "(FF)Z", + false + ); + } else if (component.descriptor.equals("D")) { + equals.visitMethodInsn( + Opcodes.INVOKESTATIC, + Type.getInternalName(Double.class), + "equals", + "(DD)Z", + false + ); + } else if (component.descriptor.equals("J")) { + equals.visitInsn(Opcodes.LCMP); + } else { + throw new AssertionError("Unknown descriptor " + component.descriptor); + } + equals.visitJumpInsn(Opcodes.IFNE, notEqualLabel); + } + equals.visitInsn(Opcodes.ICONST_1); + equals.visitInsn(Opcodes.IRETURN); + equals.visitLabel(notEqualLabel); + equals.visitInsn(Opcodes.ICONST_0); + equals.visitInsn(Opcodes.IRETURN); + + equals.visitEnd(); + } + + node.methods.remove(ASMUtils.getMethod(node, "hashCode", HASHCODE_DESC)); + final MethodVisitor hashCode = node.visitMethod(Opcodes.ACC_PUBLIC, "hashCode", HASHCODE_DESC, null, null); + { + hashCode.visitCode(); + + hashCode.visitInsn(Opcodes.ICONST_0); + for (final RecordComponentNode component : node.recordComponents) { + hashCode.visitIntInsn(Opcodes.BIPUSH, 31); + hashCode.visitInsn(Opcodes.IMUL); + hashCode.visitVarInsn(Opcodes.ALOAD, 0); + hashCode.visitFieldInsn(Opcodes.GETFIELD, node.name, component.name, component.descriptor); + final String owner = PRIMITIVE_WRAPPERS.get(component.descriptor); + hashCode.visitMethodInsn( + Opcodes.INVOKESTATIC, + owner != null ? owner : "java/util/Objects", + "hashCode", + "(" + (owner != null ? component.descriptor : "Ljava/lang/Object;") + ")I", + false + ); + hashCode.visitInsn(Opcodes.IADD); + } + hashCode.visitInsn(Opcodes.IRETURN); + + hashCode.visitEnd(); + } + + node.methods.remove(ASMUtils.getMethod(node, "toString", TOSTRING_DESC)); + final MethodVisitor toString = node.visitMethod(Opcodes.ACC_PUBLIC, "toString", TOSTRING_DESC, null, null); + { + toString.visitCode(); + + final StringBuilder formatString = new StringBuilder("%s["); + for (int i = 0; i < node.recordComponents.size(); i++) { + formatString.append(node.recordComponents.get(i).name).append("=%s"); + if (i != node.recordComponents.size() - 1) { + formatString.append(", "); + } + } + formatString.append(']'); + + toString.visitLdcInsn(formatString.toString()); + toString.visitIntInsn(Opcodes.SIPUSH, node.recordComponents.size() + 1); + toString.visitTypeInsn(Opcodes.ANEWARRAY, "java/lang/Object"); + toString.visitInsn(Opcodes.DUP); + toString.visitInsn(Opcodes.ICONST_0); + toString.visitVarInsn(Opcodes.ALOAD, 0); + toString.visitMethodInsn( + Opcodes.INVOKEVIRTUAL, + "java/lang/Object", + "getClass", + "()Ljava/lang/Class;", + false + ); + toString.visitMethodInsn( + Opcodes.INVOKEVIRTUAL, + "java/lang/Class", + "getSimpleName", + "()Ljava/lang/String;", + false + ); + toString.visitInsn(Opcodes.AASTORE); + int i = 1; + for (final RecordComponentNode component : node.recordComponents) { + toString.visitInsn(Opcodes.DUP); + toString.visitIntInsn(Opcodes.SIPUSH, i); + toString.visitVarInsn(Opcodes.ALOAD, 0); + toString.visitFieldInsn(Opcodes.GETFIELD, node.name, component.name, component.descriptor); + final String owner = PRIMITIVE_WRAPPERS.get(component.descriptor); + toString.visitMethodInsn( + Opcodes.INVOKESTATIC, + owner != null ? owner : "java/util/Objects", + "toString", + "(" + (owner != null ? component.descriptor : "Ljava/lang/Object;") + ")Ljava/lang/String;", + false + ); + toString.visitInsn(Opcodes.AASTORE); + i++; + } + toString.visitMethodInsn( + Opcodes.INVOKESTATIC, + "java/lang/String", + "format", + "(Ljava/lang/String;[Ljava/lang/Object;)Ljava/lang/String;", + false + ); + toString.visitInsn(Opcodes.ARETURN); + + toString.visitEnd(); + } } private int count(final String s, final char search) {