Skip to content

Commit

Permalink
Refactorings
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed May 10, 2024
1 parent e41bda9 commit aeb4d75
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.SecretSettings;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.XContentBuilder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,7 @@ public void infer(
switch (model.getConfigurations().getTaskType()) {
case ANY, TEXT_EMBEDDING -> {
ServiceSettings modelServiceSettings = model.getServiceSettings();
listener.onResponse(
makeResults(input, modelServiceSettings.dimensions(), modelServiceSettings.similarity())
);
listener.onResponse(makeResults(input, modelServiceSettings.dimensions(), modelServiceSettings.similarity()));
}
default -> listener.onFailure(
new ElasticsearchStatusException(
Expand Down Expand Up @@ -175,7 +173,7 @@ protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceS
private static double[] generateEmbedding(String input, int dimensions, SimilarityMeasure similarityMeasure) {
double[] embedding = new double[dimensions];
for (int j = 0; j < dimensions; j++) {
embedding[j] = input.hashCode() + (double)j;
embedding[j] = input.hashCode() + (double) j;
}

return embedding;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,34 +130,34 @@ public void testBulkOperations() throws Exception {
}

private void storeSparseModel() throws Exception {
ModelRegistry modelRegistry = new ModelRegistry(client());

String inferenceEntityId = TestSparseInferenceServiceExtension.TestInferenceService.NAME;
Model model = new TestSparseInferenceServiceExtension.TestSparseModel(
inferenceEntityId,
new TestSparseInferenceServiceExtension.TestServiceSettings(inferenceEntityId, null, false)
TestSparseInferenceServiceExtension.TestInferenceService.NAME,
new TestSparseInferenceServiceExtension.TestServiceSettings(
TestSparseInferenceServiceExtension.TestInferenceService.NAME,
null,
false
)
);
AtomicReference<Boolean> storeModelHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();

blockingCall(listener -> modelRegistry.storeModel(model, listener), storeModelHolder, exceptionHolder);

assertThat(storeModelHolder.get(), is(true));
assertThat(exceptionHolder.get(), is(nullValue()));
storeModel(model);
}

private void storeDenseModel() throws Exception {
ModelRegistry modelRegistry = new ModelRegistry(client());

String inferenceEntityId = TestDenseInferenceServiceExtension.TestInferenceService.NAME;
Model model = new TestDenseInferenceServiceExtension.TestDenseModel(
inferenceEntityId,
TestDenseInferenceServiceExtension.TestInferenceService.NAME,
new TestDenseInferenceServiceExtension.TestServiceSettings(
inferenceEntityId,
TestDenseInferenceServiceExtension.TestInferenceService.NAME,
randomIntBetween(1, 100),
// dot product means that we need normalized vectors; it's not worth doing that in this test
randomValueOtherThan(SimilarityMeasure.DOT_PRODUCT, () -> randomFrom(SimilarityMeasure.values()))
)
);

storeModel(model);
}

private void storeModel(Model model) throws Exception {
ModelRegistry modelRegistry = new ModelRegistry(client());

AtomicReference<Boolean> storeModelHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();

Expand Down

0 comments on commit aeb4d75

Please sign in to comment.