Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle environments with HOME set to a not-a-directory #1063

Merged
merged 1 commit into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 42 additions & 22 deletions asyncpg/connect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def _validate_port_spec(hosts, port):
# If there is a list of ports, its length must
# match that of the host list.
if len(port) != len(hosts):
raise exceptions.InterfaceError(
raise exceptions.ClientConfigurationError(
'could not match {} port numbers to {} hosts'.format(
len(port), len(hosts)))
else:
Expand Down Expand Up @@ -211,7 +211,7 @@ def _parse_hostlist(hostlist, port, *, unquote=False):
addr = m.group(1)
hostspec_port = m.group(2)
else:
raise ValueError(
raise exceptions.ClientConfigurationError(
'invalid IPv6 address in the connection URI: {!r}'.format(
hostspec
)
Expand Down Expand Up @@ -240,13 +240,13 @@ def _parse_hostlist(hostlist, port, *, unquote=False):

def _parse_tls_version(tls_version):
if tls_version.startswith('SSL'):
raise ValueError(
raise exceptions.ClientConfigurationError(
f"Unsupported TLS version: {tls_version}"
)
try:
return ssl_module.TLSVersion[tls_version.replace('.', '_')]
except KeyError:
raise ValueError(
raise exceptions.ClientConfigurationError(
f"No such TLS version: {tls_version}"
)

Expand Down Expand Up @@ -274,7 +274,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
parsed = urllib.parse.urlparse(dsn)

if parsed.scheme not in {'postgresql', 'postgres'}:
raise ValueError(
raise exceptions.ClientConfigurationError(
'invalid DSN: scheme is expected to be either '
'"postgresql" or "postgres", got {!r}'.format(parsed.scheme))

Expand Down Expand Up @@ -437,11 +437,11 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
database = user

if user is None:
raise exceptions.InterfaceError(
raise exceptions.ClientConfigurationError(
'could not determine user name to connect with')

if database is None:
raise exceptions.InterfaceError(
raise exceptions.ClientConfigurationError(
'could not determine database name to connect to')

if password is None:
Expand Down Expand Up @@ -477,7 +477,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
have_tcp_addrs = True

if not addrs:
raise ValueError(
raise exceptions.InternalClientError(
'could not determine the database address to connect to')

if ssl is None:
Expand All @@ -491,7 +491,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
sslmode = SSLMode.parse(ssl)
except AttributeError:
modes = ', '.join(m.name.replace('_', '-') for m in SSLMode)
raise exceptions.InterfaceError(
raise exceptions.ClientConfigurationError(
'`sslmode` parameter must be one of: {}'.format(modes))

# docs at https://www.postgresql.org/docs/10/static/libpq-connect.html
Expand All @@ -511,19 +511,36 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
else:
try:
sslrootcert = _dot_postgresql_path('root.crt')
assert sslrootcert is not None
ssl.load_verify_locations(cafile=sslrootcert)
except (AssertionError, FileNotFoundError):
if sslrootcert is not None:
ssl.load_verify_locations(cafile=sslrootcert)
else:
raise exceptions.ClientConfigurationError(
'cannot determine location of user '
'PostgreSQL configuration directory'
)
except (
exceptions.ClientConfigurationError,
FileNotFoundError,
NotADirectoryError,
):
if sslmode > SSLMode.require:
if sslrootcert is None:
raise RuntimeError(
'Cannot determine home directory'
sslrootcert = '~/.postgresql/root.crt'
detail = (
'Could not determine location of user '
'home directory (HOME is either unset, '
'inaccessible, or does not point to a '
'valid directory)'
)
raise ValueError(
else:
detail = None
raise exceptions.ClientConfigurationError(
f'root certificate file "{sslrootcert}" does '
f'not exist\nEither provide the file or '
f'change sslmode to disable server '
f'certificate verification.'
f'not exist or cannot be accessed',
hint='Provide the certificate file directly '
f'or make sure "{sslrootcert}" '
'exists and is readable.',
detail=detail,
)
elif sslmode == SSLMode.require:
ssl.verify_mode = ssl_module.CERT_NONE
Expand All @@ -542,7 +559,10 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
if sslcrl is not None:
try:
ssl.load_verify_locations(cafile=sslcrl)
except FileNotFoundError:
except (
FileNotFoundError,
NotADirectoryError,
):
pass
else:
ssl.verify_flags |= \
Expand Down Expand Up @@ -571,7 +591,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
keyfile=sslkey,
password=lambda: sslpassword
)
except FileNotFoundError:
except (FileNotFoundError, NotADirectoryError):
pass

# OpenSSL 1.1.1 keylog file, copied from create_default_context()
Expand Down Expand Up @@ -606,7 +626,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
not isinstance(server_settings, dict) or
not all(isinstance(k, str) for k in server_settings) or
not all(isinstance(v, str) for v in server_settings.values())):
raise ValueError(
raise exceptions.ClientConfigurationError(
'server_settings is expected to be None or '
'a Dict[str, str]')

Expand All @@ -617,7 +637,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
try:
target_session_attrs = SessionAttribute(target_session_attrs)
except ValueError:
raise exceptions.InterfaceError(
raise exceptions.ClientConfigurationError(
"target_session_attrs is expected to be one of "
"{!r}"
", got {!r}".format(
Expand Down
7 changes: 6 additions & 1 deletion asyncpg/exceptions/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
__all__ = ('PostgresError', 'FatalPostgresError', 'UnknownPostgresError',
'InterfaceError', 'InterfaceWarning', 'PostgresLogMessage',
'InternalClientError', 'OutdatedSchemaCacheError', 'ProtocolError',
'UnsupportedClientFeatureError', 'TargetServerAttributeNotMatched')
'UnsupportedClientFeatureError', 'TargetServerAttributeNotMatched',
'ClientConfigurationError')


def _is_asyncpg_class(cls):
Expand Down Expand Up @@ -220,6 +221,10 @@ def with_msg(self, msg):
)


class ClientConfigurationError(InterfaceError, ValueError):
"""An error caused by improper client configuration."""


class DataError(InterfaceError, ValueError):
"""An error caused by invalid query input."""

Expand Down
32 changes: 30 additions & 2 deletions tests/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,15 @@ def mock_no_home_dir():
yield


@contextlib.contextmanager
def mock_dev_null_home_dir():
with unittest.mock.patch(
'pathlib.Path.home',
unittest.mock.Mock(return_value=pathlib.Path('/dev/null')),
):
yield


class TestSettings(tb.ConnectedTestCase):

async def test_get_settings_01(self):
Expand Down Expand Up @@ -1318,16 +1327,35 @@ async def test_connection_no_home_dir(self):
await con.fetchval('SELECT 42')
await con.close()

with mock_dev_null_home_dir():
con = await self.connect(
dsn='postgresql://foo/',
user='postgres',
database='postgres',
host='localhost')
await con.fetchval('SELECT 42')
await con.close()

with self.assertRaisesRegex(
RuntimeError,
'Cannot determine home directory'
exceptions.ClientConfigurationError,
r'root certificate file "~/\.postgresql/root\.crt" does not exist'
):
with mock_no_home_dir():
await self.connect(
host='localhost',
user='ssl_user',
ssl='verify-full')

with self.assertRaisesRegex(
exceptions.ClientConfigurationError,
r'root certificate file ".*" does not exist'
):
with mock_dev_null_home_dir():
await self.connect(
host='localhost',
user='ssl_user',
ssl='verify-full')


class BaseTestSSLConnection(tb.ConnectedTestCase):
@classmethod
Expand Down