Skip to content

Commit

Permalink
Merge pull request Codium-ai#48 from Codium-ai/hl/gitlab_fix
Browse files Browse the repository at this point in the history
Inline suggestion refactor + supporting GitLab
  • Loading branch information
hussam789 authored Jul 14, 2023
2 parents e48cc55 + 2dca2bf commit bcd09a7
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 30 deletions.
3 changes: 3 additions & 0 deletions pr_agent/algo/git_patch_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ def extend_patch(original_file_str, patch_str, num_lines) -> str:
if not patch_str or num_lines == 0:
return patch_str

if type(original_file_str) == bytes:
original_file_str = original_file_str.decode('utf-8')

original_lines = original_file_str.splitlines()
patch_lines = patch_str.splitlines()
extended_patch_lines = []
Expand Down
13 changes: 13 additions & 0 deletions pr_agent/git_providers/git_provider.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass

# enum EDIT_TYPE (ADDED, DELETED, MODIFIED, RENAMED)
from enum import Enum
class EDIT_TYPE(Enum):
ADDED = 1
DELETED = 2
MODIFIED = 3
RENAMED = 4

@dataclass
class FilePatchInfo:
Expand All @@ -9,6 +16,8 @@ class FilePatchInfo:
patch: str
filename: str
tokens: int = -1
edit_type: EDIT_TYPE = EDIT_TYPE.MODIFIED
old_filename: str = None


class GitProvider(ABC):
Expand All @@ -24,6 +33,10 @@ def publish_description(self, pr_title: str, pr_body: str):
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
pass

@abstractmethod
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str):
pass

@abstractmethod
def remove_initial_comment(self):
pass
Expand Down
30 changes: 28 additions & 2 deletions pr_agent/git_providers/github_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ def __init__(self, pr_url: Optional[str] = None):
self.pr_num = None
self.pr = None
self.github_user_id = None
self.diff_files = None
if pr_url:
self.set_pr(pr_url)
self.last_commit_id = list(self.pr.get_commits())[-1]

def set_pr(self, pr_url: str):
self.repo, self.pr_num = self._parse_pr_url(pr_url)
Expand All @@ -35,6 +37,7 @@ def get_diff_files(self) -> list[FilePatchInfo]:
original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha)
new_file_content_str = self._get_pr_file_content(file, self.pr.head.sha)
diff_files.append(FilePatchInfo(original_file_content_str, new_file_content_str, file.patch, file.filename))
self.diff_files = diff_files
return diff_files

def publish_description(self, pr_title: str, pr_body: str):
Expand All @@ -50,6 +53,29 @@ def publish_comment(self, pr_comment: str, is_temporary: bool = False):
self.pr.comments_list = []
self.pr.comments_list.append(response)

def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str):
self.diff_files = self.diff_files if self.diff_files else self.get_diff_files()
position = -1
for file in self.diff_files:
if file.filename.strip() == relevant_file:
patch = file.patch
patch_lines = patch.splitlines()
for i, line in enumerate(patch_lines):
if relevant_line_in_file in line:
position = i
break
elif relevant_line_in_file[0] == '+' and relevant_line_in_file[1:] in line:
# The model often adds a '+' to the beginning of the relevant_line_in_file even if originally
# it's a context line
position = i
break
if position == -1:
if settings.config.verbosity_level >= 2:
logging.info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
else:
path = relevant_file.strip()
self.pr.create_review_comment(body=body, commit_id=self.last_commit_id, path=path, position=position)

def remove_initial_comment(self):
try:
for comment in self.pr.comments_list:
Expand Down Expand Up @@ -150,9 +176,9 @@ def _get_repo(self):
def _get_pr(self):
return self._get_repo().get_pull(self.pr_num)

def _get_pr_file_content(self, file: FilePatchInfo, sha: str):
def _get_pr_file_content(self, file: FilePatchInfo, sha: str) -> str:
try:
file_content_str = self._get_repo().get_contents(file.filename, ref=sha).decoded_content.decode()
file_content_str = str(self._get_repo().get_contents(file.filename, ref=sha).decoded_content.decode())
except Exception:
file_content_str = ""
return file_content_str
109 changes: 107 additions & 2 deletions pr_agent/git_providers/gitlab_provider.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import logging
import re
from typing import Optional, Tuple
from urllib.parse import urlparse

import gitlab

from pr_agent.config_loader import settings

from .git_provider import FilePatchInfo, GitProvider
from .git_provider import FilePatchInfo, GitProvider, EDIT_TYPE


class GitLabProvider(GitProvider):
Expand All @@ -24,6 +25,7 @@ def __init__(self, merge_request_url: Optional[str] = None):
self.id_project = None
self.id_mr = None
self.mr = None
self.diff_files = None
self.temp_comments = []
self._set_merge_request(merge_request_url)

Expand All @@ -35,10 +37,35 @@ def pr(self):
def _set_merge_request(self, merge_request_url: str):
self.id_project, self.id_mr = self._parse_merge_request_url(merge_request_url)
self.mr = self._get_merge_request()
self.last_diff = self.mr.diffs.list()[-1]

def _get_pr_file_content(self, file_path: str, branch: str) -> str:
return self.gl.projects.get(self.id_project).files.get(file_path, branch).decode()

