1# Copyright 2017, The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Utility functions for unit tests."""
16
17import os
18
19from atest import constants
20from atest import unittest_constants as uc
21
22
23def assert_strict_equal(test_class, first, second):
24  """Check for strict equality and strict equality of nametuple elements.
25
26  assertEqual considers types equal to their subtypes, but we want to
27  not consider set() and frozenset() equal for testing.
28  """
29  # Allow 2 lists with different order but the same content equal.
30  if isinstance(first, list) and isinstance(second, list):
31    first.sort()
32    second.sort()
33  test_class.assertEqual(first, second)
34  # allow byte and unicode string equality.
35  if not (isinstance(first, str) and isinstance(second, str)):
36    test_class.assertIsInstance(first, type(second))
37    test_class.assertIsInstance(second, type(first))
38  # Recursively check elements of namedtuples for strict equals.
39  if isinstance(first, tuple) and hasattr(first, '_fields'):
40    # pylint: disable=invalid-name
41    for f in first._fields:
42      assert_strict_equal(test_class, getattr(first, f), getattr(second, f))
43
44
45def assert_equal_testinfos(test_class, test_info_a, test_info_b):
46  """Check that the passed in TestInfos are equal."""
47  # Use unittest.assertEqual to do checks when None is involved.
48  if test_info_a is None or test_info_b is None:
49    test_class.assertEqual(test_info_a, test_info_b)
50    return
51
52  for attr in test_info_a.__dict__:
53    test_info_a_attr = getattr(test_info_a, attr)
54    test_info_b_attr = getattr(test_info_b, attr)
55    test_class.assertEqual(
56        test_info_a_attr,
57        test_info_b_attr,
58        msg=(
59            'TestInfo.%s mismatch: %s != %s'
60            % (attr, test_info_a_attr, test_info_b_attr)
61        ),
62    )
63
64
65def assert_equal_testinfo_sets(test_class, test_info_set_a, test_info_set_b):
66  """Check that the sets of TestInfos are equal."""
67  test_class.assertEqual(
68      len(test_info_set_a),
69      len(test_info_set_b),
70      msg=(
71          'mismatch # of TestInfos: %d != %d'
72          % (len(test_info_set_a), len(test_info_set_b))
73      ),
74  )
75  # Iterate over a set and pop them out as you compare them.
76  while test_info_set_a:
77    test_info_a = test_info_set_a.pop()
78    test_info_b_to_remove = None
79    for test_info_b in test_info_set_b:
80      try:
81        assert_equal_testinfos(test_class, test_info_a, test_info_b)
82        test_info_b_to_remove = test_info_b
83        break
84      except AssertionError:
85        pass
86    if test_info_b_to_remove:
87      test_info_set_b.remove(test_info_b_to_remove)
88    else:
89      # We haven't found a match, raise an assertion error.
90      raise AssertionError(
91          'No matching TestInfo (%s) in [%s]'
92          % (test_info_a, ';'.join([str(t) for t in test_info_set_b]))
93      )
94
95
96def assert_equal_testinfo_lists(test_class, test_info_list_a, test_info_list_b):
97  """Check that the passed in TestInfos are equal."""
98  # Use unittest.assertEqual to do checks when None is involved.
99  if test_info_list_a is None or test_info_list_a is None:
100    test_class.assertEqual(test_info_list_a, test_info_list_a)
101    return
102
103  for i, test_info_a in enumerate(test_info_list_a):
104    assert_equal_testinfos(test_class, test_info_a, test_info_list_b[i])
105
106
107# pylint: disable=too-many-return-statements
108def isfile_side_effect(value):
109  """Mock return values for os.path.isfile."""
110  value = str(value)
111  if value == '/%s/%s' % (uc.CC_MODULE_DIR, constants.MODULE_CONFIG):
112    return True
113  if value == '/%s/%s' % (uc.MODULE_DIR, constants.MODULE_CONFIG):
114    return True
115  if value.endswith('.cc'):
116    return True
117  if value.endswith('.cpp'):
118    return True
119  if value.endswith('.java'):
120    return True
121  if value.endswith('.kt'):
122    return True
123  if value.endswith(uc.INT_NAME + '.xml'):
124    return True
125  if value.endswith(uc.GTF_INT_NAME + '.xml'):
126    return True
127  if value.endswith(
128      '/%s/%s' % (uc.ANDTEST_CONFIG_PATH, constants.MODULE_CONFIG)
129  ):
130    return True
131  if value.endswith('/%s/%s' % (uc.SINGLE_CONFIG_PATH, uc.SINGLE_CONFIG_NAME)):
132    return True
133  if value.endswith('/%s/%s' % (uc.MULTIPLE_CONFIG_PATH, uc.MAIN_CONFIG_NAME)):
134    return True
135  if value.endswith('/%s/%s' % (uc.MULTIPLE_CONFIG_PATH, uc.SUB_CONFIG_NAME_2)):
136    return True
137  return False
138
139
140def realpath_side_effect(path):
141  """Mock return values for os.path.realpath."""
142  return os.path.join(uc.ROOT, path)
143