1 /* 2 * Copyright (C) 2024 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.federatedcompute.services.examplestore; 18 19 import static com.android.federatedcompute.services.common.Constants.TRACE_GET_EXAMPLE_STORE_ITERATOR; 20 21 import android.content.Context; 22 import android.federatedcompute.aidl.IExampleStoreCallback; 23 import android.federatedcompute.aidl.IExampleStoreIterator; 24 import android.federatedcompute.aidl.IExampleStoreService; 25 import android.federatedcompute.common.ClientConstants; 26 import android.os.Bundle; 27 import android.os.SystemClock; 28 import android.os.Trace; 29 30 import androidx.concurrent.futures.CallbackToFutureAdapter; 31 32 import com.android.federatedcompute.internal.util.AbstractServiceBinder; 33 import com.android.federatedcompute.internal.util.LogUtil; 34 import com.android.federatedcompute.services.common.ExampleStats; 35 import com.android.federatedcompute.services.common.FlagsFactory; 36 import com.android.federatedcompute.services.data.FederatedTrainingTask; 37 38 import com.google.common.util.concurrent.ListenableFuture; 39 import com.google.internal.federated.plan.ExampleSelector; 40 41 import java.util.concurrent.ArrayBlockingQueue; 42 import java.util.concurrent.BlockingQueue; 43 import java.util.concurrent.TimeUnit; 44 45 /** Provides {@link IExampleStoreService}. */ 46 public class ExampleStoreServiceProvider { 47 private static final String TAG = ExampleStoreServiceProvider.class.getSimpleName(); 48 private AbstractServiceBinder<IExampleStoreService> mExampleStoreServiceBinder; 49 50 /** Returns {@link IExampleStoreService}. */ getExampleStoreService(String packageName, Context context)51 public IExampleStoreService getExampleStoreService(String packageName, Context context) { 52 mExampleStoreServiceBinder = 53 AbstractServiceBinder.getServiceBinderByIntent( 54 context, 55 ClientConstants.EXAMPLE_STORE_ACTION, 56 packageName, 57 IExampleStoreService.Stub::asInterface); 58 return mExampleStoreServiceBinder.getService(Runnable::run); 59 } 60 61 /** Unbind from {@link IExampleStoreService}. */ unbindFromExampleStoreService()62 public void unbindFromExampleStoreService() { 63 mExampleStoreServiceBinder.unbindFromService(); 64 } 65 66 /** Returns an {@link IExampleStoreIterator} implemented by client app in synchronized call. */ getExampleIterator( IExampleStoreService exampleStoreService, FederatedTrainingTask task, String taskName, int minExample, ExampleSelector exampleSelector)67 public IExampleStoreIterator getExampleIterator( 68 IExampleStoreService exampleStoreService, 69 FederatedTrainingTask task, 70 String taskName, 71 int minExample, 72 ExampleSelector exampleSelector) { 73 try { 74 Trace.beginAsyncSection(TRACE_GET_EXAMPLE_STORE_ITERATOR, 1); 75 Bundle bundle = new Bundle(); 76 bundle.putString(ClientConstants.EXTRA_POPULATION_NAME, task.populationName()); 77 bundle.putString(ClientConstants.EXTRA_TASK_ID, taskName); 78 bundle.putByteArray(ClientConstants.EXTRA_CONTEXT_DATA, task.contextData()); 79 bundle.putInt(ClientConstants.EXTRA_ELIGIBILITY_MIN_EXAMPLE, minExample); 80 if (exampleSelector != null) { 81 byte[] criteria = exampleSelector.getCriteria().toByteArray(); 82 byte[] resumptionToken = exampleSelector.getResumptionToken().toByteArray(); 83 bundle.putByteArray( 84 ClientConstants.EXTRA_EXAMPLE_ITERATOR_RESUMPTION_TOKEN, resumptionToken); 85 bundle.putByteArray(ClientConstants.EXTRA_EXAMPLE_ITERATOR_CRITERIA, criteria); 86 bundle.putString( 87 ClientConstants.EXTRA_COLLECTION_URI, exampleSelector.getCollectionUri()); 88 } 89 BlockingQueue<CallbackResult> asyncResult = new ArrayBlockingQueue<>(1); 90 exampleStoreService.startQuery( 91 bundle, 92 new IExampleStoreCallback.Stub() { 93 @Override 94 public void onStartQuerySuccess(IExampleStoreIterator iterator) { 95 LogUtil.d(TAG, "Acquired iterator"); 96 asyncResult.add(new CallbackResult(iterator, 0)); 97 Trace.endAsyncSection(TRACE_GET_EXAMPLE_STORE_ITERATOR, 1); 98 } 99 100 @Override 101 public void onStartQueryFailure(int errorCode) { 102 LogUtil.e(TAG, "Could not acquire iterator: " + errorCode); 103 asyncResult.add(new CallbackResult(null, errorCode)); 104 Trace.endAsyncSection(TRACE_GET_EXAMPLE_STORE_ITERATOR, 1); 105 } 106 }); 107 CallbackResult callbackResult = 108 asyncResult.poll( 109 FlagsFactory.getFlags().getExampleStoreServiceCallbackTimeoutSec(), 110 TimeUnit.SECONDS); 111 // Callback result is null if timeout. 112 if (callbackResult == null || callbackResult.mErrorCode != 0) { 113 return null; 114 } 115 return callbackResult.mIterator; 116 } catch (Exception e) { 117 LogUtil.e(TAG, e, "Got exception when StartQuery"); 118 return null; 119 } 120 } 121 122 private static class CallbackResult { 123 final IExampleStoreIterator mIterator; 124 final int mErrorCode; 125 CallbackResult(IExampleStoreIterator iterator, int errorCode)126 CallbackResult(IExampleStoreIterator iterator, int errorCode) { 127 mIterator = iterator; 128 mErrorCode = errorCode; 129 } 130 } 131 runExampleStoreStartQuery( IExampleStoreService exampleStoreService, Bundle input, ExampleStats exampleStats, long startCallTimeNanos)132 private ListenableFuture<IExampleStoreIterator> runExampleStoreStartQuery( 133 IExampleStoreService exampleStoreService, 134 Bundle input, 135 ExampleStats exampleStats, 136 long startCallTimeNanos) { 137 return CallbackToFutureAdapter.getFuture( 138 completer -> { 139 try { 140 exampleStoreService.startQuery( 141 input, 142 new IExampleStoreCallback.Stub() { 143 @Override 144 public void onStartQuerySuccess( 145 IExampleStoreIterator iterator) { 146 LogUtil.d(TAG, "Acquired iterator"); 147 exampleStats.mStartQueryLatencyNanos.addAndGet( 148 SystemClock.elapsedRealtimeNanos() 149 - startCallTimeNanos); 150 completer.set(iterator); 151 Trace.endAsyncSection(TRACE_GET_EXAMPLE_STORE_ITERATOR, 0); 152 } 153 154 @Override 155 public void onStartQueryFailure(int errorCode) { 156 LogUtil.e(TAG, "Could not acquire iterator: " + errorCode); 157 exampleStats.mStartQueryLatencyNanos.addAndGet( 158 SystemClock.elapsedRealtimeNanos() 159 - startCallTimeNanos); 160 completer.setException( 161 new IllegalStateException( 162 "StartQuery failed: " + errorCode)); 163 Trace.endAsyncSection(TRACE_GET_EXAMPLE_STORE_ITERATOR, 0); 164 } 165 }); 166 } catch (Exception e) { 167 completer.setException(e); 168 } 169 return "runExampleStoreStartQuery"; 170 }); 171 } 172 } 173