diff --git a/email_validator/__init__.py b/email_validator/__init__.py index c3b5929..0607680 100644 --- a/email_validator/__init__.py +++ b/email_validator/__init__.py @@ -3,13 +3,14 @@ # Export the main method, helper methods, and the public data types. from .exceptions_types import ValidatedEmail, EmailNotValidError, \ EmailSyntaxError, EmailUndeliverableError -from .validate_email import validate_email +from .validate_email import validate_email_sync as validate_email, validate_email_async from .version import __version__ -__all__ = ["validate_email", +__all__ = ["validate_email", "validate_email_async", "ValidatedEmail", "EmailNotValidError", "EmailSyntaxError", "EmailUndeliverableError", - "caching_resolver", "__version__"] + "caching_resolver", "caching_async_resolver", + "__version__"] def caching_resolver(*args, **kwargs): @@ -19,6 +20,13 @@ def caching_resolver(*args, **kwargs): return caching_resolver(*args, **kwargs) +def caching_async_resolver(*args, **kwargs): + # Lazy load `deliverability` as it is slow to import (due to dns.resolver) + from .deliverability import caching_async_resolver + + return caching_async_resolver(*args, **kwargs) + + # These global attributes are a part of the library's API and can be # changed by library users. diff --git a/email_validator/__main__.py b/email_validator/__main__.py index a414ff6..045938b 100644 --- a/email_validator/__main__.py +++ b/email_validator/__main__.py @@ -5,11 +5,19 @@ # python -m email_validator test@example.org # python -m email_validator < LIST_OF_ADDRESSES.TXT # -# Provide email addresses to validate either as a command-line argument -# or in STDIN separated by newlines. Validation errors will be printed for -# invalid email addresses. When passing an email address on the command -# line, if the email address is valid, information about it will be printed. -# When using STDIN, no output will be given for valid email addresses. +# Provide email addresses to validate either as a single command-line argument +# or on STDIN separated by newlines. +# +# When passing an email address on the command line, if the email address +# is valid, information about it will be printed to STDOUT. If the email +# address is invalid, an error message will be printed to STDOUT and +# the exit code will be set to 1. +# +# When passsing email addresses on STDIN, validation errors will be printed +# for invalid email addresses. No output is given for valid email addresses. +# Validation errors are preceded by the email address that failed and a tab +# character. It is the user's responsibility to ensure email addresses +# do not contain tab or newline characters. # # Keyword arguments to validate_email can be set in environment variables # of the same name but upprcase (see below). @@ -17,12 +25,65 @@ import json import os import sys +import itertools -from .validate_email import validate_email -from .deliverability import caching_resolver +from .deliverability import caching_async_resolver from .exceptions_types import EmailNotValidError +def main_command_line(email_address, options, dns_resolver): + # Validate the email address passed on the command line. + + from . import validate_email + + try: + result = validate_email(email_address, dns_resolver=dns_resolver, **options) + print(json.dumps(result.as_dict(), indent=2, sort_keys=True, ensure_ascii=False)) + return True + except EmailNotValidError as e: + print(e) + return False + + +async def main_stdin(options, dns_resolver): + # Validate the email addresses pased line-by-line on STDIN. + # Chunk the addresses and call the async version of validate_email + # for all the addresses in the chunk, and wait for the chunk + # to complete. + + import asyncio + + from . import validate_email_async as validate_email + + dns_resolver = dns_resolver or caching_async_resolver() + + # https://stackoverflow.com/a/312467 + def split_seq(iterable, size): + it = iter(iterable) + item = list(itertools.islice(it, size)) + while item: + yield item + item = list(itertools.islice(it, size)) + + CHUNK_SIZE = 100 + + async def process_line(line): + email = line.strip() + try: + await validate_email(email, dns_resolver=dns_resolver, **options) + # If the email was valid, do nothing. + return None + except EmailNotValidError as e: + return (email, e) + + for chunk in split_seq(sys.stdin, CHUNK_SIZE): + awaitables = [process_line(line) for line in chunk] + errors = await asyncio.gather(*awaitables) + for error in errors: + if error is not None: + print(*error, sep='\t') + + def main(dns_resolver=None): # The dns_resolver argument is for tests. @@ -36,24 +97,14 @@ def main(dns_resolver=None): if varname in os.environ: options[varname.lower()] = float(os.environ[varname]) - if len(sys.argv) == 1: - # Validate the email addresses pased line-by-line on STDIN. - dns_resolver = dns_resolver or caching_resolver() - for line in sys.stdin: - email = line.strip() - try: - validate_email(email, dns_resolver=dns_resolver, **options) - except EmailNotValidError as e: - print(f"{email} {e}") + if len(sys.argv) == 2: + return main_command_line(sys.argv[1], options, dns_resolver) else: - # Validate the email address passed on the command line. - email = sys.argv[1] - try: - result = validate_email(email, dns_resolver=dns_resolver, **options) - print(json.dumps(result.as_dict(), indent=2, sort_keys=True, ensure_ascii=False)) - except EmailNotValidError as e: - print(e) + import asyncio + asyncio.run(main_stdin(options, dns_resolver)) + return True if __name__ == "__main__": - main() + if not main(): + sys.exit(1) diff --git a/email_validator/deliverability.py b/email_validator/deliverability.py index 4846091..8f03428 100644 --- a/email_validator/deliverability.py +++ b/email_validator/deliverability.py @@ -3,6 +3,7 @@ from .exceptions_types import EmailUndeliverableError import dns.resolver +import dns.asyncresolver import dns.exception @@ -16,30 +17,72 @@ def caching_resolver(*, timeout: Optional[int] = None, cache=None): return resolver -def validate_email_deliverability(domain: str, domain_i18n: str, timeout: Optional[int] = None, dns_resolver=None): +def caching_async_resolver(*, timeout: Optional[int] = None, cache=None): + if timeout is None: + from . import DEFAULT_TIMEOUT + timeout = DEFAULT_TIMEOUT + resolver = dns.asyncresolver.Resolver() + resolver.cache = cache or dns.resolver.LRUCache() # type: ignore + resolver.lifetime = timeout # type: ignore # timeout, in seconds + return resolver + + +async def validate_email_deliverability( + domain: str, + domain_i18n: str, + timeout: Optional[int] = None, + dns_resolver=None, + _async_loop: Optional[bool] = None): + # Check that the domain resolves to an MX record. If there is no MX record, # try an A or AAAA record which is a deprecated fallback for deliverability. # Raises an EmailUndeliverableError on failure. On success, returns a dict # with deliverability information. + # When _async_loop is None, this method must return synchronously. The + # caller drives the coroutine manually to get the result synchronously, + # and consequently this call must not yield execution. Otherwise, regular + # async/await calls may be used. + # If no dns.resolver.Resolver was given, get dnspython's default resolver. - # Override the default resolver's timeout. This may affect other uses of - # dnspython in this process. if dns_resolver is None: + if not _async_loop: + dns_resolver = dns.resolver.get_default_resolver() + else: + dns_resolver = dns.asyncresolver.get_default_resolver() + + # Override the default resolver's timeout. This may affect other uses of + # dnspython in this process. from . import DEFAULT_TIMEOUT if timeout is None: timeout = DEFAULT_TIMEOUT - dns_resolver = dns.resolver.get_default_resolver() dns_resolver.lifetime = timeout + elif timeout is not None: raise ValueError("It's not valid to pass both timeout and dns_resolver.") + # Define a resolve function that works with a regular or + # asynchronous dns.resolver.Resolver instance depending + # on the _async_loop argument. + async def resolve(qname, rtype): + # When called non-asynchronously, expect a regular + # resolver that returns synchronously. Or if the + # user didn't pass a dns.asyncresolver.Resolver, + # call it synchronously. + if not _async_loop or not isinstance(dns_resolver, dns.asyncresolver.Resolver): + return dns_resolver.resolve(qname, rtype) + + # When called asynchronsouly, if given a dns.asyncresolver.Resolver, + # call it asynchronously. + else: + return await dns_resolver.resolve(qname, rtype) + deliverability_info: Dict[str, Any] = {} try: try: # Try resolving for MX records (RFC 5321 Section 5). - response = dns_resolver.resolve(domain, "MX") + response = await resolve(domain, "MX") # For reporting, put them in priority order and remove the trailing dot in the qnames. mtas = sorted([(r.preference, str(r.exchange).rstrip('.')) for r in response]) @@ -59,7 +102,7 @@ def validate_email_deliverability(domain: str, domain_i18n: str, timeout: Option except dns.resolver.NoAnswer: # If there was no MX record, fall back to an A record. (RFC 5321 Section 5) try: - response = dns_resolver.resolve(domain, "A") + response = await resolve(domain, "A") deliverability_info["mx"] = [(0, str(r)) for r in response] deliverability_info["mx_fallback_type"] = "A" @@ -68,7 +111,7 @@ def validate_email_deliverability(domain: str, domain_i18n: str, timeout: Option # If there was no A record, fall back to an AAAA record. # (It's unclear if SMTP servers actually do this.) try: - response = dns_resolver.resolve(domain, "AAAA") + response = await resolve(domain, "AAAA") deliverability_info["mx"] = [(0, str(r)) for r in response] deliverability_info["mx_fallback_type"] = "AAAA" @@ -85,7 +128,7 @@ def validate_email_deliverability(domain: str, domain_i18n: str, timeout: Option # absence of an MX record, this is probably a good sign that the # domain is not used for email. try: - response = dns_resolver.resolve(domain, "TXT") + response = await resolve(domain, "TXT") for rec in response: value = b"".join(rec.strings) if value.startswith(b"v=spf1 "): diff --git a/email_validator/validate_email.py b/email_validator/validate_email.py index d2791fe..c9ce85d 100644 --- a/email_validator/validate_email.py +++ b/email_validator/validate_email.py @@ -1,3 +1,4 @@ +from asyncio import Future from typing import Optional, Union from .exceptions_types import EmailSyntaxError, ValidatedEmail @@ -6,6 +7,7 @@ def validate_email( + # NOTE: Arguments must match validate_email_async defined below. email: Union[str, bytes], /, # prior arguments are positional-only *, # subsequent arguments are keyword-only @@ -17,8 +19,9 @@ def validate_email( test_environment: Optional[bool] = None, globally_deliverable: Optional[bool] = None, timeout: Optional[int] = None, - dns_resolver: Optional[object] = None -) -> ValidatedEmail: + dns_resolver: Optional[object] = None, + _async_loop: Optional[object] = None +) -> Union[ValidatedEmail, Future]: # Future[ValidatedEmail] works in Python 3.10+ """ Given an email address, and some options, returns a ValidatedEmail instance with information about the address if it is valid or, if the address is not @@ -127,20 +130,151 @@ def validate_email( # Check the length of the address. validate_email_length(ret) - if check_deliverability and not test_environment: - # Validate the email address's deliverability using DNS - # and update the returned ValidatedEmail object with metadata. - - if is_domain_literal: - # There is nothing to check --- skip deliverability checks. + # If no deliverability checks will be performed, return the validation + # information immediately. + if not check_deliverability or is_domain_literal or test_environment: + # When called non-asynchronously, just return --- that's easy. + if not _async_loop: return ret - # Lazy load `deliverability` as it is slow to import (due to dns.resolver) - from .deliverability import validate_email_deliverability - deliverability_info = validate_email_deliverability( - ret.ascii_domain, ret.domain, timeout, dns_resolver + # When this method is called asynchronously, we must return an awaitable, + # not the regular return value. Normally 'async def' handles that for you, + # but to not duplicate this entire function in an asynchronous version, we + # have a single function that works both both ways, depending on if + # _async_loop is set. + # + # Wrap the ValidatedEmail object in a Future that is immediately + # done. If _async_loop holds a loop object, use it to create the Future. + # Otherwise create a default Future instance. + fut: Future + if _async_loop is True: + fut = Future() + elif not hasattr(_async_loop, 'create_future'): # suppress typing warning + raise RuntimeError("_async_loop parameter must have a create_future method.") + else: + fut = _async_loop.create_future() + fut.set_result(ret) + return fut + + # Validate the email address's deliverability using DNS + # and update the returned ValidatedEmail object with metadata. + # + # Domain literals are not DNS names so deliverability checks are + # skipped (above) if is_domain_literal is set. + + # Lazy load `deliverability` as it is slow to import (due to dns.resolver) + from .deliverability import validate_email_deliverability + + # Wrap validate_email_deliverability, which is an async function, in another + # async function that merges the resulting information with the ValidatedEmail + # instance. Since this method may be used in a non-asynchronous call, it + # must not await on anything that might yield execution. + async def run_deliverability_checks(): + # Run the DNS-based deliverabiltiy checks. + # + # Although validate_email_deliverability (and this local function) + # are async functions, when _async_loop is None it must not yield + # execution. See below. + info = await validate_email_deliverability( + ret.ascii_domain, ret.domain, timeout, dns_resolver, + _async_loop ) - for key, value in deliverability_info.items(): + + # Merge deliverability info with the syntax info (if there was no exception). + for key, value in info.items(): setattr(ret, key, value) + return ret + + if not _async_loop: + # When this function is called non-asynchronously, we will manually + # drive the coroutine returned by the async run_deliverability_checks + # function. Since we know that it does not yield execution, it will + # finish by raising StopIteration after the first 'send()' call. (If + # it doesn't, something serious went wrong.) + try: + # This call will either raise StopIteration on success or it will + # raise an EmailUndeliverableError on failure. + run_deliverability_checks().send(None) + + # If we come here, the coroutine yielded execution. We can't recover + # from this. + raise RuntimeError("Asynchronous resolver used in non-asychronous call or validate_email_deliverability mistakenly yielded.") + + except StopIteration as e: + # This is how a successful return occurs when driving a coroutine. + # The 'value' attribute on the exception holds the return value. + # Since we're in a non-asynchronous call, we can return it directly. + return e.value + + else: + # When this method is called asynchronously, return + # a coroutine. + return run_deliverability_checks() + + +# Validates an email address with DNS queries issued synchronously. +# This is exposed as the package's main validate_email method. +def validate_email_sync( + email: Union[str, bytes], + /, # prior arguments are positional-only + *, # subsequent arguments are keyword-only + allow_smtputf8: Optional[bool] = None, + allow_empty_local: bool = False, + allow_quoted_local: Optional[bool] = None, + allow_domain_literal: Optional[bool] = None, + check_deliverability: Optional[bool] = None, + test_environment: Optional[bool] = None, + globally_deliverable: Optional[bool] = None, + timeout: Optional[int] = None, + dns_resolver: Optional[object] = None +) -> ValidatedEmail: + ret = validate_email( + email, + allow_smtputf8=allow_smtputf8, + allow_empty_local=allow_empty_local, + allow_quoted_local=allow_quoted_local, + allow_domain_literal=allow_domain_literal, + check_deliverability=check_deliverability, + test_environment=test_environment, + globally_deliverable=globally_deliverable, + timeout=timeout, + dns_resolver=dns_resolver, + _async_loop=None) + if not isinstance(ret, ValidatedEmail): # suppress typing warning + raise RuntimeError(type(ret)) return ret + + +# Validates an email address with DNS queries issued asynchronously. +async def validate_email_async( + email: Union[str, bytes], + /, # prior arguments are positional-only + *, # subsequent arguments are keyword-only + allow_smtputf8: Optional[bool] = None, + allow_empty_local: bool = False, + allow_quoted_local: Optional[bool] = None, + allow_domain_literal: Optional[bool] = None, + check_deliverability: Optional[bool] = None, + test_environment: Optional[bool] = None, + globally_deliverable: Optional[bool] = None, + timeout: Optional[int] = None, + dns_resolver: Optional[object] = None, + loop: Optional[object] = None +) -> ValidatedEmail: + coro = validate_email( + email, + allow_smtputf8=allow_smtputf8, + allow_empty_local=allow_empty_local, + allow_quoted_local=allow_quoted_local, + allow_domain_literal=allow_domain_literal, + check_deliverability=check_deliverability, + test_environment=test_environment, + globally_deliverable=globally_deliverable, + timeout=timeout, + dns_resolver=dns_resolver, + _async_loop=loop or True) + import inspect + if not inspect.isawaitable(coro): # suppress typing warning + raise RuntimeError() + return await coro diff --git a/tests/mocked_dns_response.py b/tests/mocked_dns_response.py index cd32796..cccd047 100644 --- a/tests/mocked_dns_response.py +++ b/tests/mocked_dns_response.py @@ -3,7 +3,7 @@ import os.path import pytest -from email_validator.deliverability import caching_resolver +from email_validator.deliverability import caching_resolver, caching_async_resolver # To run deliverability checks without actually making # DNS queries, we use a caching resolver where the cache @@ -21,7 +21,7 @@ class MockedDnsResponseData: DATA_PATH = os.path.dirname(__file__) + "/mocked-dns-answers.json" @staticmethod - def create_resolver(): + def create_resolver(_async=False): if not hasattr(MockedDnsResponseData, 'INSTANCE'): # Create a singleton instance of this class and load the saved DNS responses. # Except when BUILD_MOCKED_DNS_RESPONSE_DATA is true, don't load the data. @@ -32,7 +32,10 @@ def create_resolver(): # Return a new dns.resolver.Resolver configured for caching # using the singleton instance. - return caching_resolver(cache=MockedDnsResponseData.INSTANCE) + if not _async: + return caching_resolver(cache=MockedDnsResponseData.INSTANCE) + else: + return caching_async_resolver(cache=MockedDnsResponseData.INSTANCE) def __init__(self): self.data = {} diff --git a/tests/test_deliverability.py b/tests/test_deliverability.py index 7431668..a5448fc 100644 --- a/tests/test_deliverability.py +++ b/tests/test_deliverability.py @@ -3,13 +3,21 @@ from email_validator import EmailUndeliverableError, \ validate_email, caching_resolver -from email_validator.deliverability import validate_email_deliverability +from email_validator.deliverability import validate_email_deliverability as validate_email_deliverability_async from mocked_dns_response import MockedDnsResponseData, MockedDnsResponseDataCleanup # noqa: F401 RESOLVER = MockedDnsResponseData.create_resolver() +def validate_email_deliverability(*args, **kwargs): + try: + validate_email_deliverability_async(*args, **kwargs).send(None) + raise RuntimeError("validate_email_deliverability did not run synchronously.") + except StopIteration as e: + return e.value + + def test_deliverability_found(): response = validate_email_deliverability('gmail.com', 'gmail.com', dns_resolver=RESOLVER) assert response.keys() == {'mx', 'mx_fallback_type'} diff --git a/tests/test_main.py b/tests/test_main.py index 579163f..25dc70a 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -6,7 +6,8 @@ from mocked_dns_response import MockedDnsResponseData, MockedDnsResponseDataCleanup # noqa: F401 -RESOLVER = MockedDnsResponseData.create_resolver() +SYNC_RESOLVER = MockedDnsResponseData.create_resolver(_async=False) +ASYNC_RESOLVER = MockedDnsResponseData.create_resolver(_async=True) def test_dict_accessor(): @@ -20,17 +21,17 @@ def test_main_single_good_input(monkeypatch, capsys): import json test_email = "google@google.com" monkeypatch.setattr('sys.argv', ['email_validator', test_email]) - validator_command_line_tool(dns_resolver=RESOLVER) + validator_command_line_tool(dns_resolver=SYNC_RESOLVER) stdout, _ = capsys.readouterr() output = json.loads(str(stdout)) assert isinstance(output, dict) - assert validate_email(test_email, dns_resolver=RESOLVER).original == output["original"] + assert validate_email(test_email, dns_resolver=SYNC_RESOLVER).original == output["original"] def test_main_single_bad_input(monkeypatch, capsys): bad_email = 'test@..com' monkeypatch.setattr('sys.argv', ['email_validator', bad_email]) - validator_command_line_tool(dns_resolver=RESOLVER) + validator_command_line_tool(dns_resolver=SYNC_RESOLVER) stdout, _ = capsys.readouterr() assert stdout == 'An email address cannot have a period immediately after the @-sign.\n' @@ -41,7 +42,7 @@ def test_main_multi_input(monkeypatch, capsys): test_input = io.StringIO("\n".join(test_cases)) monkeypatch.setattr('sys.stdin', test_input) monkeypatch.setattr('sys.argv', ['email_validator']) - validator_command_line_tool(dns_resolver=RESOLVER) + validator_command_line_tool(dns_resolver=ASYNC_RESOLVER) stdout, _ = capsys.readouterr() assert test_cases[0] not in stdout assert test_cases[1] not in stdout