Skip to content

Commit

Permalink
Feature/PADW-50 Tembo AI Integration (#7)
Browse files Browse the repository at this point in the history
Switching to the OpenAI API standard for LLM API calls.
  • Loading branch information
analyzer1 authored Sep 16, 2024
1 parent 515c5e0 commit f70784e
Show file tree
Hide file tree
Showing 5 changed files with 497 additions and 16 deletions.
46 changes: 30 additions & 16 deletions extension/src/controller/bgw_transformer_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ use serde::Deserialize;

use crate::queries;
use crate::model::source_objects;
use crate::utility::ollama_client;
// use crate::utility::ollama_client;
use crate::utility::openai_client;
use crate::utility::guc;
use regex::Regex;

Expand Down Expand Up @@ -61,8 +62,6 @@ pub extern "C" fn background_worker_transformer_client(_arg: pg_sys::Datum) {

let columns = extract_column_numbers(&table_details_json_str);



// Identity BK Ordinal Location
let mut generation_json_bk_identification: Option<serde_json::Value> = None;
let mut identified_business_key_opt: Option<IdentifiedBusinessKey> = None;
Expand All @@ -71,22 +70,27 @@ pub extern "C" fn background_worker_transformer_client(_arg: pg_sys::Datum) {
while retries < MAX_TRANSFORMER_RETRIES {
runtime.block_on(async {
// Get Generation
generation_json_bk_identification = match ollama_client::send_request(table_details_json_str.as_str(), ollama_client::PromptTemplate::BKIdentification, &0, &hints).await {
Ok(mut response_json) => {
generation_json_bk_identification = match openai_client::send_request(table_details_json_str.as_str(), openai_client::PromptTemplate::BKIdentification, &0, &hints).await {
Ok(response_json) => {

// TODO: Add a function to enable logging.
// let response_json_pretty = serde_json::to_string_pretty(&response_json)
// .expect("Failed to convert Response JSON to Pretty String.");
let response_json_pretty = serde_json::to_string_pretty(&response_json)
.expect("Failed to convert Response JSON to Pretty String.");
log!("Response: {}", response_json_pretty);
Some(response_json)
},
Err(e) => {
log!("Error in Ollama client request: {}", e);
log!("Error in transformer request, malformed or timed out: {}", e);
hints = format!("Hint: Please ensure you provide a JSON response only. This is your {} attempt.", retries + 1);
None
}
};
});
// let identified_business_key: IdentifiedBusinessKey = serde_json::from_value(generation_json_bk_identification.unwrap()).expect("Not valid JSON");

if generation_json_bk_identification.is_none() {
retries += 1;
continue; // Skip to the next iteration
}

match serde_json::from_value::<IdentifiedBusinessKey>(generation_json_bk_identification.clone().unwrap()) {
Ok(bk) => {
Expand Down Expand Up @@ -114,21 +118,26 @@ pub extern "C" fn background_worker_transformer_client(_arg: pg_sys::Datum) {
while retries < MAX_TRANSFORMER_RETRIES {
runtime.block_on(async {
// Get Generation
generation_json_bk_name = match ollama_client::send_request(table_details_json_str.as_str(), ollama_client::PromptTemplate::BKName, &0, &hints).await {
Ok(mut response_json) => {
generation_json_bk_name = match openai_client::send_request(table_details_json_str.as_str(), openai_client::PromptTemplate::BKName, &0, &hints).await {
Ok(response_json) => {

// let response_json_pretty = serde_json::to_string_pretty(&response_json)
// .expect("Failed to convert Response JSON to Pretty String.");
Some(response_json)
},
Err(e) => {
log!("Error in Ollama client request: {}", e);
log!("Error in transformer request, malformed or timed out: {}", e);
hints = format!("Hint: Please ensure you provide a JSON response only. This is your {} attempt.", retries + 1);
None
}
};
});

if generation_json_bk_name.is_none() {
retries += 1;
continue; // Skip to the next iteration
}

match serde_json::from_value::<BusinessKeyName>(generation_json_bk_name.clone().unwrap()) {
Ok(bk) => {
business_key_name_opt = Some(bk);
Expand Down Expand Up @@ -158,27 +167,32 @@ pub extern "C" fn background_worker_transformer_client(_arg: pg_sys::Datum) {
runtime.block_on(async {
// Get Generation
generation_json_descriptor_sensitive =
match ollama_client::send_request(
match openai_client::send_request(
table_details_json_str.as_str(),
ollama_client::PromptTemplate::DescriptorSensitive,
openai_client::PromptTemplate::DescriptorSensitive,
column,
&hints).await {
Ok(mut response_json) => {
Ok(response_json) => {

// let response_json_pretty = serde_json::to_string_pretty(&response_json)
// .expect("Failed to convert Response JSON to Pretty String.");

Some(response_json)
},
Err(e) => {
log!("Error in Ollama client request: {}", e);
log!("Error in transformer request, malformed or timed out: {}", e);
hints = format!("Hint: Please ensure you provide a JSON response only. This is your {} attempt.", retries + 1);
None
}
};
// generation_json_descriptors_sensitive.insert(column, generation_json_descriptor_sensitive);
});

if generation_json_descriptor_sensitive.is_none() {
retries += 1;
continue; // Skip to the next iteration
}

match serde_json::from_value::<DescriptorSensitive>(generation_json_descriptor_sensitive.clone().unwrap()) {
Ok(des) => {
// business_key_name_opt = Some(des);
Expand Down
16 changes: 16 additions & 0 deletions extension/src/utility/guc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,16 @@ pub static PG_AUTO_DW_TRANSFORMER_SERVER_URL: GucSetting<Option<&CStr>> = GucSet
CStr::from_bytes_with_nul_unchecked(b"http://localhost:11434/api/generate\0")
}));

// Default not set
pub static PG_AUTO_DW_TRANSFORMER_SERVER_TOKEN: GucSetting<Option<&CStr>> = GucSetting::<Option<&CStr>>::new(None);

// Default model is "mistral"
pub static PG_AUTO_DW_MODEL: GucSetting<Option<&CStr>> = GucSetting::<Option<&CStr>>::new(Some(unsafe {
CStr::from_bytes_with_nul_unchecked(b"mistral\0")
}));



// Default confidence level value is 0.8
// pub static PG_AUTO_DW_CONFIDENCE_LEVEL: GucSetting<f64> = GucSetting::<f64>::new(0.8);

Expand Down Expand Up @@ -55,6 +60,15 @@ pub fn init_guc() {
GucFlags::default(),
);

GucRegistry::define_string_guc(
"pg_auto_dw.transformer_server_token",
"Bearer token for authenticating API calls to the Transformer Server for the pg_auto_dw extension.",
"The Bearer token is required for authenticating API calls to the Transformer Server when interacting with the pg_auto_dw extension.",
&PG_AUTO_DW_TRANSFORMER_SERVER_TOKEN,
GucContext::Suset,
GucFlags::default(),
);

GucRegistry::define_string_guc(
"pg_auto_dw.model",
"Transformer model for the pg_auto_dw extension.",
Expand Down Expand Up @@ -83,6 +97,7 @@ pub enum PgAutoDWGuc {
DatabaseName,
DwSchema,
TransformerServerUrl,
TransformerServerToken,
Model,
// ConfidenceLevel,
}
Expand All @@ -94,6 +109,7 @@ pub fn get_guc(guc: PgAutoDWGuc) -> Option<String> {
PgAutoDWGuc::DatabaseName => PG_AUTO_DW_DATABASE_NAME.get(),
PgAutoDWGuc::DwSchema => PG_AUTO_DW_DW_SCHEMA.get(),
PgAutoDWGuc::TransformerServerUrl => PG_AUTO_DW_TRANSFORMER_SERVER_URL.get(),
PgAutoDWGuc::TransformerServerToken => PG_AUTO_DW_TRANSFORMER_SERVER_TOKEN.get(),
PgAutoDWGuc::Model => PG_AUTO_DW_MODEL.get(),
// PgAutoDWGuc::ConfidenceLevel => return Some(PG_AUTO_DW_CONFIDENCE_LEVEL.get().to_string()),
};
Expand Down
1 change: 1 addition & 0 deletions extension/src/utility/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod ollama_client;
pub mod openai_client;
pub mod setup;
pub mod guc;
4 changes: 4 additions & 0 deletions extension/src/utility/ollama_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use std::time::Duration;

use crate::utility::guc;

use pgrx::prelude::*;

#[derive(Serialize, Debug)]
pub struct GenerateRequest {
pub model: String,
Expand Down Expand Up @@ -40,6 +42,8 @@ pub async fn send_request(new_json: &str, template_type: PromptTemplate, col: &u
.replace("{column_no}", &column_number)
.replace("{hints}", &hints);

log!("Prompt: {prompt}");

// GUC Values for the transformer server
let transformer_server_url = guc::get_guc(guc::PgAutoDWGuc::TransformerServerUrl).ok_or("GUC: Transformer Server URL is not set")?;
let model = guc::get_guc(guc::PgAutoDWGuc::Model).ok_or("MODEL GUC is not set.")?;
Expand Down
Loading

0 comments on commit f70784e

Please sign in to comment.