1 /*
2 * Copyright (C) 2023 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 package com.android.server
18
19 import android.content.Context
20 import android.net.ConnectivityManager
21 import android.net.INetworkMonitor
22 import android.net.INetworkMonitor.NETWORK_VALIDATION_PROBE_DNS
23 import android.net.INetworkMonitor.NETWORK_VALIDATION_PROBE_HTTP
24 import android.net.INetworkMonitorCallbacks
25 import android.net.LinkProperties
26 import android.net.LocalNetworkConfig
27 import android.net.Network
28 import android.net.NetworkAgent
29 import android.net.NetworkAgentConfig
30 import android.net.NetworkCapabilities
31 import android.net.NetworkCapabilities.NET_CAPABILITY_LOCAL_NETWORK
32 import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_SUSPENDED
33 import android.net.NetworkCapabilities.TRANSPORT_CELLULAR
34 import android.net.NetworkProvider
35 import android.net.NetworkRequest
36 import android.net.NetworkScore
37 import android.net.NetworkTestResultParcelable
38 import android.net.networkstack.NetworkStackClientBase
39 import android.os.HandlerThread
40 import com.android.testutils.RecorderCallback.CallbackEntry.Available
41 import com.android.testutils.RecorderCallback.CallbackEntry.Lost
42 import com.android.testutils.TestableNetworkCallback
43 import java.util.concurrent.atomic.AtomicInteger
44 import kotlin.test.assertEquals
45 import kotlin.test.fail
46 import org.mockito.ArgumentCaptor
47 import org.mockito.ArgumentMatchers.any
48 import org.mockito.ArgumentMatchers.anyInt
49 import org.mockito.Mockito.doAnswer
50 import org.mockito.Mockito.doNothing
51 import org.mockito.Mockito.verify
52 import org.mockito.stubbing.Answer
53
54 const val SHORT_TIMEOUT_MS = 200L
55
ArgumentCaptornull56 private inline fun <reified T> ArgumentCaptor() = ArgumentCaptor.forClass(T::class.java)
57
58 private val agentCounter = AtomicInteger(1)
59 private fun nextAgentId() = agentCounter.getAndIncrement()
60
61 /**
62 * A wrapper for network agents, for use with CSTest.
63 *
64 * This class knows how to interact with CSTest and has helpful methods to make fake agents
65 * that can be manipulated directly from a test.
66 */
67 class CSAgentWrapper(
68 val context: Context,
69 val deps: ConnectivityService.Dependencies,
70 csHandlerThread: HandlerThread,
71 networkStack: NetworkStackClientBase,
72 nac: NetworkAgentConfig,
73 val nc: NetworkCapabilities,
74 val lp: LinkProperties,
75 val lnc: FromS<LocalNetworkConfig>?,
76 val score: FromS<NetworkScore>,
77 val provider: NetworkProvider?
78 ) : TestableNetworkCallback.HasNetwork {
79 private val TAG = "CSAgent${nextAgentId()}"
80 private val VALIDATION_RESULT_INVALID = 0
81 private val NO_PROBE_RESULT = 0
82 private val VALIDATION_TIMESTAMP = 1234L
83 private val agent: NetworkAgent
84 private val nmCallbacks: INetworkMonitorCallbacks
85 val networkMonitor = mock<INetworkMonitor>()
86 private var nmValidationRedirectUrl: String? = null
87 private var nmValidationResult = NO_PROBE_RESULT
88 private var nmProbesCompleted = NO_PROBE_RESULT
89 private var nmProbesSucceeded = NO_PROBE_RESULT
90
91 override val network: Network get() = agent.network!!
92
93 init {
94 // Capture network monitor callbacks and simulate network monitor
95 val validateAnswer = Answer {
96 CSTest.CSTestExecutor.execute { onValidationRequested() }
97 null
98 }
99 doAnswer(validateAnswer).`when`(networkMonitor).notifyNetworkConnected(any(), any())
100 doAnswer(validateAnswer).`when`(networkMonitor).notifyNetworkConnectedParcel(any())
101 doAnswer(validateAnswer).`when`(networkMonitor).forceReevaluation(anyInt())
102 val nmNetworkCaptor = ArgumentCaptor<Network>()
103 val nmCbCaptor = ArgumentCaptor<INetworkMonitorCallbacks>()
104 doNothing().`when`(networkStack).makeNetworkMonitor(
105 nmNetworkCaptor.capture(),
106 any() /* name */,
107 nmCbCaptor.capture())
108
109 // Create the actual agent. NetworkAgent is abstract, so make an anonymous subclass.
110 if (deps.isAtLeastS()) {
111 agent = object : NetworkAgent(context, csHandlerThread.looper, TAG,
112 nc, lp, lnc?.value, score.value, nac, provider) {}
113 } else {
114 agent = object : NetworkAgent(context, csHandlerThread.looper, TAG,
115 nc, lp, 50 /* score */, nac, provider) {}
116 }
117 agent.register()
118 assertEquals(agent.network!!.netId, nmNetworkCaptor.value.netId)
119 nmCallbacks = nmCbCaptor.value
120 nmCallbacks.onNetworkMonitorCreated(networkMonitor)
121 }
122
123 private fun onValidationRequested() {
124 if (deps.isAtLeastT()) {
125 verify(networkMonitor).notifyNetworkConnectedParcel(any())
126 } else {
127 verify(networkMonitor).notifyNetworkConnected(any(), any())
128 }
129 nmCallbacks.notifyProbeStatusChanged(0 /* completed */, 0 /* succeeded */)
130 val p = NetworkTestResultParcelable()
131 p.result = nmValidationResult
132 p.probesAttempted = nmProbesCompleted
133 p.probesSucceeded = nmProbesSucceeded
134 p.redirectUrl = nmValidationRedirectUrl
135 p.timestampMillis = VALIDATION_TIMESTAMP
136 nmCallbacks.notifyNetworkTestedWithExtras(p)
137 }
138
139 fun connect() {
140 val mgr = context.getSystemService(Context.CONNECTIVITY_SERVICE) as ConnectivityManager
141 val request = NetworkRequest.Builder().apply {
142 clearCapabilities()
143 if (nc.transportTypes.isNotEmpty()) addTransportType(nc.transportTypes[0])
144 if (nc.hasCapability(NET_CAPABILITY_LOCAL_NETWORK)) {
145 addCapability(NET_CAPABILITY_LOCAL_NETWORK)
146 }
147 }.build()
148 val cb = TestableNetworkCallback()
149 mgr.registerNetworkCallback(request, cb)
150 agent.markConnected()
151 if (null == cb.poll { it is Available && agent.network == it.network }) {
152 if (!nc.hasCapability(NET_CAPABILITY_NOT_SUSPENDED) &&
153 nc.hasTransport(TRANSPORT_CELLULAR)) {
154 // ConnectivityService adds NOT_SUSPENDED by default to all non-cell agents. An
155 // agent without NOT_SUSPENDED will not connect, instead going into the SUSPENDED
156 // state, so this call will not terminate.
157 // Instead of forcefully adding NOT_SUSPENDED to all agents like older tools did,
158 // it's better to let the developer manage it as they see fit but help them
159 // debug if they forget.
160 fail("Could not connect the agent. Did you forget to add " +
161 "NET_CAPABILITY_NOT_SUSPENDED ?")
162 }
163 fail("Could not connect the agent. Instrumentation failure ?")
164 }
165 mgr.unregisterNetworkCallback(cb)
166 }
167
168 fun disconnect() {
169 val mgr = context.getSystemService(Context.CONNECTIVITY_SERVICE) as ConnectivityManager
170 val request = NetworkRequest.Builder().apply {
171 clearCapabilities()
172 if (nc.transportTypes.isNotEmpty()) addTransportType(nc.transportTypes[0])
173 if (nc.hasCapability(NET_CAPABILITY_LOCAL_NETWORK)) {
174 addCapability(NET_CAPABILITY_LOCAL_NETWORK)
175 }
176 }.build()
177 val cb = TestableNetworkCallback(timeoutMs = SHORT_TIMEOUT_MS)
178 mgr.registerNetworkCallback(request, cb)
179 cb.eventuallyExpect<Available> { it.network == agent.network }
180 agent.unregister()
181 cb.eventuallyExpect<Lost> { it.network == agent.network }
182 }
183
184 fun unregisterAfterReplacement(timeoutMs: Int) = agent.unregisterAfterReplacement(timeoutMs)
185
186 fun sendLocalNetworkConfig(lnc: LocalNetworkConfig) = agent.sendLocalNetworkConfig(lnc)
187 fun sendNetworkCapabilities(nc: NetworkCapabilities) = agent.sendNetworkCapabilities(nc)
188 fun sendLinkProperties(lp: LinkProperties) = agent.sendLinkProperties(lp)
189
190 fun connectWithCaptivePortal(redirectUrl: String) {
191 setCaptivePortal(redirectUrl)
192 connect()
193 }
194
195 fun setProbesStatus(probesCompleted: Int, probesSucceeded: Int) {
196 nmProbesCompleted = probesCompleted
197 nmProbesSucceeded = probesSucceeded
198 }
199
200 fun setCaptivePortal(redirectUrl: String) {
201 nmValidationResult = VALIDATION_RESULT_INVALID
202 nmValidationRedirectUrl = redirectUrl
203 // Suppose the portal is found when NetworkMonitor probes NETWORK_VALIDATION_PROBE_HTTP
204 // in the beginning. Because NETWORK_VALIDATION_PROBE_HTTP is the decisive probe for captive
205 // portal, considering the NETWORK_VALIDATION_PROBE_HTTPS hasn't probed yet and set only
206 // DNS and HTTP probes completed.
207 setProbesStatus(
208 NETWORK_VALIDATION_PROBE_DNS or NETWORK_VALIDATION_PROBE_HTTP /* probesCompleted */,
209 VALIDATION_RESULT_INVALID /* probesSucceeded */)
210 }
211 }
212