/* * Copyright (C) 2023 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.android.server.connectivity.mdns; import static com.android.server.connectivity.mdns.MdnsResponse.EXPIRATION_NEVER; import static com.android.server.connectivity.mdns.util.MdnsUtils.ensureRunningOnHandlerThread; import static com.android.server.connectivity.mdns.util.MdnsUtils.equalsIgnoreDnsCase; import static com.android.server.connectivity.mdns.util.MdnsUtils.toDnsLowerCase; import static java.lang.Math.min; import android.annotation.NonNull; import android.annotation.Nullable; import android.os.Handler; import android.os.Looper; import android.util.ArrayMap; import com.android.internal.annotations.VisibleForTesting; import com.android.server.connectivity.mdns.util.MdnsUtils; import java.util.ArrayList; import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Objects; /** * The {@link MdnsServiceCache} manages the service which discovers from each socket and cache these * services to reduce duplicated queries. * *

This class is not thread safe, it is intended to be used only from the looper thread. * However, the constructor is an exception, as it is called on another thread; * therefore for thread safety all members of this class MUST either be final or initialized * to their default value (0, false or null). */ public class MdnsServiceCache { static class CacheKey { @NonNull final String mLowercaseServiceType; @NonNull final SocketKey mSocketKey; CacheKey(@NonNull String serviceType, @NonNull SocketKey socketKey) { mLowercaseServiceType = toDnsLowerCase(serviceType); mSocketKey = socketKey; } @Override public int hashCode() { return Objects.hash(mLowercaseServiceType, mSocketKey); } @Override public boolean equals(Object other) { if (this == other) { return true; } if (!(other instanceof CacheKey)) { return false; } return Objects.equals(mLowercaseServiceType, ((CacheKey) other).mLowercaseServiceType) && Objects.equals(mSocketKey, ((CacheKey) other).mSocketKey); } } /** * A map of cached services. Key is composed of service type and socket. Value is the list of * services which are discovered from the given CacheKey. * When the MdnsFeatureFlags#NSD_EXPIRED_SERVICES_REMOVAL flag is enabled, the lists are sorted * by expiration time, with the earliest entries appearing first. This sorting allows the * removal process to progress through the expiration check efficiently. */ @NonNull private final ArrayMap> mCachedServices = new ArrayMap<>(); /** * A map of service expire callbacks. Key is composed of service type and socket and value is * the callback listener. */ @NonNull private final ArrayMap mCallbacks = new ArrayMap<>(); @NonNull private final Handler mHandler; @NonNull private final MdnsFeatureFlags mMdnsFeatureFlags; @NonNull private final MdnsUtils.Clock mClock; private long mNextExpirationTime = EXPIRATION_NEVER; public MdnsServiceCache(@NonNull Looper looper, @NonNull MdnsFeatureFlags mdnsFeatureFlags) { this(looper, mdnsFeatureFlags, new MdnsUtils.Clock()); } @VisibleForTesting MdnsServiceCache(@NonNull Looper looper, @NonNull MdnsFeatureFlags mdnsFeatureFlags, @NonNull MdnsUtils.Clock clock) { mHandler = new Handler(looper); mMdnsFeatureFlags = mdnsFeatureFlags; mClock = clock; } /** * Get the cache services which are queried from given service type and socket. * * @param cacheKey the target CacheKey. * @return the set of services which matches the given service type. */ @NonNull public List getCachedServices(@NonNull CacheKey cacheKey) { ensureRunningOnHandlerThread(mHandler); if (mMdnsFeatureFlags.mIsExpiredServicesRemovalEnabled) { maybeRemoveExpiredServices(cacheKey, mClock.elapsedRealtime()); } return mCachedServices.containsKey(cacheKey) ? Collections.unmodifiableList(new ArrayList<>(mCachedServices.get(cacheKey))) : Collections.emptyList(); } /** * Find a matched response for given service name * * @param responses the responses to be searched. * @param serviceName the target service name * @return the response which matches the given service name or null if not found. */ public static MdnsResponse findMatchedResponse(@NonNull List responses, @NonNull String serviceName) { for (MdnsResponse response : responses) { if (equalsIgnoreDnsCase(serviceName, response.getServiceInstanceName())) { return response; } } return null; } /** * Get the cache service. * * @param serviceName the target service name. * @param cacheKey the target CacheKey. * @return the service which matches given conditions. */ @Nullable public MdnsResponse getCachedService(@NonNull String serviceName, @NonNull CacheKey cacheKey) { ensureRunningOnHandlerThread(mHandler); if (mMdnsFeatureFlags.mIsExpiredServicesRemovalEnabled) { maybeRemoveExpiredServices(cacheKey, mClock.elapsedRealtime()); } final List responses = mCachedServices.get(cacheKey); if (responses == null) { return null; } final MdnsResponse response = findMatchedResponse(responses, serviceName); return response != null ? new MdnsResponse(response) : null; } static void insertResponseAndSortList( List responses, MdnsResponse response, long now) { // binarySearch returns "the index of the search key, if it is contained in the list; // otherwise, (-(insertion point) - 1)" final int searchRes = Collections.binarySearch(responses, response, // Sort the list by ttl. (o1, o2) -> Long.compare(o1.getMinRemainingTtl(now), o2.getMinRemainingTtl(now))); responses.add(searchRes >= 0 ? searchRes : (-searchRes - 1), response); } /** * Add or update a service. * * @param cacheKey the target CacheKey. * @param response the response of the discovered service. */ public void addOrUpdateService(@NonNull CacheKey cacheKey, @NonNull MdnsResponse response) { ensureRunningOnHandlerThread(mHandler); final List responses = mCachedServices.computeIfAbsent( cacheKey, key -> new ArrayList<>()); // Remove existing service if present. final MdnsResponse existing = findMatchedResponse(responses, response.getServiceInstanceName()); responses.remove(existing); if (mMdnsFeatureFlags.mIsExpiredServicesRemovalEnabled) { final long now = mClock.elapsedRealtime(); // Insert and sort service insertResponseAndSortList(responses, response, now); // Update the next expiration check time when a new service is added. mNextExpirationTime = getNextExpirationTime(now); } else { responses.add(response); } } /** * Remove a service which matches the given service name, type and socket. * * @param serviceName the target service name. * @param cacheKey the target CacheKey. */ @Nullable public MdnsResponse removeService(@NonNull String serviceName, @NonNull CacheKey cacheKey) { ensureRunningOnHandlerThread(mHandler); final List responses = mCachedServices.get(cacheKey); if (responses == null) { return null; } final Iterator iterator = responses.iterator(); MdnsResponse removedResponse = null; while (iterator.hasNext()) { final MdnsResponse response = iterator.next(); if (equalsIgnoreDnsCase(serviceName, response.getServiceInstanceName())) { iterator.remove(); removedResponse = response; break; } } if (mMdnsFeatureFlags.mIsExpiredServicesRemovalEnabled) { // Remove the serviceType if no response. if (responses.isEmpty()) { mCachedServices.remove(cacheKey); } // Update the next expiration check time when a service is removed. mNextExpirationTime = getNextExpirationTime(mClock.elapsedRealtime()); } return removedResponse; } /** * Register a callback to listen to service expiration. * *

Registering the same callback instance twice is a no-op, since MdnsServiceTypeClient * relies on this. * * @param cacheKey the target CacheKey. * @param callback the callback that notify the service is expired. */ public void registerServiceExpiredCallback(@NonNull CacheKey cacheKey, @NonNull ServiceExpiredCallback callback) { ensureRunningOnHandlerThread(mHandler); mCallbacks.put(cacheKey, callback); } /** * Unregister the service expired callback. * * @param cacheKey the CacheKey that is registered to listen service expiration before. */ public void unregisterServiceExpiredCallback(@NonNull CacheKey cacheKey) { ensureRunningOnHandlerThread(mHandler); mCallbacks.remove(cacheKey); } private void notifyServiceExpired(@NonNull CacheKey cacheKey, @NonNull MdnsResponse previousResponse, @Nullable MdnsResponse newResponse) { final ServiceExpiredCallback callback = mCallbacks.get(cacheKey); if (callback == null) { // The cached service is no listener. return; } mHandler.post(()-> callback.onServiceRecordExpired(previousResponse, newResponse)); } static List removeExpiredServices(@NonNull List responses, long now) { final List removedResponses = new ArrayList<>(); final Iterator iterator = responses.iterator(); while (iterator.hasNext()) { final MdnsResponse response = iterator.next(); // TODO: Check other records (A, AAAA, TXT) ttl time and remove the record if it's // expired. Then send service update notification. if (!response.hasServiceRecord() || response.getMinRemainingTtl(now) > 0) { // The responses are sorted by the service record ttl time. Break out of loop // early if service is not expired or no service record. break; } // Remove the ttl expired service. iterator.remove(); removedResponses.add(response); } return removedResponses; } private long getNextExpirationTime(long now) { if (mCachedServices.isEmpty()) { return EXPIRATION_NEVER; } long minRemainingTtl = EXPIRATION_NEVER; for (int i = 0; i < mCachedServices.size(); i++) { minRemainingTtl = min(minRemainingTtl, // The empty lists are not kept in the map, so there's always at least one // element in the list. Therefore, it's fine to get the first element without a // null check. mCachedServices.valueAt(i).get(0).getMinRemainingTtl(now)); } return minRemainingTtl == EXPIRATION_NEVER ? EXPIRATION_NEVER : now + minRemainingTtl; } /** * Check whether the ttl time is expired on each service and notify to the listeners */ private void maybeRemoveExpiredServices(CacheKey cacheKey, long now) { ensureRunningOnHandlerThread(mHandler); if (now < mNextExpirationTime) { // Skip the check if ttl time is not expired. return; } final List responses = mCachedServices.get(cacheKey); if (responses == null) { // No such services. return; } final List removedResponses = removeExpiredServices(responses, now); if (removedResponses.isEmpty()) { // No expired services. return; } for (MdnsResponse previousResponse : removedResponses) { notifyServiceExpired(cacheKey, previousResponse, null /* newResponse */); } // Remove the serviceType if no response. if (responses.isEmpty()) { mCachedServices.remove(cacheKey); } // Update next expiration time. mNextExpirationTime = getNextExpirationTime(now); } /*** Callbacks for listening service expiration */ public interface ServiceExpiredCallback { /*** Notify the service is expired */ void onServiceRecordExpired(@NonNull MdnsResponse previousResponse, @Nullable MdnsResponse newResponse); } // TODO: Schedule a job to check ttl expiration for all services and notify to the clients. }