Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HuggingFace backend implementation #135

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.SemanticKernel.AI;
using Microsoft.SemanticKernel.AI.HuggingFace.Services;
using Moq;
using Moq.Protected;
using Xunit;

namespace SemanticKernelTests.AI.HuggingFace;

/// <summary>
/// Unit tests for <see cref="HuggingFaceBackend"/> class.
/// </summary>
public class HuggingFaceBackendTests : IDisposable
{
private const string BaseUri = "http://localhost:5000";
private const string Model = "gpt2";

private readonly HttpResponseMessage _response = new()
{
StatusCode = HttpStatusCode.OK,
};

/// <summary>
/// Verifies that <see cref="HuggingFaceBackend.CompleteAsync(string, CompleteRequestSettings)"/>
/// returns expected completed text without errors.
/// </summary>
[Fact]
public async Task ItReturnsCompletionCorrectlyAsync()
{
// Arrange
const string prompt = "This is test";
CompleteRequestSettings requestSettings = new();

using var backend = this.CreateBackend(this.GetTestResponse("completion_test_response.json"));

// Act
var completion = await backend.CompleteAsync(prompt, requestSettings);

// Assert
Assert.Equal("This is test completion response", completion);
}

/// <summary>
/// Verifies that <see cref="HuggingFaceBackend.GenerateEmbeddingsAsync(IList{string})"/>
/// returns expected list of generated embeddings without errors.
/// </summary>
[Fact]
public async Task ItReturnsEmbeddingsCorrectlyAsync()
{
// Arrange
const int expectedEmbeddingCount = 1;
const int expectedVectorCount = 8;
List<string> data = new() { "test_string_1", "test_string_2", "test_string_3" };

using var backend = this.CreateBackend(this.GetTestResponse("embeddings_test_response.json"));

// Act
var embeddings = await backend.GenerateEmbeddingsAsync(data);

// Assert
Assert.NotNull(embeddings);
Assert.Equal(expectedEmbeddingCount, embeddings.Count);
Assert.Equal(expectedVectorCount, embeddings.First().Count);
}

/// <summary>
/// Reads test response from file for mocking purposes.
/// </summary>
/// <param name="fileName">Name of the file with test response.</param>
private string GetTestResponse(string fileName)
{
return File.ReadAllText($"./AI/HuggingFace/TestData/{fileName}");
}

/// <summary>
/// Initializes <see cref="HuggingFaceBackend"/> with mocked <see cref="HttpClientHandler"/>.
/// </summary>
/// <param name="testResponse">Test response for <see cref="HttpClientHandler"/> to return.</param>
private HuggingFaceBackend CreateBackend(string testResponse)
{
var httpClientHandler = new Mock<HttpClientHandler>();

this._response.Content = new StringContent(testResponse);

httpClientHandler
.Protected()
.Setup<Task<HttpResponseMessage>>(
"SendAsync",
ItExpr.IsAny<HttpRequestMessage>(),
ItExpr.IsAny<CancellationToken>())
.ReturnsAsync(this._response);

return new HuggingFaceBackend(BaseUri, Model, httpClientHandler.Object);
}

public void Dispose()
{
this.Dispose(true);
GC.SuppressFinalize(this);
}

protected virtual void Dispose(bool disposing)
{
if (disposing)
{
this._response.Dispose();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"choices": [
{
"finish_reason": "test",
"index": 0,
"logprobs": "",
"text": "This is test completion response"
}
],
"created": "Tue, 21 Mar 2023 11:18:04 GMT",
"id": "",
"model": "gpt2",
"object": "text_completion",
"usage": {
"completion_tokens": 32,
"prompt_tokens": 3,
"total_tokens": 35
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"data": [
{
"embedding": [
-0.08541165292263031,
0.08639130741357803,
-0.12805694341659546,
-0.2877824902534485,
0.2114177942276001,
-0.29374566674232483,
-0.10496602207422256,
0.009402364492416382
],
"index": 0,
"object": "embedding"
}
],
"object": "list",
"usage": {
"prompt_tokens": 15,
"total_tokens": 15
}
}
9 changes: 9 additions & 0 deletions dotnet/src/SemanticKernel.Test/SemanticKernel.Test.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,13 @@
<ProjectReference Include="..\SemanticKernel\SemanticKernel.csproj" />
</ItemGroup>

<ItemGroup>
<None Update="AI\HuggingFace\TestData\completion_test_response.json">
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
</None>
<None Update="AI\HuggingFace\TestData\embeddings_test_response.json">
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
</None>
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Text.Json.Serialization;

namespace Microsoft.SemanticKernel.AI.HuggingFace.HttpSchema;

/// <summary>
/// HTTP schema to perform completion request.
/// </summary>
[Serializable]
public sealed class CompletionRequest
{
/// <summary>
/// Prompt to complete.
/// </summary>
[JsonPropertyName("prompt")]
public string Prompt { get; set; } = string.Empty;

/// <summary>
/// Model to use for completion.
/// </summary>
[JsonPropertyName("model")]
public string Model { get; set; } = string.Empty;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using System.Text.Json.Serialization;

namespace Microsoft.SemanticKernel.AI.HuggingFace.HttpSchema;

/// <summary>
/// HTTP Schema for completion response.
/// </summary>
public sealed class CompletionResponse
{
/// <summary>
/// Model containing possible completion option.
/// </summary>
public sealed class Choice
{
/// <summary>
/// Completed text.
/// </summary>
[JsonPropertyName("text")]
public string? Text { get; set; }
}

/// <summary>
/// List of possible completions.
/// </summary>
[JsonPropertyName("choices")]
public IList<Choice>? Choices { get; set; }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Text.Json.Serialization;

namespace Microsoft.SemanticKernel.AI.HuggingFace.HttpSchema;

/// <summary>
/// HTTP schema to perform embedding request.
/// </summary>
[Serializable]
public sealed class EmbeddingRequest
{
/// <summary>
/// Data to embed.
/// </summary>
[JsonPropertyName("input")]
public IList<string> Input { get; set; } = new List<string>();

/// <summary>
/// Model to use for embedding generation.
/// </summary>
[JsonPropertyName("model")]
public string Model { get; set; } = string.Empty;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using System.Text.Json.Serialization;

namespace Microsoft.SemanticKernel.AI.HuggingFace.HttpSchema;

/// <summary>
/// HTTP Schema for embedding response.
/// </summary>
public sealed class EmbeddingResponse
{
/// <summary>
/// Model containing embedding.
/// </summary>
public sealed class EmbeddingVector
{
[JsonPropertyName("embedding")]
public IList<float>? Embedding { get; set; }
}

/// <summary>
/// List of embeddings.
/// </summary>
[JsonPropertyName("data")]
public IList<EmbeddingVector>? Embeddings { get; set; }
}
Loading