Skip to content

Commit

Permalink
adds hnsw graph based storage & query layer to events correlation engine
Browse files Browse the repository at this point in the history
Signed-off-by: Subhobrata Dey <[email protected]>
  • Loading branch information
sbcd90 committed May 17, 2023
1 parent a1e42b1 commit dcfd057
Show file tree
Hide file tree
Showing 26 changed files with 2,639 additions and 1 deletion.
2 changes: 2 additions & 0 deletions plugins/events-correlation-engine/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,7 @@ opensearchplugin {
classname 'org.opensearch.plugin.correlation.EventsCorrelationPlugin'
}

// forbiddenApis.ignoreFailures = true

dependencies {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.plugin.correlation;

import org.apache.hc.core5.http.Header;
import org.apache.hc.core5.http.HttpEntity;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.junit.Assert;
import org.opensearch.client.Request;
import org.opensearch.client.RequestOptions;
import org.opensearch.client.Response;
import org.opensearch.client.RestClient;
import org.opensearch.client.WarningsHandler;
import org.opensearch.common.Strings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.IndexSettings;
import org.opensearch.rest.RestStatus;
import org.opensearch.test.rest.OpenSearchRestTestCase;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;

/**
* Correlation Vectors Engine e2e tests
*/
public class CorrelationVectorsEngineIT extends OpenSearchRestTestCase {

private static final int DIMENSION = 4;
private static final String PROPERTIES_FIELD_NAME = "properties";
private static final String TYPE_FIELD_NAME = "type";
private static final String CORRELATION_VECTOR_TYPE = "correlation_vector";
private static final String DIMENSION_FIELD_NAME = "dimension";
private static final int M = 16;
private static final int EF_CONSTRUCTION = 128;
private static final String INDEX_NAME = "test-index-1";
private static final Float[][] TEST_VECTORS = new Float[][] {
{ 1.0f, 1.0f, 1.0f, 1.0f },
{ 2.0f, 2.0f, 2.0f, 2.0f },
{ 3.0f, 3.0f, 3.0f, 3.0f } };
private static final float[][] TEST_QUERY_VECTORS = new float[][] {
{ 1.0f, 1.0f, 1.0f, 1.0f },
{ 2.0f, 2.0f, 2.0f, 2.0f },
{ 3.0f, 3.0f, 3.0f, 3.0f } };
private static final Map<VectorSimilarityFunction, Function<Float, Float>> VECTOR_SIMILARITY_TO_SCORE = Map.of(
VectorSimilarityFunction.EUCLIDEAN,
(similarity) -> 1 / (1 + similarity),
VectorSimilarityFunction.DOT_PRODUCT,
(similarity) -> (1 + similarity) / 2,
VectorSimilarityFunction.COSINE,
(similarity) -> (1 + similarity) / 2
);

/**
* test the e2e storage and query layer of events-correlation-engine
* @throws IOException IOException
*/
@SuppressWarnings("unchecked")
public void testQuery() throws IOException {
String textField = "text-field";
String luceneField = "lucene-field";
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject(PROPERTIES_FIELD_NAME)
.startObject(textField)
.field(TYPE_FIELD_NAME, "text")
.endObject()
.startObject(luceneField)
.field(TYPE_FIELD_NAME, CORRELATION_VECTOR_TYPE)
.field(DIMENSION_FIELD_NAME, DIMENSION)
.startObject("correlation_ctx")
.field("similarityFunction", VectorSimilarityFunction.EUCLIDEAN.name())
.startObject("parameters")
.field("m", M)
.field("ef_construction", EF_CONSTRUCTION)
.endObject()
.endObject()
.endObject()
.endObject()
.endObject();

String mapping = Strings.toString(builder);
createTestIndexWithMappingJson(client(), INDEX_NAME, mapping, getCorrelationDefaultIndexSettings());

for (int idx = 0; idx < TEST_VECTORS.length; ++idx) {
addCorrelationDoc(
INDEX_NAME,
String.valueOf(idx + 1),
List.of(textField, luceneField),
List.of(java.util.UUID.randomUUID().toString(), TEST_VECTORS[idx])
);
}
refreshAllIndices();
Assert.assertEquals(TEST_VECTORS.length, getDocCount(INDEX_NAME));

int k = 2;
for (float[] query : TEST_QUERY_VECTORS) {

String correlationQuery = "{\n"
+ " \"query\": {\n"
+ " \"correlation\": {\n"
+ " \"lucene-field\": {\n"
+ " \"vector\": \n"
+ Arrays.toString(query)
+ " ,\n"
+ " \"k\": 2,\n"
+ " \"boost\": 1\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ "}";

Response response = searchCorrelationIndex(INDEX_NAME, correlationQuery, k);
Map<String, Object> responseBody = entityAsMap(response);
Assert.assertEquals(2, ((List<Object>) ((Map<String, Object>) responseBody.get("hits")).get("hits")).size());
@SuppressWarnings("unchecked")
double actualScore1 = Double.parseDouble(
((List<Map<String, Object>>) ((Map<String, Object>) responseBody.get("hits")).get("hits")).get(0).get("_score").toString()
);
@SuppressWarnings("unchecked")
double actualScore2 = Double.parseDouble(
((List<Map<String, Object>>) ((Map<String, Object>) responseBody.get("hits")).get("hits")).get(1).get("_score").toString()
);
@SuppressWarnings("unchecked")
List<Float> hit1 = ((Map<String, List<Double>>) ((List<Map<String, Object>>) ((Map<String, Object>) responseBody.get("hits"))
.get("hits")).get(0).get("_source")).get(luceneField).stream().map(Double::floatValue).collect(Collectors.toList());
float[] resultVector1 = new float[hit1.size()];
for (int i = 0; i < hit1.size(); ++i) {
resultVector1[i] = hit1.get(i);
}

@SuppressWarnings("unchecked")
List<Float> hit2 = ((Map<String, List<Double>>) ((List<Map<String, Object>>) ((Map<String, Object>) responseBody.get("hits"))
.get("hits")).get(1).get("_source")).get(luceneField).stream().map(Double::floatValue).collect(Collectors.toList());
float[] resultVector2 = new float[hit2.size()];
for (int i = 0; i < hit2.size(); ++i) {
resultVector2[i] = hit2.get(i);
}

double rawScore1 = VectorSimilarityFunction.EUCLIDEAN.compare(resultVector1, query);
Assert.assertEquals(rawScore1, actualScore1, 0.0001);
double rawScore2 = VectorSimilarityFunction.EUCLIDEAN.compare(resultVector2, query);
Assert.assertEquals(rawScore2, actualScore2, 0.0001);
}
}

private String createTestIndexWithMappingJson(RestClient client, String index, String mapping, Settings settings) throws IOException {
Request request = new Request("PUT", "/" + index);
String entity = "{\"settings\": " + Strings.toString(XContentType.JSON, settings);
if (mapping != null) {
entity = entity + ",\"mappings\" : " + mapping;
}

entity = entity + "}";
if (!settings.getAsBoolean(IndexSettings.INDEX_SOFT_DELETES_SETTING.getKey(), true)) {
expectSoftDeletesWarning(request, index);
}

request.setJsonEntity(entity);
client.performRequest(request);
return index;
}

private Settings getCorrelationDefaultIndexSettings() {
return Settings.builder().put("number_of_shards", 1).put("number_of_replicas", 0).put("index.correlation", true).build();
}

private void addCorrelationDoc(String index, String docId, List<String> fieldNames, List<Object> vectors) throws IOException {
Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true");

XContentBuilder builder = XContentFactory.jsonBuilder().startObject();
for (int i = 0; i < fieldNames.size(); i++) {
builder.field(fieldNames.get(i), vectors.get(i));
}
builder.endObject();

request.setJsonEntity(Strings.toString(builder));
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.CREATED, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
}

private Response searchCorrelationIndex(String index, String correlationQuery, int resultSize) throws IOException {
Request request = new Request("POST", "/" + index + "/_search");

request.addParameter("size", Integer.toString(resultSize));
request.addParameter("explain", Boolean.toString(true));
request.addParameter("search_type", "query_then_fetch");
request.setJsonEntity(correlationQuery);

Response response = client().performRequest(request);
Assert.assertEquals("Search failed", RestStatus.OK, restStatus(response));
return response;
}

private int getDocCount(String index) throws IOException {
Response response = makeRequest(
client(),
"GET",
String.format(Locale.getDefault(), "/%s/_count", index),
Collections.emptyMap(),
null
);
Assert.assertEquals(RestStatus.OK, restStatus(response));
return Integer.parseInt(entityAsMap(response).get("count").toString());
}

private Response makeRequest(
RestClient client,
String method,
String endpoint,
Map<String, String> params,
HttpEntity entity,
Header... headers
) throws IOException {
Request request = new Request(method, endpoint);
RequestOptions.Builder options = RequestOptions.DEFAULT.toBuilder();
options.setWarningsHandler(WarningsHandler.PERMISSIVE);

for (Header header : headers) {
options.addHeader(header.getName(), header.getValue());
}
request.setOptions(options.build());
request.addParameters(params);
if (entity != null) {
request.setEntity(entity);
}
return client.performRequest(request);
}

private RestStatus restStatus(Response response) {
return RestStatus.fromCode(response.getStatusLine().getStatusCode());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,23 @@
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.env.Environment;
import org.opensearch.env.NodeEnvironment;
import org.opensearch.index.IndexSettings;
import org.opensearch.index.codec.CodecServiceFactory;
import org.opensearch.index.mapper.Mapper;
import org.opensearch.plugin.correlation.core.index.codec.CorrelationCodecService;
import org.opensearch.plugin.correlation.core.index.mapper.CorrelationVectorFieldMapper;
import org.opensearch.plugin.correlation.core.index.mapper.VectorFieldMapper;
import org.opensearch.plugin.correlation.core.index.query.CorrelationQueryBuilder;
import org.opensearch.plugin.correlation.rules.action.IndexCorrelationRuleAction;
import org.opensearch.plugin.correlation.rules.resthandler.RestIndexCorrelationRuleAction;
import org.opensearch.plugin.correlation.rules.transport.TransportIndexCorrelationRuleAction;
import org.opensearch.plugin.correlation.settings.EventsCorrelationSettings;
import org.opensearch.plugin.correlation.utils.CorrelationRuleIndices;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.EnginePlugin;
import org.opensearch.plugins.MapperPlugin;
import org.opensearch.plugins.Plugin;
import org.opensearch.plugins.SearchPlugin;
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.rest.RestController;
import org.opensearch.rest.RestHandler;
Expand All @@ -38,13 +48,16 @@
import org.opensearch.watcher.ResourceWatcherService;

import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;

/**
* Plugin class for events-correlation-engine
*/
public class EventsCorrelationPlugin extends Plugin implements ActionPlugin {
public class EventsCorrelationPlugin extends Plugin implements ActionPlugin, MapperPlugin, SearchPlugin, EnginePlugin {

/**
* events-correlation-engine base uri
Expand Down Expand Up @@ -93,6 +106,30 @@ public List<RestHandler> getRestHandlers(
return List.of(new RestIndexCorrelationRuleAction());
}

@Override
public Map<String, Mapper.TypeParser> getMappers() {
return Collections.singletonMap(CorrelationVectorFieldMapper.CONTENT_TYPE, new VectorFieldMapper.TypeParser());
}

@Override
public Optional<CodecServiceFactory> getCustomCodecServiceFactory(IndexSettings indexSettings) {
if (indexSettings.getValue(EventsCorrelationSettings.IS_CORRELATION_INDEX_SETTING)) {
return Optional.of(CorrelationCodecService::new);
}
return Optional.empty();
}

@Override
public List<QuerySpec<?>> getQueries() {
return Collections.singletonList(
new QuerySpec<>(
CorrelationQueryBuilder.NAME_FIELD.getPreferredName(),
CorrelationQueryBuilder::new,
CorrelationQueryBuilder::parse
)
);
}

@Override
public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() {
return List.of(new ActionPlugin.ActionHandler<>(IndexCorrelationRuleAction.INSTANCE, TransportIndexCorrelationRuleAction.class));
Expand Down
Loading

0 comments on commit dcfd057

Please sign in to comment.