diff --git a/tensorbay/cli/cli.py b/tensorbay/cli/cli.py index f974237b3..152c7f1a5 100644 --- a/tensorbay/cli/cli.py +++ b/tensorbay/cli/cli.py @@ -372,6 +372,9 @@ def tag(obj: ContextInfo, tbrn: str, name: str, is_delete: bool, sort: str) -> N "", "# Show text-based graphical commit logs.", "$ gas log --graph tb:[@]", + "# Show commit and open draft logs.", + "$ gas log --show-drafts tb:[@]", + "", ) ) @click.argument("tbrn", type=str) @@ -381,6 +384,7 @@ def tag(obj: ContextInfo, tbrn: str, name: str, is_delete: bool, sort: str) -> N @click.option("--oneline", is_flag=True, help="Limit commit message to oneline") @click.option("--all", "is_all", is_flag=True, help="Show all the commits of all branches") @click.option("--graph", is_flag=True, help="Show text-based graphical commits history") +@click.option("--show-drafts", is_flag=True, help="Show open drafts") @click.pass_obj def log( # pylint: disable=too-many-arguments obj: ContextInfo, @@ -389,6 +393,7 @@ def log( # pylint: disable=too-many-arguments oneline: bool, is_all: bool, graph: bool, + show_drafts: bool, ) -> None: """Show commit logs.\f @@ -399,11 +404,12 @@ def log( # pylint: disable=too-many-arguments oneline: Whether to show a commit message in oneline. is_all: Whether to show all commits of all branches. graph: Whether to show graphical commit history. + show_drafts: Whether to log open drafts. """ # noqa: D301,D415 from tensorbay.cli.log import _implement_log - _implement_log(obj, tbrn, max_count, oneline, is_all, graph) + _implement_log(obj, tbrn, max_count, oneline, is_all, graph, show_drafts) @command( diff --git a/tensorbay/cli/log.py b/tensorbay/cli/log.py index b6194f1c9..6a4a1bc95 100644 --- a/tensorbay/cli/log.py +++ b/tensorbay/cli/log.py @@ -9,9 +9,21 @@ import sys from collections import defaultdict from datetime import datetime +from functools import partial from itertools import cycle, islice, zip_longest from textwrap import indent -from typing import DefaultDict, Dict, Iterator, List, Optional, Type, Union +from typing import ( + Callable, + DefaultDict, + Dict, + Iterable, + Iterator, + List, + Optional, + Type, + TypeVar, + Union, +) import click @@ -19,20 +31,29 @@ from tensorbay.cli.tbrn import TBRN, TBRNType from tensorbay.cli.utility import ContextInfo, error, exception_handler, shorten from tensorbay.client.gas import DatasetClientType -from tensorbay.client.struct import Commit +from tensorbay.client.struct import ROOT_COMMIT_ID, Commit, Draft _LEFT_BRACKET = click.style("(", fg="yellow", reset=False) _COMMA = click.style(", ", fg="yellow", reset=False) _RIGHT_BRACKET = click.style(")", fg="yellow", reset=True) -_FULL_LOG = f"""{click.style("commit {}", fg="yellow")} +_FULL_COMMIT_MESSAGE = f"""{click.style("commit {}", fg="yellow")} Author: {{}} Date: {{}} {{}} """ -_ONELINE_LOG = f"""{click.style("{}", fg="yellow")} {{}} +_FULL_DRAFT_MESSAGE = f"""{click.style("draft {}", fg="red")} +Date: {{}} + + {{}} +""" + +_ONELINE_COMMIT_MESSAGE = f"""{click.style("{}", fg="yellow")} {{}} +""" + +_ONELINE_DRAFT_MESSAGE = f"""{click.style("{}", fg="red")} {{}} """ _LOG_COLORS = ( @@ -48,6 +69,9 @@ "bright_green", "bright_yellow", ) +_R = TypeVar("_R", bound="_RootCommitNode") +_D = TypeVar("_D", bound="_DraftNode") +_T = Union["_RootCommitNode", "_DraftNode"] @exception_handler @@ -58,6 +82,7 @@ def _implement_log( # pylint: disable=too-many-arguments oneline: bool, is_all: bool, graph: bool, + show_drafts: bool, ) -> None: gas = obj.get_gas() tbrn_info = TBRN(tbrn=tbrn) @@ -76,19 +101,14 @@ def _implement_log( # pylint: disable=too-many-arguments revisions = ( [tbrn_info.revision] if tbrn_info.revision else [dataset_client.status.branch_name] ) - - Printer: Union[Type[_GraphPrinter], Type[_CommitPrinter]] = ( - _GraphPrinter if graph else _CommitPrinter - ) - commit_generator = islice( - Printer(dataset_client, revisions, commit_id_to_branches, oneline).generate_commits(), + Printer: Union[Type[_GraphPrinter], Type[_Printer]] = _GraphPrinter if graph else _Printer + message_generator = islice( + Printer( + dataset_client, revisions, commit_id_to_branches, oneline, show_drafts=show_drafts + ).generate_commits_and_drafts_messages(), max_count, ) - if sys.platform.startswith("win"): - for item in commit_generator: - click.echo(item) - else: - click.echo_via_pager(commit_generator) + _echo_messages(message_generator) def _join_branch_names(commit_id: str, branch_names: List[str]) -> str: @@ -99,14 +119,18 @@ def _join_branch_names(commit_id: str, branch_names: List[str]) -> str: ) -def _get_oneline_log(commit: Commit, branch_names: Optional[List[str]]) -> str: +def _get_oneline_commit_message(commit: Commit, branch_names: Optional[List[str]]) -> str: commit_id = shorten(commit.commit_id) if branch_names: commit_id = _join_branch_names(commit_id, branch_names) - return _ONELINE_LOG.format(commit_id, commit.title) + return _ONELINE_COMMIT_MESSAGE.format(commit_id, commit.title) + +def _get_oneline_draft_message(draft: Draft) -> str: + return _ONELINE_DRAFT_MESSAGE.format(draft.number, draft.title) -def _get_full_log(commit: Commit, branch_names: Optional[List[str]]) -> str: + +def _get_full_commit_message(commit: Commit, branch_names: Optional[List[str]]) -> str: description = commit.description if description: description = f"\n\n{indent(description, INDENT)}" @@ -114,7 +138,7 @@ def _get_full_log(commit: Commit, branch_names: Optional[List[str]]) -> str: commit_id = commit.commit_id if branch_names: commit_id = _join_branch_names(commit_id, branch_names) - return _FULL_LOG.format( + return _FULL_COMMIT_MESSAGE.format( commit_id, commit.committer.name, datetime.fromtimestamp(commit.committer.date).strftime("%a %b %d %H:%M:%S %y"), @@ -122,14 +146,27 @@ def _get_full_log(commit: Commit, branch_names: Optional[List[str]]) -> str: ) -class _CommitPrinter: - """This class defines the structure of logging commits. +def _get_full_draft_message(draft: Draft) -> str: + description = draft.description + if description: + description = f"\n\n{indent(description, INDENT)}" + draft_message = f"{draft.title}{description}\n" + return _FULL_DRAFT_MESSAGE.format( + draft.number, + datetime.fromtimestamp(draft.updated_at).strftime("%a %b %d %H:%M:%S %y"), + draft_message, + ) + + +class _Printer: + """This class defines the structure of logging commits and open drafts. Arguments: dataset_client: The dataset that needs to be logged. revisions: The revisions that needs to be logged. oneline: Whether to log with oneline method. commit_id_to_branches: The map of commit id to branch name. + show_drafts: Whether to log open drafts. """ @@ -139,27 +176,43 @@ def __init__( revisions: List[Optional[str]], commit_id_to_branches: Dict[str, List[str]], oneline: bool, + *, + show_drafts: bool, ): - all_commits = list(map(dataset_client.list_commits, revisions)) - - if len(all_commits[0]) == 0: - error(f'Dataset "{dataset_client.name}" has no commit history') + all_commit_logs = list(map(dataset_client.list_commits, revisions)) + all_drafts: List[Draft] = [] + error_message = f'Dataset "{dataset_client.name}" has no commit history' + if show_drafts: + error_message += " or open drafts" + for revision in revisions: + all_drafts.extend(dataset_client.list_drafts(branch_name=revision)) + + if not all_commit_logs[0]: + if not all_drafts: + error(error_message) + self._sorted_commit_logs = [] + else: + # Sort logs from different branches by the date of the latest commit of each branch. + self._sorted_commit_logs = sorted(all_commit_logs, key=lambda x: x[0].committer.date) self._commit_id_to_branches = commit_id_to_branches - self._printer = _get_oneline_log if oneline else _get_full_log + self._commit_printer, self._draft_printer = ( + (_get_oneline_commit_message, _get_oneline_draft_message) + if oneline + else (_get_full_commit_message, _get_full_draft_message) + ) - # Sort commits from different branches by the date of the latest commit of each branch. - self._sorted_commits = sorted(all_commits, key=lambda x: x[0].committer.date) - self._keys = [commits[0].committer.date for commits in self._sorted_commits] + self._sorted_drafts = sorted(all_drafts, key=lambda x: x.updated_at) + self._keys = [log[0].committer.date for log in self._sorted_commit_logs] def _merge(self, latest_commit: Commit) -> bool: - if len(self._sorted_commits) <= 1: + if len(self._sorted_commit_logs) <= 1: return False date = latest_commit.committer.date commit_id = latest_commit.commit_id - for commits in islice(reversed(self._sorted_commits), 1, None): - commit = commits[0] + for log in islice(reversed(self._sorted_commit_logs), 1, None): + commit = log[0] # Traverse all the commits with the same timestamp, # if the commit id is the same as the latest commit, # then merge the branch where the latest commit is located. @@ -171,54 +224,83 @@ def _merge(self, latest_commit: Commit) -> bool: else: return False # Merge branches. - del self._sorted_commits[-1] + del self._sorted_commit_logs[-1] del self._keys[-1] return True def _sort(self) -> None: """Sort commits paging list by commit date.""" # Only one branch exists. - if len(self._sorted_commits) == 1: + if len(self._sorted_commit_logs) == 1: return # Binary insert. - commits = self._sorted_commits.pop() + log = self._sorted_commit_logs.pop() del self._keys[-1] - date = commits[0].committer.date + date = log[0].committer.date index = bisect.bisect_left(self._keys, date) - self._sorted_commits.insert(index, commits) + self._sorted_commit_logs.insert(index, log) self._keys.insert(index, date) - def _generate_commits(self) -> Iterator[Commit]: - """Get the latest commit in commit list. + def generate_commits_and_drafts_messages(self) -> Iterator[str]: + """Get the messages of all commits and open drafts. Yields: - The latest commit. + The commit or draft message. """ + latest_draft: Optional[Draft] = None while True: try: - latest_commit = self._sorted_commits[-1].pop(0) + latest_commit = self._sorted_commit_logs[-1][0] except IndexError: + if latest_draft: + yield self._draft_printer(latest_draft) + if self._sorted_drafts: + yield from map(self._draft_printer, reversed(self._sorted_drafts)) return - if self._merge(latest_commit): - continue - yield latest_commit - self._sort() - - def generate_commits(self) -> Iterator[str]: - """Get the the log of all commits. + if not latest_draft and self._sorted_drafts: + latest_draft = self._sorted_drafts.pop() + if latest_draft and latest_draft.updated_at > latest_commit.committer.date: + yield self._draft_printer(latest_draft) + latest_draft = None + else: + del self._sorted_commit_logs[-1][0] + if self._merge(latest_commit): + continue + yield self._commit_printer( + latest_commit, self._commit_id_to_branches.get(latest_commit.commit_id) + ) + self._sort() + + +class _RootCommitNode: # pylint: disable=too-many-instance-attributes + SIGN = "*" + + def __init__(self) -> None: + self.parent: Optional[_RootCommitNode] = None + self.date = 0 + self.key = ROOT_COMMIT_ID + self.available_child_num = 0 + self.get_oneline_message: Callable[ + [List[str]], str + ] = lambda _: _ONELINE_COMMIT_MESSAGE.format(shorten(ROOT_COMMIT_ID), "ROOT_COMMIT") + self.get_full_message: Callable[ + [List[str]], str + ] = lambda _: f'{click.style("ROOT_COMMIT {}", fg="yellow")}'.format(self.key) + + def add_child(self, child_node: _T) -> None: + """Save the child node for the parent node. - Yields: - The commit log. + Arguments: + child_node: The child node. """ - for commit in self._generate_commits(): - commit_id = commit.commit_id - yield self._printer(commit, self._commit_id_to_branches.get(commit_id)) + child_node.parent = self + self.available_child_num += 1 -class _CommitNode: +class _CommitNode(_RootCommitNode): """This class defines the tree struct of graphical logging commits. Arguments: @@ -227,31 +309,45 @@ class _CommitNode: """ def __init__(self, commit: Commit): + super().__init__() self.commit = commit - self.children: List[_CommitNode] = [] - self.available_child_num = 0 - self.parent: Optional[_CommitNode] = None + self.date = commit.committer.date + self.key = commit.commit_id + self.get_oneline_message = partial(_get_oneline_commit_message, commit) + self.get_full_message = partial(_get_full_commit_message, commit) - def add_child(self, child_node: "_CommitNode") -> None: - """Save the child node for the parent node. - Arguments: - child_node: The child node. +class _DraftNode: # pylint: disable=too-many-instance-attributes + """This class defines the tree struct of graphical logging drafts. - """ - child_node.parent = self - self.children.append(child_node) - self.available_child_num += 1 + Arguments: + draft: The draft that needs to be saved in the tree. + + """ + + SIGN = "#" + + def __init__(self, draft: Draft) -> None: + self.draft = draft + self.parent: Optional[_CommitNode] = None + self.date = draft.updated_at + self.available_child_num = 0 + self.key = str(draft.number) + self.get_oneline_message: Callable[[List[str]], str] = lambda _: _get_oneline_draft_message( + draft + ) + self.get_full_message: Callable[[List[str]], str] = lambda _: _get_full_draft_message(draft) class _GraphPrinter: - """This class defines the structure of logging graphical commits stack. + """This class defines the structure of logging graphical commits and open drafts stack. Arguments: dataset_client: The dataset that needs to be logged. revisions: The revisions that needs to be logged. oneline: Whether to log with oneline method. commit_id_to_branches: The map of commit id to branch name. + show_drafts: Whether to log open drafts. """ @@ -261,21 +357,28 @@ def __init__( revisions: List[Optional[str]], commit_id_to_branches: Dict[str, List[str]], oneline: bool, + *, + show_drafts: bool, ): - self._commit_id_to_branches = commit_id_to_branches - self._graph_printer = self._add_graph_oneline if oneline else self._add_graph_full - - self._sorted_leaves = self._build_commit_tree(dataset_client, revisions) + self._key_to_branches = commit_id_to_branches + self._graph_printer = self._add_oneline_graph if oneline else self._add_full_graph + + self._sorted_leaves = self._build_tree(dataset_client, revisions, show_drafts) + error_message = f'Dataset "{dataset_client.name}" has no commit history' + if show_drafts: + error_message += " or open drafts" + if not self._sorted_leaves: + error(error_message) self._pointer = 0 self._merge_pointer: Optional[int] = None self._log_colors = cycle(_LOG_COLORS) self._layer_colors: List[str] = [next(self._log_colors)] - def _get_log_node(self) -> _CommitNode: - """Get the next log commit node. + def _get_log_node(self) -> _T: + """Get the next log node. Returns: - The next log commit node. + The next log node. Raises: RuntimeError: Graphical logging algorithm error. @@ -292,7 +395,7 @@ def _get_log_node(self) -> _CommitNode: raise RuntimeError("Graphical logging algorithm error.") - def _merge_branches(self, parent: _CommitNode) -> None: + def _merge_branches(self, parent: _RootCommitNode) -> None: """Merge branches. Arguments: @@ -303,7 +406,7 @@ def _merge_branches(self, parent: _CommitNode) -> None: self._pointer, self._merge_pointer = sorted((index, self._pointer)) del self._sorted_leaves[self._merge_pointer] - def _set_next_node(self, node: _CommitNode) -> None: + def _set_next_node(self, node: _RootCommitNode) -> None: """Set the next node at the position of the printing pointer. Arguments: @@ -313,18 +416,16 @@ def _set_next_node(self, node: _CommitNode) -> None: node.available_child_num -= 1 self._sorted_leaves[self._pointer] = node - def _add_graph_oneline( - self, commit: Commit, branch_names: Optional[List[str]], original_pointer: int - ) -> str: - log = _get_oneline_log(commit, branch_names) + def _add_oneline_graph(self, node: _T, original_pointer: int) -> str: + message = node.get_oneline_message(self._key_to_branches[node.key]) prefixes = self._get_colorful_prefixes() # Don't merge branches. if self._merge_pointer is None: - return f"{self._get_title_prefix(prefixes, original_pointer)}{log}" + return f"{self._get_title_prefix(prefixes, original_pointer, node.SIGN)}{message}" # Merge branches. del self._layer_colors[self._merge_pointer] - lines = [f"{self._get_title_prefix(prefixes, original_pointer)}{log}"] + lines = [f"{self._get_title_prefix(prefixes, original_pointer, node.SIGN)}{message}"] lines.extend( f"{prefixes}\n" for prefixes in self._get_merge_prefixes( @@ -333,14 +434,13 @@ def _add_graph_oneline( ) return "".join(lines) - def _add_graph_full( - self, commit: Commit, branch_names: Optional[List[str]], original_pointer: int - ) -> str: - log = _get_full_log(commit, branch_names) - splitlines = iter(log.splitlines()) + def _add_full_graph(self, node: _T, original_pointer: int) -> str: + message = node.get_full_message(self._key_to_branches[node.key]) + splitlines = iter(message.splitlines()) original_prefixes = self._get_colorful_prefixes() lines = [ - f"{self._get_title_prefix(original_prefixes, original_pointer)}{next(splitlines)}\n" + f"{self._get_title_prefix(original_prefixes, original_pointer, node.SIGN)}" + f"{next(splitlines)}\n" ] # Don't merge branches. if self._merge_pointer is None: @@ -374,9 +474,9 @@ def _get_colorful_prefixes(self) -> List[str]: return prefixes @staticmethod - def _get_title_prefix(prefixes: List[str], original_pointer: int) -> str: + def _get_title_prefix(prefixes: List[str], original_pointer: int, sign: str) -> str: title_prefixes = prefixes.copy() - title_prefixes[2 * original_pointer] = "*" + title_prefixes[2 * original_pointer] = sign return "".join(title_prefixes) @staticmethod @@ -387,16 +487,17 @@ def _combine_details( for prefix, message in zip_longest(merge_prefixes, messages, fillvalue=fillvalue): yield f"{prefix} {message}\n" - @staticmethod - def _build_commit_tree( - dataset_client: DatasetClientType, revisions: List[Optional[str]] - ) -> List[_CommitNode]: - commit_to_node: Dict[str, _CommitNode] = {} - leaves: Dict[str, _CommitNode] = {} + def _build_tree( + self, + dataset_client: DatasetClientType, + revisions: List[Optional[str]], + show_drafts: bool, + ) -> List[_T]: + commit_to_node: Dict[str, _RootCommitNode] = {} + leaves: Dict[str, _T] = {} for revision in revisions: - commits = dataset_client.list_commits(revision) - child_node: Optional[_CommitNode] = None - for index, commit in enumerate(commits): + child_node: Optional[_RootCommitNode] = None + for commit in dataset_client.list_commits(revision): commit_id = commit.commit_id current_node = commit_to_node.get(commit_id, _CommitNode(commit)) if child_node: @@ -406,23 +507,43 @@ def _build_commit_tree( break # Save leaf node to leaf set. - if index == 0: + if not child_node: leaves[commit_id] = current_node # Save commit to commit dict. commit_to_node[commit_id] = current_node child_node = current_node + if show_drafts: + for draft in dataset_client.list_drafts(branch_name=revision): + draft_node = _DraftNode(draft) + leaves[draft_node.key] = draft_node + self._key_to_branches[draft_node.key] = [draft.branch_name] + parent_commit_id = draft.parent_commit_id + if parent_commit_id in commit_to_node: + parent_node = commit_to_node[parent_commit_id] + else: + parent_node = _RootCommitNode() + commit_to_node[parent_commit_id] = parent_node + if child_node: + parent_node.add_child(child_node) + parent_node.add_child(draft_node) + + return self._check_and_sort_leaves(leaves) + @staticmethod + def _check_and_sort_leaves(leaves: Dict[str, _T]) -> List[_T]: # Check the correction of leaf set. - for commit_id in tuple(leaves): - if leaves[commit_id].available_child_num != 0: - del leaves[commit_id] - return sorted(leaves.values(), key=lambda x: x.commit.committer.date, reverse=True) + delete_keys = [ + key for key, leaf_node in leaves.items() if leaf_node.available_child_num != 0 + ] + for key in delete_keys: + del leaves[key] + return sorted(leaves.values(), key=lambda x: x.date, reverse=True) - def generate_commits(self) -> Iterator[str]: - """Get the graphical commit log. + def generate_commits_and_drafts_messages(self) -> Iterator[str]: + """Get the graphical message. Yields: - The graphical commit log. + The graphical message. """ change_color = False @@ -440,11 +561,18 @@ def generate_commits(self) -> Iterator[str]: change_color = True self._merge_branches(parent) yield self._graph_printer( - current_node.commit, - self._commit_id_to_branches.get(current_node.commit.commit_id), + current_node, original_pointer, ) if not parent: break self._set_next_node(parent) + + +def _echo_messages(message_generator: Iterable[str]) -> None: + if sys.platform.startswith("win"): + for item in message_generator: + click.echo(item) + else: + click.echo_via_pager(message_generator)