1# Copyright 2018, The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""SUITE Tradefed test runner class."""
16
17import copy
18import logging
19import os
20from typing import Any, Dict, List
21
22from atest import atest_utils
23from atest import constants
24from atest.atest_enum import ExitCode
25from atest.logstorage import logstorage_utils
26from atest.metrics import metrics
27from atest.test_finders import test_info
28from atest.test_runners import atest_tf_test_runner
29
30
31class SuitePlanTestRunner(atest_tf_test_runner.AtestTradefedTestRunner):
32  """Suite Plan Test Runner class."""
33
34  NAME = 'SuitePlanTestRunner'
35  EXECUTABLE = '%s-tradefed'
36  _RUN_CMD = '{exe} run commandAndExit {test} {args}'
37
38  def __init__(self, results_dir: str, extra_args: Dict[str, Any], **kwargs):
39    """Init stuff for suite tradefed runner class."""
40    super().__init__(results_dir, extra_args, **kwargs)
41    self.run_cmd_dict = {'exe': '', 'test': '', 'args': ''}
42
43  def get_test_runner_build_reqs(self, test_infos: List[test_info.TestInfo]):
44    """Return the build requirements.
45
46    Args:
47        test_infos: List of TestInfo.
48
49    Returns:
50        Set of build targets.
51    """
52    build_req = set()
53    build_req |= super().get_test_runner_build_reqs(test_infos)
54    return build_req
55
56  def run_tests(self, test_infos, extra_args, reporter):
57    """Run the list of test_infos.
58
59    Args:
60        test_infos: List of TestInfo.
61        extra_args: Dict of extra args to add to test run.
62        reporter: An instance of result_report.ResultReporter.
63
64    Returns:
65        Return code of the process for running tests.
66    """
67    reporter.register_unsupported_runner(self.NAME)
68    creds, inv = (
69        logstorage_utils.do_upload_flow(extra_args)
70        if logstorage_utils.is_upload_enabled(extra_args)
71        else (None, None)
72    )
73
74    run_cmds = self.generate_run_commands(test_infos, extra_args)
75    ret_code = ExitCode.SUCCESS
76    for run_cmd in run_cmds:
77      try:
78        proc = super().run(
79            run_cmd,
80            output_to_stdout=True,
81            env_vars=self.generate_env_vars(extra_args),
82        )
83        ret_code |= self.wait_for_subprocess(proc)
84      finally:
85        if inv:
86          try:
87            logging.disable(logging.INFO)
88            # Always set invocation status to completed due to
89            # the ATest handle whole process by its own.
90            inv['schedulerState'] = 'completed'
91            logstorage_utils.BuildClient(creds).update_invocation(inv)
92            reporter.test_result_link = (
93                constants.RESULT_LINK % inv['invocationId']
94            )
95          finally:
96            logging.disable(logging.NOTSET)
97    return ret_code
98
99  # pylint: disable=arguments-differ
100  def _parse_extra_args(self, extra_args):
101    """Convert the extra args into something *ts-tf can understand.
102
103    We want to transform the top-level args from atest into specific args
104    that *ts-tradefed supports. The only arg we take as is
105    EXTRA_ARG since that is what the user intentionally wants to pass to
106    the test runner.
107
108    Args:
109        extra_args: Dict of args
110
111    Returns:
112        List of args to append.
113    """
114    args_to_append = []
115    args_not_supported = []
116    for arg in extra_args:
117      if constants.SERIAL == arg:
118        args_to_append.append('--serial')
119        args_to_append.append(extra_args[arg])
120        continue
121      if constants.CUSTOM_ARGS == arg:
122        args_to_append.extend(extra_args[arg])
123        continue
124      if constants.INVOCATION_ID == arg:
125        args_to_append.append(
126            '--invocation-data invocation_id=%s' % extra_args[arg]
127        )
128      if constants.WORKUNIT_ID == arg:
129        args_to_append.append(
130            '--invocation-data work_unit_id=%s' % extra_args[arg]
131        )
132      if arg in (constants.DRY_RUN, constants.REQUEST_UPLOAD_RESULT):
133        continue
134      if constants.TF_DEBUG == arg:
135        debug_port = extra_args.get(constants.TF_DEBUG, '')
136        port = debug_port if debug_port else constants.DEFAULT_DEBUG_PORT
137        print('Please attach process to your IDE...(%s)' % port)
138        continue
139      args_not_supported.append(arg)
140    if args_not_supported:
141      atest_utils.print_and_log_info(
142          '%s does not support the following args: %s',
143          self.EXECUTABLE,
144          args_not_supported,
145      )
146    return args_to_append
147
148  # pylint: disable=arguments-differ
149  def generate_run_commands(self, test_infos, extra_args):
150    """Generate a list of run commands from TestInfos.
151
152    Args:
153        test_infos: List of TestInfo tests to run.
154        extra_args: Dict of extra args to add to test run.
155
156    Returns:
157        A List of strings that contains the run command
158        which *ts-tradefed supports.
159    """
160    cmds = []
161    args = []
162    args.extend(self._parse_extra_args(extra_args))
163    args.extend(atest_utils.get_result_server_args())
164    for test_info in test_infos:
165      cmd_dict = copy.deepcopy(self.run_cmd_dict)
166      cmd_dict['test'] = test_info.test_name
167      cmd_dict['args'] = ' '.join(args)
168      cmd_dict['exe'] = self.EXECUTABLE % test_info.suite
169      cmds.append(self._RUN_CMD.format(**cmd_dict))
170      if constants.DETECT_TYPE_XTS_SUITE:
171        xts_detect_type = constants.DETECT_TYPE_XTS_SUITE.get(
172            test_info.suite, ''
173        )
174        if xts_detect_type:
175          metrics.LocalDetectEvent(detect_type=xts_detect_type, result=1)
176    return cmds
177
178  def generate_env_vars(self, extra_args):
179    """Convert extra args into env vars."""
180    env_vars = os.environ.copy()
181    debug_port = extra_args.get(constants.TF_DEBUG, '')
182    if debug_port:
183      env_vars['TF_DEBUG'] = 'true'
184      env_vars['TF_DEBUG_PORT'] = str(debug_port)
185    if constants.TF_GLOBAL_CONFIG:
186      env_vars['TF_GLOBAL_CONFIG'] = constants.TF_GLOBAL_CONFIG
187    return env_vars
188