1"""Provides EntityDatabase, a class that keeps track of spec-defined entities and associated macros."""
2
3# Copyright (c) 2018-2019 Collabora, Ltd.
4#
5# SPDX-License-Identifier: Apache-2.0
6#
7# Author(s):    Ryan Pavlik <ryan.pavlik@collabora.com>
8
9from abc import ABC, abstractmethod
10
11from .shared import (CATEGORIES_WITH_VALIDITY, EXTENSION_CATEGORY,
12                     NON_EXISTENT_MACROS, EntityData)
13from .util import getElemName
14
15
16def _entityToDict(data):
17    return {
18        'macro': data.macro,
19        'filename': data.filename,
20        'category': data.category,
21        'directory': data.directory
22    }
23
24
25class EntityDatabase(ABC):
26    """Parsed and processed information from the registry XML.
27
28    Must be subclasses for each specific API.
29    """
30
31    ###
32    # Methods that must be implemented in subclasses.
33    ###
34    @abstractmethod
35    def makeRegistry(self):
36        """Return a Registry object that has already had loadFile() and parseTree() called.
37
38        Called only once during construction.
39        """
40        raise NotImplementedError
41
42    @abstractmethod
43    def getNamePrefix(self):
44        """Return the (two-letter) prefix of all entity names for this API.
45
46        Called only once during construction.
47        """
48        raise NotImplementedError
49
50    @abstractmethod
51    def getPlatformRequires(self):
52        """Return the 'requires' string associated with external/platform definitions.
53
54        This is the string found in the requires attribute of the XML for entities that
55        are externally defined in a platform include file, like the question marks in:
56
57        <type requires="???" name="int8_t"/>
58
59        In Vulkan, this is 'vk_platform'.
60
61        Called only once during construction.
62        """
63        raise NotImplementedError
64
65    ###
66    # Methods that it is optional to **override**
67    ###
68    def getSystemTypes(self):
69        """Return an enumerable of strings that name system types.
70
71        System types use the macro `code`, and they do not generate API/validity includes.
72
73        Called only once during construction.
74        """
75        return []
76
77    def getGeneratedDirs(self):
78        """Return a sequence of strings that are the subdirectories of generates API includes.
79
80        Called only once during construction.
81        """
82        return ['basetypes',
83                'defines',
84                'enums',
85                'flags',
86                'funcpointers',
87                'handles',
88                'protos',
89                'structs']
90
91    def populateMacros(self):
92        """Perform API-specific calls, if any, to self.addMacro() and self.addMacros().
93
94        It is recommended to implement/override this and call
95        self.addMacros(..., ..., [..., "flags"]),
96        since the base implementation, in _basicPopulateMacros(),
97        does not add any macros as pertaining to the category "flags".
98
99        Called only once during construction.
100        """
101        pass
102
103    def populateEntities(self):
104        """Perform API-specific calls, if any, to self.addEntity()."""
105        pass
106
107    def getEntitiesWithoutValidity(self):
108        """Return an enumerable of entity names that do not generate validity includes."""
109        return [self.mixed_case_name_prefix +
110                x for x in ['BaseInStructure', 'BaseOutStructure']]
111
112    def getExclusionSet(self):
113        """Return a set of "support=" attribute strings that should not be included in the database.
114
115        Called only during construction."""
116        return set(('disabled',))
117
118    ###
119    # Methods that it is optional to **extend**
120    ###
121    def handleType(self, name, info, requires):
122        """Add entities, if appropriate, for an item in registry.typedict.
123
124        Called at construction for every name, info in registry.typedict.items()
125        not immediately skipped,
126        to perform the correct associated addEntity() call, if applicable.
127        The contents of the requires attribute, if any, is passed in requires.
128
129        May be extended by API-specific code to handle some cases preferentially,
130        then calling the super implementation to handle the rest.
131        """
132        if requires == self.platform_requires:
133            # Ah, no, don't skip this, it's just in the platform header file.
134            # TODO are these code or basetype?
135            self.addEntity(name, 'code', elem=info.elem, generates=False)
136            return
137
138        protect = info.elem.get('protect')
139        if protect:
140            self.addEntity(protect, 'dlink',
141                           category='configdefines', generates=False)
142
143        alias = info.elem.get('alias')
144        if alias:
145            self.addAlias(name, alias)
146
147        cat = info.elem.get('category')
148        if cat == 'struct':
149            self.addEntity(name, 'slink', elem=info.elem)
150
151        elif cat == 'union':
152            # TODO: is this right?
153            self.addEntity(name, 'slink', elem=info.elem)
154
155        elif cat == 'enum':
156            self.addEntity(
157                name, 'elink', elem=info.elem)
158
159        elif cat == 'handle':
160            self.addEntity(name, 'slink', elem=info.elem,
161                           category='handles')
162
163        elif cat == 'bitmask':
164            self.addEntity(
165                name, 'tlink', elem=info.elem, category='flags')
166
167        elif cat == 'basetype':
168            self.addEntity(name, 'basetype',
169                           elem=info.elem)
170
171        elif cat == 'define':
172            self.addEntity(name, 'dlink', elem=info.elem)
173
174        elif cat == 'funcpointer':
175            self.addEntity(name, 'tlink', elem=info.elem)
176
177        elif cat == 'include':
178            # skip
179            return
180
181        elif cat is None:
182            self.addEntity(name, 'code', elem=info.elem, generates=False)
183
184        else:
185            raise RuntimeError('unrecognized category {}'.format(cat))
186
187    def handleCommand(self, name, info):
188        """Add entities, if appropriate, for an item in registry.cmddict.
189
190        Called at construction for every name, info in registry.cmddict.items().
191        Calls self.addEntity() accordingly.
192        """
193        self.addEntity(name, 'flink', elem=info.elem,
194                       category='commands', directory='protos')
195
196    def handleExtension(self, name, info):
197        """Add entities, if appropriate, for an item in registry.extdict.
198
199        Called at construction for every name, info in registry.extdict.items().
200        Calls self.addEntity() accordingly.
201        """
202        if info.supported in self._supportExclusionSet:
203            # Don't populate with disabled extensions.
204            return
205
206        # Only get the protect strings and name from extensions
207
208        self.addEntity(name, None, category=EXTENSION_CATEGORY,
209                       generates=False)
210        protect = info.elem.get('protect')
211        if protect:
212            self.addEntity(protect, 'dlink',
213                           category='configdefines', generates=False)
214
215    def handleEnumValue(self, name, info):
216        """Add entities, if appropriate, for an item in registry.enumdict.
217
218        Called at construction for every name, info in registry.enumdict.items().
219        Calls self.addEntity() accordingly.
220        """
221        self.addEntity(name, 'ename', elem=info.elem,
222                       category='enumvalues', generates=False)
223
224    ###
225    # END of methods intended to be implemented, overridden, or extended in child classes!
226    ###
227
228    ###
229    # Accessors
230    ###
231    def findMacroAndEntity(self, macro, entity):
232        """Look up EntityData by macro and entity pair.
233
234        Does **not** resolve aliases."""
235        return self._byMacroAndEntity.get((macro, entity))
236
237    def findEntity(self, entity):
238        """Look up EntityData by entity name (case-sensitive).
239
240        If it fails, it will try resolving aliases.
241        """
242        result = self._byEntity.get(entity)
243        if result:
244            return result
245
246        alias_set = self._aliasSetsByEntity.get(entity)
247        if alias_set:
248            for alias in alias_set:
249                if alias in self._byEntity:
250                    return self.findEntity(alias)
251
252            assert(not "Alias without main entry!")
253
254        return None
255
256    def findEntityCaseInsensitive(self, entity):
257        """Look up EntityData by entity name (case-insensitive).
258
259        Does **not** resolve aliases."""
260        return self._byLowercaseEntity.get(entity.lower())
261
262    def getMemberElems(self, commandOrStruct):
263        """Given a command or struct name, retrieve the ETree elements for each member/param.
264
265        Returns None if the entity is not found or doesn't have members/params.
266        """
267        data = self.findEntity(commandOrStruct)
268
269        if not data:
270            return None
271        if data.elem is None:
272            return None
273        if data.macro == 'slink':
274            tag = 'member'
275        else:
276            tag = 'param'
277        return data.elem.findall('.//{}'.format(tag))
278
279    def getMemberNames(self, commandOrStruct):
280        """Given a command or struct name, retrieve the names of each member/param.
281
282        Returns an empty list if the entity is not found or doesn't have members/params.
283        """
284        members = self.getMemberElems(commandOrStruct)
285        if not members:
286            return []
287        ret = []
288        for member in members:
289            name_tag = member.find('name')
290            if name_tag:
291                ret.append(name_tag.text)
292        return ret
293
294    def getEntityJson(self):
295        """Dump the internal entity dictionary to JSON for debugging."""
296        import json
297        d = {entity: _entityToDict(data)
298             for entity, data in self._byEntity.items()}
299        return json.dumps(d, sort_keys=True, indent=4)
300
301    def entityHasValidity(self, entity):
302        """Estimate if we expect to see a validity include for an entity name.
303
304        Returns None if the entity name is not known,
305        otherwise a boolean: True if a validity include is expected.
306
307        Related to Generator.isStructAlwaysValid.
308        """
309        data = self.findEntity(entity)
310        if not data:
311            return None
312
313        if entity in self.entities_without_validity:
314            return False
315
316        if data.category == 'protos':
317            # All protos have validity
318            return True
319
320        if data.category not in CATEGORIES_WITH_VALIDITY:
321            return False
322
323        # Handle structs here.
324        members = self.getMemberElems(entity)
325        if not members:
326            return None
327        for member in members:
328            member_name = getElemName(member)
329            member_type = member.find('type').text
330            member_category = member.get('category')
331
332            if member_name in ('next', 'type'):
333                return True
334
335            if member_type in ('void', 'char'):
336                return True
337
338            if member.get('noautovalidity'):
339                # Not generating validity for this member, skip it
340                continue
341
342            if member.get('len'):
343                # Array
344                return True
345
346            typetail = member.find('type').tail
347            if typetail and '*' in typetail:
348                # Pointer
349                return True
350
351            if member_category in ('handle', 'enum', 'bitmask'):
352                return True
353
354            if member.get('category') in ('struct', 'union') \
355                    and self.entityHasValidity(member_type):
356                # struct or union member - recurse
357                return True
358
359        # Got this far - no validity needed
360        return False
361
362    def entityGenerates(self, entity_name):
363        """Return True if the named entity generates include file(s)."""
364        return entity_name in self._generating_entities
365
366    @property
367    def generating_entities(self):
368        """Return a sequence of all generating entity names."""
369        return self._generating_entities.keys()
370
371    def shouldBeRecognized(self, macro, entity_name):
372        """Determine, based on the macro and the name provided, if we should expect to recognize the entity.
373
374        True if it is linked. Specific APIs may also provide additional cases where it is True."""
375        return self.isLinkedMacro(macro)
376
377    def likelyRecognizedEntity(self, entity_name):
378        """Guess (based on name prefix alone) if an entity is likely to be recognized."""
379        return entity_name.lower().startswith(self.name_prefix)
380
381    def isLinkedMacro(self, macro):
382        """Identify if a macro is considered a "linked" macro."""
383        return macro in self._linkedMacros
384
385    def isValidMacro(self, macro):
386        """Identify if a macro is known and valid."""
387        if macro not in self._categoriesByMacro:
388            return False
389
390        return macro not in NON_EXISTENT_MACROS
391
392    def getCategoriesForMacro(self, macro):
393        """Identify the categories associated with a (known, valid) macro."""
394        if macro in self._categoriesByMacro:
395            return self._categoriesByMacro[macro]
396        return None
397
398    def areAliases(self, first_entity_name, second_entity_name):
399        """Return true if the two entity names are equivalent (aliases of each other)."""
400        alias_set = self._aliasSetsByEntity.get(first_entity_name)
401        if not alias_set:
402            # If this assert fails, we have goofed in addAlias
403            assert(second_entity_name not in self._aliasSetsByEntity)
404
405            return False
406
407        return second_entity_name in alias_set
408
409    @property
410    def macros(self):
411        """Return the collection of all known entity-related markup macros."""
412        return self._categoriesByMacro.keys()
413
414    def childTypes(self, typename):
415        """Return the list of types specifying typename as their parent type."""
416        children = [childname
417                    for childname, entity in self._byEntity.items()
418                    if entity.elem is not None and entity.elem.get("parentstruct") == typename]
419        return children
420
421    ###
422    # Methods only used during initial setup/population of this data structure
423    ###
424    def addMacro(self, macro, categories, link=False):
425        """Add a single markup macro to the collection of categories by macro.
426
427        Also adds the macro to the set of linked macros if link=True.
428
429        If a macro has already been supplied to a call, later calls for that macro have no effect.
430        """
431        if macro in self._categoriesByMacro:
432            return
433        self._categoriesByMacro[macro] = categories
434        if link:
435            self._linkedMacros.add(macro)
436
437    def addMacros(self, letter, macroTypes, categories):
438        """Add markup macros associated with a leading letter to the collection of categories by macro.
439
440        Also, those macros created using 'link' in macroTypes will also be added to the set of linked macros.
441
442        Basically automates a number of calls to addMacro().
443        """
444        for macroType in macroTypes:
445            macro = letter + macroType
446            self.addMacro(macro, categories, link=(macroType == 'link'))
447
448    def addAlias(self, entityName, aliasName):
449        """Record that entityName is an alias for aliasName."""
450        # See if we already have something with this as the alias.
451        alias_set = self._aliasSetsByEntity.get(aliasName)
452        other_alias_set = self._aliasSetsByEntity.get(entityName)
453        if alias_set and other_alias_set:
454            # If this fails, we need to merge sets and update.
455            assert(alias_set is other_alias_set)
456
457        if not alias_set:
458            # Try looking by the other name.
459            alias_set = other_alias_set
460
461        if not alias_set:
462            # Nope, this is a new set.
463            alias_set = set()
464            self._aliasSets.append(alias_set)
465
466        # Add both names to the set
467        alias_set.add(entityName)
468        alias_set.add(aliasName)
469
470        # Associate the set with each name
471        self._aliasSetsByEntity[aliasName] = alias_set
472        self._aliasSetsByEntity[entityName] = alias_set
473
474    def addEntity(self, entityName, macro, category=None, elem=None,
475                  generates=None, directory=None, filename=None):
476        """Add an entity (command, structure type, enum, enum value, etc) in the database.
477
478        If an entityName has already been supplied to a call, later calls for that entityName have no effect.
479
480        Arguments:
481        entityName -- the name of the entity.
482        macro -- the macro (without the trailing colon) that should be used to refer to this entity.
483
484        Optional keyword arguments:
485        category -- If not manually specified, looked up based on the macro.
486        elem -- The ETree element associated with the entity in the registry XML.
487        generates -- Indicates whether this entity generates api and validity include files.
488                     Default depends on directory (or if not specified, category).
489        directory -- The directory that include files (under api/ and validity/) are generated in.
490                     If not specified (and generates is True), the default is the same as the category,
491                     which is almost always correct.
492        filename -- The relative filename (under api/ or validity/) where includes are generated for this.
493                    This only matters if generates is True (default). If not specified and generates is True,
494                    one will be generated based on directory and entityName.
495        """
496        # Probably dealt with in handleType(), but just in case it wasn't.
497        if elem is not None:
498            alias = elem.get('alias')
499            if alias:
500                self.addAlias(entityName, alias)
501
502        if entityName in self._byEntity:
503            # skip if already recorded.
504            return
505
506        # Look up category based on the macro, if category isn't specified.
507        if category is None:
508            category = self._categoriesByMacro.get(macro)[0]
509
510        if generates is None:
511            potential_dir = directory or category
512            generates = potential_dir in self._generated_dirs
513
514        # If directory isn't specified and this entity generates,
515        # the directory is the same as the category.
516        if directory is None and generates:
517            directory = category
518
519        # Don't generate a filename if this entity doesn't generate includes.
520        if filename is None and generates:
521            filename = f'{directory}/{entityName}.adoc'
522
523        data = EntityData(
524            entity=entityName,
525            macro=macro,
526            elem=elem,
527            filename=filename,
528            category=category,
529            directory=directory
530        )
531        if entityName.lower() not in self._byLowercaseEntity:
532            self._byLowercaseEntity[entityName.lower()] = []
533
534        self._byEntity[entityName] = data
535        self._byLowercaseEntity[entityName.lower()].append(data)
536        self._byMacroAndEntity[(macro, entityName)] = data
537        if generates and filename is not None:
538            self._generating_entities[entityName] = data
539
540    def __init__(self):
541        """Constructor: Do not extend or override.
542
543        Changing the behavior of other parts of this logic should be done by
544        implementing, extending, or overriding (as documented):
545
546        - Implement makeRegistry()
547        - Implement getNamePrefix()
548        - Implement getPlatformRequires()
549        - Override getSystemTypes()
550        - Override populateMacros()
551        - Override populateEntities()
552        - Extend handleType()
553        - Extend handleCommand()
554        - Extend handleExtension()
555        - Extend handleEnumValue()
556        """
557        # Internal data that we don't want consumers of the class touching for fear of
558        # breaking invariants
559        self._byEntity = {}
560        self._byLowercaseEntity = {}
561        self._byMacroAndEntity = {}
562        self._categoriesByMacro = {}
563        self._linkedMacros = set()
564        self._aliasSetsByEntity = {}
565        self._aliasSets = []
566
567        self._registry = None
568
569        # Retrieve from subclass, if overridden, then store locally.
570        self._supportExclusionSet = set(self.getExclusionSet())
571
572        # Entities that get a generated/api/category/entity.adoc file.
573        self._generating_entities = {}
574
575        # Name prefix members
576        self.name_prefix = self.getNamePrefix().lower()
577        self.mixed_case_name_prefix = self.name_prefix[:1].upper(
578        ) + self.name_prefix[1:]
579        # Regex string for the name prefix that is case-insensitive.
580        self.case_insensitive_name_prefix_pattern = ''.join(
581            ('[{}{}]'.format(c.upper(), c) for c in self.name_prefix))
582
583        self.platform_requires = self.getPlatformRequires()
584
585        self._generated_dirs = set(self.getGeneratedDirs())
586
587        # Note: Default impl requires self.mixed_case_name_prefix
588        self.entities_without_validity = set(self.getEntitiesWithoutValidity())
589
590        # TODO: Where should flags actually go? Not mentioned in the style guide.
591        # TODO: What about flag wildcards? There are a few such uses...
592
593        # Abstract method: subclass must implement to define macros for flags
594        self.populateMacros()
595
596        # Now, do default macro population
597        self._basicPopulateMacros()
598
599        # Abstract method: subclass must implement to add any "not from the registry" (and not system type)
600        # entities
601        self.populateEntities()
602
603        # Now, do default entity population
604        self._basicPopulateEntities(self.registry)
605
606    ###
607    # Methods only used internally during initial setup/population of this data structure
608    ###
609    @property
610    def registry(self):
611        """Return a Registry."""
612        if not self._registry:
613            self._registry = self.makeRegistry()
614        return self._registry
615
616    def _basicPopulateMacros(self):
617        """Contains calls to self.addMacro() and self.addMacros().
618
619        If you need to change any of these, do so in your override of populateMacros(),
620        which will be called first.
621        """
622        self.addMacro('basetype', ['basetypes'])
623        self.addMacro('code', ['code'])
624        self.addMacros('f', ['link', 'name', 'text'], ['protos'])
625        self.addMacros('s', ['link', 'name', 'text'], ['structs', 'handles'])
626        self.addMacros('e', ['link', 'name', 'text'], ['enums'])
627        self.addMacros('p', ['name', 'text'], ['parameter', 'member'])
628        self.addMacros('t', ['link', 'name'], ['funcpointers'])
629        self.addMacros('d', ['link', 'name'], ['defines', 'configdefines'])
630
631        for macro in NON_EXISTENT_MACROS:
632            # Still search for them
633            self.addMacro(macro, None)
634
635    def _basicPopulateEntities(self, registry):
636        """Contains typical calls to self.addEntity().
637
638        If you need to change any of these, do so in your override of populateEntities(),
639        which will be called first.
640        """
641        system_types = set(self.getSystemTypes())
642        for t in system_types:
643            self.addEntity(t, 'code', generates=False)
644
645        for name, info in registry.typedict.items():
646            if name in system_types:
647                # We already added these.
648                continue
649
650            requires = info.elem.get('requires')
651
652            if requires and not requires.lower().startswith(self.name_prefix):
653                # This is an externally-defined type, will skip it.
654                continue
655
656            # OK, we might actually add an entity here
657            self.handleType(name=name, info=info, requires=requires)
658
659        for name, info in registry.enumdict.items():
660            self.handleEnumValue(name, info)
661
662        for name, info in registry.cmddict.items():
663            self.handleCommand(name, info)
664
665        for name, info in registry.extdict.items():
666            self.handleExtension(name, info)
667