Skip to content

Commit

Permalink
Setting locale to Locale.ROOT while using String.format to ensure tha…
Browse files Browse the repository at this point in the history
…t UTs doesn't fail when validating the exception messages (#1812) (#1813)

Signed-off-by: Navneet Verma <[email protected]>
  • Loading branch information
navneet1v authored Jul 10, 2024
1 parent c47fba1 commit 67b4fa9
Showing 1 changed file with 36 additions and 23 deletions.
59 changes: 36 additions & 23 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;

Expand Down Expand Up @@ -199,41 +200,46 @@ public KNNQueryBuilder build() {

private void validate() {
if (Strings.isNullOrEmpty(fieldName)) {
throw new IllegalArgumentException(String.format("[%s] requires fieldName", NAME));
throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] requires fieldName", NAME));
}

if (vector == null) {
throw new IllegalArgumentException(String.format("[%s] requires query vector", NAME));
throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] requires query vector", NAME));
} else if (vector.length == 0) {
throw new IllegalArgumentException(String.format("[%s] query vector is empty", NAME));
throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] query vector is empty", NAME));
}

if (k == null && minScore == null && maxDistance == null) {
throw new IllegalArgumentException(String.format("[%s] requires exactly one of k, distance or score to be set", NAME));
throw new IllegalArgumentException(
String.format(Locale.ROOT, "[%s] requires exactly one of k, distance or score to be set", NAME)
);
}

if ((k != null && maxDistance != null) || (maxDistance != null && minScore != null) || (k != null && minScore != null)) {
throw new IllegalArgumentException(String.format("[%s] requires exactly one of k, distance or score to be set", NAME));
throw new IllegalArgumentException(
String.format(Locale.ROOT, "[%s] requires exactly one of k, distance or score to be set", NAME)
);
}

if (k != null) {
if (k <= 0 || k > K_MAX) {
final String errorMessage = "[" + NAME + "] requires k to be in the range (0, " + K_MAX + "]";
throw new IllegalArgumentException(errorMessage);
throw new IllegalArgumentException(
String.format(Locale.ROOT, "[%s] requires k to be in the range (0, %d]", NAME, K_MAX)
);
}
}

if (minScore != null) {
if (minScore <= 0) {
throw new IllegalArgumentException(String.format("[%s] requires minScore to be greater than 0", NAME));
throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] requires minScore to be greater than 0", NAME));
}
}

if (methodParameters != null) {
ValidationException validationException = validateMethodParameters(methodParameters);
if (validationException != null) {
throw new IllegalArgumentException(
String.format("[%s] errors in method parameter [%s]", NAME, validationException.getMessage())
String.format(Locale.ROOT, "[%s] errors in method parameter [%s]", NAME, validationException.getMessage())
);
}
}
Expand All @@ -259,19 +265,19 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k) {
@Deprecated
public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder filter) {
if (Strings.isNullOrEmpty(fieldName)) {
throw new IllegalArgumentException(String.format("[%s] requires fieldName", NAME));
throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] requires fieldName", NAME));
}
if (vector == null) {
throw new IllegalArgumentException(String.format("[%s] requires query vector", NAME));
throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] requires query vector", NAME));
}
if (vector.length == 0) {
throw new IllegalArgumentException(String.format("[%s] query vector is empty", NAME));
throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] query vector is empty", NAME));
}
if (k <= 0) {
throw new IllegalArgumentException(String.format("[%s] requires k > 0", NAME));
throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] requires k > 0", NAME));
}
if (k > K_MAX) {
throw new IllegalArgumentException(String.format("[%s] requires k <= %d", NAME, K_MAX));
throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] requires k <= %d", NAME, K_MAX));
}

