1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.google.android.iwlan.epdg;
18 
19 import android.annotation.CallbackExecutor;
20 import android.annotation.NonNull;
21 import android.annotation.Nullable;
22 import android.net.DnsResolver;
23 import android.net.DnsResolver.DnsException;
24 import android.net.Network;
25 import android.net.ParseException;
26 import android.os.CancellationSignal;
27 import android.util.Log;
28 
29 import com.android.net.module.util.DnsPacket;
30 import com.android.net.module.util.DnsPacketUtils.DnsRecordParser;
31 
32 import java.net.InetAddress;
33 import java.net.UnknownHostException;
34 import java.nio.BufferUnderflowException;
35 import java.nio.ByteBuffer;
36 import java.util.ArrayList;
37 import java.util.HashMap;
38 import java.util.Iterator;
39 import java.util.LinkedHashMap;
40 import java.util.List;
41 import java.util.Map;
42 import java.util.concurrent.CompletableFuture;
43 import java.util.concurrent.ExecutionException;
44 import java.util.concurrent.Executor;
45 import java.util.concurrent.Executors;
46 
47 /**
48  * A utility wrapper around android.net.DnsResolver that queries for SRV DNS Resource Records, and
49  * returns in the user callback a list of server (IP addresses, port number) combinations pertaining
50  * to the service requested.
51  *
52  * <p>The returned {@link List<SrvRecordInetAddress>} is currently not sorted according to priority
53  * and weight, in the mechanism described in RFC2782.
54  */
55 final class SrvDnsResolver {
56     private static final String TAG = "SrvDnsResolver";
57 
58     /**
59      * An SRV Resource Record is queried to obtain the specific port number at which a service is
60      * offered. So the client is returned a combination of (INetAddress, port).
61      */
62     static class SrvRecordInetAddress {
63         // Holds an IPv4/v6 address, obtained by querying getHostAddress().
64         public final InetAddress mInetAddress;
65         // A 16-bit unsigned port number.
66         public final int mPort;
67 
SrvRecordInetAddress(InetAddress inetAddress, int port)68         public SrvRecordInetAddress(InetAddress inetAddress, int port) {
69             mInetAddress = inetAddress;
70             mPort = port;
71         }
72     }
73 
74     // Since the query type for SRV records is not defined in DnsResolver, it is defined here.
75     static final int QUERY_TYPE_SRV = 33;
76 
77     /*
78      * Parses and stores an SRV record as described in RFC2782.
79      *
80      * Expects records of type QUERY_TYPE_SRV in the Queries and Answer records section, and records
81      * of type TYPE_A and TYPE_AAAA in the Additional Records section of the DnsPacket.
82      */
83     static class SrvResponse extends DnsPacket {
84         static class SrvRecord {
85             // A 16-bit unsigned integer that determines the priority of the target host. Clients
86             // must attempt to contact the target host with the lowest-numbered priority first.
87             public final int priority;
88 
89             // A 16-bit unsigned integer that specifies a relative weight for entries with the same
90             // priority. Larger weights should we given a proportionately higher probability of
91             // being selected.
92             public final int weight;
93 
94             // A 16-bit unsigned integer that specifies the port on this target for this service.
95             public final int port;
96 
97             // The domain name of the target host. A target of "." means that the service is
98             // decidedly not available at this domain.
99             public final String target;
100 
101             private static final int MAXNAMESIZE = 255;
102 
SrvRecord(byte[] srvRecordData)103             SrvRecord(byte[] srvRecordData) throws ParseException {
104                 final ByteBuffer buf = ByteBuffer.wrap(srvRecordData);
105 
106                 try {
107                     priority = Short.toUnsignedInt(buf.getShort());
108                     weight = Short.toUnsignedInt(buf.getShort());
109                     port = Short.toUnsignedInt(buf.getShort());
110                     // Although unexpected, some DNS servers do use name compression on portions of
111                     // the 'target' field that overlap with the query section of the DNS packet.
112                     target =
113                             DnsRecordParser.parseName(
114                                     buf, 0, /* isNameCompressionSupported */ true);
115                     if (target.length() > MAXNAMESIZE) {
116                         throw new ParseException(
117                                 "Parse name failed, name size is too long: " + target.length());
118                     }
119                     if (buf.hasRemaining()) {
120                         throw new ParseException(
121                                 "Parsing SRV record data failed: more bytes than expected!");
122                     }
123                 } catch (BufferUnderflowException e) {
124                     throw new ParseException("Parsing SRV Record data failed with cause", e);
125                 }
126             }
127         }
128 
129         private final int mQueryType;
130 
SrvResponse(@onNull byte[] data)131         SrvResponse(@NonNull byte[] data) throws ParseException {
132             super(data);
133             if (!mHeader.isResponse()) {
134                 throw new ParseException("Not an answer packet");
135             }
136             int numQueries = mHeader.getRecordCount(QDSECTION);
137             // Expects exactly one query in query section.
138             if (numQueries != 1) {
139                 throw new ParseException("Unexpected query count: " + numQueries);
140             }
141             mQueryType = mRecords[QDSECTION].get(0).nsType;
142             if (mQueryType != QUERY_TYPE_SRV) {
143                 throw new ParseException("Unexpected query type: " + mQueryType);
144             }
145         }
146 
147         // Parses the Answers section of a DnsPacket to construct and return a mapping
148         // of Domain Name strings to their corresponding SRV record.
parseSrvRecords()149         public @NonNull Map<String, SrvRecord> parseSrvRecords() throws ParseException {
150             final HashMap<String, SrvRecord> targetNameToSrvRecord = new LinkedHashMap<>();
151             if (mHeader.getRecordCount(ANSECTION) == 0) return targetNameToSrvRecord;
152 
153             for (final DnsRecord ansSec : mRecords[ANSECTION]) {
154                 final int nsType = ansSec.nsType;
155                 if (nsType != QUERY_TYPE_SRV) {
156                     throw new ParseException("Unexpected DNS record type in ANSECTION: " + nsType);
157                 }
158                 final SrvRecord record = new SrvRecord(ansSec.getRR());
159                 if (targetNameToSrvRecord.containsKey(record.target)) {
160                     throw new ParseException(
161                             "Domain name "
162                                     + record.target
163                                     + " already encountered in DNS response!");
164                 }
165                 targetNameToSrvRecord.put(record.target, record);
166                 Log.d(TAG, "SrvRecord name: " + ansSec.dName + " target name: " + record.target);
167             }
168             return targetNameToSrvRecord;
169         }
170 
171         /*
172          * Parses the 'Additional Records' section of a DnsPacket and expects 'Address Records'
173          * (TYPE_A and TYPE_AAAA records) to construct and return a mapping of Domain Name strings
174          * to their corresponding IP address(es).
175          */
parseIpAddresses()176         public @NonNull Map<String, List<InetAddress>> parseIpAddresses() throws ParseException {
177             final HashMap<String, List<InetAddress>> domainNameToIpAddress = new HashMap<>();
178             if (mHeader.getRecordCount(ARSECTION) == 0) return domainNameToIpAddress;
179 
180             for (final DnsRecord ansSec : mRecords[ARSECTION]) {
181                 int nsType = ansSec.nsType;
182                 if (nsType != DnsResolver.TYPE_A && nsType != DnsResolver.TYPE_AAAA) {
183                     throw new ParseException("Unexpected DNS record type in ARSECTION: " + nsType);
184                 }
185                 domainNameToIpAddress.computeIfAbsent(ansSec.dName, k -> new ArrayList<>());
186                 try {
187                     final InetAddress ipAddress = InetAddress.getByAddress(ansSec.getRR());
188                     Log.d(
189                             TAG,
190                             "Additional record name: "
191                                     + ansSec.dName
192                                     + " IP addr: "
193                                     + ipAddress.getHostAddress());
194                     domainNameToIpAddress.get(ansSec.dName).add(ipAddress);
195                 } catch (UnknownHostException e) {
196                     throw new ParseException(
197                             "RR to IP address translation failed for domain: " + ansSec.dName);
198                 }
199             }
200             return domainNameToIpAddress;
201         }
202     }
203 
204     /**
205      * A decorator for {@link DnsResolver.Callback} that accumulates IPv4/v6 responses for SRV DNS
206      * queries and passes it up to the user callback.
207      */
208     private static class SrvRecordAnswerAccumulator implements DnsResolver.Callback<byte[]> {
209         private static final String TAG = "SrvRecordAnswerAccum";
210 
211         private final Network mNetwork;
212         private final DnsResolver.Callback<List<SrvRecordInetAddress>> mUserCallback;
213         private final Executor mUserExecutor;
214 
215         private static class LazyExecutor {
216             public static final Executor INSTANCE = Executors.newSingleThreadExecutor();
217         }
218 
getInternalExecutor()219         static Executor getInternalExecutor() {
220             return LazyExecutor.INSTANCE;
221         }
222 
SrvRecordAnswerAccumulator( @onNull Network network, @NonNull DnsResolver.Callback<List<SrvRecordInetAddress>> callback, @NonNull @CallbackExecutor Executor executor)223         SrvRecordAnswerAccumulator(
224                 @NonNull Network network,
225                 @NonNull DnsResolver.Callback<List<SrvRecordInetAddress>> callback,
226                 @NonNull @CallbackExecutor Executor executor) {
227             mNetwork = network;
228             mUserCallback = callback;
229             mUserExecutor = executor;
230         }
231 
232         /**
233          * Some DNS servers, when queried for an SRV record, do not return the IPv4/v6 records along
234          * with the SRV record. For those, we perform an additional blocking IPv4/v6 DNS query for
235          * each outstanding SRV record.
236          */
queryDns(String domainName)237         private List<InetAddress> queryDns(String domainName) throws DnsException {
238             final CompletableFuture<List<InetAddress>> result = new CompletableFuture();
239             final DnsResolver.Callback<List<InetAddress>> cb =
240                     new DnsResolver.Callback<List<InetAddress>>() {
241                         @Override
242                         public void onAnswer(
243                                 @NonNull final List<InetAddress> answer, final int rcode) {
244                             if (rcode != 0) {
245                                 Log.e(TAG, "queryDNS Response Code = " + rcode);
246                             }
247                             result.complete(answer);
248                         }
249 
250                         @Override
251                         public void onError(@Nullable final DnsException error) {
252                             Log.e(TAG, "queryDNS response with error : " + error);
253                             result.completeExceptionally(error);
254                         }
255                     };
256             DnsResolver.getInstance()
257                     .query(mNetwork, domainName, DnsResolver.FLAG_EMPTY, Runnable::run, null, cb);
258 
259             try {
260                 return result.get();
261             } catch (ExecutionException e) {
262                 throw (DnsException) e.getCause();
263             } catch (InterruptedException e) {
264                 Thread.currentThread().interrupt(); // Restore the interrupted status
265                 throw new DnsException(DnsResolver.ERROR_SYSTEM, e);
266             }
267         }
268 
269         /**
270          * Composes the final (IP address, Port) combination for the client's SRV request. Performs
271          * additional DNS queries if necessary. The SRV records are presently not sorted according
272          * to priority and weight, as described in RFC2782- this is simply 'good enough'.
273          */
composeSrvRecordResult(SrvResponse response)274         private List<SrvRecordInetAddress> composeSrvRecordResult(SrvResponse response)
275                 throws DnsPacket.ParseException, DnsException {
276             final List<SrvRecordInetAddress> srvRecordInetAddresses = new ArrayList<>();
277             final Map<String, List<InetAddress>> domainNameToIpAddresses =
278                     response.parseIpAddresses();
279             final Map<String, SrvResponse.SrvRecord> targetNameToSrvRecords =
280                     response.parseSrvRecords();
281 
282             Iterator<Map.Entry<String, SrvResponse.SrvRecord>> itr =
283                     targetNameToSrvRecords.entrySet().iterator();
284 
285             // Checks if the received SRV RRs have a corresponding match in IP addresses. For the
286             // ones that do, adds the (IP address, port number) to the output field list.
287             while (itr.hasNext()) {
288                 Map.Entry<String, SrvResponse.SrvRecord> targetNameToSrvRecord = itr.next();
289                 String domainName = targetNameToSrvRecord.getKey();
290                 int port = targetNameToSrvRecord.getValue().port;
291                 List<InetAddress> addresses = domainNameToIpAddresses.get(domainName);
292                 if (addresses != null) {
293                     // Found a match- add to output list and remove entry from SrvRecord collection.
294                     for (InetAddress address : addresses) {
295                         srvRecordInetAddresses.add(new SrvRecordInetAddress(address, port));
296                     }
297                     itr.remove();
298                 }
299             }
300 
301             // For the SRV RRs that don't, spawns a separate DnsResolver query for each, and
302             // collects results using a blocking call.
303             itr = targetNameToSrvRecords.entrySet().iterator();
304             while (itr.hasNext()) {
305                 Map.Entry<String, SrvResponse.SrvRecord> targetNameToSrvRecord = itr.next();
306                 String domainName = targetNameToSrvRecord.getKey();
307                 int port = targetNameToSrvRecord.getValue().port;
308                 List<InetAddress> addresses = queryDns(domainName);
309                 for (InetAddress address : addresses) {
310                     srvRecordInetAddresses.add(new SrvRecordInetAddress(address, port));
311                 }
312             }
313             return srvRecordInetAddresses;
314         }
315 
316         @Override
onAnswer(@onNull byte[] answer, int rcode)317         public void onAnswer(@NonNull byte[] answer, int rcode) {
318             try {
319                 final SrvResponse response = new SrvResponse(answer);
320                 final List<SrvRecordInetAddress> result = composeSrvRecordResult(response);
321                 mUserExecutor.execute(() -> mUserCallback.onAnswer(result, rcode));
322             } catch (DnsPacket.ParseException e) {
323                 // Convert the com.android.net.module.util.DnsPacket.ParseException to an
324                 // android.net.ParseException. This is the type that was used in Q and is implied
325                 // by the public documentation of ERROR_PARSE.
326                 //
327                 // DnsPacket cannot throw android.net.ParseException directly because it's @hide.
328                 final ParseException pe = new ParseException(e.reason, e.getCause());
329                 pe.setStackTrace(e.getStackTrace());
330                 Log.e(TAG, "ParseException", pe);
331                 mUserExecutor.execute(
332                         () -> mUserCallback.onError(new DnsException(DnsResolver.ERROR_PARSE, pe)));
333             } catch (DnsException e) {
334                 mUserExecutor.execute(() -> mUserCallback.onError(e));
335             }
336         }
337 
338         @Override
onError(@onNull DnsException error)339         public void onError(@NonNull DnsException error) {
340             Log.e(TAG, "onError: " + error);
341             mUserExecutor.execute(() -> mUserCallback.onError(error));
342         }
343     }
344 
345     /**
346      * Send an SRV DNS query with the specified name, class and query type. The answer will be
347      * provided asynchronously on the passed executor, through the provided {@link
348      * DnsResolver.Callback}.
349      *
350      * @param network {@link Network} specifying which network to query on. {@code null} for query
351      *     on default network.
352      * @param domain SRV domain name to query ( in format _Service._Protocol.Name)
353      * @param cancellationSignal used by the caller to signal if the query should be cancelled. May
354      *     be {@code null}.
355      * @param callback a {@link DnsResolver.Callback} which will be called on a separate thread to
356      *     notify the caller of the result of the DNS query.
357      */
query( @ullable Network network, @NonNull String domain, @NonNull @CallbackExecutor Executor executor, @Nullable CancellationSignal cancellationSignal, @NonNull DnsResolver.Callback<List<SrvRecordInetAddress>> callback)358     public static void query(
359             @Nullable Network network,
360             @NonNull String domain,
361             @NonNull @CallbackExecutor Executor executor,
362             @Nullable CancellationSignal cancellationSignal,
363             @NonNull DnsResolver.Callback<List<SrvRecordInetAddress>> callback) {
364         final SrvRecordAnswerAccumulator srvDnsCb =
365                 new SrvRecordAnswerAccumulator(network, callback, executor);
366         DnsResolver.getInstance()
367                 .rawQuery(
368                         network,
369                         domain,
370                         DnsResolver.CLASS_IN,
371                         QUERY_TYPE_SRV,
372                         DnsResolver.FLAG_EMPTY,
373                         SrvRecordAnswerAccumulator.getInternalExecutor(),
374                         cancellationSignal,
375                         srvDnsCb);
376     }
377 
SrvDnsResolver()378     private SrvDnsResolver() {}
379 }
380