1 /*
<lambda>null2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.tools.flicker.junit
18 
19 import android.app.Instrumentation
20 import android.device.collectors.util.SendToInstrumentation
21 import android.os.Bundle
22 import android.tools.Scenario
23 import android.tools.ScenarioBuilder
24 import android.tools.flicker.FlickerService
25 import android.tools.flicker.FlickerServiceResultsCollector.Companion.FLICKER_ASSERTIONS_COUNT_KEY
26 import android.tools.flicker.ScenarioInstance
27 import android.tools.flicker.Utils.captureTrace
28 import android.tools.flicker.annotation.ExpectedScenarios
29 import android.tools.flicker.annotation.FlickerConfigProvider
30 import android.tools.flicker.assertions.ScenarioAssertion
31 import android.tools.flicker.config.FlickerConfig
32 import android.tools.flicker.config.ScenarioId
33 import android.tools.io.Reader
34 import android.tools.traces.getDefaultFlickerOutputDir
35 import android.tools.traces.now
36 import androidx.test.platform.app.InstrumentationRegistry
37 import com.google.common.truth.Truth
38 import java.lang.reflect.Method
39 import org.junit.After
40 import org.junit.Before
41 import org.junit.Test
42 import org.junit.runner.Description
43 import org.junit.runners.model.FrameworkMethod
44 import org.junit.runners.model.Statement
45 import org.junit.runners.model.TestClass
46 
47 class FlickerServiceDecorator(
48     testClass: TestClass,
49     val paramString: String?,
50     private val skipNonBlocking: Boolean,
51     inner: IFlickerJUnitDecorator?,
52     instrumentation: Instrumentation = InstrumentationRegistry.getInstrumentation(),
53     flickerService: FlickerService? = null
54 ) : AbstractFlickerRunnerDecorator(testClass, inner, instrumentation) {
55     private val flickerService by lazy { flickerService ?: FlickerService(getFlickerConfig()) }
56 
57     private val testClassName =
58         ScenarioBuilder().forClass("${testClass.name}${paramString ?: ""}").build()
59 
60     override fun getChildDescription(method: FrameworkMethod): Description {
61         return if (isMethodHandledByDecorator(method)) {
62             Description.createTestDescription(testClass.javaClass, method.name, *method.annotations)
63         } else {
64             inner?.getChildDescription(method) ?: error("No child descriptor found")
65         }
66     }
67 
68     private val flickerServiceMethodsFor =
69         mutableMapOf<FrameworkMethod, Collection<InjectedTestCase>>()
70     private val innerMethodsResults = mutableMapOf<FrameworkMethod, Throwable?>()
71 
72     override fun getTestMethods(test: Any): List<FrameworkMethod> {
73         val innerMethods =
74             inner?.getTestMethods(test)
75                 ?: error("FlickerServiceDecorator requires a non-null inner decorator")
76         val testMethods = innerMethods.toMutableList()
77 
78         if (shouldComputeTestMethods()) {
79             for (method in innerMethods) {
80                 if (!innerMethodsResults.containsKey(method)) {
81                     var methodResult: Throwable? =
82                         null // TODO: Maybe don't use null but wrap in another object
83                     val reader =
84                         captureTrace(testClassName, getDefaultFlickerOutputDir()) { writer ->
85                             try {
86                                 Utils.notifyRunnerProgress(
87                                     testClassName,
88                                     "Running setup",
89                                     instrumentation
90                                 )
91                                 val befores = testClass.getAnnotatedMethods(Before::class.java)
92                                 befores.forEach { it.invokeExplosively(test) }
93 
94                                 Utils.notifyRunnerProgress(
95                                     testClassName,
96                                     "Running transition",
97                                     instrumentation
98                                 )
99                                 writer.setTransitionStartTime(now())
100                                 method.invokeExplosively(test)
101                                 writer.setTransitionEndTime(now())
102 
103                                 Utils.notifyRunnerProgress(
104                                     testClassName,
105                                     "Running teardown",
106                                     instrumentation
107                                 )
108                                 val afters = testClass.getAnnotatedMethods(After::class.java)
109                                 afters.forEach { it.invokeExplosively(test) }
110                             } catch (e: Throwable) {
111                                 methodResult = e
112                             } finally {
113                                 innerMethodsResults[method] = methodResult
114                             }
115                         }
116                     if (methodResult == null) {
117                         Utils.notifyRunnerProgress(
118                             testClassName,
119                             "Computing Flicker service tests",
120                             instrumentation
121                         )
122                         try {
123                             flickerServiceMethodsFor[method] =
124                                 computeFlickerServiceTests(reader, testClassName, method)
125                         } catch (e: Throwable) {
126                             // Failed to compute flicker service methods
127                             innerMethodsResults[method] = e
128                         }
129                     }
130                 }
131 
132                 if (innerMethodsResults[method] == null) {
133                     testMethods.addAll(flickerServiceMethodsFor[method]!!)
134                 }
135             }
136         }
137 
138         return testMethods
139     }
140 
141     // TODO: Common with LegacyFlickerServiceDecorator, might be worth extracting this up
142     private fun shouldComputeTestMethods(): Boolean {
143         // Don't compute when called from validateInstanceMethods since this will fail
144         // as the parameters will not be set. And AndroidLogOnlyBuilder is a non-executing runner
145         // used to run tests in dry-run mode, so we don't want to execute in flicker transition in
146         // that case either.
147         val stackTrace = Thread.currentThread().stackTrace
148         val isDryRun =
149             stackTrace.any { it.methodName == "validateInstanceMethods" } ||
150                 stackTrace.any {
151                     it.className == "androidx.test.internal.runner.AndroidLogOnlyBuilder"
152                 } ||
153                 stackTrace.any {
154                     it.className == "androidx.test.internal.runner.NonExecutingRunner"
155                 }
156 
157         return !isDryRun
158     }
159 
160     override fun getMethodInvoker(method: FrameworkMethod, test: Any): Statement {
161         return object : Statement() {
162             @Throws(Throwable::class)
163             override fun evaluate() {
164                 val description = getChildDescription(method)
165                 if (isMethodHandledByDecorator(method)) {
166                     (method as InjectedTestCase).execute(description)
167                 } else {
168                     if (innerMethodsResults.containsKey(method)) {
169                         innerMethodsResults[method]?.let { throw it }
170                     } else {
171                         inner?.getMethodInvoker(method, test)?.evaluate()
172                     }
173                 }
174             }
175         }
176     }
177 
178     override fun doValidateInstanceMethods(): List<Throwable> {
179         val errors = super.doValidateInstanceMethods().toMutableList()
180 
181         val testMethods = testClass.getAnnotatedMethods(Test::class.java)
182         if (testMethods.size > 1) {
183             errors.add(IllegalArgumentException("Only one @Test annotated method is supported"))
184         }
185 
186         // Validate Registry provider
187         val flickerConfigProviderProviderFunctions =
188             testClass.getAnnotatedMethods(FlickerConfigProvider::class.java).filter {
189                 it.isStatic && it.isPublic
190             }
191         if (flickerConfigProviderProviderFunctions.isEmpty()) {
192             errors.add(
193                 IllegalArgumentException(
194                     "A public static function returning a " +
195                         "${FlickerConfig::class.simpleName} annotated with " +
196                         "@${FlickerConfigProvider::class.simpleName} should be provided."
197                 )
198             )
199         } else if (flickerConfigProviderProviderFunctions.size > 1) {
200             errors.add(
201                 IllegalArgumentException(
202                     "Only one @${FlickerConfigProvider::class.simpleName} " +
203                         "annotated method is supported."
204                 )
205             )
206         } else if (
207             flickerConfigProviderProviderFunctions.first().returnType.name !=
208                 FlickerConfig::class.qualifiedName
209         ) {
210             errors.add(
211                 IllegalArgumentException(
212                     "Expected method annotated with " +
213                         "@${FlickerConfig::class.simpleName} to return " +
214                         "${FlickerConfig::class.qualifiedName} but was " +
215                         "${flickerConfigProviderProviderFunctions.first().returnType.name} instead."
216                 )
217             )
218         } else {
219             // Validate @ExpectedScenarios annotation
220             val expectedScenarioAnnotations =
221                 testClass.getAnnotatedMethods(ExpectedScenarios::class.java).map {
222                     it.getAnnotation(ExpectedScenarios::class.java)
223                 }
224             val registeredScenarios = getFlickerConfig().getEntries().map { it.scenarioId.name }
225             for (expectedScenarioAnnotation in expectedScenarioAnnotations) {
226                 for (expectedScenario in expectedScenarioAnnotation.expectedScenarios) {
227                     val scenarioRegistered = registeredScenarios.contains(expectedScenario)
228                     if (!scenarioRegistered) {
229                         errors.add(
230                             IllegalArgumentException(
231                                 "Provided scenarios that are not registered to " +
232                                     "@${ExpectedScenarios::class.simpleName} annotation. " +
233                                     "$expectedScenario is not registered in the " +
234                                     "${FlickerConfig::class.simpleName}. Available scenarios " +
235                                     "are [${registeredScenarios.joinToString()}]."
236                             )
237                         )
238                     }
239                 }
240             }
241         }
242 
243         return errors
244     }
245 
246     private fun getFlickerConfig(): FlickerConfig {
247         require(testClass.getAnnotatedMethods(ExpectedScenarios::class.java).size == 1) {
248             "@ExpectedScenarios missing. " +
249                 "getFlickerConfig() may have been called before validation."
250         }
251 
252         val flickerConfigProviderProviderFunction =
253             testClass.getAnnotatedMethods(FlickerConfigProvider::class.java).first()
254         // TODO: Pass the correct target
255         return flickerConfigProviderProviderFunction.invokeExplosively(testClass) as FlickerConfig
256     }
257 
258     override fun shouldRunBeforeOn(method: FrameworkMethod): Boolean {
259         return false
260     }
261 
262     override fun shouldRunAfterOn(method: FrameworkMethod): Boolean {
263         return false
264     }
265 
266     private fun isMethodHandledByDecorator(method: FrameworkMethod): Boolean {
267         return method is InjectedTestCase && method.injectedBy == this
268     }
269 
270     private fun computeFlickerServiceTests(
271         reader: Reader,
272         testScenario: Scenario,
273         method: FrameworkMethod
274     ): Collection<InjectedTestCase> {
275         val expectedScenarios =
276             (method.annotations
277                     .filterIsInstance<ExpectedScenarios>()
278                     .firstOrNull()
279                     ?.expectedScenarios
280                     ?: emptyArray())
281                 .map { ScenarioId(it) }
282                 .toSet()
283 
284         return getFaasTestCases(
285             testScenario,
286             expectedScenarios,
287             paramString ?: "",
288             reader,
289             flickerService,
290             instrumentation,
291             this,
292             skipNonBlocking
293         )
294     }
295 
296     companion object {
297         private fun getDetectedScenarios(
298             testScenario: Scenario,
299             reader: Reader,
300             flickerService: FlickerService
301         ): Collection<ScenarioId> {
302             val groupedAssertions = getGroupedAssertions(testScenario, reader, flickerService)
303             return groupedAssertions.keys.map { it.type }.distinct()
304         }
305 
306         private fun getCachedResultMethod(): Method {
307             return InjectedTestCase::class.java.getMethod("execute", Description::class.java)
308         }
309 
310         private fun getGroupedAssertions(
311             testScenario: Scenario,
312             reader: Reader,
313             flickerService: FlickerService,
314         ): Map<ScenarioInstance, Collection<ScenarioAssertion>> {
315             if (
316                 !android.tools.flicker.datastore.DataStore.containsFlickerServiceResult(
317                     testScenario
318                 )
319             ) {
320                 val detectedScenarios = flickerService.detectScenarios(reader)
321                 val groupedAssertions = detectedScenarios.associateWith { it.generateAssertions() }
322                 android.tools.flicker.datastore.DataStore.addFlickerServiceAssertions(
323                     testScenario,
324                     groupedAssertions
325                 )
326             }
327 
328             return android.tools.flicker.datastore.DataStore.getFlickerServiceAssertions(
329                 testScenario
330             )
331         }
332 
333         internal fun getFaasTestCases(
334             testScenario: Scenario,
335             expectedScenarios: Set<ScenarioId>,
336             paramString: String,
337             reader: Reader,
338             flickerService: FlickerService,
339             instrumentation: Instrumentation,
340             caller: IFlickerJUnitDecorator,
341             skipNonBlocking: Boolean,
342         ): Collection<InjectedTestCase> {
343             val groupedAssertions = getGroupedAssertions(testScenario, reader, flickerService)
344             val organizedScenarioInstances = groupedAssertions.keys.groupBy { it.type }
345 
346             val faasTestCases = mutableListOf<FlickerServiceCachedTestCase>()
347             organizedScenarioInstances.values.forEachIndexed {
348                 scenarioTypesIndex,
349                 scenarioInstancesOfSameType ->
350                 scenarioInstancesOfSameType.forEachIndexed { scenarioInstanceIndex, scenarioInstance
351                     ->
352                     val assertionsForScenarioInstance = groupedAssertions[scenarioInstance]!!
353 
354                     assertionsForScenarioInstance.forEach {
355                         faasTestCases.add(
356                             FlickerServiceCachedTestCase(
357                                 assertion = it,
358                                 method = getCachedResultMethod(),
359                                 skipNonBlocking = skipNonBlocking,
360                                 isLast =
361                                     organizedScenarioInstances.values.size == scenarioTypesIndex &&
362                                         scenarioInstancesOfSameType.size == scenarioInstanceIndex,
363                                 injectedBy = caller,
364                                 paramString =
365                                     "${paramString}${
366                                     if (scenarioInstancesOfSameType.size > 1) {
367                                         "_${scenarioInstanceIndex + 1}"
368                                     } else {
369                                         ""
370                                     }}",
371                                 instrumentation = instrumentation,
372                             )
373                         )
374                     }
375                 }
376             }
377 
378             val detectedScenarioTestCase =
379                 AnonymousInjectedTestCase(
380                     getCachedResultMethod(),
381                     "FaaS_DetectedExpectedScenarios$paramString",
382                     injectedBy = caller
383                 ) {
384                     val metricBundle = Bundle()
385                     metricBundle.putString(FLICKER_ASSERTIONS_COUNT_KEY, "${faasTestCases.size}")
386                     SendToInstrumentation.sendBundle(instrumentation, metricBundle)
387 
388                     Truth.assertThat(getDetectedScenarios(testScenario, reader, flickerService))
389                         .containsAtLeastElementsIn(expectedScenarios)
390                 }
391 
392             return faasTestCases + listOf(detectedScenarioTestCase)
393         }
394     }
395 }
396