Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lightweight-serving whisper asr example #11847

Merged
merged 5 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion python/llm/example/GPU/Lightweight-Serving/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
# for internlm-xcomposer2-vl-7b
pip install transformers==4.31.0
pip install accelerate timm==0.4.12 sentencepiece==0.1.99 gradio==3.44.4 markdown2==2.4.10 xlsxwriter==3.1.2 einops

# for whisper-large-v3
pip install transformers==4.36.2
pip install datasets soundfile librosa # required by audio processing
```

#### 1.2 Installation on Windows
Expand All @@ -35,6 +39,14 @@ pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-exte
pip install fastapi uvicorn openai
pip install gradio # for gradio web UI
conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc

# for internlm-xcomposer2-vl-7b
pip install transformers==4.31.0
pip install accelerate timm==0.4.12 sentencepiece==0.1.99 gradio==3.44.4 markdown2==2.4.10 xlsxwriter==3.1.2 einops

# for whisper-large-v3
pip install transformers==4.36.2
pip install datasets soundfile librosa # required by audio processing
```

### 2. Configures OneAPI environment variables for Linux
Expand Down Expand Up @@ -180,7 +192,7 @@ curl http://localhost:8000/v1/chat/completions \

image input only supports [internlm-xcomposer2-vl-7b](https://huggingface.co/internlm/internlm-xcomposer2-vl-7b) now, and it must install transformers==4.31.0 to run.
```bash
wget -O ./test.jpg http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg
wget -O /llm/lightweight_serving/test.jpg http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
Expand Down Expand Up @@ -219,6 +231,17 @@ curl http://localhost:8000/v1/completions \
}'
```

#### v1/audio/transcriptions

ASR only supports [whisper-large-v3](https://huggingface.co/openai/whisper-large-v3) now. And `whisper-large-v3` just can be used to transcription audio. The audio file_type should be supported by `librosa.load`.
```bash
curl http://localhost:8000/v1/audio/transcriptions \
-H "Content-Type: multipart/form-data" \
-F file="@/llm/test.mp3" \
-F model="whisper-large-v3" \
-F languag="zh"
```

### 6. Benchmark with wrk

Please refer to [here](https://github.com/intel-analytics/ipex-llm/tree/main/python/llm/example/GPU/Pipeline-Parallel-Serving#4-benchmark-with-wrk) for more details
Expand Down
19 changes: 13 additions & 6 deletions python/llm/example/GPU/Lightweight-Serving/lightweight_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,19 @@ async def main():
model_path = args.repo_id_or_model_path
low_bit = args.low_bit

local_model = ModelWorker(model_path, low_bit)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left')
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
myapp = FastApp(local_model, tokenizer)
processor = None
if "whisper" not in model_path.lower():
local_model = ModelWorker(model_path, low_bit)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left')
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
else:
local_model = ModelWorker(model_path, low_bit, "audio", torch_dtype=torch.float32)
from transformers import WhisperProcessor
processor = WhisperProcessor.from_pretrained(model_path)
tokenizer = processor.tokenizer
myapp = FastApp(local_model, tokenizer, processor)
config = uvicorn.Config(app=myapp.app, host="0.0.0.0", port=args.port)
server = uvicorn.Server(config)
await server.serve()
Expand Down
42 changes: 40 additions & 2 deletions python/llm/src/ipex_llm/serving/fastapi/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from typing import List, Optional, Union, Dict
from fastapi.middleware.cors import CORSMiddleware
from .tgi_protocol import Parameters
from typing_extensions import Literal
from fastapi import File, UploadFile, Form
from .openai_protocol import (
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
Expand All @@ -38,6 +40,8 @@
CompletionResponse,
CompletionResponseStreamChoice,
CompletionStreamResponse,
TranscriptionRequest,
TranscriptionResponse,
)

result_dict: Dict[str, str] = {}
Expand All @@ -50,6 +54,7 @@ class InputsRequest(BaseModel):
image_list: Optional[list] = None
stream: Optional[bool] = False
req_type: str = 'completion'
transcription_request: Optional[TranscriptionRequest] = None


class ChatCompletionRequest(BaseModel):
Expand Down Expand Up @@ -92,20 +97,27 @@ class CompletionRequest(BaseModel):

global tokenizer
global local_model
global processor


class FastApp():
def __init__(self, model, mytokenizer):
def __init__(self, model, mytokenizer, myprocessor=None):
global tokenizer
global local_model
global processor
local_model = model
tokenizer = mytokenizer
processor = myprocessor
self.app = app


def get_queue_next_token(delta_text_queue):
timeout = int(os.getenv("IPEX_LLM_FASTAPI_TIMEOUT", 60))
delta_text = delta_text_queue.text_queue.get(timeout=timeout)
if "whisper" in local_model.model_name.lower():
if delta_text is not None and "<|" in delta_text and "|>" in delta_text:
import re
delta_text = re.sub(r'<\|.*?\|>', '', delta_text)
if delta_text is None:
remain = 0
else:
Expand Down Expand Up @@ -385,6 +397,32 @@ async def create_completion(request: CompletionRequest):
return result


@app.post("/v1/audio/transcriptions")
async def transcriptions(
file: UploadFile=File(...),
model: Optional[str]=Form("default_model"),
language: Optional[str]=Form("zh"),
prompt: Optional[str]=Form(None),
response_format: Optional[Literal["json", "text", "srt", "verbose_json", "vtt"]]=Form(None),
temperature: Optional[float]=Form(None),
timestamp_granularities: Optional[List[Literal["word", "segment"]]]=Form(None)
):
file_path = "./" + file.filename
if not os.path.exists(file_path):
with open(file_path, "wb") as f:
f.write(await file.read())
inputs_request = InputsRequest(
inputs="transcriptions",
parameters=None,
stream=False,
req_type="completion",
transcription_request=TranscriptionRequest(file=file_path, model=model, language=language)
)
request_id, result = await generate(inputs_request)
rsp = TranscriptionResponse(text=result)
return rsp


@app.on_event("startup")
async def startup_event():
asyncio.create_task(process_requests(local_model, result_dict))
Expand All @@ -393,4 +431,4 @@ async def startup_event():
async def process_requests(local_model, result_dict):
while True:
await asyncio.sleep(0)
await local_model.process_step(tokenizer, result_dict)
await local_model.process_step(tokenizer, result_dict, processor)
128 changes: 84 additions & 44 deletions python/llm/src/ipex_llm/serving/fastapi/model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,37 +23,69 @@


class ModelWorker:
def __init__(self, checkpoint, low_bit, torch_dtype=torch.float16):
def __init__(self, checkpoint, low_bit, model_type="normal", torch_dtype=torch.float16):
self.dtype = torch_dtype
start = time.perf_counter()
model = self.load_model(checkpoint, low_bit)
from ipex_llm.utils import BenchmarkWrapper
self.model = BenchmarkWrapper(model, do_print=True)
if model_type == "audio":
self.model = self.load_model(checkpoint, low_bit, "audio")
else:
model = self.load_model(checkpoint, low_bit)
from ipex_llm.utils import BenchmarkWrapper
self.model = BenchmarkWrapper(model, do_print=True)
end = time.perf_counter()
logger.info(f"Time to load weights: {end - start:.2f}s")
self.waiting_requests = asyncio.Queue()
self.streamer = {}
self.model_name = checkpoint

def load_model(self, model_path, low_bit='sym_int4'):
from ipex_llm.transformers import AutoModelForCausalLM, AutoModel
try:
model = AutoModelForCausalLM.from_pretrained(model_path,
load_in_low_bit=low_bit,
torch_dtype=self.dtype,
optimize_model=True,
trust_remote_code=True,
use_cache=True,)
except:
model = AutoModel.from_pretrained(model_path,
load_in_low_bit=low_bit,
torch_dtype=self.dtype,
optimize_model=True,
trust_remote_code=True,
use_cache=True,)
def load_model(self, model_path, low_bit='sym_int4', model_type="normal"):
if model_type == "audio":
from ipex_llm.transformers import AutoModelForSpeechSeq2Seq
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_path,
load_in_low_bit=low_bit,
torch_dtype=self.dtype,
optimize_model=True,
trust_remote_code=True,
use_cache=True)
else:
from ipex_llm.transformers import AutoModelForCausalLM, AutoModel
try:
model = AutoModelForCausalLM.from_pretrained(model_path,
load_in_low_bit=low_bit,
torch_dtype=self.dtype,
optimize_model=True,
trust_remote_code=True,
use_cache=True,)
except:
model = AutoModel.from_pretrained(model_path,
load_in_low_bit=low_bit,
torch_dtype=self.dtype,
optimize_model=True,
trust_remote_code=True,
use_cache=True,)
model = model.eval().to("xpu")
return model

async def add_asr_request(self, processor):
if self.waiting_requests.empty():
return
tmp_result = await self.waiting_requests.get()
request_id, request = tmp_result
transcription_request = request.transcription_request
forced_decoder_ids = processor.get_decoder_prompt_ids(
language=transcription_request.language, task="transcribe")
audio_path = transcription_request.file
import librosa
raw_speech, sampling_rate = librosa.load(audio_path,
sr=processor.feature_extractor.sampling_rate)
input_features = processor(
raw_speech,
sampling_rate=sampling_rate,
return_tensors="pt",
return_attention_mask=True,
).input_features.to('xpu')
return input_features, forced_decoder_ids, request_id

async def add_request(self, tokenizer):
if self.waiting_requests.empty():
return
Expand Down Expand Up @@ -91,33 +123,41 @@ async def add_request(self, tokenizer):
return input_ids, parameters, request_id, inputs_embeds

@torch.no_grad()
async def process_step(self, tokenizer, result_dict):
async def process_step(self, tokenizer, result_dict, processor=None):
if not self.waiting_requests.empty():
input_ids, parameters, request_id, inputs_embeds = await self.add_request(tokenizer)
self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True)
if processor is not None and "whisper" in self.model_name.lower():
input_features, decoder_ids, request_id = await self.add_asr_request(processor)
self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True)

