# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import hashlib
import os
import re
import shutil
import subprocess
import tempfile
import pathlib
from typing import TextIO, List, Dict, Generator, Optional

from absl import logging

def _git_long(verb: str, *args, git_flag: Optional[str] = None, cwd=None, input=None) -> str:
    logging.debug('Running\ngit %s %s\n with input: %s', verb, ' '.join(args), input)
    command = ['git']
    if git_flag:
        command += [git_flag]
    command += [verb]
    command += list(args)
    print(' '.join(command))
    result = subprocess.run(command,
                            cwd=cwd, input=input,
                            text=True,
                            check = True,
                            stderr=subprocess.STDOUT,
                            stdout=subprocess.PIPE)
    return result.stdout

def _git(verb: str, *args, git_flag: Optional[str] = None, cwd=None, input=None) -> str:
    result = _git_long(verb, *args, git_flag=git_flag, cwd=cwd, input=input)
    stdout = str(result)
    logging.info('git %s stdout: %s', verb, stdout)
    return stdout

class CommitMessage(object):
    def __init__(self, commit_message: str) -> None:
        self.commit_message = commit_message

class Commit(object):
    def __init__(self, raw_commit: str) -> None:
        self.raw_commit = raw_commit

class Git(object):
    def __init__(self, git_dir: str) -> None:
        self._git_dir = git_dir

    def __call__(self, verb: str, *args) -> str:
        return _git(verb, *args, cwd=self._git_dir)

    def clone(self, remote, *args) -> str:
        return _git('clone', *args, '--', remote, self._git_dir)

    def am(self, patch_contents: str) -> str:
        return _git('am', cwd=self._git_dir, input=patch_contents)

    def push(self, remote_branch: str) -> str:
        return _git('push', '-u', 'origin', remote_branch, cwd=self._git_dir)

    def config(self, config: str, option: str) -> str:
        return _git('config', '--local', config, option, cwd=self._git_dir)

    def commit(self, *args) -> str:
        return _git('commit', *args, cwd=self._git_dir)

    def log(self, *args) -> str:
        return _git_long('log', *args, git_flag='--no-pager', cwd=self._git_dir)

    def log_no_page(self, commit_hash: str) -> CommitMessage:
        return CommitMessage(_git('log', '-n 1', commit_hash, git_flag='--no-pager', cwd=self._git_dir))

    def show(self, *args) -> Commit:
        return Commit(_git('show', *args, git_flag='--no-pager', cwd=self._git_dir))

class BugFixChecker(object):
    def bug_fix_log_grep(self) -> str:
        return ''

    def is_bug_fix(self, commit_message: CommitMessage) -> bool:
        return False

class UpstreamKernelBugFixChecker(BugFixChecker):
    def bug_fix_log_grep(self) -> str:
        return 'Fixes:'

    def is_bug_fix(self, commit_message: CommitMessage) -> bool:
        match = re.search('^\s*Fixes: ([0-9a-fA-F]+) \(', commit_message.commit_message, re.MULTILINE)
        return bool(match)

def find_bug_fixes(git: Git, bug_fix_checker: BugFixChecker) -> Generator[str, None, None]:
    git_grep = '--grep=' + bug_fix_checker.bug_fix_log_grep()
    for commit_hash in git.log(git_grep, '--pretty=format:%h').splitlines():
        commit_message = git.log_no_page(commit_hash)
        if bug_fix_checker.is_bug_fix(commit_message):
            print('Found bug: ', commit_hash)
            yield commit_hash

class PathToBugListMap(object):
    def __init__(self) -> None:
        self._path_to_bug_list_map = {}

    def add_bug_to_path(self, path: str, bug_fix_hash: str) -> None:
        self._path_to_bug_list_map.setdefault(path, []).append(bug_fix_hash)

    def add_bug_to_paths(self, path_list: List[str], bug_fix_hash: str) -> None:
        for path in path_list:
            self.add_bug_to_path(path, bug_fix_hash)

    def serialize(self, file_name) -> None:
        with open(file_name, 'w') as file:
            for path, bug_fix_hash_list in self._path_to_bug_list_map.items():
                for bug_fix_hash in bug_fix_hash_list:
                    file.write(' '.join([path, bug_fix_hash, '\n']))

def path_to_bug_list_map_from_file(file_name: str) -> PathToBugListMap:
    path_to_bug_list_map = PathToBugListMap()
    with open(file_name, 'r') as file:
        for line in file:
            path, bug_fix_hash = line.split()
            path_to_bug_list_map.add_bug_to_path(path, bug_fix_hash)
    return path_to_bug_list_map

def find_affected_files_in_commit(commit: Commit) -> List[str]:
    results = []
    for match in re.finditer('^diff --git a\/\S+ b\/(\S+)$', commit.raw_commit, re.MULTILINE):
        affected_file = match.group(1)
        print('Found file: ', affected_file)
        if os.path.isfile(affected_file):
            results.append(affected_file)
    return results

def build_path_to_bug_list_map(git: Git, bug_fix_checker: BugFixChecker) -> PathToBugListMap:
    path_to_bug_lists = PathToBugListMap()
    for commit_hash in find_bug_fixes(git, bug_fix_checker):
        commit = git.show(commit_hash)
        paths = find_affected_files_in_commit(commit)
        path_to_bug_lists.add_bug_to_paths(paths, commit_hash)
    return path_to_bug_lists

