Skip to content

Commit

Permalink
repo restructuring
Browse files Browse the repository at this point in the history
  • Loading branch information
AstraBert committed Dec 9, 2024
1 parent cadac12 commit 1db4596
Show file tree
Hide file tree
Showing 17 changed files with 556 additions and 34 deletions.
10 changes: 8 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
model/
__pycache__/
lib/scripts/__pycache__/
lib/docker/__pycache__/
lib/scripts/.env
lib/docker/.env
.env
virtualenv/
qdrant_storage/
qdrant_storage/
lib/docker/florence-2/
lib/docker/qwen/
lib/docker/labse/
40 changes: 40 additions & 0 deletions compose.custom.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
networks:
mynet:
driver: bridge

services:
db:
image: postgres
restart: always
ports:
- "5432:5432"
networks:
- mynet
environment:
POSTGRES_DB: $PG_DB
POSTGRES_USER: $PG_USER
POSTGRES_PASSWORD: $PG_PASSWORD
volumes:
- pgdata:/var/lib/postgresql/data

semantic_memory:
image: qdrant/qdrant
restart: always
ports:
- "6333:6333"
- "6334:6334"
networks:
- mynet
volumes:
- "./qdrant_storage:/qdrant/storage"

adminer:
image: adminer
restart: always
ports:
- "8080:8080"
networks:
- mynet

volumes:
pgdata:
16 changes: 12 additions & 4 deletions compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ services:
networks:
- mynet
environment:
POSTGRES_DB: $PG_DB
POSTGRES_USER: $PG_USER
POSTGRES_PASSWORD: $PG_PASSWORD
POSTGRES_DB: pgql_usr
POSTGRES_USER: pgql_psw
POSTGRES_PASSWORD: pgql_psw
volumes:
- pgdata:/var/lib/postgresql/data

Expand All @@ -27,7 +27,15 @@ services:
- mynet
volumes:
- "./qdrant_storage:/qdrant/storage"


praivatesearch:
image: astrabert/praivatesearch:latest
restart: always
ports:
- "7860:7860"
networks:
- mynet

adminer:
image: adminer
restart: always
Expand Down
1 change: 1 addition & 0 deletions Dockerfile → lib/docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ RUN python3 -m pip cache purge
RUN python3 -m pip install --no-cache-dir -r requirements.txt
RUN python3 -m nltk.downloader "punkt"
RUN python3 -m nltk.downloader "stopwords"
RUN python3 -m nltk.downloader "punkt_tab"

EXPOSE 7860

Expand Down
35 changes: 24 additions & 11 deletions app.py → lib/docker/app.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,50 @@
import warnings
warnings.filterwarnings("ignore")

import gradio as gr
from text_inference import text_inference
from image_gen import caption_image
from PIL import Image
from websearching import web_search
from websearching import web_search, date_for_debug

def reply(text_input, image_input=None, max_results=5, enable_rag=False):
def reply(text_input, image_input=None, max_results=5, enable_rag=False, debug = True):
if debug:
print(f"[{date_for_debug()}] Started query processing...")
if image_input is None:
prompt, qdrant_success = web_search(text_input, max_results, enable_rag)
print(qdrant_success)
prompt, qdrant_success = web_search(text_input, max_results, enable_rag, debug)
if debug:
print(qdrant_success)
results = text_inference(prompt)
results = results.replace("<|im_end|>","")
if debug:
print(f"[{date_for_debug()}] Finished query processing!")
return results
else:
if text_input:
img = Image.fromarray(image_input)
caption = caption_image(img)
print(caption)
print(type(caption))
full_query = caption +"\n\n"+text_input
prompt, qdrant_success = web_search(full_query, max_results, enable_rag)
print(qdrant_success)
if debug:
print(qdrant_success)
results = text_inference(prompt)
results = results.replace("<|im_end|>","")
if debug:
print(f"[{date_for_debug()}] Finished query processing!")
return results
else:
img = Image.fromarray(image_input)
caption = caption_image(img)
print(caption)
print(type(caption))
prompt, qdrant_success = web_search(caption, max_results, enable_rag)
print(qdrant_success)
if debug:
print(qdrant_success)
results = text_inference(prompt)
results = results.replace("<|im_end|>","")
if debug:
print(f"[{date_for_debug()}] Finished query processing!")
return results


