Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mypy providers #20190

Merged
merged 35 commits into from
Dec 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
a1c532e
[16185] Added LocalKubernetesExecutor to breeze supported executors
subkanthi Nov 22, 2021
6cbb506
Merge branch 'master' of github.com:subkanthi/airflow
subkanthi Nov 22, 2021
d16a343
Revert "[16185] Added LocalKubernetesExecutor to breeze supported exe…
subkanthi Nov 22, 2021
a336e17
Merge branch 'apache:main' into master
subkanthi Nov 30, 2021
382eedf
Merge branch 'apache:main' into master
subkanthi Dec 2, 2021
c3030fb
Merge branch 'apache:main' into master
subkanthi Dec 4, 2021
1c3c029
Merge branch 'apache:main' into master
subkanthi Dec 6, 2021
0eb6598
Merge branch 'apache:main' into master
subkanthi Dec 10, 2021
8ef0bcc
[19891] Fixed mypy in spark_sql
subkanthi Dec 10, 2021
c719166
Fixed mypy error in jira hook
subkanthi Dec 10, 2021
a503ef5
Fixed mypy in cassandra example_dag
subkanthi Dec 10, 2021
37e2fc6
Fixed mypy errors in asana hooks
subkanthi Dec 10, 2021
3c976c5
Fixed mypy errors in telegram provider
subkanthi Dec 10, 2021
3638cd0
Fixed mypy errors in providers/trino
subkanthi Dec 10, 2021
f8cba96
Removed adding arguments to cassandra example dag
subkanthi Dec 12, 2021
31ad8db
Removed default arguments in example_asana
subkanthi Dec 12, 2021
6a117e4
Fixed mypy errors in postgres and slack
subkanthi Dec 12, 2021
35a344e
Merge branch 'apache:main' into master
subkanthi Dec 12, 2021
0f60ae4
Fixed static checks in trino
subkanthi Dec 12, 2021
56b414c
Merge branch 'master' of github.com:subkanthi/airflow into fix_mypy_p…
subkanthi Dec 12, 2021
8d47486
Merge branch 'apache:main' into master
subkanthi Dec 12, 2021
b4daedc
Merge branch 'master' of github.com:subkanthi/airflow into fix_mypy_p…
subkanthi Dec 12, 2021
7a31cd1
Fixed unit tests in test_telegram
subkanthi Dec 12, 2021
e21a9a8
Removed unnecessary files
subkanthi Dec 12, 2021
1fa154a
Fixed mypy errors in providers/docker
subkanthi Dec 12, 2021
a997e30
Merge branch 'fix_mypy_providers' of github.com:subkanthi/airflow int…
subkanthi Dec 13, 2021
12c628e
Addressed PR review comments
subkanthi Dec 13, 2021
dcf537d
More mypy fixes
subkanthi Dec 13, 2021
f450679
Merge branch 'apache:main' into master
subkanthi Dec 14, 2021
96b7a9b
Fixed merge conflicts
subkanthi Dec 14, 2021
67c6f5e
Merge branch 'apache:main' into master
subkanthi Dec 14, 2021
dbfdcc5
Resolved conflicts and static checks
subkanthi Dec 14, 2021
a09d628
Merge branch 'apache:main' into master
subkanthi Dec 14, 2021
816b976
Merge branch 'master' of github.com:subkanthi/airflow into fix_mypy_p…
subkanthi Dec 14, 2021
fa68b40
Removed type:ignore in Docker decorator
subkanthi Dec 14, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion airflow/providers/apache/spark/hooks/spark_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,11 @@ def __init__(
yarn_queue: Optional[str] = None,
) -> None:
super().__init__()
options: Dict = {}
conn: Optional[Connection] = None

try:
conn: "Optional[Connection]" = self.get_connection(conn_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the stringified typing can safely be removed then the TYPE_CHECKING statement can also be removed. You'll need to confirm though with some testing. MyPy should complain if it's an issue.

@kaxil @potiuk Another example of get_connection() in a hook's __init__() here. Oddly enough, this hook has a get_conn() method but it's a null method.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don’t think this is going to work at runtime. The stringified version is still needed.

Copy link
Member

@potiuk potiuk Dec 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm @josh-fell (and @kaxil ) - Something struck, me I just looked closer at the #18339

I think while discussing it I made a silent assumption (which I see now was wrong ) that the connection was created as part of "operator's" init but this is about creating it in the Hook's init, which IMHO is quite legitimate use case (as long as you do not instantiate the Hook in the operator's init(). And it is pretty common pattern in Airflow (and one we actually encourage).

I double-checked and I looked at the Databricks code and the only place I could see it being instantiated was _get_hook() and the only place where _get_hook() is called was inside "execute" and "on_kill" method of the Databricks operator - so that all sounds pretty legitimate.

Was not the whole issue caused by a misunderstanding of who's init it was? I think we have maaaaany cases where Hook is created "on-demand" in execute() method of the operator, also if you decide to create the Hook() inside the @task-decorated functions, it should work really well.
Are my eyes cheating me ? O r maybe I miss something?

conn = self.get_connection(conn_id)
except AirflowNotFoundException:
conn = None
options: Dict = {}
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/asana/example_dags/example_asana.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
AsanaUpdateTaskOperator,
)

