1# Copyright (c) 2018 The Android Open Source Project
2# Copyright (c) 2018 Google Inc.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15from copy import copy
16
17from .common.codegen import CodeGen, VulkanAPIWrapper
18from .common.vulkantypes import \
19        VulkanAPI, makeVulkanTypeSimple, iterateVulkanType, VulkanTypeIterator, Atom, FuncExpr, FuncExprVal, FuncLambda
20
21from .wrapperdefs import VulkanWrapperGenerator
22from .wrapperdefs import VULKAN_STREAM_VAR_NAME
23from .wrapperdefs import ROOT_TYPE_VAR_NAME, ROOT_TYPE_PARAM
24from .wrapperdefs import STREAM_RET_TYPE
25from .wrapperdefs import MARSHAL_INPUT_VAR_NAME
26from .wrapperdefs import UNMARSHAL_INPUT_VAR_NAME
27from .wrapperdefs import PARAMETERS_MARSHALING
28from .wrapperdefs import PARAMETERS_MARSHALING_GUEST
29from .wrapperdefs import STYPE_OVERRIDE
30from .wrapperdefs import STRUCT_EXTENSION_PARAM, STRUCT_EXTENSION_PARAM_FOR_WRITE, EXTENSION_SIZE_WITH_STREAM_FEATURES_API_NAME
31from .wrapperdefs import API_PREFIX_RESERVEDMARSHAL
32from .wrapperdefs import API_PREFIX_RESERVEDUNMARSHAL
33
34from .marshalingdefs import CUSTOM_MARSHAL_TYPES
35class VulkanReservedMarshalingCodegen(VulkanTypeIterator):
36    def __init__(self,
37                 cgen,
38                 variant,
39                 streamVarName,
40                 rootTypeVarName,
41                 inputVarName,
42                 ptrVarName,
43                 marshalPrefix,
44                 handlemapPrefix,
45                 direction = "write",
46                 forApiOutput = False,
47                 dynAlloc = False,
48                 mapHandles = True,
49                 handleMapOverwrites = False,
50                 doFiltering = True,
51                 stackVar=None,
52                 stackArrSize=None):
53        self.cgen = cgen
54        self.variant = variant
55        self.direction = direction
56        self.processSimple = "write" if self.direction == "write" else "read"
57        self.forApiOutput = forApiOutput
58
59        self.checked = False
60
61        self.streamVarName = streamVarName
62        self.rootTypeVarName = rootTypeVarName
63        self.inputVarName = inputVarName
64        self.ptrVar = ptrVarName
65        self.marshalPrefix = marshalPrefix
66        self.handlemapPrefix = handlemapPrefix
67
68        self.exprAccessor = lambda t: self.cgen.generalAccess(t, parentVarName = self.inputVarName, asPtr = True)
69        self.exprValueAccessor = lambda t: self.cgen.generalAccess(t, parentVarName = self.inputVarName, asPtr = False)
70        self.exprPrimitiveValueAccessor = lambda t: self.cgen.generalAccess(t, parentVarName = self.inputVarName, asPtr = False)
71        self.lenAccessor = lambda t: self.cgen.generalLengthAccess(t, parentVarName = self.inputVarName)
72        self.lenAccessorGuard = lambda t: self.cgen.generalLengthAccessGuard(
73            t, parentVarName=self.inputVarName)
74        self.filterVarAccessor = lambda t: self.cgen.filterVarAccess(t, parentVarName = self.inputVarName)
75
76        self.dynAlloc = dynAlloc
77        self.mapHandles = mapHandles
78        self.handleMapOverwrites = handleMapOverwrites
79        self.doFiltering = doFiltering
80
81        self.stackVar = stackVar
82        self.stackArrSize = stackArrSize
83
84    def getTypeForStreaming(self, vulkanType):
85        res = copy(vulkanType)
86
87        if not vulkanType.accessibleAsPointer():
88            res = res.getForAddressAccess()
89
90        if vulkanType.staticArrExpr:
91            res = res.getForAddressAccess()
92
93        if self.direction == "write":
94            return res
95        else:
96            return res.getForNonConstAccess()
97
98    def makeCastExpr(self, vulkanType):
99        return "(%s)" % (
100            self.cgen.makeCTypeDecl(vulkanType, useParamName=False))
101
102    def genPtrIncr(self, sizeExpr):
103        self.cgen.stmt("*%s += %s" % (self.ptrVar, sizeExpr))
104
105    def genMemcpyAndIncr(self, varname, cast, toStreamExpr, sizeExpr, toBe = False, actualSize = 4):
106        if self.direction == "write":
107            self.cgen.stmt("memcpy(*%s, %s%s, %s)" % (varname, cast, toStreamExpr, sizeExpr))
108        else:
109            self.cgen.stmt("memcpy(%s%s, *%s, %s)" % (cast, toStreamExpr, varname, sizeExpr))
110
111        if toBe:
112            streamPrefix = "to"
113            if "read" == self.direction:
114                streamPrefix = "from"
115
116            streamMethod = streamPrefix
117
118            if 1 == actualSize:
119                streamMethod += "Byte"
120            elif 2 == actualSize:
121                streamMethod += "Be16"
122            elif 4 == actualSize:
123                streamMethod += "Be32"
124            elif 8 == actualSize:
125                streamMethod += "Be64"
126            else:
127                pass
128
129            streamNamespace = "gfxstream::guest" if self.variant == "guest" else "android::base"
130            if self.direction == "write":
131                self.cgen.stmt("%s::Stream::%s((uint8_t*)*%s)" % (streamNamespace, streamMethod, varname))
132            else:
133                self.cgen.stmt("%s::Stream::%s((uint8_t*)%s)" % (streamNamespace, streamMethod, toStreamExpr))
134
135        self.genPtrIncr(sizeExpr)
136
137    def genStreamCall(self, vulkanType, toStreamExpr, sizeExpr):
138        varname = self.ptrVar
139        cast = self.makeCastExpr(self.getTypeForStreaming(vulkanType))
140        self.genMemcpyAndIncr(varname, cast, toStreamExpr, sizeExpr)
141
142    def genPrimitiveStreamCall(self, vulkanType, access):
143        varname = self.ptrVar
144        self.cgen.memcpyPrimitive(
145            self.typeInfo,
146            "(*" + varname + ")",
147            access,
148            vulkanType,
149            self.variant,
150            direction=self.direction)
151        self.genPtrIncr(str(self.cgen.countPrimitive(
152            self.typeInfo,
153            vulkanType)))
154
155    def genHandleMappingCall(self, vulkanType, access, lenAccess, lenAccessGuard):
156
157        if lenAccess is None:
158            lenAccess = "1"
159            handle64Bytes = "8"
160        else:
161            handle64Bytes = "%s * 8" % lenAccess
162
163        handle64Var = self.cgen.var()
164        if lenAccess != "1":
165            self.cgen.beginIf(lenAccess)
166            self.cgen.stmt("uint8_t* %s_ptr = (uint8_t*)(*%s)" % (handle64Var, self.ptrVar))
167            handle64VarAccess = handle64Var
168            handle64VarType = \
169                makeVulkanTypeSimple(False, "uint64_t", 1, paramName=handle64Var)
170        else:
171            self.cgen.stmt("uint64_t %s" % handle64Var)
172            handle64VarAccess = "&%s" % handle64Var
173            handle64VarType = \
174                makeVulkanTypeSimple(False, "uint64_t", 0, paramName=handle64Var)
175
176        if "" == self.handlemapPrefix:
177            mapFunc = ("(%s)" % vulkanType.typeName)
178            mapFunc64 = ("(%s)" % "uint64_t")
179        else:
180            mapFunc = self.handlemapPrefix + vulkanType.typeName
181            mapFunc64 = mapFunc
182
183        if self.direction == "write":
184            if self.handleMapOverwrites:
185                self.cgen.stmt(
186                    "static_assert(8 == sizeof(%s), \"handle map overwrite requres %s to be 8 bytes long\")" % \
187                            (vulkanType.typeName, vulkanType.typeName))
188                if "1" == lenAccess:
189                    self.cgen.stmt("*%s = (%s)%s(*%s)" % (access, vulkanType.typeName, mapFunc, access))
190                    self.genStreamCall(vulkanType, access, "8 * %s" % lenAccess)
191                else:
192                    if lenAccessGuard is not None:
193                        self.cgen.beginIf(lenAccessGuard)
194                    self.cgen.beginFor("uint32_t k = 0", "k < %s" % lenAccess, "++k")
195                    self.cgen.stmt("%s[k] = (%s)%s(%s[k])" % (access, vulkanType.typeName, mapFunc, access))
196                    self.cgen.endFor()
197                    if lenAccessGuard is not None:
198                        self.cgen.endIf()
199                    self.genPtrIncr("8 * %s" % lenAccess)
200            else:
201                if "1" == lenAccess:
202                    self.cgen.stmt("*%s = %s((*%s))" % (handle64VarAccess, mapFunc64, access))
203                    self.genStreamCall(handle64VarType, handle64VarAccess, handle64Bytes)
204                else:
205                    if lenAccessGuard is not None:
206                        self.cgen.beginIf(lenAccessGuard)
207                    self.cgen.beginFor("uint32_t k = 0", "k < %s" % lenAccess, "++k")
208                    self.cgen.stmt("uint64_t tmpval = %s(%s[k])" % (mapFunc64, access))
209                    self.cgen.stmt("memcpy(%s_ptr + k * 8, &tmpval, sizeof(uint64_t))" % (handle64Var))
210                    self.cgen.endFor()
211                    if lenAccessGuard is not None:
212                        self.cgen.endIf()
213                    self.genPtrIncr("8 * %s" % lenAccess)
214        else:
215            if "1" == lenAccess:
216                self.genStreamCall(handle64VarType, handle64VarAccess, handle64Bytes)
217                self.cgen.stmt("*%s%s = (%s)%s((%s)(*%s))" % (
218                    self.makeCastExpr(vulkanType.getForNonConstAccess()), access,
219                    vulkanType.typeName, mapFunc, vulkanType.typeName, handle64VarAccess))
220            else:
221                self.genPtrIncr("8 * %s" % lenAccess)
222                if lenAccessGuard is not None:
223                    self.cgen.beginIf(lenAccessGuard)
224                self.cgen.beginFor("uint32_t k = 0", "k < %s" % lenAccess, "++k")
225                self.cgen.stmt("uint64_t tmpval; memcpy(&tmpval, %s_ptr + k * 8, sizeof(uint64_t))" % handle64Var)
226                self.cgen.stmt("*((%s%s) + k) = (%s)%s((%s)tmpval)" % (
227                    self.makeCastExpr(vulkanType.getForNonConstAccess()), access,
228                    vulkanType.typeName, mapFunc, vulkanType.typeName))
229                if lenAccessGuard is not None:
230                    self.cgen.endIf()
231                self.cgen.endFor()
232
233        if lenAccess != "1":
234            self.cgen.endIf()
235
236    def doAllocSpace(self, vulkanType):
237        if self.dynAlloc and self.direction == "read":
238            access = self.exprAccessor(vulkanType)
239            lenAccess = self.lenAccessor(vulkanType)
240            sizeof = self.cgen.sizeofExpr(vulkanType.getForValueAccess())
241            if lenAccess:
242                bytesExpr = "%s * %s" % (lenAccess, sizeof)
243            else:
244                bytesExpr = sizeof
245                lenAccess = "1"
246
247            if self.stackVar:
248                if self.stackArrSize != lenAccess:
249                    self.cgen.beginIf("%s <= %s" % (lenAccess, self.stackArrSize))
250
251                self.cgen.stmt(
252                        "%s = %s%s" % (access, self.makeCastExpr(vulkanType.getForNonConstAccess()), self.stackVar))
253
254                if self.stackArrSize != lenAccess:
255                    self.cgen.endIf()
256                    self.cgen.beginElse()
257
258                if self.stackArrSize != lenAccess:
259                    self.cgen.stmt(
260                            "%s->alloc((void**)&%s, %s)" %
261                            (self.streamVarName,
262                                access, bytesExpr))
263
264                if self.stackArrSize != lenAccess:
265                    self.cgen.endIf()
266            else:
267                self.cgen.stmt(
268                        "%s->alloc((void**)&%s, %s)" %
269                        (self.streamVarName,
270                            access, bytesExpr))
271
272    def getOptionalStringFeatureExpr(self, vulkanType):
273        streamFeature = vulkanType.getProtectStreamFeature()
274        if streamFeature is None:
275            return None
276        return "%s->getFeatureBits() & %s" % (self.streamVarName, streamFeature)
277
278    def onCheck(self, vulkanType):
279        if self.forApiOutput:
280            return
281
282        featureExpr = self.getOptionalStringFeatureExpr(vulkanType)
283        self.checked = True
284        access = self.exprAccessor(vulkanType)
285        needConsistencyCheck = False
286
287        self.cgen.line("// WARNING PTR CHECK")
288        if (self.dynAlloc and self.direction == "read") or self.direction == "write":
289            checkAccess = self.exprAccessor(vulkanType)
290            addrExpr = "&" + checkAccess
291            sizeExpr = self.cgen.sizeofExpr(vulkanType)
292        else:
293            checkName = "check_%s" % vulkanType.paramName
294            self.cgen.stmt("%s %s" % (
295                self.cgen.makeCTypeDecl(vulkanType, useParamName = False), checkName))
296            checkAccess = checkName
297            addrExpr = "&" + checkAccess
298            sizeExpr = self.cgen.sizeofExpr(vulkanType)
299            needConsistencyCheck = True
300
301        if featureExpr is not None:
302            self.cgen.beginIf(featureExpr)
303
304        self.genPrimitiveStreamCall(
305            vulkanType,
306            checkAccess)
307
308        if featureExpr is not None:
309            self.cgen.endIf()
310
311        if featureExpr is not None:
312            self.cgen.beginIf("(!(%s) || %s)" % (featureExpr, access))
313        else:
314            self.cgen.beginIf(access)
315
316        if needConsistencyCheck and featureExpr is None:
317            self.cgen.beginIf("!(%s)" % checkName)
318            self.cgen.stmt(
319                "fprintf(stderr, \"fatal: %s inconsistent between guest and host\\n\")" % (access))
320            self.cgen.endIf()
321
322    def onCheckWithNullOptionalStringFeature(self, vulkanType):
323        self.cgen.beginIf("%s->getFeatureBits() & VULKAN_STREAM_FEATURE_NULL_OPTIONAL_STRINGS_BIT" % self.streamVarName)
324        self.onCheck(vulkanType)
325
326    def endCheckWithNullOptionalStringFeature(self, vulkanType):
327        self.endCheck(vulkanType)
328        self.cgen.endIf()
329        self.cgen.beginElse()
330
331    def finalCheckWithNullOptionalStringFeature(self, vulkanType):
332        self.cgen.endElse()
333
334    def endCheck(self, vulkanType):
335
336        if self.checked:
337            self.cgen.endIf()
338            self.checked = False
339
340    def genFilterFunc(self, filterfunc, env):
341
342        def loop(expr, lambdaEnv={}):
343            def do_func(expr):
344                fnamestr = expr.name.name
345                if "not" == fnamestr:
346                    return "!(%s)" % (loop(expr.args[0], lambdaEnv))
347                if "eq" == fnamestr:
348                    return "(%s == %s)" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv))
349                if "and" == fnamestr:
350                    return "(%s && %s)" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv))
351                if "or" == fnamestr:
352                    return "(%s || %s)" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv))
353                if "bitwise_and" == fnamestr:
354                    return "(%s & %s)" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv))
355                if "getfield" == fnamestr:
356                    ptrlevels = get_ptrlevels(expr.args[0].val.name)
357                    if ptrlevels == 0:
358                        return "%s.%s" % (loop(expr.args[0], lambdaEnv), expr.args[1].val)
359                    else:
360                        return "(%s(%s)).%s" % ("*" * ptrlevels, loop(expr.args[0], lambdaEnv), expr.args[1].val)
361
362                if "if" == fnamestr:
363                    return "((%s) ? (%s) : (%s))" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv), loop(expr.args[2], lambdaEnv))
364
365                return "%s(%s)" % (fnamestr, ", ".join(map(lambda e: loop(e, lambdaEnv), expr.args)))
366
367            def do_expratom(atomname, lambdaEnv= {}):
368                if lambdaEnv.get(atomname, None) is not None:
369                    return atomname
370
371                enventry = env.get(atomname, None)
372                if None != enventry:
373                    return self.getEnvAccessExpr(atomname)
374                return atomname
375
376            def get_ptrlevels(atomname, lambdaEnv= {}):
377                if lambdaEnv.get(atomname, None) is not None:
378                    return 0
379
380                enventry = env.get(atomname, None)
381                if None != enventry:
382                    return self.getPointerIndirectionLevels(atomname)
383
384                return 0
385
386            def do_exprval(expr, lambdaEnv= {}):
387                expratom = expr.val
388
389                if Atom == type(expratom):
390                    return do_expratom(expratom.name, lambdaEnv)
391
392                return "%s" % expratom
393
394            def do_lambda(expr, lambdaEnv= {}):
395                params = expr.vs
396                body = expr.body
397                newEnv = {}
398
399                for (k, v) in lambdaEnv.items():
400                    newEnv[k] = v
401
402                for p in params:
403                    newEnv[p.name] = p.typ
404
405                return "[](%s) { return %s; }" % (", ".join(list(map(lambda p: "%s %s" % (p.typ, p.name), params))), loop(body, lambdaEnv=newEnv))
406
407            if FuncExpr == type(expr):
408                return do_func(expr)
409            if FuncLambda == type(expr):
410                return do_lambda(expr)
411            elif FuncExprVal == type(expr):
412                return do_exprval(expr)
413
414        return loop(filterfunc)
415
416    def beginFilterGuard(self, vulkanType):
417        if vulkanType.filterVar == None:
418            return
419
420        if self.doFiltering == False:
421            return
422
423        filterVarAccess = self.getEnvAccessExpr(vulkanType.filterVar)
424
425        filterValsExpr = None
426        filterFuncExpr = None
427        filterExpr = None
428
429        filterFeature = "%s->getFeatureBits() & VULKAN_STREAM_FEATURE_IGNORED_HANDLES_BIT" % self.streamVarName
430
431        if None != vulkanType.filterVals:
432            filterValsExpr = " || ".join(map(lambda filterval: "(%s == %s)" % (filterval, filterVarAccess), vulkanType.filterVals))
433
434        if None != vulkanType.filterFunc:
435            filterFuncExpr = self.genFilterFunc(vulkanType.filterFunc, self.currentStructInfo.environment)
436
437        if None != filterValsExpr and None != filterFuncExpr:
438            filterExpr = "%s || %s" % (filterValsExpr, filterFuncExpr)
439        elif None == filterValsExpr and None == filterFuncExpr:
440            # Assume is bool
441            self.cgen.beginIf(filterVarAccess)
442        elif None != filterValsExpr:
443            self.cgen.beginIf("(!(%s) || (%s))" % (filterFeature, filterValsExpr))
444        elif None != filterFuncExpr:
445            self.cgen.beginIf("(!(%s) || (%s))" % (filterFeature, filterFuncExpr))
446
447    def endFilterGuard(self, vulkanType, cleanupExpr=None):
448        if vulkanType.filterVar == None:
449            return
450
451        if self.doFiltering == False:
452            return
453
454        if cleanupExpr == None:
455            self.cgen.endIf()
456        else:
457            self.cgen.endIf()
458            self.cgen.beginElse()
459            self.cgen.stmt(cleanupExpr)
460            self.cgen.endElse()
461
462    def getEnvAccessExpr(self, varName):
463        parentEnvEntry = self.currentStructInfo.environment.get(varName, None)
464
465        if parentEnvEntry != None:
466            isParentMember = parentEnvEntry["structmember"]
467
468            if isParentMember:
469                envAccess = self.exprValueAccessor(list(filter(lambda member: member.paramName == varName, self.currentStructInfo.members))[0])
470            else:
471                envAccess = varName
472            return envAccess
473
474        return None
475
476    def getPointerIndirectionLevels(self, varName):
477        parentEnvEntry = self.currentStructInfo.environment.get(varName, None)
478
479        if parentEnvEntry != None:
480            isParentMember = parentEnvEntry["structmember"]
481
482            if isParentMember:
483                return list(filter(lambda member: member.paramName == varName, self.currentStructInfo.members))[0].pointerIndirectionLevels
484            else:
485                return 0
486            return 0
487
488        return 0
489
490
491    def onCompoundType(self, vulkanType):
492
493        access = self.exprAccessor(vulkanType)
494        lenAccess = self.lenAccessor(vulkanType)
495
496        self.beginFilterGuard(vulkanType)
497
498        if vulkanType.pointerIndirectionLevels > 0:
499            self.doAllocSpace(vulkanType)
500
501        if lenAccess is not None:
502            loopVar = "i"
503            access = "%s + %s" % (access, loopVar)
504            forInit = "uint32_t %s = 0" % loopVar
505            forCond = "%s < (uint32_t)%s" % (loopVar, lenAccess)
506            forIncr = "++%s" % loopVar
507            self.cgen.beginFor(forInit, forCond, forIncr)
508
509        accessWithCast = "%s(%s)" % (self.makeCastExpr(
510            self.getTypeForStreaming(vulkanType)), access)
511
512        callParams = [self.streamVarName, self.rootTypeVarName, accessWithCast, self.ptrVar]
513
514        for (bindName, localName) in vulkanType.binds.items():
515            callParams.append(self.getEnvAccessExpr(localName))
516
517        self.cgen.funcCall(None, self.marshalPrefix + vulkanType.typeName,
518                           callParams)
519
520        if lenAccess is not None:
521            self.cgen.endFor()
522
523        if self.direction == "read":
524            self.endFilterGuard(vulkanType, "%s = 0" % self.exprAccessor(vulkanType))
525        else:
526            self.endFilterGuard(vulkanType)
527
528    def onString(self, vulkanType):
529        access = self.exprAccessor(vulkanType)
530
531        if self.direction == "write":
532            self.cgen.beginBlock()
533            self.cgen.stmt("uint32_t l = %s ? strlen(%s): 0" % (access, access))
534            self.genMemcpyAndIncr(self.ptrVar, "(uint32_t*)" ,"&l", "sizeof(uint32_t)", toBe = True, actualSize = 4)
535            self.genMemcpyAndIncr(self.ptrVar, "(char*)", access, "l")
536            self.cgen.endBlock()
537        else:
538            castExpr = \
539                self.makeCastExpr( \
540                    self.getTypeForStreaming( \
541                        vulkanType.getForAddressAccess()))
542            self.cgen.stmt( \
543                "%s->loadStringInPlaceWithStreamPtr(%s&%s, %s)" % (self.streamVarName, castExpr, access, self.ptrVar))
544
545    def onStringArray(self, vulkanType):
546        access = self.exprAccessor(vulkanType)
547        lenAccess = self.lenAccessor(vulkanType)
548        lenAccessGuard = self.lenAccessorGuard(vulkanType)
549
550        if self.direction == "write":
551            self.cgen.beginBlock()
552
553            self.cgen.stmt("uint32_t c = 0")
554            if lenAccessGuard is not None:
555                self.cgen.beginIf(lenAccessGuard)
556            self.cgen.stmt("c = %s" % (lenAccess))
557            if lenAccessGuard is not None:
558                self.cgen.endIf()
559            self.genMemcpyAndIncr(self.ptrVar, "(uint32_t*)" ,"&c", "sizeof(uint32_t)", toBe = True, actualSize = 4)
560
561            self.cgen.beginFor("uint32_t i = 0", "i < c", "++i")
562            self.cgen.stmt("uint32_t l = %s ? strlen(%s[i]): 0" % (access, access))
563            self.genMemcpyAndIncr(self.ptrVar, "(uint32_t*)" ,"&l", "sizeof(uint32_t)", toBe = True, actualSize = 4)
564            self.cgen.beginIf("l")
565            self.genMemcpyAndIncr(self.ptrVar, "(char*)", "(%s[i])" % access, "l")
566            self.cgen.endIf()
567            self.cgen.endFor()
568
569            self.cgen.endBlock()
570        else:
571            castExpr = \
572                self.makeCastExpr( \
573                    self.getTypeForStreaming( \
574                        vulkanType.getForAddressAccess()))
575
576            self.cgen.stmt("%s->loadStringArrayInPlaceWithStreamPtr(%s&%s, %s)" % (self.streamVarName, castExpr, access, self.ptrVar))
577
578    def onStaticArr(self, vulkanType):
579        access = self.exprValueAccessor(vulkanType)
580        lenAccess = self.lenAccessor(vulkanType)
581        finalLenExpr = "%s * %s" % (lenAccess, self.cgen.sizeofExpr(vulkanType))
582        self.genStreamCall(vulkanType, access, finalLenExpr)
583
584    # Old version VkEncoder may have some sType values conflict with VkDecoder
585    # of new versions. For host decoder, it should not carry the incorrect old
586    # sType values to the |forUnmarshaling| struct. Instead it should overwrite
587    # the sType value.
588    def overwriteSType(self, vulkanType):
589        if self.direction == "read":
590            sTypeParam = copy(vulkanType)
591            sTypeParam.paramName = "sType"
592            sTypeAccess = self.exprAccessor(sTypeParam)
593
594            typeName = vulkanType.parent.typeName
595            if typeName in STYPE_OVERRIDE:
596                self.cgen.stmt("%s = %s" %
597                               (sTypeAccess, STYPE_OVERRIDE[typeName]))
598
599    def onStructExtension(self, vulkanType):
600        self.overwriteSType(vulkanType)
601
602        sTypeParam = copy(vulkanType)
603        sTypeParam.paramName = "sType"
604
605        access = self.exprAccessor(vulkanType)
606        sizeVar = "%s_size" % vulkanType.paramName
607
608        if self.direction == "read":
609            castedAccessExpr = "(%s)(%s)" % ("void*", access)
610        else:
611            castedAccessExpr = access
612
613        sTypeAccess = self.exprAccessor(sTypeParam)
614        self.cgen.beginIf("%s == VK_STRUCTURE_TYPE_MAX_ENUM" % self.rootTypeVarName)
615        self.cgen.stmt("%s = %s" % (self.rootTypeVarName, sTypeAccess))
616        self.cgen.endIf()
617
618        if self.direction == "read" and self.dynAlloc:
619            self.cgen.stmt("uint32_t %s" % sizeVar)
620
621            self.genMemcpyAndIncr(self.ptrVar, "(uint32_t*)", "&" + sizeVar, "sizeof(uint32_t)", toBe = True, actualSize = 4)
622
623            self.cgen.stmt("%s = nullptr" % access)
624            self.cgen.beginIf(sizeVar)
625            self.cgen.stmt( \
626                    "%s->alloc((void**)&%s, sizeof(VkStructureType))" %
627                    (self.streamVarName, access))
628
629            self.genStreamCall(vulkanType, access, "sizeof(VkStructureType)")
630            self.cgen.stmt("VkStructureType extType = *(VkStructureType*)(%s)" % access)
631            self.cgen.stmt( \
632                    "%s->alloc((void**)&%s, %s(%s->getFeatureBits(), %s, %s))" %
633                    (self.streamVarName, access, EXTENSION_SIZE_WITH_STREAM_FEATURES_API_NAME, self.streamVarName, self.rootTypeVarName, access))
634            self.cgen.stmt("*(VkStructureType*)%s = extType" % access)
635
636            self.cgen.funcCall(None, self.marshalPrefix + "extension_struct",
637                [self.streamVarName, self.rootTypeVarName, castedAccessExpr, self.ptrVar])
638            self.cgen.endIf()
639        else:
640
641            self.cgen.funcCall(None, self.marshalPrefix + "extension_struct",
642                [self.streamVarName, self.rootTypeVarName, castedAccessExpr, self.ptrVar])
643
644    def onPointer(self, vulkanType):
645        access = self.exprAccessor(vulkanType)
646
647        lenAccess = self.lenAccessor(vulkanType)
648        lenAccessGuard = self.lenAccessorGuard(vulkanType)
649
650        self.beginFilterGuard(vulkanType)
651        self.doAllocSpace(vulkanType)
652
653        if vulkanType.filterVar != None:
654            print("onPointer Needs filter: %s filterVar %s" % (access, vulkanType.filterVar))
655
656        if vulkanType.isHandleType() and self.mapHandles:
657            self.genHandleMappingCall(
658                vulkanType, access, lenAccess, lenAccessGuard)
659        else:
660            if self.typeInfo.isNonAbiPortableType(vulkanType.typeName):
661                if lenAccess is not None:
662                    if lenAccessGuard is not None:
663                        self.cgen.beginIf(lenAccessGuard)
664                    self.cgen.beginFor("uint32_t i = 0", "i < (uint32_t)%s" % lenAccess, "++i")
665                    self.genPrimitiveStreamCall(vulkanType.getForValueAccess(), "%s[i]" % access)
666                    self.cgen.endFor()
667                    if lenAccessGuard is not None:
668                        self.cgen.endIf()
669                else:
670                    self.genPrimitiveStreamCall(vulkanType.getForValueAccess(), "(*%s)" % access)
671            else:
672                if lenAccess is not None:
673                    finalLenExpr = "%s * %s" % (
674                        lenAccess, self.cgen.sizeofExpr(vulkanType.getForValueAccess()))
675                else:
676                    finalLenExpr = "%s" % (
677                        self.cgen.sizeofExpr(vulkanType.getForValueAccess()))
678                self.genStreamCall(vulkanType, access, finalLenExpr)
679
680        if self.direction == "read":
681            self.endFilterGuard(vulkanType, "%s = 0" % access)
682        else:
683            self.endFilterGuard(vulkanType)
684
685    def onValue(self, vulkanType):
686        self.beginFilterGuard(vulkanType)
687
688        if vulkanType.isHandleType() and self.mapHandles:
689            access = self.exprAccessor(vulkanType)
690            if vulkanType.filterVar != None:
691                print("onValue Needs filter: %s filterVar %s" % (access, vulkanType.filterVar))
692            self.genHandleMappingCall(
693                vulkanType.getForAddressAccess(), access, "1", None)
694        elif self.typeInfo.isNonAbiPortableType(vulkanType.typeName):
695            access = self.exprPrimitiveValueAccessor(vulkanType)
696            self.genPrimitiveStreamCall(vulkanType, access)
697        else:
698            access = self.exprAccessor(vulkanType)
699            self.genStreamCall(vulkanType, access, self.cgen.sizeofExpr(vulkanType))
700
701        self.endFilterGuard(vulkanType)
702
703    def streamLetParameter(self, structInfo, letParamInfo):
704        filterFeature = "%s->getFeatureBits() & VULKAN_STREAM_FEATURE_IGNORED_HANDLES_BIT" % self.streamVarName
705        self.cgen.stmt("%s %s = 1" % (letParamInfo.typeName, letParamInfo.paramName))
706
707        self.cgen.beginIf(filterFeature)
708
709        if self.direction == "write":
710            bodyExpr = self.currentStructInfo.environment[letParamInfo.paramName]["body"]
711            self.cgen.stmt("%s = %s" % (letParamInfo.paramName, self.genFilterFunc(bodyExpr, self.currentStructInfo.environment)))
712
713        self.genPrimitiveStreamCall(letParamInfo, letParamInfo.paramName)
714
715        self.cgen.endIf()
716
717class VulkanReservedMarshaling(VulkanWrapperGenerator):
718
719    def __init__(self, module, typeInfo, variant="host"):
720        VulkanWrapperGenerator.__init__(self, module, typeInfo)
721
722        self.cgenHeader = CodeGen()
723        self.cgenImpl = CodeGen()
724
725        self.variant = variant
726
727        self.currentFeature = None
728        self.apiOpcodes = {}
729        self.dynAlloc = self.variant != "guest"
730        self.ptrVarName = "ptr"
731        self.ptrVarType = makeVulkanTypeSimple(False, "uint8_t", 2, self.ptrVarName)
732        self.ptrVarTypeUnmarshal = makeVulkanTypeSimple(False, "uint8_t", 2, self.ptrVarName)
733
734        if self.variant == "guest":
735            self.marshalingParams = PARAMETERS_MARSHALING_GUEST
736        else:
737            self.marshalingParams = PARAMETERS_MARSHALING
738
739        self.writeCodegen = \
740            VulkanReservedMarshalingCodegen(
741                None,
742                self.variant,
743                VULKAN_STREAM_VAR_NAME,
744                ROOT_TYPE_VAR_NAME,
745                MARSHAL_INPUT_VAR_NAME,
746                self.ptrVarName,
747                API_PREFIX_RESERVEDMARSHAL,
748                "get_host_u64_" if "guest" == self.variant else "",
749                direction = "write")
750
751        self.readCodegen = \
752            VulkanReservedMarshalingCodegen(
753                None,
754                self.variant,
755                VULKAN_STREAM_VAR_NAME,
756                ROOT_TYPE_VAR_NAME,
757                UNMARSHAL_INPUT_VAR_NAME,
758                self.ptrVarName,
759                API_PREFIX_RESERVEDUNMARSHAL,
760                "unbox_" if "host" == self.variant else "",
761                direction = "read",
762                dynAlloc=self.dynAlloc)
763
764        self.knownDefs = {}
765
766        self.extensionMarshalPrototype = \
767            VulkanAPI(API_PREFIX_RESERVEDMARSHAL + "extension_struct",
768                      STREAM_RET_TYPE,
769                      self.marshalingParams +
770                      [STRUCT_EXTENSION_PARAM, self.ptrVarType])
771
772        self.extensionUnmarshalPrototype = \
773            VulkanAPI(API_PREFIX_RESERVEDUNMARSHAL + "extension_struct",
774                      STREAM_RET_TYPE,
775                      self.marshalingParams +
776                      [STRUCT_EXTENSION_PARAM_FOR_WRITE, self.ptrVarTypeUnmarshal])
777
778    def onBegin(self,):
779        VulkanWrapperGenerator.onBegin(self)
780        self.module.appendImpl(self.cgenImpl.makeFuncDecl(self.extensionMarshalPrototype))
781        self.module.appendImpl(self.cgenImpl.makeFuncDecl(self.extensionUnmarshalPrototype))
782
783    def onBeginFeature(self, featureName, featureType):
784        VulkanWrapperGenerator.onBeginFeature(self, featureName, featureType)
785        self.currentFeature = featureName
786
787    def onGenType(self, typeXml, name, alias):
788        VulkanWrapperGenerator.onGenType(self, typeXml, name, alias)
789
790        if name in self.knownDefs:
791            return
792
793        category = self.typeInfo.categoryOf(name)
794
795        if category in ["struct", "union"] and alias:
796            if self.variant != "host":
797                self.module.appendHeader(
798                    self.cgenHeader.makeFuncAlias(API_PREFIX_RESERVEDMARSHAL + name,
799                                                  API_PREFIX_RESERVEDMARSHAL + alias))
800            if self.variant != "guest":
801                self.module.appendHeader(
802                    self.cgenHeader.makeFuncAlias(API_PREFIX_RESERVEDUNMARSHAL + name,
803                                                  API_PREFIX_RESERVEDUNMARSHAL + alias))
804
805        if category in ["struct", "union"] and not alias:
806
807            structInfo = self.typeInfo.structs[name]
808
809            marshalParams = self.marshalingParams + \
810                [makeVulkanTypeSimple(True, name, 1, MARSHAL_INPUT_VAR_NAME),
811                        self.ptrVarType]
812
813            freeParams = []
814            letParams = []
815
816            for (envname, bindingInfo) in list(sorted(structInfo.environment.items(), key = lambda kv: kv[0])):
817                if None == bindingInfo["binding"]:
818                    freeParams.append(makeVulkanTypeSimple(True, bindingInfo["type"], 0, envname))
819                else:
820                    if not bindingInfo["structmember"]:
821                        letParams.append(makeVulkanTypeSimple(True, bindingInfo["type"], 0, envname))
822
823            marshalPrototype = \
824                VulkanAPI(API_PREFIX_RESERVEDMARSHAL + name,
825                          STREAM_RET_TYPE,
826                          marshalParams + freeParams)
827
828            marshalPrototypeNoFilter = \
829                VulkanAPI(API_PREFIX_RESERVEDMARSHAL + name,
830                          STREAM_RET_TYPE,
831                          marshalParams)
832
833            def structMarshalingCustom(cgen):
834                self.writeCodegen.cgen = cgen
835                self.writeCodegen.currentStructInfo = structInfo
836                marshalingCode = \
837                    CUSTOM_MARSHAL_TYPES[name]["common"] + \
838                    CUSTOM_MARSHAL_TYPES[name]["reservedmarshaling"].format(
839                        streamVarName=self.writeCodegen.streamVarName,
840                        rootTypeVarName=self.writeCodegen.rootTypeVarName,
841                        inputVarName=self.writeCodegen.inputVarName,
842                        newInputVarName=self.writeCodegen.inputVarName + "_new")
843                for line in marshalingCode.split('\n'):
844                    cgen.line(line)
845
846            def structMarshalingDef(cgen):
847                self.writeCodegen.cgen = cgen
848                self.writeCodegen.currentStructInfo = structInfo
849                self.writeCodegen.cgen.stmt("(void)%s" % VULKAN_STREAM_VAR_NAME)
850                self.writeCodegen.cgen.stmt("(void)%s" % ROOT_TYPE_VAR_NAME)
851
852                if category == "struct":
853                    # marshal 'let' parameters first
854                    for letp in letParams:
855                        self.writeCodegen.streamLetParameter(self.typeInfo, letp)
856
857                    for member in structInfo.members:
858                        iterateVulkanType(self.typeInfo, member, self.writeCodegen)
859                if category == "union":
860                    iterateVulkanType(self.typeInfo, structInfo.members[0], self.writeCodegen)
861
862            def structMarshalingDefNoFilter(cgen):
863                self.writeCodegen.cgen = cgen
864                self.writeCodegen.currentStructInfo = structInfo
865                self.writeCodegen.doFiltering = False
866                self.writeCodegen.cgen.stmt("(void)%s" % VULKAN_STREAM_VAR_NAME)
867                self.writeCodegen.cgen.stmt("(void)%s" % ROOT_TYPE_VAR_NAME)
868
869                if category == "struct":
870                    # marshal 'let' parameters first
871                    for letp in letParams:
872                        self.writeCodegen.streamLetParameter(self.typeInfo, letp)
873
874                    for member in structInfo.members:
875                        iterateVulkanType(self.typeInfo, member, self.writeCodegen)
876                if category == "union":
877                    iterateVulkanType(self.typeInfo, structInfo.members[0], self.writeCodegen)
878                self.writeCodegen.doFiltering = True
879
880            if self.variant != "host":
881                self.module.appendHeader(
882                    self.cgenHeader.makeFuncDecl(marshalPrototype))
883
884                if name in CUSTOM_MARSHAL_TYPES:
885                    self.module.appendImpl(
886                        self.cgenImpl.makeFuncImpl(
887                            marshalPrototype, structMarshalingCustom))
888                else:
889                    self.module.appendImpl(
890                        self.cgenImpl.makeFuncImpl(
891                            marshalPrototype, structMarshalingDef))
892
893                if freeParams != []:
894                    self.module.appendHeader(
895                        self.cgenHeader.makeFuncDecl(marshalPrototypeNoFilter))
896                    self.module.appendImpl(
897                        self.cgenImpl.makeFuncImpl(
898                            marshalPrototypeNoFilter, structMarshalingDefNoFilter))
899
900            unmarshalPrototype = \
901                VulkanAPI(API_PREFIX_RESERVEDUNMARSHAL + name,
902                          STREAM_RET_TYPE,
903                          self.marshalingParams + [makeVulkanTypeSimple(False, name, 1, UNMARSHAL_INPUT_VAR_NAME), self.ptrVarTypeUnmarshal] + freeParams)
904
905            unmarshalPrototypeNoFilter = \
906                VulkanAPI(API_PREFIX_RESERVEDUNMARSHAL + name,
907                          STREAM_RET_TYPE,
908                          self.marshalingParams + [makeVulkanTypeSimple(False, name, 1, UNMARSHAL_INPUT_VAR_NAME), self.ptrVarTypeUnmarshal])
909
910            def structUnmarshalingCustom(cgen):
911                self.readCodegen.cgen = cgen
912                self.readCodegen.currentStructInfo = structInfo
913                unmarshalingCode = \
914                    CUSTOM_MARSHAL_TYPES[name]["common"] + \
915                    CUSTOM_MARSHAL_TYPES[name]["reservedunmarshaling"].format(
916                        streamVarName=self.readCodegen.streamVarName,
917                        rootTypeVarName=self.readCodegen.rootTypeVarName,
918                        inputVarName=self.readCodegen.inputVarName,
919                        newInputVarName=self.readCodegen.inputVarName + "_new")
920                for line in unmarshalingCode.split('\n'):
921                    cgen.line(line)
922
923            def structUnmarshalingDef(cgen):
924                self.readCodegen.cgen = cgen
925                self.readCodegen.currentStructInfo = structInfo
926                if category == "struct":
927                    # unmarshal 'let' parameters first
928                    for letp in letParams:
929                        self.readCodegen.streamLetParameter(self.typeInfo, letp)
930
931                    for member in structInfo.members:
932                        iterateVulkanType(self.typeInfo, member, self.readCodegen)
933                if category == "union":
934                    iterateVulkanType(self.typeInfo, structInfo.members[0], self.readCodegen)
935
936            def structUnmarshalingDefNoFilter(cgen):
937                self.readCodegen.cgen = cgen
938                self.readCodegen.currentStructInfo = structInfo
939                self.readCodegen.doFiltering = False
940                if category == "struct":
941                    # unmarshal 'let' parameters first
942                    for letp in letParams:
943                        iterateVulkanType(self.typeInfo, letp, self.readCodegen)
944                    for member in structInfo.members:
945                        iterateVulkanType(self.typeInfo, member, self.readCodegen)
946                if category == "union":
947                    iterateVulkanType(self.typeInfo, structInfo.members[0], self.readCodegen)
948                self.readCodegen.doFiltering = True
949
950            if self.variant != "guest":
951                self.module.appendHeader(
952                    self.cgenHeader.makeFuncDecl(unmarshalPrototype))
953
954                if name in CUSTOM_MARSHAL_TYPES:
955                    self.module.appendImpl(
956                        self.cgenImpl.makeFuncImpl(
957                            unmarshalPrototype, structUnmarshalingCustom))
958                else:
959                    self.module.appendImpl(
960                        self.cgenImpl.makeFuncImpl(
961                            unmarshalPrototype, structUnmarshalingDef))
962
963                if freeParams != []:
964                    self.module.appendHeader(
965                        self.cgenHeader.makeFuncDecl(unmarshalPrototypeNoFilter))
966                    self.module.appendImpl(
967                        self.cgenImpl.makeFuncImpl(
968                            unmarshalPrototypeNoFilter, structUnmarshalingDefNoFilter))
969
970    def onGenCmd(self, cmdinfo, name, alias):
971        VulkanWrapperGenerator.onGenCmd(self, cmdinfo, name, alias)
972
973    def doExtensionStructMarshalingCodegen(self, cgen, retType, extParam, forEach, funcproto, direction):
974        accessVar = "structAccess"
975        sizeVar = "currExtSize"
976        cgen.stmt("VkInstanceCreateInfo* %s = (VkInstanceCreateInfo*)(%s)" % (accessVar, extParam.paramName))
977        cgen.stmt("uint32_t %s = %s(%s->getFeatureBits(), %s, %s)" % (sizeVar, EXTENSION_SIZE_WITH_STREAM_FEATURES_API_NAME, VULKAN_STREAM_VAR_NAME, ROOT_TYPE_VAR_NAME, extParam.paramName))
978
979        cgen.beginIf("!%s && %s" % (sizeVar, extParam.paramName))
980
981        cgen.line("// unknown struct extension; skip and call on its pNext field");
982        cgen.funcCall(None, funcproto.name, ["vkStream", ROOT_TYPE_VAR_NAME, "(void*)%s->pNext" % accessVar, self.ptrVarName])
983        cgen.stmt("return")
984
985        cgen.endIf()
986        cgen.beginElse()
987
988        cgen.line("// known or null extension struct")
989
990        streamNamespace = "gfxstream::guest" if self.variant == "guest" else "android::base"
991
992        if direction == "write":
993            cgen.stmt("memcpy(*%s, &%s, sizeof(uint32_t));" % (self.ptrVarName, sizeVar))
994            cgen.stmt("%s::Stream::toBe32((uint8_t*)*%s); *%s += sizeof(uint32_t)" % (streamNamespace, self.ptrVarName, self.ptrVarName))
995        elif not self.dynAlloc:
996            cgen.stmt("memcpy(&%s, *%s, sizeof(uint32_t));" % (sizeVar, self.ptrVarName))
997            cgen.stmt("%s::Stream::fromBe32((uint8_t*)&%s); *%s += sizeof(uint32_t)" % (streamNamespace, sizeVar, self.ptrVarName))
998
999        cgen.beginIf("!%s" % (sizeVar))
1000        cgen.line("// exit if this was a null extension struct (size == 0 in this branch)")
1001        cgen.stmt("return")
1002        cgen.endIf()
1003
1004        cgen.endIf()
1005
1006        # Now we can do stream stuff
1007        if direction == "write":
1008            cgen.stmt("memcpy(*%s, %s, sizeof(VkStructureType)); *%s += sizeof(VkStructureType)" % (self.ptrVarName, extParam.paramName, self.ptrVarName))
1009        elif not self.dynAlloc:
1010            cgen.stmt("uint64_t pNext_placeholder")
1011            placeholderAccess = "(&pNext_placeholder)"
1012            cgen.stmt("memcpy(%s, *%s, sizeof(VkStructureType)); *%s += sizeof(VkStructureType)" % (placeholderAccess, self.ptrVarName, self.ptrVarName))
1013            cgen.stmt("(void)pNext_placeholder")
1014
1015        def fatalDefault(cgen):
1016            cgen.line("// fatal; the switch is only taken if the extension struct is known");
1017            cgen.stmt("abort()")
1018            pass
1019
1020        self.emitForEachStructExtension(
1021            cgen,
1022            retType,
1023            extParam,
1024            forEach,
1025            defaultEmit=fatalDefault,
1026            rootTypeVar=ROOT_TYPE_PARAM)
1027
1028    def onEnd(self,):
1029        VulkanWrapperGenerator.onEnd(self)
1030
1031        def forEachExtensionMarshal(ext, castedAccess, cgen):
1032            cgen.funcCall(None, API_PREFIX_RESERVEDMARSHAL + ext.name,
1033                          [VULKAN_STREAM_VAR_NAME, ROOT_TYPE_VAR_NAME, castedAccess, self.ptrVarName])
1034
1035        def forEachExtensionUnmarshal(ext, castedAccess, cgen):
1036            cgen.funcCall(None, API_PREFIX_RESERVEDUNMARSHAL + ext.name,
1037                          [VULKAN_STREAM_VAR_NAME, ROOT_TYPE_VAR_NAME, castedAccess, self.ptrVarName])
1038
1039        if self.variant != "host":
1040            self.module.appendImpl(
1041                self.cgenImpl.makeFuncImpl(
1042                    self.extensionMarshalPrototype,
1043                    lambda cgen: self.doExtensionStructMarshalingCodegen(
1044                        cgen,
1045                        STREAM_RET_TYPE,
1046                        STRUCT_EXTENSION_PARAM,
1047                        forEachExtensionMarshal,
1048                        self.extensionMarshalPrototype,
1049                        "write")))
1050
1051        if self.variant != "guest":
1052            self.module.appendImpl(
1053                self.cgenImpl.makeFuncImpl(
1054                    self.extensionUnmarshalPrototype,
1055                    lambda cgen: self.doExtensionStructMarshalingCodegen(
1056                        cgen,
1057                        STREAM_RET_TYPE,
1058                        STRUCT_EXTENSION_PARAM_FOR_WRITE,
1059                        forEachExtensionUnmarshal,
1060                        self.extensionUnmarshalPrototype,
1061                        "read")))
1062