1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.server.connectivity.mdns;
18 
19 import static com.android.server.connectivity.mdns.MdnsResponse.EXPIRATION_NEVER;
20 import static com.android.server.connectivity.mdns.util.MdnsUtils.ensureRunningOnHandlerThread;
21 import static com.android.server.connectivity.mdns.util.MdnsUtils.equalsIgnoreDnsCase;
22 import static com.android.server.connectivity.mdns.util.MdnsUtils.toDnsLowerCase;
23 
24 import static java.lang.Math.min;
25 
26 import android.annotation.NonNull;
27 import android.annotation.Nullable;
28 import android.os.Handler;
29 import android.os.Looper;
30 import android.util.ArrayMap;
31 
32 import com.android.internal.annotations.VisibleForTesting;
33 import com.android.server.connectivity.mdns.util.MdnsUtils;
34 
35 import java.util.ArrayList;
36 import java.util.Collections;
37 import java.util.Iterator;
38 import java.util.List;
39 import java.util.Objects;
40 
41 /**
42  * The {@link MdnsServiceCache} manages the service which discovers from each socket and cache these
43  * services to reduce duplicated queries.
44  *
45  * <p>This class is not thread safe, it is intended to be used only from the looper thread.
46  *  However, the constructor is an exception, as it is called on another thread;
47  *  therefore for thread safety all members of this class MUST either be final or initialized
48  *  to their default value (0, false or null).
49  */
50 public class MdnsServiceCache {
51     static class CacheKey {
52         @NonNull final String mLowercaseServiceType;
53         @NonNull final SocketKey mSocketKey;
54 
CacheKey(@onNull String serviceType, @NonNull SocketKey socketKey)55         CacheKey(@NonNull String serviceType, @NonNull SocketKey socketKey) {
56             mLowercaseServiceType = toDnsLowerCase(serviceType);
57             mSocketKey = socketKey;
58         }
59 
hashCode()60         @Override public int hashCode() {
61             return Objects.hash(mLowercaseServiceType, mSocketKey);
62         }
63 
equals(Object other)64         @Override public boolean equals(Object other) {
65             if (this == other) {
66                 return true;
67             }
68             if (!(other instanceof CacheKey)) {
69                 return false;
70             }
71             return Objects.equals(mLowercaseServiceType, ((CacheKey) other).mLowercaseServiceType)
72                     && Objects.equals(mSocketKey, ((CacheKey) other).mSocketKey);
73         }
74     }
75     /**
76      * A map of cached services. Key is composed of service type and socket. Value is the list of
77      * services which are discovered from the given CacheKey.
78      * When the MdnsFeatureFlags#NSD_EXPIRED_SERVICES_REMOVAL flag is enabled, the lists are sorted
79      * by expiration time, with the earliest entries appearing first. This sorting allows the
80      * removal process to progress through the expiration check efficiently.
81      */
82     @NonNull
83     private final ArrayMap<CacheKey, List<MdnsResponse>> mCachedServices = new ArrayMap<>();
84     /**
85      * A map of service expire callbacks. Key is composed of service type and socket and value is
86      * the callback listener.
87      */
88     @NonNull
89     private final ArrayMap<CacheKey, ServiceExpiredCallback> mCallbacks = new ArrayMap<>();
90     @NonNull
91     private final Handler mHandler;
92     @NonNull
93     private final MdnsFeatureFlags mMdnsFeatureFlags;
94     @NonNull
95     private final MdnsUtils.Clock mClock;
96     private long mNextExpirationTime = EXPIRATION_NEVER;
97 
MdnsServiceCache(@onNull Looper looper, @NonNull MdnsFeatureFlags mdnsFeatureFlags)98     public MdnsServiceCache(@NonNull Looper looper, @NonNull MdnsFeatureFlags mdnsFeatureFlags) {
99         this(looper, mdnsFeatureFlags, new MdnsUtils.Clock());
100     }
101 
102     @VisibleForTesting
MdnsServiceCache(@onNull Looper looper, @NonNull MdnsFeatureFlags mdnsFeatureFlags, @NonNull MdnsUtils.Clock clock)103     MdnsServiceCache(@NonNull Looper looper, @NonNull MdnsFeatureFlags mdnsFeatureFlags,
104             @NonNull MdnsUtils.Clock clock) {
105         mHandler = new Handler(looper);
106         mMdnsFeatureFlags = mdnsFeatureFlags;
107         mClock = clock;
108     }
109 
110     /**
111      * Get the cache services which are queried from given service type and socket.
112      *
113      * @param cacheKey the target CacheKey.
114      * @return the set of services which matches the given service type.
115      */
116     @NonNull
getCachedServices(@onNull CacheKey cacheKey)117     public List<MdnsResponse> getCachedServices(@NonNull CacheKey cacheKey) {
118         ensureRunningOnHandlerThread(mHandler);
119         if (mMdnsFeatureFlags.mIsExpiredServicesRemovalEnabled) {
120             maybeRemoveExpiredServices(cacheKey, mClock.elapsedRealtime());
121         }
122         return mCachedServices.containsKey(cacheKey)
123                 ? Collections.unmodifiableList(new ArrayList<>(mCachedServices.get(cacheKey)))
124                 : Collections.emptyList();
125     }
126 
127     /**
128      * Find a matched response for given service name
129      *
130      * @param responses the responses to be searched.
131      * @param serviceName the target service name
132      * @return the response which matches the given service name or null if not found.
133      */
findMatchedResponse(@onNull List<MdnsResponse> responses, @NonNull String serviceName)134     public static MdnsResponse findMatchedResponse(@NonNull List<MdnsResponse> responses,
135             @NonNull String serviceName) {
136         for (MdnsResponse response : responses) {
137             if (equalsIgnoreDnsCase(serviceName, response.getServiceInstanceName())) {
138                 return response;
139             }
140         }
141         return null;
142     }
143 
144     /**
145      * Get the cache service.
146      *
147      * @param serviceName the target service name.
148      * @param cacheKey the target CacheKey.
149      * @return the service which matches given conditions.
150      */
151     @Nullable
getCachedService(@onNull String serviceName, @NonNull CacheKey cacheKey)152     public MdnsResponse getCachedService(@NonNull String serviceName, @NonNull CacheKey cacheKey) {
153         ensureRunningOnHandlerThread(mHandler);
154         if (mMdnsFeatureFlags.mIsExpiredServicesRemovalEnabled) {
155             maybeRemoveExpiredServices(cacheKey, mClock.elapsedRealtime());
156         }
157         final List<MdnsResponse> responses = mCachedServices.get(cacheKey);
158         if (responses == null) {
159             return null;
160         }
161         final MdnsResponse response = findMatchedResponse(responses, serviceName);
162         return response != null ? new MdnsResponse(response) : null;
163     }
164 
insertResponseAndSortList( List<MdnsResponse> responses, MdnsResponse response, long now)165     static void insertResponseAndSortList(
166             List<MdnsResponse> responses, MdnsResponse response, long now) {
167         // binarySearch returns "the index of the search key, if it is contained in the list;
168         // otherwise, (-(insertion point) - 1)"
169         final int searchRes = Collections.binarySearch(responses, response,
170                 // Sort the list by ttl.
171                 (o1, o2) -> Long.compare(o1.getMinRemainingTtl(now), o2.getMinRemainingTtl(now)));
172         responses.add(searchRes >= 0 ? searchRes : (-searchRes - 1), response);
173     }
174 
175     /**
176      * Add or update a service.
177      *
178      * @param cacheKey the target CacheKey.
179      * @param response the response of the discovered service.
180      */
addOrUpdateService(@onNull CacheKey cacheKey, @NonNull MdnsResponse response)181     public void addOrUpdateService(@NonNull CacheKey cacheKey, @NonNull MdnsResponse response) {
182         ensureRunningOnHandlerThread(mHandler);
183         final List<MdnsResponse> responses = mCachedServices.computeIfAbsent(
184                 cacheKey, key -> new ArrayList<>());
185         // Remove existing service if present.
186         final MdnsResponse existing =
187                 findMatchedResponse(responses, response.getServiceInstanceName());
188         responses.remove(existing);
189         if (mMdnsFeatureFlags.mIsExpiredServicesRemovalEnabled) {
190             final long now = mClock.elapsedRealtime();
191             // Insert and sort service
192             insertResponseAndSortList(responses, response, now);
193             // Update the next expiration check time when a new service is added.
194             mNextExpirationTime = getNextExpirationTime(now);
195         } else {
196             responses.add(response);
197         }
198     }
199 
200     /**
201      * Remove a service which matches the given service name, type and socket.
202      *
203      * @param serviceName the target service name.
204      * @param cacheKey the target CacheKey.
205      */
206     @Nullable
removeService(@onNull String serviceName, @NonNull CacheKey cacheKey)207     public MdnsResponse removeService(@NonNull String serviceName, @NonNull CacheKey cacheKey) {
208         ensureRunningOnHandlerThread(mHandler);
209         final List<MdnsResponse> responses = mCachedServices.get(cacheKey);
210         if (responses == null) {
211             return null;
212         }
213         final Iterator<MdnsResponse> iterator = responses.iterator();
214         MdnsResponse removedResponse = null;
215         while (iterator.hasNext()) {
216             final MdnsResponse response = iterator.next();
217             if (equalsIgnoreDnsCase(serviceName, response.getServiceInstanceName())) {
218                 iterator.remove();
219                 removedResponse = response;
220                 break;
221             }
222         }
223 
224         if (mMdnsFeatureFlags.mIsExpiredServicesRemovalEnabled) {
225             // Remove the serviceType if no response.
226             if (responses.isEmpty()) {
227                 mCachedServices.remove(cacheKey);
228             }
229             // Update the next expiration check time when a service is removed.
230             mNextExpirationTime = getNextExpirationTime(mClock.elapsedRealtime());
231         }
232         return removedResponse;
233     }
234 
235     /**
236      * Register a callback to listen to service expiration.
237      *
238      * <p> Registering the same callback instance twice is a no-op, since MdnsServiceTypeClient
239      * relies on this.
240      *
241      * @param cacheKey the target CacheKey.
242      * @param callback the callback that notify the service is expired.
243      */
registerServiceExpiredCallback(@onNull CacheKey cacheKey, @NonNull ServiceExpiredCallback callback)244     public void registerServiceExpiredCallback(@NonNull CacheKey cacheKey,
245             @NonNull ServiceExpiredCallback callback) {
246         ensureRunningOnHandlerThread(mHandler);
247         mCallbacks.put(cacheKey, callback);
248     }
249 
250     /**
251      * Unregister the service expired callback.
252      *
253      * @param cacheKey the CacheKey that is registered to listen service expiration before.
254      */
unregisterServiceExpiredCallback(@onNull CacheKey cacheKey)255     public void unregisterServiceExpiredCallback(@NonNull CacheKey cacheKey) {
256         ensureRunningOnHandlerThread(mHandler);
257         mCallbacks.remove(cacheKey);
258     }
259 
notifyServiceExpired(@onNull CacheKey cacheKey, @NonNull MdnsResponse previousResponse, @Nullable MdnsResponse newResponse)260     private void notifyServiceExpired(@NonNull CacheKey cacheKey,
261             @NonNull MdnsResponse previousResponse, @Nullable MdnsResponse newResponse) {
262         final ServiceExpiredCallback callback = mCallbacks.get(cacheKey);
263         if (callback == null) {
264             // The cached service is no listener.
265             return;
266         }
267         mHandler.post(()-> callback.onServiceRecordExpired(previousResponse, newResponse));
268     }
269 
removeExpiredServices(@onNull List<MdnsResponse> responses, long now)270     static List<MdnsResponse> removeExpiredServices(@NonNull List<MdnsResponse> responses,
271             long now) {
272         final List<MdnsResponse> removedResponses = new ArrayList<>();
273         final Iterator<MdnsResponse> iterator = responses.iterator();
274         while (iterator.hasNext()) {
275             final MdnsResponse response = iterator.next();
276             // TODO: Check other records (A, AAAA, TXT) ttl time and remove the record if it's
277             //  expired. Then send service update notification.
278             if (!response.hasServiceRecord() || response.getMinRemainingTtl(now) > 0) {
279                 // The responses are sorted by the service record ttl time. Break out of loop
280                 // early if service is not expired or no service record.
281                 break;
282             }
283             // Remove the ttl expired service.
284             iterator.remove();
285             removedResponses.add(response);
286         }
287         return removedResponses;
288     }
289 
getNextExpirationTime(long now)290     private long getNextExpirationTime(long now) {
291         if (mCachedServices.isEmpty()) {
292             return EXPIRATION_NEVER;
293         }
294 
295         long minRemainingTtl = EXPIRATION_NEVER;
296         for (int i = 0; i < mCachedServices.size(); i++) {
297             minRemainingTtl = min(minRemainingTtl,
298                     // The empty lists are not kept in the map, so there's always at least one
299                     // element in the list. Therefore, it's fine to get the first element without a
300                     // null check.
301                     mCachedServices.valueAt(i).get(0).getMinRemainingTtl(now));
302         }
303         return minRemainingTtl == EXPIRATION_NEVER ? EXPIRATION_NEVER : now + minRemainingTtl;
304     }
305 
306     /**
307      * Check whether the ttl time is expired on each service and notify to the listeners
308      */
maybeRemoveExpiredServices(CacheKey cacheKey, long now)309     private void maybeRemoveExpiredServices(CacheKey cacheKey, long now) {
310         ensureRunningOnHandlerThread(mHandler);
311         if (now < mNextExpirationTime) {
312             // Skip the check if ttl time is not expired.
313             return;
314         }
315 
316         final List<MdnsResponse> responses = mCachedServices.get(cacheKey);
317         if (responses == null) {
318             // No such services.
319             return;
320         }
321 
322         final List<MdnsResponse> removedResponses = removeExpiredServices(responses, now);
323         if (removedResponses.isEmpty()) {
324             // No expired services.
325             return;
326         }
327 
328         for (MdnsResponse previousResponse : removedResponses) {
329             notifyServiceExpired(cacheKey, previousResponse, null /* newResponse */);
330         }
331 
332         // Remove the serviceType if no response.
333         if (responses.isEmpty()) {
334             mCachedServices.remove(cacheKey);
335         }
336 
337         // Update next expiration time.
338         mNextExpirationTime = getNextExpirationTime(now);
339     }
340 
341     /*** Callbacks for listening service expiration */
342     public interface ServiceExpiredCallback {
343         /*** Notify the service is expired */
onServiceRecordExpired(@onNull MdnsResponse previousResponse, @Nullable MdnsResponse newResponse)344         void onServiceRecordExpired(@NonNull MdnsResponse previousResponse,
345                 @Nullable MdnsResponse newResponse);
346     }
347 
348     // TODO: Schedule a job to check ttl expiration for all services and notify to the clients.
349 }
350