Skip to content

Commit

Permalink
Strict type check for multiple providers (#11229)
Browse files Browse the repository at this point in the history
  • Loading branch information
mlgruby authored Oct 2, 2020
1 parent c74b3ac commit 720912f
Show file tree
Hide file tree
Showing 15 changed files with 186 additions and 139 deletions.
22 changes: 12 additions & 10 deletions airflow/providers/dingding/hooks/dingding.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
# under the License.

import json
from typing import Union, Optional, List

import requests
from requests import Session

from airflow.exceptions import AirflowException
from airflow.providers.http.hooks.http import HttpHook
Expand Down Expand Up @@ -49,20 +51,20 @@ class DingdingHook(HttpHook):
def __init__(
self,
dingding_conn_id='dingding_default',
message_type='text',
message=None,
at_mobiles=None,
at_all=False,
message_type: str = 'text',
message: Optional[Union[str, dict]] = None,
at_mobiles: Optional[List[str]] = None,
at_all: bool = False,
*args,
**kwargs,
):
super().__init__(http_conn_id=dingding_conn_id, *args, **kwargs)
) -> None:
super().__init__(http_conn_id=dingding_conn_id, *args, **kwargs) # type: ignore[misc]
self.message_type = message_type
self.message = message
self.at_mobiles = at_mobiles
self.at_all = at_all

def _get_endpoint(self):
def _get_endpoint(self) -> str:
"""
Get Dingding endpoint for sending message.
"""
Expand All @@ -74,7 +76,7 @@ def _get_endpoint(self):
)
return 'robot/send?access_token={}'.format(token)

def _build_message(self):
def _build_message(self) -> str:
"""
Build different type of Dingding message
As most commonly used type, text message just need post message content
Expand All @@ -90,7 +92,7 @@ def _build_message(self):
data = {'msgtype': self.message_type, self.message_type: self.message}
return json.dumps(data)

def get_conn(self, headers=None):
def get_conn(self, headers: Optional[dict] = None) -> Session:
"""
Overwrite HttpHook get_conn because just need base_url and headers and
not don't need generic params
Expand All @@ -105,7 +107,7 @@ def get_conn(self, headers=None):
session.headers.update(headers)
return session

def send(self):
def send(self) -> None:
"""
Send Dingding message
"""
Expand Down
15 changes: 8 additions & 7 deletions airflow/providers/dingding/operators/dingding.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Union, Optional, List

from airflow.operators.bash import BaseOperator
from airflow.providers.dingding.hooks.dingding import DingdingHook
Expand Down Expand Up @@ -50,21 +51,21 @@ class DingdingOperator(BaseOperator):
def __init__(
self,
*,
dingding_conn_id='dingding_default',
message_type='text',
message=None,
at_mobiles=None,
at_all=False,
dingding_conn_id: str = 'dingding_default',
message_type: str = 'text',
message: Union[str, dict, None] = None,
at_mobiles: Optional[List[str]] = None,
at_all: bool = False,
**kwargs,
):
) -> None:
super().__init__(**kwargs)
self.dingding_conn_id = dingding_conn_id
self.message_type = message_type
self.message = message
self.at_mobiles = at_mobiles
self.at_all = at_all

