diff --git a/CHANGELOG.md b/CHANGELOG.md index c353f8f..076acf4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ In Development -------------- +* The library now includes an asynchronous version of the main method named validate_email_async, which can be called with await, that runs DNS-based deliverability checks asychronously. * A new option to parse `My Name ` strings, i.e. a display name plus an email address in angle brackets, is now available. It is off by default. * When a domain name has no MX record but does have an A or AAAA record, if none of the IP addresses in the response are globally reachable (i.e. not Private-Use, Loopback, etc.), the response is treated as if there was no A/AAAA response and the email address will fail the deliverability check. * When a domain name has no MX record but does have an A or AAAA record, the mx field in the object returned by validate_email incorrectly held the IP addresses rather than the domain itself. diff --git a/README.md b/README.md index 7b71ee4..71b08b6 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ Key features: can display to end-users. * Checks deliverability (optional): Does the domain name resolve? (You can override the default DNS resolver to add query caching.) +* Can be called asynchronously with `await`. * Supports internationalized domain names (like `@ツ.life`), internationalized local parts (like `ツ@example.com`), and optionally parses display names (e.g. `"My Name" `). @@ -83,6 +84,9 @@ This validates the address and gives you its normalized form. You should checking if an address is in your database. When using this in a login form, set `check_deliverability` to `False` to avoid unnecessary DNS queries. +See below for examples for caching DNS queries and calling the library +asynchronously with `await`. + Usage ----- @@ -163,6 +167,30 @@ while True: validate_email(email, dns_resolver=resolver) ``` +### Asynchronous call + +The library has an alternative, asynchronous method named `validate_email_async` which must be called with `await`. This method uses an [asynchronous DNS resolver](https://dnspython.readthedocs.io/en/latest/async.html) so that multiple DNS-based deliverability checks can be performed in parallel. + +Here how to use it. In this example, `import ... as` is used to alias the async method to the usual method name `validate_email`. + +```python +from email_validator import validate_email_async as validate_email, \ + EmailNotValidError, caching_async_resolver + +resolver = caching_async_resolver(timeout=10) + +email = "my+address@example.org" +try: + emailinfo = await validate_email(email) + email = emailinfo.normalized +except EmailNotValidError as e: + print(str(e)) +``` + +Note that to create a caching asynchronous resolver, use `caching_async_resolver`. As with the synchronous version, creating a resolver is optional. + +When processing batches of email addresses, I found that chunking around 25 email addresses at a time (using e.g. `asyncio.gather()`) resulted in the highest performance. I tested on a residential Internet connection with valid addresses. + ### Test addresses This library rejects email addresses that use the [Special Use Domain Names](https://www.iana.org/assignments/special-use-domain-names/special-use-domain-names.xhtml) `invalid`, `localhost`, `test`, and some others by raising `EmailSyntaxError`. This is to protect your system from abuse: You probably don't want a user to be able to cause an email to be sent to `localhost` (although they might be able to still do so via a malicious MX record). However, in your non-production test environments you may want to use `@test` or `@myname.test` email addresses. There are three ways you can allow this: diff --git a/email_validator/__init__.py b/email_validator/__init__.py index 626aa00..6f718a4 100644 --- a/email_validator/__init__.py +++ b/email_validator/__init__.py @@ -3,13 +3,14 @@ # Export the main method, helper methods, and the public data types. from .exceptions_types import ValidatedEmail, EmailNotValidError, \ EmailSyntaxError, EmailUndeliverableError -from .validate_email import validate_email +from .validate_email import validate_email_sync as validate_email, validate_email_async from .version import __version__ -__all__ = ["validate_email", +__all__ = ["validate_email", "validate_email_async", "ValidatedEmail", "EmailNotValidError", "EmailSyntaxError", "EmailUndeliverableError", - "caching_resolver", "__version__"] + "caching_resolver", "caching_async_resolver", + "__version__"] if TYPE_CHECKING: from .deliverability import caching_resolver @@ -21,6 +22,13 @@ def caching_resolver(*args, **kwargs): return caching_resolver(*args, **kwargs) +def caching_async_resolver(*args, **kwargs): + # Lazy load `deliverability` as it is slow to import (due to dns.resolver) + from .deliverability import caching_async_resolver + + return caching_async_resolver(*args, **kwargs) + + # These global attributes are a part of the library's API and can be # changed by library users. diff --git a/email_validator/__main__.py b/email_validator/__main__.py index 52791c7..fc7123a 100644 --- a/email_validator/__main__.py +++ b/email_validator/__main__.py @@ -5,26 +5,88 @@ # python -m email_validator test@example.org # python -m email_validator < LIST_OF_ADDRESSES.TXT # -# Provide email addresses to validate either as a command-line argument -# or in STDIN separated by newlines. Validation errors will be printed for -# invalid email addresses. When passing an email address on the command -# line, if the email address is valid, information about it will be printed. -# When using STDIN, no output will be given for valid email addresses. +# Provide email addresses to validate either as a single command-line argument +# or on STDIN separated by newlines. +# +# When passing an email address on the command line, if the email address +# is valid, information about it will be printed to STDOUT. If the email +# address is invalid, an error message will be printed to STDOUT and +# the exit code will be set to 1. +# +# When passsing email addresses on STDIN, validation errors will be printed +# for invalid email addresses. No output is given for valid email addresses. +# Validation errors are preceded by the email address that failed and a tab +# character. It is the user's responsibility to ensure email addresses +# do not contain tab or newline characters. # # Keyword arguments to validate_email can be set in environment variables # of the same name but upprcase (see below). +import itertools import json import os import sys -from typing import Any, Dict, Optional +from typing import Any, Dict -from .validate_email import validate_email, _Resolver -from .deliverability import caching_resolver +from .deliverability import caching_async_resolver from .exceptions_types import EmailNotValidError -def main(dns_resolver: Optional[_Resolver] = None) -> None: +def main_command_line(email_address, options, dns_resolver): + # Validate the email address passed on the command line. + + from . import validate_email + + try: + result = validate_email(email_address, dns_resolver=dns_resolver, **options) + print(json.dumps(result.as_dict(), indent=2, sort_keys=True, ensure_ascii=False)) + return True + except EmailNotValidError as e: + print(e) + return False + + +async def main_stdin(options, dns_resolver): + # Validate the email addresses pased line-by-line on STDIN. + # Chunk the addresses and call the async version of validate_email + # for all the addresses in the chunk, and wait for the chunk + # to complete. + + import asyncio + + from . import validate_email_async as validate_email + + dns_resolver = dns_resolver or caching_async_resolver() + + # https://stackoverflow.com/a/312467 + def split_seq(iterable, size): + it = iter(iterable) + item = list(itertools.islice(it, size)) + while item: + yield item + item = list(itertools.islice(it, size)) + + CHUNK_SIZE = 25 + + async def process_line(line): + email = line.strip() + try: + await validate_email(email, dns_resolver=dns_resolver, **options) + # If the email was valid, do nothing. + return None + except EmailNotValidError as e: + return (email, e) + + chunks = split_seq(sys.stdin, CHUNK_SIZE) + for chunk in chunks: + awaitables = [process_line(line) for line in chunk] + errors = await asyncio.gather(*awaitables) + for error in errors: + if error is not None: + print(*error, sep='\t') + + +def main(dns_resolver=None): # The dns_resolver argument is for tests. # Set options from environment variables. @@ -37,24 +99,14 @@ def main(dns_resolver: Optional[_Resolver] = None) -> None: if varname in os.environ: options[varname.lower()] = float(os.environ[varname]) - if len(sys.argv) == 1: - # Validate the email addresses pased line-by-line on STDIN. - dns_resolver = dns_resolver or caching_resolver() - for line in sys.stdin: - email = line.strip() - try: - validate_email(email, dns_resolver=dns_resolver, **options) - except EmailNotValidError as e: - print(f"{email} {e}") + if len(sys.argv) == 2: + return main_command_line(sys.argv[1], options, dns_resolver) else: - # Validate the email address passed on the command line. - email = sys.argv[1] - try: - result = validate_email(email, dns_resolver=dns_resolver, **options) - print(json.dumps(result.as_dict(), indent=2, sort_keys=True, ensure_ascii=False)) - except EmailNotValidError as e: - print(e) + import asyncio + asyncio.run(main_stdin(options, dns_resolver)) + return True if __name__ == "__main__": - main() + if not main(): + sys.exit(1) diff --git a/email_validator/deliverability.py b/email_validator/deliverability.py index 90f5f9a..56ef0e8 100644 --- a/email_validator/deliverability.py +++ b/email_validator/deliverability.py @@ -5,6 +5,7 @@ from .exceptions_types import EmailUndeliverableError import dns.resolver +import dns.asyncresolver import dns.exception @@ -25,30 +26,73 @@ def caching_resolver(*, timeout: Optional[int] = None, cache: Any = None, dns_re }, total=False) -def validate_email_deliverability(domain: str, domain_i18n: str, timeout: Optional[int] = None, dns_resolver: Optional[dns.resolver.Resolver] = None) -> DeliverabilityInfo: +def caching_async_resolver(*, timeout: Optional[int] = None, cache=None, dns_resolver=None): + if timeout is None: + from . import DEFAULT_TIMEOUT + timeout = DEFAULT_TIMEOUT + resolver = dns_resolver or dns.asyncresolver.Resolver() + resolver.cache = cache or dns.resolver.LRUCache() # type: ignore + resolver.lifetime = timeout # type: ignore # timeout, in seconds + return resolver + + +async def validate_email_deliverability( + domain: str, + domain_i18n: str, + timeout: Optional[int] = None, + dns_resolver: Optional[dns.resolver.Resolver] = None, + async_loop: Optional[bool] = None +) -> DeliverabilityInfo: # Check that the domain resolves to an MX record. If there is no MX record, # try an A or AAAA record which is a deprecated fallback for deliverability. # Raises an EmailUndeliverableError on failure. On success, returns a dict # with deliverability information. + # When async_loop is None, the caller drives the coroutine manually to get + # the result synchronously, and consequently this call must not yield execution. + # It can use 'await' so long as the callee does not yield execution either. + # Otherwise, if async_loop is not None, there is no restriction on 'await' calls'. + # If no dns.resolver.Resolver was given, get dnspython's default resolver. - # Override the default resolver's timeout. This may affect other uses of - # dnspython in this process. + # Use the asyncresolver if async_loop is not None. if dns_resolver is None: + if not async_loop: + dns_resolver = dns.resolver.get_default_resolver() + else: + dns_resolver = dns.asyncresolver.get_default_resolver() + + # Override the default resolver's timeout. This may affect other uses of + # dnspython in this process. from . import DEFAULT_TIMEOUT if timeout is None: timeout = DEFAULT_TIMEOUT - dns_resolver = dns.resolver.get_default_resolver() dns_resolver.lifetime = timeout + elif timeout is not None: raise ValueError("It's not valid to pass both timeout and dns_resolver.") - deliverability_info: DeliverabilityInfo = {} + # Define a resolve function that works with a regular or + # asynchronous dns.resolver.Resolver instance. + async def resolve(qname, rtype): + # When called non-asynchronously, expect a regular + # resolver that returns synchronously. Or if async_loop + # is not None but the caller didn't pass an + # dns.asyncresolver.Resolver, call it synchronously. + if not async_loop or not isinstance(dns_resolver, dns.asyncresolver.Resolver): + return dns_resolver.resolve(qname, rtype) + + # When async_loop is not None and if given a + # dns.asyncresolver.Resolver, call it asynchronously. + else: + return await dns_resolver.resolve(qname, rtype) + + # Collect successful deliverability information here. + deliverability_info = DeliverabilityInfo() try: try: # Try resolving for MX records (RFC 5321 Section 5). - response = dns_resolver.resolve(domain, "MX") + response = await resolve(domain, "MX") # For reporting, put them in priority order and remove the trailing dot in the qnames. mtas = sorted([(r.preference, str(r.exchange).rstrip('.')) for r in response]) @@ -84,11 +128,7 @@ def is_global_addr(address: Any) -> bool: return ipaddr.is_global try: - response = dns_resolver.resolve(domain, "A") - - if not any(is_global_addr(r.address) for r in response): - raise dns.resolver.NoAnswer # fall back to AAAA - + response = await resolve(domain, "A") deliverability_info["mx"] = [(0, domain)] deliverability_info["mx_fallback_type"] = "A" @@ -97,11 +137,7 @@ def is_global_addr(address: Any) -> bool: # If there was no A record, fall back to an AAAA record. # (It's unclear if SMTP servers actually do this.) try: - response = dns_resolver.resolve(domain, "AAAA") - - if not any(is_global_addr(r.address) for r in response): - raise dns.resolver.NoAnswer - + response = await resolve(domain, "AAAA") deliverability_info["mx"] = [(0, domain)] deliverability_info["mx_fallback_type"] = "AAAA" @@ -118,7 +154,7 @@ def is_global_addr(address: Any) -> bool: # absence of an MX record, this is probably a good sign that the # domain is not used for email. try: - response = dns_resolver.resolve(domain, "TXT") + response = await resolve(domain, "TXT") for rec in response: value = b"".join(rec.strings) if value.startswith(b"v=spf1 "): diff --git a/email_validator/validate_email.py b/email_validator/validate_email.py index 2adda2a..c448f3d 100644 --- a/email_validator/validate_email.py +++ b/email_validator/validate_email.py @@ -1,3 +1,4 @@ +from asyncio import Future from typing import Optional, Union, TYPE_CHECKING from .exceptions_types import EmailSyntaxError, ValidatedEmail @@ -11,7 +12,14 @@ _Resolver = object -def validate_email( +# This is the main function of the package. Through some magic, +# it can be called both non-asynchronously and, if async_loop +# is not None, also asynchronously with 'await'. If called +# asynchronously, dns_resolver may be an instance of +# dns.asyncresolver.Resolver. +def validate_email_sync_or_async( + # NOTE: Arguments other than async_loop must match + # validate_email_sync/async defined below. email: Union[str, bytes], /, # prior arguments are positional-only *, # subsequent arguments are keyword-only @@ -24,8 +32,9 @@ def validate_email( test_environment: Optional[bool] = None, globally_deliverable: Optional[bool] = None, timeout: Optional[int] = None, - dns_resolver: Optional[_Resolver] = None -) -> ValidatedEmail: + dns_resolver: Optional[_Resolver] = None, + async_loop: Optional[object] = None +) -> Union[ValidatedEmail, Future]: # Future[ValidatedEmail] works in Python 3.10+ """ Given an email address, and some options, returns a ValidatedEmail instance with information about the address if it is valid or, if the address is not @@ -139,22 +148,155 @@ def validate_email( # Check the length of the address. validate_email_length(ret) - if check_deliverability and not test_environment: - # Validate the email address's deliverability using DNS - # and update the returned ValidatedEmail object with metadata. - - if is_domain_literal: - # There is nothing to check --- skip deliverability checks. + # If no deliverability checks will be performed, return the validation + # information immediately. + if not check_deliverability or is_domain_literal or test_environment: + # When called non-asynchronously, just return --- that's easy. + if not async_loop: return ret - # Lazy load `deliverability` as it is slow to import (due to dns.resolver) - from .deliverability import validate_email_deliverability - deliverability_info = validate_email_deliverability( - ret.ascii_domain, ret.domain, timeout, dns_resolver + # When this method is called asynchronously, we must return an awaitable, + # not the regular return value. Normally 'async def' handles that for you, + # but to not duplicate this entire function in an asynchronous version, we + # have a single function that works both both ways, depending on if + # async_loop is set. + # + # Wrap the ValidatedEmail object in a Future that is immediately + # done. If async_loop holds a loop object, use it to create the Future. + # Otherwise create a default Future instance. + fut: Future + if async_loop is True: + fut = Future() + elif not hasattr(async_loop, 'create_future'): # suppress typing warning + raise RuntimeError("async_loop parameter must have a create_future method.") + else: + fut = async_loop.create_future() + fut.set_result(ret) + return fut + + # Validate the email address's deliverability using DNS + # and update the returned ValidatedEmail object with metadata. + # + # Domain literals are not DNS names so deliverability checks are + # skipped (above) if is_domain_literal is set. + + # Lazy load `deliverability` as it is slow to import (due to dns.resolver) + from .deliverability import validate_email_deliverability + + # Wrap validate_email_deliverability, which is an async function, in another + # async function that merges the resulting information with the ValidatedEmail + # instance. Since this method may be used in a non-asynchronous call, it + # must not await on anything that might yield execution. + async def run_deliverability_checks(): + # Run the DNS-based deliverabiltiy checks. + # + # Although validate_email_deliverability (and this local function) + # are async functions, when async_loop is None it must not yield + # execution. See below. + info = await validate_email_deliverability( + ret.ascii_domain, ret.domain, timeout, dns_resolver, + async_loop ) - mx = deliverability_info.get("mx") - if mx is not None: - ret.mx = mx - ret.mx_fallback_type = deliverability_info.get("mx_fallback_type") + # Merge deliverability info with the syntax info (if there was no exception). + for key, value in info.items(): + setattr(ret, key, value) + + return ret + + if not async_loop: + # When this function is called non-asynchronously, we will manually + # drive the coroutine returned by the async run_deliverability_checks + # function. Since we know that it does not yield execution, it will + # finish by raising StopIteration after the first 'send()' call. (If + # it doesn't, something serious went wrong.) + try: + # This call will either raise StopIteration on success or it will + # raise an EmailUndeliverableError on failure. + run_deliverability_checks().send(None) + + # If we come here, the coroutine yielded execution. We can't recover + # from this. + raise RuntimeError("Asynchronous resolver used in non-asychronous call or validate_email_deliverability mistakenly yielded.") + + except StopIteration as e: + # This is how a successful return occurs when driving a coroutine. + # The 'value' attribute on the exception holds the return value. + # Since we're in a non-asynchronous call, we can return it directly. + return e.value + + else: + # When this method is called asynchronously, return + # a coroutine. + return run_deliverability_checks() + + +# Validates an email address with DNS queries issued synchronously. +# This is exposed as the package's main validate_email method. +def validate_email_sync( + email: Union[str, bytes], + /, # prior arguments are positional-only + *, # subsequent arguments are keyword-only + allow_smtputf8: Optional[bool] = None, + allow_empty_local: bool = False, + allow_quoted_local: Optional[bool] = None, + allow_domain_literal: Optional[bool] = None, + allow_display_name: Optional[bool] = None, + check_deliverability: Optional[bool] = None, + test_environment: Optional[bool] = None, + globally_deliverable: Optional[bool] = None, + timeout: Optional[int] = None, + dns_resolver: Optional[object] = None +) -> ValidatedEmail: + ret = validate_email_sync_or_async( + email, + allow_smtputf8=allow_smtputf8, + allow_empty_local=allow_empty_local, + allow_quoted_local=allow_quoted_local, + allow_domain_literal=allow_domain_literal, + allow_display_name=allow_display_name, + check_deliverability=check_deliverability, + test_environment=test_environment, + globally_deliverable=globally_deliverable, + timeout=timeout, + dns_resolver=dns_resolver, + async_loop=None) + if not isinstance(ret, ValidatedEmail): # suppress typing warning + raise RuntimeError(type(ret)) return ret + + +# Validates an email address with DNS queries issued asynchronously. +async def validate_email_async( + email: Union[str, bytes], + /, # prior arguments are positional-only + *, # subsequent arguments are keyword-only + allow_smtputf8: Optional[bool] = None, + allow_empty_local: bool = False, + allow_quoted_local: Optional[bool] = None, + allow_domain_literal: Optional[bool] = None, + allow_display_name: Optional[bool] = None, + check_deliverability: Optional[bool] = None, + test_environment: Optional[bool] = None, + globally_deliverable: Optional[bool] = None, + timeout: Optional[int] = None, + dns_resolver: Optional[object] = None, + loop: Optional[object] = None +) -> ValidatedEmail: + coro = validate_email_sync_or_async( + email, + allow_smtputf8=allow_smtputf8, + allow_empty_local=allow_empty_local, + allow_quoted_local=allow_quoted_local, + allow_domain_literal=allow_domain_literal, + allow_display_name=allow_display_name, + check_deliverability=check_deliverability, + test_environment=test_environment, + globally_deliverable=globally_deliverable, + timeout=timeout, + dns_resolver=dns_resolver, + async_loop=loop or True) + import inspect + if not inspect.isawaitable(coro): # suppress typing warning + raise RuntimeError(type(coro)) + return await coro diff --git a/find_optimal_async_chunk_size.py b/find_optimal_async_chunk_size.py new file mode 100644 index 0000000..c31664d --- /dev/null +++ b/find_optimal_async_chunk_size.py @@ -0,0 +1,80 @@ +# Try different chunk sizes to find the optimal +# size for the fastest performance. +# +# Read in a list of email addresses on STDIN and +# draw from it random addresses for each call. + +import asyncio +import random +import sys +import time + +from email_validator import validate_email_async, EmailNotValidError, \ + caching_async_resolver + +async def wrap_validate_email(email, dns_resolver): + # Wrap validate_email_async to catch + # exceptions. + try: + return await validate_email_async(email, dns_resolver=dns_resolver) + except EmailNotValidError as e: + return e + +async def is_valid(email, dns_resolver): + try: + await validate_email_async(email, dns_resolver=dns_resolver) + return True + except EmailNotValidError as e: + return False + +async def go(): + # Read in all of the test addresses from STDIN. + all_email_addreses = [line.strip() for line in sys.stdin.readlines()] + + # Sample the whole set and throw out addresses that are + # invalid. + resolver = caching_async_resolver(timeout=5) + all_email_addreses = random.sample(all_email_addreses, 10000) + all_email_addreses = [email for email in all_email_addreses + if is_valid(is_valid, resolver)] + + print("Starting...") + + # Start testing various chunk sizes. + for chunk_size in range(1, 200): + reps = max(1, int(15 / chunk_size)) + + # Draw a random sample of email addresses to use + # in this test. For low chunk sizes where we perform + # multiple reps, draw the samples for all of the + # reps ahead of time so that we don't time the + # sampling. + samples = [ + random.sample(all_email_addreses, chunk_size) + for _ in range(reps) + ] + + # Create a resolver with a short timeout. + # Use a caching resolver to better reflect real-world practice. + resolver = caching_async_resolver(timeout=5) + resolver.nameservers = ["8.8.8.8"] + + # Start timing. + t_start = time.time_ns() + + # Run the reps. + for i in range(reps): + # Run the chunk. + coros = [ + wrap_validate_email(email, dns_resolver=resolver) + for email in samples[i]] + await asyncio.gather(*coros) + + # End timing. + t_end = time.time_ns() + + duration = t_end - t_start + + print(chunk_size, int(round(duration / (chunk_size * reps) / 1000)), sep='\t') + +asyncio.run(go()) diff --git a/pyproject.toml b/pyproject.toml index a92c08e..87ed5fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,3 +15,4 @@ warn_unused_ignores = true markers = [ "network: marks tests as requiring Internet access", ] +asyncio_mode='auto' diff --git a/test_requirements.txt b/test_requirements.txt index d05813d..6fe5343 100644 --- a/test_requirements.txt +++ b/test_requirements.txt @@ -1,7 +1,7 @@ # This file was generated by running: # sudo docker run --rm -it --network=host python:3.8-slim /bin/bash # pip install dnspython idna # from setup.cfg -# pip install pytest pytest-cov coverage flake8 mypy +# pip install pytest pytest-asyncio pytest-cov coverage flake8 mypy # pip freeze # (Some packages' latest versions may not be compatible with # the earliest Python version we support, and some exception @@ -21,6 +21,7 @@ pluggy==1.4.0 pycodestyle==2.11.1 pyflakes==3.2.0 pytest==8.1.1 +pytest-asyncio==0.23.6 pytest-cov==5.0.0 tomli==2.0.1 typing_extensions==4.11.0 diff --git a/tests/mocked_dns_response.py b/tests/mocked_dns_response.py index c6db5cb..89356ed 100644 --- a/tests/mocked_dns_response.py +++ b/tests/mocked_dns_response.py @@ -7,7 +7,7 @@ import os.path import pytest -from email_validator.deliverability import caching_resolver +from email_validator.deliverability import caching_resolver, caching_async_resolver # To run deliverability checks without actually making # DNS queries, we use a caching resolver where the cache @@ -27,8 +27,8 @@ class MockedDnsResponseData: INSTANCE = None @staticmethod - def create_resolver() -> dns.resolver.Resolver: - if MockedDnsResponseData.INSTANCE is None: + def create_resolver(_async=False) -> dns.resolver.Resolver: + if not hasattr(MockedDnsResponseData, 'INSTANCE'): # Create a singleton instance of this class and load the saved DNS responses. # Except when BUILD_MOCKED_DNS_RESPONSE_DATA is true, don't load the data. singleton = MockedDnsResponseData() @@ -38,8 +38,12 @@ def create_resolver() -> dns.resolver.Resolver: # Return a new dns.resolver.Resolver configured for caching # using the singleton instance. - dns_resolver = dns.resolver.Resolver(configure=BUILD_MOCKED_DNS_RESPONSE_DATA) - return caching_resolver(cache=MockedDnsResponseData.INSTANCE, dns_resolver=dns_resolver) + if not _async: + dns_resolver = dns.resolver.Resolver(configure=BUILD_MOCKED_DNS_RESPONSE_DATA) + return caching_resolver(cache=MockedDnsResponseData.INSTANCE, dns_resolver=dns_resolver) + else: + dns_resolver = dns.asyncresolver.Resolver(configure=BUILD_MOCKED_DNS_RESPONSE_DATA) + return caching_async_resolver(cache=MockedDnsResponseData.INSTANCE, dns_resolver=dns_resolver) def __init__(self) -> None: self.data: Dict[dns.resolver.CacheKey, Optional[MockedDnsResponseData.Ans]] = {} diff --git a/tests/test_deliverability.py b/tests/test_deliverability.py index b65116b..cbe6c9c 100644 --- a/tests/test_deliverability.py +++ b/tests/test_deliverability.py @@ -5,13 +5,42 @@ from email_validator import EmailUndeliverableError, \ validate_email, caching_resolver -from email_validator.deliverability import validate_email_deliverability +from email_validator.deliverability import validate_email_deliverability as validate_email_deliverability_async from mocked_dns_response import MockedDnsResponseData, MockedDnsResponseDataCleanup # noqa: F401 RESOLVER = MockedDnsResponseData.create_resolver() +async def validate_email_deliverability(*args, **kwargs): + # The internal validate_email_deliverability method is + # asynchronous but has no awaits if not passed an + # async loop. To call it synchronously in tests, + # we can drive a manual loop. + try: + validate_email_deliverability_async(*args, **kwargs).send(None) + raise RuntimeError("validate_email_deliverability did not run synchronously.") + except StopIteration as e: + sync_result = e.value + except Exception as e: + sync_result = e + + # Do the same thing again asynchronously. + try: + async_result = await validate_email_deliverability_async(*args, **kwargs) + except Exception as e: + async_result = e + + # Check that the results match. + # Not sure if repr() is really sufficient here. + assert repr(sync_result) == repr(async_result) + + # Return the synchronous result for the caller's asserts. + if isinstance(sync_result, Exception): + raise sync_result + return sync_result + + @pytest.mark.parametrize( 'domain,expected_response', [ @@ -19,8 +48,8 @@ ('pages.github.com', {'mx': [(0, 'pages.github.com')], 'mx_fallback_type': 'A'}), ], ) -def test_deliverability_found(domain: str, expected_response: str) -> None: - response = validate_email_deliverability(domain, domain, dns_resolver=RESOLVER) +async def test_deliverability_found(domain: str, expected_response: str) -> None: + response = await validate_email_deliverability(domain, domain, dns_resolver=RESOLVER) assert response == expected_response @@ -59,14 +88,14 @@ def test_email_example_reserved_domain(email_input: str) -> None: assert re.match(r"The domain name [a-z\.]+ does not (accept email|exist)\.", str(exc_info.value)) is not None -def test_deliverability_dns_timeout() -> None: - response = validate_email_deliverability('timeout.com', 'timeout.com', dns_resolver=RESOLVER) +async def test_deliverability_dns_timeout() -> None: + response = await validate_email_deliverability('timeout.com', 'timeout.com', dns_resolver=RESOLVER) assert "mx" not in response assert response.get("unknown-deliverability") == "timeout" @pytest.mark.network -def test_caching_dns_resolver() -> None: +async def test_caching_dns_resolver() -> None: class TestCache: def __init__(self) -> None: self.cache: Dict[Any, Any] = {} diff --git a/tests/test_main.py b/tests/test_main.py index ab8eecd..0606cba 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -41,7 +41,8 @@ def test_main_multi_input(monkeypatch: pytest.MonkeyPatch, capsys: pytest.Captur test_input = io.StringIO("\n".join(test_cases)) monkeypatch.setattr('sys.stdin', test_input) monkeypatch.setattr('sys.argv', ['email_validator']) - validator_command_line_tool(dns_resolver=RESOLVER) + ASYNC_RESOLVER = MockedDnsResponseData.create_resolver(_async=True) + validator_command_line_tool(dns_resolver=ASYNC_RESOLVER) stdout, _ = capsys.readouterr() assert test_cases[0] not in stdout assert test_cases[1] not in stdout