1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef _DNS_DNSTLSDISPATCHER_H
18 #define _DNS_DNSTLSDISPATCHER_H
19 
20 #include <list>
21 #include <map>
22 #include <memory>
23 #include <mutex>
24 
25 #include <android-base/thread_annotations.h>
26 #include <netdutils/Slice.h>
27 
28 #include "DnsTlsServer.h"
29 #include "DnsTlsTransport.h"
30 #include "IDnsTlsSocketFactory.h"
31 #include "PrivateDnsValidationObserver.h"
32 #include "resolv_private.h"
33 
34 namespace android {
35 namespace net {
36 
37 // This is a singleton class that manages the collection of active DnsTlsTransports.
38 // Queries made here are dispatched to an existing or newly constructed DnsTlsTransport.
39 // TODO: PrivateDnsValidationObserver is not implemented in this class. Remove it.
40 class DnsTlsDispatcher : public PrivateDnsValidationObserver {
41   public:
42     // Constructor with dependency injection for testing.
DnsTlsDispatcher(std::unique_ptr<IDnsTlsSocketFactory> factory)43     explicit DnsTlsDispatcher(std::unique_ptr<IDnsTlsSocketFactory> factory)
44         : mFactory(std::move(factory)) {}
45 
46     static DnsTlsDispatcher& getInstance();
47 
48     // Enqueues |query| for resolution via the given |tlsServers| on the
49     // network indicated by |mark|; writes the response into |ans|, and stores
50     // the count of bytes written in |resplen|. Returns a success or error code.
51     // The order in which servers from |tlsServers| are queried may not be the
52     // order passed in by the caller.
53     DnsTlsTransport::Response query(const std::list<DnsTlsServer>& tlsServers,
54                                     ResState* _Nonnull statp, const netdutils::Slice query,
55                                     const netdutils::Slice ans, int* _Nonnull resplen,
56                                     bool dotQuickFallback);
57 
58     // Given a |query|, sends it to the server on the network indicated by |mark|,
59     // and writes the response into |ans|, and indicates the number of bytes written in |resplen|.
60     // If the whole procedure above triggers (or experiences) any new connection, |connectTriggered|
61     // is set. Returns a success or error code.
62     DnsTlsTransport::Response query(const DnsTlsServer& server, unsigned netId, unsigned mark,
63                                     const netdutils::Slice query, const netdutils::Slice ans,
64                                     int* _Nonnull resplen, bool* _Nonnull connectTriggered);
65 
66     // Implement PrivateDnsValidationObserver.
onValidationStateUpdate(const std::string &,Validation,uint32_t)67     void onValidationStateUpdate(const std::string&, Validation, uint32_t) override{};
68 
69     void forceCleanup(unsigned netId) EXCLUDES(sLock);
70 
71   private:
72     DnsTlsDispatcher();
73 
74     // This lock is static so that it can be used to annotate the Transport struct.
75     // DnsTlsDispatcher is a singleton in practice, so making this static does not change
76     // the locking behavior.
77     static std::mutex sLock;
78 
79     // Key = <mark, server>
80     typedef std::pair<unsigned, const DnsTlsServer> Key;
81 
82     // Transport is a thin wrapper around DnsTlsTransport, adding reference counting and
83     // usage monitoring so we can expire idle sessions from the cache.
84     struct Transport {
TransportTransport85         Transport(const DnsTlsServer& server, unsigned mark, unsigned netId,
86                   IDnsTlsSocketFactory* _Nonnull factory, int triggerThr, int unusableThr,
87                   int timeout)
88             : transport(server, mark, factory),
89               mNetId(netId),
90               triggerThreshold(triggerThr),
91               unusableThreshold(unusableThr),
92               mTimeout(timeout) {}
93 
94         // DnsTlsTransport is thread-safe, so it doesn't need to be guarded.
95         DnsTlsTransport transport;
96 
97         // The expected network, assigned from dns_netid, to which Transport will send DNS packets.
98         const unsigned mNetId;
99 
100         // This use counter and timestamp are used to ensure that only idle sessions are
101         // destroyed.
102         int useCount GUARDED_BY(sLock) = 0;
103         // lastUsed is only guaranteed to be meaningful after useCount is decremented to zero.
104         std::chrono::time_point<std::chrono::steady_clock> lastUsed GUARDED_BY(sLock);
105 
106         // If DoT revalidation is disabled, it returns true; otherwise, it returns
107         // whether or not this Transport is usable.
108         bool usable() REQUIRES(sLock);
109 
110         // Used to track if this Transport is usable.
111         int continuousfailureCount GUARDED_BY(sLock) = 0;
112 
113         bool checkRevalidationNecessary() REQUIRES(sLock);
114 
timeoutTransport115         std::chrono::milliseconds timeout() const { return mTimeout; }
116 
117         static constexpr int kDotRevalidationThreshold = -1;
118         static constexpr int kDotXportUnusableThreshold = -1;
119         static constexpr int kDotQueryTimeoutMs = -1;
120 
121       private:
122         // The flag to record whether or not dot_revalidation_threshold is ever reached.
123         bool isRevalidationThresholdReached GUARDED_BY(sLock) = false;
124 
125         // The flag to record whether or not dot_xport_unusable_threshold is ever reached.
126         bool isXportUnusableThresholdReached GUARDED_BY(sLock) = false;
127 
128         // If the number of continuous query timeouts reaches the threshold, mark the
129         // server as unvalidated and trigger a validation.
130         // If the value is not a positive value or private DNS mode is strict mode, no threshold is
131         // set. Note that it must be at least 10, or it breaks
132         // ConnectTlsServerTimeout_ConcurrentQueries test.
133         const int triggerThreshold;
134 
135         // The threshold to determine if this Transport is considered unusable.
136         // If the number of continuous query timeouts reaches the threshold, mark this
137         // Transport as unusable. An unusable Transport won't be used anymore.
138         // If the value is not a positive value or private DNS mode is strict mode, no threshold is
139         // set.
140         const int unusableThreshold;
141 
142         // The time to await a future (the result of a DNS request) from the DnsTlsTransport
143         // of this Transport.
144         // To set an infinite timeout, assign the value to -1.
145         const std::chrono::milliseconds mTimeout;
146     };
147 
148     Transport* _Nullable addTransport(const DnsTlsServer& server, unsigned mark, unsigned netId)
149             REQUIRES(sLock);
150     Transport* _Nullable getTransport(const Key& key) REQUIRES(sLock);
151 
152     // Cache of reusable DnsTlsTransports.  Transports stay in cache as long as
153     // they are in use and for a few minutes after.
154     std::map<Key, std::unique_ptr<Transport>> mStore GUARDED_BY(sLock);
155 
156     // The last time we did a cleanup.  For efficiency, we only perform a cleanup once every
157     // few minutes.
158     std::chrono::time_point<std::chrono::steady_clock> mLastCleanup GUARDED_BY(sLock);
159 
160     DnsTlsTransport::Result queryInternal(Transport& transport, const netdutils::Slice query)
161             EXCLUDES(sLock);
162 
163     void maybeCleanup(std::chrono::time_point<std::chrono::steady_clock> now) REQUIRES(sLock);
164 
165     // Drop any cache entries whose useCount is zero and which have not been used recently.
166     // This function performs a linear scan of mStore.
167     void cleanup(std::chrono::time_point<std::chrono::steady_clock> now,
168                  std::optional<unsigned> netId) REQUIRES(sLock);
169 
170     // Return a sorted list of usable DnsTlsServers in preference order.
171     std::list<DnsTlsServer> getOrderedAndUsableServerList(const std::list<DnsTlsServer>& tlsServers,
172                                                           unsigned netId, unsigned mark);
173 
174     // Trivial factory for DnsTlsSockets.  Dependency injection is only used for testing.
175     std::unique_ptr<IDnsTlsSocketFactory> mFactory;
176 };
177 
178 }  // end of namespace net
179 }  // end of namespace android
180 
181 #endif  // _DNS_DNSTLSDISPATCHER_H
182