1 //
2 // Copyright (C) 2020 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 
17 #include <sys/types.h>
18 #include <unistd.h>
19 
20 #include <cstdint>
21 #include <limits>
22 #include <memory>
23 #include <queue>
24 #include <vector>
25 
26 #include <android-base/file.h>
27 #include <android-base/logging.h>
28 #include <android-base/parseint.h>
29 #include <android-base/strings.h>
30 #include <android-base/unique_fd.h>
31 #include <brotli/encode.h>
32 #include <libsnapshot/cow_compress.h>
33 #include <libsnapshot/cow_format.h>
34 #include <libsnapshot/cow_reader.h>
35 #include <libsnapshot/cow_writer.h>
36 #include <lz4.h>
37 #include <zlib.h>
38 #include <zstd.h>
39 
40 namespace android {
41 namespace snapshot {
42 
CompressionAlgorithmFromString(std::string_view name)43 std::optional<CowCompressionAlgorithm> CompressionAlgorithmFromString(std::string_view name) {
44     if (name == "gz") {
45         return {kCowCompressGz};
46     } else if (name == "brotli") {
47         return {kCowCompressBrotli};
48     } else if (name == "lz4") {
49         return {kCowCompressLz4};
50     } else if (name == "zstd") {
51         return {kCowCompressZstd};
52     } else if (name == "none" || name.empty()) {
53         return {kCowCompressNone};
54     } else {
55         LOG(ERROR) << "unable to determine default compression algorithm for: " << name;
56         return {};
57     }
58 }
59 
Create(CowCompression compression,const uint32_t block_size)60 std::unique_ptr<ICompressor> ICompressor::Create(CowCompression compression,
61                                                  const uint32_t block_size) {
62     switch (compression.algorithm) {
63         case kCowCompressLz4:
64             return ICompressor::Lz4(compression.compression_level, block_size);
65         case kCowCompressBrotli:
66             return ICompressor::Brotli(compression.compression_level, block_size);
67         case kCowCompressGz:
68             return ICompressor::Gz(compression.compression_level, block_size);
69         case kCowCompressZstd:
70             return ICompressor::Zstd(compression.compression_level, block_size);
71         case kCowCompressNone:
72             return nullptr;
73     }
74     return nullptr;
75 }
76 
77 // 1. Default compression level is determined by compression algorithm
78 // 2. There might be compatibility issues if a value is changed here, as  some older versions of
79 // Android will assume a different compression level, causing cow_size estimation differences that
80 // will lead to OTA failure. Ensure that the device and OTA package use the same compression level
81 // for OTA to succeed.
GetDefaultCompressionLevel(CowCompressionAlgorithm compression)82 uint32_t CompressWorker::GetDefaultCompressionLevel(CowCompressionAlgorithm compression) {
83     switch (compression) {
84         case kCowCompressGz: {
85             return Z_BEST_COMPRESSION;
86         }
87         case kCowCompressBrotli: {
88             return BROTLI_DEFAULT_QUALITY;
89         }
90         case kCowCompressLz4: {
91             break;
92         }
93         case kCowCompressZstd: {
94             return ZSTD_defaultCLevel();
95         }
96         case kCowCompressNone: {
97             break;
98         }
99     }
100     return 0;
101 }
102 
103 class GzCompressor final : public ICompressor {
104   public:
GzCompressor(int32_t compression_level,const uint32_t block_size)105     GzCompressor(int32_t compression_level, const uint32_t block_size)
106         : ICompressor(compression_level, block_size){};
107 
Compress(const void * data,size_t length) const108     std::vector<uint8_t> Compress(const void* data, size_t length) const override {
109         const auto bound = compressBound(length);
110         std::vector<uint8_t> buffer(bound, '\0');
111 
112         uLongf dest_len = bound;
113         auto rv = compress2(buffer.data(), &dest_len, reinterpret_cast<const Bytef*>(data), length,
114                             GetCompressionLevel());
115         if (rv != Z_OK) {
116             LOG(ERROR) << "compress2 returned: " << rv;
117             return {};
118         }
119         buffer.resize(dest_len);
120         return buffer;
121     };
122 };
123 
124 class Lz4Compressor final : public ICompressor {
125   public:
Lz4Compressor(int32_t compression_level,const uint32_t block_size)126     Lz4Compressor(int32_t compression_level, const uint32_t block_size)
127         : ICompressor(compression_level, block_size){};
128 
Compress(const void * data,size_t length) const129     std::vector<uint8_t> Compress(const void* data, size_t length) const override {
130         const auto bound = LZ4_compressBound(length);
131         if (!bound) {
132             LOG(ERROR) << "LZ4_compressBound returned 0";
133             return {};
134         }
135         std::vector<uint8_t> buffer(bound, '\0');
136 
137         const auto compressed_size =
138                 LZ4_compress_default(static_cast<const char*>(data),
139                                      reinterpret_cast<char*>(buffer.data()), length, buffer.size());
140         if (compressed_size <= 0) {
141             LOG(ERROR) << "LZ4_compress_default failed, input size: " << length
142                        << ", compression bound: " << bound << ", ret: " << compressed_size;
143             return {};
144         }
145         // Don't run compression if the compressed output is larger
146         if (compressed_size >= length) {
147             buffer.resize(length);
148             memcpy(buffer.data(), data, length);
149         } else {
150             buffer.resize(compressed_size);
151         }
152         return buffer;
153     };
154 };
155 
156 class BrotliCompressor final : public ICompressor {
157   public:
BrotliCompressor(int32_t compression_level,const uint32_t block_size)158     BrotliCompressor(int32_t compression_level, const uint32_t block_size)
159         : ICompressor(compression_level, block_size){};
160 
Compress(const void * data,size_t length) const161     std::vector<uint8_t> Compress(const void* data, size_t length) const override {
162         const auto bound = BrotliEncoderMaxCompressedSize(length);
163         if (!bound) {
164             LOG(ERROR) << "BrotliEncoderMaxCompressedSize returned 0";
165             return {};
166         }
167         std::vector<uint8_t> buffer(bound, '\0');
168 
169         size_t encoded_size = bound;
170         auto rv = BrotliEncoderCompress(
171                 GetCompressionLevel(), BROTLI_DEFAULT_WINDOW, BROTLI_DEFAULT_MODE, length,
172                 reinterpret_cast<const uint8_t*>(data), &encoded_size, buffer.data());
173         if (!rv) {
174             LOG(ERROR) << "BrotliEncoderCompress failed";
175             return {};
176         }
177         buffer.resize(encoded_size);
178         return buffer;
179     };
180 };
181 
182 class ZstdCompressor final : public ICompressor {
183   public:
ZstdCompressor(int32_t compression_level,const uint32_t block_size)184     ZstdCompressor(int32_t compression_level, const uint32_t block_size)
185         : ICompressor(compression_level, block_size),
186           zstd_context_(ZSTD_createCCtx(), ZSTD_freeCCtx) {
187         ZSTD_CCtx_setParameter(zstd_context_.get(), ZSTD_c_compressionLevel, compression_level);
188         ZSTD_CCtx_setParameter(zstd_context_.get(), ZSTD_c_windowLog, log2(GetBlockSize()));
189     };
190 
Compress(const void * data,size_t length) const191     std::vector<uint8_t> Compress(const void* data, size_t length) const override {
192         std::vector<uint8_t> buffer(ZSTD_compressBound(length), '\0');
193         const auto compressed_size =
194                 ZSTD_compress2(zstd_context_.get(), buffer.data(), buffer.size(), data, length);
195         if (compressed_size <= 0) {
196             LOG(ERROR) << "ZSTD compression failed " << compressed_size;
197             return {};
198         }
199         // Don't run compression if the compressed output is larger
200         if (compressed_size >= length) {
201             buffer.resize(length);
202             memcpy(buffer.data(), data, length);
203         } else {
204             buffer.resize(compressed_size);
205         }
206         return buffer;
207     };
208 
209   private:
210     std::unique_ptr<ZSTD_CCtx, decltype(&ZSTD_freeCCtx)> zstd_context_;
211 };
212 
CompressBlocks(const void * buffer,size_t num_blocks,size_t block_size,std::vector<std::vector<uint8_t>> * compressed_data)213 bool CompressWorker::CompressBlocks(const void* buffer, size_t num_blocks, size_t block_size,
214                                     std::vector<std::vector<uint8_t>>* compressed_data) {
215     return CompressBlocks(compressor_.get(), block_size, buffer, num_blocks, compressed_data);
216 }
217 
CompressBlocks(ICompressor * compressor,size_t block_size,const void * buffer,size_t num_blocks,std::vector<std::vector<uint8_t>> * compressed_data)218 bool CompressWorker::CompressBlocks(ICompressor* compressor, size_t block_size, const void* buffer,
219                                     size_t num_blocks,
220                                     std::vector<std::vector<uint8_t>>* compressed_data) {
221     const uint8_t* iter = reinterpret_cast<const uint8_t*>(buffer);
222     while (num_blocks) {
223         auto data = compressor->Compress(iter, block_size);
224         if (data.empty()) {
225             PLOG(ERROR) << "CompressBlocks: Compression failed";
226             return false;
227         }
228         if (data.size() > std::numeric_limits<uint32_t>::max()) {
229             LOG(ERROR) << "Compressed block is too large: " << data.size();
230             return false;
231         }
232 
233         compressed_data->emplace_back(std::move(data));
234         num_blocks -= 1;
235         iter += block_size;
236     }
237     return true;
238 }
239 
RunThread()240 bool CompressWorker::RunThread() {
241     while (true) {
242         // Wait for work
243         CompressWork blocks;
244         {
245             std::unique_lock<std::mutex> lock(lock_);
246             while (work_queue_.empty() && !stopped_) {
247                 cv_.wait(lock);
248             }
249 
250             if (stopped_) {
251                 return true;
252             }
253 
254             blocks = std::move(work_queue_.front());
255             work_queue_.pop();
256         }
257 
258         // Compress blocks
259         bool ret = CompressBlocks(blocks.buffer, blocks.num_blocks, blocks.block_size,
260                                   &blocks.compressed_data);
261         blocks.compression_status = ret;
262         {
263             std::lock_guard<std::mutex> lock(lock_);
264             compressed_queue_.push(std::move(blocks));
265         }
266 
267         // Notify completion
268         cv_.notify_all();
269 
270         if (!ret) {
271             LOG(ERROR) << "CompressBlocks failed";
272             return false;
273         }
274     }
275 
276     return true;
277 }
278 
EnqueueCompressBlocks(const void * buffer,size_t block_size,size_t num_blocks)279 void CompressWorker::EnqueueCompressBlocks(const void* buffer, size_t block_size,
280                                            size_t num_blocks) {
281     {
282         std::lock_guard<std::mutex> lock(lock_);
283 
284         CompressWork blocks = {};
285         blocks.buffer = buffer;
286         blocks.block_size = block_size;
287         blocks.num_blocks = num_blocks;
288         work_queue_.push(std::move(blocks));
289         total_submitted_ += 1;
290     }
291     cv_.notify_all();
292 }
293 
GetCompressedBuffers(std::vector<std::vector<uint8_t>> * compressed_buf)294 bool CompressWorker::GetCompressedBuffers(std::vector<std::vector<uint8_t>>* compressed_buf) {
295     while (true) {
296         std::unique_lock<std::mutex> lock(lock_);
297         while ((total_submitted_ != total_processed_) && compressed_queue_.empty() && !stopped_) {
298             cv_.wait(lock);
299         }
300         while (compressed_queue_.size() > 0) {
301             CompressWork blocks = std::move(compressed_queue_.front());
302             compressed_queue_.pop();
303             total_processed_ += 1;
304 
305             if (blocks.compression_status) {
306                 compressed_buf->insert(compressed_buf->end(),
307                                        std::make_move_iterator(blocks.compressed_data.begin()),
308                                        std::make_move_iterator(blocks.compressed_data.end()));
309             } else {
310                 LOG(ERROR) << "Block compression failed";
311                 return false;
312             }
313         }
314         if ((total_submitted_ == total_processed_) || stopped_) {
315             total_submitted_ = 0;
316             total_processed_ = 0;
317             return true;
318         }
319     }
320 }
321 
Brotli(const int32_t compression_level,const uint32_t block_size)322 std::unique_ptr<ICompressor> ICompressor::Brotli(const int32_t compression_level,
323                                                  const uint32_t block_size) {
324     return std::make_unique<BrotliCompressor>(compression_level, block_size);
325 }
326 
Gz(const int32_t compression_level,const uint32_t block_size)327 std::unique_ptr<ICompressor> ICompressor::Gz(const int32_t compression_level,
328                                              const uint32_t block_size) {
329     return std::make_unique<GzCompressor>(compression_level, block_size);
330 }
331 
Lz4(const int32_t compression_level,const uint32_t block_size)332 std::unique_ptr<ICompressor> ICompressor::Lz4(const int32_t compression_level,
333                                               const uint32_t block_size) {
334     return std::make_unique<Lz4Compressor>(compression_level, block_size);
335 }
336 
Zstd(const int32_t compression_level,const uint32_t block_size)337 std::unique_ptr<ICompressor> ICompressor::Zstd(const int32_t compression_level,
338                                                const uint32_t block_size) {
339     return std::make_unique<ZstdCompressor>(compression_level, block_size);
340 }
341 
Finalize()342 void CompressWorker::Finalize() {
343     {
344         std::unique_lock<std::mutex> lock(lock_);
345         stopped_ = true;
346     }
347     cv_.notify_all();
348 }
349 
CompressWorker(std::unique_ptr<ICompressor> && compressor)350 CompressWorker::CompressWorker(std::unique_ptr<ICompressor>&& compressor)
351     : compressor_(std::move(compressor)) {}
352 
353 }  // namespace snapshot
354 }  // namespace android
355