1 /* 2 * Copyright (C) 2023 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.federatedcompute.services.data; 18 19 import static com.android.adservices.service.stats.AdServicesStatsLog.AD_SERVICES_ERROR_REPORTED__ERROR_CODE__DELETE_TASK_FAILURE; 20 import static com.android.adservices.service.stats.AdServicesStatsLog.AD_SERVICES_ERROR_REPORTED__PPAPI_NAME__FEDERATED_COMPUTE; 21 import static com.android.federatedcompute.services.data.FederatedTraningTaskContract.FEDERATED_TRAINING_TASKS_TABLE; 22 23 import android.annotation.NonNull; 24 import android.annotation.Nullable; 25 import android.content.ContentValues; 26 import android.content.Context; 27 import android.database.Cursor; 28 import android.database.SQLException; 29 import android.database.sqlite.SQLiteDatabase; 30 import android.database.sqlite.SQLiteException; 31 32 import com.android.federatedcompute.internal.util.LogUtil; 33 import com.android.federatedcompute.services.data.FederatedTraningTaskContract.FederatedTrainingTaskColumns; 34 import com.android.federatedcompute.services.statsd.ClientErrorLogger; 35 36 import com.google.common.annotations.VisibleForTesting; 37 import com.google.common.collect.Iterables; 38 39 import java.util.ArrayList; 40 import java.util.List; 41 42 /** DAO for accessing training task table. */ 43 public class FederatedTrainingTaskDao { 44 45 private static final String TAG = FederatedTrainingTaskDao.class.getSimpleName(); 46 47 private final FederatedComputeDbHelper mDbHelper; 48 private static volatile FederatedTrainingTaskDao sSingletonInstance; 49 FederatedTrainingTaskDao(FederatedComputeDbHelper dbHelper)50 private FederatedTrainingTaskDao(FederatedComputeDbHelper dbHelper) { 51 this.mDbHelper = dbHelper; 52 } 53 54 /** Returns an instance of the FederatedTrainingTaskDao given a context. */ 55 @NonNull getInstance(Context context)56 public static FederatedTrainingTaskDao getInstance(Context context) { 57 if (sSingletonInstance == null) { 58 synchronized (FederatedTrainingTaskDao.class) { 59 if (sSingletonInstance == null) { 60 sSingletonInstance = 61 new FederatedTrainingTaskDao( 62 FederatedComputeDbHelper.getInstance(context)); 63 } 64 } 65 } 66 return sSingletonInstance; 67 } 68 69 /** It's only public to unit test. */ 70 @VisibleForTesting getInstanceForTest(Context context)71 public static FederatedTrainingTaskDao getInstanceForTest(Context context) { 72 synchronized (FederatedTrainingTaskDao.class) { 73 if (sSingletonInstance == null) { 74 FederatedComputeDbHelper dbHelper = 75 FederatedComputeDbHelper.getInstanceForTest(context); 76 sSingletonInstance = new FederatedTrainingTaskDao(dbHelper); 77 } 78 return sSingletonInstance; 79 } 80 } 81 82 /** Deletes a training task in FederatedTrainingTask table. */ deleteFederatedTrainingTask(String selection, String[] selectionArgs)83 private void deleteFederatedTrainingTask(String selection, String[] selectionArgs) { 84 SQLiteDatabase db = mDbHelper.safeGetWritableDatabase(); 85 if (db == null) { 86 ClientErrorLogger.getInstance() 87 .logError( 88 AD_SERVICES_ERROR_REPORTED__ERROR_CODE__DELETE_TASK_FAILURE, 89 AD_SERVICES_ERROR_REPORTED__PPAPI_NAME__FEDERATED_COMPUTE); 90 return; 91 } 92 db.delete(FEDERATED_TRAINING_TASKS_TABLE, selection, selectionArgs); 93 } 94 95 /** Insert a training task or update it if task already exists. */ updateOrInsertFederatedTrainingTask(FederatedTrainingTask trainingTask)96 public boolean updateOrInsertFederatedTrainingTask(FederatedTrainingTask trainingTask) { 97 try { 98 SQLiteDatabase db = mDbHelper.safeGetWritableDatabase(); 99 if (db == null) { 100 return false; 101 } 102 return trainingTask.addToDatabase(db); 103 } catch (SQLException e) { 104 LogUtil.e( 105 TAG, 106 e, 107 "Failed to persist federated training task %s", 108 trainingTask.populationName()); 109 return false; 110 } 111 } 112 113 /** Get the list of tasks that match select conditions. */ 114 @Nullable getFederatedTrainingTask( String selection, String[] selectionArgs)115 public List<FederatedTrainingTask> getFederatedTrainingTask( 116 String selection, String[] selectionArgs) { 117 SQLiteDatabase db = mDbHelper.safeGetReadableDatabase(); 118 if (db == null) { 119 return null; 120 } 121 return FederatedTrainingTask.readFederatedTrainingTasksFromDatabase( 122 db, selection, selectionArgs); 123 } 124 125 /** Delete a task from table based on job scheduler id. */ findAndRemoveTaskByJobId(int jobId)126 public FederatedTrainingTask findAndRemoveTaskByJobId(int jobId) { 127 String selection = FederatedTrainingTaskColumns.JOB_SCHEDULER_JOB_ID + " = ?"; 128 String[] selectionArgs = selectionArgs(jobId); 129 FederatedTrainingTask task = 130 Iterables.getOnlyElement(getFederatedTrainingTask(selection, selectionArgs), null); 131 try { 132 if (task != null) { 133 deleteFederatedTrainingTask(selection, selectionArgs); 134 } 135 return task; 136 } catch (SQLException e) { 137 LogUtil.e(TAG, e, "Failed to delete federated training task by job id %d", jobId); 138 ClientErrorLogger.getInstance() 139 .logErrorWithExceptionInfo( 140 e, 141 AD_SERVICES_ERROR_REPORTED__ERROR_CODE__DELETE_TASK_FAILURE, 142 AD_SERVICES_ERROR_REPORTED__PPAPI_NAME__FEDERATED_COMPUTE); 143 return null; 144 } 145 } 146 147 /** Delete a task from table based on population name. */ findAndRemoveTaskByPopulationName(String populationName)148 public FederatedTrainingTask findAndRemoveTaskByPopulationName(String populationName) { 149 String selection = FederatedTrainingTaskColumns.POPULATION_NAME + " = ?"; 150 String[] selectionArgs = {populationName}; 151 FederatedTrainingTask task = 152 Iterables.getOnlyElement(getFederatedTrainingTask(selection, selectionArgs), null); 153 try { 154 if (task != null) { 155 deleteFederatedTrainingTask(selection, selectionArgs); 156 } 157 return task; 158 } catch (SQLException e) { 159 LogUtil.e( 160 TAG, 161 e, 162 "Failed to delete federated training task by population name %s", 163 populationName); 164 ClientErrorLogger.getInstance() 165 .logErrorWithExceptionInfo( 166 e, 167 AD_SERVICES_ERROR_REPORTED__ERROR_CODE__DELETE_TASK_FAILURE, 168 AD_SERVICES_ERROR_REPORTED__PPAPI_NAME__FEDERATED_COMPUTE); 169 return null; 170 } 171 } 172 173 /** Delete a task from table based on population name and calling package. */ findAndRemoveTaskByPopulationNameAndCallingPackage( String populationName, String callingPackage)174 public FederatedTrainingTask findAndRemoveTaskByPopulationNameAndCallingPackage( 175 String populationName, String callingPackage) { 176 String selection = 177 FederatedTrainingTaskColumns.POPULATION_NAME 178 + " = ? AND " 179 + FederatedTrainingTaskColumns.APP_PACKAGE_NAME 180 + " = ?"; 181 String[] selectionArgs = {populationName, callingPackage}; 182 FederatedTrainingTask task = 183 Iterables.getOnlyElement(getFederatedTrainingTask(selection, selectionArgs), null); 184 try { 185 if (task != null) { 186 deleteFederatedTrainingTask(selection, selectionArgs); 187 } 188 return task; 189 } catch (SQLException e) { 190 LogUtil.e( 191 TAG, 192 e, 193 "Failed to delete federated training task by " 194 + "population name %s and calling package: %s", 195 populationName, 196 callingPackage); 197 ClientErrorLogger.getInstance() 198 .logErrorWithExceptionInfo( 199 e, 200 AD_SERVICES_ERROR_REPORTED__ERROR_CODE__DELETE_TASK_FAILURE, 201 AD_SERVICES_ERROR_REPORTED__PPAPI_NAME__FEDERATED_COMPUTE); 202 return null; 203 } 204 } 205 206 /** Delete a task from table based on population name and owner Id (package and class name). */ findAndRemoveTaskByPopulationNameAndOwnerId( String populationName, String ownerPackage, String ownerClass, String ownerCertDigest)207 public FederatedTrainingTask findAndRemoveTaskByPopulationNameAndOwnerId( 208 String populationName, String ownerPackage, String ownerClass, String ownerCertDigest) { 209 String selection = 210 FederatedTrainingTaskColumns.POPULATION_NAME 211 + " = ? AND " 212 + FederatedTrainingTaskColumns.OWNER_PACKAGE 213 + " = ? AND " 214 + FederatedTrainingTaskColumns.OWNER_CLASS 215 + " = ? AND " 216 + FederatedTrainingTaskColumns.OWNER_ID_CERT_DIGEST 217 + " = ?"; 218 String[] selectionArgs = {populationName, ownerPackage, ownerClass, ownerCertDigest}; 219 FederatedTrainingTask task = 220 Iterables.getOnlyElement(getFederatedTrainingTask(selection, selectionArgs), null); 221 try { 222 if (task != null) { 223 deleteFederatedTrainingTask(selection, selectionArgs); 224 } 225 return task; 226 } catch (SQLException e) { 227 LogUtil.e( 228 TAG, 229 e, 230 "Failed to delete federated training task by population name %s and ATP: %s/%s", 231 populationName, 232 ownerPackage, 233 ownerClass); 234 ClientErrorLogger.getInstance() 235 .logErrorWithExceptionInfo( 236 e, 237 AD_SERVICES_ERROR_REPORTED__ERROR_CODE__DELETE_TASK_FAILURE, 238 AD_SERVICES_ERROR_REPORTED__PPAPI_NAME__FEDERATED_COMPUTE); 239 return null; 240 } 241 } 242 243 /** Delete a task from table based on population name and job scheduler id. */ findAndRemoveTaskByPopulationAndJobId( String populationName, int jobId)244 public FederatedTrainingTask findAndRemoveTaskByPopulationAndJobId( 245 String populationName, int jobId) { 246 String selection = 247 FederatedTrainingTaskColumns.POPULATION_NAME 248 + " = ? AND " 249 + FederatedTrainingTaskColumns.JOB_SCHEDULER_JOB_ID 250 + " = ?"; 251 String[] selectionArgs = {populationName, String.valueOf(jobId)}; 252 FederatedTrainingTask task = 253 Iterables.getOnlyElement(getFederatedTrainingTask(selection, selectionArgs), null); 254 try { 255 if (task != null) { 256 deleteFederatedTrainingTask(selection, selectionArgs); 257 } 258 return task; 259 } catch (SQLException e) { 260 LogUtil.e( 261 TAG, 262 e, 263 "Failed to delete federated training task by population name %s and job id %d", 264 populationName, 265 jobId); 266 ClientErrorLogger.getInstance() 267 .logErrorWithExceptionInfo( 268 e, 269 AD_SERVICES_ERROR_REPORTED__ERROR_CODE__DELETE_TASK_FAILURE, 270 AD_SERVICES_ERROR_REPORTED__PPAPI_NAME__FEDERATED_COMPUTE); 271 return null; 272 } 273 } 274 275 /** Returns number of tasks already belongs to given owners package. */ getTotalTrainingTaskPerOwnerPackage(String packageName)276 public int getTotalTrainingTaskPerOwnerPackage(String packageName) { 277 SQLiteDatabase db = mDbHelper.safeGetReadableDatabase(); 278 if (db == null) { 279 return 0; 280 } 281 final String query = 282 "SELECT COUNT(*) FROM " 283 + FEDERATED_TRAINING_TASKS_TABLE 284 + " WHERE " 285 + FederatedTrainingTaskColumns.OWNER_PACKAGE 286 + " = ?"; 287 try (Cursor cursor = db.rawQuery(query, new String[] {packageName})) { 288 if (cursor.moveToFirst()) { 289 return cursor.getInt(0); 290 } else { 291 return 0; // No matching tasks found 292 } 293 } 294 } 295 296 /** Insert a training task history record or update it if task already exists. */ updateOrInsertTaskHistory(TaskHistory taskHistory)297 public boolean updateOrInsertTaskHistory(TaskHistory taskHistory) { 298 try { 299 SQLiteDatabase db = mDbHelper.safeGetWritableDatabase(); 300 ContentValues values = new ContentValues(); 301 values.put(TaskHistoryContract.TaskHistoryEntry.JOB_ID, taskHistory.getJobId()); 302 values.put( 303 TaskHistoryContract.TaskHistoryEntry.POPULATION_NAME, 304 taskHistory.getPopulationName()); 305 values.put(TaskHistoryContract.TaskHistoryEntry.TASK_ID, taskHistory.getTaskId()); 306 values.put( 307 TaskHistoryContract.TaskHistoryEntry.CONTRIBUTION_ROUND, 308 taskHistory.getContributionRound()); 309 values.put( 310 TaskHistoryContract.TaskHistoryEntry.CONTRIBUTION_TIME, 311 taskHistory.getContributionTime()); 312 values.put( 313 TaskHistoryContract.TaskHistoryEntry.TOTAL_PARTICIPATION, 314 taskHistory.getTotalParticipation()); 315 return db.insertWithOnConflict( 316 TaskHistoryContract.TaskHistoryEntry.TABLE_NAME, 317 null, 318 values, 319 SQLiteDatabase.CONFLICT_REPLACE) 320 != -1; 321 } catch (SQLException e) { 322 LogUtil.e( 323 TAG, 324 "Failed to update or insert task history %s %s", 325 taskHistory.getPopulationName(), 326 taskHistory.getTaskId()); 327 } 328 return false; 329 } 330 331 /** Gets the list of task history based on provided job id, population name and task id. */ getTaskHistoryList(int jobId, String populationName, String taskId)332 public List<TaskHistory> getTaskHistoryList(int jobId, String populationName, String taskId) { 333 return getTaskHistory(jobId, populationName, taskId, false); 334 } 335 336 /** Get the latest record of task history based on job id, population name and task name. */ getLatestTaskHistory(int jobId, String populationName, String taskId)337 public TaskHistory getLatestTaskHistory(int jobId, String populationName, String taskId) { 338 List<TaskHistory> taskList = getTaskHistory(jobId, populationName, taskId, true); 339 return taskList.isEmpty() ? null : taskList.get(0); 340 } 341 getTaskHistory( int jobId, String populationName, String taskId, boolean latest)342 private List<TaskHistory> getTaskHistory( 343 int jobId, String populationName, String taskId, boolean latest) { 344 SQLiteDatabase db = mDbHelper.safeGetReadableDatabase(); 345 String selection = 346 TaskHistoryContract.TaskHistoryEntry.JOB_ID 347 + " = ? AND " 348 + TaskHistoryContract.TaskHistoryEntry.POPULATION_NAME 349 + " = ? AND " 350 + TaskHistoryContract.TaskHistoryEntry.TASK_ID 351 + " = ?"; 352 String[] selectionArgs = {String.valueOf(jobId), populationName, taskId}; 353 String orderBy = TaskHistoryContract.TaskHistoryEntry.CONTRIBUTION_TIME + " DESC"; 354 String[] projection = { 355 TaskHistoryContract.TaskHistoryEntry.CONTRIBUTION_TIME, 356 TaskHistoryContract.TaskHistoryEntry.CONTRIBUTION_ROUND, 357 TaskHistoryContract.TaskHistoryEntry.TOTAL_PARTICIPATION 358 }; 359 List<TaskHistory> taskList = new ArrayList<>(); 360 try (Cursor cursor = 361 db.query( 362 TaskHistoryContract.TaskHistoryEntry.TABLE_NAME, 363 projection, 364 selection, 365 selectionArgs, 366 /* groupBy= */ null, 367 /* having= */ null, 368 /* orderBy= */ orderBy)) { 369 while (cursor.moveToNext()) { 370 long contributionTime = 371 cursor.getLong( 372 cursor.getColumnIndexOrThrow( 373 TaskHistoryContract.TaskHistoryEntry.CONTRIBUTION_TIME)); 374 long contributionRound = 375 cursor.getLong( 376 cursor.getColumnIndexOrThrow( 377 TaskHistoryContract.TaskHistoryEntry.CONTRIBUTION_ROUND)); 378 long totalParticipation = 379 cursor.getLong( 380 cursor.getColumnIndexOrThrow( 381 TaskHistoryContract.TaskHistoryEntry.TOTAL_PARTICIPATION)); 382 taskList.add( 383 new TaskHistory.Builder() 384 .setJobId(jobId) 385 .setTaskId(taskId) 386 .setPopulationName(populationName) 387 .setContributionRound(contributionRound) 388 .setContributionTime(contributionTime) 389 .setTotalParticipation(totalParticipation) 390 .build()); 391 if (latest) { 392 cursor.close(); 393 return taskList; 394 } 395 } 396 cursor.close(); 397 return taskList; 398 } catch (SQLiteException e) { 399 LogUtil.e(TAG, e, "Failed to read TaskHistory db"); 400 } 401 return null; 402 } 403 404 /** Batch delete expired task history records. */ deleteExpiredTaskHistory(long deleteTime)405 public int deleteExpiredTaskHistory(long deleteTime) { 406 SQLiteDatabase db = mDbHelper.safeGetWritableDatabase(); 407 if (db == null) { 408 throw new SQLiteException(TAG + ": Failed to open database."); 409 } 410 String whereClause = TaskHistoryContract.TaskHistoryEntry.CONTRIBUTION_TIME + " < ?"; 411 String[] whereArgs = {String.valueOf(deleteTime)}; 412 int deletedRows = 413 db.delete(TaskHistoryContract.TaskHistoryEntry.TABLE_NAME, whereClause, whereArgs); 414 LogUtil.d(TAG, "Deleted %d expired tokens", deletedRows); 415 return deletedRows; 416 } 417 selectionArgs(Number... args)418 private String[] selectionArgs(Number... args) { 419 String[] values = new String[args.length]; 420 for (int i = 0; i < args.length; i++) { 421 values[i] = String.valueOf(args[i]); 422 } 423 return values; 424 } 425 } 426