ASANA_TASK_TO_UPDATE = os.environ.get("ASANA_TASK_TO_UPDATE")
ASANA_TASK_TO_DELETE = os.environ.get("ASANA_TASK_TO_DELETE")
ASANA_TASK_TO_UPDATE = os.environ.get("ASANA_TASK_TO_UPDATE", "update_task")
ASANA_TASK_TO_DELETE = os.environ.get("ASANA_TASK_TO_DELETE", "delete_task")
# This example assumes a default project ID has been specified in the connection. If you
# provide a different id in ASANA_PROJECT_ID_OVERRIDE, it will override this default
# project ID in the AsanaFindTaskOperator example below
ASANA_PROJECT_ID_OVERRIDE = os.environ.get("ASANA_PROJECT_ID_OVERRIDE")
ASANA_PROJECT_ID_OVERRIDE = os.environ.get("ASANA_PROJECT_ID_OVERRIDE", "test_project")
# This connection should specify a personal access token and a default project ID
CONN_ID = os.environ.get("ASANA_CONNECTION_ID")

Expand Down
15 changes: 8 additions & 7 deletions airflow/providers/asana/hooks/asana.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
# under the License.

"""Connect to Asana."""
from typing import Any, Dict
import sys
from typing import Any, Dict, Optional

from asana import Client
from asana.error import NotFoundError

try:
if sys.version_info >= (3, 8):
from functools import cached_property
except ImportError:
else:
from cached_property import cached_property

from airflow.hooks.base import BaseHook
Expand Down Expand Up @@ -84,7 +85,7 @@ def client(self) -> Client:

return Client.access_token(self.connection.password)

def create_task(self, task_name: str, params: dict) -> dict:
def create_task(self, task_name: str, params: Optional[dict]) -> dict:
"""
Creates an Asana task.

Expand All @@ -98,7 +99,7 @@ def create_task(self, task_name: str, params: dict) -> dict:
response = self.client.tasks.create(params=merged_params)
return response

def _merge_create_task_parameters(self, task_name: str, task_params: dict) -> dict:
def _merge_create_task_parameters(self, task_name: str, task_params: Optional[dict]) -> dict:
"""
Merge create_task parameters with default params from the connection.

Expand Down Expand Up @@ -144,7 +145,7 @@ def delete_task(self, task_id: str) -> dict:
self.log.info("Asana task %s not found for deletion.", task_id)
return {}

def find_task(self, params: dict) -> list:
def find_task(self, params: Optional[dict]) -> list:
"""
Retrieves a list of Asana tasks that match search parameters.

Expand All @@ -157,7 +158,7 @@ def find_task(self, params: dict) -> list:
response = self.client.tasks.find_all(params=merged_params)
return list(response)

def _merge_find_task_parameters(self, search_parameters: dict) -> dict:
def _merge_find_task_parameters(self, search_parameters: Optional[dict]) -> dict:
"""
Merge find_task parameters with default params from the connection.

Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/jira/hooks/jira.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, jira_conn_id: str = default_conn_name, proxies: Optional[Any]
super().__init__()
self.jira_conn_id = jira_conn_id
self.proxies = proxies
self.client = None
self.client: Optional[JIRA] = None
self.get_conn()

def get_conn(self) -> JIRA:
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/postgres/hooks/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,13 +211,13 @@ def get_iam_token(self, conn: Connection) -> Tuple[str, str, int]:
token = aws_hook.conn.generate_db_auth_token(conn.host, port, conn.login)
return login, token, port

