diff --git a/llama3.1-70b-instruct-awq/.bentoignore b/llama3.1-70b-instruct-awq/.bentoignore new file mode 100644 index 0000000..d9cf115 --- /dev/null +++ b/llama3.1-70b-instruct-awq/.bentoignore @@ -0,0 +1,5 @@ +__pycache__/ +*.py[cod] +*$py.class +.ipynb_checkpoints +venv/ diff --git a/llama3.1-70b-instruct-awq/README.md b/llama3.1-70b-instruct-awq/README.md new file mode 100644 index 0000000..df155d9 --- /dev/null +++ b/llama3.1-70b-instruct-awq/README.md @@ -0,0 +1,168 @@ +
+

Self-host Llama 3.1 70B with vLLM and BentoML

+
+ +This is a BentoML example project, showing you how to serve and deploy Llama 3.1 70B (with AWQ quantization) using [vLLM](https://vllm.ai), a high-throughput and memory-efficient inference engine. + +See [here](https://github.com/bentoml/BentoML?tab=readme-ov-file#%EF%B8%8F-what-you-can-build-with-bentoml) for a full list of BentoML example projects. + +💡 This example is served as a basis for advanced code customization, such as custom model, inference logic or vLLM options. For simple LLM hosting with OpenAI compatible endpoint without writing any code, see [OpenLLM](https://github.com/bentoml/OpenLLM). + + +## Prerequisites + +- You have installed Python 3.8+ and `pip`. See the [Python downloads page](https://www.python.org/downloads/) to learn more. +- You have a basic understanding of key concepts in BentoML, such as Services. We recommend you read [Quickstart](https://docs.bentoml.com/en/1.2/get-started/quickstart.html) first. +- You have gained access to Llama 3.1 8B on [its official website](https://llama.meta.com/) and [Hugging Face](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct). +- If you want to test the Service locally, you need a Nvidia GPU with at least 48G VRAM. +- (Optional) We recommend you create a virtual environment for dependency isolation for this project. See the [Conda documentation](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html) or the [Python documentation](https://docs.python.org/3/library/venv.html) for details. + +## Install dependencies + +```bash +git clone https://github.com/bentoml/BentoVLLM.git +cd BentoVLLM/llama3.1-70b-instruct-awq +pip install -r requirements.txt +``` + +## Run the BentoML Service + +We have defined a BentoML Service in `service.py`. Run `bentoml serve` in your project directory to start the Service. + +```python +$ bentoml serve . + +2024-01-18T07:51:30+0800 [INFO] [cli] Starting production HTTP BentoServer from "service:VLLM" listening on http://localhost:3000 (Press CTRL+C to quit) +INFO 01-18 07:51:40 model_runner.py:501] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. +INFO 01-18 07:51:40 model_runner.py:505] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. +INFO 01-18 07:51:46 model_runner.py:547] Graph capturing finished in 6 secs. +``` + +The server is now active at [http://localhost:3000](http://localhost:3000/). You can interact with it using the Swagger UI or in other different ways. + +
+ +CURL + +```bash +curl -X 'POST' \ + 'http://localhost:3000/generate' \ + -H 'accept: text/event-stream' \ + -H 'Content-Type: application/json' \ + -d '{ + "prompt": "Explain superconductors like I'\''m five years old", + "tokens": null +}' +``` + +
+ +
+ +Python client + +```python +import bentoml + +with bentoml.SyncHTTPClient("http://localhost:3000") as client: + response_generator = client.generate( + prompt="Explain superconductors like I'm five years old", + tokens=None + ) + for response in response_generator: + print(response) +``` + +
+ +
+ +OpenAI-compatible endpoints + +This Service uses the `@openai_endpoints` decorator to set up OpenAI-compatible endpoints (`chat/completions` and `completions`). This means your client can interact with the backend Service (in this case, the VLLM class) as if they were communicating directly with OpenAI's API. This [utility](bentovllm_openai/) does not affect your BentoML Service code, and you can use it for other LLMs as well. + +```python +from openai import OpenAI + +client = OpenAI(base_url='http://localhost:3000/v1', api_key='na') + +# Use the following func to get the available models +client.models.list() + +chat_completion = client.chat.completions.create( + model="casperhansen/llama-3.1-70b-instruct-awq", + messages=[ + { + "role": "user", + "content": "Explain superconductors like I'm five years old" + } + ], + stream=True, + stop=["<|eot_id|>", "<|end_of_text|>"], +) +for chunk in chat_completion: + # Extract and print the content of the model's reply + print(chunk.choices[0].delta.content or "", end="") +``` + +These OpenAI-compatible endpoints also support [vLLM extra parameters](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters). For example, you can force the chat completion output a JSON object by using the `guided_json` parameters: + +```python +from openai import OpenAI + +client = OpenAI(base_url='http://localhost:3000/v1', api_key='na') + +# Use the following func to get the available models +client.models.list() + +json_schema = { + "type": "object", + "properties": { + "city": {"type": "string"} + } +} + +chat_completion = client.chat.completions.create( + model="casperhansen/llama-3.1-70b-instruct-awq", + messages=[ + { + "role": "user", + "content": "What is the capital of France?" + } + ], + extra_body=dict(guided_json=json_schema), +) +print(chat_completion.choices[0].message.content) # will return something like: {"city": "Paris"} +``` + +All supported extra parameters are listed in [vLLM documentation](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters). + +**Note**: If your Service is deployed with [protected endpoints on BentoCloud](https://docs.bentoml.com/en/latest/bentocloud/how-tos/manage-access-token.html#access-protected-deployments), you need to set the environment variable `OPENAI_API_KEY` to your BentoCloud API key first. + +```bash +export OPENAI_API_KEY={YOUR_BENTOCLOUD_API_TOKEN} +``` + +You can then use the following line to replace the client in the above code snippet. Refer to [Obtain the endpoint URL](https://docs.bentoml.com/en/latest/bentocloud/how-tos/call-deployment-endpoints.html#obtain-the-endpoint-url) to retrieve the endpoint URL. + +```python +client = OpenAI(base_url='your_bentocloud_deployment_endpoint_url/v1') +``` + +
+ +For detailed explanations of the Service code, see [vLLM inference](https://docs.bentoml.org/en/latest/use-cases/large-language-models/vllm.html). + +## Deploy to BentoCloud + +After the Service is ready, you can deploy the application to BentoCloud for better management and scalability. [Sign up](https://www.bentoml.com/) if you haven't got a BentoCloud account. + +Make sure you have [logged in to BentoCloud](https://docs.bentoml.com/en/latest/bentocloud/how-tos/manage-access-token.html), then run the following command to deploy it. + +```bash +bentoml deploy . +``` + +Once the application is up and running on BentoCloud, you can access it via the exposed URL. + +**Note**: For custom deployment in your own infrastructure, use [BentoML to generate an OCI-compliant image](https://docs.bentoml.com/en/latest/guides/containerization.html). diff --git a/llama3.1-70b-instruct-awq/bentofile.yaml b/llama3.1-70b-instruct-awq/bentofile.yaml new file mode 100644 index 0000000..55ebd2a --- /dev/null +++ b/llama3.1-70b-instruct-awq/bentofile.yaml @@ -0,0 +1,12 @@ +service: 'service:VLLM' +labels: + owner: bentoml-team + stage: demo +include: + - '*.py' + - 'bentovllm_openai/*.py' +python: + requirements_txt: './requirements.txt' + lock_packages: false +docker: + python_version: 3.11 diff --git a/llama3.1-70b-instruct-awq/bentovllm_openai/__init__.py b/llama3.1-70b-instruct-awq/bentovllm_openai/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/llama3.1-70b-instruct-awq/bentovllm_openai/protocol.py b/llama3.1-70b-instruct-awq/bentovllm_openai/protocol.py new file mode 100644 index 0000000..c024bbc --- /dev/null +++ b/llama3.1-70b-instruct-awq/bentovllm_openai/protocol.py @@ -0,0 +1,734 @@ +# Adapted from +# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py +import time +from typing import Any, Dict, List, Literal, Optional, Union + +import torch +from pydantic import BaseModel, ConfigDict, Field, model_validator +from typing_extensions import Annotated + +from vllm.entrypoints.chat_utils import ChatCompletionMessageParam +from vllm.pooling_params import PoolingParams +from vllm.sampling_params import SamplingParams +from vllm.utils import random_uuid + + +class OpenAIBaseModel(BaseModel): + # OpenAI API does not allow extra fields + model_config = ConfigDict(extra="forbid") + + +class ErrorResponse(OpenAIBaseModel): + object: str = "error" + message: str + type: str + param: Optional[str] = None + code: int + + +class ModelPermission(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}") + object: str = "model_permission" + created: int = Field(default_factory=lambda: int(time.time())) + allow_create_engine: bool = False + allow_sampling: bool = True + allow_logprobs: bool = True + allow_search_indices: bool = False + allow_view: bool = True + allow_fine_tuning: bool = False + organization: str = "*" + group: Optional[str] = None + is_blocking: bool = False + + +class ModelCard(OpenAIBaseModel): + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "vllm" + root: Optional[str] = None + parent: Optional[str] = None + max_model_len: Optional[int] = None + permission: List[ModelPermission] = Field(default_factory=list) + + +class ModelList(OpenAIBaseModel): + object: str = "list" + data: List[ModelCard] = Field(default_factory=list) + + +class UsageInfo(OpenAIBaseModel): + prompt_tokens: int = 0 + total_tokens: int = 0 + completion_tokens: Optional[int] = 0 + + +class ResponseFormat(OpenAIBaseModel): + # type must be "json_object" or "text" + type: Literal["text", "json_object"] + + +class StreamOptions(OpenAIBaseModel): + include_usage: Optional[bool] = True + continuous_usage_stats: Optional[bool] = True + + +class FunctionDefinition(OpenAIBaseModel): + name: str + description: Optional[str] = None + parameters: Optional[Dict[str, Any]] = None + + +class ChatCompletionToolsParam(OpenAIBaseModel): + type: Literal["function"] = "function" + function: FunctionDefinition + + +class ChatCompletionNamedFunction(OpenAIBaseModel): + name: str + + +class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel): + function: ChatCompletionNamedFunction + type: Literal["function"] = "function" + + +class ChatCompletionRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/chat/create + messages: List[ChatCompletionMessageParam] + model: str + frequency_penalty: Optional[float] = 0.0 + logit_bias: Optional[Dict[str, float]] = None + logprobs: Optional[bool] = False + top_logprobs: Optional[int] = 0 + max_tokens: Optional[int] = None + n: Optional[int] = 1 + presence_penalty: Optional[float] = 0.0 + response_format: Optional[ResponseFormat] = None + seed: Optional[int] = Field(None, + ge=torch.iinfo(torch.long).min, + le=torch.iinfo(torch.long).max) + stop: Optional[Union[str, List[str]]] = Field(default_factory=list) + stream: Optional[bool] = False + stream_options: Optional[StreamOptions] = None + temperature: Optional[float] = 0.7 + top_p: Optional[float] = 1.0 + tools: Optional[List[ChatCompletionToolsParam]] = None + tool_choice: Optional[Union[Literal["none"], + ChatCompletionNamedToolChoiceParam]] = "none" + user: Optional[str] = None + + # doc: begin-chat-completion-sampling-params + best_of: Optional[int] = None + use_beam_search: bool = False + top_k: int = -1 + min_p: float = 0.0 + repetition_penalty: float = 1.0 + length_penalty: float = 1.0 + early_stopping: bool = False + stop_token_ids: Optional[List[int]] = Field(default_factory=list) + include_stop_str_in_output: bool = False + ignore_eos: bool = False + min_tokens: int = 0 + skip_special_tokens: bool = True + spaces_between_special_tokens: bool = True + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + # doc: end-chat-completion-sampling-params + + # doc: begin-chat-completion-extra-params + echo: bool = Field( + default=False, + description=( + "If true, the new message will be prepended with the last message " + "if they belong to the same role."), + ) + add_generation_prompt: bool = Field( + default=True, + description= + ("If true, the generation prompt will be added to the chat template. " + "This is a parameter used by chat template in tokenizer config of the " + "model."), + ) + add_special_tokens: bool = Field( + default=False, + description=( + "If true, special tokens (e.g. BOS) will be added to the prompt " + "on top of what is added by the chat template. " + "For most models, the chat template takes care of adding the " + "special tokens so this should be set to false (as is the " + "default)."), + ) + documents: Optional[List[Dict[str, str]]] = Field( + default=None, + description= + ("A list of dicts representing documents that will be accessible to " + "the model if it is performing RAG (retrieval-augmented generation)." + " If the template does not support RAG, this argument will have no " + "effect. We recommend that each document should be a dict containing " + "\"title\" and \"text\" keys."), + ) + chat_template: Optional[str] = Field( + default=None, + description=( + "A Jinja template to use for this conversion. " + "If this is not passed, the model's default chat template will be " + "used instead."), + ) + chat_template_kwargs: Optional[Dict[str, Any]] = Field( + default=None, + description=("Additional kwargs to pass to the template renderer. " + "Will be accessible by the chat template."), + ) + guided_json: Optional[Union[str, dict, BaseModel]] = Field( + default=None, + description=("If specified, the output will follow the JSON schema."), + ) + guided_regex: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the regex pattern."), + ) + guided_choice: Optional[List[str]] = Field( + default=None, + description=( + "If specified, the output will be exactly one of the choices."), + ) + guided_grammar: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the context free grammar."), + ) + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be either " + "'outlines' / 'lm-format-enforcer'")) + guided_whitespace_pattern: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding.")) + + # doc: end-chat-completion-extra-params + + def to_sampling_params(self) -> SamplingParams: + # We now allow logprobs being true without top_logrobs. + + logits_processors = None + if self.logit_bias: + logit_bias: Dict[int, float] = {} + try: + for token_id, bias in self.logit_bias.items(): + # Convert token_id to integer before we add to LLMEngine + # Clamp the bias between -100 and 100 per OpenAI API spec + logit_bias[int(token_id)] = min(100, max(-100, bias)) + except ValueError as exc: + raise ValueError(f"Found token_id `{token_id}` in logit_bias " + f"but token_id must be an integer or string " + f"representing an integer") from exc + + def logit_bias_logits_processor( + token_ids: List[int], + logits: torch.Tensor) -> torch.Tensor: + for token_id, bias in logit_bias.items(): + logits[token_id] += bias + return logits + + logits_processors = [logit_bias_logits_processor] + + return SamplingParams( + n=self.n, + best_of=self.best_of, + presence_penalty=self.presence_penalty, + frequency_penalty=self.frequency_penalty, + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, + min_p=self.min_p, + seed=self.seed, + stop=self.stop, + stop_token_ids=self.stop_token_ids, + logprobs=self.top_logprobs if self.logprobs else None, + prompt_logprobs=self.top_logprobs if self.echo else None, + ignore_eos=self.ignore_eos, + max_tokens=self.max_tokens, + min_tokens=self.min_tokens, + use_beam_search=self.use_beam_search, + early_stopping=self.early_stopping, + skip_special_tokens=self.skip_special_tokens, + spaces_between_special_tokens=self.spaces_between_special_tokens, + include_stop_str_in_output=self.include_stop_str_in_output, + length_penalty=self.length_penalty, + logits_processors=logits_processors, + truncate_prompt_tokens=self.truncate_prompt_tokens, + ) + + @model_validator(mode='before') + @classmethod + def validate_stream_options(cls, values): + if (values.get('stream_options') is not None + and not values.get('stream')): + raise ValueError( + "stream_options can only be set if stream is true") + return values + + @model_validator(mode="before") + @classmethod + def check_guided_decoding_count(cls, data): + guide_count = sum([ + "guided_json" in data and data["guided_json"] is not None, + "guided_regex" in data and data["guided_regex"] is not None, + "guided_choice" in data and data["guided_choice"] is not None + ]) + # you can only use one kind of guided decoding + if guide_count > 1: + raise ValueError( + "You can only use one kind of guided decoding " + "('guided_json', 'guided_regex' or 'guided_choice').") + # you can only either use guided decoding or tools, not both + if guide_count > 1 and "tool_choice" in data and data[ + "tool_choice"] != "none": + raise ValueError( + "You can only either use guided decoding or tools, not both.") + return data + + @model_validator(mode="before") + @classmethod + def check_tool_choice(cls, data): + if "tool_choice" in data and data["tool_choice"] != "none": + if not isinstance(data["tool_choice"], dict): + raise ValueError("Currently only named tools are supported.") + if "tools" not in data or data["tools"] is None: + raise ValueError( + "When using `tool_choice`, `tools` must be set.") + return data + + @model_validator(mode="before") + @classmethod + def check_logprobs(cls, data): + if "top_logprobs" in data and data["top_logprobs"] is not None: + if "logprobs" not in data or data["logprobs"] is False: + raise ValueError( + "when using `top_logprobs`, `logprobs` must be set to true." + ) + elif data["top_logprobs"] < 0: + raise ValueError( + "`top_logprobs` must be a value a positive value.") + return data + + +class CompletionRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/completions/create + model: str + prompt: Union[List[int], List[List[int]], str, List[str]] + best_of: Optional[int] = None + echo: Optional[bool] = False + frequency_penalty: Optional[float] = 0.0 + logit_bias: Optional[Dict[str, float]] = None + logprobs: Optional[int] = None + max_tokens: Optional[int] = 16 + n: int = 1 + presence_penalty: Optional[float] = 0.0 + seed: Optional[int] = Field(None, + ge=torch.iinfo(torch.long).min, + le=torch.iinfo(torch.long).max) + stop: Optional[Union[str, List[str]]] = Field(default_factory=list) + stream: Optional[bool] = False + stream_options: Optional[StreamOptions] = None + suffix: Optional[str] = None + temperature: Optional[float] = 1.0 + top_p: Optional[float] = 1.0 + user: Optional[str] = None + + # doc: begin-completion-sampling-params + use_beam_search: bool = False + top_k: int = -1 + min_p: float = 0.0 + repetition_penalty: float = 1.0 + length_penalty: float = 1.0 + early_stopping: bool = False + stop_token_ids: Optional[List[int]] = Field(default_factory=list) + include_stop_str_in_output: bool = False + ignore_eos: bool = False + min_tokens: int = 0 + skip_special_tokens: bool = True + spaces_between_special_tokens: bool = True + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + # doc: end-completion-sampling-params + + # doc: begin-completion-extra-params + add_special_tokens: bool = Field( + default=True, + description=( + "If true (the default), special tokens (e.g. BOS) will be added to " + "the prompt."), + ) + response_format: Optional[ResponseFormat] = Field( + default=None, + description= + ("Similar to chat completion, this parameter specifies the format of " + "output. Only {'type': 'json_object'} or {'type': 'text' } is " + "supported."), + ) + guided_json: Optional[Union[str, dict, BaseModel]] = Field( + default=None, + description=("If specified, the output will follow the JSON schema."), + ) + guided_regex: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the regex pattern."), + ) + guided_choice: Optional[List[str]] = Field( + default=None, + description=( + "If specified, the output will be exactly one of the choices."), + ) + guided_grammar: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the context free grammar."), + ) + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be one of " + "'outlines' / 'lm-format-enforcer'")) + guided_whitespace_pattern: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding.")) + + # doc: end-completion-extra-params + + def to_sampling_params(self): + echo_without_generation = self.echo and self.max_tokens == 0 + + logits_processors = None + if self.logit_bias: + logit_bias: Dict[int, float] = {} + try: + for token_id, bias in self.logit_bias.items(): + # Convert token_id to integer + # Clamp the bias between -100 and 100 per OpenAI API spec + logit_bias[int(token_id)] = min(100, max(-100, bias)) + except ValueError as exc: + raise ValueError(f"Found token_id `{token_id}` in logit_bias " + f"but token_id must be an integer or string " + f"representing an integer") from exc + + def logit_bias_logits_processor( + token_ids: List[int], + logits: torch.Tensor) -> torch.Tensor: + for token_id, bias in logit_bias.items(): + logits[token_id] += bias + return logits + + logits_processors = [logit_bias_logits_processor] + + return SamplingParams( + n=self.n, + best_of=self.best_of, + presence_penalty=self.presence_penalty, + frequency_penalty=self.frequency_penalty, + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, + min_p=self.min_p, + seed=self.seed, + stop=self.stop, + stop_token_ids=self.stop_token_ids, + logprobs=self.logprobs, + ignore_eos=self.ignore_eos, + max_tokens=self.max_tokens if not echo_without_generation else 1, + min_tokens=self.min_tokens, + use_beam_search=self.use_beam_search, + early_stopping=self.early_stopping, + prompt_logprobs=self.logprobs if self.echo else None, + skip_special_tokens=self.skip_special_tokens, + spaces_between_special_tokens=self.spaces_between_special_tokens, + include_stop_str_in_output=self.include_stop_str_in_output, + length_penalty=self.length_penalty, + logits_processors=logits_processors, + truncate_prompt_tokens=self.truncate_prompt_tokens, + ) + + @model_validator(mode="before") + @classmethod + def check_guided_decoding_count(cls, data): + guide_count = sum([ + "guided_json" in data and data["guided_json"] is not None, + "guided_regex" in data and data["guided_regex"] is not None, + "guided_choice" in data and data["guided_choice"] is not None + ]) + if guide_count > 1: + raise ValueError( + "You can only use one kind of guided decoding " + "('guided_json', 'guided_regex' or 'guided_choice').") + return data + + @model_validator(mode="before") + @classmethod + def check_logprobs(cls, data): + if "logprobs" in data and data[ + "logprobs"] is not None and not data["logprobs"] >= 0: + raise ValueError("if passed, `logprobs` must be a positive value.") + return data + + @model_validator(mode="before") + @classmethod + def validate_stream_options(cls, data): + if data.get("stream_options") and not data.get("stream"): + raise ValueError( + "Stream options can only be defined when stream is true.") + return data + + +class EmbeddingRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/embeddings + model: str + input: Union[List[int], List[List[int]], str, List[str]] + encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$') + dimensions: Optional[int] = None + user: Optional[str] = None + + # doc: begin-embedding-pooling-params + additional_data: Optional[Any] = None + + # doc: end-embedding-pooling-params + + def to_pooling_params(self): + return PoolingParams(additional_data=self.additional_data) + + +class CompletionLogProbs(OpenAIBaseModel): + text_offset: List[int] = Field(default_factory=list) + token_logprobs: List[Optional[float]] = Field(default_factory=list) + tokens: List[str] = Field(default_factory=list) + top_logprobs: List[Optional[Dict[str, + float]]] = Field(default_factory=list) + + +class CompletionResponseChoice(OpenAIBaseModel): + index: int + text: str + logprobs: Optional[CompletionLogProbs] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = Field( + default=None, + description=( + "The stop string or token id that caused the completion " + "to stop, None if the completion finished for some other reason " + "including encountering the EOS token"), + ) + + +class CompletionResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseChoice] + usage: UsageInfo + + +class CompletionResponseStreamChoice(OpenAIBaseModel): + index: int + text: str + logprobs: Optional[CompletionLogProbs] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = Field( + default=None, + description=( + "The stop string or token id that caused the completion " + "to stop, None if the completion finished for some other reason " + "including encountering the EOS token"), + ) + + +class CompletionStreamResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseStreamChoice] + usage: Optional[UsageInfo] = Field(default=None) + + +class EmbeddingResponseData(OpenAIBaseModel): + index: int + object: str = "embedding" + embedding: Union[List[float], str] + + +class EmbeddingResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "list" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + data: List[EmbeddingResponseData] + usage: UsageInfo + + +class FunctionCall(OpenAIBaseModel): + name: str + arguments: str + + +class ToolCall(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}") + type: Literal["function"] = "function" + function: FunctionCall + + +class ChatMessage(OpenAIBaseModel): + role: str + content: str + tool_calls: List[ToolCall] = Field(default_factory=list) + + +class ChatCompletionLogProb(OpenAIBaseModel): + token: str + logprob: float = -9999.0 + bytes: Optional[List[int]] = None + + +class ChatCompletionLogProbsContent(ChatCompletionLogProb): + top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list) + + +class ChatCompletionLogProbs(OpenAIBaseModel): + content: Optional[List[ChatCompletionLogProbsContent]] = None + + +class ChatCompletionResponseChoice(OpenAIBaseModel): + index: int + message: ChatMessage + logprobs: Optional[ChatCompletionLogProbs] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = None + + +class ChatCompletionResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") + object: Literal["chat.completion"] = "chat.completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseChoice] + usage: UsageInfo + + +class DeltaMessage(OpenAIBaseModel): + role: Optional[str] = None + content: Optional[str] = None + tool_calls: List[ToolCall] = Field(default_factory=list) + + +class ChatCompletionResponseStreamChoice(OpenAIBaseModel): + index: int + delta: DeltaMessage + logprobs: Optional[ChatCompletionLogProbs] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = None + + +class ChatCompletionStreamResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") + object: Literal["chat.completion.chunk"] = "chat.completion.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseStreamChoice] + usage: Optional[UsageInfo] = Field(default=None) + + +class BatchRequestInput(OpenAIBaseModel): + """ + The per-line object of the batch input file. + + NOTE: Currently only the `/v1/chat/completions` endpoint is supported. + """ + + # A developer-provided per-request id that will be used to match outputs to + # inputs. Must be unique for each request in a batch. + custom_id: str + + # The HTTP method to be used for the request. Currently only POST is + # supported. + method: str + + # The OpenAI API relative URL to be used for the request. Currently + # /v1/chat/completions is supported. + url: str + + # The parameters of the request. + body: ChatCompletionRequest + + +class BatchResponseData(OpenAIBaseModel): + # HTTP status code of the response. + status_code: int = 200 + + # An unique identifier for the API request. + request_id: str + + # The body of the response. + body: Optional[ChatCompletionResponse] = None + + +class BatchRequestOutput(OpenAIBaseModel): + """ + The per-line object of the batch output and error files + """ + + id: str + + # A developer-provided per-request id that will be used to match outputs to + # inputs. + custom_id: str + + response: Optional[BatchResponseData] + + # For requests that failed with a non-HTTP error, this will contain more + # information on the cause of the failure. + error: Optional[Any] + + +class TokenizeCompletionRequest(OpenAIBaseModel): + model: str + prompt: str + + add_special_tokens: bool = Field(default=True) + + +class TokenizeChatRequest(OpenAIBaseModel): + model: str + messages: List[ChatCompletionMessageParam] + + add_generation_prompt: bool = Field(default=True) + add_special_tokens: bool = Field(default=False) + + +TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest] + + +class TokenizeResponse(OpenAIBaseModel): + count: int + max_model_len: int + tokens: List[int] + + +class DetokenizeRequest(OpenAIBaseModel): + model: str + tokens: List[int] + + +class DetokenizeResponse(OpenAIBaseModel): + prompt: str diff --git a/llama3.1-70b-instruct-awq/bentovllm_openai/utils.py b/llama3.1-70b-instruct-awq/bentovllm_openai/utils.py new file mode 100644 index 0000000..1494d7e --- /dev/null +++ b/llama3.1-70b-instruct-awq/bentovllm_openai/utils.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +import asyncio +import typing as t + +from _bentoml_sdk.service.factory import Service +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, StreamingResponse + +from .protocol import ChatCompletionRequest, CompletionRequest, ErrorResponse + +T = t.TypeVar("T", bound=object) + +if t.TYPE_CHECKING: + from vllm import AsyncLLMEngine + +def openai_endpoints( + model_id: str, + response_role: str = "assistant", + served_model_names: t.Optional[list[str]] = None, + chat_template: t.Optional[str] = None, + chat_template_model_id: t.Optional[str] = None, + default_completion_parameters: t.Optional[t.Dict[str, t.Any]] = None, + default_chat_completion_parameters: t.Optional[t.Dict[str, t.Any]] = None, +): + + if served_model_names is None: + served_model_names = [model_id] + + def openai_wrapper(svc: Service[T]): + + cls = svc.inner + app = FastAPI() + + # make sure default_*_parameters are in valid format + if default_completion_parameters is not None: + assert "prompt" not in default_completion_parameters + assert CompletionRequest( + prompt="", model="", **default_completion_parameters + ) + + if default_chat_completion_parameters is not None: + assert "messages" not in default_chat_completion_parameters + assert ChatCompletionRequest( + messages=[], model="", **default_chat_completion_parameters + ) + + class new_cls(cls): + + def __init__(self): + + super().__init__() + + # we need to import bentoml before vllm so + # `prometheus_client` won't cause import troubles + # That's also why we put these codes inside class's + # `__init__` function + import bentoml + + from vllm.entrypoints.openai.serving_chat import OpenAIServingChat + from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion + + # we can do this because worker/engine_user_ray is always False for us + model_config = self.engine.engine.get_model_config() + + self.openai_serving_completion = OpenAIServingCompletion( + engine=self.engine, + served_model_names=served_model_names, + model_config=model_config, + lora_modules=None, + prompt_adapters=None, + request_logger=None, + ) + + self.chat_template = chat_template + if self.chat_template is None and chat_template_model_id is not None: + from transformers import AutoTokenizer + _tokenizer = AutoTokenizer.from_pretrained(chat_template_model_id) + self.chat_template = _tokenizer.chat_template + + self.openai_serving_chat = OpenAIServingChat( + engine=self.engine, + served_model_names=served_model_names, + response_role=response_role, + chat_template=self.chat_template, + model_config=model_config, + lora_modules=None, + prompt_adapters=None, + request_logger=None, + ) + + @app.get("/models") + async def show_available_models(): + models = await self.openai_serving_chat.show_available_models() + return JSONResponse(content=models.model_dump()) + + @app.post("/chat/completions") + async def create_chat_completion( + request: ChatCompletionRequest, + raw_request: Request + ): + if default_chat_completion_parameters is not None: + for k, v in default_chat_completion_parameters.items(): + if k not in request.__fields_set__: + setattr(request, k, v) + generator = await self.openai_serving_chat.create_chat_completion( + request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + if request.stream: + return StreamingResponse(content=generator, + media_type="text/event-stream") + else: + return JSONResponse(content=generator.model_dump()) + + @app.post("/completions") + async def create_completion(request: CompletionRequest, raw_request: Request): + if default_completion_parameters is not None: + for k, v in default_completion_parameters.items(): + if k not in request.__fields_set__: + setattr(request, k, v) + generator = await self.openai_serving_completion.create_completion( + request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + if request.stream: + return StreamingResponse(content=generator, + media_type="text/event-stream") + else: + return JSONResponse(content=generator.model_dump()) + + new_cls.__name__ = "%s_OpenAI" % cls.__name__ + svc.inner = new_cls + svc.mount_asgi_app(app, path="/v1/") + return svc + + return openai_wrapper + + +# helper function to make a httpx client for BentoML service +def _make_httpx_client(url, svc): + + import httpx + from urllib.parse import urlparse + from bentoml._internal.utils.uri import uri_to_path + + timeout = svc.config["traffic"]["timeout"] + headers = {"Runner-Name": svc.name} + parsed = urlparse(url) + transport = None + target_url = url + + if parsed.scheme == "file": + uds = uri_to_path(url) + transport = httpx.HTTPTransport(uds=uds) + target_url = "http://127.0.0.1:3000" + elif parsed.scheme == "tcp": + target_url = f"http://{parsed.netloc}" + + return httpx.Client( + transport=transport, + timeout=timeout, + follow_redirects=True, + headers=headers, + ), target_url diff --git a/llama3.1-70b-instruct-awq/requirements.txt b/llama3.1-70b-instruct-awq/requirements.txt new file mode 100644 index 0000000..24e6fcb --- /dev/null +++ b/llama3.1-70b-instruct-awq/requirements.txt @@ -0,0 +1,7 @@ +accelerate==0.29.3 +autoawq==0.2.5 +bentoml>=1.2.17 +bitsandbytes==0.43.1 +fastapi==0.111.1 +openai==1.35.14 +vllm==0.5.3.post1; sys_platform == "linux" diff --git a/llama3.1-70b-instruct-awq/service.py b/llama3.1-70b-instruct-awq/service.py new file mode 100644 index 0000000..0d4c961 --- /dev/null +++ b/llama3.1-70b-instruct-awq/service.py @@ -0,0 +1,86 @@ +import uuid +from typing import AsyncGenerator, Optional + +import bentoml +from annotated_types import Ge, Le +from typing_extensions import Annotated + +from bentovllm_openai.utils import openai_endpoints + + +MAX_TOKENS = 8192 +SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. + +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" + +PROMPT_TEMPLATE = """<|begin_of_text|><|start_header_id|>system<|end_header_id|> + +{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|> + +{user_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +""" + +MODEL_ID = "hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4" + + +@openai_endpoints( + model_id=MODEL_ID, + default_chat_completion_parameters=dict(stop=["<|eot_id|>"]), +) +@bentoml.service( + name="bentovllm-llama3.1-70b-insruct-awq-service", + traffic={ + "timeout": 1200, + "concurrency": 256, # Matches the default max_num_seqs in the VLLM engine + }, + resources={ + "gpu": 1, + "gpu_type": "nvidia-a100-80gb", + }, +) +class VLLM: + + def __init__(self) -> None: + from transformers import AutoTokenizer + from vllm import AsyncEngineArgs, AsyncLLMEngine + + ENGINE_ARGS = AsyncEngineArgs( + model=MODEL_ID, + max_model_len=MAX_TOKENS, + quantization="AWQ", + enable_prefix_caching=True, + ) + + self.engine = AsyncLLMEngine.from_engine_args(ENGINE_ARGS) + + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + self.stop_token_ids = [ + tokenizer.eos_token_id, + tokenizer.convert_tokens_to_ids("<|eot_id|>"), + ] + + @bentoml.api + async def generate( + self, + prompt: str = "Explain superconductors in plain English", + system_prompt: Optional[str] = SYSTEM_PROMPT, + max_tokens: Annotated[int, Ge(128), Le(MAX_TOKENS)] = MAX_TOKENS, + ) -> AsyncGenerator[str, None]: + from vllm import SamplingParams + + SAMPLING_PARAM = SamplingParams( + max_tokens=max_tokens, + stop_token_ids=self.stop_token_ids, + ) + + if system_prompt is None: + system_prompt = SYSTEM_PROMPT + prompt = PROMPT_TEMPLATE.format(user_prompt=prompt, system_prompt=system_prompt) + stream = await self.engine.add_request(uuid.uuid4().hex, prompt, SAMPLING_PARAM) + + cursor = 0 + async for request_output in stream: + text = request_output.outputs[0].text + yield text[cursor:] + cursor = len(text) diff --git a/llama3.1-8b-instruct/.bentoignore b/llama3.1-8b-instruct/.bentoignore new file mode 100644 index 0000000..d9cf115 --- /dev/null +++ b/llama3.1-8b-instruct/.bentoignore @@ -0,0 +1,5 @@ +__pycache__/ +*.py[cod] +*$py.class +.ipynb_checkpoints +venv/ diff --git a/llama3.1-8b-instruct/README.md b/llama3.1-8b-instruct/README.md new file mode 100644 index 0000000..10dfb92 --- /dev/null +++ b/llama3.1-8b-instruct/README.md @@ -0,0 +1,167 @@ +
+

Self-host Llama 3.1 8B with vLLM and BentoML

+
+ +This is a BentoML example project, showing you how to serve and deploy Llama 3.1 8B using [vLLM](https://vllm.ai), a high-throughput and memory-efficient inference engine. + +See [here](https://github.com/bentoml/BentoML?tab=readme-ov-file#%EF%B8%8F-what-you-can-build-with-bentoml) for a full list of BentoML example projects. + +💡 This example is served as a basis for advanced code customization, such as custom model, inference logic or vLLM options. For simple LLM hosting with OpenAI compatible endpoint without writing any code, see [OpenLLM](https://github.com/bentoml/OpenLLM). + + +## Prerequisites + +- You have installed Python 3.8+ and `pip`. See the [Python downloads page](https://www.python.org/downloads/) to learn more. +- You have a basic understanding of key concepts in BentoML, such as Services. We recommend you read [Quickstart](https://docs.bentoml.com/en/1.2/get-started/quickstart.html) first. +- You have gained access to Llama 3.1 8B on [its official website](https://llama.meta.com/) and [Hugging Face](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct). +- If you want to test the Service locally, you need a Nvidia GPU with at least 16G VRAM. +- (Optional) We recommend you create a virtual environment for dependency isolation for this project. See the [Conda documentation](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html) or the [Python documentation](https://docs.python.org/3/library/venv.html) for details. + +## Install dependencies + +```bash +git clone https://github.com/bentoml/BentoVLLM.git +cd BentoVLLM/llama3.1-8b-instruct +pip install -r requirements.txt +``` + +## Run the BentoML Service + +We have defined a BentoML Service in `service.py`. Run `bentoml serve` in your project directory to start the Service. + +```python +$ bentoml serve . + +2024-01-18T07:51:30+0800 [INFO] [cli] Starting production HTTP BentoServer from "service:VLLM" listening on http://localhost:3000 (Press CTRL+C to quit) +INFO 01-18 07:51:40 model_runner.py:501] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. +INFO 01-18 07:51:40 model_runner.py:505] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. +INFO 01-18 07:51:46 model_runner.py:547] Graph capturing finished in 6 secs. +``` + +The server is now active at [http://localhost:3000](http://localhost:3000/). You can interact with it using the Swagger UI or in other different ways. + +
+ +CURL + +```bash +curl -X 'POST' \ + 'http://localhost:3000/generate' \ + -H 'accept: text/event-stream' \ + -H 'Content-Type: application/json' \ + -d '{ + "prompt": "Explain superconductors like I'\''m five years old", + "tokens": null +}' +``` + +
+ +
+ +Python client + +```python +import bentoml + +with bentoml.SyncHTTPClient("http://localhost:3000") as client: + response_generator = client.generate( + prompt="Explain superconductors like I'm five years old", + tokens=None + ) + for response in response_generator: + print(response) +``` + +
+ +
+ +OpenAI-compatible endpoints + +This Service uses the `@openai_endpoints` decorator to set up OpenAI-compatible endpoints (`chat/completions` and `completions`). This means your client can interact with the backend Service (in this case, the VLLM class) as if they were communicating directly with OpenAI's API. This [utility](bentovllm_openai/) does not affect your BentoML Service code, and you can use it for other LLMs as well. + +```python +from openai import OpenAI + +client = OpenAI(base_url='http://localhost:3000/v1', api_key='na') + +# Use the following func to get the available models +client.models.list() + +chat_completion = client.chat.completions.create( + model="meta-llama/Meta-Llama-3.1-8B-Instruct", + messages=[ + { + "role": "user", + "content": "Explain superconductors like I'm five years old" + } + ], + stream=True, +) +for chunk in chat_completion: + # Extract and print the content of the model's reply + print(chunk.choices[0].delta.content or "", end="") +``` + +These OpenAI-compatible endpoints also support [vLLM extra parameters](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters). For example, you can force the chat completion output a JSON object by using the `guided_json` parameters: + +```python +from openai import OpenAI + +client = OpenAI(base_url='http://localhost:3000/v1', api_key='na') + +# Use the following func to get the available models +client.models.list() + +json_schema = { + "type": "object", + "properties": { + "city": {"type": "string"} + } +} + +chat_completion = client.chat.completions.create( + model="meta-llama/Meta-Llama-3.1-8B-Instruct", + messages=[ + { + "role": "user", + "content": "What is the capital of France?" + } + ], + extra_body=dict(guided_json=json_schema), +) +print(chat_completion.choices[0].message.content) # will return something like: {"city": "Paris"} +``` + +All supported extra parameters are listed in [vLLM documentation](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters). + +**Note**: If your Service is deployed with [protected endpoints on BentoCloud](https://docs.bentoml.com/en/latest/bentocloud/how-tos/manage-access-token.html#access-protected-deployments), you need to set the environment variable `OPENAI_API_KEY` to your BentoCloud API key first. + +```bash +export OPENAI_API_KEY={YOUR_BENTOCLOUD_API_TOKEN} +``` + +You can then use the following line to replace the client in the above code snippet. Refer to [Obtain the endpoint URL](https://docs.bentoml.com/en/latest/bentocloud/how-tos/call-deployment-endpoints.html#obtain-the-endpoint-url) to retrieve the endpoint URL. + +```python +client = OpenAI(base_url='your_bentocloud_deployment_endpoint_url/v1') +``` + +
+ +For detailed explanations of the Service code, see [vLLM inference](https://docs.bentoml.org/en/latest/use-cases/large-language-models/vllm.html). + +## Deploy to BentoCloud + +After the Service is ready, you can deploy the application to BentoCloud for better management and scalability. [Sign up](https://www.bentoml.com/) if you haven't got a BentoCloud account. + +Make sure you have [logged in to BentoCloud](https://docs.bentoml.com/en/latest/bentocloud/how-tos/manage-access-token.html), then run the following command to deploy it. + +```bash +bentoml deploy --env HF_TOKEN= . +``` + +Once the application is up and running on BentoCloud, you can access it via the exposed URL. + +**Note**: For custom deployment in your own infrastructure, use [BentoML to generate an OCI-compliant image](https://docs.bentoml.com/en/latest/guides/containerization.html). diff --git a/llama3.1-8b-instruct/bentofile.yaml b/llama3.1-8b-instruct/bentofile.yaml new file mode 100644 index 0000000..29f4736 --- /dev/null +++ b/llama3.1-8b-instruct/bentofile.yaml @@ -0,0 +1,14 @@ +service: 'service:VLLM' +labels: + owner: bentoml-team + stage: demo +include: + - '*.py' + - 'bentovllm_openai/*.py' +python: + requirements_txt: './requirements.txt' + lock_packages: false +envs: + - name: HF_TOKEN +docker: + python_version: 3.11 diff --git a/llama3.1-8b-instruct/bentovllm_openai/__init__.py b/llama3.1-8b-instruct/bentovllm_openai/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/llama3.1-8b-instruct/bentovllm_openai/protocol.py b/llama3.1-8b-instruct/bentovllm_openai/protocol.py new file mode 100644 index 0000000..c024bbc --- /dev/null +++ b/llama3.1-8b-instruct/bentovllm_openai/protocol.py @@ -0,0 +1,734 @@ +# Adapted from +# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py +import time +from typing import Any, Dict, List, Literal, Optional, Union + +import torch +from pydantic import BaseModel, ConfigDict, Field, model_validator +from typing_extensions import Annotated + +from vllm.entrypoints.chat_utils import ChatCompletionMessageParam +from vllm.pooling_params import PoolingParams +from vllm.sampling_params import SamplingParams +from vllm.utils import random_uuid + + +class OpenAIBaseModel(BaseModel): + # OpenAI API does not allow extra fields + model_config = ConfigDict(extra="forbid") + + +class ErrorResponse(OpenAIBaseModel): + object: str = "error" + message: str + type: str + param: Optional[str] = None + code: int + + +class ModelPermission(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}") + object: str = "model_permission" + created: int = Field(default_factory=lambda: int(time.time())) + allow_create_engine: bool = False + allow_sampling: bool = True + allow_logprobs: bool = True + allow_search_indices: bool = False + allow_view: bool = True + allow_fine_tuning: bool = False + organization: str = "*" + group: Optional[str] = None + is_blocking: bool = False + + +class ModelCard(OpenAIBaseModel): + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "vllm" + root: Optional[str] = None + parent: Optional[str] = None + max_model_len: Optional[int] = None + permission: List[ModelPermission] = Field(default_factory=list) + + +class ModelList(OpenAIBaseModel): + object: str = "list" + data: List[ModelCard] = Field(default_factory=list) + + +class UsageInfo(OpenAIBaseModel): + prompt_tokens: int = 0 + total_tokens: int = 0 + completion_tokens: Optional[int] = 0 + + +class ResponseFormat(OpenAIBaseModel): + # type must be "json_object" or "text" + type: Literal["text", "json_object"] + + +class StreamOptions(OpenAIBaseModel): + include_usage: Optional[bool] = True + continuous_usage_stats: Optional[bool] = True + + +class FunctionDefinition(OpenAIBaseModel): + name: str + description: Optional[str] = None + parameters: Optional[Dict[str, Any]] = None + + +class ChatCompletionToolsParam(OpenAIBaseModel): + type: Literal["function"] = "function" + function: FunctionDefinition + + +class ChatCompletionNamedFunction(OpenAIBaseModel): + name: str + + +class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel): + function: ChatCompletionNamedFunction + type: Literal["function"] = "function" + + +class ChatCompletionRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/chat/create + messages: List[ChatCompletionMessageParam] + model: str + frequency_penalty: Optional[float] = 0.0 + logit_bias: Optional[Dict[str, float]] = None + logprobs: Optional[bool] = False + top_logprobs: Optional[int] = 0 + max_tokens: Optional[int] = None + n: Optional[int] = 1 + presence_penalty: Optional[float] = 0.0 + response_format: Optional[ResponseFormat] = None + seed: Optional[int] = Field(None, + ge=torch.iinfo(torch.long).min, + le=torch.iinfo(torch.long).max) + stop: Optional[Union[str, List[str]]] = Field(default_factory=list) + stream: Optional[bool] = False + stream_options: Optional[StreamOptions] = None + temperature: Optional[float] = 0.7 + top_p: Optional[float] = 1.0 + tools: Optional[List[ChatCompletionToolsParam]] = None + tool_choice: Optional[Union[Literal["none"], + ChatCompletionNamedToolChoiceParam]] = "none" + user: Optional[str] = None + + # doc: begin-chat-completion-sampling-params + best_of: Optional[int] = None + use_beam_search: bool = False + top_k: int = -1 + min_p: float = 0.0 + repetition_penalty: float = 1.0 + length_penalty: float = 1.0 + early_stopping: bool = False + stop_token_ids: Optional[List[int]] = Field(default_factory=list) + include_stop_str_in_output: bool = False + ignore_eos: bool = False + min_tokens: int = 0 + skip_special_tokens: bool = True + spaces_between_special_tokens: bool = True + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + # doc: end-chat-completion-sampling-params + + # doc: begin-chat-completion-extra-params + echo: bool = Field( + default=False, + description=( + "If true, the new message will be prepended with the last message " + "if they belong to the same role."), + ) + add_generation_prompt: bool = Field( + default=True, + description= + ("If true, the generation prompt will be added to the chat template. " + "This is a parameter used by chat template in tokenizer config of the " + "model."), + ) + add_special_tokens: bool = Field( + default=False, + description=( + "If true, special tokens (e.g. BOS) will be added to the prompt " + "on top of what is added by the chat template. " + "For most models, the chat template takes care of adding the " + "special tokens so this should be set to false (as is the " + "default)."), + ) + documents: Optional[List[Dict[str, str]]] = Field( + default=None, + description= + ("A list of dicts representing documents that will be accessible to " + "the model if it is performing RAG (retrieval-augmented generation)." + " If the template does not support RAG, this argument will have no " + "effect. We recommend that each document should be a dict containing " + "\"title\" and \"text\" keys."), + ) + chat_template: Optional[str] = Field( + default=None, + description=( + "A Jinja template to use for this conversion. " + "If this is not passed, the model's default chat template will be " + "used instead."), + ) + chat_template_kwargs: Optional[Dict[str, Any]] = Field( + default=None, + description=("Additional kwargs to pass to the template renderer. " + "Will be accessible by the chat template."), + ) + guided_json: Optional[Union[str, dict, BaseModel]] = Field( + default=None, + description=("If specified, the output will follow the JSON schema."), + ) + guided_regex: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the regex pattern."), + ) + guided_choice: Optional[List[str]] = Field( + default=None, + description=( + "If specified, the output will be exactly one of the choices."), + ) + guided_grammar: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the context free grammar."), + ) + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be either " + "'outlines' / 'lm-format-enforcer'")) + guided_whitespace_pattern: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding.")) + + # doc: end-chat-completion-extra-params + + def to_sampling_params(self) -> SamplingParams: + # We now allow logprobs being true without top_logrobs. + + logits_processors = None + if self.logit_bias: + logit_bias: Dict[int, float] = {} + try: + for token_id, bias in self.logit_bias.items(): + # Convert token_id to integer before we add to LLMEngine + # Clamp the bias between -100 and 100 per OpenAI API spec + logit_bias[int(token_id)] = min(100, max(-100, bias)) + except ValueError as exc: + raise ValueError(f"Found token_id `{token_id}` in logit_bias " + f"but token_id must be an integer or string " + f"representing an integer") from exc + + def logit_bias_logits_processor( + token_ids: List[int], + logits: torch.Tensor) -> torch.Tensor: + for token_id, bias in logit_bias.items(): + logits[token_id] += bias + return logits + + logits_processors = [logit_bias_logits_processor] + + return SamplingParams( + n=self.n, + best_of=self.best_of, + presence_penalty=self.presence_penalty, + frequency_penalty=self.frequency_penalty, + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, + min_p=self.min_p, + seed=self.seed, + stop=self.stop, + stop_token_ids=self.stop_token_ids, + logprobs=self.top_logprobs if self.logprobs else None, + prompt_logprobs=self.top_logprobs if self.echo else None, + ignore_eos=self.ignore_eos, + max_tokens=self.max_tokens, + min_tokens=self.min_tokens, + use_beam_search=self.use_beam_search, + early_stopping=self.early_stopping, + skip_special_tokens=self.skip_special_tokens, + spaces_between_special_tokens=self.spaces_between_special_tokens, + include_stop_str_in_output=self.include_stop_str_in_output, + length_penalty=self.length_penalty, + logits_processors=logits_processors, + truncate_prompt_tokens=self.truncate_prompt_tokens, + ) + + @model_validator(mode='before') + @classmethod + def validate_stream_options(cls, values): + if (values.get('stream_options') is not None + and not values.get('stream')): + raise ValueError( + "stream_options can only be set if stream is true") + return values + + @model_validator(mode="before") + @classmethod + def check_guided_decoding_count(cls, data): + guide_count = sum([ + "guided_json" in data and data["guided_json"] is not None, + "guided_regex" in data and data["guided_regex"] is not None, + "guided_choice" in data and data["guided_choice"] is not None + ]) + # you can only use one kind of guided decoding + if guide_count > 1: + raise ValueError( + "You can only use one kind of guided decoding " + "('guided_json', 'guided_regex' or 'guided_choice').") + # you can only either use guided decoding or tools, not both + if guide_count > 1 and "tool_choice" in data and data[ + "tool_choice"] != "none": + raise ValueError( + "You can only either use guided decoding or tools, not both.") + return data + + @model_validator(mode="before") + @classmethod + def check_tool_choice(cls, data): + if "tool_choice" in data and data["tool_choice"] != "none": + if not isinstance(data["tool_choice"], dict): + raise ValueError("Currently only named tools are supported.") + if "tools" not in data or data["tools"] is None: + raise ValueError( + "When using `tool_choice`, `tools` must be set.") + return data + + @model_validator(mode="before") + @classmethod + def check_logprobs(cls, data): + if "top_logprobs" in data and data["top_logprobs"] is not None: + if "logprobs" not in data or data["logprobs"] is False: + raise ValueError( + "when using `top_logprobs`, `logprobs` must be set to true." + ) + elif data["top_logprobs"] < 0: + raise ValueError( + "`top_logprobs` must be a value a positive value.") + return data + + +class CompletionRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/completions/create + model: str + prompt: Union[List[int], List[List[int]], str, List[str]] + best_of: Optional[int] = None + echo: Optional[bool] = False + frequency_penalty: Optional[float] = 0.0 + logit_bias: Optional[Dict[str, float]] = None + logprobs: Optional[int] = None + max_tokens: Optional[int] = 16 + n: int = 1 + presence_penalty: Optional[float] = 0.0 + seed: Optional[int] = Field(None, + ge=torch.iinfo(torch.long).min, + le=torch.iinfo(torch.long).max) + stop: Optional[Union[str, List[str]]] = Field(default_factory=list) + stream: Optional[bool] = False + stream_options: Optional[StreamOptions] = None + suffix: Optional[str] = None + temperature: Optional[float] = 1.0 + top_p: Optional[float] = 1.0 + user: Optional[str] = None + + # doc: begin-completion-sampling-params + use_beam_search: bool = False + top_k: int = -1 + min_p: float = 0.0 + repetition_penalty: float = 1.0 + length_penalty: float = 1.0 + early_stopping: bool = False + stop_token_ids: Optional[List[int]] = Field(default_factory=list) + include_stop_str_in_output: bool = False + ignore_eos: bool = False + min_tokens: int = 0 + skip_special_tokens: bool = True + spaces_between_special_tokens: bool = True + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + # doc: end-completion-sampling-params + + # doc: begin-completion-extra-params + add_special_tokens: bool = Field( + default=True, + description=( + "If true (the default), special tokens (e.g. BOS) will be added to " + "the prompt."), + ) + response_format: Optional[ResponseFormat] = Field( + default=None, + description= + ("Similar to chat completion, this parameter specifies the format of " + "output. Only {'type': 'json_object'} or {'type': 'text' } is " + "supported."), + ) + guided_json: Optional[Union[str, dict, BaseModel]] = Field( + default=None, + description=("If specified, the output will follow the JSON schema."), + ) + guided_regex: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the regex pattern."), + ) + guided_choice: Optional[List[str]] = Field( + default=None, + description=( + "If specified, the output will be exactly one of the choices."), + ) + guided_grammar: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the context free grammar."), + ) + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be one of " + "'outlines' / 'lm-format-enforcer'")) + guided_whitespace_pattern: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding.")) + + # doc: end-completion-extra-params + + def to_sampling_params(self): + echo_without_generation = self.echo and self.max_tokens == 0 + + logits_processors = None + if self.logit_bias: + logit_bias: Dict[int, float] = {} + try: + for token_id, bias in self.logit_bias.items(): + # Convert token_id to integer + # Clamp the bias between -100 and 100 per OpenAI API spec + logit_bias[int(token_id)] = min(100, max(-100, bias)) + except ValueError as exc: + raise ValueError(f"Found token_id `{token_id}` in logit_bias " + f"but token_id must be an integer or string " + f"representing an integer") from exc + + def logit_bias_logits_processor( + token_ids: List[int], + logits: torch.Tensor) -> torch.Tensor: + for token_id, bias in logit_bias.items(): + logits[token_id] += bias + return logits + + logits_processors = [logit_bias_logits_processor] + + return SamplingParams( + n=self.n, + best_of=self.best_of, + presence_penalty=self.presence_penalty, + frequency_penalty=self.frequency_penalty, + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, + min_p=self.min_p, + seed=self.seed, + stop=self.stop, + stop_token_ids=self.stop_token_ids, + logprobs=self.logprobs, + ignore_eos=self.ignore_eos, + max_tokens=self.max_tokens if not echo_without_generation else 1, + min_tokens=self.min_tokens, + use_beam_search=self.use_beam_search, + early_stopping=self.early_stopping, + prompt_logprobs=self.logprobs if self.echo else None, + skip_special_tokens=self.skip_special_tokens, + spaces_between_special_tokens=self.spaces_between_special_tokens, + include_stop_str_in_output=self.include_stop_str_in_output, + length_penalty=self.length_penalty, + logits_processors=logits_processors, + truncate_prompt_tokens=self.truncate_prompt_tokens, + ) + + @model_validator(mode="before") + @classmethod + def check_guided_decoding_count(cls, data): + guide_count = sum([ + "guided_json" in data and data["guided_json"] is not None, + "guided_regex" in data and data["guided_regex"] is not None, + "guided_choice" in data and data["guided_choice"] is not None + ]) + if guide_count > 1: + raise ValueError( + "You can only use one kind of guided decoding " + "('guided_json', 'guided_regex' or 'guided_choice').") + return data + + @model_validator(mode="before") + @classmethod + def check_logprobs(cls, data): + if "logprobs" in data and data[ + "logprobs"] is not None and not data["logprobs"] >= 0: + raise ValueError("if passed, `logprobs` must be a positive value.") + return data + + @model_validator(mode="before") + @classmethod + def validate_stream_options(cls, data): + if data.get("stream_options") and not data.get("stream"): + raise ValueError( + "Stream options can only be defined when stream is true.") + return data + + +class EmbeddingRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/embeddings + model: str + input: Union[List[int], List[List[int]], str, List[str]] + encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$') + dimensions: Optional[int] = None + user: Optional[str] = None + + # doc: begin-embedding-pooling-params + additional_data: Optional[Any] = None + + # doc: end-embedding-pooling-params + + def to_pooling_params(self): + return PoolingParams(additional_data=self.additional_data) + + +class CompletionLogProbs(OpenAIBaseModel): + text_offset: List[int] = Field(default_factory=list) + token_logprobs: List[Optional[float]] = Field(default_factory=list) + tokens: List[str] = Field(default_factory=list) + top_logprobs: List[Optional[Dict[str, + float]]] = Field(default_factory=list) + + +class CompletionResponseChoice(OpenAIBaseModel): + index: int + text: str + logprobs: Optional[CompletionLogProbs] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = Field( + default=None, + description=( + "The stop string or token id that caused the completion " + "to stop, None if the completion finished for some other reason " + "including encountering the EOS token"), + ) + + +class CompletionResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseChoice] + usage: UsageInfo + + +class CompletionResponseStreamChoice(OpenAIBaseModel): + index: int + text: str + logprobs: Optional[CompletionLogProbs] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = Field( + default=None, + description=( + "The stop string or token id that caused the completion " + "to stop, None if the completion finished for some other reason " + "including encountering the EOS token"), + ) + + +class CompletionStreamResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseStreamChoice] + usage: Optional[UsageInfo] = Field(default=None) + + +class EmbeddingResponseData(OpenAIBaseModel): + index: int + object: str = "embedding" + embedding: Union[List[float], str] + + +class EmbeddingResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "list" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + data: List[EmbeddingResponseData] + usage: UsageInfo + + +class FunctionCall(OpenAIBaseModel): + name: str + arguments: str + + +class ToolCall(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}") + type: Literal["function"] = "function" + function: FunctionCall + + +class ChatMessage(OpenAIBaseModel): + role: str + content: str + tool_calls: List[ToolCall] = Field(default_factory=list) + + +class ChatCompletionLogProb(OpenAIBaseModel): + token: str + logprob: float = -9999.0 + bytes: Optional[List[int]] = None + + +class ChatCompletionLogProbsContent(ChatCompletionLogProb): + top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list) + + +class ChatCompletionLogProbs(OpenAIBaseModel): + content: Optional[List[ChatCompletionLogProbsContent]] = None + + +class ChatCompletionResponseChoice(OpenAIBaseModel): + index: int + message: ChatMessage + logprobs: Optional[ChatCompletionLogProbs] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = None + + +class ChatCompletionResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") + object: Literal["chat.completion"] = "chat.completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseChoice] + usage: UsageInfo + + +class DeltaMessage(OpenAIBaseModel): + role: Optional[str] = None + content: Optional[str] = None + tool_calls: List[ToolCall] = Field(default_factory=list) + + +class ChatCompletionResponseStreamChoice(OpenAIBaseModel): + index: int + delta: DeltaMessage + logprobs: Optional[ChatCompletionLogProbs] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = None + + +class ChatCompletionStreamResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") + object: Literal["chat.completion.chunk"] = "chat.completion.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseStreamChoice] + usage: Optional[UsageInfo] = Field(default=None) + + +class BatchRequestInput(OpenAIBaseModel): + """ + The per-line object of the batch input file. + + NOTE: Currently only the `/v1/chat/completions` endpoint is supported. + """ + + # A developer-provided per-request id that will be used to match outputs to + # inputs. Must be unique for each request in a batch. + custom_id: str + + # The HTTP method to be used for the request. Currently only POST is + # supported. + method: str + + # The OpenAI API relative URL to be used for the request. Currently + # /v1/chat/completions is supported. + url: str + + # The parameters of the request. + body: ChatCompletionRequest + + +class BatchResponseData(OpenAIBaseModel): + # HTTP status code of the response. + status_code: int = 200 + + # An unique identifier for the API request. + request_id: str + + # The body of the response. + body: Optional[ChatCompletionResponse] = None + + +class BatchRequestOutput(OpenAIBaseModel): + """ + The per-line object of the batch output and error files + """ + + id: str + + # A developer-provided per-request id that will be used to match outputs to + # inputs. + custom_id: str + + response: Optional[BatchResponseData] + + # For requests that failed with a non-HTTP error, this will contain more + # information on the cause of the failure. + error: Optional[Any] + + +class TokenizeCompletionRequest(OpenAIBaseModel): + model: str + prompt: str + + add_special_tokens: bool = Field(default=True) + + +class TokenizeChatRequest(OpenAIBaseModel): + model: str + messages: List[ChatCompletionMessageParam] + + add_generation_prompt: bool = Field(default=True) + add_special_tokens: bool = Field(default=False) + + +TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest] + + +class TokenizeResponse(OpenAIBaseModel): + count: int + max_model_len: int + tokens: List[int] + + +class DetokenizeRequest(OpenAIBaseModel): + model: str + tokens: List[int] + + +class DetokenizeResponse(OpenAIBaseModel): + prompt: str diff --git a/llama3.1-8b-instruct/bentovllm_openai/utils.py b/llama3.1-8b-instruct/bentovllm_openai/utils.py new file mode 100644 index 0000000..1494d7e --- /dev/null +++ b/llama3.1-8b-instruct/bentovllm_openai/utils.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +import asyncio +import typing as t + +from _bentoml_sdk.service.factory import Service +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, StreamingResponse + +from .protocol import ChatCompletionRequest, CompletionRequest, ErrorResponse + +T = t.TypeVar("T", bound=object) + +if t.TYPE_CHECKING: + from vllm import AsyncLLMEngine + +def openai_endpoints( + model_id: str, + response_role: str = "assistant", + served_model_names: t.Optional[list[str]] = None, + chat_template: t.Optional[str] = None, + chat_template_model_id: t.Optional[str] = None, + default_completion_parameters: t.Optional[t.Dict[str, t.Any]] = None, + default_chat_completion_parameters: t.Optional[t.Dict[str, t.Any]] = None, +): + + if served_model_names is None: + served_model_names = [model_id] + + def openai_wrapper(svc: Service[T]): + + cls = svc.inner + app = FastAPI() + + # make sure default_*_parameters are in valid format + if default_completion_parameters is not None: + assert "prompt" not in default_completion_parameters + assert CompletionRequest( + prompt="", model="", **default_completion_parameters + ) + + if default_chat_completion_parameters is not None: + assert "messages" not in default_chat_completion_parameters + assert ChatCompletionRequest( + messages=[], model="", **default_chat_completion_parameters + ) + + class new_cls(cls): + + def __init__(self): + + super().__init__() + + # we need to import bentoml before vllm so + # `prometheus_client` won't cause import troubles + # That's also why we put these codes inside class's + # `__init__` function + import bentoml + + from vllm.entrypoints.openai.serving_chat import OpenAIServingChat + from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion + + # we can do this because worker/engine_user_ray is always False for us + model_config = self.engine.engine.get_model_config() + + self.openai_serving_completion = OpenAIServingCompletion( + engine=self.engine, + served_model_names=served_model_names, + model_config=model_config, + lora_modules=None, + prompt_adapters=None, + request_logger=None, + ) + + self.chat_template = chat_template + if self.chat_template is None and chat_template_model_id is not None: + from transformers import AutoTokenizer + _tokenizer = AutoTokenizer.from_pretrained(chat_template_model_id) + self.chat_template = _tokenizer.chat_template + + self.openai_serving_chat = OpenAIServingChat( + engine=self.engine, + served_model_names=served_model_names, + response_role=response_role, + chat_template=self.chat_template, + model_config=model_config, + lora_modules=None, + prompt_adapters=None, + request_logger=None, + ) + + @app.get("/models") + async def show_available_models(): + models = await self.openai_serving_chat.show_available_models() + return JSONResponse(content=models.model_dump()) + + @app.post("/chat/completions") + async def create_chat_completion( + request: ChatCompletionRequest, + raw_request: Request + ): + if default_chat_completion_parameters is not None: + for k, v in default_chat_completion_parameters.items(): + if k not in request.__fields_set__: + setattr(request, k, v) + generator = await self.openai_serving_chat.create_chat_completion( + request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + if request.stream: + return StreamingResponse(content=generator, + media_type="text/event-stream") + else: + return JSONResponse(content=generator.model_dump()) + + @app.post("/completions") + async def create_completion(request: CompletionRequest, raw_request: Request): + if default_completion_parameters is not None: + for k, v in default_completion_parameters.items(): + if k not in request.__fields_set__: + setattr(request, k, v) + generator = await self.openai_serving_completion.create_completion( + request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + if request.stream: + return StreamingResponse(content=generator, + media_type="text/event-stream") + else: + return JSONResponse(content=generator.model_dump()) + + new_cls.__name__ = "%s_OpenAI" % cls.__name__ + svc.inner = new_cls + svc.mount_asgi_app(app, path="/v1/") + return svc + + return openai_wrapper + + +# helper function to make a httpx client for BentoML service +def _make_httpx_client(url, svc): + + import httpx + from urllib.parse import urlparse + from bentoml._internal.utils.uri import uri_to_path + + timeout = svc.config["traffic"]["timeout"] + headers = {"Runner-Name": svc.name} + parsed = urlparse(url) + transport = None + target_url = url + + if parsed.scheme == "file": + uds = uri_to_path(url) + transport = httpx.HTTPTransport(uds=uds) + target_url = "http://127.0.0.1:3000" + elif parsed.scheme == "tcp": + target_url = f"http://{parsed.netloc}" + + return httpx.Client( + transport=transport, + timeout=timeout, + follow_redirects=True, + headers=headers, + ), target_url diff --git a/llama3.1-8b-instruct/requirements.txt b/llama3.1-8b-instruct/requirements.txt new file mode 100644 index 0000000..e2de545 --- /dev/null +++ b/llama3.1-8b-instruct/requirements.txt @@ -0,0 +1,5 @@ +accelerate==0.29.3 +bentoml>=1.2.20 +fastapi==0.111.1 +openai==1.35.14 +vllm==0.5.3.post1; sys_platform == "linux" diff --git a/llama3.1-8b-instruct/service.py b/llama3.1-8b-instruct/service.py new file mode 100644 index 0000000..34d5ae6 --- /dev/null +++ b/llama3.1-8b-instruct/service.py @@ -0,0 +1,83 @@ +import uuid +from typing import AsyncGenerator, Optional + +import bentoml +from annotated_types import Ge, Le +from typing_extensions import Annotated + +from bentovllm_openai.utils import openai_endpoints + + +MAX_TOKENS = 1024 +SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. + +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" + +PROMPT_TEMPLATE = """<|begin_of_text|><|start_header_id|>system<|end_header_id|> + +{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|> + +{user_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +""" + +MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct" + +@openai_endpoints( + model_id=MODEL_ID, + default_chat_completion_parameters=dict(stop=["<|eot_id|>"]), +) +@bentoml.service( + name="bentovllm-llama3.1-8b-insruct-service", + traffic={ + "timeout": 300, + "concurrency": 256, # Matches the default max_num_seqs in the VLLM engine + }, + resources={ + "gpu": 1, + "gpu_type": "nvidia-l4", + }, +) +class VLLM: + + def __init__(self) -> None: + from transformers import AutoTokenizer + from vllm import AsyncEngineArgs, AsyncLLMEngine + + ENGINE_ARGS = AsyncEngineArgs( + model=MODEL_ID, + max_model_len=MAX_TOKENS, + enable_prefix_caching=True + ) + + self.engine = AsyncLLMEngine.from_engine_args(ENGINE_ARGS) + + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + self.stop_token_ids = [ + tokenizer.eos_token_id, + tokenizer.convert_tokens_to_ids("<|eot_id|>"), + ] + + @bentoml.api + async def generate( + self, + prompt: str = "Explain superconductors in plain English", + system_prompt: Optional[str] = SYSTEM_PROMPT, + max_tokens: Annotated[int, Ge(128), Le(MAX_TOKENS)] = MAX_TOKENS, + ) -> AsyncGenerator[str, None]: + from vllm import SamplingParams + + SAMPLING_PARAM = SamplingParams( + max_tokens=max_tokens, stop_token_ids=self.stop_token_ids, + ) + + if system_prompt is None: + system_prompt = SYSTEM_PROMPT + prompt = PROMPT_TEMPLATE.format(user_prompt=prompt, system_prompt=system_prompt) + stream = await self.engine.add_request(uuid.uuid4().hex, prompt, SAMPLING_PARAM) + + cursor = 0 + async for request_output in stream: + text = request_output.outputs[0].text + yield text[cursor:] + cursor = len(text)