1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.server
18 
19 import android.content.Context
20 import android.net.ConnectivityManager
21 import android.net.INetworkMonitor
22 import android.net.INetworkMonitor.NETWORK_VALIDATION_PROBE_DNS
23 import android.net.INetworkMonitor.NETWORK_VALIDATION_PROBE_HTTP
24 import android.net.INetworkMonitorCallbacks
25 import android.net.LinkProperties
26 import android.net.LocalNetworkConfig
27 import android.net.Network
28 import android.net.NetworkAgent
29 import android.net.NetworkAgentConfig
30 import android.net.NetworkCapabilities
31 import android.net.NetworkCapabilities.NET_CAPABILITY_LOCAL_NETWORK
32 import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_SUSPENDED
33 import android.net.NetworkCapabilities.TRANSPORT_CELLULAR
34 import android.net.NetworkProvider
35 import android.net.NetworkRequest
36 import android.net.NetworkScore
37 import android.net.NetworkTestResultParcelable
38 import android.net.networkstack.NetworkStackClientBase
39 import android.os.HandlerThread
40 import com.android.testutils.RecorderCallback.CallbackEntry.Available
41 import com.android.testutils.RecorderCallback.CallbackEntry.Lost
42 import com.android.testutils.TestableNetworkCallback
43 import java.util.concurrent.atomic.AtomicInteger
44 import kotlin.test.assertEquals
45 import kotlin.test.fail
46 import org.mockito.ArgumentCaptor
47 import org.mockito.ArgumentMatchers.any
48 import org.mockito.ArgumentMatchers.anyInt
49 import org.mockito.Mockito.doAnswer
50 import org.mockito.Mockito.doNothing
51 import org.mockito.Mockito.verify
52 import org.mockito.stubbing.Answer
53 
54 const val SHORT_TIMEOUT_MS = 200L
55 
ArgumentCaptornull56 private inline fun <reified T> ArgumentCaptor() = ArgumentCaptor.forClass(T::class.java)
57 
58 private val agentCounter = AtomicInteger(1)
59 private fun nextAgentId() = agentCounter.getAndIncrement()
60 
61 /**
62  * A wrapper for network agents, for use with CSTest.
63  *
64  * This class knows how to interact with CSTest and has helpful methods to make fake agents
65  * that can be manipulated directly from a test.
66  */
67 class CSAgentWrapper(
68         val context: Context,
69         val deps: ConnectivityService.Dependencies,
70         csHandlerThread: HandlerThread,
71         networkStack: NetworkStackClientBase,
72         nac: NetworkAgentConfig,
73         val nc: NetworkCapabilities,
74         val lp: LinkProperties,
75         val lnc: FromS<LocalNetworkConfig>?,
76         val score: FromS<NetworkScore>,
77         val provider: NetworkProvider?
78 ) : TestableNetworkCallback.HasNetwork {
79     private val TAG = "CSAgent${nextAgentId()}"
80     private val VALIDATION_RESULT_INVALID = 0
81     private val NO_PROBE_RESULT = 0
82     private val VALIDATION_TIMESTAMP = 1234L
83     private val agent: NetworkAgent
84     private val nmCallbacks: INetworkMonitorCallbacks
85     val networkMonitor = mock<INetworkMonitor>()
86     private var nmValidationRedirectUrl: String? = null
87     private var nmValidationResult = NO_PROBE_RESULT
88     private var nmProbesCompleted = NO_PROBE_RESULT
89     private var nmProbesSucceeded = NO_PROBE_RESULT
90 
91     override val network: Network get() = agent.network!!
92 
93     init {
94         // Capture network monitor callbacks and simulate network monitor
95         val validateAnswer = Answer {
96             CSTest.CSTestExecutor.execute { onValidationRequested() }
97             null
98         }
99         doAnswer(validateAnswer).`when`(networkMonitor).notifyNetworkConnected(any(), any())
100         doAnswer(validateAnswer).`when`(networkMonitor).notifyNetworkConnectedParcel(any())
101         doAnswer(validateAnswer).`when`(networkMonitor).forceReevaluation(anyInt())
102         val nmNetworkCaptor = ArgumentCaptor<Network>()
103         val nmCbCaptor = ArgumentCaptor<INetworkMonitorCallbacks>()
104         doNothing().`when`(networkStack).makeNetworkMonitor(
105                 nmNetworkCaptor.capture(),
106                 any() /* name */,
107                 nmCbCaptor.capture())
108 
109         // Create the actual agent. NetworkAgent is abstract, so make an anonymous subclass.
110         if (deps.isAtLeastS()) {
111             agent = object : NetworkAgent(context, csHandlerThread.looper, TAG,
112                     nc, lp, lnc?.value, score.value, nac, provider) {}
113         } else {
114             agent = object : NetworkAgent(context, csHandlerThread.looper, TAG,
115                     nc, lp, 50 /* score */, nac, provider) {}
116         }
117         agent.register()
118         assertEquals(agent.network!!.netId, nmNetworkCaptor.value.netId)
119         nmCallbacks = nmCbCaptor.value
120         nmCallbacks.onNetworkMonitorCreated(networkMonitor)
121     }
122 
123     private fun onValidationRequested() {
124         if (deps.isAtLeastT()) {
125             verify(networkMonitor).notifyNetworkConnectedParcel(any())
126         } else {
127             verify(networkMonitor).notifyNetworkConnected(any(), any())
128         }
129         nmCallbacks.notifyProbeStatusChanged(0 /* completed */, 0 /* succeeded */)
130         val p = NetworkTestResultParcelable()
131         p.result = nmValidationResult
132         p.probesAttempted = nmProbesCompleted
133         p.probesSucceeded = nmProbesSucceeded
134         p.redirectUrl = nmValidationRedirectUrl
135         p.timestampMillis = VALIDATION_TIMESTAMP
136         nmCallbacks.notifyNetworkTestedWithExtras(p)
137     }
138 
139     fun connect() {
140         val mgr = context.getSystemService(Context.CONNECTIVITY_SERVICE) as ConnectivityManager
141         val request = NetworkRequest.Builder().apply {
142             clearCapabilities()
143             if (nc.transportTypes.isNotEmpty()) addTransportType(nc.transportTypes[0])
144             if (nc.hasCapability(NET_CAPABILITY_LOCAL_NETWORK)) {
145                 addCapability(NET_CAPABILITY_LOCAL_NETWORK)
146             }
147         }.build()
148         val cb = TestableNetworkCallback()
149         mgr.registerNetworkCallback(request, cb)
150         agent.markConnected()
151         if (null == cb.poll { it is Available && agent.network == it.network }) {
152             if (!nc.hasCapability(NET_CAPABILITY_NOT_SUSPENDED) &&
153                     nc.hasTransport(TRANSPORT_CELLULAR)) {
154                 // ConnectivityService adds NOT_SUSPENDED by default to all non-cell agents. An
155                 // agent without NOT_SUSPENDED will not connect, instead going into the SUSPENDED
156                 // state, so this call will not terminate.
157                 // Instead of forcefully adding NOT_SUSPENDED to all agents like older tools did,
158                 // it's better to let the developer manage it as they see fit but help them
159                 // debug if they forget.
160                 fail("Could not connect the agent. Did you forget to add " +
161                         "NET_CAPABILITY_NOT_SUSPENDED ?")
162             }
163             fail("Could not connect the agent. Instrumentation failure ?")
164         }
165         mgr.unregisterNetworkCallback(cb)
166     }
167 
168     fun disconnect() {
169         val mgr = context.getSystemService(Context.CONNECTIVITY_SERVICE) as ConnectivityManager
170         val request = NetworkRequest.Builder().apply {
171             clearCapabilities()
172             if (nc.transportTypes.isNotEmpty()) addTransportType(nc.transportTypes[0])
173             if (nc.hasCapability(NET_CAPABILITY_LOCAL_NETWORK)) {
174                 addCapability(NET_CAPABILITY_LOCAL_NETWORK)
175             }
176         }.build()
177         val cb = TestableNetworkCallback(timeoutMs = SHORT_TIMEOUT_MS)
178         mgr.registerNetworkCallback(request, cb)
179         cb.eventuallyExpect<Available> { it.network == agent.network }
180         agent.unregister()
181         cb.eventuallyExpect<Lost> { it.network == agent.network }
182     }
183 
184     fun unregisterAfterReplacement(timeoutMs: Int) = agent.unregisterAfterReplacement(timeoutMs)
185 
186     fun sendLocalNetworkConfig(lnc: LocalNetworkConfig) = agent.sendLocalNetworkConfig(lnc)
187     fun sendNetworkCapabilities(nc: NetworkCapabilities) = agent.sendNetworkCapabilities(nc)
188     fun sendLinkProperties(lp: LinkProperties) = agent.sendLinkProperties(lp)
189 
190     fun connectWithCaptivePortal(redirectUrl: String) {
191         setCaptivePortal(redirectUrl)
192         connect()
193     }
194 
195     fun setProbesStatus(probesCompleted: Int, probesSucceeded: Int) {
196         nmProbesCompleted = probesCompleted
197         nmProbesSucceeded = probesSucceeded
198     }
199 
200     fun setCaptivePortal(redirectUrl: String) {
201         nmValidationResult = VALIDATION_RESULT_INVALID
202         nmValidationRedirectUrl = redirectUrl
203         // Suppose the portal is found when NetworkMonitor probes NETWORK_VALIDATION_PROBE_HTTP
204         // in the beginning. Because NETWORK_VALIDATION_PROBE_HTTP is the decisive probe for captive
205         // portal, considering the NETWORK_VALIDATION_PROBE_HTTPS hasn't probed yet and set only
206         // DNS and HTTP probes completed.
207         setProbesStatus(
208             NETWORK_VALIDATION_PROBE_DNS or NETWORK_VALIDATION_PROBE_HTTP /* probesCompleted */,
209             VALIDATION_RESULT_INVALID /* probesSucceeded */)
210     }
211 }
212