1#!/usr/bin/python3
2#
3# Copyright 2014 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Base module for multinetwork tests."""
18
19import errno
20import fcntl
21import os
22import posix
23import random
24import re
25from socket import *  # pylint: disable=wildcard-import
26import struct
27import time
28
29from scapy import all as scapy
30
31import csocket
32import iproute
33import net_test
34
35
36IFF_TUN = 1
37IFF_TAP = 2
38IFF_NO_PI = 0x1000
39TUNSETIFF = 0x400454ca
40
41SO_BINDTODEVICE = 25
42
43# Setsockopt values.
44IP_UNICAST_IF = 50
45IPV6_MULTICAST_IF = 17
46IPV6_UNICAST_IF = 76
47
48# Cmsg values.
49IP_TTL = 2
50IPV6_2292PKTOPTIONS = 6
51IPV6_FLOWINFO = 11
52IPV6_HOPLIMIT = 52  # Different from IPV6_UNICAST_HOPS, this is cmsg only.
53
54
55ACCEPT_RA_MIN_LFT_SYSCTL = "/proc/sys/net/ipv6/conf/default/accept_ra_min_lft"
56AUTOCONF_TABLE_SYSCTL = "/proc/sys/net/ipv6/conf/default/accept_ra_rt_table"
57IPV4_MARK_REFLECT_SYSCTL = "/proc/sys/net/ipv4/fwmark_reflect"
58IPV6_MARK_REFLECT_SYSCTL = "/proc/sys/net/ipv6/fwmark_reflect"
59RA_HONOR_PIO_LIFE_SYSCTL = "/proc/sys/net/ipv6/conf/default/ra_honor_pio_life"
60
61HAVE_ACCEPT_RA_MIN_LFT = (os.path.isfile(ACCEPT_RA_MIN_LFT_SYSCTL) or
62                          net_test.NonGXI(5, 10) or
63                          net_test.KernelAtLeast([(5, 10, 199), (5, 15, 136),
64                                                  (6, 1, 57), (6, 6, 0)]))
65HAVE_AUTOCONF_TABLE = os.path.isfile(AUTOCONF_TABLE_SYSCTL)
66HAVE_RA_HONOR_PIO_LIFE = (os.path.isfile(RA_HONOR_PIO_LIFE_SYSCTL) or
67                          net_test.KernelAtLeast([(6, 7, 0)]))
68
69
70class ConfigurationError(AssertionError):
71  pass
72
73
74class UnexpectedPacketError(AssertionError):
75  pass
76
77
78def MakePktInfo(version, addr, ifindex):
79  family = {4: AF_INET, 6: AF_INET6}[version]
80  if not addr:
81    addr = {4: "0.0.0.0", 6: "::"}[version]
82  if addr:
83    addr = inet_pton(family, addr)
84  if version == 6:
85    return csocket.In6Pktinfo((addr, ifindex)).Pack()
86  else:
87    return csocket.InPktinfo((ifindex, addr, b"\x00" * 4)).Pack()
88
89
90class MultiNetworkBaseTest(net_test.NetworkTest):
91  """Base class for all multinetwork tests.
92
93  This class does not contain any test code, but contains code to set up and
94  tear a multi-network environment using multiple tun interfaces. The
95  environment is designed to be similar to a real Android device in terms of
96  rules and routes, and supports IPv4 and IPv6.
97
98  Tests wishing to use this environment should inherit from this class and
99  ensure that any setupClass, tearDownClass, setUp, and tearDown methods they
100  implement also call the superclass versions.
101  """
102
103  # Must be between 1 and 256, since we put them in MAC addresses and IIDs.
104  NETIDS = [100, 150, 200, 250]
105
106  # Stores sysctl values to write back when the test completes.
107  saved_sysctls = {}
108
109  # Wether to output setup commands.
110  DEBUG = False
111
112  UID_RANGE_START = 2000
113  UID_RANGE_END = 9999
114  UID_RANGE_SIZE = UID_RANGE_END - UID_RANGE_START + 1
115
116  # Rule priorities.
117  PRIORITY_UID = 100
118  PRIORITY_OIF = 200
119  PRIORITY_FWMARK = 300
120  PRIORITY_IIF = 400
121  PRIORITY_DEFAULT = 999
122  PRIORITY_UNREACHABLE = 1000
123
124  # Actual device routing is more complicated, involving more than one rule
125  # per NetId, but here we make do with just one rule that selects the lower
126  # 16 bits.
127  NETID_FWMASK = 0xffff
128
129  # For convenience.
130  IPV4_ADDR = net_test.IPV4_ADDR
131  IPV6_ADDR = net_test.IPV6_ADDR
132  IPV4_ADDR2 = net_test.IPV4_ADDR2
133  IPV6_ADDR2 = net_test.IPV6_ADDR2
134  IPV4_PING = net_test.IPV4_PING
135  IPV6_PING = net_test.IPV6_PING
136
137  RA_VALIDITY = 600 # seconds
138
139  @classmethod
140  def UidRangeForNetid(cls, netid):
141    per_netid_range = int(cls.UID_RANGE_SIZE / len(cls.NETIDS))
142    idx = cls.NETIDS.index(netid)
143    return (
144        cls.UID_RANGE_START + per_netid_range * idx,
145        cls.UID_RANGE_START + per_netid_range * (idx + 1) - 1
146    )
147
148  @classmethod
149  def UidForNetid(cls, netid):
150    if not netid:
151      return 0
152    return random.randint(*cls.UidRangeForNetid(netid))
153
154  @classmethod
155  def _TableForNetid(cls, netid):
156    if cls.AUTOCONF_TABLE_OFFSET and netid in cls.ifindices:
157      return cls.ifindices[netid] + (-cls.AUTOCONF_TABLE_OFFSET)
158    else:
159      return netid
160
161  @staticmethod
162  def GetInterfaceName(netid):
163    return "nettest%d" % netid
164
165  @staticmethod
166  def RouterMacAddress(netid):
167    return "02:00:00:00:%02x:00" % netid
168
169  @staticmethod
170  def MyMacAddress(netid):
171    return "02:00:00:00:%02x:01" % netid
172
173  @staticmethod
174  def _RouterAddress(netid, version):
175    if version == 6:
176      return "fe80::%02x00" % netid
177    elif version == 4:
178      return "10.0.%d.1" % netid
179    else:
180      raise ValueError("Don't support IPv%s" % version)
181
182  @classmethod
183  def _MyIPv4Address(cls, netid):
184    return "10.0.%d.2" % netid
185
186  @classmethod
187  def _MyIPv6Address(cls, netid):
188    return net_test.GetLinkAddress(cls.GetInterfaceName(netid), False)
189
190  @classmethod
191  def MyAddress(cls, version, netid):
192    return {4: cls._MyIPv4Address(netid),
193            5: cls._MyIPv4Address(netid),
194            6: cls._MyIPv6Address(netid)}[version]
195
196  @classmethod
197  def MySocketAddress(cls, version, netid):
198    addr = cls.MyAddress(version, netid)
199    return "::ffff:" + addr if version == 5 else addr
200
201  @classmethod
202  def MyLinkLocalAddress(cls, netid):
203    return net_test.GetLinkAddress(cls.GetInterfaceName(netid), True)
204
205  @staticmethod
206  def OnlinkPrefixLen(version):
207    return {4: 24, 6: 64}[version]
208
209  @staticmethod
210  def OnlinkPrefix(version, netid):
211    return {4: "10.0.%d.0" % netid,
212            6: "2001:db8:%02x::" % netid}[version]
213
214  @staticmethod
215  def GetRandomDestination(prefix):
216    if "." in prefix:
217      return prefix + "%d.%d" % (random.randint(0, 255), random.randint(0, 255))
218    else:
219      return prefix + "%x:%x" % (random.randint(0, 65535),
220                                 random.randint(0, 65535))
221
222  def GetProtocolFamily(self, version):
223    return {4: AF_INET, 6: AF_INET6}[version]
224
225  @classmethod
226  def CreateTunInterface(cls, netid):
227    iface = cls.GetInterfaceName(netid)
228    try:
229      f = open("/dev/net/tun", "r+b", buffering=0)
230    except IOError:
231      f = open("/dev/tun", "r+b", buffering=0)
232    ifr = struct.pack("16sH", iface.encode(), IFF_TAP | IFF_NO_PI)
233    ifr += b"\x00" * (40 - len(ifr))
234    fcntl.ioctl(f, TUNSETIFF, ifr)
235    # Give ourselves a predictable MAC address.
236    net_test.SetInterfaceHWAddr(iface, cls.MyMacAddress(netid))
237    # Disable DAD so we don't have to wait for it.
238    cls.SetSysctl("/proc/sys/net/ipv6/conf/%s/accept_dad" % iface, 0)
239    # Set accept_ra to 2, because that's what we use.
240    cls.SetSysctl("/proc/sys/net/ipv6/conf/%s/accept_ra" % iface, 2)
241    net_test.SetInterfaceUp(iface)
242    net_test.SetNonBlocking(f)
243    return f
244
245  @classmethod
246  def SendRA(cls, netid, retranstimer=None, reachabletime=0, routerlft=RA_VALIDITY,
247             piolft=RA_VALIDITY, m=0, o=0, options=()):
248    macaddr = cls.RouterMacAddress(netid)
249    lladdr = cls._RouterAddress(netid, 6)
250
251    if retranstimer is None:
252      # If no retrans timer was specified, pick one that's as long as the
253      # router lifetime. This ensures that no spurious ND retransmits
254      # will interfere with test expectations.
255      retranstimer = routerlft * 1000  # Lifetime is in s, retrans timer in ms.
256
257    # We don't want any routes in the main table. If the kernel doesn't support
258    # putting RA routes into per-interface tables, configure routing manually.
259    if not HAVE_AUTOCONF_TABLE:
260      routerlft = 0
261
262    ra = (scapy.Ether(src=macaddr, dst="33:33:00:00:00:01") /
263          scapy.IPv6(src=lladdr, hlim=255) /
264          scapy.ICMPv6ND_RA(reachabletime=reachabletime,
265                            retranstimer=retranstimer,
266                            routerlifetime=routerlft,
267                            M=m, O=o) /
268          scapy.ICMPv6NDOptSrcLLAddr(lladdr=macaddr) /
269          scapy.ICMPv6NDOptPrefixInfo(prefix=cls.OnlinkPrefix(6, netid),
270                                      prefixlen=cls.OnlinkPrefixLen(6),
271                                      L=1, A=1,
272                                      validlifetime=piolft,
273                                      preferredlifetime=piolft))
274    for option in options:
275      ra /= option
276    posix.write(cls.tuns[netid].fileno(), bytes(ra))
277
278  @classmethod
279  def _RunSetupCommands(cls, netid, is_add):
280    for version in [4, 6]:
281      # Find out how to configure things.
282      iface = cls.GetInterfaceName(netid)
283      ifindex = cls.ifindices[netid]
284      macaddr = cls.RouterMacAddress(netid)
285      router = cls._RouterAddress(netid, version)
286      table = cls._TableForNetid(netid)
287
288      # Set up routing rules.
289      start, end = cls.UidRangeForNetid(netid)
290      cls.iproute.UidRangeRule(version, is_add, start, end, table,
291                               cls.PRIORITY_UID)
292      cls.iproute.OifRule(version, is_add, iface, table, cls.PRIORITY_OIF)
293      cls.iproute.FwmarkRule(version, is_add, netid, cls.NETID_FWMASK, table,
294                             cls.PRIORITY_FWMARK)
295
296      # Configure routing and addressing.
297      #
298      # IPv6 uses autoconf for everything, except if per-device autoconf routing
299      # tables are not supported, in which case the default route (only) is
300      # configured manually. For IPv4 we have to manually configure addresses,
301      # routes, and neighbour cache entries (since we don't reply to ARP or ND).
302      #
303      # Since deleting addresses also causes routes to be deleted, we need to
304      # be careful with ordering or the delete commands will fail with ENOENT.
305      #
306      # A real Android system will have both IPv4 and IPv6 routes for
307      # directly-connected subnets in the per-interface routing tables. Ensure
308      # we create those as well.
309      do_routing = (version == 4 or cls.AUTOCONF_TABLE_OFFSET is None)
310      if is_add:
311        if version == 4:
312          cls.iproute.AddAddress(cls._MyIPv4Address(netid),
313                                 cls.OnlinkPrefixLen(4), ifindex)
314          cls.iproute.AddNeighbour(version, router, macaddr, ifindex)
315        if do_routing:
316          cls.iproute.AddRoute(version, table,
317                               cls.OnlinkPrefix(version, netid),
318                               cls.OnlinkPrefixLen(version), None, ifindex)
319          cls.iproute.AddRoute(version, table, "default", 0, router, ifindex)
320      else:
321        if do_routing:
322          cls.iproute.DelRoute(version, table, "default", 0, router, ifindex)
323          cls.iproute.DelRoute(version, table,
324                               cls.OnlinkPrefix(version, netid),
325                               cls.OnlinkPrefixLen(version), None, ifindex)
326        if version == 4:
327          cls.iproute.DelNeighbour(version, router, macaddr, ifindex)
328          cls.iproute.DelAddress(cls._MyIPv4Address(netid),
329                                 cls.OnlinkPrefixLen(4), ifindex)
330
331  @classmethod
332  def SetMarkReflectSysctls(cls, value):
333    """Makes kernel-generated replies use the mark of the original packet."""
334    cls.SetSysctl(IPV4_MARK_REFLECT_SYSCTL, value)
335    cls.SetSysctl(IPV6_MARK_REFLECT_SYSCTL, value)
336
337  @classmethod
338  def _SetInboundMarking(cls, netid, iface, is_add):
339    for version in [4, 6]:
340      # Run iptables to set up incoming packet marking.
341      add_del = "-A" if is_add else "-D"
342      iptables = {4: "iptables", 6: "ip6tables"}[version]
343      args = "%s INPUT -t mangle -i %s -j MARK --set-mark %d" % (
344          add_del, iface, netid)
345      if net_test.RunIptablesCommand(version, args):
346        raise ConfigurationError("Setup command failed: %s" % args)
347
348  @classmethod
349  def SetInboundMarks(cls, is_add):
350    for netid in cls.tuns:
351      cls._SetInboundMarking(netid, cls.GetInterfaceName(netid), is_add)
352
353  @classmethod
354  def SetDefaultNetwork(cls, netid):
355    table = cls._TableForNetid(netid) if netid else None
356    for version in [4, 6]:
357      is_add = table is not None
358      cls.iproute.DefaultRule(version, is_add, table, cls.PRIORITY_DEFAULT)
359
360  @classmethod
361  def ClearDefaultNetwork(cls):
362    cls.SetDefaultNetwork(None)
363
364  @classmethod
365  def GetSysctl(cls, sysctl):
366    with open(sysctl, "r") as sysctl_file:
367      return sysctl_file.read()
368
369  @classmethod
370  def SetSysctl(cls, sysctl, value):
371    # Only save each sysctl value the first time we set it. This is so we can
372    # set it to arbitrary values multiple times and still write it back
373    # correctly at the end.
374    if sysctl not in cls.saved_sysctls:
375      cls.saved_sysctls[sysctl] = cls.GetSysctl(sysctl)
376    with open(sysctl, "w") as sysctl_file:
377      sysctl_file.write(str(value) + "\n")
378
379  @classmethod
380  def SetIPv6SysctlOnAllIfaces(cls, sysctl, value):
381    for netid in cls.tuns:
382      iface = cls.GetInterfaceName(netid)
383      name = "/proc/sys/net/ipv6/conf/%s/%s" % (iface, sysctl)
384      cls.SetSysctl(name, value)
385
386  @classmethod
387  def _RestoreSysctls(cls):
388    for sysctl, value in cls.saved_sysctls.items():
389      try:
390        with open(sysctl, "w") as sysctl_file:
391          sysctl_file.write(value)
392      except IOError:
393        pass
394
395  @classmethod
396  def _ICMPRatelimitFilename(cls, version):
397    return "/proc/sys/net/" + {4: "ipv4/icmp_ratelimit",
398                               6: "ipv6/icmp/ratelimit"}[version]
399
400  @classmethod
401  def _SetICMPRatelimit(cls, version, limit):
402    cls.SetSysctl(cls._ICMPRatelimitFilename(version), limit)
403
404  @classmethod
405  def setUpClass(cls):
406    # This is per-class setup instead of per-testcase setup because shelling out
407    # to ip and iptables is slow, and because routing configuration doesn't
408    # change during the test.
409    cls.iproute = iproute.IPRoute()
410    cls.tuns = {}
411    cls.ifindices = {}
412    if HAVE_AUTOCONF_TABLE:
413      cls.SetSysctl(AUTOCONF_TABLE_SYSCTL, -1000)
414      cls.AUTOCONF_TABLE_OFFSET = -1000
415    else:
416      cls.AUTOCONF_TABLE_OFFSET = None
417
418    # Disable ICMP rate limits. These will be restored by _RestoreSysctls.
419    for version in [4, 6]:
420      cls._SetICMPRatelimit(version, 0)
421
422    for version in [4, 6]:
423      cls.iproute.UnreachableRule(version, True, cls.PRIORITY_UNREACHABLE)
424
425    for netid in cls.NETIDS:
426      cls.tuns[netid] = cls.CreateTunInterface(netid)
427      iface = cls.GetInterfaceName(netid)
428      cls.ifindices[netid] = net_test.GetInterfaceIndex(iface)
429
430      cls.SendRA(netid)
431      cls._RunSetupCommands(netid, True)
432
433    # Don't print lots of "device foo entered promiscuous mode" warnings.
434    cls.loglevel = cls.GetConsoleLogLevel()
435    cls.SetConsoleLogLevel(net_test.KERN_INFO)
436
437    # When running on device, don't send connections through FwmarkServer.
438    os.environ["ANDROID_NO_USE_FWMARK_CLIENT"] = "1"
439
440    # Uncomment to look around at interface and rule configuration while
441    # running in the background. (Once the test finishes running, all the
442    # interfaces and rules are gone.)
443    # time.sleep(30)
444
445  @classmethod
446  def tearDownClass(cls):
447    del os.environ["ANDROID_NO_USE_FWMARK_CLIENT"]
448
449    for version in [4, 6]:
450      try:
451        cls.iproute.UnreachableRule(version, False, cls.PRIORITY_UNREACHABLE)
452      except IOError:
453        pass
454
455    for netid in cls.tuns:
456      cls._RunSetupCommands(netid, False)
457      cls.tuns[netid].close()
458
459    cls.iproute.close()
460    cls._RestoreSysctls()
461    cls.SetConsoleLogLevel(cls.loglevel)
462
463  def setUp(self):
464    self.ClearTunQueues()
465
466  def SetSocketMark(self, s, netid):
467    if netid is None:
468      netid = 0
469    s.setsockopt(SOL_SOCKET, net_test.SO_MARK, netid)
470
471  def GetSocketMark(self, s):
472    return s.getsockopt(SOL_SOCKET, net_test.SO_MARK)
473
474  def ClearSocketMark(self, s):
475    self.SetSocketMark(s, 0)
476
477  def BindToDevice(self, s, iface):
478    if not iface:
479      iface = ""
480    s.setsockopt(SOL_SOCKET, SO_BINDTODEVICE, iface.encode())
481
482  def SetUnicastInterface(self, s, ifindex):
483    # Otherwise, Python thinks it's a 1-byte option.
484    ifindex = struct.pack("!I", ifindex)
485
486    # Always set the IPv4 interface, because it will be used even on IPv6
487    # sockets if the destination address is a mapped address.
488    s.setsockopt(net_test.SOL_IP, IP_UNICAST_IF, ifindex)
489    if s.family == AF_INET6:
490      s.setsockopt(net_test.SOL_IPV6, IPV6_UNICAST_IF, ifindex)
491
492  def GetRemoteAddress(self, version):
493    return {4: self.IPV4_ADDR,
494            5: self.IPV4_ADDR,  # see GetRemoteSocketAddress()
495            6: self.IPV6_ADDR}[version]
496
497  def GetRemoteSocketAddress(self, version):
498    addr = self.GetRemoteAddress(version)
499    return "::ffff:" + addr if version == 5 else addr
500
501  def GetOtherRemoteSocketAddress(self, version):
502    return {4: self.IPV4_ADDR2,
503            5: "::ffff:" + self.IPV4_ADDR2,
504            6: self.IPV6_ADDR2}[version]
505
506  def SelectInterface(self, s, netid, mode):
507    if mode == "uid":
508      os.fchown(s.fileno(), self.UidForNetid(netid), -1)
509    elif mode == "mark":
510      self.SetSocketMark(s, netid)
511    elif mode == "oif":
512      iface = self.GetInterfaceName(netid) if netid else ""
513      self.BindToDevice(s, iface)
514    elif mode == "ucast_oif":
515      self.SetUnicastInterface(s, self.ifindices.get(netid, 0))
516    else:
517      raise ValueError("Unknown interface selection mode %s" % mode)
518
519  def BuildSocket(self, version, constructor, netid, routing_mode):
520    if version == 5: version = 6
521    s = constructor(self.GetProtocolFamily(version))
522
523    if routing_mode not in [None, "uid"]:
524      self.SelectInterface(s, netid, routing_mode)
525    elif routing_mode == "uid":
526      os.fchown(s.fileno(), self.UidForNetid(netid), -1)
527
528    return s
529
530  def RandomNetid(self, exclude=None):
531    """Return a random netid from the list of netids
532
533    Args:
534      exclude: a netid or list of netids that should not be chosen
535    """
536    if exclude is None:
537      exclude = []
538    elif isinstance(exclude, int):
539        exclude = [exclude]
540    diff = [netid for netid in self.NETIDS if netid not in exclude]
541    return random.choice(diff)
542
543  def SendOnNetid(self, version, s, dstaddr, dstport, netid, payload, cmsgs):
544    if netid is not None:
545      pktinfo = MakePktInfo(version, None, self.ifindices[netid])
546      cmsg_level, cmsg_name = {
547          4: (net_test.SOL_IP, csocket.IP_PKTINFO),
548          6: (net_test.SOL_IPV6, csocket.IPV6_PKTINFO)}[version]
549      cmsgs.append((cmsg_level, cmsg_name, pktinfo))
550    csocket.Sendmsg(s, (dstaddr, dstport), payload, cmsgs, csocket.MSG_CONFIRM)
551
552  def ReceiveEtherPacketOn(self, netid, packet):
553    posix.write(self.tuns[netid].fileno(), bytes(packet))
554
555  def ReceivePacketOn(self, netid, ip_packet):
556    routermac = self.RouterMacAddress(netid)
557    mymac = self.MyMacAddress(netid)
558    packet = scapy.Ether(src=routermac, dst=mymac) / ip_packet
559    self.ReceiveEtherPacketOn(netid, packet)
560
561  def ReadAllPacketsOn(self, netid, include_multicast=False):
562    """Return all queued packets on a netid as a list.
563
564    Args:
565      netid: The netid from which to read packets
566      include_multicast: A boolean, whether to remove multicast packets
567        (default=False)
568    """
569    packets = []
570    retries = 0
571    max_retries = 1
572    while True:
573      try:
574        packet = posix.read(self.tuns[netid].fileno(), 4096)
575        if not packet:
576          break
577        ether = scapy.Ether(packet)
578        # Multicast frames are frames where the first byte of the destination
579        # MAC address has 1 in the least-significant bit.
580        if include_multicast or not int(ether.dst.split(":")[0], 16) & 0x1:
581          packets.append(ether.payload)
582      except OSError as e:
583        # EAGAIN means there are no more packets waiting.
584        if e.errno == errno.EAGAIN:
585          # If we didn't see any packets, try again for good luck.
586          if not packets and retries < max_retries:
587            time.sleep(0.01)
588            retries += 1
589            continue
590          else:
591            break
592        # Anything else is unexpected.
593        else:
594          raise e
595    return packets
596
597  def InvalidateDstCache(self, version, netid):
598    """Invalidates destination cache entries of sockets on the specified table.
599
600    Creates and then deletes a low-priority throw route in the table for the
601    given netid, which invalidates the destination cache entries of any sockets
602    that refer to routes in that table.
603
604    The fact that this method actually invalidates destination cache entries is
605    tested by OutgoingTest#testIPv[46]Remarking, which checks that the kernel
606    does not re-route sockets when they are remarked, but does re-route them if
607    this method is called.
608
609    Args:
610      version: The IP version, 4 or 6.
611      netid: The netid to invalidate dst caches on.
612    """
613    iface = self.GetInterfaceName(netid)
614    ifindex = self.ifindices[netid]
615    table = self._TableForNetid(netid)
616    for action in [iproute.RTM_NEWROUTE, iproute.RTM_DELROUTE]:
617      self.iproute._Route(version, iproute.RTPROT_STATIC, action, table,
618                          "default", 0, nexthop=None, dev=None, mark=None,
619                          uid=None, route_type=iproute.RTN_THROW,
620                          priority=100000)
621
622  def ClearTunQueues(self):
623    # Keep reading packets on all netids until we get no packets on any of them.
624    waiting = None
625    while waiting != 0:
626      waiting = sum(len(self.ReadAllPacketsOn(netid)) for netid in self.NETIDS)
627
628  def assertPacketMatches(self, expected, actual):
629    # The expected packet is just a rough sketch of the packet we expect to
630    # receive. For example, it doesn't contain fields we can't predict, such as
631    # initial TCP sequence numbers, or that depend on the host implementation
632    # and settings, such as TCP options. To check whether the packet matches
633    # what we expect, instead of just checking all the known fields one by one,
634    # we blank out fields in the actual packet and then compare the whole
635    # packets to each other as strings. Because we modify the actual packet,
636    # make a copy here.
637    actual = actual.copy()
638
639    # Blank out IPv4 fields that we can't predict, like ID and the DF bit.
640    actualip = actual.getlayer("IP")
641    expectedip = expected.getlayer("IP")
642    if actualip and expectedip:
643      actualip.id = expectedip.id
644      actualip.flags &= 5
645      actualip.chksum = None  # Change the header, recalculate the checksum.
646
647    # Blank out the flow label, since new kernels randomize it by default.
648    actualipv6 = actual.getlayer("IPv6")
649    expectedipv6 = expected.getlayer("IPv6")
650    if actualipv6 and expectedipv6:
651      actualipv6.fl = expectedipv6.fl
652
653    # Blank out UDP fields that we can't predict (e.g., the source port for
654    # kernel-originated packets).
655    actualudp = actual.getlayer("UDP")
656    expectedudp = expected.getlayer("UDP")
657    if actualudp and expectedudp:
658      if expectedudp.sport is None:
659        actualudp.sport = None
660        actualudp.chksum = None
661      elif actualudp.chksum == 0xffff and expectedudp.chksum == 0:
662        # Scapy does not appear to change 0 to 0xffff as required by RFC 768.
663        # It is possible that scapy has been upgraded and this no longer triggers.
664        actualudp.chksum = 0
665
666    # Since the TCP code below messes with options, recalculate the length.
667    if actualip:
668      actualip.len = None
669    if actualipv6:
670      actualipv6.plen = None
671
672    # Blank out TCP fields that we can't predict.
673    actualtcp = actual.getlayer("TCP")
674    expectedtcp = expected.getlayer("TCP")
675    if actualtcp and expectedtcp:
676      actualtcp.dataofs = expectedtcp.dataofs
677      actualtcp.options = expectedtcp.options
678      actualtcp.window = expectedtcp.window
679      if expectedtcp.sport is None:
680        actualtcp.sport = None
681      if expectedtcp.seq is None:
682        actualtcp.seq = None
683      if expectedtcp.ack is None:
684        actualtcp.ack = None
685      actualtcp.chksum = None
686
687    # Serialize the packet so that expected packet fields that are only set when
688    # a packet is serialized e.g., the checksum) are filled in.
689    expected_real = expected.__class__(bytes(expected))
690    actual_real = actual.__class__(bytes(actual))
691    # repr() can be expensive. Call it only if the test is going to fail and we
692    # want to see the error.
693    if expected_real != actual_real:
694      self.assertEqual(repr(expected_real), repr(actual_real))
695
696  def PacketMatches(self, expected, actual):
697    try:
698      self.assertPacketMatches(expected, actual)
699      return True
700    except AssertionError:
701      return False
702
703  def ExpectNoPacketsOn(self, netid, msg):
704    packets = self.ReadAllPacketsOn(netid)
705    if packets:
706      firstpacket = repr(packets[0])
707    else:
708      firstpacket = ""
709    self.assertFalse(packets, msg + ": unexpected packet: " + firstpacket)
710
711  def ExpectPacketOn(self, netid, msg, expected):
712    # To avoid confusion due to lots of ICMPv6 ND going on all the time, drop
713    # multicast packets unless the packet we expect to see is a multicast
714    # packet. For now the only tests that use this are IPv6.
715    ipv6 = expected.getlayer("IPv6")
716    if ipv6 and ipv6.dst.startswith("ff"):
717      include_multicast = True
718    else:
719      include_multicast = False
720
721    packets = self.ReadAllPacketsOn(netid, include_multicast=include_multicast)
722    self.assertTrue(packets, msg + ": received no packets")
723
724    # If we receive a packet that matches what we expected, return it.
725    for packet in packets:
726      if self.PacketMatches(expected, packet):
727        return packet
728
729    # None of the packets matched. Call assertPacketMatches to output a diff
730    # between the expected packet and the last packet we received. In theory,
731    # we'd output a diff to the packet that's the best match for what we
732    # expected, but this is good enough for now.
733    try:
734      self.assertPacketMatches(expected, packets[-1])
735    except Exception as e:
736      raise UnexpectedPacketError(
737          "%s: diff with last packet:\n%s" % (msg, str(e)))
738
739  def Combinations(self, version):
740    """Produces a list of combinations to test."""
741    combinations = []
742
743    # Check packets addressed to the IP addresses of all our interfaces...
744    for dest_ip_netid in self.tuns:
745      ip_if = self.GetInterfaceName(dest_ip_netid)
746      myaddr = self.MyAddress(version, dest_ip_netid)
747      prefix = {4: "172.22.", 6: "2001:db8:aaaa::"}[version]
748      remoteaddr = self.GetRandomDestination(prefix)
749
750      # ... coming in on all our interfaces.
751      for netid in self.tuns:
752        iif = self.GetInterfaceName(netid)
753        combinations.append((netid, iif, ip_if, myaddr, remoteaddr))
754
755    return combinations
756
757  def _FormatMessage(self, iif, ip_if, extra, desc, reply_desc):
758    msg = "Receiving %s on %s to %s IP, %s" % (desc, iif, ip_if, extra)
759    if reply_desc:
760      msg += ": Expecting %s on %s" % (reply_desc, iif)
761    else:
762      msg += ": Expecting no packets on %s" % iif
763    return msg
764
765  def _ReceiveAndExpectResponse(self, netid, packet, reply, msg):
766    self.ReceivePacketOn(netid, packet)
767    if reply:
768      return self.ExpectPacketOn(netid, msg, reply)
769    else:
770      self.ExpectNoPacketsOn(netid, msg)
771      return None
772
773
774class InboundMarkingTest(MultiNetworkBaseTest):
775  """Class that automatically sets up inbound marking."""
776
777  @classmethod
778  def setUpClass(cls):
779    super(InboundMarkingTest, cls).setUpClass()
780    cls.SetInboundMarks(True)
781
782  @classmethod
783  def tearDownClass(cls):
784    cls.SetInboundMarks(False)
785    super(InboundMarkingTest, cls).tearDownClass()
786