1 /*
<lambda>null2 * Copyright (C) 2024 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 package com.android.app.tracing.coroutines
18
19 import com.android.app.tracing.FakeTraceState.getOpenTraceSectionsOnCurrentThread
20 import com.android.systemui.Flags
21 import java.util.concurrent.CyclicBarrier
22 import java.util.concurrent.Executors
23 import java.util.concurrent.TimeUnit
24 import java.util.concurrent.atomic.AtomicInteger
25 import kotlin.coroutines.CoroutineContext
26 import kotlin.coroutines.EmptyCoroutineContext
27 import kotlinx.coroutines.CoroutineScope
28 import kotlinx.coroutines.CoroutineStart
29 import kotlinx.coroutines.channels.Channel
30 import kotlinx.coroutines.delay
31 import kotlinx.coroutines.launch
32 import kotlinx.coroutines.newSingleThreadContext
33 import kotlinx.coroutines.test.TestScope
34 import kotlinx.coroutines.test.UnconfinedTestDispatcher
35 import kotlinx.coroutines.test.runTest
36 import kotlinx.coroutines.withContext
37 import org.junit.After
38 import org.junit.Assert.assertArrayEquals
39 import org.junit.Assert.assertEquals
40 import org.junit.Assert.assertNotNull
41 import org.junit.Assert.assertNotSame
42 import org.junit.Assert.assertNull
43 import org.junit.Assert.assertSame
44 import org.junit.Assert.assertTrue
45 import org.junit.Before
46 import org.junit.Test
47 import org.junit.runner.RunWith
48 import org.junit.runners.BlockJUnit4ClassRunner
49
50 @RunWith(BlockJUnit4ClassRunner::class)
51 class CoroutineTracingTest {
52 @Before
53 fun setup() {
54 TraceData.strictModeForTesting = true
55 }
56
57 @After
58 fun checkFinished() {
59 val lastEvent = eventCounter.get()
60 assertTrue(
61 "Expected `finish(${lastEvent + 1})` to be called, but the test finished",
62 lastEvent == FINAL_EVENT || lastEvent == 0,
63 )
64 }
65
66 @Test
67 fun simpleTraceSection() = runTestWithTraceContext {
68 expect(1)
69 traceCoroutine("hello") { expect(2, "hello") }
70 finish(3)
71 }
72
73 @Test
74 fun simpleNestedTraceSection() = runTestWithTraceContext {
75 expect(1)
76 traceCoroutine("hello") {
77 expect(2, "hello")
78 traceCoroutine("world") { expect(3, "hello", "world") }
79 expect(4, "hello")
80 }
81 finish(5)
82 }
83
84 @Test
85 fun simpleLaunch() = runTestWithTraceContext {
86 expect(1)
87 traceCoroutine("hello") {
88 expect(2, "hello")
89 launch { finish(4, "hello") }
90 }
91 expect(3)
92 }
93
94 @Test
95 fun launchWithSuspendingLambda() = runTestWithTraceContext {
96 val fetchData: suspend () -> String = {
97 expect(3, "span-for-launch")
98 delay(1L)
99 traceCoroutine("span-for-fetchData") {
100 expect(4, "span-for-launch", "span-for-fetchData")
101 }
102 "stuff"
103 }
104 expect(1)
105 launch("span-for-launch") {
106 assertEquals("stuff", fetchData())
107 finish(5, "span-for-launch")
108 }
109 expect(2)
110 }
111
112 @Test
113 fun nestedUpdateAndRestoreOnSingleThread_unconfinedDispatcher() = runTestWithTraceContext {
114 traceCoroutine("parent-span") {
115 expect(1, "parent-span")
116 launch(UnconfinedTestDispatcher(scheduler = testScheduler)) {
117 // While this may appear unusual, it is actually expected behavior:
118 // 1) The parent has an open trace section called "parent-span".
119 // 2) The child launches, it inherits from its parent, and it is resumed
120 // immediately due to its use of the unconfined dispatcher.
121 // 3) The child emits all the trace sections known to its scope. The parent
122 // does not have an opportunity to restore its context yet.
123 traceCoroutine("child-span") {
124 // [parent's active trace]
125 // \ [trace section inherited from parent]
126 // \ | [new trace section in child scope]
127 // \ | /
128 expect(2, "parent-span", "parent-span", "child-span")
129 delay(1) // <-- delay will give parent a chance to restore its context
130 // After a delay, the parent resumes, finishing its trace section, so we are
131 // left with only those in the child's scope
132 finish(4, "parent-span", "child-span")
133 }
134 }
135 }
136 expect(3)
137 }
138
139 /** @see nestedUpdateAndRestoreOnSingleThread_unconfinedDispatcher */
140 @Test
141 fun nestedUpdateAndRestoreOnSingleThread_undispatchedLaunch() = runTestWithTraceContext {
142 traceCoroutine("parent-span") {
143 launch(start = CoroutineStart.UNDISPATCHED) {
144 traceCoroutine("child-span") {
145 expect(1, "parent-span", "parent-span", "child-span")
146 delay(1) // <-- delay will give parent a chance to restore its context
147 finish(3, "parent-span", "child-span")
148 }
149 }
150 }
151 expect(2)
152 }
153
154 @Test
155 fun launchOnSeparateThread_defaultDispatcher() = runTestWithTraceContext {
156 val channel = Channel<Int>()
157 val bgThread = newSingleThreadContext("thread-#1")
158 expect()
159 traceCoroutine("hello") {
160 expect(1, "hello")
161 launch(bgThread) {
162 expect(2, "hello")
163 traceCoroutine("world") {
164 expect("hello", "world")
165 channel.send(1)
166 expect(3, "hello", "world")
167 }
168 }
169 expect("hello")
170 }
171 expect()
172 assertEquals(1, channel.receive())
173 finish(4)
174 }
175
176 @Test
177 fun testTraceStorage() = runTestWithTraceContext {
178 val channel = Channel<Int>()
179 val fetchData: suspend () -> String = {
180 traceCoroutine("span-for-fetchData") {
181 channel.receive()
182 expect("span-for-launch", "span-for-fetchData")
183 }
184 "stuff"
185 }
186 val threadContexts =
187 listOf(
188 newSingleThreadContext("thread-#1"),
189 newSingleThreadContext("thread-#2"),
190 newSingleThreadContext("thread-#3"),
191 newSingleThreadContext("thread-#4"),
192 )
193
194 val finishedLaunches = Channel<Int>()
195
196 // Start 1000 coroutines waiting on [channel]
197 val job = launch {
198 repeat(1000) {
199 launch("span-for-launch", threadContexts[it % threadContexts.size]) {
200 assertNotNull(traceThreadLocal.get())
201 assertEquals("stuff", fetchData())
202 expect("span-for-launch")
203 assertNotNull(traceThreadLocal.get())
204 expect("span-for-launch")
205 finishedLaunches.send(it)
206 }
207 expect()
208 }
209 }
210 // Resume half the coroutines that are waiting on this channel
211 repeat(500) { channel.send(1) }
212 var receivedClosures = 0
213 repeat(500) {
214 finishedLaunches.receive()
215 receivedClosures++
216 }
217 // ...and cancel the rest
218 job.cancel()
219 }
220
221 private fun CoroutineScope.testTraceSectionsMultiThreaded(
222 thread1Context: CoroutineContext,
223 thread2Context: CoroutineContext
224 ) {
225 val fetchData1: suspend () -> String = {
226 expect("span-for-launch-1")
227 delay(1L)
228 traceCoroutine("span-for-fetchData-1") {
229 expect("span-for-launch-1", "span-for-fetchData-1")
230 }
231 expect("span-for-launch-1")
232 "stuff-1"
233 }
234
235 val fetchData2: suspend () -> String = {
236 expect(
237 "span-for-launch-1",
238 "span-for-launch-2",
239 )
240 delay(1L)
241 traceCoroutine("span-for-fetchData-2") {
242 expect("span-for-launch-1", "span-for-launch-2", "span-for-fetchData-2")
243 }
244 expect(
245 "span-for-launch-1",
246 "span-for-launch-2",
247 )
248 "stuff-2"
249 }
250
251 val thread1 = newSingleThreadContext("thread-#1") + thread1Context
252 val thread2 = newSingleThreadContext("thread-#2") + thread2Context
253
254 launch("span-for-launch-1", thread1) {
255 assertEquals("stuff-1", fetchData1())
256 expect("span-for-launch-1")
257 launch("span-for-launch-2", thread2) {
258 assertEquals("stuff-2", fetchData2())
259 expect("span-for-launch-1", "span-for-launch-2")
260 }
261 expect("span-for-launch-1")
262 }
263 expect()
264
265 // Launching without the trace extension won't result in traces
266 launch(thread1) { expect() }
267 launch(thread2) { expect() }
268 }
269
270 @Test
271 fun nestedTraceSectionsMultiThreaded1() = runTestWithTraceContext {
272 // Thread-#1 and Thread-#2 inherit TraceContextElement from the test's CoroutineContext.
273 testTraceSectionsMultiThreaded(
274 thread1Context = EmptyCoroutineContext,
275 thread2Context = EmptyCoroutineContext
276 )
277 }
278
279 @Test
280 fun nestedTraceSectionsMultiThreaded2() = runTest {
281 // Thread-#2 inherits the TraceContextElement from Thread-#1. The test's CoroutineContext
282 // does not need a TraceContextElement because it does not do any tracing.
283 testTraceSectionsMultiThreaded(
284 thread1Context = TraceContextElement(TraceData()),
285 thread2Context = EmptyCoroutineContext
286 )
287 }
288
289 @Test
290 fun nestedTraceSectionsMultiThreaded3() = runTest {
291 // Thread-#2 overrides the TraceContextElement from Thread-#1, but the merging context
292 // should be fine; it is essentially a no-op. The test's CoroutineContext does not need the
293 // trace context because it does not do any tracing.
294 testTraceSectionsMultiThreaded(
295 thread1Context = TraceContextElement(TraceData()),
296 thread2Context = TraceContextElement(TraceData())
297 )
298 }
299
300 @Test
301 fun nestedTraceSectionsMultiThreaded4() = runTestWithTraceContext {
302 // TraceContextElement is merged on each context switch, which should have no effect on the
303 // trace results.
304 testTraceSectionsMultiThreaded(
305 thread1Context = TraceContextElement(TraceData()),
306 thread2Context = TraceContextElement(TraceData())
307 )
308 }
309
310 @Test
311 fun missingTraceContextObjects() = runTest {
312 val channel = Channel<Int>()
313 // Thread-#1 is missing a TraceContextElement, so some of the trace sections get dropped.
314 // The resulting trace sections will be different than the 4 tests above.
315 val fetchData1: suspend () -> String = {
316 expect()
317 channel.receive()
318 traceCoroutine("span-for-fetchData-1") { expect() }
319 expect()
320 "stuff-1"
321 }
322
323 val fetchData2: suspend () -> String = {
324 expect(
325 "span-for-launch-2",
326 )
327 channel.receive()
328 traceCoroutine("span-for-fetchData-2") {
329 expect("span-for-launch-2", "span-for-fetchData-2")
330 }
331 expect(
332 "span-for-launch-2",
333 )
334 "stuff-2"
335 }
336
337 val thread1 = newSingleThreadContext("thread-#1")
338 val thread2 = newSingleThreadContext("thread-#2") + TraceContextElement(TraceData())
339
340 launch("span-for-launch-1", thread1) {
341 assertEquals("stuff-1", fetchData1())
342 expect()
343 launch("span-for-launch-2", thread2) {
344 assertEquals("stuff-2", fetchData2())
345 expect("span-for-launch-2")
346 }
347 expect()
348 }
349 expect()
350
351 channel.send(1)
352 channel.send(2)
353
354 // Launching without the trace extension won't result in traces
355 launch(thread1) { expect() }
356 launch(thread2) { expect() }
357 }
358
359 /**
360 * Tests interleaving:
361 * ```
362 * Thread #1 | [updateThreadContext]....^ [restoreThreadContext]
363 * --------------------------------------------------------------------------------------------
364 * Thread #2 | [updateThreadContext]...........^[restoreThreadContext]
365 * ```
366 *
367 * This test checks for issues with concurrent modification of the trace state. For example, the
368 * test should fail if [TraceData.endAllOnThread] uses the size of the slices array as follows
369 * instead of using the ThreadLocal count:
370 * ```
371 * class TraceData {
372 * ...
373 * fun endAllOnThread() {
374 * repeat(slices.size) {
375 * // THIS WOULD BE AN ERROR. If the thread is slow, the TraceData object could have been
376 * // modified by another thread
377 * endSlice()
378 * }
379 * ...
380 * }
381 * }
382 * ```
383 */
384 @Test
385 fun coroutineMachinery() {
386 assertNull(traceThreadLocal.get())
387 val traceContext = TraceContextElement()
388 assertNull(traceThreadLocal.get())
389
390 val thread1ResumptionPoint = CyclicBarrier(2)
391 val thread1SuspensionPoint = CyclicBarrier(2)
392
393 val thread1 = Executors.newSingleThreadExecutor()
394 val thread2 = Executors.newSingleThreadExecutor()
395 val slicesForThread1 = listOf("a", "c", "e", "g")
396 val slicesForThread2 = listOf("b", "d", "f", "h")
397 var failureOnThread1: Error? = null
398 var failureOnThread2: Error? = null
399
400 val expectedTraceForThread1 = arrayOf("1:a", "2:b", "1:c", "2:d", "1:e", "2:f", "1:g")
401 thread1.execute {
402 try {
403 slicesForThread1.forEachIndexed { index, sliceName ->
404 assertNull(traceThreadLocal.get())
405 val oldTrace = traceContext.updateThreadContext(EmptyCoroutineContext)
406 // await() AFTER updateThreadContext, thus thread #1 always resumes the
407 // coroutine before thread #2
408 assertSame(traceThreadLocal.get(), traceContext.traceData)
409
410 // coroutine body start {
411 traceThreadLocal.get()?.beginSpan("1:$sliceName")
412
413 // At the end, verify the interleaved trace sections look correct:
414 if (index == slicesForThread1.size - 1) {
415 expect(*expectedTraceForThread1)
416 }
417
418 // simulate a slow thread, wait to call restoreThreadContext until after thread
419 // A
420 // has resumed
421 thread1SuspensionPoint.await(3, TimeUnit.SECONDS)
422 Thread.sleep(500)
423 // } coroutine body end
424
425 traceContext.restoreThreadContext(EmptyCoroutineContext, oldTrace)
426 thread1ResumptionPoint.await(3, TimeUnit.SECONDS)
427 assertNull(traceThreadLocal.get())
428 }
429 } catch (e: Error) {
430 failureOnThread1 = e
431 }
432 }
433
434 val expectedTraceForThread2 =
435 arrayOf("1:a", "2:b", "1:c", "2:d", "1:e", "2:f", "1:g", "2:h")
436 thread2.execute {
437 try {
438 slicesForThread2.forEachIndexed { i, n ->
439 assertNull(traceThreadLocal.get())
440 thread1SuspensionPoint.await(3, TimeUnit.SECONDS)
441
442 val oldTrace: TraceData? =
443 traceContext.updateThreadContext(EmptyCoroutineContext)
444
445 // coroutine body start {
446 traceThreadLocal.get()?.beginSpan("2:$n")
447
448 // At the end, verify the interleaved trace sections look correct:
449 if (i == slicesForThread2.size - 1) {
450 expect(*expectedTraceForThread2)
451 }
452 // } coroutine body end
453
454 traceContext.restoreThreadContext(EmptyCoroutineContext, oldTrace)
455 thread1ResumptionPoint.await(3, TimeUnit.SECONDS)
456 assertNull(traceThreadLocal.get())
457 }
458 } catch (e: Error) {
459 failureOnThread2 = e
460 }
461 }
462
463 thread1.shutdown()
464 thread1.awaitTermination(5, TimeUnit.SECONDS)
465 thread2.shutdown()
466 thread2.awaitTermination(5, TimeUnit.SECONDS)
467
468 assertNull("Failure executing coroutine on thread-#1.", failureOnThread1)
469 assertNull("Failure executing coroutine on thread-#2.", failureOnThread2)
470 }
471
472 @Test
473 fun scopeReentry_withContextFastPath() = runTestWithTraceContext {
474 val channel = Channel<Int>()
475 val bgThread = newSingleThreadContext("bg-thread #1")
476 val job =
477 launch("#1", bgThread) {
478 expect("#1")
479 var i = 0
480 while (true) {
481 expect("#1")
482 channel.send(i++)
483 expect("#1")
484 // when withContext is passed the same scope, it takes a fast path, dispatching
485 // immediately. This means that in subsequent loops, if we do not handle reentry
486 // correctly in TraceContextElement, the trace may become deeply nested:
487 // "#1", "#1", "#1", ... "#2"
488 withContext(bgThread) {
489 expect("#1")
490 traceCoroutine("#2") {
491 expect("#1", "#2")
492 channel.send(i++)
493 expect("#1", "#2")
494 }
495 expect("#1")
496 }
497 }
498 }
499 repeat(1000) {
500 expect()
501 traceCoroutine("receive") {
502 expect("receive")
503 val receivedVal = channel.receive()
504 assertEquals(it, receivedVal)
505 expect("receive")
506 }
507 expect()
508 }
509 job.cancel()
510 }
511
512 @Test
513 fun traceContextIsCopied() = runTest {
514 expect()
515 val traceContext = TraceContextElement()
516 expect()
517 withContext(traceContext) {
518 // Not the same object because it should be copied into the current context
519 assertNotSame(traceThreadLocal.get(), traceContext.traceData)
520 assertNotSame(traceThreadLocal.get()?.slices, traceContext.traceData?.slices)
521 expect()
522 traceCoroutine("hello") {
523 assertNotSame(traceThreadLocal.get(), traceContext.traceData)
524 assertNotSame(traceThreadLocal.get()?.slices, traceContext.traceData?.slices)
525 assertArrayEquals(arrayOf("hello"), traceThreadLocal.get()?.slices?.toArray())
526 }
527 assertNotSame(traceThreadLocal.get(), traceContext.traceData)
528 assertNotSame(traceThreadLocal.get()?.slices, traceContext.traceData?.slices)
529 expect()
530 }
531 expect()
532 }
533
534 @Test
535 fun tracingDisabled() = runTest {
536 Flags.disableCoroutineTracing()
537 assertNull(traceThreadLocal.get())
538 withContext(createCoroutineTracingContext()) {
539 assertNull(traceThreadLocal.get())
540 traceCoroutine("hello") { // should not crash
541 assertNull(traceThreadLocal.get())
542 }
543 }
544 }
545
546 private fun expect(vararg expectedOpenTraceSections: String) {
547 expect(null, *expectedOpenTraceSections)
548 }
549
550 /**
551 * Checks the currently active trace sections on the current thread, and optionally checks the
552 * order of operations if [expectedEvent] is not null.
553 */
554 private fun expect(expectedEvent: Int? = null, vararg expectedOpenTraceSections: String) {
555 if (expectedEvent != null) {
556 val previousEvent = eventCounter.getAndAdd(1)
557 val currentEvent = previousEvent + 1
558 check(expectedEvent == currentEvent) {
559 if (previousEvent == FINAL_EVENT) {
560 "Expected event=$expectedEvent, but finish() was already called"
561 } else {
562 "Expected event=$expectedEvent," +
563 " but the event counter is currently at $currentEvent"
564 }
565 }
566 }
567
568 // Inspect trace output to the fake used for recording android.os.Trace API calls:
569 assertArrayEquals(expectedOpenTraceSections, getOpenTraceSectionsOnCurrentThread())
570 }
571
572 /** Same as [expect], except that no more [expect] statements can be called after it. */
573 private fun finish(expectedEvent: Int, vararg expectedOpenTraceSections: String) {
574 val previousEvent = eventCounter.getAndSet(FINAL_EVENT)
575 val currentEvent = previousEvent + 1
576 check(expectedEvent == currentEvent) {
577 if (previousEvent == FINAL_EVENT) {
578 "finish() was called more than once"
579 } else {
580 "Finished with event=$expectedEvent," +
581 " but the event counter is currently $currentEvent"
582 }
583 }
584
585 // Inspect trace output to the fake used for recording android.os.Trace API calls:
586 assertArrayEquals(expectedOpenTraceSections, getOpenTraceSectionsOnCurrentThread())
587 }
588
589 private val eventCounter = AtomicInteger(0)
590
591 companion object {
592 const val FINAL_EVENT = Int.MIN_VALUE
593 }
594 }
595
596 /**
597 * Helper util for calling [runTest] with a [TraceContextElement]. This is useful for formatting
598 * purposes. Passing an arg to `runTest {}` directly, as in `fun testStuff() =
599 * runTestWithTraceContext {}` would require more indentations according to our style guide.
600 */
runTestWithTraceContextnull601 private fun runTestWithTraceContext(testBody: suspend TestScope.() -> Unit) =
602 runTest(context = TraceContextElement(TraceData()), testBody = testBody)
603