Skip to content

Commit

Permalink
New impl for synthesizing composite keys
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeknep committed Jul 18, 2023
1 parent 1a59fe0 commit b2d907b
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 73 deletions.
20 changes: 14 additions & 6 deletions src/gretel_trainer/relational/strategies/ancestral.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@ def post_process_individual_synthetic_result(
"""
processed = synthetic_table

primary_key = rel_data.get_primary_key(table_name)
multigenerational_primary_key = ancestry.get_multigenerational_primary_key(
rel_data, table_name
)
Expand All @@ -254,15 +253,24 @@ def post_process_individual_synthetic_result(
i for i in range(len(synthetic_table))
]
else:
synthetic_pk_columns = common.make_composite_pk_columns(
synthetic_pk_columns = common.make_composite_pks(
table_name=table_name,
rel_data=rel_data,
primary_key=primary_key,
primary_key=multigenerational_primary_key,
synth_row_count=len(synthetic_table),
record_size_ratio=record_size_ratio,
)
for index, col in enumerate(multigenerational_primary_key):
processed[col] = synthetic_pk_columns[index]

# make_composite_pks may not have created as many unique keys as we have
# synthetic rows, so we truncate the table to avoid inserting NaN PKs.
processed = pd.concat(
[
pd.DataFrame.from_records(synthetic_pk_columns),
processed.drop(multigenerational_primary_key, axis="columns").head(
len(synthetic_pk_columns)
),
],
axis=1,
)

for fk_map in ancestry.get_ancestral_foreign_key_maps(rel_data, table_name):
fk_col, parent_pk_col = fk_map
Expand Down
114 changes: 96 additions & 18 deletions src/gretel_trainer/relational/strategies/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import logging
import math
import random
from pathlib import Path
from typing import Optional
Expand All @@ -10,7 +9,7 @@
from gretel_client.projects.models import Model
from sklearn import preprocessing

from gretel_trainer.relational.core import RelationalData
from gretel_trainer.relational.core import MultiTableException, RelationalData
from gretel_trainer.relational.sdk_extras import ExtendedGretelSDK

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -107,28 +106,107 @@ def label_encode_keys(
return tables


def make_composite_pk_columns(
def make_composite_pks(
table_name: str,
rel_data: RelationalData,
primary_key: list[str],
synth_row_count: int,
record_size_ratio: float,
) -> list[tuple]:
source_pk_columns = rel_data.get_table_data(table_name)[primary_key]
unique_counts = source_pk_columns.nunique(axis=0)
new_key_columns_values = []
for col in primary_key:
synth_values_count = math.ceil(unique_counts[col] * record_size_ratio)
new_key_columns_values.append(range(synth_values_count))

results = set()
while len(results) < synth_row_count:
key_combination = tuple(
[random.choice(vals) for vals in new_key_columns_values]
) -> list[dict]:
# Given the randomness involved in this process, it is possible for this function to generate
# fewer unique composite keys than desired to completely fill the dataframe (i.e. the length
# of the tuple values in the dictionary may be < synth_row_count). It is the client's
# responsibility to check for this and drop synthetic records if necessary to fit.
table_data = rel_data.get_table_data(table_name)
original_primary_key = rel_data.get_primary_key(table_name)

pk_component_frequencies = {
col: get_frequencies(table_data, [col]) for col in original_primary_key
}

# each key in new_cols is a column name, and each value is a list of
# column values. The values are a contiguous list of integers, with
# each integer value appearing 1-N times to match the frequencies of
# (original source) values' appearances in the source data.
new_cols: dict[str, list] = {}
for i, col in enumerate(primary_key):
freqs = pk_component_frequencies[original_primary_key[i]]
next_freq = 0
next_key = 0
new_col_values = []

while len(new_col_values) < synth_row_count:
for i in range(freqs[next_freq]):
new_col_values.append(next_key)
next_key += 1
next_freq += 1
if next_freq == len(freqs):
next_freq = 0

# A large frequency may have added more values than we need,
# so trim to synth_row_count
new_cols[col] = new_col_values[0:synth_row_count]

# Shuffle for realism
for col_name, col_values in new_cols.items():
random.shuffle(col_values)

# Zip the individual columns into a list of records.
# Each element in the list is a composite key dict.
composite_keys: list[dict] = []
for i in range(synth_row_count):
comp_key = {}
for col_name, col_values in new_cols.items():
comp_key[col_name] = col_values[i]
composite_keys.append(comp_key)

# The zip above may not have produced unique composite key dicts.
# Using the most unique column (to give us the most options), try
# changing a value to "resolve" candidate composite keys to unique combinations.
cant_resolve = 0
seen: set[str] = set()
final_synthetic_composite_keys: list[dict] = []
most_unique_column = _get_most_unique_column(primary_key, pk_component_frequencies)

for i in range(synth_row_count):
y = i + 1
if y == len(composite_keys):
y = 0

comp_key = composite_keys[i]

while str(comp_key) in seen and y != i:
last_val = new_cols[most_unique_column][y]
y += 1
if y == len(composite_keys):
y = 0
comp_key[most_unique_column] = last_val
if str(comp_key) in seen:
cant_resolve += 1
else:
final_synthetic_composite_keys.append(comp_key)
seen.add(str(comp_key))

return final_synthetic_composite_keys


def _get_most_unique_column(pk: list[str], col_freqs: dict[str, list]) -> str:
most_unique = None
max_length = 0
for col, freqs in col_freqs.items():
if len(freqs) > max_length:
most_unique = col

if most_unique is None:
raise MultiTableException(
f"Failed to identify most unique column from column frequencies: {col_freqs}"
)
results.add(key_combination)

return list(zip(*results))
# The keys in col_freqs are always the source column names from the original primary key.
# Meanwhile, `pk` could be either the same (independent strategy) or in multigenerational
# format (ancestral strategy). We need to return the column name in the format matching
# the rest of the synthetic data undergoing post-processing.
idx = list(col_freqs.keys()).index(most_unique)
return pk[idx]


def get_frequencies(table_data: pd.DataFrame, cols: list[str]) -> list[int]:
Expand Down
19 changes: 14 additions & 5 deletions src/gretel_trainer/relational/strategies/independent.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def prepare_training_data(

pd.DataFrame(columns=use_columns).to_csv(path, index=False)
source_path = rel_data.get_table_source(table)
for chunk in pd.read_csv(source_path, usecols=use_columns, chunksize=10_000):
for chunk in pd.read_csv(
source_path, usecols=use_columns, chunksize=10_000
):
chunk.to_csv(path, index=False, mode="a", header=False)

return table_paths
Expand Down Expand Up @@ -229,15 +231,22 @@ def _synthesize_primary_keys(
elif len(primary_key) == 1:
processed[table_name][primary_key[0]] = [i for i in range(synth_row_count)]
else:
synthetic_pk_columns = common.make_composite_pk_columns(
synthetic_pk_columns = common.make_composite_pks(
table_name=table_name,
rel_data=rel_data,
primary_key=primary_key,
synth_row_count=synth_row_count,
record_size_ratio=record_size_ratio,
)
for index, col in enumerate(primary_key):
processed[table_name][col] = synthetic_pk_columns[index]

# make_composite_pks may not have created as many unique keys as we have
# synthetic rows, so we truncate the table to avoid inserting NaN PKs.
processed[table_name] = pd.concat(
[
processed[table_name].head(len(synthetic_pk_columns)),
pd.DataFrame.from_records(synthetic_pk_columns),
],
axis=1,
)

return processed

Expand Down
74 changes: 64 additions & 10 deletions tests/relational/test_ancestral_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,23 +528,77 @@ def test_post_processing_individual_synthetic_result_composite_keys(tpch):
strategy = AncestralStrategy()
synth_lineitem = pd.DataFrame(
data={
"self|l_partkey": [10, 20, 30, 40],
"self|l_suppkey": [10, 20, 30, 40],
"self|l_quantity": [42, 42, 42, 42],
"self.l_partkey+l_suppkey|ps_partkey": [2, 3, 4, 5],
"self.l_partkey+l_suppkey|ps_suppkey": [6, 7, 8, 9],
"self.l_partkey+l_suppkey|ps_availqty": [80, 80, 80, 80],
"self.l_partkey+l_suppkey.ps_partkey|p_partkey": [2, 3, 4, 5],
"self.l_partkey+l_suppkey.ps_partkey|p_name": ["a", "b", "c", "d"],
"self.l_partkey+l_suppkey.ps_suppkey|s_suppkey": [6, 7, 8, 9],
"self.l_partkey+l_suppkey.ps_suppkey|s_name": ["e", "f", "g", "h"],
"self|l_partkey": [10, 20, 30, 40] * 3,
"self|l_suppkey": [10, 20, 30, 40] * 3,
"self|l_quantity": [42, 42, 42, 42] * 3,
"self.l_partkey+l_suppkey|ps_partkey": [2, 3, 4, 5] * 3,
"self.l_partkey+l_suppkey|ps_suppkey": [6, 7, 8, 9] * 3,
"self.l_partkey+l_suppkey|ps_availqty": [80, 80, 80, 80] * 3,
"self.l_partkey+l_suppkey.ps_partkey|p_partkey": [2, 3, 4, 5] * 3,
"self.l_partkey+l_suppkey.ps_partkey|p_name": ["a", "b", "c", "d"] * 3,
"self.l_partkey+l_suppkey.ps_suppkey|s_suppkey": [6, 7, 8, 9] * 3,
"self.l_partkey+l_suppkey.ps_suppkey|s_name": ["e", "f", "g", "h"] * 3,
}
)

processed_lineitem = strategy.post_process_individual_synthetic_result(
"lineitem", tpch, synth_lineitem, 1
)

expected_post_processing = pd.DataFrame(
data={
"self|l_partkey": [2, 3, 4, 5] * 3,
"self|l_suppkey": [6, 7, 8, 9] * 3,
"self|l_quantity": [42, 42, 42, 42] * 3,
"self.l_partkey+l_suppkey|ps_partkey": [2, 3, 4, 5] * 3,
"self.l_partkey+l_suppkey|ps_suppkey": [6, 7, 8, 9] * 3,
"self.l_partkey+l_suppkey|ps_availqty": [80, 80, 80, 80] * 3,
"self.l_partkey+l_suppkey.ps_partkey|p_partkey": [2, 3, 4, 5] * 3,
"self.l_partkey+l_suppkey.ps_partkey|p_name": ["a", "b", "c", "d"] * 3,
"self.l_partkey+l_suppkey.ps_suppkey|s_suppkey": [6, 7, 8, 9] * 3,
"self.l_partkey+l_suppkey.ps_suppkey|s_name": ["e", "f", "g", "h"] * 3,
}
)

pdtest.assert_frame_equal(expected_post_processing, processed_lineitem)


def test_post_processing_individual_composite_too_few_keys_created(tpch):
strategy = AncestralStrategy()
synth_lineitem = pd.DataFrame(
data={
"self|l_partkey": [10, 20, 30, 40] * 3,
"self|l_suppkey": [10, 20, 30, 40] * 3,
"self|l_quantity": [42, 42, 42, 42] * 3,
"self.l_partkey+l_suppkey|ps_partkey": [2, 3, 4, 5] * 3,
"self.l_partkey+l_suppkey|ps_suppkey": [6, 7, 8, 9] * 3,
"self.l_partkey+l_suppkey|ps_availqty": [80, 80, 80, 80] * 3,
"self.l_partkey+l_suppkey.ps_partkey|p_partkey": [2, 3, 4, 5] * 3,
"self.l_partkey+l_suppkey.ps_partkey|p_name": ["a", "b", "c", "d"] * 3,
"self.l_partkey+l_suppkey.ps_suppkey|s_suppkey": [6, 7, 8, 9] * 3,
"self.l_partkey+l_suppkey.ps_suppkey|s_name": ["e", "f", "g", "h"] * 3,
}
)

# Given inherent randomness, make_composite_pks can fail to produce enough
# unique composite keys to fill the entire synthetic dataframe. In such situations,
# the client drops records from the raw record handler output and the resulting
# synthetic table only has as many records as keys produced.
with patch(
"gretel_trainer.relational.strategies.ancestral.common.make_composite_pks"
) as make_keys:
make_keys.return_value = [
{"self|l_partkey": 55, "self|l_suppkey": 55},
{"self|l_partkey": 66, "self|l_suppkey": 66},
{"self|l_partkey": 77, "self|l_suppkey": 77},
{"self|l_partkey": 88, "self|l_suppkey": 88},
]
processed_lineitem = strategy.post_process_individual_synthetic_result(
"lineitem", tpch, synth_lineitem, 1
)

# make_composite_pks only created 4 unique keys, so the table is truncated.
# The values (2-5 and 6-9) come from the subsequent foreign key step.
expected_post_processing = pd.DataFrame(
data={
"self|l_partkey": [2, 3, 4, 5],
Expand Down
66 changes: 32 additions & 34 deletions tests/relational/test_common_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,30 @@ def test_composite_pk_columns(tmpdir):
data=df,
)

result = common.make_composite_pk_columns(
result = common.make_composite_pks(
table_name="table",
rel_data=rel_data,
primary_key=["letter", "number"],
synth_row_count=8,
record_size_ratio=1.0,
)

# There is a tuple of values for each primary key column
assert len(result) == 2
# Label-encoding turns the keys into zero-indexed contiguous integers.
# It is absolutely required that all composite keys returned are unique.
# We also ideally recreate the original data frequencies (in this case,
# two unique letters and four unique numbers).
expected_keys = [
{"letter": 0, "number": 0},
{"letter": 0, "number": 1},
{"letter": 0, "number": 2},
{"letter": 0, "number": 3},
{"letter": 1, "number": 0},
{"letter": 1, "number": 1},
{"letter": 1, "number": 2},
{"letter": 1, "number": 3},
]

# Each tuple has enough values for the synthetic result
for t in result:
assert len(t) == 8

# Each combination is unique
synthetic_pks = set(zip(*result))
assert len(synthetic_pks) == 8

# The set of unique values in each synthetic column roughly matches
# the set of unique values in the source columns.
# In this example they match exactly because there are no other possible combinations,
# but in practice it's possible to randomly not-select some values.
assert len(set(result[0])) == 2
assert len(set(result[1])) == 4
for expected_key in expected_keys:
assert expected_key in result


def test_composite_pk_columns_2(tmpdir):
Expand All @@ -59,28 +58,27 @@ def test_composite_pk_columns_2(tmpdir):
data=df,
)

result = common.make_composite_pk_columns(
result = common.make_composite_pks(
table_name="table",
rel_data=rel_data,
primary_key=["letter", "number"],
synth_row_count=8,
record_size_ratio=1.0,
)

# There is a tuple of values for each primary key column
assert len(result) == 2

# Each tuple has enough values for the synthetic result
for t in result:
assert len(t) == 8
# We create as many keys as we need
assert len(result) == 8

# Each combination is unique
synthetic_pks = set(zip(*result))
assert len(synthetic_pks) == 8
assert len(set([str(composite_key) for composite_key in result])) == 8

# In this case, there are more potential unique combinations than there are synthetic rows,
# so we can't say for sure what the exact composite values will be. However, we do expect
# the original frequencies to be maintained.
synthetic_letters = [key["letter"] for key in result]
assert len(synthetic_letters) == 8
assert set(synthetic_letters) == {0, 1}
assert len([x for x in synthetic_letters if x != 0]) == 4

# The set of unique values in each synthetic column roughly matches
# the set of unique values in the source columns.
# In this example, there are more potential combinations than there are synthetic rows,
# so our assertions are not as strict.
assert len(set(result[0])) <= 2
assert len(set(result[1])) <= 8
synthetic_numbers = [key["number"] for key in result]
assert len(synthetic_numbers) == 8
assert set(synthetic_numbers) == {0, 1, 2, 3, 4, 5, 6, 7}

0 comments on commit b2d907b

Please sign in to comment.