1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.federatedcompute.services.data;
18 
19 import static com.android.federatedcompute.services.data.FederatedTraningTaskContract.FEDERATED_TRAINING_TASKS_TABLE;
20 
21 import android.annotation.NonNull;
22 import android.annotation.Nullable;
23 import android.content.ContentValues;
24 import android.database.Cursor;
25 import android.database.sqlite.SQLiteDatabase;
26 
27 import com.android.federatedcompute.services.data.FederatedTraningTaskContract.FederatedTrainingTaskColumns;
28 import com.android.federatedcompute.services.data.fbs.TrainingConstraints;
29 import com.android.federatedcompute.services.data.fbs.TrainingIntervalOptions;
30 
31 import com.google.auto.value.AutoValue;
32 
33 import java.nio.ByteBuffer;
34 import java.util.ArrayList;
35 import java.util.List;
36 
37 /** Contains the details of a training task. */
38 @AutoValue
39 public abstract class FederatedTrainingTask {
40     private static final String TAG = FederatedTrainingTask.class.getSimpleName();
41 
42     /**
43      * @return client app package name
44      */
appPackageName()45     public abstract String appPackageName();
46 
47     /**
48      * @return the ID to use for the JobScheduler job that will run the training for this session.
49      */
jobId()50     public abstract int jobId();
51 
52     /**
53      * @return owner identifier package name
54      */
ownerPackageName()55     public abstract String ownerPackageName();
56 
57     /**
58      * @return owner identifier class name
59      */
ownerClassName()60     public abstract String ownerClassName();
61 
62     /**
63      * @return owner identifier cert digest
64      */
ownerIdCertDigest()65     public abstract String ownerIdCertDigest();
66 
67     /**
68      * @return the population name to uniquely identify the training job by.
69      */
populationName()70     public abstract String populationName();
71 
72     /**
73      * @return the remote federated compute server address that federated client need contact when
74      *     job starts.
75      */
serverAddress()76     public abstract String serverAddress();
77 
78     /**
79      * @return the byte array of training interval including scheduling mode and minimum latency.
80      *     The byte array is constructed from TrainingConstraints flatbuffer.
81      */
82     @Nullable
83     @SuppressWarnings("mutable")
intervalOptions()84     public abstract byte[] intervalOptions();
85 
86     /**
87      * @return the training interval including scheduling mode and minimum latency.
88      */
89     @Nullable
getTrainingIntervalOptions()90     public final TrainingIntervalOptions getTrainingIntervalOptions() {
91         if (intervalOptions() == null) {
92             return null;
93         }
94         return TrainingIntervalOptions.getRootAsTrainingIntervalOptions(
95                 ByteBuffer.wrap(intervalOptions()));
96     }
97 
98     /**
99      * @return the context data that clients pass when schedule the job.
100      */
101     @Nullable
102     @SuppressWarnings("mutable")
contextData()103     public abstract byte[] contextData();
104 
105     /**
106      * @return the time the task was originally created.
107      */
creationTime()108     public abstract Long creationTime();
109 
110     /**
111      * @return the time the task was last scheduled.
112      */
lastScheduledTime()113     public abstract Long lastScheduledTime();
114 
115     /**
116      * @return the start time of the task's last run.
117      */
118     @Nullable
lastRunStartTime()119     public abstract Long lastRunStartTime();
120 
121     @NonNull
getLastRunStartTime()122     public long getLastRunStartTime() {
123         return lastRunStartTime() == null ? 0 : lastRunStartTime();
124     }
125 
126     /**
127      * @return the end time of the task's last run.
128      */
129     @Nullable
lastRunEndTime()130     public abstract Long lastRunEndTime();
131 
132     @NonNull
getLastRunEndTime()133     public long getLastRunEndTime() {
134         return lastRunEndTime() == null ? 0 : lastRunEndTime();
135     }
136 
137     /**
138      * @return the earliest time to run the task by.
139      */
earliestNextRunTime()140     public abstract Long earliestNextRunTime();
141 
142     /**
143      * @return the byte array of training constraints that should apply to this task. The byte array
144      *     is constructed from TrainingConstraints flatbuffer.
145      */
146     @SuppressWarnings("mutable")
constraints()147     public abstract byte[] constraints();
148 
149     /**
150      * @return the training constraints that should apply to this task.
151      */
getTrainingConstraints()152     public final TrainingConstraints getTrainingConstraints() {
153         return TrainingConstraints.getRootAsTrainingConstraints(ByteBuffer.wrap(constraints()));
154     }
155 
156     /**
157      * @return the reason to schedule the task.
158      */
schedulingReason()159     public abstract int schedulingReason();
160 
161     /**
162      * @return the number of rescheduling happened for this task.
163      */
rescheduleCount()164     public abstract int rescheduleCount();
165 
166     /** Builder for {@link FederatedTrainingTask} */
167     @AutoValue.Builder
168     public abstract static class Builder {
169         /** Set client application package name. */
appPackageName(String appPackageName)170         public abstract Builder appPackageName(String appPackageName);
171 
172         /** Set job scheduler Id. */
jobId(int jobId)173         public abstract Builder jobId(int jobId);
174 
175         /** Set owner package name. */
ownerPackageName(String ownerPackageName)176         public abstract Builder ownerPackageName(String ownerPackageName);
177         /** Set owner class name. */
ownerClassName(String ownerClassName)178         public abstract Builder ownerClassName(String ownerClassName);
179 
180         /** Set owner identifier cert digest. */
ownerIdCertDigest(String ownerIdCertDigest)181         public abstract Builder ownerIdCertDigest(String ownerIdCertDigest);
182 
183         /** Set population name which uniquely identify the job. */
populationName(String populationName)184         public abstract Builder populationName(String populationName);
185 
186         /** Set remote federated compute server address. */
serverAddress(String serverAddress)187         public abstract Builder serverAddress(String serverAddress);
188 
189         /** Set the training interval including scheduling mode and minimum latency. */
190         @SuppressWarnings("mutable")
intervalOptions(@ullable byte[] intervalOptions)191         public abstract Builder intervalOptions(@Nullable byte[] intervalOptions);
192 
193         /** Set the context data that clients pass when schedule job. */
194         @SuppressWarnings("mutable")
contextData(@ullable byte[] contextData)195         public abstract Builder contextData(@Nullable byte[] contextData);
196 
197         /** Set the time the task was originally created. */
creationTime(Long creationTime)198         public abstract Builder creationTime(Long creationTime);
199 
200         /** Set the time the task was last scheduled. */
lastScheduledTime(Long lastScheduledTime)201         public abstract Builder lastScheduledTime(Long lastScheduledTime);
202 
203         /** Set the start time of the task's last run. */
lastRunStartTime(@ullable Long lastRunStartTime)204         public abstract Builder lastRunStartTime(@Nullable Long lastRunStartTime);
205 
206         /** Set the end time of the task's last run. */
lastRunEndTime(@ullable Long lastRunEndTime)207         public abstract Builder lastRunEndTime(@Nullable Long lastRunEndTime);
208 
209         /** Set the earliest time to run the task by. */
earliestNextRunTime(Long earliestNextRunTime)210         public abstract Builder earliestNextRunTime(Long earliestNextRunTime);
211 
212         /** Set the training constraints that should apply to this task. */
213         @SuppressWarnings("mutable")
constraints(byte[] constraints)214         public abstract Builder constraints(byte[] constraints);
215 
216         /** Set the reason to schedule the task. */
schedulingReason(int schedulingReason)217         public abstract Builder schedulingReason(int schedulingReason);
218 
219         /** Set the count of reschedules. */
rescheduleCount(int rescheduleCount)220         public abstract Builder rescheduleCount(int rescheduleCount);
221 
222         /** Build a federated training task instance. */
223         @NonNull
build()224         public abstract FederatedTrainingTask build();
225     }
226 
227     /**
228      * @return a builder of federated training task.
229      */
toBuilder()230     public abstract Builder toBuilder();
231 
232     /**
233      * @return a generic builder.
234      */
235     @NonNull
builder()236     public static Builder builder() {
237         return new AutoValue_FederatedTrainingTask.Builder().rescheduleCount(0);
238     }
239 
addToDatabase(SQLiteDatabase db)240     boolean addToDatabase(SQLiteDatabase db) {
241         ContentValues values = new ContentValues();
242         values.put(FederatedTrainingTaskColumns.APP_PACKAGE_NAME, appPackageName());
243         values.put(FederatedTrainingTaskColumns.JOB_SCHEDULER_JOB_ID, jobId());
244         values.put(FederatedTrainingTaskColumns.OWNER_PACKAGE, ownerPackageName());
245         values.put(FederatedTrainingTaskColumns.OWNER_CLASS, ownerClassName());
246         values.put(
247                 FederatedTrainingTaskColumns.OWNER_ID, ownerPackageName() + "/" + ownerClassName());
248         values.put(FederatedTrainingTaskColumns.OWNER_ID_CERT_DIGEST, ownerIdCertDigest());
249 
250         values.put(FederatedTrainingTaskColumns.POPULATION_NAME, populationName());
251         values.put(FederatedTrainingTaskColumns.SERVER_ADDRESS, serverAddress());
252         if (intervalOptions() != null) {
253             values.put(FederatedTrainingTaskColumns.INTERVAL_OPTIONS, intervalOptions());
254         }
255 
256         if (contextData() != null) {
257             values.put(FederatedTrainingTaskColumns.CONTEXT_DATA, contextData());
258         }
259 
260         values.put(FederatedTrainingTaskColumns.CREATION_TIME, creationTime());
261         values.put(FederatedTrainingTaskColumns.LAST_SCHEDULED_TIME, lastScheduledTime());
262         if (lastRunStartTime() != null) {
263             values.put(FederatedTrainingTaskColumns.LAST_RUN_START_TIME, lastRunStartTime());
264         }
265         if (lastRunEndTime() != null) {
266             values.put(FederatedTrainingTaskColumns.LAST_RUN_END_TIME, lastRunEndTime());
267         }
268         values.put(FederatedTrainingTaskColumns.EARLIEST_NEXT_RUN_TIME, earliestNextRunTime());
269         values.put(FederatedTrainingTaskColumns.CONSTRAINTS, constraints());
270         values.put(FederatedTrainingTaskColumns.SCHEDULING_REASON, schedulingReason());
271         values.put(FederatedTrainingTaskColumns.RESCHEDULE_COUNT, rescheduleCount());
272         long jobId =
273                 db.insertWithOnConflict(
274                         FEDERATED_TRAINING_TASKS_TABLE,
275                         "",
276                         values,
277                         SQLiteDatabase.CONFLICT_REPLACE);
278         return jobId != -1;
279     }
280 
readFederatedTrainingTasksFromDatabase( SQLiteDatabase db, String selection, String[] selectionArgs)281     static List<FederatedTrainingTask> readFederatedTrainingTasksFromDatabase(
282             SQLiteDatabase db, String selection, String[] selectionArgs) {
283         List<FederatedTrainingTask> taskList = new ArrayList<>();
284         String[] selectColumns = {
285             FederatedTrainingTaskColumns.APP_PACKAGE_NAME,
286             FederatedTrainingTaskColumns.JOB_SCHEDULER_JOB_ID,
287             FederatedTrainingTaskColumns.OWNER_PACKAGE,
288             FederatedTrainingTaskColumns.OWNER_CLASS,
289             FederatedTrainingTaskColumns.OWNER_ID_CERT_DIGEST,
290             FederatedTrainingTaskColumns.POPULATION_NAME,
291             FederatedTrainingTaskColumns.SERVER_ADDRESS,
292             FederatedTrainingTaskColumns.INTERVAL_OPTIONS,
293             FederatedTrainingTaskColumns.CONTEXT_DATA,
294             FederatedTrainingTaskColumns.CREATION_TIME,
295             FederatedTrainingTaskColumns.LAST_SCHEDULED_TIME,
296             FederatedTrainingTaskColumns.LAST_RUN_START_TIME,
297             FederatedTrainingTaskColumns.LAST_RUN_END_TIME,
298             FederatedTrainingTaskColumns.EARLIEST_NEXT_RUN_TIME,
299             FederatedTrainingTaskColumns.CONSTRAINTS,
300             FederatedTrainingTaskColumns.SCHEDULING_REASON,
301             FederatedTrainingTaskColumns.RESCHEDULE_COUNT,
302         };
303         Cursor cursor = null;
304         try {
305             cursor =
306                     db.query(
307                             FEDERATED_TRAINING_TASKS_TABLE,
308                             selectColumns,
309                             selection,
310                             selectionArgs,
311                             null,
312                             null
313                             /* groupBy= */ ,
314                             null
315                             /* having= */ ,
316                             null
317                             /* orderBy= */ );
318             while (cursor.moveToNext()) {
319                 FederatedTrainingTask.Builder trainingTaskBuilder =
320                         FederatedTrainingTask.builder()
321                                 .appPackageName(
322                                         cursor.getString(
323                                                 cursor.getColumnIndexOrThrow(
324                                                         FederatedTrainingTaskColumns
325                                                                 .APP_PACKAGE_NAME)))
326                                 .jobId(
327                                         cursor.getInt(
328                                                 cursor.getColumnIndexOrThrow(
329                                                         FederatedTrainingTaskColumns
330                                                                 .JOB_SCHEDULER_JOB_ID)))
331                                 .ownerPackageName(
332                                         cursor.getString(
333                                                 cursor.getColumnIndexOrThrow(
334                                                         FederatedTrainingTaskColumns
335                                                                 .OWNER_PACKAGE)))
336                                 .ownerClassName(
337                                         cursor.getString(
338                                                 cursor.getColumnIndexOrThrow(
339                                                         FederatedTrainingTaskColumns.OWNER_CLASS)))
340                                 .ownerIdCertDigest(
341                                         cursor.getString(
342                                                 cursor.getColumnIndexOrThrow(
343                                                         FederatedTrainingTaskColumns
344                                                                 .OWNER_ID_CERT_DIGEST)))
345                                 .populationName(
346                                         cursor.getString(
347                                                 cursor.getColumnIndexOrThrow(
348                                                         FederatedTrainingTaskColumns
349                                                                 .POPULATION_NAME)))
350                                 .serverAddress(
351                                         cursor.getString(
352                                                 cursor.getColumnIndexOrThrow(
353                                                         FederatedTrainingTaskColumns
354                                                                 .SERVER_ADDRESS)))
355                                 .creationTime(
356                                         cursor.getLong(
357                                                 cursor.getColumnIndexOrThrow(
358                                                         FederatedTrainingTaskColumns
359                                                                 .CREATION_TIME)))
360                                 .lastScheduledTime(
361                                         cursor.getLong(
362                                                 cursor.getColumnIndexOrThrow(
363                                                         FederatedTrainingTaskColumns
364                                                                 .LAST_SCHEDULED_TIME)))
365                                 .lastRunStartTime(
366                                         cursor.getLong(
367                                                 cursor.getColumnIndexOrThrow(
368                                                         FederatedTrainingTaskColumns
369                                                                 .LAST_RUN_START_TIME)))
370                                 .lastRunEndTime(
371                                         cursor.getLong(
372                                                 cursor.getColumnIndexOrThrow(
373                                                         FederatedTrainingTaskColumns
374                                                                 .LAST_RUN_END_TIME)))
375                                 .earliestNextRunTime(
376                                         cursor.getLong(
377                                                 cursor.getColumnIndexOrThrow(
378                                                         FederatedTrainingTaskColumns
379                                                                 .EARLIEST_NEXT_RUN_TIME)))
380                                 .rescheduleCount(
381                                         cursor.getInt(
382                                                 cursor.getColumnIndexOrThrow(
383                                                         FederatedTrainingTaskColumns
384                                                                 .RESCHEDULE_COUNT)));
385                 int schedulingReason =
386                         cursor.getInt(
387                                 cursor.getColumnIndexOrThrow(
388                                         FederatedTrainingTaskColumns.SCHEDULING_REASON));
389                 if (!cursor.isNull(schedulingReason)) {
390                     trainingTaskBuilder.schedulingReason(schedulingReason);
391                 }
392                 byte[] intervalOptions =
393                         cursor.getBlob(
394                                 cursor.getColumnIndexOrThrow(
395                                         FederatedTrainingTaskColumns.INTERVAL_OPTIONS));
396                 if (intervalOptions != null) {
397                     trainingTaskBuilder.intervalOptions(intervalOptions);
398                 }
399                 byte[] contextData =
400                         cursor.getBlob(
401                                 cursor.getColumnIndexOrThrow(
402                                         FederatedTrainingTaskColumns.CONTEXT_DATA));
403                 if (contextData != null) {
404                     trainingTaskBuilder.contextData(contextData);
405                 }
406                 byte[] constraints =
407                         cursor.getBlob(
408                                 cursor.getColumnIndexOrThrow(
409                                         FederatedTrainingTaskColumns.CONSTRAINTS));
410                 trainingTaskBuilder.constraints(constraints);
411                 taskList.add(trainingTaskBuilder.build());
412             }
413         } finally {
414             if (cursor != null) {
415                 cursor.close();
416             }
417         }
418         return taskList;
419     }
420 }
421