1 /*
2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.testutils
18 
19 import android.net.ConnectivityManager
20 import android.net.ConnectivityManager.NetworkCallback
21 import android.net.Network
22 import android.net.NetworkCapabilities
23 import android.net.NetworkRequest
24 import android.os.Handler
25 import androidx.test.platform.app.InstrumentationRegistry
26 import com.android.testutils.RecorderCallback.CallbackEntry
27 import java.util.Collections
28 import kotlin.test.fail
29 import org.junit.rules.TestRule
30 import org.junit.runner.Description
31 import org.junit.runners.model.Statement
32 
33 /**
34  * A rule to file [NetworkCallback]s to request or watch networks.
35  *
36  * The callbacks filed in test methods are automatically unregistered when the method completes.
37  */
38 class AutoReleaseNetworkCallbackRule : NetworkCallbackHelper(), TestRule {
applynull39     override fun apply(base: Statement, description: Description): Statement {
40         return RequestCellNetworkStatement(base, description)
41     }
42 
43     private inner class RequestCellNetworkStatement(
44         private val base: Statement,
45         private val description: Description
46     ) : Statement() {
evaluatenull47         override fun evaluate() {
48             tryTest {
49                 base.evaluate()
50             } cleanup {
51                 unregisterAll()
52             }
53         }
54     }
55 }
56 
57 /**
58  * Helps file [NetworkCallback]s to request or watch networks, keeping track of them for cleanup.
59  */
60 open class NetworkCallbackHelper {
<lambda>null61     private val cm by lazy {
62         InstrumentationRegistry.getInstrumentation().context
63             .getSystemService(ConnectivityManager::class.java)
64             ?: fail("ConnectivityManager not found")
65     }
66     private val cbToCleanup = Collections.synchronizedSet(mutableSetOf<NetworkCallback>())
67     private var cellRequestCb: TestableNetworkCallback? = null
68 
69     /**
70      * Convenience method to request a cell network, similarly to [requestNetwork].
71      *
72      * The rule will keep tract of a single cell network request, which can be unrequested manually
73      * using [unrequestCell].
74      */
requestCellnull75     fun requestCell(): Network {
76         if (cellRequestCb != null) {
77             fail("Cell network was already requested")
78         }
79         val cb = requestNetwork(
80             NetworkRequest.Builder()
81                 .addTransportType(NetworkCapabilities.TRANSPORT_CELLULAR)
82                 .addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
83                 .build()
84         )
85         cellRequestCb = cb
86         return cb.expect<CallbackEntry.Available>(
87             errorMsg = "Cell network not available. " +
88                     "Please ensure the device has working mobile data."
89         ).network
90     }
91 
92     /**
93      * Unrequest a cell network requested through [requestCell].
94      */
unrequestCellnull95     fun unrequestCell() {
96         val cb = cellRequestCb ?: fail("Cell network was not requested")
97         unregisterNetworkCallback(cb)
98         cellRequestCb = null
99     }
100 
addCallbacknull101     private fun addCallback(
102         cb: TestableNetworkCallback,
103         registrar: (TestableNetworkCallback) -> Unit
104     ): TestableNetworkCallback {
105         registrar(cb)
106         cbToCleanup.add(cb)
107         return cb
108     }
109 
110     /**
111      * File a request for a Network.
112      *
113      * This will fail tests (throw) if the cell network cannot be obtained, or if it was already
114      * requested.
115      *
116      * Tests may call [unregisterNetworkCallback] once they are done using the returned [Network],
117      * otherwise it will be automatically unrequested after the test.
118      */
119     @JvmOverloads
requestNetworknull120     fun requestNetwork(
121         request: NetworkRequest,
122         cb: TestableNetworkCallback = TestableNetworkCallback(),
123         handler: Handler? = null
124     ) = addCallback(cb) {
125         if (handler == null) {
126             cm.requestNetwork(request, it)
127         } else {
128             cm.requestNetwork(request, it, handler)
129         }
130     }
131 
132     /**
133      * Overload of [requestNetwork] that allows specifying a timeout.
134      */
135     @JvmOverloads
requestNetworknull136     fun requestNetwork(
137         request: NetworkRequest,
138         cb: TestableNetworkCallback = TestableNetworkCallback(),
139         timeoutMs: Int,
140     ) = addCallback(cb) { cm.requestNetwork(request, it, timeoutMs) }
141 
142     /**
143      * File a callback for a NetworkRequest.
144      *
145      * This will fail tests (throw) if the cell network cannot be obtained, or if it was already
146      * requested.
147      *
148      * Tests may call [unregisterNetworkCallback] once they are done using the returned [Network],
149      * otherwise it will be automatically unrequested after the test.
150      */
151     @JvmOverloads
registerNetworkCallbacknull152     fun registerNetworkCallback(
153         request: NetworkRequest,
154         cb: TestableNetworkCallback = TestableNetworkCallback()
155     ) = addCallback(cb) { cm.registerNetworkCallback(request, it) }
156 
157     /**
158      * @see ConnectivityManager.registerDefaultNetworkCallback
159      */
160     @JvmOverloads
registerDefaultNetworkCallbacknull161     fun registerDefaultNetworkCallback(
162         cb: TestableNetworkCallback = TestableNetworkCallback(),
163         handler: Handler? = null
164     ) = addCallback(cb) {
165         if (handler == null) {
166             cm.registerDefaultNetworkCallback(it)
167         } else {
168             cm.registerDefaultNetworkCallback(it, handler)
169         }
170     }
171 
172     /**
173      * @see ConnectivityManager.registerSystemDefaultNetworkCallback
174      */
175     @JvmOverloads
registerSystemDefaultNetworkCallbacknull176     fun registerSystemDefaultNetworkCallback(
177         cb: TestableNetworkCallback = TestableNetworkCallback(),
178         handler: Handler
179     ) = addCallback(cb) { cm.registerSystemDefaultNetworkCallback(it, handler) }
180 
181     /**
182      * @see ConnectivityManager.registerDefaultNetworkCallbackForUid
183      */
184     @JvmOverloads
registerDefaultNetworkCallbackForUidnull185     fun registerDefaultNetworkCallbackForUid(
186         uid: Int,
187         cb: TestableNetworkCallback = TestableNetworkCallback(),
188         handler: Handler
189     ) = addCallback(cb) { cm.registerDefaultNetworkCallbackForUid(uid, it, handler) }
190 
191     /**
192      * @see ConnectivityManager.registerBestMatchingNetworkCallback
193      */
194     @JvmOverloads
registerBestMatchingNetworkCallbacknull195     fun registerBestMatchingNetworkCallback(
196         request: NetworkRequest,
197         cb: TestableNetworkCallback = TestableNetworkCallback(),
198         handler: Handler
199     ) = addCallback(cb) { cm.registerBestMatchingNetworkCallback(request, it, handler) }
200 
201     /**
202      * @see ConnectivityManager.requestBackgroundNetwork
203      */
204     @JvmOverloads
requestBackgroundNetworknull205     fun requestBackgroundNetwork(
206         request: NetworkRequest,
207         cb: TestableNetworkCallback = TestableNetworkCallback(),
208         handler: Handler
209     ) = addCallback(cb) { cm.requestBackgroundNetwork(request, it, handler) }
210 
211     /**
212      * Unregister a callback filed using registration methods in this class.
213      */
unregisterNetworkCallbacknull214     fun unregisterNetworkCallback(cb: NetworkCallback) {
215         cm.unregisterNetworkCallback(cb)
216         cbToCleanup.remove(cb)
217     }
218 
219     /**
220      * Unregister all callbacks that were filed using registration methods in this class.
221      */
unregisterAllnull222     fun unregisterAll() {
223         cbToCleanup.forEach { cm.unregisterNetworkCallback(it) }
224         cbToCleanup.clear()
225         cellRequestCb = null
226     }
227 }
228