Skip to content

Commit

Permalink
Implement neural search query type
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas Farr <[email protected]>
  • Loading branch information
Xtansia committed Aug 22, 2024
1 parent 95daccf commit 296d700
Show file tree
Hide file tree
Showing 10 changed files with 691 additions and 495 deletions.
30 changes: 11 additions & 19 deletions samples/Samples/NeuralSearch/NeuralSearchSample.cs
Original file line number Diff line number Diff line change
Expand Up @@ -159,26 +159,18 @@ protected override async Task Run(IOpenSearchClient client)
Console.WriteLine($"Indexed {documents.Length} documents");

// Perform the neural search
// TODO: Client does not yet contain typings for neural query type
Console.WriteLine("Performing neural search for text 'wild west'");
var searchResp = await client.Http.PostAsync<SearchResponse<NeuralSearchDoc>>(
$"/{IndexName}/_search",
r => r.SerializableBody(new
{
_source = new { excludes = new[] { "passage_embedding" } },
query = new
{
neural = new
{
passage_embedding = new
{
query_text = "wild west",
model_id = _modelId,
k = 5
}
}
}
}));
var searchResp = await client.SearchAsync<NeuralSearchDoc>(s => s
.Index(IndexName)
.Source(sf => sf
.Excludes(f => f
.Field(d => d.PassageEmbedding)))
.Query(q => q
.Neural(n => n
.Field(f => f.PassageEmbedding)
.QueryText("wild west")
.ModelId(_modelId)
.K(5))));
AssertValid(searchResp);
Console.WriteLine($"Found {searchResp.Hits.Count} documents");
foreach (var hit in searchResp.Hits) Console.WriteLine($"- Document id: {hit.Source.Id}, score: {hit.Score}, text: {hit.Source.Text}");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,9 @@ public interface IQueryContainer
[DataMember(Name = "knn")]
IKnnQuery Knn { get; set; }

[DataMember(Name = "neural")]
INeuralQuery Neural { get; set; }

void Accept(IQueryVisitor visitor);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ public partial class QueryContainer : IQueryContainer, IDescriptor
private IMoreLikeThisQuery _moreLikeThis;
private IMultiMatchQuery _multiMatch;
private INestedQuery _nested;
private INeuralQuery _neural;
private IParentIdQuery _parentId;
private IPercolateQuery _percolate;
private IPrefixQuery _prefix;
Expand Down Expand Up @@ -254,6 +255,12 @@ INestedQuery IQueryContainer.Nested
set => _nested = Set(value);
}

INeuralQuery IQueryContainer.Neural
{
get => _neural;
set => _neural = Set(value);
}

