1 /* 2 * Copyright (C) 2023 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.ondevicepersonalization.services.federatedcompute; 18 19 import static android.federatedcompute.common.ClientConstants.STATUS_SUCCESS; 20 21 import android.content.ComponentName; 22 import android.federatedcompute.ResultHandlingService; 23 import android.federatedcompute.common.ClientConstants; 24 import android.federatedcompute.common.ExampleConsumption; 25 import android.os.Bundle; 26 27 import com.android.ondevicepersonalization.internal.util.LoggerFactory; 28 import com.android.ondevicepersonalization.services.OnDevicePersonalizationExecutors; 29 import com.android.ondevicepersonalization.services.data.events.EventState; 30 import com.android.ondevicepersonalization.services.data.events.EventsDao; 31 32 import com.google.common.util.concurrent.FutureCallback; 33 import com.google.common.util.concurrent.Futures; 34 import com.google.common.util.concurrent.ListenableFuture; 35 36 import java.util.ArrayList; 37 import java.util.List; 38 import java.util.Objects; 39 import java.util.function.Consumer; 40 41 /** Implementation of ResultHandlingService for OnDevicePersonalization */ 42 public class OdpResultHandlingService extends ResultHandlingService { 43 private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger(); 44 private static final String TAG = "OdpResultHandlingService"; 45 46 @Override handleResult(Bundle params, Consumer<Integer> callback)47 public void handleResult(Bundle params, Consumer<Integer> callback) { 48 try { 49 ContextData contextData = 50 ContextData.fromByteArray( 51 Objects.requireNonNull( 52 params.getByteArray(ClientConstants.EXTRA_CONTEXT_DATA))); 53 ComponentName service = 54 ComponentName.createRelative( 55 contextData.getPackageName(), contextData.getClassName()); 56 String populationName = 57 Objects.requireNonNull(params.getString(ClientConstants.EXTRA_POPULATION_NAME)); 58 String taskId = Objects.requireNonNull(params.getString(ClientConstants.EXTRA_TASK_ID)); 59 int computationResult = params.getInt(ClientConstants.EXTRA_COMPUTATION_RESULT); 60 ArrayList<ExampleConsumption> consumptionList = 61 Objects.requireNonNull( 62 params.getParcelableArrayList( 63 ClientConstants.EXTRA_EXAMPLE_CONSUMPTION_LIST, 64 ExampleConsumption.class)); 65 66 // Just return if training failed. Next query will retry the failed examples. 67 if (computationResult != STATUS_SUCCESS) { 68 callback.accept(ClientConstants.STATUS_SUCCESS); 69 return; 70 } 71 72 ListenableFuture<Boolean> result = 73 Futures.submit( 74 () -> 75 processExampleConsumptions( 76 consumptionList, populationName, taskId, service), 77 OnDevicePersonalizationExecutors.getBackgroundExecutor()); 78 Futures.addCallback( 79 result, 80 new FutureCallback<Boolean>() { 81 @Override 82 public void onSuccess(Boolean result) { 83 if (result) { 84 callback.accept(STATUS_SUCCESS); 85 } else { 86 callback.accept(ClientConstants.STATUS_INTERNAL_ERROR); 87 } 88 } 89 90 @Override 91 public void onFailure(Throwable t) { 92 sLogger.w(TAG + ": handleResult failed.", t); 93 callback.accept(ClientConstants.STATUS_INTERNAL_ERROR); 94 } 95 }, 96 OnDevicePersonalizationExecutors.getBackgroundExecutor()); 97 98 } catch (Exception e) { 99 sLogger.w(TAG + ": handleResult failed.", e); 100 callback.accept(ClientConstants.STATUS_INTERNAL_ERROR); 101 } 102 } 103 processExampleConsumptions( List<ExampleConsumption> exampleConsumptions, String populationName, String taskId, ComponentName service)104 private Boolean processExampleConsumptions( 105 List<ExampleConsumption> exampleConsumptions, 106 String populationName, 107 String taskId, 108 ComponentName service) { 109 List<EventState> eventStates = new ArrayList<>(); 110 for (ExampleConsumption consumption : exampleConsumptions) { 111 String taskIdentifier = 112 consumption.getCollectionUri() != null 113 && !consumption.getCollectionUri().isEmpty() 114 ? OdpExampleStoreService.getTaskIdentifier( 115 populationName, taskId, consumption.getCollectionUri()) 116 : OdpExampleStoreService.getTaskIdentifier(populationName, taskId); 117 byte[] resumptionToken = consumption.getResumptionToken(); 118 if (resumptionToken != null) { 119 eventStates.add( 120 new EventState.Builder() 121 .setService(service) 122 .setTaskIdentifier(taskIdentifier) 123 .setToken(resumptionToken) 124 .build()); 125 } 126 } 127 EventsDao eventsDao = EventsDao.getInstance(this); 128 return eventsDao.updateOrInsertEventStatesTransaction(eventStates); 129 } 130 } 131