Skip to content

Commit

Permalink
Separate neural query cluster configuration
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas Farr <[email protected]>
  • Loading branch information
Xtansia committed May 14, 2024
1 parent d125b2b commit 66c91c3
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 23 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ jobs:
name: Integration Tests
working-directory: client

# Run neural query integration tests separately as they use a significant amount of memory on their own
- run: "./build.sh integrate ${{ matrix.version }} neuralquery random:test_only_one --report"
name: Neural Query Integration Tests
working-directory: client

- name: Upload test report
if: failure()
uses: actions/upload-artifact@v3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ private static ClientTestClusterConfiguration CreateConfiguration() =>
Knn,
MachineLearning,
MapperMurmur3,
NeuralSearch,
Security)
{
MaxConcurrency = 4,
Expand Down
49 changes: 27 additions & 22 deletions tests/Tests/QueryDsl/Specialized/Neural/NeuralQueryUsageTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,44 @@
*/

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using FluentAssertions;
using OpenSearch.Client;
using OpenSearch.Net;
using OpenSearch.OpenSearch.Xunit.XunitPlumbing;
using OpenSearch.Stack.ArtifactsApi.Products;
using Tests.Core.Extensions;
using Tests.Core.ManagedOpenSearch.Clusters;
using Tests.Framework.EndpointTests.TestState;
using Version = SemanticVersioning.Version;

namespace Tests.QueryDsl.Specialized.Neural;

public class NeuralQueryCluster : ClientTestClusterBase
{
public NeuralQueryCluster() : base(CreateConfiguration()) { }

private static ClientTestClusterConfiguration CreateConfiguration()
{
var config = new ClientTestClusterConfiguration(
OpenSearchPlugin.Knn,
OpenSearchPlugin.MachineLearning,
OpenSearchPlugin.NeuralSearch,
OpenSearchPlugin.Security)
{
MaxConcurrency = 4,
ValidatePluginsToInstall = false,
};

config.DefaultNodeSettings.Add("plugins.ml_commons.only_run_on_ml_node", "false");
config.DefaultNodeSettings.Add("plugins.ml_commons.native_memory_threshold", "99");
config.DefaultNodeSettings.Add("plugins.ml_commons.model_access_control_enabled", "true", ">=2.8.0");

return config;
}
}

public class NeuralSearchDoc
{
[PropertyName("id")] public string Id { get; set; }
Expand All @@ -28,16 +52,15 @@ public class NeuralSearchDoc
}

[SkipVersion("<2.6.0", "Avoid the various early permutations of the ML APIs")]
[SkipPrereleaseVersions("Prerelease versions of OpenSearch do not include the ML & Neural Search plugins")]
public class NeuralQueryUsageTests
: QueryDslUsageTestsBase<WritableCluster, NeuralSearchDoc>
: QueryDslUsageTestsBase<NeuralQueryCluster, NeuralSearchDoc>
{
private static readonly string TestName = nameof(NeuralQueryUsageTests).ToLowerInvariant();

private string _modelGroupId;
private string _modelId = "default-for-unit-tests";

public NeuralQueryUsageTests(WritableCluster cluster, EndpointUsage usage) : base(cluster, usage) { }
public NeuralQueryUsageTests(NeuralQueryCluster cluster, EndpointUsage usage) : base(cluster, usage) { }

protected override IndexName IndexName => TestName;
protected override string ExpectedIndexString => TestName;
Expand Down Expand Up @@ -129,24 +152,6 @@ protected override void IntegrationSetup(IOpenSearchClient client, CallUniqueVal
var renamedToRegisterDeploy = baseVersion >= new Version("2.7.0");
var hasModelAccessControl = baseVersion >= new Version("2.8.0");

var settings = new Dictionary<string, object>
{
["plugins.ml_commons.only_run_on_ml_node"] = false,
["plugins.ml_commons.native_memory_threshold"] = 99
};

if (hasModelAccessControl)
settings["plugins.ml_commons.model_access_control_enabled"] = true;

if (settings.Count > 0)
{
var putSettingsResp = client.Cluster.PutSettings(new ClusterPutSettingsRequest
{
Transient = settings
});
putSettingsResp.ShouldBeValid();
}

if (hasModelAccessControl)
{
var registerModelGroupResp = client.Http.Post<DynamicResponse>(
Expand Down

0 comments on commit 66c91c3

Please sign in to comment.