/* * Copyright (C) 2023 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include namespace android { struct TfLiteMotionPredictorSample { // The untransformed AMOTION_EVENT_AXIS_X and AMOTION_EVENT_AXIS_Y of the sample. struct Point { float x; float y; } position; // The AMOTION_EVENT_AXIS_PRESSURE, _TILT, and _ORIENTATION. float pressure; float tilt; float orientation; }; inline TfLiteMotionPredictorSample::Point operator-(const TfLiteMotionPredictorSample::Point& lhs, const TfLiteMotionPredictorSample::Point& rhs) { return {.x = lhs.x - rhs.x, .y = lhs.y - rhs.y}; } class TfLiteMotionPredictorModel; // Buffer storage for a TfLiteMotionPredictorModel. class TfLiteMotionPredictorBuffers { public: // Creates buffer storage for a model with the given input length. TfLiteMotionPredictorBuffers(size_t inputLength); // Adds a motion sample to the buffers. void pushSample(int64_t timestamp, TfLiteMotionPredictorSample sample); // Returns true if the buffers are complete enough to generate a prediction. bool isReady() const { // Predictions can't be applied unless there are at least two points to determine // the direction to apply them in. return mAxisFrom && mAxisTo; } // Resets all buffers to their initial state. void reset(); // Copies the buffers to those of a model for prediction. void copyTo(TfLiteMotionPredictorModel& model) const; // Returns the current axis of the buffer's samples. Only valid if isReady(). TfLiteMotionPredictorSample axisFrom() const { return *mAxisFrom; } TfLiteMotionPredictorSample axisTo() const { return *mAxisTo; } // Returns the timestamp of the last sample. int64_t lastTimestamp() const { return mTimestamp; } private: int64_t mTimestamp = 0; RingBuffer mInputR; RingBuffer mInputPhi; RingBuffer mInputPressure; RingBuffer mInputTilt; RingBuffer mInputOrientation; // The samples defining the current polar axis. std::optional mAxisFrom; std::optional mAxisTo; }; // A TFLite model for generating motion predictions. class TfLiteMotionPredictorModel { public: struct Config { // The time between predictions. nsecs_t predictionInterval = 0; // The noise floor for predictions. // Distances (r) less than this should be discarded as noise. float distanceNoiseFloor = 0; // Low and high jerk thresholds (with normalized dt = 1) for predictions. // High jerk means more predictions will be pruned, vice versa for low. float lowJerk = 0; float highJerk = 0; }; // Creates a model from an encoded Flatbuffer model. static std::unique_ptr create(); ~TfLiteMotionPredictorModel(); // Returns the length of the model's input buffers. size_t inputLength() const; // Returns the length of the model's output buffers. size_t outputLength() const; const Config& config() const { return mConfig; } // Executes the model. // Returns true if the model successfully executed and the output tensors can be read. bool invoke(); // Returns mutable buffers to the input tensors of inputLength() elements. std::span inputR(); std::span inputPhi(); std::span inputPressure(); std::span inputOrientation(); std::span inputTilt(); // Returns immutable buffers to the output tensors of identical length. Only valid after a // successful call to invoke(). std::span outputR() const; std::span outputPhi() const; std::span outputPressure() const; private: explicit TfLiteMotionPredictorModel(std::unique_ptr model, Config config); void allocateTensors(); void attachInputTensors(); void attachOutputTensors(); TfLiteTensor* mInputR = nullptr; TfLiteTensor* mInputPhi = nullptr; TfLiteTensor* mInputPressure = nullptr; TfLiteTensor* mInputTilt = nullptr; TfLiteTensor* mInputOrientation = nullptr; const TfLiteTensor* mOutputR = nullptr; const TfLiteTensor* mOutputPhi = nullptr; const TfLiteTensor* mOutputPressure = nullptr; std::unique_ptr mFlatBuffer; std::unique_ptr mErrorReporter; std::unique_ptr mModel; std::unique_ptr mInterpreter; tflite::SignatureRunner* mRunner = nullptr; const Config mConfig = {}; }; } // namespace android