diff --git a/pyproject.toml b/pyproject.toml index 1bf01df4..f5388382 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "pydantic>=2", "cloudpickle", "importlib_metadata", + "openai>=1.10.0", ] dynamic = ["version"] classifiers = [ diff --git a/sdgx/models/LLM/__init__.py b/sdgx/models/LLM/__init__.py new file mode 100644 index 00000000..a49c46f2 --- /dev/null +++ b/sdgx/models/LLM/__init__.py @@ -0,0 +1 @@ +from . import single_table diff --git a/sdgx/models/LLM/base.py b/sdgx/models/LLM/base.py new file mode 100644 index 00000000..07e6f109 --- /dev/null +++ b/sdgx/models/LLM/base.py @@ -0,0 +1,123 @@ +from sdgx.exceptions import SynthesizerInitError +from sdgx.models.base import SynthesizerModel +from sdgx.utils import logger + + +class LLMBaseModel(SynthesizerModel): + """ + This is a base class for generating synthetic data using LLM (Large Language Model). + + Note: + - When using the data loader, the original data is transformed to pd.DataFrame format for subsequent processing. + - It is not recommended to use this model with large data tables due to excessive token consumption in some expensive LLM service. + - Generating data based on metadata is a potential way to generate data that cannot be made public and contains sensitive information. + """ + + use_raw_data = False + """ + By default, we use raw_data for data access. + + When using the data loader, due to the need of randomization operation, we currently use the `.load_all()` to transform the original data to pd.DataFrame format for subsequent processing. + + Due to the characteristics of the OpenAI GPT service, we do not recommend running this model with large data tables, which will consume your tokens excessively. + """ + + use_metadata = False + """ + In this model, we accept a data generation paradigm that only provides metadata. + + When only metadata is provided, sdgx will format the metadata of the data set into a message and transmit it to GPT, and GPT will generate similar data based on what it knows. + + This is a potential way to generate data that cannot be made public and contains sensitive information. + """ + + _metadata = None + """ + the metadata. + """ + + off_table_features = [] + """ + * Experimental Feature + + Whether infer data columns that do not exist in the real data table, the effect may not be very good. + """ + + prompts = { + "message_prefix": """Suppose you are the best data generating model in this world, we have some data samples with the following information:\n\n""", + "message_suffix": """\nGenerate synthetic data samples based on the above information and your knowledge, each sample should be output on one line (do not output in multiple lines), the output format of the sample is the same as the example in this message, such as "column_name_1 is value_1", the count of the generated data samples is """, + "system_role_content": "You are a powerful synthetic data generation model.", + } + """ + Prompt words for generating data (preliminary version, improvements welcome). + """ + + columns = [] + """ + The columns of the data set. + """ + + dataset_description = "" + """ + The description of the data set. + """ + + _responses = [] + """ + A list to store the responses received from the LLM. + """ + + _message_list = [] + """ + A list to store the messages used to ask LLM. + """ + + def _check_access_type(self): + """ + Checks the data access type. + + Raises: + SynthesizerInitError: If data access type is not specified or if duplicate data access type is found. + """ + if self.use_dataloader == self.use_raw_data == self.use_metadata == False: + raise SynthesizerInitError( + "Data access type not specified, please use `use_raw_data: bool` or `use_dataloader: bool` to specify data access type." + ) + if self.use_dataloader == self.use_raw_data == True: + raise SynthesizerInitError("Duplicate data access type found.") + + def _form_columns_description(self): + """ + We believe that giving information about a column helps improve data quality. + + Currently, we leave this function to Good First Issue until March 2024, if unclaimed we will implement it quickly. + """ + + raise NotImplementedError + + def _form_message_with_offtable_features(self): + """ + This function forms a message with off-table features. + + If there are more off-table columns, additional processing is excuted here. + """ + if self.off_table_features: + logger.info(f"Use off_table_feature = {self.off_table_features}.") + return f"Also, you should try to infer another {len(self.off_table_features)} columns based on your knowledge, the name of these columns are : {self.off_table_features}, attach these columns after the original table. \n" + else: + logger.info("No off_table_feature needed in current model.") + return "" + + def _form_dataset_description(self): + """ + This function is used to form the dataset description. + + Returns: + str: The description of the generated table. + """ + if self.dataset_description: + logger.info(f"Use dataset_description = {self.dataset_description}.") + return "\nThe description of the generated table is " + self.dataset_description + "\n" + else: + logger.info("No dataset_description given in current model.") + return "" diff --git a/sdgx/models/LLM/single_table/__init__.py b/sdgx/models/LLM/single_table/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sdgx/models/LLM/single_table/gpt.py b/sdgx/models/LLM/single_table/gpt.py new file mode 100644 index 00000000..02bce739 --- /dev/null +++ b/sdgx/models/LLM/single_table/gpt.py @@ -0,0 +1,511 @@ +from __future__ import annotations + +import os +import random +import re +from copy import copy + +import openai +import pandas as pd + +from sdgx.data_loader import DataLoader +from sdgx.data_models.metadata import Metadata +from sdgx.exceptions import InitializationError +from sdgx.models.LLM.base import LLMBaseModel +from sdgx.utils import logger + + +class SingleTableGPTModel(LLMBaseModel): + """ + This is a synthetic data generation model powered by OpenAI GPT, a state-of-the-art language model. This model is based on groundbreaking research presented in the ICLR paper titled "Language Models are Realistic Tabular Data Generators". + + Our model harnesses the power of GPT to generate synthetic tabular data that closely resembles real-world datasets. By utilizing the advanced capabilities of GPT, we aim to provide a reliable and efficient solution for generating simulated data that can be used for various purposes, such as testing, training, and analysis. + + With this synthetic data generation model, users can easily generate diverse and realistic tabular datasets, mimicking the characteristics and patterns found in real data. + """ + + openai_API_key = "" + """ + The API key required to access the OpenAI GPT model. Please provide your own API key for authentication. + """ + + openai_API_url = "https://api.openai.com/v1/" + """ + The URL endpoint for the OpenAI GPT API. Please specify the appropriate URL for accessing the API. + """ + + max_tokens = 4000 + """ + The maximum number of tokens allowed in the generated response. This parameter helps in limiting the length of the output text. + """ + + temperature = 0.1 + """ + A parameter that controls the randomness of the generated text. Lower values like 0.1 make the output more focused and deterministic, while higher values like 1.0 introduce more randomness. + """ + + timeout = 90 + """ + The maximum time (in seconds) to wait for a response from the OpenAI GPT API. If the response is not received within this time, the request will be timed out. + """ + + gpt_model = "Gpt-3.5-turbo-0613" + """ + The specific GPT model to be used for generating text. The default model is "gpt-3.5-turbo", which is known for its high performance and versatility. + """ + + query_batch = 30 + """ + This parameter is the number of samples submitted to GPT each time and the number of returned samples. + + This size has a certain relationship with the max_token parameter. + + We do not recommend setting too large a value, as this may cause potential problems or errors. + """ + + _sample_lines = [] + """ + A list to store the sample lines of generated data. + """ + + _result_list = [] + """ + A list to store the generated data samples. + """ + + def __init__(self, *args, **kwargs) -> None: + """ + Initializes the class instance. + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + """ + super().__init__(*args, **kwargs) + self._get_openai_setting_from_env() + + def check(self): + """ + Performs various checks. + + Raises: + SynthesizerInitError: If data access type is not specified or if duplicate data access type is found. + """ + self._check_openAI_setting() + self._set_openAI() + self._check_access_type() + + def set_openAI_settings(self, API_url="https://api.openai.com/v1/", API_key=""): + """ + Sets the OpenAI settings. + + Args: + API_url (str): The OpenAI API URL. Defaults to "https://api.openai.com/v1/". + API_key (str): The OpenAI API key. Defaults to an empty string. + """ + self.openai_API_url = API_url + self.openai_API_key = API_key + self._set_openAI() + + def _set_openAI(self): + """ + Sets the OpenAI API key and base URL. + """ + openai.api_key = self.openai_API_key + openai.base_url = self.openai_API_url + + def _check_openAI_setting(self): + """ + Checks if the OpenAI settings are properly initialized. + + Raises: + InitializationError: If openai_API_url or openai_API_key is not found. + """ + if not self.openai_API_url: + raise InitializationError("openai_API_url NOT found.") + if not self.openai_API_key: + raise InitializationError("openai_API_key NOT found.") + logger.debug("OpenAI setting check passed.") + + def _get_openai_setting_from_env(self): + """ + Retrieves OpenAI settings from environment variables. + """ + if os.getenv("OPENAI_KEY"): + self.openai_API_key = os.getenv("OPENAI_KEY") + logger.debug("Get OPENAI_KEY from ENV.") + if os.getenv("OPENAI_URL"): + self.openai_API_url = os.getenv("OPENAI_URL") + logger.debug("Get OPENAI_URL from ENV.") + + def ask_gpt(self, question, model=None): + """ + Sends a question to the GPT model. + + Args: + question (str): The question to ask. + model (str): The GPT model to use. Defaults to None. + + Returns: + str: The response from the GPT model. + + Raises: + SynthesizerInitError: If the check method fails. + """ + self.check() + api_key = self.openai_API_key + if model: + model = model + else: + model = self.gpt_model + openai.api_key = api_key + client = openai.OpenAI(api_key=api_key) + logger.info(f"Ask GPT with temperature = {self.temperature}.") + response = client.chat.completions.create( + model=model, + messages=[ + { + "role": "user", + "content": question, + }, + ], + temperature=self.temperature, + max_tokens=self.max_tokens, + timeout=self.timeout, + ) + logger.info("Ask GPT Finished.") + # store response + self._responses.append(response) + # return the content of the gpt response + return response.choices[0].message.content + + def fit( + self, raw_data: pd.DataFrame | DataLoader = None, metadata: Metadata = None, *args, **kwargs + ): + """ + Fits this model to the provided data. + Please note that no actual algorithmic training is excuted here. + + Args: + raw_data (pd.DataFrame | DataLoader): The raw data to fit the model to. It can be either a pandas DataFrame or a DataLoader object. + metadata (Metadata): The metadata associated with the raw data. + + Returns: + None + + Raises: + InitializationError: If neither raw_data nor metadata is provided. + """ + + if raw_data is not None and type(raw_data) in [pd.DataFrame, DataLoader]: + if metadata: + self._metadata = metadata + self._fit_with_data(raw_data) + return + + if type(raw_data) is Metadata: + self._fit_with_metadata(raw_data) + return + + if metadata is not None and type(metadata) is Metadata: + self._fit_with_metadata(metadata) + return + + raise InitializationError( + "Ple1ase pass at least one valid parameter, train_data or metadata" + ) + + def _fit_with_metadata(self, metadata): + """ + Fit the model using metadata. + + Args: + metadata: Metadata object. + + Returns: + None + """ + logger.info("Fitting model with metadata...") + self.use_metadata = True + self._metadata = metadata + self.columns = list(metadata.column_list) + logger.info("Fitting model with metadata... Finished.") + + def _fit_with_data(self, train_data): + """ + Fit the model using data. + + Args: + train_data: Training data. + + Returns: + None + """ + logger.info("Fitting model with raw data...") + self.use_raw_data = True + self.use_dataloader = False + if type(train_data) is DataLoader: + self.columns = list(train_data.columns()) + train_data = train_data.load_all() + if not self.columns: + self.columns = list(train_data.columns) + # get metadata if no metadata + if not self._metadata: + self._metadata = Metadata.from_dataframe(train_data) + # here we got both raw_data and metadata + sample_lines = [] + for _, row in train_data.iterrows(): + each_line = "" + shuffled_columns = copy(self.columns) + random.shuffle(shuffled_columns) + for column in shuffled_columns: + value = str(row[column]) + each_line += f"{column} is {value}, " + + each_line = each_line[:-2] + each_line += "\n" + sample_lines.append(each_line) + self._sample_lines = sample_lines + logger.info("Fitting model with raw data... Finished.") + + @staticmethod + def _select_random_elements(input_list, cnt): + """ + This function selects a random sample of elements from the input list. + + Args: + input_list (list): The list from which elements will be selected. + cnt (int): The number of elements to be selected. + + Returns: + list: A list of randomly selected elements from the input list. + + Raises: + ValueError: If cnt is greater than the length of the input list. + """ + if cnt > len(input_list): + raise ValueError("cnt should not be greater than the length of the list") + return random.sample(input_list, cnt) + + def _form_message_with_data(self, sample_list, current_cnt): + """ + This function forms a message with data. + + Args: + sample_list (list): A list of samples. + current_cnt (int): The current count of samples. + + Returns: + str: The formed message with data. + """ + # form sample string + sample_str = "" + for i in range(current_cnt): + each_sample = sample_list[i] + each_str = f"sample {i}: " + each_sample + "\n" + sample_str += each_str + + # form the message sent to GPT + message = self.prompts["message_prefix"] + sample_str + + # add dataset description + message = message + self._form_dataset_description() + + # add off-table features + message = message + self._form_message_with_offtable_features() + + message = ( + message + + f"Please note that the generated table has total {len(self.columns) + len(self.off_table_features)} columns of the generated data, the column names are {self.columns + self.off_table_features}, every column should not be missed when generating the data. \n" + ) + + # add the suffix of the message + message = message + self.prompts["message_suffix"] + str(current_cnt) + "." + # Maybe it can be optimized here + self._message_list.append(message) + logger.debug("Message Generated.") + return message + + def extract_samples_from_response(self, response_content): + """ + Extracts samples from the response content. + + Args: + response_content (dict): The response content as a dictionary. + + Returns: + list: A list of extracted samples. + """ + + def dict_to_list(input_dict, header): + """ + Converts a dictionary to a list based on the given header. + + Args: + input_dict (dict): The input dictionary. + header (list): The list of keys to extract from the dictionary. + + Returns: + list: A list of values extracted from the dictionary based on the header. + """ + res = [] + for each_col in header: + each_value = input_dict.get(each_col, None) + res.append(each_value) + return res + + logger.info("Extracting samples from response ...") + header = self.columns + self.off_table_features + features = [] + for line in response_content.split("\n"): + feature = {} + for field in header: + pattern = r"\b" + field + r"\s*(?:is|=)\s*([^,\n]+)" + match = re.search(pattern, line) + if match: + feature[field] = match.group(1).strip() + if feature: + features.append(dict_to_list(feature, header)) + logger.info(f"Extracting samples from response ... Finished, {len(features)} extracted.") + return features + + def sample(self, count=50, dataset_desp="", *args, **kwargs): + """ + This function samples data from either raw data or metadata based on the given parameters. + + Args: + count (int): The number of samples to be generated. Default is 50. + dataset_desp (str): The description of the dataset. Default is an empty string. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + res: The sampled data. + """ + logger.info("Sampling use GPT model ...") + self.dataset_description = dataset_desp + + if self.use_raw_data: + # If the use_raw_data flag is True, sample data using the _sample_with_data method. + res = self._sample_with_data(count, *args, **kwargs) + + elif self.use_metadata: + # If the use_metadata flag is True, sample data using the _sample_with_metadata method. + res = self._sample_with_metadata(count, *args, **kwargs) + logger.info("Sampling use GPT model ... Finished.") + return res + + def _form_message_with_metadata(self, current_cnt): + """ + This function forms a message with metadata for table data generation task. + + Args: + current_cnt (int): The current count of the message. + + Returns: + str: The formed message with metadata. + """ + # form message + message = "" + message = message + self.prompts["message_prefix"] # Add the message prefix + message = message + self._form_dataset_description() # Add the dataset description + message = ( + message + + "This table data generation task will only have metadata and no data samples. The header (columns infomation) of the tabular data is: " + ) + # Add information about the table data generation task + message = message + str(self.columns) + ". \n" # Add the column information + message = ( + message + self._form_message_with_offtable_features() + ) # Add off-table features information + message = ( + message + + f"Note that the generated table has total {len(self.columns) + len(self.off_table_features)} columns, the column names are {self.columns + self.off_table_features}, every column should NOT be missed in generated data.\n" + ) # Add information about the generated table columns + + # Add the message suffix and current count + message = message + self.prompts["message_suffix"] + str(current_cnt) + "." + # Append the message to the message list + self._message_list.append(message) + # Return the formed message with metadata + return message + + def _sample_with_metadata(self, count, *args, **kwargs): + """ + This method samples data with metadata. + + Args: + count (int): The number of samples to be generated. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + int: The input count. + + """ + logger.info("Sampling with metadata.") + # Initialize an empty list to store the generated samples + result = [] + # Set the remaining count to the input count + remaining_cnt = count + + # Set the remaining count to the input count + while remaining_cnt > 0: + if remaining_cnt - self.query_batch >= 0: + # Set the current count to the query batch size if >= 0 + current_cnt = self.query_batch + else: + # else, set the current count to the remaining count + current_cnt = remaining_cnt + # Generate a message with metadata + message = self._form_message_with_metadata(current_cnt) + # Send the message to GPT and get the response + response = self.ask_gpt(message) + # Extract features from the response + generated_batch = self.extract_samples_from_response(response) + # Add the generated batch to the result list + result += generated_batch + # Update the remaining count + remaining_cnt = remaining_cnt - current_cnt + + return count # Return the input count + + def _sample_with_data(self, count, *args, **kwargs): + """ + This function samples data with a given count and returns a DataFrame with the sampled data. + + Args: + count (int): The number of data samples to be generated. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + pd.DataFrame: A DataFrame containing the sampled data. + + """ + logger.info("Sampling with raw_data.") + result = [] + remaining_cnt = count + while remaining_cnt > 0: + # get the current count + if remaining_cnt - self.query_batch >= 0: + current_cnt = self.query_batch + else: + current_cnt = remaining_cnt + # select data / form message + sample_list = self._select_random_elements(self._sample_lines, current_cnt) + message = self._form_message_with_data(sample_list, current_cnt) + # ask_gpt + response = self.ask_gpt(message) + # get result from response + generated_batch = self.extract_samples_from_response(response) + # update result + result += generated_batch + # update remaining_cnt + remaining_cnt = remaining_cnt - current_cnt + + self._result_list.append(result) + + # return result + final_columns = self.columns + self.off_table_features + return pd.DataFrame(self._result_list, columns=final_columns) diff --git a/tests/models/test_singletable_gpt.py b/tests/models/test_singletable_gpt.py new file mode 100644 index 00000000..63c599d7 --- /dev/null +++ b/tests/models/test_singletable_gpt.py @@ -0,0 +1,227 @@ +from __future__ import annotations + +import pandas as pd +import pytest + +from sdgx.data_loader import DataLoader +from sdgx.data_models.metadata import Metadata +from sdgx.exceptions import InitializationError +from sdgx.models.LLM.single_table.gpt import SingleTableGPTModel + + +@pytest.fixture +def raw_data(demo_single_table_path): + yield pd.read_csv(demo_single_table_path).head(100) + + +@pytest.fixture +def single_table_gpt_model(): + yield SingleTableGPTModel() + + +# When reading the code, please collapse this list +# COLLAPSE ME +gpt_response_list = [ + """ +Here are 20 similar data entries generated based on the provided information: + +sample 21: relationship is Husband, fnlwgt is 145441, educational-num is 9, education is HS-grad, occupation is Exec-managerial, gender is Male, race is White, workclass is Private, capital-gain is 0, native-country is United-States, marital-status is Married-civ-spouse, income is >50K, age is 40, capital-loss is 1485, hours-per-week is 40 +sample 22: income is <=50K, gender is Male, education is Assoc-acdm, native-country is ?, educational-num is 12, hours-per-week is 7, occupation is Prof-specialty, capital-gain is 0, capital-loss is 0, fnlwgt is 154164, race is White, workclass is Private, age is 66, relationship is Not-in-family, marital-status is Never-married +sample 23: relationship is Own-child, marital-status is Never-married, occupation is Transport-moving, income is <=50K, race is White, workclass is Private, hours-per-week is 30, capital-gain is 0, native-country is United-States, capital-loss is 0, fnlwgt is 283499, gender is Male, education is Some-college, age is 20, educational-num is 10 +sample 24: marital-status is Married-civ-spouse, gender is Male, capital-gain is 0, capital-loss is 0, educational-num is 9, hours-per-week is 40, fnlwgt is 170772, income is <=50K, occupation is Other-service, relationship is Husband, race is White, education is HS-grad, age is 30, native-country is United-States, workclass is Local-gov +sample 25: native-country is United-States, income is <=50K, capital-gain is 0, hours-per-week is 40, fnlwgt is 367306, educational-num is 10, capital-loss is 0, race is White, relationship is Own-child, gender is Female, occupation is Tech-support, workclass is Private, age is 25, marital-status is Never-married, education is Some-college +sample 26: occupation is Exec-managerial, fnlwgt is 81973, native-country is United-States, capital-loss is 0, educational-num is 14, relationship is Husband, marital-status is Married-civ-spouse, workclass is Private, capital-gain is 0, gender is Male, race is Asian-Pac-Islander, hours-per-week is 40, age is 59, income is >50K, education is Masters +sample 27: gender is Male, fnlwgt is 287268, workclass is Private, native-country is United-States, capital-loss is 0, age is 28, educational-num is 10, hours-per-week is 35, income is <=50K, relationship is Not-in-family, occupation is Other-service, capital-gain is 0, race is White, education is Some-college, marital-status is Never-married +sample 28: race is White, income is >50K, native-country is United-States, workclass is Private, capital-loss is 0, occupation is Craft-repair, hours-per-week is 60, educational-num is 13, gender is Male, fnlwgt is 176729, relationship is Husband, marital-status is Married-civ-spouse, capital-gain is 0, age is 25, education is Bachelors +sample 29: capital-loss is 0, native-country is Cambodia, educational-num is 9, marital-status is Married-civ-spouse, workclass is Self-emp-not-inc, education is HS-grad, age is 42, occupation is Farming-fishing, race is Asian-Pac-Islander, income is >50K, gender is Male, hours-per-week is 40, fnlwgt is 303044, capital-gain is 0, relationship is Husband +sample 30: age is 17, capital-loss is 0, relationship is Own-child, educational-num is 6, workclass is Private, income is <=50K, native-country is United-States, hours-per-week is 15, race is White, occupation is Other-service, fnlwgt is 202521, education is 10th, gender is Male, capital-gain is 0, marital-status is Never-married +sample 31: fnlwgt is 173858, occupation is Exec-managerial, education is Bachelors, native-country is China, hours-per-week is 35, marital-status is Married-civ-spouse, relationship is Husband, age is 38, capital-gain is 7688, workclass is Private, race is Asian-Pac-Islander, income is >50K, capital-loss is 0, gender is Male, educational-num is 13 +sample 32: age is 74, workclass is Private, marital-status is Widowed, income is <=50K, hours-per-week is 40, occupation is Priv-house-serv, education is Assoc-voc, gender is Female, race is White, relationship is Not-in-family, fnlwgt is 68326, capital-loss is 0, native-country is United-States, capital-gain is 0, educational-num is 11 +sample 33: gender is Male, race is White, age is 28, capital-gain is 0, marital-status is Married-civ-spouse, hours-per-week is 40, relationship is Husband, education is HS-grad, educational-num is 9, income is <=50K, workclass is Private, capital-loss is 0, fnlwgt is 66095, occupation is Sales, native-country is United-States +sample 34: education is Bachelors, occupation is Adm-clerical, race is White, marital-status is Married-civ-spouse, capital-gain is 7688, gender is Male, income is >50K, relationship is Husband, workclass is Private, fnlwgt is 206814, age is 58, hours-per-week is 50, capital-loss is 0, native-country is United-States, educational-num is 13 +sample 35: age is 44, occupation is Craft-repair, capital-loss is 0, workclass is Federal-gov, income is >50K, fnlwgt is 243636, capital-gain is 0, education is Assoc-voc, relationship is Husband, educational-num is 11, race is White, marital-status is Married-civ-spouse, hours-per-week is 40, native-country is United-States, gender is Male +sample 36: education is Masters, fnlwgt is 37070, marital-status is Married-civ-spouse, age is 33, relationship is Husband, capital-gain is 0, educational-num is 14, workclass is State-gov, gender is Male, income is <=50K, occupation is Prof-specialty, capital-loss is 0, race is White, native-country is Canada, hours-per-week is 60 +sample 37: workclass is Private, educational-num is 7, education is 11th, income is <=50K, relationship is Not-in-family, gender is Male, hours-per-week is 40, native-country is Puerto-Rico, age is 23, fnlwgt is 224217, occupation is Transport-moving, capital-gain is 0, marital-status is Never-married, capital-loss is 0, race is White +sample 38: hours-per-week is 40, capital-gain is 0, gender is Male, marital-status is Never-married, age is 25, capital-loss is 0, native-country is ?, fnlwgt is 310864, income is <=50K, educational-num is 13, race is Black, workclass is Private, education is Bachelors, relationship is Not-in-family, occupation is Tech-support +sample 39: capital-gain is 0, hours-per-week is 55, capital-loss is 0, age is 30, income is <=50K, education is Bachelors, relationship is Not-in-family, marital-status is Never-married, occupation is Exec-managerial, native-country is United-States, gender is Female, race is White, educational-num is 13, workclass is Private, fnlwgt is 128016 +sample 40: hours-per-week is 40, fnlwgt is 174515, capital-loss is 0, marital-status is Widowed, native-country is United-States, capital-gain is 0, age is 40, education is HS-grad, occupation is Machine-op-inspct, educational-num is 9, relationship is Unmarried, race is White, gender is Female, workclass is Private, income is <=50K +""", + """Based on the provided information, here are 15 synthetic data samples generated: + +Sample 0: income is <=50K, race is White, marital-status is Married-civ-spouse, gender is Male, age is 36, workclass is Self-emp-inc, education is HS-grad, relationship is Husband, native-country is United-States, educational-num is 9.0, occupation is Farming-fishing, capital-loss is 0.0, capital-gain is 0.0, hours-per-week is 80.0, fnlwgt is 48063 + +Sample 1: income is <=50K, race is White, marital-status is Never-married, gender is Male, age is 34, workclass is Private, education is Masters, relationship is Not-in-family, native-country is United-States, educational-num is 14.0, occupation is Prof-specialty, capital-loss is 0.0, capital-gain is 0.0, hours-per-week is 40.0, fnlwgt is 189759 + +Sample 2: income is <=50K, race is Amer-Indian-Eskimo, marital-status is Never-married, gender is Female, age is 19, workclass is Private, education is HS-grad, relationship is Own-child, native-country is United-States, educational-num is 9.0, occupation is Other-service, capital-loss is 0.0, capital-gain is 0.0, hours-per-week is 35.0, fnlwgt is 106183 + +Sample 3: income is <=50K, race is White, marital-status is Married-civ-spouse, gender is Male, age is 42, workclass is Private, education is HS-grad, relationship is Husband, native-country is United-States, educational-num is 9.0, occupation is Craft-repair, capital-loss is 0.0, capital-gain is 0.0, hours-per-week is 45.0, fnlwgt is 190910 + +Sample 4: income is >50K, race is White, marital-status is Married-civ-spouse, gender is Male, age is 59, workclass is Private, education is 7th-8th, relationship is Husband, native-country is United-States, educational-num is 4.0, occupation is Craft-repair, capital-loss is 0.0, capital-gain is 5178.0, hours-per-week is 50.0, fnlwgt is 107318 + +Sample 5: income is >50K, race is White, marital-status is Divorced, gender is Male, age is 51, workclass is Private, education is Some-college, relationship is Not-in-family, native-country is United-States, educational-num is 10.0, occupation is Craft-repair, capital-loss is 0.0, capital-gain is 0.0, hours-per-week is 42.0, fnlwgt is 43354 + +Sample 6: income is <=50K, race is White, marital-status is Separated, gender is Male, age is 46, workclass is Private, education is HS-grad, relationship is Not-in-family, native-country is United-States, educational-num is 9.0, occupation is Transport-moving, capital-loss is 0.0, capital-gain is 0.0, hours-per-week is 40.0, fnlwgt is 170338 + +Sample 7: income is >50K, race is White, marital-status is Married-civ-spouse, gender is Male, age is 45, workclass is Self-emp-not-inc, education is Some-college, relationship is Husband, native-country is United-States, educational-num is 10.0, occupation is Sales, capital-loss is 0.0, capital-gain is 0.0, hours-per-week is 60.0, fnlwgt is 355978 + +Sample 8: income is <=50K, race is White, marital-status is Never-married, gender is Male, age is 33, workclass is Private, education is Bachelors, relationship is Not-in-family, native-country is United-States, educational-num is 13.0, occupation is Sales, capital-loss is 0.0, capital-gain is 0.0, hours-per-week is 45.0, fnlwgt is 90409 + +Sample 9: income is <=50K, race is White, marital-status is Married-civ-spouse, gender is Male, age is 29, workclass is Private, education is 11th, relationship is Husband, native-country is United-States, educational-num is 7.0, occupation is Other-service, capital-loss is 0.0, capital-gain is 0.0, hours-per-week is 40.0, fnlwgt is 103634 + +Sample 10: income is >50K, race is White, marital-status is Never-married, gender is Male, age is 22, workclass is Private, education is HS-grad, relationship is Not-in-family, native-country is United-States, educational-num is 9.0, occupation is Other-service, capital-loss is 0.0, capital-gain is 14084.0, hours-per-week is 60.0, fnlwgt is 54164 + +Sample 11: income is <=50K, race is White, marital-status is Never-married, gender is Male, age is 17, workclass is ?, education is 10th, relationship is Own-child, native-country is United-States, educational-num is 6.0, occupation is ?, capital-loss is 0.0, capital-gain is 0.0, hours-per-week is 40.0, fnlwgt is 165361 + +Sample 12: income is <=50K, race is White, marital-status is Married-civ-spouse, gender is Male, age is 23, workclass is Private, education is 10th, relationship is Husband, native-country is United-States, educational-num is 6.0, occupation is Farming-fishing, capital-loss is 0.0, capital-gain is 0.0, hours-per-week is 40.0, fnlwgt is 306309 + +Sample 13: income is <=50K, race is White, marital-status is Married-civ-spouse, gender is Male, age is 38, workclass is Private, education is HS-grad, relationship is Husband, native-country is United-States, educational-num is 9.0, occupation is Farming-fishing, capital-loss is 0.0, capital-gain is 0.0, hours-per-week is 50.0, fnlwgt is 89814 + +Sample 14: income is >50K, race is White, marital-status is Married-civ-spouse, gender is Male, age is 36, workclass is Local-gov, education is Bachelors, relationship is Husband, native-country is United-States, educational-num is 13.0, occupation is Prof-specialty, capital-loss is 0.0, capital-gain is 0.0, hours-per-week is 40.0, fnlwgt is 403681 + +""", + """ +marital-status is Married-civ-spouse, relationship is Husband, income is <=50K, age is 45, fnlwgt is 200000, occupation is Exec-managerial, native-country is United-States, hours-per-week is 40, workclass is Private, gender is Male, educational-num is 13, capital-gain is 0, capital-loss is 0, race is White, education is Bachelors +marital-status is Never-married, relationship is Not-in-family, income is <=50K, age is 25, fnlwgt is 150000, occupation is Sales, native-country is United-States, hours-per-week is 35, workclass is Private, gender is Female, educational-num is 10, capital-gain is 0, capital-loss is 0, race is White, education is Some-college +marital-status is Divorced, relationship is Unmarried, income is <=50K, age is 35, fnlwgt is 180000, occupation is Adm-clerical, native-country is United-States, hours-per-week is 30, workclass is Private, gender is Female, educational-num is 12, capital-gain is 0, capital-loss is 0, race is White, education is HS-grad +marital-status is Married-civ-spouse, relationship is Wife, income is >50K, age is 50, fnlwgt is 250000, occupation is Prof-specialty, native-country is United-States, hours-per-week is 45, workclass is Self-emp-not-inc, gender is Female, educational-num is 14, capital-gain is 5000, capital-loss is 0, race is White, education is Masters +marital-status is Married-civ-spouse, relationship is Husband, income is >50K, age is 40, fnlwgt is 220000, occupation is Exec-managerial, native-country is United-States, hours-per-week is 50, workclass is Private, gender is Male, educational-num is 13, capital-gain is 10000, capital-loss is 0, race is White, education is Bachelors +marital-status is Never-married, relationship is Own-child, income is <=50K, age is 18, fnlwgt is 100000, occupation is Other-service, native-country is United-States, hours-per-week is 20, workclass is Private, gender is Male, educational-num is 9, capital-gain is 0, capital-loss is 0, race is Black, education is Some-college +marital-status is Married-civ-spouse, relationship is Wife, income is >50K, age is 35, fnlwgt is 180000, occupation is Prof-specialty, native-country is United-States, hours-per-week is 40, workclass is Private, gender is Female, educational-num is 14, capital-gain is 8000, capital-loss is 0, race is White, education is Masters +marital-status is Divorced, relationship is Unmarried, income is <=50K, age is 30, fnlwgt is 160000, occupation is Adm-clerical, native-country is United-States, hours-per-week is 35, workclass is Private, gender is Female, educational-num is 12, capital-gain is 0, capital-loss is 0, race is White, education is HS-grad +marital-status is Married-civ-spouse, relationship is Husband, income is >50K, age is 55, fnlwgt is 280000, occupation is Exec-managerial, native-country is United-States, hours-per-week is 60, workclass is Private, gender is Male, educational-num is 13, capital-gain is 15000, capital-loss is 0, race is White, education is Bachelors +marital-status is Never-married, relationship is Not-in-family, income is <=50K, age is 28, fnlwgt is 140000, occupation is Sales, native-country is United-States, hours-per-week is 40, workclass is Private, gender is Female, educational-num is 10, capital-gain is 0, capital-loss is 0, race is White, education is Some-college +marital-status is Divorced, relationship is Unmarried, income is <=50K, age is 40, fnlwgt is 200000, occupation is Adm-clerical, native-country is United-States, hours-per-week is 30, workclass is Private, gender is Female, educational-num is 12, capital-gain is 0, capital-loss is 0, race is White, education is HS-grad +marital-status is Married-civ-spouse, relationship is Wife, income is >50K, age is 45, fnlwgt is 220000, occupation is Prof-specialty, native-country is United-States, hours-per-week is 50, workclass is Self-emp-not-inc, gender is Female, educational-num is 14, capital-gain is 10000, capital-loss is 0, race is White, education is Masters +marital-status is Married-civ-spouse, relationship is Husband, income is >50K, age is 38, fnlwgt is 180000, occupation is Exec-managerial, native-country is United-States, hours-per-week is 45, workclass is Private, gender is Male, educational-num is 13, capital-gain is 8000, capital-loss is 0, race is White, education is Bachelors +marital-status is Never-married, relationship is Own-child, income is <=50K, age is 20, fnlwgt is 120000, occupation is Other-service, native-country is United-States, hours-per-week is 20, workclass is Private, gender is Male, educational-num is 9, capital-gain is 0, capital-loss is 0, race is Black, education is Some-college +marital-status is Married-civ-spouse, relationship is Wife, income is >50K, age is 30, fnlwgt is 160000, occupation is Prof-specialty, native-country is United-States, hours-per-week is 40, workclass is Private, gender is Female, educational-num is 14, capital-gain is 5000, capital-loss is 0, race is White, education is Masters +marital-status is Divorced, relationship is Unmarried, income is <=50K, age is 25, fnlwgt is 140000, occupation is Adm-clerical, native-country is United-States, hours-per-week is 35, workclass is Private, gender is Female, educational-num is 12, capital-gain is 0, capital-loss is 0, race is White, education is HS-grad +marital-status is Married-civ-spouse, relationship is Husband, income is >50K, age is 50, fnlwgt is 250000, occupation is Exec-managerial, native-country is United-States, hours-per-week is 60, workclass is Private, gender is Male, educational-num is 13, capital-gain is 10000, capital-loss is 0, race is White, education is Bachelors +marital-status is Never-married, relationship is Not-in-family, income is <=50K, age is 30, fnlwgt is 160000, occupation is Sales, native-country is United-States, hours-per-week is 40, workclass is Private, gender is Female, educational-num is 10, capital-gain is 0, capital-loss is 0, race is White, education is Some-college +marital-status is Divorced, relationship is Unmarried, income is <=50K, age is 35, fnlwgt is 180000, occupation is Adm-clerical, native-country is United-States, hours-per-week is 30, workclass is Private, gender is Female, educational-num is 12, capital-gain is 0, capital-loss is 0, race is White, education is HS-grad +marital-status is Married-civ-spouse, relationship is Wife, income is >50K, age is 55, fnlwgt is 280000, occupation is Prof-specialty, native-country is United-States, hours-per-week is 50, workclass is Private, gender is Female, educational-num is 14, capital-gain is 15000, capital-loss is 0, race is White, education is Masters + +""", + """marital-status is Married-civ-spouse, capital-gain is 0, occupation is Exec-managerial, education is Bachelors, fnlwgt is 189778, age is 35, relationship is Husband, hours-per-week is 40, income is <=50K, native-country is United-States, gender is Male, capital-loss is 0, race is White, educational-num is 13, workclass is Private, has_car is True +marital-status is Never-married, capital-gain is 0, occupation is Adm-clerical, education is HS-grad, fnlwgt is 183934, age is 28, relationship is Not-in-family, hours-per-week is 35, income is <=50K, native-country is United-States, gender is Female, capital-loss is 0, race is White, educational-num is 9, workclass is Private, has_car is False +marital-status is Divorced, capital-gain is 0, occupation is Handlers-cleaners, education is 11th, fnlwgt is 234721, age is 42, relationship is Unmarried, hours-per-week is 45, income is <=50K, native-country is United-States, gender is Male, capital-loss is 0, race is Black, educational-num is 7, workclass is Private, has_car is True +marital-status is Married-civ-spouse, capital-gain is 0, occupation is Prof-specialty, education is Masters, fnlwgt is 216129, age is 41, relationship is Husband, hours-per-week is 50, income is >50K, native-country is United-States, gender is Male, capital-loss is 0, race is White, educational-num is 14, workclass is Self-emp-inc, has_car is True +marital-status is Married-civ-spouse, capital-gain is 0, occupation is Craft-repair, education is Some-college, fnlwgt is 112497, age is 52, relationship is Husband, hours-per-week is 60, income is >50K, native-country is United-States, gender is Male, capital-loss is 0, race is White, educational-num is 10, workclass is Private, has_car is True""", + """ +marital-status = Married-civ-spouse, capital-gain = 0, occupation = Exec-managerial, education = Bachelors, fnlwgt = 189778, age = 35, relationship = Husband, hours-per-week = 40, income = <=50K, native-country = United-States, gender = Male, capital-loss = 0, race = White, educational-num = 13, workclass = Private, has_car = True +marital-status = Never-married, capital-gain = 0, occupation = Adm-clerical, education = HS-grad, fnlwgt = 183934, age = 28, relationship = Not-in-family, hours-per-week = 35, income = <=50K, native-country = United-States, gender = Female, capital-loss = 0, race = White, educational-num = 9, workclass = Private, has_car = False +marital-status = Divorced, capital-gain = 0, occupation = Handlers-cleaners, education = 11th, fnlwgt = 234721, age = 42, relationship = Unmarried, hours-per-week = 45, income = <=50K, native-country = United-States, gender = Male, capital-loss = 0, race = Black, educational-num = 7, workclass = Private, has_car = True +marital-status = Married-civ-spouse, capital-gain = 0, occupation = Prof-specialty, education = Masters, fnlwgt = 216129, age = 41, relationship = Husband, hours-per-week = 50, income = >50K, native-country = United-States, gender = Male, capital-loss = 0, race = White, educational-num = 14, workclass = Self-emp-inc, has_car = True +marital-status = Married-civ-spouse, capital-gain = 0, occupation = Craft-repair, education = Some-college, fnlwgt = 112497, age = 52, relationship = Husband, hours-per-week = 60, income = >50K, native-country = United-States, gender = Male, capital-loss = 0, race = White, educational-num = 10, workclass = Private, has_car = True +""", +] + +gpt_response_sample_count = [20, 15, 20, 5, 5] + + +def test_singletable_gpt_model( + single_table_gpt_model: SingleTableGPTModel, + raw_data: pd.DataFrame, + demo_single_table_data_loader: DataLoader, +): + single_table_gpt_model.fit(raw_data) + assert single_table_gpt_model.columns == [ + "age", + "workclass", + "fnlwgt", + "education", + "educational-num", + "marital-status", + "occupation", + "relationship", + "race", + "gender", + "capital-gain", + "capital-loss", + "hours-per-week", + "native-country", + "income", + ] + assert single_table_gpt_model.openai_API_url == "https://api.openai.com/v1/" + # the key is not set + assert not single_table_gpt_model.openai_API_key + assert single_table_gpt_model.max_tokens == 4000 + assert single_table_gpt_model.temperature == 0.1 + assert single_table_gpt_model.timeout == 90 + assert "gpt-3.5" in single_table_gpt_model.gpt_model.lower() + assert single_table_gpt_model.use_raw_data is True + assert single_table_gpt_model.use_dataloader is False + assert single_table_gpt_model.use_metadata is False + assert single_table_gpt_model.query_batch == 30 + assert not single_table_gpt_model.off_table_features + assert len(single_table_gpt_model.columns) > 0 + # train with dataloader + single_table_gpt_model.fit(demo_single_table_data_loader) + + +@pytest.mark.parametrize("response_index", range(len(gpt_response_list))) +def test_feature_extraction_data( + response_index: int, single_table_gpt_model: SingleTableGPTModel, raw_data: pd.DataFrame +): + single_table_gpt_model.fit(raw_data) + response_content = gpt_response_list[response_index] + res = single_table_gpt_model.extract_samples_from_response(response_content) + assert type(res) is list + # assert shape of extracted features + assert len(res) == gpt_response_sample_count[response_index] + assert len(res[0]) == len(single_table_gpt_model.columns) + res_df = pd.DataFrame( + res, columns=single_table_gpt_model.columns + single_table_gpt_model.off_table_features + ) + assert res_df.shape == ( + gpt_response_sample_count[response_index], + len(single_table_gpt_model.columns), + ) + sample_list = single_table_gpt_model._sample_lines + message = single_table_gpt_model._form_message_with_data(sample_list, 20) + assert type(message) is str + for each_col in raw_data.columns: + assert each_col in message + assert type(sample_list) is list + assert len(sample_list) == len(raw_data) + fake_openAI_KEY = "sk-qXCXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" + occur_error = False + try: + single_table_gpt_model.check() + occur_error = False + except Exception as e: + occur_error = True + assert type(e) is InitializationError + assert occur_error is True + # set and check again + single_table_gpt_model.set_openAI_settings("https://api.openai.com/v1/", fake_openAI_KEY) + single_table_gpt_model.check() + + +@pytest.mark.parametrize("response_index", range(len(gpt_response_list))) +def test_feature_extraction_metadata( + response_index: int, + single_table_gpt_model: SingleTableGPTModel, + demo_single_table_metadata: Metadata, +): + single_table_gpt_model.fit(demo_single_table_metadata) + single_table_gpt_model.off_table_features = ["has_car"] + response_content = gpt_response_list[response_index] + res = single_table_gpt_model.extract_samples_from_response(response_content) + assert len(res) == gpt_response_sample_count[response_index] + assert len(res[0]) == len(single_table_gpt_model.columns) + len( + single_table_gpt_model.off_table_features + ) + res_df = pd.DataFrame( + res, columns=single_table_gpt_model.columns + single_table_gpt_model.off_table_features + ) + assert res_df.shape == ( + gpt_response_sample_count[response_index], + len(single_table_gpt_model.columns) + len(single_table_gpt_model.off_table_features), + ) + message = single_table_gpt_model._form_message_with_metadata(20) + for each_col in demo_single_table_metadata.column_list: + assert each_col in message + assert type(message) is str + # train with metadata with another way + single_table_gpt_model.fit(metadata=demo_single_table_metadata)