Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update OpenVINO Code #721

Merged
merged 2 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion modules/openvino_code/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ OpenVINO Code provides the following features:

1. Create a new python file
2. Try typing `def main():`
3. Press shortcut buttons (TBD) for code completion
3. Press shortcut button `ctrl+alt+space` for code completion

### Checking output

Expand Down
Binary file not shown.
Binary file not shown.
4 changes: 2 additions & 2 deletions modules/openvino_code/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 19 additions & 7 deletions modules/openvino_code/package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"publisher": "OpenVINO",
"name": "openvino-code-completion",
"version": "0.0.2",
"version": "0.0.3",
"displayName": "OpenVINO Code Completion",
"description": "VSCode extension for AI code completion with OpenVINO",
"icon": "media/logo.png",
Expand Down Expand Up @@ -188,38 +188,44 @@
"default": 30,
"markdownDescription": "Server request timeout in seconds after which request will be aborted."
},
"openvinoCode.fillInTheMiddleMode": {
"openvinoCode.streamInlineCompletion": {
"order": 3,
"type": "boolean",
"default": "false",
"description": "When checked inline complention will be generated in streaming mode"
},
"openvinoCode.fillInTheMiddleMode": {
"order": 4,
"type": "boolean",
"default": false,
"markdownDescription": "When checked, text before (above) and after (below) the cursor will be used for completion generation. When unckecked, only text before (above) the cursor will be used."
},
"openvinoCode.temperature": {
"order": 4,
"order": 5,
"type": "number",
"default": 0.2,
"description": "Sampling temperature."
},
"openvinoCode.topK": {
"order": 4,
"order": 5,
"type": "integer",
"default": 10,
"description": "Top K."
},
"openvinoCode.topP": {
"order": 4,
"order": 5,
"type": "number",
"default": 1,
"description": "Top P."
},
"openvinoCode.minNewTokens": {
"order": 5,
"order": 6,
"type": "number",
"default": 1,
"description": "Minimum of new generated tokens."
},
"openvinoCode.maxNewTokens": {
"order": 5,
"order": 6,
"type": "number",
"default": 100,
"description": "Maximum of new generated tokens."
Expand Down Expand Up @@ -280,6 +286,12 @@
"key": "ctrl+alt+space",
"mac": "ctrl+alt+space",
"when": "editorTextFocus"
},
{
"command": "openvinoCode.stopGeneration",
"key": "escape",
"mac": "escape",
"when": "openvinoCode.generating"
}
]
},
Expand Down
13 changes: 6 additions & 7 deletions modules/openvino_code/server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,21 @@ version = "0.0.1"
requires-python = ">=3.8"

dependencies = [
'fastapi==0.101.0',
'uvicorn==0.23.1',
'fastapi==0.103.1',
'uvicorn==0.23.2',
'torch @ https://download.pytorch.org/whl/cpu-cxx11-abi/torch-2.0.1%2Bcpu.cxx11.abi-cp38-cp38-linux_x86_64.whl ; sys_platform=="linux" and python_version == "3.8"',
'torch @ https://download.pytorch.org/whl/cpu-cxx11-abi/torch-2.0.1%2Bcpu.cxx11.abi-cp39-cp39-linux_x86_64.whl ; sys_platform=="linux" and python_version == "3.9"',
'torch @ https://download.pytorch.org/whl/cpu-cxx11-abi/torch-2.0.1%2Bcpu.cxx11.abi-cp310-cp310-linux_x86_64.whl ; sys_platform=="linux" and python_version == "3.10"',
'torch @ https://download.pytorch.org/whl/cpu-cxx11-abi/torch-2.0.1%2Bcpu.cxx11.abi-cp311-cp311-linux_x86_64.whl ; sys_platform=="linux" and python_version == "3.11"',
'torch ; sys_platform != "linux"',
'openvino==2023.1.0.dev20230811',
'optimum-intel[openvino]==1.11.0',
'transformers==4.31.0',
'optimum==1.12.0',
'optimum-intel[openvino]==1.10.1',
]

[project.optional-dependencies]
dev = [
"black",
"ruff",
]
dev = ["black", "ruff"]

[build-system]
requires = ["setuptools>=43.0.0", "wheel"]
Expand Down
11 changes: 6 additions & 5 deletions modules/openvino_code/server/src/app.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from time import perf_counter
from typing import Dict, Union

from fastapi import Depends, FastAPI
from fastapi import Depends, FastAPI, Request
from fastapi.responses import RedirectResponse, StreamingResponse
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, TypeAdapter

