Skip to content

Commit

Permalink
.Net Added support for gemini-vision (microsoft#4957)
Browse files Browse the repository at this point in the history
### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

State of connector progress
microsoft#4680

This PR introduces limited support for ImageContent. At present, the
support extends to only Uri files and is only available in VertexAI. The
associated unit tests have been updated to reflect these changes.

@RogerBarreto 

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄
  • Loading branch information
Krzysztof318 committed Feb 15, 2024
1 parent 21970ed commit 11420ff
Show file tree
Hide file tree
Showing 4 changed files with 249 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,19 @@ public void IsValidWhenInlineDataIsNotNull()
Assert.True(result);
}

[Fact]
public void IsValidWhenFileDataIsNotNull()
{
// Arrange
var sut = new GeminiPart { FileData = new() };

// Act
var result = sut.IsValid();

// Assert
Assert.True(result);
}

[Fact]
public void IsValidWhenFunctionCallIsNotNull()
{
Expand Down Expand Up @@ -93,16 +106,35 @@ public GeminiPartTestData()
this.Add(new() { Text = "text", FunctionCall = new() });
this.Add(new() { Text = "text", InlineData = new() });
this.Add(new() { Text = "text", FunctionResponse = new() });
this.Add(new() { Text = "text", FileData = new() });
this.Add(new() { InlineData = new(), FunctionCall = new() });
this.Add(new() { InlineData = new(), FunctionResponse = new() });
this.Add(new() { InlineData = new(), FileData = new() });
this.Add(new() { FunctionCall = new(), FunctionResponse = new() });
this.Add(new() { FunctionCall = new(), FileData = new() });
this.Add(new() { FunctionResponse = new(), FileData = new() });

// Three properties
this.Add(new() { Text = "text", InlineData = new(), FunctionCall = new() });
this.Add(new() { Text = "text", InlineData = new(), FunctionResponse = new() });
this.Add(new() { Text = "text", InlineData = new(), FileData = new() });
this.Add(new() { Text = "text", FunctionCall = new(), FunctionResponse = new() });
this.Add(new() { Text = "text", FunctionCall = new(), FileData = new() });
this.Add(new() { Text = "text", FunctionResponse = new(), FileData = new() });
this.Add(new() { InlineData = new(), FunctionCall = new(), FunctionResponse = new() });
this.Add(new() { InlineData = new(), FunctionCall = new(), FileData = new() });
this.Add(new() { InlineData = new(), FunctionResponse = new(), FileData = new() });
this.Add(new() { FunctionCall = new(), FunctionResponse = new(), FileData = new() });

// Four properties
this.Add(new() { Text = "text", InlineData = new(), FunctionCall = new(), FunctionResponse = new() });
this.Add(new() { Text = "text", InlineData = new(), FunctionCall = new(), FileData = new() });
this.Add(new() { Text = "text", InlineData = new(), FunctionResponse = new(), FileData = new() });
this.Add(new() { Text = "text", FunctionCall = new(), FunctionResponse = new(), FileData = new() });
this.Add(new() { InlineData = new(), FunctionCall = new(), FunctionResponse = new(), FileData = new() });

// Five properties
this.Add(new() { Text = "text", InlineData = new(), FunctionCall = new(), FunctionResponse = new(), FileData = new() });
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.GoogleVertexAI;
using Xunit;

Expand All @@ -9,7 +13,7 @@ namespace SemanticKernel.Connectors.GoogleVertexAI.UnitTests.Core.Gemini;
public sealed class GeminiRequestTests
{
[Fact]
public void FromPromptExecutionSettingsReturnsGeminiRequestWithConfiguration()
public void FromPromptItReturnsGeminiRequestWithConfiguration()
{
// Arrange
var prompt = "prompt-example";
Expand All @@ -31,7 +35,7 @@ public void FromPromptExecutionSettingsReturnsGeminiRequestWithConfiguration()
}

[Fact]
public void FromPromptExecutionSettingsReturnsGeminiRequestWithSafetySettings()
public void FromPromptItReturnsGeminiRequestWithSafetySettings()
{
// Arrange
var prompt = "prompt-example";
Expand All @@ -53,7 +57,7 @@ public void FromPromptExecutionSettingsReturnsGeminiRequestWithSafetySettings()
}

[Fact]
public void FromPromptExecutionSettingsReturnsGeminiRequestWithPrompt()
public void FromPromptItReturnsGeminiRequestWithPrompt()
{
// Arrange
var prompt = "prompt-example";
Expand All @@ -65,4 +69,142 @@ public void FromPromptExecutionSettingsReturnsGeminiRequestWithPrompt()
// Assert
Assert.Equal(prompt, request.Contents[0].Parts[0].Text);
}

[Fact]
public void FromChatHistoryItReturnsGeminiRequestWithConfiguration()
{
// Arrange
ChatHistory chatHistory = [];
chatHistory.AddUserMessage("user-message");
chatHistory.AddAssistantMessage("assist-message");
chatHistory.AddUserMessage("user-message2");
var executionSettings = new GeminiPromptExecutionSettings
{
Temperature = 1.5,
MaxTokens = 10,
TopP = 0.9,
};

// Act
var request = GeminiRequest.FromChatHistoryAndExecutionSettings(chatHistory, executionSettings);

// Assert
Assert.NotNull(request.Configuration);
Assert.Equal(executionSettings.Temperature, request.Configuration.Temperature);
Assert.Equal(executionSettings.MaxTokens, request.Configuration.MaxOutputTokens);
Assert.Equal(executionSettings.TopP, request.Configuration.TopP);
}

[Fact]
public void FromChatHistoryItReturnsGeminiRequestWithSafetySettings()
{
// Arrange
ChatHistory chatHistory = [];
chatHistory.AddUserMessage("user-message");
chatHistory.AddAssistantMessage("assist-message");
chatHistory.AddUserMessage("user-message2");
var executionSettings = new GeminiPromptExecutionSettings
{
SafetySettings = new List<GeminiSafetySetting>
{
new(GeminiSafetyCategory.Derogatory, GeminiSafetyThreshold.BlockNone)
}
};

// Act
var request = GeminiRequest.FromChatHistoryAndExecutionSettings(chatHistory, executionSettings);

// Assert
Assert.NotNull(request.SafetySettings);
Assert.Equal(executionSettings.SafetySettings[0].Category, request.SafetySettings[0].Category);
Assert.Equal(executionSettings.SafetySettings[0].Threshold, request.SafetySettings[0].Threshold);
}

[Fact]
public void FromChatHistoryItReturnsGeminiRequestWithChatHistory()
{
// Arrange
ChatHistory chatHistory = [];
chatHistory.AddUserMessage("user-message");
chatHistory.AddAssistantMessage("assist-message");
chatHistory.AddUserMessage("user-message2");
var executionSettings = new GeminiPromptExecutionSettings();

// Act
var request = GeminiRequest.FromChatHistoryAndExecutionSettings(chatHistory, executionSettings);

// Assert
Assert.Collection(request.Contents,
c => Assert.Equal(chatHistory[0].Content, c.Parts[0].Text),
c => Assert.Equal(chatHistory[1].Content, c.Parts[0].Text),
c => Assert.Equal(chatHistory[2].Content, c.Parts[0].Text));
Assert.Collection(request.Contents,
c => Assert.Equal(chatHistory[0].Role, c.Role),
c => Assert.Equal(chatHistory[1].Role, c.Role),
c => Assert.Equal(chatHistory[2].Role, c.Role));
}

[Fact]
public void FromChatHistoryTextAsTextContentItReturnsGeminiRequestWithChatHistory()
{
// Arrange
ChatHistory chatHistory = [];
chatHistory.AddUserMessage("user-message");
chatHistory.AddAssistantMessage("assist-message");
chatHistory.AddUserMessage(contentItems: [new TextContent("user-message2")]);
var executionSettings = new GeminiPromptExecutionSettings();

// Act
var request = GeminiRequest.FromChatHistoryAndExecutionSettings(chatHistory, executionSettings);

// Assert
Assert.Collection(request.Contents,
c => Assert.Equal(chatHistory[0].Content, c.Parts[0].Text),
c => Assert.Equal(chatHistory[1].Content, c.Parts[0].Text),
c => Assert.Equal(chatHistory[2].Items!.Cast<TextContent>().Single().Text, c.Parts[0].Text));
}

[Fact]
public void FromChatHistoryImageAsImageContentItReturnsGeminiRequestWithChatHistory()
{
// Arrange
ChatHistory chatHistory = [];
chatHistory.AddUserMessage("user-message");
chatHistory.AddAssistantMessage("assist-message");
chatHistory.AddUserMessage(contentItems:
[new ImageContent(new Uri("https://example-image.com/"), metadata: new Dictionary<string, object?> { ["mimeType"] = "value" })]);
var executionSettings = new GeminiPromptExecutionSettings();

// Act
var request = GeminiRequest.FromChatHistoryAndExecutionSettings(chatHistory, executionSettings);

// Assert
Assert.Collection(request.Contents,
c => Assert.Equal(chatHistory[0].Content, c.Parts[0].Text),
c => Assert.Equal(chatHistory[1].Content, c.Parts[0].Text),
c => Assert.Equal(chatHistory[2].Items!.Cast<ImageContent>().Single().Uri, c.Parts[0].FileData!.FileUri));
}

[Fact]
public void FromChatHistoryUnsupportedContentItThrowsNotSupportedException()
{
// Arrange
ChatHistory chatHistory = [];
chatHistory.AddUserMessage("user-message");
chatHistory.AddAssistantMessage("assist-message");
chatHistory.AddUserMessage(contentItems: [new DummyContent("unsupported-content")]);
var executionSettings = new GeminiPromptExecutionSettings();

// Act
void Act() => GeminiRequest.FromChatHistoryAndExecutionSettings(chatHistory, executionSettings);

// Assert
Assert.Throws<NotSupportedException>(Act);
}

private sealed class DummyContent : KernelContent
{
public DummyContent(object? innerContent, string? modelId = null, IReadOnlyDictionary<string, object?>? metadata = null)
: base(innerContent, modelId, metadata) { }
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Text.Json;
using System.Text.Json.Nodes;
Expand All @@ -20,12 +21,19 @@ public sealed class GeminiPart : IJsonOnDeserialized
public string? Text { get; set; }

/// <summary>
/// Gets or sets the image or video data.
/// Gets or sets the image or video as binary data.
/// </summary>
[JsonPropertyName("inlineData")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public InlineDataPart? InlineData { get; set; }

/// <summary>
/// Gets or sets the image or video as file uri.
/// </summary>
[JsonPropertyName("fileData")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public FileDataPart? FileData { get; set; }

/// <summary>
/// Function call data.
/// </summary>
Expand All @@ -42,13 +50,14 @@ public sealed class GeminiPart : IJsonOnDeserialized

/// <summary>
/// Checks whether only one property of the GeminiPart instance is not null.
/// Returns true if only one property among Text, InlineData, FunctionCall, and FunctionResponse is not null,
/// Returns true if only one property among Text, InlineData, FileData, FunctionCall, and FunctionResponse is not null,
/// Otherwise, it returns false.
/// </summary>
public bool IsValid()
{
return (this.Text != null ? 1 : 0) +
(this.InlineData != null ? 1 : 0) +
(this.FileData != null ? 1 : 0) +
(this.FunctionCall != null ? 1 : 0) +
(this.FunctionResponse != null ? 1 : 0) == 1;
}
Expand All @@ -59,7 +68,7 @@ public void OnDeserialized()
if (!this.IsValid())
{
throw new JsonException(
"GeminiPart is invalid. One and only one property among Text, InlineData, FunctionCall, and FunctionResponse should be set.");
"GeminiPart is invalid. One and only one property among Text, InlineData, FileData, FunctionCall, and FunctionResponse should be set.");
}
}

Expand All @@ -72,7 +81,7 @@ public sealed class InlineDataPart
/// The IANA standard MIME type of the source data.
/// </summary>
/// <remarks>
/// Accepted types include: "image/png", "image/jpeg", "image/heic", "image/heif", "image/webp".
/// Acceptable values include: "image/png", "image/jpeg", "image/heic", "image/heif", "image/webp".
/// </remarks>
[JsonPropertyName("mimeType")]
[JsonRequired]
Expand All @@ -86,6 +95,30 @@ public sealed class InlineDataPart
public string InlineData { get; set; } = null!;
}

/// <summary>
/// File media bytes like image or video data.
/// </summary>
public sealed class FileDataPart
{
/// <summary>
/// The IANA standard MIME type of the source data.
/// </summary>
/// <remarks>
/// Acceptable values include: "image/png", "image/jpeg", "video/mov", "video/mpeg", "video/mp4", "video/mpg", "video/avi", "video/wmv", "video/mpegps", "video/flv".
/// </remarks>
[JsonPropertyName("mimeType")]
[JsonRequired]
public string MimeType { get; set; } = null!;

/// <summary>
/// The Cloud Storage URI of the image or video to include in the prompt.
/// The bucket that stores the file must be in the same Google Cloud project that's sending the request.
/// </summary>
[JsonPropertyName("fileUri")]
[JsonRequired]
public Uri FileUri { get; set; } = null!;
}

/// <summary>
/// A predicted FunctionCall returned from the model that contains a
/// string representing the FunctionDeclaration.name with the arguments and their values.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json.Serialization;
Expand Down Expand Up @@ -73,20 +74,46 @@ private static GeminiRequest CreateGeminiRequest(ChatHistory chatHistory)
{
Contents = chatHistory.Select(c => new GeminiContent
{
Parts = new List<GeminiPart>
{
new()
{
Text = (c.Items?.SingleOrDefault(content => content is TextContent)
as TextContent)?.Text ?? c.Content ?? string.Empty,
}
},
Parts = CreateGeminiParts(c),
Role = c.Role
}).ToList()
};
return obj;
}

private static List<GeminiPart> CreateGeminiParts(ChatMessageContent content)
{
var list = content.Items?.Select(item => item switch
{
TextContent textContent => new GeminiPart { Text = textContent.Text },
ImageContent imageContent => new GeminiPart
{
FileData = new GeminiPart.FileDataPart
{
MimeType = GetMimeTypeFromImageContent(imageContent),
FileUri = imageContent.Uri ?? throw new InvalidOperationException("Image content URI is empty.")
}
},
_ => throw new NotSupportedException($"Unsupported content type. {item.GetType().Name} is not supported by Gemini.")
}).ToList() ?? new List<GeminiPart>();

if (list.Count == 0)
{
list.Add(new GeminiPart { Text = content.Content ?? string.Empty });
}

return list;
}

private static string GetMimeTypeFromImageContent(ImageContent imageContent)
{
var key = imageContent.Metadata?.Keys.SingleOrDefault(key =>
key.Equals("mimeType", StringComparison.OrdinalIgnoreCase)
|| key.Equals("mime_type", StringComparison.OrdinalIgnoreCase))
?? throw new InvalidOperationException("Mime type is not found in the image content metadata.");
return imageContent.Metadata[key]!.ToString();
}

private static void AddConfiguration(GeminiPromptExecutionSettings executionSettings, GeminiRequest obj)
{
obj.Configuration = new ConfigurationElement
Expand Down

0 comments on commit 11420ff

Please sign in to comment.