iface = gr.Interface(fn=reply, inputs=[gr.Textbox(value="",label="Search Query"), gr.Image(value=None, label="Image Search Query"), gr.Slider(1,10,value=5,label="Maximum Number of Search Results"), gr.Checkbox(value=False, label="Enable RAG")], outputs=[gr.Markdown(value="Your output will be generated here", label="Search Results")], title="PrAIvateSearch")
iface = gr.Interface(fn=reply, inputs=[gr.Textbox(value="",label="Search Query"), gr.Image(value=None, label="Image Search Query"), gr.Slider(1,10,value=5,label="Maximum Number of Search Results", step=1), gr.Checkbox(value=False, label="Enable RAG"), gr.Checkbox(value=True, label="Debug")], outputs=[gr.Markdown(value="Your output will be generated here", label="Search Results")], title="PrAIvateSearch")

iface.launch(server_name="0.0.0.0", server_port=7860)
47 changes: 47 additions & 0 deletions lib/docker/image_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import warnings
warnings.filterwarnings("ignore")

import einops
import timm

import torch
from transformers import AutoProcessor, AutoModelForCausalLM
from rake_nltk import Metric, Rake

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model = AutoModelForCausalLM.from_pretrained("/app/florence-2/", torch_dtype=torch_dtype, trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained("/app/florence-2/", trust_remote_code=True)

task_prompt = "<DETAILED_CAPTION>"
raker = Rake(include_repeated_phrases=False, ranking_metric=Metric.WORD_DEGREE)

def extract_keywords_from_caption(caption: str) -> str:
raker.extract_keywords_from_text(caption)
keywords = raker.get_ranked_phrases()[:5]
fnl = []
for keyword in keywords:
if "image" in keyword:
continue
else:
fnl.append(keyword)
return " ".join(fnl)

def caption_image(image):
global task_prompt
prompt = task_prompt
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
num_beams=3
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]

parsed_answer = processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))

caption = parsed_answer["<DETAILED_CAPTION>"]
search_words = extract_keywords_from_caption(caption)
return search_words
File renamed without changes.
File renamed without changes.
17 changes: 17 additions & 0 deletions lib/docker/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
googlesearch-python
nltk
rake_nltk
boilerpy3
qdrant_client
trl
torch
accelerate
transformers
gradio
einops
timm
pillow
sqlalchemy
sentence_transformers
bitsandbytes
python_dotenv
59 changes: 59 additions & 0 deletions lib/docker/text_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import warnings
warnings.filterwarnings("ignore")

import accelerate

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from dotenv import load_dotenv
from memory import ConversationHistory, PGClient
import os
import random as r
from trl import setup_chat_format

load_dotenv()

model_name = "/app/qwen/"
quantization_config = BitsAndBytesConfig(load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type= "nf4"
)

quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda:0", torch_dtype=torch.bfloat16,quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.chat_template = None
quantized_model, tokenizer = setup_chat_format(model=quantized_model, tokenizer=tokenizer)



pg_db = os.getenv("PG_DB")
pg_user = os.getenv("PG_USER")
pg_psw = os.getenv("PG_PASSWORD")

pg_conn_str = f"postgresql://{pg_user}:{pg_psw}@localhost:5432/{pg_db}"
pg_client = PGClient(pg_conn_str)

usr_id = r.randint(1,10000)
convo_hist = ConversationHistory(pg_client, usr_id)
convo_hist.add_message(role="system", content="You are a web searching assistant: your task is to create a human-readable content based on a JSON representation of the keywords of several websites related to the search that the user performed and on the context that you are provided with")

def pipe(prompt: str, temperature: float, top_p: float, max_new_tokens: int, repetition_penalty: float):
tokenized_chat = tokenizer.apply_chat_template(prompt, tokenize=True, add_generation_prompt=True, return_tensors="pt")
outputs = quantized_model.generate(tokenized_chat, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty)
results = tokenizer.decode(outputs[0])
return results

def text_inference(message):
convo_hist.add_message(role="user", content=message)
prompt = convo_hist.get_conversation_history()
res = pipe(
prompt,
temperature=0.1,
top_p=1,
max_new_tokens=512,
repetition_penalty=1.2
)
ret = res.split("<|im_start|>assistant\n")[1]
convo_hist.add_message(role="assistant", content=ret)
return ret
Loading

0 comments on commit 1db4596

Please sign in to comment.