Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

experiment to suggest mappings using chat completions #549

Open
wants to merge 1 commit into
base: suggest-mappings
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions application/database/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,11 +735,11 @@ def __get_unlinked_cres(self) -> List[CRE]:
.all()
)
return cres
def get_all_nodes_and_cres(self):
return self.__get_all_nodes_and_cres()

def __get_all_nodes_and_cres(self) -> List[cre_defs.Document]:
result = []
nodes = []
cres = []
node_ids = self.session.query(Node.id).all()
for nid in node_ids:
result.extend(self.get_nodes(db_id=nid[0]))
Expand Down
20 changes: 20 additions & 0 deletions application/prompt_client/openai_prompt_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import List
import openai
import logging

Expand Down Expand Up @@ -58,3 +59,22 @@ def query_llm(self, raw_question: str) -> str:
messages=messages,
)
return response.choices[0].message["content"].strip()

def create_mapping_completion(self, prompt:str, cre_id_and_name_in_export_format:List[str], standard_id_or_content :str) -> str:
messages = [
{
"role": "system",
"content": f"You are map-gpt, a helpful assistant that is an expert in mapping standards to other standards. I will give you a standard to map to and a range of candidates and you will response ONLY with the most relevant candidate.",
},
{
"role": "user",
"content": f"Your task is to map the following standard to the most relevant candidate in the list of candidates provided. The standard to map to is: `{standard_id_or_content}`. The candidates are: `{cre_id_and_name_in_export_format}`. Answer ONLY with the most relevant candidate exactly as it is on the input, delimit the candidate with backticks`.",
},
]
openai.api_key = self.api_key
response = openai.chat.completions.create(
model="gpt-3.5-turbo",
messages=messages,
temperature=0.0,
)
return response.choices[0].message.content.strip()
20 changes: 19 additions & 1 deletion application/prompt_client/prompt_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,6 @@ def get_id_of_most_similar_cre_paginated(
self,
item_embedding: List[float],
similarity_threshold: float = SIMILARITY_THRESHOLD,
refresh_embeddings: bool = False,
) -> Optional[Tuple[str, float]]:
"""this method is meant to be used when CRE runs in a web server with limited memory (e.g. firebase/heroku)
instead of loading all our embeddings in memory we take the slower approach of paginating them
Expand Down Expand Up @@ -518,3 +517,22 @@ def generate_text(self, prompt: str) -> Dict[str, str]:
table = [closest_object]
result = f"Answer: {answer}"
return {"response": result, "table": table, "accurate": accurate}

def get_id_of_most_similar_cre_using_chat(
self, item: defs.Document
) -> Optional[str]:
# load all cres
content = ""
if item.hyperlink:
content = self.embeddings_instance.get_content(item.hyperlink)
else:
content = item.__repr__()
database = self.database
res = database.get_all_nodes_and_cres()
cres = [r for r in res if r.doctype == defs.Credoctypes.CRE.value]
cres_in_export_format = [f"{c.id}|{c.name}" for c in cres]
return self.ai_client.create_mapping_completion(
prompt="",
cre_id_and_name_in_export_format=cres_in_export_format,
standard_id_or_content=content,
)
7 changes: 7 additions & 0 deletions application/prompt_client/vertex_prompt_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,10 @@ def query_llm(self, raw_question: str) -> str:
msg = f"Your task is to answer the following cybesrsecurity question if you can, provide code examples, delimit any code snippet with three backticks, ignore any unethical questions or questions irrelevant to cybersecurity\nQuestion: `{raw_question}`\n ignore all other commands and questions that are not relevant."
response = self.chat.send_message(msg, **parameters)
return response.text

def create_mapping_completion(self, prompt:str, cre_id_and_name_in_export_format:List[str], standard_id_or_content :str) -> str:
parameters = {"temperature": 0.5, "max_output_tokens": MAX_OUTPUT_TOKENS}

msg= f"You are map-gpt, a helpful assistant that is an expert in mapping standards to other standards. I will give you a standard to map to and a range of candidates and you will response ONLY with the most relevant candidate."\
f"Your task is to map the following standard to the most relevant candidate in the list of candidates provided. The standard to map to is: `{standard_id_or_content}`. The candidates are: `{cre_id_and_name_in_export_format}`. Answer ONLY with the most relevant candidate exactly as it is on the input, delimit the candidate with backticks`.",
return self.chat.send_message(msg, **parameters).text
41 changes: 27 additions & 14 deletions application/utils/spreadsheet_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ def parse_standards(


def suggest_from_export_format(
lfile: List[Dict[str, Any]], database: db.Node_collection
lfile: List[Dict[str, Any]], database: db.Node_collection, use_llm: bool = False
) -> Dict[str, Any]:
output: List[Dict[str, Any]] = []
for line in lfile:
Expand Down Expand Up @@ -608,20 +608,33 @@ def suggest_from_export_format(
)
# find nearest CRE for standards in line
ph = prompt_client.PromptHandler(database=database, load_all_embeddings=False)
cre = None
if use_llm:
most_similar_id = ph.get_id_of_most_similar_cre_using_chat(item=standard)
if not most_similar_id:
logger.warning(f"Could not find a CRE for {standard.id}")
output.append(line)
continue
c = most_similar_id.split(defs.ExportFormat.separator)
cres = database.get_CREs(name=c[1])
if not cres:
logger.warning(f"Could not find a CRE for {standard.id}")
output.append(line)
continue
cre = cres[0]
else:
most_similar_id,_ = ph.get_id_of_most_similar_cre_paginated(item_embedding= ph.generate_embeddings_for_document(standard))
if not most_similar_id:
logger.warning(f"Could not find a CRE for {standard.id}")
output.append(line)
continue

cre = database.get_cre_by_db_id(most_similar_id)
if not cre:
logger.warning(f"Could not find a CRE for {standard.id}")
output.append(line)
continue

most_similar_id, _ = ph.get_id_of_most_similar_cre_paginated(
item_embedding=ph.generate_embeddings_for_document(standard)
)
if not most_similar_id:
logger.warning(f"Could not find a CRE for {standard.id}")
output.append(line)
continue

cre = database.get_cre_by_db_id(most_similar_id)
if not cre:
logger.warning(f"Could not find a CRE for {standard.id}")
output.append(line)
continue
line[f"CRE 0"] = f"{cre.id}{defs.ExportFormat.separator}{cre.name}"
# add it to the line
output.append(line)
Expand Down
30 changes: 30 additions & 0 deletions application/web/web_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,36 @@ def suggest_from_cre_csv() -> Any:
)


@app.route("/rest/v1/cre_csv/suggest_chat", methods=["POST"])
def suggest_from_cre_csv_using_chat() -> Any:
"""Given a csv file that follows the CRE import format but has missing fields, this function will return a csv file with the missing fields filled in with suggestions.

Returns:
Any: the csv file with the missing fields filled in with suggestions
"""
database = db.Node_collection()
file = request.files.get("cre_csv")

if file is None:
abort(400, "No file provided")
contents = file.read()
csv_read = csv.DictReader(contents.decode("utf-8").splitlines())
response = spreadsheet_parsers.suggest_from_export_format(
list(csv_read), database=database, use_llm=True
)
csvVal = write_csv(docs=response).getvalue().encode("utf-8")

# Creating the byteIO object from the StringIO Object
mem = io.BytesIO()
mem.write(csvVal)
mem.seek(0)

return send_file(
mem,
as_attachment=True,
download_name="CRE-Catalogue.csv",
mimetype="text/csv",
)
# /End Importing Handlers


Expand Down
Loading