#!/usr/bin/python3 # # Copyright 2019 The Android Open Source Project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest from errno import * # pylint: disable=wildcard-import from socket import * # pylint: disable=wildcard-import import ctypes import fcntl import os import random import select import termios import threading import time from scapy import all as scapy import multinetwork_base import net_test import packets SOL_TCP = net_test.SOL_TCP SHUT_RD = net_test.SHUT_RD SHUT_WR = net_test.SHUT_WR SHUT_RDWR = net_test.SHUT_RDWR SIOCINQ = termios.FIONREAD SIOCOUTQ = termios.TIOCOUTQ TEST_PORT = 5555 # Following constants are SOL_TCP level options and arguments. # They are defined in linux-kernel: include/uapi/linux/tcp.h # SOL_TCP level options. TCP_REPAIR = 19 TCP_REPAIR_QUEUE = 20 TCP_QUEUE_SEQ = 21 # TCP_REPAIR_{OFF, ON} is an argument to TCP_REPAIR. TCP_REPAIR_OFF = 0 TCP_REPAIR_ON = 1 # TCP_{NO, RECV, SEND}_QUEUE is an argument to TCP_REPAIR_QUEUE. TCP_NO_QUEUE = 0 TCP_RECV_QUEUE = 1 TCP_SEND_QUEUE = 2 # This test is aiming to ensure tcp keep alive offload works correctly # when it fetches tcp information from kernel via tcp repair mode. class TcpRepairTest(multinetwork_base.MultiNetworkBaseTest): def assertSocketNotConnected(self, sock): self.assertRaisesErrno(ENOTCONN, sock.getpeername) def assertSocketConnected(self, sock): sock.getpeername() # No errors? Socket is alive and connected. def createConnectedSocket(self, version, netid): s = net_test.TCPSocket(net_test.GetAddressFamily(version)) net_test.DisableFinWait(s) self.SelectInterface(s, netid, "mark") remotesockaddr = self.GetRemoteSocketAddress(version) remoteaddr = self.GetRemoteAddress(version) self.assertRaisesErrno(EINPROGRESS, s.connect, (remotesockaddr, TEST_PORT)) self.assertSocketNotConnected(s) myaddr = self.MyAddress(version, netid) port = s.getsockname()[1] self.assertNotEqual(0, port) desc, expect_syn = packets.SYN(TEST_PORT, version, myaddr, remoteaddr, port, seq=None) msg = "socket connect: expected %s" % desc syn = self.ExpectPacketOn(netid, msg, expect_syn) synack_desc, synack = packets.SYNACK(version, remoteaddr, myaddr, syn) synack.getlayer("TCP").seq = random.getrandbits(32) synack.getlayer("TCP").window = 14400 self.ReceivePacketOn(netid, synack) desc, ack = packets.ACK(version, myaddr, remoteaddr, synack) msg = "socket connect: got SYN+ACK, expected %s" % desc ack = self.ExpectPacketOn(netid, msg, ack) self.last_sent = ack self.last_received = synack return s def receiveFin(self, netid, version, sock): self.assertSocketConnected(sock) remoteaddr = self.GetRemoteAddress(version) myaddr = self.MyAddress(version, netid) desc, fin = packets.FIN(version, remoteaddr, myaddr, self.last_sent) self.ReceivePacketOn(netid, fin) self.last_received = fin def sendData(self, netid, version, sock, payload): sock.send(payload) remoteaddr = self.GetRemoteAddress(version) myaddr = self.MyAddress(version, netid) desc, send = packets.ACK(version, myaddr, remoteaddr, self.last_received, payload) self.last_sent = send def receiveData(self, netid, version, payload): remoteaddr = self.GetRemoteAddress(version) myaddr = self.MyAddress(version, netid) desc, received = packets.ACK(version, remoteaddr, myaddr, self.last_sent, payload) ack_desc, ack = packets.ACK(version, myaddr, remoteaddr, received) self.ReceivePacketOn(netid, received) time.sleep(0.1) self.ExpectPacketOn(netid, "expecting %s" % ack_desc, ack) self.last_sent = ack self.last_received = received # Test the behavior of NO_QUEUE. Expect incoming data will be stored into # the queue, but socket cannot be read/written in NO_QUEUE. def testTcpRepairInNoQueue(self): for version in [4, 5, 6]: self.tcpRepairInNoQueueTest(version) def tcpRepairInNoQueueTest(self, version): netid = self.RandomNetid() sock = self.createConnectedSocket(version, netid) sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_ON) # In repair mode with NO_QUEUE, writes fail... self.assertRaisesErrno(EINVAL, sock.send, b"write test") # remote data is coming. TEST_RECEIVED = net_test.UDP_PAYLOAD self.receiveData(netid, version, TEST_RECEIVED) # In repair mode with NO_QUEUE, read fail... self.assertRaisesErrno(EPERM, sock.recv, 4096) sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_OFF) readData = sock.recv(4096) self.assertEqual(readData, TEST_RECEIVED) sock.close() # Test whether tcp read/write sequence number can be fetched correctly # by TCP_QUEUE_SEQ. def testGetSequenceNumber(self): for version in [4, 5, 6]: self.GetSequenceNumberTest(version) def GetSequenceNumberTest(self, version): netid = self.RandomNetid() sock = self.createConnectedSocket(version, netid) # test write queue sequence number sequence_before = self.GetWriteSequenceNumber(version, sock) expect_sequence = self.last_sent.getlayer("TCP").seq self.assertEqual(sequence_before & 0xffffffff, expect_sequence) TEST_SEND = net_test.UDP_PAYLOAD self.sendData(netid, version, sock, TEST_SEND) sequence_after = self.GetWriteSequenceNumber(version, sock) self.assertEqual(sequence_before + len(TEST_SEND), sequence_after) # test read queue sequence number sequence_before = self.GetReadSequenceNumber(version, sock) expect_sequence = self.last_received.getlayer("TCP").seq + 1 self.assertEqual(sequence_before & 0xffffffff, expect_sequence) TEST_READ = net_test.UDP_PAYLOAD self.receiveData(netid, version, TEST_READ) sequence_after = self.GetReadSequenceNumber(version, sock) self.assertEqual(sequence_before + len(TEST_READ), sequence_after) sock.close() def GetWriteSequenceNumber(self, version, sock): sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_ON) sock.setsockopt(SOL_TCP, TCP_REPAIR_QUEUE, TCP_SEND_QUEUE) sequence = sock.getsockopt(SOL_TCP, TCP_QUEUE_SEQ) sock.setsockopt(SOL_TCP, TCP_REPAIR_QUEUE, TCP_NO_QUEUE) sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_OFF) return sequence def GetReadSequenceNumber(self, version, sock): sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_ON) sock.setsockopt(SOL_TCP, TCP_REPAIR_QUEUE, TCP_RECV_QUEUE) sequence = sock.getsockopt(SOL_TCP, TCP_QUEUE_SEQ) sock.setsockopt(SOL_TCP, TCP_REPAIR_QUEUE, TCP_NO_QUEUE) sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_OFF) return sequence # Test whether tcp repair socket can be poll()'ed correctly # in mutiple threads at the same time. def testMultiThreadedPoll(self): for version in [4, 5, 6]: self.PollWhenShutdownTest(version) self.PollWhenReceiveFinTest(version) def PollRepairSocketInMultipleThreads(self, netid, version, expected): sock = self.createConnectedSocket(version, netid) sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_ON) multiThreads = [] for i in [0, 1]: thread = SocketExceptionThread(sock, lambda sk: self.fdSelect(sock, expected)) thread.start() self.assertTrue(thread.is_alive()) multiThreads.append(thread) return sock, multiThreads def assertThreadsStopped(self, multiThreads, msg) : for thread in multiThreads: if (thread.is_alive()): thread.join(1) if (thread.is_alive()): thread.stop() raise AssertionError(msg) def PollWhenShutdownTest(self, version): netid = self.RandomNetid() expected = select.POLLIN sock, multiThreads = self.PollRepairSocketInMultipleThreads(netid, version, expected) # Test shutdown RD. sock.shutdown(SHUT_RD) self.assertThreadsStopped(multiThreads, "poll fail during SHUT_RD") sock.close() expected = None sock, multiThreads = self.PollRepairSocketInMultipleThreads(netid, version, expected) # Test shutdown WR. sock.shutdown(SHUT_WR) self.assertThreadsStopped(multiThreads, "poll fail during SHUT_WR") sock.close() expected = select.POLLIN | select.POLLHUP sock, multiThreads = self.PollRepairSocketInMultipleThreads(netid, version, expected) # Test shutdown RDWR. sock.shutdown(SHUT_RDWR) self.assertThreadsStopped(multiThreads, "poll fail during SHUT_RDWR") sock.close() def PollWhenReceiveFinTest(self, version): netid = self.RandomNetid() expected = select.POLLIN sock, multiThreads = self.PollRepairSocketInMultipleThreads(netid, version, expected) self.receiveFin(netid, version, sock) self.assertThreadsStopped(multiThreads, "poll fail during FIN") sock.close() # Test whether socket idle can be detected by SIOCINQ and SIOCOUTQ. def testSocketIdle(self): for version in [4, 5, 6]: self.readQueueIdleTest(version) self.writeQueueIdleTest(version) def readQueueIdleTest(self, version): netid = self.RandomNetid() sock = self.createConnectedSocket(version, netid) buf = ctypes.c_int() fcntl.ioctl(sock, SIOCINQ, buf) self.assertEqual(buf.value, 0) TEST_RECV_PAYLOAD = net_test.UDP_PAYLOAD self.receiveData(netid, version, TEST_RECV_PAYLOAD) fcntl.ioctl(sock, SIOCINQ, buf) self.assertEqual(buf.value, len(TEST_RECV_PAYLOAD)) sock.close() def writeQueueIdleTest(self, version): netid = self.RandomNetid() # Setup a connected socket, write queue is empty. sock = self.createConnectedSocket(version, netid) buf = ctypes.c_int() fcntl.ioctl(sock, SIOCOUTQ, buf) self.assertEqual(buf.value, 0) # Change to repair mode with SEND_QUEUE, writing some data to the queue. sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_ON) TEST_SEND_PAYLOAD = net_test.UDP_PAYLOAD sock.setsockopt(SOL_TCP, TCP_REPAIR_QUEUE, TCP_SEND_QUEUE) self.sendData(netid, version, sock, TEST_SEND_PAYLOAD) fcntl.ioctl(sock, SIOCOUTQ, buf) self.assertEqual(buf.value, len(TEST_SEND_PAYLOAD)) sock.close() # Setup a connected socket again. netid = self.RandomNetid() sock = self.createConnectedSocket(version, netid) # Send out some data and don't receive ACK yet. self.sendData(netid, version, sock, TEST_SEND_PAYLOAD) fcntl.ioctl(sock, SIOCOUTQ, buf) self.assertEqual(buf.value, len(TEST_SEND_PAYLOAD)) # Receive response ACK. remoteaddr = self.GetRemoteAddress(version) myaddr = self.MyAddress(version, netid) desc_ack, ack = packets.ACK(version, remoteaddr, myaddr, self.last_sent) self.ReceivePacketOn(netid, ack) fcntl.ioctl(sock, SIOCOUTQ, buf) self.assertEqual(buf.value, 0) sock.close() def fdSelect(self, sock, expected): READ_ONLY = select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR | select.POLLNVAL p = select.poll() p.register(sock, READ_ONLY) events = p.poll(500) for fd,event in events: if fd == sock.fileno(): self.assertEqual(event, expected) else: raise AssertionError("unexpected poll fd") class SocketExceptionThread(threading.Thread): def __init__(self, sock, operation): self.exception = None super(SocketExceptionThread, self).__init__() self.daemon = True self.sock = sock self.operation = operation def stop(self): self._Thread__stop() def run(self): try: self.operation(self.sock) except (IOError, AssertionError) as e: self.exception = e if __name__ == '__main__': unittest.main()