1 /*
2  * Copyright (C) 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Netd"
18 
19 #include "SockDiag.h"
20 
21 #include <errno.h>
22 #include <linux/inet_diag.h>
23 #include <linux/netlink.h>
24 #include <linux/sock_diag.h>
25 #include <netdb.h>
26 #include <netinet/in.h>
27 #include <netinet/tcp.h>
28 #include <string.h>
29 #include <sys/socket.h>
30 #include <sys/uio.h>
31 
32 #include <cinttypes>
33 
34 #include <android-base/properties.h>
35 #include <android-base/stringprintf.h>
36 #include <android-base/strings.h>
37 #include <log/log.h>
38 #include <netdutils/InternetAddresses.h>
39 #include <netdutils/Stopwatch.h>
40 
41 #include "Permission.h"
42 
43 #ifndef SOCK_DESTROY
44 #define SOCK_DESTROY 21
45 #endif
46 
47 #define INET_DIAG_BC_MARK_COND 10
48 
49 namespace android {
50 
51 using android::base::StringPrintf;
52 using netdutils::ScopedAddrinfo;
53 using netdutils::Stopwatch;
54 
55 namespace net {
56 namespace {
57 
58 static const bool isUser = (android::base::GetProperty("ro.build.type", "") == "user");
59 
getAdbPort()60 int getAdbPort() {
61     return android::base::GetIntProperty("service.adb.tcp.port", 0);
62 }
63 
isAdbSocket(const inet_diag_msg * msg,int adbPort)64 bool isAdbSocket(const inet_diag_msg *msg, int adbPort) {
65     return adbPort > 0 && msg->id.idiag_sport == htons(adbPort) &&
66         (msg->idiag_uid == AID_ROOT || msg->idiag_uid == AID_SHELL);
67 }
68 
checkError(int fd)69 int checkError(int fd) {
70     struct {
71         nlmsghdr h;
72         nlmsgerr err;
73     } __attribute__((__packed__)) ack;
74     ssize_t bytesread = recv(fd, &ack, sizeof(ack), MSG_DONTWAIT | MSG_PEEK);
75     if (bytesread == -1) {
76        // Read failed (error), or nothing to read (good).
77        return (errno == EAGAIN) ? 0 : -errno;
78     } else if (bytesread == (ssize_t) sizeof(ack) && ack.h.nlmsg_type == NLMSG_ERROR) {
79         // We got an error. Consume it.
80         recv(fd, &ack, sizeof(ack), 0);
81         return ack.err.error;
82     } else {
83         // The kernel replied with something. Leave it to the caller.
84         return 0;
85     }
86 }
87 
88 }  // namespace
89 
open()90 bool SockDiag::open() {
91     if (hasSocks()) {
92         return false;
93     }
94 
95     mSock = socket(PF_NETLINK, SOCK_DGRAM | SOCK_CLOEXEC, NETLINK_INET_DIAG);
96     mWriteSock = socket(PF_NETLINK, SOCK_DGRAM | SOCK_CLOEXEC, NETLINK_INET_DIAG);
97     if (!hasSocks()) {
98         closeSocks();
99         return false;
100     }
101 
102     sockaddr_nl nl = { .nl_family = AF_NETLINK };
103     if ((connect(mSock, reinterpret_cast<sockaddr *>(&nl), sizeof(nl)) == -1) ||
104         (connect(mWriteSock, reinterpret_cast<sockaddr *>(&nl), sizeof(nl)) == -1)) {
105         closeSocks();
106         return false;
107     }
108 
109     return true;
110 }
111 
sendDumpRequest(uint8_t proto,uint8_t family,uint8_t extensions,uint32_t states,iovec * iov,int iovcnt)112 int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, uint8_t extensions, uint32_t states,
113                               iovec *iov, int iovcnt) {
114     struct {
115         nlmsghdr nlh;
116         inet_diag_req_v2 req;
117     } __attribute__((__packed__)) request = {
118         .nlh = {
119             .nlmsg_type = SOCK_DIAG_BY_FAMILY,
120             .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
121         },
122         .req = {
123             .sdiag_family = family,
124             .sdiag_protocol = proto,
125             .idiag_ext = extensions,
126             .idiag_states = states,
127         },
128     };
129 
130     size_t len = 0;
131     iov[0].iov_base = &request;
132     iov[0].iov_len = sizeof(request);
133     for (int i = 0; i < iovcnt; i++) {
134         len += iov[i].iov_len;
135     }
136     request.nlh.nlmsg_len = len;
137 
138     ssize_t writevRet = writev(mSock, iov, iovcnt);
139     // Don't let pointers to the stack escape.
140     iov[0] = {nullptr, 0};
141     if (writevRet != (ssize_t)len) {
142         return -errno;
143     }
144 
145     return checkError(mSock);
146 }
147 
sendDumpRequest(uint8_t proto,uint8_t family,uint32_t states)148 int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, uint32_t states) {
149     iovec iov[] = {
150         { nullptr, 0 },
151     };
152     return sendDumpRequest(proto, family, 0, states, iov, ARRAY_SIZE(iov));
153 }
154 
sendDumpRequest(uint8_t proto,uint8_t family,const char * addrstr)155 int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, const char *addrstr) {
156     addrinfo hints = { .ai_flags = AI_NUMERICHOST };
157     addrinfo *res;
158     in6_addr mapped = { .s6_addr32 = { 0, 0, htonl(0xffff), 0 } };
159 
160     // TODO: refactor the netlink parsing code out of system/core, bring it into netd, and stop
161     // doing string conversions when they're not necessary.
162     int ret = getaddrinfo(addrstr, nullptr, &hints, &res);
163     if (ret != 0) return -EINVAL;
164 
165     // So we don't have to call freeaddrinfo on every failure path.
166     ScopedAddrinfo resP(res);
167 
168     void *addr;
169     uint8_t addrlen;
170     if (res->ai_family == AF_INET && family == AF_INET) {
171         in_addr& ina = reinterpret_cast<sockaddr_in*>(res->ai_addr)->sin_addr;
172         addr = &ina;
173         addrlen = sizeof(ina);
174     } else if (res->ai_family == AF_INET && family == AF_INET6) {
175         in_addr& ina = reinterpret_cast<sockaddr_in*>(res->ai_addr)->sin_addr;
176         mapped.s6_addr32[3] = ina.s_addr;
177         addr = &mapped;
178         addrlen = sizeof(mapped);
179     } else if (res->ai_family == AF_INET6 && family == AF_INET6) {
180         in6_addr& in6a = reinterpret_cast<sockaddr_in6*>(res->ai_addr)->sin6_addr;
181         addr = &in6a;
182         addrlen = sizeof(in6a);
183     } else {
184         return -EAFNOSUPPORT;
185     }
186 
187     uint8_t prefixlen = addrlen * 8;
188     uint8_t yesjump = sizeof(inet_diag_bc_op) + sizeof(inet_diag_hostcond) + addrlen;
189     uint8_t nojump = yesjump + 4;
190 
191     struct {
192         nlattr nla;
193         inet_diag_bc_op op;
194         inet_diag_hostcond cond;
195     } __attribute__((__packed__)) attrs = {
196         .nla = {
197             .nla_type = INET_DIAG_REQ_BYTECODE,
198         },
199         .op = {
200             INET_DIAG_BC_S_COND,
201             yesjump,
202             nojump,
203         },
204         .cond = {
205             family,
206             prefixlen,
207             -1,
208             {}
209         },
210     };
211 
212     attrs.nla.nla_len = sizeof(attrs) + addrlen;
213 
214     iovec iov[] = {
215         { nullptr,           0 },
216         { &attrs,            sizeof(attrs) },
217         { addr,              addrlen },
218     };
219 
220     uint32_t states = ~(1 << TCP_TIME_WAIT);
221     return sendDumpRequest(proto, family, 0, states, iov, ARRAY_SIZE(iov));
222 }
223 
readDiagMsg(uint8_t proto,const SockDiag::DestroyFilter & shouldDestroy)224 int SockDiag::readDiagMsg(uint8_t proto, const SockDiag::DestroyFilter& shouldDestroy) {
225     NetlinkDumpCallback callback = [this, proto, shouldDestroy] (nlmsghdr *nlh) {
226         const inet_diag_msg *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
227         if (shouldDestroy(proto, msg)) {
228             sockDestroy(proto, msg);
229         }
230     };
231 
232     return processNetlinkDump(mSock, callback);
233 }
234 
readDiagMsgWithTcpInfo(const TcpInfoReader & tcpInfoReader)235 int SockDiag::readDiagMsgWithTcpInfo(const TcpInfoReader& tcpInfoReader) {
236     NetlinkDumpCallback callback = [tcpInfoReader] (nlmsghdr *nlh) {
237         if (nlh->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
238             ALOGE("expected nlmsg_type=SOCK_DIAG_BY_FAMILY, got nlmsg_type=%d", nlh->nlmsg_type);
239             return;
240         }
241         Fwmark mark;
242         struct tcp_info *tcpinfo = nullptr;
243         uint32_t tcpinfoLength = 0;
244         inet_diag_msg *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
245         uint32_t attr_len = nlh->nlmsg_len - NLMSG_LENGTH(sizeof(*msg));
246         struct rtattr *attr = reinterpret_cast<struct rtattr*>(msg+1);
247         while (RTA_OK(attr, attr_len)) {
248             if (attr->rta_type == INET_DIAG_INFO) {
249                 tcpinfo = reinterpret_cast<struct tcp_info*>(RTA_DATA(attr));
250                 tcpinfoLength = RTA_PAYLOAD(attr);
251             }
252             if (attr->rta_type == INET_DIAG_MARK) {
253                 mark.intValue = *reinterpret_cast<uint32_t*>(RTA_DATA(attr));
254             }
255             attr = RTA_NEXT(attr, attr_len);
256         }
257 
258         tcpInfoReader(mark, msg, tcpinfo, tcpinfoLength);
259     };
260 
261     return processNetlinkDump(mSock, callback);
262 }
263 
264 // Determines whether a socket is a loopback socket. Does not check socket state.
isLoopbackSocket(const inet_diag_msg * msg)265 bool SockDiag::isLoopbackSocket(const inet_diag_msg *msg) {
266     switch (msg->idiag_family) {
267         case AF_INET:
268             // Old kernels only copy the IPv4 address and leave the other 12 bytes uninitialized.
269             return IN_LOOPBACK(htonl(msg->id.idiag_src[0])) ||
270                    IN_LOOPBACK(htonl(msg->id.idiag_dst[0])) ||
271                    msg->id.idiag_src[0] == msg->id.idiag_dst[0];
272 
273         case AF_INET6: {
274             const struct in6_addr *src = (const struct in6_addr *) &msg->id.idiag_src;
275             const struct in6_addr *dst = (const struct in6_addr *) &msg->id.idiag_dst;
276             return (IN6_IS_ADDR_V4MAPPED(src) && IN_LOOPBACK(src->s6_addr32[3])) ||
277                    (IN6_IS_ADDR_V4MAPPED(dst) && IN_LOOPBACK(dst->s6_addr32[3])) ||
278                    IN6_IS_ADDR_LOOPBACK(src) || IN6_IS_ADDR_LOOPBACK(dst) ||
279                    !memcmp(src, dst, sizeof(*src));
280         }
281         default:
282             return false;
283     }
284 }
285 
sockDestroy(uint8_t proto,const inet_diag_msg * msg)286 int SockDiag::sockDestroy(uint8_t proto, const inet_diag_msg *msg) {
287     if (msg == nullptr) {
288        return 0;
289     }
290 
291     DestroyRequest request = {
292         .nlh = {
293             .nlmsg_type = SOCK_DESTROY,
294             .nlmsg_flags = NLM_F_REQUEST,
295         },
296         .req = {
297             .sdiag_family = msg->idiag_family,
298             .sdiag_protocol = proto,
299             .idiag_states = (uint32_t) (1 << msg->idiag_state),
300             .id = msg->id,
301         },
302     };
303     request.nlh.nlmsg_len = sizeof(request);
304 
305     if (write(mWriteSock, &request, sizeof(request)) < (ssize_t) sizeof(request)) {
306         return -errno;
307     }
308 
309     int ret = checkError(mWriteSock);
310     if (!ret) mSocketsDestroyed++;
311     return ret;
312 }
313 
destroySockets(uint8_t proto,int family,const char * addrstr,int ifindex)314 int SockDiag::destroySockets(uint8_t proto, int family, const char* addrstr, int ifindex) {
315     if (!hasSocks()) {
316         return -EBADFD;
317     }
318 
319     if (int ret = sendDumpRequest(proto, family, addrstr)) {
320         return ret;
321     }
322 
323     // Destroy all sockets on the address, except link-local sockets where ifindex doesn't match.
324     auto shouldDestroy = [ifindex](uint8_t, const inet_diag_msg* msg) {
325         return ifindex == 0 || ifindex == (int)msg->id.idiag_if;
326     };
327 
328     return readDiagMsg(proto, shouldDestroy);
329 }
330 
destroySockets(const char * addrstr,int ifindex)331 int SockDiag::destroySockets(const char* addrstr, int ifindex) {
332     Stopwatch s;
333     mSocketsDestroyed = 0;
334 
335     std::string where = addrstr;
336     if (ifindex) where += StringPrintf(" ifindex %d", ifindex);
337 
338     if (!strchr(addrstr, ':')) {  // inet_ntop never returns something like ::ffff:192.0.2.1
339         if (int ret = destroySockets(IPPROTO_TCP, AF_INET, addrstr, ifindex)) {
340             ALOGE("Failed to destroy IPv4 sockets on %s: %s",
341                 (isUser ? "[hidden: user build]" : where.c_str()), strerror(-ret));
342             return ret;
343         }
344     }
345     if (int ret = destroySockets(IPPROTO_TCP, AF_INET6, addrstr, ifindex)) {
346         ALOGE("Failed to destroy IPv6 sockets on %s: %s",
347             (isUser ? "[hidden: user build]" : where.c_str()), strerror(-ret));
348         return ret;
349     }
350 
351     if (mSocketsDestroyed > 0) {
352         ALOGI("Destroyed %d sockets on %s in %" PRId64 "us", mSocketsDestroyed,
353             (isUser ? "[hidden: user build]" : where.c_str()), s.timeTakenUs());
354     }
355 
356     return mSocketsDestroyed;
357 }
358 
destroyLiveSockets(const DestroyFilter & destroyFilter,const char * what,iovec * iov,int iovcnt)359 int SockDiag::destroyLiveSockets(const DestroyFilter& destroyFilter, const char *what,
360                                  iovec *iov, int iovcnt) {
361     const int proto = IPPROTO_TCP;
362     const uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
363 
364     for (const int family : {AF_INET, AF_INET6}) {
365         const char *familyName = (family == AF_INET) ? "IPv4" : "IPv6";
366         if (int ret = sendDumpRequest(proto, family, 0, states, iov, iovcnt)) {
367             ALOGE("Failed to dump %s sockets for %s: %s", familyName, what, strerror(-ret));
368             return ret;
369         }
370         if (int ret = readDiagMsg(proto, destroyFilter)) {
371             ALOGE("Failed to destroy %s sockets for %s: %s", familyName, what, strerror(-ret));
372             return ret;
373         }
374     }
375 
376     return 0;
377 }
378 
getLiveTcpInfos(const TcpInfoReader & tcpInfoReader)379 int SockDiag::getLiveTcpInfos(const TcpInfoReader& tcpInfoReader) {
380     const int proto = IPPROTO_TCP;
381     const uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
382     const uint8_t extensions = (1 << INET_DIAG_MEMINFO); // flag for dumping struct tcp_info.
383 
384     iovec iov[] = {
385         { nullptr, 0 },
386     };
387 
388     for (const int family : {AF_INET, AF_INET6}) {
389         const char *familyName = (family == AF_INET) ? "IPv4" : "IPv6";
390         if (int ret = sendDumpRequest(proto, family, extensions, states, iov, ARRAY_SIZE(iov))) {
391             ALOGE("Failed to dump %s sockets struct tcp_info: %s", familyName, strerror(-ret));
392             return ret;
393         }
394         if (int ret = readDiagMsgWithTcpInfo(tcpInfoReader)) {
395             ALOGE("Failed to read %s sockets struct tcp_info: %s", familyName, strerror(-ret));
396             return ret;
397         }
398     }
399 
400     return 0;
401 }
402 
destroySockets(uint8_t proto,const uid_t uid,bool excludeLoopback)403 int SockDiag::destroySockets(uint8_t proto, const uid_t uid, bool excludeLoopback) {
404     mSocketsDestroyed = 0;
405     Stopwatch s;
406 
407     auto shouldDestroy = [uid, excludeLoopback] (uint8_t, const inet_diag_msg *msg) {
408         return msg != nullptr &&
409                msg->idiag_uid == uid &&
410                !(excludeLoopback && isLoopbackSocket(msg));
411     };
412 
413     for (const int family : {AF_INET, AF_INET6}) {
414         const char *familyName = family == AF_INET ? "IPv4" : "IPv6";
415         uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
416         if (int ret = sendDumpRequest(proto, family, states)) {
417             ALOGE("Failed to dump %s sockets for UID: %s", familyName, strerror(-ret));
418             return ret;
419         }
420         if (int ret = readDiagMsg(proto, shouldDestroy)) {
421             ALOGE("Failed to destroy %s sockets for UID: %s", familyName, strerror(-ret));
422             return ret;
423         }
424     }
425 
426     if (mSocketsDestroyed > 0) {
427         ALOGI("Destroyed %d sockets for UID in %" PRId64 "us", mSocketsDestroyed, s.timeTakenUs());
428     }
429 
430     return 0;
431 }
432 
destroySockets(const UidRanges & uidRanges,const std::set<uid_t> & skipUids,bool excludeLoopback)433 int SockDiag::destroySockets(const UidRanges& uidRanges, const std::set<uid_t>& skipUids,
434                              bool excludeLoopback) {
435     mSocketsDestroyed = 0;
436     Stopwatch s;
437 
438     auto shouldDestroy = [&] (uint8_t, const inet_diag_msg *msg) {
439         return msg != nullptr &&
440                uidRanges.hasUid(msg->idiag_uid) &&
441                skipUids.find(msg->idiag_uid) == skipUids.end() &&
442                !(excludeLoopback && isLoopbackSocket(msg)) &&
443                !isAdbSocket(msg, getAdbPort());
444     };
445 
446     iovec iov[] = {
447         { nullptr, 0 },
448     };
449 
450     if (int ret = destroyLiveSockets(shouldDestroy, "UID", iov, ARRAY_SIZE(iov))) {
451         return ret;
452     }
453 
454     if (mSocketsDestroyed > 0) {
455         ALOGI("Destroyed %d sockets for %s skip={%s} in %" PRId64 "us", mSocketsDestroyed,
456               uidRanges.toString().c_str(), android::base::Join(skipUids, " ").c_str(),
457               s.timeTakenUs());
458     }
459 
460     return 0;
461 }
462 
463 // Destroys all "live" (CONNECTED, SYN_SENT, SYN_RECV) TCP sockets on the specified netId where:
464 // 1. The opening app no longer has permission to use this network, or:
465 // 2. The opening app does have permission, but did not explicitly select this network.
466 //
467 // We destroy sockets without the explicit bit because we want to avoid the situation where a
468 // privileged app uses its privileges without knowing it is doing so. For example, a privileged app
469 // might have opened a socket on this network just because it was the default network at the
470 // time. If we don't kill these sockets, those apps could continue to use them without realizing
471 // that they are now sending and receiving traffic on a network that is now restricted.
destroySocketsLackingPermission(unsigned netId,Permission permission,bool excludeLoopback)472 int SockDiag::destroySocketsLackingPermission(unsigned netId, Permission permission,
473                                               bool excludeLoopback) {
474     struct markmatch {
475         inet_diag_bc_op op;
476         // TODO: switch to inet_diag_markcond
477         __u32 mark;
478         __u32 mask;
479     } __attribute__((packed));
480     constexpr uint8_t matchlen = sizeof(markmatch);
481 
482     Fwmark netIdMark, netIdMask;
483     netIdMark.netId = netId;
484     netIdMask.netId = 0xffff;
485 
486     Fwmark controlMark;
487     controlMark.explicitlySelected = true;
488     controlMark.permission = permission;
489 
490     // A SOCK_DIAG bytecode program that accepts the sockets we intend to destroy.
491     struct bytecode {
492         markmatch netIdMatch;
493         markmatch controlMatch;
494         inet_diag_bc_op controlJump;
495     } __attribute__((packed)) bytecode;
496 
497     // The length of the INET_DIAG_BC_JMP instruction.
498     constexpr uint8_t jmplen = sizeof(inet_diag_bc_op);
499     // Jump exactly this far past the end of the program to reject.
500     constexpr uint8_t rejectoffset = sizeof(inet_diag_bc_op);
501     // Total length of the program.
502     constexpr uint8_t bytecodelen = sizeof(bytecode);
503 
504     bytecode = (struct bytecode) {
505         // If netId matches, continue, otherwise, reject (i.e., leave socket alone).
506         { { INET_DIAG_BC_MARK_COND, matchlen, bytecodelen + rejectoffset },
507           netIdMark.intValue, netIdMask.intValue },
508 
509         // If explicit and permission bits match, go to the JMP below which rejects the socket
510         // (i.e., we leave it alone). Otherwise, jump to the end of the program, which accepts the
511         // socket (so we destroy it).
512         { { INET_DIAG_BC_MARK_COND, matchlen, matchlen + jmplen },
513           controlMark.intValue, controlMark.intValue },
514 
515         // This JMP unconditionally rejects the packet by jumping to the reject target. It is
516         // necessary to keep the kernel bytecode verifier happy. If we don't have a JMP the bytecode
517         // is invalid because the target of every no jump must always be reachable by yes jumps.
518         // Without this JMP, the accept target is not reachable by yes jumps and the program will
519         // be rejected by the validator.
520         { INET_DIAG_BC_JMP, jmplen, jmplen + rejectoffset },
521 
522         // We have reached the end of the program. Accept the socket, and destroy it below.
523     };
524 
525     struct nlattr nla = {
526             .nla_len = sizeof(struct nlattr) + bytecodelen,
527             .nla_type = INET_DIAG_REQ_BYTECODE,
528     };
529 
530     iovec iov[] = {
531         { nullptr,   0 },
532         { &nla,      sizeof(nla) },
533         { &bytecode, bytecodelen },
534     };
535 
536     mSocketsDestroyed = 0;
537     Stopwatch s;
538 
539     auto shouldDestroy = [&] (uint8_t, const inet_diag_msg *msg) {
540         return msg != nullptr && !(excludeLoopback && isLoopbackSocket(msg));
541     };
542 
543     if (int ret = destroyLiveSockets(shouldDestroy, "permission change", iov, ARRAY_SIZE(iov))) {
544         return ret;
545     }
546 
547     if (mSocketsDestroyed > 0) {
548         ALOGI("Destroyed %d sockets for netId %d permission=%d in %" PRId64 "us", mSocketsDestroyed,
549               netId, permission, s.timeTakenUs());
550     }
551 
552     return 0;
553 }
554 
555 }  // namespace net
556 }  // namespace android
557