def get_table_primary_key(self, table: str, schema: Optional[str] = "public") -> List[str]:
def get_table_primary_key(self, table: str, schema: Optional[str] = "public") -> Optional[List[str]]:
"""
Helper method that returns the table primary key

subkanthi marked this conversation as resolved.
Show resolved Hide resolved
:param table: Name of the target table
:type table: str
:param table: Name of the target schema, public by default
:param schema: Name of the target schema, public by default
:type table: str
:return: Primary key columns list
:rtype: List[str]
Expand Down
8 changes: 4 additions & 4 deletions airflow/providers/slack/operators/slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,9 @@ def __init__(
self,
channel: str = '#general',
initial_comment: str = 'No message has been set!',
filename: str = None,
filetype: str = None,
content: str = None,
filename: Optional[str] = None,
filetype: Optional[str] = None,
content: Optional[str] = None,
**kwargs,
) -> None:
self.method = 'files.upload'
Expand All @@ -212,7 +212,7 @@ def __init__(
self.filename = filename
self.filetype = filetype
self.content = content
self.file_params = {}
self.file_params: Dict = {}
super().__init__(method=self.method, **kwargs)

def execute(self, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/telegram/hooks/telegram.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def get_conn(self) -> telegram.bot.Bot:
"""
return telegram.bot.Bot(token=self.token)

def __get_token(self, token: Optional[str], telegram_conn_id: str) -> str:
def __get_token(self, token: Optional[str], telegram_conn_id: Optional[str]) -> str:
"""
Returns the telegram API token

Expand All @@ -103,7 +103,7 @@ def __get_token(self, token: Optional[str], telegram_conn_id: str) -> str:

raise AirflowException("Cannot get token: No valid Telegram connection supplied.")

def __get_chat_id(self, chat_id: Optional[str], telegram_conn_id: str) -> Optional[str]:
def __get_chat_id(self, chat_id: Optional[str], telegram_conn_id: Optional[str]) -> Optional[str]:
"""
Returns the telegram chat ID for a chat/channel/group

Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/telegram/operators/telegram.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
"""Operator for Telegram"""
from typing import Optional
from typing import Dict, Optional

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
Expand Down Expand Up @@ -70,7 +70,7 @@ def __init__(

super().__init__(**kwargs)

def execute(self, **kwargs) -> None:
def execute(self, context: Dict) -> None:
"""Calls the TelegramHook to post the provided Telegram message"""
if self.text:
self.telegram_kwargs['text'] = self.text
Expand Down
7 changes: 1 addition & 6 deletions airflow/providers/trino/hooks/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,7 @@ def get_pandas_df(self, hql, parameters=None, **kwargs):
df = pandas.DataFrame(**kwargs)
return df

def run(
self,
hql,
autocommit: bool = False,
parameters: Optional[dict] = None,
) -> None:
def run(self, hql, autocommit: bool = False, parameters: Optional[dict] = None, handler=None) -> None:
"""Execute the statement against Trino. Can be used to create views."""
return super().run(sql=self._strip_sql(hql), parameters=parameters)

Expand Down
11 changes: 6 additions & 5 deletions tests/providers/telegram/operators/test_telegram.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_should_send_message_when_all_parameters_are_provided(self, mock_telegra
task_id='telegram',
text="some non empty text",
)
hook.execute()
hook.execute(None)

mock_telegram_hook.assert_called_once_with(
telegram_conn_id='telegram_default',
Expand Down Expand Up @@ -89,7 +89,7 @@ def side_effect(*args, **kwargs):
task_id='telegram',
text="some non empty text",
)
hook.execute()
hook.execute(None)

assert "cosmic rays caused bit flips" == str(ctx.value)

Expand All @@ -105,7 +105,7 @@ def test_should_forward_all_args_to_telegram(self, mock_telegram_hook):
text="some non empty text",
telegram_kwargs={"custom_arg": "value"},
)
hook.execute()
hook.execute(None)

mock_telegram_hook.assert_called_once_with(
telegram_conn_id='telegram_default',
Expand All @@ -128,7 +128,7 @@ def test_should_give_precedence_to_text_passed_in_constructor(self, mock_telegra
text="some non empty text - higher precedence",
telegram_kwargs={"custom_arg": "value", "text": "some text, that will be ignored"},
)
hook.execute()
hook.execute(None)

mock_telegram_hook.assert_called_once_with(
telegram_conn_id='telegram_default',
Expand Down Expand Up @@ -159,7 +159,8 @@ def test_should_return_templatized_text_field(self, mock_hook):
telegram_kwargs={"custom_arg": "value", "text": "should be ignored"},
)
operator.render_template_fields({"ds": "2021-02-04"})
operator.execute()

operator.execute(None)
assert operator.text == "execution date is 2021-02-04"
assert 'text' in operator.telegram_kwargs
assert operator.telegram_kwargs['text'] == "execution date is 2021-02-04"
Expand Down