1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.federatedcompute.services.data;
18 
19 import static com.android.adservices.service.stats.AdServicesStatsLog.AD_SERVICES_ERROR_REPORTED__ERROR_CODE__DELETE_TASK_FAILURE;
20 import static com.android.adservices.service.stats.AdServicesStatsLog.AD_SERVICES_ERROR_REPORTED__PPAPI_NAME__FEDERATED_COMPUTE;
21 import static com.android.federatedcompute.services.data.FederatedTraningTaskContract.FEDERATED_TRAINING_TASKS_TABLE;
22 
23 import android.annotation.NonNull;
24 import android.annotation.Nullable;
25 import android.content.ContentValues;
26 import android.content.Context;
27 import android.database.Cursor;
28 import android.database.SQLException;
29 import android.database.sqlite.SQLiteDatabase;
30 import android.database.sqlite.SQLiteException;
31 
32 import com.android.federatedcompute.internal.util.LogUtil;
33 import com.android.federatedcompute.services.data.FederatedTraningTaskContract.FederatedTrainingTaskColumns;
34 import com.android.federatedcompute.services.statsd.ClientErrorLogger;
35 
36 import com.google.common.annotations.VisibleForTesting;
37 import com.google.common.collect.Iterables;
38 
39 import java.util.ArrayList;
40 import java.util.List;
41 
42 /** DAO for accessing training task table. */
43 public class FederatedTrainingTaskDao {
44 
45     private static final String TAG = FederatedTrainingTaskDao.class.getSimpleName();
46 
47     private final FederatedComputeDbHelper mDbHelper;
48     private static volatile FederatedTrainingTaskDao sSingletonInstance;
49 
FederatedTrainingTaskDao(FederatedComputeDbHelper dbHelper)50     private FederatedTrainingTaskDao(FederatedComputeDbHelper dbHelper) {
51         this.mDbHelper = dbHelper;
52     }
53 
54     /** Returns an instance of the FederatedTrainingTaskDao given a context. */
55     @NonNull
getInstance(Context context)56     public static FederatedTrainingTaskDao getInstance(Context context) {
57         if (sSingletonInstance == null) {
58             synchronized (FederatedTrainingTaskDao.class) {
59                 if (sSingletonInstance == null) {
60                     sSingletonInstance =
61                             new FederatedTrainingTaskDao(
62                                     FederatedComputeDbHelper.getInstance(context));
63                 }
64             }
65         }
66         return sSingletonInstance;
67     }
68 
69     /** It's only public to unit test. */
70     @VisibleForTesting
getInstanceForTest(Context context)71     public static FederatedTrainingTaskDao getInstanceForTest(Context context) {
72         synchronized (FederatedTrainingTaskDao.class) {
73             if (sSingletonInstance == null) {
74                 FederatedComputeDbHelper dbHelper =
75                         FederatedComputeDbHelper.getInstanceForTest(context);
76                 sSingletonInstance = new FederatedTrainingTaskDao(dbHelper);
77             }
78             return sSingletonInstance;
79         }
80     }
81 
82     /** Deletes a training task in FederatedTrainingTask table. */
deleteFederatedTrainingTask(String selection, String[] selectionArgs)83     private void deleteFederatedTrainingTask(String selection, String[] selectionArgs) {
84         SQLiteDatabase db = mDbHelper.safeGetWritableDatabase();
85         if (db == null) {
86             ClientErrorLogger.getInstance()
87                     .logError(
88                             AD_SERVICES_ERROR_REPORTED__ERROR_CODE__DELETE_TASK_FAILURE,
89                             AD_SERVICES_ERROR_REPORTED__PPAPI_NAME__FEDERATED_COMPUTE);
90             return;
91         }
92         db.delete(FEDERATED_TRAINING_TASKS_TABLE, selection, selectionArgs);
93     }
94 
95     /** Insert a training task or update it if task already exists. */
updateOrInsertFederatedTrainingTask(FederatedTrainingTask trainingTask)96     public boolean updateOrInsertFederatedTrainingTask(FederatedTrainingTask trainingTask) {
97         try {
98             SQLiteDatabase db = mDbHelper.safeGetWritableDatabase();
99             if (db == null) {
100                 return false;
101             }
102             return trainingTask.addToDatabase(db);
103         } catch (SQLException e) {
104             LogUtil.e(
105                     TAG,
106                     e,
107                     "Failed to persist federated training task %s",
108                     trainingTask.populationName());
109             return false;
110         }
111     }
112 
113     /** Get the list of tasks that match select conditions. */
114     @Nullable
getFederatedTrainingTask( String selection, String[] selectionArgs)115     public List<FederatedTrainingTask> getFederatedTrainingTask(
116             String selection, String[] selectionArgs) {
117         SQLiteDatabase db = mDbHelper.safeGetReadableDatabase();
118         if (db == null) {
119             return null;
120         }
121         return FederatedTrainingTask.readFederatedTrainingTasksFromDatabase(
122                 db, selection, selectionArgs);
123     }
124 
125     /** Delete a task from table based on job scheduler id. */
findAndRemoveTaskByJobId(int jobId)126     public FederatedTrainingTask findAndRemoveTaskByJobId(int jobId) {
127         String selection = FederatedTrainingTaskColumns.JOB_SCHEDULER_JOB_ID + " = ?";
128         String[] selectionArgs = selectionArgs(jobId);
129         FederatedTrainingTask task =
130                 Iterables.getOnlyElement(getFederatedTrainingTask(selection, selectionArgs), null);
131         try {
132             if (task != null) {
133                 deleteFederatedTrainingTask(selection, selectionArgs);
134             }
135             return task;
136         } catch (SQLException e) {
137             LogUtil.e(TAG, e, "Failed to delete federated training task by job id %d", jobId);
138             ClientErrorLogger.getInstance()
139                     .logErrorWithExceptionInfo(
140                             e,
141                             AD_SERVICES_ERROR_REPORTED__ERROR_CODE__DELETE_TASK_FAILURE,
142                             AD_SERVICES_ERROR_REPORTED__PPAPI_NAME__FEDERATED_COMPUTE);
143             return null;
144         }
145     }
146 
147     /** Delete a task from table based on population name. */
findAndRemoveTaskByPopulationName(String populationName)148     public FederatedTrainingTask findAndRemoveTaskByPopulationName(String populationName) {
149         String selection = FederatedTrainingTaskColumns.POPULATION_NAME + " = ?";
150         String[] selectionArgs = {populationName};
151         FederatedTrainingTask task =
152                 Iterables.getOnlyElement(getFederatedTrainingTask(selection, selectionArgs), null);
153         try {
154             if (task != null) {
155                 deleteFederatedTrainingTask(selection, selectionArgs);
156             }
157             return task;
158         } catch (SQLException e) {
159             LogUtil.e(
160                     TAG,
161                     e,
162                     "Failed to delete federated training task by population name %s",
163                     populationName);
164             ClientErrorLogger.getInstance()
165                     .logErrorWithExceptionInfo(
166                             e,
167                             AD_SERVICES_ERROR_REPORTED__ERROR_CODE__DELETE_TASK_FAILURE,
168                             AD_SERVICES_ERROR_REPORTED__PPAPI_NAME__FEDERATED_COMPUTE);
169             return null;
170         }
171     }
172 
173     /** Delete a task from table based on population name and calling package. */
findAndRemoveTaskByPopulationNameAndCallingPackage( String populationName, String callingPackage)174     public FederatedTrainingTask findAndRemoveTaskByPopulationNameAndCallingPackage(
175             String populationName, String callingPackage) {
176         String selection =
177                 FederatedTrainingTaskColumns.POPULATION_NAME
178                         + " = ? AND "
179                         + FederatedTrainingTaskColumns.APP_PACKAGE_NAME
180                         + " = ?";
181         String[] selectionArgs = {populationName, callingPackage};
182         FederatedTrainingTask task =
183                 Iterables.getOnlyElement(getFederatedTrainingTask(selection, selectionArgs), null);
184         try {
185             if (task != null) {
186                 deleteFederatedTrainingTask(selection, selectionArgs);
187             }
188             return task;
189         } catch (SQLException e) {
190             LogUtil.e(
191                     TAG,
192                     e,
193                     "Failed to delete federated training task by "
194                             + "population name %s and calling package: %s",
195                     populationName,
196                     callingPackage);
197             ClientErrorLogger.getInstance()
198                     .logErrorWithExceptionInfo(
199                             e,
200                             AD_SERVICES_ERROR_REPORTED__ERROR_CODE__DELETE_TASK_FAILURE,
201                             AD_SERVICES_ERROR_REPORTED__PPAPI_NAME__FEDERATED_COMPUTE);
202             return null;
203         }
204     }
205 
206     /** Delete a task from table based on population name and owner Id (package and class name). */
findAndRemoveTaskByPopulationNameAndOwnerId( String populationName, String ownerPackage, String ownerClass, String ownerCertDigest)207     public FederatedTrainingTask findAndRemoveTaskByPopulationNameAndOwnerId(
208             String populationName, String ownerPackage, String ownerClass, String ownerCertDigest) {
209         String selection =
210                 FederatedTrainingTaskColumns.POPULATION_NAME
211                         + " = ? AND "
212                         + FederatedTrainingTaskColumns.OWNER_PACKAGE
213                         + " = ? AND "
214                         + FederatedTrainingTaskColumns.OWNER_CLASS
215                         + " = ? AND "
216                         + FederatedTrainingTaskColumns.OWNER_ID_CERT_DIGEST
217                         + " = ?";
218         String[] selectionArgs = {populationName, ownerPackage, ownerClass, ownerCertDigest};
219         FederatedTrainingTask task =
220                 Iterables.getOnlyElement(getFederatedTrainingTask(selection, selectionArgs), null);
221         try {
222             if (task != null) {
223                 deleteFederatedTrainingTask(selection, selectionArgs);
224             }
225             return task;
226         } catch (SQLException e) {
227             LogUtil.e(
228                     TAG,
229                     e,
230                     "Failed to delete federated training task by population name %s and ATP: %s/%s",
231                     populationName,
232                     ownerPackage,
233                     ownerClass);
234             ClientErrorLogger.getInstance()
235                     .logErrorWithExceptionInfo(
236                             e,
237                             AD_SERVICES_ERROR_REPORTED__ERROR_CODE__DELETE_TASK_FAILURE,
238                             AD_SERVICES_ERROR_REPORTED__PPAPI_NAME__FEDERATED_COMPUTE);
239             return null;
240         }
241     }
242 
243     /** Delete a task from table based on population name and job scheduler id. */
findAndRemoveTaskByPopulationAndJobId( String populationName, int jobId)244     public FederatedTrainingTask findAndRemoveTaskByPopulationAndJobId(
245             String populationName, int jobId) {
246         String selection =
247                 FederatedTrainingTaskColumns.POPULATION_NAME
248                         + " = ? AND "
249                         + FederatedTrainingTaskColumns.JOB_SCHEDULER_JOB_ID
250                         + " = ?";
251         String[] selectionArgs = {populationName, String.valueOf(jobId)};
252         FederatedTrainingTask task =
253                 Iterables.getOnlyElement(getFederatedTrainingTask(selection, selectionArgs), null);
254         try {
255             if (task != null) {
256                 deleteFederatedTrainingTask(selection, selectionArgs);
257             }
258             return task;
259         } catch (SQLException e) {
260             LogUtil.e(
261                     TAG,
262                     e,
263                     "Failed to delete federated training task by population name %s and job id %d",
264                     populationName,
265                     jobId);
266             ClientErrorLogger.getInstance()
267                     .logErrorWithExceptionInfo(
268                             e,
269                             AD_SERVICES_ERROR_REPORTED__ERROR_CODE__DELETE_TASK_FAILURE,
270                             AD_SERVICES_ERROR_REPORTED__PPAPI_NAME__FEDERATED_COMPUTE);
271             return null;
272         }
273     }
274 
275     /** Returns number of tasks already belongs to given owners package. */
getTotalTrainingTaskPerOwnerPackage(String packageName)276     public int getTotalTrainingTaskPerOwnerPackage(String packageName) {
277         SQLiteDatabase db = mDbHelper.safeGetReadableDatabase();
278         if (db == null) {
279             return 0;
280         }
281         final String query =
282                 "SELECT COUNT(*) FROM "
283                         + FEDERATED_TRAINING_TASKS_TABLE
284                         + " WHERE "
285                         + FederatedTrainingTaskColumns.OWNER_PACKAGE
286                         + " = ?";
287         try (Cursor cursor = db.rawQuery(query, new String[] {packageName})) {
288             if (cursor.moveToFirst()) {
289                 return cursor.getInt(0);
290             } else {
291                 return 0; // No matching tasks found
292             }
293         }
294     }
295 
296     /** Insert a training task history record or update it if task already exists. */
updateOrInsertTaskHistory(TaskHistory taskHistory)297     public boolean updateOrInsertTaskHistory(TaskHistory taskHistory) {
298         try {
299             SQLiteDatabase db = mDbHelper.safeGetWritableDatabase();
300             ContentValues values = new ContentValues();
301             values.put(TaskHistoryContract.TaskHistoryEntry.JOB_ID, taskHistory.getJobId());
302             values.put(
303                     TaskHistoryContract.TaskHistoryEntry.POPULATION_NAME,
304                     taskHistory.getPopulationName());
305             values.put(TaskHistoryContract.TaskHistoryEntry.TASK_ID, taskHistory.getTaskId());
306             values.put(
307                     TaskHistoryContract.TaskHistoryEntry.CONTRIBUTION_ROUND,
308                     taskHistory.getContributionRound());
309             values.put(
310                     TaskHistoryContract.TaskHistoryEntry.CONTRIBUTION_TIME,
311                     taskHistory.getContributionTime());
312             values.put(
313                     TaskHistoryContract.TaskHistoryEntry.TOTAL_PARTICIPATION,
314                     taskHistory.getTotalParticipation());
315             return db.insertWithOnConflict(
316                             TaskHistoryContract.TaskHistoryEntry.TABLE_NAME,
317                             null,
318                             values,
319                             SQLiteDatabase.CONFLICT_REPLACE)
320                     != -1;
321         } catch (SQLException e) {
322             LogUtil.e(
323                     TAG,
324                     "Failed to update or insert task history %s %s",
325                     taskHistory.getPopulationName(),
326                     taskHistory.getTaskId());
327         }
328         return false;
329     }
330 
331     /** Gets the list of task history based on provided job id, population name and task id. */
getTaskHistoryList(int jobId, String populationName, String taskId)332     public List<TaskHistory> getTaskHistoryList(int jobId, String populationName, String taskId) {
333         return getTaskHistory(jobId, populationName, taskId, false);
334     }
335 
336     /** Get the latest record of task history based on job id, population name and task name. */
getLatestTaskHistory(int jobId, String populationName, String taskId)337     public TaskHistory getLatestTaskHistory(int jobId, String populationName, String taskId) {
338         List<TaskHistory> taskList = getTaskHistory(jobId, populationName, taskId, true);
339         return taskList.isEmpty() ? null : taskList.get(0);
340     }
341 
getTaskHistory( int jobId, String populationName, String taskId, boolean latest)342     private List<TaskHistory> getTaskHistory(
343             int jobId, String populationName, String taskId, boolean latest) {
344         SQLiteDatabase db = mDbHelper.safeGetReadableDatabase();
345         String selection =
346                 TaskHistoryContract.TaskHistoryEntry.JOB_ID
347                         + " = ? AND "
348                         + TaskHistoryContract.TaskHistoryEntry.POPULATION_NAME
349                         + " = ? AND "
350                         + TaskHistoryContract.TaskHistoryEntry.TASK_ID
351                         + " = ?";
352         String[] selectionArgs = {String.valueOf(jobId), populationName, taskId};
353         String orderBy = TaskHistoryContract.TaskHistoryEntry.CONTRIBUTION_TIME + " DESC";
354         String[] projection = {
355             TaskHistoryContract.TaskHistoryEntry.CONTRIBUTION_TIME,
356             TaskHistoryContract.TaskHistoryEntry.CONTRIBUTION_ROUND,
357             TaskHistoryContract.TaskHistoryEntry.TOTAL_PARTICIPATION
358         };
359         List<TaskHistory> taskList = new ArrayList<>();
360         try (Cursor cursor =
361                 db.query(
362                         TaskHistoryContract.TaskHistoryEntry.TABLE_NAME,
363                         projection,
364                         selection,
365                         selectionArgs,
366                         /* groupBy= */ null,
367                         /* having= */ null,
368                         /* orderBy= */ orderBy)) {
369             while (cursor.moveToNext()) {
370                 long contributionTime =
371                         cursor.getLong(
372                                 cursor.getColumnIndexOrThrow(
373                                         TaskHistoryContract.TaskHistoryEntry.CONTRIBUTION_TIME));
374                 long contributionRound =
375                         cursor.getLong(
376                                 cursor.getColumnIndexOrThrow(
377                                         TaskHistoryContract.TaskHistoryEntry.CONTRIBUTION_ROUND));
378                 long totalParticipation =
379                         cursor.getLong(
380                                 cursor.getColumnIndexOrThrow(
381                                         TaskHistoryContract.TaskHistoryEntry.TOTAL_PARTICIPATION));
382                 taskList.add(
383                         new TaskHistory.Builder()
384                                 .setJobId(jobId)
385                                 .setTaskId(taskId)
386                                 .setPopulationName(populationName)
387                                 .setContributionRound(contributionRound)
388                                 .setContributionTime(contributionTime)
389                                 .setTotalParticipation(totalParticipation)
390                                 .build());
391                 if (latest) {
392                     cursor.close();
393                     return taskList;
394                 }
395             }
396             cursor.close();
397             return taskList;
398         } catch (SQLiteException e) {
399             LogUtil.e(TAG, e, "Failed to read TaskHistory db");
400         }
401         return null;
402     }
403 
404     /** Batch delete expired task history records. */
deleteExpiredTaskHistory(long deleteTime)405     public int deleteExpiredTaskHistory(long deleteTime) {
406         SQLiteDatabase db = mDbHelper.safeGetWritableDatabase();
407         if (db == null) {
408             throw new SQLiteException(TAG + ": Failed to open database.");
409         }
410         String whereClause = TaskHistoryContract.TaskHistoryEntry.CONTRIBUTION_TIME + " < ?";
411         String[] whereArgs = {String.valueOf(deleteTime)};
412         int deletedRows =
413                 db.delete(TaskHistoryContract.TaskHistoryEntry.TABLE_NAME, whereClause, whereArgs);
414         LogUtil.d(TAG, "Deleted %d expired tokens", deletedRows);
415         return deletedRows;
416     }
417 
selectionArgs(Number... args)418     private String[] selectionArgs(Number... args) {
419         String[] values = new String[args.length];
420         for (int i = 0; i < args.length; i++) {
421             values[i] = String.valueOf(args[i]);
422         }
423         return values;
424     }
425 }
426