Skip to content

Commit

Permalink
Convert sftp hook to use paramiko instead of pysftp (#24512)
Browse files Browse the repository at this point in the history
  • Loading branch information
pauldalewilliams authored Jun 19, 2022
1 parent f48112c commit f3aaceb
Show file tree
Hide file tree
Showing 11 changed files with 366 additions and 214 deletions.
254 changes: 151 additions & 103 deletions airflow/providers/sftp/hooks/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
# under the License.
"""This module contains SFTP hook."""
import datetime
import os
import stat
import warnings
from fnmatch import fnmatch
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import pysftp
import tenacity
from paramiko import SSHException
import paramiko

from airflow.exceptions import AirflowException
from airflow.providers.ssh.hooks.ssh import SSHHook


Expand All @@ -49,11 +49,9 @@ class SFTPHook(SSHHook):
Errors that may occur throughout but should be handled downstream.
For consistency reasons with SSHHook, the preferred parameter is "ssh_conn_id".
Please note that it is still possible to use the parameter "ftp_conn_id"
to initialize the hook, but it will be removed in future Airflow versions.
:param ssh_conn_id: The :ref:`sftp connection id<howto/connection:sftp>`
:param ftp_conn_id (Outdated): The :ref:`sftp connection id<howto/connection:sftp>`
:param ssh_hook: Optional SSH hook (included to support passing of an SSH hook to the SFTP operator)
"""

conn_name_attr = 'ssh_conn_id'
Expand All @@ -73,9 +71,29 @@ def get_ui_field_behaviour() -> Dict[str, Any]:
def __init__(
self,
ssh_conn_id: Optional[str] = 'sftp_default',
ssh_hook: Optional[SSHHook] = None,
*args,
**kwargs,
) -> None:
self.conn: Optional[paramiko.SFTPClient] = None

# TODO: remove support for ssh_hook when it is removed from SFTPOperator
self.ssh_hook = ssh_hook

if self.ssh_hook is not None:
warnings.warn(
'Parameter `ssh_hook` is deprecated and will be removed in a future version.',
DeprecationWarning,
stacklevel=2,
)
if not isinstance(self.ssh_hook, SSHHook):
raise AirflowException(
f'ssh_hook must be an instance of SSHHook, but got {type(self.ssh_hook)}'
)
self.log.info('ssh_hook is provided. It will be used to generate SFTP connection.')
self.ssh_conn_id = self.ssh_hook.ssh_conn_id
return

ftp_conn_id = kwargs.pop('ftp_conn_id', None)
if ftp_conn_id:
warnings.warn(
Expand All @@ -84,114 +102,47 @@ def __init__(
stacklevel=2,
)
ssh_conn_id = ftp_conn_id

kwargs['ssh_conn_id'] = ssh_conn_id
self.ssh_conn_id = ssh_conn_id

super().__init__(*args, **kwargs)

self.conn = None
self.private_key_pass = None
self.ciphers = None

# Fail for unverified hosts, unless this is explicitly allowed
self.no_host_key_check = False

if self.ssh_conn_id is not None:
conn = self.get_connection(self.ssh_conn_id)
if conn.extra is not None:
extra_options = conn.extra_dejson

# For backward compatibility
# TODO: remove in the next major provider release.

if 'private_key_pass' in extra_options:
warnings.warn(
'Extra option `private_key_pass` is deprecated.'
'Please use `private_key_passphrase` instead.'
'`private_key_passphrase` will precede if both options are specified.'
'The old option `private_key_pass` will be removed in a future release.',
DeprecationWarning,
stacklevel=2,
)
self.private_key_pass = extra_options.get(
'private_key_passphrase', extra_options.get('private_key_pass')
)
def get_conn(self) -> paramiko.SFTPClient: # type: ignore[override]
"""
Opens an SFTP connection to the remote host
if 'ignore_hostkey_verification' in extra_options:
warnings.warn(
'Extra option `ignore_hostkey_verification` is deprecated.'
'Please use `no_host_key_check` instead.'
'This option will be removed in a future release.',
DeprecationWarning,
stacklevel=2,
)
self.no_host_key_check = (
str(extra_options['ignore_hostkey_verification']).lower() == 'true'
)

if 'no_host_key_check' in extra_options:
self.no_host_key_check = str(extra_options['no_host_key_check']).lower() == 'true'

if 'ciphers' in extra_options:
self.ciphers = extra_options['ciphers']

@tenacity.retry(
stop=tenacity.stop_after_delay(10),
wait=tenacity.wait_exponential(multiplier=1, max=10),
retry=tenacity.retry_if_exception_type(SSHException),
reraise=True,
)
def get_conn(self) -> pysftp.Connection:
"""Returns an SFTP connection object"""
:rtype: paramiko.SFTPClient
"""
if self.conn is None:
cnopts = pysftp.CnOpts()
if self.no_host_key_check:
cnopts.hostkeys = None
# TODO: remove support for ssh_hook when it is removed from SFTPOperator
if self.ssh_hook is not None:
self.conn = self.ssh_hook.get_conn().open_sftp()
else:
if self.host_key is not None:
cnopts.hostkeys.add(self.remote_host, self.host_key.get_name(), self.host_key)
else:
pass # will fallback to system host keys if none explicitly specified in conn extra

cnopts.compression = self.compress
cnopts.ciphers = self.ciphers
conn_params = {
'host': self.remote_host,
'port': self.port,
'username': self.username,
'cnopts': cnopts,
}
if self.password and self.password.strip():
conn_params['password'] = self.password
if self.pkey:
conn_params['private_key'] = self.pkey
elif self.key_file:
conn_params['private_key'] = self.key_file
if self.private_key_pass:
conn_params['private_key_pass'] = self.private_key_pass

self.conn = pysftp.Connection(**conn_params)
self.conn = super().get_conn().open_sftp()
return self.conn

def close_conn(self) -> None:
"""Closes the connection"""
"""Closes the SFTP connection"""
if self.conn is not None:
self.conn.close()
self.conn = None

def describe_directory(self, path: str) -> Dict[str, Dict[str, str]]:
def describe_directory(self, path: str) -> Dict[str, Dict[str, Union[str, int, None]]]:
"""
Returns a dictionary of {filename: {attributes}} for all files
on the remote system (where the MLSD command is supported).
:param path: full path to the remote directory
"""
conn = self.get_conn()
flist = conn.listdir_attr(path)
flist = sorted(conn.listdir_attr(path), key=lambda x: x.filename)
files = {}
for f in flist:
modify = datetime.datetime.fromtimestamp(f.st_mtime).strftime('%Y%m%d%H%M%S')
modify = datetime.datetime.fromtimestamp(f.st_mtime).strftime('%Y%m%d%H%M%S') # type: ignore
files[f.filename] = {
'size': f.st_size,
'type': 'dir' if stat.S_ISDIR(f.st_mode) else 'file',
'type': 'dir' if stat.S_ISDIR(f.st_mode) else 'file', # type: ignore
'modify': modify,
}
return files
Expand All @@ -203,9 +154,45 @@ def list_directory(self, path: str) -> List[str]:
:param path: full path to the remote directory to list
"""
conn = self.get_conn()
files = conn.listdir(path)
files = sorted(conn.listdir(path))
return files

def mkdir(self, path: str, mode: int = 777) -> None:
"""
Creates a directory on the remote system.
:param path: full path to the remote directory to create
:param mode: permissions to set the directory with
"""
conn = self.get_conn()
conn.mkdir(path, mode=int(str(mode), 8))

def isdir(self, path: str) -> bool:
"""
Checks if the path provided is a directory or not.
:param path: full path to the remote directory to check
"""
conn = self.get_conn()
try:
result = stat.S_ISDIR(conn.stat(path).st_mode) # type: ignore
except OSError:
result = False
return result

def isfile(self, path: str) -> bool:
"""
Checks if the path provided is a file or not.
:param path: full path to the remote file to check
"""
conn = self.get_conn()
try:
result = stat.S_ISREG(conn.stat(path).st_mode) # type: ignore
except OSError:
result = False
return result

def create_directory(self, path: str, mode: int = 777) -> None:
"""
Creates a directory on the remote system.
Expand All @@ -214,7 +201,18 @@ def create_directory(self, path: str, mode: int = 777) -> None:
:param mode: int representation of octal mode for directory
"""
conn = self.get_conn()
conn.makedirs(path, mode)
if self.isdir(path):
self.log.info(f"{path} already exists")
return
elif self.isfile(path):
raise AirflowException(f"{path} already exists and is a file")
else:
dirname, basename = os.path.split(path)
if dirname and not self.isdir(dirname):
self.create_directory(dirname, mode)
if basename:
self.log.info(f"Creating {path}")
conn.mkdir(path, mode=mode)

def delete_directory(self, path: str) -> None:
"""
Expand All @@ -237,7 +235,7 @@ def retrieve_file(self, remote_full_path: str, local_full_path: str) -> None:
conn = self.get_conn()
conn.get(remote_full_path, local_full_path)

def store_file(self, remote_full_path: str, local_full_path: str) -> None:
def store_file(self, remote_full_path: str, local_full_path: str, confirm: bool = True) -> None:
"""
Transfers a local file to the remote location.
If local_full_path_or_buffer is a string path, the file will be read
Expand All @@ -247,7 +245,7 @@ def store_file(self, remote_full_path: str, local_full_path: str) -> None:
:param local_full_path: full path to the local file
"""
conn = self.get_conn()
conn.put(local_full_path, remote_full_path)
conn.put(local_full_path, remote_full_path, confirm=confirm)

def delete_file(self, path: str) -> None:
"""
Expand All @@ -266,7 +264,7 @@ def get_mod_time(self, path: str) -> str:
"""
conn = self.get_conn()
ftp_mdtm = conn.stat(path).st_mtime
return datetime.datetime.fromtimestamp(ftp_mdtm).strftime('%Y%m%d%H%M%S')
return datetime.datetime.fromtimestamp(ftp_mdtm).strftime('%Y%m%d%H%M%S') # type: ignore

def path_exists(self, path: str) -> bool:
"""
Expand All @@ -275,7 +273,11 @@ def path_exists(self, path: str) -> bool:
:param path: full path to the remote file or directory
"""
conn = self.get_conn()
return conn.exists(path)
try:
conn.stat(path)
except OSError:
return False
return True

@staticmethod
def _is_path_match(path: str, prefix: Optional[str] = None, delimiter: Optional[str] = None) -> bool:
Expand All @@ -293,6 +295,51 @@ def _is_path_match(path: str, prefix: Optional[str] = None, delimiter: Optional[
return False
return True

def walktree(
self,
path: str,
fcallback: Callable[[str], Optional[Any]],
dcallback: Callable[[str], Optional[Any]],
ucallback: Callable[[str], Optional[Any]],
recurse: bool = True,
) -> None:
"""
Recursively descend, depth first, the directory tree rooted at
path, calling discreet callback functions for each regular file,
directory and unknown file type.
:param str path:
root of remote directory to descend, use '.' to start at
:attr:`.pwd`
:param callable fcallback:
callback function to invoke for a regular file.
(form: ``func(str)``)
:param callable dcallback:
callback function to invoke for a directory. (form: ``func(str)``)
:param callable ucallback:
callback function to invoke for an unknown file type.
(form: ``func(str)``)
:param bool recurse: *Default: True* - should it recurse
:returns: None
"""
conn = self.get_conn()
for entry in self.list_directory(path):
pathname = os.path.join(path, entry)
mode = conn.stat(pathname).st_mode
if stat.S_ISDIR(mode): # type: ignore
# It's a directory, call the dcallback function
dcallback(pathname)
if recurse:
# now, recurse into it
self.walktree(pathname, fcallback, dcallback, ucallback)
elif stat.S_ISREG(mode): # type: ignore
# It's a file, call the fcallback function
fcallback(pathname)
else:
# Unknown file type
ucallback(pathname)

def get_tree_map(
self, path: str, prefix: Optional[str] = None, delimiter: Optional[str] = None
) -> Tuple[List[str], List[str], List[str]]:
Expand All @@ -306,14 +353,15 @@ def get_tree_map(
:return: tuple with list of files, dirs and unknown items
:rtype: Tuple[List[str], List[str], List[str]]
"""
conn = self.get_conn()
files, dirs, unknowns = [], [], [] # type: List[str], List[str], List[str]
files: List[str] = []
dirs: List[str] = []
unknowns: List[str] = []

def append_matching_path_callback(list_):
def append_matching_path_callback(list_: List[str]) -> Callable:
return lambda item: list_.append(item) if self._is_path_match(item, prefix, delimiter) else None

conn.walktree(
remotepath=path,
self.walktree(
path=path,
fcallback=append_matching_path_callback(files),
dcallback=append_matching_path_callback(dirs),
ucallback=append_matching_path_callback(unknowns),
Expand All @@ -326,7 +374,7 @@ def test_connection(self) -> Tuple[bool, str]:
"""Test the SFTP connection by calling path with directory"""
try:
conn = self.get_conn()
conn.pwd
conn.normalize('.')
return True, "Connection successfully tested"
except Exception as e:
return False, str(e)
Expand Down
Loading

0 comments on commit f3aaceb

Please sign in to comment.