def generate_path_to_bug_list(git: Git, path_to_bug_list_file: str, bug_fix_checker=UpstreamKernelBugFixChecker()) -> None:
    path_to_bug_list_map = build_path_to_bug_list_map(git, bug_fix_checker)
    path_to_bug_list_map.serialize(path_to_bug_list_file)

class BugFixFileTreeNode(object):
    def get_bugs_of_children(self) -> List[str]:
        return []

    def get_number_of_child_bugs(self) -> int:
        return len(self.get_bugs_of_children())

    def sort(self) -> None:
        pass

    def print(self, indent: str = '') -> None:
        pass

    def add_path_to_bug_mapping(self, path: pathlib.Path, bug_hash: str) -> None:
        pass

class BugFixFileNode(BugFixFileTreeNode):
    def __init__(self, file_name: str, bug_list: Optional[List[str]] = None) -> None:
        self.file_name = file_name
        self._bug_list = bug_list or []

    def get_bugs_of_children(self) -> List[str]:
        return self._bug_list

    def add_path_to_bug_mapping(self, path: pathlib.Path, bug_hash: str) -> None:
        assert os.path.basename(path) == path and path == self.file_name
        self._bug_list.append(bug_hash)

    def sort(self) -> None:
        self._bug_list.sort()

    def print(self, indent: str = '') -> None:
        print(indent, self.file_name, ': ', str(len(self._bug_list)))
        #indent += '\t'
        #for bug_hash in self._bug_list:
        #    print(indent, bug_hash)

class BugFixDirectoryNode(BugFixFileTreeNode):
    def __init__(self,
                 directory_name: str,
                 child_directory_nodes: Optional[Dict[str, 'BugFixDirectoryNode']] = None,
                 child_file_nodes: Optional[Dict[str, BugFixFileNode]] = None) -> None:
        self.directory_name = directory_name
        self._child_directory_nodes = child_directory_nodes or {}
        self._child_file_nodes = child_file_nodes or {}
        self._child_nodes_sorted = []
        self._number_of_child_bugs = 0

    def get_bugs_of_children(self) -> List[str]:
        results = []
        for _, child in self._child_directory_nodes:
            results += child.get_bugs_of_children
        for _, child in self._child_file_nodes:
            results += child.get_bugs_of_children
        return results

    def get_number_of_child_bugs(self) -> int:
        return self._number_of_child_bugs

    def add_path_to_bug_mapping(self, path: pathlib.Path, bug_hash: str) -> None:
        child_name = path.parts[0]
        if len(path.parts) > 1:
            remaining_path = pathlib.Path(*path.parts[1:])
            # print('Adding path: ' + str(path), ', Child name: ', child_name, ', Remaining path: ', remaining_path)
            self._child_directory_nodes.setdefault(child_name, BugFixDirectoryNode(child_name)).add_path_to_bug_mapping(remaining_path, bug_hash)
            # print(self._child_directory_nodes)
        else:
            self._child_file_nodes.setdefault(child_name, BugFixFileNode(child_name)).add_path_to_bug_mapping(child_name, bug_hash)

    def sort(self) -> None:
        #print('Keys for ', self.directory_name, ':')
        #for part in list(self._child_directory_nodes.keys()) + list(self._child_file_nodes.keys()):
        #    print('\t', part)
        self._child_nodes_sorted = [] + list(self._child_directory_nodes.values()) + list(self._child_file_nodes.values())
        self._number_of_child_bugs = 0
        for child in self._child_nodes_sorted:
            child.sort()
            self._number_of_child_bugs += child.get_number_of_child_bugs()
        self._child_nodes_sorted.sort(key=lambda child: child.get_number_of_child_bugs(), reverse=True)

    def print(self, indent: str = '') -> None:
        print(indent, self.directory_name, ': ', str(self._number_of_child_bugs))
        indent += '\t'
        for child in self._child_nodes_sorted:
            child.print(indent)

class BugFixFileTree(object):
    def __init__(self) -> None:
        self._root_directory_node = BugFixDirectoryNode('')

    def build_from_path_to_bug_list_map(self, path_to_bug_list_map: PathToBugListMap) -> None:
        for path, bug_list in path_to_bug_list_map._path_to_bug_list_map.items():
            for bug_hash in bug_list:
                self._root_directory_node.add_path_to_bug_mapping(pathlib.Path(path), bug_hash)

    def print(self) -> None:
        self._root_directory_node.sort()
        self._root_directory_node.print()

def print_heat_map(git: Git, path_to_bug_list_file: str) -> None:
    path_to_bug_list_map = path_to_bug_list_map_from_file(path_to_bug_list_file)
    bug_fix_file_tree = BugFixFileTree()
    bug_fix_file_tree.build_from_path_to_bug_list_map(path_to_bug_list_map)
    bug_fix_file_tree.print()

def print_file_paths_with_bug_numbers(git: Git, path_to_bug_list_file: str) -> None:
    path_to_bug_list_map = path_to_bug_list_map_from_file(path_to_bug_list_file)
    path_to_bug_number = {}
    for path, bug_list in path_to_bug_list_map._path_to_bug_list_map.items():
        for bug_hash in bug_list:
            path_to_bug_number.setdefault(path, 0)
            path_to_bug_number[path] += 1
    path_to_bug_number_list = list(path_to_bug_number.items())
    path_to_bug_number_list.sort(key=lambda pair: pair[1], reverse=True)
    for path_bug_number_pair in path_to_bug_number_list:
        path = path_bug_number_pair[0]
        bug_number = path_bug_number_pair[1]
        print(path, ', ', bug_number)
