From d125b2b71d6c30d05f1a8b435f1d700326f83735 Mon Sep 17 00:00:00 2001 From: Thomas Farr Date: Wed, 8 May 2024 17:08:26 +1200 Subject: [PATCH] Improve test Signed-off-by: Thomas Farr --- .github/actions/build-opensearch/action.yml | 2 +- .../EphemeralClusterConfiguration.cs | 2 +- .../Tasks/InstallationTasks/InstallPlugins.cs | 19 +- .../IntegrationTestDiscoverer.cs | 177 ++++++------- .../SkipPrereleaseVersionsAttribute.cs | 23 ++ .../XunitPlumbing/SkipVersionAttribute.cs | 55 ++-- .../Products/OpenSearchPlugin.cs | 7 +- .../NeuralSearch/NeuralSearchSample.cs | 1 - .../Infer/Indices/Indices.cs | 16 +- .../Clusters/WritableCluster.cs | 5 +- tests/Tests/Ingest/ProcessorAssertions.cs | 22 +- .../Ingest/PutPipeline/PutPipelineApiTests.cs | 2 +- .../Tests/QueryDsl/QueryDslUsageTestsBase.cs | 157 ++++++------ .../Specialized/Knn/KnnQueryUsageTests.cs | 1 - .../Neural/NeuralQueryUsageTests.cs | 240 ++++++++++++++++-- 15 files changed, 485 insertions(+), 244 deletions(-) create mode 100644 abstractions/src/OpenSearch.OpenSearch.Xunit/XunitPlumbing/SkipPrereleaseVersionsAttribute.cs diff --git a/.github/actions/build-opensearch/action.yml b/.github/actions/build-opensearch/action.yml index 8ad36ee266..11b9bb2254 100644 --- a/.github/actions/build-opensearch/action.yml +++ b/.github/actions/build-opensearch/action.yml @@ -15,7 +15,7 @@ inputs: plugins_output_directory: description: The directory to output the plugins to default: "" -outputs: +outputs: distribution: description: The path to the OpenSearch distribution value: ${{ steps.determine.outputs.distribution }} diff --git a/abstractions/src/OpenSearch.OpenSearch.Ephemeral/EphemeralClusterConfiguration.cs b/abstractions/src/OpenSearch.OpenSearch.Ephemeral/EphemeralClusterConfiguration.cs index 77aaef02c4..970add8fd3 100644 --- a/abstractions/src/OpenSearch.OpenSearch.Ephemeral/EphemeralClusterConfiguration.cs +++ b/abstractions/src/OpenSearch.OpenSearch.Ephemeral/EphemeralClusterConfiguration.cs @@ -74,7 +74,7 @@ public EphemeralClusterConfiguration(OpenSearchVersion version, ClusterFeatures /// This can be useful to fail early when subsequent operations are relying on installation /// succeeding. /// - public bool ValidatePluginsToInstall { get; } = true; + public bool ValidatePluginsToInstall { get; set; } = true; public bool EnableSsl => Features.HasFlag(ClusterFeatures.SSL); diff --git a/abstractions/src/OpenSearch.OpenSearch.Ephemeral/Tasks/InstallationTasks/InstallPlugins.cs b/abstractions/src/OpenSearch.OpenSearch.Ephemeral/Tasks/InstallationTasks/InstallPlugins.cs index 3225476353..57a6982541 100644 --- a/abstractions/src/OpenSearch.OpenSearch.Ephemeral/Tasks/InstallationTasks/InstallPlugins.cs +++ b/abstractions/src/OpenSearch.OpenSearch.Ephemeral/Tasks/InstallationTasks/InstallPlugins.cs @@ -61,30 +61,29 @@ public override void Run(IEphemeralCluster cluste .Where(p => !p.IsValid(v)) .Select(p => p.SubProductName).ToList(); if (invalidPlugins.Any()) - throw new OpenSearchCleanExitException( - $"Can not install the following plugins for version {v}: {string.Join(", ", invalidPlugins)} "); - } + { + throw new OpenSearchCleanExitException( + $"Can not install the following plugins for version {v}: {string.Join(", ", invalidPlugins)} "); + } + } foreach (var plugin in requiredPlugins) { - var includedByDefault = plugin.IsIncludedOutOfTheBox(v); - if (includedByDefault) + if (plugin.IsIncludedOutOfTheBox(v)) { cluster.Writer?.WriteDiagnostic( $"{{{nameof(InstallPlugins)}}} SKIP plugin [{plugin.SubProductName}] shipped OOTB as of: {{{plugin.ShippedByDefaultAsOf}}}"); continue; } - var validForCurrentVersion = plugin.IsValid(v); - if (!validForCurrentVersion) + if (!plugin.IsValid(v)) { cluster.Writer?.WriteDiagnostic( $"{{{nameof(InstallPlugins)}}} SKIP plugin [{plugin.SubProductName}] not valid for version: {{{v}}}"); continue; } - var alreadyInstalled = AlreadyInstalled(fs, plugin.SubProductName); - if (alreadyInstalled) + if (AlreadyInstalled(fs, plugin.SubProductName)) { cluster.Writer?.WriteDiagnostic( $"{{{nameof(InstallPlugins)}}} SKIP plugin [{plugin.SubProductName}] already installed"); @@ -92,7 +91,7 @@ public override void Run(IEphemeralCluster cluste } cluster.Writer?.WriteDiagnostic( - $"{{{nameof(InstallPlugins)}}} attempting install [{plugin.SubProductName}] as it's not OOTB: {{{plugin.ShippedByDefaultAsOf}}} and valid for {v}: {{{plugin.IsValid(v)}}}"); + $"{{{nameof(InstallPlugins)}}} attempting install [{plugin.SubProductName}] as it's not OOTB: {{{plugin.ShippedByDefaultAsOf}}} and valid for {v}"); var homeConfigPath = Path.Combine(fs.OpenSearchHome, "config"); diff --git a/abstractions/src/OpenSearch.OpenSearch.Xunit/XunitPlumbing/IntegrationTestDiscoverer.cs b/abstractions/src/OpenSearch.OpenSearch.Xunit/XunitPlumbing/IntegrationTestDiscoverer.cs index 9e84b0ee01..e7ea9bdf87 100644 --- a/abstractions/src/OpenSearch.OpenSearch.Xunit/XunitPlumbing/IntegrationTestDiscoverer.cs +++ b/abstractions/src/OpenSearch.OpenSearch.Xunit/XunitPlumbing/IntegrationTestDiscoverer.cs @@ -35,102 +35,109 @@ using Xunit; using Xunit.Abstractions; using Xunit.Sdk; -using Enumerable = System.Linq.Enumerable; -namespace OpenSearch.OpenSearch.Xunit.XunitPlumbing +namespace OpenSearch.OpenSearch.Xunit.XunitPlumbing; + +/// +/// A Xunit test that should be skipped, and a reason why. +/// +public abstract class SkipTestAttributeBase : Attribute { - /// - /// An Xunit test that should be skipped, and a reason why. - /// - public abstract class SkipTestAttributeBase : Attribute - { - /// - /// Whether the test should be skipped - /// - public abstract bool Skip { get; } + /// + /// Whether the test should be skipped + /// + public abstract bool Skip { get; } + + /// + /// The reason why the test should be skipped + /// + public abstract string Reason { get; } +} - /// - /// The reason why the test should be skipped - /// - public abstract string Reason { get; } - } +/// +/// An Xunit integration test +/// +[XunitTestCaseDiscoverer("OpenSearch.OpenSearch.Xunit.XunitPlumbing.IntegrationTestDiscoverer", + "OpenSearch.OpenSearch.Xunit")] +public class I : FactAttribute +{ +} - /// - /// An Xunit integration test - /// - [XunitTestCaseDiscoverer("OpenSearch.OpenSearch.Xunit.XunitPlumbing.IntegrationTestDiscoverer", - "OpenSearch.OpenSearch.Xunit")] - public class I : FactAttribute - { - } +/// +/// A test discoverer used to discover integration tests cases attached +/// to test methods that are attributed with attribute +/// +public class IntegrationTestDiscoverer : OpenSearchTestCaseDiscoverer +{ + public IntegrationTestDiscoverer(IMessageSink diagnosticMessageSink) : base(diagnosticMessageSink) + { + } - /// - /// A test discoverer used to discover integration tests cases attached - /// to test methods that are attributed with attribute - /// - public class IntegrationTestDiscoverer : OpenSearchTestCaseDiscoverer - { - public IntegrationTestDiscoverer(IMessageSink diagnosticMessageSink) : base(diagnosticMessageSink) - { - } + /// + protected override bool SkipMethod(ITestFrameworkDiscoveryOptions discoveryOptions, ITestMethod testMethod, + out string skipReason) + { + skipReason = null; + var runIntegrationTests = + discoveryOptions.GetValue(nameof(OpenSearchXunitRunOptions.RunIntegrationTests)); + if (!runIntegrationTests) return true; - /// - protected override bool SkipMethod(ITestFrameworkDiscoveryOptions discoveryOptions, ITestMethod testMethod, - out string skipReason) - { - skipReason = null; - var runIntegrationTests = - discoveryOptions.GetValue(nameof(OpenSearchXunitRunOptions.RunIntegrationTests)); - if (!runIntegrationTests) return true; + var cluster = TestAssemblyRunner.GetClusterForClass(testMethod.TestClass.Class); + if (cluster == null) + { + skipReason += + $"{testMethod.TestClass.Class.Name} does not define a cluster through IClusterFixture or {nameof(IntegrationTestClusterAttribute)}"; + return true; + } - var cluster = TestAssemblyRunner.GetClusterForClass(testMethod.TestClass.Class); - if (cluster == null) - { - skipReason += - $"{testMethod.TestClass.Class.Name} does not define a cluster through IClusterFixture or {nameof(IntegrationTestClusterAttribute)}"; - return true; - } + var openSearchVersion = + discoveryOptions.GetValue(nameof(OpenSearchXunitRunOptions.Version)); - var openSearchVersion = - discoveryOptions.GetValue(nameof(OpenSearchXunitRunOptions.Version)); + // Skip if the version we are testing against is attributed to be skipped do not run the test nameof(SkipVersionAttribute.Ranges) + var skipVersionAttribute = GetAttributes(testMethod).FirstOrDefault(); + if (skipVersionAttribute != null) + { + var skipVersionRanges = + skipVersionAttribute.GetNamedArgument>(nameof(SkipVersionAttribute.Ranges)) ?? + new List(); + if (openSearchVersion == null && skipVersionRanges.Count > 0) + { + skipReason = $"{nameof(SkipVersionAttribute)} has ranges defined for this test but " + + $"no {nameof(OpenSearchXunitRunOptions.Version)} has been provided to {nameof(OpenSearchXunitRunOptions)}"; + return true; + } - // Skip if the version we are testing against is attributed to be skipped do not run the test nameof(SkipVersionAttribute.Ranges) - var skipVersionAttribute = Enumerable.FirstOrDefault(GetAttributes(testMethod)); - if (skipVersionAttribute != null) - { - var skipVersionRanges = - skipVersionAttribute.GetNamedArgument>(nameof(SkipVersionAttribute.Ranges)) ?? - new List(); - if (openSearchVersion == null && skipVersionRanges.Count > 0) - { - skipReason = $"{nameof(SkipVersionAttribute)} has ranges defined for this test but " + - $"no {nameof(OpenSearchXunitRunOptions.Version)} has been provided to {nameof(OpenSearchXunitRunOptions)}"; - return true; - } + if (openSearchVersion != null) + { + var reason = skipVersionAttribute.GetNamedArgument(nameof(SkipVersionAttribute.Reason)); + foreach (var range in skipVersionRanges) + { + // inrange takes prereleases into account + if (!openSearchVersion.InRange(range)) continue; + skipReason = + $"{nameof(SkipVersionAttribute)} has range {range} that {openSearchVersion} satisfies"; + if (!string.IsNullOrWhiteSpace(reason)) skipReason += $": {reason}"; + return true; + } + } + } - if (openSearchVersion != null) - { - var reason = skipVersionAttribute.GetNamedArgument(nameof(SkipVersionAttribute.Reason)); - for (var index = 0; index < skipVersionRanges.Count; index++) - { - var range = skipVersionRanges[index]; - // inrange takes prereleases into account - if (!openSearchVersion.InRange(range)) continue; - skipReason = - $"{nameof(SkipVersionAttribute)} has range {range} that {openSearchVersion} satisfies"; - if (!string.IsNullOrWhiteSpace(reason)) skipReason += $": {reason}"; - return true; - } - } - } + // Skip if a prerelease version and has SkipPrereleaseVersionsAttribute + var skipPrerelease = GetAttributes(testMethod).FirstOrDefault(); + if (openSearchVersion != null && openSearchVersion.IsPreRelease && skipPrerelease != null) + { + skipReason = $"{nameof(SkipPrereleaseVersionsAttribute)} has been applied to this test"; + var reason = skipPrerelease.GetNamedArgument(nameof(SkipVersionAttribute.Reason)); + if (!string.IsNullOrWhiteSpace(reason)) skipReason += $": {reason}"; + return true; + } - var skipTests = GetAttributes(testMethod) - .FirstOrDefault(a => a.GetNamedArgument(nameof(SkipTestAttributeBase.Skip))); + var skipTests = GetAttributes(testMethod) + .FirstOrDefault(a => a.GetNamedArgument(nameof(SkipTestAttributeBase.Skip))); - if (skipTests == null) return false; + if (skipTests == null) return false; - skipReason = skipTests.GetNamedArgument(nameof(SkipTestAttributeBase.Reason)); - return true; - } - } + skipReason = skipTests.GetNamedArgument(nameof(SkipTestAttributeBase.Reason)); + return true; + } } diff --git a/abstractions/src/OpenSearch.OpenSearch.Xunit/XunitPlumbing/SkipPrereleaseVersionsAttribute.cs b/abstractions/src/OpenSearch.OpenSearch.Xunit/XunitPlumbing/SkipPrereleaseVersionsAttribute.cs new file mode 100644 index 0000000000..dee8646e6f --- /dev/null +++ b/abstractions/src/OpenSearch.OpenSearch.Xunit/XunitPlumbing/SkipPrereleaseVersionsAttribute.cs @@ -0,0 +1,23 @@ +/* SPDX-License-Identifier: Apache-2.0 +* +* The OpenSearch Contributors require contributions made to +* this file be licensed under the Apache-2.0 license or a +* compatible open source license. +*/ + +using System; + +namespace OpenSearch.OpenSearch.Xunit.XunitPlumbing; + +/// +/// A Xunit test that should be skipped for prerelease OpenSearch versions, and a reason why. +/// +public class SkipPrereleaseVersionsAttribute : Attribute +{ + public SkipPrereleaseVersionsAttribute(string reason) => Reason = reason; + + /// + /// The reason why the test should be skipped + /// + public string Reason { get; } +} diff --git a/abstractions/src/OpenSearch.OpenSearch.Xunit/XunitPlumbing/SkipVersionAttribute.cs b/abstractions/src/OpenSearch.OpenSearch.Xunit/XunitPlumbing/SkipVersionAttribute.cs index cfeec7b8da..e885718fe2 100644 --- a/abstractions/src/OpenSearch.OpenSearch.Xunit/XunitPlumbing/SkipVersionAttribute.cs +++ b/abstractions/src/OpenSearch.OpenSearch.Xunit/XunitPlumbing/SkipVersionAttribute.cs @@ -31,35 +31,34 @@ using System.Linq; using SemanticVersioning; -namespace OpenSearch.OpenSearch.Xunit.XunitPlumbing +namespace OpenSearch.OpenSearch.Xunit.XunitPlumbing; + +/// +/// A Xunit test that should be skipped for given OpenSearch versions, and a reason why. +/// +public class SkipVersionAttribute : Attribute { - /// - /// An Xunit test that should be skipped for given OpenSearch versions, and a reason why. - /// - public class SkipVersionAttribute : Attribute - { - // ReSharper disable once UnusedParameter.Local - // reason is used to allow the test its used on to self document why its been put in place - public SkipVersionAttribute(string skipVersionRangesSeparatedByComma, string reason) - { - Reason = reason; - Ranges = string.IsNullOrEmpty(skipVersionRangesSeparatedByComma) - ? new List() - : skipVersionRangesSeparatedByComma.Split(',') - .Select(r => r.Trim()) - .Where(r => !string.IsNullOrWhiteSpace(r)) - .Select(r => new Range(r)) - .ToList(); - } + // ReSharper disable once UnusedParameter.Local + // reason is used to allow the test its used on to self document why its been put in place + public SkipVersionAttribute(string skipVersionRangesSeparatedByComma, string reason) + { + Reason = reason; + Ranges = string.IsNullOrEmpty(skipVersionRangesSeparatedByComma) + ? new List() + : skipVersionRangesSeparatedByComma.Split(',') + .Select(r => r.Trim()) + .Where(r => !string.IsNullOrWhiteSpace(r)) + .Select(r => new Range(r)) + .ToList(); + } - /// - /// The reason why the test should be skipped - /// - public string Reason { get; } + /// + /// The reason why the test should be skipped + /// + public string Reason { get; } - /// - /// The version ranges for which the test should be skipped - /// - public IList Ranges { get; } - } + /// + /// The version ranges for which the test should be skipped + /// + public IList Ranges { get; } } diff --git a/abstractions/src/OpenSearch.Stack.ArtifactsApi/Products/OpenSearchPlugin.cs b/abstractions/src/OpenSearch.Stack.ArtifactsApi/Products/OpenSearchPlugin.cs index 536cebb3ee..f2ac62ab0e 100644 --- a/abstractions/src/OpenSearch.Stack.ArtifactsApi/Products/OpenSearchPlugin.cs +++ b/abstractions/src/OpenSearch.Stack.ArtifactsApi/Products/OpenSearchPlugin.cs @@ -27,6 +27,7 @@ */ using System; +using Version = SemanticVersioning.Version; namespace OpenSearch.Stack.ArtifactsApi.Products { @@ -81,5 +82,9 @@ public OpenSearchPlugin(string plugin, Func isValid = n public static OpenSearchPlugin DeleteByQuery { get; } = new("delete-by-query", version => version < "1.0.0"); public static OpenSearchPlugin Knn { get; } = new("opensearch-knn"); - } + + public static OpenSearchPlugin MachineLearning { get; } = new("opensearch-ml", v => v.BaseVersion() >= new Version("1.3.0") && !v.IsPreRelease); + + public static OpenSearchPlugin NeuralSearch { get; } = new("opensearch-neural-search", v => v.BaseVersion() >= new Version("2.4.0") && !v.IsPreRelease); + } } diff --git a/samples/Samples/NeuralSearch/NeuralSearchSample.cs b/samples/Samples/NeuralSearch/NeuralSearchSample.cs index aafdf6400a..aeb2c28a80 100644 --- a/samples/Samples/NeuralSearch/NeuralSearchSample.cs +++ b/samples/Samples/NeuralSearch/NeuralSearchSample.cs @@ -5,7 +5,6 @@ * compatible open source license. */ -using System.Diagnostics; using OpenSearch.Client; using OpenSearch.Net; diff --git a/src/OpenSearch.Client/CommonAbstractions/Infer/Indices/Indices.cs b/src/OpenSearch.Client/CommonAbstractions/Infer/Indices/Indices.cs index 714f43b474..a6b83805b1 100644 --- a/src/OpenSearch.Client/CommonAbstractions/Infer/Indices/Indices.cs +++ b/src/OpenSearch.Client/CommonAbstractions/Infer/Indices/Indices.cs @@ -46,12 +46,12 @@ internal Indices(ManyIndices indices) : base(indices) { } internal Indices(IEnumerable indices) : base(new ManyIndices(indices)) { } /// All indices. Represents _all - public static Indices All { get; } = new Indices(new AllIndicesMarker()); + public static Indices All { get; } = new(new AllIndicesMarker()); /// - public static Indices AllIndices { get; } = All; + public static Indices AllIndices => All; - private string DebugDisplay => Match( + private string DebugDisplay => Match( all => "_all", types => $"Count: {types.Indices.Count} [" + string.Join(",", types.Indices.Select((t, i) => $"({i + 1}: {t.DebugDisplay})")) + "]" ); @@ -62,11 +62,13 @@ string IUrlParameter.GetString(IConnectionConfigurationValues settings) => Match all => "_all", many => { - if (!(settings is IConnectionSettingsValues oscSettings)) - throw new Exception( - "Tried to pass index names on querysting but it could not be resolved because no OpenSearch.Client settings are available"); + if (settings is not IConnectionSettingsValues oscSettings) + { + throw new Exception( + "Tried to pass index names on querysting but it could not be resolved because no OpenSearch.Client settings are available"); + } - var infer = oscSettings.Inferrer; + var infer = oscSettings.Inferrer; var indices = many.Indices.Select(i => infer.IndexName(i)).Distinct(); return string.Join(",", indices); } diff --git a/tests/Tests.Core/ManagedOpenSearch/Clusters/WritableCluster.cs b/tests/Tests.Core/ManagedOpenSearch/Clusters/WritableCluster.cs index 1f8457c82f..145cc6b15c 100644 --- a/tests/Tests.Core/ManagedOpenSearch/Clusters/WritableCluster.cs +++ b/tests/Tests.Core/ManagedOpenSearch/Clusters/WritableCluster.cs @@ -47,10 +47,13 @@ private static ClientTestClusterConfiguration CreateConfiguration() => AnalysisIcu, AnalysisKuromoji, AnalysisNori, AnalysisPhonetic, IngestAttachment, IngestGeoIp, Knn, + MachineLearning, MapperMurmur3, + NeuralSearch, Security) { - MaxConcurrency = 4 + MaxConcurrency = 4, + ValidatePluginsToInstall = false }; protected override void SeedNode() diff --git a/tests/Tests/Ingest/ProcessorAssertions.cs b/tests/Tests/Ingest/ProcessorAssertions.cs index ec052776d8..9b70cb0fce 100644 --- a/tests/Tests/Ingest/ProcessorAssertions.cs +++ b/tests/Tests/Ingest/ProcessorAssertions.cs @@ -63,11 +63,21 @@ public abstract class ProcessorAssertion : IProcessorAssertion public static class ProcessorAssertions { public static IEnumerable All => - from t in typeof(ProcessorAssertions).GetNestedTypes() - where typeof(IProcessorAssertion).IsAssignableFrom(t) && t.IsClass - let a = t.GetCustomAttributes(typeof(SkipVersionAttribute)).FirstOrDefault() as SkipVersionAttribute - where a == null || !a.Ranges.Any(r => r.IsSatisfied(TestClient.Configuration.OpenSearchVersion)) - select (IProcessorAssertion)Activator.CreateInstance(t); + typeof(ProcessorAssertions).GetNestedTypes() + .Where(t => + { + if (!t.IsClass || !typeof(IProcessorAssertion).IsAssignableFrom(t)) return false; + + var skipVersion = t.GetCustomAttributes().FirstOrDefault(); + if (skipVersion != null && skipVersion.Ranges.Any(r => r.IsSatisfied(TestClient.Configuration.OpenSearchVersion))) + return false; + + var skipPrereleases = t.GetCustomAttributes().FirstOrDefault(); + if (skipPrereleases != null && TestClient.Configuration.OpenSearchVersion.IsPreRelease) return false; + + return true; + }) + .Select(t => (IProcessorAssertion)Activator.CreateInstance(t)); public static IProcessor[] Initializers => All.Select(a => a.Initializer).ToArray(); @@ -594,6 +604,8 @@ public class Pipeline : ProcessorAssertion public override string Key => "pipeline"; } + [SkipVersion("<2.4.0", "neural search plugin was released with v2.4.0")] + [SkipPrereleaseVersions("Prerelease versions of OpenSearch do not include the ML & Neural Search plugins")] public class TextEmbedding : ProcessorAssertion { private class NeuralSearchDoc diff --git a/tests/Tests/Ingest/PutPipeline/PutPipelineApiTests.cs b/tests/Tests/Ingest/PutPipeline/PutPipelineApiTests.cs index 3d064a9d4e..53c06a61ed 100644 --- a/tests/Tests/Ingest/PutPipeline/PutPipelineApiTests.cs +++ b/tests/Tests/Ingest/PutPipeline/PutPipelineApiTests.cs @@ -55,7 +55,7 @@ public PutPipelineApiTests(WritableCluster cluster, EndpointUsage usage) : base( processors = ProcessorAssertions.AllAsJson }; -protected override int ExpectStatusCode => 200; + protected override int ExpectStatusCode => 200; protected override Func Fluent => d => d .Description("My test pipeline") diff --git a/tests/Tests/QueryDsl/QueryDslUsageTestsBase.cs b/tests/Tests/QueryDsl/QueryDslUsageTestsBase.cs index eaeaff5800..b2cd56a7d6 100644 --- a/tests/Tests/QueryDsl/QueryDslUsageTestsBase.cs +++ b/tests/Tests/QueryDsl/QueryDslUsageTestsBase.cs @@ -34,6 +34,7 @@ using FluentAssertions; using OpenSearch.Client; using Newtonsoft.Json; +using OpenSearch.OpenSearch.Ephemeral; using Tests.Core.Client; using Tests.Core.Extensions; using Tests.Core.ManagedOpenSearch.Clusters; @@ -41,105 +42,111 @@ using Tests.Framework.EndpointTests; using Tests.Framework.EndpointTests.TestState; -namespace Tests.QueryDsl +namespace Tests.QueryDsl; + +public abstract class QueryDslUsageTestsBase + : ApiTestBase, ISearchRequest, SearchDescriptor, SearchRequest> + where TCluster : IEphemeralCluster, IOpenSearchClientTestCluster, new() + where TDocument : class { - public abstract class QueryDslUsageTestsBase - : ApiTestBase, ISearchRequest, SearchDescriptor, SearchRequest> - { - protected readonly QueryContainer ConditionlessQuery = new QueryContainer(new TermQuery()); + protected QueryDslUsageTestsBase(TCluster cluster, EndpointUsage usage) : base(cluster, usage) { } + + protected abstract IndexName IndexName { get; } + protected abstract string ExpectedIndexString { get; } - protected readonly QueryContainer VerbatimQuery = new QueryContainer(new TermQuery { IsVerbatim = true }); + protected virtual ConditionlessWhen ConditionlessWhen => null; - protected byte[] ShortFormQuery => Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(new { description = "project description" })); + protected override object ExpectJson => new { query = QueryJson }; - protected QueryDslUsageTestsBase(ReadOnlyCluster cluster, EndpointUsage usage) : base(cluster, usage) { } + protected override Func, ISearchRequest> Fluent => s => s + .Index(IndexName) + .Query(QueryFluent); - protected virtual ConditionlessWhen ConditionlessWhen => null; + protected override HttpMethod HttpMethod => HttpMethod.POST; - protected override object ExpectJson => new { query = QueryJson }; + protected override SearchRequest Initializer => + new(IndexName) + { + Query = QueryInitializer + }; - protected override Func, ISearchRequest> Fluent => s => s - .Query(q => QueryFluent(q)); + protected virtual NotConditionlessWhen NotConditionlessWhen => null; - protected override HttpMethod HttpMethod => HttpMethod.POST; + protected abstract QueryContainer QueryInitializer { get; } - protected override SearchRequest Initializer => - new SearchRequest - { - Query = QueryInitializer - }; + protected abstract object QueryJson { get; } + protected override string UrlPath => $"/{ExpectedIndexString}/_search"; - protected virtual bool KnownParseException => false; + protected override LazyResponses ClientUsage() => Calls( + (client, f) => client.Search(f), + (client, f) => client.SearchAsync(f), + (client, r) => client.Search(r), + (client, r) => client.SearchAsync(r) + ); - protected virtual NotConditionlessWhen NotConditionlessWhen => null; + protected abstract QueryContainer QueryFluent(QueryContainerDescriptor q); - protected abstract QueryContainer QueryInitializer { get; } + [U] public void FluentIsNotConditionless() => + AssertIsNotConditionless(QueryFluent(new QueryContainerDescriptor())); - protected abstract object QueryJson { get; } - protected override string UrlPath => "/project/_search"; + [U] public void InitializerIsNotConditionless() => AssertIsNotConditionless(QueryInitializer); - protected override LazyResponses ClientUsage() => Calls( - (client, f) => client.Search(f), - (client, f) => client.SearchAsync(f), - (client, r) => client.Search(r), - (client, r) => client.SearchAsync(r) - ); + private void AssertIsNotConditionless(IQueryContainer c) + { + if (!c.IsVerbatim) + c.IsConditionless.Should().BeFalse(); + } - protected abstract QueryContainer QueryFluent(QueryContainerDescriptor q); + [U] public void SeenByVisitor() + { + var visitor = new DslPrettyPrintVisitor(TestClient.DefaultInMemoryClient.ConnectionSettings); + var query = QueryFluent(new QueryContainerDescriptor()); + query.Should().NotBeNull("query evaluated to null which implies it may be conditionless"); + query.Accept(visitor); + var pretty = visitor.PrettyPrint; + pretty.Should().NotBeNullOrWhiteSpace(); + } - [U] public void FluentIsNotConditionless() => - AssertIsNotConditionless(QueryFluent(new QueryContainerDescriptor())); + [U] public void ConditionlessWhenExpectedToBe() + { + if (ConditionlessWhen == null) return; - [U] public void InitializerIsNotConditionless() => AssertIsNotConditionless(QueryInitializer); + foreach (var when in ConditionlessWhen) + { + when(QueryFluent(new QueryContainerDescriptor())); + when(QueryInitializer); + } - private void AssertIsNotConditionless(IQueryContainer c) - { - if (!c.IsVerbatim) - c.IsConditionless.Should().BeFalse(); - } + ((IQueryContainer)QueryInitializer).IsConditionless.Should().BeFalse(); + } - [U] public void SeenByVisitor() - { - var visitor = new DslPrettyPrintVisitor(TestClient.DefaultInMemoryClient.ConnectionSettings); - var query = QueryFluent(new QueryContainerDescriptor()); - query.Should().NotBeNull("query evaluated to null which implies it may be conditionless"); - query.Accept(visitor); - var pretty = visitor.PrettyPrint; - pretty.Should().NotBeNullOrWhiteSpace(); - } + [U] public void NotConditionlessWhenExpectedToBe() + { + if (NotConditionlessWhen == null) return; - [U] public void ConditionlessWhenExpectedToBe() - { - if (ConditionlessWhen == null) return; + foreach (var when in NotConditionlessWhen) + { + when(QueryFluent(new QueryContainerDescriptor())); + when(QueryInitializer); + } + } - foreach (var when in ConditionlessWhen) - { - when(QueryFluent(new QueryContainerDescriptor())); - //this.JsonEquals(query, new { }); - when(QueryInitializer); - //this.JsonEquals(query, new { }); - } + [I] protected async Task AssertQueryResponse() => await AssertOnAllResponses(AssertQueryResponseValid); - ((IQueryContainer)QueryInitializer).IsConditionless.Should().BeFalse(); - } + protected virtual void AssertQueryResponseValid(ISearchResponse response) => response.ShouldBeValid(); +} + +public abstract class QueryDslUsageTestsBase + : QueryDslUsageTestsBase +{ + protected static byte[] ShortFormQuery => Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(new { description = "project description" })); - [U] public void NotConditionlessWhenExpectedToBe() - { - if (NotConditionlessWhen == null) return; + protected static readonly QueryContainer ConditionlessQuery = new(new TermQuery()); - foreach (var when in NotConditionlessWhen) - { - var query = QueryFluent(new QueryContainerDescriptor()); - when(query); + protected static readonly QueryContainer VerbatimQuery = new(new TermQuery { IsVerbatim = true }); - query = QueryInitializer; - when(query); - } - } + protected QueryDslUsageTestsBase(ReadOnlyCluster cluster, EndpointUsage usage) : base(cluster, usage) { } - [I] protected async Task AssertQueryResponse() => await AssertOnAllResponses(r => - { - r.ShouldBeValid(); - }); - } + protected override IndexName IndexName => typeof(Project); + protected override string ExpectedIndexString => "project"; } diff --git a/tests/Tests/QueryDsl/Specialized/Knn/KnnQueryUsageTests.cs b/tests/Tests/QueryDsl/Specialized/Knn/KnnQueryUsageTests.cs index 8150004a20..98d20ac6e4 100644 --- a/tests/Tests/QueryDsl/Specialized/Knn/KnnQueryUsageTests.cs +++ b/tests/Tests/QueryDsl/Specialized/Knn/KnnQueryUsageTests.cs @@ -6,7 +6,6 @@ */ using System; -using System.Linq; using System.Threading.Tasks; using FluentAssertions; using OpenSearch.Client; diff --git a/tests/Tests/QueryDsl/Specialized/Neural/NeuralQueryUsageTests.cs b/tests/Tests/QueryDsl/Specialized/Neural/NeuralQueryUsageTests.cs index 69b6a80578..4a67c3cd08 100644 --- a/tests/Tests/QueryDsl/Specialized/Neural/NeuralQueryUsageTests.cs +++ b/tests/Tests/QueryDsl/Specialized/Neural/NeuralQueryUsageTests.cs @@ -5,16 +5,70 @@ * compatible open source license. */ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using FluentAssertions; using OpenSearch.Client; +using OpenSearch.Net; +using OpenSearch.OpenSearch.Xunit.XunitPlumbing; +using Tests.Core.Extensions; using Tests.Core.ManagedOpenSearch.Clusters; -using Tests.Domain; using Tests.Framework.EndpointTests.TestState; +using Version = SemanticVersioning.Version; namespace Tests.QueryDsl.Specialized.Neural; -public class NeuralQueryUsageTests : QueryDslUsageTestsBase +public class NeuralSearchDoc { - public NeuralQueryUsageTests(ReadOnlyCluster i, EndpointUsage usage) : base(i, usage) { } + [PropertyName("id")] public string Id { get; set; } + [PropertyName("text")] public string Text { get; set; } + [PropertyName("passage_embedding")] public float[] PassageEmbedding { get; set; } +} + +[SkipVersion("<2.6.0", "Avoid the various early permutations of the ML APIs")] +[SkipPrereleaseVersions("Prerelease versions of OpenSearch do not include the ML & Neural Search plugins")] +public class NeuralQueryUsageTests + : QueryDslUsageTestsBase +{ + private static readonly string TestName = nameof(NeuralQueryUsageTests).ToLowerInvariant(); + + private string _modelGroupId; + private string _modelId = "default-for-unit-tests"; + + public NeuralQueryUsageTests(WritableCluster cluster, EndpointUsage usage) : base(cluster, usage) { } + + protected override IndexName IndexName => TestName; + protected override string ExpectedIndexString => TestName; + + protected override QueryContainer QueryInitializer => new NeuralQuery + { + Field = Infer.Field(d => d.PassageEmbedding), + QueryText = "wild west", + K = 5, + ModelId = _modelId + }; + + protected override object QueryJson => new + { + neural = new + { + passage_embedding = new + { + query_text = "wild west", + k = 5, + model_id = _modelId + } + } + }; + + protected override QueryContainer QueryFluent(QueryContainerDescriptor q) => q + .Neural(n => n + .Field(f => f.PassageEmbedding) + .QueryText("wild west") + .K(5) + .ModelId(_modelId)); protected override ConditionlessWhen ConditionlessWhen => new ConditionlessWhen(a => a.Neural) { @@ -69,36 +123,168 @@ public NeuralQueryUsageTests(ReadOnlyCluster i, EndpointUsage usage) : base(i, u } }; - protected override QueryContainer QueryInitializer => new NeuralQuery + protected override void IntegrationSetup(IOpenSearchClient client, CallUniqueValues values) { - Boost = 1.1, - Field = Infer.Field(f => f.Vector), - QueryText = "wild west", - K = 5, - ModelId = "aFcV879" - }; + var baseVersion = Cluster.ClusterConfiguration.Version.BaseVersion(); + var renamedToRegisterDeploy = baseVersion >= new Version("2.7.0"); + var hasModelAccessControl = baseVersion >= new Version("2.8.0"); - protected override object QueryJson => - new + var settings = new Dictionary { - neural = new + ["plugins.ml_commons.only_run_on_ml_node"] = false, + ["plugins.ml_commons.native_memory_threshold"] = 99 + }; + + if (hasModelAccessControl) + settings["plugins.ml_commons.model_access_control_enabled"] = true; + + if (settings.Count > 0) + { + var putSettingsResp = client.Cluster.PutSettings(new ClusterPutSettingsRequest { - vector = new + Transient = settings + }); + putSettingsResp.ShouldBeValid(); + } + + if (hasModelAccessControl) + { + var registerModelGroupResp = client.Http.Post( + "/_plugins/_ml/model_groups/_register", + r => r.SerializableBody(new { - boost = 1.1, - query_text = "wild west", - k = 5, - model_id = "aFcV879" - } + name = TestName, + access_mode = "public", + model_access_mode = "public" + })); + registerModelGroupResp.ShouldBeCreated(); + _modelGroupId = (string)registerModelGroupResp.Body.model_group_id; + } + + var registerModelResp = client.Http.Post( + $"/_plugins/_ml/models/{(renamedToRegisterDeploy ? "_register" : "_upload")}", + r => r.SerializableBody(new + { + name = "huggingface/sentence-transformers/msmarco-distilbert-base-tas-b", + version = "1.0.1", + model_group_id = _modelGroupId, + model_format = "TORCH_SCRIPT" + })); + registerModelResp.ShouldBeCreated(); + var modelRegistrationTaskId = (string) registerModelResp.Body.task_id; + + while (true) + { + var getTaskResp = client.Http.Get($"/_plugins/_ml/tasks/{modelRegistrationTaskId}"); + getTaskResp.ShouldNotBeFailed(); + if (((string)getTaskResp.Body.state).StartsWith("COMPLETED")) + { + _modelId = getTaskResp.Body.model_id; + break; } + Thread.Sleep(5000); + } + + var deployModelResp = client.Http.Post($"/_plugins/_ml/models/{_modelId}/{(renamedToRegisterDeploy ? "_deploy" : "_load")}"); + deployModelResp.ShouldBeCreated(); + var modelDeployTaskId = (string) deployModelResp.Body.task_id; + + while (true) + { + var getTaskResp = client.Http.Get($"/_plugins/_ml/tasks/{modelDeployTaskId}"); + getTaskResp.ShouldNotBeFailed(); + if (((string)getTaskResp.Body.state).StartsWith("COMPLETED")) break; + Thread.Sleep(5000); + } + + var putIngestPipelineResp = client.Ingest.PutPipeline(TestName, p => p + .Processors(pp => pp + .TextEmbedding(te => te + .ModelId(_modelId) + .FieldMap(fm => fm + .Map(d => d.Text, d => d.PassageEmbedding))))); + putIngestPipelineResp.ShouldBeValid(); + + var createIndexResp = client.Indices.Create( + IndexName, + i => i + .Settings(s => s + .Setting("index.knn", true) + .DefaultPipeline(TestName)) + .Map(m => m + .Properties(p => p + .Text(t => t.Name(d => d.Id)) + .Text(t => t.Name(d => d.Text)) + .KnnVector(k => k + .Name(d => d.PassageEmbedding) + .Dimension(768) + .Method(km => km + .Engine("lucene") + .SpaceType("l2") + .Name("hnsw")))))); + createIndexResp.ShouldBeValid(); + + var documents = new NeuralSearchDoc[] + { + new() { Id = "4319130149.jpg", Text = "A West Virginia university women 's basketball team , officials , and a small gathering of fans are in a West Virginia arena ." }, + new() { Id = "1775029934.jpg", Text = "A wild animal races across an uncut field with a minimal amount of trees ." }, + new() { Id = "2664027527.jpg", Text = "People line the stands which advertise Freemont 's orthopedics , a cowboy rides a light brown bucking bronco ." }, + new() { Id = "4427058951.jpg", Text = "A man who is riding a wild horse in the rodeo is very near to falling off ." }, + new() { Id = "2691147709.jpg", Text = "A rodeo cowboy , wearing a cowboy hat , is being thrown off of a wild white horse ." } }; + var bulkResp = client.Bulk(b => b + .Index(IndexName) + .IndexMany(documents) + .Refresh(Refresh.WaitFor)); + bulkResp.ShouldBeValid(); + } - protected override QueryContainer QueryFluent(QueryContainerDescriptor q) => q - .Neural(n => n - .Boost(1.1) - .Field(f => f.Vector) - .QueryText("wild west") - .K(5) - .ModelId("aFcV879") - ); + protected override void AssertQueryResponseValid(ISearchResponse response) + { + base.AssertQueryResponseValid(response); + + response.Hits.Should().HaveCount(5); + var hit = response.Hits.First(); + + hit.Id.Should().Be("4427058951.jpg"); + hit.Score.Should().BeApproximately(0.01585195, 0.00000001); + hit.Source.Text.Should().Be("A man who is riding a wild horse in the rodeo is very near to falling off ."); + hit.Source.PassageEmbedding.Should().HaveCount(768); + } + + protected override void IntegrationTeardown(IOpenSearchClient client, CallUniqueValues values) + { + client.Indices.Delete(IndexName); + client.Ingest.DeletePipeline(TestName); + + if (_modelId != "default-for-unit-tests") + { + while (true) + { + var deleteModelResp = client.Http.Delete($"/_plugins/_ml/models/{_modelId}"); + if (deleteModelResp.Success || !(((string)deleteModelResp.Body.error?.reason)?.Contains("Try undeploy") ?? false)) break; + + client.Http.Post($"/_plugins/_ml/models/{_modelId}/_undeploy"); + Thread.Sleep(5000); + } + } + + if (_modelGroupId != null) + { + client.Http.Delete($"/_plugins/_ml/model_groups/{_modelGroupId}"); + } + } +} + +internal static class Helpers +{ + public static void ShouldBeCreated(this DynamicResponse r) + { + if (!r.Success || r.Body.status != "CREATED") throw new Exception("Expected to be created, was: " + r.DebugInformation); + } + + public static void ShouldNotBeFailed(this DynamicResponse r) + { + if (!r.Success || r.Body.state == "FAILED") throw new Exception("Expected to not be failed, was: " + r.DebugInformation); + } }