1from .common.codegen import CodeGen, VulkanWrapperGenerator
2from .common.vulkantypes import VulkanAPI, iterateVulkanType, VulkanType
3
4from .reservedmarshaling import VulkanReservedMarshalingCodegen
5from .transform import TransformCodegen
6
7from .wrapperdefs import API_PREFIX_RESERVEDUNMARSHAL
8from .wrapperdefs import MAX_PACKET_LENGTH
9from .wrapperdefs import ROOT_TYPE_DEFAULT_VALUE
10
11
12decoder_decl_preamble = """
13"""
14
15decoder_impl_preamble = """
16"""
17
18global_state_prefix = "this->on_"
19
20READ_STREAM = "readStream"
21WRITE_STREAM = "vkStream"
22
23# Driver workarounds for APIs that don't work well multithreaded
24driver_workarounds_global_lock_apis = [
25    "vkCreatePipelineLayout",
26    "vkDestroyPipelineLayout",
27]
28
29MAX_STACK_ITEMS = "16"
30
31
32def emit_param_decl_for_reading(param, cgen):
33    if param.staticArrExpr:
34        cgen.stmt(
35            cgen.makeRichCTypeDecl(param.getForNonConstAccess()))
36    else:
37        cgen.stmt(
38            cgen.makeRichCTypeDecl(param))
39
40    if param.pointerIndirectionLevels > 0:
41        lenAccess = cgen.generalLengthAccess(param)
42        if not lenAccess:
43            lenAccess = "1"
44        arrSize = "1" if "1" == lenAccess else "MAX_STACK_ITEMS"
45
46        typeHere = "uint8_t*" if "void" == param.typeName else param.typeName
47        cgen.stmt("%s%s stack_%s[%s]" % (
48            typeHere, "*" * (param.pointerIndirectionLevels - 1), param.paramName, arrSize))
49
50
51def emit_unmarshal(typeInfo, param, cgen, output=False, destroy=False, noUnbox=False):
52    if destroy:
53        iterateVulkanType(typeInfo, param, VulkanReservedMarshalingCodegen(
54            cgen,
55            "host",
56            READ_STREAM,
57            ROOT_TYPE_DEFAULT_VALUE,
58            param.paramName,
59            "readStreamPtrPtr",
60            API_PREFIX_RESERVEDUNMARSHAL,
61            "",
62            direction="read",
63            dynAlloc=True))
64        lenAccess = cgen.generalLengthAccess(param)
65        lenAccessGuard = cgen.generalLengthAccessGuard(param)
66        if None == lenAccess or "1" == lenAccess:
67            cgen.stmt("boxed_%s_preserve = %s" %
68                      (param.paramName, param.paramName))
69            cgen.stmt("%s = unbox_%s(%s)" %
70                      (param.paramName, param.typeName, param.paramName))
71        else:
72            if lenAccessGuard is not None:
73                self.cgen.beginIf(lenAccessGuard)
74            cgen.beginFor("uint32_t i = 0", "i < %s" % lenAccess, "++i")
75            cgen.stmt("boxed_%s_preserve[i] = %s[i]" %
76                      (param.paramName, param.paramName))
77            cgen.stmt("((%s*)(%s))[i] = unbox_%s(%s[i])" % (param.typeName,
78                                                            param.paramName, param.typeName, param.paramName))
79            cgen.endFor()
80            if lenAccessGuard is not None:
81                self.cgen.endIf()
82    else:
83        if noUnbox:
84            cgen.line("// No unbox for %s" % (param.paramName))
85
86        lenAccess = cgen.generalLengthAccess(param)
87        if not lenAccess:
88            lenAccess = "1"
89        arrSize = "1" if "1" == lenAccess else "MAX_STACK_ITEMS"
90
91        iterateVulkanType(typeInfo, param, VulkanReservedMarshalingCodegen(
92            cgen,
93            "host",
94            READ_STREAM,
95            ROOT_TYPE_DEFAULT_VALUE,
96            param.paramName,
97            "readStreamPtrPtr",
98            API_PREFIX_RESERVEDUNMARSHAL,
99            "" if (output or noUnbox) else "unbox_",
100            direction="read",
101            dynAlloc=True,
102            stackVar="stack_%s" % param.paramName,
103            stackArrSize=arrSize))
104
105
106def emit_dispatch_unmarshal(typeInfo, param, cgen, globalWrapped):
107    if globalWrapped:
108        cgen.stmt(
109            "// Begin global wrapped dispatchable handle unboxing for %s" % param.paramName)
110        iterateVulkanType(typeInfo, param, VulkanReservedMarshalingCodegen(
111            cgen,
112            "host",
113            READ_STREAM,
114            ROOT_TYPE_DEFAULT_VALUE,
115            param.paramName,
116            "readStreamPtrPtr",
117            API_PREFIX_RESERVEDUNMARSHAL,
118            "",
119            direction="read",
120            dynAlloc=True))
121    else:
122        cgen.stmt(
123            "// Begin non wrapped dispatchable handle unboxing for %s" % param.paramName)
124        # cgen.stmt("%s->unsetHandleMapping()" % READ_STREAM)
125        iterateVulkanType(typeInfo, param, VulkanReservedMarshalingCodegen(
126            cgen,
127            "host",
128            READ_STREAM,
129            ROOT_TYPE_DEFAULT_VALUE,
130            param.paramName,
131            "readStreamPtrPtr",
132            API_PREFIX_RESERVEDUNMARSHAL,
133            "",
134            direction="read",
135            dynAlloc=True))
136        cgen.stmt("auto unboxed_%s = unbox_%s(%s)" %
137                  (param.paramName, param.typeName, param.paramName))
138        cgen.stmt("auto vk = dispatch_%s(%s)" %
139                  (param.typeName, param.paramName))
140        cgen.stmt("// End manual dispatchable handle unboxing for %s" %
141                  param.paramName)
142
143
144def emit_transform(typeInfo, param, cgen, variant="tohost"):
145    res = \
146        iterateVulkanType(typeInfo, param, TransformCodegen(
147            cgen, param.paramName, "globalstate", "transform_%s_" % variant, variant))
148    if not res:
149        cgen.stmt("(void)%s" % param.paramName)
150
151# Everything here elides the initial arg
152
153
154class DecodingParameters(object):
155    def __init__(self, api: VulkanAPI):
156        self.params: list[VulkanType] = []
157        self.toRead: list[VulkanType] = []
158        self.toWrite: list[VulkanType] = []
159
160        for i, param in enumerate(api.parameters[1:]):
161            if i == 0 and param.isDispatchableHandleType():
162                param.dispatchHandle = True
163
164            if param.isNonDispatchableHandleType() and param.isCreatedBy(api):
165                param.nonDispatchableHandleCreate = True
166
167            if param.isNonDispatchableHandleType() and param.isDestroyedBy(api):
168                param.nonDispatchableHandleDestroy = True
169
170            if param.isDispatchableHandleType() and param.isCreatedBy(api):
171                param.dispatchableHandleCreate = True
172
173            if param.isDispatchableHandleType() and param.isDestroyedBy(api):
174                param.dispatchableHandleDestroy = True
175
176            self.toRead.append(param)
177
178            if param.possiblyOutput():
179                self.toWrite.append(param)
180
181            self.params.append(param)
182
183
184def emit_call_log(api, cgen):
185    decodingParams = DecodingParameters(api)
186    paramsToRead = decodingParams.toRead
187
188    # cgen.beginIf("m_logCalls")
189    paramLogFormat = "%p"
190    paramLogArgs = ["(void*)boxed_dispatchHandle"]
191
192    for p in paramsToRead:
193        paramLogFormat += "0x%llx "
194    for p in paramsToRead:
195        paramLogArgs.append("(unsigned long long)%s" % (p.paramName))
196    # cgen.stmt("fprintf(stderr, \"substream %%p: call %s %s\\n\", readStream, %s)" % (api.name, paramLogFormat, ", ".join(paramLogArgs)))
197    # cgen.endIf()
198
199
200def emit_decode_parameters(typeInfo, api, cgen, globalWrapped=False):
201
202    decodingParams = DecodingParameters(api)
203
204    paramsToRead = decodingParams.toRead
205
206    for p in paramsToRead:
207        emit_param_decl_for_reading(p, cgen)
208
209    i = 0
210    for p in paramsToRead:
211        lenAccess = cgen.generalLengthAccess(p)
212
213        if p.dispatchHandle:
214            emit_dispatch_unmarshal(typeInfo, p, cgen, globalWrapped)
215        else:
216            destroy = p.nonDispatchableHandleDestroy or p.dispatchableHandleDestroy
217            noUnbox = False
218
219            if p.nonDispatchableHandleDestroy or p.dispatchableHandleDestroy:
220                destroy = True
221                cgen.stmt(
222                    "// Begin manual non dispatchable handle destroy unboxing for %s" % p.paramName)
223                if None == lenAccess or "1" == lenAccess:
224                    cgen.stmt("%s boxed_%s_preserve" %
225                              (p.typeName, p.paramName))
226                else:
227                    cgen.stmt("%s* boxed_%s_preserve; %s->alloc((void**)&boxed_%s_preserve, %s * sizeof(%s))" %
228                              (p.typeName, p.paramName, READ_STREAM, p.paramName, lenAccess, p.typeName))
229
230            if p.possiblyOutput():
231                cgen.stmt(
232                    "// Begin manual dispatchable handle unboxing for %s" % p.paramName)
233                cgen.stmt("%s->unsetHandleMapping()" % READ_STREAM)
234
235            emit_unmarshal(typeInfo, p, cgen, output=p.possiblyOutput(
236            ), destroy=destroy, noUnbox=noUnbox)
237        i += 1
238
239    for p in paramsToRead:
240        emit_transform(typeInfo, p, cgen, variant="tohost")
241
242    emit_call_log(api, cgen)
243
244
245def emit_dispatch_call(api, cgen):
246
247    decodingParams = DecodingParameters(api)
248
249    customParams = ["(VkCommandBuffer)dispatchHandle"]
250
251    for (i, p) in enumerate(api.parameters[1:]):
252        customParam = p.paramName
253        if decodingParams.params[i].dispatchHandle:
254            customParam = "unboxed_%s" % p.paramName
255        customParams.append(customParam)
256
257    if api.name in driver_workarounds_global_lock_apis:
258        cgen.stmt("lock()")
259
260    cgen.vkApiCall(api, customPrefix="vk->", customParameters=customParams,
261                    checkForDeviceLost=True, globalStatePrefix=global_state_prefix,
262                    checkForOutOfMemory=True)
263
264    if api.name in driver_workarounds_global_lock_apis:
265        cgen.stmt("unlock()")
266
267
268def emit_global_state_wrapped_call(api, cgen, context=False):
269    customParams = ["pool", "(VkCommandBuffer)(boxed_dispatchHandle)"] + \
270        list(map(lambda p: p.paramName, api.parameters[1:]))
271    if context:
272        customParams += ["context"];
273    cgen.vkApiCall(api, customPrefix=global_state_prefix,
274                   customParameters=customParams, checkForDeviceLost=True,
275                   checkForOutOfMemory=True, globalStatePrefix=global_state_prefix)
276
277
278def emit_default_decoding(typeInfo, api, cgen):
279    emit_decode_parameters(typeInfo, api, cgen)
280    emit_dispatch_call(api, cgen)
281
282
283def emit_global_state_wrapped_decoding(typeInfo, api, cgen):
284    emit_decode_parameters(typeInfo, api, cgen, globalWrapped=True)
285    emit_global_state_wrapped_call(api, cgen)
286
287def emit_global_state_wrapped_decoding_with_context(typeInfo, api, cgen):
288    emit_decode_parameters(typeInfo, api, cgen, globalWrapped=True)
289    emit_global_state_wrapped_call(api, cgen, context=True)
290
291custom_decodes = {
292    "vkCmdCopyBufferToImage": emit_global_state_wrapped_decoding_with_context,
293    "vkCmdCopyImage": emit_global_state_wrapped_decoding,
294    "vkCmdCopyImageToBuffer": emit_global_state_wrapped_decoding,
295    "vkCmdCopyBufferToImage2": emit_global_state_wrapped_decoding_with_context,
296    "vkCmdCopyImage2": emit_global_state_wrapped_decoding,
297    "vkCmdCopyImageToBuffer2": emit_global_state_wrapped_decoding,
298    "vkCmdCopyBufferToImage2KHR": emit_global_state_wrapped_decoding_with_context,
299    "vkCmdCopyImage2KHR": emit_global_state_wrapped_decoding,
300    "vkCmdCopyImageToBuffer2KHR": emit_global_state_wrapped_decoding,
301    "vkCmdExecuteCommands": emit_global_state_wrapped_decoding,
302    "vkBeginCommandBuffer": emit_global_state_wrapped_decoding_with_context,
303    "vkEndCommandBuffer": emit_global_state_wrapped_decoding_with_context,
304    "vkResetCommandBuffer": emit_global_state_wrapped_decoding,
305    "vkCmdPipelineBarrier": emit_global_state_wrapped_decoding,
306    "vkCmdBindPipeline": emit_global_state_wrapped_decoding,
307    "vkCmdBindDescriptorSets": emit_global_state_wrapped_decoding,
308    "vkCmdCopyQueryPoolResults": emit_global_state_wrapped_decoding,
309    "vkBeginCommandBufferAsyncGOOGLE": emit_global_state_wrapped_decoding_with_context,
310    "vkEndCommandBufferAsyncGOOGLE": emit_global_state_wrapped_decoding_with_context,
311    "vkResetCommandBufferAsyncGOOGLE": emit_global_state_wrapped_decoding,
312    "vkCommandBufferHostSyncGOOGLE": emit_global_state_wrapped_decoding,
313    "vkCmdBeginRenderPass" : emit_global_state_wrapped_decoding,
314    "vkCmdBeginRenderPass2" : emit_global_state_wrapped_decoding,
315    "vkCmdBeginRenderPass2KHR" : emit_global_state_wrapped_decoding,
316}
317
318
319class VulkanSubDecoder(VulkanWrapperGenerator):
320    def __init__(self, module, typeInfo):
321        VulkanWrapperGenerator.__init__(self, module, typeInfo)
322        self.typeInfo = typeInfo
323        self.cgen = CodeGen()
324
325    def onBegin(self,):
326        self.module.appendImpl(
327            "#define MAX_STACK_ITEMS %s\n" % MAX_STACK_ITEMS)
328
329        self.module.appendImpl(
330            "#define MAX_PACKET_LENGTH %s\n" % MAX_PACKET_LENGTH)
331
332        self.module.appendImpl(
333            "size_t subDecode(VulkanMemReadingStream* readStream, VulkanDispatch* vk, void* boxed_dispatchHandle, void* dispatchHandle, VkDeviceSize dataSize, const void* pData, const VkDecoderContext& context)\n")
334
335        self.cgen.beginBlock()  # function body
336
337        self.cgen.stmt("auto& metricsLogger = *context.metricsLogger")
338        self.cgen.stmt("uint32_t count = 0")
339        self.cgen.stmt("unsigned char *buf = (unsigned char *)pData")
340        self.cgen.stmt("android::base::BumpPool* pool = readStream->pool()")
341        self.cgen.stmt("unsigned char *ptr = (unsigned char *)pData")
342        self.cgen.stmt(
343            "const unsigned char* const end = (const unsigned char*)buf + dataSize")
344        self.cgen.stmt(
345            "VkDecoderGlobalState* globalstate = VkDecoderGlobalState::get()")
346
347        self.cgen.line("while (end - ptr >= 8)")
348        self.cgen.beginBlock()  # while loop
349
350        self.cgen.stmt("uint32_t opcode = *(uint32_t *)ptr")
351        self.cgen.stmt("uint32_t packetLen = *(uint32_t *)(ptr + 4)")
352        self.cgen.line("""
353        // packetLen should be at least 8 (op code and packet length) and should not be excessively large
354        if (packetLen < 8 || packetLen > MAX_PACKET_LENGTH) {
355            WARN("Bad packet length %d detected, subdecode may fail", packetLen);
356            metricsLogger.logMetricEvent(MetricEventBadPacketLength{ .len = packetLen });
357        }
358        """)
359        self.cgen.stmt("if (end - ptr < packetLen) return ptr - (unsigned char*)buf")
360
361
362        self.cgen.stmt("%s->setBuf((uint8_t*)(ptr + 8))" % READ_STREAM)
363        self.cgen.stmt(
364            "uint8_t* readStreamPtr = %s->getBuf(); uint8_t** readStreamPtrPtr = &readStreamPtr" % READ_STREAM)
365        self.cgen.line("switch (opcode)")
366        self.cgen.beginBlock()  # switch stmt
367
368        self.module.appendImpl(self.cgen.swapCode())
369
370    def onGenCmd(self, cmdinfo, name, alias):
371        typeInfo = self.typeInfo
372        cgen = self.cgen
373        api = typeInfo.apis[name]
374
375        if "commandBuffer" != api.parameters[0].paramName:
376            return
377
378        cgen.line("case OP_%s:" % name)
379        cgen.beginBlock()
380        cgen.stmt("android::base::beginTrace(\"%s subdecode\")" % name)
381
382        if api.name in custom_decodes.keys():
383            custom_decodes[api.name](typeInfo, api, cgen)
384        else:
385            emit_default_decoding(typeInfo, api, cgen)
386
387        cgen.stmt("android::base::endTrace()")
388        cgen.stmt("break")
389        cgen.endBlock()
390        self.module.appendImpl(self.cgen.swapCode())
391
392    def onEnd(self,):
393        self.cgen.line("default:")
394        self.cgen.beginBlock()
395        self.cgen.stmt(
396            "GFXSTREAM_ABORT(::emugl::FatalError(::emugl::ABORT_REASON_OTHER)) << \"Unrecognized opcode \" << opcode")
397        self.cgen.endBlock()
398
399        self.cgen.endBlock()  # switch stmt
400
401        self.cgen.stmt("++count; if (count % 1000 == 0) { pool->freeAll(); }")
402        self.cgen.stmt("ptr += packetLen")
403        self.cgen.endBlock()  # while loop
404
405        self.cgen.stmt("pool->freeAll()")
406        self.cgen.stmt("return ptr - (unsigned char*)buf;")
407        self.cgen.endBlock()  # function body
408        self.module.appendImpl(self.cgen.swapCode())
409