/* * Copyright (C) 2023 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "PosePredictor.h" namespace android::media { namespace { #ifdef ENABLE_VERIFICATION constexpr bool kEnableVerification = true; constexpr std::array kLookAheadMs{ 50, 100, 200 }; #else constexpr bool kEnableVerification = false; constexpr std::array kLookAheadMs{}; #endif } // namespace void LeastSquaresPredictor::add(int64_t atNs, const Pose3f& pose, const Twist3f& twist) { (void)twist; mLastAtNs = atNs; mLastPose = pose; const auto q = pose.rotation(); const double datNs = static_cast(atNs); mRw.add({datNs, q.w()}); mRx.add({datNs, q.x()}); mRy.add({datNs, q.y()}); mRz.add({datNs, q.z()}); } Pose3f LeastSquaresPredictor::predict(int64_t atNs) const { if (mRw.getN() < kMinimumSamplesForPrediction) return mLastPose; /* * Using parametric form, we have q(t) = { w(t), x(t), y(t), z(t) }. * We compute the least squares prediction of w, x, y, z. */ const double dLookahead = static_cast(atNs); Eigen::Quaternionf lsq( mRw.getYFromX(dLookahead), mRx.getYFromX(dLookahead), mRy.getYFromX(dLookahead), mRz.getYFromX(dLookahead)); /* * We cheat here, since the result lsq is the least squares prediction * in H (arbitrary quaternion), not the least squares prediction in * SO(3) (unit quaternion). * * In other words, the result for lsq is most likely not a unit quaternion. * To solve this, we normalize, thereby selecting the closest unit quaternion * in SO(3) to the prediction in H. */ lsq.normalize(); return Pose3f(lsq); } void LeastSquaresPredictor::reset() { mLastAtNs = {}; mLastPose = {}; mRw.reset(); mRx.reset(); mRy.reset(); mRz.reset(); } std::string LeastSquaresPredictor::toString(size_t index) const { std::string s(index, ' '); s.append("LeastSquaresPredictor using alpha: ") .append(std::to_string(mAlpha)) .append(" last pose: ") .append(mLastPose.toString()) .append("\n"); return s; } // Formatting static inline std::vector createDelimiterIdx(size_t predictors, size_t lookaheads) { if (lookaheads == 0) return {}; --lookaheads; std::vector delimiterIdx(lookaheads); for (size_t i = 0; i < lookaheads; ++i) { delimiterIdx[i] = (i + 1) * predictors; } return delimiterIdx; } PosePredictor::PosePredictor() : mPredictors{ // First predictors must match switch in getCurrentPredictor() std::make_shared(), std::make_shared(), std::make_shared(), // After this, can place additional predictors here for comparison such as // std::make_shared(0.25), } , mLookaheadMs(kLookAheadMs.begin(), kLookAheadMs.end()) , mVerifiers(std::size(mLookaheadMs) * std::size(mPredictors)) , mDelimiterIdx(createDelimiterIdx(std::size(mPredictors), std::size(mLookaheadMs))) , mPredictionRecorder( std::size(mVerifiers) /* vectorSize */, std::chrono::seconds(1), 10 /* maxLogLine */, mDelimiterIdx) , mPredictionDurableRecorder( std::size(mVerifiers) /* vectorSize */, std::chrono::minutes(1), 10 /* maxLogLine */, mDelimiterIdx) { } Pose3f PosePredictor::predict( int64_t timestampNs, const Pose3f& pose, const Twist3f& twist, float predictionDurationNs) { if (timestampNs - mLastTimestampNs > kMaximumSampleIntervalBeforeResetNs) { for (const auto& predictor : mPredictors) { predictor->reset(); } ++mResets; } mLastTimestampNs = timestampNs; auto selectedPredictor = getCurrentPredictor(); if constexpr (kEnableVerification) { // Update all Predictors for (const auto& predictor : mPredictors) { predictor->add(timestampNs, pose, twist); } // Update Verifiers and calculate errors std::vector error(std::size(mVerifiers)); for (size_t i = 0; i < mLookaheadMs.size(); ++i) { constexpr float RADIAN_TO_DEGREES = 180 / M_PI; const int64_t atNs = timestampNs + mLookaheadMs[i] * PosePredictorVerifier::kMillisToNanos; for (size_t j = 0; j < mPredictors.size(); ++j) { const size_t idx = i * std::size(mPredictors) + j; mVerifiers[idx].verifyActualPose(timestampNs, pose); mVerifiers[idx].addPredictedPose(atNs, mPredictors[j]->predict(atNs)); error[idx] = RADIAN_TO_DEGREES * mVerifiers[idx].lastError(); } } // Record errors mPredictionRecorder.record(error); mPredictionDurableRecorder.record(error); } else /* constexpr */ { selectedPredictor->add(timestampNs, pose, twist); } // Deliver prediction const int64_t predictionTimeNs = timestampNs + (int64_t)predictionDurationNs; return selectedPredictor->predict(predictionTimeNs); } void PosePredictor::setPosePredictorType(PosePredictorType type) { if (!isValidPosePredictorType(type)) return; if (type == mSetType) return; mSetType = type; if (type == android::media::PosePredictorType::AUTO) { type = android::media::PosePredictorType::LEAST_SQUARES; } if (type != mCurrentType) { mCurrentType = type; if constexpr (!kEnableVerification) { // Verification keeps all predictors up-to-date. // If we don't enable verification, we must reset the current predictor. getCurrentPredictor()->reset(); } } } std::string PosePredictor::toString(size_t index) const { std::string prefixSpace(index, ' '); std::string ss(prefixSpace); ss.append("PosePredictor:\n") .append(prefixSpace) .append(" Current Prediction Type: ") .append(android::media::toString(mCurrentType)) .append("\n") .append(prefixSpace) .append(" Resets: ") .append(std::to_string(mResets)) .append("\n") .append(getCurrentPredictor()->toString(index + 1)); if constexpr (kEnableVerification) { // dump verification ss.append(prefixSpace) .append(" Prediction abs error (L1) degrees [ type ("); for (size_t i = 0; i < mPredictors.size(); ++i) { if (i > 0) ss.append(" , "); ss.append(mPredictors[i]->name()); } ss.append(" ) x ( "); for (size_t i = 0; i < mLookaheadMs.size(); ++i) { if (i > 0) ss.append(" : "); ss.append(std::to_string(mLookaheadMs[i])); } std::vector cumulativeAverageErrors(std::size(mVerifiers)); for (size_t i = 0; i < cumulativeAverageErrors.size(); ++i) { cumulativeAverageErrors[i] = mVerifiers[i].cumulativeAverageError(); } ss.append(" ) ms ]\n") .append(prefixSpace) .append(" Cumulative Average Error:\n") .append(prefixSpace) .append(" ") .append(VectorRecorder::toString(cumulativeAverageErrors, mDelimiterIdx, "%.3g")) .append("\n") .append(prefixSpace) .append(" PerMinuteHistory:\n") .append(mPredictionDurableRecorder.toString(index + 3)) .append(prefixSpace) .append(" PerSecondHistory:\n") .append(mPredictionRecorder.toString(index + 3)); } return ss; } std::shared_ptr PosePredictor::getCurrentPredictor() const { // we don't use a map here, we look up directly switch (mCurrentType) { default: case android::media::PosePredictorType::LAST: return mPredictors[0]; case android::media::PosePredictorType::TWIST: return mPredictors[1]; case android::media::PosePredictorType::AUTO: // shouldn't occur here. case android::media::PosePredictorType::LEAST_SQUARES: return mPredictors[2]; } } } // namespace android::media