Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump typing-extensions and mypy for ParamSpec #25088

Merged
merged 11 commits into from
Jul 18, 2022
4 changes: 2 additions & 2 deletions airflow/jobs/scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def register_signals(self) -> None:
signal.signal(signal.SIGTERM, self._exit_gracefully)
signal.signal(signal.SIGUSR2, self._debug_dump)

def _exit_gracefully(self, signum: int, frame: "FrameType") -> None:
def _exit_gracefully(self, signum: int, frame: Optional["FrameType"]) -> None:
"""Helper method to clean up processor_agent to avoid leaving orphan processes."""
if not _is_parent_process():
# Only the parent process should perform the cleanup.
Expand All @@ -186,7 +186,7 @@ def _exit_gracefully(self, signum: int, frame: "FrameType") -> None:
self.processor_agent.end()
sys.exit(os.EX_OK)

def _debug_dump(self, signum: int, frame: "FrameType") -> None:
def _debug_dump(self, signum: int, frame: Optional["FrameType"]) -> None:
if not _is_parent_process():
# Only the parent process should perform the debug dump.
return
Expand Down
5 changes: 4 additions & 1 deletion airflow/mypy/plugin/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ def _change_decorator_function_type(
# Mark provided arguments as optional
decorator.arg_types = copy.copy(decorated.arg_types)
for argument in provided_arguments:
index = decorated.arg_names.index(argument)
try:
index = decorated.arg_names.index(argument)
except ValueError:
continue
decorated_type = decorated.arg_types[index]
decorator.arg_types[index] = UnionType.make_union([decorated_type, NoneType()])
decorated.arg_kinds[index] = ARG_NAMED_OPT
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def execute(self, context: 'Context') -> None:

scan_kwargs = copy(self.dynamodb_scan_kwargs) if self.dynamodb_scan_kwargs else {}
err = None
f: IO[Any]
with NamedTemporaryFile() as f:
try:
f = self._scan_dynamodb_and_upload_to_s3(f, scan_kwargs, table)
Expand Down
19 changes: 11 additions & 8 deletions airflow/providers/amazon/aws/transfers/sql_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
# specific language governing permissions and limitations
# under the License.

import enum
from collections import namedtuple
from enum import Enum
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union

Expand All @@ -35,10 +35,13 @@
from airflow.utils.context import Context


FILE_FORMAT = Enum(
"FILE_FORMAT",
"CSV, JSON, PARQUET",
)
class FILE_FORMAT(enum.Enum):
"""Possible file formats."""

CSV = enum.auto()
JSON = enum.auto()
PARQUET = enum.auto()


FileOptions = namedtuple('FileOptions', ['mode', 'suffix', 'function'])

Expand Down Expand Up @@ -118,9 +121,9 @@ def __init__(
if "path_or_buf" in self.pd_kwargs:
raise AirflowException('The argument path_or_buf is not allowed, please remove it')

self.file_format = getattr(FILE_FORMAT, file_format.upper(), None)

if self.file_format is None:
try:
self.file_format = FILE_FORMAT[file_format.upper()]
except KeyError:
raise AirflowException(f"The argument file_format doesn't support {file_format} value.")

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/operators/cloud_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
SETTINGS = 'settings'
SETTINGS_VERSION = 'settingsVersion'

CLOUD_SQL_CREATE_VALIDATION = [
CLOUD_SQL_CREATE_VALIDATION: Sequence[dict] = [
dict(name="name", allow_empty=False),
dict(
name="settings",
Expand Down
62 changes: 38 additions & 24 deletions airflow/providers/microsoft/azure/hooks/cosmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
login (=Endpoint uri), password (=secret key) and extra fields database_name and collection_name to specify
the default database and collection to use (see connection `azure_cosmos_default` for an example).
"""
import json
import uuid
from typing import Any, Dict, Optional

Expand Down Expand Up @@ -140,14 +141,22 @@ def does_collection_exist(self, collection_name: str, database_name: str) -> boo
existing_container = list(
self.get_conn()
.get_database_client(self.__get_database_name(database_name))
.query_containers("SELECT * FROM r WHERE r.id=@id", [{"name": "@id", "value": collection_name}])
.query_containers(
"SELECT * FROM r WHERE r.id=@id",
parameters=[json.dumps({"name": "@id", "value": collection_name})],
)
)
if len(existing_container) == 0:
return False

return True

def create_collection(self, collection_name: str, database_name: Optional[str] = None) -> None:
def create_collection(
self,
collection_name: str,
database_name: Optional[str] = None,
partition_key: Optional[str] = None,
) -> None:
"""Creates a new collection in the CosmosDB database."""
if collection_name is None:
raise AirflowBadRequest("Collection name cannot be None.")
Expand All @@ -157,13 +166,16 @@ def create_collection(self, collection_name: str, database_name: Optional[str] =
existing_container = list(
self.get_conn()
.get_database_client(self.__get_database_name(database_name))
.query_containers("SELECT * FROM r WHERE r.id=@id", [{"name": "@id", "value": collection_name}])
.query_containers(
"SELECT * FROM r WHERE r.id=@id",
parameters=[json.dumps({"name": "@id", "value": collection_name})],
)
)

# Only create if we did not find it already existing
if len(existing_container) == 0:
self.get_conn().get_database_client(self.__get_database_name(database_name)).create_container(
collection_name
collection_name, partition_key=partition_key
)

def does_database_exist(self, database_name: str) -> bool:
Expand All @@ -173,10 +185,8 @@ def does_database_exist(self, database_name: str) -> bool:

existing_database = list(
self.get_conn().query_databases(
{
"query": "SELECT * FROM r WHERE r.id=@id",
"parameters": [{"name": "@id", "value": database_name}],
}
"SELECT * FROM r WHERE r.id=@id",
parameters=[json.dumps({"name": "@id", "value": database_name})],
)
)
if len(existing_database) == 0:
Expand All @@ -193,10 +203,8 @@ def create_database(self, database_name: str) -> None:
# to create it twice
existing_database = list(
self.get_conn().query_databases(
{
"query": "SELECT * FROM r WHERE r.id=@id",
"parameters": [{"name": "@id", "value": database_name}],
}
"SELECT * FROM r WHERE r.id=@id",
parameters=[json.dumps({"name": "@id", "value": database_name})],
)
)

Expand Down Expand Up @@ -267,18 +275,28 @@ def insert_documents(
return created_documents

def delete_document(
self, document_id: str, database_name: Optional[str] = None, collection_name: Optional[str] = None
self,
document_id: str,
database_name: Optional[str] = None,
collection_name: Optional[str] = None,
partition_key: Optional[str] = None,
) -> None:
"""Delete an existing document out of a collection in the CosmosDB database."""
if document_id is None:
raise AirflowBadRequest("Cannot delete a document without an id")

self.get_conn().get_database_client(self.__get_database_name(database_name)).get_container_client(
self.__get_collection_name(collection_name)
).delete_item(document_id)
(
self.get_conn()
.get_database_client(self.__get_database_name(database_name))
.get_container_client(self.__get_collection_name(collection_name))
.delete_item(document_id, partition_key=partition_key)
)

def get_document(
self, document_id: str, database_name: Optional[str] = None, collection_name: Optional[str] = None
self,
document_id: str,
database_name: Optional[str] = None,
collection_name: Optional[str] = None,
partition_key: Optional[str] = None,
):
"""Get a document from an existing collection in the CosmosDB database."""
if document_id is None:
Expand All @@ -289,7 +307,7 @@ def get_document(
self.get_conn()
.get_database_client(self.__get_database_name(database_name))
.get_container_client(self.__get_collection_name(collection_name))
.read_item(document_id)
.read_item(document_id, partition_key=partition_key)
)
except CosmosHttpResponseError:
return None
Expand All @@ -305,17 +323,13 @@ def get_documents(
if sql_string is None:
raise AirflowBadRequest("SQL query string cannot be None")

# Query them in SQL
query = {'query': sql_string}

try:
result_iterable = (
self.get_conn()
.get_database_client(self.__get_database_name(database_name))
.get_container_client(self.__get_collection_name(collection_name))
.query_items(query, partition_key)
.query_items(sql_string, partition_key=partition_key)
)

return list(result_iterable)
except CosmosHttpResponseError:
return None
Expand Down
6 changes: 3 additions & 3 deletions airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
import functools
import warnings
from typing import (
AbstractSet,
Any,
Container,
Dict,
ItemsView,
Iterator,
KeysView,
List,
Mapping,
MutableMapping,
Expand Down Expand Up @@ -175,7 +175,7 @@ class Context(MutableMapping[str, Any]):
}

def __init__(self, context: Optional[MutableMapping[str, Any]] = None, **kwargs: Any) -> None:
self._context = context or {}
self._context: MutableMapping[str, Any] = context or {}
if kwargs:
self._context.update(kwargs)
self._deprecation_replacements = self._DEPRECATION_REPLACEMENTS.copy()
Expand Down Expand Up @@ -231,7 +231,7 @@ def __ne__(self, other: Any) -> bool:
return NotImplemented
return self._context != other._context

def keys(self) -> AbstractSet[str]:
def keys(self) -> KeysView[str]:
return self._context.keys()

def items(self):
Expand Down
8 changes: 4 additions & 4 deletions dev/breeze/src/airflow_breeze/commands/testing_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,9 @@ def run_with_progress(
) -> RunCommandResult:
title = f"Running tests: {test_type}, Python: {python}, Backend: {backend}:{version}"
try:
with tempfile.NamedTemporaryFile(mode='w+t', delete=False) as f:
with tempfile.NamedTemporaryFile(mode='w+t', delete=False) as tf:
get_console().print(f"[info]Starting test = {title}[/]")
thread = MonitoringThread(title=title, file_name=f.name)
thread = MonitoringThread(title=title, file_name=tf.name)
thread.start()
try:
result = run_command(
Expand All @@ -208,14 +208,14 @@ def run_with_progress(
dry_run=dry_run,
env=env_variables,
check=False,
stdout=f,
stdout=tf,
stderr=subprocess.STDOUT,
)
finally:
thread.stop()
thread.join()
with ci_group(f"Result of {title}", message_type=message_type_from_return_code(result.returncode)):
with open(f.name) as f:
with open(tf.name) as f:
shutil.copyfileobj(f, sys.stdout)
finally:
os.unlink(f.name)
Expand Down
1 change: 1 addition & 0 deletions scripts/in_container/run_migration_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def revision_suffix(rev: "Script"):

def ensure_airflow_version(revisions: Iterable["Script"]):
for rev in revisions:
assert rev.module.__file__ is not None # For Mypy.
file = Path(rev.module.__file__)
content = file.read_text()
if not has_version(content):
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ install_requires =
tabulate>=0.7.5
tenacity>=6.2.0
termcolor>=1.1.0
typing-extensions>=3.7.4
typing-extensions>=4.0.0
unicodecsv>=0.14.1
werkzeug>=2.0

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def write_version(filename: str = str(AIRFLOW_SOURCES_ROOT / "airflow" / "git_ve
# mypyd which does not support installing the types dynamically with --install-types
mypy_dependencies = [
# TODO: upgrade to newer versions of MyPy continuously as they are released
'mypy==0.910',
'mypy==0.950',
'types-boto',
'types-certifi',
'types-croniter',
Expand Down
8 changes: 6 additions & 2 deletions tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ def test_create_container(self, mock_cosmos):
hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
hook.create_collection(self.test_collection_name, self.test_database_name)
expected_calls = [
mock.call().get_database_client('test_database_name').create_container('test_collection_name')
mock.call()
.get_database_client('test_database_name')
.create_container('test_collection_name', partition_key=None)
]
mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
mock_cosmos.assert_has_calls(expected_calls)
Expand All @@ -101,7 +103,9 @@ def test_create_container_default(self, mock_cosmos):
hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
hook.create_collection(self.test_collection_name)
expected_calls = [
mock.call().get_database_client('test_database_name').create_container('test_collection_name')
mock.call()
.get_database_client('test_database_name')
.create_container('test_collection_name', partition_key=None)
]
mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
mock_cosmos.assert_has_calls(expected_calls)
Expand Down