Skip to content

Commit

Permalink
Add lightweight-serving whisper asr example (#11847)
Browse files Browse the repository at this point in the history
* add asr init

* update for pp

* update style

* update readme

* update reamde
  • Loading branch information
hzjane authored Aug 22, 2024
1 parent a8e2573 commit 5c4ed00
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 54 deletions.
25 changes: 24 additions & 1 deletion python/llm/example/GPU/Lightweight-Serving/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
# for internlm-xcomposer2-vl-7b
pip install transformers==4.31.0
pip install accelerate timm==0.4.12 sentencepiece==0.1.99 gradio==3.44.4 markdown2==2.4.10 xlsxwriter==3.1.2 einops

# for whisper-large-v3
pip install transformers==4.36.2
pip install datasets soundfile librosa # required by audio processing
```

#### 1.2 Installation on Windows
Expand All @@ -35,6 +39,14 @@ pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-exte
pip install fastapi uvicorn openai
pip install gradio # for gradio web UI
conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc

# for internlm-xcomposer2-vl-7b
pip install transformers==4.31.0
pip install accelerate timm==0.4.12 sentencepiece==0.1.99 gradio==3.44.4 markdown2==2.4.10 xlsxwriter==3.1.2 einops

# for whisper-large-v3
pip install transformers==4.36.2
pip install datasets soundfile librosa # required by audio processing
```

### 2. Configures OneAPI environment variables for Linux
Expand Down Expand Up @@ -180,7 +192,7 @@ curl http://localhost:8000/v1/chat/completions \

image input only supports [internlm-xcomposer2-vl-7b](https://huggingface.co/internlm/internlm-xcomposer2-vl-7b) now, and it must install transformers==4.31.0 to run.
```bash
wget -O ./test.jpg http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg
wget -O /llm/lightweight_serving/test.jpg http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
Expand Down Expand Up @@ -219,6 +231,17 @@ curl http://localhost:8000/v1/completions \
}'
```

#### v1/audio/transcriptions

ASR only supports [whisper-large-v3](https://huggingface.co/openai/whisper-large-v3) now. And `whisper-large-v3` just can be used to transcription audio. The audio file_type should be supported by `librosa.load`.
```bash
curl http://localhost:8000/v1/audio/transcriptions \
-H "Content-Type: multipart/form-data" \
-F file="@/llm/test.mp3" \
-F model="whisper-large-v3" \
-F languag="zh"
```

### 6. Benchmark with wrk

Please refer to [here](https://github.com/intel-analytics/ipex-llm/tree/main/python/llm/example/GPU/Pipeline-Parallel-Serving#4-benchmark-with-wrk) for more details
Expand Down
19 changes: 13 additions & 6 deletions python/llm/example/GPU/Lightweight-Serving/lightweight_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,19 @@ async def main():
model_path = args.repo_id_or_model_path
low_bit = args.low_bit

local_model = ModelWorker(model_path, low_bit)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left')
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
myapp = FastApp(local_model, tokenizer)
processor = None
if "whisper" not in model_path.lower():
local_model = ModelWorker(model_path, low_bit)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left')
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
else:
local_model = ModelWorker(model_path, low_bit, "audio", torch_dtype=torch.float32)
from transformers import WhisperProcessor
processor = WhisperProcessor.from_pretrained(model_path)
tokenizer = processor.tokenizer
myapp = FastApp(local_model, tokenizer, processor)
config = uvicorn.Config(app=myapp.app, host="0.0.0.0", port=args.port)
server = uvicorn.Server(config)
await server.serve()
Expand Down
42 changes: 40 additions & 2 deletions python/llm/src/ipex_llm/serving/fastapi/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from typing import List, Optional, Union, Dict
from fastapi.middleware.cors import CORSMiddleware
from .tgi_protocol import Parameters
from typing_extensions import Literal
from fastapi import File, UploadFile, Form
from .openai_protocol import (
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
Expand All @@ -38,6 +40,8 @@
CompletionResponse,
CompletionResponseStreamChoice,
CompletionStreamResponse,
TranscriptionRequest,
TranscriptionResponse,
)

result_dict: Dict[str, str] = {}
Expand All @@ -50,6 +54,7 @@ class InputsRequest(BaseModel):
image_list: Optional[list] = None
stream: Optional[bool] = False
req_type: str = 'completion'
transcription_request: Optional[TranscriptionRequest] = None


class ChatCompletionRequest(BaseModel):
Expand Down Expand Up @@ -92,20 +97,27 @@ class CompletionRequest(BaseModel):

global tokenizer
global local_model
global processor


class FastApp():
def __init__(self, model, mytokenizer):
def __init__(self, model, mytokenizer, myprocessor=None):
global tokenizer
global local_model
global processor
local_model = model
tokenizer = mytokenizer
processor = myprocessor
self.app = app


def get_queue_next_token(delta_text_queue):
timeout = int(os.getenv("IPEX_LLM_FASTAPI_TIMEOUT", 60))
delta_text = delta_text_queue.text_queue.get(timeout=timeout)
if "whisper" in local_model.model_name.lower():
if delta_text is not None and "<|" in delta_text and "|>" in delta_text:
import re
delta_text = re.sub(r'<\|.*?\|>', '', delta_text)
if delta_text is None:
remain = 0
else:
Expand Down Expand Up @@ -385,6 +397,32 @@ async def create_completion(request: CompletionRequest):
return result


@app.post("/v1/audio/transcriptions")
async def transcriptions(
file: UploadFile=File(...),
model: Optional[str]=Form("default_model"),
language: Optional[str]=Form("zh"),
prompt: Optional[str]=Form(None),
response_format: Optional[Literal["json", "text", "srt", "verbose_json", "vtt"]]=Form(None),
temperature: Optional[float]=Form(None),
timestamp_granularities: Optional[List[Literal["word", "segment"]]]=Form(None)
):
file_path = "./" + file.filename
if not os.path.exists(file_path):
with open(file_path, "wb") as f:
f.write(await file.read())
inputs_request = InputsRequest(
inputs="transcriptions",
parameters=None,
stream=False,
req_type="completion",
transcription_request=TranscriptionRequest(file=file_path, model=model, language=language)
)
request_id, result = await generate(inputs_request)
rsp = TranscriptionResponse(text=result)
return rsp


@app.on_event("startup")
async def startup_event():
asyncio.create_task(process_requests(local_model, result_dict))
Expand All @@ -393,4 +431,4 @@ async def startup_event():
async def process_requests(local_model, result_dict):
while True:
await asyncio.sleep(0)
await local_model.process_step(tokenizer, result_dict)
await local_model.process_step(tokenizer, result_dict, processor)
128 changes: 84 additions & 44 deletions python/llm/src/ipex_llm/serving/fastapi/model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,37 +23,69 @@


class ModelWorker:
def __init__(self, checkpoint, low_bit, torch_dtype=torch.float16):
def __init__(self, checkpoint, low_bit, model_type="normal", torch_dtype=torch.float16):
self.dtype = torch_dtype
start = time.perf_counter()
model = self.load_model(checkpoint, low_bit)
from ipex_llm.utils import BenchmarkWrapper
self.model = BenchmarkWrapper(model, do_print=True)
if model_type == "audio":
self.model = self.load_model(checkpoint, low_bit, "audio")
else:
model = self.load_model(checkpoint, low_bit)
from ipex_llm.utils import BenchmarkWrapper
self.model = BenchmarkWrapper(model, do_print=True)
end = time.perf_counter()
logger.info(f"Time to load weights: {end - start:.2f}s")
self.waiting_requests = asyncio.Queue()
self.streamer = {}
self.model_name = checkpoint

def load_model(self, model_path, low_bit='sym_int4'):
from ipex_llm.transformers import AutoModelForCausalLM, AutoModel
try:
model = AutoModelForCausalLM.from_pretrained(model_path,
load_in_low_bit=low_bit,
torch_dtype=self.dtype,
optimize_model=True,
trust_remote_code=True,
use_cache=True,)
except:
model = AutoModel.from_pretrained(model_path,
load_in_low_bit=low_bit,
torch_dtype=self.dtype,
optimize_model=True,
trust_remote_code=True,
use_cache=True,)
def load_model(self, model_path, low_bit='sym_int4', model_type="normal"):
if model_type == "audio":
from ipex_llm.transformers import AutoModelForSpeechSeq2Seq
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_path,
load_in_low_bit=low_bit,
torch_dtype=self.dtype,
optimize_model=True,
trust_remote_code=True,
use_cache=True)
else:
from ipex_llm.transformers import AutoModelForCausalLM, AutoModel
try:
model = AutoModelForCausalLM.from_pretrained(model_path,
load_in_low_bit=low_bit,
torch_dtype=self.dtype,
optimize_model=True,
trust_remote_code=True,
use_cache=True,)
except:
model = AutoModel.from_pretrained(model_path,
load_in_low_bit=low_bit,
torch_dtype=self.dtype,
optimize_model=True,
trust_remote_code=True,
use_cache=True,)
model = model.eval().to("xpu")
return model

async def add_asr_request(self, processor):
if self.waiting_requests.empty():
return
tmp_result = await self.waiting_requests.get()
request_id, request = tmp_result
transcription_request = request.transcription_request
forced_decoder_ids = processor.get_decoder_prompt_ids(
language=transcription_request.language, task="transcribe")
audio_path = transcription_request.file
import librosa
raw_speech, sampling_rate = librosa.load(audio_path,
sr=processor.feature_extractor.sampling_rate)
input_features = processor(
raw_speech,
sampling_rate=sampling_rate,
return_tensors="pt",
return_attention_mask=True,
).input_features.to('xpu')
return input_features, forced_decoder_ids, request_id

async def add_request(self, tokenizer):
if self.waiting_requests.empty():
return
Expand Down Expand Up @@ -91,33 +123,41 @@ async def add_request(self, tokenizer):
return input_ids, parameters, request_id, inputs_embeds

@torch.no_grad()
async def process_step(self, tokenizer, result_dict):
async def process_step(self, tokenizer, result_dict, processor=None):
if not self.waiting_requests.empty():
input_ids, parameters, request_id, inputs_embeds = await self.add_request(tokenizer)
self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True)
if processor is not None and "whisper" in self.model_name.lower():
input_features, decoder_ids, request_id = await self.add_asr_request(processor)
self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True)

def model_generate():
generate_kwargs = {k: v for k, v in parameters.dict().items() if v is not None}
if "codegeex" in self.model_name.lower():
eos_token_id = [tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|user|>"),
tokenizer.convert_tokens_to_ids("<|observation|>")]
generate_kwargs["eos_token_id"] = eos_token_id
elif "internlm-xcomposer2-vl-7b" in self.model_name.lower():
eos_token_id = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids(['[UNUSED_TOKEN_145]'])[0]
]
generate_kwargs["eos_token_id"] = eos_token_id
if input_ids is not None:
self.model.generate(input_ids,
streamer=self.streamer[request_id], **generate_kwargs)
elif inputs_embeds is not None:
self.model.generate(inputs_embeds=inputs_embeds,
streamer=self.streamer[request_id], **generate_kwargs)
torch.xpu.empty_cache()
torch.xpu.synchronize()
def model_generate():
self.model.generate(input_features,
streamer=self.streamer[request_id],
forced_decoder_ids=decoder_ids)
else:
input_ids, parameters, request_id, inputs_embeds = await self.add_request(tokenizer)
self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True)

