1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include "PosePredictorVerifier.h"
20 #include <memory>
21 #include <audio_utils/Statistics.h>
22 #include <media/PosePredictorType.h>
23 #include <media/Twist.h>
24 #include <media/VectorRecorder.h>
25 
26 namespace android::media {
27 
28 // Interface for generic pose predictors
29 class PredictorBase {
30 public:
31     virtual ~PredictorBase() = default;
32     virtual void add(int64_t atNs, const Pose3f& pose, const Twist3f& twist) = 0;
33     virtual Pose3f predict(int64_t atNs) const = 0;
34     virtual void reset() = 0;
35     virtual std::string name() const = 0;
36     virtual std::string toString(size_t index) const = 0;
37 };
38 
39 /**
40  * LastPredictor uses the last sample Pose for prediction
41  *
42  * This class is not thread-safe.
43  */
44 class LastPredictor : public PredictorBase {
45 public:
add(int64_t atNs,const Pose3f & pose,const Twist3f & twist)46     void add(int64_t atNs, const Pose3f& pose, const Twist3f& twist) override {
47         (void)atNs;
48         (void)twist;
49         mLastPose = pose;
50     }
51 
predict(int64_t atNs)52     Pose3f predict(int64_t atNs) const override {
53         (void)atNs;
54         return mLastPose;
55     }
56 
reset()57     void reset() override {
58         mLastPose = {};
59     }
60 
name()61     std::string name() const override {
62         return "LAST";
63     }
64 
toString(size_t index)65     std::string toString(size_t index) const override {
66         std::string s(index, ' ');
67         s.append("LastPredictor using last pose: ")
68             .append(mLastPose.toString())
69             .append("\n");
70         return s;
71     }
72 
73 private:
74     Pose3f mLastPose;
75 };
76 
77 /**
78  * TwistPredictor uses the last sample Twist and Pose for prediction
79  *
80  * This class is not thread-safe.
81  */
82 class TwistPredictor : public PredictorBase {
83 public:
add(int64_t atNs,const Pose3f & pose,const Twist3f & twist)84     void add(int64_t atNs, const Pose3f& pose, const Twist3f& twist) override {
85         mLastAtNs = atNs;
86         mLastPose = pose;
87         mLastTwist = twist;
88     }
89 
predict(int64_t atNs)90     Pose3f predict(int64_t atNs) const override {
91         return mLastPose * integrate(mLastTwist, atNs - mLastAtNs);
92     }
93 
reset()94     void reset() override {
95         mLastAtNs = {};
96         mLastPose = {};
97         mLastTwist = {};
98     }
99 
name()100     std::string name() const override {
101         return "TWIST";
102     }
103 
toString(size_t index)104     std::string toString(size_t index) const override {
105         std::string s(index, ' ');
106         s.append("TwistPredictor using last pose: ")
107             .append(mLastPose.toString())
108             .append(" last twist: ")
109             .append(mLastTwist.toString())
110             .append("\n");
111         return s;
112     }
113 
114 private:
115     int64_t mLastAtNs{};
116     Pose3f mLastPose;
117     Twist3f mLastTwist;
118 };
119 
120 
121 /**
122  * LeastSquaresPredictor uses the Pose history for prediction.
123  *
124  * A exponential weighted least squares is used.
125  *
126  * This class is not thread-safe.
127  */
128 class LeastSquaresPredictor : public PredictorBase {
129 public:
130     // alpha is the exponential decay.
131     LeastSquaresPredictor(double alpha = kDefaultAlphaEstimator)
mAlpha(alpha)132         : mAlpha(alpha)
133         , mRw(alpha)
134         , mRx(alpha)
135         , mRy(alpha)
136         , mRz(alpha)
137         {}
138 
139     void add(int64_t atNs, const Pose3f& pose, const Twist3f& twist) override;
140     Pose3f predict(int64_t atNs) const override;
141     void reset() override;
name()142     std::string name() const override {
143         return "LEAST_SQUARES(" + std::to_string(mAlpha) + ")";
144     }
145     std::string toString(size_t index) const override;
146 
147 private:
148     const double mAlpha;
149     int64_t mLastAtNs{};
150     Pose3f mLastPose;
151     static constexpr double kDefaultAlphaEstimator = 0.2;
152     static constexpr size_t kMinimumSamplesForPrediction = 4;
153     audio_utils::LinearLeastSquaresFit<double> mRw;
154     audio_utils::LinearLeastSquaresFit<double> mRx;
155     audio_utils::LinearLeastSquaresFit<double> mRy;
156     audio_utils::LinearLeastSquaresFit<double> mRz;
157 };
158 
159 /*
160  * PosePredictor predicts the pose given sensor input at a time in the future.
161  *
162  * This class is not thread safe.
163  */
164 class PosePredictor {
165 public:
166     PosePredictor();
167 
168     Pose3f predict(int64_t timestampNs, const Pose3f& pose, const Twist3f& twist,
169             float predictionDurationNs);
170 
171     void setPosePredictorType(PosePredictorType type);
172 
173     // convert predictions to a printable string
174     std::string toString(size_t index) const;
175 
176 private:
177     static constexpr int64_t kMaximumSampleIntervalBeforeResetNs =
178             300'000'000;
179 
180     // Predictors
181     const std::vector<std::shared_ptr<PredictorBase>> mPredictors;
182 
183     // Verifiers, create one for an array of future lookaheads for comparison.
184     const std::vector<int> mLookaheadMs;
185 
186     std::vector<PosePredictorVerifier> mVerifiers;
187 
188     const std::vector<size_t> mDelimiterIdx;
189 
190     // Recorders
191     media::VectorRecorder mPredictionRecorder{
192         std::size(mVerifiers) /* vectorSize */, std::chrono::seconds(1), 10 /* maxLogLine */,
193         mDelimiterIdx};
194     media::VectorRecorder mPredictionDurableRecorder{
195         std::size(mVerifiers) /* vectorSize */, std::chrono::minutes(1), 10 /* maxLogLine */,
196         mDelimiterIdx};
197 
198     // Status
199 
200     // SetType is the externally set predictor type.  It may include AUTO.
201     PosePredictorType mSetType = PosePredictorType::LEAST_SQUARES;
202 
203     // CurrentType is the actual predictor type used by this class.
204     // It does not include AUTO because that metatype means the class
205     // chooses the best predictor type based on sensor statistics.
206     PosePredictorType mCurrentType = PosePredictorType::LEAST_SQUARES;
207 
208     int64_t mResets{};
209     int64_t mLastTimestampNs{};
210 
211     // Returns current predictor
212     std::shared_ptr<PredictorBase> getCurrentPredictor() const;
213 };
214 
215 }  // namespace android::media
216