#!/usr/bin/python3 -B

# Copyright 2022 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Read the EXPECTED_UPSTREAM and  merge the files from the upstream."""
import argparse
import datetime
import logging
# pylint: disable=g-importing-member
import os.path
from pathlib import Path
import random
import re
import string
import sys
from typing import List, Tuple, Set, Dict
from typing import Sequence

# pylint: disable=g-multiple-import
from common_util import (
    ExpectedUpstreamEntry,
    ExpectedUpstreamFile,
    has_file_in_tree,
    LIBCORE_DIR,
    OjluniFinder,
    TEST_PATH,
)

from git import (
    Commit,
    DiffIndex,
    GitCommandError,
    Head,
    IndexFile,
    Repo,
)

# Enable INFO logging for error emitted by GitPython
logging.basicConfig(level=logging.INFO)


def validate_and_remove_unmodified_entries(
    entries: List[ExpectedUpstreamEntry],
    repo: Repo, commit: Commit) -> List[ExpectedUpstreamEntry]:
  """Returns a list of entries of which the file content needs to be updated."""
  commit_tree = commit.tree
  result: List[ExpectedUpstreamEntry] = []

  for e in entries:
    try:
      # The following step validate each entry by querying the git database
      commit = repo.commit(e.git_ref)
      source_blob = commit.tree.join(e.src_path)
      if not has_file_in_tree(e.dst_path, commit_tree):
        # Add the entry if the file is missing in the HEAD
        result.append(e)
        continue

      dst_blob = commit_tree.join(e.dst_path)
      # Add the entry if the content is different.
      # data_stream will be close during GC.
      if source_blob.data_stream.read() != dst_blob.data_stream.read():
        result.append(e)
    except:
      print(f"ERROR: reading entry: {e}", file=sys.stderr)
      raise

  return result


THIS_TOOL_PATH = Path(__file__).relative_to(LIBCORE_DIR)

TEMP_EXPECTED_BRANCH_PREFIX = "expected_upstream_"

MSG_FIRST_COMMIT = ("Import {summary}\n"
                    "\n"
                    "List of files:\n"
                    "  {files}\n"
                    "\n"
                    f"Generated by {THIS_TOOL_PATH}\n"
                    "\n"
                    "{bug}\n"
                    "Test: N/A\n"
                    "No-Typo-Check: Imported files"
                    "{change_id_str}")

MSG_SECOND_COMMIT = ("Merge {summary} into the "
                     "aosp/main branch\n"
                     "\n"
                     "List of files:\n"
                     "  {files}\n"
                     "\n"
                     "{bug}\n"
                     "Test: N/A"
                     "{change_id_str}")

INVALID_DIFF = (None, None)

LICENSE_BLOCK = r"\/\*(?:\*(?!\/)|[^*])*\*\/[ ]*\n+"
REGEX_LICENSE_AND_IMPORT = re.compile(
    r"^(" + LICENSE_BLOCK + ")(import .+;)$", re.MULTILINE)


def create_commit_staging_diff(repo: Repo) -> None:
  r"""Save the current EXPECTED_UPSTREAM filein a new git commit.

  It can be retrieved later if this script fails.

  Args:
    repo: the repository object
  """
  head = repo.head
  index = IndexFile.from_tree(repo, head.commit)
  index.add("EXPECTED_UPSTREAM")

  now_str = datetime.datetime.now().strftime("%Y/%m/%d %H:%M:%S")
  msg = f"Staging EXPECTED_UPSTREAM at {now_str}"
  commit = index.commit(message=msg, parent_commits=[head.commit], head=False)

  print(
      f"The current EXPECTED_UPSTREAM file is saved in {commit.hexsha}.\n"
      "If this script fails in the later stage, please retrieve the file by:\n"
      f"  git checkout {commit.hexsha} -- EXPECTED_UPSTREAM")


