Skip to content

Commit

Permalink
Fix impersonation issue with LocalTaskJob (#16852)
Browse files Browse the repository at this point in the history
Running a task with run_as_user fails because PIDs are not matched
correctly.

This change fixes it by matching the parent process ID (the `sudo`
process) of the task instance to the current process ID of the task_runner
process when we use impersonation

Co-authored-by: Ash Berlin-Taylor <[email protected]>
(cherry picked from commit feea380)
  • Loading branch information
ephraimbuddy authored and jhtimmins committed Jul 9, 2021
1 parent 24f3f63 commit 26a2beb
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 1 deletion.
5 changes: 4 additions & 1 deletion airflow/jobs/local_task_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import signal
from typing import Optional

import psutil

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.jobs.base_job import BaseJob
Expand Down Expand Up @@ -188,9 +190,10 @@ def heartbeat_callback(self, session=None):
fqdn,
)
raise AirflowException("Hostname of job runner does not match")

current_pid = self.task_runner.process.pid
same_process = ti.pid == current_pid
if ti.run_as_user:
same_process = psutil.Process(ti.pid).ppid() == current_pid
if ti.pid is not None and not same_process:
self.log.warning("Recorded pid %s does not match " "the current pid %s", ti.pid, current_pid)
raise AirflowException("PID of job runner does not match")
Expand Down
54 changes: 54 additions & 0 deletions tests/jobs/test_local_task_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
from tests.test_utils.db import clear_db_jobs, clear_db_runs
from tests.test_utils.mock_executor import MockExecutor

# pylint: skip-file

DEFAULT_DATE = timezone.datetime(2016, 1, 1)
TEST_DAG_FOLDER = os.environ['AIRFLOW__CORE__DAGS_FOLDER']

Expand Down Expand Up @@ -135,6 +137,58 @@ def test_localtaskjob_heartbeat(self):
with pytest.raises(AirflowException):
job1.heartbeat_callback() # pylint: disable=no-value-for-parameter

@mock.patch('airflow.jobs.local_task_job.psutil')
def test_localtaskjob_heartbeat_with_run_as_user(self, psutil_mock):
session = settings.Session()
dag = DAG('test_localtaskjob_heartbeat', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})

with dag:
op1 = DummyOperator(task_id='op1', run_as_user='myuser')

dag.clear()
dr = dag.create_dagrun(
run_id="test",
state=State.SUCCESS,
execution_date=DEFAULT_DATE,
start_date=DEFAULT_DATE,
session=session,
)

ti = dr.get_task_instance(task_id=op1.task_id, session=session)
ti.state = State.RUNNING
ti.pid = 2
ti.hostname = get_hostname()
session.commit()

job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
ti.task = op1
ti.refresh_from_task(op1)
job1.task_runner = StandardTaskRunner(job1)
job1.task_runner.process = mock.Mock()
job1.task_runner.process.pid = 2
# Here, ti.pid is 2, the parent process of ti.pid is a mock(different).
# And task_runner process is 2. Should fail
with pytest.raises(AirflowException, match='PID of job runner does not match'):
job1.heartbeat_callback()

job1.task_runner.process.pid = 1
# We make the parent process of ti.pid to equal the task_runner process id
psutil_mock.Process.return_value.ppid.return_value = 1
ti.state = State.RUNNING
ti.pid = 2
# The task_runner process id is 1, same as the parent process of ti.pid
# as seen above
assert ti.run_as_user
session.merge(ti)
session.commit()
job1.heartbeat_callback(session=None)

# Here the task_runner process id is changed to 2
# while parent process of ti.pid is kept at 1, which is different
job1.task_runner.process.pid = 2
with pytest.raises(AirflowException, match='PID of job runner does not match'):
job1.heartbeat_callback()

def test_heartbeat_failed_fast(self):
"""
Test that task heartbeat will sleep when it fails fast
Expand Down

0 comments on commit 26a2beb

Please sign in to comment.