1 /*
2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.federatedcompute.services.examplestore;
18 
19 import static com.android.federatedcompute.services.common.Constants.TRACE_GET_EXAMPLE_STORE_ITERATOR;
20 
21 import android.content.Context;
22 import android.federatedcompute.aidl.IExampleStoreCallback;
23 import android.federatedcompute.aidl.IExampleStoreIterator;
24 import android.federatedcompute.aidl.IExampleStoreService;
25 import android.federatedcompute.common.ClientConstants;
26 import android.os.Bundle;
27 import android.os.SystemClock;
28 import android.os.Trace;
29 
30 import androidx.concurrent.futures.CallbackToFutureAdapter;
31 
32 import com.android.federatedcompute.internal.util.AbstractServiceBinder;
33 import com.android.federatedcompute.internal.util.LogUtil;
34 import com.android.federatedcompute.services.common.ExampleStats;
35 import com.android.federatedcompute.services.common.FlagsFactory;
36 import com.android.federatedcompute.services.data.FederatedTrainingTask;
37 
38 import com.google.common.util.concurrent.ListenableFuture;
39 import com.google.internal.federated.plan.ExampleSelector;
40 
41 import java.util.concurrent.ArrayBlockingQueue;
42 import java.util.concurrent.BlockingQueue;
43 import java.util.concurrent.TimeUnit;
44 
45 /** Provides {@link IExampleStoreService}. */
46 public class ExampleStoreServiceProvider {
47     private static final String TAG = ExampleStoreServiceProvider.class.getSimpleName();
48     private AbstractServiceBinder<IExampleStoreService> mExampleStoreServiceBinder;
49 
50     /** Returns {@link IExampleStoreService}. */
getExampleStoreService(String packageName, Context context)51     public IExampleStoreService getExampleStoreService(String packageName, Context context) {
52         mExampleStoreServiceBinder =
53                 AbstractServiceBinder.getServiceBinderByIntent(
54                         context,
55                         ClientConstants.EXAMPLE_STORE_ACTION,
56                         packageName,
57                         IExampleStoreService.Stub::asInterface);
58         return mExampleStoreServiceBinder.getService(Runnable::run);
59     }
60 
61     /** Unbind from {@link IExampleStoreService}. */
unbindFromExampleStoreService()62     public void unbindFromExampleStoreService() {
63         mExampleStoreServiceBinder.unbindFromService();
64     }
65 
66     /** Returns an {@link IExampleStoreIterator} implemented by client app in synchronized call. */
getExampleIterator( IExampleStoreService exampleStoreService, FederatedTrainingTask task, String taskName, int minExample, ExampleSelector exampleSelector)67     public IExampleStoreIterator getExampleIterator(
68             IExampleStoreService exampleStoreService,
69             FederatedTrainingTask task,
70             String taskName,
71             int minExample,
72             ExampleSelector exampleSelector) {
73         try {
74             Trace.beginAsyncSection(TRACE_GET_EXAMPLE_STORE_ITERATOR, 1);
75             Bundle bundle = new Bundle();
76             bundle.putString(ClientConstants.EXTRA_POPULATION_NAME, task.populationName());
77             bundle.putString(ClientConstants.EXTRA_TASK_ID, taskName);
78             bundle.putByteArray(ClientConstants.EXTRA_CONTEXT_DATA, task.contextData());
79             bundle.putInt(ClientConstants.EXTRA_ELIGIBILITY_MIN_EXAMPLE, minExample);
80             if (exampleSelector != null) {
81                 byte[] criteria = exampleSelector.getCriteria().toByteArray();
82                 byte[] resumptionToken = exampleSelector.getResumptionToken().toByteArray();
83                 bundle.putByteArray(
84                         ClientConstants.EXTRA_EXAMPLE_ITERATOR_RESUMPTION_TOKEN, resumptionToken);
85                 bundle.putByteArray(ClientConstants.EXTRA_EXAMPLE_ITERATOR_CRITERIA, criteria);
86                 bundle.putString(
87                         ClientConstants.EXTRA_COLLECTION_URI, exampleSelector.getCollectionUri());
88             }
89             BlockingQueue<CallbackResult> asyncResult = new ArrayBlockingQueue<>(1);
90             exampleStoreService.startQuery(
91                     bundle,
92                     new IExampleStoreCallback.Stub() {
93                         @Override
94                         public void onStartQuerySuccess(IExampleStoreIterator iterator) {
95                             LogUtil.d(TAG, "Acquired iterator");
96                             asyncResult.add(new CallbackResult(iterator, 0));
97                             Trace.endAsyncSection(TRACE_GET_EXAMPLE_STORE_ITERATOR, 1);
98                         }
99 
100                         @Override
101                         public void onStartQueryFailure(int errorCode) {
102                             LogUtil.e(TAG, "Could not acquire iterator: " + errorCode);
103                             asyncResult.add(new CallbackResult(null, errorCode));
104                             Trace.endAsyncSection(TRACE_GET_EXAMPLE_STORE_ITERATOR, 1);
105                         }
106                     });
107             CallbackResult callbackResult =
108                     asyncResult.poll(
109                             FlagsFactory.getFlags().getExampleStoreServiceCallbackTimeoutSec(),
110                             TimeUnit.SECONDS);
111             // Callback result is null if timeout.
112             if (callbackResult == null || callbackResult.mErrorCode != 0) {
113                 return null;
114             }
115             return callbackResult.mIterator;
116         } catch (Exception e) {
117             LogUtil.e(TAG, e, "Got exception when StartQuery");
118             return null;
119         }
120     }
121 
122     private static class CallbackResult {
123         final IExampleStoreIterator mIterator;
124         final int mErrorCode;
125 
CallbackResult(IExampleStoreIterator iterator, int errorCode)126         CallbackResult(IExampleStoreIterator iterator, int errorCode) {
127             mIterator = iterator;
128             mErrorCode = errorCode;
129         }
130     }
131 
runExampleStoreStartQuery( IExampleStoreService exampleStoreService, Bundle input, ExampleStats exampleStats, long startCallTimeNanos)132     private ListenableFuture<IExampleStoreIterator> runExampleStoreStartQuery(
133             IExampleStoreService exampleStoreService,
134             Bundle input,
135             ExampleStats exampleStats,
136             long startCallTimeNanos) {
137         return CallbackToFutureAdapter.getFuture(
138                 completer -> {
139                     try {
140                         exampleStoreService.startQuery(
141                                 input,
142                                 new IExampleStoreCallback.Stub() {
143                                     @Override
144                                     public void onStartQuerySuccess(
145                                             IExampleStoreIterator iterator) {
146                                         LogUtil.d(TAG, "Acquired iterator");
147                                         exampleStats.mStartQueryLatencyNanos.addAndGet(
148                                                 SystemClock.elapsedRealtimeNanos()
149                                                         - startCallTimeNanos);
150                                         completer.set(iterator);
151                                         Trace.endAsyncSection(TRACE_GET_EXAMPLE_STORE_ITERATOR, 0);
152                                     }
153 
154                                     @Override
155                                     public void onStartQueryFailure(int errorCode) {
156                                         LogUtil.e(TAG, "Could not acquire iterator: " + errorCode);
157                                         exampleStats.mStartQueryLatencyNanos.addAndGet(
158                                                 SystemClock.elapsedRealtimeNanos()
159                                                         - startCallTimeNanos);
160                                         completer.setException(
161                                                 new IllegalStateException(
162                                                         "StartQuery failed: " + errorCode));
163                                         Trace.endAsyncSection(TRACE_GET_EXAMPLE_STORE_ITERATOR, 0);
164                                     }
165                                 });
166                     } catch (Exception e) {
167                         completer.setException(e);
168                     }
169                     return "runExampleStoreStartQuery";
170                 });
171     }
172 }
173