1# Copyright 2023, The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Device update methods used to prepare the device under test."""
16
17from abc import ABC, abstractmethod
18from pathlib import Path
19import subprocess
20from subprocess import CalledProcessError
21import time
22from typing import List, Set
23
24from atest import atest_utils
25from atest import constants
26
27
28class DeviceUpdateMethod(ABC):
29  """A device update method used to update device."""
30
31  @abstractmethod
32  def update(self, serials: List[str] = None):
33    """Updates the device.
34
35    Args:
36        serials: A list of serial numbers.
37
38    Raises:
39        Error: If the device update fails.
40    """
41
42  @abstractmethod
43  def dependencies(self) -> Set[str]:
44    """Returns the dependencies required by this device update method."""
45
46
47class NoopUpdateMethod(DeviceUpdateMethod):
48
49  def update(self, serials: List[str] = None) -> None:
50    pass
51
52  def dependencies(self) -> Set[str]:
53    return set()
54
55
56class AdeviceUpdateMethod(DeviceUpdateMethod):
57  _TOOL = 'adevice'
58
59  def __init__(self, adevice_path: Path=_TOOL, targets: Set[str]=None):
60    self._adevice_path = adevice_path
61    self._targets = targets or set(['sync'])
62
63  def update(self, serials: List[str] = None) -> None:
64    try:
65      print(atest_utils.mark_cyan('\nUpdating device...'))
66      update_start = time.time()
67
68      update_cmd = [self._adevice_path, 'update']
69      if serials:
70        if len(serials) > 1:
71          atest_utils.colorful_print(
72              'Warning: Device update feature can only update one '
73              'device for now, but this invocation specifies more '
74              'than one device. Atest will update the first device '
75              'by default.',
76              constants.YELLOW,
77          )
78
79        update_cmd.extend(['--serial', serials[0]])
80
81      subprocess.check_call(update_cmd)
82
83      print(
84          atest_utils.mark_cyan(
85              '\nDevice update finished in '
86              f'{str(round(time.time() - update_start, 2))}s.'
87          )
88      )
89
90    except CalledProcessError as e:
91      raise Error('Failed to update the device with adevice') from e
92
93  def dependencies(self) -> Set[str]:
94    return self._targets.union({self._TOOL})
95
96
97class Error(Exception):
98  pass
99