def create_commit_summary(diff_entries: List[ExpectedUpstreamEntry]) -> str:
  r"""Create a commit summary message.

  Args:
    diff_entries: list of new / modified entries

  Returns:
    a string message
  """

  default_msg = "files"
  entries_and_names = []
  for e in diff_entries:
    t = (e, OjluniFinder.translate_ojluni_path_to_class_name(e.dst_path))
    entries_and_names.append(t)

  # Non-test entries
  important_entries: List[tuple[ExpectedUpstreamEntry, str]] = [
      t for t in entries_and_names
      if t[1] is not None and not t[1].startswith("test.")]
  if not important_entries:
    # Test entries
    important_entries = [t for t in entries_and_names if t[1] is not None and
                         t[1].startswith("test.")]
    # no path is under OJLUNI_JAVA_BASE_PATH or OJLUNI_TEST_PATH
    if not important_entries:
      return default_msg

  # Get ref if all entries come from the same OpenJDK revision
  git_ref = important_entries[0][0].git_ref
  for e in important_entries:
    if e[0].git_ref != git_ref:
      git_ref = None
      break

  if len(important_entries) == 1:
    classes_summary = important_entries[0][1].split(".")[-1]
  else:
    common_prefix = os.path.commonprefix(list(map(
        lambda t: t[1], important_entries)))
    prefix_split = common_prefix.split(".")

    # short java package, e.g. javax. or java.n, doesn't provide meaningful
    # commit summary.
    if len(prefix_split) <= 2:
      classes_summary = default_msg
    else:
      # Assume that package name isn't title-case.
      is_package = (not prefix_split[-1] or prefix_split[-1][0].islower())
      if is_package:
        # Discard the prefix after the last "."
        classes_summary = ".".join(prefix_split[:-1])
      else:
        classes_summary = common_prefix + "*"

  if git_ref is None:
    return classes_summary
  else:
    abbv_ref = git_ref.split("/", 1)[-1]
    return f"{classes_summary} from {abbv_ref}"


def create_commit_at_expected_upstream(
    repo: Repo, head: Head, new_entries: List[ExpectedUpstreamEntry],
    removed_paths: Set[str], bug_id: str,
    last_expected_change_id: str, discard_working_tree: bool) -> Head:
  r"""Create a new commit importing the given files at the head.

  Args:
    repo: the repository object
    head: the temp expected_upstream branch
    new_entries: a list of entries
    removed_paths: removed paths
    bug_id: bug id
    last_expected_change_id: Gerrit's change Id
    discard_working_tree: discard the working tree.

  Returns:
    a list of entries
  """
  affected_paths = [e.dst_path for e in new_entries] + list(removed_paths)
  str_affected_paths = "\n  ".join(affected_paths)

  for entry in new_entries:
    ref = entry.git_ref
    upstream_commit = repo.commit(ref)
    src_blob = upstream_commit.tree[entry.src_path]
    # Write into the file system directly because GitPython provides no API
    # writing into the index in memory. IndexFile.move doesn't help here,
    # because the API requires the file on the working tree too.
    # However, it's fine, because we later reset the HEAD.
    absolute_dst_path = Path(LIBCORE_DIR, entry.dst_path)
    absolute_dst_path.parent.mkdir(parents=True, exist_ok=True)
    with absolute_dst_path.open("wb") as file:
      file.write(src_blob.data_stream.read())

  entries = ExpectedUpstreamFile(head.commit.tree["EXPECTED_UPSTREAM"]
                                 .data_stream.read()).read_all_entries()
  entries = overlay_entries(entries, new_entries)
  entries = list(filter(lambda e: e.dst_path not in removed_paths, entries))
  # Write the entries to the file system.
  ExpectedUpstreamFile().sort_and_write_all_entries(entries)

  if discard_working_tree:
    repo.head.reference = head
    repo.head.reset(index=True)
    index = repo.index
  else:
    index = IndexFile.from_tree(repo, head.commit)
  index.add("EXPECTED_UPSTREAM")
  for entry in new_entries:
    index.add(entry.dst_path)

  for p in removed_paths:
    index.remove(p)

  summary_msg = create_commit_summary(new_entries)
  str_bug = "" if bug_id is None else f"Bug: {bug_id}"
  change_id_str = ""
  if last_expected_change_id:
    change_id_str = f"\nChange-Id: {last_expected_change_id}"
  msg = MSG_FIRST_COMMIT.format(summary=summary_msg, files=str_affected_paths,
                                bug=str_bug, change_id_str=change_id_str)
  commit = index.commit(message=msg, parent_commits=[head.commit], head=False)
  new_head = head.set_commit(commit)

  print(f"Create a new commit {commit.hexsha} at {head.name}")

  return new_head


