J17 -> J8 improvements (#43)

* Desugar record equals, hashCode, and toString

* No more runtime libraries for J17 -> J8

I dropped Apache Commons IO by implementing transferTo, then implementing readAllBytes as a transferTo a ByteArrayOutputStream

* Fix lists being constructed in reverse

Well they still are, but they're reversed at the end.

* Better Stream.toList conversion
This commit is contained in:
Josiah Glosson 2023-05-13 08:25:38 -05:00 committed by GitHub
parent 565690fcb1
commit 6070c67785
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -20,20 +20,44 @@ package net.raphimc.viaproxy.injection;
import net.lenni0451.classtransform.TransformerManager;
import net.lenni0451.classtransform.transformer.IBytecodeTransformer;
import net.lenni0451.classtransform.utils.ASMUtils;
import net.raphimc.viaproxy.util.logging.Logger;
import org.objectweb.asm.Label;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.Type;
import org.objectweb.asm.tree.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.io.File;
import java.nio.file.Files;
import java.util.*;
public class Java17ToJava8 implements IBytecodeTransformer {
private static final boolean DEBUG_DUMP = Boolean.getBoolean("viaproxy.debug.dump17to8");
private static final char STACK_ARG_CONSTANT = '\u0001';
private static final char BSM_ARG_CONSTANT = '\u0002';
final TransformerManager transformerManager;
private static final String TRANSFERTO_DESC = "(Ljava/io/InputStream;Ljava/io/OutputStream;)J";
private static final String EQUALS_DESC = "(Ljava/lang/Object;)Z";
private static final String HASHCODE_DESC = "()I";
private static final String TOSTRING_DESC = "()Ljava/lang/String;";
private static final Map<String, String> PRIMITIVE_WRAPPERS = new HashMap<>();
static {
PRIMITIVE_WRAPPERS.put("V", Type.getInternalName(Void.class));
PRIMITIVE_WRAPPERS.put("Z", Type.getInternalName(Boolean.class));
PRIMITIVE_WRAPPERS.put("B", Type.getInternalName(Byte.class));
PRIMITIVE_WRAPPERS.put("S", Type.getInternalName(Short.class));
PRIMITIVE_WRAPPERS.put("C", Type.getInternalName(Character.class));
PRIMITIVE_WRAPPERS.put("I", Type.getInternalName(Integer.class));
PRIMITIVE_WRAPPERS.put("F", Type.getInternalName(Float.class));
PRIMITIVE_WRAPPERS.put("J", Type.getInternalName(Long.class));
PRIMITIVE_WRAPPERS.put("D", Type.getInternalName(Double.class));
}
private final TransformerManager transformerManager;
private final int nativeClassVersion;
private final List<String> whitelistedPackages = new ArrayList<>();
@ -69,10 +93,20 @@ public class Java17ToJava8 implements IBytecodeTransformer {
this.convertMapMethods(classNode);
this.convertStreamMethods(classNode);
this.convertMiscMethods(classNode);
this.removeRecords(classNode);
this.convertRecords(classNode);
if (calculateStackMapFrames) {
return ASMUtils.toBytes(classNode, this.transformerManager.getClassTree(), this.transformerManager.getClassProvider());
final byte[] result = ASMUtils.toBytes(classNode, this.transformerManager.getClassTree(), this.transformerManager.getClassProvider());
if (DEBUG_DUMP) {
try {
final File file = new File("vp_17to8_dump", classNode.name + ".class");
file.getParentFile().mkdirs();
Files.write(file.toPath(), result);
} catch (Throwable e) {
Logger.LOGGER.error("Failed to dump class {}", className, e);
}
}
return result;
} else {
return ASMUtils.toStacklessBytes(classNode);
}
@ -144,6 +178,13 @@ public class Java17ToJava8 implements IBytecodeTransformer {
list.add(new InsnNode(Opcodes.POP));
}
list.add(new VarInsnNode(Opcodes.ALOAD, freeVarIndex));
list.add(new InsnNode(Opcodes.DUP));
list.add(new MethodInsnNode(
Opcodes.INVOKESTATIC,
"java/util/Collections",
"reverse",
"(Ljava/util/List;)V"
));
list.add(new MethodInsnNode(Opcodes.INVOKESTATIC, "java/util/Collections", "unmodifiableList", "(Ljava/util/List;)Ljava/util/List;"));
}
} else if (min.name.equals("copyOf")) {
@ -268,16 +309,25 @@ public class Java17ToJava8 implements IBytecodeTransformer {
final InsnList list = new InsnList();
if (min.name.equals("toList")) {
int freeVarIndex = ASMUtils.getFreeVarIndex(method);
list.add(new VarInsnNode(Opcodes.ASTORE, freeVarIndex));
list.add(new TypeInsnNode(Opcodes.NEW, "java/util/ArrayList"));
list.add(new InsnNode(Opcodes.DUP));
list.add(new VarInsnNode(Opcodes.ALOAD, freeVarIndex));
list.add(new MethodInsnNode(Opcodes.INVOKEINTERFACE, "java/util/stream/Stream", "toArray", "()[Ljava/lang/Object;"));
list.add(new MethodInsnNode(Opcodes.INVOKESTATIC, "java/util/Arrays", "asList", "([Ljava/lang/Object;)Ljava/util/List;"));
list.add(new MethodInsnNode(Opcodes.INVOKESPECIAL, "java/util/ArrayList", "<init>", "(Ljava/util/Collection;)V"));
list.add(new MethodInsnNode(Opcodes.INVOKESTATIC, "java/util/Collections", "unmodifiableList", "(Ljava/util/List;)Ljava/util/List;"));
list.add(new MethodInsnNode(
Opcodes.INVOKESTATIC,
"java/util/stream/Collectors",
"toList",
"()Ljava/util/stream/Collector;"
));
list.add(new MethodInsnNode(
Opcodes.INVOKEINTERFACE,
"java/util/stream/Stream",
"collect",
"(Ljava/util/stream/Collector;)Ljava/lang/Object;"
));
list.add(new TypeInsnNode(Opcodes.CHECKCAST, "java/util/List"));
list.add(new MethodInsnNode(
Opcodes.INVOKESTATIC,
"java/util/Collections",
"unmodifiableList",
"(Ljava/util/List;)Ljava/util/List;")
);
}
if (list.size() != 0) {
@ -290,6 +340,15 @@ public class Java17ToJava8 implements IBytecodeTransformer {
}
private void convertMiscMethods(final ClassNode node) {
boolean needsTransferTo = false;
String transferToName;
{
int i = 0;
do {
transferToName = "transferTo$" + i;
} while (ASMUtils.getMethod(node, transferToName, TRANSFERTO_DESC) != null);
}
for (MethodNode method : node.methods) {
for (AbstractInsnNode insn : method.instructions.toArray()) {
if (insn instanceof MethodInsnNode) {
@ -303,7 +362,36 @@ public class Java17ToJava8 implements IBytecodeTransformer {
}
} else if (min.owner.equals("java/io/InputStream")) {
if (min.name.equals("readAllBytes") && min.getOpcode() == Opcodes.INVOKEVIRTUAL) {
list.add(new MethodInsnNode(Opcodes.INVOKESTATIC, "org/apache/commons/io/IOUtils", "toByteArray", "(Ljava/io/InputStream;)[B"));
needsTransferTo = true;
list.add(new TypeInsnNode(Opcodes.NEW, "java/io/ByteArrayOutputStream"));
list.add(new InsnNode(Opcodes.DUP));
list.add(new MethodInsnNode(
Opcodes.INVOKESPECIAL,
"java/io/ByteArrayOutputStream",
"<init>",
"()V"
));
list.add(new InsnNode(Opcodes.DUP_X1));
list.add(new MethodInsnNode(
Opcodes.INVOKESTATIC,
node.name,
transferToName,
TRANSFERTO_DESC
));
list.add(new InsnNode(Opcodes.POP2));
list.add(new MethodInsnNode(
Opcodes.INVOKEVIRTUAL,
"java/io/ByteArrayOutputStream",
"toByteArray",
"()[B"
));
} else if (min.name.equals("transferTo") && min.getOpcode() == Opcodes.INVOKEVIRTUAL) {
needsTransferTo = true;
list.add(new MethodInsnNode(Opcodes.INVOKESTATIC,
node.name,
transferToName,
TRANSFERTO_DESC
));
}
} else if (min.owner.equals("java/nio/file/FileSystems")) {
if (min.name.equals("newFileSystem") && min.desc.equals("(Ljava/nio/file/Path;Ljava/util/Map;Ljava/lang/ClassLoader;)Ljava/nio/file/FileSystem;")) {
@ -402,14 +490,94 @@ public class Java17ToJava8 implements IBytecodeTransformer {
}
}
}
if (needsTransferTo) {
// I compiled this by hand btw
final MethodVisitor transferTo = node.visitMethod(
Opcodes.ACC_PRIVATE | Opcodes.ACC_STATIC | Opcodes.ACC_SYNTHETIC,
transferToName, TRANSFERTO_DESC, null, new String[] {"java/io/IOException"}
);
transferTo.visitCode();
// Objects.requireNonNull(out, "out");
transferTo.visitVarInsn(Opcodes.ALOAD, 1);
transferTo.visitLdcInsn("out");
transferTo.visitMethodInsn(
Opcodes.INVOKESTATIC,
"java/util/Objects",
"requireNonNull",
"(Ljava/lang/Object;Ljava/lang/String;)Ljava/lang/Object;",
false
);
transferTo.visitInsn(Opcodes.POP);
// long transferred = 0;
transferTo.visitInsn(Opcodes.LCONST_0);
transferTo.visitVarInsn(Opcodes.LSTORE, 2);
// byte[] buffer = new byte[DEFAULT_BUFFER_SIZE];
transferTo.visitIntInsn(Opcodes.SIPUSH, 8192);
transferTo.visitIntInsn(Opcodes.NEWARRAY, Opcodes.T_BYTE);
transferTo.visitVarInsn(Opcodes.ASTORE, 4);
// while ((read = this.read(buffer, 0, DEFAULT_BUFFER_SIZE)) >= 0) {
final Label whileStart = new Label();
final Label whileEnd = new Label();
transferTo.visitLabel(whileStart);
transferTo.visitVarInsn(Opcodes.ALOAD, 0);
transferTo.visitVarInsn(Opcodes.ALOAD, 4);
transferTo.visitInsn(Opcodes.ICONST_0);
transferTo.visitIntInsn(Opcodes.SIPUSH, 8192);
transferTo.visitMethodInsn(
Opcodes.INVOKEVIRTUAL,
"java/io/InputStream",
"read",
"([BII)I",
false
);
transferTo.visitInsn(Opcodes.DUP);
transferTo.visitVarInsn(Opcodes.ISTORE, 5);
transferTo.visitJumpInsn(Opcodes.IFLT, whileEnd);
// out.write(buffer, 0, read);
transferTo.visitVarInsn(Opcodes.ALOAD, 1);
transferTo.visitVarInsn(Opcodes.ALOAD, 4);
transferTo.visitInsn(Opcodes.ICONST_0);
transferTo.visitVarInsn(Opcodes.ILOAD, 5);
transferTo.visitMethodInsn(
Opcodes.INVOKEVIRTUAL,
"java/io/OutputStream",
"write",
"([BII)V",
false
);
// transferred += read;
transferTo.visitVarInsn(Opcodes.LLOAD, 2);
transferTo.visitVarInsn(Opcodes.ILOAD, 5);
transferTo.visitInsn(Opcodes.I2L);
transferTo.visitInsn(Opcodes.LADD);
transferTo.visitVarInsn(Opcodes.LSTORE, 2);
// }
transferTo.visitJumpInsn(Opcodes.GOTO, whileStart);
transferTo.visitLabel(whileEnd);
// return transferred;
transferTo.visitVarInsn(Opcodes.LLOAD, 2);
transferTo.visitInsn(Opcodes.LRETURN);
transferTo.visitEnd();
}
}
private void removeRecords(final ClassNode node) {
if (node.superName.equals("java/lang/Record")) {
private void convertRecords(final ClassNode node) {
if (!node.superName.equals("java/lang/Record")) return;
node.access &= ~Opcodes.ACC_RECORD;
node.superName = "java/lang/Object";
List<MethodNode> constructors = ASMUtils.getMethodsFromCombi(node, "<init>");
final List<MethodNode> constructors = ASMUtils.getMethodsFromCombi(node, "<init>");
for (MethodNode method : constructors) {
for (AbstractInsnNode insn : method.instructions.toArray()) {
if (insn.getOpcode() == Opcodes.INVOKESPECIAL) {
@ -421,6 +589,172 @@ public class Java17ToJava8 implements IBytecodeTransformer {
}
}
}
node.methods.remove(ASMUtils.getMethod(node, "equals", EQUALS_DESC));
final MethodVisitor equals = node.visitMethod(Opcodes.ACC_PUBLIC, "equals", EQUALS_DESC, null, null);
{
equals.visitCode();
equals.visitVarInsn(Opcodes.ALOAD, 0);
equals.visitVarInsn(Opcodes.ALOAD, 1);
final Label notSameLabel = new Label();
equals.visitJumpInsn(Opcodes.IF_ACMPNE, notSameLabel);
equals.visitInsn(Opcodes.ICONST_1);
equals.visitInsn(Opcodes.IRETURN);
equals.visitLabel(notSameLabel);
// Original uses Class.isInstance, but I think instanceof is more fitting here
equals.visitVarInsn(Opcodes.ALOAD, 1);
equals.visitTypeInsn(Opcodes.INSTANCEOF, node.name);
final Label notIsInstanceLabel = new Label();
equals.visitJumpInsn(Opcodes.IFNE, notIsInstanceLabel);
equals.visitInsn(Opcodes.ICONST_0);
equals.visitInsn(Opcodes.IRETURN);
equals.visitLabel(notIsInstanceLabel);
equals.visitVarInsn(Opcodes.ALOAD, 1);
equals.visitTypeInsn(Opcodes.CHECKCAST, node.name);
equals.visitVarInsn(Opcodes.ASTORE, 2);
final Label notEqualLabel = new Label();
for (final RecordComponentNode component : node.recordComponents) {
equals.visitVarInsn(Opcodes.ALOAD, 0);
equals.visitFieldInsn(Opcodes.GETFIELD, node.name, component.name, component.descriptor);
equals.visitVarInsn(Opcodes.ALOAD, 2);
equals.visitFieldInsn(Opcodes.GETFIELD, node.name, component.name, component.descriptor);
if (Type.getType(component.descriptor).getSort() >= Type.ARRAY) { // ARRAY or OBJECT
equals.visitMethodInsn(
Opcodes.INVOKESTATIC,
Type.getInternalName(Objects.class),
"equals",
"(Ljava/lang/Object;Ljava/lang/Object;)Z",
false
);
equals.visitJumpInsn(Opcodes.IFEQ, notEqualLabel);
continue;
} else if ("BSCIZ".contains(component.descriptor)) {
equals.visitJumpInsn(Opcodes.IF_ICMPNE, notEqualLabel);
continue;
} else if (component.descriptor.equals("F")) {
equals.visitMethodInsn(
Opcodes.INVOKESTATIC,
Type.getInternalName(Float.class),
"equals",
"(FF)Z",
false
);
} else if (component.descriptor.equals("D")) {
equals.visitMethodInsn(
Opcodes.INVOKESTATIC,
Type.getInternalName(Double.class),
"equals",
"(DD)Z",
false
);
} else if (component.descriptor.equals("J")) {
equals.visitInsn(Opcodes.LCMP);
} else {
throw new AssertionError("Unknown descriptor " + component.descriptor);
}
equals.visitJumpInsn(Opcodes.IFNE, notEqualLabel);
}
equals.visitInsn(Opcodes.ICONST_1);
equals.visitInsn(Opcodes.IRETURN);
equals.visitLabel(notEqualLabel);
equals.visitInsn(Opcodes.ICONST_0);
equals.visitInsn(Opcodes.IRETURN);
equals.visitEnd();
}
node.methods.remove(ASMUtils.getMethod(node, "hashCode", HASHCODE_DESC));
final MethodVisitor hashCode = node.visitMethod(Opcodes.ACC_PUBLIC, "hashCode", HASHCODE_DESC, null, null);
{
hashCode.visitCode();
hashCode.visitInsn(Opcodes.ICONST_0);
for (final RecordComponentNode component : node.recordComponents) {
hashCode.visitIntInsn(Opcodes.BIPUSH, 31);
hashCode.visitInsn(Opcodes.IMUL);
hashCode.visitVarInsn(Opcodes.ALOAD, 0);
hashCode.visitFieldInsn(Opcodes.GETFIELD, node.name, component.name, component.descriptor);
final String owner = PRIMITIVE_WRAPPERS.get(component.descriptor);
hashCode.visitMethodInsn(
Opcodes.INVOKESTATIC,
owner != null ? owner : "java/util/Objects",
"hashCode",
"(" + (owner != null ? component.descriptor : "Ljava/lang/Object;") + ")I",
false
);
hashCode.visitInsn(Opcodes.IADD);
}
hashCode.visitInsn(Opcodes.IRETURN);
hashCode.visitEnd();
}
node.methods.remove(ASMUtils.getMethod(node, "toString", TOSTRING_DESC));
final MethodVisitor toString = node.visitMethod(Opcodes.ACC_PUBLIC, "toString", TOSTRING_DESC, null, null);
{
toString.visitCode();
final StringBuilder formatString = new StringBuilder("%s[");
for (int i = 0; i < node.recordComponents.size(); i++) {
formatString.append(node.recordComponents.get(i).name).append("=%s");
if (i != node.recordComponents.size() - 1) {
formatString.append(", ");
}
}
formatString.append(']');
toString.visitLdcInsn(formatString.toString());
toString.visitIntInsn(Opcodes.SIPUSH, node.recordComponents.size() + 1);
toString.visitTypeInsn(Opcodes.ANEWARRAY, "java/lang/Object");
toString.visitInsn(Opcodes.DUP);
toString.visitInsn(Opcodes.ICONST_0);
toString.visitVarInsn(Opcodes.ALOAD, 0);
toString.visitMethodInsn(
Opcodes.INVOKEVIRTUAL,
"java/lang/Object",
"getClass",
"()Ljava/lang/Class;",
false
);
toString.visitMethodInsn(
Opcodes.INVOKEVIRTUAL,
"java/lang/Class",
"getSimpleName",
"()Ljava/lang/String;",
false
);
toString.visitInsn(Opcodes.AASTORE);
int i = 1;
for (final RecordComponentNode component : node.recordComponents) {
toString.visitInsn(Opcodes.DUP);
toString.visitIntInsn(Opcodes.SIPUSH, i);
toString.visitVarInsn(Opcodes.ALOAD, 0);
toString.visitFieldInsn(Opcodes.GETFIELD, node.name, component.name, component.descriptor);
final String owner = PRIMITIVE_WRAPPERS.get(component.descriptor);
toString.visitMethodInsn(
Opcodes.INVOKESTATIC,
owner != null ? owner : "java/util/Objects",
"toString",
"(" + (owner != null ? component.descriptor : "Ljava/lang/Object;") + ")Ljava/lang/String;",
false
);
toString.visitInsn(Opcodes.AASTORE);
i++;
}
toString.visitMethodInsn(
Opcodes.INVOKESTATIC,
"java/lang/String",
"format",
"(Ljava/lang/String;[Ljava/lang/Object;)Ljava/lang/String;",
false
);
toString.visitInsn(Opcodes.ARETURN);
toString.visitEnd();
}
}