1#!/usr/bin/env python
2#
3# Copyright (C) 2021 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#      http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16"""gen_sdk is a command line tool for managing the sdk extension proto db.
17
18Example usages:
19# Print a binary representation of the proto database.
20$ gen_sdk --action print_binary
21
22# Validate the database
23$ gen_sdk --action validate
24
25# Create a new SDK
26$ gen_sdk --action new_sdk --sdk 1 --modules=IPSEC,SDK_EXTENSIONS
27"""
28
29import argparse
30import google.protobuf.text_format
31import pathlib
32import sys
33
34from sdk_pb2 import ExtensionVersion
35from sdk_pb2 import ExtensionDatabase
36from sdk_pb2 import SdkModule
37from sdk_pb2 import SdkVersion
38
39
40def ParseArgs(argv):
41  parser = argparse.ArgumentParser('Manage the extension SDK database')
42  parser.add_argument(
43    '--database',
44    type=pathlib.Path,
45    metavar='PATH',
46    default='extensions_db.textpb',
47    help='The existing text-proto database to use. (default: extensions_db.textpb)'
48  )
49  parser.add_argument(
50    '--action',
51    choices=['print_binary', 'new_sdk', 'validate'],
52    metavar='ACTION',
53    required=True,
54    help='Which action to take (print_binary|new_sdk|validate).'
55  )
56  parser.add_argument(
57    '--sdk',
58    type=int,
59    metavar='SDK',
60    help='The extension SDK level to deal with (int)'
61  )
62  parser.add_argument(
63    '--modules',
64    metavar='MODULES',
65    help='Comma-separated list of modules providing new APIs. Used for action=new_sdk to create a '
66         'new SDK that only requires new versions of some modules.'
67  )
68  return parser.parse_args(argv)
69
70
71"""Print the binary representation of the db proto to stdout."""
72def PrintBinary(database):
73  sys.stdout.buffer.write(database.SerializeToString())
74
75
76def ValidateDatabase(database, dbname):
77  def find_duplicate(l):
78    s = set()
79    for i in l:
80      if i in s:
81        return i
82      s.add(i)
83    return None
84
85  def find_bug():
86    dupe = find_duplicate([v.version for v in database.versions])
87    if dupe:
88      return 'Found duplicate extension version: %d' % dupe
89
90    for version in database.versions:
91      dupe = find_duplicate([r.module for r in version.requirements])
92      if dupe:
93        return 'Found duplicate module requirement for %s in single version %s' % (dupe, version)
94
95    prev_requirements = {}
96    for version in sorted(database.versions, key=lambda v: v.version):
97      for requirement in version.requirements:
98        if requirement.module in prev_requirements:
99          prev = prev_requirements[requirement.module]
100          if prev.version > requirement.version.version:
101            return 'Found module requirement moving backwards: %s in %s' % (requirement, version)
102        prev_requirements[requirement.module] = requirement.version
103
104    for version in database.versions:
105      required_modules = [r.module for r in version.requirements]
106      if SdkModule.UNKNOWN in required_modules:
107        return 'SDK %d has a requirement on the UNKNOWN module' % version.version
108      if not all([m in SdkModule.values() for m in required_modules]):
109        return 'SDK %d has a requirement on an undefined module value' % version.version
110      has_adservices = SdkModule.AD_SERVICES in required_modules
111      has_extservices = SdkModule.EXT_SERVICES in required_modules
112      if version.version >= 9 and (has_adservices ^ has_extservices):
113        return 'AD_SERVICES and EXT_SERVICES must be finalized together as of version 9'
114
115    return None
116
117  err = find_bug()
118  if err is not None:
119    print('%s not valid, aborting:\n  %s' % (dbname, err))
120    sys.exit(1)
121
122
123def NewSdk(database, args):
124  if not args.sdk:
125    print('Missing required argument --sdk for action new_sdk')
126    sys.exit(1)
127
128  new_version = args.sdk
129  if args.modules:
130    module_names = args.modules.split(',')
131  else:
132    # Default: require all modules
133    module_names = [m for m in SdkModule.keys() if not m == 'UNKNOWN']
134
135  module_values = [SdkModule.Value(m) for m in module_names]
136  new_requirements = {}
137
138  # Gather the previous highest requirement of each module
139  for prev_version in sorted(database.versions, key=lambda v: v.version):
140    for prev_requirement in prev_version.requirements:
141      new_requirements[prev_requirement.module] = prev_requirement.version
142
143  # Add new requirements of this version
144  for module in module_values:
145    new_requirements[module] = SdkVersion(version=new_version)
146
147  to_proto = lambda m : ExtensionVersion.ModuleRequirement(module=m, version=new_requirements[m])
148  module_requirements = [to_proto(m) for m in new_requirements]
149  extension_version = ExtensionVersion(version=new_version, requirements=module_requirements)
150  database.versions.append(extension_version)
151
152  print('Created a new extension SDK level %d with modules %s' % (new_version, ','.join(module_names)))
153
154
155def main(argv):
156  args = ParseArgs(argv)
157  with args.database.open('r') as f:
158    database = google.protobuf.text_format.Parse(f.read(), ExtensionDatabase())
159
160  ValidateDatabase(database, 'Input database')
161
162  {
163    'validate': lambda : print('Validated database'),
164    'print_binary': lambda : PrintBinary(database),
165    'new_sdk': lambda : NewSdk(database, args)
166  }[args.action]()
167
168  if args.action in ['new_sdk']:
169    ValidateDatabase(database, 'Post-modification database')
170    with args.database.open('w') as f:
171      f.write(google.protobuf.text_format.MessageToString(database))
172
173if __name__ == '__main__':
174  main(sys.argv[1:])
175