1# Copyright (c) 2018 The Android Open Source Project
2# Copyright (c) 2018 Google Inc.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16from copy import copy
17import hashlib, sys
18
19from .common.codegen import CodeGen, VulkanAPIWrapper
20from .common.vulkantypes import \
21        VulkanAPI, makeVulkanTypeSimple, iterateVulkanType, VulkanTypeIterator, Atom, FuncExpr, FuncExprVal, FuncLambda
22
23from .wrapperdefs import VulkanWrapperGenerator
24from .wrapperdefs import VULKAN_STREAM_VAR_NAME
25from .wrapperdefs import ROOT_TYPE_VAR_NAME, ROOT_TYPE_PARAM
26from .wrapperdefs import STREAM_RET_TYPE
27from .wrapperdefs import MARSHAL_INPUT_VAR_NAME
28from .wrapperdefs import UNMARSHAL_INPUT_VAR_NAME
29from .wrapperdefs import PARAMETERS_MARSHALING
30from .wrapperdefs import PARAMETERS_MARSHALING_GUEST
31from .wrapperdefs import STYPE_OVERRIDE
32from .wrapperdefs import STRUCT_EXTENSION_PARAM, STRUCT_EXTENSION_PARAM_FOR_WRITE, EXTENSION_SIZE_WITH_STREAM_FEATURES_API_NAME
33from .wrapperdefs import API_PREFIX_MARSHAL
34from .wrapperdefs import API_PREFIX_UNMARSHAL
35
36from .marshalingdefs import KNOWN_FUNCTION_OPCODES, CUSTOM_MARSHAL_TYPES
37
38class VulkanMarshalingCodegen(VulkanTypeIterator):
39
40    def __init__(self,
41                 cgen,
42                 streamVarName,
43                 rootTypeVarName,
44                 inputVarName,
45                 marshalPrefix,
46                 direction = "write",
47                 forApiOutput = False,
48                 dynAlloc = False,
49                 mapHandles = True,
50                 handleMapOverwrites = False,
51                 doFiltering = True):
52        self.cgen = cgen
53        self.direction = direction
54        self.processSimple = "write" if self.direction == "write" else "read"
55        self.forApiOutput = forApiOutput
56
57        self.checked = False
58
59        self.streamVarName = streamVarName
60        self.rootTypeVarName = rootTypeVarName
61        self.inputVarName = inputVarName
62        self.marshalPrefix = marshalPrefix
63
64        self.exprAccessor = lambda t: self.cgen.generalAccess(t, parentVarName = self.inputVarName, asPtr = True)
65        self.exprValueAccessor = lambda t: self.cgen.generalAccess(t, parentVarName = self.inputVarName, asPtr = False)
66        self.exprPrimitiveValueAccessor = lambda t: self.cgen.generalAccess(t, parentVarName = self.inputVarName, asPtr = False)
67        self.lenAccessor = lambda t: self.cgen.generalLengthAccess(t, parentVarName = self.inputVarName)
68        self.lenAccessorGuard = lambda t: self.cgen.generalLengthAccessGuard(
69            t, parentVarName=self.inputVarName)
70        self.filterVarAccessor = lambda t: self.cgen.filterVarAccess(t, parentVarName = self.inputVarName)
71
72        self.dynAlloc = dynAlloc
73        self.mapHandles = mapHandles
74        self.handleMapOverwrites = handleMapOverwrites
75        self.doFiltering = doFiltering
76
77    def getTypeForStreaming(self, vulkanType):
78        res = copy(vulkanType)
79
80        if not vulkanType.accessibleAsPointer():
81            res = res.getForAddressAccess()
82
83        if vulkanType.staticArrExpr:
84            res = res.getForAddressAccess()
85
86        if self.direction == "write":
87            return res
88        else:
89            return res.getForNonConstAccess()
90
91    def makeCastExpr(self, vulkanType):
92        return "(%s)" % (
93            self.cgen.makeCTypeDecl(vulkanType, useParamName=False))
94
95    def genStreamCall(self, vulkanType, toStreamExpr, sizeExpr):
96        varname = self.streamVarName
97        func = self.processSimple
98        cast = self.makeCastExpr(self.getTypeForStreaming(vulkanType))
99
100        self.cgen.stmt(
101            "%s->%s(%s%s, %s)" % (varname, func, cast, toStreamExpr, sizeExpr))
102
103    def genPrimitiveStreamCall(self, vulkanType, access):
104        varname = self.streamVarName
105
106        self.cgen.streamPrimitive(
107            self.typeInfo,
108            varname,
109            access,
110            vulkanType,
111            direction=self.direction)
112
113    def genHandleMappingCall(self, vulkanType, access, lenAccess):
114
115        if lenAccess is None:
116            lenAccess = "1"
117            handle64Bytes = "8"
118        else:
119            handle64Bytes = "%s * 8" % lenAccess
120
121        handle64Var = self.cgen.var()
122        if lenAccess != "1":
123            self.cgen.beginIf(lenAccess)
124            self.cgen.stmt("uint64_t* %s" % handle64Var)
125            self.cgen.stmt(
126                "%s->alloc((void**)&%s, %s * 8)" % \
127                (self.streamVarName, handle64Var, lenAccess))
128            handle64VarAccess = handle64Var
129            handle64VarType = \
130                makeVulkanTypeSimple(False, "uint64_t", 1, paramName=handle64Var)
131        else:
132            self.cgen.stmt("uint64_t %s" % handle64Var)
133            handle64VarAccess = "&%s" % handle64Var
134            handle64VarType = \
135                makeVulkanTypeSimple(False, "uint64_t", 0, paramName=handle64Var)
136
137        if self.direction == "write":
138            if self.handleMapOverwrites:
139                self.cgen.stmt(
140                    "static_assert(8 == sizeof(%s), \"handle map overwrite requires %s to be 8 bytes long\")" % \
141                            (vulkanType.typeName, vulkanType.typeName))
142                self.cgen.stmt(
143                    "%s->handleMapping()->mapHandles_%s((%s*)%s, %s)" %
144                    (self.streamVarName, vulkanType.typeName, vulkanType.typeName,
145                    access, lenAccess))
146                self.genStreamCall(vulkanType, access, "8 * %s" % lenAccess)
147            else:
148                self.cgen.stmt(
149                    "%s->handleMapping()->mapHandles_%s_u64(%s, %s, %s)" %
150                    (self.streamVarName, vulkanType.typeName,
151                    access,
152                    handle64VarAccess, lenAccess))
153                self.genStreamCall(handle64VarType, handle64VarAccess, handle64Bytes)
154        else:
155            self.genStreamCall(handle64VarType, handle64VarAccess, handle64Bytes)
156            self.cgen.stmt(
157                "%s->handleMapping()->mapHandles_u64_%s(%s, %s%s, %s)" %
158                (self.streamVarName, vulkanType.typeName,
159                handle64VarAccess,
160                self.makeCastExpr(vulkanType.getForNonConstAccess()), access,
161                lenAccess))
162
163        if lenAccess != "1":
164            self.cgen.endIf()
165
166    def doAllocSpace(self, vulkanType):
167        if self.dynAlloc and self.direction == "read":
168            access = self.exprAccessor(vulkanType)
169            lenAccess = self.lenAccessor(vulkanType)
170            sizeof = self.cgen.sizeofExpr( \
171                         vulkanType.getForValueAccess())
172            if lenAccess:
173                bytesExpr = "%s * %s" % (lenAccess, sizeof)
174            else:
175                bytesExpr = sizeof
176
177            self.cgen.stmt( \
178                "%s->alloc((void**)&%s, %s)" %
179                    (self.streamVarName,
180                     access, bytesExpr))
181
182    def getOptionalStringFeatureExpr(self, vulkanType):
183        streamFeature = vulkanType.getProtectStreamFeature()
184        if streamFeature is None:
185            return None
186        return "%s->getFeatureBits() & %s" % (self.streamVarName, streamFeature)
187
188    def onCheck(self, vulkanType):
189
190        if self.forApiOutput:
191            return
192
193        featureExpr = self.getOptionalStringFeatureExpr(vulkanType);
194
195        self.checked = True
196
197        access = self.exprAccessor(vulkanType)
198
199        needConsistencyCheck = False
200
201        self.cgen.line("// WARNING PTR CHECK")
202        if (self.dynAlloc and self.direction == "read") or self.direction == "write":
203            checkAccess = self.exprAccessor(vulkanType)
204            addrExpr = "&" + checkAccess
205            sizeExpr = self.cgen.sizeofExpr(vulkanType)
206        else:
207            checkName = "check_%s" % vulkanType.paramName
208            self.cgen.stmt("%s %s" % (
209                self.cgen.makeCTypeDecl(vulkanType, useParamName = False), checkName))
210            checkAccess = checkName
211            addrExpr = "&" + checkAccess
212            sizeExpr = self.cgen.sizeofExpr(vulkanType)
213            needConsistencyCheck = True
214
215        if featureExpr is not None:
216            self.cgen.beginIf(featureExpr)
217
218        self.genPrimitiveStreamCall(
219            vulkanType,
220            checkAccess)
221
222        if featureExpr is not None:
223            self.cgen.endIf()
224
225        if featureExpr is not None:
226            self.cgen.beginIf("(!(%s) || %s)" % (featureExpr, access))
227        else:
228            self.cgen.beginIf(access)
229
230        if needConsistencyCheck and featureExpr is None:
231            self.cgen.beginIf("!(%s)" % checkName)
232            self.cgen.stmt(
233                "fprintf(stderr, \"fatal: %s inconsistent between guest and host\\n\")" % (access))
234            self.cgen.endIf()
235
236
237    def onCheckWithNullOptionalStringFeature(self, vulkanType):
238        self.cgen.beginIf("%s->getFeatureBits() & VULKAN_STREAM_FEATURE_NULL_OPTIONAL_STRINGS_BIT" % self.streamVarName)
239        self.onCheck(vulkanType)
240
241    def endCheckWithNullOptionalStringFeature(self, vulkanType):
242        self.endCheck(vulkanType)
243        self.cgen.endIf()
244        self.cgen.beginElse()
245
246    def finalCheckWithNullOptionalStringFeature(self, vulkanType):
247        self.cgen.endElse()
248
249    def endCheck(self, vulkanType):
250
251        if self.checked:
252            self.cgen.endIf()
253            self.checked = False
254
255    def genFilterFunc(self, filterfunc, env):
256
257        def loop(expr, lambdaEnv={}):
258            def do_func(expr):
259                fnamestr = expr.name.name
260                if "not" == fnamestr:
261                    return "!(%s)" % (loop(expr.args[0], lambdaEnv))
262                if "eq" == fnamestr:
263                    return "(%s == %s)" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv))
264                if "and" == fnamestr:
265                    return "(%s && %s)" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv))
266                if "or" == fnamestr:
267                    return "(%s || %s)" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv))
268                if "bitwise_and" == fnamestr:
269                    return "(%s & %s)" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv))
270                if "getfield" == fnamestr:
271                    ptrlevels = get_ptrlevels(expr.args[0].val.name)
272                    if ptrlevels == 0:
273                        return "%s.%s" % (loop(expr.args[0], lambdaEnv), expr.args[1].val)
274                    else:
275                        return "(%s(%s)).%s" % ("*" * ptrlevels, loop(expr.args[0], lambdaEnv), expr.args[1].val)
276
277                if "if" == fnamestr:
278                    return "((%s) ? (%s) : (%s))" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv), loop(expr.args[2], lambdaEnv))
279
280                return "%s(%s)" % (fnamestr, ", ".join(map(lambda e: loop(e, lambdaEnv), expr.args)))
281
282            def do_expratom(atomname, lambdaEnv= {}):
283                if lambdaEnv.get(atomname, None) is not None:
284                    return atomname
285
286                enventry = env.get(atomname, None)
287                if None != enventry:
288                    return self.getEnvAccessExpr(atomname)
289                return atomname
290
291            def get_ptrlevels(atomname, lambdaEnv= {}):
292                if lambdaEnv.get(atomname, None) is not None:
293                    return 0
294
295                enventry = env.get(atomname, None)
296                if None != enventry:
297                    return self.getPointerIndirectionLevels(atomname)
298
299                return 0
300
301            def do_exprval(expr, lambdaEnv= {}):
302                expratom = expr.val
303
304                if Atom == type(expratom):
305                    return do_expratom(expratom.name, lambdaEnv)
306
307                return "%s" % expratom
308
309            def do_lambda(expr, lambdaEnv= {}):
310                params = expr.vs
311                body = expr.body
312                newEnv = {}
313
314                for (k, v) in lambdaEnv.items():
315                    newEnv[k] = v
316
317                for p in params:
318                    newEnv[p.name] = p.typ
319
320                return "[](%s) { return %s; }" % (", ".join(list(map(lambda p: "%s %s" % (p.typ, p.name), params))), loop(body, lambdaEnv=newEnv))
321
322            if FuncExpr == type(expr):
323                return do_func(expr)
324            if FuncLambda == type(expr):
325                return do_lambda(expr)
326            elif FuncExprVal == type(expr):
327                return do_exprval(expr)
328
329        return loop(filterfunc)
330
331    def beginFilterGuard(self, vulkanType):
332        if vulkanType.filterVar == None:
333            return
334
335        if self.doFiltering == False:
336            return
337
338        filterVarAccess = self.getEnvAccessExpr(vulkanType.filterVar)
339
340        filterValsExpr = None
341        filterFuncExpr = None
342        filterExpr = None
343
344        filterFeature = "%s->getFeatureBits() & VULKAN_STREAM_FEATURE_IGNORED_HANDLES_BIT" % self.streamVarName
345
346        if None != vulkanType.filterVals:
347            filterValsExpr = " || ".join(map(lambda filterval: "(%s == %s)" % (filterval, filterVarAccess), vulkanType.filterVals))
348
349        if None != vulkanType.filterFunc:
350            filterFuncExpr = self.genFilterFunc(vulkanType.filterFunc, self.currentStructInfo.environment)
351
352        if None != filterValsExpr and None != filterFuncExpr:
353            filterExpr = "%s || %s" % (filterValsExpr, filterFuncExpr)
354        elif None == filterValsExpr and None == filterFuncExpr:
355            # Assume is bool
356            self.cgen.beginIf(filterVarAccess)
357        elif None != filterValsExpr:
358            self.cgen.beginIf("(!(%s) || (%s))" % (filterFeature, filterValsExpr))
359        elif None != filterFuncExpr:
360            self.cgen.beginIf("(!(%s) || (%s))" % (filterFeature, filterFuncExpr))
361
362    def endFilterGuard(self, vulkanType, cleanupExpr=None):
363        if vulkanType.filterVar == None:
364            return
365
366        if self.doFiltering == False:
367            return
368
369        if cleanupExpr == None:
370            self.cgen.endIf()
371        else:
372            self.cgen.endIf()
373            self.cgen.beginElse()
374            self.cgen.stmt(cleanupExpr)
375            self.cgen.endElse()
376
377    def getEnvAccessExpr(self, varName):
378        parentEnvEntry = self.currentStructInfo.environment.get(varName, None)
379
380        if parentEnvEntry != None:
381            isParentMember = parentEnvEntry["structmember"]
382
383            if isParentMember:
384                envAccess = self.exprValueAccessor(list(filter(lambda member: member.paramName == varName, self.currentStructInfo.members))[0])
385            else:
386                envAccess = varName
387            return envAccess
388
389        return None
390
391    def getPointerIndirectionLevels(self, varName):
392        parentEnvEntry = self.currentStructInfo.environment.get(varName, None)
393
394        if parentEnvEntry != None:
395            isParentMember = parentEnvEntry["structmember"]
396
397            if isParentMember:
398                return list(filter(lambda member: member.paramName == varName, self.currentStructInfo.members))[0].pointerIndirectionLevels
399            else:
400                return 0
401            return 0
402
403        return 0
404
405
406    def onCompoundType(self, vulkanType):
407
408        access = self.exprAccessor(vulkanType)
409        lenAccess = self.lenAccessor(vulkanType)
410        lenAccessGuard = self.lenAccessorGuard(vulkanType)
411
412        self.beginFilterGuard(vulkanType)
413
414        if vulkanType.pointerIndirectionLevels > 0:
415            self.doAllocSpace(vulkanType)
416
417        if lenAccess is not None:
418            if lenAccessGuard is not None:
419                self.cgen.beginIf(lenAccessGuard)
420            loopVar = "i"
421            access = "%s + %s" % (access, loopVar)
422            forInit = "uint32_t %s = 0" % loopVar
423            forCond = "%s < (uint32_t)%s" % (loopVar, lenAccess)
424            forIncr = "++%s" % loopVar
425            self.cgen.beginFor(forInit, forCond, forIncr)
426
427        accessWithCast = "%s(%s)" % (self.makeCastExpr(
428            self.getTypeForStreaming(vulkanType)), access)
429
430        callParams = [self.streamVarName, self.rootTypeVarName, accessWithCast]
431
432        for (bindName, localName) in vulkanType.binds.items():
433            callParams.append(self.getEnvAccessExpr(localName))
434
435        self.cgen.funcCall(None, self.marshalPrefix + vulkanType.typeName,
436                           callParams)
437
438        if lenAccess is not None:
439            self.cgen.endFor()
440            if lenAccessGuard is not None:
441                self.cgen.endIf()
442
443        if self.direction == "read":
444            self.endFilterGuard(vulkanType, "%s = 0" % self.exprAccessor(vulkanType))
445        else:
446            self.endFilterGuard(vulkanType)
447
448    def onString(self, vulkanType):
449
450        access = self.exprAccessor(vulkanType)
451
452        if self.direction == "write":
453            self.cgen.stmt("%s->putString(%s)" % (self.streamVarName, access))
454        else:
455            castExpr = \
456                self.makeCastExpr( \
457                    self.getTypeForStreaming( \
458                        vulkanType.getForAddressAccess()))
459
460            self.cgen.stmt( \
461                "%s->loadStringInPlace(%s&%s)" % (self.streamVarName, castExpr, access))
462
463    def onStringArray(self, vulkanType):
464
465        access = self.exprAccessor(vulkanType)
466        lenAccess = self.lenAccessor(vulkanType)
467
468        if self.direction == "write":
469            self.cgen.stmt("saveStringArray(%s, %s, %s)" % (self.streamVarName,
470                                                            access, lenAccess))
471        else:
472            castExpr = \
473                self.makeCastExpr( \
474                    self.getTypeForStreaming( \
475                        vulkanType.getForAddressAccess()))
476
477            self.cgen.stmt("%s->loadStringArrayInPlace(%s&%s)" % (self.streamVarName, castExpr, access))
478
479    def onStaticArr(self, vulkanType):
480        access = self.exprValueAccessor(vulkanType)
481        lenAccess = self.lenAccessor(vulkanType)
482        finalLenExpr = "%s * %s" % (lenAccess, self.cgen.sizeofExpr(vulkanType))
483        self.genStreamCall(vulkanType, access, finalLenExpr)
484
485    # Old version VkEncoder may have some sType values conflict with VkDecoder
486    # of new versions. For host decoder, it should not carry the incorrect old
487    # sType values to the |forUnmarshaling| struct. Instead it should overwrite
488    # the sType value.
489    def overwriteSType(self, vulkanType):
490        if self.direction == "read":
491            sTypeParam = copy(vulkanType)
492            sTypeParam.paramName = "sType"
493            sTypeAccess = self.exprAccessor(sTypeParam)
494
495            typeName = vulkanType.parent.typeName
496            if typeName in STYPE_OVERRIDE:
497                self.cgen.stmt("%s = %s" %
498                               (sTypeAccess, STYPE_OVERRIDE[typeName]))
499
500    def onStructExtension(self, vulkanType):
501        self.overwriteSType(vulkanType)
502
503        sTypeParam = copy(vulkanType)
504        sTypeParam.paramName = "sType"
505
506        access = self.exprAccessor(vulkanType)
507        sizeVar = "%s_size" % vulkanType.paramName
508
509        if self.direction == "read":
510            castedAccessExpr = "(%s)(%s)" % ("void*", access)
511        else:
512            castedAccessExpr = access
513
514        sTypeAccess = self.exprAccessor(sTypeParam)
515        self.cgen.beginIf("%s == VK_STRUCTURE_TYPE_MAX_ENUM" %
516                          self.rootTypeVarName)
517        self.cgen.stmt("%s = %s" % (self.rootTypeVarName, sTypeAccess))
518        self.cgen.endIf()
519
520        if self.direction == "read" and self.dynAlloc:
521            self.cgen.stmt("size_t %s" % sizeVar)
522            self.cgen.stmt("%s = %s->getBe32()" % \
523                (sizeVar, self.streamVarName))
524            self.cgen.stmt("%s = nullptr" % access)
525            self.cgen.beginIf(sizeVar)
526            self.cgen.stmt( \
527                    "%s->alloc((void**)&%s, sizeof(VkStructureType))" %
528                    (self.streamVarName, access))
529
530            self.genStreamCall(vulkanType, access, "sizeof(VkStructureType)")
531            self.cgen.stmt("VkStructureType extType = *(VkStructureType*)(%s)" % access)
532            self.cgen.stmt( \
533                "%s->alloc((void**)&%s, %s(%s->getFeatureBits(), %s, %s))" %
534                (self.streamVarName, access, EXTENSION_SIZE_WITH_STREAM_FEATURES_API_NAME, self.streamVarName, self.rootTypeVarName, access))
535            self.cgen.stmt("*(VkStructureType*)%s = extType" % access)
536
537            self.cgen.funcCall(None, self.marshalPrefix + "extension_struct",
538                               [self.streamVarName, self.rootTypeVarName, castedAccessExpr])
539            self.cgen.endIf()
540        else:
541
542            self.cgen.funcCall(None, self.marshalPrefix + "extension_struct",
543                               [self.streamVarName, self.rootTypeVarName, castedAccessExpr])
544
545
546    def onPointer(self, vulkanType):
547        access = self.exprAccessor(vulkanType)
548
549        lenAccess = self.lenAccessor(vulkanType)
550        lenAccessGuard = self.lenAccessorGuard(vulkanType)
551
552        self.beginFilterGuard(vulkanType)
553        self.doAllocSpace(vulkanType)
554
555        if vulkanType.filterVar != None:
556            print("onPointer Needs filter: %s filterVar %s" % (access, vulkanType.filterVar))
557
558        if vulkanType.isHandleType() and self.mapHandles:
559            self.genHandleMappingCall(vulkanType, access, lenAccess)
560        else:
561            if self.typeInfo.isNonAbiPortableType(vulkanType.typeName):
562                if lenAccess is not None:
563                    if lenAccessGuard is not None:
564                        self.cgen.beginIf(lenAccessGuard)
565                    self.cgen.beginFor("uint32_t i = 0", "i < (uint32_t)%s" % lenAccess, "++i")
566                    self.genPrimitiveStreamCall(vulkanType.getForValueAccess(), "%s[i]" % access)
567                    self.cgen.endFor()
568                    if lenAccessGuard is not None:
569                        self.cgen.endIf()
570                else:
571                    self.genPrimitiveStreamCall(vulkanType.getForValueAccess(), "(*%s)" % access)
572            else:
573                if lenAccess is not None:
574                    finalLenExpr = "%s * %s" % (
575                        lenAccess, self.cgen.sizeofExpr(vulkanType.getForValueAccess()))
576                else:
577                    finalLenExpr = "%s" % (
578                        self.cgen.sizeofExpr(vulkanType.getForValueAccess()))
579                self.genStreamCall(vulkanType, access, finalLenExpr)
580
581        if self.direction == "read":
582            self.endFilterGuard(vulkanType, "%s = 0" % access)
583        else:
584            self.endFilterGuard(vulkanType)
585
586    def onValue(self, vulkanType):
587        self.beginFilterGuard(vulkanType)
588
589        if vulkanType.isHandleType() and self.mapHandles:
590            access = self.exprAccessor(vulkanType)
591            if vulkanType.filterVar != None:
592                print("onValue Needs filter: %s filterVar %s" % (access, vulkanType.filterVar))
593            self.genHandleMappingCall(
594                vulkanType.getForAddressAccess(), access, "1")
595        elif self.typeInfo.isNonAbiPortableType(vulkanType.typeName):
596            access = self.exprPrimitiveValueAccessor(vulkanType)
597            self.genPrimitiveStreamCall(vulkanType, access)
598        else:
599            access = self.exprAccessor(vulkanType)
600            self.genStreamCall(vulkanType, access, self.cgen.sizeofExpr(vulkanType))
601
602        self.endFilterGuard(vulkanType)
603
604    def streamLetParameter(self, structInfo, letParamInfo):
605        filterFeature = "%s->getFeatureBits() & VULKAN_STREAM_FEATURE_IGNORED_HANDLES_BIT" % self.streamVarName
606        self.cgen.stmt("%s %s = 1" % (letParamInfo.typeName, letParamInfo.paramName))
607
608        self.cgen.beginIf(filterFeature)
609
610        if self.direction == "write":
611            bodyExpr = self.currentStructInfo.environment[letParamInfo.paramName]["body"]
612            self.cgen.stmt("%s = %s" % (letParamInfo.paramName, self.genFilterFunc(bodyExpr, self.currentStructInfo.environment)))
613
614        self.genPrimitiveStreamCall(letParamInfo, letParamInfo.paramName)
615
616        self.cgen.endIf()
617
618
619class VulkanMarshaling(VulkanWrapperGenerator):
620
621    def __init__(self, module, typeInfo, variant="host"):
622        VulkanWrapperGenerator.__init__(self, module, typeInfo)
623
624        self.cgenHeader = CodeGen()
625        self.cgenImpl = CodeGen()
626
627        self.variant = variant
628
629        self.currentFeature = None
630        self.apiOpcodes = {}
631        self.dynAlloc = self.variant != "guest"
632
633        if self.variant == "guest":
634            self.marshalingParams = PARAMETERS_MARSHALING_GUEST
635        else:
636            self.marshalingParams = PARAMETERS_MARSHALING
637
638        self.writeCodegen = \
639            VulkanMarshalingCodegen(
640                None,
641                VULKAN_STREAM_VAR_NAME,
642                ROOT_TYPE_VAR_NAME,
643                MARSHAL_INPUT_VAR_NAME,
644                API_PREFIX_MARSHAL,
645                direction = "write")
646
647        self.readCodegen = \
648            VulkanMarshalingCodegen(
649                None,
650                VULKAN_STREAM_VAR_NAME,
651                ROOT_TYPE_VAR_NAME,
652                UNMARSHAL_INPUT_VAR_NAME,
653                API_PREFIX_UNMARSHAL,
654                direction = "read",
655                dynAlloc=self.dynAlloc)
656
657        self.knownDefs = {}
658
659        # Begin Vulkan API opcodes from something high
660        # that is not going to interfere with renderControl
661        # opcodes
662        self.beginOpcodeOld = 20000
663        self.endOpcodeOld = 30000
664
665        self.beginOpcode = 200000000
666        self.endOpcode = 300000000
667        self.knownOpcodes = set()
668
669        self.extensionMarshalPrototype = \
670            VulkanAPI(API_PREFIX_MARSHAL + "extension_struct",
671                      STREAM_RET_TYPE,
672                      self.marshalingParams +
673                      [STRUCT_EXTENSION_PARAM])
674
675        self.extensionUnmarshalPrototype = \
676            VulkanAPI(API_PREFIX_UNMARSHAL + "extension_struct",
677                      STREAM_RET_TYPE,
678                      self.marshalingParams +
679                      [STRUCT_EXTENSION_PARAM_FOR_WRITE])
680
681    def onBegin(self,):
682        VulkanWrapperGenerator.onBegin(self)
683        self.module.appendImpl(self.cgenImpl.makeFuncDecl(self.extensionMarshalPrototype))
684        self.module.appendImpl(self.cgenImpl.makeFuncDecl(self.extensionUnmarshalPrototype))
685
686    def onBeginFeature(self, featureName, featureType):
687        VulkanWrapperGenerator.onBeginFeature(self, featureName, featureType)
688        self.currentFeature = featureName
689
690    def onGenType(self, typeXml, name, alias):
691        VulkanWrapperGenerator.onGenType(self, typeXml, name, alias)
692
693        if name in self.knownDefs:
694            return
695
696        category = self.typeInfo.categoryOf(name)
697
698        if category in ["struct", "union"] and alias:
699            self.module.appendHeader(
700                self.cgenHeader.makeFuncAlias(API_PREFIX_MARSHAL + name,
701                                              API_PREFIX_MARSHAL + alias))
702            self.module.appendHeader(
703                self.cgenHeader.makeFuncAlias(API_PREFIX_UNMARSHAL + name,
704                                              API_PREFIX_UNMARSHAL + alias))
705
706        if category in ["struct", "union"] and not alias:
707
708            structInfo = self.typeInfo.structs[name]
709
710            marshalParams = self.marshalingParams + \
711                [makeVulkanTypeSimple(True, name, 1, MARSHAL_INPUT_VAR_NAME)]
712
713            freeParams = []
714            letParams = []
715
716            for (envname, bindingInfo) in list(sorted(structInfo.environment.items(), key = lambda kv: kv[0])):
717                if None == bindingInfo["binding"]:
718                    freeParams.append(makeVulkanTypeSimple(True, bindingInfo["type"], 0, envname))
719                else:
720                    if not bindingInfo["structmember"]:
721                        letParams.append(makeVulkanTypeSimple(True, bindingInfo["type"], 0, envname))
722
723            marshalPrototype = \
724                VulkanAPI(API_PREFIX_MARSHAL + name,
725                          STREAM_RET_TYPE,
726                          marshalParams + freeParams)
727
728            marshalPrototypeNoFilter = \
729                VulkanAPI(API_PREFIX_MARSHAL + name,
730                          STREAM_RET_TYPE,
731                          marshalParams)
732
733            def structMarshalingCustom(cgen):
734                self.writeCodegen.cgen = cgen
735                self.writeCodegen.currentStructInfo = structInfo
736                self.writeCodegen.cgen.stmt("(void)%s" % ROOT_TYPE_VAR_NAME)
737
738                marshalingCode = \
739                    CUSTOM_MARSHAL_TYPES[name]["common"] + \
740                    CUSTOM_MARSHAL_TYPES[name]["marshaling"].format(
741                        streamVarName=self.writeCodegen.streamVarName,
742                        rootTypeVarName=self.writeCodegen.rootTypeVarName,
743                        inputVarName=self.writeCodegen.inputVarName,
744                        newInputVarName=self.writeCodegen.inputVarName + "_new")
745                for line in marshalingCode.split('\n'):
746                    cgen.line(line)
747
748            def structMarshalingDef(cgen):
749                self.writeCodegen.cgen = cgen
750                self.writeCodegen.currentStructInfo = structInfo
751                self.writeCodegen.cgen.stmt("(void)%s" % ROOT_TYPE_VAR_NAME)
752
753                if category == "struct":
754                    # marshal 'let' parameters first
755                    for letp in letParams:
756                        self.writeCodegen.streamLetParameter(self.typeInfo, letp)
757
758                    for member in structInfo.members:
759                        iterateVulkanType(self.typeInfo, member, self.writeCodegen)
760                if category == "union":
761                    iterateVulkanType(self.typeInfo, structInfo.members[0], self.writeCodegen)
762
763            def structMarshalingDefNoFilter(cgen):
764                self.writeCodegen.cgen = cgen
765                self.writeCodegen.currentStructInfo = structInfo
766                self.writeCodegen.doFiltering = False
767                self.writeCodegen.cgen.stmt("(void)%s" % ROOT_TYPE_VAR_NAME)
768
769                if category == "struct":
770                    # marshal 'let' parameters first
771                    for letp in letParams:
772                        self.writeCodegen.streamLetParameter(self.typeInfo, letp)
773
774                    for member in structInfo.members:
775                        iterateVulkanType(self.typeInfo, member, self.writeCodegen)
776                if category == "union":
777                    iterateVulkanType(self.typeInfo, structInfo.members[0], self.writeCodegen)
778                self.writeCodegen.doFiltering = True
779
780            self.module.appendHeader(
781                self.cgenHeader.makeFuncDecl(marshalPrototype))
782
783            if name in CUSTOM_MARSHAL_TYPES:
784                self.module.appendImpl(
785                    self.cgenImpl.makeFuncImpl(
786                        marshalPrototype, structMarshalingCustom))
787            else:
788                self.module.appendImpl(
789                    self.cgenImpl.makeFuncImpl(
790                        marshalPrototype, structMarshalingDef))
791
792            if freeParams != []:
793                self.module.appendHeader(
794                    self.cgenHeader.makeFuncDecl(marshalPrototypeNoFilter))
795                self.module.appendImpl(
796                    self.cgenImpl.makeFuncImpl(
797                        marshalPrototypeNoFilter, structMarshalingDefNoFilter))
798
799            unmarshalPrototype = \
800                VulkanAPI(API_PREFIX_UNMARSHAL + name,
801                          STREAM_RET_TYPE,
802                          self.marshalingParams + [makeVulkanTypeSimple(False, name, 1, UNMARSHAL_INPUT_VAR_NAME)] + freeParams)
803
804            unmarshalPrototypeNoFilter = \
805                VulkanAPI(API_PREFIX_UNMARSHAL + name,
806                          STREAM_RET_TYPE,
807                          self.marshalingParams + [makeVulkanTypeSimple(False, name, 1, UNMARSHAL_INPUT_VAR_NAME)])
808
809            def structUnmarshalingCustom(cgen):
810                self.readCodegen.cgen = cgen
811                self.readCodegen.currentStructInfo = structInfo
812                self.writeCodegen.cgen.stmt("(void)%s" % ROOT_TYPE_VAR_NAME)
813
814                unmarshalingCode = \
815                    CUSTOM_MARSHAL_TYPES[name]["common"] + \
816                    CUSTOM_MARSHAL_TYPES[name]["unmarshaling"].format(
817                        streamVarName=self.readCodegen.streamVarName,
818                        rootTypeVarName=self.readCodegen.rootTypeVarName,
819                        inputVarName=self.readCodegen.inputVarName,
820                        newInputVarName=self.readCodegen.inputVarName + "_new")
821                for line in unmarshalingCode.split('\n'):
822                    cgen.line(line)
823
824            def structUnmarshalingDef(cgen):
825                self.readCodegen.cgen = cgen
826                self.readCodegen.currentStructInfo = structInfo
827                self.writeCodegen.cgen.stmt("(void)%s" % ROOT_TYPE_VAR_NAME)
828
829                if category == "struct":
830                    # unmarshal 'let' parameters first
831                    for letp in letParams:
832                        self.readCodegen.streamLetParameter(self.typeInfo, letp)
833
834                    for member in structInfo.members:
835                        iterateVulkanType(self.typeInfo, member, self.readCodegen)
836                if category == "union":
837                    iterateVulkanType(self.typeInfo, structInfo.members[0], self.readCodegen)
838
839            def structUnmarshalingDefNoFilter(cgen):
840                self.readCodegen.cgen = cgen
841                self.readCodegen.currentStructInfo = structInfo
842                self.readCodegen.doFiltering = False
843                self.writeCodegen.cgen.stmt("(void)%s" % ROOT_TYPE_VAR_NAME)
844
845                if category == "struct":
846                    # unmarshal 'let' parameters first
847                    for letp in letParams:
848                        iterateVulkanType(self.typeInfo, letp, self.readCodegen)
849                    for member in structInfo.members:
850                        iterateVulkanType(self.typeInfo, member, self.readCodegen)
851                if category == "union":
852                    iterateVulkanType(self.typeInfo, structInfo.members[0], self.readCodegen)
853                self.readCodegen.doFiltering = True
854
855            self.module.appendHeader(
856                self.cgenHeader.makeFuncDecl(unmarshalPrototype))
857
858            if name in CUSTOM_MARSHAL_TYPES:
859                self.module.appendImpl(
860                    self.cgenImpl.makeFuncImpl(
861                        unmarshalPrototype, structUnmarshalingCustom))
862            else:
863                self.module.appendImpl(
864                    self.cgenImpl.makeFuncImpl(
865                        unmarshalPrototype, structUnmarshalingDef))
866
867            if freeParams != []:
868                self.module.appendHeader(
869                    self.cgenHeader.makeFuncDecl(unmarshalPrototypeNoFilter))
870                self.module.appendImpl(
871                    self.cgenImpl.makeFuncImpl(
872                        unmarshalPrototypeNoFilter, structUnmarshalingDefNoFilter))
873
874    def onGenCmd(self, cmdinfo, name, alias):
875        VulkanWrapperGenerator.onGenCmd(self, cmdinfo, name, alias)
876        if name in KNOWN_FUNCTION_OPCODES:
877            opcode = KNOWN_FUNCTION_OPCODES[name]
878        else:
879            hashCode = hashlib.sha256(name.encode()).hexdigest()[:8]
880            hashInt = int(hashCode, 16)
881            opcode = self.beginOpcode + hashInt % (self.endOpcode - self.beginOpcode)
882            hasHashCollision = False
883            while opcode in self.knownOpcodes:
884                hasHashCollision = True
885                opcode += 1
886            if hasHashCollision:
887                print("Hash collision occurred on function '{}'. "
888                      "Please add the following line to marshalingdefs.py:".format(name), file=sys.stderr)
889                print("----------------------", file=sys.stderr)
890                print("    \"{}\": {},".format(name, opcode), file=sys.stderr)
891                print("----------------------", file=sys.stderr)
892
893        self.module.appendHeader(
894            "#define OP_%s %d\n" % (name, opcode))
895        self.apiOpcodes[name] = (opcode, self.currentFeature)
896        self.knownOpcodes.add(opcode)
897
898    def doExtensionStructMarshalingCodegen(self, cgen, retType, extParam, forEach, funcproto, direction):
899        accessVar = "structAccess"
900        sizeVar = "currExtSize"
901        cgen.stmt("VkInstanceCreateInfo* %s = (VkInstanceCreateInfo*)(%s)" % (accessVar, extParam.paramName))
902        cgen.stmt("size_t %s = %s(%s->getFeatureBits(), %s, %s)" % (sizeVar,
903                                                                    EXTENSION_SIZE_WITH_STREAM_FEATURES_API_NAME, VULKAN_STREAM_VAR_NAME, ROOT_TYPE_VAR_NAME, extParam.paramName))
904
905        cgen.beginIf("!%s && %s" % (sizeVar, extParam.paramName))
906
907        cgen.line("// unknown struct extension; skip and call on its pNext field");
908        cgen.funcCall(None, funcproto.name, [
909                      "vkStream", ROOT_TYPE_VAR_NAME, "(void*)%s->pNext" % accessVar])
910        cgen.stmt("return")
911
912        cgen.endIf()
913        cgen.beginElse()
914
915        cgen.line("// known or null extension struct")
916
917        if direction == "write":
918            cgen.stmt("vkStream->putBe32(%s)" % sizeVar)
919        elif not self.dynAlloc:
920            cgen.stmt("vkStream->getBe32()");
921
922        cgen.beginIf("!%s" % (sizeVar))
923        cgen.line("// exit if this was a null extension struct (size == 0 in this branch)")
924        cgen.stmt("return")
925        cgen.endIf()
926
927        cgen.endIf()
928
929        # Now we can do stream stuff
930        if direction == "write":
931            cgen.stmt("vkStream->write(%s, sizeof(VkStructureType))" % extParam.paramName)
932        elif not self.dynAlloc:
933            cgen.stmt("uint64_t pNext_placeholder")
934            placeholderAccess = "(&pNext_placeholder)"
935            cgen.stmt("vkStream->read((void*)(&pNext_placeholder), sizeof(VkStructureType))")
936            cgen.stmt("(void)pNext_placeholder")
937
938        def fatalDefault(cgen):
939            cgen.line("// fatal; the switch is only taken if the extension struct is known");
940            cgen.stmt("abort()")
941            pass
942
943        self.emitForEachStructExtension(
944            cgen,
945            retType,
946            extParam,
947            forEach,
948            defaultEmit=fatalDefault,
949            rootTypeVar=ROOT_TYPE_PARAM)
950
951    def onEnd(self,):
952        VulkanWrapperGenerator.onEnd(self)
953
954        def forEachExtensionMarshal(ext, castedAccess, cgen):
955            cgen.funcCall(None, API_PREFIX_MARSHAL + ext.name,
956                          [VULKAN_STREAM_VAR_NAME, ROOT_TYPE_VAR_NAME, castedAccess])
957
958        def forEachExtensionUnmarshal(ext, castedAccess, cgen):
959            cgen.funcCall(None, API_PREFIX_UNMARSHAL + ext.name,
960                          [VULKAN_STREAM_VAR_NAME, ROOT_TYPE_VAR_NAME, castedAccess])
961
962        self.module.appendImpl(
963            self.cgenImpl.makeFuncImpl(
964                self.extensionMarshalPrototype,
965                lambda cgen: self.doExtensionStructMarshalingCodegen(
966                    cgen,
967                    STREAM_RET_TYPE,
968                    STRUCT_EXTENSION_PARAM,
969                    forEachExtensionMarshal,
970                    self.extensionMarshalPrototype,
971                    "write")))
972
973        self.module.appendImpl(
974            self.cgenImpl.makeFuncImpl(
975                self.extensionUnmarshalPrototype,
976                lambda cgen: self.doExtensionStructMarshalingCodegen(
977                    cgen,
978                    STREAM_RET_TYPE,
979                    STRUCT_EXTENSION_PARAM_FOR_WRITE,
980                    forEachExtensionUnmarshal,
981                    self.extensionUnmarshalPrototype,
982                    "read")))
983
984        opcode2stringPrototype = \
985            VulkanAPI("api_opcode_to_string",
986                          makeVulkanTypeSimple(True, "char", 1, "none"),
987                          [ makeVulkanTypeSimple(True, "uint32_t", 0, "opcode") ])
988
989        self.module.appendHeader(
990            self.cgenHeader.makeFuncDecl(opcode2stringPrototype))
991
992        def emitOpcode2StringImpl(apiOpcodes, cgen):
993            cgen.line("switch(opcode)")
994            cgen.beginBlock()
995
996            currFeature = None
997
998            for (name, (opcodeNum, feature)) in sorted(apiOpcodes.items(), key = lambda x : x[1][0]):
999                if not currFeature:
1000                    cgen.leftline("#ifdef %s" % feature)
1001                    currFeature = feature
1002
1003                if currFeature and feature != currFeature:
1004                    cgen.leftline("#endif")
1005                    cgen.leftline("#ifdef %s" % feature)
1006                    currFeature = feature
1007
1008                cgen.line("case OP_%s:" % name)
1009                cgen.beginBlock()
1010                cgen.stmt("return \"OP_%s\"" % name)
1011                cgen.endBlock()
1012
1013            if currFeature:
1014                cgen.leftline("#endif")
1015
1016            cgen.line("default:")
1017            cgen.beginBlock()
1018            cgen.stmt("return \"OP_UNKNOWN_API_CALL\"")
1019            cgen.endBlock()
1020
1021            cgen.endBlock()
1022
1023        self.module.appendImpl(
1024            self.cgenImpl.makeFuncImpl(
1025                opcode2stringPrototype,
1026                lambda cgen: emitOpcode2StringImpl(self.apiOpcodes, cgen)))
1027
1028        self.module.appendHeader(
1029            "#define OP_vkFirst_old %d\n" % (self.beginOpcodeOld))
1030        self.module.appendHeader(
1031            "#define OP_vkLast_old %d\n" % (self.endOpcodeOld))
1032        self.module.appendHeader(
1033            "#define OP_vkFirst %d\n" % (self.beginOpcode))
1034        self.module.appendHeader(
1035            "#define OP_vkLast %d\n" % (self.endOpcode))
1036