diff --git a/core/src/main/python/synapse/ml/core/logging/SynapseMLLogger.py b/core/src/main/python/synapse/ml/core/logging/SynapseMLLogger.py index f1dd6d79bc..38b244b3e4 100644 --- a/core/src/main/python/synapse/ml/core/logging/SynapseMLLogger.py +++ b/core/src/main/python/synapse/ml/core/logging/SynapseMLLogger.py @@ -126,9 +126,18 @@ def _log_base( num_cols: int, execution_sec: float, feature_name: Optional[str] = None, + custom_log_dict: Optional[Dict[str, str]] = None, ): + payload_dict = self._get_payload( + class_name, method_name, num_cols, execution_sec, None + ) + if custom_log_dict: + if shared_keys := set(custom_log_dict.keys()) & set(payload_dict.keys()): + raise ValueError( + f"Shared keys found in custom logger dictionary: {shared_keys}" + ) self._log_base_dict( - self._get_payload(class_name, method_name, num_cols, execution_sec, None), + payload_dict | (custom_log_dict if custom_log_dict else {}), feature_name=feature_name, ) @@ -213,6 +222,66 @@ def log_decorator_wrapper(self, *args, **kwargs): return get_wrapper + @staticmethod + def log_verb_static( + method_name: Optional[str] = None, + feature_name: Optional[str] = None, + custom_log_function=None, + ): + def get_wrapper(func): + @functools.wraps(func) + def log_decorator_wrapper(self, *args, **kwargs): + # Validate that self._logger is set + if not hasattr(self, "_logger"): + raise AttributeError( + f"{self.__class__.__name__} does not have a '_logger' attribute. " + "Ensure a _logger instance is initialized in the constructor." + ) + + # Validate custom_log_function for proper definition + if custom_log_function: + if not callable(custom_log_function): + raise ValueError("custom_log_function must be callable") + + logger = self._logger + start_time = time.perf_counter() + try: + result = func(self, *args, **kwargs) + execution_time = logger._round_significant( + time.perf_counter() - start_time, 3 + ) + # Create custom logs if necessary + custom_log_dict = None + if custom_log_function: + custom_log_dict = custom_log_function( + self, result, *args, **kwargs + ) + if not isinstance(custom_log_dict, dict): + raise TypeError( + "custom_log_function must return a Dict[str, str]" + ) + + logger._log_base( + func.__module__, + method_name if method_name else func.__name__, + logger.get_column_number(args, kwargs), + execution_time, + feature_name, + custom_log_dict, + ) + return result + except Exception as e: + logger._log_error_base( + func.__module__, + method_name if method_name else func.__name__, + e, + ) + raise + + return log_decorator_wrapper + + return get_wrapper + def log_class(self, feature_name: str): return self._log_base("constructor", None, None, None, feature_name) diff --git a/core/src/test/python/synapsemltest/core/test_logging.py b/core/src/test/python/synapsemltest/core/test_logging.py index 83a4b9b1fd..e419fd4965 100644 --- a/core/src/test/python/synapsemltest/core/test_logging.py +++ b/core/src/test/python/synapsemltest/core/test_logging.py @@ -35,6 +35,47 @@ def test_feature_name(self): return 0 +class NoInheritTransformer: + def __init__(self): + self._logger = SynapseMLLogger(log_level=logging.DEBUG) + + @SynapseMLLogger.log_verb_static(method_name="transform") + def transform(self, df): + return True + + @SynapseMLLogger.log_verb_static(method_name="fit") + def fit(self, df): + return True + + @SynapseMLLogger.log_verb_static() + def test_throw(self): + raise Exception("test exception") + + @SynapseMLLogger.log_verb_static(feature_name="test_logging") + def test_feature_name(self): + return 0 + + def custom_logging_function(self, results, *args, **kwargs): + return {"args": f"Arguments: {args}", "result": str(results)} + + @SynapseMLLogger.log_verb_static(custom_log_function=custom_logging_function) + def test_custom_function(self, df): + return 0 + + def custom_logging_function_w_collision(self, results, *args, **kwargs): + return { + "args": f"Arguments: {args}", + "result": str(results), + "className": "this is the collision key", + } + + @SynapseMLLogger.log_verb_static( + custom_log_function=custom_logging_function_w_collision + ) + def test_custom_function_w_collision(self, df): + return 0 + + class LoggingTest(unittest.TestCase): def test_logging_smoke(self): t = SampleTransformer() @@ -49,6 +90,26 @@ def test_logging_smoke(self): assert f"{e}" == "test exception" t.test_feature_name() + def test_log_verb_static(self): + t = NoInheritTransformer() + data = [("Alice", 25), ("Bob", 30), ("Charlie", 35)] + columns = ["name", "age"] + df = sc.createDataFrame(data, columns) + t.transform(df) + t.fit(df) + t.test_feature_name() + t.test_custom_function(df) + try: + t.test_throw() + except Exception as e: + assert f"{e}" == "test exception" + try: + t.test_custom_function_w_collision(df) + except Exception as e: + assert ( + f"{e}" == "Shared keys found in custom logger dictionary: {'className'}" + ) + if __name__ == "__main__": result = unittest.main()