1# Copyright (C) 2019 The Android Open Source Project 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific 13 14import os 15import threading 16from hashlib import sha1 17 18from rangelib import RangeSet 19 20__all__ = ["EmptyImage", "DataImage", "FileImage"] 21 22 23class Image(object): 24 def RangeSha1(self, ranges): 25 raise NotImplementedError 26 27 def ReadRangeSet(self, ranges): 28 raise NotImplementedError 29 30 def TotalSha1(self, include_clobbered_blocks=False): 31 raise NotImplementedError 32 33 def WriteRangeDataToFd(self, ranges, fd): 34 raise NotImplementedError 35 36 37class EmptyImage(Image): 38 """A zero-length image.""" 39 40 def __init__(self): 41 self.blocksize = 4096 42 self.care_map = RangeSet() 43 self.clobbered_blocks = RangeSet() 44 self.extended = RangeSet() 45 self.total_blocks = 0 46 self.file_map = {} 47 self.hashtree_info = None 48 49 def RangeSha1(self, ranges): 50 return sha1().hexdigest() 51 52 def ReadRangeSet(self, ranges): 53 return () 54 55 def TotalSha1(self, include_clobbered_blocks=False): 56 # EmptyImage always carries empty clobbered_blocks, so 57 # include_clobbered_blocks can be ignored. 58 assert self.clobbered_blocks.size() == 0 59 return sha1().hexdigest() 60 61 def WriteRangeDataToFd(self, ranges, fd): 62 raise ValueError("Can't write data from EmptyImage to file") 63 64 65class DataImage(Image): 66 """An image wrapped around a single string of data.""" 67 68 def __init__(self, data, trim=False, pad=False): 69 self.data = data 70 self.blocksize = 4096 71 72 assert not (trim and pad) 73 74 partial = len(self.data) % self.blocksize 75 padded = False 76 if partial > 0: 77 if trim: 78 self.data = self.data[:-partial] 79 elif pad: 80 self.data += '\0' * (self.blocksize - partial) 81 padded = True 82 else: 83 raise ValueError(("data for DataImage must be multiple of %d bytes " 84 "unless trim or pad is specified") % 85 (self.blocksize,)) 86 87 assert len(self.data) % self.blocksize == 0 88 89 self.total_blocks = len(self.data) // self.blocksize 90 self.care_map = RangeSet(data=(0, self.total_blocks)) 91 # When the last block is padded, we always write the whole block even for 92 # incremental OTAs. Because otherwise the last block may get skipped if 93 # unchanged for an incremental, but would fail the post-install 94 # verification if it has non-zero contents in the padding bytes. 95 # Bug: 23828506 96 if padded: 97 clobbered_blocks = [self.total_blocks-1, self.total_blocks] 98 else: 99 clobbered_blocks = [] 100 self.clobbered_blocks = clobbered_blocks 101 self.extended = RangeSet() 102 103 zero_blocks = [] 104 nonzero_blocks = [] 105 reference = '\0' * self.blocksize 106 107 for i in range(self.total_blocks-1 if padded else self.total_blocks): 108 d = self.data[i*self.blocksize : (i+1)*self.blocksize] 109 if d == reference: 110 zero_blocks.append(i) 111 zero_blocks.append(i+1) 112 else: 113 nonzero_blocks.append(i) 114 nonzero_blocks.append(i+1) 115 116 assert zero_blocks or nonzero_blocks or clobbered_blocks 117 118 self.file_map = dict() 119 if zero_blocks: 120 self.file_map["__ZERO"] = RangeSet(data=zero_blocks) 121 if nonzero_blocks: 122 self.file_map["__NONZERO"] = RangeSet(data=nonzero_blocks) 123 if clobbered_blocks: 124 self.file_map["__COPY"] = RangeSet(data=clobbered_blocks) 125 126 def _GetRangeData(self, ranges): 127 for s, e in ranges: 128 yield self.data[s*self.blocksize:e*self.blocksize] 129 130 def RangeSha1(self, ranges): 131 h = sha1() 132 for data in self._GetRangeData(ranges): # pylint: disable=not-an-iterable 133 h.update(data) 134 return h.hexdigest() 135 136 def ReadRangeSet(self, ranges): 137 return list(self._GetRangeData(ranges)) 138 139 def TotalSha1(self, include_clobbered_blocks=False): 140 if not include_clobbered_blocks: 141 return self.RangeSha1(self.care_map.subtract(self.clobbered_blocks)) 142 return sha1(self.data).hexdigest() 143 144 def WriteRangeDataToFd(self, ranges, fd): 145 for data in self._GetRangeData(ranges): # pylint: disable=not-an-iterable 146 fd.write(data) 147 148 149class FileImage(Image): 150 """An image wrapped around a raw image file.""" 151 152 def __init__(self, path): 153 self.path = path 154 self.blocksize = 4096 155 self._file_size = os.path.getsize(self.path) 156 self._file = open(self.path, 'rb') 157 158 if self._file_size % self.blocksize != 0: 159 raise ValueError("Size of file %s must be multiple of %d bytes, but is %d" 160 % self.path, self.blocksize, self._file_size) 161 162 self.total_blocks = self._file_size // self.blocksize 163 self.care_map = RangeSet(data=(0, self.total_blocks)) 164 self.clobbered_blocks = RangeSet() 165 self.extended = RangeSet() 166 167 self.generator_lock = threading.Lock() 168 169 zero_blocks = [] 170 nonzero_blocks = [] 171 reference = '\0' * self.blocksize 172 173 for i in range(self.total_blocks): 174 d = self._file.read(self.blocksize) 175 if d == reference: 176 zero_blocks.append(i) 177 zero_blocks.append(i+1) 178 else: 179 nonzero_blocks.append(i) 180 nonzero_blocks.append(i+1) 181 182 assert zero_blocks or nonzero_blocks 183 184 self.file_map = {} 185 if zero_blocks: 186 self.file_map["__ZERO"] = RangeSet(data=zero_blocks) 187 if nonzero_blocks: 188 self.file_map["__NONZERO"] = RangeSet(data=nonzero_blocks) 189 190 def __del__(self): 191 self._file.close() 192 193 def _GetRangeData(self, ranges): 194 # Use a lock to protect the generator so that we will not run two 195 # instances of this generator on the same object simultaneously. 196 with self.generator_lock: 197 for s, e in ranges: 198 self._file.seek(s * self.blocksize) 199 for _ in range(s, e): 200 yield self._file.read(self.blocksize) 201 202 def RangeSha1(self, ranges): 203 h = sha1() 204 for data in self._GetRangeData(ranges): # pylint: disable=not-an-iterable 205 h.update(data) 206 return h.hexdigest() 207 208 def ReadRangeSet(self, ranges): 209 return list(self._GetRangeData(ranges)) 210 211 def TotalSha1(self, include_clobbered_blocks=False): 212 assert not self.clobbered_blocks 213 return self.RangeSha1(self.care_map) 214 215 def WriteRangeDataToFd(self, ranges, fd): 216 for data in self._GetRangeData(ranges): # pylint: disable=not-an-iterable 217 fd.write(data) 218