Skip to content
This repository was archived by the owner on May 12, 2024. It is now read-only.

Commit 02cf214

Browse files
committed
Fixed record replacer overwriting custom added methods
1 parent e996d2e commit 02cf214

2 files changed

Lines changed: 72 additions & 35 deletions

File tree

src/main/java/net/raphimc/javadowngrader/transformer/j15/RecordReplacer.java

Lines changed: 71 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,14 @@
1818
package net.raphimc.javadowngrader.transformer.j15;
1919

2020
import net.raphimc.javadowngrader.util.ASMUtil;
21-
import org.objectweb.asm.Label;
22-
import org.objectweb.asm.MethodVisitor;
23-
import org.objectweb.asm.Opcodes;
24-
import org.objectweb.asm.Type;
21+
import net.raphimc.javadowngrader.util.Constants;
22+
import org.objectweb.asm.*;
23+
import org.objectweb.asm.tree.AbstractInsnNode;
2524
import org.objectweb.asm.tree.ClassNode;
26-
import org.objectweb.asm.tree.RecordComponentNode;
25+
import org.objectweb.asm.tree.InvokeDynamicInsnNode;
26+
import org.objectweb.asm.tree.MethodNode;
2727

28-
import java.util.Collections;
29-
import java.util.HashMap;
30-
import java.util.Map;
31-
import java.util.Objects;
28+
import java.util.*;
3229

