1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.server;
18 
19 import static android.net.NetworkCapabilities.NET_CAPABILITY_NOT_SUSPENDED;
20 import static android.net.NetworkCapabilities.NET_CAPABILITY_NOT_VCN_MANAGED;
21 import static android.net.NetworkCapabilities.NET_CAPABILITY_NOT_VPN;
22 import static android.net.NetworkCapabilities.TRANSPORT_BLUETOOTH;
23 import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
24 import static android.net.NetworkCapabilities.TRANSPORT_ETHERNET;
25 import static android.net.NetworkCapabilities.TRANSPORT_TEST;
26 import static android.net.NetworkCapabilities.TRANSPORT_VPN;
27 import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
28 import static android.net.NetworkCapabilities.TRANSPORT_WIFI_AWARE;
29 
30 import static com.android.server.ConnectivityServiceTestUtils.transportToLegacyType;
31 
32 import static junit.framework.Assert.assertFalse;
33 import static junit.framework.Assert.assertTrue;
34 
35 import static org.junit.Assert.assertEquals;
36 import static org.junit.Assert.fail;
37 
38 import android.annotation.NonNull;
39 import android.annotation.SuppressLint;
40 import android.content.Context;
41 import android.net.ConnectivityManager;
42 import android.net.LinkProperties;
43 import android.net.Network;
44 import android.net.NetworkAgent;
45 import android.net.NetworkAgentConfig;
46 import android.net.NetworkCapabilities;
47 import android.net.NetworkProvider;
48 import android.net.NetworkScore;
49 import android.net.NetworkSpecifier;
50 import android.net.QosFilter;
51 import android.net.SocketKeepalive;
52 import android.os.ConditionVariable;
53 import android.os.HandlerThread;
54 import android.os.Message;
55 import android.util.CloseGuard;
56 import android.util.Log;
57 import android.util.Range;
58 
59 import com.android.net.module.util.ArrayTrackRecord;
60 import com.android.testutils.HandlerUtils;
61 import com.android.testutils.TestableNetworkCallback;
62 
63 import java.util.List;
64 import java.util.Objects;
65 import java.util.Set;
66 import java.util.concurrent.atomic.AtomicBoolean;
67 import java.util.function.Consumer;
68 
69 public class NetworkAgentWrapper implements TestableNetworkCallback.HasNetwork {
70     private static final long DESTROY_TIMEOUT_MS = 10_000L;
71 
72     // Note : Please do not add any new instrumentation here. If you need new instrumentation,
73     // please add it in CSAgentWrapper and use subclasses of CSTest instead of adding more
74     // tools in ConnectivityServiceTest.
75     private final NetworkCapabilities mNetworkCapabilities;
76     private final HandlerThread mHandlerThread;
77     private final CloseGuard mCloseGuard;
78     private final Context mContext;
79     private final String mLogTag;
80     private final NetworkAgentConfig mNetworkAgentConfig;
81 
82     private final ConditionVariable mDisconnected = new ConditionVariable();
83     private final ConditionVariable mPreventReconnectReceived = new ConditionVariable();
84     private final AtomicBoolean mConnected = new AtomicBoolean(false);
85     private NetworkScore mScore;
86     private NetworkAgent mNetworkAgent;
87     private int mStartKeepaliveError = SocketKeepalive.ERROR_UNSUPPORTED;
88     private int mStopKeepaliveError = SocketKeepalive.NO_KEEPALIVE;
89     // Controls how test network agent is going to wait before responding to keepalive
90     // start/stop. Useful when simulate KeepaliveTracker is waiting for response from modem.
91     private long mKeepaliveResponseDelay = 0L;
92     private Integer mExpectedKeepaliveSlot = null;
93     private final ArrayTrackRecord<CallbackType>.ReadHead mCallbackHistory =
94             new ArrayTrackRecord<CallbackType>().newReadHead();
95 
96     public static class Callbacks {
97         public final Consumer<NetworkAgent> onNetworkCreated;
98         public final Consumer<NetworkAgent> onNetworkUnwanted;
99         public final Consumer<NetworkAgent> onNetworkDestroyed;
100 
Callbacks()101         public Callbacks() {
102             this(null, null, null);
103         }
104 
Callbacks(Consumer<NetworkAgent> onNetworkCreated, Consumer<NetworkAgent> onNetworkUnwanted, Consumer<NetworkAgent> onNetworkDestroyed)105         public Callbacks(Consumer<NetworkAgent> onNetworkCreated,
106                 Consumer<NetworkAgent> onNetworkUnwanted,
107                 Consumer<NetworkAgent> onNetworkDestroyed) {
108             this.onNetworkCreated = onNetworkCreated;
109             this.onNetworkUnwanted = onNetworkUnwanted;
110             this.onNetworkDestroyed = onNetworkDestroyed;
111         }
112     }
113 
114     private final Callbacks mCallbacks;
115 
NetworkAgentWrapper(int transport, LinkProperties linkProperties, NetworkCapabilities ncTemplate, Context context)116     public NetworkAgentWrapper(int transport, LinkProperties linkProperties,
117             NetworkCapabilities ncTemplate, Context context) throws Exception {
118         this(transport, linkProperties, ncTemplate, null /* provider */,
119                 null /* callbacks */, context);
120     }
121 
NetworkAgentWrapper(int transport, LinkProperties linkProperties, NetworkCapabilities ncTemplate, NetworkProvider provider, Callbacks callbacks, Context context)122     public NetworkAgentWrapper(int transport, LinkProperties linkProperties,
123             NetworkCapabilities ncTemplate, NetworkProvider provider,
124             Callbacks callbacks, Context context) throws Exception {
125         final int type = transportToLegacyType(transport);
126         final String typeName = ConnectivityManager.getNetworkTypeName(type);
127         mNetworkCapabilities = (ncTemplate != null) ? ncTemplate : new NetworkCapabilities();
128         mNetworkCapabilities.addCapability(NET_CAPABILITY_NOT_SUSPENDED);
129         mNetworkCapabilities.addCapability(NET_CAPABILITY_NOT_VCN_MANAGED);
130         mNetworkCapabilities.addTransportType(transport);
131         switch (transport) {
132             case TRANSPORT_BLUETOOTH:
133                 // Score for Wear companion proxy network; not BLUETOOTH tethering.
134                 mScore = new NetworkScore.Builder().setLegacyInt(100).build();
135                 break;
136             case TRANSPORT_ETHERNET:
137                 mScore = new NetworkScore.Builder().setLegacyInt(70).build();
138                 break;
139             case TRANSPORT_WIFI:
140                 mScore = new NetworkScore.Builder().setLegacyInt(60).build();
141                 break;
142             case TRANSPORT_CELLULAR:
143                 mScore = new NetworkScore.Builder().setLegacyInt(50).build();
144                 break;
145             case TRANSPORT_WIFI_AWARE:
146                 mScore = new NetworkScore.Builder().setLegacyInt(20).build();
147                 break;
148             case TRANSPORT_TEST:
149                 mScore = new NetworkScore.Builder().build();
150                 break;
151             case TRANSPORT_VPN:
152                 mNetworkCapabilities.removeCapability(NET_CAPABILITY_NOT_VPN);
153                 // VPNs deduce the SUSPENDED capability from their underlying networks and there
154                 // is no public API to let VPN services set it.
155                 mNetworkCapabilities.removeCapability(NET_CAPABILITY_NOT_SUSPENDED);
156                 mScore = new NetworkScore.Builder().setLegacyInt(101).build();
157                 break;
158             default:
159                 throw new UnsupportedOperationException("unimplemented network type");
160         }
161         mContext = context;
162         mLogTag = "Mock-" + typeName;
163         mHandlerThread = new HandlerThread(mLogTag);
164         mHandlerThread.start();
165         mCloseGuard = new CloseGuard();
166         mCloseGuard.open("destroy");
167 
168         // extraInfo is set to "" by default in NetworkAgentConfig.
169         final String extraInfo = (transport == TRANSPORT_CELLULAR) ? "internet.apn" : "";
170         mNetworkAgentConfig = new NetworkAgentConfig.Builder()
171                 .setLegacyType(type)
172                 .setLegacyTypeName(typeName)
173                 .setLegacyExtraInfo(extraInfo)
174                 .build();
175         mCallbacks = (callbacks != null) ? callbacks : new Callbacks();
176         mNetworkAgent = makeNetworkAgent(linkProperties, mNetworkAgentConfig, provider);
177     }
178 
makeNetworkAgent(LinkProperties linkProperties, final NetworkAgentConfig nac, NetworkProvider provider)179     protected InstrumentedNetworkAgent makeNetworkAgent(LinkProperties linkProperties,
180             final NetworkAgentConfig nac, NetworkProvider provider) throws Exception {
181         return new InstrumentedNetworkAgent(this, linkProperties, nac, provider);
182     }
183 
184     public static class InstrumentedNetworkAgent extends NetworkAgent {
185         private final NetworkAgentWrapper mWrapper;
186         private static final String PROVIDER_NAME = "InstrumentedNetworkAgentProvider";
187 
InstrumentedNetworkAgent(NetworkAgentWrapper wrapper, LinkProperties lp, NetworkAgentConfig nac)188         public InstrumentedNetworkAgent(NetworkAgentWrapper wrapper, LinkProperties lp,
189                 NetworkAgentConfig nac) {
190             this(wrapper, lp, nac, null /* provider */);
191         }
192 
InstrumentedNetworkAgent(NetworkAgentWrapper wrapper, LinkProperties lp, NetworkAgentConfig nac, NetworkProvider provider)193         public InstrumentedNetworkAgent(NetworkAgentWrapper wrapper, LinkProperties lp,
194                 NetworkAgentConfig nac, NetworkProvider provider) {
195             super(wrapper.mContext, wrapper.mHandlerThread.getLooper(), wrapper.mLogTag,
196                     wrapper.mNetworkCapabilities, lp, wrapper.mScore, nac,
197                     null != provider ? provider : new NetworkProvider(wrapper.mContext,
198                             wrapper.mHandlerThread.getLooper(), PROVIDER_NAME));
199             mWrapper = wrapper;
200             register();
201         }
202 
203         @Override
unwanted()204         public void unwanted() {
205             mWrapper.mDisconnected.open();
206         }
207 
208         @Override
startSocketKeepalive(Message msg)209         public void startSocketKeepalive(Message msg) {
210             int slot = msg.arg1;
211             if (mWrapper.mExpectedKeepaliveSlot != null) {
212                 assertEquals((int) mWrapper.mExpectedKeepaliveSlot, slot);
213             }
214             mWrapper.mHandlerThread.getThreadHandler().postDelayed(
215                     () -> onSocketKeepaliveEvent(slot, mWrapper.mStartKeepaliveError),
216                     mWrapper.mKeepaliveResponseDelay);
217         }
218 
219         @Override
stopSocketKeepalive(Message msg)220         public void stopSocketKeepalive(Message msg) {
221             final int slot = msg.arg1;
222             mWrapper.mHandlerThread.getThreadHandler().postDelayed(
223                     () -> onSocketKeepaliveEvent(slot, mWrapper.mStopKeepaliveError),
224                     mWrapper.mKeepaliveResponseDelay);
225         }
226 
227         @Override
onQosCallbackRegistered(final int qosCallbackId, final @NonNull QosFilter filter)228         public void onQosCallbackRegistered(final int qosCallbackId,
229                 final @NonNull QosFilter filter) {
230             Log.i(mWrapper.mLogTag, "onQosCallbackRegistered");
231             mWrapper.mCallbackHistory.add(
232                     new CallbackType.OnQosCallbackRegister(qosCallbackId, filter));
233         }
234 
235         @Override
onQosCallbackUnregistered(final int qosCallbackId)236         public void onQosCallbackUnregistered(final int qosCallbackId) {
237             Log.i(mWrapper.mLogTag, "onQosCallbackUnregistered");
238             mWrapper.mCallbackHistory.add(new CallbackType.OnQosCallbackUnregister(qosCallbackId));
239         }
240 
241         @Override
preventAutomaticReconnect()242         protected void preventAutomaticReconnect() {
243             mWrapper.mPreventReconnectReceived.open();
244         }
245 
246         @Override
addKeepalivePacketFilter(Message msg)247         protected void addKeepalivePacketFilter(Message msg) {
248             Log.i(mWrapper.mLogTag, "Add keepalive packet filter.");
249         }
250 
251         @Override
removeKeepalivePacketFilter(Message msg)252         protected void removeKeepalivePacketFilter(Message msg) {
253             Log.i(mWrapper.mLogTag, "Remove keepalive packet filter.");
254         }
255 
256         @Override
onNetworkCreated()257         public void onNetworkCreated() {
258             super.onNetworkCreated();
259             if (mWrapper.mCallbacks.onNetworkCreated != null) {
260                 mWrapper.mCallbacks.onNetworkCreated.accept(this);
261             }
262         }
263 
264         @Override
onNetworkUnwanted()265         public void onNetworkUnwanted() {
266             super.onNetworkUnwanted();
267             if (mWrapper.mCallbacks.onNetworkUnwanted != null) {
268                 mWrapper.mCallbacks.onNetworkUnwanted.accept(this);
269             }
270         }
271 
272         @Override
onNetworkDestroyed()273         public void onNetworkDestroyed() {
274             super.onNetworkDestroyed();
275             if (mWrapper.mCallbacks.onNetworkDestroyed != null) {
276                 mWrapper.mCallbacks.onNetworkDestroyed.accept(this);
277             }
278         }
279 
280     }
281 
setScore(@onNull final NetworkScore score)282     public void setScore(@NonNull final NetworkScore score) {
283         mScore = score;
284         mNetworkAgent.sendNetworkScore(score);
285     }
286 
287     // TODO : remove adjustScore and replace with the appropriate exiting flags.
adjustScore(int change)288     public void adjustScore(int change) {
289         final int newLegacyScore = mScore.getLegacyInt() + change;
290         final NetworkScore.Builder builder = new NetworkScore.Builder()
291                 .setLegacyInt(newLegacyScore);
292         if (mNetworkCapabilities.hasTransport(TRANSPORT_WIFI) && newLegacyScore < 50) {
293             builder.setExiting(true);
294         }
295         mScore = builder.build();
296         mNetworkAgent.sendNetworkScore(mScore);
297     }
298 
getScore()299     public NetworkScore getScore() {
300         return mScore;
301     }
302 
explicitlySelected(boolean explicitlySelected, boolean acceptUnvalidated)303     public void explicitlySelected(boolean explicitlySelected, boolean acceptUnvalidated) {
304         mNetworkAgent.explicitlySelected(explicitlySelected, acceptUnvalidated);
305     }
306 
addCapability(int capability)307     public void addCapability(int capability) {
308         mNetworkCapabilities.addCapability(capability);
309         mNetworkAgent.sendNetworkCapabilities(mNetworkCapabilities);
310     }
311 
removeCapability(int capability)312     public void removeCapability(int capability) {
313         mNetworkCapabilities.removeCapability(capability);
314         mNetworkAgent.sendNetworkCapabilities(mNetworkCapabilities);
315     }
316 
setUids(Set<Range<Integer>> uids)317     public void setUids(Set<Range<Integer>> uids) {
318         mNetworkCapabilities.setUids(uids);
319         mNetworkAgent.sendNetworkCapabilities(mNetworkCapabilities);
320     }
321 
setSignalStrength(int signalStrength)322     public void setSignalStrength(int signalStrength) {
323         mNetworkCapabilities.setSignalStrength(signalStrength);
324         mNetworkAgent.sendNetworkCapabilities(mNetworkCapabilities);
325     }
326 
setNetworkSpecifier(NetworkSpecifier networkSpecifier)327     public void setNetworkSpecifier(NetworkSpecifier networkSpecifier) {
328         mNetworkCapabilities.setNetworkSpecifier(networkSpecifier);
329         mNetworkAgent.sendNetworkCapabilities(mNetworkCapabilities);
330     }
331 
setNetworkCapabilities(NetworkCapabilities nc, boolean sendToConnectivityService)332     public void setNetworkCapabilities(NetworkCapabilities nc, boolean sendToConnectivityService) {
333         mNetworkCapabilities.set(nc);
334         if (sendToConnectivityService) {
335             mNetworkAgent.sendNetworkCapabilities(mNetworkCapabilities);
336         }
337     }
338 
setUnderlyingNetworks(List<Network> underlyingNetworks)339     public void setUnderlyingNetworks(List<Network> underlyingNetworks) {
340         mNetworkAgent.setUnderlyingNetworks(underlyingNetworks);
341     }
342 
setOwnerUid(int uid)343     public void setOwnerUid(int uid) {
344         mNetworkCapabilities.setOwnerUid(uid);
345         mNetworkAgent.sendNetworkCapabilities(mNetworkCapabilities);
346     }
347 
connect()348     public void connect() {
349         if (!mConnected.compareAndSet(false /* expect */, true /* update */)) {
350             // compareAndSet returns false when the value couldn't be updated because it did not
351             // match the expected value.
352             fail("Test NetworkAgents can only be connected once");
353         }
354         mNetworkAgent.markConnected();
355     }
356 
suspend()357     public void suspend() {
358         removeCapability(NET_CAPABILITY_NOT_SUSPENDED);
359     }
360 
resume()361     public void resume() {
362         addCapability(NET_CAPABILITY_NOT_SUSPENDED);
363     }
364 
disconnect()365     public void disconnect() {
366         mNetworkAgent.unregister();
367     }
368 
369     /**
370      * Destroy the network agent and stop its looper.
371      *
372      * <p>This must always be called.
373      */
destroy()374     public void destroy() {
375         mHandlerThread.quitSafely();
376         try {
377             mHandlerThread.join(DESTROY_TIMEOUT_MS);
378         } catch (InterruptedException e) {
379             Log.e(mLogTag, "Interrupted when waiting for handler thread on destroy", e);
380         }
381         mCloseGuard.close();
382     }
383 
384     @SuppressLint("Finalize") // Follows the recommended pattern for CloseGuard
385     @Override
finalize()386     protected void finalize() throws Throwable {
387         try {
388             // Note that mCloseGuard could be null if the constructor threw.
389             if (mCloseGuard != null) {
390                 mCloseGuard.warnIfOpen();
391             }
392             destroy();
393         } finally {
394             super.finalize();
395         }
396     }
397 
398     @Override
getNetwork()399     public Network getNetwork() {
400         return mNetworkAgent.getNetwork();
401     }
402 
expectPreventReconnectReceived(long timeoutMs)403     public void expectPreventReconnectReceived(long timeoutMs) {
404         assertTrue(mPreventReconnectReceived.block(timeoutMs));
405     }
406 
expectDisconnected(long timeoutMs)407     public void expectDisconnected(long timeoutMs) {
408         assertTrue(mDisconnected.block(timeoutMs));
409     }
410 
assertNotDisconnected(long timeoutMs)411     public void assertNotDisconnected(long timeoutMs) {
412         assertFalse(mDisconnected.block(timeoutMs));
413     }
414 
sendLinkProperties(LinkProperties lp)415     public void sendLinkProperties(LinkProperties lp) {
416         mNetworkAgent.sendLinkProperties(lp);
417     }
418 
setStartKeepaliveEvent(int reason)419     public void setStartKeepaliveEvent(int reason) {
420         mStartKeepaliveError = reason;
421     }
422 
setStopKeepaliveEvent(int reason)423     public void setStopKeepaliveEvent(int reason) {
424         mStopKeepaliveError = reason;
425     }
426 
setKeepaliveResponseDelay(long delay)427     public void setKeepaliveResponseDelay(long delay) {
428         mKeepaliveResponseDelay = delay;
429     }
430 
setExpectedKeepaliveSlot(Integer slot)431     public void setExpectedKeepaliveSlot(Integer slot) {
432         mExpectedKeepaliveSlot = slot;
433     }
434 
getNetworkAgent()435     public NetworkAgent getNetworkAgent() {
436         return mNetworkAgent;
437     }
438 
getNetworkAgentConfig()439     public NetworkAgentConfig getNetworkAgentConfig() {
440         return mNetworkAgentConfig;
441     }
442 
getNetworkCapabilities()443     public NetworkCapabilities getNetworkCapabilities() {
444         return mNetworkCapabilities;
445     }
446 
getLegacyType()447     public int getLegacyType() {
448         return mNetworkAgentConfig.getLegacyType();
449     }
450 
getExtraInfo()451     public String getExtraInfo() {
452         return mNetworkAgentConfig.getLegacyExtraInfo();
453     }
454 
getCallbackHistory()455     public @NonNull ArrayTrackRecord<CallbackType>.ReadHead getCallbackHistory() {
456         return mCallbackHistory;
457     }
458 
waitForIdle(long timeoutMs)459     public void waitForIdle(long timeoutMs) {
460         HandlerUtils.waitForIdle(mHandlerThread, timeoutMs);
461     }
462 
463     abstract static class CallbackType {
464         final int mQosCallbackId;
465 
CallbackType(final int qosCallbackId)466         protected CallbackType(final int qosCallbackId) {
467             mQosCallbackId = qosCallbackId;
468         }
469 
470         static class OnQosCallbackRegister extends CallbackType {
471             final QosFilter mFilter;
OnQosCallbackRegister(final int qosCallbackId, final QosFilter filter)472             OnQosCallbackRegister(final int qosCallbackId, final QosFilter filter) {
473                 super(qosCallbackId);
474                 mFilter = filter;
475             }
476 
477             @Override
equals(final Object o)478             public boolean equals(final Object o) {
479                 if (this == o) return true;
480                 if (o == null || getClass() != o.getClass()) return false;
481                 final OnQosCallbackRegister that = (OnQosCallbackRegister) o;
482                 return mQosCallbackId == that.mQosCallbackId
483                         && Objects.equals(mFilter, that.mFilter);
484             }
485 
486             @Override
hashCode()487             public int hashCode() {
488                 return Objects.hash(mQosCallbackId, mFilter);
489             }
490         }
491 
492         static class OnQosCallbackUnregister extends CallbackType {
OnQosCallbackUnregister(final int qosCallbackId)493             OnQosCallbackUnregister(final int qosCallbackId) {
494                 super(qosCallbackId);
495             }
496 
497             @Override
equals(final Object o)498             public boolean equals(final Object o) {
499                 if (this == o) return true;
500                 if (o == null || getClass() != o.getClass()) return false;
501                 final OnQosCallbackUnregister that = (OnQosCallbackUnregister) o;
502                 return mQosCallbackId == that.mQosCallbackId;
503             }
504 
505             @Override
hashCode()506             public int hashCode() {
507                 return Objects.hash(mQosCallbackId);
508             }
509         }
510     }
511 
isBypassableVpn()512     public boolean isBypassableVpn() {
513         return mNetworkAgentConfig.isBypassableVpn();
514     }
515 
516     // Note : Please do not add any new instrumentation here. If you need new instrumentation,
517     // please add it in CSAgentWrapper and use subclasses of CSTest instead of adding more
518     // tools in ConnectivityServiceTest.
519 }
520