Skip to content

Commit

Permalink
Parse template parameters field for MySQL operator (#17080)
Browse files Browse the repository at this point in the history
  • Loading branch information
oyarushe authored Jul 19, 2021
1 parent a4af964 commit a1d3b27
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
8 changes: 8 additions & 0 deletions airflow/providers/mysql/operators/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import ast
from typing import Dict, Iterable, Mapping, Optional, Union

from airflow.models import BaseOperator
Expand All @@ -37,6 +38,8 @@ class MySqlOperator(BaseOperator):
:param mysql_conn_id: Reference to :ref:`mysql connection id <howto/connection:mysql>`.
:type mysql_conn_id: str
:param parameters: (optional) the parameters to render the SQL query with.
Template reference are recognized by str ending in '.json'
(templated)
:type parameters: dict or iterable
:param autocommit: if True, each command is automatically committed.
(default value: False)
Expand Down Expand Up @@ -67,6 +70,11 @@ def __init__(
self.parameters = parameters
self.database = database

def prepare_template(self) -> None:
"""Parse template file for attribute parameters."""
if isinstance(self.parameters, str):
self.parameters = ast.literal_eval(self.parameters)

def execute(self, context: Dict) -> None:
self.log.info('Executing: %s', self.sql)
hook = MySqlHook(mysql_conn_id=self.mysql_conn_id, schema=self.database)
Expand Down
18 changes: 18 additions & 0 deletions tests/providers/mysql/operators/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import unittest
from contextlib import closing
from tempfile import NamedTemporaryFile

import pytest
from parameterized import parameterized
Expand Down Expand Up @@ -108,3 +110,19 @@ def test_overwrite_schema(self, client):
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
except OperationalError as e:
assert "Unknown database 'foobar'" in str(e)

def test_mysql_operator_resolve_parameters_template_json_file(self):

with NamedTemporaryFile(suffix='.json') as f:
f.write(b"{\n \"foo\": \"{{ ds }}\"}")
f.flush()
template_dir = os.path.dirname(f.name)
template_file = os.path.basename(f.name)

with DAG("test-dag", start_date=DEFAULT_DATE, template_searchpath=template_dir):
task = MySqlOperator(task_id="op1", parameters=template_file, sql="SELECT 1")

task.resolve_template_files()

assert isinstance(task.parameters, dict)
assert task.parameters["foo"] == "{{ ds }}"

0 comments on commit a1d3b27

Please sign in to comment.