1 /*
<lambda>null2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.protolog.tool
18 
19 import com.android.internal.protolog.common.LogDataType
20 import com.android.internal.protolog.common.LogLevel
21 import com.github.javaparser.StaticJavaParser
22 import com.github.javaparser.ast.CompilationUnit
23 import com.github.javaparser.ast.NodeList
24 import com.github.javaparser.ast.body.VariableDeclarator
25 import com.github.javaparser.ast.expr.ArrayAccessExpr
26 import com.github.javaparser.ast.expr.CastExpr
27 import com.github.javaparser.ast.expr.Expression
28 import com.github.javaparser.ast.expr.FieldAccessExpr
29 import com.github.javaparser.ast.expr.IntegerLiteralExpr
30 import com.github.javaparser.ast.expr.LongLiteralExpr
31 import com.github.javaparser.ast.expr.MethodCallExpr
32 import com.github.javaparser.ast.expr.NameExpr
33 import com.github.javaparser.ast.expr.NullLiteralExpr
34 import com.github.javaparser.ast.expr.SimpleName
35 import com.github.javaparser.ast.expr.TypeExpr
36 import com.github.javaparser.ast.expr.VariableDeclarationExpr
37 import com.github.javaparser.ast.stmt.BlockStmt
38 import com.github.javaparser.ast.stmt.ExpressionStmt
39 import com.github.javaparser.ast.stmt.IfStmt
40 import com.github.javaparser.ast.stmt.Statement
41 import com.github.javaparser.ast.type.ArrayType
42 import com.github.javaparser.ast.type.ClassOrInterfaceType
43 import com.github.javaparser.ast.type.PrimitiveType
44 import com.github.javaparser.ast.type.Type
45 import com.github.javaparser.printer.PrettyPrinter
46 import com.github.javaparser.printer.PrettyPrinterConfiguration
47 
48 class SourceTransformer(
49     protoLogImplClassName: String,
50     private val protoLogCallProcessor: ProtoLogCallProcessor
51 ) {
52     private val inlinePrinter: PrettyPrinter
53     private val objectType = StaticJavaParser.parseClassOrInterfaceType("Object")
54 
55     init {
56         val config = PrettyPrinterConfiguration()
57         config.endOfLineCharacter = " "
58         config.indentSize = 0
59         config.tabWidth = 1
60         inlinePrinter = PrettyPrinter(config)
61     }
62 
63     fun processClass(
64         code: String,
65         path: String,
66         packagePath: String,
67         compilationUnit: CompilationUnit =
68             StaticJavaParser.parse(code)
69     ): String {
70         this.path = path
71         this.packagePath = packagePath
72         processedCode = code.split('\n').toMutableList()
73         offsets = IntArray(processedCode.size)
74         protoLogCallProcessor.process(compilationUnit, protoLogCallVisitor, otherCallVisitor, path)
75         return processedCode.joinToString("\n")
76     }
77 
78     private val protoLogImplClassNode =
79             StaticJavaParser.parseExpression<FieldAccessExpr>(protoLogImplClassName)
80     private val protoLogImplCacheClassNode =
81         StaticJavaParser.parseExpression<FieldAccessExpr>("$protoLogImplClassName.Cache")
82     private var processedCode: MutableList<String> = mutableListOf()
83     private var offsets: IntArray = IntArray(0)
84     /** The path of the file being processed, relative to $ANDROID_BUILD_TOP */
85     private var path: String = ""
86     /** The path of the file being processed, relative to the root package */
87     private var packagePath: String = ""
88 
89     private val protoLogCallVisitor = object : ProtoLogCallVisitor {
90         override fun processCall(
91             call: MethodCallExpr,
92             messageString: String,
93             level: LogLevel,
94             group: LogGroup
95         ) {
96             validateCall(call)
97             val processedCallStatement =
98                 createProcessedCallStatement(call, group, level, messageString)
99             val parentStmt = call.parentNode.get() as ExpressionStmt
100             injectProcessedCallStatementInCode(processedCallStatement, parentStmt)
101         }
102     }
103 
104     private fun validateCall(call: MethodCallExpr) {
105         // Input format: ProtoLog.e(GROUP, "msg %d", arg)
106         if (!call.parentNode.isPresent) {
107             // Should never happen
108             throw RuntimeException("Unable to process log call $call " +
109                     "- no parent node in AST")
110         }
111         if (call.parentNode.get() !is ExpressionStmt) {
112             // Should never happen
113             throw RuntimeException("Unable to process log call $call " +
114                     "- parent node in AST is not an ExpressionStmt")
115         }
116         val parentStmt = call.parentNode.get() as ExpressionStmt
117         if (!parentStmt.parentNode.isPresent) {
118             // Should never happen
119             throw RuntimeException("Unable to process log call $call " +
120                     "- no grandparent node in AST")
121         }
122     }
123 
124     private fun createProcessedCallStatement(
125         call: MethodCallExpr,
126         group: LogGroup,
127         level: LogLevel,
128         messageString: String
129     ): Statement {
130         val hash = CodeUtils.hash(packagePath, messageString, level, group)
131 
132         val newCall = call.clone()
133         if (!group.textEnabled) {
134             // Remove message string if text logging is not enabled by default.
135             // Out: ProtoLog.e(GROUP, null, arg)
136             newCall.arguments[1].replace(NameExpr("null"))
137         }
138         // Insert message string hash as a second argument.
139         // Out: ProtoLog.e(GROUP, 1234, null, arg)
140         newCall.arguments.add(1, LongLiteralExpr("" + hash + "L"))
141         val argTypes = LogDataType.parseFormatString(messageString)
142         val typeMask = LogDataType.logDataTypesToBitMask(argTypes)
143         // Insert bitmap representing which Number parameters are to be considered as
144         // floating point numbers.
145         // Out: ProtoLog.e(GROUP, 1234, 0, null, arg)
146         newCall.arguments.add(2, IntegerLiteralExpr(typeMask))
147         // Replace call to a stub method with an actual implementation.
148         // Out: ProtoLogImpl.e(GROUP, 1234, null, arg)
149         newCall.setScope(protoLogImplClassNode)
150         if (argTypes.size != call.arguments.size - 2) {
151             throw InvalidProtoLogCallException(
152                 "Number of arguments (${argTypes.size} does not match format" +
153                         " string in: $call", ParsingContext(path, call))
154         }
155         val blockStmt = BlockStmt()
156         if (argTypes.isNotEmpty()) {
157             // Assign every argument to a variable to check its type in compile time
158             // (this is assignment is optimized-out by dex tool, there is no runtime impact)/
159             // Out: long protoLogParam0 = arg
160             argTypes.forEachIndexed { idx, type ->
161                 val varName = "protoLogParam$idx"
162                 val declaration = VariableDeclarator(getASTTypeForDataType(type), varName,
163                     getConversionForType(type)(newCall.arguments[idx + 4].clone()))
164                 blockStmt.addStatement(ExpressionStmt(VariableDeclarationExpr(declaration)))
165                 newCall.setArgument(idx + 4, NameExpr(SimpleName(varName)))
166             }
167         } else {
168             // Assign (Object[])null as the vararg parameter to prevent allocating an empty
169             // object array.
170             val nullArray = CastExpr(ArrayType(objectType), NullLiteralExpr())
171             newCall.addArgument(nullArray)
172         }
173         blockStmt.addStatement(ExpressionStmt(newCall))
174 
175         val isLogEnabled = ArrayAccessExpr()
176             .setName(NameExpr("$protoLogImplCacheClassNode.${group.name}_enabled"))
177             .setIndex(IntegerLiteralExpr(level.ordinal))
178 
179         return IfStmt(isLogEnabled, blockStmt, null)
180     }
181 
182     private fun injectProcessedCallStatementInCode(
183         processedCallStatement: Statement,
184         parentStmt: ExpressionStmt
185     ) {
186         // Inline the new statement.
187         val printedBlockStmt = inlinePrinter.print(processedCallStatement)
188         // Append blank lines to preserve line numbering in file (to allow debugging)
189         val parentRange = parentStmt.range.get()
190         val newLines = parentRange.end.line - parentRange.begin.line
191         val newStmt = printedBlockStmt.substringBeforeLast('}') + ("\n".repeat(newLines)) + '}'
192         // pre-workaround code, see explanation below
193 
194         /** Workaround for a bug in JavaParser (AST tree invalid after replacing a node when using
195          * LexicalPreservingPrinter (https://github.com/javaparser/javaparser/issues/2290).
196          * Replace the code below with the one commended-out above one the issue is resolved. */
197         if (!parentStmt.range.isPresent) {
198             // Should never happen
199             throw RuntimeException("Unable to process log call in $parentStmt " +
200                     "- unable to replace the call.")
201         }
202         val range = parentStmt.range.get()
203         val begin = range.begin.line - 1
204         val oldLines = processedCode.subList(begin, range.end.line)
205         val oldCode = oldLines.joinToString("\n")
206         val newCode = oldCode.replaceRange(
207             offsets[begin] + range.begin.column - 1,
208             oldCode.length - oldLines.lastOrNull()!!.length +
209                     range.end.column + offsets[range.end.line - 1], newStmt)
210         newCode.split("\n").forEachIndexed { idx, line ->
211             offsets[begin + idx] += line.length - processedCode[begin + idx].length
212             processedCode[begin + idx] = line
213         }
214     }
215 
216     private val otherCallVisitor = object : MethodCallVisitor {
217         override fun processCall(call: MethodCallExpr) {
218             val newCall = call.clone()
219             newCall.setScope(protoLogImplClassNode)
220 
221             val range = call.range.get()
222             val begin = range.begin.line - 1
223             val oldLines = processedCode.subList(begin, range.end.line)
224             val oldCode = oldLines.joinToString("\n")
225             val newCode = oldCode.replaceRange(
226                 offsets[begin] + range.begin.column - 1,
227                 oldCode.length - oldLines.lastOrNull()!!.length +
228                         range.end.column + offsets[range.end.line - 1], newCall.toString())
229             newCode.split("\n").forEachIndexed { idx, line ->
230                 offsets[begin + idx] += line.length - processedCode[begin + idx].length
231                 processedCode[begin + idx] = line
232             }
233         }
234     }
235 
236     companion object {
237         private val stringType: ClassOrInterfaceType =
238             StaticJavaParser.parseClassOrInterfaceType("String")
239 
240         fun getASTTypeForDataType(type: Int): Type {
241             return when (type) {
242                 LogDataType.STRING -> stringType.clone()
243                 LogDataType.LONG -> PrimitiveType.longType()
244                 LogDataType.DOUBLE -> PrimitiveType.doubleType()
245                 LogDataType.BOOLEAN -> PrimitiveType.booleanType()
246                 else -> {
247                     // Should never happen.
248                     throw RuntimeException("Invalid LogDataType")
249                 }
250             }
251         }
252 
253         fun getConversionForType(type: Int): (Expression) -> Expression {
254             return when (type) {
255                 LogDataType.STRING -> { expr ->
256                     MethodCallExpr(TypeExpr(StaticJavaParser.parseClassOrInterfaceType("String")),
257                         SimpleName("valueOf"), NodeList(expr))
258                 }
259                 else -> { expr -> expr }
260             }
261         }
262     }
263 }
264