1 /* 2 * Copyright (C) 2024 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.testutils 18 19 import android.net.ConnectivityManager 20 import android.net.ConnectivityManager.NetworkCallback 21 import android.net.Network 22 import android.net.NetworkCapabilities 23 import android.net.NetworkRequest 24 import android.os.Handler 25 import androidx.test.platform.app.InstrumentationRegistry 26 import com.android.testutils.RecorderCallback.CallbackEntry 27 import java.util.Collections 28 import kotlin.test.fail 29 import org.junit.rules.TestRule 30 import org.junit.runner.Description 31 import org.junit.runners.model.Statement 32 33 /** 34 * A rule to file [NetworkCallback]s to request or watch networks. 35 * 36 * The callbacks filed in test methods are automatically unregistered when the method completes. 37 */ 38 class AutoReleaseNetworkCallbackRule : NetworkCallbackHelper(), TestRule { applynull39 override fun apply(base: Statement, description: Description): Statement { 40 return RequestCellNetworkStatement(base, description) 41 } 42 43 private inner class RequestCellNetworkStatement( 44 private val base: Statement, 45 private val description: Description 46 ) : Statement() { evaluatenull47 override fun evaluate() { 48 tryTest { 49 base.evaluate() 50 } cleanup { 51 unregisterAll() 52 } 53 } 54 } 55 } 56 57 /** 58 * Helps file [NetworkCallback]s to request or watch networks, keeping track of them for cleanup. 59 */ 60 open class NetworkCallbackHelper { <lambda>null61 private val cm by lazy { 62 InstrumentationRegistry.getInstrumentation().context 63 .getSystemService(ConnectivityManager::class.java) 64 ?: fail("ConnectivityManager not found") 65 } 66 private val cbToCleanup = Collections.synchronizedSet(mutableSetOf<NetworkCallback>()) 67 private var cellRequestCb: TestableNetworkCallback? = null 68 69 /** 70 * Convenience method to request a cell network, similarly to [requestNetwork]. 71 * 72 * The rule will keep tract of a single cell network request, which can be unrequested manually 73 * using [unrequestCell]. 74 */ requestCellnull75 fun requestCell(): Network { 76 if (cellRequestCb != null) { 77 fail("Cell network was already requested") 78 } 79 val cb = requestNetwork( 80 NetworkRequest.Builder() 81 .addTransportType(NetworkCapabilities.TRANSPORT_CELLULAR) 82 .addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET) 83 .build() 84 ) 85 cellRequestCb = cb 86 return cb.expect<CallbackEntry.Available>( 87 errorMsg = "Cell network not available. " + 88 "Please ensure the device has working mobile data." 89 ).network 90 } 91 92 /** 93 * Unrequest a cell network requested through [requestCell]. 94 */ unrequestCellnull95 fun unrequestCell() { 96 val cb = cellRequestCb ?: fail("Cell network was not requested") 97 unregisterNetworkCallback(cb) 98 cellRequestCb = null 99 } 100 addCallbacknull101 private fun addCallback( 102 cb: TestableNetworkCallback, 103 registrar: (TestableNetworkCallback) -> Unit 104 ): TestableNetworkCallback { 105 registrar(cb) 106 cbToCleanup.add(cb) 107 return cb 108 } 109 110 /** 111 * File a request for a Network. 112 * 113 * This will fail tests (throw) if the cell network cannot be obtained, or if it was already 114 * requested. 115 * 116 * Tests may call [unregisterNetworkCallback] once they are done using the returned [Network], 117 * otherwise it will be automatically unrequested after the test. 118 */ 119 @JvmOverloads requestNetworknull120 fun requestNetwork( 121 request: NetworkRequest, 122 cb: TestableNetworkCallback = TestableNetworkCallback(), 123 handler: Handler? = null 124 ) = addCallback(cb) { 125 if (handler == null) { 126 cm.requestNetwork(request, it) 127 } else { 128 cm.requestNetwork(request, it, handler) 129 } 130 } 131 132 /** 133 * Overload of [requestNetwork] that allows specifying a timeout. 134 */ 135 @JvmOverloads requestNetworknull136 fun requestNetwork( 137 request: NetworkRequest, 138 cb: TestableNetworkCallback = TestableNetworkCallback(), 139 timeoutMs: Int, 140 ) = addCallback(cb) { cm.requestNetwork(request, it, timeoutMs) } 141 142 /** 143 * File a callback for a NetworkRequest. 144 * 145 * This will fail tests (throw) if the cell network cannot be obtained, or if it was already 146 * requested. 147 * 148 * Tests may call [unregisterNetworkCallback] once they are done using the returned [Network], 149 * otherwise it will be automatically unrequested after the test. 150 */ 151 @JvmOverloads registerNetworkCallbacknull152 fun registerNetworkCallback( 153 request: NetworkRequest, 154 cb: TestableNetworkCallback = TestableNetworkCallback() 155 ) = addCallback(cb) { cm.registerNetworkCallback(request, it) } 156 157 /** 158 * @see ConnectivityManager.registerDefaultNetworkCallback 159 */ 160 @JvmOverloads registerDefaultNetworkCallbacknull161 fun registerDefaultNetworkCallback( 162 cb: TestableNetworkCallback = TestableNetworkCallback(), 163 handler: Handler? = null 164 ) = addCallback(cb) { 165 if (handler == null) { 166 cm.registerDefaultNetworkCallback(it) 167 } else { 168 cm.registerDefaultNetworkCallback(it, handler) 169 } 170 } 171 172 /** 173 * @see ConnectivityManager.registerSystemDefaultNetworkCallback 174 */ 175 @JvmOverloads registerSystemDefaultNetworkCallbacknull176 fun registerSystemDefaultNetworkCallback( 177 cb: TestableNetworkCallback = TestableNetworkCallback(), 178 handler: Handler 179 ) = addCallback(cb) { cm.registerSystemDefaultNetworkCallback(it, handler) } 180 181 /** 182 * @see ConnectivityManager.registerDefaultNetworkCallbackForUid 183 */ 184 @JvmOverloads registerDefaultNetworkCallbackForUidnull185 fun registerDefaultNetworkCallbackForUid( 186 uid: Int, 187 cb: TestableNetworkCallback = TestableNetworkCallback(), 188 handler: Handler 189 ) = addCallback(cb) { cm.registerDefaultNetworkCallbackForUid(uid, it, handler) } 190 191 /** 192 * @see ConnectivityManager.registerBestMatchingNetworkCallback 193 */ 194 @JvmOverloads registerBestMatchingNetworkCallbacknull195 fun registerBestMatchingNetworkCallback( 196 request: NetworkRequest, 197 cb: TestableNetworkCallback = TestableNetworkCallback(), 198 handler: Handler 199 ) = addCallback(cb) { cm.registerBestMatchingNetworkCallback(request, it, handler) } 200 201 /** 202 * @see ConnectivityManager.requestBackgroundNetwork 203 */ 204 @JvmOverloads requestBackgroundNetworknull205 fun requestBackgroundNetwork( 206 request: NetworkRequest, 207 cb: TestableNetworkCallback = TestableNetworkCallback(), 208 handler: Handler 209 ) = addCallback(cb) { cm.requestBackgroundNetwork(request, it, handler) } 210 211 /** 212 * Unregister a callback filed using registration methods in this class. 213 */ unregisterNetworkCallbacknull214 fun unregisterNetworkCallback(cb: NetworkCallback) { 215 cm.unregisterNetworkCallback(cb) 216 cbToCleanup.remove(cb) 217 } 218 219 /** 220 * Unregister all callbacks that were filed using registration methods in this class. 221 */ unregisterAllnull222 fun unregisterAll() { 223 cbToCleanup.forEach { cm.unregisterNetworkCallback(it) } 224 cbToCleanup.clear() 225 cellRequestCb = null 226 } 227 } 228