Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor kNN codec related classes #582

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;

import java.util.Map;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.function.Supplier;

/**
* Base class for PerFieldKnnVectorsFormat, builds KnnVectorsFormat based on specific Lucene version
*/
@AllArgsConstructor
@Log4j2
public abstract class BasePerFieldKnnVectorsFormat extends PerFieldKnnVectorsFormat {

private final Optional<MapperService> mapperService;
private final int defaultMaxConnections;
private final int defaultBeamWidth;
private final Supplier<KnnVectorsFormat> defaultFormatSupplier;
private final BiFunction<Integer, Integer, KnnVectorsFormat> formatSupplier;

@Override
public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
if (isKnnVectorFieldType(field) == false) {
log.debug(
String.format(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to use String.format here.

Suggested change
String.format(
log.debug("Initialize KNN vector format for field [{}] with default params [max_connections] = \"{}\" and [beam_width] = \"{}\"", field, defaultMaxConnections, defaultBeamWidth);

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack

"Initialize KNN vector format for field [%s] with default params [max_connections] = \"%d\" and [beam_width] = \"%d\"",
field,
defaultMaxConnections,
defaultBeamWidth
)
);
return defaultFormatSupplier.get();
}
var type = (KNNVectorFieldMapper.KNNVectorFieldType) mapperService.orElseThrow(
() -> new IllegalStateException(
String.format("Cannot read field type for field [%s] because mapper service is not available", field)
)
).fieldType(field);
var params = type.getKnnMethodContext().getMethodComponent().getParameters();
int maxConnections = getMaxConnections(params);
int beamWidth = getBeamWidth(params);
log.debug(
String.format(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to use String.format

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack

"Initialize KNN vector format for field [%s] with params [max_connections] = \"%d\" and [beam_width] = \"%d\"",
field,
maxConnections,
beamWidth
)
);
return formatSupplier.apply(maxConnections, beamWidth);
}

private boolean isKnnVectorFieldType(final String field) {
return mapperService.isPresent() && mapperService.get().fieldType(field) instanceof KNNVectorFieldMapper.KNNVectorFieldType;
}

private int getMaxConnections(final Map<String, Object> params) {
if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_M)) {
return (int) params.get(KNNConstants.METHOD_PARAMETER_M);
}
return defaultMaxConnections;
}

private int getBeamWidth(final Map<String, Object> params) {
if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION)) {
return (int) params.get(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION);
}
return defaultBeamWidth;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,23 @@
import org.apache.lucene.codecs.CompoundFormat;
import org.apache.lucene.codecs.DocValuesFormat;
import org.apache.lucene.codecs.FilterCodec;
import org.opensearch.knn.index.codec.KNNCodecVersion;
import org.opensearch.knn.index.codec.KNNFormatFacade;
import org.opensearch.knn.index.codec.KNNFormatFactory;

import static org.opensearch.knn.index.codec.KNNCodecFactory.CodecDelegateFactory.createKNN91DefaultDelegate;

