From e656e1de55094e8369cab80b9b1669b1d1225f54 Mon Sep 17 00:00:00 2001 From: Alex Kruchkov <36231027+alexkruc@users.noreply.github.com> Date: Mon, 6 Jun 2022 15:54:27 +0300 Subject: [PATCH] Adding fnmatch type regex to SFTPSensor (#24084) --- airflow/providers/sftp/hooks/sftp.py | 19 ++++++++++++ airflow/providers/sftp/sensors/sftp.py | 18 +++++++++-- tests/providers/sftp/hooks/test_sftp.py | 38 +++++++++++++++++++++-- tests/providers/sftp/sensors/test_sftp.py | 28 +++++++++++++++++ 4 files changed, 97 insertions(+), 6 deletions(-) diff --git a/airflow/providers/sftp/hooks/sftp.py b/airflow/providers/sftp/hooks/sftp.py index 58c820b838386..d436d091b5ba6 100644 --- a/airflow/providers/sftp/hooks/sftp.py +++ b/airflow/providers/sftp/hooks/sftp.py @@ -19,6 +19,7 @@ import datetime import stat import warnings +from fnmatch import fnmatch from typing import Any, Dict, List, Optional, Tuple import pysftp @@ -329,3 +330,21 @@ def test_connection(self) -> Tuple[bool, str]: return True, "Connection successfully tested" except Exception as e: return False, str(e) + + def get_file_by_pattern(self, path, fnmatch_pattern) -> str: + """ + Returning the first matching file based on the given fnmatch type pattern + + :param path: path to be checked + :param fnmatch_pattern: The pattern that will be matched with `fnmatch` + :return: string containing the first found file, or an empty string if none matched + """ + files_list = self.list_directory(path) + + for file in files_list: + if not fnmatch(file, fnmatch_pattern): + pass + else: + return file + + return "" diff --git a/airflow/providers/sftp/sensors/sftp.py b/airflow/providers/sftp/sensors/sftp.py index 904321e9b80c0..757a23b1d8b8b 100644 --- a/airflow/providers/sftp/sensors/sftp.py +++ b/airflow/providers/sftp/sensors/sftp.py @@ -34,6 +34,7 @@ class SFTPSensor(BaseSensorOperator): Waits for a file or directory to be present on SFTP. :param path: Remote file or directory path + :param file_pattern: The pattern that will be used to match the file (fnmatch format) :param sftp_conn_id: The connection to run the sensor against :param newer_than: DateTime for which the file or file path should be newer than, comparison is inclusive """ @@ -47,22 +48,33 @@ def __init__( self, *, path: str, + file_pattern: str = "", newer_than: Optional[datetime] = None, sftp_conn_id: str = 'sftp_default', **kwargs, ) -> None: super().__init__(**kwargs) self.path = path + self.file_pattern = file_pattern self.hook: Optional[SFTPHook] = None self.sftp_conn_id = sftp_conn_id self.newer_than: Optional[datetime] = newer_than + self.actual_file_to_check = self.path def poke(self, context: 'Context') -> bool: self.hook = SFTPHook(self.sftp_conn_id) - self.log.info('Poking for %s', self.path) + self.log.info(f"Poking for {self.path}, with pattern {self.file_pattern}") + + if self.file_pattern: + file_from_pattern = self.hook.get_file_by_pattern(self.path, self.file_pattern) + if file_from_pattern: + self.actual_file_to_check = file_from_pattern + else: + return False + try: - mod_time = self.hook.get_mod_time(self.path) - self.log.info('Found File %s last modified: %s', str(self.path), str(mod_time)) + mod_time = self.hook.get_mod_time(self.actual_file_to_check) + self.log.info('Found File %s last modified: %s', str(self.actual_file_to_check), str(mod_time)) except OSError as e: if e.errno != SFTP_NO_SUCH_FILE: raise e diff --git a/tests/providers/sftp/hooks/test_sftp.py b/tests/providers/sftp/hooks/test_sftp.py index 9c63402054a8d..95bb971bdfc8d 100644 --- a/tests/providers/sftp/hooks/test_sftp.py +++ b/tests/providers/sftp/hooks/test_sftp.py @@ -43,6 +43,8 @@ def generate_host_key(pkey: paramiko.PKey): TMP_DIR_FOR_TESTS = 'tests_sftp_hook_dir' SUB_DIR = "sub_dir" TMP_FILE_FOR_TESTS = 'test_file.txt' +ANOTHER_FILE_FOR_TESTS = 'test_file_1.txt' +LOG_FILE_FOR_TESTS = 'test_log.log' SFTP_CONNECTION_USER = "root" @@ -60,13 +62,18 @@ def update_connection(self, login, session=None): session.commit() return old_login + def _create_additional_test_file(self, file_name): + with open(os.path.join(TMP_PATH, file_name), 'a') as file: + file.write('Test file') + def setUp(self): self.old_login = self.update_connection(SFTP_CONNECTION_USER) self.hook = SFTPHook() os.makedirs(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR)) - with open(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), 'a') as file: - file.write('Test file') + for file_name in [TMP_FILE_FOR_TESTS, ANOTHER_FILE_FOR_TESTS, LOG_FILE_FOR_TESTS]: + with open(os.path.join(TMP_PATH, file_name), 'a') as file: + file.write('Test file') with open(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS), 'a') as file: file.write('Test file') @@ -353,7 +360,32 @@ def test_deprecation_ftp_conn_id(self, mock_get_connection): # Default is 'sftp_default assert SFTPHook().ssh_conn_id == 'sftp_default' + def test_get_suffix_pattern_match(self): + output = self.hook.get_file_by_pattern(TMP_PATH, "*.txt") + self.assertTrue(output, TMP_FILE_FOR_TESTS) + + def test_get_prefix_pattern_match(self): + output = self.hook.get_file_by_pattern(TMP_PATH, "test*") + self.assertTrue(output, TMP_FILE_FOR_TESTS) + + def test_get_pattern_not_match(self): + output = self.hook.get_file_by_pattern(TMP_PATH, "*.text") + self.assertFalse(output) + + def test_get_several_pattern_match(self): + output = self.hook.get_file_by_pattern(TMP_PATH, "*.log") + self.assertEqual(LOG_FILE_FOR_TESTS, output) + + def test_get_first_pattern_match(self): + output = self.hook.get_file_by_pattern(TMP_PATH, "test_*.txt") + self.assertEqual(TMP_FILE_FOR_TESTS, output) + + def test_get_middle_pattern_match(self): + output = self.hook.get_file_by_pattern(TMP_PATH, "*_file_*.txt") + self.assertEqual(ANOTHER_FILE_FOR_TESTS, output) + def tearDown(self): shutil.rmtree(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) - os.remove(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS)) + for file_name in [TMP_FILE_FOR_TESTS, ANOTHER_FILE_FOR_TESTS, LOG_FILE_FOR_TESTS]: + os.remove(os.path.join(TMP_PATH, file_name)) self.update_connection(self.old_login) diff --git a/tests/providers/sftp/sensors/test_sftp.py b/tests/providers/sftp/sensors/test_sftp.py index f7c26495bffe6..1bb6c71068db1 100644 --- a/tests/providers/sftp/sensors/test_sftp.py +++ b/tests/providers/sftp/sensors/test_sftp.py @@ -97,3 +97,31 @@ def test_naive_datetime(self, sftp_hook_mock): output = sftp_sensor.poke(context) sftp_hook_mock.return_value.get_mod_time.assert_called_once_with('/path/to/file/1970-01-01.txt') assert not output + + @patch('airflow.providers.sftp.sensors.sftp.SFTPHook') + def test_file_with_pattern_parameter_call(self, sftp_hook_mock): + sftp_hook_mock.return_value.get_mod_time.return_value = '19700101000000' + sftp_sensor = SFTPSensor(task_id='unit_test', path='/path/to/file/', file_pattern="*.txt") + context = {'ds': '1970-01-01'} + output = sftp_sensor.poke(context) + sftp_hook_mock.return_value.get_file_by_pattern.assert_called_once_with('/path/to/file/', '*.txt') + assert output + + @patch('airflow.providers.sftp.sensors.sftp.SFTPHook') + def test_file_present_with_pattern(self, sftp_hook_mock): + sftp_hook_mock.return_value.get_mod_time.return_value = '19700101000000' + sftp_hook_mock.return_value.get_file_by_pattern.return_value = '/path/to/file/text_file.txt' + sftp_sensor = SFTPSensor(task_id='unit_test', path='/path/to/file/', file_pattern="*.txt") + context = {'ds': '1970-01-01'} + output = sftp_sensor.poke(context) + sftp_hook_mock.return_value.get_mod_time.assert_called_once_with('/path/to/file/text_file.txt') + assert output + + @patch('airflow.providers.sftp.sensors.sftp.SFTPHook') + def test_file_not_present_with_pattern(self, sftp_hook_mock): + sftp_hook_mock.return_value.get_mod_time.return_value = '19700101000000' + sftp_hook_mock.return_value.get_file_by_pattern.return_value = "" + sftp_sensor = SFTPSensor(task_id='unit_test', path='/path/to/file/', file_pattern="*.txt") + context = {'ds': '1970-01-01'} + output = sftp_sensor.poke(context) + assert not output