From f7a6a135e10403327d4fcaf4f4fc1e054df8e6e0 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Tue, 2 Apr 2024 13:21:51 -0400 Subject: [PATCH] Address some PR feedback RE: optional values --- .../core/ml/inference/results/TextExpansionResults.java | 4 ++-- .../xpack/ml/queries/SparseVectorQueryBuilder.java | 6 +++--- .../xpack/ml/queries/SparseVectorQueryBuilderTests.java | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResults.java index de679c8e6062e..45aa4d51e0ad6 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResults.java @@ -89,8 +89,8 @@ public Object predictedValue() { @Override void doXContentBody(XContentBuilder builder, Params params) throws IOException { builder.startObject(resultsField); - for (var vectorDimension : weightedTokens) { - vectorDimension.toXContent(builder, params); + for (var weightedToken : weightedTokens) { + weightedToken.toXContent(builder, params); } builder.endObject(); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/SparseVectorQueryBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/SparseVectorQueryBuilder.java index 52e51d6f65d6b..bcd47cb9acc4e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/SparseVectorQueryBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/SparseVectorQueryBuilder.java @@ -67,7 +67,7 @@ public SparseVectorQueryBuilder( public SparseVectorQueryBuilder( String fieldName, - String modelText, + @Nullable String modelText, @Nullable String modelId, @Nullable List weightedTokens, @Nullable TokenPruningConfig tokenPruningConfig @@ -97,7 +97,7 @@ public SparseVectorQueryBuilder( public SparseVectorQueryBuilder(StreamInput in) throws IOException { super(in); this.fieldName = in.readString(); - this.modelText = in.readString(); + this.modelText = in.readOptionalString(); this.modelId = in.readOptionalString(); this.tokenPruningConfig = in.readOptionalWriteable(TokenPruningConfig::new); this.weightedTokens = in.readOptionalCollectionAsList(WeightedToken::new); @@ -134,7 +134,7 @@ protected void doWriteTo(StreamOutput out) throws IOException { throw new IllegalStateException("supplier must be null, can't serialize suppliers, missing a rewriteAndFetch?"); } out.writeString(fieldName); - out.writeString(modelText); + out.writeOptionalString(modelText); out.writeOptionalString(modelId); out.writeOptionalWriteable(tokenPruningConfig); out.writeOptionalCollection(weightedTokens, StreamOutput::writeWriteable); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/SparseVectorQueryBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/SparseVectorQueryBuilderTests.java index ef2dc379b6d4a..976f151ed4eee 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/SparseVectorQueryBuilderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/SparseVectorQueryBuilderTests.java @@ -61,8 +61,8 @@ protected SparseVectorQueryBuilder doCreateTestQueryBuilder() { TokenPruningConfig tokenPruningConfig = randomBoolean() ? new TokenPruningConfig(randomIntBetween(1, 100), randomFloat(), randomBoolean()) : null; - String modelText = randomAlphaOfLength(4); String modelId = randomBoolean() ? randomAlphaOfLength(4) : null; + String modelText = modelId != null ? randomAlphaOfLength(4) : null; List weightedTokens = modelId == null ? VECTOR_DIMENSIONS : null; var builder = new SparseVectorQueryBuilder(RANK_FEATURES_FIELD, modelText, modelId, weightedTokens, tokenPruningConfig);