Skip to content

Commit

Permalink
Added functionalities like IndexBuildServiceClient, parsing of reques…
Browse files Browse the repository at this point in the history
…t, response for create index API

Signed-off-by: Navneet Verma <[email protected]>
  • Loading branch information
navneet1v committed Dec 17, 2024
1 parent 6b8c826 commit 5606ac5
Show file tree
Hide file tree
Showing 8 changed files with 337 additions and 14 deletions.
4 changes: 4 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,10 @@ dependencies {
api group: 'com.google.guava', name: 'failureaccess', version:'1.0.1'
api group: 'com.google.guava', name: 'guava', version:'32.1.3-jre'
api group: 'commons-lang', name: 'commons-lang', version: '2.6'

api group: 'org.apache.httpcomponents', name: 'httpcore', version: "${versions.httpcore}"
api group: 'org.apache.httpcomponents', name: 'httpclient', version: "${versions.httpclient}"
api group: 'org.apache.httpcomponents', name: 'httpasyncclient', version: "${versions.httpasyncclient}"
testFixturesImplementation "org.opensearch.test:framework:${opensearch_version}"
testImplementation group: 'net.bytebuddy', name: 'byte-buddy', version: '1.15.10'
testImplementation group: 'org.objenesis', name: 'objenesis', version: '3.3'
Expand Down
87 changes: 84 additions & 3 deletions src/main/java/org/opensearch/knn/index/KNNSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ public class KNNSettings {
public static final String QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES = "knn.quantization.cache.expiry.minutes";
public static final String KNN_FAISS_AVX512_DISABLED = "knn.faiss.avx512.disabled";
public static final String KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED = "index.knn.disk.vector.shard_level_rescoring_disabled";
public static final String KNN_S3_ACCESS_KEY = "knn.s3.access.key";
public static final String KNN_S3_SECRET_KEY = "knn.s3.secret.key";
public static final String KNN_S3_TOKEN_KEY = "knn.s3.token.key";
public static final String REMOTE_SERVICE_ENDPOINT = "knn.remote.index.build.service.endpoint";
public static final String REMOTE_SERVICE_PORT = "knn.remote.index.build.service.port";

/**
* Default setting values
Expand Down Expand Up @@ -160,6 +165,37 @@ public class KNNSettings {
Setting.Property.Deprecated
);

public static final Setting<String> KNN_S3_ACCESS_KEY_SETTING = Setting.simpleString(
KNN_S3_ACCESS_KEY,
Setting.Property.Dynamic,
Setting.Property.NodeScope
);

public static final Setting<String> KNN_S3_SECRET_KEY_SETTING = Setting.simpleString(
KNN_S3_SECRET_KEY,
Setting.Property.Dynamic,
Setting.Property.NodeScope
);

public static final Setting<String> KNN_S3_TOKEN_KEY_SETTING = Setting.simpleString(
KNN_S3_TOKEN_KEY,
Setting.Property.Dynamic,
Setting.Property.NodeScope
);

public static final Setting<String> REMOTE_SERVICE_ENDPOINT_SETTING = Setting.simpleString(
REMOTE_SERVICE_ENDPOINT,
Setting.Property.Dynamic,
Setting.Property.NodeScope
);

public static final Setting<Integer> REMOTE_SERVICE_PORT_SETTING = Setting.intSetting(
REMOTE_SERVICE_PORT,
8200,
Setting.Property.Dynamic,
Setting.Property.NodeScope
);

/**
* build_vector_data_structure_threshold - This parameter determines when to build vector data structure for knn fields during indexing
* and merging. Setting -1 (min) will skip building graph, whereas on any other values, the graph will be built if
Expand Down Expand Up @@ -352,6 +388,19 @@ public class KNNSettings {
NodeScope
);

/**
* All other Settings
*/
private static final Map<String, Setting<?>> REMOTE_INDEX_BUILD_SERVICE_SETTINGS_MAP = new HashMap<>() {
{
put(KNN_S3_ACCESS_KEY, KNN_S3_ACCESS_KEY_SETTING);
put(KNN_S3_SECRET_KEY, KNN_S3_SECRET_KEY_SETTING);
put(KNN_S3_TOKEN_KEY, KNN_S3_TOKEN_KEY_SETTING);
put(REMOTE_SERVICE_ENDPOINT, REMOTE_SERVICE_ENDPOINT_SETTING);
put(REMOTE_SERVICE_PORT, REMOTE_SERVICE_PORT_SETTING);
}
};

/**
* Dynamic settings
*/
Expand Down Expand Up @@ -499,6 +548,10 @@ private Setting<?> getSetting(String key) {
return KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING;
}

if (REMOTE_INDEX_BUILD_SERVICE_SETTINGS_MAP.containsKey(key)) {
return REMOTE_INDEX_BUILD_SERVICE_SETTINGS_MAP.get(key);
}

throw new IllegalArgumentException("Cannot find setting by key [" + key + "]");
}

Expand All @@ -522,10 +575,18 @@ public List<Setting<?>> getSettings() {
KNN_FAISS_AVX512_DISABLED_SETTING,
QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING,
QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING,
KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING
KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING,
KNN_S3_ACCESS_KEY_SETTING,
KNN_S3_SECRET_KEY_SETTING,
KNN_S3_TOKEN_KEY_SETTING
);
return Stream.concat(settings.stream(), Stream.concat(getFeatureFlags().stream(), dynamicCacheSettings.values().stream()))
.collect(Collectors.toList());
final List<Stream<Setting<?>>> streamList = Arrays.asList(
settings.stream(),
getFeatureFlags().stream(),
dynamicCacheSettings.values().stream(),
REMOTE_INDEX_BUILD_SERVICE_SETTINGS_MAP.values().stream()
);
return streamList.stream().flatMap(stream -> stream).collect(Collectors.toList());
}

