Skip to content

Commit

Permalink
sftp_to_s3 stream file option (#17609)
Browse files Browse the repository at this point in the history
  • Loading branch information
john-jac authored Sep 8, 2021
1 parent ff64fe8 commit 3fe948a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
15 changes: 12 additions & 3 deletions airflow/providers/amazon/aws/transfers/sftp_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class SFTPToS3Operator(BaseOperator):
:param s3_key: The targeted s3 key. This is the specified path for
uploading the file to S3.
:type s3_key: str
:param use_temp_file: If True, copies file first to local,
if False streams file from SFTP to S3.
:type use_temp_file: bool
"""

template_fields = ('s3_key', 'sftp_path')
Expand All @@ -59,6 +62,7 @@ def __init__(
sftp_path: str,
sftp_conn_id: str = 'ssh_default',
s3_conn_id: str = 'aws_default',
use_temp_file: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -67,6 +71,7 @@ def __init__(
self.s3_bucket = s3_bucket
self.s3_key = s3_key
self.s3_conn_id = s3_conn_id
self.use_temp_file = use_temp_file

@staticmethod
def get_s3_key(s3_key: str) -> str:
Expand All @@ -81,7 +86,11 @@ def execute(self, context) -> None:

sftp_client = ssh_hook.get_conn().open_sftp()

with NamedTemporaryFile("w") as f:
sftp_client.get(self.sftp_path, f.name)
if self.use_temp_file:
with NamedTemporaryFile("w") as f:
sftp_client.get(self.sftp_path, f.name)

s3_hook.load_file(filename=f.name, key=self.s3_key, bucket_name=self.s3_bucket, replace=True)
s3_hook.load_file(filename=f.name, key=self.s3_key, bucket_name=self.s3_bucket, replace=True)
else:
with sftp_client.file(self.sftp_path, mode='rb') as data:
s3_hook.get_conn().upload_fileobj(data, self.s3_bucket, self.s3_key, Callback=self.log.info)
10 changes: 9 additions & 1 deletion tests/providers/amazon/aws/transfers/test_sftp_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import boto3
from moto import mock_s3
from parameterized import parameterized

from airflow.models import DAG
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
Expand Down Expand Up @@ -66,9 +67,15 @@ def setUp(self):
self.sftp_path = SFTP_PATH
self.s3_key = S3_KEY

@parameterized.expand(
[
(True,),
(False,),
]
)
@mock_s3
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
def test_sftp_to_s3_operation(self):
def test_sftp_to_s3_operation(self, use_temp_file=True):
# Setting
test_remote_file_content = (
"This is remote file content \n which is also multiline "
Expand Down Expand Up @@ -98,6 +105,7 @@ def test_sftp_to_s3_operation(self):
sftp_path=SFTP_PATH,
sftp_conn_id=SFTP_CONN_ID,
s3_conn_id=S3_CONN_ID,
use_temp_file=use_temp_file,
task_id='test_sftp_to_s3',
dag=self.dag,
)
Expand Down

0 comments on commit 3fe948a

Please sign in to comment.