1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.media.videoquality.bdrate;
18 
19 import com.google.auto.value.AutoValue;
20 import com.google.common.annotations.VisibleForTesting;
21 
22 import org.apache.commons.math3.stat.descriptive.moment.Mean;
23 
24 import java.util.ArrayList;
25 import java.util.Iterator;
26 import java.util.LinkedList;
27 import java.util.logging.Logger;
28 
29 /** Pair of two {@link RateDistortionCurve}s used for calculating a Bjontegaard-Delta (BD) value. */
30 @AutoValue
31 public abstract class RateDistortionCurvePair {
32     private static final Logger LOGGER = Logger.getLogger(RateDistortionCurvePair.class.getName());
33 
34     private static final Mean MEAN = new Mean();
35 
baseline()36     public abstract RateDistortionCurve baseline();
37 
target()38     public abstract RateDistortionCurve target();
39 
40     /**
41      * Creates a new {@link RateDistortionCurvePair} by first clustering the points to eliminate
42      * noise in the data, then validating that the remaining points are sufficient for BD
43      * calculations.
44      */
createClusteredPair( RateDistortionCurve baseline, RateDistortionCurve target)45     public static RateDistortionCurvePair createClusteredPair(
46             RateDistortionCurve baseline, RateDistortionCurve target) {
47         RateDistortionCurve clusteredBaseline = cluster(baseline);
48         RateDistortionCurve clusteredTarget = cluster(target);
49 
50         // Check for correct number of points.
51         if (clusteredBaseline.points().size() < 5) {
52             throw new BdPreconditionFailedException(
53                     "The reference curve does not have enough points.", /* isTargetCurve= */ false);
54         }
55         if (clusteredTarget.points().size() < 5) {
56             throw new BdPreconditionFailedException(
57                     "The target curve does not have enough points.", /* isTargetCurve= */ true);
58         }
59 
60         // Check for monotonicity.
61         if (!isMonotonicallyIncreasing(clusteredBaseline)) {
62             throw new BdPreconditionFailedException(
63                     "The reference curve is not monotonically increasing.",
64                     /* isTargetCurve= */ false);
65         }
66         if (!isMonotonicallyIncreasing(clusteredTarget)) {
67             throw new BdPreconditionFailedException(
68                     "The is not monotonically increasing.", /* isTargetCurve= */ true);
69         }
70 
71         return new AutoValue_RateDistortionCurvePair(clusteredBaseline, clusteredTarget);
72     }
73 
74     /** To calculate BD-RATE, the two rate-distortion curves must overlap in terms of distortion. */
canCalculateBdRate()75     public boolean canCalculateBdRate() {
76         return !(baseline().getMaxDistortion() < target().getMinDistortion())
77                 && !(target().getMaxDistortion() < baseline().getMinDistortion());
78     }
79 
80     /** To calculate BD-QUALITY, the two rate-distortion curves must overlap in terms of bitrate. */
canCalculateBdQuality()81     public boolean canCalculateBdQuality() {
82         return !(baseline().getMaxLog10Bitrate() < target().getMinLog10Bitrate())
83                 && !(target().getMaxLog10Bitrate() < baseline().getMinLog10Bitrate());
84     }
85 
86     /**
87      * Clusters provided rate-distortion points together to reduce noise when the points are close
88      * together in terms of bitrate.
89      *
90      * <p>"Clusters" are points that have a bitrate that is within 1% of the previous
91      * rate-distortion point. Such points are bucketed and then averaged to provide a single point
92      * in the same range as the cluster.
93      */
94     @VisibleForTesting
cluster(RateDistortionCurve baseCurve)95     static RateDistortionCurve cluster(RateDistortionCurve baseCurve) {
96         if (baseCurve.points().size() < 3) {
97             return baseCurve;
98         }
99 
100         RateDistortionCurve.Builder newCurve = RateDistortionCurve.builder();
101 
102         LinkedList<ArrayList<RateDistortionPoint>> buckets = new LinkedList<>();
103 
104         // Bucket the items, moving through the points pairwise.
105         buckets.add(new ArrayList<>());
106         buckets.peekLast().add(baseCurve.points().first());
107 
108         Iterator<RateDistortionPoint> pointIterator = baseCurve.points().iterator();
109         RateDistortionPoint lastPoint = pointIterator.next();
110         RateDistortionPoint currentPoint;
111 
112         double maxObservedDistortion = lastPoint.distortion();
113         while (pointIterator.hasNext()) {
114             currentPoint = pointIterator.next();
115 
116             // Cluster points that are within 10% (bitrate) of each other that would make the curve
117             // non-monotonic.
118             if (currentPoint.rate() / lastPoint.rate() > 1.1
119                     || currentPoint.distortion() > maxObservedDistortion) {
120                 buckets.add(new ArrayList<>());
121                 maxObservedDistortion = currentPoint.distortion();
122             }
123             buckets.peekLast().add(currentPoint);
124             lastPoint = currentPoint;
125         }
126 
127         for (ArrayList<RateDistortionPoint> bucket : buckets) {
128             if (bucket.size() < 2) {
129                 newCurve.addPoint(bucket.get(0));
130             }
131 
132             // For a bucket with multiple points, the new point is the average
133             // between all other points.
134             newCurve.addPoint(
135                     RateDistortionPoint.create(
136                             MEAN.evaluate(bucket.stream().mapToDouble(p -> p.rate()).toArray()),
137                             MEAN.evaluate(
138                                     bucket.stream().mapToDouble(p -> p.distortion()).toArray())));
139         }
140 
141         return newCurve.build();
142     }
143 
144     /**
145      * Returns whether a {@link RateDistortionCurve} is monotonically increasing which is required
146      * for the Cubic Spline interpolation performed during BD rate calculation.
147      */
isMonotonicallyIncreasing(RateDistortionCurve rateDistortionCurve)148     private static boolean isMonotonicallyIncreasing(RateDistortionCurve rateDistortionCurve) {
149         Iterator<RateDistortionPoint> pointIterator = rateDistortionCurve.points().iterator();
150 
151         RateDistortionPoint lastPoint = pointIterator.next();
152         RateDistortionPoint currentPoint;
153         while (pointIterator.hasNext()) {
154             currentPoint = pointIterator.next();
155             if (currentPoint.distortion() <= lastPoint.distortion()) {
156                 return false;
157             }
158             lastPoint = currentPoint;
159         }
160 
161         return true;
162     }
163 }
164