Skip to content

Commit

Permalink
Fix some compile errors
Browse files Browse the repository at this point in the history
  • Loading branch information
kderusso committed May 2, 2024
1 parent bd9dff8 commit 4b30b6a
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 152 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ILM_SHRINK_ENABLE_WRITE = def(8_635_00_0);
public static final TransportVersion GEOIP_CACHE_STATS = def(8_636_00_0);
public static final TransportVersion WATERMARK_THRESHOLDS_STATS = def(8_637_00_0);
public static final TransportVersion SPARSE_VECTOR_QUERY_ADDED = def(8_638_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.search.WeightedToken;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -22,10 +23,10 @@ public static ChunkedTextExpansionResults createRandomResults() {
int numChunks = randomIntBetween(1, 5);

for (int i = 0; i < numChunks; i++) {
var tokenWeights = new ArrayList<TextExpansionResults.WeightedToken>();
var tokenWeights = new ArrayList<WeightedToken>();
int numTokens = randomIntBetween(1, 8);
for (int j = 0; j < numTokens; j++) {
tokenWeights.add(new TextExpansionResults.WeightedToken(Integer.toString(j), (float) randomDoubleBetween(0.0, 5.0, false)));
tokenWeights.add(new WeightedToken(Integer.toString(j), (float) randomDoubleBetween(0.0, 5.0, false)));
}
chunks.add(new ChunkedTextExpansionResults.ChunkedResult(randomAlphaOfLength(6), tokenWeights));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.xpack.core.ml.search.WeightedToken;

import java.util.ArrayList;
import java.util.List;
Expand All @@ -23,9 +24,9 @@ public static TextExpansionResults createRandomResults() {

public static TextExpansionResults createRandomResults(int min, int max) {
int numTokens = randomIntBetween(min, max);
List<TextExpansionResults.WeightedToken> tokenList = new ArrayList<>();
List<WeightedToken> tokenList = new ArrayList<>();
for (int i = 0; i < numTokens; i++) {
tokenList.add(new TextExpansionResults.WeightedToken(Integer.toString(i), (float) randomDoubleBetween(0.0, 5.0, false)));
tokenList.add(new WeightedToken(Integer.toString(i), (float) randomDoubleBetween(0.0, 5.0, false)));
}
return new TextExpansionResults(randomAlphaOfLength(4), tokenList, randomBoolean());
}
Expand All @@ -49,9 +50,7 @@ protected TextExpansionResults mutateInstance(TextExpansionResults instance) {
@SuppressWarnings("unchecked")
void assertFieldValues(TextExpansionResults createdInstance, IngestDocument document, String parentField, String resultsField) {
var ingestedTokens = (Map<String, Object>) document.getFieldValue(parentField + resultsField, Map.class);
var tokenMap = createdInstance.getWeightedTokens()
.stream()
.collect(Collectors.toMap(TextExpansionResults.WeightedToken::token, TextExpansionResults.WeightedToken::weight));
var tokenMap = createdInstance.getWeightedTokens().stream().collect(Collectors.toMap(WeightedToken::token, WeightedToken::weight));
assertEquals(tokenMap.size(), ingestedTokens.size());

assertEquals(tokenMap, ingestedTokens);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.ml.queries;
package org.elasticsearch.xpack.core.ml.search;

import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractXContentSerializingTestCase;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.ml.queries;
package org.elasticsearch.xpack.core.ml.search;

import org.apache.lucene.document.Document;
import org.apache.lucene.document.FeatureField;
Expand All @@ -26,23 +26,18 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.compress.CompressedXContent;
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.mapper.extras.MapperExtrasPlugin;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.test.AbstractQueryTestCase;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.ml.MachineLearning;

import java.io.IOException;
import java.lang.reflect.Method;
import java.util.Collection;
import java.util.List;

import static org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults.WeightedToken;
import static org.elasticsearch.xpack.ml.queries.WeightedTokensQueryBuilder.TOKENS_FIELD;
import static org.elasticsearch.xpack.core.ml.search.WeightedTokensQueryBuilder.TOKENS_FIELD;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.Matchers.either;
Expand All @@ -51,7 +46,7 @@
public class WeightedTokensQueryBuilderTests extends AbstractQueryTestCase<WeightedTokensQueryBuilder> {

private static final String RANK_FEATURES_FIELD = "rank";
private static final List<WeightedToken> WEIGHTED_TOKENS = List.of(new TextExpansionResults.WeightedToken("foo", .42f));
private static final List<WeightedToken> WEIGHTED_TOKENS = List.of(new WeightedToken("foo", .42f));
private static final int NUM_TOKENS = WEIGHTED_TOKENS.size();

@Override
Expand All @@ -74,10 +69,10 @@ private WeightedTokensQueryBuilder createTestQueryBuilder(boolean onlyScorePrune
return builder;
}

@Override
protected Collection<Class<? extends Plugin>> getPlugins() {
return List.of(MachineLearning.class, MapperExtrasPlugin.class);
}
// @Override
// protected Collection<Class<? extends Plugin>> getPlugins() {
// return List.of(MachineLearning.class, MapperExtrasPlugin.class);
// }

@Override
protected boolean canSimulateMethod(Method method, Object[] args) throws NoSuchMethodException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.core.ml.search.WeightedToken;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -119,9 +119,9 @@ public void chunkedInfer(
private SparseEmbeddingResults makeResults(List<String> input) {
var embeddings = new ArrayList<SparseEmbeddingResults.Embedding>();
for (int i = 0; i < input.size(); i++) {
var tokens = new ArrayList<SparseEmbeddingResults.WeightedToken>();
var tokens = new ArrayList<WeightedToken>();
for (int j = 0; j < 5; j++) {
tokens.add(new SparseEmbeddingResults.WeightedToken(Integer.toString(j), (float) j));
tokens.add(new WeightedToken(Integer.toString(j), (float) j));
}
embeddings.add(new SparseEmbeddingResults.Embedding(tokens, false));
}
Expand All @@ -131,9 +131,9 @@ private SparseEmbeddingResults makeResults(List<String> input) {
private List<ChunkedInferenceServiceResults> makeChunkedResults(List<String> input) {
var chunks = new ArrayList<ChunkedTextExpansionResults.ChunkedResult>();
for (int i = 0; i < input.size(); i++) {
var tokens = new ArrayList<TextExpansionResults.WeightedToken>();
var tokens = new ArrayList<WeightedToken>();
for (int j = 0; j < 5; j++) {
tokens.add(new TextExpansionResults.WeightedToken(Integer.toString(j), (float) j));
tokens.add(new WeightedToken(Integer.toString(j), (float) j));
}
chunks.add(new ChunkedTextExpansionResults.ChunkedResult(input.get(i), tokens));
}
Expand Down
Loading

0 comments on commit 4b30b6a

Please sign in to comment.