1#!/usr/bin/python3
2#
3# Copyright 2014 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import contextlib
18import fcntl
19import os
20import random
21import re
22from socket import *  # pylint: disable=wildcard-import
23import struct
24import sys
25import unittest
26
27from scapy import all as scapy
28
29import binascii
30import csocket
31import gki
32
33# TODO: Move these to csocket.py.
34SOL_IPV6 = 41
35IP_RECVERR = 11
36IPV6_RECVERR = 25
37IP_TRANSPARENT = 19
38IPV6_TRANSPARENT = 75
39IPV6_TCLASS = 67
40IPV6_FLOWLABEL_MGR = 32
41IPV6_FLOWINFO_SEND = 33
42
43SO_BINDTODEVICE = 25
44SO_MARK = 36
45SO_PROTOCOL = 38
46SO_DOMAIN = 39
47SO_COOKIE = 57
48
49ETH_P_IP = 0x0800
50ETH_P_IPV6 = 0x86dd
51
52IPPROTO_GRE = 47
53
54SIOCSIFHWADDR = 0x8924
55
56IPV6_FL_A_GET = 0
57IPV6_FL_A_PUT = 1
58IPV6_FL_A_RENEW = 1
59
60IPV6_FL_F_CREATE = 1
61IPV6_FL_F_EXCL = 2
62
63IPV6_FL_S_NONE = 0
64IPV6_FL_S_EXCL = 1
65IPV6_FL_S_ANY = 255
66
67IFNAMSIZ = 16
68
69IPV4_PING = b"\x08\x00\x00\x00\x0a\xce\x00\x03"
70IPV6_PING = b"\x80\x00\x00\x00\x0a\xce\x00\x03"
71
72IPV4_ADDR = "8.8.8.8"
73IPV4_ADDR2 = "8.8.4.4"
74IPV6_ADDR = "2001:4860:4860::8888"
75IPV6_ADDR2 = "2001:4860:4860::8844"
76
77IPV6_SEQ_DGRAM_HEADER = ("  sl  "
78                         "local_address                         "
79                         "remote_address                        "
80                         "st tx_queue rx_queue tr tm->when retrnsmt"
81                         "   uid  timeout inode ref pointer drops\n")
82
83UDP_HDR_LEN = 8
84
85# Arbitrary packet payload.
86UDP_PAYLOAD = bytes(scapy.DNS(rd=1,
87                              id=random.randint(0, 65535),
88                              qd=scapy.DNSQR(qname="wWW.GoOGle.CoM",
89                                             qtype="AAAA")))
90
91# Unix group to use if we want to open sockets as non-root.
92AID_INET = 3003
93
94# Kernel log verbosity levels.
95KERN_INFO = 6
96
97# The following ends up being (VERSION, PATCHLEVEL, SUBLEVEL) from top of kernel's Makefile
98LINUX_VERSION = csocket.LinuxVersion()
99
100LINUX_ANY_VERSION = (0, 0, 0)
101
102# Linus always releases x.y.0-rcZ or x.y.0, any stable (incl. LTS) release will be x.y.1+
103IS_STABLE = (LINUX_VERSION[2] > 0)
104
105# From //system/gsid/libgsi.cpp IsGsiRunning()
106IS_GSI = os.access("/metadata/gsi/dsu/booted", os.F_OK)
107
108# NonGXI() is useful to run tests starting from a specific kernel version,
109# thus allowing one to test for correctly backported fixes,
110# without running the tests on non-updatable kernels (as part of GSI tests).
111#
112# Running vts_net_test on GSI image basically doesn't make sense, since
113# it's not like the unmodified vendor image - including the kernel - can be
114# realistically fixed in such a setup. Particularly problematic is GSI
115# on *older* pixel vendor: newer pixel images will have the fixed kernel,
116# but running newer GSI against ancient vendor will not see those fixes.
117#
118# Normally you'd also want to run on GKI kernels, but older release branches
119# are no longer maintained, so they also need to be excluded.
120# Proper GKI testing will happen on at the tip of the appropriate ACK/GKI branch.
121def NonGXI(major, minor):
122  """Checks the kernel version is >= major.minor, and not GKI or GSI."""
123
124  if IS_GSI or gki.IS_GKI:
125    return False
126  return LINUX_VERSION >= (major, minor, 0)
127
128def KernelAtLeast(versions):
129  """Checks the kernel version matches the specified versions.
130
131  Args:
132    versions: a list of versions expressed as tuples,
133    e.g., [(5, 10, 108), (5, 15, 31)]. The kernel version matches if it's
134    between each specified version and the next minor version with last digit
135    set to 0. In this example, the kernel version must match either:
136      >= 5.10.108 and < 5.15.0
137      >= 5.15.31
138    While this is less flexible than matching exact tuples, it allows the caller
139    to pass in fewer arguments, because Android only supports certain minor
140    versions (4.19, 5.4, 5.10, ...)
141
142  Returns:
143    True if the kernel version matches, False otherwise
144  """
145  maxversion = (1000, 255, 65535)
146  for version in sorted(versions, reverse=True):
147    if version[:2] == maxversion[:2]:
148      raise ValueError("Duplicate minor version: %s %s", (version, maxversion))
149    if LINUX_VERSION >= version and LINUX_VERSION < maxversion:
150      return True
151    maxversion = (version[0], version[1], 0)
152  return False
153
154def ByteToHex(b):
155  return "%02x" % (ord(b) if isinstance(b, str) else b)
156
157def GetWildcardAddress(version):
158  return {4: "0.0.0.0", 6: "::"}[version]
159
160def GetIpHdrLength(version):
161  return {4: 20, 6: 40}[version]
162
163def GetAddressFamily(version):
164  return {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
165
166
167def AddressLengthBits(version):
168  return {4: 32, 6: 128}[version]
169
170def GetAddressVersion(address):
171  if ":" not in address:
172    return 4
173  if address.startswith("::ffff"):
174    return 5
175  return 6
176
177def SetSocketTos(s, tos):
178  level = {AF_INET: SOL_IP, AF_INET6: SOL_IPV6}[s.family]
179  option = {AF_INET: IP_TOS, AF_INET6: IPV6_TCLASS}[s.family]
180  s.setsockopt(level, option, tos)
181
182
183def SetNonBlocking(fd):
184  flags = fcntl.fcntl(fd, fcntl.F_GETFL, 0)
185  fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
186
187
188# Convenience functions to create sockets.
189def Socket(family, sock_type, protocol):
190  s = socket(family, sock_type, protocol)
191  csocket.SetSocketTimeout(s, 5000)
192  return s
193
194
195def PingSocket(family):
196  proto = {AF_INET: IPPROTO_ICMP, AF_INET6: IPPROTO_ICMPV6}[family]
197  return Socket(family, SOCK_DGRAM, proto)
198
199
200def IPv4PingSocket():
201  return PingSocket(AF_INET)
202
203
204def IPv6PingSocket():
205  return PingSocket(AF_INET6)
206
207
208def TCPSocket(family):
209  s = Socket(family, SOCK_STREAM, IPPROTO_TCP)
210  SetNonBlocking(s.fileno())
211  return s
212
213
214def IPv4TCPSocket():
215  return TCPSocket(AF_INET)
216
217
218def IPv6TCPSocket():
219  return TCPSocket(AF_INET6)
220
221
222def UDPSocket(family):
223  return Socket(family, SOCK_DGRAM, IPPROTO_UDP)
224
225
226def RawGRESocket(family):
227  s = Socket(family, SOCK_RAW, IPPROTO_GRE)
228  return s
229
230
231def BindRandomPort(version, sock):
232  addr = {4: "0.0.0.0", 5: "::", 6: "::"}[version]
233  sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
234  sock.bind((addr, 0))
235  if sock.getsockopt(SOL_SOCKET, SO_PROTOCOL) == IPPROTO_TCP:
236    sock.listen(100)
237  port = sock.getsockname()[1]
238  return port
239
240
241def EnableFinWait(sock):
242  # Disabling SO_LINGER causes sockets to go into FIN_WAIT on close().
243  sock.setsockopt(SOL_SOCKET, SO_LINGER, struct.pack("ii", 0, 0))
244
245
246def DisableFinWait(sock):
247  # Enabling SO_LINGER with a timeout of zero causes close() to send RST.
248  sock.setsockopt(SOL_SOCKET, SO_LINGER, struct.pack("ii", 1, 0))
249
250
251def CreateSocketPair(family, socktype, addr):
252  clientsock = socket(family, socktype, 0)
253  listensock = socket(family, socktype, 0)
254  listensock.bind((addr, 0))
255  addr = listensock.getsockname()
256  if socktype == SOCK_STREAM:
257    listensock.listen(1)
258  clientsock.connect(listensock.getsockname())
259  if socktype == SOCK_STREAM:
260    acceptedsock, _ = listensock.accept()
261    DisableFinWait(clientsock)
262    DisableFinWait(acceptedsock)
263    listensock.close()
264  else:
265    listensock.connect(clientsock.getsockname())
266    acceptedsock = listensock
267  return clientsock, acceptedsock
268
269
270def GetInterfaceIndex(ifname):
271  with UDPSocket(AF_INET) as s:
272    ifr = struct.pack("%dsi" % IFNAMSIZ, ifname.encode(), 0)
273    ifr = fcntl.ioctl(s, scapy.SIOCGIFINDEX, ifr)
274    return struct.unpack("%dsi" % IFNAMSIZ, ifr)[1]
275
276
277def SetInterfaceHWAddr(ifname, hwaddr):
278  with UDPSocket(AF_INET) as s:
279    hwaddr = hwaddr.replace(":", "")
280    hwaddr = binascii.unhexlify(hwaddr)
281    if len(hwaddr) != 6:
282      raise ValueError("Unknown hardware address length %d" % len(hwaddr))
283    ifr = struct.pack("%dsH6s" % IFNAMSIZ, ifname.encode(), scapy.ARPHDR_ETHER,
284                      hwaddr)
285    fcntl.ioctl(s, SIOCSIFHWADDR, ifr)
286
287
288def SetInterfaceState(ifname, up):
289  ifname_bytes = ifname.encode()
290  with UDPSocket(AF_INET) as s:
291    ifr = struct.pack("%dsH" % IFNAMSIZ, ifname_bytes, 0)
292    ifr = fcntl.ioctl(s, scapy.SIOCGIFFLAGS, ifr)
293    _, flags = struct.unpack("%dsH" % IFNAMSIZ, ifr)
294    if up:
295      flags |= scapy.IFF_UP
296    else:
297      flags &= ~scapy.IFF_UP
298    ifr = struct.pack("%dsH" % IFNAMSIZ, ifname_bytes, flags)
299    ifr = fcntl.ioctl(s, scapy.SIOCSIFFLAGS, ifr)
300
301
302def SetInterfaceUp(ifname):
303  return SetInterfaceState(ifname, True)
304
305
306def SetInterfaceDown(ifname):
307  return SetInterfaceState(ifname, False)
308
309
310def CanonicalizeIPv6Address(addr):
311  return inet_ntop(AF_INET6, inet_pton(AF_INET6, addr))
312
313
314def FormatProcAddress(unformatted):
315  groups = []
316  for i in range(0, len(unformatted), 4):
317    groups.append(unformatted[i:i+4])
318  formatted = ":".join(groups)
319  # Compress the address.
320  address = CanonicalizeIPv6Address(formatted)
321  return address
322
323
324def FormatSockStatAddress(address):
325  if ":" in address:
326    family = AF_INET6
327  else:
328    family = AF_INET
329  binary = inet_pton(family, address)
330  out = ""
331  for i in range(0, len(binary), 4):
332    out += "%08X" % struct.unpack("=L", binary[i:i+4])
333  return out
334
335
336def GetLinkAddress(ifname, linklocal):
337  with open("/proc/net/if_inet6") as if_inet6:
338    addresses = if_inet6.readlines()
339  for address in addresses:
340    address = [s for s in address.strip().split(" ") if s]
341    if address[5] == ifname:
342      if (linklocal and address[0].startswith("fe80")
343          or not linklocal and not address[0].startswith("fe80")):
344        # Convert the address from raw hex to something with colons in it.
345        return FormatProcAddress(address[0])
346  return None
347
348
349def GetDefaultRoute(version=6):
350  if version == 6:
351    with open("/proc/net/ipv6_route") as ipv6_route:
352      routes = ipv6_route.readlines()
353    for route in routes:
354      route = [s for s in route.strip().split(" ") if s]
355      if (route[0] == "00000000000000000000000000000000" and route[1] == "00"
356          # Routes in non-default tables end up in /proc/net/ipv6_route!!!
357          and route[9] != "lo" and not route[9].startswith("nettest")):
358        return FormatProcAddress(route[4]), route[9]
359    raise ValueError("No IPv6 default route found")
360  elif version == 4:
361    with open("/proc/net/route") as ipv4_route:
362      routes = ipv4_route.readlines()
363    for route in routes:
364      route = [s for s in route.strip().split("\t") if s]
365      if route[1] == "00000000" and route[7] == "00000000":
366        gw, iface = route[2], route[0]
367        gw = inet_ntop(AF_INET, binascii.unhexlify(gw)[::-1])
368        return gw, iface
369    raise ValueError("No IPv4 default route found")
370  else:
371    raise ValueError("Don't know about IPv%s" % version)
372
373
374def GetDefaultRouteInterface():
375  unused_gw, iface = GetDefaultRoute()
376  return iface
377
378
379def MakeFlowLabelOption(addr, label):
380  # struct in6_flowlabel_req {
381  #         struct in6_addr flr_dst;
382  #         __be32  flr_label;
383  #         __u8    flr_action;
384  #         __u8    flr_share;
385  #         __u16   flr_flags;
386  #         __u16   flr_expires;
387  #         __u16   flr_linger;
388  #         __u32   __flr_pad;
389  #         /* Options in format of IPV6_PKTOPTIONS */
390  # };
391  fmt = "16sIBBHHH4s"
392  assert struct.calcsize(fmt) == 32
393  addr = inet_pton(AF_INET6, addr)
394  assert len(addr) == 16
395  label = htonl(label & 0xfffff)
396  action = IPV6_FL_A_GET
397  share = IPV6_FL_S_ANY
398  flags = IPV6_FL_F_CREATE
399  pad = b"\x00" * 4
400  return struct.pack(fmt, addr, label, action, share, flags, 0, 0, pad)
401
402
403def SetFlowLabel(s, addr, label):
404  opt = MakeFlowLabelOption(addr, label)
405  s.setsockopt(SOL_IPV6, IPV6_FLOWLABEL_MGR, opt)
406  # Caller also needs to do s.setsockopt(SOL_IPV6, IPV6_FLOWINFO_SEND, 1).
407
408
409def GetIptablesBinaryPath(version):
410  if version == 4:
411    paths = (
412        "/sbin/iptables-legacy",
413        "/sbin/iptables",
414        "/system/bin/iptables-legacy",
415        "/system/bin/iptables",
416    )
417  elif version == 6:
418    paths = (
419        "/sbin/ip6tables-legacy",
420        "/sbin/ip6tables",
421        "/system/bin/ip6tables-legacy",
422        "/system/bin/ip6tables",
423    )
424  for iptables_path in paths:
425    if os.access(iptables_path, os.X_OK):
426      return iptables_path
427  raise FileNotFoundError(
428      "iptables binary for IPv{} not found".format(version) +
429      ", checked: {}".format(", ".join(paths)))
430
431
432def RunIptablesCommand(version, args):
433  iptables_path = GetIptablesBinaryPath(version)
434  return os.spawnvp(
435      os.P_WAIT, iptables_path,
436      [iptables_path, "-w"] + args.split(" "))
437
438# Determine network configuration.
439try:
440  GetDefaultRoute(version=4)
441  HAVE_IPV4 = True
442except ValueError:
443  HAVE_IPV4 = False
444
445try:
446  GetDefaultRoute(version=6)
447  HAVE_IPV6 = True
448except ValueError:
449  HAVE_IPV6 = False
450
451class RunAsUidGid(object):
452  """Context guard to run a code block as a given UID."""
453
454  def __init__(self, uid, gid):
455    self.uid = uid
456    self.gid = gid
457
458  def __enter__(self):
459    if self.gid:
460      self.saved_gid = os.getgid()
461      os.setgid(self.gid)
462    if self.uid:
463      self.saved_uids = os.getresuid()
464      self.saved_groups = os.getgroups()
465      os.setgroups(self.saved_groups + [AID_INET])
466      os.setresuid(self.uid, self.uid, self.saved_uids[0])
467
468  def __exit__(self, unused_type, unused_value, unused_traceback):
469    if self.uid:
470      os.setresuid(*self.saved_uids)
471      os.setgroups(self.saved_groups)
472    if self.gid:
473      os.setgid(self.saved_gid)
474
475class RunAsUid(RunAsUidGid):
476  """Context guard to run a code block as a given GID and UID."""
477
478  def __init__(self, uid):
479    RunAsUidGid.__init__(self, uid, 0)
480
481class NetworkTest(unittest.TestCase):
482
483  @contextlib.contextmanager
484  def _errnoCheck(self, err_num):
485    with self.assertRaises(EnvironmentError) as context:
486      yield context
487    self.assertEqual(context.exception.errno, err_num)
488
489  def assertRaisesErrno(self, err_num, f=None, *args):
490    """Test that the system returns an errno error.
491
492    This works similarly to unittest.TestCase.assertRaises. You can call it as
493    an assertion, or use it as a context manager.
494    e.g.
495        self.assertRaisesErrno(errno.ENOENT, do_things, arg1, arg2)
496    or
497        with self.assertRaisesErrno(errno.ENOENT):
498          do_things(arg1, arg2)
499
500    Args:
501      err_num: an errno constant
502      f: (optional) A callable that should result in error
503      *args: arguments passed to f
504    """
505    if f is None:
506      return self._errnoCheck(err_num)
507    else:
508      with self._errnoCheck(err_num):
509        f(*args)
510
511  def ReadProcNetSocket(self, protocol):
512    # Read file.
513    filename = "/proc/net/%s" % protocol
514    with open(filename) as f:
515      lines = f.readlines()
516
517    # Possibly check, and strip, header.
518    if protocol in ["icmp6", "raw6", "udp6"]:
519      self.assertEqual(IPV6_SEQ_DGRAM_HEADER, lines[0])
520    lines = lines[1:]
521
522    # Check contents.
523    if protocol.endswith("6"):
524      addrlen = 32
525    else:
526      addrlen = 8
527
528    if protocol.startswith("tcp"):
529      # Real sockets have 5 extra numbers, timewait sockets have none.
530      end_regexp = "(| +[0-9]+ [0-9]+ [0-9]+ [0-9]+ -?[0-9]+)$"
531    elif re.match("icmp|udp|raw", protocol):
532      # Drops.
533      end_regexp = " +([0-9]+) *$"
534    else:
535      raise ValueError("Don't know how to parse %s" % filename)
536
537    regexp = re.compile(r" *(\d+): "                    # bucket
538                        "([0-9A-F]{%d}:[0-9A-F]{4}) "   # srcaddr, port
539                        "([0-9A-F]{%d}:[0-9A-F]{4}) "   # dstaddr, port
540                        "([0-9A-F][0-9A-F]) "           # state
541                        "([0-9A-F]{8}:[0-9A-F]{8}) "    # mem
542                        "([0-9A-F]{2}:[0-9A-F]{8}) "    # ?
543                        "([0-9A-F]{8}) +"               # ?
544                        "([0-9]+) +"                    # uid
545                        "([0-9]+) +"                    # timeout
546                        "([0-9]+) +"                    # inode
547                        "([0-9]+) +"                    # refcnt
548                        "([0-9a-f]+)"                   # sp
549                        "%s"                            # icmp has spaces
550                        % (addrlen, addrlen, end_regexp))
551    # Return a list of lists with only source / dest addresses for now.
552    # TODO: consider returning a dict or namedtuple instead.
553    out = []
554    for line in lines:
555      m = regexp.match(line)
556      if m is None:
557        raise ValueError("Failed match on [%s]" % line)
558      (_, src, dst, state, mem,
559       _, _, uid, _, _, refcnt, _, extra) = m.groups()
560      out.append([src, dst, state, mem, uid, refcnt, extra])
561    return out
562
563  @staticmethod
564  def GetConsoleLogLevel():
565    with open("/proc/sys/kernel/printk") as printk:
566      return int(printk.readline().split()[0])
567
568  @staticmethod
569  def SetConsoleLogLevel(level):
570    with open("/proc/sys/kernel/printk", "w") as printk:
571      return printk.write("%s\n" % level)
572
573
574if __name__ == "__main__":
575  unittest.main()
576