1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.tools.flicker.junit
18 
19 import java.util.Collections
20 import java.util.IdentityHashMap
21 import org.junit.Rule
22 import org.junit.rules.MethodRule
23 import org.junit.rules.TestRule
24 import org.junit.runner.Description
25 import org.junit.runners.model.FrameworkMethod
26 import org.junit.runners.model.Statement
27 
28 /**
29  * Data structure for ordering of [TestRule]/[MethodRule] instances.
30  *
31  * @since 4.13
32  */
33 internal class RuleContainer {
34     private val orderValues = IdentityHashMap<Any, Int>()
35     private val testRules = mutableListOf<TestRule>()
36     private val methodRules = mutableListOf<MethodRule>()
37 
38     /** Sets order value for the specified rule. */
setOrdernull39     fun setOrder(rule: Any, order: Int) {
40         orderValues[rule] = order
41     }
42 
addnull43     fun add(methodRule: MethodRule) {
44         methodRules.add(methodRule)
45     }
46 
addnull47     fun add(testRule: TestRule) {
48         testRules.add(testRule)
49     }
50 
51     /** Returns entries in the order how they should be applied, i.e. inner-to-outer. */
52     private val sortedEntries: List<RuleEntry>
53         get() {
54             val ruleEntries: MutableList<RuleEntry> = ArrayList(methodRules.size + testRules.size)
55             for (rule in methodRules) {
56                 ruleEntries.add(RuleEntry(rule, RuleEntry.TYPE_METHOD_RULE, orderValues[rule]))
57             }
58             for (rule in testRules) {
59                 ruleEntries.add(RuleEntry(rule, RuleEntry.TYPE_TEST_RULE, orderValues[rule]))
60             }
61             Collections.sort(ruleEntries, ENTRY_COMPARATOR)
62             return ruleEntries
63         }
64 
65     /** Applies all the rules ordered accordingly to the specified `statement`. */
applynull66     fun apply(
67         method: FrameworkMethod?,
68         description: Description?,
69         target: Any?,
70         statement: Statement
71     ): Statement {
72         if (methodRules.isEmpty() && testRules.isEmpty()) {
73             return statement
74         }
75         var result = statement
76         for (ruleEntry in sortedEntries) {
77             result =
78                 if (ruleEntry.type == RuleEntry.TYPE_TEST_RULE) {
79                     (ruleEntry.rule as TestRule).apply(result, description)
80                 } else {
81                     (ruleEntry.rule as MethodRule).apply(result, method, target)
82                 }
83         }
84         return result
85     }
86 
87     /**
88      * Returns rule instances in the order how they should be applied, i.e. inner-to-outer.
89      * VisibleForTesting
90      */
91     val sortedRules: List<Any>
92         get() {
93             val result = mutableListOf<Any>()
94             for (entry in sortedEntries) {
95                 result.add(entry.rule)
96             }
97             return result
98         }
99 
100     internal class RuleEntry(val rule: Any, val type: Int, order: Int?) {
101         val order: Int
102 
103         init {
104             this.order = order ?: Rule.DEFAULT_ORDER
105         }
106 
107         companion object {
108             const val TYPE_TEST_RULE = 1
109             const val TYPE_METHOD_RULE = 0
110         }
111     }
112 
113     companion object {
114         val ENTRY_COMPARATOR: Comparator<RuleEntry> =
115             object : Comparator<RuleEntry> {
comparenull116                 override fun compare(o1: RuleEntry, o2: RuleEntry): Int {
117                     val result = compareInt(o1.order, o2.order)
118                     return if (result != 0) result else o1.type - o2.type
119                 }
120 
compareIntnull121                 private fun compareInt(a: Int, b: Int): Int {
122                     return if (a < b) 1 else if (a == b) 0 else -1
123                 }
124             }
125     }
126 }
127