def model_generate():
generate_kwargs = {k: v for k, v in parameters.dict().items() if v is not None}
if "codegeex" in self.model_name.lower():
eos_token_id = [tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|user|>"),
tokenizer.convert_tokens_to_ids("<|observation|>")]
generate_kwargs["eos_token_id"] = eos_token_id
elif "internlm-xcomposer2-vl-7b" in self.model_name.lower():
eos_token_id = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids(['[UNUSED_TOKEN_145]'])[0]
]
generate_kwargs["eos_token_id"] = eos_token_id
if input_ids is not None:
self.model.generate(input_ids,
streamer=self.streamer[request_id], **generate_kwargs)
elif inputs_embeds is not None:
self.model.generate(inputs_embeds=inputs_embeds,
streamer=self.streamer[request_id], **generate_kwargs)
torch.xpu.empty_cache()
torch.xpu.synchronize()
def model_generate():
self.model.generate(input_features,
streamer=self.streamer[request_id],
forced_decoder_ids=decoder_ids)
else:
input_ids, parameters, request_id, inputs_embeds = await self.add_request(tokenizer)
self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True)

def model_generate():
generate_kwargs = {k: v for k, v in parameters.dict().items() if v is not None}
if "codegeex" in self.model_name.lower():
eos_token_id = [tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|user|>"),
tokenizer.convert_tokens_to_ids("<|observation|>")]
generate_kwargs["eos_token_id"] = eos_token_id
elif "internlm-xcomposer2-vl-7b" in self.model_name.lower():
eos_token_id = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids(['[UNUSED_TOKEN_145]'])[0]
]
generate_kwargs["eos_token_id"] = eos_token_id
if input_ids is not None:
self.model.generate(input_ids,
streamer=self.streamer[request_id], **generate_kwargs)
elif inputs_embeds is not None:
self.model.generate(inputs_embeds=inputs_embeds,
streamer=self.streamer[request_id], **generate_kwargs)
torch.xpu.empty_cache()
torch.xpu.synchronize()
from threading import Thread
t1 = Thread(target=model_generate)
t1.start()
15 changes: 15 additions & 0 deletions python/llm/src/ipex_llm/serving/fastapi/openai_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,28 @@
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Annotated
from ipex_llm.utils.common import invalidInputError
from typing_extensions import Literal


# from vllm.sampling_params import SamplingParams
def random_uuid() -> str:
return str(uuid.uuid4().hex)


class TranscriptionRequest(BaseModel):
file: str = None
model: Optional[str] = "default_model"
language: Optional[str] = "zh"
prompt: Optional[str] = None
response_format: Optional[Literal["json", "text", "srt", "verbose_json", "vtt"]] = None
temperature: Optional[float] = None
timestamp_granularities: Optional[List[Literal["word", "segment"]]] = None


class TranscriptionResponse(BaseModel):
text: str


class OpenAIBaseModel(BaseModel):
# OpenAI API does not allow extra fields
model_config = ConfigDict(extra="forbid")
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,7 @@ async def stream_output(self, cur_batch, tokenizer, next_ids):
_stream_tasks.append(self.streamer[request_id].put((remain, printable_text)))
await asyncio.gather(*_stream_tasks)

async def process_step(self, tokenizer, result_dict):
async def process_step(self, tokenizer, result_dict, processor=None):
cur_batch = None
torch.xpu.synchronize(self.device)
if self.rank == 0:
Expand Down
Loading