1#!/usr/bin/env python3
2#
3# Copyright 2022, The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Integration tests for the Atest Bazel mode feature."""
18
19# pylint: disable=invalid-name
20# pylint: disable=missing-class-docstring
21# pylint: disable=missing-function-docstring
22
23import dataclasses
24import os
25from pathlib import Path
26import shutil
27import subprocess
28import tempfile
29from typing import Any, Dict, List, Tuple
30import unittest
31
32
33_ENV_BUILD_TOP = 'ANDROID_BUILD_TOP'
34_PASSING_CLASS_NAME = 'PassingHostTest'
35_FAILING_CLASS_NAME = 'FailingHostTest'
36_PASSING_METHOD_NAME = 'testPass'
37_FAILING_METHOD_NAME = 'testFAIL'
38
39
40@dataclasses.dataclass(frozen=True)
41class JavaSourceFile:
42  class_name: str
43  src_body: str
44
45
46class BazelModeTest(unittest.TestCase):
47
48  def setUp(self):
49    self.src_root_path = Path(os.environ['ANDROID_BUILD_TOP'])
50    self.test_dir = self.src_root_path.joinpath('atest_bazel_mode_test')
51    if self.test_dir.exists():
52      shutil.rmtree(self.test_dir)
53    self.out_dir_path = Path(tempfile.mkdtemp())
54    self.test_env = self.setup_test_env()
55
56  def tearDown(self):
57    shutil.rmtree(self.test_dir)
58    shutil.rmtree(self.out_dir_path)
59
60  def test_passing_test_returns_zero_exit_code(self):
61    module_name = 'passing_java_host_test'
62    self.add_passing_test(module_name)
63
64    completed_process = self.run_shell_command(
65        f'atest -c -m --bazel-mode {module_name}'
66    )
67
68    self.assertEqual(completed_process.returncode, 0)
69
70  def test_failing_test_returns_nonzero_exit_code(self):
71    module_name = 'failing_java_host_test'
72    self.add_failing_test(module_name)
73
74    completed_process = self.run_shell_command(
75        f'atest -c -m --bazel-mode {module_name}'
76    )
77
78    self.assertNotEqual(completed_process.returncode, 0)
79
80  def test_passing_test_is_cached_when_rerun(self):
81    module_name = 'passing_java_host_test'
82    self.add_passing_test(module_name)
83
84    completed_process = self.run_shell_command(
85        f'atest -c -m --bazel-mode {module_name} && '
86        f'atest --bazel-mode {module_name}'
87    )
88
89    self.assert_in_stdout(
90        f':{module_name}_host (cached) PASSED', completed_process
91    )
92
93  def test_cached_test_reruns_when_modified(self):
94    module_name = 'passing_java_host_test'
95    java_test_file, _ = self.write_java_test_module(
96        module_name, passing_java_test_source()
97    )
98    self.run_shell_command(f'atest -c -m --bazel-mode {module_name}')
99
100    java_test_file.write_text(
101        failing_java_test_source(test_class_name=_PASSING_CLASS_NAME).src_body
102    )
103    completed_process = self.run_shell_command(
104        f'atest --bazel-mode {module_name}'
105    )
106
107    self.assert_in_stdout(f':{module_name}_host FAILED', completed_process)
108
109  def test_only_supported_test_run_with_bazel(self):
110    module_name = 'passing_java_host_test'
111    unsupported_module_name = 'unsupported_passing_java_test'
112    self.add_passing_test(module_name)
113    self.add_unsupported_passing_test(unsupported_module_name)
114
115    completed_process = self.run_shell_command(
116        f'atest -c -m --host --bazel-mode {module_name} '
117        f'{unsupported_module_name}'
118    )
119
120    self.assert_in_stdout(f':{module_name}_host PASSED', completed_process)
121    self.assert_in_stdout(
122        f'{_PASSING_CLASS_NAME}#{_PASSING_METHOD_NAME}: PASSED',
123        completed_process,
124    )
125
126  def test_defaults_to_device_variant(self):
127    module_name = 'passing_cc_host_test'
128    self.write_cc_test_module(module_name, passing_cc_test_source())
129
130    completed_process = self.run_shell_command(
131        f'atest -c -m --bazel-mode {module_name}'
132    )
133
134    self.assert_in_stdout('AtestTradefedTestRunner:', completed_process)
135
136  def test_runs_host_variant_when_requested(self):
137    module_name = 'passing_cc_host_test'
138    self.write_cc_test_module(module_name, passing_cc_test_source())
139
140    completed_process = self.run_shell_command(
141        f'atest -c -m --host --bazel-mode {module_name}'
142    )
143
144    self.assert_in_stdout(f':{module_name}_host   PASSED', completed_process)
145
146  def test_ignores_host_arg_for_device_only_test(self):
147    module_name = 'passing_cc_device_test'
148    self.write_cc_test_module(
149        module_name, passing_cc_test_source(), host_supported=False
150    )
151
152    completed_process = self.run_shell_command(
153        f'atest -c -m --host --bazel-mode {module_name}'
154    )
155
156    self.assert_in_stdout(
157        'Specified --host, but the following tests are device-only',
158        completed_process,
159    )
160
161  def test_supports_extra_tradefed_reporters(self):
162    test_module_name = 'passing_java_host_test'
163    self.add_passing_test(test_module_name)
164
165    reporter_module_name = 'test-result-reporter'
166    reporter_class_name = 'TestResultReporter'
167    expected_output_string = '0xFEEDF00D'
168
169    self.write_java_reporter_module(
170        reporter_module_name,
171        java_reporter_source(reporter_class_name, expected_output_string),
172    )
173
174    self.run_shell_command(f'm {reporter_module_name}', check=True)
175    self.run_shell_command(
176        f'atest -c -m --bazel-mode {test_module_name} --dry-run', check=True
177    )
178    self.run_shell_command(
179        f'cp ${{ANDROID_HOST_OUT}}/framework/{reporter_module_name}.jar '
180        f'{self.out_dir_path}/atest_bazel_workspace/tools/asuite/atest/'
181        'bazel/reporter/bazel-result-reporter/host/framework/.',
182        check=True,
183    )
184
185    completed_process = self.run_shell_command(
186        f'atest --bazel-mode {test_module_name} --bazel-arg='
187        '--//bazel/rules:extra_tradefed_result_reporters=android.'
188        f'{reporter_class_name} --bazel-arg=--test_output=all',
189        check=True,
190    )
191
192    self.assert_in_stdout(expected_output_string, completed_process)
193
194  def setup_test_env(self) -> Dict[str, Any]:
195    test_env = {
196        'PATH': os.environ['PATH'],
197        'HOME': os.environ['HOME'],
198        'OUT_DIR': str(self.out_dir_path),
199    }
200    return test_env
201
202  def run_shell_command(
203      self, shell_command: str, check: bool = False
204  ) -> subprocess.CompletedProcess:
205    return subprocess.run(
206        '. build/envsetup.sh && '
207        'lunch aosp_cf_x86_64_pc-userdebug && '
208        f'{shell_command}',
209        env=self.test_env,
210        cwd=self.src_root_path,
211        shell=True,
212        check=check,
213        stderr=subprocess.STDOUT,
214        stdout=subprocess.PIPE,
215    )
216
217  def add_passing_test(self, module_name: str):
218    self.write_java_test_module(module_name, passing_java_test_source())
219
220  def add_failing_test(self, module_name: str):
221    self.write_java_test_module(module_name, failing_java_test_source())
222
223  def add_unsupported_passing_test(self, module_name: str):
224    self.write_java_test_module(
225        module_name, passing_java_test_source(), unit_test=False
226    )
227
228  def write_java_test_module(
229      self,
230      module_name: str,
231      test_src: JavaSourceFile,
232      unit_test: bool = True,
233  ) -> Tuple[Path, Path]:
234    test_dir = self.test_dir.joinpath(module_name)
235    test_dir.mkdir(parents=True, exist_ok=True)
236
237    src_file_name = f'{test_src.class_name}.java'
238    src_file_path = test_dir.joinpath(f'{src_file_name}')
239    src_file_path.write_text(test_src.src_body, encoding='utf8')
240
241    bp_file_path = test_dir.joinpath('Android.bp')
242    bp_file_path.write_text(
243        android_bp(
244            java_test_host(
245                name=module_name,
246                srcs=[
247                    str(src_file_name),
248                ],
249                unit_test=unit_test,
250            ),
251        ),
252        encoding='utf8',
253    )
254    return (src_file_path, bp_file_path)
255
256  def write_cc_test_module(
257      self,
258      module_name: str,
259      test_src: str,
260      host_supported: bool = True,
261  ) -> Tuple[Path, Path]:
262    test_dir = self.test_dir.joinpath(module_name)
263    test_dir.mkdir(parents=True, exist_ok=True)
264
265    src_file_name = f'{module_name}.cpp'
266    src_file_path = test_dir.joinpath(f'{src_file_name}')
267    src_file_path.write_text(test_src, encoding='utf8')
268
269    bp_file_path = test_dir.joinpath('Android.bp')
270    bp_file_path.write_text(
271        android_bp(
272            cc_test(
273                name=module_name,
274                srcs=[
275                    str(src_file_name),
276                ],
277                host_supported=host_supported,
278            ),
279        ),
280        encoding='utf8',
281    )
282    return (src_file_path, bp_file_path)
283
284  def write_java_reporter_module(
285      self,
286      module_name: str,
287      reporter_src: JavaSourceFile,
288  ) -> Tuple[Path, Path]:
289    test_dir = self.test_dir.joinpath(module_name)
290    test_dir.mkdir(parents=True, exist_ok=True)
291
292    src_file_name = f'{reporter_src.class_name}.java'
293    src_file_path = test_dir.joinpath(f'{src_file_name}')
294    src_file_path.write_text(reporter_src.src_body, encoding='utf8')
295
296    bp_file_path = test_dir.joinpath('Android.bp')
297    bp_file_path.write_text(
298        android_bp(
299            java_library(
300                name=module_name,
301                srcs=[
302                    str(src_file_name),
303                ],
304            ),
305        ),
306        encoding='utf8',
307    )
308    return (src_file_path, bp_file_path)
309
310  def assert_in_stdout(
311      self,
312      message: str,
313      completed_process: subprocess.CompletedProcess,
314  ):
315    self.assertIn(message, completed_process.stdout.decode())
316
317
318def passing_java_test_source() -> JavaSourceFile:
319  return java_test_source(
320      test_class_name=_PASSING_CLASS_NAME,
321      test_method_name=_PASSING_METHOD_NAME,
322      test_method_body='Assert.assertEquals("Pass", "Pass");',
323  )
324
325
326def failing_java_test_source(
327    test_class_name=_FAILING_CLASS_NAME,
328) -> JavaSourceFile:
329  return java_test_source(
330      test_class_name=test_class_name,
331      test_method_name=_FAILING_METHOD_NAME,
332      test_method_body='Assert.assertEquals("Pass", "Fail");',
333  )
334
335
336def java_test_source(
337    test_class_name: str,
338    test_method_name: str,
339    test_method_body: str,
340) -> JavaSourceFile:
341  return JavaSourceFile(
342      test_class_name,
343      f"""\
344package android;
345
346import org.junit.Assert;
347import org.junit.Test;
348import org.junit.runners.JUnit4;
349import org.junit.runner.RunWith;
350
351@RunWith(JUnit4.class)
352public final class {test_class_name} {{
353
354    @Test
355    public void {test_method_name}() {{
356        {test_method_body}
357    }}
358}}
359""",
360  )
361
362
363def java_reporter_source(
364    reporter_class_name: str,
365    output_string: str,
366) -> JavaSourceFile:
367  return JavaSourceFile(
368      reporter_class_name,
369      f"""\
370package android;
371
372import com.android.tradefed.result.ITestInvocationListener;
373
374public final class {reporter_class_name} implements ITestInvocationListener {{
375
376    @Override
377    public void invocationEnded(long elapsedTime) {{
378        System.out.println("{output_string}");
379    }}
380}}
381""",
382  )
383
384
385def passing_cc_test_source() -> str:
386  return cc_test_source(
387      test_suite_name='TestSuite', test_name='PassingTest', test_body=''
388  )
389
390
391def cc_test_source(
392    test_suite_name: str,
393    test_name: str,
394    test_body: str,
395) -> str:
396  return f"""\
397#include <gtest/gtest.h>
398
399TEST({test_suite_name}, {test_name}) {{
400    {test_body}
401}}
402"""
403
404
405def android_bp(
406    modules: str = '',
407) -> str:
408  return f"""\
409package {{
410    default_applicable_licenses: ["Android-Apache-2.0"],
411}}
412
413{modules}
414"""
415
416
417def cc_test(
418    name: str,
419    srcs: List[str],
420    host_supported: bool,
421) -> str:
422  src_files = ',\n'.join([f'"{f}"' for f in srcs])
423
424  return f"""\
425cc_test {{
426    name: "{name}",
427    srcs: [
428        {src_files},
429    ],
430    test_options: {{
431        unit_test: true,
432    }},
433    host_supported: {str(host_supported).lower()},
434}}
435"""
436
437
438def java_test_host(
439    name: str,
440    srcs: List[str],
441    unit_test: bool,
442) -> str:
443  src_files = ',\n'.join([f'"{f}"' for f in srcs])
444
445  return f"""\
446java_test_host {{
447    name: "{name}",
448    srcs: [
449        {src_files},
450    ],
451    test_options: {{
452        unit_test: {str(unit_test).lower()},
453    }},
454    static_libs: [
455        "junit",
456    ],
457}}
458"""
459
460
461def java_library(
462    name: str,
463    srcs: List[str],
464) -> str:
465  src_files = ',\n'.join([f'"{f}"' for f in srcs])
466
467  return f"""\
468java_library_host {{
469    name: "{name}",
470    srcs: [
471        {src_files},
472    ],
473    libs: [
474        "tradefed",
475    ],
476}}
477"""
478
479
480if __name__ == '__main__':
481  unittest.main(verbosity=2)
482