diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index 38b249cc229a..189ffac85e92 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -5,6 +5,8 @@ true + + @@ -30,6 +32,7 @@ + diff --git a/dotnet/SK-dotnet.sln b/dotnet/SK-dotnet.sln index 2d11481810cb..fdb42fed44ae 100644 --- a/dotnet/SK-dotnet.sln +++ b/dotnet/SK-dotnet.sln @@ -314,6 +314,12 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TimePlugin", "samples\Demos EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.Memory.AzureCosmosDBNoSQL", "src\Connectors\Connectors.Memory.AzureCosmosDBNoSQL\Connectors.Memory.AzureCosmosDBNoSQL.csproj", "{B0B3901E-AF56-432B-8FAA-858468E5D0DF}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Connectors.Amazon", "src\Connectors\Connectors.Amazon\Connectors.Amazon.csproj", "{E059E9B0-1302-474D-B1B5-10A6E0F1A769}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "BedrockTest", "samples\Demos\BedrockTest\BedrockTest.csproj", "{ABEAACCD-CF63-4850-8ED5-E01379DBFC46}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Connectors.Amazon.UnitTests", "src\Connectors\Connectors.Amazon.UnitTests\Connectors.Amazon.UnitTests.csproj", "{CCC6DC57-2AC1-4C8E-A448-2CC0537A288E}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -771,6 +777,24 @@ Global {B0B3901E-AF56-432B-8FAA-858468E5D0DF}.Publish|Any CPU.Build.0 = Publish|Any CPU {B0B3901E-AF56-432B-8FAA-858468E5D0DF}.Release|Any CPU.ActiveCfg = Release|Any CPU {B0B3901E-AF56-432B-8FAA-858468E5D0DF}.Release|Any CPU.Build.0 = Release|Any CPU + {E059E9B0-1302-474D-B1B5-10A6E0F1A769}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {E059E9B0-1302-474D-B1B5-10A6E0F1A769}.Debug|Any CPU.Build.0 = Debug|Any CPU + {E059E9B0-1302-474D-B1B5-10A6E0F1A769}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {E059E9B0-1302-474D-B1B5-10A6E0F1A769}.Publish|Any CPU.Build.0 = Debug|Any CPU + {E059E9B0-1302-474D-B1B5-10A6E0F1A769}.Release|Any CPU.ActiveCfg = Release|Any CPU + {E059E9B0-1302-474D-B1B5-10A6E0F1A769}.Release|Any CPU.Build.0 = Release|Any CPU + {ABEAACCD-CF63-4850-8ED5-E01379DBFC46}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {ABEAACCD-CF63-4850-8ED5-E01379DBFC46}.Debug|Any CPU.Build.0 = Debug|Any CPU + {ABEAACCD-CF63-4850-8ED5-E01379DBFC46}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {ABEAACCD-CF63-4850-8ED5-E01379DBFC46}.Publish|Any CPU.Build.0 = Debug|Any CPU + {ABEAACCD-CF63-4850-8ED5-E01379DBFC46}.Release|Any CPU.ActiveCfg = Release|Any CPU + {ABEAACCD-CF63-4850-8ED5-E01379DBFC46}.Release|Any CPU.Build.0 = Release|Any CPU + {CCC6DC57-2AC1-4C8E-A448-2CC0537A288E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {CCC6DC57-2AC1-4C8E-A448-2CC0537A288E}.Debug|Any CPU.Build.0 = Debug|Any CPU + {CCC6DC57-2AC1-4C8E-A448-2CC0537A288E}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {CCC6DC57-2AC1-4C8E-A448-2CC0537A288E}.Publish|Any CPU.Build.0 = Debug|Any CPU + {CCC6DC57-2AC1-4C8E-A448-2CC0537A288E}.Release|Any CPU.ActiveCfg = Release|Any CPU + {CCC6DC57-2AC1-4C8E-A448-2CC0537A288E}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -877,6 +901,9 @@ Global {1D3EEB5B-0E06-4700-80D5-164956E43D0A} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263} {F312FCE1-12D7-4DEF-BC29-2FF6618509F3} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263} {B0B3901E-AF56-432B-8FAA-858468E5D0DF} = {24503383-A8C4-4255-9998-28D70FE8E99A} + {E059E9B0-1302-474D-B1B5-10A6E0F1A769} = {1B4CBDE0-10C2-4E7D-9CD0-FE7586C96ED1} + {ABEAACCD-CF63-4850-8ED5-E01379DBFC46} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263} + {CCC6DC57-2AC1-4C8E-A448-2CC0537A288E} = {1B4CBDE0-10C2-4E7D-9CD0-FE7586C96ED1} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {FBDC56A3-86AD-4323-AA0F-201E59123B83} diff --git a/dotnet/samples/Demos/BedrockTest/BedrockTest.csproj b/dotnet/samples/Demos/BedrockTest/BedrockTest.csproj new file mode 100644 index 000000000000..8501cc6541dc --- /dev/null +++ b/dotnet/samples/Demos/BedrockTest/BedrockTest.csproj @@ -0,0 +1,18 @@ + + + + Exe + net8.0 + enable + enable + + + + + + + + + + + diff --git a/dotnet/samples/Demos/BedrockTest/Program.cs b/dotnet/samples/Demos/BedrockTest/Program.cs new file mode 100644 index 000000000000..9aa5cf636d73 --- /dev/null +++ b/dotnet/samples/Demos/BedrockTest/Program.cs @@ -0,0 +1,233 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Connectors.Amazon.Extensions; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.TextGeneration; + +// Display the available options +Console.WriteLine("Choose an option:"); +Console.WriteLine("1. Chat Completion"); +Console.WriteLine("2. Text Generation"); +Console.WriteLine("3. Stream Chat Completion"); +Console.WriteLine("4. Stream Text Generation"); + +Console.Write("Enter your choice (1-4): "); +int choice; +while (!int.TryParse(Console.ReadLine(), out choice) || choice < 1 || choice > 4) +{ + Console.WriteLine("Invalid input. Please enter a valid number from the list."); + Console.Write("Enter your choice (1-4): "); +} + +switch (choice) +{ + case 1: + // ----------------------------CHAT COMPLETION---------------------------- + string userInput; + ChatHistory chatHistory = new(); + + // List of available models + Dictionary modelOptions = new() + { + { 1, "amazon.titan-text-premier-v1:0" }, + { 2, "anthropic.claude-3-sonnet-20240229-v1:0" }, + { 3, "anthropic.claude-3-haiku-20240307-v1:0" }, + { 4, "anthropic.claude-v2:1" }, + { 5, "ai21.jamba-instruct-v1:0" }, + { 6, "cohere.command-r-plus-v1:0" }, + { 7, "meta.llama3-8b-instruct-v1:0" }, + { 8, "mistral.mistral-7b-instruct-v0:2" } + }; + + // Display the model options + Console.WriteLine("Available models:"); + foreach (var option in modelOptions) + { + Console.WriteLine($"{option.Key}. {option.Value}"); + } + + Console.Write("Enter the number of the model you want to use for chat completion: "); + int chosenModel; + while (!int.TryParse(Console.ReadLine(), out chosenModel) || !modelOptions.ContainsKey(chosenModel)) + { + Console.WriteLine("Invalid input. Please enter a valid number from the list."); + Console.Write("Enter the number of the model you want to use: "); + } + + var kernel = Kernel.CreateBuilder().AddBedrockChatCompletionService(modelOptions[chosenModel]).Build(); + var chatCompletionService = kernel.GetRequiredService(); + + do + { + Console.Write("Enter a prompt (or 'exit' to quit): "); + userInput = Console.ReadLine() ?? "exit"; + + if (!string.Equals(userInput, "exit", StringComparison.OrdinalIgnoreCase)) + { + chatHistory.AddMessage(AuthorRole.User, userInput); + var result = await chatCompletionService.GetChatMessageContentsAsync(chatHistory).ConfigureAwait(false); + string output = ""; + foreach (var message in result) + { + output += message.Content; + Console.WriteLine($"Chat Completion Answer: {message.Content}"); + Console.WriteLine(); + } + chatHistory.AddMessage(AuthorRole.Assistant, output); + } + } while (!string.Equals(userInput, "exit", StringComparison.OrdinalIgnoreCase)); + break; + case 2: + // ----------------------------TEXT GENERATION---------------------------- + // List of available text generation models + Dictionary textGenerationModelOptions = new() + { + { 1, "amazon.titan-text-premier-v1:0" }, + { 2, "mistral.mistral-7b-instruct-v0:2" }, + { 3, "ai21.jamba-instruct-v1:0" }, + { 4, "anthropic.claude-v2:1" }, + { 5, "cohere.command-text-v14" }, + { 6, "meta.llama3-8b-instruct-v1:0" }, + { 7, "cohere.command-r-plus-v1:0" }, + { 8, "ai21.j2-ultra-v1" } + }; + + // Display the text generation model options + Console.WriteLine("Available text generation models:"); + foreach (var option in textGenerationModelOptions) + { + Console.WriteLine($"{option.Key}. {option.Value}"); + } + + Console.Write("Enter the number of the text generation model you want to use: "); + int chosenTextGenerationModel; + while (!int.TryParse(Console.ReadLine(), out chosenTextGenerationModel) || !textGenerationModelOptions.ContainsKey(chosenTextGenerationModel)) + { + Console.WriteLine("Invalid input. Please enter a valid number from the list."); + Console.Write("Enter the number of the text generation model you want to use: "); + } + + Console.Write("Text Generation Prompt: "); + string UserPrompt2 = Console.ReadLine() ?? ""; + + var kernel2 = Kernel.CreateBuilder().AddBedrockTextGenerationService(textGenerationModelOptions[chosenTextGenerationModel]).Build(); + + var textGenerationService = kernel2.GetRequiredService(); + var textGeneration = await textGenerationService.GetTextContentsAsync(UserPrompt2).ConfigureAwait(false); + if (textGeneration.Count > 0) + { + var firstTextContent = textGeneration[0]; + if (firstTextContent != null) + { + Console.WriteLine("Text Generation Answer: " + firstTextContent.Text); + } + else + { + Console.WriteLine("Text Generation Answer: (none)"); + } + } + else + { + Console.WriteLine("Text Generation Answer: (No output text)"); + } + break; + case 3: + // ----------------------------STREAM CHAT COMPLETION---------------------------- + string userInput2; + ChatHistory chatHistory2 = new(); + + // List of available stream chat completion models + Dictionary streamChatCompletionModelOptions = new() + { + { 1, "mistral.mistral-7b-instruct-v0:2" }, + { 2, "amazon.titan-text-premier-v1:0" }, + { 3, "anthropic.claude-v2" }, + { 4, "anthropic.claude-3-sonnet-20240229-v1:0" }, + { 5, "cohere.command-r-plus-v1:0" }, + { 6, "meta.llama3-8b-instruct-v1:0" } + }; + + // Display the stream chat completion model options + Console.WriteLine("Available stream chat completion models:"); + foreach (var option in streamChatCompletionModelOptions) + { + Console.WriteLine($"{option.Key}. {option.Value}"); + } + + Console.Write("Enter the number of the stream chat completion model you want to use: "); + int chosenStreamChatCompletionModel; + while (!int.TryParse(Console.ReadLine(), out chosenStreamChatCompletionModel) || !streamChatCompletionModelOptions.ContainsKey(chosenStreamChatCompletionModel)) + { + Console.WriteLine("Invalid input. Please enter a valid number from the list."); + Console.Write("Enter the number of the stream chat completion model you want to use: "); + } + + var kernel4 = Kernel.CreateBuilder().AddBedrockChatCompletionService(streamChatCompletionModelOptions[chosenStreamChatCompletionModel]).Build(); + var chatStreamCompletionService = kernel4.GetRequiredService(); + + do + { + Console.Write("Enter a prompt (or 'exit' to quit): "); + userInput2 = Console.ReadLine() ?? "exit"; + + if (!string.Equals(userInput2, "exit", StringComparison.OrdinalIgnoreCase)) + { + chatHistory2.AddMessage(AuthorRole.User, userInput2); + var result = chatStreamCompletionService.GetStreamingChatMessageContentsAsync(chatHistory2).ConfigureAwait(false); + string output = ""; + await foreach (var message in result) + { + Console.Write($"{message.Content}"); + Thread.Sleep(50); + output += message.Content; + } + Console.WriteLine(); + chatHistory2.AddMessage(AuthorRole.Assistant, output); + } + } while (!string.Equals(userInput2, "exit", StringComparison.OrdinalIgnoreCase)); + break; + case 4: + // ----------------------------STREAM TEXT GENERATION---------------------------- + // List of available stream text generation models + Dictionary streamTextGenerationModelOptions = new() + { + { 1, "amazon.titan-text-premier-v1:0" }, + { 2, "anthropic.claude-v2" }, + { 3, "mistral.mistral-7b-instruct-v0:2" }, + { 4, "cohere.command-text-v14" }, + { 5, "cohere.command-r-plus-v1:0" }, + { 6, "meta.llama3-8b-instruct-v1:0" } + }; + + // Display the stream text generation model options + Console.WriteLine("Available stream text generation models:"); + foreach (var option in streamTextGenerationModelOptions) + { + Console.WriteLine($"{option.Key}. {option.Value}"); + } + + Console.Write("Enter the number of the stream text generation model you want to use: "); + int chosenStreamTextGenerationModel; + while (!int.TryParse(Console.ReadLine(), out chosenStreamTextGenerationModel) || !streamTextGenerationModelOptions.ContainsKey(chosenStreamTextGenerationModel)) + { + Console.WriteLine("Invalid input. Please enter a valid number from the list."); + Console.Write("Enter the number of the stream text generation model you want to use: "); + } + + Console.Write("Stream Text Generation Prompt: "); + string UserPrompt3 = Console.ReadLine() ?? ""; + + var kernel3 = Kernel.CreateBuilder().AddBedrockTextGenerationService(streamTextGenerationModelOptions[chosenStreamTextGenerationModel]).Build(); + + var streamTextGenerationService = kernel3.GetRequiredService(); + var streamTextGeneration = streamTextGenerationService.GetStreamingTextContentsAsync(UserPrompt3).ConfigureAwait(true); + await foreach (var textContent in streamTextGeneration) + { + Console.Write(textContent.Text); + Thread.Sleep(50); + } + + Console.WriteLine(); + break; +} diff --git a/dotnet/src/Connectors/Connectors.Amazon.UnitTests/BedrockKernelBuilderExtensionTests.cs b/dotnet/src/Connectors/Connectors.Amazon.UnitTests/BedrockKernelBuilderExtensionTests.cs new file mode 100644 index 000000000000..affb4d8575de --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Amazon.UnitTests/BedrockKernelBuilderExtensionTests.cs @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Connectors.Amazon.Extensions; +using Connectors.Amazon.Services; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.Amazon.Services; +using Microsoft.SemanticKernel.TextGeneration; +using Xunit; + +namespace Connectors.Amazon.UnitTests; + +/// +/// Kernel Builder Extension Tests for Bedrock. +/// +public class BedrockKernelBuilderExtensionTests +{ + /// + /// Checks that AddBedrockTextGenerationService builds a proper kernel. + /// + [Fact] + public void AddBedrockTextGenerationCreatesService() + { + // Arrange + var builder = Kernel.CreateBuilder(); + builder.AddBedrockTextGenerationService("amazon.titan-text-premier-v1:0"); + + // Act + var kernel = builder.Build(); + var service = kernel.GetRequiredService(); + + // Assert + Assert.NotNull(kernel); + Assert.NotNull(service); + Assert.IsType(service); + } + /// + /// Checks that AddBedrockChatCompletionService builds a proper kernel. + /// + [Fact] + public void AddBedrockChatCompletionCreatesService() + { + // Arrange + var builder = Kernel.CreateBuilder(); + builder.AddBedrockChatCompletionService("amazon.titan-text-premier-v1:0"); + + // Act + var kernel = builder.Build(); + var service = kernel.GetRequiredService(); + + // Assert + Assert.NotNull(kernel); + Assert.NotNull(service); + Assert.IsType(service); + } +} diff --git a/dotnet/src/Connectors/Connectors.Amazon.UnitTests/Connectors.Amazon.UnitTests.csproj b/dotnet/src/Connectors/Connectors.Amazon.UnitTests/Connectors.Amazon.UnitTests.csproj new file mode 100644 index 000000000000..4732d41edf4a --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Amazon.UnitTests/Connectors.Amazon.UnitTests.csproj @@ -0,0 +1,23 @@ + + + + true + false + net8.0 + enable + enable + + + + + + + + + + + + + + + diff --git a/dotnet/src/Connectors/Connectors.Amazon.UnitTests/Services/BedrockChatCompletionServiceTests.cs b/dotnet/src/Connectors/Connectors.Amazon.UnitTests/Services/BedrockChatCompletionServiceTests.cs new file mode 100644 index 000000000000..402201140808 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Amazon.UnitTests/Services/BedrockChatCompletionServiceTests.cs @@ -0,0 +1,642 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Amazon.BedrockRuntime; +using Amazon.BedrockRuntime.Model; +using Amazon.Runtime.Endpoints; +using Connectors.Amazon.Extensions; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Services; +using Moq; +using Xunit; + +namespace Connectors.Amazon.UnitTests.Services; + +/// +/// Unit tests for Bedrock Chat Completion Service. +/// +public class BedrockChatCompletionServiceTests +{ + private static ChatHistory CreateSampleChatHistory() + { + var chatHistory = new ChatHistory(); + chatHistory.AddUserMessage("Hello"); + chatHistory.AddAssistantMessage("Hi"); + chatHistory.AddUserMessage("How are you?"); + chatHistory.AddSystemMessage("You are an AI Assistant"); + return chatHistory; + } + /// + /// Checks that modelID is added to the list of service attributes when service is registered. + /// + [Fact] + public void AttributesShouldContainModelId() + { + // Arrange & Act + string modelId = "amazon.titan-text-premier-v1:0"; + var mockBedrockApi = new Mock(); + var kernel = Kernel.CreateBuilder().AddBedrockChatCompletionService(modelId, mockBedrockApi.Object).Build(); + var service = kernel.GetRequiredService(); + + // Assert + Assert.Equal(modelId, service.Attributes[AIServiceExtensions.ModelIdKey]); + } + /// + /// Checks that GetChatMessageContentsAsync calls and correctly handles outputs from ConverseAsync. + /// + [Fact] + public async Task GetChatMessageContentsAsyncShouldReturnChatMessageContentsAsync() + { + // Arrange + string modelId = "amazon.titan-embed-text-v1:0"; + var mockBedrockApi = new Mock(); + mockBedrockApi.Setup(m => m.DetermineServiceOperationEndpoint(It.IsAny())) + .Returns(new Endpoint("https://bedrock-runtime.us-east-1.amazonaws.com") + { + URL = "https://bedrock-runtime.us-east-1.amazonaws.com" + }); + mockBedrockApi.Setup(m => m.ConverseAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new ConverseResponse + { + Output = new ConverseOutput + { + Message = new Message + { + Role = ConversationRole.Assistant, + Content = new List { new() { Text = "Hello, world!" } } + } + }, + Metrics = new ConverseMetrics(), + StopReason = StopReason.Max_tokens, + Usage = new TokenUsage() + }); + var kernel = Kernel.CreateBuilder().AddBedrockChatCompletionService(modelId, mockBedrockApi.Object).Build(); + var service = kernel.GetRequiredService(); + var chatHistory = CreateSampleChatHistory(); + + // Act + var result = await service.GetChatMessageContentsAsync(chatHistory).ConfigureAwait(true); + + // Assert + Assert.Single(result); + Assert.Equal(AuthorRole.Assistant, result[0].Role); + Assert.Single(result[0].Items); + Assert.Equal("Hello, world!", result[0].Items[0].ToString()); + } + /// + /// Checks that GetStreamingChatMessageContentsAsync calls and correctly handles outputs from ConverseStreamAsync. + /// + [Fact] + public async Task GetStreamingChatMessageContentsAsyncShouldReturnStreamedChatMessageContentsAsync() + { + // Arrange + string modelId = "amazon.titan-text-lite-v1"; + var mockBedrockApi = new Mock(); + mockBedrockApi.Setup(m => m.DetermineServiceOperationEndpoint(It.IsAny())) + .Returns(new Endpoint("https://bedrock-runtime.us-east-1.amazonaws.com") + { + URL = "https://bedrock-runtime.us-east-1.amazonaws.com" + }); + mockBedrockApi.Setup(m => m.ConverseStreamAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new ConverseStreamResponse + { + Stream = new ConverseStreamOutput(new MemoryStream()) + }); + + var kernel = Kernel.CreateBuilder().AddBedrockChatCompletionService(modelId, mockBedrockApi.Object).Build(); + var service = kernel.GetRequiredService(); + var chatHistory = CreateSampleChatHistory(); + + // Act + List output = new(); + var result = service.GetStreamingChatMessageContentsAsync(chatHistory).ConfigureAwait(true); + + // Assert + await foreach (var item in result) + { + Assert.NotNull(item); + Assert.NotNull(item.Content); + Assert.NotNull(item.Role); + output.Add(item); + } + Assert.NotNull(output); + Assert.NotNull(service.GetModelId()); + Assert.NotNull(service.Attributes); + } + /// + /// Checks that the prompt execution settings are correctly registered for the chat completion call. + /// + [Fact] + public async Task TitanGetChatMessageContentsAsyncShouldReturnChatMessageWithPromptExecutionSettingsAsync() + { + // Arrange + string modelId = "amazon.titan-text-lite-v1"; + var mockBedrockApi = new Mock(); + var executionSettings = new PromptExecutionSettings() + { + ModelId = modelId, + ExtensionData = new Dictionary() + { + { "temperature", 0.3f }, + { "topP", 0.8f }, + { "maxTokenCount", 510 } + } + }; + mockBedrockApi.Setup(m => m.DetermineServiceOperationEndpoint(It.IsAny())) + .Returns(new Endpoint("https://bedrock-runtime.us-east-1.amazonaws.com") + { + URL = "https://bedrock-runtime.us-east-1.amazonaws.com" + }); + mockBedrockApi.Setup(m => m.ConverseAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new ConverseResponse + { + Output = new ConverseOutput + { + Message = new Message + { + Role = ConversationRole.Assistant, + Content = new List { new() { Text = "I'm doing well." } } + } + }, + Metrics = new ConverseMetrics(), + StopReason = StopReason.Max_tokens, + Usage = new TokenUsage() + }); + var kernel = Kernel.CreateBuilder().AddBedrockChatCompletionService(modelId, mockBedrockApi.Object).Build(); + var service = kernel.GetRequiredService(); + var chatHistory = CreateSampleChatHistory(); + + // Act + var result = await service.GetChatMessageContentsAsync(chatHistory, executionSettings).ConfigureAwait(true); + + // Assert + var invocation = mockBedrockApi.Invocations + .Where(i => i.Method.Name == "ConverseAsync") + .SingleOrDefault(i => i.Arguments.Count > 0 && i.Arguments[0] is ConverseRequest); + Assert.NotNull(invocation); + ConverseRequest converseRequest = (ConverseRequest)invocation.Arguments[0]; + Assert.Single(result); + Assert.Equal("I'm doing well.", result[0].Items[0].ToString()); + Assert.Equal(executionSettings.ExtensionData["temperature"], converseRequest?.InferenceConfig.Temperature); + Assert.Equal(executionSettings.ExtensionData["topP"], converseRequest?.InferenceConfig.TopP); + Assert.Equal(executionSettings.ExtensionData["maxTokenCount"], converseRequest?.InferenceConfig.MaxTokens); + } + /// + /// Checks that the roles from the chat history are correctly assigned and labeled for the converse calls. + /// + [Fact] + public async Task GetChatMessageContentsAsyncShouldAssignCorrectRolesAsync() + { + // Arrange + string modelId = "amazon.titan-embed-text-v1:0"; + var mockBedrockApi = new Mock(); + mockBedrockApi.Setup(m => m.DetermineServiceOperationEndpoint(It.IsAny())) + .Returns(new Endpoint("https://bedrock-runtime.us-east-1.amazonaws.com") + { + URL = "https://bedrock-runtime.us-east-1.amazonaws.com" + }); + mockBedrockApi.Setup(m => m.ConverseAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new ConverseResponse + { + Output = new ConverseOutput + { + Message = new Message + { + Role = ConversationRole.Assistant, + Content = new List { new() { Text = "I'm doing well." } } + } + }, + Metrics = new ConverseMetrics(), + StopReason = StopReason.Max_tokens, + Usage = new TokenUsage() + }); + var kernel = Kernel.CreateBuilder().AddBedrockChatCompletionService(modelId, mockBedrockApi.Object).Build(); + var service = kernel.GetRequiredService(); + var chatHistory = CreateSampleChatHistory(); + + // Act + var result = await service.GetChatMessageContentsAsync(chatHistory).ConfigureAwait(true); + + // Assert + Assert.Single(result); + Assert.Equal(AuthorRole.Assistant, result[0].Role); + Assert.Single(result[0].Items); + Assert.Equal("I'm doing well.", result[0].Items[0].ToString()); + } + /// + /// Checks that the chat history is given the correct values through calling GetChatMessageContentsAsync. + /// + [Fact] + public async Task GetChatMessageContentsAsyncShouldHaveProperChatHistoryAsync() + { + // Arrange + string modelId = "amazon.titan-embed-text-v1:0"; + var mockBedrockApi = new Mock(); + mockBedrockApi.Setup(m => m.DetermineServiceOperationEndpoint(It.IsAny())) + .Returns(new Endpoint("https://bedrock-runtime.us-east-1.amazonaws.com") + { + URL = "https://bedrock-runtime.us-east-1.amazonaws.com" + }); + + // Set up the mock ConverseAsync to return multiple responses + mockBedrockApi.SetupSequence(m => m.ConverseAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new ConverseResponse + { + Output = new ConverseOutput + { + Message = new Message + { + Role = ConversationRole.User, + Content = new List { new() { Text = "I'm doing well." } } + } + }, + Metrics = new ConverseMetrics(), + StopReason = StopReason.Max_tokens, + Usage = new TokenUsage() + }) + .ReturnsAsync(new ConverseResponse + { + Output = new ConverseOutput + { + Message = new Message + { + Role = ConversationRole.Assistant, + Content = new List { new() { Text = "That's great to hear!" } } + } + }, + Metrics = new ConverseMetrics(), + StopReason = StopReason.Max_tokens, + Usage = new TokenUsage() + }); + + var kernel = Kernel.CreateBuilder().AddBedrockChatCompletionService(modelId, mockBedrockApi.Object).Build(); + var service = kernel.GetRequiredService(); + var chatHistory = CreateSampleChatHistory(); + + // Act + var result1 = await service.GetChatMessageContentsAsync(chatHistory).ConfigureAwait(true); + var result2 = await service.GetChatMessageContentsAsync(chatHistory).ConfigureAwait(true); + + // Assert + string? chatResult1 = result1[0].Content; + Assert.NotNull(chatResult1); + chatHistory.AddAssistantMessage(chatResult1); + string? chatResult2 = result2[0].Content; + Assert.NotNull(chatResult2); + chatHistory.AddUserMessage(chatResult2); + Assert.Equal(2, result1.Count + result2.Count); + + // Check the first result + Assert.Equal(AuthorRole.User, result1[0].Role); + Assert.Single(result1[0].Items); + Assert.Equal("I'm doing well.", result1[0].Items[0].ToString()); + + // Check the second result + Assert.Equal(AuthorRole.Assistant, result2[0].Role); + Assert.Single(result2[0].Items); + Assert.Equal("That's great to hear!", result2[0].Items[0].ToString()); + + // Check the chat history + Assert.Equal(6, chatHistory.Count); // Use the Count property to get the number of messages + + Assert.Equal(AuthorRole.User, chatHistory[0].Role); // Use the indexer to access individual messages + Assert.Equal("Hello", chatHistory[0].Items[0].ToString()); + + Assert.Equal(AuthorRole.Assistant, chatHistory[1].Role); + Assert.Equal("Hi", chatHistory[1].Items[0].ToString()); + + Assert.Equal(AuthorRole.User, chatHistory[2].Role); + Assert.Equal("How are you?", chatHistory[2].Items[0].ToString()); + + Assert.Equal(AuthorRole.System, chatHistory[3].Role); + Assert.Equal("You are an AI Assistant", chatHistory[3].Items[0].ToString()); + + Assert.Equal(AuthorRole.Assistant, chatHistory[4].Role); + Assert.Equal("I'm doing well.", chatHistory[4].Items[0].ToString()); + + Assert.Equal(AuthorRole.User, chatHistory[5].Role); + Assert.Equal("That's great to hear!", chatHistory[5].Items[0].ToString()); + } + /// + /// Checks that error handling present for empty chat history. + /// + [Fact] + public async Task ShouldThrowArgumentExceptionIfChatHistoryIsEmptyAsync() + { + // Arrange + string modelId = "amazon.titan-embed-text-v1:0"; + var mockBedrockApi = new Mock(); + var chatHistory = new ChatHistory(); + mockBedrockApi.SetupSequence(m => m.ConverseAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new ConverseResponse + { + Output = new ConverseOutput + { + Message = new Message + { + Role = ConversationRole.Assistant, + Content = new List { new() { Text = "sample" } } + } + }, + Metrics = new ConverseMetrics(), + StopReason = StopReason.Max_tokens, + Usage = new TokenUsage() + }); + var kernel = Kernel.CreateBuilder().AddBedrockChatCompletionService(modelId, mockBedrockApi.Object).Build(); + var service = kernel.GetRequiredService(); + + // Act & Assert + await Assert.ThrowsAsync( + () => service.GetChatMessageContentsAsync(chatHistory)).ConfigureAwait(true); + } + /// + /// Checks that the prompt execution settings are correctly registered for the chat completion call. + /// + [Fact] + public async Task ClaudeGetChatMessageContentsAsyncShouldReturnChatMessageWithPromptExecutionSettingsAsync() + { + // Arrange + string modelId = "anthropic.claude-chat-completion"; + var mockBedrockApi = new Mock(); + var executionSettings = new PromptExecutionSettings() + { + ModelId = modelId, + ExtensionData = new Dictionary() + { + { "temperature", 0.7f }, + { "top_p", 0.7f }, + { "max_tokens_to_sample", 512 } + } + }; + mockBedrockApi.Setup(m => m.DetermineServiceOperationEndpoint(It.IsAny())) + .Returns(new Endpoint("https://bedrock-runtime.us-east-1.amazonaws.com") + { + URL = "https://bedrock-runtime.us-east-1.amazonaws.com" + }); + mockBedrockApi.Setup(m => m.ConverseAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new ConverseResponse + { + Output = new ConverseOutput + { + Message = new Message + { + Role = ConversationRole.Assistant, + Content = new List { new() { Text = "I'm doing well." } } + } + }, + Metrics = new ConverseMetrics(), + StopReason = StopReason.Max_tokens, + Usage = new TokenUsage() + }); + var kernel = Kernel.CreateBuilder().AddBedrockChatCompletionService(modelId, mockBedrockApi.Object).Build(); + var service = kernel.GetRequiredService(); + var chatHistory = CreateSampleChatHistory(); + + // Act + var result = await service.GetChatMessageContentsAsync(chatHistory, executionSettings).ConfigureAwait(true); + + // Assert + var invocation = mockBedrockApi.Invocations + .Where(i => i.Method.Name == "ConverseAsync") + .SingleOrDefault(i => i.Arguments.Count > 0 && i.Arguments[0] is ConverseRequest); + Assert.NotNull(invocation); + ConverseRequest converseRequest = (ConverseRequest)invocation.Arguments[0]; + Assert.Single(result); + Assert.Equal("I'm doing well.", result[0].Items[0].ToString()); + Assert.Equal(executionSettings.ExtensionData["temperature"], converseRequest?.InferenceConfig.Temperature); + Assert.Equal(executionSettings.ExtensionData["top_p"], converseRequest?.InferenceConfig.TopP); + Assert.Equal(executionSettings.ExtensionData["max_tokens_to_sample"], converseRequest?.InferenceConfig.MaxTokens); + } + /// + /// Checks that the prompt execution settings are correctly registered for the chat completion call. + /// + [Fact] + public async Task LlamaGetChatMessageContentsAsyncShouldReturnChatMessageWithPromptExecutionSettingsAsync() + { + // Arrange + string modelId = "meta.llama3-text-lite-v1"; + var mockBedrockApi = new Mock(); + var executionSettings = new PromptExecutionSettings() + { + ModelId = modelId, + ExtensionData = new Dictionary() + { + { "temperature", 0.7f }, + { "top_p", 0.6f }, + { "max_gen_len", 256 } + } + }; + mockBedrockApi.Setup(m => m.DetermineServiceOperationEndpoint(It.IsAny())) + .Returns(new Endpoint("https://bedrock-runtime.us-east-1.amazonaws.com") + { + URL = "https://bedrock-runtime.us-east-1.amazonaws.com" + }); + mockBedrockApi.Setup(m => m.ConverseAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new ConverseResponse + { + Output = new ConverseOutput + { + Message = new Message + { + Role = ConversationRole.Assistant, + Content = new List { new() { Text = "I'm doing well." } } + } + }, + Metrics = new ConverseMetrics(), + StopReason = StopReason.Max_tokens, + Usage = new TokenUsage() + }); + var kernel = Kernel.CreateBuilder().AddBedrockChatCompletionService(modelId, mockBedrockApi.Object).Build(); + var service = kernel.GetRequiredService(); + var chatHistory = CreateSampleChatHistory(); + + // Act + var result = await service.GetChatMessageContentsAsync(chatHistory, executionSettings).ConfigureAwait(true); + + // Assert + var invocation = mockBedrockApi.Invocations + .Where(i => i.Method.Name == "ConverseAsync") + .SingleOrDefault(i => i.Arguments.Count > 0 && i.Arguments[0] is ConverseRequest); + Assert.NotNull(invocation); + ConverseRequest converseRequest = (ConverseRequest)invocation.Arguments[0]; + Assert.Single(result); + Assert.Equal("I'm doing well.", result[0].Items[0].ToString()); + Assert.Equal(executionSettings.ExtensionData["temperature"], converseRequest?.InferenceConfig.Temperature); + Assert.Equal(executionSettings.ExtensionData["top_p"], converseRequest?.InferenceConfig.TopP); + Assert.Equal(executionSettings.ExtensionData["max_gen_len"], converseRequest?.InferenceConfig.MaxTokens); + } + /// + /// Checks that the prompt execution settings are correctly registered for the chat completion call. + /// + [Fact] + public async Task MistralGetChatMessageContentsAsyncShouldReturnChatMessageWithPromptExecutionSettingsAsync() + { + // Arrange + string modelId = "mistral.mistral-text-lite-v1"; + var mockBedrockApi = new Mock(); + var executionSettings = new PromptExecutionSettings() + { + ModelId = modelId, + ExtensionData = new Dictionary() + { + { "temperature", 0.5f }, + { "top_p", .9f }, + { "max_tokens", 512 } + } + }; + mockBedrockApi.Setup(m => m.DetermineServiceOperationEndpoint(It.IsAny())) + .Returns(new Endpoint("https://bedrock-runtime.us-east-1.amazonaws.com") + { + URL = "https://bedrock-runtime.us-east-1.amazonaws.com" + }); + mockBedrockApi.Setup(m => m.ConverseAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new ConverseResponse + { + Output = new ConverseOutput + { + Message = new Message + { + Role = ConversationRole.Assistant, + Content = new List { new() { Text = "I'm doing well." } } + } + }, + Metrics = new ConverseMetrics(), + StopReason = StopReason.Max_tokens, + Usage = new TokenUsage() + }); + var kernel = Kernel.CreateBuilder().AddBedrockChatCompletionService(modelId, mockBedrockApi.Object).Build(); + var service = kernel.GetRequiredService(); + var chatHistory = CreateSampleChatHistory(); + + // Act + var result = await service.GetChatMessageContentsAsync(chatHistory, executionSettings).ConfigureAwait(true); + + // Assert + var invocation = mockBedrockApi.Invocations + .Where(i => i.Method.Name == "ConverseAsync") + .SingleOrDefault(i => i.Arguments.Count > 0 && i.Arguments[0] is ConverseRequest); + Assert.NotNull(invocation); + ConverseRequest converseRequest = (ConverseRequest)invocation.Arguments[0]; + Assert.Single(result); + Assert.Equal("I'm doing well.", result[0].Items[0].ToString()); + Assert.Equal(executionSettings.ExtensionData["temperature"], converseRequest?.InferenceConfig.Temperature); + Assert.Equal(executionSettings.ExtensionData["top_p"], converseRequest?.InferenceConfig.TopP); + Assert.Equal(executionSettings.ExtensionData["max_tokens"], converseRequest?.InferenceConfig.MaxTokens); + } + /// + /// Checks that the prompt execution settings are correctly registered for the chat completion call. + /// + [Fact] + public async Task CommandRGetChatMessageContentsAsyncShouldReturnChatMessageWithPromptExecutionSettingsAsync() + { + // Arrange + string modelId = "cohere.command-r-chat-stuff"; + var mockBedrockApi = new Mock(); + var executionSettings = new PromptExecutionSettings() + { + ModelId = modelId, + ExtensionData = new Dictionary() + { + { "temperature", 0.7f }, + { "p", 0.9f }, + { "max_tokens", 202 } + } + }; + mockBedrockApi.Setup(m => m.DetermineServiceOperationEndpoint(It.IsAny())) + .Returns(new Endpoint("https://bedrock-runtime.us-east-1.amazonaws.com") + { + URL = "https://bedrock-runtime.us-east-1.amazonaws.com" + }); + mockBedrockApi.Setup(m => m.ConverseAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new ConverseResponse + { + Output = new ConverseOutput + { + Message = new Message + { + Role = ConversationRole.Assistant, + Content = new List { new() { Text = "I'm doing well." } } + } + }, + Metrics = new ConverseMetrics(), + StopReason = StopReason.Max_tokens, + Usage = new TokenUsage() + }); + var kernel = Kernel.CreateBuilder().AddBedrockChatCompletionService(modelId, mockBedrockApi.Object).Build(); + var service = kernel.GetRequiredService(); + var chatHistory = CreateSampleChatHistory(); + + // Act + var result = await service.GetChatMessageContentsAsync(chatHistory, executionSettings).ConfigureAwait(true); + + // Assert + var invocation = mockBedrockApi.Invocations + .Where(i => i.Method.Name == "ConverseAsync") + .SingleOrDefault(i => i.Arguments.Count > 0 && i.Arguments[0] is ConverseRequest); + Assert.NotNull(invocation); + ConverseRequest converseRequest = (ConverseRequest)invocation.Arguments[0]; + Assert.Single(result); + Assert.Equal("I'm doing well.", result[0].Items[0].ToString()); + Assert.Equal(executionSettings.ExtensionData["temperature"], converseRequest?.InferenceConfig.Temperature); + Assert.Equal(executionSettings.ExtensionData["p"], converseRequest?.InferenceConfig.TopP); + Assert.Equal(executionSettings.ExtensionData["max_tokens"], converseRequest?.InferenceConfig.MaxTokens); + } + /// + /// Checks that the prompt execution settings are correctly registered for the chat completion call. + /// + [Fact] + public async Task JambaGetChatMessageContentsAsyncShouldReturnChatMessageWithPromptExecutionSettingsAsync() + { + // Arrange + string modelId = "ai21.jamba-chat-stuff"; + var mockBedrockApi = new Mock(); + var executionSettings = new PromptExecutionSettings() + { + ModelId = modelId, + ExtensionData = new Dictionary() + { + { "temperature", 0.7f }, + { "top_p", 0.9f }, + { "max_tokens", 202 } + } + }; + mockBedrockApi.Setup(m => m.DetermineServiceOperationEndpoint(It.IsAny())) + .Returns(new Endpoint("https://bedrock-runtime.us-east-1.amazonaws.com") + { + URL = "https://bedrock-runtime.us-east-1.amazonaws.com" + }); + mockBedrockApi.Setup(m => m.ConverseAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new ConverseResponse + { + Output = new ConverseOutput + { + Message = new Message + { + Role = ConversationRole.Assistant, + Content = new List { new() { Text = "I'm doing well." } } + } + }, + Metrics = new ConverseMetrics(), + StopReason = StopReason.Max_tokens, + Usage = new TokenUsage() + }); + var kernel = Kernel.CreateBuilder().AddBedrockChatCompletionService(modelId, mockBedrockApi.Object).Build(); + var service = kernel.GetRequiredService(); + var chatHistory = CreateSampleChatHistory(); + + // Act + var result = await service.GetChatMessageContentsAsync(chatHistory, executionSettings).ConfigureAwait(true); + + // Assert + var invocation = mockBedrockApi.Invocations + .Where(i => i.Method.Name == "ConverseAsync") + .SingleOrDefault(i => i.Arguments.Count > 0 && i.Arguments[0] is ConverseRequest); + Assert.NotNull(invocation); + ConverseRequest converseRequest = (ConverseRequest)invocation.Arguments[0]; + Assert.Single(result); + Assert.Equal("I'm doing well.", result[0].Items[0].ToString()); + Assert.Equal(executionSettings.ExtensionData["temperature"], converseRequest?.InferenceConfig.Temperature); + Assert.Equal(executionSettings.ExtensionData["top_p"], converseRequest?.InferenceConfig.TopP); + Assert.Equal(executionSettings.ExtensionData["max_tokens"], converseRequest?.InferenceConfig.MaxTokens); + } +} diff --git a/dotnet/src/Connectors/Connectors.Amazon.UnitTests/Services/BedrockTextGenerationServiceTests.cs b/dotnet/src/Connectors/Connectors.Amazon.UnitTests/Services/BedrockTextGenerationServiceTests.cs new file mode 100644 index 000000000000..911fd742c92a --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Amazon.UnitTests/Services/BedrockTextGenerationServiceTests.cs @@ -0,0 +1,600 @@ +// Copyright (c) Microsoft. All rights reserved. +using System.Text; +using System.Text.Json; +using Amazon; +using Amazon.BedrockRuntime; +using Amazon.BedrockRuntime.Model; +using Amazon.Runtime.Endpoints; +using Connectors.Amazon.Extensions; +using Connectors.Amazon.Models.AI21; +using Connectors.Amazon.Models.Amazon; +using Connectors.Amazon.Models.Anthropic; +using Connectors.Amazon.Models.Cohere; +using Connectors.Amazon.Models.Meta; +using Connectors.Amazon.Models.Mistral; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Services; +using Microsoft.SemanticKernel.TextGeneration; +using Moq; +using Xunit; + +namespace Connectors.Amazon.UnitTests.Services; + +/// +/// Unit tests for BedrockTextGenerationService. +/// +public class BedrockTextGenerationServiceTests +{ + /// + /// Checks that modelID is added to the list of service attributes when service is registered. + /// + [Fact] + public void AttributesShouldContainModelId() + { + // Arrange & Act + string modelId = "amazon.titan-text-premier-v1:0"; + var mockBedrockApi = new Mock(); + var kernel = Kernel.CreateBuilder().AddBedrockTextGenerationService(modelId, mockBedrockApi.Object).Build(); + var service = kernel.GetRequiredService(); + + // Assert + Assert.Equal(modelId, service.Attributes[AIServiceExtensions.ModelIdKey]); + } + /// + /// Checks that GetTextContentsAsync calls and correctly handles outputs from InvokeModelAsync. + /// + [Fact] + public async Task GetTextContentsAsyncShouldReturnTextContentsAsync() + { + // Arrange + string modelId = "amazon.titan-text-premier-v1:0"; + var mockBedrockApi = new Mock(); + mockBedrockApi.Setup(m => m.DetermineServiceOperationEndpoint(It.IsAny())) + .Returns(new Endpoint("https://bedrock-runtime.us-east-1.amazonaws.com") + { + URL = "https://bedrock-runtime.us-east-1.amazonaws.com" + }); + mockBedrockApi.Setup(m => m.InvokeModelAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new InvokeModelResponse + { + Body = new MemoryStream(Encoding.UTF8.GetBytes(JsonSerializer.Serialize(new TitanTextResponse + { + InputTextTokenCount = 5, + Results = new List + { + new() { + TokenCount = 10, + OutputText = "This is a mock output.", + CompletionReason = "stop" + } + } + }))), + ContentType = "application/json" + }); + var kernel = Kernel.CreateBuilder().AddBedrockTextGenerationService(modelId, mockBedrockApi.Object).Build(); + var service = kernel.GetRequiredService(); + var prompt = "Write a greeting."; + + // Act + var result = await service.GetTextContentsAsync(prompt).ConfigureAwait(true); + + // Assert + Assert.Single(result); + Assert.Equal("This is a mock output.", result[0].Text); + } + + /// + /// Checks that GetStreamingTextContentsAsync calls and correctly handles outputs from InvokeModelAsync. + /// + [Fact] + public async Task GetStreamingTextContentsAsyncShouldReturnStreamedTextContentsAsync() + { + // Arrange + string modelId = "amazon.titan-text-premier-v1:0"; + string prompt = "Write a short greeting."; + + var mockBedrockApi = new Mock(); + mockBedrockApi.Setup(m => m.DetermineServiceOperationEndpoint(It.IsAny())) + .Returns(new Endpoint("https://bedrock-runtime.us-east-1.amazonaws.com") + { + URL = "https://bedrock-runtime.us-east-1.amazonaws.com" + }); + mockBedrockApi.Setup(m => m.InvokeModelWithResponseStreamAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new InvokeModelWithResponseStreamResponse() + { + Body = new ResponseStream(new MemoryStream()), + ContentType = "application/json" + }); + var kernel = Kernel.CreateBuilder().AddBedrockTextGenerationService(modelId, mockBedrockApi.Object).Build(); + var service = kernel.GetRequiredService(); + + // Act + List result = new(); + var output = service.GetStreamingTextContentsAsync(prompt).ConfigureAwait(true); + + // Assert + await foreach (var item in output) + { + Assert.NotNull(item); + Assert.NotNull(item.Text); + result.Add(item); + } + Assert.NotNull(result); + Assert.NotNull(service.GetModelId()); + } + + /// + /// Checks that the prompt execution settings are correctly registered for the text generation call with Amazon Titan. + /// + [Fact] + public async Task TitanGetTextContentsAsyncShouldReturnTextContentsAsyncWithPromptExecutionSettingsAsync() + { + // Arrange + string modelId = "amazon.titan-text-lite-v1"; + var mockBedrockApi = new Mock(); + var executionSettings = new PromptExecutionSettings() + { + ModelId = modelId, + ExtensionData = new Dictionary() + { + { "temperature", 0.1f }, + { "topP", 0.95f }, + { "maxTokenCount", 256 }, + { "stopSequences", new List { "" } } + } + }; + mockBedrockApi.Setup(m => m.DetermineServiceOperationEndpoint(It.IsAny())) + .Returns(new Endpoint("https://bedrock-runtime.us-east-1.amazonaws.com") + { + URL = "https://bedrock-runtime.us-east-1.amazonaws.com" + }); + mockBedrockApi.Setup(m => m.InvokeModelAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new InvokeModelResponse + { + Body = new MemoryStream(Encoding.UTF8.GetBytes(JsonSerializer.Serialize(new TitanTextResponse + { + InputTextTokenCount = 5, + Results = new List + { + new() { + TokenCount = 10, + OutputText = "This is a mock output.", + CompletionReason = "stop" + } + } + }))), + ContentType = "application/json" + }); + var kernel = Kernel.CreateBuilder().AddBedrockTextGenerationService(modelId, mockBedrockApi.Object).Build(); + var service = kernel.GetRequiredService(); + var prompt = "Write a greeting."; + + // Act + var result = await service.GetTextContentsAsync(prompt, executionSettings).ConfigureAwait(true); + + // Assert + InvokeModelRequest invokeModelRequest = new(); + var invocation = mockBedrockApi.Invocations + .Where(i => i.Method.Name == "InvokeModelAsync") + .SingleOrDefault(i => i.Arguments.Count > 0 && i.Arguments[0] is InvokeModelRequest); + if (invocation != null) + { + invokeModelRequest = (InvokeModelRequest)invocation.Arguments[0]; + } + Assert.Single(result); + Assert.Equal("This is a mock output.", result[0].Text); + Assert.NotNull(invokeModelRequest); + + using var requestBodyStream = invokeModelRequest.Body; + var requestBodyJson = await JsonDocument.ParseAsync(requestBodyStream).ConfigureAwait(true); + var requestBodyRoot = requestBodyJson.RootElement; + Assert.True(requestBodyRoot.TryGetProperty("textGenerationConfig", out var textGenerationConfig)); + if (textGenerationConfig.TryGetProperty("temperature", out var temperatureProperty)) + { + Assert.Equal(executionSettings.ExtensionData["temperature"], (float)temperatureProperty.GetDouble()); + } + + if (textGenerationConfig.TryGetProperty("topP", out var topPProperty)) + { + Assert.Equal(executionSettings.ExtensionData["topP"], (float)topPProperty.GetDouble()); + } + + if (textGenerationConfig.TryGetProperty("maxTokenCount", out var maxTokenCountProperty)) + { + Assert.Equal(executionSettings.ExtensionData["maxTokenCount"], maxTokenCountProperty.GetInt32()); + } + + if (textGenerationConfig.TryGetProperty("stopSequences", out var stopSequencesProperty)) + { + var stopSequences = stopSequencesProperty.EnumerateArray().Select(e => e.GetString()).ToList(); + Assert.Equal(executionSettings.ExtensionData["stopSequences"], stopSequences); + } + } + + /// + /// Checks that the prompt execution settings are correctly registered for the text generation call with AI21 Labs Jamba. + /// + [Fact] + public async Task AI21JambaGetTextContentsAsyncShouldReturnTextContentsAsyncWithPromptExecutionSettingsAsync() + { + // Arrange + string modelId = "ai21.jamba-instruct-v1:0"; + var mockBedrockApi = new Mock(); + var executionSettings = new PromptExecutionSettings() + { + ModelId = modelId, + ExtensionData = new Dictionary() + { + { "temperature", 0.8 }, + { "top_p", 0.95 }, + { "max_tokens", 256 }, + { "stop", new List { "" } } + } + }; + mockBedrockApi.Setup(m => m.DetermineServiceOperationEndpoint(It.IsAny())) + .Returns(new Endpoint("https://bedrock-runtime.us-east-1.amazonaws.com") + { + URL = "https://bedrock-runtime.us-east-1.amazonaws.com" + }); + mockBedrockApi.Setup(m => m.InvokeModelAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new InvokeModelResponse + { + Body = new MemoryStream(Encoding.UTF8.GetBytes(JsonSerializer.Serialize(new AI21JambaResponse.AI21TextResponse + { + Id = "my-request-id", + Choices = new List + { + new() { + Index = 0, + Message = new AI21JambaResponse.AI21TextResponse.Message + { + Role = "assistant", + Content = "Hello! This is a mock AI21 response." + }, + FinishReason = "stop" + } + }, + Use = new AI21JambaResponse.AI21TextResponse.Usage + { + PromptTokens = 10, + CompletionTokens = 15, + TotalTokens = 25 + } + }))), + ContentType = "application/json" + }); + var kernel = Kernel.CreateBuilder().AddBedrockTextGenerationService(modelId, mockBedrockApi.Object).Build(); + var service = kernel.GetRequiredService(); + var prompt = "Write a greeting."; + + // Act + var result = await service.GetTextContentsAsync(prompt, executionSettings).ConfigureAwait(true); + + // Assert + InvokeModelRequest invokeModelRequest = new(); + var invocation = mockBedrockApi.Invocations + .Where(i => i.Method.Name == "InvokeModelAsync") + .SingleOrDefault(i => i.Arguments.Count > 0 && i.Arguments[0] is InvokeModelRequest); + if (invocation != null) + { + invokeModelRequest = (InvokeModelRequest)invocation.Arguments[0]; + } + Assert.Single(result); + Assert.Equal("Hello! This is a mock AI21 response.", result[0].Text); + Assert.NotNull(invokeModelRequest); + + using var requestBodyStream = invokeModelRequest.Body; + var requestBodyJson = await JsonDocument.ParseAsync(requestBodyStream).ConfigureAwait(true); + var requestBodyRoot = requestBodyJson.RootElement; + Assert.True(requestBodyRoot.TryGetProperty("temperature", out var temperatureProperty)); + Assert.Equal(executionSettings.ExtensionData["temperature"], temperatureProperty.GetDouble()); + + Assert.True(requestBodyRoot.TryGetProperty("top_p", out var topPProperty)); + Assert.Equal(executionSettings.ExtensionData["top_p"], topPProperty.GetDouble()); + + Assert.True(requestBodyRoot.TryGetProperty("max_tokens", out var maxTokensProperty)); + Assert.Equal(executionSettings.ExtensionData["max_tokens"], maxTokensProperty.GetInt32()); + + Assert.True(requestBodyRoot.TryGetProperty("stop", out var stopProperty)); + var stopSequences = stopProperty.EnumerateArray().Select(e => e.GetString()).ToList(); + Assert.Equal(executionSettings.ExtensionData["stop"], stopSequences); + } + + /// + /// Checks that the prompt execution settings are correctly registered for the text generation call with Anthropic Claude. + /// + [Fact] + public async Task ClaudeGetTextContentsAsyncShouldReturnTextContentsAsyncWithPromptExecutionSettingsAsync() + { + // Arrange + string modelId = "anthropic.claude-text-generation.model-id-only-needs-proper-prefix"; + var mockBedrockApi = new Mock(); + var executionSettings = new PromptExecutionSettings() + { + ModelId = modelId, + ExtensionData = new Dictionary() + { + { "temperature", 0.8 }, + { "top_p", 0.95 }, + { "max_tokens_to_sample", 256 }, + { "stop_sequences", new List { "" } } + } + }; + mockBedrockApi.Setup(m => m.DetermineServiceOperationEndpoint(It.IsAny())) + .Returns(new Endpoint("https://bedrock-runtime.us-east-1.amazonaws.com") + { + URL = "https://bedrock-runtime.us-east-1.amazonaws.com" + }); + mockBedrockApi.Setup(m => m.InvokeModelAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new InvokeModelResponse + { + Body = new MemoryStream(Encoding.UTF8.GetBytes(JsonSerializer.Serialize(new ClaudeResponse + { + Completion = "Hello! This is a mock Claude response.", + StopReason = "stop_sequence", + Stop = "" + }))), + ContentType = "application/json" + }); + var kernel = Kernel.CreateBuilder().AddBedrockTextGenerationService(modelId, mockBedrockApi.Object).Build(); + var service = kernel.GetRequiredService(); + var prompt = "Write a greeting.\n\nHuman: \n\nAssistant:"; + + // Act + var result = await service.GetTextContentsAsync(prompt, executionSettings).ConfigureAwait(true); + + // Assert + InvokeModelRequest invokeModelRequest = new(); + var invocation = mockBedrockApi.Invocations + .Where(i => i.Method.Name == "InvokeModelAsync") + .SingleOrDefault(i => i.Arguments.Count > 0 && i.Arguments[0] is InvokeModelRequest); + if (invocation != null) + { + invokeModelRequest = (InvokeModelRequest)invocation.Arguments[0]; + } + Assert.Single(result); + Assert.Equal("Hello! This is a mock Claude response.", result[0].Text); + Assert.NotNull(invokeModelRequest); + + using var requestBodyStream = invokeModelRequest.Body; + var requestBodyJson = await JsonDocument.ParseAsync(requestBodyStream).ConfigureAwait(true); + var requestBodyRoot = requestBodyJson.RootElement; + Assert.True(requestBodyRoot.TryGetProperty("temperature", out var temperatureProperty)); + Assert.Equal(executionSettings.ExtensionData["temperature"], temperatureProperty.GetDouble()); + + Assert.True(requestBodyRoot.TryGetProperty("top_p", out var topPProperty)); + Assert.Equal(executionSettings.ExtensionData["top_p"], topPProperty.GetDouble()); + + Assert.True(requestBodyRoot.TryGetProperty("max_tokens_to_sample", out var maxTokensToSampleProperty)); + Assert.Equal(executionSettings.ExtensionData["max_tokens_to_sample"], maxTokensToSampleProperty.GetInt32()); + + Assert.True(requestBodyRoot.TryGetProperty("stop_sequences", out var stopSequencesProperty)); + var stopSequences = stopSequencesProperty.EnumerateArray().Select(e => e.GetString()).ToList(); + Assert.Equal(executionSettings.ExtensionData["stop_sequences"], stopSequences); + } + + /// + /// Checks that the prompt execution settings are correctly registered for the text generation call with Cohere Command. + /// + [Fact] + public async Task CohereCommandGetTextContentsAsyncShouldReturnReturnTextContentsAsyncWithPromptExecutionSettingsAsync() + { + // Arrange + string modelId = "cohere.command-text-generation"; + var mockBedrockApi = new Mock(); + var executionSettings = new PromptExecutionSettings() + { + ModelId = modelId, + ExtensionData = new Dictionary() + { + { "temperature", 0.8 }, + { "p", 0.95 }, + { "max_tokens", 256 }, + { "stop_sequences", new List { "" } } + } + }; + mockBedrockApi.Setup(m => m.DetermineServiceOperationEndpoint(It.IsAny())) + .Returns(new Endpoint("https://bedrock-runtime.us-east-1.amazonaws.com") + { + URL = "https://bedrock-runtime.us-east-1.amazonaws.com" + }); + mockBedrockApi.Setup(m => m.InvokeModelAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new InvokeModelResponse + { + Body = new MemoryStream(Encoding.UTF8.GetBytes(JsonSerializer.Serialize(new CommandTextResponse + { + Id = "my-request-id", + Prompt = "Write a greeting.", + Generations = new List + { + new() { + Id = "generation-id", + Text = "Hello! This is a mock Cohere Command response.", + FinishReason = "COMPLETE", + IsFinished = true + } + } + }))), + ContentType = "application/json" + }); + var kernel = Kernel.CreateBuilder().AddBedrockTextGenerationService(modelId, mockBedrockApi.Object).Build(); + var service = kernel.GetRequiredService(); + var prompt = "Write a greeting."; + + // Act + var result = await service.GetTextContentsAsync(prompt, executionSettings).ConfigureAwait(true); + + // Assert + InvokeModelRequest invokeModelRequest = new(); + var invocation = mockBedrockApi.Invocations + .Where(i => i.Method.Name == "InvokeModelAsync") + .SingleOrDefault(i => i.Arguments.Count > 0 && i.Arguments[0] is InvokeModelRequest); + if (invocation != null) + { + invokeModelRequest = (InvokeModelRequest)invocation.Arguments[0]; + } + Assert.Single(result); + Assert.Equal("Hello! This is a mock Cohere Command response.", result[0].Text); + Assert.NotNull(invokeModelRequest); + + using var requestBodyStream = invokeModelRequest.Body; + var requestBodyJson = await JsonDocument.ParseAsync(requestBodyStream).ConfigureAwait(true); + var requestBodyRoot = requestBodyJson.RootElement; + Assert.True(requestBodyRoot.TryGetProperty("temperature", out var temperatureProperty)); + Assert.Equal(executionSettings.ExtensionData["temperature"], temperatureProperty.GetDouble()); + + Assert.True(requestBodyRoot.TryGetProperty("p", out var topPProperty)); + Assert.Equal(executionSettings.ExtensionData["p"], topPProperty.GetDouble()); + + Assert.True(requestBodyRoot.TryGetProperty("max_tokens", out var maxTokensProperty)); + Assert.Equal(executionSettings.ExtensionData["max_tokens"], maxTokensProperty.GetInt32()); + + Assert.True(requestBodyRoot.TryGetProperty("stop_sequences", out var stopSequencesProperty)); + var stopSequences = stopSequencesProperty.EnumerateArray().Select(e => e.GetString()).ToList(); + Assert.Equal(executionSettings.ExtensionData["stop_sequences"], stopSequences); + } + /// + /// Checks that the prompt execution settings are correctly registered for the text generation call with Meta Llama3. + /// + [Fact] + public async Task LlamaGetTextContentsAsyncShouldReturnTextContentsAsyncWithPromptExecutionSettingsAsync() + { + // Arrange + string modelId = "meta.llama3-text-generation"; + var mockBedrockApi = new Mock(); + var executionSettings = new PromptExecutionSettings() + { + ModelId = modelId, + ExtensionData = new Dictionary() + { + { "temperature", 0.8 }, + { "top_p", 0.95 }, + { "max_gen_len", 256 } + } + }; + mockBedrockApi.Setup(m => m.DetermineServiceOperationEndpoint(It.IsAny())) + .Returns(new Endpoint("https://bedrock-runtime.us-east-1.amazonaws.com") + { + URL = "https://bedrock-runtime.us-east-1.amazonaws.com" + }); + mockBedrockApi.Setup(m => m.InvokeModelAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new InvokeModelResponse + { + Body = new MemoryStream(Encoding.UTF8.GetBytes(JsonSerializer.Serialize(new LlamaTextResponse + { + Generation = "Hello! This is a mock Llama response.", + PromptTokenCount = 10, + GenerationTokenCount = 15, + StopReason = "stop" + }))), + ContentType = "application/json" + }); + var kernel = Kernel.CreateBuilder().AddBedrockTextGenerationService(modelId, mockBedrockApi.Object).Build(); + var service = kernel.GetRequiredService(); + var prompt = "Write a greeting."; + + // Act + var result = await service.GetTextContentsAsync(prompt, executionSettings).ConfigureAwait(true); + + // Assert + InvokeModelRequest invokeModelRequest = new(); + var invocation = mockBedrockApi.Invocations + .Where(i => i.Method.Name == "InvokeModelAsync") + .SingleOrDefault(i => i.Arguments.Count > 0 && i.Arguments[0] is InvokeModelRequest); + if (invocation != null) + { + invokeModelRequest = (InvokeModelRequest)invocation.Arguments[0]; + } + Assert.Single(result); + Assert.Equal("Hello! This is a mock Llama response.", result[0].Text); + Assert.NotNull(invokeModelRequest); + + using var requestBodyStream = invokeModelRequest.Body; + var requestBodyJson = await JsonDocument.ParseAsync(requestBodyStream).ConfigureAwait(true); + var requestBodyRoot = requestBodyJson.RootElement; + Assert.True(requestBodyRoot.TryGetProperty("temperature", out var temperatureProperty)); + Assert.Equal(executionSettings.ExtensionData["temperature"], temperatureProperty.GetDouble()); + + Assert.True(requestBodyRoot.TryGetProperty("top_p", out var topPProperty)); + Assert.Equal(executionSettings.ExtensionData["top_p"], topPProperty.GetDouble()); + + Assert.True(requestBodyRoot.TryGetProperty("max_gen_len", out var maxGenLenProperty)); + Assert.Equal(executionSettings.ExtensionData["max_gen_len"], maxGenLenProperty.GetInt32()); + } + /// + /// Checks that the prompt execution settings are correctly registered for the text generation call with Mistral. + /// + [Fact] + public async Task MistralGetTextContentsAsyncShouldReturnTextContentsAsyncWithPromptExecutionSettingsAsync() + { + // Arrange + string modelId = "mistral.mistral-text-generation"; + var mockBedrockApi = new Mock(); + var executionSettings = new PromptExecutionSettings() + { + ModelId = modelId, + ExtensionData = new Dictionary() + { + { "temperature", 0.8 }, + { "top_p", 0.95 }, + { "max_tokens", 256 }, + { "stop", new List { "" } } + } + }; + mockBedrockApi.Setup(m => m.DetermineServiceOperationEndpoint(It.IsAny())) + .Returns(new Endpoint("https://bedrock-runtime.us-east-1.amazonaws.com") + { + URL = "https://bedrock-runtime.us-east-1.amazonaws.com" + }); + mockBedrockApi.Setup(m => m.InvokeModelAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new InvokeModelResponse + { + Body = new MemoryStream(Encoding.UTF8.GetBytes(JsonSerializer.Serialize(new MistralTextResponse + { + Outputs = new List + { + new() { + Text = "Hello! This is a mock Mistral response.", + StopReason = "stop_sequence" + } + } + }))), + ContentType = "application/json" + }); + var kernel = Kernel.CreateBuilder().AddBedrockTextGenerationService(modelId, mockBedrockApi.Object).Build(); + var service = kernel.GetRequiredService(); + var prompt = "Write a greeting."; + + // Act + var result = await service.GetTextContentsAsync(prompt, executionSettings).ConfigureAwait(true); + + // Assert + InvokeModelRequest invokeModelRequest = new(); + var invocation = mockBedrockApi.Invocations + .Where(i => i.Method.Name == "InvokeModelAsync") + .SingleOrDefault(i => i.Arguments.Count > 0 && i.Arguments[0] is InvokeModelRequest); + if (invocation != null) + { + invokeModelRequest = (InvokeModelRequest)invocation.Arguments[0]; + } + Assert.Single(result); + Assert.Equal("Hello! This is a mock Mistral response.", result[0].Text); + Assert.NotNull(invokeModelRequest); + + using var requestBodyStream = invokeModelRequest.Body; + var requestBodyJson = await JsonDocument.ParseAsync(requestBodyStream).ConfigureAwait(true); + var requestBodyRoot = requestBodyJson.RootElement; + Assert.True(requestBodyRoot.TryGetProperty("temperature", out var temperatureProperty)); + Assert.Equal(executionSettings.ExtensionData["temperature"], temperatureProperty.GetDouble()); + + Assert.True(requestBodyRoot.TryGetProperty("top_p", out var topPProperty)); + Assert.Equal(executionSettings.ExtensionData["top_p"], topPProperty.GetDouble()); + + Assert.True(requestBodyRoot.TryGetProperty("max_tokens", out var maxTokensProperty)); + Assert.Equal(executionSettings.ExtensionData["max_tokens"], maxTokensProperty.GetInt32()); + + Assert.True(requestBodyRoot.TryGetProperty("stop", out var stopSequencesProperty)); + var stopSequences = stopSequencesProperty.EnumerateArray().Select(e => e.GetString()).ToList(); + Assert.Equal(executionSettings.ExtensionData["stop"], stopSequences); + } +} diff --git a/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Core/BedrockClientIOService.cs b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Core/BedrockClientIOService.cs new file mode 100644 index 000000000000..9884dbc22207 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Core/BedrockClientIOService.cs @@ -0,0 +1,94 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Connectors.Amazon.Models; +using Connectors.Amazon.Models.AI21; +using Connectors.Amazon.Models.Amazon; +using Connectors.Amazon.Models.Anthropic; +using Connectors.Amazon.Models.Cohere; +using Connectors.Amazon.Models.Meta; +using Connectors.Amazon.Models.Mistral; + +namespace Connectors.Amazon.Bedrock.Core; + +/// +/// Utilities to get the model IO service and model provider. Used by Bedrock service clients. +/// +public class BedrockClientIOService +{ + /// + /// Gets the model IO service for body conversion. + /// + /// + /// + /// + public IBedrockModelIOService GetIOService(string modelId) + { + string[] parts = modelId.Split('.'); //modelId looks like "amazon.titan-text-premier-v1:0" + string modelProvider = parts[0]; + string modelName = parts.Length > 1 ? parts[1] : string.Empty; + + switch (modelProvider) + { + case "ai21": + if (modelName.StartsWith("jamba", StringComparison.OrdinalIgnoreCase)) + { + return new AI21JambaIOService(); + } + if (modelName.StartsWith("j2-", StringComparison.OrdinalIgnoreCase)) + { + return new AI21JurassicIOService(); + } + throw new ArgumentException($"Unsupported AI21 model: {modelId}"); + case "amazon": + if (modelName.StartsWith("titan-", StringComparison.OrdinalIgnoreCase)) + { + return new AmazonIOService(); + } + throw new ArgumentException($"Unsupported Amazon model: {modelId}"); + case "anthropic": + if (modelName.StartsWith("claude-", StringComparison.OrdinalIgnoreCase)) + { + return new AnthropicIOService(); + } + throw new ArgumentException($"Unsupported Anthropic model: {modelId}"); + case "cohere": + if (modelName.StartsWith("command-r", StringComparison.OrdinalIgnoreCase)) + { + return new CohereCommandRIOService(); + } + if (modelName.StartsWith("command-", StringComparison.OrdinalIgnoreCase)) + { + return new CohereCommandIOService(); + } + throw new ArgumentException($"Unsupported Cohere model: {modelId}"); + case "meta": + if (modelName.StartsWith("llama3-", StringComparison.OrdinalIgnoreCase)) + { + return new MetaIOService(); + } + throw new ArgumentException($"Unsupported Meta model: {modelId}"); + case "mistral": + if (modelName.StartsWith("mistral-", StringComparison.OrdinalIgnoreCase)) + { + return new MistralIOService(); + } + if (modelName.StartsWith("mixtral-", StringComparison.OrdinalIgnoreCase)) + { + return new MistralIOService(); + } + throw new ArgumentException($"Unsupported Mistral model: {modelId}"); + default: + throw new ArgumentException($"Unsupported model provider: {modelProvider}"); + } + } + /// + /// Gets the model provider from modelId. + /// + /// + /// + public string GetModelProvider(string modelId) + { + string[] parts = modelId.Split('.'); //modelId looks like "amazon.titan-text-premier-v1:0" + return parts[0]; + } +} diff --git a/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Core/BedrockClientUtilities.cs b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Core/BedrockClientUtilities.cs new file mode 100644 index 000000000000..40d651471e6f --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Core/BedrockClientUtilities.cs @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics; +using System.Net; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace Connectors.Amazon.Bedrock.Core; + +/// +/// Utility functions for the Bedrock clients. +/// +public class BedrockClientUtilities +{ + /// + /// Convert the Http Status Code in Converse Response to the Activity Status Code for Semantic Kernel activity. + /// + /// + /// + public ActivityStatusCode ConvertHttpStatusCodeToActivityStatusCode(HttpStatusCode httpStatusCode) + { + if ((int)httpStatusCode >= 200 && (int)httpStatusCode < 300) + { + // 2xx status codes represent success + return ActivityStatusCode.Ok; + } + else if ((int)httpStatusCode >= 400 && (int)httpStatusCode < 600) + { + // 4xx and 5xx status codes represent errors + return ActivityStatusCode.Error; + } + else + { + // Any other status code is considered unset + return ActivityStatusCode.Unset; + } + } + /// + /// Map Conversation role (value) to author role to build message content for semantic kernel output. + /// + /// + /// + /// + public AuthorRole MapConversationRoleToAuthorRole(string role) + { + return role switch + { + "user" => AuthorRole.User, + "assistant" => AuthorRole.Assistant, + "system" => AuthorRole.System, + _ => throw new ArgumentOutOfRangeException(nameof(role), $"Invalid role: {role}") + }; + } +} diff --git a/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Core/Clients/BedrockChatCompletionClient.cs b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Core/Clients/BedrockChatCompletionClient.cs new file mode 100644 index 000000000000..77c9a5f65432 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Core/Clients/BedrockChatCompletionClient.cs @@ -0,0 +1,196 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics; +using System.Runtime.CompilerServices; +using Amazon.BedrockRuntime; +using Amazon.BedrockRuntime.Model; +using Connectors.Amazon.Bedrock.Core; +using Connectors.Amazon.Models; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Diagnostics; + +namespace Microsoft.SemanticKernel.Connectors.Amazon.Core; + +/// +/// Represents a client for interacting with the chat completion through Bedrock. +/// +internal sealed class BedrockChatCompletionClient +{ + private readonly string _modelId; + private readonly string _modelProvider; + private readonly IAmazonBedrockRuntime _bedrockApi; + private readonly IBedrockModelIOService _ioService; + private readonly BedrockClientUtilities _clientUtilities; + private Uri? _chatGenerationEndpoint; + + /// + /// Builds the client object and registers the model input-output service given the user's passed in model ID. + /// + /// + /// + /// + public BedrockChatCompletionClient(string modelId, IAmazonBedrockRuntime bedrockApi) + { + var clientService = new BedrockClientIOService(); + this._modelId = modelId; + this._bedrockApi = bedrockApi; + this._ioService = clientService.GetIOService(modelId); + this._modelProvider = clientService.GetModelProvider(modelId); + this._clientUtilities = new BedrockClientUtilities(); + } + /// + /// Generates a chat message based on the provided chat history and execution settings. + /// + /// The chat history to use for generating the chat message. + /// The execution settings for the chat generation. + /// The Semantic Kernel instance. + /// The cancellation token. + /// The generated chat message. + /// Thrown when the chat history is null or empty. + /// Thrown when an error occurs during the chat generation process. + internal async Task> GenerateChatMessageAsync( + ChatHistory chatHistory, + PromptExecutionSettings? executionSettings = null, + Kernel? kernel = null, + CancellationToken cancellationToken = default) + { + Verify.NotNullOrEmpty(chatHistory); + ConverseRequest converseRequest = this._ioService.GetConverseRequest(this._modelId, chatHistory, executionSettings); + var regionEndpoint = this._bedrockApi.DetermineServiceOperationEndpoint(converseRequest).URL; + this._chatGenerationEndpoint = new Uri(regionEndpoint); + ConverseResponse? response = null; + using var activity = ModelDiagnostics.StartCompletionActivity( + this._chatGenerationEndpoint, this._modelId, this._modelProvider, chatHistory, executionSettings); + ActivityStatusCode activityStatus; + try + { + response = await this._bedrockApi.ConverseAsync(converseRequest, cancellationToken).ConfigureAwait(false); + if (activity is not null) + { + activityStatus = this._clientUtilities.ConvertHttpStatusCodeToActivityStatusCode(response.HttpStatusCode); + activity.SetStatus(activityStatus); + activity.SetPromptTokenUsage(response.Usage.InputTokens); + activity.SetCompletionTokenUsage(response.Usage.OutputTokens); + } + } + catch (Exception ex) + { + Console.WriteLine($"ERROR: Can't converse with '{this._modelId}'. Reason: {ex.Message}"); + if (activity is not null) + { + activity.SetError(ex); + if (response != null) + { + activityStatus = this._clientUtilities.ConvertHttpStatusCodeToActivityStatusCode(response.HttpStatusCode); + activity.SetStatus(activityStatus); + activity.SetPromptTokenUsage(response.Usage.InputTokens); + activity.SetCompletionTokenUsage(response.Usage.OutputTokens); + } + else + { + // If response is null, set a default status or leave it unset + activity.SetStatus(ActivityStatusCode.Error); // or ActivityStatusCode.Unset + } + } + throw; + } + IReadOnlyList chatMessages = this.ConvertToMessageContent(response).ToList(); + activityStatus = this._clientUtilities.ConvertHttpStatusCodeToActivityStatusCode(response.HttpStatusCode); + activity?.SetStatus(activityStatus); + activity?.SetCompletionResponse(chatMessages, response.Usage.InputTokens, response.Usage.OutputTokens); + return chatMessages; + } + /// + /// Converts the ConverseResponse object as outputted by the Bedrock Runtime API call to a ChatMessageContent for the Semantic Kernel. + /// + /// ConverseResponse object outputted by Bedrock. + /// + private ChatMessageContent[] ConvertToMessageContent(ConverseResponse response) + { + if (response.Output.Message == null) + { + return []; + } + var message = response.Output.Message; + return new[] + { + new ChatMessageContent + { + Role = this._clientUtilities.MapConversationRoleToAuthorRole(message.Role.Value), + Items = CreateChatMessageContentItemCollection(message.Content) + } + }; + } + private static ChatMessageContentItemCollection CreateChatMessageContentItemCollection(List contentBlocks) + { + var itemCollection = new ChatMessageContentItemCollection(); + foreach (var contentBlock in contentBlocks) + { + itemCollection.Add(new TextContent(contentBlock.Text)); + } + return itemCollection; + } + + // Order of operations: + // 1. Start completion activity with semantic kernel + // 2. Call converse stream async with bedrock API + // 3. Convert output to semantic kernel's StreamingChatMessageContent + // 4. Yield return the streamed contents + // 5. End streaming activity with kernel + internal async IAsyncEnumerable StreamChatMessageAsync( + ChatHistory chatHistory, + PromptExecutionSettings? executionSettings = null, + Kernel? kernel = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var converseStreamRequest = this._ioService.GetConverseStreamRequest(this._modelId, chatHistory, executionSettings); + var regionEndpoint = this._bedrockApi.DetermineServiceOperationEndpoint(converseStreamRequest).URL; + this._chatGenerationEndpoint = new Uri(regionEndpoint); + ConverseStreamResponse? response = null; + using var activity = ModelDiagnostics.StartCompletionActivity( + this._chatGenerationEndpoint, this._modelId, this._modelProvider, chatHistory, executionSettings); + ActivityStatusCode activityStatus; + try + { + response = await this._bedrockApi.ConverseStreamAsync(converseStreamRequest, cancellationToken).ConfigureAwait(false); + if (activity is not null) + { + activityStatus = this._clientUtilities.ConvertHttpStatusCodeToActivityStatusCode(response.HttpStatusCode); + activity.SetStatus(activityStatus); + } + } + catch (Exception ex) + { + Console.WriteLine($"ERROR: Can't converse stream with '{this._modelId}'. Reason: {ex.Message}"); + if (activity is not null) + { + activity.SetError(ex); + if (response != null) + { + activityStatus = this._clientUtilities.ConvertHttpStatusCodeToActivityStatusCode(response.HttpStatusCode); + activity.SetStatus(activityStatus); + } + else + { + // If response is null, set a default status or leave it unset + activity.SetStatus(ActivityStatusCode.Error); // or ActivityStatusCode.Unset + } + } + throw; + } + List? streamedContents = activity is not null ? [] : null; + foreach (var chunk in response.Stream.AsEnumerable()) + { + if (chunk is ContentBlockDeltaEvent) + { + var c = (chunk as ContentBlockDeltaEvent)?.Delta.Text; + var content = new StreamingChatMessageContent(AuthorRole.Assistant, c); + streamedContents?.Add(content); + yield return content; + } + } + activityStatus = this._clientUtilities.ConvertHttpStatusCodeToActivityStatusCode(response.HttpStatusCode); + activity?.SetStatus(activityStatus); + activity?.EndStreaming(streamedContents); + } +} diff --git a/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Core/Clients/BedrockTextGenerationClient.cs b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Core/Clients/BedrockTextGenerationClient.cs new file mode 100644 index 000000000000..b04093960c09 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Core/Clients/BedrockTextGenerationClient.cs @@ -0,0 +1,171 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Text.Json; +using System.Text.Json.Nodes; +using Amazon.BedrockRuntime; +using Amazon.BedrockRuntime.Model; +using Connectors.Amazon.Bedrock.Core; +using Connectors.Amazon.Models; +using Microsoft.SemanticKernel.Diagnostics; + +namespace Microsoft.SemanticKernel.Connectors.Amazon.Bedrock.Core; + +/// +/// Represents a client for interacting with the text generation through Bedrock. +/// +internal sealed class BedrockTextGenerationClient +{ + private readonly string _modelId; + private readonly string _modelProvider; + private readonly IAmazonBedrockRuntime _bedrockApi; + private readonly IBedrockModelIOService _ioService; + private readonly BedrockClientUtilities _clientUtilities; + private Uri? _textGenerationEndpoint; + + /// + /// Builds the client object and registers the model input-output service given the user's passed in model ID. + /// + /// + /// + /// + public BedrockTextGenerationClient(string modelId, IAmazonBedrockRuntime bedrockApi) + { + var clientService = new BedrockClientIOService(); + this._modelId = modelId; + this._bedrockApi = bedrockApi; + this._ioService = clientService.GetIOService(modelId); + this._modelProvider = clientService.GetModelProvider(modelId); + this._clientUtilities = new BedrockClientUtilities(); + } + + internal async Task> InvokeBedrockModelAsync( + string prompt, + PromptExecutionSettings? executionSettings = null, + CancellationToken cancellationToken = default) + { + Verify.NotNullOrWhiteSpace(prompt); + var requestBody = this._ioService.GetInvokeModelRequestBody(this._modelId, prompt, executionSettings); + var invokeRequest = new InvokeModelRequest + { + ModelId = this._modelId, + Accept = "*/*", + ContentType = "application/json", + Body = new MemoryStream(JsonSerializer.SerializeToUtf8Bytes(requestBody)) + }; + var regionEndpoint = this._bedrockApi.DetermineServiceOperationEndpoint(invokeRequest).URL; + this._textGenerationEndpoint = new Uri(regionEndpoint); + InvokeModelResponse? response = null; + using var activity = ModelDiagnostics.StartCompletionActivity( + this._textGenerationEndpoint, this._modelId, this._modelProvider, prompt, executionSettings); + ActivityStatusCode activityStatus; + try + { + response = await this._bedrockApi.InvokeModelAsync(invokeRequest, cancellationToken).ConfigureAwait(false); + if (activity is not null) + { + activityStatus = this._clientUtilities.ConvertHttpStatusCodeToActivityStatusCode(response.HttpStatusCode); + activity.SetStatus(activityStatus); + } + } + catch (Exception ex) + { + Console.WriteLine($"ERROR: Can't invoke '{this._modelId}'. Reason: {ex.Message}"); + if (activity is not null) + { + activity.SetError(ex); + if (response != null) + { + activityStatus = this._clientUtilities.ConvertHttpStatusCodeToActivityStatusCode(response.HttpStatusCode); + activity.SetStatus(activityStatus); + } + else + { + // If response is null, set a default status or leave it unset + activity.SetStatus(ActivityStatusCode.Error); // or ActivityStatusCode.Unset + } + } + throw; + } + activityStatus = this._clientUtilities.ConvertHttpStatusCodeToActivityStatusCode(response.HttpStatusCode); + activity?.SetStatus(activityStatus); + IReadOnlyList textResponse = this._ioService.GetInvokeResponseBody(response); + activity?.SetCompletionResponse(textResponse); + return textResponse; + } + + internal async IAsyncEnumerable StreamTextAsync( + string prompt, + PromptExecutionSettings? executionSettings = null, + Kernel? kernel = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + Verify.NotNullOrWhiteSpace(prompt); + var requestBody = this._ioService.GetInvokeModelRequestBody(this._modelId, prompt, executionSettings); + var invokeRequest = new InvokeModelWithResponseStreamRequest + { + ModelId = this._modelId, + Accept = "*/*", + ContentType = "application/json", + Body = new MemoryStream(JsonSerializer.SerializeToUtf8Bytes(requestBody)) + }; + var regionEndpoint = this._bedrockApi.DetermineServiceOperationEndpoint(invokeRequest).URL; + this._textGenerationEndpoint = new Uri(regionEndpoint); + InvokeModelWithResponseStreamResponse? streamingResponse = null; + using var activity = ModelDiagnostics.StartCompletionActivity( + this._textGenerationEndpoint, this._modelId, this._modelProvider, prompt, executionSettings); + ActivityStatusCode activityStatus; + try + { + streamingResponse = await this._bedrockApi.InvokeModelWithResponseStreamAsync(invokeRequest, cancellationToken).ConfigureAwait(false); + if (activity is not null) + { + activityStatus = this._clientUtilities.ConvertHttpStatusCodeToActivityStatusCode(streamingResponse.HttpStatusCode); + activity.SetStatus(activityStatus); + } + } + catch (Exception ex) + { + Console.WriteLine($"ERROR: Can't invoke '{this._modelId}'. Reason: {ex.Message}"); + if (activity is not null) + { + activity.SetError(ex); + if (streamingResponse != null) + { + activityStatus = this._clientUtilities.ConvertHttpStatusCodeToActivityStatusCode(streamingResponse.HttpStatusCode); + activity.SetStatus(activityStatus); + } + else + { + // If streamingResponse is null, set a default status or leave it unset + activity.SetStatus(ActivityStatusCode.Error); // or ActivityStatusCode.Unset + } + } + throw; + } + + List? streamedContents = activity is not null ? [] : null; + foreach (var item in streamingResponse.Body) + { + if (item is not PayloadPart payloadPart) + { + continue; + } + var chunk = JsonSerializer.Deserialize(payloadPart.Bytes); + if (chunk is null) + { + continue; + } + IEnumerable texts = this._ioService.GetTextStreamOutput(chunk); + foreach (var text in texts) + { + var content = new StreamingTextContent(text); + streamedContents?.Add(content); + yield return new StreamingTextContent(text); + } + } + activity?.SetStatus(this._clientUtilities.ConvertHttpStatusCodeToActivityStatusCode(streamingResponse.HttpStatusCode)); + activity?.EndStreaming(streamedContents); + } +} diff --git a/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Extensions/BedrockKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Extensions/BedrockKernelBuilderExtensions.cs new file mode 100644 index 000000000000..167ca9af8db5 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Extensions/BedrockKernelBuilderExtensions.cs @@ -0,0 +1,125 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Amazon.BedrockRuntime; +using Connectors.Amazon.Services; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.Amazon.Services; +using Microsoft.SemanticKernel.TextGeneration; + +namespace Connectors.Amazon.Extensions; + +/// +/// Extensions for adding Bedrock services to the application. +/// +public static class BedrockKernelBuilderExtensions +{ + /// + /// Add Amazon Bedrock Chat Completion service to the kernel builder using IAmazonBedrockRuntime object. + /// + /// The kernel builder. + /// The model for chat completion. + /// The IAmazonBedrockRuntime to run inference using the respective model. + /// + public static IKernelBuilder AddBedrockChatCompletionService( + this IKernelBuilder builder, + string modelId, + IAmazonBedrockRuntime bedrockApi) + { + builder.Services.AddSingleton(_ => + { + try + { + return new BedrockChatCompletionService(modelId, bedrockApi); + } + catch (Exception ex) + { + throw new KernelException($"An error occurred while initializing the BedrockChatCompletionService: {ex.Message}", ex); + } + }); + + return builder; + } + + /// + /// Add Amazon Bedrock Chat Completion service to the kernel builder using new AmazonBedrockRuntimeClient(). + /// + /// The kernel builder. + /// The model for chat completion. + /// + public static IKernelBuilder AddBedrockChatCompletionService( + this IKernelBuilder builder, + string modelId) + { + // Add IAmazonBedrockRuntime service client to the DI container + builder.Services.AddAWSService(); + builder.Services.AddSingleton(services => + { + try + { + var bedrockRuntime = services.GetRequiredService(); + return new BedrockChatCompletionService(modelId, bedrockRuntime); + } + catch (Exception ex) + { + throw new KernelException($"An error occurred while initializing the BedrockChatCompletionService: {ex.Message}", ex); + } + }); + + return builder; + } + /// + /// Add Amazon Bedrock Text Generation service to the kernel builder using IAmazonBedrockRuntime object. + /// + /// The kernel builder. + /// The model for text generation. + /// The IAmazonBedrockRuntime to run inference using the respective model. + /// + public static IKernelBuilder AddBedrockTextGenerationService( + this IKernelBuilder builder, + string modelId, + IAmazonBedrockRuntime bedrockApi) + { + builder.Services.AddSingleton(_ => + { + try + { + return new BedrockTextGenerationService(modelId, bedrockApi); + } + catch (Exception ex) + { + throw new KernelException($"An error occurred while initializing the BedrockTextGenerationService: {ex.Message}", ex); + } + }); + + return builder; + } + /// + /// Add Amazon Bedrock Text Generation service to the kernel builder using new AmazonBedrockRuntimeClient(). + /// + /// The kernel builder. + /// The model for text generation. + /// + public static IKernelBuilder AddBedrockTextGenerationService( + this IKernelBuilder builder, + string modelId) + { + // Add IAmazonBedrockRuntime service client to the DI container + builder.Services.AddAWSService(); + builder.Services.AddSingleton(services => + { + try + { + var bedrockRuntime = services.GetRequiredService(); + return new BedrockTextGenerationService(modelId, bedrockRuntime); + } + catch (Exception ex) + { + throw new KernelException($"An error occurred while initializing the BedrockTextGenerationService: {ex.Message}", ex); + } + }); + + return builder; + } +} diff --git a/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/AI21 Labs/AI21JambaIOService.cs b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/AI21 Labs/AI21JambaIOService.cs new file mode 100644 index 000000000000..029969f2b70c --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/AI21 Labs/AI21JambaIOService.cs @@ -0,0 +1,165 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text; +using System.Text.Json; +using System.Text.Json.Nodes; +using Amazon.BedrockRuntime.Model; +using Amazon.Runtime.Documents; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace Connectors.Amazon.Models.AI21; + +/// +/// Input-output service for AI21 Labs Jamba model. +/// +public class AI21JambaIOService : IBedrockModelIOService +{ + // Define constants for default values + private const double DefaultTemperature = 1.0; + private const double DefaultTopP = 0.9; + private const int DefaultMaxTokens = 4096; + private const int DefaultN = 1; + private const double DefaultFrequencyPenalty = 0.0; + private const double DefaultPresencePenalty = 0.0; + /// + /// Builds InvokeModel request Body parameter with structure as required by AI21 Labs Jamba model. + /// + /// The model ID to be used as a request parameter. + /// The input prompt for text generation. + /// Optional prompt execution settings. + /// + public object GetInvokeModelRequestBody(string modelId, string prompt, PromptExecutionSettings? executionSettings = null) + { + var requestBody = new + { + messages = new[] + { + new + { + role = "user", + content = prompt + } + }, + temperature = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "temperature", DefaultTemperature), + top_p = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "top_p", DefaultTopP), + max_tokens = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "max_tokens", DefaultMaxTokens), + stop = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "stop", new List()), + n = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "n", DefaultN), + frequency_penalty = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "frequency_penalty", DefaultFrequencyPenalty), + presence_penalty = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "presence_penalty", DefaultPresencePenalty) + }; + + return requestBody; + } + + /// + /// Extracts the test contents from the InvokeModelResponse as returned by the Bedrock API. + /// + /// The InvokeModelResponse object provided by the Bedrock InvokeModelAsync output. + /// + public IReadOnlyList GetInvokeResponseBody(InvokeModelResponse response) + { + using var memoryStream = new MemoryStream(); + response.Body.CopyToAsync(memoryStream).ConfigureAwait(false).GetAwaiter().GetResult(); + memoryStream.Position = 0; + using var reader = new StreamReader(memoryStream); + var responseBody = JsonSerializer.Deserialize(reader.ReadToEnd()); + var textContents = new List(); + if (responseBody?.Choices is not { Count: > 0 }) + { + return textContents; + } + textContents.AddRange(responseBody.Choices.Select(choice => new TextContent(choice.Message?.Content))); + return textContents; + } + + /// + /// Builds the ConverseRequest object for the Bedrock ConverseAsync call with request parameters required by AI21 Labs Jamba. + /// + /// The model ID. + /// The messages between assistant and user. + /// Optional prompt execution settings. + /// + public ConverseRequest GetConverseRequest(string modelId, ChatHistory chatHistory, PromptExecutionSettings? settings = null) + { + var messages = BedrockModelUtilities.BuildMessageList(chatHistory); + var systemMessages = BedrockModelUtilities.GetSystemMessages(chatHistory); + + var inferenceConfig = new InferenceConfiguration + { + Temperature = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "temperature", (float)DefaultTemperature), + TopP = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "top_p", (float)DefaultTopP), + MaxTokens = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "max_tokens", DefaultMaxTokens), + StopSequences = BedrockModelUtilities.GetExtensionDataValue>(settings?.ExtensionData, "stop_sequences", []), + }; + + var converseRequest = new ConverseRequest + { + ModelId = modelId, + Messages = messages, + System = systemMessages, + InferenceConfig = inferenceConfig, + AdditionalModelRequestFields = new Document + { + { "n", BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "n", DefaultN) }, + { "frequency_penalty", BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "frequency_penalty", DefaultFrequencyPenalty) }, + { "presence_penalty", BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "presence_penalty", DefaultPresencePenalty) } + }, + AdditionalModelResponseFieldPaths = [] + }; + + return converseRequest; + } + /// + /// Gets the streamed text output. + /// + /// + /// + public IEnumerable GetTextStreamOutput(JsonNode chunk) + { + var buffer = new StringBuilder(); + if (chunk["choices"]?[0]?["delta"]?["content"] != null) + { + buffer.Append(chunk["choices"]?[0]?["delta"]?["content"]); + yield return buffer.ToString(); + } + } + /// + /// Gets converse stream output. + /// + /// + /// + /// + /// + public ConverseStreamRequest GetConverseStreamRequest(string modelId, ChatHistory chatHistory, PromptExecutionSettings? settings = null) + { + var messages = BedrockModelUtilities.BuildMessageList(chatHistory); + var systemMessages = BedrockModelUtilities.GetSystemMessages(chatHistory); + + var inferenceConfig = new InferenceConfiguration + { + Temperature = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "temperature", (float)DefaultTemperature), + TopP = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "top_p", (float)DefaultTopP), + MaxTokens = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "max_tokens", DefaultMaxTokens), + StopSequences = BedrockModelUtilities.GetExtensionDataValue>(settings?.ExtensionData, "stop_sequences", []), + }; + + var converseRequest = new ConverseStreamRequest + { + ModelId = modelId, + Messages = messages, + System = systemMessages, + InferenceConfig = inferenceConfig, + AdditionalModelRequestFields = new Document + { + { "n", BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "n", DefaultN) }, + { "frequency_penalty", BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "frequency_penalty", DefaultFrequencyPenalty) }, + { "presence_penalty", BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "presence_penalty", DefaultPresencePenalty) } + }, + AdditionalModelResponseFieldPaths = [] + }; + + return converseRequest; + } +} diff --git a/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/AI21 Labs/AI21JambaResponse.cs b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/AI21 Labs/AI21JambaResponse.cs new file mode 100644 index 000000000000..7460c48d9a06 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/AI21 Labs/AI21JambaResponse.cs @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json.Serialization; + +namespace Connectors.Amazon.Models.AI21; + +/// +/// AI21JambaResponse objects for Bedrock Runtime actions. +/// +public static class AI21JambaResponse +{ + /// + /// AI21 Text Generation Response object (from Invoke). + /// + [Serializable] + public class AI21TextResponse + { + /// + /// A unique ID for the request (not the message). Repeated identical requests get different IDs. However, for a streaming response, the ID will be the same for all responses in the stream. + /// + [JsonPropertyName("id")] + public string? Id { get; set; } + /// + /// One or more responses, depending on the n parameter from the request. + /// + [JsonPropertyName("choices")] + public List? Choices { get; set; } + /// + /// The token counts for this request. Per-token billing is based on the prompt token and completion token counts and rates. + /// + [JsonPropertyName("usage")] + public Usage? Use { get; set; } + /// + /// The members for the Choice class as required by AI21 Labs Jamba. + /// + [Serializable] + public class Choice + { + /// + /// Zero-based index of the message in the list of messages. Note that this might not correspond with the position in the response list. + /// + [JsonPropertyName("index")] + public int Index { get; set; } + /// + /// The message generated by the model. Same structure as the request message, with role and content members. + /// + [JsonPropertyName("message")] + public Message? Message { get; set; } + /// + /// Why the message ended. Possible reasons: + /// stop: The response ended naturally as a complete answer(due to end-of-sequence token) or because the model generated a stop sequence provided in the request. + /// length: The response ended by reaching max_tokens. + /// + [JsonPropertyName("finish_reason")] + public string? FinishReason { get; set; } + } + /// + /// Message object for the model with role and content as required. + /// + [Serializable] + public class Message + { + /// + /// The role of the message author. One of the following values: + /// user: Input provided by the user.Any instructions given here that conflict with instructions given in the system prompt take precedence over the system prompt instructions. + /// assistant: Response generated by the model. + /// system: Initial instructions provided to the system to provide general guidance on the tone and voice of the generated message.An initial system message is optional but recommended to provide guidance on the tone of the chat.For example, "You are a helpful chatbot with a background in earth sciences and a charming French accent." + /// + [JsonPropertyName("role")] + public string? Role { get; set; } + /// + /// The content of the message. + /// + [JsonPropertyName("content")] + public string? Content { get; set; } + } + /// + /// The token counts for this request. Per-token billing is based on the prompt token and completion token counts and rates. + /// + [Serializable] + public class Usage + { + /// + /// Number of tokens in the prompt for this request. Note that the prompt token includes the entire message history, plus extra tokens needed by the system when combining the list of prompt messages into a single message, as required by the model. The number of extra tokens is typically proportional to the number of messages in the thread, and should be relatively small. + /// + [JsonPropertyName("prompt_tokens")] + public int PromptTokens { get; set; } + /// + /// Number of tokens in the response message. + /// + [JsonPropertyName("completion_tokens")] + public int CompletionTokens { get; set; } + /// + /// prompt_tokens + completion_tokens. + /// + [JsonPropertyName("total_tokens")] + public int TotalTokens { get; set; } + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/AI21 Labs/AI21JurassicIOService.cs b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/AI21 Labs/AI21JurassicIOService.cs new file mode 100644 index 000000000000..e825fedfe039 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/AI21 Labs/AI21JurassicIOService.cs @@ -0,0 +1,98 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json; +using System.Text.Json.Nodes; +using Amazon.BedrockRuntime.Model; +using Amazon.Runtime.Documents; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace Connectors.Amazon.Models.AI21; + +/// +/// Input-output service for AI21 Labs Jurassic. +/// +public class AI21JurassicIOService : IBedrockModelIOService +{ + // Defined constants for default values + private const double DefaultTemperature = 0.5; + private const double DefaultTopP = 0.5; + private const int DefaultMaxTokens = 200; + /// + /// Builds InvokeModelRequest Body parameter to be serialized. + /// + /// The model ID to be used as a request parameter. + /// The input prompt for text generation. + /// Optional prompt execution settings. + /// + public object GetInvokeModelRequestBody(string modelId, string prompt, PromptExecutionSettings? executionSettings = null) + { + var requestBody = new + { + prompt, + temperature = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "temperature", (double?)DefaultTemperature), + topP = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "topP", (double?)DefaultTopP), + maxTokens = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "maxTokens", (int?)DefaultMaxTokens), + stopSequences = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "stopSequences", new List()), + countPenalty = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "countPenalty", new Dictionary()), + presencePenalty = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "presencePenalty", new Dictionary()), + frequencyPenalty = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "frequencyPenalty", new Dictionary()) + }; + + return requestBody; + } + /// + /// Extracts the test contents from the InvokeModelResponse as returned by the Bedrock API. + /// + /// The InvokeModelResponse object provided by the Bedrock InvokeModelAsync output. + /// + public IReadOnlyList GetInvokeResponseBody(InvokeModelResponse response) + { + using var memoryStream = new MemoryStream(); + response.Body.CopyToAsync(memoryStream).ConfigureAwait(false).GetAwaiter().GetResult(); + memoryStream.Position = 0; + using var reader = new StreamReader(memoryStream); + var responseBody = JsonSerializer.Deserialize(reader.ReadToEnd()); + var textContents = new List(); + if (responseBody?.Completions is not { Count: > 0 }) + { + return textContents; + } + textContents.AddRange(responseBody.Completions.Select(completion => new TextContent(completion.Data?.Text))); + return textContents; + } + /// + /// Jurassic does not support converse. + /// + /// The model ID. + /// The messages between assistant and user. + /// Optional prompt execution settings. + /// + /// + public ConverseRequest GetConverseRequest(string modelId, ChatHistory chatHistory, PromptExecutionSettings? settings = null) + { + throw new NotImplementedException("This model does not support chat history. Use text generation to invoke singular response to use this model."); + } + /// + /// Jurassic does not support streaming. + /// + /// + /// + /// + public IEnumerable GetTextStreamOutput(JsonNode chunk) + { + throw new NotImplementedException("Streaming not supported by this model."); + } + /// + /// Jurassic does not support converse (or streaming for that matter). + /// + /// + /// + /// + /// + /// + public ConverseStreamRequest GetConverseStreamRequest(string modelId, ChatHistory chatHistory, PromptExecutionSettings? settings = null) + { + throw new NotImplementedException("Streaming not supported by this model."); + } +} diff --git a/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/AI21 Labs/AI21JurassicResponse.cs b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/AI21 Labs/AI21JurassicResponse.cs new file mode 100644 index 000000000000..113c4407880a --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/AI21 Labs/AI21JurassicResponse.cs @@ -0,0 +1,157 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json.Serialization; + +namespace Connectors.Amazon.Models.AI21; + +/// +/// AI21 Labs Jurassic Response object. +/// +[Serializable] +public class AI21JurassicResponse +{ + /// + /// A unique string id for the processed request. Repeated identical requests receive different IDs. + /// + [JsonPropertyName("id")] + public long Id { get; set; } + /// + /// The prompt includes the raw text, the tokens with their log probabilities, and the top-K alternative tokens at each position, if requested. + /// + [JsonPropertyName("prompt")] + public PromptText? Prompt { get; set; } + /// + /// A list of completions, including raw text, tokens, and log probabilities. The number of completions corresponds to the requested numResults. + /// + [JsonPropertyName("completions")] + public List? Completions { get; set; } +} +/// +/// The prompt includes the raw text, the tokens with their log probabilities, and the top-K alternative tokens at each position, if requested. +/// +[Serializable] +public class PromptText +{ + /// + /// Text string of the prompt. + /// + [JsonPropertyName("text")] + public string? Text { get; set; } + /// + /// list of TokenData. + /// + [JsonPropertyName("tokens")] + public List? Tokens { get; set; } +} +/// +/// The token object corresponding to each prompt object. +/// +[Serializable] +public class Token +{ + /// + /// The token object generated from the token data. + /// + [JsonPropertyName("generatedToken")] + public GeneratedToken? GeneratedToken { get; set; } + /// + /// A list of the top K alternative tokens for this position, sorted by probability, according to the topKReturn request parameter. If topKReturn is set to 0, this field will be null. + /// + [JsonPropertyName("topTokens")] + public object? TopTokens { get; set; } + /// + /// Indicates the start and end offsets of the token in the decoded text string. + /// + [JsonPropertyName("textRange")] + public TextRange? TextRange { get; set; } +} +/// +/// The generated token object from the token data. +/// +[Serializable] +public class GeneratedToken +{ + /// + /// The string representation of the token. + /// + [JsonPropertyName("token")] + public string? TokenValue { get; set; } + /// + /// The predicted log probability of the token after applying the sampling parameters as a float value. + /// + [JsonPropertyName("logprob")] + public double Logprob { get; set; } + /// + /// The raw predicted log probability of the token as a float value. For the indifferent values (namely, temperature=1, topP=1) we get raw_logprob=logprob. + /// + [JsonPropertyName("raw_logprob")] + public double RawLogprob { get; set; } +} +/// +/// Indicates the start and end offsets of the token in the decoded text string. +/// +[Serializable] +public class TextRange +{ + /// + /// The starting index of the token in the decoded text string. + /// + [JsonPropertyName("start")] + public int Start { get; set; } + /// + /// The ending index of the token in the decoded text string. + /// + [JsonPropertyName("end")] + public int End { get; set; } +} +/// +/// A list of completions, including raw text, tokens, and log probabilities. The number of completions corresponds to the requested numResults. +/// +[Serializable] +public class Completion +{ + /// + /// The data, which contains the text (string) and tokens (list of TokenData) for the completion. + /// + [JsonPropertyName("data")] + public JurassicData? Data { get; set; } + /// + /// This nested data structure explains why the generation process was halted for a specific completion. + /// + [JsonPropertyName("finishReason")] + public FinishReason? FinishReason { get; set; } +} +/// +/// The data, which contains the text (string) and tokens (list of TokenData) for the completion +/// +[Serializable] +public class JurassicData +{ + /// + /// The text string from the data provided. + /// + [JsonPropertyName("text")] + public string? Text { get; set; } + /// + /// The list of tokens. + /// + [JsonPropertyName("tokens")] + public List? Tokens { get; set; } +} +/// +/// This nested data structure explains why the generation process was halted for a specific completion. +/// +[Serializable] +public class FinishReason +{ + /// + /// The finish reason: length limit reached, end of text token generation, or stop sequence generated. + /// + [JsonPropertyName("reason")] + public string? Reason { get; set; } + /// + /// The max token count. + /// + [JsonPropertyName("length")] + public int Length { get; set; } +} diff --git a/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Amazon/AmazonIOService.cs b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Amazon/AmazonIOService.cs new file mode 100644 index 000000000000..7d6e83ad3376 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Amazon/AmazonIOService.cs @@ -0,0 +1,145 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json; +using System.Text.Json.Nodes; +using Amazon.BedrockRuntime.Model; +using Amazon.Runtime.Documents; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace Connectors.Amazon.Models.Amazon; + +/// +/// Input-output service for Amazon Titan model. +/// +public class AmazonIOService : IBedrockModelIOService +{ + // Define constants for default values + private const float DefaultTemperature = 0.7f; + private const float DefaultTopP = 0.9f; + private const int DefaultMaxTokenCount = 512; + private static readonly List s_defaultStopSequences = new() { "User:" }; + /// + /// Builds InvokeModel request Body parameter with structure as required by Amazon Titan. + /// + /// The model ID to be used as a request parameter. + /// The input prompt for text generation. + /// Optional prompt execution settings. + /// + public object GetInvokeModelRequestBody(string modelId, string prompt, PromptExecutionSettings? executionSettings = null) + { + float temperature = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "temperature", DefaultTemperature); + float topP = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "topP", DefaultTopP); + int maxTokenCount = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "maxTokenCount", DefaultMaxTokenCount); + List stopSequences = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "stopSequences", s_defaultStopSequences); + + var requestBody = new + { + inputText = prompt, + textGenerationConfig = new + { + temperature, + topP, + maxTokenCount, + stopSequences + } + }; + return requestBody; + } + /// + /// Extracts the test contents from the InvokeModelResponse as returned by the Bedrock API. + /// + /// The InvokeModelResponse object provided by the Bedrock InvokeModelAsync output. + /// + public IReadOnlyList GetInvokeResponseBody(InvokeModelResponse response) + { + using var memoryStream = new MemoryStream(); + response.Body.CopyToAsync(memoryStream).ConfigureAwait(false).GetAwaiter().GetResult(); + memoryStream.Position = 0; + using var reader = new StreamReader(memoryStream); + var responseBody = JsonSerializer.Deserialize(reader.ReadToEnd()); + var textContents = new List(); + if (responseBody?.Results is not { Count: > 0 }) + { + return textContents; + } + string? outputText = responseBody.Results[0].OutputText; + textContents.Add(new TextContent(outputText)); + return textContents; + } + /// + /// Builds the ConverseRequest object for the Bedrock ConverseAsync call with request parameters required by Amazon Titan. + /// + /// The model ID. + /// The messages between assistant and user. + /// Optional prompt execution settings. + /// + public ConverseRequest GetConverseRequest(string modelId, ChatHistory chatHistory, PromptExecutionSettings? settings = null) + { + var messages = BedrockModelUtilities.BuildMessageList(chatHistory); + var systemMessages = BedrockModelUtilities.GetSystemMessages(chatHistory); + + var inferenceConfig = new InferenceConfiguration + { + Temperature = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "temperature", DefaultTemperature), + TopP = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "topP", DefaultTopP), + MaxTokens = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "maxTokenCount", DefaultMaxTokenCount), + }; + + var converseRequest = new ConverseRequest + { + ModelId = modelId, + Messages = messages, + System = systemMessages, + InferenceConfig = inferenceConfig, + AdditionalModelRequestFields = new Document(), + AdditionalModelResponseFieldPaths = new List() + }; + + return converseRequest; + } + /// + /// Extracts the text generation streaming output from the Amazon Titan response object structure. + /// + /// + /// + public IEnumerable GetTextStreamOutput(JsonNode chunk) + { + var text = chunk["outputText"]?.ToString(); + if (!string.IsNullOrEmpty(text)) + { + yield return text; + } + } + /// + /// Builds the ConverseStreamRequest object for the Converse Bedrock API call, including building the Amazon Titan Request object and mapping parameters to the ConverseStreamRequest object. + /// + /// The model ID. + /// The messages between assistant and user. + /// Optional prompt execution settings. + /// + public ConverseStreamRequest GetConverseStreamRequest(string modelId, ChatHistory chatHistory, PromptExecutionSettings? settings = null) + { + var messages = BedrockModelUtilities.BuildMessageList(chatHistory); + var systemMessages = BedrockModelUtilities.GetSystemMessages(chatHistory); + + var inferenceConfig = new InferenceConfiguration + { + Temperature = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "temperature", DefaultTemperature), + TopP = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "topP", DefaultTopP), + MaxTokens = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "maxTokenCount", DefaultMaxTokenCount), + }; + + var converseStreamRequest = new ConverseStreamRequest + { + ModelId = modelId, + Messages = messages, + System = systemMessages, + InferenceConfig = inferenceConfig, + AdditionalModelRequestFields = new Document(), + AdditionalModelResponseFieldPaths = [] + }; + + return converseStreamRequest; + } +} diff --git a/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Amazon/TitanResponse.cs b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Amazon/TitanResponse.cs new file mode 100644 index 000000000000..72a0284eca08 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Amazon/TitanResponse.cs @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json.Serialization; + +namespace Connectors.Amazon.Models.Amazon; + +/// +/// The Amazon Titan Text response object when deserialized from Invoke Model call. +/// +[Serializable] +public class TitanTextResponse +{ + /// + /// The number of tokens in the prompt. + /// + [JsonPropertyName("inputTextTokenCount")] + public int InputTextTokenCount { get; set; } + /// + /// The list of result objects. + /// + [JsonPropertyName("results")] + public List? Results { get; set; } + /// + /// The result object. + /// + public class Result + { + /// + /// The number of tokens in the prompt. + /// + [JsonPropertyName("tokenCount")] + public int TokenCount { get; set; } + /// + /// The text in the response. + /// + [JsonPropertyName("outputText")] + public string? OutputText { get; set; } + /// + /// The reason the response finished being generated. + /// + [JsonPropertyName("completionReason")] + public string? CompletionReason { get; set; } + } +} diff --git a/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Anthropic/AnthropicIOService.cs b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Anthropic/AnthropicIOService.cs new file mode 100644 index 000000000000..18640f43f06e --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Anthropic/AnthropicIOService.cs @@ -0,0 +1,198 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json; +using System.Text.Json.Nodes; +using Amazon.BedrockRuntime.Model; +using Amazon.Runtime.Documents; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace Connectors.Amazon.Models.Anthropic; + +/// +/// Input-output service for Anthropic Claude model. +/// +public class AnthropicIOService : IBedrockModelIOService +{ + // Define constants for default values + private const double DefaultTemperature = 1.0; + private const double DefaultTopP = 1.0; + private const int DefaultMaxTokensToSample = 4096; + private static readonly List s_defaultStopSequences = new() { "\n\nHuman:" }; + private const int DefaultTopK = 250; + /// + /// Builds InvokeModel request Body parameter with structure as required by Anthropic Claude. + /// + /// The model ID to be used as a request parameter. + /// The input prompt for text generation. + /// Optional prompt execution settings. + /// + public object GetInvokeModelRequestBody(string modelId, string prompt, PromptExecutionSettings? executionSettings = null) + { + var requestBody = new + { + prompt = $"\n\nHuman: {prompt}\n\nAssistant:", + temperature = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "temperature", (double?)DefaultTemperature), + max_tokens_to_sample = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "max_tokens_to_sample", (int?)DefaultMaxTokensToSample), + stop_sequences = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "stop_sequences", s_defaultStopSequences), + top_p = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "top_p", (double?)DefaultTopP), + top_k = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "top_k", (int?)DefaultTopK) + }; + return requestBody; + } + /// + /// Extracts the test contents from the InvokeModelResponse as returned by the Bedrock API. + /// + /// The InvokeModelResponse object provided by the Bedrock InvokeModelAsync output. + /// + public IReadOnlyList GetInvokeResponseBody(InvokeModelResponse response) + { + using var memoryStream = new MemoryStream(); + response.Body.CopyToAsync(memoryStream).ConfigureAwait(false).GetAwaiter().GetResult(); + memoryStream.Position = 0; + using var reader = new StreamReader(memoryStream); + var responseBody = JsonSerializer.Deserialize(reader.ReadToEnd()); + var textContents = new List(); + if (!string.IsNullOrEmpty(responseBody?.Completion)) + { + textContents.Add(new TextContent(responseBody.Completion)); + } + return textContents; + } + + /// + /// Builds the ConverseRequest object for the Bedrock ConverseAsync call with request parameters required by Anthropic Claude. + /// + /// The model ID. + /// The messages between assistant and user. + /// Optional prompt execution settings. + /// + public ConverseRequest GetConverseRequest(string modelId, ChatHistory chatHistory, PromptExecutionSettings? settings = null) + { + var messages = BedrockModelUtilities.BuildMessageList(chatHistory); + var systemMessages = BedrockModelUtilities.GetSystemMessages(chatHistory); + var system = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "system", systemMessages); + var inferenceConfig = new InferenceConfiguration + { + Temperature = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "temperature", (float)DefaultTemperature), + TopP = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "top_p", (float)DefaultTopP), + MaxTokens = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "max_tokens_to_sample", DefaultMaxTokensToSample) + }; + var additionalModelRequestFields = new Document(); + + var tools = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "tools", new List()); + var toolChoice = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "tool_choice", null); + + if (modelId != "anthropic.claude-instant-v1") + { + additionalModelRequestFields.Add( + "tools", new Document(tools.Select(t => new Document + { + { "name", t.Name }, + { "description", t.Description }, + { "input_schema", t.InputSchema } + }).ToList()) + ); + + additionalModelRequestFields.Add( + "tool_choice", toolChoice != null + ? new Document + { + { "type", toolChoice.Type }, + { "name", toolChoice.Name } + } + : new Document() + ); + } + + var converseRequest = new ConverseRequest + { + ModelId = modelId, + Messages = messages, + System = system, + InferenceConfig = inferenceConfig, + AdditionalModelRequestFields = additionalModelRequestFields, + AdditionalModelResponseFieldPaths = new List(), + GuardrailConfig = null, // Set if needed + ToolConfig = null // Set if needed + }; + + return converseRequest; + } + /// + /// Extracts the text generation streaming output from the Anthropic Claude response object structure. + /// + /// + /// + public IEnumerable GetTextStreamOutput(JsonNode chunk) + { + var text = chunk["completion"]?.ToString(); + if (!string.IsNullOrEmpty(text)) + { + yield return text; + } + } + + /// + /// Builds the ConverseStreamRequest object for the Converse Bedrock API call, including building the Anthropic Claude Request object and mapping parameters to the ConverseStreamRequest object. + /// + /// The model ID. + /// The messages between assistant and user. + /// Optional prompt execution settings. + /// + public ConverseStreamRequest GetConverseStreamRequest(string modelId, ChatHistory chatHistory, PromptExecutionSettings? settings = null) + { + var messages = BedrockModelUtilities.BuildMessageList(chatHistory); + var systemMessages = BedrockModelUtilities.GetSystemMessages(chatHistory); + + var system = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "system", systemMessages); + + var inferenceConfig = new InferenceConfiguration + { + Temperature = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "temperature", (float)DefaultTemperature), + TopP = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "top_p", (float)DefaultTopP), + MaxTokens = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "max_tokens_to_sample", DefaultMaxTokensToSample) + }; + + var additionalModelRequestFields = new Document(); + + var tools = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "tools", new List()); + var toolChoice = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "tool_choice", null); + + if (modelId != "anthropic.claude-instant-v1") + { + additionalModelRequestFields.Add( + "tools", new Document(tools.Select(t => new Document + { + { "name", t.Name }, + { "description", t.Description }, + { "input_schema", t.InputSchema } + }).ToList()) + ); + + additionalModelRequestFields.Add( + "tool_choice", toolChoice != null + ? new Document + { + { "type", toolChoice.Type }, + { "name", toolChoice.Name } + } + : new Document() + ); + } + + var converseRequest = new ConverseStreamRequest + { + ModelId = modelId, + Messages = messages, + System = system, + InferenceConfig = inferenceConfig, + AdditionalModelRequestFields = additionalModelRequestFields, + AdditionalModelResponseFieldPaths = new List(), + GuardrailConfig = null, // Set if needed + ToolConfig = null // Set if needed + }; + + return converseRequest; + } +} diff --git a/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Anthropic/ClaudeResponse.cs b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Anthropic/ClaudeResponse.cs new file mode 100644 index 000000000000..bb5ec72eac1f --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Anthropic/ClaudeResponse.cs @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json.Serialization; + +namespace Connectors.Amazon.Models.Anthropic; + +/// +/// Anthropic Claude completion response. +/// +public class ClaudeResponse +{ + /// + /// The resulting completion up to and excluding the stop sequences. + /// + [JsonPropertyName("completion")] + public string? Completion { get; set; } + /// + /// The reason why the model stopped generating the response. + /// "stop_sequence" – The model reached a stop sequence — either provided by you with the stop_sequences inference parameter, or a stop sequence built into the model. + /// "max_tokens" – The model exceeded max_tokens_to_sample or the model's maximum number of tokens. + /// + [JsonPropertyName("stop_reason")] + public string? StopReason { get; set; } + /// + /// If you specify the stop_sequences inference parameter, stop contains the stop sequence that signalled the model to stop generating text. For example, holes in the following response. + /// + [JsonPropertyName("stop")] + public string? Stop { get; set; } +} diff --git a/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Anthropic/ClaudeToolUse.cs b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Anthropic/ClaudeToolUse.cs new file mode 100644 index 000000000000..8f80e99faeea --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Anthropic/ClaudeToolUse.cs @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json.Serialization; +using Amazon.BedrockRuntime.Model; + +namespace Connectors.Amazon.Models.Anthropic; + +/// +/// Anthropic Claude request object. +/// +public static class ClaudeToolUse +{ + /// + /// (Optional) Definitions of tools that the model may use. + /// + public class ClaudeTool : Tool + { + /// + /// The name of the tool. + /// + [JsonPropertyName("name")] + public required string Name { get; set; } + + /// + /// (optional, but strongly recommended) The description of the tool. + /// + [JsonPropertyName("description")] + public string? Description { get; set; } + + /// + /// The JSON schema for the tool. + /// + [JsonPropertyName("input_schema")] + public required string InputSchema { get; set; } + } + + /// + /// (Optional) Specifices how the model should use the provided tools. The model can use a specific tool, any available tool, or decide by itself. + /// + public class ClaudeToolChoice + { + /// + /// The type of tool choice. Possible values are any (use any available tool), auto (the model decides), and tool (use the specified tool). + /// + [JsonPropertyName("type")] + public string? Type { get; set; } + + /// + /// (Optional) The name of the tool to use. Required if you specify tool in the type field. + /// + [JsonPropertyName("name")] + public string? Name { get; set; } + } +} diff --git a/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/BedrockModelUtilities.cs b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/BedrockModelUtilities.cs new file mode 100644 index 000000000000..a3f4884fda69 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/BedrockModelUtilities.cs @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Amazon.BedrockRuntime; +using Amazon.BedrockRuntime.Model; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace Connectors.Amazon.Models; + +/// +/// Utilities class for functions all Bedrock models need to use. +/// +public static class BedrockModelUtilities +{ + /// + /// Maps the AuthorRole to the corresponding ConversationRole because AuthorRole is static and { readonly get; }. Only called if AuthorRole is USer or Assistant (System set outside/beforehand). + /// + /// + /// + /// + public static ConversationRole MapAuthorRoleToConversationRole(AuthorRole role) + { + if (role == AuthorRole.User) + { + return ConversationRole.User; + } + + if (role == AuthorRole.Assistant) + { + return ConversationRole.Assistant; + } + throw new ArgumentException($"Invalid role: {role}"); + } + /// + /// Gets the system messages from the ChatHistory and adds them to the ConverseRequest System parameter. + /// + /// + /// + public static List GetSystemMessages(ChatHistory chatHistory) + { + return chatHistory + .Where(m => m.Role == AuthorRole.System) + .Select(m => new SystemContentBlock { Text = m.Content }) + .ToList(); + } + /// + /// Creates the list of user and assistant messages for the Converse Request from the Chat History. + /// + /// + /// + public static List BuildMessageList(ChatHistory chatHistory) + { + // Check that the text from the latest message in the chat history is not empty. + Verify.NotNullOrEmpty(chatHistory); + string? text = chatHistory[^1].Content; + if (string.IsNullOrWhiteSpace(text)) + { + throw new ArgumentException("Last message in chat history was null or whitespace."); + } + return chatHistory + .Where(m => m.Role != AuthorRole.System) + .Select(m => new Message + { + Role = MapAuthorRoleToConversationRole(m.Role), + Content = new List { new() { Text = m.Content } } + }) + .ToList(); + } + /// + /// Gets the prompt execution settings extension data for the model request body build. + /// + /// + /// + /// + /// + /// + public static TValue GetExtensionDataValue(IDictionary? extensionData, string key, TValue defaultValue) + { + if (extensionData?.TryGetValue(key, out object? value) == true && value is TValue typedValue) + { + return typedValue; + } + return defaultValue; + } +} diff --git a/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Cohere/CohereCommandIOService.cs b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Cohere/CohereCommandIOService.cs new file mode 100644 index 000000000000..49148d0e8389 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Cohere/CohereCommandIOService.cs @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json; +using System.Text.Json.Nodes; +using Amazon.BedrockRuntime.Model; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace Connectors.Amazon.Models.Cohere; + +/// +/// Input-output service for Cohere Command. +/// +public class CohereCommandIOService : IBedrockModelIOService +{ + // Define constants for default values + private const double DefaultTemperature = 0.9; + private const double DefaultTopP = 0.75; + private const int DefaultMaxTokens = 20; + private const double DefaultTopK = 0.0; + private const string DefaultReturnLikelihoods = "NONE"; + private const bool DefaultStream = false; + private const int DefaultNumGenerations = 1; + private const string DefaultTruncate = "END"; + /// + /// Builds InvokeModel request Body parameter with structure as required by Cohere Command. + /// + /// The model ID to be used as a request parameter. + /// The input prompt for text generation. + /// Optional prompt execution settings. + /// + public object GetInvokeModelRequestBody(string modelId, string prompt, PromptExecutionSettings? executionSettings = null) + { + var requestBody = new + { + prompt, + temperature = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "temperature", (double?)DefaultTemperature), + p = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "p", (double?)DefaultTopP), + k = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "k", (double?)DefaultTopK), + max_tokens = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "max_tokens", (int?)DefaultMaxTokens), + stop_sequences = BedrockModelUtilities.GetExtensionDataValue>(executionSettings?.ExtensionData, "stop_sequences", []), + return_likelihoods = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "return_likelihoods", DefaultReturnLikelihoods), + stream = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "stream", (bool?)DefaultStream), + num_generations = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "num_generations", (int?)DefaultNumGenerations), + logit_bias = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "logit_bias", new Dictionary()), + truncate = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "truncate", DefaultTruncate) + }; + + return requestBody; + } + /// + /// Extracts the test contents from the InvokeModelResponse as returned by the Bedrock API. + /// + /// The InvokeModelResponse object provided by the Bedrock InvokeModelAsync output. + /// A list of text content objects as required by the semantic kernel. + public IReadOnlyList GetInvokeResponseBody(InvokeModelResponse response) + { + using var memoryStream = new MemoryStream(); + response.Body.CopyToAsync(memoryStream).ConfigureAwait(false).GetAwaiter().GetResult(); + memoryStream.Position = 0; + using var reader = new StreamReader(memoryStream); + var responseBody = JsonSerializer.Deserialize(reader.ReadToEnd()); + var textContents = new List(); + if (responseBody?.Generations is not { Count: > 0 }) + { + return textContents; + } + textContents.AddRange(from generation in responseBody.Generations where !string.IsNullOrEmpty(generation.Text) select new TextContent(generation.Text)); + return textContents; + } + /// + /// Extracts the text generation streaming output from the Cohere Command response object structure. + /// + /// + /// + public IEnumerable GetTextStreamOutput(JsonNode chunk) + { + var generations = chunk["generations"]?.AsArray(); + if (generations != null) + { + foreach (var generation in generations) + { + var text = generation?["text"]?.ToString(); + if (!string.IsNullOrEmpty(text)) + { + yield return text; + } + } + } + } + + /// + /// Command does not support Converse (only Command R): "Limited. No chat support." - https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features + /// + /// + /// + /// + /// + /// + public ConverseRequest GetConverseRequest(string modelId, ChatHistory chatHistory, PromptExecutionSettings? settings = null) + { + throw new NotImplementedException("Converse not supported by this model."); + } + /// + /// Command does not support ConverseStream (only Command R): "Limited. No chat support." - https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features + /// + /// + /// + /// + /// + /// + public ConverseStreamRequest GetConverseStreamRequest(string modelId, ChatHistory chatHistory, PromptExecutionSettings? settings = null) + { + throw new NotImplementedException("Streaming not supported by this model."); + } +} diff --git a/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Cohere/CohereCommandRIOService.cs b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Cohere/CohereCommandRIOService.cs new file mode 100644 index 000000000000..04fd818a89eb --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Cohere/CohereCommandRIOService.cs @@ -0,0 +1,187 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json; +using System.Text.Json.Nodes; +using Amazon.BedrockRuntime.Model; +using Amazon.Runtime.Documents; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace Connectors.Amazon.Models.Cohere; + +/// +/// Input-output service for Cohere Command R. +/// +// ReSharper disable InconsistentNaming +public class CohereCommandRIOService : IBedrockModelIOService +// ReSharper restore InconsistentNaming +{ + // Define constants for default values + private const float DefaultTemperature = 0.3f; + private const float DefaultTopP = 0.75f; + private const float DefaultTopK = 0.0f; + private const string DefaultPromptTruncation = "OFF"; + private const float DefaultFrequencyPenalty = 0.0f; + private const float DefaultPresencePenalty = 0.0f; + private const int DefaultSeed = 0; + private const bool DefaultReturnPrompt = false; + private const bool DefaultRawPrompting = false; + private const int DefaultMaxTokens = 4096; + private const bool DefaultSearchQueriesOnly = false; + /// + /// Builds InvokeModel request Body parameter with structure as required by Cohere Command R. + /// + /// The model ID to be used as a request parameter. + /// The input prompt for text generation. + /// Optional prompt execution settings. + /// + public object GetInvokeModelRequestBody(string modelId, string prompt, PromptExecutionSettings? executionSettings = null) + { + var defaultChatHistory = new List> + { + new() + { + { "role", "USER" }, + { "message", prompt } + } + }; + var requestBody = new + { + message = prompt, + chat_history = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "chat_history", defaultChatHistory), + documents = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "documents", new List()), + search_queries_only = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "search_queries_only", DefaultSearchQueriesOnly), + preamble = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "preamble", ""), + max_tokens = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "max_tokens", DefaultMaxTokens), + temperature = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "temperature", DefaultTemperature), + p = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "p", DefaultTopP), + k = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "k", DefaultTopK), + prompt_truncation = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "prompt_truncation", DefaultPromptTruncation), + frequency_penalty = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "frequency_penalty", DefaultFrequencyPenalty), + presence_penalty = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "presence_penalty", DefaultPresencePenalty), + seed = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "seed", DefaultSeed), + return_prompt = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "return_prompt", DefaultReturnPrompt), + stop_sequences = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "stop_sequences", new List()), + raw_prompting = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "raw_prompting", DefaultRawPrompting) + }; + + return requestBody; + } + /// + /// Extracts the test contents from the InvokeModelResponse as returned by the Bedrock API. + /// + /// The InvokeModelResponse object provided by the Bedrock InvokeModelAsync output. + /// + public IReadOnlyList GetInvokeResponseBody(InvokeModelResponse response) + { + using var memoryStream = new MemoryStream(); + response.Body.CopyToAsync(memoryStream).ConfigureAwait(false).GetAwaiter().GetResult(); + memoryStream.Position = 0; + using var reader = new StreamReader(memoryStream); + var responseBody = JsonSerializer.Deserialize(reader.ReadToEnd()); + var textContents = new List(); + if (!string.IsNullOrEmpty(responseBody?.Text)) + { + textContents.Add(new TextContent(responseBody.Text)); + } + return textContents; + } + /// + /// Builds the ConverseRequest object for the Bedrock ConverseAsync call with request parameters required by Cohere Command R. + /// + /// The model ID + /// The messages between assistant and user. + /// Optional prompt execution settings. + /// + public ConverseRequest GetConverseRequest(string modelId, ChatHistory chatHistory, PromptExecutionSettings? settings = null) + { + var messages = BedrockModelUtilities.BuildMessageList(chatHistory); + var systemMessages = BedrockModelUtilities.GetSystemMessages(chatHistory); + var converseRequest = new ConverseRequest + { + ModelId = modelId, + Messages = messages, + System = systemMessages, + InferenceConfig = new InferenceConfiguration + { + Temperature = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "temperature", DefaultTemperature), + TopP = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "p", DefaultTopP), + MaxTokens = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "max_tokens", DefaultMaxTokens) + }, + AdditionalModelRequestFields = new Document + { + { "k", BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "k", DefaultTopK) }, + { "prompt_truncation", BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "prompt_truncation", DefaultPromptTruncation) }, + { "frequency_penalty", BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "frequency_penalty", DefaultFrequencyPenalty) }, + { "presence_penalty", BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "presence_penalty", DefaultPresencePenalty) }, + { "seed", BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "seed", DefaultSeed) }, + { "return_prompt", BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "return_prompt", DefaultReturnPrompt) }, + { "stop_sequences", new Document(BedrockModelUtilities.GetExtensionDataValue>(settings?.ExtensionData, "stop_sequences", []).Select(s => new Document(s)).ToList()) }, + { "raw_prompting", BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "raw_prompting", DefaultRawPrompting) } + }, + AdditionalModelResponseFieldPaths = new List(), + GuardrailConfig = null, + ToolConfig = null + }; + + return converseRequest; + } + /// + /// Extracts the text generation streaming output from the Cohere Command R response object structure. + /// + /// + /// + public IEnumerable GetTextStreamOutput(JsonNode chunk) + { + var text = chunk["text"]?.ToString(); + if (!string.IsNullOrEmpty(text)) + { + yield return text; + } + } + /// + /// Builds the ConverseStreamRequest object for the Converse Bedrock API call, including building the Cohere Command R Request object and mapping parameters to the ConverseStreamRequest object. + /// + /// The model ID. + /// The messages between assistant and user. + /// Optional prompt execution settings. + /// + public ConverseStreamRequest GetConverseStreamRequest(string modelId, ChatHistory chatHistory, PromptExecutionSettings? settings = null) + { + var messages = BedrockModelUtilities.BuildMessageList(chatHistory); + var systemMessages = BedrockModelUtilities.GetSystemMessages(chatHistory); + + var inferenceConfig = new InferenceConfiguration + { + Temperature = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "temperature", DefaultTemperature), + TopP = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "p", DefaultTopP), + MaxTokens = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "max_tokens", DefaultMaxTokens) + }; + + var additionalModelRequestFields = new Document + { + { "k", BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "k", DefaultTopK) }, + { "prompt_truncation", BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "prompt_truncation", DefaultPromptTruncation) }, + { "frequency_penalty", BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "frequency_penalty", DefaultFrequencyPenalty) }, + { "presence_penalty", BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "presence_penalty", DefaultPresencePenalty) }, + { "seed", BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "seed", DefaultSeed) }, + { "return_prompt", BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "return_prompt", DefaultReturnPrompt) }, + { "stop_sequences", new Document(BedrockModelUtilities.GetExtensionDataValue>(settings?.ExtensionData, "stop_sequences", []).Select(s => new Document(s)).ToList()) }, + { "raw_prompting", BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "raw_prompting", false) } + }; + + var converseRequest = new ConverseStreamRequest + { + ModelId = modelId, + Messages = messages, + System = systemMessages, + InferenceConfig = inferenceConfig, + AdditionalModelRequestFields = additionalModelRequestFields, + AdditionalModelResponseFieldPaths = new List(), + GuardrailConfig = null, + ToolConfig = null + }; + + return converseRequest; + } +} diff --git a/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Cohere/CommandRTextResponse.cs b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Cohere/CommandRTextResponse.cs new file mode 100644 index 000000000000..2b7898487bd2 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Cohere/CommandRTextResponse.cs @@ -0,0 +1,145 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json.Serialization; + +namespace Connectors.Amazon.Models.Cohere; + +/// +/// Cohere Command R Text Generation Response body. +/// +public class CommandRTextResponse +{ + /// + /// Unique identifier for chat completion + /// + [JsonPropertyName("response_id")] + public string? ResponseId { get; set; } + /// + /// The model’s response to chat message input. + /// + [JsonPropertyName("text")] + public string? Text { get; set; } + /// + /// Unique identifier for chat completion, used with Feedback endpoint on Cohere’s platform. + /// + [JsonPropertyName("generation_id")] + public string? GenerationId { get; set; } + /// + /// An array of inline citations and associated metadata for the generated reply. + /// + [JsonPropertyName("citations")] + public List? Citations { get; set; } + /// + /// The full prompt that was sent to the model. Specify the return_prompt field to return this field. + /// + [JsonPropertyName("prompt")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? Prompt { get; set; } + /// + /// The reason why the model stopped generating output. + /// + [JsonPropertyName("finish_reason")] + public string? FinishReason { get; set; } + /// + /// A list of appropriate tools to calls. Only returned if you specify the tools input field. + /// + [JsonPropertyName("tool_calls")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? ToolCalls { get; set; } + /// + /// API usage data (only exists for streaming). + /// + [JsonPropertyName("meta")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public MetaCommandR? Meta { get; set; } +} +/// +/// Citation object for array of inline citations and associated metadata for the generated reply. +/// +[Serializable] +public class Citation +{ + /// + /// The index that the citation begins at, starting from 0. + /// + [JsonPropertyName("start")] + public int Start { get; set; } + /// + /// The index that the citation ends after, starting from 0. + /// + [JsonPropertyName("end")] + public int End { get; set; } + /// + /// The text that the citation pertains to. + /// + [JsonPropertyName("text")] + public string? Text { get; set; } + /// + /// An array of document IDs that correspond to documents that are cited for the text. + /// + [JsonPropertyName("document_ids")] + public List? DocumentIds { get; set; } +} +/// +/// Components for tool calling. +/// +[Serializable] +public class ToolCall +{ + /// + /// Name of tool. + /// + [JsonPropertyName("name")] + public string? Name { get; set; } + /// + /// Parameters for tool. + /// + [JsonPropertyName("parameters")] + public Dictionary? Parameters { get; set; } +} +/// +/// API usage data (only exists for streaming). +/// +[Serializable] +public class MetaCommandR +{ + /// + /// The API version. The version is in the version field. + /// + [JsonPropertyName("api_version")] + public ApiVersion? ApiVersion { get; set; } + /// + /// The billed units. + /// + [JsonPropertyName("billed_units")] + public BilledUnits? BilledUnits { get; set; } +} +/// +/// The API version. +/// +[Serializable] +public class ApiVersion +{ + /// + /// The corresponding version field for the API version identification. + /// + [JsonPropertyName("version")] + public string? Version { get; set; } +} +/// +/// The billed units. +/// +[Serializable] +public class BilledUnits +{ + /// + /// The number of input tokens that were billed. + /// + [JsonPropertyName("input_tokens")] + public int InputTokens { get; set; } + /// + /// The number of output tokens that were billed. + /// + [JsonPropertyName("output_tokens")] + public int OutputTokens { get; set; } +} diff --git a/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Cohere/CommandRToolUse.cs b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Cohere/CommandRToolUse.cs new file mode 100644 index 000000000000..d34847380dbd --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Cohere/CommandRToolUse.cs @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json.Serialization; + +namespace Connectors.Amazon.Models.Cohere; + +/// +/// Cohere Command R Text Generation Request object for Invoke Model Bedrock API call. +/// +public static class CommandRToolUse +{ + /// + /// Tool parameters. + /// + [Serializable] + public class Tool + { + /// + /// Name of the tool. + /// + [JsonPropertyName("name")] + public required string Name { get; set; } + /// + /// Description of the tool. + /// + [JsonPropertyName("description")] + public required string Description { get; set; } + /// + /// Definitions for each tool. + /// + [JsonPropertyName("parameter_definitions")] + public required Dictionary ParameterDefinitions { get; set; } + } + /// + /// Components of each tool parameter. + /// + [Serializable] + public class ToolParameter + { + /// + /// Description of parameter. + /// + [JsonPropertyName("description")] + public required string Description { get; set; } + /// + /// Parameter type (str, int, etc.) as described in a string. + /// + [JsonPropertyName("type")] + public required string Type { get; set; } + /// + /// Whether this parameter is required. + /// + [JsonPropertyName("required")] + public required bool Required { get; set; } + } + /// + /// Cohere tool result. + /// + [Serializable] + public class ToolResult + { + /// + /// The tool call. + /// + [JsonPropertyName("call")] + public required ToolCall Call { get; set; } + /// + /// Outputs from the tool call. + /// + [JsonPropertyName("outputs")] + public required List> Outputs { get; set; } + } + /// + /// Tool call object to be passed into the tool call. + /// + [Serializable] + public class ToolCall + { + /// + /// Name of tool. + /// + [JsonPropertyName("name")] + public required string Name { get; set; } + /// + /// Parameters for the tool. + /// + [JsonPropertyName("parameters")] + public required Dictionary Parameters { get; set; } + } +} diff --git a/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Cohere/CommandTextResponse.cs b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Cohere/CommandTextResponse.cs new file mode 100644 index 000000000000..df12d75958ea --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Cohere/CommandTextResponse.cs @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json.Serialization; + +namespace Connectors.Amazon.Models.Cohere; + +/// +/// The Command Text Generation Response body. +/// +public class CommandTextResponse +{ + /// + /// A list of generated results along with the likelihoods for tokens requested. (Always returned). + /// + [JsonPropertyName("generations")] + public required List Generations { get; set; } + /// + /// An identifier for the request (always returned). + /// + [JsonPropertyName("id")] + public required string Id { get; set; } + /// + /// The prompt from the input request (always returned). + /// + [JsonPropertyName("prompt")] + public required string Prompt { get; set; } + /// + /// A list of generated results along with the likelihoods for tokens requested. (Always returned). Each generation object in the list contains the following fields. + /// + [Serializable] + public class Generation + { + /// + /// The reason why the model finished generating tokens. COMPLETE - the model sent back a finished reply. MAX_TOKENS – the reply was cut off because the model reached the maximum number of tokens for its context length. ERROR – something went wrong when generating the reply. ERROR_TOXIC – the model generated a reply that was deemed toxic. finish_reason is returned only when is_finished=true. (Not always returned). + /// + [JsonPropertyName("finish_reason")] + public string? FinishReason { get; set; } + /// + /// An identifier for the generation. (Always returned). + /// + [JsonPropertyName("id")] + public required string Id { get; set; } + /// + /// The generated text. + /// + [JsonPropertyName("text")] + public string? Text { get; set; } + /// + /// The likelihood of the output. The value is the average of the token likelihoods in token_likelihoods. Returned if you specify the return_likelihoods input parameter. + /// + [JsonPropertyName("likelihood")] + public double? Likelihood { get; set; } + /// + /// An array of per token likelihoods. Returned if you specify the return_likelihoods input parameter. + /// + [JsonPropertyName("token_likelihoods")] + public List? TokenLikelihoods { get; set; } + /// + /// A boolean field used only when stream is true, signifying whether there are additional tokens that will be generated as part of the streaming response. (Not always returned) + /// + [JsonPropertyName("is_finished")] + public bool IsFinished { get; set; } + /// + /// In a streaming response, use to determine which generation a given token belongs to. When only one response is streamed, all tokens belong to the same generation and index is not returned. index therefore is only returned in a streaming request with a value for num_generations that is larger than one. + /// + [JsonPropertyName("index")] + public int? Index { get; set; } + } + /// + /// An array of per token likelihoods. Returned if you specify the return_likelihoods input parameter. + /// + [Serializable] + public class TokenLikelihood + { + /// + /// Token likelihood. + /// + [JsonPropertyName("token")] + public double Token { get; set; } + } +} diff --git a/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/IBedrockModelIOService.cs b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/IBedrockModelIOService.cs new file mode 100644 index 000000000000..ca7503cb1eda --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/IBedrockModelIOService.cs @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json.Nodes; +using Amazon.BedrockRuntime.Model; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace Connectors.Amazon.Models; + +/// +/// Bedrock input-output service to build the request and response bodies as required by the given model. +/// +public interface IBedrockModelIOService +{ + /// + /// Builds InvokeModelRequest Body parameter to be serialized. Object itself dependent on model request parameter requirements. + /// + /// The model ID to be used as a request parameter. + /// The input prompt for text generation. + /// Optional prompt execution settings. + /// + object GetInvokeModelRequestBody(string modelId, string prompt, PromptExecutionSettings? executionSettings = null); + /// + /// Extracts the test contents from the InvokeModelResponse as returned by the Bedrock API. Must be deserialized into the model's specific response object first. + /// + /// The InvokeModelResponse object returned from the InvokeAsync Bedrock call. + /// + IReadOnlyList GetInvokeResponseBody(InvokeModelResponse response); + /// + /// Builds the converse request given the chat history and model ID passed in by the user. This request is to be passed into the Bedrock Converse API call. + /// + /// The model ID to be used as a request parameter. + /// The messages for the converse call. + /// Optional prompt execution settings/ + /// + ConverseRequest GetConverseRequest(string modelId, ChatHistory chatHistory, PromptExecutionSettings? settings = null); + /// + /// Converts the Json output from the streaming text generation into IEnumerable strings for output. + /// + /// + /// + public IEnumerable GetTextStreamOutput(JsonNode chunk); + /// + /// Builds the converse stream request given the chat history and model ID passed in by the user. This request is to be passed into the Bedrock Converse API call. + /// + /// + /// + /// + /// + public ConverseStreamRequest GetConverseStreamRequest(string modelId, ChatHistory chatHistory, PromptExecutionSettings? settings = null); +} diff --git a/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Meta/LlamaTextResponse.cs b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Meta/LlamaTextResponse.cs new file mode 100644 index 000000000000..1cb5f70356ba --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Meta/LlamaTextResponse.cs @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json.Serialization; + +namespace Connectors.Amazon.Models.Meta; + +/// +/// Text generation response object for Meta Llama. +/// +public class LlamaTextResponse +{ + /// + /// The generated text. + /// + [JsonPropertyName("generation")] + public string? Generation { get; set; } + /// + /// The number of tokens in the prompt. + /// + [JsonPropertyName("prompt_token_count")] + public int PromptTokenCount { get; set; } + /// + /// The number of tokens in the generated text. + /// + [JsonPropertyName("generation_token_count")] + public int GenerationTokenCount { get; set; } + /// + /// The reason why the response stopped generating text. Possible values are stop (The model has finished generating text for the input prompt) and length (The length of the tokens for the generated text exceeds the value of max_gen_len in the call to InvokeModel (InvokeModelWithResponseStream, if you are streaming output). The response is truncated to max_gen_len tokens. Consider increasing the value of max_gen_len and trying again.). + /// + [JsonPropertyName("stop_reason")] + public string? StopReason { get; set; } +} diff --git a/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Meta/MetaIOService.cs b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Meta/MetaIOService.cs new file mode 100644 index 000000000000..244d0d789f79 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Meta/MetaIOService.cs @@ -0,0 +1,142 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json; +using System.Text.Json.Nodes; +using Amazon.BedrockRuntime.Model; +using Amazon.Runtime.Documents; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace Connectors.Amazon.Models.Meta; + +/// +/// Input-output service for Meta Llama. +/// +public class MetaIOService : IBedrockModelIOService +{ + // Define constants for default values + private const double DefaultTemperature = 0.5f; + private const double DefaultTopP = 0.9f; + private const int DefaultMaxGenLen = 512; + /// + /// Builds InvokeModel request Body parameter with structure as required by Meta Llama. + /// + /// The model ID to be used as a request parameter. + /// The input prompt for text generation. + /// Optional prompt execution settings. + /// + public object GetInvokeModelRequestBody(string modelId, string prompt, PromptExecutionSettings? executionSettings = null) + { + var requestBody = new + { + prompt, + temperature = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "temperature", DefaultTemperature), + top_p = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "top_p", DefaultTopP), + max_gen_len = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "max_gen_len", (int?)DefaultMaxGenLen) + }; + + return requestBody; + } + /// + /// Extracts the test contents from the InvokeModelResponse as returned by the Bedrock API. + /// + /// The InvokeModelResponse object provided by the Bedrock InvokeModelAsync output. + /// + public IReadOnlyList GetInvokeResponseBody(InvokeModelResponse response) + { + using var memoryStream = new MemoryStream(); + response.Body.CopyToAsync(memoryStream).ConfigureAwait(false).GetAwaiter().GetResult(); + memoryStream.Position = 0; + using var reader = new StreamReader(memoryStream); + var responseBody = JsonSerializer.Deserialize(reader.ReadToEnd()); + var textContents = new List(); + if (!string.IsNullOrEmpty(responseBody?.Generation)) + { + textContents.Add(new TextContent(responseBody.Generation)); + } + return textContents; + } + /// + /// Builds the ConverseRequest object for the Bedrock ConverseAsync call with request parameters required by Meta Llama. + /// + /// The model ID. + /// The messages between assistant and user. + /// Optional prompt execution settings. + /// + public ConverseRequest GetConverseRequest(string modelId, ChatHistory chatHistory, PromptExecutionSettings? settings = null) + { + var messages = BedrockModelUtilities.BuildMessageList(chatHistory); + var systemMessages = BedrockModelUtilities.GetSystemMessages(chatHistory); + + var inferenceConfig = new InferenceConfiguration + { + Temperature = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "temperature", (float)DefaultTemperature), + TopP = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "top_p", (float)DefaultTopP), + MaxTokens = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "max_gen_len", DefaultMaxGenLen) + }; + + var converseRequest = new ConverseRequest + { + ModelId = modelId, + Messages = messages, + System = systemMessages, + InferenceConfig = inferenceConfig, + AdditionalModelRequestFields = new Document(), + AdditionalModelResponseFieldPaths = new List(), + GuardrailConfig = null, + ToolConfig = null + }; + + return converseRequest; + } + /// + /// Extracts the text generation streaming output from the Meta Llama response object structure. + /// + /// + /// + public IEnumerable GetTextStreamOutput(JsonNode chunk) + { + var generation = chunk["generation"]?.ToString(); + if (!string.IsNullOrEmpty(generation)) + { + yield return generation; + } + } + + /// + /// Builds the ConverseStreamRequest object for the Converse Bedrock API call, including building the Meta Llama Request object and mapping parameters to the ConverseStreamRequest object. + /// + /// The model ID. + /// The messages between assistant and user. + /// Optional prompt execution settings. + /// + public ConverseStreamRequest GetConverseStreamRequest( + string modelId, + ChatHistory chatHistory, + PromptExecutionSettings? settings = null) + { + var messages = BedrockModelUtilities.BuildMessageList(chatHistory); + var systemMessages = BedrockModelUtilities.GetSystemMessages(chatHistory); + + var inferenceConfig = new InferenceConfiguration + { + Temperature = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "temperature", (float)DefaultTemperature), + TopP = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "top_p", (float)DefaultTopP), + MaxTokens = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "max_gen_len", DefaultMaxGenLen) + }; + + var converseRequest = new ConverseStreamRequest + { + ModelId = modelId, + Messages = messages, + System = systemMessages, + InferenceConfig = inferenceConfig, + AdditionalModelRequestFields = new Document(), + AdditionalModelResponseFieldPaths = new List(), + GuardrailConfig = null, + ToolConfig = null + }; + + return converseRequest; + } +} diff --git a/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Mistral/MistralIOService.cs b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Mistral/MistralIOService.cs new file mode 100644 index 000000000000..0426fce3d496 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Mistral/MistralIOService.cs @@ -0,0 +1,155 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json; +using System.Text.Json.Nodes; +using Amazon.BedrockRuntime.Model; +using Amazon.Runtime.Documents; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace Connectors.Amazon.Models.Mistral; + +/// +/// Input-output service for Mistral. +/// +public class MistralIOService : IBedrockModelIOService +{ + // Define constants for default values + private const float DefaultTemperatureInstruct = 0.5f; + private const float DefaultTopPInstruct = 0.9f; + private const int DefaultMaxTokensInstruct = 512; + private const int DefaultTopKInstruct = 50; + private static readonly List DefaultStopSequencesInstruct = new(); + + private const float DefaultTemperatureNonInstruct = 0.7f; + private const float DefaultTopPNonInstruct = 1.0f; + private const int DefaultMaxTokensNonInstruct = 8192; + private const int DefaultTopKNonInstruct = 0; + private static readonly List DefaultStopSequencesNonInstruct = new(); + /// + /// Builds InvokeModel request Body parameter with structure as required by Mistral. + /// + /// The model ID to be used as a request parameter. + /// The input prompt for text generation. + /// Optional prompt execution settings. + /// + public object GetInvokeModelRequestBody(string modelId, string prompt, PromptExecutionSettings? executionSettings = null) + { + var isInstructModel = modelId.Contains("instruct", StringComparison.OrdinalIgnoreCase); + var temperature = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "temperature", isInstructModel ? DefaultTemperatureInstruct : (double?)DefaultTemperatureNonInstruct); + var topP = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "top_p", isInstructModel ? DefaultTopPInstruct : (double?)DefaultTopPNonInstruct); + var maxTokens = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "max_tokens", isInstructModel ? DefaultMaxTokensInstruct : (int?)DefaultMaxTokensNonInstruct); + var stop = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "stop", isInstructModel ? DefaultStopSequencesInstruct : DefaultStopSequencesNonInstruct); + var topK = BedrockModelUtilities.GetExtensionDataValue(executionSettings?.ExtensionData, "top_k", isInstructModel ? DefaultTopKInstruct : (int?)DefaultTopKNonInstruct); + + var requestBody = new + { + prompt, + max_tokens = maxTokens, + stop, + temperature, + top_p = topP, + top_k = topK + }; + + return requestBody; + } + /// + /// Extracts the test contents from the InvokeModelResponse as returned by the Bedrock API. + /// + /// The InvokeModelResponse object provided by the Bedrock InvokeModelAsync output. + /// A list of text content objects as required by the semantic kernel. + public IReadOnlyList GetInvokeResponseBody(InvokeModelResponse response) + { + using var memoryStream = new MemoryStream(); + response.Body.CopyToAsync(memoryStream).ConfigureAwait(false).GetAwaiter().GetResult(); + memoryStream.Position = 0; + using var reader = new StreamReader(memoryStream); + var responseBody = JsonSerializer.Deserialize(reader.ReadToEnd()); + var textContents = new List(); + if (responseBody?.Outputs is not { Count: > 0 }) + { + return textContents; + } + textContents.AddRange(responseBody.Outputs.Select(output => new TextContent(output.Text))); + return textContents; + } + /// + /// Builds the ConverseRequest object for the Bedrock ConverseAsync call with request parameters required by Mistral. + /// + /// The model ID. + /// The messages between assistant and user. + /// Optional prompt execution settings. + /// + public ConverseRequest GetConverseRequest(string modelId, ChatHistory chatHistory, PromptExecutionSettings? settings = null) + { + var isInstructModel = modelId.Contains("instruct", StringComparison.OrdinalIgnoreCase); + var temperature = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "temperature", isInstructModel ? DefaultTemperatureInstruct : DefaultTemperatureNonInstruct); + var messages = BedrockModelUtilities.BuildMessageList(chatHistory); + var systemMessages = BedrockModelUtilities.GetSystemMessages(chatHistory); + var converseRequest = new ConverseRequest + { + ModelId = modelId, + Messages = messages, + System = systemMessages, + InferenceConfig = new InferenceConfiguration + { + Temperature = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "temperature", temperature), + TopP = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "top_p", isInstructModel ? DefaultTopPInstruct : DefaultTopPNonInstruct), + MaxTokens = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "max_tokens", isInstructModel ? DefaultMaxTokensInstruct : DefaultMaxTokensNonInstruct) + }, + AdditionalModelRequestFields = new Document(), + AdditionalModelResponseFieldPaths = new List() + }; + return converseRequest; + } + /// + /// Extracts the text generation streaming output from the Mistral response object structure. + /// + /// + /// + public IEnumerable GetTextStreamOutput(JsonNode chunk) + { + var outputs = chunk["outputs"]?.AsArray(); + if (outputs != null) + { + foreach (var output in outputs) + { + var text = output?["text"]?.ToString(); + if (!string.IsNullOrEmpty(text)) + { + yield return text; + } + } + } + } + /// + /// Builds the ConverseStreamRequest object for the Converse Bedrock API call, including building the Mistral Request object and mapping parameters to the ConverseStreamRequest object. + /// + /// The model ID. + /// The messages between assistant and user. + /// Optional prompt execution settings. + /// + public ConverseStreamRequest GetConverseStreamRequest(string modelId, ChatHistory chatHistory, PromptExecutionSettings? settings = null) + { + var isInstructModel = modelId.Contains("instruct", StringComparison.OrdinalIgnoreCase); + var temperature = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "temperature", isInstructModel ? DefaultTemperatureInstruct : DefaultTemperatureNonInstruct); + var messages = BedrockModelUtilities.BuildMessageList(chatHistory); + var systemMessages = BedrockModelUtilities.GetSystemMessages(chatHistory); + var converseRequest = new ConverseStreamRequest() + { + ModelId = modelId, + Messages = messages, + System = systemMessages, + InferenceConfig = new InferenceConfiguration + { + Temperature = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "temperature", temperature), + TopP = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "top_p", isInstructModel ? DefaultTopPInstruct : DefaultTopPNonInstruct), + MaxTokens = BedrockModelUtilities.GetExtensionDataValue(settings?.ExtensionData, "max_tokens", isInstructModel ? DefaultMaxTokensInstruct : DefaultMaxTokensNonInstruct) + }, + AdditionalModelRequestFields = new Document(), + AdditionalModelResponseFieldPaths = new List() + }; + return converseRequest; + } +} diff --git a/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Mistral/MistralTextResponse.cs b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Mistral/MistralTextResponse.cs new file mode 100644 index 000000000000..b7763f838353 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Models/Mistral/MistralTextResponse.cs @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json.Serialization; + +namespace Connectors.Amazon.Models.Mistral; + +/// +/// Mistral Text Response body. +/// +[Serializable] +public class MistralTextResponse +{ + /// + /// A list of outputs from the model. + /// + [JsonPropertyName("outputs")] + public List? Outputs { get; set; } + + /// + /// Output parameters for the list of outputs of the text response. + /// + public class Output + { + /// + /// The text that the model generated. + /// + [JsonPropertyName("text")] + public string? Text { get; set; } + + /// + /// The reason why the response stopped generating text. + /// + [JsonPropertyName("stop_reason")] + public string? StopReason { get; set; } + } +} diff --git a/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Services/BedrockChatCompletionService.cs b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Services/BedrockChatCompletionService.cs new file mode 100644 index 000000000000..80a414352292 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Services/BedrockChatCompletionService.cs @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Amazon.BedrockRuntime; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.Amazon.Core; +using Microsoft.SemanticKernel.Services; + +namespace Microsoft.SemanticKernel.Connectors.Amazon.Services; + +/// +/// Represents a chat completion service using Amazon Bedrock API. +/// +public class BedrockChatCompletionService : IChatCompletionService +{ + private readonly Dictionary _attributesInternal = []; + private readonly BedrockChatCompletionClient _chatCompletionClient; + + /// + /// Initializes an instance of the BedrockChatCompletionService using an IAmazonBedrockRuntime object passed in by the user. + /// + /// The model to be used for chat completion. + /// The IAmazonBedrockRuntime object to be used for DI. + public BedrockChatCompletionService(string modelId, IAmazonBedrockRuntime bedrockApi) + { + this._chatCompletionClient = new BedrockChatCompletionClient(modelId, bedrockApi); + this._attributesInternal.Add(AIServiceExtensions.ModelIdKey, modelId); + } + + /// + public IReadOnlyDictionary Attributes => this._attributesInternal; + + /// + public Task> GetChatMessageContentsAsync( + ChatHistory chatHistory, + PromptExecutionSettings? executionSettings = null, + Kernel? kernel = null, + CancellationToken cancellationToken = default) + { + return this._chatCompletionClient.GenerateChatMessageAsync(chatHistory, executionSettings, kernel, cancellationToken); + } + + /// + public IAsyncEnumerable GetStreamingChatMessageContentsAsync( + ChatHistory chatHistory, + PromptExecutionSettings? executionSettings = null, + Kernel? kernel = null, + CancellationToken cancellationToken = default) + { + return this._chatCompletionClient.StreamChatMessageAsync(chatHistory, executionSettings, kernel, cancellationToken); + } +} diff --git a/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Services/BedrockTextGenerationService.cs b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Services/BedrockTextGenerationService.cs new file mode 100644 index 000000000000..55cbf4f24640 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Services/BedrockTextGenerationService.cs @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Amazon.BedrockRuntime; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.Amazon.Bedrock.Core; +using Microsoft.SemanticKernel.Services; +using Microsoft.SemanticKernel.TextGeneration; + +namespace Connectors.Amazon.Services; + +/// +/// Represents a text generation service using Amazon Bedrock API. +/// +public class BedrockTextGenerationService : ITextGenerationService +{ + private readonly Dictionary _attributesInternal = []; + private readonly BedrockTextGenerationClient _textGenerationClient; + + /// + /// Initializes an instance of the BedrockTextGenerationService using an IAmazonBedrockRuntime object passed in by the user. + /// + /// + /// + public BedrockTextGenerationService(string modelId, IAmazonBedrockRuntime bedrockApi) + { + this._textGenerationClient = new BedrockTextGenerationClient(modelId, bedrockApi); + this._attributesInternal.Add(AIServiceExtensions.ModelIdKey, modelId); + } + + /// + public IReadOnlyDictionary Attributes => this._attributesInternal; + + /// + public Task> GetTextContentsAsync( + string prompt, + PromptExecutionSettings? executionSettings = null, + Kernel? kernel = null, + CancellationToken cancellationToken = default) + => this._textGenerationClient.InvokeBedrockModelAsync(prompt, executionSettings, cancellationToken); + + /// + public IAsyncEnumerable GetStreamingTextContentsAsync( + string prompt, + PromptExecutionSettings? executionSettings = null, + Kernel? kernel = null, + CancellationToken cancellationToken = default) + => this._textGenerationClient.StreamTextAsync(prompt, executionSettings, kernel, cancellationToken); +} diff --git a/dotnet/src/Connectors/Connectors.Amazon/Connectors.Amazon.csproj b/dotnet/src/Connectors/Connectors.Amazon/Connectors.Amazon.csproj new file mode 100644 index 000000000000..5d163c3c5425 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Amazon/Connectors.Amazon.csproj @@ -0,0 +1,29 @@ + + + + net8.0 + enable + enable + + + $(NoWarn);SKEXP0001 + + + + + + + + + + + + + + + + + + + + diff --git a/dotnet/src/IntegrationTests/Connectors/Amazon/Bedrock/BedrockChatCompletionTests.cs b/dotnet/src/IntegrationTests/Connectors/Amazon/Bedrock/BedrockChatCompletionTests.cs new file mode 100644 index 000000000000..fda468b043e4 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Amazon/Bedrock/BedrockChatCompletionTests.cs @@ -0,0 +1,102 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Threading.Tasks; +using Connectors.Amazon.Extensions; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.Amazon.Bedrock; + +public class BedrockChatCompletionTests +{ + [Theory] + [InlineData("ai21.jamba-instruct-v1:0")] + [InlineData("amazon.titan-text-premier-v1:0")] + [InlineData("amazon.titan-text-lite-v1")] + [InlineData("amazon.titan-text-express-v1")] + [InlineData("anthropic.claude-v2")] + [InlineData("anthropic.claude-v2:1")] + [InlineData("anthropic.claude-instant-v1")] + [InlineData("anthropic.claude-3-sonnet-20240229-v1:0")] + [InlineData("anthropic.claude-3-haiku-20240307-v1:0")] + [InlineData("cohere.command-r-v1:0")] + [InlineData("cohere.command-r-plus-v1:0")] + [InlineData("meta.llama3-70b-instruct-v1:0")] + [InlineData("meta.llama3-8b-instruct-v1:0")] + [InlineData("mistral.mistral-7b-instruct-v0:2")] + [InlineData("mistral.mistral-large-2402-v1:0")] + [InlineData("mistral.mistral-small-2402-v1:0")] + [InlineData("mistral.mixtral-8x7b-instruct-v0:1")] + public async Task ChatCompletionReturnsValidResponseAsync(string modelId) + { + // Arrange + var chatHistory = new ChatHistory(); + chatHistory.AddUserMessage("Hello, I'm Alexa, how are you?"); + chatHistory.AddAssistantMessage("I'm doing well, thanks for asking."); + chatHistory.AddUserMessage("What is 2 + 2?"); + + var kernel = Kernel.CreateBuilder().AddBedrockChatCompletionService(modelId).Build(); + var chatCompletionService = kernel.GetRequiredService(); + + // Act + var response = await chatCompletionService.GetChatMessageContentsAsync(chatHistory).ConfigureAwait(true); + string output = ""; + foreach (var message in response) + { + output += message.Content; + } + chatHistory.AddAssistantMessage(output); + + // Assert + Assert.NotNull(output); + Assert.True(response.Count > 0); + Assert.Equal(4, chatHistory.Count); + Assert.Equal(AuthorRole.Assistant, chatHistory[3].Role); + } + + [Theory] + [InlineData("ai21.jamba-instruct-v1:0")] + [InlineData("amazon.titan-text-premier-v1:0")] + [InlineData("amazon.titan-text-lite-v1")] + [InlineData("amazon.titan-text-express-v1")] + [InlineData("anthropic.claude-v2")] + [InlineData("anthropic.claude-v2:1")] + [InlineData("anthropic.claude-instant-v1")] + [InlineData("anthropic.claude-3-sonnet-20240229-v1:0")] + [InlineData("anthropic.claude-3-haiku-20240307-v1:0")] + [InlineData("cohere.command-r-v1:0")] + [InlineData("cohere.command-r-plus-v1:0")] + [InlineData("meta.llama3-70b-instruct-v1:0")] + [InlineData("meta.llama3-8b-instruct-v1:0")] + [InlineData("mistral.mistral-7b-instruct-v0:2")] + [InlineData("mistral.mistral-large-2402-v1:0")] + [InlineData("mistral.mistral-small-2402-v1:0")] + [InlineData("mistral.mixtral-8x7b-instruct-v0:1")] + public async Task ChatStreamingReturnsValidResponseAsync(string modelId) + { + // Arrange + var chatHistory = new ChatHistory(); + chatHistory.AddUserMessage("Hello, I'm Alexa, how are you?"); + chatHistory.AddAssistantMessage("I'm doing well, thanks for asking."); + chatHistory.AddUserMessage("What is 2 + 2?"); + + var kernel = Kernel.CreateBuilder().AddBedrockChatCompletionService(modelId).Build(); + var chatCompletionService = kernel.GetRequiredService(); + + // Act + var response = chatCompletionService.GetStreamingChatMessageContentsAsync(chatHistory).ConfigureAwait(true); + string output = ""; + await foreach (var message in response) + { + output += message.Content; + } + chatHistory.AddAssistantMessage(output); + + // Assert + Assert.NotNull(output); + Assert.Equal(4, chatHistory.Count); + Assert.Equal(AuthorRole.Assistant, chatHistory[3].Role); + Assert.False(string.IsNullOrEmpty(output)); + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Amazon/Bedrock/BedrockTextGenerationTests.cs b/dotnet/src/IntegrationTests/Connectors/Amazon/Bedrock/BedrockTextGenerationTests.cs new file mode 100644 index 000000000000..546f556cca0e --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Amazon/Bedrock/BedrockTextGenerationTests.cs @@ -0,0 +1,91 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Threading.Tasks; +using Connectors.Amazon.Extensions; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.TextGeneration; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.Amazon.Bedrock; + +public class BedrockTextGenerationTests +{ + [Theory] + [InlineData("anthropic.claude-v2")] + [InlineData("anthropic.claude-v2:1")] + [InlineData("anthropic.claude-instant-v1")] + [InlineData("cohere.command-text-v14")] + [InlineData("cohere.command-light-text-v14")] + [InlineData("cohere.command-r-v1:0")] + [InlineData("cohere.command-r-plus-v1:0")] + [InlineData("ai21.jamba-instruct-v1:0")] + [InlineData("ai21.j2-ultra-v1")] + [InlineData("ai21.j2-mid-v1")] + [InlineData("meta.llama3-70b-instruct-v1:0")] + [InlineData("meta.llama3-8b-instruct-v1:0")] + [InlineData("mistral.mistral-7b-instruct-v0:2")] + [InlineData("mistral.mistral-large-2402-v1:0")] + [InlineData("mistral.mistral-small-2402-v1:0")] + [InlineData("mistral.mixtral-8x7b-instruct-v0:1")] + [InlineData("amazon.titan-text-premier-v1:0")] + [InlineData("amazon.titan-text-lite-v1")] + [InlineData("amazon.titan-text-express-v1")] + public async Task TextGenerationReturnsValidResponseAsync(string modelId) + { + // Arrange + string prompt = "What is 2 + 2?"; + var kernel = Kernel.CreateBuilder().AddBedrockTextGenerationService(modelId).Build(); + var textGenerationService = kernel.GetRequiredService(); + + // Act + var response = await textGenerationService.GetTextContentsAsync(prompt).ConfigureAwait(true); + string output = ""; + foreach (var text in response) + { + output += text; + } + + // Assert + Assert.NotNull(response); + Assert.True(response.Count > 0); + Assert.False(string.IsNullOrEmpty(output)); + } + + [Theory] + [InlineData("ai21.jamba-instruct-v1:0")] + [InlineData("anthropic.claude-v2")] + [InlineData("anthropic.claude-v2:1")] + [InlineData("anthropic.claude-instant-v1")] + [InlineData("cohere.command-text-v14")] + [InlineData("cohere.command-light-text-v14")] + [InlineData("cohere.command-r-v1:0")] + [InlineData("cohere.command-r-plus-v1:0")] + [InlineData("meta.llama3-70b-instruct-v1:0")] + [InlineData("meta.llama3-8b-instruct-v1:0")] + [InlineData("mistral.mistral-7b-instruct-v0:2")] + [InlineData("mistral.mistral-large-2402-v1:0")] + [InlineData("mistral.mistral-small-2402-v1:0")] + [InlineData("mistral.mixtral-8x7b-instruct-v0:1")] + [InlineData("amazon.titan-text-premier-v1:0")] + [InlineData("amazon.titan-text-lite-v1")] + [InlineData("amazon.titan-text-express-v1")] + public async Task TextStreamingReturnsValidResponseAsync(string modelId) + { + // Arrange + string prompt = "What is 2 + 2?"; + var kernel = Kernel.CreateBuilder().AddBedrockTextGenerationService(modelId).Build(); + var textGenerationService = kernel.GetRequiredService(); + + // Act + var response = textGenerationService.GetStreamingTextContentsAsync(prompt).ConfigureAwait(true); + string output = ""; + await foreach (var textContent in response) + { + output += textContent.Text; + } + + // Assert + Assert.NotNull(output); + Assert.False(string.IsNullOrEmpty(output)); + } +} diff --git a/dotnet/src/IntegrationTests/IntegrationTests.csproj b/dotnet/src/IntegrationTests/IntegrationTests.csproj index 63e4ec8d28fe..31152bf3205a 100644 --- a/dotnet/src/IntegrationTests/IntegrationTests.csproj +++ b/dotnet/src/IntegrationTests/IntegrationTests.csproj @@ -59,6 +59,7 @@ +