forked from apache/airflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added tests for the MarqueDag library (apache#3)
* Added tests for the MarqueDag library
- Loading branch information
Showing
3 changed files
with
189 additions
and
63 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,93 +1,85 @@ | ||
import json | ||
import pendulum | ||
import airflow.models | ||
from airflow.models import DAG, Log | ||
from airflow.utils.db import provide_session | ||
from marquez_client.marquez import MarquezClient | ||
|
||
from marquez.utils import JobIdMapping | ||
|
||
class MarquezDag(airflow.models.DAG): | ||
|
||
class MarquezDag(DAG): | ||
_job_id_mapping = None | ||
_mqz_client = None | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.mqz_client = MarquezClient() | ||
self.mqz_namespace = kwargs['default_args'].get('mqz_namespace', 'unknown') | ||
self.mqz_location = kwargs['default_args'].get('mqz_location', 'unknown') | ||
self.mqz_input_datasets = kwargs['default_args'].get('mqz_input_datasets', []) | ||
self.mqz_output_datasets = kwargs['default_args'].get('mqz_output_datasets', []) | ||
self._job_id_mapping = JobIdMapping() | ||
|
||
def create_dagrun(self, *args, **kwargs): | ||
job_name = self.dag_id | ||
job_run_args = "{}" # TODO retrieve from DAG/tasks | ||
start_time = pendulum.instance(kwargs['execution_date']).to_datetime_string() | ||
end_time = None | ||
|
||
self.mqz_client.set_namespace(self.mqz_namespace) | ||
self.mqz_client.create_job(job_name, self.mqz_location, self.mqz_input_datasets, self.mqz_output_datasets, | ||
self.description) | ||
mqz_job_run_id = self.mqz_client.create_job_run(job_name, job_run_args=job_run_args, | ||
nominal_start_time=start_time, | ||
nominal_end_time=end_time).run_id | ||
self.mqz_client.mark_job_run_running(mqz_job_run_id) | ||
|
||
self.marquez_log('job_running', json.dumps( | ||
{'namespace': self.mqz_namespace, | ||
'name': job_name, | ||
'description': self.description, | ||
'location': self.mqz_location, | ||
'runArgs': job_run_args, | ||
'nominal_start_time': start_time, | ||
'nominal_end_time': end_time, | ||
'jobrun_id': mqz_job_run_id, | ||
'inputDatasetUrns': self.mqz_input_datasets, | ||
'outputDatasetUrns': self.mqz_output_datasets | ||
})) | ||
|
||
run = super().create_dagrun(*args, **kwargs) | ||
airflow.models.Variable.set(run.run_id, mqz_job_run_id) | ||
|
||
run_args = "{}" # TODO extract the run Args from the tasks | ||
mqz_job_run_id = self.report_jobrun(run_args, kwargs['execution_date']) | ||
run = super(MarquezDag, self).create_dagrun(*args, **kwargs) | ||
self._job_id_mapping.set(JobIdMapping.make_key(run.dag_id, run.run_id), mqz_job_run_id) | ||
return run | ||
|
||
def handle_callback(self, *args, **kwargs): | ||
self.report_jobrun_change(args[0], **kwargs) | ||
return super().handle_callback(*args, **kwargs) | ||
|
||
def report_jobrun(self, run_args, execution_date): | ||
job_name = self.dag_id | ||
mqz_job_run_id = self.get_and_delete(args[0].run_id) | ||
job_run_args = run_args | ||
start_time = pendulum.instance(execution_date).to_datetime_string() | ||
end_time = pendulum.instance(self.following_schedule(execution_date)).to_datetime_string() | ||
mqz_client = self.get_mqz_client() | ||
mqz_client.set_namespace(self.mqz_namespace) | ||
mqz_client.create_job(job_name, self.mqz_location, self.mqz_input_datasets, | ||
self.mqz_output_datasets, self.description) | ||
mqz_job_run_id = str(mqz_client.create_job_run( | ||
job_name, job_run_args=job_run_args, nominal_start_time=start_time, nominal_end_time=end_time).run_id) | ||
mqz_client.mark_job_run_running(mqz_job_run_id) | ||
|
||
if mqz_job_run_id: | ||
self.log_marquez_event('job_running', | ||
namespace=self.mqz_namespace, | ||
name=job_name, | ||
description=self.description, | ||
location=self.mqz_location, | ||
runArgs=job_run_args, | ||
nominal_start_time=start_time, | ||
nominal_end_time=end_time, | ||
jobrun_id=mqz_job_run_id, | ||
inputDatasetUrns=self.mqz_input_datasets, | ||
outputDatasetUrns=self.mqz_output_datasets) | ||
return mqz_job_run_id | ||
|
||
def report_jobrun_change(self, dagrun, **kwargs): | ||
mqz_job_run_id = self._job_id_mapping.pop(JobIdMapping.make_key(dagrun.dag_id, dagrun.run_id)) | ||
if mqz_job_run_id: | ||
if kwargs.get('success'): | ||
self.mqz_client.mark_job_run_completed(mqz_job_run_id) | ||
self.marquez_log('job_state_change', | ||
json.dumps({'job_name': job_name, | ||
'jobrun_id': mqz_job_run_id, | ||
'state': 'COMPLETED'})) | ||
self.get_mqz_client().mark_job_run_completed(mqz_job_run_id) | ||
else: | ||
self.mqz_client.mark_job_run_failed(mqz_job_run_id) | ||
self.marquez_log('job_state_change', | ||
json.dumps({'job_name': job_name, | ||
'jobrun_id': mqz_job_run_id, | ||
'state': 'FAILED'})) | ||
|
||
else: | ||
# TODO warn that the jobrun_id couldn't be found | ||
pass | ||
|
||
return super().handle_callback(*args, **kwargs) | ||
|
||
@provide_session | ||
def get_and_delete(self, key, session=None): | ||
q = session.query(airflow.models.Variable).filter(airflow.models.Variable.key == key) | ||
if q.first() is None: | ||
return | ||
else: | ||
val = q.first().val | ||
q.delete(synchronize_session=False) | ||
return val | ||
self.get_mqz_client().mark_job_run_failed(mqz_job_run_id) | ||
self.log_marquez_event('job_state_change' if mqz_job_run_id else 'job_state_change_LOST', | ||
job_name=self.dag_id, | ||
jobrun_id=mqz_job_run_id, | ||
state='COMPLETED' if kwargs.get('success') else 'FAILED', | ||
reason=kwargs['reason']) | ||
|
||
@provide_session | ||
def marquez_log(self, event, extras, session=None): | ||
session.add(airflow.models.Log( | ||
def log_marquez_event(self, event, session=None, **kwargs): | ||
session.add(Log( | ||
event=event, | ||
task_instance=None, | ||
owner="marquez", | ||
extra=extras, | ||
extra=json.dumps(kwargs), | ||
task_id=None, | ||
dag_id=self.dag_id)) | ||
|
||
def get_mqz_client(self): | ||
if not self._mqz_client: | ||
self._mqz_client = MarquezClient() | ||
return self._mqz_client |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import airflow | ||
from airflow.utils.db import provide_session | ||
|
||
|
||
class JobIdMapping: | ||
_instance = None | ||
|
||
def __new__(cls, *args, **kwargs): | ||
if not cls._instance: | ||
cls._instance = super(JobIdMapping, cls).__new__(cls, *args, **kwargs) | ||
return cls._instance | ||
|
||
def set(self, key, val): | ||
airflow.models.Variable.set(key, val) | ||
|
||
@provide_session | ||
def pop(self, key, session=None): | ||
q = session.query(airflow.models.Variable).filter(airflow.models.Variable.key == key) | ||
if not q.first(): | ||
return | ||
else: | ||
val = q.first().val | ||
q.delete(synchronize_session=False) | ||
return val | ||
|
||
@staticmethod | ||
def make_key(job_name, run_id): | ||
return "mqz_id_mapping-{}-{}".format(job_name, run_id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
from datetime import datetime | ||
from unittest.mock import Mock, create_autospec, patch | ||
|
||
import pytest | ||
|
||
import airflow.models | ||
import marquez.utils | ||
import pendulum | ||
from airflow.utils.state import State | ||
from croniter import croniter | ||
from marquez.airflow import MarquezDag | ||
from marquez_client.marquez import MarquezClient | ||
|
||
|
||
class Context: | ||
location = 'github://test_dag_location' | ||
dag_id = 'test-dag-1' | ||
namespace = 'test-namespace-1' | ||
data_inputs = ["s3://data_input_1", "s3://data_input_2"] | ||
data_outputs = ["s3://some_output_data"] | ||
owner = 'test_owner' | ||
description = 'this is a test DAG' | ||
airflow_run_id = 'airflow_run_id_123456' | ||
mqz_run_id = '71d29487-0b54-4ae1-9295-efd87f190c57' | ||
start_date = datetime(2019, 1, 31, 0, 0, 0) | ||
execution_date = datetime(2019, 2, 2, 0, 0, 0) | ||
schedule_interval = '*/10 * * * *' | ||
|
||
dag = None | ||
|
||
def __init__(self): | ||
self.dag = MarquezDag( | ||
self.dag_id, | ||
schedule_interval=self.schedule_interval, | ||
default_args={'mqz_namespace': self.namespace, | ||
'mqz_location': self.location, | ||
'mqz_input_datasets': self.data_inputs, | ||
'mqz_output_datasets': self.data_outputs, | ||
'owner': self.owner, | ||
'depends_on_past': False, | ||
'start_date': self.start_date}, | ||
description=self.description) | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def context(): | ||
return Context() | ||
|
||
|
||
@patch.object(airflow.models.DAG, 'create_dagrun') | ||
@patch.object(marquez.utils.JobIdMapping, 'set') | ||
def test_create_dagrun(mock_set, mock_dag_run, context): | ||
|
||
dag = context.dag | ||
mock_mqz_client = make_mock_mqz_client(context.mqz_run_id) | ||
dag._mqz_client = mock_mqz_client # Use a mock marquez-python client | ||
mock_dag_run.return_value = make_mock_airflow_jobrun(dag.dag_id, context.airflow_run_id) | ||
|
||
# trigger an airflow DagRun | ||
dag.create_dagrun(state=State.RUNNING, run_id=context.airflow_run_id, execution_date=context.execution_date) | ||
|
||
# check Marquez client was called with expected arguments | ||
mock_mqz_client.set_namespace.assert_called_with(context.namespace) | ||
mock_mqz_client.create_job.assert_called_once_with(context.dag_id, context.location, context.data_inputs, | ||
context.data_outputs, context.description) | ||
mock_mqz_client.create_job_run.assert_called_once_with( | ||
context.dag_id, | ||
"{}", | ||
to_airflow_datetime_str(context.execution_date), | ||
to_airflow_datetime_str(compute_end_time(context.schedule_interval, context.execution_date))) | ||
|
||
# Test if airflow's create_dagrun() is called with the expected arguments | ||
mock_dag_run.assert_called_once_with(state=State.RUNNING, | ||
run_id=context.airflow_run_id, | ||
execution_date=context.execution_date) | ||
|
||
# Assert there is a job_id mapping being created | ||
mock_set.assert_called_once_with(marquez.utils.JobIdMapping.make_key(context.dag_id, context.airflow_run_id), | ||
context.mqz_run_id) | ||
|
||
|
||
def make_mock_mqz_client(run_id): | ||
mock_mqz_run = Mock() | ||
mock_mqz_run.run_id = run_id | ||
mock_mqz_client = create_autospec(MarquezClient) | ||
mock_mqz_client.create_job_run.return_value = mock_mqz_run | ||
return mock_mqz_client | ||
|
||
|
||
def make_mock_airflow_jobrun(dag_id, airflow_run_id): | ||
mock_airflow_jobrun = Mock() | ||
mock_airflow_jobrun.run_id = airflow_run_id | ||
mock_airflow_jobrun.dag_id = dag_id | ||
return mock_airflow_jobrun | ||
|
||
|
||
def compute_end_time(schedule_interval, start_time): | ||
return datetime.utcfromtimestamp(croniter(schedule_interval, start_time).get_next()) | ||
|
||
|
||
def to_airflow_datetime_str(dt): | ||
return pendulum.instance(dt).to_datetime_string() | ||
|
||
|
||
if __name__ == "__main__": | ||
pytest.main() |