From fabc395d0d6c9b60b0da7e2da71a87a98523c766 Mon Sep 17 00:00:00 2001 From: Guancheng Fu <110874468+gc-fu@users.noreply.github.com> Date: Fri, 24 May 2024 17:19:27 +0800 Subject: [PATCH] add langchain vllm interface (#11121) * done * fix * fix * add vllm * add langchain vllm exampels * add docs * temp --- .../DockerGuides/vllm_docker_quickstart.md | 2 +- python/llm/example/GPU/LangChain/README.md | 93 +++++-- python/llm/example/GPU/LangChain/vllm.py | 45 ++++ .../src/ipex_llm/langchain/vllm/__init__.py | 20 ++ .../llm/src/ipex_llm/langchain/vllm/vllm.py | 229 ++++++++++++++++++ 5 files changed, 371 insertions(+), 18 deletions(-) create mode 100644 python/llm/example/GPU/LangChain/vllm.py create mode 100644 python/llm/src/ipex_llm/langchain/vllm/__init__.py create mode 100644 python/llm/src/ipex_llm/langchain/vllm/vllm.py diff --git a/docs/readthedocs/source/doc/LLM/DockerGuides/vllm_docker_quickstart.md b/docs/readthedocs/source/doc/LLM/DockerGuides/vllm_docker_quickstart.md index 80f9ba657ee..56776ca9974 100644 --- a/docs/readthedocs/source/doc/LLM/DockerGuides/vllm_docker_quickstart.md +++ b/docs/readthedocs/source/doc/LLM/DockerGuides/vllm_docker_quickstart.md @@ -82,7 +82,7 @@ If the service have booted successfully, you should see the output similar to th vLLM supports to utilize multiple cards through tensor parallel. -You can refer to this [documentation](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Quickstart/vLLM_quickstart.html#about-tensor-paralle) on how to utilize the `tensor-parallel` feature and start the service. +You can refer to this [documentation](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Quickstart/vLLM_quickstart.html#about-tensor-parallel) on how to utilize the `tensor-parallel` feature and start the service. #### Verify After the service has been booted successfully, you can send a test request using `curl`. Here, `YOUR_MODEL` should be set equal to `served_model_name` in your booting script, e.g. `Qwen1.5`. diff --git a/python/llm/example/GPU/LangChain/README.md b/python/llm/example/GPU/LangChain/README.md index ea5638cf806..9731e715de3 100644 --- a/python/llm/example/GPU/LangChain/README.md +++ b/python/llm/example/GPU/LangChain/README.md @@ -5,15 +5,7 @@ The examples in this folder shows how to use [LangChain](https://www.langchain.c ### 1. Install ipex-llm Follow the instructions in [GPU Install Guide](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Overview/install_gpu.html) to install ipex-llm -### 2. Install Required Dependencies for langchain examples. - -```bash -pip install langchain==0.0.184 -pip install -U chromadb==0.3.25 -pip install -U pandas==2.0.3 -``` - -### 3. Configures OneAPI environment variables for Linux +### 2. Configures OneAPI environment variables for Linux > [!NOTE] > Skip this step if you are running on Windows. @@ -24,9 +16,9 @@ This is a required step on Linux for APT or offline installed oneAPI. Skip this source /opt/intel/oneapi/setvars.sh ``` -### 4. Runtime Configurations +### 3. Runtime Configurations For optimal performance, it is recommended to set several environment variables. Please check out the suggestions based on your device. -#### 4.1 Configurations for Linux +#### 3.1 Configurations for Linux
For Intel Arcâ„¢ A-Series Graphics and Intel Data Center GPU Flex Series @@ -63,7 +55,7 @@ export BIGDL_LLM_XMX_DISABLED=1
-#### 4.2 Configurations for Windows +#### 3.2 Configurations for Windows
For Intel iGPU @@ -88,9 +80,18 @@ set SYCL_CACHE_PERSISTENT=1 > [!NOTE] > For the first time that each model runs on Intel iGPU/Intel Arcâ„¢ A300-Series or Pro A60, it may take several minutes to compile. -### 5. Run the examples +### 4. Run the examples + +#### 4.1. Streaming Chat -#### 5.1. Streaming Chat +Install dependencies: + +```bash +pip install langchain==0.0.184 +pip install -U pandas==2.0.3 +``` + +Then execute: ```bash python chat.py -m MODEL_PATH -q QUESTION @@ -99,7 +100,16 @@ arguments info: - `-m MODEL_PATH`: **required**, path to the model - `-q QUESTION`: question to ask. Default is `What is AI?`. -#### 5.2. RAG (Retrival Augmented Generation) +#### 4.2. RAG (Retrival Augmented Generation) + +Install dependencies: +```bash +pip install langchain==0.0.184 +pip install -U chromadb==0.3.25 +pip install -U pandas==2.0.3 +``` + +Then execute: ```bash python rag.py -m [-q QUESTION] [-i INPUT_PATH] @@ -110,16 +120,65 @@ arguments info: - `-i INPUT_PATH`: path to the input doc. -#### 5.2. Low Bit +#### 4.3. Low Bit The low_bit example ([low_bit.py](./low_bit.py)) showcases how to use use langchain with low_bit optimized model. By `save_low_bit` we save the weights of low_bit model into the target folder. > Note: `save_low_bit` only saves the weights of the model. > Users could copy the tokenizer model into the target folder or specify `tokenizer_id` during initialization. + +Install dependencies: +```bash +pip install langchain==0.0.184 +pip install -U pandas==2.0.3 +``` +Then execute: + ```bash python low_bit.py -m -t [-q ] ``` **Runtime Arguments Explained**: - `-m MODEL_PATH`: **Required**, the path to the model - `-t TARGET_PATH`: **Required**, the path to save the low_bit model -- `-q QUESTION`: the question \ No newline at end of file +- `-q QUESTION`: the question + +#### 4.4 vLLM + +The vLLM example ([vllm.py](./vllm.py)) showcases how to use langchain with ipex-llm integrated vLLM engine. + +Install dependencies: +```bash +pip install "langchain<0.2" +``` + +Besides, you should also install IPEX-LLM integrated vLLM according instructions listed [here](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Quickstart/vLLM_quickstart.html#install-vllm) + +**Runtime Arguments Explained**: +- `-m MODEL_PATH`: **Required**, the path to the model +- `-q QUESTION`: the question +- `-t MAX_TOKENS`: max tokens to generate, default 128 +- `-p TENSOR_PARALLEL_SIZE`: Use multiple cards for generation +- `-l LOAD_IN_LOW_BIT`: Low bit format for quantization + +##### Single card + +The following command shows an example on how to execute the example using one card: + +```bash +python ./vllm.py -m YOUR_MODEL_PATH -q "What is AI?" -t 128 -p 1 -l sym_int4 +``` + +##### Multi cards + +To use `-p TENSOR_PARALLEL_SIZE` option, you will need to use our docker image: `intelanalytics/ipex-llm-serving-xpu:latest`. For how to use the image, try check this [guide](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/DockerGuides/vllm_docker_quickstart.html#multi-card-serving). + +The following command shows an example on how to execute the example using two cards: + +```bash +export CCL_WORKER_COUNT=2 +export FI_PROVIDER=shm +export CCL_ATL_TRANSPORT=ofi +export CCL_ZE_IPC_EXCHANGE=sockets +export CCL_ATL_SHM=1 +python ./vllm.py -m YOUR_MODEL_PATH -q "What is AI?" -t 128 -p 2 -l sym_int4 +``` \ No newline at end of file diff --git a/python/llm/example/GPU/LangChain/vllm.py b/python/llm/example/GPU/LangChain/vllm.py new file mode 100644 index 00000000000..27084d45cc6 --- /dev/null +++ b/python/llm/example/GPU/LangChain/vllm.py @@ -0,0 +1,45 @@ +from ipex_llm.langchain.vllm.vllm import VLLM +from langchain.chains import LLMChain +from langchain_core.prompts import PromptTemplate +import argparse + +def main(args): + llm = VLLM( + model=args.model_path, + trust_remote_code=True, # mandatory for hf models + max_new_tokens=128, + top_k=10, + top_p=0.95, + temperature=0.8, + max_model_len=2048, + enforce_eager=True, + load_in_low_bit=args.load_in_low_bit, + device="xpu", + tensor_parallel_size=args.tensor_parallel_size, + ) + + print(llm.invoke(args.question)) + + template = """Question: {question} + + Answer: Let's think step by step.""""" + prompt = PromptTemplate.from_template(template) + + llm_chain = LLMChain(prompt=prompt, llm=llm) + + print(llm_chain.invoke("Who was the US president in the year the first Pokemon game was released?")) + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Langchain integrated vLLM example') + parser.add_argument('-m','--model-path', type=str, required=True, + help='the path to transformers model') + parser.add_argument('-q', '--question', type=str, default='What is the capital of France?', help='qustion you want to ask.') + parser.add_argument('-t', '--max-tokens', type=int, default=128, help='max tokens to generate') + parser.add_argument('-p', '--tensor-parallel-size', type=int, default=1, help="vLLM tensor parallel size") + parser.add_argument('-l', '--load-in-low-bit', type=str, default='sym_int4', help="low bit format") + args = parser.parse_args() + + main(args) + diff --git a/python/llm/src/ipex_llm/langchain/vllm/__init__.py b/python/llm/src/ipex_llm/langchain/vllm/__init__.py new file mode 100644 index 00000000000..dbdafd2a8c2 --- /dev/null +++ b/python/llm/src/ipex_llm/langchain/vllm/__init__.py @@ -0,0 +1,20 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This would makes sure Python is aware there is more than one sub-package within bigdl, +# physically located elsewhere. +# Otherwise there would be module not found error in non-pip's setting as Python would +# only search the first bigdl package and end up finding only one sub-package. diff --git a/python/llm/src/ipex_llm/langchain/vllm/vllm.py b/python/llm/src/ipex_llm/langchain/vllm/vllm.py new file mode 100644 index 00000000000..c4b2b4a7147 --- /dev/null +++ b/python/llm/src/ipex_llm/langchain/vllm/vllm.py @@ -0,0 +1,229 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file is adapted from +# https://github.com/hwchase17/langchain/blob/master/langchain/llms/llamacpp.py + +# The MIT License + +# Copyright (c) Harrison Chase + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +from typing import Any, Dict, List, Optional + +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models.llms import BaseLLM +from langchain_core.outputs import Generation, LLMResult +from langchain_core.pydantic_v1 import Field, root_validator + +from langchain_community.llms.openai import BaseOpenAI +from langchain_community.utils.openai import is_openai_v1 + + +class VLLM(BaseLLM): + """VLLM language model.""" + + model: str = "" + """The name or path of a HuggingFace Transformers model.""" + + tensor_parallel_size: Optional[int] = 1 + """The number of GPUs to use for distributed execution with tensor parallelism.""" + + trust_remote_code: Optional[bool] = False + """Trust remote code (e.g., from HuggingFace) when downloading the model + and tokenizer.""" + + n: int = 1 + """Number of output sequences to return for the given prompt.""" + + best_of: Optional[int] = None + """Number of output sequences that are generated from the prompt.""" + + presence_penalty: float = 0.0 + """Float that penalizes new tokens based on whether they appear in the + generated text so far""" + + frequency_penalty: float = 0.0 + """Float that penalizes new tokens based on their frequency in the + generated text so far""" + + temperature: float = 1.0 + """Float that controls the randomness of the sampling.""" + + top_p: float = 1.0 + """Float that controls the cumulative probability of the top tokens to consider.""" + + top_k: int = -1 + """Integer that controls the number of top tokens to consider.""" + + use_beam_search: bool = False + """Whether to use beam search instead of sampling.""" + + stop: Optional[List[str]] = None + """List of strings that stop the generation when they are generated.""" + + ignore_eos: bool = False + """Whether to ignore the EOS token and continue generating tokens after + the EOS token is generated.""" + + max_new_tokens: int = 512 + """Maximum number of tokens to generate per output sequence.""" + + logprobs: Optional[int] = None + """Number of log probabilities to return per output token.""" + + dtype: str = "auto" + """The data type for the model weights and activations.""" + + download_dir: Optional[str] = None + """Directory to download and load the weights. (Default to the default + cache dir of huggingface)""" + + vllm_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Holds any model parameters valid for `vllm.LLM` call not explicitly specified.""" + + load_in_low_bit: str = "sym_int4" + """Load in low bit format for ipex-llm low-bit quantization""" + + device: str = "xpu" + + enforce_eager: bool = True + + + client: Any #: :meta private: + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + print(values) + """Validate that python package exists in environment.""" + + try: + # from vllm import LLM as VLLModel + from ipex_llm.vllm.engine import IPEXLLMClass as VLLModel + except ImportError: + raise ImportError( + "Could not import vllm python package. " + "Please install it with `pip install vllm`." + ) + + values["client"] = VLLModel( + model=values["model"], + tensor_parallel_size=values["tensor_parallel_size"], + trust_remote_code=values["trust_remote_code"], + dtype=values["dtype"], + download_dir=values["download_dir"], + load_in_low_bit=values["load_in_low_bit"], + device=values["device"], + enforce_eager=values["enforce_eager"], + **values["vllm_kwargs"], + ) + + return values + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling vllm.""" + return { + "n": self.n, + "best_of": self.best_of, + "max_tokens": self.max_new_tokens, + "top_k": self.top_k, + "top_p": self.top_p, + "temperature": self.temperature, + "presence_penalty": self.presence_penalty, + "frequency_penalty": self.frequency_penalty, + "stop": self.stop, + "ignore_eos": self.ignore_eos, + "use_beam_search": self.use_beam_search, + "logprobs": self.logprobs, + } + + def _generate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> LLMResult: + """Run the LLM on the given prompt and input.""" + + from vllm import SamplingParams + + # build sampling parameters + params = {**self._default_params, **kwargs, "stop": stop} + sampling_params = SamplingParams(**params) + # call the model + outputs = self.client.generate(prompts, sampling_params) + + generations = [] + for output in outputs: + text = output.outputs[0].text + generations.append([Generation(text=text)]) + + return LLMResult(generations=generations) + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "vllm" + + +class VLLMOpenAI(BaseOpenAI): + """vLLM OpenAI-compatible API client""" + + @classmethod + def is_lc_serializable(cls) -> bool: + return False + + @property + def _invocation_params(self) -> Dict[str, Any]: + """Get the parameters used to invoke the model.""" + + params: Dict[str, Any] = { + "model": self.model_name, + **self._default_params, + "logit_bias": None, + } + if not is_openai_v1(): + params.update( + { + "api_key": self.openai_api_key, + "api_base": self.openai_api_base, + } + ) + + return params + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "vllm-openai"