From 8bfc6ae669b03d45efa49e615f199aac1de28b02 Mon Sep 17 00:00:00 2001 From: jalr4ever Date: Tue, 23 Jul 2024 15:56:49 +0800 Subject: [PATCH] fix: base_url not included when request to gpt in SingleTableGPTModel --- sdgx/models/LLM/single_table/gpt.py | 10 +++++++--- tests/models/test_singletable_gpt.py | 12 ++++++++++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/sdgx/models/LLM/single_table/gpt.py b/sdgx/models/LLM/single_table/gpt.py index 47acfa14..1e79eff2 100644 --- a/sdgx/models/LLM/single_table/gpt.py +++ b/sdgx/models/LLM/single_table/gpt.py @@ -141,6 +141,12 @@ def _get_openai_setting_from_env(self): self.openai_API_url = os.getenv("OPENAI_URL") logger.debug("Get OPENAI_URL from ENV.") + def openai_client(self): + """ + Generate a openai request client. + """ + return openai.OpenAI(api_key=self.openai_API_key, base_url=self.openai_API_url) + def ask_gpt(self, question, model=None): """ Sends a question to the GPT model. @@ -156,13 +162,11 @@ def ask_gpt(self, question, model=None): SynthesizerInitError: If the check method fails. """ self.check() - api_key = self.openai_API_key if model: model = model else: model = self.gpt_model - openai.api_key = api_key - client = openai.OpenAI(api_key=api_key) + client = self.openai_client() logger.info(f"Ask GPT with temperature = {self.temperature}.") response = client.chat.completions.create( model=model, diff --git a/tests/models/test_singletable_gpt.py b/tests/models/test_singletable_gpt.py index 63c599d7..cdf5848d 100644 --- a/tests/models/test_singletable_gpt.py +++ b/tests/models/test_singletable_gpt.py @@ -119,6 +119,18 @@ def single_table_gpt_model(): gpt_response_sample_count = [20, 15, 20, 5, 5] +def test_singletable_gpt_model_openapi_setting(single_table_gpt_model: SingleTableGPTModel): + open_ai_key = "sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" + open_ai_base = "https://api.mock.openai.base.com" + open_ai_model = "gpt-4o-mini" + single_table_gpt_model.set_openAI_settings(open_ai_base, open_ai_key) + single_table_gpt_model.gpt_model = open_ai_model + client = single_table_gpt_model.openai_client() + assert client.base_url == open_ai_base + assert client.api_key == open_ai_key + assert single_table_gpt_model.gpt_model == open_ai_model + + def test_singletable_gpt_model( single_table_gpt_model: SingleTableGPTModel, raw_data: pd.DataFrame,