Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Refactor kNN codec related classes (#582) #584

Merged
merged 1 commit into from
Oct 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;

import java.util.Map;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.function.Supplier;

/**
* Base class for PerFieldKnnVectorsFormat, builds KnnVectorsFormat based on specific Lucene version
*/
@AllArgsConstructor
@Log4j2
public abstract class BasePerFieldKnnVectorsFormat extends PerFieldKnnVectorsFormat {

private final Optional<MapperService> mapperService;
private final int defaultMaxConnections;
private final int defaultBeamWidth;
private final Supplier<KnnVectorsFormat> defaultFormatSupplier;
private final BiFunction<Integer, Integer, KnnVectorsFormat> formatSupplier;

@Override
public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
if (isKnnVectorFieldType(field) == false) {
log.debug(
"Initialize KNN vector format for field [{}] with default params [max_connections] = \"{}\" and [beam_width] = \"{}\"",
field,
defaultMaxConnections,
defaultBeamWidth
);
return defaultFormatSupplier.get();
}
var type = (KNNVectorFieldMapper.KNNVectorFieldType) mapperService.orElseThrow(
() -> new IllegalStateException(
String.format("Cannot read field type for field [%s] because mapper service is not available", field)
)
).fieldType(field);
var params = type.getKnnMethodContext().getMethodComponent().getParameters();
int maxConnections = getMaxConnections(params);
int beamWidth = getBeamWidth(params);
log.debug(
"Initialize KNN vector format for field [{}] with params [max_connections] = \"{}\" and [beam_width] = \"{}\"",
field,
maxConnections,
beamWidth
);
return formatSupplier.apply(maxConnections, beamWidth);
}

private boolean isKnnVectorFieldType(final String field) {
return mapperService.isPresent() && mapperService.get().fieldType(field) instanceof KNNVectorFieldMapper.KNNVectorFieldType;
}

private int getMaxConnections(final Map<String, Object> params) {
if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_M)) {
return (int) params.get(KNNConstants.METHOD_PARAMETER_M);
}
return defaultMaxConnections;
}

private int getBeamWidth(final Map<String, Object> params) {
if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION)) {
return (int) params.get(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION);
}
return defaultBeamWidth;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,23 @@
import org.apache.lucene.codecs.CompoundFormat;
import org.apache.lucene.codecs.DocValuesFormat;
import org.apache.lucene.codecs.FilterCodec;
import org.opensearch.knn.index.codec.KNNCodecVersion;
import org.opensearch.knn.index.codec.KNNFormatFacade;
import org.opensearch.knn.index.codec.KNNFormatFactory;

import static org.opensearch.knn.index.codec.KNNCodecFactory.CodecDelegateFactory.createKNN91DefaultDelegate;

