1# Copyright (C) 2020 The Android Open Source Project
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
7#     http://www.apache.org/licenses/LICENSE-2.0
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Utility functions for atest."""
15from __future__ import print_function
17import getpass
18import logging
19import os
20import pathlib
21from pathlib import Path
22from socket import socket
23import subprocess
24import time
25from typing import Any, Callable
26import uuid
28from atest import atest_utils
29from atest import constants
30from atest.atest_enum import DetectType
31from atest.metrics import metrics
32import httplib2
33from oauth2client import client as oauth2_client
34from oauth2client import contrib as oauth2_contrib
35from oauth2client import tools as oauth2_tools
38class RunFlowFlags:
39  """Flags for oauth2client.tools.run_flow."""
41  def __init__(self, browser_auth):
42    self.auth_host_port = [8080, 8090]
43    self.auth_host_name = 'localhost'
44    self.logging_level = 'ERROR'
45    self.noauth_local_webserver = not browser_auth
48class GCPHelper:
49  """GCP bucket helper class."""
51  def __init__(
52      self,
53      client_id=None,
54      client_secret=None,
55      user_agent=None,
56      scope=constants.SCOPE_BUILD_API_SCOPE,
57  ):
58    """Init stuff for GCPHelper class.
60    Args:
61        client_id: String, client id from the cloud project.
62        client_secret: String, client secret for the client_id.
63        user_agent: The user agent for the credential.
64        scope: String, scopes separated by space.
65    """
66    self.client_id = client_id
67    self.client_secret = client_secret
68    self.user_agent = user_agent
69    self.scope = scope
71  def get_refreshed_credential_from_file(self, creds_file_path):
72    """Get refreshed credential from file.
74    Args:
75        creds_file_path: Credential file path.
77    Returns:
78        An oauth2client.OAuth2Credentials instance.
79    """
80    credential = self.get_credential_from_file(creds_file_path)
81    if credential:
82      try:
83        credential.refresh(httplib2.Http())
84      except oauth2_client.AccessTokenRefreshError as e:
85        logging.debug('Token refresh error: %s', e)
86      if not credential.invalid:
87        return credential
88    logging.debug('Cannot get credential.')
89    return None
91  def get_credential_from_file(self, creds_file_path):
92    """Get credential from file.
94    Args:
95        creds_file_path: Credential file path.
97    Returns:
98        An oauth2client.OAuth2Credentials instance.
99    """
100    storage = oauth2_contrib.multiprocess_file_storage.get_credential_storage(
101        filename=os.path.abspath(creds_file_path),
102        client_id=self.client_id,
103        user_agent=self.user_agent,
104        scope=self.scope,
105    )
106    return storage.get()
108  def get_credential_with_auth_flow(self, creds_file_path):
109    """Get Credential object from file.
111    Get credential object from file. Run oauth flow if haven't authorized
112    before.
114    Args:
115        creds_file_path: Credential file path.
117    Returns:
118        An oauth2client.OAuth2Credentials instance.
119    """
120    credentials = None
121    # SSO auth
122    try:
123      token = self._get_sso_access_token()
124      credentials = oauth2_client.AccessTokenCredentials(token, 'atest')
125      if credentials:
126        return credentials
127    # pylint: disable=broad-except
128    except Exception as e:
129      logging.debug('Exception:%s', e)
130    # GCP auth flow
131    credentials = self.get_refreshed_credential_from_file(creds_file_path)
132    if not credentials:
133      storage = oauth2_contrib.multiprocess_file_storage.get_credential_storage(
134          filename=os.path.abspath(creds_file_path),
135          client_id=self.client_id,
136          user_agent=self.user_agent,
137          scope=self.scope,
138      )
139      return self._run_auth_flow(storage)
140    return credentials
142  def _run_auth_flow(self, storage):
143    """Get user oauth2 credentials.
145    Using the loopback IP address flow for desktop clients.
147    Args:
148        storage: GCP storage object.
150    Returns:
151        An oauth2client.OAuth2Credentials instance.
152    """
153    flags = RunFlowFlags(browser_auth=True)
155    # Get a free port on demand.
156    port = None
157    while not port or port < 10000:
158      with socket() as local_socket:
159        local_socket.bind(('', 0))
160        _, port = local_socket.getsockname()
161    _localhost_port = port
162    _direct_uri = f'http://localhost:{_localhost_port}'
163    flow = oauth2_client.OAuth2WebServerFlow(
164        client_id=self.client_id,
165        client_secret=self.client_secret,
166        scope=self.scope,
167        user_agent=self.user_agent,
168        redirect_uri=f'{_direct_uri}',
169    )
170    credentials = oauth2_tools.run_flow(flow=flow, storage=storage, flags=flags)
171    return credentials
173  @staticmethod
174  def _get_sso_access_token():
175    """Use stubby command line to exchange corp sso to a scoped oauth
177    token.
179    Returns:
180        A token string.
181    """
182    if not constants.TOKEN_EXCHANGE_COMMAND:
183      return None
185    request = constants.TOKEN_EXCHANGE_REQUEST.format(
186        user=getpass.getuser(), scope=constants.SCOPE
187    )
188    # The output format is: oauth2_token: "<TOKEN>"
189    return subprocess.run(
190        constants.TOKEN_EXCHANGE_COMMAND,
191        input=request,
192        check=True,
193        text=True,
194        shell=True,
195        stdout=subprocess.PIPE,
196    ).stdout.split('"')[1]
199# TODO: The usage of build_client should be removed from this method because
200# it's not related to this module. For now, we temporarily declare the return
201# type hint for build_client_creator to be Any to avoid circular importing.
202def do_upload_flow(
203    extra_args: dict[str, str],
204    build_client_creator: Callable,
205    atest_run_id: str = None,
206) -> tuple:
207  """Run upload flow.
209  Asking user's decision and do the related steps.
211  Args:
212      extra_args: Dict of extra args to add to test run.
213      build_client_creator: A function that takes a credential and returns a
214        BuildClient object.
215      atest_run_id: The atest run ID to write into the invocation.
217  Return:
218      A tuple of credential object and invocation information dict.
219  """
220  fetch_cred_start = time.time()
221  creds = fetch_credential()
222  metrics.LocalDetectEvent(
223      detect_type=DetectType.FETCH_CRED_MS,
224      result=int((time.time() - fetch_cred_start) * 1000),
225  )
226  if creds:
227    prepare_upload_start = time.time()
228    build_client = build_client_creator(creds)
229    inv, workunit, local_build_id, build_target = _prepare_data(
230        build_client, atest_run_id or metrics.get_run_id()
231    )
232    metrics.LocalDetectEvent(
233        detect_type=DetectType.UPLOAD_PREPARE_MS,
234        result=int((time.time() - prepare_upload_start) * 1000),
235    )
236    extra_args[constants.INVOCATION_ID] = inv['invocationId']
237    extra_args[constants.WORKUNIT_ID] = workunit['id']
238    extra_args[constants.LOCAL_BUILD_ID] = local_build_id
239    extra_args[constants.BUILD_TARGET] = build_target
240    if not os.path.exists(os.path.dirname(constants.TOKEN_FILE_PATH)):
241      os.makedirs(os.path.dirname(constants.TOKEN_FILE_PATH))
242    with open(constants.TOKEN_FILE_PATH, 'w') as token_file:
243      if creds.token_response:
244        token_file.write(creds.token_response['access_token'])
245      else:
246        token_file.write(creds.access_token)
247    return creds, inv
248  return None, None
251def fetch_credential():
252  """Fetch the credential object."""
253  creds_path = atest_utils.get_config_folder().joinpath(
254      constants.CREDENTIAL_FILE_NAME
255  )
256  return GCPHelper(
257      client_id=constants.CLIENT_ID,
258      client_secret=constants.CLIENT_SECRET,
259      user_agent='atest',
260  ).get_credential_with_auth_flow(creds_path)
263def _prepare_data(client, atest_run_id: str):
264  """Prepare data for build api using.
266  Args:
267      build_client: The logstorage_utils.BuildClient object.
268      atest_run_id: The atest run ID to write into the invocation.
270  Return:
271      invocation and workunit object.
272      build id and build target of local build.
273  """
274  try:
275    logging.disable(logging.INFO)
276    external_id = str(uuid.uuid4())
277    branch = _get_branch(client)
278    target = _get_target(branch, client)
279    build_record = client.insert_local_build(external_id, target, branch)
280    client.insert_build_attempts(build_record)
281    invocation = client.insert_invocation(build_record, atest_run_id)
282    workunit = client.insert_work_unit(invocation)
283    return invocation, workunit, build_record['buildId'], target
284  finally:
285    logging.disable(logging.NOTSET)
288def _get_branch(build_client):
289  """Get source code tree branch.
291  Args:
292      build_client: The build client object.
294  Return:
295      "git_main" in internal git, "aosp-main" otherwise.
296  """
297  default_branch = 'git_main' if constants.CREDENTIAL_FILE_NAME else 'aosp-main'
298  local_branch = 'git_%s' % atest_utils.get_manifest_branch()
299  branch = build_client.get_branch(local_branch)
300  return local_branch if branch else default_branch
303def _get_target(branch, build_client):
304  """Get local build selected target.
306  Args:
307      branch: The branch want to check.
308      build_client: The build client object.
310  Return:
311      The matched build target, "aosp_x86_64-trunk_staging-userdebug"
312      otherwise.
313  """
314  default_target = 'aosp_x86_64-trunk_staging-userdebug'
315  local_target = atest_utils.get_build_target()
316  targets = [t['target'] for t in build_client.list_target(branch)['targets']]
317  return local_target if local_target in targets else default_target