Skip to content

Commit

Permalink
.Net: Fix for Chroma boolean values processing (#2072)
Browse files Browse the repository at this point in the history
### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

With latest updates in Chroma ([Release
0.4.0](https://github.com/chroma-core/chroma/releases/tag/0.4.0)),
boolean values are automatically converted to numeric representation
(False = 0, True = 1). This PR contains changes to handle this case and
convert boolean values properly.

Should fix #2049 

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->
1. Implemented `ChromaBooleanConverter` to be used during Chroma
response deserialization.
2. Added more fixes based on new Chroma version - changed error message
when trying to delete non-existent collection.
3. Added integration test to verify boolean values processing.

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#dev-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄

---------

Co-authored-by: Abby Harrison <[email protected]>
  • Loading branch information
dmytrostruk and awharrison-28 authored Jul 20, 2023
1 parent 9bdeb42 commit b79f3a0
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Text.Json;
using System.Text.Json.Serialization;

namespace Microsoft.SemanticKernel.Connectors.Memory.Chroma;

/// <summary>
/// JSON Converter for Chroma boolean values.
/// </summary>
public class ChromaBooleanConverter : JsonConverter<bool>
{
/// <inheritdoc/>
public override bool Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
if (!reader.TryGetInt16(out short value))
{
return false;
}

return Convert.ToBoolean(value);
}

/// <inheritdoc/>
public override void Write(Utf8JsonWriter writer, bool value, JsonSerializerOptions options)
{
writer.WriteNumberValue(Convert.ToDecimal(value));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ namespace Microsoft.SemanticKernel.Connectors.Memory.Chroma;
public class ChromaClientException : Exception
{
private const string CollectionDoesNotExistErrorFormat = "Collection {0} does not exist";
private const string DeleteNonExistentCollectionErrorMessage = "list index out of range";

/// <summary>
/// Initializes a new instance of the <see cref="ChromaClientException"/> class.
Expand Down Expand Up @@ -43,10 +42,4 @@ public ChromaClientException(string message, Exception innerException) : base(me
/// <param name="collectionName">Collection name.</param>
public bool CollectionDoesNotExistException(string collectionName) =>
this.Message.Contains(string.Format(CultureInfo.InvariantCulture, CollectionDoesNotExistErrorFormat, collectionName));

/// <summary>
/// Checks if Chroma API error means that there was an attempt to delete non-existent collection.
/// </summary>
public bool DeleteNonExistentCollectionException() =>
this.Message.Contains(DeleteNonExistentCollectionErrorMessage);
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public async Task DeleteCollectionAsync(string collectionName, CancellationToken
{
await this._chromaClient.DeleteCollectionAsync(collectionName, cancellationToken).ConfigureAwait(false);
}
catch (ChromaClientException e) when (e.DeleteNonExistentCollectionException())
catch (ChromaClientException e) when (e.CollectionDoesNotExistException(collectionName))
{
this._logger.LogError("Cannot delete non-existent collection {0}", collectionName);
throw new ChromaMemoryStoreException($"Cannot delete non-existent collection {collectionName}", e);
Expand Down Expand Up @@ -230,6 +230,11 @@ public async IAsyncEnumerable<string> UpsertBatchAsync(string collectionName, IE
private readonly IChromaClient _chromaClient;
private readonly List<string> _defaultEmbeddingIncludeTypes = new() { IncludeMetadatas };

private readonly JsonSerializerOptions _jsonSerializerOptions = new()
{
Converters = { new ChromaBooleanConverter() }
};

private async Task<ChromaCollectionModel> GetCollectionOrThrowAsync(string collectionName, CancellationToken cancellationToken)
{
return
Expand Down Expand Up @@ -292,15 +297,19 @@ private MemoryRecord GetMemoryRecordFromModel(List<Dictionary<string, object>>?
var embeddingsVector = this.GetEmbeddingForMemoryRecord(embeddings, recordIndex);
var key = ids?[recordIndex];

return MemoryRecord.FromJsonMetadata(
json: metadata,
return MemoryRecord.FromMetadata(
metadata: metadata,
embedding: embeddingsVector,
key: key);
}

private string GetMetadataForMemoryRecord(List<Dictionary<string, object>>? metadatas, int recordIndex)
private MemoryRecordMetadata GetMetadataForMemoryRecord(List<Dictionary<string, object>>? metadatas, int recordIndex)
{
return metadatas != null ? JsonSerializer.Serialize(metadatas[recordIndex]) : string.Empty;
var serializedMetadata = metadatas != null ? JsonSerializer.Serialize(metadatas[recordIndex]) : string.Empty;

return
JsonSerializer.Deserialize<MemoryRecordMetadata>(serializedMetadata, this._jsonSerializerOptions) ??
throw new ChromaMemoryStoreException("Unable to deserialize memory record metadata.");
}

private Embedding<float> GetEmbeddingForMemoryRecord(List<float[]>? embeddings, int recordIndex)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ public sealed class ChromaMemoryStoreTests : IDisposable
private readonly HttpMessageHandlerStub _messageHandlerStub;
private readonly HttpClient _httpClient;
private readonly Mock<IChromaClient> _chromaClientMock;
private readonly JsonSerializerOptions _serializerOptions;

public ChromaMemoryStoreTests()
{
Expand All @@ -37,6 +38,11 @@ public ChromaMemoryStoreTests()
this._chromaClientMock
.Setup(client => client.GetCollectionAsync(CollectionName, CancellationToken.None))
.ReturnsAsync(new ChromaCollectionModel { Id = CollectionId, Name = CollectionName });

this._serializerOptions = new JsonSerializerOptions
{
Converters = { new ChromaBooleanConverter() }
};
}

[Fact]
Expand Down Expand Up @@ -102,12 +108,12 @@ public async Task ItThrowsExceptionOnNonExistentCollectionDeletionAsync()
{
// Arrange
const string collectionName = "non-existent-collection";
const string deleteNonExistentCollectionErrorMessage = "list index out of range";
const string collectionDoesNotExistErrorMessage = $"Collection {collectionName} does not exist";
const string expectedExceptionMessage = $"Cannot delete non-existent collection {collectionName}";

this._chromaClientMock
.Setup(client => client.DeleteCollectionAsync(collectionName, CancellationToken.None))
.Throws(new ChromaClientException(deleteNonExistentCollectionErrorMessage));
.Throws(new ChromaClientException(collectionDoesNotExistErrorMessage));

var store = new ChromaMemoryStore(this._chromaClientMock.Object);

Expand Down Expand Up @@ -310,7 +316,7 @@ private MemoryRecord GetRandomMemoryRecord(Embedding<float>? embedding = null)

private Dictionary<string, object> GetEmbeddingMetadataFromMemoryRecord(MemoryRecord memoryRecord)
{
var serialized = JsonSerializer.Serialize(memoryRecord.Metadata);
var serialized = JsonSerializer.Serialize(memoryRecord.Metadata, this._serializerOptions);
return JsonSerializer.Deserialize<Dictionary<string, object>>(serialized)!;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,28 @@ public async Task ItCanUpsertDifferentMemoryRecordsWithSameKeyMultipleTimesAsync
this.AssertMemoryRecordEqual(expectedRecord2, actualRecord2);
}

[Theory(Skip = SkipReason)]
[InlineData(true)]
[InlineData(false)]
public async Task ItProcessesBooleanValuesCorrectlyAsync(bool isReference)
{
// Arrange
var collectionName = this.GetRandomCollectionName();
var metadata = this.GetRandomMemoryRecordMetadata(isReference: isReference);
var expectedRecord = this.GetRandomMemoryRecord(metadata: metadata);

await this._chromaMemoryStore.CreateCollectionAsync(collectionName);

// Act
var createdRecordKey = await this._chromaMemoryStore.UpsertAsync(collectionName, expectedRecord);
var actualRecord = await this._chromaMemoryStore.GetAsync(collectionName, createdRecordKey, true);

// Assert
Assert.NotNull(actualRecord);

Assert.Equal(expectedRecord.Metadata.IsReference, actualRecord.Metadata.IsReference);
}

public void Dispose()
{
this.Dispose(true);
Expand Down Expand Up @@ -429,5 +451,28 @@ private MemoryRecord GetRandomMemoryRecord(string? key = null, Embedding<float>?
key: recordKey);
}

private MemoryRecord GetRandomMemoryRecord(MemoryRecordMetadata metadata, Embedding<float>? embedding = null)
{
var recordEmbedding = embedding ?? new Embedding<float>(new[] { 1f, 3f, 5f });

return MemoryRecord.FromMetadata(
metadata: metadata,
embedding: recordEmbedding,
key: metadata.Id);
}

private MemoryRecordMetadata GetRandomMemoryRecordMetadata(bool isReference = false, string? key = null)
{
var recordKey = key ?? Guid.NewGuid().ToString();

return new MemoryRecordMetadata(
isReference: isReference,
id: recordKey,
text: "text-" + Guid.NewGuid().ToString(),
description: "description-" + Guid.NewGuid().ToString(),
externalSourceName: "source-name-" + Guid.NewGuid().ToString(),
additionalMetadata: "metadata-" + Guid.NewGuid().ToString());
}

#endregion
}

0 comments on commit b79f3a0

Please sign in to comment.