def overlay_entries(
    existing_entries: List[ExpectedUpstreamEntry],
    new_entries: List[ExpectedUpstreamEntry]) -> List[ExpectedUpstreamEntry]:
  r"""Return a list of entries after overlaying the new_entries.

  Args:
    existing_entries: current entries
    new_entries: entries being overlaid
  Returns:
    a list of entries
  """
  entries_map = {}
  for e in existing_entries:
    entries_map[e.dst_path] = e

  for e in new_entries:
    entries_map[e.dst_path] = e

  return [e for key, e in entries_map.items()]


REGEX_CHANGE_ID = r"^Change-Id: (I[0-9a-f]+)$"
REGEX_BUG_ID = r"^Bug: ([0-9]+)$"


def extract_change_id(commit: Commit) -> str:
  r"""Extract gerrit's Change-Id from a commit message.

  Args:
     commit: commit

  Returns:
    Change-Id
  """
  result = re.search(REGEX_CHANGE_ID, commit.message, re.M)
  return result.group(1) if result else None


def extract_bug_id(commit: Commit) -> str:
  r"""Extract the bug id from a commit message.

  Args:
     commit: commit

  Returns:
    Buganizer Id
  """
  result = re.search(REGEX_BUG_ID, commit.message, re.M)
  return result.group(1) if result else None


def get_diff_entries(repo: Repo, base_expected_commit: Commit) -> Tuple[
    List[ExpectedUpstreamEntry], Set[str]]:
  """Get a list of entries different from the head commit.

  Validate EXPECTED_UPSTREAM file and return the list of
  modified or new entries between the working tree and HEAD.

  Args:
    repo: Repo
    base_expected_commit: the base commit

  Returns:
    a list of entries
  """
  current_tracking_branch = repo.active_branch.tracking_branch()
  if current_tracking_branch.name != "aosp/main":
    print("This script should only run on aosp/main branch. "
          f"Currently, this is on branch {repo.active_branch} "
          f"tracking {current_tracking_branch}", file=sys.stderr)
    return INVALID_DIFF

  print("Reading EXPECTED_UPSTREAM file...")
  head_commit = repo.head.commit
  diff_index = head_commit.diff(None)
  no_file_change = len(diff_index)
  if no_file_change == 0:
    print("Can't find any EXPECTED_UPSTREAM file change", file=sys.stderr)
    return INVALID_DIFF
  elif no_file_change > 1 or diff_index[0].a_rawpath != b"EXPECTED_UPSTREAM":
    print("Expect modification in the EXPECTED_UPSTREAM file only.\n"
          "Please remove / commit the other changes. The below file changes "
          "are detected: ", file=sys.stderr)
    print_diff_index(diff_index, file=sys.stderr)
    return INVALID_DIFF

  prev_file = ExpectedUpstreamFile(head_commit.tree["EXPECTED_UPSTREAM"]
                                   .data_stream.read())
  curr_file = ExpectedUpstreamFile()
  diff_entries = prev_file.get_new_or_modified_entries(curr_file)
  removed_paths = prev_file.get_removed_paths(curr_file)

  modified_entries = validate_and_remove_unmodified_entries(
      diff_entries, repo, base_expected_commit)

  if not modified_entries and not removed_paths:
    print("No need to update. All files are updated.")
    return INVALID_DIFF

  print("The following entries will be updated from upstream")
  for e in modified_entries:
    print(f"  {e.dst_path}")
  for p in removed_paths:
    print(f"  {p}")

  return diff_entries, removed_paths


