1 
2 /*
3  * Copyright (C) 2024 The Android Open Source Project
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 
18 #include "aconfigd.h"
19 #include "aconfigd_util.h"
20 #include "storage_files_manager.h"
21 
22 using namespace aconfig_storage;
23 
24 namespace android {
25   namespace aconfigd {
26 
27   /// get storage files object for a container
GetStorageFiles(const std::string & container)28   base::Result<StorageFiles*> StorageFilesManager::GetStorageFiles(
29       const std::string& container) {
30     if (all_storage_files_.count(container) == 0) {
31       return base::Error() << "Missing storage files object for " << container;
32     }
33     return all_storage_files_[container].get();
34   }
35 
36   /// create mapped files for a container
AddNewStorageFiles(const std::string & container,const std::string & package_map,const std::string & flag_map,const std::string & flag_val)37   base::Result<StorageFiles*> StorageFilesManager::AddNewStorageFiles(
38       const std::string& container,
39       const std::string& package_map,
40       const std::string& flag_map,
41       const std::string& flag_val) {
42     if (all_storage_files_.count(container)) {
43       return base::Error() << "Storage file object for " << container << " already exists";
44     }
45 
46     auto result = base::Result<void>({});
47     auto storage_files = std::make_unique<StorageFiles>(
48           container, package_map, flag_map, flag_val, root_dir_, result);
49 
50     if (!result.ok()) {
51       return base::Error() << "Failed to create storage file object for " << container
52                      << ": " << result.error();
53     }
54 
55     auto storage_files_ptr = storage_files.get();
56     all_storage_files_[container].reset(storage_files.release());
57     return storage_files_ptr;
58   }
59 
60   /// restore storage files object from a storage record pb entry
RestoreStorageFiles(const PersistStorageRecord & pb)61   base::Result<void> StorageFilesManager::RestoreStorageFiles(
62       const PersistStorageRecord& pb) {
63     if (all_storage_files_.count(pb.container())) {
64       return base::Error() << "Storage file object for " << pb.container()
65                      << " already exists";
66     }
67 
68     all_storage_files_[pb.container()] = std::make_unique<StorageFiles>(pb, root_dir_);
69     return {};
70   }
71 
72   /// update existing storage files object with new storage file set
UpdateStorageFiles(const std::string & container,const std::string & package_map,const std::string & flag_map,const std::string & flag_val)73   base::Result<void> StorageFilesManager::UpdateStorageFiles(
74       const std::string& container,
75       const std::string& package_map,
76       const std::string& flag_map,
77       const std::string& flag_val) {
78     if (!all_storage_files_.count(container)) {
79       return base::Error() << "Failed to update storage files object for " << container
80                      << ", it does not exist";
81     }
82 
83     // backup server and local override
84     auto storage_files = GetStorageFiles(container);
85     RETURN_IF_ERROR(storage_files, "Failed to get storage files object");
86     auto server_overrides = (**storage_files).GetServerFlagValues();
87     RETURN_IF_ERROR(server_overrides, "Failed to get existing server overrides");
88 
89     auto pb_file = (**storage_files).GetStorageRecord().local_overrides;
90     auto local_overrides = ReadPbFromFile<LocalFlagOverrides>(pb_file);
91     RETURN_IF_ERROR(local_overrides, "Failed to read local overrides from " + pb_file);
92 
93     // clean up existing storage files object and recreate
94     (**storage_files).RemoveAllPersistFiles();
95     all_storage_files_.erase(container);
96     storage_files = AddNewStorageFiles(container, package_map, flag_map, flag_val);
97     RETURN_IF_ERROR(storage_files, "Failed to add a new storage object for " + container);
98 
99     // reapply local overrides
100     auto updated_local_overrides = LocalFlagOverrides();
101     for (const auto& entry : local_overrides->overrides()) {
102       auto has_flag = (**storage_files).HasFlag(entry.package_name(), entry.flag_name());
103       RETURN_IF_ERROR(has_flag, "Failed to check if has flag for " + entry.package_name()
104                       + "/" + entry.flag_name());
105       if (*has_flag) {
106         auto context = (**storage_files).GetPackageFlagContext(
107             entry.package_name(), entry.flag_name());
108         RETURN_IF_ERROR(context, "Failed to find package flag context for " +
109                         entry.package_name() + "/" + entry.flag_name());
110 
111         auto update = (**storage_files).SetHasLocalOverride(*context, true);
112         RETURN_IF_ERROR(update, "Failed to set flag has local override");
113 
114         auto* new_override = updated_local_overrides.add_overrides();
115         new_override->set_package_name(entry.package_name());
116         new_override->set_flag_name(entry.flag_name());
117         new_override->set_flag_value(entry.flag_value());
118       }
119     }
120     auto result = WritePbToFile<LocalFlagOverrides>(updated_local_overrides, pb_file);
121 
122     // reapply server overrides
123     for (const auto& entry : *server_overrides) {
124       auto has_flag = (**storage_files).HasFlag(entry.package_name, entry.flag_name);
125       RETURN_IF_ERROR(has_flag, "Failed to check if has flag for " + entry.package_name
126                       + "/" + entry.flag_name);
127       if (*has_flag) {
128         auto context = (**storage_files).GetPackageFlagContext(
129             entry.package_name, entry.flag_name);
130         RETURN_IF_ERROR(context, "Failed to find package flag context for " +
131                         entry.package_name + "/" + entry.flag_name);
132 
133         auto update = (**storage_files).SetServerFlagValue(*context, entry.flag_value);
134         RETURN_IF_ERROR(update, "Failed to set server flag value");
135       }
136     }
137 
138     return {};
139   }
140 
141   /// add or update storage file set for a container
AddOrUpdateStorageFiles(const std::string & container,const std::string & package_map,const std::string & flag_map,const std::string & flag_val)142   base::Result<bool> StorageFilesManager::AddOrUpdateStorageFiles(
143       const std::string& container,
144       const std::string& package_map,
145       const std::string& flag_map,
146       const std::string& flag_val) {
147     bool new_container = !HasContainer(container);
148     bool update_existing_container = false;
149     if (!new_container) {
150       auto digest = GetFilesDigest({package_map, flag_map, flag_val});
151       RETURN_IF_ERROR(digest, "Failed to get digest for " + container);
152       auto storage_files = GetStorageFiles(container);
153       RETURN_IF_ERROR(storage_files, "Failed to get storage files object");
154       if ((**storage_files).GetStorageRecord().digest != *digest) {
155         update_existing_container = true;
156       }
157     }
158 
159     // early return if no update is needed
160     if (!(new_container || update_existing_container)) {
161       return false;
162     }
163 
164     if (new_container) {
165       auto storage_files = AddNewStorageFiles(
166           container, package_map, flag_map, flag_val);
167       RETURN_IF_ERROR(storage_files, "Failed to add a new storage object for " + container);
168     } else {
169       auto storage_files = UpdateStorageFiles(
170           container, package_map, flag_map, flag_val);
171       RETURN_IF_ERROR(storage_files, "Failed to update storage object for " + container);
172     }
173 
174     return true;
175   }
176 
177   /// create boot copy
CreateStorageBootCopy(const std::string & container)178   base::Result<void> StorageFilesManager::CreateStorageBootCopy(
179       const std::string& container) {
180     if (!HasContainer(container)) {
181       return base::Error() << "Cannot create boot copy without persist copy for " << container;
182     }
183     auto storage_files = GetStorageFiles(container);
184     auto copy_result = (**storage_files).CreateBootStorageFiles();
185     RETURN_IF_ERROR(copy_result, "Failed to create boot copies for " + container);
186     return {};
187   }
188 
189   /// reset all storage
ResetAllStorage()190   base::Result<void> StorageFilesManager::ResetAllStorage() {
191     for (const auto& container : GetAllContainers()) {
192       auto storage_files = GetStorageFiles(container);
193       RETURN_IF_ERROR(storage_files, "Failed to get storage files object");
194       bool available = (**storage_files).HasBootCopy();
195       StorageRecord record = (**storage_files).GetStorageRecord();
196 
197       (**storage_files).RemoveAllPersistFiles();
198       all_storage_files_.erase(container);
199 
200       if (available) {
201         auto storage_files = AddNewStorageFiles(
202             container, record.package_map, record.flag_map, record.flag_val);
203         RETURN_IF_ERROR(storage_files, "Failed to add a new storage object for " + container);
204       }
205     }
206     return {};
207   }
208 
209   /// get container name given flag package name
GetContainer(const std::string & package)210   base::Result<std::string> StorageFilesManager::GetContainer(
211       const std::string& package) {
212     if (package_to_container_.count(package)) {
213       return package_to_container_[package];
214     }
215 
216     for (const auto& [container, storage_files] : all_storage_files_) {
217       auto has_flag = storage_files->HasPackage(package);
218       RETURN_IF_ERROR(has_flag, "Failed to check if has flag");
219 
220       if (*has_flag) {
221         package_to_container_[package] = container;
222         return container;
223       }
224     }
225 
226     return base::Error() << "container not found";
227   }
228 
229   /// Get all storage records
GetAllStorageRecords()230   std::vector<const StorageRecord*> StorageFilesManager::GetAllStorageRecords() {
231     auto all_records = std::vector<const StorageRecord*>();
232     for (auto const& [container, files_ptr] : all_storage_files_) {
233       all_records.push_back(&(files_ptr->GetStorageRecord()));
234     }
235     return all_records;
236   }
237 
238   /// get all containers
GetAllContainers()239   std::vector<std::string> StorageFilesManager::GetAllContainers() {
240     auto containers = std::vector<std::string>();
241     for (const auto& item : all_storage_files_) {
242       containers.push_back(item.first);
243     }
244     return containers;
245   }
246 
247   /// write to persist storage records pb file
WritePersistStorageRecordsToFile(const std::string & file_name)248   base::Result<void> StorageFilesManager::WritePersistStorageRecordsToFile(
249       const std::string& file_name) {
250     auto records_pb = PersistStorageRecords();
251     for (const auto& [container, storage_files] : all_storage_files_) {
252       const auto& record = storage_files->GetStorageRecord();
253       auto* record_pb = records_pb.add_records();
254       record_pb->set_version(record.version);
255       record_pb->set_container(record.container);
256       record_pb->set_package_map(record.package_map);
257       record_pb->set_flag_map(record.flag_map);
258       record_pb->set_flag_val(record.flag_val);
259       record_pb->set_digest(record.digest);
260     }
261     return WritePbToFile<PersistStorageRecords>(records_pb, file_name);
262   }
263 
264   /// apply flag override
UpdateFlagValue(const std::string & package_name,const std::string & flag_name,const std::string & flag_value,bool is_local_override)265   base::Result<void> StorageFilesManager::UpdateFlagValue(
266       const std::string& package_name,
267       const std::string& flag_name,
268       const std::string& flag_value,
269       bool is_local_override) {
270 
271     auto container = GetContainer(package_name);
272     RETURN_IF_ERROR(container, "Failed to find owning container");
273 
274     auto storage_files = GetStorageFiles(*container);
275     RETURN_IF_ERROR(storage_files, "Failed to get storage files object");
276 
277     auto context = (**storage_files).GetPackageFlagContext(package_name, flag_name);
278     RETURN_IF_ERROR(context, "Failed to find package flag context");
279 
280     if (is_local_override) {
281       auto update = (**storage_files).SetLocalFlagValue(*context, flag_value);
282       RETURN_IF_ERROR(update, "Failed to set local flag override");
283     } else {
284       auto update =(**storage_files).SetServerFlagValue(*context, flag_value);
285       RETURN_IF_ERROR(update, "Failed to set server flag value");
286     }
287 
288     return {};
289   }
290 
291   /// apply ota flags and return remaining ota flags
ApplyOTAFlagsForContainer(const std::string & container,const std::vector<FlagOverride> & ota_flags)292   base::Result<std::vector<FlagOverride>> StorageFilesManager::ApplyOTAFlagsForContainer(
293       const std::string& container,
294       const std::vector<FlagOverride>& ota_flags) {
295     auto storage_files = GetStorageFiles(container);
296     RETURN_IF_ERROR(storage_files, "Failed to get storage files object");
297 
298     auto remaining_ota_flags = std::vector<FlagOverride>();
299     for (const auto& entry : ota_flags) {
300       auto has_flag = (**storage_files).HasPackage(entry.package_name());
301       RETURN_IF_ERROR(has_flag, "Failed to check if has flag");
302       if (*has_flag) {
303         auto result = UpdateFlagValue(entry.package_name(),
304                                       entry.flag_name(),
305                                       entry.flag_value());
306         RETURN_IF_ERROR(result, "Failed to apply staged OTA flag " + entry.package_name()
307                         + "/" + entry.flag_name());
308       } else {
309         remaining_ota_flags.push_back(entry);
310       }
311     }
312 
313     return remaining_ota_flags;
314   }
315 
316   /// remove all local overrides
RemoveAllLocalOverrides()317   base::Result<void> StorageFilesManager::RemoveAllLocalOverrides() {
318     for (const auto& [container, storage_files] : all_storage_files_) {
319       auto update = storage_files->RemoveAllLocalFlagValue();
320       RETURN_IF_ERROR(update, "Failed to remove local overrides for " + container);
321     }
322     return {};
323   }
324 
325   /// remove a local override
RemoveFlagLocalOverride(const std::string & package,const std::string & flag)326   base::Result<void> StorageFilesManager::RemoveFlagLocalOverride(
327       const std::string& package,
328       const std::string& flag) {
329     auto container = GetContainer(package);
330     RETURN_IF_ERROR(container, "Failed to find owning container");
331 
332     auto storage_files = GetStorageFiles(*container);
333     RETURN_IF_ERROR(storage_files, "Failed to get storage files object");
334 
335     auto context = (**storage_files).GetPackageFlagContext(package, flag);
336     RETURN_IF_ERROR(context, "Failed to find package flag context");
337 
338     auto removed = (**storage_files).RemoveLocalFlagValue(*context);
339     RETURN_IF_ERROR(removed, "Failed to remove local override");
340 
341     return {};
342   }
343 
344   /// list a flag
ListFlag(const std::string & package,const std::string & flag)345   base::Result<StorageFiles::FlagSnapshot> StorageFilesManager::ListFlag(
346       const std::string& package,
347       const std::string& flag) {
348     auto container = GetContainer(package);
349     RETURN_IF_ERROR(container, "Failed to find owning container");
350     auto storage_files = GetStorageFiles(*container);
351     RETURN_IF_ERROR(storage_files, "Failed to get storage files object");
352 
353     if ((**storage_files).HasBootCopy()) {
354       return (**storage_files).ListFlag(package, flag);
355     } else{
356       return base::Error() << "Container " << *container << " is currently unavailable";
357     }
358   }
359 
360   /// list flags in a package
361   base::Result<std::vector<StorageFiles::FlagSnapshot>>
ListFlagsInPackage(const std::string & package)362       StorageFilesManager::ListFlagsInPackage(const std::string& package) {
363     auto container = GetContainer(package);
364     RETURN_IF_ERROR(container, "Failed to find owning container for " + package);
365     auto storage_files = GetStorageFiles(*container);
366     RETURN_IF_ERROR(storage_files, "Failed to get storage files object");
367 
368     if ((**storage_files).HasBootCopy()) {
369       return (**storage_files).ListFlags(package);
370     } else{
371       return base::Error() << "Container " << *container << " is currently unavailable";
372     }
373   }
374 
375   /// list flags in a container
376   base::Result<std::vector<StorageFiles::FlagSnapshot>>
ListFlagsInContainer(const std::string & container)377       StorageFilesManager::ListFlagsInContainer(const std::string& container) {
378     auto storage_files = GetStorageFiles(container);
379     RETURN_IF_ERROR(storage_files, "Failed to get storage files object");
380 
381     if ((**storage_files).HasBootCopy()) {
382       return (**storage_files).ListFlags();
383     } else {
384       return base::Error() << "Container " << container << " is currently unavailable";
385     }
386   }
387 
388   /// list all available flags
389   base::Result<std::vector<StorageFiles::FlagSnapshot>>
ListAllAvailableFlags()390       StorageFilesManager::ListAllAvailableFlags() {
391     auto total_flags = std::vector<StorageFiles::FlagSnapshot>();
392     for (const auto& [container, storage_files] : all_storage_files_) {
393       if (!storage_files->HasBootCopy()) {
394         continue;
395       }
396       auto flags = storage_files->ListFlags();
397       RETURN_IF_ERROR(flags, "Failed to list flags in " + container);
398       total_flags.reserve(total_flags.size() + flags->size());
399       total_flags.insert(total_flags.end(), flags->begin(), flags->end());
400     }
401     return total_flags;
402   }
403 
404   } // namespace aconfigd
405 } // namespace android
406