public static boolean isKNNPluginEnabled() {
Expand Down Expand Up @@ -585,6 +646,26 @@ public static boolean isShardLevelRescoringEnabledForDiskBasedVector(String inde
.getAsBoolean(KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED, false);
}

public static String getKnnS3AccessKey() {
return KNNSettings.state().getSettingValue(KNN_S3_ACCESS_KEY);
}

public static String getKnnS3SecretKey() {
return KNNSettings.state().getSettingValue(KNN_S3_SECRET_KEY);
}

public static String getKnnS3Token() {
return KNNSettings.state().getSettingValue(KNN_S3_TOKEN_KEY);
}

public static String getRemoteServiceEndpoint() {
return KNNSettings.state().getSettingValue(REMOTE_SERVICE_PORT);
}

public static Integer getRemoteServicePort() {
return KNNSettings.state().getSettingValue(REMOTE_SERVICE_PORT);
}

public void initialize(Client client, ClusterService clusterService) {
this.client = client;
this.clusterService = clusterService;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
import org.opensearch.knn.plugin.stats.KNNGraphValue;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.knn.remote.index.client.IndexBuildServiceClient;
import org.opensearch.knn.remote.index.model.CreateIndexRequest;
import org.opensearch.knn.remote.index.model.CreateIndexResponse;
import org.opensearch.knn.remote.index.s3.S3Client;

import java.io.IOException;
Expand All @@ -59,6 +62,7 @@ public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter {
private boolean finished;
private final Integer approximateThreshold;
private final S3Client s3Client;
private final IndexBuildServiceClient indexBuildServiceClient;

public NativeEngines990KnnVectorsWriter(
SegmentWriteState segmentWriteState,
Expand All @@ -73,6 +77,7 @@ public NativeEngines990KnnVectorsWriter(
} catch (Exception e) {
throw new RuntimeException(e);
}
indexBuildServiceClient = IndexBuildServiceClient.getInstance();
}

/**
Expand Down Expand Up @@ -125,7 +130,17 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException {
);
continue;
}

uploadToS3(fieldInfo, knnVectorValuesSupplier);
log.info("Creating the IndexRequest...");
CreateIndexRequest createIndexRequest = buildCreateIndexRequest(fieldInfo);
log.info("Submitting request to remote indexbuildService");
try {
CreateIndexResponse response = indexBuildServiceClient.createIndex(createIndexRequest);
log.info("Request completed with response : {}", response);
} catch (Exception e) {
log.error("Failed to call indexBuildServiceClient.createIndex for input: {}", createIndexRequest, e);
}
final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState);
final KNNVectorValues<?> knnVectorValues = knnVectorValuesSupplier.get();

Expand Down Expand Up @@ -202,6 +217,14 @@ private void uploadToS3(final FieldInfo fieldInfo, final Supplier<KNNVectorValue
}
}

private CreateIndexRequest buildCreateIndexRequest(final FieldInfo fieldInfo) {
String segmentName = segmentWriteState.segmentInfo.name;
String fieldName = fieldInfo.getName();
String s3Key = segmentName + "_" + fieldName + ".s3vec";
int dimension = fieldInfo.getVectorDimension();
return CreateIndexRequest.builder().bucketName(S3Client.BUCKET_NAME).objectLocation(s3Key).dimensions(dimension).build();
}

/**
* Called once at the end before close
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.remote.index.client;

import org.apache.http.HttpEntity;
import org.apache.http.HttpHost;
import org.apache.http.HttpRequest;
import org.apache.http.HttpResponse;
import org.apache.http.client.HttpClient;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.http.util.EntityUtils;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.remote.index.model.CreateIndexRequest;
import org.opensearch.knn.remote.index.model.CreateIndexResponse;
import org.opensearch.knn.remote.index.s3.S3Client;

import java.io.IOException;

/**
* Main class to class the IndexBuildServiceAPIs
*/
public class IndexBuildServiceClient {
private static volatile IndexBuildServiceClient INSTANCE;
private static final String CONTENT_TYPE = "Content-Type";
private static final String APPLICATION_JSON = "application/json";
private static final String ACCEPT = "Accept";
private final HttpClient httpClient;
private final HttpHost httpHost;

public static IndexBuildServiceClient getInstance() throws IOException {
IndexBuildServiceClient result = INSTANCE;
if (result == null) {
synchronized (S3Client.class) {
result = INSTANCE;
if (result == null) {
INSTANCE = result = new IndexBuildServiceClient();
}
}
}
return result;
}

private IndexBuildServiceClient() {
this.httpClient = HttpClientBuilder.create().build();
this.httpHost = new HttpHost(KNNSettings.getRemoteServiceEndpoint(), KNNSettings.getRemoteServicePort(), "http");
}

/**
* API to be called to create the Vector Index using remote endpoint
* @param createIndexRequest {@link CreateIndexRequest}
* @throws IOException Exception called if createIndex request is not successful
*/
public CreateIndexResponse createIndex(final CreateIndexRequest createIndexRequest) throws IOException {
HttpPost request = new HttpPost();
request.setHeader(CONTENT_TYPE, APPLICATION_JSON);
request.setHeader(ACCEPT, APPLICATION_JSON);
XContentBuilder builder = XContentFactory.jsonBuilder();
builder = createIndexRequest.toXContent(builder, null);
request.setEntity(new StringEntity(builder.toString()));

HttpResponse response = makeHTTPRequest(request);
HttpEntity httpEntity = response.getEntity();
String responseString = EntityUtils.toString(httpEntity);
return parseCreateIndexResponse(responseString);
}

// TODO: To be implemented
public void checkIndexBuildStatus() {

}

private HttpResponse makeHTTPRequest(final HttpRequest request) throws IOException {
HttpResponse response = httpClient.execute(httpHost, request);
HttpEntity entity = response.getEntity();
int statusCode = response.getStatusLine().getStatusCode();

if (statusCode >= 400) {
String errorBody = entity != null ? EntityUtils.toString(entity) : "No response body";
throw new IOException("Request failed with status code: " + statusCode + ", body: " + errorBody);
}

return response;
}

// Keeping it package private for doing the unit testing for now.
static CreateIndexResponse parseCreateIndexResponse(final String responseString) throws IOException {
final XContent xContent = MediaTypeRegistry.getDefaultMediaType().xContent();
final XContentParser parser = xContent.createParser(
NamedXContentRegistry.EMPTY,
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
responseString
);
return CreateIndexResponse.fromXContent(parser);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.remote.index.model;

import lombok.Builder;
import lombok.Value;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;

import java.io.IOException;

@Value
@Builder
public class CreateIndexRequest implements ToXContentObject {
String bucketName;
String objectLocation;
long numberOfVectors;
int dimensions;

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return builder.startObject()
.field("bucket_name", bucketName)
.field("object_location", objectLocation)
.field("number_of_vectors", numberOfVectors)
.field("dimensions", dimensions)
.endObject();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.remote.index.model;

import lombok.Builder;
import lombok.Value;
import org.opensearch.core.ParseField;
import org.opensearch.core.xcontent.XContentParser;

import java.io.IOException;

@Value
@Builder
public class CreateIndexResponse {
private static final ParseField INDEX_CREATION_REQUEST_ID = new ParseField("indexCreationRequestId");
private static final ParseField STATUS = new ParseField("status");
String indexCreationRequestId;
String status;

public static CreateIndexResponse fromXContent(XContentParser parser) throws IOException {
final CreateIndexResponseBuilder builder = new CreateIndexResponseBuilder();
XContentParser.Token token = parser.nextToken();
if (token != XContentParser.Token.START_OBJECT) {
throw new IOException("Invalid response format, was expecting a " + XContentParser.Token.START_OBJECT);
}
String currentFieldName = null;
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
currentFieldName = parser.currentName();
} else if (token.isValue()) {
if (INDEX_CREATION_REQUEST_ID.match(currentFieldName, parser.getDeprecationHandler())) {
builder.indexCreationRequestId(parser.text());
} else if (STATUS.match(currentFieldName, parser.getDeprecationHandler())) {
builder.status(parser.text());
} else {
throw new IOException("Invalid response format, unknown field: " + currentFieldName);
}
}
}
return builder.build();
}
}
Loading

0 comments on commit 5606ac5

Please sign in to comment.