1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.service.ondeviceintelligence;
18 
19 import static android.app.ondeviceintelligence.OnDeviceIntelligenceManager.AUGMENT_REQUEST_CONTENT_BUNDLE_KEY;
20 import static android.app.ondeviceintelligence.flags.Flags.FLAG_ENABLE_ON_DEVICE_INTELLIGENCE;
21 
22 import static com.android.internal.util.function.pooled.PooledLambda.obtainMessage;
23 
24 import android.annotation.CallbackExecutor;
25 import android.annotation.CallSuper;
26 import android.annotation.FlaggedApi;
27 import android.annotation.NonNull;
28 import android.annotation.Nullable;
29 import android.annotation.SdkConstant;
30 import android.annotation.SuppressLint;
31 import android.annotation.SystemApi;
32 import android.app.Service;
33 import android.app.ondeviceintelligence.Feature;
34 import android.app.ondeviceintelligence.IProcessingSignal;
35 import android.app.ondeviceintelligence.IResponseCallback;
36 import android.app.ondeviceintelligence.IStreamingResponseCallback;
37 import android.app.ondeviceintelligence.ITokenInfoCallback;
38 import android.app.ondeviceintelligence.OnDeviceIntelligenceException;
39 import android.app.ondeviceintelligence.OnDeviceIntelligenceManager;
40 import android.app.ondeviceintelligence.OnDeviceIntelligenceManager.InferenceParams;
41 import android.app.ondeviceintelligence.OnDeviceIntelligenceManager.StateParams;
42 import android.app.ondeviceintelligence.ProcessingCallback;
43 import android.app.ondeviceintelligence.ProcessingSignal;
44 import android.app.ondeviceintelligence.StreamingProcessingCallback;
45 import android.app.ondeviceintelligence.TokenInfo;
46 import android.content.Context;
47 import android.content.Intent;
48 import android.os.Bundle;
49 import android.os.CancellationSignal;
50 import android.os.Handler;
51 import android.os.HandlerExecutor;
52 import android.os.IBinder;
53 import android.os.ICancellationSignal;
54 import android.os.IRemoteCallback;
55 import android.os.Looper;
56 import android.os.OutcomeReceiver;
57 import android.os.ParcelFileDescriptor;
58 import android.os.PersistableBundle;
59 import android.os.RemoteCallback;
60 import android.os.RemoteException;
61 import android.util.Log;
62 import android.util.Slog;
63 
64 import com.android.internal.infra.AndroidFuture;
65 
66 import java.io.FileInputStream;
67 import java.io.FileNotFoundException;
68 import java.util.HashMap;
69 import java.util.Map;
70 import java.util.Objects;
71 import java.util.concurrent.ExecutionException;
72 import java.util.concurrent.Executor;
73 import java.util.function.Consumer;
74 
75 /**
76  * Abstract base class for performing inference in a isolated process. This service exposes its
77  * methods via {@link android.app.ondeviceintelligence.OnDeviceIntelligenceManager}.
78  *
79  * <p> A service that provides methods to perform on-device inference both in streaming and
80  * non-streaming fashion. Also, provides a way to register a storage service that will be used to
81  * read-only access files from the {@link OnDeviceIntelligenceService} counterpart. </p>
82  *
83  * <p> Similar to {@link OnDeviceIntelligenceManager} class, the contracts in this service are
84  * defined to be open-ended in general, to allow interoperability. Therefore, it is recommended
85  * that implementations of this system-service expose this API to the clients via a library which
86  * has more defined contract.</p>
87  *
88  * <pre>
89  * {@literal
90  * <service android:name=".SampleSandboxedInferenceService"
91  *          android:permission="android.permission.BIND_ONDEVICE_SANDBOXED_INFERENCE_SERVICE"
92  *          android:isolatedProcess="true">
93  * </service>}
94  * </pre>
95  *
96  * @hide
97  */
98 @SystemApi
99 @FlaggedApi(FLAG_ENABLE_ON_DEVICE_INTELLIGENCE)
100 public abstract class OnDeviceSandboxedInferenceService extends Service {
101     private static final String TAG = OnDeviceSandboxedInferenceService.class.getSimpleName();
102 
103     /**
104      * @hide
105      */
106     public static final String INFERENCE_INFO_BUNDLE_KEY = "inference_info";
107 
108     /**
109      * The {@link Intent} that must be declared as handled by the service. To be supported, the
110      * service must also require the
111      * {@link android.Manifest.permission#BIND_ON_DEVICE_SANDBOXED_INFERENCE_SERVICE}
112      * permission so that other applications can not abuse it.
113      */
114     @SdkConstant(SdkConstant.SdkConstantType.SERVICE_ACTION)
115     public static final String SERVICE_INTERFACE =
116             "android.service.ondeviceintelligence.OnDeviceSandboxedInferenceService";
117 
118     // TODO(339594686): make API
119     /**
120      * @hide
121      */
122     public static final String REGISTER_MODEL_UPDATE_CALLBACK_BUNDLE_KEY =
123             "register_model_update_callback";
124     /**
125      * @hide
126      */
127     public static final String MODEL_LOADED_BUNDLE_KEY = "model_loaded";
128     /**
129      * @hide
130      */
131     public static final String MODEL_UNLOADED_BUNDLE_KEY = "model_unloaded";
132 
133     /**
134      * @hide
135      */
136     public static final String DEVICE_CONFIG_UPDATE_BUNDLE_KEY = "device_config_update";
137 
138     private IRemoteStorageService mRemoteStorageService;
139     private Handler mHandler;
140 
141     @CallSuper
142     @Override
onCreate()143     public void onCreate() {
144         super.onCreate();
145         mHandler = new Handler(Looper.getMainLooper(), null /* callback */, true /* async */);
146     }
147 
148     /**
149      * @hide
150      */
151     @Nullable
152     @Override
onBind(@onNull Intent intent)153     public final IBinder onBind(@NonNull Intent intent) {
154         if (SERVICE_INTERFACE.equals(intent.getAction())) {
155             return new IOnDeviceSandboxedInferenceService.Stub() {
156                 @Override
157                 public void registerRemoteStorageService(IRemoteStorageService storageService,
158                         IRemoteCallback remoteCallback) throws RemoteException {
159                     Objects.requireNonNull(storageService);
160                     mRemoteStorageService = storageService;
161                     remoteCallback.sendResult(
162                             Bundle.EMPTY); //to notify caller uid to system-server.
163                 }
164 
165                 @Override
166                 public void requestTokenInfo(int callerUid, Feature feature, Bundle request,
167                         AndroidFuture cancellationSignalFuture,
168                         ITokenInfoCallback tokenInfoCallback) {
169                     Objects.requireNonNull(feature);
170                     Objects.requireNonNull(tokenInfoCallback);
171                     ICancellationSignal transport = null;
172                     if (cancellationSignalFuture != null) {
173                         transport = CancellationSignal.createTransport();
174                         cancellationSignalFuture.complete(transport);
175                     }
176 
177                     mHandler.executeOrSendMessage(
178                             obtainMessage(
179                                     OnDeviceSandboxedInferenceService::onTokenInfoRequest,
180                                     OnDeviceSandboxedInferenceService.this,
181                                     callerUid, feature,
182                                     request,
183                                     CancellationSignal.fromTransport(transport),
184                                     wrapTokenInfoCallback(tokenInfoCallback)));
185                 }
186 
187                 @Override
188                 public void processRequestStreaming(int callerUid, Feature feature, Bundle request,
189                         int requestType,
190                         AndroidFuture cancellationSignalFuture,
191                         AndroidFuture processingSignalFuture,
192                         IStreamingResponseCallback callback) {
193                     Objects.requireNonNull(feature);
194                     Objects.requireNonNull(callback);
195 
196                     ICancellationSignal transport = null;
197                     if (cancellationSignalFuture != null) {
198                         transport = CancellationSignal.createTransport();
199                         cancellationSignalFuture.complete(transport);
200                     }
201                     IProcessingSignal processingSignalTransport = null;
202                     if (processingSignalFuture != null) {
203                         processingSignalTransport = ProcessingSignal.createTransport();
204                         processingSignalFuture.complete(processingSignalTransport);
205                     }
206 
207 
208                     mHandler.executeOrSendMessage(
209                             obtainMessage(
210                                     OnDeviceSandboxedInferenceService::onProcessRequestStreaming,
211                                     OnDeviceSandboxedInferenceService.this, callerUid,
212                                     feature,
213                                     request,
214                                     requestType,
215                                     CancellationSignal.fromTransport(transport),
216                                     ProcessingSignal.fromTransport(processingSignalTransport),
217                                     wrapStreamingResponseCallback(callback)));
218                 }
219 
220                 @Override
221                 public void processRequest(int callerUid, Feature feature, Bundle request,
222                         int requestType,
223                         AndroidFuture cancellationSignalFuture,
224                         AndroidFuture processingSignalFuture,
225                         IResponseCallback callback) {
226                     Objects.requireNonNull(feature);
227                     Objects.requireNonNull(callback);
228                     ICancellationSignal transport = null;
229                     if (cancellationSignalFuture != null) {
230                         transport = CancellationSignal.createTransport();
231                         cancellationSignalFuture.complete(transport);
232                     }
233                     IProcessingSignal processingSignalTransport = null;
234                     if (processingSignalFuture != null) {
235                         processingSignalTransport = ProcessingSignal.createTransport();
236                         processingSignalFuture.complete(processingSignalTransport);
237                     }
238                     mHandler.executeOrSendMessage(
239                             obtainMessage(
240                                     OnDeviceSandboxedInferenceService::onProcessRequest,
241                                     OnDeviceSandboxedInferenceService.this, callerUid, feature,
242                                     request, requestType,
243                                     CancellationSignal.fromTransport(transport),
244                                     ProcessingSignal.fromTransport(processingSignalTransport),
245                                     wrapResponseCallback(callback)));
246                 }
247 
248                 @Override
249                 public void updateProcessingState(Bundle processingState,
250                         IProcessingUpdateStatusCallback callback) {
251                     Objects.requireNonNull(processingState);
252                     Objects.requireNonNull(callback);
253                     mHandler.executeOrSendMessage(
254                             obtainMessage(
255                                     OnDeviceSandboxedInferenceService::onUpdateProcessingState,
256                                     OnDeviceSandboxedInferenceService.this, processingState,
257                                     wrapOutcomeReceiver(callback)));
258                 }
259             };
260         }
261         Slog.w(TAG, "Incorrect service interface, returning null.");
262         return null;
263     }
264 
265     /**
266      * Invoked when caller  wants to obtain token info related to the payload in the passed
267      * content, associated with the provided feature.
268      * The expectation from the implementation is that when processing is complete, it
269      * should provide the token info in the {@link OutcomeReceiver#onResult}.
270      *
271      * @param callerUid          UID of the caller that initiated this call chain.
272      * @param feature            feature which is associated with the request.
273      * @param request            request that requires processing.
274      * @param cancellationSignal Cancellation Signal to receive cancellation events from client and
275      *                           configure a listener to.
276      * @param callback           callback to populate failure or the token info for the provided
277      *                           request.
278      */
279     @NonNull
280     public abstract void onTokenInfoRequest(
281             int callerUid, @NonNull Feature feature,
282             @NonNull @InferenceParams Bundle request,
283             @Nullable CancellationSignal cancellationSignal,
284             @NonNull OutcomeReceiver<TokenInfo, OnDeviceIntelligenceException> callback);
285 
286     /**
287      * Invoked when caller provides a request for a particular feature to be processed in a
288      * streaming manner. The expectation from the implementation is that when processing the
289      * request,
290      * it periodically populates the {@link StreamingProcessingCallback#onPartialResult} to
291      * continuously
292      * provide partial Bundle results for the caller to utilize. Optionally the implementation can
293      * provide the complete response in the {@link StreamingProcessingCallback#onResult} upon
294      * processing completion.
295      *
296      * @param callerUid          UID of the caller that initiated this call chain.
297      * @param feature            feature which is associated with the request.
298      * @param request            request that requires processing.
299      * @param requestType        identifier representing the type of request.
300      * @param cancellationSignal Cancellation Signal to receive cancellation events from client and
301      *                           configure a listener to.
302      * @param processingSignal   Signal to receive custom action instructions from client.
303      * @param callback           callback to populate the partial responses, failure and optionally
304      *                           full response for the provided request.
305      */
306     @NonNull
307     public abstract void onProcessRequestStreaming(
308             int callerUid, @NonNull Feature feature,
309             @NonNull @InferenceParams Bundle request,
310             @OnDeviceIntelligenceManager.RequestType int requestType,
311             @Nullable CancellationSignal cancellationSignal,
312             @Nullable ProcessingSignal processingSignal,
313             @NonNull StreamingProcessingCallback callback);
314 
315     /**
316      * Invoked when caller provides a request for a particular feature to be processed in one shot
317      * completely.
318      * The expectation from the implementation is that when processing the request is complete, it
319      * should
320      * provide the complete response in the {@link OutcomeReceiver#onResult}.
321      *
322      * @param callerUid          UID of the caller that initiated this call chain.
323      * @param feature            feature which is associated with the request.
324      * @param request            request that requires processing.
325      * @param requestType        identifier representing the type of request.
326      * @param cancellationSignal Cancellation Signal to receive cancellation events from client and
327      *                           configure a listener to.
328      * @param processingSignal   Signal to receive custom action instructions from client.
329      * @param callback           callback to populate failure and full response for the provided
330      *                           request.
331      */
332     @NonNull
333     public abstract void onProcessRequest(
334             int callerUid, @NonNull Feature feature,
335             @NonNull @InferenceParams Bundle request,
336             @OnDeviceIntelligenceManager.RequestType int requestType,
337             @Nullable CancellationSignal cancellationSignal,
338             @Nullable ProcessingSignal processingSignal,
339             @NonNull ProcessingCallback callback);
340 
341 
342     /**
343      * Invoked when processing environment needs to be updated or refreshed with fresh
344      * configuration, files or state.
345      *
346      * @param processingState contains updated state and params that are to be applied to the
347      *                        processing environmment,
348      * @param callback        callback to populate the update status and if there are params
349      *                        associated with the status.
350      */
351     public abstract void onUpdateProcessingState(@NonNull @StateParams Bundle processingState,
352             @NonNull OutcomeReceiver<PersistableBundle,
353                     OnDeviceIntelligenceException> callback);
354 
355 
356     /**
357      * Overrides {@link Context#openFileInput} to read files with the given file names under the
358      * internal app storage of the {@link OnDeviceIntelligenceService}, i.e., only files stored in
359      * {@link Context#getFilesDir()} can be opened.
360      */
361     @Override
362     public final FileInputStream openFileInput(@NonNull String filename) throws
363             FileNotFoundException {
364         try {
365             AndroidFuture<ParcelFileDescriptor> future = new AndroidFuture<>();
366             mRemoteStorageService.getReadOnlyFileDescriptor(filename, future);
367             ParcelFileDescriptor pfd = future.get();
368             return new FileInputStream(pfd.getFileDescriptor());
369         } catch (RemoteException | ExecutionException | InterruptedException e) {
370             Log.w(TAG, "Cannot open file due to remote service failure");
371             throw new FileNotFoundException(e.getMessage());
372         }
373     }
374 
375     /**
376      * Provides read-only access to the internal app storage via the
377      * {@link OnDeviceIntelligenceService}. This is an asynchronous alternative for
378      * {@link #openFileInput(String)}.
379      *
380      * @param fileName       File name relative to the {@link Context#getFilesDir()}.
381      * @param resultConsumer Consumer to populate the corresponding file descriptor in.
382      */
383     public final void getReadOnlyFileDescriptor(@NonNull String fileName,
384             @NonNull @CallbackExecutor Executor executor,
385             @NonNull Consumer<ParcelFileDescriptor> resultConsumer) throws FileNotFoundException {
386         AndroidFuture<ParcelFileDescriptor> future = new AndroidFuture<>();
387         try {
388             mRemoteStorageService.getReadOnlyFileDescriptor(fileName, future);
389         } catch (RemoteException e) {
390             Log.w(TAG, "Cannot open file due to remote service failure");
391             throw new FileNotFoundException(e.getMessage());
392         }
393         future.whenCompleteAsync((pfd, err) -> {
394             if (err != null) {
395                 Log.e(TAG, "Failure when reading file: " + fileName + err);
396                 executor.execute(() -> resultConsumer.accept(null));
397             } else {
398                 executor.execute(
399                         () -> resultConsumer.accept(pfd));
400             }
401         }, executor);
402     }
403 
404     /**
405      * Provides access to all file streams required for feature via the
406      * {@link OnDeviceIntelligenceService}.
407      *
408      * @param feature        Feature for which the associated files should be fetched.
409      * @param executor       Executor to run the consumer callback on.
410      * @param resultConsumer Consumer to receive a map of filePath to the corresponding file input
411      *                       stream.
412      */
413     public final void fetchFeatureFileDescriptorMap(@NonNull Feature feature,
414             @NonNull @CallbackExecutor Executor executor,
415             @NonNull Consumer<Map<String, ParcelFileDescriptor>> resultConsumer) {
416         try {
417             mRemoteStorageService.getReadOnlyFeatureFileDescriptorMap(feature,
418                     wrapAsRemoteCallback(resultConsumer, executor));
419         } catch (RemoteException e) {
420             throw new RuntimeException(e);
421         }
422     }
423 
424 
425     /**
426      * Returns the {@link Executor} to use for incoming IPC from request sender into your service
427      * implementation. For e.g. see
428      * {@link ProcessingCallback#onDataAugmentRequest(Bundle,
429      * Consumer)} where we use the executor to populate the consumer.
430      * <p>
431      * Override this method in your {@link OnDeviceSandboxedInferenceService} implementation to
432      * provide the executor you want to use for incoming IPC.
433      *
434      * @return the {@link Executor} to use for incoming IPC from {@link OnDeviceIntelligenceManager}
435      * to {@link OnDeviceSandboxedInferenceService}.
436      */
437     @SuppressLint("OnNameExpected")
438     @NonNull
439     public Executor getCallbackExecutor() {
440         return new HandlerExecutor(Handler.createAsync(getMainLooper()));
441     }
442 
443 
444     private RemoteCallback wrapAsRemoteCallback(
445             @NonNull Consumer<Map<String, ParcelFileDescriptor>> resultConsumer,
446             @NonNull Executor executor) {
447         return new RemoteCallback(result -> {
448             if (result == null) {
449                 executor.execute(() -> resultConsumer.accept(new HashMap<>()));
450             } else {
451                 Map<String, ParcelFileDescriptor> pfdMap = new HashMap<>();
452                 result.keySet().forEach(key ->
453                         pfdMap.put(key, result.getParcelable(key,
454                                 ParcelFileDescriptor.class)));
455                 executor.execute(() -> resultConsumer.accept(pfdMap));
456             }
457         });
458     }
459 
460     private ProcessingCallback wrapResponseCallback(
461             IResponseCallback callback) {
462         return new ProcessingCallback() {
463             @Override
464             public void onResult(@androidx.annotation.NonNull Bundle result) {
465                 try {
466                     callback.onSuccess(result);
467                 } catch (RemoteException e) {
468                     Slog.e(TAG, "Error sending result: " + e);
469                 }
470             }
471 
472             @Override
473             public void onError(
474                     OnDeviceIntelligenceException exception) {
475                 try {
476                     callback.onFailure(exception.getErrorCode(), exception.getMessage(),
477                             exception.getErrorParams());
478                 } catch (RemoteException e) {
479                     Slog.e(TAG, "Error sending result: " + e);
480                 }
481             }
482 
483             @Override
484             public void onDataAugmentRequest(@NonNull Bundle content,
485                     @NonNull Consumer<Bundle> contentCallback) {
486                 try {
487                     callback.onDataAugmentRequest(content, wrapRemoteCallback(contentCallback));
488 
489                 } catch (RemoteException e) {
490                     Slog.e(TAG, "Error sending augment request: " + e);
491                 }
492             }
493         };
494     }
495 
496     private StreamingProcessingCallback wrapStreamingResponseCallback(
497             IStreamingResponseCallback callback) {
498         return new StreamingProcessingCallback() {
499             @Override
500             public void onPartialResult(@androidx.annotation.NonNull Bundle partialResult) {
501                 try {
502                     callback.onNewContent(partialResult);
503                 } catch (RemoteException e) {
504                     Slog.e(TAG, "Error sending result: " + e);
505                 }
506             }
507 
508             @Override
509             public void onResult(@androidx.annotation.NonNull Bundle result) {
510                 try {
511                     callback.onSuccess(result);
512                 } catch (RemoteException e) {
513                     Slog.e(TAG, "Error sending result: " + e);
514                 }
515             }
516 
517             @Override
518             public void onError(
519                     OnDeviceIntelligenceException exception) {
520                 try {
521                     callback.onFailure(exception.getErrorCode(), exception.getMessage(),
522                             exception.getErrorParams());
523                 } catch (RemoteException e) {
524                     Slog.e(TAG, "Error sending result: " + e);
525                 }
526             }
527 
528             @Override
529             public void onDataAugmentRequest(@NonNull Bundle content,
530                     @NonNull Consumer<Bundle> contentCallback) {
531                 try {
532                     callback.onDataAugmentRequest(content, wrapRemoteCallback(contentCallback));
533 
534                 } catch (RemoteException e) {
535                     Slog.e(TAG, "Error sending augment request: " + e);
536                 }
537             }
538         };
539     }
540 
541     private RemoteCallback wrapRemoteCallback(
542             @androidx.annotation.NonNull Consumer<Bundle> contentCallback) {
543         return new RemoteCallback(
544                 result -> {
545                     if (result != null) {
546                         getCallbackExecutor().execute(() -> contentCallback.accept(
547                                 result.getParcelable(AUGMENT_REQUEST_CONTENT_BUNDLE_KEY,
548                                         Bundle.class)));
549                     } else {
550                         getCallbackExecutor().execute(
551                                 () -> contentCallback.accept(null));
552                     }
553                 });
554     }
555 
556     private OutcomeReceiver<TokenInfo, OnDeviceIntelligenceException> wrapTokenInfoCallback(
557             ITokenInfoCallback tokenInfoCallback) {
558         return new OutcomeReceiver<>() {
559             @Override
560             public void onResult(TokenInfo tokenInfo) {
561                 try {
562                     tokenInfoCallback.onSuccess(tokenInfo);
563                 } catch (RemoteException e) {
564                     Slog.e(TAG, "Error sending result: " + e);
565                 }
566             }
567 
568             @Override
569             public void onError(
570                     OnDeviceIntelligenceException exception) {
571                 try {
572                     tokenInfoCallback.onFailure(exception.getErrorCode(), exception.getMessage(),
573                             exception.getErrorParams());
574                 } catch (RemoteException e) {
575                     Slog.e(TAG, "Error sending failure: " + e);
576                 }
577             }
578         };
579     }
580 
581     @NonNull
582     private static OutcomeReceiver<PersistableBundle, OnDeviceIntelligenceException> wrapOutcomeReceiver(
583             IProcessingUpdateStatusCallback callback) {
584         return new OutcomeReceiver<>() {
585             @Override
586             public void onResult(@NonNull PersistableBundle result) {
587                 try {
588                     callback.onSuccess(result);
589                 } catch (RemoteException e) {
590                     Slog.e(TAG, "Error sending result: " + e);
591 
592                 }
593             }
594 
595             @Override
596             public void onError(
597                     @androidx.annotation.NonNull OnDeviceIntelligenceException error) {
598                 try {
599                     callback.onFailure(error.getErrorCode(), error.getMessage());
600                 } catch (RemoteException e) {
601                     Slog.e(TAG, "Error sending exception details: " + e);
602                 }
603             }
604         };
605     }
606 
607 }
608