Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Strict type check for multiple providers #11229

Merged
merged 10 commits into from
Oct 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions airflow/providers/dingding/hooks/dingding.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
# under the License.

import json
from typing import Union, Optional, List

import requests
from requests import Session

from airflow.exceptions import AirflowException
from airflow.providers.http.hooks.http import HttpHook
Expand Down Expand Up @@ -49,20 +51,20 @@ class DingdingHook(HttpHook):
def __init__(
self,
dingding_conn_id='dingding_default',
message_type='text',
message=None,
at_mobiles=None,
at_all=False,
message_type: str = 'text',
message: Optional[Union[str, dict]] = None,
at_mobiles: Optional[List[str]] = None,
at_all: bool = False,
*args,
**kwargs,
):
super().__init__(http_conn_id=dingding_conn_id, *args, **kwargs)
) -> None:
super().__init__(http_conn_id=dingding_conn_id, *args, **kwargs) # type: ignore[misc]
self.message_type = message_type
self.message = message
self.at_mobiles = at_mobiles
self.at_all = at_all

def _get_endpoint(self):
def _get_endpoint(self) -> str:
"""
Get Dingding endpoint for sending message.
"""
Expand All @@ -74,7 +76,7 @@ def _get_endpoint(self):
)
return 'robot/send?access_token={}'.format(token)

def _build_message(self):
def _build_message(self) -> str:
"""
Build different type of Dingding message
As most commonly used type, text message just need post message content
Expand All @@ -90,7 +92,7 @@ def _build_message(self):
data = {'msgtype': self.message_type, self.message_type: self.message}
return json.dumps(data)

def get_conn(self, headers=None):
def get_conn(self, headers: Optional[dict] = None) -> Session:
"""
Overwrite HttpHook get_conn because just need base_url and headers and
not don't need generic params
Expand All @@ -105,7 +107,7 @@ def get_conn(self, headers=None):
session.headers.update(headers)
return session

def send(self):
def send(self) -> None:
"""
Send Dingding message
"""
Expand Down
15 changes: 8 additions & 7 deletions airflow/providers/dingding/operators/dingding.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Union, Optional, List

from airflow.operators.bash import BaseOperator
from airflow.providers.dingding.hooks.dingding import DingdingHook
Expand Down Expand Up @@ -50,21 +51,21 @@ class DingdingOperator(BaseOperator):
def __init__(
self,
*,
dingding_conn_id='dingding_default',
message_type='text',
message=None,
at_mobiles=None,
at_all=False,
dingding_conn_id: str = 'dingding_default',
message_type: str = 'text',
message: Union[str, dict, None] = None,
at_mobiles: Optional[List[str]] = None,
at_all: bool = False,
**kwargs,
):
) -> None:
super().__init__(**kwargs)
self.dingding_conn_id = dingding_conn_id
self.message_type = message_type
self.message = message
self.at_mobiles = at_mobiles
self.at_all = at_all

