1 /*
<lambda>null2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.android.hoststubgen.visitors
17 
18 import com.android.hoststubgen.asm.CLASS_INITIALIZER_DESC
19 import com.android.hoststubgen.asm.CLASS_INITIALIZER_NAME
20 import com.android.hoststubgen.asm.ClassNodes
21 import com.android.hoststubgen.asm.isVisibilityPrivateOrPackagePrivate
22 import com.android.hoststubgen.asm.prependArgTypeToMethodDescriptor
23 import com.android.hoststubgen.asm.writeByteCodeToPushArguments
24 import com.android.hoststubgen.asm.writeByteCodeToReturn
25 import com.android.hoststubgen.filters.FilterPolicy
26 import com.android.hoststubgen.filters.FilterPolicyWithReason
27 import com.android.hoststubgen.filters.OutputFilter
28 import com.android.hoststubgen.hosthelper.HostStubGenProcessedAsIgnore
29 import com.android.hoststubgen.hosthelper.HostStubGenProcessedAsSubstitute
30 import com.android.hoststubgen.hosthelper.HostStubGenProcessedAsThrow
31 import com.android.hoststubgen.hosthelper.HostTestUtils
32 import com.android.hoststubgen.log
33 import org.objectweb.asm.ClassVisitor
34 import org.objectweb.asm.MethodVisitor
35 import org.objectweb.asm.Opcodes
36 import org.objectweb.asm.Type
37 
38 /**
39  * An adapter that generates the "impl" class file from an input class file.
40  */
41 class ImplGeneratingAdapter(
42         classes: ClassNodes,
43         nextVisitor: ClassVisitor,
44         filter: OutputFilter,
45         options: Options,
46 ) : BaseAdapter(classes, nextVisitor, filter, options) {
47 
48     override fun shouldEmit(policy: FilterPolicy): Boolean {
49         return policy.needsInImpl
50     }
51 
52     private var classLoadHooks: List<String> = emptyList()
53 
54     override fun visit(
55         version: Int,
56         access: Int,
57         name: String,
58         signature: String?,
59         superName: String?,
60         interfaces: Array<String>
61     ) {
62         super.visit(version, access, name, signature, superName, interfaces)
63 
64         classLoadHooks = filter.getClassLoadHooks(currentClassName)
65 
66         // classLoadHookMethod is non-null, then we need to inject code to call it
67         // in the class initializer.
68         // If the target class already has a class initializer, then we need to inject code to it.
69         // Otherwise, we need to create one.
70 
71         if (classLoadHooks.isNotEmpty()) {
72             log.d("  ClassLoadHooks: $classLoadHooks")
73             if (!classes.hasClassInitializer(currentClassName)) {
74                 injectClassLoadHook()
75             }
76         }
77     }
78 
79     private fun injectClassLoadHook() {
80         writeRawMembers {
81             // Create a class initializer to call onClassLoaded().
82             // Each class can only have at most one class initializer, but the base class
83             // StaticInitMerger will merge it with the existing one, if any.
84             visitMethod(
85                 Opcodes.ACC_PRIVATE or Opcodes.ACC_STATIC,
86                 CLASS_INITIALIZER_NAME,
87                 "()V",
88                 null,
89                 null
90             )!!.let { mv ->
91                 // Method prologue
92                 mv.visitCode()
93 
94                 writeClassLoadHookCalls(mv)
95                 mv.visitInsn(Opcodes.RETURN)
96 
97                 // Method epilogue
98                 mv.visitMaxs(0, 0)
99                 mv.visitEnd()
100             }
101         }
102     }
103 
104     private fun writeClassLoadHookCalls(mv: MethodVisitor) {
105         classLoadHooks.forEach { classLoadHook ->
106             // First argument: the class type.
107             mv.visitLdcInsn(Type.getType("L" + currentClassName + ";"))
108 
109             // Second argument: method name
110             mv.visitLdcInsn(classLoadHook)
111 
112             // Call HostTestUtils.onClassLoaded().
113             mv.visitMethodInsn(
114                 Opcodes.INVOKESTATIC,
115                 HostTestUtils.CLASS_INTERNAL_NAME,
116                 "onClassLoaded",
117                 "(Ljava/lang/Class;Ljava/lang/String;)V",
118                 false
119             )
120         }
121     }
122 
123     override fun updateAccessFlags(
124             access: Int,
125             name: String,
126             descriptor: String,
127     ): Int {
128         if ((access and Opcodes.ACC_NATIVE) != 0 && nativeSubstitutionClass != null) {
129             return access and Opcodes.ACC_NATIVE.inv()
130         }
131         return access
132     }
133 
134     override fun visitMethodInner(
135             access: Int,
136             name: String,
137             descriptor: String,
138             signature: String?,
139             exceptions: Array<String>?,
140             policy: FilterPolicyWithReason,
141             substituted: Boolean,
142             superVisitor: MethodVisitor?,
143     ): MethodVisitor? {
144         // Inject method log, if needed.
145         var innerVisitor = superVisitor
146 
147         //  If method logging is enabled, inject call to the logging method.
148         val methodCallHooks = filter.getMethodCallHooks(currentClassName, name, descriptor)
149         if (methodCallHooks.isNotEmpty()) {
150             innerVisitor = MethodCallHookInjectingAdapter(
151                 access,
152                 name,
153                 descriptor,
154                 signature,
155                 exceptions,
156                 innerVisitor,
157                 methodCallHooks,
158                 )
159         }
160 
161         // If this class already has a class initializer and a class load hook is needed, then
162         // we inject code.
163         if (classLoadHooks.isNotEmpty() &&
164             name == CLASS_INITIALIZER_NAME &&
165             descriptor == CLASS_INITIALIZER_DESC) {
166             innerVisitor = ClassLoadHookInjectingMethodAdapter(
167                 access,
168                 name,
169                 descriptor,
170                 signature,
171                 exceptions,
172                 innerVisitor,
173             )
174         }
175 
176         // If non-stub method call detection is enabled, then inject a call to the checker.
177         if (options.enableNonStubMethodCallDetection && doesMethodNeedNonStubCallCheck(
178                 access, name, descriptor, policy) ) {
179             innerVisitor = NonStubMethodCallDetectingAdapter(
180                     access,
181                     name,
182                     descriptor,
183                     signature,
184                     exceptions,
185                     innerVisitor,
186             )
187         }
188 
189         fun MethodVisitor.withAnnotation(descriptor: String): MethodVisitor {
190             this.visitAnnotation(descriptor, true)
191             return this
192         }
193 
194         log.withIndent {
195             var willThrow = false
196             if (policy.policy == FilterPolicy.Throw) {
197                 log.v("Making method throw...")
198                 willThrow = true
199                 innerVisitor = ThrowingMethodAdapter(
200                     access, name, descriptor, signature, exceptions, innerVisitor)
201                     .withAnnotation(HostStubGenProcessedAsThrow.CLASS_DESCRIPTOR)
202             }
203             if ((access and Opcodes.ACC_NATIVE) != 0 && nativeSubstitutionClass != null) {
204                 log.v("Rewriting native method...")
205                 return NativeSubstitutingMethodAdapter(
206                         access, name, descriptor, signature, exceptions, innerVisitor)
207                     .withAnnotation(HostStubGenProcessedAsSubstitute.CLASS_DESCRIPTOR)
208             }
209             if (willThrow) {
210                 return innerVisitor
211             }
212 
213             if (policy.policy == FilterPolicy.Ignore) {
214                 when (Type.getReturnType(descriptor)) {
215                     Type.VOID_TYPE -> {
216                         log.v("Making method ignored...")
217                         return IgnoreMethodAdapter(
218                                 access, name, descriptor, signature, exceptions, innerVisitor)
219                             .withAnnotation(HostStubGenProcessedAsIgnore.CLASS_DESCRIPTOR)
220                     }
221                     else -> {
222                         throw RuntimeException("Ignored policy only allowed for void methods")
223                     }
224                 }
225             }
226         }
227         if (substituted) {
228             innerVisitor?.withAnnotation(HostStubGenProcessedAsSubstitute.CLASS_DESCRIPTOR)
229         }
230 
231         return innerVisitor
232     }
233 
234     fun doesMethodNeedNonStubCallCheck(
235             access: Int,
236             name: String,
237             descriptor: String,
238             policy: FilterPolicyWithReason,
239     ): Boolean {
240         // If a method is in the stub, then no need to check.
241         if (policy.policy.needsInStub) {
242             return false
243         }
244         // If a method is private or package-private, no need to check.
245         // Technically test code can use framework package name, so it's a bit too lenient.
246         if (isVisibilityPrivateOrPackagePrivate(access)) {
247             return false
248         }
249         // TODO: If the method overrides a method that's accessible by tests, then we shouldn't
250         // do the check. (e.g. overrides a stub method or java standard method.)
251 
252         return true
253     }
254 
255     /**
256      * A method adapter that replaces the method body with a HostTestUtils.onThrowMethodCalled()
257      * call.
258      */
259     private inner class ThrowingMethodAdapter(
260             access: Int,
261             val name: String,
262             descriptor: String,
263             signature: String?,
264             exceptions: Array<String>?,
265             next: MethodVisitor?
266     ) : BodyReplacingMethodVisitor(access, name, descriptor, signature, exceptions, next) {
267         override fun emitNewCode() {
268             visitMethodInsn(Opcodes.INVOKESTATIC,
269                     HostTestUtils.CLASS_INTERNAL_NAME,
270                     "onThrowMethodCalled",
271                     "()V",
272                     false)
273 
274             // We still need a RETURN opcode for the return type.
275             // For now, let's just inject a `throw`.
276             visitTypeInsn(Opcodes.NEW, "java/lang/RuntimeException")
277             visitInsn(Opcodes.DUP)
278             visitLdcInsn("Unreachable")
279             visitMethodInsn(Opcodes.INVOKESPECIAL, "java/lang/RuntimeException",
280                     "<init>", "(Ljava/lang/String;)V", false)
281             visitInsn(Opcodes.ATHROW)
282 
283             // visitMaxs(3, if (isStatic) 0 else 1)
284             visitMaxs(0, 0) // We let ASM figure them out.
285         }
286     }
287 
288     /**
289      * A method adapter that replaces the method body with a no-op return.
290      */
291     private inner class IgnoreMethodAdapter(
292             access: Int,
293             val name: String,
294             descriptor: String,
295             signature: String?,
296             exceptions: Array<String>?,
297             next: MethodVisitor?
298     ) : BodyReplacingMethodVisitor(access, name, descriptor, signature, exceptions, next) {
299         override fun emitNewCode() {
300             visitInsn(Opcodes.RETURN)
301             visitMaxs(0, 0) // We let ASM figure them out.
302         }
303     }
304 
305     /**
306      * A method adapter that replaces a native method call with a call to the "native substitution"
307      * class.
308      */
309     private inner class NativeSubstitutingMethodAdapter(
310             val access: Int,
311             private val name: String,
312             private val descriptor: String,
313             signature: String?,
314             exceptions: Array<String>?,
315             next: MethodVisitor?
316     ) : MethodVisitor(OPCODE_VERSION, next) {
317         override fun visitCode() {
318             throw RuntimeException("NativeSubstitutingMethodVisitor should be called on " +
319                     " native method, where visitCode() shouldn't be called.")
320         }
321 
322         override fun visitEnd() {
323             super.visitCode()
324 
325             var targetDescriptor = descriptor
326             var argOffset = 0
327 
328             // For non-static native method, we need to tweak it a bit.
329             if ((access and Opcodes.ACC_STATIC) == 0) {
330                 // Push `this` as the first argument.
331                 this.visitVarInsn(Opcodes.ALOAD, 0)
332 
333                 // Update the descriptor -- add this class's type as the first argument
334                 // to the method descriptor.
335                 val thisType = Type.getType("L" + currentClassName + ";")
336 
337                 targetDescriptor = prependArgTypeToMethodDescriptor(
338                         descriptor,
339                         thisType,
340                 )
341 
342                 // Shift the original arguments by one.
343                 argOffset = 1
344             }
345 
346             writeByteCodeToPushArguments(descriptor, this, argOffset)
347 
348             visitMethodInsn(Opcodes.INVOKESTATIC,
349                     nativeSubstitutionClass,
350                     name,
351                     targetDescriptor,
352                     false)
353 
354             writeByteCodeToReturn(descriptor, this)
355 
356             visitMaxs(99, 0) // We let ASM figure them out.
357             super.visitEnd()
358         }
359     }
360 
361     /**
362      * Inject calls to the method call hooks.
363      *
364      * Note, when the target method is a constructor, it may contain calls to `super(...)` or
365      * `this(...)`. The logging code will be injected *before* such calls.
366      */
367     private inner class MethodCallHookInjectingAdapter(
368             access: Int,
369             val name: String,
370             val descriptor: String,
371             signature: String?,
372             exceptions: Array<String>?,
373             next: MethodVisitor?,
374             val hooks: List<String>,
375     ) : MethodVisitor(OPCODE_VERSION, next) {
376         override fun visitCode() {
377             super.visitCode()
378 
379             hooks.forEach { hook ->
380                 mv.visitLdcInsn(Type.getType("L" + currentClassName + ";"))
381                 visitLdcInsn(name)
382                 visitLdcInsn(descriptor)
383                 visitLdcInsn(hook)
384 
385                 visitMethodInsn(
386                     Opcodes.INVOKESTATIC,
387                     HostTestUtils.CLASS_INTERNAL_NAME,
388                     "callMethodCallHook",
389                     "(Ljava/lang/Class;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)V",
390                     false
391                 )
392             }
393         }
394     }
395 
396     /**
397      * Inject a class load hook call.
398      */
399     private inner class ClassLoadHookInjectingMethodAdapter(
400         access: Int,
401         val name: String,
402         val descriptor: String,
403         signature: String?,
404         exceptions: Array<String>?,
405         next: MethodVisitor?
406     ) : MethodVisitor(OPCODE_VERSION, next) {
407         override fun visitCode() {
408             super.visitCode()
409 
410             writeClassLoadHookCalls(this)
411         }
412     }
413 
414     /**
415      * A method adapter that detects calls to non-stub methods.
416      */
417     private inner class NonStubMethodCallDetectingAdapter(
418             access: Int,
419             val name: String,
420             val descriptor: String,
421             signature: String?,
422             exceptions: Array<String>?,
423             next: MethodVisitor?
424     ) : MethodVisitor(OPCODE_VERSION, next) {
425         override fun visitCode() {
426             super.visitCode()
427 
428             // First three arguments to HostTestUtils.onNonStubMethodCalled().
429             visitLdcInsn(currentClassName)
430             visitLdcInsn(name)
431             visitLdcInsn(descriptor)
432 
433             // Call: HostTestUtils.getStackWalker().getCallerClass().
434             // This push the caller Class in the stack.
435             visitMethodInsn(Opcodes.INVOKESTATIC,
436                     HostTestUtils.CLASS_INTERNAL_NAME,
437                     "getStackWalker",
438                     "()Ljava/lang/StackWalker;",
439                     false)
440             visitMethodInsn(Opcodes.INVOKEVIRTUAL,
441                     "java/lang/StackWalker",
442                     "getCallerClass",
443                     "()Ljava/lang/Class;",
444                     false)
445 
446             // Then call onNonStubMethodCalled().
447             visitMethodInsn(Opcodes.INVOKESTATIC,
448                     HostTestUtils.CLASS_INTERNAL_NAME,
449                     "onNonStubMethodCalled",
450                     "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/Class;)V",
451                     false)
452         }
453     }
454 }
455