-
Notifications
You must be signed in to change notification settings - Fork 551
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add mutual information metric (#101)
* test * test_v2 * no-test * pair_v1 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove_old_mi_sim * modify single&multi_table MISim * modify single_mi_sim by using pair_sim instance * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * modify multi_mi_sim by using pair_sim instance * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change_class_name_err * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * modify_paircolumn * mi only needs dataframe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * modify based on review * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * complete test_mi_sim * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * modify test file * change_var_name * Update sdgx/metrics/multi_table/multitable_mi_sim.py Co-authored-by: MoooCat <[email protected]> * add MULTI_TABLE_DEMO_DATA * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * modify comments * JSD->MISIM * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * modify base of pair_column * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add cls * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change self into cls instance * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change cls * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * series2array * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test * test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add label_encoder for category in mi_sim * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use series.array * change le_fit * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change transform type to np.array instead of list * add astype * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * series2array * foo * change test_suit * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * all right? * all right * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Z712023 <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Z712023 <[email protected]>
- Loading branch information
1 parent
d29a2a0
commit dae869e
Showing
9 changed files
with
398 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import numpy as np | ||
import pandas as pd | ||
from scipy.stats import entropy | ||
from sklearn.metrics.cluster import normalized_mutual_info_score | ||
|
||
from sdgx.metrics.multi_table.base import MultiTableMetric | ||
from sdgx.metrics.pair_column.mi_sim import MISim | ||
|
||
|
||
class MISim(MultiTableMetric): | ||
"""MISim : Mutual Information Similarity | ||
This class is used to calculate the Mutual Information Similarity between the target columns of real data and synthetic data. | ||
Currently, we support discrete and continuous(need to be discretized) columns as inputs. | ||
""" | ||
|
||
def __init__(self) -> None: | ||
super().__init__() | ||
self.lower_bound = 0 | ||
self.upper_bound = 1 | ||
self.metric_name = "mutual_information_similarity" | ||
self.numerical_bins = 50 | ||
|
||
@classmethod | ||
def calculate( | ||
real_data: pd.DataFrame, synthetic_data: pd.DataFrame, metadata: dict | ||
) -> pd.DataFrame: | ||
""" | ||
Calculate the Mutual Information Similarity between a real column and a synthetic column. | ||
Args: | ||
real_data (pd.DataFrame): The real data. | ||
synthetic_data (pd.DataFrame): The synthetic data. | ||
metadata(dict): The metadata that describes the data type of each column | ||
Returns: | ||
MI_similarity (float): The metric value. | ||
""" | ||
|
||
# 传入概率分布数组 | ||
|
||
columns = synthetic_data.columns | ||
n = len(columns) | ||
mi_sim_instance = MISim() | ||
nMI_sim = np.zeros((n, n)) | ||
|
||
for i in range(len(columns)): | ||
for j in range(len(columns)): | ||
syn_data = pd.concat( | ||
[synthetic_data[columns[i]], synthetic_data[columns[j]]], axis=1 | ||
) | ||
real_data = pd.concat([real_data[columns[i]], real_data[columns[j]]], axis=1) | ||
|
||
nMI_sim[i][j] = mi_sim_instance.calculate(real_data, syn_data, metadata) | ||
|
||
MI_sim = np.sum(nMI_sim) / n / n | ||
# test | ||
MISim.check_output(MI_sim) | ||
|
||
return MI_sim | ||
|
||
@classmethod | ||
def check_output(cls, raw_metric_value: float): | ||
"""Check the output value. | ||
Args: | ||
raw_metric_value (float): the calculated raw value of the Mutual Information Similarity. | ||
""" | ||
instance = cls() | ||
if raw_metric_value < instance.lower_bound or raw_metric_value > instance.upper_bound: | ||
raise ValueError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import pandas as pd | ||
|
||
from sdgx.log import logger | ||
|
||
|
||
class PairMetric(object): | ||
"""PairMetric | ||
Metrics used to evaluate the quality of synthetic data columns. | ||
""" | ||
|
||
upper_bound = None | ||
lower_bound = None | ||
metric_name = "Correlation" | ||
|
||
def __init__(self) -> None: | ||
pass | ||
|
||
@classmethod | ||
def check_input(cls, src_col: pd.Series, tar_col: pd.Series, metadata: dict): | ||
"""Input check for table input. | ||
Args: | ||
src_data(pd.Series ): the source data column. | ||
tar_data(pd.Series): the target data column . | ||
metadata(dict): The metadata that describes the data type of each column | ||
""" | ||
# Input parameter must not contain None value | ||
if real_data is None or synthetic_data is None: | ||
raise TypeError("Input contains None.") | ||
# check column_names | ||
tar_name = tar_col.name | ||
src_name = src_col.name | ||
|
||
# check column_types | ||
if metadata[tar_name] != metadata[src_name]: | ||
raise TypeError("Type of Pair is Conflicting.") | ||
|
||
# if type is pd.Series, return directly | ||
if isinstance(real_data, pd.Series): | ||
return src_col, tar_col | ||
|
||
# if type is not pd.Series or pd.DataFrame tranfer it to Series | ||
try: | ||
src_col = pd.Series(src_col) | ||
tar_col = pd.Series(tar_col) | ||
return src_col, tar_col | ||
except Exception as e: | ||
logger.error(f"An error occurred while converting to pd.Series: {e}") | ||
|
||
return None, None | ||
|
||
@classmethod | ||
def calculate(cls, src_col: pd.Series, tar_col: pd.Series, metadata): | ||
"""Calculate the metric value between pair-columns between real table and synthetic table. | ||
Args: | ||
src_data(pd.Series ): the source data column. | ||
tar_data(pd.Series): the target data column . | ||
metadata(dict): The metadata that describes the data type of each column | ||
""" | ||
# This method should first check the input | ||
# such as: | ||
real_data, synthetic_data = PairMetric.check_input(src_col, tar_col) | ||
|
||
raise NotImplementedError() | ||
|
||
@classmethod | ||
def check_output(cls, raw_metric_value: float): | ||
"""Check the output value. | ||
Args: | ||
raw_metric_value (float): the calculated raw value of the Mutual Information. | ||
""" | ||
raise NotImplementedError() | ||
|
||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import numpy as np | ||
import pandas as pd | ||
from scipy.stats import entropy | ||
from sklearn.metrics.cluster import normalized_mutual_info_score | ||
from sklearn.preprocessing import LabelEncoder | ||
|
||
from sdgx.metrics.pair_column.base import PairMetric | ||
from sdgx.utils import time2int | ||
|
||
|
||
class MISim(PairMetric): | ||
"""MISim : Mutual Information Similarity | ||
This class is used to calculate the Mutual Information Similarity between the target columns of real data and synthetic data. | ||
Currently, we support discrete and continuous(need to be discretized) columns as inputs. | ||
""" | ||
|
||
def __init__(instance) -> None: | ||
super().__init__() | ||
instance.lower_bound = 0 | ||
instance.upper_bound = 1 | ||
instance.metric_name = "mutual_information_similarity" | ||
instance.numerical_bins = 50 | ||
|
||
@classmethod | ||
def calculate( | ||
cls, | ||
src_col: pd.Series, | ||
tar_col: pd.Series, | ||
metadata: dict, | ||
) -> float: | ||
""" | ||
Calculate the MI similarity for the source data colum and the target data column. | ||
Args: | ||
src_data(pd.Series ): the source data column. | ||
tar_data(pd.Series): the target data column . | ||
metadata(dict): The metadata that describes the data type of each columns | ||
Returns: | ||
MI_similarity (float): The metric value. | ||
""" | ||
|
||
# 传入概率分布数组 | ||
instance = cls() | ||
|
||
col_name = src_col.name | ||
data_type = metadata[col_name] | ||
|
||
if data_type == "numerical": | ||
x = np.array(src_col.array) | ||
src_col = pd.cut( | ||
x, | ||
instance.numerical_bins, | ||
labels=range(instance.numerical_bins), | ||
) | ||
x = np.array(tar_col.array) | ||
tar_col = pd.cut( | ||
x, | ||
instance.numerical_bins, | ||
labels=range(instance.numerical_bins), | ||
) | ||
src_col = src_col.to_numpy() | ||
tar_col = tar_col.to_numpy() | ||
|
||
elif data_type == "category": | ||
le = LabelEncoder() | ||
src_list = list(set(src_col.array)) | ||
tar_list = list(set(tar_col.array)) | ||
fit_list = tar_list + src_list | ||
le.fit(fit_list) | ||
|
||
src_col = le.transform(np.array(src_col.array)) | ||
tar_col = le.transform(np.array(tar_col.array)) | ||
|
||
elif data_type == "datetime": | ||
src_col = src_col.apply(time2int) | ||
tar_col = tar_col.apply(time2int) | ||
src_col = pd.cut( | ||
src_col, bins=instance.numerical_bins, labels=range(instance.numerical_bins) | ||
) | ||
tar_col = pd.cut( | ||
tar_col, bins=instance.numerical_bins, labels=range(instance.numerical_bins) | ||
) | ||
src_col = src_col.to_numpy() | ||
tar_col = tar_col.to_numpy() | ||
|
||
MI_sim = normalized_mutual_info_score(src_col, tar_col) | ||
return MI_sim | ||
|
||
@classmethod | ||
def check_output(cls, raw_metric_value: float): | ||
"""Check the output value. | ||
Args: | ||
raw_metric_value (float): the calculated raw value of the MI similarity. | ||
""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import numpy as np | ||
import pandas as pd | ||
from scipy.stats import entropy | ||
from sklearn.metrics.cluster import normalized_mutual_info_score | ||
|
||
from sdgx.metrics.pair_column.mi_sim import MISim | ||
from sdgx.metrics.single_table.base import SingleTableMetric | ||
|
||
|
||
class SinTabMISim(SingleTableMetric): | ||
"""MISim : Mutual Information Similarity | ||
This class is used to calculate the Mutual Information Similarity between the target columns of real data and synthetic data. | ||
Currently, we support discrete and continuous(need to be discretized) columns as inputs. | ||
""" | ||
|
||
def __init__(self) -> None: | ||
super().__init__() | ||
self.lower_bound = 0 | ||
self.upper_bound = 1 | ||
self.metric_name = "mutual_information_similarity" | ||
self.numerical_bins = 50 | ||
|
||
@classmethod | ||
def calculate(real_data: pd.DataFrame, synthetic_data: pd.DataFrame, metadata) -> pd.DataFrame: | ||
""" | ||
Calculate the Mutual Information Similarity between a real column and a synthetic column. | ||
Args: | ||
real_data (pd.DataFrame): The real data. | ||
synthetic_data (pd.DataFrame): The synthetic data. | ||
metadata(dict): The metadata that describes the data type of each column | ||
Returns: | ||
MI_similarity (float): The metric value. | ||
""" | ||
|
||
# 传入概率分布数组 | ||
|
||
columns = synthetic_data.columns | ||
n = len(columns) | ||
mi_sim_instance = MISim() | ||
nMI_sim = np.zeros((n, n)) | ||
|
||
for i in range(len(columns)): | ||
for j in range(len(columns)): | ||
syn_data = pd.concat( | ||
[synthetic_data[columns[i]], synthetic_data[columns[j]]], axis=1 | ||
) | ||
real_data = pd.concat([real_data[columns[i]], real_data[columns[j]]], axis=1) | ||
|
||
nMI_sim[i][j] = mi_sim_instance.calculate(real_data, syn_data, metadata) | ||
|
||
MI_sim = np.sum(nMI_sim) / n / n | ||
MISim.check_output(MI_sim) | ||
|
||
return MI_sim | ||
|
||
@classmethod | ||
def check_output(cls, raw_metric_value: float): | ||
"""Check the output value. | ||
Args: | ||
raw_metric_value (float): the calculated raw value of the Mutual Information Similarity. | ||
""" | ||
instance = cls() | ||
if raw_metric_value < instance.lower_bound or raw_metric_value > instance.upper_bound: | ||
raise ValueError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.