def compute_absorbed_diff_entries(
    repo: Repo, base_commit: Commit, commit: Commit, overlaid_entries: List[
        ExpectedUpstreamEntry], removed_paths: Set[
            str]) -> Tuple[List[ExpectedUpstreamEntry], Set[str]]:
  r"""Compute the combined entries after absorbing the new changes.

  Args:
    repo: Repo
    base_commit: the base commit in the expected_upstream
    commit: The commit diff-ed against from the base_commit
    overlaid_entries: Additional entries overlaid on top of the diff.
    removed_paths: removed paths

  Returns:
    Combined diff entries
  """
  prev_file = ExpectedUpstreamFile(base_commit.tree["EXPECTED_UPSTREAM"]
                                   .data_stream.read())
  curr_file = ExpectedUpstreamFile(commit.tree["EXPECTED_UPSTREAM"]
                                   .data_stream.read())
  diff_entries = prev_file.get_new_or_modified_entries(curr_file)
  diff_entries = overlay_entries(diff_entries, overlaid_entries)
  intersection = set(filter(lambda e: e.dst_path in removed_paths,
                            diff_entries))
  diff_entries = list(filter(lambda e: e.dst_path not in intersection, diff_entries))
  new_removed_paths = set(filter(lambda p: p not in intersection,
                                 removed_paths))
  return validate_and_remove_unmodified_entries(
      diff_entries, repo, base_commit), new_removed_paths


def main_run(
    repo: Repo, expected_upstream_base: str,
    bug_id: str, use_rerere: bool, is_absorbed: bool,
    discard_working_tree: bool) -> None:
  """Create the commits importing files according to the EXPECTED_UPSTREAM.

  Args:
    repo: Repo
    expected_upstream_base: the base commit in the expected_upstream branch.
    bug_id: bug id
    use_rerere: Reuses the recorded resolution from git
    is_absorbed: Absorb the new changes from EXPECTED_UPSTREAM into the
      existing commits created by this script
    discard_working_tree: discard working tree flag.
  """
  last_master_commit = repo.head.commit
  last_master_change_id = None
  last_expected_change_id = None
  if is_absorbed:
    head = repo.head
    if len(head.commit.parents) != 2:
      print("Error: HEAD isn't a merge commit.", file=sys.stderr)
      return

    last_branch = None
    last_expected_commit = None
    for commit in head.commit.parents:
      name_rev: list[str] = commit.name_rev.split(" ", 1)
      if (len(name_rev) > 1 and  # name_rev[1] is usually the branch name
          name_rev[1].startswith(TEMP_EXPECTED_BRANCH_PREFIX)):
        last_branch = name_rev[1]
        last_expected_commit = commit
      else:
        last_master_commit = commit

    if last_branch is None:
      print("Error: Can't find the last commit in the expected_upstream "
            "branch.", file=sys.stderr)
      return

    if len(last_expected_commit.parents) != 1:
      print(f"Error: The head commit at {last_branch} isn't in the expected "
            f"state.")
      return

    base_expected_branch_commit = last_expected_commit.parents[0]
    last_expected_change_id = extract_change_id(last_expected_commit)
    last_master_change_id = extract_change_id(head.commit)
    if bug_id is None:
      bug_id = extract_bug_id(last_expected_commit)
  else:
    if expected_upstream_base is None:
      expected_upstream_base = "aosp/expected_upstream"
    try:
      base_expected_branch_commit = repo.commit(expected_upstream_base)
    finally:
      if base_expected_branch_commit is None:
        print(f"{expected_upstream_base} is not found in this repository.",
              file=sys.stderr)

  diff_entries, removed_paths = get_diff_entries(repo,
                                                 base_expected_branch_commit)
  if not diff_entries and not removed_paths:
    return

  if is_absorbed:
    diff_entries, removed_paths = compute_absorbed_diff_entries(
        repo, base_expected_branch_commit, last_expected_commit, diff_entries,
        removed_paths)

  # Due to a limitation in GitPython, index.remove requires switching branch
  # and discard the working tree.
  if removed_paths and not discard_working_tree:
    print("-r option is required to discard the current working tree.")
    return

  create_commit_staging_diff(repo)

  master_head = repo.active_branch
  branch_name = create_random_branch_name()
  new_branch = repo.create_head(branch_name, base_expected_branch_commit.hexsha)
  new_branch.set_tracking_branch(repo.remotes.aosp.refs.expected_upstream)
  new_branch = create_commit_at_expected_upstream(
      repo, new_branch, diff_entries, removed_paths, bug_id,
      last_expected_change_id, discard_working_tree)

  # Clean the working tree before merging branch
  if discard_working_tree:
    repo.head.reference = master_head

  repo.head.reset(commit=last_master_commit, working_tree=True)
  for e in diff_entries:
    if not has_file_in_tree(e.dst_path, repo.head.commit.tree):
      path = Path(LIBCORE_DIR, e.dst_path)
      path.unlink(missing_ok=True)

  affected_paths = [e.dst_path for e in diff_entries] + list(removed_paths)
  str_affected_paths = "\n  ".join(affected_paths)
  summary_msg = create_commit_summary(diff_entries)
  str_bug = "" if bug_id is None else f"Bug: {bug_id}"
  change_id_str = ""
  if last_master_change_id:
    change_id_str = f"\nChange-Id: {last_master_change_id}"
  msg = MSG_SECOND_COMMIT.format(
      summary=summary_msg, files=str_affected_paths, bug=str_bug,
      change_id_str=change_id_str)
  rerere_str = "rerere.enabled="
  rerere_str += "true" if use_rerere else "false"

  test_dst_paths = {}
  for e in diff_entries:
    if e.dst_path.startswith(TEST_PATH):
      class_name = OjluniFinder.translate_ojluni_path_to_class_name(e.dst_path)
      if class_name is not None:
        package_name = class_name[:class_name.rfind(".")]
        test_dst_paths[e.dst_path] = package_name

  # Run git-merge command here, and will let the user to handle
  # any errors and merge conflicts
  try:
    repo.git.execute(["git", "-c", rerere_str, "merge",
                      new_branch.commit.hexsha, "-m", msg])
  except GitCommandError as err:
    print(f"Error: {err}", file=sys.stderr)

  insert_package_name_to_tests(test_dst_paths)