3330
public class RecordReplacer {
3431

@@ -61,9 +58,12 @@ public static boolean replace(final ClassNode classNode) {
6158
classNode.recordComponents = Collections.emptyList();
6259
}
6360

64-
classNode.methods.remove(ASMUtil.getMethod(classNode, "equals", EQUALS_DESC));
65-
final MethodVisitor equals = classNode.visitMethod(Opcodes.ACC_PUBLIC, "equals", EQUALS_DESC, null, null);
66-
{
61+
final MethodNode defaultEquals = ASMUtil.getMethod(classNode, "equals", EQUALS_DESC);
62+
if (defaultEquals == null) throw new IllegalStateException("Could not find default equals method");
63+
RecordField[] equalsFields = getFields(defaultEquals);
64+
if (equalsFields != null) {
65+
classNode.methods.remove(defaultEquals);
66+
final MethodVisitor equals = classNode.visitMethod(Opcodes.ACC_PUBLIC, "equals", EQUALS_DESC, null, null);
6767
equals.visitCode();
6868

6969
equals.visitVarInsn(Opcodes.ALOAD, 0);
@@ -88,12 +88,12 @@ public static boolean replace(final ClassNode classNode) {
8888
equals.visitVarInsn(Opcodes.ASTORE, 2);
8989

9090
final Label notEqualLabel = new Label();
91-
for (final RecordComponentNode component : classNode.recordComponents) {
91+
for (RecordField field : equalsFields) {
9292
equals.visitVarInsn(Opcodes.ALOAD, 0);
93-
equals.visitFieldInsn(Opcodes.GETFIELD, classNode.name, component.name, component.descriptor);
93+
equals.visitFieldInsn(Opcodes.GETFIELD, classNode.name, field.name, field.descriptor);
9494
equals.visitVarInsn(Opcodes.ALOAD, 2);
95-
equals.visitFieldInsn(Opcodes.GETFIELD, classNode.name, component.name, component.descriptor);
96-
if (Type.getType(component.descriptor).getSort() >= Type.ARRAY) { // ARRAY or OBJECT
95+
equals.visitFieldInsn(Opcodes.GETFIELD, classNode.name, field.name, field.descriptor);
96+
if (Type.getType(field.descriptor).getSort() >= Type.ARRAY) { // ARRAY or OBJECT
9797
equals.visitMethodInsn(
9898
Opcodes.INVOKESTATIC,
9999
Type.getInternalName(Objects.class),
@@ -103,29 +103,29 @@ public static boolean replace(final ClassNode classNode) {
103103
);
104104
equals.visitJumpInsn(Opcodes.IFEQ, notEqualLabel);
105105
continue;
106-
} else if ("BSCIZ".contains(component.descriptor)) {
106+
} else if ("BSCIZ".contains(field.descriptor)) {
107107
equals.visitJumpInsn(Opcodes.IF_ICMPNE, notEqualLabel);
108108
continue;
109-
} else if (component.descriptor.equals("F")) {
109+
} else if (field.descriptor.equals("F")) {
110110
equals.visitMethodInsn(
111111
Opcodes.INVOKESTATIC,
112112
Type.getInternalName(Float.class),
113113
"compare",
114114
"(FF)I",
115115
false
116116
);
117-
} else if (component.descriptor.equals("D")) {
117+
} else if (field.descriptor.equals("D")) {
118118
equals.visitMethodInsn(
119119
Opcodes.INVOKESTATIC,
120120
Type.getInternalName(Double.class),
121121
"compare",
122122
"(DD)I",
123123
false
124124
);
125-
} else if (component.descriptor.equals("J")) {
125+
} else if (field.descriptor.equals("J")) {
126126
equals.visitInsn(Opcodes.LCMP);
127127
} else {
128-
throw new AssertionError("Unknown descriptor " + component.descriptor);
128+
throw new AssertionError("Unknown descriptor " + field.descriptor);
129129
}
130130
equals.visitJumpInsn(Opcodes.IFNE, notEqualLabel);
131131
}
@@ -138,23 +138,26 @@ public static boolean replace(final ClassNode classNode) {
138138
equals.visitEnd();
139139
}
140140

141-
classNode.methods.remove(ASMUtil.getMethod(classNode, "hashCode", HASHCODE_DESC));
142-
final MethodVisitor hashCode = classNode.visitMethod(Opcodes.ACC_PUBLIC, "hashCode", HASHCODE_DESC, null, null);
143-
{
141+
final MethodNode defaultHashCode = ASMUtil.getMethod(classNode, "hashCode", HASHCODE_DESC);
142+
if (defaultHashCode == null) throw new IllegalStateException("Could not find default hashCode method");
143+
RecordField[] hashCodeFields = getFields(defaultHashCode);
144+
if (hashCodeFields != null) {
145+
classNode.methods.remove(defaultHashCode);
146+
final MethodVisitor hashCode = classNode.visitMethod(Opcodes.ACC_PUBLIC, "hashCode", HASHCODE_DESC, null, null);
144147
hashCode.visitCode();
145148

146149
hashCode.visitInsn(Opcodes.ICONST_0);
147-
for (final RecordComponentNode component : classNode.recordComponents) {
150+
for (RecordField field : hashCodeFields) {
148151
hashCode.visitIntInsn(Opcodes.BIPUSH, 31);
149152
hashCode.visitInsn(Opcodes.IMUL);
150153
hashCode.visitVarInsn(Opcodes.ALOAD, 0);
151-
hashCode.visitFieldInsn(Opcodes.GETFIELD, classNode.name, component.name, component.descriptor);
152-
final String owner = PRIMITIVE_WRAPPERS.get(component.descriptor);
154+
hashCode.visitFieldInsn(Opcodes.GETFIELD, classNode.name, field.name, field.descriptor);
155+
final String owner = PRIMITIVE_WRAPPERS.get(field.descriptor);
153156
hashCode.visitMethodInsn(
154157
Opcodes.INVOKESTATIC,
155158
owner != null ? owner : "java/util/Objects",
156159
"hashCode",
157-
"(" + (owner != null ? component.descriptor : "Ljava/lang/Object;") + ")I",
160+
"(" + (owner != null ? field.descriptor : "Ljava/lang/Object;") + ")I",
158161
false
159162
);
160163
hashCode.visitInsn(Opcodes.IADD);
@@ -164,9 +167,12 @@ public static boolean replace(final ClassNode classNode) {
164167
hashCode.visitEnd();
165168
}
166169

167-
classNode.methods.remove(ASMUtil.getMethod(classNode, "toString", TOSTRING_DESC));
168-
final MethodVisitor toString = classNode.visitMethod(Opcodes.ACC_PUBLIC, "toString", TOSTRING_DESC, null, null);
169-
{
170+
final MethodNode defaultToString = ASMUtil.getMethod(classNode, "toString", TOSTRING_DESC);
171+
if (defaultToString == null) throw new IllegalStateException("Could not find default toString method");
172+
RecordField[] toStringFields = getFields(defaultToString);
173+
if (toStringFields != null) {
174+
classNode.methods.remove(defaultToString);
175+
final MethodVisitor toString = classNode.visitMethod(Opcodes.ACC_PUBLIC, "toString", TOSTRING_DESC, null, null);
170176
toString.visitCode();
171177

172178
final StringBuilder formatString = new StringBuilder("%s[");
@@ -200,17 +206,17 @@ public static boolean replace(final ClassNode classNode) {
200206
);
201207
toString.visitInsn(Opcodes.AASTORE);
202208
int i = 1;
203-
for (final RecordComponentNode component : classNode.recordComponents) {
209+
for (RecordField field : toStringFields) {
204210
toString.visitInsn(Opcodes.DUP);
205211
toString.visitIntInsn(Opcodes.SIPUSH, i);
206212
toString.visitVarInsn(Opcodes.ALOAD, 0);
207-
toString.visitFieldInsn(Opcodes.GETFIELD, classNode.name, component.name, component.descriptor);
208-
final String owner = PRIMITIVE_WRAPPERS.get(component.descriptor);
213+
toString.visitFieldInsn(Opcodes.GETFIELD, classNode.name, field.name, field.descriptor);
214+
final String owner = PRIMITIVE_WRAPPERS.get(field.descriptor);
209215
toString.visitMethodInsn(
210216
Opcodes.INVOKESTATIC,
211217
owner != null ? owner : "java/util/Objects",
212218
"toString",
213-
"(" + (owner != null ? component.descriptor : "Ljava/lang/Object;") + ")Ljava/lang/String;",
219+
"(" + (owner != null ? field.descriptor : "Ljava/lang/Object;") + ")Ljava/lang/String;",
214220
false
215221
);
216222
toString.visitInsn(Opcodes.AASTORE);
@@ -232,4 +238,34 @@ public static boolean replace(final ClassNode classNode) {
232238
return true;
233239
}
234240

241+
private static RecordField[] getFields(final MethodNode method) {
242+
for (AbstractInsnNode instruction : method.instructions) {
243+
if (!(instruction instanceof InvokeDynamicInsnNode)) continue;
244+
final InvokeDynamicInsnNode invokeDynamic = (InvokeDynamicInsnNode) instruction;
245+
if (!invokeDynamic.bsm.getOwner().equals("java/lang/runtime/ObjectMethods")) continue;
246+
if (!invokeDynamic.bsm.getName().equals("bootstrap")) continue;
247+
if (!invokeDynamic.bsm.getDesc().equals(Constants.OBJECTMETHODS_BOOTSTRAP_DESC)) continue;
248+
249+
List<RecordField> fields = new ArrayList<>();
250+
for (int i = 2; i < invokeDynamic.bsmArgs.length; i++) {
251+
if (!(invokeDynamic.bsmArgs[i] instanceof Handle)) throw new IllegalStateException("bsm arg " + i + " is not a handle");
252+
final Handle handle = (Handle) invokeDynamic.bsmArgs[i];
253+
fields.add(new RecordField(handle.getName(), handle.getDesc()));
254+
}
255+
return fields.toArray(new RecordField[0]);
256+
}
257+
return null; // You can override equals/hashCode/toString, we should not replace them with the default impl
258+
}
259+
260+
261+
private static class RecordField {
262+
private final String name;
263+
private final String descriptor;
264+
265+
private RecordField(final String name, final String descriptor) {
266+
this.name = name;
267+
this.descriptor = descriptor;
268+
}
269+
}
270+
235271
}

src/main/java/net/raphimc/javadowngrader/util/Constants.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,6 @@ public class Constants {
2323
public static final String JAVADOWNGRADER_RUNTIME_ROOT = "RuntimeRoot.class";
2424

2525
public static final String METAFACTORY_DESC = "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodHandle;Ljava/lang/invoke/MethodType;)Ljava/lang/invoke/CallSite;";
26+
public static final String OBJECTMETHODS_BOOTSTRAP_DESC = "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/TypeDescriptor;Ljava/lang/Class;Ljava/lang/String;[Ljava/lang/invoke/MethodHandle;)Ljava/lang/Object;";
2627

2728
}

0 commit comments

Comments
 (0)