Skip to content

Commit

Permalink
Merge pull request #23 from slyalin/paged_attention_in_openvino
Browse files Browse the repository at this point in the history
Use PagedAttentionExtension from OV without contrib dependency
  • Loading branch information
ilya-lavrenov authored Apr 4, 2024
2 parents e2ac5a2 + c790199 commit dbed638
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 35 deletions.
11 changes: 2 additions & 9 deletions Dockerfile.openvino
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ RUN apt-get update -y && \
WORKDIR /workspace

# build and install OpenVINO
RUN git clone -b feature/optimize_vllm_tput https://github.com/dmitry-gorokhov/openvino.git && \
RUN git clone https://github.com/openvinotoolkit/openvino.git && \
cd /workspace/openvino && \
git checkout 492699d0fa54ae2b2d2f6ce1b405385665edf2f6 && \
git submodule update --init -- /workspace/openvino/thirdparty/xbyak \
/workspace/openvino/thirdparty/pugixml \
/workspace/openvino/thirdparty/open_model_zoo \
Expand All @@ -26,13 +27,6 @@ RUN cmake -DCPACK_GENERATOR=DEB -DENABLE_PYTHON=ON -DENABLE_PYTHON_PACKAGING=ON
RUN cmake --build /workspace/openvino_build --parallel 8
RUN cmake -P /workspace/openvino_build/cmake_install.cmake

# build and install OpenVINO Contrib with PagedAttention
RUN git clone --branch paged-attention https://github.com/ilya-lavrenov/openvino_contrib.git
RUN cmake -DCUSTOM_OPERATIONS=paged_attention -DCMAKE_INSTALL_PREFIX=/usr \
-S /workspace/openvino_contrib/modules/custom_operations/ -B /workspace/paged_attention_build/
RUN cmake --build /workspace/paged_attention_build/ --parallel 8
RUN cmake -P /workspace/openvino_build/cmake_install.cmake

# Install OpenVINO tokenizers
RUN PIP_PRE=1 PIP_EXTRA_INDEX_URL="https://storage.openvinotoolkit.org/simple/wheels/nightly" python3 -m pip install openvino-tokenizers
#################### BASE BUILD IMAGE ####################
Expand All @@ -52,7 +46,6 @@ RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip in
COPY vllm/ /workspace/vllm/vllm
COPY setup.py /workspace/vllm/

RUN cmake -P /workspace/paged_attention_build/cmake_install.cmake
RUN python3 -m pip install --no-build-isolation /workspace/vllm/
#################### EXTENSION Build IMAGE ####################

Expand Down
15 changes: 0 additions & 15 deletions vllm/executor/openvino_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,21 +443,6 @@ def _init_distributed_environment(self) -> None:
ensure_model_parallel_initialized(self.parallel_config.tensor_parallel_size,
self.parallel_config.pipeline_parallel_size)

def __del__(self):
# TODO: Better to put this code in a wrapper around optimum-based model inside OpenVINO model loader
# but it requires more coding because it should be a full-functional substitution of torch.nn.Module.
# The current solution to put the code here is not robust enough: self.model_runner is not our class instance
# and it can be modified in a way that model is no longer kept as self.model_runner.model attribute.
if not (hasattr(self.model_runner, 'model') and hasattr(self.model_runner.model, 'model')):
return
pt_model = self.model_runner.model
if hasattr(pt_model, 'ov_node_factory'):
del pt_model._ov_request
del pt_model.model
if gc: # when app is being destroyed the module may not be available
gc.collect()
del pt_model.ov_node_factory


class OpenVINOExecutor(ExecutorBase):

Expand Down
29 changes: 18 additions & 11 deletions vllm/model_executor/openvino_model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,17 @@ def ov_wrapper(self, *args, **kwargs) -> torch.Tensor:
self._ov_request.wait()
return torch.from_numpy(self._ov_request.get_tensor("logits").data)

def arguments_as_outputs(arguments):
outputs = []
for argument in arguments:
if issubclass(type(argument), ov.runtime.Output):
outputs.append(argument)
else:
outputs.extend(argument.outputs())
return outputs

def patch_stateful_model(
model: ov.Model,
factory,
kv_cache_dtype: Type,
is_cpu: bool):
print('TRANSFORMING OPTIMUM-INTEL MODEL TO vLLM COMPATIBLE FORM')
Expand Down Expand Up @@ -217,7 +224,7 @@ def take_4d(option1, option2, option3):
else:
alibi_slopes = opset13.constant(np.array([], np.float32))

paged_attention = factory.create("PagedAttentionExtension", [
paged_attention = ov.runtime.op._PagedAttentionExtension(arguments_as_outputs([
q_reshape,
k_reshape,
v_reshape,
Expand All @@ -227,7 +234,7 @@ def take_4d(option1, option2, option3):
scale,
alibi_slopes,
sliding_window
])
]))
pa_shape = opset13.concat([
opset13.constant([0]),
opset13.constant([0]),
Expand Down Expand Up @@ -384,7 +391,7 @@ def _patch_model_with_openvino(
is_cpu: bool):
print(' ============= PATCHING MODEL =============')
from vllm.model_executor.layers.attention.attention import Attention
from openvino.frontend.pytorch import ModuleExtension
from openvino.frontend.pytorch import ModuleExtension, ConversionExtension
from openvino import Core, convert_model, Type, PartialShape

# Avoid usage of vllm._C.ops
Expand Down Expand Up @@ -488,6 +495,11 @@ def wrapper(module, target_op, *args, **kwargs):
torch.tensor(module.backend.sliding_window if module.backend.sliding_window is not None else 0, dtype=torch.int32) # sliding_window
)

def paged_attention_convertion(context):
inputs = [context.get_input(i) for i in range(context.get_input_size())]
pa = ov.runtime.op._PagedAttentionExtension(inputs)
return pa.outputs()

with torch.no_grad():
print('>>>>>>>>>>>>> CONVERTING OV MODEL')
ov_model = convert_model(
Expand All @@ -500,7 +512,7 @@ def wrapper(module, target_op, *args, **kwargs):
evaluate=lambda module, *args, **kwargs: args[0], # need this because PagedAttention module fails in torch.jit.trace
convert=wrapper
),
"libuser_ov_extensions.so"
ConversionExtension('PagedAttentionExtension', paged_attention_convertion),
]
)

Expand Down Expand Up @@ -593,12 +605,7 @@ def get_model(model_config: ModelConfig,
compile=False,
trust_remote_code=model_config.trust_remote_code
)
if not hasattr(pt_model, 'ov_node_factory'):
from openvino.runtime.utils.node_factory import NodeFactory
# Keep factory to destroy it in a particular moment when all other objects referencing custom nodes are destoyed
pt_model.ov_node_factory = NodeFactory()
pt_model.ov_node_factory.add_extension('libuser_ov_extensions.so')
patch_stateful_model(pt_model.model, pt_model.ov_node_factory, kv_cache_dtype, device_config.device.type == "cpu")
patch_stateful_model(pt_model.model, kv_cache_dtype, device_config.device.type == "cpu")

# For deployment outside vLLM
model_file_name = os.environ.get('VLLM_OPENVINO_EXPORTED_IR_NAME', '')
Expand Down

0 comments on commit dbed638

Please sign in to comment.