Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create an ILMProvider interface and have our current implementation use it #17394

Merged
merged 23 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/actions/spelling/allow/allow.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@ godbolt
gpt
hyperlinking
hyperlinks
ILM
kje
libfuzzer
liga
lje
Llast
lm
llm
Lmid
locl
lol
Expand Down
228 changes: 228 additions & 0 deletions src/cascadia/QueryExtension/AzureLLMProvider.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#include "pch.h"
#include "AzureLLMProvider.h"
#include "../../types/inc/utils.hpp"
#include "LibraryResources.h"

#include "AzureLLMProvider.g.cpp"
#include "AzureResponse.g.cpp"

using namespace winrt::Windows::Foundation;
using namespace winrt::Windows::Foundation::Collections;
using namespace winrt::Windows::UI::Core;
using namespace winrt::Windows::UI::Xaml;
using namespace winrt::Windows::UI::Xaml::Controls;
using namespace winrt::Windows::System;
namespace WWH = ::winrt::Windows::Web::Http;
namespace WSS = ::winrt::Windows::Storage::Streams;
namespace WDJ = ::winrt::Windows::Data::Json;

static constexpr std::wstring_view acceptedModels[] = {
L"gpt-35-turbo",
L"gpt4",
L"gpt4-32k",
L"gpt4o",
L"gpt-35-turbo-16k"
};
static constexpr std::wstring_view acceptedSeverityLevel{ L"safe" };
static constexpr std::wstring_view expectedDomain{ L"azure.com" };
static constexpr std::wstring_view applicationJson{ L"application/json" };
static constexpr std::wstring_view endpointString{ L"endpoint" };
static constexpr std::wstring_view keyString{ L"key" };
static constexpr std::wstring_view roleString{ L"role" };
static constexpr std::wstring_view contentString{ L"content" };
static constexpr std::wstring_view messageString{ L"message" };
static constexpr std::wstring_view errorString{ L"error" };
static constexpr std::wstring_view severityString{ L"severity" };

const std::wregex azureOpenAIEndpointRegex{ LR"(^https.*openai\.azure\.com)" };
PankajBhojwani marked this conversation as resolved.
Show resolved Hide resolved

