Skip to content

Commit

Permalink
Move the ltr code in its own package,
Browse files Browse the repository at this point in the history
  • Loading branch information
afoucret committed Nov 16, 2023
1 parent 6a1408d commit 22609c3
Show file tree
Hide file tree
Showing 11 changed files with 21 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,6 @@
import org.elasticsearch.xpack.ml.inference.pytorch.process.BlackHolePyTorchProcess;
import org.elasticsearch.xpack.ml.inference.pytorch.process.NativePyTorchProcessFactory;
import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcessFactory;
import org.elasticsearch.xpack.ml.inference.rescorer.InferenceRescorerFeature;
import org.elasticsearch.xpack.ml.inference.rescorer.LearnToRankRescorerBuilder;
import org.elasticsearch.xpack.ml.job.JobManager;
import org.elasticsearch.xpack.ml.job.JobManagerHolder;
import org.elasticsearch.xpack.ml.job.NodeLoadDetector;
Expand Down Expand Up @@ -358,6 +356,8 @@
import org.elasticsearch.xpack.ml.job.process.normalizer.NormalizerProcessFactory;
import org.elasticsearch.xpack.ml.job.snapshot.upgrader.SnapshotUpgradeTaskExecutor;
import org.elasticsearch.xpack.ml.job.task.OpenJobPersistentTasksExecutor;
import org.elasticsearch.xpack.ml.ltr.InferenceRescorerFeature;
import org.elasticsearch.xpack.ml.ltr.LearnToRankRescorerBuilder;
import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
Expand Down Expand Up @@ -872,8 +872,8 @@ public List<RescorerSpec<?>> getRescorers() {
return List.of(
new RescorerSpec<>(
LearnToRankRescorerBuilder.NAME,
in -> new LearnToRankRescorerBuilder(in, modelLoadingService::get, scriptService::get),
parser -> LearnToRankRescorerBuilder.fromXContent(parser, modelLoadingService::get, scriptService::get)
in -> new LearnToRankRescorerBuilder(in, modelLoadingService.get(), scriptService.get()),
parser -> LearnToRankRescorerBuilder.fromXContent(parser, modelLoadingService.get(), scriptService.get())
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.ml.inference.rescorer;
package org.elasticsearch.xpack.ml.ltr;

import org.apache.lucene.index.LeafReaderContext;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.ml.inference.rescorer;
package org.elasticsearch.xpack.ml.ltr;

import org.apache.lucene.index.LeafReaderContext;
import org.elasticsearch.index.mapper.MappedFieldType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.ml.inference.rescorer;
package org.elasticsearch.xpack.ml.ltr;

import org.elasticsearch.common.util.FeatureFlag;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.ml.inference.rescorer;
package org.elasticsearch.xpack.ml.ltr;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.ml.inference.rescorer;
package org.elasticsearch.xpack.ml.ltr;

import org.apache.lucene.util.SetOnce;
import org.elasticsearch.TransportVersion;
Expand Down Expand Up @@ -42,7 +42,6 @@
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Supplier;

import static org.elasticsearch.script.Script.DEFAULT_TEMPLATE_LANG;

Expand All @@ -60,10 +59,10 @@ public class LearnToRankRescorerBuilder extends RescorerBuilder<LearnToRankResco

public static LearnToRankRescorerBuilder fromXContent(
XContentParser parser,
Supplier<ModelLoadingService> modelLoadingServiceSupplier,
Supplier<ScriptService> scriptServiceSupplier
ModelLoadingService modelLoadingService,
ScriptService scriptService
) {
return PARSER.apply(parser, null).build(modelLoadingServiceSupplier.get(), scriptServiceSupplier.get());
return PARSER.apply(parser, null).build(modelLoadingService, scriptService);
}

private final String modelId;
Expand Down Expand Up @@ -103,18 +102,15 @@ public static LearnToRankRescorerBuilder fromXContent(
this.scriptService = null;
}

public LearnToRankRescorerBuilder(
StreamInput input,
Supplier<ModelLoadingService> modelLoadingServiceSupplier,
Supplier<ScriptService> scriptServiceSupplier
) throws IOException {
public LearnToRankRescorerBuilder(StreamInput input, ModelLoadingService modelLoadingService, ScriptService scriptService)
throws IOException {
super(input);
this.modelId = input.readString();
this.params = input.readMap();
this.learnToRankConfig = input.readOptionalNamedWriteable(LearnToRankConfig.class);

this.modelLoadingService = modelLoadingServiceSupplier.get();
this.scriptService = scriptServiceSupplier.get();
this.modelLoadingService = modelLoadingService;
this.scriptService = scriptService;

this.localModel = null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.ml.inference.rescorer;
package org.elasticsearch.xpack.ml.ltr;

import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.ml.inference.rescorer;
package org.elasticsearch.xpack.ml.ltr;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DisiPriorityQueue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.ml.inference.rescorer;
package org.elasticsearch.xpack.ml.ltr;

import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.ScoreDoc;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.ml.inference.rescorer;
package org.elasticsearch.xpack.ml.ltr;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.ParsingException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.ml.inference.rescorer;
package org.elasticsearch.xpack.ml.ltr;

import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
Expand Down

0 comments on commit 22609c3

Please sign in to comment.