Skip to content

Commit

Permalink
Fix plugins (#3017)
Browse files Browse the repository at this point in the history
Plugins was compatible only with `text-generation-inference` based
workers and therefore worked on Dragan's machines but did not work on OA
prod. This resolves the incompatibility.

---------

Co-authored-by: draganjovanovich <[email protected]>
  • Loading branch information
olliestanley and draganjovanovich authored May 4, 2023
1 parent 9bcc916 commit 5dd1025
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 106 deletions.
22 changes: 19 additions & 3 deletions inference/worker/basic_hf_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import transformers
import uvicorn
from fastapi.middleware.cors import CORSMiddleware
from hf_stopping import SequenceStoppingCriteria
from loguru import logger
from oasst_shared import model_configs
from settings import settings
Expand Down Expand Up @@ -60,8 +61,12 @@ def model_thread():
prompt = request.inputs
params = request.parameters.dict()
seed = params.pop("seed")
params.pop("stop")
stop_sequences = params.pop("stop")
params.pop("details")
params.pop("plugins")

if seed is not None:
torch.manual_seed(seed)

last_token_id = None # need to delay by 1 to simulate tgi

Expand All @@ -79,7 +84,18 @@ def print_text(token_id: int):
ids = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=False)
streamer = hf_streamer.HFStreamer(input_ids=ids, printer=print_text)
ids = ids.to(model.device)
output = model.generate(ids, **params, streamer=streamer, eos_token_id=eos_token_id)
stopping_criteria = (
transformers.StoppingCriteriaList([SequenceStoppingCriteria(tokenizer, stop_sequences, prompt)])
if stop_sequences
else None
)
output = model.generate(
ids,
**params,
streamer=streamer,
eos_token_id=eos_token_id,
stopping_criteria=stopping_criteria,
)
output = output.cpu()
output_ids = output[0][len(ids[0]) :]
decoded = tokenizer.decode(output_ids, skip_special_tokens=True)
Expand Down Expand Up @@ -130,7 +146,7 @@ def decode_token(token_id):
return result[special_decode_token_length:]

config_dtype = hf_config.torch_dtype if hasattr(hf_config, "torch_dtype") else torch.float32
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else config_dtype
dtype = torch.bfloat16 if torch.has_cuda and torch.cuda.is_bf16_supported() else config_dtype

