From f713a868b083b64fa2513dc82eb9111bfa0cb07f Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Wed, 10 Jul 2024 09:48:48 -0700 Subject: [PATCH] Setting locale to Locale.ROOT while using String.format to ensure that UTs doesn't fail when validating the exception messages (#1812) Signed-off-by: Navneet Verma --- .../knn/index/query/KNNQueryBuilder.java | 58 ++++++++++++------- 1 file changed, 36 insertions(+), 22 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 86d8031bd..eb96d0937 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -44,6 +44,7 @@ import java.io.IOException; import java.util.Arrays; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Objects; @@ -198,32 +199,38 @@ public KNNQueryBuilder build() { private void validate() { if (Strings.isNullOrEmpty(fieldName)) { - throw new IllegalArgumentException(String.format("[%s] requires fieldName", NAME)); + throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] requires fieldName", NAME)); } if (vector == null) { - throw new IllegalArgumentException(String.format("[%s] requires query vector", NAME)); + throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] requires query vector", NAME)); } else if (vector.length == 0) { - throw new IllegalArgumentException(String.format("[%s] query vector is empty", NAME)); + throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] query vector is empty", NAME)); } if (k == null && minScore == null && maxDistance == null) { - throw new IllegalArgumentException(String.format("[%s] requires exactly one of k, distance or score to be set", NAME)); + throw new IllegalArgumentException( + String.format(Locale.ROOT, "[%s] requires exactly one of k, distance or score to be set", NAME) + ); } if ((k != null && maxDistance != null) || (maxDistance != null && minScore != null) || (k != null && minScore != null)) { - throw new IllegalArgumentException(String.format("[%s] requires exactly one of k, distance or score to be set", NAME)); + throw new IllegalArgumentException( + String.format(Locale.ROOT, "[%s] requires exactly one of k, distance or score to be set", NAME) + ); } if (k != null) { if (k <= 0 || k > K_MAX) { - throw new IllegalArgumentException(String.format("[%s] requires k to be in the range (0, %d]", NAME, K_MAX)); + throw new IllegalArgumentException( + String.format(Locale.ROOT, "[%s] requires k to be in the range (0, %d]", NAME, K_MAX) + ); } } if (minScore != null) { if (minScore <= 0) { - throw new IllegalArgumentException(String.format("[%s] requires minScore to be greater than 0", NAME)); + throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] requires minScore to be greater than 0", NAME)); } } @@ -231,7 +238,7 @@ private void validate() { ValidationException validationException = validateMethodParameters(methodParameters); if (validationException != null) { throw new IllegalArgumentException( - String.format("[%s] errors in method parameter [%s]", NAME, validationException.getMessage()) + String.format(Locale.ROOT, "[%s] errors in method parameter [%s]", NAME, validationException.getMessage()) ); } } @@ -257,19 +264,19 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k) { @Deprecated public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder filter) { if (Strings.isNullOrEmpty(fieldName)) { - throw new IllegalArgumentException(String.format("[%s] requires fieldName", NAME)); + throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] requires fieldName", NAME)); } if (vector == null) { - throw new IllegalArgumentException(String.format("[%s] requires query vector", NAME)); + throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] requires query vector", NAME)); } if (vector.length == 0) { - throw new IllegalArgumentException(String.format("[%s] query vector is empty", NAME)); + throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] query vector is empty", NAME)); } if (k <= 0) { - throw new IllegalArgumentException(String.format("[%s] requires k > 0", NAME)); + throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] requires k > 0", NAME)); } if (k > K_MAX) { - throw new IllegalArgumentException(String.format("[%s] requires k <= %d", NAME, K_MAX)); + throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] requires k <= %d", NAME, K_MAX)); } this.fieldName = fieldName; @@ -287,12 +294,16 @@ public static void initialize(ModelDao modelDao) { private static float[] ObjectsToFloats(List objs) { if (Objects.isNull(objs) || objs.isEmpty()) { - throw new IllegalArgumentException(String.format("[%s] field 'vector' requires to be non-null and non-empty", NAME)); + throw new IllegalArgumentException( + String.format(Locale.ROOT, "[%s] field 'vector' requires to be non-null and non-empty", NAME) + ); } float[] vec = new float[objs.size()]; for (int i = 0; i < objs.size(); i++) { if ((objs.get(i) instanceof Number) == false) { - throw new IllegalArgumentException(String.format("[%s] field 'vector' requires to be an array of numbers", NAME)); + throw new IllegalArgumentException( + String.format(Locale.ROOT, "[%s] field 'vector' requires to be an array of numbers", NAME) + ); } vec[i] = ((Number) objs.get(i)).floatValue(); } @@ -481,7 +492,7 @@ protected Query doToQuery(QueryShardContext context) { } if (!(mappedFieldType instanceof KNNVectorFieldMapper.KNNVectorFieldType)) { - throw new IllegalArgumentException(String.format("Field '%s' is not knn_vector type.", this.fieldName)); + throw new IllegalArgumentException(String.format(Locale.ROOT, "Field '%s' is not knn_vector type.", this.fieldName)); } KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = (KNNVectorFieldMapper.KNNVectorFieldType) mappedFieldType; @@ -523,6 +534,7 @@ protected Query doToQuery(QueryShardContext context) { if (validationException != null) { throw new IllegalArgumentException( String.format( + Locale.ROOT, "Parameters not valid for [%s]:[%s]:[%s] combination: [%s]", knnEngine, method, @@ -575,7 +587,7 @@ protected Query doToQuery(QueryShardContext context) { if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine) && filter != null && !KNNEngine.getEnginesThatSupportsFilters().contains(knnEngine)) { - throw new IllegalArgumentException(String.format("Engine [%s] does not support filters", knnEngine)); + throw new IllegalArgumentException(String.format(Locale.ROOT, "Engine [%s] does not support filters", knnEngine)); } String indexName = context.index().getName(); @@ -597,7 +609,9 @@ protected Query doToQuery(QueryShardContext context) { } if (radius != null) { if (!ENGINES_SUPPORTING_RADIAL_SEARCH.contains(knnEngine)) { - throw new UnsupportedOperationException(String.format("Engine [%s] does not support radial search", knnEngine)); + throw new UnsupportedOperationException( + String.format(Locale.ROOT, "Engine [%s] does not support radial search", knnEngine) + ); } RNNQueryFactory.CreateQueryRequest createQueryRequest = RNNQueryFactory.CreateQueryRequest.builder() .knnEngine(knnEngine) @@ -613,19 +627,19 @@ protected Query doToQuery(QueryShardContext context) { .build(); return RNNQueryFactory.create(createQueryRequest); } - throw new IllegalArgumentException(String.format("[%s] requires k or distance or score to be set", NAME)); + throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] requires k or distance or score to be set", NAME)); } private ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) { String modelId = knnVectorField.getModelId(); if (modelId == null) { - throw new IllegalArgumentException(String.format("Field '%s' does not have model.", this.fieldName)); + throw new IllegalArgumentException(String.format(Locale.ROOT, "Field '%s' does not have model.", this.fieldName)); } ModelMetadata modelMetadata = modelDao.getMetadata(modelId); if (!ModelUtil.isModelCreated(modelMetadata)) { - throw new IllegalArgumentException(String.format("Model ID '%s' is not created.", modelId)); + throw new IllegalArgumentException(String.format(Locale.ROOT, "Model ID '%s' is not created.", modelId)); } return modelMetadata; } @@ -647,7 +661,7 @@ private VectorQueryType getVectorQueryType(int k, Float maxDistance, Float minSc if (k != 0) { return VectorQueryType.K; } - throw new IllegalArgumentException(String.format("[%s] requires exactly one of k, distance or score to be set", NAME)); + throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] requires exactly one of k, distance or score to be set", NAME)); } /**