Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create an ILMProvider interface and have our current implementation use it #17394

Merged
merged 23 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 198 additions & 0 deletions src/cascadia/QueryExtension/AzureLLMProvider.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#include "pch.h"
#include "AzureLLMProvider.h"
#include "../../types/inc/utils.hpp"
#include "LibraryResources.h"

#include "AzureLLMProvider.g.cpp"
#include "AzureResponse.g.cpp"

using namespace winrt::Windows::Foundation;
using namespace winrt::Windows::Foundation::Collections;
using namespace winrt::Windows::UI::Core;
using namespace winrt::Windows::UI::Xaml;
using namespace winrt::Windows::UI::Xaml::Controls;
using namespace winrt::Windows::System;
namespace WWH = ::winrt::Windows::Web::Http;
namespace WSS = ::winrt::Windows::Storage::Streams;
namespace WDJ = ::winrt::Windows::Data::Json;

static constexpr std::wstring_view acceptedModel{ L"gpt-35-turbo" };
PankajBhojwani marked this conversation as resolved.
Show resolved Hide resolved
Fixed Show fixed Hide fixed
static constexpr std::wstring_view acceptedSeverityLevel{ L"safe" };

const std::wregex azureOpenAIEndpointRegex{ LR"(^https.*openai\.azure\.com)" };
PankajBhojwani marked this conversation as resolved.
Show resolved Hide resolved

namespace winrt::Microsoft::Terminal::Query::Extension::implementation
{
AzureLLMProvider::AzureLLMProvider(winrt::hstring endpoint, winrt::hstring key)
{
_AIEndpoint = endpoint;
_AIKey = key;
_httpClient = winrt::Windows::Web::Http::HttpClient{};
_httpClient.DefaultRequestHeaders().Accept().TryParseAdd(L"application/json");
_httpClient.DefaultRequestHeaders().Append(L"api-key", _AIKey);
}

void AzureLLMProvider::ClearMessageHistory()
{
_jsonMessages.Clear();
}

void AzureLLMProvider::SetSystemPrompt(const winrt::hstring& systemPrompt)
{
WDJ::JsonObject systemMessageObject;
winrt::hstring systemMessageContent{ systemPrompt };
systemMessageObject.Insert(L"role", WDJ::JsonValue::CreateStringValue(L"system"));
systemMessageObject.Insert(L"content", WDJ::JsonValue::CreateStringValue(systemMessageContent));
_jsonMessages.Append(systemMessageObject);
}

void AzureLLMProvider::SetContext(Extension::IContext context)
{
_context = context;
}

winrt::Windows::Foundation::IAsyncOperation<Extension::IResponse> AzureLLMProvider::GetResponseAsync(const winrt::hstring& userPrompt)
{
// Use a flag for whether the response the user receives is an error message
// we pass this flag back to the caller so they can handle it appropriately (specifically, ExtensionPalette will send the correct telemetry event)
// there is only one case downstream from here that sets this flag to false, so start with it being true
bool isError{ true };
hstring message{};

// If the AI key and endpoint is still empty, tell the user to fill them out in settings
if (_AIKey.empty() || _AIEndpoint.empty())
{
message = RS_(L"CouldNotFindKeyErrorMessage");
}
else if (!std::regex_search(_AIEndpoint.c_str(), azureOpenAIEndpointRegex))
{
message = RS_(L"InvalidEndpointMessage");
}

// If we don't have a message string, that means the endpoint exists and matches the regex
// that we allow - now we can actually make the http request
if (message.empty())
{
// Make a copy of the prompt because we are switching threads
const auto promptCopy{ userPrompt };

// Make sure we are on the background thread for the http request
co_await winrt::resume_background();

WWH::HttpRequestMessage request{ WWH::HttpMethod::Post(), Uri{ _AIEndpoint } };
request.Headers().Accept().TryParseAdd(L"application/json");

WDJ::JsonObject jsonContent;
WDJ::JsonObject messageObject;

// _ActiveCommandline should be set already, we request for it the moment we become visible
winrt::hstring engineeredPrompt{ promptCopy };
if (_context && !_context.ActiveCommandline().empty())
{
engineeredPrompt = promptCopy + L". The shell I am running is " + _context.ActiveCommandline();
}
messageObject.Insert(L"role", WDJ::JsonValue::CreateStringValue(L"user"));
messageObject.Insert(L"content", WDJ::JsonValue::CreateStringValue(engineeredPrompt));
_jsonMessages.Append(messageObject);
jsonContent.SetNamedValue(L"messages", _jsonMessages);
jsonContent.SetNamedValue(L"max_tokens", WDJ::JsonValue::CreateNumberValue(800));
jsonContent.SetNamedValue(L"temperature", WDJ::JsonValue::CreateNumberValue(0.7));
jsonContent.SetNamedValue(L"frequency_penalty", WDJ::JsonValue::CreateNumberValue(0));
jsonContent.SetNamedValue(L"presence_penalty", WDJ::JsonValue::CreateNumberValue(0));
jsonContent.SetNamedValue(L"top_p", WDJ::JsonValue::CreateNumberValue(0.95));
jsonContent.SetNamedValue(L"stop", WDJ::JsonValue::CreateStringValue(L"None"));
const auto stringContent = jsonContent.ToString();
WWH::HttpStringContent requestContent{
stringContent,
WSS::UnicodeEncoding::Utf8,
L"application/json"
};

request.Content(requestContent);

// Send the request
try
{
const auto response = _httpClient.SendRequestAsync(request).get();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW all these get()s could use co_await in the future instead. That avoids the resume_background hassle.

// Parse out the suggestion from the response
const auto string{ response.Content().ReadAsStringAsync().get() };
const auto jsonResult{ WDJ::JsonObject::Parse(string) };
if (jsonResult.HasKey(L"error"))
{
const auto errorObject = jsonResult.GetNamedObject(L"error");
message = errorObject.GetNamedString(L"message");
}
else
{
if (_verifyModelIsValidHelper(jsonResult))
{
const auto choices = jsonResult.GetNamedArray(L"choices");
const auto firstChoice = choices.GetAt(0).GetObject();
const auto messageObject = firstChoice.GetNamedObject(L"message");
message = messageObject.GetNamedString(L"content");
isError = false;
}
else
{
message = RS_(L"InvalidModelMessage");
}
}
}
catch (...)
{
message = RS_(L"UnknownErrorMessage");
}
}

// Also make a new entry in our jsonMessages list, so the AI knows the full conversation so far
WDJ::JsonObject responseMessageObject;
responseMessageObject.Insert(L"role", WDJ::JsonValue::CreateStringValue(L"assistant"));
responseMessageObject.Insert(L"content", WDJ::JsonValue::CreateStringValue(message));
_jsonMessages.Append(responseMessageObject);

co_return winrt::make<AzureResponse>(message, isError);
}

bool AzureLLMProvider::_verifyModelIsValidHelper(const WDJ::JsonObject jsonResponse)
{
if (jsonResponse.GetNamedString(L"model") != acceptedModel)
{
return false;
}
WDJ::JsonObject contentFiltersObject;
// For some reason, sometimes the content filter results are in a key called "prompt_filter_results"
// and sometimes they are in a key called "prompt_annotations". Check for either.
if (jsonResponse.HasKey(L"prompt_filter_results"))
{
contentFiltersObject = jsonResponse.GetNamedArray(L"prompt_filter_results").GetObjectAt(0);
}
else if (jsonResponse.HasKey(L"prompt_annotations"))
{
contentFiltersObject = jsonResponse.GetNamedArray(L"prompt_annotations").GetObjectAt(0);
}
else
{
return false;
}
const auto contentFilters = contentFiltersObject.GetNamedObject(L"content_filter_results");
if (Feature_TerminalChatJailbreakFilter::IsEnabled() && !contentFilters.HasKey(L"jailbreak"))
{
return false;
}
for (const auto filterPair : contentFilters)
{
const auto filterLevel = filterPair.Value().GetObjectW();
if (filterLevel.HasKey(L"severity"))
{
if (filterLevel.GetNamedString(L"severity") != acceptedSeverityLevel)
{
return false;
}
}
}
return true;
}
}
51 changes: 51 additions & 0 deletions src/cascadia/QueryExtension/AzureLLMProvider.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#pragma once

