diff --git a/pr_agent/algo/git_patch_processing.py b/pr_agent/algo/git_patch_processing.py index d75c6c979..6c450a301 100644 --- a/pr_agent/algo/git_patch_processing.py +++ b/pr_agent/algo/git_patch_processing.py @@ -13,6 +13,9 @@ def extend_patch(original_file_str, patch_str, num_lines) -> str: if not patch_str or num_lines == 0: return patch_str + if type(original_file_str) == bytes: + original_file_str = original_file_str.decode('utf-8') + original_lines = original_file_str.splitlines() patch_lines = patch_str.splitlines() extended_patch_lines = [] diff --git a/pr_agent/git_providers/git_provider.py b/pr_agent/git_providers/git_provider.py index e86f461d8..dd8e4fcb1 100644 --- a/pr_agent/git_providers/git_provider.py +++ b/pr_agent/git_providers/git_provider.py @@ -1,6 +1,13 @@ from abc import ABC, abstractmethod from dataclasses import dataclass +# enum EDIT_TYPE (ADDED, DELETED, MODIFIED, RENAMED) +from enum import Enum +class EDIT_TYPE(Enum): + ADDED = 1 + DELETED = 2 + MODIFIED = 3 + RENAMED = 4 @dataclass class FilePatchInfo: @@ -9,6 +16,8 @@ class FilePatchInfo: patch: str filename: str tokens: int = -1 + edit_type: EDIT_TYPE = EDIT_TYPE.MODIFIED + old_filename: str = None class GitProvider(ABC): @@ -24,6 +33,10 @@ def publish_description(self, pr_title: str, pr_body: str): def publish_comment(self, pr_comment: str, is_temporary: bool = False): pass + @abstractmethod + def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str): + pass + @abstractmethod def remove_initial_comment(self): pass diff --git a/pr_agent/git_providers/github_provider.py b/pr_agent/git_providers/github_provider.py index 16ec6293d..af035c9f7 100644 --- a/pr_agent/git_providers/github_provider.py +++ b/pr_agent/git_providers/github_provider.py @@ -18,8 +18,10 @@ def __init__(self, pr_url: Optional[str] = None): self.pr_num = None self.pr = None self.github_user_id = None + self.diff_files = None if pr_url: self.set_pr(pr_url) + self.last_commit_id = list(self.pr.get_commits())[-1] def set_pr(self, pr_url: str): self.repo, self.pr_num = self._parse_pr_url(pr_url) @@ -35,6 +37,7 @@ def get_diff_files(self) -> list[FilePatchInfo]: original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha) new_file_content_str = self._get_pr_file_content(file, self.pr.head.sha) diff_files.append(FilePatchInfo(original_file_content_str, new_file_content_str, file.patch, file.filename)) + self.diff_files = diff_files return diff_files def publish_description(self, pr_title: str, pr_body: str): @@ -50,6 +53,29 @@ def publish_comment(self, pr_comment: str, is_temporary: bool = False): self.pr.comments_list = [] self.pr.comments_list.append(response) + def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str): + self.diff_files = self.diff_files if self.diff_files else self.get_diff_files() + position = -1 + for file in self.diff_files: + if file.filename.strip() == relevant_file: + patch = file.patch + patch_lines = patch.splitlines() + for i, line in enumerate(patch_lines): + if relevant_line_in_file in line: + position = i + break + elif relevant_line_in_file[0] == '+' and relevant_line_in_file[1:] in line: + # The model often adds a '+' to the beginning of the relevant_line_in_file even if originally + # it's a context line + position = i + break + if position == -1: + if settings.config.verbosity_level >= 2: + logging.info(f"Could not find position for {relevant_file} {relevant_line_in_file}") + else: + path = relevant_file.strip() + self.pr.create_review_comment(body=body, commit_id=self.last_commit_id, path=path, position=position) + def remove_initial_comment(self): try: for comment in self.pr.comments_list: @@ -150,9 +176,9 @@ def _get_repo(self): def _get_pr(self): return self._get_repo().get_pull(self.pr_num) - def _get_pr_file_content(self, file: FilePatchInfo, sha: str): + def _get_pr_file_content(self, file: FilePatchInfo, sha: str) -> str: try: - file_content_str = self._get_repo().get_contents(file.filename, ref=sha).decoded_content.decode() + file_content_str = str(self._get_repo().get_contents(file.filename, ref=sha).decoded_content.decode()) except Exception: file_content_str = "" return file_content_str diff --git a/pr_agent/git_providers/gitlab_provider.py b/pr_agent/git_providers/gitlab_provider.py index 485e0cf93..a04a482a0 100644 --- a/pr_agent/git_providers/gitlab_provider.py +++ b/pr_agent/git_providers/gitlab_provider.py @@ -1,4 +1,5 @@ import logging +import re from typing import Optional, Tuple from urllib.parse import urlparse @@ -6,7 +7,7 @@ from pr_agent.config_loader import settings -from .git_provider import FilePatchInfo, GitProvider +from .git_provider import FilePatchInfo, GitProvider, EDIT_TYPE class GitLabProvider(GitProvider): @@ -24,6 +25,7 @@ def __init__(self, merge_request_url: Optional[str] = None): self.id_project = None self.id_mr = None self.mr = None + self.diff_files = None self.temp_comments = [] self._set_merge_request(merge_request_url) @@ -35,10 +37,35 @@ def pr(self): def _set_merge_request(self, merge_request_url: str): self.id_project, self.id_mr = self._parse_merge_request_url(merge_request_url) self.mr = self._get_merge_request() + self.last_diff = self.mr.diffs.list()[-1] + + def _get_pr_file_content(self, file_path: str, branch: str) -> str: + return self.gl.projects.get(self.id_project).files.get(file_path, branch).decode() def get_diff_files(self) -> list[FilePatchInfo]: diffs = self.mr.changes()['changes'] - diff_files = [FilePatchInfo("", "", diff['diff'], diff['new_path']) for diff in diffs] + diff_files = [] + for diff in diffs: + original_file_content_str = self._get_pr_file_content(diff['old_path'], self.mr.target_branch) + new_file_content_str = self._get_pr_file_content(diff['new_path'], self.mr.source_branch) + edit_type = EDIT_TYPE.MODIFIED + if diff['new_file']: + edit_type = EDIT_TYPE.ADDED + elif diff['deleted_file']: + edit_type = EDIT_TYPE.DELETED + elif diff['renamed_file']: + edit_type = EDIT_TYPE.RENAMED + try: + original_file_content_str = bytes.decode(original_file_content_str, 'utf-8') + new_file_content_str = bytes.decode(new_file_content_str, 'utf-8') + except UnicodeDecodeError: + logging.warning( + f"Cannot decode file {diff['old_path']} or {diff['new_path']} in merge request {self.id_mr}") + diff_files.append( + FilePatchInfo(original_file_content_str, new_file_content_str, diff['diff'], diff['new_path'], + edit_type=edit_type, + old_filename=None if diff['old_path'] == diff['new_path'] else diff['old_path'])) + self.diff_files = diff_files return diff_files def get_files(self): @@ -53,6 +80,81 @@ def publish_comment(self, mr_comment: str, is_temporary: bool = False): if is_temporary: self.temp_comments.append(comment) + def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str): + self.diff_files = self.diff_files if self.diff_files else self.get_diff_files() + edit_type, found, source_line_no, target_file, target_line_no = self.search_line(relevant_file, + relevant_line_in_file) + if not found: + logging.info(f"Could not find position for {relevant_file} {relevant_line_in_file}") + else: + if edit_type == 'addition': + position = target_line_no - 1 + else: + position = source_line_no - 1 + d = self.last_diff + pos_obj = {'position_type': 'text', + 'new_path': target_file.filename, + 'old_path': target_file.old_filename if target_file.old_filename else target_file.filename, + 'base_sha': d.base_commit_sha, 'start_sha': d.start_commit_sha, 'head_sha': d.head_commit_sha} + if edit_type == 'deletion': + pos_obj['old_line'] = position + elif edit_type == 'addition': + pos_obj['new_line'] = position + else: + pos_obj['new_line'] = position + pos_obj['old_line'] = position + self.mr.discussions.create({'body': body, + 'position': pos_obj}) + + def search_line(self, relevant_file, relevant_line_in_file): + RE_HUNK_HEADER = re.compile( + r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)") + target_file = None + source_line_no = 0 + target_line_no = 0 + found = False + edit_type = self.get_edit_type(relevant_line_in_file) + for file in self.diff_files: + if file.filename == relevant_file: + target_file = file + patch = file.patch + patch_lines = patch.splitlines() + for i, line in enumerate(patch_lines): + if line.startswith('@@'): + match = RE_HUNK_HEADER.match(line) + if not match: + continue + start_old, size_old, start_new, size_new, _ = match.groups() + source_line_no = int(start_old) + target_line_no = int(start_new) + continue + if line.startswith('-'): + source_line_no += 1 + elif line.startswith('+'): + target_line_no += 1 + elif line.startswith(' '): + source_line_no += 1 + target_line_no += 1 + if relevant_line_in_file in line: + found = True + edit_type = self.get_edit_type(line) + break + elif relevant_line_in_file[0] == '+' and relevant_line_in_file[1:] in line: + # The model often adds a '+' to the beginning of the relevant_line_in_file even if originally + # it's a context line + found = True + edit_type = self.get_edit_type(line) + break + return edit_type, found, source_line_no, target_file, target_line_no + + def get_edit_type(self, relevant_line_in_file): + edit_type = 'context' + if relevant_line_in_file[0] == '-': + edit_type = 'deletion' + elif relevant_line_in_file[0] == '+': + edit_type = 'addition' + return edit_type + def remove_initial_comment(self): try: for comment in self.temp_comments: @@ -94,3 +196,6 @@ def _parse_merge_request_url(self, merge_request_url: str) -> Tuple[int, int]: def _get_merge_request(self): mr = self.gl.projects.get(self.id_project).mergerequests.get(self.id_mr) return mr + + def get_user_id(self): + return None diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index eabf89550..e6d001e1b 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -111,39 +111,15 @@ def _prepare_pr_review(self) -> str: return markdown_text def _publish_inline_code_comments(self): - if settings.config.git_provider != 'github': # inline comments are currently only supported for github - return - review = self.prediction.strip() try: data = json.loads(review) except json.decoder.JSONDecodeError: data = try_fix_json(review) - pr = self.git_provider.pr - last_commit_id = list(pr.get_commits())[-1] - if hasattr(pr, 'diff_files'): # prevent bringing all the files again - diff_files = pr.diff_files - else: - diff_files = list(self.git_provider.get_diff_files()) - for d in data['PR Feedback']['Code suggestions']: relevant_file = d['relevant file'].strip() relevant_line_in_file = d['relevant line in file'].strip() content = d['suggestion content'] - position = -1 - for file in diff_files: - if file.filename.strip() == relevant_file: - patch = file.patch - patch_lines = patch.splitlines() - for i, line in enumerate(patch_lines): - if relevant_line_in_file in line: - position = i - break - if position == -1: - if settings.config.verbosity_level >= 2: - logging.info(f"Could not find position for {relevant_file} {relevant_line_in_file}") - else: - body = content - path = relevant_file.strip() - pr.create_review_comment(body=body, commit_id=last_commit_id, path=path, position=position) \ No newline at end of file + + self.git_provider.publish_inline_comment(content, relevant_file, relevant_line_in_file) \ No newline at end of file