1#!/usr/bin/python3 2# 3# Copyright 2014 The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17"""Partial Python implementation of iproute functionality.""" 18 19# pylint: disable=g-bad-todo 20 21import os 22import socket 23import struct 24import sys 25 26import cstruct 27import util 28 29### Base netlink constants. See include/uapi/linux/netlink.h. 30NETLINK_ROUTE = 0 31NETLINK_SOCK_DIAG = 4 32NETLINK_XFRM = 6 33NETLINK_GENERIC = 16 34 35# Request constants. 36NLM_F_REQUEST = 1 37NLM_F_ACK = 4 38NLM_F_REPLACE = 0x100 39NLM_F_EXCL = 0x200 40NLM_F_CREATE = 0x400 41NLM_F_DUMP = 0x300 42 43# Message types. 44NLMSG_ERROR = 2 45NLMSG_DONE = 3 46 47# Data structure formats. 48# These aren't constants, they're classes. So, pylint: disable=invalid-name 49NLMsgHdr = cstruct.Struct("NLMsgHdr", "=LHHLL", "length type flags seq pid") 50NLMsgErr = cstruct.Struct("NLMsgErr", "=i", "error") 51NLAttr = cstruct.Struct("NLAttr", "=HH", "nla_len nla_type") 52 53# Alignment / padding. 54NLA_ALIGNTO = 4 55 56# List of attributes that can appear more than once in a given netlink message. 57# These can appear more than once but don't seem to contain any data. 58DUP_ATTRS_OK = ["INET_DIAG_NONE", "IFLA_PAD"] 59 60 61def MakeConstantPrefixes(prefixes): 62 return sorted(prefixes, key=len, reverse=True) 63 64 65class NetlinkSocket(object): 66 """A basic netlink socket object.""" 67 68 BUFSIZE = 65536 69 DEBUG = False 70 # List of netlink messages to print, e.g., [], ["NEIGH", "ROUTE"], or ["ALL"] 71 NL_DEBUG = [] 72 73 def _Debug(self, s): 74 if self.DEBUG: 75 print(s) 76 77 def _NlAttr(self, nla_type, data): 78 assert isinstance(data, bytes) 79 datalen = len(data) 80 # Pad the data if it's not a multiple of NLA_ALIGNTO bytes long. 81 padding = b"\x00" * util.GetPadLength(NLA_ALIGNTO, datalen) 82 nla_len = datalen + len(NLAttr) 83 return NLAttr((nla_len, nla_type)).Pack() + data + padding 84 85 def _NlAttrIPAddress(self, nla_type, family, address): 86 return self._NlAttr(nla_type, socket.inet_pton(family, address)) 87 88 def _NlAttrStr(self, nla_type, value): 89 value = value + "\x00" 90 return self._NlAttr(nla_type, value.encode("UTF-8")) 91 92 def _NlAttrU32(self, nla_type, value): 93 return self._NlAttr(nla_type, struct.pack("=I", value)) 94 95 @staticmethod 96 def _GetConstantName(module, value, prefix): 97 def FirstMatching(name, prefixlist): 98 for prefix in prefixlist: 99 if name.startswith(prefix): 100 return prefix 101 return None 102 103 thismodule = sys.modules[module] 104 constant_prefixes = getattr(thismodule, "CONSTANT_PREFIXES", []) 105 for name in dir(thismodule): 106 if value != getattr(thismodule, name) or not name.isupper(): 107 continue 108 # If the module explicitly specifies prefixes, only return this name if 109 # the passed-in prefix is the longest prefix that matches the name. 110 # This ensures, for example, that passing in a prefix of "IFA_" and a 111 # value of 1 returns "IFA_ADDRESS" instead of "IFA_F_SECONDARY". 112 # The longest matching prefix is always the first matching prefix because 113 # CONSTANT_PREFIXES must be sorted longest first. 114 if constant_prefixes and prefix != FirstMatching(name, constant_prefixes): 115 continue 116 if name.startswith(prefix): 117 return name 118 return value 119 120 def _Decode(self, command, msg, nla_type, nla_data, nested): 121 """No-op, nonspecific version of decode.""" 122 return nla_type, nla_data 123 124 def _ReadNlAttr(self, data): 125 # Read the nlattr header. 126 nla, data = cstruct.Read(data, NLAttr) 127 128 # Read the data. 129 datalen = nla.nla_len - len(nla) 130 padded_len = util.GetPadLength(NLA_ALIGNTO, datalen) + datalen 131 nla_data, data = data[:datalen], data[padded_len:] 132 133 return nla, nla_data, data 134 135 def _ParseAttributes(self, command, msg, data, nested): 136 """Parses and decodes netlink attributes. 137 138 Takes a block of NLAttr data structures, decodes them using Decode, and 139 returns the result in a dict keyed by attribute number. 140 141 Args: 142 command: An integer, the rtnetlink command being carried out. 143 msg: A Struct, the type of the data after the netlink header. 144 data: A byte string containing a sequence of NLAttr data structures. 145 nested: A list, outermost first, of each of the attributes the NLAttrs are 146 nested inside. Empty for non-nested attributes. 147 148 Returns: 149 A dictionary mapping attribute types (integers) to decoded values. 150 151 Raises: 152 ValueError: There was a duplicate attribute type. 153 """ 154 attributes = {} 155 while data: 156 nla, nla_data, data = self._ReadNlAttr(data) 157 158 # If it's an attribute we know about, try to decode it. 159 nla_name, nla_data = self._Decode(command, msg, nla.nla_type, nla_data, nested) 160 161 if nla_name in attributes and nla_name not in DUP_ATTRS_OK: 162 raise ValueError("Duplicate attribute %s" % nla_name) 163 164 attributes[nla_name] = nla_data 165 if not nested: 166 self._Debug(" %s" % (str((nla_name, nla_data)))) 167 168 return attributes 169 170 def _OpenNetlinkSocket(self, family, groups): 171 sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, family) 172 if groups: 173 sock.bind((0, groups)) 174 sock.connect((0, 0)) # The kernel. 175 return sock 176 177 def __init__(self, family, groups=None): 178 # Global sequence number. 179 self.seq = 0 180 self.sock = self._OpenNetlinkSocket(family, groups) 181 self.pid = self.sock.getsockname()[1] 182 183 def close(self): 184 self.sock.close() 185 self.sock = None 186 187 def __del__(self): 188 if self.sock: 189 self.close() 190 191 def MaybeDebugCommand(self, command, flags, data): 192 # Default no-op implementation to be overridden by subclasses. 193 pass 194 195 def _Send(self, msg): 196 # self._Debug(msg.encode("hex")) 197 self.seq += 1 198 self.sock.send(msg) 199 200 def _Recv(self): 201 data = self.sock.recv(self.BUFSIZE) 202 # self._Debug(data.encode("hex")) 203 return data 204 205 def _ExpectDone(self): 206 response = self._Recv() 207 hdr = NLMsgHdr(response) 208 if hdr.type != NLMSG_DONE: 209 raise ValueError("Expected DONE, got type %d" % hdr.type) 210 211 def _ParseAck(self, response): 212 # Find the error code. 213 hdr, data = cstruct.Read(response, NLMsgHdr) 214 if hdr.type == NLMSG_ERROR: 215 error = -NLMsgErr(data).error 216 if error: 217 raise IOError(error, os.strerror(error)) 218 else: 219 raise ValueError("Expected ACK, got type %d" % hdr.type) 220 221 def _ExpectAck(self): 222 response = self._Recv() 223 self._ParseAck(response) 224 225 def _SendNlRequest(self, command, data, flags): 226 """Sends a netlink request and expects an ack.""" 227 length = len(NLMsgHdr) + len(data) 228 nlmsg = NLMsgHdr((length, command, flags, self.seq, self.pid)).Pack() 229 230 self.MaybeDebugCommand(command, flags, nlmsg + data) 231 232 # Send the message. 233 self._Send(nlmsg + data) 234 235 if flags & NLM_F_ACK: 236 self._ExpectAck() 237 238 def _ParseNLMsg(self, data, msgtype): 239 """Parses a Netlink message into a header and a dictionary of attributes.""" 240 nlmsghdr, data = cstruct.Read(data, NLMsgHdr) 241 self._Debug(" %s" % nlmsghdr) 242 243 if nlmsghdr.type == NLMSG_ERROR or nlmsghdr.type == NLMSG_DONE: 244 print("done") 245 return (None, None), data 246 247 nlmsg, data = cstruct.Read(data, msgtype) 248 self._Debug(" %s" % nlmsg) 249 250 # Parse the attributes in the nlmsg. 251 attrlen = nlmsghdr.length - len(nlmsghdr) - len(nlmsg) 252 attributes = self._ParseAttributes(nlmsghdr.type, nlmsg, data[:attrlen], []) 253 data = data[attrlen:] 254 return (nlmsg, attributes), data 255 256 def _GetMsg(self, msgtype): 257 data = self._Recv() 258 if NLMsgHdr(data).type == NLMSG_ERROR: 259 self._ParseAck(data) 260 return self._ParseNLMsg(data, msgtype)[0] 261 262 def _GetMsgList(self, msgtype, data, expect_done): 263 out = [] 264 while data: 265 msg, data = self._ParseNLMsg(data, msgtype) 266 if msg is None: 267 break 268 out.append(msg) 269 if expect_done: 270 self._ExpectDone() 271 return out 272 273 def _Dump(self, command, msg, msgtype, attrs=b""): 274 """Sends a dump request and returns a list of decoded messages. 275 276 Args: 277 command: An integer, the command to run (e.g., RTM_NEWADDR). 278 msg: A struct, the request (e.g., a RTMsg). May be None. 279 msgtype: A cstruct.Struct, the data type to parse the dump results as. 280 attrs: A string, the raw bytes of any request attributes to include. 281 282 Returns: 283 A list of (msg, attrs) tuples where msg is of type msgtype and attrs is 284 a dict of attributes. 285 """ 286 # Create a netlink dump request containing the msg. 287 flags = NLM_F_DUMP | NLM_F_REQUEST 288 msg = b"" if msg is None else msg.Pack() 289 length = len(NLMsgHdr) + len(msg) + len(attrs) 290 nlmsghdr = NLMsgHdr((length, command, flags, self.seq, self.pid)) 291 292 # Send the request. 293 request = nlmsghdr.Pack() + msg + attrs 294 self.MaybeDebugCommand(command, flags, request) 295 self._Send(request) 296 297 # Keep reading netlink messages until we get a NLMSG_DONE. 298 out = [] 299 while True: 300 data = self._Recv() 301 response_type = NLMsgHdr(data).type 302 if response_type == NLMSG_DONE: 303 break 304 elif response_type == NLMSG_ERROR: 305 # Likely means that the kernel didn't like our dump request. 306 # Parse the error and throw an exception. 307 self._ParseAck(data) 308 out.extend(self._GetMsgList(msgtype, data, False)) 309 310 return out 311