1#!/usr/bin/python3
2#
3# Copyright 2017 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Classes for generic netlink."""
18
19import collections
20from socket import *  # pylint: disable=wildcard-import
21import struct
22
23import cstruct
24import netlink
25
26### Generic netlink constants. See include/uapi/linux/genetlink.h.
27# The generic netlink control family.
28GENL_ID_CTRL = 16
29
30# Commands.
31CTRL_CMD_GETFAMILY = 3
32
33# Attributes.
34CTRL_ATTR_FAMILY_ID = 1
35CTRL_ATTR_FAMILY_NAME = 2
36CTRL_ATTR_VERSION = 3
37CTRL_ATTR_HDRSIZE = 4
38CTRL_ATTR_MAXATTR = 5
39CTRL_ATTR_OPS = 6
40CTRL_ATTR_MCAST_GROUPS = 7
41
42# Attributes netsted inside CTRL_ATTR_OPS.
43CTRL_ATTR_OP_ID = 1
44CTRL_ATTR_OP_FLAGS = 2
45
46
47# Data structure formats.
48# These aren't constants, they're classes. So, pylint: disable=invalid-name
49Genlmsghdr = cstruct.Struct("genlmsghdr", "BBxx", "cmd version")
50
51
52class GenericNetlink(netlink.NetlinkSocket):
53  """Base class for all generic netlink classes."""
54
55  NL_DEBUG = []
56
57  def __init__(self):
58    super(GenericNetlink, self).__init__(netlink.NETLINK_GENERIC)
59
60  def _SendCommand(self, family, command, version, data, flags):
61    genlmsghdr = Genlmsghdr((command, version))
62    self._SendNlRequest(family, genlmsghdr.Pack() + data, flags)
63
64  def _Dump(self, family, command, version):
65    msg = Genlmsghdr((command, version))
66    return super(GenericNetlink, self)._Dump(family, msg, Genlmsghdr)
67
68
69class GenericNetlinkControl(GenericNetlink):
70  """Generic netlink control class.
71
72  This interface is used to manage other generic netlink families. We currently
73  use it only to find the family ID for address families of interest."""
74
75  def _DecodeOps(self, data):
76    ops = []
77    Op = collections.namedtuple("Op", ["id", "flags"])
78    # TODO: call _ParseAttributes on the nested data instead of manual parsing.
79    while data:
80      # Skip the nest marker.
81      datalen, index, data = data[:2], data[2:4], data[4:]
82
83      nla, nla_data, data = self._ReadNlAttr(data)
84      if nla.nla_type != CTRL_ATTR_OP_ID:
85        raise ValueError("Expected CTRL_ATTR_OP_ID, got %d" % nla.nla_type)
86      op_id = struct.unpack("=I", nla_data)[0]
87
88      nla, nla_data, data = self._ReadNlAttr(data)
89      if nla.nla_type != CTRL_ATTR_OP_FLAGS:
90        raise ValueError("Expected CTRL_ATTR_OP_FLAGS, got %d" % nla.type)
91      op_flags = struct.unpack("=I", nla_data)[0]
92
93      ops.append(Op(op_id, op_flags))
94    return ops
95
96  def _Decode(self, command, msg, nla_type, nla_data, nested):
97    """Decodes generic netlink control attributes to human-readable format."""
98
99    name = self._GetConstantName(__name__, nla_type, "CTRL_ATTR_")
100
101    if name == "CTRL_ATTR_FAMILY_ID":
102      data = struct.unpack("=H", nla_data)[0]
103    elif name == "CTRL_ATTR_FAMILY_NAME":
104      data = nla_data.strip(b"\x00")
105    elif name in ["CTRL_ATTR_VERSION", "CTRL_ATTR_HDRSIZE", "CTRL_ATTR_MAXATTR"]:
106      data = struct.unpack("=I", nla_data)[0]
107    elif name == "CTRL_ATTR_OPS":
108      data = self._DecodeOps(nla_data)
109    else:
110      data = nla_data
111
112    return name, data
113
114  def GetFamily(self, name):
115    """Returns the family ID for the specified family name."""
116    data = self._NlAttrStr(CTRL_ATTR_FAMILY_NAME, name)
117    self._SendCommand(GENL_ID_CTRL, CTRL_CMD_GETFAMILY, 0, data, netlink.NLM_F_REQUEST)
118    hdr, attrs = self._GetMsg(Genlmsghdr)
119    return attrs["CTRL_ATTR_FAMILY_ID"]
120
121
122if __name__ == "__main__":
123  g = GenericNetlinkControl()
124  print(g.GetFamily("tcp_metrics"))
125