From b20838224c4431ae9e772bdeba32205e7c4ae4e6 Mon Sep 17 00:00:00 2001 From: Thomas Farr Date: Tue, 27 Jun 2023 10:58:53 +1200 Subject: [PATCH] Implement k-NN approximate search and mappings bindings (#215) (#239) * Implement KNN bindings Signed-off-by: Thomas Farr * Minimal docs Signed-off-by: Thomas Farr * Add to changelog Signed-off-by: Thomas Farr --------- Signed-off-by: Thomas Farr (cherry picked from commit 38330056f3a9e056a98f71e7f3e78dc81d5dc614) --- CHANGELOG.md | 4 + .../OpenSearchNode.cs | 18 ++ .../Mapping/DynamicTemplate/SingleMapping.cs | 4 + .../Mapping/Types/FieldType.cs | 5 +- .../Mapping/Types/Properties.cs | 6 + .../Mapping/Types/PropertyFormatter.cs | 1 + .../Specialized/Knn/KnnVectorProperty.cs | 163 ++++++++++++++++++ .../Mapping/Visitor/IPropertyVisitor.cs | 2 + .../Mapping/Visitor/NoopPropertyVisitor.cs | 5 + .../Abstractions/Container/IQueryContainer.cs | 3 + .../Container/QueryContainer-Assignments.cs | 7 + .../Container/QueryContainerDescriptor.cs | 3 + src/OpenSearch.Client/QueryDsl/Query.cs | 3 + .../QueryDsl/Specialized/Knn/KnnQuery.cs | 76 ++++++++ .../QueryDsl/Visitor/DslPrettyPrintVisitor.cs | 2 + .../QueryDsl/Visitor/QueryVisitor.cs | 4 + .../QueryDsl/Visitor/QueryWalker.cs | 5 + .../NodeSeeders/DefaultSeeder.cs | 19 +- tests/Tests.Domain/Project.cs | 9 +- .../PutMapping/PutMappingApiTest.cs | 45 ++++- .../Specialized/Knn/KnnVectorPropertyTests.cs | 79 +++++++++ .../Specialized/Knn/KnnQueryUsageTests.cs | 162 +++++++++++++++++ 22 files changed, 618 insertions(+), 7 deletions(-) create mode 100644 src/OpenSearch.Client/Mapping/Types/Specialized/Knn/KnnVectorProperty.cs create mode 100644 src/OpenSearch.Client/QueryDsl/Specialized/Knn/KnnQuery.cs create mode 100644 tests/Tests/Mapping/Types/Specialized/Knn/KnnVectorPropertyTests.cs create mode 100644 tests/Tests/QueryDsl/Specialized/Knn/KnnQueryUsageTests.cs diff --git a/CHANGELOG.md b/CHANGELOG.md index b8491cd538..7f45f5125e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ## [Unreleased] + +### Added +- Added support for approximate k-NN search queries and k-NN vector index properties ([#215](https://github.com/opensearch-project/opensearch-net/pull/215)) + ### Dependencies - Bumps `System.Reflection.Emit` from 4.3.0 to 4.7.0 - Bumps `Argu` from 5.5.0 to 6.1.1 diff --git a/abstractions/src/OpenSearch.OpenSearch.Managed/OpenSearchNode.cs b/abstractions/src/OpenSearch.OpenSearch.Managed/OpenSearchNode.cs index e5177337a1..a65aa880c5 100644 --- a/abstractions/src/OpenSearch.OpenSearch.Managed/OpenSearchNode.cs +++ b/abstractions/src/OpenSearch.OpenSearch.Managed/OpenSearchNode.cs @@ -29,6 +29,8 @@ using System; using System.Collections.Generic; using System.Diagnostics; +using System.IO; +using System.Runtime.InteropServices; using System.Threading; using OpenSearch.OpenSearch.Managed.Configuration; using OpenSearch.OpenSearch.Managed.ConsoleWriters; @@ -93,9 +95,25 @@ private static Dictionary EnvVars(NodeConfiguration config) if (!string.IsNullOrWhiteSpace(config.FileSystem.OpenSearchHome)) environmentVariables.Add("OPENSEARCH_HOME", config.FileSystem.OpenSearchHome); + var knnLibDir = Path.Combine(config.FileSystem.OpenSearchHome, "plugins", "opensearch-knn", config.Version.Major >= 2 ? "lib" : "knnlib"); + if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + AppendPathEnvVar("JAVA_LIBRARY_PATH", knnLibDir); + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + AppendPathEnvVar("LD_LIBRARY_PATH", knnLibDir); + return environmentVariables; } + private static void AppendPathEnvVar(string name, string value) + { + var previous = Environment.GetEnvironmentVariable(name); + Environment.SetEnvironmentVariable(name, + string.IsNullOrWhiteSpace(previous) + ? value + : $"{previous}{Path.PathSeparator}{value}" + ); + } + private bool AssumedStartedStateChecker(string section, string message) { if (AssumeStartedOnNotEnoughMasterPing diff --git a/src/OpenSearch.Client/Mapping/DynamicTemplate/SingleMapping.cs b/src/OpenSearch.Client/Mapping/DynamicTemplate/SingleMapping.cs index 01d3ae0472..42e4440d8d 100644 --- a/src/OpenSearch.Client/Mapping/DynamicTemplate/SingleMapping.cs +++ b/src/OpenSearch.Client/Mapping/DynamicTemplate/SingleMapping.cs @@ -149,6 +149,10 @@ public IProperty Generic(Func, IGenericProperty> se public IProperty SearchAsYouType(Func, ISearchAsYouTypeProperty> selector) => selector?.Invoke(new SearchAsYouTypePropertyDescriptor()); + /// + public IProperty KnnVector(Func, IKnnVectorProperty> selector) => + selector?.Invoke(new KnnVectorPropertyDescriptor()); + #pragma warning disable CS3001 // Argument type is not CLS-compliant public IProperty Scalar(Expression> field, Func, INumberProperty> selector = null) => selector.InvokeOrDefault(new NumberPropertyDescriptor().Name(field).Type(NumberType.Integer)); diff --git a/src/OpenSearch.Client/Mapping/Types/FieldType.cs b/src/OpenSearch.Client/Mapping/Types/FieldType.cs index 2cd98bac7f..526813c5d1 100644 --- a/src/OpenSearch.Client/Mapping/Types/FieldType.cs +++ b/src/OpenSearch.Client/Mapping/Types/FieldType.cs @@ -159,6 +159,9 @@ public enum FieldType RankFeature, [EnumMember(Value = "rank_features")] - RankFeatures + RankFeatures, + + [EnumMember(Value = "knn_vector")] + KnnVector } } diff --git a/src/OpenSearch.Client/Mapping/Types/Properties.cs b/src/OpenSearch.Client/Mapping/Types/Properties.cs index 7631ccfdbd..8713d4aedc 100644 --- a/src/OpenSearch.Client/Mapping/Types/Properties.cs +++ b/src/OpenSearch.Client/Mapping/Types/Properties.cs @@ -157,6 +157,9 @@ TReturnType Nested(Func, INestedProp /// TReturnType SearchAsYouType(Func, ISearchAsYouTypeProperty> selector); + + /// + TReturnType KnnVector(Func, IKnnVectorProperty> selector); } public partial class PropertiesDescriptor where T : class @@ -252,6 +255,9 @@ public PropertiesDescriptor Object(Func public PropertiesDescriptor RankFeatures(Func, IRankFeaturesProperty> selector) => SetProperty(selector); + /// + public PropertiesDescriptor KnnVector(Func, IKnnVectorProperty> selector) => SetProperty(selector); + /// /// Map a custom property. /// diff --git a/src/OpenSearch.Client/Mapping/Types/PropertyFormatter.cs b/src/OpenSearch.Client/Mapping/Types/PropertyFormatter.cs index 50280c72b8..7ab1221b48 100644 --- a/src/OpenSearch.Client/Mapping/Types/PropertyFormatter.cs +++ b/src/OpenSearch.Client/Mapping/Types/PropertyFormatter.cs @@ -118,6 +118,7 @@ public IProperty Deserialize(ref JsonReader reader, IJsonFormatterResolver forma case FieldType.Alias: return Deserialize(ref segmentReader, formatterResolver); case FieldType.RankFeature: return Deserialize(ref segmentReader, formatterResolver); case FieldType.RankFeatures: return Deserialize(ref segmentReader, formatterResolver); + case FieldType.KnnVector: return Deserialize(ref segmentReader, formatterResolver); case FieldType.None: // no "type" field in the property mapping, or FieldType enum could not be parsed from typeString return Deserialize(ref segmentReader, formatterResolver); diff --git a/src/OpenSearch.Client/Mapping/Types/Specialized/Knn/KnnVectorProperty.cs b/src/OpenSearch.Client/Mapping/Types/Specialized/Knn/KnnVectorProperty.cs new file mode 100644 index 0000000000..3d3e371684 --- /dev/null +++ b/src/OpenSearch.Client/Mapping/Types/Specialized/Knn/KnnVectorProperty.cs @@ -0,0 +1,163 @@ +/* SPDX-License-Identifier: Apache-2.0 +* +* The OpenSearch Contributors require contributions made to +* this file be licensed under the Apache-2.0 license or a +* compatible open source license. +*/ + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Runtime.Serialization; +using OpenSearch.Net.Utf8Json; + +namespace OpenSearch.Client; + +[ReadAs(typeof(KnnVectorProperty))] +[InterfaceDataContract] +public interface IKnnVectorProperty : IDocValuesProperty +{ + /// + /// The dimension of the vector. + /// + [DataMember(Name = "dimension")] + int? Dimension { get; set; } + + /// + /// The model to use when the underlying Approximate k-NN algorithm requires a training step. + /// + [DataMember(Name = "model_id")] + string ModelId { get; set; } + + /// + /// The method to use when the underlying Approximate k-NN algorithm does not require training. + /// + [DataMember(Name = "method")] + IKnnMethod Method { get; set; } +} + +[ReadAs(typeof(KnnMethod))] +[InterfaceDataContract] +public interface IKnnMethod +{ + /// + /// The identifier for the nearest neighbor method. + /// + [DataMember(Name = "name")] + string Name { get; set; } + + /// + /// The approximate k-NN library to use for indexing and search. + /// + [DataMember(Name = "engine")] + string Engine { get; set; } + + /// + /// The vector space used to calculate the distance between vectors. + /// + [DataMember(Name = "space_type")] + string SpaceType { get; set; } + + /// + /// The parameters used for the nearest neighbor method. + /// + [DataMember(Name = "parameters")] + IDictionary Parameters { get; set; } +} + +public class KnnMethod : IKnnMethod +{ + /// + public string Name { get; set; } + /// + public string Engine { get; set; } + /// + public string SpaceType { get; set; } + /// + public IDictionary Parameters { get; set; } +} + +[InterfaceDataContract] +[JsonFormatter(typeof(VerbatimDictionaryKeysFormatter))] +public interface IKnnMethodParameters : IIsADictionary { } + +public class KnnMethodParameters : IsADictionaryBase, IKnnMethodParameters +{ + public KnnMethodParameters() { } + + public KnnMethodParameters(IDictionary container) : base(container) { } + + public KnnMethodParameters(Dictionary container) : base(container) { } + + public void Add(string name, object value) => BackingDictionary.Add(name, value); +} + +[DebuggerDisplay("{DebugDisplay}")] +public class KnnVectorProperty : DocValuesPropertyBase, IKnnVectorProperty +{ + public KnnVectorProperty() : base(FieldType.KnnVector) { } + + /// + public int? Dimension { get; set; } + /// + public string ModelId { get; set; } + /// + public IKnnMethod Method { get; set; } +} + +[DebuggerDisplay("{DebugDisplay}")] +public class KnnVectorPropertyDescriptor + : DocValuesPropertyDescriptorBase, IKnnVectorProperty, T>, IKnnVectorProperty + where T : class +{ + public KnnVectorPropertyDescriptor() : base(FieldType.KnnVector) { } + + int? IKnnVectorProperty.Dimension { get; set; } + string IKnnVectorProperty.ModelId { get; set; } + IKnnMethod IKnnVectorProperty.Method { get; set; } + + /// + public KnnVectorPropertyDescriptor Dimension(int? dimension) => + Assign(dimension, (p, v) => p.Dimension = v); + + /// + public KnnVectorPropertyDescriptor ModelId(string modelId) => + Assign(modelId, (p, v) => p.ModelId = v); + + /// + public KnnVectorPropertyDescriptor Method(Func selector) => + Assign(selector, (p, v) => p.Method = v?.Invoke(new KnnMethodDescriptor())); +} + +public class KnnMethodDescriptor + : DescriptorBase, IKnnMethod +{ + string IKnnMethod.Name { get; set; } + string IKnnMethod.Engine { get; set; } + string IKnnMethod.SpaceType { get; set; } + IDictionary IKnnMethod.Parameters { get; set; } + + /// + public KnnMethodDescriptor Name(string name) => + Assign(name, (c, v) => c.Name = v); + + /// + public KnnMethodDescriptor Engine(string engine) => + Assign(engine, (c, v) => c.Engine = v); + + /// + public KnnMethodDescriptor SpaceType(string spaceType) => + Assign(spaceType, (c, v) => c.SpaceType = v); + + /// + public KnnMethodDescriptor Parameters(Func> selector) => + Assign(selector, (c, v) => c.Parameters = v?.Invoke(new KnnMethodParametersDescriptor())?.Value); +} + +public class KnnMethodParametersDescriptor : IsADictionaryDescriptorBase +{ + public KnnMethodParametersDescriptor() : base(new KnnMethodParameters()) { } + + public KnnMethodParametersDescriptor Parameter(string name, object value) => + Assign(name, value); +} diff --git a/src/OpenSearch.Client/Mapping/Visitor/IPropertyVisitor.cs b/src/OpenSearch.Client/Mapping/Visitor/IPropertyVisitor.cs index 80fec03979..1e7d775fef 100644 --- a/src/OpenSearch.Client/Mapping/Visitor/IPropertyVisitor.cs +++ b/src/OpenSearch.Client/Mapping/Visitor/IPropertyVisitor.cs @@ -88,6 +88,8 @@ public interface IPropertyVisitor void Visit(IFieldAliasProperty type, PropertyInfo propertyInfo, OpenSearchPropertyAttributeBase attribute); + void Visit(IKnnVectorProperty type, PropertyInfo propertyInfo, OpenSearchPropertyAttributeBase attribute); + IProperty Visit(PropertyInfo propertyInfo, OpenSearchPropertyAttributeBase attribute); bool SkipProperty(PropertyInfo propertyInfo, OpenSearchPropertyAttributeBase attribute); diff --git a/src/OpenSearch.Client/Mapping/Visitor/NoopPropertyVisitor.cs b/src/OpenSearch.Client/Mapping/Visitor/NoopPropertyVisitor.cs index 62b9dcf13e..d913020247 100644 --- a/src/OpenSearch.Client/Mapping/Visitor/NoopPropertyVisitor.cs +++ b/src/OpenSearch.Client/Mapping/Visitor/NoopPropertyVisitor.cs @@ -89,6 +89,8 @@ public virtual void Visit(ISearchAsYouTypeProperty type, PropertyInfo propertyIn public virtual void Visit(IFieldAliasProperty type, PropertyInfo propertyInfo, OpenSearchPropertyAttributeBase attribute) { } + public virtual void Visit(IKnnVectorProperty type, PropertyInfo propertyInfo, OpenSearchPropertyAttributeBase attribute) { } + public virtual IProperty Visit(PropertyInfo propertyInfo, OpenSearchPropertyAttributeBase attribute) => null; public void Visit(IProperty type, PropertyInfo propertyInfo, OpenSearchPropertyAttributeBase attribute) @@ -176,6 +178,9 @@ public void Visit(IProperty type, PropertyInfo propertyInfo, OpenSearchPropertyA case IFieldAliasProperty fieldAlias: Visit(fieldAlias, propertyInfo, attribute); break; + case IKnnVectorProperty knnVector: + Visit(knnVector, propertyInfo, attribute); + break; } } } diff --git a/src/OpenSearch.Client/QueryDsl/Abstractions/Container/IQueryContainer.cs b/src/OpenSearch.Client/QueryDsl/Abstractions/Container/IQueryContainer.cs index d56f9fefa1..3468f49ee4 100644 --- a/src/OpenSearch.Client/QueryDsl/Abstractions/Container/IQueryContainer.cs +++ b/src/OpenSearch.Client/QueryDsl/Abstractions/Container/IQueryContainer.cs @@ -200,6 +200,9 @@ public interface IQueryContainer [DataMember(Name = "distance_feature")] IDistanceFeatureQuery DistanceFeature { get; set; } + [DataMember(Name = "knn")] + IKnnQuery Knn { get; set; } + void Accept(IQueryVisitor visitor); } } diff --git a/src/OpenSearch.Client/QueryDsl/Abstractions/Container/QueryContainer-Assignments.cs b/src/OpenSearch.Client/QueryDsl/Abstractions/Container/QueryContainer-Assignments.cs index f865f234fd..a7b9c79fdb 100644 --- a/src/OpenSearch.Client/QueryDsl/Abstractions/Container/QueryContainer-Assignments.cs +++ b/src/OpenSearch.Client/QueryDsl/Abstractions/Container/QueryContainer-Assignments.cs @@ -51,6 +51,7 @@ public partial class QueryContainer : IQueryContainer, IDescriptor private IHasParentQuery _hasParent; private IIdsQuery _ids; private IIntervalsQuery _intervals; + private IKnnQuery _knn; private IMatchQuery _match; private IMatchAllQuery _matchAllQuery; private IMatchBoolPrefixQuery _matchBoolPrefixQuery; @@ -193,6 +194,12 @@ IIntervalsQuery IQueryContainer.Intervals set => _intervals = Set(value); } + IKnnQuery IQueryContainer.Knn + { + get => _knn; + set => _knn = Set(value); + } + IMatchQuery IQueryContainer.Match { get => _match; diff --git a/src/OpenSearch.Client/QueryDsl/Abstractions/Container/QueryContainerDescriptor.cs b/src/OpenSearch.Client/QueryDsl/Abstractions/Container/QueryContainerDescriptor.cs index ac43b0c980..419e41d869 100644 --- a/src/OpenSearch.Client/QueryDsl/Abstractions/Container/QueryContainerDescriptor.cs +++ b/src/OpenSearch.Client/QueryDsl/Abstractions/Container/QueryContainerDescriptor.cs @@ -237,6 +237,9 @@ public QueryContainer HasChild(Func, IHa public QueryContainer HasParent(Func, IHasParentQuery> selector) where TParent : class => WrapInContainer(selector, (query, container) => container.HasParent = query); + public QueryContainer Knn(Func, IKnnQuery> selector) => + WrapInContainer(selector, (query, container) => container.Knn = query); + /// /// A query that generates the union of documents produced by its subqueries, and that scores each document /// with the maximum score for that document as produced by any subquery, plus a tie breaking increment for diff --git a/src/OpenSearch.Client/QueryDsl/Query.cs b/src/OpenSearch.Client/QueryDsl/Query.cs index 2572529356..84796d0636 100644 --- a/src/OpenSearch.Client/QueryDsl/Query.cs +++ b/src/OpenSearch.Client/QueryDsl/Query.cs @@ -92,6 +92,9 @@ public static QueryContainer Ids(Func selector) = public static QueryContainer Intervals(Func, IIntervalsQuery> selector) => new QueryContainerDescriptor().Intervals(selector); + public static QueryContainer Knn(Func, IKnnQuery> selector) => + new QueryContainerDescriptor().Knn(selector); + public static QueryContainer Match(Func, IMatchQuery> selector) => new QueryContainerDescriptor().Match(selector); diff --git a/src/OpenSearch.Client/QueryDsl/Specialized/Knn/KnnQuery.cs b/src/OpenSearch.Client/QueryDsl/Specialized/Knn/KnnQuery.cs new file mode 100644 index 0000000000..88dcdadd3c --- /dev/null +++ b/src/OpenSearch.Client/QueryDsl/Specialized/Knn/KnnQuery.cs @@ -0,0 +1,76 @@ +/* SPDX-License-Identifier: Apache-2.0 +* +* The OpenSearch Contributors require contributions made to +* this file be licensed under the Apache-2.0 license or a +* compatible open source license. +*/ + +using System; +using System.Runtime.Serialization; +using OpenSearch.Net.Utf8Json; + +namespace OpenSearch.Client; + +/// +/// An approximate k-NN query. +/// +[InterfaceDataContract] +[JsonFormatter(typeof(FieldNameQueryFormatter))] +public interface IKnnQuery : IFieldNameQuery +{ + /// + /// The vector to search for. + /// + [DataMember(Name = "vector")] + float[] Vector { get; set; } + + /// + /// The number of neighbors the search of each graph will return. + /// + [DataMember(Name = "k")] + int? K { get; set; } + + /// + /// The result restriction filter query. + /// + [DataMember(Name = "filter")] + IQueryContainer Filter { get; set; } +} + +[DataContract] +public class KnnQuery : FieldNameQueryBase, IKnnQuery +{ + /// + public float[] Vector { get; set; } + /// + public int? K { get; set; } + /// + public IQueryContainer Filter { get; set; } + + protected override bool Conditionless => IsConditionless(this); + + internal override void InternalWrapInContainer(IQueryContainer container) => container.Knn = this; + + internal static bool IsConditionless(IKnnQuery q) => q.Vector == null || q.Vector.Length == 0 || q.K == null || q.K == 0 || q.Field.IsConditionless(); +} + +public class KnnQueryDescriptor + : FieldNameQueryDescriptorBase, IKnnQuery, T>, + IKnnQuery + where T : class +{ + protected override bool Conditionless => KnnQuery.IsConditionless(this); + float[] IKnnQuery.Vector { get; set; } + int? IKnnQuery.K { get; set; } + IQueryContainer IKnnQuery.Filter { get; set; } + + /// + public KnnQueryDescriptor Vector(params float[] vector) => Assign(vector, (a, v) => a.Vector = v); + + /// + public KnnQueryDescriptor K(int? k) => Assign(k, (a, v) => a.K = v); + + /// + public KnnQueryDescriptor Filter(Func, QueryContainer> filterSelector) => + Assign(filterSelector, (a, v) => a.Filter = v?.Invoke(new QueryContainerDescriptor())); +} diff --git a/src/OpenSearch.Client/QueryDsl/Visitor/DslPrettyPrintVisitor.cs b/src/OpenSearch.Client/QueryDsl/Visitor/DslPrettyPrintVisitor.cs index c33825d727..2608c09ac9 100644 --- a/src/OpenSearch.Client/QueryDsl/Visitor/DslPrettyPrintVisitor.cs +++ b/src/OpenSearch.Client/QueryDsl/Visitor/DslPrettyPrintVisitor.cs @@ -159,6 +159,8 @@ private void WriteShape(IGeoShape shape, IFieldLookup indexedField, Field field, public virtual void Visit(IIntervalsQuery query) => Write("intervals"); + public virtual void Visit(IKnnQuery query) => Write("knn", query.Field); + public virtual void Visit(IMatchQuery query) => Write("match", query.Field); public virtual void Visit(IMatchPhraseQuery query) => Write("match_phrase", query.Field); diff --git a/src/OpenSearch.Client/QueryDsl/Visitor/QueryVisitor.cs b/src/OpenSearch.Client/QueryDsl/Visitor/QueryVisitor.cs index 056187bd0e..4440578ab7 100644 --- a/src/OpenSearch.Client/QueryDsl/Visitor/QueryVisitor.cs +++ b/src/OpenSearch.Client/QueryDsl/Visitor/QueryVisitor.cs @@ -80,6 +80,8 @@ public interface IQueryVisitor void Visit(IIntervalsQuery query); + void Visit(IKnnQuery query); + void Visit(IMatchQuery query); void Visit(IMatchPhraseQuery query); @@ -225,6 +227,8 @@ public virtual void Visit(IIdsQuery query) { } public virtual void Visit(IIntervalsQuery query) { } + public virtual void Visit(IKnnQuery query) { } + public virtual void Visit(IMatchQuery query) { } public virtual void Visit(IMatchPhraseQuery query) { } diff --git a/src/OpenSearch.Client/QueryDsl/Visitor/QueryWalker.cs b/src/OpenSearch.Client/QueryDsl/Visitor/QueryWalker.cs index 6d2e255153..2ff147331b 100644 --- a/src/OpenSearch.Client/QueryDsl/Visitor/QueryWalker.cs +++ b/src/OpenSearch.Client/QueryDsl/Visitor/QueryWalker.cs @@ -130,6 +130,11 @@ public void Walk(IQueryContainer qd, IQueryVisitor visitor) v.Visit(d); Accept(v, d.Query); }); + VisitQuery(qd.Knn, visitor, (v, d) => + { + v.Visit(d); + Accept(v, d.Filter); + }); VisitQuery(qd.Nested, visitor, (v, d) => { v.Visit(d); diff --git a/tests/Tests.Core/ManagedOpenSearch/NodeSeeders/DefaultSeeder.cs b/tests/Tests.Core/ManagedOpenSearch/NodeSeeders/DefaultSeeder.cs index 99cb3f8879..c766f1cd3f 100644 --- a/tests/Tests.Core/ManagedOpenSearch/NodeSeeders/DefaultSeeder.cs +++ b/tests/Tests.Core/ManagedOpenSearch/NodeSeeders/DefaultSeeder.cs @@ -220,7 +220,10 @@ private Task CreateDeveloperIndexAsync() => Client.Indices. #pragma warning disable 618 private Task CreateProjectIndexAsync() => Client.Indices.CreateAsync(typeof(Project), c => c - .Settings(settings => settings.Analysis(ProjectAnalysisSettings)) + .Settings(settings => settings + .Analysis(ProjectAnalysisSettings) + .Setting("index.knn", true) + .Setting("index.knn.algo_param.ef_search", 100)) .Mappings(ProjectMappings) .Aliases(aliases => aliases .Alias(ProjectsAliasName) @@ -386,7 +389,19 @@ public static PropertiesDescriptor ProjectProperties(Propert .RankFeature(rf => rf .Name(p => p.Rank) .PositiveScoreImpact() - ); + ) + .KnnVector(k => k + .Name(p => p.Vector) + .Dimension(2) + .Method(m => m + .Name("hnsw") + .SpaceType("l2") + .Engine("nmslib") + .Parameters(p => p + .Parameter("ef_construction", 128) + .Parameter("m", 24) + ) + )); return props; } diff --git a/tests/Tests.Domain/Project.cs b/tests/Tests.Domain/Project.cs index 8ae336f02d..0f47a1676d 100644 --- a/tests/Tests.Domain/Project.cs +++ b/tests/Tests.Domain/Project.cs @@ -83,7 +83,7 @@ public class Project public StateOfBeing State { get; set; } public CompletionField Suggest { get; set; } public IEnumerable Tags { get; set; } - + public string Type => TypeName; //the first applies when using internal source serializer the latter when using JsonNetSourceSerializer @@ -91,6 +91,8 @@ public class Project public string VersionControl { get; set; } + public float[] Vector { get; set; } + // @formatter:off — enable formatter after this line public static Faker Generator { get; } = new Faker() @@ -123,7 +125,8 @@ public class Project { "color", new[] { "red", "blue", "green", "violet", "yellow" }.Take(Gimme.Random.Number(1, 4)) } } }) - .RuleFor(p => p.VersionControl, VersionControlConstant); + .RuleFor(p => p.VersionControl, VersionControlConstant) + .RuleFor(p => p.Vector, f => new[] { Gimme.Random.Float(0f, 5f), Gimme.Random.Float(0f, 5f)}); public static IList Projects { get; } = Generator.Clone().Generate(100); @@ -198,7 +201,7 @@ public class Metadata public class ProjectTransform { public double? AverageCommits { get; set; } - + public long WeekStartedOnMillis { get; set; } public DateTime WeekStartedOnDate { get; set; } diff --git a/tests/Tests/Indices/MappingManagement/PutMapping/PutMappingApiTest.cs b/tests/Tests/Indices/MappingManagement/PutMapping/PutMappingApiTest.cs index c7f3fcffd9..b5cf2f7d07 100644 --- a/tests/Tests/Indices/MappingManagement/PutMapping/PutMappingApiTest.cs +++ b/tests/Tests/Indices/MappingManagement/PutMapping/PutMappingApiTest.cs @@ -154,6 +154,21 @@ public PutMappingApiTests(WritableCluster cluster, EndpointUsage usage) : base(c versionControl = new { type = "keyword" + }, + vector = new + { + type = "knn_vector", + dimension = 2, + method = new { + name = "hnsw", + space_type = "l2", + engine = "nmslib", + parameters = new + { + ef_construction = 128, + m = 24 + } + } } } }; @@ -216,6 +231,19 @@ public PutMappingApiTests(WritableCluster cluster, EndpointUsage usage) : base(c .Keyword(k => k .Name(n => n.VersionControl) ) + .KnnVector(k => k + .Name(p => p.Vector) + .Dimension(2) + .Method(m => m + .Name("hnsw") + .SpaceType("l2") + .Engine("nmslib") + .Parameters(p => p + .Parameter("ef_construction", 128) + .Parameter("m", 24) + ) + ) + ) ); protected override HttpMethod HttpMethod => HttpMethod.PUT; @@ -329,7 +357,22 @@ public PutMappingApiTests(WritableCluster cluster, EndpointUsage usage) : base(c } }, { p => p.Rank, new RankFeatureProperty() }, - { p => p.VersionControl, new KeywordProperty() } + { p => p.VersionControl, new KeywordProperty() }, + { p => p.Vector, new KnnVectorProperty + { + Dimension = 2, + Method = new KnnMethod + { + Name = "hnsw", + SpaceType = "l2", + Engine = "nmslib", + Parameters = new KnnMethodParameters + { + {"ef_construction", 128}, + {"m", 24} + } + } + } } } }; diff --git a/tests/Tests/Mapping/Types/Specialized/Knn/KnnVectorPropertyTests.cs b/tests/Tests/Mapping/Types/Specialized/Knn/KnnVectorPropertyTests.cs new file mode 100644 index 0000000000..47b070ed23 --- /dev/null +++ b/tests/Tests/Mapping/Types/Specialized/Knn/KnnVectorPropertyTests.cs @@ -0,0 +1,79 @@ +/* SPDX-License-Identifier: Apache-2.0 +* +* The OpenSearch Contributors require contributions made to +* this file be licensed under the Apache-2.0 license or a +* compatible open source license. +*/ + +using System; +using OpenSearch.Client; +using Tests.Core.ManagedOpenSearch.Clusters; +using Tests.Domain; +using Tests.Framework.EndpointTests.TestState; + +namespace Tests.Mapping.Types.Specialized.Knn +{ + public class KnnVectorPropertyTests : PropertyTestsBase + { + public KnnVectorPropertyTests(WritableCluster cluster, EndpointUsage usage) : base(cluster, usage) { } + + protected override object ExpectJson => new + { + properties = new + { + name = new + { + type = "knn_vector", + dimension = 2, + method = new + { + name = "hnsw", + space_type = "l2", + engine = "nmslib", + parameters = new + { + ef_construction = 128, + m = 24 + } + } + } + } + }; + + protected override Func, IPromise> FluentProperties => f => f + .KnnVector(k => k + .Name(p => p.Name) + .Dimension(2) + .Method(m => m + .Name("hnsw") + .SpaceType("l2") + .Engine("nmslib") + .Parameters(p => p + .Parameter("ef_construction", 128) + .Parameter("m", 24) + ) + ) + ); + + protected override IProperties InitializerProperties => new Properties + { + { + "name", new KnnVectorProperty + { + Dimension = 2, + Method = new KnnMethod + { + Name = "hnsw", + SpaceType = "l2", + Engine = "nmslib", + Parameters = new KnnMethodParameters + { + {"ef_construction", 128}, + {"m", 24} + } + } + } + } + }; + } +} diff --git a/tests/Tests/QueryDsl/Specialized/Knn/KnnQueryUsageTests.cs b/tests/Tests/QueryDsl/Specialized/Knn/KnnQueryUsageTests.cs new file mode 100644 index 0000000000..8150004a20 --- /dev/null +++ b/tests/Tests/QueryDsl/Specialized/Knn/KnnQueryUsageTests.cs @@ -0,0 +1,162 @@ +/* SPDX-License-Identifier: Apache-2.0 +* +* The OpenSearch Contributors require contributions made to +* this file be licensed under the Apache-2.0 license or a +* compatible open source license. +*/ + +using System; +using System.Linq; +using System.Threading.Tasks; +using FluentAssertions; +using OpenSearch.Client; +using OpenSearch.OpenSearch.Xunit.XunitPlumbing; +using Tests.Core.Extensions; +using Tests.Core.ManagedOpenSearch.Clusters; +using Tests.Domain; +using Tests.Framework.EndpointTests.TestState; + +namespace Tests.QueryDsl.Specialized.Knn +{ + public class KnnQueryUsageTests : QueryDslUsageTestsBase + { + public KnnQueryUsageTests(ReadOnlyCluster i, EndpointUsage usage) : base(i, usage) { } + + protected override ConditionlessWhen ConditionlessWhen => new ConditionlessWhen(a => a.Knn) + { + q => + { + q.Field = null; + q.Vector = new[] { 1.5f, -2.6f }; + q.K = 30; + }, + q => + { + q.Field = "knn_vector"; + q.Vector = null; + q.K = 30; + }, + q => + { + q.Field = "knn_vector"; + q.Vector = Array.Empty(); + q.K = 30; + }, + q => + { + q.Field = "knn_vector"; + q.Vector = new[] { 1.5f, 2.6f }; + q.K = null; + }, + q => + { + q.Field = "knn_vector"; + q.Vector = new[] { 1.5f, 2.6f }; + q.K = 0; + } + }; + + protected override QueryContainer QueryInitializer => new KnnQuery + { + Boost = 1.1, Field = Infer.Field(f => f.Vector), Vector = new[] { 1.5f, -2.6f }, K = 30 + }; + + protected override object QueryJson => + new { knn = new { vector = new { boost = 1.1, vector = new[] { 1.5f, -2.6f }, k = 30 } } }; + + protected override QueryContainer QueryFluent(QueryContainerDescriptor q) => q + .Knn(knn => knn + .Boost(1.1) + .Field(f => f.Vector) + .Vector(1.5f, -2.6f) + .K(30) + ); + } + + public class KnnIntegrationTests : IClusterFixture + { + private readonly WritableCluster _cluster; + + public KnnIntegrationTests(WritableCluster cluster) => _cluster = cluster; + + [I] public async Task KnnQuery() + { + var client = _cluster.Client; + const string index = "knn-index"; + + var createIndexResponse = await client.Indices.CreateAsync(index, c => c + .Settings(s => s + .Setting("index.knn", true) + .Setting("index.knn.algo_param.ef_search", 100)) + .Map(m => m + .Properties(p => p + .KnnVector(k => k + .Name(d => d.Vector) + .Dimension(4) + .Method(m => m + .Name("hnsw") + .SpaceType("innerproduct") + .Engine("nmslib") + .Parameters(p => p + .Parameter("ef_construction", 256) + .Parameter("m", 48) + ) + ) + ) + ) + ) + ); + + createIndexResponse.ShouldBeValid(); + + var bulkResponse = await client.BulkAsync(b => b + .Index(index) + .IndexMany(new object[] + { + new Doc(new[] { 1.5f, 5.5f, 4.5f, 6.4f }, 10.3f), + new Doc(new[] { 2.5f, 3.5f, 5.6f, 6.7f }, 5.5f), + new Doc(new[] { 4.5f, 5.5f, 6.7f, 3.7f }, 4.4f), + new Doc(new[] { 1.5f, 5.5f, 4.5f, 6.4f }, 8.9f) + })); + + bulkResponse.ShouldBeValid(); + + var refreshResponse = await client.Indices.RefreshAsync(index); + refreshResponse.ShouldBeValid(); + + var searchResponse = await client.SearchAsync(s => s + .Index(index) + .Size(2) + .Query(q => q + .Knn(k => k + .Field(d => d.Vector) + .Vector(2.0f, 3.0f, 5.0f, 6.0f) + .K(2) + ) + ) + ); + + searchResponse.ShouldBeValid(); + searchResponse + .Documents + .Should() + .BeEquivalentTo(new[] + { + new Doc(new[] { 2.5f, 3.5f, 5.6f, 6.7f }, 5.5f), + new Doc(new[] { 4.5f, 5.5f, 6.7f, 3.7f }, 4.4f), + }); + } + + public class Doc + { + public Doc(float[] vector, float price) + { + Vector = vector; + Price = price; + } + + public float Price { get; set; } + public float[] Vector { get; set; } + } + } +}