Skip to content

Commit

Permalink
Adding fnmatch type regex to SFTPSensor (#24084)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexkruc authored Jun 6, 2022
1 parent ec84ffe commit e656e1d
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 6 deletions.
19 changes: 19 additions & 0 deletions airflow/providers/sftp/hooks/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import datetime
import stat
import warnings
from fnmatch import fnmatch
from typing import Any, Dict, List, Optional, Tuple

import pysftp
Expand Down Expand Up @@ -329,3 +330,21 @@ def test_connection(self) -> Tuple[bool, str]:
return True, "Connection successfully tested"
except Exception as e:
return False, str(e)

def get_file_by_pattern(self, path, fnmatch_pattern) -> str:
"""
Returning the first matching file based on the given fnmatch type pattern
:param path: path to be checked
:param fnmatch_pattern: The pattern that will be matched with `fnmatch`
:return: string containing the first found file, or an empty string if none matched
"""
files_list = self.list_directory(path)

for file in files_list:
if not fnmatch(file, fnmatch_pattern):
pass
else:
return file

return ""
18 changes: 15 additions & 3 deletions airflow/providers/sftp/sensors/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class SFTPSensor(BaseSensorOperator):
Waits for a file or directory to be present on SFTP.
:param path: Remote file or directory path
:param file_pattern: The pattern that will be used to match the file (fnmatch format)
:param sftp_conn_id: The connection to run the sensor against
:param newer_than: DateTime for which the file or file path should be newer than, comparison is inclusive
"""
Expand All @@ -47,22 +48,33 @@ def __init__(
self,
*,
path: str,
file_pattern: str = "",
newer_than: Optional[datetime] = None,
sftp_conn_id: str = 'sftp_default',
**kwargs,
) -> None:
super().__init__(**kwargs)
self.path = path
self.file_pattern = file_pattern
self.hook: Optional[SFTPHook] = None
self.sftp_conn_id = sftp_conn_id
self.newer_than: Optional[datetime] = newer_than
self.actual_file_to_check = self.path

def poke(self, context: 'Context') -> bool:
self.hook = SFTPHook(self.sftp_conn_id)
self.log.info('Poking for %s', self.path)
self.log.info(f"Poking for {self.path}, with pattern {self.file_pattern}")

if self.file_pattern:
file_from_pattern = self.hook.get_file_by_pattern(self.path, self.file_pattern)
if file_from_pattern:
self.actual_file_to_check = file_from_pattern
else:
return False

try:
mod_time = self.hook.get_mod_time(self.path)
self.log.info('Found File %s last modified: %s', str(self.path), str(mod_time))
mod_time = self.hook.get_mod_time(self.actual_file_to_check)
self.log.info('Found File %s last modified: %s', str(self.actual_file_to_check), str(mod_time))
except OSError as e:
if e.errno != SFTP_NO_SUCH_FILE:
raise e
Expand Down
38 changes: 35 additions & 3 deletions tests/providers/sftp/hooks/test_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def generate_host_key(pkey: paramiko.PKey):
TMP_DIR_FOR_TESTS = 'tests_sftp_hook_dir'
SUB_DIR = "sub_dir"
TMP_FILE_FOR_TESTS = 'test_file.txt'
ANOTHER_FILE_FOR_TESTS = 'test_file_1.txt'
LOG_FILE_FOR_TESTS = 'test_log.log'

SFTP_CONNECTION_USER = "root"

Expand All @@ -60,13 +62,18 @@ def update_connection(self, login, session=None):
session.commit()
return old_login

def _create_additional_test_file(self, file_name):
with open(os.path.join(TMP_PATH, file_name), 'a') as file:
file.write('Test file')

def setUp(self):
self.old_login = self.update_connection(SFTP_CONNECTION_USER)
self.hook = SFTPHook()
os.makedirs(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR))

with open(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), 'a') as file:
file.write('Test file')
for file_name in [TMP_FILE_FOR_TESTS, ANOTHER_FILE_FOR_TESTS, LOG_FILE_FOR_TESTS]:
with open(os.path.join(TMP_PATH, file_name), 'a') as file:
file.write('Test file')
with open(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS), 'a') as file:
file.write('Test file')

Expand Down Expand Up @@ -353,7 +360,32 @@ def test_deprecation_ftp_conn_id(self, mock_get_connection):
# Default is 'sftp_default
assert SFTPHook().ssh_conn_id == 'sftp_default'

def test_get_suffix_pattern_match(self):
output = self.hook.get_file_by_pattern(TMP_PATH, "*.txt")
self.assertTrue(output, TMP_FILE_FOR_TESTS)

def test_get_prefix_pattern_match(self):
output = self.hook.get_file_by_pattern(TMP_PATH, "test*")
self.assertTrue(output, TMP_FILE_FOR_TESTS)

def test_get_pattern_not_match(self):
output = self.hook.get_file_by_pattern(TMP_PATH, "*.text")
self.assertFalse(output)

def test_get_several_pattern_match(self):
output = self.hook.get_file_by_pattern(TMP_PATH, "*.log")
self.assertEqual(LOG_FILE_FOR_TESTS, output)

def test_get_first_pattern_match(self):
output = self.hook.get_file_by_pattern(TMP_PATH, "test_*.txt")
self.assertEqual(TMP_FILE_FOR_TESTS, output)

def test_get_middle_pattern_match(self):
output = self.hook.get_file_by_pattern(TMP_PATH, "*_file_*.txt")
self.assertEqual(ANOTHER_FILE_FOR_TESTS, output)

def tearDown(self):
shutil.rmtree(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
os.remove(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS))
for file_name in [TMP_FILE_FOR_TESTS, ANOTHER_FILE_FOR_TESTS, LOG_FILE_FOR_TESTS]:
os.remove(os.path.join(TMP_PATH, file_name))
self.update_connection(self.old_login)
28 changes: 28 additions & 0 deletions tests/providers/sftp/sensors/test_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,31 @@ def test_naive_datetime(self, sftp_hook_mock):
output = sftp_sensor.poke(context)
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with('/path/to/file/1970-01-01.txt')
assert not output

@patch('airflow.providers.sftp.sensors.sftp.SFTPHook')
def test_file_with_pattern_parameter_call(self, sftp_hook_mock):
sftp_hook_mock.return_value.get_mod_time.return_value = '19700101000000'
sftp_sensor = SFTPSensor(task_id='unit_test', path='/path/to/file/', file_pattern="*.txt")
context = {'ds': '1970-01-01'}
output = sftp_sensor.poke(context)
sftp_hook_mock.return_value.get_file_by_pattern.assert_called_once_with('/path/to/file/', '*.txt')
assert output

@patch('airflow.providers.sftp.sensors.sftp.SFTPHook')
def test_file_present_with_pattern(self, sftp_hook_mock):
sftp_hook_mock.return_value.get_mod_time.return_value = '19700101000000'
sftp_hook_mock.return_value.get_file_by_pattern.return_value = '/path/to/file/text_file.txt'
sftp_sensor = SFTPSensor(task_id='unit_test', path='/path/to/file/', file_pattern="*.txt")
context = {'ds': '1970-01-01'}
output = sftp_sensor.poke(context)
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with('/path/to/file/text_file.txt')
assert output

@patch('airflow.providers.sftp.sensors.sftp.SFTPHook')
def test_file_not_present_with_pattern(self, sftp_hook_mock):
sftp_hook_mock.return_value.get_mod_time.return_value = '19700101000000'
sftp_hook_mock.return_value.get_file_by_pattern.return_value = ""
sftp_sensor = SFTPSensor(task_id='unit_test', path='/path/to/file/', file_pattern="*.txt")
context = {'ds': '1970-01-01'}
output = sftp_sensor.poke(context)
assert not output

0 comments on commit e656e1d

Please sign in to comment.