From 085f5748148d7934b28d45904df2e5e98a9209dc Mon Sep 17 00:00:00 2001 From: James Clarke Date: Thu, 28 Oct 2021 14:17:34 +0100 Subject: [PATCH] Update connection parameter resolution (#241) --- .gitmodules | 3 + edgedb/asyncio_con.py | 10 +- edgedb/blocking_con.py | 11 +- edgedb/con_utils.py | 787 ++++++++++++++++++++++------------ tests/shared-client-testcases | 1 + tests/test_con_utils.py | 553 ++++++++---------------- 6 files changed, 712 insertions(+), 653 deletions(-) create mode 160000 tests/shared-client-testcases diff --git a/.gitmodules b/.gitmodules index 1c692350..1830ced8 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "edgedb/pgproto"] path = edgedb/pgproto url = https://github.com/MagicStack/py-pgproto.git +[submodule "tests/shared-client-testcases"] + path = tests/shared-client-testcases + url = https://github.com/edgedb/shared-client-testcases.git diff --git a/edgedb/asyncio_con.py b/edgedb/asyncio_con.py index 0cbea7ba..a7d9499d 100644 --- a/edgedb/asyncio_con.py +++ b/edgedb/asyncio_con.py @@ -522,9 +522,9 @@ async def fetchone_json(self, query: str, *args, **kwargs) -> str: async def async_connect(dsn: str = None, *, + credentils_file: str = None, host: str = None, port: int = None, user: str = None, password: str = None, - admin: bool = None, database: str = None, tls_ca_file: str = None, tls_verify_hostname: bool = None, @@ -537,9 +537,9 @@ async def async_connect(dsn: str = None, *, if connection_class is None: connection_class = AsyncIOConnection - addrs, params, config = con_utils.parse_connect_arguments( - dsn=dsn, host=host, port=port, user=user, password=password, - database=database, admin=admin, timeout=timeout, + connect_config, client_config = con_utils.parse_connect_arguments( + dsn=dsn, credentials_file=credentils_file, host=host, port=port, + user=user, password=password, database=database, timeout=timeout, tls_ca_file=tls_ca_file, tls_verify_hostname=tls_verify_hostname, wait_until_available=wait_until_available, @@ -548,7 +548,7 @@ async def async_connect(dsn: str = None, *, server_settings=None) connection = connection_class( - loop, addrs, config, params, + loop, [connect_config.address], client_config, connect_config, codecs_registry=_CodecsRegistry(), query_cache=_QueryCodecsCache(), ) diff --git a/edgedb/blocking_con.py b/edgedb/blocking_con.py index e5db8083..d62e9cd8 100644 --- a/edgedb/blocking_con.py +++ b/edgedb/blocking_con.py @@ -425,18 +425,18 @@ def is_closed(self) -> bool: def connect(dsn: str = None, *, + credentials_file: str = None, host: str = None, port: int = None, user: str = None, password: str = None, - admin: bool = None, database: str = None, tls_ca_file: str = None, tls_verify_hostname: bool = None, timeout: int = 10, wait_until_available: int = 30) -> BlockingIOConnection: - addrs, params, config = con_utils.parse_connect_arguments( - dsn=dsn, host=host, port=port, user=user, password=password, - database=database, admin=admin, + connect_config, client_config = con_utils.parse_connect_arguments( + dsn=dsn, credentials_file=credentials_file, host=host, port=port, + user=user, password=password, database=database, timeout=timeout, wait_until_available=wait_until_available, tls_ca_file=tls_ca_file, tls_verify_hostname=tls_verify_hostname, @@ -446,7 +446,8 @@ def connect(dsn: str = None, *, server_settings=None) conn = BlockingIOConnection( - addrs=addrs, params=params, config=config, + addrs=[connect_config.address], params=connect_config, + config=client_config, codecs_registry=_CodecsRegistry(), query_cache=_QueryCodecsCache()) conn.ensure_connected() diff --git a/edgedb/con_utils.py b/edgedb/con_utils.py index 50f138ae..96aef906 100644 --- a/edgedb/con_utils.py +++ b/edgedb/con_utils.py @@ -47,16 +47,6 @@ }) -class ConnectionParameters(typing.NamedTuple): - - user: str - password: str - database: str - connect_timeout: float - server_settings: typing.Mapping[str, str] - ssl_ctx: ssl.SSLContext - - class ClientConfiguration(typing.NamedTuple): connect_timeout: float @@ -129,7 +119,10 @@ def _stash_path(path): return platform.search_config_dir('projects', dir_name) -def _parse_verify_hostname(val: str) -> bool: +def _parse_verify_hostname(val: typing.Union[str, bool]) -> bool: + if isinstance(val, bool): + return val + val = val.lower() if val in {"1", "yes", "true", "y", "t", "on"}: return True @@ -141,284 +134,517 @@ def _parse_verify_hostname(val: str) -> bool: ) -def _parse_connect_dsn_and_args(*, dsn, host, port, user, - password, database, admin, - tls_ca_file, tls_verify_hostname, - connect_timeout, server_settings): - using_credentials = False - tls_ca_data = None - - if admin: - warnings.warn( - 'The "admin=True" parameter is deprecated and is scheduled to be ' - 'removed. Admin socket should never be used in applications. ' - 'Use command-line tool `edgedb` to setup proper credentials.', - DeprecationWarning, 4) - - if not ( - dsn or host or port or - os.getenv("EDGEDB_HOST") or os.getenv("EDGEDB_PORT") - ): - instance_name = os.getenv("EDGEDB_INSTANCE") - if instance_name: - dsn = instance_name - else: - toml = find_edgedb_toml() - stash_dir = _stash_path(os.path.dirname(toml)) - if os.path.exists(stash_dir): - with open(os.path.join(stash_dir, 'instance-name'), 'rt') as f: - dsn = f.read().strip() - else: - raise errors.ClientConnectionError( - f'Found `edgedb.toml` but the project is not initialized. ' - f'Run `edgedb project init`.' +class ResolvedConnectConfig: + _host = None + _host_source = None + + _port = None + _port_source = None + + _database = None + _database_source = None + + _user = None + _user_source = None + + _password = None + _password_source = None + + _tls_ca_data = None + _tls_ca_data_source = None + + _tls_verify_hostname = None + _tls_verify_hostname_source = None + + server_settings = {} + + def _set_param(self, param, value, source, validator=None): + param_name = '_' + param + if getattr(self, param_name) is None: + setattr(self, param_name + '_source', source) + if value is not None: + setattr( + self, + param_name, + validator(value) if validator else value ) - if dsn and dsn.startswith(("edgedb://", "edgedbadmin://")): - parsed = urllib.parse.urlparse(dsn) + def set_host(self, host, source): + self._set_param('host', host, source, _validate_host) - if parsed.scheme not in ('edgedb', 'edgedbadmin'): - raise ValueError( - f'invalid DSN: scheme is expected to be ' - f'"edgedb" or "edgedbadmin", got {parsed.scheme!r}') - - if parsed.scheme == 'edgedbadmin': - warnings.warn( - 'The `edgedbadmin` scheme is deprecated and is scheduled ' - 'to be removed. Admin socket should never be used in ' - 'applications. Use command-line tool `edgedb` to setup ' - 'proper credentials.', - DeprecationWarning, 4) - - if admin is None: - admin = parsed.scheme == 'edgedbadmin' - - if not host and parsed.netloc: - if '@' in parsed.netloc: - auth, _, hostspec = parsed.netloc.partition('@') - else: - hostspec = parsed.netloc - - if hostspec: - host, port = _parse_hostlist(hostspec, port) - - if parsed.path and database is None: - database = parsed.path - if database.startswith('/'): - database = database[1:] - - if parsed.username and user is None: - user = parsed.username - - if parsed.password and password is None: - password = parsed.password - - if parsed.query: - query = urllib.parse.parse_qs(parsed.query, strict_parsing=True) - for key, val in query.items(): - if isinstance(val, list): - query[key] = val[-1] - - if 'port' in query: - val = query.pop('port') - if not port and val: - port = [int(p) for p in val.split(',')] - - if 'host' in query: - val = query.pop('host') - if not host and val: - host, port = _parse_hostlist(val, port) - - if 'dbname' in query: - val = query.pop('dbname') - if database is None: - database = val - - if 'database' in query: - val = query.pop('database') - if database is None: - database = val - - if 'user' in query: - val = query.pop('user') - if user is None: - user = val - - if 'password' in query: - val = query.pop('password') - if password is None: - password = val - - if 'tls_ca_file' in query: - val = query.pop('tls_ca_file') - if tls_ca_file is None: - tls_ca_file = val - - if 'tls_verify_hostname' in query: - val = query.pop('tls_verify_hostname') - if tls_verify_hostname is None: - tls_verify_hostname = _parse_verify_hostname(val) - - if query: - if server_settings is None: - server_settings = query - else: - server_settings = {**query, **server_settings} - elif dsn: - if not dsn.isidentifier(): - raise ValueError( - f"dsn {dsn!r} is neither a edgedb:// URI " - f"nor valid instance name" - ) + def set_port(self, port, source): + self._set_param('port', port, source, _validate_port) - using_credentials = True - path = credentials.get_credentials_path(dsn) - try: - creds = credentials.read_credentials(path) - except Exception as e: - raise errors.ClientError( - f"cannot read credentials of instance {dsn!r}" - ) from e - - if port is None: - port = creds['port'] - if user is None: - user = creds['user'] - if host is None and 'host' in creds: - host = creds['host'] - if password is None and 'password' in creds: - password = creds['password'] - if database is None and 'database' in creds: - database = creds['database'] - if tls_ca_file is None and 'tls_cert_data' in creds: - tls_ca_data = creds['tls_cert_data'] - if tls_verify_hostname is None and 'tls_verify_hostname' in creds: - tls_verify_hostname = creds['tls_verify_hostname'] - - if not host: - hostspec = os.environ.get('EDGEDB_HOST') - if hostspec: - host, port = _parse_hostlist(hostspec, port) - - if not host: - if platform.IS_WINDOWS or using_credentials: - host = [] - else: - host = ['/run/edgedb', '/var/run/edgedb'] + def set_database(self, database, source): + self._set_param('database', database, source, _validate_database) - if not admin: - host.append('localhost') + def set_user(self, user, source): + self._set_param('user', user, source, _validate_user) - if not isinstance(host, list): - host = [host] + def set_password(self, password, source): + self._set_param('password', password, source) - if not port: - portspec = os.environ.get('EDGEDB_PORT') - if portspec: - if ',' in portspec: - port = [int(p) for p in portspec.split(',')] - else: - port = int(portspec) + def set_tls_ca_data(self, ca_data, source): + self._set_param('tls_ca_data', ca_data, source) + + def set_tls_ca_file(self, ca_file, source): + def read_ca_file(file_path): + with open(file_path) as f: + return f.read() + + self._set_param('tls_ca_data', ca_file, source, read_ca_file) + + def set_tls_verify_hostname(self, verify_hostname, source): + self._set_param('tls_verify_hostname', verify_hostname, source, + _parse_verify_hostname) + + def add_server_settings(self, server_settings): + _validate_server_settings(server_settings) + self.server_settings = {**server_settings, **self.server_settings} + + @property + def address(self): + return ( + self._host if self._host else 'localhost', + self._port if self._port else 5656 + ) + + @property + def database(self): + return self._database if self._database else 'edgedb' + + @property + def user(self): + return self._user if self._user else 'edgedb' + + @property + def password(self): + return self._password + + @property + def tls_verify_hostname(self): + return (self._tls_verify_hostname + if self._tls_verify_hostname is not None + else self._tls_ca_data is None) + + _ssl_ctx = None + + @property + def ssl_ctx(self): + if (self._ssl_ctx): + return self._ssl_ctx + + self._ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + self._ssl_ctx.verify_mode = ssl.CERT_REQUIRED + if self._tls_ca_data: + self._ssl_ctx.load_verify_locations( + cadata=self._tls_ca_data + ) else: - port = EDGEDB_PORT + self._ssl_ctx.load_default_certs(ssl.Purpose.SERVER_AUTH) + if platform.IS_WINDOWS: + import certifi + self._ssl_ctx.load_verify_locations(cafile=certifi.where()) + self._ssl_ctx.check_hostname = self.tls_verify_hostname + self._ssl_ctx.set_alpn_protocols(['edgedb-binary']) + + return self._ssl_ctx + + +def _validate_host(host): + if '/' in host: + raise ValueError('unix socket paths not supported') + if host == '' or ',' in host: + raise ValueError(f'invalid host: "{host}"') + return host + + +def _validate_port(port): + try: + if isinstance(port, str): + port = int(port) + if not isinstance(port, int): + raise ValueError() + except Exception: + raise ValueError(f'invalid port: {port}, not an integer') + if port < 1 or port > 65535: + raise ValueError(f'invalid port: {port}, must be between 1 and 65535') + return port - elif isinstance(port, (list, tuple)): - port = [int(p) for p in port] - else: - port = int(port) - - port = _validate_port_spec(host, port) - - if user is None: - user = os.getenv('EDGEDB_USER') - if not user: - user = 'edgedb' - - if password is None: - password = os.getenv('EDGEDB_PASSWORD') - - if database is None: - database = os.getenv('EDGEDB_DATABASE') - - if database is None: - database = 'edgedb' - - if user is None: - raise errors.InterfaceError( - 'could not determine user name to connect with') - - if database is None: - raise errors.InterfaceError( - 'could not determine database name to connect to') - - have_unix_sockets = False - addrs = [] - for h, p in zip(host, port): - if h.startswith('/'): - # UNIX socket name - if '.s.EDGEDB.' not in h: - if admin: - sock_name = f'.s.EDGEDB.admin.{p}' - else: - sock_name = f'.s.EDGEDB.{p}' - h = os.path.join(h, sock_name) - have_unix_sockets = True - addrs.append(h) - elif not admin: - # TCP host/port - addrs.append((h, p)) - - if admin and not have_unix_sockets: - raise ValueError( - 'admin connections are only supported over UNIX sockets') +def _validate_database(database): + if database == '': + raise ValueError(f'invalid database name: {database}') + return database + + +def _validate_user(user): + if user == '': + raise ValueError(f'invalid user name: {user}') + return user - if not addrs: - raise ValueError( - 'could not determine the database address to connect to') - if server_settings is not None and ( - not isinstance(server_settings, dict) or - not all(isinstance(k, str) for k in server_settings) or - not all(isinstance(v, str) for v in server_settings.values())): +def _validate_server_settings(server_settings): + if ( + not isinstance(server_settings, dict) or + not all(isinstance(k, str) for k in server_settings) or + not all(isinstance(v, str) for v in server_settings.values()) + ): raise ValueError( 'server_settings is expected to be None or ' 'a Dict[str, str]') - if admin: - ssl_ctx = None - else: - ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - ssl_ctx.verify_mode = ssl.CERT_REQUIRED - if tls_ca_file or tls_ca_data: - ssl_ctx.load_verify_locations( - cafile=tls_ca_file, cadata=tls_ca_data - ) - if tls_verify_hostname is None: - tls_verify_hostname = False + +def _parse_connect_dsn_and_args(*, dsn, credentials_file, host, port, user, + password, database, + tls_ca_file, tls_verify_hostname, + server_settings): + + resolved_config = ResolvedConnectConfig() + + dsn, instance_name = ( + (dsn, None) + if dsn is not None and re.match('(?i)^[a-z]+://', dsn) + else (None, dsn) + ) + + has_compound_options = _resolve_config_options( + resolved_config, + 'Cannot have more than one of the following connection options: ' + + '"dsn", "credentials_file" or "host"/"port"', + dsn=(dsn, '"dsn" option') if dsn is not None else None, + instance_name=( + (instance_name, '"dsn" option (parsed as instance name)') + if instance_name is not None else None + ), + credentials_file=( + (credentials_file, '"credentials_file" option') + if credentials_file is not None else None + ), + host=(host, '"host" option') if host is not None else None, + port=(port, '"port" option') if port is not None else None, + database=( + (database, '"database" option') + if database is not None else None + ), + user=(user, '"user" option') if user is not None else None, + password=( + (password, '"password" option') + if password is not None else None + ), + tls_ca_file=( + (tls_ca_file, '"tls_ca_file" option') + if tls_ca_file is not None else None + ), + tls_verify_hostname=( + (tls_verify_hostname, '"tls_verify_hostname" option') + if tls_verify_hostname is not None else None + ), + server_settings=( + (server_settings, '"server_settings" option') + if server_settings is not None else None + ), + ) + + if has_compound_options is False: + env_port = os.getenv("EDGEDB_PORT") + if ( + resolved_config._port is None and + env_port and env_port.startswith('tcp://') + ): + # EDGEDB_PORT is set by 'docker --link' so ignore and warn + warnings.warn('EDGEDB_PORT in "tcp://host:port" format, ' + + 'so will be ignored') + env_port = None + + env_dsn = os.getenv('EDGEDB_DSN') + env_instance = os.getenv('EDGEDB_INSTANCE') + env_credentials_file = os.getenv('EDGEDB_CREDENTIALS_FILE') + env_host = os.getenv('EDGEDB_HOST') + env_database = os.getenv('EDGEDB_DATABASE') + env_user = os.getenv('EDGEDB_USER') + env_password = os.getenv('EDGEDB_PASSWORD') + env_tls_ca_file = os.getenv('EDGEDB_TLS_CA_FILE') + env_tls_verify_hostname = os.getenv('EDGEDB_TLS_VERIFY_HOSTNAME') + + has_compound_options = _resolve_config_options( + resolved_config, + 'Cannot have more than one of the following connection ' + + 'environment variables: "EDGEDB_DSN", "EDGEDB_INSTANCE", ' + + '"EDGEDB_CREDENTIALS_FILE" or "EDGEDB_HOST"/"EDGEDB_PORT"', + dsn=( + (env_dsn, '"EDGEDB_DSN" environment variable') + if env_dsn is not None else None + ), + instance_name=( + (env_instance, '"EDGEDB_INSTANCE" environment variable') + if env_instance is not None else None + ), + credentials_file=( + (env_credentials_file, + '"EDGEDB_CREDENTIALS_FILE" environment variable') + if env_credentials_file is not None else None + ), + host=( + (env_host, '"EDGEDB_HOST" environment variable') + if env_host is not None else None + ), + port=( + (env_port, '"EDGEDB_PORT" environment variable') + if env_port is not None else None + ), + database=( + (env_database, '"EDGEDB_DATABASE" environment variable') + if env_database is not None else None + ), + user=( + (env_user, '"EDGEDB_USER" environment variable') + if env_user is not None else None + ), + password=( + (env_password, '"EDGEDB_PASSWORD" environment variable') + if env_password is not None else None + ), + tls_ca_file=( + (env_tls_ca_file, '"EDGEDB_TLS_CA_FILE" environment variable') + if env_tls_ca_file is not None else None + ), + tls_verify_hostname=( + (env_tls_verify_hostname, + '"EDGEDB_TLS_VERIFY_HOSTNAME" environment variable') + if env_tls_verify_hostname is not None else None + ), + ) + + if has_compound_options is False: + dir = find_edgedb_project_dir() + stash_dir = _stash_path(dir) + if os.path.exists(stash_dir): + with open(os.path.join(stash_dir, 'instance-name'), 'rt') as f: + instance_name = f.read().strip() + + _resolve_config_options( + resolved_config, + '', + instance_name=( + instance_name, + f'project linked instance ("{instance_name}")' + ) + ) else: - ssl_ctx.load_default_certs(ssl.Purpose.SERVER_AUTH) - if platform.IS_WINDOWS: - import certifi - ssl_ctx.load_verify_locations(cafile=certifi.where()) - if tls_verify_hostname is None: - tls_verify_hostname = True - ssl_ctx.check_hostname = tls_verify_hostname - ssl_ctx.set_alpn_protocols(['edgedb-binary']) - - params = ConnectionParameters( - user=user, - password=password, - database=database, - connect_timeout=connect_timeout, - server_settings=server_settings, - ssl_ctx=ssl_ctx, + raise errors.ClientConnectionError( + f'Found `edgedb.toml` but the project is not initialized. ' + f'Run `edgedb project init`.' + ) + + return resolved_config + + +def _parse_dsn_into_config( + resolved_config: ResolvedConnectConfig, + dsn: typing.Tuple[str, str] +): + dsn_str, source = dsn + + try: + parsed = urllib.parse.urlparse(dsn_str) + host = parsed.hostname + port = parsed.port + database = parsed.path + user = parsed.username + password = parsed.password + except Exception as e: + raise ValueError(f'invalid DSN or instance name: {str(e)}') + + if parsed.scheme != 'edgedb': + raise ValueError( + f'invalid DSN: scheme is expected to be ' + f'"edgedb", got {parsed.scheme!r}') + + query = ( + urllib.parse.parse_qs(parsed.query, keep_blank_values=True) + if parsed.query != '' + else {} ) + for key, val in query.items(): + if isinstance(val, list): + if len(val) > 1: + raise ValueError( + f'invalid DSN: duplicate query parameter {key}') + query[key] = val[-1] + + def handle_dsn_part( + paramName, value, currentValue, setter, + formatter=lambda val: val + ): + param_values = [ + (value if value != '' else None), + query.get(paramName), + query.get(paramName + '_env'), + query.get(paramName + '_file') + ] + if len([p for p in param_values if p is not None]) > 1: + raise ValueError( + f'invalid DSN: more than one of ' + + f'{(paramName + ", ") if value else ""}' + + f'?{paramName}=, ?{paramName}_env=, ?{paramName}_file= ' + + f'was specified' + ) + + if currentValue is None: + param = ( + value if (value is not None and value != '') + else query.get(paramName) + ) + paramSource = source + + if param is None: + env = query.get(paramName + '_env') + if env is not None: + param = os.getenv(env) + if param is None: + raise ValueError( + f'{paramName}_env environment variable "{env}" ' + + f'doesn\'t exist') + paramSource = paramSource + f' ({paramName}_env: {env})' + if param is None: + filename = query.get(paramName + '_file') + if filename is not None: + with open(filename) as f: + param = f.read() + paramSource = ( + paramSource + f' ({paramName}_file: {filename})' + ) + + param = formatter(param) if param is not None else None + + setter(param, paramSource) + + query.pop(paramName, None) + query.pop(paramName + '_env', None) + query.pop(paramName + '_file', None) + + handle_dsn_part( + 'host', host, resolved_config._host, resolved_config.set_host + ) + + handle_dsn_part( + 'port', port, resolved_config._port, resolved_config.set_port + ) + + def strip_leading_slash(str): + return str[1:] if str.startswith('/') else str + + handle_dsn_part( + 'database', strip_leading_slash(database), + resolved_config._database, resolved_config.set_database, + strip_leading_slash + ) + + handle_dsn_part( + 'user', user, resolved_config._user, resolved_config.set_user + ) + + handle_dsn_part( + 'password', password, + resolved_config._password, resolved_config.set_password + ) + + handle_dsn_part( + 'tls_cert_file', None, + resolved_config._tls_ca_data, resolved_config.set_tls_ca_file + ) + + handle_dsn_part( + 'tls_verify_hostname', None, + resolved_config._tls_verify_hostname, + resolved_config.set_tls_verify_hostname + ) + + resolved_config.add_server_settings(query) + + +def _resolve_config_options( + resolved_config: ResolvedConnectConfig, + compound_error: str, + *, + dsn=None, + instance_name=None, + credentials_file=None, + host=None, + port=None, + database=None, + user=None, + password=None, + tls_ca_file=None, + tls_verify_hostname=None, + server_settings=None +): + if database is not None: + resolved_config.set_database(*database) + if user is not None: + resolved_config.set_user(*user) + if password is not None: + resolved_config.set_password(*password) + if tls_ca_file is not None: + resolved_config.set_tls_ca_file(*tls_ca_file) + if tls_verify_hostname is not None: + resolved_config.set_tls_verify_hostname(*tls_verify_hostname) + if server_settings is not None: + resolved_config.add_server_settings(server_settings[0]) + + compound_params = [dsn, instance_name, credentials_file, host or port] + compound_params_count = len([p for p in compound_params if p is not None]) + + if compound_params_count > 1: + raise errors.ClientConnectionError(compound_error) + + if compound_params_count == 1: + if dsn is not None or host is not None or port is not None: + if port is not None: + resolved_config.set_port(*port) + if dsn is None: + dsn = ( + 'edgedb://' + (_validate_host(host[0]) if host else ''), + host[1] if host is not None else port[1] + ) + _parse_dsn_into_config(resolved_config, dsn) + else: + if credentials_file is None: + if ( + re.match( + '^[A-Za-z_][A-Za-z_0-9]*$', + instance_name[0] + ) is None + ): + raise ValueError( + f'invalid DSN or instance name: "{instance_name[0]}"' + ) + credentials_file = ( + credentials.get_credentials_path(instance_name[0]), + instance_name[1] + ) + creds = credentials.read_credentials(credentials_file[0]) + + source = credentials_file[1] + + resolved_config.set_host(creds.get('host'), source) + resolved_config.set_port(creds.get('port'), source) + resolved_config.set_database(creds.get('database'), source) + resolved_config.set_user(creds.get('user'), source) + resolved_config.set_password(creds.get('password'), source) + resolved_config.set_tls_ca_data(creds.get('tls_cert_data'), source) + resolved_config.set_tls_verify_hostname( + creds.get('tls_verify_hostname'), + source + ) + + return True - return addrs, params + return False -def find_edgedb_toml(): +def find_edgedb_project_dir(): dir = os.getcwd() dev = os.stat(dir).st_dev @@ -442,14 +668,16 @@ def find_edgedb_toml(): dir = parent dev = parent_dev continue - return toml + return dir -def parse_connect_arguments(*, dsn, host, port, user, password, - database, admin, - tls_ca_file, tls_verify_hostname, - timeout, command_timeout, wait_until_available, - server_settings): +def parse_connect_arguments( + *, dsn, credentials_file, host, port, + database, user, password, + tls_ca_file, tls_verify_hostname, + timeout, command_timeout, wait_until_available, + server_settings +) -> typing.Tuple[ResolvedConnectConfig, ClientConfiguration]: if command_timeout is not None: try: @@ -464,21 +692,20 @@ def parse_connect_arguments(*, dsn, host, port, user, password, 'expected greater than 0 float (got {!r})'.format( command_timeout)) from None - addrs, params = _parse_connect_dsn_and_args( - dsn=dsn, host=host, port=port, user=user, - password=password, admin=admin, - database=database, connect_timeout=timeout, + connect_config = _parse_connect_dsn_and_args( + dsn=dsn, credentials_file=credentials_file, host=host, port=port, + database=database, user=user, password=password, tls_ca_file=tls_ca_file, tls_verify_hostname=tls_verify_hostname, server_settings=server_settings, ) - config = ClientConfiguration( + client_config = ClientConfiguration( connect_timeout=timeout, command_timeout=command_timeout, wait_until_available=wait_until_available or 0, ) - return addrs, params, config + return connect_config, client_config def check_alpn_protocol(ssl_obj): diff --git a/tests/shared-client-testcases b/tests/shared-client-testcases new file mode 160000 index 00000000..0a8b0b7a --- /dev/null +++ b/tests/shared-client-testcases @@ -0,0 +1 @@ +Subproject commit 0a8b0b7ac96f2b13190ffaa9f6b1ab82e13a8af7 diff --git a/tests/test_con_utils.py b/tests/test_con_utils.py index 16472b3f..680bcfd7 100644 --- a/tests/test_con_utils.py +++ b/tests/test_con_utils.py @@ -20,6 +20,7 @@ import contextlib import json import os +import sys import pathlib import tempfile import unittest @@ -32,315 +33,39 @@ class TestConUtils(unittest.TestCase): - TESTS = [ - { - 'user': 'user', - 'host': 'localhost', - 'result': ( - [("localhost", 5656)], - { - 'user': 'user', - 'database': 'edgedb', - }, - {'wait_until_available': 30}, - ) - }, - - { - 'env': { - 'EDGEDB_USER': 'user', - 'EDGEDB_DATABASE': 'testdb', - 'EDGEDB_PASSWORD': 'passw', - 'EDGEDB_HOST': 'host', - 'EDGEDB_PORT': '123' - }, - 'result': ( - [('host', 123)], - { - 'user': 'user', - 'password': 'passw', - 'database': 'testdb' - }, - {'wait_until_available': 30}, - ) - }, - - { - 'env': { - 'EDGEDB_USER': 'user', - 'EDGEDB_DATABASE': 'testdb', - 'EDGEDB_PASSWORD': 'passw', - 'EDGEDB_HOST': 'host', - 'EDGEDB_PORT': '123' - }, - - 'host': 'host2', - 'port': '456', - 'user': 'user2', - 'password': 'passw2', - 'database': 'db2', - - 'result': ( - [('host2', 456)], - { - 'user': 'user2', - 'password': 'passw2', - 'database': 'db2' - }, - {'wait_until_available': 30}, - ) - }, - - { - 'env': { - 'EDGEDB_USER': 'user', - 'EDGEDB_DATABASE': 'testdb', - 'EDGEDB_PASSWORD': 'passw', - 'EDGEDB_HOST': 'host', - 'EDGEDB_PORT': '123', - 'PGSSLMODE': 'prefer' - }, - - 'dsn': 'edgedb://user3:123123@localhost/abcdef', - - 'host': 'host2', - 'port': '456', - 'user': 'user2', - 'password': 'passw2', - 'database': 'db2', - 'server_settings': {'ssl': 'False'}, - - 'result': ( - [('host2', 456)], - { - 'user': 'user2', - 'password': 'passw2', - 'database': 'db2', - 'server_settings': {'ssl': 'False'}, - }, - {'wait_until_available': 30}, - ) - }, - - { - 'env': { - 'EDGEDB_USER': 'user', - 'EDGEDB_DATABASE': 'testdb', - 'EDGEDB_PASSWORD': 'passw', - 'EDGEDB_HOST': 'host', - 'EDGEDB_PORT': '123', - }, - - 'dsn': 'edgedb://user3:123123@localhost:5555/abcdef', - 'command_timeout': 10, - - 'result': ( - [('localhost', 5555)], - { - 'user': 'user3', - 'password': '123123', - 'database': 'abcdef', - }, { - 'command_timeout': 10, - 'wait_until_available': 30, - }) - }, - - { - 'dsn': 'edgedb://user3:123123@localhost:5555/abcdef', - 'result': ( - [('localhost', 5555)], - { - 'user': 'user3', - 'password': '123123', - 'database': 'abcdef' - }, - {'wait_until_available': 30}, - ) - }, - - { - 'dsn': 'edgedb://user@host1,host2/db', - 'result': ( - [('host1', 5656), ('host2', 5656)], - { - 'database': 'db', - 'user': 'user', - }, - {'wait_until_available': 30}, - ) - }, - - { - 'dsn': 'edgedb://user@host1:1111,host2:2222/db', - 'result': ( - [('host1', 1111), ('host2', 2222)], - { - 'database': 'db', - 'user': 'user', - }, - {'wait_until_available': 30}, - ) - }, - - { - 'env': { - 'EDGEDB_HOST': 'host1:1111,host2:2222', - 'EDGEDB_USER': 'foo', - }, - 'dsn': 'edgedb:///db', - 'result': ( - [('host1', 1111), ('host2', 2222)], - { - 'database': 'db', - 'user': 'foo', - }, - {'wait_until_available': 30}, - ) - }, - - { - 'env': { - 'EDGEDB_USER': 'foo', - }, - 'dsn': 'edgedb:///db?host=host1:1111,host2:2222', - 'result': ( - [('host1', 1111), ('host2', 2222)], - { - 'database': 'db', - 'user': 'foo', - }, - {'wait_until_available': 30}, - ) - }, - - { - 'env': { - 'EDGEDB_USER': 'foo', - }, - 'dsn': 'edgedb:///db', - 'host': ['host1', 'host2'], - 'result': ( - [('host1', 5656), ('host2', 5656)], - { - 'database': 'db', - 'user': 'foo', - }, - {'wait_until_available': 30}, - ) - }, - - { - 'dsn': 'edgedb://user3:123123@localhost:5555/' - 'abcdef?param=sss¶m=123&host=testhost&user=testuser' - '&port=2222&database=testdb', - 'host': '127.0.0.1', - 'port': '888', - 'user': 'me', - 'password': 'ask', - 'database': 'db', - 'result': ( - [('127.0.0.1', 888)], - { - 'server_settings': {'param': '123'}, - 'user': 'me', - 'password': 'ask', - 'database': 'db', - }, - {'wait_until_available': 30}, - ) - }, - - { - 'dsn': 'edgedb://user3:123123@localhost:5555/' - 'abcdef?param=sss¶m=123&host=testhost&user=testuser' - '&port=2222&database=testdb', - 'host': '127.0.0.1', - 'port': '888', - 'user': 'me', - 'password': 'ask', - 'database': 'db', - 'server_settings': {'aa': 'bb'}, - 'result': ( - [('127.0.0.1', 888)], - { - 'server_settings': {'aa': 'bb', 'param': '123'}, - 'user': 'me', - 'password': 'ask', - 'database': 'db', - }, - {'wait_until_available': 30}, - ) - }, - - { - 'dsn': 'edgedb:///dbname?host=/unix_sock/test&user=spam', - 'result': ( - [os.path.join('/unix_sock/test', '.s.EDGEDB.5656')], - { - 'user': 'spam', - 'database': 'dbname' - }, - {'wait_until_available': 30}, - ) - }, - - { - 'dsn': 'pq:///dbname?host=/unix_sock/test&user=spam', - 'error': ( - ValueError, - "dsn " - "'pq:///dbname\\?host=/unix_sock/test&user=spam' " - "is neither a edgedb:// URI nor valid instance name" - ) - }, - - { - 'dsn': 'edgedb://host1,host2,host3/db', - 'port': [111, 222], - 'error': ( - errors.InterfaceError, - 'could not match 2 port numbers to 3 hosts' - ) - }, - - { - 'dsn': 'edgedb://user@?port=56226&host=%2Ftmp', - 'result': ( - [os.path.join('/tmp', '.s.EDGEDB.56226')], - { - 'user': 'user', - 'database': 'edgedb', - }, - {'wait_until_available': 30}, - ) - }, - - { - 'dsn': 'edgedb://user@?host=%2Ftmp', - 'admin': True, - 'result': ( - [os.path.join('/tmp', '.s.EDGEDB.admin.5656')], - { - 'user': 'user', - 'database': 'edgedb', - }, - {'wait_until_available': 30}, - ) - }, - - { - 'dsn': 'edgedbadmin://user@?host=%2Ftmp', - 'result': ( - [os.path.join('/tmp', '.s.EDGEDB.admin.5656')], - { - 'user': 'user', - 'database': 'edgedb', - }, - {'wait_until_available': 30}, - ) - }, - ] + error_mapping = { + 'credentials_file_not_found': ( + RuntimeError, 'cannot read credentials'), + 'project_not_initialised': ( + errors.ClientConnectionError, + 'Found `edgedb.toml` but the project is not initialized'), + 'no_options_or_toml': ( + errors.ClientConnectionError, + 'no `edgedb.toml` found and no connection options specified'), + 'invalid_credentials_file': ( + RuntimeError, 'cannot read credentials'), + 'invalid_dsn_or_instance_name': ( + ValueError, 'invalid DSN or instance name'), + 'invalid_dsn': (ValueError, 'invalid DSN'), + 'unix_socket_unsupported': ( + ValueError, 'unix socket paths not supported'), + 'invalid_host': (ValueError, 'invalid host'), + 'invalid_port': (ValueError, 'invalid port'), + 'invalid_user': (ValueError, 'invalid user'), + 'invalid_database': (ValueError, 'invalid database'), + 'multiple_compound_env': ( + errors.ClientConnectionError, + 'Cannot have more than one of the following connection ' + + 'environment variables'), + 'multiple_compound_opts': ( + errors.ClientConnectionError, + 'Cannot have more than one of the following connection options'), + 'env_not_found': ( + ValueError, 'environment variable ".*" doesn\'t exist'), + 'file_not_found': (FileNotFoundError, 'No such file or directory'), + 'invalid_tls_verify_hostname': ( + ValueError, 'tls_verify_hostname can only be one of yes/no') + } @contextlib.contextmanager def environ(self, **kwargs): @@ -369,64 +94,133 @@ def run_testcase(self, testcase): env = testcase.get('env', {}) test_env = {'EDGEDB_HOST': None, 'EDGEDB_PORT': None, 'EDGEDB_USER': None, 'EDGEDB_PASSWORD': None, - 'EDGEDB_DATABASE': None, 'PGSSLMODE': None} + 'EDGEDB_DATABASE': None, 'PGSSLMODE': None, + 'XDG_CONFIG_HOME': None} test_env.update(env) - dsn = testcase.get('dsn') - user = testcase.get('user') - port = testcase.get('port') - host = testcase.get('host') - password = testcase.get('password') - database = testcase.get('database') - tls_ca_file = testcase.get('tls_ca_file') - tls_verify_hostname = testcase.get('tls_verify_hostname') - admin = testcase.get('admin') - timeout = testcase.get('timeout') - command_timeout = testcase.get('command_timeout') - server_settings = testcase.get('server_settings') - - expected = testcase.get('result') + fs = testcase.get('fs') + + opts = testcase.get('opts', {}) + dsn = opts.get('dsn') + credentials_file = opts.get('credentialsFile') + host = opts.get('host') + port = opts.get('port') + database = opts.get('database') + user = opts.get('user') + password = opts.get('password') + tls_ca_file = opts.get('tlsCAFile') + tls_verify_hostname = opts.get('tlsVerifyHostname') + server_settings = opts.get('serverSettings') + + other_opts = testcase.get('other_opts', {}) + timeout = other_opts.get('timeout') + command_timeout = other_opts.get('command_timeout') + + expected = (testcase.get('result'), testcase.get('other_results')) expected_error = testcase.get('error') - if expected is None and expected_error is None: + if expected_error and expected_error.get('type'): + expected_error = self.error_mapping.get(expected_error.get('type')) + if not expected_error: + raise RuntimeError( + f"unknown error type: {testcase.get('error').get('type')}") + + if expected[0] is None and expected_error is None: raise RuntimeError( 'invalid test case: either "result" or "error" key ' 'has to be specified') - if expected is not None and expected_error is not None: + if expected[0] is not None and expected_error is not None: raise RuntimeError( 'invalid test case: either "result" or "error" key ' 'has to be specified, got both') result = None with contextlib.ExitStack() as es: - es.enter_context(self.subTest(dsn=dsn, env=env)) + es.enter_context(self.subTest(dsn=dsn, env=env, opts=opts)) es.enter_context(self.environ(**test_env)) + stat_result = os.stat(os.getcwd()) + es.enter_context( + mock.patch('os.stat', lambda _: stat_result) + ) + + if fs: + cwd = fs.get('cwd') + homedir = fs.get('homedir') + files = fs.get('files') + + if cwd: + es.enter_context(mock.patch('os.getcwd', lambda: cwd)) + if homedir: + homedir = pathlib.Path(homedir) + es.enter_context( + mock.patch('pathlib.Path.home', lambda: homedir) + ) + if files: + es.enter_context( + mock.patch( + 'os.path.exists', + lambda filepath: str(filepath) in files + ) + ) + es.enter_context( + mock.patch( + 'os.path.isfile', + lambda filepath: str(filepath) in files + ) + ) + + es.enter_context( + mock.patch('os.path.realpath', lambda f: f) + ) + + def mocked_open(filepath, *args, **kwargs): + if str(filepath) in files: + return mock.mock_open( + read_data=files.get(str(filepath)) + )() + raise FileNotFoundError( + f"[Errno 2] No such file or directory: " + + f"'{filepath}'" + ) + es.enter_context(mock.patch('builtins.open', mocked_open)) + if expected_error: es.enter_context(self.assertRaisesRegex(*expected_error)) - addrs, params, config = con_utils.parse_connect_arguments( - dsn=dsn, host=host, port=port, user=user, password=password, - database=database, admin=admin, + connect_config, client_config = con_utils.parse_connect_arguments( + dsn=dsn, credentials_file=credentials_file, + host=host, port=port, database=database, + user=user, password=password, tls_ca_file=tls_ca_file, tls_verify_hostname=tls_verify_hostname, timeout=timeout, command_timeout=command_timeout, server_settings=server_settings, wait_until_available=30) - params = {k: v for k, v in params._asdict().items() - if v is not None} - config = {k: v for k, v in config._asdict().items() - if v is not None} - params.pop('ssl_ctx', None) - - result = (addrs, params, config) + result = ( + { + 'address': [ + connect_config.address[0], connect_config.address[1] + ], + 'database': connect_config.database, + 'user': connect_config.user, + 'password': connect_config.password, + 'tlsCAData': connect_config._tls_ca_data, + 'tlsVerifyHostname': connect_config.tls_verify_hostname, + 'serverSettings': connect_config.server_settings + }, { + k: v for k, v in client_config._asdict().items() + if v is not None + } if testcase.get('other_results') else None + ) - if expected is not None: - for k, v in expected[1].items(): - # If `expected` contains a type, allow that to "match" any - # instance of that type that `result` may contain. - if isinstance(v, type) and isinstance(result[1].get(k), v): - result[1][k] = v + if expected[0] is not None: + if (expected[1]): + for k, v in expected[1].items(): + # If `expected` contains a type, allow that to "match" any + # instance of that type that `result` may contain. + if isinstance(v, type) and isinstance(result[1].get(k), v): + result[1][k] = v self.assertEqual(expected, result, 'Testcase: {}'.format(testcase)) def test_test_connect_params_environ(self): @@ -460,18 +254,48 @@ def test_test_connect_params_run_testcase(self): with self.environ(EDGEDB_PORT='777'): self.run_testcase({ 'env': { - 'EDGEDB_USER': '__test__' + 'EDGEDB_HOST': 'abc' + }, + 'opts': { + 'user': '__test__', + }, + 'result': { + 'address': ['abc', 5656], + 'database': 'edgedb', + 'user': '__test__', + 'password': None, + 'tlsCAData': None, + 'tlsVerifyHostname': True, + 'serverSettings': {} + }, + 'other_results': { + 'wait_until_available': 30 }, - 'host': 'abc', - 'result': ( - [('abc', 5656)], - {'user': '__test__', 'database': 'edgedb'}, - {'wait_until_available': 30}, - ) }) def test_connect_params(self): - for testcase in self.TESTS: + testcases_path = os.path.abspath( + 'tests/shared-client-testcases/connection_testcases.json' + ) + try: + with open(testcases_path) as f: + testcases = json.load(f) + except FileNotFoundError as err: + raise FileNotFoundError( + f'Failed to read "connection_testcases.json": {err}.\n' + + f'Is the "shared-client-testcases" submodule initialised? ' + + f'Try running "git submodule update --init".' + ) + + for testcase in testcases: + platform = testcase.get('platform') + if testcase.get('fs') and ( + sys.platform == 'win32' or platform == 'windows' + or (platform is None and sys.platform == 'darwin') + or (platform == 'macos' and sys.platform != 'darwin') + ): + continue + self.run_testcase(testcase) @mock.patch("edgedb.platform.config_dir", @@ -517,15 +341,18 @@ def test_project_config(self): with open(instance_file, 'wt') as f: f.write('inst1') - addrs, params, _config = con_utils.parse_connect_arguments( - dsn=None, host=None, port=None, user=None, password=None, - database=None, admin=None, - tls_ca_file=None, tls_verify_hostname=None, - timeout=10, command_timeout=None, - server_settings=None, - wait_until_available=30) - - self.assertEqual(addrs, [('inst1.example.org', 12323)]) - self.assertEqual(params.user, 'inst1_user') - self.assertEqual(params.password, 'passw1') - self.assertEqual(params.database, 'inst1_db') + connect_config, client_config = ( + con_utils.parse_connect_arguments( + dsn=None, credentials_file=None, host=None, port=None, + user=None, password=None, database=None, + tls_ca_file=None, tls_verify_hostname=None, + timeout=10, command_timeout=None, + server_settings=None, + wait_until_available=30 + ) + ) + + self.assertEqual(connect_config.address, ('inst1.example.org', 12323)) + self.assertEqual(connect_config.user, 'inst1_user') + self.assertEqual(connect_config.password, 'passw1') + self.assertEqual(connect_config.database, 'inst1_db')