diff --git a/CHANGELOG.md b/CHANGELOG.md index 3457c9f5..89985aa2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,10 @@ All versions prior to 0.0.9 are untracked. faster and more responsive ([#283](https://github.com/trailofbits/pip-audit/pull/283)) +* CLI, Vulnerability sources: the error message used to report + connection failures to vulnerability sources was improved + ([#287](https://github.com/trailofbits/pip-audit/pull/287)) + ### Fixed * Vulnerability sources: a bug stemming from an incorrect assumption diff --git a/pip_audit/_cli.py b/pip_audit/_cli.py index e66a2772..b16684dc 100644 --- a/pip_audit/_cli.py +++ b/pip_audit/_cli.py @@ -25,6 +25,7 @@ from pip_audit._fix import ResolvedFixVersion, SkippedFixVersion, resolve_fix_versions from pip_audit._format import ColumnsFormat, CycloneDxFormat, JsonFormat, VulnerabilityFormat from pip_audit._service import OsvService, PyPIService, VulnerabilityService +from pip_audit._service.interface import ConnectionError as VulnServiceConnectionError from pip_audit._service.interface import ResolvedDependency, SkippedDependency from pip_audit._state import AuditSpinner, AuditState from pip_audit._util import assert_never @@ -400,25 +401,34 @@ def audit() -> None: skip_count = 0 vuln_ignore_count = 0 vulns_to_ignore = set(args.ignore_vulns) - for (spec, vulns) in auditor.audit(source): - if spec.is_skipped(): - spec = cast(SkippedDependency, spec) - if args.strict: - _fatal(f"{spec.name}: {spec.skip_reason}") + try: + for (spec, vulns) in auditor.audit(source): + if spec.is_skipped(): + spec = cast(SkippedDependency, spec) + if args.strict: + _fatal(f"{spec.name}: {spec.skip_reason}") + else: + state.update_state(f"Skipping {spec.name}: {spec.skip_reason}") + skip_count += 1 else: - state.update_state(f"Skipping {spec.name}: {spec.skip_reason}") - skip_count += 1 - else: - spec = cast(ResolvedDependency, spec) - state.update_state(f"Auditing {spec.name} ({spec.version})") - if vulns_to_ignore: - filtered_vulns = [v for v in vulns if not v.has_any_id(vulns_to_ignore)] - vuln_ignore_count += len(vulns) - len(filtered_vulns) - vulns = filtered_vulns - result[spec] = vulns - if len(vulns) > 0: - pkg_count += 1 - vuln_count += len(vulns) + spec = cast(ResolvedDependency, spec) + state.update_state(f"Auditing {spec.name} ({spec.version})") + if vulns_to_ignore: + filtered_vulns = [v for v in vulns if not v.has_any_id(vulns_to_ignore)] + vuln_ignore_count += len(vulns) - len(filtered_vulns) + vulns = filtered_vulns + result[spec] = vulns + if len(vulns) > 0: + pkg_count += 1 + vuln_count += len(vulns) + except VulnServiceConnectionError as e: + # The most common source of connection errors is corporate blocking, + # so we offer a bit of advice. + logger.error(str(e)) + _fatal( + "Tip: your network may be blocking this service. " + "Try another service with `-s SERVICE`" + ) # If the `--fix` flag has been applied, find a set of suitable fix versions and upgrade the # dependencies at the source diff --git a/pip_audit/_service/__init__.py b/pip_audit/_service/__init__.py index 03f2e217..39338676 100644 --- a/pip_audit/_service/__init__.py +++ b/pip_audit/_service/__init__.py @@ -3,6 +3,7 @@ """ from .interface import ( + ConnectionError, Dependency, ResolvedDependency, ServiceError, @@ -14,6 +15,7 @@ from .pypi import PyPIService __all__ = [ + "ConnectionError", "Dependency", "ResolvedDependency", "ServiceError", diff --git a/pip_audit/_service/interface.py b/pip_audit/_service/interface.py index 3978ebb9..16fa7de9 100644 --- a/pip_audit/_service/interface.py +++ b/pip_audit/_service/interface.py @@ -158,3 +158,12 @@ class ServiceError(Exception): """ pass + + +class ConnectionError(ServiceError): + """ + A specialization of `ServiceError` specifically for cases where the + vulnerability service is unreachable or offline. + """ + + pass diff --git a/pip_audit/_service/osv.py b/pip_audit/_service/osv.py index e10867dd..c7a7b8c2 100644 --- a/pip_audit/_service/osv.py +++ b/pip_audit/_service/osv.py @@ -12,6 +12,7 @@ from pip_audit._cache import caching_session from pip_audit._service.interface import ( + ConnectionError, Dependency, ResolvedDependency, ServiceError, @@ -56,17 +57,15 @@ def query(self, spec: Dependency) -> Tuple[Dependency, List[VulnerabilityResult] "package": {"name": spec.canonical_name, "ecosystem": "PyPI"}, "version": str(spec.version), } - response: requests.Response = self.session.post( - url=url, - data=json.dumps(query), - timeout=self.timeout, - ) - - results: List[VulnerabilityResult] = [] - - # Check for an unsuccessful status code try: + response: requests.Response = self.session.post( + url=url, + data=json.dumps(query), + timeout=self.timeout, + ) response.raise_for_status() + except requests.ConnectTimeout: + raise ConnectionError("Could not connect to OSV's vulnerability feed") except requests.HTTPError as http_error: raise ServiceError from http_error @@ -74,6 +73,7 @@ def query(self, spec: Dependency) -> Tuple[Dependency, List[VulnerabilityResult] # associated vulnerabilities # # In that case, return an empty list + results: List[VulnerabilityResult] = [] response_json = response.json() if not response_json: return spec, results diff --git a/pip_audit/_service/pypi.py b/pip_audit/_service/pypi.py index 3460b89b..33777b54 100644 --- a/pip_audit/_service/pypi.py +++ b/pip_audit/_service/pypi.py @@ -13,6 +13,7 @@ from pip_audit._cache import caching_session from pip_audit._service.interface import ( + ConnectionError, Dependency, ResolvedDependency, ServiceError, @@ -55,9 +56,17 @@ def query(self, spec: Dependency) -> Tuple[Dependency, List[VulnerabilityResult] spec = cast(ResolvedDependency, spec) url = f"https://pypi.org/pypi/{spec.canonical_name}/{str(spec.version)}/json" - response: requests.Response = self.session.get(url=url, timeout=self.timeout) + try: + response: requests.Response = self.session.get(url=url, timeout=self.timeout) response.raise_for_status() + except requests.ConnectTimeout: + # Apart from a normal network outage, this can happen for two main + # reasons: + # 1. PyPI's APIs are offline + # 2. The user is behind a firewall or corporate network that blocks + # PyPI (and they're probably using custom indices) + raise ConnectionError("Could not connect to PyPI's vulnerability feed") except requests.HTTPError as http_error: if response.status_code == 404: skip_reason = ( diff --git a/test/service/test_osv.py b/test/service/test_osv.py index c2729092..4723206f 100644 --- a/test/service/test_osv.py +++ b/test/service/test_osv.py @@ -3,7 +3,7 @@ import pretend # type: ignore import pytest from packaging.version import Version -from requests.exceptions import HTTPError +from requests.exceptions import ConnectTimeout, HTTPError import pip_audit._service as service @@ -95,6 +95,17 @@ def test_osv_no_vuln(): assert len(vulns) == 0 +def test_osv_connection_error(monkeypatch): + osv = service.OsvService() + monkeypatch.setattr(osv.session, "post", pretend.raiser(ConnectTimeout)) + + dep = service.ResolvedDependency("jinja2", Version("2.4.1")) + with pytest.raises( + service.ConnectionError, match="Could not connect to OSV's vulnerability feed" + ): + dict(osv.query_all(iter([dep]))) + + def test_osv_error_response(monkeypatch): def raise_for_status(): raise HTTPError diff --git a/test/service/test_pypi.py b/test/service/test_pypi.py index 9c89873c..18976233 100644 --- a/test/service/test_pypi.py +++ b/test/service/test_pypi.py @@ -48,6 +48,20 @@ def test_pypi_multiple_pkg(cache_dir): assert len(results[deps[1]]) > 0 +def test_pypi_connection_error(monkeypatch): + session = pretend.stub(get=pretend.raiser(requests.ConnectTimeout)) + caching_session = pretend.call_recorder(lambda c, **kw: session) + monkeypatch.setattr(service.pypi, "caching_session", caching_session) + + cache_dir = pretend.stub() + pypi = service.PyPIService(cache_dir) + + with pytest.raises( + service.ConnectionError, match="Could not connect to PyPI's vulnerability feed" + ): + dict(pypi.query_all(iter([service.ResolvedDependency("fakedep", Version("1.0.0"))]))) + + def test_pypi_http_notfound(monkeypatch, cache_dir): # If we get a "not found" response, that means that we're querying a package or version that # isn't known to PyPI. If that's the case, we should just log a debug message and continue on