1# Copyright 2014 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""A simple module for declaring C-like structures.
16
17Example usage:
18
19>>> # Declare a struct type by specifying name, field formats and field names.
20... # Field formats are the same as those used in the struct module, except:
21... # - S: Nested Struct.
22... # - A: NULL-padded ASCII string. Like s, but printing ignores contiguous
23... #      trailing NULL blocks at the end.
24... import cstruct
25>>> NLMsgHdr = cstruct.Struct("NLMsgHdr", "=LHHLL", "length type flags seq pid")
26>>>
27>>>
28>>> # Create instances from a tuple of values, raw bytes, zero-initialized, or
29>>> # using keywords.
30... n1 = NLMsgHdr((44, 32, 0x2, 0, 491))
31>>> print(n1)
32NLMsgHdr(length=44, type=32, flags=2, seq=0, pid=491)
33>>>
34>>> n2 = NLMsgHdr("\x2c\x00\x00\x00\x21\x00\x02\x00"
35...               "\x00\x00\x00\x00\xfe\x01\x00\x00" + "junk at end")
36>>> print(n2)
37NLMsgHdr(length=44, type=33, flags=2, seq=0, pid=510)
38>>>
39>>> n3 = netlink.NLMsgHdr() # Zero-initialized
40>>> print(n3)
41NLMsgHdr(length=0, type=0, flags=0, seq=0, pid=0)
42>>>
43>>> n4 = netlink.NLMsgHdr(length=44, type=33) # Other fields zero-initialized
44>>> print(n4)
45NLMsgHdr(length=44, type=33, flags=0, seq=0, pid=0)
46>>>
47>>> # Serialize to raw bytes.
48... print(n1.Pack().encode("hex"))
492c0000002000020000000000eb010000
50>>>
51>>> # Parse the beginning of a byte stream as a struct, and return the struct
52... # and the remainder of the stream for further reading.
53... data = ("\x2c\x00\x00\x00\x21\x00\x02\x00"
54...         "\x00\x00\x00\x00\xfe\x01\x00\x00"
55...         "more data")
56>>> cstruct.Read(data, NLMsgHdr)
57(NLMsgHdr(length=44, type=33, flags=2, seq=0, pid=510), 'more data')
58>>>
59>>> # Structs can contain one or more nested structs. The nested struct types
60... # are specified in a list as an optional last argument. Nested structs may
61... # contain nested structs.
62... S = cstruct.Struct("S", "=BI", "byte1 int2")
63>>> N = cstruct.Struct("N", "!BSiS", "byte1 s2 int3 s2", [S, S])
64>>> NN = cstruct.Struct("NN", "SHS", "s1 word2 n3", [S, N])
65>>> nn = NN((S((1, 25000)), -29876, N((55, S((5, 6)), 1111, S((7, 8))))))
66>>> nn.n3.s2.int2 = 5
67>>>
68"""
69
70import binascii
71import ctypes
72import string
73import struct
74import re
75
76
77def _PythonFormat(fmt):
78  if "A" in fmt:
79    fmt = fmt.replace("A", "s")
80  return re.split('\d+$', fmt)[0]
81
82def CalcSize(fmt):
83  return struct.calcsize(_PythonFormat(fmt))
84
85def CalcNumElements(fmt):
86  fmt = _PythonFormat(fmt)
87  prevlen = len(fmt)
88  fmt = fmt.replace("S", "")
89  numstructs = prevlen - len(fmt)
90  size = struct.calcsize(fmt)
91  elements = struct.unpack(fmt, b"\x00" * size)
92  return len(elements) + numstructs
93
94
95class StructMetaclass(type):
96
97  def __len__(cls):
98    return cls._length
99
100  def __init__(cls, unused_name, unused_bases, namespace):
101    # Make the class object have the name that's passed in.
102    type.__init__(cls, namespace["_name"], unused_bases, namespace)
103
104
105def Struct(name, fmt, fieldnames, substructs={}):
106  """Function that returns struct classes."""
107
108  # Hack to make struct classes use the StructMetaclass class on both python2 and
109  # python3. This is needed because in python2 the metaclass is assigned in the
110  # class definition, but in python3 it's passed into the constructor via
111  # keyword argument. Works by making all structs subclass CStructSuperclass,
112  # whose __new__ method uses StructMetaclass as its metaclass.
113  #
114  # A better option would be to use six.with_metaclass, but the existing python2
115  # VM image doesn't have the six module.
116  CStructSuperclass = type.__new__(StructMetaclass, 'unused', (), {})
117
118  class CStruct(CStructSuperclass):
119    """Class representing a C-like structure."""
120
121    # Name of the struct.
122    _name = name
123    # List of field names.
124    _fieldnames = fieldnames
125    # Dict mapping field indices to nested struct classes.
126    _nested = {}
127    # List of string fields that are ASCII strings.
128    _asciiz = set()
129
130    _fieldnames = _fieldnames.split(" ")
131
132    # Parse fmt into _format, converting any S format characters to "XXs",
133    # where XX is the length of the struct type's packed representation.
134    _format = ""
135    laststructindex = 0
136    for i in range(len(fmt)):
137      if fmt[i] == "S":
138        # Nested struct. Record the index in our struct it should go into.
139        index = CalcNumElements(fmt[:i])
140        _nested[index] = substructs[laststructindex]
141        laststructindex += 1
142        _format += "%ds" % len(_nested[index])
143      elif fmt[i] == "A":
144        # Null-terminated ASCII string. Remove digits before the A, so we don't
145        # call CalcNumElements on an (invalid) format that ends with a digit.
146        start = i
147        while start > 0 and fmt[start - 1].isdigit(): start -= 1
148        index = CalcNumElements(fmt[:start])
149        _asciiz.add(index)
150        _format += "s"
151      else:
152        # Standard struct format character.
153        _format += fmt[i]
154
155    _length = CalcSize(_format)
156
157    offset_list = [0]
158    last_offset = 0
159    for i in range(len(_format)):
160      offset = CalcSize(_format[:i])
161      if offset > last_offset:
162        last_offset = offset
163        offset_list.append(offset)
164
165    # A dictionary that maps field names to their offsets in the struct.
166    _offsets = dict(list(zip(_fieldnames, offset_list)))
167
168    # Check that the number of field names matches the number of fields.
169    numfields = len(struct.unpack(_format, b"\x00" * _length))
170    if len(_fieldnames) != numfields:
171      raise ValueError("Invalid cstruct: \"%s\" has %d elements, \"%s\" has %d."
172                       % (fmt, numfields, fieldnames, len(_fieldnames)))
173
174    def _SetValues(self, values):
175      # Replace self._values with the given list. We can't do direct assignment
176      # because of the __setattr__ overload on this class.
177      super(CStruct, self).__setattr__("_values", list(values))
178
179    def _Parse(self, data):
180      data = data[:self._length]
181      values = list(struct.unpack(self._format, data))
182      for index, value in enumerate(values):
183        if isinstance(value, bytes) and index in self._nested:
184          values[index] = self._nested[index](value)
185      self._SetValues(values)
186
187    def __init__(self, tuple_or_bytes=None, **kwargs):
188      """Construct an instance of this Struct.
189
190      1. With no args, the whole struct is zero-initialized.
191      2. With keyword args, the matching fields are populated; rest are zeroed.
192      3. With one tuple as the arg, the fields are assigned based on position.
193      4. With one bytes arg, the Struct is parsed from bytes.
194      """
195      if tuple_or_bytes and kwargs:
196        raise TypeError(
197            "%s: cannot specify both a tuple and keyword args" % self._name)
198
199      if tuple_or_bytes is None:
200        # Default construct from null bytes.
201        self._Parse(b"\x00" * len(self))
202        # If any keywords were supplied, set those fields.
203        for k, v in kwargs.items():
204          setattr(self, k, v)
205      elif isinstance(tuple_or_bytes, bytes):
206        # Initializing from bytes.
207        if len(tuple_or_bytes) < self._length:
208          raise TypeError("%s requires a bytes object of length %d, got %d" %
209                          (self._name, self._length, len(tuple_or_bytes)))
210        self._Parse(tuple_or_bytes)
211      else:
212        # Initializing from a tuple.
213        if len(tuple_or_bytes) != len(self._fieldnames):
214          raise TypeError("%s has exactly %d fieldnames: (%s), %d given: (%s)" %
215                          (self._name, len(self._fieldnames),
216                           ", ".join(self._fieldnames), len(tuple_or_bytes),
217                           ", ".join(str(x) for x in tuple_or_bytes)))
218        self._SetValues(tuple_or_bytes)
219
220    def _FieldIndex(self, attr):
221      try:
222        return self._fieldnames.index(attr)
223      except ValueError:
224        raise AttributeError("'%s' has no attribute '%s'" %
225                             (self._name, attr))
226
227    def __getattr__(self, name):
228      return self._values[self._FieldIndex(name)]
229
230    def __setattr__(self, name, value):
231      # TODO: check value type against self._format and throw here, or else
232      # callers get an unhelpful exception when they call Pack().
233      self._values[self._FieldIndex(name)] = value
234
235    def offset(self, name):
236      if "." in name:
237        raise NotImplementedError("offset() on nested field")
238      return self._offsets[name]
239
240    @classmethod
241    def __len__(cls):
242      return cls._length
243
244    def __ne__(self, other):
245      return not self.__eq__(other)
246
247    def __eq__(self, other):
248      return (isinstance(other, self.__class__) and
249              self._name == other._name and
250              self._fieldnames == other._fieldnames and
251              self._values == other._values)
252
253    @staticmethod
254    def _MaybePackStruct(value):
255      if isinstance(type(value), StructMetaclass):
256        return value.Pack()
257      else:
258        return value
259
260    def Pack(self):
261      values = [self._MaybePackStruct(v) for v in self._values]
262      return struct.pack(self._format, *values)
263
264    def __str__(self):
265
266      def HasNonPrintableChar(s):
267        for c in s:
268          # Iterating over bytes yields chars in python2 but ints in python3.
269          if isinstance(c, int): c = chr(c)
270          if c not in string.printable: return True
271        return False
272
273      def FieldDesc(index, name, value):
274        if isinstance(value, bytes) or isinstance(value, str):
275          if index in self._asciiz:
276            # TODO: use "backslashreplace" when python 2 is no longer supported.
277            value = value.rstrip(b"\x00").decode(errors="ignore")
278          elif HasNonPrintableChar(value):
279            value = binascii.hexlify(value).decode()
280        return "%s=%s" % (name, str(value))
281
282      descriptions = [
283          FieldDesc(i, n, v) for i, (n, v) in
284          enumerate(zip(self._fieldnames, self._values))]
285
286      return "%s(%s)" % (self._name, ", ".join(descriptions))
287
288    def __repr__(self):
289      return str(self)
290
291    def CPointer(self):
292      """Returns a C pointer to the serialized structure."""
293      buf = ctypes.create_string_buffer(self.Pack())
294      # Store the C buffer in the object so it doesn't get garbage collected.
295      super(CStruct, self).__setattr__("_buffer", buf)
296      return ctypes.addressof(self._buffer)
297
298  return CStruct
299
300
301def Read(data, struct_type):
302  length = len(struct_type)
303  return struct_type(data), data[length:]
304