1#!/usr/bin/python3
2#
3# Copyright 2015 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import time
18from socket import *  # pylint: disable=wildcard-import
19
20import net_test
21import multinetwork_base
22import packets
23
24# TCP states. See include/net/tcp_states.h.
25TCP_ESTABLISHED = 1
26TCP_SYN_SENT = 2
27TCP_SYN_RECV = 3
28TCP_FIN_WAIT1 = 4
29TCP_FIN_WAIT2 = 5
30TCP_TIME_WAIT = 6
31TCP_CLOSE = 7
32TCP_CLOSE_WAIT = 8
33TCP_LAST_ACK = 9
34TCP_LISTEN = 10
35TCP_CLOSING = 11
36TCP_NEW_SYN_RECV = 12
37
38TCP_NOT_YET_ACCEPTED = -1
39
40
41class TcpBaseTest(multinetwork_base.MultiNetworkBaseTest):
42
43  def CloseSockets(self):
44    if hasattr(self, "accepted"):
45      self.accepted.close()
46      del self.accepted
47    if hasattr(self, "s"):
48      self.s.close()
49      del self.s
50
51  def tearDown(self):
52    self.CloseSockets()
53    super(TcpBaseTest, self).tearDown()
54
55  def OpenListenSocket(self, version, netid):
56    family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
57    address = {4: "0.0.0.0", 5: "::", 6: "::"}[version]
58    s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP)
59    # We haven't configured inbound iptables marking, so bind explicitly.
60    self.SelectInterface(s, netid, "mark")
61    self.port = net_test.BindRandomPort(version, s)
62    return s
63
64  def _ReceiveAndExpectResponse(self, netid, packet, reply, msg):
65    pkt = super(TcpBaseTest, self)._ReceiveAndExpectResponse(netid, packet,
66                                                             reply, msg)
67    self.last_packet = pkt
68    return pkt
69
70  def ReceivePacketOn(self, netid, packet):
71    super(TcpBaseTest, self).ReceivePacketOn(netid, packet)
72    self.last_packet = packet
73
74  def ReceiveRstPacketOn(self, netid):
75    # self.last_packet is the last packet we received. Invert direction twice.
76    _, ack = packets.ACK(self.version, self.myaddr, self.remoteaddr,
77                         self.last_packet)
78    desc, rst = packets.RST(self.version, self.remoteaddr, self.myaddr,
79                            ack)
80    super(TcpBaseTest, self).ReceivePacketOn(netid, rst)
81
82  def RstPacket(self):
83    return packets.RST(self.version, self.myaddr, self.remoteaddr,
84                       self.last_packet)
85
86  def FinPacket(self):
87    return packets.FIN(self.version, self.myaddr, self.remoteaddr,
88                       self.last_packet)
89
90
91  def IncomingConnection(self, version, end_state, netid):
92    self.s = self.OpenListenSocket(version, netid)
93    self.end_state = end_state
94
95    remoteaddr = self.remoteaddr = self.GetRemoteAddress(version)
96    remotesockaddr = self.remotesockaddr = self.GetRemoteSocketAddress(version)
97
98    myaddr = self.myaddr = self.MyAddress(version, netid)
99    mysockaddr = self.mysockaddr = self.MySocketAddress(version, netid)
100
101    if version == 5: version = 4
102    self.version = version
103
104    if end_state == TCP_LISTEN:
105      return
106
107    desc, syn = packets.SYN(self.port, version, remoteaddr, myaddr)
108    synack_desc, synack = packets.SYNACK(version, myaddr, remoteaddr, syn)
109    msg = "Received %s, expected to see reply %s" % (desc, synack_desc)
110    reply = self._ReceiveAndExpectResponse(netid, syn, synack, msg)
111    if end_state == TCP_SYN_RECV:
112      return
113
114    establishing_ack = packets.ACK(version, remoteaddr, myaddr, reply)[1]
115    self.ReceivePacketOn(netid, establishing_ack)
116
117    if end_state == TCP_NOT_YET_ACCEPTED:
118      return
119
120    self.accepted, _ = self.s.accept()
121    net_test.DisableFinWait(self.accepted)
122
123    if end_state == TCP_ESTABLISHED:
124      return
125
126    desc, data = packets.ACK(version, myaddr, remoteaddr, establishing_ack,
127                             payload=net_test.UDP_PAYLOAD)
128    self.accepted.send(net_test.UDP_PAYLOAD)
129    self.ExpectPacketOn(netid, msg + ": expecting %s" % desc, data)
130
131    desc, fin = packets.FIN(version, remoteaddr, myaddr, data)
132    fin = packets._GetIpLayer(version)(bytes(fin))
133    ack_desc, ack = packets.ACK(version, myaddr, remoteaddr, fin)
134    msg = "Received %s, expected to see reply %s" % (desc, ack_desc)
135
136    # TODO: Why can't we use this?
137    #   self._ReceiveAndExpectResponse(netid, fin, ack, msg)
138    self.ReceivePacketOn(netid, fin)
139    time.sleep(0.1)
140    self.ExpectPacketOn(netid, msg + ": expecting %s" % ack_desc, ack)
141    if end_state == TCP_CLOSE_WAIT:
142      return
143
144    raise ValueError("Invalid TCP state %d specified" % end_state)
145