1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #pragma once 18 19 #include <array> 20 #include <list> 21 #include <map> 22 #include <mutex> 23 #include <vector> 24 25 #include <android-base/format.h> 26 #include <android-base/logging.h> 27 #include <android-base/result.h> 28 #include <android-base/thread_annotations.h> 29 #include <netdutils/BackoffSequence.h> 30 #include <netdutils/DumpWriter.h> 31 #include <netdutils/InternetAddresses.h> 32 #include <netdutils/Slice.h> 33 #include <stats.pb.h> 34 35 #include "DnsTlsServer.h" 36 #include "LockedQueue.h" 37 #include "PrivateDnsValidationObserver.h" 38 #include "doh.h" 39 40 namespace android { 41 namespace net { 42 43 PrivateDnsModes convertEnumType(PrivateDnsMode mode); 44 45 struct PrivateDnsStatus { 46 PrivateDnsMode mode; 47 48 // TODO: change the type to std::vector<DnsTlsServer>. 49 std::map<DnsTlsServer, Validation, AddressComparator> dotServersMap; 50 51 std::map<netdutils::IPSockAddr, Validation> dohServersMap; 52 validatedServersPrivateDnsStatus53 std::list<DnsTlsServer> validatedServers() const { 54 std::list<DnsTlsServer> servers; 55 56 for (const auto& pair : dotServersMap) { 57 if (pair.second == Validation::success) { 58 servers.push_back(pair.first); 59 } 60 } 61 return servers; 62 } 63 hasValidatedDohServersPrivateDnsStatus64 bool hasValidatedDohServers() const { 65 for (const auto& [_, status] : dohServersMap) { 66 if (status == Validation::success) { 67 return true; 68 } 69 } 70 return false; 71 } 72 }; 73 74 class PrivateDnsConfiguration { 75 public: 76 static constexpr int kDohQueryDefaultTimeoutMs = 30000; 77 static constexpr int kDohProbeDefaultTimeoutMs = 60000; 78 79 // The default value for QUIC max_idle_timeout. 80 static constexpr int kDohIdleDefaultTimeoutMs = 55000; 81 82 struct ServerIdentity { 83 const netdutils::IPSockAddr sockaddr; 84 const std::string provider; 85 ServerIdentityServerIdentity86 explicit ServerIdentity(const DnsTlsServer& server) 87 : sockaddr(server.addr()), provider(server.provider()) {} ServerIdentityServerIdentity88 ServerIdentity(const netdutils::IPSockAddr& addr, const std::string& host) 89 : sockaddr(addr), provider(host) {} 90 91 bool operator<(const ServerIdentity& other) const { 92 return std::tie(sockaddr, provider) < std::tie(other.sockaddr, other.provider); 93 } 94 bool operator==(const ServerIdentity& other) const { 95 return std::tie(sockaddr, provider) == std::tie(other.sockaddr, other.provider); 96 } 97 }; 98 99 // The only instance of PrivateDnsConfiguration. getInstance()100 static PrivateDnsConfiguration& getInstance() { 101 static PrivateDnsConfiguration instance; 102 return instance; 103 } 104 105 int set(int32_t netId, uint32_t mark, const std::vector<std::string>& unencryptedServers, 106 const std::vector<std::string>& encryptedServers, const std::string& name, 107 const std::string& caCert) EXCLUDES(mPrivateDnsLock); 108 109 void initDoh() EXCLUDES(mPrivateDnsLock); 110 111 PrivateDnsStatus getStatus(unsigned netId) const EXCLUDES(mPrivateDnsLock); 112 NetworkDnsServerSupportReported getStatusForMetrics(unsigned netId) const 113 EXCLUDES(mPrivateDnsLock); 114 115 void clear(unsigned netId) EXCLUDES(mPrivateDnsLock); 116 117 ssize_t dohQuery(unsigned netId, const netdutils::Slice query, const netdutils::Slice answer, 118 uint64_t timeoutMs) EXCLUDES(mPrivateDnsLock); 119 120 // Request the server to be revalidated on a connection tagged with |mark|. 121 // Returns a Result to indicate if the request is accepted. 122 base::Result<void> requestDotValidation(unsigned netId, const ServerIdentity& identity, 123 uint32_t mark) EXCLUDES(mPrivateDnsLock); 124 125 void setObserver(PrivateDnsValidationObserver* observer); 126 127 void dump(netdutils::DumpWriter& dw) const; 128 129 void onDohStatusUpdate(uint32_t netId, bool success, const char* ipAddr, const char* host) 130 EXCLUDES(mPrivateDnsLock); 131 132 base::Result<netdutils::IPSockAddr> getDohServer(unsigned netId) const 133 EXCLUDES(mPrivateDnsLock); 134 135 private: 136 PrivateDnsConfiguration() = default; 137 138 int setDot(int32_t netId, uint32_t mark, const std::vector<std::string>& servers, 139 const std::string& name, const std::string& caCert) REQUIRES(mPrivateDnsLock); 140 141 void clearDot(int32_t netId) REQUIRES(mPrivateDnsLock); 142 143 // For testing. 144 base::Result<DnsTlsServer*> getDotServer(const ServerIdentity& identity, unsigned netId) 145 EXCLUDES(mPrivateDnsLock); 146 147 base::Result<DnsTlsServer*> getDotServerLocked(const ServerIdentity& identity, unsigned netId) 148 REQUIRES(mPrivateDnsLock); 149 150 // TODO: change the return type to Result<PrivateDnsStatus>. 151 PrivateDnsStatus getStatusLocked(unsigned netId) const REQUIRES(mPrivateDnsLock); 152 153 // Launchs a thread to run the validation for the DoT server |server| on the network |netId|. 154 // |isRevalidation| is true if this call is due to a revalidation request. 155 void startDotValidation(const ServerIdentity& identity, unsigned netId, bool isRevalidation) 156 REQUIRES(mPrivateDnsLock); 157 158 bool recordDotValidation(const ServerIdentity& identity, unsigned netId, bool success, 159 bool isRevalidation) EXCLUDES(mPrivateDnsLock); 160 161 void sendPrivateDnsValidationEvent(const ServerIdentity& identity, unsigned netId, 162 bool success) const REQUIRES(mPrivateDnsLock); 163 164 // Decide if a validation for |server| is needed. Note that servers that have failed 165 // multiple validation attempts but for which there is still a validating 166 // thread running are marked as being Validation::in_process. 167 bool needsValidation(const DnsTlsServer& server) const REQUIRES(mPrivateDnsLock); 168 169 void updateServerState(const ServerIdentity& identity, Validation state, uint32_t netId) 170 REQUIRES(mPrivateDnsLock); 171 172 void initDohLocked() REQUIRES(mPrivateDnsLock); 173 int setDoh(int32_t netId, uint32_t mark, const std::vector<std::string>& servers, 174 const std::string& name, const std::string& caCert) REQUIRES(mPrivateDnsLock); 175 void clearDoh(unsigned netId) REQUIRES(mPrivateDnsLock); 176 177 mutable std::mutex mPrivateDnsLock; 178 std::map<unsigned, PrivateDnsMode> mPrivateDnsModes GUARDED_BY(mPrivateDnsLock); 179 180 // Contains all servers for a network, along with their current validation status. 181 // In case a server is removed due to a configuration change, it remains in this map, 182 // but is marked inactive. 183 // Any pending validation threads will continue running because we have no way to cancel them. 184 std::map<unsigned, std::map<ServerIdentity, DnsTlsServer>> mDotTracker 185 GUARDED_BY(mPrivateDnsLock); 186 187 void notifyValidationStateUpdate(const netdutils::IPSockAddr& sockaddr, Validation validation, 188 uint32_t netId) const REQUIRES(mPrivateDnsLock); 189 190 bool needReportEvent(uint32_t netId, ServerIdentity identity, bool success) const 191 REQUIRES(mPrivateDnsLock); 192 193 // TODO: fix the reentrancy problem. 194 PrivateDnsValidationObserver* mObserver GUARDED_BY(mPrivateDnsLock); 195 196 DohDispatcher* mDohDispatcher = nullptr; 197 std::condition_variable mCv; 198 199 friend class PrivateDnsConfigurationTest; 200 201 // It's not const because PrivateDnsConfigurationTest needs to override it. 202 // TODO: make it const by dependency injection. 203 netdutils::BackoffSequence<>::Builder mBackoffBuilder = 204 netdutils::BackoffSequence<>::Builder() 205 .withInitialRetransmissionTime(std::chrono::seconds(60)) 206 .withMaximumRetransmissionTime(std::chrono::seconds(3600)); 207 208 struct DohIdentity { 209 std::string httpsTemplate; 210 std::string ipAddr; 211 std::string host; 212 Validation status; 213 bool operator<(const DohIdentity& other) const { 214 return std::tie(ipAddr, host) < std::tie(other.ipAddr, other.host); 215 } 216 bool operator==(const DohIdentity& other) const { 217 return std::tie(ipAddr, host) == std::tie(other.ipAddr, other.host); 218 } 219 bool operator<(const ServerIdentity& other) const { 220 std::string otherIp = other.sockaddr.ip().toString(); 221 return std::tie(ipAddr, host) < std::tie(otherIp, other.provider); 222 } 223 bool operator==(const ServerIdentity& other) const { 224 std::string otherIp = other.sockaddr.ip().toString(); 225 return std::tie(ipAddr, host) == std::tie(otherIp, other.provider); 226 } 227 }; 228 229 struct DohProviderEntry { 230 std::string provider; 231 std::set<std::string> ips; 232 std::string host; 233 std::string httpsTemplate; 234 bool requireRootPermission; 235 getDohIdentityDohProviderEntry236 base::Result<DohIdentity> getDohIdentity(const std::vector<std::string>& sortedValidIps, 237 const std::string& host) const { 238 // If the private DNS hostname is known, `sortedValidIps` are the IP addresses 239 // resolved from the hostname, and hostname verification will be performed during 240 // TLS handshake to ensure the validity of the server, so it's not necessary to 241 // check the IP address. 242 if (!host.empty()) { 243 if (this->host != host) return Errorf("host {} not matched", host); 244 if (!sortedValidIps.empty()) { 245 const auto& ip = sortedValidIps[0]; 246 LOG(INFO) << fmt::format("getDohIdentity: {} {}", ip, host); 247 return DohIdentity{httpsTemplate, ip, host, Validation::in_process}; 248 } 249 } 250 for (const auto& ip : sortedValidIps) { 251 if (ips.find(ip) == ips.end()) continue; 252 LOG(INFO) << fmt::format("getDohIdentity: {} {}", ip, host); 253 return DohIdentity{httpsTemplate, ip, host, Validation::in_process}; 254 } 255 return Errorf("server not matched"); 256 }; 257 }; 258 259 // TODO: Move below DoH relevant stuff into Rust implementation. 260 std::map<unsigned, DohIdentity> mDohTracker GUARDED_BY(mPrivateDnsLock); 261 std::array<DohProviderEntry, 5> mAvailableDoHProviders = {{ 262 {"Google", 263 {"2001:4860:4860::8888", "2001:4860:4860::8844", "8.8.8.8", "8.8.4.4"}, 264 "dns.google", 265 "https://dns.google/dns-query", 266 false}, 267 {"Google DNS64", 268 {"2001:4860:4860::64", "2001:4860:4860::6464"}, 269 "dns64.dns.google", 270 "https://dns64.dns.google/dns-query", 271 false}, 272 {"Cloudflare", 273 {"2606:4700::6810:f8f9", "2606:4700::6810:f9f9", "104.16.248.249", "104.16.249.249"}, 274 "cloudflare-dns.com", 275 "https://cloudflare-dns.com/dns-query", 276 false}, 277 278 // The DoH providers for testing only. 279 // Using ResolverTestProvider requires that the DnsResolver is configured by someone 280 // who has root permission, which should be run by tests only. 281 {"ResolverTestProvider", 282 {"127.0.0.3", "::1"}, 283 "example.com", 284 "https://example.com/dns-query", 285 true}, 286 {"AndroidTesting", 287 {"192.0.2.100"}, 288 "dns.androidtesting.org", 289 "https://dns.androidtesting.org/dns-query", 290 false}, 291 }}; 292 293 // Makes a DohIdentity by looking up the `mAvailableDoHProviders` by `servers` and `name`. 294 base::Result<DohIdentity> makeDohIdentity(const std::vector<std::string>& servers, 295 const std::string& name) const 296 REQUIRES(mPrivateDnsLock); 297 298 // For the metrics. Store the current DNS server list in the same order as what is passed 299 // in setResolverConfiguration(). 300 std::map<unsigned, std::vector<std::string>> mUnorderedDnsTracker GUARDED_BY(mPrivateDnsLock); 301 std::map<unsigned, std::vector<std::string>> mUnorderedDotTracker GUARDED_BY(mPrivateDnsLock); 302 std::map<unsigned, std::vector<std::string>> mUnorderedDohTracker GUARDED_BY(mPrivateDnsLock); 303 304 struct RecordEntry { RecordEntryRecordEntry305 RecordEntry(uint32_t netId, const ServerIdentity& identity, Validation state) 306 : netId(netId), serverIdentity(identity), state(state) {} 307 308 const uint32_t netId; 309 const ServerIdentity serverIdentity; 310 const Validation state; 311 const std::chrono::system_clock::time_point timestamp = std::chrono::system_clock::now(); 312 }; 313 314 LockedRingBuffer<RecordEntry> mPrivateDnsLog{100}; 315 }; 316 317 } // namespace net 318 } // namespace android 319