Skip to content

Commit

Permalink
Fix assignment of template field in __init__ in `CloudDataTransferS…
Browse files Browse the repository at this point in the history
…erviceCreateJobOperator` (#36909)

* fix initialization of templated field in constructor

* remove file from exclude

* add test for templated field

* change body to be realistic
  • Loading branch information
romsharon98 authored Feb 4, 2024
1 parent caec4c7 commit 46470ab
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,6 @@ repos:
^airflow\/providers\/google\/cloud\/operators\/bigquery\.py$|
^airflow\/providers\/amazon\/aws\/transfers\/gcs_to_s3\.py$|
^airflow\/providers\/databricks\/operators\/databricks\.py$|
^airflow\/providers\/google\/cloud\/operators\/cloud_storage_transfer_service\.py$|
^airflow\/providers\/google\/cloud\/transfers\/bigquery_to_mysql\.py$|
^airflow\/providers\/amazon\/aws\/transfers\/redshift_to_s3\.py$|
^airflow\/providers\/google\/cloud\/operators\/compute\.py$|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,9 @@ def __init__(
**kwargs,
) -> None:
super().__init__(**kwargs)
self.body = deepcopy(body)
self.body = body
if isinstance(self.body, dict):
self.body = deepcopy(body)
self.aws_conn_id = aws_conn_id
self.gcp_conn_id = gcp_conn_id
self.api_version = api_version
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@
SCHEDULE: SCHEDULE_DICT,
TRANSFER_SPEC: {GCS_DATA_SINK: {BUCKET_NAME: GCS_BUCKET_NAME, PATH: DESTINATION_PATH}},
}
VALID_TRANSFER_JOB_JINJA = deepcopy(VALID_TRANSFER_JOB_BASE)
VALID_TRANSFER_JOB_JINJA[NAME] = "{{ dag.dag_id }}"
VALID_TRANSFER_JOB_JINJA_RENDERED = deepcopy(VALID_TRANSFER_JOB_JINJA)
VALID_TRANSFER_JOB_JINJA_RENDERED[NAME] = "TestGcpStorageTransferJobCreateOperator"
VALID_TRANSFER_JOB_GCS = deepcopy(VALID_TRANSFER_JOB_BASE)
VALID_TRANSFER_JOB_GCS[TRANSFER_SPEC].update(deepcopy(SOURCE_GCS))
VALID_TRANSFER_JOB_AWS = deepcopy(VALID_TRANSFER_JOB_BASE)
Expand Down Expand Up @@ -324,21 +328,25 @@ def test_job_create_multiple(self, aws_hook, gcp_hook):
# (could be anything else) just to test if the templating works for all
# fields
@pytest.mark.db_test
@pytest.mark.parametrize(
"body, excepted",
[(VALID_TRANSFER_JOB_JINJA, VALID_TRANSFER_JOB_JINJA_RENDERED)],
)
@mock.patch(
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook"
)
def test_templates(self, _, create_task_instance_of_operator):
dag_id = "TestGcpStorageTransferJobCreateOperator_test_templates"
def test_templates(self, _, create_task_instance_of_operator, body, excepted):
dag_id = "TestGcpStorageTransferJobCreateOperator"
ti = create_task_instance_of_operator(
CloudDataTransferServiceCreateJobOperator,
dag_id=dag_id,
body={"description": "{{ dag.dag_id }}"},
body=body,
gcp_conn_id="{{ dag.dag_id }}",
aws_conn_id="{{ dag.dag_id }}",
task_id="task-id",
)
ti.render_templates()
assert dag_id == getattr(ti.task, "body")[DESCRIPTION]
assert excepted == getattr(ti.task, "body")
assert dag_id == getattr(ti.task, "gcp_conn_id")
assert dag_id == getattr(ti.task, "aws_conn_id")

Expand Down

0 comments on commit 46470ab

Please sign in to comment.