From be80854fcd6b3b59ee226be26ad8a0710ceaaa2c Mon Sep 17 00:00:00 2001 From: Yeison Vargas Date: Thu, 7 Jul 2022 23:53:09 -0500 Subject: [PATCH] Improving error messages, fixing license command and fixing proxy issue --- safety/cli.py | 14 ++++--- safety/errors.py | 19 +++++---- safety/safety.py | 7 ++-- safety/util.py | 106 ++++++++++++++++++++++++++--------------------- 4 files changed, 80 insertions(+), 66 deletions(-) diff --git a/safety/cli.py b/safety/cli.py index b0329c33..b2efe0b6 100644 --- a/safety/cli.py +++ b/safety/cli.py @@ -16,23 +16,26 @@ from safety.safety import get_packages, read_vulnerabilities from safety.util import get_proxy_dict, get_packages_licenses, output_exception, \ MutuallyExclusiveOption, DependentOption, transform_ignore, SafetyPolicyFile, active_color_if_needed, \ - get_processed_options, get_safety_version, json_alias, bare_alias + get_processed_options, get_safety_version, json_alias, bare_alias, SafetyContext LOG = logging.getLogger(__name__) @click.group() @click.option('--debug/--no-debug', default=False) -@click.option('--telemetry/--disable-telemetry', default=True) +@click.option('--telemetry/--disable-telemetry', default=True, hidden=True) +@click.option('--disable-optional-telemetry-data', default=False, cls=MutuallyExclusiveOption, + mutually_exclusive=["telemetry", "disable-telemetry"], is_flag=True, show_default=True) @click.version_option(version=get_safety_version()) @click.pass_context -def cli(ctx, debug, telemetry): +def cli(ctx, debug, telemetry, disable_optional_telemetry_data): """ Safety checks Python dependencies for known security vulnerabilities and suggests the proper remediations for vulnerabilities detected. Safety can be run on developer machines, in CI/CD pipelines and on production systems. """ - ctx.telemetry = telemetry + SafetyContext().safety_source = 'cli' + ctx.telemetry = telemetry and not disable_optional_telemetry_data level = logging.CRITICAL if debug: level = logging.DEBUG @@ -225,7 +228,6 @@ def license(ctx, key, db, output, cache, files, proxyprotocol, proxyhost, proxyp """ LOG.info('Running license command') packages = get_packages(files, False) - ctx.obj = packages proxy_dictionary = get_proxy_dict(proxyprotocol, proxyhost, proxyport) announcements = [] @@ -244,7 +246,7 @@ def license(ctx, key, db, output, cache, files, proxyprotocol, proxyhost, proxyp exception = e if isinstance(e, SafetyException) else SafetyException(info=e) output_exception(exception, exit_code_output=False) - filtered_packages_licenses = get_packages_licenses(packages, licenses_db) + filtered_packages_licenses = get_packages_licenses(packages=packages, licenses_db=licenses_db) output_report = SafetyFormatter(output=output).render_licenses(announcements, filtered_packages_licenses) diff --git a/safety/errors.py b/safety/errors.py index af9bf3be..8e8a3a36 100644 --- a/safety/errors.py +++ b/safety/errors.py @@ -28,7 +28,7 @@ def __init__(self, reason=None, fetched_from="server", message="Sorry, something went wrong.\n" + "Safety CLI can not read the data fetched from {fetched_from} because is malformed.\n"): info = "Reason, {reason}".format(reason=reason) - self.message = message.format(fetched_from=fetched_from) + info if reason else "" + self.message = message.format(fetched_from=fetched_from) + (info if reason else "") super().__init__(self.message) def get_exit_code(self): @@ -58,10 +58,12 @@ def get_exit_code(self): class InvalidKeyError(DatabaseFetchError): - def __init__(self, key=None, message="Your API Key '{key}' is invalid. See {link}"): + def __init__(self, key=None, message="Your API Key '{key}' is invalid. See {link}.", reason=None): self.key = key self.link = 'https://bit.ly/3OY2wEI' self.message = message.format(key=key, link=self.link) if key else message + info = f" Reason: {reason}" + self.message = self.message + (info if reason else "") super().__init__(self.message) def get_exit_code(self): @@ -70,9 +72,10 @@ def get_exit_code(self): class TooManyRequestsError(DatabaseFetchError): - def __init__(self, - message="Unable to load database (Too many requests, please wait a while before to make another request)"): - self.message = message + def __init__(self, reason=None, + message="Too many requests."): + info = f" Reason: {reason}" + self.message = message + (info if reason else "") super().__init__(self.message) def get_exit_code(self): @@ -81,7 +84,7 @@ def get_exit_code(self): class NetworkConnectionError(DatabaseFetchError): - def __init__(self, message="Check your network connection, unable to reach the server"): + def __init__(self, message="Check your network connection, unable to reach the server."): self.message = message super().__init__(self.message) @@ -98,6 +101,6 @@ class ServerError(DatabaseFetchError): def __init__(self, reason=None, message="Sorry, something went wrong.\n" + "Safety CLI can not connect to the server.\n" + "Our engineers are working quickly to resolve the issue."): - info = " Reason: {reason}".format(reason=reason) - self.message = message + info if reason else "" + info = f" Reason: {reason}" + self.message = message + (info if reason else "") super().__init__(self.message) diff --git a/safety/safety.py b/safety/safety.py index 628e6192..1fe720c6 100644 --- a/safety/safety.py +++ b/safety/safety.py @@ -13,8 +13,7 @@ from packaging.utils import canonicalize_name from packaging.version import parse as parse_version, Version, LegacyVersion, parse -from .constants import (API_MIRRORS, CACHE_FILE, OPEN_MIRRORS, - REQUEST_TIMEOUT, API_BASE_URL) +from .constants import (API_MIRRORS, CACHE_FILE, OPEN_MIRRORS, REQUEST_TIMEOUT, API_BASE_URL) from .errors import (DatabaseFetchError, DatabaseFileNotFoundError, InvalidKeyError, TooManyRequestsError, NetworkConnectionError, RequestTimeoutError, ServerError, MalformedDatabase) @@ -123,10 +122,10 @@ def fetch_database_url(mirror, db_name, key, cached, proxy, telemetry=True): raise DatabaseFetchError() if r.status_code == 403: - raise InvalidKeyError(key=key) + raise InvalidKeyError(key=key, reason=r.text) if r.status_code == 429: - raise TooManyRequestsError() + raise TooManyRequestsError(reason=r.text) if r.status_code != 200: raise ServerError(reason=r.reason) diff --git a/safety/util.py b/safety/util.py index 20a3c073..7a5b5df3 100644 --- a/safety/util.py +++ b/safety/util.py @@ -120,7 +120,8 @@ def read_requirements(fh, resolve=False): def get_proxy_dict(proxy_protocol, proxy_host, proxy_port): if proxy_protocol and proxy_host and proxy_port: - return {proxy_protocol: f"{proxy_protocol}://{proxy_host}:{str(proxy_port)}"} + # Safety only uses https request, so only https dict will be passed to requests + return {'https': f"{proxy_protocol}://{proxy_host}:{str(proxy_port)}"} return None @@ -132,51 +133,6 @@ def get_license_name_by_id(license_id, db): return None -def get_packages_licenses(packages, licenses_db): - """Get the licenses for the specified packages based on their version. - - :param packages: packages list - :param licenses_db: the licenses db in the raw form. - :return: list of objects with the packages and their respectives licenses. - """ - packages_licenses_db = licenses_db.get('packages', {}) - filtered_packages_licenses = [] - - for pkg in packages: - # Ignore recursive files not resolved - if isinstance(pkg, RequirementFile): - continue - # normalize the package name - pkg_name = canonicalize_name(pkg.name) - # packages may have different licenses depending their version. - pkg_licenses = packages_licenses_db.get(pkg_name, []) - version_requested = parse_version(pkg.version) - license_id = None - license_name = None - for pkg_version in pkg_licenses: - license_start_version = parse_version(pkg_version['start_version']) - # Stops and return the previous stored license when a new - # license starts on a version above the requested one. - if version_requested >= license_start_version: - license_id = pkg_version['license_id'] - else: - # We found the license for the version requested - break - - if license_id: - license_name = get_license_name_by_id(license_id, licenses_db) - if not license_id or not license_name: - license_name = "unknown" - - filtered_packages_licenses.append({ - "package": pkg_name, - "version": pkg.version, - "license": license_name - }) - - return filtered_packages_licenses - - def get_flags_from_context(): flags = {} context = click.get_current_context(silent=True) @@ -245,6 +201,7 @@ def build_telemetry_data(telemetry=True): } if telemetry else {} body['safety_version'] = get_safety_version() + body['safety_source'] = os.environ.get("SAFETY_SOURCE", None) or context.safety_source LOG.debug(f'Telemetry body built: {body}') @@ -309,8 +266,7 @@ def handle_parse_result(self, ctx, opts, args): if option_used and (not self.with_values or exclusive_value_used): options = ', '.join(self.opts) prohibited = ''.join(["\n * --{0} with {1}".format(item, self.with_values.get( - item)) if item in self.with_values else item for item in self.mutually_exclusive]) - + item)) if item in self.with_values else f"\n * {item}" for item in self.mutually_exclusive]) raise click.UsageError( f"Illegal usage: `{options}` is mutually exclusive with: {prohibited}" ) @@ -626,6 +582,7 @@ class SafetyContext(metaclass=SingletonMeta): command = None review = None params = {} + safety_source = 'code' def sync_safety_context(f): @@ -639,3 +596,56 @@ def new_func(*args, **kwargs): return f(*args, **kwargs) return new_func + + +@sync_safety_context +def get_packages_licenses(packages=None, licenses_db=None): + """Get the licenses for the specified packages based on their version. + + :param packages: packages list + :param licenses_db: the licenses db in the raw form. + :return: list of objects with the packages and their respectives licenses. + """ + SafetyContext().command = 'license' + + if not packages: + packages = [] + if not licenses_db: + licenses_db = {} + + packages_licenses_db = licenses_db.get('packages', {}) + filtered_packages_licenses = [] + + for pkg in packages: + # Ignore recursive files not resolved + if isinstance(pkg, RequirementFile): + continue + # normalize the package name + pkg_name = canonicalize_name(pkg.name) + # packages may have different licenses depending their version. + pkg_licenses = packages_licenses_db.get(pkg_name, []) + version_requested = parse_version(pkg.version) + license_id = None + license_name = None + for pkg_version in pkg_licenses: + license_start_version = parse_version(pkg_version['start_version']) + # Stops and return the previous stored license when a new + # license starts on a version above the requested one. + if version_requested >= license_start_version: + license_id = pkg_version['license_id'] + else: + # We found the license for the version requested + break + + if license_id: + license_name = get_license_name_by_id(license_id, licenses_db) + if not license_id or not license_name: + license_name = "unknown" + + filtered_packages_licenses.append({ + "package": pkg_name, + "version": pkg.version, + "license": license_name + }) + + return filtered_packages_licenses