1 /*
2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.android.adservices.shared.testing.junit;
17 
18 import org.junit.Rule;
19 import org.junit.rules.MethodRule;
20 import org.junit.rules.TestRule;
21 import org.junit.runner.Description;
22 import org.junit.runners.BlockJUnit4ClassRunner;
23 import org.junit.runners.model.FrameworkMember;
24 import org.junit.runners.model.FrameworkMethod;
25 import org.junit.runners.model.InitializationError;
26 import org.junit.runners.model.MemberValueConsumer;
27 import org.junit.runners.model.Statement;
28 import org.junit.runners.model.TestClass;
29 
30 import java.lang.reflect.Field;
31 import java.util.ArrayList;
32 import java.util.Arrays;
33 import java.util.Comparator;
34 import java.util.List;
35 
36 /**
37  * This class extends {@link BlockJUnit4ClassRunner} and re-implement some methods that are private
38  * as protected (by copying them "as-is", unless indicated otherwise) so they can be used by custom
39  * runners.
40  */
41 public abstract class EasilyExtensibleBlockJUnit4ClassRunner extends BlockJUnit4ClassRunner {
42 
43     private static final ThreadLocal<RuleContainer> sCurrentRuleContainer = new ThreadLocal<>();
44     private static final FieldComparator FIELD_COMPARATOR = new FieldComparator();
45 
EasilyExtensibleBlockJUnit4ClassRunner(Class<?> testClass)46     protected EasilyExtensibleBlockJUnit4ClassRunner(Class<?> testClass)
47             throws InitializationError {
48         super(testClass);
49     }
50 
EasilyExtensibleBlockJUnit4ClassRunner(TestClass testClass)51     protected EasilyExtensibleBlockJUnit4ClassRunner(TestClass testClass)
52             throws InitializationError {
53         super(testClass);
54     }
55 
56     // NOTE: this is not the same as EasilyExtensibleBlockJUnit4ClassRunner's method, as that one
57     // takes an Object as testClass. But the body itself is copied.
getTestRules(TestClass testClass, Object target)58     protected List<TestRule> getTestRules(TestClass testClass, Object target) {
59         RuleCollector<TestRule> collector = new RuleCollector<>();
60         testClass.collectAnnotatedMethodValues(target, Rule.class, TestRule.class, collector);
61         testClass.collectAnnotatedFieldValues(target, Rule.class, TestRule.class, collector);
62         return collector.mResult;
63     }
64 
rules(TestClass testClass, Object target)65     protected List<MethodRule> rules(TestClass testClass, Object target) {
66         RuleCollector<MethodRule> collector = new RuleCollector<MethodRule>();
67         testClass.collectAnnotatedMethodValues(target, Rule.class, MethodRule.class, collector);
68         testClass.collectAnnotatedFieldValues(target, Rule.class, MethodRule.class, collector);
69         return collector.mResult;
70     }
71 
withRules( FrameworkMethod method, TestClass testClass, Object target, Statement statement)72     protected Statement withRules(
73             FrameworkMethod method, TestClass testClass, Object target, Statement statement) {
74         RuleContainer ruleContainer = new RuleContainer();
75         sCurrentRuleContainer.set(ruleContainer);
76         try {
77             List<TestRule> testRules = getTestRules(testClass, target);
78             for (MethodRule each : rules(testClass, target)) {
79                 if (!(each instanceof TestRule && testRules.contains(each))) {
80                     ruleContainer.add(each);
81                 }
82             }
83             for (TestRule rule : testRules) {
84                 ruleContainer.add(rule);
85             }
86         } finally {
87             sCurrentRuleContainer.remove();
88         }
89         return ruleContainer.apply(
90                 method,
91                 Description.createTestDescription(testClass.getJavaClass(), method.getName()),
92                 target,
93                 statement);
94     }
95 
96     protected static final class RuleCollector<T> implements MemberValueConsumer<T> {
97         final List<T> mResult = new ArrayList<T>();
98 
accept(FrameworkMember<?> member, T value)99         public void accept(FrameworkMember<?> member, T value) {
100             Rule rule = member.getAnnotation(Rule.class);
101             if (rule != null) {
102                 RuleContainer container = sCurrentRuleContainer.get();
103                 if (container != null) {
104                     container.setOrder(value, rule.order());
105                 }
106             }
107             mResult.add(value);
108         }
109     }
110 
getSortedDeclaredFields(Class<?> clazz)111     protected static Field[] getSortedDeclaredFields(Class<?> clazz) {
112         Field[] declaredFields = clazz.getDeclaredFields();
113         Arrays.sort(declaredFields, FIELD_COMPARATOR);
114         return declaredFields;
115     }
116 
getSuperClasses(Class<?> testClass)117     protected static List<Class<?>> getSuperClasses(Class<?> testClass) {
118         List<Class<?>> results = new ArrayList<Class<?>>();
119         Class<?> current = testClass;
120         while (current != null) {
121             results.add(current);
122             current = current.getSuperclass();
123         }
124         return results;
125     }
126 
127     private static class FieldComparator implements Comparator<Field> {
compare(Field left, Field right)128         public int compare(Field left, Field right) {
129             return left.getName().compareTo(right.getName());
130         }
131     }
132 }
133