From e037974a192f554a2177fb93f260fc35cdb266d8 Mon Sep 17 00:00:00 2001 From: Mike Knepper Date: Fri, 11 Aug 2023 10:39:31 -0500 Subject: [PATCH] Sync SAL sources and dependencies GitOrigin-RevId: dea281a81071fed5dab34b66b0cbc861e6cb3911 --- notebooks/conditional-generation.py | 37 +++++++++++-------- notebooks/custom-example.py | 5 +-- notebooks/simple-example.py | 4 +- requirements.txt | 4 +- setup.py | 3 +- src/gretel_trainer/benchmark/core.py | 1 + .../benchmark/custom/datasets.py | 1 + .../benchmark/custom/strategy.py | 2 +- src/gretel_trainer/benchmark/entrypoints.py | 7 ++-- src/gretel_trainer/benchmark/executor.py | 2 +- .../benchmark/gretel/datasets.py | 2 + .../datasets_backwards_compatibility.py | 1 + src/gretel_trainer/benchmark/gretel/models.py | 7 ++-- .../benchmark/gretel/strategy_sdk.py | 3 +- .../benchmark/gretel/strategy_trainer.py | 2 +- src/gretel_trainer/benchmark/sdk_extras.py | 3 +- src/gretel_trainer/benchmark/session.py | 8 ++-- src/gretel_trainer/models.py | 14 ++++--- src/gretel_trainer/relational/__init__.py | 1 + src/gretel_trainer/relational/ancestry.py | 1 + src/gretel_trainer/relational/artifacts.py | 1 + src/gretel_trainer/relational/connectors.py | 2 + src/gretel_trainer/relational/core.py | 3 ++ src/gretel_trainer/relational/extractor.py | 6 ++- src/gretel_trainer/relational/json.py | 4 +- src/gretel_trainer/relational/log.py | 1 + src/gretel_trainer/relational/model_config.py | 1 - src/gretel_trainer/relational/multi_table.py | 11 +++--- .../relational/report/figures.py | 1 + .../relational/report/report.py | 4 +- src/gretel_trainer/relational/sdk_extras.py | 3 +- .../relational/strategies/ancestral.py | 4 +- .../relational/strategies/common.py | 4 +- .../relational/strategies/independent.py | 4 +- .../relational/table_evaluation.py | 3 +- src/gretel_trainer/relational/task_runner.py | 2 +- .../relational/tasks/classify.py | 6 ++- src/gretel_trainer/relational/tasks/common.py | 2 +- .../relational/tasks/synthetics_evaluate.py | 4 +- .../relational/tasks/synthetics_run.py | 6 ++- .../relational/tasks/synthetics_train.py | 4 +- .../relational/tasks/transforms_run.py | 5 ++- .../relational/tasks/transforms_train.py | 4 +- src/gretel_trainer/runner.py | 19 ++++++---- src/gretel_trainer/strategy.py | 4 +- src/gretel_trainer/trainer.py | 28 ++++++++++---- test-requirements.txt | 7 ++-- tests/benchmark/conftest.py | 1 + tests/benchmark/test_bad_setup.py | 2 +- tests/benchmark/test_benchmark.py | 7 ++-- tests/benchmark/test_custom_datasets.py | 2 +- tests/relational/conftest.py | 2 + tests/relational/test_ancestral_strategy.py | 2 + tests/relational/test_artifacts.py | 3 +- tests/relational/test_common_strategy.py | 1 + tests/relational/test_extractor.py | 3 +- tests/relational/test_independent_strategy.py | 1 + tests/relational/test_model_config.py | 1 - .../test_multi_table_config_options.py | 1 + tests/relational/test_multi_table_restore.py | 2 + .../test_relational_data_with_json.py | 5 +-- tests/relational/test_synthetics_run_task.py | 5 ++- tests/relational/test_task_runner.py | 4 +- tests/relational/test_train_synthetics.py | 1 + tests/relational/test_train_transforms.py | 1 + tests/test_strategy.py | 4 +- 66 files changed, 186 insertions(+), 113 deletions(-) diff --git a/notebooks/conditional-generation.py b/notebooks/conditional-generation.py index d8833c40..573ae94b 100644 --- a/notebooks/conditional-generation.py +++ b/notebooks/conditional-generation.py @@ -1,24 +1,29 @@ import pandas as pd -from gretel_client import configure_session +from gretel_client import configure_session from gretel_trainer import Trainer -from gretel_trainer.models import GretelLSTM, GretelACTGAN +from gretel_trainer.models import GretelACTGAN, GretelLSTM -DATASET_PATH = 'https://gretel-public-website.s3.amazonaws.com/datasets/mitre-synthea-health.csv' +DATASET_PATH = ( + "https://gretel-public-website.s3.amazonaws.com/datasets/mitre-synthea-health.csv" +) MODEL_TYPE = [GretelLSTM(), GretelACTGAN()][1] # Create dataset to autocomplete values for -seed_df = pd.DataFrame(data=[ - ["black", "african", "F"], - ["black", "african", "F"], - ["black", "african", "F"], - ["black", "african", "F"], - ["asian", "chinese", "F"], - ["asian", "chinese", "F"], - ["asian", "chinese", "F"], - ["asian", "chinese", "F"], - ["asian", "chinese", "F"] -], columns=["RACE", "ETHNICITY", "GENDER"]) +seed_df = pd.DataFrame( + data=[ + ["black", "african", "F"], + ["black", "african", "F"], + ["black", "african", "F"], + ["black", "african", "F"], + ["asian", "chinese", "F"], + ["asian", "chinese", "F"], + ["asian", "chinese", "F"], + ["asian", "chinese", "F"], + ["asian", "chinese", "F"], + ], + columns=["RACE", "ETHNICITY", "GENDER"], +) # Configure Gretel credentials @@ -31,5 +36,5 @@ print(model.generate(seed_df=seed_df)) # Load a existing model and conditionally generate data -#model = Trainer.load() -#print(model.generate(seed_df=seed_df)) +# model = Trainer.load() +# print(model.generate(seed_df=seed_df)) diff --git a/notebooks/custom-example.py b/notebooks/custom-example.py index cdf1ea7a..d9e605aa 100644 --- a/notebooks/custom-example.py +++ b/notebooks/custom-example.py @@ -9,10 +9,7 @@ # Specify underlying model and config options. # configs can be either a string, dict, or path -model_type = GretelACTGAN( - config="synthetics/tabular-actgan", - max_rows=50000 -) +model_type = GretelACTGAN(config="synthetics/tabular-actgan", max_rows=50000) # Optionally update model parameters from a base config model_type.update_params({"epochs": 500}) diff --git a/notebooks/simple-example.py b/notebooks/simple-example.py index 42c2fddb..110190a8 100644 --- a/notebooks/simple-example.py +++ b/notebooks/simple-example.py @@ -13,5 +13,5 @@ # Or, load and generate data from an existing model -#model = Trainer.load() -#model.generate(num_records=70) +# model = Trainer.load() +# model.generate(num_records=70) diff --git a/requirements.txt b/requirements.txt index 0ee5a4fc..2fc352dc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,7 @@ plotly~=5.11 pydantic~=1.9 requests~=2.25 scikit-learn~=1.0 -smart-open[s3]~=5.2 +smart_open[s3]~=5.2 sqlalchemy~=1.4 -typing-extensions~=4.7 +typing_extensions~=4.7 unflatten==0.1.1 diff --git a/setup.py b/setup.py index 1db3c8e4..74827a4b 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,6 @@ import pathlib -from setuptools import setup, find_packages + +from setuptools import find_packages, setup local_path = pathlib.Path(__file__).parent install_requires = (local_path / "requirements.txt").read_text().splitlines() diff --git a/src/gretel_trainer/benchmark/core.py b/src/gretel_trainer/benchmark/core.py index b5450e27..abfa4881 100644 --- a/src/gretel_trainer/benchmark/core.py +++ b/src/gretel_trainer/benchmark/core.py @@ -1,6 +1,7 @@ import csv import logging import time + from dataclasses import dataclass, field from datetime import datetime from enum import Enum diff --git a/src/gretel_trainer/benchmark/custom/datasets.py b/src/gretel_trainer/benchmark/custom/datasets.py index efa9d00c..766a27f9 100644 --- a/src/gretel_trainer/benchmark/custom/datasets.py +++ b/src/gretel_trainer/benchmark/custom/datasets.py @@ -1,6 +1,7 @@ import logging import os import uuid + from dataclasses import dataclass, field from pathlib import Path from typing import Optional, Union diff --git a/src/gretel_trainer/benchmark/custom/strategy.py b/src/gretel_trainer/benchmark/custom/strategy.py index 66271599..ac12aa6e 100644 --- a/src/gretel_trainer/benchmark/custom/strategy.py +++ b/src/gretel_trainer/benchmark/custom/strategy.py @@ -1,7 +1,7 @@ from pathlib import Path from typing import Optional -from gretel_trainer.benchmark.core import BenchmarkConfig, Dataset, Timer, run_out_path +from gretel_trainer.benchmark.core import BenchmarkConfig, Dataset, run_out_path, Timer from gretel_trainer.benchmark.custom.models import CustomModel diff --git a/src/gretel_trainer/benchmark/entrypoints.py b/src/gretel_trainer/benchmark/entrypoints.py index 29b78772..5686fc4e 100644 --- a/src/gretel_trainer/benchmark/entrypoints.py +++ b/src/gretel_trainer/benchmark/entrypoints.py @@ -2,21 +2,22 @@ import logging import shutil + from inspect import isclass from pathlib import Path -from typing import Optional, Type, Union, cast +from typing import cast, Optional, Type, Union import pandas as pd -from gretel_client.config import get_session_config +from gretel_client.config import get_session_config from gretel_trainer.benchmark.core import BenchmarkConfig, BenchmarkException, Dataset from gretel_trainer.benchmark.custom.models import CustomModel from gretel_trainer.benchmark.gretel.models import GretelModel from gretel_trainer.benchmark.job_spec import ( DatasetTypes, JobSpec, - ModelTypes, model_name, + ModelTypes, ) from gretel_trainer.benchmark.session import Session diff --git a/src/gretel_trainer/benchmark/executor.py b/src/gretel_trainer/benchmark/executor.py index 33c42b90..2ad1099b 100644 --- a/src/gretel_trainer/benchmark/executor.py +++ b/src/gretel_trainer/benchmark/executor.py @@ -1,10 +1,10 @@ import logging + from enum import Enum from typing import Optional, Protocol from gretel_client.projects.models import Model from gretel_client.projects.projects import Project - from gretel_trainer.benchmark.core import BenchmarkConfig, Dataset, log, run_out_path from gretel_trainer.benchmark.sdk_extras import create_evaluate_model, run_evaluate diff --git a/src/gretel_trainer/benchmark/gretel/datasets.py b/src/gretel_trainer/benchmark/gretel/datasets.py index cbaed7ae..bec12d9a 100644 --- a/src/gretel_trainer/benchmark/gretel/datasets.py +++ b/src/gretel_trainer/benchmark/gretel/datasets.py @@ -1,10 +1,12 @@ from __future__ import annotations import json + from functools import cached_property from typing import Optional, Union import boto3 + from botocore import UNSIGNED from botocore.client import Config diff --git a/src/gretel_trainer/benchmark/gretel/datasets_backwards_compatibility.py b/src/gretel_trainer/benchmark/gretel/datasets_backwards_compatibility.py index f9a7fdbc..87b18831 100644 --- a/src/gretel_trainer/benchmark/gretel/datasets_backwards_compatibility.py +++ b/src/gretel_trainer/benchmark/gretel/datasets_backwards_compatibility.py @@ -2,6 +2,7 @@ # It can be deleted completely once we fully remove these functions. import logging + from typing import Optional, Union from gretel_trainer.benchmark import Datatype diff --git a/src/gretel_trainer/benchmark/gretel/models.py b/src/gretel_trainer/benchmark/gretel/models.py index 6d7163ce..5a38ed55 100644 --- a/src/gretel_trainer/benchmark/gretel/models.py +++ b/src/gretel_trainer/benchmark/gretel/models.py @@ -1,12 +1,13 @@ import copy + from inspect import isclass from pathlib import Path -from typing import Optional, Type, Union, cast +from typing import cast, Optional, Type, Union + +import gretel_trainer.models from gretel_client.projects.exceptions import ModelConfigError from gretel_client.projects.models import read_model_config - -import gretel_trainer.models from gretel_trainer.benchmark.core import BenchmarkException, Dataset, Datatype GretelModelConfig = Union[str, Path, dict] diff --git a/src/gretel_trainer/benchmark/gretel/strategy_sdk.py b/src/gretel_trainer/benchmark/gretel/strategy_sdk.py index 6de97408..eae193d5 100644 --- a/src/gretel_trainer/benchmark/gretel/strategy_sdk.py +++ b/src/gretel_trainer/benchmark/gretel/strategy_sdk.py @@ -1,14 +1,15 @@ import copy import gzip + from pathlib import Path from typing import Optional import requests + from gretel_client.projects.jobs import END_STATES, Job, RunnerMode, Status from gretel_client.projects.models import Model, read_model_config from gretel_client.projects.projects import Project from gretel_client.projects.records import RecordHandler - from gretel_trainer.benchmark.core import ( BenchmarkConfig, BenchmarkException, diff --git a/src/gretel_trainer/benchmark/gretel/strategy_trainer.py b/src/gretel_trainer/benchmark/gretel/strategy_trainer.py index a958606d..549780db 100644 --- a/src/gretel_trainer/benchmark/gretel/strategy_trainer.py +++ b/src/gretel_trainer/benchmark/gretel/strategy_trainer.py @@ -6,8 +6,8 @@ BenchmarkConfig, BenchmarkException, Dataset, - Timer, run_out_path, + Timer, ) from gretel_trainer.benchmark.gretel.models import GretelModel from gretel_trainer.benchmark.job_spec import JobSpec diff --git a/src/gretel_trainer/benchmark/sdk_extras.py b/src/gretel_trainer/benchmark/sdk_extras.py index 92863980..76471533 100644 --- a/src/gretel_trainer/benchmark/sdk_extras.py +++ b/src/gretel_trainer/benchmark/sdk_extras.py @@ -1,8 +1,10 @@ import json import time + from typing import Any import smart_open + from gretel_client.projects.jobs import ( ACTIVE_STATES, END_STATES, @@ -12,7 +14,6 @@ ) from gretel_client.projects.models import Model, read_model_config from gretel_client.projects.projects import Project - from gretel_trainer.benchmark.core import BenchmarkException, log diff --git a/src/gretel_trainer/benchmark/session.py b/src/gretel_trainer/benchmark/session.py index 824ebb09..2154f923 100644 --- a/src/gretel_trainer/benchmark/session.py +++ b/src/gretel_trainer/benchmark/session.py @@ -1,15 +1,17 @@ from __future__ import annotations import logging + from concurrent.futures import Future, ThreadPoolExecutor from typing import Any, Optional, Union import pandas as pd -from gretel_client.helpers import poll -from gretel_client.projects import Project, create_project, search_projects -from gretel_client.projects.jobs import Job + from typing_extensions import TypeGuard +from gretel_client.helpers import poll +from gretel_client.projects import create_project, Project, search_projects +from gretel_client.projects.jobs import Job from gretel_trainer.benchmark.core import BenchmarkConfig, BenchmarkException from gretel_trainer.benchmark.custom.models import CustomModel from gretel_trainer.benchmark.custom.strategy import CustomStrategy diff --git a/src/gretel_trainer/models.py b/src/gretel_trainer/models.py index b2d5356a..df2a88c7 100644 --- a/src/gretel_trainer/models.py +++ b/src/gretel_trainer/models.py @@ -1,7 +1,8 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Union, Optional + +from typing import Optional, TYPE_CHECKING, Union if TYPE_CHECKING: import pandas as pd @@ -17,11 +18,12 @@ def _actgan_is_best(rows: int, cols: int) -> bool: - return \ - rows > HIGH_RECORD_THRESHOLD or \ - cols > HIGH_COLUMN_THRESHOLD or \ - rows < LOW_RECORD_THRESHOLD or \ - cols < LOW_COLUMN_THRESHOLD + return ( + rows > HIGH_RECORD_THRESHOLD + or cols > HIGH_COLUMN_THRESHOLD + or rows < LOW_RECORD_THRESHOLD + or cols < LOW_COLUMN_THRESHOLD + ) def determine_best_model(df: pd.DataFrame) -> _BaseConfig: diff --git a/src/gretel_trainer/relational/__init__.py b/src/gretel_trainer/relational/__init__.py index ff9c215f..810d51f8 100644 --- a/src/gretel_trainer/relational/__init__.py +++ b/src/gretel_trainer/relational/__init__.py @@ -1,4 +1,5 @@ import gretel_trainer.relational.log + from gretel_trainer.relational.connectors import ( Connector, mariadb_conn, diff --git a/src/gretel_trainer/relational/ancestry.py b/src/gretel_trainer/relational/ancestry.py index 95e6730b..6f8d9567 100644 --- a/src/gretel_trainer/relational/ancestry.py +++ b/src/gretel_trainer/relational/ancestry.py @@ -1,4 +1,5 @@ import re + from typing import Optional import pandas as pd diff --git a/src/gretel_trainer/relational/artifacts.py b/src/gretel_trainer/relational/artifacts.py index 33d1843b..7d284dd9 100644 --- a/src/gretel_trainer/relational/artifacts.py +++ b/src/gretel_trainer/relational/artifacts.py @@ -1,5 +1,6 @@ import shutil import tempfile + from dataclasses import dataclass from pathlib import Path from typing import Optional diff --git a/src/gretel_trainer/relational/connectors.py b/src/gretel_trainer/relational/connectors.py index 6b461392..916f7e2d 100644 --- a/src/gretel_trainer/relational/connectors.py +++ b/src/gretel_trainer/relational/connectors.py @@ -9,10 +9,12 @@ from __future__ import annotations import logging + from pathlib import Path from typing import Optional import pandas as pd + from sqlalchemy import create_engine from sqlalchemy.engine.base import Engine from sqlalchemy.exc import OperationalError diff --git a/src/gretel_trainer/relational/core.py b/src/gretel_trainer/relational/core.py index 93e6e457..1154e270 100644 --- a/src/gretel_trainer/relational/core.py +++ b/src/gretel_trainer/relational/core.py @@ -17,6 +17,7 @@ import logging import shutil import tempfile + from dataclasses import dataclass, replace from enum import Enum from pathlib import Path @@ -24,10 +25,12 @@ import networkx import pandas as pd + from networkx.algorithms.dag import dag_longest_path_length, topological_sort from pandas.api.types import is_string_dtype import gretel_trainer.relational.json as relational_json + from gretel_trainer.relational.json import ( IngestResponseT, InventedTableMetadata, diff --git a/src/gretel_trainer/relational/extractor.py b/src/gretel_trainer/relational/extractor.py index b0be69d3..7365763c 100644 --- a/src/gretel_trainer/relational/extractor.py +++ b/src/gretel_trainer/relational/extractor.py @@ -4,17 +4,19 @@ from __future__ import annotations import logging + from contextlib import nullcontext from dataclasses import asdict, dataclass from enum import Enum from pathlib import Path from threading import Lock -from typing import TYPE_CHECKING, Iterator, Optional +from typing import Iterator, Optional, TYPE_CHECKING import dask.dataframe as dd import numpy as np import pandas as pd -from sqlalchemy import MetaData, Table, func, inspect, select, tuple_ + +from sqlalchemy import func, inspect, MetaData, select, Table, tuple_ from gretel_trainer.relational.core import RelationalData diff --git a/src/gretel_trainer/relational/json.py b/src/gretel_trainer/relational/json.py index 533dd0d5..2a1edf3f 100644 --- a/src/gretel_trainer/relational/json.py +++ b/src/gretel_trainer/relational/json.py @@ -2,13 +2,15 @@ import logging import re + from dataclasses import dataclass -from json import JSONDecodeError, dumps, loads +from json import dumps, JSONDecodeError, loads from typing import Any, Optional, Protocol, Union from uuid import uuid4 import numpy as np import pandas as pd + from unflatten import unflatten logger = logging.getLogger(__name__) diff --git a/src/gretel_trainer/relational/log.py b/src/gretel_trainer/relational/log.py index ea50840f..75e9a53d 100644 --- a/src/gretel_trainer/relational/log.py +++ b/src/gretel_trainer/relational/log.py @@ -1,4 +1,5 @@ import logging + from contextlib import contextmanager RELATIONAL = "gretel_trainer.relational" diff --git a/src/gretel_trainer/relational/model_config.py b/src/gretel_trainer/relational/model_config.py index 400c7161..b43d3aee 100644 --- a/src/gretel_trainer/relational/model_config.py +++ b/src/gretel_trainer/relational/model_config.py @@ -3,7 +3,6 @@ from gretel_client.projects.exceptions import ModelConfigError from gretel_client.projects.models import read_model_config - from gretel_trainer.relational.core import ( GretelModelConfig, MultiTableException, diff --git a/src/gretel_trainer/relational/multi_table.py b/src/gretel_trainer/relational/multi_table.py index ca6cd753..d0924f77 100644 --- a/src/gretel_trainer/relational/multi_table.py +++ b/src/gretel_trainer/relational/multi_table.py @@ -11,6 +11,7 @@ import shutil import tarfile import tempfile + from collections import defaultdict from contextlib import suppress from dataclasses import replace @@ -20,15 +21,15 @@ import pandas as pd import smart_open -from gretel_client.config import RunnerMode, get_session_config -from gretel_client.projects import Project, create_project, get_project + +from gretel_client.config import get_session_config, RunnerMode +from gretel_client.projects import create_project, get_project, Project from gretel_client.projects.jobs import ACTIVE_STATES, END_STATES, Status from gretel_client.projects.records import RecordHandler - from gretel_trainer.relational.artifacts import ( - ArtifactCollection, archive_items, archive_nested_dir, + ArtifactCollection, ) from gretel_trainer.relational.backup import ( Backup, @@ -43,8 +44,8 @@ MultiTableException, RelationalData, Scope, - UserFriendlyDataT, skip_table, + UserFriendlyDataT, ) from gretel_trainer.relational.json import InventedTableMetadata, ProducerMetadata from gretel_trainer.relational.log import silent_logs diff --git a/src/gretel_trainer/relational/report/figures.py b/src/gretel_trainer/relational/report/figures.py index 0fa95c89..46695824 100644 --- a/src/gretel_trainer/relational/report/figures.py +++ b/src/gretel_trainer/relational/report/figures.py @@ -1,4 +1,5 @@ import math + from typing import Optional import plotly.graph_objects as go diff --git a/src/gretel_trainer/relational/report/report.py b/src/gretel_trainer/relational/report/report.py index acea7675..ac2f188f 100644 --- a/src/gretel_trainer/relational/report/report.py +++ b/src/gretel_trainer/relational/report/report.py @@ -1,6 +1,7 @@ from __future__ import annotations import datetime + from dataclasses import dataclass from functools import cached_property from math import ceil @@ -8,12 +9,13 @@ from typing import Optional import plotly.graph_objects as go + from jinja2 import Environment, FileSystemLoader from gretel_trainer.relational.core import ForeignKey, RelationalData, Scope from gretel_trainer.relational.report.figures import ( - PRIVACY_LEVEL_VALUES, gauge_and_needle_chart, + PRIVACY_LEVEL_VALUES, ) from gretel_trainer.relational.table_evaluation import TableEvaluation diff --git a/src/gretel_trainer/relational/sdk_extras.py b/src/gretel_trainer/relational/sdk_extras.py index b636cf59..6ea05285 100644 --- a/src/gretel_trainer/relational/sdk_extras.py +++ b/src/gretel_trainer/relational/sdk_extras.py @@ -1,5 +1,6 @@ import logging import shutil + from contextlib import suppress from pathlib import Path from typing import Any, Optional, Union @@ -7,11 +8,11 @@ import pandas as pd import requests import smart_open + from gretel_client.projects.jobs import Job, Status from gretel_client.projects.models import Model from gretel_client.projects.projects import Project from gretel_client.projects.records import RecordHandler - from gretel_trainer.relational.core import MultiTableException logger = logging.getLogger(__name__) diff --git a/src/gretel_trainer/relational/strategies/ancestral.py b/src/gretel_trainer/relational/strategies/ancestral.py index 06b60cf6..d53a0e3d 100644 --- a/src/gretel_trainer/relational/strategies/ancestral.py +++ b/src/gretel_trainer/relational/strategies/ancestral.py @@ -1,12 +1,14 @@ import logging + from pathlib import Path from typing import Any, Optional, Union import pandas as pd -from gretel_client.projects.models import Model import gretel_trainer.relational.ancestry as ancestry import gretel_trainer.relational.strategies.common as common + +from gretel_client.projects.models import Model from gretel_trainer.relational.core import MultiTableException, RelationalData from gretel_trainer.relational.sdk_extras import ExtendedGretelSDK from gretel_trainer.relational.table_evaluation import TableEvaluation diff --git a/src/gretel_trainer/relational/strategies/common.py b/src/gretel_trainer/relational/strategies/common.py index dfce7680..e9d0726e 100644 --- a/src/gretel_trainer/relational/strategies/common.py +++ b/src/gretel_trainer/relational/strategies/common.py @@ -1,14 +1,16 @@ import json import logging import random + from pathlib import Path from typing import Optional import pandas as pd import smart_open -from gretel_client.projects.models import Model + from sklearn import preprocessing +from gretel_client.projects.models import Model from gretel_trainer.relational.core import MultiTableException, RelationalData from gretel_trainer.relational.sdk_extras import ExtendedGretelSDK diff --git a/src/gretel_trainer/relational/strategies/independent.py b/src/gretel_trainer/relational/strategies/independent.py index aa195e93..9e8c04ca 100644 --- a/src/gretel_trainer/relational/strategies/independent.py +++ b/src/gretel_trainer/relational/strategies/independent.py @@ -1,13 +1,15 @@ import logging import random + from pathlib import Path from typing import Any, Optional import pandas as pd -from gretel_client.projects.models import Model import gretel_trainer.relational.ancestry as ancestry import gretel_trainer.relational.strategies.common as common + +from gretel_client.projects.models import Model from gretel_trainer.relational.core import RelationalData from gretel_trainer.relational.sdk_extras import ExtendedGretelSDK from gretel_trainer.relational.table_evaluation import TableEvaluation diff --git a/src/gretel_trainer/relational/table_evaluation.py b/src/gretel_trainer/relational/table_evaluation.py index 5f8fad44..0993577b 100644 --- a/src/gretel_trainer/relational/table_evaluation.py +++ b/src/gretel_trainer/relational/table_evaluation.py @@ -1,6 +1,7 @@ import json + from dataclasses import dataclass, field -from typing import Literal, Optional, Union, overload +from typing import Literal, Optional, overload, Union _SQS = "synthetic_data_quality_score" _PPL = "privacy_protection_level" diff --git a/src/gretel_trainer/relational/task_runner.py b/src/gretel_trainer/relational/task_runner.py index 7f038785..f251b452 100644 --- a/src/gretel_trainer/relational/task_runner.py +++ b/src/gretel_trainer/relational/task_runner.py @@ -1,10 +1,10 @@ import logging + from collections import defaultdict from typing import Protocol from gretel_client.projects.jobs import END_STATES, Job, Status from gretel_client.projects.projects import Project - from gretel_trainer.relational.sdk_extras import ExtendedGretelSDK MAX_REFRESH_ATTEMPTS = 3 diff --git a/src/gretel_trainer/relational/tasks/classify.py b/src/gretel_trainer/relational/tasks/classify.py index 6a455ea5..a8ba37f8 100644 --- a/src/gretel_trainer/relational/tasks/classify.py +++ b/src/gretel_trainer/relational/tasks/classify.py @@ -1,13 +1,15 @@ import shutil + from pathlib import Path import smart_open + +import gretel_trainer.relational.tasks.common as common + from gretel_client.projects.jobs import Job from gretel_client.projects.models import Model from gretel_client.projects.projects import Project from gretel_client.projects.records import RecordHandler - -import gretel_trainer.relational.tasks.common as common from gretel_trainer.relational.workflow_state import Classify diff --git a/src/gretel_trainer/relational/tasks/common.py b/src/gretel_trainer/relational/tasks/common.py index 5eb9a70a..45ab051b 100644 --- a/src/gretel_trainer/relational/tasks/common.py +++ b/src/gretel_trainer/relational/tasks/common.py @@ -1,10 +1,10 @@ import logging import time + from typing import Protocol, Union from gretel_client.projects.jobs import Job, Status from gretel_client.projects.projects import Project - from gretel_trainer.relational.core import RelationalData from gretel_trainer.relational.sdk_extras import ExtendedGretelSDK from gretel_trainer.relational.strategies.ancestral import AncestralStrategy diff --git a/src/gretel_trainer/relational/tasks/synthetics_evaluate.py b/src/gretel_trainer/relational/tasks/synthetics_evaluate.py index 747e8ae0..8943c417 100644 --- a/src/gretel_trainer/relational/tasks/synthetics_evaluate.py +++ b/src/gretel_trainer/relational/tasks/synthetics_evaluate.py @@ -1,9 +1,9 @@ +import gretel_trainer.relational.tasks.common as common + from gretel_client.projects.jobs import Job from gretel_client.projects.models import Model from gretel_client.projects.projects import Project -import gretel_trainer.relational.tasks.common as common - ACTION = "synthetic data evaluation" diff --git a/src/gretel_trainer/relational/tasks/synthetics_run.py b/src/gretel_trainer/relational/tasks/synthetics_run.py index 82221ad6..27007eab 100644 --- a/src/gretel_trainer/relational/tasks/synthetics_run.py +++ b/src/gretel_trainer/relational/tasks/synthetics_run.py @@ -1,13 +1,15 @@ import logging + from pathlib import Path from typing import Optional import pandas as pd + +import gretel_trainer.relational.tasks.common as common + from gretel_client.projects.jobs import ACTIVE_STATES, Job, Status from gretel_client.projects.projects import Project from gretel_client.projects.records import RecordHandler - -import gretel_trainer.relational.tasks.common as common from gretel_trainer.relational.workflow_state import SyntheticsRun, SyntheticsTrain logger = logging.getLogger(__name__) diff --git a/src/gretel_trainer/relational/tasks/synthetics_train.py b/src/gretel_trainer/relational/tasks/synthetics_train.py index a15eb32d..67b4fc86 100644 --- a/src/gretel_trainer/relational/tasks/synthetics_train.py +++ b/src/gretel_trainer/relational/tasks/synthetics_train.py @@ -1,7 +1,7 @@ +import gretel_trainer.relational.tasks.common as common + from gretel_client.projects.jobs import Job from gretel_client.projects.projects import Project - -import gretel_trainer.relational.tasks.common as common from gretel_trainer.relational.workflow_state import SyntheticsTrain ACTION = "synthetics model training" diff --git a/src/gretel_trainer/relational/tasks/transforms_run.py b/src/gretel_trainer/relational/tasks/transforms_run.py index 49951e2c..8c6fcaa9 100644 --- a/src/gretel_trainer/relational/tasks/transforms_run.py +++ b/src/gretel_trainer/relational/tasks/transforms_run.py @@ -1,12 +1,13 @@ from typing import Optional import pandas as pd + +import gretel_trainer.relational.tasks.common as common + from gretel_client.projects.jobs import Job from gretel_client.projects.projects import Project from gretel_client.projects.records import RecordHandler -import gretel_trainer.relational.tasks.common as common - ACTION = "transforms run" diff --git a/src/gretel_trainer/relational/tasks/transforms_train.py b/src/gretel_trainer/relational/tasks/transforms_train.py index 3e896b90..b16a7443 100644 --- a/src/gretel_trainer/relational/tasks/transforms_train.py +++ b/src/gretel_trainer/relational/tasks/transforms_train.py @@ -1,7 +1,7 @@ +import gretel_trainer.relational.tasks.common as common + from gretel_client.projects.jobs import Job from gretel_client.projects.projects import Project - -import gretel_trainer.relational.tasks.common as common from gretel_trainer.relational.workflow_state import TransformsTrain ACTION = "transforms model training" diff --git a/src/gretel_trainer/runner.py b/src/gretel_trainer/runner.py index fdd539f0..50133985 100644 --- a/src/gretel_trainer/runner.py +++ b/src/gretel_trainer/runner.py @@ -27,6 +27,7 @@ import math import tempfile import time + from collections import Counter from concurrent.futures import ALL_COMPLETED, ThreadPoolExecutor, wait from copy import deepcopy @@ -38,13 +39,13 @@ import pandas as pd import smart_open + from gretel_client.projects import Project from gretel_client.projects.jobs import ACTIVE_STATES from gretel_client.projects.models import Model, Status from gretel_client.projects.records import RecordHandler from gretel_client.rest import ApiException from gretel_client.users.users import get_me - from gretel_trainer.strategy import Partition, PartitionConstraints, PartitionStrategy MODEL_ID = "model_id" @@ -84,7 +85,9 @@ class RemoteDFPayload: artifact_type: str -def _remote_dataframe_fetcher(payload: RemoteDFPayload) -> Tuple[RemoteDFPayload, pd.DataFrame]: +def _remote_dataframe_fetcher( + payload: RemoteDFPayload, +) -> Tuple[RemoteDFPayload, pd.DataFrame]: # We need the model object no matter what model = Model(payload.project, model_id=payload.uid) job = model @@ -405,7 +408,8 @@ def train_partition( } model = self._project.create_model_obj( - data_source=artifact.id, model_config=model_config, + data_source=artifact.id, + model_config=model_config, ) model.name = artifact.id.split("_")[-1] @@ -678,10 +682,7 @@ def get_synthetic_data(self) -> pd.DataFrame: return self._maybe_restore_df_headers(df) def get_sqs_information(self) -> List[dict]: - return [ - partition.ctx[SQS] - for partition in self._strategy.partitions - ] + return [partition.ctx[SQS] for partition in self._strategy.partitions] def generate_data( self, @@ -736,7 +737,9 @@ def generate_data( # NOTE: This payload will be used to create a new payload object per # partition, so this will get passed in more as a template to the next # routine - partition_num_records = math.ceil(num_records / self._strategy.row_partition_count) + partition_num_records = math.ceil( + num_records / self._strategy.row_partition_count + ) gen_payload = GenPayload( seed_df=seed_df, num_records=partition_num_records, max_invalid=max_invalid ) diff --git a/src/gretel_trainer/strategy.py b/src/gretel_trainer/strategy.py index 4d90b401..c9949966 100644 --- a/src/gretel_trainer/strategy.py +++ b/src/gretel_trainer/strategy.py @@ -68,9 +68,7 @@ def _build_partitions( seed_headers = constraints.seed_headers partitions.append( Partition( - rows=RowPartition( - start=curr_start, end=curr_start + chunk_size - ), + rows=RowPartition(start=curr_start, end=curr_start + chunk_size), columns=ColumnPartition( headers=list(df.columns), seed_headers=seed_headers ), diff --git a/src/gretel_trainer/trainer.py b/src/gretel_trainer/trainer.py index b8355078..8b2267f5 100644 --- a/src/gretel_trainer/trainer.py +++ b/src/gretel_trainer/trainer.py @@ -5,13 +5,14 @@ import json import logging import os.path + from pathlib import Path from typing import Optional import pandas as pd + from gretel_client.config import get_session_config, RunnerMode from gretel_client.projects import create_or_get_unique_project - from gretel_trainer import runner, strategy from gretel_trainer.models import _BaseConfig, determine_best_model @@ -22,11 +23,13 @@ DEFAULT_CACHE = f"{DEFAULT_PROJECT}-runner.json" _ACCEPTABLE_CHARS = set( - [chr(c) for c in range(ord("a"), ord("z")+1)] + - [chr(c) for c in range(ord("A"), ord("Z")+1)] + - [chr(c) for c in range(ord("0"), ord("9")+1)] + - ["-"] + [chr(c) for c in range(ord("a"), ord("z") + 1)] + + [chr(c) for c in range(ord("A"), ord("Z") + 1)] + + [chr(c) for c in range(ord("0"), ord("9") + 1)] + + ["-"] ) + + def _sanitize_name(name: str): """Replace unacceptable characters for Gretel API project or model names.""" # Does not account for the following requirements: @@ -35,6 +38,7 @@ def _sanitize_name(name: str): # - minimum and maximum length return "".join(c if c in _ACCEPTABLE_CHARS else "-" for c in name) + class Trainer: """Automated model training and synthetic data generation tool @@ -63,7 +67,9 @@ def __init__( if self.overwrite: if self.model_type is None: - logger.debug("Deferring model configuration to optimize based on training data.") + logger.debug( + "Deferring model configuration to optimize based on training data." + ) else: logger.debug(json.dumps(self.model_type.config, indent=2)) @@ -109,7 +115,11 @@ def load( return trainer def train( - self, dataset_path: str, delimiter: str = ",", round_decimals: int = 4, seed_fields: Optional[list] = None, + self, + dataset_path: str, + delimiter: str = ",", + round_decimals: int = 4, + seed_fields: Optional[list] = None, ): """Train a model on the dataset @@ -121,7 +131,9 @@ def train( """ self.dataset_path = Path(dataset_path) self.df = self._preprocess_data( - dataset_path=dataset_path, delimiter=delimiter, round_decimals=round_decimals + dataset_path=dataset_path, + delimiter=delimiter, + round_decimals=round_decimals, ) self.run = self._initialize_run( df=self.df, overwrite=self.overwrite, seed_fields=seed_fields diff --git a/test-requirements.txt b/test-requirements.txt index faf2dfc6..2d09773d 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,4 +1,5 @@ -lxml -pandas-stubs +lxml==4.6.5 +pandas-stubs~=1.5.3.230321 +pandas~=1.5 pyright==1.1.305 -pytest +pytest==6.1.2 diff --git a/tests/benchmark/conftest.py b/tests/benchmark/conftest.py index ee0adf10..01158281 100644 --- a/tests/benchmark/conftest.py +++ b/tests/benchmark/conftest.py @@ -1,5 +1,6 @@ import json import tempfile + from unittest.mock import Mock, patch import pandas as pd diff --git a/tests/benchmark/test_bad_setup.py b/tests/benchmark/test_bad_setup.py index 198d135d..67fe04e7 100644 --- a/tests/benchmark/test_bad_setup.py +++ b/tests/benchmark/test_bad_setup.py @@ -2,7 +2,7 @@ import pytest -from gretel_trainer.benchmark import GretelLSTM, compare +from gretel_trainer.benchmark import compare, GretelLSTM from gretel_trainer.benchmark.core import BenchmarkConfig, BenchmarkException diff --git a/tests/benchmark/test_benchmark.py b/tests/benchmark/test_benchmark.py index 9fb80683..0508a7da 100644 --- a/tests/benchmark/test_benchmark.py +++ b/tests/benchmark/test_benchmark.py @@ -1,21 +1,22 @@ import gzip import os import tempfile + from unittest.mock import Mock, patch import pandas as pd import pandas.testing as pdtest import pytest + from gretel_client.projects.jobs import Status from gretel_client.projects.models import read_model_config - from gretel_trainer.benchmark import ( BenchmarkConfig, + compare, + create_dataset, Datatype, GretelGPTX, GretelLSTM, - compare, - create_dataset, launch, ) from gretel_trainer.benchmark.core import Dataset diff --git a/tests/benchmark/test_custom_datasets.py b/tests/benchmark/test_custom_datasets.py index a0b66569..f2c728c1 100644 --- a/tests/benchmark/test_custom_datasets.py +++ b/tests/benchmark/test_custom_datasets.py @@ -2,7 +2,7 @@ import pytest -from gretel_trainer.benchmark import Datatype, create_dataset +from gretel_trainer.benchmark import create_dataset, Datatype from gretel_trainer.benchmark.core import BenchmarkException diff --git a/tests/relational/conftest.py b/tests/relational/conftest.py index a18ceaa2..4a06cd0c 100644 --- a/tests/relational/conftest.py +++ b/tests/relational/conftest.py @@ -1,12 +1,14 @@ import itertools import sqlite3 import tempfile + from pathlib import Path from typing import Callable, Generator from unittest.mock import Mock, patch import pandas as pd import pytest + from sqlalchemy import create_engine from gretel_trainer.relational.connectors import Connector diff --git a/tests/relational/test_ancestral_strategy.py b/tests/relational/test_ancestral_strategy.py index f10c3cd2..92834f93 100644 --- a/tests/relational/test_ancestral_strategy.py +++ b/tests/relational/test_ancestral_strategy.py @@ -1,6 +1,7 @@ import json import os import tempfile + from pathlib import Path from unittest.mock import Mock, patch @@ -9,6 +10,7 @@ import pytest import gretel_trainer.relational.ancestry as ancestry + from gretel_trainer.relational.core import MultiTableException from gretel_trainer.relational.strategies.ancestral import AncestralStrategy from gretel_trainer.relational.table_evaluation import TableEvaluation diff --git a/tests/relational/test_artifacts.py b/tests/relational/test_artifacts.py index 9b411782..3a14ca1e 100644 --- a/tests/relational/test_artifacts.py +++ b/tests/relational/test_artifacts.py @@ -2,15 +2,16 @@ import shutil import tarfile import tempfile + from pathlib import Path from unittest.mock import Mock import pytest from gretel_trainer.relational.artifacts import ( - ArtifactCollection, archive_items, archive_nested_dir, + ArtifactCollection, ) diff --git a/tests/relational/test_common_strategy.py b/tests/relational/test_common_strategy.py index 2d7dc4f9..86051d89 100644 --- a/tests/relational/test_common_strategy.py +++ b/tests/relational/test_common_strategy.py @@ -1,6 +1,7 @@ import pandas as pd import gretel_trainer.relational.strategies.common as common + from gretel_trainer.relational.core import RelationalData diff --git a/tests/relational/test_extractor.py b/tests/relational/test_extractor.py index 0bdd2e1f..761fb88f 100644 --- a/tests/relational/test_extractor.py +++ b/tests/relational/test_extractor.py @@ -1,5 +1,6 @@ import sqlite3 import tempfile + from pathlib import Path from typing import Iterable @@ -7,9 +8,9 @@ from gretel_trainer.relational.connectors import Connector, sqlite_conn from gretel_trainer.relational.extractor import ( + _determine_sample_size, ExtractorConfig, TableExtractor, - _determine_sample_size, ) diff --git a/tests/relational/test_independent_strategy.py b/tests/relational/test_independent_strategy.py index 662e1f32..03547b95 100644 --- a/tests/relational/test_independent_strategy.py +++ b/tests/relational/test_independent_strategy.py @@ -1,6 +1,7 @@ import json import os import tempfile + from collections import defaultdict from pathlib import Path from unittest.mock import Mock, patch diff --git a/tests/relational/test_model_config.py b/tests/relational/test_model_config.py index c52a6549..80018162 100644 --- a/tests/relational/test_model_config.py +++ b/tests/relational/test_model_config.py @@ -1,5 +1,4 @@ from gretel_client.projects.models import read_model_config - from gretel_trainer.relational.model_config import ( get_model_key, make_evaluate_config, diff --git a/tests/relational/test_multi_table_config_options.py b/tests/relational/test_multi_table_config_options.py index f703e81d..3571e710 100644 --- a/tests/relational/test_multi_table_config_options.py +++ b/tests/relational/test_multi_table_config_options.py @@ -1,4 +1,5 @@ import tempfile + from unittest.mock import patch import pytest diff --git a/tests/relational/test_multi_table_restore.py b/tests/relational/test_multi_table_restore.py index 31e7262e..5d67e92d 100644 --- a/tests/relational/test_multi_table_restore.py +++ b/tests/relational/test_multi_table_restore.py @@ -3,6 +3,7 @@ import shutil import tarfile import tempfile + from pathlib import Path from typing import Optional from unittest.mock import Mock, patch @@ -10,6 +11,7 @@ import pytest import gretel_trainer.relational.backup as b + from gretel_trainer.relational.artifacts import ArtifactCollection from gretel_trainer.relational.core import MultiTableException, RelationalData from gretel_trainer.relational.multi_table import MultiTable, SyntheticsRun diff --git a/tests/relational/test_relational_data_with_json.py b/tests/relational/test_relational_data_with_json.py index 5cd8d48c..83cfb1ff 100644 --- a/tests/relational/test_relational_data_with_json.py +++ b/tests/relational/test_relational_data_with_json.py @@ -817,7 +817,6 @@ def test_handles_missing_interior_invented_tables(nested_lists_of_objects): "Records>userAgent": ["abc", "def"], } ), - # Since demo_invented_2 is missing, Independent strategy post-processing # will set all foreign key values on this table to None. "demo_invented_3": pd.DataFrame( @@ -847,9 +846,7 @@ def test_handles_missing_interior_invented_tables(nested_lists_of_objects): }, { "userAgent": "def", - "responseElements": { - "accountAttributes": [] - }, + "responseElements": {"accountAttributes": []}, }, ] } diff --git a/tests/relational/test_synthetics_run_task.py b/tests/relational/test_synthetics_run_task.py index 250ae358..39298594 100644 --- a/tests/relational/test_synthetics_run_task.py +++ b/tests/relational/test_synthetics_run_task.py @@ -1,4 +1,5 @@ import tempfile + from dataclasses import dataclass from pathlib import Path from typing import Optional, Union @@ -7,13 +8,13 @@ import pandas as pd import pandas.testing as pdtest import pytest + from gretel_client.projects.jobs import Status from gretel_client.projects.projects import Project - from gretel_trainer.relational.core import RelationalData from gretel_trainer.relational.sdk_extras import ( - MAX_PROJECT_ARTIFACTS, ExtendedGretelSDK, + MAX_PROJECT_ARTIFACTS, ) from gretel_trainer.relational.strategies.ancestral import AncestralStrategy from gretel_trainer.relational.strategies.independent import IndependentStrategy diff --git a/tests/relational/test_task_runner.py b/tests/relational/test_task_runner.py index a8742e9d..ae18158c 100644 --- a/tests/relational/test_task_runner.py +++ b/tests/relational/test_task_runner.py @@ -1,8 +1,8 @@ -from unittest.mock import Mock, PropertyMock, patch +from unittest.mock import Mock, patch, PropertyMock import pytest -from gretel_client.projects.jobs import Job, Status +from gretel_client.projects.jobs import Job, Status from gretel_trainer.relational.sdk_extras import MAX_PROJECT_ARTIFACTS from gretel_trainer.relational.task_runner import run_task diff --git a/tests/relational/test_train_synthetics.py b/tests/relational/test_train_synthetics.py index fac0bcef..cb07bf45 100644 --- a/tests/relational/test_train_synthetics.py +++ b/tests/relational/test_train_synthetics.py @@ -1,4 +1,5 @@ import tempfile + from unittest.mock import ANY, patch import pytest diff --git a/tests/relational/test_train_transforms.py b/tests/relational/test_train_transforms.py index 4de5166e..94478976 100644 --- a/tests/relational/test_train_transforms.py +++ b/tests/relational/test_train_transforms.py @@ -1,4 +1,5 @@ import tempfile + from unittest.mock import ANY, patch import pytest diff --git a/tests/test_strategy.py b/tests/test_strategy.py index 23673ce3..51c219f3 100644 --- a/tests/test_strategy.py +++ b/tests/test_strategy.py @@ -33,7 +33,9 @@ def test_strategy_all_columns(constraints: PartitionConstraints, test_df): ) # partitions are of roughly equal size - extracted_df_lengths = [len(partition.extract_df(test_df)) for partition in strategy.partitions] + extracted_df_lengths = [ + len(partition.extract_df(test_df)) for partition in strategy.partitions + ] assert max(extracted_df_lengths) - min(extracted_df_lengths) <= 1 # re-assemble all partitions and compare