def insert_package_name_to_tests(test_dst_paths: Dict[str, str]):
  """Insert package name into the test file before the java import statement.

  Args:
    test_dst_paths: Map the file path to package names
  """
  for dst_path, package_name in test_dst_paths.items():
    with open(dst_path, "r") as file:
      src = file.read()
    replacement = r"\1package " + package_name + r";\n\n\2"
    modified = REGEX_LICENSE_AND_IMPORT.sub(replacement, src, count=1)
    with open(dst_path, "w") as out:
      out.write(modified)


def create_random_branch_name():
  rand_suffix = "".join(random.choice(string.ascii_lowercase +
                                      string.digits) for _ in range(10))
  return f"{TEMP_EXPECTED_BRANCH_PREFIX}{rand_suffix}"


def print_diff_index(index: DiffIndex, file=sys.stdout) -> None:
  for diff in index:
    print(f"  {diff.a_rawpath}", file=file)


def main(argv: Sequence[str]) -> None:
  arg_parser = argparse.ArgumentParser(
      description="Read the EXPECTED_UPSTREAM and update the files from the "
                  "OpenJDK. This script imports the files from OpenJDK into "
                  "the expected_upstream branch and merges it into the "
                  "current branch.")
  arg_parser.add_argument(
      "-a", "--absorbed-to-last-merge", action="store_true",
      help="Import more files but absorb them into the last commits created "
           "by this script.")
  arg_parser.add_argument(
      "--disable-rerere", action="store_true",
      help="Do not re-use the recorded resolution from git.")
  arg_parser.add_argument(
      "-r", "--reset", action="store_true",
      help="Discard the current working tree. Experimental flag to "
           "support file removal from ojluni/.")
  arg_parser.add_argument(
      "-b", "--bug", nargs="?",
      help="Buganizer Id")
  arg_parser.add_argument(
      "-e", "--expected_upstream_base", nargs="?",
      help="The base commit in the expected_upstream branch")

  args = arg_parser.parse_args(argv)

  bug_id = args.bug
  expected_upstream_base = args.expected_upstream_base
  use_rerere = not args.disable_rerere
  is_absorbed = args.absorbed_to_last_merge
  discard_working_tree = args.reset
  if is_absorbed and expected_upstream_base is not None:
    print("Error: -a and -e options can't be used together.", file=sys.stderr)
    return

  repo = Repo(LIBCORE_DIR.as_posix())
  try:
    main_run(repo, expected_upstream_base, bug_id, use_rerere, is_absorbed,
             discard_working_tree)
  finally:
    repo.close()


if __name__ == "__main__":
  main(sys.argv[1:])