Skip to content

Commit

Permalink
Refactor kNN codec related classes (#582)
Browse files Browse the repository at this point in the history
* Refactor codec related classes, create KNNCodecVersion abstraction

Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski authored Oct 19, 2022
1 parent 6d77882 commit 3d0a9d7
Show file tree
Hide file tree
Showing 15 changed files with 237 additions and 334 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;

import java.util.Map;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.function.Supplier;

/**
* Base class for PerFieldKnnVectorsFormat, builds KnnVectorsFormat based on specific Lucene version
*/
@AllArgsConstructor
@Log4j2
public abstract class BasePerFieldKnnVectorsFormat extends PerFieldKnnVectorsFormat {

private final Optional<MapperService> mapperService;
private final int defaultMaxConnections;
private final int defaultBeamWidth;
private final Supplier<KnnVectorsFormat> defaultFormatSupplier;
private final BiFunction<Integer, Integer, KnnVectorsFormat> formatSupplier;

@Override
public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
if (isKnnVectorFieldType(field) == false) {
log.debug(
"Initialize KNN vector format for field [{}] with default params [max_connections] = \"{}\" and [beam_width] = \"{}\"",
field,
defaultMaxConnections,
defaultBeamWidth
);
return defaultFormatSupplier.get();
}
var type = (KNNVectorFieldMapper.KNNVectorFieldType) mapperService.orElseThrow(
() -> new IllegalStateException(
String.format("Cannot read field type for field [%s] because mapper service is not available", field)
)
).fieldType(field);
var params = type.getKnnMethodContext().getMethodComponent().getParameters();
int maxConnections = getMaxConnections(params);
int beamWidth = getBeamWidth(params);
log.debug(
"Initialize KNN vector format for field [{}] with params [max_connections] = \"{}\" and [beam_width] = \"{}\"",
field,
maxConnections,
beamWidth
);
return formatSupplier.apply(maxConnections, beamWidth);
}

private boolean isKnnVectorFieldType(final String field) {
return mapperService.isPresent() && mapperService.get().fieldType(field) instanceof KNNVectorFieldMapper.KNNVectorFieldType;
}

private int getMaxConnections(final Map<String, Object> params) {
if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_M)) {
return (int) params.get(KNNConstants.METHOD_PARAMETER_M);
}
return defaultMaxConnections;
}

private int getBeamWidth(final Map<String, Object> params) {
if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION)) {
return (int) params.get(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION);
}
return defaultBeamWidth;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,23 @@
import org.apache.lucene.codecs.CompoundFormat;
import org.apache.lucene.codecs.DocValuesFormat;
import org.apache.lucene.codecs.FilterCodec;
import org.opensearch.knn.index.codec.KNNCodecVersion;
import org.opensearch.knn.index.codec.KNNFormatFacade;
import org.opensearch.knn.index.codec.KNNFormatFactory;

import static org.opensearch.knn.index.codec.KNNCodecFactory.CodecDelegateFactory.createKNN91DefaultDelegate;

