Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Setting locale to Locale.ROOT while using String.format to ensure that UTs doesn't fail when validating the exception messages #1812

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 36 additions & 22 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;

Expand Down Expand Up @@ -198,40 +199,46 @@ public KNNQueryBuilder build() {

private void validate() {
if (Strings.isNullOrEmpty(fieldName)) {
throw new IllegalArgumentException(String.format("[%s] requires fieldName", NAME));
throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] requires fieldName", NAME));
}

if (vector == null) {
throw new IllegalArgumentException(String.format("[%s] requires query vector", NAME));
throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] requires query vector", NAME));
} else if (vector.length == 0) {
throw new IllegalArgumentException(String.format("[%s] query vector is empty", NAME));
throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] query vector is empty", NAME));
}

if (k == null && minScore == null && maxDistance == null) {
throw new IllegalArgumentException(String.format("[%s] requires exactly one of k, distance or score to be set", NAME));
throw new IllegalArgumentException(
String.format(Locale.ROOT, "[%s] requires exactly one of k, distance or score to be set", NAME)
);
}

if ((k != null && maxDistance != null) || (maxDistance != null && minScore != null) || (k != null && minScore != null)) {
throw new IllegalArgumentException(String.format("[%s] requires exactly one of k, distance or score to be set", NAME));
throw new IllegalArgumentException(
String.format(Locale.ROOT, "[%s] requires exactly one of k, distance or score to be set", NAME)
);
}

if (k != null) {
if (k <= 0 || k > K_MAX) {
throw new IllegalArgumentException(String.format("[%s] requires k to be in the range (0, %d]", NAME, K_MAX));
throw new IllegalArgumentException(
String.format(Locale.ROOT, "[%s] requires k to be in the range (0, %d]", NAME, K_MAX)
);
}
}

if (minScore != null) {
if (minScore <= 0) {
throw new IllegalArgumentException(String.format("[%s] requires minScore to be greater than 0", NAME));
throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] requires minScore to be greater than 0", NAME));
}
}

if (methodParameters != null) {
ValidationException validationException = validateMethodParameters(methodParameters);
if (validationException != null) {
throw new IllegalArgumentException(
String.format("[%s] errors in method parameter [%s]", NAME, validationException.getMessage())
String.format(Locale.ROOT, "[%s] errors in method parameter [%s]", NAME, validationException.getMessage())
);
}
}
Expand All @@ -257,19 +264,19 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k) {
@Deprecated
public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder filter) {
if (Strings.isNullOrEmpty(fieldName)) {
throw new IllegalArgumentException(String.format("[%s] requires fieldName", NAME));
throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] requires fieldName", NAME));
}
if (vector == null) {
throw new IllegalArgumentException(String.format("[%s] requires query vector", NAME));
throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] requires query vector", NAME));
}
if (vector.length == 0) {
throw new IllegalArgumentException(String.format("[%s] query vector is empty", NAME));
throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] query vector is empty", NAME));
}
if (k <= 0) {
throw new IllegalArgumentException(String.format("[%s] requires k > 0", NAME));
throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] requires k > 0", NAME));
}
if (k > K_MAX) {
throw new IllegalArgumentException(String.format("[%s] requires k <= %d", NAME, K_MAX));
throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] requires k <= %d", NAME, K_MAX));
}

this.fieldName = fieldName;
Expand All @@ -287,12 +294,16 @@ public static void initialize(ModelDao modelDao) {

private static float[] ObjectsToFloats(List<Object> objs) {
if (Objects.isNull(objs) || objs.isEmpty()) {
throw new IllegalArgumentException(String.format("[%s] field 'vector' requires to be non-null and non-empty", NAME));
throw new IllegalArgumentException(
String.format(Locale.ROOT, "[%s] field 'vector' requires to be non-null and non-empty", NAME)
);
}
float[] vec = new float[objs.size()];
for (int i = 0; i < objs.size(); i++) {
if ((objs.get(i) instanceof Number) == false) {
throw new IllegalArgumentException(String.format("[%s] field 'vector' requires to be an array of numbers", NAME));
throw new IllegalArgumentException(
String.format(Locale.ROOT, "[%s] field 'vector' requires to be an array of numbers", NAME)
);
}
vec[i] = ((Number) objs.get(i)).floatValue();
}
Expand Down Expand Up @@ -481,7 +492,7 @@ protected Query doToQuery(QueryShardContext context) {
}

if (!(mappedFieldType instanceof KNNVectorFieldMapper.KNNVectorFieldType)) {
throw new IllegalArgumentException(String.format("Field '%s' is not knn_vector type.", this.fieldName));
throw new IllegalArgumentException(String.format(Locale.ROOT, "Field '%s' is not knn_vector type.", this.fieldName));
}

KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = (KNNVectorFieldMapper.KNNVectorFieldType) mappedFieldType;
Expand Down Expand Up @@ -523,6 +534,7 @@ protected Query doToQuery(QueryShardContext context) {
if (validationException != null) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"Parameters not valid for [%s]:[%s]:[%s] combination: [%s]",
knnEngine,
method,
Expand Down Expand Up @@ -575,7 +587,7 @@ protected Query doToQuery(QueryShardContext context) {
if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)
&& filter != null
&& !KNNEngine.getEnginesThatSupportsFilters().contains(knnEngine)) {
throw new IllegalArgumentException(String.format("Engine [%s] does not support filters", knnEngine));
throw new IllegalArgumentException(String.format(Locale.ROOT, "Engine [%s] does not support filters", knnEngine));
}

String indexName = context.index().getName();
Expand All @@ -597,7 +609,9 @@ protected Query doToQuery(QueryShardContext context) {
}
if (radius != null) {
if (!ENGINES_SUPPORTING_RADIAL_SEARCH.contains(knnEngine)) {
throw new UnsupportedOperationException(String.format("Engine [%s] does not support radial search", knnEngine));
throw new UnsupportedOperationException(
String.format(Locale.ROOT, "Engine [%s] does not support radial search", knnEngine)
);
}
RNNQueryFactory.CreateQueryRequest createQueryRequest = RNNQueryFactory.CreateQueryRequest.builder()
.knnEngine(knnEngine)
Expand All @@ -613,19 +627,19 @@ protected Query doToQuery(QueryShardContext context) {
.build();
return RNNQueryFactory.create(createQueryRequest);
}
throw new IllegalArgumentException(String.format("[%s] requires k or distance or score to be set", NAME));
throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] requires k or distance or score to be set", NAME));
}

private ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) {
String modelId = knnVectorField.getModelId();

if (modelId == null) {
throw new IllegalArgumentException(String.format("Field '%s' does not have model.", this.fieldName));
throw new IllegalArgumentException(String.format(Locale.ROOT, "Field '%s' does not have model.", this.fieldName));
}

ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (!ModelUtil.isModelCreated(modelMetadata)) {
throw new IllegalArgumentException(String.format("Model ID '%s' is not created.", modelId));
throw new IllegalArgumentException(String.format(Locale.ROOT, "Model ID '%s' is not created.", modelId));
}
return modelMetadata;
}
Expand All @@ -647,7 +661,7 @@ private VectorQueryType getVectorQueryType(int k, Float maxDistance, Float minSc
if (k != 0) {
return VectorQueryType.K;
}
throw new IllegalArgumentException(String.format("[%s] requires exactly one of k, distance or score to be set", NAME));
throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] requires exactly one of k, distance or score to be set", NAME));
}

/**
Expand Down
Loading