• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python3
2#
3# Copyright 2015 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Partial Python implementation of sock_diag functionality."""
18
19# pylint: disable=g-bad-todo
20
21import errno
22import os
23from socket import *  # pylint: disable=wildcard-import
24import struct
25
26import csocket
27import cstruct
28import net_test
29import netlink
30
31### sock_diag constants. See include/uapi/linux/sock_diag.h.
32# Message types.
33SOCK_DIAG_BY_FAMILY = 20
34SOCK_DESTROY = 21
35
36### inet_diag_constants. See include/uapi/linux/inet_diag.h
37# Message types.
38TCPDIAG_GETSOCK = 18
39
40# Request attributes.
41INET_DIAG_REQ_BYTECODE = 1
42
43# Extensions.
44INET_DIAG_NONE = 0
45INET_DIAG_MEMINFO = 1
46INET_DIAG_INFO = 2
47INET_DIAG_VEGASINFO = 3
48INET_DIAG_CONG = 4
49INET_DIAG_TOS = 5
50INET_DIAG_TCLASS = 6
51INET_DIAG_SKMEMINFO = 7
52INET_DIAG_SHUTDOWN = 8
53INET_DIAG_DCTCPINFO = 9
54INET_DIAG_DCTCPINFO = 9
55INET_DIAG_PROTOCOL = 10
56INET_DIAG_SKV6ONLY = 11
57INET_DIAG_LOCALS = 12
58INET_DIAG_PEERS = 13
59INET_DIAG_PAD = 14
60INET_DIAG_MARK = 15
61
62# Bytecode operations.
63INET_DIAG_BC_NOP = 0
64INET_DIAG_BC_JMP = 1
65INET_DIAG_BC_S_GE = 2
66INET_DIAG_BC_S_LE = 3
67INET_DIAG_BC_D_GE = 4
68INET_DIAG_BC_D_LE = 5
69INET_DIAG_BC_AUTO = 6
70INET_DIAG_BC_S_COND = 7
71INET_DIAG_BC_D_COND = 8
72INET_DIAG_BC_DEV_COND = 9
73INET_DIAG_BC_MARK_COND = 10
74
75CONSTANT_PREFIXES = netlink.MakeConstantPrefixes([
76    "INET_DIAG_", "INET_DIAG_REQ_", "INET_DIAG_BC_"])
77
78# Data structure formats.
79# These aren't constants, they're classes. So, pylint: disable=invalid-name
80InetDiagSockId = cstruct.Struct(
81    "InetDiagSockId", "!HH16s16sI8s", "sport dport src dst iface cookie")
82InetDiagReqV2 = cstruct.Struct(
83    "InetDiagReqV2", "=BBBxIS", "family protocol ext states id",
84    [InetDiagSockId])
85InetDiagMsg = cstruct.Struct(
86    "InetDiagMsg", "=BBBBSLLLLL",
87    "family state timer retrans id expires rqueue wqueue uid inode",
88    [InetDiagSockId])
89InetDiagMeminfo = cstruct.Struct(
90    "InetDiagMeminfo", "=IIII", "rmem wmem fmem tmem")
91InetDiagBcOp = cstruct.Struct("InetDiagBcOp", "BBH", "code yes no")
92InetDiagHostcond = cstruct.Struct("InetDiagHostcond", "=BBxxi",
93                                  "family prefix_len port")
94InetDiagMarkcond = cstruct.Struct("InetDiagMarkcond", "=II", "mark mask")
95
96SkMeminfo = cstruct.Struct(
97    "SkMeminfo", "=IIIIIIII",
98    "rmem_alloc rcvbuf wmem_alloc sndbuf fwd_alloc wmem_queued optmem backlog")
99TcpInfo = cstruct.Struct(
100    "TcpInfo", "=BBBBBBBxIIIIIIIIIIIIIIIIIIIIIIII",
101    "state ca_state retransmits probes backoff options wscale "
102    "rto ato snd_mss rcv_mss "
103    "unacked sacked lost retrans fackets "
104    "last_data_sent last_ack_sent last_data_recv last_ack_recv "
105    "pmtu rcv_ssthresh rtt rttvar snd_ssthresh snd_cwnd advmss reordering "
106    "rcv_rtt rcv_space "
107    "total_retrans")  # As of linux 3.13, at least.
108
109TCP_TIME_WAIT = 6
110ALL_NON_TIME_WAIT = 0xffffffff & ~(1 << TCP_TIME_WAIT)
111
112
113class SockDiag(netlink.NetlinkSocket):
114
115  NL_DEBUG = []
116
117  def __init__(self):
118    super(SockDiag, self).__init__(netlink.NETLINK_SOCK_DIAG)
119
120  def _Decode(self, command, msg, nla_type, nla_data, nested):
121    """Decodes netlink attributes to Python types."""
122    if msg.family == AF_INET or msg.family == AF_INET6:
123      if isinstance(msg, InetDiagReqV2):
124        prefix = "INET_DIAG_REQ_"
125      else:
126        prefix = "INET_DIAG_"
127      name = self._GetConstantName(__name__, nla_type, prefix)
128    else:
129      # Don't know what this is. Leave it as an integer.
130      name = nla_type
131
132    if name in ["INET_DIAG_SHUTDOWN", "INET_DIAG_TOS", "INET_DIAG_TCLASS",
133                "INET_DIAG_SKV6ONLY"]:
134      data = ord(nla_data)
135    elif name == "INET_DIAG_CONG":
136      data = nla_data.strip(b"\x00")
137    elif name == "INET_DIAG_MEMINFO":
138      data = InetDiagMeminfo(nla_data)
139    elif name == "INET_DIAG_INFO":
140      # TODO: Catch the exception and try something else if it's not TCP.
141      data = TcpInfo(nla_data)
142    elif name == "INET_DIAG_SKMEMINFO":
143      data = SkMeminfo(nla_data)
144    elif name == "INET_DIAG_MARK":
145      data = struct.unpack("=I", nla_data)[0]
146    elif name == "INET_DIAG_REQ_BYTECODE":
147      data = self.DecodeBytecode(nla_data)
148    elif name in ["INET_DIAG_LOCALS", "INET_DIAG_PEERS"]:
149      data = []
150      while len(nla_data):
151        # The SCTP diag code always appears to copy sizeof(sockaddr_storage)
152        # bytes, but does so from a union sctp_addr which is at most as long
153        # as a sockaddr_in6.
154        addr, nla_data = cstruct.Read(nla_data, csocket.SockaddrStorage)
155        if addr.family == AF_INET:
156          addr = csocket.SockaddrIn(addr.Pack())
157        elif addr.family == AF_INET6:
158          addr = csocket.SockaddrIn6(addr.Pack())
159        data.append(addr)
160    else:
161      data = nla_data
162
163    return name, data
164
165  def MaybeDebugCommand(self, command, unused_flags, data):
166    name = self._GetConstantName(__name__, command, "SOCK_")
167    if "ALL" not in self.NL_DEBUG and "SOCK" not in self.NL_DEBUG:
168      return
169    parsed = self._ParseNLMsg(data, InetDiagReqV2)
170    print("%s %s" % (name, str(parsed)))
171
172  @staticmethod
173  def _EmptyInetDiagSockId():
174    return InetDiagSockId((b"\x00" * len(InetDiagSockId)))
175
176  @staticmethod
177  def PackBytecode(instructions):
178    """Compiles instructions to inet_diag bytecode.
179
180    The input is a list of (INET_DIAG_BC_xxx, yes, no, arg) tuples, where yes
181    and no are relative jump offsets measured in instructions. The yes branch
182    is taken if the instruction matches.
183
184    To accept, jump 1 past the last instruction. To reject, jump 2 past the
185    last instruction.
186
187    The target of a no jump is only valid if it is reachable by following
188    only yes jumps from the first instruction - see inet_diag_bc_audit and
189    valid_cc. This means that if cond1 and cond2 are two mutually exclusive
190    filter terms, it is not possible to implement cond1 OR cond2 using:
191
192      ...
193      cond1 2 1 arg
194      cond2 1 2 arg
195      accept
196      reject
197
198    but only using:
199
200      ...
201      cond1 1 2 arg
202      jmp   1 2
203      cond2 1 2 arg
204      accept
205      reject
206
207    The jmp instruction ignores yes and always jumps to no, but yes must be 1
208    or the bytecode won't validate. It doesn't have to be jmp - any instruction
209    that is guaranteed not to match on real data will do.
210
211    Args:
212      instructions: list of instruction tuples
213
214    Returns:
215      A string, the raw bytecode.
216    """
217    args = []
218    positions = [0]
219
220    for op, yes, no, arg in instructions:
221
222      if yes <= 0 or no <= 0:
223        raise ValueError("Jumps must be > 0")
224
225      if op in [INET_DIAG_BC_NOP, INET_DIAG_BC_JMP, INET_DIAG_BC_AUTO]:
226        arg = b""
227      elif op in [INET_DIAG_BC_S_GE, INET_DIAG_BC_S_LE,
228                  INET_DIAG_BC_D_GE, INET_DIAG_BC_D_LE]:
229        arg = b"\x00\x00" + struct.pack("=H", arg)
230      elif op in [INET_DIAG_BC_S_COND, INET_DIAG_BC_D_COND]:
231        addr, prefixlen, port = arg
232        family = AF_INET6 if ":" in addr else AF_INET
233        addr = inet_pton(family, addr)
234        arg = InetDiagHostcond((family, prefixlen, port)).Pack() + addr
235      elif op == INET_DIAG_BC_MARK_COND:
236        if isinstance(arg, tuple):
237          mark, mask = arg
238        else:
239          mark, mask = arg, 0xffffffff
240        arg = InetDiagMarkcond((mark, mask)).Pack()
241      else:
242        raise ValueError("Unsupported opcode %d" % op)
243
244      args.append(arg)
245      length = len(InetDiagBcOp) + len(arg)
246      positions.append(positions[-1] + length)
247
248    # Reject label.
249    positions.append(positions[-1] + 4)  # Why 4? Because the kernel uses 4.
250    assert len(args) == len(instructions) == len(positions) - 2
251
252    # print(positions)
253
254    packed = b""
255    for i, (op, yes, no, arg) in enumerate(instructions):
256      yes = positions[i + yes] - positions[i]
257      no = positions[i + no] - positions[i]
258      instruction = InetDiagBcOp((op, yes, no)).Pack() + args[i]
259      #print("%3d: %d %3d %3d %s %s" % (positions[i], op, yes, no,
260      #                                 arg, instruction.encode("hex")))
261      packed += instruction
262    #print
263
264    return packed
265
266  @staticmethod
267  def DecodeBytecode(bytecode):
268    instructions = []
269    try:
270      while bytecode:
271        op, rest = cstruct.Read(bytecode, InetDiagBcOp)
272
273        if op.code in [INET_DIAG_BC_NOP, INET_DIAG_BC_JMP, INET_DIAG_BC_AUTO]:
274          arg = None
275        elif op.code in [INET_DIAG_BC_S_GE, INET_DIAG_BC_S_LE,
276                         INET_DIAG_BC_D_GE, INET_DIAG_BC_D_LE]:
277          op, rest = cstruct.Read(rest, InetDiagBcOp)
278          arg = op.no
279        elif op.code in [INET_DIAG_BC_S_COND, INET_DIAG_BC_D_COND]:
280          cond, rest = cstruct.Read(rest, InetDiagHostcond)
281          if cond.family == 0:
282            arg = (None, cond.prefix_len, cond.port)
283          else:
284            addrlen = 4 if cond.family == AF_INET else 16
285            addr, rest = rest[:addrlen], rest[addrlen:]
286            addr = inet_ntop(cond.family, addr)
287            arg = (addr, cond.prefix_len, cond.port)
288        elif op.code == INET_DIAG_BC_DEV_COND:
289          attrlen = struct.calcsize("=I")
290          attr, rest = rest[:attrlen], rest[attrlen:]
291          arg = struct.unpack("=I", attr)
292        elif op.code == INET_DIAG_BC_MARK_COND:
293          arg, rest = cstruct.Read(rest, InetDiagMarkcond)
294        else:
295          raise ValueError("Unknown opcode %d" % op.code)
296        instructions.append((op, arg))
297        bytecode = rest
298
299      return instructions
300    except (TypeError, ValueError):
301      return "???"
302
303  def Dump(self, diag_req, bytecode):
304    if bytecode:
305      bytecode = self._NlAttr(INET_DIAG_REQ_BYTECODE, bytecode)
306
307    out = self._Dump(SOCK_DIAG_BY_FAMILY, diag_req, InetDiagMsg, bytecode)
308    return out
309
310  def DumpAllInetSockets(self, protocol, bytecode, sock_id=None, ext=0,
311                         states=ALL_NON_TIME_WAIT):
312    """Dumps IPv4 or IPv6 sockets matching the specified parameters."""
313    # DumpSockets(AF_UNSPEC) does not result in dumping all inet sockets, it
314    # results in ENOENT.
315    if sock_id is None:
316      sock_id = self._EmptyInetDiagSockId()
317
318    sockets = []
319    for family in [AF_INET, AF_INET6]:
320      diag_req = InetDiagReqV2((family, protocol, ext, states, sock_id))
321      sockets += self.Dump(diag_req, bytecode)
322
323    return sockets
324
325  @staticmethod
326  def GetRawAddress(family, addr):
327    """Fetches the source address from an InetDiagMsg."""
328    addrlen = {AF_INET:4, AF_INET6: 16}[family]
329    return inet_ntop(family, addr[:addrlen])
330
331  @staticmethod
332  def GetSourceAddress(diag_msg):
333    """Fetches the source address from an InetDiagMsg."""
334    return SockDiag.GetRawAddress(diag_msg.family, diag_msg.id.src)
335
336  @staticmethod
337  def GetDestinationAddress(diag_msg):
338    """Fetches the source address from an InetDiagMsg."""
339    return SockDiag.GetRawAddress(diag_msg.family, diag_msg.id.dst)
340
341  @staticmethod
342  def RawAddress(addr):
343    """Converts an IP address string to binary format."""
344    family = AF_INET6 if ":" in addr else AF_INET
345    return inet_pton(family, addr)
346
347  @staticmethod
348  def PaddedAddress(addr):
349    """Converts an IP address string to binary format for InetDiagSockId."""
350    padded = SockDiag.RawAddress(addr)
351    if len(padded) < 16:
352      padded += b"\x00" * (16 - len(padded))
353    return padded
354
355  @staticmethod
356  def DiagReqFromSocket(s):
357    """Creates an InetDiagReqV2 that matches the specified socket."""
358    family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN)
359    protocol = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_PROTOCOL)
360    iface = s.getsockopt(SOL_SOCKET, net_test.SO_BINDTODEVICE,
361                         net_test.IFNAMSIZ)
362    iface = GetInterfaceIndex(iface) if iface else 0
363    src, sport = s.getsockname()[:2]
364    try:
365      dst, dport = s.getpeername()[:2]
366    except error as e:
367      if e.errno == errno.ENOTCONN:
368        dport = 0
369        dst = "::" if family == AF_INET6 else "0.0.0.0"
370      else:
371        raise e
372    src = SockDiag.PaddedAddress(src)
373    dst = SockDiag.PaddedAddress(dst)
374    sock_id = InetDiagSockId((sport, dport, src, dst, iface, b"\x00" * 8))
375    return InetDiagReqV2((family, protocol, 0, 0xffffffff, sock_id))
376
377  @staticmethod
378  def GetSocketCookie(s):
379    cookie = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_COOKIE, 8)
380    return struct.unpack("=Q", cookie)[0]
381
382  def FindSockInfoFromFd(self, s):
383    """Gets a diag_msg and attrs from the kernel for the specified socket."""
384    req = self.DiagReqFromSocket(s)
385    # The kernel doesn't use idiag_src and idiag_dst when dumping sockets, it
386    # only uses them when targeting a specific socket with a cookie. Check the
387    # the inode number to ensure we don't mistakenly match another socket on
388    # the same port but with a different IP address.
389    inode = os.fstat(s.fileno()).st_ino
390    results = self.Dump(req, b"")
391    if len(results) == 0:
392      raise ValueError("Dump of %s returned no sockets" % req)
393    for diag_msg, attrs in results:
394      if diag_msg.inode == inode:
395        return diag_msg, attrs
396    raise ValueError("Dump of %s did not contain inode %d" % (req, inode))
397
398  def FindSockDiagFromFd(self, s):
399    """Gets an InetDiagMsg from the kernel for the specified socket."""
400    return self.FindSockInfoFromFd(s)[0]
401
402  def GetSockInfo(self, req):
403    """Gets a diag_msg and attrs from the kernel for the specified request."""
404    self._SendNlRequest(SOCK_DIAG_BY_FAMILY, req.Pack(), netlink.NLM_F_REQUEST)
405    return self._GetMsg(InetDiagMsg)
406
407  @staticmethod
408  def DiagReqFromDiagMsg(d, protocol):
409    """Constructs a diag_req from a diag_msg the kernel has given us."""
410    return InetDiagReqV2((d.family, protocol, 0, 1 << d.state, d.id))
411
412  def CloseSocket(self, req):
413    self._SendNlRequest(SOCK_DESTROY, req.Pack(),
414                        netlink.NLM_F_REQUEST | netlink.NLM_F_ACK)
415
416  def CloseSocketFromFd(self, s):
417    diag_msg, attrs = self.FindSockInfoFromFd(s)
418    protocol = s.getsockopt(SOL_SOCKET, net_test.SO_PROTOCOL)
419    req = self.DiagReqFromDiagMsg(diag_msg, protocol)
420    return self.CloseSocket(req)
421
422
423if __name__ == "__main__":
424  n = SockDiag()
425  n.DEBUG = True
426  sock_id = n._EmptyInetDiagSockId()
427  sock_id.dport = 443
428  ext = 1 << (INET_DIAG_TOS - 1) | 1 << (INET_DIAG_TCLASS - 1)
429  states = 0xffffffff
430  diag_msgs = n.DumpAllInetSockets(IPPROTO_TCP, b"",
431                                   sock_id=sock_id, ext=ext, states=states)
432  print(diag_msgs)
433