Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move model metadata retrieval out of mapper build method #111

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/main/java/org/opensearch/knn/index/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

package org.opensearch.knn.index;

import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.plugin.stats.KNNCounter;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
Expand All @@ -49,6 +51,8 @@
*/
public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
private static Logger logger = LogManager.getLogger(KNNQueryBuilder.class);
private static ModelDao modelDao;

public static final ParseField VECTOR_FIELD = new ParseField("vector");
public static final ParseField K_FIELD = new ParseField("k");
public static int K_MAX = 10000;
Expand Down Expand Up @@ -92,6 +96,10 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k) {
this.k = k;
}

public static void initialize(ModelDao modelDao) {
KNNQueryBuilder.modelDao = modelDao;
}

private static float[] ObjectsToFloats(List<Object> objs) {
float[] vec = new float[objs.size()];
for (int i = 0; i < objs.size(); i++) {
Expand Down Expand Up @@ -212,6 +220,22 @@ protected Query doToQuery(QueryShardContext context) throws IOException {

int dimension = ((KNNVectorFieldMapper.KNNVectorFieldType) mappedFieldType).getDimension();

// If the dimension is not set, then the only valid route forward is if the field uses a model
if (dimension == -1) {
String modelId = ((KNNVectorFieldMapper.KNNVectorFieldType) mappedFieldType).getModelId();

if (modelId == null) {
throw new IllegalArgumentException("Field '" + this.fieldName + "' does not have dimension set.");
}

ModelMetadata modelMetadata = modelDao.getMetadata(modelId);

if (modelMetadata == null) {
throw new IllegalArgumentException("Model ID \"" + modelId + "\" does not exist.");
}
dimension = modelMetadata.getDimension();
}

if (dimension != vector.length) {
throw new IllegalArgumentException("Query vector has invalid dimension: " + vector.length +
". Dimension should be: " + dimension);
Expand Down
58 changes: 41 additions & 17 deletions src/main/java/org/opensearch/knn/index/KNNVectorFieldMapper.java
Original file line number Diff line number Diff line change
Expand Up @@ -233,23 +233,21 @@ public KNNVectorFieldMapper build(BuilderContext context) {

String modelIdAsString = this.modelId.get();
if (modelIdAsString != null) {
ModelMetadata modelMetadata = modelDao.getMetadata(modelIdAsString);

if (modelMetadata == null) {
throw new IllegalArgumentException("Model \"" + modelId + "\" does not exist");
}
// Because model information is stored in cluster metadata, we are unable to get it here. This is
// because to get the cluster metadata, you need access to the cluster state. Because this code is
// sometimes used to initialize the cluster state/update cluster state, we cannot get the state here
// safely. So, we are unable to validate the model. The model gets validated during ingestion.

return new ModelFieldMapper(
name,
new KNNVectorFieldType(buildFullName(context), meta.getValue(), modelMetadata.getDimension()),
new KNNVectorFieldType(buildFullName(context), meta.getValue(), -1, modelIdAsString),
multiFieldsBuilder.build(this, context),
copyTo.build(),
ignoreMalformed(context),
stored.get(),
hasDocValues.get(),
modelDao,
modelIdAsString,
modelMetadata);
modelIdAsString);
}

// Build legacy
Expand Down Expand Up @@ -314,10 +312,16 @@ public Mapper.Builder<?> parse(String name, Map<String, Object> node, ParserCont
public static class KNNVectorFieldType extends MappedFieldType {

int dimension;
String modelId;

public KNNVectorFieldType(String name, Map<String, String> meta, int dimension) {
this(name, meta, dimension, null);
}

public KNNVectorFieldType(String name, Map<String, String> meta, int dimension, String modelId) {
super(name, false, false, true, TextSearchInfo.NONE, meta);
this.dimension = dimension;
this.modelId = modelId;
}

@Override
Expand Down Expand Up @@ -345,6 +349,10 @@ public int getDimension() {
return dimension;
}

public String getModelId() {
return modelId;
}

@Override
public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName, Supplier<SearchLookup> searchLookup) {
failIfNoDocValues();
Expand Down Expand Up @@ -385,6 +393,11 @@ protected String contentType() {

@Override
protected void parseCreateField(ParseContext context) throws IOException {
parseCreateField(context, fieldType().getDimension());
}

protected void parseCreateField(ParseContext context, int dimension) throws IOException {

if (!KNNSettings.isKNNPluginEnabled()) {
throw new IllegalStateException("KNN plugin is disabled. To enable " +
"update knn.plugin.enabled setting to true");
Expand Down Expand Up @@ -431,9 +444,9 @@ protected void parseCreateField(ParseContext context) throws IOException {
context.parser().nextToken();
}

if (fieldType().dimension != vector.size()) {
String errorMessage = String.format("Vector dimension mismatch. Expected: %d, Given: %d",
fieldType().dimension, vector.size());
if (dimension != vector.size()) {
String errorMessage = String.format("Vector dimension mismatch. Expected: %d, Given: %d", dimension,
vector.size());
throw new IllegalArgumentException(errorMessage);
}

Expand Down Expand Up @@ -606,21 +619,32 @@ protected static class ModelFieldMapper extends KNNVectorFieldMapper {

private ModelFieldMapper(String simpleName, KNNVectorFieldType mappedFieldType, MultiFields multiFields,
CopyTo copyTo, Explicit<Boolean> ignoreMalformed, boolean stored,
boolean hasDocValues, ModelDao modelDao, String modelId, ModelMetadata modelMetadata) {
boolean hasDocValues, ModelDao modelDao, String modelId) {
super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues);

this.modelId = modelId;
this.modelDao = modelDao;

this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE);

this.fieldType.putAttribute(MODEL_ID, modelId);
this.fieldType.freeze();
}

this.fieldType.putAttribute(DIMENSION, String.valueOf(modelMetadata.getDimension()));
this.fieldType.putAttribute(SPACE_TYPE, modelMetadata.getSpaceType().getValue());
this.fieldType.putAttribute(KNN_ENGINE, modelMetadata.getKnnEngine().getName());
@Override
protected void parseCreateField(ParseContext context) throws IOException {
// For the model field mapper, we cannot validate the model during index creation due to
// an issue with reading cluster state during mapper creation. So, we need to validate the
// model when ingestion starts.
ModelMetadata modelMetadata = this.modelDao.getMetadata(modelId);

if (modelMetadata == null) {
throw new IllegalStateException("Model \"" + modelId + "\" from " +
context.mapperService().index().getName() + "'s mapping does not exist. Because the " +
"\"" + MODEL_ID + "\" parameter is not updateable, this index will need to " +
"be recreated with a valid model.");
}

this.fieldType.freeze();
parseCreateField(context, modelMetadata.getDimension());
}
}
}
29 changes: 27 additions & 2 deletions src/main/java/org/opensearch/knn/index/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
import org.apache.lucene.store.FilterDirectory;
import org.apache.lucene.util.DocIdSetBuilder;
import org.opensearch.common.io.PathUtils;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.plugin.stats.KNNCounter;

import java.io.IOException;
Expand All @@ -60,6 +62,7 @@
import java.util.stream.Collectors;

import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE;
import static org.opensearch.knn.plugin.stats.KNNCounter.GRAPH_QUERY_ERRORS;

Expand All @@ -68,6 +71,8 @@
*/
public class KNNWeight extends Weight {
private static Logger logger = LogManager.getLogger(KNNWeight.class);
private static ModelDao modelDao;

private final KNNQuery knnQuery;
private final float boost;

Expand All @@ -80,6 +85,10 @@ public KNNWeight(KNNQuery query, float boost) {
this.nativeMemoryCacheManager = NativeMemoryCacheManager.getInstance();
}

public static void initialize(ModelDao modelDao) {
KNNWeight.modelDao = modelDao;
}

@Override
public Explanation explain(LeafReaderContext context, int doc) {
return Explanation.match(1.0f, "No Explanation");
Expand All @@ -102,8 +111,24 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
return null;
}

KNNEngine knnEngine = KNNEngine.getEngine(fieldInfo.getAttribute(KNN_ENGINE));
SpaceType spaceType = SpaceType.getSpace(fieldInfo.getAttribute(SPACE_TYPE));
KNNEngine knnEngine;
SpaceType spaceType;

// Check if a modelId exists. If so, the space type and engine will need to be picked up from the model's
// metadata.
String modelId = fieldInfo.getAttribute(MODEL_ID);
if (modelId != null) {
ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (modelMetadata == null) {
throw new RuntimeException("Model \"" + modelId + "\" does not exist.");
}

knnEngine = modelMetadata.getKnnEngine();
spaceType = modelMetadata.getSpaceType();
} else {
knnEngine = KNNEngine.getEngine(fieldInfo.getAttribute(KNN_ENGINE));
spaceType = SpaceType.getSpace(fieldInfo.getAttribute(SPACE_TYPE));
}

/*
* In case of compound file, extension would be <engine-extension> + c otherwise <engine-extension>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
import java.util.HashMap;
import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.DIMENSION;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.index.codec.KNNCodecUtil.buildEngineFileName;

Expand Down Expand Up @@ -91,9 +90,6 @@ public void addBinaryField(FieldInfo field, DocValuesProducer valuesProducer) th
public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer) throws IOException {
KNNCounter.GRAPH_INDEX_REQUESTS.increment();
if (field.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD)) {
// Get engine to be used for indexing
String engineName = field.attributes().getOrDefault(KNNConstants.KNN_ENGINE, KNNEngine.DEFAULT.getName());
KNNEngine knnEngine = KNNEngine.getEngine(engineName);

// Get values to be indexed
BinaryDocValues values = valuesProducer.getBinary(field);
Expand All @@ -104,44 +100,42 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer)
}

// Create library index either from model or from scratch
String engineFileName = buildEngineFileName(state.segmentInfo.name, knnEngine.getLatestBuildVersion(),
field.name, knnEngine.getExtension());
String indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(),
engineFileName).toString();
String tmpEngineFileName = engineFileName + TEMP_SUFFIX;
String tempIndexPath = indexPath + TEMP_SUFFIX;
String engineFileName;
String indexPath;
String tmpEngineFileName;

if (field.attributes().containsKey(MODEL_ID)) {

String modelId = field.attributes().get(MODEL_ID);
Model model = ModelCache.getInstance().get(modelId);

if (model.getModelBlob() == null) {
throw new RuntimeException("There is no model with id \"" + modelId + "\"");
}
KNNEngine knnEngine = model.getModelMetadata().getKnnEngine();

if (model.getModelMetadata().getKnnEngine() != knnEngine) {
throw new RuntimeException("Model Engine \"" + model.getModelMetadata().getKnnEngine().getName()
+ "\" cannot be different than index engine \"" + knnEngine.getName() + "\"");
}

String spaceName = field.getAttribute(KNNConstants.SPACE_TYPE);
if (spaceName == null) {
throw new RuntimeException("Space Type cannot be null");
}
engineFileName = buildEngineFileName(state.segmentInfo.name, knnEngine.getLatestBuildVersion(),
field.name, knnEngine.getExtension());
indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(),
engineFileName).toString();
tmpEngineFileName = engineFileName + TEMP_SUFFIX;
String tempIndexPath = indexPath + TEMP_SUFFIX;

SpaceType spaceType = SpaceType.getSpace(spaceName);
if (model.getModelMetadata().getSpaceType() != spaceType) {
throw new RuntimeException("Model Space Type \"" + model.getModelMetadata().getSpaceType().getValue()
+ "\" cannot be different than index Space Type \"" + spaceType.getValue() + "\"");
}

int dimension = Integer.parseInt(field.attributes().getOrDefault(DIMENSION, "-1"));
if (model.getModelMetadata().getDimension() != dimension) {
throw new RuntimeException("Model dimension \"" + model.getModelMetadata().getDimension()
+ "\" cannot be different than index dimension \"" + dimension + "\"");
if (model.getModelBlob() == null) {
throw new RuntimeException("There is no trained model with id \"" + modelId + "\"");
}

createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, tempIndexPath);
} else {

// Get engine to be used for indexing
String engineName = field.attributes().getOrDefault(KNNConstants.KNN_ENGINE, KNNEngine.DEFAULT.getName());
KNNEngine knnEngine = KNNEngine.getEngine(engineName);

engineFileName = buildEngineFileName(state.segmentInfo.name, knnEngine.getLatestBuildVersion(),
field.name, knnEngine.getExtension());
indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(),
engineFileName).toString();
tmpEngineFileName = engineFileName + TEMP_SUFFIX;
String tempIndexPath = indexPath + TEMP_SUFFIX;

createKNNIndexFromScratch(field, pair, knnEngine, tempIndexPath);
}

Expand Down
3 changes: 3 additions & 0 deletions src/main/java/org/opensearch/knn/plugin/KNNPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.KNNVectorFieldMapper;

import org.opensearch.knn.index.KNNWeight;
import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy;
import org.opensearch.knn.indices.ModelCache;
import org.opensearch.knn.indices.ModelDao;
Expand Down Expand Up @@ -170,6 +171,8 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
ModelCache.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService);
TrainingJobRunner.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance());
KNNCircuitBreaker.getInstance().initialize(threadPool, clusterService, client);
KNNQueryBuilder.initialize(ModelDao.OpenSearchKNNModelDao.getInstance());
KNNWeight.initialize(ModelDao.OpenSearchKNNModelDao.getInstance());
knnStats = new KNNStats(KNNStatsConfig.KNN_STATS);
return ImmutableList.of(knnStats);
}
Expand Down
Loading