1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *
16  */
17 
18 #define LOG_TAG "TunForwarder"
19 
20 #include "tun_forwarder.h"
21 
22 #include <arpa/inet.h>
23 #include <linux/if.h>
24 #include <linux/if_tun.h>
25 #include <linux/ioctl.h>
26 #include <netinet/ip6.h>
27 #include <netinet/tcp.h>
28 #include <netinet/udp.h>
29 #include <sys/eventfd.h>
30 #include <sys/poll.h>
31 
32 #include <android-base/logging.h>
33 
34 extern "C" {
35 #include <checksum.h>
36 }
37 
38 using android::base::Error;
39 using android::base::Result;
40 using android::base::unique_fd;
41 using android::netdutils::Slice;
42 
43 namespace android::net {
44 
45 static constexpr int MAXMTU = 1500;
46 static constexpr ssize_t TUN_HDRLEN = sizeof(struct tun_pi);
47 static constexpr ssize_t IP4_HDRLEN = sizeof(struct iphdr);
48 static constexpr ssize_t IP6_HDRLEN = sizeof(struct ip6_hdr);
49 static constexpr ssize_t TCP_HDRLEN = sizeof(struct tcphdr);
50 static constexpr ssize_t UDP_HDRLEN = sizeof(struct udphdr);
51 
52 namespace {
53 
operator ==(const in6_addr & x,const in6_addr & y)54 bool operator==(const in6_addr& x, const in6_addr& y) {
55     return std::memcmp(x.s6_addr, y.s6_addr, 16) == 0;
56 }
57 
operator !=(const in6_addr & x,const in6_addr & y)58 bool operator!=(const in6_addr& x, const in6_addr& y) {
59     return !(x == y);
60 }
61 
operator <(const in6_addr & x,const in6_addr & y)62 bool operator<(const in6_addr& x, const in6_addr& y) {
63     return std::memcmp(x.s6_addr, y.s6_addr, 16) < 0;
64 }
65 
66 }  // namespace
67 
makePair(const std::array<std::string,2> & addrs)68 Result<TunForwarder::v4pair> TunForwarder::v4pair::makePair(
69         const std::array<std::string, 2>& addrs) {
70     v4pair pair;
71     if (inet_pton(AF_INET, addrs[0].c_str(), &pair.src) != 1 ||
72         inet_pton(AF_INET, addrs[1].c_str(), &pair.dst) != 1) {
73         return Error() << "Failed to make v4pair";
74     }
75     return pair;
76 }
77 
operator ==(const v4pair & o) const78 bool TunForwarder::v4pair::operator==(const v4pair& o) const {
79     return std::tie(src.s_addr, dst.s_addr) == std::tie(o.src.s_addr, o.dst.s_addr);
80 }
81 
operator <(const v4pair & o) const82 bool TunForwarder::v4pair::operator<(const v4pair& o) const {
83     return std::tie(src.s_addr, dst.s_addr) < std::tie(o.src.s_addr, o.dst.s_addr);
84 }
85 
makePair(const std::array<std::string,2> & addrs)86 Result<TunForwarder::v6pair> TunForwarder::v6pair::makePair(
87         const std::array<std::string, 2>& addrs) {
88     v6pair pair;
89     if (inet_pton(AF_INET6, addrs[0].c_str(), &pair.src) != 1 ||
90         inet_pton(AF_INET6, addrs[1].c_str(), &pair.dst) != 1) {
91         return Error() << "Failed to make v6pair";
92     }
93     return pair;
94 }
95 
operator ==(const v6pair & o) const96 bool TunForwarder::v6pair::operator==(const v6pair& o) const {
97     return src == o.src && dst == o.dst;
98 }
99 
operator <(const v6pair & o) const100 bool TunForwarder::v6pair::operator<(const v6pair& o) const {
101     if (src != o.src) return src < o.src;
102     return dst < o.dst;
103 }
104 
TunForwarder(unique_fd tunFd)105 TunForwarder::TunForwarder(unique_fd tunFd) : mTunFd(std::move(tunFd)) {
106     mEventFd.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC));
107 }
108 
~TunForwarder()109 TunForwarder::~TunForwarder() {
110     stopForwarding();
111     if (mForwarder.joinable()) {
112         mForwarder.join();
113     }
114 }
115 
startForwarding()116 bool TunForwarder::startForwarding() {
117     if (mForwarder.joinable()) return false;
118     mForwarder = std::thread(&TunForwarder::loop, this);
119     return true;
120 }
121 
stopForwarding()122 bool TunForwarder::stopForwarding() {
123     return signalEventFd();
124 }
125 
126 // Assume all of the strings in |from| and |to| are the IP addresses of the same IP version.
addForwardingRule(const std::array<std::string,2> & from,const std::array<std::string,2> & to)127 bool TunForwarder::addForwardingRule(const std::array<std::string, 2>& from,
128                                      const std::array<std::string, 2>& to) {
129     const bool isV4 = (from[0].find(':') == from[0].npos);
130     if (isV4) {
131         auto k = v4pair::makePair(from);
132         auto v = v4pair::makePair(to);
133         if (!k.ok() || !v.ok()) return false;
134         mRulesIpv4[k.value()] = v.value();
135     } else {
136         auto k = v6pair::makePair(from);
137         auto v = v6pair::makePair(to);
138         if (!k.ok() || !v.ok()) return false;
139         mRulesIpv6[k.value()] = v.value();
140     }
141     return true;
142 }
143 
createTun(const std::string & ifname)144 unique_fd TunForwarder::createTun(const std::string& ifname) {
145     unique_fd fd(open("/dev/tun", O_RDWR | O_NONBLOCK | O_CLOEXEC));
146     if (!fd.ok()) {
147         return {};
148     }
149 
150     ifreq ifr = {
151             .ifr_ifru = {.ifru_flags = IFF_TUN},
152     };
153     strlcpy(ifr.ifr_name, ifname.data(), sizeof(ifr.ifr_name));
154 
155     if (ioctl(fd.get(), TUNSETIFF, &ifr) == -1) {
156         PLOG(WARNING) << "failed to bring up tun " << ifr.ifr_name;
157         return {};
158     }
159 
160     unique_fd inet6CtrlSock(socket(AF_INET6, SOCK_DGRAM | SOCK_CLOEXEC, 0));
161     ifr.ifr_flags = IFF_UP;
162     if (ioctl(inet6CtrlSock.get(), SIOCSIFFLAGS, &ifr) == -1) {
163         PLOG(WARNING) << "failed on SIOCSIFFLAGS " << ifr.ifr_name;
164         return {};
165     }
166 
167     return fd;
168 }
169 
loop()170 void TunForwarder::loop() {
171     while (true) {
172         struct pollfd wait_fd[] = {
173                 {mEventFd, POLLIN, 0},
174                 {mTunFd.get(), POLLIN, 0},
175         };
176 
177         if (int ret = poll(wait_fd, std::size(wait_fd), kPollTimeoutMs); ret <= 0) {
178             break;
179         }
180 
181         if (wait_fd[0].revents & (POLLIN | POLLERR)) {
182             uint64_t value = 0;
183             eventfd_read(mEventFd, &value);
184             break;
185         }
186         if (wait_fd[1].revents & (POLLIN | POLLERR)) {
187             handlePacket(wait_fd[1].fd);
188         }
189     }
190 }
191 
handlePacket(int fd) const192 void TunForwarder::handlePacket(int fd) const {
193     uint8_t buf[MAXMTU + TUN_HDRLEN];
194 
195     ssize_t readlen = read(fd, buf, std::size(buf));
196     if (readlen < 0) {
197         PLOG(ERROR) << "failed to read packets from tun";
198         return;
199     } else if (readlen == 0) {
200         PLOG(ERROR) << "tun interface removed";
201         return;
202     }
203 
204     // Filter the packet. Only TCP and UDP packets are allowed.
205     const Slice tunPacket(buf, readlen);
206     if (auto result = validatePacket(tunPacket); !result.ok()) {
207         LOG(DEBUG) << "validatePacket failed: " << result.error();
208         return;
209     }
210 
211     // Change the packet's source/destination address and checksum.
212     if (auto result = translatePacket(tunPacket); !result.ok()) {
213         LOG(ERROR) << "translatePacket failed: " << result.error();
214     }
215 
216     // Write the new packet to the fd, causing the kernel to receive it on the tun interface.
217     write(fd, buf, readlen);
218 }
219 
validatePacket(Slice tunPacket) const220 Result<void> TunForwarder::validatePacket(Slice tunPacket) const {
221     if (tunPacket.size() < TUN_HDRLEN) {
222         return Error() << "Too short for a tun header";
223     }
224 
225     const tun_pi* const tunHeader = reinterpret_cast<tun_pi*>(tunPacket.base());
226     if (tunHeader->flags != 0) {
227         return Error() << "Unexpected tun flags " << static_cast<int>(tunHeader->flags);
228     }
229 
230     switch (uint16_t proto = ntohs(tunHeader->proto); proto) {
231         case ETH_P_IP:
232             return validateIpv4Packet(drop(tunPacket, TUN_HDRLEN));
233         case ETH_P_IPV6:
234             return validateIpv6Packet(drop(tunPacket, TUN_HDRLEN));
235         default:
236             return Error() << "Unsupported packet type 0x" << std::hex << static_cast<int>(proto);
237     }
238 }
239 
validateIpv4Packet(Slice ipv4Packet) const240 Result<void> TunForwarder::validateIpv4Packet(Slice ipv4Packet) const {
241     if (ipv4Packet.size() < IP4_HDRLEN) {
242         return Error() << "Too short for an ip header";
243     }
244 
245     const iphdr* const ipHeader = reinterpret_cast<iphdr*>(ipv4Packet.base());
246     if (ipHeader->ihl < 5) {
247         return Error() << "IP header length set to less than 5";
248     }
249     if (ipHeader->ihl * 4 > ipv4Packet.size()) {
250         return Error() << "IP header length set too large: " << ipHeader->ihl;
251     }
252     if (ipHeader->version != 4) {
253         return Error() << "IP header version not 4: " << ipHeader->version;
254     }
255     if (mRulesIpv4.find({ipHeader->saddr, ipHeader->daddr}) == mRulesIpv4.end()) {
256         return Error() << "Can't find any v4 rule. Packet hex dump: " << toHex(ipv4Packet, 32);
257     }
258 
259     switch (ipHeader->protocol) {
260         case IPPROTO_UDP:
261             return validateUdpPacket(drop(ipv4Packet, ipHeader->ihl * 4));
262         case IPPROTO_TCP:
263             return validateTcpPacket(drop(ipv4Packet, ipHeader->ihl * 4));
264         default:
265             return Error() << "Unsupported transport protocol "
266                            << static_cast<int>(ipHeader->protocol);
267     }
268 }
269 
validateIpv6Packet(Slice ipv6Packet) const270 Result<void> TunForwarder::validateIpv6Packet(Slice ipv6Packet) const {
271     if (ipv6Packet.size() < IP6_HDRLEN) {
272         return Error() << "Too short for an ipv6 header";
273     }
274 
275     const ip6_hdr* const ipv6Header = reinterpret_cast<ip6_hdr*>(ipv6Packet.base());
276     if (mRulesIpv6.find({ipv6Header->ip6_src, ipv6Header->ip6_dst}) == mRulesIpv6.end()) {
277         return Error() << "Can't find any v6 rule. Packet hex dump: " << toHex(ipv6Packet, 32);
278     }
279 
280     switch (ipv6Header->ip6_nxt) {
281         case IPPROTO_UDP:
282             return validateUdpPacket(drop(ipv6Packet, IP6_HDRLEN));
283         case IPPROTO_TCP:
284             return validateTcpPacket(drop(ipv6Packet, IP6_HDRLEN));
285         default:
286             return Error() << "Expect TCP/UDP in ipv6 next header: "
287                            << static_cast<int>(ipv6Header->ip6_nxt);
288     }
289 }
290 
validateUdpPacket(Slice udpPacket) const291 Result<void> TunForwarder::validateUdpPacket(Slice udpPacket) const {
292     if (udpPacket.size() < UDP_HDRLEN) {
293         return Error() << "Too short for a udp header";
294     }
295     return {};
296 }
297 
validateTcpPacket(Slice tcpPacket) const298 Result<void> TunForwarder::validateTcpPacket(Slice tcpPacket) const {
299     if (tcpPacket.size() < TCP_HDRLEN) {
300         return Error() << "Too short for a tcp header";
301     }
302 
303     const tcphdr* const tcpHeader = reinterpret_cast<tcphdr*>(tcpPacket.base());
304     if (tcpHeader->doff < 5) {
305         return Error() << "TCP header length set to less than 5";
306     }
307     if (tcpHeader->doff * 4 > tcpPacket.size()) {
308         return Error() << "TCP header length set too large: " << tcpHeader->doff;
309     }
310     return {};
311 }
312 
translatePacket(Slice tunPacket) const313 Result<void> TunForwarder::translatePacket(Slice tunPacket) const {
314     const tun_pi* const tunHeader = reinterpret_cast<tun_pi*>(tunPacket.base());
315     switch (uint16_t proto = ntohs(tunHeader->proto); proto) {
316         case ETH_P_IP:
317             return translateIpv4Packet(drop(tunPacket, TUN_HDRLEN));
318         case ETH_P_IPV6:
319             return translateIpv6Packet(drop(tunPacket, TUN_HDRLEN));
320         default:
321             return Error() << "translate: Unsupported packet type 0x" << std::hex
322                            << static_cast<int>(proto);
323     }
324 }
325 
translateIpv4Packet(Slice ipv4Packet) const326 Result<void> TunForwarder::translateIpv4Packet(Slice ipv4Packet) const {
327     iphdr* ipHeader = reinterpret_cast<iphdr*>(ipv4Packet.base());
328     const size_t ipHeaderLen = ipHeader->ihl * 4;
329     const size_t transport_len = ipv4Packet.size() - ipHeaderLen;
330 
331     uint32_t oldPseudoSum = ipv4_pseudo_header_checksum(ipHeader, transport_len);
332     for (const auto& [from, to] : mRulesIpv4) {
333         if (ipHeader->saddr == static_cast<int>(from.src.s_addr) &&
334             ipHeader->daddr == static_cast<int>(from.dst.s_addr)) {
335             ipHeader->saddr = to.src.s_addr;
336             ipHeader->daddr = to.dst.s_addr;
337             break;
338         }
339     }
340     uint32_t newPseudoSum = ipv4_pseudo_header_checksum(ipHeader, transport_len);
341 
342     ipHeader->check = 0;
343     ipHeader->check = ip_checksum(ipHeader, sizeof(struct iphdr));
344 
345     switch (ipHeader->protocol) {
346         case IPPROTO_UDP:
347             translateUdpPacket(drop(ipv4Packet, ipHeaderLen), oldPseudoSum, newPseudoSum);
348             break;
349         case IPPROTO_TCP:
350             translateTcpPacket(drop(ipv4Packet, ipHeaderLen), oldPseudoSum, newPseudoSum);
351             break;
352         default:
353             return Error() << "translate: Unsupported transport protocol "
354                            << static_cast<int>(ipHeader->protocol);
355     }
356 
357     return {};
358 }
359 
translateIpv6Packet(Slice ipv6Packet) const360 Result<void> TunForwarder::translateIpv6Packet(Slice ipv6Packet) const {
361     ip6_hdr* ipv6Header = reinterpret_cast<ip6_hdr*>(ipv6Packet.base());
362     const size_t ipHeaderLen = IP6_HDRLEN;
363     const size_t transport_len = ipv6Packet.size() - ipHeaderLen;
364 
365     uint32_t oldPseudoSum =
366             ipv6_pseudo_header_checksum(ipv6Header, transport_len, ipv6Header->ip6_nxt);
367     for (const auto& [from, to] : mRulesIpv6) {
368         if (ipv6Header->ip6_src == from.src && ipv6Header->ip6_dst == from.dst) {
369             ipv6Header->ip6_src = to.src;
370             ipv6Header->ip6_dst = to.dst;
371             break;
372         }
373     }
374     uint32_t newPseudoSum =
375             ipv6_pseudo_header_checksum(ipv6Header, transport_len, ipv6Header->ip6_nxt);
376 
377     switch (ipv6Header->ip6_nxt) {
378         case IPPROTO_UDP:
379             translateUdpPacket(drop(ipv6Packet, ipHeaderLen), oldPseudoSum, newPseudoSum);
380             break;
381         case IPPROTO_TCP:
382             translateTcpPacket(drop(ipv6Packet, ipHeaderLen), oldPseudoSum, newPseudoSum);
383             break;
384         default:
385             return Error() << "transliate: Expect TCP/UDP in ipv6 next header: "
386                            << static_cast<int>(ipv6Header->ip6_nxt);
387     }
388 
389     return {};
390 }
391 
translateUdpPacket(Slice udpPacket,uint32_t oldPseudoSum,uint32_t newPseudoSum) const392 void TunForwarder::translateUdpPacket(Slice udpPacket, uint32_t oldPseudoSum,
393                                       uint32_t newPseudoSum) const {
394     udphdr* udpHeader = reinterpret_cast<udphdr*>(udpPacket.base());
395     if (udpHeader->check) {
396         udpHeader->check = ip_checksum_adjust(udpHeader->check, oldPseudoSum, newPseudoSum);
397     } else {
398         uint32_t tmp = ip_checksum_add(newPseudoSum, udpPacket.base(), udpPacket.size());
399         udpHeader->check = ip_checksum_finish(tmp);
400     }
401 
402     // RFC 768: "If the computed checksum is zero, it is transmitted as all ones (the equivalent
403     // in one's complement arithmetic)."
404     if (!udpHeader->check) {
405         udpHeader->check = 0xffff;
406     }
407 }
408 
translateTcpPacket(Slice tcpPacket,uint32_t oldPseudoSum,uint32_t newPseudoSum) const409 void TunForwarder::translateTcpPacket(Slice tcpPacket, uint32_t oldPseudoSum,
410                                       uint32_t newPseudoSum) const {
411     tcphdr* tcpHeader = reinterpret_cast<tcphdr*>(tcpPacket.base());
412     tcpHeader->check = ip_checksum_adjust(tcpHeader->check, oldPseudoSum, newPseudoSum);
413 }
414 
signalEventFd()415 bool TunForwarder::signalEventFd() {
416     return eventfd_write(mEventFd.get(), 1) == 0;
417 }
418 
419 }  // namespace android::net
420