1#!/usr/bin/python3 2# 3# Copyright 2015 The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17"""Partial Python implementation of sock_diag functionality.""" 18 19# pylint: disable=g-bad-todo 20 21import errno 22import os 23from socket import * # pylint: disable=wildcard-import 24import struct 25 26import csocket 27import cstruct 28import net_test 29import netlink 30 31### sock_diag constants. See include/uapi/linux/sock_diag.h. 32# Message types. 33SOCK_DIAG_BY_FAMILY = 20 34SOCK_DESTROY = 21 35 36### inet_diag_constants. See include/uapi/linux/inet_diag.h 37# Message types. 38TCPDIAG_GETSOCK = 18 39 40# Request attributes. 41INET_DIAG_REQ_BYTECODE = 1 42 43# Extensions. 44INET_DIAG_NONE = 0 45INET_DIAG_MEMINFO = 1 46INET_DIAG_INFO = 2 47INET_DIAG_VEGASINFO = 3 48INET_DIAG_CONG = 4 49INET_DIAG_TOS = 5 50INET_DIAG_TCLASS = 6 51INET_DIAG_SKMEMINFO = 7 52INET_DIAG_SHUTDOWN = 8 53INET_DIAG_DCTCPINFO = 9 54INET_DIAG_DCTCPINFO = 9 55INET_DIAG_PROTOCOL = 10 56INET_DIAG_SKV6ONLY = 11 57INET_DIAG_LOCALS = 12 58INET_DIAG_PEERS = 13 59INET_DIAG_PAD = 14 60INET_DIAG_MARK = 15 61 62# Bytecode operations. 63INET_DIAG_BC_NOP = 0 64INET_DIAG_BC_JMP = 1 65INET_DIAG_BC_S_GE = 2 66INET_DIAG_BC_S_LE = 3 67INET_DIAG_BC_D_GE = 4 68INET_DIAG_BC_D_LE = 5 69INET_DIAG_BC_AUTO = 6 70INET_DIAG_BC_S_COND = 7 71INET_DIAG_BC_D_COND = 8 72INET_DIAG_BC_DEV_COND = 9 73INET_DIAG_BC_MARK_COND = 10 74 75CONSTANT_PREFIXES = netlink.MakeConstantPrefixes([ 76 "INET_DIAG_", "INET_DIAG_REQ_", "INET_DIAG_BC_"]) 77 78# Data structure formats. 79# These aren't constants, they're classes. So, pylint: disable=invalid-name 80InetDiagSockId = cstruct.Struct( 81 "InetDiagSockId", "!HH16s16sI8s", "sport dport src dst iface cookie") 82InetDiagReqV2 = cstruct.Struct( 83 "InetDiagReqV2", "=BBBxIS", "family protocol ext states id", 84 [InetDiagSockId]) 85InetDiagMsg = cstruct.Struct( 86 "InetDiagMsg", "=BBBBSLLLLL", 87 "family state timer retrans id expires rqueue wqueue uid inode", 88 [InetDiagSockId]) 89InetDiagMeminfo = cstruct.Struct( 90 "InetDiagMeminfo", "=IIII", "rmem wmem fmem tmem") 91InetDiagBcOp = cstruct.Struct("InetDiagBcOp", "BBH", "code yes no") 92InetDiagHostcond = cstruct.Struct("InetDiagHostcond", "=BBxxi", 93 "family prefix_len port") 94InetDiagMarkcond = cstruct.Struct("InetDiagMarkcond", "=II", "mark mask") 95 96SkMeminfo = cstruct.Struct( 97 "SkMeminfo", "=IIIIIIII", 98 "rmem_alloc rcvbuf wmem_alloc sndbuf fwd_alloc wmem_queued optmem backlog") 99TcpInfo = cstruct.Struct( 100 "TcpInfo", "=BBBBBBBxIIIIIIIIIIIIIIIIIIIIIIII", 101 "state ca_state retransmits probes backoff options wscale " 102 "rto ato snd_mss rcv_mss " 103 "unacked sacked lost retrans fackets " 104 "last_data_sent last_ack_sent last_data_recv last_ack_recv " 105 "pmtu rcv_ssthresh rtt rttvar snd_ssthresh snd_cwnd advmss reordering " 106 "rcv_rtt rcv_space " 107 "total_retrans") # As of linux 3.13, at least. 108 109TCP_TIME_WAIT = 6 110ALL_NON_TIME_WAIT = 0xffffffff & ~(1 << TCP_TIME_WAIT) 111 112 113class SockDiag(netlink.NetlinkSocket): 114 115 NL_DEBUG = [] 116 117 def __init__(self): 118 super(SockDiag, self).__init__(netlink.NETLINK_SOCK_DIAG) 119 120 def _Decode(self, command, msg, nla_type, nla_data, nested): 121 """Decodes netlink attributes to Python types.""" 122 if msg.family == AF_INET or msg.family == AF_INET6: 123 if isinstance(msg, InetDiagReqV2): 124 prefix = "INET_DIAG_REQ_" 125 else: 126 prefix = "INET_DIAG_" 127 name = self._GetConstantName(__name__, nla_type, prefix) 128 else: 129 # Don't know what this is. Leave it as an integer. 130 name = nla_type 131 132 if name in ["INET_DIAG_SHUTDOWN", "INET_DIAG_TOS", "INET_DIAG_TCLASS", 133 "INET_DIAG_SKV6ONLY"]: 134 data = ord(nla_data) 135 elif name == "INET_DIAG_CONG": 136 data = nla_data.strip(b"\x00") 137 elif name == "INET_DIAG_MEMINFO": 138 data = InetDiagMeminfo(nla_data) 139 elif name == "INET_DIAG_INFO": 140 # TODO: Catch the exception and try something else if it's not TCP. 141 data = TcpInfo(nla_data) 142 elif name == "INET_DIAG_SKMEMINFO": 143 data = SkMeminfo(nla_data) 144 elif name == "INET_DIAG_MARK": 145 data = struct.unpack("=I", nla_data)[0] 146 elif name == "INET_DIAG_REQ_BYTECODE": 147 data = self.DecodeBytecode(nla_data) 148 elif name in ["INET_DIAG_LOCALS", "INET_DIAG_PEERS"]: 149 data = [] 150 while len(nla_data): 151 # The SCTP diag code always appears to copy sizeof(sockaddr_storage) 152 # bytes, but does so from a union sctp_addr which is at most as long 153 # as a sockaddr_in6. 154 addr, nla_data = cstruct.Read(nla_data, csocket.SockaddrStorage) 155 if addr.family == AF_INET: 156 addr = csocket.SockaddrIn(addr.Pack()) 157 elif addr.family == AF_INET6: 158 addr = csocket.SockaddrIn6(addr.Pack()) 159 data.append(addr) 160 else: 161 data = nla_data 162 163 return name, data 164 165 def MaybeDebugCommand(self, command, unused_flags, data): 166 name = self._GetConstantName(__name__, command, "SOCK_") 167 if "ALL" not in self.NL_DEBUG and "SOCK" not in self.NL_DEBUG: 168 return 169 parsed = self._ParseNLMsg(data, InetDiagReqV2) 170 print("%s %s" % (name, str(parsed))) 171 172 @staticmethod 173 def _EmptyInetDiagSockId(): 174 return InetDiagSockId((b"\x00" * len(InetDiagSockId))) 175 176 @staticmethod 177 def PackBytecode(instructions): 178 """Compiles instructions to inet_diag bytecode. 179 180 The input is a list of (INET_DIAG_BC_xxx, yes, no, arg) tuples, where yes 181 and no are relative jump offsets measured in instructions. The yes branch 182 is taken if the instruction matches. 183 184 To accept, jump 1 past the last instruction. To reject, jump 2 past the 185 last instruction. 186 187 The target of a no jump is only valid if it is reachable by following 188 only yes jumps from the first instruction - see inet_diag_bc_audit and 189 valid_cc. This means that if cond1 and cond2 are two mutually exclusive 190 filter terms, it is not possible to implement cond1 OR cond2 using: 191 192 ... 193 cond1 2 1 arg 194 cond2 1 2 arg 195 accept 196 reject 197 198 but only using: 199 200 ... 201 cond1 1 2 arg 202 jmp 1 2 203 cond2 1 2 arg 204 accept 205 reject 206 207 The jmp instruction ignores yes and always jumps to no, but yes must be 1 208 or the bytecode won't validate. It doesn't have to be jmp - any instruction 209 that is guaranteed not to match on real data will do. 210 211 Args: 212 instructions: list of instruction tuples 213 214 Returns: 215 A string, the raw bytecode. 216 """ 217 args = [] 218 positions = [0] 219 220 for op, yes, no, arg in instructions: 221 222 if yes <= 0 or no <= 0: 223 raise ValueError("Jumps must be > 0") 224 225 if op in [INET_DIAG_BC_NOP, INET_DIAG_BC_JMP, INET_DIAG_BC_AUTO]: 226 arg = b"" 227 elif op in [INET_DIAG_BC_S_GE, INET_DIAG_BC_S_LE, 228 INET_DIAG_BC_D_GE, INET_DIAG_BC_D_LE]: 229 arg = b"\x00\x00" + struct.pack("=H", arg) 230 elif op in [INET_DIAG_BC_S_COND, INET_DIAG_BC_D_COND]: 231 addr, prefixlen, port = arg 232 family = AF_INET6 if ":" in addr else AF_INET 233 addr = inet_pton(family, addr) 234 arg = InetDiagHostcond((family, prefixlen, port)).Pack() + addr 235 elif op == INET_DIAG_BC_MARK_COND: 236 if isinstance(arg, tuple): 237 mark, mask = arg 238 else: 239 mark, mask = arg, 0xffffffff 240 arg = InetDiagMarkcond((mark, mask)).Pack() 241 else: 242 raise ValueError("Unsupported opcode %d" % op) 243 244 args.append(arg) 245 length = len(InetDiagBcOp) + len(arg) 246 positions.append(positions[-1] + length) 247 248 # Reject label. 249 positions.append(positions[-1] + 4) # Why 4? Because the kernel uses 4. 250 assert len(args) == len(instructions) == len(positions) - 2 251 252 # print(positions) 253 254 packed = b"" 255 for i, (op, yes, no, arg) in enumerate(instructions): 256 yes = positions[i + yes] - positions[i] 257 no = positions[i + no] - positions[i] 258 instruction = InetDiagBcOp((op, yes, no)).Pack() + args[i] 259 #print("%3d: %d %3d %3d %s %s" % (positions[i], op, yes, no, 260 # arg, instruction.encode("hex"))) 261 packed += instruction 262 #print 263 264 return packed 265 266 @staticmethod 267 def DecodeBytecode(bytecode): 268 instructions = [] 269 try: 270 while bytecode: 271 op, rest = cstruct.Read(bytecode, InetDiagBcOp) 272 273 if op.code in [INET_DIAG_BC_NOP, INET_DIAG_BC_JMP, INET_DIAG_BC_AUTO]: 274 arg = None 275 elif op.code in [INET_DIAG_BC_S_GE, INET_DIAG_BC_S_LE, 276 INET_DIAG_BC_D_GE, INET_DIAG_BC_D_LE]: 277 op, rest = cstruct.Read(rest, InetDiagBcOp) 278 arg = op.no 279 elif op.code in [INET_DIAG_BC_S_COND, INET_DIAG_BC_D_COND]: 280 cond, rest = cstruct.Read(rest, InetDiagHostcond) 281 if cond.family == 0: 282 arg = (None, cond.prefix_len, cond.port) 283 else: 284 addrlen = 4 if cond.family == AF_INET else 16 285 addr, rest = rest[:addrlen], rest[addrlen:] 286 addr = inet_ntop(cond.family, addr) 287 arg = (addr, cond.prefix_len, cond.port) 288 elif op.code == INET_DIAG_BC_DEV_COND: 289 attrlen = struct.calcsize("=I") 290 attr, rest = rest[:attrlen], rest[attrlen:] 291 arg = struct.unpack("=I", attr) 292 elif op.code == INET_DIAG_BC_MARK_COND: 293 arg, rest = cstruct.Read(rest, InetDiagMarkcond) 294 else: 295 raise ValueError("Unknown opcode %d" % op.code) 296 instructions.append((op, arg)) 297 bytecode = rest 298 299 return instructions 300 except (TypeError, ValueError): 301 return "???" 302 303 def Dump(self, diag_req, bytecode): 304 if bytecode: 305 bytecode = self._NlAttr(INET_DIAG_REQ_BYTECODE, bytecode) 306 307 out = self._Dump(SOCK_DIAG_BY_FAMILY, diag_req, InetDiagMsg, bytecode) 308 return out 309 310 def DumpAllInetSockets(self, protocol, bytecode, sock_id=None, ext=0, 311 states=ALL_NON_TIME_WAIT): 312 """Dumps IPv4 or IPv6 sockets matching the specified parameters.""" 313 # DumpSockets(AF_UNSPEC) does not result in dumping all inet sockets, it 314 # results in ENOENT. 315 if sock_id is None: 316 sock_id = self._EmptyInetDiagSockId() 317 318 sockets = [] 319 for family in [AF_INET, AF_INET6]: 320 diag_req = InetDiagReqV2((family, protocol, ext, states, sock_id)) 321 sockets += self.Dump(diag_req, bytecode) 322 323 return sockets 324 325 @staticmethod 326 def GetRawAddress(family, addr): 327 """Fetches the source address from an InetDiagMsg.""" 328 addrlen = {AF_INET:4, AF_INET6: 16}[family] 329 return inet_ntop(family, addr[:addrlen]) 330 331 @staticmethod 332 def GetSourceAddress(diag_msg): 333 """Fetches the source address from an InetDiagMsg.""" 334 return SockDiag.GetRawAddress(diag_msg.family, diag_msg.id.src) 335 336 @staticmethod 337 def GetDestinationAddress(diag_msg): 338 """Fetches the source address from an InetDiagMsg.""" 339 return SockDiag.GetRawAddress(diag_msg.family, diag_msg.id.dst) 340 341 @staticmethod 342 def RawAddress(addr): 343 """Converts an IP address string to binary format.""" 344 family = AF_INET6 if ":" in addr else AF_INET 345 return inet_pton(family, addr) 346 347 @staticmethod 348 def PaddedAddress(addr): 349 """Converts an IP address string to binary format for InetDiagSockId.""" 350 padded = SockDiag.RawAddress(addr) 351 if len(padded) < 16: 352 padded += b"\x00" * (16 - len(padded)) 353 return padded 354 355 @staticmethod 356 def DiagReqFromSocket(s): 357 """Creates an InetDiagReqV2 that matches the specified socket.""" 358 family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN) 359 protocol = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_PROTOCOL) 360 iface = s.getsockopt(SOL_SOCKET, net_test.SO_BINDTODEVICE, 361 net_test.IFNAMSIZ) 362 iface = GetInterfaceIndex(iface) if iface else 0 363 src, sport = s.getsockname()[:2] 364 try: 365 dst, dport = s.getpeername()[:2] 366 except error as e: 367 if e.errno == errno.ENOTCONN: 368 dport = 0 369 dst = "::" if family == AF_INET6 else "0.0.0.0" 370 else: 371 raise e 372 src = SockDiag.PaddedAddress(src) 373 dst = SockDiag.PaddedAddress(dst) 374 sock_id = InetDiagSockId((sport, dport, src, dst, iface, b"\x00" * 8)) 375 return InetDiagReqV2((family, protocol, 0, 0xffffffff, sock_id)) 376 377 @staticmethod 378 def GetSocketCookie(s): 379 cookie = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_COOKIE, 8) 380 return struct.unpack("=Q", cookie)[0] 381 382 def FindSockInfoFromFd(self, s): 383 """Gets a diag_msg and attrs from the kernel for the specified socket.""" 384 req = self.DiagReqFromSocket(s) 385 # The kernel doesn't use idiag_src and idiag_dst when dumping sockets, it 386 # only uses them when targeting a specific socket with a cookie. Check the 387 # the inode number to ensure we don't mistakenly match another socket on 388 # the same port but with a different IP address. 389 inode = os.fstat(s.fileno()).st_ino 390 results = self.Dump(req, b"") 391 if len(results) == 0: 392 raise ValueError("Dump of %s returned no sockets" % req) 393 for diag_msg, attrs in results: 394 if diag_msg.inode == inode: 395 return diag_msg, attrs 396 raise ValueError("Dump of %s did not contain inode %d" % (req, inode)) 397 398 def FindSockDiagFromFd(self, s): 399 """Gets an InetDiagMsg from the kernel for the specified socket.""" 400 return self.FindSockInfoFromFd(s)[0] 401 402 def GetSockInfo(self, req): 403 """Gets a diag_msg and attrs from the kernel for the specified request.""" 404 self._SendNlRequest(SOCK_DIAG_BY_FAMILY, req.Pack(), netlink.NLM_F_REQUEST) 405 return self._GetMsg(InetDiagMsg) 406 407 @staticmethod 408 def DiagReqFromDiagMsg(d, protocol): 409 """Constructs a diag_req from a diag_msg the kernel has given us.""" 410 return InetDiagReqV2((d.family, protocol, 0, 1 << d.state, d.id)) 411 412 def CloseSocket(self, req): 413 self._SendNlRequest(SOCK_DESTROY, req.Pack(), 414 netlink.NLM_F_REQUEST | netlink.NLM_F_ACK) 415 416 def CloseSocketFromFd(self, s): 417 diag_msg, attrs = self.FindSockInfoFromFd(s) 418 protocol = s.getsockopt(SOL_SOCKET, net_test.SO_PROTOCOL) 419 req = self.DiagReqFromDiagMsg(diag_msg, protocol) 420 return self.CloseSocket(req) 421 422 423if __name__ == "__main__": 424 n = SockDiag() 425 n.DEBUG = True 426 sock_id = n._EmptyInetDiagSockId() 427 sock_id.dport = 443 428 ext = 1 << (INET_DIAG_TOS - 1) | 1 << (INET_DIAG_TCLASS - 1) 429 states = 0xffffffff 430 diag_msgs = n.DumpAllInetSockets(IPPROTO_TCP, b"", 431 sock_id=sock_id, ext=ext, states=states) 432 print(diag_msgs) 433