Skip to content

Commit

Permalink
Add conditional template_fields_renderers check for new SQL lexers (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
josh-fell authored Feb 8, 2022
1 parent 34d63fa commit 8f81b9a
Show file tree
Hide file tree
Showing 9 changed files with 49 additions and 10 deletions.
6 changes: 5 additions & 1 deletion airflow/providers/amazon/aws/operators/redshift_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
from airflow.www import utils as wwwutils

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand All @@ -44,7 +45,10 @@ class RedshiftSQLOperator(BaseOperator):

template_fields: Sequence[str] = ('sql',)
template_ext: Sequence[str] = ('.sql',)
template_fields_renderers = {"sql": "postgresql"}
# TODO: Remove renderer check when the provider has an Airflow 2.3+ requirement.
template_fields_renderers = {
"sql": "postgresql" if "postgresql" in wwwutils.get_attr_renderer() else "sql"
}

def __init__(
self,
Expand Down
10 changes: 9 additions & 1 deletion airflow/providers/apache/hive/transfers/hive_to_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,14 @@
from airflow.providers.apache.hive.hooks.hive import HiveServer2Hook
from airflow.providers.mysql.hooks.mysql import MySqlHook
from airflow.utils.operator_helpers import context_to_airflow_vars
from airflow.www import utils as wwwutils

if TYPE_CHECKING:
from airflow.utils.context import Context

# TODO: Remove renderer check when the provider has an Airflow 2.3+ requirement.
MYSQL_RENDERER = 'mysql' if 'mysql' in wwwutils.get_attr_renderer() else 'sql'


class HiveToMySqlOperator(BaseOperator):
"""
Expand Down Expand Up @@ -57,7 +61,11 @@ class HiveToMySqlOperator(BaseOperator):

template_fields: Sequence[str] = ('sql', 'mysql_table', 'mysql_preoperator', 'mysql_postoperator')
template_ext: Sequence[str] = ('.sql',)
template_fields_renderers = {'sql': 'hql', 'mysql_preoperator': 'mysql', 'mysql_postoperator': 'mysql'}
template_fields_renderers = {
'sql': 'hql',
'mysql_preoperator': MYSQL_RENDERER,
'mysql_postoperator': MYSQL_RENDERER,
}
ui_color = '#a0e08c'

def __init__(
Expand Down
4 changes: 3 additions & 1 deletion airflow/providers/apache/hive/transfers/mssql_to_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from airflow.models import BaseOperator
from airflow.providers.apache.hive.hooks.hive import HiveCliHook
from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook
from airflow.www import utils as wwwutils

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -65,7 +66,8 @@ class MsSqlToHiveOperator(BaseOperator):

template_fields: Sequence[str] = ('sql', 'partition', 'hive_table')
template_ext: Sequence[str] = ('.sql',)
template_fields_renderers = {'sql': 'tsql'}
# TODO: Remove renderer check when the provider has an Airflow 2.3+ requirement.
template_fields_renderers = {'sql': 'tsql' if 'tsql' in wwwutils.get_attr_renderer() else 'sql'}
ui_color = '#a0e08c'

def __init__(
Expand Down
4 changes: 3 additions & 1 deletion airflow/providers/microsoft/mssql/operators/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook
from airflow.www import utils as wwwutils

if TYPE_CHECKING:
from airflow.hooks.dbapi import DbApiHook
Expand Down Expand Up @@ -49,7 +50,8 @@ class MsSqlOperator(BaseOperator):

template_fields: Sequence[str] = ('sql',)
template_ext: Sequence[str] = ('.sql',)
template_fields_renderers = {'sql': 'tsql'}
# TODO: Remove renderer check when the provider has an Airflow 2.3+ requirement.
template_fields_renderers = {'sql': 'tsql' if 'tsql' in wwwutils.get_attr_renderer() else 'sql'}
ui_color = '#ededed'

def __init__(
Expand Down
7 changes: 6 additions & 1 deletion airflow/providers/mysql/operators/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from airflow.models import BaseOperator
from airflow.providers.mysql.hooks.mysql import MySqlHook
from airflow.www import utils as wwwutils

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -47,7 +48,11 @@ class MySqlOperator(BaseOperator):
"""

template_fields: Sequence[str] = ('sql', 'parameters')
template_fields_renderers = {'sql': 'mysql', 'parameters': 'json'}
# TODO: Remove renderer check when the provider has an Airflow 2.3+ requirement.
template_fields_renderers = {
'sql': 'mysql' if 'mysql' in wwwutils.get_attr_renderer() else 'sql',
'parameters': 'json',
}
template_ext: Sequence[str] = ('.sql', '.json')
ui_color = '#ededed'

Expand Down
7 changes: 6 additions & 1 deletion airflow/providers/mysql/transfers/presto_to_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from airflow.models import BaseOperator
from airflow.providers.mysql.hooks.mysql import MySqlHook
from airflow.providers.presto.hooks.presto import PrestoHook
from airflow.www import utils as wwwutils

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand All @@ -44,7 +45,11 @@ class PrestoToMySqlOperator(BaseOperator):

template_fields: Sequence[str] = ('sql', 'mysql_table', 'mysql_preoperator')
template_ext: Sequence[str] = ('.sql',)
template_fields_renderers = {"sql": "sql", "mysql_preoperator": "mysql"}
# TODO: Remove renderer check when the provider has an Airflow 2.3+ requirement.
template_fields_renderers = {
"sql": "sql",
"mysql_preoperator": "mysql" if "mysql" in wwwutils.get_attr_renderer() else "sql",
}
ui_color = '#a0e08c'

def __init__(
Expand Down
7 changes: 6 additions & 1 deletion airflow/providers/mysql/transfers/trino_to_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from airflow.models import BaseOperator
from airflow.providers.mysql.hooks.mysql import MySqlHook
from airflow.providers.trino.hooks.trino import TrinoHook
from airflow.www import utils as wwwutils

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand All @@ -44,7 +45,11 @@ class TrinoToMySqlOperator(BaseOperator):

template_fields: Sequence[str] = ('sql', 'mysql_table', 'mysql_preoperator')
template_ext: Sequence[str] = ('.sql',)
template_fields_renderers = {"sql": "sql", "mysql_preoperator": "mysql"}
# TODO: Remove renderer check when the provider has an Airflow 2.3+ requirement.
template_fields_renderers = {
"sql": "sql",
"mysql_preoperator": "mysql" if "mysql" in wwwutils.get_attr_renderer() else "sql",
}
ui_color = '#a0e08c'

def __init__(
Expand Down
8 changes: 6 additions & 2 deletions airflow/providers/mysql/transfers/vertica_to_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,14 @@
from airflow.models import BaseOperator
from airflow.providers.mysql.hooks.mysql import MySqlHook
from airflow.providers.vertica.hooks.vertica import VerticaHook
from airflow.www import utils as wwwutils

if TYPE_CHECKING:
from airflow.utils.context import Context

# TODO: Remove renderer check when the provider has an Airflow 2.3+ requirement.
MYSQL_RENDERER = 'mysql' if 'mysql' in wwwutils.get_attr_renderer() else 'sql'


class VerticaToMySqlOperator(BaseOperator):
"""
Expand Down Expand Up @@ -57,8 +61,8 @@ class VerticaToMySqlOperator(BaseOperator):
template_ext: Sequence[str] = ('.sql',)
template_fields_renderers = {
"sql": "sql",
"mysql_preoperator": "mysql",
"mysql_postoperator": "mysql",
"mysql_preoperator": MYSQL_RENDERER,
"mysql_postoperator": MYSQL_RENDERER,
}
ui_color = '#a0e08c'

Expand Down
6 changes: 5 additions & 1 deletion airflow/providers/postgres/operators/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from airflow.models import BaseOperator
from airflow.providers.postgres.hooks.postgres import PostgresHook
from airflow.www import utils as wwwutils

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand All @@ -40,7 +41,10 @@ class PostgresOperator(BaseOperator):
"""

template_fields: Sequence[str] = ('sql',)
template_fields_renderers = {'sql': 'postgresql'}
# TODO: Remove renderer check when the provider has an Airflow 2.3+ requirement.
template_fields_renderers = {
'sql': 'postgresql' if 'postgresql' in wwwutils.get_attr_renderer() else 'sql'
}
template_ext: Sequence[str] = ('.sql',)
ui_color = '#ededed'

Expand Down

0 comments on commit 8f81b9a

Please sign in to comment.