1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Unit Test for MediaSampleReaderNDK
18 
19 // #define LOG_NDEBUG 0
20 #define LOG_TAG "MediaSampleReaderNDKTests"
21 
22 #include <android-base/logging.h>
23 #include <android/binder_manager.h>
24 #include <android/binder_process.h>
25 #include <fcntl.h>
26 #include <gtest/gtest.h>
27 #include <media/MediaSampleReaderNDK.h>
28 #include <openssl/md5.h>
29 #include <utils/Timers.h>
30 
31 #include <cmath>
32 #include <mutex>
33 #include <thread>
34 
35 // TODO(b/153453392): Test more asset types (frame reordering?).
36 
37 namespace android {
38 
39 #define SEC_TO_USEC(s) ((s)*1000 * 1000)
40 
41 /** Helper class for comparing sample data using checksums. */
42 class Sample {
43 public:
Sample(uint32_t flags,int64_t timestamp,size_t size,const uint8_t * buffer)44     Sample(uint32_t flags, int64_t timestamp, size_t size, const uint8_t* buffer)
45           : mFlags{flags}, mTimestamp{timestamp}, mSize{size} {
46         initChecksum(buffer);
47     }
48 
Sample(AMediaExtractor * extractor)49     Sample(AMediaExtractor* extractor) {
50         mFlags = AMediaExtractor_getSampleFlags(extractor);
51         mTimestamp = AMediaExtractor_getSampleTime(extractor);
52         mSize = static_cast<size_t>(AMediaExtractor_getSampleSize(extractor));
53 
54         auto buffer = std::make_unique<uint8_t[]>(mSize);
55         AMediaExtractor_readSampleData(extractor, buffer.get(), mSize);
56 
57         initChecksum(buffer.get());
58     }
59 
initChecksum(const uint8_t * buffer)60     void initChecksum(const uint8_t* buffer) {
61         MD5_CTX md5Ctx;
62         MD5_Init(&md5Ctx);
63         MD5_Update(&md5Ctx, buffer, mSize);
64         MD5_Final(mChecksum, &md5Ctx);
65     }
66 
operator ==(const Sample & rhs) const67     bool operator==(const Sample& rhs) const {
68         return mSize == rhs.mSize && mFlags == rhs.mFlags && mTimestamp == rhs.mTimestamp &&
69                memcmp(mChecksum, rhs.mChecksum, MD5_DIGEST_LENGTH) == 0;
70     }
71 
72     uint32_t mFlags;
73     int64_t mTimestamp;
74     size_t mSize;
75     uint8_t mChecksum[MD5_DIGEST_LENGTH];
76 };
77 
78 /** Constant for selecting all samples. */
79 static constexpr int SAMPLE_COUNT_ALL = -1;
80 
81 /**
82  * Utility class to test different sample access patterns combined with sequential or parallel
83  * sample access modes.
84  */
85 class SampleAccessTester {
86 public:
SampleAccessTester(int sourceFd,size_t fileSize)87     SampleAccessTester(int sourceFd, size_t fileSize) {
88         mSampleReader = MediaSampleReaderNDK::createFromFd(sourceFd, 0, fileSize);
89         EXPECT_TRUE(mSampleReader);
90 
91         mTrackCount = mSampleReader->getTrackCount();
92 
93         for (int trackIndex = 0; trackIndex < mTrackCount; trackIndex++) {
94             EXPECT_EQ(mSampleReader->selectTrack(trackIndex), AMEDIA_OK);
95         }
96 
97         mSamples.resize(mTrackCount);
98         mTrackThreads.resize(mTrackCount);
99     }
100 
getSampleInfo(int trackIndex)101     void getSampleInfo(int trackIndex) {
102         MediaSampleInfo info;
103         media_status_t status = mSampleReader->getSampleInfoForTrack(trackIndex, &info);
104         EXPECT_EQ(status, AMEDIA_OK);
105     }
106 
readSamplesAsync(int trackIndex,int sampleCount)107     void readSamplesAsync(int trackIndex, int sampleCount) {
108         mTrackThreads[trackIndex] = std::thread{[this, trackIndex, sampleCount] {
109             int samplesRead = 0;
110             MediaSampleInfo info;
111             while (samplesRead < sampleCount || sampleCount == SAMPLE_COUNT_ALL) {
112                 media_status_t status = mSampleReader->getSampleInfoForTrack(trackIndex, &info);
113                 if (status != AMEDIA_OK) {
114                     EXPECT_EQ(status, AMEDIA_ERROR_END_OF_STREAM);
115                     EXPECT_TRUE((info.flags & SAMPLE_FLAG_END_OF_STREAM) != 0);
116                     break;
117                 }
118                 ASSERT_TRUE((info.flags & SAMPLE_FLAG_END_OF_STREAM) == 0);
119 
120                 auto buffer = std::make_unique<uint8_t[]>(info.size);
121                 status = mSampleReader->readSampleDataForTrack(trackIndex, buffer.get(), info.size);
122                 EXPECT_EQ(status, AMEDIA_OK);
123 
124                 mSampleMutex.lock();
125                 const uint8_t* bufferPtr = buffer.get();
126                 mSamples[trackIndex].emplace_back(info.flags, info.presentationTimeUs, info.size,
127                                                   bufferPtr);
128                 mSampleMutex.unlock();
129                 ++samplesRead;
130             }
131         }};
132     }
133 
readSamplesAsync(int sampleCount)134     void readSamplesAsync(int sampleCount) {
135         for (int trackIndex = 0; trackIndex < mTrackCount; trackIndex++) {
136             readSamplesAsync(trackIndex, sampleCount);
137         }
138     }
139 
waitForTrack(int trackIndex)140     void waitForTrack(int trackIndex) {
141         ASSERT_TRUE(mTrackThreads[trackIndex].joinable());
142         mTrackThreads[trackIndex].join();
143     }
144 
waitForTracks()145     void waitForTracks() {
146         for (int trackIndex = 0; trackIndex < mTrackCount; trackIndex++) {
147             waitForTrack(trackIndex);
148         }
149     }
150 
setEnforceSequentialAccess(bool enforce)151     void setEnforceSequentialAccess(bool enforce) {
152         media_status_t status = mSampleReader->setEnforceSequentialAccess(enforce);
153         EXPECT_EQ(status, AMEDIA_OK);
154     }
155 
getSamples()156     std::vector<std::vector<Sample>>& getSamples() { return mSamples; }
157 
158     std::shared_ptr<MediaSampleReader> mSampleReader;
159     size_t mTrackCount;
160     std::mutex mSampleMutex;
161     std::vector<std::thread> mTrackThreads;
162     std::vector<std::vector<Sample>> mSamples;
163 };
164 
165 class MediaSampleReaderNDKTests : public ::testing::Test {
166 public:
MediaSampleReaderNDKTests()167     MediaSampleReaderNDKTests() { LOG(DEBUG) << "MediaSampleReaderNDKTests created"; }
168 
SetUp()169     void SetUp() override {
170         LOG(DEBUG) << "MediaSampleReaderNDKTests set up";
171 
172         // Need to start a thread pool to prevent AMediaExtractor binder calls from starving
173         // (b/155663561).
174         ABinderProcess_startThreadPool();
175 
176         const char* sourcePath =
177                 "/data/local/tmp/TranscodingTestAssets/cubicle_avc_480x240_aac_24KHz.mp4";
178 
179         mSourceFd = open(sourcePath, O_RDONLY);
180         ASSERT_GT(mSourceFd, 0);
181 
182         mFileSize = lseek(mSourceFd, 0, SEEK_END);
183         lseek(mSourceFd, 0, SEEK_SET);
184 
185         mExtractor = AMediaExtractor_new();
186         ASSERT_NE(mExtractor, nullptr);
187 
188         media_status_t status =
189                 AMediaExtractor_setDataSourceFd(mExtractor, mSourceFd, 0, mFileSize);
190         ASSERT_EQ(status, AMEDIA_OK);
191 
192         mTrackCount = AMediaExtractor_getTrackCount(mExtractor);
193         for (size_t trackIndex = 0; trackIndex < mTrackCount; trackIndex++) {
194             AMediaExtractor_selectTrack(mExtractor, trackIndex);
195         }
196     }
197 
initExtractorSamples()198     void initExtractorSamples() {
199         if (mExtractorSamples.size() == mTrackCount) return;
200 
201         // Save sample information, per track, as reported by the extractor.
202         mExtractorSamples.resize(mTrackCount);
203         do {
204             const int trackIndex = AMediaExtractor_getSampleTrackIndex(mExtractor);
205             mExtractorSamples[trackIndex].emplace_back(mExtractor);
206         } while (AMediaExtractor_advance(mExtractor));
207 
208         AMediaExtractor_seekTo(mExtractor, 0, AMEDIAEXTRACTOR_SEEK_PREVIOUS_SYNC);
209     }
210 
getTrackBitrates()211     std::vector<int32_t> getTrackBitrates() {
212         size_t totalSize[mTrackCount];
213         memset(totalSize, 0, sizeof(totalSize));
214 
215         do {
216             const int trackIndex = AMediaExtractor_getSampleTrackIndex(mExtractor);
217             totalSize[trackIndex] += AMediaExtractor_getSampleSize(mExtractor);
218         } while (AMediaExtractor_advance(mExtractor));
219 
220         AMediaExtractor_seekTo(mExtractor, 0, AMEDIAEXTRACTOR_SEEK_PREVIOUS_SYNC);
221 
222         std::vector<int32_t> bitrates;
223         for (int trackIndex = 0; trackIndex < mTrackCount; trackIndex++) {
224             int64_t durationUs;
225             AMediaFormat* trackFormat = AMediaExtractor_getTrackFormat(mExtractor, trackIndex);
226             EXPECT_NE(trackFormat, nullptr);
227             EXPECT_TRUE(AMediaFormat_getInt64(trackFormat, AMEDIAFORMAT_KEY_DURATION, &durationUs));
228             bitrates.push_back(roundf((float)totalSize[trackIndex] * 8 * 1000000 / durationUs));
229         }
230 
231         return bitrates;
232     }
233 
compareSamples(std::vector<std::vector<Sample>> & readerSamples)234     void compareSamples(std::vector<std::vector<Sample>>& readerSamples) {
235         initExtractorSamples();
236         EXPECT_EQ(readerSamples.size(), mTrackCount);
237 
238         for (int trackIndex = 0; trackIndex < mTrackCount; trackIndex++) {
239             LOG(DEBUG) << "Track " << trackIndex << ", comparing "
240                        << readerSamples[trackIndex].size() << " samples.";
241             EXPECT_EQ(readerSamples[trackIndex].size(), mExtractorSamples[trackIndex].size());
242             for (size_t sampleIndex = 0; sampleIndex < readerSamples[trackIndex].size();
243                  sampleIndex++) {
244                 EXPECT_EQ(readerSamples[trackIndex][sampleIndex],
245                           mExtractorSamples[trackIndex][sampleIndex]);
246             }
247         }
248     }
249 
TearDown()250     void TearDown() override {
251         LOG(DEBUG) << "MediaSampleReaderNDKTests tear down";
252         AMediaExtractor_delete(mExtractor);
253         close(mSourceFd);
254     }
255 
~MediaSampleReaderNDKTests()256     ~MediaSampleReaderNDKTests() { LOG(DEBUG) << "MediaSampleReaderNDKTests destroyed"; }
257 
258     AMediaExtractor* mExtractor = nullptr;
259     size_t mTrackCount;
260     int mSourceFd;
261     size_t mFileSize;
262     std::vector<std::vector<Sample>> mExtractorSamples;
263 };
264 
265 /** Reads all samples from all tracks in parallel. */
TEST_F(MediaSampleReaderNDKTests,TestParallelSampleAccess)266 TEST_F(MediaSampleReaderNDKTests, TestParallelSampleAccess) {
267     LOG(DEBUG) << "TestParallelSampleAccess Starts";
268 
269     SampleAccessTester tester{mSourceFd, mFileSize};
270     tester.readSamplesAsync(SAMPLE_COUNT_ALL);
271     tester.waitForTracks();
272     compareSamples(tester.getSamples());
273 }
274 
275 /** Reads all samples except the last in each track, before finishing. */
TEST_F(MediaSampleReaderNDKTests,TestLastSampleBeforeEOS)276 TEST_F(MediaSampleReaderNDKTests, TestLastSampleBeforeEOS) {
277     LOG(DEBUG) << "TestLastSampleBeforeEOS Starts";
278     initExtractorSamples();
279 
280     {  // Natural track order
281         SampleAccessTester tester{mSourceFd, mFileSize};
282         for (int trackIndex = 0; trackIndex < mTrackCount; ++trackIndex) {
283             tester.readSamplesAsync(trackIndex, mExtractorSamples[trackIndex].size() - 1);
284         }
285         tester.waitForTracks();
286         for (int trackIndex = 0; trackIndex < mTrackCount; ++trackIndex) {
287             tester.readSamplesAsync(trackIndex, SAMPLE_COUNT_ALL);
288             tester.waitForTrack(trackIndex);
289         }
290         compareSamples(tester.getSamples());
291     }
292 
293     {  // Reverse track order
294         SampleAccessTester tester{mSourceFd, mFileSize};
295         for (int trackIndex = mTrackCount - 1; trackIndex >= 0; --trackIndex) {
296             tester.readSamplesAsync(trackIndex, mExtractorSamples[trackIndex].size() - 1);
297         }
298         tester.waitForTracks();
299         for (int trackIndex = mTrackCount - 1; trackIndex >= 0; --trackIndex) {
300             tester.readSamplesAsync(trackIndex, SAMPLE_COUNT_ALL);
301             tester.waitForTrack(trackIndex);
302         }
303         compareSamples(tester.getSamples());
304     }
305 }
306 
307 /** Reads all samples from all tracks sequentially. */
TEST_F(MediaSampleReaderNDKTests,TestSequentialSampleAccess)308 TEST_F(MediaSampleReaderNDKTests, TestSequentialSampleAccess) {
309     LOG(DEBUG) << "TestSequentialSampleAccess Starts";
310 
311     SampleAccessTester tester{mSourceFd, mFileSize};
312     tester.setEnforceSequentialAccess(true);
313     tester.readSamplesAsync(SAMPLE_COUNT_ALL);
314     tester.waitForTracks();
315     compareSamples(tester.getSamples());
316 }
317 
318 /** Reads all samples from one track in parallel mode before switching to sequential mode. */
TEST_F(MediaSampleReaderNDKTests,TestMixedSampleAccessTrackEOS)319 TEST_F(MediaSampleReaderNDKTests, TestMixedSampleAccessTrackEOS) {
320     LOG(DEBUG) << "TestMixedSampleAccessTrackEOS Starts";
321 
322     for (int readSampleInfoFlag = 0; readSampleInfoFlag <= 1; readSampleInfoFlag++) {
323         for (int trackIndToEOS = 0; trackIndToEOS < mTrackCount; ++trackIndToEOS) {
324             LOG(DEBUG) << "Testing EOS of track " << trackIndToEOS;
325 
326             SampleAccessTester tester{mSourceFd, mFileSize};
327 
328             // If the flag is set, read sample info from a different track before draining the track
329             // under test to force the reader to save the extractor position.
330             if (readSampleInfoFlag) {
331                 tester.getSampleInfo((trackIndToEOS + 1) % mTrackCount);
332             }
333 
334             // Read all samples from one track before enabling sequential access
335             tester.readSamplesAsync(trackIndToEOS, SAMPLE_COUNT_ALL);
336             tester.waitForTrack(trackIndToEOS);
337             tester.setEnforceSequentialAccess(true);
338 
339             for (int trackIndex = 0; trackIndex < mTrackCount; ++trackIndex) {
340                 if (trackIndex == trackIndToEOS) continue;
341 
342                 tester.readSamplesAsync(trackIndex, SAMPLE_COUNT_ALL);
343                 tester.waitForTrack(trackIndex);
344             }
345 
346             compareSamples(tester.getSamples());
347         }
348     }
349 }
350 
351 /**
352  * Reads different combinations of sample counts from all tracks in parallel mode before switching
353  * to sequential mode and reading the rest of the samples.
354  */
TEST_F(MediaSampleReaderNDKTests,TestMixedSampleAccess)355 TEST_F(MediaSampleReaderNDKTests, TestMixedSampleAccess) {
356     LOG(DEBUG) << "TestMixedSampleAccess Starts";
357     initExtractorSamples();
358 
359     for (int trackIndToTest = 0; trackIndToTest < mTrackCount; ++trackIndToTest) {
360         for (int sampleCount = 0; sampleCount <= (mExtractorSamples[trackIndToTest].size() + 1);
361              ++sampleCount) {
362             SampleAccessTester tester{mSourceFd, mFileSize};
363 
364             for (int trackIndex = 0; trackIndex < mTrackCount; ++trackIndex) {
365                 if (trackIndex == trackIndToTest) {
366                     tester.readSamplesAsync(trackIndex, sampleCount);
367                 } else {
368                     tester.readSamplesAsync(trackIndex, mExtractorSamples[trackIndex].size() / 2);
369                 }
370             }
371 
372             tester.waitForTracks();
373             tester.setEnforceSequentialAccess(true);
374 
375             tester.readSamplesAsync(SAMPLE_COUNT_ALL);
376             tester.waitForTracks();
377 
378             compareSamples(tester.getSamples());
379         }
380     }
381 }
382 
TEST_F(MediaSampleReaderNDKTests,TestEstimatedBitrateAccuracy)383 TEST_F(MediaSampleReaderNDKTests, TestEstimatedBitrateAccuracy) {
384     // Just put a somewhat reasonable upper bound on the estimated bitrate expected in our test
385     // assets. This is mostly to make sure the estimation is not way off.
386     static constexpr int32_t kMaxEstimatedBitrate = 100 * 1000 * 1000;  // 100 Mbps
387 
388     auto sampleReader = MediaSampleReaderNDK::createFromFd(mSourceFd, 0, mFileSize);
389     ASSERT_TRUE(sampleReader);
390 
391     std::vector<int32_t> actualTrackBitrates = getTrackBitrates();
392     for (int trackIndex = 0; trackIndex < mTrackCount; ++trackIndex) {
393         EXPECT_EQ(sampleReader->selectTrack(trackIndex), AMEDIA_OK);
394 
395         int32_t bitrate;
396         EXPECT_EQ(sampleReader->getEstimatedBitrateForTrack(trackIndex, &bitrate), AMEDIA_OK);
397         EXPECT_GT(bitrate, 0);
398         EXPECT_LT(bitrate, kMaxEstimatedBitrate);
399 
400         // Note: The test asset currently used in this test is shorter than the sampling duration
401         // used to estimate the bitrate in the sample reader. So for now the estimation should be
402         // exact but if/when a longer asset is used a reasonable delta needs to be defined.
403         EXPECT_EQ(bitrate, actualTrackBitrates[trackIndex]);
404     }
405 }
406 
TEST_F(MediaSampleReaderNDKTests,TestInvalidFd)407 TEST_F(MediaSampleReaderNDKTests, TestInvalidFd) {
408     std::shared_ptr<MediaSampleReader> sampleReader =
409             MediaSampleReaderNDK::createFromFd(0, 0, mFileSize);
410     ASSERT_TRUE(sampleReader == nullptr);
411 
412     sampleReader = MediaSampleReaderNDK::createFromFd(-1, 0, mFileSize);
413     ASSERT_TRUE(sampleReader == nullptr);
414 }
415 
TEST_F(MediaSampleReaderNDKTests,TestZeroSize)416 TEST_F(MediaSampleReaderNDKTests, TestZeroSize) {
417     std::shared_ptr<MediaSampleReader> sampleReader =
418             MediaSampleReaderNDK::createFromFd(mSourceFd, 0, 0);
419     ASSERT_TRUE(sampleReader == nullptr);
420 }
421 
TEST_F(MediaSampleReaderNDKTests,TestInvalidOffset)422 TEST_F(MediaSampleReaderNDKTests, TestInvalidOffset) {
423     std::shared_ptr<MediaSampleReader> sampleReader =
424             MediaSampleReaderNDK::createFromFd(mSourceFd, mFileSize, mFileSize);
425     ASSERT_TRUE(sampleReader == nullptr);
426 }
427 
428 }  // namespace android
429 
main(int argc,char ** argv)430 int main(int argc, char** argv) {
431     ::testing::InitGoogleTest(&argc, argv);
432     return RUN_ALL_TESTS();
433 }
434