1 /*
<lambda>null2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.app.tracing.coroutines
18 
19 import com.android.app.tracing.FakeTraceState.getOpenTraceSectionsOnCurrentThread
20 import com.android.systemui.Flags
21 import java.util.concurrent.CyclicBarrier
22 import java.util.concurrent.Executors
23 import java.util.concurrent.TimeUnit
24 import java.util.concurrent.atomic.AtomicInteger
25 import kotlin.coroutines.CoroutineContext
26 import kotlin.coroutines.EmptyCoroutineContext
27 import kotlinx.coroutines.CoroutineScope
28 import kotlinx.coroutines.CoroutineStart
29 import kotlinx.coroutines.channels.Channel
30 import kotlinx.coroutines.delay
31 import kotlinx.coroutines.launch
32 import kotlinx.coroutines.newSingleThreadContext
33 import kotlinx.coroutines.test.TestScope
34 import kotlinx.coroutines.test.UnconfinedTestDispatcher
35 import kotlinx.coroutines.test.runTest
36 import kotlinx.coroutines.withContext
37 import org.junit.After
38 import org.junit.Assert.assertArrayEquals
39 import org.junit.Assert.assertEquals
40 import org.junit.Assert.assertNotNull
41 import org.junit.Assert.assertNotSame
42 import org.junit.Assert.assertNull
43 import org.junit.Assert.assertSame
44 import org.junit.Assert.assertTrue
45 import org.junit.Before
46 import org.junit.Test
47 import org.junit.runner.RunWith
48 import org.junit.runners.BlockJUnit4ClassRunner
49 
50 @RunWith(BlockJUnit4ClassRunner::class)
51 class CoroutineTracingTest {
52     @Before
53     fun setup() {
54         TraceData.strictModeForTesting = true
55     }
56 
57     @After
58     fun checkFinished() {
59         val lastEvent = eventCounter.get()
60         assertTrue(
61             "Expected `finish(${lastEvent + 1})` to be called, but the test finished",
62             lastEvent == FINAL_EVENT || lastEvent == 0,
63         )
64     }
65 
66     @Test
67     fun simpleTraceSection() = runTestWithTraceContext {
68         expect(1)
69         traceCoroutine("hello") { expect(2, "hello") }
70         finish(3)
71     }
72 
73     @Test
74     fun simpleNestedTraceSection() = runTestWithTraceContext {
75         expect(1)
76         traceCoroutine("hello") {
77             expect(2, "hello")
78             traceCoroutine("world") { expect(3, "hello", "world") }
79             expect(4, "hello")
80         }
81         finish(5)
82     }
83 
84     @Test
85     fun simpleLaunch() = runTestWithTraceContext {
86         expect(1)
87         traceCoroutine("hello") {
88             expect(2, "hello")
89             launch { finish(4, "hello") }
90         }
91         expect(3)
92     }
93 
94     @Test
95     fun launchWithSuspendingLambda() = runTestWithTraceContext {
96         val fetchData: suspend () -> String = {
97             expect(3, "span-for-launch")
98             delay(1L)
99             traceCoroutine("span-for-fetchData") {
100                 expect(4, "span-for-launch", "span-for-fetchData")
101             }
102             "stuff"
103         }
104         expect(1)
105         launch("span-for-launch") {
106             assertEquals("stuff", fetchData())
107             finish(5, "span-for-launch")
108         }
109         expect(2)
110     }
111 
112     @Test
113     fun nestedUpdateAndRestoreOnSingleThread_unconfinedDispatcher() = runTestWithTraceContext {
114         traceCoroutine("parent-span") {
115             expect(1, "parent-span")
116             launch(UnconfinedTestDispatcher(scheduler = testScheduler)) {
117                 // While this may appear unusual, it is actually expected behavior:
118                 //   1) The parent has an open trace section called "parent-span".
119                 //   2) The child launches, it inherits from its parent, and it is resumed
120                 //      immediately due to its use of the unconfined dispatcher.
121                 //   3) The child emits all the trace sections known to its scope. The parent
122                 //      does not have an opportunity to restore its context yet.
123                 traceCoroutine("child-span") {
124                     // [parent's active trace]
125                     //           \  [trace section inherited from parent]
126                     //            \                 |    [new trace section in child scope]
127                     //             \                |             /
128                     expect(2, "parent-span", "parent-span", "child-span")
129                     delay(1) // <-- delay will give parent a chance to restore its context
130                     // After a delay, the parent resumes, finishing its trace section, so we are
131                     // left with only those in the child's scope
132                     finish(4, "parent-span", "child-span")
133                 }
134             }
135         }
136         expect(3)
137     }
138 
139     /** @see nestedUpdateAndRestoreOnSingleThread_unconfinedDispatcher */
140     @Test
141     fun nestedUpdateAndRestoreOnSingleThread_undispatchedLaunch() = runTestWithTraceContext {
142         traceCoroutine("parent-span") {
143             launch(start = CoroutineStart.UNDISPATCHED) {
144                 traceCoroutine("child-span") {
145                     expect(1, "parent-span", "parent-span", "child-span")
146                     delay(1) // <-- delay will give parent a chance to restore its context
147                     finish(3, "parent-span", "child-span")
148                 }
149             }
150         }
151         expect(2)
152     }
153 
154     @Test
155     fun launchOnSeparateThread_defaultDispatcher() = runTestWithTraceContext {
156         val channel = Channel<Int>()
157         val bgThread = newSingleThreadContext("thread-#1")
158         expect()
159         traceCoroutine("hello") {
160             expect(1, "hello")
161             launch(bgThread) {
162                 expect(2, "hello")
163                 traceCoroutine("world") {
164                     expect("hello", "world")
165                     channel.send(1)
166                     expect(3, "hello", "world")
167                 }
168             }
169             expect("hello")
170         }
171         expect()
172         assertEquals(1, channel.receive())
173         finish(4)
174     }
175 
176     @Test
177     fun testTraceStorage() = runTestWithTraceContext {
178         val channel = Channel<Int>()
179         val fetchData: suspend () -> String = {
180             traceCoroutine("span-for-fetchData") {
181                 channel.receive()
182                 expect("span-for-launch", "span-for-fetchData")
183             }
184             "stuff"
185         }
186         val threadContexts =
187             listOf(
188                 newSingleThreadContext("thread-#1"),
189                 newSingleThreadContext("thread-#2"),
190                 newSingleThreadContext("thread-#3"),
191                 newSingleThreadContext("thread-#4"),
192             )
193 
194         val finishedLaunches = Channel<Int>()
195 
196         // Start 1000 coroutines waiting on [channel]
197         val job = launch {
198             repeat(1000) {
199                 launch("span-for-launch", threadContexts[it % threadContexts.size]) {
200                     assertNotNull(traceThreadLocal.get())
201                     assertEquals("stuff", fetchData())
202                     expect("span-for-launch")
203                     assertNotNull(traceThreadLocal.get())
204                     expect("span-for-launch")
205                     finishedLaunches.send(it)
206                 }
207                 expect()
208             }
209         }
210         // Resume half the coroutines that are waiting on this channel
211         repeat(500) { channel.send(1) }
212         var receivedClosures = 0
213         repeat(500) {
214             finishedLaunches.receive()
215             receivedClosures++
216         }
217         // ...and cancel the rest
218         job.cancel()
219     }
220 
221     private fun CoroutineScope.testTraceSectionsMultiThreaded(
222         thread1Context: CoroutineContext,
223         thread2Context: CoroutineContext
224     ) {
225         val fetchData1: suspend () -> String = {
226             expect("span-for-launch-1")
227             delay(1L)
228             traceCoroutine("span-for-fetchData-1") {
229                 expect("span-for-launch-1", "span-for-fetchData-1")
230             }
231             expect("span-for-launch-1")
232             "stuff-1"
233         }
234 
235         val fetchData2: suspend () -> String = {
236             expect(
237                 "span-for-launch-1",
238                 "span-for-launch-2",
239             )
240             delay(1L)
241             traceCoroutine("span-for-fetchData-2") {
242                 expect("span-for-launch-1", "span-for-launch-2", "span-for-fetchData-2")
243             }
244             expect(
245                 "span-for-launch-1",
246                 "span-for-launch-2",
247             )
248             "stuff-2"
249         }
250 
251         val thread1 = newSingleThreadContext("thread-#1") + thread1Context
252         val thread2 = newSingleThreadContext("thread-#2") + thread2Context
253 
254         launch("span-for-launch-1", thread1) {
255             assertEquals("stuff-1", fetchData1())
256             expect("span-for-launch-1")
257             launch("span-for-launch-2", thread2) {
258                 assertEquals("stuff-2", fetchData2())
259                 expect("span-for-launch-1", "span-for-launch-2")
260             }
261             expect("span-for-launch-1")
262         }
263         expect()
264 
265         // Launching without the trace extension won't result in traces
266         launch(thread1) { expect() }
267         launch(thread2) { expect() }
268     }
269 
270     @Test
271     fun nestedTraceSectionsMultiThreaded1() = runTestWithTraceContext {
272         // Thread-#1 and Thread-#2 inherit TraceContextElement from the test's CoroutineContext.
273         testTraceSectionsMultiThreaded(
274             thread1Context = EmptyCoroutineContext,
275             thread2Context = EmptyCoroutineContext
276         )
277     }
278 
279     @Test
280     fun nestedTraceSectionsMultiThreaded2() = runTest {
281         // Thread-#2 inherits the TraceContextElement from Thread-#1. The test's CoroutineContext
282         // does not need a TraceContextElement because it does not do any tracing.
283         testTraceSectionsMultiThreaded(
284             thread1Context = TraceContextElement(TraceData()),
285             thread2Context = EmptyCoroutineContext
286         )
287     }
288 
289     @Test
290     fun nestedTraceSectionsMultiThreaded3() = runTest {
291         // Thread-#2 overrides the TraceContextElement from Thread-#1, but the merging context
292         // should be fine; it is essentially a no-op. The test's CoroutineContext does not need the
293         // trace context because it does not do any tracing.
294         testTraceSectionsMultiThreaded(
295             thread1Context = TraceContextElement(TraceData()),
296             thread2Context = TraceContextElement(TraceData())
297         )
298     }
299 
300     @Test
301     fun nestedTraceSectionsMultiThreaded4() = runTestWithTraceContext {
302         // TraceContextElement is merged on each context switch, which should have no effect on the
303         // trace results.
304         testTraceSectionsMultiThreaded(
305             thread1Context = TraceContextElement(TraceData()),
306             thread2Context = TraceContextElement(TraceData())
307         )
308     }
309 
310     @Test
311     fun missingTraceContextObjects() = runTest {
312         val channel = Channel<Int>()
313         // Thread-#1 is missing a TraceContextElement, so some of the trace sections get dropped.
314         // The resulting trace sections will be different than the 4 tests above.
315         val fetchData1: suspend () -> String = {
316             expect()
317             channel.receive()
318             traceCoroutine("span-for-fetchData-1") { expect() }
319             expect()
320             "stuff-1"
321         }
322 
323         val fetchData2: suspend () -> String = {
324             expect(
325                 "span-for-launch-2",
326             )
327             channel.receive()
328             traceCoroutine("span-for-fetchData-2") {
329                 expect("span-for-launch-2", "span-for-fetchData-2")
330             }
331             expect(
332                 "span-for-launch-2",
333             )
334             "stuff-2"
335         }
336 
337         val thread1 = newSingleThreadContext("thread-#1")
338         val thread2 = newSingleThreadContext("thread-#2") + TraceContextElement(TraceData())
339 
340         launch("span-for-launch-1", thread1) {
341             assertEquals("stuff-1", fetchData1())
342             expect()
343             launch("span-for-launch-2", thread2) {
344                 assertEquals("stuff-2", fetchData2())
345                 expect("span-for-launch-2")
346             }
347             expect()
348         }
349         expect()
350 
351         channel.send(1)
352         channel.send(2)
353 
354         // Launching without the trace extension won't result in traces
355         launch(thread1) { expect() }
356         launch(thread2) { expect() }
357     }
358 
359     /**
360      * Tests interleaving:
361      * ```
362      * Thread #1 | [updateThreadContext]....^              [restoreThreadContext]
363      * --------------------------------------------------------------------------------------------
364      * Thread #2 |                           [updateThreadContext]...........^[restoreThreadContext]
365      * ```
366      *
367      * This test checks for issues with concurrent modification of the trace state. For example, the
368      * test should fail if [TraceData.endAllOnThread] uses the size of the slices array as follows
369      * instead of using the ThreadLocal count:
370      * ```
371      * class TraceData {
372      *   ...
373      *   fun endAllOnThread() {
374      *     repeat(slices.size) {
375      *       // THIS WOULD BE AN ERROR. If the thread is slow, the TraceData object could have been
376      *       // modified by another thread
377      *       endSlice()
378      *     }
379      *   ...
380      *   }
381      * }
382      * ```
383      */
384     @Test
385     fun coroutineMachinery() {
386         assertNull(traceThreadLocal.get())
387         val traceContext = TraceContextElement()
388         assertNull(traceThreadLocal.get())
389 
390         val thread1ResumptionPoint = CyclicBarrier(2)
391         val thread1SuspensionPoint = CyclicBarrier(2)
392 
393         val thread1 = Executors.newSingleThreadExecutor()
394         val thread2 = Executors.newSingleThreadExecutor()
395         val slicesForThread1 = listOf("a", "c", "e", "g")
396         val slicesForThread2 = listOf("b", "d", "f", "h")
397         var failureOnThread1: Error? = null
398         var failureOnThread2: Error? = null
399 
400         val expectedTraceForThread1 = arrayOf("1:a", "2:b", "1:c", "2:d", "1:e", "2:f", "1:g")
401         thread1.execute {
402             try {
403                 slicesForThread1.forEachIndexed { index, sliceName ->
404                     assertNull(traceThreadLocal.get())
405                     val oldTrace = traceContext.updateThreadContext(EmptyCoroutineContext)
406                     // await() AFTER updateThreadContext, thus thread #1 always resumes the
407                     // coroutine before thread #2
408                     assertSame(traceThreadLocal.get(), traceContext.traceData)
409 
410                     // coroutine body start {
411                     traceThreadLocal.get()?.beginSpan("1:$sliceName")
412 
413                     // At the end, verify the interleaved trace sections look correct:
414                     if (index == slicesForThread1.size - 1) {
415                         expect(*expectedTraceForThread1)
416                     }
417 
418                     // simulate a slow thread, wait to call restoreThreadContext until after thread
419                     // A
420                     // has resumed
421                     thread1SuspensionPoint.await(3, TimeUnit.SECONDS)
422                     Thread.sleep(500)
423                     // } coroutine body end
424 
425                     traceContext.restoreThreadContext(EmptyCoroutineContext, oldTrace)
426                     thread1ResumptionPoint.await(3, TimeUnit.SECONDS)
427                     assertNull(traceThreadLocal.get())
428                 }
429             } catch (e: Error) {
430                 failureOnThread1 = e
431             }
432         }
433 
434         val expectedTraceForThread2 =
435             arrayOf("1:a", "2:b", "1:c", "2:d", "1:e", "2:f", "1:g", "2:h")
436         thread2.execute {
437             try {
438                 slicesForThread2.forEachIndexed { i, n ->
439                     assertNull(traceThreadLocal.get())
440                     thread1SuspensionPoint.await(3, TimeUnit.SECONDS)
441 
442                     val oldTrace: TraceData? =
443                         traceContext.updateThreadContext(EmptyCoroutineContext)
444 
445                     // coroutine body start {
446                     traceThreadLocal.get()?.beginSpan("2:$n")
447 
448                     // At the end, verify the interleaved trace sections look correct:
449                     if (i == slicesForThread2.size - 1) {
450                         expect(*expectedTraceForThread2)
451                     }
452                     // } coroutine body end
453 
454                     traceContext.restoreThreadContext(EmptyCoroutineContext, oldTrace)
455                     thread1ResumptionPoint.await(3, TimeUnit.SECONDS)
456                     assertNull(traceThreadLocal.get())
457                 }
458             } catch (e: Error) {
459                 failureOnThread2 = e
460             }
461         }
462 
463         thread1.shutdown()
464         thread1.awaitTermination(5, TimeUnit.SECONDS)
465         thread2.shutdown()
466         thread2.awaitTermination(5, TimeUnit.SECONDS)
467 
468         assertNull("Failure executing coroutine on thread-#1.", failureOnThread1)
469         assertNull("Failure executing coroutine on thread-#2.", failureOnThread2)
470     }
471 
472     @Test
473     fun scopeReentry_withContextFastPath() = runTestWithTraceContext {
474         val channel = Channel<Int>()
475         val bgThread = newSingleThreadContext("bg-thread #1")
476         val job =
477             launch("#1", bgThread) {
478                 expect("#1")
479                 var i = 0
480                 while (true) {
481                     expect("#1")
482                     channel.send(i++)
483                     expect("#1")
484                     // when withContext is passed the same scope, it takes a fast path, dispatching
485                     // immediately. This means that in subsequent loops, if we do not handle reentry
486                     // correctly in TraceContextElement, the trace may become deeply nested:
487                     // "#1", "#1", "#1", ... "#2"
488                     withContext(bgThread) {
489                         expect("#1")
490                         traceCoroutine("#2") {
491                             expect("#1", "#2")
492                             channel.send(i++)
493                             expect("#1", "#2")
494                         }
495                         expect("#1")
496                     }
497                 }
498             }
499         repeat(1000) {
500             expect()
501             traceCoroutine("receive") {
502                 expect("receive")
503                 val receivedVal = channel.receive()
504                 assertEquals(it, receivedVal)
505                 expect("receive")
506             }
507             expect()
508         }
509         job.cancel()
510     }
511 
512     @Test
513     fun traceContextIsCopied() = runTest {
514         expect()
515         val traceContext = TraceContextElement()
516         expect()
517         withContext(traceContext) {
518             // Not the same object because it should be copied into the current context
519             assertNotSame(traceThreadLocal.get(), traceContext.traceData)
520             assertNotSame(traceThreadLocal.get()?.slices, traceContext.traceData?.slices)
521             expect()
522             traceCoroutine("hello") {
523                 assertNotSame(traceThreadLocal.get(), traceContext.traceData)
524                 assertNotSame(traceThreadLocal.get()?.slices, traceContext.traceData?.slices)
525                 assertArrayEquals(arrayOf("hello"), traceThreadLocal.get()?.slices?.toArray())
526             }
527             assertNotSame(traceThreadLocal.get(), traceContext.traceData)
528             assertNotSame(traceThreadLocal.get()?.slices, traceContext.traceData?.slices)
529             expect()
530         }
531         expect()
532     }
533 
534     @Test
535     fun tracingDisabled() = runTest {
536         Flags.disableCoroutineTracing()
537         assertNull(traceThreadLocal.get())
538         withContext(createCoroutineTracingContext()) {
539             assertNull(traceThreadLocal.get())
540             traceCoroutine("hello") { // should not crash
541                 assertNull(traceThreadLocal.get())
542             }
543         }
544     }
545 
546     private fun expect(vararg expectedOpenTraceSections: String) {
547         expect(null, *expectedOpenTraceSections)
548     }
549 
550     /**
551      * Checks the currently active trace sections on the current thread, and optionally checks the
552      * order of operations if [expectedEvent] is not null.
553      */
554     private fun expect(expectedEvent: Int? = null, vararg expectedOpenTraceSections: String) {
555         if (expectedEvent != null) {
556             val previousEvent = eventCounter.getAndAdd(1)
557             val currentEvent = previousEvent + 1
558             check(expectedEvent == currentEvent) {
559                 if (previousEvent == FINAL_EVENT) {
560                     "Expected event=$expectedEvent, but finish() was already called"
561                 } else {
562                     "Expected event=$expectedEvent," +
563                         " but the event counter is currently at $currentEvent"
564                 }
565             }
566         }
567 
568         // Inspect trace output to the fake used for recording android.os.Trace API calls:
569         assertArrayEquals(expectedOpenTraceSections, getOpenTraceSectionsOnCurrentThread())
570     }
571 
572     /** Same as [expect], except that no more [expect] statements can be called after it. */
573     private fun finish(expectedEvent: Int, vararg expectedOpenTraceSections: String) {
574         val previousEvent = eventCounter.getAndSet(FINAL_EVENT)
575         val currentEvent = previousEvent + 1
576         check(expectedEvent == currentEvent) {
577             if (previousEvent == FINAL_EVENT) {
578                 "finish() was called more than once"
579             } else {
580                 "Finished with event=$expectedEvent," +
581                     " but the event counter is currently $currentEvent"
582             }
583         }
584 
585         // Inspect trace output to the fake used for recording android.os.Trace API calls:
586         assertArrayEquals(expectedOpenTraceSections, getOpenTraceSectionsOnCurrentThread())
587     }
588 
589     private val eventCounter = AtomicInteger(0)
590 
591     companion object {
592         const val FINAL_EVENT = Int.MIN_VALUE
593     }
594 }
595 
596 /**
597  * Helper util for calling [runTest] with a [TraceContextElement]. This is useful for formatting
598  * purposes. Passing an arg to `runTest {}` directly, as in `fun testStuff() =
599  * runTestWithTraceContext {}` would require more indentations according to our style guide.
600  */
runTestWithTraceContextnull601 private fun runTestWithTraceContext(testBody: suspend TestScope.() -> Unit) =
602     runTest(context = TraceContextElement(TraceData()), testBody = testBody)
603