1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.media.mediatranscoding.cts;
18 
19 import static org.junit.Assert.assertTrue;
20 
21 
22 import android.content.ContentResolver;
23 import android.content.Context;
24 import android.content.res.AssetFileDescriptor;
25 import android.graphics.ImageFormat;
26 import android.graphics.Rect;
27 import android.media.Image;
28 import android.media.MediaCodec;
29 import android.media.MediaCodecInfo;
30 import android.media.MediaExtractor;
31 import android.media.MediaFormat;
32 import android.media.MediaMetadataRetriever;
33 import android.net.Uri;
34 import android.os.FileUtils;
35 import android.os.ParcelFileDescriptor;
36 import android.util.Log;
37 
38 import java.io.EOFException;
39 import java.io.File;
40 import java.io.FileInputStream;
41 import java.io.FileOutputStream;
42 import java.io.IOException;
43 import java.io.InputStream;
44 import java.nio.ByteBuffer;
45 import java.util.Locale;
46 
47 /* package */ class MediaTranscodingTestUtil {
48     private static final String TAG = "MediaTranscodingTestUtil";
49 
50     // Helper class to extract the information from source file and transcoded file.
51     static class VideoFileInfo {
52         String mUri;
53         int mNumVideoFrames = 0;
54         int mWidth = 0;
55         int mHeight = 0;
56         float mVideoFrameRate = 0.0f;
57         boolean mHasAudio = false;
58         int mRotationDegree = 0;
59 
toString()60         public String toString() {
61             String str = mUri;
62             str += " Width:" + mWidth;
63             str += " Height:" + mHeight;
64             str += " FrameRate:" + mWidth;
65             str += " FrameCount:" + mNumVideoFrames;
66             str += " HasAudio:" + (mHasAudio ? "Yes" : "No");
67             return str;
68         }
69     }
70 
extractVideoFileInfo(Context ctx, Uri videoUri)71     static VideoFileInfo extractVideoFileInfo(Context ctx, Uri videoUri) throws IOException {
72         VideoFileInfo info = new VideoFileInfo();
73         AssetFileDescriptor afd = null;
74         MediaMetadataRetriever retriever = null;
75 
76         try {
77             afd = ctx.getContentResolver().openAssetFileDescriptor(videoUri, "r");
78             retriever = new MediaMetadataRetriever();
79             retriever.setDataSource(afd.getFileDescriptor(), afd.getStartOffset(), afd.getLength());
80 
81             info.mUri = videoUri.getLastPathSegment();
82             Log.i(TAG, "Trying to transcode to " + info.mUri);
83             String width = retriever.extractMetadata(
84                     MediaMetadataRetriever.METADATA_KEY_VIDEO_WIDTH);
85             String height = retriever.extractMetadata(
86                     MediaMetadataRetriever.METADATA_KEY_VIDEO_HEIGHT);
87             if (width != null && height != null) {
88                 info.mWidth = Integer.parseInt(width);
89                 info.mHeight = Integer.parseInt(height);
90             }
91 
92             String frameRate = retriever.extractMetadata(
93                     MediaMetadataRetriever.METADATA_KEY_CAPTURE_FRAMERATE);
94             if (frameRate != null) {
95                 info.mVideoFrameRate = Float.parseFloat(frameRate);
96             }
97 
98             String frameCount = retriever.extractMetadata(
99                     MediaMetadataRetriever.METADATA_KEY_VIDEO_FRAME_COUNT);
100             if (frameCount != null) {
101                 info.mNumVideoFrames = Integer.parseInt(frameCount);
102             }
103 
104             String hasAudio = retriever.extractMetadata(
105                     MediaMetadataRetriever.METADATA_KEY_HAS_AUDIO);
106             if (hasAudio != null) {
107                 info.mHasAudio = hasAudio.equals("yes");
108             }
109 
110             retriever.extractMetadata(MediaMetadataRetriever.METADATA_KEY_VIDEO_ROTATION);
111             String degree = retriever.extractMetadata(
112                     MediaMetadataRetriever.METADATA_KEY_VIDEO_ROTATION);
113             if (degree != null) {
114                 info.mRotationDegree = Integer.parseInt(degree);
115             }
116         } finally {
117             if (retriever != null) {
118                 retriever.close();
119             }
120             if (afd != null) {
121                 afd.close();
122             }
123         }
124         return info;
125     }
126 
dumpYuvToExternal(final Context ctx, Uri yuvUri)127     static void dumpYuvToExternal(final Context ctx, Uri yuvUri) {
128         Log.i(TAG, "dumping file to external");
129         try {
130             String filename = + System.nanoTime() + "_" + yuvUri.getLastPathSegment();
131             String path = "/storage/emulated/0/Download/" + filename;
132             final File file = new File(path);
133             ParcelFileDescriptor pfd = ctx.getContentResolver().openFileDescriptor(yuvUri, "r");
134             FileInputStream fis = new FileInputStream(pfd.getFileDescriptor());
135             FileOutputStream fos = new FileOutputStream(file);
136             FileUtils.copy(fis, fos);
137         } catch (IOException e) {
138             Log.e(TAG, "Failed to copy file", e);
139         }
140     }
141 
computeStats(final Context ctx, final Uri sourceMp4, final Uri transcodedMp4, boolean debugYuv)142     static VideoTranscodingStatistics computeStats(final Context ctx, final Uri sourceMp4,
143             final Uri transcodedMp4, boolean debugYuv)
144             throws Exception {
145         // First decode the sourceMp4 to a temp yuv in yuv420p format.
146         Uri sourceYUV420PUri = Uri.parse(ContentResolver.SCHEME_FILE + "://"
147                 + ctx.getCacheDir().getAbsolutePath() + "/sourceYUV420P.yuv");
148         decodeMp4ToYuv(ctx, sourceMp4, sourceYUV420PUri);
149         VideoFileInfo srcInfo = extractVideoFileInfo(ctx, sourceMp4);
150         if (debugYuv) {
151             dumpYuvToExternal(ctx, sourceYUV420PUri);
152         }
153 
154         // Second decode the transcodedMp4 to a temp yuv in yuv420p format.
155         Uri transcodedYUV420PUri = Uri.parse(ContentResolver.SCHEME_FILE + "://"
156                 + ctx.getCacheDir().getAbsolutePath() + "/transcodedYUV420P.yuv");
157         decodeMp4ToYuv(ctx, transcodedMp4, transcodedYUV420PUri);
158         VideoFileInfo dstInfo = extractVideoFileInfo(ctx, sourceMp4);
159         if (debugYuv) {
160             dumpYuvToExternal(ctx, transcodedYUV420PUri);
161         }
162 
163         if ((srcInfo.mWidth != dstInfo.mWidth) || (srcInfo.mHeight != dstInfo.mHeight) ||
164                 (srcInfo.mNumVideoFrames != dstInfo.mNumVideoFrames) ||
165                 (srcInfo.mRotationDegree != dstInfo.mRotationDegree)) {
166             throw new UnsupportedOperationException(
167                     "Src mp4 and dst mp4 must have same width/height/frames");
168         }
169 
170         // Then Compute the psnr of transcodedYUV420PUri against sourceYUV420PUri.
171         return computePsnr(ctx, sourceYUV420PUri, transcodedYUV420PUri, srcInfo.mWidth,
172                 srcInfo.mHeight);
173     }
174 
decodeMp4ToYuv(final Context ctx, final Uri fileUri, final Uri yuvUri)175     private static void decodeMp4ToYuv(final Context ctx, final Uri fileUri, final Uri yuvUri)
176             throws Exception {
177         AssetFileDescriptor fileFd = null;
178         MediaExtractor extractor = null;
179         MediaCodec codec = null;
180         AssetFileDescriptor yuvFd = null;
181         FileOutputStream out = null;
182         int width = 0;
183         int height = 0;
184 
185         try {
186             fileFd = ctx.getContentResolver().openAssetFileDescriptor(fileUri, "r");
187             extractor = new MediaExtractor();
188             extractor.setDataSource(fileFd.getFileDescriptor(), fileFd.getStartOffset(),
189                     fileFd.getLength());
190 
191             // Selects the video track.
192             int trackCount = extractor.getTrackCount();
193             if (trackCount <= 0) {
194                 throw new IllegalArgumentException("Invalid mp4 file");
195             }
196             int videoTrackIndex = -1;
197             for (int i = 0; i < trackCount; i++) {
198                 extractor.selectTrack(i);
199                 MediaFormat format = extractor.getTrackFormat(i);
200                 if (format.getString(MediaFormat.KEY_MIME).startsWith("video/")) {
201                     videoTrackIndex = i;
202                     break;
203                 }
204                 extractor.unselectTrack(i);
205             }
206             if (videoTrackIndex == -1) {
207                 throw new IllegalArgumentException("Can not find video track");
208             }
209 
210             extractor.selectTrack(videoTrackIndex);
211             MediaFormat format = extractor.getTrackFormat(videoTrackIndex);
212             String mime = format.getString(MediaFormat.KEY_MIME);
213             format.setInteger(MediaFormat.KEY_COLOR_FORMAT,
214                     MediaCodecInfo.CodecCapabilities.COLOR_FormatYUV420Planar);
215 
216             // Opens the yuv file uri.
217             yuvFd = ctx.getContentResolver().openAssetFileDescriptor(yuvUri,
218                     "w");
219             out = new FileOutputStream(yuvFd.getFileDescriptor());
220 
221             codec = MediaCodec.createDecoderByType(mime);
222             codec.configure(format,
223                     null,  // surface
224                     null,  // crypto
225                     0);    // flags
226             codec.start();
227 
228             ByteBuffer[] inputBuffers = codec.getInputBuffers();
229             ByteBuffer[] outputBuffers = codec.getOutputBuffers();
230 
231             // start decode loop
232             MediaCodec.BufferInfo info = new MediaCodec.BufferInfo();
233 
234             final long kTimeOutUs = 1000; // 1ms timeout
235             long lastOutputTimeUs = 0;
236             boolean sawInputEOS = false;
237             boolean sawOutputEOS = false;
238             int inputNum = 0;
239             int outputNum = 0;
240             boolean advanceDone = true;
241 
242             long start = System.currentTimeMillis();
243             while (!sawOutputEOS) {
244                 // handle input
245                 if (!sawInputEOS) {
246                     int inputBufIndex = codec.dequeueInputBuffer(kTimeOutUs);
247 
248                     if (inputBufIndex >= 0) {
249                         ByteBuffer dstBuf = inputBuffers[inputBufIndex];
250                         // sample contains the buffer and the PTS offset normalized to frame index
251                         int sampleSize =
252                                 extractor.readSampleData(dstBuf, 0 /* offset */);
253                         long presentationTimeUs = extractor.getSampleTime();
254                         advanceDone = extractor.advance();
255 
256                         if (sampleSize < 0) {
257                             Log.d(TAG, "Input EOS");
258                             sawInputEOS = true;
259                             sampleSize = 0;
260                         }
261                         codec.queueInputBuffer(
262                                 inputBufIndex,
263                                 0 /* offset */,
264                                 sampleSize,
265                                 presentationTimeUs,
266                                 sawInputEOS ? MediaCodec.BUFFER_FLAG_END_OF_STREAM : 0);
267                     } else if (inputBufIndex == MediaCodec.INFO_TRY_AGAIN_LATER) {
268                         // Expected. Do nothing.
269                     }  else {
270                         Log.w(
271                                 TAG,
272                                 "Unrecognized dequeueInputBuffer() return value: " + inputBufIndex);
273                     }
274                 }
275 
276                 // handle output
277                 int outputBufIndex = codec.dequeueOutputBuffer(info, kTimeOutUs);
278 
279                 if (outputBufIndex >= 0) {
280                     if (info.size > 0) { // Disregard 0-sized buffers at the end.
281                         outputNum++;
282                         Log.i(TAG, "Output frame number: " + outputNum);
283                         Image image = codec.getOutputImage(outputBufIndex);
284                         dumpYUV420PToFile(image, out);
285                     }
286 
287                     codec.releaseOutputBuffer(outputBufIndex, false /* render */);
288                     if ((info.flags & MediaCodec.BUFFER_FLAG_END_OF_STREAM) != 0) {
289                         Log.d(TAG, "Output EOS");
290                         sawOutputEOS = true;
291                     }
292                 } else if (outputBufIndex == MediaCodec.INFO_OUTPUT_BUFFERS_CHANGED) {
293                     outputBuffers = codec.getOutputBuffers();
294                     Log.d(TAG, "Output buffers changed");
295                 } else if (outputBufIndex == MediaCodec.INFO_OUTPUT_FORMAT_CHANGED) {
296                     Log.d(TAG, "Output format changed");
297                 } else if (outputBufIndex == MediaCodec.INFO_TRY_AGAIN_LATER) {
298                     // Expected. Do nothing.
299                 } else {
300                     Log.w(
301                             TAG,
302                             "Unrecognized dequeueOutputBuffer() return value: " + outputBufIndex);
303                 }
304             }
305         } finally {
306             if (codec != null) {
307                 codec.stop();
308                 codec.release();
309             }
310             if (extractor != null) {
311                 extractor.release();
312             }
313             if (out != null) {
314                 out.close();
315             }
316             if (fileFd != null) {
317                 fileFd.close();
318             }
319             if (yuvFd != null) {
320                 yuvFd.close();
321             }
322         }
323     }
324 
dumpYUV420PToFile(Image image, FileOutputStream out)325     private static void dumpYUV420PToFile(Image image, FileOutputStream out) throws IOException {
326         int format = image.getFormat();
327 
328         if (ImageFormat.YUV_420_888 != format) {
329             throw new UnsupportedOperationException("Only supports YUV420P");
330         }
331 
332         Rect crop = image.getCropRect();
333         int cropLeft = crop.left;
334         int cropRight = crop.right;
335         int cropTop = crop.top;
336         int cropBottom = crop.bottom;
337         int imageWidth = cropRight - cropLeft;
338         int imageHeight = cropBottom - cropTop;
339         byte[] bb = new byte[imageWidth * imageHeight];
340         byte[] lb = null;
341         Image.Plane[] planes = image.getPlanes();
342         for (int i = 0; i < planes.length; ++i) {
343             ByteBuffer buf = planes[i].getBuffer();
344 
345             int width, height, rowStride, pixelStride, x, y, top, left;
346             rowStride = planes[i].getRowStride();
347             pixelStride = planes[i].getPixelStride();
348             if (i == 0) {
349                 width = imageWidth;
350                 height = imageHeight;
351                 left = cropLeft;
352                 top = cropTop;
353             } else {
354                 width = imageWidth / 2;
355                 height = imageHeight / 2;
356                 left = cropLeft / 2;
357                 top = cropTop / 2;
358             }
359 
360             if (buf.hasArray()) {
361                 byte b[] = buf.array();
362                 int offs = buf.arrayOffset();
363                 if (pixelStride == 1) {
364                     for (y = 0; y < height; ++y) {
365                         System.arraycopy(bb, y * width, b, y * rowStride + offs, width);
366                     }
367                 } else {
368                     // do it pixel-by-pixel
369                     for (y = 0; y < height; ++y) {
370                         int lineOffset = offs + y * rowStride;
371                         for (x = 0; x < width; ++x) {
372                             bb[y * width + x] = b[lineOffset + x * pixelStride];
373                         }
374                     }
375                 }
376             } else { // almost always ends up here due to direct buffers
377                 int pos = buf.position();
378                 if (pixelStride == 1) {
379                     for (y = 0; y < height; ++y) {
380                         buf.position(pos + y * rowStride);
381                         buf.get(bb, y * width, width);
382                     }
383                 } else {
384                     // Reallocate linebuffer if necessary.
385                     if (lb == null || lb.length < rowStride) {
386                         lb = new byte[rowStride];
387                     }
388                     // do it pixel-by-pixel
389                     for (y = 0; y < height; ++y) {
390                         buf.position(pos + left + (top + y) * rowStride);
391                         // we're only guaranteed to have pixelStride * (width - 1) + 1 bytes
392                         buf.get(lb, 0, pixelStride * (width - 1) + 1);
393                         for (x = 0; x < width; ++x) {
394                             bb[y * width + x] = lb[x * pixelStride];
395                         }
396                     }
397                 }
398                 buf.position(pos);
399             }
400             // Write out the buffer to the output.
401             out.write(bb, 0, width * height);
402         }
403     }
404 
405     ////////////////////////////////////////////////////////////////////////////////////////////////
406     // The following psnr code is leveraged from the following file with minor modification:
407     // cts/tests/tests/media/src/android/media/cts/VideoCodecTestBase.java
408     ////////////////////////////////////////////////////////////////////////////////////////////////
409     // TODO(hkuang): Merge this code with the code in VideoCodecTestBase to use the same one.
410     /**
411      * Calculates PSNR value between two video frames.
412      */
computePSNR(byte[] data0, byte[] data1)413     private static double computePSNR(byte[] data0, byte[] data1) {
414         long squareError = 0;
415         assertTrue(data0.length == data1.length);
416         int length = data0.length;
417         for (int i = 0; i < length; i++) {
418             int diff = ((int) data0[i] & 0xff) - ((int) data1[i] & 0xff);
419             squareError += diff * diff;
420         }
421         double meanSquareError = (double) squareError / length;
422         double psnr = 10 * Math.log10((double) 255 * 255 / meanSquareError);
423         return psnr;
424     }
425 
426     /**
427      * Calculates average and minimum PSNR values between
428      * set of reference and decoded video frames.
429      * Runs PSNR calculation for the full duration of the decoded data.
430      */
computePsnr( Context ctx, Uri referenceYuvFileUri, Uri decodedYuvFileUri, int width, int height)431     private static VideoTranscodingStatistics computePsnr(
432             Context ctx,
433             Uri referenceYuvFileUri,
434             Uri decodedYuvFileUri,
435             int width,
436             int height) throws Exception {
437         VideoTranscodingStatistics statistics = new VideoTranscodingStatistics();
438         AssetFileDescriptor referenceFd = ctx.getContentResolver().openAssetFileDescriptor(
439                 referenceYuvFileUri, "r");
440         InputStream referenceStream = new FileInputStream(referenceFd.getFileDescriptor());
441 
442         AssetFileDescriptor decodedFd = ctx.getContentResolver().openAssetFileDescriptor(
443                 decodedYuvFileUri, "r");
444         InputStream decodedStream = new FileInputStream(decodedFd.getFileDescriptor());
445 
446         int ySize = width * height;
447         int uvSize = width * height / 4;
448         byte[] yRef = new byte[ySize];
449         byte[] yDec = new byte[ySize];
450         byte[] uvRef = new byte[uvSize];
451         byte[] uvDec = new byte[uvSize];
452 
453         int frames = 0;
454         double averageYPSNR = 0;
455         double averageUPSNR = 0;
456         double averageVPSNR = 0;
457         double minimumYPSNR = Integer.MAX_VALUE;
458         double minimumUPSNR = Integer.MAX_VALUE;
459         double minimumVPSNR = Integer.MAX_VALUE;
460         int minimumPSNRFrameIndex = 0;
461 
462         while (true) {
463             // Calculate Y PSNR.
464             boolean refResult = readFully(referenceStream, yRef);
465             boolean decResult = readFully(decodedStream, yDec);
466             assertTrue(refResult == decResult);
467             if (!refResult) {
468                 // We've reached the end.
469                 break;
470             }
471             double curYPSNR = computePSNR(yRef, yDec);
472             averageYPSNR += curYPSNR;
473             minimumYPSNR = Math.min(minimumYPSNR, curYPSNR);
474             double curMinimumPSNR = curYPSNR;
475 
476             // Calculate U PSNR.
477             assertTrue(readFully(referenceStream, uvRef));
478             assertTrue(readFully(decodedStream, uvDec));
479             double curUPSNR = computePSNR(uvRef, uvDec);
480             averageUPSNR += curUPSNR;
481             minimumUPSNR = Math.min(minimumUPSNR, curUPSNR);
482             curMinimumPSNR = Math.min(curMinimumPSNR, curUPSNR);
483 
484             // Calculate V PSNR.
485             assertTrue(readFully(referenceStream, uvRef));
486             assertTrue(readFully(decodedStream, uvDec));
487             double curVPSNR = computePSNR(uvRef, uvDec);
488             averageVPSNR += curVPSNR;
489             minimumVPSNR = Math.min(minimumVPSNR, curVPSNR);
490             curMinimumPSNR = Math.min(curMinimumPSNR, curVPSNR);
491 
492             // Frame index for minimum PSNR value - help to detect possible distortions
493             if (curMinimumPSNR < statistics.mMinimumPSNR) {
494                 statistics.mMinimumPSNR = curMinimumPSNR;
495                 minimumPSNRFrameIndex = frames;
496             }
497 
498             String logStr = String.format(Locale.US, "PSNR #%d: Y: %.2f. U: %.2f. V: %.2f",
499                     frames, curYPSNR, curUPSNR, curVPSNR);
500             Log.v(TAG, logStr);
501 
502             frames++;
503         }
504 
505         averageYPSNR /= frames;
506         averageUPSNR /= frames;
507         averageVPSNR /= frames;
508         statistics.mAveragePSNR = (4 * averageYPSNR + averageUPSNR + averageVPSNR) / 6;
509 
510         Log.d(TAG, "PSNR statistics for " + frames + " frames.");
511         String logStr = String.format(Locale.US,
512                 "Average PSNR: Y: %.1f. U: %.1f. V: %.1f. Average: %.1f",
513                 averageYPSNR, averageUPSNR, averageVPSNR, statistics.mAveragePSNR);
514         Log.d(TAG, logStr);
515         logStr = String.format(Locale.US,
516                 "Minimum PSNR: Y: %.1f. U: %.1f. V: %.1f. Overall: %.1f at frame %d",
517                 minimumYPSNR, minimumUPSNR, minimumVPSNR,
518                 statistics.mMinimumPSNR, minimumPSNRFrameIndex);
519         Log.d(TAG, logStr);
520 
521         referenceStream.close();
522         decodedStream.close();
523         referenceFd.close();
524         decodedFd.close();
525         return statistics;
526     }
527 
528     /**
529      * Reads {@code out.length} of data into {@code out}. True is returned if the operation
530      * succeeds, and false is returned if {@code in} was already ended. If {@code in} was not
531      * already ended but ad fewer than {@code out.length} bytes remaining, then {@link EOFException}
532      * is thrown.
533      */
readFully(InputStream in, byte[] out)534     private static boolean readFully(InputStream in, byte[] out) throws IOException {
535         int totalBytesRead = 0;
536         while (totalBytesRead < out.length) {
537             int bytesRead = in.read(out, totalBytesRead, out.length - totalBytesRead);
538             if (bytesRead == -1) {
539                 if (totalBytesRead == 0) {
540                     // We were already at the end of the stream.
541                     return false;
542                 } else {
543                     throw new EOFException();
544                 }
545             } else {
546                 totalBytesRead += bytesRead;
547             }
548         }
549         return true;
550     }
551 
552     /**
553      * Transcoding PSNR statistics.
554      */
555     protected static class VideoTranscodingStatistics {
556         public double mAveragePSNR;
557         public double mMinimumPSNR;
558 
VideoTranscodingStatistics()559         VideoTranscodingStatistics() {
560             mMinimumPSNR = Integer.MAX_VALUE;
561         }
562     }
563 }
564