IParentIdQuery IQueryContainer.ParentId
{
get => _parentId;
Expand Down

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions src/OpenSearch.Client/QueryDsl/Query.cs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ public static QueryContainer MultiMatch(Func<MultiMatchQueryDescriptor<T>, IMult
public static QueryContainer Nested(Func<NestedQueryDescriptor<T>, INestedQuery> selector) =>
new QueryContainerDescriptor<T>().Nested(selector);

public static QueryContainer Neural(Func<NeuralQueryDescriptor<T>, INeuralQuery> selector) =>
new QueryContainerDescriptor<T>().Neural(selector);

public static QueryContainer ParentId(Func<ParentIdQueryDescriptor<T>, IParentIdQuery> selector) =>
new QueryContainerDescriptor<T>().ParentId(selector);

Expand Down
75 changes: 75 additions & 0 deletions src/OpenSearch.Client/QueryDsl/Specialized/Neural/NeuralQuery.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

using System.Runtime.Serialization;
using OpenSearch.Net.Utf8Json;

namespace OpenSearch.Client;

/// <summary>
/// A neural query.
/// </summary>
[InterfaceDataContract]
[JsonFormatter(typeof(FieldNameQueryFormatter<NeuralQuery, INeuralQuery>))]
public interface INeuralQuery : IFieldNameQuery
{
/// <summary>
/// The query text from which to produce queries.
/// </summary>
[DataMember(Name = "query_text")]
string QueryText { get; set; }

/// <summary>
/// The number of results the k-NN search returns.
/// </summary>
[DataMember(Name = "k")]
int? K { get; set; }

/// <summary>
/// The ID of the model that will be used in the embedding interface.
/// The model must be indexed in OpenSearch before it can be used in Neural Search.
/// </summary>
[DataMember(Name = "model_id")]
string ModelId { get; set; }
}

[DataContract]
public class NeuralQuery : FieldNameQueryBase, INeuralQuery
{
/// <inheritdoc />
public string QueryText { get; set; }
/// <inheritdoc />
public int? K { get; set; }
/// <inheritdoc />
public string ModelId { get; set; }

protected override bool Conditionless => IsConditionless(this);

internal override void InternalWrapInContainer(IQueryContainer container) => container.Neural = this;

internal static bool IsConditionless(INeuralQuery q) => string.IsNullOrEmpty(q.QueryText) || q.K == null || q.K == 0 || string.IsNullOrEmpty(q.ModelId) || q.Field.IsConditionless();
}

public class NeuralQueryDescriptor<T>
: FieldNameQueryDescriptorBase<NeuralQueryDescriptor<T>, INeuralQuery, T>,
INeuralQuery
where T : class
{
protected override bool Conditionless => NeuralQuery.IsConditionless(this);
string INeuralQuery.QueryText { get; set; }
int? INeuralQuery.K { get; set; }
string INeuralQuery.ModelId { get; set; }

/// <inheritdoc cref="INeuralQuery.QueryText" />
public NeuralQueryDescriptor<T> QueryText(string queryText) => Assign(queryText, (a, v) => a.QueryText = v);

/// <inheritdoc cref="INeuralQuery.K" />
public NeuralQueryDescriptor<T> K(int? k) => Assign(k, (a, v) => a.K = v);

/// <inheritdoc cref="INeuralQuery.ModelId" />
public NeuralQueryDescriptor<T> ModelId(string modelId) => Assign(modelId, (a, v) => a.ModelId = v);
}
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ private void WriteShape(IGeoShape shape, IFieldLookup indexedField, Field field,

public virtual void Visit(INestedQuery query) => Write("nested");

public virtual void Visit(INeuralQuery query) => Write("neural", query.Field);

public virtual void Visit(IPrefixQuery query) => Write("prefix");

public virtual void Visit(IQueryStringQuery query) => Write("query_string");
Expand Down
4 changes: 4 additions & 0 deletions src/OpenSearch.Client/QueryDsl/Visitor/QueryVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ public interface IQueryVisitor

void Visit(INestedQuery query);

void Visit(INeuralQuery query);

void Visit(IPrefixQuery query);

void Visit(IQueryStringQuery query);
Expand Down Expand Up @@ -247,6 +249,8 @@ public virtual void Visit(IMultiMatchQuery query) { }

public virtual void Visit(INestedQuery query) { }

public virtual void Visit(INeuralQuery query) { }

public virtual void Visit(IPrefixQuery query) { }

public virtual void Visit(IQueryStringQuery query) { }
Expand Down
1 change: 1 addition & 0 deletions src/OpenSearch.Client/QueryDsl/Visitor/QueryWalker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ public void Walk(IQueryContainer qd, IQueryVisitor visitor)
VisitQuery(qd.Percolate, visitor, (v, d) => v.Visit(d));
VisitQuery(qd.ParentId, visitor, (v, d) => v.Visit(d));
VisitQuery(qd.TermsSet, visitor, (v, d) => v.Visit(d));
VisitQuery(qd.Neural, visitor, (v, d) => v.Visit(d));

VisitQuery(qd.Bool, visitor, (v, d) =>
{
Expand Down
104 changes: 104 additions & 0 deletions tests/Tests/QueryDsl/Specialized/Neural/NeuralQueryUsageTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

using OpenSearch.Client;
using Tests.Core.ManagedOpenSearch.Clusters;
using Tests.Domain;
using Tests.Framework.EndpointTests.TestState;

namespace Tests.QueryDsl.Specialized.Neural;

public class NeuralQueryUsageTests : QueryDslUsageTestsBase
{
public NeuralQueryUsageTests(ReadOnlyCluster i, EndpointUsage usage) : base(i, usage) { }

protected override ConditionlessWhen ConditionlessWhen => new ConditionlessWhen<INeuralQuery>(a => a.Neural)
{
q =>
{
q.Field = null;
q.QueryText = "wild west";
q.K = 5;
q.ModelId = "aFcV879";
},
q =>
{
q.Field = "passage_embedding";
q.QueryText = null;
q.K = 5;
q.ModelId = "aFcV879";
},
q =>
{
q.Field = "passage_embedding";
q.QueryText = "";
q.K = 5;
q.ModelId = "aFcV879";
},
q =>
{
q.Field = "passage_embedding";
q.QueryText = "wild west";
q.K = null;
q.ModelId = "aFcV879";
},
q =>
{
q.Field = "passage_embedding";
q.QueryText = "wild west";
q.K = 0;
q.ModelId = "aFcV879";
},
q =>
{
q.Field = "passage_embedding";
q.QueryText = "wild west";
q.K = 5;
q.ModelId = null;
},
q =>
{
q.Field = "passage_embedding";
q.QueryText = "wild west";
q.K = 5;
q.ModelId = "";
}
};

protected override QueryContainer QueryInitializer => new NeuralQuery
{
Boost = 1.1,
Field = Infer.Field<Project>(f => f.Vector),
QueryText = "wild west",
K = 5,
ModelId = "aFcV879"
};

protected override object QueryJson =>
new
{
neural = new
{
vector = new
{
boost = 1.1,
query_text = "wild west",
k = 5,
model_id = "aFcV879"
}
}
};

protected override QueryContainer QueryFluent(QueryContainerDescriptor<Project> q) => q
.Neural(n => n
.Boost(1.1)
.Field(f => f.Vector)
.QueryText("wild west")
.K(5)
.ModelId("aFcV879")
);
}

0 comments on commit 296d700

Please sign in to comment.