Skip to content

Commit

Permalink
[AIRFLOW-6193] Do not use asserts in Airflow main code (#6749)
Browse files Browse the repository at this point in the history
* [AIRFLOW-6193] Do not use asserts in Airflow main code
  • Loading branch information
potiuk authored Dec 9, 2019
1 parent d087925 commit 25e9047
Show file tree
Hide file tree
Showing 37 changed files with 348 additions and 149 deletions.
3 changes: 2 additions & 1 deletion airflow/api/common/experimental/trigger_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def _trigger_dag(

execution_date = execution_date if execution_date else timezone.utcnow()

assert timezone.is_localized(execution_date)
if not timezone.is_localized(execution_date):
raise ValueError("The execution_date should be localized")

if replace_microseconds:
execution_date = execution_date.replace(microsecond=0)
Expand Down
3 changes: 2 additions & 1 deletion airflow/contrib/example_dags/example_kubernetes_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def use_zip_binary():
:rtype: bool
"""
return_code = os.system("zip")
assert return_code == 0
if return_code != 0:
raise SystemError("The zip binary is missing")

# You don't have to use any special KubernetesExecutor configuration if you don't want to
start_task = PythonOperator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def test_volume_mount():
foo.write('Hello')

return_code = os.system("cat /foo/volume_mount_test.txt")
assert return_code == 0
if return_code != 0:
raise ValueError(f"Error when checking volume mount. Return code {return_code}")

# You can use annotations on your kubernetes pods!
start_task = PythonOperator(
Expand Down
11 changes: 8 additions & 3 deletions airflow/example_dags/example_xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,20 @@ def puller(**kwargs):

# get value_1
pulled_value_1 = ti.xcom_pull(key=None, task_ids='push')
assert pulled_value_1 == value_1
if pulled_value_1 != value_1:
raise ValueError(f'The two values differ {pulled_value_1} and {value_1}')

# get value_2
pulled_value_2 = ti.xcom_pull(task_ids='push_by_returning')
assert pulled_value_2 == value_2
if pulled_value_2 != value_2:
raise ValueError(f'The two values differ {pulled_value_2} and {value_2}')

# get both value_1 and value_2
pulled_value_1, pulled_value_2 = ti.xcom_pull(key=None, task_ids=['push', 'push_by_returning'])
assert (pulled_value_1, pulled_value_2) == (value_1, value_2)
if pulled_value_1 != value_1:
raise ValueError(f'The two values differ {pulled_value_1} and {value_1}')
if pulled_value_2 != value_2:
raise ValueError(f'The two values differ {pulled_value_2} and {value_2}')


push1 = PythonOperator(
Expand Down
26 changes: 18 additions & 8 deletions airflow/executors/dask_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from distributed import Client, Future, as_completed
from distributed.security import Security

from airflow import AirflowException
from airflow.configuration import conf
from airflow.executors.base_executor import NOT_STARTED_MESSAGE, BaseExecutor, CommandType
from airflow.models.taskinstance import TaskInstanceKeyType
Expand All @@ -35,7 +36,8 @@ def __init__(self, cluster_address=None):
super().__init__(parallelism=0)
if cluster_address is None:
cluster_address = conf.get('dask', 'cluster_address')
assert cluster_address, 'Please provide a Dask cluster address in airflow.cfg'
if not cluster_address:
raise ValueError('Please provide a Dask cluster address in airflow.cfg')
self.cluster_address = cluster_address
# ssl / tls parameters
self.tls_ca = conf.get('dask', 'tls_ca')
Expand Down Expand Up @@ -63,17 +65,21 @@ def execute_async(self,
command: CommandType,
queue: Optional[str] = None,
executor_config: Optional[Any] = None) -> None:
assert self.futures, NOT_STARTED_MESSAGE
if not self.futures:
raise AirflowException(NOT_STARTED_MESSAGE)

def airflow_run():
return subprocess.check_call(command, close_fds=True)

assert self.client, "The Dask executor has not been started yet!"
if not self.client:
raise AirflowException(NOT_STARTED_MESSAGE)

future = self.client.submit(airflow_run, pure=False)
self.futures[future] = key

def _process_future(self, future: Future) -> None:
assert self.futures, NOT_STARTED_MESSAGE
if not self.futures:
raise AirflowException(NOT_STARTED_MESSAGE)
if future.done():
key = self.futures[future]
if future.exception():
Expand All @@ -87,19 +93,23 @@ def _process_future(self, future: Future) -> None:
self.futures.pop(future)

def sync(self) -> None:
assert self.futures, NOT_STARTED_MESSAGE
if not self.futures:
raise AirflowException(NOT_STARTED_MESSAGE)
# make a copy so futures can be popped during iteration
for future in self.futures.copy():
self._process_future(future)

def end(self) -> None:
assert self.client, NOT_STARTED_MESSAGE
assert self.futures, NOT_STARTED_MESSAGE
if not self.client:
raise AirflowException(NOT_STARTED_MESSAGE)
if not self.futures:
raise AirflowException(NOT_STARTED_MESSAGE)
self.client.cancel(list(self.futures.keys()))
for future in as_completed(self.futures.copy()):
self._process_future(future)

def terminate(self):
assert self.futures, NOT_STARTED_MESSAGE
if not self.futures:
raise AirflowException(NOT_STARTED_MESSAGE)
self.client.cancel(self.futures.keys())
self.end()
9 changes: 5 additions & 4 deletions airflow/executors/executor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,9 @@ def _get_executor(executor_name: str) -> BaseExecutor:
from airflow import plugins_manager
plugins_manager.integrate_executor_plugins()
executor_path = executor_name.split('.')
assert len(executor_path) == 2, f"Executor {executor_name} not supported: " \
f"please specify in format plugin_module.executor"

assert executor_path[0] in globals(), f"Executor {executor_name} not supported"
if len(executor_path) != 2:
raise ValueError(f"Executor {executor_name} not supported: "
f"please specify in format plugin_module.executor")
if executor_path[0] not in globals():
raise ValueError(f"Executor {executor_name} not supported")
return globals()[executor_path[0]].__dict__[executor_path[1]]()
44 changes: 30 additions & 14 deletions airflow/executors/kubernetes_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,8 @@ def __init__(self,
def run(self) -> None:
"""Performs watching"""
kube_client: client.CoreV1Api = get_kube_client()
assert self.worker_uuid, NOT_STARTED_MESSAGE
if not self.worker_uuid:
raise AirflowException(NOT_STARTED_MESSAGE)
while True:
try:
self.resource_version = self._run(kube_client, self.resource_version,
Expand Down Expand Up @@ -657,7 +658,8 @@ def clear_not_launched_queued_tasks(self, session=None) -> None:
proper support
for State.LAUNCHED
"""
assert self.kube_client, NOT_STARTED_MESSAGE
if not self.kube_client:
raise AirflowException(NOT_STARTED_MESSAGE)
queued_tasks = session\
.query(TaskInstance)\
.filter(TaskInstance.state == State.QUEUED).all()
Expand Down Expand Up @@ -737,7 +739,8 @@ def start(self) -> None:
"""Starts the executor"""
self.log.info('Start Kubernetes executor')
self.worker_uuid = KubeWorkerIdentifier.get_or_create_current_kube_worker_uuid()
assert self.worker_uuid, "Could not get worker_uuid"
if not self.worker_uuid:
raise AirflowException("Could not get worker uuid")
self.log.debug('Start with worker_uuid: %s', self.worker_uuid)
# always need to reset resource version since we don't know
# when we last started, note for behavior below
Expand All @@ -764,7 +767,8 @@ def execute_async(self,
)

kube_executor_config = PodGenerator.from_obj(executor_config)
assert self.task_queue, NOT_STARTED_MESSAGE
if not self.task_queue:
raise AirflowException(NOT_STARTED_MESSAGE)
self.task_queue.put((key, command, kube_executor_config))

def sync(self) -> None:
Expand All @@ -773,10 +777,16 @@ def sync(self) -> None:
self.log.debug('self.running: %s', self.running)
if self.queued_tasks:
self.log.debug('self.queued: %s', self.queued_tasks)
assert self.kube_scheduler, NOT_STARTED_MESSAGE
assert self.kube_config, NOT_STARTED_MESSAGE
assert self.result_queue, NOT_STARTED_MESSAGE
assert self.task_queue, NOT_STARTED_MESSAGE
if not self.worker_uuid:
raise AirflowException(NOT_STARTED_MESSAGE)
if not self.kube_scheduler:
raise AirflowException(NOT_STARTED_MESSAGE)
if not self.kube_config:
raise AirflowException(NOT_STARTED_MESSAGE)
if not self.result_queue:
raise AirflowException(NOT_STARTED_MESSAGE)
if not self.task_queue:
raise AirflowException(NOT_STARTED_MESSAGE)
self.kube_scheduler.sync()

last_resource_version = None
Expand Down Expand Up @@ -819,7 +829,8 @@ def sync(self) -> None:
def _change_state(self, key: TaskInstanceKeyType, state: Optional[str], pod_id: str) -> None:
if state != State.RUNNING:
if self.kube_config.delete_worker_pods:
assert self.kube_scheduler, NOT_STARTED_MESSAGE
if not self.kube_scheduler:
raise AirflowException(NOT_STARTED_MESSAGE)
self.kube_scheduler.delete_pod(pod_id)
self.log.info('Deleted pod: %s', str(key))
try:
Expand All @@ -829,7 +840,8 @@ def _change_state(self, key: TaskInstanceKeyType, state: Optional[str], pod_id:
self.event_buffer[key] = state

def _flush_task_queue(self) -> None:
assert self.task_queue, NOT_STARTED_MESSAGE
if not self.task_queue:
raise AirflowException(NOT_STARTED_MESSAGE)
self.log.debug('Executor shutting down, task_queue approximate size=%d', self.task_queue.qsize())
while True:
try:
Expand All @@ -841,7 +853,8 @@ def _flush_task_queue(self) -> None:
break

def _flush_result_queue(self) -> None:
assert self.result_queue, NOT_STARTED_MESSAGE
if not self.result_queue:
raise AirflowException(NOT_STARTED_MESSAGE)
self.log.debug('Executor shutting down, result_queue approximate size=%d', self.result_queue.qsize())
while True: # pylint: disable=too-many-nested-blocks
try:
Expand All @@ -863,9 +876,12 @@ def _flush_result_queue(self) -> None:

def end(self) -> None:
"""Called when the executor shuts down"""
assert self.task_queue, NOT_STARTED_MESSAGE
assert self.result_queue, NOT_STARTED_MESSAGE
assert self.kube_scheduler, NOT_STARTED_MESSAGE
if not self.task_queue:
raise AirflowException(NOT_STARTED_MESSAGE)
if not self.result_queue:
raise AirflowException(NOT_STARTED_MESSAGE)
if not self.kube_scheduler:
raise AirflowException(NOT_STARTED_MESSAGE)
self.log.info('Shutting down Kubernetes executor')
self.log.debug('Flushing task_queue...')
self._flush_task_queue()
Expand Down
22 changes: 14 additions & 8 deletions airflow/executors/local_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ def execute_async(self,
:param queue: Name of the queue
:param executor_config: configuration for the executor
"""
assert self.executor.result_queue, NOT_STARTED_MESSAGE
if not self.executor.result_queue:
raise AirflowException(NOT_STARTED_MESSAGE)
local_worker = LocalWorker(self.executor.result_queue, key=key, command=command)
self.executor.workers_used += 1
self.executor.workers_active += 1
Expand Down Expand Up @@ -228,10 +229,10 @@ def __init__(self, executor: 'LocalExecutor'):
def start(self) -> None:
"""Starts limited parallelism implementation."""
if not self.executor.manager:
raise AirflowException("Executor must be started!")
raise AirflowException(NOT_STARTED_MESSAGE)
self.queue = self.executor.manager.Queue()
if not self.executor.result_queue:
raise AirflowException("Executor must be started!")
raise AirflowException(NOT_STARTED_MESSAGE)
self.executor.workers = [
QueuedLocalWorker(self.queue, self.executor.result_queue)
for _ in range(self.executor.parallelism)
Expand All @@ -257,7 +258,8 @@ def execute_async(self,
:param queue: name of the queue
:param executor_config: configuration for the executor
"""
assert self.queue, NOT_STARTED_MESSAGE
if not self.queue:
raise AirflowException(NOT_STARTED_MESSAGE)
self.queue.put((key, command))

def sync(self):
Expand Down Expand Up @@ -300,23 +302,27 @@ def execute_async(self, key: TaskInstanceKeyType,
queue: Optional[str] = None,
executor_config: Optional[Any] = None) -> None:
"""Execute asynchronously."""
assert self.impl, NOT_STARTED_MESSAGE
if not self.impl:
raise AirflowException(NOT_STARTED_MESSAGE)
self.impl.execute_async(key=key, command=command, queue=queue, executor_config=executor_config)

def sync(self) -> None:
"""
Sync will get called periodically by the heartbeat method.
"""
assert self.impl, NOT_STARTED_MESSAGE
if not self.impl:
raise AirflowException(NOT_STARTED_MESSAGE)
self.impl.sync()

def end(self) -> None:
"""
Ends the executor.
:return:
"""
assert self.impl, NOT_STARTED_MESSAGE
assert self.manager, NOT_STARTED_MESSAGE
if not self.impl:
raise AirflowException(NOT_STARTED_MESSAGE)
if not self.manager:
raise AirflowException(NOT_STARTED_MESSAGE)
self.impl.end()
self.manager.shutdown()

Expand Down
Loading

0 comments on commit 25e9047

Please sign in to comment.