1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *
16  */
17 
18 #define LOG_TAG "resolv"
19 
20 #include "DnsStats.h"
21 
22 #include <android-base/format.h>
23 #include <android-base/logging.h>
24 
25 namespace android::net {
26 
27 using netdutils::DumpWriter;
28 using netdutils::IPAddress;
29 using netdutils::IPSockAddr;
30 using netdutils::ScopedIndent;
31 using std::chrono::duration_cast;
32 using std::chrono::microseconds;
33 using std::chrono::milliseconds;
34 using std::chrono::seconds;
35 
36 namespace {
37 
38 static constexpr IPAddress INVALID_IPADDRESS = IPAddress();
39 
rcodeToName(int rcode)40 std::string rcodeToName(int rcode) {
41     // clang-format off
42     switch (rcode) {
43         case NS_R_NO_ERROR: return "NOERROR";
44         case NS_R_FORMERR: return "FORMERR";
45         case NS_R_SERVFAIL: return "SERVFAIL";
46         case NS_R_NXDOMAIN: return "NXDOMAIN";
47         case NS_R_NOTIMPL: return "NOTIMP";
48         case NS_R_REFUSED: return "REFUSED";
49         case NS_R_YXDOMAIN: return "YXDOMAIN";
50         case NS_R_YXRRSET: return "YXRRSET";
51         case NS_R_NXRRSET: return "NXRRSET";
52         case NS_R_NOTAUTH: return "NOTAUTH";
53         case NS_R_NOTZONE: return "NOTZONE";
54         case NS_R_INTERNAL_ERROR: return "INTERNAL_ERROR";
55         case NS_R_TIMEOUT: return "TIMEOUT";
56         default: return fmt::format("UNKNOWN({})", rcode);
57     }
58     // clang-format on
59 }
60 
ensureNoInvalidIp(const std::vector<IPSockAddr> & addrs)61 bool ensureNoInvalidIp(const std::vector<IPSockAddr>& addrs) {
62     for (const auto& addr : addrs) {
63         if (addr.ip() == INVALID_IPADDRESS || addr.port() == 0) {
64             LOG(WARNING) << "Invalid addr: " << addr;
65             return false;
66         }
67     }
68     return true;
69 }
70 
71 }  // namespace
72 
73 // The comparison ignores the last update time.
operator ==(const StatsData & o) const74 bool StatsData::operator==(const StatsData& o) const {
75     return std::tie(sockAddr, total, rcodeCounts, latencyUs) ==
76            std::tie(o.sockAddr, o.total, o.rcodeCounts, o.latencyUs);
77 }
78 
averageLatencyMs() const79 int StatsData::averageLatencyMs() const {
80     return (total == 0) ? 0 : duration_cast<milliseconds>(latencyUs).count() / total;
81 }
82 
toString() const83 std::string StatsData::toString() const {
84     if (total == 0) return fmt::format("{} <no data>", sockAddr.toString());
85 
86     const auto now = std::chrono::steady_clock::now();
87     const int lastUpdateSec = duration_cast<seconds>(now - lastUpdate).count();
88     std::string buf;
89     for (const auto& [rcode, counts] : rcodeCounts) {
90         if (counts != 0) {
91             buf += fmt::format("{}:{} ", rcodeToName(rcode), counts);
92         }
93     }
94     return fmt::format("{} ({}, {}ms, [{}], {}s)", sockAddr.toString(), total, averageLatencyMs(),
95                        buf, lastUpdateSec);
96 }
97 
StatsRecords(const IPSockAddr & ipSockAddr,size_t size)98 StatsRecords::StatsRecords(const IPSockAddr& ipSockAddr, size_t size)
99     : mCapacity(size), mStatsData(ipSockAddr) {}
100 
push(const Record & record)101 void StatsRecords::push(const Record& record) {
102     updateStatsData(record, true);
103     mRecords.push_back(record);
104 
105     if (mRecords.size() > mCapacity) {
106         updateStatsData(mRecords.front(), false);
107         mRecords.pop_front();
108     }
109 
110     // Update the quality factors.
111     mSkippedCount = 0;
112 
113     // Because failures due to no permission can't prove that the quality of DNS server is bad,
114     // skip the penalty update. The average latency, however, has been updated. For short-latency
115     // servers, it will be fine. For long-latency servers, their average latency will be
116     // decreased but the latency-based algorithm will adjust their average latency back to the
117     // right range after few attempts when network is not restricted.
118     // The check is synced from isNetworkRestricted() in res_send.cpp.
119     if (record.linux_errno != EPERM) {
120         updatePenalty(record);
121     }
122 }
123 
updateStatsData(const Record & record,const bool add)124 void StatsRecords::updateStatsData(const Record& record, const bool add) {
125     const int rcode = record.rcode;
126     if (add) {
127         mStatsData.total += 1;
128         mStatsData.rcodeCounts[rcode] += 1;
129         mStatsData.latencyUs += record.latencyUs;
130     } else {
131         mStatsData.total -= 1;
132         mStatsData.rcodeCounts[rcode] -= 1;
133         mStatsData.latencyUs -= record.latencyUs;
134     }
135     mStatsData.lastUpdate = std::chrono::steady_clock::now();
136 }
137 
updatePenalty(const Record & record)138 void StatsRecords::updatePenalty(const Record& record) {
139     switch (record.rcode) {
140         case NS_R_NO_ERROR:
141         case NS_R_NXDOMAIN:
142         case NS_R_NOTAUTH:
143             mPenalty = 0;
144             return;
145         default:
146             // NS_R_TIMEOUT and NS_R_INTERNAL_ERROR are in this case.
147             if (mPenalty == 0) {
148                 mPenalty = 100;
149             } else {
150                 // The evaluated quality drops more quickly when continuous failures happen.
151                 mPenalty = std::min(mPenalty * 2, kMaxQuality);
152             }
153             return;
154     }
155 }
156 
score() const157 double StatsRecords::score() const {
158     const int avgRtt = mStatsData.averageLatencyMs();
159 
160     // Set the lower bound to -1 in case of "avgRtt + mPenalty < mSkippedCount"
161     //   1) when the server doesn't have any stats yet.
162     //   2) when the sorting has been disabled while it was enabled before.
163     int quality = std::clamp(avgRtt + mPenalty - mSkippedCount, -1, kMaxQuality);
164 
165     // Normalization.
166     return static_cast<double>(kMaxQuality - quality) * 100 / kMaxQuality;
167 }
168 
incrementSkippedCount()169 void StatsRecords::incrementSkippedCount() {
170     mSkippedCount = std::min(mSkippedCount + 1, kMaxQuality);
171 }
172 
setAddrs(const std::vector<netdutils::IPSockAddr> & addrs,Protocol protocol)173 bool DnsStats::setAddrs(const std::vector<netdutils::IPSockAddr>& addrs, Protocol protocol) {
174     if (!ensureNoInvalidIp(addrs)) return false;
175 
176     StatsMap& statsMap = mStats[protocol];
177     for (const auto& addr : addrs) {
178         statsMap.try_emplace(addr, StatsRecords(addr, kLogSize));
179     }
180 
181     // Clean up the map to eliminate the nodes not belonging to the given list of servers.
182     const auto cleanup = [&](StatsMap* statsMap) {
183         StatsMap tmp;
184         for (const auto& addr : addrs) {
185             if (statsMap->find(addr) != statsMap->end()) {
186                 tmp.insert(statsMap->extract(addr));
187             }
188         }
189         statsMap->swap(tmp);
190     };
191 
192     cleanup(&statsMap);
193 
194     return true;
195 }
196 
addStats(const IPSockAddr & ipSockAddr,const DnsQueryEvent & record)197 bool DnsStats::addStats(const IPSockAddr& ipSockAddr, const DnsQueryEvent& record) {
198     if (ipSockAddr.ip() == INVALID_IPADDRESS) return false;
199 
200     bool added = false;
201     for (auto& [sockAddr, statsRecords] : mStats[record.protocol()]) {
202         if (sockAddr == ipSockAddr) {
203             const StatsRecords::Record rec = {
204                     .rcode = record.rcode(),
205                     .linux_errno = record.linux_errno(),
206                     .latencyUs = microseconds(record.latency_micros()),
207             };
208             statsRecords.push(rec);
209             added = true;
210         } else {
211             statsRecords.incrementSkippedCount();
212         }
213     }
214 
215     return added;
216 }
217 
getSortedServers(Protocol protocol) const218 std::vector<IPSockAddr> DnsStats::getSortedServers(Protocol protocol) const {
219     // DoT unsupported. The handshake overhead is expensive, and the connection will hang for a
220     // while. Need to figure out if it is worth doing for DoT servers.
221     if (protocol == PROTO_DOT) return {};
222 
223     auto it = mStats.find(protocol);
224     if (it == mStats.end()) return {};
225 
226     // Sorting on insertion in decreasing order.
227     std::multimap<double, IPSockAddr, std::greater<double>> sortedData;
228     for (const auto& [ip, statsRecords] : it->second) {
229         sortedData.insert({statsRecords.score(), ip});
230     }
231 
232     std::vector<IPSockAddr> ret;
233     ret.reserve(sortedData.size());
234     for (auto& [_, v] : sortedData) {
235         ret.push_back(v);  // IPSockAddr is trivially-copyable.
236     }
237 
238     return ret;
239 }
240 
getAverageLatencyUs(Protocol protocol) const241 std::optional<microseconds> DnsStats::getAverageLatencyUs(Protocol protocol) const {
242     const auto stats = getStats(protocol);
243 
244     int count = 0;
245     microseconds sum;
246     for (const auto& v : stats) {
247         count += v.total;
248         sum += v.latencyUs;
249     }
250 
251     if (count == 0) return std::nullopt;
252     return sum / count;
253 }
254 
getStats(Protocol protocol) const255 std::vector<StatsData> DnsStats::getStats(Protocol protocol) const {
256     std::vector<StatsData> ret;
257 
258     if (mStats.find(protocol) != mStats.end()) {
259         for (const auto& [_, statsRecords] : mStats.at(protocol)) {
260             ret.push_back(statsRecords.getStatsData());
261         }
262     }
263     return ret;
264 }
265 
dump(DumpWriter & dw)266 void DnsStats::dump(DumpWriter& dw) {
267     const auto dumpStatsMap = [&](StatsMap& statsMap) {
268         ScopedIndent indentLog(dw);
269         if (statsMap.size() == 0) {
270             dw.println("<no data>");
271             return;
272         }
273         for (const auto& [_, statsRecords] : statsMap) {
274             const StatsData& data = statsRecords.getStatsData();
275             std::string str =
276                     fmt::format("{} score{{{:.1f}}}", data.toString(), statsRecords.score());
277             dw.println("%s", str.c_str());
278         }
279     };
280 
281     dw.println("Server statistics: (total, RTT avg, {rcode:counts}, last update)");
282     ScopedIndent indentStats(dw);
283 
284     dw.println("over UDP");
285     dumpStatsMap(mStats[PROTO_UDP]);
286 
287     dw.println("over DOH");
288     dumpStatsMap(mStats[PROTO_DOH]);
289 
290     dw.println("over TLS");
291     dumpStatsMap(mStats[PROTO_DOT]);
292 
293     dw.println("over TCP");
294     dumpStatsMap(mStats[PROTO_TCP]);
295 
296     dw.println("over MDNS");
297     dumpStatsMap(mStats[PROTO_MDNS]);
298 }
299 
300 }  // namespace android::net
301