1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.mediav2.common.cts;
18 
19 import static android.media.MediaCodecInfo.CodecCapabilities.COLOR_FormatYUVP010;
20 import static android.mediav2.common.cts.DecodeStreamToYuv.findDecoderForFormat;
21 import static android.mediav2.common.cts.DecodeStreamToYuv.findDecoderForStream;
22 import static android.mediav2.common.cts.DecodeStreamToYuv.getFormatInStream;
23 import static android.mediav2.common.cts.DecodeStreamToYuv.getImage;
24 import static android.mediav2.common.cts.VideoErrorManager.computeMSE;
25 import static android.mediav2.common.cts.VideoErrorManager.computePSNR;
26 
27 import static org.junit.Assert.assertEquals;
28 import static org.junit.Assert.assertNotNull;
29 import static org.junit.Assert.assertTrue;
30 
31 import android.graphics.ImageFormat;
32 import android.graphics.Rect;
33 import android.media.Image;
34 import android.media.MediaCodec;
35 import android.media.MediaExtractor;
36 import android.media.MediaFormat;
37 import android.util.Log;
38 
39 import com.android.compatibility.common.util.MediaUtils;
40 
41 import org.junit.Assume;
42 
43 import java.io.File;
44 import java.io.FileInputStream;
45 import java.io.IOException;
46 import java.nio.ByteBuffer;
47 import java.util.ArrayList;
48 import java.util.List;
49 import java.util.Map;
50 
51 /**
52  * Wrapper class for storing YUV Planes of an image
53  */
54 class YUVImage {
55     public ArrayList<byte[]> mData = new ArrayList<>();
56 }
57 
58 /**
59  * Utility class for video encoder tests to validate the encoded output.
60  * <p>
61  * The class computes the PSNR between encoders output and input. As the input to an encoder can
62  * be raw yuv buffer or the output of a decoder that is connected to the encoder, the test
63  * accepts YUV as well as compressed streams for validation.
64  * <p>
65  * Before validation, the class checks if the input and output have same width, height and bitdepth.
66  */
67 public class CompareStreams extends CodecDecoderTestBase {
68     private static final String LOG_TAG = CompareStreams.class.getSimpleName();
69 
70     private final RawResource mRefYuv;
71     private final MediaFormat mStreamFormat;
72     private final ByteBuffer mStreamBuffer;
73     private final ArrayList<MediaCodec.BufferInfo> mStreamBufferInfos;
74     private final boolean mAllowRefResize;
75     private final boolean mAllowRefLoopBack;
76     private final Map<Long, List<Rect>> mFrameCropRects;
77     private final double[] mGlobalMSE = {0.0, 0.0, 0.0};
78     private final double[] mMinimumMSE = {Float.MAX_VALUE, Float.MAX_VALUE, Float.MAX_VALUE};
79     private final double[] mGlobalPSNR = new double[3];
80     private final double[] mMinimumPSNR = new double[3];
81     private final double[] mAvgPSNR = {0.0, 0.0, 0.0};
82     private final ArrayList<double[]> mFramesPSNR = new ArrayList<>();
83     private final List<List<double[]>> mFramesCropRectPSNR = new ArrayList<>();
84 
85     private final ArrayList<String> mTmpFiles = new ArrayList<>();
86     private boolean mGenerateStats;
87     private int mFileOffset;
88     private int mFileSize;
89     private int mFrameSize;
90     private byte[] mInputData;
91 
CompareStreams(RawResource refYuv, MediaFormat testFormat, ByteBuffer testBuffer, ArrayList<MediaCodec.BufferInfo> testBufferInfos, boolean allowRefResize, boolean allowRefLoopBack)92     private CompareStreams(RawResource refYuv, MediaFormat testFormat, ByteBuffer testBuffer,
93             ArrayList<MediaCodec.BufferInfo> testBufferInfos, boolean allowRefResize,
94             boolean allowRefLoopBack) {
95         super(findDecoderForFormat(testFormat), testFormat.getString(MediaFormat.KEY_MIME), null,
96                 LOG_TAG);
97         mRefYuv = refYuv;
98         mStreamFormat = testFormat;
99         mStreamBuffer = testBuffer;
100         mStreamBufferInfos = testBufferInfos;
101         mAllowRefResize = allowRefResize;
102         mAllowRefLoopBack = allowRefLoopBack;
103         mFrameCropRects = null;
104     }
105 
CompareStreams(RawResource refYuv, String testMediaType, String testFile, boolean allowRefResize, boolean allowRefLoopBack)106     public CompareStreams(RawResource refYuv, String testMediaType, String testFile,
107             boolean allowRefResize, boolean allowRefLoopBack) throws IOException {
108         super(findDecoderForStream(testMediaType, testFile), testMediaType, testFile, LOG_TAG);
109         mRefYuv = refYuv;
110         mStreamFormat = null;
111         mStreamBuffer = null;
112         mStreamBufferInfos = null;
113         mAllowRefResize = allowRefResize;
114         mAllowRefLoopBack = allowRefLoopBack;
115         mFrameCropRects = null;
116     }
117 
CompareStreams(MediaFormat refFormat, ByteBuffer refBuffer, ArrayList<MediaCodec.BufferInfo> refBufferInfos, MediaFormat testFormat, ByteBuffer testBuffer, ArrayList<MediaCodec.BufferInfo> testBufferInfos, boolean allowRefResize, boolean allowRefLoopBack)118     public CompareStreams(MediaFormat refFormat, ByteBuffer refBuffer,
119             ArrayList<MediaCodec.BufferInfo> refBufferInfos, MediaFormat testFormat,
120             ByteBuffer testBuffer, ArrayList<MediaCodec.BufferInfo> testBufferInfos,
121             boolean allowRefResize, boolean allowRefLoopBack) {
122         this(new DecodeStreamToYuv(refFormat, refBuffer, refBufferInfos).getDecodedYuv(),
123                 testFormat, testBuffer, testBufferInfos, allowRefResize, allowRefLoopBack);
124         mTmpFiles.add(mRefYuv.mFileName);
125     }
126 
CompareStreams(String refMediaType, String refFile, String testMediaType, String testFile, boolean allowRefResize, boolean allowRefLoopBack)127     public CompareStreams(String refMediaType, String refFile, String testMediaType,
128             String testFile, boolean allowRefResize, boolean allowRefLoopBack) throws IOException {
129         this(new DecodeStreamToYuv(refMediaType, refFile).getDecodedYuv(), testMediaType, testFile,
130                 allowRefResize, allowRefLoopBack);
131         mTmpFiles.add(mRefYuv.mFileName);
132     }
133 
CompareStreams(RawResource refYuv, String testMediaType, String testFile, Map<Long, List<Rect>> frameCropRects, boolean allowRefResize, boolean allowRefLoopBack)134     public CompareStreams(RawResource refYuv, String testMediaType, String testFile,
135             Map<Long, List<Rect>> frameCropRects, boolean allowRefResize, boolean allowRefLoopBack)
136             throws IOException {
137         super(findDecoderForStream(testMediaType, testFile), testMediaType, testFile, LOG_TAG);
138         mRefYuv = refYuv;
139         mStreamFormat = null;
140         mStreamBuffer = null;
141         mStreamBufferInfos = null;
142         mAllowRefResize = allowRefResize;
143         mAllowRefLoopBack = allowRefLoopBack;
144         mFrameCropRects = frameCropRects;
145     }
146 
fillByteArray(int tgtFrameWidth, int tgtFrameHeight, int bytesPerSample, int inpFrameWidth, int inpFrameHeight, byte[] inputData)147     static YUVImage fillByteArray(int tgtFrameWidth, int tgtFrameHeight,
148             int bytesPerSample, int inpFrameWidth, int inpFrameHeight, byte[] inputData) {
149         YUVImage yuvImage = new YUVImage();
150         int inOffset = 0;
151         for (int plane = 0; plane < 3; plane++) {
152             int width, height, tileWidth, tileHeight;
153             if (plane != 0) {
154                 width = tgtFrameWidth / 2;
155                 height = tgtFrameHeight / 2;
156                 tileWidth = inpFrameWidth / 2;
157                 tileHeight = inpFrameHeight / 2;
158             } else {
159                 width = tgtFrameWidth;
160                 height = tgtFrameHeight;
161                 tileWidth = inpFrameWidth;
162                 tileHeight = inpFrameHeight;
163             }
164             byte[] outputData = new byte[width * height * bytesPerSample];
165             for (int k = 0; k < height; k += tileHeight) {
166                 int rowsToCopy = Math.min(height - k, tileHeight);
167                 for (int j = 0; j < rowsToCopy; j++) {
168                     for (int i = 0; i < width; i += tileWidth) {
169                         int colsToCopy = Math.min(width - i, tileWidth);
170                         System.arraycopy(inputData,
171                                 inOffset + j * tileWidth * bytesPerSample,
172                                 outputData,
173                                 (k + j) * width * bytesPerSample + i * bytesPerSample,
174                                 colsToCopy * bytesPerSample);
175                     }
176                 }
177             }
178             inOffset += tileWidth * tileHeight * bytesPerSample;
179             yuvImage.mData.add(outputData);
180         }
181         return yuvImage;
182     }
183 
dequeueOutput(int bufferIndex, MediaCodec.BufferInfo info)184     protected void dequeueOutput(int bufferIndex, MediaCodec.BufferInfo info) {
185         if (info.size > 0) {
186             Image img = mCodec.getOutputImage(bufferIndex);
187             assertNotNull(img);
188             YUVImage yuvImage = getImage(img);
189             MediaFormat format = mCodec.getOutputFormat();
190             int width = getWidth(format);
191             int height = getHeight(format);
192             if (mOutputCount == 0) {
193                 int imgFormat = img.getFormat();
194                 int bytesPerSample = (ImageFormat.getBitsPerPixel(imgFormat) * 2) / (8 * 3);
195                 if (mRefYuv.mBytesPerSample != bytesPerSample) {
196                     String msg = String.format(
197                             "Reference file bytesPerSample and Test file bytesPerSample are not "
198                                     + "same. Reference bytesPerSample : %d, Test bytesPerSample :"
199                                     + " %d", mRefYuv.mBytesPerSample, bytesPerSample);
200                     throw new IllegalArgumentException(msg);
201                 }
202                 if (!mAllowRefResize && (mRefYuv.mWidth != width || mRefYuv.mHeight != height)) {
203                     String msg = String.format(
204                             "Reference file attributes and Test file attributes are not same. "
205                                     + "Reference width : %d, height : %d, bytesPerSample : %d, "
206                                     + "Test width : %d, height : %d, bytesPerSample : %d",
207                             mRefYuv.mWidth, mRefYuv.mHeight, mRefYuv.mBytesPerSample, width,
208                             height, bytesPerSample);
209                     throw new IllegalArgumentException(msg);
210                 }
211                 mFileOffset = 0;
212                 mFileSize = (int) new File(mRefYuv.mFileName).length();
213                 mFrameSize = mRefYuv.mWidth * mRefYuv.mHeight * mRefYuv.mBytesPerSample * 3 / 2;
214                 mInputData = new byte[mFrameSize];
215             }
216             try (FileInputStream fInp = new FileInputStream(mRefYuv.mFileName)) {
217                 assertEquals(mFileOffset, fInp.skip(mFileOffset));
218                 assertEquals(mFrameSize, fInp.read(mInputData));
219                 mFileOffset += mFrameSize;
220                 if (mAllowRefLoopBack && mFileOffset == mFileSize) mFileOffset = 0;
221                 YUVImage yuvRefImage = fillByteArray(width, height, mRefYuv.mBytesPerSample,
222                         mRefYuv.mWidth, mRefYuv.mHeight, mInputData);
223                 List<Rect> frameCropRects =
224                         mFrameCropRects != null ? mFrameCropRects.get(info.presentationTimeUs) :
225                                 null;
226                 updateErrorStats(yuvRefImage.mData.get(0), yuvRefImage.mData.get(1),
227                         yuvRefImage.mData.get(2), yuvImage.mData.get(0), yuvImage.mData.get(1),
228                         yuvImage.mData.get(2), width, height, frameCropRects);
229 
230             } catch (IOException e) {
231                 throw new RuntimeException(e);
232             }
233             mOutputCount++;
234         }
235         if ((info.flags & MediaCodec.BUFFER_FLAG_END_OF_STREAM) != 0) {
236             mSawOutputEOS = true;
237             mGenerateStats = true;
238             finalizerErrorStats();
239         }
240         if (info.size > 0 && (info.flags & MediaCodec.BUFFER_FLAG_CODEC_CONFIG) == 0) {
241             mOutputBuff.saveOutPTS(info.presentationTimeUs);
242             mOutputCount++;
243         }
244         mCodec.releaseOutputBuffer(bufferIndex, false);
245     }
246 
clamp(int val, int min, int max)247     private int clamp(int val, int min, int max) {
248         return Math.max(min, Math.min(max, val));
249     }
250 
updateErrorStats(byte[] yRef, byte[] uRef, byte[] vRef, byte[] yTest, byte[] uTest, byte[] vTest, int imgWidth, int imgHeight, List<Rect> cropRectList)251     private void updateErrorStats(byte[] yRef, byte[] uRef, byte[] vRef, byte[] yTest,
252             byte[] uTest, byte[] vTest, int imgWidth, int imgHeight, List<Rect> cropRectList) {
253         if (cropRectList == null || cropRectList.isEmpty()) {
254             cropRectList = new ArrayList<>();
255             cropRectList.add(new Rect(0, 0, imgWidth, imgHeight));
256         }
257         double sumYMSE = 0;
258         double sumUMSE = 0;
259         double sumVMSE = 0;
260         Rect frameRect = new Rect(0, 0, imgWidth, imgHeight);
261         ArrayList<double[]> frameCropRectPSNR = new ArrayList<>();
262 
263         for (int i = 0; i < cropRectList.size(); i++) {
264             Rect cropRect = new Rect(cropRectList.get(i));
265             cropRect.left = clamp(cropRect.left, 0, imgWidth);
266             cropRect.top = clamp(cropRect.top, 0, imgHeight);
267             cropRect.right = clamp(cropRect.right, 0, imgWidth);
268             cropRect.bottom = clamp(cropRect.bottom, 0, imgHeight);
269             assertTrue("invalid cropRect, " + cropRect,
270                     IS_AT_LEAST_T ? cropRect.isValid()
271                             : cropRect.left <= cropRect.right && cropRect.top <= cropRect.bottom);
272             assertTrue(String.format("cropRect %s exceeds frameRect %s", cropRect, frameRect),
273                     frameRect.contains(cropRect));
274             double curYMSE = computeMSE(yRef, yTest, mRefYuv.mBytesPerSample, imgWidth, imgHeight,
275                     cropRect);
276             sumYMSE += curYMSE;
277 
278             cropRect.left = cropRect.left / 2;   // for uv
279             cropRect.top = cropRect.top / 2;
280             cropRect.right = cropRect.right / 2;
281             cropRect.bottom = cropRect.bottom / 2;
282 
283             double curUMSE = computeMSE(uRef, uTest, mRefYuv.mBytesPerSample, imgWidth / 2,
284                     imgHeight / 2, cropRect);
285             sumUMSE += curUMSE;
286 
287             double curVMSE = computeMSE(vRef, vTest, mRefYuv.mBytesPerSample, imgWidth / 2,
288                     imgHeight / 2, cropRect);
289             sumVMSE += curVMSE;
290 
291             double yCurrCropRectPSNR = computePSNR(curYMSE, mRefYuv.mBytesPerSample);
292             double uCurrCropRectPSNR = computePSNR(curUMSE, mRefYuv.mBytesPerSample);
293             double vCurrCropRectPSNR = computePSNR(curVMSE, mRefYuv.mBytesPerSample);
294 
295             frameCropRectPSNR.add(new double[]{yCurrCropRectPSNR, uCurrCropRectPSNR,
296                     vCurrCropRectPSNR});
297         }
298         mFramesCropRectPSNR.add(frameCropRectPSNR);
299         mGlobalMSE[0] += sumYMSE;
300         mGlobalMSE[1] += sumUMSE;
301         mGlobalMSE[2] += sumVMSE;
302         mMinimumMSE[0] = Math.min(mMinimumMSE[0], sumYMSE);
303         mMinimumMSE[1] = Math.min(mMinimumMSE[1], sumUMSE);
304         mMinimumMSE[2] = Math.min(mMinimumMSE[2], sumVMSE);
305         double yFramePSNR = computePSNR(sumYMSE, mRefYuv.mBytesPerSample);
306         double uFramePSNR = computePSNR(sumUMSE, mRefYuv.mBytesPerSample);
307         double vFramePSNR = computePSNR(sumVMSE, mRefYuv.mBytesPerSample);
308         mAvgPSNR[0] += yFramePSNR;
309         mAvgPSNR[1] += uFramePSNR;
310         mAvgPSNR[2] += vFramePSNR;
311         mFramesPSNR.add(new double[]{yFramePSNR, uFramePSNR, vFramePSNR});
312     }
313 
finalizerErrorStats()314     private void finalizerErrorStats() {
315         for (int i = 0; i < mGlobalPSNR.length; i++) {
316             mGlobalMSE[i] /= mFramesPSNR.size();
317             mGlobalPSNR[i] = computePSNR(mGlobalMSE[i], mRefYuv.mBytesPerSample);
318             mMinimumPSNR[i] = computePSNR(mMinimumMSE[i], mRefYuv.mBytesPerSample);
319             mAvgPSNR[i] /= mFramesPSNR.size();
320         }
321         if (ENABLE_LOGS) {
322             String msg = String.format(
323                     "global_psnr_y:%.2f, global_psnr_u:%.2f, global_psnr_v:%.2f, min_psnr_y:%"
324                             + ".2f, min_psnr_u:%.2f, min_psnr_v:%.2f avg_psnr_y:%.2f, "
325                             + "avg_psnr_u:%.2f, avg_psnr_v:%.2f",
326                     mGlobalPSNR[0], mGlobalPSNR[1], mGlobalPSNR[2], mMinimumPSNR[0],
327                     mMinimumPSNR[1], mMinimumPSNR[2], mAvgPSNR[0], mAvgPSNR[1], mAvgPSNR[2]);
328             Log.v(LOG_TAG, msg);
329         }
330     }
331 
generateErrorStats()332     private void generateErrorStats() throws IOException, InterruptedException {
333         if (!mGenerateStats) {
334             if (MediaUtils.isTv()) {
335                 // Some TV devices support HDR10 display with VO instead of GPU. In this case,
336                 // COLOR_FormatYUVP010 may not be supported.
337                 MediaFormat format = mStreamFormat != null ? mStreamFormat :
338                         getFormatInStream(mMediaType, mTestFile);
339                 ArrayList<MediaFormat> formatList = new ArrayList<>();
340                 formatList.add(format);
341                 boolean isHBD = doesAnyFormatHaveHDRProfile(mMediaType, formatList);
342                 if (isHBD || mTestFile.contains("10bit")) {
343                     if (!hasSupportForColorFormat(mCodecName, mMediaType, COLOR_FormatYUVP010)) {
344                         Assume.assumeTrue("Could not validate the encoded output as"
345                                 + " COLOR_FormatYUVP010 is not supported by the decoder", false);
346                     }
347                 }
348             }
349             if (mStreamFormat != null) {
350                 decodeToMemory(mStreamBuffer, mStreamBufferInfos, mStreamFormat, mCodecName);
351             } else {
352                 decodeToMemory(mTestFile, mCodecName, 0, MediaExtractor.SEEK_TO_CLOSEST_SYNC,
353                         Integer.MAX_VALUE);
354             }
355         }
356     }
357 
358     /**
359      * @see VideoErrorManager#getGlobalPSNR()
360      */
getGlobalPSNR()361     public double[] getGlobalPSNR() throws IOException, InterruptedException {
362         generateErrorStats();
363         return mGlobalPSNR;
364     }
365 
366     /**
367      * @see VideoErrorManager#getMinimumPSNR()
368      */
getMinimumPSNR()369     public double[] getMinimumPSNR() throws IOException, InterruptedException {
370         generateErrorStats();
371         return mMinimumPSNR;
372     }
373 
374     /**
375      * @see VideoErrorManager#getFramesPSNR()
376      */
getFramesPSNR()377     public ArrayList<double[]> getFramesPSNR() throws IOException, InterruptedException {
378         generateErrorStats();
379         return mFramesPSNR;
380     }
381 
382     /**
383      * @see VideoErrorManager#getAvgPSNR()
384      */
getAvgPSNR()385     public double[] getAvgPSNR() throws IOException, InterruptedException {
386         generateErrorStats();
387         return mAvgPSNR;
388     }
389 
getFramesPSNRForRect()390     public List<List<double[]>> getFramesPSNRForRect() throws IOException, InterruptedException {
391         generateErrorStats();
392         return mFramesCropRectPSNR;
393     }
394 
cleanUp()395     public void cleanUp() {
396         for (String tmpFile : mTmpFiles) {
397             File tmp = new File(tmpFile);
398             if (tmp.exists()) tmp.delete();
399         }
400         mTmpFiles.clear();
401     }
402 }
403