#include "AzureLLMProvider.g.h"
#include "AzureResponse.g.h"

namespace winrt::Microsoft::Terminal::Query::Extension::implementation
{
struct AzureLLMProvider : AzureLLMProviderT<AzureLLMProvider>
{
AzureLLMProvider(winrt::hstring endpoint, winrt::hstring key);

void ClearMessageHistory();
void SetSystemPrompt(const winrt::hstring& systemPrompt);
void SetContext(Extension::IContext context);

winrt::Windows::Foundation::IAsyncOperation<Extension::IResponse> GetResponseAsync(const winrt::hstring& userPrompt);

private:
winrt::hstring _AIEndpoint;
winrt::hstring _AIKey;
PankajBhojwani marked this conversation as resolved.
Show resolved Hide resolved
winrt::Windows::Web::Http::HttpClient _httpClient{ nullptr };

Extension::IContext _context;

winrt::Windows::Data::Json::JsonArray _jsonMessages;

bool _verifyModelIsValidHelper(const Windows::Data::Json::JsonObject jsonResponse);
};

struct AzureResponse : AzureResponseT<AzureResponse>
{
AzureResponse(winrt::hstring message, bool isError) :
_message{ message },
_isError{ isError } {}
winrt::hstring Message() { return _message; };
PankajBhojwani marked this conversation as resolved.
Show resolved Hide resolved
bool IsError() { return _isError; };

private:
winrt::hstring _message;
bool _isError;
};
}

namespace winrt::Microsoft::Terminal::Query::Extension::factory_implementation
{
BASIC_FACTORY(AzureLLMProvider);
BASIC_FACTORY(AzureResponse);
}
17 changes: 17 additions & 0 deletions src/cascadia/QueryExtension/AzureLLMProvider.idl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

import "ILLMProvider.idl";
Fixed Show fixed Hide fixed

namespace Microsoft.Terminal.Query.Extension
{
[default_interface] runtimeclass AzureLLMProvider : ILLMProvider
Fixed Show fixed Hide fixed
{
AzureLLMProvider(String endpoint, String key);
}

[default_interface] runtimeclass AzureResponse : IResponse
PankajBhojwani marked this conversation as resolved.
Show resolved Hide resolved
{
AzureResponse(String message, Boolean isError);
}
}
Loading
Loading