1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.ondevicepersonalization.services.federatedcompute;
18 
19 import static android.federatedcompute.common.ClientConstants.STATUS_SUCCESS;
20 
21 import android.content.ComponentName;
22 import android.federatedcompute.ResultHandlingService;
23 import android.federatedcompute.common.ClientConstants;
24 import android.federatedcompute.common.ExampleConsumption;
25 import android.os.Bundle;
26 
27 import com.android.ondevicepersonalization.internal.util.LoggerFactory;
28 import com.android.ondevicepersonalization.services.OnDevicePersonalizationExecutors;
29 import com.android.ondevicepersonalization.services.data.events.EventState;
30 import com.android.ondevicepersonalization.services.data.events.EventsDao;
31 
32 import com.google.common.util.concurrent.FutureCallback;
33 import com.google.common.util.concurrent.Futures;
34 import com.google.common.util.concurrent.ListenableFuture;
35 
36 import java.util.ArrayList;
37 import java.util.List;
38 import java.util.Objects;
39 import java.util.function.Consumer;
40 
41 /** Implementation of ResultHandlingService for OnDevicePersonalization */
42 public class OdpResultHandlingService extends ResultHandlingService {
43     private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger();
44     private static final String TAG = "OdpResultHandlingService";
45 
46     @Override
handleResult(Bundle params, Consumer<Integer> callback)47     public void handleResult(Bundle params, Consumer<Integer> callback) {
48         try {
49             ContextData contextData =
50                     ContextData.fromByteArray(
51                             Objects.requireNonNull(
52                                     params.getByteArray(ClientConstants.EXTRA_CONTEXT_DATA)));
53             ComponentName service =
54                     ComponentName.createRelative(
55                             contextData.getPackageName(), contextData.getClassName());
56             String populationName =
57                     Objects.requireNonNull(params.getString(ClientConstants.EXTRA_POPULATION_NAME));
58             String taskId = Objects.requireNonNull(params.getString(ClientConstants.EXTRA_TASK_ID));
59             int computationResult = params.getInt(ClientConstants.EXTRA_COMPUTATION_RESULT);
60             ArrayList<ExampleConsumption> consumptionList =
61                     Objects.requireNonNull(
62                             params.getParcelableArrayList(
63                                     ClientConstants.EXTRA_EXAMPLE_CONSUMPTION_LIST,
64                                     ExampleConsumption.class));
65 
66             // Just return if training failed. Next query will retry the failed examples.
67             if (computationResult != STATUS_SUCCESS) {
68                 callback.accept(ClientConstants.STATUS_SUCCESS);
69                 return;
70             }
71 
72             ListenableFuture<Boolean> result =
73                     Futures.submit(
74                             () ->
75                                     processExampleConsumptions(
76                                             consumptionList, populationName, taskId, service),
77                             OnDevicePersonalizationExecutors.getBackgroundExecutor());
78             Futures.addCallback(
79                     result,
80                     new FutureCallback<Boolean>() {
81                         @Override
82                         public void onSuccess(Boolean result) {
83                             if (result) {
84                                 callback.accept(STATUS_SUCCESS);
85                             } else {
86                                 callback.accept(ClientConstants.STATUS_INTERNAL_ERROR);
87                             }
88                         }
89 
90                         @Override
91                         public void onFailure(Throwable t) {
92                             sLogger.w(TAG + ": handleResult failed.", t);
93                             callback.accept(ClientConstants.STATUS_INTERNAL_ERROR);
94                         }
95                     },
96                     OnDevicePersonalizationExecutors.getBackgroundExecutor());
97 
98         } catch (Exception e) {
99             sLogger.w(TAG + ": handleResult failed.", e);
100             callback.accept(ClientConstants.STATUS_INTERNAL_ERROR);
101         }
102     }
103 
processExampleConsumptions( List<ExampleConsumption> exampleConsumptions, String populationName, String taskId, ComponentName service)104     private Boolean processExampleConsumptions(
105             List<ExampleConsumption> exampleConsumptions,
106             String populationName,
107             String taskId,
108             ComponentName service) {
109         List<EventState> eventStates = new ArrayList<>();
110         for (ExampleConsumption consumption : exampleConsumptions) {
111             String taskIdentifier =
112                     consumption.getCollectionUri() != null
113                                     && !consumption.getCollectionUri().isEmpty()
114                             ? OdpExampleStoreService.getTaskIdentifier(
115                                     populationName, taskId, consumption.getCollectionUri())
116                             : OdpExampleStoreService.getTaskIdentifier(populationName, taskId);
117             byte[] resumptionToken = consumption.getResumptionToken();
118             if (resumptionToken != null) {
119                 eventStates.add(
120                         new EventState.Builder()
121                                 .setService(service)
122                                 .setTaskIdentifier(taskIdentifier)
123                                 .setToken(resumptionToken)
124                                 .build());
125             }
126         }
127         EventsDao eventsDao = EventsDao.getInstance(this);
128         return eventsDao.updateOrInsertEventStatesTransaction(eventStates);
129     }
130 }
131