Skip to content

Commit

Permalink
Store PackageFinder.trusted_hosts instead of secure_origins.
Browse files Browse the repository at this point in the history
  • Loading branch information
cjerdonek committed Jun 12, 2019
1 parent efaabe3 commit 4c1ccae
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/pip/_internal/build_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def install_requirements(
args.append('--no-index')
for link in finder.find_links:
args.extend(['--find-links', link])
for _, host, _ in finder.secure_origins:
for host in finder.trusted_hosts:
args.extend(['--trusted-host', host])
if finder.allow_all_prereleases:
args.append('--pre')
Expand Down
36 changes: 24 additions & 12 deletions src/pip/_internal/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@
if MYPY_CHECK_RUNNING:
from logging import Logger
from typing import (
Tuple, Optional, Any, List, Union, Callable, Set, Sequence,
Iterable, MutableMapping
Any, Callable, Iterable, Iterator, List, MutableMapping, Optional,
Sequence, Set, Tuple, Union,
)
from pip._vendor.packaging.version import _BaseVersion
from pip._vendor.requests import Response
Expand Down Expand Up @@ -562,9 +562,9 @@ def __init__(
candidate_evaluator, # type: CandidateEvaluator
find_links, # type: List[str]
index_urls, # type: List[str]
secure_origins, # type: List[SecureOrigin]
session, # type: PipSession
format_control=None, # type: Optional[FormatControl]
trusted_hosts=None, # type: Optional[List[str]]
):
# type: (...) -> None
"""
Expand All @@ -577,14 +577,17 @@ def __init__(
the selection of source packages / binary packages when consulting
the index and links.
"""
if trusted_hosts is None:
trusted_hosts = []

format_control = format_control or FormatControl(set(), set())

self.candidate_evaluator = candidate_evaluator
self.find_links = find_links
self.index_urls = index_urls
self.secure_origins = secure_origins
self.session = session
self.format_control = format_control
self.trusted_hosts = trusted_hosts

# These are boring links that have already been logged somehow.
self._logged_links = set() # type: Set[Link]
Expand All @@ -595,7 +598,7 @@ def create(
find_links, # type: List[str]
index_urls, # type: List[str]
allow_all_prereleases=False, # type: bool
trusted_hosts=None, # type: Optional[Iterable[str]]
trusted_hosts=None, # type: Optional[List[str]]
session=None, # type: Optional[PipSession]
format_control=None, # type: Optional[FormatControl]
target_python=None, # type: Optional[TargetPython]
Expand Down Expand Up @@ -636,11 +639,6 @@ def create(
link = new_link
built_find_links.append(link)

secure_origins = [
("*", host, "*")
for host in (trusted_hosts if trusted_hosts else [])
] # type: List[SecureOrigin]

candidate_evaluator = CandidateEvaluator(
target_python=target_python, prefer_binary=prefer_binary,
allow_all_prereleases=allow_all_prereleases,
Expand All @@ -664,9 +662,9 @@ def create(
candidate_evaluator=candidate_evaluator,
find_links=built_find_links,
index_urls=index_urls,
secure_origins=secure_origins,
session=session,
format_control=format_control,
trusted_hosts=trusted_hosts,
)

@property
Expand All @@ -678,6 +676,20 @@ def set_allow_all_prereleases(self):
# type: () -> None
self.candidate_evaluator.allow_all_prereleases = True

def extend_trusted_hosts(self, hosts):
# type: (List[str]) -> None
for host in hosts:
if host in self.trusted_hosts:
continue
self.trusted_hosts.append(host)

def iter_secure_origins(self):
# type: () -> Iterator[SecureOrigin]
for secure_origin in SECURE_ORIGINS:
yield secure_origin
for host in self.trusted_hosts:
yield ('*', host, '*')

def get_formatted_locations(self):
# type: () -> str
lines = []
Expand Down Expand Up @@ -766,7 +778,7 @@ def _validate_secure_origin(self, logger, location):
# Determine if our origin is a secure origin by looking through our
# hardcoded list of secure origins, as well as any additional ones
# configured on this PackageFinder instance.
for secure_origin in (SECURE_ORIGINS + self.secure_origins):
for secure_origin in self.iter_secure_origins():
if protocol != secure_origin[0] and secure_origin[0] != "*":
continue

Expand Down
3 changes: 1 addition & 2 deletions src/pip/_internal/req/req_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,7 @@ def process_line(
if opts.pre:
finder.set_allow_all_prereleases()
if opts.trusted_hosts:
finder.secure_origins.extend(
("*", host, "*") for host in opts.trusted_hosts)
finder.extend_trusted_hosts(opts.trusted_hosts)


def break_args_options(line):
Expand Down
38 changes: 38 additions & 0 deletions tests/unit/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,44 @@ def test_create__target_python(self):
assert actual_target_python is target_python
assert actual_target_python.py_version_info == (3, 7, 3)

def test_extend_trusted_hosts(self):
trusted_hosts = ['host1', 'host2']
finder = make_test_finder(trusted_hosts=trusted_hosts)

# Check that extend_trusted_hosts() prevents duplicates.
finder.extend_trusted_hosts(['host2', 'host3', 'host2'])
assert finder.trusted_hosts == ['host1', 'host2', 'host3'], (
'actual: {}'.format(finder.trusted_hosts)
)

def test_iter_secure_origins(self):
trusted_hosts = ['host1', 'host2']
finder = make_test_finder(trusted_hosts=trusted_hosts)

actual = list(finder.iter_secure_origins())
assert len(actual) == 8
# Spot-check that SECURE_ORIGINS is included.
assert actual[0] == ('https', '*', '*')
assert actual[-2:] == [
('*', 'host1', '*'),
('*', 'host2', '*'),
]

def test_iter_secure_origins__none_trusted_hosts(self):
"""
Test iter_secure_origins() after passing trusted_hosts=None.
"""
# Use PackageFinder.create() rather than make_test_finder()
# to make sure we're really passing trusted_hosts=None.
finder = PackageFinder.create(
[], [], trusted_hosts=None, session=object(),
)

actual = list(finder.iter_secure_origins())
assert len(actual) == 6
# Spot-check that SECURE_ORIGINS is included.
assert actual[0] == ('https', '*', '*')


def test_sort_locations_file_expand_dir(data):
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_req_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def test_set_finder_extra_index_urls(self, finder):

def test_set_finder_trusted_host(self, finder):
list(process_line("--trusted-host=url", "file", 1, finder=finder))
assert finder.secure_origins == [('*', 'url', '*')]
assert finder.trusted_hosts == ['url']

def test_noop_always_unzip(self, finder):
# noop, but confirm it can be set
Expand Down

0 comments on commit 4c1ccae

Please sign in to comment.