from src.generators import GeneratorFunctor
from src.utils import get_logger
Expand Down Expand Up @@ -105,11 +105,12 @@ async def generate(

@app.post("/api/generate_stream", status_code=200)
async def generate_stream(
request: GenerationRequest,
request: Request,
generator: GeneratorFunctor = Depends(get_generator_dummy),
) -> StreamingResponse:
logger.info(request)
return StreamingResponse(generator.generate_stream(request.inputs, request.parameters.model_dump()))
generation_request = TypeAdapter(GenerationRequest).validate_python(await request.json())
logger.info(generation_request)
return StreamingResponse(generator.generate_stream(generation_request.inputs, generation_request.parameters.model_dump(), request))


@app.post("/api/summarize", status_code=200, response_model=GenerationResponse)
Expand Down
40 changes: 35 additions & 5 deletions modules/openvino_code/server/src/generators.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import re
from functools import lru_cache
from io import StringIO
Expand All @@ -6,6 +7,7 @@
from typing import Any, Callable, Container, Dict, Generator, List, Optional, Type, Union

import torch
from fastapi import Request
from huggingface_hub.utils import EntryNotFoundError
from optimum.intel import OVModelForCausalLM, OVModelForSeq2SeqLM
from transformers import (
Expand Down Expand Up @@ -61,11 +63,15 @@ def get_model(checkpoint: str, device: str = "CPU") -> OVModel:
return model


# TODO: generator needs running flag or cancellation on new generation request
# generator cannot handle concurrent requests - fails and stalls process
# RuntimeError: Exception from src/inference/src/infer_request.cpp:189:
# [ REQUEST_BUSY ]
class GeneratorFunctor:
def __call__(self, input_text: str, parameters: Dict[str, Any]) -> str:
raise NotImplementedError

async def generate_stream(self, input_text: str, parameters: Dict[str, Any]):
async def generate_stream(self, input_text: str, parameters: Dict[str, Any], request: Request):
raise NotImplementedError

def summarize(self, input_text: str, template: str, signature: str, style: str, parameters: Dict[str, Any]):
Expand Down Expand Up @@ -122,24 +128,45 @@ def __call__(
logger.info(f"Number of input tokens: {prompt_len}; generated {len(output_ids)} tokens")
return self.tokenizer.decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)

async def generate_stream(
self, input_text: str, parameters: Dict[str, Any], stopping_criteria: Optional[StoppingCriteriaList] = None
):
async def generate_stream(self, input_text: str, parameters: Dict[str, Any], request: Request = None):
input_ids = self.tokenizer.encode(input_text, return_tensors="pt")
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
parameters["streamer"] = streamer
config = GenerationConfig.from_dict({**self.generation_config.to_dict(), **parameters})

stop_on_tokens = StopOnTokens([])

generation_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
stopping_criteria=stopping_criteria,
stopping_criteria=StoppingCriteriaList([stop_on_tokens]),
**config.to_dict(),
)

# listen disconnect event so generation can be stopped
def listen_for_disconnect():
async def listen():
message = await request.receive()
if message.get("type") == "http.disconnect":
stop_on_tokens.cancelled = True
asyncio.create_task(listen())


listen_thread = Thread(target=listen_for_disconnect)
# thread.run doesn't actually start a new thread
# it runs the thread function in current thread context
# thread.start() doesn't work here
listen_thread.run()

thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()

for token in streamer:
await asyncio.sleep(0.01)
yield token

thread.join()

def generate_between(
self,
input_parts: List[str],
Expand Down Expand Up @@ -243,7 +270,10 @@ def inner() -> GeneratorFunctor:

class StopOnTokens(StoppingCriteria):
def __init__(self, token_ids: List[int]) -> None:
self.cancelled = False
self.token_ids = torch.tensor(token_ids, requires_grad=False)

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
if self.cancelled:
return True
return torch.any(torch.eq(input_ids[0, -1], self.token_ids)).item()
1 change: 1 addition & 0 deletions modules/openvino_code/src/configuration.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export type CustomConfiguration = {
model: ModelName;
serverUrl: string;
serverRequestTimeout: number;
streamInlineCompletion: boolean;
fillInTheMiddleMode: boolean;
temperature: number;
topK: number;
Expand Down
5 changes: 5 additions & 0 deletions modules/openvino_code/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,9 @@ export const COMMANDS = {
STOP_SERVER_NATIVE: 'openvinoCode.stopServerNative',
SHOW_SERVER_LOG: 'openvinoCode.showServerLog',
SHOW_EXTENSION_LOG: 'openvinoCode.showExtensionLog',
STOP_GENERATION: 'openvinoCode.stopGeneration',
};

export const EXTENSION_CONTEXT_STATE = {
GENERATING: 'openvinoCode.generating',
};
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { InlineCompletionItem, Position, Range, TextDocument, window } from 'vscode';
import { backendService } from '../services/backend.service';
import { extensionState } from '../state';
import { EXTENSION_DISPLAY_NAME } from '../constants';
import { IGenerateRequest, backendService } from '../services/backend.service';
import { extensionState } from '../state';

const outputChannel = window.createOutputChannel(EXTENSION_DISPLAY_NAME, { log: true });
const logCompletionInput = (input: string): void => outputChannel.append(`Completion input:\n${input}\n\n`);
Expand Down Expand Up @@ -67,6 +67,41 @@ class CompletionService {
const completionItem = new InlineCompletionItem(generatedText, new Range(position, position.translate(0, 1)));
return [completionItem];
}

async getCompletionStream(
document: TextDocument,
position: Position,
onDataChunk: (chunk: string) => unknown,
signal?: AbortSignal
) {
const textBeforeCursor = this._getTextBeforeCursor(document, position);
const textAfterCursor = this._getTextAfterCursor(document, position);
const completionInput = this._prepareCompletionInput(textBeforeCursor, textAfterCursor);
logCompletionInput(completionInput);

const { temperature, topK, topP, minNewTokens, maxNewTokens } = extensionState.config;

const request: IGenerateRequest = {
inputs: completionInput,
parameters: {
temperature,
top_k: topK,
top_p: topP,
min_new_tokens: minNewTokens,
max_new_tokens: maxNewTokens,
},
};

outputChannel.append(`Completion output:\n`);
return backendService.generateCompletionStream(
request,
(chunk) => {
outputChannel.append(chunk);
onDataChunk(chunk);
},
signal
);
}
}

export default new CompletionService();
65 changes: 28 additions & 37 deletions modules/openvino_code/src/inline-completion/index.ts
Original file line number Diff line number Diff line change
@@ -1,48 +1,39 @@
import { Disposable, ExtensionContext, commands, languages, window } from 'vscode';
import { IExtensionState } from '@shared/extension-state';
import { ExtensionContext } from 'vscode';
import { IExtensionComponent } from '../extension-component.interface';
import { CommandInlineCompletionItemProvider } from './command-inline-completion-provider';
import { COMMANDS } from '../constants';
import { extensionState } from '../state';
import { notificationService } from '../services/notification.service';
import { inlineCompletion as baseInlineCompletion } from './inline-completion-component';
import { streamingInlineCompletion } from './streaming-inline-completion-component';

class InlineCompletion implements IExtensionComponent {
activate(context: ExtensionContext): void {
// Register Inline Completion triggered by command
const commandInlineCompletionProvider = new CommandInlineCompletionItemProvider();

let commandInlineCompletionDisposable: Disposable;

const commandDisposable = commands.registerCommand(COMMANDS.GENERATE_INLINE_COPMLETION, () => {
if (!extensionState.get('isServerAvailable')) {
notificationService.showServerNotAvailableMessage(extensionState.state);
return;
}
if (extensionState.get('isLoading') && window.activeTextEditor) {
void window.showTextDocument(window.activeTextEditor.document);
return;
}

extensionState.set('isLoading', true);

if (commandInlineCompletionDisposable) {
commandInlineCompletionDisposable.dispose();
}
private _context: ExtensionContext | null = null;
private _listener = ({ config }: IExtensionState) => this.activateCompletion(config.streamInlineCompletion);

commandInlineCompletionDisposable = languages.registerInlineCompletionItemProvider(
{ pattern: '**' },
commandInlineCompletionProvider
);

void commandInlineCompletionProvider.triggerCompletion(() => {
commandInlineCompletionDisposable.dispose();
extensionState.set('isLoading', false);
});
});
activate(context: ExtensionContext): void {
this._context = context;
this.activateCompletion(extensionState.config.streamInlineCompletion);
extensionState.subscribe(this._listener);
}

context.subscriptions.push(commandDisposable);
deactivate(): void {
streamingInlineCompletion.deactivate();
baseInlineCompletion.deactivate();
extensionState.unsubscribe(this._listener);
}

deactivate(): void {}
activateCompletion(streaming: boolean) {
if (!this._context) {
return;
}
baseInlineCompletion.deactivate();
streamingInlineCompletion.deactivate();

if (streaming) {
streamingInlineCompletion.activate(this._context);
} else {
baseInlineCompletion.activate(this._context);
}
}
}

export const inlineCompletion = new InlineCompletion();
Loading