1 /* 2 * Copyright (C) 2023 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.media.videoquality.bdrate; 18 19 import com.google.auto.value.AutoValue; 20 import com.google.common.annotations.VisibleForTesting; 21 22 import org.apache.commons.math3.stat.descriptive.moment.Mean; 23 24 import java.util.ArrayList; 25 import java.util.Iterator; 26 import java.util.LinkedList; 27 import java.util.logging.Logger; 28 29 /** Pair of two {@link RateDistortionCurve}s used for calculating a Bjontegaard-Delta (BD) value. */ 30 @AutoValue 31 public abstract class RateDistortionCurvePair { 32 private static final Logger LOGGER = Logger.getLogger(RateDistortionCurvePair.class.getName()); 33 34 private static final Mean MEAN = new Mean(); 35 baseline()36 public abstract RateDistortionCurve baseline(); 37 target()38 public abstract RateDistortionCurve target(); 39 40 /** 41 * Creates a new {@link RateDistortionCurvePair} by first clustering the points to eliminate 42 * noise in the data, then validating that the remaining points are sufficient for BD 43 * calculations. 44 */ createClusteredPair( RateDistortionCurve baseline, RateDistortionCurve target)45 public static RateDistortionCurvePair createClusteredPair( 46 RateDistortionCurve baseline, RateDistortionCurve target) { 47 RateDistortionCurve clusteredBaseline = cluster(baseline); 48 RateDistortionCurve clusteredTarget = cluster(target); 49 50 // Check for correct number of points. 51 if (clusteredBaseline.points().size() < 5) { 52 throw new BdPreconditionFailedException( 53 "The reference curve does not have enough points.", /* isTargetCurve= */ false); 54 } 55 if (clusteredTarget.points().size() < 5) { 56 throw new BdPreconditionFailedException( 57 "The target curve does not have enough points.", /* isTargetCurve= */ true); 58 } 59 60 // Check for monotonicity. 61 if (!isMonotonicallyIncreasing(clusteredBaseline)) { 62 throw new BdPreconditionFailedException( 63 "The reference curve is not monotonically increasing.", 64 /* isTargetCurve= */ false); 65 } 66 if (!isMonotonicallyIncreasing(clusteredTarget)) { 67 throw new BdPreconditionFailedException( 68 "The is not monotonically increasing.", /* isTargetCurve= */ true); 69 } 70 71 return new AutoValue_RateDistortionCurvePair(clusteredBaseline, clusteredTarget); 72 } 73 74 /** To calculate BD-RATE, the two rate-distortion curves must overlap in terms of distortion. */ canCalculateBdRate()75 public boolean canCalculateBdRate() { 76 return !(baseline().getMaxDistortion() < target().getMinDistortion()) 77 && !(target().getMaxDistortion() < baseline().getMinDistortion()); 78 } 79 80 /** To calculate BD-QUALITY, the two rate-distortion curves must overlap in terms of bitrate. */ canCalculateBdQuality()81 public boolean canCalculateBdQuality() { 82 return !(baseline().getMaxLog10Bitrate() < target().getMinLog10Bitrate()) 83 && !(target().getMaxLog10Bitrate() < baseline().getMinLog10Bitrate()); 84 } 85 86 /** 87 * Clusters provided rate-distortion points together to reduce noise when the points are close 88 * together in terms of bitrate. 89 * 90 * <p>"Clusters" are points that have a bitrate that is within 1% of the previous 91 * rate-distortion point. Such points are bucketed and then averaged to provide a single point 92 * in the same range as the cluster. 93 */ 94 @VisibleForTesting cluster(RateDistortionCurve baseCurve)95 static RateDistortionCurve cluster(RateDistortionCurve baseCurve) { 96 if (baseCurve.points().size() < 3) { 97 return baseCurve; 98 } 99 100 RateDistortionCurve.Builder newCurve = RateDistortionCurve.builder(); 101 102 LinkedList<ArrayList<RateDistortionPoint>> buckets = new LinkedList<>(); 103 104 // Bucket the items, moving through the points pairwise. 105 buckets.add(new ArrayList<>()); 106 buckets.peekLast().add(baseCurve.points().first()); 107 108 Iterator<RateDistortionPoint> pointIterator = baseCurve.points().iterator(); 109 RateDistortionPoint lastPoint = pointIterator.next(); 110 RateDistortionPoint currentPoint; 111 112 double maxObservedDistortion = lastPoint.distortion(); 113 while (pointIterator.hasNext()) { 114 currentPoint = pointIterator.next(); 115 116 // Cluster points that are within 10% (bitrate) of each other that would make the curve 117 // non-monotonic. 118 if (currentPoint.rate() / lastPoint.rate() > 1.1 119 || currentPoint.distortion() > maxObservedDistortion) { 120 buckets.add(new ArrayList<>()); 121 maxObservedDistortion = currentPoint.distortion(); 122 } 123 buckets.peekLast().add(currentPoint); 124 lastPoint = currentPoint; 125 } 126 127 for (ArrayList<RateDistortionPoint> bucket : buckets) { 128 if (bucket.size() < 2) { 129 newCurve.addPoint(bucket.get(0)); 130 } 131 132 // For a bucket with multiple points, the new point is the average 133 // between all other points. 134 newCurve.addPoint( 135 RateDistortionPoint.create( 136 MEAN.evaluate(bucket.stream().mapToDouble(p -> p.rate()).toArray()), 137 MEAN.evaluate( 138 bucket.stream().mapToDouble(p -> p.distortion()).toArray()))); 139 } 140 141 return newCurve.build(); 142 } 143 144 /** 145 * Returns whether a {@link RateDistortionCurve} is monotonically increasing which is required 146 * for the Cubic Spline interpolation performed during BD rate calculation. 147 */ isMonotonicallyIncreasing(RateDistortionCurve rateDistortionCurve)148 private static boolean isMonotonicallyIncreasing(RateDistortionCurve rateDistortionCurve) { 149 Iterator<RateDistortionPoint> pointIterator = rateDistortionCurve.points().iterator(); 150 151 RateDistortionPoint lastPoint = pointIterator.next(); 152 RateDistortionPoint currentPoint; 153 while (pointIterator.hasNext()) { 154 currentPoint = pointIterator.next(); 155 if (currentPoint.distortion() <= lastPoint.distortion()) { 156 return false; 157 } 158 lastPoint = currentPoint; 159 } 160 161 return true; 162 } 163 } 164