diff --git a/airflow/providers/alibaba/cloud/hooks/oss.py b/airflow/providers/alibaba/cloud/hooks/oss.py index 23bbde19d41c2..08272adb25e70 100644 --- a/airflow/providers/alibaba/cloud/hooks/oss.py +++ b/airflow/providers/alibaba/cloud/hooks/oss.py @@ -43,11 +43,10 @@ def provide_bucket_name(func: T) -> T: def wrapper(*args, **kwargs) -> T: bound_args = function_signature.bind(*args, **kwargs) self = args[0] - if 'bucket_name' not in bound_args.arguments or bound_args.arguments['bucket_name'] is None: - if self.oss_conn_id: - connection = self.get_connection(self.oss_conn_id) - if connection.schema: - bound_args.arguments['bucket_name'] = connection.schema + if bound_args.arguments.get('bucket_name') is None and self.oss_conn_id: + connection = self.get_connection(self.oss_conn_id) + if connection.schema: + bound_args.arguments['bucket_name'] = connection.schema return func(*bound_args.args, **bound_args.kwargs) @@ -92,10 +91,7 @@ class OSSHook(BaseHook): def __init__(self, region: Optional[str] = None, oss_conn_id='oss_default', *args, **kwargs) -> None: self.oss_conn_id = oss_conn_id self.oss_conn = self.get_connection(oss_conn_id) - if region is None: - self.region = self.get_default_region() - else: - self.region = region + self.region = self.get_default_region() if region is None else region super().__init__(*args, **kwargs) def get_conn(self) -> "Connection": @@ -148,7 +144,7 @@ def get_bucket(self, bucket_name: Optional[str] = None) -> oss2.api.Bucket: """ auth = self.get_credential() assert self.region is not None - return oss2.Bucket(auth, 'http://oss-' + self.region + '.aliyuncs.com', bucket_name) + return oss2.Bucket(auth, f'https://oss-{self.region}.aliyuncs.com', bucket_name) @provide_bucket_name @unify_bucket_name_and_key @@ -352,16 +348,17 @@ def get_credential(self) -> oss2.auth.Auth: if not auth_type: raise Exception("No auth_type specified in extra_config. ") - if auth_type == 'AK': - oss_access_key_id = extra_config.get('access_key_id', None) - oss_access_key_secret = extra_config.get('access_key_secret', None) - if not oss_access_key_id: - raise Exception("No access_key_id is specified for connection: " + self.oss_conn_id) - if not oss_access_key_secret: - raise Exception("No access_key_secret is specified for connection: " + self.oss_conn_id) - return oss2.Auth(oss_access_key_id, oss_access_key_secret) - else: - raise Exception("Unsupported auth_type: " + auth_type) + if auth_type != 'AK': + raise Exception(f"Unsupported auth_type: {auth_type}") + oss_access_key_id = extra_config.get('access_key_id', None) + oss_access_key_secret = extra_config.get('access_key_secret', None) + if not oss_access_key_id: + raise Exception(f"No access_key_id is specified for connection: {self.oss_conn_id}") + + if not oss_access_key_secret: + raise Exception(f"No access_key_secret is specified for connection: {self.oss_conn_id}") + + return oss2.Auth(oss_access_key_id, oss_access_key_secret) def get_default_region(self) -> Optional[str]: extra_config = self.oss_conn.extra_dejson @@ -369,11 +366,10 @@ def get_default_region(self) -> Optional[str]: if not auth_type: raise Exception("No auth_type specified in extra_config. ") - if auth_type == 'AK': - default_region = extra_config.get('region', None) - if not default_region: - raise Exception("No region is specified for connection: " + self.oss_conn_id) - else: - raise Exception("Unsupported auth_type: " + auth_type) + if auth_type != 'AK': + raise Exception(f"Unsupported auth_type: {auth_type}") + default_region = extra_config.get('region', None) + if not default_region: + raise Exception(f"No region is specified for connection: {self.oss_conn_id}") return default_region diff --git a/airflow/providers/alibaba/cloud/log/oss_task_handler.py b/airflow/providers/alibaba/cloud/log/oss_task_handler.py index d26bfbfd048dc..bb1d801ac9821 100644 --- a/airflow/providers/alibaba/cloud/log/oss_task_handler.py +++ b/airflow/providers/alibaba/cloud/log/oss_task_handler.py @@ -15,7 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import contextlib import os +import pathlib import sys if sys.version_info >= (3, 8): @@ -61,6 +63,7 @@ def hook(self): ) def set_context(self, ti): + """This function is used to set the context of the handler""" super().set_context(ti) # Local location and remote location is needed to open and # upload local log file to OSS remote storage. @@ -91,8 +94,7 @@ def close(self): remote_loc = self.log_relative_path if os.path.exists(local_loc): # read log and remove old logs to get just the latest additions - with open(local_loc) as logfile: - log = logfile.read() + log = pathlib.Path(local_loc).read_text() self.oss_write(log, remote_loc) # Mark closed so we don't double write if close is called twice @@ -114,15 +116,14 @@ def _read(self, ti, try_number, metadata=None): log_relative_path = self._render_filename(ti, try_number) remote_loc = log_relative_path - if self.oss_log_exists(remote_loc): - # If OSS remote file exists, we do not fetch logs from task instance - # local machine even if there are errors reading remote logs, as - # returned remote_log will contain error messages. - remote_log = self.oss_read(remote_loc, return_error=True) - log = f'*** Reading remote log from {remote_loc}.\n{remote_log}\n' - return log, {'end_of_log': True} - else: + if not self.oss_log_exists(remote_loc): return super()._read(ti, try_number) + # If OSS remote file exists, we do not fetch logs from task instance + # local machine even if there are errors reading remote logs, as + # returned remote_log will contain error messages. + remote_log = self.oss_read(remote_loc, return_error=True) + log = f'*** Reading remote log from {remote_loc}.\n{remote_log}\n' + return log, {'end_of_log': True} def oss_log_exists(self, remote_log_location): """ @@ -131,11 +132,9 @@ def oss_log_exists(self, remote_log_location): :param remote_log_location: log's location in remote storage :return: True if location exists else False """ - oss_remote_log_location = self.base_folder + '/' + remote_log_location - try: + oss_remote_log_location = f'{self.base_folder}/{remote_log_location}' + with contextlib.suppress(Exception): return self.hook.key_exist(self.bucket_name, oss_remote_log_location) - except Exception: - pass return False def oss_read(self, remote_log_location, return_error=False): @@ -148,7 +147,7 @@ def oss_read(self, remote_log_location, return_error=False): error occurs. Otherwise returns '' when an error occurs. """ try: - oss_remote_log_location = self.base_folder + '/' + remote_log_location + oss_remote_log_location = f'{self.base_folder}/{remote_log_location}' self.log.info("read remote log: %s", oss_remote_log_location) return self.hook.read_key(self.bucket_name, oss_remote_log_location) except Exception: @@ -168,7 +167,7 @@ def oss_write(self, log, remote_log_location, append=True): :param append: if False, any existing log file is overwritten. If True, the new log is appended to any existing logs. """ - oss_remote_log_location = self.base_folder + '/' + remote_log_location + oss_remote_log_location = f'{self.base_folder}/{remote_log_location}' pos = 0 if append and self.oss_log_exists(oss_remote_log_location): head = self.hook.head_key(self.bucket_name, oss_remote_log_location) diff --git a/airflow/providers/alibaba/cloud/sensors/oss_key.py b/airflow/providers/alibaba/cloud/sensors/oss_key.py index a53dcbb3f7cd1..0160783178b08 100644 --- a/airflow/providers/alibaba/cloud/sensors/oss_key.py +++ b/airflow/providers/alibaba/cloud/sensors/oss_key.py @@ -66,7 +66,12 @@ def __init__( self.hook: Optional[OSSHook] = None def poke(self, context: 'Context'): - + """ + Check if the object exists in the bucket to pull key. + @param self - the object itself + @param context - the context of the object + @returns True if the object exists, False otherwise + """ if self.bucket_name is None: parsed_url = urlparse(self.bucket_key) if parsed_url.netloc == '': diff --git a/tests/providers/alibaba/cloud/hooks/test_oss.py b/tests/providers/alibaba/cloud/hooks/test_oss.py index f5659bf3f2313..b5d78c8bc972b 100644 --- a/tests/providers/alibaba/cloud/hooks/test_oss.py +++ b/tests/providers/alibaba/cloud/hooks/test_oss.py @@ -61,7 +61,7 @@ def test_get_bucket(self, mock_oss2, mock_get_credential): self.hook.get_bucket('mock_bucket_name') mock_get_credential.assert_called_once_with() mock_oss2.Bucket.assert_called_once_with( - mock_get_credential.return_value, 'http://oss-mock_region.aliyuncs.com', MOCK_BUCKET_NAME + mock_get_credential.return_value, 'https://oss-mock_region.aliyuncs.com', MOCK_BUCKET_NAME ) @mock.patch(OSS_STRING.format('OSSHook.get_bucket'))