1from .common.codegen import CodeGen, VulkanWrapperGenerator, VulkanAPIWrapper
2from .common.vulkantypes import \
3        VulkanAPI, makeVulkanTypeSimple, iterateVulkanType, DISPATCHABLE_HANDLE_TYPES, NON_DISPATCHABLE_HANDLE_TYPES
4
5from .transform import TransformCodegen, genTransformsForVulkanType
6
7from .wrapperdefs import API_PREFIX_MARSHAL
8from .wrapperdefs import API_PREFIX_UNMARSHAL
9from .wrapperdefs import VULKAN_STREAM_TYPE
10
11from copy import copy
12from dataclasses import dataclass
13
14decoder_snapshot_decl_preamble = """
15
16namespace android {
17namespace base {
18class BumpPool;
19class Stream;
20} // namespace base {
21} // namespace android {
22
23class VkDecoderSnapshot {
24public:
25    VkDecoderSnapshot();
26    ~VkDecoderSnapshot();
27
28    void save(android::base::Stream* stream);
29    void load(android::base::Stream* stream, emugl::GfxApiLogger& gfx_logger,
30              emugl::HealthMonitor<>* healthMonitor);
31    void createExtraHandlesForNextApi(const uint64_t* created, uint32_t count);
32"""
33
34decoder_snapshot_decl_postamble = """
35private:
36    class Impl;
37    std::unique_ptr<Impl> mImpl;
38
39};
40"""
41
42decoder_snapshot_impl_preamble ="""
43
44using namespace gfxstream::vk;
45using emugl::GfxApiLogger;
46using emugl::HealthMonitor;
47
48class VkDecoderSnapshot::Impl {
49public:
50    Impl() { }
51
52    void save(android::base::Stream* stream) {
53        mReconstruction.save(stream);
54    }
55
56    void load(android::base::Stream* stream, GfxApiLogger& gfx_logger,
57              HealthMonitor<>* healthMonitor) {
58        mReconstruction.load(stream, gfx_logger, healthMonitor);
59    }
60
61    void createExtraHandlesForNextApi(const uint64_t* created, uint32_t count) {
62        mReconstruction.createExtraHandlesForNextApi(created, count);
63    }
64"""
65
66decoder_snapshot_impl_postamble = """
67private:
68    android::base::Lock mLock;
69    VkReconstruction mReconstruction;
70};
71
72VkDecoderSnapshot::VkDecoderSnapshot() :
73    mImpl(new VkDecoderSnapshot::Impl()) { }
74
75void VkDecoderSnapshot::save(android::base::Stream* stream) {
76    mImpl->save(stream);
77}
78
79void VkDecoderSnapshot::load(android::base::Stream* stream, GfxApiLogger& gfx_logger,
80                             HealthMonitor<>* healthMonitor) {
81    mImpl->load(stream, gfx_logger, healthMonitor);
82}
83
84void VkDecoderSnapshot::createExtraHandlesForNextApi(const uint64_t* created, uint32_t count) {
85    mImpl->createExtraHandlesForNextApi(created, count);
86}
87
88VkDecoderSnapshot::~VkDecoderSnapshot() = default;
89"""
90
91AUXILIARY_SNAPSHOT_API_BASE_PARAM_COUNT = 3
92
93AUXILIARY_SNAPSHOT_API_PARAM_NAMES = [
94    "input_result",
95]
96
97# Vulkan handle dependencies.
98# (a, b): a depends on b
99SNAPSHOT_HANDLE_DEPENDENCIES = [
100    # Dispatchable handle types
101    ("VkCommandBuffer", "VkCommandPool"),
102    ("VkCommandPool", "VkDevice"),
103    ("VkQueue", "VkDevice"),
104    ("VkDevice", "VkPhysicalDevice"),
105    ("VkPhysicalDevice", "VkInstance")] + \
106    list(map(lambda handleType : (handleType, "VkDevice"), NON_DISPATCHABLE_HANDLE_TYPES))
107
108handleDependenciesDict = dict(SNAPSHOT_HANDLE_DEPENDENCIES)
109
110def extract_deps_vkAllocateMemory(param, access, lenExpr, api, cgen):
111    cgen.stmt("const VkMemoryDedicatedAllocateInfo* dedicatedAllocateInfo = vk_find_struct<VkMemoryDedicatedAllocateInfo>(pAllocateInfo)");
112    cgen.beginIf("dedicatedAllocateInfo");
113    cgen.beginIf("dedicatedAllocateInfo->image")
114    cgen.stmt("mReconstruction.addHandleDependency((const uint64_t*)%s, %s, (uint64_t)(uintptr_t)%s)" % \
115              (access, lenExpr, "unboxed_to_boxed_non_dispatchable_VkImage(dedicatedAllocateInfo->image)"))
116    cgen.endIf()
117    cgen.beginIf("dedicatedAllocateInfo->buffer")
118    cgen.stmt("mReconstruction.addHandleDependency((const uint64_t*)%s, %s, (uint64_t)(uintptr_t)%s)" % \
119              (access, lenExpr, "unboxed_to_boxed_non_dispatchable_VkBuffer(dedicatedAllocateInfo->buffer)"))
120    cgen.endIf()
121    cgen.endIf()
122
123def extract_deps_vkAllocateCommandBuffers(param, access, lenExpr, api, cgen):
124    cgen.stmt("mReconstruction.addHandleDependency((const uint64_t*)%s, %s, (uint64_t)(uintptr_t)%s)" % \
125              (access, lenExpr, "unboxed_to_boxed_non_dispatchable_VkCommandPool(pAllocateInfo->commandPool)"))
126
127def extract_deps_vkCreateImageView(param, access, lenExpr, api, cgen):
128    cgen.stmt("mReconstruction.addHandleDependency((const uint64_t*)%s, %s, (uint64_t)(uintptr_t)%s, VkReconstruction::CREATED, VkReconstruction::BOUND_MEMORY)" % \
129              (access, lenExpr, "unboxed_to_boxed_non_dispatchable_VkImage(pCreateInfo->image)"))
130
131def extract_deps_vkCreateGraphicsPipelines(param, access, lenExpr, api, cgen):
132    cgen.beginFor("uint32_t i = 0", "i < createInfoCount", "++i")
133    cgen.beginFor("uint32_t j = 0", "j < pCreateInfos[i].stageCount", "++j")
134    cgen.stmt("mReconstruction.addHandleDependency((const uint64_t*)(%s + i), %s, (uint64_t)(uintptr_t)%s)" % \
135              (access, 1, "unboxed_to_boxed_non_dispatchable_VkShaderModule(pCreateInfos[i].pStages[j].module)"))
136    cgen.endFor()
137    cgen.endFor()
138
139def extract_deps_vkCreateFramebuffer(param, access, lenExpr, api, cgen):
140    cgen.stmt("mReconstruction.addHandleDependency((const uint64_t*)%s, %s, (uint64_t)(uintptr_t)%s)" % \
141              (access, lenExpr, "unboxed_to_boxed_non_dispatchable_VkRenderPass(pCreateInfo->renderPass)"))
142    cgen.beginFor("uint32_t i = 0", "i < pCreateInfo->attachmentCount" , "++i")
143    cgen.stmt("mReconstruction.addHandleDependency((const uint64_t*)%s, %s, (uint64_t)(uintptr_t)%s)" % \
144              (access, lenExpr, "unboxed_to_boxed_non_dispatchable_VkImageView(pCreateInfo->pAttachments[i])"))
145    cgen.endFor()
146
147def extract_deps_vkBindImageMemory(param, access, lenExpr, api, cgen):
148    cgen.stmt("mReconstruction.addHandleDependency((const uint64_t*)%s, %s, (uint64_t)(uintptr_t)%s, VkReconstruction::BOUND_MEMORY)" % \
149              (access, lenExpr, "unboxed_to_boxed_non_dispatchable_VkDeviceMemory(memory)"))
150    cgen.stmt("mReconstruction.addHandleDependency((const uint64_t*)%s, %s, (uint64_t)(uintptr_t)((%s)[0]), VkReconstruction::BOUND_MEMORY)" % \
151              (access, lenExpr, access))
152
153def extract_deps_vkBindBufferMemory(param, access, lenExpr, api, cgen):
154    cgen.stmt("mReconstruction.addHandleDependency((const uint64_t*)%s, %s, (uint64_t)(uintptr_t)%s, VkReconstruction::BOUND_MEMORY)" % \
155              (access, lenExpr, "unboxed_to_boxed_non_dispatchable_VkDeviceMemory(memory)"))
156    cgen.stmt("mReconstruction.addHandleDependency((const uint64_t*)%s, %s, (uint64_t)(uintptr_t)((%s)[0]), VkReconstruction::BOUND_MEMORY)" % \
157              (access, lenExpr, access))
158
159specialCaseDependencyExtractors = {
160    "vkAllocateCommandBuffers" : extract_deps_vkAllocateCommandBuffers,
161    "vkAllocateMemory" : extract_deps_vkAllocateMemory,
162    "vkCreateImageView" : extract_deps_vkCreateImageView,
163    "vkCreateGraphicsPipelines" : extract_deps_vkCreateGraphicsPipelines,
164    "vkCreateFramebuffer" : extract_deps_vkCreateFramebuffer,
165    "vkBindImageMemory": extract_deps_vkBindImageMemory,
166    "vkBindBufferMemory": extract_deps_vkBindBufferMemory,
167}
168
169apiSequences = {
170    "vkAllocateMemory" : ["vkAllocateMemory", "vkMapMemoryIntoAddressSpaceGOOGLE"]
171}
172
173@dataclass(frozen=True)
174class VkObjectState:
175    vk_object : str
176    state : str = "VkReconstruction::CREATED"
177
178# TODO: add vkBindImageMemory2 and vkBindBufferMemory2 into this list
179apiChangeState = {
180    "vkBindImageMemory": VkObjectState("image", "VkReconstruction::BOUND_MEMORY"),
181    "vkBindBufferMemory": VkObjectState("buffer", "VkReconstruction::BOUND_MEMORY"),
182}
183
184apiModifies = {
185    "vkMapMemoryIntoAddressSpaceGOOGLE" : ["memory"],
186    "vkGetBlobGOOGLE" : ["memory"],
187    "vkBeginCommandBuffer" : ["commandBuffer"],
188    "vkEndCommandBuffer" : ["commandBuffer"],
189}
190
191apiClearModifiers = {
192    "vkResetCommandBuffer" : ["commandBuffer"],
193}
194
195delayedDestroys = [
196    "vkDestroyShaderModule",
197]
198
199# The following types are created and cached by other commands.
200# Thus we should not snapshot their "create" commands.
201skipCreatorSnapshotTypes = [
202    "VkQueue", # created by vkCreateDevice
203    "VkDescriptorSet", # created by vkCreateDescriptorPool
204]
205
206def is_state_change_operation(api, param):
207    if param.isCreatedBy(api) and param.typeName not in skipCreatorSnapshotTypes:
208        return True
209    if api.name in apiChangeState:
210        if param.paramName == apiChangeState[api.name].vk_object:
211            return True
212    return False
213
214def get_target_state(api, param):
215    if param.isCreatedBy(api):
216        return "VkReconstruction::CREATED"
217    if api.name in apiChangeState:
218        if param.paramName == apiChangeState[api.name].vk_object:
219            return apiChangeState[api.name].state
220    return None
221
222def is_modify_operation(api, param):
223    if api.name in apiModifies:
224        if param.paramName in apiModifies[api.name]:
225            return True
226    if api.name.startswith('vkCmd') and param.paramName == 'commandBuffer':
227        return True
228    return False
229
230def is_clear_modifier_operation(api, param):
231    if api.name in apiClearModifiers:
232        if param.paramName in apiClearModifiers[api.name]:
233            return True
234
235
236def emit_impl(typeInfo, api, cgen):
237    for p in api.parameters:
238        if not (p.isHandleType):
239            continue
240
241        lenExpr = cgen.generalLengthAccess(p)
242        lenAccessGuard = cgen.generalLengthAccessGuard(p)
243
244        if lenExpr is None:
245            lenExpr = "1"
246
247        # Note that in vkCreate*, the last parameter (the output) is boxed. But all input parameters are unboxed.
248
249        if p.pointerIndirectionLevels > 0:
250            access = p.paramName
251        else:
252            access = "(&%s)" % p.paramName
253
254        if is_state_change_operation(api, p):
255            if p.isCreatedBy(api):
256                boxed_access = access
257            else:
258                cgen.stmt("%s boxed_%s = unboxed_to_boxed_non_dispatchable_%s(%s[0])" % (p.typeName, p.typeName, p.typeName, access))
259                boxed_access = "&boxed_%s" % p.typeName
260            if p.pointerIndirectionLevels > 0:
261                cgen.stmt("if (!%s) return" % access)
262            cgen.stmt("android::base::AutoLock lock(mLock)")
263            cgen.line("// %s create" % p.paramName)
264            if p.isCreatedBy(api):
265                cgen.stmt("mReconstruction.addHandles((const uint64_t*)%s, %s)" % (boxed_access, lenExpr));
266
267            if p.isCreatedBy(api) and p.typeName in handleDependenciesDict:
268                dependsOnType = handleDependenciesDict[p.typeName];
269                for p2 in api.parameters:
270                    if p2.typeName == dependsOnType:
271                        cgen.stmt("mReconstruction.addHandleDependency((const uint64_t*)%s, %s, (uint64_t)(uintptr_t)%s)" % (boxed_access, lenExpr, p2.paramName))
272            if api.name in specialCaseDependencyExtractors:
273                specialCaseDependencyExtractors[api.name](p, boxed_access, lenExpr, api, cgen)
274
275            cgen.stmt("auto apiHandle = mReconstruction.createApiInfo()")
276            cgen.stmt("auto apiInfo = mReconstruction.getApiInfo(apiHandle)")
277            cgen.stmt("mReconstruction.setApiTrace(apiInfo, OP_%s, snapshotTraceBegin, snapshotTraceBytes)" % api.name)
278            if lenAccessGuard is not None:
279                cgen.beginIf(lenAccessGuard)
280            cgen.stmt(f"mReconstruction.forEachHandleAddApi((const uint64_t*){boxed_access}, {lenExpr}, apiHandle, {get_target_state(api, p)})")
281            if p.isCreatedBy(api):
282                cgen.stmt("mReconstruction.setCreatedHandlesForApi(apiHandle, (const uint64_t*)%s, %s)" % (boxed_access, lenExpr))
283            if lenAccessGuard is not None:
284                cgen.endIf()
285
286        if p.isDestroyedBy(api):
287            cgen.stmt("android::base::AutoLock lock(mLock)")
288            cgen.line("// %s destroy" % p.paramName)
289            if lenAccessGuard is not None:
290                cgen.beginIf(lenAccessGuard)
291            shouldRecursiveDestroy = "false" if api.name in delayedDestroys else "true"
292            cgen.stmt("mReconstruction.removeHandles((const uint64_t*)%s, %s, %s)" % (access, lenExpr, shouldRecursiveDestroy));
293            if lenAccessGuard is not None:
294                cgen.endIf()
295
296        if is_modify_operation(api, p) or is_clear_modifier_operation(api, p):
297            cgen.stmt("android::base::AutoLock lock(mLock)")
298            cgen.line("// %s modify" % p.paramName)
299            cgen.stmt("auto apiHandle = mReconstruction.createApiInfo()")
300            cgen.stmt("auto apiInfo = mReconstruction.getApiInfo(apiHandle)")
301            cgen.stmt("mReconstruction.setApiTrace(apiInfo, OP_%s, snapshotTraceBegin, snapshotTraceBytes)" % api.name)
302            if lenAccessGuard is not None:
303                cgen.beginIf(lenAccessGuard)
304            cgen.beginFor("uint32_t i = 0", "i < %s" % lenExpr, "++i")
305            if p.isNonDispatchableHandleType():
306                cgen.stmt("%s boxed = unboxed_to_boxed_non_dispatchable_%s(%s[i])" % (p.typeName, p.typeName, access))
307            else:
308                cgen.stmt("%s boxed = unboxed_to_boxed_%s(%s[i])" % (p.typeName, p.typeName, access))
309            if is_modify_operation(api, p):
310                cgen.stmt("mReconstruction.forEachHandleAddModifyApi((const uint64_t*)(&boxed), 1, apiHandle)")
311            else: # is clear modifier operation
312                cgen.stmt("mReconstruction.forEachHandleClearModifyApi((const uint64_t*)(&boxed), 1)")
313            cgen.endFor()
314            if lenAccessGuard is not None:
315                cgen.endIf()
316
317def emit_passthrough_to_impl(typeInfo, api, cgen):
318    cgen.vkApiCall(api, customPrefix = "mImpl->")
319
320class VulkanDecoderSnapshot(VulkanWrapperGenerator):
321    def __init__(self, module, typeInfo):
322        VulkanWrapperGenerator.__init__(self, module, typeInfo)
323
324        self.typeInfo = typeInfo
325
326        self.cgenHeader = CodeGen()
327        self.cgenHeader.incrIndent()
328
329        self.cgenImpl = CodeGen()
330
331        self.currentFeature = None
332
333        self.feature_apis = []
334
335    def onBegin(self,):
336        self.module.appendHeader(decoder_snapshot_decl_preamble)
337        self.module.appendImpl(decoder_snapshot_impl_preamble)
338
339    def onBeginFeature(self, featureName, featureType):
340        VulkanWrapperGenerator.onBeginFeature(self, featureName, featureType)
341        self.currentFeature = featureName
342
343    def onGenCmd(self, cmdinfo, name, alias):
344        VulkanWrapperGenerator.onGenCmd(self, cmdinfo, name, alias)
345
346        api = self.typeInfo.apis[name]
347
348        additionalParams = [ \
349            makeVulkanTypeSimple(True, "uint8_t", 1, "snapshotTraceBegin"),
350            makeVulkanTypeSimple(False, "size_t", 0, "snapshotTraceBytes"),
351            makeVulkanTypeSimple(False, "android::base::BumpPool", 1, "pool"),]
352
353        if api.retType.typeName != "void":
354            additionalParams.append( \
355                makeVulkanTypeSimple(False, api.retType.typeName, 0, "input_result"))
356
357        apiForSnapshot = \
358            api.withCustomParameters( \
359                additionalParams + \
360                api.parameters).withCustomReturnType( \
361                    makeVulkanTypeSimple(False, "void", 0, "void"))
362
363        self.feature_apis.append((self.currentFeature, apiForSnapshot))
364
365        self.cgenHeader.stmt(self.cgenHeader.makeFuncProto(apiForSnapshot))
366        self.module.appendHeader(self.cgenHeader.swapCode())
367
368        self.cgenImpl.emitFuncImpl( \
369            apiForSnapshot, lambda cgen: emit_impl(self.typeInfo, apiForSnapshot, cgen))
370        self.module.appendImpl(self.cgenImpl.swapCode())
371
372    def onEnd(self,):
373        self.module.appendHeader(decoder_snapshot_decl_postamble)
374        self.module.appendImpl(decoder_snapshot_impl_postamble)
375        self.cgenHeader.decrIndent()
376
377        for feature, api in self.feature_apis:
378            if feature is not None:
379                self.cgenImpl.line("#ifdef %s" % feature)
380
381            apiImplShell = \
382                api.withModifiedName("VkDecoderSnapshot::" + api.name)
383
384            self.cgenImpl.emitFuncImpl( \
385                apiImplShell, lambda cgen: emit_passthrough_to_impl(self.typeInfo, api, cgen))
386
387            if feature is not None:
388                self.cgenImpl.line("#endif")
389
390        self.module.appendImpl(self.cgenImpl.swapCode())
391
392