From 7b82942b15eed8d2811372e1efb1e696b5d06c97 Mon Sep 17 00:00:00 2001 From: hzjane Date: Mon, 19 Aug 2024 16:52:23 +0800 Subject: [PATCH 1/5] add asr init --- .../lightweight_serving.py | 19 ++- .../ipex_llm/serving/fastapi/api_server.py | 41 +++++- .../ipex_llm/serving/fastapi/model_worker.py | 127 ++++++++++++------ .../serving/fastapi/openai_protocol.py | 18 ++- 4 files changed, 150 insertions(+), 55 deletions(-) diff --git a/python/llm/example/GPU/Lightweight-Serving/lightweight_serving.py b/python/llm/example/GPU/Lightweight-Serving/lightweight_serving.py index 003307a198f..a7b62bad2e7 100644 --- a/python/llm/example/GPU/Lightweight-Serving/lightweight_serving.py +++ b/python/llm/example/GPU/Lightweight-Serving/lightweight_serving.py @@ -39,12 +39,19 @@ async def main(): model_path = args.repo_id_or_model_path low_bit = args.low_bit - local_model = ModelWorker(model_path, low_bit) - # Load tokenizer - tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left') - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - myapp = FastApp(local_model, tokenizer) + processor = None + if "whisper" not in model_path: + local_model = ModelWorker(model_path, low_bit) + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left') + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + else: + local_model = ModelWorker(model_path, low_bit, "audio", torch_dtype=torch.float32) + from transformers import WhisperProcessor + processor = WhisperProcessor.from_pretrained(model_path) + tokenizer = processor.tokenizer + myapp = FastApp(local_model, tokenizer, processor) config = uvicorn.Config(app=myapp.app, host="0.0.0.0", port=args.port) server = uvicorn.Server(config) await server.serve() diff --git a/python/llm/src/ipex_llm/serving/fastapi/api_server.py b/python/llm/src/ipex_llm/serving/fastapi/api_server.py index ea2832503ef..9269da2f2a2 100644 --- a/python/llm/src/ipex_llm/serving/fastapi/api_server.py +++ b/python/llm/src/ipex_llm/serving/fastapi/api_server.py @@ -38,6 +38,8 @@ CompletionResponse, CompletionResponseStreamChoice, CompletionStreamResponse, + TranscriptionRequest, + TranscriptionResponse, ) result_dict: Dict[str, str] = {} @@ -50,6 +52,7 @@ class InputsRequest(BaseModel): image_list: Optional[list] = None stream: Optional[bool] = False req_type: str = 'completion' + transcription_request: Optional[TranscriptionRequest] = None class ChatCompletionRequest(BaseModel): @@ -92,20 +95,27 @@ class CompletionRequest(BaseModel): global tokenizer global local_model +global processor class FastApp(): - def __init__(self, model, mytokenizer): + def __init__(self, model, mytokenizer, myprocessor = None): global tokenizer global local_model + global processor local_model = model tokenizer = mytokenizer + processor = myprocessor self.app = app def get_queue_next_token(delta_text_queue): timeout = int(os.getenv("IPEX_LLM_FASTAPI_TIMEOUT", 60)) delta_text = delta_text_queue.text_queue.get(timeout=timeout) + if "whisper" in local_model.model_name.lower(): + if delta_text is not None and "<|" in delta_text and "|>" in delta_text: + import re + delta_text = re.sub(r'<\|.*?\|>', '', delta_text) if delta_text is None: remain = 0 else: @@ -384,6 +394,33 @@ async def create_completion(request: CompletionRequest): model=model_name) return result +from typing_extensions import Literal +from fastapi import File, UploadFile, Form +@app.post("/v1/audio/transcriptions") +async def transcriptions( + file: UploadFile = File(...), + model: Optional[str] = Form("default_model"), + language: Optional[str] = Form("zh"), + prompt: Optional[str] = Form(None), + response_format: Optional[Literal["json", "text", "srt", "verbose_json", "vtt"]] = Form(None), + temperature: Optional[float] = Form(None), + timestamp_granularities: Optional[List[Literal["word", "segment"]]] = Form(None) +): + file_path = "./" + file.filename + if not os.path.exists(file_path): + with open(file_path, "wb") as f: + f.write(await file.read()) + inputs_request = InputsRequest( + inputs="transcriptions", + parameters=None, + stream=False, + req_type="completion", + transcription_request=TranscriptionRequest(file=file_path, model=model, language=language) + ) + request_id, result = await generate(inputs_request) + rsp = TranscriptionResponse(text=result) + return rsp + @app.on_event("startup") async def startup_event(): @@ -393,4 +430,4 @@ async def startup_event(): async def process_requests(local_model, result_dict): while True: await asyncio.sleep(0) - await local_model.process_step(tokenizer, result_dict) + await local_model.process_step(tokenizer, result_dict, processor) diff --git a/python/llm/src/ipex_llm/serving/fastapi/model_worker.py b/python/llm/src/ipex_llm/serving/fastapi/model_worker.py index 3d8d75bfafa..4cfb4dcb2a3 100644 --- a/python/llm/src/ipex_llm/serving/fastapi/model_worker.py +++ b/python/llm/src/ipex_llm/serving/fastapi/model_worker.py @@ -23,37 +23,69 @@ class ModelWorker: - def __init__(self, checkpoint, low_bit, torch_dtype=torch.float16): + def __init__(self, checkpoint, low_bit, model_type="normal", torch_dtype=torch.float16): self.dtype = torch_dtype start = time.perf_counter() - model = self.load_model(checkpoint, low_bit) - from ipex_llm.utils import BenchmarkWrapper - self.model = BenchmarkWrapper(model, do_print=True) + if model_type == "audio": + self.model = self.load_model(checkpoint, low_bit, "audio") + else: + model = self.load_model(checkpoint, low_bit) + from ipex_llm.utils import BenchmarkWrapper + self.model = BenchmarkWrapper(model, do_print=True) end = time.perf_counter() logger.info(f"Time to load weights: {end - start:.2f}s") self.waiting_requests = asyncio.Queue() self.streamer = {} self.model_name = checkpoint - def load_model(self, model_path, low_bit='sym_int4'): - from ipex_llm.transformers import AutoModelForCausalLM, AutoModel - try: - model = AutoModelForCausalLM.from_pretrained(model_path, - load_in_low_bit=low_bit, - torch_dtype=self.dtype, - optimize_model=True, - trust_remote_code=True, - use_cache=True,) - except: - model = AutoModel.from_pretrained(model_path, - load_in_low_bit=low_bit, - torch_dtype=self.dtype, - optimize_model=True, - trust_remote_code=True, + def load_model(self, model_path, low_bit='sym_int4', model_type="normal"): + if model_type == "audio": + from ipex_llm.transformers import AutoModelForSpeechSeq2Seq + model = AutoModelForSpeechSeq2Seq.from_pretrained(model_path, + load_in_low_bit=low_bit, + torch_dtype=self.dtype, + optimize_model=True, + trust_remote_code=True, + use_cache=True) + else: + from ipex_llm.transformers import AutoModelForCausalLM, AutoModel + try: + model = AutoModelForCausalLM.from_pretrained(model_path, + load_in_low_bit=low_bit, + torch_dtype=self.dtype, + optimize_model=True, + trust_remote_code=True, + use_cache=True,) + except: + model = AutoModel.from_pretrained(model_path, + load_in_low_bit=low_bit, + torch_dtype=self.dtype, + optimize_model=True, + trust_remote_code=True, use_cache=True,) model = model.eval().to("xpu") return model + + async def add_asr_request(self, processor): + if self.waiting_requests.empty(): + return + tmp_result = await self.waiting_requests.get() + request_id, request = tmp_result + transcription_request = request.transcription_request + forced_decoder_ids = processor.get_decoder_prompt_ids(language=transcription_request.language, task="transcribe") + audio_path = transcription_request.file + import librosa + raw_speech, sampling_rate = librosa.load(audio_path, sr=processor.feature_extractor.sampling_rate) + input_features = processor( + raw_speech, + sampling_rate=sampling_rate, + return_tensors="pt", + return_attention_mask=True, + ).input_features.to('xpu') + return input_features, forced_decoder_ids, request_id, + + async def add_request(self, tokenizer): if self.waiting_requests.empty(): return @@ -91,33 +123,40 @@ async def add_request(self, tokenizer): return input_ids, parameters, request_id, inputs_embeds @torch.no_grad() - async def process_step(self, tokenizer, result_dict): + async def process_step(self, tokenizer, result_dict, processor = None): if not self.waiting_requests.empty(): - input_ids, parameters, request_id, inputs_embeds = await self.add_request(tokenizer) - self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True) - - def model_generate(): - generate_kwargs = {k: v for k, v in parameters.dict().items() if v is not None} - if "codegeex" in self.model_name.lower(): - eos_token_id = [tokenizer.eos_token_id, - tokenizer.convert_tokens_to_ids("<|user|>"), - tokenizer.convert_tokens_to_ids("<|observation|>")] - generate_kwargs["eos_token_id"] = eos_token_id - elif "internlm-xcomposer2-vl-7b" in self.model_name.lower(): - eos_token_id = [ - tokenizer.eos_token_id, - tokenizer.convert_tokens_to_ids(['[UNUSED_TOKEN_145]'])[0] - ] - generate_kwargs["eos_token_id"] = eos_token_id - if input_ids is not None: - self.model.generate(input_ids, - streamer=self.streamer[request_id], **generate_kwargs) - elif inputs_embeds is not None: - self.model.generate(inputs_embeds=inputs_embeds, - streamer=self.streamer[request_id], **generate_kwargs) - torch.xpu.empty_cache() - torch.xpu.synchronize() + if processor is not None and "whisper" in self.model_name.lower(): + input_features, forced_decoder_ids, request_id = await self.add_asr_request(processor) + self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True) + def model_generate(): + self.model.generate(input_features, + streamer=self.streamer[request_id], + forced_decoder_ids=forced_decoder_ids) + else: + input_ids, parameters, request_id, inputs_embeds = await self.add_request(tokenizer) + self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True) + def model_generate(): + generate_kwargs = {k: v for k, v in parameters.dict().items() if v is not None} + if "codegeex" in self.model_name.lower(): + eos_token_id = [tokenizer.eos_token_id, + tokenizer.convert_tokens_to_ids("<|user|>"), + tokenizer.convert_tokens_to_ids("<|observation|>")] + generate_kwargs["eos_token_id"] = eos_token_id + elif "internlm-xcomposer2-vl-7b" in self.model_name.lower(): + eos_token_id = [ + tokenizer.eos_token_id, + tokenizer.convert_tokens_to_ids(['[UNUSED_TOKEN_145]'])[0] + ] + generate_kwargs["eos_token_id"] = eos_token_id + if input_ids is not None: + self.model.generate(input_ids, + streamer=self.streamer[request_id], **generate_kwargs) + elif inputs_embeds is not None: + self.model.generate(inputs_embeds=inputs_embeds, + streamer=self.streamer[request_id], **generate_kwargs) + torch.xpu.empty_cache() + torch.xpu.synchronize() from threading import Thread t1 = Thread(target=model_generate) t1.start() diff --git a/python/llm/src/ipex_llm/serving/fastapi/openai_protocol.py b/python/llm/src/ipex_llm/serving/fastapi/openai_protocol.py index 1bc8f1e3a69..f688ff082db 100644 --- a/python/llm/src/ipex_llm/serving/fastapi/openai_protocol.py +++ b/python/llm/src/ipex_llm/serving/fastapi/openai_protocol.py @@ -27,9 +27,21 @@ # from vllm.sampling_params import SamplingParams -def random_uuid() -> str: - return str(uuid.uuid4().hex) - +def random_uuid() -> str:BaseModel + +from typing_extensions import Literal +from fastapi import File, UploadFile +class TranscriptionRequest(BaseModel): + file: str = None + model: Optional[str] = "default_model" + language: Optional[str] = "zh" + prompt: Optional[str] = None + response_format: Optional[Literal["json", "text", "srt", "verbose_json", "vtt"]] = None + temperature: Optional[float] = None + timestamp_granularities: Optional[List[Literal["word", "segment"]]] = None + +class TranscriptionResponse(BaseModel): + text: str class OpenAIBaseModel(BaseModel): # OpenAI API does not allow extra fields From bad3a9eeea33e0aaa804d45b718869413fe00d95 Mon Sep 17 00:00:00 2001 From: hzjane Date: Tue, 20 Aug 2024 12:59:05 +0800 Subject: [PATCH 2/5] update for pp --- .../llm/example/GPU/Lightweight-Serving/lightweight_serving.py | 2 +- python/llm/src/ipex_llm/transformers/pipeline_parallel.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/llm/example/GPU/Lightweight-Serving/lightweight_serving.py b/python/llm/example/GPU/Lightweight-Serving/lightweight_serving.py index a7b62bad2e7..ce579213553 100644 --- a/python/llm/example/GPU/Lightweight-Serving/lightweight_serving.py +++ b/python/llm/example/GPU/Lightweight-Serving/lightweight_serving.py @@ -40,7 +40,7 @@ async def main(): low_bit = args.low_bit processor = None - if "whisper" not in model_path: + if "whisper" not in model_path.lower(): local_model = ModelWorker(model_path, low_bit) # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left') diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py index 8202b3ee0ee..ce2aea2fac5 100644 --- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py +++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py @@ -800,7 +800,7 @@ async def stream_output(self, cur_batch, tokenizer, next_ids): _stream_tasks.append(self.streamer[request_id].put((remain, printable_text))) await asyncio.gather(*_stream_tasks) - async def process_step(self, tokenizer, result_dict): + async def process_step(self, tokenizer, result_dict, processor = None): cur_batch = None torch.xpu.synchronize(self.device) if self.rank == 0: From d715d604b34167c74dee77244e451bd72b32fc42 Mon Sep 17 00:00:00 2001 From: hzjane Date: Tue, 20 Aug 2024 13:12:26 +0800 Subject: [PATCH 3/5] update style --- .../ipex_llm/serving/fastapi/api_server.py | 23 +++++---- .../ipex_llm/serving/fastapi/model_worker.py | 51 ++++++++++--------- .../serving/fastapi/openai_protocol.py | 9 ++-- .../transformers/pipeline_parallel.py | 2 +- 4 files changed, 45 insertions(+), 40 deletions(-) diff --git a/python/llm/src/ipex_llm/serving/fastapi/api_server.py b/python/llm/src/ipex_llm/serving/fastapi/api_server.py index 9269da2f2a2..86fc6bce4ce 100644 --- a/python/llm/src/ipex_llm/serving/fastapi/api_server.py +++ b/python/llm/src/ipex_llm/serving/fastapi/api_server.py @@ -27,6 +27,8 @@ from typing import List, Optional, Union, Dict from fastapi.middleware.cors import CORSMiddleware from .tgi_protocol import Parameters +from typing_extensions import Literal +from fastapi import File, UploadFile, Form from .openai_protocol import ( ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, @@ -99,7 +101,7 @@ class CompletionRequest(BaseModel): class FastApp(): - def __init__(self, model, mytokenizer, myprocessor = None): + def __init__(self, model, mytokenizer, myprocessor=None): global tokenizer global local_model global processor @@ -113,7 +115,7 @@ def get_queue_next_token(delta_text_queue): timeout = int(os.getenv("IPEX_LLM_FASTAPI_TIMEOUT", 60)) delta_text = delta_text_queue.text_queue.get(timeout=timeout) if "whisper" in local_model.model_name.lower(): - if delta_text is not None and "<|" in delta_text and "|>" in delta_text: + if delta_text is not None and "<|" in delta_text and "|>" in delta_text: import re delta_text = re.sub(r'<\|.*?\|>', '', delta_text) if delta_text is None: @@ -394,17 +396,16 @@ async def create_completion(request: CompletionRequest): model=model_name) return result -from typing_extensions import Literal -from fastapi import File, UploadFile, Form + @app.post("/v1/audio/transcriptions") async def transcriptions( - file: UploadFile = File(...), - model: Optional[str] = Form("default_model"), - language: Optional[str] = Form("zh"), - prompt: Optional[str] = Form(None), - response_format: Optional[Literal["json", "text", "srt", "verbose_json", "vtt"]] = Form(None), - temperature: Optional[float] = Form(None), - timestamp_granularities: Optional[List[Literal["word", "segment"]]] = Form(None) + file: UploadFile=File(...), + model: Optional[str]=Form("default_model"), + language: Optional[str]=Form("zh"), + prompt: Optional[str]=Form(None), + response_format: Optional[Literal["json", "text", "srt", "verbose_json", "vtt"]]=Form(None), + temperature: Optional[float]=Form(None), + timestamp_granularities: Optional[List[Literal["word", "segment"]]]=Form(None) ): file_path = "./" + file.filename if not os.path.exists(file_path): diff --git a/python/llm/src/ipex_llm/serving/fastapi/model_worker.py b/python/llm/src/ipex_llm/serving/fastapi/model_worker.py index 4cfb4dcb2a3..9a7b2b0be11 100644 --- a/python/llm/src/ipex_llm/serving/fastapi/model_worker.py +++ b/python/llm/src/ipex_llm/serving/fastapi/model_worker.py @@ -42,49 +42,49 @@ def load_model(self, model_path, low_bit='sym_int4', model_type="normal"): if model_type == "audio": from ipex_llm.transformers import AutoModelForSpeechSeq2Seq model = AutoModelForSpeechSeq2Seq.from_pretrained(model_path, - load_in_low_bit=low_bit, - torch_dtype=self.dtype, - optimize_model=True, - trust_remote_code=True, - use_cache=True) + load_in_low_bit=low_bit, + torch_dtype=self.dtype, + optimize_model=True, + trust_remote_code=True, + use_cache=True) else: from ipex_llm.transformers import AutoModelForCausalLM, AutoModel try: model = AutoModelForCausalLM.from_pretrained(model_path, - load_in_low_bit=low_bit, - torch_dtype=self.dtype, - optimize_model=True, - trust_remote_code=True, - use_cache=True,) + load_in_low_bit=low_bit, + torch_dtype=self.dtype, + optimize_model=True, + trust_remote_code=True, + use_cache=True,) except: model = AutoModel.from_pretrained(model_path, - load_in_low_bit=low_bit, - torch_dtype=self.dtype, - optimize_model=True, - trust_remote_code=True, - use_cache=True,) + load_in_low_bit=low_bit, + torch_dtype=self.dtype, + optimize_model=True, + trust_remote_code=True, + use_cache=True,) model = model.eval().to("xpu") return model - async def add_asr_request(self, processor): if self.waiting_requests.empty(): return tmp_result = await self.waiting_requests.get() request_id, request = tmp_result transcription_request = request.transcription_request - forced_decoder_ids = processor.get_decoder_prompt_ids(language=transcription_request.language, task="transcribe") + forced_decoder_ids = processor.get_decoder_prompt_ids( + language=transcription_request.language, task="transcribe") audio_path = transcription_request.file import librosa - raw_speech, sampling_rate = librosa.load(audio_path, sr=processor.feature_extractor.sampling_rate) + raw_speech, sampling_rate = librosa.load(audio_path, + sr=processor.feature_extractor.sampling_rate) input_features = processor( raw_speech, sampling_rate=sampling_rate, return_tensors="pt", return_attention_mask=True, ).input_features.to('xpu') - return input_features, forced_decoder_ids, request_id, - + return input_features, forced_decoder_ids, request_id async def add_request(self, tokenizer): if self.waiting_requests.empty(): @@ -123,15 +123,16 @@ async def add_request(self, tokenizer): return input_ids, parameters, request_id, inputs_embeds @torch.no_grad() - async def process_step(self, tokenizer, result_dict, processor = None): + async def process_step(self, tokenizer, result_dict, processor=None): if not self.waiting_requests.empty(): if processor is not None and "whisper" in self.model_name.lower(): - input_features, forced_decoder_ids, request_id = await self.add_asr_request(processor) + input_features, decoder_ids, request_id = await self.add_asr_request(processor) self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True) + def model_generate(): - self.model.generate(input_features, - streamer=self.streamer[request_id], - forced_decoder_ids=forced_decoder_ids) + self.model.generate(input_features, + streamer=self.streamer[request_id], + forced_decoder_ids=decoder_ids) else: input_ids, parameters, request_id, inputs_embeds = await self.add_request(tokenizer) self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True) diff --git a/python/llm/src/ipex_llm/serving/fastapi/openai_protocol.py b/python/llm/src/ipex_llm/serving/fastapi/openai_protocol.py index f688ff082db..ca5963af1dd 100644 --- a/python/llm/src/ipex_llm/serving/fastapi/openai_protocol.py +++ b/python/llm/src/ipex_llm/serving/fastapi/openai_protocol.py @@ -24,13 +24,14 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator from typing_extensions import Annotated from ipex_llm.utils.common import invalidInputError +from typing_extensions import Literal # from vllm.sampling_params import SamplingParams -def random_uuid() -> str:BaseModel +def random_uuid() -> str: + return str(uuid.uuid4().hex) + -from typing_extensions import Literal -from fastapi import File, UploadFile class TranscriptionRequest(BaseModel): file: str = None model: Optional[str] = "default_model" @@ -40,9 +41,11 @@ class TranscriptionRequest(BaseModel): temperature: Optional[float] = None timestamp_granularities: Optional[List[Literal["word", "segment"]]] = None + class TranscriptionResponse(BaseModel): text: str + class OpenAIBaseModel(BaseModel): # OpenAI API does not allow extra fields model_config = ConfigDict(extra="forbid") diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py index ce2aea2fac5..87167d81573 100644 --- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py +++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py @@ -800,7 +800,7 @@ async def stream_output(self, cur_batch, tokenizer, next_ids): _stream_tasks.append(self.streamer[request_id].put((remain, printable_text))) await asyncio.gather(*_stream_tasks) - async def process_step(self, tokenizer, result_dict, processor = None): + async def process_step(self, tokenizer, result_dict, processor=None): cur_batch = None torch.xpu.synchronize(self.device) if self.rank == 0: From 0594bd2c5e39f01dbe873da8f128e195286ebbb3 Mon Sep 17 00:00:00 2001 From: hzjane Date: Wed, 21 Aug 2024 09:31:40 +0800 Subject: [PATCH 4/5] update readme --- .../example/GPU/Lightweight-Serving/README.md | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/python/llm/example/GPU/Lightweight-Serving/README.md b/python/llm/example/GPU/Lightweight-Serving/README.md index 4cb29db1efc..60f27539fb5 100644 --- a/python/llm/example/GPU/Lightweight-Serving/README.md +++ b/python/llm/example/GPU/Lightweight-Serving/README.md @@ -22,6 +22,10 @@ conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc # for internlm-xcomposer2-vl-7b pip install transformers==4.31.0 pip install accelerate timm==0.4.12 sentencepiece==0.1.99 gradio==3.44.4 markdown2==2.4.10 xlsxwriter==3.1.2 einops + +# for whisper-large-v3 +pip install transformers==4.36.2 +pip install datasets soundfile librosa # required by audio processing ``` #### 1.2 Installation on Windows @@ -35,6 +39,14 @@ pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-exte pip install fastapi uvicorn openai pip install gradio # for gradio web UI conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc + +# for internlm-xcomposer2-vl-7b +pip install transformers==4.31.0 +pip install accelerate timm==0.4.12 sentencepiece==0.1.99 gradio==3.44.4 markdown2==2.4.10 xlsxwriter==3.1.2 einops + +# for whisper-large-v3 +pip install transformers==4.36.2 +pip install datasets soundfile librosa # required by audio processing ``` ### 2. Configures OneAPI environment variables for Linux @@ -219,6 +231,17 @@ curl http://localhost:8000/v1/completions \ }' ``` +#### v1/audio/transcriptions + +ASR only supports [whisper-large-v3](https://huggingface.co/openai/whisper-large-v3) now. +```bash +curl http://localhost:8000/v1/audio/transcriptions \ + -H "Content-Type: multipart/form-data" \ + -F file="@/llm/test.mp3" \ + -F model="whisper-large-v3" \ + -F languag="zh" +``` + ### 6. Benchmark with wrk Please refer to [here](https://github.com/intel-analytics/ipex-llm/tree/main/python/llm/example/GPU/Pipeline-Parallel-Serving#4-benchmark-with-wrk) for more details From c4a0a06535af4365289a5c227a34a37cb20bc0a3 Mon Sep 17 00:00:00 2001 From: hzjane Date: Wed, 21 Aug 2024 09:59:39 +0800 Subject: [PATCH 5/5] update reamde --- python/llm/example/GPU/Lightweight-Serving/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/llm/example/GPU/Lightweight-Serving/README.md b/python/llm/example/GPU/Lightweight-Serving/README.md index 60f27539fb5..c21aa880bfd 100644 --- a/python/llm/example/GPU/Lightweight-Serving/README.md +++ b/python/llm/example/GPU/Lightweight-Serving/README.md @@ -192,7 +192,7 @@ curl http://localhost:8000/v1/chat/completions \ image input only supports [internlm-xcomposer2-vl-7b](https://huggingface.co/internlm/internlm-xcomposer2-vl-7b) now, and it must install transformers==4.31.0 to run. ```bash -wget -O ./test.jpg http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg +wget -O /llm/lightweight_serving/test.jpg http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg curl http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ @@ -233,7 +233,7 @@ curl http://localhost:8000/v1/completions \ #### v1/audio/transcriptions -ASR only supports [whisper-large-v3](https://huggingface.co/openai/whisper-large-v3) now. +ASR only supports [whisper-large-v3](https://huggingface.co/openai/whisper-large-v3) now. And `whisper-large-v3` just can be used to transcription audio. The audio file_type should be supported by `librosa.load`. ```bash curl http://localhost:8000/v1/audio/transcriptions \ -H "Content-Type: multipart/form-data" \