namespace winrt::Microsoft::Terminal::Query::Extension::implementation
{
void AzureLLMProvider::SetAuthentication(const Windows::Foundation::Collections::ValueSet& authValues)
{
_azureEndpoint = unbox_value_or<hstring>(authValues.TryLookup(endpointString).try_as<IPropertyValue>(), L"");
_azureKey = unbox_value_or<hstring>(authValues.TryLookup(keyString).try_as<IPropertyValue>(), L"");
_httpClient = winrt::Windows::Web::Http::HttpClient{};
_httpClient.DefaultRequestHeaders().Accept().TryParseAdd(applicationJson);
_httpClient.DefaultRequestHeaders().Append(L"api-key", _azureKey);
}

void AzureLLMProvider::ClearMessageHistory()
{
_jsonMessages.Clear();
}

void AzureLLMProvider::SetSystemPrompt(const winrt::hstring& systemPrompt)
{
WDJ::JsonObject systemMessageObject;
winrt::hstring systemMessageContent{ systemPrompt };
systemMessageObject.Insert(roleString, WDJ::JsonValue::CreateStringValue(L"system"));
systemMessageObject.Insert(contentString, WDJ::JsonValue::CreateStringValue(systemMessageContent));
_jsonMessages.Append(systemMessageObject);
}

void AzureLLMProvider::SetContext(const Extension::IContext context)
{
_context = context;
}

winrt::Windows::Foundation::IAsyncOperation<Extension::IResponse> AzureLLMProvider::GetResponseAsync(const winrt::hstring& userPrompt)
{
// Use a flag for whether the response the user receives is an error message
// we pass this flag back to the caller so they can handle it appropriately (specifically, ExtensionPalette will send the correct telemetry event)
// there is only one case downstream from here that sets this flag to false, so start with it being true
bool isError{ true };
hstring message{};

if (_azureEndpoint.empty())
{
message = RS_(L"CouldNotFindKeyErrorMessage");
}
else
{
// If the AI endpoint is not an azure open AI endpoint, return an error message
Windows::Foundation::Uri parsedUri{ _azureEndpoint };
if (!std::regex_search(_azureEndpoint.c_str(), azureOpenAIEndpointRegex) ||
parsedUri.Domain() != expectedDomain)
{
message = RS_(L"InvalidEndpointMessage");
}
}

// If we don't have a message string, that means the endpoint exists and matches the regex
// that we allow - now we can actually make the http request
if (message.empty())
{
// Make a copy of the prompt because we are switching threads
const auto promptCopy{ userPrompt };

// Make sure we are on the background thread for the http request
co_await winrt::resume_background();

WWH::HttpRequestMessage request{ WWH::HttpMethod::Post(), Uri{ _azureEndpoint } };
request.Headers().Accept().TryParseAdd(applicationJson);

WDJ::JsonObject jsonContent;
WDJ::JsonObject messageObject;

// _ActiveCommandline should be set already, we request for it the moment we become visible
winrt::hstring engineeredPrompt{ promptCopy };
if (_context && !_context.ActiveCommandline().empty())
{
engineeredPrompt = promptCopy + L". The shell I am running is " + _context.ActiveCommandline();
}
messageObject.Insert(roleString, WDJ::JsonValue::CreateStringValue(L"user"));
messageObject.Insert(contentString, WDJ::JsonValue::CreateStringValue(engineeredPrompt));
_jsonMessages.Append(messageObject);
jsonContent.SetNamedValue(L"messages", _jsonMessages);
jsonContent.SetNamedValue(L"max_tokens", WDJ::JsonValue::CreateNumberValue(800));
jsonContent.SetNamedValue(L"temperature", WDJ::JsonValue::CreateNumberValue(0.7));
jsonContent.SetNamedValue(L"frequency_penalty", WDJ::JsonValue::CreateNumberValue(0));
jsonContent.SetNamedValue(L"presence_penalty", WDJ::JsonValue::CreateNumberValue(0));
jsonContent.SetNamedValue(L"top_p", WDJ::JsonValue::CreateNumberValue(0.95));
jsonContent.SetNamedValue(L"stop", WDJ::JsonValue::CreateStringValue(L"None"));
const auto stringContent = jsonContent.ToString();
WWH::HttpStringContent requestContent{
stringContent,
WSS::UnicodeEncoding::Utf8,
L"application/json"
};

request.Content(requestContent);

// Send the request
try
{
const auto response = _httpClient.SendRequestAsync(request).get();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW all these get()s could use co_await in the future instead. That avoids the resume_background hassle.

// Parse out the suggestion from the response
const auto string{ response.Content().ReadAsStringAsync().get() };
const auto jsonResult{ WDJ::JsonObject::Parse(string) };
if (jsonResult.HasKey(errorString))
{
const auto errorObject = jsonResult.GetNamedObject(errorString);
message = errorObject.GetNamedString(messageString);
}
else
{
if (_verifyModelIsValidHelper(jsonResult))
{
const auto choices = jsonResult.GetNamedArray(L"choices");
const auto firstChoice = choices.GetAt(0).GetObject();
const auto messageObject = firstChoice.GetNamedObject(messageString);
message = messageObject.GetNamedString(contentString);
isError = false;
}
else
{
message = RS_(L"InvalidModelMessage");
}
}
}
catch (...)
{
message = RS_(L"UnknownErrorMessage");
}
}

// Also make a new entry in our jsonMessages list, so the AI knows the full conversation so far
WDJ::JsonObject responseMessageObject;
responseMessageObject.Insert(roleString, WDJ::JsonValue::CreateStringValue(L"assistant"));
responseMessageObject.Insert(contentString, WDJ::JsonValue::CreateStringValue(message));
_jsonMessages.Append(responseMessageObject);

co_return winrt::make<AzureResponse>(message, isError);
}

bool AzureLLMProvider::_verifyModelIsValidHelper(const WDJ::JsonObject jsonResponse)
{
const auto model = jsonResponse.GetNamedString(L"model");
bool modelIsAccepted{ false };
for (const auto acceptedModel : acceptedModels)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm sure if we didn't just use a string[] we could do some sort of smarter .contains(), but also, eh. We accept like 5 models, this is fine.

{
if (model == acceptedModel)
{
modelIsAccepted = true;
PankajBhojwani marked this conversation as resolved.
Show resolved Hide resolved
}
break;
}
if (!modelIsAccepted)
{
return false;
}
WDJ::JsonObject contentFiltersObject;
// For some reason, sometimes the content filter results are in a key called "prompt_filter_results"
// and sometimes they are in a key called "prompt_annotations". Check for either.
if (jsonResponse.HasKey(L"prompt_filter_results"))
{
contentFiltersObject = jsonResponse.GetNamedArray(L"prompt_filter_results").GetObjectAt(0);
}
else if (jsonResponse.HasKey(L"prompt_annotations"))
{
contentFiltersObject = jsonResponse.GetNamedArray(L"prompt_annotations").GetObjectAt(0);
}
else
{
return false;
}
const auto contentFilters = contentFiltersObject.GetNamedObject(L"content_filter_results");
if (Feature_TerminalChatJailbreakFilter::IsEnabled() && !contentFilters.HasKey(L"jailbreak"))
{
return false;
}
for (const auto filterPair : contentFilters)
{
const auto filterLevel = filterPair.Value().GetObjectW();
if (filterLevel.HasKey(severityString))
{
if (filterLevel.GetNamedString(severityString) != acceptedSeverityLevel)
{
return false;
}
}
}
return true;
}
}
54 changes: 54 additions & 0 deletions src/cascadia/QueryExtension/AzureLLMProvider.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#pragma once

