1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef _DNS_DNSTLSDISPATCHER_H 18 #define _DNS_DNSTLSDISPATCHER_H 19 20 #include <list> 21 #include <map> 22 #include <memory> 23 #include <mutex> 24 25 #include <android-base/thread_annotations.h> 26 #include <netdutils/Slice.h> 27 28 #include "DnsTlsServer.h" 29 #include "DnsTlsTransport.h" 30 #include "IDnsTlsSocketFactory.h" 31 #include "PrivateDnsValidationObserver.h" 32 #include "resolv_private.h" 33 34 namespace android { 35 namespace net { 36 37 // This is a singleton class that manages the collection of active DnsTlsTransports. 38 // Queries made here are dispatched to an existing or newly constructed DnsTlsTransport. 39 // TODO: PrivateDnsValidationObserver is not implemented in this class. Remove it. 40 class DnsTlsDispatcher : public PrivateDnsValidationObserver { 41 public: 42 // Constructor with dependency injection for testing. DnsTlsDispatcher(std::unique_ptr<IDnsTlsSocketFactory> factory)43 explicit DnsTlsDispatcher(std::unique_ptr<IDnsTlsSocketFactory> factory) 44 : mFactory(std::move(factory)) {} 45 46 static DnsTlsDispatcher& getInstance(); 47 48 // Enqueues |query| for resolution via the given |tlsServers| on the 49 // network indicated by |mark|; writes the response into |ans|, and stores 50 // the count of bytes written in |resplen|. Returns a success or error code. 51 // The order in which servers from |tlsServers| are queried may not be the 52 // order passed in by the caller. 53 DnsTlsTransport::Response query(const std::list<DnsTlsServer>& tlsServers, 54 ResState* _Nonnull statp, const netdutils::Slice query, 55 const netdutils::Slice ans, int* _Nonnull resplen, 56 bool dotQuickFallback); 57 58 // Given a |query|, sends it to the server on the network indicated by |mark|, 59 // and writes the response into |ans|, and indicates the number of bytes written in |resplen|. 60 // If the whole procedure above triggers (or experiences) any new connection, |connectTriggered| 61 // is set. Returns a success or error code. 62 DnsTlsTransport::Response query(const DnsTlsServer& server, unsigned netId, unsigned mark, 63 const netdutils::Slice query, const netdutils::Slice ans, 64 int* _Nonnull resplen, bool* _Nonnull connectTriggered); 65 66 // Implement PrivateDnsValidationObserver. onValidationStateUpdate(const std::string &,Validation,uint32_t)67 void onValidationStateUpdate(const std::string&, Validation, uint32_t) override{}; 68 69 void forceCleanup(unsigned netId) EXCLUDES(sLock); 70 71 private: 72 DnsTlsDispatcher(); 73 74 // This lock is static so that it can be used to annotate the Transport struct. 75 // DnsTlsDispatcher is a singleton in practice, so making this static does not change 76 // the locking behavior. 77 static std::mutex sLock; 78 79 // Key = <mark, server> 80 typedef std::pair<unsigned, const DnsTlsServer> Key; 81 82 // Transport is a thin wrapper around DnsTlsTransport, adding reference counting and 83 // usage monitoring so we can expire idle sessions from the cache. 84 struct Transport { TransportTransport85 Transport(const DnsTlsServer& server, unsigned mark, unsigned netId, 86 IDnsTlsSocketFactory* _Nonnull factory, int triggerThr, int unusableThr, 87 int timeout) 88 : transport(server, mark, factory), 89 mNetId(netId), 90 triggerThreshold(triggerThr), 91 unusableThreshold(unusableThr), 92 mTimeout(timeout) {} 93 94 // DnsTlsTransport is thread-safe, so it doesn't need to be guarded. 95 DnsTlsTransport transport; 96 97 // The expected network, assigned from dns_netid, to which Transport will send DNS packets. 98 const unsigned mNetId; 99 100 // This use counter and timestamp are used to ensure that only idle sessions are 101 // destroyed. 102 int useCount GUARDED_BY(sLock) = 0; 103 // lastUsed is only guaranteed to be meaningful after useCount is decremented to zero. 104 std::chrono::time_point<std::chrono::steady_clock> lastUsed GUARDED_BY(sLock); 105 106 // If DoT revalidation is disabled, it returns true; otherwise, it returns 107 // whether or not this Transport is usable. 108 bool usable() REQUIRES(sLock); 109 110 // Used to track if this Transport is usable. 111 int continuousfailureCount GUARDED_BY(sLock) = 0; 112 113 bool checkRevalidationNecessary() REQUIRES(sLock); 114 timeoutTransport115 std::chrono::milliseconds timeout() const { return mTimeout; } 116 117 static constexpr int kDotRevalidationThreshold = -1; 118 static constexpr int kDotXportUnusableThreshold = -1; 119 static constexpr int kDotQueryTimeoutMs = -1; 120 121 private: 122 // The flag to record whether or not dot_revalidation_threshold is ever reached. 123 bool isRevalidationThresholdReached GUARDED_BY(sLock) = false; 124 125 // The flag to record whether or not dot_xport_unusable_threshold is ever reached. 126 bool isXportUnusableThresholdReached GUARDED_BY(sLock) = false; 127 128 // If the number of continuous query timeouts reaches the threshold, mark the 129 // server as unvalidated and trigger a validation. 130 // If the value is not a positive value or private DNS mode is strict mode, no threshold is 131 // set. Note that it must be at least 10, or it breaks 132 // ConnectTlsServerTimeout_ConcurrentQueries test. 133 const int triggerThreshold; 134 135 // The threshold to determine if this Transport is considered unusable. 136 // If the number of continuous query timeouts reaches the threshold, mark this 137 // Transport as unusable. An unusable Transport won't be used anymore. 138 // If the value is not a positive value or private DNS mode is strict mode, no threshold is 139 // set. 140 const int unusableThreshold; 141 142 // The time to await a future (the result of a DNS request) from the DnsTlsTransport 143 // of this Transport. 144 // To set an infinite timeout, assign the value to -1. 145 const std::chrono::milliseconds mTimeout; 146 }; 147 148 Transport* _Nullable addTransport(const DnsTlsServer& server, unsigned mark, unsigned netId) 149 REQUIRES(sLock); 150 Transport* _Nullable getTransport(const Key& key) REQUIRES(sLock); 151 152 // Cache of reusable DnsTlsTransports. Transports stay in cache as long as 153 // they are in use and for a few minutes after. 154 std::map<Key, std::unique_ptr<Transport>> mStore GUARDED_BY(sLock); 155 156 // The last time we did a cleanup. For efficiency, we only perform a cleanup once every 157 // few minutes. 158 std::chrono::time_point<std::chrono::steady_clock> mLastCleanup GUARDED_BY(sLock); 159 160 DnsTlsTransport::Result queryInternal(Transport& transport, const netdutils::Slice query) 161 EXCLUDES(sLock); 162 163 void maybeCleanup(std::chrono::time_point<std::chrono::steady_clock> now) REQUIRES(sLock); 164 165 // Drop any cache entries whose useCount is zero and which have not been used recently. 166 // This function performs a linear scan of mStore. 167 void cleanup(std::chrono::time_point<std::chrono::steady_clock> now, 168 std::optional<unsigned> netId) REQUIRES(sLock); 169 170 // Return a sorted list of usable DnsTlsServers in preference order. 171 std::list<DnsTlsServer> getOrderedAndUsableServerList(const std::list<DnsTlsServer>& tlsServers, 172 unsigned netId, unsigned mark); 173 174 // Trivial factory for DnsTlsSockets. Dependency injection is only used for testing. 175 std::unique_ptr<IDnsTlsSocketFactory> mFactory; 176 }; 177 178 } // end of namespace net 179 } // end of namespace android 180 181 #endif // _DNS_DNSTLSDISPATCHER_H 182