Skip to content

Commit

Permalink
Sync SAL sources and dependencies
Browse files Browse the repository at this point in the history
GitOrigin-RevId: dea281a81071fed5dab34b66b0cbc861e6cb3911
  • Loading branch information
mikeknep committed Aug 11, 2023
1 parent 2456fec commit e037974
Show file tree
Hide file tree
Showing 66 changed files with 186 additions and 113 deletions.
37 changes: 21 additions & 16 deletions notebooks/conditional-generation.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,29 @@
import pandas as pd
from gretel_client import configure_session

from gretel_client import configure_session
from gretel_trainer import Trainer
from gretel_trainer.models import GretelLSTM, GretelACTGAN
from gretel_trainer.models import GretelACTGAN, GretelLSTM

DATASET_PATH = 'https://gretel-public-website.s3.amazonaws.com/datasets/mitre-synthea-health.csv'
DATASET_PATH = (
"https://gretel-public-website.s3.amazonaws.com/datasets/mitre-synthea-health.csv"
)
MODEL_TYPE = [GretelLSTM(), GretelACTGAN()][1]

# Create dataset to autocomplete values for
seed_df = pd.DataFrame(data=[
["black", "african", "F"],
["black", "african", "F"],
["black", "african", "F"],
["black", "african", "F"],
["asian", "chinese", "F"],
["asian", "chinese", "F"],
["asian", "chinese", "F"],
["asian", "chinese", "F"],
["asian", "chinese", "F"]
], columns=["RACE", "ETHNICITY", "GENDER"])
seed_df = pd.DataFrame(
data=[
["black", "african", "F"],
["black", "african", "F"],
["black", "african", "F"],
["black", "african", "F"],
["asian", "chinese", "F"],
["asian", "chinese", "F"],
["asian", "chinese", "F"],
["asian", "chinese", "F"],
["asian", "chinese", "F"],
],
columns=["RACE", "ETHNICITY", "GENDER"],
)


# Configure Gretel credentials
Expand All @@ -31,5 +36,5 @@
print(model.generate(seed_df=seed_df))

# Load a existing model and conditionally generate data
#model = Trainer.load()
#print(model.generate(seed_df=seed_df))
# model = Trainer.load()
# print(model.generate(seed_df=seed_df))
5 changes: 1 addition & 4 deletions notebooks/custom-example.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@

# Specify underlying model and config options.
# configs can be either a string, dict, or path
model_type = GretelACTGAN(
config="synthetics/tabular-actgan",
max_rows=50000
)
model_type = GretelACTGAN(config="synthetics/tabular-actgan", max_rows=50000)

# Optionally update model parameters from a base config
model_type.update_params({"epochs": 500})
Expand Down
4 changes: 2 additions & 2 deletions notebooks/simple-example.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@

# Or, load and generate data from an existing model

#model = Trainer.load()
#model.generate(num_records=70)
# model = Trainer.load()
# model.generate(num_records=70)
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ plotly~=5.11
pydantic~=1.9
requests~=2.25
scikit-learn~=1.0
smart-open[s3]~=5.2
smart_open[s3]~=5.2
sqlalchemy~=1.4
typing-extensions~=4.7
typing_extensions~=4.7
unflatten==0.1.1
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pathlib
from setuptools import setup, find_packages

from setuptools import find_packages, setup

local_path = pathlib.Path(__file__).parent
install_requires = (local_path / "requirements.txt").read_text().splitlines()
Expand Down
1 change: 1 addition & 0 deletions src/gretel_trainer/benchmark/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import csv
import logging
import time

from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
Expand Down
1 change: 1 addition & 0 deletions src/gretel_trainer/benchmark/custom/datasets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
import uuid

from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Union
Expand Down
2 changes: 1 addition & 1 deletion src/gretel_trainer/benchmark/custom/strategy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pathlib import Path
from typing import Optional

from gretel_trainer.benchmark.core import BenchmarkConfig, Dataset, Timer, run_out_path
from gretel_trainer.benchmark.core import BenchmarkConfig, Dataset, run_out_path, Timer
from gretel_trainer.benchmark.custom.models import CustomModel


