1#!/usr/bin/env python3
2#
3# Copyright 2023, The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Preserves and restores the state of a repository.
18
19This module includes a `Snapshot` class that provides methods to:
20- Take snapshots of a directory, including or excluding specified paths.
21- Preserve environment variables.
22- Restore the directory state from previously taken snapshots, managing file and
23directory deletions and replacements.
24"""
25
26import functools
27import glob
28import hashlib
29import json
30import logging
31import os
32import pathlib
33import threading
34from typing import Any, Optional
35
36
37def _synchronized(func):
38  """Ensures thread-safe execution of the wrapped function."""
39  lock = threading.Lock()
40
41  @functools.wraps(func)
42  def _synchronized_func(*args, **kwargs):
43    with lock:
44      return func(*args, **kwargs)
45
46  return _synchronized_func
47
48
49class Snapshot:
50  """Provides functionality to take and restore snapshots of a directory."""
51
52  def __init__(self, storage_dir: pathlib.Path):
53    """Initializes a Snapshot object.
54
55    Args:
56        storage_dir: The directory where snapshots will be stored.
57    """
58    self._dir_snapshot = _DirSnapshot(storage_dir)
59    self._env_snapshot = _EnvSnapshot(storage_dir)
60    self._obj_snapshot = _ObjectSnapshot(storage_dir)
61    self._lock = self._get_threading_lock(storage_dir)
62
63  @_synchronized
64  def _get_threading_lock(
65      self,
66      name: str,
67  ):
68    """Gets a threading lock for the snapshot directory."""
69    locks_dict_attr_name = 'threading_locks'
70    current_function = self._get_threading_lock.__func__
71    if not hasattr(current_function, locks_dict_attr_name):
72      setattr(current_function, locks_dict_attr_name, {})
73    if name not in getattr(current_function, locks_dict_attr_name):
74      getattr(current_function, locks_dict_attr_name)[name] = threading.Lock()
75    return getattr(current_function, locks_dict_attr_name)[name]
76
77  # pylint: disable=too-many-arguments
78  def take_snapshot(
79      self,
80      name: str,
81      root_path: str,
82      include_paths: list[str],
83      exclude_paths: Optional[list[str]] = None,
84      env_keys: Optional[list[str]] = None,
85      env: Optional[dict[str, str]] = None,
86      objs: Optional[dict[str, Any]] = None,
87  ) -> None:
88    """Takes a snapshot of the directory at the given path.
89
90    Args:
91        name: The name of the snapshot.
92        root_path: The path to the directory to snapshot.
93        include_paths: A list of relative paths to include in the snapshot.
94        exclude_paths: A list of relative paths to exclude from the snapshot.
95        env_keys: A list of environment variable keys to save.
96        env: Environment variables to use while restoring.
97        objs: A dictionary of objects to save. The current implementation limits
98          the type of objects to the types that can be serialized by the json
99          module.
100    """
101    with self._lock:
102      self._dir_snapshot.take_snapshot(
103          name, root_path, include_paths, exclude_paths, env
104      )
105      self._env_snapshot.take_snapshot(name, env_keys)
106      self._obj_snapshot.take_snapshot(name, objs)
107
108  def restore_snapshot(
109      self,
110      name: str,
111      root_path: str,
112      exclude_paths: Optional[list[str]] = None,
113  ) -> tuple[dict[str, str], dict[str, Any]]:
114    """Restores directory at given path to a snapshot with the given name.
115
116    Args:
117        name: The name of the snapshot.
118        root_path: The path to the target directory.
119        exclude_paths: A list of paths to ignore during restore.
120
121    Returns:
122        A tuple of restored environment variables and object dictionary.
123    """
124    env = self._env_snapshot.restore_snapshot(name, root_path)
125    objs = self._obj_snapshot.restore_snapshot(name)
126    with self._lock:
127      self._dir_snapshot.restore_snapshot(name, root_path, exclude_paths, env)
128    return env, objs
129
130
131class _ObjectSnapshot:
132  """Save and restore a dictionary of objects through json."""
133
134  def __init__(self, storage_path: pathlib.Path):
135    self._storage_path = storage_path
136
137  def take_snapshot(
138      self,
139      name: str,
140      objs: Optional[dict[str, Any]] = None,
141  ) -> None:
142    """Save a dictionary of objects in snapshot.
143
144    Args:
145        name: The name of the snapshot
146        objs: A dictionary of objects to snapshot. Note: The current
147          implementation limits the type of objects to the types that can be
148          serialized by the json module.
149    """
150    if objs is None:
151      objs = {}
152    with open(
153        self._storage_path.joinpath('%s.objs.json' % name),
154        'w',
155        encoding='utf-8',
156    ) as f:
157      json.dump(objs, f)
158
159  def restore_snapshot(self, name: str) -> dict[str, Any]:
160    """Restore saved objects from snapshot."""
161    with open(
162        self._storage_path.joinpath('%s.objs.json' % name),
163        'r',
164        encoding='utf-8',
165    ) as f:
166      return json.load(f)
167
168
169class _EnvSnapshot:
170  """Save and restore environment variables."""
171
172  _repo_root_placeholder = '<repo_root_placeholder>'
173
174  def __init__(self, storage_path: pathlib.Path):
175    self._storage_path = storage_path
176
177  def take_snapshot(
178      self,
179      name: str,
180      env_keys: Optional[list[str]] = None,
181  ) -> None:
182    """Save a subset of environment variables."""
183    if env_keys is None:
184      env_keys = []
185    original_env = os.environ.copy()
186    subset_env = {
187        key: os.environ[key] for key in env_keys if key in original_env
188    }
189    modified_env = {
190        key: value.replace(
191            os.environ['ANDROID_BUILD_TOP'], self._repo_root_placeholder
192        )
193        for key, value in subset_env.items()
194    }
195    with open(self._get_env_file_path(name), 'w', encoding='utf-8') as f:
196      json.dump(modified_env, f)
197
198  def restore_snapshot(self, name: str, root_path: str) -> dict[str, str]:
199    """Load saved environment variables."""
200    with self._get_env_file_path(name).open('r') as f:
201      loaded_env = json.load(f)
202    restored_env = {
203        key: value.replace(
204            self._repo_root_placeholder,
205            root_path,
206        )
207        for key, value in loaded_env.items()
208    }
209    if 'PATH' in os.environ:
210      if 'PATH' in restored_env:
211        restored_env['PATH'] = restored_env['PATH'] + ':' + os.environ['PATH']
212      else:
213        restored_env['PATH'] = os.environ['PATH']
214    return restored_env
215
216  def _get_env_file_path(self, name: str) -> pathlib.Path:
217    """Get environment file path."""
218    return self._storage_path / (name + '_env.json')
219
220
221class _FileInfo:
222  """An object to save file information."""
223
224  # pylint: disable=too-many-arguments
225  def __init__(
226      self,
227      path: str,
228      timestamp: float,
229      content_hash: str,
230      permissions: int,
231      symlink_target: str,
232      is_directory: bool,
233      is_target_in_workspace: bool = False,
234  ):
235    self.path = path
236    self.timestamp = timestamp
237    self.content_hash = content_hash
238    self.permissions = permissions
239    self.symlink_target = symlink_target
240    self.is_directory = is_directory
241    self.is_target_in_workspace = is_target_in_workspace
242
243
244class _BlobStore:
245  """Class to save and load file content."""
246
247  def __init__(self, path: str):
248    self.path = pathlib.Path(path)
249    self.cache = self._load_cache()
250
251  def add(self, path: pathlib.Path, timestamp: float) -> str:
252    """Add a file path to the store."""
253    cache_key = path.as_posix() + str(timestamp)
254    if cache_key in self.cache:
255      return self.cache[cache_key]
256    content = path.read_bytes()
257    content_hash = hashlib.sha256(content).hexdigest()
258    content_path = self.path.joinpath(content_hash[:2], content_hash[2:])
259    if not content_path.exists():
260      content_path.parent.mkdir(parents=True, exist_ok=True)
261      content_path.write_bytes(content)
262    self.cache[cache_key] = content_hash
263    return content_hash
264
265  def get(self, content_hash: str) -> bytes:
266    """Read file content from a content hash."""
267    file_path = self.path.joinpath(content_hash[:2], content_hash[2:])
268    if file_path.exists():
269      return file_path.read_bytes()
270    return None
271
272  def dump_cache(self) -> None:
273    """Dump the saved file path cache to speed up next run."""
274    self._get_cache_path().parent.mkdir(parents=True, exist_ok=True)
275    with self._get_cache_path().open('w', encoding='utf-8') as f:
276      json.dump(self.cache, f)
277
278  def _load_cache(self) -> dict[str, str]:
279    if not self._get_cache_path().exists():
280      return {}
281    with self._get_cache_path().open('r', encoding='utf-8') as f:
282      return json.load(f)
283
284  def _get_cache_path(self) -> pathlib.Path:
285    return self.path.joinpath('cache.json')
286
287
288class _DirSnapshot:
289  """Class to take and restore snapshot for a directory path."""
290
291  def __init__(self, storage_path: pathlib.Path):
292    self._storage_path = storage_path
293    self._blob_store = _BlobStore(self._storage_path.joinpath('blobs'))
294
295  def _expand_vars_paths(
296      self, paths: list[str], variables: dict[str, str]
297  ) -> list[str]:
298    """Expand variables in paths with the given environment variables.
299
300    This function is similar to os.path.expandvars(path) which relies on
301    os.environ.
302
303    Args:
304        paths: A list of paths that might contains variables to expand.
305        variables: A dictionary of variable names and values.
306
307    Returns:
308        A list containing paths whose variables have been expanded if known.
309    """
310    if not variables:
311      return paths
312    path_result = paths.copy()
313    for idx, _ in enumerate(path_result):
314      for key, val in sorted(
315          variables.items(), key=lambda item: len(item[0]), reverse=True
316      ):
317        path_result[idx] = path_result[idx].replace(f'${key}', val)
318    return path_result
319
320  def _expand_wildcard_paths(
321      self,
322      root_path: str,
323      paths: list[str],
324      env: Optional[dict[str, str]] = None,
325  ) -> list[str]:
326    """Expand wildcard paths."""
327    compose = lambda inner, outer: lambda path: outer(inner(path))
328    get_abs_path = (
329        lambda path: path
330        if os.path.isabs(path)
331        else os.path.join(root_path, path)
332    )
333    glob_path = functools.partial(glob.glob, recursive=True)
334    return sum(
335        map(
336            compose(get_abs_path, glob_path),
337            self._expand_vars_paths(paths, env),
338        ),
339        [],
340    )
341
342  def _is_excluded(self, path: str, exclude_paths: list[pathlib.Path]) -> bool:
343    """Check whether a path should be excluded."""
344    return exclude_paths and any(
345        path.startswith(exclude_path) for exclude_path in exclude_paths
346    )
347
348  def _filter_excluded_paths(
349      self,
350      root: pathlib.Path,
351      paths: list[pathlib.Path],
352      exclude_paths: list[pathlib.Path],
353  ) -> None:
354    """Filter a list of paths with a list of exclude paths."""
355    new_paths = [
356        path
357        for path in paths
358        if not self._is_excluded(os.path.join(root, path), exclude_paths)
359    ]
360    if len(new_paths) == len(paths):
361      return
362    paths.clear()
363    paths.extend(new_paths)
364
365  def take_snapshot(
366      self,
367      name: str,
368      root_path: str,
369      include_paths: list[str],
370      exclude_paths: Optional[list[str]] = None,
371      env: Optional[dict[str, str]] = None,
372  ) -> tuple[dict[str, _FileInfo], list[str]]:
373    """Creates a snapshot of the directory at the given path.
374
375    Args:
376        name: The name of the snapshot.
377        root_path: The path to the root directory.
378        include_paths: A list of relative paths to include in the snapshot.
379        exclude_paths: A list of relative paths to exclude from the snapshot.
380        env: Environment variables to use while restoring.
381
382    Returns:
383        A tuple containing:
384            - A dictionary of _FileInfo objects keyed by their relative path
385            within the directory.
386    """
387    include_paths = (
388        self._expand_wildcard_paths(root_path, include_paths, env)
389        if include_paths
390        else []
391    )
392    exclude_paths = (
393        self._expand_wildcard_paths(root_path, exclude_paths, env)
394        if exclude_paths
395        else []
396    )
397
398    file_infos = {}
399
400    def process_directory(path: pathlib.Path) -> None:
401      if path.is_symlink():
402        process_link(path)
403        return
404      relative_path = path.relative_to(root_path).as_posix()
405      if relative_path == '.':
406        return
407      file_infos[relative_path] = _FileInfo(
408          relative_path,
409          timestamp=None,
410          content_hash=None,
411          permissions=path.stat().st_mode,
412          symlink_target=None,
413          is_directory=True,
414      )
415
416    def process_file(path: pathlib.Path) -> None:
417      if path.is_symlink():
418        process_link(path)
419        return
420      relative_path = path.relative_to(root_path).as_posix()
421      timestamp = path.stat().st_mtime
422      file_infos[relative_path] = _FileInfo(
423          relative_path,
424          timestamp=timestamp,
425          content_hash=self._blob_store.add(path, timestamp)
426          if path.stat().st_size
427          else None,
428          permissions=path.stat().st_mode,
429          symlink_target=None,
430          is_directory=False,
431      )
432
433    def process_link(path: pathlib.Path) -> None:
434      relative_path = path.relative_to(root_path).as_posix()
435      symlink_target = path.readlink()
436      is_target_in_workspace = False
437      if symlink_target.is_relative_to(root_path):
438        symlink_target = symlink_target.relative_to(root_path)
439        is_target_in_workspace = True
440      file_infos[relative_path] = _FileInfo(
441          relative_path,
442          timestamp=None,
443          content_hash=None,
444          permissions=None,
445          symlink_target=symlink_target.as_posix(),
446          is_target_in_workspace=is_target_in_workspace,
447          is_directory=False,
448      )
449
450    def process_path(path: pathlib.Path) -> None:
451      if self._is_excluded(path.as_posix(), exclude_paths):
452        return
453      if path.is_symlink():
454        process_link(path)
455      elif path.is_file():
456        process_file(path)
457      elif path.is_dir():
458        process_directory(path)
459        for root, directories, files in os.walk(path):
460          self._filter_excluded_paths(root, directories, exclude_paths)
461          self._filter_excluded_paths(root, files, exclude_paths)
462          for directory in directories:
463            process_directory(pathlib.Path(root).joinpath(directory))
464          for file in files:
465            process_file(pathlib.Path(root).joinpath(file))
466      else:
467        # We are not throwing error here because it might be just a
468        # corner case which likely doesn't affect the test process.
469        logging.error('Unexpected path type: %s', path.as_posix())
470
471    for path in include_paths:
472      process_path(pathlib.Path(path))
473
474    snapshot_path = self._storage_path.joinpath(name + '_metadata.json')
475    snapshot_path.parent.mkdir(parents=True, exist_ok=True)
476    with snapshot_path.open('w') as f:
477      json.dump(file_infos, f, default=lambda o: o.__dict__)
478
479    self._blob_store.dump_cache()
480
481    return file_infos
482
483  def restore_snapshot(
484      self,
485      name: str,
486      root_path: str,
487      exclude_paths: Optional[list[str]] = None,
488      env: Optional[dict[str, str]] = None,
489  ) -> tuple[list[str], list[str], list[str]]:
490    """Restores directory at given path to snapshot with given name.
491
492    Args:
493        name: The name of the snapshot.
494        root_path: The path to the root directory.
495        exclude_paths: A list of relative paths to ignore during restoring.
496        env: Environment variables to use while restoring.
497
498    Returns:
499        A tuple containing 3 lists:
500            - Files and directories that were deleted.
501            - Files that were replaced.
502    """
503    with self._storage_path.joinpath(name + '_metadata.json').open('r') as f:
504      file_infos_dict = {
505          key: _FileInfo(**val) for key, val in json.load(f).items()
506      }
507
508    exclude_paths = (
509        self._expand_wildcard_paths(root_path, exclude_paths, env)
510        if exclude_paths
511        else []
512    )
513
514    deleted = self._remove_extra_files(
515        file_infos_dict, root_path, exclude_paths
516    )
517    self._restore_directories(file_infos_dict, root_path, exclude_paths)
518    replaced = self._restore_files(file_infos_dict, root_path, exclude_paths)
519
520    return deleted, replaced
521
522  def _remove_extra_files(
523      self,
524      file_infos_dict: dict[str, _FileInfo],
525      root_path: str,
526      exclude_paths: list[str],
527  ):
528    """Internal method to remove extra files during snapshot restore."""
529    deleted = []
530    for root, directories, files in os.walk(root_path):
531      self._filter_excluded_paths(root, directories, exclude_paths)
532      self._filter_excluded_paths(root, files, exclude_paths)
533      for directory in directories:
534        dir_path = pathlib.Path(root).joinpath(directory)
535        # Ignore non link directories because complicated to deal
536        # with file paths in include filters and unnecessary
537        if dir_path.is_symlink():
538          dir_path.unlink()
539      for file in files:
540        file_path = pathlib.Path(root).joinpath(file)
541        if file_path.is_symlink():
542          file_path.unlink()
543        elif file_path.relative_to(root_path).as_posix() not in file_infos_dict:
544          file_path.unlink()
545          deleted.append(file_path.as_posix())
546    return deleted
547
548  def _restore_directories(
549      self,
550      file_infos_dict: dict[str, _FileInfo],
551      root_path: str,
552      exclude_paths: list[str],
553  ):
554    """Internal method to restore directories during snapshot restore."""
555    for relative_path, file_info in file_infos_dict.items():
556      if not file_info.is_directory:
557        continue
558      dir_path = pathlib.Path(root_path).joinpath(relative_path)
559      if self._is_excluded(dir_path.as_posix(), exclude_paths):
560        continue
561      dir_path.mkdir(parents=True, exist_ok=True)
562      os.chmod(dir_path, file_info.permissions)
563
564  def _restore_files(
565      self,
566      file_infos_dict: dict[str, _FileInfo],
567      root_path: str,
568      exclude_paths: list[str],
569  ):
570    """Internal method to restore files during snapshot restore."""
571    replaced = []
572    for relative_path, file_info in file_infos_dict.items():
573      file_path = pathlib.Path(root_path).joinpath(relative_path)
574      if self._is_excluded(file_path.as_posix(), exclude_paths):
575        continue
576      if file_info.symlink_target:
577        file_path.parent.mkdir(parents=True, exist_ok=True)
578        target = file_info.symlink_target
579        if bool(file_info.is_target_in_workspace):
580          target = pathlib.Path(root_path).joinpath(target)
581        file_path.parent.mkdir(parents=True, exist_ok=True)
582        file_path.symlink_to(target)
583        continue
584
585      if file_info.is_directory:
586        continue
587
588      if (
589          file_path.exists()
590          and file_path.stat().st_mtime == file_info.timestamp
591      ):
592        continue
593
594      file_path.parent.mkdir(parents=True, exist_ok=True)
595      file_path.unlink(missing_ok=True)
596      if not file_info.content_hash:
597        file_path.touch()
598      else:
599        file_path.write_bytes(self._blob_store.get(file_info.content_hash))
600      os.utime(file_path, (file_info.timestamp, file_info.timestamp))
601      os.chmod(file_path, file_info.permissions)
602      replaced.append(file_path.as_posix())
603    return replaced
604