Skip to content

Commit

Permalink
Improving error messages, fixing license command and fixing proxy issue
Browse files Browse the repository at this point in the history
  • Loading branch information
yeisonvargasf committed Jul 8, 2022
1 parent c839544 commit be80854
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 66 deletions.
14 changes: 8 additions & 6 deletions safety/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,26 @@
from safety.safety import get_packages, read_vulnerabilities
from safety.util import get_proxy_dict, get_packages_licenses, output_exception, \
MutuallyExclusiveOption, DependentOption, transform_ignore, SafetyPolicyFile, active_color_if_needed, \
get_processed_options, get_safety_version, json_alias, bare_alias
get_processed_options, get_safety_version, json_alias, bare_alias, SafetyContext

LOG = logging.getLogger(__name__)


@click.group()
@click.option('--debug/--no-debug', default=False)
@click.option('--telemetry/--disable-telemetry', default=True)
@click.option('--telemetry/--disable-telemetry', default=True, hidden=True)
@click.option('--disable-optional-telemetry-data', default=False, cls=MutuallyExclusiveOption,
mutually_exclusive=["telemetry", "disable-telemetry"], is_flag=True, show_default=True)
@click.version_option(version=get_safety_version())
@click.pass_context
def cli(ctx, debug, telemetry):
def cli(ctx, debug, telemetry, disable_optional_telemetry_data):
"""
Safety checks Python dependencies for known security vulnerabilities and suggests the proper
remediations for vulnerabilities detected. Safety can be run on developer machines, in CI/CD pipelines and
on production systems.
"""
ctx.telemetry = telemetry
SafetyContext().safety_source = 'cli'
ctx.telemetry = telemetry and not disable_optional_telemetry_data
level = logging.CRITICAL
if debug:
level = logging.DEBUG
Expand Down Expand Up @@ -225,7 +228,6 @@ def license(ctx, key, db, output, cache, files, proxyprotocol, proxyhost, proxyp
"""
LOG.info('Running license command')
packages = get_packages(files, False)
ctx.obj = packages

proxy_dictionary = get_proxy_dict(proxyprotocol, proxyhost, proxyport)
announcements = []
Expand All @@ -244,7 +246,7 @@ def license(ctx, key, db, output, cache, files, proxyprotocol, proxyhost, proxyp
exception = e if isinstance(e, SafetyException) else SafetyException(info=e)
output_exception(exception, exit_code_output=False)

filtered_packages_licenses = get_packages_licenses(packages, licenses_db)
filtered_packages_licenses = get_packages_licenses(packages=packages, licenses_db=licenses_db)

output_report = SafetyFormatter(output=output).render_licenses(announcements, filtered_packages_licenses)

Expand Down
19 changes: 11 additions & 8 deletions safety/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, reason=None, fetched_from="server",
message="Sorry, something went wrong.\n" +
"Safety CLI can not read the data fetched from {fetched_from} because is malformed.\n"):
info = "Reason, {reason}".format(reason=reason)
self.message = message.format(fetched_from=fetched_from) + info if reason else ""
self.message = message.format(fetched_from=fetched_from) + (info if reason else "")
super().__init__(self.message)

def get_exit_code(self):
Expand Down Expand Up @@ -58,10 +58,12 @@ def get_exit_code(self):

class InvalidKeyError(DatabaseFetchError):

def __init__(self, key=None, message="Your API Key '{key}' is invalid. See {link}"):
def __init__(self, key=None, message="Your API Key '{key}' is invalid. See {link}.", reason=None):
self.key = key
self.link = 'https://bit.ly/3OY2wEI'
self.message = message.format(key=key, link=self.link) if key else message
info = f" Reason: {reason}"
self.message = self.message + (info if reason else "")
super().__init__(self.message)

def get_exit_code(self):
Expand All @@ -70,9 +72,10 @@ def get_exit_code(self):

class TooManyRequestsError(DatabaseFetchError):

def __init__(self,
message="Unable to load database (Too many requests, please wait a while before to make another request)"):
self.message = message
def __init__(self, reason=None,
message="Too many requests."):
info = f" Reason: {reason}"
self.message = message + (info if reason else "")
super().__init__(self.message)

def get_exit_code(self):
Expand All @@ -81,7 +84,7 @@ def get_exit_code(self):

class NetworkConnectionError(DatabaseFetchError):

def __init__(self, message="Check your network connection, unable to reach the server"):
def __init__(self, message="Check your network connection, unable to reach the server."):
self.message = message
super().__init__(self.message)

Expand All @@ -98,6 +101,6 @@ class ServerError(DatabaseFetchError):
def __init__(self, reason=None,
message="Sorry, something went wrong.\n" + "Safety CLI can not connect to the server.\n" +
"Our engineers are working quickly to resolve the issue."):
info = " Reason: {reason}".format(reason=reason)
self.message = message + info if reason else ""
info = f" Reason: {reason}"
self.message = message + (info if reason else "")
super().__init__(self.message)
7 changes: 3 additions & 4 deletions safety/safety.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
from packaging.utils import canonicalize_name
from packaging.version import parse as parse_version, Version, LegacyVersion, parse

from .constants import (API_MIRRORS, CACHE_FILE, OPEN_MIRRORS,
REQUEST_TIMEOUT, API_BASE_URL)
from .constants import (API_MIRRORS, CACHE_FILE, OPEN_MIRRORS, REQUEST_TIMEOUT, API_BASE_URL)
from .errors import (DatabaseFetchError, DatabaseFileNotFoundError,
InvalidKeyError, TooManyRequestsError, NetworkConnectionError,
RequestTimeoutError, ServerError, MalformedDatabase)
Expand Down Expand Up @@ -123,10 +122,10 @@ def fetch_database_url(mirror, db_name, key, cached, proxy, telemetry=True):
raise DatabaseFetchError()

if r.status_code == 403:
raise InvalidKeyError(key=key)
raise InvalidKeyError(key=key, reason=r.text)

if r.status_code == 429:
raise TooManyRequestsError()
raise TooManyRequestsError(reason=r.text)

if r.status_code != 200:
raise ServerError(reason=r.reason)
Expand Down
106 changes: 58 additions & 48 deletions safety/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ def read_requirements(fh, resolve=False):

def get_proxy_dict(proxy_protocol, proxy_host, proxy_port):
if proxy_protocol and proxy_host and proxy_port:
return {proxy_protocol: f"{proxy_protocol}://{proxy_host}:{str(proxy_port)}"}
# Safety only uses https request, so only https dict will be passed to requests
return {'https': f"{proxy_protocol}://{proxy_host}:{str(proxy_port)}"}
return None


Expand All @@ -132,51 +133,6 @@ def get_license_name_by_id(license_id, db):
return None


def get_packages_licenses(packages, licenses_db):
"""Get the licenses for the specified packages based on their version.
:param packages: packages list
:param licenses_db: the licenses db in the raw form.
:return: list of objects with the packages and their respectives licenses.
"""
packages_licenses_db = licenses_db.get('packages', {})
filtered_packages_licenses = []

for pkg in packages:
# Ignore recursive files not resolved
if isinstance(pkg, RequirementFile):
continue
# normalize the package name
pkg_name = canonicalize_name(pkg.name)
# packages may have different licenses depending their version.
pkg_licenses = packages_licenses_db.get(pkg_name, [])
version_requested = parse_version(pkg.version)
license_id = None
license_name = None
for pkg_version in pkg_licenses:
license_start_version = parse_version(pkg_version['start_version'])
# Stops and return the previous stored license when a new
# license starts on a version above the requested one.
if version_requested >= license_start_version:
license_id = pkg_version['license_id']
else:
# We found the license for the version requested
break

if license_id:
license_name = get_license_name_by_id(license_id, licenses_db)
if not license_id or not license_name:
license_name = "unknown"

filtered_packages_licenses.append({
"package": pkg_name,
"version": pkg.version,
"license": license_name
})

return filtered_packages_licenses


def get_flags_from_context():
flags = {}
context = click.get_current_context(silent=True)
Expand Down Expand Up @@ -245,6 +201,7 @@ def build_telemetry_data(telemetry=True):
} if telemetry else {}

body['safety_version'] = get_safety_version()
body['safety_source'] = os.environ.get("SAFETY_SOURCE", None) or context.safety_source

LOG.debug(f'Telemetry body built: {body}')

Expand Down Expand Up @@ -309,8 +266,7 @@ def handle_parse_result(self, ctx, opts, args):
if option_used and (not self.with_values or exclusive_value_used):
options = ', '.join(self.opts)
prohibited = ''.join(["\n * --{0} with {1}".format(item, self.with_values.get(
item)) if item in self.with_values else item for item in self.mutually_exclusive])

item)) if item in self.with_values else f"\n * {item}" for item in self.mutually_exclusive])
raise click.UsageError(
f"Illegal usage: `{options}` is mutually exclusive with: {prohibited}"
)
Expand Down Expand Up @@ -626,6 +582,7 @@ class SafetyContext(metaclass=SingletonMeta):
command = None
review = None
params = {}
safety_source = 'code'


def sync_safety_context(f):
Expand All @@ -639,3 +596,56 @@ def new_func(*args, **kwargs):
return f(*args, **kwargs)

return new_func


@sync_safety_context
def get_packages_licenses(packages=None, licenses_db=None):
"""Get the licenses for the specified packages based on their version.
:param packages: packages list
:param licenses_db: the licenses db in the raw form.
:return: list of objects with the packages and their respectives licenses.
"""
SafetyContext().command = 'license'

if not packages:
packages = []
if not licenses_db:
licenses_db = {}

packages_licenses_db = licenses_db.get('packages', {})
filtered_packages_licenses = []

for pkg in packages:
# Ignore recursive files not resolved
if isinstance(pkg, RequirementFile):
continue
# normalize the package name
pkg_name = canonicalize_name(pkg.name)
# packages may have different licenses depending their version.
pkg_licenses = packages_licenses_db.get(pkg_name, [])
version_requested = parse_version(pkg.version)
license_id = None
license_name = None
for pkg_version in pkg_licenses:
license_start_version = parse_version(pkg_version['start_version'])
# Stops and return the previous stored license when a new
# license starts on a version above the requested one.
if version_requested >= license_start_version:
license_id = pkg_version['license_id']
else:
# We found the license for the version requested
break

if license_id:
license_name = get_license_name_by_id(license_id, licenses_db)
if not license_id or not license_name:
license_name = "unknown"

filtered_packages_licenses.append({
"package": pkg_name,
"version": pkg.version,
"license": license_name
})

return filtered_packages_licenses

0 comments on commit be80854

Please sign in to comment.