1#!/usr/bin/python3
2#
3# Copyright 2017 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17from errno import *  # pylint: disable=wildcard-import,g-importing-member
18import itertools
19import os
20from socket import *  # pylint: disable=wildcard-import,g-importing-member
21import threading
22import time
23import unittest
24
25import net_test
26from scapy import all as scapy
27from tun_twister import TapTwister
28import util
29import xfrm
30import xfrm_base
31import xfrm_test
32
33ANY_KVER = net_test.LINUX_ANY_VERSION
34
35# List of encryption algorithms for use in ParamTests.
36CRYPT_ALGOS = [
37    (xfrm.XfrmAlgo((xfrm.XFRM_EALG_CBC_AES, 128)), ANY_KVER),
38    (xfrm.XfrmAlgo((xfrm.XFRM_EALG_CBC_AES, 192)), ANY_KVER),
39    (xfrm.XfrmAlgo((xfrm.XFRM_EALG_CBC_AES, 256)), ANY_KVER),
40    # RFC 3686 specifies that key length must be 128, 192 or 256 bits, with
41    # an additional 4 bytes (32 bits) of nonce. A fresh nonce value MUST be
42    # assigned for each SA.
43    # CTR-AES is enforced since kernel version 5.8
44    (xfrm.XfrmAlgo((xfrm.XFRM_EALG_CTR_AES, 128+32)), (5, 8)),
45    (xfrm.XfrmAlgo((xfrm.XFRM_EALG_CTR_AES, 192+32)), (5, 8)),
46    (xfrm.XfrmAlgo((xfrm.XFRM_EALG_CTR_AES, 256+32)), (5, 8)),
47]
48
49# List of auth algorithms for use in ParamTests.
50AUTH_ALGOS = [
51    # RFC 4868 specifies that the only supported truncation length is half the
52    # hash size.
53    (xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_MD5, 128, 96)), ANY_KVER),
54    (xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA1, 160, 96)), ANY_KVER),
55    (xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA256, 256, 128)), ANY_KVER),
56    (xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA384, 384, 192)), ANY_KVER),
57    (xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA512, 512, 256)), ANY_KVER),
58    # Test larger truncation lengths for good measure.
59    (xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_MD5, 128, 128)), ANY_KVER),
60    (xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA1, 160, 160)), ANY_KVER),
61    (xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA256, 256, 256)), ANY_KVER),
62    (xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA384, 384, 384)), ANY_KVER),
63    (xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA512, 512, 512)), ANY_KVER),
64    # RFC 3566 specifies that the only supported truncation length
65    # is 96 bits.
66    # XCBC-AES is enforced since kernel version 5.8
67    (xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_AUTH_XCBC_AES, 128, 96)), (5, 8)),
68]
69
70# List of aead algorithms for use in ParamTests.
71AEAD_ALGOS = [
72    # RFC 4106 specifies that key length must be 128, 192 or 256 bits,
73    #   with an additional 4 bytes (32 bits) of salt. The salt must be unique
74    #   for each new SA using the same key.
75    # RFC 4106 specifies that ICV length must be 8, 12, or 16 bytes
76    (xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 128+32, 8*8)), ANY_KVER),
77    (xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 128+32, 12*8)), ANY_KVER),
78    (xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 128+32, 16*8)), ANY_KVER),
79    (xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 192+32, 8*8)), ANY_KVER),
80    (xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 192+32, 12*8)), ANY_KVER),
81    (xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 192+32, 16*8)), ANY_KVER),
82    (xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 256+32, 8*8)), ANY_KVER),
83    (xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 256+32, 12*8)), ANY_KVER),
84    (xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 256+32, 16*8)), ANY_KVER),
85    # RFC 7634 specifies that key length must be 256 bits, with an additional
86    # 4 bytes (32 bits) of nonce. A fresh nonce value MUST be assigned for
87    # each SA. RFC 7634 also specifies that ICV length must be 16 bytes.
88    # ChaCha20-Poly1305 is enforced since kernel version 5.8
89    (xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_CHACHA20_POLY1305, 256+32, 16*8)),
90     (5, 8)),
91]
92
93
94def GenerateKey(key_len):
95  if key_len % 8 != 0:
96    raise ValueError("Invalid key length in bits: " + str(key_len))
97  return os.urandom(key_len // 8)
98
99
100# Does the kernel support this algorithm?
101def HaveAlgo(crypt_algo, auth_algo, aead_algo):
102  try:
103    test_xfrm = xfrm.Xfrm()
104    test_xfrm.FlushSaInfo()
105    test_xfrm.FlushPolicyInfo()
106
107    test_xfrm.AddSaInfo(
108        src=xfrm_test.TEST_ADDR1,
109        dst=xfrm_test.TEST_ADDR2,
110        spi=xfrm_test.TEST_SPI,
111        mode=xfrm.XFRM_MODE_TRANSPORT,
112        reqid=100,
113        encryption=(crypt_algo,
114                    GenerateKey(crypt_algo.key_len)) if crypt_algo else None,
115        auth_trunc=(auth_algo,
116                    GenerateKey(auth_algo.key_len)) if auth_algo else None,
117        aead=(aead_algo, GenerateKey(aead_algo.key_len)) if aead_algo else None,
118        encap=None,
119        mark=None,
120        output_mark=None)
121
122    test_xfrm.FlushSaInfo()
123    test_xfrm.FlushPolicyInfo()
124
125    return True
126  except IOError as err:
127    if err.errno == ENOSYS:
128      return False
129    else:
130      print("Unexpected error:", err.errno)
131      return True
132
133# Dictionary to record the algorithm state. Mark the state True if this
134# algorithm is enforced or enabled on this kernel. Otherwise, mark it
135# False.
136algoState = {}
137
138
139def AlgoEnforcedOrEnabled(crypt, auth, aead, target_algo, target_kernel):
140  if algoState.get(target_algo) is None:
141    algoState[target_algo] = (net_test.LINUX_VERSION >= target_kernel
142                              or HaveAlgo(crypt, auth, aead))
143  return algoState.get(target_algo)
144
145
146# Return true if this algorithm should be enforced or is enabled on this kernel
147def AuthEnforcedOrEnabled(auth_case):
148  auth = auth_case[0]
149  crypt = xfrm.XfrmAlgo((b"ecb(cipher_null)", 0))
150  return AlgoEnforcedOrEnabled(crypt, auth, None, auth.name, auth_case[1])
151
152
153# Return true if this algorithm should be enforced or is enabled on this kernel
154def CryptEnforcedOrEnabled(crypt_case):
155  crypt = crypt_case[0]
156  auth = xfrm.XfrmAlgoAuth((b"digest_null", 0, 0))
157  return AlgoEnforcedOrEnabled(crypt, auth, None, crypt.name, crypt_case[1])
158
159
160# Return true if this algorithm should be enforced or is enabled on this kernel
161def AeadEnforcedOrEnabled(aead_case):
162  aead = aead_case[0]
163  return AlgoEnforcedOrEnabled(None, None, aead, aead.name, aead_case[1])
164
165
166def InjectTests():
167  XfrmAlgorithmTest.InjectTests()
168
169
170class XfrmAlgorithmTest(xfrm_base.XfrmLazyTest):
171  @classmethod
172  def InjectTests(cls):
173    versions = (4, 6)
174    types = (SOCK_DGRAM, SOCK_STREAM)
175
176    # Tests all combinations of auth & crypt. Mutually exclusive with aead.
177    param_list = itertools.product(versions, types, AUTH_ALGOS, CRYPT_ALGOS,
178                                   [None])
179    util.InjectParameterizedTest(cls, param_list, cls.TestNameGenerator)
180
181    # Tests all combinations of aead. Mutually exclusive with auth/crypt.
182    param_list = itertools.product(versions, types, [None], [None], AEAD_ALGOS)
183    util.InjectParameterizedTest(cls, param_list, cls.TestNameGenerator)
184
185  @staticmethod
186  def TestNameGenerator(version, proto, auth_case, crypt_case, aead_case):
187    # Produce a unique and readable name for each test. e.g.
188    #     testSocketPolicySimple_cbc-aes_256_hmac-sha512_512_256_IPv6_UDP
189    param_string = ""
190    if crypt_case is not None:
191      crypt = crypt_case[0]
192      param_string += "%s_%d_" % (crypt.name.decode(), crypt.key_len)
193
194    if auth_case is not None:
195      auth = auth_case[0]
196      param_string += "%s_%d_%d_" % (auth.name.decode(), auth.key_len,
197                                     auth.trunc_len)
198
199    if aead_case is not None:
200      aead = aead_case[0]
201      param_string += "%s_%d_%d_" % (aead.name.decode(), aead.key_len,
202                                     aead.icv_len)
203
204    param_string += "%s_%s" % ("IPv4" if version == 4 else "IPv6",
205                               "UDP" if proto == SOCK_DGRAM else "TCP")
206    return param_string
207
208  def ParamTestSocketPolicySimple(self, version, proto, auth_case, crypt_case,
209                                  aead_case):
210    """Test two-way traffic using transport mode and socket policies."""
211
212    # Bypass the test if any algorithm going to be tested is not enforced
213    # or enabled on this kernel
214    if auth_case is not None and not AuthEnforcedOrEnabled(auth_case):
215      return
216    if crypt_case is not None and not CryptEnforcedOrEnabled(crypt_case):
217      return
218    if aead_case is not None and not AeadEnforcedOrEnabled(aead_case):
219      return
220
221    auth = auth_case[0] if auth_case else None
222    crypt = crypt_case[0] if crypt_case else None
223    aead = aead_case[0] if aead_case else None
224
225    def AssertEncrypted(packet):
226      # This gives a free pass to ICMP and ICMPv6 packets, which show up
227      # nondeterministically in tests.
228      self.assertEqual(None,
229                       packet.getlayer(scapy.UDP),
230                       "UDP packet sent in the clear")
231      self.assertEqual(None,
232                       packet.getlayer(scapy.TCP),
233                       "TCP packet sent in the clear")
234
235    # We create a pair of sockets, "left" and "right", that will talk to each
236    # other using transport mode ESP. Because of TapTwister, both sockets
237    # perceive each other as owning "remote_addr".
238    netid = self.RandomNetid()
239    family = net_test.GetAddressFamily(version)
240    local_addr = self.MyAddress(version, netid)
241    remote_addr = self.GetRemoteSocketAddress(version)
242    auth_left = (xfrm.XfrmAlgoAuth((auth.name, auth.key_len, auth.trunc_len)),
243                 os.urandom(auth.key_len // 8)) if auth else None
244    auth_right = (xfrm.XfrmAlgoAuth((auth.name, auth.key_len, auth.trunc_len)),
245                  os.urandom(auth.key_len // 8)) if auth else None
246    crypt_left = (xfrm.XfrmAlgo((crypt.name, crypt.key_len)),
247                  os.urandom(crypt.key_len // 8)) if crypt else None
248    crypt_right = (xfrm.XfrmAlgo((crypt.name, crypt.key_len)),
249                   os.urandom(crypt.key_len // 8)) if crypt else None
250    aead_left = (xfrm.XfrmAlgoAead((aead.name, aead.key_len, aead.icv_len)),
251                 os.urandom(aead.key_len // 8)) if aead else None
252    aead_right = (xfrm.XfrmAlgoAead((aead.name, aead.key_len, aead.icv_len)),
253                  os.urandom(aead.key_len // 8)) if aead else None
254    spi_left = 0xbeefface
255    spi_right = 0xcafed00d
256    req_ids = [100, 200, 300, 400]  # Used to match templates and SAs.
257
258    # Left outbound SA
259    self.xfrm.AddSaInfo(
260        src=local_addr,
261        dst=remote_addr,
262        spi=spi_right,
263        mode=xfrm.XFRM_MODE_TRANSPORT,
264        reqid=req_ids[0],
265        encryption=crypt_right,
266        auth_trunc=auth_right,
267        aead=aead_right,
268        encap=None,
269        mark=None,
270        output_mark=None)
271    # Right inbound SA
272    self.xfrm.AddSaInfo(
273        src=remote_addr,
274        dst=local_addr,
275        spi=spi_right,
276        mode=xfrm.XFRM_MODE_TRANSPORT,
277        reqid=req_ids[1],
278        encryption=crypt_right,
279        auth_trunc=auth_right,
280        aead=aead_right,
281        encap=None,
282        mark=None,
283        output_mark=None)
284    # Right outbound SA
285    self.xfrm.AddSaInfo(
286        src=local_addr,
287        dst=remote_addr,
288        spi=spi_left,
289        mode=xfrm.XFRM_MODE_TRANSPORT,
290        reqid=req_ids[2],
291        encryption=crypt_left,
292        auth_trunc=auth_left,
293        aead=aead_left,
294        encap=None,
295        mark=None,
296        output_mark=None)
297    # Left inbound SA
298    self.xfrm.AddSaInfo(
299        src=remote_addr,
300        dst=local_addr,
301        spi=spi_left,
302        mode=xfrm.XFRM_MODE_TRANSPORT,
303        reqid=req_ids[3],
304        encryption=crypt_left,
305        auth_trunc=auth_left,
306        aead=aead_left,
307        encap=None,
308        mark=None,
309        output_mark=None)
310
311    # Make two sockets.
312    sock_left = socket(family, proto, 0)
313    sock_left.settimeout(2.0)
314    sock_left.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
315    self.SelectInterface(sock_left, netid, "mark")
316    sock_right = socket(family, proto, 0)
317    sock_right.settimeout(2.0)
318    sock_right.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
319    self.SelectInterface(sock_right, netid, "mark")
320
321    # For UDP, set SO_LINGER to 0, to prevent TCP sockets from hanging around
322    # in a TIME_WAIT state.
323    if proto == SOCK_STREAM:
324      net_test.DisableFinWait(sock_left)
325      net_test.DisableFinWait(sock_right)
326
327    # Apply the left outbound socket policy.
328    xfrm_base.ApplySocketPolicy(sock_left, family, xfrm.XFRM_POLICY_OUT,
329                                spi_right, req_ids[0], None)
330    # Apply right inbound socket policy.
331    xfrm_base.ApplySocketPolicy(sock_right, family, xfrm.XFRM_POLICY_IN,
332                                spi_right, req_ids[1], None)
333    # Apply right outbound socket policy.
334    xfrm_base.ApplySocketPolicy(sock_right, family, xfrm.XFRM_POLICY_OUT,
335                                spi_left, req_ids[2], None)
336    # Apply left inbound socket policy.
337    xfrm_base.ApplySocketPolicy(sock_left, family, xfrm.XFRM_POLICY_IN,
338                                spi_left, req_ids[3], None)
339
340    server_ready = threading.Event()
341    server_error = None  # Save exceptions thrown by the server.
342
343    def TcpServer(sock, client_port):
344      try:
345        sock.listen(1)
346        server_ready.set()
347        accepted, peer = sock.accept()
348        self.assertEqual(remote_addr, peer[0])
349        self.assertEqual(client_port, peer[1])
350        data = accepted.recv(2048)
351        self.assertEqual(b"hello request", data)
352        accepted.send(b"hello response")
353        time.sleep(0.1)
354        accepted.close()
355      except Exception as e:  # pylint: disable=broad-exception-caught
356        nonlocal server_error
357        server_error = e
358      finally:
359        sock.close()
360
361    def UdpServer(sock, client_port):
362      try:
363        server_ready.set()
364        data, peer = sock.recvfrom(2048)
365        self.assertEqual(remote_addr, peer[0])
366        self.assertEqual(client_port, peer[1])
367        self.assertEqual(b"hello request", data)
368        sock.sendto(b"hello response", peer)
369      except Exception as e:  # pylint: disable=broad-exception-caught
370        nonlocal server_error
371        server_error = e
372      finally:
373        sock.close()
374
375    # Server and client need to know each other's port numbers in advance.
376    wildcard_addr = net_test.GetWildcardAddress(version)
377    sock_left.bind((wildcard_addr, 0))
378    sock_right.bind((wildcard_addr, 0))
379    left_port = sock_left.getsockname()[1]
380    right_port = sock_right.getsockname()[1]
381
382    # Start the appropriate server type on sock_right.
383    target = TcpServer if proto == SOCK_STREAM else UdpServer
384    server = threading.Thread(
385        target=target,
386        args=(sock_right, left_port),
387        name="SocketServer")
388    server.start()
389    # Wait for server to be ready before attempting to connect. TCP retries
390    # hide this problem, but UDP will fail outright if the server socket has
391    # not bound when we send.
392    self.assertTrue(server_ready.wait(3.0),
393                    "Timed out waiting for server thread")
394
395    with TapTwister(fd=self.tuns[netid].fileno(), validator=AssertEncrypted):
396      sock_left.connect((remote_addr, right_port))
397      sock_left.send(b"hello request")
398      data = sock_left.recv(2048)
399      self.assertEqual(b"hello response", data)
400      sock_left.close()
401      server.join(timeout=3.0)
402      self.assertFalse(server.is_alive(), "Timed out waiting for server exit")
403    if server_error:
404      raise server_error
405
406
407if __name__ == "__main__":
408  XfrmAlgorithmTest.InjectTests()
409  unittest.main()
410