1#
2# Copyright (C) 2013 The Android Open Source Project
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#      http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16
17"""Verifying the integrity of a Chrome OS update payload.
18
19This module is used internally by the main Payload class for verifying the
20integrity of an update payload. The interface for invoking the checks is as
21follows:
22
23  checker = PayloadChecker(payload)
24  checker.Run(...)
25"""
26
27from __future__ import absolute_import
28from __future__ import print_function
29
30import array
31import base64
32import collections
33import hashlib
34import itertools
35import os
36import subprocess
37
38# pylint: disable=redefined-builtin
39from six.moves import range
40
41from update_payload import common
42from update_payload import error
43from update_payload import format_utils
44from update_payload import histogram
45import update_metadata_pb2
46
47#
48# Constants.
49#
50
51_CHECK_MOVE_SAME_SRC_DST_BLOCK = 'move-same-src-dst-block'
52_CHECK_PAYLOAD_SIG = 'payload-sig'
53CHECKS_TO_DISABLE = (
54    _CHECK_MOVE_SAME_SRC_DST_BLOCK,
55    _CHECK_PAYLOAD_SIG,
56)
57
58_TYPE_FULL = 'full'
59_TYPE_DELTA = 'delta'
60
61_DEFAULT_BLOCK_SIZE = 4096
62
63_DEFAULT_PUBKEY_BASE_NAME = 'update-payload-key.pub.pem'
64_DEFAULT_PUBKEY_FILE_NAME = os.path.join(os.path.dirname(__file__),
65                                         _DEFAULT_PUBKEY_BASE_NAME)
66
67# Supported minor version map to payload types allowed to be using them.
68_SUPPORTED_MINOR_VERSIONS = {
69    0: (_TYPE_FULL,),
70    2: (_TYPE_DELTA,),
71    3: (_TYPE_DELTA,),
72    4: (_TYPE_DELTA,),
73    5: (_TYPE_DELTA,),
74    6: (_TYPE_DELTA,),
75    7: (_TYPE_DELTA,),
76}
77
78
79#
80# Helper functions.
81#
82
83def _IsPowerOfTwo(val):
84  """Returns True iff val is a power of two."""
85  return val > 0 and (val & (val - 1)) == 0
86
87
88def _AddFormat(format_func, value):
89  """Adds a custom formatted representation to ordinary string representation.
90
91  Args:
92    format_func: A value formatter.
93    value: Value to be formatted and returned.
94
95  Returns:
96    A string 'x (y)' where x = str(value) and y = format_func(value).
97  """
98  ret = str(value)
99  formatted_str = format_func(value)
100  if formatted_str:
101    ret += ' (%s)' % formatted_str
102  return ret
103
104
105def _AddHumanReadableSize(size):
106  """Adds a human readable representation to a byte size value."""
107  return _AddFormat(format_utils.BytesToHumanReadable, size)
108
109
110#
111# Payload report generator.
112#
113
114class _PayloadReport(object):
115  """A payload report generator.
116
117  A report is essentially a sequence of nodes, which represent data points. It
118  is initialized to have a "global", untitled section. A node may be a
119  sub-report itself.
120  """
121
122  # Report nodes: Field, sub-report, section.
123  class Node(object):
124    """A report node interface."""
125
126    @staticmethod
127    def _Indent(indent, line):
128      """Indents a line by a given indentation amount.
129
130      Args:
131        indent: The indentation amount.
132        line: The line content (string).
133
134      Returns:
135        The properly indented line (string).
136      """
137      return '%*s%s' % (indent, '', line)
138
139    def GenerateLines(self, base_indent, sub_indent, curr_section):
140      """Generates the report lines for this node.
141
142      Args:
143        base_indent: Base indentation for each line.
144        sub_indent: Additional indentation for sub-nodes.
145        curr_section: The current report section object.
146
147      Returns:
148        A pair consisting of a list of properly indented report lines and a new
149        current section object.
150      """
151      raise NotImplementedError
152
153  class FieldNode(Node):
154    """A field report node, representing a (name, value) pair."""
155
156    def __init__(self, name, value, linebreak, indent):
157      super(_PayloadReport.FieldNode, self).__init__()
158      self.name = name
159      self.value = value
160      self.linebreak = linebreak
161      self.indent = indent
162
163    def GenerateLines(self, base_indent, sub_indent, curr_section):
164      """Generates a properly formatted 'name : value' entry."""
165      report_output = ''
166      if self.name:
167        report_output += self.name.ljust(curr_section.max_field_name_len) + ' :'
168      value_lines = str(self.value).splitlines()
169      if self.linebreak and self.name:
170        report_output += '\n' + '\n'.join(
171            ['%*s%s' % (self.indent, '', line) for line in value_lines])
172      else:
173        if self.name:
174          report_output += ' '
175        report_output += '%*s' % (self.indent, '')
176        cont_line_indent = len(report_output)
177        indented_value_lines = [value_lines[0]]
178        indented_value_lines.extend(['%*s%s' % (cont_line_indent, '', line)
179                                     for line in value_lines[1:]])
180        report_output += '\n'.join(indented_value_lines)
181
182      report_lines = [self._Indent(base_indent, line + '\n')
183                      for line in report_output.split('\n')]
184      return report_lines, curr_section
185
186  class SubReportNode(Node):
187    """A sub-report node, representing a nested report."""
188
189    def __init__(self, title, report):
190      super(_PayloadReport.SubReportNode, self).__init__()
191      self.title = title
192      self.report = report
193
194    def GenerateLines(self, base_indent, sub_indent, curr_section):
195      """Recurse with indentation."""
196      report_lines = [self._Indent(base_indent, self.title + ' =>\n')]
197      report_lines.extend(self.report.GenerateLines(base_indent + sub_indent,
198                                                    sub_indent))
199      return report_lines, curr_section
200
201  class SectionNode(Node):
202    """A section header node."""
203
204    def __init__(self, title=None):
205      super(_PayloadReport.SectionNode, self).__init__()
206      self.title = title
207      self.max_field_name_len = 0
208
209    def GenerateLines(self, base_indent, sub_indent, curr_section):
210      """Dump a title line, return self as the (new) current section."""
211      report_lines = []
212      if self.title:
213        report_lines.append(self._Indent(base_indent,
214                                         '=== %s ===\n' % self.title))
215      return report_lines, self
216
217  def __init__(self):
218    self.report = []
219    self.last_section = self.global_section = self.SectionNode()
220    self.is_finalized = False
221
222  def GenerateLines(self, base_indent, sub_indent):
223    """Generates the lines in the report, properly indented.
224
225    Args:
226      base_indent: The indentation used for root-level report lines.
227      sub_indent: The indentation offset used for sub-reports.
228
229    Returns:
230      A list of indented report lines.
231    """
232    report_lines = []
233    curr_section = self.global_section
234    for node in self.report:
235      node_report_lines, curr_section = node.GenerateLines(
236          base_indent, sub_indent, curr_section)
237      report_lines.extend(node_report_lines)
238
239    return report_lines
240
241  def Dump(self, out_file, base_indent=0, sub_indent=2):
242    """Dumps the report to a file.
243
244    Args:
245      out_file: File object to output the content to.
246      base_indent: Base indentation for report lines.
247      sub_indent: Added indentation for sub-reports.
248    """
249    report_lines = self.GenerateLines(base_indent, sub_indent)
250    if report_lines and not self.is_finalized:
251      report_lines.append('(incomplete report)\n')
252
253    for line in report_lines:
254      out_file.write(line)
255
256  def AddField(self, name, value, linebreak=False, indent=0):
257    """Adds a field/value pair to the payload report.
258
259    Args:
260      name: The field's name.
261      value: The field's value.
262      linebreak: Whether the value should be printed on a new line.
263      indent: Amount of extra indent for each line of the value.
264    """
265    assert not self.is_finalized
266    if name and self.last_section.max_field_name_len < len(name):
267      self.last_section.max_field_name_len = len(name)
268    self.report.append(self.FieldNode(name, value, linebreak, indent))
269
270  def AddSubReport(self, title):
271    """Adds and returns a sub-report with a title."""
272    assert not self.is_finalized
273    sub_report = self.SubReportNode(title, type(self)())
274    self.report.append(sub_report)
275    return sub_report.report
276
277  def AddSection(self, title):
278    """Adds a new section title."""
279    assert not self.is_finalized
280    self.last_section = self.SectionNode(title)
281    self.report.append(self.last_section)
282
283  def Finalize(self):
284    """Seals the report, marking it as complete."""
285    self.is_finalized = True
286
287
288#
289# Payload verification.
290#
291
292class PayloadChecker(object):
293  """Checking the integrity of an update payload.
294
295  This is a short-lived object whose purpose is to isolate the logic used for
296  verifying the integrity of an update payload.
297  """
298
299  def __init__(self, payload, assert_type=None, block_size=0,
300               allow_unhashed=False, disabled_tests=()):
301    """Initialize the checker.
302
303    Args:
304      payload: The payload object to check.
305      assert_type: Assert that payload is either 'full' or 'delta' (optional).
306      block_size: Expected filesystem / payload block size (optional).
307      allow_unhashed: Allow operations with unhashed data blobs.
308      disabled_tests: Sequence of tests to disable.
309    """
310    if not payload.is_init:
311      raise ValueError('Uninitialized update payload.')
312
313    # Set checker configuration.
314    self.payload = payload
315    self.block_size = block_size if block_size else _DEFAULT_BLOCK_SIZE
316    if not _IsPowerOfTwo(self.block_size):
317      raise error.PayloadError(
318          'Expected block (%d) size is not a power of two.' % self.block_size)
319    if assert_type not in (None, _TYPE_FULL, _TYPE_DELTA):
320      raise error.PayloadError('Invalid assert_type value (%r).' %
321                               assert_type)
322    self.payload_type = assert_type
323    self.allow_unhashed = allow_unhashed
324
325    # Disable specific tests.
326    self.check_move_same_src_dst_block = (
327        _CHECK_MOVE_SAME_SRC_DST_BLOCK not in disabled_tests)
328    self.check_payload_sig = _CHECK_PAYLOAD_SIG not in disabled_tests
329
330    # Reset state; these will be assigned when the manifest is checked.
331    self.sigs_offset = 0
332    self.sigs_size = 0
333    self.old_part_info = {}
334    self.new_part_info = {}
335    self.new_fs_sizes = collections.defaultdict(int)
336    self.old_fs_sizes = collections.defaultdict(int)
337    self.minor_version = None
338    self.major_version = None
339
340  @staticmethod
341  def _CheckElem(msg, name, report, is_mandatory, is_submsg, convert=str,
342                 msg_name=None, linebreak=False, indent=0):
343    """Adds an element from a protobuf message to the payload report.
344
345    Checks to see whether a message contains a given element, and if so adds
346    the element value to the provided report. A missing mandatory element
347    causes an exception to be raised.
348
349    Args:
350      msg: The message containing the element.
351      name: The name of the element.
352      report: A report object to add the element name/value to.
353      is_mandatory: Whether or not this element must be present.
354      is_submsg: Whether this element is itself a message.
355      convert: A function for converting the element value for reporting.
356      msg_name: The name of the message object (for error reporting).
357      linebreak: Whether the value report should induce a line break.
358      indent: Amount of indent used for reporting the value.
359
360    Returns:
361      A pair consisting of the element value and the generated sub-report for
362      it (if the element is a sub-message, None otherwise). If the element is
363      missing, returns (None, None).
364
365    Raises:
366      error.PayloadError if a mandatory element is missing.
367    """
368    element_result = collections.namedtuple('element_result', ['msg', 'report'])
369
370    if not msg.HasField(name):
371      if is_mandatory:
372        raise error.PayloadError('%smissing mandatory %s %r.' %
373                                 (msg_name + ' ' if msg_name else '',
374                                  'sub-message' if is_submsg else 'field',
375                                  name))
376      return element_result(None, None)
377
378    value = getattr(msg, name)
379    if is_submsg:
380      return element_result(value, report and report.AddSubReport(name))
381    else:
382      if report:
383        report.AddField(name, convert(value), linebreak=linebreak,
384                        indent=indent)
385      return element_result(value, None)
386
387  @staticmethod
388  def _CheckRepeatedElemNotPresent(msg, field_name, msg_name):
389    """Checks that a repeated element is not specified in the message.
390
391    Args:
392      msg: The message containing the element.
393      field_name: The name of the element.
394      msg_name: The name of the message object (for error reporting).
395
396    Raises:
397      error.PayloadError if the repeated element is present or non-empty.
398    """
399    if getattr(msg, field_name, None):
400      raise error.PayloadError('%sfield %r not empty.' %
401                               (msg_name + ' ' if msg_name else '', field_name))
402
403  @staticmethod
404  def _CheckElemNotPresent(msg, field_name, msg_name):
405    """Checks that an element is not specified in the message.
406
407    Args:
408      msg: The message containing the element.
409      field_name: The name of the element.
410      msg_name: The name of the message object (for error reporting).
411
412    Raises:
413      error.PayloadError if the repeated element is present.
414    """
415    if msg.HasField(field_name):
416      raise error.PayloadError('%sfield %r exists.' %
417                               (msg_name + ' ' if msg_name else '', field_name))
418
419  @staticmethod
420  def _CheckMandatoryField(msg, field_name, report, msg_name, convert=str,
421                           linebreak=False, indent=0):
422    """Adds a mandatory field; returning first component from _CheckElem."""
423    return PayloadChecker._CheckElem(msg, field_name, report, True, False,
424                                     convert=convert, msg_name=msg_name,
425                                     linebreak=linebreak, indent=indent)[0]
426
427  @staticmethod
428  def _CheckOptionalField(msg, field_name, report, convert=str,
429                          linebreak=False, indent=0):
430    """Adds an optional field; returning first component from _CheckElem."""
431    return PayloadChecker._CheckElem(msg, field_name, report, False, False,
432                                     convert=convert, linebreak=linebreak,
433                                     indent=indent)[0]
434
435  @staticmethod
436  def _CheckMandatorySubMsg(msg, submsg_name, report, msg_name):
437    """Adds a mandatory sub-message; wrapper for _CheckElem."""
438    return PayloadChecker._CheckElem(msg, submsg_name, report, True, True,
439                                     msg_name)
440
441  @staticmethod
442  def _CheckOptionalSubMsg(msg, submsg_name, report):
443    """Adds an optional sub-message; wrapper for _CheckElem."""
444    return PayloadChecker._CheckElem(msg, submsg_name, report, False, True)
445
446  @staticmethod
447  def _CheckPresentIff(val1, val2, name1, name2, obj_name):
448    """Checks that val1 is None iff val2 is None.
449
450    Args:
451      val1: first value to be compared.
452      val2: second value to be compared.
453      name1: name of object holding the first value.
454      name2: name of object holding the second value.
455      obj_name: Name of the object containing these values.
456
457    Raises:
458      error.PayloadError if assertion does not hold.
459    """
460    if None in (val1, val2) and val1 is not val2:
461      present, missing = (name1, name2) if val2 is None else (name2, name1)
462      raise error.PayloadError('%r present without %r%s.' %
463                               (present, missing,
464                                ' in ' + obj_name if obj_name else ''))
465
466  @staticmethod
467  def _CheckPresentIffMany(vals, name, obj_name):
468    """Checks that a set of vals and names imply every other element.
469
470    Args:
471      vals: The set of values to be compared.
472      name: The name of the objects holding the corresponding value.
473      obj_name: Name of the object containing these values.
474
475    Raises:
476      error.PayloadError if assertion does not hold.
477    """
478    if any(vals) and not all(vals):
479      raise error.PayloadError('%r is not present in all values%s.' %
480                               (name, ' in ' + obj_name if obj_name else ''))
481
482  @staticmethod
483  def _Run(cmd, send_data=None):
484    """Runs a subprocess, returns its output.
485
486    Args:
487      cmd: Sequence of command-line argument for invoking the subprocess.
488      send_data: Data to feed to the process via its stdin.
489
490    Returns:
491      A tuple containing the stdout and stderr output of the process.
492    """
493    run_process = subprocess.Popen(cmd, stdin=subprocess.PIPE,
494                                   stdout=subprocess.PIPE)
495    try:
496      result = run_process.communicate(input=send_data)
497    finally:
498      exit_code = run_process.wait()
499
500    if exit_code:
501      raise RuntimeError('Subprocess %r failed with code %r.' %
502                         (cmd, exit_code))
503
504    return result
505
506  @staticmethod
507  def _CheckSha256Signature(sig_data, pubkey_file_name, actual_hash, sig_name):
508    """Verifies an actual hash against a signed one.
509
510    Args:
511      sig_data: The raw signature data.
512      pubkey_file_name: Public key used for verifying signature.
513      actual_hash: The actual hash digest.
514      sig_name: Signature name for error reporting.
515
516    Raises:
517      error.PayloadError if signature could not be verified.
518    """
519    if len(sig_data) != 256:
520      raise error.PayloadError(
521          '%s: signature size (%d) not as expected (256).' %
522          (sig_name, len(sig_data)))
523    signed_data, _ = PayloadChecker._Run(
524        ['openssl', 'rsautl', '-verify', '-pubin', '-inkey', pubkey_file_name],
525        send_data=sig_data)
526
527    if len(signed_data) != len(common.SIG_ASN1_HEADER) + 32:
528      raise error.PayloadError('%s: unexpected signed data length (%d).' %
529                               (sig_name, len(signed_data)))
530
531    if not signed_data.startswith(common.SIG_ASN1_HEADER):
532      raise error.PayloadError('%s: not containing standard ASN.1 prefix.' %
533                               sig_name)
534
535    signed_hash = signed_data[len(common.SIG_ASN1_HEADER):]
536    if signed_hash != actual_hash:
537      raise error.PayloadError(
538          '%s: signed hash (%s) different from actual (%s).' %
539          (sig_name, common.FormatSha256(signed_hash),
540           common.FormatSha256(actual_hash)))
541
542  @staticmethod
543  def _CheckBlocksFitLength(length, num_blocks, block_size, length_name,
544                            block_name=None):
545    """Checks that a given length fits given block space.
546
547    This ensures that the number of blocks allocated is appropriate for the
548    length of the data residing in these blocks.
549
550    Args:
551      length: The actual length of the data.
552      num_blocks: The number of blocks allocated for it.
553      block_size: The size of each block in bytes.
554      length_name: Name of length (used for error reporting).
555      block_name: Name of block (used for error reporting).
556
557    Raises:
558      error.PayloadError if the aforementioned invariant is not satisfied.
559    """
560    # Check: length <= num_blocks * block_size.
561    if length > num_blocks * block_size:
562      raise error.PayloadError(
563          '%s (%d) > num %sblocks (%d) * block_size (%d).' %
564          (length_name, length, block_name or '', num_blocks, block_size))
565
566    # Check: length > (num_blocks - 1) * block_size.
567    if length <= (num_blocks - 1) * block_size:
568      raise error.PayloadError(
569          '%s (%d) <= (num %sblocks - 1 (%d)) * block_size (%d).' %
570          (length_name, length, block_name or '', num_blocks - 1, block_size))
571
572  def _CheckManifestMinorVersion(self, report):
573    """Checks the payload manifest minor_version field.
574
575    Args:
576      report: The report object to add to.
577
578    Raises:
579      error.PayloadError if any of the checks fail.
580    """
581    self.minor_version = self._CheckOptionalField(self.payload.manifest,
582                                                  'minor_version', report)
583    if self.minor_version in _SUPPORTED_MINOR_VERSIONS:
584      if self.payload_type not in _SUPPORTED_MINOR_VERSIONS[self.minor_version]:
585        raise error.PayloadError(
586            'Minor version %d not compatible with payload type %s.' %
587            (self.minor_version, self.payload_type))
588    elif self.minor_version is None:
589      raise error.PayloadError('Minor version is not set.')
590    else:
591      raise error.PayloadError('Unsupported minor version: %d' %
592                               self.minor_version)
593
594  def _CheckManifest(self, report, part_sizes=None):
595    """Checks the payload manifest.
596
597    Args:
598      report: A report object to add to.
599      part_sizes: Map of partition label to partition size in bytes.
600
601    Returns:
602      A tuple consisting of the partition block size used during the update
603      (integer), the signatures block offset and size.
604
605    Raises:
606      error.PayloadError if any of the checks fail.
607    """
608    self.major_version = self.payload.header.version
609
610    part_sizes = part_sizes or collections.defaultdict(int)
611    manifest = self.payload.manifest
612    report.AddSection('manifest')
613
614    # Check: block_size must exist and match the expected value.
615    actual_block_size = self._CheckMandatoryField(manifest, 'block_size',
616                                                  report, 'manifest')
617    if actual_block_size != self.block_size:
618      raise error.PayloadError('Block_size (%d) not as expected (%d).' %
619                               (actual_block_size, self.block_size))
620
621    # Check: signatures_offset <==> signatures_size.
622    self.sigs_offset = self._CheckOptionalField(manifest, 'signatures_offset',
623                                                report)
624    self.sigs_size = self._CheckOptionalField(manifest, 'signatures_size',
625                                              report)
626    self._CheckPresentIff(self.sigs_offset, self.sigs_size,
627                          'signatures_offset', 'signatures_size', 'manifest')
628
629    for part in manifest.partitions:
630      name = part.partition_name
631      self.old_part_info[name] = self._CheckOptionalSubMsg(
632          part, 'old_partition_info', report)
633      self.new_part_info[name] = self._CheckMandatorySubMsg(
634          part, 'new_partition_info', report, 'manifest.partitions')
635
636    # Check: Old-style partition infos should not be specified.
637    for _, part in common.CROS_PARTITIONS:
638      self._CheckElemNotPresent(manifest, 'old_%s_info' % part, 'manifest')
639      self._CheckElemNotPresent(manifest, 'new_%s_info' % part, 'manifest')
640
641    # Check: If old_partition_info is specified anywhere, it must be
642    # specified everywhere.
643    old_part_msgs = [part.msg for part in self.old_part_info.values() if part]
644    self._CheckPresentIffMany(old_part_msgs, 'old_partition_info',
645                              'manifest.partitions')
646
647    is_delta = any(part and part.msg for part in self.old_part_info.values())
648    if is_delta:
649      # Assert/mark delta payload.
650      if self.payload_type == _TYPE_FULL:
651        raise error.PayloadError(
652            'Apparent full payload contains old_{kernel,rootfs}_info.')
653      self.payload_type = _TYPE_DELTA
654
655      for part, (msg, part_report) in self.old_part_info.items():
656        # Check: {size, hash} present in old_{kernel,rootfs}_info.
657        field = 'old_%s_info' % part
658        self.old_fs_sizes[part] = self._CheckMandatoryField(msg, 'size',
659                                                            part_report, field)
660        self._CheckMandatoryField(msg, 'hash', part_report, field,
661                                  convert=common.FormatSha256)
662
663        # Check: old_{kernel,rootfs} size must fit in respective partition.
664        if self.old_fs_sizes[part] > part_sizes[part] > 0:
665          raise error.PayloadError(
666              'Old %s content (%d) exceed partition size (%d).' %
667              (part, self.old_fs_sizes[part], part_sizes[part]))
668    else:
669      # Assert/mark full payload.
670      if self.payload_type == _TYPE_DELTA:
671        raise error.PayloadError(
672            'Apparent delta payload missing old_{kernel,rootfs}_info.')
673      self.payload_type = _TYPE_FULL
674
675    # Check: new_{kernel,rootfs}_info present; contains {size, hash}.
676    for part, (msg, part_report) in self.new_part_info.items():
677      field = 'new_%s_info' % part
678      self.new_fs_sizes[part] = self._CheckMandatoryField(msg, 'size',
679                                                          part_report, field)
680      self._CheckMandatoryField(msg, 'hash', part_report, field,
681                                convert=common.FormatSha256)
682
683      # Check: new_{kernel,rootfs} size must fit in respective partition.
684      if self.new_fs_sizes[part] > part_sizes[part] > 0:
685        raise error.PayloadError(
686            'New %s content (%d) exceed partition size (%d).' %
687            (part, self.new_fs_sizes[part], part_sizes[part]))
688
689    # Check: minor_version makes sense for the payload type. This check should
690    # run after the payload type has been set.
691    self._CheckManifestMinorVersion(report)
692
693  def _CheckLength(self, length, total_blocks, op_name, length_name):
694    """Checks whether a length matches the space designated in extents.
695
696    Args:
697      length: The total length of the data.
698      total_blocks: The total number of blocks in extents.
699      op_name: Operation name (for error reporting).
700      length_name: Length name (for error reporting).
701
702    Raises:
703      error.PayloadError is there a problem with the length.
704    """
705    # Check: length is non-zero.
706    if length == 0:
707      raise error.PayloadError('%s: %s is zero.' % (op_name, length_name))
708
709    # Check that length matches number of blocks.
710    self._CheckBlocksFitLength(length, total_blocks, self.block_size,
711                               '%s: %s' % (op_name, length_name))
712
713  def _CheckExtents(self, extents, usable_size, block_counters, name):
714    """Checks a sequence of extents.
715
716    Args:
717      extents: The sequence of extents to check.
718      usable_size: The usable size of the partition to which the extents apply.
719      block_counters: Array of counters corresponding to the number of blocks.
720      name: The name of the extent block.
721
722    Returns:
723      The total number of blocks in the extents.
724
725    Raises:
726      error.PayloadError if any of the entailed checks fails.
727    """
728    total_num_blocks = 0
729    for ex, ex_name in common.ExtentIter(extents, name):
730      # Check: Mandatory fields.
731      start_block = PayloadChecker._CheckMandatoryField(ex, 'start_block',
732                                                        None, ex_name)
733      num_blocks = PayloadChecker._CheckMandatoryField(ex, 'num_blocks', None,
734                                                       ex_name)
735      end_block = start_block + num_blocks
736
737      # Check: num_blocks > 0.
738      if num_blocks == 0:
739        raise error.PayloadError('%s: extent length is zero.' % ex_name)
740
741      # Check: Make sure we're within the partition limit.
742      if usable_size and end_block * self.block_size > usable_size:
743        raise error.PayloadError(
744            '%s: extent (%s) exceeds usable partition size (%d).' %
745            (ex_name, common.FormatExtent(ex, self.block_size), usable_size))
746
747      # Record block usage.
748      for i in range(start_block, end_block):
749        block_counters[i] += 1
750
751      total_num_blocks += num_blocks
752
753    return total_num_blocks
754
755  def _CheckReplaceOperation(self, op, data_length, total_dst_blocks, op_name):
756    """Specific checks for REPLACE/REPLACE_BZ/REPLACE_XZ operations.
757
758    Args:
759      op: The operation object from the manifest.
760      data_length: The length of the data blob associated with the operation.
761      total_dst_blocks: Total number of blocks in dst_extents.
762      op_name: Operation name for error reporting.
763
764    Raises:
765      error.PayloadError if any check fails.
766    """
767    # Check: total_dst_blocks is not a floating point.
768    if isinstance(total_dst_blocks, float):
769      raise error.PayloadError('%s: contains invalid data type of '
770                               'total_dst_blocks.' % op_name)
771
772    # Check: Does not contain src extents.
773    if op.src_extents:
774      raise error.PayloadError('%s: contains src_extents.' % op_name)
775
776    # Check: Contains data.
777    if data_length is None:
778      raise error.PayloadError('%s: missing data_{offset,length}.' % op_name)
779
780    if op.type == common.OpType.REPLACE:
781      PayloadChecker._CheckBlocksFitLength(data_length, total_dst_blocks,
782                                           self.block_size,
783                                           op_name + '.data_length', 'dst')
784    else:
785      # Check: data_length must be smaller than the allotted dst blocks.
786      if data_length >= total_dst_blocks * self.block_size:
787        raise error.PayloadError(
788            '%s: data_length (%d) must be less than allotted dst block '
789            'space (%d * %d).' %
790            (op_name, data_length, total_dst_blocks, self.block_size))
791
792  def _CheckZeroOperation(self, op, op_name):
793    """Specific checks for ZERO operations.
794
795    Args:
796      op: The operation object from the manifest.
797      op_name: Operation name for error reporting.
798
799    Raises:
800      error.PayloadError if any check fails.
801    """
802    # Check: Does not contain src extents, data_length and data_offset.
803    if op.src_extents:
804      raise error.PayloadError('%s: contains src_extents.' % op_name)
805    if op.data_length:
806      raise error.PayloadError('%s: contains data_length.' % op_name)
807    if op.data_offset:
808      raise error.PayloadError('%s: contains data_offset.' % op_name)
809
810  def _CheckAnyDiffOperation(self, op, data_length, total_dst_blocks, op_name):
811    """Specific checks for SOURCE_BSDIFF, PUFFDIFF and BROTLI_BSDIFF
812       operations.
813
814    Args:
815      op: The operation.
816      data_length: The length of the data blob associated with the operation.
817      total_dst_blocks: Total number of blocks in dst_extents.
818      op_name: Operation name for error reporting.
819
820    Raises:
821      error.PayloadError if any check fails.
822    """
823    # Check: data_{offset,length} present.
824    if data_length is None:
825      raise error.PayloadError('%s: missing data_{offset,length}.' % op_name)
826
827    # Check: data_length is strictly smaller than the allotted dst blocks.
828    if data_length >= total_dst_blocks * self.block_size:
829      raise error.PayloadError(
830          '%s: data_length (%d) must be smaller than allotted dst space '
831          '(%d * %d = %d).' %
832          (op_name, data_length, total_dst_blocks, self.block_size,
833           total_dst_blocks * self.block_size))
834
835    # Check the existence of src_length and dst_length for legacy bsdiffs.
836    if op.type == common.OpType.SOURCE_BSDIFF and self.minor_version <= 3:
837      if not op.HasField('src_length') or not op.HasField('dst_length'):
838        raise error.PayloadError('%s: require {src,dst}_length.' % op_name)
839    else:
840      if op.HasField('src_length') or op.HasField('dst_length'):
841        raise error.PayloadError('%s: unneeded {src,dst}_length.' % op_name)
842
843  def _CheckSourceCopyOperation(self, data_offset, total_src_blocks,
844                                total_dst_blocks, op_name):
845    """Specific checks for SOURCE_COPY.
846
847    Args:
848      data_offset: The offset of a data blob for the operation.
849      total_src_blocks: Total number of blocks in src_extents.
850      total_dst_blocks: Total number of blocks in dst_extents.
851      op_name: Operation name for error reporting.
852
853    Raises:
854      error.PayloadError if any check fails.
855    """
856    # Check: No data_{offset,length}.
857    if data_offset is not None:
858      raise error.PayloadError('%s: contains data_{offset,length}.' % op_name)
859
860    # Check: total_src_blocks == total_dst_blocks.
861    if total_src_blocks != total_dst_blocks:
862      raise error.PayloadError(
863          '%s: total src blocks (%d) != total dst blocks (%d).' %
864          (op_name, total_src_blocks, total_dst_blocks))
865
866  def _CheckAnySourceOperation(self, op, total_src_blocks, op_name):
867    """Specific checks for SOURCE_* operations.
868
869    Args:
870      op: The operation object from the manifest.
871      total_src_blocks: Total number of blocks in src_extents.
872      op_name: Operation name for error reporting.
873
874    Raises:
875      error.PayloadError if any check fails.
876    """
877    # Check: total_src_blocks != 0.
878    if total_src_blocks == 0:
879      raise error.PayloadError('%s: no src blocks in a source op.' % op_name)
880
881    # Check: src_sha256_hash present in minor version >= 3.
882    if self.minor_version >= 3 and op.src_sha256_hash is None:
883      raise error.PayloadError('%s: source hash missing.' % op_name)
884
885  def _CheckOperation(self, op, op_name, old_block_counters, new_block_counters,
886                      old_usable_size, new_usable_size, prev_data_offset,
887                      blob_hash_counts):
888    """Checks a single update operation.
889
890    Args:
891      op: The operation object.
892      op_name: Operation name string for error reporting.
893      old_block_counters: Arrays of block read counters.
894      new_block_counters: Arrays of block write counters.
895      old_usable_size: The overall usable size for src data in bytes.
896      new_usable_size: The overall usable size for dst data in bytes.
897      prev_data_offset: Offset of last used data bytes.
898      blob_hash_counts: Counters for hashed/unhashed blobs.
899
900    Returns:
901      The amount of data blob associated with the operation.
902
903    Raises:
904      error.PayloadError if any check has failed.
905    """
906    # Check extents.
907    total_src_blocks = self._CheckExtents(
908        op.src_extents, old_usable_size, old_block_counters,
909        op_name + '.src_extents')
910    total_dst_blocks = self._CheckExtents(
911        op.dst_extents, new_usable_size, new_block_counters,
912        op_name + '.dst_extents')
913
914    # Check: data_offset present <==> data_length present.
915    data_offset = self._CheckOptionalField(op, 'data_offset', None)
916    data_length = self._CheckOptionalField(op, 'data_length', None)
917    self._CheckPresentIff(data_offset, data_length, 'data_offset',
918                          'data_length', op_name)
919
920    # Check: At least one dst_extent.
921    if not op.dst_extents:
922      raise error.PayloadError('%s: dst_extents is empty.' % op_name)
923
924    # Check {src,dst}_length, if present.
925    if op.HasField('src_length'):
926      self._CheckLength(op.src_length, total_src_blocks, op_name, 'src_length')
927    if op.HasField('dst_length'):
928      self._CheckLength(op.dst_length, total_dst_blocks, op_name, 'dst_length')
929
930    if op.HasField('data_sha256_hash'):
931      blob_hash_counts['hashed'] += 1
932
933      # Check: Operation carries data.
934      if data_offset is None:
935        raise error.PayloadError(
936            '%s: data_sha256_hash present but no data_{offset,length}.' %
937            op_name)
938
939      # Check: Hash verifies correctly.
940      actual_hash = hashlib.sha256(self.payload.ReadDataBlob(data_offset,
941                                                             data_length))
942      if op.data_sha256_hash != actual_hash.digest():
943        raise error.PayloadError(
944            '%s: data_sha256_hash (%s) does not match actual hash (%s).' %
945            (op_name, common.FormatSha256(op.data_sha256_hash),
946             common.FormatSha256(actual_hash.digest())))
947    elif data_offset is not None:
948      if self.allow_unhashed:
949        blob_hash_counts['unhashed'] += 1
950      else:
951        raise error.PayloadError('%s: unhashed operation not allowed.' %
952                                 op_name)
953
954    if data_offset is not None:
955      # Check: Contiguous use of data section.
956      if data_offset != prev_data_offset:
957        raise error.PayloadError(
958            '%s: data offset (%d) not matching amount used so far (%d).' %
959            (op_name, data_offset, prev_data_offset))
960
961    # Type-specific checks.
962    if op.type in (common.OpType.REPLACE, common.OpType.REPLACE_BZ,
963                   common.OpType.REPLACE_XZ):
964      self._CheckReplaceOperation(op, data_length, total_dst_blocks, op_name)
965    elif op.type == common.OpType.ZERO and self.minor_version >= 4:
966      self._CheckZeroOperation(op, op_name)
967    elif op.type == common.OpType.SOURCE_COPY and self.minor_version >= 2:
968      self._CheckSourceCopyOperation(data_offset, total_src_blocks,
969                                     total_dst_blocks, op_name)
970      self._CheckAnySourceOperation(op, total_src_blocks, op_name)
971    elif op.type == common.OpType.SOURCE_BSDIFF and self.minor_version >= 2:
972      self._CheckAnyDiffOperation(op, data_length, total_dst_blocks, op_name)
973      self._CheckAnySourceOperation(op, total_src_blocks, op_name)
974    elif op.type == common.OpType.BROTLI_BSDIFF and self.minor_version >= 4:
975      self._CheckAnyDiffOperation(op, data_length, total_dst_blocks, op_name)
976      self._CheckAnySourceOperation(op, total_src_blocks, op_name)
977    elif op.type == common.OpType.PUFFDIFF and self.minor_version >= 5:
978      self._CheckAnyDiffOperation(op, data_length, total_dst_blocks, op_name)
979      self._CheckAnySourceOperation(op, total_src_blocks, op_name)
980    else:
981      raise error.PayloadError(
982          'Operation %s (type %d) not allowed in minor version %d' %
983          (op_name, op.type, self.minor_version))
984    return data_length if data_length is not None else 0
985
986  def _SizeToNumBlocks(self, size):
987    """Returns the number of blocks needed to contain a given byte size."""
988    return (size + self.block_size - 1) // self.block_size
989
990  def _AllocBlockCounters(self, total_size):
991    """Returns a freshly initialized array of block counters.
992
993    Note that the generated array is not portable as is due to byte-ordering
994    issues, hence it should not be serialized.
995
996    Args:
997      total_size: The total block size in bytes.
998
999    Returns:
1000      An array of unsigned short elements initialized to zero, one for each of
1001      the blocks necessary for containing the partition.
1002    """
1003    return array.array('H',
1004                       itertools.repeat(0, self._SizeToNumBlocks(total_size)))
1005
1006  def _CheckOperations(self, operations, report, base_name, old_fs_size,
1007                       new_fs_size, old_usable_size, new_usable_size,
1008                       prev_data_offset):
1009    """Checks a sequence of update operations.
1010
1011    Args:
1012      operations: The sequence of operations to check.
1013      report: The report object to add to.
1014      base_name: The name of the operation block.
1015      old_fs_size: The old filesystem size in bytes.
1016      new_fs_size: The new filesystem size in bytes.
1017      old_usable_size: The overall usable size of the old partition in bytes.
1018      new_usable_size: The overall usable size of the new partition in bytes.
1019      prev_data_offset: Offset of last used data bytes.
1020
1021    Returns:
1022      The total data blob size used.
1023
1024    Raises:
1025      error.PayloadError if any of the checks fails.
1026    """
1027    # The total size of data blobs used by operations scanned thus far.
1028    total_data_used = 0
1029    # Counts of specific operation types.
1030    op_counts = {
1031        common.OpType.REPLACE: 0,
1032        common.OpType.REPLACE_BZ: 0,
1033        common.OpType.REPLACE_XZ: 0,
1034        common.OpType.ZERO: 0,
1035        common.OpType.SOURCE_COPY: 0,
1036        common.OpType.SOURCE_BSDIFF: 0,
1037        common.OpType.PUFFDIFF: 0,
1038        common.OpType.BROTLI_BSDIFF: 0,
1039    }
1040    # Total blob sizes for each operation type.
1041    op_blob_totals = {
1042        common.OpType.REPLACE: 0,
1043        common.OpType.REPLACE_BZ: 0,
1044        common.OpType.REPLACE_XZ: 0,
1045        # SOURCE_COPY operations don't have blobs.
1046        common.OpType.SOURCE_BSDIFF: 0,
1047        common.OpType.PUFFDIFF: 0,
1048        common.OpType.BROTLI_BSDIFF: 0,
1049    }
1050    # Counts of hashed vs unhashed operations.
1051    blob_hash_counts = {
1052        'hashed': 0,
1053        'unhashed': 0,
1054    }
1055
1056    # Allocate old and new block counters.
1057    old_block_counters = (self._AllocBlockCounters(old_usable_size)
1058                          if old_fs_size else None)
1059    new_block_counters = self._AllocBlockCounters(new_usable_size)
1060
1061    # Process and verify each operation.
1062    op_num = 0
1063    for op, op_name in common.OperationIter(operations, base_name):
1064      op_num += 1
1065
1066      # Check: Type is valid.
1067      if op.type not in op_counts:
1068        raise error.PayloadError('%s: invalid type (%d).' % (op_name, op.type))
1069      op_counts[op.type] += 1
1070
1071      curr_data_used = self._CheckOperation(
1072          op, op_name, old_block_counters, new_block_counters,
1073          old_usable_size, new_usable_size,
1074          prev_data_offset + total_data_used, blob_hash_counts)
1075      if curr_data_used:
1076        op_blob_totals[op.type] += curr_data_used
1077        total_data_used += curr_data_used
1078
1079    # Report totals and breakdown statistics.
1080    report.AddField('total operations', op_num)
1081    report.AddField(
1082        None,
1083        histogram.Histogram.FromCountDict(op_counts,
1084                                          key_names=common.OpType.NAMES),
1085        indent=1)
1086    report.AddField('total blobs', sum(blob_hash_counts.values()))
1087    report.AddField(None,
1088                    histogram.Histogram.FromCountDict(blob_hash_counts),
1089                    indent=1)
1090    report.AddField('total blob size', _AddHumanReadableSize(total_data_used))
1091    report.AddField(
1092        None,
1093        histogram.Histogram.FromCountDict(op_blob_totals,
1094                                          formatter=_AddHumanReadableSize,
1095                                          key_names=common.OpType.NAMES),
1096        indent=1)
1097
1098    # Report read/write histograms.
1099    if old_block_counters:
1100      report.AddField('block read hist',
1101                      histogram.Histogram.FromKeyList(old_block_counters),
1102                      linebreak=True, indent=1)
1103
1104    new_write_hist = histogram.Histogram.FromKeyList(
1105        new_block_counters[:self._SizeToNumBlocks(new_fs_size)])
1106    report.AddField('block write hist', new_write_hist, linebreak=True,
1107                    indent=1)
1108
1109    # Check: Full update must write each dst block once.
1110    if self.payload_type == _TYPE_FULL and new_write_hist.GetKeys() != [1]:
1111      raise error.PayloadError(
1112          '%s: not all blocks written exactly once during full update.' %
1113          base_name)
1114
1115    return total_data_used
1116
1117  def _CheckSignatures(self, report, pubkey_file_name):
1118    """Checks a payload's signature block."""
1119    sigs_raw = self.payload.ReadDataBlob(self.sigs_offset, self.sigs_size)
1120    sigs = update_metadata_pb2.Signatures()
1121    sigs.ParseFromString(sigs_raw)
1122    report.AddSection('signatures')
1123
1124    # Check: At least one signature present.
1125    if not sigs.signatures:
1126      raise error.PayloadError('Signature block is empty.')
1127
1128    # Check that we don't have the signature operation blob at the end (used to
1129    # be for major version 1).
1130    last_partition = self.payload.manifest.partitions[-1]
1131    if last_partition.operations:
1132      last_op = last_partition.operations[-1]
1133      # Check: signatures_{offset,size} must match the last (fake) operation.
1134      if (last_op.type == common.OpType.REPLACE and
1135          last_op.data_offset == self.sigs_offset and
1136          last_op.data_length == self.sigs_size):
1137        raise error.PayloadError('It seems like the last operation is the '
1138                                 'signature blob. This is an invalid payload.')
1139
1140    # Compute the checksum of all data up to signature blob.
1141    # TODO(garnold) we're re-reading the whole data section into a string
1142    # just to compute the checksum; instead, we could do it incrementally as
1143    # we read the blobs one-by-one, under the assumption that we're reading
1144    # them in order (which currently holds). This should be reconsidered.
1145    payload_hasher = self.payload.manifest_hasher.copy()
1146    common.Read(self.payload.payload_file, self.sigs_offset,
1147                offset=self.payload.data_offset, hasher=payload_hasher)
1148
1149    for sig, sig_name in common.SignatureIter(sigs.signatures, 'signatures'):
1150      sig_report = report.AddSubReport(sig_name)
1151
1152      # Check: Signature contains mandatory fields.
1153      self._CheckMandatoryField(sig, 'data', None, sig_name)
1154      sig_report.AddField('data len', len(sig.data))
1155
1156      # Check: Signatures pertains to actual payload hash.
1157      if sig.data:
1158        self._CheckSha256Signature(sig.data, pubkey_file_name,
1159                                   payload_hasher.digest(), sig_name)
1160
1161  def Run(self, pubkey_file_name=None, metadata_sig_file=None, metadata_size=0,
1162          part_sizes=None, report_out_file=None):
1163    """Checker entry point, invoking all checks.
1164
1165    Args:
1166      pubkey_file_name: Public key used for signature verification.
1167      metadata_sig_file: Metadata signature, if verification is desired.
1168      metadata_size: Metadata size, if verification is desired.
1169      part_sizes: Mapping of partition label to size in bytes (default: infer
1170        based on payload type and version or filesystem).
1171      report_out_file: File object to dump the report to.
1172
1173    Raises:
1174      error.PayloadError if payload verification failed.
1175    """
1176    if not pubkey_file_name:
1177      pubkey_file_name = _DEFAULT_PUBKEY_FILE_NAME
1178
1179    report = _PayloadReport()
1180
1181    # Get payload file size.
1182    self.payload.payload_file.seek(0, 2)
1183    payload_file_size = self.payload.payload_file.tell()
1184    self.payload.ResetFile()
1185
1186    try:
1187      # Check metadata_size (if provided).
1188      if metadata_size and self.payload.metadata_size != metadata_size:
1189        raise error.PayloadError('Invalid payload metadata size in payload(%d) '
1190                                 'vs given(%d)' % (self.payload.metadata_size,
1191                                                   metadata_size))
1192
1193      # Check metadata signature (if provided).
1194      if metadata_sig_file:
1195        metadata_sig = base64.b64decode(metadata_sig_file.read())
1196        self._CheckSha256Signature(metadata_sig, pubkey_file_name,
1197                                   self.payload.manifest_hasher.digest(),
1198                                   'metadata signature')
1199
1200      # Part 1: Check the file header.
1201      report.AddSection('header')
1202      # Check: Payload version is valid.
1203      if self.payload.header.version not in (1, 2):
1204        raise error.PayloadError('Unknown payload version (%d).' %
1205                                 self.payload.header.version)
1206      report.AddField('version', self.payload.header.version)
1207      report.AddField('manifest len', self.payload.header.manifest_len)
1208
1209      # Part 2: Check the manifest.
1210      self._CheckManifest(report, part_sizes)
1211      assert self.payload_type, 'payload type should be known by now'
1212
1213      # Make sure deprecated values are not present in the payload.
1214      for field in ('install_operations', 'kernel_install_operations'):
1215        self._CheckRepeatedElemNotPresent(self.payload.manifest, field,
1216                                          'manifest')
1217      for field in ('old_kernel_info', 'old_rootfs_info',
1218                    'new_kernel_info', 'new_rootfs_info'):
1219        self._CheckElemNotPresent(self.payload.manifest, field, 'manifest')
1220
1221      total_blob_size = 0
1222      for part, operations in ((p.partition_name, p.operations)
1223                               for p in self.payload.manifest.partitions):
1224        report.AddSection('%s operations' % part)
1225
1226        new_fs_usable_size = self.new_fs_sizes[part]
1227        old_fs_usable_size = self.old_fs_sizes[part]
1228
1229        if part_sizes is not None and part_sizes.get(part, None):
1230          new_fs_usable_size = old_fs_usable_size = part_sizes[part]
1231
1232        # TODO(chromium:243559) only default to the filesystem size if no
1233        # explicit size provided *and* the partition size is not embedded in the
1234        # payload; see issue for more details.
1235        total_blob_size += self._CheckOperations(
1236            operations, report, '%s_install_operations' % part,
1237            self.old_fs_sizes[part], self.new_fs_sizes[part],
1238            old_fs_usable_size, new_fs_usable_size, total_blob_size)
1239
1240      # Check: Operations data reach the end of the payload file.
1241      used_payload_size = self.payload.data_offset + total_blob_size
1242      # Major versions 2 and higher have a signature at the end, so it should be
1243      # considered in the total size of the image.
1244      if self.sigs_size:
1245        used_payload_size += self.sigs_size
1246
1247      if used_payload_size != payload_file_size:
1248        raise error.PayloadError(
1249            'Used payload size (%d) different from actual file size (%d).' %
1250            (used_payload_size, payload_file_size))
1251
1252      # Part 4: Handle payload signatures message.
1253      if self.check_payload_sig and self.sigs_size:
1254        self._CheckSignatures(report, pubkey_file_name)
1255
1256      # Part 5: Summary.
1257      report.AddSection('summary')
1258      report.AddField('update type', self.payload_type)
1259
1260      report.Finalize()
1261    finally:
1262      if report_out_file:
1263        report.Dump(report_out_file)
1264