/**
* Extends the Codec to support a new file format for KNN index
* based on the mappings.
*
*/
public final class KNN910Codec extends FilterCodec {

private static final String KNN910 = "KNN910Codec";
private static final KNNCodecVersion VERSION = KNNCodecVersion.V_9_1_0;
private final KNNFormatFacade knnFormatFacade;

/**
* No arg constructor that uses Lucene91 as the delegate
*/
public KNN910Codec() {
this(createKNN91DefaultDelegate());
this(VERSION.getDefaultCodecDelegate());
}

/**
Expand All @@ -36,8 +33,8 @@ public KNN910Codec() {
* @param delegate codec that will perform all operations this codec does not override
*/
public KNN910Codec(Codec delegate) {
super(KNN910, delegate);
knnFormatFacade = KNNFormatFactory.createKNN910Format(delegate);
super(VERSION.getCodecName(), delegate);
knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,23 @@
import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.knn.index.codec.KNNCodecVersion;
import org.opensearch.knn.index.codec.KNNFormatFacade;
import org.opensearch.knn.index.codec.KNNFormatFactory;

import java.util.Optional;

import static org.opensearch.knn.index.codec.KNNCodecFactory.CodecDelegateFactory.createKNN92DefaultDelegate;

/**
* KNN codec that is based on Lucene92 codec
*/
@Log4j2
public final class KNN920Codec extends FilterCodec {

private static final String KNN920 = "KNN920Codec";

private static final KNNCodecVersion VERSION = KNNCodecVersion.V_9_2_0;
private final KNNFormatFacade knnFormatFacade;
private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat;

/**
* No arg constructor that uses Lucene91 as the delegate
*/
public KNN920Codec() {
this(createKNN92DefaultDelegate(), new KNN920PerFieldKnnVectorsFormat(Optional.empty()));
this(VERSION.getDefaultCodecDelegate(), VERSION.getPerFieldKnnVectorsFormat());
}

/**
Expand All @@ -45,8 +39,8 @@ public KNN920Codec() {
*/
@Builder
public KNN920Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat) {
super(KNN920, delegate);
knnFormatFacade = KNNFormatFactory.createKNN920Format(delegate);
super(VERSION.getCodecName(), delegate);
knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate);
perFieldKnnVectorsFormat = knnVectorsFormat;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,74 +5,24 @@

package org.opensearch.knn.index.codec.KNN920Codec;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.backward_codecs.lucene92.Lucene92HnswVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.codec.BasePerFieldKnnVectorsFormat;

import java.util.Map;
import java.util.Optional;

/**
* Class provides per field format implementation for Lucene Knn vector type
*/
@AllArgsConstructor
@Log4j2
public class KNN920PerFieldKnnVectorsFormat extends PerFieldKnnVectorsFormat {

private final Optional<MapperService> mapperService;

@Override
public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
if (isNotKnnVectorFieldType(field)) {
log.debug(
String.format(
"Initialize KNN vector format for field [%s] with default params [max_connections] = \"%d\" and [beam_width] = \"%d\"",
field,
Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene92HnswVectorsFormat.DEFAULT_BEAM_WIDTH
)
);
return new Lucene92HnswVectorsFormat();
}
var type = (KNNVectorFieldMapper.KNNVectorFieldType) mapperService.orElseThrow(
() -> new IllegalStateException(
String.format("Cannot read field type for field [%s] because mapper service is not available", field)
)
).fieldType(field);
var params = type.getKnnMethodContext().getMethodComponent().getParameters();
int maxConnections = getMaxConnections(params);
int beamWidth = getBeamWidth(params);
log.debug(
String.format(
"Initialize KNN vector format for field [%s] with params [max_connections] = \"%d\" and [beam_width] = \"%d\"",
field,
maxConnections,
beamWidth
)
public class KNN920PerFieldKnnVectorsFormat extends BasePerFieldKnnVectorsFormat {

public KNN920PerFieldKnnVectorsFormat(final Optional<MapperService> mapperService) {
super(
mapperService,
Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene92HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
() -> new Lucene92HnswVectorsFormat(),
(maxConnm, beamWidth) -> new Lucene92HnswVectorsFormat(maxConnm, beamWidth)
);
return new Lucene92HnswVectorsFormat(maxConnections, beamWidth);
}

private boolean isNotKnnVectorFieldType(final String field) {
return !mapperService.isPresent() || !(mapperService.get().fieldType(field) instanceof KNNVectorFieldMapper.KNNVectorFieldType);
}

private int getMaxConnections(final Map<String, Object> params) {
if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_M)) {
return (int) params.get(KNNConstants.METHOD_PARAMETER_M);
}
return Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN;
}

private int getBeamWidth(final Map<String, Object> params) {
if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION)) {
return (int) params.get(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION);
}
return Lucene92HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,19 @@
import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.knn.index.codec.KNNCodecVersion;
import org.opensearch.knn.index.codec.KNNFormatFacade;
import org.opensearch.knn.index.codec.KNNFormatFactory;

import java.util.Optional;

import static org.opensearch.knn.index.codec.KNNCodecFactory.CodecDelegateFactory.createKNN94DefaultDelegate;

public class KNN940Codec extends FilterCodec {
private static final String KNN940 = "KNN940Codec";
private static final KNNCodecVersion VERSION = KNNCodecVersion.V_9_4_0;
private final KNNFormatFacade knnFormatFacade;
private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat;

/**
* No arg constructor that uses Lucene94 as the delegate
*/
public KNN940Codec() {
this(createKNN94DefaultDelegate(), new KNN940PerFieldKnnVectorsFormat(Optional.empty()));
this(VERSION.getDefaultCodecDelegate(), VERSION.getPerFieldKnnVectorsFormat());
}

/**
Expand All @@ -40,8 +36,8 @@ public KNN940Codec() {
*/
@Builder
protected KNN940Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat) {
super(KNN940, delegate);
knnFormatFacade = KNNFormatFactory.createKNN940Format(delegate);
super(VERSION.getCodecName(), delegate);
knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate);
perFieldKnnVectorsFormat = knnVectorsFormat;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,70 +5,24 @@

package org.opensearch.knn.index.codec.KNN940Codec;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.codecs.lucene94.Lucene94HnswVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.codec.BasePerFieldKnnVectorsFormat;

import java.util.Map;
import java.util.Optional;

/**
* Class provides per field format implementation for Lucene Knn vector type
*/
@AllArgsConstructor
@Log4j2
public class KNN940PerFieldKnnVectorsFormat extends PerFieldKnnVectorsFormat {

private final Optional<MapperService> mapperService;

@Override
public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
if (isNotKnnVectorFieldType(field)) {
log.debug(
"Initialize KNN vector format for field [{}] with default params [max_connections] = \"{}\" and [beam_width] = \"{}\"",
field,
Lucene94HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene94HnswVectorsFormat.DEFAULT_BEAM_WIDTH
);
return new Lucene94HnswVectorsFormat();
}
var type = (KNNVectorFieldMapper.KNNVectorFieldType) mapperService.orElseThrow(
() -> new IllegalStateException(
String.format("Cannot read field type for field [%s] because mapper service is not available", field)
)
).fieldType(field);
var params = type.getKnnMethodContext().getMethodComponent().getParameters();
int maxConnections = getMaxConnections(params);
int beamWidth = getBeamWidth(params);
log.debug(
"Initialize KNN vector format for field [{}] with params [max_connections] = \"{}\" and [beam_width] = \"{}\"",
field,
maxConnections,
beamWidth
public class KNN940PerFieldKnnVectorsFormat extends BasePerFieldKnnVectorsFormat {

public KNN940PerFieldKnnVectorsFormat(final Optional<MapperService> mapperService) {
super(
mapperService,
Lucene94HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene94HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
() -> new Lucene94HnswVectorsFormat(),
(maxConnm, beamWidth) -> new Lucene94HnswVectorsFormat(maxConnm, beamWidth)
);
return new Lucene94HnswVectorsFormat(maxConnections, beamWidth);
}

private boolean isNotKnnVectorFieldType(final String field) {
return !mapperService.isPresent() || !(mapperService.get().fieldType(field) instanceof KNNVectorFieldMapper.KNNVectorFieldType);
}

private int getMaxConnections(final Map<String, Object> params) {
if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_M)) {
return (int) params.get(KNNConstants.METHOD_PARAMETER_M);
}
return Lucene94HnswVectorsFormat.DEFAULT_MAX_CONN;
}

private int getBeamWidth(final Map<String, Object> params) {
if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION)) {
return (int) params.get(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION);
}
return Lucene94HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
}
}
Loading

0 comments on commit 3d0a9d7

Please sign in to comment.