1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "resolv"
18
19 #include <arpa/inet.h>
20
21 #include <chrono>
22
23 #include <android-base/logging.h>
24 #include <android-base/macros.h>
25 #include <gmock/gmock.h>
26 #include <gtest/gtest.h>
27 #include <netdutils/NetNativeTestBase.h>
28 #include <netdutils/Slice.h>
29
30 #include "DnsTlsDispatcher.h"
31 #include "DnsTlsQueryMap.h"
32 #include "DnsTlsServer.h"
33 #include "DnsTlsSessionCache.h"
34 #include "DnsTlsSocket.h"
35 #include "DnsTlsTransport.h"
36 #include "Experiments.h"
37 #include "IDnsTlsSocket.h"
38 #include "IDnsTlsSocketFactory.h"
39 #include "IDnsTlsSocketObserver.h"
40 #include "tests/dns_responder/dns_tls_frontend.h"
41
42 namespace android {
43 namespace net {
44
45 using netdutils::IPAddress;
46 using netdutils::IPSockAddr;
47 using netdutils::makeSlice;
48 using netdutils::Slice;
49
50 typedef std::vector<uint8_t> bytevec;
51
52 static const std::string DOT_MAXTRIES_FLAG = "dot_maxtries";
53 static const std::string SERVERNAME1 = "dns.example.com";
54 static const std::string SERVERNAME2 = "dns.example.org";
55 static const IPAddress V4ADDR1 = IPAddress::forString("192.0.2.1");
56 static const IPAddress V4ADDR2 = IPAddress::forString("192.0.2.2");
57 static const IPAddress V6ADDR1 = IPAddress::forString("2001:db8::1");
58 static const IPAddress V6ADDR2 = IPAddress::forString("2001:db8::2");
59
60 // BaseTest just provides constants that are useful for the tests.
61 class BaseTest : public NetNativeTestBase {
62 protected:
BaseTest()63 BaseTest() {
64 SERVER1.name = SERVERNAME1;
65 }
66
67 DnsTlsServer SERVER1{V4ADDR1};
68 };
69
make_query(uint16_t id,size_t size)70 bytevec make_query(uint16_t id, size_t size) {
71 bytevec vec(size);
72 vec[0] = id >> 8;
73 vec[1] = id;
74 // Arbitrarily fill the query body with unique data.
75 for (size_t i = 2; i < size; ++i) {
76 vec[i] = id + i;
77 }
78 return vec;
79 }
80
81 // Query constants
82 const unsigned NETID = 123;
83 const unsigned MARK = 123;
84 const uint16_t ID = 52;
85 const uint16_t SIZE = 22;
86 const bytevec QUERY = make_query(ID, SIZE);
87
88 template <class T>
89 class FakeSocketFactory : public IDnsTlsSocketFactory {
90 public:
FakeSocketFactory()91 FakeSocketFactory() {}
createDnsTlsSocket(const DnsTlsServer & server ATTRIBUTE_UNUSED,unsigned mark ATTRIBUTE_UNUSED,IDnsTlsSocketObserver * observer,DnsTlsSessionCache * cache ATTRIBUTE_UNUSED)92 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
93 const DnsTlsServer& server ATTRIBUTE_UNUSED,
94 unsigned mark ATTRIBUTE_UNUSED,
95 IDnsTlsSocketObserver* observer,
96 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
97 return std::make_unique<T>(observer);
98 }
99 };
100
make_echo(uint16_t id,const Slice query)101 bytevec make_echo(uint16_t id, const Slice query) {
102 bytevec response(query.size() + 2);
103 response[0] = id >> 8;
104 response[1] = id;
105 // Echo the query as the fake response.
106 memcpy(response.data() + 2, query.base(), query.size());
107 return response;
108 }
109
110 // Simplest possible fake server. This just echoes the query as the response.
111 class FakeSocketEcho : public IDnsTlsSocket {
112 public:
FakeSocketEcho(IDnsTlsSocketObserver * observer)113 explicit FakeSocketEcho(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
query(uint16_t id,const Slice query)114 bool query(uint16_t id, const Slice query) override {
115 // Return the response immediately (asynchronously).
116 std::thread(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query)).detach();
117 return true;
118 }
startHandshake()119 bool startHandshake() override { return true; }
120
121 private:
122 IDnsTlsSocketObserver* const mObserver;
123 };
124
125 class TransportTest : public BaseTest {};
126
TEST_F(TransportTest,Query)127 TEST_F(TransportTest, Query) {
128 FakeSocketFactory<FakeSocketEcho> factory;
129 DnsTlsTransport transport(SERVER1, MARK, &factory);
130 auto r = transport.query(makeSlice(QUERY)).get();
131
132 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
133 EXPECT_EQ(QUERY, r.response);
134 EXPECT_EQ(transport.getConnectCounter(), 1);
135 }
136
137 // Fake Socket that echoes the observed query ID as the response body.
138 class FakeSocketId : public IDnsTlsSocket {
139 public:
FakeSocketId(IDnsTlsSocketObserver * observer)140 explicit FakeSocketId(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
query(uint16_t id,const Slice query ATTRIBUTE_UNUSED)141 bool query(uint16_t id, const Slice query ATTRIBUTE_UNUSED) override {
142 // Return the response immediately (asynchronously).
143 bytevec response(4);
144 // Echo the ID in the header to match the response to the query.
145 // This will be overwritten by DnsTlsQueryMap.
146 response[0] = id >> 8;
147 response[1] = id;
148 // Echo the ID in the body, so that the test can verify which ID was used by
149 // DnsTlsQueryMap.
150 response[2] = id >> 8;
151 response[3] = id;
152 std::thread(&IDnsTlsSocketObserver::onResponse, mObserver, response).detach();
153 return true;
154 }
startHandshake()155 bool startHandshake() override { return true; }
156
157 private:
158 IDnsTlsSocketObserver* const mObserver;
159 };
160
161 // Test that IDs are properly reused
TEST_F(TransportTest,IdReuse)162 TEST_F(TransportTest, IdReuse) {
163 FakeSocketFactory<FakeSocketId> factory;
164 DnsTlsTransport transport(SERVER1, MARK, &factory);
165 for (int i = 0; i < 100; ++i) {
166 // Send a query.
167 std::future<DnsTlsTransport::Result> f = transport.query(makeSlice(QUERY));
168 // Wait for the response.
169 DnsTlsTransport::Result r = f.get();
170 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
171
172 // All queries should have an observed ID of zero, because it is returned to the ID pool
173 // after each use.
174 EXPECT_EQ(0, (r.response[2] << 8) | r.response[3]);
175 }
176 EXPECT_EQ(transport.getConnectCounter(), 1);
177 }
178
179 // These queries might be handled in serial or parallel as they race the
180 // responses.
TEST_F(TransportTest,RacingQueries_10000)181 TEST_F(TransportTest, RacingQueries_10000) {
182 FakeSocketFactory<FakeSocketEcho> factory;
183 DnsTlsTransport transport(SERVER1, MARK, &factory);
184 std::vector<std::future<DnsTlsTransport::Result>> results;
185 // Fewer than 65536 queries to avoid ID exhaustion.
186 const int num_queries = 10000;
187 results.reserve(num_queries);
188 for (int i = 0; i < num_queries; ++i) {
189 results.push_back(transport.query(makeSlice(QUERY)));
190 }
191 for (auto& result : results) {
192 auto r = result.get();
193 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
194 EXPECT_EQ(QUERY, r.response);
195 }
196 EXPECT_EQ(transport.getConnectCounter(), 1);
197 }
198
199 // A server that waits until sDelay queries are queued before responding.
200 class FakeSocketDelay : public IDnsTlsSocket {
201 public:
FakeSocketDelay(IDnsTlsSocketObserver * observer)202 explicit FakeSocketDelay(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
~FakeSocketDelay()203 ~FakeSocketDelay() {
204 std::lock_guard guard(mLock);
205 sDelay = 1;
206 sReverse = false;
207 sConnectable = true;
208 }
209 inline static size_t sDelay = 1;
210 inline static bool sReverse = false;
211 inline static bool sConnectable = true;
212
query(uint16_t id,const Slice query)213 bool query(uint16_t id, const Slice query) override {
214 LOG(DEBUG) << "FakeSocketDelay got query with ID " << int(id);
215 std::lock_guard guard(mLock);
216 // Check for duplicate IDs.
217 EXPECT_EQ(0U, mIds.count(id));
218 mIds.insert(id);
219
220 // Store response.
221 mResponses.push_back(make_echo(id, query));
222
223 LOG(DEBUG) << "Up to " << mResponses.size() << " out of " << sDelay << " queries";
224 if (mResponses.size() == sDelay) {
225 std::thread(&FakeSocketDelay::sendResponses, this).detach();
226 }
227 return true;
228 }
startHandshake()229 bool startHandshake() override { return sConnectable; }
230
231 private:
sendResponses()232 void sendResponses() {
233 std::lock_guard guard(mLock);
234 if (sReverse) {
235 std::reverse(std::begin(mResponses), std::end(mResponses));
236 }
237 for (auto& response : mResponses) {
238 mObserver->onResponse(response);
239 }
240 mIds.clear();
241 mResponses.clear();
242 }
243
244 std::mutex mLock;
245 IDnsTlsSocketObserver* const mObserver;
246 std::set<uint16_t> mIds GUARDED_BY(mLock);
247 std::vector<bytevec> mResponses GUARDED_BY(mLock);
248 };
249
TEST_F(TransportTest,ParallelColliding)250 TEST_F(TransportTest, ParallelColliding) {
251 FakeSocketDelay::sDelay = 10;
252 FakeSocketDelay::sReverse = false;
253 FakeSocketFactory<FakeSocketDelay> factory;
254 DnsTlsTransport transport(SERVER1, MARK, &factory);
255 std::vector<std::future<DnsTlsTransport::Result>> results;
256 // Fewer than 65536 queries to avoid ID exhaustion.
257 results.reserve(FakeSocketDelay::sDelay);
258 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
259 results.push_back(transport.query(makeSlice(QUERY)));
260 }
261 for (auto& result : results) {
262 auto r = result.get();
263 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
264 EXPECT_EQ(QUERY, r.response);
265 }
266 EXPECT_EQ(transport.getConnectCounter(), 1);
267 }
268
TEST_F(TransportTest,ParallelColliding_Max)269 TEST_F(TransportTest, ParallelColliding_Max) {
270 FakeSocketDelay::sDelay = 65536;
271 FakeSocketDelay::sReverse = false;
272 FakeSocketFactory<FakeSocketDelay> factory;
273 DnsTlsTransport transport(SERVER1, MARK, &factory);
274 std::vector<std::future<DnsTlsTransport::Result>> results;
275 // Exactly 65536 queries should still be possible in parallel,
276 // even if they all have the same original ID.
277 results.reserve(FakeSocketDelay::sDelay);
278 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
279 results.push_back(transport.query(makeSlice(QUERY)));
280 }
281 for (auto& result : results) {
282 auto r = result.get();
283 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
284 EXPECT_EQ(QUERY, r.response);
285 }
286 EXPECT_EQ(transport.getConnectCounter(), 1);
287 }
288
TEST_F(TransportTest,ParallelUnique)289 TEST_F(TransportTest, ParallelUnique) {
290 FakeSocketDelay::sDelay = 10;
291 FakeSocketDelay::sReverse = false;
292 FakeSocketFactory<FakeSocketDelay> factory;
293 DnsTlsTransport transport(SERVER1, MARK, &factory);
294 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
295 std::vector<std::future<DnsTlsTransport::Result>> results;
296 results.reserve(FakeSocketDelay::sDelay);
297 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
298 queries[i] = make_query(i, SIZE);
299 results.push_back(transport.query(makeSlice(queries[i])));
300 }
301 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
302 auto r = results[i].get();
303 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
304 EXPECT_EQ(queries[i], r.response);
305 }
306 EXPECT_EQ(transport.getConnectCounter(), 1);
307 }
308
TEST_F(TransportTest,ParallelUnique_Max)309 TEST_F(TransportTest, ParallelUnique_Max) {
310 FakeSocketDelay::sDelay = 65536;
311 FakeSocketDelay::sReverse = false;
312 FakeSocketFactory<FakeSocketDelay> factory;
313 DnsTlsTransport transport(SERVER1, MARK, &factory);
314 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
315 std::vector<std::future<DnsTlsTransport::Result>> results;
316 // Exactly 65536 queries should still be possible in parallel,
317 // and they should all be mapped correctly back to the original ID.
318 results.reserve(FakeSocketDelay::sDelay);
319 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
320 queries[i] = make_query(i, SIZE);
321 results.push_back(transport.query(makeSlice(queries[i])));
322 }
323 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
324 auto r = results[i].get();
325 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
326 EXPECT_EQ(queries[i], r.response);
327 }
328 EXPECT_EQ(transport.getConnectCounter(), 1);
329 }
330
TEST_F(TransportTest,IdExhaustion)331 TEST_F(TransportTest, IdExhaustion) {
332 const int num_queries = 65536;
333 // A delay of 65537 is unreachable, because the maximum number
334 // of outstanding queries is 65536.
335 FakeSocketDelay::sDelay = num_queries + 1;
336 FakeSocketDelay::sReverse = false;
337 FakeSocketFactory<FakeSocketDelay> factory;
338 DnsTlsTransport transport(SERVER1, MARK, &factory);
339 std::vector<std::future<DnsTlsTransport::Result>> results;
340 // Issue the maximum number of queries.
341 results.reserve(num_queries);
342 for (int i = 0; i < num_queries; ++i) {
343 results.push_back(transport.query(makeSlice(QUERY)));
344 }
345
346 // The ID space is now full, so subsequent queries should fail immediately.
347 auto r = transport.query(makeSlice(QUERY)).get();
348 EXPECT_EQ(DnsTlsTransport::Response::internal_error, r.code);
349 EXPECT_TRUE(r.response.empty());
350
351 for (auto& result : results) {
352 // All other queries should remain outstanding.
353 EXPECT_EQ(std::future_status::timeout,
354 result.wait_for(std::chrono::duration<int>::zero()));
355 }
356 EXPECT_EQ(transport.getConnectCounter(), 1);
357 }
358
359 // Responses can come back from the server in any order. This should have no
360 // effect on Transport's observed behavior.
TEST_F(TransportTest,ReverseOrder)361 TEST_F(TransportTest, ReverseOrder) {
362 FakeSocketDelay::sDelay = 10;
363 FakeSocketDelay::sReverse = true;
364 FakeSocketFactory<FakeSocketDelay> factory;
365 DnsTlsTransport transport(SERVER1, MARK, &factory);
366 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
367 std::vector<std::future<DnsTlsTransport::Result>> results;
368 results.reserve(FakeSocketDelay::sDelay);
369 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
370 queries[i] = make_query(i, SIZE);
371 results.push_back(transport.query(makeSlice(queries[i])));
372 }
373 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
374 auto r = results[i].get();
375 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
376 EXPECT_EQ(queries[i], r.response);
377 }
378 EXPECT_EQ(transport.getConnectCounter(), 1);
379 }
380
TEST_F(TransportTest,ReverseOrder_Max)381 TEST_F(TransportTest, ReverseOrder_Max) {
382 FakeSocketDelay::sDelay = 65536;
383 FakeSocketDelay::sReverse = true;
384 FakeSocketFactory<FakeSocketDelay> factory;
385 DnsTlsTransport transport(SERVER1, MARK, &factory);
386 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
387 std::vector<std::future<DnsTlsTransport::Result>> results;
388 results.reserve(FakeSocketDelay::sDelay);
389 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
390 queries[i] = make_query(i, SIZE);
391 results.push_back(transport.query(makeSlice(queries[i])));
392 }
393 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
394 auto r = results[i].get();
395 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
396 EXPECT_EQ(queries[i], r.response);
397 }
398 EXPECT_EQ(transport.getConnectCounter(), 1);
399 }
400
401 // Returning null from the factory indicates a connection failure.
402 class NullSocketFactory : public IDnsTlsSocketFactory {
403 public:
NullSocketFactory()404 NullSocketFactory() {}
createDnsTlsSocket(const DnsTlsServer & server ATTRIBUTE_UNUSED,unsigned mark ATTRIBUTE_UNUSED,IDnsTlsSocketObserver * observer ATTRIBUTE_UNUSED,DnsTlsSessionCache * cache ATTRIBUTE_UNUSED)405 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
406 const DnsTlsServer& server ATTRIBUTE_UNUSED,
407 unsigned mark ATTRIBUTE_UNUSED,
408 IDnsTlsSocketObserver* observer ATTRIBUTE_UNUSED,
409 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
410 return nullptr;
411 }
412 };
413
TEST_F(TransportTest,ConnectFail)414 TEST_F(TransportTest, ConnectFail) {
415 // Failure on creating socket.
416 NullSocketFactory factory1;
417 DnsTlsTransport transport1(SERVER1, MARK, &factory1);
418 auto r = transport1.query(makeSlice(QUERY)).get();
419
420 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
421 EXPECT_TRUE(r.response.empty());
422 EXPECT_EQ(transport1.getConnectCounter(), 1);
423
424 // Failure on handshaking.
425 FakeSocketDelay::sConnectable = false;
426 FakeSocketFactory<FakeSocketDelay> factory2;
427 DnsTlsTransport transport2(SERVER1, MARK, &factory2);
428 r = transport2.query(makeSlice(QUERY)).get();
429
430 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
431 EXPECT_TRUE(r.response.empty());
432 EXPECT_EQ(transport2.getConnectCounter(), 1);
433 }
434
435 // Simulate a socket that connects but then immediately receives a server
436 // close notification.
437 class FakeSocketClose : public IDnsTlsSocket {
438 public:
FakeSocketClose(IDnsTlsSocketObserver * observer)439 explicit FakeSocketClose(IDnsTlsSocketObserver* observer)
440 : mCloser(&IDnsTlsSocketObserver::onClosed, observer) {}
~FakeSocketClose()441 ~FakeSocketClose() { mCloser.join(); }
query(uint16_t id ATTRIBUTE_UNUSED,const Slice query ATTRIBUTE_UNUSED)442 bool query(uint16_t id ATTRIBUTE_UNUSED,
443 const Slice query ATTRIBUTE_UNUSED) override {
444 return true;
445 }
startHandshake()446 bool startHandshake() override { return true; }
447
448 private:
449 std::thread mCloser;
450 };
451
TEST_F(TransportTest,CloseRetryFail)452 TEST_F(TransportTest, CloseRetryFail) {
453 FakeSocketFactory<FakeSocketClose> factory;
454 DnsTlsTransport transport(SERVER1, MARK, &factory);
455 auto r = transport.query(makeSlice(QUERY)).get();
456
457 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
458 EXPECT_TRUE(r.response.empty());
459
460 // Reconnections might be triggered depending on the flag.
461 EXPECT_EQ(transport.getConnectCounter(),
462 Experiments::getInstance()->getFlag(DOT_MAXTRIES_FLAG, DnsTlsQueryMap::kMaxTries));
463 }
464
465 // Simulate a server that occasionally closes the connection and silently
466 // drops some queries.
467 class FakeSocketLimited : public IDnsTlsSocket {
468 public:
469 static int sLimit; // Number of queries to answer per socket.
470 static size_t sMaxSize; // Silently discard queries greater than this size.
FakeSocketLimited(IDnsTlsSocketObserver * observer)471 explicit FakeSocketLimited(IDnsTlsSocketObserver* observer)
472 : mObserver(observer), mQueries(0) {}
~FakeSocketLimited()473 ~FakeSocketLimited() {
474 {
475 LOG(DEBUG) << "~FakeSocketLimited acquiring mLock";
476 std::lock_guard guard(mLock);
477 LOG(DEBUG) << "~FakeSocketLimited acquired mLock";
478 for (auto& thread : mThreads) {
479 LOG(DEBUG) << "~FakeSocketLimited joining response thread";
480 thread.join();
481 LOG(DEBUG) << "~FakeSocketLimited joined response thread";
482 }
483 mThreads.clear();
484 }
485
486 if (mCloser) {
487 LOG(DEBUG) << "~FakeSocketLimited joining closer thread";
488 mCloser->join();
489 LOG(DEBUG) << "~FakeSocketLimited joined closer thread";
490 }
491 }
query(uint16_t id,const Slice query)492 bool query(uint16_t id, const Slice query) override {
493 LOG(DEBUG) << "FakeSocketLimited::query acquiring mLock";
494 std::lock_guard guard(mLock);
495 LOG(DEBUG) << "FakeSocketLimited::query acquired mLock";
496 ++mQueries;
497
498 if (mQueries <= sLimit) {
499 LOG(DEBUG) << "size " << query.size() << " vs. limit of " << sMaxSize;
500 if (query.size() <= sMaxSize) {
501 // Return the response immediately (asynchronously).
502 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query));
503 }
504 }
505 if (mQueries == sLimit) {
506 mCloser = std::make_unique<std::thread>(&FakeSocketLimited::sendClose, this);
507 }
508 return mQueries <= sLimit;
509 }
startHandshake()510 bool startHandshake() override { return true; }
511
512 private:
sendClose()513 void sendClose() {
514 {
515 LOG(DEBUG) << "FakeSocketLimited::sendClose acquiring mLock";
516 std::lock_guard guard(mLock);
517 LOG(DEBUG) << "FakeSocketLimited::sendClose acquired mLock";
518 for (auto& thread : mThreads) {
519 LOG(DEBUG) << "FakeSocketLimited::sendClose joining response thread";
520 thread.join();
521 LOG(DEBUG) << "FakeSocketLimited::sendClose joined response thread";
522 }
523 mThreads.clear();
524 }
525 mObserver->onClosed();
526 }
527 std::mutex mLock;
528 IDnsTlsSocketObserver* const mObserver;
529 int mQueries GUARDED_BY(mLock);
530 std::vector<std::thread> mThreads GUARDED_BY(mLock);
531 std::unique_ptr<std::thread> mCloser GUARDED_BY(mLock);
532 };
533
534 int FakeSocketLimited::sLimit;
535 size_t FakeSocketLimited::sMaxSize;
536
TEST_F(TransportTest,SilentDrop)537 TEST_F(TransportTest, SilentDrop) {
538 FakeSocketLimited::sLimit = 10; // Close the socket after 10 queries.
539 FakeSocketLimited::sMaxSize = 0; // Silently drop all queries
540 FakeSocketFactory<FakeSocketLimited> factory;
541 DnsTlsTransport transport(SERVER1, MARK, &factory);
542
543 // Queue up 10 queries. They will all be ignored, and after the 10th,
544 // the socket will close. Transport will retry them all, until they
545 // all hit the retry limit and expire.
546 std::vector<std::future<DnsTlsTransport::Result>> results;
547 results.reserve(FakeSocketLimited::sLimit);
548 for (int i = 0; i < FakeSocketLimited::sLimit; ++i) {
549 results.push_back(transport.query(makeSlice(QUERY)));
550 }
551 for (auto& result : results) {
552 auto r = result.get();
553 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
554 EXPECT_TRUE(r.response.empty());
555 }
556
557 // Reconnections might be triggered depending on the flag.
558 EXPECT_EQ(transport.getConnectCounter(),
559 Experiments::getInstance()->getFlag(DOT_MAXTRIES_FLAG, DnsTlsQueryMap::kMaxTries));
560 }
561
TEST_F(TransportTest,PartialDrop)562 TEST_F(TransportTest, PartialDrop) {
563 FakeSocketLimited::sLimit = 10; // Close the socket after 10 queries.
564 FakeSocketLimited::sMaxSize = SIZE - 2; // Silently drop "long" queries
565 FakeSocketFactory<FakeSocketLimited> factory;
566 DnsTlsTransport transport(SERVER1, MARK, &factory);
567
568 // Queue up 100 queries, alternating "short" which will be served and "long"
569 // which will be dropped.
570 const int num_queries = 10 * FakeSocketLimited::sLimit;
571 std::vector<bytevec> queries(num_queries);
572 std::vector<std::future<DnsTlsTransport::Result>> results;
573 results.reserve(num_queries);
574 for (int i = 0; i < num_queries; ++i) {
575 queries[i] = make_query(i, SIZE + (i % 2));
576 results.push_back(transport.query(makeSlice(queries[i])));
577 }
578 // Just check the short queries, which are at the even indices.
579 for (int i = 0; i < num_queries; i += 2) {
580 auto r = results[i].get();
581 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
582 EXPECT_EQ(queries[i], r.response);
583 }
584
585 // TODO: transport.getConnectCounter() seems not stable in this test. Find how to check the
586 // connect attempts for this test.
587 }
588
TEST_F(TransportTest,ConnectCounter)589 TEST_F(TransportTest, ConnectCounter) {
590 FakeSocketLimited::sLimit = 2; // Close the socket after 2 queries.
591 FakeSocketLimited::sMaxSize = SIZE; // No query drops.
592 FakeSocketFactory<FakeSocketLimited> factory;
593 DnsTlsTransport transport(SERVER1, MARK, &factory);
594
595 // Connecting on demand.
596 EXPECT_EQ(transport.getConnectCounter(), 0);
597
598 const int num_queries = 10;
599 std::vector<std::future<DnsTlsTransport::Result>> results;
600 results.reserve(num_queries);
601 for (int i = 0; i < num_queries; i++) {
602 // Reconnections take place every two queries.
603 results.push_back(transport.query(makeSlice(QUERY)));
604 }
605 for (int i = 0; i < num_queries; i++) {
606 auto r = results[i].get();
607 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
608 }
609
610 EXPECT_EQ(transport.getConnectCounter(), num_queries / FakeSocketLimited::sLimit);
611 }
612
613 // Simulate a malfunctioning server that injects extra miscellaneous
614 // responses to queries that were not asked. This will cause wrong answers but
615 // must not crash the Transport.
616 class FakeSocketGarbage : public IDnsTlsSocket {
617 public:
FakeSocketGarbage(IDnsTlsSocketObserver * observer)618 explicit FakeSocketGarbage(IDnsTlsSocketObserver* observer) : mObserver(observer) {
619 // Inject a garbage event.
620 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(ID + 1, SIZE));
621 }
~FakeSocketGarbage()622 ~FakeSocketGarbage() {
623 std::lock_guard guard(mLock);
624 for (auto& thread : mThreads) {
625 thread.join();
626 }
627 }
query(uint16_t id,const Slice query)628 bool query(uint16_t id, const Slice query) override {
629 std::lock_guard guard(mLock);
630 // Return the response twice.
631 auto echo = make_echo(id, query);
632 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, echo);
633 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, echo);
634 // Also return some other garbage
635 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(id + 1, query.size() + 2));
636 return true;
637 }
startHandshake()638 bool startHandshake() override { return true; }
639
640 private:
641 std::mutex mLock;
642 std::vector<std::thread> mThreads GUARDED_BY(mLock);
643 IDnsTlsSocketObserver* const mObserver;
644 };
645
TEST_F(TransportTest,IgnoringGarbage)646 TEST_F(TransportTest, IgnoringGarbage) {
647 FakeSocketFactory<FakeSocketGarbage> factory;
648 DnsTlsTransport transport(SERVER1, MARK, &factory);
649 for (int i = 0; i < 10; ++i) {
650 auto r = transport.query(makeSlice(QUERY)).get();
651
652 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
653 // Don't check the response because this server is malfunctioning.
654 }
655 EXPECT_EQ(transport.getConnectCounter(), 1);
656 }
657
658 // Dispatcher tests
659 class DispatcherTest : public BaseTest {};
660
TEST_F(DispatcherTest,Query)661 TEST_F(DispatcherTest, Query) {
662 bytevec ans(4096);
663 int resplen = 0;
664 bool connectTriggered = false;
665
666 auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
667 DnsTlsDispatcher dispatcher(std::move(factory));
668 auto r = dispatcher.query(SERVER1, NETID, MARK, makeSlice(QUERY), makeSlice(ans), &resplen,
669 &connectTriggered);
670
671 EXPECT_EQ(DnsTlsTransport::Response::success, r);
672 EXPECT_EQ(int(QUERY.size()), resplen);
673 EXPECT_TRUE(connectTriggered);
674 ans.resize(resplen);
675 EXPECT_EQ(QUERY, ans);
676
677 // Expect to reuse the connection.
678 r = dispatcher.query(SERVER1, NETID, MARK, makeSlice(QUERY), makeSlice(ans), &resplen,
679 &connectTriggered);
680 EXPECT_EQ(DnsTlsTransport::Response::success, r);
681 EXPECT_FALSE(connectTriggered);
682 }
683
TEST_F(DispatcherTest,AnswerTooLarge)684 TEST_F(DispatcherTest, AnswerTooLarge) {
685 bytevec ans(SIZE - 1); // Too small to hold the answer
686 int resplen = 0;
687 bool connectTriggered = false;
688
689 auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
690 DnsTlsDispatcher dispatcher(std::move(factory));
691 auto r = dispatcher.query(SERVER1, NETID, MARK, makeSlice(QUERY), makeSlice(ans), &resplen,
692 &connectTriggered);
693
694 EXPECT_EQ(DnsTlsTransport::Response::limit_error, r);
695 EXPECT_TRUE(connectTriggered);
696 }
697
698 template<class T>
699 class TrackingFakeSocketFactory : public IDnsTlsSocketFactory {
700 public:
TrackingFakeSocketFactory()701 TrackingFakeSocketFactory() {}
createDnsTlsSocket(const DnsTlsServer & server,unsigned mark,IDnsTlsSocketObserver * observer,DnsTlsSessionCache * cache ATTRIBUTE_UNUSED)702 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
703 const DnsTlsServer& server,
704 unsigned mark,
705 IDnsTlsSocketObserver* observer,
706 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
707 std::lock_guard guard(mLock);
708 keys.emplace(mark, server);
709 return std::make_unique<T>(observer);
710 }
711 std::multiset<std::pair<unsigned, DnsTlsServer>> keys;
712
713 private:
714 std::mutex mLock;
715 };
716
TEST_F(DispatcherTest,Dispatching)717 TEST_F(DispatcherTest, Dispatching) {
718 FakeSocketDelay::sDelay = 5;
719 FakeSocketDelay::sReverse = true;
720 auto factory = std::make_unique<TrackingFakeSocketFactory<FakeSocketDelay>>();
721 auto* weak_factory = factory.get(); // Valid as long as dispatcher is in scope.
722 DnsTlsDispatcher dispatcher(std::move(factory));
723
724 // Populate a vector of two servers and two socket marks, four combinations
725 // in total.
726 std::vector<std::pair<unsigned, DnsTlsServer>> keys;
727 keys.emplace_back(MARK, SERVER1);
728 keys.emplace_back(MARK + 1, SERVER1);
729 keys.emplace_back(MARK, V4ADDR2);
730 keys.emplace_back(MARK + 1, V4ADDR2);
731
732 // Do several queries on each server. They should all succeed.
733 std::vector<std::thread> threads;
734 for (size_t i = 0; i < FakeSocketDelay::sDelay * keys.size(); ++i) {
735 auto key = keys[i % keys.size()];
736 threads.emplace_back([key, i] (DnsTlsDispatcher* dispatcher) {
737 auto q = make_query(i, SIZE);
738 bytevec ans(4096);
739 int resplen = 0;
740 bool connectTriggered = false;
741 unsigned mark = key.first;
742 unsigned netId = key.first;
743 const DnsTlsServer& server = key.second;
744 auto r = dispatcher->query(server, netId, mark, makeSlice(q), makeSlice(ans), &resplen,
745 &connectTriggered);
746 EXPECT_EQ(DnsTlsTransport::Response::success, r);
747 EXPECT_EQ(int(q.size()), resplen);
748 ans.resize(resplen);
749 EXPECT_EQ(q, ans);
750 }, &dispatcher);
751 }
752 for (auto& thread : threads) {
753 thread.join();
754 }
755 // We expect that the factory created one socket for each key.
756 EXPECT_EQ(keys.size(), weak_factory->keys.size());
757 for (auto& key : keys) {
758 EXPECT_EQ(1U, weak_factory->keys.count(key));
759 }
760 }
761
762 // Check DnsTlsServer's comparison logic.
763 AddressComparator ADDRESS_COMPARATOR;
isAddressEqual(const DnsTlsServer & s1,const DnsTlsServer & s2)764 bool isAddressEqual(const DnsTlsServer& s1, const DnsTlsServer& s2) {
765 bool cmp1 = ADDRESS_COMPARATOR(s1, s2);
766 bool cmp2 = ADDRESS_COMPARATOR(s2, s1);
767 EXPECT_FALSE(cmp1 && cmp2);
768 return !cmp1 && !cmp2;
769 }
770
checkUnequal(const DnsTlsServer & s1,const DnsTlsServer & s2)771 void checkUnequal(const DnsTlsServer& s1, const DnsTlsServer& s2) {
772 EXPECT_TRUE(s1 == s1);
773 EXPECT_TRUE(s2 == s2);
774 EXPECT_TRUE(isAddressEqual(s1, s1));
775 EXPECT_TRUE(isAddressEqual(s2, s2));
776
777 EXPECT_TRUE(s1 < s2 ^ s2 < s1);
778 EXPECT_FALSE(s1 == s2);
779 EXPECT_FALSE(s2 == s1);
780 }
781
checkEqual(const DnsTlsServer & s1,const DnsTlsServer & s2)782 void checkEqual(const DnsTlsServer& s1, const DnsTlsServer& s2) {
783 EXPECT_TRUE(s1 == s1);
784 EXPECT_TRUE(s2 == s2);
785 EXPECT_TRUE(isAddressEqual(s1, s1));
786 EXPECT_TRUE(isAddressEqual(s2, s2));
787
788 EXPECT_FALSE(s1 < s2);
789 EXPECT_FALSE(s2 < s1);
790 EXPECT_TRUE(s1 == s2);
791 EXPECT_TRUE(s2 == s1);
792 }
793
794 class ServerTest : public BaseTest {};
795
TEST_F(ServerTest,IPv4)796 TEST_F(ServerTest, IPv4) {
797 checkUnequal(DnsTlsServer(V4ADDR1), DnsTlsServer(V4ADDR2));
798 EXPECT_FALSE(isAddressEqual(DnsTlsServer(V4ADDR1), DnsTlsServer(V4ADDR2)));
799 }
800
TEST_F(ServerTest,IPv6)801 TEST_F(ServerTest, IPv6) {
802 checkUnequal(DnsTlsServer(V6ADDR1), DnsTlsServer(V6ADDR2));
803 EXPECT_FALSE(isAddressEqual(DnsTlsServer(V6ADDR1), DnsTlsServer(V6ADDR2)));
804 }
805
TEST_F(ServerTest,MixedAddressFamily)806 TEST_F(ServerTest, MixedAddressFamily) {
807 checkUnequal(DnsTlsServer(V6ADDR1), DnsTlsServer(V4ADDR1));
808 EXPECT_FALSE(isAddressEqual(DnsTlsServer(V6ADDR1), DnsTlsServer(V4ADDR1)));
809 }
810
TEST_F(ServerTest,IPv6ScopeId)811 TEST_F(ServerTest, IPv6ScopeId) {
812 DnsTlsServer s1(IPAddress::forString("fe80::1%1"));
813 DnsTlsServer s2(IPAddress::forString("fe80::1%2"));
814 checkUnequal(s1, s2);
815 EXPECT_FALSE(isAddressEqual(s1, s2));
816
817 EXPECT_FALSE(s1.wasExplicitlyConfigured());
818 EXPECT_FALSE(s2.wasExplicitlyConfigured());
819 }
820
TEST_F(ServerTest,Port)821 TEST_F(ServerTest, Port) {
822 DnsTlsServer s1(IPSockAddr::toIPSockAddr("192.0.2.1", 853));
823 DnsTlsServer s2(IPSockAddr::toIPSockAddr("192.0.2.1", 854));
824 checkUnequal(s1, s2);
825 EXPECT_TRUE(isAddressEqual(s1, s2));
826 EXPECT_EQ(s1.toIpString(), "192.0.2.1");
827 EXPECT_EQ(s2.toIpString(), "192.0.2.1");
828
829 DnsTlsServer s3(IPSockAddr::toIPSockAddr("2001:db8::1", 853));
830 DnsTlsServer s4(IPSockAddr::toIPSockAddr("2001:db8::1", 854));
831 checkUnequal(s3, s4);
832 EXPECT_TRUE(isAddressEqual(s3, s4));
833 EXPECT_EQ(s3.toIpString(), "2001:db8::1");
834 EXPECT_EQ(s4.toIpString(), "2001:db8::1");
835
836 EXPECT_FALSE(s1.wasExplicitlyConfigured());
837 EXPECT_FALSE(s2.wasExplicitlyConfigured());
838 }
839
TEST_F(ServerTest,Name)840 TEST_F(ServerTest, Name) {
841 DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
842 s1.name = SERVERNAME1;
843 checkUnequal(s1, s2);
844 s2.name = SERVERNAME2;
845 checkUnequal(s1, s2);
846 EXPECT_TRUE(isAddressEqual(s1, s2));
847
848 EXPECT_TRUE(s1.wasExplicitlyConfigured());
849 EXPECT_TRUE(s2.wasExplicitlyConfigured());
850 }
851
TEST_F(ServerTest,State)852 TEST_F(ServerTest, State) {
853 DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
854 checkEqual(s1, s2);
855 s1.setValidationState(Validation::success);
856 checkEqual(s1, s2);
857 s2.setValidationState(Validation::fail);
858 checkEqual(s1, s2);
859 s1.setActive(true);
860 checkEqual(s1, s2);
861 s2.setActive(false);
862 checkEqual(s1, s2);
863
864 EXPECT_EQ(s1.validationState(), Validation::success);
865 EXPECT_EQ(s2.validationState(), Validation::fail);
866 EXPECT_TRUE(s1.active());
867 EXPECT_FALSE(s2.active());
868 }
869
870 class QueryMapTest : public NetNativeTestBase {};
871
TEST_F(QueryMapTest,Basic)872 TEST_F(QueryMapTest, Basic) {
873 DnsTlsQueryMap map;
874
875 EXPECT_TRUE(map.empty());
876
877 bytevec q0 = make_query(999, SIZE);
878 bytevec q1 = make_query(888, SIZE);
879 bytevec q2 = make_query(777, SIZE);
880
881 auto f0 = map.recordQuery(makeSlice(q0));
882 auto f1 = map.recordQuery(makeSlice(q1));
883 auto f2 = map.recordQuery(makeSlice(q2));
884
885 // Check return values of recordQuery
886 EXPECT_EQ(0, f0->query.newId);
887 EXPECT_EQ(1, f1->query.newId);
888 EXPECT_EQ(2, f2->query.newId);
889
890 // Check side effects of recordQuery
891 EXPECT_FALSE(map.empty());
892
893 auto all = map.getAll();
894 EXPECT_EQ(3U, all.size());
895
896 EXPECT_EQ(0, all[0].newId);
897 EXPECT_EQ(1, all[1].newId);
898 EXPECT_EQ(2, all[2].newId);
899
900 EXPECT_EQ(q0, all[0].query);
901 EXPECT_EQ(q1, all[1].query);
902 EXPECT_EQ(q2, all[2].query);
903
904 bytevec a0 = make_query(0, SIZE);
905 bytevec a1 = make_query(1, SIZE);
906 bytevec a2 = make_query(2, SIZE);
907
908 // Return responses out of order
909 map.onResponse(a2);
910 map.onResponse(a0);
911 map.onResponse(a1);
912
913 EXPECT_TRUE(map.empty());
914
915 auto r0 = f0->result.get();
916 auto r1 = f1->result.get();
917 auto r2 = f2->result.get();
918
919 EXPECT_EQ(DnsTlsQueryMap::Response::success, r0.code);
920 EXPECT_EQ(DnsTlsQueryMap::Response::success, r1.code);
921 EXPECT_EQ(DnsTlsQueryMap::Response::success, r2.code);
922
923 const bytevec& d0 = r0.response;
924 const bytevec& d1 = r1.response;
925 const bytevec& d2 = r2.response;
926
927 // The ID should match the query
928 EXPECT_EQ(999, d0[0] << 8 | d0[1]);
929 EXPECT_EQ(888, d1[0] << 8 | d1[1]);
930 EXPECT_EQ(777, d2[0] << 8 | d2[1]);
931 // The body should match the answer
932 EXPECT_EQ(bytevec(a0.begin() + 2, a0.end()), bytevec(d0.begin() + 2, d0.end()));
933 EXPECT_EQ(bytevec(a1.begin() + 2, a1.end()), bytevec(d1.begin() + 2, d1.end()));
934 EXPECT_EQ(bytevec(a2.begin() + 2, a2.end()), bytevec(d2.begin() + 2, d2.end()));
935 }
936
TEST_F(QueryMapTest,FillHole)937 TEST_F(QueryMapTest, FillHole) {
938 DnsTlsQueryMap map;
939 std::vector<std::unique_ptr<DnsTlsQueryMap::QueryFuture>> futures(UINT16_MAX + 1);
940 for (uint32_t i = 0; i <= UINT16_MAX; ++i) {
941 futures[i] = map.recordQuery(makeSlice(QUERY));
942 ASSERT_TRUE(futures[i]); // answers[i] should be nonnull.
943 EXPECT_EQ(i, futures[i]->query.newId);
944 }
945
946 // The map should now be full.
947 EXPECT_EQ(size_t(UINT16_MAX + 1), map.getAll().size());
948
949 // Trying to add another query should fail because the map is full.
950 EXPECT_FALSE(map.recordQuery(makeSlice(QUERY)));
951
952 // Send an answer to query 40000
953 auto answer = make_query(40000, SIZE);
954 map.onResponse(answer);
955 auto result = futures[40000]->result.get();
956 EXPECT_EQ(DnsTlsQueryMap::Response::success, result.code);
957 EXPECT_EQ(ID, result.response[0] << 8 | result.response[1]);
958 EXPECT_EQ(bytevec(answer.begin() + 2, answer.end()),
959 bytevec(result.response.begin() + 2, result.response.end()));
960
961 // There should now be room in the map.
962 EXPECT_EQ(size_t(UINT16_MAX), map.getAll().size());
963 auto f = map.recordQuery(makeSlice(QUERY));
964 ASSERT_TRUE(f);
965 EXPECT_EQ(40000, f->query.newId);
966
967 // The map should now be full again.
968 EXPECT_EQ(size_t(UINT16_MAX + 1), map.getAll().size());
969 EXPECT_FALSE(map.recordQuery(makeSlice(QUERY)));
970 }
971
972 class DnsTlsSocketTest : public NetNativeTestBase {
973 protected:
974 class MockDnsTlsSocketObserver : public IDnsTlsSocketObserver {
975 public:
976 MOCK_METHOD(void, onClosed, (), (override));
977 MOCK_METHOD(void, onResponse, (std::vector<uint8_t>), (override));
978 };
979
makeDnsTlsSocket(IDnsTlsSocketObserver * observer)980 std::unique_ptr<DnsTlsSocket> makeDnsTlsSocket(IDnsTlsSocketObserver* observer) {
981 return std::make_unique<DnsTlsSocket>(this->server, MARK, observer, &this->cache);
982 }
983
enableAsyncHandshake(const std::unique_ptr<DnsTlsSocket> & socket)984 void enableAsyncHandshake(const std::unique_ptr<DnsTlsSocket>& socket) {
985 ASSERT_TRUE(socket);
986 DnsTlsSocket* delegate = socket.get();
987 std::lock_guard guard(delegate->mLock);
988 delegate->mAsyncHandshake = true;
989 }
990
991 static constexpr char kTlsAddr[] = "127.0.0.3";
992 static constexpr char kTlsPort[] = "8530"; // High-numbered port so root isn't required.
993 static constexpr char kBackendAddr[] = "192.0.2.1";
994 static constexpr char kBackendPort[] = "8531"; // High-numbered port so root isn't required.
995
996 test::DnsTlsFrontend tls{kTlsAddr, kTlsPort, kBackendAddr, kBackendPort};
997
998 const DnsTlsServer server{IPSockAddr::toIPSockAddr(kTlsAddr, std::stoi(kTlsPort))};
999 DnsTlsSessionCache cache;
1000 };
1001
TEST_F(DnsTlsSocketTest,SlowDestructor)1002 TEST_F(DnsTlsSocketTest, SlowDestructor) {
1003 ASSERT_TRUE(tls.startServer());
1004
1005 MockDnsTlsSocketObserver observer;
1006 auto socket = makeDnsTlsSocket(&observer);
1007
1008 ASSERT_TRUE(socket->initialize());
1009 ASSERT_TRUE(socket->startHandshake());
1010
1011 // Test: Time the socket destructor. This should be fast.
1012 auto before = std::chrono::steady_clock::now();
1013 EXPECT_CALL(observer, onClosed);
1014 socket.reset();
1015 auto after = std::chrono::steady_clock::now();
1016 auto delay = after - before;
1017 LOG(DEBUG) << "Shutdown took " << delay / std::chrono::nanoseconds{1} << "ns";
1018 // Shutdown should complete in milliseconds, but if the shutdown signal is lost
1019 // it will wait for the timeout, which is expected to take 20seconds.
1020 EXPECT_LT(delay, std::chrono::seconds{5});
1021 }
1022
TEST_F(DnsTlsSocketTest,StartHandshake)1023 TEST_F(DnsTlsSocketTest, StartHandshake) {
1024 ASSERT_TRUE(tls.startServer());
1025
1026 MockDnsTlsSocketObserver observer;
1027 auto socket = makeDnsTlsSocket(&observer);
1028
1029 // Call the function before the call to initialize().
1030 EXPECT_FALSE(socket->startHandshake());
1031
1032 // Call the function after the call to initialize().
1033 EXPECT_TRUE(socket->initialize());
1034 EXPECT_TRUE(socket->startHandshake());
1035
1036 // Call both of them again.
1037 EXPECT_FALSE(socket->initialize());
1038 EXPECT_FALSE(socket->startHandshake());
1039
1040 // Should happen when joining the loop thread in |socket| destruction.
1041 EXPECT_CALL(observer, onClosed);
1042 }
1043
TEST_F(DnsTlsSocketTest,ShutdownSignal)1044 TEST_F(DnsTlsSocketTest, ShutdownSignal) {
1045 ASSERT_TRUE(tls.startServer());
1046
1047 MockDnsTlsSocketObserver observer;
1048 std::unique_ptr<DnsTlsSocket> socket;
1049
1050 const auto setupAndStartHandshake = [&]() {
1051 socket = makeDnsTlsSocket(&observer);
1052 EXPECT_TRUE(socket->initialize());
1053 enableAsyncHandshake(socket);
1054 EXPECT_TRUE(socket->startHandshake());
1055 };
1056 const auto triggerShutdown = [&](const std::string& traceLog) {
1057 SCOPED_TRACE(traceLog);
1058 auto before = std::chrono::steady_clock::now();
1059 EXPECT_CALL(observer, onClosed);
1060 socket.reset();
1061 auto after = std::chrono::steady_clock::now();
1062 auto delay = after - before;
1063 LOG(INFO) << "Shutdown took " << delay / std::chrono::nanoseconds{1} << "ns";
1064 EXPECT_LT(delay, std::chrono::seconds{1});
1065 };
1066
1067 tls.setHangOnHandshakeForTesting(true);
1068
1069 // Test 1: Reset the DnsTlsSocket which is doing the handshake.
1070 setupAndStartHandshake();
1071 triggerShutdown("Shutdown handshake w/o query requests");
1072
1073 // Test 2: Reset the DnsTlsSocket which is doing the handshake with some query requests.
1074 setupAndStartHandshake();
1075
1076 // DnsTlsSocket doesn't report the status of pending queries. The decision whether to mark
1077 // a query request as failed or not is made in DnsTlsTransport.
1078 EXPECT_CALL(observer, onResponse).Times(0);
1079 EXPECT_TRUE(socket->query(1, makeSlice(QUERY)));
1080 EXPECT_TRUE(socket->query(2, makeSlice(QUERY)));
1081 triggerShutdown("Shutdown handshake w/ query requests");
1082 }
1083
1084 } // end of namespace net
1085 } // end of namespace android
1086