1 /*
2  * Copyright (C) 2015 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.server.connectivity;
18 
19 import static android.system.OsConstants.*;
20 
21 import static com.android.net.module.util.NetworkStackConstants.DNS_OVER_TLS_PORT;
22 import static com.android.net.module.util.NetworkStackConstants.ETHER_MTU;
23 import static com.android.net.module.util.NetworkStackConstants.ICMP_HEADER_LEN;
24 import static com.android.net.module.util.NetworkStackConstants.IPV4_HEADER_MIN_LEN;
25 import static com.android.net.module.util.NetworkStackConstants.IPV6_HEADER_LEN;
26 import static com.android.net.module.util.NetworkStackConstants.IPV6_MIN_MTU;
27 import static com.android.net.module.util.NetworkStackConstants.IPV6_MTU;
28 import static com.android.net.module.util.NetworkStackConstants.IP_MTU;
29 
30 import android.annotation.NonNull;
31 import android.annotation.Nullable;
32 import android.annotation.TargetApi;
33 import android.net.InetAddresses;
34 import android.net.LinkAddress;
35 import android.net.LinkProperties;
36 import android.net.Network;
37 import android.net.RouteInfo;
38 import android.net.TrafficStats;
39 import android.net.shared.PrivateDnsConfig;
40 import android.net.util.NetworkConstants;
41 import android.os.Build;
42 import android.os.SystemClock;
43 import android.system.ErrnoException;
44 import android.system.Os;
45 import android.system.StructTimeval;
46 import android.text.TextUtils;
47 import android.util.Log;
48 import android.util.Pair;
49 
50 import com.android.internal.util.IndentingPrintWriter;
51 import com.android.net.module.util.NetworkStackConstants;
52 
53 import libcore.io.IoUtils;
54 
55 import java.io.Closeable;
56 import java.io.DataInputStream;
57 import java.io.DataOutputStream;
58 import java.io.FileDescriptor;
59 import java.io.IOException;
60 import java.io.InterruptedIOException;
61 import java.net.Inet4Address;
62 import java.net.Inet6Address;
63 import java.net.InetAddress;
64 import java.net.InetSocketAddress;
65 import java.net.NetworkInterface;
66 import java.net.SocketAddress;
67 import java.net.SocketException;
68 import java.net.UnknownHostException;
69 import java.nio.ByteBuffer;
70 import java.nio.charset.StandardCharsets;
71 import java.util.ArrayList;
72 import java.util.Collections;
73 import java.util.HashMap;
74 import java.util.List;
75 import java.util.Map;
76 import java.util.Random;
77 import java.util.concurrent.CountDownLatch;
78 import java.util.concurrent.TimeUnit;
79 
80 import javax.net.ssl.SNIHostName;
81 import javax.net.ssl.SNIServerName;
82 import javax.net.ssl.SSLParameters;
83 import javax.net.ssl.SSLSocket;
84 import javax.net.ssl.SSLSocketFactory;
85 
86 /**
87  * NetworkDiagnostics
88  *
89  * A simple class to diagnose network connectivity fundamentals.  Current
90  * checks performed are:
91  *     - ICMPv4/v6 echo requests for all routers
92  *     - ICMPv4/v6 echo requests for all DNS servers
93  *     - DNS UDP queries to all DNS servers
94  *
95  * Currently unimplemented checks include:
96  *     - report ARP/ND data about on-link neighbors
97  *     - DNS TCP queries to all DNS servers
98  *     - HTTP DIRECT and PROXY checks
99  *     - port 443 blocking/TLS intercept checks
100  *     - QUIC reachability checks
101  *     - MTU checks
102  *
103  * The supplied timeout bounds the entire diagnostic process.  Each specific
104  * check class must implement this upper bound on measurements in whichever
105  * manner is most appropriate and effective.
106  *
107  * @hide
108  */
109 public class NetworkDiagnostics {
110     private static final String TAG = "NetworkDiagnostics";
111 
112     private static final InetAddress TEST_DNS4 = InetAddresses.parseNumericAddress("8.8.8.8");
113     private static final InetAddress TEST_DNS6 = InetAddresses.parseNumericAddress(
114             "2001:4860:4860::8888");
115 
116     // For brevity elsewhere.
now()117     private static final long now() {
118         return SystemClock.elapsedRealtime();
119     }
120 
121     // Values from RFC 1035 section 4.1.1, names from <arpa/nameser.h>.
122     // Should be a member of DnsUdpCheck, but "compiler says no".
123     public static enum DnsResponseCode { NOERROR, FORMERR, SERVFAIL, NXDOMAIN, NOTIMP, REFUSED };
124 
125     private final Network mNetwork;
126     private final LinkProperties mLinkProperties;
127     private final PrivateDnsConfig mPrivateDnsCfg;
128     private final Integer mInterfaceIndex;
129 
130     private final long mTimeoutMs;
131     private final long mStartTime;
132     private final long mDeadlineTime;
133 
134     // A counter, initialized to the total number of measurements,
135     // so callers can wait for completion.
136     private final CountDownLatch mCountDownLatch;
137 
138     public class Measurement {
139         private static final String SUCCEEDED = "SUCCEEDED";
140         private static final String FAILED = "FAILED";
141 
142         private boolean succeeded;
143 
144         // Package private.  TODO: investigate better encapsulation.
145         String description = "";
146         long startTime;
147         long finishTime;
148         String result = "";
149         Thread thread;
150 
checkSucceeded()151         public boolean checkSucceeded() { return succeeded; }
152 
recordSuccess(String msg)153         void recordSuccess(String msg) {
154             maybeFixupTimes();
155             succeeded = true;
156             result = SUCCEEDED + ": " + msg;
157             if (mCountDownLatch != null) {
158                 mCountDownLatch.countDown();
159             }
160         }
161 
recordFailure(String msg)162         void recordFailure(String msg) {
163             maybeFixupTimes();
164             succeeded = false;
165             result = FAILED + ": " + msg;
166             if (mCountDownLatch != null) {
167                 mCountDownLatch.countDown();
168             }
169         }
170 
maybeFixupTimes()171         private void maybeFixupTimes() {
172             // Allows the caller to just set success/failure and not worry
173             // about also setting the correct finishing time.
174             if (finishTime == 0) { finishTime = now(); }
175 
176             // In cases where, for example, a failure has occurred before the
177             // measurement even began, fixup the start time to reflect as much.
178             if (startTime == 0) { startTime = finishTime; }
179         }
180 
181         @Override
toString()182         public String toString() {
183             return description + ": " + result + " (" + (finishTime - startTime) + "ms)";
184         }
185     }
186 
187     private final Map<Pair<InetAddress, Integer>, Measurement> mIcmpChecks = new HashMap<>();
188     private final Map<Pair<InetAddress, InetAddress>, Measurement> mExplicitSourceIcmpChecks =
189             new HashMap<>();
190     private final Map<InetAddress, Measurement> mDnsUdpChecks = new HashMap<>();
191     private final Map<InetAddress, Measurement> mDnsTlsChecks = new HashMap<>();
192     private final String mDescription;
193 
194 
NetworkDiagnostics(Network network, LinkProperties lp, @NonNull PrivateDnsConfig privateDnsCfg, long timeoutMs)195     public NetworkDiagnostics(Network network, LinkProperties lp,
196             @NonNull PrivateDnsConfig privateDnsCfg, long timeoutMs) {
197         mNetwork = network;
198         mLinkProperties = lp;
199         mPrivateDnsCfg = privateDnsCfg;
200         mInterfaceIndex = getInterfaceIndex(mLinkProperties.getInterfaceName());
201         mTimeoutMs = timeoutMs;
202         mStartTime = now();
203         mDeadlineTime = mStartTime + mTimeoutMs;
204 
205         // Hardcode measurements to TEST_DNS4 and TEST_DNS6 in order to test off-link connectivity.
206         // We are free to modify mLinkProperties with impunity because ConnectivityService passes us
207         // a copy and not the original object. It's easier to do it this way because we don't need
208         // to check whether the LinkProperties already contains these DNS servers because
209         // LinkProperties#addDnsServer checks for duplicates.
210         if (mLinkProperties.isReachable(TEST_DNS4)) {
211             mLinkProperties.addDnsServer(TEST_DNS4);
212         }
213         // TODO: we could use mLinkProperties.isReachable(TEST_DNS6) here, because we won't set any
214         // DNS servers for which isReachable() is false, but since this is diagnostic code, be extra
215         // careful.
216         if (mLinkProperties.hasGlobalIpv6Address() || mLinkProperties.hasIpv6DefaultRoute()) {
217             mLinkProperties.addDnsServer(TEST_DNS6);
218         }
219 
220         for (RouteInfo route : mLinkProperties.getRoutes()) {
221             if (route.getType() == RouteInfo.RTN_UNICAST && route.hasGateway()) {
222                 final InetAddress gateway = route.getGateway();
223                 prepareIcmpMeasurements(gateway);
224                 if (route.isIPv6Default()) {
225                     prepareExplicitSourceIcmpMeasurements(gateway);
226                 }
227             }
228         }
229 
230         for (InetAddress nameserver : mLinkProperties.getDnsServers()) {
231             prepareIcmpMeasurements(nameserver);
232             prepareDnsMeasurement(nameserver);
233 
234             // Unlike the DnsResolver which doesn't do certificate validation in opportunistic mode,
235             // DoT probes to the DNS servers will fail if certificate validation fails.
236             prepareDnsTlsMeasurement(null /* hostname */, nameserver);
237         }
238 
239         for (InetAddress tlsNameserver : mPrivateDnsCfg.ips) {
240             // Reachability check is necessary since when resolving the strict mode hostname,
241             // NetworkMonitor always queries for both A and AAAA records, even if the network
242             // is IPv4-only or IPv6-only.
243             if (mLinkProperties.isReachable(tlsNameserver)) {
244                 // If there are IPs, there must have been a name that resolved to them.
245                 prepareDnsTlsMeasurement(mPrivateDnsCfg.hostname, tlsNameserver);
246             }
247         }
248 
249         mCountDownLatch = new CountDownLatch(totalMeasurementCount());
250 
251         startMeasurements();
252 
253         mDescription = "ifaces{" + TextUtils.join(",", mLinkProperties.getAllInterfaceNames()) + "}"
254                 + " index{" + mInterfaceIndex + "}"
255                 + " network{" + mNetwork + "}"
256                 + " nethandle{" + mNetwork.getNetworkHandle() + "}";
257     }
258 
getInterfaceIndex(String ifname)259     private static Integer getInterfaceIndex(String ifname) {
260         try {
261             NetworkInterface ni = NetworkInterface.getByName(ifname);
262             return ni.getIndex();
263         } catch (NullPointerException | SocketException e) {
264             return null;
265         }
266     }
267 
socketAddressToString(@onNull SocketAddress sockAddr)268     private static String socketAddressToString(@NonNull SocketAddress sockAddr) {
269         // The default toString() implementation is not the prettiest.
270         InetSocketAddress inetSockAddr = (InetSocketAddress) sockAddr;
271         InetAddress localAddr = inetSockAddr.getAddress();
272         return String.format(
273                 (localAddr instanceof Inet6Address ? "[%s]:%d" : "%s:%d"),
274                 localAddr.getHostAddress(), inetSockAddr.getPort());
275     }
276 
getHeaderLen(@onNull InetAddress target)277     private static int getHeaderLen(@NonNull InetAddress target) {
278         // Convert IPv4 mapped v6 address to v4 if any.
279         try {
280             final InetAddress addr = InetAddress.getByAddress(target.getAddress());
281             // An ICMPv6 header is technically 4 bytes, but the implementation in IcmpCheck#run()
282             // will always fill in another 4 bytes padding in the v6 diagnostic packets, so the size
283             // before icmp data is always 8 bytes in the implementation of ICMP diagnostics for both
284             // v4 and v6 packets. Thus, it's fine to use the v4 header size in the length
285             // calculation.
286             if (addr instanceof Inet6Address) {
287                 return IPV6_HEADER_LEN + ICMP_HEADER_LEN;
288             } else {
289                 return IPV4_HEADER_MIN_LEN + ICMP_HEADER_LEN;
290             }
291         } catch (UnknownHostException e) {
292             throw new AssertionError("Create InetAddress fail(" + target + ")", e);
293         }
294     }
295 
prepareIcmpMeasurements(@onNull InetAddress target)296     private void prepareIcmpMeasurements(@NonNull InetAddress target) {
297         int mtu = getMtuForTarget(target);
298         // If getMtuForTarget fails, it doesn't matter what mtu is used because connect can't
299         // succeed anyway
300         if (mtu <= 0) mtu = mLinkProperties.getMtu();
301         if (mtu <= 0) mtu = ETHER_MTU;
302         // Test with different size payload ICMP.
303         // 1. Test with 0 payload.
304         addPayloadIcmpMeasurement(target, 0);
305         final int header = getHeaderLen(target);
306         // 2. Test with full size MTU.
307         addPayloadIcmpMeasurement(target, mtu - header);
308         // 3. If v6, make another measurement with the full v6 min MTU, unless that's what
309         //    was done above.
310         if ((target instanceof Inet6Address) && (mtu != IPV6_MIN_MTU)) {
311             addPayloadIcmpMeasurement(target, IPV6_MIN_MTU - header);
312         }
313     }
314 
addPayloadIcmpMeasurement(@onNull InetAddress target, int payloadLen)315     private void addPayloadIcmpMeasurement(@NonNull InetAddress target, int payloadLen) {
316         // This can happen if the there is no mtu filled(which is 0) in the link property.
317         // The value becomes negative after minus header length.
318         if (payloadLen < 0) return;
319 
320         final Pair<InetAddress, Integer> lenTarget =
321                 new Pair<>(target, Integer.valueOf(payloadLen));
322         if (!mIcmpChecks.containsKey(lenTarget)) {
323             final Measurement measurement = new Measurement();
324             measurement.thread = new Thread(new IcmpCheck(target, payloadLen, measurement));
325             mIcmpChecks.put(lenTarget, measurement);
326         }
327     }
328 
329     /**
330      * Open a socket to the target address and return the mtu from that socket
331      *
332      * If the MTU can't be obtained for some reason (e.g. the target is unreachable) this will
333      * return -1.
334      *
335      * @param target the destination address
336      * @return the mtu to that destination, or -1
337      */
338     // getsockoptInt is S+, but this service code and only installs on S, so it's safe to ignore
339     // the lint warnings by using @TargetApi.
340     @TargetApi(Build.VERSION_CODES.S)
getMtuForTarget(InetAddress target)341     private int getMtuForTarget(InetAddress target) {
342         final int family = target instanceof Inet4Address ? AF_INET : AF_INET6;
343         FileDescriptor socket = null;
344         try {
345             socket = Os.socket(family, SOCK_DGRAM, 0);
346             mNetwork.bindSocket(socket);
347             Os.connect(socket, target, 0);
348             if (family == AF_INET) {
349                 return Os.getsockoptInt(socket, IPPROTO_IP, IP_MTU);
350             } else {
351                 return Os.getsockoptInt(socket, IPPROTO_IPV6, IPV6_MTU);
352             }
353         } catch (ErrnoException | IOException e) {
354             Log.e(TAG, "Can't get MTU for destination " + target, e);
355             return -1;
356         } finally {
357             IoUtils.closeQuietly(socket);
358         }
359     }
360 
prepareExplicitSourceIcmpMeasurements(InetAddress target)361     private void prepareExplicitSourceIcmpMeasurements(InetAddress target) {
362         for (LinkAddress l : mLinkProperties.getLinkAddresses()) {
363             InetAddress source = l.getAddress();
364             if (source instanceof Inet6Address && l.isGlobalPreferred()) {
365                 Pair<InetAddress, InetAddress> srcTarget = new Pair<>(source, target);
366                 if (!mExplicitSourceIcmpChecks.containsKey(srcTarget)) {
367                     Measurement measurement = new Measurement();
368                     measurement.thread = new Thread(new IcmpCheck(source, target, 0, measurement));
369                     mExplicitSourceIcmpChecks.put(srcTarget, measurement);
370                 }
371             }
372         }
373     }
374 
prepareDnsMeasurement(InetAddress target)375     private void prepareDnsMeasurement(InetAddress target) {
376         if (!mDnsUdpChecks.containsKey(target)) {
377             Measurement measurement = new Measurement();
378             measurement.thread = new Thread(new DnsUdpCheck(target, measurement));
379             mDnsUdpChecks.put(target, measurement);
380         }
381     }
382 
prepareDnsTlsMeasurement(@ullable String hostname, @NonNull InetAddress target)383     private void prepareDnsTlsMeasurement(@Nullable String hostname, @NonNull InetAddress target) {
384         // This might overwrite an existing entry in mDnsTlsChecks, because |target| can be an IP
385         // address configured by the network as well as an IP address learned by resolving the
386         // strict mode DNS hostname. If the entry is overwritten, the overwritten measurement
387         // thread will not execute.
388         Measurement measurement = new Measurement();
389         measurement.thread = new Thread(new DnsTlsCheck(hostname, target, measurement));
390         mDnsTlsChecks.put(target, measurement);
391     }
392 
totalMeasurementCount()393     private int totalMeasurementCount() {
394         return mIcmpChecks.size() + mExplicitSourceIcmpChecks.size() + mDnsUdpChecks.size()
395                 + mDnsTlsChecks.size();
396     }
397 
startMeasurements()398     private void startMeasurements() {
399         for (Measurement measurement : mIcmpChecks.values()) {
400             measurement.thread.start();
401         }
402         for (Measurement measurement : mExplicitSourceIcmpChecks.values()) {
403             measurement.thread.start();
404         }
405         for (Measurement measurement : mDnsUdpChecks.values()) {
406             measurement.thread.start();
407         }
408         for (Measurement measurement : mDnsTlsChecks.values()) {
409             measurement.thread.start();
410         }
411     }
412 
waitForMeasurements()413     public void waitForMeasurements() {
414         try {
415             mCountDownLatch.await(mDeadlineTime - now(), TimeUnit.MILLISECONDS);
416         } catch (InterruptedException ignored) {}
417     }
418 
getMeasurements()419     public List<Measurement> getMeasurements() {
420         // TODO: Consider moving waitForMeasurements() in here to minimize the
421         // chance of caller errors.
422 
423         ArrayList<Measurement> measurements = new ArrayList(totalMeasurementCount());
424 
425         // Sort measurements IPv4 first.
426         for (Map.Entry<Pair<InetAddress, Integer>, Measurement> entry : mIcmpChecks.entrySet()) {
427             if (entry.getKey().first instanceof Inet4Address) {
428                 measurements.add(entry.getValue());
429             }
430         }
431         for (Map.Entry<Pair<InetAddress, InetAddress>, Measurement> entry :
432                 mExplicitSourceIcmpChecks.entrySet()) {
433             if (entry.getKey().first instanceof Inet4Address) {
434                 measurements.add(entry.getValue());
435             }
436         }
437         for (Map.Entry<InetAddress, Measurement> entry : mDnsUdpChecks.entrySet()) {
438             if (entry.getKey() instanceof Inet4Address) {
439                 measurements.add(entry.getValue());
440             }
441         }
442         for (Map.Entry<InetAddress, Measurement> entry : mDnsTlsChecks.entrySet()) {
443             if (entry.getKey() instanceof Inet4Address) {
444                 measurements.add(entry.getValue());
445             }
446         }
447 
448         // IPv6 measurements second.
449         for (Map.Entry<Pair<InetAddress, Integer>, Measurement> entry : mIcmpChecks.entrySet()) {
450             if (entry.getKey().first instanceof Inet6Address) {
451                 measurements.add(entry.getValue());
452             }
453         }
454         for (Map.Entry<Pair<InetAddress, InetAddress>, Measurement> entry :
455                 mExplicitSourceIcmpChecks.entrySet()) {
456             if (entry.getKey().first instanceof Inet6Address) {
457                 measurements.add(entry.getValue());
458             }
459         }
460         for (Map.Entry<InetAddress, Measurement> entry : mDnsUdpChecks.entrySet()) {
461             if (entry.getKey() instanceof Inet6Address) {
462                 measurements.add(entry.getValue());
463             }
464         }
465         for (Map.Entry<InetAddress, Measurement> entry : mDnsTlsChecks.entrySet()) {
466             if (entry.getKey() instanceof Inet6Address) {
467                 measurements.add(entry.getValue());
468             }
469         }
470 
471         return measurements;
472     }
473 
dump(IndentingPrintWriter pw)474     public void dump(IndentingPrintWriter pw) {
475         pw.println(TAG + ":" + mDescription);
476         final long unfinished = mCountDownLatch.getCount();
477         if (unfinished > 0) {
478             // This can't happen unless a caller forgets to call waitForMeasurements()
479             // or a measurement isn't implemented to correctly honor the timeout.
480             pw.println("WARNING: countdown wait incomplete: "
481                     + unfinished + " unfinished measurements");
482         }
483 
484         pw.increaseIndent();
485 
486         String prefix;
487         for (Measurement m : getMeasurements()) {
488             prefix = m.checkSucceeded() ? "." : "F";
489             pw.println(prefix + "  " + m.toString());
490         }
491 
492         pw.decreaseIndent();
493     }
494 
495 
496     private class SimpleSocketCheck implements Closeable {
497         protected final InetAddress mSource;  // Usually null.
498         protected final InetAddress mTarget;
499         protected final int mAddressFamily;
500         protected final Measurement mMeasurement;
501         protected FileDescriptor mFileDescriptor;
502         protected SocketAddress mSocketAddress;
503 
SimpleSocketCheck( InetAddress source, InetAddress target, Measurement measurement)504         protected SimpleSocketCheck(
505                 InetAddress source, InetAddress target, Measurement measurement) {
506             mMeasurement = measurement;
507 
508             if (target instanceof Inet6Address) {
509                 Inet6Address targetWithScopeId = null;
510                 if (target.isLinkLocalAddress() && mInterfaceIndex != null) {
511                     try {
512                         targetWithScopeId = Inet6Address.getByAddress(
513                                 null, target.getAddress(), mInterfaceIndex);
514                     } catch (UnknownHostException e) {
515                         mMeasurement.recordFailure(e.toString());
516                     }
517                 }
518                 mTarget = (targetWithScopeId != null) ? targetWithScopeId : target;
519                 mAddressFamily = AF_INET6;
520             } else {
521                 mTarget = target;
522                 mAddressFamily = AF_INET;
523             }
524 
525             // We don't need to check the scope ID here because we currently only do explicit-source
526             // measurements from global IPv6 addresses.
527             mSource = source;
528         }
529 
SimpleSocketCheck(InetAddress target, Measurement measurement)530         protected SimpleSocketCheck(InetAddress target, Measurement measurement) {
531             this(null, target, measurement);
532         }
533 
setupSocket( int sockType, int protocol, long writeTimeout, long readTimeout, int dstPort)534         protected void setupSocket(
535                 int sockType, int protocol, long writeTimeout, long readTimeout, int dstPort)
536                 throws ErrnoException, IOException {
537             final int oldTag = TrafficStats.getAndSetThreadStatsTag(
538                     NetworkStackConstants.TAG_SYSTEM_PROBE);
539             try {
540                 mFileDescriptor = Os.socket(mAddressFamily, sockType, protocol);
541             } finally {
542                 // TODO: The tag should remain set until all traffic is sent and received.
543                 // Consider tagging the socket after the measurement thread is started.
544                 TrafficStats.setThreadStatsTag(oldTag);
545             }
546             // Setting SNDTIMEO is purely for defensive purposes.
547             Os.setsockoptTimeval(mFileDescriptor,
548                     SOL_SOCKET, SO_SNDTIMEO, StructTimeval.fromMillis(writeTimeout));
549             Os.setsockoptTimeval(mFileDescriptor,
550                     SOL_SOCKET, SO_RCVTIMEO, StructTimeval.fromMillis(readTimeout));
551             // TODO: Use IP_RECVERR/IPV6_RECVERR, pending OsContants availability.
552             mNetwork.bindSocket(mFileDescriptor);
553             if (mSource != null) {
554                 Os.bind(mFileDescriptor, mSource, 0);
555             }
556             Os.connect(mFileDescriptor, mTarget, dstPort);
557             mSocketAddress = Os.getsockname(mFileDescriptor);
558         }
559 
ensureMeasurementNecessary()560         protected boolean ensureMeasurementNecessary() {
561             if (mMeasurement.finishTime == 0) return false;
562 
563             // Countdown latch was not decremented when the measurement failed during setup.
564             mCountDownLatch.countDown();
565             return true;
566         }
567 
568         @Override
close()569         public void close() {
570             IoUtils.closeQuietly(mFileDescriptor);
571         }
572     }
573 
574 
575     private class IcmpCheck extends SimpleSocketCheck implements Runnable {
576         private static final int TIMEOUT_SEND = 100;
577         private static final int TIMEOUT_RECV = 300;
578         private static final int PACKET_BUFSIZE = 512;
579         private final int mProtocol;
580         private final int mIcmpType;
581         private final int mPayloadSize;
582         // The length parameter is effectively the -s flag to ping/ping6 to specify the number of
583         // data bytes to be sent.
IcmpCheck(InetAddress source, InetAddress target, int length, Measurement measurement)584         IcmpCheck(InetAddress source, InetAddress target, int length, Measurement measurement) {
585 
586             super(source, target, measurement);
587 
588             if (mAddressFamily == AF_INET6) {
589                 mProtocol = IPPROTO_ICMPV6;
590                 mIcmpType = NetworkConstants.ICMPV6_ECHO_REQUEST_TYPE;
591                 mMeasurement.description = "ICMPv6";
592             } else {
593                 mProtocol = IPPROTO_ICMP;
594                 mIcmpType = NetworkConstants.ICMPV4_ECHO_REQUEST_TYPE;
595                 mMeasurement.description = "ICMPv4";
596             }
597             mPayloadSize = length;
598             mMeasurement.description += " payloadLength{" + mPayloadSize  + "}"
599                     + " dst{" + mTarget.getHostAddress() + "}";
600         }
601 
IcmpCheck(InetAddress target, int length, Measurement measurement)602         IcmpCheck(InetAddress target, int length, Measurement measurement) {
603             this(null, target, length, measurement);
604         }
605 
606         @Override
run()607         public void run() {
608             if (ensureMeasurementNecessary()) return;
609 
610             try {
611                 setupSocket(SOCK_DGRAM, mProtocol, TIMEOUT_SEND, TIMEOUT_RECV, 0);
612             } catch (ErrnoException | IOException e) {
613                 mMeasurement.recordFailure(e.toString());
614                 return;
615             }
616             mMeasurement.description += " src{" + socketAddressToString(mSocketAddress) + "}";
617 
618             // Build a trivial ICMP packet.
619             // The v4 ICMP header ICMP_HEADER_LEN (which is 8) and v6 is only 4 bytes (4 bytes
620             // message body followed by header before the payload).
621             // Use 8 bytes for both v4 and v6 for simplicity.
622             final byte[] icmpPacket = new byte[ICMP_HEADER_LEN + mPayloadSize];
623             icmpPacket[0] = (byte) mIcmpType;
624 
625             int count = 0;
626             mMeasurement.startTime = now();
627             while (now() < mDeadlineTime - (TIMEOUT_SEND + TIMEOUT_RECV)) {
628                 count++;
629                 icmpPacket[icmpPacket.length - 1] = (byte) count;
630                 try {
631                     Os.write(mFileDescriptor, icmpPacket, 0, icmpPacket.length);
632                 } catch (ErrnoException | InterruptedIOException e) {
633                     mMeasurement.recordFailure(e.toString());
634                     break;
635                 }
636 
637                 try {
638                     ByteBuffer reply = ByteBuffer.allocate(PACKET_BUFSIZE);
639                     Os.read(mFileDescriptor, reply);
640                     // TODO: send a few pings back to back to guesstimate packet loss.
641                     mMeasurement.recordSuccess("1/" + count);
642                     break;
643                 } catch (ErrnoException | InterruptedIOException e) {
644                     continue;
645                 }
646             }
647             if (mMeasurement.finishTime == 0) {
648                 mMeasurement.recordFailure("0/" + count);
649             }
650 
651             close();
652         }
653     }
654 
655 
656     private class DnsUdpCheck extends SimpleSocketCheck implements Runnable {
657         private static final int TIMEOUT_SEND = 100;
658         private static final int TIMEOUT_RECV = 500;
659         private static final int RR_TYPE_A = 1;
660         private static final int RR_TYPE_AAAA = 28;
661         private static final int PACKET_BUFSIZE = 512;
662 
663         protected final Random mRandom = new Random();
664 
665         // Should be static, but the compiler mocks our puny, human attempts at reason.
responseCodeStr(int rcode)666         protected String responseCodeStr(int rcode) {
667             try {
668                 return DnsResponseCode.values()[rcode].toString();
669             } catch (IndexOutOfBoundsException e) {
670                 return String.valueOf(rcode);
671             }
672         }
673 
674         protected final int mQueryType;
675 
DnsUdpCheck(InetAddress target, Measurement measurement)676         public DnsUdpCheck(InetAddress target, Measurement measurement) {
677             super(target, measurement);
678 
679             // TODO: Ideally, query the target for both types regardless of address family.
680             if (mAddressFamily == AF_INET6) {
681                 mQueryType = RR_TYPE_AAAA;
682             } else {
683                 mQueryType = RR_TYPE_A;
684             }
685 
686             mMeasurement.description = "DNS UDP dst{" + mTarget.getHostAddress() + "}";
687         }
688 
689         @Override
run()690         public void run() {
691             if (ensureMeasurementNecessary()) return;
692 
693             try {
694                 setupSocket(SOCK_DGRAM, IPPROTO_UDP, TIMEOUT_SEND, TIMEOUT_RECV,
695                         NetworkConstants.DNS_SERVER_PORT);
696             } catch (ErrnoException | IOException e) {
697                 mMeasurement.recordFailure(e.toString());
698                 return;
699             }
700 
701             // This needs to be fixed length so it can be dropped into the pre-canned packet.
702             final String sixRandomDigits = String.valueOf(mRandom.nextInt(900000) + 100000);
703             appendDnsToMeasurementDescription(sixRandomDigits, mSocketAddress);
704 
705             // Build a trivial DNS packet.
706             final byte[] dnsPacket = getDnsQueryPacket(sixRandomDigits);
707 
708             int count = 0;
709             mMeasurement.startTime = now();
710             while (now() < mDeadlineTime - (TIMEOUT_RECV + TIMEOUT_RECV)) {
711                 count++;
712                 try {
713                     Os.write(mFileDescriptor, dnsPacket, 0, dnsPacket.length);
714                 } catch (ErrnoException | InterruptedIOException e) {
715                     mMeasurement.recordFailure(e.toString());
716                     break;
717                 }
718 
719                 try {
720                     ByteBuffer reply = ByteBuffer.allocate(PACKET_BUFSIZE);
721                     Os.read(mFileDescriptor, reply);
722                     // TODO: more correct and detailed evaluation of the response,
723                     // possibly adding the returned IP address(es) to the output.
724                     final String rcodeStr = (reply.limit() > 3)
725                             ? " " + responseCodeStr((int) (reply.get(3)) & 0x0f)
726                             : "";
727                     mMeasurement.recordSuccess("1/" + count + rcodeStr);
728                     break;
729                 } catch (ErrnoException | InterruptedIOException e) {
730                     continue;
731                 }
732             }
733             if (mMeasurement.finishTime == 0) {
734                 mMeasurement.recordFailure("0/" + count);
735             }
736 
737             close();
738         }
739 
getDnsQueryPacket(String sixRandomDigits)740         protected byte[] getDnsQueryPacket(String sixRandomDigits) {
741             byte[] rnd = sixRandomDigits.getBytes(StandardCharsets.US_ASCII);
742             return new byte[] {
743                 (byte) mRandom.nextInt(), (byte) mRandom.nextInt(),  // [0-1]   query ID
744                 1, 0,  // [2-3]   flags; byte[2] = 1 for recursion desired (RD).
745                 0, 1,  // [4-5]   QDCOUNT (number of queries)
746                 0, 0,  // [6-7]   ANCOUNT (number of answers)
747                 0, 0,  // [8-9]   NSCOUNT (number of name server records)
748                 0, 0,  // [10-11] ARCOUNT (number of additional records)
749                 17, rnd[0], rnd[1], rnd[2], rnd[3], rnd[4], rnd[5],
750                         '-', 'a', 'n', 'd', 'r', 'o', 'i', 'd', '-', 'd', 's',
751                 6, 'm', 'e', 't', 'r', 'i', 'c',
752                 7, 'g', 's', 't', 'a', 't', 'i', 'c',
753                 3, 'c', 'o', 'm',
754                 0,  // null terminator of FQDN (root TLD)
755                 0, (byte) mQueryType,  // QTYPE
756                 0, 1  // QCLASS, set to 1 = IN (Internet)
757             };
758         }
759 
appendDnsToMeasurementDescription( String sixRandomDigits, SocketAddress sockAddr)760         protected void appendDnsToMeasurementDescription(
761                 String sixRandomDigits, SocketAddress sockAddr) {
762             mMeasurement.description += " src{" + socketAddressToString(sockAddr) + "}"
763                     + " qtype{" + mQueryType + "}"
764                     + " qname{" + sixRandomDigits + "-android-ds.metric.gstatic.com}";
765         }
766     }
767 
768     // TODO: Have it inherited from SimpleSocketCheck, and separate common DNS helpers out of
769     // DnsUdpCheck.
770     private class DnsTlsCheck extends DnsUdpCheck {
771         private static final int TCP_CONNECT_TIMEOUT_MS = 2500;
772         private static final int TCP_TIMEOUT_MS = 2000;
773         private static final int DNS_HEADER_SIZE = 12;
774 
775         private final String mHostname;
776 
DnsTlsCheck(@ullable String hostname, @NonNull InetAddress target, @NonNull Measurement measurement)777         public DnsTlsCheck(@Nullable String hostname, @NonNull InetAddress target,
778                 @NonNull Measurement measurement) {
779             super(target, measurement);
780 
781             mHostname = hostname;
782             mMeasurement.description = "DNS TLS dst{" + mTarget.getHostAddress() + "} hostname{"
783                     + (mHostname == null ? "" : mHostname) + "}";
784         }
785 
setupSSLSocket()786         private SSLSocket setupSSLSocket() throws IOException {
787             // A TrustManager will be created and initialized with a KeyStore containing system
788             // CaCerts. During SSL handshake, it will be used to validate the certificates from
789             // the server.
790             SSLSocket sslSocket = (SSLSocket) SSLSocketFactory.getDefault().createSocket();
791             sslSocket.setSoTimeout(TCP_TIMEOUT_MS);
792 
793             if (!TextUtils.isEmpty(mHostname)) {
794                 // Set SNI.
795                 final List<SNIServerName> names =
796                         Collections.singletonList(new SNIHostName(mHostname));
797                 SSLParameters params = sslSocket.getSSLParameters();
798                 params.setServerNames(names);
799                 sslSocket.setSSLParameters(params);
800             }
801 
802             mNetwork.bindSocket(sslSocket);
803             return sslSocket;
804         }
805 
sendDoTProbe(@ullable SSLSocket sslSocket)806         private void sendDoTProbe(@Nullable SSLSocket sslSocket) throws IOException {
807             final String sixRandomDigits = String.valueOf(mRandom.nextInt(900000) + 100000);
808             final byte[] dnsPacket = getDnsQueryPacket(sixRandomDigits);
809 
810             mMeasurement.startTime = now();
811             sslSocket.connect(new InetSocketAddress(mTarget, DNS_OVER_TLS_PORT),
812                     TCP_CONNECT_TIMEOUT_MS);
813 
814             // Synchronous call waiting for the TLS handshake complete.
815             sslSocket.startHandshake();
816             appendDnsToMeasurementDescription(sixRandomDigits, sslSocket.getLocalSocketAddress());
817 
818             final DataOutputStream output = new DataOutputStream(sslSocket.getOutputStream());
819             output.writeShort(dnsPacket.length);
820             output.write(dnsPacket, 0, dnsPacket.length);
821 
822             final DataInputStream input = new DataInputStream(sslSocket.getInputStream());
823             final int replyLength = Short.toUnsignedInt(input.readShort());
824             final byte[] reply = new byte[replyLength];
825             int bytesRead = 0;
826             while (bytesRead < replyLength) {
827                 bytesRead += input.read(reply, bytesRead, replyLength - bytesRead);
828             }
829 
830             if (bytesRead > DNS_HEADER_SIZE && bytesRead == replyLength) {
831                 mMeasurement.recordSuccess("1/1 " + responseCodeStr((int) (reply[3]) & 0x0f));
832             } else {
833                 mMeasurement.recordFailure("1/1 Read " + bytesRead + " bytes while expected to be "
834                         + replyLength + " bytes");
835             }
836         }
837 
838         @Override
run()839         public void run() {
840             if (ensureMeasurementNecessary()) return;
841 
842             // No need to restore the tag, since this thread is only used for this measurement.
843             TrafficStats.getAndSetThreadStatsTag(NetworkStackConstants.TAG_SYSTEM_PROBE);
844 
845             try (SSLSocket sslSocket = setupSSLSocket()) {
846                 sendDoTProbe(sslSocket);
847             } catch (IOException e) {
848                 mMeasurement.recordFailure(e.toString());
849             }
850         }
851     }
852 }
853