Expand Down
7 changes: 4 additions & 3 deletions src/gretel_trainer/benchmark/entrypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,22 @@

import logging
import shutil

from inspect import isclass
from pathlib import Path
from typing import Optional, Type, Union, cast
from typing import cast, Optional, Type, Union

import pandas as pd
from gretel_client.config import get_session_config

from gretel_client.config import get_session_config
from gretel_trainer.benchmark.core import BenchmarkConfig, BenchmarkException, Dataset
from gretel_trainer.benchmark.custom.models import CustomModel
from gretel_trainer.benchmark.gretel.models import GretelModel
from gretel_trainer.benchmark.job_spec import (
DatasetTypes,
JobSpec,
ModelTypes,
model_name,
ModelTypes,
)
from gretel_trainer.benchmark.session import Session

Expand Down
2 changes: 1 addition & 1 deletion src/gretel_trainer/benchmark/executor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import logging

from enum import Enum
from typing import Optional, Protocol

from gretel_client.projects.models import Model
from gretel_client.projects.projects import Project

from gretel_trainer.benchmark.core import BenchmarkConfig, Dataset, log, run_out_path
from gretel_trainer.benchmark.sdk_extras import create_evaluate_model, run_evaluate

Expand Down
2 changes: 2 additions & 0 deletions src/gretel_trainer/benchmark/gretel/datasets.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import json

from functools import cached_property
from typing import Optional, Union

import boto3

from botocore import UNSIGNED
from botocore.client import Config

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# It can be deleted completely once we fully remove these functions.

import logging

from typing import Optional, Union

from gretel_trainer.benchmark import Datatype
Expand Down
7 changes: 4 additions & 3 deletions src/gretel_trainer/benchmark/gretel/models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import copy

from inspect import isclass
from pathlib import Path
from typing import Optional, Type, Union, cast
from typing import cast, Optional, Type, Union

import gretel_trainer.models

from gretel_client.projects.exceptions import ModelConfigError
from gretel_client.projects.models import read_model_config

import gretel_trainer.models
from gretel_trainer.benchmark.core import BenchmarkException, Dataset, Datatype

GretelModelConfig = Union[str, Path, dict]
Expand Down
3 changes: 2 additions & 1 deletion src/gretel_trainer/benchmark/gretel/strategy_sdk.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import copy
import gzip

from pathlib import Path
from typing import Optional

import requests

from gretel_client.projects.jobs import END_STATES, Job, RunnerMode, Status
from gretel_client.projects.models import Model, read_model_config
from gretel_client.projects.projects import Project
from gretel_client.projects.records import RecordHandler

from gretel_trainer.benchmark.core import (
BenchmarkConfig,
BenchmarkException,
Expand Down
2 changes: 1 addition & 1 deletion src/gretel_trainer/benchmark/gretel/strategy_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
BenchmarkConfig,
BenchmarkException,
Dataset,
Timer,
run_out_path,
Timer,
)
from gretel_trainer.benchmark.gretel.models import GretelModel
from gretel_trainer.benchmark.job_spec import JobSpec
Expand Down
3 changes: 2 additions & 1 deletion src/gretel_trainer/benchmark/sdk_extras.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import json
import time

from typing import Any

import smart_open

from gretel_client.projects.jobs import (
ACTIVE_STATES,
END_STATES,
Expand All @@ -12,7 +14,6 @@
)
from gretel_client.projects.models import Model, read_model_config
from gretel_client.projects.projects import Project

from gretel_trainer.benchmark.core import BenchmarkException, log


Expand Down
8 changes: 5 additions & 3 deletions src/gretel_trainer/benchmark/session.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from __future__ import annotations

import logging

from concurrent.futures import Future, ThreadPoolExecutor
from typing import Any, Optional, Union

import pandas as pd
from gretel_client.helpers import poll
from gretel_client.projects import Project, create_project, search_projects
from gretel_client.projects.jobs import Job

from typing_extensions import TypeGuard

