Skip to content

Commit

Permalink
Move model metadata retrieval out of mapper build method (#111)
Browse files Browse the repository at this point in the history
Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
jmazanec15 authored Oct 5, 2021
1 parent 3acec17 commit 6d7eb0a
Show file tree
Hide file tree
Showing 8 changed files with 249 additions and 54 deletions.
24 changes: 24 additions & 0 deletions src/main/java/org/opensearch/knn/index/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

package org.opensearch.knn.index;

import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.plugin.stats.KNNCounter;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
Expand All @@ -49,6 +51,8 @@
*/
public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
private static Logger logger = LogManager.getLogger(KNNQueryBuilder.class);
private static ModelDao modelDao;

public static final ParseField VECTOR_FIELD = new ParseField("vector");
public static final ParseField K_FIELD = new ParseField("k");
public static int K_MAX = 10000;
Expand Down Expand Up @@ -92,6 +96,10 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k) {
this.k = k;
}

public static void initialize(ModelDao modelDao) {
KNNQueryBuilder.modelDao = modelDao;
}

private static float[] ObjectsToFloats(List<Object> objs) {
float[] vec = new float[objs.size()];
for (int i = 0; i < objs.size(); i++) {
Expand Down Expand Up @@ -212,6 +220,22 @@ protected Query doToQuery(QueryShardContext context) throws IOException {

int dimension = ((KNNVectorFieldMapper.KNNVectorFieldType) mappedFieldType).getDimension();

// If the dimension is not set, then the only valid route forward is if the field uses a model
if (dimension == -1) {
String modelId = ((KNNVectorFieldMapper.KNNVectorFieldType) mappedFieldType).getModelId();

if (modelId == null) {
throw new IllegalArgumentException("Field '" + this.fieldName + "' does not have dimension set.");
}

ModelMetadata modelMetadata = modelDao.getMetadata(modelId);

if (modelMetadata == null) {
throw new IllegalArgumentException("Model ID \"" + modelId + "\" does not exist.");
}
dimension = modelMetadata.getDimension();
}

if (dimension != vector.length) {
throw new IllegalArgumentException("Query vector has invalid dimension: " + vector.length +
". Dimension should be: " + dimension);
Expand Down
58 changes: 41 additions & 17 deletions src/main/java/org/opensearch/knn/index/KNNVectorFieldMapper.java
Original file line number Diff line number Diff line change
Expand Up @@ -233,23 +233,21 @@ public KNNVectorFieldMapper build(BuilderContext context) {

String modelIdAsString = this.modelId.get();
if (modelIdAsString != null) {
ModelMetadata modelMetadata = modelDao.getMetadata(modelIdAsString);

if (modelMetadata == null) {
throw new IllegalArgumentException("Model \"" + modelId + "\" does not exist");
}
// Because model information is stored in cluster metadata, we are unable to get it here. This is
// because to get the cluster metadata, you need access to the cluster state. Because this code is
// sometimes used to initialize the cluster state/update cluster state, we cannot get the state here
// safely. So, we are unable to validate the model. The model gets validated during ingestion.

return new ModelFieldMapper(
name,
new KNNVectorFieldType(buildFullName(context), meta.getValue(), modelMetadata.getDimension()),
new KNNVectorFieldType(buildFullName(context), meta.getValue(), -1, modelIdAsString),
multiFieldsBuilder.build(this, context),
copyTo.build(),
ignoreMalformed(context),
stored.get(),
hasDocValues.get(),
modelDao,
modelIdAsString,
modelMetadata);
modelIdAsString);
}

// Build legacy
Expand Down Expand Up @@ -314,10 +312,16 @@ public Mapper.Builder<?> parse(String name, Map<String, Object> node, ParserCont
public static class KNNVectorFieldType extends MappedFieldType {

int dimension;
String modelId;

public KNNVectorFieldType(String name, Map<String, String> meta, int dimension) {
this(name, meta, dimension, null);
}

public KNNVectorFieldType(String name, Map<String, String> meta, int dimension, String modelId) {
super(name, false, false, true, TextSearchInfo.NONE, meta);
this.dimension = dimension;
this.modelId = modelId;
}

@Override
Expand Down Expand Up @@ -345,6 +349,10 @@ public int getDimension() {
return dimension;
}

public String getModelId() {
return modelId;
}

@Override
public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName, Supplier<SearchLookup> searchLookup) {
failIfNoDocValues();
Expand Down Expand Up @@ -385,6 +393,11 @@ protected String contentType() {

@Override
protected void parseCreateField(ParseContext context) throws IOException {
parseCreateField(context, fieldType().getDimension());
}

protected void parseCreateField(ParseContext context, int dimension) throws IOException {

if (!KNNSettings.isKNNPluginEnabled()) {
throw new IllegalStateException("KNN plugin is disabled. To enable " +
"update knn.plugin.enabled setting to true");
Expand Down Expand Up @@ -431,9 +444,9 @@ protected void parseCreateField(ParseContext context) throws IOException {
context.parser().nextToken();
}

if (fieldType().dimension != vector.size()) {
String errorMessage = String.format("Vector dimension mismatch. Expected: %d, Given: %d",
fieldType().dimension, vector.size());
if (dimension != vector.size()) {
String errorMessage = String.format("Vector dimension mismatch. Expected: %d, Given: %d", dimension,
vector.size());
throw new IllegalArgumentException(errorMessage);
}

Expand Down Expand Up @@ -606,21 +619,32 @@ protected static class ModelFieldMapper extends KNNVectorFieldMapper {

private ModelFieldMapper(String simpleName, KNNVectorFieldType mappedFieldType, MultiFields multiFields,
CopyTo copyTo, Explicit<Boolean> ignoreMalformed, boolean stored,
boolean hasDocValues, ModelDao modelDao, String modelId, ModelMetadata modelMetadata) {
boolean hasDocValues, ModelDao modelDao, String modelId) {
super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues);

this.modelId = modelId;
this.modelDao = modelDao;

this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE);

this.fieldType.putAttribute(MODEL_ID, modelId);
this.fieldType.freeze();
}

this.fieldType.putAttribute(DIMENSION, String.valueOf(modelMetadata.getDimension()));
this.fieldType.putAttribute(SPACE_TYPE, modelMetadata.getSpaceType().getValue());
this.fieldType.putAttribute(KNN_ENGINE, modelMetadata.getKnnEngine().getName());
@Override
protected void parseCreateField(ParseContext context) throws IOException {
// For the model field mapper, we cannot validate the model during index creation due to
// an issue with reading cluster state during mapper creation. So, we need to validate the
// model when ingestion starts.
ModelMetadata modelMetadata = this.modelDao.getMetadata(modelId);

if (modelMetadata == null) {
throw new IllegalStateException("Model \"" + modelId + "\" from " +
context.mapperService().index().getName() + "'s mapping does not exist. Because the " +
"\"" + MODEL_ID + "\" parameter is not updateable, this index will need to " +
"be recreated with a valid model.");
}

this.fieldType.freeze();
parseCreateField(context, modelMetadata.getDimension());
}
}
}
29 changes: 27 additions & 2 deletions src/main/java/org/opensearch/knn/index/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
import org.apache.lucene.store.FilterDirectory;
import org.apache.lucene.util.DocIdSetBuilder;
import org.opensearch.common.io.PathUtils;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.plugin.stats.KNNCounter;

import java.io.IOException;
Expand All @@ -60,6 +62,7 @@
import java.util.stream.Collectors;

import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE;
import static org.opensearch.knn.plugin.stats.KNNCounter.GRAPH_QUERY_ERRORS;

Expand All @@ -68,6 +71,8 @@
*/
public class KNNWeight extends Weight {
private static Logger logger = LogManager.getLogger(KNNWeight.class);
private static ModelDao modelDao;

private final KNNQuery knnQuery;
private final float boost;

Expand All @@ -80,6 +85,10 @@ public KNNWeight(KNNQuery query, float boost) {
this.nativeMemoryCacheManager = NativeMemoryCacheManager.getInstance();
}

public static void initialize(ModelDao modelDao) {
KNNWeight.modelDao = modelDao;
}

@Override
public Explanation explain(LeafReaderContext context, int doc) {
return Explanation.match(1.0f, "No Explanation");
Expand All @@ -102,8 +111,24 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
return null;
}

KNNEngine knnEngine = KNNEngine.getEngine(fieldInfo.getAttribute(KNN_ENGINE));
SpaceType spaceType = SpaceType.getSpace(fieldInfo.getAttribute(SPACE_TYPE));
KNNEngine knnEngine;
SpaceType spaceType;

// Check if a modelId exists. If so, the space type and engine will need to be picked up from the model's
// metadata.
String modelId = fieldInfo.getAttribute(MODEL_ID);
if (modelId != null) {
ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (modelMetadata == null) {
throw new RuntimeException("Model \"" + modelId + "\" does not exist.");
}

knnEngine = modelMetadata.getKnnEngine();
spaceType = modelMetadata.getSpaceType();
} else {
knnEngine = KNNEngine.getEngine(fieldInfo.getAttribute(KNN_ENGINE));
spaceType = SpaceType.getSpace(fieldInfo.getAttribute(SPACE_TYPE));
}

/*
* In case of compound file, extension would be <engine-extension> + c otherwise <engine-extension>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
import java.util.HashMap;
import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.DIMENSION;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.index.codec.KNNCodecUtil.buildEngineFileName;

Expand Down Expand Up @@ -91,9 +90,6 @@ public void addBinaryField(FieldInfo field, DocValuesProducer valuesProducer) th
public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer) throws IOException {
KNNCounter.GRAPH_INDEX_REQUESTS.increment();
if (field.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD)) {
// Get engine to be used for indexing
String engineName = field.attributes().getOrDefault(KNNConstants.KNN_ENGINE, KNNEngine.DEFAULT.getName());
KNNEngine knnEngine = KNNEngine.getEngine(engineName);

// Get values to be indexed
BinaryDocValues values = valuesProducer.getBinary(field);
Expand All @@ -104,44 +100,42 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer)
}

// Create library index either from model or from scratch
String engineFileName = buildEngineFileName(state.segmentInfo.name, knnEngine.getLatestBuildVersion(),
field.name, knnEngine.getExtension());
String indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(),
engineFileName).toString();
String tmpEngineFileName = engineFileName + TEMP_SUFFIX;
String tempIndexPath = indexPath + TEMP_SUFFIX;
String engineFileName;
String indexPath;
String tmpEngineFileName;

if (field.attributes().containsKey(MODEL_ID)) {

String modelId = field.attributes().get(MODEL_ID);
Model model = ModelCache.getInstance().get(modelId);

if (model.getModelBlob() == null) {
throw new RuntimeException("There is no model with id \"" + modelId + "\"");
}
KNNEngine knnEngine = model.getModelMetadata().getKnnEngine();

if (model.getModelMetadata().getKnnEngine() != knnEngine) {
throw new RuntimeException("Model Engine \"" + model.getModelMetadata().getKnnEngine().getName()
+ "\" cannot be different than index engine \"" + knnEngine.getName() + "\"");
}

String spaceName = field.getAttribute(KNNConstants.SPACE_TYPE);
if (spaceName == null) {
throw new RuntimeException("Space Type cannot be null");
}
engineFileName = buildEngineFileName(state.segmentInfo.name, knnEngine.getLatestBuildVersion(),
field.name, knnEngine.getExtension());
indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(),
engineFileName).toString();
tmpEngineFileName = engineFileName + TEMP_SUFFIX;
String tempIndexPath = indexPath + TEMP_SUFFIX;

SpaceType spaceType = SpaceType.getSpace(spaceName);
if (model.getModelMetadata().getSpaceType() != spaceType) {
throw new RuntimeException("Model Space Type \"" + model.getModelMetadata().getSpaceType().getValue()
+ "\" cannot be different than index Space Type \"" + spaceType.getValue() + "\"");
}

int dimension = Integer.parseInt(field.attributes().getOrDefault(DIMENSION, "-1"));
if (model.getModelMetadata().getDimension() != dimension) {
throw new RuntimeException("Model dimension \"" + model.getModelMetadata().getDimension()
+ "\" cannot be different than index dimension \"" + dimension + "\"");
if (model.getModelBlob() == null) {
throw new RuntimeException("There is no trained model with id \"" + modelId + "\"");
}

createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, tempIndexPath);
} else {

// Get engine to be used for indexing
String engineName = field.attributes().getOrDefault(KNNConstants.KNN_ENGINE, KNNEngine.DEFAULT.getName());
KNNEngine knnEngine = KNNEngine.getEngine(engineName);

engineFileName = buildEngineFileName(state.segmentInfo.name, knnEngine.getLatestBuildVersion(),
field.name, knnEngine.getExtension());
indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(),
engineFileName).toString();
tmpEngineFileName = engineFileName + TEMP_SUFFIX;
String tempIndexPath = indexPath + TEMP_SUFFIX;

createKNNIndexFromScratch(field, pair, knnEngine, tempIndexPath);
}

Expand Down
3 changes: 3 additions & 0 deletions src/main/java/org/opensearch/knn/plugin/KNNPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.KNNVectorFieldMapper;

import org.opensearch.knn.index.KNNWeight;
import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy;
import org.opensearch.knn.indices.ModelCache;
import org.opensearch.knn.indices.ModelDao;
Expand Down Expand Up @@ -173,6 +174,8 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
ModelCache.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService);
TrainingJobRunner.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance());
KNNCircuitBreaker.getInstance().initialize(threadPool, clusterService, client);
KNNQueryBuilder.initialize(ModelDao.OpenSearchKNNModelDao.getInstance());
KNNWeight.initialize(ModelDao.OpenSearchKNNModelDao.getInstance());
knnStats = new KNNStats(KNNStatsConfig.KNN_STATS);
return ImmutableList.of(knnStats);
}
Expand Down
Loading

0 comments on commit 6d7eb0a

Please sign in to comment.