def execute(self, context):
def execute(self, context) -> None:
self.log.info('Sending Dingding message.')
hook = DingdingHook(
self.dingding_conn_id, self.message_type, self.message, self.at_mobiles, self.at_all
Expand Down
11 changes: 6 additions & 5 deletions airflow/providers/opsgenie/hooks/opsgenie_alert.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#

import json
from typing import Optional, Any

import requests

Expand All @@ -40,10 +41,10 @@ class OpsgenieAlertHook(HttpHook):
"""

def __init__(self, opsgenie_conn_id='opsgenie_default', *args, **kwargs):
super().__init__(http_conn_id=opsgenie_conn_id, *args, **kwargs)
def __init__(self, opsgenie_conn_id: str = 'opsgenie_default', *args, **kwargs) -> None:
super().__init__(http_conn_id=opsgenie_conn_id, *args, **kwargs) # type: ignore[misc]

def _get_api_key(self):
def _get_api_key(self) -> str:
"""
Get Opsgenie api_key for creating alert
"""
Expand All @@ -55,7 +56,7 @@ def _get_api_key(self):
)
return api_key

def get_conn(self, headers=None):
def get_conn(self, headers: Optional[dict] = None) -> requests.Session:
"""
Overwrite HttpHook get_conn because this hook just needs base_url
and headers, and does not need generic params
Expand All @@ -70,7 +71,7 @@ def get_conn(self, headers=None):
session.headers.update(headers)
return session

def execute(self, payload=None):
def execute(self, payload: Optional[dict] = None) -> Any:
"""
Execute the Opsgenie Alert call
Expand Down
38 changes: 20 additions & 18 deletions airflow/providers/opsgenie/operators/opsgenie_alert.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# specific language governing permissions and limitations
# under the License.
#
from typing import Optional, List, Dict, Any

from airflow.models import BaseOperator
from airflow.providers.opsgenie.hooks.opsgenie_alert import OpsgenieAlertHook
from airflow.utils.decorators import apply_defaults
Expand Down Expand Up @@ -72,22 +74,22 @@ class OpsgenieAlertOperator(BaseOperator):
def __init__(
self,
*,
message,
opsgenie_conn_id='opsgenie_default',
alias=None,
description=None,
responders=None,
visible_to=None,
actions=None,
tags=None,
details=None,
entity=None,
source=None,
priority=None,
user=None,
note=None,
message: str,
opsgenie_conn_id: str = 'opsgenie_default',
alias: Optional[str] = None,
description: Optional[str] = None,
responders: Optional[List[dict]] = None,
visible_to: Optional[List[dict]] = None,
actions: Optional[List[dict]] = None,
tags: Optional[List[dict]] = None,
details: Optional[dict] = None,
entity: Optional[str] = None,
source: Optional[str] = None,
priority: Optional[str] = None,
user: Optional[str] = None,
note: Optional[str] = None,
**kwargs,
):
) -> None:
super().__init__(**kwargs)

self.message = message
Expand All @@ -104,9 +106,9 @@ def __init__(
self.priority = priority
self.user = user
self.note = note
self.hook = None
self.hook: Optional[OpsgenieAlertHook] = None

def _build_opsgenie_payload(self):
def _build_opsgenie_payload(self) -> Dict[str, Any]:
"""
Construct the Opsgenie JSON payload. All relevant parameters are combined here
to a valid Opsgenie JSON payload.
Expand Down Expand Up @@ -135,7 +137,7 @@ def _build_opsgenie_payload(self):
payload[key] = val
return payload

def execute(self, context):
def execute(self, context) -> None:
"""
Call the OpsgenieAlertHook to post message
"""
Expand Down
46 changes: 34 additions & 12 deletions airflow/providers/presto/hooks/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Optional, Any, Iterable

import prestodb
from prestodb.exceptions import DatabaseError
from prestodb.transaction import IsolationLevel

from airflow.hooks.dbapi_hook import DbApiHook
from airflow.models import Connection


class PrestoException(Exception):
Expand All @@ -41,9 +44,11 @@ class PrestoHook(DbApiHook):
conn_name_attr = 'presto_conn_id'
default_conn_name = 'presto_default'

def get_conn(self):
def get_conn(self) -> Connection:
"""Returns a connection object"""
db = self.get_connection(self.presto_conn_id) # pylint: disable=no-member
db = self.get_connection(
self.presto_conn_id # type: ignore[attr-defined] # pylint: disable=no-member
)
auth = prestodb.auth.BasicAuthentication(db.login, db.password) if db.password else None

return prestodb.dbapi.connect(
Expand All @@ -55,20 +60,22 @@ def get_conn(self):
catalog=db.extra_dejson.get('catalog', 'hive'),
schema=db.schema,
auth=auth,
isolation_level=self.get_isolation_level(),
isolation_level=self.get_isolation_level(), # type: ignore[func-returns-value]
)

def get_isolation_level(self):
def get_isolation_level(self) -> Any:
"""Returns an isolation level"""
db = self.get_connection(self.presto_conn_id) # pylint: disable=no-member
db = self.get_connection(
self.presto_conn_id # type: ignore[attr-defined] # pylint: disable=no-member
)
isolation_level = db.extra_dejson.get('isolation_level', 'AUTOCOMMIT').upper()
return getattr(IsolationLevel, isolation_level, IsolationLevel.AUTOCOMMIT)

@staticmethod
def _strip_sql(sql):
def _strip_sql(sql: str) -> str:
return sql.strip().rstrip(';')

def get_records(self, hql, parameters=None):
def get_records(self, hql, parameters: Optional[dict] = None):
"""
Get a set of records from Presto
"""
Expand All @@ -77,7 +84,7 @@ def get_records(self, hql, parameters=None):
except DatabaseError as e:
raise PrestoException(e)

def get_first(self, hql, parameters=None):
def get_first(self, hql: str, parameters: Optional[dict] = None) -> Any:
"""
Returns only the first row, regardless of how many rows the query
returns.
Expand Down Expand Up @@ -107,13 +114,26 @@ def get_pandas_df(self, hql, parameters=None, **kwargs):
df = pandas.DataFrame(**kwargs)
return df

def run(self, hql, parameters=None):
def run(
self,
hql,
autocommit: bool = False,
parameters: Optional[dict] = None,
) -> None:
"""
Execute the statement against Presto. Can be used to create views.
"""
return super().run(self._strip_sql(hql), parameters)

def insert_rows(self, table, rows, target_fields=None, commit_every=0):
return super().run(sql=self._strip_sql(hql), parameters=parameters)

def insert_rows(
self,
table: str,
rows: Iterable[tuple],
target_fields: Optional[Iterable[str]] = None,
commit_every: int = 0,
replace: bool = False,
**kwargs,
) -> None:
"""
A generic way to insert a set of tuples into a table.
Expand All @@ -126,6 +146,8 @@ def insert_rows(self, table, rows, target_fields=None, commit_every=0):
:param commit_every: The maximum number of rows to insert in one
transaction. Set to 0 to insert all rows in one transaction.
:type commit_every: int
:param replace: Whether to replace instead of insert
:type replace: bool
"""
if self.get_isolation_level() == IsolationLevel.AUTOCOMMIT:
self.log.info(
Expand Down
Loading

0 comments on commit 720912f

Please sign in to comment.