1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.server.connectivity.mdns;
18 
19 import android.annotation.NonNull;
20 import android.annotation.Nullable;
21 import android.net.Network;
22 import android.util.ArrayMap;
23 import android.util.Pair;
24 
25 import com.android.server.connectivity.mdns.util.MdnsUtils;
26 
27 import java.io.EOFException;
28 import java.util.ArrayList;
29 import java.util.Collection;
30 import java.util.List;
31 import java.util.Set;
32 
33 /** A class that decodes mDNS responses from UDP packets. */
34 public class MdnsResponseDecoder {
35     public static final int SUCCESS = 0;
36     private static final String TAG = "MdnsResponseDecoder";
37     private final boolean allowMultipleSrvRecordsPerHost =
38             MdnsConfigs.allowMultipleSrvRecordsPerHost();
39     @Nullable private final String[] serviceType;
40     private final MdnsUtils.Clock clock;
41 
42     /** Constructs a new decoder that will extract responses for the given service type. */
MdnsResponseDecoder(@onNull MdnsUtils.Clock clock, @Nullable String[] serviceType)43     public MdnsResponseDecoder(@NonNull MdnsUtils.Clock clock, @Nullable String[] serviceType) {
44         this.clock = clock;
45         this.serviceType = serviceType;
46     }
47 
findResponseWithPointer( List<MdnsResponse> responses, String[] pointer)48     private static MdnsResponse findResponseWithPointer(
49             List<MdnsResponse> responses, String[] pointer) {
50         if (responses != null) {
51             for (MdnsResponse response : responses) {
52                 if (MdnsUtils.equalsDnsLabelIgnoreDnsCase(response.getServiceName(), pointer)) {
53                     return response;
54                 }
55             }
56         }
57         return null;
58     }
59 
findResponseWithHostName( List<MdnsResponse> responses, String[] hostName)60     private static MdnsResponse findResponseWithHostName(
61             List<MdnsResponse> responses, String[] hostName) {
62         if (responses != null) {
63             for (MdnsResponse response : responses) {
64                 MdnsServiceRecord serviceRecord = response.getServiceRecord();
65                 if (serviceRecord == null) {
66                     continue;
67                 }
68                 if (MdnsUtils.equalsDnsLabelIgnoreDnsCase(serviceRecord.getServiceHost(),
69                         hostName)) {
70                     return response;
71                 }
72             }
73         }
74         return null;
75     }
76 
77     /**
78      * Decodes all mDNS responses for the desired service type from a packet. The class does not
79      * check the responses for completeness; the caller should do that.
80      *
81      * @param recvbuf The received data buffer to read from.
82      * @param length The length of received data buffer.
83      * @return A decoded {@link MdnsPacket}.
84      * @throws MdnsPacket.ParseException if a response packet could not be parsed.
85      */
86     @NonNull
parseResponse(@onNull byte[] recvbuf, int length, @NonNull MdnsFeatureFlags mdnsFeatureFlags)87     public static MdnsPacket parseResponse(@NonNull byte[] recvbuf, int length,
88             @NonNull MdnsFeatureFlags mdnsFeatureFlags) throws MdnsPacket.ParseException {
89         final MdnsPacketReader reader = new MdnsPacketReader(recvbuf, length, mdnsFeatureFlags);
90 
91         final MdnsPacket mdnsPacket;
92         try {
93             final int transactionId = reader.readUInt16();
94             int flags = reader.readUInt16();
95             if ((flags & MdnsConstants.FLAGS_RESPONSE_MASK) != MdnsConstants.FLAGS_RESPONSE) {
96                 throw new MdnsPacket.ParseException(
97                         MdnsResponseErrorCode.ERROR_NOT_RESPONSE_MESSAGE, "Not a response", null);
98             }
99 
100             mdnsPacket = MdnsPacket.parseRecordsSection(reader, flags, transactionId);
101             if (mdnsPacket.answers.size() < 1) {
102                 throw new MdnsPacket.ParseException(
103                         MdnsResponseErrorCode.ERROR_NO_ANSWERS, "Response has no answers",
104                         null);
105             }
106             return mdnsPacket;
107         } catch (EOFException e) {
108             throw new MdnsPacket.ParseException(MdnsResponseErrorCode.ERROR_END_OF_FILE,
109                     "Reached the end of the mDNS response unexpectedly.", e);
110         }
111     }
112 
113     /**
114      * Augments a list of {@link MdnsResponse} with records from a packet. The class does not check
115      * the resulting responses for completeness; the caller should do that.
116      *
117      * @param mdnsPacket the response packet with the new records
118      * @param existingResponses list of existing responses. Will not be modified.
119      * @param interfaceIndex the network interface index (or
120      * {@link MdnsSocket#INTERFACE_INDEX_UNSPECIFIED} if not known) at which the packet was received
121      * @param network the network at which the packet was received, or null if it is unknown.
122      * @return The pair of 1) set of response instances that were modified or newly added. *not*
123      *                      including those which records were only updated with newer receive
124      *                      timestamps.
125      *                     2) A copy of the original responses with some of them have records
126      *                     update or only contains receive time updated.
127      */
augmentResponses( @onNull MdnsPacket mdnsPacket, @NonNull Collection<MdnsResponse> existingResponses, int interfaceIndex, @Nullable Network network)128     public Pair<Set<MdnsResponse>, ArrayList<MdnsResponse>> augmentResponses(
129             @NonNull MdnsPacket mdnsPacket,
130             @NonNull Collection<MdnsResponse> existingResponses, int interfaceIndex,
131             @Nullable Network network) {
132         final ArrayList<MdnsRecord> records = new ArrayList<>(
133                 mdnsPacket.questions.size() + mdnsPacket.answers.size()
134                         + mdnsPacket.authorityRecords.size() + mdnsPacket.additionalRecords.size());
135         records.addAll(mdnsPacket.answers);
136         records.addAll(mdnsPacket.authorityRecords);
137         records.addAll(mdnsPacket.additionalRecords);
138 
139         final Set<MdnsResponse> modified = MdnsUtils.newSet();
140         final ArrayList<MdnsResponse> responses = new ArrayList<>(existingResponses.size());
141         final ArrayMap<MdnsResponse, MdnsResponse> augmentedToOriginal = new ArrayMap<>();
142         for (MdnsResponse existing : existingResponses) {
143             final MdnsResponse copy = new MdnsResponse(existing);
144             responses.add(copy);
145             augmentedToOriginal.put(copy, existing);
146         }
147         // The response records are structured in a hierarchy, where some records reference
148         // others, as follows:
149         //
150         //        PTR
151         //        / \
152         //       /   \
153         //      TXT  SRV
154         //           / \
155         //          /   \
156         //         A   AAAA
157         //
158         // But the order in which these records appear in the response packet is completely
159         // arbitrary. This means that we need to rescan the record list to construct each level of
160         // this hierarchy.
161         //
162         // PTR: service type -> service instance name
163         //
164         // SRV: service instance name -> host name (priority, weight)
165         //
166         // TXT: service instance name -> machine readable txt entries.
167         //
168         // A: host name -> IP address
169 
170         // Loop 1: find PTR records, which identify distinct service instances.
171         long now = clock.elapsedRealtime();
172         for (MdnsRecord record : records) {
173             if (record instanceof MdnsPointerRecord) {
174                 String[] name = record.getName();
175                 if ((serviceType == null) || MdnsUtils.typeEqualsOrIsSubtype(
176                         serviceType, name)) {
177                     MdnsPointerRecord pointerRecord = (MdnsPointerRecord) record;
178                     // Group PTR records that refer to the same service instance name into a single
179                     // response.
180                     MdnsResponse response = findResponseWithPointer(responses,
181                             pointerRecord.getPointer());
182                     if (response == null) {
183                         response = new MdnsResponse(now, pointerRecord.getPointer(), interfaceIndex,
184                                 network);
185                         responses.add(response);
186                     }
187                     if (response.addPointerRecord((MdnsPointerRecord) record)) {
188                         modified.add(response);
189                     }
190                 }
191             }
192         }
193 
194         // Loop 2: find SRV and TXT records, which reference the pointer in the PTR record.
195         for (MdnsRecord record : records) {
196             if (record instanceof MdnsServiceRecord) {
197                 MdnsServiceRecord serviceRecord = (MdnsServiceRecord) record;
198                 MdnsResponse response = findResponseWithPointer(responses, serviceRecord.getName());
199                 if (response != null && response.setServiceRecord(serviceRecord)) {
200                     response.dropUnmatchedAddressRecords();
201                     modified.add(response);
202                 }
203             } else if (record instanceof MdnsTextRecord) {
204                 MdnsTextRecord textRecord = (MdnsTextRecord) record;
205                 MdnsResponse response = findResponseWithPointer(responses, textRecord.getName());
206                 if (response != null && response.setTextRecord(textRecord)) {
207                     modified.add(response);
208                 }
209             }
210         }
211 
212         // Loop 3-1: find A and AAAA records and clear addresses if the cache-flush bit set, which
213         //           reference the host name in the SRV record.
214         final List<MdnsInetAddressRecord> inetRecords = new ArrayList<>();
215         for (MdnsRecord record : records) {
216             if (record instanceof MdnsInetAddressRecord) {
217                 MdnsInetAddressRecord inetRecord = (MdnsInetAddressRecord) record;
218                 inetRecords.add(inetRecord);
219                 if (allowMultipleSrvRecordsPerHost) {
220                     List<MdnsResponse> matchingResponses =
221                             findResponsesWithHostName(responses, inetRecord.getName());
222                     for (MdnsResponse response : matchingResponses) {
223                         // Per RFC6762 10.2, clear all address records if the cache-flush bit set.
224                         // This bit, the cache-flush bit, tells neighboring hosts
225                         // that this is not a shared record type.  Instead of merging this new
226                         // record additively into the cache in addition to any previous records with
227                         // the same name, rrtype, and rrclass.
228                         // TODO: All old records with that name, rrtype, and rrclass that were
229                         //       received more than one second ago are declared invalid, and marked
230                         //       to expire from the cache in one second.
231                         if (inetRecord.getCacheFlush()) {
232                             response.clearInet4AddressRecords();
233                             response.clearInet6AddressRecords();
234                         }
235                     }
236                 } else {
237                     MdnsResponse response =
238                             findResponseWithHostName(responses, inetRecord.getName());
239                     if (response != null) {
240                         // Per RFC6762 10.2, clear all address records if the cache-flush bit set.
241                         // This bit, the cache-flush bit, tells neighboring hosts
242                         // that this is not a shared record type.  Instead of merging this new
243                         // record additively into the cache in addition to any previous records with
244                         // the same name, rrtype, and rrclass.
245                         // TODO: All old records with that name, rrtype, and rrclass that were
246                         //       received more than one second ago are declared invalid, and marked
247                         //       to expire from the cache in one second.
248                         if (inetRecord.getCacheFlush()) {
249                             response.clearInet4AddressRecords();
250                             response.clearInet6AddressRecords();
251                         }
252                     }
253                 }
254             }
255         }
256 
257         // Loop 3-2: Assign addresses, which reference the host name in the SRV record.
258         for (MdnsInetAddressRecord inetRecord : inetRecords) {
259             if (allowMultipleSrvRecordsPerHost) {
260                 List<MdnsResponse> matchingResponses =
261                         findResponsesWithHostName(responses, inetRecord.getName());
262                 for (MdnsResponse response : matchingResponses) {
263                     if (assignInetRecord(response, inetRecord)) {
264                         final MdnsResponse originalResponse = augmentedToOriginal.get(response);
265                         if (originalResponse == null
266                                 || !originalResponse.hasIdenticalRecord(inetRecord)) {
267                             modified.add(response);
268                         }
269                     }
270                 }
271             } else {
272                 MdnsResponse response =
273                         findResponseWithHostName(responses, inetRecord.getName());
274                 if (response != null) {
275                     if (assignInetRecord(response, inetRecord)) {
276                         final MdnsResponse originalResponse = augmentedToOriginal.get(response);
277                         if (originalResponse == null
278                                 || !originalResponse.hasIdenticalRecord(inetRecord)) {
279                             modified.add(response);
280                         }
281                     }
282                 }
283             }
284         }
285 
286         // Only responses that have new or modified address records were added to the modified set.
287         // Make sure responses that have lost address records are added to the set too.
288         for (int i = 0; i < augmentedToOriginal.size(); i++) {
289             final MdnsResponse augmented = augmentedToOriginal.keyAt(i);
290             final MdnsResponse original = augmentedToOriginal.valueAt(i);
291             if (augmented.getRecords().size() != original.getRecords().size()) {
292                 modified.add(augmented);
293             }
294         }
295 
296         return Pair.create(modified, responses);
297     }
298 
assignInetRecord( MdnsResponse response, MdnsInetAddressRecord inetRecord)299     private static boolean assignInetRecord(
300             MdnsResponse response, MdnsInetAddressRecord inetRecord) {
301         if (inetRecord.getInet4Address() != null) {
302             return response.addInet4AddressRecord(inetRecord);
303         } else if (inetRecord.getInet6Address() != null) {
304             return response.addInet6AddressRecord(inetRecord);
305         }
306         return false;
307     }
308 
findResponsesWithHostName( @ullable List<MdnsResponse> responses, String[] hostName)309     private static List<MdnsResponse> findResponsesWithHostName(
310             @Nullable List<MdnsResponse> responses, String[] hostName) {
311         if (responses == null || responses.isEmpty()) {
312             return List.of();
313         }
314 
315         List<MdnsResponse> result = null;
316         for (MdnsResponse response : responses) {
317             MdnsServiceRecord serviceRecord = response.getServiceRecord();
318             if (serviceRecord == null) {
319                 continue;
320             }
321             if (MdnsUtils.equalsDnsLabelIgnoreDnsCase(serviceRecord.getServiceHost(), hostName)) {
322                 if (result == null) {
323                     result = new ArrayList<>(/* initialCapacity= */ responses.size());
324                 }
325                 result.add(response);
326             }
327         }
328         return result == null ? List.of() : result;
329     }
330 }