1818package net .raphimc .javadowngrader .transformer .j15 ;
1919
2020import net .raphimc .javadowngrader .util .ASMUtil ;
21- import org .objectweb .asm .Label ;
22- import org .objectweb .asm .MethodVisitor ;
23- import org .objectweb .asm .Opcodes ;
24- import org .objectweb .asm .Type ;
21+ import net .raphimc .javadowngrader .util .Constants ;
22+ import org .objectweb .asm .*;
23+ import org .objectweb .asm .tree .AbstractInsnNode ;
2524import org .objectweb .asm .tree .ClassNode ;
26- import org .objectweb .asm .tree .RecordComponentNode ;
25+ import org .objectweb .asm .tree .InvokeDynamicInsnNode ;
26+ import org .objectweb .asm .tree .MethodNode ;
2727
28- import java .util .Collections ;
29- import java .util .HashMap ;
30- import java .util .Map ;
31- import java .util .Objects ;
28+ import java .util .*;
3229
3330public class RecordReplacer {
3431
@@ -61,9 +58,12 @@ public static boolean replace(final ClassNode classNode) {
6158 classNode .recordComponents = Collections .emptyList ();
6259 }
6360
64- classNode .methods .remove (ASMUtil .getMethod (classNode , "equals" , EQUALS_DESC ));
65- final MethodVisitor equals = classNode .visitMethod (Opcodes .ACC_PUBLIC , "equals" , EQUALS_DESC , null , null );
66- {
61+ final MethodNode defaultEquals = ASMUtil .getMethod (classNode , "equals" , EQUALS_DESC );
62+ if (defaultEquals == null ) throw new IllegalStateException ("Could not find default equals method" );
63+ RecordField [] equalsFields = getFields (defaultEquals );
64+ if (equalsFields != null ) {
65+ classNode .methods .remove (defaultEquals );
66+ final MethodVisitor equals = classNode .visitMethod (Opcodes .ACC_PUBLIC , "equals" , EQUALS_DESC , null , null );
6767 equals .visitCode ();
6868
6969 equals .visitVarInsn (Opcodes .ALOAD , 0 );
@@ -88,12 +88,12 @@ public static boolean replace(final ClassNode classNode) {
8888 equals .visitVarInsn (Opcodes .ASTORE , 2 );
8989
9090 final Label notEqualLabel = new Label ();
91- for (final RecordComponentNode component : classNode . recordComponents ) {
91+ for (RecordField field : equalsFields ) {
9292 equals .visitVarInsn (Opcodes .ALOAD , 0 );
93- equals .visitFieldInsn (Opcodes .GETFIELD , classNode .name , component .name , component .descriptor );
93+ equals .visitFieldInsn (Opcodes .GETFIELD , classNode .name , field .name , field .descriptor );
9494 equals .visitVarInsn (Opcodes .ALOAD , 2 );
95- equals .visitFieldInsn (Opcodes .GETFIELD , classNode .name , component .name , component .descriptor );
96- if (Type .getType (component .descriptor ).getSort () >= Type .ARRAY ) { // ARRAY or OBJECT
95+ equals .visitFieldInsn (Opcodes .GETFIELD , classNode .name , field .name , field .descriptor );
96+ if (Type .getType (field .descriptor ).getSort () >= Type .ARRAY ) { // ARRAY or OBJECT
9797 equals .visitMethodInsn (
9898 Opcodes .INVOKESTATIC ,
9999 Type .getInternalName (Objects .class ),
@@ -103,29 +103,29 @@ public static boolean replace(final ClassNode classNode) {
103103 );
104104 equals .visitJumpInsn (Opcodes .IFEQ , notEqualLabel );
105105 continue ;
106- } else if ("BSCIZ" .contains (component .descriptor )) {
106+ } else if ("BSCIZ" .contains (field .descriptor )) {
107107 equals .visitJumpInsn (Opcodes .IF_ICMPNE , notEqualLabel );
108108 continue ;
109- } else if (component .descriptor .equals ("F" )) {
109+ } else if (field .descriptor .equals ("F" )) {
110110 equals .visitMethodInsn (
111111 Opcodes .INVOKESTATIC ,
112112 Type .getInternalName (Float .class ),
113113 "compare" ,
114114 "(FF)I" ,
115115 false
116116 );
117- } else if (component .descriptor .equals ("D" )) {
117+ } else if (field .descriptor .equals ("D" )) {
118118 equals .visitMethodInsn (
119119 Opcodes .INVOKESTATIC ,
120120 Type .getInternalName (Double .class ),
121121 "compare" ,
122122 "(DD)I" ,
123123 false
124124 );
125- } else if (component .descriptor .equals ("J" )) {
125+ } else if (field .descriptor .equals ("J" )) {
126126 equals .visitInsn (Opcodes .LCMP );
127127 } else {
128- throw new AssertionError ("Unknown descriptor " + component .descriptor );
128+ throw new AssertionError ("Unknown descriptor " + field .descriptor );
129129 }
130130 equals .visitJumpInsn (Opcodes .IFNE , notEqualLabel );
131131 }
@@ -138,23 +138,26 @@ public static boolean replace(final ClassNode classNode) {
138138 equals .visitEnd ();
139139 }
140140
141- classNode .methods .remove (ASMUtil .getMethod (classNode , "hashCode" , HASHCODE_DESC ));
142- final MethodVisitor hashCode = classNode .visitMethod (Opcodes .ACC_PUBLIC , "hashCode" , HASHCODE_DESC , null , null );
143- {
141+ final MethodNode defaultHashCode = ASMUtil .getMethod (classNode , "hashCode" , HASHCODE_DESC );
142+ if (defaultHashCode == null ) throw new IllegalStateException ("Could not find default hashCode method" );
143+ RecordField [] hashCodeFields = getFields (defaultHashCode );
144+ if (hashCodeFields != null ) {
145+ classNode .methods .remove (defaultHashCode );
146+ final MethodVisitor hashCode = classNode .visitMethod (Opcodes .ACC_PUBLIC , "hashCode" , HASHCODE_DESC , null , null );
144147 hashCode .visitCode ();
145148
146149 hashCode .visitInsn (Opcodes .ICONST_0 );
147- for (final RecordComponentNode component : classNode . recordComponents ) {
150+ for (RecordField field : hashCodeFields ) {
148151 hashCode .visitIntInsn (Opcodes .BIPUSH , 31 );
149152 hashCode .visitInsn (Opcodes .IMUL );
150153 hashCode .visitVarInsn (Opcodes .ALOAD , 0 );
151- hashCode .visitFieldInsn (Opcodes .GETFIELD , classNode .name , component .name , component .descriptor );
152- final String owner = PRIMITIVE_WRAPPERS .get (component .descriptor );
154+ hashCode .visitFieldInsn (Opcodes .GETFIELD , classNode .name , field .name , field .descriptor );
155+ final String owner = PRIMITIVE_WRAPPERS .get (field .descriptor );
153156 hashCode .visitMethodInsn (
154157 Opcodes .INVOKESTATIC ,
155158 owner != null ? owner : "java/util/Objects" ,
156159 "hashCode" ,
157- "(" + (owner != null ? component .descriptor : "Ljava/lang/Object;" ) + ")I" ,
160+ "(" + (owner != null ? field .descriptor : "Ljava/lang/Object;" ) + ")I" ,
158161 false
159162 );
160163 hashCode .visitInsn (Opcodes .IADD );
@@ -164,9 +167,12 @@ public static boolean replace(final ClassNode classNode) {
164167 hashCode .visitEnd ();
165168 }
166169
167- classNode .methods .remove (ASMUtil .getMethod (classNode , "toString" , TOSTRING_DESC ));
168- final MethodVisitor toString = classNode .visitMethod (Opcodes .ACC_PUBLIC , "toString" , TOSTRING_DESC , null , null );
169- {
170+ final MethodNode defaultToString = ASMUtil .getMethod (classNode , "toString" , TOSTRING_DESC );
171+ if (defaultToString == null ) throw new IllegalStateException ("Could not find default toString method" );
172+ RecordField [] toStringFields = getFields (defaultToString );
173+ if (toStringFields != null ) {
174+ classNode .methods .remove (defaultToString );
175+ final MethodVisitor toString = classNode .visitMethod (Opcodes .ACC_PUBLIC , "toString" , TOSTRING_DESC , null , null );
170176 toString .visitCode ();
171177
172178 final StringBuilder formatString = new StringBuilder ("%s[" );
@@ -200,17 +206,17 @@ public static boolean replace(final ClassNode classNode) {
200206 );
201207 toString .visitInsn (Opcodes .AASTORE );
202208 int i = 1 ;
203- for (final RecordComponentNode component : classNode . recordComponents ) {
209+ for (RecordField field : toStringFields ) {
204210 toString .visitInsn (Opcodes .DUP );
205211 toString .visitIntInsn (Opcodes .SIPUSH , i );
206212 toString .visitVarInsn (Opcodes .ALOAD , 0 );
207- toString .visitFieldInsn (Opcodes .GETFIELD , classNode .name , component .name , component .descriptor );
208- final String owner = PRIMITIVE_WRAPPERS .get (component .descriptor );
213+ toString .visitFieldInsn (Opcodes .GETFIELD , classNode .name , field .name , field .descriptor );
214+ final String owner = PRIMITIVE_WRAPPERS .get (field .descriptor );
209215 toString .visitMethodInsn (
210216 Opcodes .INVOKESTATIC ,
211217 owner != null ? owner : "java/util/Objects" ,
212218 "toString" ,
213- "(" + (owner != null ? component .descriptor : "Ljava/lang/Object;" ) + ")Ljava/lang/String;" ,
219+ "(" + (owner != null ? field .descriptor : "Ljava/lang/Object;" ) + ")Ljava/lang/String;" ,
214220 false
215221 );
216222 toString .visitInsn (Opcodes .AASTORE );
@@ -232,4 +238,34 @@ public static boolean replace(final ClassNode classNode) {
232238 return true ;
233239 }
234240
241+ private static RecordField [] getFields (final MethodNode method ) {
242+ for (AbstractInsnNode instruction : method .instructions ) {
243+ if (!(instruction instanceof InvokeDynamicInsnNode )) continue ;
244+ final InvokeDynamicInsnNode invokeDynamic = (InvokeDynamicInsnNode ) instruction ;
245+ if (!invokeDynamic .bsm .getOwner ().equals ("java/lang/runtime/ObjectMethods" )) continue ;
246+ if (!invokeDynamic .bsm .getName ().equals ("bootstrap" )) continue ;
247+ if (!invokeDynamic .bsm .getDesc ().equals (Constants .OBJECTMETHODS_BOOTSTRAP_DESC )) continue ;
248+
249+ List <RecordField > fields = new ArrayList <>();
250+ for (int i = 2 ; i < invokeDynamic .bsmArgs .length ; i ++) {
251+ if (!(invokeDynamic .bsmArgs [i ] instanceof Handle )) throw new IllegalStateException ("bsm arg " + i + " is not a handle" );
252+ final Handle handle = (Handle ) invokeDynamic .bsmArgs [i ];
253+ fields .add (new RecordField (handle .getName (), handle .getDesc ()));
254+ }
255+ return fields .toArray (new RecordField [0 ]);
256+ }
257+ return null ; // You can override equals/hashCode/toString, we should not replace them with the default impl
258+ }
259+
260+
261+ private static class RecordField {
262+ private final String name ;
263+ private final String descriptor ;
264+
265+ private RecordField (final String name , final String descriptor ) {
266+ this .name = name ;
267+ this .descriptor = descriptor ;
268+ }
269+ }
270+
235271}
0 commit comments