this.fieldName = fieldName;
Expand All @@ -289,12 +295,16 @@ public static void initialize(ModelDao modelDao) {

private static float[] ObjectsToFloats(List<Object> objs) {
if (Objects.isNull(objs) || objs.isEmpty()) {
throw new IllegalArgumentException(String.format("[%s] field 'vector' requires to be non-null and non-empty", NAME));
throw new IllegalArgumentException(
String.format(Locale.ROOT, "[%s] field 'vector' requires to be non-null and non-empty", NAME)
);
}
float[] vec = new float[objs.size()];
for (int i = 0; i < objs.size(); i++) {
if ((objs.get(i) instanceof Number) == false) {
throw new IllegalArgumentException(String.format("[%s] field 'vector' requires to be an array of numbers", NAME));
throw new IllegalArgumentException(
String.format(Locale.ROOT, "[%s] field 'vector' requires to be an array of numbers", NAME)
);
}
vec[i] = ((Number) objs.get(i)).floatValue();
}
Expand Down Expand Up @@ -511,7 +521,7 @@ protected Query doToQuery(QueryShardContext context) {
}

if (!(mappedFieldType instanceof KNNVectorFieldMapper.KNNVectorFieldType)) {
throw new IllegalArgumentException(String.format("Field '%s' is not knn_vector type.", this.fieldName));
throw new IllegalArgumentException(String.format(Locale.ROOT, "Field '%s' is not knn_vector type.", this.fieldName));
}

KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = (KNNVectorFieldMapper.KNNVectorFieldType) mappedFieldType;
Expand Down Expand Up @@ -553,6 +563,7 @@ protected Query doToQuery(QueryShardContext context) {
if (validationException != null) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"Parameters not valid for [%s]:[%s]:[%s] combination: [%s]",
knnEngine,
method,
Expand Down Expand Up @@ -605,7 +616,7 @@ protected Query doToQuery(QueryShardContext context) {
if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)
&& filter != null
&& !KNNEngine.getEnginesThatSupportsFilters().contains(knnEngine)) {
throw new IllegalArgumentException(String.format("Engine [%s] does not support filters", knnEngine));
throw new IllegalArgumentException(String.format(Locale.ROOT, "Engine [%s] does not support filters", knnEngine));
}

String indexName = context.index().getName();
Expand All @@ -627,7 +638,9 @@ protected Query doToQuery(QueryShardContext context) {
}
if (radius != null) {
if (!ENGINES_SUPPORTING_RADIAL_SEARCH.contains(knnEngine)) {
throw new UnsupportedOperationException(String.format("Engine [%s] does not support radial search", knnEngine));
throw new UnsupportedOperationException(
String.format(Locale.ROOT, "Engine [%s] does not support radial search", knnEngine)
);
}
RNNQueryFactory.CreateQueryRequest createQueryRequest = RNNQueryFactory.CreateQueryRequest.builder()
.knnEngine(knnEngine)
Expand All @@ -643,19 +656,19 @@ protected Query doToQuery(QueryShardContext context) {
.build();
return RNNQueryFactory.create(createQueryRequest);
}
throw new IllegalArgumentException(String.format("[%s] requires k or distance or score to be set", NAME));
throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] requires k or distance or score to be set", NAME));
}

private ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) {
String modelId = knnVectorField.getModelId();

if (modelId == null) {
throw new IllegalArgumentException(String.format("Field '%s' does not have model.", this.fieldName));
throw new IllegalArgumentException(String.format(Locale.ROOT, "Field '%s' does not have model.", this.fieldName));
}

ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (!ModelUtil.isModelCreated(modelMetadata)) {
throw new IllegalArgumentException(String.format("Model ID '%s' is not created.", modelId));
throw new IllegalArgumentException(String.format(Locale.ROOT, "Model ID '%s' is not created.", modelId));
}
return modelMetadata;
}
Expand All @@ -677,7 +690,7 @@ private VectorQueryType getVectorQueryType(int k, Float maxDistance, Float minSc
if (k != 0) {
return VectorQueryType.K;
}
throw new IllegalArgumentException(String.format("[%s] requires exactly one of k, distance or score to be set", NAME));
throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] requires exactly one of k, distance or score to be set", NAME));
}

/**
Expand Down

0 comments on commit 67b4fa9

Please sign in to comment.