1 /* 2 * Copyright (C) 2023 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package android.service.ondeviceintelligence; 18 19 import static android.app.ondeviceintelligence.OnDeviceIntelligenceManager.AUGMENT_REQUEST_CONTENT_BUNDLE_KEY; 20 import static android.app.ondeviceintelligence.flags.Flags.FLAG_ENABLE_ON_DEVICE_INTELLIGENCE; 21 22 import static com.android.internal.util.function.pooled.PooledLambda.obtainMessage; 23 24 import android.annotation.CallbackExecutor; 25 import android.annotation.CallSuper; 26 import android.annotation.FlaggedApi; 27 import android.annotation.NonNull; 28 import android.annotation.Nullable; 29 import android.annotation.SdkConstant; 30 import android.annotation.SuppressLint; 31 import android.annotation.SystemApi; 32 import android.app.Service; 33 import android.app.ondeviceintelligence.Feature; 34 import android.app.ondeviceintelligence.IProcessingSignal; 35 import android.app.ondeviceintelligence.IResponseCallback; 36 import android.app.ondeviceintelligence.IStreamingResponseCallback; 37 import android.app.ondeviceintelligence.ITokenInfoCallback; 38 import android.app.ondeviceintelligence.OnDeviceIntelligenceException; 39 import android.app.ondeviceintelligence.OnDeviceIntelligenceManager; 40 import android.app.ondeviceintelligence.OnDeviceIntelligenceManager.InferenceParams; 41 import android.app.ondeviceintelligence.OnDeviceIntelligenceManager.StateParams; 42 import android.app.ondeviceintelligence.ProcessingCallback; 43 import android.app.ondeviceintelligence.ProcessingSignal; 44 import android.app.ondeviceintelligence.StreamingProcessingCallback; 45 import android.app.ondeviceintelligence.TokenInfo; 46 import android.content.Context; 47 import android.content.Intent; 48 import android.os.Bundle; 49 import android.os.CancellationSignal; 50 import android.os.Handler; 51 import android.os.HandlerExecutor; 52 import android.os.IBinder; 53 import android.os.ICancellationSignal; 54 import android.os.IRemoteCallback; 55 import android.os.Looper; 56 import android.os.OutcomeReceiver; 57 import android.os.ParcelFileDescriptor; 58 import android.os.PersistableBundle; 59 import android.os.RemoteCallback; 60 import android.os.RemoteException; 61 import android.util.Log; 62 import android.util.Slog; 63 64 import com.android.internal.infra.AndroidFuture; 65 66 import java.io.FileInputStream; 67 import java.io.FileNotFoundException; 68 import java.util.HashMap; 69 import java.util.Map; 70 import java.util.Objects; 71 import java.util.concurrent.ExecutionException; 72 import java.util.concurrent.Executor; 73 import java.util.function.Consumer; 74 75 /** 76 * Abstract base class for performing inference in a isolated process. This service exposes its 77 * methods via {@link android.app.ondeviceintelligence.OnDeviceIntelligenceManager}. 78 * 79 * <p> A service that provides methods to perform on-device inference both in streaming and 80 * non-streaming fashion. Also, provides a way to register a storage service that will be used to 81 * read-only access files from the {@link OnDeviceIntelligenceService} counterpart. </p> 82 * 83 * <p> Similar to {@link OnDeviceIntelligenceManager} class, the contracts in this service are 84 * defined to be open-ended in general, to allow interoperability. Therefore, it is recommended 85 * that implementations of this system-service expose this API to the clients via a library which 86 * has more defined contract.</p> 87 * 88 * <pre> 89 * {@literal 90 * <service android:name=".SampleSandboxedInferenceService" 91 * android:permission="android.permission.BIND_ONDEVICE_SANDBOXED_INFERENCE_SERVICE" 92 * android:isolatedProcess="true"> 93 * </service>} 94 * </pre> 95 * 96 * @hide 97 */ 98 @SystemApi 99 @FlaggedApi(FLAG_ENABLE_ON_DEVICE_INTELLIGENCE) 100 public abstract class OnDeviceSandboxedInferenceService extends Service { 101 private static final String TAG = OnDeviceSandboxedInferenceService.class.getSimpleName(); 102 103 /** 104 * @hide 105 */ 106 public static final String INFERENCE_INFO_BUNDLE_KEY = "inference_info"; 107 108 /** 109 * The {@link Intent} that must be declared as handled by the service. To be supported, the 110 * service must also require the 111 * {@link android.Manifest.permission#BIND_ON_DEVICE_SANDBOXED_INFERENCE_SERVICE} 112 * permission so that other applications can not abuse it. 113 */ 114 @SdkConstant(SdkConstant.SdkConstantType.SERVICE_ACTION) 115 public static final String SERVICE_INTERFACE = 116 "android.service.ondeviceintelligence.OnDeviceSandboxedInferenceService"; 117 118 // TODO(339594686): make API 119 /** 120 * @hide 121 */ 122 public static final String REGISTER_MODEL_UPDATE_CALLBACK_BUNDLE_KEY = 123 "register_model_update_callback"; 124 /** 125 * @hide 126 */ 127 public static final String MODEL_LOADED_BUNDLE_KEY = "model_loaded"; 128 /** 129 * @hide 130 */ 131 public static final String MODEL_UNLOADED_BUNDLE_KEY = "model_unloaded"; 132 133 /** 134 * @hide 135 */ 136 public static final String DEVICE_CONFIG_UPDATE_BUNDLE_KEY = "device_config_update"; 137 138 private IRemoteStorageService mRemoteStorageService; 139 private Handler mHandler; 140 141 @CallSuper 142 @Override onCreate()143 public void onCreate() { 144 super.onCreate(); 145 mHandler = new Handler(Looper.getMainLooper(), null /* callback */, true /* async */); 146 } 147 148 /** 149 * @hide 150 */ 151 @Nullable 152 @Override onBind(@onNull Intent intent)153 public final IBinder onBind(@NonNull Intent intent) { 154 if (SERVICE_INTERFACE.equals(intent.getAction())) { 155 return new IOnDeviceSandboxedInferenceService.Stub() { 156 @Override 157 public void registerRemoteStorageService(IRemoteStorageService storageService, 158 IRemoteCallback remoteCallback) throws RemoteException { 159 Objects.requireNonNull(storageService); 160 mRemoteStorageService = storageService; 161 remoteCallback.sendResult( 162 Bundle.EMPTY); //to notify caller uid to system-server. 163 } 164 165 @Override 166 public void requestTokenInfo(int callerUid, Feature feature, Bundle request, 167 AndroidFuture cancellationSignalFuture, 168 ITokenInfoCallback tokenInfoCallback) { 169 Objects.requireNonNull(feature); 170 Objects.requireNonNull(tokenInfoCallback); 171 ICancellationSignal transport = null; 172 if (cancellationSignalFuture != null) { 173 transport = CancellationSignal.createTransport(); 174 cancellationSignalFuture.complete(transport); 175 } 176 177 mHandler.executeOrSendMessage( 178 obtainMessage( 179 OnDeviceSandboxedInferenceService::onTokenInfoRequest, 180 OnDeviceSandboxedInferenceService.this, 181 callerUid, feature, 182 request, 183 CancellationSignal.fromTransport(transport), 184 wrapTokenInfoCallback(tokenInfoCallback))); 185 } 186 187 @Override 188 public void processRequestStreaming(int callerUid, Feature feature, Bundle request, 189 int requestType, 190 AndroidFuture cancellationSignalFuture, 191 AndroidFuture processingSignalFuture, 192 IStreamingResponseCallback callback) { 193 Objects.requireNonNull(feature); 194 Objects.requireNonNull(callback); 195 196 ICancellationSignal transport = null; 197 if (cancellationSignalFuture != null) { 198 transport = CancellationSignal.createTransport(); 199 cancellationSignalFuture.complete(transport); 200 } 201 IProcessingSignal processingSignalTransport = null; 202 if (processingSignalFuture != null) { 203 processingSignalTransport = ProcessingSignal.createTransport(); 204 processingSignalFuture.complete(processingSignalTransport); 205 } 206 207 208 mHandler.executeOrSendMessage( 209 obtainMessage( 210 OnDeviceSandboxedInferenceService::onProcessRequestStreaming, 211 OnDeviceSandboxedInferenceService.this, callerUid, 212 feature, 213 request, 214 requestType, 215 CancellationSignal.fromTransport(transport), 216 ProcessingSignal.fromTransport(processingSignalTransport), 217 wrapStreamingResponseCallback(callback))); 218 } 219 220 @Override 221 public void processRequest(int callerUid, Feature feature, Bundle request, 222 int requestType, 223 AndroidFuture cancellationSignalFuture, 224 AndroidFuture processingSignalFuture, 225 IResponseCallback callback) { 226 Objects.requireNonNull(feature); 227 Objects.requireNonNull(callback); 228 ICancellationSignal transport = null; 229 if (cancellationSignalFuture != null) { 230 transport = CancellationSignal.createTransport(); 231 cancellationSignalFuture.complete(transport); 232 } 233 IProcessingSignal processingSignalTransport = null; 234 if (processingSignalFuture != null) { 235 processingSignalTransport = ProcessingSignal.createTransport(); 236 processingSignalFuture.complete(processingSignalTransport); 237 } 238 mHandler.executeOrSendMessage( 239 obtainMessage( 240 OnDeviceSandboxedInferenceService::onProcessRequest, 241 OnDeviceSandboxedInferenceService.this, callerUid, feature, 242 request, requestType, 243 CancellationSignal.fromTransport(transport), 244 ProcessingSignal.fromTransport(processingSignalTransport), 245 wrapResponseCallback(callback))); 246 } 247 248 @Override 249 public void updateProcessingState(Bundle processingState, 250 IProcessingUpdateStatusCallback callback) { 251 Objects.requireNonNull(processingState); 252 Objects.requireNonNull(callback); 253 mHandler.executeOrSendMessage( 254 obtainMessage( 255 OnDeviceSandboxedInferenceService::onUpdateProcessingState, 256 OnDeviceSandboxedInferenceService.this, processingState, 257 wrapOutcomeReceiver(callback))); 258 } 259 }; 260 } 261 Slog.w(TAG, "Incorrect service interface, returning null."); 262 return null; 263 } 264 265 /** 266 * Invoked when caller wants to obtain token info related to the payload in the passed 267 * content, associated with the provided feature. 268 * The expectation from the implementation is that when processing is complete, it 269 * should provide the token info in the {@link OutcomeReceiver#onResult}. 270 * 271 * @param callerUid UID of the caller that initiated this call chain. 272 * @param feature feature which is associated with the request. 273 * @param request request that requires processing. 274 * @param cancellationSignal Cancellation Signal to receive cancellation events from client and 275 * configure a listener to. 276 * @param callback callback to populate failure or the token info for the provided 277 * request. 278 */ 279 @NonNull 280 public abstract void onTokenInfoRequest( 281 int callerUid, @NonNull Feature feature, 282 @NonNull @InferenceParams Bundle request, 283 @Nullable CancellationSignal cancellationSignal, 284 @NonNull OutcomeReceiver<TokenInfo, OnDeviceIntelligenceException> callback); 285 286 /** 287 * Invoked when caller provides a request for a particular feature to be processed in a 288 * streaming manner. The expectation from the implementation is that when processing the 289 * request, 290 * it periodically populates the {@link StreamingProcessingCallback#onPartialResult} to 291 * continuously 292 * provide partial Bundle results for the caller to utilize. Optionally the implementation can 293 * provide the complete response in the {@link StreamingProcessingCallback#onResult} upon 294 * processing completion. 295 * 296 * @param callerUid UID of the caller that initiated this call chain. 297 * @param feature feature which is associated with the request. 298 * @param request request that requires processing. 299 * @param requestType identifier representing the type of request. 300 * @param cancellationSignal Cancellation Signal to receive cancellation events from client and 301 * configure a listener to. 302 * @param processingSignal Signal to receive custom action instructions from client. 303 * @param callback callback to populate the partial responses, failure and optionally 304 * full response for the provided request. 305 */ 306 @NonNull 307 public abstract void onProcessRequestStreaming( 308 int callerUid, @NonNull Feature feature, 309 @NonNull @InferenceParams Bundle request, 310 @OnDeviceIntelligenceManager.RequestType int requestType, 311 @Nullable CancellationSignal cancellationSignal, 312 @Nullable ProcessingSignal processingSignal, 313 @NonNull StreamingProcessingCallback callback); 314 315 /** 316 * Invoked when caller provides a request for a particular feature to be processed in one shot 317 * completely. 318 * The expectation from the implementation is that when processing the request is complete, it 319 * should 320 * provide the complete response in the {@link OutcomeReceiver#onResult}. 321 * 322 * @param callerUid UID of the caller that initiated this call chain. 323 * @param feature feature which is associated with the request. 324 * @param request request that requires processing. 325 * @param requestType identifier representing the type of request. 326 * @param cancellationSignal Cancellation Signal to receive cancellation events from client and 327 * configure a listener to. 328 * @param processingSignal Signal to receive custom action instructions from client. 329 * @param callback callback to populate failure and full response for the provided 330 * request. 331 */ 332 @NonNull 333 public abstract void onProcessRequest( 334 int callerUid, @NonNull Feature feature, 335 @NonNull @InferenceParams Bundle request, 336 @OnDeviceIntelligenceManager.RequestType int requestType, 337 @Nullable CancellationSignal cancellationSignal, 338 @Nullable ProcessingSignal processingSignal, 339 @NonNull ProcessingCallback callback); 340 341 342 /** 343 * Invoked when processing environment needs to be updated or refreshed with fresh 344 * configuration, files or state. 345 * 346 * @param processingState contains updated state and params that are to be applied to the 347 * processing environmment, 348 * @param callback callback to populate the update status and if there are params 349 * associated with the status. 350 */ 351 public abstract void onUpdateProcessingState(@NonNull @StateParams Bundle processingState, 352 @NonNull OutcomeReceiver<PersistableBundle, 353 OnDeviceIntelligenceException> callback); 354 355 356 /** 357 * Overrides {@link Context#openFileInput} to read files with the given file names under the 358 * internal app storage of the {@link OnDeviceIntelligenceService}, i.e., only files stored in 359 * {@link Context#getFilesDir()} can be opened. 360 */ 361 @Override 362 public final FileInputStream openFileInput(@NonNull String filename) throws 363 FileNotFoundException { 364 try { 365 AndroidFuture<ParcelFileDescriptor> future = new AndroidFuture<>(); 366 mRemoteStorageService.getReadOnlyFileDescriptor(filename, future); 367 ParcelFileDescriptor pfd = future.get(); 368 return new FileInputStream(pfd.getFileDescriptor()); 369 } catch (RemoteException | ExecutionException | InterruptedException e) { 370 Log.w(TAG, "Cannot open file due to remote service failure"); 371 throw new FileNotFoundException(e.getMessage()); 372 } 373 } 374 375 /** 376 * Provides read-only access to the internal app storage via the 377 * {@link OnDeviceIntelligenceService}. This is an asynchronous alternative for 378 * {@link #openFileInput(String)}. 379 * 380 * @param fileName File name relative to the {@link Context#getFilesDir()}. 381 * @param resultConsumer Consumer to populate the corresponding file descriptor in. 382 */ 383 public final void getReadOnlyFileDescriptor(@NonNull String fileName, 384 @NonNull @CallbackExecutor Executor executor, 385 @NonNull Consumer<ParcelFileDescriptor> resultConsumer) throws FileNotFoundException { 386 AndroidFuture<ParcelFileDescriptor> future = new AndroidFuture<>(); 387 try { 388 mRemoteStorageService.getReadOnlyFileDescriptor(fileName, future); 389 } catch (RemoteException e) { 390 Log.w(TAG, "Cannot open file due to remote service failure"); 391 throw new FileNotFoundException(e.getMessage()); 392 } 393 future.whenCompleteAsync((pfd, err) -> { 394 if (err != null) { 395 Log.e(TAG, "Failure when reading file: " + fileName + err); 396 executor.execute(() -> resultConsumer.accept(null)); 397 } else { 398 executor.execute( 399 () -> resultConsumer.accept(pfd)); 400 } 401 }, executor); 402 } 403 404 /** 405 * Provides access to all file streams required for feature via the 406 * {@link OnDeviceIntelligenceService}. 407 * 408 * @param feature Feature for which the associated files should be fetched. 409 * @param executor Executor to run the consumer callback on. 410 * @param resultConsumer Consumer to receive a map of filePath to the corresponding file input 411 * stream. 412 */ 413 public final void fetchFeatureFileDescriptorMap(@NonNull Feature feature, 414 @NonNull @CallbackExecutor Executor executor, 415 @NonNull Consumer<Map<String, ParcelFileDescriptor>> resultConsumer) { 416 try { 417 mRemoteStorageService.getReadOnlyFeatureFileDescriptorMap(feature, 418 wrapAsRemoteCallback(resultConsumer, executor)); 419 } catch (RemoteException e) { 420 throw new RuntimeException(e); 421 } 422 } 423 424 425 /** 426 * Returns the {@link Executor} to use for incoming IPC from request sender into your service 427 * implementation. For e.g. see 428 * {@link ProcessingCallback#onDataAugmentRequest(Bundle, 429 * Consumer)} where we use the executor to populate the consumer. 430 * <p> 431 * Override this method in your {@link OnDeviceSandboxedInferenceService} implementation to 432 * provide the executor you want to use for incoming IPC. 433 * 434 * @return the {@link Executor} to use for incoming IPC from {@link OnDeviceIntelligenceManager} 435 * to {@link OnDeviceSandboxedInferenceService}. 436 */ 437 @SuppressLint("OnNameExpected") 438 @NonNull 439 public Executor getCallbackExecutor() { 440 return new HandlerExecutor(Handler.createAsync(getMainLooper())); 441 } 442 443 444 private RemoteCallback wrapAsRemoteCallback( 445 @NonNull Consumer<Map<String, ParcelFileDescriptor>> resultConsumer, 446 @NonNull Executor executor) { 447 return new RemoteCallback(result -> { 448 if (result == null) { 449 executor.execute(() -> resultConsumer.accept(new HashMap<>())); 450 } else { 451 Map<String, ParcelFileDescriptor> pfdMap = new HashMap<>(); 452 result.keySet().forEach(key -> 453 pfdMap.put(key, result.getParcelable(key, 454 ParcelFileDescriptor.class))); 455 executor.execute(() -> resultConsumer.accept(pfdMap)); 456 } 457 }); 458 } 459 460 private ProcessingCallback wrapResponseCallback( 461 IResponseCallback callback) { 462 return new ProcessingCallback() { 463 @Override 464 public void onResult(@androidx.annotation.NonNull Bundle result) { 465 try { 466 callback.onSuccess(result); 467 } catch (RemoteException e) { 468 Slog.e(TAG, "Error sending result: " + e); 469 } 470 } 471 472 @Override 473 public void onError( 474 OnDeviceIntelligenceException exception) { 475 try { 476 callback.onFailure(exception.getErrorCode(), exception.getMessage(), 477 exception.getErrorParams()); 478 } catch (RemoteException e) { 479 Slog.e(TAG, "Error sending result: " + e); 480 } 481 } 482 483 @Override 484 public void onDataAugmentRequest(@NonNull Bundle content, 485 @NonNull Consumer<Bundle> contentCallback) { 486 try { 487 callback.onDataAugmentRequest(content, wrapRemoteCallback(contentCallback)); 488 489 } catch (RemoteException e) { 490 Slog.e(TAG, "Error sending augment request: " + e); 491 } 492 } 493 }; 494 } 495 496 private StreamingProcessingCallback wrapStreamingResponseCallback( 497 IStreamingResponseCallback callback) { 498 return new StreamingProcessingCallback() { 499 @Override 500 public void onPartialResult(@androidx.annotation.NonNull Bundle partialResult) { 501 try { 502 callback.onNewContent(partialResult); 503 } catch (RemoteException e) { 504 Slog.e(TAG, "Error sending result: " + e); 505 } 506 } 507 508 @Override 509 public void onResult(@androidx.annotation.NonNull Bundle result) { 510 try { 511 callback.onSuccess(result); 512 } catch (RemoteException e) { 513 Slog.e(TAG, "Error sending result: " + e); 514 } 515 } 516 517 @Override 518 public void onError( 519 OnDeviceIntelligenceException exception) { 520 try { 521 callback.onFailure(exception.getErrorCode(), exception.getMessage(), 522 exception.getErrorParams()); 523 } catch (RemoteException e) { 524 Slog.e(TAG, "Error sending result: " + e); 525 } 526 } 527 528 @Override 529 public void onDataAugmentRequest(@NonNull Bundle content, 530 @NonNull Consumer<Bundle> contentCallback) { 531 try { 532 callback.onDataAugmentRequest(content, wrapRemoteCallback(contentCallback)); 533 534 } catch (RemoteException e) { 535 Slog.e(TAG, "Error sending augment request: " + e); 536 } 537 } 538 }; 539 } 540 541 private RemoteCallback wrapRemoteCallback( 542 @androidx.annotation.NonNull Consumer<Bundle> contentCallback) { 543 return new RemoteCallback( 544 result -> { 545 if (result != null) { 546 getCallbackExecutor().execute(() -> contentCallback.accept( 547 result.getParcelable(AUGMENT_REQUEST_CONTENT_BUNDLE_KEY, 548 Bundle.class))); 549 } else { 550 getCallbackExecutor().execute( 551 () -> contentCallback.accept(null)); 552 } 553 }); 554 } 555 556 private OutcomeReceiver<TokenInfo, OnDeviceIntelligenceException> wrapTokenInfoCallback( 557 ITokenInfoCallback tokenInfoCallback) { 558 return new OutcomeReceiver<>() { 559 @Override 560 public void onResult(TokenInfo tokenInfo) { 561 try { 562 tokenInfoCallback.onSuccess(tokenInfo); 563 } catch (RemoteException e) { 564 Slog.e(TAG, "Error sending result: " + e); 565 } 566 } 567 568 @Override 569 public void onError( 570 OnDeviceIntelligenceException exception) { 571 try { 572 tokenInfoCallback.onFailure(exception.getErrorCode(), exception.getMessage(), 573 exception.getErrorParams()); 574 } catch (RemoteException e) { 575 Slog.e(TAG, "Error sending failure: " + e); 576 } 577 } 578 }; 579 } 580 581 @NonNull 582 private static OutcomeReceiver<PersistableBundle, OnDeviceIntelligenceException> wrapOutcomeReceiver( 583 IProcessingUpdateStatusCallback callback) { 584 return new OutcomeReceiver<>() { 585 @Override 586 public void onResult(@NonNull PersistableBundle result) { 587 try { 588 callback.onSuccess(result); 589 } catch (RemoteException e) { 590 Slog.e(TAG, "Error sending result: " + e); 591 592 } 593 } 594 595 @Override 596 public void onError( 597 @androidx.annotation.NonNull OnDeviceIntelligenceException error) { 598 try { 599 callback.onFailure(error.getErrorCode(), error.getMessage()); 600 } catch (RemoteException e) { 601 Slog.e(TAG, "Error sending exception details: " + e); 602 } 603 } 604 }; 605 } 606 607 } 608