-
-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Neuron] Adding support for adding/ overriding neuron configuration a… #8062
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import os | ||
|
||
from vllm import LLM, SamplingParams | ||
|
||
# creates XLA hlo graphs for all the context length buckets. | ||
os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048" | ||
# creates XLA hlo graphs for all the token gen buckets. | ||
os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048" | ||
# Quantizes neuron model weight to int8 , | ||
# The default config for quantization is int8 dtype. | ||
os.environ['NEURON_QUANT_DTYPE'] = "s8" | ||
|
||
# Sample prompts. | ||
prompts = [ | ||
"Hello, my name is", | ||
"The president of the United States is", | ||
"The capital of France is", | ||
"The future of AI is", | ||
] | ||
# Create a sampling params object. | ||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95) | ||
|
||
# Create an LLM. | ||
llm = LLM( | ||
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", | ||
max_num_seqs=8, | ||
# The max_model_len and block_size arguments are required to be same as | ||
# max sequence length when targeting neuron device. | ||
# Currently, this is a known limitation in continuous batching support | ||
# in transformers-neuronx. | ||
# TODO(liangfu): Support paged-attention in transformers-neuronx. | ||
max_model_len=2048, | ||
block_size=2048, | ||
# The device can be automatically detected when AWS Neuron SDK is installed. | ||
# The device argument can be either unspecified for automated detection, | ||
# or explicitly assigned. | ||
device="neuron", | ||
quantization="neuron_quant", | ||
override_neuron_config={ | ||
"cast_logits_dtype": "bfloat16", | ||
}, | ||
tensor_parallel_size=2) | ||
# Generate texts from the prompts. The output is a list of RequestOutput objects | ||
# that contain the prompt, generated text, and other information. | ||
outputs = llm.generate(prompts, sampling_params) | ||
# Print the outputs. | ||
for output in outputs: | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text | ||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import os | ||
from importlib.util import find_spec | ||
from typing import Any, Dict, List, Optional | ||
|
||
from torch.nn import Module | ||
|
||
from vllm.model_executor.layers.quantization.base_config import ( | ||
QuantizationConfig) | ||
|
||
SUPPORTED_QUANT_DTYPE_LIST = ['s8', 'f8e4m3fn'] | ||
|
||
|
||
class NeuronQuantConfig(QuantizationConfig): | ||
"""Int8 Quantization Config class for Neuron Backend.""" | ||
|
||
def __init__( | ||
self, | ||
dequant_dtype: str = "f16", | ||
quantize_method: str = "vector_dynamic", | ||
) -> None: | ||
self.quant_dtype = os.getenv("NEURON_QUANT_DTYPE", "s8") | ||
if self.quant_dtype not in SUPPORTED_QUANT_DTYPE_LIST: | ||
raise ValueError( | ||
f"Neuron quantization datatype {self.quant_dtype} is not valid," | ||
f"the quantization datatype should match one of the below types" | ||
f"{SUPPORTED_QUANT_DTYPE_LIST}") | ||
self.dequant_dtype = dequant_dtype | ||
self.quantize_method = quantize_method | ||
|
||
def get_name(self) -> str: | ||
return "neuron_quant" | ||
|
||
def get_supported_act_dtypes(self) -> List[str]: | ||
return SUPPORTED_QUANT_DTYPE_LIST | ||
|
||
@classmethod | ||
def get_min_capability(cls) -> int: | ||
raise NotImplementedError( | ||
"This function should not be called with Neuron Backend") | ||
|
||
@staticmethod | ||
def get_config_filenames() -> List[str]: | ||
return [] | ||
|
||
@classmethod | ||
def from_config(cls, config: Dict[str, Any]) -> "NeuronQuantConfig": | ||
quantize_method = cls.get_from_keys(config, ["quantize_method"]) | ||
dequant_dtype = cls.get_from_keys(config, ["dequant_dtype"]) | ||
return cls(dequant_dtype=dequant_dtype, | ||
quantize_method=quantize_method) | ||
|
||
def get_quant_method(self, layer: Module, prefix: str) -> Optional[Any]: | ||
if find_spec("transformers_neuronx") is not None: | ||
return self.get_quantization_config() | ||
else: | ||
raise NotImplementedError( | ||
"Neuron Quantization is only supported through" | ||
" transformers_neuronx.") | ||
|
||
def get_scaled_act_names(self) -> List[str]: | ||
return [] | ||
|
||
def get_quantization_config(self): | ||
from transformers_neuronx.config import QuantizationConfig | ||
return QuantizationConfig(quant_dtype=self.quant_dtype, | ||
dequant_dtype=self.dequant_dtype, | ||
quantize_method=self.quantize_method) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this configurable from existing vLLM config? can we pass existing vLLM config ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the datatype we use for quantization in neuron doesn't match the torch d types or existing quant , so created a new quant config to support it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we could map torch dtype to neuron dtype, we should be able to avoid inventing the environment variable specifically for neuron backend.
for instance,
s8
can be translated totorch.int8
, andf8e4m3fn
and be translated totorch.float8_e4m3fn
.ideally, we should be able to configure quantization data type with existing vllm config.
help map the relation between neuron dtype and torch dtype?