diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 74c5a06b..5171bff2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,7 +15,7 @@ repos: - id: trailing-whitespace - repo: https://github.com/psf/black - rev: 23.12.1 + rev: 24.1.1 hooks: - id: black args: ["--line-length", "100"] diff --git a/example/1_ctgan_example.py b/example/1_ctgan_example.py index 002b976d..5156f3ca 100644 --- a/example/1_ctgan_example.py +++ b/example/1_ctgan_example.py @@ -1,6 +1,7 @@ """ Example for CTGAN """ + from sdgx.data_connectors.csv_connector import CsvConnector from sdgx.models.ml.single_table.ctgan import CTGANSynthesizerModel from sdgx.synthesizer import Synthesizer diff --git a/example/extension/dummycache/dummycache/dummycache.py b/example/extension/dummycache/dummycache/dummycache.py index f826c6dd..2c10e083 100644 --- a/example/extension/dummycache/dummycache/dummycache.py +++ b/example/extension/dummycache/dummycache/dummycache.py @@ -3,8 +3,7 @@ from sdgx.cachers.base import Cacher -class MyOwnCache(Cacher): - ... +class MyOwnCache(Cacher): ... from sdgx.cachers.extension import hookimpl diff --git a/example/extension/dummydataconnector/dummydataconnector/dataconnector.py b/example/extension/dummydataconnector/dummydataconnector/dataconnector.py index ab3891ae..e4bddff2 100644 --- a/example/extension/dummydataconnector/dummydataconnector/dataconnector.py +++ b/example/extension/dummydataconnector/dummydataconnector/dataconnector.py @@ -3,8 +3,7 @@ from sdgx.data_connectors.base import DataConnector -class MyOwnDataConnector(DataConnector): - ... +class MyOwnDataConnector(DataConnector): ... from sdgx.data_connectors.extension import hookimpl diff --git a/example/extension/dummydataprocessor/dummydataprocessor/dataprocessor.py b/example/extension/dummydataprocessor/dummydataprocessor/dataprocessor.py index b6028877..3da7c831 100644 --- a/example/extension/dummydataprocessor/dummydataprocessor/dataprocessor.py +++ b/example/extension/dummydataprocessor/dummydataprocessor/dataprocessor.py @@ -3,8 +3,7 @@ from sdgx.data_processors.base import DataProcessor -class MyOwnDataProcessor(DataProcessor): - ... +class MyOwnDataProcessor(DataProcessor): ... from sdgx.data_processors.extension import hookimpl diff --git a/example/extension/dummyexporter/dummyexporter/dummyexporter.py b/example/extension/dummyexporter/dummyexporter/dummyexporter.py index a9b2e9e9..48612a22 100644 --- a/example/extension/dummyexporter/dummyexporter/dummyexporter.py +++ b/example/extension/dummyexporter/dummyexporter/dummyexporter.py @@ -3,8 +3,7 @@ from sdgx.data_exporters.base import DataExporter -class MyOwnExporter(DataExporter): - ... +class MyOwnExporter(DataExporter): ... from sdgx.data_exporters.extension import hookimpl diff --git a/example/extension/dummymetadatainspector/dummymetadatainspector/inspector.py b/example/extension/dummymetadatainspector/dummymetadatainspector/inspector.py index 005fe8ea..20719c10 100644 --- a/example/extension/dummymetadatainspector/dummymetadatainspector/inspector.py +++ b/example/extension/dummymetadatainspector/dummymetadatainspector/inspector.py @@ -4,8 +4,7 @@ from sdgx.data_models.inspectors.extension import hookimpl -class MyOwnInspector(Inspector): - ... +class MyOwnInspector(Inspector): ... @hookimpl diff --git a/example/extension/dummymodel/dummymodel/model.py b/example/extension/dummymodel/dummymodel/model.py index 714e7b3b..120ed54d 100644 --- a/example/extension/dummymodel/dummymodel/model.py +++ b/example/extension/dummymodel/dummymodel/model.py @@ -3,8 +3,7 @@ from sdgx.models.base import SynthesizerModel -class MyOwnModel(SynthesizerModel): - ... +class MyOwnModel(SynthesizerModel): ... from sdgx.models.extension import hookimpl diff --git a/sdgx/models/components/optimize/sdv_ctgan/data_sampler.py b/sdgx/models/components/optimize/sdv_ctgan/data_sampler.py index fc69c9c4..a7355eb8 100644 --- a/sdgx/models/components/optimize/sdv_ctgan/data_sampler.py +++ b/sdgx/models/components/optimize/sdv_ctgan/data_sampler.py @@ -1,4 +1,5 @@ """DataSampler module.""" + from __future__ import annotations import numpy as np diff --git a/sdgx/models/components/optimize/sdv_ctgan/data_transformer.py b/sdgx/models/components/optimize/sdv_ctgan/data_transformer.py index 695d05fc..c1cfa162 100644 --- a/sdgx/models/components/optimize/sdv_ctgan/data_transformer.py +++ b/sdgx/models/components/optimize/sdv_ctgan/data_transformer.py @@ -1,4 +1,5 @@ """DataTransformer module.""" + from __future__ import annotations from collections import namedtuple diff --git a/sdgx/models/components/sdv_copulas/bivariate/clayton.py b/sdgx/models/components/sdv_copulas/bivariate/clayton.py index 820febad..0e1bcbec 100644 --- a/sdgx/models/components/sdv_copulas/bivariate/clayton.py +++ b/sdgx/models/components/sdv_copulas/bivariate/clayton.py @@ -82,11 +82,14 @@ def cumulative_distribution(self, X): else: cdfs = [ - np.power( - np.power(U[i], -self.theta) + np.power(V[i], -self.theta) - 1, -1.0 / self.theta + ( + np.power( + np.power(U[i], -self.theta) + np.power(V[i], -self.theta) - 1, + -1.0 / self.theta, + ) + if (U[i] > 0 and V[i] > 0) + else 0 ) - if (U[i] > 0 and V[i] > 0) - else 0 for i in range(len(U)) ] diff --git a/sdgx/models/components/sdv_rdt/transformers/base.py b/sdgx/models/components/sdv_rdt/transformers/base.py index 8a1188f3..4b6e4f57 100644 --- a/sdgx/models/components/sdv_rdt/transformers/base.py +++ b/sdgx/models/components/sdv_rdt/transformers/base.py @@ -1,4 +1,5 @@ """BaseTransformer module.""" + import abc import inspect diff --git a/sdgx/models/components/sdv_rdt/transformers/datetime.py b/sdgx/models/components/sdv_rdt/transformers/datetime.py index 3a744826..1776e452 100644 --- a/sdgx/models/components/sdv_rdt/transformers/datetime.py +++ b/sdgx/models/components/sdv_rdt/transformers/datetime.py @@ -1,4 +1,5 @@ """Transformer for datetime data.""" + import numpy as np import pandas as pd from pandas.api.types import is_datetime64_dtype diff --git a/sdgx/models/components/sdv_rdt/transformers/numerical.py b/sdgx/models/components/sdv_rdt/transformers/numerical.py index 8d0a274d..47c9b38c 100644 --- a/sdgx/models/components/sdv_rdt/transformers/numerical.py +++ b/sdgx/models/components/sdv_rdt/transformers/numerical.py @@ -1,4 +1,5 @@ """Transformers for numerical data.""" + import copy import sys import warnings diff --git a/sdgx/models/components/sdv_rdt/transformers/text.py b/sdgx/models/components/sdv_rdt/transformers/text.py index 6a98ca61..0d15b51c 100644 --- a/sdgx/models/components/sdv_rdt/transformers/text.py +++ b/sdgx/models/components/sdv_rdt/transformers/text.py @@ -1,4 +1,5 @@ """Transformers for text data.""" + import warnings import numpy as np diff --git a/sdgx/models/statistics/single_table/copula.py b/sdgx/models/statistics/single_table/copula.py index 6226a2ec..0a9912f4 100644 --- a/sdgx/models/statistics/single_table/copula.py +++ b/sdgx/models/statistics/single_table/copula.py @@ -2,6 +2,7 @@ Wrappers around copulas models. 需要修改: fit接口以适应性能优化措施 """ + import logging import warnings from copy import deepcopy diff --git a/sdgx/synthesizer.py b/sdgx/synthesizer.py index 4c5531a1..d17b34f5 100644 --- a/sdgx/synthesizer.py +++ b/sdgx/synthesizer.py @@ -103,10 +103,12 @@ def __init__( data_processors = [] self.data_processors_manager = DataProcessorManager() self.data_processors = [ - d - if isinstance(d, DataProcessor) - else self.data_processors_manager.init_data_processor( - d, **(data_processors_kwargs or {}) + ( + d + if isinstance(d, DataProcessor) + else self.data_processors_manager.init_data_processor( + d, **(data_processors_kwargs or {}) + ) ) for d in data_processors ]