1# Copyright (C) 2014 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import print_function
16
17import array
18import copy
19import functools
20import heapq
21import itertools
22import logging
23import multiprocessing
24import os
25import os.path
26import re
27import sys
28import threading
29import zlib
30from collections import deque, namedtuple, OrderedDict
31
32import common
33from images import EmptyImage
34from rangelib import RangeSet
35
36__all__ = ["BlockImageDiff"]
37
38logger = logging.getLogger(__name__)
39
40# The tuple contains the style and bytes of a bsdiff|imgdiff patch.
41PatchInfo = namedtuple("PatchInfo", ["imgdiff", "content"])
42
43
44def compute_patch(srcfile, tgtfile, imgdiff=False):
45  """Calls bsdiff|imgdiff to compute the patch data, returns a PatchInfo."""
46  patchfile = common.MakeTempFile(prefix='patch-')
47
48  cmd = ['imgdiff', '-z'] if imgdiff else ['bsdiff']
49  cmd.extend([srcfile, tgtfile, patchfile])
50
51  # Don't dump the bsdiff/imgdiff commands, which are not useful for the case
52  # here, since they contain temp filenames only.
53  proc = common.Run(cmd, verbose=False)
54  output, _ = proc.communicate()
55
56  if proc.returncode != 0:
57    raise ValueError(output)
58
59  with open(patchfile, 'rb') as f:
60    return PatchInfo(imgdiff, f.read())
61
62
63class Transfer(object):
64  def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, tgt_sha1,
65               src_sha1, style, by_id):
66    self.tgt_name = tgt_name
67    self.src_name = src_name
68    self.tgt_ranges = tgt_ranges
69    self.src_ranges = src_ranges
70    self.tgt_sha1 = tgt_sha1
71    self.src_sha1 = src_sha1
72    self.style = style
73
74    # We use OrderedDict rather than dict so that the output is repeatable;
75    # otherwise it would depend on the hash values of the Transfer objects.
76    self.goes_before = OrderedDict()
77    self.goes_after = OrderedDict()
78
79    self.stash_before = []
80    self.use_stash = []
81
82    self.id = len(by_id)
83    by_id.append(self)
84
85    self._patch_info = None
86
87  @property
88  def patch_info(self):
89    return self._patch_info
90
91  @patch_info.setter
92  def patch_info(self, info):
93    if info:
94      assert self.style == "diff"
95    self._patch_info = info
96
97  def NetStashChange(self):
98    return (sum(sr.size() for (_, sr) in self.stash_before) -
99            sum(sr.size() for (_, sr) in self.use_stash))
100
101  def ConvertToNew(self):
102    assert self.style != "new"
103    self.use_stash = []
104    self.style = "new"
105    self.src_ranges = RangeSet()
106    self.patch_info = None
107
108  def __str__(self):
109    return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
110            " to " + str(self.tgt_ranges) + ">")
111
112
113@functools.total_ordering
114class HeapItem(object):
115  def __init__(self, item):
116    self.item = item
117    # Negate the score since python's heap is a min-heap and we want the
118    # maximum score.
119    self.score = -item.score
120
121  def clear(self):
122    self.item = None
123
124  def __bool__(self):
125    return self.item is not None
126
127  # Python 2 uses __nonzero__, while Python 3 uses __bool__.
128  __nonzero__ = __bool__
129
130  # The rest operations are generated by functools.total_ordering decorator.
131  def __eq__(self, other):
132    return self.score == other.score
133
134  def __le__(self, other):
135    return self.score <= other.score
136
137
138class ImgdiffStats(object):
139  """A class that collects imgdiff stats.
140
141  It keeps track of the files that will be applied imgdiff while generating
142  BlockImageDiff. It also logs the ones that cannot use imgdiff, with specific
143  reasons. The stats is only meaningful when imgdiff not being disabled by the
144  caller of BlockImageDiff. In addition, only files with supported types
145  (BlockImageDiff.FileTypeSupportedByImgdiff()) are allowed to be logged.
146  """
147
148  USED_IMGDIFF = "APK files diff'd with imgdiff"
149  USED_IMGDIFF_LARGE_APK = "Large APK files split and diff'd with imgdiff"
150
151  # Reasons for not applying imgdiff on APKs.
152  SKIPPED_NONMONOTONIC = "Not used imgdiff due to having non-monotonic ranges"
153  SKIPPED_SHARED_BLOCKS = "Not used imgdiff due to using shared blocks"
154  SKIPPED_INCOMPLETE = "Not used imgdiff due to incomplete RangeSet"
155
156  # The list of valid reasons, which will also be the dumped order in a report.
157  REASONS = (
158      USED_IMGDIFF,
159      USED_IMGDIFF_LARGE_APK,
160      SKIPPED_NONMONOTONIC,
161      SKIPPED_SHARED_BLOCKS,
162      SKIPPED_INCOMPLETE,
163  )
164
165  def  __init__(self):
166    self.stats = {}
167
168  def Log(self, filename, reason):
169    """Logs why imgdiff can or cannot be applied to the given filename.
170
171    Args:
172      filename: The filename string.
173      reason: One of the reason constants listed in REASONS.
174
175    Raises:
176      AssertionError: On unsupported filetypes or invalid reason.
177    """
178    assert BlockImageDiff.FileTypeSupportedByImgdiff(filename)
179    assert reason in self.REASONS
180
181    if reason not in self.stats:
182      self.stats[reason] = set()
183    self.stats[reason].add(filename)
184
185  def Report(self):
186    """Prints a report of the collected imgdiff stats."""
187
188    def print_header(header, separator):
189      logger.info(header)
190      logger.info('%s\n', separator * len(header))
191
192    print_header('  Imgdiff Stats Report  ', '=')
193    for key in self.REASONS:
194      if key not in self.stats:
195        continue
196      values = self.stats[key]
197      section_header = ' {} (count: {}) '.format(key, len(values))
198      print_header(section_header, '-')
199      logger.info(''.join(['  {}\n'.format(name) for name in values]))
200
201
202class BlockImageDiff(object):
203  """Generates the diff of two block image objects.
204
205  BlockImageDiff works on two image objects. An image object is anything that
206  provides the following attributes:
207
208     blocksize: the size in bytes of a block, currently must be 4096.
209
210     total_blocks: the total size of the partition/image, in blocks.
211
212     care_map: a RangeSet containing which blocks (in the range [0,
213       total_blocks) we actually care about; i.e. which blocks contain data.
214
215     file_map: a dict that partitions the blocks contained in care_map into
216         smaller domains that are useful for doing diffs on. (Typically a domain
217         is a file, and the key in file_map is the pathname.)
218
219     clobbered_blocks: a RangeSet containing which blocks contain data but may
220         be altered by the FS. They need to be excluded when verifying the
221         partition integrity.
222
223     ReadRangeSet(): a function that takes a RangeSet and returns the data
224         contained in the image blocks of that RangeSet. The data is returned as
225         a list or tuple of strings; concatenating the elements together should
226         produce the requested data. Implementations are free to break up the
227         data into list/tuple elements in any way that is convenient.
228
229     RangeSha1(): a function that returns (as a hex string) the SHA-1 hash of
230         all the data in the specified range.
231
232     TotalSha1(): a function that returns (as a hex string) the SHA-1 hash of
233         all the data in the image (ie, all the blocks in the care_map minus
234         clobbered_blocks, or including the clobbered blocks if
235         include_clobbered_blocks is True).
236
237  When creating a BlockImageDiff, the src image may be None, in which case the
238  list of transfers produced will never read from the original image.
239  """
240
241  def __init__(self, tgt, src=None, threads=None, version=4,
242               disable_imgdiff=False):
243    if threads is None:
244      threads = multiprocessing.cpu_count() // 2
245      if threads == 0:
246        threads = 1
247    self.threads = threads
248    self.version = version
249    self.transfers = []
250    self.src_basenames = {}
251    self.src_numpatterns = {}
252    self._max_stashed_size = 0
253    self.touched_src_ranges = RangeSet()
254    self.touched_src_sha1 = None
255    self.disable_imgdiff = disable_imgdiff
256    self.imgdiff_stats = ImgdiffStats() if not disable_imgdiff else None
257
258    assert version in (3, 4)
259
260    self.tgt = tgt
261    if src is None:
262      src = EmptyImage()
263    self.src = src
264
265    # The updater code that installs the patch always uses 4k blocks.
266    assert tgt.blocksize == 4096
267    assert src.blocksize == 4096
268
269    # The range sets in each filemap should comprise a partition of
270    # the care map.
271    self.AssertPartition(src.care_map, src.file_map.values())
272    self.AssertPartition(tgt.care_map, tgt.file_map.values())
273
274  @property
275  def max_stashed_size(self):
276    return self._max_stashed_size
277
278  @staticmethod
279  def FileTypeSupportedByImgdiff(filename):
280    """Returns whether the file type is supported by imgdiff."""
281    return filename.lower().endswith(('.apk', '.jar', '.zip'))
282
283  def CanUseImgdiff(self, name, tgt_ranges, src_ranges, large_apk=False):
284    """Checks whether we can apply imgdiff for the given RangeSets.
285
286    For files in ZIP format (e.g., APKs, JARs, etc.) we would like to use
287    'imgdiff -z' if possible. Because it usually produces significantly smaller
288    patches than bsdiff.
289
290    This is permissible if all of the following conditions hold.
291      - The imgdiff hasn't been disabled by the caller (e.g. squashfs);
292      - The file type is supported by imgdiff;
293      - The source and target blocks are monotonic (i.e. the data is stored with
294        blocks in increasing order);
295      - Both files don't contain shared blocks;
296      - Both files have complete lists of blocks;
297      - We haven't removed any blocks from the source set.
298
299    If all these conditions are satisfied, concatenating all the blocks in the
300    RangeSet in order will produce a valid ZIP file (plus possibly extra zeros
301    in the last block). imgdiff is fine with extra zeros at the end of the file.
302
303    Args:
304      name: The filename to be diff'd.
305      tgt_ranges: The target RangeSet.
306      src_ranges: The source RangeSet.
307      large_apk: Whether this is to split a large APK.
308
309    Returns:
310      A boolean result.
311    """
312    if self.disable_imgdiff or not self.FileTypeSupportedByImgdiff(name):
313      return False
314
315    if not tgt_ranges.monotonic or not src_ranges.monotonic:
316      self.imgdiff_stats.Log(name, ImgdiffStats.SKIPPED_NONMONOTONIC)
317      return False
318
319    if (tgt_ranges.extra.get('uses_shared_blocks') or
320        src_ranges.extra.get('uses_shared_blocks')):
321      self.imgdiff_stats.Log(name, ImgdiffStats.SKIPPED_SHARED_BLOCKS)
322      return False
323
324    if tgt_ranges.extra.get('incomplete') or src_ranges.extra.get('incomplete'):
325      self.imgdiff_stats.Log(name, ImgdiffStats.SKIPPED_INCOMPLETE)
326      return False
327
328    reason = (ImgdiffStats.USED_IMGDIFF_LARGE_APK if large_apk
329              else ImgdiffStats.USED_IMGDIFF)
330    self.imgdiff_stats.Log(name, reason)
331    return True
332
333  def Compute(self, prefix):
334    # When looking for a source file to use as the diff input for a
335    # target file, we try:
336    #   1) an exact path match if available, otherwise
337    #   2) a exact basename match if available, otherwise
338    #   3) a basename match after all runs of digits are replaced by
339    #      "#" if available, otherwise
340    #   4) we have no source for this target.
341    self.AbbreviateSourceNames()
342    self.FindTransfers()
343
344    self.FindSequenceForTransfers()
345
346    # Ensure the runtime stash size is under the limit.
347    if common.OPTIONS.cache_size is not None:
348      stash_limit = (common.OPTIONS.cache_size *
349                     common.OPTIONS.stash_threshold / self.tgt.blocksize)
350      # Ignore the stash limit and calculate the maximum simultaneously stashed
351      # blocks needed.
352      _, max_stashed_blocks = self.ReviseStashSize(ignore_stash_limit=True)
353
354      # We cannot stash more blocks than the stash limit simultaneously. As a
355      # result, some 'diff' commands will be converted to new; leading to an
356      # unintended large package. To mitigate this issue, we can carefully
357      # choose the transfers for conversion. The number '1024' can be further
358      # tweaked here to balance the package size and build time.
359      if max_stashed_blocks > stash_limit + 1024:
360        self.SelectAndConvertDiffTransfersToNew(
361            max_stashed_blocks - stash_limit)
362        # Regenerate the sequence as the graph has changed.
363        self.FindSequenceForTransfers()
364
365      # Revise the stash size again to keep the size under limit.
366      self.ReviseStashSize()
367
368    # Double-check our work.
369    self.AssertSequenceGood()
370    self.AssertSha1Good()
371
372    self.ComputePatches(prefix)
373    self.WriteTransfers(prefix)
374
375    # Report the imgdiff stats.
376    if not self.disable_imgdiff:
377      self.imgdiff_stats.Report()
378
379  def WriteTransfers(self, prefix):
380    def WriteSplitTransfers(out, style, target_blocks):
381      """Limit the size of operand in command 'new' and 'zero' to 1024 blocks.
382
383      This prevents the target size of one command from being too large; and
384      might help to avoid fsync errors on some devices."""
385
386      assert style == "new" or style == "zero"
387      blocks_limit = 1024
388      total = 0
389      while target_blocks:
390        blocks_to_write = target_blocks.first(blocks_limit)
391        out.append("%s %s\n" % (style, blocks_to_write.to_string_raw()))
392        total += blocks_to_write.size()
393        target_blocks = target_blocks.subtract(blocks_to_write)
394      return total
395
396    out = []
397    total = 0
398
399    # In BBOTA v3+, it uses the hash of the stashed blocks as the stash slot
400    # id. 'stashes' records the map from 'hash' to the ref count. The stash
401    # will be freed only if the count decrements to zero.
402    stashes = {}
403    stashed_blocks = 0
404    max_stashed_blocks = 0
405
406    for xf in self.transfers:
407
408      for _, sr in xf.stash_before:
409        sh = self.src.RangeSha1(sr)
410        if sh in stashes:
411          stashes[sh] += 1
412        else:
413          stashes[sh] = 1
414          stashed_blocks += sr.size()
415          self.touched_src_ranges = self.touched_src_ranges.union(sr)
416          out.append("stash %s %s\n" % (sh, sr.to_string_raw()))
417
418      if stashed_blocks > max_stashed_blocks:
419        max_stashed_blocks = stashed_blocks
420
421      free_string = []
422      free_size = 0
423
424      #   <# blocks> <src ranges>
425      #     OR
426      #   <# blocks> <src ranges> <src locs> <stash refs...>
427      #     OR
428      #   <# blocks> - <stash refs...>
429
430      size = xf.src_ranges.size()
431      src_str_buffer = [str(size)]
432
433      unstashed_src_ranges = xf.src_ranges
434      mapped_stashes = []
435      for _, sr in xf.use_stash:
436        unstashed_src_ranges = unstashed_src_ranges.subtract(sr)
437        sh = self.src.RangeSha1(sr)
438        sr = xf.src_ranges.map_within(sr)
439        mapped_stashes.append(sr)
440        assert sh in stashes
441        src_str_buffer.append("%s:%s" % (sh, sr.to_string_raw()))
442        stashes[sh] -= 1
443        if stashes[sh] == 0:
444          free_string.append("free %s\n" % (sh,))
445          free_size += sr.size()
446          stashes.pop(sh)
447
448      if unstashed_src_ranges:
449        src_str_buffer.insert(1, unstashed_src_ranges.to_string_raw())
450        if xf.use_stash:
451          mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges)
452          src_str_buffer.insert(2, mapped_unstashed.to_string_raw())
453          mapped_stashes.append(mapped_unstashed)
454          self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
455      else:
456        src_str_buffer.insert(1, "-")
457        self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
458
459      src_str = " ".join(src_str_buffer)
460
461      # version 3+:
462      #   zero <rangeset>
463      #   new <rangeset>
464      #   erase <rangeset>
465      #   bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
466      #   imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
467      #   move hash <tgt rangeset> <src_str>
468
469      tgt_size = xf.tgt_ranges.size()
470
471      if xf.style == "new":
472        assert xf.tgt_ranges
473        assert tgt_size == WriteSplitTransfers(out, xf.style, xf.tgt_ranges)
474        total += tgt_size
475      elif xf.style == "move":
476        assert xf.tgt_ranges
477        assert xf.src_ranges.size() == tgt_size
478        if xf.src_ranges != xf.tgt_ranges:
479          # take into account automatic stashing of overlapping blocks
480          if xf.src_ranges.overlaps(xf.tgt_ranges):
481            temp_stash_usage = stashed_blocks + xf.src_ranges.size()
482            if temp_stash_usage > max_stashed_blocks:
483              max_stashed_blocks = temp_stash_usage
484
485          self.touched_src_ranges = self.touched_src_ranges.union(
486              xf.src_ranges)
487
488          out.append("%s %s %s %s\n" % (
489              xf.style,
490              xf.tgt_sha1,
491              xf.tgt_ranges.to_string_raw(), src_str))
492          total += tgt_size
493      elif xf.style in ("bsdiff", "imgdiff"):
494        assert xf.tgt_ranges
495        assert xf.src_ranges
496        # take into account automatic stashing of overlapping blocks
497        if xf.src_ranges.overlaps(xf.tgt_ranges):
498          temp_stash_usage = stashed_blocks + xf.src_ranges.size()
499          if temp_stash_usage > max_stashed_blocks:
500            max_stashed_blocks = temp_stash_usage
501
502        self.touched_src_ranges = self.touched_src_ranges.union(xf.src_ranges)
503
504        out.append("%s %d %d %s %s %s %s\n" % (
505            xf.style,
506            xf.patch_start, xf.patch_len,
507            xf.src_sha1,
508            xf.tgt_sha1,
509            xf.tgt_ranges.to_string_raw(), src_str))
510        total += tgt_size
511      elif xf.style == "zero":
512        assert xf.tgt_ranges
513        to_zero = xf.tgt_ranges.subtract(xf.src_ranges)
514        assert WriteSplitTransfers(out, xf.style, to_zero) == to_zero.size()
515        total += to_zero.size()
516      else:
517        raise ValueError("unknown transfer style '%s'\n" % xf.style)
518
519      if free_string:
520        out.append("".join(free_string))
521        stashed_blocks -= free_size
522
523      if common.OPTIONS.cache_size is not None:
524        # Validation check: abort if we're going to need more stash space than
525        # the allowed size (cache_size * threshold). There are two purposes
526        # of having a threshold here. a) Part of the cache may have been
527        # occupied by some recovery logs. b) It will buy us some time to deal
528        # with the oversize issue.
529        cache_size = common.OPTIONS.cache_size
530        stash_threshold = common.OPTIONS.stash_threshold
531        max_allowed = cache_size * stash_threshold
532        assert max_stashed_blocks * self.tgt.blocksize <= max_allowed, \
533               'Stash size %d (%d * %d) exceeds the limit %d (%d * %.2f)' % (
534                   max_stashed_blocks * self.tgt.blocksize, max_stashed_blocks,
535                   self.tgt.blocksize, max_allowed, cache_size,
536                   stash_threshold)
537
538    self.touched_src_sha1 = self.src.RangeSha1(self.touched_src_ranges)
539
540    # Zero out extended blocks as a workaround for bug 20881595.
541    if self.tgt.extended:
542      assert (WriteSplitTransfers(out, "zero", self.tgt.extended) ==
543              self.tgt.extended.size())
544      total += self.tgt.extended.size()
545
546    # We erase all the blocks on the partition that a) don't contain useful
547    # data in the new image; b) will not be touched by dm-verity. Out of those
548    # blocks, we erase the ones that won't be used in this update at the
549    # beginning of an update. The rest would be erased at the end. This is to
550    # work around the eMMC issue observed on some devices, which may otherwise
551    # get starving for clean blocks and thus fail the update. (b/28347095)
552    all_tgt = RangeSet(data=(0, self.tgt.total_blocks))
553    all_tgt_minus_extended = all_tgt.subtract(self.tgt.extended)
554    new_dontcare = all_tgt_minus_extended.subtract(self.tgt.care_map)
555
556    erase_first = new_dontcare.subtract(self.touched_src_ranges)
557    if erase_first:
558      out.insert(0, "erase %s\n" % (erase_first.to_string_raw(),))
559
560    erase_last = new_dontcare.subtract(erase_first)
561    if erase_last:
562      out.append("erase %s\n" % (erase_last.to_string_raw(),))
563
564    out.insert(0, "%d\n" % (self.version,))   # format version number
565    out.insert(1, "%d\n" % (total,))
566    # v3+: the number of stash slots is unused.
567    out.insert(2, "0\n")
568    out.insert(3, str(max_stashed_blocks) + "\n")
569
570    with open(prefix + ".transfer.list", "w") as f:
571      for i in out:
572        f.write(i)
573
574    self._max_stashed_size = max_stashed_blocks * self.tgt.blocksize
575    OPTIONS = common.OPTIONS
576    if OPTIONS.cache_size is not None:
577      max_allowed = OPTIONS.cache_size * OPTIONS.stash_threshold
578      logger.info(
579          "max stashed blocks: %d  (%d bytes), limit: %d bytes (%.2f%%)\n",
580          max_stashed_blocks, self._max_stashed_size, max_allowed,
581          self._max_stashed_size * 100.0 / max_allowed)
582    else:
583      logger.info(
584          "max stashed blocks: %d  (%d bytes), limit: <unknown>\n",
585          max_stashed_blocks, self._max_stashed_size)
586
587  def ReviseStashSize(self, ignore_stash_limit=False):
588    """ Revises the transfers to keep the stash size within the size limit.
589
590    Iterates through the transfer list and calculates the stash size each
591    transfer generates. Converts the affected transfers to new if we reach the
592    stash limit.
593
594    Args:
595      ignore_stash_limit: Ignores the stash limit and calculates the max
596      simultaneous stashed blocks instead. No change will be made to the
597      transfer list with this flag.
598
599    Return:
600      A tuple of (tgt blocks converted to new, max stashed blocks)
601    """
602    logger.info("Revising stash size...")
603    stash_map = {}
604
605    # Create the map between a stash and its def/use points. For example, for a
606    # given stash of (raw_id, sr), stash_map[raw_id] = (sr, def_cmd, use_cmd).
607    for xf in self.transfers:
608      # Command xf defines (stores) all the stashes in stash_before.
609      for stash_raw_id, sr in xf.stash_before:
610        stash_map[stash_raw_id] = (sr, xf)
611
612      # Record all the stashes command xf uses.
613      for stash_raw_id, _ in xf.use_stash:
614        stash_map[stash_raw_id] += (xf,)
615
616    max_allowed_blocks = None
617    if not ignore_stash_limit:
618      # Compute the maximum blocks available for stash based on /cache size and
619      # the threshold.
620      cache_size = common.OPTIONS.cache_size
621      stash_threshold = common.OPTIONS.stash_threshold
622      max_allowed_blocks = cache_size * stash_threshold / self.tgt.blocksize
623
624    # See the comments for 'stashes' in WriteTransfers().
625    stashes = {}
626    stashed_blocks = 0
627    new_blocks = 0
628    max_stashed_blocks = 0
629
630    # Now go through all the commands. Compute the required stash size on the
631    # fly. If a command requires excess stash than available, it deletes the
632    # stash by replacing the command that uses the stash with a "new" command
633    # instead.
634    for xf in self.transfers:
635      replaced_cmds = []
636
637      # xf.stash_before generates explicit stash commands.
638      for stash_raw_id, sr in xf.stash_before:
639        # Check the post-command stashed_blocks.
640        stashed_blocks_after = stashed_blocks
641        sh = self.src.RangeSha1(sr)
642        if sh not in stashes:
643          stashed_blocks_after += sr.size()
644
645        if max_allowed_blocks and stashed_blocks_after > max_allowed_blocks:
646          # We cannot stash this one for a later command. Find out the command
647          # that will use this stash and replace the command with "new".
648          use_cmd = stash_map[stash_raw_id][2]
649          replaced_cmds.append(use_cmd)
650          logger.info("%10d  %9s  %s", sr.size(), "explicit", use_cmd)
651        else:
652          # Update the stashes map.
653          if sh in stashes:
654            stashes[sh] += 1
655          else:
656            stashes[sh] = 1
657          stashed_blocks = stashed_blocks_after
658          max_stashed_blocks = max(max_stashed_blocks, stashed_blocks)
659
660      # "move" and "diff" may introduce implicit stashes in BBOTA v3. Prior to
661      # ComputePatches(), they both have the style of "diff".
662      if xf.style == "diff":
663        assert xf.tgt_ranges and xf.src_ranges
664        if xf.src_ranges.overlaps(xf.tgt_ranges):
665          if (max_allowed_blocks and
666              stashed_blocks + xf.src_ranges.size() > max_allowed_blocks):
667            replaced_cmds.append(xf)
668            logger.info("%10d  %9s  %s", xf.src_ranges.size(), "implicit", xf)
669          else:
670            # The whole source ranges will be stashed for implicit stashes.
671            max_stashed_blocks = max(max_stashed_blocks,
672                                     stashed_blocks + xf.src_ranges.size())
673
674      # Replace the commands in replaced_cmds with "new"s.
675      for cmd in replaced_cmds:
676        # It no longer uses any commands in "use_stash". Remove the def points
677        # for all those stashes.
678        for stash_raw_id, sr in cmd.use_stash:
679          def_cmd = stash_map[stash_raw_id][1]
680          assert (stash_raw_id, sr) in def_cmd.stash_before
681          def_cmd.stash_before.remove((stash_raw_id, sr))
682
683        # Add up blocks that violates space limit and print total number to
684        # screen later.
685        new_blocks += cmd.tgt_ranges.size()
686        cmd.ConvertToNew()
687
688      # xf.use_stash may generate free commands.
689      for _, sr in xf.use_stash:
690        sh = self.src.RangeSha1(sr)
691        assert sh in stashes
692        stashes[sh] -= 1
693        if stashes[sh] == 0:
694          stashed_blocks -= sr.size()
695          stashes.pop(sh)
696
697    num_of_bytes = new_blocks * self.tgt.blocksize
698    logger.info(
699        "  Total %d blocks (%d bytes) are packed as new blocks due to "
700        "insufficient cache size. Maximum blocks stashed simultaneously: %d",
701        new_blocks, num_of_bytes, max_stashed_blocks)
702    return new_blocks, max_stashed_blocks
703
704  def ComputePatches(self, prefix):
705    logger.info("Reticulating splines...")
706    diff_queue = []
707    patch_num = 0
708    with open(prefix + ".new.dat", "wb") as new_f:
709      for index, xf in enumerate(self.transfers):
710        if xf.style == "zero":
711          tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
712          logger.info(
713              "%10d %10d (%6.2f%%) %7s %s %s", tgt_size, tgt_size, 100.0,
714              xf.style, xf.tgt_name, str(xf.tgt_ranges))
715
716        elif xf.style == "new":
717          self.tgt.WriteRangeDataToFd(xf.tgt_ranges, new_f)
718          tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
719          logger.info(
720              "%10d %10d (%6.2f%%) %7s %s %s", tgt_size, tgt_size, 100.0,
721              xf.style, xf.tgt_name, str(xf.tgt_ranges))
722
723        elif xf.style == "diff":
724          # We can't compare src and tgt directly because they may have
725          # the same content but be broken up into blocks differently, eg:
726          #
727          #    ["he", "llo"]  vs  ["h", "ello"]
728          #
729          # We want those to compare equal, ideally without having to
730          # actually concatenate the strings (these may be tens of
731          # megabytes).
732          if xf.src_sha1 == xf.tgt_sha1:
733            # These are identical; we don't need to generate a patch,
734            # just issue copy commands on the device.
735            xf.style = "move"
736            xf.patch_info = None
737            tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
738            if xf.src_ranges != xf.tgt_ranges:
739              logger.info(
740                  "%10d %10d (%6.2f%%) %7s %s %s (from %s)", tgt_size, tgt_size,
741                  100.0, xf.style,
742                  xf.tgt_name if xf.tgt_name == xf.src_name else (
743                      xf.tgt_name + " (from " + xf.src_name + ")"),
744                  str(xf.tgt_ranges), str(xf.src_ranges))
745          else:
746            if xf.patch_info:
747              # We have already generated the patch (e.g. during split of large
748              # APKs or reduction of stash size)
749              imgdiff = xf.patch_info.imgdiff
750            else:
751              imgdiff = self.CanUseImgdiff(
752                  xf.tgt_name, xf.tgt_ranges, xf.src_ranges)
753            xf.style = "imgdiff" if imgdiff else "bsdiff"
754            diff_queue.append((index, imgdiff, patch_num))
755            patch_num += 1
756
757        else:
758          assert False, "unknown style " + xf.style
759
760    patches = self.ComputePatchesForInputList(diff_queue, False)
761
762    offset = 0
763    with open(prefix + ".patch.dat", "wb") as patch_fd:
764      for index, patch_info, _ in patches:
765        xf = self.transfers[index]
766        xf.patch_len = len(patch_info.content)
767        xf.patch_start = offset
768        offset += xf.patch_len
769        patch_fd.write(patch_info.content)
770
771        tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
772        logger.info(
773            "%10d %10d (%6.2f%%) %7s %s %s %s", xf.patch_len, tgt_size,
774            xf.patch_len * 100.0 / tgt_size, xf.style,
775            xf.tgt_name if xf.tgt_name == xf.src_name else (
776                xf.tgt_name + " (from " + xf.src_name + ")"),
777            xf.tgt_ranges, xf.src_ranges)
778
779  def AssertSha1Good(self):
780    """Check the SHA-1 of the src & tgt blocks in the transfer list.
781
782    Double check the SHA-1 value to avoid the issue in b/71908713, where
783    SparseImage.RangeSha1() messed up with the hash calculation in multi-thread
784    environment. That specific problem has been fixed by protecting the
785    underlying generator function 'SparseImage._GetRangeData()' with lock.
786    """
787    for xf in self.transfers:
788      tgt_sha1 = self.tgt.RangeSha1(xf.tgt_ranges)
789      assert xf.tgt_sha1 == tgt_sha1
790      if xf.style == "diff":
791        src_sha1 = self.src.RangeSha1(xf.src_ranges)
792        assert xf.src_sha1 == src_sha1
793
794  def AssertSequenceGood(self):
795    # Simulate the sequences of transfers we will output, and check that:
796    # - we never read a block after writing it, and
797    # - we write every block we care about exactly once.
798
799    # Start with no blocks having been touched yet.
800    touched = array.array("B", b"\0" * self.tgt.total_blocks)
801
802    # Imagine processing the transfers in order.
803    for xf in self.transfers:
804      # Check that the input blocks for this transfer haven't yet been touched.
805
806      x = xf.src_ranges
807      for _, sr in xf.use_stash:
808        x = x.subtract(sr)
809
810      for s, e in x:
811        # Source image could be larger. Don't check the blocks that are in the
812        # source image only. Since they are not in 'touched', and won't ever
813        # be touched.
814        for i in range(s, min(e, self.tgt.total_blocks)):
815          assert touched[i] == 0
816
817      # Check that the output blocks for this transfer haven't yet
818      # been touched, and touch all the blocks written by this
819      # transfer.
820      for s, e in xf.tgt_ranges:
821        for i in range(s, e):
822          assert touched[i] == 0
823          touched[i] = 1
824
825    # Check that we've written every target block.
826    for s, e in self.tgt.care_map:
827      for i in range(s, e):
828        assert touched[i] == 1
829
830  def FindSequenceForTransfers(self):
831    """Finds a sequence for the given transfers.
832
833     The goal is to minimize the violation of order dependencies between these
834     transfers, so that fewer blocks are stashed when applying the update.
835    """
836
837    # Clear the existing dependency between transfers
838    for xf in self.transfers:
839      xf.goes_before = OrderedDict()
840      xf.goes_after = OrderedDict()
841
842      xf.stash_before = []
843      xf.use_stash = []
844
845    # Find the ordering dependencies among transfers (this is O(n^2)
846    # in the number of transfers).
847    self.GenerateDigraph()
848    # Find a sequence of transfers that satisfies as many ordering
849    # dependencies as possible (heuristically).
850    self.FindVertexSequence()
851    # Fix up the ordering dependencies that the sequence didn't
852    # satisfy.
853    self.ReverseBackwardEdges()
854    self.ImproveVertexSequence()
855
856  def ImproveVertexSequence(self):
857    logger.info("Improving vertex order...")
858
859    # At this point our digraph is acyclic; we reversed any edges that
860    # were backwards in the heuristically-generated sequence.  The
861    # previously-generated order is still acceptable, but we hope to
862    # find a better order that needs less memory for stashed data.
863    # Now we do a topological sort to generate a new vertex order,
864    # using a greedy algorithm to choose which vertex goes next
865    # whenever we have a choice.
866
867    # Make a copy of the edge set; this copy will get destroyed by the
868    # algorithm.
869    for xf in self.transfers:
870      xf.incoming = xf.goes_after.copy()
871      xf.outgoing = xf.goes_before.copy()
872
873    L = []   # the new vertex order
874
875    # S is the set of sources in the remaining graph; we always choose
876    # the one that leaves the least amount of stashed data after it's
877    # executed.
878    S = [(u.NetStashChange(), u.order, u) for u in self.transfers
879         if not u.incoming]
880    heapq.heapify(S)
881
882    while S:
883      _, _, xf = heapq.heappop(S)
884      L.append(xf)
885      for u in xf.outgoing:
886        del u.incoming[xf]
887        if not u.incoming:
888          heapq.heappush(S, (u.NetStashChange(), u.order, u))
889
890    # if this fails then our graph had a cycle.
891    assert len(L) == len(self.transfers)
892
893    self.transfers = L
894    for i, xf in enumerate(L):
895      xf.order = i
896
897  def ReverseBackwardEdges(self):
898    """Reverse unsatisfying edges and compute pairs of stashed blocks.
899
900    For each transfer, make sure it properly stashes the blocks it touches and
901    will be used by later transfers. It uses pairs of (stash_raw_id, range) to
902    record the blocks to be stashed. 'stash_raw_id' is an id that uniquely
903    identifies each pair. Note that for the same range (e.g. RangeSet("1-5")),
904    it is possible to have multiple pairs with different 'stash_raw_id's. Each
905    'stash_raw_id' will be consumed by one transfer. In BBOTA v3+, identical
906    blocks will be written to the same stash slot in WriteTransfers().
907    """
908
909    logger.info("Reversing backward edges...")
910    in_order = 0
911    out_of_order = 0
912    stash_raw_id = 0
913    stash_size = 0
914
915    for xf in self.transfers:
916      for u in xf.goes_before.copy():
917        # xf should go before u
918        if xf.order < u.order:
919          # it does, hurray!
920          in_order += 1
921        else:
922          # it doesn't, boo.  modify u to stash the blocks that it
923          # writes that xf wants to read, and then require u to go
924          # before xf.
925          out_of_order += 1
926
927          overlap = xf.src_ranges.intersect(u.tgt_ranges)
928          assert overlap
929
930          u.stash_before.append((stash_raw_id, overlap))
931          xf.use_stash.append((stash_raw_id, overlap))
932          stash_raw_id += 1
933          stash_size += overlap.size()
934
935          # reverse the edge direction; now xf must go after u
936          del xf.goes_before[u]
937          del u.goes_after[xf]
938          xf.goes_after[u] = None    # value doesn't matter
939          u.goes_before[xf] = None
940
941    logger.info(
942        "  %d/%d dependencies (%.2f%%) were violated; %d source blocks "
943        "stashed.", out_of_order, in_order + out_of_order,
944        (out_of_order * 100.0 / (in_order + out_of_order)) if (
945            in_order + out_of_order) else 0.0,
946        stash_size)
947
948  def FindVertexSequence(self):
949    logger.info("Finding vertex sequence...")
950
951    # This is based on "A Fast & Effective Heuristic for the Feedback
952    # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth.  Think of
953    # it as starting with the digraph G and moving all the vertices to
954    # be on a horizontal line in some order, trying to minimize the
955    # number of edges that end up pointing to the left.  Left-pointing
956    # edges will get removed to turn the digraph into a DAG.  In this
957    # case each edge has a weight which is the number of source blocks
958    # we'll lose if that edge is removed; we try to minimize the total
959    # weight rather than just the number of edges.
960
961    # Make a copy of the edge set; this copy will get destroyed by the
962    # algorithm.
963    for xf in self.transfers:
964      xf.incoming = xf.goes_after.copy()
965      xf.outgoing = xf.goes_before.copy()
966      xf.score = sum(xf.outgoing.values()) - sum(xf.incoming.values())
967
968    # We use an OrderedDict instead of just a set so that the output
969    # is repeatable; otherwise it would depend on the hash values of
970    # the transfer objects.
971    G = OrderedDict()
972    for xf in self.transfers:
973      G[xf] = None
974    s1 = deque()  # the left side of the sequence, built from left to right
975    s2 = deque()  # the right side of the sequence, built from right to left
976
977    heap = []
978    for xf in self.transfers:
979      xf.heap_item = HeapItem(xf)
980      heap.append(xf.heap_item)
981    heapq.heapify(heap)
982
983    # Use OrderedDict() instead of set() to preserve the insertion order. Need
984    # to use 'sinks[key] = None' to add key into the set. sinks will look like
985    # { key1: None, key2: None, ... }.
986    sinks = OrderedDict.fromkeys(u for u in G if not u.outgoing)
987    sources = OrderedDict.fromkeys(u for u in G if not u.incoming)
988
989    def adjust_score(iu, delta):
990      iu.score += delta
991      iu.heap_item.clear()
992      iu.heap_item = HeapItem(iu)
993      heapq.heappush(heap, iu.heap_item)
994
995    while G:
996      # Put all sinks at the end of the sequence.
997      while sinks:
998        new_sinks = OrderedDict()
999        for u in sinks:
1000          if u not in G:
1001            continue
1002          s2.appendleft(u)
1003          del G[u]
1004          for iu in u.incoming:
1005            adjust_score(iu, -iu.outgoing.pop(u))
1006            if not iu.outgoing:
1007              new_sinks[iu] = None
1008        sinks = new_sinks
1009
1010      # Put all the sources at the beginning of the sequence.
1011      while sources:
1012        new_sources = OrderedDict()
1013        for u in sources:
1014          if u not in G:
1015            continue
1016          s1.append(u)
1017          del G[u]
1018          for iu in u.outgoing:
1019            adjust_score(iu, +iu.incoming.pop(u))
1020            if not iu.incoming:
1021              new_sources[iu] = None
1022        sources = new_sources
1023
1024      if not G:
1025        break
1026
1027      # Find the "best" vertex to put next.  "Best" is the one that
1028      # maximizes the net difference in source blocks saved we get by
1029      # pretending it's a source rather than a sink.
1030
1031      while True:
1032        u = heapq.heappop(heap)
1033        if u and u.item in G:
1034          u = u.item
1035          break
1036
1037      s1.append(u)
1038      del G[u]
1039      for iu in u.outgoing:
1040        adjust_score(iu, +iu.incoming.pop(u))
1041        if not iu.incoming:
1042          sources[iu] = None
1043
1044      for iu in u.incoming:
1045        adjust_score(iu, -iu.outgoing.pop(u))
1046        if not iu.outgoing:
1047          sinks[iu] = None
1048
1049    # Now record the sequence in the 'order' field of each transfer,
1050    # and by rearranging self.transfers to be in the chosen sequence.
1051
1052    new_transfers = []
1053    for x in itertools.chain(s1, s2):
1054      x.order = len(new_transfers)
1055      new_transfers.append(x)
1056      del x.incoming
1057      del x.outgoing
1058
1059    self.transfers = new_transfers
1060
1061  def GenerateDigraph(self):
1062    logger.info("Generating digraph...")
1063
1064    # Each item of source_ranges will be:
1065    #   - None, if that block is not used as a source,
1066    #   - an ordered set of transfers.
1067    source_ranges = []
1068    for b in self.transfers:
1069      for s, e in b.src_ranges:
1070        if e > len(source_ranges):
1071          source_ranges.extend([None] * (e-len(source_ranges)))
1072        for i in range(s, e):
1073          if source_ranges[i] is None:
1074            source_ranges[i] = OrderedDict.fromkeys([b])
1075          else:
1076            source_ranges[i][b] = None
1077
1078    for a in self.transfers:
1079      intersections = OrderedDict()
1080      for s, e in a.tgt_ranges:
1081        for i in range(s, e):
1082          if i >= len(source_ranges):
1083            break
1084          # Add all the Transfers in source_ranges[i] to the (ordered) set.
1085          if source_ranges[i] is not None:
1086            for j in source_ranges[i]:
1087              intersections[j] = None
1088
1089      for b in intersections:
1090        if a is b:
1091          continue
1092
1093        # If the blocks written by A are read by B, then B needs to go before A.
1094        i = a.tgt_ranges.intersect(b.src_ranges)
1095        if i:
1096          if b.src_name == "__ZERO":
1097            # the cost of removing source blocks for the __ZERO domain
1098            # is (nearly) zero.
1099            size = 0
1100          else:
1101            size = i.size()
1102          b.goes_before[a] = size
1103          a.goes_after[b] = size
1104
1105  def ComputePatchesForInputList(self, diff_queue, compress_target):
1106    """Returns a list of patch information for the input list of transfers.
1107
1108      Args:
1109        diff_queue: a list of transfers with style 'diff'
1110        compress_target: If True, compresses the target ranges of each
1111            transfers; and save the size.
1112
1113      Returns:
1114        A list of (transfer order, patch_info, compressed_size) tuples.
1115    """
1116
1117    if not diff_queue:
1118      return []
1119
1120    if self.threads > 1:
1121      logger.info("Computing patches (using %d threads)...", self.threads)
1122    else:
1123      logger.info("Computing patches...")
1124
1125    diff_total = len(diff_queue)
1126    patches = [None] * diff_total
1127    error_messages = []
1128
1129    # Using multiprocessing doesn't give additional benefits, due to the
1130    # pattern of the code. The diffing work is done by subprocess.call, which
1131    # already runs in a separate process (not affected much by the GIL -
1132    # Global Interpreter Lock). Using multiprocess also requires either a)
1133    # writing the diff input files in the main process before forking, or b)
1134    # reopening the image file (SparseImage) in the worker processes. Doing
1135    # neither of them further improves the performance.
1136    lock = threading.Lock()
1137
1138    def diff_worker():
1139      while True:
1140        with lock:
1141          if not diff_queue:
1142            return
1143          xf_index, imgdiff, patch_index = diff_queue.pop()
1144          xf = self.transfers[xf_index]
1145
1146        message = []
1147        compressed_size = None
1148
1149        patch_info = xf.patch_info
1150        if not patch_info:
1151          src_file = common.MakeTempFile(prefix="src-")
1152          with open(src_file, "wb") as fd:
1153            self.src.WriteRangeDataToFd(xf.src_ranges, fd)
1154
1155          tgt_file = common.MakeTempFile(prefix="tgt-")
1156          with open(tgt_file, "wb") as fd:
1157            self.tgt.WriteRangeDataToFd(xf.tgt_ranges, fd)
1158
1159          try:
1160            patch_info = compute_patch(src_file, tgt_file, imgdiff)
1161          except ValueError as e:
1162            message.append(
1163                "Failed to generate %s for %s: tgt=%s, src=%s:\n%s" % (
1164                    "imgdiff" if imgdiff else "bsdiff",
1165                    xf.tgt_name if xf.tgt_name == xf.src_name else
1166                    xf.tgt_name + " (from " + xf.src_name + ")",
1167                    xf.tgt_ranges, xf.src_ranges, e.message))
1168
1169        if compress_target:
1170          tgt_data = self.tgt.ReadRangeSet(xf.tgt_ranges)
1171          try:
1172            # Compresses with the default level
1173            compress_obj = zlib.compressobj(6, zlib.DEFLATED, -zlib.MAX_WBITS)
1174            compressed_data = (compress_obj.compress(b"".join(tgt_data))
1175                               + compress_obj.flush())
1176            compressed_size = len(compressed_data)
1177          except zlib.error as e:
1178            message.append(
1179                "Failed to compress the data in target range {} for {}:\n"
1180                "{}".format(xf.tgt_ranges, xf.tgt_name, e.message))
1181
1182        if message:
1183          with lock:
1184            error_messages.extend(message)
1185
1186        with lock:
1187          patches[patch_index] = (xf_index, patch_info, compressed_size)
1188
1189    threads = [threading.Thread(target=diff_worker)
1190               for _ in range(self.threads)]
1191    for th in threads:
1192      th.start()
1193    while threads:
1194      threads.pop().join()
1195
1196    if error_messages:
1197      logger.error('ERROR:')
1198      logger.error('\n'.join(error_messages))
1199      logger.error('\n\n\n')
1200      sys.exit(1)
1201
1202    return patches
1203
1204  def SelectAndConvertDiffTransfersToNew(self, violated_stash_blocks):
1205    """Converts the diff transfers to reduce the max simultaneous stash.
1206
1207    Since the 'new' data is compressed with deflate, we can select the 'diff'
1208    transfers for conversion by comparing its patch size with the size of the
1209    compressed data. Ideally, we want to convert the transfers with a small
1210    size increase, but using a large number of stashed blocks.
1211    """
1212    TransferSizeScore = namedtuple("TransferSizeScore",
1213                                   "xf, used_stash_blocks, score")
1214
1215    logger.info("Selecting diff commands to convert to new.")
1216    diff_queue = []
1217    for xf in self.transfers:
1218      if xf.style == "diff" and xf.src_sha1 != xf.tgt_sha1:
1219        use_imgdiff = self.CanUseImgdiff(xf.tgt_name, xf.tgt_ranges,
1220                                         xf.src_ranges)
1221        diff_queue.append((xf.order, use_imgdiff, len(diff_queue)))
1222
1223    # Remove the 'move' transfers, and compute the patch & compressed size
1224    # for the remaining.
1225    result = self.ComputePatchesForInputList(diff_queue, True)
1226
1227    conversion_candidates = []
1228    for xf_index, patch_info, compressed_size in result:
1229      xf = self.transfers[xf_index]
1230      if not xf.patch_info:
1231        xf.patch_info = patch_info
1232
1233      size_ratio = len(xf.patch_info.content) * 100.0 / compressed_size
1234      diff_style = "imgdiff" if xf.patch_info.imgdiff else "bsdiff"
1235      logger.info("%s, target size: %d blocks, style: %s, patch size: %d,"
1236                  " compression_size: %d, ratio %.2f%%", xf.tgt_name,
1237                  xf.tgt_ranges.size(), diff_style,
1238                  len(xf.patch_info.content), compressed_size, size_ratio)
1239
1240      used_stash_blocks = sum(sr.size() for _, sr in xf.use_stash)
1241      # Convert the transfer to new if the compressed size is smaller or equal.
1242      # We don't need to maintain the stash_before lists here because the
1243      # graph will be regenerated later.
1244      if len(xf.patch_info.content) >= compressed_size:
1245        # Add the transfer to the candidate list with negative score. And it
1246        # will be converted later.
1247        conversion_candidates.append(TransferSizeScore(xf, used_stash_blocks,
1248                                                       -1))
1249      elif used_stash_blocks > 0:
1250        # This heuristic represents the size increase in the final package to
1251        # remove per unit of stashed data.
1252        score = ((compressed_size - len(xf.patch_info.content)) * 100.0
1253                 / used_stash_blocks)
1254        conversion_candidates.append(TransferSizeScore(xf, used_stash_blocks,
1255                                                       score))
1256    # Transfers with lower score (i.e. less expensive to convert) will be
1257    # converted first.
1258    conversion_candidates.sort(key=lambda x: x.score)
1259
1260    # TODO(xunchang), improve the logic to find the transfers to convert, e.g.
1261    # convert the ones that contribute to the max stash, run ReviseStashSize
1262    # multiple times etc.
1263    removed_stashed_blocks = 0
1264    for xf, used_stash_blocks, _ in conversion_candidates:
1265      logger.info("Converting %s to new", xf.tgt_name)
1266      xf.ConvertToNew()
1267      removed_stashed_blocks += used_stash_blocks
1268      # Experiments show that we will get a smaller package size if we remove
1269      # slightly more stashed blocks than the violated stash blocks.
1270      if removed_stashed_blocks >= violated_stash_blocks:
1271        break
1272
1273    logger.info("Removed %d stashed blocks", removed_stashed_blocks)
1274
1275  def FindTransfers(self):
1276    """Parse the file_map to generate all the transfers."""
1277
1278    def AddSplitTransfersWithFixedSizeChunks(tgt_name, src_name, tgt_ranges,
1279                                             src_ranges, style, by_id):
1280      """Add one or multiple Transfer()s by splitting large files.
1281
1282      For BBOTA v3, we need to stash source blocks for resumable feature.
1283      However, with the growth of file size and the shrink of the cache
1284      partition source blocks are too large to be stashed. If a file occupies
1285      too many blocks, we split it into smaller pieces by getting multiple
1286      Transfer()s.
1287
1288      The downside is that after splitting, we may increase the package size
1289      since the split pieces don't align well. According to our experiments,
1290      1/8 of the cache size as the per-piece limit appears to be optimal.
1291      Compared to the fixed 1024-block limit, it reduces the overall package
1292      size by 30% for volantis, and 20% for angler and bullhead."""
1293
1294      pieces = 0
1295      while (tgt_ranges.size() > max_blocks_per_transfer and
1296             src_ranges.size() > max_blocks_per_transfer):
1297        tgt_split_name = "%s-%d" % (tgt_name, pieces)
1298        src_split_name = "%s-%d" % (src_name, pieces)
1299        tgt_first = tgt_ranges.first(max_blocks_per_transfer)
1300        src_first = src_ranges.first(max_blocks_per_transfer)
1301
1302        Transfer(tgt_split_name, src_split_name, tgt_first, src_first,
1303                 self.tgt.RangeSha1(tgt_first), self.src.RangeSha1(src_first),
1304                 style, by_id)
1305
1306        tgt_ranges = tgt_ranges.subtract(tgt_first)
1307        src_ranges = src_ranges.subtract(src_first)
1308        pieces += 1
1309
1310      # Handle remaining blocks.
1311      if tgt_ranges.size() or src_ranges.size():
1312        # Must be both non-empty.
1313        assert tgt_ranges.size() and src_ranges.size()
1314        tgt_split_name = "%s-%d" % (tgt_name, pieces)
1315        src_split_name = "%s-%d" % (src_name, pieces)
1316        Transfer(tgt_split_name, src_split_name, tgt_ranges, src_ranges,
1317                 self.tgt.RangeSha1(tgt_ranges), self.src.RangeSha1(src_ranges),
1318                 style, by_id)
1319
1320    def AddSplitTransfers(tgt_name, src_name, tgt_ranges, src_ranges, style,
1321                          by_id):
1322      """Find all the zip files and split the others with a fixed chunk size.
1323
1324      This function will construct a list of zip archives, which will later be
1325      split by imgdiff to reduce the final patch size. For the other files,
1326      we will plainly split them based on a fixed chunk size with the potential
1327      patch size penalty.
1328      """
1329
1330      assert style == "diff"
1331
1332      # Change nothing for small files.
1333      if (tgt_ranges.size() <= max_blocks_per_transfer and
1334          src_ranges.size() <= max_blocks_per_transfer):
1335        Transfer(tgt_name, src_name, tgt_ranges, src_ranges,
1336                 self.tgt.RangeSha1(tgt_ranges), self.src.RangeSha1(src_ranges),
1337                 style, by_id)
1338        return
1339
1340      # Split large APKs with imgdiff, if possible. We're intentionally checking
1341      # file types one more time (CanUseImgdiff() checks that as well), before
1342      # calling the costly RangeSha1()s.
1343      if (self.FileTypeSupportedByImgdiff(tgt_name) and
1344          self.tgt.RangeSha1(tgt_ranges) != self.src.RangeSha1(src_ranges)):
1345        if self.CanUseImgdiff(tgt_name, tgt_ranges, src_ranges, True):
1346          large_apks.append((tgt_name, src_name, tgt_ranges, src_ranges))
1347          return
1348
1349      AddSplitTransfersWithFixedSizeChunks(tgt_name, src_name, tgt_ranges,
1350                                           src_ranges, style, by_id)
1351
1352    def AddTransfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id,
1353                    split=False):
1354      """Wrapper function for adding a Transfer()."""
1355
1356      # We specialize diff transfers only (which covers bsdiff/imgdiff/move);
1357      # otherwise add the Transfer() as is.
1358      if style != "diff" or not split:
1359        Transfer(tgt_name, src_name, tgt_ranges, src_ranges,
1360                 self.tgt.RangeSha1(tgt_ranges), self.src.RangeSha1(src_ranges),
1361                 style, by_id)
1362        return
1363
1364      # Handle .odex files specially to analyze the block-wise difference. If
1365      # most of the blocks are identical with only few changes (e.g. header),
1366      # we will patch the changed blocks only. This avoids stashing unchanged
1367      # blocks while patching. We limit the analysis to files without size
1368      # changes only. This is to avoid sacrificing the OTA generation cost too
1369      # much.
1370      if (tgt_name.split(".")[-1].lower() == 'odex' and
1371          tgt_ranges.size() == src_ranges.size()):
1372
1373        # 0.5 threshold can be further tuned. The tradeoff is: if only very
1374        # few blocks remain identical, we lose the opportunity to use imgdiff
1375        # that may have better compression ratio than bsdiff.
1376        crop_threshold = 0.5
1377
1378        tgt_skipped = RangeSet()
1379        src_skipped = RangeSet()
1380        tgt_size = tgt_ranges.size()
1381        tgt_changed = 0
1382        for src_block, tgt_block in zip(src_ranges.next_item(),
1383                                        tgt_ranges.next_item()):
1384          src_rs = RangeSet(str(src_block))
1385          tgt_rs = RangeSet(str(tgt_block))
1386          if self.src.ReadRangeSet(src_rs) == self.tgt.ReadRangeSet(tgt_rs):
1387            tgt_skipped = tgt_skipped.union(tgt_rs)
1388            src_skipped = src_skipped.union(src_rs)
1389          else:
1390            tgt_changed += tgt_rs.size()
1391
1392          # Terminate early if no clear sign of benefits.
1393          if tgt_changed > tgt_size * crop_threshold:
1394            break
1395
1396        if tgt_changed < tgt_size * crop_threshold:
1397          assert tgt_changed + tgt_skipped.size() == tgt_size
1398          logger.info(
1399              '%10d %10d (%6.2f%%) %s', tgt_skipped.size(), tgt_size,
1400              tgt_skipped.size() * 100.0 / tgt_size, tgt_name)
1401          AddSplitTransfers(
1402              "%s-skipped" % (tgt_name,),
1403              "%s-skipped" % (src_name,),
1404              tgt_skipped, src_skipped, style, by_id)
1405
1406          # Intentionally change the file extension to avoid being imgdiff'd as
1407          # the files are no longer in their original format.
1408          tgt_name = "%s-cropped" % (tgt_name,)
1409          src_name = "%s-cropped" % (src_name,)
1410          tgt_ranges = tgt_ranges.subtract(tgt_skipped)
1411          src_ranges = src_ranges.subtract(src_skipped)
1412
1413          # Possibly having no changed blocks.
1414          if not tgt_ranges:
1415            return
1416
1417      # Add the transfer(s).
1418      AddSplitTransfers(
1419          tgt_name, src_name, tgt_ranges, src_ranges, style, by_id)
1420
1421    def ParseAndValidateSplitInfo(patch_size, tgt_ranges, src_ranges,
1422                                  split_info):
1423      """Parse the split_info and return a list of info tuples.
1424
1425      Args:
1426        patch_size: total size of the patch file.
1427        tgt_ranges: Ranges of the target file within the original image.
1428        src_ranges: Ranges of the source file within the original image.
1429        split_info format:
1430          imgdiff version#
1431          count of pieces
1432          <patch_size_1> <tgt_size_1> <src_ranges_1>
1433          ...
1434          <patch_size_n> <tgt_size_n> <src_ranges_n>
1435
1436      Returns:
1437        [patch_start, patch_len, split_tgt_ranges, split_src_ranges]
1438      """
1439
1440      version = int(split_info[0])
1441      assert version == 2
1442      count = int(split_info[1])
1443      assert len(split_info) - 2 == count
1444
1445      split_info_list = []
1446      patch_start = 0
1447      tgt_remain = copy.deepcopy(tgt_ranges)
1448      # each line has the format <patch_size>, <tgt_size>, <src_ranges>
1449      for line in split_info[2:]:
1450        info = line.split()
1451        assert len(info) == 3
1452        patch_length = int(info[0])
1453
1454        split_tgt_size = int(info[1])
1455        assert split_tgt_size % 4096 == 0
1456        assert split_tgt_size // 4096 <= tgt_remain.size()
1457        split_tgt_ranges = tgt_remain.first(split_tgt_size // 4096)
1458        tgt_remain = tgt_remain.subtract(split_tgt_ranges)
1459
1460        # Find the split_src_ranges within the image file from its relative
1461        # position in file.
1462        split_src_indices = RangeSet.parse_raw(info[2])
1463        split_src_ranges = RangeSet()
1464        for r in split_src_indices:
1465          curr_range = src_ranges.first(r[1]).subtract(src_ranges.first(r[0]))
1466          assert not split_src_ranges.overlaps(curr_range)
1467          split_src_ranges = split_src_ranges.union(curr_range)
1468
1469        split_info_list.append((patch_start, patch_length,
1470                                split_tgt_ranges, split_src_ranges))
1471        patch_start += patch_length
1472
1473      # Check that the sizes of all the split pieces add up to the final file
1474      # size for patch and target.
1475      assert tgt_remain.size() == 0
1476      assert patch_start == patch_size
1477      return split_info_list
1478
1479    def SplitLargeApks():
1480      """Split the large apks files.
1481
1482      Example: Chrome.apk will be split into
1483        src-0: Chrome.apk-0, tgt-0: Chrome.apk-0
1484        src-1: Chrome.apk-1, tgt-1: Chrome.apk-1
1485        ...
1486
1487      After the split, the target pieces are continuous and block aligned; and
1488      the source pieces are mutually exclusive. During the split, we also
1489      generate and save the image patch between src-X & tgt-X. This patch will
1490      be valid because the block ranges of src-X & tgt-X will always stay the
1491      same afterwards; but there's a chance we don't use the patch if we
1492      convert the "diff" command into "new" or "move" later.
1493      """
1494
1495      while True:
1496        with transfer_lock:
1497          if not large_apks:
1498            return
1499          tgt_name, src_name, tgt_ranges, src_ranges = large_apks.pop(0)
1500
1501        src_file = common.MakeTempFile(prefix="src-")
1502        tgt_file = common.MakeTempFile(prefix="tgt-")
1503        with open(src_file, "wb") as src_fd:
1504          self.src.WriteRangeDataToFd(src_ranges, src_fd)
1505        with open(tgt_file, "wb") as tgt_fd:
1506          self.tgt.WriteRangeDataToFd(tgt_ranges, tgt_fd)
1507
1508        patch_file = common.MakeTempFile(prefix="patch-")
1509        patch_info_file = common.MakeTempFile(prefix="split_info-")
1510        cmd = ["imgdiff", "-z",
1511               "--block-limit={}".format(max_blocks_per_transfer),
1512               "--split-info=" + patch_info_file,
1513               src_file, tgt_file, patch_file]
1514        proc = common.Run(cmd)
1515        imgdiff_output, _ = proc.communicate()
1516        assert proc.returncode == 0, \
1517            "Failed to create imgdiff patch between {} and {}:\n{}".format(
1518                src_name, tgt_name, imgdiff_output)
1519
1520        with open(patch_info_file) as patch_info:
1521          lines = patch_info.readlines()
1522
1523        patch_size_total = os.path.getsize(patch_file)
1524        split_info_list = ParseAndValidateSplitInfo(patch_size_total,
1525                                                    tgt_ranges, src_ranges,
1526                                                    lines)
1527        for index, (patch_start, patch_length, split_tgt_ranges,
1528                    split_src_ranges) in enumerate(split_info_list):
1529          with open(patch_file, 'rb') as f:
1530            f.seek(patch_start)
1531            patch_content = f.read(patch_length)
1532
1533          split_src_name = "{}-{}".format(src_name, index)
1534          split_tgt_name = "{}-{}".format(tgt_name, index)
1535          split_large_apks.append((split_tgt_name,
1536                                   split_src_name,
1537                                   split_tgt_ranges,
1538                                   split_src_ranges,
1539                                   patch_content))
1540
1541    logger.info("Finding transfers...")
1542
1543    large_apks = []
1544    split_large_apks = []
1545    cache_size = common.OPTIONS.cache_size
1546    split_threshold = 0.125
1547    assert cache_size is not None
1548    max_blocks_per_transfer = int(cache_size * split_threshold /
1549                                  self.tgt.blocksize)
1550    empty = RangeSet()
1551    for tgt_fn, tgt_ranges in sorted(self.tgt.file_map.items()):
1552      if tgt_fn == "__ZERO":
1553        # the special "__ZERO" domain is all the blocks not contained
1554        # in any file and that are filled with zeros.  We have a
1555        # special transfer style for zero blocks.
1556        src_ranges = self.src.file_map.get("__ZERO", empty)
1557        AddTransfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges,
1558                    "zero", self.transfers)
1559        continue
1560
1561      elif tgt_fn == "__COPY":
1562        # "__COPY" domain includes all the blocks not contained in any
1563        # file and that need to be copied unconditionally to the target.
1564        AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
1565        continue
1566
1567      elif tgt_fn == "__HASHTREE":
1568        continue
1569
1570      elif tgt_fn in self.src.file_map:
1571        # Look for an exact pathname match in the source.
1572        AddTransfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn],
1573                    "diff", self.transfers, True)
1574        continue
1575
1576      b = os.path.basename(tgt_fn)
1577      if b in self.src_basenames:
1578        # Look for an exact basename match in the source.
1579        src_fn = self.src_basenames[b]
1580        AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
1581                    "diff", self.transfers, True)
1582        continue
1583
1584      b = re.sub("[0-9]+", "#", b)
1585      if b in self.src_numpatterns:
1586        # Look for a 'number pattern' match (a basename match after
1587        # all runs of digits are replaced by "#").  (This is useful
1588        # for .so files that contain version numbers in the filename
1589        # that get bumped.)
1590        src_fn = self.src_numpatterns[b]
1591        AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
1592                    "diff", self.transfers, True)
1593        continue
1594
1595      AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
1596
1597    transfer_lock = threading.Lock()
1598    threads = [threading.Thread(target=SplitLargeApks)
1599               for _ in range(self.threads)]
1600    for th in threads:
1601      th.start()
1602    while threads:
1603      threads.pop().join()
1604
1605    # Sort the split transfers for large apks to generate a determinate package.
1606    split_large_apks.sort()
1607    for (tgt_name, src_name, tgt_ranges, src_ranges,
1608         patch) in split_large_apks:
1609      transfer_split = Transfer(tgt_name, src_name, tgt_ranges, src_ranges,
1610                                self.tgt.RangeSha1(tgt_ranges),
1611                                self.src.RangeSha1(src_ranges),
1612                                "diff", self.transfers)
1613      transfer_split.patch_info = PatchInfo(True, patch)
1614
1615  def AbbreviateSourceNames(self):
1616    for k in self.src.file_map.keys():
1617      b = os.path.basename(k)
1618      self.src_basenames[b] = k
1619      b = re.sub("[0-9]+", "#", b)
1620      self.src_numpatterns[b] = k
1621
1622  @staticmethod
1623  def AssertPartition(total, seq):
1624    """Assert that all the RangeSets in 'seq' form a partition of the
1625    'total' RangeSet (ie, they are nonintersecting and their union
1626    equals 'total')."""
1627
1628    so_far = RangeSet()
1629    for i in seq:
1630      assert not so_far.overlaps(i)
1631      so_far = so_far.union(i)
1632    assert so_far == total
1633