Skip to content

Commit

Permalink
Remove SimilarityMeasure from methods
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed May 16, 2024
1 parent 20efb16 commit 11e2bc5
Showing 1 changed file with 7 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ public void infer(
switch (model.getConfigurations().getTaskType()) {
case ANY, TEXT_EMBEDDING -> {
ServiceSettings modelServiceSettings = model.getServiceSettings();
listener.onResponse(makeResults(input, modelServiceSettings.dimensions(), modelServiceSettings.similarity()));
listener.onResponse(makeResults(input, modelServiceSettings.dimensions()));
}
default -> listener.onFailure(
new ElasticsearchStatusException(
Expand All @@ -125,7 +125,7 @@ public void chunkedInfer(
switch (model.getConfigurations().getTaskType()) {
case ANY, TEXT_EMBEDDING -> {
ServiceSettings modelServiceSettings = model.getServiceSettings();
listener.onResponse(makeChunkedResults(input, modelServiceSettings.dimensions(), modelServiceSettings.similarity()));
listener.onResponse(makeChunkedResults(input, modelServiceSettings.dimensions()));
}
default -> listener.onFailure(
new ElasticsearchStatusException(
Expand All @@ -136,10 +136,10 @@ public void chunkedInfer(
}
}

private TextEmbeddingResults makeResults(List<String> input, int dimensions, SimilarityMeasure similarityMeasure) {
private TextEmbeddingResults makeResults(List<String> input, int dimensions) {
List<TextEmbeddingResults.Embedding> embeddings = new ArrayList<>();
for (int i = 0; i < input.size(); i++) {
double[] doubleEmbeddings = generateEmbedding(input.get(i), dimensions, similarityMeasure);
double[] doubleEmbeddings = generateEmbedding(input.get(i), dimensions);
List<Float> floatEmbeddings = new ArrayList<>(dimensions);
for (int j = 0; j < dimensions; j++) {
floatEmbeddings.add((float) doubleEmbeddings[j]);
Expand All @@ -149,14 +149,10 @@ private TextEmbeddingResults makeResults(List<String> input, int dimensions, Sim
return new TextEmbeddingResults(embeddings);
}

private List<ChunkedInferenceServiceResults> makeChunkedResults(
List<String> input,
int dimensions,
SimilarityMeasure similarityMeasure
) {
private List<ChunkedInferenceServiceResults> makeChunkedResults(List<String> input, int dimensions) {
var results = new ArrayList<ChunkedInferenceServiceResults>();
for (int i = 0; i < input.size(); i++) {
double[] embeddings = generateEmbedding(input.get(i), dimensions, similarityMeasure);
double[] embeddings = generateEmbedding(input.get(i), dimensions);
results.add(
new org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults(
List.of(new ChunkedTextEmbeddingResults.EmbeddingChunk(input.get(i), embeddings))
Expand All @@ -170,7 +166,7 @@ protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceS
return TestServiceSettings.fromMap(serviceSettingsMap);
}

private static double[] generateEmbedding(String input, int dimensions, SimilarityMeasure similarityMeasure) {
private static double[] generateEmbedding(String input, int dimensions) {
double[] embedding = new double[dimensions];
for (int j = 0; j < dimensions; j++) {
embedding[j] = input.hashCode() + 1 + j;
Expand Down

0 comments on commit 11e2bc5

Please sign in to comment.