Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] allow for larger models in the inference step for data frame analytics #76116

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
*/
public final class InferenceToXContentCompressor {
private static final int BUFFER_SIZE = 4096;
// Either 10% of the configured JVM heap, or 1 GB, which ever is smaller
// Either 25% of the configured JVM heap, or 1 GB, which ever is smaller
private static final long MAX_INFLATED_BYTES = Math.min(
(long)((0.10) * JvmInfo.jvmInfo().getMem().getHeapMax().getBytes()),
(long)((0.25) * JvmInfo.jvmInfo().getMem().getHeapMax().getBytes()),
ByteSizeValue.ofGb(1).getBytes());

private InferenceToXContentCompressor() {}
Expand All @@ -49,6 +49,12 @@ public static <T extends ToXContentObject> BytesReference deflate(T objectToComp
return deflate(reference);
}

public static <T> T inflateUnsafe(BytesReference compressedBytes,
CheckedFunction<XContentParser, T, IOException> parserFunction,
NamedXContentRegistry xContentRegistry) throws IOException {
return inflate(compressedBytes, parserFunction, xContentRegistry, Long.MAX_VALUE);
}

public static <T> T inflate(BytesReference compressedBytes,
CheckedFunction<XContentParser, T, IOException> parserFunction,
NamedXContentRegistry xContentRegistry) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,14 @@ public TrainedModelConfig ensureParsedDefinition(NamedXContentRegistry xContentR
return this;
}

public TrainedModelConfig ensureParsedDefinitionUnsafe(NamedXContentRegistry xContentRegistry) throws IOException {
if (definition == null) {
return null;
}
definition.ensureParsedDefinitionUnsafe(xContentRegistry);
return this;
}

@Nullable
public TrainedModelDefinition getModelDefinition() {
if (definition == null) {
Expand Down Expand Up @@ -872,6 +880,14 @@ private void ensureParsedDefinition(NamedXContentRegistry xContentRegistry) thro
}
}

private void ensureParsedDefinitionUnsafe(NamedXContentRegistry xContentRegistry) throws IOException {
if (parsedDefinition == null) {
parsedDefinition = InferenceToXContentCompressor.inflateUnsafe(compressedRepresentation,
parser -> TrainedModelDefinition.fromXContent(parser, true).build(),
xContentRegistry);
}
}

@Override
public void writeTo(StreamOutput out) throws IOException {
if (out.getVersion().onOrAfter(Version.V_8_0_0)) { // TODO adjust on backport
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ public void testGetTrainedModelForInference() throws InterruptedException, IOExc

AtomicReference<InferenceDefinition> definitionHolder = new AtomicReference<>();
blockingCall(
listener -> trainedModelProvider.getTrainedModelForInference(modelId, listener),
listener -> trainedModelProvider.getTrainedModelForInference(modelId, false, listener),
definitionHolder,
exceptionHolder);
assertThat(exceptionHolder.get(), is(nullValue()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public void run(String modelId) {
LOGGER.info("[{}] Started inference on test data against model [{}]", config.getId(), modelId);
try {
PlainActionFuture<LocalModel> localModelPlainActionFuture = new PlainActionFuture<>();
modelLoadingService.getModelForPipeline(modelId, localModelPlainActionFuture);
modelLoadingService.getModelForInternalInference(modelId, localModelPlainActionFuture);
InferenceState inferenceState = restoreInferenceState();
dataCountsTracker.setTestDocsCount(inferenceState.processedTestDocsCount);
TestDocsIterator testDocsIterator = new TestDocsIterator(new OriginSettingClient(client, ClientHelper.ML_ORIGIN), config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ public class ModelLoadingService implements ClusterStateListener {

// The feature requesting the model
public enum Consumer {
PIPELINE, SEARCH
PIPELINE, SEARCH, INTERNAL
}

private static class ModelAndConsumer {
Expand Down Expand Up @@ -175,6 +175,16 @@ public void getModelForPipeline(String modelId, ActionListener<LocalModel> model
getModel(modelId, Consumer.PIPELINE, modelActionListener);
}

/**
* Load the model for internal use. Note, this decompresses the model if the stored estimate doesn't trip circuit breakers.
* Consequently, it assumes the model was created by an ML process
* @param modelId the model to get
* @param modelActionListener the listener to alert when the model has been retrieved
*/
public void getModelForInternalInference(String modelId, ActionListener<LocalModel> modelActionListener) {
getModel(modelId, Consumer.INTERNAL, modelActionListener);
}

/**
* Load the model for use by at search. Models requested by search are always cached.
*
Expand Down Expand Up @@ -272,15 +282,15 @@ private boolean loadModelIfNecessary(String modelIdOrAlias, Consumer consumer, A
return true;
}

if (Consumer.PIPELINE == consumer && referencedModels.contains(modelId) == false) {
if (Consumer.SEARCH != consumer && referencedModels.contains(modelId) == false) {
// The model is requested by a pipeline but not referenced by any ingest pipelines.
// This means it is a simulate call and the model should not be cached
logger.trace(() -> new ParameterizedMessage(
"[{}] (model_alias [{}]) not actively loading, eager loading without cache",
modelId,
modelIdOrAlias
));
loadWithoutCaching(modelId, modelActionListener);
loadWithoutCaching(modelId, consumer, modelActionListener);
} else {
logger.trace(() -> new ParameterizedMessage(
"[{}] (model_alias [{}]) attempting to load and cache",
Expand All @@ -298,7 +308,7 @@ private void loadModel(String modelId, Consumer consumer) {
provider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), ActionListener.wrap(
trainedModelConfig -> {
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId);
provider.getTrainedModelForInference(modelId, ActionListener.wrap(
provider.getTrainedModelForInference(modelId, consumer == Consumer.INTERNAL, ActionListener.wrap(
inferenceDefinition -> {
try {
// Since we have used the previously stored estimate to help guard against OOM we need
Expand Down Expand Up @@ -327,14 +337,14 @@ private void loadModel(String modelId, Consumer consumer) {
));
}

private void loadWithoutCaching(String modelId, ActionListener<LocalModel> modelActionListener) {
private void loadWithoutCaching(String modelId, Consumer consumer, ActionListener<LocalModel> modelActionListener) {
// If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called
// by a simulated pipeline
provider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), ActionListener.wrap(
trainedModelConfig -> {
// Verify we can pull the model into memory without causing OOM
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId);
provider.getTrainedModelForInference(modelId, ActionListener.wrap(
provider.getTrainedModelForInference(modelId, consumer == Consumer.INTERNAL, ActionListener.wrap(
inferenceDefinition -> {
InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ?
inferenceConfigFromTargetType(inferenceDefinition.getTargetType()) :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,13 +392,17 @@ private void storeTrainedModelAndDefinition(TrainedModelConfig trainedModelConfi
* do not.
*
* @param modelId The model tp get
* @param unsafe when true, the compressed bytes size is not checked and the circuit breaker is solely responsible for
* preventing OOMs
* @param listener The listener
*/
public void getTrainedModelForInference(final String modelId, final ActionListener<InferenceDefinition> listener) {
public void getTrainedModelForInference(final String modelId, boolean unsafe, final ActionListener<InferenceDefinition> listener) {
// TODO Change this when we get more than just langIdent stored
if (MODELS_STORED_AS_RESOURCE.contains(modelId)) {
try {
TrainedModelConfig config = loadModelFromResource(modelId, false).build().ensureParsedDefinition(xContentRegistry);
TrainedModelConfig config = loadModelFromResource(modelId, false)
.build()
.ensureParsedDefinitionUnsafe(xContentRegistry);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should respect the value of the unsafe parameter the same as line 432

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@davidkyle I don't think it should. Models in the resourceFiles are provided in the jar distribution. So, I don't think we should ever check their stream length on parsing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, I don't think we should ever check their stream length on parsing.

Because we know how big they are? Why not enforce that with a simple check.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is superfluous to me as the only way to adjust this resource is to modify the resource files directly on disk and since we control these resource models, we already know and trust their sizes.

assert config.getModelDefinition().getTrainedModel() instanceof LangIdentNeuralNetwork;
assert config.getModelType() == TrainedModelType.LANG_IDENT;
listener.onResponse(
Expand All @@ -425,10 +429,9 @@ public void getTrainedModelForInference(final String modelId, final ActionListen
success -> {
try {
BytesReference compressedData = getDefinitionFromDocs(docs, modelId);
InferenceDefinition inferenceDefinition = InferenceToXContentCompressor.inflate(
compressedData,
InferenceDefinition::fromXContent,
xContentRegistry);
InferenceDefinition inferenceDefinition = unsafe ?
InferenceToXContentCompressor.inflateUnsafe(compressedData, InferenceDefinition::fromXContent, xContentRegistry) :
InferenceToXContentCompressor.inflate(compressedData, InferenceDefinition::fromXContent, xContentRegistry);

listener.onResponse(inferenceDefinition);
} catch (Exception e) {
Expand Down
Loading