1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package art.constmethodhandle;
18 
19 import java.io.*;
20 import java.util.*;
21 import java.lang.invoke.CallSite;
22 import java.lang.invoke.MethodHandle;
23 import java.lang.invoke.MethodHandles;
24 import java.lang.invoke.MethodType;
25 import java.nio.file.*;
26 import org.objectweb.asm.ClassReader;
27 import org.objectweb.asm.ClassVisitor;
28 import org.objectweb.asm.ClassWriter;
29 import org.objectweb.asm.Handle;
30 import org.objectweb.asm.MethodVisitor;
31 import org.objectweb.asm.Opcodes;
32 import org.objectweb.asm.Type;
33 
34 // This test will modify in place the compiled java files to fill in the transformed version and
35 // fill in the TestInvoker.runTest function with a load-constant of a method-handle. It will use d8
36 // (passed in as an argument) to create the dex we will transform TestInvoke into.
37 public class TestGenerator {
38 
main(String[] args)39   public static void main(String[] args) throws IOException {
40     if (args.length != 2) {
41       throw new Error("Unable to convert class to dex without d8 binary!");
42     }
43     Path base = Paths.get(args[0]);
44     String d8Bin = args[1];
45 
46     Path initTestInvoke = base.resolve(TestGenerator.class.getPackage().getName().replace('.', '/'))
47                               .resolve(TestInvoke.class.getSimpleName() + ".class");
48     byte[] initClass = new FileInputStream(initTestInvoke.toFile()).readAllBytes();
49 
50     // Make the initial version of TestInvoker
51     generateInvoker(initClass, "sayHi", new FileOutputStream(initTestInvoke.toFile()));
52 
53     // Make the final 'class' version of testInvoker
54     ByteArrayOutputStream finalClass = new ByteArrayOutputStream();
55     generateInvoker(initClass, "sayBye", finalClass);
56 
57     Path initTest1948 = base.resolve("art").resolve(art.Test1948.class.getSimpleName() + ".class");
58     byte[] finalClassBytes = finalClass.toByteArray();
59     byte[] finalDexBytes = getFinalDexBytes(d8Bin, finalClassBytes);
60     generateTestCode(
61         new FileInputStream(initTest1948.toFile()).readAllBytes(),
62         finalClassBytes,
63         finalDexBytes,
64         new FileOutputStream(initTest1948.toFile()));
65   }
66 
67   // Modify the Test1948 class bytecode so it has the transformed version of TestInvoker as a string
68   // constant.
generateTestCode( byte[] initClass, byte[] transClass, byte[] transDex, OutputStream out)69   private static void generateTestCode(
70       byte[] initClass, byte[] transClass, byte[] transDex, OutputStream out) throws IOException {
71     ClassReader cr = new ClassReader(initClass);
72     ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES);
73     cr.accept(
74         new ClassVisitor(Opcodes.ASM9, cw) {
75           @Override
76           public void visitEnd() {
77             generateStringAccessorMethod(
78                 cw, "getDexBase64", Base64.getEncoder().encodeToString(transDex));
79             generateStringAccessorMethod(
80                 cw, "getClassBase64", Base64.getEncoder().encodeToString(transClass));
81             super.visitEnd();
82           }
83         }, 0);
84     out.write(cw.toByteArray());
85   }
86 
87   // Insert a string accessor method so we can get the transformed versions of TestInvoker.
generateStringAccessorMethod(ClassVisitor cv, String name, String ret)88   private static void generateStringAccessorMethod(ClassVisitor cv, String name, String ret) {
89     MethodVisitor mv = cv.visitMethod(
90         Opcodes.ACC_PUBLIC | Opcodes.ACC_STATIC | Opcodes.ACC_SYNTHETIC,
91         name, "()Ljava/lang/String;", null, null);
92     mv.visitLdcInsn(ret);
93     mv.visitInsn(Opcodes.ARETURN);
94     mv.visitMaxs(-1, -1);
95   }
96 
97   // Use d8bin to convert the classBytes into a dex file bytes. We need to do this here because we
98   // need the dex-file bytes to be used by the test class to redefine TestInvoker. We use d8 because
99   // it doesn't require setting up a directory structures or matching file names like dx does.
100   // TODO We should maybe just call d8 functions directly?
getFinalDexBytes(String d8Bin, byte[] classBytes)101   private static byte[] getFinalDexBytes(String d8Bin, byte[] classBytes) throws IOException {
102     Path tempDir = Files.createTempDirectory("FinalTestInvoker_Gen");
103     File tempInput = Files.createTempFile(tempDir, "temp_input_class", ".class").toFile();
104 
105     OutputStream tempClassStream = new FileOutputStream(tempInput);
106     tempClassStream.write(classBytes);
107     tempClassStream.close();
108     tempClassStream = null;
109 
110     Process d8Proc = new ProcessBuilder(d8Bin,
111                                         // Put classes.dex in the temp-dir we made.
112                                         "--output", tempDir.toAbsolutePath().toString(),
113                                         "--min-api", "28",  // Allow the new invoke ops.
114                                         "--no-desugaring",  // Don't try to be clever please.
115                                         tempInput.toPath().toAbsolutePath().toString())
116         .inheritIO()  // Just print to stdio.
117         .start();
118     int res;
119     try {
120       res = d8Proc.waitFor();
121     } catch (Exception e) {
122       System.out.println("Failed to dex: ".concat(e.toString()));
123       e.printStackTrace();
124       res = -123;
125     }
126     tempInput.delete();
127     try {
128       if (res == 0) {
129         byte[] out = new FileInputStream(tempDir.resolve("classes.dex").toFile()).readAllBytes();
130         tempDir.resolve("classes.dex").toFile().delete();
131         return out;
132       }
133     } finally {
134       tempDir.toFile().delete();
135     }
136     throw new Error("Failed to get dex file! " + res);
137   }
138 
generateInvoker( byte[] inputClass, String toCall, OutputStream output)139   private static void generateInvoker(
140       byte[] inputClass,
141       String toCall,
142       OutputStream output) throws IOException {
143     ClassReader cr = new ClassReader(inputClass);
144     ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES);
145     cr.accept(
146         new ClassVisitor(Opcodes.ASM9, cw) {
147           @Override
148           public void visitEnd() {
149             generateRunTest(cw, toCall);
150             super.visitEnd();
151           }
152         }, 0);
153     output.write(cw.toByteArray());
154   }
155 
156   // Creates the following method:
157   //   public runTest(Runnable preCall) {
158   //     preCall.run();
159   //     MethodHandle mh = <CONSTANT MH>;
160   //     mh.invokeExact();
161   //   }
generateRunTest(ClassVisitor cv, String toCall)162   private static void generateRunTest(ClassVisitor cv, String toCall) {
163     MethodVisitor mv = cv.visitMethod(Opcodes.ACC_PUBLIC,
164                                       "runTest", "(Ljava/lang/Runnable;)V", null, null);
165     MethodType mt = MethodType.methodType(Void.TYPE);
166     Handle mh = new Handle(
167         Opcodes.H_INVOKESTATIC,
168         Type.getInternalName(Responses.class),
169         toCall,
170         mt.toMethodDescriptorString(),
171         false);
172     String internalName = Type.getInternalName(Runnable.class);
173     mv.visitVarInsn(Opcodes.ALOAD, 1);
174     mv.visitMethodInsn(Opcodes.INVOKEINTERFACE, internalName, "run", "()V", true);
175     mv.visitLdcInsn(mh);
176     mv.visitMethodInsn(
177         Opcodes.INVOKEVIRTUAL,
178         Type.getInternalName(MethodHandle.class),
179         "invokeExact",
180         "()V",
181         false);
182     mv.visitInsn(Opcodes.RETURN);
183     mv.visitMaxs(-1, -1);
184   }
185 }
186