Skip to content

Commit

Permalink
Merge pull request #5723 from cjerdonek/vcs-add-git-get-branch
Browse files Browse the repository at this point in the history
Fix the "new install" case of issue #2037
  • Loading branch information
pradyunsg authored Sep 18, 2018
2 parents 6296766 + 0d81793 commit 392cb09
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 21 deletions.
1 change: 1 addition & 0 deletions news/2037.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Checkout the correct branch when doing an editable Git install.
Empty file.
51 changes: 42 additions & 9 deletions src/pip/_internal/vcs/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,20 @@ def get_git_version(self):
version = '.'.join(version.split('.')[:3])
return parse_version(version)

def get_branch(self, location):
"""
Return the current branch, or None if HEAD isn't at a branch
(e.g. detached HEAD).
"""
args = ['rev-parse', '--abbrev-ref', 'HEAD']
output = self.run_command(args, show_stdout=False, cwd=location)
branch = output.strip()

if branch == 'HEAD':
return None

return branch

def export(self, location):
"""Export the Git repository at the url to the destination location"""
if not location.endswith('/'):
Expand All @@ -91,8 +105,8 @@ def export(self, location):

def get_revision_sha(self, dest, rev):
"""
Return a commit hash for the given revision if it names a remote
branch or tag. Otherwise, return None.
Return (sha_or_none, is_branch), where sha_or_none is a commit hash
if the revision names a remote branch or tag, otherwise None.
Args:
dest: the repository directory.
Expand All @@ -115,7 +129,13 @@ def get_revision_sha(self, dest, rev):
branch_ref = 'refs/remotes/origin/{}'.format(rev)
tag_ref = 'refs/tags/{}'.format(rev)

return refs.get(branch_ref) or refs.get(tag_ref)
sha = refs.get(branch_ref)
if sha is not None:
return (sha, True)

sha = refs.get(tag_ref)

return (sha, False)

def resolve_revision(self, dest, url, rev_options):
"""
Expand All @@ -126,10 +146,13 @@ def resolve_revision(self, dest, url, rev_options):
rev_options: a RevOptions object.
"""
rev = rev_options.arg_rev
sha = self.get_revision_sha(dest, rev)
sha, is_branch = self.get_revision_sha(dest, rev)

if sha is not None:
return rev_options.make_new(sha)
rev_options = rev_options.make_new(sha)
rev_options.branch_name = rev if is_branch else None

return rev_options

# Do not show a warning for the common case of something that has
# the form of a Git commit hash.
Expand Down Expand Up @@ -177,10 +200,20 @@ def fetch_new(self, dest, url, rev_options):
if rev_options.rev:
# Then a specific revision was requested.
rev_options = self.resolve_revision(dest, url, rev_options)
# Only do a checkout if the current commit id doesn't match
# the requested revision.
if not self.is_commit_id_equal(dest, rev_options.rev):
cmd_args = ['checkout', '-q'] + rev_options.to_args()
branch_name = getattr(rev_options, 'branch_name', None)
if branch_name is None:
# Only do a checkout if the current commit id doesn't match
# the requested revision.
if not self.is_commit_id_equal(dest, rev_options.rev):
cmd_args = ['checkout', '-q'] + rev_options.to_args()
self.run_command(cmd_args, cwd=dest)
elif self.get_branch(dest) != branch_name:
# Then a specific branch was requested, and that branch
# is not yet checked out.
track_branch = 'origin/{}'.format(branch_name)
cmd_args = [
'checkout', '-b', branch_name, '--track', track_branch,
]
self.run_command(cmd_args, cwd=dest)

#: repo may contain submodules
Expand Down
86 changes: 83 additions & 3 deletions tests/functional/test_install_vcs_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,32 @@
from tests.lib.local_repos import local_checkout


def _get_editable_repo_dir(script, package_name):
"""
Return the repository directory for an editable install.
"""
return script.venv_path / 'src' / package_name


def _get_editable_branch(script, package_name):
"""
Return the current branch of an editable install.
"""
repo_dir = script.venv_path / 'src' / package_name
repo_dir = _get_editable_repo_dir(script, package_name)
result = script.run(
'git', 'rev-parse', '--abbrev-ref', 'HEAD', cwd=repo_dir
)
return result.stdout.strip()


def _get_branch_remote(script, package_name, branch):
"""
"""
repo_dir = _get_editable_repo_dir(script, package_name)
result = script.run(
'git', 'config', 'branch.{}.remote'.format(branch), cwd=repo_dir
)
return result.stdout.strip()


Expand Down Expand Up @@ -363,7 +380,69 @@ def test_git_works_with_editable_non_origin_repo(script):
assert "version-pkg==0.1" in result.stdout


def test_editable_non_master_default_branch(script):
def test_editable__no_revision(script):
"""
Test a basic install in editable mode specifying no revision.
"""
version_pkg_path = _create_test_package(script)
_install_version_pkg_only(script, version_pkg_path)

branch = _get_editable_branch(script, 'version-pkg')
assert branch == 'master'

remote = _get_branch_remote(script, 'version-pkg', 'master')
assert remote == 'origin'


def test_editable__branch_with_sha_same_as_default(script):
"""
Test installing in editable mode a branch whose sha matches the sha
of the default branch, but is different from the default branch.
"""
version_pkg_path = _create_test_package(script)
# Create a second branch with the same SHA.
script.run(
'git', 'branch', 'develop', expect_stderr=True,
cwd=version_pkg_path,
)
_install_version_pkg_only(
script, version_pkg_path, rev='develop', expect_stderr=True
)

branch = _get_editable_branch(script, 'version-pkg')
assert branch == 'develop'

remote = _get_branch_remote(script, 'version-pkg', 'develop')
assert remote == 'origin'


def test_editable__branch_with_sha_different_from_default(script):
"""
Test installing in editable mode a branch whose sha is different from
the sha of the default branch.
"""
version_pkg_path = _create_test_package(script)
# Create a second branch.
script.run(
'git', 'branch', 'develop', expect_stderr=True,
cwd=version_pkg_path,
)
# Add another commit to the master branch to give it a different sha.
_change_test_package_version(script, version_pkg_path)

version = _install_version_pkg(
script, version_pkg_path, rev='develop', expect_stderr=True
)
assert version == '0.1'

branch = _get_editable_branch(script, 'version-pkg')
assert branch == 'develop'

remote = _get_branch_remote(script, 'version-pkg', 'develop')
assert remote == 'origin'


def test_editable__non_master_default_branch(script):
"""
Test the branch you get after an editable install from a remote repo
with a non-master default branch.
Expand All @@ -376,8 +455,9 @@ def test_editable_non_master_default_branch(script):
cwd=version_pkg_path,
)
_install_version_pkg_only(script, version_pkg_path)

