Skip to content

Commit

Permalink
SSL Bucket, Light Logic Refactor and Docstring Update for Alibaba Pro…
Browse files Browse the repository at this point in the history
…vider (#23891)
  • Loading branch information
Vincent Koc authored May 31, 2022
1 parent 8804b1a commit d19cb86
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 44 deletions.
48 changes: 22 additions & 26 deletions airflow/providers/alibaba/cloud/hooks/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,10 @@ def provide_bucket_name(func: T) -> T:
def wrapper(*args, **kwargs) -> T:
bound_args = function_signature.bind(*args, **kwargs)
self = args[0]
if 'bucket_name' not in bound_args.arguments or bound_args.arguments['bucket_name'] is None:
if self.oss_conn_id:
connection = self.get_connection(self.oss_conn_id)
if connection.schema:
bound_args.arguments['bucket_name'] = connection.schema
if bound_args.arguments.get('bucket_name') is None and self.oss_conn_id:
connection = self.get_connection(self.oss_conn_id)
if connection.schema:
bound_args.arguments['bucket_name'] = connection.schema

return func(*bound_args.args, **bound_args.kwargs)

Expand Down Expand Up @@ -92,10 +91,7 @@ class OSSHook(BaseHook):
def __init__(self, region: Optional[str] = None, oss_conn_id='oss_default', *args, **kwargs) -> None:
self.oss_conn_id = oss_conn_id
self.oss_conn = self.get_connection(oss_conn_id)
if region is None:
self.region = self.get_default_region()
else:
self.region = region
self.region = self.get_default_region() if region is None else region
super().__init__(*args, **kwargs)

def get_conn(self) -> "Connection":
Expand Down Expand Up @@ -148,7 +144,7 @@ def get_bucket(self, bucket_name: Optional[str] = None) -> oss2.api.Bucket:
"""
auth = self.get_credential()
assert self.region is not None
return oss2.Bucket(auth, 'http://oss-' + self.region + '.aliyuncs.com', bucket_name)
return oss2.Bucket(auth, f'https://oss-{self.region}.aliyuncs.com', bucket_name)

@provide_bucket_name
@unify_bucket_name_and_key
Expand Down Expand Up @@ -352,28 +348,28 @@ def get_credential(self) -> oss2.auth.Auth:
if not auth_type:
raise Exception("No auth_type specified in extra_config. ")

if auth_type == 'AK':
oss_access_key_id = extra_config.get('access_key_id', None)
oss_access_key_secret = extra_config.get('access_key_secret', None)
if not oss_access_key_id:
raise Exception("No access_key_id is specified for connection: " + self.oss_conn_id)
if not oss_access_key_secret:
raise Exception("No access_key_secret is specified for connection: " + self.oss_conn_id)
return oss2.Auth(oss_access_key_id, oss_access_key_secret)
else:
raise Exception("Unsupported auth_type: " + auth_type)
if auth_type != 'AK':
raise Exception(f"Unsupported auth_type: {auth_type}")
oss_access_key_id = extra_config.get('access_key_id', None)
oss_access_key_secret = extra_config.get('access_key_secret', None)
if not oss_access_key_id:
raise Exception(f"No access_key_id is specified for connection: {self.oss_conn_id}")

if not oss_access_key_secret:
raise Exception(f"No access_key_secret is specified for connection: {self.oss_conn_id}")

return oss2.Auth(oss_access_key_id, oss_access_key_secret)

def get_default_region(self) -> Optional[str]:
extra_config = self.oss_conn.extra_dejson
auth_type = extra_config.get('auth_type', None)
if not auth_type:
raise Exception("No auth_type specified in extra_config. ")

if auth_type == 'AK':
default_region = extra_config.get('region', None)
if not default_region:
raise Exception("No region is specified for connection: " + self.oss_conn_id)
else:
raise Exception("Unsupported auth_type: " + auth_type)
if auth_type != 'AK':
raise Exception(f"Unsupported auth_type: {auth_type}")

default_region = extra_config.get('region', None)
if not default_region:
raise Exception(f"No region is specified for connection: {self.oss_conn_id}")
return default_region
31 changes: 15 additions & 16 deletions airflow/providers/alibaba/cloud/log/oss_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import contextlib
import os
import pathlib
import sys

if sys.version_info >= (3, 8):
Expand Down Expand Up @@ -61,6 +63,7 @@ def hook(self):
)

def set_context(self, ti):
"""This function is used to set the context of the handler"""
super().set_context(ti)
# Local location and remote location is needed to open and
# upload local log file to OSS remote storage.
Expand Down Expand Up @@ -91,8 +94,7 @@ def close(self):
remote_loc = self.log_relative_path
if os.path.exists(local_loc):
# read log and remove old logs to get just the latest additions
with open(local_loc) as logfile:
log = logfile.read()
log = pathlib.Path(local_loc).read_text()
self.oss_write(log, remote_loc)

# Mark closed so we don't double write if close is called twice
Expand All @@ -114,15 +116,14 @@ def _read(self, ti, try_number, metadata=None):
log_relative_path = self._render_filename(ti, try_number)
remote_loc = log_relative_path

if self.oss_log_exists(remote_loc):
# If OSS remote file exists, we do not fetch logs from task instance
# local machine even if there are errors reading remote logs, as
# returned remote_log will contain error messages.
remote_log = self.oss_read(remote_loc, return_error=True)
log = f'*** Reading remote log from {remote_loc}.\n{remote_log}\n'
return log, {'end_of_log': True}
else:
if not self.oss_log_exists(remote_loc):
return super()._read(ti, try_number)
# If OSS remote file exists, we do not fetch logs from task instance
# local machine even if there are errors reading remote logs, as
# returned remote_log will contain error messages.
remote_log = self.oss_read(remote_loc, return_error=True)
log = f'*** Reading remote log from {remote_loc}.\n{remote_log}\n'
return log, {'end_of_log': True}

def oss_log_exists(self, remote_log_location):
"""
Expand All @@ -131,11 +132,9 @@ def oss_log_exists(self, remote_log_location):
:param remote_log_location: log's location in remote storage
:return: True if location exists else False
"""
oss_remote_log_location = self.base_folder + '/' + remote_log_location
try:
oss_remote_log_location = f'{self.base_folder}/{remote_log_location}'
with contextlib.suppress(Exception):
return self.hook.key_exist(self.bucket_name, oss_remote_log_location)
except Exception:
pass
return False

def oss_read(self, remote_log_location, return_error=False):
Expand All @@ -148,7 +147,7 @@ def oss_read(self, remote_log_location, return_error=False):
error occurs. Otherwise returns '' when an error occurs.
"""
try:
oss_remote_log_location = self.base_folder + '/' + remote_log_location
oss_remote_log_location = f'{self.base_folder}/{remote_log_location}'
self.log.info("read remote log: %s", oss_remote_log_location)
return self.hook.read_key(self.bucket_name, oss_remote_log_location)
except Exception:
Expand All @@ -168,7 +167,7 @@ def oss_write(self, log, remote_log_location, append=True):
:param append: if False, any existing log file is overwritten. If True,
the new log is appended to any existing logs.
"""
oss_remote_log_location = self.base_folder + '/' + remote_log_location
oss_remote_log_location = f'{self.base_folder}/{remote_log_location}'
pos = 0
if append and self.oss_log_exists(oss_remote_log_location):
head = self.hook.head_key(self.bucket_name, oss_remote_log_location)
Expand Down
7 changes: 6 additions & 1 deletion airflow/providers/alibaba/cloud/sensors/oss_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,12 @@ def __init__(
self.hook: Optional[OSSHook] = None

def poke(self, context: 'Context'):

"""
Check if the object exists in the bucket to pull key.
@param self - the object itself
@param context - the context of the object
@returns True if the object exists, False otherwise
"""
if self.bucket_name is None:
parsed_url = urlparse(self.bucket_key)
if parsed_url.netloc == '':
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/alibaba/cloud/hooks/test_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_get_bucket(self, mock_oss2, mock_get_credential):
self.hook.get_bucket('mock_bucket_name')
mock_get_credential.assert_called_once_with()
mock_oss2.Bucket.assert_called_once_with(
mock_get_credential.return_value, 'http://oss-mock_region.aliyuncs.com', MOCK_BUCKET_NAME
mock_get_credential.return_value, 'https://oss-mock_region.aliyuncs.com', MOCK_BUCKET_NAME
)

@mock.patch(OSS_STRING.format('OSSHook.get_bucket'))
Expand Down

0 comments on commit d19cb86

Please sign in to comment.