Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating inference logic to add node level request-response logging #3874

Merged
merged 3 commits into from
Feb 2, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 74 additions & 37 deletions executor/predictor/predictor_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func (p *PredictorProcess) getModelName(node *v1.PredictiveUnit) string {
return modelName
}

func (p *PredictorProcess) transformInput(node *v1.PredictiveUnit, msg payload.SeldonPayload) (payload.SeldonPayload, error) {
func (p *PredictorProcess) transformInput(node *v1.PredictiveUnit, msg payload.SeldonPayload, puid string) (tmsg payload.SeldonPayload, err error) {
callModel := false
callTransformInput := false
if (*node).Type != nil {
Expand All @@ -99,31 +99,44 @@ func (p *PredictorProcess) transformInput(node *v1.PredictiveUnit, msg payload.S

modelName := p.getModelName(node)

if callModel {
msg, err := p.Client.Chain(p.Ctx, modelName, msg)
if err != nil {
return nil, err
if callModel || callTransformInput {
//Log Request
if node.Logger != nil && (node.Logger.Mode == v1.LogRequest || node.Logger.Mode == v1.LogAll) {
err := p.logPayload(node.Name, node.Logger, payloadLogger.InferenceRequest, msg, puid)
if err != nil {
return nil, err
}
Comment on lines +103 to +108
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan of logging inside transformInput. It is not clear from the Predict function where transformInput is called that the message may or may not be logged.

edit; unfortunately, due to the way the code has been structured, this is the easiest way to do what's needed because the callTransformInput and callModel flags are determined in here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I also tried structuring the code in a different manner to start with but this seemed easiest considering other aspects of the implementation.

}
p.RoutingMutex.Lock()
p.Routing[node.Name] = -1
p.RoutingMutex.Unlock()
return p.Client.Predict(p.Ctx, modelName, node.Endpoint.ServiceHost, p.getPort(node), msg, p.Meta.Meta)
} else if callTransformInput {

msg, err := p.Client.Chain(p.Ctx, modelName, msg)
if err != nil {
return nil, err
}
p.RoutingMutex.Lock()
p.Routing[node.Name] = -1
p.RoutingMutex.Unlock()
return p.Client.TransformInput(p.Ctx, modelName, node.Endpoint.ServiceHost, p.getPort(node), msg, p.Meta.Meta)

if callTransformInput {
tmsg, err = p.Client.TransformInput(p.Ctx, modelName, node.Endpoint.ServiceHost, p.getPort(node), msg, p.Meta.Meta)
} else {
tmsg, err = p.Client.Predict(p.Ctx, modelName, node.Endpoint.ServiceHost, p.getPort(node), msg, p.Meta.Meta)
}
if tmsg != nil && err == nil {
// Log Response
if node.Logger != nil && (node.Logger.Mode == v1.LogResponse || node.Logger.Mode == v1.LogAll) {
err := p.logPayload(node.Name, node.Logger, payloadLogger.InferenceResponse, tmsg, puid)
if err != nil {
return nil, err
}
}
}
return tmsg, err
} else {
return msg, nil
}

}

func (p *PredictorProcess) transformOutput(node *v1.PredictiveUnit, msg payload.SeldonPayload) (payload.SeldonPayload, error) {
func (p *PredictorProcess) transformOutput(node *v1.PredictiveUnit, msg payload.SeldonPayload, puid string) (payload.SeldonPayload, error) {
callClient := false
if (*node).Type != nil {
switch *node.Type {
Expand All @@ -138,11 +151,29 @@ func (p *PredictorProcess) transformOutput(node *v1.PredictiveUnit, msg payload.
modelName := p.getModelName(node)

if callClient {
//Log Request
if node.Logger != nil && (node.Logger.Mode == v1.LogRequest || node.Logger.Mode == v1.LogAll) {
err := p.logPayload(node.Name, node.Logger, payloadLogger.InferenceRequest, msg, puid)
if err != nil {
return nil, err
}
}

msg, err := p.Client.Chain(p.Ctx, modelName, msg)
if err != nil {
return nil, err
}
return p.Client.TransformOutput(p.Ctx, modelName, node.Endpoint.ServiceHost, p.getPort(node), msg, p.Meta.Meta)
tmsg, err := p.Client.TransformOutput(p.Ctx, modelName, node.Endpoint.ServiceHost, p.getPort(node), msg, p.Meta.Meta)
if tmsg != nil && err == nil {
// Log Response
if node.Logger != nil && (node.Logger.Mode == v1.LogResponse || node.Logger.Mode == v1.LogAll) {
err := p.logPayload(node.Name, node.Logger, payloadLogger.InferenceResponse, tmsg, puid)
if err != nil {
return nil, err
}
}
}
return tmsg, err
} else {
return msg, nil
}
Expand Down Expand Up @@ -202,7 +233,7 @@ func (p *PredictorProcess) route(node *v1.PredictiveUnit, msg payload.SeldonPayl
}
}

func (p *PredictorProcess) aggregate(node *v1.PredictiveUnit, msg []payload.SeldonPayload) (payload.SeldonPayload, error) {
func (p *PredictorProcess) aggregate(node *v1.PredictiveUnit, cmsg []payload.SeldonPayload, msg payload.SeldonPayload, puid string) (payload.SeldonPayload, error) {
callClient := false
if (*node).Type != nil {
switch *node.Type {
Expand All @@ -217,24 +248,41 @@ func (p *PredictorProcess) aggregate(node *v1.PredictiveUnit, msg []payload.Seld
modelName := p.getModelName(node)

if callClient {
//Log Request
if node.Logger != nil && (node.Logger.Mode == v1.LogRequest || node.Logger.Mode == v1.LogAll) {
err := p.logPayload(node.Name, node.Logger, payloadLogger.InferenceRequest, msg, puid)
if err != nil {
return nil, err
}
}
p.RoutingMutex.Lock()
p.Routing[node.Name] = -1
p.RoutingMutex.Unlock()
return p.Client.Combine(p.Ctx, modelName, node.Endpoint.ServiceHost, p.getPort(node), msg, p.Meta.Meta)
tmsg, err := p.Client.Combine(p.Ctx, modelName, node.Endpoint.ServiceHost, p.getPort(node), cmsg, p.Meta.Meta)
if tmsg != nil && err == nil {
// Log Response
if node.Logger != nil && (node.Logger.Mode == v1.LogResponse || node.Logger.Mode == v1.LogAll) {
err := p.logPayload(node.Name, node.Logger, payloadLogger.InferenceResponse, tmsg, puid)
if err != nil {
return nil, err
}
}
}
return tmsg, err
} else {
return msg[0], nil
return cmsg[0], nil
}

}

func (p *PredictorProcess) predictChildren(node *v1.PredictiveUnit, msg payload.SeldonPayload) (payload.SeldonPayload, error) {
func (p *PredictorProcess) predictChildren(node *v1.PredictiveUnit, msg payload.SeldonPayload, puid string) (payload.SeldonPayload, error) {
if node.Children != nil && len(node.Children) > 0 {
route, err := p.route(node, msg)
if err != nil {
return nil, err
}
var cmsgs []payload.SeldonPayload
if route == -1 {

cmsgs = make([]payload.SeldonPayload, len(node.Children))
var errs = make([]error, len(node.Children))
wg := sync.WaitGroup{}
Expand Down Expand Up @@ -270,7 +318,7 @@ func (p *PredictorProcess) predictChildren(node *v1.PredictiveUnit, msg payload.
return cmsgs[0], err
}
}
return p.aggregate(node, cmsgs)
return p.aggregate(node, cmsgs, msg, puid)
} else {
// Don't add routing for leaf nodes
return msg, nil
Expand Down Expand Up @@ -365,29 +413,18 @@ func (p *PredictorProcess) Predict(node *v1.PredictiveUnit, msg payload.SeldonPa
if err != nil {
return nil, err
}
//Log Request
if node.Logger != nil && (node.Logger.Mode == v1.LogRequest || node.Logger.Mode == v1.LogAll) {
err := p.logPayload(node.Name, node.Logger, payloadLogger.InferenceRequest, msg, puid)
if err != nil {
return nil, err
}
}
tmsg, err := p.transformInput(node, msg)

tmsg, err := p.transformInput(node, msg, puid)
if err != nil {
return tmsg, err
}
cmsg, err := p.predictChildren(node, tmsg)
cmsg, err := p.predictChildren(node, tmsg, puid)
if err != nil {
return cmsg, err
}
response, err := p.transformOutput(node, cmsg)
// Log Response
if err == nil && node.Logger != nil && (node.Logger.Mode == v1.LogResponse || node.Logger.Mode == v1.LogAll) {
err := p.logPayload(node.Name, node.Logger, payloadLogger.InferenceResponse, response, puid)
if err != nil {
return nil, err
}
}

response, err := p.transformOutput(node, cmsg, puid)

if envEnableRoutingInjection {
if routeResponse, err := util.InsertRouteToSeldonPredictPayload(response, &p.Routing); err == nil {
return routeResponse, err
Expand Down