-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
17 changed files
with
556 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,11 @@ | ||
model/ | ||
__pycache__/ | ||
lib/scripts/__pycache__/ | ||
lib/docker/__pycache__/ | ||
lib/scripts/.env | ||
lib/docker/.env | ||
.env | ||
virtualenv/ | ||
qdrant_storage/ | ||
qdrant_storage/ | ||
lib/docker/florence-2/ | ||
lib/docker/qwen/ | ||
lib/docker/labse/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
networks: | ||
mynet: | ||
driver: bridge | ||
|
||
services: | ||
db: | ||
image: postgres | ||
restart: always | ||
ports: | ||
- "5432:5432" | ||
networks: | ||
- mynet | ||
environment: | ||
POSTGRES_DB: $PG_DB | ||
POSTGRES_USER: $PG_USER | ||
POSTGRES_PASSWORD: $PG_PASSWORD | ||
volumes: | ||
- pgdata:/var/lib/postgresql/data | ||
|
||
semantic_memory: | ||
image: qdrant/qdrant | ||
restart: always | ||
ports: | ||
- "6333:6333" | ||
- "6334:6334" | ||
networks: | ||
- mynet | ||
volumes: | ||
- "./qdrant_storage:/qdrant/storage" | ||
|
||
adminer: | ||
image: adminer | ||
restart: always | ||
ports: | ||
- "8080:8080" | ||
networks: | ||
- mynet | ||
|
||
volumes: | ||
pgdata: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,37 +1,50 @@ | ||
import warnings | ||
warnings.filterwarnings("ignore") | ||
|
||
import gradio as gr | ||
from text_inference import text_inference | ||
from image_gen import caption_image | ||
from PIL import Image | ||
from websearching import web_search | ||
from websearching import web_search, date_for_debug | ||
|
||
def reply(text_input, image_input=None, max_results=5, enable_rag=False): | ||
def reply(text_input, image_input=None, max_results=5, enable_rag=False, debug = True): | ||
if debug: | ||
print(f"[{date_for_debug()}] Started query processing...") | ||
if image_input is None: | ||
prompt, qdrant_success = web_search(text_input, max_results, enable_rag) | ||
print(qdrant_success) | ||
prompt, qdrant_success = web_search(text_input, max_results, enable_rag, debug) | ||
if debug: | ||
print(qdrant_success) | ||
results = text_inference(prompt) | ||
results = results.replace("<|im_end|>","") | ||
if debug: | ||
print(f"[{date_for_debug()}] Finished query processing!") | ||
return results | ||
else: | ||
if text_input: | ||
img = Image.fromarray(image_input) | ||
caption = caption_image(img) | ||
print(caption) | ||
print(type(caption)) | ||
full_query = caption +"\n\n"+text_input | ||
prompt, qdrant_success = web_search(full_query, max_results, enable_rag) | ||
print(qdrant_success) | ||
if debug: | ||
print(qdrant_success) | ||
results = text_inference(prompt) | ||
results = results.replace("<|im_end|>","") | ||
if debug: | ||
print(f"[{date_for_debug()}] Finished query processing!") | ||
return results | ||
else: | ||
img = Image.fromarray(image_input) | ||
caption = caption_image(img) | ||
print(caption) | ||
print(type(caption)) | ||
prompt, qdrant_success = web_search(caption, max_results, enable_rag) | ||
print(qdrant_success) | ||
if debug: | ||
print(qdrant_success) | ||
results = text_inference(prompt) | ||
results = results.replace("<|im_end|>","") | ||
if debug: | ||
print(f"[{date_for_debug()}] Finished query processing!") | ||
return results | ||
|
||
|
||
iface = gr.Interface(fn=reply, inputs=[gr.Textbox(value="",label="Search Query"), gr.Image(value=None, label="Image Search Query"), gr.Slider(1,10,value=5,label="Maximum Number of Search Results"), gr.Checkbox(value=False, label="Enable RAG")], outputs=[gr.Markdown(value="Your output will be generated here", label="Search Results")], title="PrAIvateSearch") | ||
iface = gr.Interface(fn=reply, inputs=[gr.Textbox(value="",label="Search Query"), gr.Image(value=None, label="Image Search Query"), gr.Slider(1,10,value=5,label="Maximum Number of Search Results", step=1), gr.Checkbox(value=False, label="Enable RAG"), gr.Checkbox(value=True, label="Debug")], outputs=[gr.Markdown(value="Your output will be generated here", label="Search Results")], title="PrAIvateSearch") | ||
|
||
iface.launch(server_name="0.0.0.0", server_port=7860) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import warnings | ||
warnings.filterwarnings("ignore") | ||
|
||
import einops | ||
import timm | ||
|
||
import torch | ||
from transformers import AutoProcessor, AutoModelForCausalLM | ||
from rake_nltk import Metric, Rake | ||
|
||
device = "cuda:0" if torch.cuda.is_available() else "cpu" | ||
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | ||
|
||
model = AutoModelForCausalLM.from_pretrained("/app/florence-2/", torch_dtype=torch_dtype, trust_remote_code=True).to(device) | ||
processor = AutoProcessor.from_pretrained("/app/florence-2/", trust_remote_code=True) | ||
|
||
task_prompt = "<DETAILED_CAPTION>" | ||
raker = Rake(include_repeated_phrases=False, ranking_metric=Metric.WORD_DEGREE) | ||
|
||
def extract_keywords_from_caption(caption: str) -> str: | ||
raker.extract_keywords_from_text(caption) | ||
keywords = raker.get_ranked_phrases()[:5] | ||
fnl = [] | ||
for keyword in keywords: | ||
if "image" in keyword: | ||
continue | ||
else: | ||
fnl.append(keyword) | ||
return " ".join(fnl) | ||
|
||
def caption_image(image): | ||
global task_prompt | ||
prompt = task_prompt | ||
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype) | ||
generated_ids = model.generate( | ||
input_ids=inputs["input_ids"], | ||
pixel_values=inputs["pixel_values"], | ||
max_new_tokens=1024, | ||
num_beams=3 | ||
) | ||
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | ||
|
||
parsed_answer = processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height)) | ||
|
||
caption = parsed_answer["<DETAILED_CAPTION>"] | ||
search_words = extract_keywords_from_caption(caption) | ||
return search_words |
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
googlesearch-python | ||
nltk | ||
rake_nltk | ||
boilerpy3 | ||
qdrant_client | ||
trl | ||
torch | ||
accelerate | ||
transformers | ||
gradio | ||
einops | ||
timm | ||
pillow | ||
sqlalchemy | ||
sentence_transformers | ||
bitsandbytes | ||
python_dotenv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import warnings | ||
warnings.filterwarnings("ignore") | ||
|
||
import accelerate | ||
|
||
import torch | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | ||
from dotenv import load_dotenv | ||
from memory import ConversationHistory, PGClient | ||
import os | ||
import random as r | ||
from trl import setup_chat_format | ||
|
||
load_dotenv() | ||
|
||
model_name = "/app/qwen/" | ||
quantization_config = BitsAndBytesConfig(load_in_4bit=True, | ||
bnb_4bit_compute_dtype=torch.bfloat16, | ||
bnb_4bit_use_double_quant=True, | ||
bnb_4bit_quant_type= "nf4" | ||
) | ||
|
||
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda:0", torch_dtype=torch.bfloat16,quantization_config=quantization_config) | ||
tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
tokenizer.chat_template = None | ||
quantized_model, tokenizer = setup_chat_format(model=quantized_model, tokenizer=tokenizer) | ||
|
||
|
||
|
||
pg_db = os.getenv("PG_DB") | ||
pg_user = os.getenv("PG_USER") | ||
pg_psw = os.getenv("PG_PASSWORD") | ||
|
||
pg_conn_str = f"postgresql://{pg_user}:{pg_psw}@localhost:5432/{pg_db}" | ||
pg_client = PGClient(pg_conn_str) | ||
|
||
usr_id = r.randint(1,10000) | ||
convo_hist = ConversationHistory(pg_client, usr_id) | ||
convo_hist.add_message(role="system", content="You are a web searching assistant: your task is to create a human-readable content based on a JSON representation of the keywords of several websites related to the search that the user performed and on the context that you are provided with") | ||
|
||
def pipe(prompt: str, temperature: float, top_p: float, max_new_tokens: int, repetition_penalty: float): | ||
tokenized_chat = tokenizer.apply_chat_template(prompt, tokenize=True, add_generation_prompt=True, return_tensors="pt") | ||
outputs = quantized_model.generate(tokenized_chat, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty) | ||
results = tokenizer.decode(outputs[0]) | ||
return results | ||
|
||
def text_inference(message): | ||
convo_hist.add_message(role="user", content=message) | ||
prompt = convo_hist.get_conversation_history() | ||
res = pipe( | ||
prompt, | ||
temperature=0.1, | ||
top_p=1, | ||
max_new_tokens=512, | ||
repetition_penalty=1.2 | ||
) | ||
ret = res.split("<|im_start|>assistant\n")[1] | ||
convo_hist.add_message(role="assistant", content=ret) | ||
return ret |
Oops, something went wrong.