#include "AzureLLMProvider.g.h"
#include "AzureResponse.g.h"

namespace winrt::Microsoft::Terminal::Query::Extension::implementation
{
struct AzureLLMProvider : AzureLLMProviderT<AzureLLMProvider>
{
AzureLLMProvider() = default;

void ClearMessageHistory();
void SetSystemPrompt(const winrt::hstring& systemPrompt);
void SetContext(const Extension::IContext context);

winrt::Windows::Foundation::IAsyncOperation<Extension::IResponse> GetResponseAsync(const winrt::hstring& userPrompt);

void SetAuthentication(const Windows::Foundation::Collections::ValueSet& authValues);
TYPED_EVENT(AuthChanged, winrt::Microsoft::Terminal::Query::Extension::ILMProvider, winrt::hstring);

private:
winrt::hstring _azureEndpoint;
winrt::hstring _azureKey;
winrt::Windows::Web::Http::HttpClient _httpClient{ nullptr };

Extension::IContext _context;

winrt::Windows::Data::Json::JsonArray _jsonMessages;

bool _verifyModelIsValidHelper(const Windows::Data::Json::JsonObject jsonResponse);
};

struct AzureResponse : AzureResponseT<AzureResponse>
{
AzureResponse(const winrt::hstring& message, const bool isError) :
_message{ message },
_isError{ isError } {}
winrt::hstring Message() { return _message; };
PankajBhojwani marked this conversation as resolved.
Show resolved Hide resolved
bool IsError() { return _isError; };

private:
winrt::hstring _message;
bool _isError;
};
}

namespace winrt::Microsoft::Terminal::Query::Extension::factory_implementation
{
BASIC_FACTORY(AzureLLMProvider);
BASIC_FACTORY(AzureResponse);
}
17 changes: 17 additions & 0 deletions src/cascadia/QueryExtension/AzureLLMProvider.idl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

import "ILMProvider.idl";

namespace Microsoft.Terminal.Query.Extension
{
[default_interface] runtimeclass AzureLLMProvider : ILMProvider
PankajBhojwani marked this conversation as resolved.
Show resolved Hide resolved
{
AzureLLMProvider();
}

[default_interface] runtimeclass AzureResponse : IResponse
PankajBhojwani marked this conversation as resolved.
Show resolved Hide resolved
{
AzureResponse(String message, Boolean isError);
}
}
Loading
Loading