-
Notifications
You must be signed in to change notification settings - Fork 24.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ML] Infer against model deployment #71177
[ML] Infer against model deployment #71177
Conversation
Pinging @elastic/ml-core (Team:ML) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couple of suggestions but LGTM
super(in); | ||
result = new PyTorchResult(in); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does response need a writeTo
method?
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
result.writeTo(out);
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well spotted!
double[][] primitiveDoubles = new double[listOfListOfDoubles.size()][]; | ||
for (int i = 0; i < listOfListOfDoubles.size(); i++) { | ||
List<Double> row = listOfListOfDoubles.get(i); | ||
double[] primitiveRow = new double[row.size()]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
double[] primitiveRow = new double[row.size()]; | |
primitiveDoubles[i] = row.toArray(new double[]{}); |
rather than copying the elements. Or if the elements must be copied perhaps System.arrayCopy(primitiveRow, row.toArray(new double[]{});
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about using the row.stream().mapToDouble(d -> d).toArray();
way?
ProcessContext processContext = processContextByAllocation.get(task.getAllocationId()); | ||
try { | ||
String requestId = processContext.process.get().writeInferenceRequest(inputs); | ||
waitForResult(processContext, requestId, listener); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A future improvement would be to not block the thread here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this might be worth fixing now. It's something that could be forgotten with terrible consequences. I'll work on this following the pattern we used in AutodetectCommunicator
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, the pattern in AutodetectCommunicator
ensures we only perform one operation at a time against the native process. This might not be a restriction for pytorch. We could investigate multithreading capabilities in order to do inference on multiple requests in parallel in the same process. So, for now, I will just make sure we're not blocking the thread.
private static final String NAME = "pytorch_inference"; | ||
|
||
private static AtomicLong ms_RequestId = new AtomicLong(1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does the ms_
mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(m)ember (s)tatic. Not sure if we're supposed to be using this convention. In all fairness, I saw we did it this way in AutodetectControlMsgWriter.ms_FlushNumber
but perhaps that's too ancient to follow :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It’s the old Prelert convention that matched the C++ standards. We took these out of the Java code in 2016 but must have missed that one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool, I'll change it then. And raise a tiny PR to fix the flush one too.
logger.debug(() -> new ParameterizedMessage("[{}] Parsed result with id [{}]", deploymentId, result.getRequestId())); | ||
PendingResult pendingResult = pendingResults.get(result.getRequestId()); | ||
if (pendingResult == null) { | ||
logger.debug(() -> new ParameterizedMessage("[{}] no pending result for [{}]", deploymentId, result.getRequestId())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logger.debug(() -> new ParameterizedMessage("[{}] no pending result for [{}]", deploymentId, result.getRequestId())); | |
logger.warn(() -> new ParameterizedMessage("[{}] no pending result for [{}]", deploymentId, result.getRequestId())); |
This is interesting enough to be warn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should only occur if the infer request timed out. That would indicate a throughput problem. I'm not sure we'd want to fill the log with those in that case as we could be getting an entry per document we're trying to apply inference to.
This adds a temporary API for doing inference against a trained model deployment.
f9c32a9
to
b3cc616
Compare
The feature branch contains changes to configure PyTorch models with a TrainedModelConfig and defines a format to store the binary models. The _start and _stop deployment actions control the model lifecycle and the model can be directly evaluated with the _infer endpoint. 2 Types of NLP tasks are supported: Named Entity Recognition and Fill Mask. The feature branch consists of these PRs: #73523, #72218, #71679 #71323, #71035, #71177, #70713
This adds a temporary API for doing inference against
a trained model deployment.