1# Copyright (c) 2018 The Android Open Source Project
2# Copyright (c) 2018 Google Inc.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16from .common.codegen import CodeGen
17from .common.vulkantypes import \
18        VulkanCompoundType, VulkanAPI, makeVulkanTypeSimple, vulkanTypeNeedsTransform, vulkanTypeGetNeededTransformTypes, VulkanTypeIterator, iterateVulkanType, vulkanTypeforEachSubType, TRIVIAL_TRANSFORMED_TYPES, NON_TRIVIAL_TRANSFORMED_TYPES, TRANSFORMED_TYPES
19
20from .wrapperdefs import VulkanWrapperGenerator
21from .wrapperdefs import STRUCT_EXTENSION_PARAM, STRUCT_EXTENSION_PARAM_FOR_WRITE
22
23def deviceMemoryTransform(resourceTrackerVarName, structOrApiInfo, getExpr, getLen, cgen, variant="tohost"):
24    paramIndices = \
25        structOrApiInfo.deviceMemoryInfoParameterIndices
26
27    for _, info in paramIndices.items():
28        orderedKeys = [
29            "handle",
30            "offset",
31            "size",
32            "typeIndex",
33            "typeBits",]
34
35        casts = {
36            "handle" : "VkDeviceMemory*",
37            "offset" : "VkDeviceSize*",
38            "size" : "VkDeviceSize*",
39            "typeIndex" : "uint32_t*",
40            "typeBits" : "uint32_t*",
41        }
42
43        accesses = {
44            "handle" : "nullptr",
45            "offset" : "nullptr",
46            "size" : "nullptr",
47            "typeIndex" : "nullptr",
48            "typeBits" : "nullptr",
49        }
50
51        lenAccesses = {
52            "handle" : "0",
53            "offset" : "0",
54            "size" : "0",
55            "typeIndex" : "0",
56            "typeBits" : "0",
57        }
58
59        def doParam(i, vulkanType):
60            access = getExpr(vulkanType)
61            lenAccess = getLen(vulkanType)
62
63            for k in orderedKeys:
64                if i == info.__dict__[k]:
65                    accesses[k] = access
66                    if lenAccess is not None:
67                        lenAccesses[k] = lenAccess
68                    else:
69                        lenAccesses[k] = "1"
70
71        vulkanTypeforEachSubType(structOrApiInfo, doParam)
72
73        callParams = ", ".join( \
74            ["(%s)%s, %s" % (casts[k], accesses[k], lenAccesses[k]) \
75                for k in orderedKeys])
76
77        if variant == "tohost":
78            cgen.stmt("%s->deviceMemoryTransform_tohost(%s)" % \
79                (resourceTrackerVarName, callParams))
80        else:
81            cgen.stmt("%s->deviceMemoryTransform_fromhost(%s)" % \
82                (resourceTrackerVarName, callParams))
83
84def directTransform(resourceTrackerVarName, vulkanType, getExpr, getLen, cgen, variant="tohost"):
85    access = getExpr(vulkanType)
86    lenAccess = getLen(vulkanType)
87
88    if lenAccess:
89        finalLenAccess = lenAccess
90    else:
91        finalLenAccess = "1"
92
93    cgen.stmt("%s->transformImpl_%s_%s(%s, %s)" % (resourceTrackerVarName,
94                                                   vulkanType.typeName, variant, access, finalLenAccess))
95
96def genTransformsForVulkanType(resourceTrackerVarName, structOrApiInfo, getExpr, getLen, cgen, variant="tohost"):
97    for transform in vulkanTypeGetNeededTransformTypes(structOrApiInfo):
98        if transform == "devicememory":
99            deviceMemoryTransform( \
100                resourceTrackerVarName,
101                structOrApiInfo,
102                getExpr, getLen, cgen, variant=variant)
103
104class TransformCodegen(VulkanTypeIterator):
105    def __init__(self, cgen, inputVar, resourceTrackerVarName, prefix, variant):
106        self.cgen = cgen
107        self.inputVar = inputVar
108        self.prefix = prefix
109        self.resourceTrackerVarName = resourceTrackerVarName
110
111        def makeAccess(varName, asPtr = True):
112            return lambda t: self.cgen.generalAccess(t, parentVarName = varName, asPtr = asPtr)
113
114        def makeLengthAccess(varName):
115            return lambda t: self.cgen.generalLengthAccess(t, parentVarName = varName)
116
117        def makeLengthAccessGuard(varName):
118            return lambda t: self.cgen.generalLengthAccessGuard(t, parentVarName=varName)
119
120        self.exprAccessor = makeAccess(self.inputVar)
121        self.exprAccessorValue = makeAccess(self.inputVar, asPtr = False)
122        self.lenAccessor = makeLengthAccess(self.inputVar)
123        self.lenAccessorGuard = makeLengthAccessGuard(self.inputVar)
124
125        self.checked = False
126
127        self.variant = variant
128
129    def makeCastExpr(self, vulkanType):
130        return "(%s)" % (
131            self.cgen.makeCTypeDecl(vulkanType, useParamName=False))
132
133    def asNonConstCast(self, access, vulkanType):
134        if vulkanType.staticArrExpr:
135            casted = "%s(%s)" % (self.makeCastExpr(vulkanType.getForAddressAccess().getForNonConstAccess()), access)
136        elif vulkanType.accessibleAsPointer():
137            casted = "%s(%s)" % (self.makeCastExpr(vulkanType.getForNonConstAccess()), access)
138        else:
139            casted = "%s(%s)" % (self.makeCastExpr(vulkanType.getForAddressAccess().getForNonConstAccess()), access)
140        return casted
141
142    def onCheck(self, vulkanType):
143        pass
144
145    def endCheck(self, vulkanType):
146        pass
147
148    def onCompoundType(self, vulkanType):
149
150        access = self.exprAccessor(vulkanType)
151        lenAccess = self.lenAccessor(vulkanType)
152        lenAccessGuard = self.lenAccessorGuard(vulkanType)
153
154        isPtr = vulkanType.pointerIndirectionLevels > 0
155
156        if lenAccessGuard is not None:
157            self.cgen.beginIf(lenAccessGuard)
158
159        if isPtr:
160            self.cgen.beginIf(access)
161
162        if lenAccess is not None:
163
164            loopVar = "i"
165            access = "%s + %s" % (access, loopVar)
166            forInit = "uint32_t %s = 0" % loopVar
167            forCond = "%s < (uint32_t)%s" % (loopVar, lenAccess)
168            forIncr = "++%s" % loopVar
169
170            self.cgen.beginFor(forInit, forCond, forIncr)
171
172        accessCasted = self.asNonConstCast(access, vulkanType)
173
174        if vulkanType.isTransformed:
175            directTransform(self.resourceTrackerVarName, vulkanType, self.exprAccessor, self.lenAccessor, self.cgen, variant=self.variant)
176
177        self.cgen.funcCall(None, self.prefix + vulkanType.typeName,
178                           [self.resourceTrackerVarName, accessCasted])
179
180        if lenAccess is not None:
181            self.cgen.endFor()
182
183        if isPtr:
184            self.cgen.endIf()
185
186        if lenAccessGuard is not None:
187            self.cgen.endIf()
188
189    def onString(self, vulkanType):
190        pass
191
192    def onStringArray(self, vulkanType):
193        pass
194
195    def onStaticArr(self, vulkanType):
196        pass
197
198    def onStructExtension(self, vulkanType):
199        access = self.exprAccessor(vulkanType)
200
201        castedAccessExpr = "(%s)(%s)" % ("void*", access)
202        self.cgen.beginIf(access)
203        self.cgen.funcCall(None, self.prefix + "extension_struct",
204                           [self.resourceTrackerVarName, castedAccessExpr])
205        self.cgen.endIf()
206
207    def onPointer(self, vulkanType):
208        pass
209
210    def onValue(self, vulkanType):
211        pass
212
213
214class VulkanTransform(VulkanWrapperGenerator):
215    def __init__(self, module, typeInfo, resourceTrackerTypeName="ResourceTracker", resourceTrackerVarName="resourceTracker"):
216        VulkanWrapperGenerator.__init__(self, module, typeInfo)
217
218        self.codegen = CodeGen()
219
220        self.transformPrefix = "transform_"
221
222        self.tohostpart = "tohost"
223        self.fromhostpart = "fromhost"
224        self.variants = [self.tohostpart, self.fromhostpart]
225
226        self.toTransformVar = "toTransform"
227        self.resourceTrackerTypeName = resourceTrackerTypeName
228        self.resourceTrackerVarName = resourceTrackerVarName
229        self.transformParam = \
230            makeVulkanTypeSimple(False, self.resourceTrackerTypeName, 1,
231                                 self.resourceTrackerVarName)
232        self.voidType = makeVulkanTypeSimple(False, "void", 0)
233
234        self.extensionTransformPrototypes = []
235
236        for variant in self.variants:
237            self.extensionTransformPrototypes.append( \
238                VulkanAPI(self.transformPrefix + variant + "_extension_struct",
239                          self.voidType,
240                          [self.transformParam, STRUCT_EXTENSION_PARAM_FOR_WRITE]))
241
242        self.knownStructs = {}
243        self.needsTransform = set([])
244
245    def onBegin(self,):
246        VulkanWrapperGenerator.onBegin(self)
247        # Set up a convenience macro fro the transformed structs
248        # and forward-declare the resource tracker class
249        self.codegen.stmt("class %s" % self.resourceTrackerTypeName)
250        self.codegen.line("#define LIST_TRIVIAL_TRANSFORMED_TYPES(f) \\")
251        for name in TRIVIAL_TRANSFORMED_TYPES:
252            self.codegen.line("f(%s) \\" % name)
253        self.codegen.line("")
254
255        self.codegen.line("#define LIST_NON_TRIVIAL_TRANSFORMED_TYPES(f) \\")
256        for name in NON_TRIVIAL_TRANSFORMED_TYPES:
257            self.codegen.line("f(%s) \\" % name)
258        self.codegen.line("")
259
260        self.codegen.line("#define LIST_TRANSFORMED_TYPES(f) \\")
261        self.codegen.line("LIST_TRIVIAL_TRANSFORMED_TYPES(f) \\")
262        self.codegen.line("LIST_NON_TRIVIAL_TRANSFORMED_TYPES(f) \\")
263        self.codegen.line("")
264
265        self.module.appendHeader(self.codegen.swapCode())
266
267        for prototype in self.extensionTransformPrototypes:
268            self.module.appendImpl(self.codegen.makeFuncDecl(
269                prototype))
270
271    def onGenType(self, typeXml, name, alias):
272        VulkanWrapperGenerator.onGenType(self, typeXml, name, alias)
273
274        if name in self.knownStructs:
275            return
276
277        category = self.typeInfo.categoryOf(name)
278
279        if category in ["struct", "union"] and alias:
280            for variant in self.variants:
281                self.module.appendHeader(
282                    self.codegen.makeFuncAlias(self.transformPrefix + variant + "_" + name,
283                                               self.transformPrefix + variant + "_" + alias))
284
285        if category in ["struct", "union"] and not alias:
286            structInfo = self.typeInfo.structs[name]
287            self.knownStructs[name] = structInfo
288
289            for variant in self.variants:
290                api = VulkanAPI( \
291                    self.transformPrefix + variant + "_" + name,
292                    self.voidType,
293                    [self.transformParam] + \
294                    [makeVulkanTypeSimple( \
295                        False, name, 1, self.toTransformVar)])
296
297                transformer = TransformCodegen(
298                    None,
299                    self.toTransformVar,
300                    self.resourceTrackerVarName,
301                    self.transformPrefix + variant + "_",
302                    variant)
303
304                def funcDefGenerator(cgen):
305                    transformer.cgen = cgen
306                    for p in api.parameters:
307                        cgen.stmt("(void)%s" % p.paramName)
308
309                    genTransformsForVulkanType(
310                        self.resourceTrackerVarName,
311                        structInfo,
312                        transformer.exprAccessor,
313                        transformer.lenAccessor,
314                        cgen,
315                        variant=variant)
316
317                    for member in structInfo.members:
318                        iterateVulkanType(
319                            self.typeInfo, member,
320                            transformer)
321
322                self.module.appendHeader(
323                    self.codegen.makeFuncDecl(api))
324                self.module.appendImpl(
325                    self.codegen.makeFuncImpl(api, funcDefGenerator))
326
327
328    def onGenCmd(self, cmdinfo, name, alias):
329        VulkanWrapperGenerator.onGenCmd(self, cmdinfo, name, alias)
330
331    def onEnd(self,):
332        VulkanWrapperGenerator.onEnd(self)
333
334        for (variant, prototype) in zip(self.variants, self.extensionTransformPrototypes):
335            def forEachExtensionTransform(ext, castedAccess, cgen):
336                if ext.isTransformed:
337                    directTransform(self.resourceTrackerVarName, ext, lambda _ : castedAccess, lambda _ : "1", cgen, variant);
338                cgen.funcCall(None, self.transformPrefix + variant + "_" + ext.name,
339                              [self.resourceTrackerVarName, castedAccess])
340
341            self.module.appendImpl(
342                self.codegen.makeFuncImpl(
343                    prototype,
344                    lambda cgen: self.emitForEachStructExtension(
345                        cgen,
346                        self.voidType,
347                        STRUCT_EXTENSION_PARAM_FOR_WRITE,
348                        forEachExtensionTransform)))
349