1#!/usr/bin/env python
2#
3# Copyright (C) 2022 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#      http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16"""Verify that one set of hidden API flags is a subset of another."""
17import dataclasses
18import typing
19
20from itertools import chain
21
22
23@dataclasses.dataclass()
24class Node:
25    """A node in the signature trie."""
26
27    # The type of the node.
28    #
29    # Leaf nodes are of type "member".
30    # Interior nodes can be either "package", or "class".
31    type: str
32
33    # The selector of the node.
34    #
35    # That is a string that can be used to select the node, e.g. in a pattern
36    # that is passed to InteriorNode.get_matching_rows().
37    selector: str
38
39    def values(self, selector):
40        """Get the values from a set of selected nodes.
41
42        :param selector: a function that can be applied to a key in the nodes
43            attribute to determine whether to return its values.
44
45        :return: A list of iterables of all the values associated with
46            this node and its children.
47        """
48        values = []
49        self.append_values(values, selector)
50        return values
51
52    def append_values(self, values, selector):
53        """Append the values associated with this node and its children.
54
55        For each item (key, child) in nodes the child node's values are returned
56        if and only if the selector returns True when called on its key. A child
57        node's values are all the values associated with it and all its
58        descendant nodes.
59
60        :param selector: a function that can be applied to a key in the nodes
61        attribute to determine whether to return its values.
62        :param values: a list of a iterables of values.
63        """
64        raise NotImplementedError("Please Implement this method")
65
66    def child_nodes(self):
67        """Get an iterable of the child nodes of this node."""
68        raise NotImplementedError("Please Implement this method")
69
70
71# pylint: disable=line-too-long
72@dataclasses.dataclass()
73class InteriorNode(Node):
74    """An interior node in a trie.
75
76    Each interior node has a dict that maps from an element of a signature to
77    either another interior node or a leaf. Each interior node represents either
78    a package, class or nested class. Class members are represented by a Leaf.
79
80    Associating the set of flags [public-api] with the signature
81    "Ljava/lang/Object;->String()Ljava/lang/String;" will cause the following
82    nodes to be created:
83    Node()
84    ^- package:java -> Node()
85       ^- package:lang -> Node()
86           ^- class:Object -> Node()
87              ^- member:String()Ljava/lang/String; -> Leaf([public-api])
88
89    Associating the set of flags [blocked,core-platform-api] with the signature
90    "Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;"
91    will cause the following nodes to be created:
92    Node()
93    ^- package:java -> Node()
94       ^- package:lang -> Node()
95           ^- class:Character -> Node()
96              ^- class:UnicodeScript -> Node()
97                 ^- member:of(I)Ljava/lang/Character$UnicodeScript;
98                    -> Leaf([blocked,core-platform-api])
99    """
100
101    # pylint: enable=line-too-long
102
103    # A dict from an element of the signature to the Node/Leaf containing the
104    # next element/value.
105    nodes: typing.Dict[str, Node] = dataclasses.field(default_factory=dict)
106
107    # pylint: disable=line-too-long
108    @staticmethod
109    def signature_to_elements(signature):
110        """Split a signature or a prefix into a number of elements:
111
112        1. The packages (excluding the leading L preceding the first package).
113        2. The class names, from outermost to innermost.
114        3. The member signature.
115        e.g.
116        Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;
117        will be broken down into these elements:
118        1. package:java
119        2. package:lang
120        3. class:Character
121        4. class:UnicodeScript
122        5. member:of(I)Ljava/lang/Character$UnicodeScript;
123        """
124        # Remove the leading L.
125        #  - java/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;
126        text = signature.removeprefix("L")
127        # Split the signature between qualified class name and the class member
128        # signature.
129        #  0 - java/lang/Character$UnicodeScript
130        #  1 - of(I)Ljava/lang/Character$UnicodeScript;
131        parts = text.split(";->")
132        # If there is no member then this will be an empty list.
133        member = parts[1:]
134        # Split the qualified class name into packages, and class name.
135        #  0 - java
136        #  1 - lang
137        #  2 - Character$UnicodeScript
138        elements = parts[0].split("/")
139        last_element = elements[-1]
140        wildcard = []
141        classes = []
142        if "*" in last_element:
143            if last_element not in ("*", "**"):
144                raise Exception(f"Invalid signature '{signature}': invalid "
145                                f"wildcard '{last_element}'")
146            packages = elements[0:-1]
147            # Cannot specify a wildcard and target a specific member
148            if member:
149                raise Exception(f"Invalid signature '{signature}': contains "
150                                f"wildcard '{last_element}' and "
151                                f"member signature '{member[0]}'")
152            wildcard = [last_element]
153        else:
154            packages = elements[0:-1]
155            # Split the class name into outer / inner classes
156            #  0 - Character
157            #  1 - UnicodeScript
158            classes = last_element.removesuffix(";").split("$")
159
160        # Assemble the parts into a single list, adding prefixes to identify
161        # the different parts. If a wildcard is provided then it looks something
162        # like this:
163        #  0 - package:java
164        #  1 - package:lang
165        #  2 - *
166        #
167        # Otherwise, it looks something like this:
168        #  0 - package:java
169        #  1 - package:lang
170        #  2 - class:Character
171        #  3 - class:UnicodeScript
172        #  4 - member:of(I)Ljava/lang/Character$UnicodeScript;
173        return list(
174            chain([("package", x) for x in packages],
175                  [("class", x) for x in classes],
176                  [("member", x) for x in member],
177                  [("wildcard", x) for x in wildcard]))
178
179    # pylint: enable=line-too-long
180
181    @staticmethod
182    def split_element(element):
183        element_type, element_value = element
184        return element_type, element_value
185
186    @staticmethod
187    def element_type(element):
188        element_type, _ = InteriorNode.split_element(element)
189        return element_type
190
191    @staticmethod
192    def elements_to_selector(elements):
193        """Compute a selector for a set of elements.
194
195        A selector uniquely identifies a specific Node in the trie. It is
196        essentially a prefix of a signature (without the leading L).
197
198        e.g. a trie containing "Ljava/lang/Object;->String()Ljava/lang/String;"
199        would contain nodes with the following selectors:
200        * "java"
201        * "java/lang"
202        * "java/lang/Object"
203        * "java/lang/Object;->String()Ljava/lang/String;"
204        """
205        signature = ""
206        preceding_type = ""
207        for element in elements:
208            element_type, element_value = InteriorNode.split_element(element)
209            separator = ""
210            if element_type == "package":
211                separator = "/"
212            elif element_type == "class":
213                if preceding_type == "class":
214                    separator = "$"
215                else:
216                    separator = "/"
217            elif element_type == "wildcard":
218                separator = "/"
219            elif element_type == "member":
220                separator += ";->"
221
222            if signature:
223                signature += separator
224
225            signature += element_value
226
227            preceding_type = element_type
228
229        return signature
230
231    def add(self, signature, value, only_if_matches=False):
232        """Associate the value with the specific signature.
233
234        :param signature: the member signature
235        :param value: the value to associated with the signature
236        :param only_if_matches: True if the value is added only if the signature
237             matches at least one of the existing top level packages.
238        :return: n/a
239        """
240        # Split the signature into elements.
241        elements = self.signature_to_elements(signature)
242        # Find the Node associated with the deepest class.
243        node = self
244        for index, element in enumerate(elements[:-1]):
245            if element in node.nodes:
246                node = node.nodes[element]
247            elif only_if_matches and index == 0:
248                return
249            else:
250                selector = self.elements_to_selector(elements[0:index + 1])
251                next_node = InteriorNode(
252                    type=InteriorNode.element_type(element), selector=selector)
253                node.nodes[element] = next_node
254                node = next_node
255        # Add a Leaf containing the value and associate it with the member
256        # signature within the class.
257        last_element = elements[-1]
258        last_element_type = self.element_type(last_element)
259        if last_element_type != "member":
260            raise Exception(
261                f"Invalid signature: {signature}, does not identify a "
262                "specific member")
263        if last_element in node.nodes:
264            raise Exception(f"Duplicate signature: {signature}")
265        leaf = Leaf(
266            type=last_element_type,
267            selector=signature,
268            value=value,
269        )
270        node.nodes[last_element] = leaf
271
272    def get_matching_rows(self, pattern):
273        """Get the values (plural) associated with the pattern.
274
275        e.g. If the pattern is a full signature then this will return a list
276        containing the value associated with that signature.
277
278        If the pattern is a class then this will return a list containing the
279        values associated with all members of that class.
280
281        If the pattern ends with "*" then the preceding part is treated as a
282        package and this will return a list containing the values associated
283        with all the members of all the classes in that package.
284
285        If the pattern ends with "**" then the preceding part is treated
286        as a package and this will return a list containing the values
287        associated with all the members of all the classes in that package and
288        all sub-packages.
289
290        :param pattern: the pattern which could be a complete signature or a
291        class, or package wildcard.
292        :return: an iterable containing all the values associated with the
293        pattern.
294        """
295        elements = self.signature_to_elements(pattern)
296        node = self
297
298        # Include all values from this node and all its children.
299        selector = lambda x: True
300
301        last_element = elements[-1]
302        last_element_type, last_element_value = self.split_element(last_element)
303        if last_element_type == "wildcard":
304            elements = elements[:-1]
305            if last_element_value == "*":
306                # Do not include values from sub-packages.
307                selector = lambda x: InteriorNode.element_type(x) != "package"
308
309        for element in elements:
310            if element in node.nodes:
311                node = node.nodes[element]
312            else:
313                return []
314
315        return node.values(selector)
316
317    def append_values(self, values, selector):
318        for key, node in self.nodes.items():
319            if selector(key):
320                node.append_values(values, lambda x: True)
321
322    def child_nodes(self):
323        return self.nodes.values()
324
325
326@dataclasses.dataclass()
327class Leaf(Node):
328    """A leaf of the trie"""
329
330    # The value associated with this leaf.
331    value: typing.Any
332
333    def append_values(self, values, selector):
334        values.append(self.value)
335
336    def child_nodes(self):
337        return []
338
339
340def signature_trie():
341    return InteriorNode(type="root", selector="")
342