def get_diff_files(self) -> list[FilePatchInfo]:
diffs = self.mr.changes()['changes']
diff_files = [FilePatchInfo("", "", diff['diff'], diff['new_path']) for diff in diffs]
diff_files = []
for diff in diffs:
original_file_content_str = self._get_pr_file_content(diff['old_path'], self.mr.target_branch)
new_file_content_str = self._get_pr_file_content(diff['new_path'], self.mr.source_branch)
edit_type = EDIT_TYPE.MODIFIED
if diff['new_file']:
edit_type = EDIT_TYPE.ADDED
elif diff['deleted_file']:
edit_type = EDIT_TYPE.DELETED
elif diff['renamed_file']:
edit_type = EDIT_TYPE.RENAMED
try:
original_file_content_str = bytes.decode(original_file_content_str, 'utf-8')
new_file_content_str = bytes.decode(new_file_content_str, 'utf-8')
except UnicodeDecodeError:
logging.warning(
f"Cannot decode file {diff['old_path']} or {diff['new_path']} in merge request {self.id_mr}")
diff_files.append(
FilePatchInfo(original_file_content_str, new_file_content_str, diff['diff'], diff['new_path'],
edit_type=edit_type,
old_filename=None if diff['old_path'] == diff['new_path'] else diff['old_path']))
self.diff_files = diff_files
return diff_files

def get_files(self):
Expand All @@ -53,6 +80,81 @@ def publish_comment(self, mr_comment: str, is_temporary: bool = False):
if is_temporary:
self.temp_comments.append(comment)

def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str):
self.diff_files = self.diff_files if self.diff_files else self.get_diff_files()
edit_type, found, source_line_no, target_file, target_line_no = self.search_line(relevant_file,
relevant_line_in_file)
if not found:
logging.info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
else:
if edit_type == 'addition':
position = target_line_no - 1
else:
position = source_line_no - 1
d = self.last_diff
pos_obj = {'position_type': 'text',
'new_path': target_file.filename,
'old_path': target_file.old_filename if target_file.old_filename else target_file.filename,
'base_sha': d.base_commit_sha, 'start_sha': d.start_commit_sha, 'head_sha': d.head_commit_sha}
if edit_type == 'deletion':
pos_obj['old_line'] = position
elif edit_type == 'addition':
pos_obj['new_line'] = position
else:
pos_obj['new_line'] = position
pos_obj['old_line'] = position
self.mr.discussions.create({'body': body,
'position': pos_obj})

def search_line(self, relevant_file, relevant_line_in_file):
RE_HUNK_HEADER = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
target_file = None
source_line_no = 0
target_line_no = 0
found = False
edit_type = self.get_edit_type(relevant_line_in_file)
for file in self.diff_files:
if file.filename == relevant_file:
target_file = file
patch = file.patch
patch_lines = patch.splitlines()
for i, line in enumerate(patch_lines):
if line.startswith('@@'):
match = RE_HUNK_HEADER.match(line)
if not match:
continue
start_old, size_old, start_new, size_new, _ = match.groups()
source_line_no = int(start_old)
target_line_no = int(start_new)
continue
if line.startswith('-'):
source_line_no += 1
elif line.startswith('+'):
target_line_no += 1
elif line.startswith(' '):
source_line_no += 1
target_line_no += 1
if relevant_line_in_file in line:
found = True
edit_type = self.get_edit_type(line)
break
elif relevant_line_in_file[0] == '+' and relevant_line_in_file[1:] in line:
# The model often adds a '+' to the beginning of the relevant_line_in_file even if originally
# it's a context line
found = True
edit_type = self.get_edit_type(line)
break
return edit_type, found, source_line_no, target_file, target_line_no

def get_edit_type(self, relevant_line_in_file):
edit_type = 'context'
if relevant_line_in_file[0] == '-':
edit_type = 'deletion'
elif relevant_line_in_file[0] == '+':
edit_type = 'addition'
return edit_type

def remove_initial_comment(self):
try:
for comment in self.temp_comments:
Expand Down Expand Up @@ -94,3 +196,6 @@ def _parse_merge_request_url(self, merge_request_url: str) -> Tuple[int, int]:
def _get_merge_request(self):
mr = self.gl.projects.get(self.id_project).mergerequests.get(self.id_mr)
return mr

def get_user_id(self):
return None
28 changes: 2 additions & 26 deletions pr_agent/tools/pr_reviewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,39 +111,15 @@ def _prepare_pr_review(self) -> str:
return markdown_text

def _publish_inline_code_comments(self):
if settings.config.git_provider != 'github': # inline comments are currently only supported for github
return

review = self.prediction.strip()
try:
data = json.loads(review)
except json.decoder.JSONDecodeError:
data = try_fix_json(review)

pr = self.git_provider.pr
last_commit_id = list(pr.get_commits())[-1]
if hasattr(pr, 'diff_files'): # prevent bringing all the files again
diff_files = pr.diff_files
else:
diff_files = list(self.git_provider.get_diff_files())

for d in data['PR Feedback']['Code suggestions']:
relevant_file = d['relevant file'].strip()
relevant_line_in_file = d['relevant line in file'].strip()
content = d['suggestion content']
position = -1
for file in diff_files:
if file.filename.strip() == relevant_file:
patch = file.patch
patch_lines = patch.splitlines()
for i, line in enumerate(patch_lines):
if relevant_line_in_file in line:
position = i
break
if position == -1:
if settings.config.verbosity_level >= 2:
logging.info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
else:
body = content
path = relevant_file.strip()
pr.create_review_comment(body=body, commit_id=last_commit_id, path=path, position=position)

self.git_provider.publish_inline_comment(content, relevant_file, relevant_line_in_file)

0 comments on commit bcd09a7

Please sign in to comment.