1#!/usr/bin/python3
2#
3# Copyright 2017 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17# pylint: disable=g-bad-todo,g-bad-file-header,wildcard-import
18from errno import *  # pylint: disable=wildcard-import
19from scapy import all as scapy
20from socket import *  # pylint: disable=wildcard-import
21import binascii
22import struct
23import subprocess
24import threading
25import unittest
26
27import csocket
28import cstruct
29import multinetwork_base
30import net_test
31import packets
32import xfrm
33import xfrm_base
34
35ENCRYPTED_PAYLOAD = ("b1c74998efd6326faebe2061f00f2c750e90e76001664a80c287b150"
36                     "59e74bf949769cc6af71e51b539e7de3a2a14cb05a231b969e035174"
37                     "d98c5aa0cef1937db98889ec0d08fa408fecf616")
38
39TEST_ADDR1 = "2001:4860:4860::8888"
40TEST_ADDR2 = "2001:4860:4860::8844"
41
42XFRM_STATS_PROCFILE = "/proc/net/xfrm_stat"
43XFRM_STATS_OUT_NO_STATES = "XfrmOutNoStates"
44
45# IP addresses to use for tunnel endpoints. For generality, these should be
46# different from the addresses we send packets to.
47TUNNEL_ENDPOINTS = {4: "8.8.4.4", 6: TEST_ADDR2}
48
49TEST_SPI = 0x1234
50TEST_SPI2 = 0x1235
51
52
53
54class XfrmFunctionalTest(xfrm_base.XfrmLazyTest):
55
56  def assertIsUdpEncapEsp(self, packet, spi, seq, length):
57    protocol = packet.nh if packet.version == 6 else packet.proto
58    self.assertEqual(IPPROTO_UDP, protocol)
59    udp_hdr = packet[scapy.UDP]
60    self.assertEqual(4500, udp_hdr.dport)
61    self.assertEqual(length, len(udp_hdr))
62    esp_hdr, _ = cstruct.Read(bytes(udp_hdr.payload), xfrm.EspHdr)
63    # FIXME: this file currently swaps SPI byte order manually, so SPI needs to
64    # be double-swapped here.
65    self.assertEqual(xfrm.EspHdr((spi, seq)), esp_hdr)
66
67  def CreateNewSa(self, localAddr, remoteAddr, spi, reqId, encap_tmpl,
68                  null_auth=False):
69    auth_algo = (
70        xfrm_base._ALGO_AUTH_NULL if null_auth else xfrm_base._ALGO_HMAC_SHA1)
71    self.xfrm.AddSaInfo(localAddr, remoteAddr, spi, xfrm.XFRM_MODE_TRANSPORT,
72                    reqId, xfrm_base._ALGO_CBC_AES_256, auth_algo, None,
73                    encap_tmpl, None, None)
74
75  def testAddSa(self):
76    self.CreateNewSa("::", TEST_ADDR1, TEST_SPI, 3320, None)
77    expected = (
78        "src :: dst 2001:4860:4860::8888\n"
79        "\tproto esp spi 0x00001234 reqid 3320 mode transport\n"
80        "\treplay-window 32 \n"
81        "\tauth-trunc hmac(sha1) 0x%s 96\n"
82        "\tenc cbc(aes) 0x%s\n"
83        "\tsel src ::/0 dst ::/0 \n" % (
84            binascii.hexlify(xfrm_base._AUTHENTICATION_KEY_128).decode("utf-8"),
85            binascii.hexlify(xfrm_base._ENCRYPTION_KEY_256).decode("utf-8")))
86
87    actual = subprocess.check_output("ip xfrm state".split()).decode("utf-8")
88    # Newer versions of IP also show anti-replay context. Don't choke if it's
89    # missing.
90    actual = actual.replace(
91        "\tanti-replay context: seq 0x0, oseq 0x0, bitmap 0x00000000\n", "")
92    try:
93      self.assertMultiLineEqual(expected, actual)
94    finally:
95      self.xfrm.DeleteSaInfo(TEST_ADDR1, TEST_SPI, IPPROTO_ESP)
96
97  def testFlush(self):
98    self.assertEqual(0, len(self.xfrm.DumpSaInfo()))
99    self.CreateNewSa("::", "2000::", TEST_SPI, 1234, None)
100    self.CreateNewSa("0.0.0.0", "192.0.2.1", TEST_SPI, 4321, None)
101    self.assertEqual(2, len(self.xfrm.DumpSaInfo()))
102    self.xfrm.FlushSaInfo()
103    self.assertEqual(0, len(self.xfrm.DumpSaInfo()))
104
105  def _TestSocketPolicy(self, version):
106    # Open a UDP socket and connect it.
107    family = net_test.GetAddressFamily(version)
108    s = socket(family, SOCK_DGRAM, 0)
109    netid = self.RandomNetid()
110    self.SelectInterface(s, netid, "mark")
111
112    remotesockaddr = self.GetRemoteSocketAddress(version)
113    s.connect((remotesockaddr, 53))
114    saddr, sport = s.getsockname()[:2]
115    daddr, dport = s.getpeername()[:2]
116    if version == 5:
117      saddr = saddr.replace("::ffff:", "")
118      daddr = daddr.replace("::ffff:", "")
119
120    reqid = 0
121
122    desc, pkt = packets.UDP(version, saddr, daddr, sport=sport)
123    s.sendto(net_test.UDP_PAYLOAD, (remotesockaddr, 53))
124    self.ExpectPacketOn(netid, "Send after socket, expected %s" % desc, pkt)
125
126    # Using IPv4 XFRM on a dual-stack socket requires setting an AF_INET policy
127    # that's written in terms of IPv4 addresses.
128    xfrm_version = 4 if version == 5 else version
129    xfrm_family = net_test.GetAddressFamily(xfrm_version)
130    xfrm_base.ApplySocketPolicy(s, xfrm_family, xfrm.XFRM_POLICY_OUT,
131                                TEST_SPI, reqid, None)
132
133    # Because the policy has level set to "require" (the default), attempting
134    # to send a packet results in an error, because there is no SA that
135    # matches the socket policy we set.
136    self.assertRaisesErrno(
137        EAGAIN,
138        s.sendto, net_test.UDP_PAYLOAD, (remotesockaddr, 53))
139
140    # If there is a user space key manager, calling sendto() after applying the socket policy
141    # creates an SA whose state is XFRM_STATE_ACQ. So this just deletes it.
142    # If there is no user space key manager, deleting SA returns ESRCH as the error code.
143    try:
144        self.xfrm.DeleteSaInfo(self.GetRemoteAddress(xfrm_version), TEST_SPI, IPPROTO_ESP)
145    except IOError as e:
146        self.assertEqual(ESRCH, e.errno, "Unexpected error when deleting ACQ SA")
147
148    # Adding a matching SA causes the packet to go out encrypted. The SA's
149    # SPI must match the one in our template, and the destination address must
150    # match the packet's destination address (in tunnel mode, it has to match
151    # the tunnel destination).
152    self.CreateNewSa(
153        net_test.GetWildcardAddress(xfrm_version),
154        self.GetRemoteAddress(xfrm_version), TEST_SPI, reqid, None)
155
156    s.sendto(net_test.UDP_PAYLOAD, (remotesockaddr, 53))
157    expected_length = xfrm_base.GetEspPacketLength(xfrm.XFRM_MODE_TRANSPORT,
158                                                version, False,
159                                                net_test.UDP_PAYLOAD,
160                                                xfrm_base._ALGO_HMAC_SHA1,
161                                                xfrm_base._ALGO_CBC_AES_256)
162    self._ExpectEspPacketOn(netid, TEST_SPI, 1, expected_length, None, None)
163
164    # Sending to another destination doesn't work: again, no matching SA.
165    remoteaddr2 = self.GetOtherRemoteSocketAddress(version)
166    self.assertRaisesErrno(
167        EAGAIN,
168        s.sendto, net_test.UDP_PAYLOAD, (remoteaddr2, 53))
169
170    # Sending on another socket without the policy applied results in an
171    # unencrypted packet going out.
172    s2 = socket(family, SOCK_DGRAM, 0)
173    self.SelectInterface(s2, netid, "mark")
174    s2.sendto(net_test.UDP_PAYLOAD, (remotesockaddr, 53))
175    pkts = self.ReadAllPacketsOn(netid)
176    self.assertEqual(1, len(pkts))
177    packet = pkts[0]
178
179    protocol = packet.nh if version == 6 else packet.proto
180    self.assertEqual(IPPROTO_UDP, protocol)
181
182    # Deleting the SA causes the first socket to return errors again.
183    self.xfrm.DeleteSaInfo(self.GetRemoteAddress(xfrm_version), TEST_SPI,
184                           IPPROTO_ESP)
185    self.assertRaisesErrno(
186        EAGAIN,
187        s.sendto, net_test.UDP_PAYLOAD, (remotesockaddr, 53))
188
189    # Clear the socket policy and expect a cleartext packet.
190    xfrm_base.SetPolicySockopt(s, family, None)
191    s.sendto(net_test.UDP_PAYLOAD, (remotesockaddr, 53))
192    self.ExpectPacketOn(netid, "Send after clear, expected %s" % desc, pkt)
193
194    # Clearing the policy twice is safe.
195    xfrm_base.SetPolicySockopt(s, family, None)
196    s.sendto(net_test.UDP_PAYLOAD, (remotesockaddr, 53))
197    self.ExpectPacketOn(netid, "Send after clear 2, expected %s" % desc, pkt)
198    s.close()
199
200    # Clearing if a policy was never set is safe.
201    s = socket(AF_INET6, SOCK_DGRAM, 0)
202    xfrm_base.SetPolicySockopt(s, family, None)
203
204    s.close()
205    s2.close()
206
207  def testSocketPolicyIPv4(self):
208    self._TestSocketPolicy(4)
209
210  def testSocketPolicyIPv6(self):
211    self._TestSocketPolicy(6)
212
213  def testSocketPolicyMapped(self):
214    self._TestSocketPolicy(5)
215
216  # Sets up sockets and marks to correct netid
217  def _SetupUdpEncapSockets(self, version):
218    netid = self.RandomNetid()
219    myaddr = self.MyAddress(version, netid)
220    remoteaddr = self.GetRemoteAddress(version)
221    family = net_test.GetAddressFamily(version)
222
223    # Reserve a port on which to receive UDP encapsulated packets. Sending
224    # packets works without this (and potentially can send packets with a source
225    # port belonging to another application), but receiving requires the port to
226    # be bound and the encapsulation socket option enabled.
227    encap_sock = net_test.Socket(family, SOCK_DGRAM, 0)
228    encap_sock.bind((myaddr, 0))
229    encap_port = encap_sock.getsockname()[1]
230    encap_sock.setsockopt(IPPROTO_UDP, xfrm.UDP_ENCAP, xfrm.UDP_ENCAP_ESPINUDP)
231
232    # Open a socket to send traffic.
233    # TODO: test with a different family than the encap socket.
234    s = socket(family, SOCK_DGRAM, 0)
235    self.SelectInterface(s, netid, "mark")
236    s.connect((remoteaddr, 53))
237
238    return netid, myaddr, remoteaddr, encap_sock, encap_port, s
239
240  # Sets up SAs and applies socket policy to given socket
241  def _SetupUdpEncapSaPair(self, version, myaddr, remoteaddr, in_spi, out_spi,
242                           encap_port, s, use_null_auth):
243    in_reqid = 123
244    out_reqid = 456
245
246    # Create inbound and outbound SAs that specify UDP encapsulation.
247    encaptmpl = xfrm.XfrmEncapTmpl((xfrm.UDP_ENCAP_ESPINUDP, htons(encap_port),
248                                    htons(4500), 16 * b"\x00"))
249    self.CreateNewSa(myaddr, remoteaddr, out_spi, out_reqid, encaptmpl,
250                     use_null_auth)
251
252    # Add an encap template that's the mirror of the outbound one.
253    encaptmpl.sport, encaptmpl.dport = encaptmpl.dport, encaptmpl.sport
254    self.CreateNewSa(remoteaddr, myaddr, in_spi, in_reqid, encaptmpl,
255                     use_null_auth)
256
257    # Apply socket policies to s.
258    family = net_test.GetAddressFamily(version)
259    xfrm_base.ApplySocketPolicy(s, family, xfrm.XFRM_POLICY_OUT, out_spi,
260                                out_reqid, None)
261
262    # TODO: why does this work without a per-socket policy applied?
263    # The received  packet obviously matches an SA, but don't inbound packets
264    # need to match a policy as well? (b/71541609)
265    xfrm_base.ApplySocketPolicy(s, family, xfrm.XFRM_POLICY_IN, in_spi,
266                                in_reqid, None)
267
268    # Uncomment for debugging.
269    # subprocess.call("ip xfrm state".split())
270
271  # Check that packets can be sent and received.
272  def _VerifyUdpEncapSocket(self, version, netid, remoteaddr, myaddr, encap_port,
273                           sock, in_spi, out_spi, null_auth, seq_num):
274    # Now send a packet.
275    sock.sendto(net_test.UDP_PAYLOAD, (remoteaddr, 53))
276    srcport = sock.getsockname()[1]
277
278    # Expect to see an UDP encapsulated packet.
279    pkts = self.ReadAllPacketsOn(netid)
280    self.assertEqual(1, len(pkts))
281    packet = pkts[0]
282
283    auth_algo = (
284        xfrm_base._ALGO_AUTH_NULL if null_auth else xfrm_base._ALGO_HMAC_SHA1)
285    expected_len = xfrm_base.GetEspPacketLength(
286        xfrm.XFRM_MODE_TRANSPORT, version, True, net_test.UDP_PAYLOAD,
287        auth_algo, xfrm_base._ALGO_CBC_AES_256)
288    self.assertIsUdpEncapEsp(packet, out_spi, seq_num, expected_len)
289
290    # Now test the receive path. Because we don't know how to decrypt packets,
291    # we just play back the encrypted packet that kernel sent earlier. We swap
292    # the addresses in the IP header to make the packet look like it's bound for
293    # us, but we can't do that for the port numbers because the UDP header is
294    # part of the integrity protected payload, which we can only replay as is.
295    # So the source and destination ports are swapped and the packet appears to
296    # be sent from srcport to port 53. Open another socket on that port, and
297    # apply the inbound policy to it.
298    family = net_test.GetAddressFamily(version)
299    twisted_socket = socket(family, SOCK_DGRAM, 0)
300    csocket.SetSocketTimeout(twisted_socket, 100)
301    twisted_socket.bind((net_test.GetWildcardAddress(version), 53))
302
303    # Save the payload of the packet so we can replay it back to ourselves, and
304    # replace the SPI with our inbound SPI.
305    payload = bytes(packet.payload)[8:]
306    spi_seq = xfrm.EspHdr((in_spi, seq_num)).Pack()
307    payload = spi_seq + payload[len(spi_seq):]
308
309    sainfo = self.xfrm.FindSaInfo(in_spi)
310    start_integrity_failures = sainfo.stats.integrity_failed
311
312    # Now play back the valid packet and check that we receive it.
313    ip = {4: scapy.IP, 6: scapy.IPv6}[version]
314    incoming = (ip(src=remoteaddr, dst=myaddr) /
315                scapy.UDP(sport=4500, dport=encap_port) / payload)
316    incoming = ip(bytes(incoming))
317    self.ReceivePacketOn(netid, incoming)
318
319    sainfo = self.xfrm.FindSaInfo(in_spi)
320
321    # TODO: break this out into a separate test
322    # If our SPIs are different, and we aren't using null authentication,
323    # we expect the packet to be dropped. We also expect that the integrity
324    # failure counter to increase, as SPIs are part of the authenticated or
325    # integrity-verified portion of the packet.
326    if not null_auth and in_spi != out_spi:
327      self.assertRaisesErrno(EAGAIN, twisted_socket.recv, 4096)
328      self.assertEqual(start_integrity_failures + 1,
329                        sainfo.stats.integrity_failed)
330    else:
331      data, src = twisted_socket.recvfrom(4096)
332      self.assertEqual(net_test.UDP_PAYLOAD, data)
333      self.assertEqual((remoteaddr, srcport), src[:2])
334      self.assertEqual(start_integrity_failures, sainfo.stats.integrity_failed)
335
336    # Check that unencrypted packets on twisted_socket are not received.
337    unencrypted = (
338        ip(src=remoteaddr, dst=myaddr) / scapy.UDP(
339            sport=srcport, dport=53) / net_test.UDP_PAYLOAD)
340    self.assertRaisesErrno(EAGAIN, twisted_socket.recv, 4096)
341
342    twisted_socket.close()
343
344  def _RunEncapSocketPolicyTest(self, version, in_spi, out_spi, use_null_auth):
345    netid, myaddr, remoteaddr, encap_sock, encap_port, s = \
346        self._SetupUdpEncapSockets(version)
347
348    self._SetupUdpEncapSaPair(version, myaddr, remoteaddr, in_spi, out_spi,
349                              encap_port, s, use_null_auth)
350
351    # Check that UDP encap sockets work with socket policy and given SAs
352    self._VerifyUdpEncapSocket(version, netid, remoteaddr, myaddr, encap_port,
353                               s, in_spi, out_spi, use_null_auth, 1)
354    encap_sock.close()
355    s.close()
356
357  # TODO: Add tests for ESP (non-encap) sockets.
358  def testUdpEncapSameSpisNullAuth(self):
359    # Use the same SPI both inbound and outbound because this lets us receive
360    # encrypted packets by simply replaying the packets the kernel sends
361    # without having to disable authentication
362    self._RunEncapSocketPolicyTest(4, TEST_SPI, TEST_SPI, True)
363
364  def testUdpEncapSameSpis(self):
365    self._RunEncapSocketPolicyTest(4, TEST_SPI, TEST_SPI, False)
366
367  def testUdpEncapDifferentSpisNullAuth(self):
368    self._RunEncapSocketPolicyTest(4, TEST_SPI, TEST_SPI2, True)
369
370  def testUdpEncapDifferentSpis(self):
371    self._RunEncapSocketPolicyTest(4, TEST_SPI, TEST_SPI2, False)
372
373  def testUdpEncapRekey(self):
374    # Select the two SPIs that will be used
375    start_spi = TEST_SPI
376    rekey_spi = TEST_SPI2
377
378    # Setup sockets
379    netid, myaddr, remoteaddr, encap_sock, encap_port, s = \
380        self._SetupUdpEncapSockets(4)
381
382    # The SAs must use null authentication, since we change SPIs on the fly
383    # Without null authentication, this would result in an ESP authentication
384    # error since the SPI is part of the authenticated section. The packet
385    # would then be dropped
386    self._SetupUdpEncapSaPair(4, myaddr, remoteaddr, start_spi, start_spi,
387                              encap_port, s, True)
388
389    # Check that UDP encap sockets work with socket policy and given SAs
390    self._VerifyUdpEncapSocket(4, netid, remoteaddr, myaddr, encap_port, s,
391                               start_spi, start_spi, True, 1)
392
393    # Rekey this socket using the make-before-break paradigm. First we create
394    # new SAs, update the per-socket policies, and only then remove the old SAs
395    #
396    # This allows us to switch to the new SA without breaking the outbound path.
397    self._SetupUdpEncapSaPair(4, myaddr, remoteaddr, rekey_spi, rekey_spi,
398                              encap_port, s, True)
399
400    # Check that UDP encap socket works with updated socket policy, sending
401    # using new SA, but receiving on both old and new SAs
402    self._VerifyUdpEncapSocket(4, netid, remoteaddr, myaddr, encap_port, s,
403                               rekey_spi, rekey_spi, True, 1)
404    self._VerifyUdpEncapSocket(4, netid, remoteaddr, myaddr, encap_port, s,
405                               start_spi, rekey_spi, True, 2)
406
407    # Delete old SAs
408    self.xfrm.DeleteSaInfo(remoteaddr, start_spi, IPPROTO_ESP)
409    self.xfrm.DeleteSaInfo(myaddr, start_spi, IPPROTO_ESP)
410
411    # Check that UDP encap socket works with updated socket policy and new SAs
412    self._VerifyUdpEncapSocket(4, netid, remoteaddr, myaddr, encap_port, s,
413                               rekey_spi, rekey_spi, True, 3)
414    encap_sock.close()
415    s.close()
416
417  def _CheckUDPEncapRecv(self, version, mode):
418    netid, myaddr, remoteaddr, encap_sock, encap_port, s = \
419        self._SetupUdpEncapSockets(version)
420
421    # Create inbound and outbound SAs that specify UDP encapsulation.
422    reqid = 123
423    encaptmpl = xfrm.XfrmEncapTmpl((xfrm.UDP_ENCAP_ESPINUDP, htons(encap_port),
424                                    htons(4500), 16 * b"\x00"))
425    self.xfrm.AddSaInfo(remoteaddr, myaddr, TEST_SPI, mode, reqid,
426                    xfrm_base._ALGO_CRYPT_NULL, xfrm_base._ALGO_AUTH_NULL, None,
427                    encaptmpl, None, None)
428
429    sainfo = self.xfrm.FindSaInfo(TEST_SPI)
430    self.assertEqual(0, sainfo.curlft.packets)
431    self.assertEqual(0, sainfo.curlft.bytes)
432    self.assertEqual(0, sainfo.stats.integrity_failed)
433
434    IpType = {4: scapy.IP, 6: scapy.IPv6}[version]
435    if mode == xfrm.XFRM_MODE_TRANSPORT:
436      # Due to a bug in the IPv6 UDP encap code, there must be at least 32
437      # bytes after the ESP header or the packet will be dropped.
438      # 8 (UDP header) + 18 (payload) + 2 (ESP trailer) = 28, dropped
439      # 8 (UDP header) + 19 (payload) + 4 (ESP trailer) = 32, received
440      # There is a similar bug in IPv4 encap, but the minimum is only 12 bytes,
441      # which is much less likely to occur. This doesn't affect tunnel mode
442      # because IP headers are always at least 20 bytes long.
443      data = 19 * b"a"
444      datalen = len(data)
445      # TODO: update scapy and use scapy.ESP instead of manually generating ESP header.
446      inner_pkt = xfrm.EspHdr(spi=TEST_SPI, seqnum=1).Pack() + bytes(
447          scapy.UDP(sport=443, dport=32123) / data) + bytes(
448          xfrm_base.GetEspTrailer(len(data), IPPROTO_UDP))
449      input_pkt = (IpType(src=remoteaddr, dst=myaddr) /
450                   scapy.UDP(sport=4500, dport=encap_port) /
451                   inner_pkt)
452    else:
453      # TODO: test IPv4 in IPv6 encap and vice versa.
454      data = b""  # Empty UDP payload
455      datalen = {4: 20, 6: 40}[version] + len(data)
456      # TODO: update scapy and use scapy.ESP instead of manually generating ESP header.
457      inner_pkt = xfrm.EspHdr(spi=TEST_SPI, seqnum=1).Pack() + bytes(
458          IpType(src=remoteaddr, dst=myaddr) /
459          scapy.UDP(sport=443, dport=32123) / data) + bytes(
460          xfrm_base.GetEspTrailer(len(data), {4: IPPROTO_IPIP, 6: IPPROTO_IPV6}[version]))
461      input_pkt = (IpType(src=remoteaddr, dst=myaddr) /
462                   scapy.UDP(sport=4500, dport=encap_port) /
463                   inner_pkt)
464
465    # input_pkt.show2()
466    self.ReceivePacketOn(netid, input_pkt)
467
468    sainfo = self.xfrm.FindSaInfo(TEST_SPI)
469    self.assertEqual(1, sainfo.curlft.packets)
470    self.assertEqual(datalen + 8, sainfo.curlft.bytes)
471    self.assertEqual(0, sainfo.stats.integrity_failed)
472
473    # Uncomment for debugging.
474    # subprocess.call("ip -s xfrm state".split())
475
476    encap_sock.close()
477    s.close()
478
479  def testIPv4UDPEncapRecvTransport(self):
480    self._CheckUDPEncapRecv(4, xfrm.XFRM_MODE_TRANSPORT)
481
482  def testIPv4UDPEncapRecvTunnel(self):
483    self._CheckUDPEncapRecv(4, xfrm.XFRM_MODE_TUNNEL)
484
485  # IPv6 UDP encap is broken between:
486  # 4db4075f92af ("esp6: fix check on ipv6_skip_exthdr's return value") and
487  # 5f9c55c8066b ("ipv6: check return value of ipv6_skip_exthdr")
488  @unittest.skipUnless(net_test.KernelAtLeast([(5, 10, 108), (5, 15, 31)]) or
489                       net_test.NonGXI(5, 10),
490                       reason="Unsupported or broken on current kernel")
491  def testIPv6UDPEncapRecvTransport(self):
492    self._CheckUDPEncapRecv(6, xfrm.XFRM_MODE_TRANSPORT)
493
494  @unittest.skipUnless(net_test.KernelAtLeast([(5, 10, 108), (5, 15, 31)]) or
495                       net_test.NonGXI(5, 10),
496                       reason="Unsupported or broken on current kernel")
497  def testIPv6UDPEncapRecvTunnel(self):
498    self._CheckUDPEncapRecv(6, xfrm.XFRM_MODE_TUNNEL)
499
500  def testAllocSpecificSpi(self):
501    spi = 0xABCD
502    new_sa = self.xfrm.AllocSpi("::", IPPROTO_ESP, spi, spi)
503    self.assertEqual(spi, new_sa.id.spi)
504
505  def testAllocSpecificSpiUnavailable(self):
506    """Attempt to allocate the same SPI twice."""
507    spi = 0xABCD
508    new_sa = self.xfrm.AllocSpi("::", IPPROTO_ESP, spi, spi)
509    self.assertEqual(spi, new_sa.id.spi)
510    with self.assertRaisesErrno(ENOENT):
511      new_sa = self.xfrm.AllocSpi("::", IPPROTO_ESP, spi, spi)
512
513  def testAllocRangeSpi(self):
514    start, end = 0xABCD0, 0xABCDF
515    new_sa = self.xfrm.AllocSpi("::", IPPROTO_ESP, start, end)
516    spi = new_sa.id.spi
517    self.assertGreaterEqual(spi, start)
518    self.assertLessEqual(spi, end)
519
520  def testAllocRangeSpiUnavailable(self):
521    """Attempt to allocate N+1 SPIs from a range of size N."""
522    start, end = 0xABCD0, 0xABCDF
523    range_size = end - start + 1
524    spis = set()
525    # Assert that allocating SPI fails when none are available.
526    with self.assertRaisesErrno(ENOENT):
527      # Allocating range_size + 1 SPIs is guaranteed to fail.  Due to the way
528      # kernel picks random SPIs, this has a high probability of failing before
529      # reaching that limit.
530      for i in range(range_size + 1):
531        new_sa = self.xfrm.AllocSpi("::", IPPROTO_ESP, start, end)
532        spi = new_sa.id.spi
533        self.assertNotIn(spi, spis)
534        spis.add(spi)
535
536  def testSocketPolicyDstCacheV6(self):
537    self._TestSocketPolicyDstCache(6)
538
539  def testSocketPolicyDstCacheV4(self):
540    self._TestSocketPolicyDstCache(4)
541
542  def _TestSocketPolicyDstCache(self, version):
543    """Test that destination cache is cleared with socket policy.
544
545    This relies on the fact that connect() on a UDP socket populates the
546    destination cache.
547    """
548
549    # Create UDP socket.
550    family = net_test.GetAddressFamily(version)
551    netid = self.RandomNetid()
552    s = socket(family, SOCK_DGRAM, 0)
553    self.SelectInterface(s, netid, "mark")
554
555    # Populate the socket's destination cache.
556    remote = self.GetRemoteAddress(version)
557    s.connect((remote, 53))
558
559    # Apply a policy to the socket. Should clear dst cache.
560    reqid = 123
561    xfrm_base.ApplySocketPolicy(s, family, xfrm.XFRM_POLICY_OUT,
562                                TEST_SPI, reqid, None)
563
564    # Policy with no matching SA should result in EAGAIN. If destination cache
565    # failed to clear, then the UDP packet will be sent normally.
566    with self.assertRaisesErrno(EAGAIN):
567      s.send(net_test.UDP_PAYLOAD)
568    self.ExpectNoPacketsOn(netid, "Packet not blocked by policy")
569    s.close()
570
571  def _CheckNullEncryptionTunnelMode(self, version):
572    family = net_test.GetAddressFamily(version)
573    netid = self.RandomNetid()
574    local_addr = self.MyAddress(version, netid)
575    remote_addr = self.GetRemoteAddress(version)
576
577    # Borrow the address of another netId as the source address of the tunnel
578    tun_local = self.MyAddress(version, self.RandomNetid(netid))
579    # For generality, pick a tunnel endpoint that's not the address we
580    # connect the socket to.
581    tun_remote = TUNNEL_ENDPOINTS[version]
582
583    # Output
584    self.xfrm.AddSaInfo(
585        tun_local, tun_remote, 0xABCD, xfrm.XFRM_MODE_TUNNEL, 123,
586        xfrm_base._ALGO_CRYPT_NULL, xfrm_base._ALGO_AUTH_NULL,
587        None, None, None, netid)
588    # Input
589    self.xfrm.AddSaInfo(
590        tun_remote, tun_local, 0x9876, xfrm.XFRM_MODE_TUNNEL, 456,
591        xfrm_base._ALGO_CRYPT_NULL, xfrm_base._ALGO_AUTH_NULL,
592        None, None, None, None)
593
594    sock = net_test.UDPSocket(family)
595    self.SelectInterface(sock, netid, "mark")
596    sock.bind((local_addr, 0))
597    local_port = sock.getsockname()[1]
598    remote_port = 5555
599
600    xfrm_base.ApplySocketPolicy(
601        sock, family, xfrm.XFRM_POLICY_OUT, 0xABCD, 123,
602        (tun_local, tun_remote))
603    xfrm_base.ApplySocketPolicy(
604        sock, family, xfrm.XFRM_POLICY_IN, 0x9876, 456,
605        (tun_remote, tun_local))
606
607    # Create and receive an ESP packet.
608    IpType = {4: scapy.IP, 6: scapy.IPv6}[version]
609    input_pkt = (IpType(src=remote_addr, dst=local_addr) /
610                 scapy.UDP(sport=remote_port, dport=local_port) /
611                 b"input hello")
612    input_pkt = IpType(bytes(input_pkt)) # Compute length, checksum.
613    input_pkt = xfrm_base.EncryptPacketWithNull(input_pkt, 0x9876,
614                                                1, (tun_remote, tun_local))
615
616    self.ReceivePacketOn(netid, input_pkt)
617    msg, addr = sock.recvfrom(1024)
618    self.assertEqual(b"input hello", msg)
619    self.assertEqual((remote_addr, remote_port), addr[:2])
620
621    # Send and capture a packet.
622    sock.sendto(b"output hello", (remote_addr, remote_port))
623    packets = self.ReadAllPacketsOn(netid)
624    self.assertEqual(1, len(packets))
625    output_pkt = packets[0]
626    output_pkt, esp_hdr = xfrm_base.DecryptPacketWithNull(output_pkt)
627    self.assertEqual(output_pkt[scapy.UDP].len, len(b"output_hello") + 8)
628    self.assertEqual(remote_addr, output_pkt.dst)
629    self.assertEqual(remote_port, output_pkt[scapy.UDP].dport)
630    # length of the payload plus the UDP header
631    self.assertEqual(b"output hello", bytes(output_pkt[scapy.UDP].payload))
632    self.assertEqual(0xABCD, esp_hdr.spi)
633    sock.close()
634
635  def testNullEncryptionTunnelMode(self):
636    """Verify null encryption in tunnel mode.
637
638    This test verifies both manual assembly and disassembly of UDP packets
639    with ESP in IPsec tunnel mode.
640    """
641    for version in [4, 6]:
642      self._CheckNullEncryptionTunnelMode(version)
643
644  def _CheckNullEncryptionTransportMode(self, version):
645    family = net_test.GetAddressFamily(version)
646    netid = self.RandomNetid()
647    local_addr = self.MyAddress(version, netid)
648    remote_addr = self.GetRemoteAddress(version)
649
650    # Output
651    self.xfrm.AddSaInfo(
652        local_addr, remote_addr, 0xABCD, xfrm.XFRM_MODE_TRANSPORT, 123,
653        xfrm_base._ALGO_CRYPT_NULL, xfrm_base._ALGO_AUTH_NULL,
654        None, None, None, None)
655    # Input
656    self.xfrm.AddSaInfo(
657        remote_addr, local_addr, 0x9876, xfrm.XFRM_MODE_TRANSPORT, 456,
658        xfrm_base._ALGO_CRYPT_NULL, xfrm_base._ALGO_AUTH_NULL,
659        None, None, None, None)
660
661    sock = net_test.UDPSocket(family)
662    self.SelectInterface(sock, netid, "mark")
663    sock.bind((local_addr, 0))
664    local_port = sock.getsockname()[1]
665    remote_port = 5555
666
667    xfrm_base.ApplySocketPolicy(
668        sock, family, xfrm.XFRM_POLICY_OUT, 0xABCD, 123, None)
669    xfrm_base.ApplySocketPolicy(
670        sock, family, xfrm.XFRM_POLICY_IN, 0x9876, 456, None)
671
672    # Create and receive an ESP packet.
673    IpType = {4: scapy.IP, 6: scapy.IPv6}[version]
674    input_pkt = (IpType(src=remote_addr, dst=local_addr) /
675                 scapy.UDP(sport=remote_port, dport=local_port) /
676                 b"input hello")
677    input_pkt = IpType(bytes(input_pkt)) # Compute length, checksum.
678    input_pkt = xfrm_base.EncryptPacketWithNull(input_pkt, 0x9876, 1, None)
679
680    self.ReceivePacketOn(netid, input_pkt)
681    msg, addr = sock.recvfrom(1024)
682    self.assertEqual(b"input hello", msg)
683    self.assertEqual((remote_addr, remote_port), addr[:2])
684
685    # Send and capture a packet.
686    sock.sendto(b"output hello", (remote_addr, remote_port))
687    packets = self.ReadAllPacketsOn(netid)
688    self.assertEqual(1, len(packets))
689    output_pkt = packets[0]
690    output_pkt, esp_hdr = xfrm_base.DecryptPacketWithNull(output_pkt)
691    # length of the payload plus the UDP header
692    self.assertEqual(output_pkt[scapy.UDP].len, len(b"output_hello") + 8)
693    self.assertEqual(remote_addr, output_pkt.dst)
694    self.assertEqual(remote_port, output_pkt[scapy.UDP].dport)
695    self.assertEqual(b"output hello", bytes(output_pkt[scapy.UDP].payload))
696    self.assertEqual(0xABCD, esp_hdr.spi)
697    sock.close()
698
699  def testNullEncryptionTransportMode(self):
700    """Verify null encryption in transport mode.
701
702    This test verifies both manual assembly and disassembly of UDP packets
703    with ESP in IPsec transport mode.
704    """
705    for version in [4, 6]:
706      self._CheckNullEncryptionTransportMode(version)
707
708  def _CheckGlobalPoliciesByMark(self, version):
709    """Tests that global policies may differ by only the mark."""
710    family = net_test.GetAddressFamily(version)
711    sel = xfrm.EmptySelector(family)
712    # Pick 2 arbitrary mark values.
713    mark1 = xfrm.XfrmMark(mark=0xf00, mask=xfrm_base.MARK_MASK_ALL)
714    mark2 = xfrm.XfrmMark(mark=0xf00d, mask=xfrm_base.MARK_MASK_ALL)
715    # Create a global policy.
716    policy = xfrm.UserPolicy(xfrm.XFRM_POLICY_OUT, sel)
717    tmpl = xfrm.UserTemplate(AF_UNSPEC, 0xfeed, 0, None)
718    # Create the policy with the first mark.
719    self.xfrm.AddPolicyInfo(policy, tmpl, mark1)
720    # Create the same policy but with the second (different) mark.
721    self.xfrm.AddPolicyInfo(policy, tmpl, mark2)
722    # Delete the policies individually
723    self.xfrm.DeletePolicyInfo(sel, xfrm.XFRM_POLICY_OUT, mark1)
724    self.xfrm.DeletePolicyInfo(sel, xfrm.XFRM_POLICY_OUT, mark2)
725
726  def testGlobalPoliciesByMarkV4(self):
727    self._CheckGlobalPoliciesByMark(4)
728
729  def testGlobalPoliciesByMarkV6(self):
730    self._CheckGlobalPoliciesByMark(6)
731
732  def _CheckUpdatePolicy(self, version):
733    """Tests that we can can update the template on a policy."""
734    family = net_test.GetAddressFamily(version)
735    tmpl1 = xfrm.UserTemplate(family, 0xdead, 0, None)
736    tmpl2 = xfrm.UserTemplate(family, 0xbeef, 0, None)
737    sel = xfrm.EmptySelector(family)
738    policy = xfrm.UserPolicy(xfrm.XFRM_POLICY_OUT, sel)
739    mark = xfrm.XfrmMark(mark=0xf00, mask=xfrm_base.MARK_MASK_ALL)
740
741    def _CheckTemplateMatch(tmpl):
742      """Dump the SPD and match a single template on a single policy."""
743      dump = self.xfrm.DumpPolicyInfo()
744      self.assertEqual(1, len(dump))
745      _, attributes = dump[0]
746      self.assertEqual(attributes['XFRMA_TMPL'], tmpl)
747
748    # Create a new policy using update.
749    self.xfrm.UpdatePolicyInfo(policy, tmpl1, mark, None)
750    # NEWPOLICY will not update the existing policy. This checks both that
751    # UPDPOLICY created a policy and that NEWPOLICY will not perform updates.
752    _CheckTemplateMatch(tmpl1)
753    with self.assertRaisesErrno(EEXIST):
754      self.xfrm.AddPolicyInfo(policy, tmpl2, mark, None)
755    # Update the policy using UPDPOLICY.
756    self.xfrm.UpdatePolicyInfo(policy, tmpl2, mark, None)
757    # There should only be one policy after update, and it should have the
758    # updated template.
759    _CheckTemplateMatch(tmpl2)
760
761  def testUpdatePolicyV4(self):
762    self._CheckUpdatePolicy(4)
763
764  def testUpdatePolicyV6(self):
765    self._CheckUpdatePolicy(6)
766
767  def _CheckPolicyDifferByDirection(self,version):
768    """Tests that policies can differ only by direction."""
769    family = net_test.GetAddressFamily(version)
770    tmpl = xfrm.UserTemplate(family, 0xdead, 0, None)
771    sel = xfrm.EmptySelector(family)
772    mark = xfrm.XfrmMark(mark=0xf00, mask=xfrm_base.MARK_MASK_ALL)
773    policy = xfrm.UserPolicy(xfrm.XFRM_POLICY_OUT, sel)
774    self.xfrm.AddPolicyInfo(policy, tmpl, mark)
775    policy = xfrm.UserPolicy(xfrm.XFRM_POLICY_IN, sel)
776    self.xfrm.AddPolicyInfo(policy, tmpl, mark)
777
778  def testPolicyDifferByDirectionV4(self):
779    self._CheckPolicyDifferByDirection(4)
780
781  def testPolicyDifferByDirectionV6(self):
782    self._CheckPolicyDifferByDirection(6)
783
784class XfrmOutputMarkTest(xfrm_base.XfrmLazyTest):
785
786  def _CheckTunnelModeOutputMark(self, version, tunsrc, mark, expected_netid):
787    """Tests sending UDP packets to tunnel mode SAs with output marks.
788
789    Opens a UDP socket and binds it to a random netid, then sets up tunnel mode
790    SAs with an output_mark of mark and sets a socket policy to use the SA.
791    Then checks that sending on those SAs sends a packet on expected_netid,
792    or, if expected_netid is zero, checks that sending returns ENETUNREACH.
793
794    Args:
795      version: 4 or 6.
796      tunsrc: A string, the source address of the tunnel.
797      mark: An integer, the output_mark to set in the SA.
798      expected_netid: An integer, the netid to expect the kernel to send the
799          packet on. If None, expect that sendto will fail with ENETUNREACH.
800    """
801    # Open a UDP socket and bind it to a random netid.
802    family = net_test.GetAddressFamily(version)
803    s = socket(family, SOCK_DGRAM, 0)
804    self.SelectInterface(s, self.RandomNetid(), "mark")
805
806    # For generality, pick a tunnel endpoint that's not the address we
807    # connect the socket to.
808    tundst = TUNNEL_ENDPOINTS[version]
809    tun_addrs = (tunsrc, tundst)
810
811    # Create a tunnel mode SA and use XFRM_OUTPUT_MARK to bind it to netid.
812    spi = TEST_SPI * mark
813    reqid = 100 + spi
814    self.xfrm.AddSaInfo(tunsrc, tundst, spi, xfrm.XFRM_MODE_TUNNEL, reqid,
815                        xfrm_base._ALGO_CBC_AES_256, xfrm_base._ALGO_HMAC_SHA1,
816                        None, None, None, mark)
817
818    # Set a socket policy to use it.
819    xfrm_base.ApplySocketPolicy(s, family, xfrm.XFRM_POLICY_OUT, spi, reqid,
820                                tun_addrs)
821
822    # Send a packet and check that we see it on the wire.
823    remoteaddr = self.GetRemoteAddress(version)
824
825    packetlen = xfrm_base.GetEspPacketLength(xfrm.XFRM_MODE_TUNNEL, version,
826                                             False, net_test.UDP_PAYLOAD,
827                                             xfrm_base._ALGO_HMAC_SHA1,
828                                             xfrm_base._ALGO_CBC_AES_256)
829
830    if expected_netid is not None:
831      s.sendto(net_test.UDP_PAYLOAD, (remoteaddr, 53))
832      self._ExpectEspPacketOn(expected_netid, spi, 1, packetlen, tunsrc, tundst)
833    else:
834      with self.assertRaisesErrno(ENETUNREACH):
835        s.sendto(net_test.UDP_PAYLOAD, (remoteaddr, 53))
836
837    s.close()
838
839  def testTunnelModeOutputMarkIPv4(self):
840    for netid in self.NETIDS:
841      tunsrc = self.MyAddress(4, netid)
842      self._CheckTunnelModeOutputMark(4, tunsrc, netid, netid)
843
844  def testTunnelModeOutputMarkIPv6(self):
845    for netid in self.NETIDS:
846      tunsrc = self.MyAddress(6, netid)
847      self._CheckTunnelModeOutputMark(6, tunsrc, netid, netid)
848
849  def testTunnelModeOutputNoMarkIPv4(self):
850    tunsrc = self.MyAddress(4, self.RandomNetid())
851    self._CheckTunnelModeOutputMark(4, tunsrc, 0, None)
852
853  def testTunnelModeOutputNoMarkIPv6(self):
854    tunsrc = self.MyAddress(6, self.RandomNetid())
855    self._CheckTunnelModeOutputMark(6, tunsrc, 0, None)
856
857  def testTunnelModeOutputInvalidMarkIPv4(self):
858    tunsrc = self.MyAddress(4, self.RandomNetid())
859    self._CheckTunnelModeOutputMark(4, tunsrc, 9999, None)
860
861  def testTunnelModeOutputInvalidMarkIPv6(self):
862    tunsrc = self.MyAddress(6, self.RandomNetid())
863    self._CheckTunnelModeOutputMark(6, tunsrc, 9999, None)
864
865  def testTunnelModeOutputMarkAttributes(self):
866    mark = 1234567
867    self.xfrm.AddSaInfo(TEST_ADDR1, TUNNEL_ENDPOINTS[6], 0x1234,
868                        xfrm.XFRM_MODE_TUNNEL, 100, xfrm_base._ALGO_CBC_AES_256,
869                        xfrm_base._ALGO_HMAC_SHA1, None, None, None, mark)
870    dump = self.xfrm.DumpSaInfo()
871    self.assertEqual(1, len(dump))
872    sainfo, attributes = dump[0]
873    self.assertEqual(mark, attributes["XFRMA_OUTPUT_MARK"])
874
875  def testInvalidAlgorithms(self):
876    key = binascii.unhexlify("af442892cdcd0ef650e9c299f9a8436a")
877    invalid_auth = (xfrm.XfrmAlgoAuth((b"invalid(algo)", 128, 96)), key)
878    invalid_crypt = (xfrm.XfrmAlgo((b"invalid(algo)", 128)), key)
879    with self.assertRaisesErrno(ENOSYS):
880        self.xfrm.AddSaInfo(TEST_ADDR1, TEST_ADDR2, 0x1234,
881            xfrm.XFRM_MODE_TRANSPORT, 0, xfrm_base._ALGO_CBC_AES_256,
882            invalid_auth, None, None, None, 0)
883    with self.assertRaisesErrno(ENOSYS):
884        self.xfrm.AddSaInfo(TEST_ADDR1, TEST_ADDR2, 0x1234,
885            xfrm.XFRM_MODE_TRANSPORT, 0, invalid_crypt,
886            xfrm_base._ALGO_HMAC_SHA1, None, None, None, 0)
887
888  def testUpdateSaAddMark(self):
889    """Test that an embryonic SA can be updated to add a mark."""
890    for version in [4, 6]:
891      spi = 0xABCD
892      # Test that an SA created with ALLOCSPI can be updated with the mark.
893      new_sa = self.xfrm.AllocSpi(net_test.GetWildcardAddress(version),
894                                  IPPROTO_ESP, spi, spi)
895      mark = xfrm.ExactMatchMark(0xf00d)
896      self.xfrm.AddSaInfo(net_test.GetWildcardAddress(version),
897                          net_test.GetWildcardAddress(version),
898                          spi, xfrm.XFRM_MODE_TUNNEL, 0,
899                          xfrm_base._ALGO_CBC_AES_256,
900                          xfrm_base._ALGO_HMAC_SHA1,
901                          None, None, mark, 0, is_update=True)
902      dump = self.xfrm.DumpSaInfo()
903      self.assertEqual(1, len(dump)) # check that update updated
904      sainfo, attributes = dump[0]
905      self.assertEqual(mark, attributes["XFRMA_MARK"])
906      self.xfrm.DeleteSaInfo(net_test.GetWildcardAddress(version),
907                             spi, IPPROTO_ESP, mark)
908
909  def getXfrmStat(self, statName):
910    stateVal = 0
911    with open(XFRM_STATS_PROCFILE, 'r') as f:
912      for line in f:
913          if statName in line:
914            stateVal = int(line.split()[1])
915            break
916      f.close()
917    return stateVal
918
919  def testUpdateActiveSaMarks(self):
920    """Test that the OUTPUT_MARK can be updated on an ACTIVE SA."""
921    for version in [4, 6]:
922      family = net_test.GetAddressFamily(version)
923      netid = self.RandomNetid()
924      remote = self.GetRemoteAddress(version)
925      local = self.MyAddress(version, netid)
926      s = socket(family, SOCK_DGRAM, 0)
927      self.SelectInterface(s, netid, "mark")
928      # Create a mark that we will apply to the policy and later the SA
929      mark = xfrm.ExactMatchMark(netid)
930
931      # Create a global policy that selects using the mark.
932      sel = xfrm.EmptySelector(family)
933      policy = xfrm.UserPolicy(xfrm.XFRM_POLICY_OUT, sel)
934      tmpl = xfrm.UserTemplate(family, 0, 0, (local, remote))
935      self.xfrm.AddPolicyInfo(policy, tmpl, mark)
936
937      # Pull /proc/net/xfrm_stats for baseline
938      outNoStateCount = self.getXfrmStat(XFRM_STATS_OUT_NO_STATES);
939
940      # should increment XfrmOutNoStates
941      s.sendto(net_test.UDP_PAYLOAD, (remote, 53))
942
943      # Check to make sure XfrmOutNoStates is incremented by exactly 1
944      self.assertEqual(outNoStateCount + 1,
945                        self.getXfrmStat(XFRM_STATS_OUT_NO_STATES))
946
947      length = xfrm_base.GetEspPacketLength(xfrm.XFRM_MODE_TUNNEL,
948                                            version, False,
949                                            net_test.UDP_PAYLOAD,
950                                            xfrm_base._ALGO_HMAC_SHA1,
951                                            xfrm_base._ALGO_CBC_AES_256)
952
953      # Add a default SA with no mark that routes to nowhere.
954      try:
955          self.xfrm.AddSaInfo(local,
956                              remote,
957                              TEST_SPI, xfrm.XFRM_MODE_TUNNEL, 0,
958                              xfrm_base._ALGO_CBC_AES_256,
959                              xfrm_base._ALGO_HMAC_SHA1,
960                              None, None, mark, 0, is_update=False)
961      except IOError as e:
962          self.assertEqual(EEXIST, e.errno, "SA exists")
963          self.xfrm.AddSaInfo(local,
964                              remote,
965                              TEST_SPI, xfrm.XFRM_MODE_TUNNEL, 0,
966                              xfrm_base._ALGO_CBC_AES_256,
967                              xfrm_base._ALGO_HMAC_SHA1,
968                              None, None, mark, 0, is_update=True)
969
970      self.assertRaisesErrno(
971          ENETUNREACH,
972          s.sendto, net_test.UDP_PAYLOAD, (remote, 53))
973
974      # Update the SA to route to a valid netid.
975      self.xfrm.AddSaInfo(local,
976                          remote,
977                          TEST_SPI, xfrm.XFRM_MODE_TUNNEL, 0,
978                          xfrm_base._ALGO_CBC_AES_256,
979                          xfrm_base._ALGO_HMAC_SHA1,
980                          None, None, mark, netid, is_update=True)
981
982      # Now the payload routes to the updated netid.
983      s.sendto(net_test.UDP_PAYLOAD, (remote, 53))
984      self._ExpectEspPacketOn(netid, TEST_SPI, 1, length, None, None)
985
986      # Get a new netid and reroute the packets to the new netid.
987      reroute_netid = self.RandomNetid(netid)
988      # Update the SA to change the output mark.
989      self.xfrm.AddSaInfo(local,
990                         remote,
991                         TEST_SPI, xfrm.XFRM_MODE_TUNNEL, 0,
992                         xfrm_base._ALGO_CBC_AES_256,
993                         xfrm_base._ALGO_HMAC_SHA1,
994                         None, None, mark, reroute_netid, is_update=True)
995
996      s.sendto(net_test.UDP_PAYLOAD, (remote, 53))
997      self._ExpectEspPacketOn(reroute_netid, TEST_SPI, 2, length, None, None)
998
999      dump = self.xfrm.DumpSaInfo()
1000
1001      self.assertEqual(1, len(dump)) # check that update updated
1002      sainfo, attributes = dump[0]
1003      self.assertEqual(reroute_netid, attributes["XFRMA_OUTPUT_MARK"])
1004
1005      self.xfrm.DeleteSaInfo(remote, TEST_SPI, IPPROTO_ESP, mark)
1006      self.xfrm.DeletePolicyInfo(sel, xfrm.XFRM_POLICY_OUT, mark)
1007
1008      s.close()
1009
1010if __name__ == "__main__":
1011  unittest.main()
1012