1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <array>
20 #include <list>
21 #include <map>
22 #include <mutex>
23 #include <vector>
24 
25 #include <android-base/format.h>
26 #include <android-base/logging.h>
27 #include <android-base/result.h>
28 #include <android-base/thread_annotations.h>
29 #include <netdutils/BackoffSequence.h>
30 #include <netdutils/DumpWriter.h>
31 #include <netdutils/InternetAddresses.h>
32 #include <netdutils/Slice.h>
33 #include <stats.pb.h>
34 
35 #include "DnsTlsServer.h"
36 #include "LockedQueue.h"
37 #include "PrivateDnsValidationObserver.h"
38 #include "doh.h"
39 
40 namespace android {
41 namespace net {
42 
43 PrivateDnsModes convertEnumType(PrivateDnsMode mode);
44 
45 struct PrivateDnsStatus {
46     PrivateDnsMode mode;
47 
48     // TODO: change the type to std::vector<DnsTlsServer>.
49     std::map<DnsTlsServer, Validation, AddressComparator> dotServersMap;
50 
51     std::map<netdutils::IPSockAddr, Validation> dohServersMap;
52 
validatedServersPrivateDnsStatus53     std::list<DnsTlsServer> validatedServers() const {
54         std::list<DnsTlsServer> servers;
55 
56         for (const auto& pair : dotServersMap) {
57             if (pair.second == Validation::success) {
58                 servers.push_back(pair.first);
59             }
60         }
61         return servers;
62     }
63 
hasValidatedDohServersPrivateDnsStatus64     bool hasValidatedDohServers() const {
65         for (const auto& [_, status] : dohServersMap) {
66             if (status == Validation::success) {
67                 return true;
68             }
69         }
70         return false;
71     }
72 };
73 
74 class PrivateDnsConfiguration {
75   public:
76     static constexpr int kDohQueryDefaultTimeoutMs = 30000;
77     static constexpr int kDohProbeDefaultTimeoutMs = 60000;
78 
79     // The default value for QUIC max_idle_timeout.
80     static constexpr int kDohIdleDefaultTimeoutMs = 55000;
81 
82     struct ServerIdentity {
83         const netdutils::IPSockAddr sockaddr;
84         const std::string provider;
85 
ServerIdentityServerIdentity86         explicit ServerIdentity(const DnsTlsServer& server)
87             : sockaddr(server.addr()), provider(server.provider()) {}
ServerIdentityServerIdentity88         ServerIdentity(const netdutils::IPSockAddr& addr, const std::string& host)
89             : sockaddr(addr), provider(host) {}
90 
91         bool operator<(const ServerIdentity& other) const {
92             return std::tie(sockaddr, provider) < std::tie(other.sockaddr, other.provider);
93         }
94         bool operator==(const ServerIdentity& other) const {
95             return std::tie(sockaddr, provider) == std::tie(other.sockaddr, other.provider);
96         }
97     };
98 
99     // The only instance of PrivateDnsConfiguration.
getInstance()100     static PrivateDnsConfiguration& getInstance() {
101         static PrivateDnsConfiguration instance;
102         return instance;
103     }
104 
105     int set(int32_t netId, uint32_t mark, const std::vector<std::string>& unencryptedServers,
106             const std::vector<std::string>& encryptedServers, const std::string& name,
107             const std::string& caCert) EXCLUDES(mPrivateDnsLock);
108 
109     void initDoh() EXCLUDES(mPrivateDnsLock);
110 
111     PrivateDnsStatus getStatus(unsigned netId) const EXCLUDES(mPrivateDnsLock);
112     NetworkDnsServerSupportReported getStatusForMetrics(unsigned netId) const
113             EXCLUDES(mPrivateDnsLock);
114 
115     void clear(unsigned netId) EXCLUDES(mPrivateDnsLock);
116 
117     ssize_t dohQuery(unsigned netId, const netdutils::Slice query, const netdutils::Slice answer,
118                      uint64_t timeoutMs) EXCLUDES(mPrivateDnsLock);
119 
120     // Request the server to be revalidated on a connection tagged with |mark|.
121     // Returns a Result to indicate if the request is accepted.
122     base::Result<void> requestDotValidation(unsigned netId, const ServerIdentity& identity,
123                                             uint32_t mark) EXCLUDES(mPrivateDnsLock);
124 
125     void setObserver(PrivateDnsValidationObserver* observer);
126 
127     void dump(netdutils::DumpWriter& dw) const;
128 
129     void onDohStatusUpdate(uint32_t netId, bool success, const char* ipAddr, const char* host)
130             EXCLUDES(mPrivateDnsLock);
131 
132     base::Result<netdutils::IPSockAddr> getDohServer(unsigned netId) const
133             EXCLUDES(mPrivateDnsLock);
134 
135   private:
136     PrivateDnsConfiguration() = default;
137 
138     int setDot(int32_t netId, uint32_t mark, const std::vector<std::string>& servers,
139                const std::string& name, const std::string& caCert) REQUIRES(mPrivateDnsLock);
140 
141     void clearDot(int32_t netId) REQUIRES(mPrivateDnsLock);
142 
143     // For testing.
144     base::Result<DnsTlsServer*> getDotServer(const ServerIdentity& identity, unsigned netId)
145             EXCLUDES(mPrivateDnsLock);
146 
147     base::Result<DnsTlsServer*> getDotServerLocked(const ServerIdentity& identity, unsigned netId)
148             REQUIRES(mPrivateDnsLock);
149 
150     // TODO: change the return type to Result<PrivateDnsStatus>.
151     PrivateDnsStatus getStatusLocked(unsigned netId) const REQUIRES(mPrivateDnsLock);
152 
153     // Launchs a thread to run the validation for the DoT server |server| on the network |netId|.
154     // |isRevalidation| is true if this call is due to a revalidation request.
155     void startDotValidation(const ServerIdentity& identity, unsigned netId, bool isRevalidation)
156             REQUIRES(mPrivateDnsLock);
157 
158     bool recordDotValidation(const ServerIdentity& identity, unsigned netId, bool success,
159                              bool isRevalidation) EXCLUDES(mPrivateDnsLock);
160 
161     void sendPrivateDnsValidationEvent(const ServerIdentity& identity, unsigned netId,
162                                        bool success) const REQUIRES(mPrivateDnsLock);
163 
164     // Decide if a validation for |server| is needed. Note that servers that have failed
165     // multiple validation attempts but for which there is still a validating
166     // thread running are marked as being Validation::in_process.
167     bool needsValidation(const DnsTlsServer& server) const REQUIRES(mPrivateDnsLock);
168 
169     void updateServerState(const ServerIdentity& identity, Validation state, uint32_t netId)
170             REQUIRES(mPrivateDnsLock);
171 
172     void initDohLocked() REQUIRES(mPrivateDnsLock);
173     int setDoh(int32_t netId, uint32_t mark, const std::vector<std::string>& servers,
174                const std::string& name, const std::string& caCert) REQUIRES(mPrivateDnsLock);
175     void clearDoh(unsigned netId) REQUIRES(mPrivateDnsLock);
176 
177     mutable std::mutex mPrivateDnsLock;
178     std::map<unsigned, PrivateDnsMode> mPrivateDnsModes GUARDED_BY(mPrivateDnsLock);
179 
180     // Contains all servers for a network, along with their current validation status.
181     // In case a server is removed due to a configuration change, it remains in this map,
182     // but is marked inactive.
183     // Any pending validation threads will continue running because we have no way to cancel them.
184     std::map<unsigned, std::map<ServerIdentity, DnsTlsServer>> mDotTracker
185             GUARDED_BY(mPrivateDnsLock);
186 
187     void notifyValidationStateUpdate(const netdutils::IPSockAddr& sockaddr, Validation validation,
188                                      uint32_t netId) const REQUIRES(mPrivateDnsLock);
189 
190     bool needReportEvent(uint32_t netId, ServerIdentity identity, bool success) const
191             REQUIRES(mPrivateDnsLock);
192 
193     // TODO: fix the reentrancy problem.
194     PrivateDnsValidationObserver* mObserver GUARDED_BY(mPrivateDnsLock);
195 
196     DohDispatcher* mDohDispatcher = nullptr;
197     std::condition_variable mCv;
198 
199     friend class PrivateDnsConfigurationTest;
200 
201     // It's not const because PrivateDnsConfigurationTest needs to override it.
202     // TODO: make it const by dependency injection.
203     netdutils::BackoffSequence<>::Builder mBackoffBuilder =
204             netdutils::BackoffSequence<>::Builder()
205                     .withInitialRetransmissionTime(std::chrono::seconds(60))
206                     .withMaximumRetransmissionTime(std::chrono::seconds(3600));
207 
208     struct DohIdentity {
209         std::string httpsTemplate;
210         std::string ipAddr;
211         std::string host;
212         Validation status;
213         bool operator<(const DohIdentity& other) const {
214             return std::tie(ipAddr, host) < std::tie(other.ipAddr, other.host);
215         }
216         bool operator==(const DohIdentity& other) const {
217             return std::tie(ipAddr, host) == std::tie(other.ipAddr, other.host);
218         }
219         bool operator<(const ServerIdentity& other) const {
220             std::string otherIp = other.sockaddr.ip().toString();
221             return std::tie(ipAddr, host) < std::tie(otherIp, other.provider);
222         }
223         bool operator==(const ServerIdentity& other) const {
224             std::string otherIp = other.sockaddr.ip().toString();
225             return std::tie(ipAddr, host) == std::tie(otherIp, other.provider);
226         }
227     };
228 
229     struct DohProviderEntry {
230         std::string provider;
231         std::set<std::string> ips;
232         std::string host;
233         std::string httpsTemplate;
234         bool requireRootPermission;
235 
getDohIdentityDohProviderEntry236         base::Result<DohIdentity> getDohIdentity(const std::vector<std::string>& sortedValidIps,
237                                                  const std::string& host) const {
238             // If the private DNS hostname is known, `sortedValidIps` are the IP addresses
239             // resolved from the hostname, and hostname verification will be performed during
240             // TLS handshake to ensure the validity of the server, so it's not necessary to
241             // check the IP address.
242             if (!host.empty()) {
243                 if (this->host != host) return Errorf("host {} not matched", host);
244                 if (!sortedValidIps.empty()) {
245                     const auto& ip = sortedValidIps[0];
246                     LOG(INFO) << fmt::format("getDohIdentity: {} {}", ip, host);
247                     return DohIdentity{httpsTemplate, ip, host, Validation::in_process};
248                 }
249             }
250             for (const auto& ip : sortedValidIps) {
251                 if (ips.find(ip) == ips.end()) continue;
252                 LOG(INFO) << fmt::format("getDohIdentity: {} {}", ip, host);
253                 return DohIdentity{httpsTemplate, ip, host, Validation::in_process};
254             }
255             return Errorf("server not matched");
256         };
257     };
258 
259     // TODO: Move below DoH relevant stuff into Rust implementation.
260     std::map<unsigned, DohIdentity> mDohTracker GUARDED_BY(mPrivateDnsLock);
261     std::array<DohProviderEntry, 5> mAvailableDoHProviders = {{
262             {"Google",
263              {"2001:4860:4860::8888", "2001:4860:4860::8844", "8.8.8.8", "8.8.4.4"},
264              "dns.google",
265              "https://dns.google/dns-query",
266              false},
267             {"Google DNS64",
268              {"2001:4860:4860::64", "2001:4860:4860::6464"},
269              "dns64.dns.google",
270              "https://dns64.dns.google/dns-query",
271              false},
272             {"Cloudflare",
273              {"2606:4700::6810:f8f9", "2606:4700::6810:f9f9", "104.16.248.249", "104.16.249.249"},
274              "cloudflare-dns.com",
275              "https://cloudflare-dns.com/dns-query",
276              false},
277 
278             // The DoH providers for testing only.
279             // Using ResolverTestProvider requires that the DnsResolver is configured by someone
280             // who has root permission, which should be run by tests only.
281             {"ResolverTestProvider",
282              {"127.0.0.3", "::1"},
283              "example.com",
284              "https://example.com/dns-query",
285              true},
286             {"AndroidTesting",
287              {"192.0.2.100"},
288              "dns.androidtesting.org",
289              "https://dns.androidtesting.org/dns-query",
290              false},
291     }};
292 
293     // Makes a DohIdentity by looking up the `mAvailableDoHProviders` by `servers` and `name`.
294     base::Result<DohIdentity> makeDohIdentity(const std::vector<std::string>& servers,
295                                               const std::string& name) const
296             REQUIRES(mPrivateDnsLock);
297 
298     // For the metrics. Store the current DNS server list in the same order as what is passed
299     // in setResolverConfiguration().
300     std::map<unsigned, std::vector<std::string>> mUnorderedDnsTracker GUARDED_BY(mPrivateDnsLock);
301     std::map<unsigned, std::vector<std::string>> mUnorderedDotTracker GUARDED_BY(mPrivateDnsLock);
302     std::map<unsigned, std::vector<std::string>> mUnorderedDohTracker GUARDED_BY(mPrivateDnsLock);
303 
304     struct RecordEntry {
RecordEntryRecordEntry305         RecordEntry(uint32_t netId, const ServerIdentity& identity, Validation state)
306             : netId(netId), serverIdentity(identity), state(state) {}
307 
308         const uint32_t netId;
309         const ServerIdentity serverIdentity;
310         const Validation state;
311         const std::chrono::system_clock::time_point timestamp = std::chrono::system_clock::now();
312     };
313 
314     LockedRingBuffer<RecordEntry> mPrivateDnsLog{100};
315 };
316 
317 }  // namespace net
318 }  // namespace android
319