diff --git a/Dockerfile.openvino b/Dockerfile.openvino index 1a19140b8ee5a..ea260d30d926e 100644 --- a/Dockerfile.openvino +++ b/Dockerfile.openvino @@ -9,8 +9,9 @@ RUN apt-get update -y && \ WORKDIR /workspace # build and install OpenVINO -RUN git clone -b feature/optimize_vllm_tput https://github.com/dmitry-gorokhov/openvino.git && \ +RUN git clone https://github.com/openvinotoolkit/openvino.git && \ cd /workspace/openvino && \ + git checkout 492699d0fa54ae2b2d2f6ce1b405385665edf2f6 && \ git submodule update --init -- /workspace/openvino/thirdparty/xbyak \ /workspace/openvino/thirdparty/pugixml \ /workspace/openvino/thirdparty/open_model_zoo \ @@ -26,13 +27,6 @@ RUN cmake -DCPACK_GENERATOR=DEB -DENABLE_PYTHON=ON -DENABLE_PYTHON_PACKAGING=ON RUN cmake --build /workspace/openvino_build --parallel 8 RUN cmake -P /workspace/openvino_build/cmake_install.cmake -# build and install OpenVINO Contrib with PagedAttention -RUN git clone --branch paged-attention https://github.com/ilya-lavrenov/openvino_contrib.git -RUN cmake -DCUSTOM_OPERATIONS=paged_attention -DCMAKE_INSTALL_PREFIX=/usr \ - -S /workspace/openvino_contrib/modules/custom_operations/ -B /workspace/paged_attention_build/ -RUN cmake --build /workspace/paged_attention_build/ --parallel 8 -RUN cmake -P /workspace/openvino_build/cmake_install.cmake - # Install OpenVINO tokenizers RUN PIP_PRE=1 PIP_EXTRA_INDEX_URL="https://storage.openvinotoolkit.org/simple/wheels/nightly" python3 -m pip install openvino-tokenizers #################### BASE BUILD IMAGE #################### @@ -52,7 +46,6 @@ RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip in COPY vllm/ /workspace/vllm/vllm COPY setup.py /workspace/vllm/ -RUN cmake -P /workspace/paged_attention_build/cmake_install.cmake RUN python3 -m pip install --no-build-isolation /workspace/vllm/ #################### EXTENSION Build IMAGE #################### diff --git a/vllm/executor/openvino_executor.py b/vllm/executor/openvino_executor.py index 3cf397b2e1bba..0b3e626594782 100644 --- a/vllm/executor/openvino_executor.py +++ b/vllm/executor/openvino_executor.py @@ -443,21 +443,6 @@ def _init_distributed_environment(self) -> None: ensure_model_parallel_initialized(self.parallel_config.tensor_parallel_size, self.parallel_config.pipeline_parallel_size) - def __del__(self): - # TODO: Better to put this code in a wrapper around optimum-based model inside OpenVINO model loader - # but it requires more coding because it should be a full-functional substitution of torch.nn.Module. - # The current solution to put the code here is not robust enough: self.model_runner is not our class instance - # and it can be modified in a way that model is no longer kept as self.model_runner.model attribute. - if not (hasattr(self.model_runner, 'model') and hasattr(self.model_runner.model, 'model')): - return - pt_model = self.model_runner.model - if hasattr(pt_model, 'ov_node_factory'): - del pt_model._ov_request - del pt_model.model - if gc: # when app is being destroyed the module may not be available - gc.collect() - del pt_model.ov_node_factory - class OpenVINOExecutor(ExecutorBase): diff --git a/vllm/model_executor/openvino_model_loader.py b/vllm/model_executor/openvino_model_loader.py index 118e791f93cd4..be48bcdb2bca5 100644 --- a/vllm/model_executor/openvino_model_loader.py +++ b/vllm/model_executor/openvino_model_loader.py @@ -57,10 +57,17 @@ def ov_wrapper(self, *args, **kwargs) -> torch.Tensor: self._ov_request.wait() return torch.from_numpy(self._ov_request.get_tensor("logits").data) +def arguments_as_outputs(arguments): + outputs = [] + for argument in arguments: + if issubclass(type(argument), ov.runtime.Output): + outputs.append(argument) + else: + outputs.extend(argument.outputs()) + return outputs def patch_stateful_model( model: ov.Model, - factory, kv_cache_dtype: Type, is_cpu: bool): print('TRANSFORMING OPTIMUM-INTEL MODEL TO vLLM COMPATIBLE FORM') @@ -217,7 +224,7 @@ def take_4d(option1, option2, option3): else: alibi_slopes = opset13.constant(np.array([], np.float32)) - paged_attention = factory.create("PagedAttentionExtension", [ + paged_attention = ov.runtime.op._PagedAttentionExtension(arguments_as_outputs([ q_reshape, k_reshape, v_reshape, @@ -227,7 +234,7 @@ def take_4d(option1, option2, option3): scale, alibi_slopes, sliding_window - ]) + ])) pa_shape = opset13.concat([ opset13.constant([0]), opset13.constant([0]), @@ -384,7 +391,7 @@ def _patch_model_with_openvino( is_cpu: bool): print(' ============= PATCHING MODEL =============') from vllm.model_executor.layers.attention.attention import Attention - from openvino.frontend.pytorch import ModuleExtension + from openvino.frontend.pytorch import ModuleExtension, ConversionExtension from openvino import Core, convert_model, Type, PartialShape # Avoid usage of vllm._C.ops @@ -488,6 +495,11 @@ def wrapper(module, target_op, *args, **kwargs): torch.tensor(module.backend.sliding_window if module.backend.sliding_window is not None else 0, dtype=torch.int32) # sliding_window ) + def paged_attention_convertion(context): + inputs = [context.get_input(i) for i in range(context.get_input_size())] + pa = ov.runtime.op._PagedAttentionExtension(inputs) + return pa.outputs() + with torch.no_grad(): print('>>>>>>>>>>>>> CONVERTING OV MODEL') ov_model = convert_model( @@ -500,7 +512,7 @@ def wrapper(module, target_op, *args, **kwargs): evaluate=lambda module, *args, **kwargs: args[0], # need this because PagedAttention module fails in torch.jit.trace convert=wrapper ), - "libuser_ov_extensions.so" + ConversionExtension('PagedAttentionExtension', paged_attention_convertion), ] ) @@ -593,12 +605,7 @@ def get_model(model_config: ModelConfig, compile=False, trust_remote_code=model_config.trust_remote_code ) - if not hasattr(pt_model, 'ov_node_factory'): - from openvino.runtime.utils.node_factory import NodeFactory - # Keep factory to destroy it in a particular moment when all other objects referencing custom nodes are destoyed - pt_model.ov_node_factory = NodeFactory() - pt_model.ov_node_factory.add_extension('libuser_ov_extensions.so') - patch_stateful_model(pt_model.model, pt_model.ov_node_factory, kv_cache_dtype, device_config.device.type == "cpu") + patch_stateful_model(pt_model.model, kv_cache_dtype, device_config.device.type == "cpu") # For deployment outside vLLM model_file_name = os.environ.get('VLLM_OPENVINO_EXPORTED_IR_NAME', '')