def execute(self, context):
def execute(self, context) -> None:
self.log.info('Sending Dingding message.')
hook = DingdingHook(
self.dingding_conn_id, self.message_type, self.message, self.at_mobiles, self.at_all
Expand Down
11 changes: 6 additions & 5 deletions airflow/providers/opsgenie/hooks/opsgenie_alert.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#

import json
from typing import Optional, Any

import requests

Expand All @@ -40,10 +41,10 @@ class OpsgenieAlertHook(HttpHook):

"""

def __init__(self, opsgenie_conn_id='opsgenie_default', *args, **kwargs):
super().__init__(http_conn_id=opsgenie_conn_id, *args, **kwargs)
def __init__(self, opsgenie_conn_id: str = 'opsgenie_default', *args, **kwargs) -> None:
super().__init__(http_conn_id=opsgenie_conn_id, *args, **kwargs) # type: ignore[misc]

def _get_api_key(self):
def _get_api_key(self) -> str:
"""
Get Opsgenie api_key for creating alert
"""
Expand All @@ -55,7 +56,7 @@ def _get_api_key(self):
)
return api_key

def get_conn(self, headers=None):
def get_conn(self, headers: Optional[dict] = None) -> requests.Session:
"""
Overwrite HttpHook get_conn because this hook just needs base_url
and headers, and does not need generic params
Expand All @@ -70,7 +71,7 @@ def get_conn(self, headers=None):
session.headers.update(headers)
return session

def execute(self, payload=None):
def execute(self, payload: Optional[dict] = None) -> Any:
"""
Execute the Opsgenie Alert call

Expand Down
38 changes: 20 additions & 18 deletions airflow/providers/opsgenie/operators/opsgenie_alert.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# specific language governing permissions and limitations
# under the License.
#
from typing import Optional, List, Dict, Any

from airflow.models import BaseOperator
from airflow.providers.opsgenie.hooks.opsgenie_alert import OpsgenieAlertHook
from airflow.utils.decorators import apply_defaults
Expand Down Expand Up @@ -72,22 +74,22 @@ class OpsgenieAlertOperator(BaseOperator):
def __init__(
self,
*,
message,
opsgenie_conn_id='opsgenie_default',
alias=None,
description=None,
responders=None,
visible_to=None,
actions=None,
tags=None,
details=None,
entity=None,
source=None,
priority=None,
user=None,
note=None,
message: str,
opsgenie_conn_id: str = 'opsgenie_default',
alias: Optional[str] = None,
description: Optional[str] = None,
responders: Optional[List[dict]] = None,
visible_to: Optional[List[dict]] = None,
actions: Optional[List[dict]] = None,
tags: Optional[List[dict]] = None,
details: Optional[dict] = None,
entity: Optional[str] = None,
source: Optional[str] = None,
priority: Optional[str] = None,
user: Optional[str] = None,
note: Optional[str] = None,
**kwargs,
):
) -> None:
super().__init__(**kwargs)

self.message = message
Expand All @@ -104,9 +106,9 @@ def __init__(
self.priority = priority
self.user = user
self.note = note
self.hook = None
self.hook: Optional[OpsgenieAlertHook] = None

def _build_opsgenie_payload(self):
def _build_opsgenie_payload(self) -> Dict[str, Any]:
"""
Construct the Opsgenie JSON payload. All relevant parameters are combined here
to a valid Opsgenie JSON payload.
Expand Down Expand Up @@ -135,7 +137,7 @@ def _build_opsgenie_payload(self):
payload[key] = val
return payload

def execute(self, context):
def execute(self, context) -> None:
"""
Call the OpsgenieAlertHook to post message
"""
Expand Down
46 changes: 34 additions & 12 deletions airflow/providers/presto/hooks/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Optional, Any, Iterable

import prestodb
from prestodb.exceptions import DatabaseError
from prestodb.transaction import IsolationLevel

from airflow.hooks.dbapi_hook import DbApiHook
from airflow.models import Connection


class PrestoException(Exception):
Expand All @@ -41,9 +44,11 @@ class PrestoHook(DbApiHook):
conn_name_attr = 'presto_conn_id'
default_conn_name = 'presto_default'

def get_conn(self):
def get_conn(self) -> Connection:
"""Returns a connection object"""
db = self.get_connection(self.presto_conn_id) # pylint: disable=no-member
db = self.get_connection(
self.presto_conn_id # type: ignore[attr-defined] # pylint: disable=no-member
)
auth = prestodb.auth.BasicAuthentication(db.login, db.password) if db.password else None

return prestodb.dbapi.connect(
Expand All @@ -55,20 +60,22 @@ def get_conn(self):
catalog=db.extra_dejson.get('catalog', 'hive'),
schema=db.schema,
auth=auth,
isolation_level=self.get_isolation_level(),
isolation_level=self.get_isolation_level(), # type: ignore[func-returns-value]
)

def get_isolation_level(self):
def get_isolation_level(self) -> Any:
"""Returns an isolation level"""
db = self.get_connection(self.presto_conn_id) # pylint: disable=no-member
db = self.get_connection(
self.presto_conn_id # type: ignore[attr-defined] # pylint: disable=no-member
)
isolation_level = db.extra_dejson.get('isolation_level', 'AUTOCOMMIT').upper()
return getattr(IsolationLevel, isolation_level, IsolationLevel.AUTOCOMMIT)

@staticmethod
def _strip_sql(sql):
def _strip_sql(sql: str) -> str:
return sql.strip().rstrip(';')

def get_records(self, hql, parameters=None):
def get_records(self, hql, parameters: Optional[dict] = None):
"""
Get a set of records from Presto
"""
Expand All @@ -77,7 +84,7 @@ def get_records(self, hql, parameters=None):
except DatabaseError as e:
raise PrestoException(e)

def get_first(self, hql, parameters=None):
def get_first(self, hql: str, parameters: Optional[dict] = None) -> Any:
"""
Returns only the first row, regardless of how many rows the query
returns.
Expand Down Expand Up @@ -107,13 +114,26 @@ def get_pandas_df(self, hql, parameters=None, **kwargs):
df = pandas.DataFrame(**kwargs)
return df

def run(self, hql, parameters=None):
def run(
self,
hql,
autocommit: bool = False,
parameters: Optional[dict] = None,
) -> None:
"""
Execute the statement against Presto. Can be used to create views.
"""
return super().run(self._strip_sql(hql), parameters)

def insert_rows(self, table, rows, target_fields=None, commit_every=0):
return super().run(sql=self._strip_sql(hql), parameters=parameters)

def insert_rows(
self,
table: str,
rows: Iterable[tuple],
target_fields: Optional[Iterable[str]] = None,
commit_every: int = 0,
replace: bool = False,
**kwargs,
) -> None:
"""
A generic way to insert a set of tuples into a table.

Expand All @@ -126,6 +146,8 @@ def insert_rows(self, table, rows, target_fields=None, commit_every=0):
:param commit_every: The maximum number of rows to insert in one
transaction. Set to 0 to insert all rows in one transaction.
:type commit_every: int
:param replace: Whether to replace instead of insert
:type replace: bool
"""
if self.get_isolation_level() == IsolationLevel.AUTOCOMMIT:
self.log.info(
Expand Down
Loading