branch = _get_editable_branch(script, 'version-pkg')
assert 'release' == branch
assert branch == 'release'


def test_reinstalling_works_with_editable_non_master_branch(script):
Expand Down
33 changes: 27 additions & 6 deletions tests/functional/test_vcs_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ def add_commits(script, dest, count):
return shas


def check_rev(repo_dir, rev, expected_sha):
def check_rev(repo_dir, rev, expected):
git = Git()
assert git.get_revision_sha(repo_dir, rev) == expected_sha
assert git.get_revision_sha(repo_dir, rev) == expected


def test_git_dir_ignored():
Expand Down Expand Up @@ -70,6 +70,27 @@ def test_git_work_tree_ignored():
git.run_command(['status', temp_dir], extra_environ=env, cwd=temp_dir)


def test_get_branch(script, tmpdir):
repo_dir = str(tmpdir)
script.run('git', 'init', cwd=repo_dir)
sha = do_commit(script, repo_dir)

git = Git()
assert git.get_branch(repo_dir) == 'master'

# Switch to a branch with the same SHA as "master" but whose name
# is alphabetically after.
script.run(
'git', 'checkout', '-b', 'release', cwd=repo_dir,
expect_stderr=True,
)
assert git.get_branch(repo_dir) == 'release'

# Also test the detached HEAD case.
script.run('git', 'checkout', sha, cwd=repo_dir, expect_stderr=True)
assert git.get_branch(repo_dir) is None


def test_get_revision_sha(script):
with TempDirectory(kind="testing") as temp:
repo_dir = temp.path
Expand Down Expand Up @@ -102,9 +123,9 @@ def test_get_revision_sha(script):
script.run('git', 'tag', 'aaa/v1.0', head_sha, cwd=repo_dir)
script.run('git', 'tag', 'zzz/v1.0', head_sha, cwd=repo_dir)

check_rev(repo_dir, 'v1.0', tag_sha)
check_rev(repo_dir, 'v2.0', tag_sha)
check_rev(repo_dir, 'origin-branch', origin_sha)
check_rev(repo_dir, 'v1.0', (tag_sha, False))
check_rev(repo_dir, 'v2.0', (tag_sha, False))
check_rev(repo_dir, 'origin-branch', (origin_sha, True))

ignored_names = [
# Local branches should be ignored.
Expand All @@ -122,7 +143,7 @@ def test_get_revision_sha(script):
'does-not-exist',
]
for name in ignored_names:
check_rev(repo_dir, name, None)
check_rev(repo_dir, name, (None, False))


@pytest.mark.network
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/test_vcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_git_get_src_requirements(git, dist):

@patch('pip._internal.vcs.git.Git.get_revision_sha')
def test_git_resolve_revision_rev_exists(get_sha_mock):
get_sha_mock.return_value = '123456'
get_sha_mock.return_value = ('123456', False)
git = Git()
rev_options = git.make_rev_options('develop')

Expand All @@ -120,7 +120,7 @@ def test_git_resolve_revision_rev_exists(get_sha_mock):

@patch('pip._internal.vcs.git.Git.get_revision_sha')
def test_git_resolve_revision_rev_not_found(get_sha_mock):
get_sha_mock.return_value = None
get_sha_mock.return_value = (None, False)
git = Git()
rev_options = git.make_rev_options('develop')

Expand All @@ -131,7 +131,7 @@ def test_git_resolve_revision_rev_not_found(get_sha_mock):

@patch('pip._internal.vcs.git.Git.get_revision_sha')
def test_git_resolve_revision_not_found_warning(get_sha_mock, caplog):
get_sha_mock.return_value = None
get_sha_mock.return_value = (None, False)
git = Git()

url = 'git+https://git.example.com'
Expand Down

0 comments on commit 392cb09

Please sign in to comment.