from gretel_client.helpers import poll
from gretel_client.projects import create_project, Project, search_projects
from gretel_client.projects.jobs import Job
from gretel_trainer.benchmark.core import BenchmarkConfig, BenchmarkException
from gretel_trainer.benchmark.custom.models import CustomModel
from gretel_trainer.benchmark.custom.strategy import CustomStrategy
Expand Down
14 changes: 8 additions & 6 deletions src/gretel_trainer/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Union, Optional

from typing import Optional, TYPE_CHECKING, Union

if TYPE_CHECKING:
import pandas as pd
Expand All @@ -17,11 +18,12 @@


def _actgan_is_best(rows: int, cols: int) -> bool:
return \
rows > HIGH_RECORD_THRESHOLD or \
cols > HIGH_COLUMN_THRESHOLD or \
rows < LOW_RECORD_THRESHOLD or \
cols < LOW_COLUMN_THRESHOLD
return (
rows > HIGH_RECORD_THRESHOLD
or cols > HIGH_COLUMN_THRESHOLD
or rows < LOW_RECORD_THRESHOLD
or cols < LOW_COLUMN_THRESHOLD
)


def determine_best_model(df: pd.DataFrame) -> _BaseConfig:
Expand Down
1 change: 1 addition & 0 deletions src/gretel_trainer/relational/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import gretel_trainer.relational.log

from gretel_trainer.relational.connectors import (
Connector,
mariadb_conn,
Expand Down
1 change: 1 addition & 0 deletions src/gretel_trainer/relational/ancestry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re

from typing import Optional

import pandas as pd
Expand Down
1 change: 1 addition & 0 deletions src/gretel_trainer/relational/artifacts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import shutil
import tempfile

from dataclasses import dataclass
from pathlib import Path
from typing import Optional
Expand Down
2 changes: 2 additions & 0 deletions src/gretel_trainer/relational/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
from __future__ import annotations

import logging

from pathlib import Path
from typing import Optional

import pandas as pd

from sqlalchemy import create_engine
from sqlalchemy.engine.base import Engine
from sqlalchemy.exc import OperationalError
Expand Down
3 changes: 3 additions & 0 deletions src/gretel_trainer/relational/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,20 @@
import logging
import shutil
import tempfile

from dataclasses import dataclass, replace
from enum import Enum
from pathlib import Path
from typing import Any, Optional, Union

import networkx
import pandas as pd

from networkx.algorithms.dag import dag_longest_path_length, topological_sort
from pandas.api.types import is_string_dtype

import gretel_trainer.relational.json as relational_json

from gretel_trainer.relational.json import (
IngestResponseT,
InventedTableMetadata,
Expand Down
6 changes: 4 additions & 2 deletions src/gretel_trainer/relational/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@
from __future__ import annotations

import logging

from contextlib import nullcontext
from dataclasses import asdict, dataclass
from enum import Enum
from pathlib import Path
from threading import Lock
from typing import TYPE_CHECKING, Iterator, Optional
from typing import Iterator, Optional, TYPE_CHECKING

import dask.dataframe as dd
import numpy as np
import pandas as pd
from sqlalchemy import MetaData, Table, func, inspect, select, tuple_

from sqlalchemy import func, inspect, MetaData, select, Table, tuple_

from gretel_trainer.relational.core import RelationalData

Expand Down
4 changes: 3 additions & 1 deletion src/gretel_trainer/relational/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

import logging
import re

from dataclasses import dataclass
from json import JSONDecodeError, dumps, loads
from json import dumps, JSONDecodeError, loads
from typing import Any, Optional, Protocol, Union
from uuid import uuid4

import numpy as np
import pandas as pd

from unflatten import unflatten

logger = logging.getLogger(__name__)
Expand Down
1 change: 1 addition & 0 deletions src/gretel_trainer/relational/log.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging

from contextlib import contextmanager

RELATIONAL = "gretel_trainer.relational"
Expand Down
1 change: 0 additions & 1 deletion src/gretel_trainer/relational/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from gretel_client.projects.exceptions import ModelConfigError
from gretel_client.projects.models import read_model_config

from gretel_trainer.relational.core import (
GretelModelConfig,
MultiTableException,
Expand Down
Loading

0 comments on commit e037974

Please sign in to comment.