1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.app.tracing.coroutines
18 
19 import com.android.app.tracing.beginSlice
20 import com.android.app.tracing.endSlice
21 import java.util.ArrayDeque
22 
23 /**
24  * Represents a section of code executing in a coroutine. This may be split up into multiple slices
25  * on different threads as the coroutine is suspended and resumed.
26  *
27  * @see traceCoroutine
28  */
29 typealias TraceSection = String
30 
31 @PublishedApi
32 internal class TraceCountThreadLocal : ThreadLocal<Int>() {
initialValuenull33     override fun initialValue(): Int {
34         return 0
35     }
36 }
37 
38 /**
39  * Used for storing trace sections so that they can be added and removed from the currently running
40  * thread when the coroutine is suspended and resumed.
41  *
42  * @see traceCoroutine
43  */
44 @PublishedApi
45 internal class TraceData(
46     internal val slices: ArrayDeque<TraceSection> = ArrayDeque(),
47 ) : Cloneable {
48 
49     /**
50      * ThreadLocal counter for how many open trace sections there are. This is needed because it is
51      * possible that on a multi-threaded dispatcher, one of the threads could be slow, and
52      * `restoreThreadContext` might be invoked _after_ the coroutine has already resumed and
53      * modified TraceData - either adding or removing trace sections and changing the count. If we
54      * did not store this thread-locally, then we would incorrectly end too many or too few trace
55      * sections.
56      */
57     private val openSliceCount = TraceCountThreadLocal()
58 
59     /** Adds current trace slices back to the current thread. Called when coroutine is resumed. */
beginAllOnThreadnull60     internal fun beginAllOnThread() {
61         strictModeCheck()
62         slices.descendingIterator().forEach { beginSlice(it) }
63         openSliceCount.set(slices.size)
64     }
65 
66     /**
67      * Removes all current trace slices from the current thread. Called when coroutine is suspended.
68      */
endAllOnThreadnull69     internal fun endAllOnThread() {
70         strictModeCheck()
71         repeat(openSliceCount.get()) { endSlice() }
72         openSliceCount.set(0)
73     }
74 
75     /**
76      * Creates a new trace section with a unique ID and adds it to the current trace data. The slice
77      * will also be added to the current thread immediately. This slice will not propagate to parent
78      * coroutines, or to child coroutines that have already started. The unique ID is used to verify
79      * that the [endSpan] is corresponds to a [beginSpan].
80      */
81     @PublishedApi
beginSpannull82     internal fun beginSpan(name: String) {
83         strictModeCheck()
84         slices.push(name)
85         openSliceCount.set(slices.size)
86         beginSlice(name)
87     }
88 
89     /**
90      * Ends the trace section and validates it corresponds with an earlier call to [beginSpan]. The
91      * trace slice will immediately be removed from the current thread. This information will not
92      * propagate to parent coroutines, or to child coroutines that have already started.
93      */
94     @PublishedApi
endSpannull95     internal fun endSpan() {
96         strictModeCheck()
97         // Should never happen, but we should be defensive rather than crash the whole application
98         if (slices.size > 0) {
99             slices.pop()
100             openSliceCount.set(slices.size)
101             endSlice()
102         } else if (strictModeForTesting) {
103             throw IllegalStateException(INVALID_SPAN_END_CALL_ERROR_MESSAGE)
104         }
105     }
106 
107     /**
108      * Used by [TraceContextElement] when launching a child coroutine so that the child coroutine's
109      * state is isolated from the parent.
110      */
clonenull111     public override fun clone(): TraceData {
112         return TraceData(slices.clone())
113     }
114 
toStringnull115     override fun toString(): String {
116         return "TraceData@${hashCode().toHexString()}-size=${slices.size}"
117     }
118 
strictModeChecknull119     private fun strictModeCheck() {
120         if (strictModeForTesting && traceThreadLocal.get() !== this) {
121             throw ConcurrentModificationException(STRICT_MODE_ERROR_MESSAGE)
122         }
123     }
124 
125     companion object {
126         /**
127          * Whether to add additional checks to the coroutine machinery, throwing a
128          * `ConcurrentModificationException` if TraceData is modified from the wrong thread. This
129          * should only be set for testing.
130          */
131         internal var strictModeForTesting: Boolean = false
132     }
133 }
134 
135 private const val INVALID_SPAN_END_CALL_ERROR_MESSAGE =
136     "TraceData#endSpan called when there were no active trace sections."
137 
138 private const val STRICT_MODE_ERROR_MESSAGE =
139     "TraceData should only be accessed using " +
140         "the ThreadLocal: CURRENT_TRACE.get(). Accessing TraceData by other means, such as " +
141         "through the TraceContextElement's property may lead to concurrent modification."
142