/**
* Extends the Codec to support a new file format for KNN index
* based on the mappings.
*
*/
public final class KNN910Codec extends FilterCodec {

private static final String KNN910 = "KNN910Codec";
private static final KNNCodecVersion VERSION = KNNCodecVersion.V_9_1_0;
private final KNNFormatFacade knnFormatFacade;

/**
* No arg constructor that uses Lucene91 as the delegate
*/
public KNN910Codec() {
this(createKNN91DefaultDelegate());
this(VERSION.getDefaultCodecDelegate());
}

/**
Expand All @@ -36,8 +33,8 @@ public KNN910Codec() {
* @param delegate codec that will perform all operations this codec does not override
*/
public KNN910Codec(Codec delegate) {
super(KNN910, delegate);
knnFormatFacade = KNNFormatFactory.createKNN910Format(delegate);
super(VERSION.getCodecName(), delegate);
knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,23 @@
import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.knn.index.codec.KNNCodecVersion;
import org.opensearch.knn.index.codec.KNNFormatFacade;
import org.opensearch.knn.index.codec.KNNFormatFactory;

import java.util.Optional;

import static org.opensearch.knn.index.codec.KNNCodecFactory.CodecDelegateFactory.createKNN92DefaultDelegate;

/**
* KNN codec that is based on Lucene92 codec
*/
@Log4j2
public final class KNN920Codec extends FilterCodec {

private static final String KNN920 = "KNN920Codec";

private static final KNNCodecVersion VERSION = KNNCodecVersion.V_9_2_0;
private final KNNFormatFacade knnFormatFacade;
private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat;

/**
* No arg constructor that uses Lucene91 as the delegate
*/
public KNN920Codec() {
this(createKNN92DefaultDelegate(), new KNN920PerFieldKnnVectorsFormat(Optional.empty()));
this(VERSION.getDefaultCodecDelegate(), VERSION.getPerFieldKnnVectorsFormat());
}

/**
Expand All @@ -45,8 +39,8 @@ public KNN920Codec() {
*/
@Builder
public KNN920Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat) {
super(KNN920, delegate);
knnFormatFacade = KNNFormatFactory.createKNN920Format(delegate);
super(VERSION.getCodecName(), delegate);
knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate);
perFieldKnnVectorsFormat = knnVectorsFormat;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,74 +5,24 @@

package org.opensearch.knn.index.codec.KNN920Codec;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.backward_codecs.lucene92.Lucene92HnswVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.codec.BasePerFieldKnnVectorsFormat;

import java.util.Map;
import java.util.Optional;

/**
* Class provides per field format implementation for Lucene Knn vector type
*/
@AllArgsConstructor
@Log4j2
public class KNN920PerFieldKnnVectorsFormat extends PerFieldKnnVectorsFormat {

private final Optional<MapperService> mapperService;

@Override
public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
if (isNotKnnVectorFieldType(field)) {
log.debug(
String.format(
"Initialize KNN vector format for field [%s] with default params [max_connections] = \"%d\" and [beam_width] = \"%d\"",
field,
Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene92HnswVectorsFormat.DEFAULT_BEAM_WIDTH
)
);
return new Lucene92HnswVectorsFormat();
}
var type = (KNNVectorFieldMapper.KNNVectorFieldType) mapperService.orElseThrow(
() -> new IllegalStateException(
String.format("Cannot read field type for field [%s] because mapper service is not available", field)
)
).fieldType(field);
var params = type.getKnnMethodContext().getMethodComponent().getParameters();
int maxConnections = getMaxConnections(params);
int beamWidth = getBeamWidth(params);
log.debug(
String.format(
"Initialize KNN vector format for field [%s] with params [max_connections] = \"%d\" and [beam_width] = \"%d\"",
field,
maxConnections,
beamWidth
)
public class KNN920PerFieldKnnVectorsFormat extends BasePerFieldKnnVectorsFormat {

public KNN920PerFieldKnnVectorsFormat(final Optional<MapperService> mapperService) {
super(
mapperService,
Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene92HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
() -> new Lucene92HnswVectorsFormat(),
(maxConnm, beamWidth) -> new Lucene92HnswVectorsFormat(maxConnm, beamWidth)
);
return new Lucene92HnswVectorsFormat(maxConnections, beamWidth);
}

private boolean isNotKnnVectorFieldType(final String field) {
return !mapperService.isPresent() || !(mapperService.get().fieldType(field) instanceof KNNVectorFieldMapper.KNNVectorFieldType);
}

private int getMaxConnections(final Map<String, Object> params) {
if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_M)) {
return (int) params.get(KNNConstants.METHOD_PARAMETER_M);
}
return Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN;
}

private int getBeamWidth(final Map<String, Object> params) {
if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION)) {
return (int) params.get(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION);
}
return Lucene92HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,19 @@
import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.knn.index.codec.KNNCodecVersion;
import org.opensearch.knn.index.codec.KNNFormatFacade;
import org.opensearch.knn.index.codec.KNNFormatFactory;

import java.util.Optional;

import static org.opensearch.knn.index.codec.KNNCodecFactory.CodecDelegateFactory.createKNN94DefaultDelegate;

public class KNN940Codec extends FilterCodec {
private static final String KNN940 = "KNN940Codec";
private static final KNNCodecVersion VERSION = KNNCodecVersion.V_9_4_0;
private final KNNFormatFacade knnFormatFacade;
private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat;

/**
* No arg constructor that uses Lucene94 as the delegate
*/
public KNN940Codec() {
this(createKNN94DefaultDelegate(), new KNN940PerFieldKnnVectorsFormat(Optional.empty()));
this(VERSION.getDefaultCodecDelegate(), VERSION.getPerFieldKnnVectorsFormat());
}

/**
Expand All @@ -40,8 +36,8 @@ public KNN940Codec() {
*/
@Builder
protected KNN940Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat) {
super(KNN940, delegate);
knnFormatFacade = KNNFormatFactory.createKNN940Format(delegate);
super(VERSION.getCodecName(), delegate);
knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate);
perFieldKnnVectorsFormat = knnVectorsFormat;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,74 +5,24 @@

package org.opensearch.knn.index.codec.KNN940Codec;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.codecs.lucene94.Lucene94HnswVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.codec.BasePerFieldKnnVectorsFormat;

import java.util.Map;
import java.util.Optional;

/**
* Class provides per field format implementation for Lucene Knn vector type
*/
@AllArgsConstructor
@Log4j2
public class KNN940PerFieldKnnVectorsFormat extends PerFieldKnnVectorsFormat {

private final Optional<MapperService> mapperService;

@Override
public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
if (isNotKnnVectorFieldType(field)) {
log.debug(
String.format(
"Initialize KNN vector format for field [%s] with default params [max_connections] = \"%d\" and [beam_width] = \"%d\"",
field,
Lucene94HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene94HnswVectorsFormat.DEFAULT_BEAM_WIDTH
)
);
return new Lucene94HnswVectorsFormat();
}
var type = (KNNVectorFieldMapper.KNNVectorFieldType) mapperService.orElseThrow(
() -> new IllegalStateException(
String.format("Cannot read field type for field [%s] because mapper service is not available", field)
)
).fieldType(field);
var params = type.getKnnMethodContext().getMethodComponent().getParameters();
int maxConnections = getMaxConnections(params);
int beamWidth = getBeamWidth(params);
log.debug(
String.format(
"Initialize KNN vector format for field [%s] with params [max_connections] = \"%d\" and [beam_width] = \"%d\"",
field,
maxConnections,
beamWidth
)
public class KNN940PerFieldKnnVectorsFormat extends BasePerFieldKnnVectorsFormat {

public KNN940PerFieldKnnVectorsFormat(final Optional<MapperService> mapperService) {
super(
mapperService,
Lucene94HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene94HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
() -> new Lucene94HnswVectorsFormat(),
(maxConnm, beamWidth) -> new Lucene94HnswVectorsFormat(maxConnm, beamWidth)
);
return new Lucene94HnswVectorsFormat(maxConnections, beamWidth);
}

private boolean isNotKnnVectorFieldType(final String field) {
return !mapperService.isPresent() || !(mapperService.get().fieldType(field) instanceof KNNVectorFieldMapper.KNNVectorFieldType);
}

private int getMaxConnections(final Map<String, Object> params) {
if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_M)) {
return (int) params.get(KNNConstants.METHOD_PARAMETER_M);
}
return Lucene94HnswVectorsFormat.DEFAULT_MAX_CONN;
}

private int getBeamWidth(final Map<String, Object> params) {
if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION)) {
return (int) params.get(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION);
}
return Lucene94HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
}
}
Loading