1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.server.connectivity.mdns;
18 
19 import static com.android.server.connectivity.mdns.MdnsConstants.IPV4_SOCKET_ADDR;
20 import static com.android.server.connectivity.mdns.MdnsConstants.IPV6_SOCKET_ADDR;
21 import static com.android.server.connectivity.mdns.util.MdnsUtils.ensureRunningOnHandlerThread;
22 
23 import android.annotation.NonNull;
24 import android.annotation.RequiresApi;
25 import android.os.Build;
26 import android.os.Handler;
27 import android.os.Looper;
28 import android.os.Message;
29 import android.util.ArrayMap;
30 import android.util.ArraySet;
31 
32 import com.android.internal.annotations.VisibleForTesting;
33 import com.android.net.module.util.SharedLog;
34 import com.android.server.connectivity.mdns.util.MdnsUtils;
35 
36 import java.io.IOException;
37 import java.net.DatagramPacket;
38 import java.net.Inet4Address;
39 import java.net.Inet6Address;
40 import java.net.InetSocketAddress;
41 import java.net.MulticastSocket;
42 import java.util.ArrayList;
43 import java.util.Collections;
44 import java.util.Map;
45 import java.util.Set;
46 
47 /**
48  * A class that handles sending mDNS replies to a {@link MulticastSocket}, possibly queueing them
49  * to be sent after some delay.
50  *
51  * TODO: implement sending after a delay, combining queued replies and duplicate answer suppression
52  */
53 @RequiresApi(Build.VERSION_CODES.TIRAMISU)
54 public class MdnsReplySender {
55     private static final int MSG_SEND = 1;
56     private static final int PACKET_NOT_SENT = 0;
57     private static final int PACKET_SENT = 1;
58 
59     @NonNull
60     private final MdnsInterfaceSocket mSocket;
61     @NonNull
62     private final Handler mHandler;
63     @NonNull
64     private final byte[] mPacketCreationBuffer;
65     @NonNull
66     private final SharedLog mSharedLog;
67     private final boolean mEnableDebugLog;
68     @NonNull
69     private final Dependencies mDependencies;
70     // RFC6762 15.2. Multipacket Known-Answer lists
71     // Multicast DNS responders associate the initial truncated query with its
72     // continuation packets by examining the source IP address in each packet.
73     private final Map<InetSocketAddress, MdnsReplyInfo> mSrcReplies = new ArrayMap<>();
74     @NonNull
75     private final MdnsFeatureFlags mMdnsFeatureFlags;
76 
77     /**
78      * Dependencies of MdnsReplySender, for injection in tests.
79      */
80     @VisibleForTesting
81     public static class Dependencies {
82         /**
83          * @see Handler#sendMessageDelayed(Message, long)
84          */
sendMessageDelayed(@onNull Handler handler, @NonNull Message message, long delayMillis)85         public void sendMessageDelayed(@NonNull Handler handler, @NonNull Message message,
86                 long delayMillis) {
87             handler.sendMessageDelayed(message, delayMillis);
88         }
89 
90         /**
91          * @see Handler#removeMessages(int)
92          */
removeMessages(@onNull Handler handler, int what)93         public void removeMessages(@NonNull Handler handler, int what) {
94             handler.removeMessages(what);
95         }
96 
97         /**
98          * @see Handler#removeMessages(int)
99          */
removeMessages(@onNull Handler handler, int what, @NonNull Object object)100         public void removeMessages(@NonNull Handler handler, int what, @NonNull Object object) {
101             handler.removeMessages(what, object);
102         }
103     }
104 
MdnsReplySender(@onNull Looper looper, @NonNull MdnsInterfaceSocket socket, @NonNull byte[] packetCreationBuffer, @NonNull SharedLog sharedLog, boolean enableDebugLog, @NonNull MdnsFeatureFlags mdnsFeatureFlags)105     public MdnsReplySender(@NonNull Looper looper, @NonNull MdnsInterfaceSocket socket,
106             @NonNull byte[] packetCreationBuffer, @NonNull SharedLog sharedLog,
107             boolean enableDebugLog, @NonNull MdnsFeatureFlags mdnsFeatureFlags) {
108         this(looper, socket, packetCreationBuffer, sharedLog, enableDebugLog, new Dependencies(),
109                 mdnsFeatureFlags);
110     }
111 
112     @VisibleForTesting
MdnsReplySender(@onNull Looper looper, @NonNull MdnsInterfaceSocket socket, @NonNull byte[] packetCreationBuffer, @NonNull SharedLog sharedLog, boolean enableDebugLog, @NonNull Dependencies dependencies, @NonNull MdnsFeatureFlags mdnsFeatureFlags)113     public MdnsReplySender(@NonNull Looper looper, @NonNull MdnsInterfaceSocket socket,
114             @NonNull byte[] packetCreationBuffer, @NonNull SharedLog sharedLog,
115             boolean enableDebugLog, @NonNull Dependencies dependencies,
116             @NonNull MdnsFeatureFlags mdnsFeatureFlags) {
117         mHandler = new SendHandler(looper);
118         mSocket = socket;
119         mPacketCreationBuffer = packetCreationBuffer;
120         mSharedLog = sharedLog;
121         mEnableDebugLog = enableDebugLog;
122         mDependencies = dependencies;
123         mMdnsFeatureFlags = mdnsFeatureFlags;
124     }
125 
getReplyDestination(@onNull InetSocketAddress queuingDest, @NonNull InetSocketAddress incomingDest)126     static InetSocketAddress getReplyDestination(@NonNull InetSocketAddress queuingDest,
127             @NonNull InetSocketAddress incomingDest) {
128         // The queuing reply is multicast, just use the current destination.
129         if (queuingDest.equals(IPV4_SOCKET_ADDR) || queuingDest.equals(IPV6_SOCKET_ADDR)) {
130             return queuingDest;
131         }
132 
133         // The incoming reply is multicast, change the reply from unicast to multicast since
134         // replying unicast when the query requests unicast reply is optional.
135         if (incomingDest.equals(IPV4_SOCKET_ADDR) || incomingDest.equals(IPV6_SOCKET_ADDR)) {
136             return incomingDest;
137         }
138 
139         return queuingDest;
140     }
141 
142     /**
143      * Queue a reply to be sent when its send delay expires.
144      */
queueReply(@onNull MdnsReplyInfo reply)145     public void queueReply(@NonNull MdnsReplyInfo reply) {
146         ensureRunningOnHandlerThread(mHandler);
147 
148         if (mMdnsFeatureFlags.isKnownAnswerSuppressionEnabled()) {
149             mDependencies.removeMessages(mHandler, MSG_SEND, reply.source);
150 
151             final MdnsReplyInfo queuingReply = mSrcReplies.remove(reply.source);
152             final ArraySet<MdnsRecord> answers = new ArraySet<>();
153             final Set<MdnsRecord> additionalAnswers = new ArraySet<>();
154             final Set<MdnsRecord> knownAnswers = new ArraySet<>();
155             if (queuingReply != null) {
156                 answers.addAll(queuingReply.answers);
157                 additionalAnswers.addAll(queuingReply.additionalAnswers);
158                 knownAnswers.addAll(queuingReply.knownAnswers);
159             }
160             answers.addAll(reply.answers);
161             additionalAnswers.addAll(reply.additionalAnswers);
162             knownAnswers.addAll(reply.knownAnswers);
163             // RFC6762 7.2. Multipacket Known-Answer Suppression
164             // If the responder sees any of its answers listed in the Known-Answer
165             // lists of subsequent packets from the querying host, it MUST delete
166             // that answer from the list of answers it is planning to give.
167             for (MdnsRecord knownAnswer : knownAnswers) {
168                 final int idx = answers.indexOf(knownAnswer);
169                 if (idx >= 0 && knownAnswer.getTtl() > answers.valueAt(idx).getTtl() / 2) {
170                     answers.removeAt(idx);
171                 }
172             }
173 
174             if (answers.size() == 0) {
175                 return;
176             }
177 
178             final MdnsReplyInfo newReply = new MdnsReplyInfo(
179                     new ArrayList<>(answers),
180                     new ArrayList<>(additionalAnswers),
181                     reply.sendDelayMs,
182                     queuingReply == null ? reply.destination
183                             : getReplyDestination(queuingReply.destination, reply.destination),
184                     reply.source,
185                     new ArrayList<>(knownAnswers));
186 
187             mSrcReplies.put(newReply.source, newReply);
188             mDependencies.sendMessageDelayed(mHandler,
189                     mHandler.obtainMessage(MSG_SEND, newReply.source), newReply.sendDelayMs);
190         } else {
191             mDependencies.sendMessageDelayed(
192                     mHandler, mHandler.obtainMessage(MSG_SEND, reply), reply.sendDelayMs);
193         }
194 
195         if (mEnableDebugLog) {
196             mSharedLog.v("Scheduling " + reply);
197         }
198     }
199 
200     /**
201      * Send a packet immediately.
202      *
203      * Must be called on the looper thread used by the {@link MdnsReplySender}.
204      */
sendNow(@onNull MdnsPacket packet, @NonNull InetSocketAddress destination)205     public int sendNow(@NonNull MdnsPacket packet, @NonNull InetSocketAddress destination)
206             throws IOException {
207         ensureRunningOnHandlerThread(mHandler);
208         if (!((destination.getAddress() instanceof Inet6Address && mSocket.hasJoinedIpv6())
209                 || (destination.getAddress() instanceof Inet4Address && mSocket.hasJoinedIpv4()))) {
210             // Skip sending if the socket has not joined the v4/v6 group (there was no address)
211             return PACKET_NOT_SENT;
212         }
213         final byte[] outBuffer = MdnsUtils.createRawDnsPacket(mPacketCreationBuffer, packet);
214         mSocket.send(new DatagramPacket(outBuffer, 0, outBuffer.length, destination));
215         return PACKET_SENT;
216     }
217 
218     /**
219      * Cancel all pending sends.
220      */
cancelAll()221     public void cancelAll() {
222         ensureRunningOnHandlerThread(mHandler);
223         mDependencies.removeMessages(mHandler, MSG_SEND);
224     }
225 
226     private class SendHandler extends Handler {
SendHandler(@onNull Looper looper)227         SendHandler(@NonNull Looper looper) {
228             super(looper);
229         }
230 
231         @Override
handleMessage(@onNull Message msg)232         public void handleMessage(@NonNull Message msg) {
233             final MdnsReplyInfo replyInfo;
234             if (mMdnsFeatureFlags.isKnownAnswerSuppressionEnabled()) {
235                 // Retrieve the MdnsReplyInfo from the map via a source address, as the reply info
236                 // will be combined or updated.
237                 final InetSocketAddress source = (InetSocketAddress) msg.obj;
238                 replyInfo = mSrcReplies.remove(source);
239             } else {
240                 replyInfo = (MdnsReplyInfo) msg.obj;
241             }
242 
243             if (replyInfo == null) {
244                 mSharedLog.wtf("Unknown reply info.");
245                 return;
246             }
247 
248             if (mEnableDebugLog) mSharedLog.v("Sending " + replyInfo);
249 
250             final int flags = 0x8400; // Response, authoritative (rfc6762 18.4)
251             final MdnsPacket packet = new MdnsPacket(flags,
252                     Collections.emptyList() /* questions */,
253                     replyInfo.answers,
254                     Collections.emptyList() /* authorityRecords */,
255                     replyInfo.additionalAnswers);
256 
257             try {
258                 sendNow(packet, replyInfo.destination);
259             } catch (IOException e) {
260                 mSharedLog.e("Error sending MDNS response", e);
261             }
262         }
263     }
264 }
265