1 /* 2 * Copyright (C) 2023 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.server.connectivity.mdns; 18 19 import static com.android.server.connectivity.mdns.MdnsResponse.EXPIRATION_NEVER; 20 import static com.android.server.connectivity.mdns.util.MdnsUtils.ensureRunningOnHandlerThread; 21 import static com.android.server.connectivity.mdns.util.MdnsUtils.equalsIgnoreDnsCase; 22 import static com.android.server.connectivity.mdns.util.MdnsUtils.toDnsLowerCase; 23 24 import static java.lang.Math.min; 25 26 import android.annotation.NonNull; 27 import android.annotation.Nullable; 28 import android.os.Handler; 29 import android.os.Looper; 30 import android.util.ArrayMap; 31 32 import com.android.internal.annotations.VisibleForTesting; 33 import com.android.server.connectivity.mdns.util.MdnsUtils; 34 35 import java.util.ArrayList; 36 import java.util.Collections; 37 import java.util.Iterator; 38 import java.util.List; 39 import java.util.Objects; 40 41 /** 42 * The {@link MdnsServiceCache} manages the service which discovers from each socket and cache these 43 * services to reduce duplicated queries. 44 * 45 * <p>This class is not thread safe, it is intended to be used only from the looper thread. 46 * However, the constructor is an exception, as it is called on another thread; 47 * therefore for thread safety all members of this class MUST either be final or initialized 48 * to their default value (0, false or null). 49 */ 50 public class MdnsServiceCache { 51 static class CacheKey { 52 @NonNull final String mLowercaseServiceType; 53 @NonNull final SocketKey mSocketKey; 54 CacheKey(@onNull String serviceType, @NonNull SocketKey socketKey)55 CacheKey(@NonNull String serviceType, @NonNull SocketKey socketKey) { 56 mLowercaseServiceType = toDnsLowerCase(serviceType); 57 mSocketKey = socketKey; 58 } 59 hashCode()60 @Override public int hashCode() { 61 return Objects.hash(mLowercaseServiceType, mSocketKey); 62 } 63 equals(Object other)64 @Override public boolean equals(Object other) { 65 if (this == other) { 66 return true; 67 } 68 if (!(other instanceof CacheKey)) { 69 return false; 70 } 71 return Objects.equals(mLowercaseServiceType, ((CacheKey) other).mLowercaseServiceType) 72 && Objects.equals(mSocketKey, ((CacheKey) other).mSocketKey); 73 } 74 } 75 /** 76 * A map of cached services. Key is composed of service type and socket. Value is the list of 77 * services which are discovered from the given CacheKey. 78 * When the MdnsFeatureFlags#NSD_EXPIRED_SERVICES_REMOVAL flag is enabled, the lists are sorted 79 * by expiration time, with the earliest entries appearing first. This sorting allows the 80 * removal process to progress through the expiration check efficiently. 81 */ 82 @NonNull 83 private final ArrayMap<CacheKey, List<MdnsResponse>> mCachedServices = new ArrayMap<>(); 84 /** 85 * A map of service expire callbacks. Key is composed of service type and socket and value is 86 * the callback listener. 87 */ 88 @NonNull 89 private final ArrayMap<CacheKey, ServiceExpiredCallback> mCallbacks = new ArrayMap<>(); 90 @NonNull 91 private final Handler mHandler; 92 @NonNull 93 private final MdnsFeatureFlags mMdnsFeatureFlags; 94 @NonNull 95 private final MdnsUtils.Clock mClock; 96 private long mNextExpirationTime = EXPIRATION_NEVER; 97 MdnsServiceCache(@onNull Looper looper, @NonNull MdnsFeatureFlags mdnsFeatureFlags)98 public MdnsServiceCache(@NonNull Looper looper, @NonNull MdnsFeatureFlags mdnsFeatureFlags) { 99 this(looper, mdnsFeatureFlags, new MdnsUtils.Clock()); 100 } 101 102 @VisibleForTesting MdnsServiceCache(@onNull Looper looper, @NonNull MdnsFeatureFlags mdnsFeatureFlags, @NonNull MdnsUtils.Clock clock)103 MdnsServiceCache(@NonNull Looper looper, @NonNull MdnsFeatureFlags mdnsFeatureFlags, 104 @NonNull MdnsUtils.Clock clock) { 105 mHandler = new Handler(looper); 106 mMdnsFeatureFlags = mdnsFeatureFlags; 107 mClock = clock; 108 } 109 110 /** 111 * Get the cache services which are queried from given service type and socket. 112 * 113 * @param cacheKey the target CacheKey. 114 * @return the set of services which matches the given service type. 115 */ 116 @NonNull getCachedServices(@onNull CacheKey cacheKey)117 public List<MdnsResponse> getCachedServices(@NonNull CacheKey cacheKey) { 118 ensureRunningOnHandlerThread(mHandler); 119 if (mMdnsFeatureFlags.mIsExpiredServicesRemovalEnabled) { 120 maybeRemoveExpiredServices(cacheKey, mClock.elapsedRealtime()); 121 } 122 return mCachedServices.containsKey(cacheKey) 123 ? Collections.unmodifiableList(new ArrayList<>(mCachedServices.get(cacheKey))) 124 : Collections.emptyList(); 125 } 126 127 /** 128 * Find a matched response for given service name 129 * 130 * @param responses the responses to be searched. 131 * @param serviceName the target service name 132 * @return the response which matches the given service name or null if not found. 133 */ findMatchedResponse(@onNull List<MdnsResponse> responses, @NonNull String serviceName)134 public static MdnsResponse findMatchedResponse(@NonNull List<MdnsResponse> responses, 135 @NonNull String serviceName) { 136 for (MdnsResponse response : responses) { 137 if (equalsIgnoreDnsCase(serviceName, response.getServiceInstanceName())) { 138 return response; 139 } 140 } 141 return null; 142 } 143 144 /** 145 * Get the cache service. 146 * 147 * @param serviceName the target service name. 148 * @param cacheKey the target CacheKey. 149 * @return the service which matches given conditions. 150 */ 151 @Nullable getCachedService(@onNull String serviceName, @NonNull CacheKey cacheKey)152 public MdnsResponse getCachedService(@NonNull String serviceName, @NonNull CacheKey cacheKey) { 153 ensureRunningOnHandlerThread(mHandler); 154 if (mMdnsFeatureFlags.mIsExpiredServicesRemovalEnabled) { 155 maybeRemoveExpiredServices(cacheKey, mClock.elapsedRealtime()); 156 } 157 final List<MdnsResponse> responses = mCachedServices.get(cacheKey); 158 if (responses == null) { 159 return null; 160 } 161 final MdnsResponse response = findMatchedResponse(responses, serviceName); 162 return response != null ? new MdnsResponse(response) : null; 163 } 164 insertResponseAndSortList( List<MdnsResponse> responses, MdnsResponse response, long now)165 static void insertResponseAndSortList( 166 List<MdnsResponse> responses, MdnsResponse response, long now) { 167 // binarySearch returns "the index of the search key, if it is contained in the list; 168 // otherwise, (-(insertion point) - 1)" 169 final int searchRes = Collections.binarySearch(responses, response, 170 // Sort the list by ttl. 171 (o1, o2) -> Long.compare(o1.getMinRemainingTtl(now), o2.getMinRemainingTtl(now))); 172 responses.add(searchRes >= 0 ? searchRes : (-searchRes - 1), response); 173 } 174 175 /** 176 * Add or update a service. 177 * 178 * @param cacheKey the target CacheKey. 179 * @param response the response of the discovered service. 180 */ addOrUpdateService(@onNull CacheKey cacheKey, @NonNull MdnsResponse response)181 public void addOrUpdateService(@NonNull CacheKey cacheKey, @NonNull MdnsResponse response) { 182 ensureRunningOnHandlerThread(mHandler); 183 final List<MdnsResponse> responses = mCachedServices.computeIfAbsent( 184 cacheKey, key -> new ArrayList<>()); 185 // Remove existing service if present. 186 final MdnsResponse existing = 187 findMatchedResponse(responses, response.getServiceInstanceName()); 188 responses.remove(existing); 189 if (mMdnsFeatureFlags.mIsExpiredServicesRemovalEnabled) { 190 final long now = mClock.elapsedRealtime(); 191 // Insert and sort service 192 insertResponseAndSortList(responses, response, now); 193 // Update the next expiration check time when a new service is added. 194 mNextExpirationTime = getNextExpirationTime(now); 195 } else { 196 responses.add(response); 197 } 198 } 199 200 /** 201 * Remove a service which matches the given service name, type and socket. 202 * 203 * @param serviceName the target service name. 204 * @param cacheKey the target CacheKey. 205 */ 206 @Nullable removeService(@onNull String serviceName, @NonNull CacheKey cacheKey)207 public MdnsResponse removeService(@NonNull String serviceName, @NonNull CacheKey cacheKey) { 208 ensureRunningOnHandlerThread(mHandler); 209 final List<MdnsResponse> responses = mCachedServices.get(cacheKey); 210 if (responses == null) { 211 return null; 212 } 213 final Iterator<MdnsResponse> iterator = responses.iterator(); 214 MdnsResponse removedResponse = null; 215 while (iterator.hasNext()) { 216 final MdnsResponse response = iterator.next(); 217 if (equalsIgnoreDnsCase(serviceName, response.getServiceInstanceName())) { 218 iterator.remove(); 219 removedResponse = response; 220 break; 221 } 222 } 223 224 if (mMdnsFeatureFlags.mIsExpiredServicesRemovalEnabled) { 225 // Remove the serviceType if no response. 226 if (responses.isEmpty()) { 227 mCachedServices.remove(cacheKey); 228 } 229 // Update the next expiration check time when a service is removed. 230 mNextExpirationTime = getNextExpirationTime(mClock.elapsedRealtime()); 231 } 232 return removedResponse; 233 } 234 235 /** 236 * Register a callback to listen to service expiration. 237 * 238 * <p> Registering the same callback instance twice is a no-op, since MdnsServiceTypeClient 239 * relies on this. 240 * 241 * @param cacheKey the target CacheKey. 242 * @param callback the callback that notify the service is expired. 243 */ registerServiceExpiredCallback(@onNull CacheKey cacheKey, @NonNull ServiceExpiredCallback callback)244 public void registerServiceExpiredCallback(@NonNull CacheKey cacheKey, 245 @NonNull ServiceExpiredCallback callback) { 246 ensureRunningOnHandlerThread(mHandler); 247 mCallbacks.put(cacheKey, callback); 248 } 249 250 /** 251 * Unregister the service expired callback. 252 * 253 * @param cacheKey the CacheKey that is registered to listen service expiration before. 254 */ unregisterServiceExpiredCallback(@onNull CacheKey cacheKey)255 public void unregisterServiceExpiredCallback(@NonNull CacheKey cacheKey) { 256 ensureRunningOnHandlerThread(mHandler); 257 mCallbacks.remove(cacheKey); 258 } 259 notifyServiceExpired(@onNull CacheKey cacheKey, @NonNull MdnsResponse previousResponse, @Nullable MdnsResponse newResponse)260 private void notifyServiceExpired(@NonNull CacheKey cacheKey, 261 @NonNull MdnsResponse previousResponse, @Nullable MdnsResponse newResponse) { 262 final ServiceExpiredCallback callback = mCallbacks.get(cacheKey); 263 if (callback == null) { 264 // The cached service is no listener. 265 return; 266 } 267 mHandler.post(()-> callback.onServiceRecordExpired(previousResponse, newResponse)); 268 } 269 removeExpiredServices(@onNull List<MdnsResponse> responses, long now)270 static List<MdnsResponse> removeExpiredServices(@NonNull List<MdnsResponse> responses, 271 long now) { 272 final List<MdnsResponse> removedResponses = new ArrayList<>(); 273 final Iterator<MdnsResponse> iterator = responses.iterator(); 274 while (iterator.hasNext()) { 275 final MdnsResponse response = iterator.next(); 276 // TODO: Check other records (A, AAAA, TXT) ttl time and remove the record if it's 277 // expired. Then send service update notification. 278 if (!response.hasServiceRecord() || response.getMinRemainingTtl(now) > 0) { 279 // The responses are sorted by the service record ttl time. Break out of loop 280 // early if service is not expired or no service record. 281 break; 282 } 283 // Remove the ttl expired service. 284 iterator.remove(); 285 removedResponses.add(response); 286 } 287 return removedResponses; 288 } 289 getNextExpirationTime(long now)290 private long getNextExpirationTime(long now) { 291 if (mCachedServices.isEmpty()) { 292 return EXPIRATION_NEVER; 293 } 294 295 long minRemainingTtl = EXPIRATION_NEVER; 296 for (int i = 0; i < mCachedServices.size(); i++) { 297 minRemainingTtl = min(minRemainingTtl, 298 // The empty lists are not kept in the map, so there's always at least one 299 // element in the list. Therefore, it's fine to get the first element without a 300 // null check. 301 mCachedServices.valueAt(i).get(0).getMinRemainingTtl(now)); 302 } 303 return minRemainingTtl == EXPIRATION_NEVER ? EXPIRATION_NEVER : now + minRemainingTtl; 304 } 305 306 /** 307 * Check whether the ttl time is expired on each service and notify to the listeners 308 */ maybeRemoveExpiredServices(CacheKey cacheKey, long now)309 private void maybeRemoveExpiredServices(CacheKey cacheKey, long now) { 310 ensureRunningOnHandlerThread(mHandler); 311 if (now < mNextExpirationTime) { 312 // Skip the check if ttl time is not expired. 313 return; 314 } 315 316 final List<MdnsResponse> responses = mCachedServices.get(cacheKey); 317 if (responses == null) { 318 // No such services. 319 return; 320 } 321 322 final List<MdnsResponse> removedResponses = removeExpiredServices(responses, now); 323 if (removedResponses.isEmpty()) { 324 // No expired services. 325 return; 326 } 327 328 for (MdnsResponse previousResponse : removedResponses) { 329 notifyServiceExpired(cacheKey, previousResponse, null /* newResponse */); 330 } 331 332 // Remove the serviceType if no response. 333 if (responses.isEmpty()) { 334 mCachedServices.remove(cacheKey); 335 } 336 337 // Update next expiration time. 338 mNextExpirationTime = getNextExpirationTime(now); 339 } 340 341 /*** Callbacks for listening service expiration */ 342 public interface ServiceExpiredCallback { 343 /*** Notify the service is expired */ onServiceRecordExpired(@onNull MdnsResponse previousResponse, @Nullable MdnsResponse newResponse)344 void onServiceRecordExpired(@NonNull MdnsResponse previousResponse, 345 @Nullable MdnsResponse newResponse); 346 } 347 348 // TODO: Schedule a job to check ttl expiration for all services and notify to the clients. 349 } 350