1 /* 2 * Copyright (C) 2021 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.google.android.iwlan.epdg; 18 19 import android.annotation.CallbackExecutor; 20 import android.annotation.NonNull; 21 import android.annotation.Nullable; 22 import android.net.DnsResolver; 23 import android.net.DnsResolver.DnsException; 24 import android.net.Network; 25 import android.net.ParseException; 26 import android.os.CancellationSignal; 27 import android.util.Log; 28 29 import com.android.net.module.util.DnsPacket; 30 import com.android.net.module.util.DnsPacketUtils.DnsRecordParser; 31 32 import java.net.InetAddress; 33 import java.net.UnknownHostException; 34 import java.nio.BufferUnderflowException; 35 import java.nio.ByteBuffer; 36 import java.util.ArrayList; 37 import java.util.HashMap; 38 import java.util.Iterator; 39 import java.util.LinkedHashMap; 40 import java.util.List; 41 import java.util.Map; 42 import java.util.concurrent.CompletableFuture; 43 import java.util.concurrent.ExecutionException; 44 import java.util.concurrent.Executor; 45 import java.util.concurrent.Executors; 46 47 /** 48 * A utility wrapper around android.net.DnsResolver that queries for SRV DNS Resource Records, and 49 * returns in the user callback a list of server (IP addresses, port number) combinations pertaining 50 * to the service requested. 51 * 52 * <p>The returned {@link List<SrvRecordInetAddress>} is currently not sorted according to priority 53 * and weight, in the mechanism described in RFC2782. 54 */ 55 final class SrvDnsResolver { 56 private static final String TAG = "SrvDnsResolver"; 57 58 /** 59 * An SRV Resource Record is queried to obtain the specific port number at which a service is 60 * offered. So the client is returned a combination of (INetAddress, port). 61 */ 62 static class SrvRecordInetAddress { 63 // Holds an IPv4/v6 address, obtained by querying getHostAddress(). 64 public final InetAddress mInetAddress; 65 // A 16-bit unsigned port number. 66 public final int mPort; 67 SrvRecordInetAddress(InetAddress inetAddress, int port)68 public SrvRecordInetAddress(InetAddress inetAddress, int port) { 69 mInetAddress = inetAddress; 70 mPort = port; 71 } 72 } 73 74 // Since the query type for SRV records is not defined in DnsResolver, it is defined here. 75 static final int QUERY_TYPE_SRV = 33; 76 77 /* 78 * Parses and stores an SRV record as described in RFC2782. 79 * 80 * Expects records of type QUERY_TYPE_SRV in the Queries and Answer records section, and records 81 * of type TYPE_A and TYPE_AAAA in the Additional Records section of the DnsPacket. 82 */ 83 static class SrvResponse extends DnsPacket { 84 static class SrvRecord { 85 // A 16-bit unsigned integer that determines the priority of the target host. Clients 86 // must attempt to contact the target host with the lowest-numbered priority first. 87 public final int priority; 88 89 // A 16-bit unsigned integer that specifies a relative weight for entries with the same 90 // priority. Larger weights should we given a proportionately higher probability of 91 // being selected. 92 public final int weight; 93 94 // A 16-bit unsigned integer that specifies the port on this target for this service. 95 public final int port; 96 97 // The domain name of the target host. A target of "." means that the service is 98 // decidedly not available at this domain. 99 public final String target; 100 101 private static final int MAXNAMESIZE = 255; 102 SrvRecord(byte[] srvRecordData)103 SrvRecord(byte[] srvRecordData) throws ParseException { 104 final ByteBuffer buf = ByteBuffer.wrap(srvRecordData); 105 106 try { 107 priority = Short.toUnsignedInt(buf.getShort()); 108 weight = Short.toUnsignedInt(buf.getShort()); 109 port = Short.toUnsignedInt(buf.getShort()); 110 // Although unexpected, some DNS servers do use name compression on portions of 111 // the 'target' field that overlap with the query section of the DNS packet. 112 target = 113 DnsRecordParser.parseName( 114 buf, 0, /* isNameCompressionSupported */ true); 115 if (target.length() > MAXNAMESIZE) { 116 throw new ParseException( 117 "Parse name failed, name size is too long: " + target.length()); 118 } 119 if (buf.hasRemaining()) { 120 throw new ParseException( 121 "Parsing SRV record data failed: more bytes than expected!"); 122 } 123 } catch (BufferUnderflowException e) { 124 throw new ParseException("Parsing SRV Record data failed with cause", e); 125 } 126 } 127 } 128 129 private final int mQueryType; 130 SrvResponse(@onNull byte[] data)131 SrvResponse(@NonNull byte[] data) throws ParseException { 132 super(data); 133 if (!mHeader.isResponse()) { 134 throw new ParseException("Not an answer packet"); 135 } 136 int numQueries = mHeader.getRecordCount(QDSECTION); 137 // Expects exactly one query in query section. 138 if (numQueries != 1) { 139 throw new ParseException("Unexpected query count: " + numQueries); 140 } 141 mQueryType = mRecords[QDSECTION].get(0).nsType; 142 if (mQueryType != QUERY_TYPE_SRV) { 143 throw new ParseException("Unexpected query type: " + mQueryType); 144 } 145 } 146 147 // Parses the Answers section of a DnsPacket to construct and return a mapping 148 // of Domain Name strings to their corresponding SRV record. parseSrvRecords()149 public @NonNull Map<String, SrvRecord> parseSrvRecords() throws ParseException { 150 final HashMap<String, SrvRecord> targetNameToSrvRecord = new LinkedHashMap<>(); 151 if (mHeader.getRecordCount(ANSECTION) == 0) return targetNameToSrvRecord; 152 153 for (final DnsRecord ansSec : mRecords[ANSECTION]) { 154 final int nsType = ansSec.nsType; 155 if (nsType != QUERY_TYPE_SRV) { 156 throw new ParseException("Unexpected DNS record type in ANSECTION: " + nsType); 157 } 158 final SrvRecord record = new SrvRecord(ansSec.getRR()); 159 if (targetNameToSrvRecord.containsKey(record.target)) { 160 throw new ParseException( 161 "Domain name " 162 + record.target 163 + " already encountered in DNS response!"); 164 } 165 targetNameToSrvRecord.put(record.target, record); 166 Log.d(TAG, "SrvRecord name: " + ansSec.dName + " target name: " + record.target); 167 } 168 return targetNameToSrvRecord; 169 } 170 171 /* 172 * Parses the 'Additional Records' section of a DnsPacket and expects 'Address Records' 173 * (TYPE_A and TYPE_AAAA records) to construct and return a mapping of Domain Name strings 174 * to their corresponding IP address(es). 175 */ parseIpAddresses()176 public @NonNull Map<String, List<InetAddress>> parseIpAddresses() throws ParseException { 177 final HashMap<String, List<InetAddress>> domainNameToIpAddress = new HashMap<>(); 178 if (mHeader.getRecordCount(ARSECTION) == 0) return domainNameToIpAddress; 179 180 for (final DnsRecord ansSec : mRecords[ARSECTION]) { 181 int nsType = ansSec.nsType; 182 if (nsType != DnsResolver.TYPE_A && nsType != DnsResolver.TYPE_AAAA) { 183 throw new ParseException("Unexpected DNS record type in ARSECTION: " + nsType); 184 } 185 domainNameToIpAddress.computeIfAbsent(ansSec.dName, k -> new ArrayList<>()); 186 try { 187 final InetAddress ipAddress = InetAddress.getByAddress(ansSec.getRR()); 188 Log.d( 189 TAG, 190 "Additional record name: " 191 + ansSec.dName 192 + " IP addr: " 193 + ipAddress.getHostAddress()); 194 domainNameToIpAddress.get(ansSec.dName).add(ipAddress); 195 } catch (UnknownHostException e) { 196 throw new ParseException( 197 "RR to IP address translation failed for domain: " + ansSec.dName); 198 } 199 } 200 return domainNameToIpAddress; 201 } 202 } 203 204 /** 205 * A decorator for {@link DnsResolver.Callback} that accumulates IPv4/v6 responses for SRV DNS 206 * queries and passes it up to the user callback. 207 */ 208 private static class SrvRecordAnswerAccumulator implements DnsResolver.Callback<byte[]> { 209 private static final String TAG = "SrvRecordAnswerAccum"; 210 211 private final Network mNetwork; 212 private final DnsResolver.Callback<List<SrvRecordInetAddress>> mUserCallback; 213 private final Executor mUserExecutor; 214 215 private static class LazyExecutor { 216 public static final Executor INSTANCE = Executors.newSingleThreadExecutor(); 217 } 218 getInternalExecutor()219 static Executor getInternalExecutor() { 220 return LazyExecutor.INSTANCE; 221 } 222 SrvRecordAnswerAccumulator( @onNull Network network, @NonNull DnsResolver.Callback<List<SrvRecordInetAddress>> callback, @NonNull @CallbackExecutor Executor executor)223 SrvRecordAnswerAccumulator( 224 @NonNull Network network, 225 @NonNull DnsResolver.Callback<List<SrvRecordInetAddress>> callback, 226 @NonNull @CallbackExecutor Executor executor) { 227 mNetwork = network; 228 mUserCallback = callback; 229 mUserExecutor = executor; 230 } 231 232 /** 233 * Some DNS servers, when queried for an SRV record, do not return the IPv4/v6 records along 234 * with the SRV record. For those, we perform an additional blocking IPv4/v6 DNS query for 235 * each outstanding SRV record. 236 */ queryDns(String domainName)237 private List<InetAddress> queryDns(String domainName) throws DnsException { 238 final CompletableFuture<List<InetAddress>> result = new CompletableFuture(); 239 final DnsResolver.Callback<List<InetAddress>> cb = 240 new DnsResolver.Callback<List<InetAddress>>() { 241 @Override 242 public void onAnswer( 243 @NonNull final List<InetAddress> answer, final int rcode) { 244 if (rcode != 0) { 245 Log.e(TAG, "queryDNS Response Code = " + rcode); 246 } 247 result.complete(answer); 248 } 249 250 @Override 251 public void onError(@Nullable final DnsException error) { 252 Log.e(TAG, "queryDNS response with error : " + error); 253 result.completeExceptionally(error); 254 } 255 }; 256 DnsResolver.getInstance() 257 .query(mNetwork, domainName, DnsResolver.FLAG_EMPTY, Runnable::run, null, cb); 258 259 try { 260 return result.get(); 261 } catch (ExecutionException e) { 262 throw (DnsException) e.getCause(); 263 } catch (InterruptedException e) { 264 Thread.currentThread().interrupt(); // Restore the interrupted status 265 throw new DnsException(DnsResolver.ERROR_SYSTEM, e); 266 } 267 } 268 269 /** 270 * Composes the final (IP address, Port) combination for the client's SRV request. Performs 271 * additional DNS queries if necessary. The SRV records are presently not sorted according 272 * to priority and weight, as described in RFC2782- this is simply 'good enough'. 273 */ composeSrvRecordResult(SrvResponse response)274 private List<SrvRecordInetAddress> composeSrvRecordResult(SrvResponse response) 275 throws DnsPacket.ParseException, DnsException { 276 final List<SrvRecordInetAddress> srvRecordInetAddresses = new ArrayList<>(); 277 final Map<String, List<InetAddress>> domainNameToIpAddresses = 278 response.parseIpAddresses(); 279 final Map<String, SrvResponse.SrvRecord> targetNameToSrvRecords = 280 response.parseSrvRecords(); 281 282 Iterator<Map.Entry<String, SrvResponse.SrvRecord>> itr = 283 targetNameToSrvRecords.entrySet().iterator(); 284 285 // Checks if the received SRV RRs have a corresponding match in IP addresses. For the 286 // ones that do, adds the (IP address, port number) to the output field list. 287 while (itr.hasNext()) { 288 Map.Entry<String, SrvResponse.SrvRecord> targetNameToSrvRecord = itr.next(); 289 String domainName = targetNameToSrvRecord.getKey(); 290 int port = targetNameToSrvRecord.getValue().port; 291 List<InetAddress> addresses = domainNameToIpAddresses.get(domainName); 292 if (addresses != null) { 293 // Found a match- add to output list and remove entry from SrvRecord collection. 294 for (InetAddress address : addresses) { 295 srvRecordInetAddresses.add(new SrvRecordInetAddress(address, port)); 296 } 297 itr.remove(); 298 } 299 } 300 301 // For the SRV RRs that don't, spawns a separate DnsResolver query for each, and 302 // collects results using a blocking call. 303 itr = targetNameToSrvRecords.entrySet().iterator(); 304 while (itr.hasNext()) { 305 Map.Entry<String, SrvResponse.SrvRecord> targetNameToSrvRecord = itr.next(); 306 String domainName = targetNameToSrvRecord.getKey(); 307 int port = targetNameToSrvRecord.getValue().port; 308 List<InetAddress> addresses = queryDns(domainName); 309 for (InetAddress address : addresses) { 310 srvRecordInetAddresses.add(new SrvRecordInetAddress(address, port)); 311 } 312 } 313 return srvRecordInetAddresses; 314 } 315 316 @Override onAnswer(@onNull byte[] answer, int rcode)317 public void onAnswer(@NonNull byte[] answer, int rcode) { 318 try { 319 final SrvResponse response = new SrvResponse(answer); 320 final List<SrvRecordInetAddress> result = composeSrvRecordResult(response); 321 mUserExecutor.execute(() -> mUserCallback.onAnswer(result, rcode)); 322 } catch (DnsPacket.ParseException e) { 323 // Convert the com.android.net.module.util.DnsPacket.ParseException to an 324 // android.net.ParseException. This is the type that was used in Q and is implied 325 // by the public documentation of ERROR_PARSE. 326 // 327 // DnsPacket cannot throw android.net.ParseException directly because it's @hide. 328 final ParseException pe = new ParseException(e.reason, e.getCause()); 329 pe.setStackTrace(e.getStackTrace()); 330 Log.e(TAG, "ParseException", pe); 331 mUserExecutor.execute( 332 () -> mUserCallback.onError(new DnsException(DnsResolver.ERROR_PARSE, pe))); 333 } catch (DnsException e) { 334 mUserExecutor.execute(() -> mUserCallback.onError(e)); 335 } 336 } 337 338 @Override onError(@onNull DnsException error)339 public void onError(@NonNull DnsException error) { 340 Log.e(TAG, "onError: " + error); 341 mUserExecutor.execute(() -> mUserCallback.onError(error)); 342 } 343 } 344 345 /** 346 * Send an SRV DNS query with the specified name, class and query type. The answer will be 347 * provided asynchronously on the passed executor, through the provided {@link 348 * DnsResolver.Callback}. 349 * 350 * @param network {@link Network} specifying which network to query on. {@code null} for query 351 * on default network. 352 * @param domain SRV domain name to query ( in format _Service._Protocol.Name) 353 * @param cancellationSignal used by the caller to signal if the query should be cancelled. May 354 * be {@code null}. 355 * @param callback a {@link DnsResolver.Callback} which will be called on a separate thread to 356 * notify the caller of the result of the DNS query. 357 */ query( @ullable Network network, @NonNull String domain, @NonNull @CallbackExecutor Executor executor, @Nullable CancellationSignal cancellationSignal, @NonNull DnsResolver.Callback<List<SrvRecordInetAddress>> callback)358 public static void query( 359 @Nullable Network network, 360 @NonNull String domain, 361 @NonNull @CallbackExecutor Executor executor, 362 @Nullable CancellationSignal cancellationSignal, 363 @NonNull DnsResolver.Callback<List<SrvRecordInetAddress>> callback) { 364 final SrvRecordAnswerAccumulator srvDnsCb = 365 new SrvRecordAnswerAccumulator(network, callback, executor); 366 DnsResolver.getInstance() 367 .rawQuery( 368 network, 369 domain, 370 DnsResolver.CLASS_IN, 371 QUERY_TYPE_SRV, 372 DnsResolver.FLAG_EMPTY, 373 SrvRecordAnswerAccumulator.getInternalExecutor(), 374 cancellationSignal, 375 srvDnsCb); 376 } 377 SrvDnsResolver()378 private SrvDnsResolver() {} 379 } 380