Skip to content

Commit

Permalink
feat: fix high memory issues in Gaussian copula fitting for high card…
Browse files Browse the repository at this point in the history
…inality discrete columns based on frequency encoding. (#233)
  • Loading branch information
jalr4ever authored Nov 7, 2024
1 parent 4e4fff1 commit 8eb395b
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 1 deletion.
112 changes: 112 additions & 0 deletions sdgx/models/components/optimize/sdv_copulas/data_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import numpy as np
import pandas as pd

from sdgx.models.components.sdv_ctgan.data_transformer import (
ColumnTransformInfo,
DataTransformer,
SpanInfo,
)
from sdgx.models.components.sdv_rdt.transformers import ClusterBasedNormalizer
from sdgx.models.components.sdv_rdt.transformers.categorical import FrequencyEncoder

# TODO(Enhance) - Use different type of Encoder for discrete, like ordered columns, high cardinality columns...


class StatisticDataTransformer(DataTransformer):
"""Data Transformer for statistical models like Gaussian Copula."""

def _fit_continuous(self, data):
"""Train ClusterBasedNormalizer for continuous columns."""
column_name = data.columns[0]
gm = ClusterBasedNormalizer(model_missing_values=True, max_clusters=1)
gm.fit(data, column_name)

return ColumnTransformInfo(
column_name=column_name,
column_type="continuous",
transform=gm,
output_info=[SpanInfo(1, "tanh")],
output_dimensions=1,
)

def _transform_continuous(self, column_transform_info, data):
"""Transform continuous column."""
gm = column_transform_info.transform
transformed = gm.transform(data)
return transformed[f"{data.columns[0]}.normalized"].to_numpy().reshape(-1, 1)

def _inverse_transform_continuous(self, column_transform_info, column_data, sigmas, st):
"""Inverse transform continuous column."""
gm = column_transform_info.transform
column_name = column_transform_info.column_name

# Create dataframe
data = pd.DataFrame(
{
f"{column_name}.normalized": column_data.flatten(),
f"{column_name}.component": [0] * len(column_data), # virtual component
}
)

if sigmas is not None:
data[f"{column_name}.normalized"] = np.random.normal(
data[f"{column_name}.normalized"], sigmas[st]
)

# Reverse data
result = gm.reverse_transform(data)

# Ensure correct column
if column_name in result.columns:
return result[column_name]
else:
# Try first column
return result.iloc[:, 0]

def _fit_discrete(self, data):
"""Fit frequency encoder for discrete column."""
column_name = data.columns[0]
freq_encoder = FrequencyEncoder()
freq_encoder.fit(data, column_name)

# Save original unique values for inverse transform
self._discrete_values = (
{column_name: data[column_name].unique().tolist()}
if not hasattr(self, "_discrete_values")
else {**self._discrete_values, column_name: data[column_name].unique().tolist()}
)

return ColumnTransformInfo(
column_name=column_name,
column_type="discrete",
transform=freq_encoder,
output_info=[SpanInfo(1, "tanh")],
output_dimensions=1,
)

def _transform_discrete(self, column_transform_info, data):
"""Transform discrete column using frequency encoding."""
freq_encoder = column_transform_info.transform
return freq_encoder.transform(data).to_numpy().reshape(-1, 1)

def _inverse_transform_discrete(self, column_transform_info, column_data):
"""Inverse transform discrete column from frequency encoding."""
freq_encoder = column_transform_info.transform
column_name = column_transform_info.column_name

# Use frequency encoder to reverse transform
data = pd.DataFrame({column_name: column_data.flatten()})

# Get all possible category values
categories = freq_encoder.starts["category"].values

# Find the closest category for each frequency value
result = []
for val in data[column_name]:
# The index of the closest start point
starts = freq_encoder.starts.index.values
idx = np.abs(starts - val).argmin()
# Set which category does the closest start point belong to
result.append(categories[idx])

return pd.Series(result, index=data.index, dtype=freq_encoder.dtype)
5 changes: 4 additions & 1 deletion sdgx/models/statistics/single_table/copula.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from sdgx.data_loader import DataLoader
from sdgx.data_models.metadata import Metadata
from sdgx.exceptions import NonParametricError, SynthesizerInitError
from sdgx.models.components.optimize.sdv_copulas.data_transformer import (
StatisticDataTransformer,
)
from sdgx.models.components.sdv_copulas import multivariate
from sdgx.models.components.sdv_ctgan.data_transformer import DataTransformer
from sdgx.models.components.sdv_rdt.transformers import OneHotEncoder
Expand Down Expand Up @@ -138,7 +141,7 @@ def fit(self, metadata: Metadata, dataloader: DataLoader, *args, **kwargs):
self.metadata = metadata

# load the original transformer
self._transformer = DataTransformer()
self._transformer = StatisticDataTransformer()

# self._transformer.fit(processed_data, self.metadata[0])
self._transformer.fit(processed_data, self.discrete_cols)
Expand Down

0 comments on commit 8eb395b

Please sign in to comment.