1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "resolv"
18 
19 #include <arpa/inet.h>
20 
21 #include <chrono>
22 
23 #include <android-base/logging.h>
24 #include <android-base/macros.h>
25 #include <gmock/gmock.h>
26 #include <gtest/gtest.h>
27 #include <netdutils/NetNativeTestBase.h>
28 #include <netdutils/Slice.h>
29 
30 #include "DnsTlsDispatcher.h"
31 #include "DnsTlsQueryMap.h"
32 #include "DnsTlsServer.h"
33 #include "DnsTlsSessionCache.h"
34 #include "DnsTlsSocket.h"
35 #include "DnsTlsTransport.h"
36 #include "Experiments.h"
37 #include "IDnsTlsSocket.h"
38 #include "IDnsTlsSocketFactory.h"
39 #include "IDnsTlsSocketObserver.h"
40 #include "tests/dns_responder/dns_tls_frontend.h"
41 
42 namespace android {
43 namespace net {
44 
45 using netdutils::IPAddress;
46 using netdutils::IPSockAddr;
47 using netdutils::makeSlice;
48 using netdutils::Slice;
49 
50 typedef std::vector<uint8_t> bytevec;
51 
52 static const std::string DOT_MAXTRIES_FLAG = "dot_maxtries";
53 static const std::string SERVERNAME1 = "dns.example.com";
54 static const std::string SERVERNAME2 = "dns.example.org";
55 static const IPAddress V4ADDR1 = IPAddress::forString("192.0.2.1");
56 static const IPAddress V4ADDR2 = IPAddress::forString("192.0.2.2");
57 static const IPAddress V6ADDR1 = IPAddress::forString("2001:db8::1");
58 static const IPAddress V6ADDR2 = IPAddress::forString("2001:db8::2");
59 
60 // BaseTest just provides constants that are useful for the tests.
61 class BaseTest : public NetNativeTestBase {
62   protected:
BaseTest()63     BaseTest() {
64         SERVER1.name = SERVERNAME1;
65     }
66 
67     DnsTlsServer SERVER1{V4ADDR1};
68 };
69 
make_query(uint16_t id,size_t size)70 bytevec make_query(uint16_t id, size_t size) {
71     bytevec vec(size);
72     vec[0] = id >> 8;
73     vec[1] = id;
74     // Arbitrarily fill the query body with unique data.
75     for (size_t i = 2; i < size; ++i) {
76         vec[i] = id + i;
77     }
78     return vec;
79 }
80 
81 // Query constants
82 const unsigned NETID = 123;
83 const unsigned MARK = 123;
84 const uint16_t ID = 52;
85 const uint16_t SIZE = 22;
86 const bytevec QUERY = make_query(ID, SIZE);
87 
88 template <class T>
89 class FakeSocketFactory : public IDnsTlsSocketFactory {
90   public:
FakeSocketFactory()91     FakeSocketFactory() {}
createDnsTlsSocket(const DnsTlsServer & server ATTRIBUTE_UNUSED,unsigned mark ATTRIBUTE_UNUSED,IDnsTlsSocketObserver * observer,DnsTlsSessionCache * cache ATTRIBUTE_UNUSED)92     std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
93             const DnsTlsServer& server ATTRIBUTE_UNUSED,
94             unsigned mark ATTRIBUTE_UNUSED,
95             IDnsTlsSocketObserver* observer,
96             DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
97         return std::make_unique<T>(observer);
98     }
99 };
100 
make_echo(uint16_t id,const Slice query)101 bytevec make_echo(uint16_t id, const Slice query) {
102     bytevec response(query.size() + 2);
103     response[0] = id >> 8;
104     response[1] = id;
105     // Echo the query as the fake response.
106     memcpy(response.data() + 2, query.base(), query.size());
107     return response;
108 }
109 
110 // Simplest possible fake server.  This just echoes the query as the response.
111 class FakeSocketEcho : public IDnsTlsSocket {
112   public:
FakeSocketEcho(IDnsTlsSocketObserver * observer)113     explicit FakeSocketEcho(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
query(uint16_t id,const Slice query)114     bool query(uint16_t id, const Slice query) override {
115         // Return the response immediately (asynchronously).
116         std::thread(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query)).detach();
117         return true;
118     }
startHandshake()119     bool startHandshake() override { return true; }
120 
121   private:
122     IDnsTlsSocketObserver* const mObserver;
123 };
124 
125 class TransportTest : public BaseTest {};
126 
TEST_F(TransportTest,Query)127 TEST_F(TransportTest, Query) {
128     FakeSocketFactory<FakeSocketEcho> factory;
129     DnsTlsTransport transport(SERVER1, MARK, &factory);
130     auto r = transport.query(makeSlice(QUERY)).get();
131 
132     EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
133     EXPECT_EQ(QUERY, r.response);
134     EXPECT_EQ(transport.getConnectCounter(), 1);
135 }
136 
137 // Fake Socket that echoes the observed query ID as the response body.
138 class FakeSocketId : public IDnsTlsSocket {
139   public:
FakeSocketId(IDnsTlsSocketObserver * observer)140     explicit FakeSocketId(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
query(uint16_t id,const Slice query ATTRIBUTE_UNUSED)141     bool query(uint16_t id, const Slice query ATTRIBUTE_UNUSED) override {
142         // Return the response immediately (asynchronously).
143         bytevec response(4);
144         // Echo the ID in the header to match the response to the query.
145         // This will be overwritten by DnsTlsQueryMap.
146         response[0] = id >> 8;
147         response[1] = id;
148         // Echo the ID in the body, so that the test can verify which ID was used by
149         // DnsTlsQueryMap.
150         response[2] = id >> 8;
151         response[3] = id;
152         std::thread(&IDnsTlsSocketObserver::onResponse, mObserver, response).detach();
153         return true;
154     }
startHandshake()155     bool startHandshake() override { return true; }
156 
157   private:
158     IDnsTlsSocketObserver* const mObserver;
159 };
160 
161 // Test that IDs are properly reused
TEST_F(TransportTest,IdReuse)162 TEST_F(TransportTest, IdReuse) {
163     FakeSocketFactory<FakeSocketId> factory;
164     DnsTlsTransport transport(SERVER1, MARK, &factory);
165     for (int i = 0; i < 100; ++i) {
166         // Send a query.
167         std::future<DnsTlsTransport::Result> f = transport.query(makeSlice(QUERY));
168         // Wait for the response.
169         DnsTlsTransport::Result r = f.get();
170         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
171 
172         // All queries should have an observed ID of zero, because it is returned to the ID pool
173         // after each use.
174         EXPECT_EQ(0, (r.response[2] << 8) | r.response[3]);
175     }
176     EXPECT_EQ(transport.getConnectCounter(), 1);
177 }
178 
179 // These queries might be handled in serial or parallel as they race the
180 // responses.
TEST_F(TransportTest,RacingQueries_10000)181 TEST_F(TransportTest, RacingQueries_10000) {
182     FakeSocketFactory<FakeSocketEcho> factory;
183     DnsTlsTransport transport(SERVER1, MARK, &factory);
184     std::vector<std::future<DnsTlsTransport::Result>> results;
185     // Fewer than 65536 queries to avoid ID exhaustion.
186     const int num_queries = 10000;
187     results.reserve(num_queries);
188     for (int i = 0; i < num_queries; ++i) {
189         results.push_back(transport.query(makeSlice(QUERY)));
190     }
191     for (auto& result : results) {
192         auto r = result.get();
193         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
194         EXPECT_EQ(QUERY, r.response);
195     }
196     EXPECT_EQ(transport.getConnectCounter(), 1);
197 }
198 
199 // A server that waits until sDelay queries are queued before responding.
200 class FakeSocketDelay : public IDnsTlsSocket {
201   public:
FakeSocketDelay(IDnsTlsSocketObserver * observer)202     explicit FakeSocketDelay(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
~FakeSocketDelay()203     ~FakeSocketDelay() {
204         std::lock_guard guard(mLock);
205         sDelay = 1;
206         sReverse = false;
207         sConnectable = true;
208     }
209     inline static size_t sDelay = 1;
210     inline static bool sReverse = false;
211     inline static bool sConnectable = true;
212 
query(uint16_t id,const Slice query)213     bool query(uint16_t id, const Slice query) override {
214         LOG(DEBUG) << "FakeSocketDelay got query with ID " << int(id);
215         std::lock_guard guard(mLock);
216         // Check for duplicate IDs.
217         EXPECT_EQ(0U, mIds.count(id));
218         mIds.insert(id);
219 
220         // Store response.
221         mResponses.push_back(make_echo(id, query));
222 
223         LOG(DEBUG) << "Up to " << mResponses.size() << " out of " << sDelay << " queries";
224         if (mResponses.size() == sDelay) {
225             std::thread(&FakeSocketDelay::sendResponses, this).detach();
226         }
227         return true;
228     }
startHandshake()229     bool startHandshake() override { return sConnectable; }
230 
231   private:
sendResponses()232     void sendResponses() {
233         std::lock_guard guard(mLock);
234         if (sReverse) {
235             std::reverse(std::begin(mResponses), std::end(mResponses));
236         }
237         for (auto& response : mResponses) {
238             mObserver->onResponse(response);
239         }
240         mIds.clear();
241         mResponses.clear();
242     }
243 
244     std::mutex mLock;
245     IDnsTlsSocketObserver* const mObserver;
246     std::set<uint16_t> mIds GUARDED_BY(mLock);
247     std::vector<bytevec> mResponses GUARDED_BY(mLock);
248 };
249 
TEST_F(TransportTest,ParallelColliding)250 TEST_F(TransportTest, ParallelColliding) {
251     FakeSocketDelay::sDelay = 10;
252     FakeSocketDelay::sReverse = false;
253     FakeSocketFactory<FakeSocketDelay> factory;
254     DnsTlsTransport transport(SERVER1, MARK, &factory);
255     std::vector<std::future<DnsTlsTransport::Result>> results;
256     // Fewer than 65536 queries to avoid ID exhaustion.
257     results.reserve(FakeSocketDelay::sDelay);
258     for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
259         results.push_back(transport.query(makeSlice(QUERY)));
260     }
261     for (auto& result : results) {
262         auto r = result.get();
263         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
264         EXPECT_EQ(QUERY, r.response);
265     }
266     EXPECT_EQ(transport.getConnectCounter(), 1);
267 }
268 
TEST_F(TransportTest,ParallelColliding_Max)269 TEST_F(TransportTest, ParallelColliding_Max) {
270     FakeSocketDelay::sDelay = 65536;
271     FakeSocketDelay::sReverse = false;
272     FakeSocketFactory<FakeSocketDelay> factory;
273     DnsTlsTransport transport(SERVER1, MARK, &factory);
274     std::vector<std::future<DnsTlsTransport::Result>> results;
275     // Exactly 65536 queries should still be possible in parallel,
276     // even if they all have the same original ID.
277     results.reserve(FakeSocketDelay::sDelay);
278     for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
279         results.push_back(transport.query(makeSlice(QUERY)));
280     }
281     for (auto& result : results) {
282         auto r = result.get();
283         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
284         EXPECT_EQ(QUERY, r.response);
285     }
286     EXPECT_EQ(transport.getConnectCounter(), 1);
287 }
288 
TEST_F(TransportTest,ParallelUnique)289 TEST_F(TransportTest, ParallelUnique) {
290     FakeSocketDelay::sDelay = 10;
291     FakeSocketDelay::sReverse = false;
292     FakeSocketFactory<FakeSocketDelay> factory;
293     DnsTlsTransport transport(SERVER1, MARK, &factory);
294     std::vector<bytevec> queries(FakeSocketDelay::sDelay);
295     std::vector<std::future<DnsTlsTransport::Result>> results;
296     results.reserve(FakeSocketDelay::sDelay);
297     for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
298         queries[i] = make_query(i, SIZE);
299         results.push_back(transport.query(makeSlice(queries[i])));
300     }
301     for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
302         auto r = results[i].get();
303         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
304         EXPECT_EQ(queries[i], r.response);
305     }
306     EXPECT_EQ(transport.getConnectCounter(), 1);
307 }
308 
TEST_F(TransportTest,ParallelUnique_Max)309 TEST_F(TransportTest, ParallelUnique_Max) {
310     FakeSocketDelay::sDelay = 65536;
311     FakeSocketDelay::sReverse = false;
312     FakeSocketFactory<FakeSocketDelay> factory;
313     DnsTlsTransport transport(SERVER1, MARK, &factory);
314     std::vector<bytevec> queries(FakeSocketDelay::sDelay);
315     std::vector<std::future<DnsTlsTransport::Result>> results;
316     // Exactly 65536 queries should still be possible in parallel,
317     // and they should all be mapped correctly back to the original ID.
318     results.reserve(FakeSocketDelay::sDelay);
319     for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
320         queries[i] = make_query(i, SIZE);
321         results.push_back(transport.query(makeSlice(queries[i])));
322     }
323     for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
324         auto r = results[i].get();
325         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
326         EXPECT_EQ(queries[i], r.response);
327     }
328     EXPECT_EQ(transport.getConnectCounter(), 1);
329 }
330 
TEST_F(TransportTest,IdExhaustion)331 TEST_F(TransportTest, IdExhaustion) {
332     const int num_queries = 65536;
333     // A delay of 65537 is unreachable, because the maximum number
334     // of outstanding queries is 65536.
335     FakeSocketDelay::sDelay = num_queries + 1;
336     FakeSocketDelay::sReverse = false;
337     FakeSocketFactory<FakeSocketDelay> factory;
338     DnsTlsTransport transport(SERVER1, MARK, &factory);
339     std::vector<std::future<DnsTlsTransport::Result>> results;
340     // Issue the maximum number of queries.
341     results.reserve(num_queries);
342     for (int i = 0; i < num_queries; ++i) {
343         results.push_back(transport.query(makeSlice(QUERY)));
344     }
345 
346     // The ID space is now full, so subsequent queries should fail immediately.
347     auto r = transport.query(makeSlice(QUERY)).get();
348     EXPECT_EQ(DnsTlsTransport::Response::internal_error, r.code);
349     EXPECT_TRUE(r.response.empty());
350 
351     for (auto& result : results) {
352         // All other queries should remain outstanding.
353         EXPECT_EQ(std::future_status::timeout,
354                 result.wait_for(std::chrono::duration<int>::zero()));
355     }
356     EXPECT_EQ(transport.getConnectCounter(), 1);
357 }
358 
359 // Responses can come back from the server in any order.  This should have no
360 // effect on Transport's observed behavior.
TEST_F(TransportTest,ReverseOrder)361 TEST_F(TransportTest, ReverseOrder) {
362     FakeSocketDelay::sDelay = 10;
363     FakeSocketDelay::sReverse = true;
364     FakeSocketFactory<FakeSocketDelay> factory;
365     DnsTlsTransport transport(SERVER1, MARK, &factory);
366     std::vector<bytevec> queries(FakeSocketDelay::sDelay);
367     std::vector<std::future<DnsTlsTransport::Result>> results;
368     results.reserve(FakeSocketDelay::sDelay);
369     for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
370         queries[i] = make_query(i, SIZE);
371         results.push_back(transport.query(makeSlice(queries[i])));
372     }
373     for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
374         auto r = results[i].get();
375         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
376         EXPECT_EQ(queries[i], r.response);
377     }
378     EXPECT_EQ(transport.getConnectCounter(), 1);
379 }
380 
TEST_F(TransportTest,ReverseOrder_Max)381 TEST_F(TransportTest, ReverseOrder_Max) {
382     FakeSocketDelay::sDelay = 65536;
383     FakeSocketDelay::sReverse = true;
384     FakeSocketFactory<FakeSocketDelay> factory;
385     DnsTlsTransport transport(SERVER1, MARK, &factory);
386     std::vector<bytevec> queries(FakeSocketDelay::sDelay);
387     std::vector<std::future<DnsTlsTransport::Result>> results;
388     results.reserve(FakeSocketDelay::sDelay);
389     for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
390         queries[i] = make_query(i, SIZE);
391         results.push_back(transport.query(makeSlice(queries[i])));
392     }
393     for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
394         auto r = results[i].get();
395         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
396         EXPECT_EQ(queries[i], r.response);
397     }
398     EXPECT_EQ(transport.getConnectCounter(), 1);
399 }
400 
401 // Returning null from the factory indicates a connection failure.
402 class NullSocketFactory : public IDnsTlsSocketFactory {
403   public:
NullSocketFactory()404     NullSocketFactory() {}
createDnsTlsSocket(const DnsTlsServer & server ATTRIBUTE_UNUSED,unsigned mark ATTRIBUTE_UNUSED,IDnsTlsSocketObserver * observer ATTRIBUTE_UNUSED,DnsTlsSessionCache * cache ATTRIBUTE_UNUSED)405     std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
406             const DnsTlsServer& server ATTRIBUTE_UNUSED,
407             unsigned mark ATTRIBUTE_UNUSED,
408             IDnsTlsSocketObserver* observer ATTRIBUTE_UNUSED,
409             DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
410         return nullptr;
411     }
412 };
413 
TEST_F(TransportTest,ConnectFail)414 TEST_F(TransportTest, ConnectFail) {
415     // Failure on creating socket.
416     NullSocketFactory factory1;
417     DnsTlsTransport transport1(SERVER1, MARK, &factory1);
418     auto r = transport1.query(makeSlice(QUERY)).get();
419 
420     EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
421     EXPECT_TRUE(r.response.empty());
422     EXPECT_EQ(transport1.getConnectCounter(), 1);
423 
424     // Failure on handshaking.
425     FakeSocketDelay::sConnectable = false;
426     FakeSocketFactory<FakeSocketDelay> factory2;
427     DnsTlsTransport transport2(SERVER1, MARK, &factory2);
428     r = transport2.query(makeSlice(QUERY)).get();
429 
430     EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
431     EXPECT_TRUE(r.response.empty());
432     EXPECT_EQ(transport2.getConnectCounter(), 1);
433 }
434 
435 // Simulate a socket that connects but then immediately receives a server
436 // close notification.
437 class FakeSocketClose : public IDnsTlsSocket {
438   public:
FakeSocketClose(IDnsTlsSocketObserver * observer)439     explicit FakeSocketClose(IDnsTlsSocketObserver* observer)
440         : mCloser(&IDnsTlsSocketObserver::onClosed, observer) {}
~FakeSocketClose()441     ~FakeSocketClose() { mCloser.join(); }
query(uint16_t id ATTRIBUTE_UNUSED,const Slice query ATTRIBUTE_UNUSED)442     bool query(uint16_t id ATTRIBUTE_UNUSED,
443                const Slice query ATTRIBUTE_UNUSED) override {
444         return true;
445     }
startHandshake()446     bool startHandshake() override { return true; }
447 
448   private:
449     std::thread mCloser;
450 };
451 
TEST_F(TransportTest,CloseRetryFail)452 TEST_F(TransportTest, CloseRetryFail) {
453     FakeSocketFactory<FakeSocketClose> factory;
454     DnsTlsTransport transport(SERVER1, MARK, &factory);
455     auto r = transport.query(makeSlice(QUERY)).get();
456 
457     EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
458     EXPECT_TRUE(r.response.empty());
459 
460     // Reconnections might be triggered depending on the flag.
461     EXPECT_EQ(transport.getConnectCounter(),
462               Experiments::getInstance()->getFlag(DOT_MAXTRIES_FLAG, DnsTlsQueryMap::kMaxTries));
463 }
464 
465 // Simulate a server that occasionally closes the connection and silently
466 // drops some queries.
467 class FakeSocketLimited : public IDnsTlsSocket {
468   public:
469     static int sLimit;  // Number of queries to answer per socket.
470     static size_t sMaxSize;  // Silently discard queries greater than this size.
FakeSocketLimited(IDnsTlsSocketObserver * observer)471     explicit FakeSocketLimited(IDnsTlsSocketObserver* observer)
472         : mObserver(observer), mQueries(0) {}
~FakeSocketLimited()473     ~FakeSocketLimited() {
474         {
475             LOG(DEBUG) << "~FakeSocketLimited acquiring mLock";
476             std::lock_guard guard(mLock);
477             LOG(DEBUG) << "~FakeSocketLimited acquired mLock";
478             for (auto& thread : mThreads) {
479                 LOG(DEBUG) << "~FakeSocketLimited joining response thread";
480                 thread.join();
481                 LOG(DEBUG) << "~FakeSocketLimited joined response thread";
482             }
483             mThreads.clear();
484         }
485 
486         if (mCloser) {
487             LOG(DEBUG) << "~FakeSocketLimited joining closer thread";
488             mCloser->join();
489             LOG(DEBUG) << "~FakeSocketLimited joined closer thread";
490         }
491     }
query(uint16_t id,const Slice query)492     bool query(uint16_t id, const Slice query) override {
493         LOG(DEBUG) << "FakeSocketLimited::query acquiring mLock";
494         std::lock_guard guard(mLock);
495         LOG(DEBUG) << "FakeSocketLimited::query acquired mLock";
496         ++mQueries;
497 
498         if (mQueries <= sLimit) {
499             LOG(DEBUG) << "size " << query.size() << " vs. limit of " << sMaxSize;
500             if (query.size() <= sMaxSize) {
501                 // Return the response immediately (asynchronously).
502                 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query));
503             }
504         }
505         if (mQueries == sLimit) {
506             mCloser = std::make_unique<std::thread>(&FakeSocketLimited::sendClose, this);
507         }
508         return mQueries <= sLimit;
509     }
startHandshake()510     bool startHandshake() override { return true; }
511 
512   private:
sendClose()513     void sendClose() {
514         {
515             LOG(DEBUG) << "FakeSocketLimited::sendClose acquiring mLock";
516             std::lock_guard guard(mLock);
517             LOG(DEBUG) << "FakeSocketLimited::sendClose acquired mLock";
518             for (auto& thread : mThreads) {
519                 LOG(DEBUG) << "FakeSocketLimited::sendClose joining response thread";
520                 thread.join();
521                 LOG(DEBUG) << "FakeSocketLimited::sendClose joined response thread";
522             }
523             mThreads.clear();
524         }
525         mObserver->onClosed();
526     }
527     std::mutex mLock;
528     IDnsTlsSocketObserver* const mObserver;
529     int mQueries GUARDED_BY(mLock);
530     std::vector<std::thread> mThreads GUARDED_BY(mLock);
531     std::unique_ptr<std::thread> mCloser GUARDED_BY(mLock);
532 };
533 
534 int FakeSocketLimited::sLimit;
535 size_t FakeSocketLimited::sMaxSize;
536 
TEST_F(TransportTest,SilentDrop)537 TEST_F(TransportTest, SilentDrop) {
538     FakeSocketLimited::sLimit = 10;  // Close the socket after 10 queries.
539     FakeSocketLimited::sMaxSize = 0;  // Silently drop all queries
540     FakeSocketFactory<FakeSocketLimited> factory;
541     DnsTlsTransport transport(SERVER1, MARK, &factory);
542 
543     // Queue up 10 queries.  They will all be ignored, and after the 10th,
544     // the socket will close.  Transport will retry them all, until they
545     // all hit the retry limit and expire.
546     std::vector<std::future<DnsTlsTransport::Result>> results;
547     results.reserve(FakeSocketLimited::sLimit);
548     for (int i = 0; i < FakeSocketLimited::sLimit; ++i) {
549         results.push_back(transport.query(makeSlice(QUERY)));
550     }
551     for (auto& result : results) {
552         auto r = result.get();
553         EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
554         EXPECT_TRUE(r.response.empty());
555     }
556 
557     // Reconnections might be triggered depending on the flag.
558     EXPECT_EQ(transport.getConnectCounter(),
559               Experiments::getInstance()->getFlag(DOT_MAXTRIES_FLAG, DnsTlsQueryMap::kMaxTries));
560 }
561 
TEST_F(TransportTest,PartialDrop)562 TEST_F(TransportTest, PartialDrop) {
563     FakeSocketLimited::sLimit = 10;  // Close the socket after 10 queries.
564     FakeSocketLimited::sMaxSize = SIZE - 2;  // Silently drop "long" queries
565     FakeSocketFactory<FakeSocketLimited> factory;
566     DnsTlsTransport transport(SERVER1, MARK, &factory);
567 
568     // Queue up 100 queries, alternating "short" which will be served and "long"
569     // which will be dropped.
570     const int num_queries = 10 * FakeSocketLimited::sLimit;
571     std::vector<bytevec> queries(num_queries);
572     std::vector<std::future<DnsTlsTransport::Result>> results;
573     results.reserve(num_queries);
574     for (int i = 0; i < num_queries; ++i) {
575         queries[i] = make_query(i, SIZE + (i % 2));
576         results.push_back(transport.query(makeSlice(queries[i])));
577     }
578     // Just check the short queries, which are at the even indices.
579     for (int i = 0; i < num_queries; i += 2) {
580         auto r = results[i].get();
581         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
582         EXPECT_EQ(queries[i], r.response);
583     }
584 
585     // TODO: transport.getConnectCounter() seems not stable in this test. Find how to check the
586     // connect attempts for this test.
587 }
588 
TEST_F(TransportTest,ConnectCounter)589 TEST_F(TransportTest, ConnectCounter) {
590     FakeSocketLimited::sLimit = 2;       // Close the socket after 2 queries.
591     FakeSocketLimited::sMaxSize = SIZE;  // No query drops.
592     FakeSocketFactory<FakeSocketLimited> factory;
593     DnsTlsTransport transport(SERVER1, MARK, &factory);
594 
595     // Connecting on demand.
596     EXPECT_EQ(transport.getConnectCounter(), 0);
597 
598     const int num_queries = 10;
599     std::vector<std::future<DnsTlsTransport::Result>> results;
600     results.reserve(num_queries);
601     for (int i = 0; i < num_queries; i++) {
602         // Reconnections take place every two queries.
603         results.push_back(transport.query(makeSlice(QUERY)));
604     }
605     for (int i = 0; i < num_queries; i++) {
606         auto r = results[i].get();
607         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
608     }
609 
610     EXPECT_EQ(transport.getConnectCounter(), num_queries / FakeSocketLimited::sLimit);
611 }
612 
613 // Simulate a malfunctioning server that injects extra miscellaneous
614 // responses to queries that were not asked.  This will cause wrong answers but
615 // must not crash the Transport.
616 class FakeSocketGarbage : public IDnsTlsSocket {
617   public:
FakeSocketGarbage(IDnsTlsSocketObserver * observer)618     explicit FakeSocketGarbage(IDnsTlsSocketObserver* observer) : mObserver(observer) {
619         // Inject a garbage event.
620         mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(ID + 1, SIZE));
621     }
~FakeSocketGarbage()622     ~FakeSocketGarbage() {
623         std::lock_guard guard(mLock);
624         for (auto& thread : mThreads) {
625             thread.join();
626         }
627     }
query(uint16_t id,const Slice query)628     bool query(uint16_t id, const Slice query) override {
629         std::lock_guard guard(mLock);
630         // Return the response twice.
631         auto echo = make_echo(id, query);
632         mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, echo);
633         mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, echo);
634         // Also return some other garbage
635         mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(id + 1, query.size() + 2));
636         return true;
637     }
startHandshake()638     bool startHandshake() override { return true; }
639 
640   private:
641     std::mutex mLock;
642     std::vector<std::thread> mThreads GUARDED_BY(mLock);
643     IDnsTlsSocketObserver* const mObserver;
644 };
645 
TEST_F(TransportTest,IgnoringGarbage)646 TEST_F(TransportTest, IgnoringGarbage) {
647     FakeSocketFactory<FakeSocketGarbage> factory;
648     DnsTlsTransport transport(SERVER1, MARK, &factory);
649     for (int i = 0; i < 10; ++i) {
650         auto r = transport.query(makeSlice(QUERY)).get();
651 
652         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
653         // Don't check the response because this server is malfunctioning.
654     }
655     EXPECT_EQ(transport.getConnectCounter(), 1);
656 }
657 
658 // Dispatcher tests
659 class DispatcherTest : public BaseTest {};
660 
TEST_F(DispatcherTest,Query)661 TEST_F(DispatcherTest, Query) {
662     bytevec ans(4096);
663     int resplen = 0;
664     bool connectTriggered = false;
665 
666     auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
667     DnsTlsDispatcher dispatcher(std::move(factory));
668     auto r = dispatcher.query(SERVER1, NETID, MARK, makeSlice(QUERY), makeSlice(ans), &resplen,
669                               &connectTriggered);
670 
671     EXPECT_EQ(DnsTlsTransport::Response::success, r);
672     EXPECT_EQ(int(QUERY.size()), resplen);
673     EXPECT_TRUE(connectTriggered);
674     ans.resize(resplen);
675     EXPECT_EQ(QUERY, ans);
676 
677     // Expect to reuse the connection.
678     r = dispatcher.query(SERVER1, NETID, MARK, makeSlice(QUERY), makeSlice(ans), &resplen,
679                          &connectTriggered);
680     EXPECT_EQ(DnsTlsTransport::Response::success, r);
681     EXPECT_FALSE(connectTriggered);
682 }
683 
TEST_F(DispatcherTest,AnswerTooLarge)684 TEST_F(DispatcherTest, AnswerTooLarge) {
685     bytevec ans(SIZE - 1);  // Too small to hold the answer
686     int resplen = 0;
687     bool connectTriggered = false;
688 
689     auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
690     DnsTlsDispatcher dispatcher(std::move(factory));
691     auto r = dispatcher.query(SERVER1, NETID, MARK, makeSlice(QUERY), makeSlice(ans), &resplen,
692                               &connectTriggered);
693 
694     EXPECT_EQ(DnsTlsTransport::Response::limit_error, r);
695     EXPECT_TRUE(connectTriggered);
696 }
697 
698 template<class T>
699 class TrackingFakeSocketFactory : public IDnsTlsSocketFactory {
700   public:
TrackingFakeSocketFactory()701     TrackingFakeSocketFactory() {}
createDnsTlsSocket(const DnsTlsServer & server,unsigned mark,IDnsTlsSocketObserver * observer,DnsTlsSessionCache * cache ATTRIBUTE_UNUSED)702     std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
703             const DnsTlsServer& server,
704             unsigned mark,
705             IDnsTlsSocketObserver* observer,
706             DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
707         std::lock_guard guard(mLock);
708         keys.emplace(mark, server);
709         return std::make_unique<T>(observer);
710     }
711     std::multiset<std::pair<unsigned, DnsTlsServer>> keys;
712 
713   private:
714     std::mutex mLock;
715 };
716 
TEST_F(DispatcherTest,Dispatching)717 TEST_F(DispatcherTest, Dispatching) {
718     FakeSocketDelay::sDelay = 5;
719     FakeSocketDelay::sReverse = true;
720     auto factory = std::make_unique<TrackingFakeSocketFactory<FakeSocketDelay>>();
721     auto* weak_factory = factory.get();  // Valid as long as dispatcher is in scope.
722     DnsTlsDispatcher dispatcher(std::move(factory));
723 
724     // Populate a vector of two servers and two socket marks, four combinations
725     // in total.
726     std::vector<std::pair<unsigned, DnsTlsServer>> keys;
727     keys.emplace_back(MARK, SERVER1);
728     keys.emplace_back(MARK + 1, SERVER1);
729     keys.emplace_back(MARK, V4ADDR2);
730     keys.emplace_back(MARK + 1, V4ADDR2);
731 
732     // Do several queries on each server.  They should all succeed.
733     std::vector<std::thread> threads;
734     for (size_t i = 0; i < FakeSocketDelay::sDelay * keys.size(); ++i) {
735         auto key = keys[i % keys.size()];
736         threads.emplace_back([key, i] (DnsTlsDispatcher* dispatcher) {
737             auto q = make_query(i, SIZE);
738             bytevec ans(4096);
739             int resplen = 0;
740             bool connectTriggered = false;
741             unsigned mark = key.first;
742             unsigned netId = key.first;
743             const DnsTlsServer& server = key.second;
744             auto r = dispatcher->query(server, netId, mark, makeSlice(q), makeSlice(ans), &resplen,
745                                        &connectTriggered);
746             EXPECT_EQ(DnsTlsTransport::Response::success, r);
747             EXPECT_EQ(int(q.size()), resplen);
748             ans.resize(resplen);
749             EXPECT_EQ(q, ans);
750         }, &dispatcher);
751     }
752     for (auto& thread : threads) {
753         thread.join();
754     }
755     // We expect that the factory created one socket for each key.
756     EXPECT_EQ(keys.size(), weak_factory->keys.size());
757     for (auto& key : keys) {
758         EXPECT_EQ(1U, weak_factory->keys.count(key));
759     }
760 }
761 
762 // Check DnsTlsServer's comparison logic.
763 AddressComparator ADDRESS_COMPARATOR;
isAddressEqual(const DnsTlsServer & s1,const DnsTlsServer & s2)764 bool isAddressEqual(const DnsTlsServer& s1, const DnsTlsServer& s2) {
765     bool cmp1 = ADDRESS_COMPARATOR(s1, s2);
766     bool cmp2 = ADDRESS_COMPARATOR(s2, s1);
767     EXPECT_FALSE(cmp1 && cmp2);
768     return !cmp1 && !cmp2;
769 }
770 
checkUnequal(const DnsTlsServer & s1,const DnsTlsServer & s2)771 void checkUnequal(const DnsTlsServer& s1, const DnsTlsServer& s2) {
772     EXPECT_TRUE(s1 == s1);
773     EXPECT_TRUE(s2 == s2);
774     EXPECT_TRUE(isAddressEqual(s1, s1));
775     EXPECT_TRUE(isAddressEqual(s2, s2));
776 
777     EXPECT_TRUE(s1 < s2 ^ s2 < s1);
778     EXPECT_FALSE(s1 == s2);
779     EXPECT_FALSE(s2 == s1);
780 }
781 
checkEqual(const DnsTlsServer & s1,const DnsTlsServer & s2)782 void checkEqual(const DnsTlsServer& s1, const DnsTlsServer& s2) {
783     EXPECT_TRUE(s1 == s1);
784     EXPECT_TRUE(s2 == s2);
785     EXPECT_TRUE(isAddressEqual(s1, s1));
786     EXPECT_TRUE(isAddressEqual(s2, s2));
787 
788     EXPECT_FALSE(s1 < s2);
789     EXPECT_FALSE(s2 < s1);
790     EXPECT_TRUE(s1 == s2);
791     EXPECT_TRUE(s2 == s1);
792 }
793 
794 class ServerTest : public BaseTest {};
795 
TEST_F(ServerTest,IPv4)796 TEST_F(ServerTest, IPv4) {
797     checkUnequal(DnsTlsServer(V4ADDR1), DnsTlsServer(V4ADDR2));
798     EXPECT_FALSE(isAddressEqual(DnsTlsServer(V4ADDR1), DnsTlsServer(V4ADDR2)));
799 }
800 
TEST_F(ServerTest,IPv6)801 TEST_F(ServerTest, IPv6) {
802     checkUnequal(DnsTlsServer(V6ADDR1), DnsTlsServer(V6ADDR2));
803     EXPECT_FALSE(isAddressEqual(DnsTlsServer(V6ADDR1), DnsTlsServer(V6ADDR2)));
804 }
805 
TEST_F(ServerTest,MixedAddressFamily)806 TEST_F(ServerTest, MixedAddressFamily) {
807     checkUnequal(DnsTlsServer(V6ADDR1), DnsTlsServer(V4ADDR1));
808     EXPECT_FALSE(isAddressEqual(DnsTlsServer(V6ADDR1), DnsTlsServer(V4ADDR1)));
809 }
810 
TEST_F(ServerTest,IPv6ScopeId)811 TEST_F(ServerTest, IPv6ScopeId) {
812     DnsTlsServer s1(IPAddress::forString("fe80::1%1"));
813     DnsTlsServer s2(IPAddress::forString("fe80::1%2"));
814     checkUnequal(s1, s2);
815     EXPECT_FALSE(isAddressEqual(s1, s2));
816 
817     EXPECT_FALSE(s1.wasExplicitlyConfigured());
818     EXPECT_FALSE(s2.wasExplicitlyConfigured());
819 }
820 
TEST_F(ServerTest,Port)821 TEST_F(ServerTest, Port) {
822     DnsTlsServer s1(IPSockAddr::toIPSockAddr("192.0.2.1", 853));
823     DnsTlsServer s2(IPSockAddr::toIPSockAddr("192.0.2.1", 854));
824     checkUnequal(s1, s2);
825     EXPECT_TRUE(isAddressEqual(s1, s2));
826     EXPECT_EQ(s1.toIpString(), "192.0.2.1");
827     EXPECT_EQ(s2.toIpString(), "192.0.2.1");
828 
829     DnsTlsServer s3(IPSockAddr::toIPSockAddr("2001:db8::1", 853));
830     DnsTlsServer s4(IPSockAddr::toIPSockAddr("2001:db8::1", 854));
831     checkUnequal(s3, s4);
832     EXPECT_TRUE(isAddressEqual(s3, s4));
833     EXPECT_EQ(s3.toIpString(), "2001:db8::1");
834     EXPECT_EQ(s4.toIpString(), "2001:db8::1");
835 
836     EXPECT_FALSE(s1.wasExplicitlyConfigured());
837     EXPECT_FALSE(s2.wasExplicitlyConfigured());
838 }
839 
TEST_F(ServerTest,Name)840 TEST_F(ServerTest, Name) {
841     DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
842     s1.name = SERVERNAME1;
843     checkUnequal(s1, s2);
844     s2.name = SERVERNAME2;
845     checkUnequal(s1, s2);
846     EXPECT_TRUE(isAddressEqual(s1, s2));
847 
848     EXPECT_TRUE(s1.wasExplicitlyConfigured());
849     EXPECT_TRUE(s2.wasExplicitlyConfigured());
850 }
851 
TEST_F(ServerTest,State)852 TEST_F(ServerTest, State) {
853     DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
854     checkEqual(s1, s2);
855     s1.setValidationState(Validation::success);
856     checkEqual(s1, s2);
857     s2.setValidationState(Validation::fail);
858     checkEqual(s1, s2);
859     s1.setActive(true);
860     checkEqual(s1, s2);
861     s2.setActive(false);
862     checkEqual(s1, s2);
863 
864     EXPECT_EQ(s1.validationState(), Validation::success);
865     EXPECT_EQ(s2.validationState(), Validation::fail);
866     EXPECT_TRUE(s1.active());
867     EXPECT_FALSE(s2.active());
868 }
869 
870 class QueryMapTest : public NetNativeTestBase {};
871 
TEST_F(QueryMapTest,Basic)872 TEST_F(QueryMapTest, Basic) {
873     DnsTlsQueryMap map;
874 
875     EXPECT_TRUE(map.empty());
876 
877     bytevec q0 = make_query(999, SIZE);
878     bytevec q1 = make_query(888, SIZE);
879     bytevec q2 = make_query(777, SIZE);
880 
881     auto f0 = map.recordQuery(makeSlice(q0));
882     auto f1 = map.recordQuery(makeSlice(q1));
883     auto f2 = map.recordQuery(makeSlice(q2));
884 
885     // Check return values of recordQuery
886     EXPECT_EQ(0, f0->query.newId);
887     EXPECT_EQ(1, f1->query.newId);
888     EXPECT_EQ(2, f2->query.newId);
889 
890     // Check side effects of recordQuery
891     EXPECT_FALSE(map.empty());
892 
893     auto all = map.getAll();
894     EXPECT_EQ(3U, all.size());
895 
896     EXPECT_EQ(0, all[0].newId);
897     EXPECT_EQ(1, all[1].newId);
898     EXPECT_EQ(2, all[2].newId);
899 
900     EXPECT_EQ(q0, all[0].query);
901     EXPECT_EQ(q1, all[1].query);
902     EXPECT_EQ(q2, all[2].query);
903 
904     bytevec a0 = make_query(0, SIZE);
905     bytevec a1 = make_query(1, SIZE);
906     bytevec a2 = make_query(2, SIZE);
907 
908     // Return responses out of order
909     map.onResponse(a2);
910     map.onResponse(a0);
911     map.onResponse(a1);
912 
913     EXPECT_TRUE(map.empty());
914 
915     auto r0 = f0->result.get();
916     auto r1 = f1->result.get();
917     auto r2 = f2->result.get();
918 
919     EXPECT_EQ(DnsTlsQueryMap::Response::success, r0.code);
920     EXPECT_EQ(DnsTlsQueryMap::Response::success, r1.code);
921     EXPECT_EQ(DnsTlsQueryMap::Response::success, r2.code);
922 
923     const bytevec& d0 = r0.response;
924     const bytevec& d1 = r1.response;
925     const bytevec& d2 = r2.response;
926 
927     // The ID should match the query
928     EXPECT_EQ(999, d0[0] << 8 | d0[1]);
929     EXPECT_EQ(888, d1[0] << 8 | d1[1]);
930     EXPECT_EQ(777, d2[0] << 8 | d2[1]);
931     // The body should match the answer
932     EXPECT_EQ(bytevec(a0.begin() + 2, a0.end()), bytevec(d0.begin() + 2, d0.end()));
933     EXPECT_EQ(bytevec(a1.begin() + 2, a1.end()), bytevec(d1.begin() + 2, d1.end()));
934     EXPECT_EQ(bytevec(a2.begin() + 2, a2.end()), bytevec(d2.begin() + 2, d2.end()));
935 }
936 
TEST_F(QueryMapTest,FillHole)937 TEST_F(QueryMapTest, FillHole) {
938     DnsTlsQueryMap map;
939     std::vector<std::unique_ptr<DnsTlsQueryMap::QueryFuture>> futures(UINT16_MAX + 1);
940     for (uint32_t i = 0; i <= UINT16_MAX; ++i) {
941         futures[i] = map.recordQuery(makeSlice(QUERY));
942         ASSERT_TRUE(futures[i]);  // answers[i] should be nonnull.
943         EXPECT_EQ(i, futures[i]->query.newId);
944     }
945 
946     // The map should now be full.
947     EXPECT_EQ(size_t(UINT16_MAX + 1), map.getAll().size());
948 
949     // Trying to add another query should fail because the map is full.
950     EXPECT_FALSE(map.recordQuery(makeSlice(QUERY)));
951 
952     // Send an answer to query 40000
953     auto answer = make_query(40000, SIZE);
954     map.onResponse(answer);
955     auto result = futures[40000]->result.get();
956     EXPECT_EQ(DnsTlsQueryMap::Response::success, result.code);
957     EXPECT_EQ(ID, result.response[0] << 8 | result.response[1]);
958     EXPECT_EQ(bytevec(answer.begin() + 2, answer.end()),
959               bytevec(result.response.begin() + 2, result.response.end()));
960 
961     // There should now be room in the map.
962     EXPECT_EQ(size_t(UINT16_MAX), map.getAll().size());
963     auto f = map.recordQuery(makeSlice(QUERY));
964     ASSERT_TRUE(f);
965     EXPECT_EQ(40000, f->query.newId);
966 
967     // The map should now be full again.
968     EXPECT_EQ(size_t(UINT16_MAX + 1), map.getAll().size());
969     EXPECT_FALSE(map.recordQuery(makeSlice(QUERY)));
970 }
971 
972 class DnsTlsSocketTest : public NetNativeTestBase {
973   protected:
974     class MockDnsTlsSocketObserver : public IDnsTlsSocketObserver {
975       public:
976         MOCK_METHOD(void, onClosed, (), (override));
977         MOCK_METHOD(void, onResponse, (std::vector<uint8_t>), (override));
978     };
979 
makeDnsTlsSocket(IDnsTlsSocketObserver * observer)980     std::unique_ptr<DnsTlsSocket> makeDnsTlsSocket(IDnsTlsSocketObserver* observer) {
981         return std::make_unique<DnsTlsSocket>(this->server, MARK, observer, &this->cache);
982     }
983 
enableAsyncHandshake(const std::unique_ptr<DnsTlsSocket> & socket)984     void enableAsyncHandshake(const std::unique_ptr<DnsTlsSocket>& socket) {
985         ASSERT_TRUE(socket);
986         DnsTlsSocket* delegate = socket.get();
987         std::lock_guard guard(delegate->mLock);
988         delegate->mAsyncHandshake = true;
989     }
990 
991     static constexpr char kTlsAddr[] = "127.0.0.3";
992     static constexpr char kTlsPort[] = "8530";  // High-numbered port so root isn't required.
993     static constexpr char kBackendAddr[] = "192.0.2.1";
994     static constexpr char kBackendPort[] = "8531";  // High-numbered port so root isn't required.
995 
996     test::DnsTlsFrontend tls{kTlsAddr, kTlsPort, kBackendAddr, kBackendPort};
997 
998     const DnsTlsServer server{IPSockAddr::toIPSockAddr(kTlsAddr, std::stoi(kTlsPort))};
999     DnsTlsSessionCache cache;
1000 };
1001 
TEST_F(DnsTlsSocketTest,SlowDestructor)1002 TEST_F(DnsTlsSocketTest, SlowDestructor) {
1003     ASSERT_TRUE(tls.startServer());
1004 
1005     MockDnsTlsSocketObserver observer;
1006     auto socket = makeDnsTlsSocket(&observer);
1007 
1008     ASSERT_TRUE(socket->initialize());
1009     ASSERT_TRUE(socket->startHandshake());
1010 
1011     // Test: Time the socket destructor.  This should be fast.
1012     auto before = std::chrono::steady_clock::now();
1013     EXPECT_CALL(observer, onClosed);
1014     socket.reset();
1015     auto after = std::chrono::steady_clock::now();
1016     auto delay = after - before;
1017     LOG(DEBUG) << "Shutdown took " << delay / std::chrono::nanoseconds{1} << "ns";
1018     // Shutdown should complete in milliseconds, but if the shutdown signal is lost
1019     // it will wait for the timeout, which is expected to take 20seconds.
1020     EXPECT_LT(delay, std::chrono::seconds{5});
1021 }
1022 
TEST_F(DnsTlsSocketTest,StartHandshake)1023 TEST_F(DnsTlsSocketTest, StartHandshake) {
1024     ASSERT_TRUE(tls.startServer());
1025 
1026     MockDnsTlsSocketObserver observer;
1027     auto socket = makeDnsTlsSocket(&observer);
1028 
1029     // Call the function before the call to initialize().
1030     EXPECT_FALSE(socket->startHandshake());
1031 
1032     // Call the function after the call to initialize().
1033     EXPECT_TRUE(socket->initialize());
1034     EXPECT_TRUE(socket->startHandshake());
1035 
1036     // Call both of them again.
1037     EXPECT_FALSE(socket->initialize());
1038     EXPECT_FALSE(socket->startHandshake());
1039 
1040     // Should happen when joining the loop thread in |socket| destruction.
1041     EXPECT_CALL(observer, onClosed);
1042 }
1043 
TEST_F(DnsTlsSocketTest,ShutdownSignal)1044 TEST_F(DnsTlsSocketTest, ShutdownSignal) {
1045     ASSERT_TRUE(tls.startServer());
1046 
1047     MockDnsTlsSocketObserver observer;
1048     std::unique_ptr<DnsTlsSocket> socket;
1049 
1050     const auto setupAndStartHandshake = [&]() {
1051         socket = makeDnsTlsSocket(&observer);
1052         EXPECT_TRUE(socket->initialize());
1053         enableAsyncHandshake(socket);
1054         EXPECT_TRUE(socket->startHandshake());
1055     };
1056     const auto triggerShutdown = [&](const std::string& traceLog) {
1057         SCOPED_TRACE(traceLog);
1058         auto before = std::chrono::steady_clock::now();
1059         EXPECT_CALL(observer, onClosed);
1060         socket.reset();
1061         auto after = std::chrono::steady_clock::now();
1062         auto delay = after - before;
1063         LOG(INFO) << "Shutdown took " << delay / std::chrono::nanoseconds{1} << "ns";
1064         EXPECT_LT(delay, std::chrono::seconds{1});
1065     };
1066 
1067     tls.setHangOnHandshakeForTesting(true);
1068 
1069     // Test 1: Reset the DnsTlsSocket which is doing the handshake.
1070     setupAndStartHandshake();
1071     triggerShutdown("Shutdown handshake w/o query requests");
1072 
1073     // Test 2: Reset the DnsTlsSocket which is doing the handshake with some query requests.
1074     setupAndStartHandshake();
1075 
1076     // DnsTlsSocket doesn't report the status of pending queries. The decision whether to mark
1077     // a query request as failed or not is made in DnsTlsTransport.
1078     EXPECT_CALL(observer, onResponse).Times(0);
1079     EXPECT_TRUE(socket->query(1, makeSlice(QUERY)));
1080     EXPECT_TRUE(socket->query(2, makeSlice(QUERY)));
1081     triggerShutdown("Shutdown handshake w/ query requests");
1082 }
1083 
1084 } // end of namespace net
1085 } // end of namespace android
1086