1#!/usr/bin/python3 2# 3# Copyright 2014 The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17import contextlib 18import fcntl 19import os 20import random 21import re 22from socket import * # pylint: disable=wildcard-import 23import struct 24import sys 25import unittest 26 27from scapy import all as scapy 28 29import binascii 30import csocket 31import gki 32 33# TODO: Move these to csocket.py. 34SOL_IPV6 = 41 35IP_RECVERR = 11 36IPV6_RECVERR = 25 37IP_TRANSPARENT = 19 38IPV6_TRANSPARENT = 75 39IPV6_TCLASS = 67 40IPV6_FLOWLABEL_MGR = 32 41IPV6_FLOWINFO_SEND = 33 42 43SO_BINDTODEVICE = 25 44SO_MARK = 36 45SO_PROTOCOL = 38 46SO_DOMAIN = 39 47SO_COOKIE = 57 48 49ETH_P_IP = 0x0800 50ETH_P_IPV6 = 0x86dd 51 52IPPROTO_GRE = 47 53 54SIOCSIFHWADDR = 0x8924 55 56IPV6_FL_A_GET = 0 57IPV6_FL_A_PUT = 1 58IPV6_FL_A_RENEW = 1 59 60IPV6_FL_F_CREATE = 1 61IPV6_FL_F_EXCL = 2 62 63IPV6_FL_S_NONE = 0 64IPV6_FL_S_EXCL = 1 65IPV6_FL_S_ANY = 255 66 67IFNAMSIZ = 16 68 69IPV4_PING = b"\x08\x00\x00\x00\x0a\xce\x00\x03" 70IPV6_PING = b"\x80\x00\x00\x00\x0a\xce\x00\x03" 71 72IPV4_ADDR = "8.8.8.8" 73IPV4_ADDR2 = "8.8.4.4" 74IPV6_ADDR = "2001:4860:4860::8888" 75IPV6_ADDR2 = "2001:4860:4860::8844" 76 77IPV6_SEQ_DGRAM_HEADER = (" sl " 78 "local_address " 79 "remote_address " 80 "st tx_queue rx_queue tr tm->when retrnsmt" 81 " uid timeout inode ref pointer drops\n") 82 83UDP_HDR_LEN = 8 84 85# Arbitrary packet payload. 86UDP_PAYLOAD = bytes(scapy.DNS(rd=1, 87 id=random.randint(0, 65535), 88 qd=scapy.DNSQR(qname="wWW.GoOGle.CoM", 89 qtype="AAAA"))) 90 91# Unix group to use if we want to open sockets as non-root. 92AID_INET = 3003 93 94# Kernel log verbosity levels. 95KERN_INFO = 6 96 97# The following ends up being (VERSION, PATCHLEVEL, SUBLEVEL) from top of kernel's Makefile 98LINUX_VERSION = csocket.LinuxVersion() 99 100LINUX_ANY_VERSION = (0, 0, 0) 101 102# Linus always releases x.y.0-rcZ or x.y.0, any stable (incl. LTS) release will be x.y.1+ 103IS_STABLE = (LINUX_VERSION[2] > 0) 104 105# From //system/gsid/libgsi.cpp IsGsiRunning() 106IS_GSI = os.access("/metadata/gsi/dsu/booted", os.F_OK) 107 108# NonGXI() is useful to run tests starting from a specific kernel version, 109# thus allowing one to test for correctly backported fixes, 110# without running the tests on non-updatable kernels (as part of GSI tests). 111# 112# Running vts_net_test on GSI image basically doesn't make sense, since 113# it's not like the unmodified vendor image - including the kernel - can be 114# realistically fixed in such a setup. Particularly problematic is GSI 115# on *older* pixel vendor: newer pixel images will have the fixed kernel, 116# but running newer GSI against ancient vendor will not see those fixes. 117# 118# Normally you'd also want to run on GKI kernels, but older release branches 119# are no longer maintained, so they also need to be excluded. 120# Proper GKI testing will happen on at the tip of the appropriate ACK/GKI branch. 121def NonGXI(major, minor): 122 """Checks the kernel version is >= major.minor, and not GKI or GSI.""" 123 124 if IS_GSI or gki.IS_GKI: 125 return False 126 return LINUX_VERSION >= (major, minor, 0) 127 128def KernelAtLeast(versions): 129 """Checks the kernel version matches the specified versions. 130 131 Args: 132 versions: a list of versions expressed as tuples, 133 e.g., [(5, 10, 108), (5, 15, 31)]. The kernel version matches if it's 134 between each specified version and the next minor version with last digit 135 set to 0. In this example, the kernel version must match either: 136 >= 5.10.108 and < 5.15.0 137 >= 5.15.31 138 While this is less flexible than matching exact tuples, it allows the caller 139 to pass in fewer arguments, because Android only supports certain minor 140 versions (4.19, 5.4, 5.10, ...) 141 142 Returns: 143 True if the kernel version matches, False otherwise 144 """ 145 maxversion = (1000, 255, 65535) 146 for version in sorted(versions, reverse=True): 147 if version[:2] == maxversion[:2]: 148 raise ValueError("Duplicate minor version: %s %s", (version, maxversion)) 149 if LINUX_VERSION >= version and LINUX_VERSION < maxversion: 150 return True 151 maxversion = (version[0], version[1], 0) 152 return False 153 154def ByteToHex(b): 155 return "%02x" % (ord(b) if isinstance(b, str) else b) 156 157def GetWildcardAddress(version): 158 return {4: "0.0.0.0", 6: "::"}[version] 159 160def GetIpHdrLength(version): 161 return {4: 20, 6: 40}[version] 162 163def GetAddressFamily(version): 164 return {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version] 165 166 167def AddressLengthBits(version): 168 return {4: 32, 6: 128}[version] 169 170def GetAddressVersion(address): 171 if ":" not in address: 172 return 4 173 if address.startswith("::ffff"): 174 return 5 175 return 6 176 177def SetSocketTos(s, tos): 178 level = {AF_INET: SOL_IP, AF_INET6: SOL_IPV6}[s.family] 179 option = {AF_INET: IP_TOS, AF_INET6: IPV6_TCLASS}[s.family] 180 s.setsockopt(level, option, tos) 181 182 183def SetNonBlocking(fd): 184 flags = fcntl.fcntl(fd, fcntl.F_GETFL, 0) 185 fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) 186 187 188# Convenience functions to create sockets. 189def Socket(family, sock_type, protocol): 190 s = socket(family, sock_type, protocol) 191 csocket.SetSocketTimeout(s, 5000) 192 return s 193 194 195def PingSocket(family): 196 proto = {AF_INET: IPPROTO_ICMP, AF_INET6: IPPROTO_ICMPV6}[family] 197 return Socket(family, SOCK_DGRAM, proto) 198 199 200def IPv4PingSocket(): 201 return PingSocket(AF_INET) 202 203 204def IPv6PingSocket(): 205 return PingSocket(AF_INET6) 206 207 208def TCPSocket(family): 209 s = Socket(family, SOCK_STREAM, IPPROTO_TCP) 210 SetNonBlocking(s.fileno()) 211 return s 212 213 214def IPv4TCPSocket(): 215 return TCPSocket(AF_INET) 216 217 218def IPv6TCPSocket(): 219 return TCPSocket(AF_INET6) 220 221 222def UDPSocket(family): 223 return Socket(family, SOCK_DGRAM, IPPROTO_UDP) 224 225 226def RawGRESocket(family): 227 s = Socket(family, SOCK_RAW, IPPROTO_GRE) 228 return s 229 230 231def BindRandomPort(version, sock): 232 addr = {4: "0.0.0.0", 5: "::", 6: "::"}[version] 233 sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1) 234 sock.bind((addr, 0)) 235 if sock.getsockopt(SOL_SOCKET, SO_PROTOCOL) == IPPROTO_TCP: 236 sock.listen(100) 237 port = sock.getsockname()[1] 238 return port 239 240 241def EnableFinWait(sock): 242 # Disabling SO_LINGER causes sockets to go into FIN_WAIT on close(). 243 sock.setsockopt(SOL_SOCKET, SO_LINGER, struct.pack("ii", 0, 0)) 244 245 246def DisableFinWait(sock): 247 # Enabling SO_LINGER with a timeout of zero causes close() to send RST. 248 sock.setsockopt(SOL_SOCKET, SO_LINGER, struct.pack("ii", 1, 0)) 249 250 251def CreateSocketPair(family, socktype, addr): 252 clientsock = socket(family, socktype, 0) 253 listensock = socket(family, socktype, 0) 254 listensock.bind((addr, 0)) 255 addr = listensock.getsockname() 256 if socktype == SOCK_STREAM: 257 listensock.listen(1) 258 clientsock.connect(listensock.getsockname()) 259 if socktype == SOCK_STREAM: 260 acceptedsock, _ = listensock.accept() 261 DisableFinWait(clientsock) 262 DisableFinWait(acceptedsock) 263 listensock.close() 264 else: 265 listensock.connect(clientsock.getsockname()) 266 acceptedsock = listensock 267 return clientsock, acceptedsock 268 269 270def GetInterfaceIndex(ifname): 271 with UDPSocket(AF_INET) as s: 272 ifr = struct.pack("%dsi" % IFNAMSIZ, ifname.encode(), 0) 273 ifr = fcntl.ioctl(s, scapy.SIOCGIFINDEX, ifr) 274 return struct.unpack("%dsi" % IFNAMSIZ, ifr)[1] 275 276 277def SetInterfaceHWAddr(ifname, hwaddr): 278 with UDPSocket(AF_INET) as s: 279 hwaddr = hwaddr.replace(":", "") 280 hwaddr = binascii.unhexlify(hwaddr) 281 if len(hwaddr) != 6: 282 raise ValueError("Unknown hardware address length %d" % len(hwaddr)) 283 ifr = struct.pack("%dsH6s" % IFNAMSIZ, ifname.encode(), scapy.ARPHDR_ETHER, 284 hwaddr) 285 fcntl.ioctl(s, SIOCSIFHWADDR, ifr) 286 287 288def SetInterfaceState(ifname, up): 289 ifname_bytes = ifname.encode() 290 with UDPSocket(AF_INET) as s: 291 ifr = struct.pack("%dsH" % IFNAMSIZ, ifname_bytes, 0) 292 ifr = fcntl.ioctl(s, scapy.SIOCGIFFLAGS, ifr) 293 _, flags = struct.unpack("%dsH" % IFNAMSIZ, ifr) 294 if up: 295 flags |= scapy.IFF_UP 296 else: 297 flags &= ~scapy.IFF_UP 298 ifr = struct.pack("%dsH" % IFNAMSIZ, ifname_bytes, flags) 299 ifr = fcntl.ioctl(s, scapy.SIOCSIFFLAGS, ifr) 300 301 302def SetInterfaceUp(ifname): 303 return SetInterfaceState(ifname, True) 304 305 306def SetInterfaceDown(ifname): 307 return SetInterfaceState(ifname, False) 308 309 310def CanonicalizeIPv6Address(addr): 311 return inet_ntop(AF_INET6, inet_pton(AF_INET6, addr)) 312 313 314def FormatProcAddress(unformatted): 315 groups = [] 316 for i in range(0, len(unformatted), 4): 317 groups.append(unformatted[i:i+4]) 318 formatted = ":".join(groups) 319 # Compress the address. 320 address = CanonicalizeIPv6Address(formatted) 321 return address 322 323 324def FormatSockStatAddress(address): 325 if ":" in address: 326 family = AF_INET6 327 else: 328 family = AF_INET 329 binary = inet_pton(family, address) 330 out = "" 331 for i in range(0, len(binary), 4): 332 out += "%08X" % struct.unpack("=L", binary[i:i+4]) 333 return out 334 335 336def GetLinkAddress(ifname, linklocal): 337 with open("/proc/net/if_inet6") as if_inet6: 338 addresses = if_inet6.readlines() 339 for address in addresses: 340 address = [s for s in address.strip().split(" ") if s] 341 if address[5] == ifname: 342 if (linklocal and address[0].startswith("fe80") 343 or not linklocal and not address[0].startswith("fe80")): 344 # Convert the address from raw hex to something with colons in it. 345 return FormatProcAddress(address[0]) 346 return None 347 348 349def GetDefaultRoute(version=6): 350 if version == 6: 351 with open("/proc/net/ipv6_route") as ipv6_route: 352 routes = ipv6_route.readlines() 353 for route in routes: 354 route = [s for s in route.strip().split(" ") if s] 355 if (route[0] == "00000000000000000000000000000000" and route[1] == "00" 356 # Routes in non-default tables end up in /proc/net/ipv6_route!!! 357 and route[9] != "lo" and not route[9].startswith("nettest")): 358 return FormatProcAddress(route[4]), route[9] 359 raise ValueError("No IPv6 default route found") 360 elif version == 4: 361 with open("/proc/net/route") as ipv4_route: 362 routes = ipv4_route.readlines() 363 for route in routes: 364 route = [s for s in route.strip().split("\t") if s] 365 if route[1] == "00000000" and route[7] == "00000000": 366 gw, iface = route[2], route[0] 367 gw = inet_ntop(AF_INET, binascii.unhexlify(gw)[::-1]) 368 return gw, iface 369 raise ValueError("No IPv4 default route found") 370 else: 371 raise ValueError("Don't know about IPv%s" % version) 372 373 374def GetDefaultRouteInterface(): 375 unused_gw, iface = GetDefaultRoute() 376 return iface 377 378 379def MakeFlowLabelOption(addr, label): 380 # struct in6_flowlabel_req { 381 # struct in6_addr flr_dst; 382 # __be32 flr_label; 383 # __u8 flr_action; 384 # __u8 flr_share; 385 # __u16 flr_flags; 386 # __u16 flr_expires; 387 # __u16 flr_linger; 388 # __u32 __flr_pad; 389 # /* Options in format of IPV6_PKTOPTIONS */ 390 # }; 391 fmt = "16sIBBHHH4s" 392 assert struct.calcsize(fmt) == 32 393 addr = inet_pton(AF_INET6, addr) 394 assert len(addr) == 16 395 label = htonl(label & 0xfffff) 396 action = IPV6_FL_A_GET 397 share = IPV6_FL_S_ANY 398 flags = IPV6_FL_F_CREATE 399 pad = b"\x00" * 4 400 return struct.pack(fmt, addr, label, action, share, flags, 0, 0, pad) 401 402 403def SetFlowLabel(s, addr, label): 404 opt = MakeFlowLabelOption(addr, label) 405 s.setsockopt(SOL_IPV6, IPV6_FLOWLABEL_MGR, opt) 406 # Caller also needs to do s.setsockopt(SOL_IPV6, IPV6_FLOWINFO_SEND, 1). 407 408 409def GetIptablesBinaryPath(version): 410 if version == 4: 411 paths = ( 412 "/sbin/iptables-legacy", 413 "/sbin/iptables", 414 "/system/bin/iptables-legacy", 415 "/system/bin/iptables", 416 ) 417 elif version == 6: 418 paths = ( 419 "/sbin/ip6tables-legacy", 420 "/sbin/ip6tables", 421 "/system/bin/ip6tables-legacy", 422 "/system/bin/ip6tables", 423 ) 424 for iptables_path in paths: 425 if os.access(iptables_path, os.X_OK): 426 return iptables_path 427 raise FileNotFoundError( 428 "iptables binary for IPv{} not found".format(version) + 429 ", checked: {}".format(", ".join(paths))) 430 431 432def RunIptablesCommand(version, args): 433 iptables_path = GetIptablesBinaryPath(version) 434 return os.spawnvp( 435 os.P_WAIT, iptables_path, 436 [iptables_path, "-w"] + args.split(" ")) 437 438# Determine network configuration. 439try: 440 GetDefaultRoute(version=4) 441 HAVE_IPV4 = True 442except ValueError: 443 HAVE_IPV4 = False 444 445try: 446 GetDefaultRoute(version=6) 447 HAVE_IPV6 = True 448except ValueError: 449 HAVE_IPV6 = False 450 451class RunAsUidGid(object): 452 """Context guard to run a code block as a given UID.""" 453 454 def __init__(self, uid, gid): 455 self.uid = uid 456 self.gid = gid 457 458 def __enter__(self): 459 if self.gid: 460 self.saved_gid = os.getgid() 461 os.setgid(self.gid) 462 if self.uid: 463 self.saved_uids = os.getresuid() 464 self.saved_groups = os.getgroups() 465 os.setgroups(self.saved_groups + [AID_INET]) 466 os.setresuid(self.uid, self.uid, self.saved_uids[0]) 467 468 def __exit__(self, unused_type, unused_value, unused_traceback): 469 if self.uid: 470 os.setresuid(*self.saved_uids) 471 os.setgroups(self.saved_groups) 472 if self.gid: 473 os.setgid(self.saved_gid) 474 475class RunAsUid(RunAsUidGid): 476 """Context guard to run a code block as a given GID and UID.""" 477 478 def __init__(self, uid): 479 RunAsUidGid.__init__(self, uid, 0) 480 481class NetworkTest(unittest.TestCase): 482 483 @contextlib.contextmanager 484 def _errnoCheck(self, err_num): 485 with self.assertRaises(EnvironmentError) as context: 486 yield context 487 self.assertEqual(context.exception.errno, err_num) 488 489 def assertRaisesErrno(self, err_num, f=None, *args): 490 """Test that the system returns an errno error. 491 492 This works similarly to unittest.TestCase.assertRaises. You can call it as 493 an assertion, or use it as a context manager. 494 e.g. 495 self.assertRaisesErrno(errno.ENOENT, do_things, arg1, arg2) 496 or 497 with self.assertRaisesErrno(errno.ENOENT): 498 do_things(arg1, arg2) 499 500 Args: 501 err_num: an errno constant 502 f: (optional) A callable that should result in error 503 *args: arguments passed to f 504 """ 505 if f is None: 506 return self._errnoCheck(err_num) 507 else: 508 with self._errnoCheck(err_num): 509 f(*args) 510 511 def ReadProcNetSocket(self, protocol): 512 # Read file. 513 filename = "/proc/net/%s" % protocol 514 with open(filename) as f: 515 lines = f.readlines() 516 517 # Possibly check, and strip, header. 518 if protocol in ["icmp6", "raw6", "udp6"]: 519 self.assertEqual(IPV6_SEQ_DGRAM_HEADER, lines[0]) 520 lines = lines[1:] 521 522 # Check contents. 523 if protocol.endswith("6"): 524 addrlen = 32 525 else: 526 addrlen = 8 527 528 if protocol.startswith("tcp"): 529 # Real sockets have 5 extra numbers, timewait sockets have none. 530 end_regexp = "(| +[0-9]+ [0-9]+ [0-9]+ [0-9]+ -?[0-9]+)$" 531 elif re.match("icmp|udp|raw", protocol): 532 # Drops. 533 end_regexp = " +([0-9]+) *$" 534 else: 535 raise ValueError("Don't know how to parse %s" % filename) 536 537 regexp = re.compile(r" *(\d+): " # bucket 538 "([0-9A-F]{%d}:[0-9A-F]{4}) " # srcaddr, port 539 "([0-9A-F]{%d}:[0-9A-F]{4}) " # dstaddr, port 540 "([0-9A-F][0-9A-F]) " # state 541 "([0-9A-F]{8}:[0-9A-F]{8}) " # mem 542 "([0-9A-F]{2}:[0-9A-F]{8}) " # ? 543 "([0-9A-F]{8}) +" # ? 544 "([0-9]+) +" # uid 545 "([0-9]+) +" # timeout 546 "([0-9]+) +" # inode 547 "([0-9]+) +" # refcnt 548 "([0-9a-f]+)" # sp 549 "%s" # icmp has spaces 550 % (addrlen, addrlen, end_regexp)) 551 # Return a list of lists with only source / dest addresses for now. 552 # TODO: consider returning a dict or namedtuple instead. 553 out = [] 554 for line in lines: 555 m = regexp.match(line) 556 if m is None: 557 raise ValueError("Failed match on [%s]" % line) 558 (_, src, dst, state, mem, 559 _, _, uid, _, _, refcnt, _, extra) = m.groups() 560 out.append([src, dst, state, mem, uid, refcnt, extra]) 561 return out 562 563 @staticmethod 564 def GetConsoleLogLevel(): 565 with open("/proc/sys/kernel/printk") as printk: 566 return int(printk.readline().split()[0]) 567 568 @staticmethod 569 def SetConsoleLogLevel(level): 570 with open("/proc/sys/kernel/printk", "w") as printk: 571 return printk.write("%s\n" % level) 572 573 574if __name__ == "__main__": 575 unittest.main() 576