/**
* Extends the Codec to support a new file format for KNN index
* based on the mappings.
*
*/
public final class KNN910Codec extends FilterCodec {

private static final String KNN910 = "KNN910Codec";
private static final KNNCodecVersion version = KNNCodecVersion.V_9_1_0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For static final variable, the naming convention should be VERSION I think.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree, good catch

private final KNNFormatFacade knnFormatFacade;

/**
* No arg constructor that uses Lucene91 as the delegate
*/
public KNN910Codec() {
this(createKNN91DefaultDelegate());
this(version.getDefaultCodecDelegate());
}

/**
Expand All @@ -36,8 +33,8 @@ public KNN910Codec() {
* @param delegate codec that will perform all operations this codec does not override
*/
public KNN910Codec(Codec delegate) {
super(KNN910, delegate);
knnFormatFacade = KNNFormatFactory.createKNN910Format(delegate);
super(version.getCodecName(), delegate);
knnFormatFacade = version.getKnnFormatFacadeSupplier().apply(delegate);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,23 @@
import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.knn.index.codec.KNNCodecVersion;
import org.opensearch.knn.index.codec.KNNFormatFacade;
import org.opensearch.knn.index.codec.KNNFormatFactory;

import java.util.Optional;

import static org.opensearch.knn.index.codec.KNNCodecFactory.CodecDelegateFactory.createKNN92DefaultDelegate;

/**
* KNN codec that is based on Lucene92 codec
*/
@Log4j2
public final class KNN920Codec extends FilterCodec {

private static final String KNN920 = "KNN920Codec";

private static final KNNCodecVersion version = KNNCodecVersion.V_9_2_0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
private static final KNNCodecVersion version = KNNCodecVersion.V_9_2_0;
private static final KNNCodecVersion VERSION = KNNCodecVersion.V_9_2_0;

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack

private final KNNFormatFacade knnFormatFacade;
private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat;

/**
* No arg constructor that uses Lucene91 as the delegate
*/
public KNN920Codec() {
this(createKNN92DefaultDelegate(), new KNN920PerFieldKnnVectorsFormat(Optional.empty()));
this(version.getDefaultCodecDelegate(), version.getPerFieldKnnVectorsFormat());
}

/**
Expand All @@ -45,8 +39,8 @@ public KNN920Codec() {
*/
@Builder
public KNN920Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat) {
super(KNN920, delegate);
knnFormatFacade = KNNFormatFactory.createKNN920Format(delegate);
super(version.getCodecName(), delegate);
knnFormatFacade = version.getKnnFormatFacadeSupplier().apply(delegate);
perFieldKnnVectorsFormat = knnVectorsFormat;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,74 +5,24 @@

package org.opensearch.knn.index.codec.KNN920Codec;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.backward_codecs.lucene92.Lucene92HnswVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.codec.BasePerFieldKnnVectorsFormat;

import java.util.Map;
import java.util.Optional;

/**
* Class provides per field format implementation for Lucene Knn vector type
*/
@AllArgsConstructor
@Log4j2
public class KNN920PerFieldKnnVectorsFormat extends PerFieldKnnVectorsFormat {

private final Optional<MapperService> mapperService;

@Override
public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
if (isNotKnnVectorFieldType(field)) {
log.debug(
String.format(
"Initialize KNN vector format for field [%s] with default params [max_connections] = \"%d\" and [beam_width] = \"%d\"",
field,
Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene92HnswVectorsFormat.DEFAULT_BEAM_WIDTH
)
);
return new Lucene92HnswVectorsFormat();
}
var type = (KNNVectorFieldMapper.KNNVectorFieldType) mapperService.orElseThrow(
() -> new IllegalStateException(
String.format("Cannot read field type for field [%s] because mapper service is not available", field)
)
).fieldType(field);
var params = type.getKnnMethodContext().getMethodComponent().getParameters();
int maxConnections = getMaxConnections(params);
int beamWidth = getBeamWidth(params);
log.debug(
String.format(
"Initialize KNN vector format for field [%s] with params [max_connections] = \"%d\" and [beam_width] = \"%d\"",
field,
maxConnections,
beamWidth
)
public class KNN920PerFieldKnnVectorsFormat extends BasePerFieldKnnVectorsFormat {

public KNN920PerFieldKnnVectorsFormat(final Optional<MapperService> mapperService) {
super(
mapperService,
Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene92HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
() -> new Lucene92HnswVectorsFormat(),
(maxConnm, beamWidth) -> new Lucene92HnswVectorsFormat(maxConnm, beamWidth)
);
return new Lucene92HnswVectorsFormat(maxConnections, beamWidth);
}

private boolean isNotKnnVectorFieldType(final String field) {
return !mapperService.isPresent() || !(mapperService.get().fieldType(field) instanceof KNNVectorFieldMapper.KNNVectorFieldType);
}

private int getMaxConnections(final Map<String, Object> params) {
if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_M)) {
return (int) params.get(KNNConstants.METHOD_PARAMETER_M);
}
return Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN;
}

private int getBeamWidth(final Map<String, Object> params) {
if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION)) {
return (int) params.get(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION);
}
return Lucene92HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,19 @@
import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.knn.index.codec.KNNCodecVersion;
import org.opensearch.knn.index.codec.KNNFormatFacade;
import org.opensearch.knn.index.codec.KNNFormatFactory;

import java.util.Optional;

import static org.opensearch.knn.index.codec.KNNCodecFactory.CodecDelegateFactory.createKNN94DefaultDelegate;

public class KNN940Codec extends FilterCodec {
private static final String KNN940 = "KNN940Codec";
private static final KNNCodecVersion version = KNNCodecVersion.V_9_4_0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
private static final KNNCodecVersion version = KNNCodecVersion.V_9_4_0;
private static final KNNCodecVersion VERSION = KNNCodecVersion.V_9_4_0;

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack

private final KNNFormatFacade knnFormatFacade;
private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat;

/**
* No arg constructor that uses Lucene94 as the delegate
*/
public KNN940Codec() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FAR: why not pass KNNCodecVersion as parameter?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to have default constructor as per SPI contract, all codec classes are defined in corresponding registry file. In such case approach with parameter will not work

this(createKNN94DefaultDelegate(), new KNN940PerFieldKnnVectorsFormat(Optional.empty()));
this(version.getDefaultCodecDelegate(), version.getPerFieldKnnVectorsFormat());
}

/**
Expand All @@ -40,8 +36,8 @@ public KNN940Codec() {
*/
@Builder
protected KNN940Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat) {
super(KNN940, delegate);
knnFormatFacade = KNNFormatFactory.createKNN940Format(delegate);
super(version.getCodecName(), delegate);
knnFormatFacade = version.getKnnFormatFacadeSupplier().apply(delegate);
perFieldKnnVectorsFormat = knnVectorsFormat;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,70 +5,24 @@

package org.opensearch.knn.index.codec.KNN940Codec;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.codecs.lucene94.Lucene94HnswVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.codec.BasePerFieldKnnVectorsFormat;

import java.util.Map;
import java.util.Optional;

/**
* Class provides per field format implementation for Lucene Knn vector type
*/
@AllArgsConstructor
@Log4j2
public class KNN940PerFieldKnnVectorsFormat extends PerFieldKnnVectorsFormat {

private final Optional<MapperService> mapperService;

@Override
public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
if (isNotKnnVectorFieldType(field)) {
log.debug(
"Initialize KNN vector format for field [{}] with default params [max_connections] = \"{}\" and [beam_width] = \"{}\"",
field,
Lucene94HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene94HnswVectorsFormat.DEFAULT_BEAM_WIDTH
);
return new Lucene94HnswVectorsFormat();
}
var type = (KNNVectorFieldMapper.KNNVectorFieldType) mapperService.orElseThrow(
() -> new IllegalStateException(
String.format("Cannot read field type for field [%s] because mapper service is not available", field)
)
).fieldType(field);
var params = type.getKnnMethodContext().getMethodComponent().getParameters();
int maxConnections = getMaxConnections(params);
int beamWidth = getBeamWidth(params);
log.debug(
"Initialize KNN vector format for field [{}] with params [max_connections] = \"{}\" and [beam_width] = \"{}\"",
field,
maxConnections,
beamWidth
public class KNN940PerFieldKnnVectorsFormat extends BasePerFieldKnnVectorsFormat {

public KNN940PerFieldKnnVectorsFormat(final Optional<MapperService> mapperService) {
super(
mapperService,
Lucene94HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene94HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
() -> new Lucene94HnswVectorsFormat(),
(maxConnm, beamWidth) -> new Lucene94HnswVectorsFormat(maxConnm, beamWidth)
);
return new Lucene94HnswVectorsFormat(maxConnections, beamWidth);
}

private boolean isNotKnnVectorFieldType(final String field) {
return !mapperService.isPresent() || !(mapperService.get().fieldType(field) instanceof KNNVectorFieldMapper.KNNVectorFieldType);
}

private int getMaxConnections(final Map<String, Object> params) {
if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_M)) {
return (int) params.get(KNNConstants.METHOD_PARAMETER_M);
}
return Lucene94HnswVectorsFormat.DEFAULT_MAX_CONN;
}

private int getBeamWidth(final Map<String, Object> params) {
if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION)) {
return (int) params.get(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION);
}
return Lucene94HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
}
}
Loading