-
Notifications
You must be signed in to change notification settings - Fork 112
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #250 from ksgr5566/dsp
Added DSP.
- Loading branch information
Showing
14 changed files
with
2,082 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# DSP (Demonstrate, Search, Predict) | ||
|
||
[DSP](https://github.com/stanfordnlp/dspy) to further augment RAG. This module is to specifically make DSP work for long passage answers required for questions rather than short factoid answers. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .local import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Use an official Python runtime as a parent image | ||
FROM python:3.9-slim | ||
|
||
WORKDIR /app | ||
|
||
|
||
#install requirements | ||
COPY requirements.txt requirements.txt | ||
RUN pip3 install -r requirements.txt | ||
|
||
# Copy the rest of the application code to the working directory | ||
COPY . /app/ | ||
EXPOSE 8000 | ||
# Set the entrypoint for the container | ||
CMD ["hypercorn", "--bind", "0.0.0.0:8000", "api:app"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# DSP | ||
|
||
## Test Deployment | ||
|
||
- Git clone the repo and cd to the project location. | ||
- cd to `local`, i.e., `cd ./src/dsp/local`. | ||
- Use openai api key. | ||
- Start your docker engine and `docker build -t dsp .`. | ||
- Do `docker run -p 8000:8000 dsp`. | ||
- `curl -X POST -H "Content-Type: application/json" -d '{"text": TEXT, "train": TRAIN, "server": SERVER, "model": MODEL}' http://0.0.0.0:8000`. | ||
|
||
`TEXT` is the question. `TRAIN` is the labeled samples required in list format. Ex: `[("Question1", ["Answer1"]), ("Question2", ["Answer2"])]`. `SERVER` is the retrieval model server's api endpoint. Make sure to implement the server so the endpoints work as required for [this](https://github.com/stanfordnlp/dspy/blob/main/dsp/modules/colbertv2.py). `MODEL` is the hugging face model that you may want to use instead of gpt-3+, it is optional. Leave it blank if you want to use gpt-3+. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .request import ModelRequest | ||
from .model import Model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from model import Model | ||
from request import ModelRequest | ||
from quart import Quart, request | ||
import aiohttp | ||
|
||
app = Quart(__name__) | ||
|
||
model = None | ||
|
||
@app.before_serving | ||
async def startup(): | ||
app.client = aiohttp.ClientSession() | ||
global model | ||
model = Model(app) | ||
|
||
@app.route('/', methods=['POST']) | ||
async def embed(): | ||
global model | ||
data = await request.get_json() | ||
req = ModelRequest(**data) | ||
return model.inference(req) | ||
|
||
if __name__ == "__main__": | ||
app.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from request import ModelRequest | ||
import dsp | ||
from utils import DSP | ||
|
||
|
||
class Model(): | ||
def __new__(cls, context): | ||
cls.context = context | ||
if not hasattr(cls, 'instance'): | ||
cls.dsp = DSP() | ||
cls.instance = super(Model, cls).__new__(cls) | ||
return cls.instance | ||
|
||
|
||
def inference(self, request: ModelRequest): | ||
train = [dsp.Example(question=question, answer=answer) for question, answer in request.train] | ||
answer, history = self.dsp(request.text, train, request.server, request.hf_model) | ||
return {"text": answer, "history": history} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import json | ||
|
||
|
||
class ModelRequest(): | ||
def __init__(self, text, train, server: str, model: str=None): | ||
self.text = text | ||
self.train = train | ||
self.server = server | ||
self.hf_model = model | ||
|
||
def to_json(self): | ||
return json.dumps(self, default=lambda o: o.__dict__, | ||
sort_keys=True, indent=4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
dsp-ml | ||
accelerate | ||
torch | ||
tiktoken |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
import os | ||
import dsp | ||
import tiktoken | ||
from dsp.utils import deduplicate | ||
|
||
|
||
openai_key = os.getenv("OPENAI_API_KEY") | ||
|
||
|
||
class DSP(): | ||
def __init__(self): | ||
self.lm = dsp.GPT3(model='gpt-3.5-turbo-16k', api_key=openai_key, model_type='chat') | ||
self.sbert_reranker = dsp.SentenceTransformersCrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2") | ||
|
||
self.encoding = tiktoken.get_encoding("cl100k_base") | ||
|
||
self.qa_template_with_CoT = None | ||
self.rewrite_template = None | ||
self.hop_template = None | ||
self.train = None | ||
self.__generate_templates() | ||
|
||
def __call__(self, server: str, hf_model: str, question: str, train) -> str: | ||
self.rm = dsp.ColBERTv2(url=server) | ||
try: | ||
self.lm = dsp.HFModel(hf_model) | ||
except: | ||
self.lm = dsp.GPT3(model=hf_model, api_key=openai_key, model_type='chat') | ||
dsp.settings.configure(lm=self.lm, rm=self.rm, reranker=self.sbert_reranker) | ||
self.train = train | ||
return self.multihop_QA_v2(question), self.lm.inspect_history(n=3) | ||
|
||
def __generate_templates(self): | ||
Question = dsp.Type(prefix="Question:", desc="${the question to be answered}") | ||
Answer = dsp.Type(prefix="Answer:", desc="${a crisp and concise answer without redundant information}", format=dsp.format_answers) | ||
|
||
qa_template = dsp.Template(instructions="Answer questions with concise answers without redundant information.", question=Question(), answer=Answer()) | ||
|
||
Rationale = dsp.Type( | ||
prefix="Rationale: Let's think step by step.", | ||
desc="${a step-by-step deduction that identifies the correct response, which will be provided below}" | ||
) | ||
|
||
self.qa_template_with_CoT = dsp.Template( | ||
instructions=qa_template.instructions, | ||
context=Context(), question=Question(), rationale=Rationale(), answer=Answer() | ||
) | ||
|
||
Context = dsp.Type( | ||
prefix="Context:\n", | ||
desc="${sources that may contain relevant content}", | ||
format=dsp.passages2text | ||
) | ||
|
||
SearchRationale = dsp.Type( | ||
prefix="Rationale: Let's think step by step. To answer this question, we first need to find out", | ||
desc="${the missing information}" | ||
) | ||
|
||
SearchQuery = dsp.Type( | ||
prefix="Search Query:", | ||
desc="${a simple question for seeking the missing information}" | ||
) | ||
|
||
self.rewrite_template = dsp.Template( | ||
instructions="Write a search query that will help answer a complex question.", | ||
question=Question(), rationale=SearchRationale(), query=SearchQuery() | ||
) | ||
|
||
CondenseRationale = dsp.Type( | ||
prefix="Rationale: Let's think step by step. Based on the context, we have learned the following.", | ||
desc="${information from the context that provides useful clues}" | ||
) | ||
|
||
self.hop_template = dsp.Template( | ||
instructions=self.rewrite_template.instructions, | ||
context=Context(), question=Question(), rationale=CondenseRationale(), query=SearchQuery() | ||
) | ||
|
||
True_Answer = dsp.Type(prefix="The correct answer: ", desc="${The true answer}") | ||
Predicted_Answer = dsp.Type(prefix="Another answer: ", desc="${The answer to compare to}") | ||
Resp = dsp.Type(prefix="Response: ", desc="${True (or) False}") | ||
|
||
self.answer_match_template = dsp.Template( | ||
instructions="Return True if the essence of both answers are same else return False. Respond with only 'True' and 'False'.", | ||
true_answer=True_Answer(), predicted_answer=Predicted_Answer(), response=Resp() | ||
) | ||
|
||
def __answer_match(self, true_ans, pred_ans): | ||
match_example = dsp.Example(true_answer=true_ans, predicted_answer=pred_ans, demos=[]) | ||
_, completions = dsp.generate(self.answer_match_template)(match_example, stage='answer_match') | ||
return completions[0].response | ||
|
||
def __count_tokens(self, text): | ||
num_tokens = len(self.encoding.encode(text)) | ||
return num_tokens | ||
|
||
@dsp.transformation | ||
def QA_predict(self, example: dsp.Example, sc=True, return_store=False): | ||
temp = str(example) | ||
if sc: | ||
example, completions = dsp.generate(self.qa_template_with_CoT, n=20, temperature=0.7)(example, stage='qa') | ||
completions = dsp.majority(completions) | ||
else: | ||
example, completions = dsp.generate(self.qa_template_with_CoT)(example, stage='qa') | ||
|
||
if return_store: | ||
len_tokens = self.__count_tokens(temp) | ||
store = { | ||
"question": example.question, | ||
"context": example.context, | ||
"rationale": completions[0].rationale, | ||
"len_tokens": len_tokens | ||
} | ||
return example.copy(answer=completions.answer), store | ||
return example.copy(answer=completions.answer) | ||
|
||
@dsp.transformation | ||
def multihop_search_v1(self, example: dsp.Example, max_hops=2, k=2) -> dsp.Example: | ||
example.context = [] | ||
|
||
for hop in range(max_hops): | ||
# Generate a query based | ||
template = self.rewrite_template if hop == 0 else self.hop_template | ||
example, completions = dsp.generate(template)(example, stage=f'h{hop}') | ||
|
||
# Retrieve k results based on the query generated | ||
passages = dsp.retrieve(completions.query, k=k) | ||
|
||
# Update the context by concatenating old and new passages | ||
example.context = deduplicate(example.context + passages) | ||
|
||
return example | ||
|
||
@dsp.transformation | ||
def multihop_attempt(self, d: dsp.Example) -> dsp.Example: | ||
# Prepare unaugmented demonstrations for the example. | ||
x = dsp.Example(question=d.question, demos=dsp.all_but(self.train, d)) | ||
|
||
# Search. | ||
# Annotate demonstrations for multihop_search_v2 with the simpler multihop_search_v1 pipeline. | ||
x = self.multihop_search_v1(x) | ||
|
||
# Predict. And skip examples where predict fails. | ||
x = self.QA_predict(x, sc=False) | ||
if not self.__answer_match(x.answer, d.answer) == "True": return None | ||
|
||
return d.copy(**x) | ||
|
||
@dsp.transformation | ||
def multihop_demonstrate(self, x: dsp.Example) -> dsp.Example: | ||
demos = dsp.sample(self.train, k=7) | ||
x.demos = dsp.annotate(self.multihop_attempt)(demos, k=3, return_all=True) | ||
return x | ||
|
||
@dsp.transformation | ||
def multihop_search_v2(self, example: dsp.Example, max_hops=2, k=5) -> dsp.Example: | ||
example.context = [] | ||
store = [] | ||
|
||
for hop in range(max_hops): | ||
# Generate queries | ||
template = self.rewrite_template if hop == 0 else self.hop_template | ||
|
||
len_tokens = self.__count_tokens(str(example)) | ||
example, completions = dsp.generate(template, n=10, temperature=0.7)(example, stage=f'h{hop}') | ||
|
||
# Collect the queries and search with result fusion | ||
queries = [c.query for c in completions] + [example.question] | ||
example.context = dsp.retrieveEnsemble(queries, k=k) | ||
|
||
# Arrange the passages for the next hop | ||
if hop > 0: | ||
example.context = [completions[0].rationale] + example.context | ||
|
||
store.append({ | ||
"question": queries, | ||
"rationale": completions[0].rationale, | ||
"context": example.context, | ||
"len_tokens": len_tokens | ||
}) | ||
|
||
return example, store | ||
|
||
def multihop_QA_v2(self, question: str) -> str: | ||
x = dsp.Example(question=question) | ||
x = self.multihop_demonstrate(x) | ||
x, stores = self.multihop_search_v2(x) | ||
x, store = self.QA_predict(x, return_store=True) | ||
stores.append(store) | ||
return x.answer, stores |