def model_generate():
generate_kwargs = {k: v for k, v in parameters.dict().items() if v is not None}
if "codegeex" in self.model_name.lower():
eos_token_id = [tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|user|>"),
tokenizer.convert_tokens_to_ids("<|observation|>")]
generate_kwargs["eos_token_id"] = eos_token_id
elif "internlm-xcomposer2-vl-7b" in self.model_name.lower():
eos_token_id = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids(['[UNUSED_TOKEN_145]'])[0]
]
generate_kwargs["eos_token_id"] = eos_token_id
if input_ids is not None:
self.model.generate(input_ids,
streamer=self.streamer[request_id], **generate_kwargs)
elif inputs_embeds is not None:
self.model.generate(inputs_embeds=inputs_embeds,
streamer=self.streamer[request_id], **generate_kwargs)
torch.xpu.empty_cache()
torch.xpu.synchronize()
from threading import Thread
t1 = Thread(target=model_generate)
t1.start()
15 changes: 15 additions & 0 deletions python/llm/src/ipex_llm/serving/fastapi/openai_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,28 @@
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Annotated
from ipex_llm.utils.common import invalidInputError
from typing_extensions import Literal


# from vllm.sampling_params import SamplingParams
def random_uuid() -> str:
return str(uuid.uuid4().hex)


class TranscriptionRequest(BaseModel):
file: str = None
model: Optional[str] = "default_model"
language: Optional[str] = "zh"
prompt: Optional[str] = None
response_format: Optional[Literal["json", "text", "srt", "verbose_json", "vtt"]] = None
temperature: Optional[float] = None
timestamp_granularities: Optional[List[Literal["word", "segment"]]] = None


class TranscriptionResponse(BaseModel):
text: str


class OpenAIBaseModel(BaseModel):
# OpenAI API does not allow extra fields
model_config = ConfigDict(extra="forbid")
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,7 @@ async def stream_output(self, cur_batch, tokenizer, next_ids):
_stream_tasks.append(self.streamer[request_id].put((remain, printable_text)))
await asyncio.gather(*_stream_tasks)

async def process_step(self, tokenizer, result_dict):
async def process_step(self, tokenizer, result_dict, processor=None):
cur_batch = None
torch.xpu.synchronize(self.device)
if self.rank == 0:
Expand Down

0 comments on commit 5c4ed00

Please sign in to comment.