From 7def84e40f85380956dd1362ae76e476a77f5953 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Wed, 19 Oct 2022 09:34:13 -0700 Subject: [PATCH] Refactor kNN codec related classes (#582) * Refactor codec related classes, create KNNCodecVersion abstraction Signed-off-by: Martin Gaievski (cherry picked from commit 3d0a9d7ed6b1c609a9ab1ed2eff897dbe05fca63) --- .../codec/BasePerFieldKnnVectorsFormat.java | 79 ++++++++++++++++ .../index/codec/KNN910Codec/KNN910Codec.java | 13 +-- .../index/codec/KNN920Codec/KNN920Codec.java | 16 +--- .../KNN920PerFieldKnnVectorsFormat.java | 70 ++------------ .../index/codec/KNN940Codec/KNN940Codec.java | 14 +-- .../KNN940PerFieldKnnVectorsFormat.java | 70 ++------------ .../knn/index/codec/KNNCodecFactory.java | 51 ----------- .../knn/index/codec/KNNCodecService.java | 7 +- .../knn/index/codec/KNNCodecVersion.java | 91 +++++++++++++++++++ .../knn/index/codec/KNNFormatFactory.java | 53 ----------- .../codec/KNN920Codec/KNN920CodecTests.java | 6 +- .../codec/KNN940Codec/KNN940CodecTests.java | 8 +- .../knn/index/codec/KNNCodecFactoryTests.java | 43 ++++----- .../knn/index/codec/KNNCodecTestCase.java | 3 +- .../index/codec/KNNFormatFactoryTests.java | 51 ----------- 15 files changed, 237 insertions(+), 338 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java delete mode 100644 src/main/java/org/opensearch/knn/index/codec/KNNCodecFactory.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java delete mode 100644 src/main/java/org/opensearch/knn/index/codec/KNNFormatFactory.java delete mode 100644 src/test/java/org/opensearch/knn/index/codec/KNNFormatFactoryTests.java diff --git a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java new file mode 100644 index 000000000..d10ad9821 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java @@ -0,0 +1,79 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec; + +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; + +import java.util.Map; +import java.util.Optional; +import java.util.function.BiFunction; +import java.util.function.Supplier; + +/** + * Base class for PerFieldKnnVectorsFormat, builds KnnVectorsFormat based on specific Lucene version + */ +@AllArgsConstructor +@Log4j2 +public abstract class BasePerFieldKnnVectorsFormat extends PerFieldKnnVectorsFormat { + + private final Optional mapperService; + private final int defaultMaxConnections; + private final int defaultBeamWidth; + private final Supplier defaultFormatSupplier; + private final BiFunction formatSupplier; + + @Override + public KnnVectorsFormat getKnnVectorsFormatForField(final String field) { + if (isKnnVectorFieldType(field) == false) { + log.debug( + "Initialize KNN vector format for field [{}] with default params [max_connections] = \"{}\" and [beam_width] = \"{}\"", + field, + defaultMaxConnections, + defaultBeamWidth + ); + return defaultFormatSupplier.get(); + } + var type = (KNNVectorFieldMapper.KNNVectorFieldType) mapperService.orElseThrow( + () -> new IllegalStateException( + String.format("Cannot read field type for field [%s] because mapper service is not available", field) + ) + ).fieldType(field); + var params = type.getKnnMethodContext().getMethodComponent().getParameters(); + int maxConnections = getMaxConnections(params); + int beamWidth = getBeamWidth(params); + log.debug( + "Initialize KNN vector format for field [{}] with params [max_connections] = \"{}\" and [beam_width] = \"{}\"", + field, + maxConnections, + beamWidth + ); + return formatSupplier.apply(maxConnections, beamWidth); + } + + private boolean isKnnVectorFieldType(final String field) { + return mapperService.isPresent() && mapperService.get().fieldType(field) instanceof KNNVectorFieldMapper.KNNVectorFieldType; + } + + private int getMaxConnections(final Map params) { + if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_M)) { + return (int) params.get(KNNConstants.METHOD_PARAMETER_M); + } + return defaultMaxConnections; + } + + private int getBeamWidth(final Map params) { + if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION)) { + return (int) params.get(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION); + } + return defaultBeamWidth; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN910Codec/KNN910Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN910Codec/KNN910Codec.java index 0acaccfbf..77783dc29 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN910Codec/KNN910Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN910Codec/KNN910Codec.java @@ -8,10 +8,8 @@ import org.apache.lucene.codecs.CompoundFormat; import org.apache.lucene.codecs.DocValuesFormat; import org.apache.lucene.codecs.FilterCodec; +import org.opensearch.knn.index.codec.KNNCodecVersion; import org.opensearch.knn.index.codec.KNNFormatFacade; -import org.opensearch.knn.index.codec.KNNFormatFactory; - -import static org.opensearch.knn.index.codec.KNNCodecFactory.CodecDelegateFactory.createKNN91DefaultDelegate; /** * Extends the Codec to support a new file format for KNN index @@ -19,15 +17,14 @@ * */ public final class KNN910Codec extends FilterCodec { - - private static final String KNN910 = "KNN910Codec"; + private static final KNNCodecVersion VERSION = KNNCodecVersion.V_9_1_0; private final KNNFormatFacade knnFormatFacade; /** * No arg constructor that uses Lucene91 as the delegate */ public KNN910Codec() { - this(createKNN91DefaultDelegate()); + this(VERSION.getDefaultCodecDelegate()); } /** @@ -36,8 +33,8 @@ public KNN910Codec() { * @param delegate codec that will perform all operations this codec does not override */ public KNN910Codec(Codec delegate) { - super(KNN910, delegate); - knnFormatFacade = KNNFormatFactory.createKNN910Format(delegate); + super(VERSION.getCodecName(), delegate); + knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate); } @Override diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920Codec.java index 26abcea60..b79c1b4f2 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920Codec.java @@ -12,21 +12,15 @@ import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.opensearch.knn.index.codec.KNNCodecVersion; import org.opensearch.knn.index.codec.KNNFormatFacade; -import org.opensearch.knn.index.codec.KNNFormatFactory; - -import java.util.Optional; - -import static org.opensearch.knn.index.codec.KNNCodecFactory.CodecDelegateFactory.createKNN92DefaultDelegate; /** * KNN codec that is based on Lucene92 codec */ @Log4j2 public final class KNN920Codec extends FilterCodec { - - private static final String KNN920 = "KNN920Codec"; - + private static final KNNCodecVersion VERSION = KNNCodecVersion.V_9_2_0; private final KNNFormatFacade knnFormatFacade; private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat; @@ -34,7 +28,7 @@ public final class KNN920Codec extends FilterCodec { * No arg constructor that uses Lucene91 as the delegate */ public KNN920Codec() { - this(createKNN92DefaultDelegate(), new KNN920PerFieldKnnVectorsFormat(Optional.empty())); + this(VERSION.getDefaultCodecDelegate(), VERSION.getPerFieldKnnVectorsFormat()); } /** @@ -45,8 +39,8 @@ public KNN920Codec() { */ @Builder public KNN920Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat) { - super(KNN920, delegate); - knnFormatFacade = KNNFormatFactory.createKNN920Format(delegate); + super(VERSION.getCodecName(), delegate); + knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate); perFieldKnnVectorsFormat = knnVectorsFormat; } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920PerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920PerFieldKnnVectorsFormat.java index 0286e829a..ae1ef206c 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920PerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920PerFieldKnnVectorsFormat.java @@ -5,74 +5,24 @@ package org.opensearch.knn.index.codec.KNN920Codec; -import lombok.AllArgsConstructor; -import lombok.extern.log4j.Log4j2; -import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.backward_codecs.lucene92.Lucene92HnswVectorsFormat; -import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.opensearch.index.mapper.MapperService; -import org.opensearch.knn.common.KNNConstants; -import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.index.codec.BasePerFieldKnnVectorsFormat; -import java.util.Map; import java.util.Optional; /** * Class provides per field format implementation for Lucene Knn vector type */ -@AllArgsConstructor -@Log4j2 -public class KNN920PerFieldKnnVectorsFormat extends PerFieldKnnVectorsFormat { - - private final Optional mapperService; - - @Override - public KnnVectorsFormat getKnnVectorsFormatForField(final String field) { - if (isNotKnnVectorFieldType(field)) { - log.debug( - String.format( - "Initialize KNN vector format for field [%s] with default params [max_connections] = \"%d\" and [beam_width] = \"%d\"", - field, - Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN, - Lucene92HnswVectorsFormat.DEFAULT_BEAM_WIDTH - ) - ); - return new Lucene92HnswVectorsFormat(); - } - var type = (KNNVectorFieldMapper.KNNVectorFieldType) mapperService.orElseThrow( - () -> new IllegalStateException( - String.format("Cannot read field type for field [%s] because mapper service is not available", field) - ) - ).fieldType(field); - var params = type.getKnnMethodContext().getMethodComponent().getParameters(); - int maxConnections = getMaxConnections(params); - int beamWidth = getBeamWidth(params); - log.debug( - String.format( - "Initialize KNN vector format for field [%s] with params [max_connections] = \"%d\" and [beam_width] = \"%d\"", - field, - maxConnections, - beamWidth - ) +public class KNN920PerFieldKnnVectorsFormat extends BasePerFieldKnnVectorsFormat { + + public KNN920PerFieldKnnVectorsFormat(final Optional mapperService) { + super( + mapperService, + Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN, + Lucene92HnswVectorsFormat.DEFAULT_BEAM_WIDTH, + () -> new Lucene92HnswVectorsFormat(), + (maxConnm, beamWidth) -> new Lucene92HnswVectorsFormat(maxConnm, beamWidth) ); - return new Lucene92HnswVectorsFormat(maxConnections, beamWidth); - } - - private boolean isNotKnnVectorFieldType(final String field) { - return !mapperService.isPresent() || !(mapperService.get().fieldType(field) instanceof KNNVectorFieldMapper.KNNVectorFieldType); - } - - private int getMaxConnections(final Map params) { - if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_M)) { - return (int) params.get(KNNConstants.METHOD_PARAMETER_M); - } - return Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN; - } - - private int getBeamWidth(final Map params) { - if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION)) { - return (int) params.get(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION); - } - return Lucene92HnswVectorsFormat.DEFAULT_BEAM_WIDTH; } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940Codec.java index 43a348cee..a056581d6 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940Codec.java @@ -12,15 +12,11 @@ import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.opensearch.knn.index.codec.KNNCodecVersion; import org.opensearch.knn.index.codec.KNNFormatFacade; -import org.opensearch.knn.index.codec.KNNFormatFactory; - -import java.util.Optional; - -import static org.opensearch.knn.index.codec.KNNCodecFactory.CodecDelegateFactory.createKNN94DefaultDelegate; public class KNN940Codec extends FilterCodec { - private static final String KNN940 = "KNN940Codec"; + private static final KNNCodecVersion VERSION = KNNCodecVersion.V_9_4_0; private final KNNFormatFacade knnFormatFacade; private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat; @@ -28,7 +24,7 @@ public class KNN940Codec extends FilterCodec { * No arg constructor that uses Lucene94 as the delegate */ public KNN940Codec() { - this(createKNN94DefaultDelegate(), new KNN940PerFieldKnnVectorsFormat(Optional.empty())); + this(VERSION.getDefaultCodecDelegate(), VERSION.getPerFieldKnnVectorsFormat()); } /** @@ -40,8 +36,8 @@ public KNN940Codec() { */ @Builder protected KNN940Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat) { - super(KNN940, delegate); - knnFormatFacade = KNNFormatFactory.createKNN940Format(delegate); + super(VERSION.getCodecName(), delegate); + knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate); perFieldKnnVectorsFormat = knnVectorsFormat; } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940PerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940PerFieldKnnVectorsFormat.java index 5b717106f..d80c757c9 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940PerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940PerFieldKnnVectorsFormat.java @@ -5,74 +5,24 @@ package org.opensearch.knn.index.codec.KNN940Codec; -import lombok.AllArgsConstructor; -import lombok.extern.log4j.Log4j2; import org.apache.lucene.codecs.lucene94.Lucene94HnswVectorsFormat; -import org.apache.lucene.codecs.KnnVectorsFormat; -import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.opensearch.index.mapper.MapperService; -import org.opensearch.knn.common.KNNConstants; -import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.index.codec.BasePerFieldKnnVectorsFormat; -import java.util.Map; import java.util.Optional; /** * Class provides per field format implementation for Lucene Knn vector type */ -@AllArgsConstructor -@Log4j2 -public class KNN940PerFieldKnnVectorsFormat extends PerFieldKnnVectorsFormat { - - private final Optional mapperService; - - @Override - public KnnVectorsFormat getKnnVectorsFormatForField(final String field) { - if (isNotKnnVectorFieldType(field)) { - log.debug( - String.format( - "Initialize KNN vector format for field [%s] with default params [max_connections] = \"%d\" and [beam_width] = \"%d\"", - field, - Lucene94HnswVectorsFormat.DEFAULT_MAX_CONN, - Lucene94HnswVectorsFormat.DEFAULT_BEAM_WIDTH - ) - ); - return new Lucene94HnswVectorsFormat(); - } - var type = (KNNVectorFieldMapper.KNNVectorFieldType) mapperService.orElseThrow( - () -> new IllegalStateException( - String.format("Cannot read field type for field [%s] because mapper service is not available", field) - ) - ).fieldType(field); - var params = type.getKnnMethodContext().getMethodComponent().getParameters(); - int maxConnections = getMaxConnections(params); - int beamWidth = getBeamWidth(params); - log.debug( - String.format( - "Initialize KNN vector format for field [%s] with params [max_connections] = \"%d\" and [beam_width] = \"%d\"", - field, - maxConnections, - beamWidth - ) +public class KNN940PerFieldKnnVectorsFormat extends BasePerFieldKnnVectorsFormat { + + public KNN940PerFieldKnnVectorsFormat(final Optional mapperService) { + super( + mapperService, + Lucene94HnswVectorsFormat.DEFAULT_MAX_CONN, + Lucene94HnswVectorsFormat.DEFAULT_BEAM_WIDTH, + () -> new Lucene94HnswVectorsFormat(), + (maxConnm, beamWidth) -> new Lucene94HnswVectorsFormat(maxConnm, beamWidth) ); - return new Lucene94HnswVectorsFormat(maxConnections, beamWidth); - } - - private boolean isNotKnnVectorFieldType(final String field) { - return !mapperService.isPresent() || !(mapperService.get().fieldType(field) instanceof KNNVectorFieldMapper.KNNVectorFieldType); - } - - private int getMaxConnections(final Map params) { - if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_M)) { - return (int) params.get(KNNConstants.METHOD_PARAMETER_M); - } - return Lucene94HnswVectorsFormat.DEFAULT_MAX_CONN; - } - - private int getBeamWidth(final Map params) { - if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION)) { - return (int) params.get(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION); - } - return Lucene94HnswVectorsFormat.DEFAULT_BEAM_WIDTH; } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNNCodecFactory.java b/src/main/java/org/opensearch/knn/index/codec/KNNCodecFactory.java deleted file mode 100644 index e53e1dd2a..000000000 --- a/src/main/java/org/opensearch/knn/index/codec/KNNCodecFactory.java +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ -package org.opensearch.knn.index.codec; - -import lombok.AllArgsConstructor; -import org.apache.lucene.codecs.Codec; -import org.apache.lucene.backward_codecs.lucene91.Lucene91Codec; -import org.apache.lucene.backward_codecs.lucene92.Lucene92Codec; -import org.apache.lucene.codecs.lucene94.Lucene94Codec; -import org.opensearch.index.mapper.MapperService; -import org.opensearch.knn.index.codec.KNN940Codec.KNN940Codec; -import org.opensearch.knn.index.codec.KNN940Codec.KNN940PerFieldKnnVectorsFormat; - -import java.util.Optional; - -/** - * Factory abstraction for KNN codec - */ -@AllArgsConstructor -public class KNNCodecFactory { - - private final MapperService mapperService; - - public Codec createKNNCodec(final Codec userCodec) { - var codec = KNN940Codec.builder() - .delegate(userCodec) - .knnVectorsFormat(new KNN940PerFieldKnnVectorsFormat(Optional.of(mapperService))) - .build(); - return codec; - } - - /** - * Factory abstraction for codec delegate - */ - public static class CodecDelegateFactory { - - public static Codec createKNN91DefaultDelegate() { - return new Lucene91Codec(); - } - - public static Codec createKNN92DefaultDelegate() { - return new Lucene92Codec(); - } - - public static Codec createKNN94DefaultDelegate() { - return new Lucene94Codec(); - } - } -} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java b/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java index 8ce5e6928..d56e09a3f 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java @@ -8,17 +8,18 @@ import org.opensearch.index.codec.CodecServiceConfig; import org.apache.lucene.codecs.Codec; import org.opensearch.index.codec.CodecService; +import org.opensearch.index.mapper.MapperService; /** * KNNCodecService to inject the right KNNCodec version */ public class KNNCodecService extends CodecService { - private final KNNCodecFactory knnCodecFactory; + private final MapperService mapperService; public KNNCodecService(CodecServiceConfig codecServiceConfig) { super(codecServiceConfig.getMapperService(), codecServiceConfig.getLogger()); - knnCodecFactory = new KNNCodecFactory(codecServiceConfig.getMapperService()); + mapperService = codecServiceConfig.getMapperService(); } /** @@ -29,6 +30,6 @@ public KNNCodecService(CodecServiceConfig codecServiceConfig) { */ @Override public Codec codec(String name) { - return knnCodecFactory.createKNNCodec(super.codec(name)); + return KNNCodecVersion.current().getKnnCodecSupplier().apply(super.codec(name), mapperService); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java b/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java new file mode 100644 index 000000000..adbbb01ca --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import org.apache.lucene.backward_codecs.lucene91.Lucene91Codec; +import org.apache.lucene.backward_codecs.lucene92.Lucene92Codec; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.lucene94.Lucene94Codec; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.knn.index.codec.KNN80Codec.KNN80CompoundFormat; +import org.opensearch.knn.index.codec.KNN80Codec.KNN80DocValuesFormat; +import org.opensearch.knn.index.codec.KNN910Codec.KNN910Codec; +import org.opensearch.knn.index.codec.KNN920Codec.KNN920Codec; +import org.opensearch.knn.index.codec.KNN920Codec.KNN920PerFieldKnnVectorsFormat; +import org.opensearch.knn.index.codec.KNN940Codec.KNN940Codec; +import org.opensearch.knn.index.codec.KNN940Codec.KNN940PerFieldKnnVectorsFormat; + +import java.util.Optional; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.Supplier; + +/** + * Abstraction for k-NN codec version, aggregates all details for specific version such as codec name, corresponding + * Lucene codec, formats including one for k-NN vector etc. + */ +@AllArgsConstructor +@Getter +public enum KNNCodecVersion { + + V_9_1_0( + "KNN910Codec", + new Lucene91Codec(), + null, + (delegate) -> new KNNFormatFacade( + new KNN80DocValuesFormat(delegate.docValuesFormat()), + new KNN80CompoundFormat(delegate.compoundFormat()) + ), + (userCodec, mapperService) -> new KNN910Codec(userCodec), + KNN910Codec::new + ), + + V_9_2_0( + "KNN920Codec", + new Lucene92Codec(), + new KNN920PerFieldKnnVectorsFormat(Optional.empty()), + (delegate) -> new KNNFormatFacade( + new KNN80DocValuesFormat(delegate.docValuesFormat()), + new KNN80CompoundFormat(delegate.compoundFormat()) + ), + (userCodec, mapperService) -> KNN920Codec.builder() + .delegate(userCodec) + .knnVectorsFormat(new KNN920PerFieldKnnVectorsFormat(Optional.of(mapperService))) + .build(), + KNN920Codec::new + ), + + V_9_4_0( + "KNN940Codec", + new Lucene94Codec(), + new KNN940PerFieldKnnVectorsFormat(Optional.empty()), + (delegate) -> new KNNFormatFacade( + new KNN80DocValuesFormat(delegate.docValuesFormat()), + new KNN80CompoundFormat(delegate.compoundFormat()) + ), + (userCodec, mapperService) -> KNN940Codec.builder() + .delegate(userCodec) + .knnVectorsFormat(new KNN940PerFieldKnnVectorsFormat(Optional.of(mapperService))) + .build(), + KNN940Codec::new + ); + + private static final KNNCodecVersion CURRENT = V_9_4_0; + + private final String codecName; + private final Codec defaultCodecDelegate; + private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat; + private final Function knnFormatFacadeSupplier; + private final BiFunction knnCodecSupplier; + private final Supplier defaultKnnCodecSupplier; + + public static final KNNCodecVersion current() { + return CURRENT; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNNFormatFactory.java b/src/main/java/org/opensearch/knn/index/codec/KNNFormatFactory.java deleted file mode 100644 index ee17189e3..000000000 --- a/src/main/java/org/opensearch/knn/index/codec/KNNFormatFactory.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ -package org.opensearch.knn.index.codec; - -import org.apache.lucene.codecs.Codec; -import org.opensearch.knn.index.codec.KNN80Codec.KNN80CompoundFormat; -import org.opensearch.knn.index.codec.KNN80Codec.KNN80DocValuesFormat; - -/** - * Factory abstraction for KNN format facade creation - */ -public class KNNFormatFactory { - - /** - * Return facade class that abstracts format specific to KNN910 codec - * @param delegate delegate codec that is wrapped by KNN codec - * @return - */ - public static KNNFormatFacade createKNN910Format(final Codec delegate) { - final KNNFormatFacade knnFormatFacade = new KNNFormatFacade( - new KNN80DocValuesFormat(delegate.docValuesFormat()), - new KNN80CompoundFormat(delegate.compoundFormat()) - ); - return knnFormatFacade; - } - - /** - * Return facade class that abstracts format specific to KNN920 codec - * @param delegate delegate codec that is wrapped by KNN codec - * @return - */ - public static KNNFormatFacade createKNN920Format(final Codec delegate) { - final KNNFormatFacade knnFormatFacade = new KNNFormatFacade( - new KNN80DocValuesFormat(delegate.docValuesFormat()), - new KNN80CompoundFormat(delegate.compoundFormat()) - ); - return knnFormatFacade; - } - - /** - * Return facade class that abstracts format specific to KNN940 codec - * @param delegate delegate codec that is wrapped by KNN codec - */ - public static KNNFormatFacade createKNN940Format(final Codec delegate) { - final KNNFormatFacade knnFormatFacade = new KNNFormatFacade( - new KNN80DocValuesFormat(delegate.docValuesFormat()), - new KNN80CompoundFormat(delegate.compoundFormat()) - ); - return knnFormatFacade; - } -} diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920CodecTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920CodecTests.java index 06cc7fad8..8cdfc2d69 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920CodecTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920CodecTests.java @@ -9,15 +9,15 @@ import java.io.IOException; import java.util.concurrent.ExecutionException; -import static org.opensearch.knn.index.codec.KNNCodecFactory.CodecDelegateFactory.createKNN92DefaultDelegate; +import static org.opensearch.knn.index.codec.KNNCodecVersion.V_9_2_0; public class KNN920CodecTests extends KNNCodecTestCase { public void testMultiFieldsKnnIndex() throws Exception { - testMultiFieldsKnnIndex(KNN920Codec.builder().delegate(createKNN92DefaultDelegate()).build()); + testMultiFieldsKnnIndex(KNN920Codec.builder().delegate(V_9_2_0.getDefaultCodecDelegate()).build()); } public void testBuildFromModelTemplate() throws InterruptedException, ExecutionException, IOException { - testBuildFromModelTemplate((KNN920Codec.builder().delegate(createKNN92DefaultDelegate()).build())); + testBuildFromModelTemplate((KNN920Codec.builder().delegate(V_9_2_0.getDefaultCodecDelegate()).build())); } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940CodecTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940CodecTests.java index 578f88f9f..1101d93bb 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940CodecTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940CodecTests.java @@ -14,16 +14,16 @@ import java.util.concurrent.ExecutionException; import java.util.function.Function; -import static org.opensearch.knn.index.codec.KNNCodecFactory.CodecDelegateFactory.createKNN94DefaultDelegate; +import static org.opensearch.knn.index.codec.KNNCodecVersion.V_9_4_0; public class KNN940CodecTests extends KNNCodecTestCase { public void testMultiFieldsKnnIndex() throws Exception { - testMultiFieldsKnnIndex(KNN940Codec.builder().delegate(createKNN94DefaultDelegate()).build()); + testMultiFieldsKnnIndex(KNN940Codec.builder().delegate(V_9_4_0.getDefaultCodecDelegate()).build()); } public void testBuildFromModelTemplate() throws InterruptedException, ExecutionException, IOException { - testBuildFromModelTemplate((KNN940Codec.builder().delegate(createKNN94DefaultDelegate()).build())); + testBuildFromModelTemplate((KNN940Codec.builder().delegate(V_9_4_0.getDefaultCodecDelegate()).build())); } public void testKnnVectorIndex() throws Exception { @@ -31,7 +31,7 @@ public void testKnnVectorIndex() throws Exception { mapperService) -> new KNN940PerFieldKnnVectorsFormat(Optional.of(mapperService)); Function knnCodecProvider = (knnVectorFormat) -> KNN940Codec.builder() - .delegate(createKNN94DefaultDelegate()) + .delegate(V_9_4_0.getDefaultCodecDelegate()) .knnVectorsFormat(knnVectorFormat) .build(); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecFactoryTests.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecFactoryTests.java index 6e1c96bcb..d918f5439 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecFactoryTests.java @@ -5,42 +5,39 @@ package org.opensearch.knn.index.codec; +import org.apache.lucene.backward_codecs.lucene92.Lucene92Codec; import org.apache.lucene.codecs.Codec; import org.apache.lucene.backward_codecs.lucene91.Lucene91Codec; -import org.apache.lucene.backward_codecs.lucene92.Lucene92Codec; import org.apache.lucene.codecs.lucene94.Lucene94Codec; -import org.opensearch.index.mapper.MapperService; import org.opensearch.knn.KNNTestCase; -import org.opensearch.knn.index.codec.KNN940Codec.KNN940Codec; -import static org.mockito.Mockito.mock; +import static org.opensearch.knn.index.codec.KNNCodecVersion.V_9_1_0; +import static org.opensearch.knn.index.codec.KNNCodecVersion.V_9_2_0; +import static org.opensearch.knn.index.codec.KNNCodecVersion.V_9_4_0; public class KNNCodecFactoryTests extends KNNTestCase { - public void testKNN91DefaultDelegate() { - Codec knn91DefaultDelegate = KNNCodecFactory.CodecDelegateFactory.createKNN91DefaultDelegate(); - assertNotNull(knn91DefaultDelegate); - assertTrue(knn91DefaultDelegate instanceof Lucene91Codec); + public void testKNN910Codec() { + assertDelegateForVersion(V_9_1_0, Lucene91Codec.class); + assertNull(V_9_1_0.getPerFieldKnnVectorsFormat()); + assertNotNull(V_9_1_0.getKnnFormatFacadeSupplier().apply(V_9_1_0.getDefaultCodecDelegate())); } - public void testKNN92DefaultDelegate() { - Codec knn92DefaultDelegate = KNNCodecFactory.CodecDelegateFactory.createKNN92DefaultDelegate(); - assertNotNull(knn92DefaultDelegate); - assertTrue(knn92DefaultDelegate instanceof Lucene92Codec); + public void testKNN920Codec() { + assertDelegateForVersion(V_9_2_0, Lucene92Codec.class); + assertNotNull(V_9_2_0.getPerFieldKnnVectorsFormat()); + assertNotNull(V_9_2_0.getKnnFormatFacadeSupplier().apply(V_9_2_0.getDefaultCodecDelegate())); } - public void testKNN94DefaultDelegate() { - Codec knn94DefaultDelegate = KNNCodecFactory.CodecDelegateFactory.createKNN94DefaultDelegate(); - assertNotNull(knn94DefaultDelegate); - assertTrue(knn94DefaultDelegate instanceof Lucene94Codec); + public void testKNN940Codec() { + assertDelegateForVersion(V_9_4_0, Lucene94Codec.class); + assertNotNull(V_9_4_0.getPerFieldKnnVectorsFormat()); + assertNotNull(V_9_4_0.getKnnFormatFacadeSupplier().apply(V_9_4_0.getDefaultCodecDelegate())); } - public void testKNNDefaultCodec() { - MapperService mapperService = mock(MapperService.class); - KNNCodecFactory knnCodecFactory = new KNNCodecFactory(mapperService); - Codec knnCodec = knnCodecFactory.createKNNCodec(KNNCodecFactory.CodecDelegateFactory.createKNN94DefaultDelegate()); - assertNotNull(knnCodec); - assertTrue(knnCodec instanceof KNN940Codec); - assertEquals("KNN940Codec", knnCodec.getName()); + private void assertDelegateForVersion(final KNNCodecVersion codecVersion, final Class expectedCodecClass) { + final Codec defaultDelegate = codecVersion.getDefaultCodecDelegate(); + assertNotNull(defaultDelegate); + assertTrue(defaultDelegate.getClass().isAssignableFrom(expectedCodecClass)); } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java index 623f2dc74..43ae19320 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java @@ -19,7 +19,6 @@ import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.query.KNNQueryFactory; -import org.opensearch.knn.index.codec.KNN940Codec.KNN940Codec; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.query.KNNQuery; import org.opensearch.knn.index.KNNSettings; @@ -79,7 +78,7 @@ */ public class KNNCodecTestCase extends KNNTestCase { - private static final KNN940Codec ACTUAL_CODEC = new KNN940Codec(); + private static final Codec ACTUAL_CODEC = KNNCodecVersion.current().getDefaultKnnCodecSupplier().get(); private static FieldType sampleFieldType; static { sampleFieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNFormatFactoryTests.java b/src/test/java/org/opensearch/knn/index/codec/KNNFormatFactoryTests.java deleted file mode 100644 index 1b076f10b..000000000 --- a/src/test/java/org/opensearch/knn/index/codec/KNNFormatFactoryTests.java +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.codec; - -import org.apache.lucene.codecs.Codec; -import org.opensearch.index.mapper.MapperService; -import org.opensearch.knn.KNNTestCase; - -import static org.mockito.Mockito.mock; - -public class KNNFormatFactoryTests extends KNNTestCase { - - public void testKNN91Format() { - final Codec lucene91CodecDelegate = KNNCodecFactory.CodecDelegateFactory.createKNN91DefaultDelegate(); - MapperService mapperService = mock(MapperService.class); - KNNCodecFactory knnCodecFactory = new KNNCodecFactory(mapperService); - final Codec knnCodec = knnCodecFactory.createKNNCodec(lucene91CodecDelegate); - KNNFormatFacade knnFormatFacade = KNNFormatFactory.createKNN910Format(knnCodec); - - assertNotNull(knnFormatFacade); - assertNotNull(knnFormatFacade.compoundFormat()); - assertNotNull(knnFormatFacade.docValuesFormat()); - } - - public void testKNN92Format() { - MapperService mapperService = mock(MapperService.class); - final Codec lucene92CodecDelegate = KNNCodecFactory.CodecDelegateFactory.createKNN92DefaultDelegate(); - KNNCodecFactory knnCodecFactory = new KNNCodecFactory(mapperService); - final Codec knnCodec = knnCodecFactory.createKNNCodec(lucene92CodecDelegate); - KNNFormatFacade knnFormatFacade = KNNFormatFactory.createKNN920Format(knnCodec); - - assertNotNull(knnFormatFacade); - assertNotNull(knnFormatFacade.compoundFormat()); - assertNotNull(knnFormatFacade.docValuesFormat()); - } - - public void testKNN94Format() { - MapperService mapperService = mock(MapperService.class); - final Codec lucene94CodecDelegate = KNNCodecFactory.CodecDelegateFactory.createKNN94DefaultDelegate(); - KNNCodecFactory knnCodecFactory = new KNNCodecFactory(mapperService); - final Codec knnCodec = knnCodecFactory.createKNNCodec(lucene94CodecDelegate); - KNNFormatFacade knnFormatFacade = KNNFormatFactory.createKNN940Format(knnCodec); - - assertNotNull(knnFormatFacade); - assertNotNull(knnFormatFacade.compoundFormat()); - assertNotNull(knnFormatFacade.docValuesFormat()); - } -}