From 289e496e2432c84c08c1f6f991647776ec53a6d8 Mon Sep 17 00:00:00 2001 From: cruise Date: Mon, 22 Jan 2024 21:48:24 +0800 Subject: [PATCH] format code --- .../ml/core/logging/SynapseMLLogger.py | 133 ++++++++++++------ .../python/synapsemltest/core/test_logging.py | 7 +- 2 files changed, 94 insertions(+), 46 deletions(-) diff --git a/core/src/main/python/synapse/ml/core/logging/SynapseMLLogger.py b/core/src/main/python/synapse/ml/core/logging/SynapseMLLogger.py index bd7b3fb06a..919e18e99d 100644 --- a/core/src/main/python/synapse/ml/core/logging/SynapseMLLogger.py +++ b/core/src/main/python/synapse/ml/core/logging/SynapseMLLogger.py @@ -3,7 +3,10 @@ import functools import time import uuid -from synapse.ml.core.platform.Platform import running_on_synapse_internal, running_on_synapse +from synapse.ml.core.platform.Platform import ( + running_on_synapse_internal, + running_on_synapse, +) from pyspark.sql.dataframe import DataFrame from pyspark import SparkContext from pyspark.sql import SparkSession @@ -21,17 +24,21 @@ def format(self, record): return s -class SynapseMLLogger(): +class SynapseMLLogger: def __init__( - self, - library_name: str = None, # e.g SynapseML - library_version: str = None, - uid: str = None, - log_level: int = logging.INFO + self, + library_name: str = None, # e.g SynapseML + library_version: str = None, + uid: str = None, + log_level: int = logging.INFO, ): - self.logger: logging.Logger = SynapseMLLogger._get_environment_logger(log_level=log_level) + self.logger: logging.Logger = SynapseMLLogger._get_environment_logger( + log_level=log_level + ) self.library_name = library_name if library_name else "SynapseML" - self.library_version = library_version if library_version else self._get_synapseml_version() + self.library_version = ( + library_version if library_version else self._get_synapseml_version() + ) self.uid = uid if uid else f"{self.__class__.__name__}_{uuid.uuid4()}" self.is_executor = False if SynapseMLLogger.safe_get_spark_context() else True @@ -47,17 +54,24 @@ def safe_get_spark_context(cls) -> SparkContext: @classmethod def _round_significant(cls, num, digits): from math import log10, floor - return round(num, digits-int(floor(log10(abs(num))))-1) + + return round(num, digits - int(floor(log10(abs(num)))) - 1) @classmethod def _get_synapseml_version(cls) -> Optional[str]: try: from synapse.ml.core import __spark_package_version__ + return __spark_package_version__ except Exception: return None - def get_required_log_fields(self, uid: str, class_name: str, method: str, ): + def get_required_log_fields( + self, + uid: str, + class_name: str, + method: str, + ): return { "modelUid": uid, "className": class_name, @@ -70,7 +84,12 @@ def get_required_log_fields(self, uid: str, class_name: str, method: str, ): def _get_environment_logger(cls, log_level: int) -> logging.Logger: if running_on_synapse_internal(): from synapse.ml.pymds.synapse_logger import get_mds_logger - return get_mds_logger(name=__name__, log_level=log_level, formatter=CustomFormatter(fmt='%(message)s')) + + return get_mds_logger( + name=__name__, + log_level=log_level, + formatter=CustomFormatter(fmt="%(message)s"), + ) elif running_on_synapse(): logger = logging.getLogger(__name__) logger.setLevel(log_level) @@ -84,28 +103,48 @@ def _get_environment_logger(cls, log_level: int) -> logging.Logger: def get_hadoop_conf_entries(cls): if running_on_synapse_internal(): from synapse.ml.internal_utils.session_utils import get_fabric_context + return { - "artifactId": get_fabric_context().get('trident.artifact.id'), - "workspaceId": get_fabric_context().get('trident.workspace.id'), - "capacityId": get_fabric_context().get('trident.capacity.id'), - "artifactWorkspaceId": get_fabric_context().get('trident.artifact.workspace.id'), - "livyId": get_fabric_context().get('trident.activity.id'), - "artifactType": get_fabric_context().get('trident.artifact.type'), - "tenantId": get_fabric_context().get('trident.tenant.id'), - "lakehouseId": get_fabric_context().get('trident.lakehouse.id') + "artifactId": get_fabric_context().get("trident.artifact.id"), + "workspaceId": get_fabric_context().get("trident.workspace.id"), + "capacityId": get_fabric_context().get("trident.capacity.id"), + "artifactWorkspaceId": get_fabric_context().get( + "trident.artifact.workspace.id" + ), + "livyId": get_fabric_context().get("trident.activity.id"), + "artifactType": get_fabric_context().get("trident.artifact.type"), + "tenantId": get_fabric_context().get("trident.tenant.id"), + "lakehouseId": get_fabric_context().get("trident.lakehouse.id"), } else: return {} - def _log_base(self, class_name: str, method_name: Optional[str], num_cols: int, execution_sec: float, feature_name: Optional[str] = None): - self._log_base_dict(self._get_payload(class_name, method_name, num_cols, execution_sec, None), feature_name=feature_name) + def _log_base( + self, + class_name: str, + method_name: Optional[str], + num_cols: int, + execution_sec: float, + feature_name: Optional[str] = None, + ): + self._log_base_dict( + self._get_payload(class_name, method_name, num_cols, execution_sec, None), + feature_name=feature_name, + ) def _log_base_dict(self, info: Dict[str, str], feature_name: Optional[str] = None): if feature_name is not None and running_on_synapse_internal(): from synapse.ml.fabric.telemetry_utils import report_usage_telemetry - keys_to_remove = ['libraryName', 'method'] - attributes = {key: value for key, value in info.items() if key not in keys_to_remove} - report_usage_telemetry(feature_name=self.library_name, activity_name=feature_name, attributes=attributes) + + keys_to_remove = ["libraryName", "method"] + attributes = { + key: value for key, value in info.items() if key not in keys_to_remove + } + report_usage_telemetry( + feature_name=self.library_name, + activity_name=feature_name, + attributes=attributes, + ) self.logger.info(json.dumps(info)) def log_message(self, message: str): @@ -113,32 +152,36 @@ def log_message(self, message: str): @classmethod def get_error_fields(cls, e: Exception) -> Dict[str, str]: - return { - "errorType": str(type(e)), - "errorMessage": f'{e}' - } - - def _get_payload(self, class_name: str, method_name: Optional[str], num_cols: Optional[int], execution_sec: Optional[float], exception: Optional[Exception]): + return {"errorType": str(type(e)), "errorMessage": f"{e}"} + + def _get_payload( + self, + class_name: str, + method_name: Optional[str], + num_cols: Optional[int], + execution_sec: Optional[float], + exception: Optional[Exception], + ): info = self.get_required_log_fields(self.uid, class_name, method_name) env_conf = self.get_hadoop_conf_entries() for k in env_conf.keys(): info[k] = env_conf[k] if num_cols is not None: - info['dfInfo'] = { - "input": {"numCols": str(num_cols)} - } + info["dfInfo"] = {"input": {"numCols": str(num_cols)}} if execution_sec is not None: - info['executionSeconds'] = str(execution_sec) + info["executionSeconds"] = str(execution_sec) if exception: exception_info = self.get_error_fields(exception) for k in exception_info.keys(): info[k] = exception_info[k] - info['protocolVersion'] = PROTOCOL_VERSION - info['isExecutor'] = self.is_executor + info["protocolVersion"] = PROTOCOL_VERSION + info["isExecutor"] = self.is_executor return info def _log_error_base(self, class_name: str, method_name: str, e: Exception): - self.logger.exception(json.dumps(self._get_payload(class_name, method_name, None, None, e))) + self.logger.exception( + json.dumps(self._get_payload(class_name, method_name, None, None, e)) + ) def log_verb(method_name: Optional[str] = None): def get_wrapper(func): @@ -147,23 +190,27 @@ def log_decorator_wrapper(self, *args, **kwargs): start_time = time.perf_counter() try: result = func(self, *args, **kwargs) - execution_time = SynapseMLLogger._round_significant(time.perf_counter() - start_time, 3) + execution_time = SynapseMLLogger._round_significant( + time.perf_counter() - start_time, 3 + ) self._log_base( func.__module__, method_name if method_name else func.__name__, SynapseMLLogger.get_column_number(args, kwargs), execution_time, - None + None, ) return result except Exception as e: self._log_error_base( func.__module__, method_name if method_name else func.__name__, - e + e, ) raise + return log_decorator_wrapper + return get_wrapper def log_class(self, feature_name: str): @@ -177,8 +224,8 @@ def log_fit(): @classmethod def get_column_number(cls, args, kwargs): - if kwargs and kwargs['df'] and isinstance(kwargs['df'], DataFrame): - return len(kwargs['df'].columns) + if kwargs and kwargs["df"] and isinstance(kwargs["df"], DataFrame): + return len(kwargs["df"].columns) elif args and len(args) > 0 and isinstance(args[0], DataFrame): return len(args[0].columns) return None diff --git a/core/src/test/python/synapsemltest/core/test_logging.py b/core/src/test/python/synapsemltest/core/test_logging.py index 703a9adb5d..9435593109 100644 --- a/core/src/test/python/synapsemltest/core/test_logging.py +++ b/core/src/test/python/synapsemltest/core/test_logging.py @@ -6,12 +6,13 @@ import logging from synapse.ml.core.logging.SynapseMLLogger import SynapseMLLogger +from synapsemltest.spark import * class SampleTransformer(SynapseMLLogger): def __init__(self): super().__init__(log_level=logging.DEBUG) - self.log_class('SampleTransformer') + self.log_class("SampleTransformer") @SynapseMLLogger.log_transform() def transform(self, df): @@ -37,8 +38,8 @@ def test_logging_smoke(self): try: t.test_throw() except Exception as e: - assert f'{e}' == "test exception" + assert f"{e}" == "test exception" if __name__ == "__main__": - result = unittest.main() \ No newline at end of file + result = unittest.main()