Skip to content

Commit

Permalink
Merge pull request #1 from sss04/shyamsai/customLogging
Browse files Browse the repository at this point in the history
chore: Add custom logging to SynapseMLLogger
  • Loading branch information
sss04 authored Aug 8, 2024
2 parents 24b72c6 + e54bb18 commit 8809c08
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 20 deletions.
35 changes: 16 additions & 19 deletions core/src/main/python/synapse/ml/core/logging/SynapseMLLogger.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from pyspark import SparkContext
from pyspark.sql import SparkSession
import json
import inspect

PROTOCOL_VERSION = "0.0.1"

Expand Down Expand Up @@ -127,10 +126,14 @@ def _log_base(
num_cols: int,
execution_sec: float,
feature_name: Optional[str] = None,
custom_log_info: Optional[str] = None,
custom_log_dict: Optional[Dict[str, str]] = None,
):
payload_dict = self._get_payload(class_name, method_name, num_cols, execution_sec, None)
if custom_log_dict:
if shared_keys := set(custom_log_dict.keys()) & set(payload_dict.keys()):
raise ValueError(f"Shared keys found in custom logger dictionary: {shared_keys}")
self._log_base_dict(
self._get_payload(class_name, method_name, num_cols, execution_sec, None, custom_log_info),
payload_dict | (custom_log_dict if custom_log_dict else {}),
feature_name=feature_name,
)

Expand Down Expand Up @@ -163,7 +166,6 @@ def _get_payload(
num_cols: Optional[int],
execution_sec: Optional[float],
exception: Optional[Exception],
custom_log_info: Optional[str],
):
info = self.get_required_log_fields(self.uid, class_name, method_name)
env_conf = self.get_hadoop_conf_entries()
Expand All @@ -173,8 +175,6 @@ def _get_payload(
info["dfInfo"] = {"input": {"numCols": str(num_cols)}}
if execution_sec is not None:
info["executionSeconds"] = str(execution_sec)
if custom_log_info is not None:
info["custom_log_info"] = custom_log_info
if exception:
exception_info = self.get_error_fields(exception)
for k in exception_info.keys():
Expand Down Expand Up @@ -219,7 +219,7 @@ def log_decorator_wrapper(self, *args, **kwargs):
return get_wrapper

@staticmethod
def log_verb_static(method_name: Optional[str] = None, custom_log_function = None):
def log_verb_static(method_name: Optional[str] = None, feature_name: Optional[str] = None, custom_log_function = None):
def get_wrapper(func):
@functools.wraps(func)
def log_decorator_wrapper(self, *args, **kwargs):
Expand All @@ -235,30 +235,27 @@ def log_decorator_wrapper(self, *args, **kwargs):
if not callable(custom_log_function):
raise ValueError("custom_log_function must be callable")

# Check if custom_log_function can accept the required parameters
sig = inspect.signature(custom_log_function)
params = list(sig.parameters.values())
if len(params) != 4:
raise TypeError("custom_log_function must accept four parameters: "
"self, *args, **kwargs, and result")

logger = self.logger
start_time = time.perf_counter()
try:
result = func(self, *args, **kwargs)
execution_time = logger._round_significant(
time.perf_counter() - start_time, 3
)
custom_log_info = custom_log_function(self, *args, **kwargs, result)
if not isinstance(custom_log_info, str):
raise TypeError("custom_log_function must return a string")
# Create custom logs if necessary
custom_log_dict = None
if custom_log_function:
custom_log_dict = custom_log_function(self, result, *args, **kwargs)
if not isinstance(custom_log_dict, dict):
raise TypeError("custom_log_function must return a Dict[str, str]")

logger._log_base(
func.__module__,
method_name if method_name else func.__name__,
logger.get_column_number(args, kwargs),
execution_time,
None,
custom_log_info
feature_name,
custom_log_dict
)
return result
except Exception as e:
Expand Down
29 changes: 28 additions & 1 deletion core/src/test/python/synapsemltest/core/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,27 @@ def fit(self, df):
def test_throw(self):
raise Exception("test exception")

@SynapseMLLogger.log_verb_static(feature_name="test_logging")
def test_feature_name(self):
return 0

def custom_logging_function(self, results, *args, **kwargs):
return {"args": f"Arguments: {args}",
"result": str(results)}

@SynapseMLLogger.log_verb_static(custom_log_function=custom_logging_function)
def test_custom_function(self, df):
return 0

def custom_logging_function_w_collision(self, results, *args, **kwargs):
return {"args": f"Arguments: {args}",
"result": str(results),
"className": "this is the collision key"}

@SynapseMLLogger.log_verb_static(custom_log_function=custom_logging_function_w_collision)
def test_custom_function_w_collision(self, df):
return 0


class LoggingTest(unittest.TestCase):
def test_logging_smoke(self):
Expand All @@ -66,17 +87,23 @@ def test_logging_smoke(self):
assert f"{e}" == "test exception"
t.test_feature_name()

def test_logging_smoke_no_inheritance(self):
def test_log_verb_static(self):
t = NoInheritTransformer()
data = [("Alice", 25), ("Bob", 30), ("Charlie", 35)]
columns = ["name", "age"]
df = sc.createDataFrame(data, columns)
t.transform(df)
t.fit(df)
t.test_feature_name()
t.test_custom_function(df)
try:
t.test_throw()
except Exception as e:
assert f"{e}" == "test exception"
try:
t.test_custom_function_w_collision(df)
except Exception as e:
assert f"{e}" == "Shared keys found in custom logger dictionary: {'className'}"


if __name__ == "__main__":
Expand Down

0 comments on commit 8809c08

Please sign in to comment.