model = transformers.AutoModelForCausalLM.from_pretrained(
model_config.model_id,
Expand Down
7 changes: 6 additions & 1 deletion inference/worker/chat_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import interface
import transformers
import utils
from chat_chain_prompts import (
ASSISTANT_PREFIX,
HUMAN_PREFIX,
Expand Down Expand Up @@ -65,6 +66,7 @@ def handle_plugin_usage(
plugin: inference.PluginEntry | None,
worker_config: inference.WorkerConfig,
tokenizer: transformers.PreTrainedTokenizer,
parameters: interface.GenerateStreamParameters,
) -> tuple[str, inference.PluginUsed]:
execution_details = inference.PluginExecutionDetails(
inner_monologue=[],
Expand Down Expand Up @@ -115,6 +117,8 @@ def handle_plugin_usage(
# NOTE: Do not strip() any of the outputs ever, as it will degrade the
# instruction following performance, at least with
# `OpenAssistant/oasst-sft-6-llama-30b-epoch-1 model`

init_prompt = utils.truncate_prompt(tokenizer, worker_config, parameters, init_prompt, True)
chain_response = (
llm.generate(prompts=[init_prompt], stop=[ASSISTANT_PREFIX, OBSERVATION_SEQ, f"\n{OBSERVATION_SEQ}"])
.generations[0][0]
Expand Down Expand Up @@ -159,6 +163,7 @@ def handle_plugin_usage(
# NOTE: Do not strip() any of the outputs ever, as it will degrade the
# instruction following performance, at least with
# `OpenAssistant/oasst-sft-6-llama-30b-epoch-1 model`
new_prompt = utils.truncate_prompt(tokenizer, worker_config, parameters, new_prompt, True)
chain_response = (
llm.generate(prompts=[new_prompt], stop=[ASSISTANT_PREFIX, OBSERVATION_SEQ, f"\n{OBSERVATION_SEQ}"])
.generations[0][0]
Expand Down Expand Up @@ -311,7 +316,7 @@ def handle_conversation(
# using sampling settings derived from frontend UI
if plugin_enabled:
return handle_plugin_usage(
original_prompt, prompt_template, language, tools, memory, plugin, worker_config, tokenizer
original_prompt, prompt_template, language, tools, memory, plugin, worker_config, tokenizer, parameters
)

# Just regular prompt template without plugin chain.
Expand Down
10 changes: 4 additions & 6 deletions inference/worker/chat_chain_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import re
import threading

import requests
import transformers
Expand All @@ -11,10 +10,9 @@
from langchain.prompts import PromptTemplate
from loguru import logger
from oasst_shared.schemas import inference
from opeanapi_parser import prepare_plugin_for_llm
from openapi_parser import prepare_plugin_for_llm
from settings import settings

tokenizer_lock = threading.Lock()
from utils import shared_tokenizer_lock

RESPONSE_MAX_LENGTH = 2048

Expand Down Expand Up @@ -343,7 +341,7 @@ def prepare_prompt(

out_prompt = prompt_template.format(**args)

with tokenizer_lock:
with shared_tokenizer_lock:
ids = tokenizer.encode(out_prompt)

# soft truncation
Expand All @@ -362,7 +360,7 @@ def prepare_prompt(

out_prompt = prompt_template.format(**args)

with tokenizer_lock:
with shared_tokenizer_lock:
ids = tokenizer.encode(out_prompt)
logger.warning(f"Prompt too long, deleting chat history. New length: {len(ids)}")

Expand Down
41 changes: 25 additions & 16 deletions inference/worker/hf_langchain_inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import interface
import utils
from langchain.llms.base import LLM
from text_generation import Client


class HFInference(LLM):
Expand All @@ -23,22 +24,30 @@ def _call(self, prompt: str, stop: list[str] | None = None) -> str:
else:
stop += self.stop_sequences

print(stop)
client = Client(self.inference_server_url, timeout=1000)
res = client.generate(
prompt,
stop_sequences=stop,
max_new_tokens=self.max_new_tokens,
top_k=self.top_k,
top_p=self.top_p,
typical_p=self.typical_p,
temperature=self.temperature,
repetition_penalty=self.repetition_penalty,
seed=self.seed,
request = interface.GenerateStreamRequest(
inputs=prompt,
parameters=interface.GenerateStreamParameters(
stop=stop,
max_new_tokens=self.max_new_tokens,
top_k=self.top_k,
top_p=self.top_p,
typical_p=self.typical_p,
temperature=self.temperature,
repetition_penalty=self.repetition_penalty,
seed=self.seed,
),
)

for event in utils.get_inference_server_stream_events(request):
stream_response = event

generated_text = stream_response.generated_text
if generated_text is None:
generated_text = ""

# remove stop sequences from the end of the generated text
for stop_seq in stop:
if stop_seq in res.generated_text:
res.generated_text = res.generated_text[: res.generated_text.index(stop_seq)]
if stop_seq in generated_text:
generated_text = generated_text[: generated_text.index(stop_seq)]

return res.generated_text
return generated_text
31 changes: 31 additions & 0 deletions inference/worker/hf_stopping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch
from tokenizers import Tokenizer
from transformers import StoppingCriteria


class SequenceStoppingCriteria(StoppingCriteria):
def __init__(
self,
tokenizer: Tokenizer,
stop_texts: list[str],
input_prompt: str,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.stop_texts = stop_texts
self.tokenizer = tokenizer
self.input_length = len(tokenizer.encode(input_prompt))

def __call__(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
**kwargs,
) -> bool:
# Assumes batch size 1, sufficient for our use case
generated_ids = input_ids[0, self.input_length :].tolist()
# TODO: optimise this. Inefficient to decode whole sequence every time
# but can't encode stop sequences as they don't always tokenize the same
generated_text = self.tokenizer.decode(generated_ids)
return any(text in generated_text for text in self.stop_texts)
File renamed without changes.
1 change: 0 additions & 1 deletion inference/worker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,5 @@ pydantic
requests
sentencepiece
sseclient-py
text-generation
git+https://github.com/huggingface/transformers@main#egg=transformers
websocket-client
73 changes: 73 additions & 0 deletions inference/worker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,15 @@
import lorem
import pydantic
import requests
import sseclient
import transformers
import websocket
from chat_chain_prompts import V2_PROMPTER_PREFIX
from loguru import logger
from oasst_shared.schemas import inference
from settings import settings

shared_tokenizer_lock = threading.Lock()


class TokenBuffer:
Expand Down Expand Up @@ -58,6 +64,42 @@ def finish(self, reason: Literal["length", "eos_token", "stop_sequence"]) -> Ite
yield from self.tokens


def truncate_prompt(
tokenizer: transformers.PreTrainedTokenizer,
worker_config: inference.WorkerConfig,
parameters: interface.GenerateStreamParameters,
prompt: str,
plugin_used: bool,
):
with shared_tokenizer_lock:
ids = tokenizer.encode(prompt)

max_input_length = worker_config.model_config.max_input_length

# make room for prompter prefix
if plugin_used:
max_input_length = max_input_length - 1

max_total_tokens = worker_config.model_config.max_total_length
if len(ids) > max_input_length:
logger.warning(f"Prompt too long, left-truncating to {max_input_length} tokens")
ids = ids[-(max_input_length - 1) :]
with shared_tokenizer_lock:
prompt = tokenizer.decode(ids)
# If there is no prompter prefix, due to truncation, add it back.
if V2_PROMPTER_PREFIX not in prompt:
prompt = V2_PROMPTER_PREFIX + prompt

input_length = len(ids)
spare = max_total_tokens - input_length - 1
if not parameters.max_new_tokens:
parameters.max_new_tokens = spare
elif parameters.max_new_tokens > spare:
logger.warning(f"Max new tokens too high, reducing to {spare}")
parameters.max_new_tokens = spare
return prompt


def wait_for_inference_server(http: "HttpClient", timeout: int = 600):
time_limit = time.time() + timeout
while True:
Expand Down Expand Up @@ -136,3 +178,34 @@ def get(self, path: str, **kwargs):

def post(self, path: str, **kwargs):
return requests.post(self.base_url + path, auth=self.auth, **kwargs)


def get_inference_server_stream_events(request: interface.GenerateStreamRequest):
http = HttpClient(
base_url=settings.inference_server_url,
basic_auth_username=settings.basic_auth_username,
basic_auth_password=settings.basic_auth_password,
)
response = http.post(
"/generate_stream",
json=request.dict(),
stream=True,
headers={"Accept": "text/event-stream"},
)
try:
response.raise_for_status()
except requests.HTTPError:
logger.exception("Failed to get response from inference server")
logger.error(f"Response: {response.text}")
raise

client = sseclient.SSEClient(response)
for event in client.events():
if event.event == "error":
logger.error(f"Error from inference server: {event.data}")
yield interface.GenerateStreamResponse(error=event.data)
raise RuntimeError(f"Error from inference server: {event.data}")
if event.event == "ping":
continue
stream_response = interface.GenerateStreamResponse.parse_raw(event.data)
yield stream_response
Loading

0 comments on commit 5dd1025

Please sign in to comment.