1 /* 2 * Copyright (C) 2023 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.federatedcompute.services.data; 18 19 import static com.android.federatedcompute.services.data.FederatedTraningTaskContract.FEDERATED_TRAINING_TASKS_TABLE; 20 21 import android.annotation.NonNull; 22 import android.annotation.Nullable; 23 import android.content.ContentValues; 24 import android.database.Cursor; 25 import android.database.sqlite.SQLiteDatabase; 26 27 import com.android.federatedcompute.services.data.FederatedTraningTaskContract.FederatedTrainingTaskColumns; 28 import com.android.federatedcompute.services.data.fbs.TrainingConstraints; 29 import com.android.federatedcompute.services.data.fbs.TrainingIntervalOptions; 30 31 import com.google.auto.value.AutoValue; 32 33 import java.nio.ByteBuffer; 34 import java.util.ArrayList; 35 import java.util.List; 36 37 /** Contains the details of a training task. */ 38 @AutoValue 39 public abstract class FederatedTrainingTask { 40 private static final String TAG = FederatedTrainingTask.class.getSimpleName(); 41 42 /** 43 * @return client app package name 44 */ appPackageName()45 public abstract String appPackageName(); 46 47 /** 48 * @return the ID to use for the JobScheduler job that will run the training for this session. 49 */ jobId()50 public abstract int jobId(); 51 52 /** 53 * @return owner identifier package name 54 */ ownerPackageName()55 public abstract String ownerPackageName(); 56 57 /** 58 * @return owner identifier class name 59 */ ownerClassName()60 public abstract String ownerClassName(); 61 62 /** 63 * @return owner identifier cert digest 64 */ ownerIdCertDigest()65 public abstract String ownerIdCertDigest(); 66 67 /** 68 * @return the population name to uniquely identify the training job by. 69 */ populationName()70 public abstract String populationName(); 71 72 /** 73 * @return the remote federated compute server address that federated client need contact when 74 * job starts. 75 */ serverAddress()76 public abstract String serverAddress(); 77 78 /** 79 * @return the byte array of training interval including scheduling mode and minimum latency. 80 * The byte array is constructed from TrainingConstraints flatbuffer. 81 */ 82 @Nullable 83 @SuppressWarnings("mutable") intervalOptions()84 public abstract byte[] intervalOptions(); 85 86 /** 87 * @return the training interval including scheduling mode and minimum latency. 88 */ 89 @Nullable getTrainingIntervalOptions()90 public final TrainingIntervalOptions getTrainingIntervalOptions() { 91 if (intervalOptions() == null) { 92 return null; 93 } 94 return TrainingIntervalOptions.getRootAsTrainingIntervalOptions( 95 ByteBuffer.wrap(intervalOptions())); 96 } 97 98 /** 99 * @return the context data that clients pass when schedule the job. 100 */ 101 @Nullable 102 @SuppressWarnings("mutable") contextData()103 public abstract byte[] contextData(); 104 105 /** 106 * @return the time the task was originally created. 107 */ creationTime()108 public abstract Long creationTime(); 109 110 /** 111 * @return the time the task was last scheduled. 112 */ lastScheduledTime()113 public abstract Long lastScheduledTime(); 114 115 /** 116 * @return the start time of the task's last run. 117 */ 118 @Nullable lastRunStartTime()119 public abstract Long lastRunStartTime(); 120 121 @NonNull getLastRunStartTime()122 public long getLastRunStartTime() { 123 return lastRunStartTime() == null ? 0 : lastRunStartTime(); 124 } 125 126 /** 127 * @return the end time of the task's last run. 128 */ 129 @Nullable lastRunEndTime()130 public abstract Long lastRunEndTime(); 131 132 @NonNull getLastRunEndTime()133 public long getLastRunEndTime() { 134 return lastRunEndTime() == null ? 0 : lastRunEndTime(); 135 } 136 137 /** 138 * @return the earliest time to run the task by. 139 */ earliestNextRunTime()140 public abstract Long earliestNextRunTime(); 141 142 /** 143 * @return the byte array of training constraints that should apply to this task. The byte array 144 * is constructed from TrainingConstraints flatbuffer. 145 */ 146 @SuppressWarnings("mutable") constraints()147 public abstract byte[] constraints(); 148 149 /** 150 * @return the training constraints that should apply to this task. 151 */ getTrainingConstraints()152 public final TrainingConstraints getTrainingConstraints() { 153 return TrainingConstraints.getRootAsTrainingConstraints(ByteBuffer.wrap(constraints())); 154 } 155 156 /** 157 * @return the reason to schedule the task. 158 */ schedulingReason()159 public abstract int schedulingReason(); 160 161 /** 162 * @return the number of rescheduling happened for this task. 163 */ rescheduleCount()164 public abstract int rescheduleCount(); 165 166 /** Builder for {@link FederatedTrainingTask} */ 167 @AutoValue.Builder 168 public abstract static class Builder { 169 /** Set client application package name. */ appPackageName(String appPackageName)170 public abstract Builder appPackageName(String appPackageName); 171 172 /** Set job scheduler Id. */ jobId(int jobId)173 public abstract Builder jobId(int jobId); 174 175 /** Set owner package name. */ ownerPackageName(String ownerPackageName)176 public abstract Builder ownerPackageName(String ownerPackageName); 177 /** Set owner class name. */ ownerClassName(String ownerClassName)178 public abstract Builder ownerClassName(String ownerClassName); 179 180 /** Set owner identifier cert digest. */ ownerIdCertDigest(String ownerIdCertDigest)181 public abstract Builder ownerIdCertDigest(String ownerIdCertDigest); 182 183 /** Set population name which uniquely identify the job. */ populationName(String populationName)184 public abstract Builder populationName(String populationName); 185 186 /** Set remote federated compute server address. */ serverAddress(String serverAddress)187 public abstract Builder serverAddress(String serverAddress); 188 189 /** Set the training interval including scheduling mode and minimum latency. */ 190 @SuppressWarnings("mutable") intervalOptions(@ullable byte[] intervalOptions)191 public abstract Builder intervalOptions(@Nullable byte[] intervalOptions); 192 193 /** Set the context data that clients pass when schedule job. */ 194 @SuppressWarnings("mutable") contextData(@ullable byte[] contextData)195 public abstract Builder contextData(@Nullable byte[] contextData); 196 197 /** Set the time the task was originally created. */ creationTime(Long creationTime)198 public abstract Builder creationTime(Long creationTime); 199 200 /** Set the time the task was last scheduled. */ lastScheduledTime(Long lastScheduledTime)201 public abstract Builder lastScheduledTime(Long lastScheduledTime); 202 203 /** Set the start time of the task's last run. */ lastRunStartTime(@ullable Long lastRunStartTime)204 public abstract Builder lastRunStartTime(@Nullable Long lastRunStartTime); 205 206 /** Set the end time of the task's last run. */ lastRunEndTime(@ullable Long lastRunEndTime)207 public abstract Builder lastRunEndTime(@Nullable Long lastRunEndTime); 208 209 /** Set the earliest time to run the task by. */ earliestNextRunTime(Long earliestNextRunTime)210 public abstract Builder earliestNextRunTime(Long earliestNextRunTime); 211 212 /** Set the training constraints that should apply to this task. */ 213 @SuppressWarnings("mutable") constraints(byte[] constraints)214 public abstract Builder constraints(byte[] constraints); 215 216 /** Set the reason to schedule the task. */ schedulingReason(int schedulingReason)217 public abstract Builder schedulingReason(int schedulingReason); 218 219 /** Set the count of reschedules. */ rescheduleCount(int rescheduleCount)220 public abstract Builder rescheduleCount(int rescheduleCount); 221 222 /** Build a federated training task instance. */ 223 @NonNull build()224 public abstract FederatedTrainingTask build(); 225 } 226 227 /** 228 * @return a builder of federated training task. 229 */ toBuilder()230 public abstract Builder toBuilder(); 231 232 /** 233 * @return a generic builder. 234 */ 235 @NonNull builder()236 public static Builder builder() { 237 return new AutoValue_FederatedTrainingTask.Builder().rescheduleCount(0); 238 } 239 addToDatabase(SQLiteDatabase db)240 boolean addToDatabase(SQLiteDatabase db) { 241 ContentValues values = new ContentValues(); 242 values.put(FederatedTrainingTaskColumns.APP_PACKAGE_NAME, appPackageName()); 243 values.put(FederatedTrainingTaskColumns.JOB_SCHEDULER_JOB_ID, jobId()); 244 values.put(FederatedTrainingTaskColumns.OWNER_PACKAGE, ownerPackageName()); 245 values.put(FederatedTrainingTaskColumns.OWNER_CLASS, ownerClassName()); 246 values.put( 247 FederatedTrainingTaskColumns.OWNER_ID, ownerPackageName() + "/" + ownerClassName()); 248 values.put(FederatedTrainingTaskColumns.OWNER_ID_CERT_DIGEST, ownerIdCertDigest()); 249 250 values.put(FederatedTrainingTaskColumns.POPULATION_NAME, populationName()); 251 values.put(FederatedTrainingTaskColumns.SERVER_ADDRESS, serverAddress()); 252 if (intervalOptions() != null) { 253 values.put(FederatedTrainingTaskColumns.INTERVAL_OPTIONS, intervalOptions()); 254 } 255 256 if (contextData() != null) { 257 values.put(FederatedTrainingTaskColumns.CONTEXT_DATA, contextData()); 258 } 259 260 values.put(FederatedTrainingTaskColumns.CREATION_TIME, creationTime()); 261 values.put(FederatedTrainingTaskColumns.LAST_SCHEDULED_TIME, lastScheduledTime()); 262 if (lastRunStartTime() != null) { 263 values.put(FederatedTrainingTaskColumns.LAST_RUN_START_TIME, lastRunStartTime()); 264 } 265 if (lastRunEndTime() != null) { 266 values.put(FederatedTrainingTaskColumns.LAST_RUN_END_TIME, lastRunEndTime()); 267 } 268 values.put(FederatedTrainingTaskColumns.EARLIEST_NEXT_RUN_TIME, earliestNextRunTime()); 269 values.put(FederatedTrainingTaskColumns.CONSTRAINTS, constraints()); 270 values.put(FederatedTrainingTaskColumns.SCHEDULING_REASON, schedulingReason()); 271 values.put(FederatedTrainingTaskColumns.RESCHEDULE_COUNT, rescheduleCount()); 272 long jobId = 273 db.insertWithOnConflict( 274 FEDERATED_TRAINING_TASKS_TABLE, 275 "", 276 values, 277 SQLiteDatabase.CONFLICT_REPLACE); 278 return jobId != -1; 279 } 280 readFederatedTrainingTasksFromDatabase( SQLiteDatabase db, String selection, String[] selectionArgs)281 static List<FederatedTrainingTask> readFederatedTrainingTasksFromDatabase( 282 SQLiteDatabase db, String selection, String[] selectionArgs) { 283 List<FederatedTrainingTask> taskList = new ArrayList<>(); 284 String[] selectColumns = { 285 FederatedTrainingTaskColumns.APP_PACKAGE_NAME, 286 FederatedTrainingTaskColumns.JOB_SCHEDULER_JOB_ID, 287 FederatedTrainingTaskColumns.OWNER_PACKAGE, 288 FederatedTrainingTaskColumns.OWNER_CLASS, 289 FederatedTrainingTaskColumns.OWNER_ID_CERT_DIGEST, 290 FederatedTrainingTaskColumns.POPULATION_NAME, 291 FederatedTrainingTaskColumns.SERVER_ADDRESS, 292 FederatedTrainingTaskColumns.INTERVAL_OPTIONS, 293 FederatedTrainingTaskColumns.CONTEXT_DATA, 294 FederatedTrainingTaskColumns.CREATION_TIME, 295 FederatedTrainingTaskColumns.LAST_SCHEDULED_TIME, 296 FederatedTrainingTaskColumns.LAST_RUN_START_TIME, 297 FederatedTrainingTaskColumns.LAST_RUN_END_TIME, 298 FederatedTrainingTaskColumns.EARLIEST_NEXT_RUN_TIME, 299 FederatedTrainingTaskColumns.CONSTRAINTS, 300 FederatedTrainingTaskColumns.SCHEDULING_REASON, 301 FederatedTrainingTaskColumns.RESCHEDULE_COUNT, 302 }; 303 Cursor cursor = null; 304 try { 305 cursor = 306 db.query( 307 FEDERATED_TRAINING_TASKS_TABLE, 308 selectColumns, 309 selection, 310 selectionArgs, 311 null, 312 null 313 /* groupBy= */ , 314 null 315 /* having= */ , 316 null 317 /* orderBy= */ ); 318 while (cursor.moveToNext()) { 319 FederatedTrainingTask.Builder trainingTaskBuilder = 320 FederatedTrainingTask.builder() 321 .appPackageName( 322 cursor.getString( 323 cursor.getColumnIndexOrThrow( 324 FederatedTrainingTaskColumns 325 .APP_PACKAGE_NAME))) 326 .jobId( 327 cursor.getInt( 328 cursor.getColumnIndexOrThrow( 329 FederatedTrainingTaskColumns 330 .JOB_SCHEDULER_JOB_ID))) 331 .ownerPackageName( 332 cursor.getString( 333 cursor.getColumnIndexOrThrow( 334 FederatedTrainingTaskColumns 335 .OWNER_PACKAGE))) 336 .ownerClassName( 337 cursor.getString( 338 cursor.getColumnIndexOrThrow( 339 FederatedTrainingTaskColumns.OWNER_CLASS))) 340 .ownerIdCertDigest( 341 cursor.getString( 342 cursor.getColumnIndexOrThrow( 343 FederatedTrainingTaskColumns 344 .OWNER_ID_CERT_DIGEST))) 345 .populationName( 346 cursor.getString( 347 cursor.getColumnIndexOrThrow( 348 FederatedTrainingTaskColumns 349 .POPULATION_NAME))) 350 .serverAddress( 351 cursor.getString( 352 cursor.getColumnIndexOrThrow( 353 FederatedTrainingTaskColumns 354 .SERVER_ADDRESS))) 355 .creationTime( 356 cursor.getLong( 357 cursor.getColumnIndexOrThrow( 358 FederatedTrainingTaskColumns 359 .CREATION_TIME))) 360 .lastScheduledTime( 361 cursor.getLong( 362 cursor.getColumnIndexOrThrow( 363 FederatedTrainingTaskColumns 364 .LAST_SCHEDULED_TIME))) 365 .lastRunStartTime( 366 cursor.getLong( 367 cursor.getColumnIndexOrThrow( 368 FederatedTrainingTaskColumns 369 .LAST_RUN_START_TIME))) 370 .lastRunEndTime( 371 cursor.getLong( 372 cursor.getColumnIndexOrThrow( 373 FederatedTrainingTaskColumns 374 .LAST_RUN_END_TIME))) 375 .earliestNextRunTime( 376 cursor.getLong( 377 cursor.getColumnIndexOrThrow( 378 FederatedTrainingTaskColumns 379 .EARLIEST_NEXT_RUN_TIME))) 380 .rescheduleCount( 381 cursor.getInt( 382 cursor.getColumnIndexOrThrow( 383 FederatedTrainingTaskColumns 384 .RESCHEDULE_COUNT))); 385 int schedulingReason = 386 cursor.getInt( 387 cursor.getColumnIndexOrThrow( 388 FederatedTrainingTaskColumns.SCHEDULING_REASON)); 389 if (!cursor.isNull(schedulingReason)) { 390 trainingTaskBuilder.schedulingReason(schedulingReason); 391 } 392 byte[] intervalOptions = 393 cursor.getBlob( 394 cursor.getColumnIndexOrThrow( 395 FederatedTrainingTaskColumns.INTERVAL_OPTIONS)); 396 if (intervalOptions != null) { 397 trainingTaskBuilder.intervalOptions(intervalOptions); 398 } 399 byte[] contextData = 400 cursor.getBlob( 401 cursor.getColumnIndexOrThrow( 402 FederatedTrainingTaskColumns.CONTEXT_DATA)); 403 if (contextData != null) { 404 trainingTaskBuilder.contextData(contextData); 405 } 406 byte[] constraints = 407 cursor.getBlob( 408 cursor.getColumnIndexOrThrow( 409 FederatedTrainingTaskColumns.CONSTRAINTS)); 410 trainingTaskBuilder.constraints(constraints); 411 taskList.add(trainingTaskBuilder.build()); 412 } 413 } finally { 414 if (cursor != null) { 415 cursor.close(); 416 } 417 } 418 return taskList; 419 } 420 } 421