1# Copyright (C) 2020 The Android Open Source Project 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Utility functions for atest.""" 15from __future__ import print_function 16 17import getpass 18import logging 19import os 20import pathlib 21from pathlib import Path 22from socket import socket 23import subprocess 24import time 25from typing import Any, Callable 26import uuid 27 28from atest import atest_utils 29from atest import constants 30from atest.atest_enum import DetectType 31from atest.metrics import metrics 32import httplib2 33from oauth2client import client as oauth2_client 34from oauth2client import contrib as oauth2_contrib 35from oauth2client import tools as oauth2_tools 36 37 38class RunFlowFlags: 39 """Flags for oauth2client.tools.run_flow.""" 40 41 def __init__(self, browser_auth): 42 self.auth_host_port = [8080, 8090] 43 self.auth_host_name = 'localhost' 44 self.logging_level = 'ERROR' 45 self.noauth_local_webserver = not browser_auth 46 47 48class GCPHelper: 49 """GCP bucket helper class.""" 50 51 def __init__( 52 self, 53 client_id=None, 54 client_secret=None, 55 user_agent=None, 56 scope=constants.SCOPE_BUILD_API_SCOPE, 57 ): 58 """Init stuff for GCPHelper class. 59 60 Args: 61 client_id: String, client id from the cloud project. 62 client_secret: String, client secret for the client_id. 63 user_agent: The user agent for the credential. 64 scope: String, scopes separated by space. 65 """ 66 self.client_id = client_id 67 self.client_secret = client_secret 68 self.user_agent = user_agent 69 self.scope = scope 70 71 def get_refreshed_credential_from_file(self, creds_file_path): 72 """Get refreshed credential from file. 73 74 Args: 75 creds_file_path: Credential file path. 76 77 Returns: 78 An oauth2client.OAuth2Credentials instance. 79 """ 80 credential = self.get_credential_from_file(creds_file_path) 81 if credential: 82 try: 83 credential.refresh(httplib2.Http()) 84 except oauth2_client.AccessTokenRefreshError as e: 85 logging.debug('Token refresh error: %s', e) 86 if not credential.invalid: 87 return credential 88 logging.debug('Cannot get credential.') 89 return None 90 91 def get_credential_from_file(self, creds_file_path): 92 """Get credential from file. 93 94 Args: 95 creds_file_path: Credential file path. 96 97 Returns: 98 An oauth2client.OAuth2Credentials instance. 99 """ 100 storage = oauth2_contrib.multiprocess_file_storage.get_credential_storage( 101 filename=os.path.abspath(creds_file_path), 102 client_id=self.client_id, 103 user_agent=self.user_agent, 104 scope=self.scope, 105 ) 106 return storage.get() 107 108 def get_credential_with_auth_flow(self, creds_file_path): 109 """Get Credential object from file. 110 111 Get credential object from file. Run oauth flow if haven't authorized 112 before. 113 114 Args: 115 creds_file_path: Credential file path. 116 117 Returns: 118 An oauth2client.OAuth2Credentials instance. 119 """ 120 credentials = None 121 # SSO auth 122 try: 123 token = self._get_sso_access_token() 124 credentials = oauth2_client.AccessTokenCredentials(token, 'atest') 125 if credentials: 126 return credentials 127 # pylint: disable=broad-except 128 except Exception as e: 129 logging.debug('Exception:%s', e) 130 # GCP auth flow 131 credentials = self.get_refreshed_credential_from_file(creds_file_path) 132 if not credentials: 133 storage = oauth2_contrib.multiprocess_file_storage.get_credential_storage( 134 filename=os.path.abspath(creds_file_path), 135 client_id=self.client_id, 136 user_agent=self.user_agent, 137 scope=self.scope, 138 ) 139 return self._run_auth_flow(storage) 140 return credentials 141 142 def _run_auth_flow(self, storage): 143 """Get user oauth2 credentials. 144 145 Using the loopback IP address flow for desktop clients. 146 147 Args: 148 storage: GCP storage object. 149 150 Returns: 151 An oauth2client.OAuth2Credentials instance. 152 """ 153 flags = RunFlowFlags(browser_auth=True) 154 155 # Get a free port on demand. 156 port = None 157 while not port or port < 10000: 158 with socket() as local_socket: 159 local_socket.bind(('', 0)) 160 _, port = local_socket.getsockname() 161 _localhost_port = port 162 _direct_uri = f'http://localhost:{_localhost_port}' 163 flow = oauth2_client.OAuth2WebServerFlow( 164 client_id=self.client_id, 165 client_secret=self.client_secret, 166 scope=self.scope, 167 user_agent=self.user_agent, 168 redirect_uri=f'{_direct_uri}', 169 ) 170 credentials = oauth2_tools.run_flow(flow=flow, storage=storage, flags=flags) 171 return credentials 172 173 @staticmethod 174 def _get_sso_access_token(): 175 """Use stubby command line to exchange corp sso to a scoped oauth 176 177 token. 178 179 Returns: 180 A token string. 181 """ 182 if not constants.TOKEN_EXCHANGE_COMMAND: 183 return None 184 185 request = constants.TOKEN_EXCHANGE_REQUEST.format( 186 user=getpass.getuser(), scope=constants.SCOPE 187 ) 188 # The output format is: oauth2_token: "<TOKEN>" 189 return subprocess.run( 190 constants.TOKEN_EXCHANGE_COMMAND, 191 input=request, 192 check=True, 193 text=True, 194 shell=True, 195 stdout=subprocess.PIPE, 196 ).stdout.split('"')[1] 197 198 199# TODO: The usage of build_client should be removed from this method because 200# it's not related to this module. For now, we temporarily declare the return 201# type hint for build_client_creator to be Any to avoid circular importing. 202def do_upload_flow( 203 extra_args: dict[str, str], 204 build_client_creator: Callable, 205 atest_run_id: str = None, 206) -> tuple: 207 """Run upload flow. 208 209 Asking user's decision and do the related steps. 210 211 Args: 212 extra_args: Dict of extra args to add to test run. 213 build_client_creator: A function that takes a credential and returns a 214 BuildClient object. 215 atest_run_id: The atest run ID to write into the invocation. 216 217 Return: 218 A tuple of credential object and invocation information dict. 219 """ 220 fetch_cred_start = time.time() 221 creds = fetch_credential() 222 metrics.LocalDetectEvent( 223 detect_type=DetectType.FETCH_CRED_MS, 224 result=int((time.time() - fetch_cred_start) * 1000), 225 ) 226 if creds: 227 prepare_upload_start = time.time() 228 build_client = build_client_creator(creds) 229 inv, workunit, local_build_id, build_target = _prepare_data( 230 build_client, atest_run_id or metrics.get_run_id() 231 ) 232 metrics.LocalDetectEvent( 233 detect_type=DetectType.UPLOAD_PREPARE_MS, 234 result=int((time.time() - prepare_upload_start) * 1000), 235 ) 236 extra_args[constants.INVOCATION_ID] = inv['invocationId'] 237 extra_args[constants.WORKUNIT_ID] = workunit['id'] 238 extra_args[constants.LOCAL_BUILD_ID] = local_build_id 239 extra_args[constants.BUILD_TARGET] = build_target 240 if not os.path.exists(os.path.dirname(constants.TOKEN_FILE_PATH)): 241 os.makedirs(os.path.dirname(constants.TOKEN_FILE_PATH)) 242 with open(constants.TOKEN_FILE_PATH, 'w') as token_file: 243 if creds.token_response: 244 token_file.write(creds.token_response['access_token']) 245 else: 246 token_file.write(creds.access_token) 247 return creds, inv 248 return None, None 249 250 251def fetch_credential(): 252 """Fetch the credential object.""" 253 creds_path = atest_utils.get_config_folder().joinpath( 254 constants.CREDENTIAL_FILE_NAME 255 ) 256 return GCPHelper( 257 client_id=constants.CLIENT_ID, 258 client_secret=constants.CLIENT_SECRET, 259 user_agent='atest', 260 ).get_credential_with_auth_flow(creds_path) 261 262 263def _prepare_data(client, atest_run_id: str): 264 """Prepare data for build api using. 265 266 Args: 267 build_client: The logstorage_utils.BuildClient object. 268 atest_run_id: The atest run ID to write into the invocation. 269 270 Return: 271 invocation and workunit object. 272 build id and build target of local build. 273 """ 274 try: 275 logging.disable(logging.INFO) 276 external_id = str(uuid.uuid4()) 277 branch = _get_branch(client) 278 target = _get_target(branch, client) 279 build_record = client.insert_local_build(external_id, target, branch) 280 client.insert_build_attempts(build_record) 281 invocation = client.insert_invocation(build_record, atest_run_id) 282 workunit = client.insert_work_unit(invocation) 283 return invocation, workunit, build_record['buildId'], target 284 finally: 285 logging.disable(logging.NOTSET) 286 287 288def _get_branch(build_client): 289 """Get source code tree branch. 290 291 Args: 292 build_client: The build client object. 293 294 Return: 295 "git_main" in internal git, "aosp-main" otherwise. 296 """ 297 default_branch = 'git_main' if constants.CREDENTIAL_FILE_NAME else 'aosp-main' 298 local_branch = 'git_%s' % atest_utils.get_manifest_branch() 299 branch = build_client.get_branch(local_branch) 300 return local_branch if branch else default_branch 301 302 303def _get_target(branch, build_client): 304 """Get local build selected target. 305 306 Args: 307 branch: The branch want to check. 308 build_client: The build client object. 309 310 Return: 311 The matched build target, "aosp_x86_64-trunk_staging-userdebug" 312 otherwise. 313 """ 314 default_target = 'aosp_x86_64-trunk_staging-userdebug' 315 local_target = atest_utils.get_build_target() 316 targets = [t['target'] for t in build_client.list_target(branch)['targets']] 317 return local_target if local_target in targets else default_target 318